├── .gitignore
├── LICENSE
├── README.md
├── assets
├── frame0037_pred.png
└── frame_0037_frame.png
├── code.v.1.0
├── README.md
├── alt_cuda_corr
│ ├── correlation.cpp
│ ├── correlation_kernel.cu
│ └── setup.py
├── chairs_split.txt
├── core
│ ├── __init__.py
│ ├── attention.py
│ ├── corr.py
│ ├── datasets.py
│ ├── deq.py
│ ├── deq_demo.py
│ ├── extractor.py
│ ├── gma.py
│ ├── lib
│ │ ├── grad.py
│ │ ├── jacobian.py
│ │ ├── layer_utils.py
│ │ ├── optimizations.py
│ │ └── solvers.py
│ ├── metrics.py
│ ├── update.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── augmentor.py
│ │ ├── flow_viz.py
│ │ ├── frame_utils.py
│ │ ├── grid_sample.py
│ │ └── utils.py
├── evaluate.py
├── main.py
├── ref
│ ├── B_1_step_grad.txt
│ ├── H_1_step_grad.txt
│ └── val.txt
├── train_B.sh
├── train_B_demo.sh
├── train_H_demo.sh
├── train_H_full.sh
├── val.sh
├── viz.py
└── viz.sh
└── code.v.2.0
├── chairs_split.txt
├── core
├── __init__.py
├── corr.py
├── datasets.py
├── deq
│ ├── __init__.py
│ ├── arg_utils.py
│ ├── deq_class.py
│ ├── dropout.py
│ ├── grad.py
│ ├── jacobian.py
│ ├── layer_utils.py
│ ├── norm
│ │ ├── __init__.py
│ │ └── weight_norm.py
│ └── solvers.py
├── deq_flow.py
├── extractor.py
├── gma.py
├── metrics.py
├── update.py
└── utils
│ ├── __init__.py
│ ├── augmentor.py
│ ├── flow_viz.py
│ ├── frame_utils.py
│ ├── grid_sample.py
│ └── utils.py
├── evaluate.py
├── log
└── val.txt
├── main.py
├── train_H.sh
├── train_H_1_step_grad.sh
├── val.sh
└── viz.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib64/
18 | parts/
19 | sdist/
20 | var/
21 | wheels/
22 | pip-wheel-metadata/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Django stuff:
58 | *.log
59 | local_settings.py
60 | db.sqlite3
61 | db.sqlite3-journal
62 |
63 | # Flask stuff:
64 | instance/
65 | .webassets-cache
66 |
67 | # Scrapy stuff:
68 | .scrapy
69 |
70 | # Sphinx documentation
71 | docs/_build/
72 |
73 | # PyBuilder
74 | target/
75 |
76 | # Jupyter Notebook
77 | .ipynb_checkpoints
78 |
79 | # IPython
80 | profile_default/
81 | ipython_config.py
82 |
83 | # pyenv
84 | .python-version
85 |
86 | # pipenv
87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
90 | # install all needed dependencies.
91 | #Pipfile.lock
92 |
93 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
94 | __pypackages__/
95 |
96 | # Celery stuff
97 | celerybeat-schedule
98 | celerybeat.pid
99 |
100 | # SageMath parsed files
101 | *.sage.py
102 |
103 | # Environments
104 | .env
105 | .venv
106 | env/
107 | venv/
108 | ENV/
109 | env.bak/
110 | venv.bak/
111 |
112 | # Spyder project settings
113 | .spyderproject
114 | .spyproject
115 |
116 | # Rope project settings
117 | .ropeproject
118 |
119 | # mkdocs documentation
120 | /site
121 |
122 | # mypy
123 | .mypy_cache/
124 | .dmypy.json
125 | dmypy.json
126 |
127 | # Pyre type checker
128 | .pyre/
129 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deep Equilibrium Optical Flow Estimation
2 |
3 | [](https://paperswithcode.com/sota/optical-flow-estimation-on-kitti-2015-train?p=deep-equilibrium-optical-flow-estimation)
4 |
5 | (🌟Version 2.0 released!🌟)
6 |
7 | This is the official repo for the paper [*Deep Equilibrium Optical Flow Estimation*](https://arxiv.org/abs/2204.08442) (CVPR 2022), by [Shaojie Bai](https://jerrybai1995.github.io/)\*, [Zhengyang Geng](https://gsunshine.github.io/)\*, [Yash Savani](https://yashsavani.com/) and [J. Zico Kolter](http://zicokolter.com/).
8 |
9 |


10 |
11 | > A deep equilibrium (DEQ) flow estimator directly models the flow as a path-independent, “infinite-level” fixed-point solving process. We propose to use this implicit framework to replace the existing recurrent approach to optical flow estimation. The DEQ flows converge faster, require less memory, are often more accurate, and are compatible with prior model designs (e.g., RAFT and GMA).
12 |
13 | ## Demo
14 |
15 | We provide a demo video of the DEQ flow results below.
16 |
17 | https://user-images.githubusercontent.com/18630903/163676562-e14a433f-4c71-4994-8e3d-97b3c33d98ab.mp4
18 |
19 | ---
20 |
21 | ## Update
22 |
23 | 🌟 2022.xx.xx - Support visualization and demo on your own datasets and videos! Coming soon!
24 |
25 | 🌟 2022.08.08 - Release the **version 2.0** of DEQ-Flow! DEQ-Flow will be merged into [DEQ](https://github.com/locuslab/deq) after further upgrading and unit testing.
26 |
27 | - A clean and decoupled **[DEQ lib](https://github.com/locuslab/deq-flow/blob/main/code.v.2.0/core/deq)**. This is a fully featured and out-of-the-box lib. You're welcome to implement **your own DEQ** using our DEQ lib! We support the following features. (*The DEQ lib will be available on PyPI soon for easy installation via `pip`.*)
28 | - Automatic arg parser decorator. You can call this function to add the DEQ args to your program. See the explanation for args [here](https://github.com/locuslab/deq-flow/blob/main/code.v.2.0/core/deq/arg_utils.py)!
29 |
30 | ```Python
31 | add_deq_args(parser)
32 | ```
33 |
34 | - Automatic DEQ definition. Call `get_deq` to get your DEQ class! **It's highly decoupled implementation agnostic to your model design!**
35 |
36 | ```Python
37 | DEQ = get_deq(args)
38 | self.deq = DEQ(args)
39 | ```
40 |
41 | - Automatic normalization for DEQ. You now do not need to add normalization manually to each weight in the DEQ func!
42 |
43 | ```Python
44 | if args.wnorm:
45 | apply_weight_norm(self.update_block)
46 | ```
47 |
48 | - Easy DEQ forward. Even for a multi-equilibria system, you can call the DEQ function using several lines!
49 |
50 | ```Python
51 | # Assume args is a list [z1, z2, ..., zn]
52 | # of to-be-solved equilibrium variables.
53 | def func(*args):
54 | # A functor defined in the Pytorch forward function.
55 | # Having the same input and output tensor shapes.
56 | return args
57 |
58 | deq_func = DEQWrapper(func, args)
59 | z_init = deq_func.list2vec(*args) # will be merged into self.deq(...)
60 | z_out, info = self.deq(deq_func, z_init)
61 | ```
62 |
63 | - Automatic DEQ training. Gradients (both exact and inexact grad) are tracked automatically! Fixed point correction can be customized through your arg parser. Just post-process `z_out` as you want!
64 |
65 | - Benchmarked results and [checkpoints](https://drive.google.com/drive/folders/1a_eX_wYN1qTw2Rj1naEXhcsG4D3KKxFw?usp=sharing). Using the release code base v.2.0, we've trained DEQ-Flow-H on FlyingChairs and FlyingThings for two schedules, *120k+120k (1x)* and *120k+360k (3x)*. This implementation demonstrated a new SOTA, surpassing our previous results in performance, training speed, and memory usage.
66 |
67 | Notably, we also benchmark RAFT using the same model size. DEQ-Flow demonstrates a clear performance and efficiency margin and **much stronger scaling property** (scale up to larger models) over RAFT!
68 |
69 | | Checkpoint Name | Sintel (clean) | Sintel (final) | KITTI AEPE | KITTI F1-all |
70 | | :--------------: | :------------: | :------------: | :---------: | :----------: |
71 | | RAFT-H-1x | 1.36 | 2.59 | 4.47 | 16.16 |
72 | | DEQ-Flow-H-1x | 1.27 | 2.58 | 3.76 | 12.95 |
73 | | DEQ-Flow-H-3x | 1.27 | 2.48 | 3.77 | 13.41 |
74 |
75 | - 1x=120k iterations on FlyingThings, 3x=360k iterations on FlyingThings, using a batch size of 6.
76 | - Increasing the batch size on FlyingThings can further improve these results, e.g., a batch size of 12 can reduce the F1-all of DEQ-Flow-H-1x to around 12.5 on KITTI.
77 |
78 | To validate our results, download the pretrained [checkpoints](https://drive.google.com/drive/folders/1a_eX_wYN1qTw2Rj1naEXhcsG4D3KKxFw?usp=sharing) into the `checkpoints` directory. Run the following command in [code.v.2.0](https://github.com/locuslab/deq-flow/blob/main/code.v.2.0/) to infer over the Sintel train set and the KITTI train set. This is a reference [log](https://github.com/locuslab/deq-flow/blob/main/code.v.2.0/log/val.txt).
79 |
80 | ```bash
81 | bash val.sh
82 | ```
83 |
84 | ---
85 |
86 | ## Requirements
87 |
88 | The code in this repo has been tested on PyTorch v1.10.0. Install required environments through the following commands.
89 |
90 | ```bash
91 | conda create --name deq python==3.6.10
92 | conda activate deq
93 | conda install pytorch==1.10.0 torchvision==0.11.0 torchaudio==0.10.0 cudatoolkit=11.3 -c pytorch -c conda-forge
94 | conda install tensorboard scipy opencv matplotlib einops termcolor -c conda-forge
95 | ```
96 |
97 | Download the following datasets into the `datasets` directory.
98 |
99 | - [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs)
100 | - [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html)
101 | - [MPI Sintel](http://sintel.is.tue.mpg.de/)
102 | - [KITTI 2015](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow)
103 | - [HD1k](http://hci-benchmark.iwr.uni-heidelberg.de/)
104 |
105 | ---
106 | The following README doc is for version 1.0, i.e., [code.v.1.0](https://github.com/locuslab/deq-flow/blob/main/code.v.1.0/). You can follow this to reproduce all the results.
107 |
108 | ## Inference
109 |
110 | Download the pretrained [checkpoints](https://drive.google.com/drive/folders/1PeyOr4kmSuMWrh4iwYKbVLqDU6WPX-HM?usp=sharing) into the `checkpoints` directory. Run the following command to infer over the Sintel train set and the KITTI train set.
111 |
112 | ```bash
113 | bash val.sh
114 | ```
115 |
116 | You may expect the following performance statistics of given checkpoints. This is a reference [log](https://github.com/locuslab/deq-flow/blob/main/code.v.1.0/ref/val.txt).
117 |
118 | | Checkpoint Name | Sintel (clean) | Sintel (final) | KITTI AEPE | KITTI F1-all |
119 | | :--------------: | :------------: | :------------: | :---------: | :----------: |
120 | | DEQ-Flow-B | 1.43 | 2.79 | 5.43 | 16.67 |
121 | | DEQ-Flow-H-1 | 1.45 | 2.58 | 3.97 | 13.41 |
122 | | DEQ-Flow-H-2 | 1.37 | 2.62 | 3.97 | 13.62 |
123 | | DEQ-Flow-H-3 | 1.36 | 2.62 | 4.02 | 13.92 |
124 |
125 | ## Visualization
126 |
127 | Download the pretrained [checkpoints](https://drive.google.com/drive/folders/1PeyOr4kmSuMWrh4iwYKbVLqDU6WPX-HM?usp=sharing) into the `checkpoints` directory. Run the following command to visualize the optical flow estimation over the KITTI test set.
128 |
129 | ```bash
130 | bash viz.sh
131 | ```
132 |
133 | ## Training
134 |
135 | Download *FlyingChairs*-pretrained [checkpoints](https://drive.google.com/drive/folders/1PeyOr4kmSuMWrh4iwYKbVLqDU6WPX-HM?usp=sharing) into the `checkpoints` directory.
136 |
137 | For the efficiency mode, you can run 1-step gradient to train DEQ-Flow-B via the following command. Memory overhead per GPU is about 5800 MB.
138 |
139 | You may expect best results of about 1.46 (AEPE) on Sintel (clean), 2.85 (AEPE) on Sintel (final), 5.29 (AEPE) and 16.24 (F1-all) on KITTI. This is a reference [log](https://github.com/locuslab/deq-flow/blob/main/code.v.1.0/ref/B_1_step_grad.txt).
140 |
141 | ```bash
142 | bash train_B_demo.sh
143 | ```
144 |
145 | For training a demo of DEQ-Flow-H, you can run this command. Memory overhead per GPU is about 6300 MB. It can be further reduced to about **4200 MB** per GPU when combined with `--mixed-precision`. You can further reduce the memory cost if you employ the CUDA implementation of cost volumn by [RAFT](https://github.com/princeton-vl/RAFT).
146 |
147 | You may expect best results of about 1.41 (AEPE) on Sintel (clean), 2.76 (AEPE) on Sintel (final), 4.44 (AEPE) and 14.81 (F1-all) on KITTI. This is a reference [log](https://github.com/locuslab/deq-flow/blob/main/code.v.1.0/ref/H_1_step_grad.txt).
148 |
149 | ```bash
150 | bash train_H_demo.sh
151 | ```
152 |
153 | To train DEQ-Flow-B on Chairs and Things, use the following command.
154 |
155 | ```bash
156 | bash train_B.sh
157 | ```
158 |
159 | For the performance mode, you can run this command to train DEQ-Flow-H using the ``C+T`` and ``C+T+S+K+H`` schedule. You may expect the performance of <1.40 (AEPE) on Sintel (clean), around 2.60 (AEPE) on Sintel (final), around 4.00 (AEPE) and 13.6 (F1-all) on KITTI. DEQ-Flow-H-1,2,3 are checkpoints from three runs.
160 |
161 | Currently, this training protocol could entail resources slightly more than two 11 GB GPUs. In the near future, we will upload an implementation revision (of the DEQ models) that shall further reduce this overhead to **less than two 11 GB GPUs**.
162 |
163 | ```bash
164 | bash train_H_full.sh
165 | ```
166 |
167 | ## Code Usage
168 |
169 | Under construction. We will provide more detailed instructions on the code usage (e.g., argparse flags, fixed-point solvers, backward IFT modes) in the coming days.
170 |
171 | ---
172 |
173 | ## A Tutorial on DEQ
174 |
175 | If you hope to learn more about DEQ models, here is an official NeurIPS [tutorial](https://implicit-layers-tutorial.org/) on implicit deep learning. Enjoy yourself!
176 |
177 | ## Reference
178 |
179 | If you find our work helpful to your research, please consider citing this paper. :)
180 |
181 | ```bib
182 | @inproceedings{deq-flow,
183 | author = {Bai, Shaojie and Geng, Zhengyang and Savani, Yash and Kolter, J. Zico},
184 | title = {Deep Equilibrium Optical Flow Estimation},
185 | booktitle = {Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
186 | year = {2022}
187 | }
188 | ```
189 |
190 | ## Credit
191 |
192 | A lot of the utility code in this repo were adapted from the [RAFT](https://github.com/princeton-vl/RAFT) repo and the [DEQ](https://github.com/locuslab/deq) repo.
193 |
194 | ## Contact
195 |
196 | Feel free to contact us if you have additional questions. Please drop an email through zhengyanggeng@gmail.com (or [Twitter](https://twitter.com/ZhengyangGeng)).
--------------------------------------------------------------------------------
/assets/frame0037_pred.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/deq-flow/869d4bcd9c1d227cbfa16639994d1b118510a9bb/assets/frame0037_pred.png
--------------------------------------------------------------------------------
/assets/frame_0037_frame.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/deq-flow/869d4bcd9c1d227cbfa16639994d1b118510a9bb/assets/frame_0037_frame.png
--------------------------------------------------------------------------------
/code.v.1.0/README.md:
--------------------------------------------------------------------------------
1 | # Deep Equilibrium Optical Flow Estimation
2 |
3 | [](https://paperswithcode.com/sota/optical-flow-estimation-on-kitti-2015-train?p=deep-equilibrium-optical-flow-estimation)
4 |
5 | This is the official repo for the paper [*Deep Equilibrium Optical Flow Estimation*](https://arxiv.org/abs/2204.08442) (CVPR 2022), by [Shaojie Bai](https://jerrybai1995.github.io/)\*, [Zhengyang Geng](https://gsunshine.github.io/)\*, [Yash Savani](https://yashsavani.com/) and [J. Zico Kolter](http://zicokolter.com/).
6 |
7 | 

8 |
9 | > A deep equilibrium (DEQ) flow estimator directly models the flow as a path-independent, “infinite-level” fixed-point solving process. We propose to use this implicit framework to replace the existing recurrent approach to optical flow estimation. The DEQ flows converge faster, require less memory, are often more accurate, and are compatible with prior model designs (e.g., RAFT and GMA).
10 |
11 | ## Demo
12 |
13 | We provide a demo video of the DEQ flow results below.
14 |
15 | https://user-images.githubusercontent.com/18630903/163676562-e14a433f-4c71-4994-8e3d-97b3c33d98ab.mp4
16 |
17 | ## Requirements
18 |
19 | The code in this repo has been tested on PyTorch v1.10.0. Install required environments through the following commands.
20 |
21 | ```bash
22 | conda create --name deq python==3.6.10
23 | conda activate deq
24 | conda install pytorch==1.10.0 torchvision==0.11.0 torchaudio==0.10.0 cudatoolkit=11.3 -c pytorch -c conda-forge
25 | conda install tensorboard scipy opencv matplotlib einops termcolor -c conda-forge
26 | ```
27 |
28 | Download the following datasets into the `datasets` directory.
29 |
30 | - [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs)
31 | - [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html)
32 | - [MPI Sintel](http://sintel.is.tue.mpg.de/)
33 | - [KITTI 2015](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow)
34 | - [HD1k](http://hci-benchmark.iwr.uni-heidelberg.de/)
35 |
36 | ## Inference
37 |
38 | Download the pretrained [checkpoints](https://drive.google.com/drive/folders/1PeyOr4kmSuMWrh4iwYKbVLqDU6WPX-HM?usp=sharing) into the `checkpoints` directory. Run the following command to infer over the Sintel train set and the KITTI train set.
39 |
40 | ```bash
41 | bash val.sh
42 | ```
43 |
44 | You may expect the following performance statistics of given checkpoints. This is a reference [log](https://github.com/locuslab/deq-flow/blob/main/code.v.1.0/ref/val.txt).
45 |
46 | | Checkpoint Name | Sintel (clean) | Sintel (final) | KITTI AEPE | KITTI F1-all |
47 | | :--------------: | :------------: | :------------: | :---------: | :----------: |
48 | | DEQ-Flow-B | 1.43 | 2.79 | 5.43 | 16.67 |
49 | | DEQ-Flow-H-1 | 1.45 | 2.58 | 3.97 | 13.41 |
50 | | DEQ-Flow-H-2 | 1.37 | 2.62 | 3.97 | 13.62 |
51 | | DEQ-Flow-H-3 | 1.36 | 2.62 | 4.02 | 13.92 |
52 |
53 | ## Visualization
54 |
55 | Download the pretrained [checkpoints](https://drive.google.com/drive/folders/1PeyOr4kmSuMWrh4iwYKbVLqDU6WPX-HM?usp=sharing) into the `checkpoints` directory. Run the following command to visualize the optical flow estimation over the KITTI test set.
56 |
57 | ```bash
58 | bash viz.sh
59 | ```
60 |
61 | ## Training
62 |
63 | Download *FlyingChairs*-pretrained [checkpoints](https://drive.google.com/drive/folders/1PeyOr4kmSuMWrh4iwYKbVLqDU6WPX-HM?usp=sharing) into the `checkpoints` directory.
64 |
65 | For the efficiency mode, you can run 1-step gradient to train DEQ-Flow-B via the following command. Memory overhead per GPU is about 5800 MB.
66 |
67 | You may expect best results of about 1.46 (AEPE) on Sintel (clean), 2.85 (AEPE) on Sintel (final), 5.29 (AEPE) and 16.24 (F1-all) on KITTI. This is a reference [log](https://github.com/locuslab/deq-flow/blob/main/code.v.1.0/ref/B_1_step_grad.txt).
68 |
69 | ```bash
70 | bash train_B_demo.sh
71 | ```
72 |
73 | For training a demo of DEQ-Flow-H, you can run this command. Memory overhead per GPU is about 6300 MB. It can be further reduced to about **4200 MB** per GPU when combined with `--mixed-precision`. You can further reduce the memory cost if you employ the CUDA implementation of cost volumn by [RAFT](https://github.com/princeton-vl/RAFT).
74 |
75 | You may expect best results of about 1.41 (AEPE) on Sintel (clean), 2.76 (AEPE) on Sintel (final), 4.44 (AEPE) and 14.81 (F1-all) on KITTI. This is a reference [log](https://github.com/locuslab/deq-flow/blob/main/code.v.1.0/ref/H_1_step_grad.txt).
76 |
77 | ```bash
78 | bash train_H_demo.sh
79 | ```
80 |
81 | To train DEQ-Flow-B on Chairs and Things, use the following command.
82 |
83 | ```bash
84 | bash train_B.sh
85 | ```
86 |
87 | For the performance mode, you can run this command to train DEQ-Flow-H using the ``C+T`` and ``C+T+S+K+H`` schedule. You may expect the performance of <1.40 (AEPE) on Sintel (clean), around 2.60 (AEPE) on Sintel (final), around 4.00 (AEPE) and 13.6 (F1-all) on KITTI. DEQ-Flow-H-1,2,3 are checkpoints from three runs.
88 |
89 | Currently, this training protocol could entail resources slightly more than two 11 GB GPUs. In the near future, we will upload an implementation revision (of the DEQ models) that shall further reduce this overhead to **less than two 11 GB GPUs**.
90 |
91 | ```bash
92 | bash train_H_full.sh
93 | ```
94 |
95 | ## Code Usage
96 |
97 | Under construction. We will provide more detailed instructions on the code usage (e.g., argparse flags, fixed-point solvers, backward IFT modes) in the coming days.
98 |
99 | ## A Tutorial on DEQ
100 |
101 | If you hope to learn more about DEQ models, here is an official NeurIPS [tutorial](https://implicit-layers-tutorial.org/) on implicit deep learning. Enjoy yourself!
102 |
103 | ## Reference
104 |
105 | If you find our work helpful to your research, please consider citing this paper. :)
106 |
107 | ```bib
108 | @inproceedings{deq-flow,
109 | author = {Bai, Shaojie and Geng, Zhengyang and Savani, Yash and Kolter, J. Zico},
110 | title = {Deep Equilibrium Optical Flow Estimation},
111 | booktitle = {Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
112 | year = {2022}
113 | }
114 | ```
115 |
116 | ## Credit
117 |
118 | A lot of the utility code in this repo were adapted from the [RAFT](https://github.com/princeton-vl/RAFT) repo and the [DEQ](https://github.com/locuslab/deq) repo.
119 |
120 | ## Contact
121 |
122 | Feel free to contact us if you have additional questions. Please drop an email through zhengyanggeng@gmail.com (or [Twitter](https://twitter.com/ZhengyangGeng)).
123 |
--------------------------------------------------------------------------------
/code.v.1.0/alt_cuda_corr/correlation.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | // CUDA forward declarations
5 | std::vector corr_cuda_forward(
6 | torch::Tensor fmap1,
7 | torch::Tensor fmap2,
8 | torch::Tensor coords,
9 | int radius);
10 |
11 | std::vector corr_cuda_backward(
12 | torch::Tensor fmap1,
13 | torch::Tensor fmap2,
14 | torch::Tensor coords,
15 | torch::Tensor corr_grad,
16 | int radius);
17 |
18 | // C++ interface
19 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
20 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
21 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
22 |
23 | std::vector corr_forward(
24 | torch::Tensor fmap1,
25 | torch::Tensor fmap2,
26 | torch::Tensor coords,
27 | int radius) {
28 | CHECK_INPUT(fmap1);
29 | CHECK_INPUT(fmap2);
30 | CHECK_INPUT(coords);
31 |
32 | return corr_cuda_forward(fmap1, fmap2, coords, radius);
33 | }
34 |
35 |
36 | std::vector corr_backward(
37 | torch::Tensor fmap1,
38 | torch::Tensor fmap2,
39 | torch::Tensor coords,
40 | torch::Tensor corr_grad,
41 | int radius) {
42 | CHECK_INPUT(fmap1);
43 | CHECK_INPUT(fmap2);
44 | CHECK_INPUT(coords);
45 | CHECK_INPUT(corr_grad);
46 |
47 | return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius);
48 | }
49 |
50 |
51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
52 | m.def("forward", &corr_forward, "CORR forward");
53 | m.def("backward", &corr_backward, "CORR backward");
54 | }
--------------------------------------------------------------------------------
/code.v.1.0/alt_cuda_corr/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3 |
4 |
5 | setup(
6 | name='correlation',
7 | ext_modules=[
8 | CUDAExtension('alt_cuda_corr',
9 | sources=['correlation.cpp', 'correlation_kernel.cu'],
10 | extra_compile_args={'cxx': [], 'nvcc': ['-O3']}),
11 | ],
12 | cmdclass={
13 | 'build_ext': BuildExtension
14 | })
15 |
16 |
--------------------------------------------------------------------------------
/code.v.1.0/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/deq-flow/869d4bcd9c1d227cbfa16639994d1b118510a9bb/code.v.1.0/core/__init__.py
--------------------------------------------------------------------------------
/code.v.1.0/core/attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, einsum
3 | from einops import rearrange
4 |
5 | from lib.optimizations import weight_norm
6 |
7 |
8 | class Attention(nn.Module):
9 | def __init__(
10 | self,
11 | dim,
12 | heads = 4,
13 | dim_head = 128,
14 | ):
15 | super().__init__()
16 | self.heads = heads
17 | self.scale = dim_head ** -0.5
18 | inner_dim = heads * dim_head
19 |
20 | self.to_q = nn.Conv2d(dim, inner_dim, 1, bias=False)
21 | self.to_k = nn.Conv2d(dim, inner_dim, 1, bias=False)
22 | self.to_v = nn.Conv2d(dim, inner_dim, 1, bias=False)
23 |
24 | self.gamma = nn.Parameter(torch.zeros(1))
25 | self.out = nn.Conv2d(inner_dim, dim, 1, bias=False)
26 |
27 | def _wnorm(self):
28 | self.to_q, self.to_q_fn = weight_norm(module=self.to_q, names=['weight'], dim=0)
29 | self.to_k, self.to_k_fn = weight_norm(module=self.to_k, names=['weight'], dim=0)
30 | self.to_v, self.to_v_fn = weight_norm(module=self.to_v, names=['weight'], dim=0)
31 |
32 | self.out, self.out_fn = weight_norm(module=self.out, names=['weight'], dim=0)
33 |
34 | def reset(self):
35 | for name in ['to_q', 'to_k', 'to_v', 'out']:
36 | if name + '_fn' in self.__dict__:
37 | eval(f'self.{name}_fn').reset(eval(f'self.{name}'))
38 |
39 | def forward(self, q, k, v):
40 | heads, b, c, h, w = self.heads, *v.shape
41 |
42 | input_q = q
43 | q = self.to_q(q)
44 | k = self.to_k(k)
45 | v = self.to_v(v)
46 |
47 | q, k = map(lambda t: rearrange(t, 'b (h d) x y -> b h x y d', h=heads), (q, k))
48 | v = rearrange(v, 'b (h d) x y -> b h (x y) d', h=heads)
49 |
50 | sim = self.scale * einsum('b h x y d, b h u v d -> b h x y u v', q, k)
51 | sim = rearrange(sim, 'b h x y u v -> b h (x y) (u v)')
52 | attn = sim.softmax(dim=-1)
53 |
54 | out = einsum('b h i j, b h j d -> b h i d', attn, v)
55 | out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w)
56 |
57 | out = self.out(out)
58 | out = input_q + self.gamma * out
59 |
60 | return out
61 |
62 |
63 | if __name__ == "__main__":
64 | att = Attention(dim=128, heads=1)
65 | x = torch.randn(2, 128, 40, 90)
66 | out = att(x, x, x)
67 |
68 | print(out.shape)
69 |
--------------------------------------------------------------------------------
/code.v.1.0/core/corr.py:
--------------------------------------------------------------------------------
1 | # This code was originally taken from RAFT without modification
2 | # https://github.com/princeton-vl/RAFT/blob/master/core/corr.py
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from utils.utils import bilinear_sampler, coords_grid
7 |
8 | try:
9 | import alt_cuda_corr
10 | except:
11 | # alt_cuda_corr is not compiled
12 | pass
13 |
14 |
15 | class CorrBlock:
16 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
17 | self.num_levels = num_levels
18 | self.radius = radius
19 | self.corr_pyramid = []
20 |
21 | # all pairs correlation
22 | corr = CorrBlock.corr(fmap1, fmap2)
23 |
24 | batch, h1, w1, dim, h2, w2 = corr.shape
25 | corr = corr.reshape(batch*h1*w1, dim, h2, w2)
26 |
27 | self.corr_pyramid.append(corr)
28 | for i in range(self.num_levels-1):
29 | corr = F.avg_pool2d(corr, 2, stride=2)
30 | self.corr_pyramid.append(corr)
31 |
32 | def __call__(self, coords):
33 | r = self.radius
34 | coords = coords.permute(0, 2, 3, 1)
35 | batch, h1, w1, _ = coords.shape
36 |
37 | out_pyramid = []
38 | for i in range(self.num_levels):
39 | corr = self.corr_pyramid[i]
40 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
41 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
42 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
43 |
44 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
45 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
46 | coords_lvl = centroid_lvl + delta_lvl
47 |
48 | corr = bilinear_sampler(corr, coords_lvl)
49 | corr = corr.view(batch, h1, w1, -1)
50 | out_pyramid.append(corr)
51 |
52 | out = torch.cat(out_pyramid, dim=-1)
53 | return out.permute(0, 3, 1, 2).contiguous().float()
54 |
55 | @staticmethod
56 | def corr(fmap1, fmap2):
57 | batch, dim, ht, wd = fmap1.shape
58 | fmap1 = fmap1.view(batch, dim, ht*wd)
59 | fmap2 = fmap2.view(batch, dim, ht*wd)
60 |
61 | corr = torch.matmul(fmap1.transpose(1,2), fmap2)
62 | corr = corr.view(batch, ht, wd, 1, ht, wd)
63 | return corr / torch.sqrt(torch.tensor(dim).float())
64 |
65 |
66 | class ConstCorrBlock:
67 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
68 | self.num_levels = num_levels
69 | self.radius = radius
70 | self.corr_pyramid = []
71 |
72 | # all pairs correlation
73 | corr = CorrBlock.corr(fmap1, fmap2)
74 |
75 | batch, h1, w1, dim, h2, w2 = corr.shape
76 | corr = corr.reshape(batch*h1*w1, dim, h2, w2)
77 |
78 | self.corr_pyramid.append(corr.view(batch, h1*w1, dim, h2, w2))
79 | for i in range(self.num_levels-1):
80 | corr = F.avg_pool2d(corr, 2, stride=2)
81 | self.corr_pyramid.append(corr.view(batch, h1*w1, *corr.shape[1:]))
82 |
83 | def __call__(self, coords, corr_pyramid=None):
84 | r = self.radius
85 | coords = coords.permute(0, 2, 3, 1)
86 | batch, h1, w1, _ = coords.shape
87 |
88 | corr_pyramid = corr_pyramid if corr_pyramid else self.corr_pyramid
89 |
90 | out_pyramid = []
91 | for i in range(self.num_levels):
92 | corr = corr_pyramid[i].flatten(0, 1)
93 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
94 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
95 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
96 |
97 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
98 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
99 | coords_lvl = centroid_lvl + delta_lvl
100 |
101 | corr = bilinear_sampler(corr, coords_lvl)
102 | corr = corr.view(batch, h1, w1, -1)
103 | out_pyramid.append(corr)
104 |
105 | out = torch.cat(out_pyramid, dim=-1)
106 | return out.permute(0, 3, 1, 2).contiguous().float()
107 |
108 | @staticmethod
109 | def corr(fmap1, fmap2):
110 | batch, dim, ht, wd = fmap1.shape
111 | fmap1 = fmap1.view(batch, dim, ht*wd)
112 | fmap2 = fmap2.view(batch, dim, ht*wd)
113 |
114 | corr = torch.matmul(fmap1.transpose(1,2), fmap2)
115 | corr = corr.view(batch, ht, wd, 1, ht, wd)
116 | return corr / torch.sqrt(torch.tensor(dim).float())
117 |
118 |
119 | class AlternateCorrBlock:
120 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
121 | self.num_levels = num_levels
122 | self.radius = radius
123 |
124 | self.pyramid = [(fmap1, fmap2)]
125 | for i in range(self.num_levels):
126 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
127 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
128 | self.pyramid.append((fmap1, fmap2))
129 |
130 | def __call__(self, coords):
131 | coords = coords.permute(0, 2, 3, 1)
132 | B, H, W, _ = coords.shape
133 | dim = self.pyramid[0][0].shape[1]
134 |
135 | corr_list = []
136 | for i in range(self.num_levels):
137 | r = self.radius
138 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
139 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
140 |
141 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
142 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
143 | corr_list.append(corr.squeeze(1))
144 |
145 | corr = torch.stack(corr_list, dim=1)
146 | corr = corr.reshape(B, -1, H, W)
147 | return corr / torch.sqrt(torch.tensor(dim).float())
148 |
--------------------------------------------------------------------------------
/code.v.1.0/core/datasets.py:
--------------------------------------------------------------------------------
1 | # Data loading based on https://github.com/NVIDIA/flownet2-pytorch
2 |
3 | import numpy as np
4 | import torch
5 | import torch.utils.data as data
6 | import torch.nn.functional as F
7 |
8 | import os
9 | import math
10 | import random
11 | from glob import glob
12 | import os.path as osp
13 |
14 | from utils import frame_utils
15 | from utils.augmentor import FlowAugmentor, SparseFlowAugmentor
16 |
17 |
18 | class FlowDataset(data.Dataset):
19 | def __init__(self, aug_params=None, sparse=False):
20 | self.augmentor = None
21 | self.sparse = sparse
22 | if aug_params is not None:
23 | if sparse:
24 | self.augmentor = SparseFlowAugmentor(**aug_params)
25 | else:
26 | self.augmentor = FlowAugmentor(**aug_params)
27 |
28 | self.is_test = False
29 | self.init_seed = False
30 | self.flow_list = []
31 | self.image_list = []
32 | self.extra_info = []
33 |
34 | def __getitem__(self, index):
35 |
36 | if self.is_test:
37 | img1 = frame_utils.read_gen(self.image_list[index][0])
38 | img2 = frame_utils.read_gen(self.image_list[index][1])
39 | img1 = np.array(img1).astype(np.uint8)[..., :3]
40 | img2 = np.array(img2).astype(np.uint8)[..., :3]
41 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
42 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
43 | return img1, img2, self.extra_info[index]
44 |
45 | if not self.init_seed:
46 | worker_info = torch.utils.data.get_worker_info()
47 | if worker_info is not None:
48 | torch.manual_seed(worker_info.id)
49 | np.random.seed(worker_info.id)
50 | random.seed(worker_info.id)
51 | self.init_seed = True
52 |
53 | index = index % len(self.image_list)
54 | valid = None
55 | if self.sparse:
56 | flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
57 | else:
58 | flow = frame_utils.read_gen(self.flow_list[index])
59 |
60 | img1 = frame_utils.read_gen(self.image_list[index][0])
61 | img2 = frame_utils.read_gen(self.image_list[index][1])
62 |
63 | flow = np.array(flow).astype(np.float32)
64 | img1 = np.array(img1).astype(np.uint8)
65 | img2 = np.array(img2).astype(np.uint8)
66 |
67 | # grayscale images
68 | if len(img1.shape) == 2:
69 | img1 = np.tile(img1[...,None], (1, 1, 3))
70 | img2 = np.tile(img2[...,None], (1, 1, 3))
71 | else:
72 | img1 = img1[..., :3]
73 | img2 = img2[..., :3]
74 |
75 | if self.augmentor is not None:
76 | if self.sparse:
77 | img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
78 | else:
79 | img1, img2, flow = self.augmentor(img1, img2, flow)
80 |
81 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
82 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
83 | flow = torch.from_numpy(flow).permute(2, 0, 1).float()
84 |
85 | if valid is not None:
86 | valid = torch.from_numpy(valid)
87 | else:
88 | valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
89 |
90 | return img1, img2, flow, valid.float()
91 |
92 |
93 | def __rmul__(self, v):
94 | self.flow_list = v * self.flow_list
95 | self.image_list = v * self.image_list
96 | return self
97 |
98 | def __len__(self):
99 | return len(self.image_list)
100 |
101 |
102 | class MpiSintel(FlowDataset):
103 | def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'):
104 | super(MpiSintel, self).__init__(aug_params)
105 | flow_root = osp.join(root, split, 'flow')
106 | image_root = osp.join(root, split, dstype)
107 |
108 | if split == 'test':
109 | self.is_test = True
110 |
111 | for scene in os.listdir(image_root):
112 | image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
113 | for i in range(len(image_list)-1):
114 | self.image_list += [ [image_list[i], image_list[i+1]] ]
115 | self.extra_info += [ (scene, i) ] # scene and frame_id
116 |
117 | if split != 'test':
118 | self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
119 |
120 |
121 | class FlyingChairs(FlowDataset):
122 | def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):
123 | super(FlyingChairs, self).__init__(aug_params)
124 |
125 | images = sorted(glob(osp.join(root, '*.ppm')))
126 | flows = sorted(glob(osp.join(root, '*.flo')))
127 | assert (len(images)//2 == len(flows))
128 |
129 | split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
130 | for i in range(len(flows)):
131 | xid = split_list[i]
132 | if (split=='training' and xid==1) or (split=='validation' and xid==2):
133 | self.flow_list += [ flows[i] ]
134 | self.image_list += [ [images[2*i], images[2*i+1]] ]
135 |
136 |
137 | class FlyingThings3D(FlowDataset):
138 | def __init__(self, aug_params=None, root='datasets/FlyingThings3D', split='train', dstype='frames_cleanpass'):
139 | super(FlyingThings3D, self).__init__(aug_params)
140 |
141 | if split == 'train':
142 | dir_prefix = 'TRAIN'
143 | elif split == 'test':
144 | dir_prefix = 'TEST'
145 | else:
146 | raise ValueError('Unknown split for FlyingThings3D.')
147 |
148 | for cam in ['left']:
149 | for direction in ['into_future', 'into_past']:
150 | image_dirs = sorted(glob(osp.join(root, dstype, f'{dir_prefix}/*/*')))
151 | image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
152 |
153 | flow_dirs = sorted(glob(osp.join(root, f'optical_flow/{dir_prefix}/*/*')))
154 | flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])
155 |
156 | for idir, fdir in zip(image_dirs, flow_dirs):
157 | images = sorted(glob(osp.join(idir, '*.png')) )
158 | flows = sorted(glob(osp.join(fdir, '*.pfm')) )
159 | for i in range(len(flows)-1):
160 | if direction == 'into_future':
161 | self.image_list += [ [images[i], images[i+1]] ]
162 | self.flow_list += [ flows[i] ]
163 | elif direction == 'into_past':
164 | self.image_list += [ [images[i+1], images[i]] ]
165 | self.flow_list += [ flows[i+1] ]
166 |
167 |
168 | class KITTI(FlowDataset):
169 | def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):
170 | super(KITTI, self).__init__(aug_params, sparse=True)
171 | if split == 'testing':
172 | self.is_test = True
173 |
174 | root = osp.join(root, split)
175 | images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
176 | images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
177 |
178 | for img1, img2 in zip(images1, images2):
179 | frame_id = img1.split('/')[-1]
180 | self.extra_info += [ [frame_id] ]
181 | self.image_list += [ [img1, img2] ]
182 |
183 | if split == 'training':
184 | self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
185 |
186 |
187 | class HD1K(FlowDataset):
188 | def __init__(self, aug_params=None, root='datasets/HD1k'):
189 | super(HD1K, self).__init__(aug_params, sparse=True)
190 |
191 | seq_ix = 0
192 | while 1:
193 | flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
194 | images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))
195 |
196 | if len(flows) == 0:
197 | break
198 |
199 | for i in range(len(flows)-1):
200 | self.flow_list += [flows[i]]
201 | self.image_list += [ [images[i], images[i+1]] ]
202 |
203 | seq_ix += 1
204 |
205 |
206 | def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
207 | """ Create the data loader for the corresponding trainign set """
208 |
209 | if args.stage == 'chairs':
210 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
211 | train_dataset = FlyingChairs(aug_params, split='training')
212 |
213 | elif args.stage == 'things':
214 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
215 | clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
216 | final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
217 | train_dataset = clean_dataset + final_dataset
218 |
219 | elif args.stage == 'sintel':
220 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
221 | things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
222 | sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
223 | sintel_final = MpiSintel(aug_params, split='training', dstype='final')
224 |
225 | if TRAIN_DS == 'C+T+K+S+H':
226 | kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
227 | hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
228 | train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
229 |
230 | elif TRAIN_DS == 'C+T+K/S':
231 | train_dataset = 100*sintel_clean + 100*sintel_final + things
232 |
233 | elif args.stage == 'kitti':
234 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
235 | train_dataset = KITTI(aug_params, split='training')
236 |
237 | train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
238 | pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
239 |
240 | print('Training with %d image pairs' % len(train_dataset))
241 | return train_loader
242 |
243 |
--------------------------------------------------------------------------------
/code.v.1.0/core/deq_demo.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from update import BasicUpdateBlock
7 | from extractor import BasicEncoder
8 | from corr import CorrBlock
9 | from utils.utils import bilinear_sampler, coords_grid, upflow8
10 |
11 | from termcolor import colored
12 |
13 | from lib.solvers import anderson, broyden
14 | from lib.grad import make_pair, backward_factory
15 |
16 | from metrics import process_metrics
17 |
18 |
19 | try:
20 | autocast = torch.cuda.amp.autocast
21 | except:
22 | # dummy autocast for PyTorch < 1.6
23 | class autocast:
24 | def __init__(self, enabled):
25 | pass
26 | def __enter__(self):
27 | pass
28 | def __exit__(self, *args):
29 | pass
30 |
31 |
32 | class DEQFlowDemo(nn.Module):
33 | def __init__(self, args):
34 | super(DEQFlowDemo, self).__init__()
35 | self.args = args
36 |
37 | odim = 256
38 | self.hidden_dim = hdim = 128
39 | self.context_dim = cdim = 128
40 | args.corr_levels = 4
41 | args.corr_radius = 4
42 |
43 | # feature network, context network, and update block
44 | self.fnet = BasicEncoder(output_dim=odim, norm_fn='instance', dropout=args.dropout)
45 | self.cnet = BasicEncoder(output_dim=cdim, norm_fn='batch', dropout=args.dropout)
46 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
47 |
48 | # Added the following for the DEQ models
49 | if args.wnorm:
50 | self.update_block._wnorm()
51 |
52 | self.f_solver = eval(args.f_solver)
53 | self.f_thres = args.f_thres
54 | self.eval_f_thres = int(self.f_thres * args.eval_factor)
55 | self.stop_mode = args.stop_mode
56 |
57 | # Define gradient functions through the backward factory
58 | if args.n_losses > 1:
59 | n_losses = min(args.f_thres, args.n_losses)
60 | delta = int(args.f_thres // n_losses)
61 | self.indexing = [(k+1)*delta for k in range(n_losses)]
62 | else:
63 | self.indexing = [*args.indexing, args.f_thres]
64 |
65 | # By default, we use the same phantom grad for all corrections.
66 | # You can also set different grad steps a, b, and c for different terms by ``args.phantom_grad a b c ...''.
67 | indexing_pg = make_pair(self.indexing, args.phantom_grad)
68 | produce_grad = [
69 | backward_factory(grad_type=pg, tau=args.tau, sup_all=args.sup_all) for pg in indexing_pg
70 | ]
71 | if args.ift:
72 | # Enabling args.ift will replace the last gradient function by IFT.
73 | produce_grad[-1] = backward_factory(
74 | grad_type='ift', safe_ift=args.safe_ift, b_solver=eval(args.b_solver),
75 | b_solver_kwargs=dict(threshold=args.b_thres, stop_mode=args.stop_mode)
76 | )
77 |
78 | self.produce_grad = produce_grad
79 | self.hook = None
80 |
81 | def freeze_bn(self):
82 | for m in self.modules():
83 | if isinstance(m, nn.BatchNorm2d):
84 | m.eval()
85 |
86 | def _initialize_flow(self, img):
87 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
88 | N, _, H, W = img.shape
89 | coords0 = coords_grid(N, H//8, W//8, device=img.device)
90 | coords1 = coords_grid(N, H//8, W//8, device=img.device)
91 |
92 | # optical flow computed as difference: flow = coords1 - coords0
93 | return coords0, coords1
94 |
95 | def _upsample_flow(self, flow, mask):
96 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
97 | N, _, H, W = flow.shape
98 | mask = mask.view(N, 1, 9, 8, 8, H, W)
99 | mask = torch.softmax(mask, dim=2)
100 |
101 | up_flow = F.unfold(8 * flow, [3,3], padding=1)
102 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
103 |
104 | up_flow = torch.sum(mask * up_flow, dim=2)
105 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
106 | return up_flow.reshape(N, 2, 8*H, 8*W)
107 |
108 | def _decode(self, z_out, vec2list, coords0):
109 | flow_predictions = []
110 |
111 | for z_pred in z_out:
112 | net, coords1 = vec2list(z_pred)
113 | up_mask = .25 * self.update_block.mask(net)
114 |
115 | if up_mask is None:
116 | flow_up = upflow8(coords1 - coords0)
117 | else:
118 | flow_up = self._upsample_flow(coords1 - coords0, up_mask)
119 |
120 | flow_predictions.append(flow_up)
121 |
122 | return flow_predictions
123 |
124 | def _fixed_point_solve(self, deq_func, z_star, f_thres=None, **kwargs):
125 | if f_thres is None: f_thres = self.f_thres
126 | indexing = self.indexing if self.training else None
127 |
128 | with torch.no_grad():
129 | result = self.f_solver(deq_func, x0=z_star, threshold=f_thres, # To reuse previous coarse fixed points
130 | eps=(1e-3 if self.stop_mode == "abs" else 1e-6), stop_mode=self.stop_mode, indexing=indexing)
131 |
132 | z_star, trajectory = result['result'], result['indexing']
133 |
134 | return z_star, trajectory, min(result['rel_trace']), min(result['abs_trace'])
135 |
136 | def forward(self, image1, image2,
137 | flow_gt=None, valid=None, step_seq_loss=None,
138 | flow_init=None, cached_result=None,
139 | **kwargs):
140 | """ Estimate optical flow between pair of frames """
141 |
142 | image1 = 2 * (image1 / 255.0) - 1.0
143 | image2 = 2 * (image2 / 255.0) - 1.0
144 |
145 | image1 = image1.contiguous()
146 | image2 = image2.contiguous()
147 |
148 | hdim = self.hidden_dim
149 | cdim = self.context_dim
150 |
151 | # run the feature network
152 | with autocast(enabled=self.args.mixed_precision):
153 | fmap1, fmap2 = self.fnet([image1, image2])
154 |
155 | fmap1 = fmap1.float()
156 | fmap2 = fmap2.float()
157 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
158 |
159 | # run the context network
160 | with autocast(enabled=self.args.mixed_precision):
161 | inp = self.cnet(image1)
162 | inp = torch.relu(inp)
163 |
164 | bsz, _, H, W = inp.shape
165 | coords0, coords1 = self._initialize_flow(image1)
166 | net = torch.zeros(bsz, hdim, H, W, device=inp.device)
167 | if cached_result:
168 | net, flow_pred_prev = cached_result
169 | coords1 = coords0 + flow_pred_prev
170 |
171 | if flow_init is not None:
172 | coords1 = coords1 + flow_init
173 |
174 | seed = (inp.get_device() == 0 and np.random.uniform(0,1) < 2e-3)
175 |
176 | def list2vec(h, c): # h is net, c is coords1
177 | return torch.cat([h.view(bsz, h.shape[1], -1), c.view(bsz, c.shape[1], -1)], dim=1)
178 |
179 | def vec2list(hidden):
180 | return hidden[:,:net.shape[1]].view_as(net), hidden[:,net.shape[1]:].view_as(coords1)
181 |
182 | def deq_func(hidden):
183 | h, c = vec2list(hidden)
184 | c = c.detach()
185 |
186 | with autocast(enabled=self.args.mixed_precision):
187 | # corr_fn(coords1) produces the index correlation volumes
188 | new_h, delta_flow = self.update_block(h, inp, corr_fn(c), c-coords0, None)
189 | new_c = c + delta_flow # F(t+1) = F(t) + \Delta(t)
190 | return list2vec(new_h, new_c)
191 |
192 | self.update_block.reset() # In case we use weight normalization, we need to recompute the weight
193 | z_star = list2vec(net, coords1)
194 |
195 | # The code for DEQ version, where we use a wrapper.
196 | if self.training:
197 | _, trajectory, rel_error, abs_error = self._fixed_point_solve(deq_func, z_star, *kwargs)
198 |
199 | z_out = []
200 | for z_pred, produce_grad in zip(trajectory, self.produce_grad):
201 | z_out += produce_grad(self, z_pred, deq_func) # See lib/grad.py for the backward pass implementations
202 |
203 | flow_predictions = self._decode(z_out, vec2list, coords0)
204 |
205 | flow_loss, epe = step_seq_loss(flow_predictions, flow_gt, valid)
206 | metrics = process_metrics(epe, rel_error, abs_error)
207 |
208 | return flow_loss, metrics
209 | else:
210 | # During inference, we directly solve for fixed point
211 | z_star, _, rel_error, abs_error = self._fixed_point_solve(deq_func, z_star, f_thres=self.eval_f_thres)
212 | flow_up = self._decode([z_star], vec2list, coords0)[0]
213 | net, coords1 = vec2list(z_star)
214 |
215 | return coords1 - coords0, flow_up, {"sradius": torch.zeros(1, device=z_star.device), "cached_result": (net, coords1 - coords0)}
216 |
217 |
218 | def get_model(args):
219 | return DEQFlowDemo
220 |
--------------------------------------------------------------------------------
/code.v.1.0/core/extractor.py:
--------------------------------------------------------------------------------
1 | # This code was originally taken from RAFT without modification
2 | # https://github.com/princeton-vl/RAFT/blob/master/core/extractor.py
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class ResidualBlock(nn.Module):
10 | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
11 | super(ResidualBlock, self).__init__()
12 |
13 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
14 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
15 | self.relu = nn.ReLU(inplace=True)
16 |
17 | num_groups = planes // 8
18 |
19 | if norm_fn == 'group':
20 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
21 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
22 | if not stride == 1:
23 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
24 |
25 | elif norm_fn == 'batch':
26 | self.norm1 = nn.BatchNorm2d(planes)
27 | self.norm2 = nn.BatchNorm2d(planes)
28 | if not stride == 1:
29 | self.norm3 = nn.BatchNorm2d(planes)
30 |
31 | elif norm_fn == 'instance':
32 | self.norm1 = nn.InstanceNorm2d(planes)
33 | self.norm2 = nn.InstanceNorm2d(planes)
34 | if not stride == 1:
35 | self.norm3 = nn.InstanceNorm2d(planes)
36 |
37 | elif norm_fn == 'none':
38 | self.norm1 = nn.Sequential()
39 | self.norm2 = nn.Sequential()
40 | if not stride == 1:
41 | self.norm3 = nn.Sequential()
42 |
43 | if stride == 1:
44 | self.downsample = None
45 |
46 | else:
47 | self.downsample = nn.Sequential(
48 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
49 |
50 |
51 | def forward(self, x):
52 | y = x
53 | y = self.relu(self.norm1(self.conv1(y)))
54 | y = self.relu(self.norm2(self.conv2(y)))
55 |
56 | if self.downsample is not None:
57 | x = self.downsample(x)
58 |
59 | return self.relu(x+y)
60 |
61 |
62 |
63 | class BottleneckBlock(nn.Module):
64 | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
65 | super(BottleneckBlock, self).__init__()
66 |
67 | self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
68 | self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
69 | self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
70 | self.relu = nn.ReLU(inplace=True)
71 |
72 | num_groups = planes // 8
73 |
74 | if norm_fn == 'group':
75 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
76 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
77 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
78 | if not stride == 1:
79 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
80 |
81 | elif norm_fn == 'batch':
82 | self.norm1 = nn.BatchNorm2d(planes//4)
83 | self.norm2 = nn.BatchNorm2d(planes//4)
84 | self.norm3 = nn.BatchNorm2d(planes)
85 | if not stride == 1:
86 | self.norm4 = nn.BatchNorm2d(planes)
87 |
88 | elif norm_fn == 'instance':
89 | self.norm1 = nn.InstanceNorm2d(planes//4)
90 | self.norm2 = nn.InstanceNorm2d(planes//4)
91 | self.norm3 = nn.InstanceNorm2d(planes)
92 | if not stride == 1:
93 | self.norm4 = nn.InstanceNorm2d(planes)
94 |
95 | elif norm_fn == 'none':
96 | self.norm1 = nn.Sequential()
97 | self.norm2 = nn.Sequential()
98 | self.norm3 = nn.Sequential()
99 | if not stride == 1:
100 | self.norm4 = nn.Sequential()
101 |
102 | if stride == 1:
103 | self.downsample = None
104 |
105 | else:
106 | self.downsample = nn.Sequential(
107 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
108 |
109 |
110 | def forward(self, x):
111 | y = x
112 | y = self.relu(self.norm1(self.conv1(y)))
113 | y = self.relu(self.norm2(self.conv2(y)))
114 | y = self.relu(self.norm3(self.conv3(y)))
115 |
116 | if self.downsample is not None:
117 | x = self.downsample(x)
118 |
119 | return self.relu(x+y)
120 |
121 | class BasicEncoder(nn.Module):
122 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
123 | super(BasicEncoder, self).__init__()
124 | self.norm_fn = norm_fn
125 |
126 | if self.norm_fn == 'group':
127 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
128 |
129 | elif self.norm_fn == 'batch':
130 | self.norm1 = nn.BatchNorm2d(64)
131 |
132 | elif self.norm_fn == 'instance':
133 | self.norm1 = nn.InstanceNorm2d(64)
134 |
135 | elif self.norm_fn == 'none':
136 | self.norm1 = nn.Sequential()
137 |
138 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
139 | self.relu1 = nn.ReLU(inplace=True)
140 |
141 | self.in_planes = 64
142 | self.layer1 = self._make_layer(64, stride=1)
143 | self.layer2 = self._make_layer(96, stride=2)
144 | self.layer3 = self._make_layer(128, stride=2)
145 |
146 | # output convolution
147 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
148 |
149 | self.dropout = None
150 | if dropout > 0:
151 | self.dropout = nn.Dropout2d(p=dropout)
152 |
153 | for m in self.modules():
154 | if isinstance(m, nn.Conv2d):
155 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
156 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
157 | if m.weight is not None:
158 | nn.init.constant_(m.weight, 1)
159 | if m.bias is not None:
160 | nn.init.constant_(m.bias, 0)
161 |
162 | def _make_layer(self, dim, stride=1):
163 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
164 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
165 | layers = (layer1, layer2)
166 |
167 | self.in_planes = dim
168 | return nn.Sequential(*layers)
169 |
170 |
171 | def forward(self, x):
172 |
173 | # if input is list, combine batch dimension
174 | is_list = isinstance(x, tuple) or isinstance(x, list)
175 | if is_list:
176 | batch_dim = x[0].shape[0]
177 | x = torch.cat(x, dim=0)
178 |
179 | x = self.conv1(x)
180 | x = self.norm1(x)
181 | x = self.relu1(x)
182 |
183 | x = self.layer1(x)
184 | x = self.layer2(x)
185 | x = self.layer3(x)
186 |
187 | x = self.conv2(x)
188 |
189 | if self.training and self.dropout is not None:
190 | x = self.dropout(x)
191 |
192 | if is_list:
193 | x = torch.split(x, [batch_dim, batch_dim], dim=0)
194 |
195 | return x
196 |
197 |
198 | class SmallEncoder(nn.Module):
199 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
200 | super(SmallEncoder, self).__init__()
201 | self.norm_fn = norm_fn
202 |
203 | if self.norm_fn == 'group':
204 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
205 |
206 | elif self.norm_fn == 'batch':
207 | self.norm1 = nn.BatchNorm2d(32)
208 |
209 | elif self.norm_fn == 'instance':
210 | self.norm1 = nn.InstanceNorm2d(32)
211 |
212 | elif self.norm_fn == 'none':
213 | self.norm1 = nn.Sequential()
214 |
215 | self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
216 | self.relu1 = nn.ReLU(inplace=True)
217 |
218 | self.in_planes = 32
219 | self.layer1 = self._make_layer(32, stride=1)
220 | self.layer2 = self._make_layer(64, stride=2)
221 | self.layer3 = self._make_layer(96, stride=2)
222 |
223 | self.dropout = None
224 | if dropout > 0:
225 | self.dropout = nn.Dropout2d(p=dropout)
226 |
227 | self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
228 |
229 | for m in self.modules():
230 | if isinstance(m, nn.Conv2d):
231 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
232 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
233 | if m.weight is not None:
234 | nn.init.constant_(m.weight, 1)
235 | if m.bias is not None:
236 | nn.init.constant_(m.bias, 0)
237 |
238 | def _make_layer(self, dim, stride=1):
239 | layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
240 | layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
241 | layers = (layer1, layer2)
242 |
243 | self.in_planes = dim
244 | return nn.Sequential(*layers)
245 |
246 |
247 | def forward(self, x):
248 |
249 | # if input is list, combine batch dimension
250 | is_list = isinstance(x, tuple) or isinstance(x, list)
251 | if is_list:
252 | batch_dim = x[0].shape[0]
253 | x = torch.cat(x, dim=0)
254 |
255 | x = self.conv1(x)
256 | x = self.norm1(x)
257 | x = self.relu1(x)
258 |
259 | x = self.layer1(x)
260 | x = self.layer2(x)
261 | x = self.layer3(x)
262 | x = self.conv2(x)
263 |
264 | if self.training and self.dropout is not None:
265 | x = self.dropout(x)
266 |
267 | if is_list:
268 | x = torch.split(x, [batch_dim, batch_dim], dim=0)
269 |
270 | return x
271 |
--------------------------------------------------------------------------------
/code.v.1.0/core/gma.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, einsum
3 | from einops import rearrange
4 |
5 | from lib.optimizations import weight_norm
6 |
7 |
8 | class RelPosEmb(nn.Module):
9 | def __init__(
10 | self,
11 | max_pos_size,
12 | dim_head
13 | ):
14 | super().__init__()
15 | self.rel_height = nn.Embedding(2 * max_pos_size - 1, dim_head)
16 | self.rel_width = nn.Embedding(2 * max_pos_size - 1, dim_head)
17 |
18 | deltas = torch.arange(max_pos_size).view(1, -1) - torch.arange(max_pos_size).view(-1, 1)
19 | rel_ind = deltas + max_pos_size - 1
20 | self.register_buffer('rel_ind', rel_ind)
21 |
22 | def forward(self, q):
23 | batch, heads, h, w, c = q.shape
24 | height_emb = self.rel_height(self.rel_ind[:h, :h].reshape(-1))
25 | width_emb = self.rel_width(self.rel_ind[:w, :w].reshape(-1))
26 |
27 | height_emb = rearrange(height_emb, '(x u) d -> x u () d', x=h)
28 | width_emb = rearrange(width_emb, '(y v) d -> y () v d', y=w)
29 |
30 | height_score = einsum('b h x y d, x u v d -> b h x y u v', q, height_emb)
31 | width_score = einsum('b h x y d, y u v d -> b h x y u v', q, width_emb)
32 |
33 | return height_score + width_score
34 |
35 |
36 | class Attention(nn.Module):
37 | def __init__(
38 | self,
39 | *,
40 | dim,
41 | max_pos_size = 100,
42 | heads = 4,
43 | dim_head = 128,
44 | ):
45 | super().__init__()
46 | self.heads = heads
47 | self.scale = dim_head ** -0.5
48 | inner_dim = heads * dim_head
49 |
50 | self.to_qk = nn.Conv2d(dim, inner_dim * 2, 1, bias=False)
51 |
52 | def _wnorm(self):
53 | self.to_qk, self.to_qk_fn = weight_norm(module=self.to_qk, names=['weight'], dim=0)
54 |
55 | def reset(self):
56 | for name in ['to_qk']:
57 | if name + '_fn' in self.__dict__:
58 | eval(f'self.{name}_fn').reset(eval(f'self.{name}'))
59 |
60 | def forward(self, fmap):
61 | heads, b, c, h, w = self.heads, *fmap.shape
62 |
63 | q, k = self.to_qk(fmap).chunk(2, dim=1)
64 |
65 | q, k = map(lambda t: rearrange(t, 'b (h d) x y -> b h x y d', h=heads), (q, k))
66 | q = self.scale * q
67 |
68 | sim = einsum('b h x y d, b h u v d -> b h x y u v', q, k)
69 |
70 | sim = rearrange(sim, 'b h x y u v -> b h (x y) (u v)')
71 | attn = sim.softmax(dim=-1)
72 |
73 | return attn
74 |
75 |
76 | class Aggregate(nn.Module):
77 | def __init__(
78 | self,
79 | dim,
80 | heads = 4,
81 | dim_head = 128,
82 | ):
83 | super().__init__()
84 | self.heads = heads
85 | self.scale = dim_head ** -0.5
86 | inner_dim = heads * dim_head
87 |
88 | self.to_v = nn.Conv2d(dim, inner_dim, 1, bias=False)
89 |
90 | self.gamma = nn.Parameter(torch.zeros(1))
91 |
92 | if dim != inner_dim:
93 | self.project = nn.Conv2d(inner_dim, dim, 1, bias=False)
94 | else:
95 | self.project = None
96 |
97 | def _wnorm(self):
98 | self.to_v, self.to_v_fn = weight_norm(module=self.to_v, names=['weight'], dim=0)
99 |
100 | if self.project:
101 | self.project, self.project_fn = weight_norm(module=self.project, names=['weight'], dim=0)
102 |
103 | def reset(self):
104 | for name in ['to_v', 'project']:
105 | if name + '_fn' in self.__dict__:
106 | eval(f'self.{name}_fn').reset(eval(f'self.{name}'))
107 |
108 | def forward(self, attn, fmap):
109 | heads, b, c, h, w = self.heads, *fmap.shape
110 |
111 | v = self.to_v(fmap)
112 | v = rearrange(v, 'b (h d) x y -> b h (x y) d', h=heads)
113 | out = einsum('b h i j, b h j d -> b h i d', attn, v)
114 | out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w)
115 |
116 | if self.project is not None:
117 | out = self.project(out)
118 |
119 | out = fmap + self.gamma * out
120 |
121 | return out
122 |
123 |
124 | if __name__ == "__main__":
125 | att = Attention(dim=128, heads=1)
126 | fmap = torch.randn(2, 128, 40, 90)
127 | out = att(fmap)
128 |
129 | print(out.shape)
130 |
--------------------------------------------------------------------------------
/code.v.1.0/core/lib/grad.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch import autograd
6 |
7 | from .solvers import anderson, broyden
8 |
9 |
10 | def make_pair(target, source):
11 | if len(target) == len(source):
12 | return source
13 | elif len(source) == 1:
14 | return [source[0] for _ in range(len(target))]
15 | else:
16 | raise ValueError('Unable to align the arg squence!')
17 |
18 |
19 | def backward_factory(
20 | grad_type='ift',
21 | safe_ift=False,
22 | b_solver=anderson,
23 | b_solver_kwargs=dict(),
24 | sup_all=False,
25 | tau=1.0):
26 | """
27 | [2019-NeurIPS] Deep Equilibrium Models
28 | [2021-ICLR] Is Attention Better Than Matrix Decomposition?
29 | [2021-NeurIPS] On Training Implicit Models
30 | [2022-AAAI] JFB: Jacobian-Free Backpropagation for Implicit Networks
31 |
32 | This function implements a factory for the backward pass of implicit deep learning,
33 | e.g., DEQ (implicit models), Hamburger (optimization layer), etc.
34 | It now supports IFT, 1-step Grad, and Phantom Grad.
35 |
36 | Kwargs:
37 | grad_type (string, int):
38 | grad_type should be ``ift`` or an int. Default ``ift``.
39 | Set to ``ift`` to enable the implicit differentiation mode.
40 | When passing a number k to this function, it runs UPG with steps k and damping tau.
41 | safe_ift (bool):
42 | Replace the O(1) hook implementeion with a safer one. Default ``False``.
43 | Set to ``True`` to avoid the (potential) segment fault (under previous versions of Pytorch).
44 | b_solver (type):
45 | Solver for the IFT backward pass. Default ``anderson``.
46 | Supported solvers: anderson, broyden.
47 | b_solver_kwargs (dict):
48 | Colllection of backward solver kwargs, e.g.,
49 | threshold (int), max steps for the backward solver,
50 | stop_mode (string), criterion for convergence,
51 | etc.
52 | See solver.py to check all the kwargs.
53 | sup_all (bool):
54 | Indicate whether to supervise all the trajectories by Phantom Grad.
55 | Set ``True`` to return all trajectory in Phantom Grad.
56 | tau (float):
57 | Damping factor for Phantom Grad. Default ``1.0``.
58 | 0.5 is recommended for CIFAR-10. 1.0 for DEQ flow.
59 | For DEQ flow, the gating function in GRU naturally produces adaptive tau values.
60 |
61 | Return:
62 | A gradient functor for implicit deep learning.
63 | Args:
64 | trainer (nn.Module): the module that employs implicit deep learning.
65 | z_pred (torch.Tensor): latent state to run the backward pass.
66 | func (type): function that defines the ``f`` in ``z = f(z)``.
67 |
68 | Return:
69 | (list(torch.Tensor)): a list of tensors that tracks the gradient info.
70 |
71 | """
72 |
73 | if grad_type == 'ift':
74 | assert b_solver in [anderson, broyden]
75 |
76 | if safe_ift:
77 | def plain_ift_grad(trainer, z_pred, func):
78 | z_pred = z_pred.requires_grad_()
79 | new_z_pred = func(z_pred) # 1-step grad for df/dtheta
80 |
81 | z_pred_copy = new_z_pred.clone().detach().requires_grad_()
82 | new_z_pred_copy = func(z_pred_copy)
83 | def backward_hook(grad):
84 | result = b_solver(lambda y: autograd.grad(new_z_pred_copy, z_pred_copy, y, retain_graph=True)[0] + grad,
85 | torch.zeros_like(grad), **b_solver_kwargs)
86 | return result['result']
87 | new_z_pred.register_hook(backward_hook)
88 |
89 | return [new_z_pred]
90 | return plain_ift_grad
91 | else:
92 | def hook_ift_grad(trainer, z_pred, func):
93 | z_pred = z_pred.requires_grad_()
94 | new_z_pred = func(z_pred) # 1-step grad for df/dtheta
95 |
96 | def backward_hook(grad):
97 | if trainer.hook is not None:
98 | trainer.hook.remove() # To avoid infinite loop
99 | result = b_solver(lambda y: autograd.grad(new_z_pred, z_pred, y, retain_graph=True)[0] + grad,
100 | torch.zeros_like(grad), **b_solver_kwargs)
101 | return result['result']
102 | trainer.hook = new_z_pred.register_hook(backward_hook)
103 |
104 | return [new_z_pred]
105 | return hook_ift_grad
106 | else:
107 | assert type(grad_type) is int and grad_type >= 1
108 | n_phantom_grad = grad_type
109 |
110 | if sup_all:
111 | def sup_all_phantom_grad(trainer, z_pred, func):
112 | z_out = []
113 | for _ in range(n_phantom_grad):
114 | z_pred = (1 - tau) * z_pred + tau * func(z_pred)
115 | z_out.append(z_pred)
116 |
117 | return z_out
118 | return sup_all_phantom_grad
119 | else:
120 | def phantom_grad(trainer, z_pred, func):
121 | for _ in range(n_phantom_grad):
122 | z_pred = (1 - tau) * z_pred + tau * func(z_pred)
123 |
124 | return [z_pred]
125 | return phantom_grad
126 |
--------------------------------------------------------------------------------
/code.v.1.0/core/lib/jacobian.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | import numpy as np
5 |
6 |
7 | def jac_loss_estimate(f0, z0, vecs=2, create_graph=True):
8 | """Estimating tr(J^TJ)=tr(JJ^T) via Hutchinson estimator
9 |
10 | Args:
11 | f0 (torch.Tensor): Output of the function f (whose J is to be analyzed)
12 | z0 (torch.Tensor): Input to the function f
13 | vecs (int, optional): Number of random Gaussian vectors to use. Defaults to 2.
14 | create_graph (bool, optional): Whether to create backward graph (e.g., to train on this loss).
15 | Defaults to True.
16 |
17 | Returns:
18 | torch.Tensor: A 1x1 torch tensor that encodes the (shape-normalized) jacobian loss
19 | """
20 | vecs = vecs
21 | result = 0
22 | for i in range(vecs):
23 | v = torch.randn(*z0.shape).to(z0)
24 | vJ = torch.autograd.grad(f0, z0, v, retain_graph=True, create_graph=create_graph)[0]
25 | result += vJ.norm()**2
26 | return result / vecs / np.prod(z0.shape)
27 |
28 | def power_method(f0, z0, n_iters=200):
29 | """Estimating the spectral radius of J using power method
30 |
31 | Args:
32 | f0 (torch.Tensor): Output of the function f (whose J is to be analyzed)
33 | z0 (torch.Tensor): Input to the function f
34 | n_iters (int, optional): Number of power method iterations. Defaults to 200.
35 |
36 | Returns:
37 | tuple: (largest eigenvector, largest (abs.) eigenvalue)
38 | """
39 | evector = torch.randn_like(z0)
40 | bsz = evector.shape[0]
41 | for i in range(n_iters):
42 | vTJ = torch.autograd.grad(f0, z0, evector, retain_graph=(i < n_iters-1), create_graph=False)[0]
43 | evalue = (vTJ * evector).reshape(bsz, -1).sum(1, keepdim=True) / (evector * evector).reshape(bsz, -1).sum(1, keepdim=True)
44 | evector = (vTJ.reshape(bsz, -1) / vTJ.reshape(bsz, -1).norm(dim=1, keepdim=True)).reshape_as(z0)
45 | return (evector, torch.abs(evalue))
--------------------------------------------------------------------------------
/code.v.1.0/core/lib/layer_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 |
5 |
6 | def list2vec(z1_list):
7 | """Convert list of tensors to a vector"""
8 | bsz = z1_list[0].size(0)
9 | return torch.cat([elem.reshape(bsz, -1, 1) for elem in z1_list], dim=1)
10 |
11 |
12 | def vec2list(z1, cutoffs):
13 | """Convert a vector back to a list, via the cutoffs specified"""
14 | bsz = z1.shape[0]
15 | z1_list = []
16 | start_idx, end_idx = 0, cutoffs[0][0] * cutoffs[0][1] * cutoffs[0][2]
17 | for i in range(len(cutoffs)):
18 | z1_list.append(z1[:, start_idx:end_idx].view(bsz, *cutoffs[i]))
19 | if i < len(cutoffs)-1:
20 | start_idx = end_idx
21 | end_idx += cutoffs[i + 1][0] * cutoffs[i + 1][1] * cutoffs[i + 1][2]
22 | return z1_list
23 |
24 |
25 | def conv3x3(in_planes, out_planes, stride=1, bias=False):
26 | """3x3 convolution with padding"""
27 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=bias)
28 |
29 | def conv5x5(in_planes, out_planes, stride=1, bias=False):
30 | """5x5 convolution with padding"""
31 | return nn.Conv2d(in_planes, out_planes, kernel_size=5, stride=stride, padding=2, bias=bias)
32 |
33 |
34 | def norm_diff(new, old, show_list=False):
35 | if show_list:
36 | return [(new[i] - old[i]).norm().item() for i in range(len(new))]
37 | return np.sqrt(sum((new[i] - old[i]).norm().item()**2 for i in range(len(new))))
--------------------------------------------------------------------------------
/code.v.1.0/core/lib/optimizations.py:
--------------------------------------------------------------------------------
1 | from torch.nn.parameter import Parameter
2 | import torch.nn as nn
3 | import torch
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 |
7 | ##############################################################################################################
8 | #
9 | # Temporal DropConnect in a feed-forward setting
10 | #
11 | ##############################################################################################################
12 |
13 | class WeightDrop(torch.nn.Module):
14 | def __init__(self, module, weights, dropout=0, temporal=True):
15 | """
16 | Weight DropConnect, adapted from a recurrent setting by Merity et al. 2017
17 |
18 | :param module: The module whose weights are to be applied dropout on
19 | :param weights: A 2D list identifying the weights to be regularized. Each element of weights should be a
20 | list containing the "path" to the weight kernel. For instance, if we want to regularize
21 | module.layer2.weight3, then this should be ["layer2", "weight3"].
22 | :param dropout: The dropout rate (0 means no dropout)
23 | :param temporal: Whether we apply DropConnect only to the temporal parts of the weight (empirically we found
24 | this not very important)
25 | """
26 | super(WeightDrop, self).__init__()
27 | self.module = module
28 | self.weights = weights
29 | self.dropout = dropout
30 | self.temporal = temporal
31 | if self.dropout > 0.0:
32 | self._setup()
33 |
34 | def _setup(self):
35 | for path in self.weights:
36 | full_name_w = '.'.join(path)
37 |
38 | module = self.module
39 | name_w = path[-1]
40 | for i in range(len(path) - 1):
41 | module = getattr(module, path[i])
42 | w = getattr(module, name_w)
43 | del module._parameters[name_w]
44 | module.register_parameter(name_w + '_raw', Parameter(w.data))
45 |
46 | def _setweights(self):
47 | for path in self.weights:
48 | module = self.module
49 | name_w = path[-1]
50 | for i in range(len(path) - 1):
51 | module = getattr(module, path[i])
52 | raw_w = getattr(module, name_w + '_raw')
53 |
54 | if len(raw_w.size()) > 2 and raw_w.size(2) > 1 and self.temporal:
55 | # Drop the temporal parts of the weight; if 1x1 convolution then drop the whole kernel
56 | w = torch.cat([F.dropout(raw_w[:, :, :-1], p=self.dropout, training=self.training),
57 | raw_w[:, :, -1:]], dim=2)
58 | else:
59 | w = F.dropout(raw_w, p=self.dropout, training=self.training)
60 |
61 | setattr(module, name_w, w)
62 |
63 | def forward(self, *args, **kwargs):
64 | if self.dropout > 0.0:
65 | self._setweights()
66 | return self.module.forward(*args, **kwargs)
67 |
68 |
69 | def matrix_diag(a, dim=2):
70 | """
71 | a has dimension (N, (L,) C), we want a matrix/batch diag that produces (N, (L,) C, C) from the last dimension of a
72 | """
73 | if dim == 2:
74 | res = torch.zeros(a.size(0), a.size(1), a.size(1))
75 | res.as_strided(a.size(), [res.stride(0), res.size(2)+1]).copy_(a)
76 | else:
77 | res = torch.zeros(a.size(0), a.size(1), a.size(2), a.size(2))
78 | res.as_strided(a.size(), [res.stride(0), res.stride(1), res.size(3)+1]).copy_(a)
79 | return res
80 |
81 | ##############################################################################################################
82 | #
83 | # Embedding dropout
84 | #
85 | ##############################################################################################################
86 |
87 | def embedded_dropout(embed, words, dropout=0.1, scale=None):
88 | """
89 | Apply embedding encoder (whose weight we apply a dropout)
90 |
91 | :param embed: The embedding layer
92 | :param words: The input sequence
93 | :param dropout: The embedding weight dropout rate
94 | :param scale: Scaling factor for the dropped embedding weight
95 | :return: The embedding output
96 | """
97 | if dropout:
98 | mask = embed.weight.data.new().resize_((embed.weight.size(0), 1)).bernoulli_(1 - dropout).expand_as(
99 | embed.weight) / (1 - dropout)
100 | mask = Variable(mask)
101 | masked_embed_weight = mask * embed.weight
102 | else:
103 | masked_embed_weight = embed.weight
104 |
105 | if scale:
106 | masked_embed_weight = scale.expand_as(masked_embed_weight) * masked_embed_weight
107 |
108 | padding_idx = embed.padding_idx
109 | if padding_idx is None:
110 | padding_idx = -1
111 |
112 | X = F.embedding(words, masked_embed_weight, padding_idx, embed.max_norm, embed.norm_type,
113 | embed.scale_grad_by_freq, embed.sparse)
114 | return X
115 |
116 |
117 |
118 | ##############################################################################################################
119 | #
120 | # Variational dropout (for input/output layers, and for hidden layers)
121 | #
122 | ##############################################################################################################
123 |
124 |
125 | class VariationalHidDropout2d(nn.Module):
126 | def __init__(self, dropout=0.0):
127 | super(VariationalHidDropout2d, self).__init__()
128 | self.dropout = dropout
129 | self.mask = None
130 |
131 | def forward(self, x):
132 | if not self.training or self.dropout == 0:
133 | return x
134 | bsz, d, H, W = x.shape
135 | if self.mask is None:
136 | m = torch.zeros(bsz, d, H, W).bernoulli_(1 - self.dropout).to(x)
137 | self.mask = m.requires_grad_(False) / (1 - self.dropout)
138 | return self.mask * x
139 |
140 | ##############################################################################################################
141 | #
142 | # Weight normalization. Modified from the original PyTorch's implementation of weight normalization.
143 | #
144 | ##############################################################################################################
145 |
146 | def _norm(p, dim):
147 | """Computes the norm over all dimensions except dim"""
148 | if dim is None:
149 | return p.norm()
150 | elif dim == 0:
151 | output_size = (p.size(0),) + (1,) * (p.dim() - 1)
152 | return p.contiguous().view(p.size(0), -1).norm(dim=1).view(*output_size)
153 | elif dim == p.dim() - 1:
154 | output_size = (1,) * (p.dim() - 1) + (p.size(-1),)
155 | return p.contiguous().view(-1, p.size(-1)).norm(dim=0).view(*output_size)
156 | else:
157 | return _norm(p.transpose(0, dim), 0).transpose(0, dim)
158 |
159 |
160 | class WeightNorm(object):
161 | def __init__(self, names, dim):
162 | """
163 | Weight normalization module
164 |
165 | :param names: The list of weight names to apply weightnorm on
166 | :param dim: The dimension of the weights to be normalized
167 | """
168 | self.names = names
169 | self.dim = dim
170 |
171 | def compute_weight(self, module, name):
172 | g = getattr(module, name + '_g')
173 | v = getattr(module, name + '_v')
174 | return v * (g / _norm(v, self.dim))
175 |
176 | @staticmethod
177 | def apply(module, names, dim):
178 | fn = WeightNorm(names, dim)
179 |
180 | for name in names:
181 | weight = getattr(module, name)
182 |
183 | # remove w from parameter list
184 | del module._parameters[name]
185 |
186 | # add g and v as new parameters and express w as g/||v|| * v
187 | module.register_parameter(name + '_g', Parameter(_norm(weight, dim).data))
188 | module.register_parameter(name + '_v', Parameter(weight.data))
189 | setattr(module, name, fn.compute_weight(module, name))
190 |
191 | # recompute weight before every forward()
192 | module.register_forward_pre_hook(fn)
193 | return fn
194 |
195 | def remove(self, module):
196 | for name in self.names:
197 | weight = self.compute_weight(module, name)
198 | delattr(module, name)
199 | del module._parameters[name + '_g']
200 | del module._parameters[name + '_v']
201 | module.register_parameter(name, Parameter(weight.data))
202 |
203 | def reset(self, module):
204 | for name in self.names:
205 | setattr(module, name, self.compute_weight(module, name))
206 |
207 | def __call__(self, module, inputs):
208 | # Typically, every time the module is called we need to recompute the weight. However,
209 | # in the case of TrellisNet, the same weight is shared across layers, and we can save
210 | # a lot of intermediate memory by just recomputing once (at the beginning of first call).
211 | pass
212 |
213 |
214 | def weight_norm(module, names, dim=0):
215 | fn = WeightNorm.apply(module, names, dim)
216 | return module, fn
217 |
--------------------------------------------------------------------------------
/code.v.1.0/core/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | MAX_FLOW = 400
4 |
5 | @torch.no_grad()
6 | def compute_epe(flow_pred, flow_gt, valid, max_flow=MAX_FLOW):
7 | # exlude invalid pixels and extremely large diplacements
8 | mag = torch.sum(flow_gt**2, dim=1).sqrt()
9 | valid = (valid >= 0.5) & (mag < max_flow)
10 |
11 | epe = torch.sum((flow_pred - flow_gt)**2, dim=1).sqrt()
12 | epe = torch.masked_fill(epe, ~valid, 0)
13 |
14 | return epe
15 |
16 |
17 | @torch.no_grad()
18 | def process_metrics(epe, rel_error, abs_error, **kwargs):
19 | epe = epe.flatten(1)
20 | metrics = torch.stack(
21 | [
22 | epe.mean(dim=1),
23 | (epe < 1).float().mean(dim=1),
24 | (epe < 3).float().mean(dim=1),
25 | (epe < 5).float().mean(dim=1),
26 | torch.tensor(rel_error).cuda().repeat(epe.shape[0]),
27 | torch.tensor(abs_error).cuda().repeat(epe.shape[0]),
28 | ],
29 | dim=1
30 | )
31 |
32 | # (B // N_GPU, N_Metrics)
33 | return metrics
34 |
35 |
36 | def merge_metrics(metrics):
37 | metrics = metrics.mean(dim=0)
38 | metrics = {
39 | 'epe': metrics[0].item(),
40 | '1px': metrics[1].item(),
41 | '3px': metrics[2].item(),
42 | '5px': metrics[3].item(),
43 | 'rel': metrics[4].item(),
44 | 'abs': metrics[5].item(),
45 | }
46 |
47 | return metrics
48 |
--------------------------------------------------------------------------------
/code.v.1.0/core/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/deq-flow/869d4bcd9c1d227cbfa16639994d1b118510a9bb/code.v.1.0/core/utils/__init__.py
--------------------------------------------------------------------------------
/code.v.1.0/core/utils/augmentor.py:
--------------------------------------------------------------------------------
1 | # This code was originally taken from RAFT without modification
2 | # https://github.com/princeton-vl/RAFT/blob/master/utils/augmentor.py
3 |
4 | import numpy as np
5 | import random
6 | import math
7 | from PIL import Image
8 |
9 | import cv2
10 | cv2.setNumThreads(0)
11 | cv2.ocl.setUseOpenCL(False)
12 |
13 | import torch
14 | from torchvision.transforms import ColorJitter
15 | import torch.nn.functional as F
16 |
17 |
18 | class FlowAugmentor:
19 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
20 |
21 | # spatial augmentation params
22 | self.crop_size = crop_size
23 | self.min_scale = min_scale
24 | self.max_scale = max_scale
25 | self.spatial_aug_prob = 0.8
26 | self.stretch_prob = 0.8
27 | self.max_stretch = 0.2
28 |
29 | # flip augmentation params
30 | self.do_flip = do_flip
31 | self.h_flip_prob = 0.5
32 | self.v_flip_prob = 0.1
33 |
34 | # photometric augmentation params
35 | self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
36 | self.asymmetric_color_aug_prob = 0.2
37 | self.eraser_aug_prob = 0.5
38 |
39 | def color_transform(self, img1, img2):
40 | """ Photometric augmentation """
41 |
42 | # asymmetric
43 | if np.random.rand() < self.asymmetric_color_aug_prob:
44 | img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
45 | img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
46 |
47 | # symmetric
48 | else:
49 | image_stack = np.concatenate([img1, img2], axis=0)
50 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
51 | img1, img2 = np.split(image_stack, 2, axis=0)
52 |
53 | return img1, img2
54 |
55 | def eraser_transform(self, img1, img2, bounds=[50, 100]):
56 | """ Occlusion augmentation """
57 |
58 | ht, wd = img1.shape[:2]
59 | if np.random.rand() < self.eraser_aug_prob:
60 | mean_color = np.mean(img2.reshape(-1, 3), axis=0)
61 | for _ in range(np.random.randint(1, 3)):
62 | x0 = np.random.randint(0, wd)
63 | y0 = np.random.randint(0, ht)
64 | dx = np.random.randint(bounds[0], bounds[1])
65 | dy = np.random.randint(bounds[0], bounds[1])
66 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color
67 |
68 | return img1, img2
69 |
70 | def spatial_transform(self, img1, img2, flow):
71 | # randomly sample scale
72 | ht, wd = img1.shape[:2]
73 | min_scale = np.maximum(
74 | (self.crop_size[0] + 8) / float(ht),
75 | (self.crop_size[1] + 8) / float(wd))
76 |
77 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
78 | scale_x = scale
79 | scale_y = scale
80 | if np.random.rand() < self.stretch_prob:
81 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
82 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
83 |
84 | scale_x = np.clip(scale_x, min_scale, None)
85 | scale_y = np.clip(scale_y, min_scale, None)
86 |
87 | if np.random.rand() < self.spatial_aug_prob:
88 | # rescale the images
89 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
90 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
91 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
92 | flow = flow * [scale_x, scale_y]
93 |
94 | if self.do_flip:
95 | if np.random.rand() < self.h_flip_prob: # h-flip
96 | img1 = img1[:, ::-1]
97 | img2 = img2[:, ::-1]
98 | flow = flow[:, ::-1] * [-1.0, 1.0]
99 |
100 | if np.random.rand() < self.v_flip_prob: # v-flip
101 | img1 = img1[::-1, :]
102 | img2 = img2[::-1, :]
103 | flow = flow[::-1, :] * [1.0, -1.0]
104 |
105 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
106 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
107 |
108 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
109 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
110 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
111 |
112 | return img1, img2, flow
113 |
114 | def __call__(self, img1, img2, flow):
115 | img1, img2 = self.color_transform(img1, img2)
116 | img1, img2 = self.eraser_transform(img1, img2)
117 | img1, img2, flow = self.spatial_transform(img1, img2, flow)
118 |
119 | img1 = np.ascontiguousarray(img1)
120 | img2 = np.ascontiguousarray(img2)
121 | flow = np.ascontiguousarray(flow)
122 |
123 | return img1, img2, flow
124 |
125 | class SparseFlowAugmentor:
126 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
127 | # spatial augmentation params
128 | self.crop_size = crop_size
129 | self.min_scale = min_scale
130 | self.max_scale = max_scale
131 | self.spatial_aug_prob = 0.8
132 | self.stretch_prob = 0.8
133 | self.max_stretch = 0.2
134 |
135 | # flip augmentation params
136 | self.do_flip = do_flip
137 | self.h_flip_prob = 0.5
138 | self.v_flip_prob = 0.1
139 |
140 | # photometric augmentation params
141 | self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
142 | self.asymmetric_color_aug_prob = 0.2
143 | self.eraser_aug_prob = 0.5
144 |
145 | def color_transform(self, img1, img2):
146 | image_stack = np.concatenate([img1, img2], axis=0)
147 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
148 | img1, img2 = np.split(image_stack, 2, axis=0)
149 | return img1, img2
150 |
151 | def eraser_transform(self, img1, img2):
152 | ht, wd = img1.shape[:2]
153 | if np.random.rand() < self.eraser_aug_prob:
154 | mean_color = np.mean(img2.reshape(-1, 3), axis=0)
155 | for _ in range(np.random.randint(1, 3)):
156 | x0 = np.random.randint(0, wd)
157 | y0 = np.random.randint(0, ht)
158 | dx = np.random.randint(50, 100)
159 | dy = np.random.randint(50, 100)
160 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color
161 |
162 | return img1, img2
163 |
164 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
165 | ht, wd = flow.shape[:2]
166 | coords = np.meshgrid(np.arange(wd), np.arange(ht))
167 | coords = np.stack(coords, axis=-1)
168 |
169 | coords = coords.reshape(-1, 2).astype(np.float32)
170 | flow = flow.reshape(-1, 2).astype(np.float32)
171 | valid = valid.reshape(-1).astype(np.float32)
172 |
173 | coords0 = coords[valid>=1]
174 | flow0 = flow[valid>=1]
175 |
176 | ht1 = int(round(ht * fy))
177 | wd1 = int(round(wd * fx))
178 |
179 | coords1 = coords0 * [fx, fy]
180 | flow1 = flow0 * [fx, fy]
181 |
182 | xx = np.round(coords1[:,0]).astype(np.int32)
183 | yy = np.round(coords1[:,1]).astype(np.int32)
184 |
185 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
186 | xx = xx[v]
187 | yy = yy[v]
188 | flow1 = flow1[v]
189 |
190 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
191 | valid_img = np.zeros([ht1, wd1], dtype=np.int32)
192 |
193 | flow_img[yy, xx] = flow1
194 | valid_img[yy, xx] = 1
195 |
196 | return flow_img, valid_img
197 |
198 | def spatial_transform(self, img1, img2, flow, valid):
199 | # randomly sample scale
200 |
201 | ht, wd = img1.shape[:2]
202 | min_scale = np.maximum(
203 | (self.crop_size[0] + 1) / float(ht),
204 | (self.crop_size[1] + 1) / float(wd))
205 |
206 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
207 | scale_x = np.clip(scale, min_scale, None)
208 | scale_y = np.clip(scale, min_scale, None)
209 |
210 | if np.random.rand() < self.spatial_aug_prob:
211 | # rescale the images
212 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
213 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
214 | flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
215 |
216 | if self.do_flip:
217 | if np.random.rand() < 0.5: # h-flip
218 | img1 = img1[:, ::-1]
219 | img2 = img2[:, ::-1]
220 | flow = flow[:, ::-1] * [-1.0, 1.0]
221 | valid = valid[:, ::-1]
222 |
223 | margin_y = 20
224 | margin_x = 50
225 |
226 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
227 | x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
228 |
229 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
230 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
231 |
232 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
233 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
234 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
235 | valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
236 | return img1, img2, flow, valid
237 |
238 |
239 | def __call__(self, img1, img2, flow, valid):
240 | img1, img2 = self.color_transform(img1, img2)
241 | img1, img2 = self.eraser_transform(img1, img2)
242 | img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
243 |
244 | img1 = np.ascontiguousarray(img1)
245 | img2 = np.ascontiguousarray(img2)
246 | flow = np.ascontiguousarray(flow)
247 | valid = np.ascontiguousarray(valid)
248 |
249 | return img1, img2, flow, valid
250 |
--------------------------------------------------------------------------------
/code.v.1.0/core/utils/flow_viz.py:
--------------------------------------------------------------------------------
1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
2 |
3 |
4 | # MIT License
5 | #
6 | # Copyright (c) 2018 Tom Runia
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to conditions.
14 | #
15 | # Author: Tom Runia
16 | # Date Created: 2018-08-03
17 |
18 | import numpy as np
19 |
20 | def make_colorwheel():
21 | """
22 | Generates a color wheel for optical flow visualization as presented in:
23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
25 |
26 | Code follows the original C++ source code of Daniel Scharstein.
27 | Code follows the the Matlab source code of Deqing Sun.
28 |
29 | Returns:
30 | np.ndarray: Color wheel
31 | """
32 |
33 | RY = 15
34 | YG = 6
35 | GC = 4
36 | CB = 11
37 | BM = 13
38 | MR = 6
39 |
40 | ncols = RY + YG + GC + CB + BM + MR
41 | colorwheel = np.zeros((ncols, 3))
42 | col = 0
43 |
44 | # RY
45 | colorwheel[0:RY, 0] = 255
46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
47 | col = col+RY
48 | # YG
49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
50 | colorwheel[col:col+YG, 1] = 255
51 | col = col+YG
52 | # GC
53 | colorwheel[col:col+GC, 1] = 255
54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
55 | col = col+GC
56 | # CB
57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
58 | colorwheel[col:col+CB, 2] = 255
59 | col = col+CB
60 | # BM
61 | colorwheel[col:col+BM, 2] = 255
62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
63 | col = col+BM
64 | # MR
65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
66 | colorwheel[col:col+MR, 0] = 255
67 | return colorwheel
68 |
69 |
70 | def flow_uv_to_colors(u, v, convert_to_bgr=False):
71 | """
72 | Applies the flow color wheel to (possibly clipped) flow components u and v.
73 |
74 | According to the C++ source code of Daniel Scharstein
75 | According to the Matlab source code of Deqing Sun
76 |
77 | Args:
78 | u (np.ndarray): Input horizontal flow of shape [H,W]
79 | v (np.ndarray): Input vertical flow of shape [H,W]
80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
81 |
82 | Returns:
83 | np.ndarray: Flow visualization image of shape [H,W,3]
84 | """
85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
86 | colorwheel = make_colorwheel() # shape [55x3]
87 | ncols = colorwheel.shape[0]
88 | rad = np.sqrt(np.square(u) + np.square(v))
89 | a = np.arctan2(-v, -u)/np.pi
90 | fk = (a+1) / 2*(ncols-1)
91 | k0 = np.floor(fk).astype(np.int32)
92 | k1 = k0 + 1
93 | k1[k1 == ncols] = 0
94 | f = fk - k0
95 | for i in range(colorwheel.shape[1]):
96 | tmp = colorwheel[:,i]
97 | col0 = tmp[k0] / 255.0
98 | col1 = tmp[k1] / 255.0
99 | col = (1-f)*col0 + f*col1
100 | idx = (rad <= 1)
101 | col[idx] = 1 - rad[idx] * (1-col[idx])
102 | col[~idx] = col[~idx] * 0.75 # out of range
103 | # Note the 2-i => BGR instead of RGB
104 | ch_idx = 2-i if convert_to_bgr else i
105 | flow_image[:,:,ch_idx] = np.floor(255 * col)
106 | return flow_image
107 |
108 |
109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
110 | """
111 | Expects a two dimensional flow image of shape.
112 |
113 | Args:
114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
117 |
118 | Returns:
119 | np.ndarray: Flow visualization image of shape [H,W,3]
120 | """
121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions'
122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
123 | if clip_flow is not None:
124 | flow_uv = np.clip(flow_uv, 0, clip_flow)
125 | u = flow_uv[:,:,0]
126 | v = flow_uv[:,:,1]
127 | rad = np.sqrt(np.square(u) + np.square(v))
128 | rad_max = np.max(rad)
129 | epsilon = 1e-5
130 | u = u / (rad_max + epsilon)
131 | v = v / (rad_max + epsilon)
132 | return flow_uv_to_colors(u, v, convert_to_bgr)
--------------------------------------------------------------------------------
/code.v.1.0/core/utils/frame_utils.py:
--------------------------------------------------------------------------------
1 | # This code was originally taken from RAFT without modification
2 | # https://github.com/princeton-vl/RAFT/blob/master/utils/frame_utils.py
3 |
4 | import numpy as np
5 | from PIL import Image
6 | from os.path import *
7 | import re
8 |
9 | import cv2
10 | cv2.setNumThreads(0)
11 | cv2.ocl.setUseOpenCL(False)
12 |
13 | TAG_CHAR = np.array([202021.25], np.float32)
14 |
15 | def readFlow(fn):
16 | """ Read .flo file in Middlebury format"""
17 | # Code adapted from:
18 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
19 |
20 | # WARNING: this will work on little-endian architectures (eg Intel x86) only!
21 | # print 'fn = %s'%(fn)
22 | with open(fn, 'rb') as f:
23 | magic = np.fromfile(f, np.float32, count=1)
24 | if 202021.25 != magic:
25 | print('Magic number incorrect. Invalid .flo file')
26 | return None
27 | else:
28 | w = np.fromfile(f, np.int32, count=1)
29 | h = np.fromfile(f, np.int32, count=1)
30 | # print 'Reading %d x %d flo file\n' % (w, h)
31 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
32 | # Reshape data into 3D array (columns, rows, bands)
33 | # The reshape here is for visualization, the original code is (w,h,2)
34 | return np.resize(data, (int(h), int(w), 2))
35 |
36 | def readPFM(file):
37 | file = open(file, 'rb')
38 |
39 | color = None
40 | width = None
41 | height = None
42 | scale = None
43 | endian = None
44 |
45 | header = file.readline().rstrip()
46 | if header == b'PF':
47 | color = True
48 | elif header == b'Pf':
49 | color = False
50 | else:
51 | raise Exception('Not a PFM file.')
52 |
53 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
54 | if dim_match:
55 | width, height = map(int, dim_match.groups())
56 | else:
57 | raise Exception('Malformed PFM header.')
58 |
59 | scale = float(file.readline().rstrip())
60 | if scale < 0: # little-endian
61 | endian = '<'
62 | scale = -scale
63 | else:
64 | endian = '>' # big-endian
65 |
66 | data = np.fromfile(file, endian + 'f')
67 | shape = (height, width, 3) if color else (height, width)
68 |
69 | data = np.reshape(data, shape)
70 | data = np.flipud(data)
71 | return data
72 |
73 | def writeFlow(filename,uv,v=None):
74 | """ Write optical flow to file.
75 |
76 | If v is None, uv is assumed to contain both u and v channels,
77 | stacked in depth.
78 | Original code by Deqing Sun, adapted from Daniel Scharstein.
79 | """
80 | nBands = 2
81 |
82 | if v is None:
83 | assert(uv.ndim == 3)
84 | assert(uv.shape[2] == 2)
85 | u = uv[:,:,0]
86 | v = uv[:,:,1]
87 | else:
88 | u = uv
89 |
90 | assert(u.shape == v.shape)
91 | height,width = u.shape
92 | f = open(filename,'wb')
93 | # write the header
94 | f.write(TAG_CHAR)
95 | np.array(width).astype(np.int32).tofile(f)
96 | np.array(height).astype(np.int32).tofile(f)
97 | # arrange into matrix form
98 | tmp = np.zeros((height, width*nBands))
99 | tmp[:,np.arange(width)*2] = u
100 | tmp[:,np.arange(width)*2 + 1] = v
101 | tmp.astype(np.float32).tofile(f)
102 | f.close()
103 |
104 |
105 | def readFlowKITTI(filename):
106 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
107 | flow = flow[:,:,::-1].astype(np.float32)
108 | flow, valid = flow[:, :, :2], flow[:, :, 2]
109 | flow = (flow - 2**15) / 64.0
110 | return flow, valid
111 |
112 | def readDispKITTI(filename):
113 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
114 | valid = disp > 0.0
115 | flow = np.stack([-disp, np.zeros_like(disp)], -1)
116 | return flow, valid
117 |
118 |
119 | def writeFlowKITTI(filename, uv):
120 | uv = 64.0 * uv + 2**15
121 | valid = np.ones([uv.shape[0], uv.shape[1], 1])
122 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
123 | cv2.imwrite(filename, uv[..., ::-1])
124 |
125 |
126 | def read_gen(file_name, pil=False):
127 | ext = splitext(file_name)[-1]
128 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
129 | return Image.open(file_name)
130 | elif ext == '.bin' or ext == '.raw':
131 | return np.load(file_name)
132 | elif ext == '.flo':
133 | return readFlow(file_name).astype(np.float32)
134 | elif ext == '.pfm':
135 | flow = readPFM(file_name).astype(np.float32)
136 | if len(flow.shape) == 2:
137 | return flow
138 | else:
139 | return flow[:, :, :-1]
140 | return []
141 |
--------------------------------------------------------------------------------
/code.v.1.0/core/utils/grid_sample.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def grid_sample(image, optical):
6 | N, C, IH, IW = image.shape
7 | _, H, W, _ = optical.shape
8 |
9 | ix = optical[..., 0]
10 | iy = optical[..., 1]
11 |
12 | ix = ((ix + 1) / 2) * (IW-1)
13 | iy = ((iy + 1) / 2) * (IH-1)
14 | with torch.no_grad():
15 | ix_nw = torch.floor(ix)
16 | iy_nw = torch.floor(iy)
17 | ix_ne = ix_nw + 1
18 | iy_ne = iy_nw
19 | ix_sw = ix_nw
20 | iy_sw = iy_nw + 1
21 | ix_se = ix_nw + 1
22 | iy_se = iy_nw + 1
23 |
24 | nw = (ix_se - ix) * (iy_se - iy)
25 | ne = (ix - ix_sw) * (iy_sw - iy)
26 | sw = (ix_ne - ix) * (iy - iy_ne)
27 | se = (ix - ix_nw) * (iy - iy_nw)
28 |
29 | with torch.no_grad():
30 | torch.clamp(ix_nw, 0, IW-1, out=ix_nw)
31 | torch.clamp(iy_nw, 0, IH-1, out=iy_nw)
32 |
33 | torch.clamp(ix_ne, 0, IW-1, out=ix_ne)
34 | torch.clamp(iy_ne, 0, IH-1, out=iy_ne)
35 |
36 | torch.clamp(ix_sw, 0, IW-1, out=ix_sw)
37 | torch.clamp(iy_sw, 0, IH-1, out=iy_sw)
38 |
39 | torch.clamp(ix_se, 0, IW-1, out=ix_se)
40 | torch.clamp(iy_se, 0, IH-1, out=iy_se)
41 |
42 | image = image.view(N, C, IH * IW)
43 |
44 |
45 | nw_val = torch.gather(image, 2, (iy_nw * IW + ix_nw).long().view(N, 1, H * W).repeat(1, C, 1))
46 | ne_val = torch.gather(image, 2, (iy_ne * IW + ix_ne).long().view(N, 1, H * W).repeat(1, C, 1))
47 | sw_val = torch.gather(image, 2, (iy_sw * IW + ix_sw).long().view(N, 1, H * W).repeat(1, C, 1))
48 | se_val = torch.gather(image, 2, (iy_se * IW + ix_se).long().view(N, 1, H * W).repeat(1, C, 1))
49 |
50 | out_val = (nw_val.view(N, C, H, W) * nw.view(N, 1, H, W) +
51 | ne_val.view(N, C, H, W) * ne.view(N, 1, H, W) +
52 | sw_val.view(N, C, H, W) * sw.view(N, 1, H, W) +
53 | se_val.view(N, C, H, W) * se.view(N, 1, H, W))
54 |
55 | return out_val
--------------------------------------------------------------------------------
/code.v.1.0/core/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from scipy import interpolate
5 |
6 | from .grid_sample import grid_sample
7 |
8 |
9 | class InputPadder:
10 | """ Pads images such that dimensions are divisible by 8 """
11 | def __init__(self, dims, mode='sintel'):
12 | self.ht, self.wd = dims[-2:]
13 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
14 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
15 | if mode == 'sintel':
16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
17 | else:
18 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
19 |
20 | def pad(self, *inputs):
21 | return [F.pad(x, self._pad, mode='replicate') for x in inputs]
22 |
23 | def unpad(self,x):
24 | ht, wd = x.shape[-2:]
25 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
26 | return x[..., c[0]:c[1], c[2]:c[3]]
27 |
28 |
29 | def forward_interpolate(flow):
30 | flow = flow.detach().cpu().numpy()
31 | dx, dy = flow[0], flow[1]
32 |
33 | ht, wd = dx.shape
34 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
35 |
36 | x1 = x0 + dx
37 | y1 = y0 + dy
38 |
39 | x1 = x1.reshape(-1)
40 | y1 = y1.reshape(-1)
41 | dx = dx.reshape(-1)
42 | dy = dy.reshape(-1)
43 |
44 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
45 | x1 = x1[valid]
46 | y1 = y1[valid]
47 | dx = dx[valid]
48 | dy = dy[valid]
49 |
50 | flow_x = interpolate.griddata(
51 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
52 |
53 | flow_y = interpolate.griddata(
54 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
55 |
56 | flow = np.stack([flow_x, flow_y], axis=0)
57 | return torch.from_numpy(flow).float()
58 |
59 |
60 | def bilinear_sampler(img, coords, mode='bilinear', mask=False):
61 | """ Wrapper for grid_sample, uses pixel coordinates """
62 | H, W = img.shape[-2:]
63 | xgrid, ygrid = coords.split([1,1], dim=-1)
64 | xgrid = 2*xgrid/(W-1) - 1
65 | ygrid = 2*ygrid/(H-1) - 1
66 |
67 | grid = torch.cat([xgrid, ygrid], dim=-1)
68 | img = F.grid_sample(img, grid, align_corners=True)
69 |
70 | # Enable higher order grad for JR
71 | # img = grid_sample(img, grid)
72 |
73 | if mask:
74 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
75 | return img, mask.float()
76 |
77 | return img
78 |
79 |
80 | def coords_grid(batch, ht, wd, device):
81 | coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device))
82 | coords = torch.stack(coords[::-1], dim=0).float()
83 | return coords[None].repeat(batch, 1, 1, 1)
84 |
85 |
86 | def upflow8(flow, mode='bilinear'):
87 | new_size = (8 * flow.shape[2], 8 * flow.shape[3])
88 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
89 |
--------------------------------------------------------------------------------
/code.v.1.0/ref/val.txt:
--------------------------------------------------------------------------------
1 | Applying weight normalization to BasicUpdateBlock
2 | Parameter Count: 5.243 M
3 | [33mVALID | FORWARD abs_diff: 3.1601126194000244; rel_diff: 0.0005030641914345324; nstep: 53[0m
4 | Validation (clean) EPE: 1.432517170906067 (1.432517170906067), 1px: 0.906150862584164, 3px: 0.9568943325276342, 5px: 0.9685142592463305
5 | [33mVALID | FORWARD abs_diff: 0.7322859764099121; rel_diff: 0.00010898241453105584; nstep: 59[0m
6 | [33mVALID | FORWARD abs_diff: 55.754356384277344; rel_diff: 0.00757289445027709; nstep: 48[0m
7 | Validation (final) EPE: 2.7917442321777344 (2.7917442321777344), 1px: 0.8529280729345681, 3px: 0.9170532694536889, 5px: 0.9365406933832148
8 | Validation KITTI: EPE: 5.427668504863977 (5.427668504863977), F1: 16.667239367961884 (16.667239367961884)
9 | Applying weight normalization to BasicUpdateBlock
10 | Parameter Count: 12.800 M
11 | [33mVALID | FORWARD abs_diff: 3.439692735671997; rel_diff: 0.0005472839525020399; nstep: 53[0m
12 | Validation (clean) EPE: 1.453827977180481 (1.453827977180481), 1px: 0.9146167279857274, 3px: 0.9612179145570596, 5px: 0.9716016637976287
13 | [33mVALID | FORWARD abs_diff: 2.7573065757751465; rel_diff: 0.00041015823748834397; nstep: 44[0m
14 | [33mVALID | FORWARD abs_diff: 40.787933349609375; rel_diff: 0.005504987121547275; nstep: 54[0m
15 | Validation (final) EPE: 2.578798294067383 (2.578798294067383), 1px: 0.8648370765776335, 3px: 0.9247630216423374, 5px: 0.9420040652278926
16 | Validation KITTI: EPE: 3.972053589001298 (3.972053589001298), F1: 13.408777117729187 (13.408777117729187)
17 | Applying weight normalization to BasicUpdateBlock
18 | Parameter Count: 12.800 M
19 | [33mVALID | FORWARD abs_diff: 3.5625743865966797; rel_diff: 0.0005668736362606357; nstep: 51[0m
20 | Validation (clean) EPE: 1.3733853101730347 (1.3733853101730347), 1px: 0.9143194620474535, 3px: 0.961116638444476, 5px: 0.9717625759844098
21 | [33mVALID | FORWARD abs_diff: 1.6139500141143799; rel_diff: 0.00024008954221575165; nstep: 52[0m
22 | [33mVALID | FORWARD abs_diff: 27.838302612304688; rel_diff: 0.0037760218577109006; nstep: 52[0m
23 | Validation (final) EPE: 2.6194019317626953 (2.6194019317626953), 1px: 0.8644313481614472, 3px: 0.9246781902573611, 5px: 0.9420075314657803
24 | Validation KITTI: EPE: 3.973951353058219 (3.973951353058219), F1: 13.621531426906586 (13.621531426906586)
25 | Applying weight normalization to BasicUpdateBlock
26 | Parameter Count: 12.800 M
27 | [33mVALID | FORWARD abs_diff: 3.938171148300171; rel_diff: 0.0006266361620816458; nstep: 52[0m
28 | Validation (clean) EPE: 1.3641184568405151 (1.3641184568405151), 1px: 0.9150200509059743, 3px: 0.9611289370265778, 5px: 0.9716694996437628
29 | [33mVALID | FORWARD abs_diff: 2.300001859664917; rel_diff: 0.0003421608202927882; nstep: 53[0m
30 | [33mVALID | FORWARD abs_diff: 32.62757873535156; rel_diff: 0.004405397824122709; nstep: 52[0m
31 | Validation (final) EPE: 2.622418165206909 (2.622418165206909), 1px: 0.8643445910887555, 3px: 0.9246018275951196, 5px: 0.9417596599553072
32 | Validation KITTI: EPE: 4.018190502673388 (4.018190502673388), F1: 13.917067646980286 (13.917067646980286)
33 |
--------------------------------------------------------------------------------
/code.v.1.0/train_B.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python -u main.py --name deq-flow-B-chairs --stage chairs --validation chairs \
4 | --gpus 0 1 --num_steps 120000 --batch_size 12 --lr 0.0004 --image_size 368 496 --wdecay 0.0001 \
5 | --wnorm --f_solver anderson --f_thres 36 \
6 | --n_losses 6 --phantom_grad 1
7 |
8 | python -u main.py --name deq-flow-B-things --stage things \
9 | --validation sintel kitti --restore_ckpt checkpoints/deq-flow-B-chairs.pth \
10 | --gpus 0 1 --num_steps 120000 --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001 \
11 | --wnorm --f_solver anderson --f_thres 40 \
12 | --n_losses 2 --phantom_grad 3
13 |
14 |
--------------------------------------------------------------------------------
/code.v.1.0/train_B_demo.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python -u main.py --name deq-flow-B-1-step-grad-things --stage things \
4 | --validation sintel kitti --restore_ckpt checkpoints/deq-flow-B-chairs.pth \
5 | --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001 \
6 | --wnorm --f_thres 40 --f_solver anderson \
7 |
--------------------------------------------------------------------------------
/code.v.1.0/train_H_demo.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python -u main.py --name deq-flow-H-1-step-grad-ad-things --stage things \
4 | --validation sintel kitti --restore_ckpt checkpoints/deq-flow-H-chairs.pth \
5 | --gpus 0 1 --num_steps 120000 --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001 \
6 | --wnorm --huge --f_solver anderson \
7 | --f_thres 60 --phantom_grad 1
8 |
--------------------------------------------------------------------------------
/code.v.1.0/train_H_full.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | python -u main.py --name deq-flow-H-chairs --stage chairs --validation chairs \
5 | --gpus 0 1 2 --num_steps 120000 --batch_size 12 --lr 0.0004 --image_size 368 496 --wdecay 0.0001 \
6 | --wnorm --huge --f_solver broyden \
7 | --f_thres 36 --n_losses 6 --phantom_grad 1 --sliced_core
8 |
9 | python -u main.py --name deq-flow-H-things --stage things \
10 | --validation sintel kitti --restore_ckpt checkpoints/deq-flow-H-chairs.pth \
11 | --gpus 0 1 2 --num_steps 120000 --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001 \
12 | --wnorm --huge --f_solver broyden \
13 | --f_thres 36 --n_losses 6 --phantom_grad 3 --sliced_core
14 |
15 | python -u main.py --name deq-flow-H-sintel --stage sintel \
16 | --validation sintel --restore_ckpt checkpoints/deq-flow-H-things.pth \
17 | --gpus 0 1 2 --num_steps 120000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.0001 --gamma=0.90 \
18 | --wnorm --huge --f_solver broyden \
19 | --f_thres 36 --n_losses 6 --phantom_grad 3 --sliced_core
20 |
21 | python -u main.py --name deq-flow-H-kitti --stage kitti \
22 | --validation kitti --restore_ckpt checkpoints/deq-flow-H-sintel.pth \
23 | --gpus 0 1 2 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.0001 --gamma=0.90 \
24 | --wnorm --huge --f_solver broyden \
25 | --f_thres 36 --n_losses 6 --phantom_grad 1 --sliced_core
26 |
--------------------------------------------------------------------------------
/code.v.1.0/val.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python -u main.py --eval --name deq-flow-B-things --stage things \
4 | --validation sintel kitti --restore_ckpt checkpoints/deq-flow-B-things-test.pth --gpus 0 \
5 | --wnorm --f_thres 40 --f_solver anderson
6 |
7 | python -u main.py --eval --name deq-flow-H-things --stage things \
8 | --validation sintel kitti --restore_ckpt checkpoints/deq-flow-H-things-test-1.pth --gpus 0 \
9 | --wnorm --f_thres 36 --f_solver broyden --huge
10 |
11 | python -u main.py --eval --name deq-flow-H-things --stage things \
12 | --validation sintel kitti --restore_ckpt checkpoints/deq-flow-H-things-test-2.pth --gpus 0 \
13 | --wnorm --f_thres 36 --f_solver broyden --huge
14 |
15 | python -u main.py --eval --name deq-flow-H-things --stage things \
16 | --validation sintel kitti --restore_ckpt checkpoints/deq-flow-H-things-test-3.pth --gpus 0 \
17 | --wnorm --f_thres 36 --f_solver broyden --huge
18 |
--------------------------------------------------------------------------------
/code.v.1.0/viz.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | sys.path.append('core')
4 |
5 | import argparse
6 | import copy
7 | import os
8 | import time
9 |
10 | import datasets
11 | import matplotlib.pyplot as plt
12 | import numpy as np
13 | import torch
14 | import torch.nn.functional as F
15 |
16 | import cv2
17 | from PIL import Image
18 |
19 | from utils import flow_viz, frame_utils
20 | from utils.utils import InputPadder, forward_interpolate
21 |
22 |
23 | @torch.no_grad()
24 | def sintel_visualization(model, split='test', warm_start=False, fixed_point_reuse=False, output_path='sintel_viz', **kwargs):
25 | """ Create visualization for the Sintel dataset """
26 | model.eval()
27 | for dstype in ['clean', 'final']:
28 | split = 'test' if split == 'test' else 'training'
29 | test_dataset = datasets.MpiSintel(split=split, aug_params=None, dstype=dstype)
30 |
31 | flow_prev, sequence_prev, fixed_point = None, None, None
32 | for test_id in range(len(test_dataset)):
33 | image1, image2, (sequence, frame) = test_dataset[test_id]
34 | if sequence != sequence_prev:
35 | flow_prev = None
36 | fixed_point = None
37 |
38 | padder = InputPadder(image1.shape)
39 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
40 |
41 | flow_low, flow_pr, info = model(image1, image2, flow_init=flow_prev, cached_result=fixed_point, **kwargs)
42 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()
43 |
44 | if warm_start:
45 | flow_prev = forward_interpolate(flow_low[0])[None].cuda()
46 |
47 | if fixed_point_reuse:
48 | net, flow_pred_low = info['cached_result']
49 | flow_pred_low = forward_interpolate(flow_pred_low[0])[None].cuda()
50 | fixed_point = (net, flow_pred_low)
51 |
52 | output_dir = os.path.join(output_path, dstype, sequence)
53 | output_file = os.path.join(output_dir, 'frame%04d.png' % (frame+1))
54 |
55 | if not os.path.exists(output_dir):
56 | os.makedirs(output_dir)
57 |
58 | # visualizaion
59 | img_flow = flow_viz.flow_to_image(flow)
60 | img_flow = cv2.cvtColor(img_flow, cv2.COLOR_RGB2BGR)
61 | cv2.imwrite(output_file, img_flow, [int(cv2.IMWRITE_PNG_COMPRESSION), 1])
62 |
63 | sequence_prev = sequence
64 |
65 |
66 | @torch.no_grad()
67 | def kitti_visualization(model, split='test', output_path='kitti_viz'):
68 | """ Create visualization for the KITTI dataset """
69 | model.eval()
70 | split = 'testing' if split == 'test' else 'training'
71 | test_dataset = datasets.KITTI(split='testing', aug_params=None)
72 |
73 | if not os.path.exists(output_path):
74 | os.makedirs(output_path)
75 |
76 | for test_id in range(len(test_dataset)):
77 | image1, image2, (frame_id, ) = test_dataset[test_id]
78 | padder = InputPadder(image1.shape, mode='kitti')
79 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
80 |
81 | _, flow_pr, _ = model(image1, image2)
82 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()
83 |
84 | output_filename = os.path.join(output_path, frame_id)
85 |
86 | # visualizaion
87 | img_flow = flow_viz.flow_to_image(flow)
88 | img_flow = cv2.cvtColor(img_flow, cv2.COLOR_RGB2BGR)
89 | cv2.imwrite(output_filename, img_flow, [int(cv2.IMWRITE_PNG_COMPRESSION), 1])
90 |
91 |
92 |
--------------------------------------------------------------------------------
/code.v.1.0/viz.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python -u main.py --viz --name deq-flow-H-kitti --stage kitti \
4 | --viz_set kitti --restore_ckpt checkpoints/deq-flow-H-kitti.pth --gpus 0 \
5 | --wnorm --f_thres 36 --f_solver broyden --huge
6 |
--------------------------------------------------------------------------------
/code.v.2.0/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/deq-flow/869d4bcd9c1d227cbfa16639994d1b118510a9bb/code.v.2.0/core/__init__.py
--------------------------------------------------------------------------------
/code.v.2.0/core/corr.py:
--------------------------------------------------------------------------------
1 | # This code was originally taken from RAFT without modification
2 | # https://github.com/princeton-vl/RAFT/blob/master/core/corr.py
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from utils.utils import bilinear_sampler, coords_grid
7 |
8 |
9 | class CorrBlock:
10 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
11 | self.num_levels = num_levels
12 | self.radius = radius
13 | self.corr_pyramid = []
14 |
15 | # all pairs correlation
16 | corr = CorrBlock.corr(fmap1, fmap2)
17 |
18 | batch, h1, w1, dim, h2, w2 = corr.shape
19 | corr = corr.reshape(batch*h1*w1, dim, h2, w2)
20 |
21 | self.corr_pyramid.append(corr)
22 | for i in range(self.num_levels-1):
23 | corr = F.avg_pool2d(corr, 2, stride=2)
24 | self.corr_pyramid.append(corr)
25 |
26 | def __call__(self, coords):
27 | r = self.radius
28 | coords = coords.permute(0, 2, 3, 1)
29 | batch, h1, w1, _ = coords.shape
30 |
31 | out_pyramid = []
32 | for i in range(self.num_levels):
33 | corr = self.corr_pyramid[i]
34 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
35 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
36 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
37 |
38 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
39 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
40 | coords_lvl = centroid_lvl + delta_lvl
41 |
42 | corr = bilinear_sampler(corr, coords_lvl)
43 | corr = corr.view(batch, h1, w1, -1)
44 | out_pyramid.append(corr)
45 |
46 | out = torch.cat(out_pyramid, dim=-1)
47 | return out.permute(0, 3, 1, 2).contiguous().float()
48 |
49 | @staticmethod
50 | def corr(fmap1, fmap2):
51 | batch, dim, ht, wd = fmap1.shape
52 | fmap1 = fmap1.view(batch, dim, ht*wd)
53 | fmap2 = fmap2.view(batch, dim, ht*wd)
54 |
55 | corr = torch.matmul(fmap1.transpose(1,2), fmap2)
56 | corr = corr.view(batch, ht, wd, 1, ht, wd)
57 | return corr / torch.sqrt(torch.tensor(dim).float())
58 |
59 |
60 |
61 |
--------------------------------------------------------------------------------
/code.v.2.0/core/datasets.py:
--------------------------------------------------------------------------------
1 | # Data loading based on https://github.com/NVIDIA/flownet2-pytorch
2 |
3 | import numpy as np
4 | import torch
5 | import torch.utils.data as data
6 | import torch.nn.functional as F
7 |
8 | import os
9 | import math
10 | import random
11 | from glob import glob
12 | import os.path as osp
13 |
14 | from utils import frame_utils
15 | from utils.augmentor import FlowAugmentor, SparseFlowAugmentor
16 |
17 |
18 | class FlowDataset(data.Dataset):
19 | def __init__(self, aug_params=None, sparse=False):
20 | self.augmentor = None
21 | self.sparse = sparse
22 | if aug_params is not None:
23 | if sparse:
24 | self.augmentor = SparseFlowAugmentor(**aug_params)
25 | else:
26 | self.augmentor = FlowAugmentor(**aug_params)
27 |
28 | self.is_test = False
29 | self.init_seed = False
30 | self.flow_list = []
31 | self.image_list = []
32 | self.extra_info = []
33 |
34 | def __getitem__(self, index):
35 |
36 | if self.is_test:
37 | img1 = frame_utils.read_gen(self.image_list[index][0])
38 | img2 = frame_utils.read_gen(self.image_list[index][1])
39 | img1 = np.array(img1).astype(np.uint8)[..., :3]
40 | img2 = np.array(img2).astype(np.uint8)[..., :3]
41 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
42 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
43 | return img1, img2, self.extra_info[index]
44 |
45 | if not self.init_seed:
46 | worker_info = torch.utils.data.get_worker_info()
47 | if worker_info is not None:
48 | torch.manual_seed(worker_info.id)
49 | np.random.seed(worker_info.id)
50 | random.seed(worker_info.id)
51 | self.init_seed = True
52 |
53 | index = index % len(self.image_list)
54 | valid = None
55 | if self.sparse:
56 | flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
57 | else:
58 | flow = frame_utils.read_gen(self.flow_list[index])
59 |
60 | img1 = frame_utils.read_gen(self.image_list[index][0])
61 | img2 = frame_utils.read_gen(self.image_list[index][1])
62 |
63 | flow = np.array(flow).astype(np.float32)
64 | img1 = np.array(img1).astype(np.uint8)
65 | img2 = np.array(img2).astype(np.uint8)
66 |
67 | # grayscale images
68 | if len(img1.shape) == 2:
69 | img1 = np.tile(img1[...,None], (1, 1, 3))
70 | img2 = np.tile(img2[...,None], (1, 1, 3))
71 | else:
72 | img1 = img1[..., :3]
73 | img2 = img2[..., :3]
74 |
75 | if self.augmentor is not None:
76 | if self.sparse:
77 | img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
78 | else:
79 | img1, img2, flow = self.augmentor(img1, img2, flow)
80 |
81 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
82 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
83 | flow = torch.from_numpy(flow).permute(2, 0, 1).float()
84 |
85 | if valid is not None:
86 | valid = torch.from_numpy(valid)
87 | else:
88 | valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
89 |
90 | return img1, img2, flow, valid.float()
91 |
92 |
93 | def __rmul__(self, v):
94 | self.flow_list = v * self.flow_list
95 | self.image_list = v * self.image_list
96 | return self
97 |
98 | def __len__(self):
99 | return len(self.image_list)
100 |
101 |
102 | class MpiSintel(FlowDataset):
103 | def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'):
104 | super(MpiSintel, self).__init__(aug_params)
105 | flow_root = osp.join(root, split, 'flow')
106 | image_root = osp.join(root, split, dstype)
107 |
108 | if split == 'test':
109 | self.is_test = True
110 |
111 | for scene in os.listdir(image_root):
112 | image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
113 | for i in range(len(image_list)-1):
114 | self.image_list += [ [image_list[i], image_list[i+1]] ]
115 | self.extra_info += [ (scene, i) ] # scene and frame_id
116 |
117 | if split != 'test':
118 | self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
119 |
120 |
121 | class FlyingChairs(FlowDataset):
122 | def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):
123 | super(FlyingChairs, self).__init__(aug_params)
124 |
125 | images = sorted(glob(osp.join(root, '*.ppm')))
126 | flows = sorted(glob(osp.join(root, '*.flo')))
127 | assert (len(images)//2 == len(flows))
128 |
129 | split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
130 | for i in range(len(flows)):
131 | xid = split_list[i]
132 | if (split=='training' and xid==1) or (split=='validation' and xid==2):
133 | self.flow_list += [ flows[i] ]
134 | self.image_list += [ [images[2*i], images[2*i+1]] ]
135 |
136 |
137 | class FlyingThings3D(FlowDataset):
138 | def __init__(self, aug_params=None, root='datasets/FlyingThings3D', split='train', dstype='frames_cleanpass'):
139 | super(FlyingThings3D, self).__init__(aug_params)
140 |
141 | if split == 'train':
142 | dir_prefix = 'TRAIN'
143 | elif split == 'test':
144 | dir_prefix = 'TEST'
145 | else:
146 | raise ValueError('Unknown split for FlyingThings3D.')
147 |
148 | for cam in ['left']:
149 | for direction in ['into_future', 'into_past']:
150 | image_dirs = sorted(glob(osp.join(root, dstype, f'{dir_prefix}/*/*')))
151 | image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
152 |
153 | flow_dirs = sorted(glob(osp.join(root, f'optical_flow/{dir_prefix}/*/*')))
154 | flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])
155 |
156 | for idir, fdir in zip(image_dirs, flow_dirs):
157 | images = sorted(glob(osp.join(idir, '*.png')) )
158 | flows = sorted(glob(osp.join(fdir, '*.pfm')) )
159 | for i in range(len(flows)-1):
160 | if direction == 'into_future':
161 | self.image_list += [ [images[i], images[i+1]] ]
162 | self.flow_list += [ flows[i] ]
163 | elif direction == 'into_past':
164 | self.image_list += [ [images[i+1], images[i]] ]
165 | self.flow_list += [ flows[i+1] ]
166 |
167 |
168 | class KITTI(FlowDataset):
169 | def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):
170 | super(KITTI, self).__init__(aug_params, sparse=True)
171 | if split == 'testing':
172 | self.is_test = True
173 |
174 | root = osp.join(root, split)
175 | images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
176 | images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
177 |
178 | for img1, img2 in zip(images1, images2):
179 | frame_id = img1.split('/')[-1]
180 | self.extra_info += [ [frame_id] ]
181 | self.image_list += [ [img1, img2] ]
182 |
183 | if split == 'training':
184 | self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
185 |
186 |
187 | class HD1K(FlowDataset):
188 | def __init__(self, aug_params=None, root='datasets/HD1k'):
189 | super(HD1K, self).__init__(aug_params, sparse=True)
190 |
191 | seq_ix = 0
192 | while 1:
193 | flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
194 | images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))
195 |
196 | if len(flows) == 0:
197 | break
198 |
199 | for i in range(len(flows)-1):
200 | self.flow_list += [flows[i]]
201 | self.image_list += [ [images[i], images[i+1]] ]
202 |
203 | seq_ix += 1
204 |
205 |
206 | def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
207 | """ Create the data loader for the corresponding trainign set """
208 |
209 | if args.stage == 'chairs':
210 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
211 | train_dataset = FlyingChairs(aug_params, split='training')
212 |
213 | elif args.stage == 'things':
214 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
215 | clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
216 | final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
217 | train_dataset = clean_dataset + final_dataset
218 |
219 | elif args.stage == 'sintel':
220 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
221 | things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
222 | sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
223 | sintel_final = MpiSintel(aug_params, split='training', dstype='final')
224 |
225 | if TRAIN_DS == 'C+T+K+S+H':
226 | kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
227 | hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
228 | train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
229 |
230 | elif TRAIN_DS == 'C+T+K/S':
231 | train_dataset = 100*sintel_clean + 100*sintel_final + things
232 |
233 | elif args.stage == 'kitti':
234 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
235 | train_dataset = KITTI(aug_params, split='training')
236 |
237 | train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
238 | pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
239 |
240 | print('Training with %d image pairs' % len(train_dataset))
241 | return train_loader
242 |
243 |
--------------------------------------------------------------------------------
/code.v.2.0/core/deq/__init__.py:
--------------------------------------------------------------------------------
1 | from .deq_class import get_deq
2 |
--------------------------------------------------------------------------------
/code.v.2.0/core/deq/arg_utils.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | def add_deq_args(parser):
5 | parser.add_argument('--wnorm', action='store_true', help="use weight normalization")
6 | parser.add_argument('--f_solver', default='anderson', type=str, choices=['anderson', 'broyden', 'naive_solver'],
7 | help='forward solver to use (only anderson and broyden supported now)')
8 | parser.add_argument('--b_solver', default='broyden', type=str, choices=['anderson', 'broyden', 'naive_solver'],
9 | help='backward solver to use')
10 | parser.add_argument('--f_thres', type=int, default=40, help='forward pass solver threshold')
11 | parser.add_argument('--b_thres', type=int, default=40, help='backward pass solver threshold')
12 | parser.add_argument('--f_eps', type=float, default=1e-3, help='forward pass solver stopping criterion')
13 | parser.add_argument('--b_eps', type=float, default=1e-3, help='backward pass solver stopping criterion')
14 | parser.add_argument('--f_stop_mode', type=str, default="abs", help="forward pass fixed-point convergence stop mode")
15 | parser.add_argument('--b_stop_mode', type=str, default="abs", help="backward pass fixed-point convergence stop mode")
16 | parser.add_argument('--eval_factor', type=float, default=1.5, help="factor to scale up the f_thres at test for better convergence.")
17 | parser.add_argument('--eval_f_thres', type=int, default=0, help="directly set the f_thres at test.")
18 |
19 | parser.add_argument('--indexing_core', action='store_true', help="use the indexing core implementation.")
20 | parser.add_argument('--ift', action='store_true', help="use implicit differentiation.")
21 | parser.add_argument('--safe_ift', action='store_true', help="use a safer function for IFT to avoid potential segment fault in older pytorch versions.")
22 | parser.add_argument('--n_losses', type=int, default=1, help="number of loss terms (uniform spaced, 1 + fixed point correction).")
23 | parser.add_argument('--indexing', type=int, nargs='+', default=[], help="indexing for fixed point correction.")
24 | parser.add_argument('--phantom_grad', type=int, nargs='+', default=[1], help="steps of Phantom Grad")
25 | parser.add_argument('--tau', type=float, default=1.0, help="damping factor for unrolled Phantom Grad")
26 | parser.add_argument('--sup_all', action='store_true', help="supervise all the trajectories by Phantom Grad.")
27 |
28 | parser.add_argument('--sradius_mode', action='store_true', help="monitor the spectral radius during validation")
29 |
30 |
31 |
--------------------------------------------------------------------------------
/code.v.2.0/core/deq/deq_class.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from termcolor import colored
7 |
8 | from .solvers import get_solver
9 | from .norm import reset_weight_norm
10 | from .grad import make_pair, backward_factory
11 | from .jacobian import power_method
12 |
13 |
14 | class DEQBase(nn.Module):
15 | def __init__(self, args):
16 | super(DEQBase, self).__init__()
17 |
18 | self.args = args
19 | self.f_solver = get_solver(args.f_solver)
20 | self.b_solver = get_solver(args.b_solver)
21 |
22 | self.f_thres = args.f_thres
23 | self.b_thres = args.b_thres
24 |
25 | self.f_eps = args.f_eps
26 | self.b_eps = args.b_eps
27 |
28 | self.f_stop_mode = args.f_stop_mode
29 | self.b_stop_mode = args.b_stop_mode
30 |
31 | self.eval_f_thres = args.eval_f_thres if args.eval_f_thres > 0 else int(self.f_thres * args.eval_factor)
32 |
33 | self.hook = None
34 |
35 | def _log_convergence(self, info, name='FORWARD', color='yellow'):
36 | state = 'TRAIN' if self.training else 'VALID'
37 | alt_mode = 'rel' if self.f_stop_mode == 'abs' else 'abs'
38 |
39 | rel_lowest, abs_lowest = info['rel_lowest'].mean().item(), info['abs_lowest'].mean().item()
40 | nstep = info['nstep']
41 |
42 | show_str = f'{state} | {name} | rel: {rel_lowest}; abs: {abs_lowest}; nstep: {nstep}'
43 | print(colored(show_str, color))
44 |
45 | def _sradius(self, deq_func, z_star):
46 | with torch.enable_grad():
47 | new_z_star = deq_func(z_star.requires_grad_())
48 | _, sradius = power_method(new_z_star, z_star, n_iters=75)
49 |
50 | return sradius
51 |
52 | def _solve_fixed_point(
53 | self, deq_func, z_init,
54 | log=False, f_thres=None,
55 | **kwargs
56 | ):
57 | raise NotImplementedError
58 |
59 | def forward(
60 | self, deq_func, z_init,
61 | log=False, sradius_mode=False, writer=None,
62 | **kwargs
63 | ):
64 | raise NotImplementedError
65 |
66 |
67 | class DEQIndexing(DEQBase):
68 | def __init__(self, args):
69 | super(DEQIndexing, self).__init__(args)
70 |
71 | # Define gradient functions through the backward factory
72 | if args.n_losses > 1:
73 | n_losses = min(args.f_thres, args.n_losses)
74 | delta = int(args.f_thres // n_losses)
75 | self.indexing = [(k+1)*delta for k in range(n_losses)]
76 | else:
77 | self.indexing = [*args.indexing, args.f_thres]
78 |
79 | # By default, we use the same phantom grad for all corrections.
80 | # You can also set different grad steps a, b, and c for different terms by ``args.phantom_grad a b c ...''.
81 | indexing_pg = make_pair(self.indexing, args.phantom_grad)
82 | produce_grad = [
83 | backward_factory(grad_type=pg, tau=args.tau, sup_all=args.sup_all) for pg in indexing_pg
84 | ]
85 | if args.ift:
86 | # Enabling args.ift will replace the last gradient function by IFT.
87 | produce_grad[-1] = backward_factory(
88 | grad_type='ift', safe_ift=args.safe_ift, b_solver=self.b_solver,
89 | b_solver_kwargs=dict(threshold=args.b_thres, eps=args.b_eps, stop_mode=args.b_stop_mode)
90 | )
91 |
92 | self.produce_grad = produce_grad
93 |
94 | def _solve_fixed_point(
95 | self, deq_func, z_init,
96 | log=False, f_thres=None,
97 | **kwargs
98 | ):
99 | if f_thres is None: f_thres = self.f_thres
100 | indexing = self.indexing if self.training else None
101 |
102 | with torch.no_grad():
103 | z_star, trajectory, info = self.f_solver(
104 | deq_func, x0=z_init, threshold=f_thres, # To reuse previous coarse fixed points
105 | eps=self.f_eps, stop_mode=self.f_stop_mode, indexing=indexing
106 | )
107 |
108 | if log: self._log_convergence(info, name="FORWARD", color="yellow")
109 |
110 | return z_star, trajectory, info
111 |
112 | def forward(
113 | self, deq_func, z_init,
114 | log=False, sradius_mode=False, writer=None,
115 | **kwargs
116 | ):
117 | if self.training:
118 | _, trajectory, info = self._solve_fixed_point(deq_func, z_init, log=log, *kwargs)
119 |
120 | z_out = []
121 | for z_pred, produce_grad in zip(trajectory, self.produce_grad):
122 | z_out += produce_grad(self, deq_func, z_pred) # See lib/grad.py for the backward pass implementations
123 |
124 | z_out = [deq_func.vec2list(each) for each in z_out]
125 | else:
126 | # During inference, we directly solve for fixed point
127 | z_star, _, info = self._solve_fixed_point(deq_func, z_init, log=log, f_thres=self.eval_f_thres)
128 |
129 | sradius = self._sradius(deq_func, z_star) if sradius_mode else torch.zeros(1, device=z_star.device)
130 | info['sradius'] = sradius
131 |
132 | z_out = [deq_func.vec2list(z_star)]
133 |
134 | return z_out, info
135 |
136 |
137 | class DEQSliced(DEQBase):
138 | def __init__(self, args):
139 | super(DEQSliced, self).__init__(args)
140 |
141 | # Define gradient functions through the backward factory
142 | if args.n_losses > 1:
143 | self.indexing = [int(args.f_thres // args.n_losses) for _ in range(args.n_losses)]
144 | else:
145 | self.indexing = np.diff([0, *args.indexing, args.f_thres]).tolist()
146 |
147 | # By default, we use the same phantom grad for all corrections.
148 | # You can also set different grad steps a, b, and c for different terms by ``args.phantom_grad a b c ...''.
149 | indexing_pg = make_pair(self.indexing, args.phantom_grad)
150 | produce_grad = [
151 | backward_factory(grad_type=pg, tau=args.tau, sup_all=args.sup_all) for pg in indexing_pg
152 | ]
153 | if args.ift:
154 | # Enabling args.ift will replace the last gradient function by IFT.
155 | produce_grad[-1] = backward_factory(
156 | grad_type='ift', safe_ift=args.safe_ift, b_solver=self.b_solver,
157 | b_solver_kwargs=dict(threshold=args.b_thres, eps=args.b_eps, stop_mode=args.b_stop_mode)
158 | )
159 |
160 | self.produce_grad = produce_grad
161 |
162 | def _solve_fixed_point(
163 | self, deq_func, z_init,
164 | log=False, f_thres=None,
165 | **kwargs
166 | ):
167 | with torch.no_grad():
168 | z_star, _, info = self.f_solver(
169 | deq_func, x0=z_init, threshold=f_thres, # To reuse previous coarse fixed points
170 | eps=self.f_eps, stop_mode=self.f_stop_mode
171 | )
172 |
173 | if log: self._log_convergence(info, name="FORWARD", color="yellow")
174 |
175 | return z_star, info
176 |
177 | def forward(
178 | self, deq_func, z_star,
179 | log=False, sradius_mode=False, writer=None,
180 | **kwargs
181 | ):
182 | if self.training:
183 | z_out = []
184 | for f_thres, produce_grad in zip(self.indexing, self.produce_grad):
185 | z_star, info = self._solve_fixed_point(deq_func, z_star, f_thres=f_thres, log=log)
186 | z_out += produce_grad(self, deq_func, z_star, writer=writer) # See lib/grad.py for implementations
187 | z_star = z_out[-1] # Add the gradient chain to the solver.
188 |
189 | z_out = [deq_func.vec2list(each) for each in z_out]
190 | else:
191 | # During inference, we directly solve for fixed point
192 | z_star, info = self._solve_fixed_point(deq_func, z_star, f_thres=self.eval_f_thres, log=log)
193 |
194 | sradius = self._sradius(deq_func, z_star) if sradius_mode else torch.zeros(1, device=z_star.device)
195 | info['sradius'] = sradius
196 |
197 | z_out = [deq_func.vec2list(z_star)]
198 |
199 | return z_out, info
200 |
201 |
202 | def get_deq(args):
203 | if args.indexing_core:
204 | return DEQIndexing
205 | else:
206 | return DEQSliced
207 |
--------------------------------------------------------------------------------
/code.v.2.0/core/deq/dropout.py:
--------------------------------------------------------------------------------
1 | import types
2 |
3 | from torch.nn.parameter import Parameter
4 | import torch.nn as nn
5 | import torch
6 | import torch.nn.functional as F
7 | from torch.autograd import Variable
8 |
9 |
10 |
11 | ##############################################################################################################
12 | #
13 | # Temporal DropConnect in a feed-forward setting
14 | #
15 | ##############################################################################################################
16 |
17 |
18 | class WeightDrop(torch.nn.Module):
19 | def __init__(self, module, weights, dropout=0, temporal=True):
20 | """
21 | Weight DropConnect, adapted from a recurrent setting by Merity et al. 2017
22 |
23 | :param module: The module whose weights are to be applied dropout on
24 | :param weights: A 2D list identifying the weights to be regularized. Each element of weights should be a
25 | list containing the "path" to the weight kernel. For instance, if we want to regularize
26 | module.layer2.weight3, then this should be ["layer2", "weight3"].
27 | :param dropout: The dropout rate (0 means no dropout)
28 | :param temporal: Whether we apply DropConnect only to the temporal parts of the weight (empirically we found
29 | this not very important)
30 | """
31 | super(WeightDrop, self).__init__()
32 | self.module = module
33 | self.weights = weights
34 | self.dropout = dropout
35 | self.temporal = temporal
36 | if self.dropout > 0.0:
37 | self._setup()
38 |
39 | def _setup(self):
40 | for path in self.weights:
41 | full_name_w = '.'.join(path)
42 |
43 | module = self.module
44 | name_w = path[-1]
45 | for i in range(len(path) - 1):
46 | module = getattr(module, path[i])
47 | w = getattr(module, name_w)
48 | del module._parameters[name_w]
49 | module.register_parameter(name_w + '_raw', Parameter(w.data))
50 |
51 | def _setweights(self):
52 | for path in self.weights:
53 | module = self.module
54 | name_w = path[-1]
55 | for i in range(len(path) - 1):
56 | module = getattr(module, path[i])
57 | raw_w = getattr(module, name_w + '_raw')
58 |
59 | if len(raw_w.size()) > 2 and raw_w.size(2) > 1 and self.temporal:
60 | # Drop the temporal parts of the weight; if 1x1 convolution then drop the whole kernel
61 | w = torch.cat([F.dropout(raw_w[:, :, :-1], p=self.dropout, training=self.training),
62 | raw_w[:, :, -1:]], dim=2)
63 | else:
64 | w = F.dropout(raw_w, p=self.dropout, training=self.training)
65 |
66 | setattr(module, name_w, w)
67 |
68 | def forward(self, *args, **kwargs):
69 | if self.dropout > 0.0:
70 | self._setweights()
71 | return self.module.forward(*args, **kwargs)
72 |
73 |
74 | def matrix_diag(a, dim=2):
75 | """
76 | a has dimension (N, (L,) C), we want a matrix/batch diag that produces (N, (L,) C, C) from the last dimension of a
77 | """
78 | if dim == 2:
79 | res = torch.zeros(a.size(0), a.size(1), a.size(1))
80 | res.as_strided(a.size(), [res.stride(0), res.size(2)+1]).copy_(a)
81 | else:
82 | res = torch.zeros(a.size(0), a.size(1), a.size(2), a.size(2))
83 | res.as_strided(a.size(), [res.stride(0), res.stride(1), res.size(3)+1]).copy_(a)
84 | return res
85 |
86 |
87 | ##############################################################################################################
88 | #
89 | # Embedding dropout
90 | #
91 | ##############################################################################################################
92 |
93 |
94 | def embedded_dropout(embed, words, dropout=0.1, scale=None):
95 | """
96 | Apply embedding encoder (whose weight we apply a dropout)
97 |
98 | :param embed: The embedding layer
99 | :param words: The input sequence
100 | :param dropout: The embedding weight dropout rate
101 | :param scale: Scaling factor for the dropped embedding weight
102 | :return: The embedding output
103 | """
104 | if dropout:
105 | mask = embed.weight.data.new().resize_((embed.weight.size(0), 1)).bernoulli_(1 - dropout).expand_as(
106 | embed.weight) / (1 - dropout)
107 | mask = Variable(mask)
108 | masked_embed_weight = mask * embed.weight
109 | else:
110 | masked_embed_weight = embed.weight
111 |
112 | if scale:
113 | masked_embed_weight = scale.expand_as(masked_embed_weight) * masked_embed_weight
114 |
115 | padding_idx = embed.padding_idx
116 | if padding_idx is None:
117 | padding_idx = -1
118 |
119 | X = F.embedding(words, masked_embed_weight, padding_idx, embed.max_norm, embed.norm_type,
120 | embed.scale_grad_by_freq, embed.sparse)
121 | return X
122 |
123 |
124 |
125 | ##############################################################################################################
126 | #
127 | # Variational dropout (for input/output layers, and for hidden layers)
128 | #
129 | ##############################################################################################################
130 |
131 |
132 | class VariationalHidDropout2d(nn.Module):
133 | def __init__(self, dropout=0.0):
134 | super(VariationalHidDropout2d, self).__init__()
135 | self.dropout = dropout
136 | self.mask = None
137 |
138 | def forward(self, x):
139 | if not self.training or self.dropout == 0:
140 | return x
141 | bsz, d, H, W = x.shape
142 | if self.mask is None:
143 | m = torch.zeros(bsz, d, H, W).bernoulli_(1 - self.dropout).to(x)
144 | self.mask = m.requires_grad_(False) / (1 - self.dropout)
145 | return self.mask * x
146 |
147 |
148 |
149 |
--------------------------------------------------------------------------------
/code.v.2.0/core/deq/grad.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch import autograd
6 |
7 | from .solvers import anderson, broyden, naive_solver
8 |
9 |
10 | def make_pair(target, source):
11 | if len(target) == len(source):
12 | return source
13 | elif len(source) == 1:
14 | return [source[0] for _ in range(len(target))]
15 | else:
16 | raise ValueError('Unable to align the arg squence!')
17 |
18 |
19 | def backward_factory(
20 | grad_type='ift',
21 | safe_ift=False,
22 | b_solver=anderson,
23 | b_solver_kwargs=dict(),
24 | sup_all=False,
25 | tau=1.0):
26 | """
27 | [2019-NeurIPS] Deep Equilibrium Models
28 | [2021-NeurIPS] On Training Implicit Models
29 |
30 | This function implements a factory for the backward pass of implicit deep learning,
31 | e.g., DEQ (implicit models), Hamburger (optimization layer), etc.
32 | It now supports IFT, 1-step Grad, and Phantom Grad.
33 |
34 | Args:
35 | grad_type (string, int):
36 | grad_type should be ``ift`` or an int. Default ``ift``.
37 | Set to ``ift`` to enable the implicit differentiation mode.
38 | When passing a number k to this function, it runs UPG with steps k and damping tau.
39 | safe_ift (bool):
40 | Replace the O(1) hook implementeion with a safer one. Default ``False``.
41 | Set to ``True`` to avoid the (potential) segment fault (under previous versions of Pytorch).
42 | b_solver (type):
43 | Solver for the IFT backward pass. Default ``anderson``.
44 | Supported solvers: anderson, broyden.
45 | b_solver_kwargs (dict):
46 | Colllection of backward solver kwargs, e.g.,
47 | threshold (int), max steps for the backward solver,
48 | stop_mode (string), criterion for convergence,
49 | etc.
50 | See solver.py to check all the kwargs.
51 | sup_all (bool):
52 | Indicate whether to supervise all the trajectories by Phantom Grad.
53 | Set ``True`` to return all trajectory in Phantom Grad.
54 | tau (float):
55 | Damping factor for Phantom Grad. Default ``1.0``.
56 | 0.5 is recommended for CIFAR-10. 1.0 for DEQ flow.
57 | For DEQ flow, the gating function in GRU naturally produces adaptive tau values.
58 |
59 | Returns:
60 | A gradient functor for implicit deep learning.
61 | Args:
62 | trainer (nn.Module): the module that employs implicit deep learning.
63 | func (type): function that defines the ``f`` in ``z = f(z)``.
64 | z_pred (torch.Tensor): latent state to run the backward pass.
65 |
66 | Returns:
67 | (list(torch.Tensor)): a list of tensors that tracks the gradient info.
68 |
69 | """
70 |
71 | if grad_type == 'ift':
72 | assert b_solver in [naive_solver, anderson, broyden]
73 |
74 | if safe_ift:
75 | def plain_ift_grad(trainer, func, z_pred, **kwargs):
76 | z_pred = z_pred.requires_grad_()
77 | new_z_pred = func(z_pred) # 1-step grad for df/dtheta
78 |
79 | z_pred_copy = new_z_pred.clone().detach().requires_grad_()
80 | new_z_pred_copy = func(z_pred_copy)
81 | def backward_hook(grad):
82 | grad_star, _, info = b_solver(
83 | lambda y: autograd.grad(new_z_pred_copy, z_pred_copy, y, retain_graph=True)[0] + grad,
84 | torch.zeros_like(grad), **b_solver_kwargs
85 | )
86 | return grad_star
87 | new_z_pred.register_hook(backward_hook)
88 |
89 | return [new_z_pred]
90 | return plain_ift_grad
91 | else:
92 | def hook_ift_grad(trainer, func, z_pred, **kwargs):
93 | z_pred = z_pred.requires_grad_()
94 | new_z_pred = func(z_pred) # 1-step grad for df/dtheta
95 |
96 | def backward_hook(grad):
97 | if trainer.hook is not None:
98 | trainer.hook.remove() # To avoid infinite loop
99 | grad_star, _, info = b_solver(
100 | lambda y: autograd.grad(new_z_pred, z_pred, y, retain_graph=True)[0] + grad,
101 | torch.zeros_like(grad), **b_solver_kwargs
102 | )
103 | return grad_star
104 | trainer.hook = new_z_pred.register_hook(backward_hook)
105 |
106 | return [new_z_pred]
107 | return hook_ift_grad
108 | else:
109 | assert type(grad_type) is int and grad_type >= 1
110 | n_phantom_grad = grad_type
111 |
112 | if sup_all:
113 | def sup_all_phantom_grad(trainer, func, z_pred, **kwargs):
114 | z_out = []
115 | for _ in range(n_phantom_grad):
116 | z_pred = (1 - tau) * z_pred + tau * func(z_pred)
117 | z_out.append(z_pred)
118 |
119 | return z_out
120 | return sup_all_phantom_grad
121 | else:
122 | def phantom_grad(trainer, func, z_pred, **kwargs):
123 | for _ in range(n_phantom_grad):
124 | z_pred = (1 - tau) * z_pred + tau * func(z_pred)
125 |
126 | return [z_pred]
127 | return phantom_grad
128 |
--------------------------------------------------------------------------------
/code.v.2.0/core/deq/jacobian.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | import numpy as np
5 |
6 |
7 | def jac_loss_estimate(f0, z0, vecs=2, create_graph=True):
8 | """Estimating tr(J^TJ)=tr(JJ^T) via Hutchinson estimator
9 |
10 | Args:
11 | f0 (torch.Tensor): Output of the function f (whose J is to be analyzed)
12 | z0 (torch.Tensor): Input to the function f
13 | vecs (int, optional): Number of random Gaussian vectors to use. Defaults to 2.
14 | create_graph (bool, optional): Whether to create backward graph (e.g., to train on this loss).
15 | Defaults to True.
16 |
17 | Returns:
18 | torch.Tensor: A 1x1 torch tensor that encodes the (shape-normalized) jacobian loss
19 | """
20 | vecs = vecs
21 | result = 0
22 | for i in range(vecs):
23 | v = torch.randn(*z0.shape).to(z0)
24 | vJ = torch.autograd.grad(f0, z0, v, retain_graph=True, create_graph=create_graph)[0]
25 | result += vJ.norm()**2
26 | return result / vecs / np.prod(z0.shape)
27 |
28 |
29 | def power_method(f0, z0, n_iters=200):
30 | """Estimating the spectral radius of J using power method
31 |
32 | Args:
33 | f0 (torch.Tensor): Output of the function f (whose J is to be analyzed)
34 | z0 (torch.Tensor): Input to the function f
35 | n_iters (int, optional): Number of power method iterations. Defaults to 200.
36 |
37 | Returns:
38 | tuple: (largest eigenvector, largest (abs.) eigenvalue)
39 | """
40 | evector = torch.randn_like(z0)
41 | bsz = evector.shape[0]
42 | for i in range(n_iters):
43 | vTJ = torch.autograd.grad(f0, z0, evector, retain_graph=(i < n_iters-1), create_graph=False)[0]
44 | evalue = (vTJ * evector).reshape(bsz, -1).sum(1, keepdim=True) / (evector * evector).reshape(bsz, -1).sum(1, keepdim=True)
45 | evector = (vTJ.reshape(bsz, -1) / vTJ.reshape(bsz, -1).norm(dim=1, keepdim=True)).reshape_as(z0)
46 | return (evector, torch.abs(evalue))
47 |
--------------------------------------------------------------------------------
/code.v.2.0/core/deq/layer_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | import torch.nn as nn
6 |
7 |
8 | class DEQWrapper:
9 | def __init__(self, func, z_init=list()):
10 | z_shape = []
11 | z_indexing = [0]
12 | for each in z_init:
13 | z_shape.append(each.shape)
14 | z_indexing.append(np.prod(each.shape[1:]))
15 |
16 | self.func = func
17 | self.z_shape = z_shape
18 | self.z_indexing = np.cumsum(z_indexing)
19 |
20 | def list2vec(self, *z_list):
21 | '''Convert list of tensors to a batched vector (B, ...)'''
22 |
23 | z_list = [each.flatten(start_dim=1) for each in z_list]
24 | return torch.cat(z_list, dim=1)
25 |
26 | def vec2list(self, z_hidden):
27 | '''Convert a batched vector back to a list'''
28 |
29 | z_list = []
30 | z_indexing = self.z_indexing
31 | for i, shape in enumerate(self.z_shape):
32 | z_list.append(z_hidden[:, z_indexing[i]:z_indexing[i+1]].view(shape))
33 | return z_list
34 |
35 | def __call__(self, z_hidden):
36 | '''A function call to the DEQ f'''
37 |
38 | z_list = self.vec2list(z_hidden)
39 | z_list = self.func(*z_list)
40 | z_hidden = self.list2vec(*z_list)
41 |
42 | return z_hidden
43 |
44 | def norm_diff(self, z_new, z_old, show_list=False):
45 | if show_list:
46 | z_new, z_old = self.vec2list(z_new), self.vec2list()
47 | return [(z_new[i] - z_old[i]).norm().item() for i in range(len(z_new))]
48 |
49 | return (z_new - z_old).norm().item()
50 |
--------------------------------------------------------------------------------
/code.v.2.0/core/deq/norm/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from .weight_norm import apply_weight_norm, reset_weight_norm, remove_weight_norm, register_wn_module
3 |
--------------------------------------------------------------------------------
/code.v.2.0/core/deq/norm/weight_norm.py:
--------------------------------------------------------------------------------
1 | import types
2 |
3 | import torch
4 | from torch import nn
5 |
6 | from torch.nn import functional as F
7 | from torch.nn.parameter import Parameter
8 |
9 |
10 | def _norm(p, dim):
11 | """Computes the norm over all dimensions except dim"""
12 | if dim is None:
13 | return p.norm()
14 | elif dim == 0:
15 | output_size = (p.size(0),) + (1,) * (p.dim() - 1)
16 | return p.contiguous().view(p.size(0), -1).norm(dim=1).view(*output_size)
17 | elif dim == p.dim() - 1:
18 | output_size = (1,) * (p.dim() - 1) + (p.size(-1),)
19 | return p.contiguous().view(-1, p.size(-1)).norm(dim=0).view(*output_size)
20 | else:
21 | return _norm(p.transpose(0, dim), 0).transpose(0, dim)
22 |
23 |
24 | def compute_weight(module, name, dim):
25 | g = getattr(module, name + '_g')
26 | v = getattr(module, name + '_v')
27 | return v * (g / _norm(v, dim))
28 |
29 |
30 | def apply_atom_wn(module, names, dims):
31 | if type(names) is str:
32 | names = [names]
33 |
34 | if type(dims) is int:
35 | dims = [dims]
36 |
37 | assert len(names) == len(dims)
38 |
39 | for name, dim in zip(names, dims):
40 | weight = getattr(module, name)
41 |
42 | # remove w from parameter list
43 | del module._parameters[name]
44 |
45 | # add g and v as new parameters and express w as g/||v|| * v
46 | module.register_parameter(name + '_g', Parameter(_norm(weight, dim).data))
47 | module.register_parameter(name + '_v', Parameter(weight.data))
48 | setattr(module, name, compute_weight(module, name, dim))
49 |
50 | module._wn_names = names
51 | module._wn_dims = dims
52 |
53 |
54 | def reset_atom_wn(module):
55 | # Typically, every time the module is called we need to recompute the weight. However,
56 | # in the case of DEQ, the same weight is shared across layers, and we can save
57 | # a lot of intermediate memory by just recomputing once (at the beginning of first call).
58 |
59 | for name, dim in zip(module._wn_names, module._wn_dims):
60 | setattr(module, name, compute_weight(module, name, dim))
61 |
62 |
63 | def remove_atom_wn(module):
64 | for name, dim in zip(module._wn_names, module._wn_dims):
65 | weight = compute_weight(module, name, dim)
66 | delattr(module, name)
67 | del module._parameters[name + '_g']
68 | del module._parameters[name + '_v']
69 | module.register_parameter(name, Parameter(weight.data))
70 |
71 | del module._wn_names
72 | del module._wn_dims
73 |
74 |
75 | target_modules = {
76 | nn.Linear: ('weight', 0),
77 | nn.Conv1d: ('weight', 0),
78 | nn.Conv2d: ('weight', 0),
79 | nn.Conv3d: ('weight', 0)
80 | }
81 |
82 |
83 | def register_wn_module(module_class, names='weight', dims=0):
84 | '''
85 | Register your self-defined module class for ``nested_weight_norm''.
86 | This module class will be automatically indexed for WN.
87 |
88 | Args:
89 | module_class (type): module class to be indexed for weight norm (WN).
90 | names (string): attribute name of ``module_class'' for WN to be applied.
91 | dims (int, optional): dimension over which to compute the norm
92 |
93 | Returns:
94 | None
95 | '''
96 | target_modules[module_class] = (names, dims)
97 |
98 |
99 | def _is_skip_name(name, filter_out):
100 | for skip_name in filter_out:
101 | if name.startswith(skip_name):
102 | return True
103 |
104 | return False
105 |
106 |
107 | def apply_weight_norm(model, filter_out=None):
108 | if type(filter_out) is str:
109 | filter_out = [filter_out]
110 |
111 | for name, module in model.named_modules():
112 | if filter_out and _is_skip_name(name, filter_out):
113 | continue
114 |
115 | class_type = type(module)
116 | if class_type in target_modules:
117 | apply_atom_wn(module, *target_modules[class_type])
118 |
119 |
120 | def reset_weight_norm(model):
121 | for module in model.modules():
122 | if hasattr(module, '_wn_names'):
123 | reset_atom_wn(module)
124 |
125 |
126 | def remove_weight_norm(model):
127 | for module in model.modules():
128 | if hasattr(module, '_wn_names'):
129 | remove_atom_wn(module)
130 |
131 |
132 | if __name__ == '__main__':
133 | z = torch.randn(8, 128, 32, 32)
134 |
135 | net = nn.Conv2d(128, 256, 3, padding=1)
136 | z_orig = net(z)
137 |
138 | apply_weight_norm(net)
139 | z_wn = net(z)
140 |
141 | reset_weight_norm(net)
142 | z_wn_reset = net(z)
143 |
144 | remove_weight_norm(net)
145 | z_back = net(z)
146 |
147 | print((z_orig - z_wn).abs().mean().item())
148 | print((z_orig - z_wn_reset).abs().mean().item())
149 | print((z_orig - z_back).abs().mean().item())
150 |
151 | net = nn.Sequential(
152 | nn.Conv2d(128, 256, 3, padding=1),
153 | nn.GELU(),
154 | nn.Conv2d(256, 128, 3, padding=1)
155 | )
156 | z_orig = net(z)
157 |
158 | apply_weight_norm(net)
159 | z_wn = net(z)
160 |
161 | reset_weight_norm(net)
162 | z_wn_reset = net(z)
163 |
164 | remove_weight_norm(net)
165 | z_back = net(z)
166 |
167 | print((z_orig - z_wn).abs().mean().item())
168 | print((z_orig - z_wn_reset).abs().mean().item())
169 | print((z_orig - z_back).abs().mean().item())
170 |
171 |
--------------------------------------------------------------------------------
/code.v.2.0/core/deq_flow.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from update import UpdateBlock
7 | from extractor import Encoder
8 | from corr import CorrBlock
9 | from utils.utils import bilinear_sampler, coords_grid
10 |
11 | from gma import Attention
12 |
13 | from deq import get_deq
14 | from deq.norm import apply_weight_norm, reset_weight_norm
15 | from deq.layer_utils import DEQWrapper
16 |
17 | from metrics import process_metrics
18 |
19 |
20 | try:
21 | autocast = torch.cuda.amp.autocast
22 | except:
23 | # dummy autocast for PyTorch < 1.6
24 | class autocast:
25 | def __init__(self, enabled):
26 | pass
27 | def __enter__(self):
28 | pass
29 | def __exit__(self, *args):
30 | pass
31 |
32 |
33 | class DEQFlow(nn.Module):
34 | def __init__(self, args):
35 | super(DEQFlow, self).__init__()
36 | self.args = args
37 |
38 | odim = 256
39 | args.corr_levels = 4
40 | args.corr_radius = 4
41 |
42 | if args.tiny:
43 | odim = 64
44 | self.hidden_dim = hdim = 32
45 | self.context_dim = cdim = 32
46 | elif args.large:
47 | self.hidden_dim = hdim = 192
48 | self.context_dim = cdim = 192
49 | elif args.huge:
50 | self.hidden_dim = hdim = 256
51 | self.context_dim = cdim = 256
52 | elif args.gigantic:
53 | self.hidden_dim = hdim = 384
54 | self.context_dim = cdim = 384
55 | else:
56 | self.hidden_dim = hdim = 128
57 | self.context_dim = cdim = 128
58 |
59 | if 'dropout' not in self.args:
60 | self.args.dropout = 0
61 |
62 | # feature network, context network, and update block
63 | self.fnet = Encoder(output_dim=odim, norm_fn='instance', dropout=args.dropout)
64 | self.cnet = Encoder(output_dim=cdim, norm_fn='batch', dropout=args.dropout)
65 | self.update_block = UpdateBlock(self.args, hidden_dim=hdim)
66 |
67 | self.mask = nn.Sequential(
68 | nn.Conv2d(hdim, 256, 3, padding=1),
69 | nn.ReLU(inplace=True),
70 | nn.Conv2d(256, 64*9, 1, padding=0)
71 | )
72 |
73 | if args.gma:
74 | self.attn = Attention(dim=cdim, heads=1, max_pos_size=160, dim_head=cdim)
75 | else:
76 | self.attn = None
77 |
78 | # Added the following for DEQ
79 | if args.wnorm:
80 | apply_weight_norm(self.update_block)
81 |
82 | DEQ = get_deq(args)
83 | self.deq = DEQ(args)
84 |
85 | def freeze_bn(self):
86 | for m in self.modules():
87 | if isinstance(m, nn.BatchNorm2d):
88 | m.eval()
89 |
90 | def _initialize_flow(self, img):
91 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
92 | N, _, H, W = img.shape
93 | coords0 = coords_grid(N, H//8, W//8, device=img.device)
94 | coords1 = coords_grid(N, H//8, W//8, device=img.device)
95 |
96 | # optical flow computed as difference: flow = coords1 - coords0
97 | return coords0, coords1
98 |
99 | def _upsample_flow(self, flow, mask):
100 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
101 | N, _, H, W = flow.shape
102 | mask = mask.view(N, 1, 9, 8, 8, H, W)
103 | mask = torch.softmax(mask, dim=2)
104 |
105 | up_flow = F.unfold(8 * flow, [3,3], padding=1)
106 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
107 |
108 | up_flow = torch.sum(mask * up_flow, dim=2)
109 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
110 | return up_flow.reshape(N, 2, 8*H, 8*W)
111 |
112 | def _decode(self, z_out, coords0):
113 | net, coords1 = z_out
114 | up_mask = .25 * self.mask(net)
115 | flow_up = self._upsample_flow(coords1 - coords0, up_mask)
116 |
117 | return flow_up
118 |
119 | def forward(self, image1, image2,
120 | flow_gt=None, valid=None, fc_loss=None,
121 | flow_init=None, cached_result=None,
122 | writer=None, sradius_mode=False,
123 | **kwargs):
124 | """ Estimate optical flow between pair of frames """
125 |
126 | image1 = 2 * (image1 / 255.0) - 1.0
127 | image2 = 2 * (image2 / 255.0) - 1.0
128 |
129 | image1 = image1.contiguous()
130 | image2 = image2.contiguous()
131 |
132 | hdim = self.hidden_dim
133 | cdim = self.context_dim
134 |
135 | # run the feature network
136 | with autocast(enabled=self.args.mixed_precision):
137 | fmap1, fmap2 = self.fnet([image1, image2])
138 |
139 | fmap1 = fmap1.float()
140 | fmap2 = fmap2.float()
141 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
142 |
143 | # run the context network
144 | with autocast(enabled=self.args.mixed_precision):
145 | # cnet = self.cnet(image1)
146 | # net, inp = torch.split(cnet, [hdim, cdim], dim=1)
147 | # net = torch.tanh(net)
148 | inp = self.cnet(image1)
149 | inp = torch.relu(inp)
150 |
151 | if self.attn:
152 | attn = self.attn(inp)
153 | else:
154 | attn = None
155 |
156 | bsz, _, H, W = inp.shape
157 | coords0, coords1 = self._initialize_flow(image1)
158 | net = torch.zeros(bsz, hdim, H, W, device=inp.device)
159 |
160 | if cached_result:
161 | net, flow_pred_prev = cached_result
162 | coords1 = coords0 + flow_pred_prev
163 |
164 | if flow_init is not None:
165 | coords1 = coords1 + flow_init
166 |
167 | if self.args.wnorm:
168 | reset_weight_norm(self.update_block) # Reset weights for WN
169 |
170 | def func(h,c):
171 | if not self.args.all_grad:
172 | c = c.detach()
173 | with autocast(enabled=self.args.mixed_precision):
174 | new_h, delta_flow = self.update_block(h, inp, corr_fn(c), c-coords0, attn) # corr_fn(coords1) produces the index correlation volumes
175 | new_c = c + delta_flow # F(t+1) = F(t) + \Delta(t)
176 | return new_h, new_c
177 |
178 | deq_func = DEQWrapper(func, (net, coords1))
179 | z_init = deq_func.list2vec(net, coords1)
180 | log = (inp.get_device() == 0 and np.random.uniform(0,1) < 2e-3)
181 |
182 | z_out, info = self.deq(deq_func, z_init, log, sradius_mode, **kwargs)
183 | flow_pred = [self._decode(z, coords0) for z in z_out]
184 |
185 | if self.training:
186 | flow_loss, epe = fc_loss(flow_pred, flow_gt, valid)
187 | metrics = process_metrics(epe, info)
188 |
189 | return flow_loss, metrics
190 | else:
191 | (net, coords1), flow_up = z_out[-1], flow_pred[-1]
192 |
193 | return coords1 - coords0, flow_up, {"sradius": info['sradius'], "cached_result": (net, coords1 - coords0)}
194 |
195 |
--------------------------------------------------------------------------------
/code.v.2.0/core/extractor.py:
--------------------------------------------------------------------------------
1 | # This code was originally taken from RAFT without modification
2 | # https://github.com/princeton-vl/RAFT/blob/master/core/extractor.py
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class ResidualBlock(nn.Module):
10 | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
11 | super(ResidualBlock, self).__init__()
12 |
13 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
14 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
15 | self.relu = nn.ReLU(inplace=True)
16 |
17 | num_groups = planes // 8
18 |
19 | if norm_fn == 'group':
20 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
21 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
22 | if not stride == 1:
23 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
24 |
25 | elif norm_fn == 'batch':
26 | self.norm1 = nn.BatchNorm2d(planes)
27 | self.norm2 = nn.BatchNorm2d(planes)
28 | if not stride == 1:
29 | self.norm3 = nn.BatchNorm2d(planes)
30 |
31 | elif norm_fn == 'instance':
32 | self.norm1 = nn.InstanceNorm2d(planes)
33 | self.norm2 = nn.InstanceNorm2d(planes)
34 | if not stride == 1:
35 | self.norm3 = nn.InstanceNorm2d(planes)
36 |
37 | elif norm_fn == 'none':
38 | self.norm1 = nn.Sequential()
39 | self.norm2 = nn.Sequential()
40 | if not stride == 1:
41 | self.norm3 = nn.Sequential()
42 |
43 | if stride == 1:
44 | self.downsample = None
45 |
46 | else:
47 | self.downsample = nn.Sequential(
48 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
49 |
50 |
51 | def forward(self, x):
52 | y = x
53 | y = self.relu(self.norm1(self.conv1(y)))
54 | y = self.relu(self.norm2(self.conv2(y)))
55 |
56 | if self.downsample is not None:
57 | x = self.downsample(x)
58 |
59 | return self.relu(x+y)
60 |
61 |
62 | class BottleneckBlock(nn.Module):
63 | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
64 | super(BottleneckBlock, self).__init__()
65 |
66 | self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
67 | self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
68 | self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
69 | self.relu = nn.ReLU(inplace=True)
70 |
71 | num_groups = planes // 8
72 |
73 | if norm_fn == 'group':
74 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
75 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
76 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
77 | if not stride == 1:
78 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
79 |
80 | elif norm_fn == 'batch':
81 | self.norm1 = nn.BatchNorm2d(planes//4)
82 | self.norm2 = nn.BatchNorm2d(planes//4)
83 | self.norm3 = nn.BatchNorm2d(planes)
84 | if not stride == 1:
85 | self.norm4 = nn.BatchNorm2d(planes)
86 |
87 | elif norm_fn == 'instance':
88 | self.norm1 = nn.InstanceNorm2d(planes//4)
89 | self.norm2 = nn.InstanceNorm2d(planes//4)
90 | self.norm3 = nn.InstanceNorm2d(planes)
91 | if not stride == 1:
92 | self.norm4 = nn.InstanceNorm2d(planes)
93 |
94 | elif norm_fn == 'none':
95 | self.norm1 = nn.Sequential()
96 | self.norm2 = nn.Sequential()
97 | self.norm3 = nn.Sequential()
98 | if not stride == 1:
99 | self.norm4 = nn.Sequential()
100 |
101 | if stride == 1:
102 | self.downsample = None
103 |
104 | else:
105 | self.downsample = nn.Sequential(
106 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
107 |
108 | def forward(self, x):
109 | y = x
110 | y = self.relu(self.norm1(self.conv1(y)))
111 | y = self.relu(self.norm2(self.conv2(y)))
112 | y = self.relu(self.norm3(self.conv3(y)))
113 |
114 | if self.downsample is not None:
115 | x = self.downsample(x)
116 |
117 | return self.relu(x+y)
118 |
119 |
120 | class Encoder(nn.Module):
121 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
122 | super(Encoder, self).__init__()
123 | self.norm_fn = norm_fn
124 |
125 | if self.norm_fn == 'group':
126 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
127 |
128 | elif self.norm_fn == 'batch':
129 | self.norm1 = nn.BatchNorm2d(64)
130 |
131 | elif self.norm_fn == 'instance':
132 | self.norm1 = nn.InstanceNorm2d(64)
133 |
134 | elif self.norm_fn == 'none':
135 | self.norm1 = nn.Sequential()
136 |
137 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
138 | self.relu1 = nn.ReLU(inplace=True)
139 |
140 | self.in_planes = 64
141 | self.layer1 = self._make_layer(64, stride=1)
142 | self.layer2 = self._make_layer(96, stride=2)
143 | self.layer3 = self._make_layer(128, stride=2)
144 |
145 | # output convolution
146 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
147 |
148 | self.dropout = None
149 | if dropout > 0:
150 | self.dropout = nn.Dropout2d(p=dropout)
151 |
152 | for m in self.modules():
153 | if isinstance(m, nn.Conv2d):
154 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
155 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
156 | if m.weight is not None:
157 | nn.init.constant_(m.weight, 1)
158 | if m.bias is not None:
159 | nn.init.constant_(m.bias, 0)
160 |
161 | def _make_layer(self, dim, stride=1):
162 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
163 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
164 | layers = (layer1, layer2)
165 |
166 | self.in_planes = dim
167 | return nn.Sequential(*layers)
168 |
169 |
170 | def forward(self, x):
171 |
172 | # if input is list, combine batch dimension
173 | is_list = isinstance(x, tuple) or isinstance(x, list)
174 | if is_list:
175 | batch_dim = x[0].shape[0]
176 | x = torch.cat(x, dim=0)
177 |
178 | x = self.conv1(x)
179 | x = self.norm1(x)
180 | x = self.relu1(x)
181 |
182 | x = self.layer1(x)
183 | x = self.layer2(x)
184 | x = self.layer3(x)
185 |
186 | x = self.conv2(x)
187 |
188 | if self.training and self.dropout is not None:
189 | x = self.dropout(x)
190 |
191 | if is_list:
192 | x = torch.split(x, [batch_dim, batch_dim], dim=0)
193 |
194 | return x
195 |
196 |
197 |
--------------------------------------------------------------------------------
/code.v.2.0/core/gma.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, einsum
3 | from einops import rearrange
4 |
5 |
6 | class RelPosEmb(nn.Module):
7 | def __init__(
8 | self,
9 | max_pos_size,
10 | dim_head
11 | ):
12 | super().__init__()
13 | self.rel_height = nn.Embedding(2 * max_pos_size - 1, dim_head)
14 | self.rel_width = nn.Embedding(2 * max_pos_size - 1, dim_head)
15 |
16 | deltas = torch.arange(max_pos_size).view(1, -1) - torch.arange(max_pos_size).view(-1, 1)
17 | rel_ind = deltas + max_pos_size - 1
18 | self.register_buffer('rel_ind', rel_ind)
19 |
20 | def forward(self, q):
21 | batch, heads, h, w, c = q.shape
22 | height_emb = self.rel_height(self.rel_ind[:h, :h].reshape(-1))
23 | width_emb = self.rel_width(self.rel_ind[:w, :w].reshape(-1))
24 |
25 | height_emb = rearrange(height_emb, '(x u) d -> x u () d', x=h)
26 | width_emb = rearrange(width_emb, '(y v) d -> y () v d', y=w)
27 |
28 | height_score = einsum('b h x y d, x u v d -> b h x y u v', q, height_emb)
29 | width_score = einsum('b h x y d, y u v d -> b h x y u v', q, width_emb)
30 |
31 | return height_score + width_score
32 |
33 |
34 | class Attention(nn.Module):
35 | def __init__(
36 | self,
37 | *,
38 | dim,
39 | max_pos_size = 100,
40 | heads = 4,
41 | dim_head = 128,
42 | ):
43 | super().__init__()
44 | self.heads = heads
45 | self.scale = dim_head ** -0.5
46 | inner_dim = heads * dim_head
47 |
48 | self.to_qk = nn.Conv2d(dim, inner_dim * 2, 1, bias=False)
49 |
50 | def forward(self, fmap):
51 | heads, b, c, h, w = self.heads, *fmap.shape
52 |
53 | q, k = self.to_qk(fmap).chunk(2, dim=1)
54 |
55 | q, k = map(lambda t: rearrange(t, 'b (h d) x y -> b h x y d', h=heads), (q, k))
56 | q = self.scale * q
57 |
58 | sim = einsum('b h x y d, b h u v d -> b h x y u v', q, k)
59 |
60 | sim = rearrange(sim, 'b h x y u v -> b h (x y) (u v)')
61 | attn = sim.softmax(dim=-1)
62 |
63 | return attn
64 |
65 |
66 | class Aggregate(nn.Module):
67 | def __init__(
68 | self,
69 | dim,
70 | heads = 4,
71 | dim_head = 128,
72 | ):
73 | super().__init__()
74 | self.heads = heads
75 | self.scale = dim_head ** -0.5
76 | inner_dim = heads * dim_head
77 |
78 | self.to_v = nn.Conv2d(dim, inner_dim, 1, bias=False)
79 |
80 | self.gamma = nn.Parameter(torch.zeros(1))
81 |
82 | if dim != inner_dim:
83 | self.project = nn.Conv2d(inner_dim, dim, 1, bias=False)
84 | else:
85 | self.project = None
86 |
87 | def forward(self, attn, fmap):
88 | heads, b, c, h, w = self.heads, *fmap.shape
89 |
90 | v = self.to_v(fmap)
91 | v = rearrange(v, 'b (h d) x y -> b h (x y) d', h=heads)
92 | out = einsum('b h i j, b h j d -> b h i d', attn, v)
93 | out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w)
94 |
95 | if self.project is not None:
96 | out = self.project(out)
97 |
98 | out = fmap + self.gamma * out
99 |
100 | return out
101 |
102 |
103 | if __name__ == "__main__":
104 | att = Attention(dim=128, heads=1)
105 | fmap = torch.randn(2, 128, 40, 90)
106 | out = att(fmap)
107 |
108 | print(out.shape)
109 |
--------------------------------------------------------------------------------
/code.v.2.0/core/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | MAX_FLOW = 400
4 |
5 | @torch.no_grad()
6 | def compute_epe(flow_pred, flow_gt, valid, max_flow=MAX_FLOW):
7 | # exlude invalid pixels and extremely large diplacements
8 | mag = torch.sum(flow_gt**2, dim=1).sqrt()
9 | valid = (valid >= 0.5) & (mag < max_flow)
10 |
11 | epe = torch.sum((flow_pred - flow_gt)**2, dim=1).sqrt()
12 | epe = torch.masked_fill(epe, ~valid, 0)
13 |
14 | return epe
15 |
16 |
17 | @torch.no_grad()
18 | def process_metrics(epe, info, **kwargs):
19 | epe = epe.flatten(1)
20 | metrics = {
21 | 'epe': epe.mean(dim=1),
22 | '1px': (epe < 1).float().mean(dim=1),
23 | '3px': (epe < 3).float().mean(dim=1),
24 | '5px': (epe < 5).float().mean(dim=1),
25 | 'rel': info['rel_lowest'],
26 | 'abs': info['abs_lowest'],
27 | }
28 |
29 | # dict: N_Metrics -> B // N_GPU
30 | return metrics
31 |
32 |
33 | @torch.no_grad()
34 | def merge_metrics(metrics):
35 | out = dict()
36 |
37 | for key, value in metrics.items():
38 | out[key] = value.mean().item()
39 |
40 | return out
41 |
--------------------------------------------------------------------------------
/code.v.2.0/core/update.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from gma import Aggregate
6 |
7 |
8 | class FlowHead(nn.Module):
9 | def __init__(self, input_dim=128, hidden_dim=256):
10 | super(FlowHead, self).__init__()
11 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
12 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
13 | self.relu = nn.ReLU(inplace=True)
14 |
15 | def forward(self, x):
16 | return self.conv2(self.relu(self.conv1(x)))
17 |
18 |
19 | class ConvGRU(nn.Module):
20 | def __init__(self, hidden_dim=128, input_dim=192+128):
21 | super(ConvGRU, self).__init__()
22 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
23 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
24 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
25 |
26 | def forward(self, h, x):
27 | hx = torch.cat([h, x], dim=1)
28 |
29 | z = torch.sigmoid(self.convz(hx))
30 | r = torch.sigmoid(self.convr(hx))
31 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
32 |
33 | h = (1-z) * h + z * q
34 | return h
35 |
36 |
37 | class SepConvGRU(nn.Module):
38 | def __init__(self, hidden_dim=128, input_dim=192+128):
39 | super(SepConvGRU, self).__init__()
40 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
41 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
42 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
43 |
44 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
45 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
46 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
47 |
48 | def forward(self, h, x):
49 | # horizontal
50 | hx = torch.cat([h, x], dim=1)
51 | z = torch.sigmoid(self.convz1(hx))
52 | r = torch.sigmoid(self.convr1(hx))
53 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
54 | h = (1-z) * h + z * q
55 |
56 | # vertical
57 | hx = torch.cat([h, x], dim=1)
58 | z = torch.sigmoid(self.convz2(hx))
59 | r = torch.sigmoid(self.convr2(hx))
60 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
61 | h = (1-z) * h + z * q
62 |
63 | return h
64 |
65 |
66 | class MotionEncoder(nn.Module):
67 | def __init__(self, args):
68 | super(MotionEncoder, self).__init__()
69 |
70 | if args.large:
71 | c_dim_1 = 256 + 128
72 | c_dim_2 = 192 + 96
73 |
74 | f_dim_1 = 128 + 64
75 | f_dim_2 = 64 + 32
76 |
77 | cat_dim = 128 + 64
78 | elif args.huge:
79 | c_dim_1 = 256 + 256
80 | c_dim_2 = 192 + 192
81 |
82 | f_dim_1 = 128 + 128
83 | f_dim_2 = 64 + 64
84 |
85 | cat_dim = 128 + 128
86 | elif args.gigantic:
87 | c_dim_1 = 256 + 384
88 | c_dim_2 = 192 + 288
89 |
90 | f_dim_1 = 128 + 192
91 | f_dim_2 = 64 + 96
92 |
93 | cat_dim = 128 + 192
94 | elif args.tiny:
95 | c_dim_1 = 64
96 | c_dim_2 = 48
97 |
98 | f_dim_1 = 32
99 | f_dim_2 = 16
100 |
101 | cat_dim = 32
102 | else:
103 | c_dim_1 = 256
104 | c_dim_2 = 192
105 |
106 | f_dim_1 = 128
107 | f_dim_2 = 64
108 |
109 | cat_dim = 128
110 |
111 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
112 | self.convc1 = nn.Conv2d(cor_planes, c_dim_1, 1, padding=0)
113 | self.convc2 = nn.Conv2d(c_dim_1, c_dim_2, 3, padding=1)
114 | self.convf1 = nn.Conv2d(2, f_dim_1, 7, padding=3)
115 | self.convf2 = nn.Conv2d(f_dim_1, f_dim_2, 3, padding=1)
116 | self.conv = nn.Conv2d(c_dim_2+f_dim_2, cat_dim-2, 3, padding=1)
117 |
118 | def forward(self, flow, corr):
119 | cor = F.relu(self.convc1(corr))
120 | cor = F.relu(self.convc2(cor))
121 | flo = F.relu(self.convf1(flow))
122 | flo = F.relu(self.convf2(flo))
123 |
124 | cor_flo = torch.cat([cor, flo], dim=1)
125 | out = F.relu(self.conv(cor_flo))
126 | return torch.cat([out, flow], dim=1)
127 |
128 |
129 | class UpdateBlock(nn.Module):
130 | def __init__(self, args, hidden_dim=128, input_dim=128):
131 | super(UpdateBlock, self).__init__()
132 | self.args = args
133 |
134 | if args.tiny:
135 | cat_dim = 32
136 | elif args.large:
137 | cat_dim = 128 + 64
138 | elif args.huge:
139 | cat_dim = 128 + 128
140 | elif args.gigantic:
141 | cat_dim = 128 + 192
142 | else:
143 | cat_dim = 128
144 |
145 | if args.old_version:
146 | flow_head_dim = min(256, 2*cat_dim)
147 | else:
148 | flow_head_dim = 2*cat_dim
149 |
150 | self.encoder = MotionEncoder(args)
151 |
152 | if args.gma:
153 | self.gma = Aggregate(dim=cat_dim, dim_head=cat_dim, heads=1)
154 |
155 | gru_in_dim = 2 * cat_dim + hidden_dim
156 | else:
157 | self.gma = None
158 |
159 | gru_in_dim = cat_dim + hidden_dim
160 |
161 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=gru_in_dim)
162 | self.flow_head = FlowHead(hidden_dim, hidden_dim=flow_head_dim)
163 |
164 | def forward(self, net, inp, corr, flow, attn=None, upsample=True):
165 | motion_features = self.encoder(flow, corr)
166 |
167 | if self.gma:
168 | motion_features_global = self.gma(attn, motion_features)
169 | inp = torch.cat([inp, motion_features, motion_features_global], dim=1)
170 | else:
171 | inp = torch.cat([inp, motion_features], dim=1)
172 |
173 | net = self.gru(net, inp)
174 | delta_flow = self.flow_head(net)
175 |
176 | return net, delta_flow
177 |
178 |
179 |
180 |
--------------------------------------------------------------------------------
/code.v.2.0/core/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/locuslab/deq-flow/869d4bcd9c1d227cbfa16639994d1b118510a9bb/code.v.2.0/core/utils/__init__.py
--------------------------------------------------------------------------------
/code.v.2.0/core/utils/augmentor.py:
--------------------------------------------------------------------------------
1 | # This code was originally taken from RAFT without modification
2 | # https://github.com/princeton-vl/RAFT/blob/master/utils/augmentor.py
3 |
4 | import numpy as np
5 | import random
6 | import math
7 | from PIL import Image
8 |
9 | import cv2
10 | cv2.setNumThreads(0)
11 | cv2.ocl.setUseOpenCL(False)
12 |
13 | import torch
14 | from torchvision.transforms import ColorJitter
15 | import torch.nn.functional as F
16 |
17 |
18 | class FlowAugmentor:
19 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
20 |
21 | # spatial augmentation params
22 | self.crop_size = crop_size
23 | self.min_scale = min_scale
24 | self.max_scale = max_scale
25 | self.spatial_aug_prob = 0.8
26 | self.stretch_prob = 0.8
27 | self.max_stretch = 0.2
28 |
29 | # flip augmentation params
30 | self.do_flip = do_flip
31 | self.h_flip_prob = 0.5
32 | self.v_flip_prob = 0.1
33 |
34 | # photometric augmentation params
35 | self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
36 | self.asymmetric_color_aug_prob = 0.2
37 | self.eraser_aug_prob = 0.5
38 |
39 | def color_transform(self, img1, img2):
40 | """ Photometric augmentation """
41 |
42 | # asymmetric
43 | if np.random.rand() < self.asymmetric_color_aug_prob:
44 | img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
45 | img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
46 |
47 | # symmetric
48 | else:
49 | image_stack = np.concatenate([img1, img2], axis=0)
50 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
51 | img1, img2 = np.split(image_stack, 2, axis=0)
52 |
53 | return img1, img2
54 |
55 | def eraser_transform(self, img1, img2, bounds=[50, 100]):
56 | """ Occlusion augmentation """
57 |
58 | ht, wd = img1.shape[:2]
59 | if np.random.rand() < self.eraser_aug_prob:
60 | mean_color = np.mean(img2.reshape(-1, 3), axis=0)
61 | for _ in range(np.random.randint(1, 3)):
62 | x0 = np.random.randint(0, wd)
63 | y0 = np.random.randint(0, ht)
64 | dx = np.random.randint(bounds[0], bounds[1])
65 | dy = np.random.randint(bounds[0], bounds[1])
66 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color
67 |
68 | return img1, img2
69 |
70 | def spatial_transform(self, img1, img2, flow):
71 | # randomly sample scale
72 | ht, wd = img1.shape[:2]
73 | min_scale = np.maximum(
74 | (self.crop_size[0] + 8) / float(ht),
75 | (self.crop_size[1] + 8) / float(wd))
76 |
77 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
78 | scale_x = scale
79 | scale_y = scale
80 | if np.random.rand() < self.stretch_prob:
81 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
82 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
83 |
84 | scale_x = np.clip(scale_x, min_scale, None)
85 | scale_y = np.clip(scale_y, min_scale, None)
86 |
87 | if np.random.rand() < self.spatial_aug_prob:
88 | # rescale the images
89 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
90 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
91 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
92 | flow = flow * [scale_x, scale_y]
93 |
94 | if self.do_flip:
95 | if np.random.rand() < self.h_flip_prob: # h-flip
96 | img1 = img1[:, ::-1]
97 | img2 = img2[:, ::-1]
98 | flow = flow[:, ::-1] * [-1.0, 1.0]
99 |
100 | if np.random.rand() < self.v_flip_prob: # v-flip
101 | img1 = img1[::-1, :]
102 | img2 = img2[::-1, :]
103 | flow = flow[::-1, :] * [1.0, -1.0]
104 |
105 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
106 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
107 |
108 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
109 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
110 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
111 |
112 | return img1, img2, flow
113 |
114 | def __call__(self, img1, img2, flow):
115 | img1, img2 = self.color_transform(img1, img2)
116 | img1, img2 = self.eraser_transform(img1, img2)
117 | img1, img2, flow = self.spatial_transform(img1, img2, flow)
118 |
119 | img1 = np.ascontiguousarray(img1)
120 | img2 = np.ascontiguousarray(img2)
121 | flow = np.ascontiguousarray(flow)
122 |
123 | return img1, img2, flow
124 |
125 | class SparseFlowAugmentor:
126 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
127 | # spatial augmentation params
128 | self.crop_size = crop_size
129 | self.min_scale = min_scale
130 | self.max_scale = max_scale
131 | self.spatial_aug_prob = 0.8
132 | self.stretch_prob = 0.8
133 | self.max_stretch = 0.2
134 |
135 | # flip augmentation params
136 | self.do_flip = do_flip
137 | self.h_flip_prob = 0.5
138 | self.v_flip_prob = 0.1
139 |
140 | # photometric augmentation params
141 | self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
142 | self.asymmetric_color_aug_prob = 0.2
143 | self.eraser_aug_prob = 0.5
144 |
145 | def color_transform(self, img1, img2):
146 | image_stack = np.concatenate([img1, img2], axis=0)
147 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
148 | img1, img2 = np.split(image_stack, 2, axis=0)
149 | return img1, img2
150 |
151 | def eraser_transform(self, img1, img2):
152 | ht, wd = img1.shape[:2]
153 | if np.random.rand() < self.eraser_aug_prob:
154 | mean_color = np.mean(img2.reshape(-1, 3), axis=0)
155 | for _ in range(np.random.randint(1, 3)):
156 | x0 = np.random.randint(0, wd)
157 | y0 = np.random.randint(0, ht)
158 | dx = np.random.randint(50, 100)
159 | dy = np.random.randint(50, 100)
160 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color
161 |
162 | return img1, img2
163 |
164 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
165 | ht, wd = flow.shape[:2]
166 | coords = np.meshgrid(np.arange(wd), np.arange(ht))
167 | coords = np.stack(coords, axis=-1)
168 |
169 | coords = coords.reshape(-1, 2).astype(np.float32)
170 | flow = flow.reshape(-1, 2).astype(np.float32)
171 | valid = valid.reshape(-1).astype(np.float32)
172 |
173 | coords0 = coords[valid>=1]
174 | flow0 = flow[valid>=1]
175 |
176 | ht1 = int(round(ht * fy))
177 | wd1 = int(round(wd * fx))
178 |
179 | coords1 = coords0 * [fx, fy]
180 | flow1 = flow0 * [fx, fy]
181 |
182 | xx = np.round(coords1[:,0]).astype(np.int32)
183 | yy = np.round(coords1[:,1]).astype(np.int32)
184 |
185 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
186 | xx = xx[v]
187 | yy = yy[v]
188 | flow1 = flow1[v]
189 |
190 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
191 | valid_img = np.zeros([ht1, wd1], dtype=np.int32)
192 |
193 | flow_img[yy, xx] = flow1
194 | valid_img[yy, xx] = 1
195 |
196 | return flow_img, valid_img
197 |
198 | def spatial_transform(self, img1, img2, flow, valid):
199 | # randomly sample scale
200 |
201 | ht, wd = img1.shape[:2]
202 | min_scale = np.maximum(
203 | (self.crop_size[0] + 1) / float(ht),
204 | (self.crop_size[1] + 1) / float(wd))
205 |
206 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
207 | scale_x = np.clip(scale, min_scale, None)
208 | scale_y = np.clip(scale, min_scale, None)
209 |
210 | if np.random.rand() < self.spatial_aug_prob:
211 | # rescale the images
212 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
213 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
214 | flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
215 |
216 | if self.do_flip:
217 | if np.random.rand() < 0.5: # h-flip
218 | img1 = img1[:, ::-1]
219 | img2 = img2[:, ::-1]
220 | flow = flow[:, ::-1] * [-1.0, 1.0]
221 | valid = valid[:, ::-1]
222 |
223 | margin_y = 20
224 | margin_x = 50
225 |
226 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
227 | x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
228 |
229 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
230 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
231 |
232 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
233 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
234 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
235 | valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
236 | return img1, img2, flow, valid
237 |
238 |
239 | def __call__(self, img1, img2, flow, valid):
240 | img1, img2 = self.color_transform(img1, img2)
241 | img1, img2 = self.eraser_transform(img1, img2)
242 | img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
243 |
244 | img1 = np.ascontiguousarray(img1)
245 | img2 = np.ascontiguousarray(img2)
246 | flow = np.ascontiguousarray(flow)
247 | valid = np.ascontiguousarray(valid)
248 |
249 | return img1, img2, flow, valid
250 |
--------------------------------------------------------------------------------
/code.v.2.0/core/utils/flow_viz.py:
--------------------------------------------------------------------------------
1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
2 |
3 |
4 | # MIT License
5 | #
6 | # Copyright (c) 2018 Tom Runia
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to conditions.
14 | #
15 | # Author: Tom Runia
16 | # Date Created: 2018-08-03
17 |
18 | import numpy as np
19 |
20 | def make_colorwheel():
21 | """
22 | Generates a color wheel for optical flow visualization as presented in:
23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
25 |
26 | Code follows the original C++ source code of Daniel Scharstein.
27 | Code follows the the Matlab source code of Deqing Sun.
28 |
29 | Returns:
30 | np.ndarray: Color wheel
31 | """
32 |
33 | RY = 15
34 | YG = 6
35 | GC = 4
36 | CB = 11
37 | BM = 13
38 | MR = 6
39 |
40 | ncols = RY + YG + GC + CB + BM + MR
41 | colorwheel = np.zeros((ncols, 3))
42 | col = 0
43 |
44 | # RY
45 | colorwheel[0:RY, 0] = 255
46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
47 | col = col+RY
48 | # YG
49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
50 | colorwheel[col:col+YG, 1] = 255
51 | col = col+YG
52 | # GC
53 | colorwheel[col:col+GC, 1] = 255
54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
55 | col = col+GC
56 | # CB
57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
58 | colorwheel[col:col+CB, 2] = 255
59 | col = col+CB
60 | # BM
61 | colorwheel[col:col+BM, 2] = 255
62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
63 | col = col+BM
64 | # MR
65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
66 | colorwheel[col:col+MR, 0] = 255
67 | return colorwheel
68 |
69 |
70 | def flow_uv_to_colors(u, v, convert_to_bgr=False):
71 | """
72 | Applies the flow color wheel to (possibly clipped) flow components u and v.
73 |
74 | According to the C++ source code of Daniel Scharstein
75 | According to the Matlab source code of Deqing Sun
76 |
77 | Args:
78 | u (np.ndarray): Input horizontal flow of shape [H,W]
79 | v (np.ndarray): Input vertical flow of shape [H,W]
80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
81 |
82 | Returns:
83 | np.ndarray: Flow visualization image of shape [H,W,3]
84 | """
85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
86 | colorwheel = make_colorwheel() # shape [55x3]
87 | ncols = colorwheel.shape[0]
88 | rad = np.sqrt(np.square(u) + np.square(v))
89 | a = np.arctan2(-v, -u)/np.pi
90 | fk = (a+1) / 2*(ncols-1)
91 | k0 = np.floor(fk).astype(np.int32)
92 | k1 = k0 + 1
93 | k1[k1 == ncols] = 0
94 | f = fk - k0
95 | for i in range(colorwheel.shape[1]):
96 | tmp = colorwheel[:,i]
97 | col0 = tmp[k0] / 255.0
98 | col1 = tmp[k1] / 255.0
99 | col = (1-f)*col0 + f*col1
100 | idx = (rad <= 1)
101 | col[idx] = 1 - rad[idx] * (1-col[idx])
102 | col[~idx] = col[~idx] * 0.75 # out of range
103 | # Note the 2-i => BGR instead of RGB
104 | ch_idx = 2-i if convert_to_bgr else i
105 | flow_image[:,:,ch_idx] = np.floor(255 * col)
106 | return flow_image
107 |
108 |
109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
110 | """
111 | Expects a two dimensional flow image of shape.
112 |
113 | Args:
114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
117 |
118 | Returns:
119 | np.ndarray: Flow visualization image of shape [H,W,3]
120 | """
121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions'
122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
123 | if clip_flow is not None:
124 | flow_uv = np.clip(flow_uv, 0, clip_flow)
125 | u = flow_uv[:,:,0]
126 | v = flow_uv[:,:,1]
127 | rad = np.sqrt(np.square(u) + np.square(v))
128 | rad_max = np.max(rad)
129 | epsilon = 1e-5
130 | u = u / (rad_max + epsilon)
131 | v = v / (rad_max + epsilon)
132 | return flow_uv_to_colors(u, v, convert_to_bgr)
--------------------------------------------------------------------------------
/code.v.2.0/core/utils/frame_utils.py:
--------------------------------------------------------------------------------
1 | # This code was originally taken from RAFT without modification
2 | # https://github.com/princeton-vl/RAFT/blob/master/utils/frame_utils.py
3 |
4 | import numpy as np
5 | from PIL import Image
6 | from os.path import *
7 | import re
8 |
9 | import cv2
10 | cv2.setNumThreads(0)
11 | cv2.ocl.setUseOpenCL(False)
12 |
13 | TAG_CHAR = np.array([202021.25], np.float32)
14 |
15 | def readFlow(fn):
16 | """ Read .flo file in Middlebury format"""
17 | # Code adapted from:
18 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
19 |
20 | # WARNING: this will work on little-endian architectures (eg Intel x86) only!
21 | # print 'fn = %s'%(fn)
22 | with open(fn, 'rb') as f:
23 | magic = np.fromfile(f, np.float32, count=1)
24 | if 202021.25 != magic:
25 | print('Magic number incorrect. Invalid .flo file')
26 | return None
27 | else:
28 | w = np.fromfile(f, np.int32, count=1)
29 | h = np.fromfile(f, np.int32, count=1)
30 | # print 'Reading %d x %d flo file\n' % (w, h)
31 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
32 | # Reshape data into 3D array (columns, rows, bands)
33 | # The reshape here is for visualization, the original code is (w,h,2)
34 | return np.resize(data, (int(h), int(w), 2))
35 |
36 | def readPFM(file):
37 | file = open(file, 'rb')
38 |
39 | color = None
40 | width = None
41 | height = None
42 | scale = None
43 | endian = None
44 |
45 | header = file.readline().rstrip()
46 | if header == b'PF':
47 | color = True
48 | elif header == b'Pf':
49 | color = False
50 | else:
51 | raise Exception('Not a PFM file.')
52 |
53 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
54 | if dim_match:
55 | width, height = map(int, dim_match.groups())
56 | else:
57 | raise Exception('Malformed PFM header.')
58 |
59 | scale = float(file.readline().rstrip())
60 | if scale < 0: # little-endian
61 | endian = '<'
62 | scale = -scale
63 | else:
64 | endian = '>' # big-endian
65 |
66 | data = np.fromfile(file, endian + 'f')
67 | shape = (height, width, 3) if color else (height, width)
68 |
69 | data = np.reshape(data, shape)
70 | data = np.flipud(data)
71 | return data
72 |
73 | def writeFlow(filename,uv,v=None):
74 | """ Write optical flow to file.
75 |
76 | If v is None, uv is assumed to contain both u and v channels,
77 | stacked in depth.
78 | Original code by Deqing Sun, adapted from Daniel Scharstein.
79 | """
80 | nBands = 2
81 |
82 | if v is None:
83 | assert(uv.ndim == 3)
84 | assert(uv.shape[2] == 2)
85 | u = uv[:,:,0]
86 | v = uv[:,:,1]
87 | else:
88 | u = uv
89 |
90 | assert(u.shape == v.shape)
91 | height,width = u.shape
92 | f = open(filename,'wb')
93 | # write the header
94 | f.write(TAG_CHAR)
95 | np.array(width).astype(np.int32).tofile(f)
96 | np.array(height).astype(np.int32).tofile(f)
97 | # arrange into matrix form
98 | tmp = np.zeros((height, width*nBands))
99 | tmp[:,np.arange(width)*2] = u
100 | tmp[:,np.arange(width)*2 + 1] = v
101 | tmp.astype(np.float32).tofile(f)
102 | f.close()
103 |
104 |
105 | def readFlowKITTI(filename):
106 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
107 | flow = flow[:,:,::-1].astype(np.float32)
108 | flow, valid = flow[:, :, :2], flow[:, :, 2]
109 | flow = (flow - 2**15) / 64.0
110 | return flow, valid
111 |
112 | def readDispKITTI(filename):
113 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
114 | valid = disp > 0.0
115 | flow = np.stack([-disp, np.zeros_like(disp)], -1)
116 | return flow, valid
117 |
118 |
119 | def writeFlowKITTI(filename, uv):
120 | uv = 64.0 * uv + 2**15
121 | valid = np.ones([uv.shape[0], uv.shape[1], 1])
122 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
123 | cv2.imwrite(filename, uv[..., ::-1])
124 |
125 |
126 | def read_gen(file_name, pil=False):
127 | ext = splitext(file_name)[-1]
128 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
129 | return Image.open(file_name)
130 | elif ext == '.bin' or ext == '.raw':
131 | return np.load(file_name)
132 | elif ext == '.flo':
133 | return readFlow(file_name).astype(np.float32)
134 | elif ext == '.pfm':
135 | flow = readPFM(file_name).astype(np.float32)
136 | if len(flow.shape) == 2:
137 | return flow
138 | else:
139 | return flow[:, :, :-1]
140 | return []
141 |
--------------------------------------------------------------------------------
/code.v.2.0/core/utils/grid_sample.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def grid_sample(image, optical):
6 | N, C, IH, IW = image.shape
7 | _, H, W, _ = optical.shape
8 |
9 | ix = optical[..., 0]
10 | iy = optical[..., 1]
11 |
12 | ix = ((ix + 1) / 2) * (IW-1)
13 | iy = ((iy + 1) / 2) * (IH-1)
14 | with torch.no_grad():
15 | ix_nw = torch.floor(ix)
16 | iy_nw = torch.floor(iy)
17 | ix_ne = ix_nw + 1
18 | iy_ne = iy_nw
19 | ix_sw = ix_nw
20 | iy_sw = iy_nw + 1
21 | ix_se = ix_nw + 1
22 | iy_se = iy_nw + 1
23 |
24 | nw = (ix_se - ix) * (iy_se - iy)
25 | ne = (ix - ix_sw) * (iy_sw - iy)
26 | sw = (ix_ne - ix) * (iy - iy_ne)
27 | se = (ix - ix_nw) * (iy - iy_nw)
28 |
29 | with torch.no_grad():
30 | torch.clamp(ix_nw, 0, IW-1, out=ix_nw)
31 | torch.clamp(iy_nw, 0, IH-1, out=iy_nw)
32 |
33 | torch.clamp(ix_ne, 0, IW-1, out=ix_ne)
34 | torch.clamp(iy_ne, 0, IH-1, out=iy_ne)
35 |
36 | torch.clamp(ix_sw, 0, IW-1, out=ix_sw)
37 | torch.clamp(iy_sw, 0, IH-1, out=iy_sw)
38 |
39 | torch.clamp(ix_se, 0, IW-1, out=ix_se)
40 | torch.clamp(iy_se, 0, IH-1, out=iy_se)
41 |
42 | image = image.view(N, C, IH * IW)
43 |
44 |
45 | nw_val = torch.gather(image, 2, (iy_nw * IW + ix_nw).long().view(N, 1, H * W).repeat(1, C, 1))
46 | ne_val = torch.gather(image, 2, (iy_ne * IW + ix_ne).long().view(N, 1, H * W).repeat(1, C, 1))
47 | sw_val = torch.gather(image, 2, (iy_sw * IW + ix_sw).long().view(N, 1, H * W).repeat(1, C, 1))
48 | se_val = torch.gather(image, 2, (iy_se * IW + ix_se).long().view(N, 1, H * W).repeat(1, C, 1))
49 |
50 | out_val = (nw_val.view(N, C, H, W) * nw.view(N, 1, H, W) +
51 | ne_val.view(N, C, H, W) * ne.view(N, 1, H, W) +
52 | sw_val.view(N, C, H, W) * sw.view(N, 1, H, W) +
53 | se_val.view(N, C, H, W) * se.view(N, 1, H, W))
54 |
55 | return out_val
--------------------------------------------------------------------------------
/code.v.2.0/core/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from scipy import interpolate
5 |
6 | from .grid_sample import grid_sample
7 |
8 |
9 | class InputPadder:
10 | """ Pads images such that dimensions are divisible by 8 """
11 | def __init__(self, dims, mode='sintel'):
12 | self.ht, self.wd = dims[-2:]
13 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
14 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
15 | if mode == 'sintel':
16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
17 | else:
18 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
19 |
20 | def pad(self, *inputs):
21 | return [F.pad(x, self._pad, mode='replicate') for x in inputs]
22 |
23 | def unpad(self,x):
24 | ht, wd = x.shape[-2:]
25 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
26 | return x[..., c[0]:c[1], c[2]:c[3]]
27 |
28 |
29 | def forward_interpolate(flow):
30 | flow = flow.detach().cpu().numpy()
31 | dx, dy = flow[0], flow[1]
32 |
33 | ht, wd = dx.shape
34 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
35 |
36 | x1 = x0 + dx
37 | y1 = y0 + dy
38 |
39 | x1 = x1.reshape(-1)
40 | y1 = y1.reshape(-1)
41 | dx = dx.reshape(-1)
42 | dy = dy.reshape(-1)
43 |
44 | # valid = (x1 > 0.1 * wd) & (x1 < 0.9 * wd) & (y1 > 0.1 * ht) & (y1 < 0.9 * ht)
45 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
46 | x1 = x1[valid]
47 | y1 = y1[valid]
48 | dx = dx[valid]
49 | dy = dy[valid]
50 |
51 | flow_x = interpolate.griddata(
52 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
53 |
54 | flow_y = interpolate.griddata(
55 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
56 |
57 | flow = np.stack([flow_x, flow_y], axis=0)
58 | return torch.from_numpy(flow).float()
59 |
60 |
61 | def bilinear_sampler(img, coords, mode='bilinear', mask=False):
62 | """ Wrapper for grid_sample, uses pixel coordinates """
63 | H, W = img.shape[-2:]
64 | xgrid, ygrid = coords.split([1,1], dim=-1)
65 | xgrid = 2*xgrid/(W-1) - 1
66 | ygrid = 2*ygrid/(H-1) - 1
67 |
68 | grid = torch.cat([xgrid, ygrid], dim=-1)
69 | img = F.grid_sample(img, grid, align_corners=True)
70 |
71 | # Enable higher order grad for JR
72 | # img = grid_sample(img, grid)
73 |
74 | if mask:
75 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
76 | return img, mask.float()
77 |
78 | return img
79 |
80 |
81 | def coords_grid(batch, ht, wd, device):
82 | coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device))
83 | coords = torch.stack(coords[::-1], dim=0).float()
84 | return coords[None].repeat(batch, 1, 1, 1)
85 |
86 |
87 | def upflow8(flow, mode='bilinear'):
88 | new_size = (8 * flow.shape[2], 8 * flow.shape[3])
89 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
90 |
--------------------------------------------------------------------------------
/code.v.2.0/log/val.txt:
--------------------------------------------------------------------------------
1 | Parameter Count: 13.395 M
2 | Load from checkpoints/deq-flow-H-things-test-1x.pth
3 | Validation KITTI: EPE: 3.763 (3.763), F1: 12.95 (12.95)
4 | [33mVALID | FORWARD | rel: 0.0005664261989295483; abs: 3.7999179363250732; nstep: tensor([60], device='cuda:0')[0m
5 | [33mVALID | FORWARD | rel: 0.0002732036809902638; abs: 1.8421716690063477; nstep: tensor([60], device='cuda:0')[0m
6 | Validation (clean) EPE: 1.266 (1.266), 1px: 91.50, 3px: 96.20, 5px: 97.24
7 | [33mVALID | FORWARD | rel: 0.0008356361649930477; abs: 5.66709566116333; nstep: tensor([60], device='cuda:0')[0m
8 | Validation (final) EPE: 2.584 (2.584), 1px: 86.48, 3px: 92.51, 5px: 94.21
9 | Parameter Count: 13.395 M
10 | Load from checkpoints/deq-flow-H-things-test-3x.pth
11 | Validation KITTI: EPE: 3.863 (3.863), F1: 13.52 (13.52)
12 | [33mVALID | FORWARD | rel: 0.0006333217606879771; abs: 4.249484539031982; nstep: tensor([60], device='cuda:0')[0m
13 | [33mVALID | FORWARD | rel: 0.00028998314519412816; abs: 1.9555532932281494; nstep: tensor([60], device='cuda:0')[0m
14 | Validation (clean) EPE: 1.270 (1.270), 1px: 91.87, 3px: 96.38, 5px: 97.38
15 | [33mVALID | FORWARD | rel: 0.0012048647040501237; abs: 8.165122032165527; nstep: tensor([60], device='cuda:0')[0m
16 | Validation (final) EPE: 2.500 (2.500), 1px: 86.74, 3px: 92.69, 5px: 94.43
17 | Parameter Count: 13.395 M
18 | Load from checkpoints/deq-flow-H-things-test-3x.pth
19 | Validation KITTI: EPE: 3.768 (3.768), F1: 13.41 (13.41)
20 | [33mVALID | FORWARD | rel: 0.0006324453861452639; abs: 4.2437052726745605; nstep: tensor([120], device='cuda:0')[0m
21 | [33mVALID | FORWARD | rel: 0.0002864457492250949; abs: 1.931783676147461; nstep: tensor([120], device='cuda:0')[0m
22 | Validation (clean) EPE: 1.275 (1.275), 1px: 91.86, 3px: 96.38, 5px: 97.39
23 | [33mVALID | FORWARD | rel: 0.001575868227519095; abs: 10.679956436157227; nstep: tensor([120], device='cuda:0')[0m
24 | Validation (final) EPE: 2.481 (2.481), 1px: 86.77, 3px: 92.74, 5px: 94.49
25 |
--------------------------------------------------------------------------------
/code.v.2.0/train_H.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python -u main.py --total_run 1 --start_run 1 --name deq-flow-H-naive-120k-C-36-6-1 \
4 | --stage chairs --validation chairs kitti \
5 | --gpus 0 1 2 --num_steps 120000 --eval_interval 20000 \
6 | --batch_size 12 --lr 0.0004 --image_size 368 496 --wdecay 0.0001 \
7 | --f_thres 36 --f_solver naive_solver \
8 | --n_losses 6 --phantom_grad 1 \
9 | --huge --wnorm
10 |
11 | python -u main.py --total_run 1 --start_run 1 --name deq-flow-H-naive-120k-T-40-2-3 \
12 | --stage things --validation sintel kitti \
13 | --restore_name deq-flow-H-naive-120k-C-36-6-1 \
14 | --gpus 0 1 2 --num_steps 120000 \
15 | --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001 \
16 | --f_thres 40 --f_solver naive_solver \
17 | --n_losses 2 --phantom_grad 3 \
18 | --huge --wnorm --all_grad
19 |
20 |
--------------------------------------------------------------------------------
/code.v.2.0/train_H_1_step_grad.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python -u main.py --total_run 1 --start_run 1 --name deq-flow-H-naive-120k-C-36-1-1 \
4 | --stage chairs --validation chairs kitti \
5 | --gpus 0 1 2 --num_steps 120000 --eval_interval 20000 \
6 | --batch_size 12 --lr 0.0004 --image_size 368 496 --wdecay 0.0001 \
7 | --f_thres 36 --f_solver naive_solver \
8 | --n_losses 1 --phantom_grad 1 \
9 | --huge --wnorm
10 |
11 | python -u main.py --total_run 1 --start_run 1 --name deq-flow-H-naive-120k-T-40-1-1 \
12 | --stage things --validation sintel kitti --restore_name deq-flow-H-naive-120k-C-36-1-1 \
13 | --gpus 0 1 2 --num_steps 120000 \
14 | --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001 \
15 | --f_thres 40 --f_solver naive_solver \
16 | --n_losses 1 --phantom_grad 1 \
17 | --huge --wnorm --all_grad
18 |
19 |
--------------------------------------------------------------------------------
/code.v.2.0/val.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python -u main.py --eval --name deq-flow-H-all-grad --stage things \
4 | --validation kitti sintel --restore_ckpt checkpoints/deq-flow-H-things-test-1x.pth --gpus 0 \
5 | --wnorm --f_thres 40 --f_solver naive_solver \
6 | --eval_factor 1.5 --huge
7 |
8 | python -u main.py --eval --name deq-flow-H-all-grad --stage things \
9 | --validation kitti sintel --restore_ckpt checkpoints/deq-flow-H-things-test-3x.pth --gpus 0 \
10 | --wnorm --f_thres 40 --f_solver naive_solver \
11 | --eval_factor 1.5 --huge
12 |
13 | python -u main.py --eval --name deq-flow-H-all-grad --stage things \
14 | --validation kitti sintel --restore_ckpt checkpoints/deq-flow-H-things-test-3x.pth --gpus 0 \
15 | --wnorm --f_thres 40 --f_solver naive_solver \
16 | --eval_factor 3.0 --huge
17 |
18 |
19 |
--------------------------------------------------------------------------------
/code.v.2.0/viz.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | sys.path.append('core')
4 |
5 | import argparse
6 | import copy
7 | import os
8 | import time
9 |
10 | import datasets
11 | import matplotlib.pyplot as plt
12 | import numpy as np
13 | import torch
14 | import torch.nn.functional as F
15 |
16 | import cv2
17 | from PIL import Image
18 |
19 | from utils import flow_viz, frame_utils
20 | from utils.utils import InputPadder, forward_interpolate
21 |
22 |
23 | @torch.no_grad()
24 | def sintel_visualization(model, split='test', warm_start=False, fixed_point_reuse=False, output_path='sintel_viz', **kwargs):
25 | """ Create visualization for the Sintel dataset """
26 | model.eval()
27 | for dstype in ['clean', 'final']:
28 | split = 'test' if split == 'test' else 'training'
29 | test_dataset = datasets.MpiSintel(split=split, aug_params=None, dstype=dstype)
30 |
31 | flow_prev, sequence_prev, fixed_point = None, None, None
32 | for test_id in range(len(test_dataset)):
33 | image1, image2, (sequence, frame) = test_dataset[test_id]
34 | if sequence != sequence_prev:
35 | flow_prev = None
36 | fixed_point = None
37 |
38 | padder = InputPadder(image1.shape)
39 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
40 |
41 | flow_low, flow_pr, info = model(image1, image2, flow_init=flow_prev, cached_result=fixed_point, **kwargs)
42 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()
43 |
44 | if warm_start:
45 | flow_prev = forward_interpolate(flow_low[0])[None].cuda()
46 |
47 | if fixed_point_reuse:
48 | net, flow_pred_low = info['cached_result']
49 | flow_pred_low = forward_interpolate(flow_pred_low[0])[None].cuda()
50 | fixed_point = (net, flow_pred_low)
51 |
52 | output_dir = os.path.join(output_path, dstype, sequence)
53 | output_file = os.path.join(output_dir, 'frame%04d.png' % (frame+1))
54 |
55 | if not os.path.exists(output_dir):
56 | os.makedirs(output_dir)
57 |
58 | # visualizaion
59 | img_flow = flow_viz.flow_to_image(flow)
60 | img_flow = cv2.cvtColor(img_flow, cv2.COLOR_RGB2BGR)
61 | cv2.imwrite(output_file, img_flow, [int(cv2.IMWRITE_PNG_COMPRESSION), 1])
62 |
63 | sequence_prev = sequence
64 |
65 |
66 | @torch.no_grad()
67 | def kitti_visualization(model, split='test', output_path='kitti_viz'):
68 | """ Create visualization for the KITTI dataset """
69 | model.eval()
70 | split = 'testing' if split == 'test' else 'training'
71 | test_dataset = datasets.KITTI(split='testing', aug_params=None)
72 |
73 | if not os.path.exists(output_path):
74 | os.makedirs(output_path)
75 |
76 | for test_id in range(len(test_dataset)):
77 | image1, image2, (frame_id, ) = test_dataset[test_id]
78 | padder = InputPadder(image1.shape, mode='kitti')
79 | image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
80 |
81 | _, flow_pr, _ = model(image1, image2)
82 | flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()
83 |
84 | output_filename = os.path.join(output_path, frame_id)
85 |
86 | # visualizaion
87 | img_flow = flow_viz.flow_to_image(flow)
88 | img_flow = cv2.cvtColor(img_flow, cv2.COLOR_RGB2BGR)
89 | cv2.imwrite(output_filename, img_flow, [int(cv2.IMWRITE_PNG_COMPRESSION), 1])
90 |
91 |
92 |
--------------------------------------------------------------------------------