├── .gitignore ├── DATASETS.md ├── LICENSE ├── MODEL_ZOO.md ├── README.md ├── conda_environment.yml ├── dataloader ├── __init__.py ├── depth │ ├── __init__.py │ ├── augmentation.py │ ├── datasets.py │ ├── download_demon_test.sh │ ├── download_demon_train.sh │ ├── prepare_demon_test.py │ ├── prepare_demon_train.py │ ├── scannet_banet_test_pairs.txt │ └── scannet_banet_train_pairs.txt ├── flow │ ├── __init__.py │ ├── chairs_split.txt │ ├── datasets.py │ └── transforms.py └── stereo │ ├── __init__.py │ ├── datasets.py │ └── transforms.py ├── demo ├── depth-scannet │ ├── color │ │ ├── 0048.png │ │ ├── 0054.png │ │ ├── 0060.png │ │ └── 0066.png │ ├── intrinsic │ │ └── intrinsic_depth.txt │ └── pose │ │ ├── 0048.txt │ │ ├── 0054.txt │ │ ├── 0060.txt │ │ └── 0066.txt ├── flow-davis │ ├── 00000.jpg │ ├── 00001.jpg │ └── 00002.jpg ├── kitti.mp4 └── stereo-middlebury │ ├── im0.png │ └── im1.png ├── evaluate_depth.py ├── evaluate_flow.py ├── evaluate_stereo.py ├── loss ├── __init__.py ├── depth_loss.py ├── flow_loss.py └── stereo_metric.py ├── main_depth.py ├── main_flow.py ├── main_stereo.py ├── pip_install.sh ├── scripts ├── depthsplat_depth_demo.sh ├── gmdepth_demo.sh ├── gmdepth_evaluate.sh ├── gmdepth_scale1_regrefine1_train.sh ├── gmdepth_scale1_train.sh ├── gmflow_demo.sh ├── gmflow_evaluate.sh ├── gmflow_scale1_train.sh ├── gmflow_scale2_regrefine6_train.sh ├── gmflow_scale2_train.sh ├── gmflow_submission.sh ├── gmstereo_demo.sh ├── gmstereo_evaluate.sh ├── gmstereo_scale1_train.sh ├── gmstereo_scale2_regrefine3_train.sh ├── gmstereo_scale2_train.sh └── gmstereo_submission.sh ├── unimatch ├── __init__.py ├── attention.py ├── backbone.py ├── dpt_head.py ├── geometry.py ├── ldm_unet │ ├── __init__.py │ ├── attention.py │ ├── cross_attention.py │ ├── unet.py │ └── util.py ├── matching.py ├── position.py ├── reg_refine.py ├── transformer.py ├── trident_conv.py ├── unimatch.py ├── unimatch_depthsplat.py ├── utils.py └── vit_fpn.py └── utils ├── dist_utils.py ├── file_io.py ├── flow_viz.py ├── frame_utils.py ├── logger.py ├── misc.py ├── utils.py └── visualization.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | output/ 165 | pretrained/ -------------------------------------------------------------------------------- /DATASETS.md: -------------------------------------------------------------------------------- 1 | # Datasets 2 | 3 | 4 | 5 | ## Optical Flow 6 | 7 | The datasets used to train and evaluate our GMFlow model are as follows: 8 | 9 | - [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs) 10 | - [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) 11 | - [Sintel](http://sintel.is.tue.mpg.de/) 12 | - [Virtual KITTI 2](https://europe.naverlabs.com/research/computer-vision/proxy-virtual-worlds-vkitti-2/) 13 | - [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) 14 | - [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/) 15 | 16 | By default the dataloader [dataloader/flow/datasets.py](dataloader/flow/datasets.py) assumes the datasets are located in the `datasets` directory. 17 | 18 | It is recommended to symlink your dataset root to `datasets`: 19 | 20 | ``` 21 | ln -s $YOUR_DATASET_ROOT datasets 22 | ``` 23 | 24 | Otherwise, you may need to change the corresponding paths in [dataloader/flow/datasets.py](dataloader/flow/datasets.py). 25 | 26 | 27 | 28 | ## Stereo Matching 29 | 30 | The datasets used to train and evaluate our GMStereo model are as follows: 31 | 32 | - [Scene Flow](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) 33 | - [Virtual KITTI 2](https://europe.naverlabs.com/research/computer-vision/proxy-virtual-worlds-vkitti-2/) 34 | - [KITTI](https://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=stereo) 35 | - [TartanAir](https://github.com/castacks/tartanair_tools) 36 | - [Falling Things](https://research.nvidia.com/publication/2018-06_Falling-Things) 37 | - [HR-VS](https://drive.google.com/file/d/1SgEIrH_IQTKJOToUwR1rx4-237sThUqX/view) 38 | - [CREStereo Dataset](https://github.com/megvii-research/CREStereo/blob/master/dataset_download.sh) 39 | - [InStereo2K](https://github.com/YuhuaXu/StereoDataset) 40 | - [Middlebury](https://vision.middlebury.edu/stereo/data/) 41 | - [Sintel Stereo](http://sintel.is.tue.mpg.de/stereo) 42 | - [ETH3D](https://www.eth3d.net/datasets#low-res-two-view-training-data) 43 | 44 | By default the dataloader [dataloader/stereo/datasets.py](dataloader/stereo/datasets.py) assumes the datasets are located in the `datasets` directory. 45 | 46 | It is recommended to symlink your dataset root to `datasets`: 47 | 48 | ``` 49 | ln -s $YOUR_DATASET_ROOT datasets 50 | ``` 51 | 52 | Otherwise, you may need to change the corresponding paths in [dataloader/stereo/datasets.py](dataloader/flow/datasets.py). 53 | 54 | 55 | 56 | ## Depth Estimation 57 | 58 | The datasets used to train and evaluate our GMDepth model are as follows: 59 | 60 | - [DeMoN](https://github.com/lmb-freiburg/demon) 61 | - [ScanNet](http://www.scan-net.org/) 62 | 63 | We support downloading and extracting the DeMoN dataset in our code: [dataloader/depth/download_demon_train.sh](dataloader/depth/download_demon_train.sh), [dataloader/depth/download_demon_test.sh](dataloader/depth/download_demon_test.sh), [dataloader/depth/prepare_demon_train.sh](dataloader/depth/prepare_demon_train.sh) and [dataloader/depth/prepare_demon_test.sh](dataloader/depth/prepare_demon_test.sh). 64 | 65 | By default the dataloader [dataloader/depth/datasets.py](dataloader/depth/datasets.py) assumes the datasets are located in the `datasets` directory. 66 | 67 | It is recommended to symlink your dataset root to `datasets`: 68 | 69 | ``` 70 | ln -s $YOUR_DATASET_ROOT datasets 71 | ``` 72 | 73 | Otherwise, you may need to change the corresponding paths in [dataloader/depth/datasets.py](dataloader/depth/datasets.py). 74 | 75 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 autonomousvision 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MODEL_ZOO.md: -------------------------------------------------------------------------------- 1 | # Model Zoo 2 | 3 | - The models are named as `model-dataset`. 4 | - Model definition: `scale1` denotes the 1/8 feature resolution model, `scale2` denotes the 1/8 & 1/4 model, `scaleX-regrefineY` denotes the `X`-scale model with additional `Y` local regression refinements. 5 | - The inference time is averaged over 100 runs, measured with batch size 1 on a single NVIDIA A100 GPU. 6 | - All pretrained models can be downloaded together at [pretrained.zip](https://s3.eu-central-1.amazonaws.com/avg-projects/unimatch/pretrained.zip), or they can be downloaded individually as listed below. 7 | 8 | 9 | 10 | ## Optical Flow 11 | 12 | - The inference time is measured for Sintel resolution: 448x1024 13 | 14 | - The `*-mixdata` models are trained on several mixed public datasets, which are recommended for in-the-wild use cases. 15 | 16 | 17 | 18 | | Model | Params (M) | Time (ms) | Download | 19 | | --------------------------------- | :--------: | :-------: | :----------------------------------------------------------: | 20 | | GMFlow-scale1-things | 4.7 | 26 | [download](https://s3.eu-central-1.amazonaws.com/avg-projects/unimatch/pretrained/gmflow-scale1-things-e9887eda.pth) | 21 | | GMFlow-scale1-mixdata | 4.7 | 26 | [download](https://s3.eu-central-1.amazonaws.com/avg-projects/unimatch/pretrained/gmflow-scale1-mixdata-train320x576-4c3a6e9a.pth) | 22 | | GMFlow-scale2-things | 4.7 | 66 | [download](https://s3.eu-central-1.amazonaws.com/avg-projects/unimatch/pretrained/gmflow-scale2-things-36579974.pth) | 23 | | GMFlow-scale2-sintel | 4.7 | 66 | [download](https://s3.eu-central-1.amazonaws.com/avg-projects/unimatch/pretrained/gmflow-scale2-sintel-3ed1cf48.pth) | 24 | | GMFlow-scale2-mixdata | 4.7 | 66 | [download](https://s3.eu-central-1.amazonaws.com/avg-projects/unimatch/pretrained/gmflow-scale2-mixdata-train320x576-9ff1c094.pth) | 25 | | GMFlow-scale2-regrefine6-things | 7.4 | 122 | [download](https://s3.eu-central-1.amazonaws.com/avg-projects/unimatch/pretrained/gmflow-scale2-regrefine6-things-776ed612.pth) | 26 | | GMFlow-scale2-regrefine6-sintelft | 7.4 | 122 | [download](https://s3.eu-central-1.amazonaws.com/avg-projects/unimatch/pretrained/gmflow-scale2-regrefine6-sintelft-6e39e2b9.pth) | 27 | | GMFlow-scale2-regrefine6-kitti | 7.4 | 122 | [download](https://s3.eu-central-1.amazonaws.com/avg-projects/unimatch/pretrained/gmflow-scale2-regrefine6-kitti15-25b554d7.pth) | 28 | | GMFlow-scale2-regrefine6-mixdata | 7.4 | 122 | [download](https://s3.eu-central-1.amazonaws.com/avg-projects/unimatch/pretrained/gmflow-scale2-regrefine6-mixdata-train320x576-4e7b215d.pth) | 29 | 30 | 31 | 32 | ## Stereo Matching 33 | 34 | - The inference time is measured for KITTI resolution: 384x1248 35 | - The `*-resumeflowthings-*` denotes that the models are trained with GMFlow model as initialization, where GMFlow is trained on Chairs and Things dataset for optical flow task. 36 | - The `*-mixdata` models are trained on several mixed public datasets, which are recommended for in-the-wild use cases. 37 | 38 | | Model | Params (M) | Time (ms) | Download | 39 | | ------------------------------------------------------ | :--------: | :-------: | :--------: | 40 | | GMStereo-scale1-sceneflow | 4.7 | 23 | [download](https://s3.eu-central-1.amazonaws.com/avg-projects/unimatch/pretrained/gmstereo-scale1-sceneflow-124a438f.pth) | 41 | | GMStereo-scale1-resumeflowthings-sceneflow | 4.7 | 23 | [download](https://s3.eu-central-1.amazonaws.com/avg-projects/unimatch/pretrained/gmstereo-scale1-resumeflowthings-sceneflow-16e38788.pth) | 42 | | GMStereo-scale2-sceneflow | 4.7 | 58 | [download](https://s3.eu-central-1.amazonaws.com/avg-projects/unimatch/pretrained/gmstereo-scale2-sceneflow-ab93ba6a.pth) | 43 | | GMStereo-scale2-resumeflowthings-sceneflow | 4.7 | 58 | [download](https://s3.eu-central-1.amazonaws.com/avg-projects/unimatch/pretrained/gmstereo-scale2-resumeflowthings-sceneflow-48020649.pth) | 44 | | GMStereo-scale2-regrefine3-sceneflow | 7.4 | 86 | [download](https://s3.eu-central-1.amazonaws.com/avg-projects/unimatch/pretrained/gmstereo-scale2-regrefine3-sceneflow-2dd12e97.pth) | 45 | | GMStereo-scale2-regrefine3-resumeflowthings-sceneflow | 7.4 | 86 | [download](https://s3.eu-central-1.amazonaws.com/avg-projects/unimatch/pretrained/gmstereo-scale2-regrefine3-resumeflowthings-sceneflow-f724fee6.pth) | 46 | | GMStereo-scale2-regrefine3-resumeflowthings-kitti | 7.4 | 86 | [download](https://s3.eu-central-1.amazonaws.com/avg-projects/unimatch/pretrained/gmstereo-scale2-regrefine3-resumeflowthings-kitti15-04487ebf.pth) | 47 | | GMStereo-scale2-regrefine3-resumeflowthings-middlebury | 7.4 | 86 | [download](https://s3.eu-central-1.amazonaws.com/avg-projects/unimatch/pretrained/gmstereo-scale2-regrefine3-resumeflowthings-middleburyfthighres-a82bec03.pth) | 48 | | GMStereo-scale2-regrefine3-resumeflowthings-eth3dft | 7.4 | 86 | [download](https://s3.eu-central-1.amazonaws.com/avg-projects/unimatch/pretrained/gmstereo-scale2-regrefine3-resumeflowthings-eth3dft-a807cb16.pth) | 49 | | GMStereo-scale2-regrefine3-resumeflowthings-mixdata | 7.4 | 86 | [download](https://s3.eu-central-1.amazonaws.com/avg-projects/unimatch/pretrained/gmstereo-scale2-regrefine3-resumeflowthings-mixdata-train320x640-ft640x960-e4e291fd.pth) | 50 | 51 | 52 | 53 | ## Depth Estimation 54 | 55 | - The inference time is measured for ScanNet resolution: 480x640 56 | 57 | - The `*-resumeflowthings-*` models are trained with a pretrained GMFlow model as initialization, where GMFlow is trained on Chairs and Things dataset for optical flow task. 58 | 59 | 60 | 61 | | Model | Params (M) | Time (ms) | Download | 62 | | -------------------------------------------------- | :--------: | :-------: | :----------------------------------------------------------: | 63 | | GMDepth-scale1-scannet | 4.7 | 17 | [download](https://s3.eu-central-1.amazonaws.com/avg-projects/unimatch/pretrained/gmdepth-scale1-scannet-d3d1efb5.pth) | 64 | | GMDepth-scale1-resumeflowthings-scannet | 4.7 | 17 | [download](https://s3.eu-central-1.amazonaws.com/avg-projects/unimatch/pretrained/gmdepth-scale1-resumeflowthings-scannet-5d9d7964.pth) | 65 | | GMDepth-scale1-regrefine1-resumeflowthings-scannet | 4.7 | 20 | [download](https://s3.eu-central-1.amazonaws.com/avg-projects/unimatch/pretrained/gmdepth-scale1-regrefine1-resumeflowthings-scannet-90325722.pth) | 66 | | GMDepth-scale1-demon | 7.3 | 17 | [download](https://s3.eu-central-1.amazonaws.com/avg-projects/unimatch/pretrained/gmdepth-scale1-demon-bd64786e.pth) | 67 | | GMDepth-scale1-resumeflowthings-demon | 7.3 | 17 | [download](https://s3.eu-central-1.amazonaws.com/avg-projects/unimatch/pretrained/gmdepth-scale1-resumeflowthings-demon-a2fe127b.pth) | 68 | | GMDepth-scale1-regrefine1-resumeflowthings-demon | 7.3 | 20 | [download](https://s3.eu-central-1.amazonaws.com/avg-projects/unimatch/pretrained/gmdepth-scale1-regrefine1-resumeflowthings-demon-7c23f230.pth) | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |

Unifying Flow, Stereo and Depth Estimation

3 |

4 | Haofei Xu 5 | · 6 | Jing Zhang 7 | · 8 | Jianfei Cai 9 | · 10 | Hamid Rezatofighi 11 | · 12 | Fisher Yu 13 | · 14 | Dacheng Tao 15 | · 16 | Andreas Geiger 17 |

18 |

TPAMI 2023

19 |

Paper | Slides | Project Page | Colab | Demo

20 |
21 |

22 |

23 | 24 | Logo 25 | 26 |

27 | 28 | 29 |

30 | A unified model for three motion and 3D perception tasks. 31 |

32 |

33 | 34 | Logo 35 | 36 |

37 | 38 |

39 | We achieve the 1st places on Sintel (clean), Middlebury (rms metric) and Argoverse benchmarks. 40 |

41 | 42 | This project is developed based on our previous works: 43 | 44 | - [GMFlow: Learning Optical Flow via Global Matching, CVPR 2022, Oral](https://github.com/haofeixu/gmflow) 45 | 46 | - [High-Resolution Optical Flow from 1D Attention and Correlation, ICCV 2021, Oral](https://github.com/haofeixu/flow1d) 47 | 48 | - [AANet: Adaptive Aggregation Network for Efficient Stereo Matching, CVPR 2020](https://github.com/haofeixu/aanet) 49 | 50 | 51 | ## Updates 52 | 53 | - 2025-01-04: Check out [DepthSplat](https://haofeixu.github.io/depthsplat/) for a modern multi-view depth model, which leverages monocular depth ([Depth Anything V2](https://github.com/DepthAnything/Depth-Anything-V2)) to significantly improve the robustness of UniMatch. 54 | 55 | - 2025-01-04: The UniMatch depth model served as the foundational backbone of [MVSplat (ECCV 2024, Oral)](https://donydchen.github.io/mvsplat/) for sparse-view feed-forward 3DGS reconstruction. 56 | 57 | ## Installation 58 | 59 | Our code is developed based on pytorch 1.9.0, CUDA 10.2 and python 3.8. Higher version pytorch should also work well. 60 | 61 | We recommend using [conda](https://www.anaconda.com/distribution/) for installation: 62 | 63 | ``` 64 | conda env create -f conda_environment.yml 65 | conda activate unimatch 66 | ``` 67 | 68 | Alternatively, we also support installing with pip: 69 | 70 | ``` 71 | bash pip_install.sh 72 | ``` 73 | 74 | 75 | To use the [depth models from DepthSplat](https://github.com/cvg/depthsplat/blob/main/MODEL_ZOO.md), you need to create a new conda environment with higher version dependencies: 76 | 77 | ``` 78 | conda create -y -n depthsplat-depth python=3.10 79 | conda activate depthsplat-depth 80 | pip install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124 81 | pip install tensorboard==2.9.1 einops opencv-python>=4.8.1.78 matplotlib 82 | ``` 83 | 84 | 85 | ## Model Zoo 86 | 87 | A large number of pretrained models with different speed-accuracy trade-offs for flow, stereo and depth are available at [MODEL_ZOO.md](MODEL_ZOO.md). 88 | 89 | Check out [DepthSplat's Model Zoo](https://github.com/cvg/depthsplat/blob/main/MODEL_ZOO.md) for better depth models. 90 | 91 | We assume the downloaded weights are located under the `pretrained` directory. 92 | 93 | Otherwise, you may need to change the corresponding paths in the scripts. 94 | 95 | 96 | 97 | ## Demo 98 | 99 | Given an image pair or a video sequence, our code supports generating prediction results of optical flow, disparity and depth. 100 | 101 | Please refer to [scripts/gmflow_demo.sh](scripts/gmflow_demo.sh), [scripts/gmstereo_demo.sh](scripts/gmstereo_demo.sh), [scripts/gmdepth_demo.sh](scripts/gmdepth_demo.sh) and [scripts/depthsplat_depth_demo.sh](scripts/depthsplat_depth_demo.sh) for example usages. 102 | 103 | 104 | 105 | 106 | https://user-images.githubusercontent.com/19343475/199893756-998cb67e-37d7-4323-ab6e-82fd3cbcd529.mp4 107 | 108 | 109 | 110 | ## Datasets 111 | 112 | The datasets used to train and evaluate our models for all three tasks are given in [DATASETS.md](DATASETS.md) 113 | 114 | 115 | 116 | ## Evaluation 117 | 118 | The evaluation scripts used to reproduce the numbers in our paper are given in [scripts/gmflow_evaluate.sh](scripts/gmflow_evaluate.sh), [scripts/gmstereo_evaluate.sh](scripts/gmstereo_evaluate.sh) and [scripts/gmdepth_evaluate.sh](scripts/gmdepth_evaluate.sh). 119 | 120 | For submission to KITTI, Sintel, Middlebury and ETH3D online test sets, you can run [scripts/gmflow_submission.sh](scripts/gmflow_submission.sh) and [scripts/gmstereo_submission.sh](scripts/gmstereo_submission.sh) to generate the prediction results. The results can be submitted directly. 121 | 122 | 123 | 124 | ## Training 125 | 126 | All training scripts for different model variants on different datasets can be found in [scripts/*_train.sh](scripts). 127 | 128 | We support using tensorboard to monitor and visualize the training process. You can first start a tensorboard session with 129 | 130 | ``` 131 | tensorboard --logdir checkpoints 132 | ``` 133 | 134 | and then access [http://localhost:6006](http://localhost:6006/) in your browser. 135 | 136 | 137 | 138 | ## Citation 139 | 140 | ``` 141 | @article{xu2023unifying, 142 | title={Unifying Flow, Stereo and Depth Estimation}, 143 | author={Xu, Haofei and Zhang, Jing and Cai, Jianfei and Rezatofighi, Hamid and Yu, Fisher and Tao, Dacheng and Geiger, Andreas}, 144 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 145 | year={2023} 146 | } 147 | ``` 148 | 149 | This work is a substantial extension of our previous conference paper [GMFlow (CVPR 2022, Oral)](https://arxiv.org/abs/2111.13680), please consider citing GMFlow as well if you found this work useful in your research. 150 | 151 | ``` 152 | @inproceedings{xu2022gmflow, 153 | title={GMFlow: Learning Optical Flow via Global Matching}, 154 | author={Xu, Haofei and Zhang, Jing and Cai, Jianfei and Rezatofighi, Hamid and Tao, Dacheng}, 155 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 156 | pages={8121-8130}, 157 | year={2022} 158 | } 159 | ``` 160 | 161 | Please consider citing [DepthSplat](https://arxiv.org/abs/2410.13862) if DepthSplat's depth model is used in your research. 162 | 163 | ``` 164 | @article{xu2024depthsplat, 165 | title = {DepthSplat: Connecting Gaussian Splatting and Depth}, 166 | author = {Xu, Haofei and Peng, Songyou and Wang, Fangjinhua and Blum, Hermann and Barath, Daniel and Geiger, Andreas and Pollefeys, Marc}, 167 | journal = {arXiv preprint arXiv:2410.13862}, 168 | year = {2024} 169 | } 170 | ``` 171 | 172 | 173 | ## Acknowledgements 174 | 175 | This project would not have been possible without relying on some awesome repos: [RAFT](https://github.com/princeton-vl/RAFT), [LoFTR](https://github.com/zju3dv/LoFTR), [DETR](https://github.com/facebookresearch/detr), [Swin](https://github.com/microsoft/Swin-Transformer), [mmdetection](https://github.com/open-mmlab/mmdetection) and [Detectron2](https://github.com/facebookresearch/detectron2/blob/main/projects/TridentNet/tridentnet/trident_conv.py). We thank the original authors for their excellent work. 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | -------------------------------------------------------------------------------- /conda_environment.yml: -------------------------------------------------------------------------------- 1 | name: unimatch 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1 7 | - _openmp_mutex=5.1 8 | - blas=1.0 9 | - brotli=1.0.9 10 | - brotli-bin=1.0.9 11 | - bzip2=1.0.8 12 | - ca-certificates=2022.10.11 13 | - certifi=2022.9.24 14 | - cloudpickle=2.0.0 15 | - cudatoolkit=10.2.89 16 | - cycler=0.11.0 17 | - cytoolz=0.12.0 18 | - dask-core=2022.7.0 19 | - dbus=1.13.18 20 | - expat=2.4.9 21 | - ffmpeg=4.3 22 | - fftw=3.3.9 23 | - fontconfig=2.13.1 24 | - fonttools=4.25.0 25 | - freetype=2.12.1 26 | - fsspec=2022.10.0 27 | - giflib=5.2.1 28 | - glib=2.69.1 29 | - gmp=6.2.1 30 | - gnutls=3.6.15 31 | - gst-plugins-base=1.14.0 32 | - gstreamer=1.14.0 33 | - icu=58.2 34 | - imageio=2.9.0 35 | - intel-openmp=2021.4.0 36 | - jpeg=9b 37 | - kiwisolver=1.4.2 38 | - lame=3.100 39 | - lcms2=2.12 40 | - ld_impl_linux-64=2.38 41 | - libbrotlicommon=1.0.9 42 | - libbrotlidec=1.0.9 43 | - libbrotlienc=1.0.9 44 | - libffi=3.3 45 | - libgcc-ng=11.2.0 46 | - libgfortran-ng=11.2.0 47 | - libgfortran5=11.2.0 48 | - libgomp=11.2.0 49 | - libiconv=1.16 50 | - libidn2=2.3.2 51 | - libpng=1.6.37 52 | - libstdcxx-ng=11.2.0 53 | - libtasn1=4.16.0 54 | - libtiff=4.1.0 55 | - libunistring=0.9.10 56 | - libuuid=1.41.5 57 | - libuv=1.40.0 58 | - libwebp=1.2.0 59 | - libxcb=1.15 60 | - libxml2=2.9.14 61 | - locket=1.0.0 62 | - lz4-c=1.9.3 63 | - matplotlib=3.5.1 64 | - matplotlib-base=3.5.1 65 | - mkl=2021.4.0 66 | - mkl-service=2.4.0 67 | - mkl_fft=1.3.1 68 | - mkl_random=1.2.2 69 | - munkres=1.1.4 70 | - ncurses=6.3 71 | - nettle=3.7.3 72 | - networkx=2.8.4 73 | - ninja=1.10.2 74 | - ninja-base=1.10.2 75 | - numpy=1.19.2 76 | - numpy-base=1.19.2 77 | - openh264=2.1.1 78 | - openssl=1.1.1s 79 | - packaging=21.3 80 | - partd=1.2.0 81 | - pcre=8.45 82 | - pillow=9.0.1 83 | - pip=22.2.2 84 | - pyparsing=3.0.9 85 | - pyqt=5.9.2 86 | - python=3.8.13 87 | - python-dateutil=2.8.2 88 | - pytorch=1.9.0 89 | - pywavelets=1.3.0 90 | - pyyaml=6.0 91 | - qt=5.9.7 92 | - readline=8.2 93 | - scikit-image=0.19.2 94 | - scipy=1.9.3 95 | - sip=4.19.13 96 | - six=1.16.0 97 | - sqlite=3.39.3 98 | - tifffile=2020.10.1 99 | - tk=8.6.12 100 | - toolz=0.12.0 101 | - torchvision=0.10.0 102 | - tornado=6.2 103 | - typing_extensions=4.3.0 104 | - wheel=0.37.1 105 | - xz=5.2.6 106 | - yaml=0.2.5 107 | - zlib=1.2.13 108 | - zstd=1.4.9 109 | - pip: 110 | - absl-py==1.3.0 111 | - cachetools==5.2.0 112 | - charset-normalizer==2.1.1 113 | - google-auth==2.14.1 114 | - google-auth-oauthlib==0.4.6 115 | - grpcio==1.50.0 116 | - h5py==3.7.0 117 | - idna==3.4 118 | - imageio-ffmpeg==0.4.7 119 | - importlib-metadata==5.0.0 120 | - joblib==1.2.0 121 | - lz4==4.0.2 122 | - markdown==3.4.1 123 | - markupsafe==2.1.1 124 | - oauthlib==3.2.2 125 | - opencv-python==4.6.0.66 126 | - path==16.5.0 127 | - protobuf==3.19.6 128 | - pyasn1==0.4.8 129 | - pyasn1-modules==0.2.8 130 | - requests==2.28.1 131 | - requests-oauthlib==1.3.1 132 | - rsa==4.9 133 | - setuptools==59.5.0 134 | - tensorboard==2.9.1 135 | - tensorboard-data-server==0.6.1 136 | - tensorboard-plugin-wit==1.8.1 137 | - tqdm==4.64.1 138 | - urllib3==1.26.12 139 | - werkzeug==2.2.2 140 | - zipp==3.10.0 141 | -------------------------------------------------------------------------------- /dataloader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/unimatch/943a99006950ab6c84ea72a72beb147fdd03c19a/dataloader/__init__.py -------------------------------------------------------------------------------- /dataloader/depth/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/unimatch/943a99006950ab6c84ea72a72beb147fdd03c19a/dataloader/depth/__init__.py -------------------------------------------------------------------------------- /dataloader/depth/augmentation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from PIL import Image 4 | import cv2 5 | import torchvision.transforms.functional as F 6 | import random 7 | 8 | 9 | class Compose(object): 10 | def __init__(self, transforms): 11 | self.transforms = transforms 12 | 13 | def __call__(self, sample): 14 | for t in self.transforms: 15 | sample = t(sample) 16 | return sample 17 | 18 | 19 | class ToTensor(object): 20 | """Convert numpy array to torch tensor""" 21 | 22 | def __call__(self, sample): 23 | for key, value in sample.items(): 24 | if isinstance(value, np.ndarray): 25 | sample[key] = torch.from_numpy(value) 26 | 27 | if isinstance(value, list): # multi-frame target images 28 | sample[key] = [torch.from_numpy(v) for v in value] 29 | 30 | sample['img_ref'] = sample['img_ref'].permute((2, 0, 1)) / 255. # [3, H, W] 31 | sample['img_tgt'] = sample['img_tgt'].permute((2, 0, 1)) / 255. # [3, H, W] 32 | 33 | return sample 34 | 35 | 36 | class Normalize(object): 37 | """Normalize image, with type tensor""" 38 | 39 | def __init__(self, mean, std): 40 | self.mean = mean 41 | self.std = std 42 | 43 | def __call__(self, sample): 44 | norm_keys = ['img_ref', 'img_tgt'] 45 | 46 | assert isinstance(sample['img_ref'], torch.Tensor) 47 | assert sample['img_ref'].size(0) == 3 # [3, H, W] 48 | 49 | for key in norm_keys: 50 | # multi-frame inference 51 | if key == 'img_tgt' and isinstance(sample['img_tgt'], list): 52 | for i in range(len(sample['img_tgt'])): 53 | # Images have converted to tensor, with shape [C, H, W] 54 | tgt = sample['img_tgt'][i] 55 | for t, m, s in zip(tgt, self.mean, self.std): 56 | t.sub_(m).div_(s) 57 | sample['img_tgt'][i] = tgt 58 | else: 59 | # Images have converted to tensor, with shape [C, H, W] 60 | for t, m, s in zip(sample[key], self.mean, self.std): 61 | t.sub_(m).div_(s) 62 | 63 | return sample 64 | 65 | 66 | class RandomCrop(object): 67 | def __init__(self, crop_size): 68 | self.crop_size = crop_size 69 | 70 | def __call__(self, sample): 71 | crop_h, crop_w = self.crop_size 72 | 73 | ori_h, ori_w = sample['img_ref'].shape[:2] 74 | 75 | out_intrinsics = sample['intrinsics'].copy() 76 | 77 | offset_y = np.random.randint(ori_h - crop_h + 1) 78 | offset_x = np.random.randint(ori_w - crop_w + 1) 79 | 80 | for key in ['img_ref', 'img_tgt', 'depth']: 81 | sample[key] = sample[key][offset_y:offset_y + crop_h, offset_x:offset_x + crop_w] 82 | 83 | # valid mask for sparse data 84 | if 'valid' in sample: 85 | sample['valid'] = sample['valid'][offset_y:offset_y + crop_h, offset_x:offset_x + crop_w] 86 | 87 | out_intrinsics[0, 2] -= offset_x 88 | out_intrinsics[1, 2] -= offset_y 89 | 90 | sample['intrinsics'] = out_intrinsics 91 | 92 | return sample 93 | 94 | 95 | class RandomColor(object): 96 | def __init__(self, asymmetric=True): 97 | self.asymmetric = asymmetric 98 | 99 | def __call__(self, sample): 100 | transforms = [RandomContrast(asymmetric=self.asymmetric), 101 | RandomGamma(asymmetric=self.asymmetric), 102 | RandomBrightness(asymmetric=self.asymmetric), 103 | RandomHue(asymmetric=self.asymmetric), 104 | RandomSaturation(asymmetric=self.asymmetric)] 105 | 106 | sample = ToPILImage()(sample) 107 | 108 | if np.random.random() < 0.5: 109 | # A single transform 110 | t = random.choice(transforms) 111 | sample = t(sample) 112 | else: 113 | # Combination of transforms 114 | # Random order 115 | random.shuffle(transforms) 116 | for t in transforms: 117 | sample = t(sample) 118 | 119 | sample = ToNumpyArray()(sample) 120 | 121 | return sample 122 | 123 | 124 | class RandomResize(object): 125 | def __init__(self, min_size, min_scale=-0.2, max_scale=0.2): 126 | # min_size bigger than crop_size 127 | self.min_size = min_size 128 | 129 | self.min_scale = min_scale 130 | self.max_scale = max_scale 131 | 132 | self.stretch_prob = 0.4 133 | self.max_stretch = 0.2 134 | 135 | def __call__(self, sample): 136 | if np.random.random() < 0.5: 137 | min_h, min_w = self.min_size 138 | ori_h, ori_w = sample['img_ref'].shape[:2] 139 | 140 | min_scale = np.maximum(min_h / float(ori_h), min_w / float(ori_w), dtype=np.float32) 141 | 142 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 143 | scale_x = scale 144 | scale_y = scale 145 | 146 | if np.random.random() < self.stretch_prob: 147 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 148 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 149 | 150 | scale_x = np.clip(scale_x, min_scale, None).astype(np.float32) 151 | scale_y = np.clip(scale_y, min_scale, None).astype(np.float32) 152 | 153 | # Resize 154 | sample['img_ref'] = cv2.resize(sample['img_ref'], None, fx=scale_x, fy=scale_y, 155 | interpolation=cv2.INTER_LINEAR) 156 | sample['img_tgt'] = cv2.resize(sample['img_tgt'], None, fx=scale_x, fy=scale_y, 157 | interpolation=cv2.INTER_LINEAR) 158 | 159 | if 'depth' in sample: 160 | sample['depth'] = cv2.resize(sample['depth'], None, fx=scale_x, fy=scale_y, 161 | interpolation=cv2.INTER_LINEAR) 162 | 163 | if 'valid' in sample: 164 | sample['valid'] = cv2.resize(sample['valid'], None, fx=scale_x, fy=scale_y, 165 | interpolation=cv2.INTER_LINEAR) 166 | sample['valid'] = (sample['valid'] > 0.99).astype(np.float32) 167 | 168 | out_intrinsics = sample['intrinsics'].copy() 169 | out_intrinsics[0] = out_intrinsics[0] * scale_x 170 | out_intrinsics[1] = out_intrinsics[1] * scale_y 171 | sample['intrinsics'] = out_intrinsics 172 | 173 | return sample 174 | 175 | 176 | class ToPILImage(object): 177 | 178 | def __call__(self, sample): 179 | sample['img_ref'] = Image.fromarray(sample['img_ref'].astype('uint8')) 180 | sample['img_tgt'] = Image.fromarray(sample['img_tgt'].astype('uint8')) 181 | 182 | return sample 183 | 184 | 185 | class ToNumpyArray(object): 186 | 187 | def __call__(self, sample): 188 | sample['img_ref'] = np.array(sample['img_ref']).astype(np.float32) 189 | sample['img_tgt'] = np.array(sample['img_tgt']).astype(np.float32) 190 | 191 | return sample 192 | 193 | 194 | # Random coloring 195 | class RandomContrast(object): 196 | """Random contrast""" 197 | 198 | def __init__(self, asymmetric=False): 199 | self.asymmetric = asymmetric 200 | 201 | def __call__(self, sample): 202 | if np.random.random() < 0.5: 203 | contrast_factor = np.random.uniform(0.8, 1.2) 204 | 205 | sample['img_ref'] = F.adjust_contrast(sample['img_ref'], contrast_factor) 206 | 207 | if self.asymmetric and np.random.random() < 0.2: 208 | contrast_factor = np.random.uniform(0.8, 1.2) 209 | 210 | sample['img_tgt'] = F.adjust_contrast(sample['img_tgt'], contrast_factor) 211 | 212 | return sample 213 | 214 | 215 | class RandomGamma(object): 216 | def __init__(self, asymmetric=False): 217 | self.asymmetric = asymmetric 218 | 219 | def __call__(self, sample): 220 | if np.random.random() < 0.5: 221 | gamma = np.random.uniform(0.7, 1.5) # adopted from FlowNet 222 | 223 | sample['img_ref'] = F.adjust_gamma(sample['img_ref'], gamma) 224 | 225 | if self.asymmetric and np.random.random() < 0.2: 226 | gamma = np.random.uniform(0.7, 1.5) 227 | 228 | sample['img_tgt'] = F.adjust_gamma(sample['img_tgt'], gamma) 229 | 230 | return sample 231 | 232 | 233 | class RandomBrightness(object): 234 | def __init__(self, asymmetric=False): 235 | self.asymmetric = asymmetric 236 | 237 | def __call__(self, sample): 238 | if np.random.random() < 0.5: 239 | brightness = np.random.uniform(0.5, 2.0) 240 | 241 | sample['img_ref'] = F.adjust_brightness(sample['img_ref'], brightness) 242 | 243 | if self.asymmetric and np.random.random() < 0.2: 244 | brightness = np.random.uniform(0.5, 2.0) 245 | 246 | sample['img_tgt'] = F.adjust_brightness(sample['img_tgt'], brightness) 247 | 248 | return sample 249 | 250 | 251 | class RandomHue(object): 252 | def __init__(self, asymmetric=False): 253 | self.asymmetric = asymmetric 254 | 255 | def __call__(self, sample): 256 | if np.random.random() < 0.5: 257 | hue = np.random.uniform(-0.1, 0.1) 258 | 259 | sample['img_ref'] = F.adjust_hue(sample['img_ref'], hue) 260 | 261 | if self.asymmetric and np.random.random() < 0.2: 262 | hue = np.random.uniform(-0.1, 0.1) 263 | 264 | sample['img_tgt'] = F.adjust_hue(sample['img_tgt'], hue) 265 | 266 | return sample 267 | 268 | 269 | class RandomSaturation(object): 270 | def __init__(self, asymmetric=False): 271 | self.asymmetric = asymmetric 272 | 273 | def __call__(self, sample): 274 | if np.random.random() < 0.5: 275 | saturation = np.random.uniform(0.8, 1.2) 276 | 277 | sample['img_ref'] = F.adjust_saturation(sample['img_ref'], saturation) 278 | 279 | if self.asymmetric and np.random.random() < 0.2: 280 | saturation = np.random.uniform(0.8, 1.2) 281 | 282 | sample['img_tgt'] = F.adjust_saturation(sample['img_tgt'], saturation) 283 | 284 | return sample 285 | -------------------------------------------------------------------------------- /dataloader/depth/datasets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import Dataset 3 | 4 | import os 5 | from glob import glob 6 | from PIL import Image 7 | 8 | from utils.file_io import read_img 9 | 10 | 11 | class ScannetDataset(Dataset): 12 | def __init__(self, 13 | data_dir='datasets/Scannet', 14 | mode='train', 15 | transforms=None, 16 | return_filename=False, 17 | ): 18 | 19 | self.data_dir = data_dir 20 | self.transforms = transforms 21 | self.return_filename = return_filename 22 | 23 | assert mode in ['train', 'test', 'demo'] 24 | 25 | self.mode = mode 26 | 27 | self.samples = [] 28 | 29 | # following BA-Net's splits 30 | dir_path = os.path.dirname(os.path.realpath(__file__)) 31 | split_file = 'scannet_banet_' + mode + '_pairs.txt' 32 | 33 | split_file = os.path.join(dir_path, split_file) 34 | 35 | with open(split_file) as f: 36 | pairs = f.readlines() 37 | 38 | pairs = [pair.rstrip() for pair in pairs] 39 | 40 | for i in range(len(pairs)): 41 | scene, img_ref_id, img_tgt_id = pairs[i].split(' ') 42 | key = 'scannet_' + mode + '_' + scene + '_' + img_ref_id 43 | 44 | scene = os.path.join(data_dir, mode, scene) 45 | 46 | intrinsics = os.path.join(scene, 'intrinsic', 'intrinsic_depth.txt') 47 | 48 | img_ref = os.path.join(scene, 'color', img_ref_id + '.jpg') 49 | img_tgt = os.path.join(scene, 'color', img_tgt_id + '.jpg') 50 | 51 | depth = os.path.join(scene, 'depth', img_ref_id + '.png') 52 | pose_ref = os.path.join(scene, 'pose', img_ref_id + '.txt') 53 | pose_tgt = os.path.join(scene, 'pose', img_tgt_id + '.txt') 54 | 55 | if not os.path.isfile(img_ref) or not os.path.isfile(img_tgt) or not os.path.isfile(depth) or \ 56 | not os.path.isfile(pose_ref) or not os.path.isfile(pose_tgt): 57 | continue 58 | 59 | sample = (img_ref, img_tgt, pose_ref, pose_tgt, depth, intrinsics, key) 60 | 61 | self.samples.append(sample) 62 | 63 | def __getitem__(self, i): 64 | img_ref, img_tgt, pose_ref, pose_tgt, depth, intrinsics, key = self.samples[i] 65 | 66 | img_ref_filename, img_tgt_filename = img_ref, img_tgt 67 | 68 | img_ref = self._read_image(img_ref) 69 | img_tgt = self._read_image(img_tgt) 70 | depth = self._read_depth(depth) 71 | valid = (depth > 0).astype(np.float32) # invalid depth is 0 72 | 73 | # pose: camera to world 74 | pose_ref = np.loadtxt(pose_ref, delimiter=' ').astype(np.float32).reshape((4, 4)) 75 | pose_tgt = np.loadtxt(pose_tgt, delimiter=' ').astype(np.float32).reshape((4, 4)) 76 | 77 | # relative pose 78 | pose = np.linalg.inv(pose_tgt) @ pose_ref 79 | 80 | intrinsics = np.loadtxt(intrinsics).astype(np.float32).reshape((4, 4))[:3, :3] # [3, 3] 81 | 82 | sample = { 83 | 'img_ref': img_ref, 84 | 'img_tgt': img_tgt, 85 | 'intrinsics': intrinsics, 86 | 'pose': pose, 87 | 'depth': depth, 88 | 'valid': valid, 89 | } 90 | 91 | if self.transforms is not None: 92 | sample = self.transforms(sample) 93 | 94 | if self.return_filename: 95 | return img_ref_filename, img_tgt_filename, sample 96 | 97 | return sample 98 | 99 | def __len__(self): 100 | 101 | return len(self.samples) 102 | 103 | def _read_image(self, filename): 104 | img = Image.open(filename).resize((640, 480)) # resize to depth shape 105 | img = np.array(img).astype(np.float32) 106 | 107 | return img 108 | 109 | def _read_depth(self, filename): 110 | depth = np.array(Image.open(filename)).astype(np.float32) / 1000. 111 | 112 | return depth 113 | 114 | def __rmul__(self, v): 115 | self.samples = v * self.samples 116 | 117 | return self 118 | 119 | 120 | class DemonDataset(Dataset): 121 | def __init__(self, 122 | data_dir='datasets/Demon', 123 | mode='train', 124 | transforms=None, 125 | sequence_length=2, 126 | ): 127 | 128 | if 'test' in mode: 129 | data_dir = os.path.join(data_dir, 'test') 130 | else: 131 | data_dir = os.path.join(data_dir, 'train') 132 | 133 | self.data_dir = data_dir 134 | self.transforms = transforms 135 | 136 | assert sequence_length == 2 # only support two input views currently 137 | 138 | self.samples = [] 139 | 140 | scenes = [os.path.join(data_dir, scene_dir) for scene_dir in sorted(os.listdir(data_dir)) 141 | if os.path.isdir(os.path.join(os.path.join(data_dir, scene_dir))) and mode in scene_dir] 142 | 143 | demi_length = sequence_length // 2 144 | 145 | for scene in scenes: 146 | intrinsics = np.genfromtxt(os.path.join(scene, 'cam.txt')).astype(np.float32).reshape((3, 3)) # [3, 3] 147 | poses = np.genfromtxt(os.path.join(scene, 'poses.txt')).astype(np.float32) 148 | imgs = sorted(glob(os.path.join(scene, '*.jpg'))) 149 | if len(imgs) < sequence_length: 150 | continue 151 | for i in range(len(imgs)): 152 | if i < demi_length: 153 | shifts = list(range(0, sequence_length)) 154 | shifts.pop(i) 155 | elif i >= len(imgs) - demi_length: 156 | shifts = list(range(len(imgs) - sequence_length, len(imgs))) 157 | shifts.pop(i - len(imgs)) 158 | else: 159 | shifts = list(range(i - demi_length, i + (sequence_length + 1) // 2)) 160 | shifts.pop(demi_length) 161 | 162 | img_ref = imgs[i] 163 | depth = os.path.join(os.path.dirname(img_ref), os.path.basename(img_ref)[:-4] + '.npy') 164 | pose_ref = np.concatenate((poses[i, :].reshape((3, 4)), np.array([[0, 0, 0, 1]])), axis=0) # [4, 4] 165 | 166 | assert len(shifts) < 2 # only support two input images currently 167 | 168 | for j in shifts: 169 | img_tgt = imgs[j] 170 | pose_tgt = np.concatenate((poses[j, :].reshape((3, 4)), np.array([[0, 0, 0, 1]])), axis=0) 171 | pose = (pose_tgt @ np.linalg.inv(pose_ref)).astype(np.float32) # [4, 4] 172 | 173 | sample = (img_ref, img_tgt, pose, depth, intrinsics) 174 | 175 | self.samples.append(sample) 176 | 177 | def __getitem__(self, i): 178 | img_ref, img_tgt, pose, depth, intrinsics = self.samples[i] 179 | 180 | img_ref = read_img(img_ref) 181 | img_tgt = read_img(img_tgt) 182 | depth = np.load(depth) 183 | valid = (depth > 0).astype(np.float32) # invalid depth is 0 184 | 185 | sample = { 186 | 'img_ref': img_ref, 187 | 'img_tgt': img_tgt, 188 | 'intrinsics': intrinsics, 189 | 'pose': pose, 190 | 'depth': depth, 191 | 'valid': valid, 192 | } 193 | 194 | if self.transforms is not None: 195 | sample = self.transforms(sample) 196 | 197 | return sample 198 | 199 | def __len__(self): 200 | 201 | return len(self.samples) 202 | -------------------------------------------------------------------------------- /dataloader/depth/download_demon_test.sh: -------------------------------------------------------------------------------- 1 | # Source from https://github.com/lmb-freiburg/demon 2 | #!/bin/bash 3 | clear 4 | cat << EOF 5 | ================================================================================ 6 | The test datasets are provided for research purposes only. 7 | Some of the test datasets build upon other publicly available data. 8 | Make sure to cite the respective original source of the data if you use the 9 | provided files for your research. 10 | * sun3d_test.h5 is based on the SUN3D dataset http://sun3d.cs.princeton.edu/ 11 | J. Xiao, A. Owens, and A. Torralba, “SUN3D: A Database of Big Spaces Reconstructed Using SfM and Object Labels,” in 2013 IEEE International Conference on Computer Vision (ICCV), 2013, pp. 1625–1632. 12 | 13 | * rgbd_test.h5 is based on the RGBD SLAM benchmark http://vision.in.tum.de/data/datasets/rgbd-dataset (licensed under CC-BY 3.0) 14 | 15 | J. Sturm, N. Engelhard, F. Endres, W. Burgard, and D. Cremers, “A benchmark for the evaluation of RGB-D SLAM systems,” in 2012 IEEE/RSJ International Conference on Intelligent Robots and Systems, 2012, pp. 573–580. 16 | * scenes11_test.h5 uses objects from shapenet https://www.shapenet.org/ 17 | 18 | A. X. Chang et al., “ShapeNet: An Information-Rich 3D Model Repository,” arXiv:1512.03012 [cs], Dec. 2015. 19 | * mvs_test.h5 contains scenes from https://colmap.github.io/datasets.html 20 | 21 | J. L. Schönberger and J. M. Frahm, “Structure-from-Motion Revisited,” in 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2016, pp. 4104–4113. 22 | J. L. Schönberger, E. Zheng, J.-M. Frahm, and M. Pollefeys, “Pixelwise View Selection for Unstructured Multi-View Stereo,” in Computer Vision – ECCV 2016, 2016, pp. 501–518. 23 | * nyu2_test.h5 is based on http://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html 24 | 25 | N. Silberman, D. Hoiem, P. Kohli, and R. Fergus, “Indoor Segmentation and Support Inference from RGBD Images,” in Computer Vision – ECCV 2012, 2012, pp. 746–760. 26 | ================================================================================ 27 | type Y to start the download. 28 | EOF 29 | 30 | read -s -n 1 answer 31 | if [ "$answer" != "Y" -a "$answer" != "y" ]; then 32 | exit 0 33 | fi 34 | echo 35 | 36 | datasets=(sun3d rgbd mvs scenes11) 37 | 38 | OLD_PWD="$PWD" 39 | DESTINATION=testdata 40 | mkdir $DESTINATION 41 | cd $DESTINATION 42 | 43 | for ds in ${datasets[@]}; do 44 | if [ -e "${ds}_test.h5" ]; then 45 | echo "${ds}_test.h5 already exists, skipping ${ds}" 46 | else 47 | wget "https://lmb.informatik.uni-freiburg.de/data/demon/testdata/${ds}_test.tgz" 48 | tar -xvf "${ds}_test.tgz" 49 | fi 50 | done 51 | 52 | cd "$OLD_PWD" -------------------------------------------------------------------------------- /dataloader/depth/download_demon_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | clear 3 | cat << EOF 4 | 5 | ================================================================================ 6 | 7 | 8 | The train datasets are provided for research purposes only. 9 | 10 | Some of the test datasets build upon other publicly available data. 11 | Make sure to cite the respective original source of the data if you use the 12 | provided files for your research. 13 | 14 | * sun3d_train.h5 is based on the SUN3D dataset http://sun3d.cs.princeton.edu/ 15 | 16 | J. Xiao, A. Owens, and A. Torralba, “SUN3D: A Database of Big Spaces Reconstructed Using SfM and Object Labels,” in 2013 IEEE International Conference on Computer Vision (ICCV), 2013, pp. 1625–1632. 17 | 18 | 19 | 20 | 21 | * rgbd_bugfix_train.h5 is based on the RGBD SLAM benchmark http://vision.in.tum.de/data/datasets/rgbd-dataset (licensed under CC-BY 3.0) 22 | 23 | J. Sturm, N. Engelhard, F. Endres, W. Burgard, and D. Cremers, “A benchmark for the evaluation of RGB-D SLAM systems,” in 2012 IEEE/RSJ International Conference on Intelligent Robots and Systems, 2012, pp. 573–580. 24 | 25 | 26 | 27 | * scenes11_train.h5 uses objects from shapenet https://www.shapenet.org/ 28 | 29 | A. X. Chang et al., “ShapeNet: An Information-Rich 3D Model Repository,” arXiv:1512.03012 [cs], Dec. 2015. 30 | 31 | 32 | 33 | * mvs_train.h5 contains the Citywall and Achteck-Turm scenes from MVE (Multi-View Environment) http://www.gcc.tu-darmstadt.de/home/proj/mve/ 34 | 35 | S. Fuhrmann, F. Langguth, and M. Goesele, “MVE: A Multi-view Reconstruction Environment,” in Proceedings of the Eurographics Workshop on Graphics and Cultural Heritage, Aire-la-Ville, Switzerland, Switzerland, 2014, pp. 11–18. 36 | 37 | 38 | 39 | ================================================================================ 40 | 41 | type Y to start the download. 42 | 43 | EOF 44 | 45 | read -s -n 1 answer 46 | if [ "$answer" != "Y" -a "$answer" != "y" ]; then 47 | exit 0 48 | fi 49 | echo 50 | 51 | datasets=(sun3d rgbd mvs scenes11) 52 | 53 | OLD_PWD="$PWD" 54 | DESTINATION=traindata 55 | mkdir $DESTINATION 56 | cd $DESTINATION 57 | 58 | if [ ! -e "README_traindata" ]; then 59 | wget --no-check-certificate "https://lmb.informatik.uni-freiburg.de/data/demon/traindata/README_traindata" 60 | fi 61 | 62 | for ds in ${datasets[@]}; do 63 | if [ -e "${ds}_train.h5" ]; then 64 | echo "${ds}_train.h5 already exists, skipping ${ds}" 65 | else 66 | wget --no-check-certificate "https://lmb.informatik.uni-freiburg.de/data/demon/traindata/${ds}_train.tgz" 67 | tar -xvf "${ds}_train.tgz" 68 | fi 69 | done 70 | 71 | cd "$OLD_PWD" -------------------------------------------------------------------------------- /dataloader/depth/prepare_demon_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import subprocess 4 | 5 | from joblib import Parallel, delayed 6 | import numpy as np 7 | import imageio 8 | 9 | imageio.plugins.freeimage.download() 10 | from imageio.plugins import freeimage 11 | import h5py 12 | from lz4.block import decompress 13 | import scipy.misc 14 | 15 | import cv2 16 | 17 | from path import Path 18 | 19 | path = os.path.join(os.path.dirname(os.path.abspath(__file__))) 20 | 21 | 22 | def dump_example(dataset_name): 23 | print("Converting {:}.h5 ...".format(dataset_name)) 24 | file = h5py.File(os.path.join(path, "testdata", "{:}.h5".format(dataset_name)), "r") 25 | 26 | for (seq_idx, seq_name) in enumerate(file): 27 | if dataset_name == 'scenes11_test': 28 | scale = 0.4 29 | else: 30 | scale = 1 31 | 32 | print("Processing sequence {:d}/{:d}".format(seq_idx, len(file))) 33 | dump_dir = os.path.join(path, 'test', dataset_name + "_" + "{:05d}".format(seq_idx)) 34 | if not os.path.isdir(dump_dir): 35 | os.mkdir(dump_dir) 36 | dump_dir = Path(dump_dir) 37 | sequence = file[seq_name]["frames"]["t0"] 38 | poses = [] 39 | for (f_idx, f_name) in enumerate(sequence): 40 | frame = sequence[f_name] 41 | for dt_type in frame: 42 | dataset = frame[dt_type] 43 | img = dataset[...] 44 | if dt_type == "camera": 45 | if f_idx == 0: 46 | intrinsics = np.array([[img[0], 0, img[3]], [0, img[1], img[4]], [0, 0, 1]]) 47 | pose = np.array( 48 | [[img[5], img[8], img[11], img[14] * scale], [img[6], img[9], img[12], img[15] * scale], 49 | [img[7], img[10], img[13], img[16] * scale]]) 50 | poses.append(pose.tolist()) 51 | elif dt_type == "depth": 52 | dimension = dataset.attrs["extents"] 53 | depth = np.array(np.frombuffer(decompress(img.tobytes(), dimension[0] * dimension[1] * 2), 54 | dtype=np.float16)).astype(np.float32) 55 | depth = depth.reshape(dimension[0], dimension[1]) * scale 56 | 57 | dump_depth_file = dump_dir / '{:04d}.npy'.format(f_idx) 58 | np.save(dump_depth_file, depth) 59 | elif dt_type == "image": 60 | img = imageio.imread(img.tobytes()) 61 | dump_img_file = dump_dir / '{:04d}.jpg'.format(f_idx) 62 | imageio.imsave(dump_img_file, img) 63 | 64 | dump_cam_file = dump_dir / 'cam.txt' 65 | np.savetxt(dump_cam_file, intrinsics) 66 | poses_file = dump_dir / 'poses.txt' 67 | np.savetxt(poses_file, np.array(poses).reshape(-1, 12), fmt='%.6e') 68 | 69 | if len(dump_dir.files('*.jpg')) < 2: 70 | dump_dir.rmtree() 71 | 72 | 73 | def preparedata(): 74 | num_threads = 1 75 | SUB_DATASET_NAMES = (["rgbd_test", "scenes11_test", "sun3d_test"]) 76 | 77 | dump_root = os.path.join(path, 'test') 78 | if not os.path.isdir(dump_root): 79 | os.mkdir(dump_root) 80 | 81 | if num_threads == 1: 82 | for scene in SUB_DATASET_NAMES: 83 | dump_example(scene) 84 | else: 85 | Parallel(n_jobs=num_threads)(delayed(dump_example)(scene) for scene in SUB_DATASET_NAMES) 86 | 87 | dump_root = Path(dump_root) 88 | subdirs = dump_root.dirs() 89 | subdirs = [subdir.basename() for subdir in subdirs] 90 | subdirs = sorted(subdirs) 91 | with open(dump_root / 'test.txt', 'w') as tf: 92 | for subdir in subdirs: 93 | tf.write('{}\n'.format(subdir)) 94 | 95 | print("Finished Converting Data.") 96 | 97 | 98 | if __name__ == "__main__": 99 | preparedata() 100 | -------------------------------------------------------------------------------- /dataloader/depth/prepare_demon_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import subprocess 4 | 5 | from joblib import Parallel, delayed 6 | import numpy as np 7 | import imageio 8 | 9 | imageio.plugins.freeimage.download() 10 | from imageio.plugins import freeimage 11 | import h5py 12 | from lz4.block import decompress 13 | import scipy.misc 14 | import cv2 15 | 16 | from path import Path 17 | 18 | path = os.path.join(os.path.dirname(os.path.abspath(__file__))) 19 | 20 | 21 | def dump_example(dataset_name): 22 | print("Converting {:}.h5 ...".format(dataset_name)) 23 | file = h5py.File(os.path.join(path, "traindata", "{:}.h5".format(dataset_name)), "r") 24 | 25 | for (seq_idx, seq_name) in enumerate(file): 26 | if dataset_name == 'scenes11_train': 27 | scale = 0.4 28 | else: 29 | scale = 1 30 | 31 | if ((dataset_name == 'sun3d_train_1.6m_to_infm' and seq_idx == 7) or \ 32 | (dataset_name == 'sun3d_train_0.4m_to_0.8m' and seq_idx == 15) or \ 33 | (dataset_name == 'scenes11_train' and ( 34 | seq_idx == 2758 or seq_idx == 4691 or seq_idx == 7023 or seq_idx == 11157 or seq_idx == 17168 or seq_idx == 19595))): 35 | continue # Skip error files 36 | 37 | print("Processing sequence {:d}/{:d}".format(seq_idx, len(file))) 38 | dump_dir = os.path.join(path, '../train', dataset_name + "_" + "{:05d}".format(seq_idx)) 39 | if not os.path.isdir(dump_dir): 40 | os.mkdir(dump_dir) 41 | dump_dir = Path(dump_dir) 42 | sequence = file[seq_name]["frames"]["t0"] 43 | poses = [] 44 | for (f_idx, f_name) in enumerate(sequence): 45 | frame = sequence[f_name] 46 | for dt_type in frame: 47 | dataset = frame[dt_type] 48 | img = dataset[...] 49 | if dt_type == "camera": 50 | if f_idx == 0: 51 | intrinsics = np.array([[img[0], 0, img[3]], [0, img[1], img[4]], [0, 0, 1]]) 52 | pose = np.array( 53 | [[img[5], img[8], img[11], img[14] * scale], [img[6], img[9], img[12], img[15] * scale], 54 | [img[7], img[10], img[13], img[16] * scale]]) 55 | poses.append(pose.tolist()) 56 | elif dt_type == "depth": 57 | dimension = dataset.attrs["extents"] 58 | depth = np.array(np.frombuffer(decompress(img.tobytes(), dimension[0] * dimension[1] * 2), 59 | dtype=np.float16)).astype(np.float32) 60 | depth = depth.reshape(dimension[0], dimension[1]) * scale 61 | 62 | dump_depth_file = dump_dir / '{:04d}.npy'.format(f_idx) 63 | np.save(dump_depth_file, depth) 64 | elif dt_type == "image": 65 | img = imageio.imread(img.tobytes()) 66 | dump_img_file = dump_dir / '{:04d}.jpg'.format(f_idx) 67 | imageio.imsave(dump_img_file, img) 68 | 69 | dump_cam_file = dump_dir / 'cam.txt' 70 | np.savetxt(dump_cam_file, intrinsics) 71 | poses_file = dump_dir / 'poses.txt' 72 | np.savetxt(poses_file, np.array(poses).reshape(-1, 12), fmt='%.6e') 73 | 74 | if len(dump_dir.files('*.jpg')) < 2: 75 | dump_dir.rmtree() 76 | 77 | 78 | def preparedata(): 79 | num_threads = 1 80 | SUB_DATASET_NAMES = ([ 81 | "rgbd_10_to_20_3d_train", "rgbd_10_to_20_handheld_train", "rgbd_10_to_20_simple_train", 82 | "rgbd_20_to_inf_3d_train", "rgbd_20_to_inf_handheld_train", "rgbd_20_to_inf_simple_train", 83 | "sun3d_train_0.01m_to_0.1m", "sun3d_train_0.1m_to_0.2m", "sun3d_train_0.2m_to_0.4m", "sun3d_train_0.4m_to_0.8m", 84 | "sun3d_train_0.8m_to_1.6m", "sun3d_train_1.6m_to_infm", 85 | "scenes11_train", 86 | ]) 87 | 88 | dump_root = os.path.join(path, 'train') 89 | if not os.path.isdir(dump_root): 90 | os.mkdir(dump_root) 91 | 92 | if num_threads == 1: 93 | for scene in SUB_DATASET_NAMES: 94 | dump_example(scene) 95 | else: 96 | Parallel(n_jobs=num_threads)(delayed(dump_example)(scene) for scene in SUB_DATASET_NAMES) 97 | 98 | np.random.seed(8964) 99 | dump_root = Path(dump_root) 100 | subdirs = dump_root.dirs() 101 | canonic_prefixes = set([subdir.basename()[:-2] for subdir in subdirs]) 102 | with open(dump_root / 'train.txt', 'w') as tf: 103 | with open(dump_root / 'val.txt', 'w') as vf: 104 | for pr in canonic_prefixes: 105 | corresponding_dirs = dump_root.dirs('{}*'.format(pr)) 106 | if np.random.random() < 0.1: 107 | for s in corresponding_dirs: 108 | vf.write('{}\n'.format(s.name)) 109 | else: 110 | for s in corresponding_dirs: 111 | tf.write('{}\n'.format(s.name)) 112 | 113 | print("Finished Converting Data.") 114 | 115 | 116 | if __name__ == "__main__": 117 | preparedata() 118 | -------------------------------------------------------------------------------- /dataloader/flow/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/unimatch/943a99006950ab6c84ea72a72beb147fdd03c19a/dataloader/flow/__init__.py -------------------------------------------------------------------------------- /dataloader/stereo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/unimatch/943a99006950ab6c84ea72a72beb147fdd03c19a/dataloader/stereo/__init__.py -------------------------------------------------------------------------------- /demo/depth-scannet/color/0048.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/unimatch/943a99006950ab6c84ea72a72beb147fdd03c19a/demo/depth-scannet/color/0048.png -------------------------------------------------------------------------------- /demo/depth-scannet/color/0054.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/unimatch/943a99006950ab6c84ea72a72beb147fdd03c19a/demo/depth-scannet/color/0054.png -------------------------------------------------------------------------------- /demo/depth-scannet/color/0060.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/unimatch/943a99006950ab6c84ea72a72beb147fdd03c19a/demo/depth-scannet/color/0060.png -------------------------------------------------------------------------------- /demo/depth-scannet/color/0066.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/unimatch/943a99006950ab6c84ea72a72beb147fdd03c19a/demo/depth-scannet/color/0066.png -------------------------------------------------------------------------------- /demo/depth-scannet/intrinsic/intrinsic_depth.txt: -------------------------------------------------------------------------------- 1 | 577.590698 0.000000 318.905426 0.000000 2 | 0.000000 578.729797 242.683609 0.000000 3 | 0.000000 0.000000 1.000000 0.000000 4 | 0.000000 0.000000 0.000000 1.000000 5 | -------------------------------------------------------------------------------- /demo/depth-scannet/pose/0048.txt: -------------------------------------------------------------------------------- 1 | 0.703694 -0.367391 0.608144 1.896290 2 | -0.708482 -0.427345 0.561630 2.467417 3 | 0.053549 -0.826075 -0.561010 1.399475 4 | 0.000000 0.000000 0.000000 1.000000 5 | -------------------------------------------------------------------------------- /demo/depth-scannet/pose/0054.txt: -------------------------------------------------------------------------------- 1 | 0.750884 -0.329503 0.572363 1.915689 2 | -0.658300 -0.443024 0.608580 2.368469 3 | 0.053042 -0.833760 -0.549572 1.413484 4 | 0.000000 0.000000 0.000000 1.000000 5 | -------------------------------------------------------------------------------- /demo/depth-scannet/pose/0060.txt: -------------------------------------------------------------------------------- 1 | 0.776779 -0.282017 0.563098 1.923212 2 | -0.625761 -0.446388 0.639656 2.259246 3 | 0.070966 -0.849237 -0.523221 1.407526 4 | 0.000000 0.000000 0.000000 1.000000 5 | -------------------------------------------------------------------------------- /demo/depth-scannet/pose/0066.txt: -------------------------------------------------------------------------------- 1 | 0.794580 -0.275900 0.540852 1.915812 2 | -0.604505 -0.442681 0.662273 2.192295 3 | 0.056703 -0.853178 -0.518529 1.423975 4 | 0.000000 0.000000 0.000000 1.000000 5 | -------------------------------------------------------------------------------- /demo/flow-davis/00000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/unimatch/943a99006950ab6c84ea72a72beb147fdd03c19a/demo/flow-davis/00000.jpg -------------------------------------------------------------------------------- /demo/flow-davis/00001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/unimatch/943a99006950ab6c84ea72a72beb147fdd03c19a/demo/flow-davis/00001.jpg -------------------------------------------------------------------------------- /demo/flow-davis/00002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/unimatch/943a99006950ab6c84ea72a72beb147fdd03c19a/demo/flow-davis/00002.jpg -------------------------------------------------------------------------------- /demo/kitti.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/unimatch/943a99006950ab6c84ea72a72beb147fdd03c19a/demo/kitti.mp4 -------------------------------------------------------------------------------- /demo/stereo-middlebury/im0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/unimatch/943a99006950ab6c84ea72a72beb147fdd03c19a/demo/stereo-middlebury/im0.png -------------------------------------------------------------------------------- /demo/stereo-middlebury/im1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/unimatch/943a99006950ab6c84ea72a72beb147fdd03c19a/demo/stereo-middlebury/im1.png -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/unimatch/943a99006950ab6c84ea72a72beb147fdd03c19a/loss/__init__.py -------------------------------------------------------------------------------- /loss/depth_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import numpy as np 4 | 5 | 6 | def compute_errors(gt, pred): 7 | """Computation of error metrics between predicted and ground truth depths 8 | """ 9 | thresh = np.maximum((gt / pred), (pred / gt)) 10 | a1 = (thresh < 1.25).mean() 11 | a2 = (thresh < 1.25 ** 2).mean() 12 | a3 = (thresh < 1.25 ** 3).mean() 13 | 14 | rmse = (gt - pred) ** 2 15 | rmse = np.sqrt(rmse.mean()) 16 | 17 | rmse_log = (np.log(gt) - np.log(pred)) ** 2 18 | rmse_log = np.sqrt(rmse_log.mean()) 19 | 20 | abs_rel = np.mean(np.abs(gt - pred) / gt) 21 | 22 | sq_rel = np.mean(((gt - pred) ** 2) / gt) 23 | 24 | return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3 25 | 26 | 27 | def get_depth_grad_loss(depth_pred, depth_gt, valid, inverse_depth_loss=True): 28 | # default is on inverse depth 29 | # both: [B, H, W] 30 | assert depth_pred.dim() == 3 and depth_gt.dim() == 3 and valid.dim() == 3 31 | 32 | valid = valid > 0.5 33 | valid_x = valid[:, :, :-1] & valid[:, :, 1:] 34 | valid_y = valid[:, :-1, :] & valid[:, 1:, :] 35 | 36 | if valid_x.max() < 0.5 or valid_y.max() < 0.5: # no valid pixel 37 | return 0. 38 | 39 | if inverse_depth_loss: 40 | grad_pred_x = torch.abs(1. / depth_pred[:, :, :-1][valid_x] - 1. / depth_pred[:, :, 1:][valid_x]) 41 | grad_pred_y = torch.abs(1. / depth_pred[:, :-1, :][valid_y] - 1. / depth_pred[:, 1:, :][valid_y]) 42 | 43 | grad_gt_x = torch.abs(1. / depth_gt[:, :, :-1][valid_x] - 1. / depth_gt[:, :, 1:][valid_x]) 44 | grad_gt_y = torch.abs(1. / depth_gt[:, :-1, :][valid_y] - 1. / depth_gt[:, 1:, :][valid_y]) 45 | else: 46 | grad_pred_x = torch.abs((depth_pred[:, :, :-1] - depth_pred[:, :, 1:])[valid_x]) 47 | grad_pred_y = torch.abs((depth_pred[:, :-1, :] - depth_pred[:, 1:, :])[valid_y]) 48 | 49 | grad_gt_x = torch.abs((depth_gt[:, :, :-1] - depth_gt[:, :, 1:])[valid_x]) 50 | grad_gt_y = torch.abs((depth_gt[:, :-1, :] - depth_gt[:, 1:, :])[valid_y]) 51 | 52 | loss_grad_x = torch.abs(grad_pred_x - grad_gt_x).mean() 53 | loss_grad_y = torch.abs(grad_pred_y - grad_gt_y).mean() 54 | 55 | return loss_grad_x + loss_grad_y 56 | 57 | 58 | def depth_grad_loss_func(depth_preds, depth_gt, valid, 59 | inverse_depth_loss=True, 60 | gamma=0.9): 61 | num = len(depth_preds) 62 | loss = 0. 63 | 64 | for i in range(num): 65 | weight = gamma ** (num - i - 1) 66 | loss += weight * get_depth_grad_loss(depth_preds[i], depth_gt, valid, 67 | inverse_depth_loss=inverse_depth_loss) 68 | 69 | return loss 70 | 71 | 72 | def depth_loss_func(depth_preds, depth_gt, valid, gamma=0.9, 73 | ): 74 | """ loss function defined over multiple depth predictions """ 75 | 76 | n_predictions = len(depth_preds) 77 | depth_loss = 0.0 78 | 79 | for i in range(n_predictions): 80 | i_weight = gamma ** (n_predictions - i - 1) 81 | 82 | # inverse depth loss 83 | valid_bool = valid > 0.5 84 | if valid_bool.max() < 0.5: # no valid pixel 85 | i_loss = 0. 86 | else: 87 | i_loss = (1. / depth_preds[i][valid_bool] - 1. / depth_gt[valid_bool]).abs().mean() 88 | 89 | depth_loss += i_weight * i_loss 90 | 91 | return depth_loss 92 | -------------------------------------------------------------------------------- /loss/flow_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def flow_loss_func(flow_preds, flow_gt, valid, 5 | gamma=0.9, 6 | max_flow=400, 7 | **kwargs, 8 | ): 9 | n_predictions = len(flow_preds) 10 | flow_loss = 0.0 11 | 12 | # exlude invalid pixels and extremely large diplacements 13 | mag = torch.sum(flow_gt ** 2, dim=1).sqrt() # [B, H, W] 14 | valid = (valid >= 0.5) & (mag < max_flow) 15 | 16 | for i in range(n_predictions): 17 | i_weight = gamma ** (n_predictions - i - 1) 18 | 19 | i_loss = (flow_preds[i] - flow_gt).abs() 20 | 21 | flow_loss += i_weight * (valid[:, None] * i_loss).mean() 22 | 23 | epe = torch.sum((flow_preds[-1] - flow_gt) ** 2, dim=1).sqrt() 24 | 25 | if valid.max() < 0.5: 26 | pass 27 | 28 | epe = epe.view(-1)[valid.view(-1)] 29 | 30 | metrics = { 31 | 'epe': epe.mean().item(), 32 | '1px': (epe > 1).float().mean().item(), 33 | '3px': (epe > 3).float().mean().item(), 34 | '5px': (epe > 5).float().mean().item(), 35 | } 36 | 37 | return flow_loss, metrics 38 | -------------------------------------------------------------------------------- /loss/stereo_metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def epe_metric(d_est, d_gt, mask, use_np=False): 6 | d_est, d_gt = d_est[mask], d_gt[mask] 7 | if use_np: 8 | epe = np.mean(np.abs(d_est - d_gt)) 9 | else: 10 | epe = torch.mean(torch.abs(d_est - d_gt)) 11 | 12 | return epe 13 | 14 | 15 | def d1_metric(d_est, d_gt, mask, use_np=False): 16 | d_est, d_gt = d_est[mask], d_gt[mask] 17 | if use_np: 18 | e = np.abs(d_gt - d_est) 19 | else: 20 | e = torch.abs(d_gt - d_est) 21 | err_mask = (e > 3) & (e / d_gt > 0.05) 22 | 23 | if use_np: 24 | mean = np.mean(err_mask.astype('float')) 25 | else: 26 | mean = torch.mean(err_mask.float()) 27 | 28 | return mean 29 | 30 | 31 | def bad_pixel_metric(d_est, d_gt, mask, 32 | abs_threshold=10, 33 | rel_threshold=0.1, 34 | use_np=False): 35 | d_est, d_gt = d_est[mask], d_gt[mask] 36 | if use_np: 37 | e = np.abs(d_gt - d_est) 38 | else: 39 | e = torch.abs(d_gt - d_est) 40 | 41 | err_mask = (e > abs_threshold) & (e / torch.maximum(d_gt, torch.ones_like(d_gt)) > rel_threshold) 42 | 43 | if use_np: 44 | mean = np.mean(err_mask.astype('float')) 45 | else: 46 | mean = torch.mean(err_mask.float()) 47 | 48 | return mean 49 | 50 | 51 | def thres_metric(d_est, d_gt, mask, thres, use_np=False): 52 | assert isinstance(thres, (int, float)) 53 | d_est, d_gt = d_est[mask], d_gt[mask] 54 | if use_np: 55 | e = np.abs(d_gt - d_est) 56 | else: 57 | e = torch.abs(d_gt - d_est) 58 | err_mask = e > thres 59 | 60 | if use_np: 61 | mean = np.mean(err_mask.astype('float')) 62 | else: 63 | mean = torch.mean(err_mask.float()) 64 | 65 | return mean 66 | -------------------------------------------------------------------------------- /pip_install.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | pip install torch==1.9.0+cu102 torchvision==0.10.0+cu102 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html 5 | 6 | pip install imageio==2.9.0 imageio-ffmpeg matplotlib opencv-python pillow scikit-image scipy tensorboard==2.9.1 setuptools==59.5.0 7 | -------------------------------------------------------------------------------- /scripts/depthsplat_depth_demo.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | # depthsplat-depth-small 5 | CUDA_VISIBLE_DEVICES=0 python main_depth.py \ 6 | --inference_dir demo/depth-scannet \ 7 | --output_path output/depthsplat-depth-small \ 8 | --resume pretrained/depthsplat-depth-small-3d79dd5e.pth \ 9 | --depthsplat_depth 10 | 11 | # predict depth for both images 12 | # --pred_bidir_depth 13 | 14 | 15 | 16 | # depthsplat-depth-base 17 | CUDA_VISIBLE_DEVICES=0 python main_depth.py \ 18 | --inference_dir demo/depth-scannet \ 19 | --output_path output/depthsplat-depth-base \ 20 | --resume pretrained/depthsplat-depth-base-f57113bd.pth \ 21 | --depthsplat_depth \ 22 | --vit_type vitb \ 23 | --num_scales 2 \ 24 | --upsample_factor 4 25 | 26 | 27 | 28 | # depthsplat-depth-large 29 | CUDA_VISIBLE_DEVICES=0 python main_depth.py \ 30 | --inference_dir demo/depth-scannet \ 31 | --output_path output/depthsplat-depth-large \ 32 | --resume pretrained/depthsplat-depth-large-50d3d7cf.pth \ 33 | --depthsplat_depth \ 34 | --vit_type vitl \ 35 | --num_scales 2 \ 36 | --upsample_factor 4 37 | 38 | -------------------------------------------------------------------------------- /scripts/gmdepth_demo.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | # gmdepth-scale1-regrefine1 5 | CUDA_VISIBLE_DEVICES=0 python main_depth.py \ 6 | --inference_dir demo/depth-scannet \ 7 | --output_path output/gmdepth-scale1-regrefine1-scannet \ 8 | --resume pretrained/gmdepth-scale1-regrefine1-resumeflowthings-scannet-90325722.pth \ 9 | --reg_refine \ 10 | --num_reg_refine 1 11 | 12 | # --pred_bidir_depth 13 | 14 | -------------------------------------------------------------------------------- /scripts/gmdepth_evaluate.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | # gmdepth-scale1 5 | CUDA_VISIBLE_DEVICES=0 python main_depth.py \ 6 | --eval \ 7 | --resume pretrained/gmdepth-scale1-resumeflowthings-demon-a2fe127b.pth \ 8 | --val_dataset demon \ 9 | --demon_split scenes11 10 | 11 | 12 | # gmdepth-scale1-regrefine1, this is our final model 13 | CUDA_VISIBLE_DEVICES=0 python main_depth.py \ 14 | --eval \ 15 | --resume pretrained/gmdepth-scale1-regrefine1-resumeflowthings-demon-7c23f230.pth \ 16 | --val_dataset demon \ 17 | --demon_split scenes11 \ 18 | --reg_refine \ 19 | --num_reg_refine 1 20 | 21 | -------------------------------------------------------------------------------- /scripts/gmdepth_scale1_regrefine1_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # GMDepth (1/8 feature only), with additional 1 local regression refinement at 1/8 resolution 4 | 5 | # number of gpus for training, please set according to your hardware 6 | # trained on 8x 40GB A100 gpus 7 | NUM_GPUS=8 8 | 9 | 10 | # scannet 11 | CHECKPOINT_DIR=checkpoints_depth/scannet-gmdepth-scale1-regrefine1-resumeflowthings && \ 12 | mkdir -p ${CHECKPOINT_DIR} && \ 13 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_depth.py \ 14 | --launcher pytorch \ 15 | --checkpoint_dir ${CHECKPOINT_DIR} \ 16 | --resume pretrained/gmdepth-scale1-resumeflowthings-scannet-5d9d7964.pth \ 17 | --no_resume_optimizer \ 18 | --dataset scannet \ 19 | --val_dataset scannet \ 20 | --image_size 480 640 \ 21 | --batch_size 64 \ 22 | --lr 4e-4 \ 23 | --reg_refine \ 24 | --num_reg_refine 1 \ 25 | --summary_freq 100 \ 26 | --val_freq 5000 \ 27 | --save_ckpt_freq 5000 \ 28 | --num_steps 100000 \ 29 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 30 | 31 | 32 | # demon 33 | CHECKPOINT_DIR=checkpoints_depth/demon-gmdepth-scale1-regrefine1-resumeflowthings && \ 34 | mkdir -p ${CHECKPOINT_DIR} && \ 35 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_depth.py \ 36 | --launcher pytorch \ 37 | --checkpoint_dir ${CHECKPOINT_DIR} \ 38 | --resume pretrained/gmdepth-scale1-resumeflowthings-demon-a2fe127b.pth \ 39 | --no_resume_optimizer \ 40 | --dataset demon \ 41 | --val_dataset demon \ 42 | --demon_split rgbd \ 43 | --image_size 448 576 \ 44 | --batch_size 64 \ 45 | --lr 4e-4 \ 46 | --reg_refine \ 47 | --num_reg_refine 1 \ 48 | --summary_freq 100 \ 49 | --val_freq 5000 \ 50 | --save_ckpt_freq 5000 \ 51 | --num_steps 100000 \ 52 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 53 | 54 | 55 | -------------------------------------------------------------------------------- /scripts/gmdepth_scale1_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # basic GMDepth without any refinement (1/8 feature only) 4 | 5 | # number of gpus for training, please set according to your hardware 6 | # trained on 8x 40GB A100 gpus 7 | NUM_GPUS=8 8 | 9 | 10 | # scannet (our final model is trained for 100K steps, for ablation, we train for 50K) 11 | # resume flow things model (our ablations are trained from random init) 12 | CHECKPOINT_DIR=checkpoints_depth/scannet-gmdepth-scale1-resumeflowthings && \ 13 | mkdir -p ${CHECKPOINT_DIR} && \ 14 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_depth.py \ 15 | --launcher pytorch \ 16 | --checkpoint_dir ${CHECKPOINT_DIR} \ 17 | --resume pretrained/gmflow-scale1-things-e9887eda.pth \ 18 | --no_resume_optimizer \ 19 | --dataset scannet \ 20 | --val_dataset scannet \ 21 | --image_size 480 640 \ 22 | --batch_size 80 \ 23 | --lr 4e-4 \ 24 | --summary_freq 100 \ 25 | --val_freq 5000 \ 26 | --save_ckpt_freq 5000 \ 27 | --num_steps 100000 \ 28 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 29 | 30 | 31 | # demon, resume flow things model 32 | CHECKPOINT_DIR=checkpoints_depth/demon-gmdepth-scale1-resumeflowthings && \ 33 | mkdir -p ${CHECKPOINT_DIR} && \ 34 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_depth.py \ 35 | --launcher pytorch \ 36 | --checkpoint_dir ${CHECKPOINT_DIR} \ 37 | --resume pretrained/gmflow-scale1-things-e9887eda.pth \ 38 | --no_resume_optimizer \ 39 | --dataset demon \ 40 | --val_dataset demon \ 41 | --demon_split rgbd \ 42 | --image_size 448 576 \ 43 | --batch_size 80 \ 44 | --lr 4e-4 \ 45 | --summary_freq 100 \ 46 | --val_freq 5000 \ 47 | --save_ckpt_freq 5000 \ 48 | --num_steps 100000 \ 49 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 50 | 51 | 52 | -------------------------------------------------------------------------------- /scripts/gmflow_demo.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | # gmflow-scale2-regrefine6, inference on image dir 5 | CUDA_VISIBLE_DEVICES=0 python main_flow.py \ 6 | --inference_dir demo/flow-davis \ 7 | --resume pretrained/gmflow-scale2-regrefine6-mixdata-train320x576-4e7b215d.pth \ 8 | --output_path output/gmflow-scale2-regrefine6-davis \ 9 | --padding_factor 32 \ 10 | --upsample_factor 4 \ 11 | --num_scales 2 \ 12 | --attn_splits_list 2 8 \ 13 | --corr_radius_list -1 4 \ 14 | --prop_radius_list -1 1 \ 15 | --reg_refine \ 16 | --num_reg_refine 6 17 | 18 | 19 | # gmflow-scale2-regrefine6, inference on video, save as video 20 | CUDA_VISIBLE_DEVICES=0 python main_flow.py \ 21 | --inference_video demo/kitti.mp4 \ 22 | --resume pretrained/gmflow-scale2-regrefine6-kitti15-25b554d7.pth \ 23 | --output_path output/kitti \ 24 | --padding_factor 32 \ 25 | --upsample_factor 4 \ 26 | --num_scales 2 \ 27 | --attn_splits_list 2 8 \ 28 | --corr_radius_list -1 4 \ 29 | --prop_radius_list -1 1 \ 30 | --reg_refine \ 31 | --num_reg_refine 6 \ 32 | --save_video \ 33 | --concat_flow_img 34 | 35 | 36 | 37 | # gmflow-scale1, inference on image dir 38 | CUDA_VISIBLE_DEVICES=0 python main_flow.py \ 39 | --inference_dir demo/flow-davis \ 40 | --resume pretrained/gmflow-scale1-mixdata-train320x576-4c3a6e9a.pth \ 41 | --output_path output/gmflow-scale1-davis 42 | 43 | # optional predict bidirection flow and forward-backward consistency check 44 | #--pred_bidir_flow 45 | #--fwd_bwd_check 46 | 47 | 48 | # gmflow-scale2, inference on image dir 49 | CUDA_VISIBLE_DEVICES=0 python main_flow.py \ 50 | --inference_dir demo/flow-davis \ 51 | --resume pretrained/gmflow-scale2-mixdata-train320x576-9ff1c094.pth \ 52 | --output_path output/gmflow-scale2-davis \ 53 | --padding_factor 32 \ 54 | --upsample_factor 4 \ 55 | --num_scales 2 \ 56 | --attn_splits_list 2 8 \ 57 | --corr_radius_list -1 4 \ 58 | --prop_radius_list -1 1 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /scripts/gmflow_evaluate.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | # gmflow-scale1 5 | CUDA_VISIBLE_DEVICES=0 python main_flow.py \ 6 | --eval \ 7 | --resume pretrained/gmflow-scale1-things-e9887eda.pth \ 8 | --val_dataset sintel \ 9 | --with_speed_metric 10 | 11 | 12 | # gmflow-scale2 13 | CUDA_VISIBLE_DEVICES=0 python main_flow.py \ 14 | --eval \ 15 | --resume pretrained/gmflow-scale2-things-36579974.pth \ 16 | --val_dataset kitti \ 17 | --padding_factor 32 \ 18 | --upsample_factor 4 \ 19 | --num_scales 2 \ 20 | --attn_splits_list 2 8 \ 21 | --corr_radius_list -1 4 \ 22 | --prop_radius_list -1 1 \ 23 | --with_speed_metric 24 | 25 | 26 | # gmflow-scale2-regrefine6 27 | CUDA_VISIBLE_DEVICES=0 python main_flow.py \ 28 | --eval \ 29 | --resume pretrained/gmflow-scale2-regrefine6-things-776ed612.pth \ 30 | --val_dataset kitti \ 31 | --padding_factor 32 \ 32 | --upsample_factor 4 \ 33 | --num_scales 2 \ 34 | --attn_splits_list 2 8 \ 35 | --corr_radius_list -1 4 \ 36 | --prop_radius_list -1 1 \ 37 | --reg_refine \ 38 | --num_reg_refine 6 \ 39 | --with_speed_metric 40 | 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /scripts/gmflow_scale1_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # basic GMFlow without any refinement (1/8 feature only) 4 | 5 | # number of gpus for training, please set according to your hardware 6 | # can be trained on 4x 16GB V100 or 2x 32GB V100 or 2x 40GB A100 gpus 7 | NUM_GPUS=4 8 | 9 | # chairs 10 | CHECKPOINT_DIR=checkpoints_flow/chairs-gmflow-scale1 && \ 11 | mkdir -p ${CHECKPOINT_DIR} && \ 12 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \ 13 | --launcher pytorch \ 14 | --checkpoint_dir ${CHECKPOINT_DIR} \ 15 | --stage chairs \ 16 | --batch_size 16 \ 17 | --val_dataset chairs sintel kitti \ 18 | --lr 4e-4 \ 19 | --image_size 384 512 \ 20 | --padding_factor 16 \ 21 | --upsample_factor 8 \ 22 | --with_speed_metric \ 23 | --val_freq 10000 \ 24 | --save_ckpt_freq 10000 \ 25 | --num_steps 100000 \ 26 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 27 | 28 | # things (our final model is trained for 800K iterations, for ablation study, you can train for 200K) 29 | CHECKPOINT_DIR=checkpoints_flow/things-gmflow-scale1 && \ 30 | mkdir -p ${CHECKPOINT_DIR} && \ 31 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \ 32 | --launcher pytorch \ 33 | --checkpoint_dir ${CHECKPOINT_DIR} \ 34 | --resume checkpoints_flow/chairs-gmflow-scale1/step_100000.pth \ 35 | --stage things \ 36 | --batch_size 8 \ 37 | --val_dataset things sintel kitti \ 38 | --lr 2e-4 \ 39 | --image_size 384 768 \ 40 | --padding_factor 16 \ 41 | --upsample_factor 8 \ 42 | --with_speed_metric \ 43 | --val_freq 40000 \ 44 | --save_ckpt_freq 50000 \ 45 | --num_steps 800000 \ 46 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 47 | 48 | # a final note: if your training is terminated unexpectedly, you can resume from the latest checkpoint 49 | # an example: resume chairs training 50 | # CHECKPOINT_DIR=checkpoints_flow/chairs-gmflow-scale1 && \ 51 | # mkdir -p ${CHECKPOINT_DIR} && \ 52 | # python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \ 53 | # --launcher pytorch \ 54 | # --checkpoint_dir ${CHECKPOINT_DIR} \ 55 | # --resume checkpoints_flow/chairs-gmflow-scale1/checkpoint_latest.pth \ 56 | # --stage chairs \ 57 | # --batch_size 16 \ 58 | # --val_dataset chairs sintel kitti \ 59 | # --lr 4e-4 \ 60 | # --image_size 384 512 \ 61 | # --padding_factor 16 \ 62 | # --upsample_factor 8 \ 63 | # --with_speed_metric \ 64 | # --val_freq 10000 \ 65 | # --save_ckpt_freq 10000 \ 66 | # --num_steps 100000 \ 67 | # 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 68 | 69 | 70 | -------------------------------------------------------------------------------- /scripts/gmflow_scale2_regrefine6_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # GMFlow with hierarchical matching refinement (1/8 + 1/4 features) 4 | # with additional 6 local regression refinements 5 | 6 | # number of gpus for training, please set according to your hardware 7 | # can be trained on 8x 32G V100 or 8x 40GB A100 gpus 8 | NUM_GPUS=8 9 | 10 | # chairs, resume from scale2 model 11 | CHECKPOINT_DIR=checkpoints_flow/chairs-gmflow-scale2-regrefine6 && \ 12 | mkdir -p ${CHECKPOINT_DIR} && \ 13 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \ 14 | --launcher pytorch \ 15 | --checkpoint_dir ${CHECKPOINT_DIR} \ 16 | --resume pretrained/gmflow-scale2-chairs-020cc9be.pth \ 17 | --no_resume_optimizer \ 18 | --stage chairs \ 19 | --batch_size 16 \ 20 | --val_dataset chairs sintel kitti \ 21 | --lr 4e-4 \ 22 | --image_size 384 512 \ 23 | --padding_factor 32 \ 24 | --upsample_factor 4 \ 25 | --num_scales 2 \ 26 | --attn_splits_list 2 8 \ 27 | --corr_radius_list -1 4 \ 28 | --prop_radius_list -1 1 \ 29 | --reg_refine \ 30 | --num_reg_refine 6 \ 31 | --with_speed_metric \ 32 | --val_freq 10000 \ 33 | --save_ckpt_freq 10000 \ 34 | --num_steps 100000 \ 35 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 36 | 37 | # things 38 | CHECKPOINT_DIR=checkpoints_flow/things-gmflow-scale2-regrefine6 && \ 39 | mkdir -p ${CHECKPOINT_DIR} && \ 40 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \ 41 | --launcher pytorch \ 42 | --checkpoint_dir ${CHECKPOINT_DIR} \ 43 | --resume checkpoints_flow/chairs-gmflow-scale2-regrefine6/step_100000.pth \ 44 | --stage things \ 45 | --batch_size 8 \ 46 | --val_dataset things sintel kitti \ 47 | --lr 2e-4 \ 48 | --image_size 384 768 \ 49 | --padding_factor 32 \ 50 | --upsample_factor 4 \ 51 | --num_scales 2 \ 52 | --attn_splits_list 2 8 \ 53 | --corr_radius_list -1 4 \ 54 | --prop_radius_list -1 1 \ 55 | --reg_refine \ 56 | --num_reg_refine 6 \ 57 | --with_speed_metric \ 58 | --val_freq 40000 \ 59 | --save_ckpt_freq 50000 \ 60 | --num_steps 800000 \ 61 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 62 | 63 | # sintel, resume from things model 64 | CHECKPOINT_DIR=checkpoints_flow/sintel-gmflow-scale2-regrefine6 && \ 65 | mkdir -p ${CHECKPOINT_DIR} && \ 66 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \ 67 | --launcher pytorch \ 68 | --checkpoint_dir ${CHECKPOINT_DIR} \ 69 | --resume checkpoints_flow/things-gmflow-scale2-regrefine6/step_800000.pth \ 70 | --stage sintel \ 71 | --batch_size 8 \ 72 | --val_dataset sintel kitti \ 73 | --lr 2e-4 \ 74 | --image_size 320 896 \ 75 | --padding_factor 32 \ 76 | --upsample_factor 4 \ 77 | --num_scales 2 \ 78 | --attn_splits_list 2 8 \ 79 | --corr_radius_list -1 4 \ 80 | --prop_radius_list -1 1 \ 81 | --reg_refine \ 82 | --num_reg_refine 6 \ 83 | --with_speed_metric \ 84 | --val_freq 20000 \ 85 | --save_ckpt_freq 20000 \ 86 | --num_steps 200000 \ 87 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 88 | 89 | 90 | # sintel finetune, resume from sintel model, this is our final model for sintel benchmark submission 91 | CHECKPOINT_DIR=checkpoints_flow/sintel-gmflow-scale2-regrefine6-ft && \ 92 | mkdir -p ${CHECKPOINT_DIR} && \ 93 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \ 94 | --launcher pytorch \ 95 | --checkpoint_dir ${CHECKPOINT_DIR} \ 96 | --resume checkpoints_flow/sintel-gmflow-scale2-regrefine6/step_200000.pth \ 97 | --stage sintel_ft \ 98 | --batch_size 8 \ 99 | --val_dataset sintel \ 100 | --lr 1e-4 \ 101 | --image_size 416 1024 \ 102 | --padding_factor 32 \ 103 | --upsample_factor 4 \ 104 | --num_scales 2 \ 105 | --attn_splits_list 2 8 \ 106 | --corr_radius_list -1 4 \ 107 | --prop_radius_list -1 1 \ 108 | --reg_refine \ 109 | --num_reg_refine 6 \ 110 | --with_speed_metric \ 111 | --val_freq 1000 \ 112 | --save_ckpt_freq 1000 \ 113 | --num_steps 5000 \ 114 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 115 | 116 | 117 | # vkitti2, resume from things model 118 | CHECKPOINT_DIR=checkpoints_flow/vkitti2-gmflow-scale2-regrefine6 && \ 119 | mkdir -p ${CHECKPOINT_DIR} && \ 120 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \ 121 | --launcher pytorch \ 122 | --checkpoint_dir ${CHECKPOINT_DIR} \ 123 | --resume checkpoints_flow/things-gmflow-scale2-regrefine6/step_800000.pth \ 124 | --stage vkitti2 \ 125 | --batch_size 16 \ 126 | --val_dataset kitti \ 127 | --lr 2e-4 \ 128 | --image_size 320 832 \ 129 | --padding_factor 32 \ 130 | --upsample_factor 4 \ 131 | --num_scales 2 \ 132 | --attn_splits_list 2 8 \ 133 | --corr_radius_list -1 4 \ 134 | --prop_radius_list -1 1 \ 135 | --reg_refine \ 136 | --num_reg_refine 6 \ 137 | --with_speed_metric \ 138 | --val_freq 10000 \ 139 | --save_ckpt_freq 10000 \ 140 | --num_steps 40000 \ 141 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 142 | 143 | 144 | # kitti, resume from vkitti2 model, this is our final model for kitti benchmark submission 145 | CHECKPOINT_DIR=checkpoints_flow/kitti-gmflow-scale2-regrefine6 && \ 146 | mkdir -p ${CHECKPOINT_DIR} && \ 147 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \ 148 | --launcher pytorch \ 149 | --checkpoint_dir ${CHECKPOINT_DIR} \ 150 | --resume checkpoints_flow/vkitti2-gmflow-scale2-regrefine6/step_040000.pth \ 151 | --stage kitti_mix \ 152 | --batch_size 8 \ 153 | --val_dataset kitti \ 154 | --lr 2e-4 \ 155 | --image_size 352 1216 \ 156 | --padding_factor 32 \ 157 | --upsample_factor 4 \ 158 | --num_scales 2 \ 159 | --attn_splits_list 2 8 \ 160 | --corr_radius_list -1 4 \ 161 | --prop_radius_list -1 1 \ 162 | --reg_refine \ 163 | --num_reg_refine 6 \ 164 | --with_speed_metric \ 165 | --val_freq 5000 \ 166 | --save_ckpt_freq 10000 \ 167 | --num_steps 30000 \ 168 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 169 | 170 | 171 | -------------------------------------------------------------------------------- /scripts/gmflow_scale2_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # GMFlow with hierarchical matching refinement (1/8 + 1/4 features) 4 | 5 | # number of gpus for training, please set according to your hardware 6 | # can be trained on 4x 32G V100 or 4x 40GB A100 or 8x 16G V100 gpus 7 | NUM_GPUS=4 8 | 9 | # chairs 10 | CHECKPOINT_DIR=checkpoints_flow/chairs-gmflow-scale2 && \ 11 | mkdir -p ${CHECKPOINT_DIR} && \ 12 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \ 13 | --launcher pytorch \ 14 | --checkpoint_dir ${CHECKPOINT_DIR} \ 15 | --stage chairs \ 16 | --batch_size 16 \ 17 | --val_dataset chairs sintel kitti \ 18 | --lr 4e-4 \ 19 | --image_size 384 512 \ 20 | --padding_factor 32 \ 21 | --upsample_factor 4 \ 22 | --num_scales 2 \ 23 | --attn_splits_list 2 8 \ 24 | --corr_radius_list -1 4 \ 25 | --prop_radius_list -1 1 \ 26 | --with_speed_metric \ 27 | --val_freq 10000 \ 28 | --save_ckpt_freq 10000 \ 29 | --num_steps 100000 \ 30 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 31 | 32 | # things (our final model is trained for 800K iterations, for ablation study, you can train for 200K) 33 | CHECKPOINT_DIR=checkpoints_flow/things-gmflow-scale2 && \ 34 | mkdir -p ${CHECKPOINT_DIR} && \ 35 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \ 36 | --launcher pytorch \ 37 | --checkpoint_dir ${CHECKPOINT_DIR} \ 38 | --resume checkpoints_flow/chairs-gmflow-scale2/step_100000.pth \ 39 | --stage things \ 40 | --batch_size 8 \ 41 | --val_dataset things sintel kitti \ 42 | --lr 2e-4 \ 43 | --image_size 384 768 \ 44 | --padding_factor 32 \ 45 | --upsample_factor 4 \ 46 | --num_scales 2 \ 47 | --attn_splits_list 2 8 \ 48 | --corr_radius_list -1 4 \ 49 | --prop_radius_list -1 1 \ 50 | --with_speed_metric \ 51 | --val_freq 40000 \ 52 | --save_ckpt_freq 50000 \ 53 | --num_steps 800000 \ 54 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 55 | 56 | # sintel 57 | CHECKPOINT_DIR=checkpoints_flow/sintel-gmflow-scale2 && \ 58 | mkdir -p ${CHECKPOINT_DIR} && \ 59 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \ 60 | --launcher pytorch \ 61 | --checkpoint_dir ${CHECKPOINT_DIR} \ 62 | --resume checkpoints_flow/things-gmflow-scale2/step_800000.pth \ 63 | --stage sintel \ 64 | --batch_size 8 \ 65 | --val_dataset sintel kitti \ 66 | --lr 2e-4 \ 67 | --image_size 320 896 \ 68 | --padding_factor 32 \ 69 | --upsample_factor 4 \ 70 | --num_scales 2 \ 71 | --attn_splits_list 2 8 \ 72 | --corr_radius_list -1 4 \ 73 | --prop_radius_list -1 1 \ 74 | --with_speed_metric \ 75 | --val_freq 20000 \ 76 | --save_ckpt_freq 20000 \ 77 | --num_steps 200000 \ 78 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 79 | 80 | # kitti 81 | CHECKPOINT_DIR=checkpoints_flow/kitti-gmflow-scale2 && \ 82 | mkdir -p ${CHECKPOINT_DIR} && \ 83 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \ 84 | --launcher pytorch \ 85 | --checkpoint_dir ${CHECKPOINT_DIR} \ 86 | --resume checkpoints_flow/sintel-gmflow-scale2/step_200000.pth \ 87 | --stage kitti \ 88 | --batch_size 8 \ 89 | --val_dataset kitti \ 90 | --lr 2e-4 \ 91 | --image_size 320 1152 \ 92 | --padding_factor 32 \ 93 | --upsample_factor 4 \ 94 | --num_scales 2 \ 95 | --attn_splits_list 2 8 \ 96 | --corr_radius_list -1 4 \ 97 | --prop_radius_list -1 1 \ 98 | --with_speed_metric \ 99 | --val_freq 10000 \ 100 | --save_ckpt_freq 10000 \ 101 | --num_steps 100000 \ 102 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 103 | 104 | -------------------------------------------------------------------------------- /scripts/gmflow_submission.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | # generate prediction results for submission on sintel and kitti online servers 5 | 6 | 7 | # submission to sintel 8 | CUDA_VISIBLE_DEVICES=0 python main_flow.py \ 9 | --submission \ 10 | --output_path submission/sintel-gmflow-scale2-regrefine6-sintelft \ 11 | --val_dataset sintel \ 12 | --resume pretrained/gmflow-scale2-regrefine6-sintelft-6e39e2b9.pth \ 13 | --inference_size 416 1024 \ 14 | --padding_factor 32 \ 15 | --upsample_factor 4 \ 16 | --num_scales 2 \ 17 | --attn_splits_list 2 8 \ 18 | --corr_radius_list -1 4 \ 19 | --prop_radius_list -1 1 \ 20 | --reg_refine \ 21 | --num_reg_refine 6 22 | 23 | 24 | # you can also visualize the predictions before submission 25 | #CUDA_VISIBLE_DEVICES=0 python main_flow.py \ 26 | #--submission \ 27 | #--output_path submission/sintel-gmflow-scale2-regrefine6-sintelft-vis \ 28 | #--val_dataset sintel \ 29 | #--resume pretrained/gmflow-scale2-regrefine6-sintelft-6e39e2b9.pth \ 30 | #--inference_size 416 1024 \ 31 | #--save_vis_flow \ 32 | #--no_save_flo \ 33 | #--padding_factor 32 \ 34 | #--upsample_factor 4 \ 35 | #--num_scales 2 \ 36 | #--attn_splits_list 2 8 \ 37 | #--corr_radius_list -1 4 \ 38 | #--prop_radius_list -1 1 \ 39 | #--reg_refine \ 40 | #--num_reg_refine 6 41 | 42 | 43 | # submission to kitti 44 | CUDA_VISIBLE_DEVICES=0 python main_flow.py \ 45 | --submission \ 46 | --output_path submission/kitti-gmflow-scale2-regrefine6 \ 47 | --val_dataset kitti \ 48 | --resume pretrained/gmflow-scale2-regrefine6-kitti15-25b554d7.pth \ 49 | --inference_size 352 1216 \ 50 | --padding_factor 32 \ 51 | --upsample_factor 4 \ 52 | --num_scales 2 \ 53 | --attn_splits_list 2 8 \ 54 | --corr_radius_list -1 4 \ 55 | --prop_radius_list -1 1 \ 56 | --reg_refine \ 57 | --num_reg_refine 6 58 | 59 | 60 | -------------------------------------------------------------------------------- /scripts/gmstereo_demo.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | # gmstereo-scale2-regrefine3 model 5 | CUDA_VISIBLE_DEVICES=0 python main_stereo.py \ 6 | --inference_dir demo/stereo-middlebury \ 7 | --inference_size 1024 1536 \ 8 | --output_path output/gmstereo-scale2-regrefine3-middlebury \ 9 | --resume pretrained/gmstereo-scale2-regrefine3-resumeflowthings-middleburyfthighres-a82bec03.pth \ 10 | --padding_factor 32 \ 11 | --upsample_factor 4 \ 12 | --num_scales 2 \ 13 | --attn_type self_swin2d_cross_swin1d \ 14 | --attn_splits_list 2 8 \ 15 | --corr_radius_list -1 4 \ 16 | --prop_radius_list -1 1 \ 17 | --reg_refine \ 18 | --num_reg_refine 3 19 | 20 | # optionally predict both left and right disparities 21 | #--pred_bidir_disp 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /scripts/gmstereo_evaluate.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | # gmstereo-scale1 5 | CUDA_VISIBLE_DEVICES=0 python main_stereo.py \ 6 | --eval \ 7 | --resume pretrained/gmstereo-scale1-resumeflowthings-sceneflow-16e38788.pth \ 8 | --val_dataset kitti15 9 | 10 | 11 | # gmstereo-scale2 12 | CUDA_VISIBLE_DEVICES=0 python main_stereo.py \ 13 | --eval \ 14 | --resume pretrained/gmstereo-scale2-resumeflowthings-sceneflow-48020649.pth \ 15 | --val_dataset kitti15 \ 16 | --padding_factor 32 \ 17 | --upsample_factor 4 \ 18 | --num_scales 2 \ 19 | --attn_type self_swin2d_cross_swin1d \ 20 | --attn_splits_list 2 8 \ 21 | --corr_radius_list -1 4 \ 22 | --prop_radius_list -1 1 23 | 24 | 25 | # gmstereo-scale2-regrefine3 26 | CUDA_VISIBLE_DEVICES=0 python main_stereo.py \ 27 | --eval \ 28 | --resume pretrained/gmstereo-scale2-regrefine3-resumeflowthings-sceneflow-f724fee6.pth \ 29 | --val_dataset kitti15 \ 30 | --padding_factor 32 \ 31 | --upsample_factor 4 \ 32 | --num_scales 2 \ 33 | --attn_type self_swin2d_cross_swin1d \ 34 | --attn_splits_list 2 8 \ 35 | --corr_radius_list -1 4 \ 36 | --prop_radius_list -1 1 \ 37 | --reg_refine \ 38 | --num_reg_refine 3 39 | 40 | 41 | -------------------------------------------------------------------------------- /scripts/gmstereo_scale1_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # basic GMStereo without any refinement (1/8 feature only) 4 | 5 | # number of gpus for training, please set according to your hardware 6 | # trained on 8x 40GB A100 gpus 7 | NUM_GPUS=8 8 | 9 | # sceneflow (our final model is trained for 100K steps, for ablation, we train for 50K) 10 | # resume flow things model (our ablations are trained from random init) 11 | CHECKPOINT_DIR=checkpoints_stereo/sceneflow-gmstereo-scale1-resumeflowthings && \ 12 | mkdir -p ${CHECKPOINT_DIR} && \ 13 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_stereo.py \ 14 | --launcher pytorch \ 15 | --checkpoint_dir ${CHECKPOINT_DIR} \ 16 | --resume pretrained/gmflow-scale1-things-e9887eda.pth \ 17 | --no_resume_optimizer \ 18 | --stage sceneflow \ 19 | --batch_size 64 \ 20 | --val_dataset things kitti15 \ 21 | --img_height 384 \ 22 | --img_width 768 \ 23 | --padding_factor 16 \ 24 | --upsample_factor 8 \ 25 | --attn_type self_swin2d_cross_1d \ 26 | --summary_freq 1000 \ 27 | --val_freq 10000 \ 28 | --save_ckpt_freq 1000 \ 29 | --save_latest_ckpt_freq 1000 \ 30 | --num_steps 100000 \ 31 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 32 | 33 | 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /scripts/gmstereo_scale2_regrefine3_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # GMFlow with hierarchical matching refinement (1/8 + 1/4 features) 4 | # with additional 3 local regression refinements 5 | 6 | # number of gpus for training, please set according to your hardware 7 | # trained on 8x 40GB A100 gpus 8 | NUM_GPUS=8 9 | 10 | # sceneflow 11 | # resume gmstereo scale2 model, which is trained from flow things model 12 | CHECKPOINT_DIR=checkpoints_stereo/sceneflow-gmstereo-scale2-regrefine3-resumeflowthings && \ 13 | mkdir -p ${CHECKPOINT_DIR} && \ 14 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_stereo.py \ 15 | --launcher pytorch \ 16 | --checkpoint_dir ${CHECKPOINT_DIR} \ 17 | --resume pretrained/gmstereo-scale2-resumeflowthings-sceneflow-48020649.pth \ 18 | --no_resume_optimizer \ 19 | --stage sceneflow \ 20 | --lr 4e-4 \ 21 | --batch_size 16 \ 22 | --val_dataset things kitti15 \ 23 | --img_height 384 \ 24 | --img_width 768 \ 25 | --padding_factor 32 \ 26 | --upsample_factor 4 \ 27 | --num_scales 2 \ 28 | --attn_type self_swin2d_cross_swin1d \ 29 | --attn_splits_list 2 8 \ 30 | --corr_radius_list -1 4 \ 31 | --prop_radius_list -1 1 \ 32 | --reg_refine \ 33 | --num_reg_refine 3 \ 34 | --summary_freq 100 \ 35 | --val_freq 10000 \ 36 | --save_ckpt_freq 1000 \ 37 | --save_latest_ckpt_freq 1000 \ 38 | --num_steps 100000 \ 39 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 40 | 41 | 42 | # vkitti2 43 | CHECKPOINT_DIR=checkpoints_stereo/vkitti2-gmstereo-scale2-regrefine3-resumeflowthings && \ 44 | mkdir -p ${CHECKPOINT_DIR} && \ 45 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_stereo.py \ 46 | --launcher pytorch \ 47 | --checkpoint_dir ${CHECKPOINT_DIR} \ 48 | --resume checkpoints_stereo/sceneflow-gmstereo-scale2-regrefine3-resumeflowthings/step_100000.pth \ 49 | --no_resume_optimizer \ 50 | --stage vkitti2 \ 51 | --val_dataset kitti15 \ 52 | --lr 4e-4 \ 53 | --batch_size 16 \ 54 | --img_height 320 \ 55 | --img_width 832 \ 56 | --padding_factor 32 \ 57 | --upsample_factor 4 \ 58 | --num_scales 2 \ 59 | --attn_type self_swin2d_cross_swin1d \ 60 | --attn_splits_list 2 8 \ 61 | --corr_radius_list -1 4 \ 62 | --prop_radius_list -1 1 \ 63 | --reg_refine \ 64 | --num_reg_refine 3 \ 65 | --summary_freq 100 \ 66 | --val_freq 5000 \ 67 | --save_ckpt_freq 1000 \ 68 | --save_latest_ckpt_freq 1000 \ 69 | --num_steps 30000 \ 70 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 71 | 72 | # kitti, this is our final model for kitti submission 73 | CHECKPOINT_DIR=checkpoints_stereo/kitti-gmstereo-scale2-regrefine3-resumeflowthings && \ 74 | mkdir -p ${CHECKPOINT_DIR} && \ 75 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_stereo.py \ 76 | --launcher pytorch \ 77 | --checkpoint_dir ${CHECKPOINT_DIR} \ 78 | --resume checkpoints_stereo/vkitti2-gmstereo-scale2-regrefine3-resumeflowthings/step_030000.pth \ 79 | --no_resume_optimizer \ 80 | --stage kitti15mix \ 81 | --val_dataset kitti15 \ 82 | --lr 4e-4 \ 83 | --batch_size 16 \ 84 | --img_height 352 \ 85 | --img_width 1216 \ 86 | --padding_factor 32 \ 87 | --upsample_factor 4 \ 88 | --num_scales 2 \ 89 | --attn_type self_swin2d_cross_swin1d \ 90 | --attn_splits_list 2 8 \ 91 | --corr_radius_list -1 4 \ 92 | --prop_radius_list -1 1 \ 93 | --reg_refine \ 94 | --num_reg_refine 3 \ 95 | --summary_freq 100 \ 96 | --val_freq 2000 \ 97 | --save_ckpt_freq 2000 \ 98 | --save_latest_ckpt_freq 1000 \ 99 | --num_steps 10000 \ 100 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 101 | 102 | 103 | # middlebury, train on 480x640 first 104 | CHECKPOINT_DIR=checkpoints_stereo/middlebury-gmstereo-scale2-regrefine3-resumeflowthings && \ 105 | mkdir -p ${CHECKPOINT_DIR} && \ 106 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_stereo.py \ 107 | --launcher pytorch \ 108 | --checkpoint_dir ${CHECKPOINT_DIR} \ 109 | --resume checkpoints_stereo/sceneflow-gmstereo-scale2-regrefine3-resumeflowthings/step_100000.pth \ 110 | --no_resume_optimizer \ 111 | --stage middlebury \ 112 | --val_dataset middlebury \ 113 | --inference_size 768 1024 \ 114 | --lr 4e-4 \ 115 | --batch_size 16 \ 116 | --img_height 480 \ 117 | --img_width 640 \ 118 | --padding_factor 32 \ 119 | --upsample_factor 4 \ 120 | --num_scales 2 \ 121 | --attn_type self_swin2d_cross_swin1d \ 122 | --attn_splits_list 2 8 \ 123 | --corr_radius_list -1 4 \ 124 | --prop_radius_list -1 1 \ 125 | --reg_refine \ 126 | --num_reg_refine 3 \ 127 | --summary_freq 100 \ 128 | --val_freq 10000 \ 129 | --save_ckpt_freq 10000 \ 130 | --save_latest_ckpt_freq 1000 \ 131 | --num_steps 100000 \ 132 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 133 | 134 | 135 | # middlebury, finetune on 768x1024 resolution, max disparity range 600 in loss 136 | # this is our final model for middlebury submission 137 | CHECKPOINT_DIR=checkpoints_stereo/middlebury-gmstereo-scale2-regrefine3-resumeflowthings-fthighres && \ 138 | mkdir -p ${CHECKPOINT_DIR} && \ 139 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_stereo.py \ 140 | --launcher pytorch \ 141 | --checkpoint_dir ${CHECKPOINT_DIR} \ 142 | --resume checkpoints_stereo/middlebury-gmstereo-scale2-regrefine3-resumeflowthings/step_100000.pth \ 143 | --no_resume_optimizer \ 144 | --max_disp 600 \ 145 | --stage middlebury_ft \ 146 | --val_dataset middlebury \ 147 | --inference_size 1536 2048 \ 148 | --lr 4e-4 \ 149 | --batch_size 8 \ 150 | --img_height 768 \ 151 | --img_width 1024 \ 152 | --padding_factor 32 \ 153 | --upsample_factor 4 \ 154 | --num_scales 2 \ 155 | --attn_type self_swin2d_cross_swin1d \ 156 | --attn_splits_list 2 8 \ 157 | --corr_radius_list -1 4 \ 158 | --prop_radius_list -1 1 \ 159 | --reg_refine \ 160 | --num_reg_refine 3 \ 161 | --summary_freq 100 \ 162 | --val_freq 5000 \ 163 | --save_ckpt_freq 10000 \ 164 | --save_latest_ckpt_freq 1000 \ 165 | --num_steps 50000 \ 166 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 167 | 168 | 169 | # eth3d 170 | CHECKPOINT_DIR=checkpoints_stereo/eth3d-gmstereo-scale2-regrefine3-resumeflowthings && \ 171 | mkdir -p ${CHECKPOINT_DIR} && \ 172 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_stereo.py \ 173 | --launcher pytorch \ 174 | --checkpoint_dir ${CHECKPOINT_DIR} \ 175 | --resume checkpoints_stereo/sceneflow-gmstereo-scale2-regrefine3-resumeflowthings/step_100000.pth \ 176 | --no_resume_optimizer \ 177 | --stage eth3d \ 178 | --val_dataset eth3d \ 179 | --lr 4e-4 \ 180 | --batch_size 24 \ 181 | --img_height 416 \ 182 | --img_width 640 \ 183 | --padding_factor 32 \ 184 | --upsample_factor 4 \ 185 | --num_scales 2 \ 186 | --attn_type self_swin2d_cross_swin1d \ 187 | --attn_splits_list 2 8 \ 188 | --corr_radius_list -1 4 \ 189 | --prop_radius_list -1 1 \ 190 | --reg_refine \ 191 | --num_reg_refine 3 \ 192 | --summary_freq 100 \ 193 | --val_freq 10000 \ 194 | --save_ckpt_freq 10000 \ 195 | --save_latest_ckpt_freq 1000 \ 196 | --num_steps 100000 \ 197 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 198 | 199 | 200 | # eth3d, finetune, this is our final model for eth3d submission 201 | CHECKPOINT_DIR=checkpoints_stereo/eth3d-gmstereo-scale2-regrefine3-resumeflowthings-ft && \ 202 | mkdir -p ${CHECKPOINT_DIR} && \ 203 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_stereo.py \ 204 | --launcher pytorch \ 205 | --checkpoint_dir ${CHECKPOINT_DIR} \ 206 | --resume checkpoints_stereo/eth3d-gmstereo-scale2-regrefine3-resumeflowthings/step_100000.pth \ 207 | --no_resume_optimizer \ 208 | --stage eth3d_ft \ 209 | --val_dataset eth3d \ 210 | --lr 4e-4 \ 211 | --batch_size 24 \ 212 | --img_height 416 \ 213 | --img_width 640 \ 214 | --padding_factor 32 \ 215 | --upsample_factor 4 \ 216 | --num_scales 2 \ 217 | --attn_type self_swin2d_cross_swin1d \ 218 | --attn_splits_list 2 8 \ 219 | --corr_radius_list -1 4 \ 220 | --prop_radius_list -1 1 \ 221 | --reg_refine \ 222 | --num_reg_refine 3 \ 223 | --summary_freq 100 \ 224 | --val_freq 3000 \ 225 | --save_ckpt_freq 3000 \ 226 | --save_latest_ckpt_freq 1000 \ 227 | --num_steps 30000 \ 228 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | -------------------------------------------------------------------------------- /scripts/gmstereo_scale2_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # GMFlow with hierarchical matching refinement (1/8 + 1/4 features) 4 | 5 | # number of gpus for training, please set according to your hardware 6 | # trained on 8x 40GB A100 gpus 7 | NUM_GPUS=8 8 | 9 | # sceneflow 10 | # resume flow things model 11 | CHECKPOINT_DIR=checkpoints_stereo/sceneflow-gmstereo-scale2-resumeflowthings && \ 12 | mkdir -p ${CHECKPOINT_DIR} && \ 13 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_stereo.py \ 14 | --launcher pytorch \ 15 | --checkpoint_dir ${CHECKPOINT_DIR} \ 16 | --resume pretrained/gmflow-scale2-things-36579974.pth \ 17 | --no_resume_optimizer \ 18 | --stage sceneflow \ 19 | --batch_size 32 \ 20 | --val_dataset things kitti15 \ 21 | --img_height 384 \ 22 | --img_width 768 \ 23 | --padding_factor 32 \ 24 | --upsample_factor 4 \ 25 | --num_scales 2 \ 26 | --attn_type self_swin2d_cross_swin1d \ 27 | --attn_splits_list 2 8 \ 28 | --corr_radius_list -1 4 \ 29 | --prop_radius_list -1 1 \ 30 | --summary_freq 100 \ 31 | --val_freq 10000 \ 32 | --save_ckpt_freq 1000 \ 33 | --save_latest_ckpt_freq 1000 \ 34 | --num_steps 100000 \ 35 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 36 | 37 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /scripts/gmstereo_submission.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # generate prediction results for submission on kitti, middlebury and eth3d online servers 4 | 5 | 6 | # submission to kitti 7 | CUDA_VISIBLE_DEVICES=0 python main_stereo.py \ 8 | --submission \ 9 | --val_dataset kitti15 \ 10 | --inference_size 352 1216 \ 11 | --output_path submission/kitti-gmstereo-scale2-regrefine3 \ 12 | --resume pretrained/gmstereo-scale2-regrefine3-resumeflowthings-kitti15-04487ebf.pth \ 13 | --padding_factor 32 \ 14 | --upsample_factor 4 \ 15 | --num_scales 2 \ 16 | --attn_type self_swin2d_cross_swin1d \ 17 | --attn_splits_list 2 8 \ 18 | --corr_radius_list -1 4 \ 19 | --prop_radius_list -1 1 \ 20 | --reg_refine \ 21 | --num_reg_refine 3 22 | 23 | 24 | # submission to middlebury 25 | # set --eth_submission_mode to train and test to generate results on both train and test sets 26 | # use --save_vis_disp to visualize disparity 27 | CUDA_VISIBLE_DEVICES=0 python main_stereo.py \ 28 | --submission \ 29 | --val_dataset middlebury \ 30 | --middlebury_resolution F \ 31 | --middlebury_submission_mode test \ 32 | --inference_size 1024 1536 \ 33 | --output_path submission/middlebury-test-gmstereo-scale2-regrefine3 \ 34 | --resume pretrained/gmstereo-scale2-regrefine3-resumeflowthings-middleburyfthighres-a82bec03.pth \ 35 | --padding_factor 32 \ 36 | --upsample_factor 4 \ 37 | --num_scales 2 \ 38 | --attn_type self_swin2d_cross_swin1d \ 39 | --attn_splits_list 2 8 \ 40 | --corr_radius_list -1 4 \ 41 | --prop_radius_list -1 1 \ 42 | --reg_refine \ 43 | --num_reg_refine 3 44 | 45 | 46 | # submission to eth3d 47 | # set --eth_submission_mode to train and test to generate results on both train and test sets 48 | # use --save_vis_disp to visualize disparity 49 | CUDA_VISIBLE_DEVICES=0 python main_stereo.py \ 50 | --submission \ 51 | --eth_submission_mode test \ 52 | --val_dataset eth3d \ 53 | --inference_size 512 768 \ 54 | --output_path submission/eth3d-test-gmstereo-scale2-regrefine3 \ 55 | --resume pretrained/gmstereo-scale2-regrefine3-resumeflowthings-eth3dft-46effc13.pth \ 56 | --padding_factor 32 \ 57 | --upsample_factor 4 \ 58 | --num_scales 2 \ 59 | --attn_type self_swin2d_cross_swin1d \ 60 | --attn_splits_list 2 8 \ 61 | --corr_radius_list -1 4 \ 62 | --prop_radius_list -1 1 \ 63 | --reg_refine \ 64 | --num_reg_refine 3 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /unimatch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/unimatch/943a99006950ab6c84ea72a72beb147fdd03c19a/unimatch/__init__.py -------------------------------------------------------------------------------- /unimatch/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .utils import split_feature, merge_splits, split_feature_1d, merge_splits_1d 6 | 7 | 8 | def single_head_full_attention(q, k, v): 9 | # q, k, v: [B, L, C] 10 | assert q.dim() == k.dim() == v.dim() == 3 11 | 12 | scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** .5) # [B, L, L] 13 | attn = torch.softmax(scores, dim=2) # [B, L, L] 14 | out = torch.matmul(attn, v) # [B, L, C] 15 | 16 | return out 17 | 18 | 19 | def single_head_full_attention_1d(q, k, v, 20 | h=None, 21 | w=None, 22 | ): 23 | # q, k, v: [B, L, C] 24 | 25 | assert h is not None and w is not None 26 | assert q.size(1) == h * w 27 | 28 | b, _, c = q.size() 29 | 30 | q = q.view(b, h, w, c) # [B, H, W, C] 31 | k = k.view(b, h, w, c) 32 | v = v.view(b, h, w, c) 33 | 34 | scale_factor = c ** 0.5 35 | 36 | scores = torch.matmul(q, k.permute(0, 1, 3, 2)) / scale_factor # [B, H, W, W] 37 | 38 | attn = torch.softmax(scores, dim=-1) 39 | 40 | out = torch.matmul(attn, v).view(b, -1, c) # [B, H*W, C] 41 | 42 | return out 43 | 44 | 45 | def single_head_split_window_attention(q, k, v, 46 | num_splits=1, 47 | with_shift=False, 48 | h=None, 49 | w=None, 50 | attn_mask=None, 51 | ): 52 | # ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py 53 | # q, k, v: [B, L, C] 54 | assert q.dim() == k.dim() == v.dim() == 3 55 | 56 | assert h is not None and w is not None 57 | assert q.size(1) == h * w 58 | 59 | b, _, c = q.size() 60 | 61 | b_new = b * num_splits * num_splits 62 | 63 | window_size_h = h // num_splits 64 | window_size_w = w // num_splits 65 | 66 | q = q.view(b, h, w, c) # [B, H, W, C] 67 | k = k.view(b, h, w, c) 68 | v = v.view(b, h, w, c) 69 | 70 | scale_factor = c ** 0.5 71 | 72 | if with_shift: 73 | assert attn_mask is not None # compute once 74 | shift_size_h = window_size_h // 2 75 | shift_size_w = window_size_w // 2 76 | 77 | q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) 78 | k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) 79 | v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) 80 | 81 | q = split_feature(q, num_splits=num_splits, channel_last=True) # [B*K*K, H/K, W/K, C] 82 | k = split_feature(k, num_splits=num_splits, channel_last=True) 83 | v = split_feature(v, num_splits=num_splits, channel_last=True) 84 | 85 | scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1) 86 | ) / scale_factor # [B*K*K, H/K*W/K, H/K*W/K] 87 | 88 | if with_shift: 89 | scores += attn_mask.repeat(b, 1, 1) 90 | 91 | attn = torch.softmax(scores, dim=-1) 92 | 93 | out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C] 94 | 95 | out = merge_splits(out.view(b_new, h // num_splits, w // num_splits, c), 96 | num_splits=num_splits, channel_last=True) # [B, H, W, C] 97 | 98 | # shift back 99 | if with_shift: 100 | out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2)) 101 | 102 | out = out.view(b, -1, c) 103 | 104 | return out 105 | 106 | 107 | def single_head_split_window_attention_1d(q, k, v, 108 | relative_position_bias=None, 109 | num_splits=1, 110 | with_shift=False, 111 | h=None, 112 | w=None, 113 | attn_mask=None, 114 | ): 115 | # q, k, v: [B, L, C] 116 | 117 | assert h is not None and w is not None 118 | assert q.size(1) == h * w 119 | 120 | b, _, c = q.size() 121 | 122 | b_new = b * num_splits * h 123 | 124 | window_size_w = w // num_splits 125 | 126 | q = q.view(b * h, w, c) # [B*H, W, C] 127 | k = k.view(b * h, w, c) 128 | v = v.view(b * h, w, c) 129 | 130 | scale_factor = c ** 0.5 131 | 132 | if with_shift: 133 | assert attn_mask is not None # compute once 134 | shift_size_w = window_size_w // 2 135 | 136 | q = torch.roll(q, shifts=-shift_size_w, dims=1) 137 | k = torch.roll(k, shifts=-shift_size_w, dims=1) 138 | v = torch.roll(v, shifts=-shift_size_w, dims=1) 139 | 140 | q = split_feature_1d(q, num_splits=num_splits) # [B*H*K, W/K, C] 141 | k = split_feature_1d(k, num_splits=num_splits) 142 | v = split_feature_1d(v, num_splits=num_splits) 143 | 144 | scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1) 145 | ) / scale_factor # [B*H*K, W/K, W/K] 146 | 147 | if with_shift: 148 | # attn_mask: [K, W/K, W/K] 149 | scores += attn_mask.repeat(b * h, 1, 1) # [B*H*K, W/K, W/K] 150 | 151 | attn = torch.softmax(scores, dim=-1) 152 | 153 | out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*H*K, W/K, C] 154 | 155 | out = merge_splits_1d(out, h, num_splits=num_splits) # [B, H, W, C] 156 | 157 | # shift back 158 | if with_shift: 159 | out = torch.roll(out, shifts=shift_size_w, dims=2) 160 | 161 | out = out.view(b, -1, c) 162 | 163 | return out 164 | 165 | 166 | class SelfAttnPropagation(nn.Module): 167 | """ 168 | flow propagation with self-attention on feature 169 | query: feature0, key: feature0, value: flow 170 | """ 171 | 172 | def __init__(self, in_channels, 173 | **kwargs, 174 | ): 175 | super(SelfAttnPropagation, self).__init__() 176 | 177 | self.q_proj = nn.Linear(in_channels, in_channels) 178 | self.k_proj = nn.Linear(in_channels, in_channels) 179 | 180 | for p in self.parameters(): 181 | if p.dim() > 1: 182 | nn.init.xavier_uniform_(p) 183 | 184 | def forward(self, feature0, flow, 185 | local_window_attn=False, 186 | local_window_radius=1, 187 | **kwargs, 188 | ): 189 | # q, k: feature [B, C, H, W], v: flow [B, 2, H, W] 190 | if local_window_attn: 191 | return self.forward_local_window_attn(feature0, flow, 192 | local_window_radius=local_window_radius) 193 | 194 | b, c, h, w = feature0.size() 195 | 196 | query = feature0.view(b, c, h * w).permute(0, 2, 1) # [B, H*W, C] 197 | 198 | # a note: the ``correct'' implementation should be: 199 | # ``query = self.q_proj(query), key = self.k_proj(query)'' 200 | # this problem is observed while cleaning up the code 201 | # however, this doesn't affect the performance since the projection is a linear operation, 202 | # thus the two projection matrices for key can be merged 203 | # so I just leave it as is in order to not re-train all models :) 204 | query = self.q_proj(query) # [B, H*W, C] 205 | key = self.k_proj(query) # [B, H*W, C] 206 | 207 | value = flow.view(b, flow.size(1), h * w).permute(0, 2, 1) # [B, H*W, 2] 208 | 209 | scores = torch.matmul(query, key.permute(0, 2, 1)) / (c ** 0.5) # [B, H*W, H*W] 210 | prob = torch.softmax(scores, dim=-1) 211 | 212 | out = torch.matmul(prob, value) # [B, H*W, 2] 213 | out = out.view(b, h, w, value.size(-1)).permute(0, 3, 1, 2) # [B, 2, H, W] 214 | 215 | return out 216 | 217 | def forward_local_window_attn(self, feature0, flow, 218 | local_window_radius=1, 219 | ): 220 | assert flow.size(1) == 2 or flow.size(1) == 1 # flow or disparity or depth 221 | assert local_window_radius > 0 222 | 223 | b, c, h, w = feature0.size() 224 | 225 | value_channel = flow.size(1) 226 | 227 | feature0_reshape = self.q_proj(feature0.view(b, c, -1).permute(0, 2, 1) 228 | ).reshape(b * h * w, 1, c) # [B*H*W, 1, C] 229 | 230 | kernel_size = 2 * local_window_radius + 1 231 | 232 | feature0_proj = self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1)).permute(0, 2, 1).reshape(b, c, h, w) 233 | 234 | feature0_window = F.unfold(feature0_proj, kernel_size=kernel_size, 235 | padding=local_window_radius) # [B, C*(2R+1)^2), H*W] 236 | 237 | feature0_window = feature0_window.view(b, c, kernel_size ** 2, h, w).permute( 238 | 0, 3, 4, 1, 2).reshape(b * h * w, c, kernel_size ** 2) # [B*H*W, C, (2R+1)^2] 239 | 240 | flow_window = F.unfold(flow, kernel_size=kernel_size, 241 | padding=local_window_radius) # [B, 2*(2R+1)^2), H*W] 242 | 243 | flow_window = flow_window.view(b, value_channel, kernel_size ** 2, h, w).permute( 244 | 0, 3, 4, 2, 1).reshape(b * h * w, kernel_size ** 2, value_channel) # [B*H*W, (2R+1)^2, 2] 245 | 246 | scores = torch.matmul(feature0_reshape, feature0_window) / (c ** 0.5) # [B*H*W, 1, (2R+1)^2] 247 | 248 | prob = torch.softmax(scores, dim=-1) 249 | 250 | out = torch.matmul(prob, flow_window).view(b, h, w, value_channel 251 | ).permute(0, 3, 1, 2).contiguous() # [B, 2, H, W] 252 | 253 | return out 254 | -------------------------------------------------------------------------------- /unimatch/backbone.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .trident_conv import MultiScaleTridentConv 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, in_planes, planes, norm_layer=nn.InstanceNorm2d, stride=1, dilation=1, 8 | ): 9 | super(ResidualBlock, self).__init__() 10 | 11 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, 12 | dilation=dilation, padding=dilation, stride=stride, bias=False) 13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 14 | dilation=dilation, padding=dilation, bias=False) 15 | self.relu = nn.ReLU(inplace=True) 16 | 17 | self.norm1 = norm_layer(planes) 18 | self.norm2 = norm_layer(planes) 19 | if not stride == 1 or in_planes != planes: 20 | self.norm3 = norm_layer(planes) 21 | 22 | if stride == 1 and in_planes == planes: 23 | self.downsample = None 24 | else: 25 | self.downsample = nn.Sequential( 26 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 27 | 28 | def forward(self, x): 29 | y = x 30 | y = self.relu(self.norm1(self.conv1(y))) 31 | y = self.relu(self.norm2(self.conv2(y))) 32 | 33 | if self.downsample is not None: 34 | x = self.downsample(x) 35 | 36 | return self.relu(x + y) 37 | 38 | 39 | class CNNEncoder(nn.Module): 40 | def __init__(self, output_dim=128, 41 | norm_layer=nn.InstanceNorm2d, 42 | num_output_scales=1, 43 | return_all_scales=False, 44 | **kwargs, 45 | ): 46 | super(CNNEncoder, self).__init__() 47 | self.num_branch = num_output_scales 48 | self.return_all_scales = return_all_scales 49 | 50 | feature_dims = [64, 96, 128] 51 | 52 | self.conv1 = nn.Conv2d(3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False) # 1/2 53 | self.norm1 = norm_layer(feature_dims[0]) 54 | self.relu1 = nn.ReLU(inplace=True) 55 | 56 | self.in_planes = feature_dims[0] 57 | self.layer1 = self._make_layer(feature_dims[0], stride=1, norm_layer=norm_layer) # 1/2 58 | self.layer2 = self._make_layer(feature_dims[1], stride=2, norm_layer=norm_layer) # 1/4 59 | 60 | # highest resolution 1/4 or 1/8 61 | if return_all_scales: # depthsplat 62 | stride = 2 63 | else: 64 | stride = 2 if num_output_scales == 1 else 1 65 | self.layer3 = self._make_layer(feature_dims[2], stride=stride, 66 | norm_layer=norm_layer, 67 | ) # 1/4 or 1/8 68 | 69 | self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0) 70 | 71 | if self.num_branch > 1 and not return_all_scales: 72 | if self.num_branch == 4: 73 | strides = (1, 2, 4, 8) 74 | elif self.num_branch == 3: 75 | strides = (1, 2, 4) 76 | elif self.num_branch == 2: 77 | strides = (1, 2) 78 | else: 79 | raise ValueError 80 | 81 | self.trident_conv = MultiScaleTridentConv(output_dim, output_dim, 82 | kernel_size=3, 83 | strides=strides, 84 | paddings=1, 85 | num_branch=self.num_branch, 86 | ) 87 | 88 | for m in self.modules(): 89 | if isinstance(m, nn.Conv2d): 90 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 91 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 92 | if m.weight is not None: 93 | nn.init.constant_(m.weight, 1) 94 | if m.bias is not None: 95 | nn.init.constant_(m.bias, 0) 96 | 97 | def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d): 98 | layer1 = ResidualBlock(self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation) 99 | layer2 = ResidualBlock(dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation) 100 | 101 | layers = (layer1, layer2) 102 | 103 | self.in_planes = dim 104 | return nn.Sequential(*layers) 105 | 106 | def forward(self, x): 107 | output_all_scales = [] 108 | x = self.conv1(x) 109 | x = self.norm1(x) 110 | x = self.relu1(x) 111 | 112 | x = self.layer1(x) # 1/2 113 | if self.return_all_scales: 114 | output_all_scales.append(x) 115 | 116 | x = self.layer2(x) # 1/4 117 | if self.return_all_scales: 118 | output_all_scales.append(x) 119 | 120 | x = self.layer3(x) # 1/8 or 1/4 121 | 122 | x = self.conv2(x) 123 | 124 | if self.return_all_scales: 125 | output_all_scales.append(x) 126 | return output_all_scales 127 | 128 | if self.num_branch > 1: 129 | out = self.trident_conv([x] * self.num_branch) # high to low res 130 | else: 131 | out = [x] 132 | 133 | return out 134 | -------------------------------------------------------------------------------- /unimatch/geometry.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def coords_grid(b, h, w, homogeneous=False, device=None): 6 | y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W] 7 | 8 | stacks = [x, y] 9 | 10 | if homogeneous: 11 | ones = torch.ones_like(x) # [H, W] 12 | stacks.append(ones) 13 | 14 | grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W] 15 | 16 | grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W] 17 | 18 | if device is not None: 19 | grid = grid.to(device) 20 | 21 | return grid 22 | 23 | 24 | def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None): 25 | assert device is not None 26 | 27 | x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device), 28 | torch.linspace(h_min, h_max, len_h, device=device)], 29 | ) 30 | grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2] 31 | 32 | return grid 33 | 34 | 35 | def normalize_coords(coords, h, w): 36 | # coords: [B, H, W, 2] 37 | c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device) 38 | return (coords - c) / c # [-1, 1] 39 | 40 | 41 | def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False): 42 | # img: [B, C, H, W] 43 | # sample_coords: [B, 2, H, W] in image scale 44 | if sample_coords.size(1) != 2: # [B, H, W, 2] 45 | sample_coords = sample_coords.permute(0, 3, 1, 2) 46 | 47 | b, _, h, w = sample_coords.shape 48 | 49 | # Normalize to [-1, 1] 50 | x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1 51 | y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1 52 | 53 | grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2] 54 | 55 | img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True) 56 | 57 | if return_mask: 58 | mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W] 59 | 60 | return img, mask 61 | 62 | return img 63 | 64 | 65 | def flow_warp(feature, flow, mask=False, padding_mode='zeros'): 66 | b, c, h, w = feature.size() 67 | assert flow.size(1) == 2 68 | 69 | grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W] 70 | 71 | return bilinear_sample(feature, grid, padding_mode=padding_mode, 72 | return_mask=mask) 73 | 74 | 75 | def forward_backward_consistency_check(fwd_flow, bwd_flow, 76 | alpha=0.01, 77 | beta=0.5 78 | ): 79 | # fwd_flow, bwd_flow: [B, 2, H, W] 80 | # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837) 81 | assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 82 | assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 83 | flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W] 84 | 85 | warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W] 86 | warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W] 87 | 88 | diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W] 89 | diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1) 90 | 91 | threshold = alpha * flow_mag + beta 92 | 93 | fwd_occ = (diff_fwd > threshold).float() # [B, H, W] 94 | bwd_occ = (diff_bwd > threshold).float() 95 | 96 | return fwd_occ, bwd_occ 97 | 98 | 99 | def back_project(depth, intrinsics): 100 | # Back project 2D pixel coords to 3D points 101 | # depth: [B, H, W] 102 | # intrinsics: [B, 3, 3] 103 | b, h, w = depth.shape 104 | grid = coords_grid(b, h, w, homogeneous=True, device=depth.device) # [B, 3, H, W] 105 | 106 | intrinsics_inv = torch.inverse(intrinsics) # [B, 3, 3] 107 | 108 | points = intrinsics_inv.bmm(grid.view(b, 3, -1)).view(b, 3, h, w) * depth.unsqueeze(1) # [B, 3, H, W] 109 | 110 | return points 111 | 112 | 113 | def camera_transform(points_ref, extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None): 114 | # Transform 3D points from reference camera to target camera 115 | # points_ref: [B, 3, H, W] 116 | # extrinsics_ref: [B, 4, 4] 117 | # extrinsics_tgt: [B, 4, 4] 118 | # extrinsics_rel: [B, 4, 4], relative pose transform 119 | b, _, h, w = points_ref.shape 120 | 121 | if extrinsics_rel is None: 122 | extrinsics_rel = torch.bmm(extrinsics_tgt, torch.inverse(extrinsics_ref)) # [B, 4, 4] 123 | 124 | points_tgt = torch.bmm(extrinsics_rel[:, :3, :3], 125 | points_ref.view(b, 3, -1)) + extrinsics_rel[:, :3, -1:] # [B, 3, H*W] 126 | 127 | points_tgt = points_tgt.view(b, 3, h, w) # [B, 3, H, W] 128 | 129 | return points_tgt 130 | 131 | 132 | def reproject(points_tgt, intrinsics, return_mask=False): 133 | # reproject to target view 134 | # points_tgt: [B, 3, H, W] 135 | # intrinsics: [B, 3, 3] 136 | 137 | b, _, h, w = points_tgt.shape 138 | 139 | proj_points = torch.bmm(intrinsics, points_tgt.view(b, 3, -1)).view(b, 3, h, w) # [B, 3, H, W] 140 | 141 | X = proj_points[:, 0] 142 | Y = proj_points[:, 1] 143 | Z = proj_points[:, 2].clamp(min=1e-3) 144 | 145 | pixel_coords = torch.stack([X / Z, Y / Z], dim=1).view(b, 2, h, w) # [B, 2, H, W] in image scale 146 | 147 | if return_mask: 148 | # valid mask in pixel space 149 | mask = (pixel_coords[:, 0] >= 0) & (pixel_coords[:, 0] <= (w - 1)) & ( 150 | pixel_coords[:, 1] >= 0) & (pixel_coords[:, 1] <= (h - 1)) # [B, H, W] 151 | 152 | return pixel_coords, mask 153 | 154 | return pixel_coords 155 | 156 | 157 | def reproject_coords(depth_ref, intrinsics, extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None, 158 | return_mask=False): 159 | # Compute reprojection sample coords 160 | points_ref = back_project(depth_ref, intrinsics) # [B, 3, H, W] 161 | points_tgt = camera_transform(points_ref, extrinsics_ref, extrinsics_tgt, extrinsics_rel=extrinsics_rel) 162 | 163 | if return_mask: 164 | reproj_coords, mask = reproject(points_tgt, intrinsics, 165 | return_mask=return_mask) # [B, 2, H, W] in image scale 166 | 167 | return reproj_coords, mask 168 | 169 | reproj_coords = reproject(points_tgt, intrinsics, 170 | return_mask=return_mask) # [B, 2, H, W] in image scale 171 | 172 | return reproj_coords 173 | 174 | 175 | def compute_flow_with_depth_pose(depth_ref, intrinsics, 176 | extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None, 177 | return_mask=False): 178 | b, h, w = depth_ref.shape 179 | coords_init = coords_grid(b, h, w, device=depth_ref.device) # [B, 2, H, W] 180 | 181 | if return_mask: 182 | reproj_coords, mask = reproject_coords(depth_ref, intrinsics, extrinsics_ref, extrinsics_tgt, 183 | extrinsics_rel=extrinsics_rel, 184 | return_mask=return_mask) # [B, 2, H, W] 185 | rigid_flow = reproj_coords - coords_init 186 | 187 | return rigid_flow, mask 188 | 189 | reproj_coords = reproject_coords(depth_ref, intrinsics, extrinsics_ref, extrinsics_tgt, 190 | extrinsics_rel=extrinsics_rel, 191 | return_mask=return_mask) # [B, 2, H, W] 192 | 193 | rigid_flow = reproj_coords - coords_init 194 | 195 | return rigid_flow 196 | -------------------------------------------------------------------------------- /unimatch/ldm_unet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/unimatch/943a99006950ab6c84ea72a72beb147fdd03c19a/unimatch/ldm_unet/__init__.py -------------------------------------------------------------------------------- /unimatch/ldm_unet/attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | 8 | 9 | def exists(val): 10 | return val is not None 11 | 12 | 13 | def uniq(arr): 14 | return{el: True for el in arr}.keys() 15 | 16 | 17 | def default(val, d): 18 | if exists(val): 19 | return val 20 | return d() if isfunction(d) else d 21 | 22 | 23 | def max_neg_value(t): 24 | return -torch.finfo(t.dtype).max 25 | 26 | 27 | def init_(tensor): 28 | dim = tensor.shape[-1] 29 | std = 1 / math.sqrt(dim) 30 | tensor.uniform_(-std, std) 31 | return tensor 32 | 33 | 34 | # feedforward 35 | class GEGLU(nn.Module): 36 | def __init__(self, dim_in, dim_out): 37 | super().__init__() 38 | self.proj = nn.Linear(dim_in, dim_out * 2) 39 | 40 | def forward(self, x): 41 | x, gate = self.proj(x).chunk(2, dim=-1) 42 | return x * F.gelu(gate) 43 | 44 | 45 | class FeedForward(nn.Module): 46 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 47 | super().__init__() 48 | inner_dim = int(dim * mult) 49 | dim_out = default(dim_out, dim) 50 | project_in = nn.Sequential( 51 | nn.Linear(dim, inner_dim), 52 | nn.GELU() 53 | ) if not glu else GEGLU(dim, inner_dim) 54 | 55 | self.net = nn.Sequential( 56 | project_in, 57 | nn.Dropout(dropout), 58 | nn.Linear(inner_dim, dim_out) 59 | ) 60 | 61 | def forward(self, x): 62 | return self.net(x) 63 | 64 | 65 | def zero_module(module): 66 | """ 67 | Zero out the parameters of a module and return it. 68 | """ 69 | for p in module.parameters(): 70 | p.detach().zero_() 71 | return module 72 | 73 | 74 | def Normalize(in_channels): 75 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 76 | 77 | 78 | class LinearAttention(nn.Module): 79 | def __init__(self, dim, heads=4, dim_head=32): 80 | super().__init__() 81 | self.heads = heads 82 | hidden_dim = dim_head * heads 83 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 84 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 85 | 86 | def forward(self, x): 87 | b, c, h, w = x.shape 88 | qkv = self.to_qkv(x) 89 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) 90 | k = k.softmax(dim=-1) 91 | context = torch.einsum('bhdn,bhen->bhde', k, v) 92 | out = torch.einsum('bhde,bhdn->bhen', context, q) 93 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 94 | return self.to_out(out) 95 | 96 | 97 | class SpatialSelfAttention(nn.Module): 98 | def __init__(self, in_channels): 99 | super().__init__() 100 | self.in_channels = in_channels 101 | 102 | self.norm = Normalize(in_channels) 103 | self.q = torch.nn.Conv2d(in_channels, 104 | in_channels, 105 | kernel_size=1, 106 | stride=1, 107 | padding=0) 108 | self.k = torch.nn.Conv2d(in_channels, 109 | in_channels, 110 | kernel_size=1, 111 | stride=1, 112 | padding=0) 113 | self.v = torch.nn.Conv2d(in_channels, 114 | in_channels, 115 | kernel_size=1, 116 | stride=1, 117 | padding=0) 118 | self.proj_out = torch.nn.Conv2d(in_channels, 119 | in_channels, 120 | kernel_size=1, 121 | stride=1, 122 | padding=0) 123 | 124 | def forward(self, x): 125 | h_ = x 126 | h_ = self.norm(h_) 127 | q = self.q(h_) 128 | k = self.k(h_) 129 | v = self.v(h_) 130 | 131 | # compute attention 132 | b,c,h,w = q.shape 133 | q = rearrange(q, 'b c h w -> b (h w) c') 134 | k = rearrange(k, 'b c h w -> b c (h w)') 135 | w_ = torch.einsum('bij,bjk->bik', q, k) 136 | 137 | w_ = w_ * (int(c)**(-0.5)) 138 | w_ = torch.nn.functional.softmax(w_, dim=2) 139 | 140 | # attend to values 141 | v = rearrange(v, 'b c h w -> b c (h w)') 142 | w_ = rearrange(w_, 'b i j -> b j i') 143 | h_ = torch.einsum('bij,bjk->bik', v, w_) 144 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 145 | h_ = self.proj_out(h_) 146 | 147 | return x+h_ 148 | 149 | 150 | class CrossAttention(nn.Module): 151 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): 152 | super().__init__() 153 | inner_dim = dim_head * heads 154 | context_dim = default(context_dim, query_dim) 155 | 156 | self.scale = dim_head ** -0.5 157 | self.heads = heads 158 | 159 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 160 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 161 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 162 | 163 | self.to_out = nn.Sequential( 164 | nn.Linear(inner_dim, query_dim), 165 | nn.Dropout(dropout) 166 | ) 167 | 168 | def forward(self, x, context=None, mask=None): 169 | h = self.heads 170 | 171 | q = self.to_q(x) 172 | context = default(context, x) 173 | k = self.to_k(context) 174 | v = self.to_v(context) 175 | 176 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 177 | 178 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 179 | 180 | if exists(mask): 181 | mask = rearrange(mask, 'b ... -> b (...)') 182 | max_neg_value = -torch.finfo(sim.dtype).max 183 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 184 | sim.masked_fill_(~mask, max_neg_value) 185 | 186 | # attention, what we cannot get enough of 187 | attn = sim.softmax(dim=-1) 188 | 189 | out = einsum('b i j, b j d -> b i d', attn, v) 190 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 191 | return self.to_out(out) 192 | 193 | 194 | class BasicTransformerBlock(nn.Module): 195 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=False): 196 | super().__init__() 197 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention 198 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 199 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, 200 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none 201 | self.norm1 = nn.LayerNorm(dim) 202 | self.norm2 = nn.LayerNorm(dim) 203 | self.norm3 = nn.LayerNorm(dim) 204 | # self.checkpoint = checkpoint 205 | 206 | def forward(self, x, context=None): 207 | # return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 208 | 209 | return _forward(x, context) 210 | 211 | def _forward(self, x, context=None): 212 | x = self.attn1(self.norm1(x)) + x 213 | x = self.attn2(self.norm2(x), context=context) + x 214 | x = self.ff(self.norm3(x)) + x 215 | return x 216 | 217 | 218 | class SpatialTransformer(nn.Module): 219 | """ 220 | Transformer block for image-like data. 221 | First, project the input (aka embedding) 222 | and reshape to b, t, d. 223 | Then apply standard transformer action. 224 | Finally, reshape to image 225 | """ 226 | def __init__(self, in_channels, n_heads, d_head, 227 | depth=1, dropout=0., context_dim=None): 228 | super().__init__() 229 | self.in_channels = in_channels 230 | inner_dim = n_heads * d_head 231 | self.norm = Normalize(in_channels) 232 | 233 | self.proj_in = nn.Conv2d(in_channels, 234 | inner_dim, 235 | kernel_size=1, 236 | stride=1, 237 | padding=0) 238 | 239 | self.transformer_blocks = nn.ModuleList( 240 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) 241 | for d in range(depth)] 242 | ) 243 | 244 | self.proj_out = zero_module(nn.Conv2d(inner_dim, 245 | in_channels, 246 | kernel_size=1, 247 | stride=1, 248 | padding=0)) 249 | 250 | def forward(self, x, context=None): 251 | # note: if no context is given, cross-attention defaults to self-attention 252 | b, c, h, w = x.shape 253 | x_in = x 254 | x = self.norm(x) 255 | x = self.proj_in(x) 256 | x = rearrange(x, 'b c h w -> b (h w) c') 257 | for block in self.transformer_blocks: 258 | x = block(x, context=context) 259 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 260 | x = self.proj_out(x) 261 | return x + x_in -------------------------------------------------------------------------------- /unimatch/ldm_unet/cross_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import os 5 | import warnings 6 | 7 | 8 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 9 | try: 10 | if XFORMERS_ENABLED: 11 | from xformers.ops import memory_efficient_attention, unbind 12 | 13 | XFORMERS_AVAILABLE = True 14 | # warnings.warn("xFormers is available (Attention)") 15 | else: 16 | # warnings.warn("xFormers is disabled (Attention)") 17 | raise ImportError 18 | except ImportError: 19 | XFORMERS_AVAILABLE = False 20 | # warnings.warn("xFormers is not available (Attention)") 21 | 22 | 23 | class CrossAttention(nn.Module): 24 | def __init__( 25 | self, 26 | in_dim1, 27 | in_dim2, 28 | dim=128, 29 | out_dim=None, 30 | num_heads=4, 31 | qkv_bias=False, 32 | proj_bias=False, 33 | ): 34 | super().__init__() 35 | 36 | assert XFORMERS_AVAILABLE 37 | 38 | if out_dim is None: 39 | out_dim = in_dim1 40 | 41 | self.num_heads = num_heads 42 | self.dim = dim 43 | self.q = nn.Linear(in_dim1, dim, bias=qkv_bias) 44 | self.kv = nn.Linear(in_dim2, dim * 2, bias=qkv_bias) 45 | self.proj = nn.Linear(dim, out_dim, bias=proj_bias) 46 | 47 | def forward(self, x, y): 48 | c = self.dim 49 | b, n1, c1 = x.shape 50 | n2, c2 = y.shape[1:] 51 | 52 | q = self.q(x).reshape(b, n1, self.num_heads, c // self.num_heads) 53 | kv = self.kv(y).reshape(b, n2, 2, self.num_heads, c // self.num_heads) 54 | k, v = unbind(kv, 2) 55 | 56 | x = memory_efficient_attention(q, k, v) 57 | x = x.reshape(b, n1, c) 58 | 59 | x = self.proj(x) 60 | 61 | return x 62 | 63 | 64 | class UNetCrossAttentionBlock(nn.Module): 65 | def __init__(self, 66 | in_dim1, 67 | in_dim2, 68 | dim=128, 69 | out_dim=None, 70 | num_heads=4, 71 | qkv_bias=False, 72 | proj_bias=False, 73 | with_ffn=False, 74 | concat_cross_attn=False, 75 | concat_output=False, 76 | no_cross_attn=False, 77 | with_norm=False, 78 | concat_conv3x3=False, 79 | ): 80 | super().__init__() 81 | 82 | out_dim = out_dim or in_dim1 83 | 84 | self.no_cross_attn = no_cross_attn 85 | self.with_norm = with_norm 86 | 87 | if no_cross_attn: 88 | if concat_conv3x3: 89 | self.proj = nn.Conv2d(in_dim1 + in_dim2, out_dim, 3, 1, 1) 90 | else: 91 | self.proj = nn.Conv2d(in_dim1 + in_dim2, out_dim, 1) 92 | else: 93 | self.with_ffn = with_ffn 94 | self.concat_cross_attn = concat_cross_attn 95 | self.concat_output = concat_output 96 | 97 | self.cross_attn = CrossAttention( 98 | in_dim1=in_dim1, 99 | in_dim2=in_dim2, 100 | dim=dim, 101 | out_dim=out_dim, 102 | num_heads=num_heads, 103 | qkv_bias=qkv_bias, 104 | proj_bias=proj_bias, 105 | ) 106 | 107 | if with_norm: 108 | self.norm1 = nn.LayerNorm(out_dim) 109 | else: 110 | self.norm1 = nn.Identity() 111 | 112 | if with_ffn: 113 | in_channels = out_dim + in_dim1 if concat_cross_attn else in_dim1 114 | ffn_dim_expansion = 4 115 | self.mlp = nn.Sequential( 116 | nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False), 117 | nn.GELU(), 118 | nn.Linear(in_channels * ffn_dim_expansion, in_dim1, bias=False), 119 | ) 120 | 121 | if with_norm: 122 | self.norm2 = nn.LayerNorm(in_dim1) 123 | else: 124 | self.norm2 = nn.Identity() 125 | 126 | if self.concat_output: 127 | self.out = nn.Linear(out_dim + in_dim1, in_dim1) 128 | 129 | def forward(self, x, y): 130 | # x: [B, C, H, W] 131 | # y: [B, N, C] or [B, C, H, W] 132 | 133 | if self.no_cross_attn: 134 | assert x.dim() == 4 and y.dim() == 4 135 | if y.shape[2:] != x.shape[2:]: 136 | y = F.interpolate(y, x.shape[2:], mode='bilinear', align_corners=True) 137 | return self.proj(torch.cat((x, y), dim=1)) 138 | 139 | identity = x 140 | 141 | b, c, h, w = x.size() 142 | x = x.view(b, c, -1).permute(0, 2, 1) 143 | 144 | cross_attn = self.norm1(self.cross_attn(x, y)) 145 | 146 | if self.with_ffn: 147 | if self.concat_cross_attn: 148 | concat = torch.cat((x, cross_attn), dim=-1) 149 | else: 150 | concat = x + cross_attn 151 | 152 | cross_attn = self.norm2(self.mlp(concat)) 153 | 154 | if self.concat_output: 155 | return self.out(torch.cat((x, cross_attn), dim=-1)) 156 | 157 | # reshape back 158 | cross_attn = cross_attn.view(b, h, w, c).permute(0, 3, 1, 2) # [B, C, H, W] 159 | 160 | return identity + cross_attn 161 | 162 | -------------------------------------------------------------------------------- /unimatch/ldm_unet/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from einops import repeat 17 | 18 | # from ldm.util import instantiate_from_config 19 | 20 | 21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 22 | if schedule == "linear": 23 | betas = ( 24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 25 | ) 26 | 27 | elif schedule == "cosine": 28 | timesteps = ( 29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 30 | ) 31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 32 | alphas = torch.cos(alphas).pow(2) 33 | alphas = alphas / alphas[0] 34 | betas = 1 - alphas[1:] / alphas[:-1] 35 | betas = np.clip(betas, a_min=0, a_max=0.999) 36 | 37 | elif schedule == "sqrt_linear": 38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 39 | elif schedule == "sqrt": 40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 41 | else: 42 | raise ValueError(f"schedule '{schedule}' unknown.") 43 | return betas.numpy() 44 | 45 | 46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 47 | if ddim_discr_method == 'uniform': 48 | c = num_ddpm_timesteps // num_ddim_timesteps 49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 50 | elif ddim_discr_method == 'quad': 51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 52 | else: 53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 54 | 55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 56 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 57 | steps_out = ddim_timesteps + 1 58 | if verbose: 59 | print(f'Selected timesteps for ddim sampler: {steps_out}') 60 | return steps_out 61 | 62 | 63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 64 | # select alphas for computing the variance schedule 65 | alphas = alphacums[ddim_timesteps] 66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 67 | 68 | # according the the formula provided in https://arxiv.org/abs/2010.02502 69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 70 | if verbose: 71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 72 | print(f'For the chosen value of eta, which is {eta}, ' 73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 74 | return sigmas, alphas, alphas_prev 75 | 76 | 77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 78 | """ 79 | Create a beta schedule that discretizes the given alpha_t_bar function, 80 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 81 | :param num_diffusion_timesteps: the number of betas to produce. 82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 83 | produces the cumulative product of (1-beta) up to that 84 | part of the diffusion process. 85 | :param max_beta: the maximum beta to use; use values lower than 1 to 86 | prevent singularities. 87 | """ 88 | betas = [] 89 | for i in range(num_diffusion_timesteps): 90 | t1 = i / num_diffusion_timesteps 91 | t2 = (i + 1) / num_diffusion_timesteps 92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 93 | return np.array(betas) 94 | 95 | 96 | def extract_into_tensor(a, t, x_shape): 97 | b, *_ = t.shape 98 | out = a.gather(-1, t) 99 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 100 | 101 | 102 | def checkpoint(func, inputs, params, flag): 103 | """ 104 | Evaluate a function without caching intermediate activations, allowing for 105 | reduced memory at the expense of extra compute in the backward pass. 106 | :param func: the function to evaluate. 107 | :param inputs: the argument sequence to pass to `func`. 108 | :param params: a sequence of parameters `func` depends on but does not 109 | explicitly take as arguments. 110 | :param flag: if False, disable gradient checkpointing. 111 | """ 112 | if flag: 113 | args = tuple(inputs) + tuple(params) 114 | return CheckpointFunction.apply(func, len(inputs), *args) 115 | else: 116 | return func(*inputs) 117 | 118 | 119 | class CheckpointFunction(torch.autograd.Function): 120 | @staticmethod 121 | def forward(ctx, run_function, length, *args): 122 | ctx.run_function = run_function 123 | ctx.input_tensors = list(args[:length]) 124 | ctx.input_params = list(args[length:]) 125 | 126 | with torch.no_grad(): 127 | output_tensors = ctx.run_function(*ctx.input_tensors) 128 | return output_tensors 129 | 130 | @staticmethod 131 | def backward(ctx, *output_grads): 132 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 133 | with torch.enable_grad(): 134 | # Fixes a bug where the first op in run_function modifies the 135 | # Tensor storage in place, which is not allowed for detach()'d 136 | # Tensors. 137 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 138 | output_tensors = ctx.run_function(*shallow_copies) 139 | input_grads = torch.autograd.grad( 140 | output_tensors, 141 | ctx.input_tensors + ctx.input_params, 142 | output_grads, 143 | allow_unused=True, 144 | ) 145 | del ctx.input_tensors 146 | del ctx.input_params 147 | del output_tensors 148 | return (None, None) + input_grads 149 | 150 | 151 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 152 | """ 153 | Create sinusoidal timestep embeddings. 154 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 155 | These may be fractional. 156 | :param dim: the dimension of the output. 157 | :param max_period: controls the minimum frequency of the embeddings. 158 | :return: an [N x dim] Tensor of positional embeddings. 159 | """ 160 | if not repeat_only: 161 | half = dim // 2 162 | freqs = torch.exp( 163 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 164 | ).to(device=timesteps.device) 165 | args = timesteps[:, None].float() * freqs[None] 166 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 167 | if dim % 2: 168 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 169 | else: 170 | embedding = repeat(timesteps, 'b -> b d', d=dim) 171 | return embedding 172 | 173 | 174 | def zero_module(module): 175 | """ 176 | Zero out the parameters of a module and return it. 177 | """ 178 | for p in module.parameters(): 179 | p.detach().zero_() 180 | return module 181 | 182 | 183 | def scale_module(module, scale): 184 | """ 185 | Scale the parameters of a module and return it. 186 | """ 187 | for p in module.parameters(): 188 | p.detach().mul_(scale) 189 | return module 190 | 191 | 192 | def mean_flat(tensor): 193 | """ 194 | Take the mean over all non-batch dimensions. 195 | """ 196 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 197 | 198 | 199 | def normalization(channels, channels_per_group=None): 200 | """ 201 | Make a standard normalization layer. 202 | :param channels: number of input channels. 203 | :return: an nn.Module for normalization. 204 | """ 205 | # return GroupNorm32(32, channels) # original 206 | 207 | if channels_per_group is not None: 208 | return GroupNorm(channels // channels_per_group, channels) 209 | # return GroupNorm4(4, channels) 210 | # if channels % channels_per_group == 0: 211 | # # adjust group number according to the channels 212 | # return GroupNorm8(channels // channels_per_group, channels) 213 | # else: 214 | # return GroupNorm4(4, channels) 215 | 216 | if channels % 8 != 0: 217 | return GroupNorm4(4, channels) 218 | 219 | return GroupNorm8(8, channels) 220 | 221 | 222 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 223 | class SiLU(nn.Module): 224 | def forward(self, x): 225 | return x * torch.sigmoid(x) 226 | 227 | 228 | class GroupNorm(nn.GroupNorm): 229 | def forward(self, x): 230 | return super().forward(x.float()).type(x.dtype) 231 | 232 | class GroupNorm32(nn.GroupNorm): 233 | def forward(self, x): 234 | return super().forward(x.float()).type(x.dtype) 235 | 236 | 237 | class GroupNorm8(nn.GroupNorm): 238 | def forward(self, x): 239 | return super().forward(x.float()).type(x.dtype) 240 | 241 | class GroupNorm4(nn.GroupNorm): 242 | def forward(self, x): 243 | return super().forward(x.float()).type(x.dtype) 244 | 245 | def conv_nd(dims, *args, **kwargs): 246 | """ 247 | Create a 1D, 2D, or 3D convolution module. 248 | """ 249 | if dims == 1: 250 | return nn.Conv1d(*args, **kwargs) 251 | elif dims == 2: 252 | return nn.Conv2d(*args, **kwargs) 253 | elif dims == 3: 254 | return nn.Conv3d(*args, **kwargs) 255 | raise ValueError(f"unsupported dimensions: {dims}") 256 | 257 | 258 | def linear(*args, **kwargs): 259 | """ 260 | Create a linear module. 261 | """ 262 | return nn.Linear(*args, **kwargs) 263 | 264 | 265 | def avg_pool_nd(dims, *args, **kwargs): 266 | """ 267 | Create a 1D, 2D, or 3D average pooling module. 268 | """ 269 | if dims == 1: 270 | return nn.AvgPool1d(*args, **kwargs) 271 | elif dims == 2: 272 | return nn.AvgPool2d(*args, **kwargs) 273 | elif dims == 3: 274 | return nn.AvgPool3d(*args, **kwargs) 275 | raise ValueError(f"unsupported dimensions: {dims}") 276 | 277 | 278 | # class HybridConditioner(nn.Module): 279 | 280 | # def __init__(self, c_concat_config, c_crossattn_config): 281 | # super().__init__() 282 | # self.concat_conditioner = instantiate_from_config(c_concat_config) 283 | # self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 284 | 285 | # def forward(self, c_concat, c_crossattn): 286 | # c_concat = self.concat_conditioner(c_concat) 287 | # c_crossattn = self.crossattn_conditioner(c_crossattn) 288 | # return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 289 | 290 | 291 | def noise_like(shape, device, repeat=False): 292 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 293 | noise = lambda: torch.randn(shape, device=device) 294 | return repeat_noise() if repeat else noise() -------------------------------------------------------------------------------- /unimatch/position.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py 3 | 4 | import torch 5 | import torch.nn as nn 6 | import math 7 | 8 | 9 | class PositionEmbeddingSine(nn.Module): 10 | """ 11 | This is a more standard version of the position embedding, very similar to the one 12 | used by the Attention is all you need paper, generalized to work on images. 13 | """ 14 | 15 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None): 16 | super().__init__() 17 | self.num_pos_feats = num_pos_feats 18 | self.temperature = temperature 19 | self.normalize = normalize 20 | if scale is not None and normalize is False: 21 | raise ValueError("normalize should be True if scale is passed") 22 | if scale is None: 23 | scale = 2 * math.pi 24 | self.scale = scale 25 | 26 | def forward(self, x): 27 | # x = tensor_list.tensors # [B, C, H, W] 28 | # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0 29 | b, c, h, w = x.size() 30 | mask = torch.ones((b, h, w), device=x.device) # [B, H, W] 31 | y_embed = mask.cumsum(1, dtype=torch.float32) 32 | x_embed = mask.cumsum(2, dtype=torch.float32) 33 | if self.normalize: 34 | eps = 1e-6 35 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 36 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 37 | 38 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 39 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 40 | 41 | pos_x = x_embed[:, :, :, None] / dim_t 42 | pos_y = y_embed[:, :, :, None] / dim_t 43 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 44 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 45 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 46 | return pos 47 | -------------------------------------------------------------------------------- /unimatch/reg_refine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FlowHead(nn.Module): 7 | def __init__(self, input_dim=128, hidden_dim=256, 8 | out_dim=2, 9 | ): 10 | super(FlowHead, self).__init__() 11 | 12 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 13 | self.conv2 = nn.Conv2d(hidden_dim, out_dim, 3, padding=1) 14 | self.relu = nn.ReLU(inplace=True) 15 | 16 | def forward(self, x): 17 | out = self.conv2(self.relu(self.conv1(x))) 18 | 19 | return out 20 | 21 | 22 | class SepConvGRU(nn.Module): 23 | def __init__(self, hidden_dim=128, input_dim=192 + 128, 24 | kernel_size=5, 25 | ): 26 | padding = (kernel_size - 1) // 2 27 | 28 | super(SepConvGRU, self).__init__() 29 | self.convz1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding)) 30 | self.convr1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding)) 31 | self.convq1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding)) 32 | 33 | self.convz2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0)) 34 | self.convr2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0)) 35 | self.convq2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0)) 36 | 37 | def forward(self, h, x): 38 | # horizontal 39 | hx = torch.cat([h, x], dim=1) 40 | z = torch.sigmoid(self.convz1(hx)) 41 | r = torch.sigmoid(self.convr1(hx)) 42 | q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) 43 | h = (1 - z) * h + z * q 44 | 45 | # vertical 46 | hx = torch.cat([h, x], dim=1) 47 | z = torch.sigmoid(self.convz2(hx)) 48 | r = torch.sigmoid(self.convr2(hx)) 49 | q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) 50 | h = (1 - z) * h + z * q 51 | 52 | return h 53 | 54 | 55 | class BasicMotionEncoder(nn.Module): 56 | def __init__(self, corr_channels=324, 57 | flow_channels=2, 58 | ): 59 | super(BasicMotionEncoder, self).__init__() 60 | 61 | self.convc1 = nn.Conv2d(corr_channels, 256, 1, padding=0) 62 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 63 | self.convf1 = nn.Conv2d(flow_channels, 128, 7, padding=3) 64 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 65 | self.conv = nn.Conv2d(64 + 192, 128 - flow_channels, 3, padding=1) 66 | 67 | def forward(self, flow, corr): 68 | cor = F.relu(self.convc1(corr)) 69 | cor = F.relu(self.convc2(cor)) 70 | flo = F.relu(self.convf1(flow)) 71 | flo = F.relu(self.convf2(flo)) 72 | 73 | cor_flo = torch.cat([cor, flo], dim=1) 74 | out = F.relu(self.conv(cor_flo)) 75 | return torch.cat([out, flow], dim=1) 76 | 77 | 78 | class BasicUpdateBlock(nn.Module): 79 | def __init__(self, corr_channels=324, 80 | hidden_dim=128, 81 | context_dim=128, 82 | downsample_factor=8, 83 | flow_dim=2, 84 | bilinear_up=False, 85 | ): 86 | super(BasicUpdateBlock, self).__init__() 87 | 88 | self.encoder = BasicMotionEncoder(corr_channels=corr_channels, 89 | flow_channels=flow_dim, 90 | ) 91 | 92 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=context_dim + hidden_dim) 93 | 94 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256, 95 | out_dim=flow_dim, 96 | ) 97 | 98 | if bilinear_up: 99 | self.mask = None 100 | else: 101 | self.mask = nn.Sequential( 102 | nn.Conv2d(hidden_dim, 256, 3, padding=1), 103 | nn.ReLU(inplace=True), 104 | nn.Conv2d(256, downsample_factor ** 2 * 9, 1, padding=0)) 105 | 106 | def forward(self, net, inp, corr, flow): 107 | motion_features = self.encoder(flow, corr) 108 | 109 | inp = torch.cat([inp, motion_features], dim=1) 110 | 111 | net = self.gru(net, inp) 112 | delta_flow = self.flow_head(net) 113 | 114 | if self.mask is not None: 115 | mask = self.mask(net) 116 | else: 117 | mask = None 118 | 119 | return net, mask, delta_flow 120 | -------------------------------------------------------------------------------- /unimatch/trident_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # https://github.com/facebookresearch/detectron2/blob/main/projects/TridentNet/tridentnet/trident_conv.py 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from torch.nn.modules.utils import _pair 8 | 9 | 10 | class MultiScaleTridentConv(nn.Module): 11 | def __init__( 12 | self, 13 | in_channels, 14 | out_channels, 15 | kernel_size, 16 | stride=1, 17 | strides=1, 18 | paddings=0, 19 | dilations=1, 20 | dilation=1, 21 | groups=1, 22 | num_branch=1, 23 | test_branch_idx=-1, 24 | bias=False, 25 | norm=None, 26 | activation=None, 27 | ): 28 | super(MultiScaleTridentConv, self).__init__() 29 | self.in_channels = in_channels 30 | self.out_channels = out_channels 31 | self.kernel_size = _pair(kernel_size) 32 | self.num_branch = num_branch 33 | self.stride = _pair(stride) 34 | self.groups = groups 35 | self.with_bias = bias 36 | self.dilation = dilation 37 | if isinstance(paddings, int): 38 | paddings = [paddings] * self.num_branch 39 | if isinstance(dilations, int): 40 | dilations = [dilations] * self.num_branch 41 | if isinstance(strides, int): 42 | strides = [strides] * self.num_branch 43 | self.paddings = [_pair(padding) for padding in paddings] 44 | self.dilations = [_pair(dilation) for dilation in dilations] 45 | self.strides = [_pair(stride) for stride in strides] 46 | self.test_branch_idx = test_branch_idx 47 | self.norm = norm 48 | self.activation = activation 49 | 50 | assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1 51 | 52 | self.weight = nn.Parameter( 53 | torch.Tensor(out_channels, in_channels // groups, *self.kernel_size) 54 | ) 55 | if bias: 56 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 57 | else: 58 | self.bias = None 59 | 60 | nn.init.kaiming_uniform_(self.weight, nonlinearity="relu") 61 | if self.bias is not None: 62 | nn.init.constant_(self.bias, 0) 63 | 64 | def forward(self, inputs): 65 | num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1 66 | assert len(inputs) == num_branch 67 | 68 | if self.training or self.test_branch_idx == -1: 69 | outputs = [ 70 | F.conv2d(input, self.weight, self.bias, stride, padding, self.dilation, self.groups) 71 | for input, stride, padding in zip(inputs, self.strides, self.paddings) 72 | ] 73 | else: 74 | outputs = [ 75 | F.conv2d( 76 | inputs[0], 77 | self.weight, 78 | self.bias, 79 | self.strides[self.test_branch_idx] if self.test_branch_idx == -1 else self.strides[-1], 80 | self.paddings[self.test_branch_idx] if self.test_branch_idx == -1 else self.paddings[-1], 81 | self.dilation, 82 | self.groups, 83 | ) 84 | ] 85 | 86 | if self.norm is not None: 87 | outputs = [self.norm(x) for x in outputs] 88 | if self.activation is not None: 89 | outputs = [self.activation(x) for x in outputs] 90 | return outputs 91 | -------------------------------------------------------------------------------- /unimatch/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from .position import PositionEmbeddingSine 4 | 5 | 6 | def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None): 7 | assert device is not None 8 | 9 | x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device), 10 | torch.linspace(h_min, h_max, len_h, device=device)], 11 | ) 12 | grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2] 13 | 14 | return grid 15 | 16 | 17 | def normalize_coords(coords, h, w): 18 | # coords: [B, H, W, 2] 19 | c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device) 20 | return (coords - c) / c # [-1, 1] 21 | 22 | 23 | def normalize_img(img0, img1): 24 | # loaded images are in [0, 255] 25 | # normalize by ImageNet mean and std 26 | mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(img1.device) 27 | std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(img1.device) 28 | img0 = (img0 / 255. - mean) / std 29 | img1 = (img1 / 255. - mean) / std 30 | 31 | return img0, img1 32 | 33 | 34 | def split_feature(feature, 35 | num_splits=2, 36 | channel_last=False, 37 | ): 38 | if channel_last: # [B, H, W, C] 39 | b, h, w, c = feature.size() 40 | assert h % num_splits == 0 and w % num_splits == 0 41 | 42 | b_new = b * num_splits * num_splits 43 | h_new = h // num_splits 44 | w_new = w // num_splits 45 | 46 | feature = feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c 47 | ).permute(0, 1, 3, 2, 4, 5).reshape(b_new, h_new, w_new, c) # [B*K*K, H/K, W/K, C] 48 | else: # [B, C, H, W] 49 | b, c, h, w = feature.size() 50 | assert h % num_splits == 0 and w % num_splits == 0 51 | 52 | b_new = b * num_splits * num_splits 53 | h_new = h // num_splits 54 | w_new = w // num_splits 55 | 56 | feature = feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits 57 | ).permute(0, 2, 4, 1, 3, 5).reshape(b_new, c, h_new, w_new) # [B*K*K, C, H/K, W/K] 58 | 59 | return feature 60 | 61 | 62 | def merge_splits(splits, 63 | num_splits=2, 64 | channel_last=False, 65 | ): 66 | if channel_last: # [B*K*K, H/K, W/K, C] 67 | b, h, w, c = splits.size() 68 | new_b = b // num_splits // num_splits 69 | 70 | splits = splits.view(new_b, num_splits, num_splits, h, w, c) 71 | merge = splits.permute(0, 1, 3, 2, 4, 5).contiguous().view( 72 | new_b, num_splits * h, num_splits * w, c) # [B, H, W, C] 73 | else: # [B*K*K, C, H/K, W/K] 74 | b, c, h, w = splits.size() 75 | new_b = b // num_splits // num_splits 76 | 77 | splits = splits.view(new_b, num_splits, num_splits, c, h, w) 78 | merge = splits.permute(0, 3, 1, 4, 2, 5).contiguous().view( 79 | new_b, c, num_splits * h, num_splits * w) # [B, C, H, W] 80 | 81 | return merge 82 | 83 | 84 | def generate_shift_window_attn_mask(input_resolution, window_size_h, window_size_w, 85 | shift_size_h, shift_size_w, device=torch.device('cuda')): 86 | # ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py 87 | # calculate attention mask for SW-MSA 88 | h, w = input_resolution 89 | img_mask = torch.zeros((1, h, w, 1)).to(device) # 1 H W 1 90 | h_slices = (slice(0, -window_size_h), 91 | slice(-window_size_h, -shift_size_h), 92 | slice(-shift_size_h, None)) 93 | w_slices = (slice(0, -window_size_w), 94 | slice(-window_size_w, -shift_size_w), 95 | slice(-shift_size_w, None)) 96 | cnt = 0 97 | for h in h_slices: 98 | for w in w_slices: 99 | img_mask[:, h, w, :] = cnt 100 | cnt += 1 101 | 102 | mask_windows = split_feature(img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True) 103 | 104 | mask_windows = mask_windows.view(-1, window_size_h * window_size_w) 105 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 106 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 107 | 108 | return attn_mask 109 | 110 | 111 | def feature_add_position(feature0, feature1, attn_splits, feature_channels): 112 | pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2) 113 | 114 | if attn_splits > 1: # add position in splited window 115 | feature0_splits = split_feature(feature0, num_splits=attn_splits) 116 | feature1_splits = split_feature(feature1, num_splits=attn_splits) 117 | 118 | position = pos_enc(feature0_splits) 119 | 120 | feature0_splits = feature0_splits + position 121 | feature1_splits = feature1_splits + position 122 | 123 | feature0 = merge_splits(feature0_splits, num_splits=attn_splits) 124 | feature1 = merge_splits(feature1_splits, num_splits=attn_splits) 125 | else: 126 | position = pos_enc(feature0) 127 | 128 | feature0 = feature0 + position 129 | feature1 = feature1 + position 130 | 131 | return feature0, feature1 132 | 133 | 134 | def upsample_flow_with_mask(flow, up_mask, upsample_factor, 135 | is_depth=False): 136 | # convex upsampling following raft 137 | 138 | mask = up_mask 139 | b, flow_channel, h, w = flow.shape 140 | mask = mask.view(b, 1, 9, upsample_factor, upsample_factor, h, w) # [B, 1, 9, K, K, H, W] 141 | mask = torch.softmax(mask, dim=2) 142 | 143 | multiplier = 1 if is_depth else upsample_factor 144 | up_flow = F.unfold(multiplier * flow, [3, 3], padding=1) 145 | up_flow = up_flow.view(b, flow_channel, 9, 1, 1, h, w) # [B, 2, 9, 1, 1, H, W] 146 | 147 | up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W] 148 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W] 149 | up_flow = up_flow.reshape(b, flow_channel, upsample_factor * h, 150 | upsample_factor * w) # [B, 2, K*H, K*W] 151 | 152 | return up_flow 153 | 154 | 155 | def split_feature_1d(feature, 156 | num_splits=2, 157 | ): 158 | # feature: [B, W, C] 159 | b, w, c = feature.size() 160 | assert w % num_splits == 0 161 | 162 | b_new = b * num_splits 163 | w_new = w // num_splits 164 | 165 | feature = feature.view(b, num_splits, w // num_splits, c 166 | ).view(b_new, w_new, c) # [B*K, W/K, C] 167 | 168 | return feature 169 | 170 | 171 | def merge_splits_1d(splits, 172 | h, 173 | num_splits=2, 174 | ): 175 | b, w, c = splits.size() 176 | new_b = b // num_splits // h 177 | 178 | splits = splits.view(new_b, h, num_splits, w, c) 179 | merge = splits.view( 180 | new_b, h, num_splits * w, c) # [B, H, W, C] 181 | 182 | return merge 183 | 184 | 185 | def window_partition_1d(x, window_size_w): 186 | """ 187 | Args: 188 | x: (B, W, C) 189 | window_size (int): window size 190 | 191 | Returns: 192 | windows: (num_windows*B, window_size, C) 193 | """ 194 | B, W, C = x.shape 195 | x = x.view(B, W // window_size_w, window_size_w, C).view(-1, window_size_w, C) 196 | return x 197 | 198 | 199 | def generate_shift_window_attn_mask_1d(input_w, window_size_w, 200 | shift_size_w, device=torch.device('cuda')): 201 | # calculate attention mask for SW-MSA 202 | img_mask = torch.zeros((1, input_w, 1)).to(device) # 1 W 1 203 | w_slices = (slice(0, -window_size_w), 204 | slice(-window_size_w, -shift_size_w), 205 | slice(-shift_size_w, None)) 206 | cnt = 0 207 | for w in w_slices: 208 | img_mask[:, w, :] = cnt 209 | cnt += 1 210 | 211 | mask_windows = window_partition_1d(img_mask, window_size_w) # nW, window_size, 1 212 | mask_windows = mask_windows.view(-1, window_size_w) 213 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # nW, window_size, window_size 214 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 215 | 216 | return attn_mask 217 | -------------------------------------------------------------------------------- /unimatch/vit_fpn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | # Ref: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py#L363 5 | 6 | 7 | class ViTFeaturePyramid(nn.Module): 8 | """ 9 | This module implements SimpleFeaturePyramid in :paper:`vitdet`. 10 | It creates pyramid features built on top of the input feature map. 11 | """ 12 | 13 | def __init__( 14 | self, 15 | in_channels, 16 | scale_factors, 17 | ): 18 | """ 19 | Args: 20 | scale_factors (list[float]): list of scaling factors to upsample or downsample 21 | the input features for creating pyramid features. 22 | """ 23 | super(ViTFeaturePyramid, self).__init__() 24 | 25 | self.scale_factors = scale_factors 26 | 27 | out_dim = dim = in_channels 28 | self.stages = nn.ModuleList() 29 | for idx, scale in enumerate(scale_factors): 30 | if scale == 4.0: 31 | layers = [ 32 | nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2), 33 | nn.GELU(), 34 | nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2), 35 | ] 36 | out_dim = dim // 4 37 | elif scale == 2.0: 38 | layers = [nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2)] 39 | out_dim = dim // 2 40 | elif scale == 1.0: 41 | layers = [] 42 | elif scale == 0.5: 43 | layers = [nn.MaxPool2d(kernel_size=2, stride=2)] 44 | else: 45 | raise NotImplementedError(f"scale_factor={scale} is not supported yet.") 46 | 47 | if scale != 1.0: 48 | layers.extend( 49 | [ 50 | nn.GELU(), 51 | nn.Conv2d(out_dim, out_dim, 3, 1, 1), 52 | ] 53 | ) 54 | layers = nn.Sequential(*layers) 55 | 56 | self.stages.append(layers) 57 | 58 | def forward(self, x): 59 | results = [] 60 | 61 | for stage in self.stages: 62 | results.append(stage(x)) 63 | 64 | return results 65 | -------------------------------------------------------------------------------- /utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | # https://github.com/open-mmlab/mmcv/blob/7540cf73ac7e5d1e14d0ffbd9b6759e83929ecfc/mmcv/runner/dist_utils.py 3 | 4 | import os 5 | import subprocess 6 | 7 | import torch 8 | import torch.multiprocessing as mp 9 | from torch import distributed as dist 10 | 11 | 12 | def init_dist(launcher, backend='nccl', **kwargs): 13 | if mp.get_start_method(allow_none=True) is None: 14 | mp.set_start_method('spawn') 15 | if launcher == 'pytorch': 16 | _init_dist_pytorch(backend, **kwargs) 17 | elif launcher == 'mpi': 18 | _init_dist_mpi(backend, **kwargs) 19 | elif launcher == 'slurm': 20 | _init_dist_slurm(backend, **kwargs) 21 | else: 22 | raise ValueError(f'Invalid launcher type: {launcher}') 23 | 24 | 25 | def _init_dist_pytorch(backend, **kwargs): 26 | # TODO: use local_rank instead of rank % num_gpus 27 | rank = int(os.environ['RANK']) 28 | num_gpus = torch.cuda.device_count() 29 | torch.cuda.set_device(rank % num_gpus) 30 | dist.init_process_group(backend=backend, **kwargs) 31 | 32 | 33 | def _init_dist_mpi(backend, **kwargs): 34 | # TODO: use local_rank instead of rank % num_gpus 35 | rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 36 | num_gpus = torch.cuda.device_count() 37 | torch.cuda.set_device(rank % num_gpus) 38 | dist.init_process_group(backend=backend, **kwargs) 39 | 40 | 41 | def _init_dist_slurm(backend, port=None): 42 | """Initialize slurm distributed training environment. 43 | If argument ``port`` is not specified, then the master port will be system 44 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 45 | environment variable, then a default port ``29500`` will be used. 46 | Args: 47 | backend (str): Backend of torch.distributed. 48 | port (int, optional): Master port. Defaults to None. 49 | """ 50 | proc_id = int(os.environ['SLURM_PROCID']) 51 | ntasks = int(os.environ['SLURM_NTASKS']) 52 | node_list = os.environ['SLURM_NODELIST'] 53 | num_gpus = torch.cuda.device_count() 54 | torch.cuda.set_device(proc_id % num_gpus) 55 | addr = subprocess.getoutput( 56 | f'scontrol show hostname {node_list} | head -n1') 57 | # specify master port 58 | if port is not None: 59 | os.environ['MASTER_PORT'] = str(port) 60 | elif 'MASTER_PORT' in os.environ: 61 | pass # use MASTER_PORT in the environment variable 62 | else: 63 | # 29500 is torch.distributed default port 64 | os.environ['MASTER_PORT'] = '29500' 65 | # use MASTER_ADDR in the environment variable if it already exists 66 | if 'MASTER_ADDR' not in os.environ: 67 | os.environ['MASTER_ADDR'] = addr 68 | os.environ['WORLD_SIZE'] = str(ntasks) 69 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 70 | os.environ['RANK'] = str(proc_id) 71 | dist.init_process_group(backend=backend) 72 | 73 | 74 | def get_dist_info(): 75 | # if (TORCH_VERSION != 'parrots' 76 | # and digit_version(TORCH_VERSION) < digit_version('1.0')): 77 | # initialized = dist._initialized 78 | # else: 79 | if dist.is_available(): 80 | initialized = dist.is_initialized() 81 | else: 82 | initialized = False 83 | if initialized: 84 | rank = dist.get_rank() 85 | world_size = dist.get_world_size() 86 | else: 87 | rank = 0 88 | world_size = 1 89 | return rank, world_size 90 | 91 | 92 | # from DETR repo 93 | def setup_for_distributed(is_master): 94 | """ 95 | This function disables printing when not in master process 96 | """ 97 | import builtins as __builtin__ 98 | builtin_print = __builtin__.print 99 | 100 | def print(*args, **kwargs): 101 | force = kwargs.pop('force', False) 102 | if is_master or force: 103 | builtin_print(*args, **kwargs) 104 | 105 | __builtin__.print = print 106 | -------------------------------------------------------------------------------- /utils/file_io.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import re 7 | from PIL import Image 8 | import sys 9 | import cv2 10 | import json 11 | import os 12 | 13 | 14 | def read_img(filename): 15 | # convert to RGB for scene flow finalpass data 16 | img = np.array(Image.open(filename).convert('RGB')).astype(np.float32) 17 | return img 18 | 19 | 20 | def read_disp(filename, subset=False, vkitti2=False, sintel=False, 21 | tartanair=False, instereo2k=False, crestereo=False, 22 | fallingthings=False, 23 | argoverse=False, 24 | raw_disp_png=False, 25 | ): 26 | # Scene Flow dataset 27 | if filename.endswith('pfm'): 28 | # For finalpass and cleanpass, gt disparity is positive, subset is negative 29 | disp = np.ascontiguousarray(_read_pfm(filename)[0]) 30 | if subset: 31 | disp = -disp 32 | # VKITTI2 dataset 33 | elif vkitti2: 34 | disp = _read_vkitti2_disp(filename) 35 | # Sintel 36 | elif sintel: 37 | disp = _read_sintel_disparity(filename) 38 | elif tartanair: 39 | disp = _read_tartanair_disp(filename) 40 | elif instereo2k: 41 | disp = _read_instereo2k_disp(filename) 42 | elif crestereo: 43 | disp = _read_crestereo_disp(filename) 44 | elif fallingthings: 45 | disp = _read_fallingthings_disp(filename) 46 | elif argoverse: 47 | disp = _read_argoverse_disp(filename) 48 | elif raw_disp_png: 49 | disp = np.array(Image.open(filename)).astype(np.float32) 50 | # KITTI 51 | elif filename.endswith('png'): 52 | disp = _read_kitti_disp(filename) 53 | elif filename.endswith('npy'): 54 | disp = np.load(filename) 55 | else: 56 | raise Exception('Invalid disparity file format!') 57 | return disp # [H, W] 58 | 59 | 60 | def _read_pfm(file): 61 | file = open(file, 'rb') 62 | 63 | color = None 64 | width = None 65 | height = None 66 | scale = None 67 | endian = None 68 | 69 | header = file.readline().rstrip() 70 | if header.decode("ascii") == 'PF': 71 | color = True 72 | elif header.decode("ascii") == 'Pf': 73 | color = False 74 | else: 75 | raise Exception('Not a PFM file.') 76 | 77 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii")) 78 | if dim_match: 79 | width, height = list(map(int, dim_match.groups())) 80 | else: 81 | raise Exception('Malformed PFM header.') 82 | 83 | scale = float(file.readline().decode("ascii").rstrip()) 84 | if scale < 0: # little-endian 85 | endian = '<' 86 | scale = -scale 87 | else: 88 | endian = '>' # big-endian 89 | 90 | data = np.fromfile(file, endian + 'f') 91 | shape = (height, width, 3) if color else (height, width) 92 | 93 | data = np.reshape(data, shape) 94 | data = np.flipud(data) 95 | return data, scale 96 | 97 | 98 | def write_pfm(file, image, scale=1): 99 | file = open(file, 'wb') 100 | 101 | color = None 102 | 103 | if image.dtype.name != 'float32': 104 | raise Exception('Image dtype must be float32.') 105 | 106 | image = np.flipud(image) 107 | 108 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 109 | color = True 110 | elif len(image.shape) == 2 or len( 111 | image.shape) == 3 and image.shape[2] == 1: # greyscale 112 | color = False 113 | else: 114 | raise Exception( 115 | 'Image must have H x W x 3, H x W x 1 or H x W dimensions.') 116 | 117 | file.write(b'PF\n' if color else b'Pf\n') 118 | file.write(b'%d %d\n' % (image.shape[1], image.shape[0])) 119 | 120 | endian = image.dtype.byteorder 121 | 122 | if endian == '<' or endian == '=' and sys.byteorder == 'little': 123 | scale = -scale 124 | 125 | file.write(b'%f\n' % scale) 126 | 127 | image.tofile(file) 128 | 129 | 130 | def _read_kitti_disp(filename): 131 | depth = np.array(Image.open(filename)) 132 | depth = depth.astype(np.float32) / 256. 133 | return depth 134 | 135 | 136 | def _read_vkitti2_disp(filename): 137 | # read depth 138 | depth = cv2.imread(filename, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) # in cm 139 | depth = (depth / 100).astype(np.float32) # depth clipped to 655.35m for sky 140 | 141 | valid = (depth > 0) & (depth < 655) # depth clipped to 655.35m for sky 142 | 143 | # convert to disparity 144 | focal_length = 725.0087 # in pixels 145 | baseline = 0.532725 # meter 146 | 147 | disp = baseline * focal_length / depth 148 | 149 | disp[~valid] = 0.000001 # invalid as very small value 150 | 151 | return disp 152 | 153 | 154 | def _read_sintel_disparity(filename): 155 | """ Return disparity read from filename. """ 156 | f_in = np.array(Image.open(filename)) 157 | 158 | d_r = f_in[:, :, 0].astype('float32') 159 | d_g = f_in[:, :, 1].astype('float32') 160 | d_b = f_in[:, :, 2].astype('float32') 161 | 162 | depth = d_r * 4 + d_g / (2 ** 6) + d_b / (2 ** 14) 163 | return depth 164 | 165 | 166 | def _read_tartanair_disp(filename): 167 | # the infinite distant object such as the sky has a large depth value (e.g. 10000) 168 | depth = np.load(filename) 169 | 170 | # change to disparity image 171 | disparity = 80.0 / depth 172 | 173 | return disparity 174 | 175 | 176 | def _read_instereo2k_disp(filename): 177 | disp = np.array(Image.open(filename)) 178 | disp = disp.astype(np.float32) / 100. 179 | return disp 180 | 181 | 182 | def _read_crestereo_disp(filename): 183 | disp = np.array(Image.open(filename)) 184 | return disp.astype(np.float32) / 32. 185 | 186 | 187 | def _read_fallingthings_disp(filename): 188 | depth = np.array(Image.open(filename)) 189 | camera_file = os.path.join(os.path.dirname(filename), '_camera_settings.json') 190 | with open(camera_file, 'r') as f: 191 | intrinsics = json.load(f) 192 | fx = intrinsics['camera_settings'][0]['intrinsic_settings']['fx'] 193 | disp = (fx * 6.0 * 100) / depth.astype(np.float32) 194 | 195 | return disp 196 | 197 | 198 | def _read_argoverse_disp(filename): 199 | disparity_map = cv2.imread(filename, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) 200 | return np.float32(disparity_map) / 256. 201 | 202 | 203 | def extract_video(video_name): 204 | cap = cv2.VideoCapture(video_name) 205 | assert cap.isOpened(), f'Failed to load video file {video_name}' 206 | # get video info 207 | size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), 208 | int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))) 209 | fps = cap.get(cv2.CAP_PROP_FPS) 210 | 211 | print('video size (hxw): %dx%d' % (size[1], size[0])) 212 | print('fps: %d' % fps) 213 | 214 | imgs = [] 215 | while cap.isOpened(): 216 | # get frames 217 | flag, img = cap.read() 218 | if not flag: 219 | break 220 | # to rgb format 221 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 222 | imgs.append(img) 223 | 224 | return imgs, fps 225 | -------------------------------------------------------------------------------- /utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-08-03 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | from PIL import Image 21 | 22 | 23 | def make_colorwheel(): 24 | ''' 25 | Generates a color wheel for optical flow visualization as presented in: 26 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 27 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 28 | According to the C++ source code of Daniel Scharstein 29 | According to the Matlab source code of Deqing Sun 30 | ''' 31 | 32 | RY = 15 33 | YG = 6 34 | GC = 4 35 | CB = 11 36 | BM = 13 37 | MR = 6 38 | 39 | ncols = RY + YG + GC + CB + BM + MR 40 | colorwheel = np.zeros((ncols, 3)) 41 | col = 0 42 | 43 | # RY 44 | colorwheel[0:RY, 0] = 255 45 | colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY) 46 | col = col + RY 47 | # YG 48 | colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG) 49 | colorwheel[col:col + YG, 1] = 255 50 | col = col + YG 51 | # GC 52 | colorwheel[col:col + GC, 1] = 255 53 | colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC) 54 | col = col + GC 55 | # CB 56 | colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB) 57 | colorwheel[col:col + CB, 2] = 255 58 | col = col + CB 59 | # BM 60 | colorwheel[col:col + BM, 2] = 255 61 | colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM) 62 | col = col + BM 63 | # MR 64 | colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR) 65 | colorwheel[col:col + MR, 0] = 255 66 | return colorwheel 67 | 68 | 69 | def flow_compute_color(u, v, convert_to_bgr=False): 70 | ''' 71 | Applies the flow color wheel to (possibly clipped) flow components u and v. 72 | According to the C++ source code of Daniel Scharstein 73 | According to the Matlab source code of Deqing Sun 74 | :param u: np.ndarray, input horizontal flow 75 | :param v: np.ndarray, input vertical flow 76 | :param convert_to_bgr: bool, whether to change ordering and output BGR instead of RGB 77 | :return: 78 | ''' 79 | 80 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 81 | 82 | colorwheel = make_colorwheel() # shape [55x3] 83 | ncols = colorwheel.shape[0] 84 | 85 | rad = np.sqrt(np.square(u) + np.square(v)) 86 | a = np.arctan2(-v, -u) / np.pi 87 | 88 | fk = (a + 1) / 2 * (ncols - 1) + 1 89 | k0 = np.floor(fk).astype(np.int32) 90 | k1 = k0 + 1 91 | k1[k1 == ncols] = 1 92 | f = fk - k0 93 | 94 | for i in range(colorwheel.shape[1]): 95 | tmp = colorwheel[:, i] 96 | col0 = tmp[k0] / 255.0 97 | col1 = tmp[k1] / 255.0 98 | col = (1 - f) * col0 + f * col1 99 | 100 | idx = (rad <= 1) 101 | col[idx] = 1 - rad[idx] * (1 - col[idx]) 102 | col[~idx] = col[~idx] * 0.75 # out of range? 103 | 104 | # Note the 2-i => BGR instead of RGB 105 | ch_idx = 2 - i if convert_to_bgr else i 106 | flow_image[:, :, ch_idx] = np.floor(255 * col) 107 | 108 | return flow_image 109 | 110 | 111 | def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False): 112 | ''' 113 | Expects a two dimensional flow image of shape [H,W,2] 114 | According to the C++ source code of Daniel Scharstein 115 | According to the Matlab source code of Deqing Sun 116 | :param flow_uv: np.ndarray of shape [H,W,2] 117 | :param clip_flow: float, maximum clipping value for flow 118 | :return: 119 | ''' 120 | 121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 123 | 124 | if clip_flow is not None: 125 | flow_uv = np.clip(flow_uv, 0, clip_flow) 126 | 127 | u = flow_uv[:, :, 0] 128 | v = flow_uv[:, :, 1] 129 | 130 | rad = np.sqrt(np.square(u) + np.square(v)) 131 | rad_max = np.max(rad) 132 | 133 | epsilon = 1e-5 134 | u = u / (rad_max + epsilon) 135 | v = v / (rad_max + epsilon) 136 | 137 | return flow_compute_color(u, v, convert_to_bgr) 138 | 139 | 140 | UNKNOWN_FLOW_THRESH = 1e7 141 | SMALLFLOW = 0.0 142 | LARGEFLOW = 1e8 143 | 144 | 145 | def make_color_wheel(): 146 | """ 147 | Generate color wheel according Middlebury color code 148 | :return: Color wheel 149 | """ 150 | RY = 15 151 | YG = 6 152 | GC = 4 153 | CB = 11 154 | BM = 13 155 | MR = 6 156 | 157 | ncols = RY + YG + GC + CB + BM + MR 158 | 159 | colorwheel = np.zeros([ncols, 3]) 160 | 161 | col = 0 162 | 163 | # RY 164 | colorwheel[0:RY, 0] = 255 165 | colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY)) 166 | col += RY 167 | 168 | # YG 169 | colorwheel[col:col + YG, 0] = 255 - np.transpose(np.floor(255 * np.arange(0, YG) / YG)) 170 | colorwheel[col:col + YG, 1] = 255 171 | col += YG 172 | 173 | # GC 174 | colorwheel[col:col + GC, 1] = 255 175 | colorwheel[col:col + GC, 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC)) 176 | col += GC 177 | 178 | # CB 179 | colorwheel[col:col + CB, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, CB) / CB)) 180 | colorwheel[col:col + CB, 2] = 255 181 | col += CB 182 | 183 | # BM 184 | colorwheel[col:col + BM, 2] = 255 185 | colorwheel[col:col + BM, 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM)) 186 | col += + BM 187 | 188 | # MR 189 | colorwheel[col:col + MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR)) 190 | colorwheel[col:col + MR, 0] = 255 191 | 192 | return colorwheel 193 | 194 | 195 | def compute_color(u, v): 196 | """ 197 | compute optical flow color map 198 | :param u: optical flow horizontal map 199 | :param v: optical flow vertical map 200 | :return: optical flow in color code 201 | """ 202 | [h, w] = u.shape 203 | img = np.zeros([h, w, 3]) 204 | nanIdx = np.isnan(u) | np.isnan(v) 205 | u[nanIdx] = 0 206 | v[nanIdx] = 0 207 | 208 | colorwheel = make_color_wheel() 209 | ncols = np.size(colorwheel, 0) 210 | 211 | rad = np.sqrt(u ** 2 + v ** 2) 212 | 213 | a = np.arctan2(-v, -u) / np.pi 214 | 215 | fk = (a + 1) / 2 * (ncols - 1) + 1 216 | 217 | k0 = np.floor(fk).astype(int) 218 | 219 | k1 = k0 + 1 220 | k1[k1 == ncols + 1] = 1 221 | f = fk - k0 222 | 223 | for i in range(0, np.size(colorwheel, 1)): 224 | tmp = colorwheel[:, i] 225 | col0 = tmp[k0 - 1] / 255 226 | col1 = tmp[k1 - 1] / 255 227 | col = (1 - f) * col0 + f * col1 228 | 229 | idx = rad <= 1 230 | col[idx] = 1 - rad[idx] * (1 - col[idx]) 231 | notidx = np.logical_not(idx) 232 | 233 | col[notidx] *= 0.75 234 | img[:, :, i] = np.uint8(np.floor(255 * col * (1 - nanIdx))) 235 | 236 | return img 237 | 238 | 239 | # from https://github.com/gengshan-y/VCN 240 | def flow_to_image(flow): 241 | """ 242 | Convert flow into middlebury color code image 243 | :param flow: optical flow map 244 | :return: optical flow image in middlebury color 245 | """ 246 | u = flow[:, :, 0] 247 | v = flow[:, :, 1] 248 | 249 | maxu = -999. 250 | maxv = -999. 251 | minu = 999. 252 | minv = 999. 253 | 254 | idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH) 255 | u[idxUnknow] = 0 256 | v[idxUnknow] = 0 257 | 258 | maxu = max(maxu, np.max(u)) 259 | minu = min(minu, np.min(u)) 260 | 261 | maxv = max(maxv, np.max(v)) 262 | minv = min(minv, np.min(v)) 263 | 264 | rad = np.sqrt(u ** 2 + v ** 2) 265 | maxrad = max(-1, np.max(rad)) 266 | 267 | u = u / (maxrad + np.finfo(float).eps) 268 | v = v / (maxrad + np.finfo(float).eps) 269 | 270 | img = compute_color(u, v) 271 | 272 | idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2) 273 | img[idx] = 0 274 | 275 | return np.uint8(img) 276 | 277 | 278 | def save_vis_flow_tofile(flow, output_path): 279 | vis_flow = flow_to_image(flow) 280 | Image.fromarray(vis_flow).save(output_path) 281 | 282 | 283 | def flow_tensor_to_image(flow): 284 | """Used for tensorboard visualization""" 285 | flow = flow.permute(1, 2, 0) # [H, W, 2] 286 | flow = flow.detach().cpu().numpy() 287 | flow = flow_to_image(flow) # [H, W, 3] 288 | flow = np.transpose(flow, (2, 0, 1)) # [3, H, W] 289 | 290 | return flow 291 | -------------------------------------------------------------------------------- /utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | import cv2 6 | 7 | TAG_CHAR = np.array([202021.25], np.float32) 8 | 9 | 10 | def readFlow(fn): 11 | """ Read .flo file in Middlebury format""" 12 | # Code adapted from: 13 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 14 | 15 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 16 | # print 'fn = %s'%(fn) 17 | with open(fn, 'rb') as f: 18 | magic = np.fromfile(f, np.float32, count=1) 19 | if 202021.25 != magic: 20 | print('Magic number incorrect. Invalid .flo file') 21 | return None 22 | else: 23 | w = np.fromfile(f, np.int32, count=1) 24 | h = np.fromfile(f, np.int32, count=1) 25 | # print 'Reading %d x %d flo file\n' % (w, h) 26 | data = np.fromfile(f, np.float32, count=2 * int(w) * int(h)) 27 | # Reshape testdata into 3D array (columns, rows, bands) 28 | # The reshape here is for visualization, the original code is (w,h,2) 29 | return np.resize(data, (int(h), int(w), 2)) 30 | 31 | 32 | def readPFM(file): 33 | file = open(file, 'rb') 34 | 35 | color = None 36 | width = None 37 | height = None 38 | scale = None 39 | endian = None 40 | 41 | header = file.readline().rstrip() 42 | if header == b'PF': 43 | color = True 44 | elif header == b'Pf': 45 | color = False 46 | else: 47 | raise Exception('Not a PFM file.') 48 | 49 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 50 | if dim_match: 51 | width, height = map(int, dim_match.groups()) 52 | else: 53 | raise Exception('Malformed PFM header.') 54 | 55 | scale = float(file.readline().rstrip()) 56 | if scale < 0: # little-endian 57 | endian = '<' 58 | scale = -scale 59 | else: 60 | endian = '>' # big-endian 61 | 62 | data = np.fromfile(file, endian + 'f') 63 | shape = (height, width, 3) if color else (height, width) 64 | 65 | data = np.reshape(data, shape) 66 | data = np.flipud(data) 67 | return data 68 | 69 | 70 | def writeFlow(filename, uv, v=None): 71 | """ Write optical flow to file. 72 | 73 | If v is None, uv is assumed to contain both u and v channels, 74 | stacked in depth. 75 | Original code by Deqing Sun, adapted from Daniel Scharstein. 76 | """ 77 | nBands = 2 78 | 79 | if v is None: 80 | assert (uv.ndim == 3) 81 | assert (uv.shape[2] == 2) 82 | u = uv[:, :, 0] 83 | v = uv[:, :, 1] 84 | else: 85 | u = uv 86 | 87 | assert (u.shape == v.shape) 88 | height, width = u.shape 89 | f = open(filename, 'wb') 90 | # write the header 91 | f.write(TAG_CHAR) 92 | np.array(width).astype(np.int32).tofile(f) 93 | np.array(height).astype(np.int32).tofile(f) 94 | # arrange into matrix form 95 | tmp = np.zeros((height, width * nBands)) 96 | tmp[:, np.arange(width) * 2] = u 97 | tmp[:, np.arange(width) * 2 + 1] = v 98 | tmp.astype(np.float32).tofile(f) 99 | f.close() 100 | 101 | 102 | def readFlowKITTI(filename): 103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR) 104 | flow = flow[:, :, ::-1].astype(np.float32) 105 | flow, valid = flow[:, :, :2], flow[:, :, 2] 106 | flow = (flow - 2 ** 15) / 64.0 107 | return flow, valid 108 | 109 | 110 | def readDispKITTI(filename): 111 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 112 | valid = disp > 0.0 113 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 114 | return flow, valid 115 | 116 | 117 | def writeFlowKITTI(filename, uv): 118 | uv = 64.0 * uv + 2 ** 15 119 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 120 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 121 | cv2.imwrite(filename, uv[..., ::-1]) 122 | 123 | 124 | def read_gen(file_name, pil=False): 125 | ext = splitext(file_name)[-1] 126 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 127 | return Image.open(file_name) 128 | elif ext == '.bin' or ext == '.raw': 129 | return np.load(file_name) 130 | elif ext == '.flo': 131 | return readFlow(file_name).astype(np.float32) 132 | elif ext == '.pfm': 133 | flow = readPFM(file_name).astype(np.float32) 134 | if len(flow.shape) == 2: 135 | return flow 136 | else: 137 | return flow[:, :, :-1] 138 | return [] 139 | 140 | 141 | def read_vkitti2_flow(filename): 142 | # In R, flow along x-axis normalized by image width and quantized to [0;2^16 – 1] 143 | # In G, flow along x-axis normalized by image width and quantized to [0;2^16 – 1] 144 | # B = 0 for invalid flow (e.g., sky pixels) 145 | bgr = cv2.imread(filename, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) 146 | h, w, _c = bgr.shape 147 | assert bgr.dtype == np.uint16 and _c == 3 148 | # b == invalid flow flag == 0 for sky or other invalid flow 149 | invalid = bgr[:, :, 0] == 0 150 | # g,r == flow_y,x normalized by height,width and scaled to [0;2**16 – 1] 151 | out_flow = 2.0 / (2 ** 16 - 1.0) * bgr[:, :, 2:0:-1].astype('f4') - 1 # [H, W, 2] 152 | out_flow[..., 0] *= (w - 1) 153 | out_flow[..., 1] *= (h - 1) 154 | 155 | out_flow[invalid] = 0.000001 # invalid as very small value to add supervison on the sky 156 | valid = (np.logical_or(invalid, ~invalid)).astype(np.float32) 157 | 158 | return out_flow, valid 159 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from utils.flow_viz import flow_tensor_to_image 4 | from .visualization import viz_depth_tensor 5 | 6 | 7 | class Logger: 8 | def __init__(self, lr_scheduler, 9 | summary_writer, 10 | summary_freq=100, 11 | start_step=0, 12 | img_mean=None, 13 | img_std=None, 14 | ): 15 | self.lr_scheduler = lr_scheduler 16 | self.total_steps = start_step 17 | self.running_loss = {} 18 | self.summary_writer = summary_writer 19 | self.summary_freq = summary_freq 20 | 21 | self.img_mean = img_mean 22 | self.img_std = img_std 23 | 24 | def print_training_status(self, mode='train', is_depth=False): 25 | if is_depth: 26 | print('step: %06d \t loss: %.3f' % (self.total_steps, self.running_loss['total_loss'] / self.summary_freq)) 27 | else: 28 | print('step: %06d \t epe: %.3f' % (self.total_steps, self.running_loss['epe'] / self.summary_freq)) 29 | 30 | for k in self.running_loss: 31 | self.summary_writer.add_scalar(mode + '/' + k, 32 | self.running_loss[k] / self.summary_freq, self.total_steps) 33 | self.running_loss[k] = 0.0 34 | 35 | def lr_summary(self): 36 | lr = self.lr_scheduler.get_last_lr()[0] 37 | self.summary_writer.add_scalar('lr', lr, self.total_steps) 38 | 39 | def add_image_summary(self, img1, img2, flow_preds=None, flow_gt=None, mode='train', 40 | is_depth=False, 41 | ): 42 | if self.total_steps % self.summary_freq == 0: 43 | if is_depth: 44 | img1 = self.unnormalize_image(img1.detach().cpu()) # [3, H, W], range [0, 1] 45 | img2 = self.unnormalize_image(img2.detach().cpu()) 46 | 47 | concat = torch.cat((img1, img2), dim=-1) # [3, H, W*2] 48 | 49 | self.summary_writer.add_image(mode + '/img', concat, self.total_steps) 50 | else: 51 | img_concat = torch.cat((img1[0].detach().cpu(), img2[0].detach().cpu()), dim=-1) 52 | img_concat = img_concat.type(torch.uint8) # convert to uint8 to visualize in tensorboard 53 | 54 | flow_pred = flow_tensor_to_image(flow_preds[-1][0]) 55 | forward_flow_gt = flow_tensor_to_image(flow_gt[0]) 56 | flow_concat = torch.cat((torch.from_numpy(flow_pred), 57 | torch.from_numpy(forward_flow_gt)), dim=-1) 58 | 59 | concat = torch.cat((img_concat, flow_concat), dim=-2) 60 | 61 | self.summary_writer.add_image(mode + '/img_pred_gt', concat, self.total_steps) 62 | 63 | def add_depth_summary(self, depth_pred, depth_gt, mode='train'): 64 | # assert depth_pred.dim() == 2 # [H, W] 65 | if self.total_steps % self.summary_freq == 0 or 'val' in mode: 66 | pred_viz = viz_depth_tensor(depth_pred.detach().cpu()) # [3, H, W] 67 | gt_viz = viz_depth_tensor(depth_gt.detach().cpu()) 68 | 69 | concat = torch.cat((pred_viz, gt_viz), dim=-1) # [3, H, W*2] 70 | 71 | self.summary_writer.add_image(mode + '/depth_pred_gt', concat, self.total_steps) 72 | 73 | def unnormalize_image(self, img): 74 | # img: [3, H, W], used for visualizing image 75 | mean = torch.tensor(self.img_mean).view(3, 1, 1).type_as(img) 76 | std = torch.tensor(self.img_std).view(3, 1, 1).type_as(img) 77 | 78 | out = img * std + mean 79 | 80 | return out 81 | 82 | def push(self, metrics, mode='train', is_depth=False, ): 83 | self.total_steps += 1 84 | 85 | self.lr_summary() 86 | 87 | for key in metrics: 88 | if key not in self.running_loss: 89 | self.running_loss[key] = 0.0 90 | 91 | self.running_loss[key] += metrics[key] 92 | 93 | if self.total_steps % self.summary_freq == 0: 94 | self.print_training_status(mode, is_depth=is_depth) 95 | self.running_loss = {} 96 | 97 | def write_dict(self, results): 98 | for key in results: 99 | tag = key.split('_')[0] 100 | tag = tag + '/' + key 101 | self.summary_writer.add_scalar(tag, results[key], self.total_steps) 102 | 103 | def close(self): 104 | self.summary_writer.close() 105 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | 5 | 6 | def read_text_lines(filepath): 7 | with open(filepath, 'r') as f: 8 | lines = f.readlines() 9 | lines = [l.rstrip() for l in lines] 10 | return lines 11 | 12 | 13 | def check_path(path): 14 | if not os.path.exists(path): 15 | os.makedirs(path, exist_ok=True) # explicitly set exist_ok when multi-processing 16 | 17 | 18 | def save_command(save_path, filename='command_train.txt'): 19 | check_path(save_path) 20 | command = sys.argv 21 | save_file = os.path.join(save_path, filename) 22 | # Save all training commands when resuming training 23 | with open(save_file, 'a') as f: 24 | f.write(' '.join(command)) 25 | f.write('\n\n') 26 | 27 | 28 | def save_args(args, filename='args.json'): 29 | args_dict = vars(args) 30 | check_path(args.checkpoint_dir) 31 | save_path = os.path.join(args.checkpoint_dir, filename) 32 | 33 | # save all training args when resuming training 34 | with open(save_path, 'a') as f: 35 | json.dump(args_dict, f, indent=4, sort_keys=False) 36 | f.write('\n\n') 37 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | 6 | class InputPadder: 7 | """ Pads images such that dimensions are divisible by 8 """ 8 | 9 | def __init__(self, dims, mode='sintel', padding_factor=8): 10 | self.ht, self.wd = dims[-2:] 11 | pad_ht = (((self.ht // padding_factor) + 1) * padding_factor - self.ht) % padding_factor 12 | pad_wd = (((self.wd // padding_factor) + 1) * padding_factor - self.wd) % padding_factor 13 | if mode == 'sintel': 14 | self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2] 15 | else: 16 | self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht] 17 | 18 | def pad(self, *inputs): 19 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 20 | 21 | def unpad(self, x): 22 | ht, wd = x.shape[-2:] 23 | c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] 24 | return x[..., c[0]:c[1], c[2]:c[3]] 25 | 26 | 27 | def bilinear_sampler(img, coords, mode='bilinear', mask=False, padding_mode='zeros'): 28 | """ Wrapper for grid_sample, uses pixel coordinates """ 29 | if coords.size(-1) != 2: # [B, 2, H, W] -> [B, H, W, 2] 30 | coords = coords.permute(0, 2, 3, 1) 31 | 32 | H, W = img.shape[-2:] 33 | # H = height if height is not None else img.shape[-2] 34 | # W = width if width is not None else img.shape[-1] 35 | 36 | xgrid, ygrid = coords.split([1, 1], dim=-1) 37 | 38 | # To handle H or W equals to 1 by explicitly defining height and width 39 | if H == 1: 40 | assert ygrid.abs().max() < 1e-8 41 | H = 10 42 | if W == 1: 43 | assert xgrid.abs().max() < 1e-8 44 | W = 10 45 | 46 | xgrid = 2 * xgrid / (W - 1) - 1 47 | ygrid = 2 * ygrid / (H - 1) - 1 48 | 49 | grid = torch.cat([xgrid, ygrid], dim=-1) 50 | img = F.grid_sample(img, grid, mode=mode, 51 | padding_mode=padding_mode, 52 | align_corners=True) 53 | 54 | if mask: 55 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 56 | return img, mask.squeeze(-1).float() 57 | 58 | return img 59 | 60 | 61 | def coords_grid(batch, ht, wd, normalize=False): 62 | if normalize: # [-1, 1] 63 | coords = torch.meshgrid(2 * torch.arange(ht) / (ht - 1) - 1, 64 | 2 * torch.arange(wd) / (wd - 1) - 1) 65 | else: 66 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 67 | coords = torch.stack(coords[::-1], dim=0).float() 68 | return coords[None].repeat(batch, 1, 1, 1) # [B, 2, H, W] 69 | 70 | 71 | def coords_grid_np(h, w): # used for accumulating high speed sintel flow testdata 72 | coords = np.meshgrid(np.arange(h, dtype=np.float32), 73 | np.arange(w, dtype=np.float32), indexing='ij') 74 | coords = np.stack(coords[::-1], axis=-1) # [H, W, 2] 75 | 76 | return coords 77 | 78 | 79 | def compute_out_of_boundary_mask(flow, downsample_factor=None): 80 | # flow: [B, 2, H, W] 81 | assert flow.dim() == 4 and flow.size(1) == 2 82 | b, _, h, w = flow.shape 83 | init_coords = coords_grid(b, h, w).to(flow.device) 84 | corres = init_coords + flow # [B, 2, H, W] 85 | 86 | if downsample_factor is not None: 87 | assert w % downsample_factor == 0 and h % downsample_factor == 0 88 | # the actual max disp can predict is in the downsampled feature resolution, then upsample 89 | max_w = (w // downsample_factor - 1) * downsample_factor 90 | max_h = (h // downsample_factor - 1) * downsample_factor 91 | # print('max_w: %d, max_h: %d' % (max_w, max_h)) 92 | else: 93 | max_w = w - 1 94 | max_h = h - 1 95 | 96 | valid_mask = (corres[:, 0] >= 0) & (corres[:, 0] <= max_w) & (corres[:, 1] >= 0) & (corres[:, 1] <= max_h) 97 | 98 | # in case very large flow 99 | flow_mask = (flow[:, 0].abs() <= max_w) & (flow[:, 1].abs() <= max_h) 100 | 101 | valid_mask = valid_mask & flow_mask 102 | 103 | return valid_mask # [B, H, W] 104 | 105 | 106 | def normalize_coords(grid): 107 | """Normalize coordinates of image scale to [-1, 1] 108 | Args: 109 | grid: [B, 2, H, W] 110 | """ 111 | assert grid.size(1) == 2 112 | h, w = grid.size()[2:] 113 | grid[:, 0, :, :] = 2 * (grid[:, 0, :, :].clone() / (w - 1)) - 1 # x: [-1, 1] 114 | grid[:, 1, :, :] = 2 * (grid[:, 1, :, :].clone() / (h - 1)) - 1 # y: [-1, 1] 115 | # grid = grid.permute((0, 2, 3, 1)) # [B, H, W, 2] 116 | return grid 117 | 118 | 119 | def flow_warp(feature, flow, mask=False, padding_mode='zeros'): 120 | b, c, h, w = feature.size() 121 | assert flow.size(1) == 2 122 | 123 | grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W] 124 | 125 | return bilinear_sampler(feature, grid, mask=mask, padding_mode=padding_mode) 126 | 127 | 128 | def upflow8(flow, mode='bilinear'): 129 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 130 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 131 | 132 | 133 | def bilinear_upflow(flow, scale_factor=8): 134 | assert flow.size(1) == 2 135 | flow = F.interpolate(flow, scale_factor=scale_factor, 136 | mode='bilinear', align_corners=True) * scale_factor 137 | 138 | return flow 139 | 140 | 141 | def upsample_flow(flow, img): 142 | if flow.size(-1) != img.size(-1): 143 | scale_factor = img.size(-1) / flow.size(-1) 144 | flow = F.interpolate(flow, size=img.size()[-2:], 145 | mode='bilinear', align_corners=True) * scale_factor 146 | return flow 147 | 148 | 149 | def count_parameters(model): 150 | num = sum(p.numel() for p in model.parameters() if p.requires_grad) 151 | return num 152 | 153 | 154 | def set_bn_eval(m): 155 | classname = m.__class__.__name__ 156 | if classname.find('BatchNorm') != -1: 157 | m.eval() 158 | -------------------------------------------------------------------------------- /utils/visualization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import numpy as np 4 | import torchvision.utils as vutils 5 | import cv2 6 | from matplotlib.cm import get_cmap 7 | import matplotlib as mpl 8 | import matplotlib.cm as cm 9 | 10 | 11 | def vis_disparity(disp): 12 | disp_vis = (disp - disp.min()) / (disp.max() - disp.min()) * 255.0 13 | disp_vis = disp_vis.astype("uint8") 14 | disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO) 15 | 16 | return disp_vis 17 | 18 | 19 | def gen_error_colormap(): 20 | cols = np.array( 21 | [[0 / 3.0, 0.1875 / 3.0, 49, 54, 149], 22 | [0.1875 / 3.0, 0.375 / 3.0, 69, 117, 180], 23 | [0.375 / 3.0, 0.75 / 3.0, 116, 173, 209], 24 | [0.75 / 3.0, 1.5 / 3.0, 171, 217, 233], 25 | [1.5 / 3.0, 3 / 3.0, 224, 243, 248], 26 | [3 / 3.0, 6 / 3.0, 254, 224, 144], 27 | [6 / 3.0, 12 / 3.0, 253, 174, 97], 28 | [12 / 3.0, 24 / 3.0, 244, 109, 67], 29 | [24 / 3.0, 48 / 3.0, 215, 48, 39], 30 | [48 / 3.0, np.inf, 165, 0, 38]], dtype=np.float32) 31 | cols[:, 2: 5] /= 255. 32 | return cols 33 | 34 | 35 | def disp_error_img(D_est_tensor, D_gt_tensor, abs_thres=3., rel_thres=0.05, dilate_radius=1): 36 | D_gt_np = D_gt_tensor.detach().cpu().numpy() 37 | D_est_np = D_est_tensor.detach().cpu().numpy() 38 | B, H, W = D_gt_np.shape 39 | # valid mask 40 | mask = D_gt_np > 0 41 | # error in percentage. When error <= 1, the pixel is valid since <= 3px & 5% 42 | error = np.abs(D_gt_np - D_est_np) 43 | error[np.logical_not(mask)] = 0 44 | error[mask] = np.minimum(error[mask] / abs_thres, (error[mask] / D_gt_np[mask]) / rel_thres) 45 | # get colormap 46 | cols = gen_error_colormap() 47 | # create error image 48 | error_image = np.zeros([B, H, W, 3], dtype=np.float32) 49 | for i in range(cols.shape[0]): 50 | error_image[np.logical_and(error >= cols[i][0], error < cols[i][1])] = cols[i, 2:] 51 | # TODO: imdilate 52 | # error_image = cv2.imdilate(D_err, strel('disk', dilate_radius)); 53 | error_image[np.logical_not(mask)] = 0. 54 | # show color tag in the top-left cornor of the image 55 | for i in range(cols.shape[0]): 56 | distance = 20 57 | error_image[:, :10, i * distance:(i + 1) * distance, :] = cols[i, 2:] 58 | 59 | return torch.from_numpy(np.ascontiguousarray(error_image.transpose([0, 3, 1, 2]))) 60 | 61 | 62 | def save_images(logger, mode_tag, images_dict, global_step): 63 | images_dict = tensor2numpy(images_dict) 64 | for tag, values in images_dict.items(): 65 | if not isinstance(values, list) and not isinstance(values, tuple): 66 | values = [values] 67 | for idx, value in enumerate(values): 68 | if len(value.shape) == 3: 69 | value = value[:, np.newaxis, :, :] 70 | value = value[:1] 71 | value = torch.from_numpy(value) 72 | 73 | image_name = '{}/{}'.format(mode_tag, tag) 74 | if len(values) > 1: 75 | image_name = image_name + "_" + str(idx) 76 | logger.add_image(image_name, vutils.make_grid(value, padding=0, nrow=1, normalize=True, scale_each=True), 77 | global_step) 78 | 79 | 80 | def tensor2numpy(var_dict): 81 | for key, vars in var_dict.items(): 82 | if isinstance(vars, np.ndarray): 83 | var_dict[key] = vars 84 | elif isinstance(vars, torch.Tensor): 85 | var_dict[key] = vars.data.cpu().numpy() 86 | else: 87 | raise NotImplementedError("invalid input type for tensor2numpy") 88 | 89 | return var_dict 90 | 91 | 92 | def viz_depth_tensor(disp, return_numpy=False, colormap='plasma'): 93 | # visualize inverse depth 94 | assert isinstance(disp, torch.Tensor) 95 | 96 | disp = disp.numpy() 97 | vmax = np.percentile(disp, 95) 98 | normalizer = mpl.colors.Normalize(vmin=disp.min(), vmax=vmax) 99 | mapper = cm.ScalarMappable(norm=normalizer, cmap=colormap) 100 | colormapped_im = (mapper.to_rgba(disp)[:, :, :3] * 255).astype(np.uint8) # [H, W, 3] 101 | 102 | if return_numpy: 103 | return colormapped_im 104 | 105 | viz = torch.from_numpy(colormapped_im).permute(2, 0, 1) # [3, H, W] 106 | 107 | return viz 108 | --------------------------------------------------------------------------------