├── .gitignore ├── LICENSE ├── README.md ├── assets ├── frame0037_pred.png └── frame_0037_frame.png ├── code.v.1.0 ├── README.md ├── alt_cuda_corr │ ├── correlation.cpp │ ├── correlation_kernel.cu │ └── setup.py ├── chairs_split.txt ├── core │ ├── __init__.py │ ├── attention.py │ ├── corr.py │ ├── datasets.py │ ├── deq.py │ ├── deq_demo.py │ ├── extractor.py │ ├── gma.py │ ├── lib │ │ ├── grad.py │ │ ├── jacobian.py │ │ ├── layer_utils.py │ │ ├── optimizations.py │ │ └── solvers.py │ ├── metrics.py │ ├── update.py │ └── utils │ │ ├── __init__.py │ │ ├── augmentor.py │ │ ├── flow_viz.py │ │ ├── frame_utils.py │ │ ├── grid_sample.py │ │ └── utils.py ├── evaluate.py ├── main.py ├── ref │ ├── B_1_step_grad.txt │ ├── H_1_step_grad.txt │ └── val.txt ├── train_B.sh ├── train_B_demo.sh ├── train_H_demo.sh ├── train_H_full.sh ├── val.sh ├── viz.py └── viz.sh └── code.v.2.0 ├── chairs_split.txt ├── core ├── __init__.py ├── corr.py ├── datasets.py ├── deq │ ├── __init__.py │ ├── arg_utils.py │ ├── deq_class.py │ ├── dropout.py │ ├── grad.py │ ├── jacobian.py │ ├── layer_utils.py │ ├── norm │ │ ├── __init__.py │ │ └── weight_norm.py │ └── solvers.py ├── deq_flow.py ├── extractor.py ├── gma.py ├── metrics.py ├── update.py └── utils │ ├── __init__.py │ ├── augmentor.py │ ├── flow_viz.py │ ├── frame_utils.py │ ├── grid_sample.py │ └── utils.py ├── evaluate.py ├── log └── val.txt ├── main.py ├── train_H.sh ├── train_H_1_step_grad.sh ├── val.sh └── viz.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib64/ 18 | parts/ 19 | sdist/ 20 | var/ 21 | wheels/ 22 | pip-wheel-metadata/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # pipenv 87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 90 | # install all needed dependencies. 91 | #Pipfile.lock 92 | 93 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 94 | __pypackages__/ 95 | 96 | # Celery stuff 97 | celerybeat-schedule 98 | celerybeat.pid 99 | 100 | # SageMath parsed files 101 | *.sage.py 102 | 103 | # Environments 104 | .env 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Equilibrium Optical Flow Estimation 2 | 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/deep-equilibrium-optical-flow-estimation/optical-flow-estimation-on-kitti-2015-train)](https://paperswithcode.com/sota/optical-flow-estimation-on-kitti-2015-train?p=deep-equilibrium-optical-flow-estimation) 4 | 5 | (🌟Version 2.0 released!🌟) 6 | 7 | This is the official repo for the paper [*Deep Equilibrium Optical Flow Estimation*](https://arxiv.org/abs/2204.08442) (CVPR 2022), by [Shaojie Bai](https://jerrybai1995.github.io/)\*, [Zhengyang Geng](https://gsunshine.github.io/)\*, [Yash Savani](https://yashsavani.com/) and [J. Zico Kolter](http://zicokolter.com/). 8 | 9 |
10 | 11 | > A deep equilibrium (DEQ) flow estimator directly models the flow as a path-independent, “infinite-level” fixed-point solving process. We propose to use this implicit framework to replace the existing recurrent approach to optical flow estimation. The DEQ flows converge faster, require less memory, are often more accurate, and are compatible with prior model designs (e.g., RAFT and GMA). 12 | 13 | ## Demo 14 | 15 | We provide a demo video of the DEQ flow results below. 16 | 17 | https://user-images.githubusercontent.com/18630903/163676562-e14a433f-4c71-4994-8e3d-97b3c33d98ab.mp4 18 | 19 | --- 20 | 21 | ## Update 22 | 23 | 🌟 2022.xx.xx - Support visualization and demo on your own datasets and videos! Coming soon! 24 | 25 | 🌟 2022.08.08 - Release the **version 2.0** of DEQ-Flow! DEQ-Flow will be merged into [DEQ](https://github.com/locuslab/deq) after further upgrading and unit testing. 26 | 27 | - A clean and decoupled **[DEQ lib](https://github.com/locuslab/deq-flow/blob/main/code.v.2.0/core/deq)**. This is a fully featured and out-of-the-box lib. You're welcome to implement **your own DEQ** using our DEQ lib! We support the following features. (*The DEQ lib will be available on PyPI soon for easy installation via `pip`.*) 28 | - Automatic arg parser decorator. You can call this function to add the DEQ args to your program. See the explanation for args [here](https://github.com/locuslab/deq-flow/blob/main/code.v.2.0/core/deq/arg_utils.py)! 29 | 30 | ```Python 31 | add_deq_args(parser) 32 | ``` 33 | 34 | - Automatic DEQ definition. Call `get_deq` to get your DEQ class! **It's highly decoupled implementation agnostic to your model design!** 35 | 36 | ```Python 37 | DEQ = get_deq(args) 38 | self.deq = DEQ(args) 39 | ``` 40 | 41 | - Automatic normalization for DEQ. You now do not need to add normalization manually to each weight in the DEQ func! 42 | 43 | ```Python 44 | if args.wnorm: 45 | apply_weight_norm(self.update_block) 46 | ``` 47 | 48 | - Easy DEQ forward. Even for a multi-equilibria system, you can call the DEQ function using several lines! 49 | 50 | ```Python 51 | # Assume args is a list [z1, z2, ..., zn] 52 | # of to-be-solved equilibrium variables. 53 | def func(*args): 54 | # A functor defined in the Pytorch forward function. 55 | # Having the same input and output tensor shapes. 56 | return args 57 | 58 | deq_func = DEQWrapper(func, args) 59 | z_init = deq_func.list2vec(*args) # will be merged into self.deq(...) 60 | z_out, info = self.deq(deq_func, z_init) 61 | ``` 62 | 63 | - Automatic DEQ training. Gradients (both exact and inexact grad) are tracked automatically! Fixed point correction can be customized through your arg parser. Just post-process `z_out` as you want! 64 | 65 | - Benchmarked results and [checkpoints](https://drive.google.com/drive/folders/1a_eX_wYN1qTw2Rj1naEXhcsG4D3KKxFw?usp=sharing). Using the release code base v.2.0, we've trained DEQ-Flow-H on FlyingChairs and FlyingThings for two schedules, *120k+120k (1x)* and *120k+360k (3x)*. This implementation demonstrated a new SOTA, surpassing our previous results in performance, training speed, and memory usage. 66 | 67 | Notably, we also benchmark RAFT using the same model size. DEQ-Flow demonstrates a clear performance and efficiency margin and **much stronger scaling property** (scale up to larger models) over RAFT! 68 | 69 | | Checkpoint Name | Sintel (clean) | Sintel (final) | KITTI AEPE | KITTI F1-all | 70 | | :--------------: | :------------: | :------------: | :---------: | :----------: | 71 | | RAFT-H-1x | 1.36 | 2.59 | 4.47 | 16.16 | 72 | | DEQ-Flow-H-1x | 1.27 | 2.58 | 3.76 | 12.95 | 73 | | DEQ-Flow-H-3x | 1.27 | 2.48 | 3.77 | 13.41 | 74 | 75 | - 1x=120k iterations on FlyingThings, 3x=360k iterations on FlyingThings, using a batch size of 6. 76 | - Increasing the batch size on FlyingThings can further improve these results, e.g., a batch size of 12 can reduce the F1-all of DEQ-Flow-H-1x to around 12.5 on KITTI. 77 | 78 | To validate our results, download the pretrained [checkpoints](https://drive.google.com/drive/folders/1a_eX_wYN1qTw2Rj1naEXhcsG4D3KKxFw?usp=sharing) into the `checkpoints` directory. Run the following command in [code.v.2.0](https://github.com/locuslab/deq-flow/blob/main/code.v.2.0/) to infer over the Sintel train set and the KITTI train set. This is a reference [log](https://github.com/locuslab/deq-flow/blob/main/code.v.2.0/log/val.txt). 79 | 80 | ```bash 81 | bash val.sh 82 | ``` 83 | 84 | --- 85 | 86 | ## Requirements 87 | 88 | The code in this repo has been tested on PyTorch v1.10.0. Install required environments through the following commands. 89 | 90 | ```bash 91 | conda create --name deq python==3.6.10 92 | conda activate deq 93 | conda install pytorch==1.10.0 torchvision==0.11.0 torchaudio==0.10.0 cudatoolkit=11.3 -c pytorch -c conda-forge 94 | conda install tensorboard scipy opencv matplotlib einops termcolor -c conda-forge 95 | ``` 96 | 97 | Download the following datasets into the `datasets` directory. 98 | 99 | - [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs) 100 | - [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) 101 | - [MPI Sintel](http://sintel.is.tue.mpg.de/) 102 | - [KITTI 2015](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) 103 | - [HD1k](http://hci-benchmark.iwr.uni-heidelberg.de/) 104 | 105 | --- 106 | The following README doc is for version 1.0, i.e., [code.v.1.0](https://github.com/locuslab/deq-flow/blob/main/code.v.1.0/). You can follow this to reproduce all the results. 107 | 108 | ## Inference 109 | 110 | Download the pretrained [checkpoints](https://drive.google.com/drive/folders/1PeyOr4kmSuMWrh4iwYKbVLqDU6WPX-HM?usp=sharing) into the `checkpoints` directory. Run the following command to infer over the Sintel train set and the KITTI train set. 111 | 112 | ```bash 113 | bash val.sh 114 | ``` 115 | 116 | You may expect the following performance statistics of given checkpoints. This is a reference [log](https://github.com/locuslab/deq-flow/blob/main/code.v.1.0/ref/val.txt). 117 | 118 | | Checkpoint Name | Sintel (clean) | Sintel (final) | KITTI AEPE | KITTI F1-all | 119 | | :--------------: | :------------: | :------------: | :---------: | :----------: | 120 | | DEQ-Flow-B | 1.43 | 2.79 | 5.43 | 16.67 | 121 | | DEQ-Flow-H-1 | 1.45 | 2.58 | 3.97 | 13.41 | 122 | | DEQ-Flow-H-2 | 1.37 | 2.62 | 3.97 | 13.62 | 123 | | DEQ-Flow-H-3 | 1.36 | 2.62 | 4.02 | 13.92 | 124 | 125 | ## Visualization 126 | 127 | Download the pretrained [checkpoints](https://drive.google.com/drive/folders/1PeyOr4kmSuMWrh4iwYKbVLqDU6WPX-HM?usp=sharing) into the `checkpoints` directory. Run the following command to visualize the optical flow estimation over the KITTI test set. 128 | 129 | ```bash 130 | bash viz.sh 131 | ``` 132 | 133 | ## Training 134 | 135 | Download *FlyingChairs*-pretrained [checkpoints](https://drive.google.com/drive/folders/1PeyOr4kmSuMWrh4iwYKbVLqDU6WPX-HM?usp=sharing) into the `checkpoints` directory. 136 | 137 | For the efficiency mode, you can run 1-step gradient to train DEQ-Flow-B via the following command. Memory overhead per GPU is about 5800 MB. 138 | 139 | You may expect best results of about 1.46 (AEPE) on Sintel (clean), 2.85 (AEPE) on Sintel (final), 5.29 (AEPE) and 16.24 (F1-all) on KITTI. This is a reference [log](https://github.com/locuslab/deq-flow/blob/main/code.v.1.0/ref/B_1_step_grad.txt). 140 | 141 | ```bash 142 | bash train_B_demo.sh 143 | ``` 144 | 145 | For training a demo of DEQ-Flow-H, you can run this command. Memory overhead per GPU is about 6300 MB. It can be further reduced to about **4200 MB** per GPU when combined with `--mixed-precision`. You can further reduce the memory cost if you employ the CUDA implementation of cost volumn by [RAFT](https://github.com/princeton-vl/RAFT). 146 | 147 | You may expect best results of about 1.41 (AEPE) on Sintel (clean), 2.76 (AEPE) on Sintel (final), 4.44 (AEPE) and 14.81 (F1-all) on KITTI. This is a reference [log](https://github.com/locuslab/deq-flow/blob/main/code.v.1.0/ref/H_1_step_grad.txt). 148 | 149 | ```bash 150 | bash train_H_demo.sh 151 | ``` 152 | 153 | To train DEQ-Flow-B on Chairs and Things, use the following command. 154 | 155 | ```bash 156 | bash train_B.sh 157 | ``` 158 | 159 | For the performance mode, you can run this command to train DEQ-Flow-H using the ``C+T`` and ``C+T+S+K+H`` schedule. You may expect the performance of <1.40 (AEPE) on Sintel (clean), around 2.60 (AEPE) on Sintel (final), around 4.00 (AEPE) and 13.6 (F1-all) on KITTI. DEQ-Flow-H-1,2,3 are checkpoints from three runs. 160 | 161 | Currently, this training protocol could entail resources slightly more than two 11 GB GPUs. In the near future, we will upload an implementation revision (of the DEQ models) that shall further reduce this overhead to **less than two 11 GB GPUs**. 162 | 163 | ```bash 164 | bash train_H_full.sh 165 | ``` 166 | 167 | ## Code Usage 168 | 169 | Under construction. We will provide more detailed instructions on the code usage (e.g., argparse flags, fixed-point solvers, backward IFT modes) in the coming days. 170 | 171 | --- 172 | 173 | ## A Tutorial on DEQ 174 | 175 | If you hope to learn more about DEQ models, here is an official NeurIPS [tutorial](https://implicit-layers-tutorial.org/) on implicit deep learning. Enjoy yourself! 176 | 177 | ## Reference 178 | 179 | If you find our work helpful to your research, please consider citing this paper. :) 180 | 181 | ```bib 182 | @inproceedings{deq-flow, 183 | author = {Bai, Shaojie and Geng, Zhengyang and Savani, Yash and Kolter, J. Zico}, 184 | title = {Deep Equilibrium Optical Flow Estimation}, 185 | booktitle = {Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 186 | year = {2022} 187 | } 188 | ``` 189 | 190 | ## Credit 191 | 192 | A lot of the utility code in this repo were adapted from the [RAFT](https://github.com/princeton-vl/RAFT) repo and the [DEQ](https://github.com/locuslab/deq) repo. 193 | 194 | ## Contact 195 | 196 | Feel free to contact us if you have additional questions. Please drop an email through zhengyanggeng@gmail.com (or [Twitter](https://twitter.com/ZhengyangGeng)). -------------------------------------------------------------------------------- /assets/frame0037_pred.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/deq-flow/869d4bcd9c1d227cbfa16639994d1b118510a9bb/assets/frame0037_pred.png -------------------------------------------------------------------------------- /assets/frame_0037_frame.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/deq-flow/869d4bcd9c1d227cbfa16639994d1b118510a9bb/assets/frame_0037_frame.png -------------------------------------------------------------------------------- /code.v.1.0/README.md: -------------------------------------------------------------------------------- 1 | # Deep Equilibrium Optical Flow Estimation 2 | 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/deep-equilibrium-optical-flow-estimation/optical-flow-estimation-on-kitti-2015-train)](https://paperswithcode.com/sota/optical-flow-estimation-on-kitti-2015-train?p=deep-equilibrium-optical-flow-estimation) 4 | 5 | This is the official repo for the paper [*Deep Equilibrium Optical Flow Estimation*](https://arxiv.org/abs/2204.08442) (CVPR 2022), by [Shaojie Bai](https://jerrybai1995.github.io/)\*, [Zhengyang Geng](https://gsunshine.github.io/)\*, [Yash Savani](https://yashsavani.com/) and [J. Zico Kolter](http://zicokolter.com/). 6 | 7 |
8 | 9 | > A deep equilibrium (DEQ) flow estimator directly models the flow as a path-independent, “infinite-level” fixed-point solving process. We propose to use this implicit framework to replace the existing recurrent approach to optical flow estimation. The DEQ flows converge faster, require less memory, are often more accurate, and are compatible with prior model designs (e.g., RAFT and GMA). 10 | 11 | ## Demo 12 | 13 | We provide a demo video of the DEQ flow results below. 14 | 15 | https://user-images.githubusercontent.com/18630903/163676562-e14a433f-4c71-4994-8e3d-97b3c33d98ab.mp4 16 | 17 | ## Requirements 18 | 19 | The code in this repo has been tested on PyTorch v1.10.0. Install required environments through the following commands. 20 | 21 | ```bash 22 | conda create --name deq python==3.6.10 23 | conda activate deq 24 | conda install pytorch==1.10.0 torchvision==0.11.0 torchaudio==0.10.0 cudatoolkit=11.3 -c pytorch -c conda-forge 25 | conda install tensorboard scipy opencv matplotlib einops termcolor -c conda-forge 26 | ``` 27 | 28 | Download the following datasets into the `datasets` directory. 29 | 30 | - [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs) 31 | - [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) 32 | - [MPI Sintel](http://sintel.is.tue.mpg.de/) 33 | - [KITTI 2015](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) 34 | - [HD1k](http://hci-benchmark.iwr.uni-heidelberg.de/) 35 | 36 | ## Inference 37 | 38 | Download the pretrained [checkpoints](https://drive.google.com/drive/folders/1PeyOr4kmSuMWrh4iwYKbVLqDU6WPX-HM?usp=sharing) into the `checkpoints` directory. Run the following command to infer over the Sintel train set and the KITTI train set. 39 | 40 | ```bash 41 | bash val.sh 42 | ``` 43 | 44 | You may expect the following performance statistics of given checkpoints. This is a reference [log](https://github.com/locuslab/deq-flow/blob/main/code.v.1.0/ref/val.txt). 45 | 46 | | Checkpoint Name | Sintel (clean) | Sintel (final) | KITTI AEPE | KITTI F1-all | 47 | | :--------------: | :------------: | :------------: | :---------: | :----------: | 48 | | DEQ-Flow-B | 1.43 | 2.79 | 5.43 | 16.67 | 49 | | DEQ-Flow-H-1 | 1.45 | 2.58 | 3.97 | 13.41 | 50 | | DEQ-Flow-H-2 | 1.37 | 2.62 | 3.97 | 13.62 | 51 | | DEQ-Flow-H-3 | 1.36 | 2.62 | 4.02 | 13.92 | 52 | 53 | ## Visualization 54 | 55 | Download the pretrained [checkpoints](https://drive.google.com/drive/folders/1PeyOr4kmSuMWrh4iwYKbVLqDU6WPX-HM?usp=sharing) into the `checkpoints` directory. Run the following command to visualize the optical flow estimation over the KITTI test set. 56 | 57 | ```bash 58 | bash viz.sh 59 | ``` 60 | 61 | ## Training 62 | 63 | Download *FlyingChairs*-pretrained [checkpoints](https://drive.google.com/drive/folders/1PeyOr4kmSuMWrh4iwYKbVLqDU6WPX-HM?usp=sharing) into the `checkpoints` directory. 64 | 65 | For the efficiency mode, you can run 1-step gradient to train DEQ-Flow-B via the following command. Memory overhead per GPU is about 5800 MB. 66 | 67 | You may expect best results of about 1.46 (AEPE) on Sintel (clean), 2.85 (AEPE) on Sintel (final), 5.29 (AEPE) and 16.24 (F1-all) on KITTI. This is a reference [log](https://github.com/locuslab/deq-flow/blob/main/code.v.1.0/ref/B_1_step_grad.txt). 68 | 69 | ```bash 70 | bash train_B_demo.sh 71 | ``` 72 | 73 | For training a demo of DEQ-Flow-H, you can run this command. Memory overhead per GPU is about 6300 MB. It can be further reduced to about **4200 MB** per GPU when combined with `--mixed-precision`. You can further reduce the memory cost if you employ the CUDA implementation of cost volumn by [RAFT](https://github.com/princeton-vl/RAFT). 74 | 75 | You may expect best results of about 1.41 (AEPE) on Sintel (clean), 2.76 (AEPE) on Sintel (final), 4.44 (AEPE) and 14.81 (F1-all) on KITTI. This is a reference [log](https://github.com/locuslab/deq-flow/blob/main/code.v.1.0/ref/H_1_step_grad.txt). 76 | 77 | ```bash 78 | bash train_H_demo.sh 79 | ``` 80 | 81 | To train DEQ-Flow-B on Chairs and Things, use the following command. 82 | 83 | ```bash 84 | bash train_B.sh 85 | ``` 86 | 87 | For the performance mode, you can run this command to train DEQ-Flow-H using the ``C+T`` and ``C+T+S+K+H`` schedule. You may expect the performance of <1.40 (AEPE) on Sintel (clean), around 2.60 (AEPE) on Sintel (final), around 4.00 (AEPE) and 13.6 (F1-all) on KITTI. DEQ-Flow-H-1,2,3 are checkpoints from three runs. 88 | 89 | Currently, this training protocol could entail resources slightly more than two 11 GB GPUs. In the near future, we will upload an implementation revision (of the DEQ models) that shall further reduce this overhead to **less than two 11 GB GPUs**. 90 | 91 | ```bash 92 | bash train_H_full.sh 93 | ``` 94 | 95 | ## Code Usage 96 | 97 | Under construction. We will provide more detailed instructions on the code usage (e.g., argparse flags, fixed-point solvers, backward IFT modes) in the coming days. 98 | 99 | ## A Tutorial on DEQ 100 | 101 | If you hope to learn more about DEQ models, here is an official NeurIPS [tutorial](https://implicit-layers-tutorial.org/) on implicit deep learning. Enjoy yourself! 102 | 103 | ## Reference 104 | 105 | If you find our work helpful to your research, please consider citing this paper. :) 106 | 107 | ```bib 108 | @inproceedings{deq-flow, 109 | author = {Bai, Shaojie and Geng, Zhengyang and Savani, Yash and Kolter, J. Zico}, 110 | title = {Deep Equilibrium Optical Flow Estimation}, 111 | booktitle = {Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 112 | year = {2022} 113 | } 114 | ``` 115 | 116 | ## Credit 117 | 118 | A lot of the utility code in this repo were adapted from the [RAFT](https://github.com/princeton-vl/RAFT) repo and the [DEQ](https://github.com/locuslab/deq) repo. 119 | 120 | ## Contact 121 | 122 | Feel free to contact us if you have additional questions. Please drop an email through zhengyanggeng@gmail.com (or [Twitter](https://twitter.com/ZhengyangGeng)). 123 | -------------------------------------------------------------------------------- /code.v.1.0/alt_cuda_corr/correlation.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | // CUDA forward declarations 5 | std::vector corr_cuda_forward( 6 | torch::Tensor fmap1, 7 | torch::Tensor fmap2, 8 | torch::Tensor coords, 9 | int radius); 10 | 11 | std::vector corr_cuda_backward( 12 | torch::Tensor fmap1, 13 | torch::Tensor fmap2, 14 | torch::Tensor coords, 15 | torch::Tensor corr_grad, 16 | int radius); 17 | 18 | // C++ interface 19 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 20 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 21 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 22 | 23 | std::vector corr_forward( 24 | torch::Tensor fmap1, 25 | torch::Tensor fmap2, 26 | torch::Tensor coords, 27 | int radius) { 28 | CHECK_INPUT(fmap1); 29 | CHECK_INPUT(fmap2); 30 | CHECK_INPUT(coords); 31 | 32 | return corr_cuda_forward(fmap1, fmap2, coords, radius); 33 | } 34 | 35 | 36 | std::vector corr_backward( 37 | torch::Tensor fmap1, 38 | torch::Tensor fmap2, 39 | torch::Tensor coords, 40 | torch::Tensor corr_grad, 41 | int radius) { 42 | CHECK_INPUT(fmap1); 43 | CHECK_INPUT(fmap2); 44 | CHECK_INPUT(coords); 45 | CHECK_INPUT(corr_grad); 46 | 47 | return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius); 48 | } 49 | 50 | 51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 52 | m.def("forward", &corr_forward, "CORR forward"); 53 | m.def("backward", &corr_backward, "CORR backward"); 54 | } -------------------------------------------------------------------------------- /code.v.1.0/alt_cuda_corr/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | 5 | setup( 6 | name='correlation', 7 | ext_modules=[ 8 | CUDAExtension('alt_cuda_corr', 9 | sources=['correlation.cpp', 'correlation_kernel.cu'], 10 | extra_compile_args={'cxx': [], 'nvcc': ['-O3']}), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) 15 | 16 | -------------------------------------------------------------------------------- /code.v.1.0/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/deq-flow/869d4bcd9c1d227cbfa16639994d1b118510a9bb/code.v.1.0/core/__init__.py -------------------------------------------------------------------------------- /code.v.1.0/core/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | from einops import rearrange 4 | 5 | from lib.optimizations import weight_norm 6 | 7 | 8 | class Attention(nn.Module): 9 | def __init__( 10 | self, 11 | dim, 12 | heads = 4, 13 | dim_head = 128, 14 | ): 15 | super().__init__() 16 | self.heads = heads 17 | self.scale = dim_head ** -0.5 18 | inner_dim = heads * dim_head 19 | 20 | self.to_q = nn.Conv2d(dim, inner_dim, 1, bias=False) 21 | self.to_k = nn.Conv2d(dim, inner_dim, 1, bias=False) 22 | self.to_v = nn.Conv2d(dim, inner_dim, 1, bias=False) 23 | 24 | self.gamma = nn.Parameter(torch.zeros(1)) 25 | self.out = nn.Conv2d(inner_dim, dim, 1, bias=False) 26 | 27 | def _wnorm(self): 28 | self.to_q, self.to_q_fn = weight_norm(module=self.to_q, names=['weight'], dim=0) 29 | self.to_k, self.to_k_fn = weight_norm(module=self.to_k, names=['weight'], dim=0) 30 | self.to_v, self.to_v_fn = weight_norm(module=self.to_v, names=['weight'], dim=0) 31 | 32 | self.out, self.out_fn = weight_norm(module=self.out, names=['weight'], dim=0) 33 | 34 | def reset(self): 35 | for name in ['to_q', 'to_k', 'to_v', 'out']: 36 | if name + '_fn' in self.__dict__: 37 | eval(f'self.{name}_fn').reset(eval(f'self.{name}')) 38 | 39 | def forward(self, q, k, v): 40 | heads, b, c, h, w = self.heads, *v.shape 41 | 42 | input_q = q 43 | q = self.to_q(q) 44 | k = self.to_k(k) 45 | v = self.to_v(v) 46 | 47 | q, k = map(lambda t: rearrange(t, 'b (h d) x y -> b h x y d', h=heads), (q, k)) 48 | v = rearrange(v, 'b (h d) x y -> b h (x y) d', h=heads) 49 | 50 | sim = self.scale * einsum('b h x y d, b h u v d -> b h x y u v', q, k) 51 | sim = rearrange(sim, 'b h x y u v -> b h (x y) (u v)') 52 | attn = sim.softmax(dim=-1) 53 | 54 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 55 | out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w) 56 | 57 | out = self.out(out) 58 | out = input_q + self.gamma * out 59 | 60 | return out 61 | 62 | 63 | if __name__ == "__main__": 64 | att = Attention(dim=128, heads=1) 65 | x = torch.randn(2, 128, 40, 90) 66 | out = att(x, x, x) 67 | 68 | print(out.shape) 69 | -------------------------------------------------------------------------------- /code.v.1.0/core/corr.py: -------------------------------------------------------------------------------- 1 | # This code was originally taken from RAFT without modification 2 | # https://github.com/princeton-vl/RAFT/blob/master/core/corr.py 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from utils.utils import bilinear_sampler, coords_grid 7 | 8 | try: 9 | import alt_cuda_corr 10 | except: 11 | # alt_cuda_corr is not compiled 12 | pass 13 | 14 | 15 | class CorrBlock: 16 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 17 | self.num_levels = num_levels 18 | self.radius = radius 19 | self.corr_pyramid = [] 20 | 21 | # all pairs correlation 22 | corr = CorrBlock.corr(fmap1, fmap2) 23 | 24 | batch, h1, w1, dim, h2, w2 = corr.shape 25 | corr = corr.reshape(batch*h1*w1, dim, h2, w2) 26 | 27 | self.corr_pyramid.append(corr) 28 | for i in range(self.num_levels-1): 29 | corr = F.avg_pool2d(corr, 2, stride=2) 30 | self.corr_pyramid.append(corr) 31 | 32 | def __call__(self, coords): 33 | r = self.radius 34 | coords = coords.permute(0, 2, 3, 1) 35 | batch, h1, w1, _ = coords.shape 36 | 37 | out_pyramid = [] 38 | for i in range(self.num_levels): 39 | corr = self.corr_pyramid[i] 40 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device) 41 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device) 42 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) 43 | 44 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i 45 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) 46 | coords_lvl = centroid_lvl + delta_lvl 47 | 48 | corr = bilinear_sampler(corr, coords_lvl) 49 | corr = corr.view(batch, h1, w1, -1) 50 | out_pyramid.append(corr) 51 | 52 | out = torch.cat(out_pyramid, dim=-1) 53 | return out.permute(0, 3, 1, 2).contiguous().float() 54 | 55 | @staticmethod 56 | def corr(fmap1, fmap2): 57 | batch, dim, ht, wd = fmap1.shape 58 | fmap1 = fmap1.view(batch, dim, ht*wd) 59 | fmap2 = fmap2.view(batch, dim, ht*wd) 60 | 61 | corr = torch.matmul(fmap1.transpose(1,2), fmap2) 62 | corr = corr.view(batch, ht, wd, 1, ht, wd) 63 | return corr / torch.sqrt(torch.tensor(dim).float()) 64 | 65 | 66 | class ConstCorrBlock: 67 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 68 | self.num_levels = num_levels 69 | self.radius = radius 70 | self.corr_pyramid = [] 71 | 72 | # all pairs correlation 73 | corr = CorrBlock.corr(fmap1, fmap2) 74 | 75 | batch, h1, w1, dim, h2, w2 = corr.shape 76 | corr = corr.reshape(batch*h1*w1, dim, h2, w2) 77 | 78 | self.corr_pyramid.append(corr.view(batch, h1*w1, dim, h2, w2)) 79 | for i in range(self.num_levels-1): 80 | corr = F.avg_pool2d(corr, 2, stride=2) 81 | self.corr_pyramid.append(corr.view(batch, h1*w1, *corr.shape[1:])) 82 | 83 | def __call__(self, coords, corr_pyramid=None): 84 | r = self.radius 85 | coords = coords.permute(0, 2, 3, 1) 86 | batch, h1, w1, _ = coords.shape 87 | 88 | corr_pyramid = corr_pyramid if corr_pyramid else self.corr_pyramid 89 | 90 | out_pyramid = [] 91 | for i in range(self.num_levels): 92 | corr = corr_pyramid[i].flatten(0, 1) 93 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device) 94 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device) 95 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) 96 | 97 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i 98 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) 99 | coords_lvl = centroid_lvl + delta_lvl 100 | 101 | corr = bilinear_sampler(corr, coords_lvl) 102 | corr = corr.view(batch, h1, w1, -1) 103 | out_pyramid.append(corr) 104 | 105 | out = torch.cat(out_pyramid, dim=-1) 106 | return out.permute(0, 3, 1, 2).contiguous().float() 107 | 108 | @staticmethod 109 | def corr(fmap1, fmap2): 110 | batch, dim, ht, wd = fmap1.shape 111 | fmap1 = fmap1.view(batch, dim, ht*wd) 112 | fmap2 = fmap2.view(batch, dim, ht*wd) 113 | 114 | corr = torch.matmul(fmap1.transpose(1,2), fmap2) 115 | corr = corr.view(batch, ht, wd, 1, ht, wd) 116 | return corr / torch.sqrt(torch.tensor(dim).float()) 117 | 118 | 119 | class AlternateCorrBlock: 120 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 121 | self.num_levels = num_levels 122 | self.radius = radius 123 | 124 | self.pyramid = [(fmap1, fmap2)] 125 | for i in range(self.num_levels): 126 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2) 127 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 128 | self.pyramid.append((fmap1, fmap2)) 129 | 130 | def __call__(self, coords): 131 | coords = coords.permute(0, 2, 3, 1) 132 | B, H, W, _ = coords.shape 133 | dim = self.pyramid[0][0].shape[1] 134 | 135 | corr_list = [] 136 | for i in range(self.num_levels): 137 | r = self.radius 138 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() 139 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() 140 | 141 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() 142 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) 143 | corr_list.append(corr.squeeze(1)) 144 | 145 | corr = torch.stack(corr_list, dim=1) 146 | corr = corr.reshape(B, -1, H, W) 147 | return corr / torch.sqrt(torch.tensor(dim).float()) 148 | -------------------------------------------------------------------------------- /code.v.1.0/core/datasets.py: -------------------------------------------------------------------------------- 1 | # Data loading based on https://github.com/NVIDIA/flownet2-pytorch 2 | 3 | import numpy as np 4 | import torch 5 | import torch.utils.data as data 6 | import torch.nn.functional as F 7 | 8 | import os 9 | import math 10 | import random 11 | from glob import glob 12 | import os.path as osp 13 | 14 | from utils import frame_utils 15 | from utils.augmentor import FlowAugmentor, SparseFlowAugmentor 16 | 17 | 18 | class FlowDataset(data.Dataset): 19 | def __init__(self, aug_params=None, sparse=False): 20 | self.augmentor = None 21 | self.sparse = sparse 22 | if aug_params is not None: 23 | if sparse: 24 | self.augmentor = SparseFlowAugmentor(**aug_params) 25 | else: 26 | self.augmentor = FlowAugmentor(**aug_params) 27 | 28 | self.is_test = False 29 | self.init_seed = False 30 | self.flow_list = [] 31 | self.image_list = [] 32 | self.extra_info = [] 33 | 34 | def __getitem__(self, index): 35 | 36 | if self.is_test: 37 | img1 = frame_utils.read_gen(self.image_list[index][0]) 38 | img2 = frame_utils.read_gen(self.image_list[index][1]) 39 | img1 = np.array(img1).astype(np.uint8)[..., :3] 40 | img2 = np.array(img2).astype(np.uint8)[..., :3] 41 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 42 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 43 | return img1, img2, self.extra_info[index] 44 | 45 | if not self.init_seed: 46 | worker_info = torch.utils.data.get_worker_info() 47 | if worker_info is not None: 48 | torch.manual_seed(worker_info.id) 49 | np.random.seed(worker_info.id) 50 | random.seed(worker_info.id) 51 | self.init_seed = True 52 | 53 | index = index % len(self.image_list) 54 | valid = None 55 | if self.sparse: 56 | flow, valid = frame_utils.readFlowKITTI(self.flow_list[index]) 57 | else: 58 | flow = frame_utils.read_gen(self.flow_list[index]) 59 | 60 | img1 = frame_utils.read_gen(self.image_list[index][0]) 61 | img2 = frame_utils.read_gen(self.image_list[index][1]) 62 | 63 | flow = np.array(flow).astype(np.float32) 64 | img1 = np.array(img1).astype(np.uint8) 65 | img2 = np.array(img2).astype(np.uint8) 66 | 67 | # grayscale images 68 | if len(img1.shape) == 2: 69 | img1 = np.tile(img1[...,None], (1, 1, 3)) 70 | img2 = np.tile(img2[...,None], (1, 1, 3)) 71 | else: 72 | img1 = img1[..., :3] 73 | img2 = img2[..., :3] 74 | 75 | if self.augmentor is not None: 76 | if self.sparse: 77 | img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid) 78 | else: 79 | img1, img2, flow = self.augmentor(img1, img2, flow) 80 | 81 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 82 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 83 | flow = torch.from_numpy(flow).permute(2, 0, 1).float() 84 | 85 | if valid is not None: 86 | valid = torch.from_numpy(valid) 87 | else: 88 | valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) 89 | 90 | return img1, img2, flow, valid.float() 91 | 92 | 93 | def __rmul__(self, v): 94 | self.flow_list = v * self.flow_list 95 | self.image_list = v * self.image_list 96 | return self 97 | 98 | def __len__(self): 99 | return len(self.image_list) 100 | 101 | 102 | class MpiSintel(FlowDataset): 103 | def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'): 104 | super(MpiSintel, self).__init__(aug_params) 105 | flow_root = osp.join(root, split, 'flow') 106 | image_root = osp.join(root, split, dstype) 107 | 108 | if split == 'test': 109 | self.is_test = True 110 | 111 | for scene in os.listdir(image_root): 112 | image_list = sorted(glob(osp.join(image_root, scene, '*.png'))) 113 | for i in range(len(image_list)-1): 114 | self.image_list += [ [image_list[i], image_list[i+1]] ] 115 | self.extra_info += [ (scene, i) ] # scene and frame_id 116 | 117 | if split != 'test': 118 | self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo'))) 119 | 120 | 121 | class FlyingChairs(FlowDataset): 122 | def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'): 123 | super(FlyingChairs, self).__init__(aug_params) 124 | 125 | images = sorted(glob(osp.join(root, '*.ppm'))) 126 | flows = sorted(glob(osp.join(root, '*.flo'))) 127 | assert (len(images)//2 == len(flows)) 128 | 129 | split_list = np.loadtxt('chairs_split.txt', dtype=np.int32) 130 | for i in range(len(flows)): 131 | xid = split_list[i] 132 | if (split=='training' and xid==1) or (split=='validation' and xid==2): 133 | self.flow_list += [ flows[i] ] 134 | self.image_list += [ [images[2*i], images[2*i+1]] ] 135 | 136 | 137 | class FlyingThings3D(FlowDataset): 138 | def __init__(self, aug_params=None, root='datasets/FlyingThings3D', split='train', dstype='frames_cleanpass'): 139 | super(FlyingThings3D, self).__init__(aug_params) 140 | 141 | if split == 'train': 142 | dir_prefix = 'TRAIN' 143 | elif split == 'test': 144 | dir_prefix = 'TEST' 145 | else: 146 | raise ValueError('Unknown split for FlyingThings3D.') 147 | 148 | for cam in ['left']: 149 | for direction in ['into_future', 'into_past']: 150 | image_dirs = sorted(glob(osp.join(root, dstype, f'{dir_prefix}/*/*'))) 151 | image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) 152 | 153 | flow_dirs = sorted(glob(osp.join(root, f'optical_flow/{dir_prefix}/*/*'))) 154 | flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs]) 155 | 156 | for idir, fdir in zip(image_dirs, flow_dirs): 157 | images = sorted(glob(osp.join(idir, '*.png')) ) 158 | flows = sorted(glob(osp.join(fdir, '*.pfm')) ) 159 | for i in range(len(flows)-1): 160 | if direction == 'into_future': 161 | self.image_list += [ [images[i], images[i+1]] ] 162 | self.flow_list += [ flows[i] ] 163 | elif direction == 'into_past': 164 | self.image_list += [ [images[i+1], images[i]] ] 165 | self.flow_list += [ flows[i+1] ] 166 | 167 | 168 | class KITTI(FlowDataset): 169 | def __init__(self, aug_params=None, split='training', root='datasets/KITTI'): 170 | super(KITTI, self).__init__(aug_params, sparse=True) 171 | if split == 'testing': 172 | self.is_test = True 173 | 174 | root = osp.join(root, split) 175 | images1 = sorted(glob(osp.join(root, 'image_2/*_10.png'))) 176 | images2 = sorted(glob(osp.join(root, 'image_2/*_11.png'))) 177 | 178 | for img1, img2 in zip(images1, images2): 179 | frame_id = img1.split('/')[-1] 180 | self.extra_info += [ [frame_id] ] 181 | self.image_list += [ [img1, img2] ] 182 | 183 | if split == 'training': 184 | self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png'))) 185 | 186 | 187 | class HD1K(FlowDataset): 188 | def __init__(self, aug_params=None, root='datasets/HD1k'): 189 | super(HD1K, self).__init__(aug_params, sparse=True) 190 | 191 | seq_ix = 0 192 | while 1: 193 | flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix))) 194 | images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix))) 195 | 196 | if len(flows) == 0: 197 | break 198 | 199 | for i in range(len(flows)-1): 200 | self.flow_list += [flows[i]] 201 | self.image_list += [ [images[i], images[i+1]] ] 202 | 203 | seq_ix += 1 204 | 205 | 206 | def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'): 207 | """ Create the data loader for the corresponding trainign set """ 208 | 209 | if args.stage == 'chairs': 210 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True} 211 | train_dataset = FlyingChairs(aug_params, split='training') 212 | 213 | elif args.stage == 'things': 214 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True} 215 | clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass') 216 | final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass') 217 | train_dataset = clean_dataset + final_dataset 218 | 219 | elif args.stage == 'sintel': 220 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True} 221 | things = FlyingThings3D(aug_params, dstype='frames_cleanpass') 222 | sintel_clean = MpiSintel(aug_params, split='training', dstype='clean') 223 | sintel_final = MpiSintel(aug_params, split='training', dstype='final') 224 | 225 | if TRAIN_DS == 'C+T+K+S+H': 226 | kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True}) 227 | hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True}) 228 | train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things 229 | 230 | elif TRAIN_DS == 'C+T+K/S': 231 | train_dataset = 100*sintel_clean + 100*sintel_final + things 232 | 233 | elif args.stage == 'kitti': 234 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} 235 | train_dataset = KITTI(aug_params, split='training') 236 | 237 | train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, 238 | pin_memory=False, shuffle=True, num_workers=4, drop_last=True) 239 | 240 | print('Training with %d image pairs' % len(train_dataset)) 241 | return train_loader 242 | 243 | -------------------------------------------------------------------------------- /code.v.1.0/core/deq_demo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from update import BasicUpdateBlock 7 | from extractor import BasicEncoder 8 | from corr import CorrBlock 9 | from utils.utils import bilinear_sampler, coords_grid, upflow8 10 | 11 | from termcolor import colored 12 | 13 | from lib.solvers import anderson, broyden 14 | from lib.grad import make_pair, backward_factory 15 | 16 | from metrics import process_metrics 17 | 18 | 19 | try: 20 | autocast = torch.cuda.amp.autocast 21 | except: 22 | # dummy autocast for PyTorch < 1.6 23 | class autocast: 24 | def __init__(self, enabled): 25 | pass 26 | def __enter__(self): 27 | pass 28 | def __exit__(self, *args): 29 | pass 30 | 31 | 32 | class DEQFlowDemo(nn.Module): 33 | def __init__(self, args): 34 | super(DEQFlowDemo, self).__init__() 35 | self.args = args 36 | 37 | odim = 256 38 | self.hidden_dim = hdim = 128 39 | self.context_dim = cdim = 128 40 | args.corr_levels = 4 41 | args.corr_radius = 4 42 | 43 | # feature network, context network, and update block 44 | self.fnet = BasicEncoder(output_dim=odim, norm_fn='instance', dropout=args.dropout) 45 | self.cnet = BasicEncoder(output_dim=cdim, norm_fn='batch', dropout=args.dropout) 46 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) 47 | 48 | # Added the following for the DEQ models 49 | if args.wnorm: 50 | self.update_block._wnorm() 51 | 52 | self.f_solver = eval(args.f_solver) 53 | self.f_thres = args.f_thres 54 | self.eval_f_thres = int(self.f_thres * args.eval_factor) 55 | self.stop_mode = args.stop_mode 56 | 57 | # Define gradient functions through the backward factory 58 | if args.n_losses > 1: 59 | n_losses = min(args.f_thres, args.n_losses) 60 | delta = int(args.f_thres // n_losses) 61 | self.indexing = [(k+1)*delta for k in range(n_losses)] 62 | else: 63 | self.indexing = [*args.indexing, args.f_thres] 64 | 65 | # By default, we use the same phantom grad for all corrections. 66 | # You can also set different grad steps a, b, and c for different terms by ``args.phantom_grad a b c ...''. 67 | indexing_pg = make_pair(self.indexing, args.phantom_grad) 68 | produce_grad = [ 69 | backward_factory(grad_type=pg, tau=args.tau, sup_all=args.sup_all) for pg in indexing_pg 70 | ] 71 | if args.ift: 72 | # Enabling args.ift will replace the last gradient function by IFT. 73 | produce_grad[-1] = backward_factory( 74 | grad_type='ift', safe_ift=args.safe_ift, b_solver=eval(args.b_solver), 75 | b_solver_kwargs=dict(threshold=args.b_thres, stop_mode=args.stop_mode) 76 | ) 77 | 78 | self.produce_grad = produce_grad 79 | self.hook = None 80 | 81 | def freeze_bn(self): 82 | for m in self.modules(): 83 | if isinstance(m, nn.BatchNorm2d): 84 | m.eval() 85 | 86 | def _initialize_flow(self, img): 87 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 88 | N, _, H, W = img.shape 89 | coords0 = coords_grid(N, H//8, W//8, device=img.device) 90 | coords1 = coords_grid(N, H//8, W//8, device=img.device) 91 | 92 | # optical flow computed as difference: flow = coords1 - coords0 93 | return coords0, coords1 94 | 95 | def _upsample_flow(self, flow, mask): 96 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 97 | N, _, H, W = flow.shape 98 | mask = mask.view(N, 1, 9, 8, 8, H, W) 99 | mask = torch.softmax(mask, dim=2) 100 | 101 | up_flow = F.unfold(8 * flow, [3,3], padding=1) 102 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 103 | 104 | up_flow = torch.sum(mask * up_flow, dim=2) 105 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 106 | return up_flow.reshape(N, 2, 8*H, 8*W) 107 | 108 | def _decode(self, z_out, vec2list, coords0): 109 | flow_predictions = [] 110 | 111 | for z_pred in z_out: 112 | net, coords1 = vec2list(z_pred) 113 | up_mask = .25 * self.update_block.mask(net) 114 | 115 | if up_mask is None: 116 | flow_up = upflow8(coords1 - coords0) 117 | else: 118 | flow_up = self._upsample_flow(coords1 - coords0, up_mask) 119 | 120 | flow_predictions.append(flow_up) 121 | 122 | return flow_predictions 123 | 124 | def _fixed_point_solve(self, deq_func, z_star, f_thres=None, **kwargs): 125 | if f_thres is None: f_thres = self.f_thres 126 | indexing = self.indexing if self.training else None 127 | 128 | with torch.no_grad(): 129 | result = self.f_solver(deq_func, x0=z_star, threshold=f_thres, # To reuse previous coarse fixed points 130 | eps=(1e-3 if self.stop_mode == "abs" else 1e-6), stop_mode=self.stop_mode, indexing=indexing) 131 | 132 | z_star, trajectory = result['result'], result['indexing'] 133 | 134 | return z_star, trajectory, min(result['rel_trace']), min(result['abs_trace']) 135 | 136 | def forward(self, image1, image2, 137 | flow_gt=None, valid=None, step_seq_loss=None, 138 | flow_init=None, cached_result=None, 139 | **kwargs): 140 | """ Estimate optical flow between pair of frames """ 141 | 142 | image1 = 2 * (image1 / 255.0) - 1.0 143 | image2 = 2 * (image2 / 255.0) - 1.0 144 | 145 | image1 = image1.contiguous() 146 | image2 = image2.contiguous() 147 | 148 | hdim = self.hidden_dim 149 | cdim = self.context_dim 150 | 151 | # run the feature network 152 | with autocast(enabled=self.args.mixed_precision): 153 | fmap1, fmap2 = self.fnet([image1, image2]) 154 | 155 | fmap1 = fmap1.float() 156 | fmap2 = fmap2.float() 157 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 158 | 159 | # run the context network 160 | with autocast(enabled=self.args.mixed_precision): 161 | inp = self.cnet(image1) 162 | inp = torch.relu(inp) 163 | 164 | bsz, _, H, W = inp.shape 165 | coords0, coords1 = self._initialize_flow(image1) 166 | net = torch.zeros(bsz, hdim, H, W, device=inp.device) 167 | if cached_result: 168 | net, flow_pred_prev = cached_result 169 | coords1 = coords0 + flow_pred_prev 170 | 171 | if flow_init is not None: 172 | coords1 = coords1 + flow_init 173 | 174 | seed = (inp.get_device() == 0 and np.random.uniform(0,1) < 2e-3) 175 | 176 | def list2vec(h, c): # h is net, c is coords1 177 | return torch.cat([h.view(bsz, h.shape[1], -1), c.view(bsz, c.shape[1], -1)], dim=1) 178 | 179 | def vec2list(hidden): 180 | return hidden[:,:net.shape[1]].view_as(net), hidden[:,net.shape[1]:].view_as(coords1) 181 | 182 | def deq_func(hidden): 183 | h, c = vec2list(hidden) 184 | c = c.detach() 185 | 186 | with autocast(enabled=self.args.mixed_precision): 187 | # corr_fn(coords1) produces the index correlation volumes 188 | new_h, delta_flow = self.update_block(h, inp, corr_fn(c), c-coords0, None) 189 | new_c = c + delta_flow # F(t+1) = F(t) + \Delta(t) 190 | return list2vec(new_h, new_c) 191 | 192 | self.update_block.reset() # In case we use weight normalization, we need to recompute the weight 193 | z_star = list2vec(net, coords1) 194 | 195 | # The code for DEQ version, where we use a wrapper. 196 | if self.training: 197 | _, trajectory, rel_error, abs_error = self._fixed_point_solve(deq_func, z_star, *kwargs) 198 | 199 | z_out = [] 200 | for z_pred, produce_grad in zip(trajectory, self.produce_grad): 201 | z_out += produce_grad(self, z_pred, deq_func) # See lib/grad.py for the backward pass implementations 202 | 203 | flow_predictions = self._decode(z_out, vec2list, coords0) 204 | 205 | flow_loss, epe = step_seq_loss(flow_predictions, flow_gt, valid) 206 | metrics = process_metrics(epe, rel_error, abs_error) 207 | 208 | return flow_loss, metrics 209 | else: 210 | # During inference, we directly solve for fixed point 211 | z_star, _, rel_error, abs_error = self._fixed_point_solve(deq_func, z_star, f_thres=self.eval_f_thres) 212 | flow_up = self._decode([z_star], vec2list, coords0)[0] 213 | net, coords1 = vec2list(z_star) 214 | 215 | return coords1 - coords0, flow_up, {"sradius": torch.zeros(1, device=z_star.device), "cached_result": (net, coords1 - coords0)} 216 | 217 | 218 | def get_model(args): 219 | return DEQFlowDemo 220 | -------------------------------------------------------------------------------- /code.v.1.0/core/extractor.py: -------------------------------------------------------------------------------- 1 | # This code was originally taken from RAFT without modification 2 | # https://github.com/princeton-vl/RAFT/blob/master/core/extractor.py 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class ResidualBlock(nn.Module): 10 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 11 | super(ResidualBlock, self).__init__() 12 | 13 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) 14 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 15 | self.relu = nn.ReLU(inplace=True) 16 | 17 | num_groups = planes // 8 18 | 19 | if norm_fn == 'group': 20 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 21 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 22 | if not stride == 1: 23 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 24 | 25 | elif norm_fn == 'batch': 26 | self.norm1 = nn.BatchNorm2d(planes) 27 | self.norm2 = nn.BatchNorm2d(planes) 28 | if not stride == 1: 29 | self.norm3 = nn.BatchNorm2d(planes) 30 | 31 | elif norm_fn == 'instance': 32 | self.norm1 = nn.InstanceNorm2d(planes) 33 | self.norm2 = nn.InstanceNorm2d(planes) 34 | if not stride == 1: 35 | self.norm3 = nn.InstanceNorm2d(planes) 36 | 37 | elif norm_fn == 'none': 38 | self.norm1 = nn.Sequential() 39 | self.norm2 = nn.Sequential() 40 | if not stride == 1: 41 | self.norm3 = nn.Sequential() 42 | 43 | if stride == 1: 44 | self.downsample = None 45 | 46 | else: 47 | self.downsample = nn.Sequential( 48 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 49 | 50 | 51 | def forward(self, x): 52 | y = x 53 | y = self.relu(self.norm1(self.conv1(y))) 54 | y = self.relu(self.norm2(self.conv2(y))) 55 | 56 | if self.downsample is not None: 57 | x = self.downsample(x) 58 | 59 | return self.relu(x+y) 60 | 61 | 62 | 63 | class BottleneckBlock(nn.Module): 64 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 65 | super(BottleneckBlock, self).__init__() 66 | 67 | self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) 68 | self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) 69 | self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) 70 | self.relu = nn.ReLU(inplace=True) 71 | 72 | num_groups = planes // 8 73 | 74 | if norm_fn == 'group': 75 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 76 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 77 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 78 | if not stride == 1: 79 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 80 | 81 | elif norm_fn == 'batch': 82 | self.norm1 = nn.BatchNorm2d(planes//4) 83 | self.norm2 = nn.BatchNorm2d(planes//4) 84 | self.norm3 = nn.BatchNorm2d(planes) 85 | if not stride == 1: 86 | self.norm4 = nn.BatchNorm2d(planes) 87 | 88 | elif norm_fn == 'instance': 89 | self.norm1 = nn.InstanceNorm2d(planes//4) 90 | self.norm2 = nn.InstanceNorm2d(planes//4) 91 | self.norm3 = nn.InstanceNorm2d(planes) 92 | if not stride == 1: 93 | self.norm4 = nn.InstanceNorm2d(planes) 94 | 95 | elif norm_fn == 'none': 96 | self.norm1 = nn.Sequential() 97 | self.norm2 = nn.Sequential() 98 | self.norm3 = nn.Sequential() 99 | if not stride == 1: 100 | self.norm4 = nn.Sequential() 101 | 102 | if stride == 1: 103 | self.downsample = None 104 | 105 | else: 106 | self.downsample = nn.Sequential( 107 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) 108 | 109 | 110 | def forward(self, x): 111 | y = x 112 | y = self.relu(self.norm1(self.conv1(y))) 113 | y = self.relu(self.norm2(self.conv2(y))) 114 | y = self.relu(self.norm3(self.conv3(y))) 115 | 116 | if self.downsample is not None: 117 | x = self.downsample(x) 118 | 119 | return self.relu(x+y) 120 | 121 | class BasicEncoder(nn.Module): 122 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 123 | super(BasicEncoder, self).__init__() 124 | self.norm_fn = norm_fn 125 | 126 | if self.norm_fn == 'group': 127 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) 128 | 129 | elif self.norm_fn == 'batch': 130 | self.norm1 = nn.BatchNorm2d(64) 131 | 132 | elif self.norm_fn == 'instance': 133 | self.norm1 = nn.InstanceNorm2d(64) 134 | 135 | elif self.norm_fn == 'none': 136 | self.norm1 = nn.Sequential() 137 | 138 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 139 | self.relu1 = nn.ReLU(inplace=True) 140 | 141 | self.in_planes = 64 142 | self.layer1 = self._make_layer(64, stride=1) 143 | self.layer2 = self._make_layer(96, stride=2) 144 | self.layer3 = self._make_layer(128, stride=2) 145 | 146 | # output convolution 147 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) 148 | 149 | self.dropout = None 150 | if dropout > 0: 151 | self.dropout = nn.Dropout2d(p=dropout) 152 | 153 | for m in self.modules(): 154 | if isinstance(m, nn.Conv2d): 155 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 156 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 157 | if m.weight is not None: 158 | nn.init.constant_(m.weight, 1) 159 | if m.bias is not None: 160 | nn.init.constant_(m.bias, 0) 161 | 162 | def _make_layer(self, dim, stride=1): 163 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) 164 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 165 | layers = (layer1, layer2) 166 | 167 | self.in_planes = dim 168 | return nn.Sequential(*layers) 169 | 170 | 171 | def forward(self, x): 172 | 173 | # if input is list, combine batch dimension 174 | is_list = isinstance(x, tuple) or isinstance(x, list) 175 | if is_list: 176 | batch_dim = x[0].shape[0] 177 | x = torch.cat(x, dim=0) 178 | 179 | x = self.conv1(x) 180 | x = self.norm1(x) 181 | x = self.relu1(x) 182 | 183 | x = self.layer1(x) 184 | x = self.layer2(x) 185 | x = self.layer3(x) 186 | 187 | x = self.conv2(x) 188 | 189 | if self.training and self.dropout is not None: 190 | x = self.dropout(x) 191 | 192 | if is_list: 193 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 194 | 195 | return x 196 | 197 | 198 | class SmallEncoder(nn.Module): 199 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 200 | super(SmallEncoder, self).__init__() 201 | self.norm_fn = norm_fn 202 | 203 | if self.norm_fn == 'group': 204 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) 205 | 206 | elif self.norm_fn == 'batch': 207 | self.norm1 = nn.BatchNorm2d(32) 208 | 209 | elif self.norm_fn == 'instance': 210 | self.norm1 = nn.InstanceNorm2d(32) 211 | 212 | elif self.norm_fn == 'none': 213 | self.norm1 = nn.Sequential() 214 | 215 | self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) 216 | self.relu1 = nn.ReLU(inplace=True) 217 | 218 | self.in_planes = 32 219 | self.layer1 = self._make_layer(32, stride=1) 220 | self.layer2 = self._make_layer(64, stride=2) 221 | self.layer3 = self._make_layer(96, stride=2) 222 | 223 | self.dropout = None 224 | if dropout > 0: 225 | self.dropout = nn.Dropout2d(p=dropout) 226 | 227 | self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) 228 | 229 | for m in self.modules(): 230 | if isinstance(m, nn.Conv2d): 231 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 232 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 233 | if m.weight is not None: 234 | nn.init.constant_(m.weight, 1) 235 | if m.bias is not None: 236 | nn.init.constant_(m.bias, 0) 237 | 238 | def _make_layer(self, dim, stride=1): 239 | layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) 240 | layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) 241 | layers = (layer1, layer2) 242 | 243 | self.in_planes = dim 244 | return nn.Sequential(*layers) 245 | 246 | 247 | def forward(self, x): 248 | 249 | # if input is list, combine batch dimension 250 | is_list = isinstance(x, tuple) or isinstance(x, list) 251 | if is_list: 252 | batch_dim = x[0].shape[0] 253 | x = torch.cat(x, dim=0) 254 | 255 | x = self.conv1(x) 256 | x = self.norm1(x) 257 | x = self.relu1(x) 258 | 259 | x = self.layer1(x) 260 | x = self.layer2(x) 261 | x = self.layer3(x) 262 | x = self.conv2(x) 263 | 264 | if self.training and self.dropout is not None: 265 | x = self.dropout(x) 266 | 267 | if is_list: 268 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 269 | 270 | return x 271 | -------------------------------------------------------------------------------- /code.v.1.0/core/gma.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | from einops import rearrange 4 | 5 | from lib.optimizations import weight_norm 6 | 7 | 8 | class RelPosEmb(nn.Module): 9 | def __init__( 10 | self, 11 | max_pos_size, 12 | dim_head 13 | ): 14 | super().__init__() 15 | self.rel_height = nn.Embedding(2 * max_pos_size - 1, dim_head) 16 | self.rel_width = nn.Embedding(2 * max_pos_size - 1, dim_head) 17 | 18 | deltas = torch.arange(max_pos_size).view(1, -1) - torch.arange(max_pos_size).view(-1, 1) 19 | rel_ind = deltas + max_pos_size - 1 20 | self.register_buffer('rel_ind', rel_ind) 21 | 22 | def forward(self, q): 23 | batch, heads, h, w, c = q.shape 24 | height_emb = self.rel_height(self.rel_ind[:h, :h].reshape(-1)) 25 | width_emb = self.rel_width(self.rel_ind[:w, :w].reshape(-1)) 26 | 27 | height_emb = rearrange(height_emb, '(x u) d -> x u () d', x=h) 28 | width_emb = rearrange(width_emb, '(y v) d -> y () v d', y=w) 29 | 30 | height_score = einsum('b h x y d, x u v d -> b h x y u v', q, height_emb) 31 | width_score = einsum('b h x y d, y u v d -> b h x y u v', q, width_emb) 32 | 33 | return height_score + width_score 34 | 35 | 36 | class Attention(nn.Module): 37 | def __init__( 38 | self, 39 | *, 40 | dim, 41 | max_pos_size = 100, 42 | heads = 4, 43 | dim_head = 128, 44 | ): 45 | super().__init__() 46 | self.heads = heads 47 | self.scale = dim_head ** -0.5 48 | inner_dim = heads * dim_head 49 | 50 | self.to_qk = nn.Conv2d(dim, inner_dim * 2, 1, bias=False) 51 | 52 | def _wnorm(self): 53 | self.to_qk, self.to_qk_fn = weight_norm(module=self.to_qk, names=['weight'], dim=0) 54 | 55 | def reset(self): 56 | for name in ['to_qk']: 57 | if name + '_fn' in self.__dict__: 58 | eval(f'self.{name}_fn').reset(eval(f'self.{name}')) 59 | 60 | def forward(self, fmap): 61 | heads, b, c, h, w = self.heads, *fmap.shape 62 | 63 | q, k = self.to_qk(fmap).chunk(2, dim=1) 64 | 65 | q, k = map(lambda t: rearrange(t, 'b (h d) x y -> b h x y d', h=heads), (q, k)) 66 | q = self.scale * q 67 | 68 | sim = einsum('b h x y d, b h u v d -> b h x y u v', q, k) 69 | 70 | sim = rearrange(sim, 'b h x y u v -> b h (x y) (u v)') 71 | attn = sim.softmax(dim=-1) 72 | 73 | return attn 74 | 75 | 76 | class Aggregate(nn.Module): 77 | def __init__( 78 | self, 79 | dim, 80 | heads = 4, 81 | dim_head = 128, 82 | ): 83 | super().__init__() 84 | self.heads = heads 85 | self.scale = dim_head ** -0.5 86 | inner_dim = heads * dim_head 87 | 88 | self.to_v = nn.Conv2d(dim, inner_dim, 1, bias=False) 89 | 90 | self.gamma = nn.Parameter(torch.zeros(1)) 91 | 92 | if dim != inner_dim: 93 | self.project = nn.Conv2d(inner_dim, dim, 1, bias=False) 94 | else: 95 | self.project = None 96 | 97 | def _wnorm(self): 98 | self.to_v, self.to_v_fn = weight_norm(module=self.to_v, names=['weight'], dim=0) 99 | 100 | if self.project: 101 | self.project, self.project_fn = weight_norm(module=self.project, names=['weight'], dim=0) 102 | 103 | def reset(self): 104 | for name in ['to_v', 'project']: 105 | if name + '_fn' in self.__dict__: 106 | eval(f'self.{name}_fn').reset(eval(f'self.{name}')) 107 | 108 | def forward(self, attn, fmap): 109 | heads, b, c, h, w = self.heads, *fmap.shape 110 | 111 | v = self.to_v(fmap) 112 | v = rearrange(v, 'b (h d) x y -> b h (x y) d', h=heads) 113 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 114 | out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w) 115 | 116 | if self.project is not None: 117 | out = self.project(out) 118 | 119 | out = fmap + self.gamma * out 120 | 121 | return out 122 | 123 | 124 | if __name__ == "__main__": 125 | att = Attention(dim=128, heads=1) 126 | fmap = torch.randn(2, 128, 40, 90) 127 | out = att(fmap) 128 | 129 | print(out.shape) 130 | -------------------------------------------------------------------------------- /code.v.1.0/core/lib/grad.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch import autograd 6 | 7 | from .solvers import anderson, broyden 8 | 9 | 10 | def make_pair(target, source): 11 | if len(target) == len(source): 12 | return source 13 | elif len(source) == 1: 14 | return [source[0] for _ in range(len(target))] 15 | else: 16 | raise ValueError('Unable to align the arg squence!') 17 | 18 | 19 | def backward_factory( 20 | grad_type='ift', 21 | safe_ift=False, 22 | b_solver=anderson, 23 | b_solver_kwargs=dict(), 24 | sup_all=False, 25 | tau=1.0): 26 | """ 27 | [2019-NeurIPS] Deep Equilibrium Models 28 | [2021-ICLR] Is Attention Better Than Matrix Decomposition? 29 | [2021-NeurIPS] On Training Implicit Models 30 | [2022-AAAI] JFB: Jacobian-Free Backpropagation for Implicit Networks 31 | 32 | This function implements a factory for the backward pass of implicit deep learning, 33 | e.g., DEQ (implicit models), Hamburger (optimization layer), etc. 34 | It now supports IFT, 1-step Grad, and Phantom Grad. 35 | 36 | Kwargs: 37 | grad_type (string, int): 38 | grad_type should be ``ift`` or an int. Default ``ift``. 39 | Set to ``ift`` to enable the implicit differentiation mode. 40 | When passing a number k to this function, it runs UPG with steps k and damping tau. 41 | safe_ift (bool): 42 | Replace the O(1) hook implementeion with a safer one. Default ``False``. 43 | Set to ``True`` to avoid the (potential) segment fault (under previous versions of Pytorch). 44 | b_solver (type): 45 | Solver for the IFT backward pass. Default ``anderson``. 46 | Supported solvers: anderson, broyden. 47 | b_solver_kwargs (dict): 48 | Colllection of backward solver kwargs, e.g., 49 | threshold (int), max steps for the backward solver, 50 | stop_mode (string), criterion for convergence, 51 | etc. 52 | See solver.py to check all the kwargs. 53 | sup_all (bool): 54 | Indicate whether to supervise all the trajectories by Phantom Grad. 55 | Set ``True`` to return all trajectory in Phantom Grad. 56 | tau (float): 57 | Damping factor for Phantom Grad. Default ``1.0``. 58 | 0.5 is recommended for CIFAR-10. 1.0 for DEQ flow. 59 | For DEQ flow, the gating function in GRU naturally produces adaptive tau values. 60 | 61 | Return: 62 | A gradient functor for implicit deep learning. 63 | Args: 64 | trainer (nn.Module): the module that employs implicit deep learning. 65 | z_pred (torch.Tensor): latent state to run the backward pass. 66 | func (type): function that defines the ``f`` in ``z = f(z)``. 67 | 68 | Return: 69 | (list(torch.Tensor)): a list of tensors that tracks the gradient info. 70 | 71 | """ 72 | 73 | if grad_type == 'ift': 74 | assert b_solver in [anderson, broyden] 75 | 76 | if safe_ift: 77 | def plain_ift_grad(trainer, z_pred, func): 78 | z_pred = z_pred.requires_grad_() 79 | new_z_pred = func(z_pred) # 1-step grad for df/dtheta 80 | 81 | z_pred_copy = new_z_pred.clone().detach().requires_grad_() 82 | new_z_pred_copy = func(z_pred_copy) 83 | def backward_hook(grad): 84 | result = b_solver(lambda y: autograd.grad(new_z_pred_copy, z_pred_copy, y, retain_graph=True)[0] + grad, 85 | torch.zeros_like(grad), **b_solver_kwargs) 86 | return result['result'] 87 | new_z_pred.register_hook(backward_hook) 88 | 89 | return [new_z_pred] 90 | return plain_ift_grad 91 | else: 92 | def hook_ift_grad(trainer, z_pred, func): 93 | z_pred = z_pred.requires_grad_() 94 | new_z_pred = func(z_pred) # 1-step grad for df/dtheta 95 | 96 | def backward_hook(grad): 97 | if trainer.hook is not None: 98 | trainer.hook.remove() # To avoid infinite loop 99 | result = b_solver(lambda y: autograd.grad(new_z_pred, z_pred, y, retain_graph=True)[0] + grad, 100 | torch.zeros_like(grad), **b_solver_kwargs) 101 | return result['result'] 102 | trainer.hook = new_z_pred.register_hook(backward_hook) 103 | 104 | return [new_z_pred] 105 | return hook_ift_grad 106 | else: 107 | assert type(grad_type) is int and grad_type >= 1 108 | n_phantom_grad = grad_type 109 | 110 | if sup_all: 111 | def sup_all_phantom_grad(trainer, z_pred, func): 112 | z_out = [] 113 | for _ in range(n_phantom_grad): 114 | z_pred = (1 - tau) * z_pred + tau * func(z_pred) 115 | z_out.append(z_pred) 116 | 117 | return z_out 118 | return sup_all_phantom_grad 119 | else: 120 | def phantom_grad(trainer, z_pred, func): 121 | for _ in range(n_phantom_grad): 122 | z_pred = (1 - tau) * z_pred + tau * func(z_pred) 123 | 124 | return [z_pred] 125 | return phantom_grad 126 | -------------------------------------------------------------------------------- /code.v.1.0/core/lib/jacobian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | 7 | def jac_loss_estimate(f0, z0, vecs=2, create_graph=True): 8 | """Estimating tr(J^TJ)=tr(JJ^T) via Hutchinson estimator 9 | 10 | Args: 11 | f0 (torch.Tensor): Output of the function f (whose J is to be analyzed) 12 | z0 (torch.Tensor): Input to the function f 13 | vecs (int, optional): Number of random Gaussian vectors to use. Defaults to 2. 14 | create_graph (bool, optional): Whether to create backward graph (e.g., to train on this loss). 15 | Defaults to True. 16 | 17 | Returns: 18 | torch.Tensor: A 1x1 torch tensor that encodes the (shape-normalized) jacobian loss 19 | """ 20 | vecs = vecs 21 | result = 0 22 | for i in range(vecs): 23 | v = torch.randn(*z0.shape).to(z0) 24 | vJ = torch.autograd.grad(f0, z0, v, retain_graph=True, create_graph=create_graph)[0] 25 | result += vJ.norm()**2 26 | return result / vecs / np.prod(z0.shape) 27 | 28 | def power_method(f0, z0, n_iters=200): 29 | """Estimating the spectral radius of J using power method 30 | 31 | Args: 32 | f0 (torch.Tensor): Output of the function f (whose J is to be analyzed) 33 | z0 (torch.Tensor): Input to the function f 34 | n_iters (int, optional): Number of power method iterations. Defaults to 200. 35 | 36 | Returns: 37 | tuple: (largest eigenvector, largest (abs.) eigenvalue) 38 | """ 39 | evector = torch.randn_like(z0) 40 | bsz = evector.shape[0] 41 | for i in range(n_iters): 42 | vTJ = torch.autograd.grad(f0, z0, evector, retain_graph=(i < n_iters-1), create_graph=False)[0] 43 | evalue = (vTJ * evector).reshape(bsz, -1).sum(1, keepdim=True) / (evector * evector).reshape(bsz, -1).sum(1, keepdim=True) 44 | evector = (vTJ.reshape(bsz, -1) / vTJ.reshape(bsz, -1).norm(dim=1, keepdim=True)).reshape_as(z0) 45 | return (evector, torch.abs(evalue)) -------------------------------------------------------------------------------- /code.v.1.0/core/lib/layer_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | 6 | def list2vec(z1_list): 7 | """Convert list of tensors to a vector""" 8 | bsz = z1_list[0].size(0) 9 | return torch.cat([elem.reshape(bsz, -1, 1) for elem in z1_list], dim=1) 10 | 11 | 12 | def vec2list(z1, cutoffs): 13 | """Convert a vector back to a list, via the cutoffs specified""" 14 | bsz = z1.shape[0] 15 | z1_list = [] 16 | start_idx, end_idx = 0, cutoffs[0][0] * cutoffs[0][1] * cutoffs[0][2] 17 | for i in range(len(cutoffs)): 18 | z1_list.append(z1[:, start_idx:end_idx].view(bsz, *cutoffs[i])) 19 | if i < len(cutoffs)-1: 20 | start_idx = end_idx 21 | end_idx += cutoffs[i + 1][0] * cutoffs[i + 1][1] * cutoffs[i + 1][2] 22 | return z1_list 23 | 24 | 25 | def conv3x3(in_planes, out_planes, stride=1, bias=False): 26 | """3x3 convolution with padding""" 27 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=bias) 28 | 29 | def conv5x5(in_planes, out_planes, stride=1, bias=False): 30 | """5x5 convolution with padding""" 31 | return nn.Conv2d(in_planes, out_planes, kernel_size=5, stride=stride, padding=2, bias=bias) 32 | 33 | 34 | def norm_diff(new, old, show_list=False): 35 | if show_list: 36 | return [(new[i] - old[i]).norm().item() for i in range(len(new))] 37 | return np.sqrt(sum((new[i] - old[i]).norm().item()**2 for i in range(len(new)))) -------------------------------------------------------------------------------- /code.v.1.0/core/lib/optimizations.py: -------------------------------------------------------------------------------- 1 | from torch.nn.parameter import Parameter 2 | import torch.nn as nn 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | ############################################################################################################## 8 | # 9 | # Temporal DropConnect in a feed-forward setting 10 | # 11 | ############################################################################################################## 12 | 13 | class WeightDrop(torch.nn.Module): 14 | def __init__(self, module, weights, dropout=0, temporal=True): 15 | """ 16 | Weight DropConnect, adapted from a recurrent setting by Merity et al. 2017 17 | 18 | :param module: The module whose weights are to be applied dropout on 19 | :param weights: A 2D list identifying the weights to be regularized. Each element of weights should be a 20 | list containing the "path" to the weight kernel. For instance, if we want to regularize 21 | module.layer2.weight3, then this should be ["layer2", "weight3"]. 22 | :param dropout: The dropout rate (0 means no dropout) 23 | :param temporal: Whether we apply DropConnect only to the temporal parts of the weight (empirically we found 24 | this not very important) 25 | """ 26 | super(WeightDrop, self).__init__() 27 | self.module = module 28 | self.weights = weights 29 | self.dropout = dropout 30 | self.temporal = temporal 31 | if self.dropout > 0.0: 32 | self._setup() 33 | 34 | def _setup(self): 35 | for path in self.weights: 36 | full_name_w = '.'.join(path) 37 | 38 | module = self.module 39 | name_w = path[-1] 40 | for i in range(len(path) - 1): 41 | module = getattr(module, path[i]) 42 | w = getattr(module, name_w) 43 | del module._parameters[name_w] 44 | module.register_parameter(name_w + '_raw', Parameter(w.data)) 45 | 46 | def _setweights(self): 47 | for path in self.weights: 48 | module = self.module 49 | name_w = path[-1] 50 | for i in range(len(path) - 1): 51 | module = getattr(module, path[i]) 52 | raw_w = getattr(module, name_w + '_raw') 53 | 54 | if len(raw_w.size()) > 2 and raw_w.size(2) > 1 and self.temporal: 55 | # Drop the temporal parts of the weight; if 1x1 convolution then drop the whole kernel 56 | w = torch.cat([F.dropout(raw_w[:, :, :-1], p=self.dropout, training=self.training), 57 | raw_w[:, :, -1:]], dim=2) 58 | else: 59 | w = F.dropout(raw_w, p=self.dropout, training=self.training) 60 | 61 | setattr(module, name_w, w) 62 | 63 | def forward(self, *args, **kwargs): 64 | if self.dropout > 0.0: 65 | self._setweights() 66 | return self.module.forward(*args, **kwargs) 67 | 68 | 69 | def matrix_diag(a, dim=2): 70 | """ 71 | a has dimension (N, (L,) C), we want a matrix/batch diag that produces (N, (L,) C, C) from the last dimension of a 72 | """ 73 | if dim == 2: 74 | res = torch.zeros(a.size(0), a.size(1), a.size(1)) 75 | res.as_strided(a.size(), [res.stride(0), res.size(2)+1]).copy_(a) 76 | else: 77 | res = torch.zeros(a.size(0), a.size(1), a.size(2), a.size(2)) 78 | res.as_strided(a.size(), [res.stride(0), res.stride(1), res.size(3)+1]).copy_(a) 79 | return res 80 | 81 | ############################################################################################################## 82 | # 83 | # Embedding dropout 84 | # 85 | ############################################################################################################## 86 | 87 | def embedded_dropout(embed, words, dropout=0.1, scale=None): 88 | """ 89 | Apply embedding encoder (whose weight we apply a dropout) 90 | 91 | :param embed: The embedding layer 92 | :param words: The input sequence 93 | :param dropout: The embedding weight dropout rate 94 | :param scale: Scaling factor for the dropped embedding weight 95 | :return: The embedding output 96 | """ 97 | if dropout: 98 | mask = embed.weight.data.new().resize_((embed.weight.size(0), 1)).bernoulli_(1 - dropout).expand_as( 99 | embed.weight) / (1 - dropout) 100 | mask = Variable(mask) 101 | masked_embed_weight = mask * embed.weight 102 | else: 103 | masked_embed_weight = embed.weight 104 | 105 | if scale: 106 | masked_embed_weight = scale.expand_as(masked_embed_weight) * masked_embed_weight 107 | 108 | padding_idx = embed.padding_idx 109 | if padding_idx is None: 110 | padding_idx = -1 111 | 112 | X = F.embedding(words, masked_embed_weight, padding_idx, embed.max_norm, embed.norm_type, 113 | embed.scale_grad_by_freq, embed.sparse) 114 | return X 115 | 116 | 117 | 118 | ############################################################################################################## 119 | # 120 | # Variational dropout (for input/output layers, and for hidden layers) 121 | # 122 | ############################################################################################################## 123 | 124 | 125 | class VariationalHidDropout2d(nn.Module): 126 | def __init__(self, dropout=0.0): 127 | super(VariationalHidDropout2d, self).__init__() 128 | self.dropout = dropout 129 | self.mask = None 130 | 131 | def forward(self, x): 132 | if not self.training or self.dropout == 0: 133 | return x 134 | bsz, d, H, W = x.shape 135 | if self.mask is None: 136 | m = torch.zeros(bsz, d, H, W).bernoulli_(1 - self.dropout).to(x) 137 | self.mask = m.requires_grad_(False) / (1 - self.dropout) 138 | return self.mask * x 139 | 140 | ############################################################################################################## 141 | # 142 | # Weight normalization. Modified from the original PyTorch's implementation of weight normalization. 143 | # 144 | ############################################################################################################## 145 | 146 | def _norm(p, dim): 147 | """Computes the norm over all dimensions except dim""" 148 | if dim is None: 149 | return p.norm() 150 | elif dim == 0: 151 | output_size = (p.size(0),) + (1,) * (p.dim() - 1) 152 | return p.contiguous().view(p.size(0), -1).norm(dim=1).view(*output_size) 153 | elif dim == p.dim() - 1: 154 | output_size = (1,) * (p.dim() - 1) + (p.size(-1),) 155 | return p.contiguous().view(-1, p.size(-1)).norm(dim=0).view(*output_size) 156 | else: 157 | return _norm(p.transpose(0, dim), 0).transpose(0, dim) 158 | 159 | 160 | class WeightNorm(object): 161 | def __init__(self, names, dim): 162 | """ 163 | Weight normalization module 164 | 165 | :param names: The list of weight names to apply weightnorm on 166 | :param dim: The dimension of the weights to be normalized 167 | """ 168 | self.names = names 169 | self.dim = dim 170 | 171 | def compute_weight(self, module, name): 172 | g = getattr(module, name + '_g') 173 | v = getattr(module, name + '_v') 174 | return v * (g / _norm(v, self.dim)) 175 | 176 | @staticmethod 177 | def apply(module, names, dim): 178 | fn = WeightNorm(names, dim) 179 | 180 | for name in names: 181 | weight = getattr(module, name) 182 | 183 | # remove w from parameter list 184 | del module._parameters[name] 185 | 186 | # add g and v as new parameters and express w as g/||v|| * v 187 | module.register_parameter(name + '_g', Parameter(_norm(weight, dim).data)) 188 | module.register_parameter(name + '_v', Parameter(weight.data)) 189 | setattr(module, name, fn.compute_weight(module, name)) 190 | 191 | # recompute weight before every forward() 192 | module.register_forward_pre_hook(fn) 193 | return fn 194 | 195 | def remove(self, module): 196 | for name in self.names: 197 | weight = self.compute_weight(module, name) 198 | delattr(module, name) 199 | del module._parameters[name + '_g'] 200 | del module._parameters[name + '_v'] 201 | module.register_parameter(name, Parameter(weight.data)) 202 | 203 | def reset(self, module): 204 | for name in self.names: 205 | setattr(module, name, self.compute_weight(module, name)) 206 | 207 | def __call__(self, module, inputs): 208 | # Typically, every time the module is called we need to recompute the weight. However, 209 | # in the case of TrellisNet, the same weight is shared across layers, and we can save 210 | # a lot of intermediate memory by just recomputing once (at the beginning of first call). 211 | pass 212 | 213 | 214 | def weight_norm(module, names, dim=0): 215 | fn = WeightNorm.apply(module, names, dim) 216 | return module, fn 217 | -------------------------------------------------------------------------------- /code.v.1.0/core/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | MAX_FLOW = 400 4 | 5 | @torch.no_grad() 6 | def compute_epe(flow_pred, flow_gt, valid, max_flow=MAX_FLOW): 7 | # exlude invalid pixels and extremely large diplacements 8 | mag = torch.sum(flow_gt**2, dim=1).sqrt() 9 | valid = (valid >= 0.5) & (mag < max_flow) 10 | 11 | epe = torch.sum((flow_pred - flow_gt)**2, dim=1).sqrt() 12 | epe = torch.masked_fill(epe, ~valid, 0) 13 | 14 | return epe 15 | 16 | 17 | @torch.no_grad() 18 | def process_metrics(epe, rel_error, abs_error, **kwargs): 19 | epe = epe.flatten(1) 20 | metrics = torch.stack( 21 | [ 22 | epe.mean(dim=1), 23 | (epe < 1).float().mean(dim=1), 24 | (epe < 3).float().mean(dim=1), 25 | (epe < 5).float().mean(dim=1), 26 | torch.tensor(rel_error).cuda().repeat(epe.shape[0]), 27 | torch.tensor(abs_error).cuda().repeat(epe.shape[0]), 28 | ], 29 | dim=1 30 | ) 31 | 32 | # (B // N_GPU, N_Metrics) 33 | return metrics 34 | 35 | 36 | def merge_metrics(metrics): 37 | metrics = metrics.mean(dim=0) 38 | metrics = { 39 | 'epe': metrics[0].item(), 40 | '1px': metrics[1].item(), 41 | '3px': metrics[2].item(), 42 | '5px': metrics[3].item(), 43 | 'rel': metrics[4].item(), 44 | 'abs': metrics[5].item(), 45 | } 46 | 47 | return metrics 48 | -------------------------------------------------------------------------------- /code.v.1.0/core/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/deq-flow/869d4bcd9c1d227cbfa16639994d1b118510a9bb/code.v.1.0/core/utils/__init__.py -------------------------------------------------------------------------------- /code.v.1.0/core/utils/augmentor.py: -------------------------------------------------------------------------------- 1 | # This code was originally taken from RAFT without modification 2 | # https://github.com/princeton-vl/RAFT/blob/master/utils/augmentor.py 3 | 4 | import numpy as np 5 | import random 6 | import math 7 | from PIL import Image 8 | 9 | import cv2 10 | cv2.setNumThreads(0) 11 | cv2.ocl.setUseOpenCL(False) 12 | 13 | import torch 14 | from torchvision.transforms import ColorJitter 15 | import torch.nn.functional as F 16 | 17 | 18 | class FlowAugmentor: 19 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): 20 | 21 | # spatial augmentation params 22 | self.crop_size = crop_size 23 | self.min_scale = min_scale 24 | self.max_scale = max_scale 25 | self.spatial_aug_prob = 0.8 26 | self.stretch_prob = 0.8 27 | self.max_stretch = 0.2 28 | 29 | # flip augmentation params 30 | self.do_flip = do_flip 31 | self.h_flip_prob = 0.5 32 | self.v_flip_prob = 0.1 33 | 34 | # photometric augmentation params 35 | self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14) 36 | self.asymmetric_color_aug_prob = 0.2 37 | self.eraser_aug_prob = 0.5 38 | 39 | def color_transform(self, img1, img2): 40 | """ Photometric augmentation """ 41 | 42 | # asymmetric 43 | if np.random.rand() < self.asymmetric_color_aug_prob: 44 | img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) 45 | img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) 46 | 47 | # symmetric 48 | else: 49 | image_stack = np.concatenate([img1, img2], axis=0) 50 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 51 | img1, img2 = np.split(image_stack, 2, axis=0) 52 | 53 | return img1, img2 54 | 55 | def eraser_transform(self, img1, img2, bounds=[50, 100]): 56 | """ Occlusion augmentation """ 57 | 58 | ht, wd = img1.shape[:2] 59 | if np.random.rand() < self.eraser_aug_prob: 60 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 61 | for _ in range(np.random.randint(1, 3)): 62 | x0 = np.random.randint(0, wd) 63 | y0 = np.random.randint(0, ht) 64 | dx = np.random.randint(bounds[0], bounds[1]) 65 | dy = np.random.randint(bounds[0], bounds[1]) 66 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color 67 | 68 | return img1, img2 69 | 70 | def spatial_transform(self, img1, img2, flow): 71 | # randomly sample scale 72 | ht, wd = img1.shape[:2] 73 | min_scale = np.maximum( 74 | (self.crop_size[0] + 8) / float(ht), 75 | (self.crop_size[1] + 8) / float(wd)) 76 | 77 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 78 | scale_x = scale 79 | scale_y = scale 80 | if np.random.rand() < self.stretch_prob: 81 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 82 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 83 | 84 | scale_x = np.clip(scale_x, min_scale, None) 85 | scale_y = np.clip(scale_y, min_scale, None) 86 | 87 | if np.random.rand() < self.spatial_aug_prob: 88 | # rescale the images 89 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 90 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 91 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 92 | flow = flow * [scale_x, scale_y] 93 | 94 | if self.do_flip: 95 | if np.random.rand() < self.h_flip_prob: # h-flip 96 | img1 = img1[:, ::-1] 97 | img2 = img2[:, ::-1] 98 | flow = flow[:, ::-1] * [-1.0, 1.0] 99 | 100 | if np.random.rand() < self.v_flip_prob: # v-flip 101 | img1 = img1[::-1, :] 102 | img2 = img2[::-1, :] 103 | flow = flow[::-1, :] * [1.0, -1.0] 104 | 105 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) 106 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) 107 | 108 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 109 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 110 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 111 | 112 | return img1, img2, flow 113 | 114 | def __call__(self, img1, img2, flow): 115 | img1, img2 = self.color_transform(img1, img2) 116 | img1, img2 = self.eraser_transform(img1, img2) 117 | img1, img2, flow = self.spatial_transform(img1, img2, flow) 118 | 119 | img1 = np.ascontiguousarray(img1) 120 | img2 = np.ascontiguousarray(img2) 121 | flow = np.ascontiguousarray(flow) 122 | 123 | return img1, img2, flow 124 | 125 | class SparseFlowAugmentor: 126 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): 127 | # spatial augmentation params 128 | self.crop_size = crop_size 129 | self.min_scale = min_scale 130 | self.max_scale = max_scale 131 | self.spatial_aug_prob = 0.8 132 | self.stretch_prob = 0.8 133 | self.max_stretch = 0.2 134 | 135 | # flip augmentation params 136 | self.do_flip = do_flip 137 | self.h_flip_prob = 0.5 138 | self.v_flip_prob = 0.1 139 | 140 | # photometric augmentation params 141 | self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) 142 | self.asymmetric_color_aug_prob = 0.2 143 | self.eraser_aug_prob = 0.5 144 | 145 | def color_transform(self, img1, img2): 146 | image_stack = np.concatenate([img1, img2], axis=0) 147 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 148 | img1, img2 = np.split(image_stack, 2, axis=0) 149 | return img1, img2 150 | 151 | def eraser_transform(self, img1, img2): 152 | ht, wd = img1.shape[:2] 153 | if np.random.rand() < self.eraser_aug_prob: 154 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 155 | for _ in range(np.random.randint(1, 3)): 156 | x0 = np.random.randint(0, wd) 157 | y0 = np.random.randint(0, ht) 158 | dx = np.random.randint(50, 100) 159 | dy = np.random.randint(50, 100) 160 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color 161 | 162 | return img1, img2 163 | 164 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): 165 | ht, wd = flow.shape[:2] 166 | coords = np.meshgrid(np.arange(wd), np.arange(ht)) 167 | coords = np.stack(coords, axis=-1) 168 | 169 | coords = coords.reshape(-1, 2).astype(np.float32) 170 | flow = flow.reshape(-1, 2).astype(np.float32) 171 | valid = valid.reshape(-1).astype(np.float32) 172 | 173 | coords0 = coords[valid>=1] 174 | flow0 = flow[valid>=1] 175 | 176 | ht1 = int(round(ht * fy)) 177 | wd1 = int(round(wd * fx)) 178 | 179 | coords1 = coords0 * [fx, fy] 180 | flow1 = flow0 * [fx, fy] 181 | 182 | xx = np.round(coords1[:,0]).astype(np.int32) 183 | yy = np.round(coords1[:,1]).astype(np.int32) 184 | 185 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) 186 | xx = xx[v] 187 | yy = yy[v] 188 | flow1 = flow1[v] 189 | 190 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) 191 | valid_img = np.zeros([ht1, wd1], dtype=np.int32) 192 | 193 | flow_img[yy, xx] = flow1 194 | valid_img[yy, xx] = 1 195 | 196 | return flow_img, valid_img 197 | 198 | def spatial_transform(self, img1, img2, flow, valid): 199 | # randomly sample scale 200 | 201 | ht, wd = img1.shape[:2] 202 | min_scale = np.maximum( 203 | (self.crop_size[0] + 1) / float(ht), 204 | (self.crop_size[1] + 1) / float(wd)) 205 | 206 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 207 | scale_x = np.clip(scale, min_scale, None) 208 | scale_y = np.clip(scale, min_scale, None) 209 | 210 | if np.random.rand() < self.spatial_aug_prob: 211 | # rescale the images 212 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 213 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 214 | flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) 215 | 216 | if self.do_flip: 217 | if np.random.rand() < 0.5: # h-flip 218 | img1 = img1[:, ::-1] 219 | img2 = img2[:, ::-1] 220 | flow = flow[:, ::-1] * [-1.0, 1.0] 221 | valid = valid[:, ::-1] 222 | 223 | margin_y = 20 224 | margin_x = 50 225 | 226 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) 227 | x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) 228 | 229 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) 230 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) 231 | 232 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 233 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 234 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 235 | valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 236 | return img1, img2, flow, valid 237 | 238 | 239 | def __call__(self, img1, img2, flow, valid): 240 | img1, img2 = self.color_transform(img1, img2) 241 | img1, img2 = self.eraser_transform(img1, img2) 242 | img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) 243 | 244 | img1 = np.ascontiguousarray(img1) 245 | img2 = np.ascontiguousarray(img2) 246 | flow = np.ascontiguousarray(flow) 247 | valid = np.ascontiguousarray(valid) 248 | 249 | return img1, img2, flow, valid 250 | -------------------------------------------------------------------------------- /code.v.1.0/core/utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | 20 | def make_colorwheel(): 21 | """ 22 | Generates a color wheel for optical flow visualization as presented in: 23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 25 | 26 | Code follows the original C++ source code of Daniel Scharstein. 27 | Code follows the the Matlab source code of Deqing Sun. 28 | 29 | Returns: 30 | np.ndarray: Color wheel 31 | """ 32 | 33 | RY = 15 34 | YG = 6 35 | GC = 4 36 | CB = 11 37 | BM = 13 38 | MR = 6 39 | 40 | ncols = RY + YG + GC + CB + BM + MR 41 | colorwheel = np.zeros((ncols, 3)) 42 | col = 0 43 | 44 | # RY 45 | colorwheel[0:RY, 0] = 255 46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 47 | col = col+RY 48 | # YG 49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 50 | colorwheel[col:col+YG, 1] = 255 51 | col = col+YG 52 | # GC 53 | colorwheel[col:col+GC, 1] = 255 54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 55 | col = col+GC 56 | # CB 57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 58 | colorwheel[col:col+CB, 2] = 255 59 | col = col+CB 60 | # BM 61 | colorwheel[col:col+BM, 2] = 255 62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 63 | col = col+BM 64 | # MR 65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 66 | colorwheel[col:col+MR, 0] = 255 67 | return colorwheel 68 | 69 | 70 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 71 | """ 72 | Applies the flow color wheel to (possibly clipped) flow components u and v. 73 | 74 | According to the C++ source code of Daniel Scharstein 75 | According to the Matlab source code of Deqing Sun 76 | 77 | Args: 78 | u (np.ndarray): Input horizontal flow of shape [H,W] 79 | v (np.ndarray): Input vertical flow of shape [H,W] 80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 81 | 82 | Returns: 83 | np.ndarray: Flow visualization image of shape [H,W,3] 84 | """ 85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 86 | colorwheel = make_colorwheel() # shape [55x3] 87 | ncols = colorwheel.shape[0] 88 | rad = np.sqrt(np.square(u) + np.square(v)) 89 | a = np.arctan2(-v, -u)/np.pi 90 | fk = (a+1) / 2*(ncols-1) 91 | k0 = np.floor(fk).astype(np.int32) 92 | k1 = k0 + 1 93 | k1[k1 == ncols] = 0 94 | f = fk - k0 95 | for i in range(colorwheel.shape[1]): 96 | tmp = colorwheel[:,i] 97 | col0 = tmp[k0] / 255.0 98 | col1 = tmp[k1] / 255.0 99 | col = (1-f)*col0 + f*col1 100 | idx = (rad <= 1) 101 | col[idx] = 1 - rad[idx] * (1-col[idx]) 102 | col[~idx] = col[~idx] * 0.75 # out of range 103 | # Note the 2-i => BGR instead of RGB 104 | ch_idx = 2-i if convert_to_bgr else i 105 | flow_image[:,:,ch_idx] = np.floor(255 * col) 106 | return flow_image 107 | 108 | 109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 110 | """ 111 | Expects a two dimensional flow image of shape. 112 | 113 | Args: 114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 117 | 118 | Returns: 119 | np.ndarray: Flow visualization image of shape [H,W,3] 120 | """ 121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 123 | if clip_flow is not None: 124 | flow_uv = np.clip(flow_uv, 0, clip_flow) 125 | u = flow_uv[:,:,0] 126 | v = flow_uv[:,:,1] 127 | rad = np.sqrt(np.square(u) + np.square(v)) 128 | rad_max = np.max(rad) 129 | epsilon = 1e-5 130 | u = u / (rad_max + epsilon) 131 | v = v / (rad_max + epsilon) 132 | return flow_uv_to_colors(u, v, convert_to_bgr) -------------------------------------------------------------------------------- /code.v.1.0/core/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | # This code was originally taken from RAFT without modification 2 | # https://github.com/princeton-vl/RAFT/blob/master/utils/frame_utils.py 3 | 4 | import numpy as np 5 | from PIL import Image 6 | from os.path import * 7 | import re 8 | 9 | import cv2 10 | cv2.setNumThreads(0) 11 | cv2.ocl.setUseOpenCL(False) 12 | 13 | TAG_CHAR = np.array([202021.25], np.float32) 14 | 15 | def readFlow(fn): 16 | """ Read .flo file in Middlebury format""" 17 | # Code adapted from: 18 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 19 | 20 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 21 | # print 'fn = %s'%(fn) 22 | with open(fn, 'rb') as f: 23 | magic = np.fromfile(f, np.float32, count=1) 24 | if 202021.25 != magic: 25 | print('Magic number incorrect. Invalid .flo file') 26 | return None 27 | else: 28 | w = np.fromfile(f, np.int32, count=1) 29 | h = np.fromfile(f, np.int32, count=1) 30 | # print 'Reading %d x %d flo file\n' % (w, h) 31 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) 32 | # Reshape data into 3D array (columns, rows, bands) 33 | # The reshape here is for visualization, the original code is (w,h,2) 34 | return np.resize(data, (int(h), int(w), 2)) 35 | 36 | def readPFM(file): 37 | file = open(file, 'rb') 38 | 39 | color = None 40 | width = None 41 | height = None 42 | scale = None 43 | endian = None 44 | 45 | header = file.readline().rstrip() 46 | if header == b'PF': 47 | color = True 48 | elif header == b'Pf': 49 | color = False 50 | else: 51 | raise Exception('Not a PFM file.') 52 | 53 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 54 | if dim_match: 55 | width, height = map(int, dim_match.groups()) 56 | else: 57 | raise Exception('Malformed PFM header.') 58 | 59 | scale = float(file.readline().rstrip()) 60 | if scale < 0: # little-endian 61 | endian = '<' 62 | scale = -scale 63 | else: 64 | endian = '>' # big-endian 65 | 66 | data = np.fromfile(file, endian + 'f') 67 | shape = (height, width, 3) if color else (height, width) 68 | 69 | data = np.reshape(data, shape) 70 | data = np.flipud(data) 71 | return data 72 | 73 | def writeFlow(filename,uv,v=None): 74 | """ Write optical flow to file. 75 | 76 | If v is None, uv is assumed to contain both u and v channels, 77 | stacked in depth. 78 | Original code by Deqing Sun, adapted from Daniel Scharstein. 79 | """ 80 | nBands = 2 81 | 82 | if v is None: 83 | assert(uv.ndim == 3) 84 | assert(uv.shape[2] == 2) 85 | u = uv[:,:,0] 86 | v = uv[:,:,1] 87 | else: 88 | u = uv 89 | 90 | assert(u.shape == v.shape) 91 | height,width = u.shape 92 | f = open(filename,'wb') 93 | # write the header 94 | f.write(TAG_CHAR) 95 | np.array(width).astype(np.int32).tofile(f) 96 | np.array(height).astype(np.int32).tofile(f) 97 | # arrange into matrix form 98 | tmp = np.zeros((height, width*nBands)) 99 | tmp[:,np.arange(width)*2] = u 100 | tmp[:,np.arange(width)*2 + 1] = v 101 | tmp.astype(np.float32).tofile(f) 102 | f.close() 103 | 104 | 105 | def readFlowKITTI(filename): 106 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) 107 | flow = flow[:,:,::-1].astype(np.float32) 108 | flow, valid = flow[:, :, :2], flow[:, :, 2] 109 | flow = (flow - 2**15) / 64.0 110 | return flow, valid 111 | 112 | def readDispKITTI(filename): 113 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 114 | valid = disp > 0.0 115 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 116 | return flow, valid 117 | 118 | 119 | def writeFlowKITTI(filename, uv): 120 | uv = 64.0 * uv + 2**15 121 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 122 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 123 | cv2.imwrite(filename, uv[..., ::-1]) 124 | 125 | 126 | def read_gen(file_name, pil=False): 127 | ext = splitext(file_name)[-1] 128 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 129 | return Image.open(file_name) 130 | elif ext == '.bin' or ext == '.raw': 131 | return np.load(file_name) 132 | elif ext == '.flo': 133 | return readFlow(file_name).astype(np.float32) 134 | elif ext == '.pfm': 135 | flow = readPFM(file_name).astype(np.float32) 136 | if len(flow.shape) == 2: 137 | return flow 138 | else: 139 | return flow[:, :, :-1] 140 | return [] 141 | -------------------------------------------------------------------------------- /code.v.1.0/core/utils/grid_sample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def grid_sample(image, optical): 6 | N, C, IH, IW = image.shape 7 | _, H, W, _ = optical.shape 8 | 9 | ix = optical[..., 0] 10 | iy = optical[..., 1] 11 | 12 | ix = ((ix + 1) / 2) * (IW-1) 13 | iy = ((iy + 1) / 2) * (IH-1) 14 | with torch.no_grad(): 15 | ix_nw = torch.floor(ix) 16 | iy_nw = torch.floor(iy) 17 | ix_ne = ix_nw + 1 18 | iy_ne = iy_nw 19 | ix_sw = ix_nw 20 | iy_sw = iy_nw + 1 21 | ix_se = ix_nw + 1 22 | iy_se = iy_nw + 1 23 | 24 | nw = (ix_se - ix) * (iy_se - iy) 25 | ne = (ix - ix_sw) * (iy_sw - iy) 26 | sw = (ix_ne - ix) * (iy - iy_ne) 27 | se = (ix - ix_nw) * (iy - iy_nw) 28 | 29 | with torch.no_grad(): 30 | torch.clamp(ix_nw, 0, IW-1, out=ix_nw) 31 | torch.clamp(iy_nw, 0, IH-1, out=iy_nw) 32 | 33 | torch.clamp(ix_ne, 0, IW-1, out=ix_ne) 34 | torch.clamp(iy_ne, 0, IH-1, out=iy_ne) 35 | 36 | torch.clamp(ix_sw, 0, IW-1, out=ix_sw) 37 | torch.clamp(iy_sw, 0, IH-1, out=iy_sw) 38 | 39 | torch.clamp(ix_se, 0, IW-1, out=ix_se) 40 | torch.clamp(iy_se, 0, IH-1, out=iy_se) 41 | 42 | image = image.view(N, C, IH * IW) 43 | 44 | 45 | nw_val = torch.gather(image, 2, (iy_nw * IW + ix_nw).long().view(N, 1, H * W).repeat(1, C, 1)) 46 | ne_val = torch.gather(image, 2, (iy_ne * IW + ix_ne).long().view(N, 1, H * W).repeat(1, C, 1)) 47 | sw_val = torch.gather(image, 2, (iy_sw * IW + ix_sw).long().view(N, 1, H * W).repeat(1, C, 1)) 48 | se_val = torch.gather(image, 2, (iy_se * IW + ix_se).long().view(N, 1, H * W).repeat(1, C, 1)) 49 | 50 | out_val = (nw_val.view(N, C, H, W) * nw.view(N, 1, H, W) + 51 | ne_val.view(N, C, H, W) * ne.view(N, 1, H, W) + 52 | sw_val.view(N, C, H, W) * sw.view(N, 1, H, W) + 53 | se_val.view(N, C, H, W) * se.view(N, 1, H, W)) 54 | 55 | return out_val -------------------------------------------------------------------------------- /code.v.1.0/core/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | from .grid_sample import grid_sample 7 | 8 | 9 | class InputPadder: 10 | """ Pads images such that dimensions are divisible by 8 """ 11 | def __init__(self, dims, mode='sintel'): 12 | self.ht, self.wd = dims[-2:] 13 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 14 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 15 | if mode == 'sintel': 16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 17 | else: 18 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 19 | 20 | def pad(self, *inputs): 21 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 22 | 23 | def unpad(self,x): 24 | ht, wd = x.shape[-2:] 25 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 26 | return x[..., c[0]:c[1], c[2]:c[3]] 27 | 28 | 29 | def forward_interpolate(flow): 30 | flow = flow.detach().cpu().numpy() 31 | dx, dy = flow[0], flow[1] 32 | 33 | ht, wd = dx.shape 34 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 35 | 36 | x1 = x0 + dx 37 | y1 = y0 + dy 38 | 39 | x1 = x1.reshape(-1) 40 | y1 = y1.reshape(-1) 41 | dx = dx.reshape(-1) 42 | dy = dy.reshape(-1) 43 | 44 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 45 | x1 = x1[valid] 46 | y1 = y1[valid] 47 | dx = dx[valid] 48 | dy = dy[valid] 49 | 50 | flow_x = interpolate.griddata( 51 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 52 | 53 | flow_y = interpolate.griddata( 54 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 55 | 56 | flow = np.stack([flow_x, flow_y], axis=0) 57 | return torch.from_numpy(flow).float() 58 | 59 | 60 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 61 | """ Wrapper for grid_sample, uses pixel coordinates """ 62 | H, W = img.shape[-2:] 63 | xgrid, ygrid = coords.split([1,1], dim=-1) 64 | xgrid = 2*xgrid/(W-1) - 1 65 | ygrid = 2*ygrid/(H-1) - 1 66 | 67 | grid = torch.cat([xgrid, ygrid], dim=-1) 68 | img = F.grid_sample(img, grid, align_corners=True) 69 | 70 | # Enable higher order grad for JR 71 | # img = grid_sample(img, grid) 72 | 73 | if mask: 74 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 75 | return img, mask.float() 76 | 77 | return img 78 | 79 | 80 | def coords_grid(batch, ht, wd, device): 81 | coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) 82 | coords = torch.stack(coords[::-1], dim=0).float() 83 | return coords[None].repeat(batch, 1, 1, 1) 84 | 85 | 86 | def upflow8(flow, mode='bilinear'): 87 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 88 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 89 | -------------------------------------------------------------------------------- /code.v.1.0/ref/val.txt: -------------------------------------------------------------------------------- 1 | Applying weight normalization to BasicUpdateBlock 2 | Parameter Count: 5.243 M 3 | VALID | FORWARD abs_diff: 3.1601126194000244; rel_diff: 0.0005030641914345324; nstep: 53 4 | Validation (clean) EPE: 1.432517170906067 (1.432517170906067), 1px: 0.906150862584164, 3px: 0.9568943325276342, 5px: 0.9685142592463305 5 | VALID | FORWARD abs_diff: 0.7322859764099121; rel_diff: 0.00010898241453105584; nstep: 59 6 | VALID | FORWARD abs_diff: 55.754356384277344; rel_diff: 0.00757289445027709; nstep: 48 7 | Validation (final) EPE: 2.7917442321777344 (2.7917442321777344), 1px: 0.8529280729345681, 3px: 0.9170532694536889, 5px: 0.9365406933832148 8 | Validation KITTI: EPE: 5.427668504863977 (5.427668504863977), F1: 16.667239367961884 (16.667239367961884) 9 | Applying weight normalization to BasicUpdateBlock 10 | Parameter Count: 12.800 M 11 | VALID | FORWARD abs_diff: 3.439692735671997; rel_diff: 0.0005472839525020399; nstep: 53 12 | Validation (clean) EPE: 1.453827977180481 (1.453827977180481), 1px: 0.9146167279857274, 3px: 0.9612179145570596, 5px: 0.9716016637976287 13 | VALID | FORWARD abs_diff: 2.7573065757751465; rel_diff: 0.00041015823748834397; nstep: 44 14 | VALID | FORWARD abs_diff: 40.787933349609375; rel_diff: 0.005504987121547275; nstep: 54 15 | Validation (final) EPE: 2.578798294067383 (2.578798294067383), 1px: 0.8648370765776335, 3px: 0.9247630216423374, 5px: 0.9420040652278926 16 | Validation KITTI: EPE: 3.972053589001298 (3.972053589001298), F1: 13.408777117729187 (13.408777117729187) 17 | Applying weight normalization to BasicUpdateBlock 18 | Parameter Count: 12.800 M 19 | VALID | FORWARD abs_diff: 3.5625743865966797; rel_diff: 0.0005668736362606357; nstep: 51 20 | Validation (clean) EPE: 1.3733853101730347 (1.3733853101730347), 1px: 0.9143194620474535, 3px: 0.961116638444476, 5px: 0.9717625759844098 21 | VALID | FORWARD abs_diff: 1.6139500141143799; rel_diff: 0.00024008954221575165; nstep: 52 22 | VALID | FORWARD abs_diff: 27.838302612304688; rel_diff: 0.0037760218577109006; nstep: 52 23 | Validation (final) EPE: 2.6194019317626953 (2.6194019317626953), 1px: 0.8644313481614472, 3px: 0.9246781902573611, 5px: 0.9420075314657803 24 | Validation KITTI: EPE: 3.973951353058219 (3.973951353058219), F1: 13.621531426906586 (13.621531426906586) 25 | Applying weight normalization to BasicUpdateBlock 26 | Parameter Count: 12.800 M 27 | VALID | FORWARD abs_diff: 3.938171148300171; rel_diff: 0.0006266361620816458; nstep: 52 28 | Validation (clean) EPE: 1.3641184568405151 (1.3641184568405151), 1px: 0.9150200509059743, 3px: 0.9611289370265778, 5px: 0.9716694996437628 29 | VALID | FORWARD abs_diff: 2.300001859664917; rel_diff: 0.0003421608202927882; nstep: 53 30 | VALID | FORWARD abs_diff: 32.62757873535156; rel_diff: 0.004405397824122709; nstep: 52 31 | Validation (final) EPE: 2.622418165206909 (2.622418165206909), 1px: 0.8643445910887555, 3px: 0.9246018275951196, 5px: 0.9417596599553072 32 | Validation KITTI: EPE: 4.018190502673388 (4.018190502673388), F1: 13.917067646980286 (13.917067646980286) 33 | -------------------------------------------------------------------------------- /code.v.1.0/train_B.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -u main.py --name deq-flow-B-chairs --stage chairs --validation chairs \ 4 | --gpus 0 1 --num_steps 120000 --batch_size 12 --lr 0.0004 --image_size 368 496 --wdecay 0.0001 \ 5 | --wnorm --f_solver anderson --f_thres 36 \ 6 | --n_losses 6 --phantom_grad 1 7 | 8 | python -u main.py --name deq-flow-B-things --stage things \ 9 | --validation sintel kitti --restore_ckpt checkpoints/deq-flow-B-chairs.pth \ 10 | --gpus 0 1 --num_steps 120000 --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001 \ 11 | --wnorm --f_solver anderson --f_thres 40 \ 12 | --n_losses 2 --phantom_grad 3 13 | 14 | -------------------------------------------------------------------------------- /code.v.1.0/train_B_demo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -u main.py --name deq-flow-B-1-step-grad-things --stage things \ 4 | --validation sintel kitti --restore_ckpt checkpoints/deq-flow-B-chairs.pth \ 5 | --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001 \ 6 | --wnorm --f_thres 40 --f_solver anderson \ 7 | -------------------------------------------------------------------------------- /code.v.1.0/train_H_demo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -u main.py --name deq-flow-H-1-step-grad-ad-things --stage things \ 4 | --validation sintel kitti --restore_ckpt checkpoints/deq-flow-H-chairs.pth \ 5 | --gpus 0 1 --num_steps 120000 --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001 \ 6 | --wnorm --huge --f_solver anderson \ 7 | --f_thres 60 --phantom_grad 1 8 | -------------------------------------------------------------------------------- /code.v.1.0/train_H_full.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | python -u main.py --name deq-flow-H-chairs --stage chairs --validation chairs \ 5 | --gpus 0 1 2 --num_steps 120000 --batch_size 12 --lr 0.0004 --image_size 368 496 --wdecay 0.0001 \ 6 | --wnorm --huge --f_solver broyden \ 7 | --f_thres 36 --n_losses 6 --phantom_grad 1 --sliced_core 8 | 9 | python -u main.py --name deq-flow-H-things --stage things \ 10 | --validation sintel kitti --restore_ckpt checkpoints/deq-flow-H-chairs.pth \ 11 | --gpus 0 1 2 --num_steps 120000 --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001 \ 12 | --wnorm --huge --f_solver broyden \ 13 | --f_thres 36 --n_losses 6 --phantom_grad 3 --sliced_core 14 | 15 | python -u main.py --name deq-flow-H-sintel --stage sintel \ 16 | --validation sintel --restore_ckpt checkpoints/deq-flow-H-things.pth \ 17 | --gpus 0 1 2 --num_steps 120000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.0001 --gamma=0.90 \ 18 | --wnorm --huge --f_solver broyden \ 19 | --f_thres 36 --n_losses 6 --phantom_grad 3 --sliced_core 20 | 21 | python -u main.py --name deq-flow-H-kitti --stage kitti \ 22 | --validation kitti --restore_ckpt checkpoints/deq-flow-H-sintel.pth \ 23 | --gpus 0 1 2 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.0001 --gamma=0.90 \ 24 | --wnorm --huge --f_solver broyden \ 25 | --f_thres 36 --n_losses 6 --phantom_grad 1 --sliced_core 26 | -------------------------------------------------------------------------------- /code.v.1.0/val.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -u main.py --eval --name deq-flow-B-things --stage things \ 4 | --validation sintel kitti --restore_ckpt checkpoints/deq-flow-B-things-test.pth --gpus 0 \ 5 | --wnorm --f_thres 40 --f_solver anderson 6 | 7 | python -u main.py --eval --name deq-flow-H-things --stage things \ 8 | --validation sintel kitti --restore_ckpt checkpoints/deq-flow-H-things-test-1.pth --gpus 0 \ 9 | --wnorm --f_thres 36 --f_solver broyden --huge 10 | 11 | python -u main.py --eval --name deq-flow-H-things --stage things \ 12 | --validation sintel kitti --restore_ckpt checkpoints/deq-flow-H-things-test-2.pth --gpus 0 \ 13 | --wnorm --f_thres 36 --f_solver broyden --huge 14 | 15 | python -u main.py --eval --name deq-flow-H-things --stage things \ 16 | --validation sintel kitti --restore_ckpt checkpoints/deq-flow-H-things-test-3.pth --gpus 0 \ 17 | --wnorm --f_thres 36 --f_solver broyden --huge 18 | -------------------------------------------------------------------------------- /code.v.1.0/viz.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('core') 4 | 5 | import argparse 6 | import copy 7 | import os 8 | import time 9 | 10 | import datasets 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | import cv2 17 | from PIL import Image 18 | 19 | from utils import flow_viz, frame_utils 20 | from utils.utils import InputPadder, forward_interpolate 21 | 22 | 23 | @torch.no_grad() 24 | def sintel_visualization(model, split='test', warm_start=False, fixed_point_reuse=False, output_path='sintel_viz', **kwargs): 25 | """ Create visualization for the Sintel dataset """ 26 | model.eval() 27 | for dstype in ['clean', 'final']: 28 | split = 'test' if split == 'test' else 'training' 29 | test_dataset = datasets.MpiSintel(split=split, aug_params=None, dstype=dstype) 30 | 31 | flow_prev, sequence_prev, fixed_point = None, None, None 32 | for test_id in range(len(test_dataset)): 33 | image1, image2, (sequence, frame) = test_dataset[test_id] 34 | if sequence != sequence_prev: 35 | flow_prev = None 36 | fixed_point = None 37 | 38 | padder = InputPadder(image1.shape) 39 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) 40 | 41 | flow_low, flow_pr, info = model(image1, image2, flow_init=flow_prev, cached_result=fixed_point, **kwargs) 42 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy() 43 | 44 | if warm_start: 45 | flow_prev = forward_interpolate(flow_low[0])[None].cuda() 46 | 47 | if fixed_point_reuse: 48 | net, flow_pred_low = info['cached_result'] 49 | flow_pred_low = forward_interpolate(flow_pred_low[0])[None].cuda() 50 | fixed_point = (net, flow_pred_low) 51 | 52 | output_dir = os.path.join(output_path, dstype, sequence) 53 | output_file = os.path.join(output_dir, 'frame%04d.png' % (frame+1)) 54 | 55 | if not os.path.exists(output_dir): 56 | os.makedirs(output_dir) 57 | 58 | # visualizaion 59 | img_flow = flow_viz.flow_to_image(flow) 60 | img_flow = cv2.cvtColor(img_flow, cv2.COLOR_RGB2BGR) 61 | cv2.imwrite(output_file, img_flow, [int(cv2.IMWRITE_PNG_COMPRESSION), 1]) 62 | 63 | sequence_prev = sequence 64 | 65 | 66 | @torch.no_grad() 67 | def kitti_visualization(model, split='test', output_path='kitti_viz'): 68 | """ Create visualization for the KITTI dataset """ 69 | model.eval() 70 | split = 'testing' if split == 'test' else 'training' 71 | test_dataset = datasets.KITTI(split='testing', aug_params=None) 72 | 73 | if not os.path.exists(output_path): 74 | os.makedirs(output_path) 75 | 76 | for test_id in range(len(test_dataset)): 77 | image1, image2, (frame_id, ) = test_dataset[test_id] 78 | padder = InputPadder(image1.shape, mode='kitti') 79 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) 80 | 81 | _, flow_pr, _ = model(image1, image2) 82 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy() 83 | 84 | output_filename = os.path.join(output_path, frame_id) 85 | 86 | # visualizaion 87 | img_flow = flow_viz.flow_to_image(flow) 88 | img_flow = cv2.cvtColor(img_flow, cv2.COLOR_RGB2BGR) 89 | cv2.imwrite(output_filename, img_flow, [int(cv2.IMWRITE_PNG_COMPRESSION), 1]) 90 | 91 | 92 | -------------------------------------------------------------------------------- /code.v.1.0/viz.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -u main.py --viz --name deq-flow-H-kitti --stage kitti \ 4 | --viz_set kitti --restore_ckpt checkpoints/deq-flow-H-kitti.pth --gpus 0 \ 5 | --wnorm --f_thres 36 --f_solver broyden --huge 6 | -------------------------------------------------------------------------------- /code.v.2.0/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/deq-flow/869d4bcd9c1d227cbfa16639994d1b118510a9bb/code.v.2.0/core/__init__.py -------------------------------------------------------------------------------- /code.v.2.0/core/corr.py: -------------------------------------------------------------------------------- 1 | # This code was originally taken from RAFT without modification 2 | # https://github.com/princeton-vl/RAFT/blob/master/core/corr.py 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from utils.utils import bilinear_sampler, coords_grid 7 | 8 | 9 | class CorrBlock: 10 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 11 | self.num_levels = num_levels 12 | self.radius = radius 13 | self.corr_pyramid = [] 14 | 15 | # all pairs correlation 16 | corr = CorrBlock.corr(fmap1, fmap2) 17 | 18 | batch, h1, w1, dim, h2, w2 = corr.shape 19 | corr = corr.reshape(batch*h1*w1, dim, h2, w2) 20 | 21 | self.corr_pyramid.append(corr) 22 | for i in range(self.num_levels-1): 23 | corr = F.avg_pool2d(corr, 2, stride=2) 24 | self.corr_pyramid.append(corr) 25 | 26 | def __call__(self, coords): 27 | r = self.radius 28 | coords = coords.permute(0, 2, 3, 1) 29 | batch, h1, w1, _ = coords.shape 30 | 31 | out_pyramid = [] 32 | for i in range(self.num_levels): 33 | corr = self.corr_pyramid[i] 34 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device) 35 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device) 36 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) 37 | 38 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i 39 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) 40 | coords_lvl = centroid_lvl + delta_lvl 41 | 42 | corr = bilinear_sampler(corr, coords_lvl) 43 | corr = corr.view(batch, h1, w1, -1) 44 | out_pyramid.append(corr) 45 | 46 | out = torch.cat(out_pyramid, dim=-1) 47 | return out.permute(0, 3, 1, 2).contiguous().float() 48 | 49 | @staticmethod 50 | def corr(fmap1, fmap2): 51 | batch, dim, ht, wd = fmap1.shape 52 | fmap1 = fmap1.view(batch, dim, ht*wd) 53 | fmap2 = fmap2.view(batch, dim, ht*wd) 54 | 55 | corr = torch.matmul(fmap1.transpose(1,2), fmap2) 56 | corr = corr.view(batch, ht, wd, 1, ht, wd) 57 | return corr / torch.sqrt(torch.tensor(dim).float()) 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /code.v.2.0/core/datasets.py: -------------------------------------------------------------------------------- 1 | # Data loading based on https://github.com/NVIDIA/flownet2-pytorch 2 | 3 | import numpy as np 4 | import torch 5 | import torch.utils.data as data 6 | import torch.nn.functional as F 7 | 8 | import os 9 | import math 10 | import random 11 | from glob import glob 12 | import os.path as osp 13 | 14 | from utils import frame_utils 15 | from utils.augmentor import FlowAugmentor, SparseFlowAugmentor 16 | 17 | 18 | class FlowDataset(data.Dataset): 19 | def __init__(self, aug_params=None, sparse=False): 20 | self.augmentor = None 21 | self.sparse = sparse 22 | if aug_params is not None: 23 | if sparse: 24 | self.augmentor = SparseFlowAugmentor(**aug_params) 25 | else: 26 | self.augmentor = FlowAugmentor(**aug_params) 27 | 28 | self.is_test = False 29 | self.init_seed = False 30 | self.flow_list = [] 31 | self.image_list = [] 32 | self.extra_info = [] 33 | 34 | def __getitem__(self, index): 35 | 36 | if self.is_test: 37 | img1 = frame_utils.read_gen(self.image_list[index][0]) 38 | img2 = frame_utils.read_gen(self.image_list[index][1]) 39 | img1 = np.array(img1).astype(np.uint8)[..., :3] 40 | img2 = np.array(img2).astype(np.uint8)[..., :3] 41 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 42 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 43 | return img1, img2, self.extra_info[index] 44 | 45 | if not self.init_seed: 46 | worker_info = torch.utils.data.get_worker_info() 47 | if worker_info is not None: 48 | torch.manual_seed(worker_info.id) 49 | np.random.seed(worker_info.id) 50 | random.seed(worker_info.id) 51 | self.init_seed = True 52 | 53 | index = index % len(self.image_list) 54 | valid = None 55 | if self.sparse: 56 | flow, valid = frame_utils.readFlowKITTI(self.flow_list[index]) 57 | else: 58 | flow = frame_utils.read_gen(self.flow_list[index]) 59 | 60 | img1 = frame_utils.read_gen(self.image_list[index][0]) 61 | img2 = frame_utils.read_gen(self.image_list[index][1]) 62 | 63 | flow = np.array(flow).astype(np.float32) 64 | img1 = np.array(img1).astype(np.uint8) 65 | img2 = np.array(img2).astype(np.uint8) 66 | 67 | # grayscale images 68 | if len(img1.shape) == 2: 69 | img1 = np.tile(img1[...,None], (1, 1, 3)) 70 | img2 = np.tile(img2[...,None], (1, 1, 3)) 71 | else: 72 | img1 = img1[..., :3] 73 | img2 = img2[..., :3] 74 | 75 | if self.augmentor is not None: 76 | if self.sparse: 77 | img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid) 78 | else: 79 | img1, img2, flow = self.augmentor(img1, img2, flow) 80 | 81 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 82 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 83 | flow = torch.from_numpy(flow).permute(2, 0, 1).float() 84 | 85 | if valid is not None: 86 | valid = torch.from_numpy(valid) 87 | else: 88 | valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) 89 | 90 | return img1, img2, flow, valid.float() 91 | 92 | 93 | def __rmul__(self, v): 94 | self.flow_list = v * self.flow_list 95 | self.image_list = v * self.image_list 96 | return self 97 | 98 | def __len__(self): 99 | return len(self.image_list) 100 | 101 | 102 | class MpiSintel(FlowDataset): 103 | def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'): 104 | super(MpiSintel, self).__init__(aug_params) 105 | flow_root = osp.join(root, split, 'flow') 106 | image_root = osp.join(root, split, dstype) 107 | 108 | if split == 'test': 109 | self.is_test = True 110 | 111 | for scene in os.listdir(image_root): 112 | image_list = sorted(glob(osp.join(image_root, scene, '*.png'))) 113 | for i in range(len(image_list)-1): 114 | self.image_list += [ [image_list[i], image_list[i+1]] ] 115 | self.extra_info += [ (scene, i) ] # scene and frame_id 116 | 117 | if split != 'test': 118 | self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo'))) 119 | 120 | 121 | class FlyingChairs(FlowDataset): 122 | def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'): 123 | super(FlyingChairs, self).__init__(aug_params) 124 | 125 | images = sorted(glob(osp.join(root, '*.ppm'))) 126 | flows = sorted(glob(osp.join(root, '*.flo'))) 127 | assert (len(images)//2 == len(flows)) 128 | 129 | split_list = np.loadtxt('chairs_split.txt', dtype=np.int32) 130 | for i in range(len(flows)): 131 | xid = split_list[i] 132 | if (split=='training' and xid==1) or (split=='validation' and xid==2): 133 | self.flow_list += [ flows[i] ] 134 | self.image_list += [ [images[2*i], images[2*i+1]] ] 135 | 136 | 137 | class FlyingThings3D(FlowDataset): 138 | def __init__(self, aug_params=None, root='datasets/FlyingThings3D', split='train', dstype='frames_cleanpass'): 139 | super(FlyingThings3D, self).__init__(aug_params) 140 | 141 | if split == 'train': 142 | dir_prefix = 'TRAIN' 143 | elif split == 'test': 144 | dir_prefix = 'TEST' 145 | else: 146 | raise ValueError('Unknown split for FlyingThings3D.') 147 | 148 | for cam in ['left']: 149 | for direction in ['into_future', 'into_past']: 150 | image_dirs = sorted(glob(osp.join(root, dstype, f'{dir_prefix}/*/*'))) 151 | image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) 152 | 153 | flow_dirs = sorted(glob(osp.join(root, f'optical_flow/{dir_prefix}/*/*'))) 154 | flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs]) 155 | 156 | for idir, fdir in zip(image_dirs, flow_dirs): 157 | images = sorted(glob(osp.join(idir, '*.png')) ) 158 | flows = sorted(glob(osp.join(fdir, '*.pfm')) ) 159 | for i in range(len(flows)-1): 160 | if direction == 'into_future': 161 | self.image_list += [ [images[i], images[i+1]] ] 162 | self.flow_list += [ flows[i] ] 163 | elif direction == 'into_past': 164 | self.image_list += [ [images[i+1], images[i]] ] 165 | self.flow_list += [ flows[i+1] ] 166 | 167 | 168 | class KITTI(FlowDataset): 169 | def __init__(self, aug_params=None, split='training', root='datasets/KITTI'): 170 | super(KITTI, self).__init__(aug_params, sparse=True) 171 | if split == 'testing': 172 | self.is_test = True 173 | 174 | root = osp.join(root, split) 175 | images1 = sorted(glob(osp.join(root, 'image_2/*_10.png'))) 176 | images2 = sorted(glob(osp.join(root, 'image_2/*_11.png'))) 177 | 178 | for img1, img2 in zip(images1, images2): 179 | frame_id = img1.split('/')[-1] 180 | self.extra_info += [ [frame_id] ] 181 | self.image_list += [ [img1, img2] ] 182 | 183 | if split == 'training': 184 | self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png'))) 185 | 186 | 187 | class HD1K(FlowDataset): 188 | def __init__(self, aug_params=None, root='datasets/HD1k'): 189 | super(HD1K, self).__init__(aug_params, sparse=True) 190 | 191 | seq_ix = 0 192 | while 1: 193 | flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix))) 194 | images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix))) 195 | 196 | if len(flows) == 0: 197 | break 198 | 199 | for i in range(len(flows)-1): 200 | self.flow_list += [flows[i]] 201 | self.image_list += [ [images[i], images[i+1]] ] 202 | 203 | seq_ix += 1 204 | 205 | 206 | def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'): 207 | """ Create the data loader for the corresponding trainign set """ 208 | 209 | if args.stage == 'chairs': 210 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True} 211 | train_dataset = FlyingChairs(aug_params, split='training') 212 | 213 | elif args.stage == 'things': 214 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True} 215 | clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass') 216 | final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass') 217 | train_dataset = clean_dataset + final_dataset 218 | 219 | elif args.stage == 'sintel': 220 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True} 221 | things = FlyingThings3D(aug_params, dstype='frames_cleanpass') 222 | sintel_clean = MpiSintel(aug_params, split='training', dstype='clean') 223 | sintel_final = MpiSintel(aug_params, split='training', dstype='final') 224 | 225 | if TRAIN_DS == 'C+T+K+S+H': 226 | kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True}) 227 | hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True}) 228 | train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things 229 | 230 | elif TRAIN_DS == 'C+T+K/S': 231 | train_dataset = 100*sintel_clean + 100*sintel_final + things 232 | 233 | elif args.stage == 'kitti': 234 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} 235 | train_dataset = KITTI(aug_params, split='training') 236 | 237 | train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, 238 | pin_memory=False, shuffle=True, num_workers=4, drop_last=True) 239 | 240 | print('Training with %d image pairs' % len(train_dataset)) 241 | return train_loader 242 | 243 | -------------------------------------------------------------------------------- /code.v.2.0/core/deq/__init__.py: -------------------------------------------------------------------------------- 1 | from .deq_class import get_deq 2 | -------------------------------------------------------------------------------- /code.v.2.0/core/deq/arg_utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | def add_deq_args(parser): 5 | parser.add_argument('--wnorm', action='store_true', help="use weight normalization") 6 | parser.add_argument('--f_solver', default='anderson', type=str, choices=['anderson', 'broyden', 'naive_solver'], 7 | help='forward solver to use (only anderson and broyden supported now)') 8 | parser.add_argument('--b_solver', default='broyden', type=str, choices=['anderson', 'broyden', 'naive_solver'], 9 | help='backward solver to use') 10 | parser.add_argument('--f_thres', type=int, default=40, help='forward pass solver threshold') 11 | parser.add_argument('--b_thres', type=int, default=40, help='backward pass solver threshold') 12 | parser.add_argument('--f_eps', type=float, default=1e-3, help='forward pass solver stopping criterion') 13 | parser.add_argument('--b_eps', type=float, default=1e-3, help='backward pass solver stopping criterion') 14 | parser.add_argument('--f_stop_mode', type=str, default="abs", help="forward pass fixed-point convergence stop mode") 15 | parser.add_argument('--b_stop_mode', type=str, default="abs", help="backward pass fixed-point convergence stop mode") 16 | parser.add_argument('--eval_factor', type=float, default=1.5, help="factor to scale up the f_thres at test for better convergence.") 17 | parser.add_argument('--eval_f_thres', type=int, default=0, help="directly set the f_thres at test.") 18 | 19 | parser.add_argument('--indexing_core', action='store_true', help="use the indexing core implementation.") 20 | parser.add_argument('--ift', action='store_true', help="use implicit differentiation.") 21 | parser.add_argument('--safe_ift', action='store_true', help="use a safer function for IFT to avoid potential segment fault in older pytorch versions.") 22 | parser.add_argument('--n_losses', type=int, default=1, help="number of loss terms (uniform spaced, 1 + fixed point correction).") 23 | parser.add_argument('--indexing', type=int, nargs='+', default=[], help="indexing for fixed point correction.") 24 | parser.add_argument('--phantom_grad', type=int, nargs='+', default=[1], help="steps of Phantom Grad") 25 | parser.add_argument('--tau', type=float, default=1.0, help="damping factor for unrolled Phantom Grad") 26 | parser.add_argument('--sup_all', action='store_true', help="supervise all the trajectories by Phantom Grad.") 27 | 28 | parser.add_argument('--sradius_mode', action='store_true', help="monitor the spectral radius during validation") 29 | 30 | 31 | -------------------------------------------------------------------------------- /code.v.2.0/core/deq/deq_class.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from termcolor import colored 7 | 8 | from .solvers import get_solver 9 | from .norm import reset_weight_norm 10 | from .grad import make_pair, backward_factory 11 | from .jacobian import power_method 12 | 13 | 14 | class DEQBase(nn.Module): 15 | def __init__(self, args): 16 | super(DEQBase, self).__init__() 17 | 18 | self.args = args 19 | self.f_solver = get_solver(args.f_solver) 20 | self.b_solver = get_solver(args.b_solver) 21 | 22 | self.f_thres = args.f_thres 23 | self.b_thres = args.b_thres 24 | 25 | self.f_eps = args.f_eps 26 | self.b_eps = args.b_eps 27 | 28 | self.f_stop_mode = args.f_stop_mode 29 | self.b_stop_mode = args.b_stop_mode 30 | 31 | self.eval_f_thres = args.eval_f_thres if args.eval_f_thres > 0 else int(self.f_thres * args.eval_factor) 32 | 33 | self.hook = None 34 | 35 | def _log_convergence(self, info, name='FORWARD', color='yellow'): 36 | state = 'TRAIN' if self.training else 'VALID' 37 | alt_mode = 'rel' if self.f_stop_mode == 'abs' else 'abs' 38 | 39 | rel_lowest, abs_lowest = info['rel_lowest'].mean().item(), info['abs_lowest'].mean().item() 40 | nstep = info['nstep'] 41 | 42 | show_str = f'{state} | {name} | rel: {rel_lowest}; abs: {abs_lowest}; nstep: {nstep}' 43 | print(colored(show_str, color)) 44 | 45 | def _sradius(self, deq_func, z_star): 46 | with torch.enable_grad(): 47 | new_z_star = deq_func(z_star.requires_grad_()) 48 | _, sradius = power_method(new_z_star, z_star, n_iters=75) 49 | 50 | return sradius 51 | 52 | def _solve_fixed_point( 53 | self, deq_func, z_init, 54 | log=False, f_thres=None, 55 | **kwargs 56 | ): 57 | raise NotImplementedError 58 | 59 | def forward( 60 | self, deq_func, z_init, 61 | log=False, sradius_mode=False, writer=None, 62 | **kwargs 63 | ): 64 | raise NotImplementedError 65 | 66 | 67 | class DEQIndexing(DEQBase): 68 | def __init__(self, args): 69 | super(DEQIndexing, self).__init__(args) 70 | 71 | # Define gradient functions through the backward factory 72 | if args.n_losses > 1: 73 | n_losses = min(args.f_thres, args.n_losses) 74 | delta = int(args.f_thres // n_losses) 75 | self.indexing = [(k+1)*delta for k in range(n_losses)] 76 | else: 77 | self.indexing = [*args.indexing, args.f_thres] 78 | 79 | # By default, we use the same phantom grad for all corrections. 80 | # You can also set different grad steps a, b, and c for different terms by ``args.phantom_grad a b c ...''. 81 | indexing_pg = make_pair(self.indexing, args.phantom_grad) 82 | produce_grad = [ 83 | backward_factory(grad_type=pg, tau=args.tau, sup_all=args.sup_all) for pg in indexing_pg 84 | ] 85 | if args.ift: 86 | # Enabling args.ift will replace the last gradient function by IFT. 87 | produce_grad[-1] = backward_factory( 88 | grad_type='ift', safe_ift=args.safe_ift, b_solver=self.b_solver, 89 | b_solver_kwargs=dict(threshold=args.b_thres, eps=args.b_eps, stop_mode=args.b_stop_mode) 90 | ) 91 | 92 | self.produce_grad = produce_grad 93 | 94 | def _solve_fixed_point( 95 | self, deq_func, z_init, 96 | log=False, f_thres=None, 97 | **kwargs 98 | ): 99 | if f_thres is None: f_thres = self.f_thres 100 | indexing = self.indexing if self.training else None 101 | 102 | with torch.no_grad(): 103 | z_star, trajectory, info = self.f_solver( 104 | deq_func, x0=z_init, threshold=f_thres, # To reuse previous coarse fixed points 105 | eps=self.f_eps, stop_mode=self.f_stop_mode, indexing=indexing 106 | ) 107 | 108 | if log: self._log_convergence(info, name="FORWARD", color="yellow") 109 | 110 | return z_star, trajectory, info 111 | 112 | def forward( 113 | self, deq_func, z_init, 114 | log=False, sradius_mode=False, writer=None, 115 | **kwargs 116 | ): 117 | if self.training: 118 | _, trajectory, info = self._solve_fixed_point(deq_func, z_init, log=log, *kwargs) 119 | 120 | z_out = [] 121 | for z_pred, produce_grad in zip(trajectory, self.produce_grad): 122 | z_out += produce_grad(self, deq_func, z_pred) # See lib/grad.py for the backward pass implementations 123 | 124 | z_out = [deq_func.vec2list(each) for each in z_out] 125 | else: 126 | # During inference, we directly solve for fixed point 127 | z_star, _, info = self._solve_fixed_point(deq_func, z_init, log=log, f_thres=self.eval_f_thres) 128 | 129 | sradius = self._sradius(deq_func, z_star) if sradius_mode else torch.zeros(1, device=z_star.device) 130 | info['sradius'] = sradius 131 | 132 | z_out = [deq_func.vec2list(z_star)] 133 | 134 | return z_out, info 135 | 136 | 137 | class DEQSliced(DEQBase): 138 | def __init__(self, args): 139 | super(DEQSliced, self).__init__(args) 140 | 141 | # Define gradient functions through the backward factory 142 | if args.n_losses > 1: 143 | self.indexing = [int(args.f_thres // args.n_losses) for _ in range(args.n_losses)] 144 | else: 145 | self.indexing = np.diff([0, *args.indexing, args.f_thres]).tolist() 146 | 147 | # By default, we use the same phantom grad for all corrections. 148 | # You can also set different grad steps a, b, and c for different terms by ``args.phantom_grad a b c ...''. 149 | indexing_pg = make_pair(self.indexing, args.phantom_grad) 150 | produce_grad = [ 151 | backward_factory(grad_type=pg, tau=args.tau, sup_all=args.sup_all) for pg in indexing_pg 152 | ] 153 | if args.ift: 154 | # Enabling args.ift will replace the last gradient function by IFT. 155 | produce_grad[-1] = backward_factory( 156 | grad_type='ift', safe_ift=args.safe_ift, b_solver=self.b_solver, 157 | b_solver_kwargs=dict(threshold=args.b_thres, eps=args.b_eps, stop_mode=args.b_stop_mode) 158 | ) 159 | 160 | self.produce_grad = produce_grad 161 | 162 | def _solve_fixed_point( 163 | self, deq_func, z_init, 164 | log=False, f_thres=None, 165 | **kwargs 166 | ): 167 | with torch.no_grad(): 168 | z_star, _, info = self.f_solver( 169 | deq_func, x0=z_init, threshold=f_thres, # To reuse previous coarse fixed points 170 | eps=self.f_eps, stop_mode=self.f_stop_mode 171 | ) 172 | 173 | if log: self._log_convergence(info, name="FORWARD", color="yellow") 174 | 175 | return z_star, info 176 | 177 | def forward( 178 | self, deq_func, z_star, 179 | log=False, sradius_mode=False, writer=None, 180 | **kwargs 181 | ): 182 | if self.training: 183 | z_out = [] 184 | for f_thres, produce_grad in zip(self.indexing, self.produce_grad): 185 | z_star, info = self._solve_fixed_point(deq_func, z_star, f_thres=f_thres, log=log) 186 | z_out += produce_grad(self, deq_func, z_star, writer=writer) # See lib/grad.py for implementations 187 | z_star = z_out[-1] # Add the gradient chain to the solver. 188 | 189 | z_out = [deq_func.vec2list(each) for each in z_out] 190 | else: 191 | # During inference, we directly solve for fixed point 192 | z_star, info = self._solve_fixed_point(deq_func, z_star, f_thres=self.eval_f_thres, log=log) 193 | 194 | sradius = self._sradius(deq_func, z_star) if sradius_mode else torch.zeros(1, device=z_star.device) 195 | info['sradius'] = sradius 196 | 197 | z_out = [deq_func.vec2list(z_star)] 198 | 199 | return z_out, info 200 | 201 | 202 | def get_deq(args): 203 | if args.indexing_core: 204 | return DEQIndexing 205 | else: 206 | return DEQSliced 207 | -------------------------------------------------------------------------------- /code.v.2.0/core/deq/dropout.py: -------------------------------------------------------------------------------- 1 | import types 2 | 3 | from torch.nn.parameter import Parameter 4 | import torch.nn as nn 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | 9 | 10 | 11 | ############################################################################################################## 12 | # 13 | # Temporal DropConnect in a feed-forward setting 14 | # 15 | ############################################################################################################## 16 | 17 | 18 | class WeightDrop(torch.nn.Module): 19 | def __init__(self, module, weights, dropout=0, temporal=True): 20 | """ 21 | Weight DropConnect, adapted from a recurrent setting by Merity et al. 2017 22 | 23 | :param module: The module whose weights are to be applied dropout on 24 | :param weights: A 2D list identifying the weights to be regularized. Each element of weights should be a 25 | list containing the "path" to the weight kernel. For instance, if we want to regularize 26 | module.layer2.weight3, then this should be ["layer2", "weight3"]. 27 | :param dropout: The dropout rate (0 means no dropout) 28 | :param temporal: Whether we apply DropConnect only to the temporal parts of the weight (empirically we found 29 | this not very important) 30 | """ 31 | super(WeightDrop, self).__init__() 32 | self.module = module 33 | self.weights = weights 34 | self.dropout = dropout 35 | self.temporal = temporal 36 | if self.dropout > 0.0: 37 | self._setup() 38 | 39 | def _setup(self): 40 | for path in self.weights: 41 | full_name_w = '.'.join(path) 42 | 43 | module = self.module 44 | name_w = path[-1] 45 | for i in range(len(path) - 1): 46 | module = getattr(module, path[i]) 47 | w = getattr(module, name_w) 48 | del module._parameters[name_w] 49 | module.register_parameter(name_w + '_raw', Parameter(w.data)) 50 | 51 | def _setweights(self): 52 | for path in self.weights: 53 | module = self.module 54 | name_w = path[-1] 55 | for i in range(len(path) - 1): 56 | module = getattr(module, path[i]) 57 | raw_w = getattr(module, name_w + '_raw') 58 | 59 | if len(raw_w.size()) > 2 and raw_w.size(2) > 1 and self.temporal: 60 | # Drop the temporal parts of the weight; if 1x1 convolution then drop the whole kernel 61 | w = torch.cat([F.dropout(raw_w[:, :, :-1], p=self.dropout, training=self.training), 62 | raw_w[:, :, -1:]], dim=2) 63 | else: 64 | w = F.dropout(raw_w, p=self.dropout, training=self.training) 65 | 66 | setattr(module, name_w, w) 67 | 68 | def forward(self, *args, **kwargs): 69 | if self.dropout > 0.0: 70 | self._setweights() 71 | return self.module.forward(*args, **kwargs) 72 | 73 | 74 | def matrix_diag(a, dim=2): 75 | """ 76 | a has dimension (N, (L,) C), we want a matrix/batch diag that produces (N, (L,) C, C) from the last dimension of a 77 | """ 78 | if dim == 2: 79 | res = torch.zeros(a.size(0), a.size(1), a.size(1)) 80 | res.as_strided(a.size(), [res.stride(0), res.size(2)+1]).copy_(a) 81 | else: 82 | res = torch.zeros(a.size(0), a.size(1), a.size(2), a.size(2)) 83 | res.as_strided(a.size(), [res.stride(0), res.stride(1), res.size(3)+1]).copy_(a) 84 | return res 85 | 86 | 87 | ############################################################################################################## 88 | # 89 | # Embedding dropout 90 | # 91 | ############################################################################################################## 92 | 93 | 94 | def embedded_dropout(embed, words, dropout=0.1, scale=None): 95 | """ 96 | Apply embedding encoder (whose weight we apply a dropout) 97 | 98 | :param embed: The embedding layer 99 | :param words: The input sequence 100 | :param dropout: The embedding weight dropout rate 101 | :param scale: Scaling factor for the dropped embedding weight 102 | :return: The embedding output 103 | """ 104 | if dropout: 105 | mask = embed.weight.data.new().resize_((embed.weight.size(0), 1)).bernoulli_(1 - dropout).expand_as( 106 | embed.weight) / (1 - dropout) 107 | mask = Variable(mask) 108 | masked_embed_weight = mask * embed.weight 109 | else: 110 | masked_embed_weight = embed.weight 111 | 112 | if scale: 113 | masked_embed_weight = scale.expand_as(masked_embed_weight) * masked_embed_weight 114 | 115 | padding_idx = embed.padding_idx 116 | if padding_idx is None: 117 | padding_idx = -1 118 | 119 | X = F.embedding(words, masked_embed_weight, padding_idx, embed.max_norm, embed.norm_type, 120 | embed.scale_grad_by_freq, embed.sparse) 121 | return X 122 | 123 | 124 | 125 | ############################################################################################################## 126 | # 127 | # Variational dropout (for input/output layers, and for hidden layers) 128 | # 129 | ############################################################################################################## 130 | 131 | 132 | class VariationalHidDropout2d(nn.Module): 133 | def __init__(self, dropout=0.0): 134 | super(VariationalHidDropout2d, self).__init__() 135 | self.dropout = dropout 136 | self.mask = None 137 | 138 | def forward(self, x): 139 | if not self.training or self.dropout == 0: 140 | return x 141 | bsz, d, H, W = x.shape 142 | if self.mask is None: 143 | m = torch.zeros(bsz, d, H, W).bernoulli_(1 - self.dropout).to(x) 144 | self.mask = m.requires_grad_(False) / (1 - self.dropout) 145 | return self.mask * x 146 | 147 | 148 | 149 | -------------------------------------------------------------------------------- /code.v.2.0/core/deq/grad.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch import autograd 6 | 7 | from .solvers import anderson, broyden, naive_solver 8 | 9 | 10 | def make_pair(target, source): 11 | if len(target) == len(source): 12 | return source 13 | elif len(source) == 1: 14 | return [source[0] for _ in range(len(target))] 15 | else: 16 | raise ValueError('Unable to align the arg squence!') 17 | 18 | 19 | def backward_factory( 20 | grad_type='ift', 21 | safe_ift=False, 22 | b_solver=anderson, 23 | b_solver_kwargs=dict(), 24 | sup_all=False, 25 | tau=1.0): 26 | """ 27 | [2019-NeurIPS] Deep Equilibrium Models 28 | [2021-NeurIPS] On Training Implicit Models 29 | 30 | This function implements a factory for the backward pass of implicit deep learning, 31 | e.g., DEQ (implicit models), Hamburger (optimization layer), etc. 32 | It now supports IFT, 1-step Grad, and Phantom Grad. 33 | 34 | Args: 35 | grad_type (string, int): 36 | grad_type should be ``ift`` or an int. Default ``ift``. 37 | Set to ``ift`` to enable the implicit differentiation mode. 38 | When passing a number k to this function, it runs UPG with steps k and damping tau. 39 | safe_ift (bool): 40 | Replace the O(1) hook implementeion with a safer one. Default ``False``. 41 | Set to ``True`` to avoid the (potential) segment fault (under previous versions of Pytorch). 42 | b_solver (type): 43 | Solver for the IFT backward pass. Default ``anderson``. 44 | Supported solvers: anderson, broyden. 45 | b_solver_kwargs (dict): 46 | Colllection of backward solver kwargs, e.g., 47 | threshold (int), max steps for the backward solver, 48 | stop_mode (string), criterion for convergence, 49 | etc. 50 | See solver.py to check all the kwargs. 51 | sup_all (bool): 52 | Indicate whether to supervise all the trajectories by Phantom Grad. 53 | Set ``True`` to return all trajectory in Phantom Grad. 54 | tau (float): 55 | Damping factor for Phantom Grad. Default ``1.0``. 56 | 0.5 is recommended for CIFAR-10. 1.0 for DEQ flow. 57 | For DEQ flow, the gating function in GRU naturally produces adaptive tau values. 58 | 59 | Returns: 60 | A gradient functor for implicit deep learning. 61 | Args: 62 | trainer (nn.Module): the module that employs implicit deep learning. 63 | func (type): function that defines the ``f`` in ``z = f(z)``. 64 | z_pred (torch.Tensor): latent state to run the backward pass. 65 | 66 | Returns: 67 | (list(torch.Tensor)): a list of tensors that tracks the gradient info. 68 | 69 | """ 70 | 71 | if grad_type == 'ift': 72 | assert b_solver in [naive_solver, anderson, broyden] 73 | 74 | if safe_ift: 75 | def plain_ift_grad(trainer, func, z_pred, **kwargs): 76 | z_pred = z_pred.requires_grad_() 77 | new_z_pred = func(z_pred) # 1-step grad for df/dtheta 78 | 79 | z_pred_copy = new_z_pred.clone().detach().requires_grad_() 80 | new_z_pred_copy = func(z_pred_copy) 81 | def backward_hook(grad): 82 | grad_star, _, info = b_solver( 83 | lambda y: autograd.grad(new_z_pred_copy, z_pred_copy, y, retain_graph=True)[0] + grad, 84 | torch.zeros_like(grad), **b_solver_kwargs 85 | ) 86 | return grad_star 87 | new_z_pred.register_hook(backward_hook) 88 | 89 | return [new_z_pred] 90 | return plain_ift_grad 91 | else: 92 | def hook_ift_grad(trainer, func, z_pred, **kwargs): 93 | z_pred = z_pred.requires_grad_() 94 | new_z_pred = func(z_pred) # 1-step grad for df/dtheta 95 | 96 | def backward_hook(grad): 97 | if trainer.hook is not None: 98 | trainer.hook.remove() # To avoid infinite loop 99 | grad_star, _, info = b_solver( 100 | lambda y: autograd.grad(new_z_pred, z_pred, y, retain_graph=True)[0] + grad, 101 | torch.zeros_like(grad), **b_solver_kwargs 102 | ) 103 | return grad_star 104 | trainer.hook = new_z_pred.register_hook(backward_hook) 105 | 106 | return [new_z_pred] 107 | return hook_ift_grad 108 | else: 109 | assert type(grad_type) is int and grad_type >= 1 110 | n_phantom_grad = grad_type 111 | 112 | if sup_all: 113 | def sup_all_phantom_grad(trainer, func, z_pred, **kwargs): 114 | z_out = [] 115 | for _ in range(n_phantom_grad): 116 | z_pred = (1 - tau) * z_pred + tau * func(z_pred) 117 | z_out.append(z_pred) 118 | 119 | return z_out 120 | return sup_all_phantom_grad 121 | else: 122 | def phantom_grad(trainer, func, z_pred, **kwargs): 123 | for _ in range(n_phantom_grad): 124 | z_pred = (1 - tau) * z_pred + tau * func(z_pred) 125 | 126 | return [z_pred] 127 | return phantom_grad 128 | -------------------------------------------------------------------------------- /code.v.2.0/core/deq/jacobian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | 7 | def jac_loss_estimate(f0, z0, vecs=2, create_graph=True): 8 | """Estimating tr(J^TJ)=tr(JJ^T) via Hutchinson estimator 9 | 10 | Args: 11 | f0 (torch.Tensor): Output of the function f (whose J is to be analyzed) 12 | z0 (torch.Tensor): Input to the function f 13 | vecs (int, optional): Number of random Gaussian vectors to use. Defaults to 2. 14 | create_graph (bool, optional): Whether to create backward graph (e.g., to train on this loss). 15 | Defaults to True. 16 | 17 | Returns: 18 | torch.Tensor: A 1x1 torch tensor that encodes the (shape-normalized) jacobian loss 19 | """ 20 | vecs = vecs 21 | result = 0 22 | for i in range(vecs): 23 | v = torch.randn(*z0.shape).to(z0) 24 | vJ = torch.autograd.grad(f0, z0, v, retain_graph=True, create_graph=create_graph)[0] 25 | result += vJ.norm()**2 26 | return result / vecs / np.prod(z0.shape) 27 | 28 | 29 | def power_method(f0, z0, n_iters=200): 30 | """Estimating the spectral radius of J using power method 31 | 32 | Args: 33 | f0 (torch.Tensor): Output of the function f (whose J is to be analyzed) 34 | z0 (torch.Tensor): Input to the function f 35 | n_iters (int, optional): Number of power method iterations. Defaults to 200. 36 | 37 | Returns: 38 | tuple: (largest eigenvector, largest (abs.) eigenvalue) 39 | """ 40 | evector = torch.randn_like(z0) 41 | bsz = evector.shape[0] 42 | for i in range(n_iters): 43 | vTJ = torch.autograd.grad(f0, z0, evector, retain_graph=(i < n_iters-1), create_graph=False)[0] 44 | evalue = (vTJ * evector).reshape(bsz, -1).sum(1, keepdim=True) / (evector * evector).reshape(bsz, -1).sum(1, keepdim=True) 45 | evector = (vTJ.reshape(bsz, -1) / vTJ.reshape(bsz, -1).norm(dim=1, keepdim=True)).reshape_as(z0) 46 | return (evector, torch.abs(evalue)) 47 | -------------------------------------------------------------------------------- /code.v.2.0/core/deq/layer_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | 7 | 8 | class DEQWrapper: 9 | def __init__(self, func, z_init=list()): 10 | z_shape = [] 11 | z_indexing = [0] 12 | for each in z_init: 13 | z_shape.append(each.shape) 14 | z_indexing.append(np.prod(each.shape[1:])) 15 | 16 | self.func = func 17 | self.z_shape = z_shape 18 | self.z_indexing = np.cumsum(z_indexing) 19 | 20 | def list2vec(self, *z_list): 21 | '''Convert list of tensors to a batched vector (B, ...)''' 22 | 23 | z_list = [each.flatten(start_dim=1) for each in z_list] 24 | return torch.cat(z_list, dim=1) 25 | 26 | def vec2list(self, z_hidden): 27 | '''Convert a batched vector back to a list''' 28 | 29 | z_list = [] 30 | z_indexing = self.z_indexing 31 | for i, shape in enumerate(self.z_shape): 32 | z_list.append(z_hidden[:, z_indexing[i]:z_indexing[i+1]].view(shape)) 33 | return z_list 34 | 35 | def __call__(self, z_hidden): 36 | '''A function call to the DEQ f''' 37 | 38 | z_list = self.vec2list(z_hidden) 39 | z_list = self.func(*z_list) 40 | z_hidden = self.list2vec(*z_list) 41 | 42 | return z_hidden 43 | 44 | def norm_diff(self, z_new, z_old, show_list=False): 45 | if show_list: 46 | z_new, z_old = self.vec2list(z_new), self.vec2list() 47 | return [(z_new[i] - z_old[i]).norm().item() for i in range(len(z_new))] 48 | 49 | return (z_new - z_old).norm().item() 50 | -------------------------------------------------------------------------------- /code.v.2.0/core/deq/norm/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .weight_norm import apply_weight_norm, reset_weight_norm, remove_weight_norm, register_wn_module 3 | -------------------------------------------------------------------------------- /code.v.2.0/core/deq/norm/weight_norm.py: -------------------------------------------------------------------------------- 1 | import types 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from torch.nn import functional as F 7 | from torch.nn.parameter import Parameter 8 | 9 | 10 | def _norm(p, dim): 11 | """Computes the norm over all dimensions except dim""" 12 | if dim is None: 13 | return p.norm() 14 | elif dim == 0: 15 | output_size = (p.size(0),) + (1,) * (p.dim() - 1) 16 | return p.contiguous().view(p.size(0), -1).norm(dim=1).view(*output_size) 17 | elif dim == p.dim() - 1: 18 | output_size = (1,) * (p.dim() - 1) + (p.size(-1),) 19 | return p.contiguous().view(-1, p.size(-1)).norm(dim=0).view(*output_size) 20 | else: 21 | return _norm(p.transpose(0, dim), 0).transpose(0, dim) 22 | 23 | 24 | def compute_weight(module, name, dim): 25 | g = getattr(module, name + '_g') 26 | v = getattr(module, name + '_v') 27 | return v * (g / _norm(v, dim)) 28 | 29 | 30 | def apply_atom_wn(module, names, dims): 31 | if type(names) is str: 32 | names = [names] 33 | 34 | if type(dims) is int: 35 | dims = [dims] 36 | 37 | assert len(names) == len(dims) 38 | 39 | for name, dim in zip(names, dims): 40 | weight = getattr(module, name) 41 | 42 | # remove w from parameter list 43 | del module._parameters[name] 44 | 45 | # add g and v as new parameters and express w as g/||v|| * v 46 | module.register_parameter(name + '_g', Parameter(_norm(weight, dim).data)) 47 | module.register_parameter(name + '_v', Parameter(weight.data)) 48 | setattr(module, name, compute_weight(module, name, dim)) 49 | 50 | module._wn_names = names 51 | module._wn_dims = dims 52 | 53 | 54 | def reset_atom_wn(module): 55 | # Typically, every time the module is called we need to recompute the weight. However, 56 | # in the case of DEQ, the same weight is shared across layers, and we can save 57 | # a lot of intermediate memory by just recomputing once (at the beginning of first call). 58 | 59 | for name, dim in zip(module._wn_names, module._wn_dims): 60 | setattr(module, name, compute_weight(module, name, dim)) 61 | 62 | 63 | def remove_atom_wn(module): 64 | for name, dim in zip(module._wn_names, module._wn_dims): 65 | weight = compute_weight(module, name, dim) 66 | delattr(module, name) 67 | del module._parameters[name + '_g'] 68 | del module._parameters[name + '_v'] 69 | module.register_parameter(name, Parameter(weight.data)) 70 | 71 | del module._wn_names 72 | del module._wn_dims 73 | 74 | 75 | target_modules = { 76 | nn.Linear: ('weight', 0), 77 | nn.Conv1d: ('weight', 0), 78 | nn.Conv2d: ('weight', 0), 79 | nn.Conv3d: ('weight', 0) 80 | } 81 | 82 | 83 | def register_wn_module(module_class, names='weight', dims=0): 84 | ''' 85 | Register your self-defined module class for ``nested_weight_norm''. 86 | This module class will be automatically indexed for WN. 87 | 88 | Args: 89 | module_class (type): module class to be indexed for weight norm (WN). 90 | names (string): attribute name of ``module_class'' for WN to be applied. 91 | dims (int, optional): dimension over which to compute the norm 92 | 93 | Returns: 94 | None 95 | ''' 96 | target_modules[module_class] = (names, dims) 97 | 98 | 99 | def _is_skip_name(name, filter_out): 100 | for skip_name in filter_out: 101 | if name.startswith(skip_name): 102 | return True 103 | 104 | return False 105 | 106 | 107 | def apply_weight_norm(model, filter_out=None): 108 | if type(filter_out) is str: 109 | filter_out = [filter_out] 110 | 111 | for name, module in model.named_modules(): 112 | if filter_out and _is_skip_name(name, filter_out): 113 | continue 114 | 115 | class_type = type(module) 116 | if class_type in target_modules: 117 | apply_atom_wn(module, *target_modules[class_type]) 118 | 119 | 120 | def reset_weight_norm(model): 121 | for module in model.modules(): 122 | if hasattr(module, '_wn_names'): 123 | reset_atom_wn(module) 124 | 125 | 126 | def remove_weight_norm(model): 127 | for module in model.modules(): 128 | if hasattr(module, '_wn_names'): 129 | remove_atom_wn(module) 130 | 131 | 132 | if __name__ == '__main__': 133 | z = torch.randn(8, 128, 32, 32) 134 | 135 | net = nn.Conv2d(128, 256, 3, padding=1) 136 | z_orig = net(z) 137 | 138 | apply_weight_norm(net) 139 | z_wn = net(z) 140 | 141 | reset_weight_norm(net) 142 | z_wn_reset = net(z) 143 | 144 | remove_weight_norm(net) 145 | z_back = net(z) 146 | 147 | print((z_orig - z_wn).abs().mean().item()) 148 | print((z_orig - z_wn_reset).abs().mean().item()) 149 | print((z_orig - z_back).abs().mean().item()) 150 | 151 | net = nn.Sequential( 152 | nn.Conv2d(128, 256, 3, padding=1), 153 | nn.GELU(), 154 | nn.Conv2d(256, 128, 3, padding=1) 155 | ) 156 | z_orig = net(z) 157 | 158 | apply_weight_norm(net) 159 | z_wn = net(z) 160 | 161 | reset_weight_norm(net) 162 | z_wn_reset = net(z) 163 | 164 | remove_weight_norm(net) 165 | z_back = net(z) 166 | 167 | print((z_orig - z_wn).abs().mean().item()) 168 | print((z_orig - z_wn_reset).abs().mean().item()) 169 | print((z_orig - z_back).abs().mean().item()) 170 | 171 | -------------------------------------------------------------------------------- /code.v.2.0/core/deq_flow.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from update import UpdateBlock 7 | from extractor import Encoder 8 | from corr import CorrBlock 9 | from utils.utils import bilinear_sampler, coords_grid 10 | 11 | from gma import Attention 12 | 13 | from deq import get_deq 14 | from deq.norm import apply_weight_norm, reset_weight_norm 15 | from deq.layer_utils import DEQWrapper 16 | 17 | from metrics import process_metrics 18 | 19 | 20 | try: 21 | autocast = torch.cuda.amp.autocast 22 | except: 23 | # dummy autocast for PyTorch < 1.6 24 | class autocast: 25 | def __init__(self, enabled): 26 | pass 27 | def __enter__(self): 28 | pass 29 | def __exit__(self, *args): 30 | pass 31 | 32 | 33 | class DEQFlow(nn.Module): 34 | def __init__(self, args): 35 | super(DEQFlow, self).__init__() 36 | self.args = args 37 | 38 | odim = 256 39 | args.corr_levels = 4 40 | args.corr_radius = 4 41 | 42 | if args.tiny: 43 | odim = 64 44 | self.hidden_dim = hdim = 32 45 | self.context_dim = cdim = 32 46 | elif args.large: 47 | self.hidden_dim = hdim = 192 48 | self.context_dim = cdim = 192 49 | elif args.huge: 50 | self.hidden_dim = hdim = 256 51 | self.context_dim = cdim = 256 52 | elif args.gigantic: 53 | self.hidden_dim = hdim = 384 54 | self.context_dim = cdim = 384 55 | else: 56 | self.hidden_dim = hdim = 128 57 | self.context_dim = cdim = 128 58 | 59 | if 'dropout' not in self.args: 60 | self.args.dropout = 0 61 | 62 | # feature network, context network, and update block 63 | self.fnet = Encoder(output_dim=odim, norm_fn='instance', dropout=args.dropout) 64 | self.cnet = Encoder(output_dim=cdim, norm_fn='batch', dropout=args.dropout) 65 | self.update_block = UpdateBlock(self.args, hidden_dim=hdim) 66 | 67 | self.mask = nn.Sequential( 68 | nn.Conv2d(hdim, 256, 3, padding=1), 69 | nn.ReLU(inplace=True), 70 | nn.Conv2d(256, 64*9, 1, padding=0) 71 | ) 72 | 73 | if args.gma: 74 | self.attn = Attention(dim=cdim, heads=1, max_pos_size=160, dim_head=cdim) 75 | else: 76 | self.attn = None 77 | 78 | # Added the following for DEQ 79 | if args.wnorm: 80 | apply_weight_norm(self.update_block) 81 | 82 | DEQ = get_deq(args) 83 | self.deq = DEQ(args) 84 | 85 | def freeze_bn(self): 86 | for m in self.modules(): 87 | if isinstance(m, nn.BatchNorm2d): 88 | m.eval() 89 | 90 | def _initialize_flow(self, img): 91 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 92 | N, _, H, W = img.shape 93 | coords0 = coords_grid(N, H//8, W//8, device=img.device) 94 | coords1 = coords_grid(N, H//8, W//8, device=img.device) 95 | 96 | # optical flow computed as difference: flow = coords1 - coords0 97 | return coords0, coords1 98 | 99 | def _upsample_flow(self, flow, mask): 100 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 101 | N, _, H, W = flow.shape 102 | mask = mask.view(N, 1, 9, 8, 8, H, W) 103 | mask = torch.softmax(mask, dim=2) 104 | 105 | up_flow = F.unfold(8 * flow, [3,3], padding=1) 106 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 107 | 108 | up_flow = torch.sum(mask * up_flow, dim=2) 109 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 110 | return up_flow.reshape(N, 2, 8*H, 8*W) 111 | 112 | def _decode(self, z_out, coords0): 113 | net, coords1 = z_out 114 | up_mask = .25 * self.mask(net) 115 | flow_up = self._upsample_flow(coords1 - coords0, up_mask) 116 | 117 | return flow_up 118 | 119 | def forward(self, image1, image2, 120 | flow_gt=None, valid=None, fc_loss=None, 121 | flow_init=None, cached_result=None, 122 | writer=None, sradius_mode=False, 123 | **kwargs): 124 | """ Estimate optical flow between pair of frames """ 125 | 126 | image1 = 2 * (image1 / 255.0) - 1.0 127 | image2 = 2 * (image2 / 255.0) - 1.0 128 | 129 | image1 = image1.contiguous() 130 | image2 = image2.contiguous() 131 | 132 | hdim = self.hidden_dim 133 | cdim = self.context_dim 134 | 135 | # run the feature network 136 | with autocast(enabled=self.args.mixed_precision): 137 | fmap1, fmap2 = self.fnet([image1, image2]) 138 | 139 | fmap1 = fmap1.float() 140 | fmap2 = fmap2.float() 141 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 142 | 143 | # run the context network 144 | with autocast(enabled=self.args.mixed_precision): 145 | # cnet = self.cnet(image1) 146 | # net, inp = torch.split(cnet, [hdim, cdim], dim=1) 147 | # net = torch.tanh(net) 148 | inp = self.cnet(image1) 149 | inp = torch.relu(inp) 150 | 151 | if self.attn: 152 | attn = self.attn(inp) 153 | else: 154 | attn = None 155 | 156 | bsz, _, H, W = inp.shape 157 | coords0, coords1 = self._initialize_flow(image1) 158 | net = torch.zeros(bsz, hdim, H, W, device=inp.device) 159 | 160 | if cached_result: 161 | net, flow_pred_prev = cached_result 162 | coords1 = coords0 + flow_pred_prev 163 | 164 | if flow_init is not None: 165 | coords1 = coords1 + flow_init 166 | 167 | if self.args.wnorm: 168 | reset_weight_norm(self.update_block) # Reset weights for WN 169 | 170 | def func(h,c): 171 | if not self.args.all_grad: 172 | c = c.detach() 173 | with autocast(enabled=self.args.mixed_precision): 174 | new_h, delta_flow = self.update_block(h, inp, corr_fn(c), c-coords0, attn) # corr_fn(coords1) produces the index correlation volumes 175 | new_c = c + delta_flow # F(t+1) = F(t) + \Delta(t) 176 | return new_h, new_c 177 | 178 | deq_func = DEQWrapper(func, (net, coords1)) 179 | z_init = deq_func.list2vec(net, coords1) 180 | log = (inp.get_device() == 0 and np.random.uniform(0,1) < 2e-3) 181 | 182 | z_out, info = self.deq(deq_func, z_init, log, sradius_mode, **kwargs) 183 | flow_pred = [self._decode(z, coords0) for z in z_out] 184 | 185 | if self.training: 186 | flow_loss, epe = fc_loss(flow_pred, flow_gt, valid) 187 | metrics = process_metrics(epe, info) 188 | 189 | return flow_loss, metrics 190 | else: 191 | (net, coords1), flow_up = z_out[-1], flow_pred[-1] 192 | 193 | return coords1 - coords0, flow_up, {"sradius": info['sradius'], "cached_result": (net, coords1 - coords0)} 194 | 195 | -------------------------------------------------------------------------------- /code.v.2.0/core/extractor.py: -------------------------------------------------------------------------------- 1 | # This code was originally taken from RAFT without modification 2 | # https://github.com/princeton-vl/RAFT/blob/master/core/extractor.py 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class ResidualBlock(nn.Module): 10 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 11 | super(ResidualBlock, self).__init__() 12 | 13 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) 14 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 15 | self.relu = nn.ReLU(inplace=True) 16 | 17 | num_groups = planes // 8 18 | 19 | if norm_fn == 'group': 20 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 21 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 22 | if not stride == 1: 23 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 24 | 25 | elif norm_fn == 'batch': 26 | self.norm1 = nn.BatchNorm2d(planes) 27 | self.norm2 = nn.BatchNorm2d(planes) 28 | if not stride == 1: 29 | self.norm3 = nn.BatchNorm2d(planes) 30 | 31 | elif norm_fn == 'instance': 32 | self.norm1 = nn.InstanceNorm2d(planes) 33 | self.norm2 = nn.InstanceNorm2d(planes) 34 | if not stride == 1: 35 | self.norm3 = nn.InstanceNorm2d(planes) 36 | 37 | elif norm_fn == 'none': 38 | self.norm1 = nn.Sequential() 39 | self.norm2 = nn.Sequential() 40 | if not stride == 1: 41 | self.norm3 = nn.Sequential() 42 | 43 | if stride == 1: 44 | self.downsample = None 45 | 46 | else: 47 | self.downsample = nn.Sequential( 48 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 49 | 50 | 51 | def forward(self, x): 52 | y = x 53 | y = self.relu(self.norm1(self.conv1(y))) 54 | y = self.relu(self.norm2(self.conv2(y))) 55 | 56 | if self.downsample is not None: 57 | x = self.downsample(x) 58 | 59 | return self.relu(x+y) 60 | 61 | 62 | class BottleneckBlock(nn.Module): 63 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 64 | super(BottleneckBlock, self).__init__() 65 | 66 | self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) 67 | self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) 68 | self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) 69 | self.relu = nn.ReLU(inplace=True) 70 | 71 | num_groups = planes // 8 72 | 73 | if norm_fn == 'group': 74 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 75 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 76 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 77 | if not stride == 1: 78 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 79 | 80 | elif norm_fn == 'batch': 81 | self.norm1 = nn.BatchNorm2d(planes//4) 82 | self.norm2 = nn.BatchNorm2d(planes//4) 83 | self.norm3 = nn.BatchNorm2d(planes) 84 | if not stride == 1: 85 | self.norm4 = nn.BatchNorm2d(planes) 86 | 87 | elif norm_fn == 'instance': 88 | self.norm1 = nn.InstanceNorm2d(planes//4) 89 | self.norm2 = nn.InstanceNorm2d(planes//4) 90 | self.norm3 = nn.InstanceNorm2d(planes) 91 | if not stride == 1: 92 | self.norm4 = nn.InstanceNorm2d(planes) 93 | 94 | elif norm_fn == 'none': 95 | self.norm1 = nn.Sequential() 96 | self.norm2 = nn.Sequential() 97 | self.norm3 = nn.Sequential() 98 | if not stride == 1: 99 | self.norm4 = nn.Sequential() 100 | 101 | if stride == 1: 102 | self.downsample = None 103 | 104 | else: 105 | self.downsample = nn.Sequential( 106 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) 107 | 108 | def forward(self, x): 109 | y = x 110 | y = self.relu(self.norm1(self.conv1(y))) 111 | y = self.relu(self.norm2(self.conv2(y))) 112 | y = self.relu(self.norm3(self.conv3(y))) 113 | 114 | if self.downsample is not None: 115 | x = self.downsample(x) 116 | 117 | return self.relu(x+y) 118 | 119 | 120 | class Encoder(nn.Module): 121 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 122 | super(Encoder, self).__init__() 123 | self.norm_fn = norm_fn 124 | 125 | if self.norm_fn == 'group': 126 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) 127 | 128 | elif self.norm_fn == 'batch': 129 | self.norm1 = nn.BatchNorm2d(64) 130 | 131 | elif self.norm_fn == 'instance': 132 | self.norm1 = nn.InstanceNorm2d(64) 133 | 134 | elif self.norm_fn == 'none': 135 | self.norm1 = nn.Sequential() 136 | 137 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 138 | self.relu1 = nn.ReLU(inplace=True) 139 | 140 | self.in_planes = 64 141 | self.layer1 = self._make_layer(64, stride=1) 142 | self.layer2 = self._make_layer(96, stride=2) 143 | self.layer3 = self._make_layer(128, stride=2) 144 | 145 | # output convolution 146 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) 147 | 148 | self.dropout = None 149 | if dropout > 0: 150 | self.dropout = nn.Dropout2d(p=dropout) 151 | 152 | for m in self.modules(): 153 | if isinstance(m, nn.Conv2d): 154 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 155 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 156 | if m.weight is not None: 157 | nn.init.constant_(m.weight, 1) 158 | if m.bias is not None: 159 | nn.init.constant_(m.bias, 0) 160 | 161 | def _make_layer(self, dim, stride=1): 162 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) 163 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 164 | layers = (layer1, layer2) 165 | 166 | self.in_planes = dim 167 | return nn.Sequential(*layers) 168 | 169 | 170 | def forward(self, x): 171 | 172 | # if input is list, combine batch dimension 173 | is_list = isinstance(x, tuple) or isinstance(x, list) 174 | if is_list: 175 | batch_dim = x[0].shape[0] 176 | x = torch.cat(x, dim=0) 177 | 178 | x = self.conv1(x) 179 | x = self.norm1(x) 180 | x = self.relu1(x) 181 | 182 | x = self.layer1(x) 183 | x = self.layer2(x) 184 | x = self.layer3(x) 185 | 186 | x = self.conv2(x) 187 | 188 | if self.training and self.dropout is not None: 189 | x = self.dropout(x) 190 | 191 | if is_list: 192 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 193 | 194 | return x 195 | 196 | 197 | -------------------------------------------------------------------------------- /code.v.2.0/core/gma.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | from einops import rearrange 4 | 5 | 6 | class RelPosEmb(nn.Module): 7 | def __init__( 8 | self, 9 | max_pos_size, 10 | dim_head 11 | ): 12 | super().__init__() 13 | self.rel_height = nn.Embedding(2 * max_pos_size - 1, dim_head) 14 | self.rel_width = nn.Embedding(2 * max_pos_size - 1, dim_head) 15 | 16 | deltas = torch.arange(max_pos_size).view(1, -1) - torch.arange(max_pos_size).view(-1, 1) 17 | rel_ind = deltas + max_pos_size - 1 18 | self.register_buffer('rel_ind', rel_ind) 19 | 20 | def forward(self, q): 21 | batch, heads, h, w, c = q.shape 22 | height_emb = self.rel_height(self.rel_ind[:h, :h].reshape(-1)) 23 | width_emb = self.rel_width(self.rel_ind[:w, :w].reshape(-1)) 24 | 25 | height_emb = rearrange(height_emb, '(x u) d -> x u () d', x=h) 26 | width_emb = rearrange(width_emb, '(y v) d -> y () v d', y=w) 27 | 28 | height_score = einsum('b h x y d, x u v d -> b h x y u v', q, height_emb) 29 | width_score = einsum('b h x y d, y u v d -> b h x y u v', q, width_emb) 30 | 31 | return height_score + width_score 32 | 33 | 34 | class Attention(nn.Module): 35 | def __init__( 36 | self, 37 | *, 38 | dim, 39 | max_pos_size = 100, 40 | heads = 4, 41 | dim_head = 128, 42 | ): 43 | super().__init__() 44 | self.heads = heads 45 | self.scale = dim_head ** -0.5 46 | inner_dim = heads * dim_head 47 | 48 | self.to_qk = nn.Conv2d(dim, inner_dim * 2, 1, bias=False) 49 | 50 | def forward(self, fmap): 51 | heads, b, c, h, w = self.heads, *fmap.shape 52 | 53 | q, k = self.to_qk(fmap).chunk(2, dim=1) 54 | 55 | q, k = map(lambda t: rearrange(t, 'b (h d) x y -> b h x y d', h=heads), (q, k)) 56 | q = self.scale * q 57 | 58 | sim = einsum('b h x y d, b h u v d -> b h x y u v', q, k) 59 | 60 | sim = rearrange(sim, 'b h x y u v -> b h (x y) (u v)') 61 | attn = sim.softmax(dim=-1) 62 | 63 | return attn 64 | 65 | 66 | class Aggregate(nn.Module): 67 | def __init__( 68 | self, 69 | dim, 70 | heads = 4, 71 | dim_head = 128, 72 | ): 73 | super().__init__() 74 | self.heads = heads 75 | self.scale = dim_head ** -0.5 76 | inner_dim = heads * dim_head 77 | 78 | self.to_v = nn.Conv2d(dim, inner_dim, 1, bias=False) 79 | 80 | self.gamma = nn.Parameter(torch.zeros(1)) 81 | 82 | if dim != inner_dim: 83 | self.project = nn.Conv2d(inner_dim, dim, 1, bias=False) 84 | else: 85 | self.project = None 86 | 87 | def forward(self, attn, fmap): 88 | heads, b, c, h, w = self.heads, *fmap.shape 89 | 90 | v = self.to_v(fmap) 91 | v = rearrange(v, 'b (h d) x y -> b h (x y) d', h=heads) 92 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 93 | out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w) 94 | 95 | if self.project is not None: 96 | out = self.project(out) 97 | 98 | out = fmap + self.gamma * out 99 | 100 | return out 101 | 102 | 103 | if __name__ == "__main__": 104 | att = Attention(dim=128, heads=1) 105 | fmap = torch.randn(2, 128, 40, 90) 106 | out = att(fmap) 107 | 108 | print(out.shape) 109 | -------------------------------------------------------------------------------- /code.v.2.0/core/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | MAX_FLOW = 400 4 | 5 | @torch.no_grad() 6 | def compute_epe(flow_pred, flow_gt, valid, max_flow=MAX_FLOW): 7 | # exlude invalid pixels and extremely large diplacements 8 | mag = torch.sum(flow_gt**2, dim=1).sqrt() 9 | valid = (valid >= 0.5) & (mag < max_flow) 10 | 11 | epe = torch.sum((flow_pred - flow_gt)**2, dim=1).sqrt() 12 | epe = torch.masked_fill(epe, ~valid, 0) 13 | 14 | return epe 15 | 16 | 17 | @torch.no_grad() 18 | def process_metrics(epe, info, **kwargs): 19 | epe = epe.flatten(1) 20 | metrics = { 21 | 'epe': epe.mean(dim=1), 22 | '1px': (epe < 1).float().mean(dim=1), 23 | '3px': (epe < 3).float().mean(dim=1), 24 | '5px': (epe < 5).float().mean(dim=1), 25 | 'rel': info['rel_lowest'], 26 | 'abs': info['abs_lowest'], 27 | } 28 | 29 | # dict: N_Metrics -> B // N_GPU 30 | return metrics 31 | 32 | 33 | @torch.no_grad() 34 | def merge_metrics(metrics): 35 | out = dict() 36 | 37 | for key, value in metrics.items(): 38 | out[key] = value.mean().item() 39 | 40 | return out 41 | -------------------------------------------------------------------------------- /code.v.2.0/core/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from gma import Aggregate 6 | 7 | 8 | class FlowHead(nn.Module): 9 | def __init__(self, input_dim=128, hidden_dim=256): 10 | super(FlowHead, self).__init__() 11 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 12 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 13 | self.relu = nn.ReLU(inplace=True) 14 | 15 | def forward(self, x): 16 | return self.conv2(self.relu(self.conv1(x))) 17 | 18 | 19 | class ConvGRU(nn.Module): 20 | def __init__(self, hidden_dim=128, input_dim=192+128): 21 | super(ConvGRU, self).__init__() 22 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 23 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 24 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 25 | 26 | def forward(self, h, x): 27 | hx = torch.cat([h, x], dim=1) 28 | 29 | z = torch.sigmoid(self.convz(hx)) 30 | r = torch.sigmoid(self.convr(hx)) 31 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) 32 | 33 | h = (1-z) * h + z * q 34 | return h 35 | 36 | 37 | class SepConvGRU(nn.Module): 38 | def __init__(self, hidden_dim=128, input_dim=192+128): 39 | super(SepConvGRU, self).__init__() 40 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 41 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 42 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 43 | 44 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 45 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 46 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 47 | 48 | def forward(self, h, x): 49 | # horizontal 50 | hx = torch.cat([h, x], dim=1) 51 | z = torch.sigmoid(self.convz1(hx)) 52 | r = torch.sigmoid(self.convr1(hx)) 53 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 54 | h = (1-z) * h + z * q 55 | 56 | # vertical 57 | hx = torch.cat([h, x], dim=1) 58 | z = torch.sigmoid(self.convz2(hx)) 59 | r = torch.sigmoid(self.convr2(hx)) 60 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 61 | h = (1-z) * h + z * q 62 | 63 | return h 64 | 65 | 66 | class MotionEncoder(nn.Module): 67 | def __init__(self, args): 68 | super(MotionEncoder, self).__init__() 69 | 70 | if args.large: 71 | c_dim_1 = 256 + 128 72 | c_dim_2 = 192 + 96 73 | 74 | f_dim_1 = 128 + 64 75 | f_dim_2 = 64 + 32 76 | 77 | cat_dim = 128 + 64 78 | elif args.huge: 79 | c_dim_1 = 256 + 256 80 | c_dim_2 = 192 + 192 81 | 82 | f_dim_1 = 128 + 128 83 | f_dim_2 = 64 + 64 84 | 85 | cat_dim = 128 + 128 86 | elif args.gigantic: 87 | c_dim_1 = 256 + 384 88 | c_dim_2 = 192 + 288 89 | 90 | f_dim_1 = 128 + 192 91 | f_dim_2 = 64 + 96 92 | 93 | cat_dim = 128 + 192 94 | elif args.tiny: 95 | c_dim_1 = 64 96 | c_dim_2 = 48 97 | 98 | f_dim_1 = 32 99 | f_dim_2 = 16 100 | 101 | cat_dim = 32 102 | else: 103 | c_dim_1 = 256 104 | c_dim_2 = 192 105 | 106 | f_dim_1 = 128 107 | f_dim_2 = 64 108 | 109 | cat_dim = 128 110 | 111 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 112 | self.convc1 = nn.Conv2d(cor_planes, c_dim_1, 1, padding=0) 113 | self.convc2 = nn.Conv2d(c_dim_1, c_dim_2, 3, padding=1) 114 | self.convf1 = nn.Conv2d(2, f_dim_1, 7, padding=3) 115 | self.convf2 = nn.Conv2d(f_dim_1, f_dim_2, 3, padding=1) 116 | self.conv = nn.Conv2d(c_dim_2+f_dim_2, cat_dim-2, 3, padding=1) 117 | 118 | def forward(self, flow, corr): 119 | cor = F.relu(self.convc1(corr)) 120 | cor = F.relu(self.convc2(cor)) 121 | flo = F.relu(self.convf1(flow)) 122 | flo = F.relu(self.convf2(flo)) 123 | 124 | cor_flo = torch.cat([cor, flo], dim=1) 125 | out = F.relu(self.conv(cor_flo)) 126 | return torch.cat([out, flow], dim=1) 127 | 128 | 129 | class UpdateBlock(nn.Module): 130 | def __init__(self, args, hidden_dim=128, input_dim=128): 131 | super(UpdateBlock, self).__init__() 132 | self.args = args 133 | 134 | if args.tiny: 135 | cat_dim = 32 136 | elif args.large: 137 | cat_dim = 128 + 64 138 | elif args.huge: 139 | cat_dim = 128 + 128 140 | elif args.gigantic: 141 | cat_dim = 128 + 192 142 | else: 143 | cat_dim = 128 144 | 145 | if args.old_version: 146 | flow_head_dim = min(256, 2*cat_dim) 147 | else: 148 | flow_head_dim = 2*cat_dim 149 | 150 | self.encoder = MotionEncoder(args) 151 | 152 | if args.gma: 153 | self.gma = Aggregate(dim=cat_dim, dim_head=cat_dim, heads=1) 154 | 155 | gru_in_dim = 2 * cat_dim + hidden_dim 156 | else: 157 | self.gma = None 158 | 159 | gru_in_dim = cat_dim + hidden_dim 160 | 161 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=gru_in_dim) 162 | self.flow_head = FlowHead(hidden_dim, hidden_dim=flow_head_dim) 163 | 164 | def forward(self, net, inp, corr, flow, attn=None, upsample=True): 165 | motion_features = self.encoder(flow, corr) 166 | 167 | if self.gma: 168 | motion_features_global = self.gma(attn, motion_features) 169 | inp = torch.cat([inp, motion_features, motion_features_global], dim=1) 170 | else: 171 | inp = torch.cat([inp, motion_features], dim=1) 172 | 173 | net = self.gru(net, inp) 174 | delta_flow = self.flow_head(net) 175 | 176 | return net, delta_flow 177 | 178 | 179 | 180 | -------------------------------------------------------------------------------- /code.v.2.0/core/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/locuslab/deq-flow/869d4bcd9c1d227cbfa16639994d1b118510a9bb/code.v.2.0/core/utils/__init__.py -------------------------------------------------------------------------------- /code.v.2.0/core/utils/augmentor.py: -------------------------------------------------------------------------------- 1 | # This code was originally taken from RAFT without modification 2 | # https://github.com/princeton-vl/RAFT/blob/master/utils/augmentor.py 3 | 4 | import numpy as np 5 | import random 6 | import math 7 | from PIL import Image 8 | 9 | import cv2 10 | cv2.setNumThreads(0) 11 | cv2.ocl.setUseOpenCL(False) 12 | 13 | import torch 14 | from torchvision.transforms import ColorJitter 15 | import torch.nn.functional as F 16 | 17 | 18 | class FlowAugmentor: 19 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): 20 | 21 | # spatial augmentation params 22 | self.crop_size = crop_size 23 | self.min_scale = min_scale 24 | self.max_scale = max_scale 25 | self.spatial_aug_prob = 0.8 26 | self.stretch_prob = 0.8 27 | self.max_stretch = 0.2 28 | 29 | # flip augmentation params 30 | self.do_flip = do_flip 31 | self.h_flip_prob = 0.5 32 | self.v_flip_prob = 0.1 33 | 34 | # photometric augmentation params 35 | self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14) 36 | self.asymmetric_color_aug_prob = 0.2 37 | self.eraser_aug_prob = 0.5 38 | 39 | def color_transform(self, img1, img2): 40 | """ Photometric augmentation """ 41 | 42 | # asymmetric 43 | if np.random.rand() < self.asymmetric_color_aug_prob: 44 | img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) 45 | img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) 46 | 47 | # symmetric 48 | else: 49 | image_stack = np.concatenate([img1, img2], axis=0) 50 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 51 | img1, img2 = np.split(image_stack, 2, axis=0) 52 | 53 | return img1, img2 54 | 55 | def eraser_transform(self, img1, img2, bounds=[50, 100]): 56 | """ Occlusion augmentation """ 57 | 58 | ht, wd = img1.shape[:2] 59 | if np.random.rand() < self.eraser_aug_prob: 60 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 61 | for _ in range(np.random.randint(1, 3)): 62 | x0 = np.random.randint(0, wd) 63 | y0 = np.random.randint(0, ht) 64 | dx = np.random.randint(bounds[0], bounds[1]) 65 | dy = np.random.randint(bounds[0], bounds[1]) 66 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color 67 | 68 | return img1, img2 69 | 70 | def spatial_transform(self, img1, img2, flow): 71 | # randomly sample scale 72 | ht, wd = img1.shape[:2] 73 | min_scale = np.maximum( 74 | (self.crop_size[0] + 8) / float(ht), 75 | (self.crop_size[1] + 8) / float(wd)) 76 | 77 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 78 | scale_x = scale 79 | scale_y = scale 80 | if np.random.rand() < self.stretch_prob: 81 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 82 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 83 | 84 | scale_x = np.clip(scale_x, min_scale, None) 85 | scale_y = np.clip(scale_y, min_scale, None) 86 | 87 | if np.random.rand() < self.spatial_aug_prob: 88 | # rescale the images 89 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 90 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 91 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 92 | flow = flow * [scale_x, scale_y] 93 | 94 | if self.do_flip: 95 | if np.random.rand() < self.h_flip_prob: # h-flip 96 | img1 = img1[:, ::-1] 97 | img2 = img2[:, ::-1] 98 | flow = flow[:, ::-1] * [-1.0, 1.0] 99 | 100 | if np.random.rand() < self.v_flip_prob: # v-flip 101 | img1 = img1[::-1, :] 102 | img2 = img2[::-1, :] 103 | flow = flow[::-1, :] * [1.0, -1.0] 104 | 105 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) 106 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) 107 | 108 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 109 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 110 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 111 | 112 | return img1, img2, flow 113 | 114 | def __call__(self, img1, img2, flow): 115 | img1, img2 = self.color_transform(img1, img2) 116 | img1, img2 = self.eraser_transform(img1, img2) 117 | img1, img2, flow = self.spatial_transform(img1, img2, flow) 118 | 119 | img1 = np.ascontiguousarray(img1) 120 | img2 = np.ascontiguousarray(img2) 121 | flow = np.ascontiguousarray(flow) 122 | 123 | return img1, img2, flow 124 | 125 | class SparseFlowAugmentor: 126 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): 127 | # spatial augmentation params 128 | self.crop_size = crop_size 129 | self.min_scale = min_scale 130 | self.max_scale = max_scale 131 | self.spatial_aug_prob = 0.8 132 | self.stretch_prob = 0.8 133 | self.max_stretch = 0.2 134 | 135 | # flip augmentation params 136 | self.do_flip = do_flip 137 | self.h_flip_prob = 0.5 138 | self.v_flip_prob = 0.1 139 | 140 | # photometric augmentation params 141 | self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) 142 | self.asymmetric_color_aug_prob = 0.2 143 | self.eraser_aug_prob = 0.5 144 | 145 | def color_transform(self, img1, img2): 146 | image_stack = np.concatenate([img1, img2], axis=0) 147 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 148 | img1, img2 = np.split(image_stack, 2, axis=0) 149 | return img1, img2 150 | 151 | def eraser_transform(self, img1, img2): 152 | ht, wd = img1.shape[:2] 153 | if np.random.rand() < self.eraser_aug_prob: 154 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 155 | for _ in range(np.random.randint(1, 3)): 156 | x0 = np.random.randint(0, wd) 157 | y0 = np.random.randint(0, ht) 158 | dx = np.random.randint(50, 100) 159 | dy = np.random.randint(50, 100) 160 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color 161 | 162 | return img1, img2 163 | 164 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): 165 | ht, wd = flow.shape[:2] 166 | coords = np.meshgrid(np.arange(wd), np.arange(ht)) 167 | coords = np.stack(coords, axis=-1) 168 | 169 | coords = coords.reshape(-1, 2).astype(np.float32) 170 | flow = flow.reshape(-1, 2).astype(np.float32) 171 | valid = valid.reshape(-1).astype(np.float32) 172 | 173 | coords0 = coords[valid>=1] 174 | flow0 = flow[valid>=1] 175 | 176 | ht1 = int(round(ht * fy)) 177 | wd1 = int(round(wd * fx)) 178 | 179 | coords1 = coords0 * [fx, fy] 180 | flow1 = flow0 * [fx, fy] 181 | 182 | xx = np.round(coords1[:,0]).astype(np.int32) 183 | yy = np.round(coords1[:,1]).astype(np.int32) 184 | 185 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) 186 | xx = xx[v] 187 | yy = yy[v] 188 | flow1 = flow1[v] 189 | 190 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) 191 | valid_img = np.zeros([ht1, wd1], dtype=np.int32) 192 | 193 | flow_img[yy, xx] = flow1 194 | valid_img[yy, xx] = 1 195 | 196 | return flow_img, valid_img 197 | 198 | def spatial_transform(self, img1, img2, flow, valid): 199 | # randomly sample scale 200 | 201 | ht, wd = img1.shape[:2] 202 | min_scale = np.maximum( 203 | (self.crop_size[0] + 1) / float(ht), 204 | (self.crop_size[1] + 1) / float(wd)) 205 | 206 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 207 | scale_x = np.clip(scale, min_scale, None) 208 | scale_y = np.clip(scale, min_scale, None) 209 | 210 | if np.random.rand() < self.spatial_aug_prob: 211 | # rescale the images 212 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 213 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 214 | flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) 215 | 216 | if self.do_flip: 217 | if np.random.rand() < 0.5: # h-flip 218 | img1 = img1[:, ::-1] 219 | img2 = img2[:, ::-1] 220 | flow = flow[:, ::-1] * [-1.0, 1.0] 221 | valid = valid[:, ::-1] 222 | 223 | margin_y = 20 224 | margin_x = 50 225 | 226 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) 227 | x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) 228 | 229 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) 230 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) 231 | 232 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 233 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 234 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 235 | valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 236 | return img1, img2, flow, valid 237 | 238 | 239 | def __call__(self, img1, img2, flow, valid): 240 | img1, img2 = self.color_transform(img1, img2) 241 | img1, img2 = self.eraser_transform(img1, img2) 242 | img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) 243 | 244 | img1 = np.ascontiguousarray(img1) 245 | img2 = np.ascontiguousarray(img2) 246 | flow = np.ascontiguousarray(flow) 247 | valid = np.ascontiguousarray(valid) 248 | 249 | return img1, img2, flow, valid 250 | -------------------------------------------------------------------------------- /code.v.2.0/core/utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | 20 | def make_colorwheel(): 21 | """ 22 | Generates a color wheel for optical flow visualization as presented in: 23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 25 | 26 | Code follows the original C++ source code of Daniel Scharstein. 27 | Code follows the the Matlab source code of Deqing Sun. 28 | 29 | Returns: 30 | np.ndarray: Color wheel 31 | """ 32 | 33 | RY = 15 34 | YG = 6 35 | GC = 4 36 | CB = 11 37 | BM = 13 38 | MR = 6 39 | 40 | ncols = RY + YG + GC + CB + BM + MR 41 | colorwheel = np.zeros((ncols, 3)) 42 | col = 0 43 | 44 | # RY 45 | colorwheel[0:RY, 0] = 255 46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 47 | col = col+RY 48 | # YG 49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 50 | colorwheel[col:col+YG, 1] = 255 51 | col = col+YG 52 | # GC 53 | colorwheel[col:col+GC, 1] = 255 54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 55 | col = col+GC 56 | # CB 57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 58 | colorwheel[col:col+CB, 2] = 255 59 | col = col+CB 60 | # BM 61 | colorwheel[col:col+BM, 2] = 255 62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 63 | col = col+BM 64 | # MR 65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 66 | colorwheel[col:col+MR, 0] = 255 67 | return colorwheel 68 | 69 | 70 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 71 | """ 72 | Applies the flow color wheel to (possibly clipped) flow components u and v. 73 | 74 | According to the C++ source code of Daniel Scharstein 75 | According to the Matlab source code of Deqing Sun 76 | 77 | Args: 78 | u (np.ndarray): Input horizontal flow of shape [H,W] 79 | v (np.ndarray): Input vertical flow of shape [H,W] 80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 81 | 82 | Returns: 83 | np.ndarray: Flow visualization image of shape [H,W,3] 84 | """ 85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 86 | colorwheel = make_colorwheel() # shape [55x3] 87 | ncols = colorwheel.shape[0] 88 | rad = np.sqrt(np.square(u) + np.square(v)) 89 | a = np.arctan2(-v, -u)/np.pi 90 | fk = (a+1) / 2*(ncols-1) 91 | k0 = np.floor(fk).astype(np.int32) 92 | k1 = k0 + 1 93 | k1[k1 == ncols] = 0 94 | f = fk - k0 95 | for i in range(colorwheel.shape[1]): 96 | tmp = colorwheel[:,i] 97 | col0 = tmp[k0] / 255.0 98 | col1 = tmp[k1] / 255.0 99 | col = (1-f)*col0 + f*col1 100 | idx = (rad <= 1) 101 | col[idx] = 1 - rad[idx] * (1-col[idx]) 102 | col[~idx] = col[~idx] * 0.75 # out of range 103 | # Note the 2-i => BGR instead of RGB 104 | ch_idx = 2-i if convert_to_bgr else i 105 | flow_image[:,:,ch_idx] = np.floor(255 * col) 106 | return flow_image 107 | 108 | 109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 110 | """ 111 | Expects a two dimensional flow image of shape. 112 | 113 | Args: 114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 117 | 118 | Returns: 119 | np.ndarray: Flow visualization image of shape [H,W,3] 120 | """ 121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 123 | if clip_flow is not None: 124 | flow_uv = np.clip(flow_uv, 0, clip_flow) 125 | u = flow_uv[:,:,0] 126 | v = flow_uv[:,:,1] 127 | rad = np.sqrt(np.square(u) + np.square(v)) 128 | rad_max = np.max(rad) 129 | epsilon = 1e-5 130 | u = u / (rad_max + epsilon) 131 | v = v / (rad_max + epsilon) 132 | return flow_uv_to_colors(u, v, convert_to_bgr) -------------------------------------------------------------------------------- /code.v.2.0/core/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | # This code was originally taken from RAFT without modification 2 | # https://github.com/princeton-vl/RAFT/blob/master/utils/frame_utils.py 3 | 4 | import numpy as np 5 | from PIL import Image 6 | from os.path import * 7 | import re 8 | 9 | import cv2 10 | cv2.setNumThreads(0) 11 | cv2.ocl.setUseOpenCL(False) 12 | 13 | TAG_CHAR = np.array([202021.25], np.float32) 14 | 15 | def readFlow(fn): 16 | """ Read .flo file in Middlebury format""" 17 | # Code adapted from: 18 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 19 | 20 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 21 | # print 'fn = %s'%(fn) 22 | with open(fn, 'rb') as f: 23 | magic = np.fromfile(f, np.float32, count=1) 24 | if 202021.25 != magic: 25 | print('Magic number incorrect. Invalid .flo file') 26 | return None 27 | else: 28 | w = np.fromfile(f, np.int32, count=1) 29 | h = np.fromfile(f, np.int32, count=1) 30 | # print 'Reading %d x %d flo file\n' % (w, h) 31 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) 32 | # Reshape data into 3D array (columns, rows, bands) 33 | # The reshape here is for visualization, the original code is (w,h,2) 34 | return np.resize(data, (int(h), int(w), 2)) 35 | 36 | def readPFM(file): 37 | file = open(file, 'rb') 38 | 39 | color = None 40 | width = None 41 | height = None 42 | scale = None 43 | endian = None 44 | 45 | header = file.readline().rstrip() 46 | if header == b'PF': 47 | color = True 48 | elif header == b'Pf': 49 | color = False 50 | else: 51 | raise Exception('Not a PFM file.') 52 | 53 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 54 | if dim_match: 55 | width, height = map(int, dim_match.groups()) 56 | else: 57 | raise Exception('Malformed PFM header.') 58 | 59 | scale = float(file.readline().rstrip()) 60 | if scale < 0: # little-endian 61 | endian = '<' 62 | scale = -scale 63 | else: 64 | endian = '>' # big-endian 65 | 66 | data = np.fromfile(file, endian + 'f') 67 | shape = (height, width, 3) if color else (height, width) 68 | 69 | data = np.reshape(data, shape) 70 | data = np.flipud(data) 71 | return data 72 | 73 | def writeFlow(filename,uv,v=None): 74 | """ Write optical flow to file. 75 | 76 | If v is None, uv is assumed to contain both u and v channels, 77 | stacked in depth. 78 | Original code by Deqing Sun, adapted from Daniel Scharstein. 79 | """ 80 | nBands = 2 81 | 82 | if v is None: 83 | assert(uv.ndim == 3) 84 | assert(uv.shape[2] == 2) 85 | u = uv[:,:,0] 86 | v = uv[:,:,1] 87 | else: 88 | u = uv 89 | 90 | assert(u.shape == v.shape) 91 | height,width = u.shape 92 | f = open(filename,'wb') 93 | # write the header 94 | f.write(TAG_CHAR) 95 | np.array(width).astype(np.int32).tofile(f) 96 | np.array(height).astype(np.int32).tofile(f) 97 | # arrange into matrix form 98 | tmp = np.zeros((height, width*nBands)) 99 | tmp[:,np.arange(width)*2] = u 100 | tmp[:,np.arange(width)*2 + 1] = v 101 | tmp.astype(np.float32).tofile(f) 102 | f.close() 103 | 104 | 105 | def readFlowKITTI(filename): 106 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) 107 | flow = flow[:,:,::-1].astype(np.float32) 108 | flow, valid = flow[:, :, :2], flow[:, :, 2] 109 | flow = (flow - 2**15) / 64.0 110 | return flow, valid 111 | 112 | def readDispKITTI(filename): 113 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 114 | valid = disp > 0.0 115 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 116 | return flow, valid 117 | 118 | 119 | def writeFlowKITTI(filename, uv): 120 | uv = 64.0 * uv + 2**15 121 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 122 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 123 | cv2.imwrite(filename, uv[..., ::-1]) 124 | 125 | 126 | def read_gen(file_name, pil=False): 127 | ext = splitext(file_name)[-1] 128 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 129 | return Image.open(file_name) 130 | elif ext == '.bin' or ext == '.raw': 131 | return np.load(file_name) 132 | elif ext == '.flo': 133 | return readFlow(file_name).astype(np.float32) 134 | elif ext == '.pfm': 135 | flow = readPFM(file_name).astype(np.float32) 136 | if len(flow.shape) == 2: 137 | return flow 138 | else: 139 | return flow[:, :, :-1] 140 | return [] 141 | -------------------------------------------------------------------------------- /code.v.2.0/core/utils/grid_sample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def grid_sample(image, optical): 6 | N, C, IH, IW = image.shape 7 | _, H, W, _ = optical.shape 8 | 9 | ix = optical[..., 0] 10 | iy = optical[..., 1] 11 | 12 | ix = ((ix + 1) / 2) * (IW-1) 13 | iy = ((iy + 1) / 2) * (IH-1) 14 | with torch.no_grad(): 15 | ix_nw = torch.floor(ix) 16 | iy_nw = torch.floor(iy) 17 | ix_ne = ix_nw + 1 18 | iy_ne = iy_nw 19 | ix_sw = ix_nw 20 | iy_sw = iy_nw + 1 21 | ix_se = ix_nw + 1 22 | iy_se = iy_nw + 1 23 | 24 | nw = (ix_se - ix) * (iy_se - iy) 25 | ne = (ix - ix_sw) * (iy_sw - iy) 26 | sw = (ix_ne - ix) * (iy - iy_ne) 27 | se = (ix - ix_nw) * (iy - iy_nw) 28 | 29 | with torch.no_grad(): 30 | torch.clamp(ix_nw, 0, IW-1, out=ix_nw) 31 | torch.clamp(iy_nw, 0, IH-1, out=iy_nw) 32 | 33 | torch.clamp(ix_ne, 0, IW-1, out=ix_ne) 34 | torch.clamp(iy_ne, 0, IH-1, out=iy_ne) 35 | 36 | torch.clamp(ix_sw, 0, IW-1, out=ix_sw) 37 | torch.clamp(iy_sw, 0, IH-1, out=iy_sw) 38 | 39 | torch.clamp(ix_se, 0, IW-1, out=ix_se) 40 | torch.clamp(iy_se, 0, IH-1, out=iy_se) 41 | 42 | image = image.view(N, C, IH * IW) 43 | 44 | 45 | nw_val = torch.gather(image, 2, (iy_nw * IW + ix_nw).long().view(N, 1, H * W).repeat(1, C, 1)) 46 | ne_val = torch.gather(image, 2, (iy_ne * IW + ix_ne).long().view(N, 1, H * W).repeat(1, C, 1)) 47 | sw_val = torch.gather(image, 2, (iy_sw * IW + ix_sw).long().view(N, 1, H * W).repeat(1, C, 1)) 48 | se_val = torch.gather(image, 2, (iy_se * IW + ix_se).long().view(N, 1, H * W).repeat(1, C, 1)) 49 | 50 | out_val = (nw_val.view(N, C, H, W) * nw.view(N, 1, H, W) + 51 | ne_val.view(N, C, H, W) * ne.view(N, 1, H, W) + 52 | sw_val.view(N, C, H, W) * sw.view(N, 1, H, W) + 53 | se_val.view(N, C, H, W) * se.view(N, 1, H, W)) 54 | 55 | return out_val -------------------------------------------------------------------------------- /code.v.2.0/core/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | from .grid_sample import grid_sample 7 | 8 | 9 | class InputPadder: 10 | """ Pads images such that dimensions are divisible by 8 """ 11 | def __init__(self, dims, mode='sintel'): 12 | self.ht, self.wd = dims[-2:] 13 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 14 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 15 | if mode == 'sintel': 16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 17 | else: 18 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 19 | 20 | def pad(self, *inputs): 21 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 22 | 23 | def unpad(self,x): 24 | ht, wd = x.shape[-2:] 25 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 26 | return x[..., c[0]:c[1], c[2]:c[3]] 27 | 28 | 29 | def forward_interpolate(flow): 30 | flow = flow.detach().cpu().numpy() 31 | dx, dy = flow[0], flow[1] 32 | 33 | ht, wd = dx.shape 34 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 35 | 36 | x1 = x0 + dx 37 | y1 = y0 + dy 38 | 39 | x1 = x1.reshape(-1) 40 | y1 = y1.reshape(-1) 41 | dx = dx.reshape(-1) 42 | dy = dy.reshape(-1) 43 | 44 | # valid = (x1 > 0.1 * wd) & (x1 < 0.9 * wd) & (y1 > 0.1 * ht) & (y1 < 0.9 * ht) 45 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 46 | x1 = x1[valid] 47 | y1 = y1[valid] 48 | dx = dx[valid] 49 | dy = dy[valid] 50 | 51 | flow_x = interpolate.griddata( 52 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 53 | 54 | flow_y = interpolate.griddata( 55 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 56 | 57 | flow = np.stack([flow_x, flow_y], axis=0) 58 | return torch.from_numpy(flow).float() 59 | 60 | 61 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 62 | """ Wrapper for grid_sample, uses pixel coordinates """ 63 | H, W = img.shape[-2:] 64 | xgrid, ygrid = coords.split([1,1], dim=-1) 65 | xgrid = 2*xgrid/(W-1) - 1 66 | ygrid = 2*ygrid/(H-1) - 1 67 | 68 | grid = torch.cat([xgrid, ygrid], dim=-1) 69 | img = F.grid_sample(img, grid, align_corners=True) 70 | 71 | # Enable higher order grad for JR 72 | # img = grid_sample(img, grid) 73 | 74 | if mask: 75 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 76 | return img, mask.float() 77 | 78 | return img 79 | 80 | 81 | def coords_grid(batch, ht, wd, device): 82 | coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) 83 | coords = torch.stack(coords[::-1], dim=0).float() 84 | return coords[None].repeat(batch, 1, 1, 1) 85 | 86 | 87 | def upflow8(flow, mode='bilinear'): 88 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 89 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 90 | -------------------------------------------------------------------------------- /code.v.2.0/log/val.txt: -------------------------------------------------------------------------------- 1 | Parameter Count: 13.395 M 2 | Load from checkpoints/deq-flow-H-things-test-1x.pth 3 | Validation KITTI: EPE: 3.763 (3.763), F1: 12.95 (12.95) 4 | VALID | FORWARD | rel: 0.0005664261989295483; abs: 3.7999179363250732; nstep: tensor([60], device='cuda:0') 5 | VALID | FORWARD | rel: 0.0002732036809902638; abs: 1.8421716690063477; nstep: tensor([60], device='cuda:0') 6 | Validation (clean) EPE: 1.266 (1.266), 1px: 91.50, 3px: 96.20, 5px: 97.24 7 | VALID | FORWARD | rel: 0.0008356361649930477; abs: 5.66709566116333; nstep: tensor([60], device='cuda:0') 8 | Validation (final) EPE: 2.584 (2.584), 1px: 86.48, 3px: 92.51, 5px: 94.21 9 | Parameter Count: 13.395 M 10 | Load from checkpoints/deq-flow-H-things-test-3x.pth 11 | Validation KITTI: EPE: 3.863 (3.863), F1: 13.52 (13.52) 12 | VALID | FORWARD | rel: 0.0006333217606879771; abs: 4.249484539031982; nstep: tensor([60], device='cuda:0') 13 | VALID | FORWARD | rel: 0.00028998314519412816; abs: 1.9555532932281494; nstep: tensor([60], device='cuda:0') 14 | Validation (clean) EPE: 1.270 (1.270), 1px: 91.87, 3px: 96.38, 5px: 97.38 15 | VALID | FORWARD | rel: 0.0012048647040501237; abs: 8.165122032165527; nstep: tensor([60], device='cuda:0') 16 | Validation (final) EPE: 2.500 (2.500), 1px: 86.74, 3px: 92.69, 5px: 94.43 17 | Parameter Count: 13.395 M 18 | Load from checkpoints/deq-flow-H-things-test-3x.pth 19 | Validation KITTI: EPE: 3.768 (3.768), F1: 13.41 (13.41) 20 | VALID | FORWARD | rel: 0.0006324453861452639; abs: 4.2437052726745605; nstep: tensor([120], device='cuda:0') 21 | VALID | FORWARD | rel: 0.0002864457492250949; abs: 1.931783676147461; nstep: tensor([120], device='cuda:0') 22 | Validation (clean) EPE: 1.275 (1.275), 1px: 91.86, 3px: 96.38, 5px: 97.39 23 | VALID | FORWARD | rel: 0.001575868227519095; abs: 10.679956436157227; nstep: tensor([120], device='cuda:0') 24 | Validation (final) EPE: 2.481 (2.481), 1px: 86.77, 3px: 92.74, 5px: 94.49 25 | -------------------------------------------------------------------------------- /code.v.2.0/train_H.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -u main.py --total_run 1 --start_run 1 --name deq-flow-H-naive-120k-C-36-6-1 \ 4 | --stage chairs --validation chairs kitti \ 5 | --gpus 0 1 2 --num_steps 120000 --eval_interval 20000 \ 6 | --batch_size 12 --lr 0.0004 --image_size 368 496 --wdecay 0.0001 \ 7 | --f_thres 36 --f_solver naive_solver \ 8 | --n_losses 6 --phantom_grad 1 \ 9 | --huge --wnorm 10 | 11 | python -u main.py --total_run 1 --start_run 1 --name deq-flow-H-naive-120k-T-40-2-3 \ 12 | --stage things --validation sintel kitti \ 13 | --restore_name deq-flow-H-naive-120k-C-36-6-1 \ 14 | --gpus 0 1 2 --num_steps 120000 \ 15 | --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001 \ 16 | --f_thres 40 --f_solver naive_solver \ 17 | --n_losses 2 --phantom_grad 3 \ 18 | --huge --wnorm --all_grad 19 | 20 | -------------------------------------------------------------------------------- /code.v.2.0/train_H_1_step_grad.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -u main.py --total_run 1 --start_run 1 --name deq-flow-H-naive-120k-C-36-1-1 \ 4 | --stage chairs --validation chairs kitti \ 5 | --gpus 0 1 2 --num_steps 120000 --eval_interval 20000 \ 6 | --batch_size 12 --lr 0.0004 --image_size 368 496 --wdecay 0.0001 \ 7 | --f_thres 36 --f_solver naive_solver \ 8 | --n_losses 1 --phantom_grad 1 \ 9 | --huge --wnorm 10 | 11 | python -u main.py --total_run 1 --start_run 1 --name deq-flow-H-naive-120k-T-40-1-1 \ 12 | --stage things --validation sintel kitti --restore_name deq-flow-H-naive-120k-C-36-1-1 \ 13 | --gpus 0 1 2 --num_steps 120000 \ 14 | --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001 \ 15 | --f_thres 40 --f_solver naive_solver \ 16 | --n_losses 1 --phantom_grad 1 \ 17 | --huge --wnorm --all_grad 18 | 19 | -------------------------------------------------------------------------------- /code.v.2.0/val.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -u main.py --eval --name deq-flow-H-all-grad --stage things \ 4 | --validation kitti sintel --restore_ckpt checkpoints/deq-flow-H-things-test-1x.pth --gpus 0 \ 5 | --wnorm --f_thres 40 --f_solver naive_solver \ 6 | --eval_factor 1.5 --huge 7 | 8 | python -u main.py --eval --name deq-flow-H-all-grad --stage things \ 9 | --validation kitti sintel --restore_ckpt checkpoints/deq-flow-H-things-test-3x.pth --gpus 0 \ 10 | --wnorm --f_thres 40 --f_solver naive_solver \ 11 | --eval_factor 1.5 --huge 12 | 13 | python -u main.py --eval --name deq-flow-H-all-grad --stage things \ 14 | --validation kitti sintel --restore_ckpt checkpoints/deq-flow-H-things-test-3x.pth --gpus 0 \ 15 | --wnorm --f_thres 40 --f_solver naive_solver \ 16 | --eval_factor 3.0 --huge 17 | 18 | 19 | -------------------------------------------------------------------------------- /code.v.2.0/viz.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('core') 4 | 5 | import argparse 6 | import copy 7 | import os 8 | import time 9 | 10 | import datasets 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | import cv2 17 | from PIL import Image 18 | 19 | from utils import flow_viz, frame_utils 20 | from utils.utils import InputPadder, forward_interpolate 21 | 22 | 23 | @torch.no_grad() 24 | def sintel_visualization(model, split='test', warm_start=False, fixed_point_reuse=False, output_path='sintel_viz', **kwargs): 25 | """ Create visualization for the Sintel dataset """ 26 | model.eval() 27 | for dstype in ['clean', 'final']: 28 | split = 'test' if split == 'test' else 'training' 29 | test_dataset = datasets.MpiSintel(split=split, aug_params=None, dstype=dstype) 30 | 31 | flow_prev, sequence_prev, fixed_point = None, None, None 32 | for test_id in range(len(test_dataset)): 33 | image1, image2, (sequence, frame) = test_dataset[test_id] 34 | if sequence != sequence_prev: 35 | flow_prev = None 36 | fixed_point = None 37 | 38 | padder = InputPadder(image1.shape) 39 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) 40 | 41 | flow_low, flow_pr, info = model(image1, image2, flow_init=flow_prev, cached_result=fixed_point, **kwargs) 42 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy() 43 | 44 | if warm_start: 45 | flow_prev = forward_interpolate(flow_low[0])[None].cuda() 46 | 47 | if fixed_point_reuse: 48 | net, flow_pred_low = info['cached_result'] 49 | flow_pred_low = forward_interpolate(flow_pred_low[0])[None].cuda() 50 | fixed_point = (net, flow_pred_low) 51 | 52 | output_dir = os.path.join(output_path, dstype, sequence) 53 | output_file = os.path.join(output_dir, 'frame%04d.png' % (frame+1)) 54 | 55 | if not os.path.exists(output_dir): 56 | os.makedirs(output_dir) 57 | 58 | # visualizaion 59 | img_flow = flow_viz.flow_to_image(flow) 60 | img_flow = cv2.cvtColor(img_flow, cv2.COLOR_RGB2BGR) 61 | cv2.imwrite(output_file, img_flow, [int(cv2.IMWRITE_PNG_COMPRESSION), 1]) 62 | 63 | sequence_prev = sequence 64 | 65 | 66 | @torch.no_grad() 67 | def kitti_visualization(model, split='test', output_path='kitti_viz'): 68 | """ Create visualization for the KITTI dataset """ 69 | model.eval() 70 | split = 'testing' if split == 'test' else 'training' 71 | test_dataset = datasets.KITTI(split='testing', aug_params=None) 72 | 73 | if not os.path.exists(output_path): 74 | os.makedirs(output_path) 75 | 76 | for test_id in range(len(test_dataset)): 77 | image1, image2, (frame_id, ) = test_dataset[test_id] 78 | padder = InputPadder(image1.shape, mode='kitti') 79 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) 80 | 81 | _, flow_pr, _ = model(image1, image2) 82 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy() 83 | 84 | output_filename = os.path.join(output_path, frame_id) 85 | 86 | # visualizaion 87 | img_flow = flow_viz.flow_to_image(flow) 88 | img_flow = cv2.cvtColor(img_flow, cv2.COLOR_RGB2BGR) 89 | cv2.imwrite(output_filename, img_flow, [int(cv2.IMWRITE_PNG_COMPRESSION), 1]) 90 | 91 | 92 | --------------------------------------------------------------------------------