├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE.txt ├── README.md ├── configs ├── kitti_aug+hg+mf.json ├── kitti_aug+hg.json ├── kitti_aug.json ├── kitti_base.json ├── sintel_aug+hg+mf.json ├── sintel_aug+hg.json ├── sintel_aug.json └── sintel_base.json ├── datasets ├── flow_datasets.py └── get_dataset.py ├── intro.png ├── losses ├── flow_loss.py ├── get_loss.py └── loss_blocks.py ├── models ├── correlation_native.py ├── correlation_package │ ├── __init__.py │ ├── correlation.py │ ├── correlation_cuda.cc │ ├── correlation_cuda_kernel.cu │ ├── correlation_cuda_kernel.cuh │ └── setup.py ├── get_model.py └── pwclite.py ├── sam_inference.py ├── test.py ├── train.py ├── trainer ├── base_trainer.py ├── get_trainer.py ├── kitti_trainer_ar.py ├── object_cache.py └── sintel_trainer_ar.py ├── transforms ├── ar_transforms │ ├── ap_transforms.py │ ├── interpolation.py │ ├── oc_transforms.py │ └── sp_transforms.py ├── co_transforms.py └── input_transforms.py └── utils ├── config_parser.py ├── flow_utils.py ├── logger.py ├── misc_utils.py ├── torch_utils.py └── warp_utils.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to UnSAMFlow 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you haven't already, complete the Contributor License Agreement ("CLA"). 10 | 11 | ## Contributor License Agreement ("CLA") 12 | In order to accept your pull request, we need you to submit a CLA. You only need 13 | to do this once to work on any of Meta's open source projects. 14 | 15 | Complete your CLA here: 16 | 17 | ## Issues 18 | We use GitHub issues to track public bugs. Please ensure your description is 19 | clear and has sufficient instructions to be able to reproduce the issue. 20 | 21 | ## License 22 | By contributing to UnSAMFlow, you agree that your contributions will be licensed 23 | under the LICENSE file in the root directory of this source tree. 24 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | The majority of UnSAMFlow is licensed under CC-BY-NC, however portions of the project are available under separate license terms: SemARFlow and ARFlow are licensed under the MIT license. 4 | 5 | --- 6 | 7 | MIT License 8 | 9 | Copyright (c) 2022 duke-vision 10 | 11 | Permission is hereby granted, free of charge, to any person obtaining a copy 12 | of this software and associated documentation files (the "Software"), to deal 13 | in the Software without restriction, including without limitation the rights 14 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | copies of the Software, and to permit persons to whom the Software is 16 | furnished to do so, subject to the following conditions: 17 | 18 | The above copyright notice and this permission notice shall be included in all 19 | copies or substantial portions of the Software. 20 | 21 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 27 | SOFTWARE. 28 | 29 | --- 30 | 31 | MIT License 32 | 33 | Copyright (c) 2020 Liang Liu 34 | 35 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 36 | 37 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 38 | 39 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UnSAMFlow: Unsupervised Optical Flow Guided by Segment Anything Model (CVPR 2024) 2 | 3 | ![Python 3.10.9](https://img.shields.io/badge/Python-3.6.15-brightgreen?style=plastic) ![PyTorch 2.2.0a0](https://img.shields.io/badge/PyTorch-2.2.0a0-brightgreen?style=plastic) ![CUDA 12.0](https://img.shields.io/badge/CUDA-10.1-brightgreen?style=plastic) 4 | 5 | This repository contains the PyTorch implementation of our paper titled *UnSAMFlow: Unsupervised Optical Flow Guided by Segment Anything Model*, accepted by CVPR 2024. 6 | 7 | __Authors__: Shuai Yuan, Lei Luo, Zhuo Hui, Can Pu, Xiaoyu Xiang, Rakesh Ranjan, Denis Demandolx. 8 | 9 | ![demo image](intro.png) 10 | 11 | ## Disclaimer 12 | Our code is developed on our internal AI platform and has not been tested on regular linux systems. Some of the code depends on internal tools and packages that we cannot share here, but we will talk about ways to work around. 13 | 14 | - We use our own internal correlation module, which is not included in this repo. Alternatively, we provide two optional correlation modules in the code to work around: 15 | 16 | 1. A correlation package (included in [models/correlation_package](./models/correlation_package/)) that you may need to install. Instructions can be found [here](https://github.com/duke-vision/semantic-unsup-flow-release#environment). After you install the package, uncomment [models/pwclite.py, Line 6](./models/pwclite.py#L6) and [Line 207-215](./models/pwclite.py#L207) to enable that. 17 | 2. A naive correlation implementation in pytorch. This naive implementation could be very slow and should only be used if none of the former methods work. Simply uncomment [models/pwclite.py, Line 7](./models/pwclite.py#L7) and [Line 207-215](./models/pwclite.py#L207) to enable it. 18 | 19 | - We use our own file systems for all I/O operations. You may need to redefine your own directories and input/output stream by adjusting the code based on your own file system. 20 | 21 | ## Datasets 22 | 23 | Due to copyright issues, please download the dataset from the official websites. 24 | 25 | - **Sintel**: [Sintel clean+final](http://sintel.is.tue.mpg.de/downloads); [Sintel raw](https://github.com/lliuz/ARFlow#datasets-in-the-paper) (prepared by ARFlow authors; please follow instructions on this page to download). 26 | 27 | - **KITTI**: [KITTI 2015](http://www.cvlibs.net/download.php?file=data_scene_flow_multiview.zip); [KITTI 2012](http://www.cvlibs.net/download.php?file=data_stereo_flow_multiview.zip); [KITTI raw](http://www.cvlibs.net/datasets/kitti/raw_data.php). 28 | 29 | 30 | ## Segment Anything Model 31 | Please follow the [official repo](https://github.com/facebookresearch/segment-anything) to infer SAM masks for all samples. We use the default ViT-H SAM model. The code for generating full segmentation and finding key objects from SAM masks is also included in [sam_inference.py](./sam_inference.py). 32 | 33 | 34 | ## Usage 35 | 36 | We provide scripts and code to run each of our experiments. Before running the experiments, please redefine the input/output directories in the scripts. For each row in our final result tables (Tabs. 1-2 in the paper), please run the following to reproduce those results. 37 | 38 | **Ours (baseline)**: 39 | 40 | ```shell 41 | # KITTI 42 | python3 train.py -c configs/kitti_base.json --n_gpu=N_GPU --exp_folder=EXP_FOLDER 43 | 44 | # Sintel 45 | python3 train.py -c configs/sintel_base.json --n_gpu=N_GPU --exp_folder=EXP_FOLDER 46 | ``` 47 | 48 | **Ours (+aug)**: 49 | 50 | ```shell 51 | # KITTI 52 | python3 train.py -c configs/kitti_aug.json --n_gpu=N_GPU --exp_folder=EXP_FOLDER 53 | 54 | # Sintel 55 | python3 train.py -c configs/sintel_aug.json --n_gpu=N_GPU --exp_folder=EXP_FOLDER 56 | ``` 57 | 58 | **Ours (+aug +hg)**: 59 | 60 | ```shell 61 | # KITTI 62 | python3 train.py -c configs/kitti_aug+hg.json --n_gpu=N_GPU --exp_folder=EXP_FOLDER 63 | 64 | # Sintel 65 | python3 train.py -c configs/sintel_aug+hg.json --n_gpu=N_GPU --exp_folder=EXP_FOLDER 66 | ``` 67 | 68 | **Ours (+aug +hg +mf)**: 69 | 70 | ```shell 71 | # KITTI 72 | python3 train.py -c configs/kitti_aug+hg+mf.json --n_gpu=N_GPU --exp_folder=EXP_FOLDER 73 | 74 | # Sintel 75 | python3 train.py -c configs/sintel_aug+hg+mf.json --n_gpu=N_GPU --exp_folder=EXP_FOLDER 76 | ``` 77 | 78 | ## Code credits 79 | The overall structure of this code is adapted from the official [SemARFlow github repo](https://github.com/duke-vision/semantic-unsup-flow-release), appeared in their publication [SemARFlow: Injecting Semantics into Unsupervised Optical Flow Estimation for Autonomous Driving](https://openaccess.thecvf.com/content/ICCV2023/papers/Yuan_SemARFlow_Injecting_Semantics_into_Unsupervised_Optical_Flow_Estimation_for_Autonomous_ICCV_2023_paper.pdf). 80 | 81 | ## License 82 | 83 | The majority of UnSAMFlow is licensed under CC-BY-NC, however portions of the project are available under separate license terms: SemARFlow and ARFlow are licensed under the MIT license. 84 | 85 | Copyright (c) Meta Platforms, Inc. and affiliates. 86 | -------------------------------------------------------------------------------- /configs/kitti_aug+hg+mf.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_configs": "kitti_base.json", 3 | "model": { 4 | "add_mask_corr": true, 5 | "aggregation_type": "concat" 6 | }, 7 | "train": { 8 | "stage2": { 9 | "epoch": 150, 10 | "loss": { 11 | "ransac_threshold": 0.5, 12 | "smooth_type": "homography", 13 | "w_sm": 0.1 14 | }, 15 | "train": { 16 | "key_obj_aug": true, 17 | "key_obj_count": 3, 18 | "w_ar": 0.1 19 | } 20 | } 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /configs/kitti_aug+hg.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_configs": "kitti_base.json", 3 | "train": { 4 | "stage2": { 5 | "epoch": 150, 6 | "loss": { 7 | "ransac_threshold": 0.5, 8 | "smooth_type": "homography", 9 | "w_sm": 0.1 10 | }, 11 | "train": { 12 | "key_obj_aug": true, 13 | "key_obj_count": 3, 14 | "w_ar": 0.1 15 | } 16 | } 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /configs/kitti_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_configs": "kitti_base.json", 3 | "train": { 4 | "stage2": { 5 | "epoch": 150, 6 | "train": { 7 | "key_obj_aug": true, 8 | "key_obj_count": 3, 9 | "w_ar": 0.1 10 | } 11 | } 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /configs/kitti_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": { 3 | "at_cfg": { 4 | "cj": true, 5 | "cj_bri": 0.3, 6 | "cj_con": 0.3, 7 | "cj_hue": 0.1, 8 | "cj_sat": 0.3, 9 | "gamma": false, 10 | "gblur": true 11 | }, 12 | "data_aug": { 13 | "crop": false, 14 | "hflip": true, 15 | "swap": true 16 | }, 17 | "epoches_raw": 100, 18 | "epoches_mv": -1, 19 | "root_kitti12": "YOUR_DIR/KITTI-2012/training/", 20 | "root_kitti15": "YOUR_DIR/KITTI-2015/training/", 21 | "root_raw": "YOUR_DIR/KITTI-raw/", 22 | "full_seg_root_kitti12": "YOUR_DIR/full_seg/KITTI-2012/training/", 23 | "full_seg_root_kitti15": "YOUR_DIR/full_seg/KITTI-2015/training/", 24 | "full_seg_root_raw": "YOUR_DIR/full_seg/KITTI-raw/", 25 | "key_obj_root_kitti12": "YOUR_DIR/key_objects/KITTI-2012/training/", 26 | "key_obj_root_kitti15": "YOUR_DIR/key_objects/KITTI-2015/training/", 27 | "key_obj_root_raw": "YOUR_DIR/key_objects/KITTI-raw/", 28 | "run_at": true, 29 | "test_shape": [256, 832], 30 | "train_shape": [256, 832], 31 | "type": "KITTI_Raw+MV_2stage" 32 | }, 33 | "loss": { 34 | "edge_aware_alpha": 10, 35 | "occ_from_back": true, 36 | "smooth_type": "2nd", 37 | "smooth_edge": "image", 38 | "type": "unflow", 39 | "w_l1": 0.15, 40 | "w_ph_scales": [1.0, 1.0, 1.0, 1.0, 0.0], 41 | "w_sm": 0, 42 | "w_ssim": 0.85, 43 | "w_ternary": 0.0, 44 | "warp_pad": "border", 45 | "with_bk": true 46 | }, 47 | "model": { 48 | "learned_upsampler": true, 49 | "reduce_dense": true, 50 | "type": "pwclite" 51 | }, 52 | "seed": 42, 53 | "train": { 54 | "ar_eps": 0.0, 55 | "ar_q": 1.0, 56 | "batch_size": 8, 57 | "beta": 0.999, 58 | "bias_decay": 0, 59 | "epoch_num": 200, 60 | "epoch_size": 1000, 61 | "key_obj_aug": false, 62 | "lr": 0.0002, 63 | "lr_scheduler": { 64 | "module": "OneCycleLR", 65 | "params": { 66 | "max_lr": 0.0004, 67 | "pct_start": 0.05, 68 | "cycle_momentum": false, 69 | "anneal_strategy": "linear" 70 | } 71 | }, 72 | "mask_st": true, 73 | "max_grad_norm": 10, 74 | "momentum": 0.9, 75 | "n_gpu": 8, 76 | "optim": "adam", 77 | "pretrained_model": null, 78 | "print_freq": 100, 79 | "record_freq": 500, 80 | "run_atst": false, 81 | "run_ot": false, 82 | "run_st": false, 83 | "save_iter": 10000, 84 | "st_cfg": { 85 | "add_noise": true, 86 | "hflip": true, 87 | "rotate": [-0.01, 0.01, -0.01, 0.01], 88 | "squeeze": [1.0, 1.0, 1.0, 1.0], 89 | "trans": [0.04, 0.005], 90 | "vflip": false, 91 | "zoom": [1.0, 1.4, 0.99, 1.01] 92 | }, 93 | "stage1": { 94 | "epoch": 50, 95 | "loss": { 96 | "occ_from_back": false, 97 | "w_l1": 0.0, 98 | "w_ssim": 0.0, 99 | "w_ternary": 1.0 100 | }, 101 | "train": { 102 | "key_obj_aug": false, 103 | "ot_size": [192, 640], 104 | "run_atst": true, 105 | "run_ot": true, 106 | "run_st": true 107 | } 108 | }, 109 | "val_epoch_size": 5, 110 | "valid_size": 0, 111 | "w_ar": 0.02, 112 | "weight_decay": 1e-06, 113 | "workers": 8 114 | }, 115 | "trainer": "KITTI_AR" 116 | } 117 | -------------------------------------------------------------------------------- /configs/sintel_aug+hg+mf.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_configs": "sintel_base.json", 3 | "model": { 4 | "add_mask_corr": true, 5 | "aggregation_type": "concat" 6 | }, 7 | "train": { 8 | "stage2": { 9 | "epoch": 150, 10 | "loss": { 11 | "smooth_type": "homography", 12 | "w_sm": 0.1 13 | }, 14 | "train": { 15 | "key_obj_aug": true, 16 | "key_obj_count": 3, 17 | "w_ar": 0.1 18 | } 19 | } 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /configs/sintel_aug+hg.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_configs": "sintel_base.json", 3 | "train": { 4 | "stage2": { 5 | "epoch": 150, 6 | "loss": { 7 | "smooth_type": "homography", 8 | "w_sm": 0.1 9 | }, 10 | "train": { 11 | "key_obj_aug": true, 12 | "key_obj_count": 3, 13 | "w_ar": 0.1 14 | } 15 | } 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /configs/sintel_aug.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_configs": "sintel_base.json", 3 | "train": { 4 | "stage2": { 5 | "epoch": 150, 6 | "train": { 7 | "key_obj_aug": true, 8 | "key_obj_count": 3, 9 | "w_ar": 0.1 10 | } 11 | } 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /configs/sintel_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": { 3 | "at_cfg": { 4 | "cj": true, 5 | "cj_bri": 0.3, 6 | "cj_con": 0.3, 7 | "cj_hue": 0.1, 8 | "cj_sat": 0.3, 9 | "gamma": false, 10 | "gblur": true 11 | }, 12 | "data_aug": { 13 | "crop": true, 14 | "hflip": true, 15 | "para_crop": [384, 832], 16 | "swap": true 17 | }, 18 | "epoches_raw": 100, 19 | "epoches_ft": -1, 20 | "root_sintel": "YOUR_DIR/Sintel/", 21 | "root_sintel_raw": "YOUR_DIR/Sintel-raw/", 22 | "full_seg_root_sintel": "YOUR_DIR/full_seg/Sintel", 23 | "full_seg_root_sintel_raw": "YOUR_DIR/full_seg/Sintel-raw", 24 | "key_obj_root_sintel": "YOUR_DIR/key_objects/Sintel", 25 | "key_obj_root_sintel_raw": "YOUR_DIR/key_objects/Sintel-raw", 26 | "run_at": true, 27 | "test_shape": [448, 1024], 28 | "train_subsplit": "trainval", 29 | "type": "Sintel_Raw+ft_2stage", 30 | "val_subsplit": "trainval" 31 | }, 32 | "loss": { 33 | "edge_aware_alpha": 10, 34 | "occ_from_back": true, 35 | "smooth_type": "2nd", 36 | "smooth_edge": "image", 37 | "type": "unflow", 38 | "w_l1": 0.15, 39 | "w_ph_scales": [1.0, 1.0, 1.0, 1.0, 0.0], 40 | "w_sm": 0, 41 | "w_ssim": 0.85, 42 | "w_ternary": 0.0, 43 | "warp_pad": "border", 44 | "with_bk": true 45 | }, 46 | "model": { 47 | "learned_upsampler": true, 48 | "reduce_dense": true, 49 | "type": "pwclite" 50 | }, 51 | "seed": 42, 52 | "train": { 53 | "ar_eps": 0.0, 54 | "ar_q": 1.0, 55 | "batch_size": 8, 56 | "beta": 0.999, 57 | "bias_decay": 0, 58 | "epoch_num": 200, 59 | "epoch_size": 1000, 60 | "key_obj_aug": false, 61 | "lr": 0.0002, 62 | "lr_scheduler": { 63 | "module": "OneCycleLR", 64 | "params": { 65 | "max_lr": 0.0004, 66 | "pct_start": 0.05, 67 | "cycle_momentum": false, 68 | "anneal_strategy": "linear" 69 | } 70 | }, 71 | "mask_st": true, 72 | "max_grad_norm": 10, 73 | "momentum": 0.9, 74 | "n_gpu": 8, 75 | "optim": "adam", 76 | "pretrained_model": null, 77 | "print_freq": 100, 78 | "record_freq": 500, 79 | "run_atst": false, 80 | "run_ot": false, 81 | "run_st": false, 82 | "save_iter": 10000, 83 | "st_cfg": { 84 | "add_noise": true, 85 | "hflip": true, 86 | "rotate": [-0.2, 0.2, -0.015, 0.015], 87 | "squeeze": [0.86, 1.16, 1.0, 1.0], 88 | "trans": [0.2, 0.015], 89 | "vflip": true, 90 | "zoom": [1.0, 1.5, 0.985, 1.015] 91 | }, 92 | "stage1": { 93 | "epoch": 50, 94 | "loss": { 95 | "occ_from_back": false, 96 | "w_l1": 0.0, 97 | "w_ssim": 0.0, 98 | "w_ternary": 1.0 99 | }, 100 | "train": { 101 | "ot_size": [320, 704], 102 | "run_atst": true, 103 | "run_ot": true, 104 | "run_st": true 105 | } 106 | }, 107 | "val_epoch_size": 5, 108 | "valid_size": 0, 109 | "w_ar": 0.02, 110 | "weight_decay": 1e-06, 111 | "workers": 8 112 | }, 113 | "trainer": "SINTEL_AR" 114 | } 115 | -------------------------------------------------------------------------------- /datasets/flow_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | """ 4 | 5 | import os 6 | 7 | from abc import ABCMeta, abstractmethod 8 | 9 | # from glob import glob 10 | 11 | import imageio 12 | import numpy as np 13 | import torch 14 | 15 | # from transforms.input_transforms import full_segs_to_adj_maps 16 | from utils.flow_utils import load_flow 17 | 18 | from utils.manifold_utils import pathmgr 19 | 20 | 21 | def local_path(path): 22 | if "manifold" in path: 23 | return pathmgr.get_local_path(path) 24 | else: 25 | return path 26 | 27 | 28 | class ImgSeqDataset(torch.utils.data.Dataset, metaclass=ABCMeta): 29 | def __init__( 30 | self, 31 | root, 32 | full_seg_root, 33 | key_obj_root=None, 34 | name="", 35 | input_transform=None, 36 | co_transform=None, 37 | ap_transform=None, 38 | ): 39 | self.root = root 40 | self.full_seg_root = full_seg_root 41 | self.key_obj_root = key_obj_root 42 | self.name = name 43 | self.input_transform = input_transform 44 | self.co_transform = co_transform 45 | self.ap_transform = ap_transform 46 | self.samples = self.collect_samples() 47 | 48 | @abstractmethod 49 | def collect_samples(self): 50 | pass 51 | 52 | def _load_sample(self, s): 53 | 54 | imgs = [] 55 | full_segs = [] 56 | key_objs = [] 57 | for p in s["imgs"]: 58 | 59 | image = ( 60 | imageio.imread(local_path(os.path.join(self.root, p))).astype( 61 | np.float32 62 | ) 63 | / 255.0 64 | ) 65 | imgs.append(image) 66 | 67 | full_seg = imageio.imread(local_path(os.path.join(self.full_seg_root, p)))[ 68 | :, :, None 69 | ] 70 | full_segs.append(full_seg) 71 | 72 | if self.key_obj_root is not None: 73 | key_obj = ( 74 | np.load( 75 | local_path(os.path.join(self.key_obj_root, p[:-4] + ".npy")) 76 | ) 77 | / 255.0 78 | ) 79 | key_objs.append(key_obj) 80 | 81 | return imgs, full_segs, key_objs 82 | 83 | def __len__(self): 84 | return len(self.samples) 85 | 86 | def __getitem__(self, idx): 87 | imgs, full_segs, key_objs = self._load_sample(self.samples[idx]) 88 | 89 | data = { 90 | "raw_size": imgs[0].shape[:2], 91 | "img1_path": os.path.join(self.root, self.samples[idx]["imgs"][0]), 92 | } 93 | 94 | if self.co_transform is not None: 95 | # In unsupervised learning, there is no need to change target with image 96 | imgs, full_segs, key_objs, _ = self.co_transform( 97 | imgs, full_segs, key_objs, {} 98 | ) 99 | 100 | if self.input_transform is not None: 101 | imgs, full_segs, key_objs = self.input_transform( 102 | (imgs, full_segs, key_objs) 103 | ) 104 | 105 | # adj_maps = full_segs_to_adj_maps(torch.stack(full_segs), win_size=9) 106 | 107 | data.update( 108 | { 109 | "img1": imgs[0], 110 | "img2": imgs[1], 111 | "full_seg1": full_segs[0], 112 | "full_seg2": full_segs[1], 113 | } 114 | ) 115 | 116 | # process key_objs to keep exactly three objects (to make sure the number of objects is fixed so that we can form batches) 117 | if self.key_obj_root is not None: 118 | place_holder = torch.full( 119 | (1, *key_objs[0].shape[1:]), np.nan, dtype=torch.float32 120 | ) 121 | 122 | if key_objs[0].shape[0] == 0: 123 | key_obj = place_holder 124 | else: 125 | valid_key_obj = ( 126 | key_objs[0].mean(axis=(1, 2)) >= 0.005 127 | ) ## some objects may be too small after cropping 128 | 129 | if valid_key_obj.sum() == 0: 130 | key_obj = place_holder 131 | else: 132 | idx = np.random.choice(np.where(valid_key_obj)[0]) 133 | key_obj = key_objs[0][idx : idx + 1] 134 | 135 | data["key_obj_mask"] = key_obj 136 | 137 | if self.ap_transform is not None: 138 | data["img1_ph"], data["img2_ph"] = self.ap_transform( 139 | [imgs[0].clone(), imgs[1].clone()] 140 | ) 141 | 142 | return data 143 | 144 | 145 | class KITTIRawFile(ImgSeqDataset): 146 | def __init__( 147 | self, 148 | root, 149 | full_seg_root, 150 | key_obj_root, 151 | name="kitti-raw", 152 | ap_transform=None, 153 | input_transform=None, 154 | co_transform=None, 155 | ): 156 | super(KITTIRawFile, self).__init__( 157 | root, 158 | full_seg_root, 159 | key_obj_root, 160 | name, 161 | input_transform=input_transform, 162 | co_transform=co_transform, 163 | ap_transform=ap_transform, 164 | ) 165 | 166 | def collect_samples(self): 167 | sp_file = os.path.join(self.root, "kitti_train_2f_sv.txt") 168 | 169 | samples = [] 170 | with open(local_path(sp_file), "r") as f: 171 | for line in f.readlines(): 172 | sp = line.split() 173 | samples.append({"imgs": sp[0:2]}) 174 | samples.append({"imgs": sp[2:4]}) 175 | 176 | return samples 177 | 178 | 179 | class KITTIFlowMV(ImgSeqDataset): 180 | """ 181 | This dataset is used for unsupervised training only 182 | """ 183 | 184 | def __init__( 185 | self, 186 | root, 187 | full_seg_root, 188 | key_obj_root, 189 | name="", 190 | input_transform=None, 191 | co_transform=None, 192 | ap_transform=None, 193 | ): 194 | super(KITTIFlowMV, self).__init__( 195 | root, 196 | full_seg_root, 197 | key_obj_root, 198 | name, 199 | input_transform=input_transform, 200 | co_transform=co_transform, 201 | ap_transform=ap_transform, 202 | ) 203 | 204 | def collect_samples(self): 205 | 206 | sp_file = os.path.join(self.root, "sample_list_mv.txt") 207 | 208 | samples = [] 209 | with open(local_path(sp_file), "r") as f: 210 | for line in f.readlines(): 211 | samples.append({"imgs": line.split()}) 212 | 213 | return samples 214 | 215 | 216 | class KITTIFlowEval(ImgSeqDataset): 217 | """ 218 | This dataset is used for validation/test ONLY, so all files about target are stored as 219 | file filepath and there is no transform about target. 220 | """ 221 | 222 | def __init__( 223 | self, 224 | root, 225 | full_seg_root, 226 | key_obj_root, 227 | name="", 228 | input_transform=None, 229 | test_mode=False, 230 | ): 231 | self.test_mode = test_mode 232 | super(KITTIFlowEval, self).__init__( 233 | root, full_seg_root, key_obj_root, name, input_transform=input_transform 234 | ) 235 | 236 | def __getitem__(self, idx): 237 | data = super(KITTIFlowEval, self).__getitem__(idx) 238 | if not self.test_mode: 239 | # for validation; we do not load here because different samples have different sizes 240 | data["flow_occ"] = os.path.join(self.root, self.samples[idx]["flow_occ"]) 241 | data["flow_noc"] = os.path.join(self.root, self.samples[idx]["flow_noc"]) 242 | 243 | return data 244 | 245 | def collect_samples(self): 246 | """Will search in training folder for folders 'flow_noc' or 'flow_occ' 247 | and 'colored_0' (KITTI 2012) or 'image_2' (KITTI 2015)""" 248 | 249 | sp_file = os.path.join(self.root, "sample_list.txt") 250 | 251 | samples = [] 252 | with open(local_path(sp_file), "r") as f: 253 | for line in f.readlines(): 254 | samples.append({"imgs": line.split()}) 255 | 256 | if self.test_mode: 257 | return samples 258 | else: 259 | for i, sample in enumerate(samples): 260 | filename = os.path.basename(sample["imgs"][0]) 261 | 262 | samples[i].update( 263 | { 264 | "flow_occ": os.path.join("flow_occ", filename), 265 | "flow_noc": os.path.join("flow_noc", filename), 266 | } 267 | ) 268 | 269 | return samples 270 | 271 | 272 | class SintelRaw(ImgSeqDataset): 273 | def __init__( 274 | self, 275 | root, 276 | full_seg_root, 277 | key_obj_root, 278 | name="", 279 | input_transform=None, 280 | ap_transform=None, 281 | co_transform=None, 282 | ): 283 | super(SintelRaw, self).__init__( 284 | root, 285 | full_seg_root, 286 | key_obj_root, 287 | name, 288 | input_transform=input_transform, 289 | ap_transform=ap_transform, 290 | co_transform=co_transform, 291 | ) 292 | 293 | def collect_samples(self): 294 | 295 | sp_file = os.path.join(self.root, "sample_list.txt") 296 | 297 | samples = [] 298 | with open(local_path(sp_file), "r") as f: 299 | for line in f.readlines(): 300 | samples.append({"imgs": line.split()}) 301 | 302 | return samples 303 | 304 | 305 | class Sintel(ImgSeqDataset): 306 | def __init__( 307 | self, 308 | root, 309 | full_seg_root, 310 | key_obj_root, 311 | name="", 312 | dataset_type="clean", 313 | split="train", 314 | subsplit="trainval", 315 | with_flow=False, 316 | input_transform=None, 317 | co_transform=None, 318 | ap_transform=None, 319 | ): 320 | self.dataset_type = dataset_type 321 | self.with_flow = with_flow 322 | 323 | self.split = split 324 | self.subsplit = subsplit 325 | self.training_scenes = [ 326 | "alley_1", 327 | "ambush_4", 328 | "ambush_6", 329 | "ambush_7", 330 | "bamboo_2", 331 | "bandage_2", 332 | "cave_2", 333 | "market_2", 334 | "market_5", 335 | "shaman_2", 336 | "sleeping_2", 337 | "temple_3", 338 | ] # Unofficial train-val split 339 | 340 | super(Sintel, self).__init__( 341 | root, 342 | full_seg_root, 343 | key_obj_root, 344 | name, 345 | input_transform=input_transform, 346 | co_transform=co_transform, 347 | ap_transform=ap_transform, 348 | ) 349 | 350 | def __getitem__(self, idx): 351 | data = super(Sintel, self).__getitem__(idx) 352 | if self.with_flow: 353 | data["flow_gt"] = load_flow( 354 | pathmgr.get_local_path(self.samples[idx]["flow"]) 355 | ).astype(np.float32) 356 | data["occ_mask"] = ( 357 | imageio.imread( 358 | pathmgr.get_local_path(self.samples[idx]["occ_mask"]) 359 | ).astype(np.float32)[:, :, None] 360 | / 255.0 361 | ) 362 | 363 | return data 364 | 365 | def collect_samples(self): 366 | 367 | samples = [] 368 | filename = self.split + "_" + self.dataset_type + "_images.txt" 369 | sp_file = os.path.join(self.root, filename) 370 | 371 | with open(local_path(sp_file), "r") as f: 372 | for line in f.readlines(): 373 | img1, img2 = line[:-1].split(",") 374 | path_split = img1.split("/") 375 | scene = path_split[-2] 376 | sample = { 377 | "imgs": [ 378 | "/".join(img1.split("/")[-4:]), 379 | "/".join(img2.split("/")[-4:]), 380 | ] 381 | } 382 | if self.with_flow: 383 | sample["flow"] = os.path.join( 384 | "/".join(path_split[:-3]), 385 | "flow", 386 | scene, 387 | path_split[-1][:-4] + ".flo", 388 | ) 389 | sample["occ_mask"] = os.path.join( 390 | "/".join(path_split[:-3]), 391 | "occlusions", 392 | scene, 393 | path_split[-1], 394 | ) 395 | 396 | if self.subsplit == "trainval": 397 | samples.append(sample) 398 | elif self.subsplit == "train" and scene in self.training_scenes: 399 | samples.append(sample) 400 | elif self.subsplit == "val" and scene not in self.training_scenes: 401 | samples.append(sample) 402 | 403 | return samples 404 | -------------------------------------------------------------------------------- /datasets/get_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | """ 4 | 5 | from datasets.flow_datasets import ( 6 | KITTIFlowEval, 7 | KITTIFlowMV, 8 | KITTIRawFile, 9 | Sintel, 10 | SintelRaw, 11 | ) 12 | 13 | from torch.utils.data import ConcatDataset 14 | from torchvision import transforms 15 | from transforms import input_transforms 16 | from transforms.ar_transforms.ap_transforms import get_ap_transforms 17 | from transforms.co_transforms import get_co_transforms 18 | 19 | 20 | def get_dataset(cfg): 21 | 22 | co_transform = get_co_transforms(aug_args=cfg.data_aug) 23 | ap_transform = get_ap_transforms(cfg.at_cfg) if cfg.run_at else None 24 | 25 | if cfg.type == "KITTI_Raw+MV_2stage": 26 | 27 | train_input_transform = transforms.Compose( 28 | [input_transforms.Zoom(*cfg.train_shape), input_transforms.ArrayToTensor()] 29 | ) 30 | valid_input_transform = transforms.Compose( 31 | [input_transforms.Zoom(*cfg.test_shape), input_transforms.ArrayToTensor()] 32 | ) 33 | 34 | train_set_1 = KITTIRawFile( 35 | cfg.root_raw, 36 | cfg.full_seg_root_raw, 37 | cfg.key_obj_root_raw, 38 | name="kitti-raw", 39 | input_transform=train_input_transform, 40 | ap_transform=ap_transform, 41 | co_transform=co_transform, 42 | ) 43 | train_set_2_1 = KITTIFlowMV( 44 | cfg.root_kitti15, 45 | cfg.full_seg_root_kitti15, 46 | cfg.key_obj_root_kitti15, 47 | name="kitti2015-mv", 48 | input_transform=train_input_transform, 49 | ap_transform=ap_transform, 50 | co_transform=co_transform, 51 | ) 52 | train_set_2_2 = KITTIFlowMV( 53 | cfg.root_kitti12, 54 | cfg.full_seg_root_kitti12, 55 | cfg.key_obj_root_kitti12, 56 | name="kitti2012-mv", 57 | input_transform=train_input_transform, 58 | ap_transform=ap_transform, 59 | co_transform=co_transform, 60 | ) 61 | train_set_2 = ConcatDataset([train_set_2_1, train_set_2_2]) 62 | train_set_2.name = "kitti-mv" 63 | 64 | valid_set_1 = KITTIFlowEval( 65 | cfg.root_kitti15, 66 | cfg.full_seg_root_kitti15, 67 | None, 68 | name="kitti2015", 69 | input_transform=valid_input_transform, 70 | ) 71 | valid_set_2 = KITTIFlowEval( 72 | cfg.root_kitti12, 73 | cfg.full_seg_root_kitti12, 74 | None, 75 | name="kitti2012", 76 | input_transform=valid_input_transform, 77 | ) 78 | 79 | train_sets = [train_set_1, train_set_2] 80 | train_sets_epoches = [cfg.epoches_raw, cfg.epoches_mv] 81 | valid_sets = [valid_set_1, valid_set_2] 82 | 83 | elif cfg.type == "Sintel_Raw+ft_2stage": 84 | 85 | train_input_transform = transforms.Compose([input_transforms.ArrayToTensor()]) 86 | valid_input_transform = transforms.Compose( 87 | [input_transforms.Zoom(*cfg.test_shape), input_transforms.ArrayToTensor()] 88 | ) 89 | 90 | train_set_1 = SintelRaw( 91 | cfg.root_sintel_raw, 92 | cfg.full_seg_root_sintel_raw, 93 | cfg.key_obj_root_sintel_raw, 94 | name="sintel-raw", 95 | input_transform=train_input_transform, 96 | ap_transform=ap_transform, 97 | co_transform=co_transform, 98 | ) 99 | train_set_2_1 = Sintel( 100 | cfg.root_sintel, 101 | cfg.full_seg_root_sintel, 102 | cfg.key_obj_root_sintel, 103 | name="sintel-clean_" + cfg.train_subsplit, 104 | dataset_type="clean", 105 | split="train", 106 | subsplit=cfg.train_subsplit, 107 | input_transform=train_input_transform, 108 | ap_transform=ap_transform, 109 | co_transform=co_transform, 110 | ) 111 | train_set_2_2 = Sintel( 112 | cfg.root_sintel, 113 | cfg.full_seg_root_sintel, 114 | cfg.key_obj_root_sintel, 115 | name="sintel-final_" + cfg.train_subsplit, 116 | dataset_type="final", 117 | split="train", 118 | subsplit=cfg.train_subsplit, 119 | input_transform=train_input_transform, 120 | ap_transform=ap_transform, 121 | co_transform=co_transform, 122 | ) 123 | train_set_2 = ConcatDataset([train_set_2_1, train_set_2_2]) 124 | train_set_2.name = "sintel_clean+final_" + cfg.train_subsplit 125 | 126 | valid_set_1 = Sintel( 127 | cfg.root_sintel, 128 | cfg.full_seg_root_sintel, 129 | None, 130 | name="sintel-clean_" + cfg.val_subsplit, 131 | dataset_type="clean", 132 | split="train", 133 | subsplit=cfg.val_subsplit, 134 | with_flow=True, # for validation 135 | input_transform=valid_input_transform, 136 | ) 137 | valid_set_2 = Sintel( 138 | cfg.root_sintel, 139 | cfg.full_seg_root_sintel, 140 | None, 141 | name="sintel-final_" + cfg.val_subsplit, 142 | dataset_type="final", 143 | split="train", 144 | subsplit=cfg.val_subsplit, 145 | with_flow=True, # for validation 146 | input_transform=valid_input_transform, 147 | ) 148 | 149 | train_sets = [train_set_1, train_set_2] 150 | train_sets_epoches = [cfg.epoches_raw, cfg.epoches_ft] 151 | valid_sets = [valid_set_1, valid_set_2] 152 | 153 | return train_sets, valid_sets, train_sets_epoches 154 | -------------------------------------------------------------------------------- /intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/UnSAMFlow/0363aed7b258ad7e659d82cbdb1766709f7b9429/intro.png -------------------------------------------------------------------------------- /losses/flow_loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | """ 4 | 5 | # import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | # import torchvision 11 | from utils.warp_utils import ( 12 | flow_warp, 13 | get_occu_mask_backward, 14 | get_occu_mask_bidirection, 15 | ) 16 | 17 | from .loss_blocks import ( 18 | smooth_grad_1st, 19 | smooth_grad_2nd, 20 | smooth_homography, 21 | SSIM, 22 | TernaryLoss, 23 | ) 24 | 25 | 26 | class unFlowLoss(nn.modules.Module): 27 | def __init__(self, cfg): 28 | super(unFlowLoss, self).__init__() 29 | self.cfg = cfg 30 | if "ransac_threshold" not in cfg: 31 | self.cfg.ransac_threshold = 3 32 | 33 | def loss_photomatric(self, im1_scaled, im1_recons, vis_mask1): 34 | loss = [] 35 | 36 | if self.cfg.w_l1 > 0: 37 | loss += [self.cfg.w_l1 * (im1_scaled - im1_recons).abs() * vis_mask1] 38 | 39 | if self.cfg.w_ssim > 0: 40 | loss += [ 41 | self.cfg.w_ssim * SSIM(im1_recons * vis_mask1, im1_scaled * vis_mask1) 42 | ] 43 | 44 | if self.cfg.w_ternary > 0: 45 | loss += [ 46 | self.cfg.w_ternary 47 | * TernaryLoss(im1_recons * vis_mask1, im1_scaled * vis_mask1) 48 | ] 49 | 50 | return sum([item.mean() for item in loss]) / (vis_mask1.mean() + 1e-6) 51 | 52 | def loss_smooth(self, flow, im1_scaled, **kwargs): 53 | 54 | loss = [] 55 | if self.cfg.smooth_type == "2nd": 56 | func_smooth = smooth_grad_2nd 57 | elif self.cfg.smooth_type == "1st": 58 | func_smooth = smooth_grad_1st 59 | 60 | if "smooth_edge" not in self.cfg or self.cfg.smooth_edge == "image": 61 | loss += [ 62 | func_smooth( 63 | flow, im1_scaled, edge="image", alpha=self.cfg.edge_aware_alpha 64 | ) 65 | ] 66 | else: 67 | loss += [ 68 | func_smooth( 69 | flow, im1_scaled, edge="full_seg", full_seg=kwargs["full_seg"] 70 | ) 71 | ] 72 | return sum([item.mean() for item in loss]) 73 | 74 | def loss_smooth_homography(self, flow, full_seg, occ_mask): 75 | loss = smooth_homography( 76 | flow, 77 | full_seg=full_seg, 78 | occ_mask=occ_mask, 79 | ransac_threshold=self.cfg.ransac_threshold, 80 | ) 81 | return loss 82 | 83 | def loss_one_pair( 84 | self, pyramid_flows, im1_origin, im2_origin, occ_aware=True, **kwargs 85 | ): 86 | """ 87 | 88 | :param output: Multi-scale forward/backward flows n * [B x 4 x h x w] 89 | :param target: image pairs Nx6xHxW 90 | :return: 91 | """ 92 | DEVICE = pyramid_flows[0].device 93 | 94 | # process data 95 | B, _, H, W = im1_origin.shape 96 | 97 | # generate visibility mask/occlusion estimation 98 | top_flow = pyramid_flows[0] 99 | scale = min(*top_flow.shape[-2:]) 100 | 101 | if self.cfg.occ_from_back: 102 | vis_mask1 = 1 - get_occu_mask_backward(top_flow[:, 2:], th=0.2) 103 | vis_mask2 = 1 - get_occu_mask_backward(top_flow[:, :2], th=0.2) 104 | else: 105 | vis_mask1 = 1 - get_occu_mask_bidirection(top_flow[:, :2], top_flow[:, 2:]) 106 | vis_mask2 = 1 - get_occu_mask_bidirection(top_flow[:, 2:], top_flow[:, :2]) 107 | 108 | pyramid_vis_mask1 = [vis_mask1] 109 | pyramid_vis_mask2 = [vis_mask2] 110 | for i in range(1, 5): 111 | _, _, h, w = pyramid_flows[i].size() 112 | pyramid_vis_mask1.append(F.interpolate(vis_mask1, (h, w), mode="nearest")) 113 | pyramid_vis_mask2.append(F.interpolate(vis_mask2, (h, w), mode="nearest")) 114 | 115 | # compute losses at each level 116 | pyramid_warp_losses = [] 117 | pyramid_smooth_losses = [] 118 | zero_loss = torch.tensor(0, dtype=torch.float32, device=DEVICE) 119 | 120 | for i, flow in enumerate(pyramid_flows): 121 | 122 | # resize images to match the size of layer 123 | b, _, h, w = flow.size() 124 | im1_scaled, im2_scaled = None, None 125 | 126 | # photometric loss 127 | if self.cfg.w_ph_scales[i] > 0: 128 | im1_scaled = F.interpolate(im1_origin, (h, w), mode="area") 129 | im2_scaled = F.interpolate(im2_origin, (h, w), mode="area") 130 | im1_recons = flow_warp(im2_scaled, flow[:, :2], pad=self.cfg.warp_pad) 131 | im2_recons = flow_warp(im1_scaled, flow[:, 2:], pad=self.cfg.warp_pad) 132 | 133 | if occ_aware: 134 | vis_mask1, vis_mask2 = pyramid_vis_mask1[i], pyramid_vis_mask2[i] 135 | else: 136 | vis_mask1 = torch.ones( 137 | (b, 1, h, w), dtype=torch.float32, device=DEVICE 138 | ) 139 | vis_mask2 = torch.ones( 140 | (b, 1, h, w), dtype=torch.float32, device=DEVICE 141 | ) 142 | 143 | loss_warp = self.loss_photomatric(im1_scaled, im1_recons, vis_mask1) 144 | if self.cfg.with_bk: 145 | loss_warp += self.loss_photomatric( 146 | im2_scaled, im2_recons, vis_mask2 147 | ) 148 | loss_warp /= 2.0 149 | pyramid_warp_losses.append(loss_warp) 150 | 151 | else: 152 | pyramid_warp_losses.append(zero_loss) 153 | 154 | # smoothness loss 155 | if i == 0 and self.cfg.w_sm > 0: 156 | if self.cfg.smooth_type == "homography": 157 | loss_smooth = self.loss_smooth_homography( 158 | flow[:, :2], 159 | full_seg=kwargs["full_seg1"], 160 | occ_mask=1 - vis_mask1, 161 | ) 162 | if self.cfg.with_bk: 163 | loss_smooth += self.loss_smooth_homography( 164 | flow[:, 2:], 165 | full_seg=kwargs["full_seg2"], 166 | occ_mask=1 - vis_mask2, 167 | ) 168 | loss_smooth /= 2.0 169 | else: 170 | if im1_scaled is None: 171 | im1_scaled = F.interpolate(im1_origin, (h, w), mode="area") 172 | im2_scaled = F.interpolate(im2_origin, (h, w), mode="area") 173 | 174 | loss_smooth = self.loss_smooth( 175 | flow[:, :2] / scale, im1_scaled, full_seg=kwargs["full_seg1"] 176 | ) 177 | if self.cfg.with_bk: 178 | loss_smooth += self.loss_smooth( 179 | flow[:, 2:] / scale, 180 | im2_scaled, 181 | full_seg=kwargs["full_seg2"], 182 | ) 183 | loss_smooth /= 2.0 184 | 185 | pyramid_smooth_losses.append(loss_smooth) 186 | 187 | else: 188 | pyramid_smooth_losses.append(zero_loss) 189 | 190 | # debug: print to see 191 | """ 192 | import numpy as np 193 | from utils.flow_utils import flow_to_image 194 | import matplotlib.pyplot as plt 195 | 196 | img1_show = (im1_scaled.cpu().numpy().transpose(0, 2, 3, 1) * 255).astype(np.uint8) 197 | img2_show = (im2_scaled.cpu().numpy().transpose(0, 2, 3, 1) * 255).astype(np.uint8) 198 | 199 | flow12_numpy = flow[:, :2].detach().cpu().numpy().transpose(0, 2, 3, 1) 200 | flow12_show = [] 201 | for f in flow12_numpy: 202 | flow12_show.append(flow_to_image(f)) 203 | flow12_show = np.stack(flow12_show) 204 | 205 | flow21_numpy = flow[:, 2:].detach().cpu().numpy().transpose(0, 2, 3, 1) 206 | flow21_show = [] 207 | for f in flow21_numpy: 208 | flow21_show.append(flow_to_image(f)) 209 | flow21_show = np.stack(flow21_show) 210 | 211 | vis1_show = (vis_mask1.cpu().numpy().transpose(0, 2, 3, 1) * 255).astype(np.uint8).repeat(3, axis=3) 212 | vis2_show = (vis_mask2.cpu().numpy().transpose(0, 2, 3, 1) * 255).astype(np.uint8).repeat(3, axis=3) 213 | img1_warp_show = (im1_recons.detach().cpu().numpy().transpose(0, 2, 3, 1) * 255).astype(np.uint8) 214 | img2_warp_show = (im2_recons.detach().cpu().numpy().transpose(0, 2, 3, 1) * 255).astype(np.uint8) 215 | 216 | ternary12, ternary21, sem12, sem21 = TEMP[-4:] 217 | ternary12_show = (ternary12.detach().cpu().numpy().transpose(0, 2, 3, 1) * 255).astype(np.uint8).repeat(3, axis=3) 218 | ternary21_show = (ternary21.detach().cpu().numpy().transpose(0, 2, 3, 1) * 255).astype(np.uint8).repeat(3, axis=3) 219 | sem12_show = ((sem12/2).detach().cpu().numpy().transpose(0, 2, 3, 1) * 255).astype(np.uint8).repeat(3, axis=3) 220 | sem21_show = ((sem21/2).detach().cpu().numpy().transpose(0, 2, 3, 1) * 255).astype(np.uint8).repeat(3, axis=3) 221 | 222 | all_show = np.concatenate((np.concatenate((img1_show, img2_show, sem1_show, sem2_show), axis=1), 223 | np.concatenate((flow12_show, flow21_show, vis1_show, vis2_show), axis=1), 224 | np.concatenate((img1_warp_show, img2_warp_show, sem1_warp_show, sem2_warp_show), axis=1), 225 | np.concatenate((ternary12_show, ternary21_show, sem12_show, sem21_show), axis=1)), 226 | axis=2) 227 | b, h, w, c = all_show.shape 228 | all_show = np.concatenate((all_show[:, :, :w//2, :], all_show[:, :, w//2:, :]), axis=1) 229 | #all_show = all_show.reshape((b*h, w, c)) 230 | all_show = all_show[0] 231 | 232 | import IPython; IPython.embed(); exit() 233 | plt.imsave('_DEBUG_DEMO_{}.png'.format(i), all_show) 234 | """ 235 | 236 | """ 237 | if i == 0: # for analysis 238 | self.l_ph_0 = loss_warp 239 | self.l_ph_L1_map_0 = (im1_scaled - im1_recons).abs().mean(dim=1) 240 | """ 241 | 242 | # aggregate losses 243 | pyramid_warp_losses = [ 244 | item * w for item, w in zip(pyramid_warp_losses, self.cfg.w_ph_scales) 245 | ] 246 | 247 | l_ph = sum(pyramid_warp_losses) 248 | l_sm = sum(pyramid_smooth_losses) 249 | 250 | loss = l_ph + self.cfg.w_sm * l_sm 251 | 252 | return ( 253 | loss, 254 | l_ph, 255 | l_sm, 256 | pyramid_flows[0][:, :2].norm(dim=1).mean(), 257 | pyramid_vis_mask1[0], 258 | pyramid_vis_mask2[0], 259 | ) 260 | 261 | def forward(self, pyramid_flows, img1, img2, occ_aware=True, **kwargs): 262 | ( 263 | loss, 264 | l_ph, 265 | l_sm, 266 | flow_mean, 267 | flow_vis_mask12, 268 | flow_vis_mask21, 269 | ) = self.loss_one_pair(pyramid_flows, img1, img2, occ_aware=occ_aware, **kwargs) 270 | 271 | return ( 272 | loss[None], 273 | l_ph[None], 274 | l_sm[None], 275 | flow_mean[None], 276 | flow_vis_mask12, 277 | flow_vis_mask21, 278 | ) 279 | -------------------------------------------------------------------------------- /losses/get_loss.py: -------------------------------------------------------------------------------- 1 | from .flow_loss import unFlowLoss 2 | 3 | 4 | def get_loss(cfg): 5 | if cfg.type == "unflow": 6 | loss = unFlowLoss(cfg) 7 | else: 8 | raise NotImplementedError(cfg.type) 9 | return loss 10 | -------------------------------------------------------------------------------- /losses/loss_blocks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | """ 4 | 5 | import cv2 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | # Crecit: https://github.com/simonmeister/UnFlow/blob/master/src/e2eflow/core/losses.py 12 | def TernaryLoss(im, im_warp, max_distance=1): 13 | patch_size = 2 * max_distance + 1 14 | 15 | def _rgb_to_grayscale(image): 16 | grayscale = ( 17 | image[:, 0, :, :] * 0.2989 18 | + image[:, 1, :, :] * 0.5870 19 | + image[:, 2, :, :] * 0.1140 20 | ) 21 | return grayscale.unsqueeze(1) 22 | 23 | def _ternary_transform(image): 24 | intensities = _rgb_to_grayscale(image) * 255 25 | out_channels = patch_size * patch_size 26 | w = torch.eye(out_channels).view((out_channels, 1, patch_size, patch_size)) 27 | weights = w.type_as(im) 28 | patches = F.conv2d(intensities, weights, padding=max_distance) 29 | transf = patches - intensities 30 | transf_norm = transf / torch.sqrt(0.81 + torch.pow(transf, 2)) 31 | return transf_norm 32 | 33 | def _hamming_distance(t1, t2): 34 | dist = torch.pow(t1 - t2, 2) 35 | dist_norm = dist / (0.1 + dist) 36 | dist_mean = torch.mean(dist_norm, 1, keepdim=True) # instead of sum 37 | return dist_mean 38 | 39 | def _valid_mask(t, padding): 40 | n, _, h, w = t.size() 41 | inner = torch.ones(n, 1, h - 2 * padding, w - 2 * padding).type_as(t) 42 | mask = F.pad(inner, [padding] * 4) 43 | return mask 44 | 45 | t1 = _ternary_transform(im) 46 | t2 = _ternary_transform(im_warp) 47 | dist = _hamming_distance(t1, t2) 48 | mask = _valid_mask(im, max_distance) 49 | 50 | return dist * mask 51 | 52 | 53 | def SSIM(x, y, md=1): 54 | patch_size = 2 * md + 1 55 | C1 = 0.01**2 56 | C2 = 0.03**2 57 | 58 | mu_x = nn.AvgPool2d(patch_size, 1, 0)(x) 59 | mu_y = nn.AvgPool2d(patch_size, 1, 0)(y) 60 | mu_x_mu_y = mu_x * mu_y 61 | mu_x_sq = mu_x.pow(2) 62 | mu_y_sq = mu_y.pow(2) 63 | 64 | sigma_x = nn.AvgPool2d(patch_size, 1, 0)(x * x) - mu_x_sq 65 | sigma_y = nn.AvgPool2d(patch_size, 1, 0)(y * y) - mu_y_sq 66 | sigma_xy = nn.AvgPool2d(patch_size, 1, 0)(x * y) - mu_x_mu_y 67 | 68 | SSIM_n = (2 * mu_x_mu_y + C1) * (2 * sigma_xy + C2) 69 | SSIM_d = (mu_x_sq + mu_y_sq + C1) * (sigma_x + sigma_y + C2) 70 | SSIM = SSIM_n / SSIM_d 71 | dist = torch.clamp((1 - SSIM) / 2, 0, 1) 72 | return dist 73 | 74 | 75 | def gradient(data): 76 | D_dy = data[..., 1:, :] - data[..., :-1, :] 77 | D_dx = data[..., :, 1:] - data[..., :, :-1] 78 | return D_dx, D_dy 79 | 80 | 81 | def get_image_edge_weights(image, alpha=10): 82 | img_dx, img_dy = gradient(image) 83 | weights_x = torch.exp(-torch.mean(torch.abs(img_dx), 1, keepdim=True) * alpha) 84 | weights_y = torch.exp(-torch.mean(torch.abs(img_dy), 1, keepdim=True) * alpha) 85 | return weights_x, weights_y 86 | 87 | 88 | def get_full_seg_edge_weights(full_seg): 89 | weights_y = (full_seg[..., 1:, :] - full_seg[..., :-1, :] == 0).float() 90 | weights_x = (full_seg[..., :, 1:] - full_seg[..., :, :-1] == 0).float() 91 | return weights_x, weights_y 92 | 93 | 94 | def smooth_grad_1st(flo, image, edge="image", **kwargs): 95 | if edge == "image": 96 | weights_x, weights_y = get_image_edge_weights(image, kwargs["alpha"]) 97 | elif edge == "full_seg": 98 | weights_x, weights_y = get_full_seg_edge_weights(kwargs["full_seg"]) 99 | 100 | dx, dy = gradient(flo) 101 | loss_x = weights_x * dx.abs() 102 | loss_y = weights_y * dy.abs() 103 | 104 | return loss_x.mean() / 2.0 + loss_y.mean() / 2.0 105 | 106 | 107 | def smooth_grad_2nd(flo, image, edge="image", **kwargs): 108 | if edge == "image": 109 | weights_x, weights_y = get_image_edge_weights(image, kwargs["alpha"]) 110 | elif edge == "full_seg": 111 | weights_x, weights_y = get_full_seg_edge_weights(kwargs["full_seg"]) 112 | 113 | dx, dy = gradient(flo) 114 | dx2, dxdy = gradient(dx) 115 | dydx, dy2 = gradient(dy) 116 | 117 | loss_x = weights_x[:, :, :, 1:] * dx2.abs() 118 | loss_y = weights_y[:, :, 1:, :] * dy2.abs() 119 | # loss_x = weights_x[:, :, :, 1:] * (torch.exp(dx2.abs() * 100) - 1) / 100. 120 | # loss_y = weights_y[:, :, 1:, :] * (torch.exp(dy2.abs() * 100) - 1) / 100. 121 | 122 | return loss_x.mean() / 2.0 + loss_y.mean() / 2.0 123 | 124 | 125 | def smooth_homography(flo, full_seg, occ_mask, ransac_threshold=3): 126 | 127 | DEVICE = flo.device 128 | B, _, h, w = flo.shape 129 | 130 | loss = torch.tensor(0, dtype=torch.float32, device=DEVICE) 131 | for i in range(B): 132 | 133 | ## find regions to refine 134 | n = int(full_seg[i].max().item() + 1) 135 | occ_mask_ids = full_seg[i, occ_mask[i].to(bool)].to(int) 136 | occ_mask_id_count = torch.eye(n, dtype=bool, device=DEVICE)[occ_mask_ids].sum( 137 | axis=0 138 | ) 139 | 140 | id_order = occ_mask_id_count.argsort(descending=True) 141 | refine_id = id_order[id_order > 0][ 142 | :6 143 | ] # we disregard the `0` mask id because it is just the non-masked region, not one object 144 | refine_id = refine_id.tolist() 145 | 146 | ## start refining 147 | coords1 = ( 148 | torch.stack(torch.meshgrid(torch.arange(h), torch.arange(w))[::-1], axis=2) 149 | .float() 150 | .to(DEVICE) 151 | ) 152 | coords2 = coords1 + flo[i].permute(1, 2, 0) 153 | 154 | # effective_mask_ids = [] 155 | for id in refine_id: 156 | reliable_mask = ( 157 | (1 - occ_mask[i, full_seg[i] == id]).bool().detach().cpu().numpy() 158 | ) 159 | if reliable_mask.sum() < 4 or reliable_mask.mean() < 0.2: 160 | # print("Mask #{0:} dropped due to low non-occcluded ratio ({1:4.2f}).".format(id, reliable_mask.mean())) 161 | continue 162 | 163 | pts1 = coords1[full_seg[i, 0] == id] 164 | pts2 = coords2[full_seg[i, 0] == id] 165 | 166 | H, mask = cv2.findHomography( 167 | pts1[reliable_mask].detach().cpu().numpy(), 168 | pts2[reliable_mask].detach().cpu().numpy(), 169 | cv2.RANSAC, 170 | ransac_threshold, 171 | ) 172 | 173 | if ( 174 | mask.mean() < 0.5 175 | ): # do not refine if the estimated homography's inlier rate < 0.5 176 | # print("Mask #{0:} dropped due to low inlier rate ({1:4.2f}).".format(id, mask.mean())) 177 | continue 178 | 179 | H = torch.FloatTensor(H).to(DEVICE) 180 | pts1_homo = torch.concat( 181 | (pts1, torch.ones((pts1.shape[0], 1)).to(DEVICE)), dim=1 182 | ) 183 | new_pts2_homo = torch.matmul(H, pts1_homo.T).T 184 | new_pts2 = new_pts2_homo[:, :2] / new_pts2_homo[:, 2:3] 185 | diff = (new_pts2 - pts2)[:, :2] 186 | # flow_refined[i, :, full_seg[i, 0] == id] = (new_pts2 - pts1)[:, :2].T 187 | 188 | loss += diff.abs().sum() / (h * w) 189 | # effective_mask_ids.append(id) 190 | 191 | ## DEBUG: 192 | # import IPython; IPython.embed(); exit() 193 | # import matplotlib.pyplot as plt 194 | # plt.imsave("_DEBUG_flow.png", flow_to_image(flo[i].detach().cpu().numpy().transpose(1, 2, 0))) 195 | # plt.imsave("_DEBUG_flow_refined.png", flow_to_image(flow_refined[i].detach().cpu().numpy().transpose(1, 2, 0))) 196 | # from skimage.segmentation import mark_boundaries 197 | # plt.imsave("_DEBUG_full_seg.png", mark_boundaries(occ_mask[i].detach().cpu().numpy().transpose(1, 2, 0).squeeze(), full_seg[i].detach().cpu().numpy().transpose(1, 2, 0).squeeze().astype(int))) 198 | 199 | loss /= B 200 | return loss 201 | -------------------------------------------------------------------------------- /models/correlation_native.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Correlation(nn.Module): 7 | def __init__(self, max_displacement=4, *args, **kwargs): 8 | super(Correlation, self).__init__() 9 | self.max_displacement = max_displacement 10 | self.output_dim = 2 * self.max_displacement + 1 11 | self.pad_size = self.max_displacement 12 | 13 | def forward(self, x1, x2): 14 | B, C, H, W = x1.size() 15 | 16 | x2 = F.pad(x2, [self.pad_size] * 4) 17 | cv = [] 18 | for i in range(self.output_dim): 19 | for j in range(self.output_dim): 20 | cost = x1 * x2[:, :, i : (i + H), j : (j + W)] 21 | cost = torch.mean(cost, 1, keepdim=True) 22 | cv.append(cost) 23 | return torch.cat(cv, 1) 24 | 25 | 26 | if __name__ == "__main__": 27 | import random 28 | import time 29 | 30 | from correlation_package.correlation import Correlation as Correlation_cuda 31 | 32 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 33 | corr1 = Correlation( 34 | max_displacement=4, kernel_size=1, stride1=1, stride2=1, corr_multiply=1 35 | ).to(device) 36 | 37 | corr2 = Correlation_cuda( 38 | pad_size=4, 39 | kernel_size=1, 40 | max_displacement=4, 41 | stride1=1, 42 | stride2=1, 43 | corr_multiply=1, 44 | ) 45 | 46 | t1_sum = 0 47 | t2_sum = 0 48 | 49 | for i in range(50): 50 | C = random.choice([128, 256]) 51 | H = random.choice([128, 256]) # , 512 52 | W = random.choice([64, 128]) # , 256 53 | x1 = torch.randn(4, C, H, W, requires_grad=True).to(device) 54 | x2 = torch.randn(4, C, H, W).to(device) 55 | 56 | end = time.time() 57 | y2 = corr2(x1, x2) 58 | t2_f = time.time() - end 59 | 60 | end = time.time() 61 | y2.sum().backward() 62 | t2_b = time.time() - end 63 | 64 | end = time.time() 65 | y1 = corr1(x1, x2) 66 | t1_f = time.time() - end 67 | 68 | end = time.time() 69 | y1.sum().backward() 70 | t1_b = time.time() - end 71 | 72 | assert torch.allclose(y1, y2, atol=1e-7) 73 | 74 | print( 75 | "Forward: cuda: {:.3f}ms, pytorch: {:.3f}ms".format(t1_f * 100, t2_f * 100) 76 | ) 77 | print( 78 | "Backward: cuda: {:.3f}ms, pytorch: {:.3f}ms".format(t1_b * 100, t2_b * 100) 79 | ) 80 | 81 | if i < 3: 82 | continue 83 | t1_sum += t1_b + t1_f 84 | t2_sum += t2_b + t2_f 85 | 86 | print("cuda: {:.3f}s, pytorch: {:.3f}s".format(t1_sum, t2_sum)) 87 | ... 88 | -------------------------------------------------------------------------------- /models/correlation_package/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/UnSAMFlow/0363aed7b258ad7e659d82cbdb1766709f7b9429/models/correlation_package/__init__.py -------------------------------------------------------------------------------- /models/correlation_package/correlation.py: -------------------------------------------------------------------------------- 1 | import correlation_cuda 2 | import torch 3 | from torch.autograd import Function 4 | from torch.nn.modules.module import Module 5 | 6 | 7 | class CorrelationFunction(Function): 8 | def __init__( 9 | self, 10 | pad_size=3, 11 | kernel_size=3, 12 | max_displacement=20, 13 | stride1=1, 14 | stride2=2, 15 | corr_multiply=1, 16 | ): 17 | super(CorrelationFunction, self).__init__() 18 | self.pad_size = pad_size 19 | self.kernel_size = kernel_size 20 | self.max_displacement = max_displacement 21 | self.stride1 = stride1 22 | self.stride2 = stride2 23 | self.corr_multiply = corr_multiply 24 | # self.out_channel = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1) 25 | 26 | def forward(self, input1, input2): 27 | self.save_for_backward(input1, input2) 28 | 29 | with torch.cuda.device_of(input1): 30 | rbot1 = input1.new() 31 | rbot2 = input2.new() 32 | output = input1.new() 33 | 34 | correlation_cuda.forward( 35 | input1, 36 | input2, 37 | rbot1, 38 | rbot2, 39 | output, 40 | self.pad_size, 41 | self.kernel_size, 42 | self.max_displacement, 43 | self.stride1, 44 | self.stride2, 45 | self.corr_multiply, 46 | ) 47 | 48 | return output 49 | 50 | def backward(self, grad_output): 51 | input1, input2 = self.saved_tensors 52 | 53 | with torch.cuda.device_of(input1): 54 | rbot1 = input1.new() 55 | rbot2 = input2.new() 56 | 57 | grad_input1 = input1.new() 58 | grad_input2 = input2.new() 59 | 60 | correlation_cuda.backward( 61 | input1, 62 | input2, 63 | rbot1, 64 | rbot2, 65 | grad_output, 66 | grad_input1, 67 | grad_input2, 68 | self.pad_size, 69 | self.kernel_size, 70 | self.max_displacement, 71 | self.stride1, 72 | self.stride2, 73 | self.corr_multiply, 74 | ) 75 | 76 | return grad_input1, grad_input2 77 | 78 | 79 | class Correlation(Module): 80 | def __init__( 81 | self, 82 | pad_size=0, 83 | kernel_size=0, 84 | max_displacement=0, 85 | stride1=1, 86 | stride2=2, 87 | corr_multiply=1, 88 | ): 89 | super(Correlation, self).__init__() 90 | self.pad_size = pad_size 91 | self.kernel_size = kernel_size 92 | self.max_displacement = max_displacement 93 | self.stride1 = stride1 94 | self.stride2 = stride2 95 | self.corr_multiply = corr_multiply 96 | 97 | def forward(self, input1, input2): 98 | 99 | result = CorrelationFunction( 100 | self.pad_size, 101 | self.kernel_size, 102 | self.max_displacement, 103 | self.stride1, 104 | self.stride2, 105 | self.corr_multiply, 106 | )(input1, input2) 107 | 108 | return result 109 | -------------------------------------------------------------------------------- /models/correlation_package/correlation_cuda.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "correlation_cuda_kernel.cuh" 9 | 10 | int correlation_forward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& output, 11 | int pad_size, 12 | int kernel_size, 13 | int max_displacement, 14 | int stride1, 15 | int stride2, 16 | int corr_type_multiply) 17 | { 18 | 19 | int batchSize = input1.size(0); 20 | 21 | int nInputChannels = input1.size(1); 22 | int inputHeight = input1.size(2); 23 | int inputWidth = input1.size(3); 24 | 25 | int kernel_radius = (kernel_size - 1) / 2; 26 | int border_radius = kernel_radius + max_displacement; 27 | 28 | int paddedInputHeight = inputHeight + 2 * pad_size; 29 | int paddedInputWidth = inputWidth + 2 * pad_size; 30 | 31 | int nOutputChannels = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1); 32 | 33 | int outputHeight = ceil(static_cast(paddedInputHeight - 2 * border_radius) / static_cast(stride1)); 34 | int outputwidth = ceil(static_cast(paddedInputWidth - 2 * border_radius) / static_cast(stride1)); 35 | 36 | rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 37 | rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 38 | output.resize_({batchSize, nOutputChannels, outputHeight, outputwidth}); 39 | 40 | rInput1.fill_(0); 41 | rInput2.fill_(0); 42 | output.fill_(0); 43 | 44 | int success = correlation_forward_cuda_kernel( 45 | output, 46 | output.size(0), 47 | output.size(1), 48 | output.size(2), 49 | output.size(3), 50 | output.stride(0), 51 | output.stride(1), 52 | output.stride(2), 53 | output.stride(3), 54 | input1, 55 | input1.size(1), 56 | input1.size(2), 57 | input1.size(3), 58 | input1.stride(0), 59 | input1.stride(1), 60 | input1.stride(2), 61 | input1.stride(3), 62 | input2, 63 | input2.size(1), 64 | input2.stride(0), 65 | input2.stride(1), 66 | input2.stride(2), 67 | input2.stride(3), 68 | rInput1, 69 | rInput2, 70 | pad_size, 71 | kernel_size, 72 | max_displacement, 73 | stride1, 74 | stride2, 75 | corr_type_multiply, 76 | at::cuda::getCurrentCUDAStream() 77 | //at::globalContext().getCurrentCUDAStream() 78 | ); 79 | 80 | //check for errors 81 | if (!success) { 82 | AT_ERROR("CUDA call failed"); 83 | } 84 | 85 | return 1; 86 | 87 | } 88 | 89 | int correlation_backward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& gradOutput, 90 | at::Tensor& gradInput1, at::Tensor& gradInput2, 91 | int pad_size, 92 | int kernel_size, 93 | int max_displacement, 94 | int stride1, 95 | int stride2, 96 | int corr_type_multiply) 97 | { 98 | 99 | int batchSize = input1.size(0); 100 | int nInputChannels = input1.size(1); 101 | int paddedInputHeight = input1.size(2)+ 2 * pad_size; 102 | int paddedInputWidth = input1.size(3)+ 2 * pad_size; 103 | 104 | int height = input1.size(2); 105 | int width = input1.size(3); 106 | 107 | rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 108 | rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 109 | gradInput1.resize_({batchSize, nInputChannels, height, width}); 110 | gradInput2.resize_({batchSize, nInputChannels, height, width}); 111 | 112 | rInput1.fill_(0); 113 | rInput2.fill_(0); 114 | gradInput1.fill_(0); 115 | gradInput2.fill_(0); 116 | 117 | int success = correlation_backward_cuda_kernel(gradOutput, 118 | gradOutput.size(0), 119 | gradOutput.size(1), 120 | gradOutput.size(2), 121 | gradOutput.size(3), 122 | gradOutput.stride(0), 123 | gradOutput.stride(1), 124 | gradOutput.stride(2), 125 | gradOutput.stride(3), 126 | input1, 127 | input1.size(1), 128 | input1.size(2), 129 | input1.size(3), 130 | input1.stride(0), 131 | input1.stride(1), 132 | input1.stride(2), 133 | input1.stride(3), 134 | input2, 135 | input2.stride(0), 136 | input2.stride(1), 137 | input2.stride(2), 138 | input2.stride(3), 139 | gradInput1, 140 | gradInput1.stride(0), 141 | gradInput1.stride(1), 142 | gradInput1.stride(2), 143 | gradInput1.stride(3), 144 | gradInput2, 145 | gradInput2.size(1), 146 | gradInput2.stride(0), 147 | gradInput2.stride(1), 148 | gradInput2.stride(2), 149 | gradInput2.stride(3), 150 | rInput1, 151 | rInput2, 152 | pad_size, 153 | kernel_size, 154 | max_displacement, 155 | stride1, 156 | stride2, 157 | corr_type_multiply, 158 | at::cuda::getCurrentCUDAStream() 159 | //at::globalContext().getCurrentCUDAStream() 160 | ); 161 | 162 | if (!success) { 163 | AT_ERROR("CUDA call failed"); 164 | } 165 | 166 | return 1; 167 | } 168 | 169 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 170 | m.def("forward", &correlation_forward_cuda, "Correlation forward (CUDA)"); 171 | m.def("backward", &correlation_backward_cuda, "Correlation backward (CUDA)"); 172 | } 173 | 174 | -------------------------------------------------------------------------------- /models/correlation_package/correlation_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "correlation_cuda_kernel.cuh" 4 | 5 | #define CUDA_NUM_THREADS 1024 6 | #define THREADS_PER_BLOCK 32 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | using at::Half; 14 | 15 | template 16 | __global__ void channels_first(const scalar_t* __restrict__ input, scalar_t* rinput, int channels, int height, int width, int pad_size) 17 | { 18 | 19 | // n (batch size), c (num of channels), y (height), x (width) 20 | int n = blockIdx.x; 21 | int y = blockIdx.y; 22 | int x = blockIdx.z; 23 | 24 | int ch_off = threadIdx.x; 25 | scalar_t value; 26 | 27 | int dimcyx = channels * height * width; 28 | int dimyx = height * width; 29 | 30 | int p_dimx = (width + 2 * pad_size); 31 | int p_dimy = (height + 2 * pad_size); 32 | int p_dimyxc = channels * p_dimy * p_dimx; 33 | int p_dimxc = p_dimx * channels; 34 | 35 | for (int c = ch_off; c < channels; c += THREADS_PER_BLOCK) { 36 | value = input[n * dimcyx + c * dimyx + y * width + x]; 37 | rinput[n * p_dimyxc + (y + pad_size) * p_dimxc + (x + pad_size) * channels + c] = value; 38 | } 39 | } 40 | 41 | template 42 | __global__ void correlation_forward(scalar_t* output, int nOutputChannels, int outputHeight, int outputWidth, 43 | const scalar_t* __restrict__ rInput1, int nInputChannels, int inputHeight, int inputWidth, 44 | const scalar_t* __restrict__ rInput2, 45 | int pad_size, 46 | int kernel_size, 47 | int max_displacement, 48 | int stride1, 49 | int stride2) 50 | { 51 | // n (batch size), c (num of channels), y (height), x (width) 52 | 53 | int pInputWidth = inputWidth + 2 * pad_size; 54 | int pInputHeight = inputHeight + 2 * pad_size; 55 | 56 | int kernel_rad = (kernel_size - 1) / 2; 57 | int displacement_rad = max_displacement / stride2; 58 | int displacement_size = 2 * displacement_rad + 1; 59 | 60 | int n = blockIdx.x; 61 | int y1 = blockIdx.y * stride1 + max_displacement; 62 | int x1 = blockIdx.z * stride1 + max_displacement; 63 | int c = threadIdx.x; 64 | 65 | int pdimyxc = pInputHeight * pInputWidth * nInputChannels; 66 | int pdimxc = pInputWidth * nInputChannels; 67 | int pdimc = nInputChannels; 68 | 69 | int tdimcyx = nOutputChannels * outputHeight * outputWidth; 70 | int tdimyx = outputHeight * outputWidth; 71 | int tdimx = outputWidth; 72 | 73 | scalar_t nelems = kernel_size * kernel_size * pdimc; 74 | 75 | __shared__ scalar_t prod_sum[THREADS_PER_BLOCK]; 76 | 77 | // no significant speed-up in using chip memory for input1 sub-data, 78 | // not enough chip memory size to accomodate memory per block for input2 sub-data 79 | // instead i've used device memory for both 80 | 81 | // element-wise product along channel axis 82 | for (int tj = -displacement_rad; tj <= displacement_rad; ++tj) { 83 | for (int ti = -displacement_rad; ti <= displacement_rad; ++ti) { 84 | prod_sum[c] = 0; 85 | int x2 = x1 + ti*stride2; 86 | int y2 = y1 + tj*stride2; 87 | 88 | for (int j = -kernel_rad; j <= kernel_rad; ++j) { 89 | for (int i = -kernel_rad; i <= kernel_rad; ++i) { 90 | for (int ch = c; ch < pdimc; ch += THREADS_PER_BLOCK) { 91 | int indx1 = n * pdimyxc + (y1 + j) * pdimxc + (x1 + i) * pdimc + ch; 92 | int indx2 = n * pdimyxc + (y2 + j) * pdimxc + (x2 + i) * pdimc + ch; 93 | 94 | prod_sum[c] += rInput1[indx1] * rInput2[indx2]; 95 | } 96 | } 97 | } 98 | 99 | // accumulate 100 | __syncthreads(); 101 | if (c == 0) { 102 | scalar_t reduce_sum = 0; 103 | for (int index = 0; index < THREADS_PER_BLOCK; ++index) { 104 | reduce_sum += prod_sum[index]; 105 | } 106 | int tc = (tj + displacement_rad) * displacement_size + (ti + displacement_rad); 107 | const int tindx = n * tdimcyx + tc * tdimyx + blockIdx.y * tdimx + blockIdx.z; 108 | output[tindx] = reduce_sum / nelems; 109 | } 110 | 111 | } 112 | } 113 | 114 | } 115 | 116 | template 117 | __global__ void correlation_backward_input1(int item, scalar_t* gradInput1, int nInputChannels, int inputHeight, int inputWidth, 118 | const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth, 119 | const scalar_t* __restrict__ rInput2, 120 | int pad_size, 121 | int kernel_size, 122 | int max_displacement, 123 | int stride1, 124 | int stride2) 125 | { 126 | // n (batch size), c (num of channels), y (height), x (width) 127 | 128 | int n = item; 129 | int y = blockIdx.x * stride1 + pad_size; 130 | int x = blockIdx.y * stride1 + pad_size; 131 | int c = blockIdx.z; 132 | int tch_off = threadIdx.x; 133 | 134 | int kernel_rad = (kernel_size - 1) / 2; 135 | int displacement_rad = max_displacement / stride2; 136 | int displacement_size = 2 * displacement_rad + 1; 137 | 138 | int xmin = (x - kernel_rad - max_displacement) / stride1; 139 | int ymin = (y - kernel_rad - max_displacement) / stride1; 140 | 141 | int xmax = (x + kernel_rad - max_displacement) / stride1; 142 | int ymax = (y + kernel_rad - max_displacement) / stride1; 143 | 144 | if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) { 145 | // assumes gradInput1 is pre-allocated and zero filled 146 | return; 147 | } 148 | 149 | if (xmin > xmax || ymin > ymax) { 150 | // assumes gradInput1 is pre-allocated and zero filled 151 | return; 152 | } 153 | 154 | xmin = max(0, xmin); 155 | xmax = min(outputWidth - 1, xmax); 156 | 157 | ymin = max(0, ymin); 158 | ymax = min(outputHeight - 1, ymax); 159 | 160 | int pInputWidth = inputWidth + 2 * pad_size; 161 | int pInputHeight = inputHeight + 2 * pad_size; 162 | 163 | int pdimyxc = pInputHeight * pInputWidth * nInputChannels; 164 | int pdimxc = pInputWidth * nInputChannels; 165 | int pdimc = nInputChannels; 166 | 167 | int tdimcyx = nOutputChannels * outputHeight * outputWidth; 168 | int tdimyx = outputHeight * outputWidth; 169 | int tdimx = outputWidth; 170 | 171 | int odimcyx = nInputChannels * inputHeight* inputWidth; 172 | int odimyx = inputHeight * inputWidth; 173 | int odimx = inputWidth; 174 | 175 | scalar_t nelems = kernel_size * kernel_size * nInputChannels; 176 | 177 | __shared__ scalar_t prod_sum[THREADS_PER_BLOCK]; 178 | prod_sum[tch_off] = 0; 179 | 180 | for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) { 181 | 182 | int i2 = (tc % displacement_size - displacement_rad) * stride2; 183 | int j2 = (tc / displacement_size - displacement_rad) * stride2; 184 | 185 | int indx2 = n * pdimyxc + (y + j2)* pdimxc + (x + i2) * pdimc + c; 186 | 187 | scalar_t val2 = rInput2[indx2]; 188 | 189 | for (int j = ymin; j <= ymax; ++j) { 190 | for (int i = xmin; i <= xmax; ++i) { 191 | int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i; 192 | prod_sum[tch_off] += gradOutput[tindx] * val2; 193 | } 194 | } 195 | } 196 | __syncthreads(); 197 | 198 | if (tch_off == 0) { 199 | scalar_t reduce_sum = 0; 200 | for (int idx = 0; idx < THREADS_PER_BLOCK; idx++) { 201 | reduce_sum += prod_sum[idx]; 202 | } 203 | const int indx1 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size); 204 | gradInput1[indx1] = reduce_sum / nelems; 205 | } 206 | 207 | } 208 | 209 | template 210 | __global__ void correlation_backward_input2(int item, scalar_t* gradInput2, int nInputChannels, int inputHeight, int inputWidth, 211 | const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth, 212 | const scalar_t* __restrict__ rInput1, 213 | int pad_size, 214 | int kernel_size, 215 | int max_displacement, 216 | int stride1, 217 | int stride2) 218 | { 219 | // n (batch size), c (num of channels), y (height), x (width) 220 | 221 | int n = item; 222 | int y = blockIdx.x * stride1 + pad_size; 223 | int x = blockIdx.y * stride1 + pad_size; 224 | int c = blockIdx.z; 225 | 226 | int tch_off = threadIdx.x; 227 | 228 | int kernel_rad = (kernel_size - 1) / 2; 229 | int displacement_rad = max_displacement / stride2; 230 | int displacement_size = 2 * displacement_rad + 1; 231 | 232 | int pInputWidth = inputWidth + 2 * pad_size; 233 | int pInputHeight = inputHeight + 2 * pad_size; 234 | 235 | int pdimyxc = pInputHeight * pInputWidth * nInputChannels; 236 | int pdimxc = pInputWidth * nInputChannels; 237 | int pdimc = nInputChannels; 238 | 239 | int tdimcyx = nOutputChannels * outputHeight * outputWidth; 240 | int tdimyx = outputHeight * outputWidth; 241 | int tdimx = outputWidth; 242 | 243 | int odimcyx = nInputChannels * inputHeight* inputWidth; 244 | int odimyx = inputHeight * inputWidth; 245 | int odimx = inputWidth; 246 | 247 | scalar_t nelems = kernel_size * kernel_size * nInputChannels; 248 | 249 | __shared__ scalar_t prod_sum[THREADS_PER_BLOCK]; 250 | prod_sum[tch_off] = 0; 251 | 252 | for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) { 253 | int i2 = (tc % displacement_size - displacement_rad) * stride2; 254 | int j2 = (tc / displacement_size - displacement_rad) * stride2; 255 | 256 | int xmin = (x - kernel_rad - max_displacement - i2) / stride1; 257 | int ymin = (y - kernel_rad - max_displacement - j2) / stride1; 258 | 259 | int xmax = (x + kernel_rad - max_displacement - i2) / stride1; 260 | int ymax = (y + kernel_rad - max_displacement - j2) / stride1; 261 | 262 | if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) { 263 | // assumes gradInput2 is pre-allocated and zero filled 264 | continue; 265 | } 266 | 267 | if (xmin > xmax || ymin > ymax) { 268 | // assumes gradInput2 is pre-allocated and zero filled 269 | continue; 270 | } 271 | 272 | xmin = max(0, xmin); 273 | xmax = min(outputWidth - 1, xmax); 274 | 275 | ymin = max(0, ymin); 276 | ymax = min(outputHeight - 1, ymax); 277 | 278 | int indx1 = n * pdimyxc + (y - j2)* pdimxc + (x - i2) * pdimc + c; 279 | scalar_t val1 = rInput1[indx1]; 280 | 281 | for (int j = ymin; j <= ymax; ++j) { 282 | for (int i = xmin; i <= xmax; ++i) { 283 | int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i; 284 | prod_sum[tch_off] += gradOutput[tindx] * val1; 285 | } 286 | } 287 | } 288 | 289 | __syncthreads(); 290 | 291 | if (tch_off == 0) { 292 | scalar_t reduce_sum = 0; 293 | for (int idx = 0; idx < THREADS_PER_BLOCK; idx++) { 294 | reduce_sum += prod_sum[idx]; 295 | } 296 | const int indx2 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size); 297 | gradInput2[indx2] = reduce_sum / nelems; 298 | } 299 | 300 | } 301 | 302 | int correlation_forward_cuda_kernel(at::Tensor& output, 303 | int ob, 304 | int oc, 305 | int oh, 306 | int ow, 307 | int osb, 308 | int osc, 309 | int osh, 310 | int osw, 311 | 312 | at::Tensor& input1, 313 | int ic, 314 | int ih, 315 | int iw, 316 | int isb, 317 | int isc, 318 | int ish, 319 | int isw, 320 | 321 | at::Tensor& input2, 322 | int gc, 323 | int gsb, 324 | int gsc, 325 | int gsh, 326 | int gsw, 327 | 328 | at::Tensor& rInput1, 329 | at::Tensor& rInput2, 330 | int pad_size, 331 | int kernel_size, 332 | int max_displacement, 333 | int stride1, 334 | int stride2, 335 | int corr_type_multiply, 336 | cudaStream_t stream) 337 | { 338 | 339 | int batchSize = ob; 340 | 341 | int nInputChannels = ic; 342 | int inputWidth = iw; 343 | int inputHeight = ih; 344 | 345 | int nOutputChannels = oc; 346 | int outputWidth = ow; 347 | int outputHeight = oh; 348 | 349 | dim3 blocks_grid(batchSize, inputHeight, inputWidth); 350 | dim3 threads_block(THREADS_PER_BLOCK); 351 | 352 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channels_first_fwd_1", ([&] { 353 | 354 | channels_first << > >( 355 | input1.data(), rInput1.data(), nInputChannels, inputHeight, inputWidth, pad_size); 356 | 357 | })); 358 | 359 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "channels_first_fwd_2", ([&] { 360 | 361 | channels_first << > > ( 362 | input2.data(), rInput2.data(), nInputChannels, inputHeight, inputWidth, pad_size); 363 | 364 | })); 365 | 366 | dim3 threadsPerBlock(THREADS_PER_BLOCK); 367 | dim3 totalBlocksCorr(batchSize, outputHeight, outputWidth); 368 | 369 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "correlation_forward", ([&] { 370 | 371 | correlation_forward << > > 372 | (output.data(), nOutputChannels, outputHeight, outputWidth, 373 | rInput1.data(), nInputChannels, inputHeight, inputWidth, 374 | rInput2.data(), 375 | pad_size, 376 | kernel_size, 377 | max_displacement, 378 | stride1, 379 | stride2); 380 | 381 | })); 382 | 383 | cudaError_t err = cudaGetLastError(); 384 | 385 | 386 | // check for errors 387 | if (err != cudaSuccess) { 388 | printf("error in correlation_forward_cuda_kernel: %s\n", cudaGetErrorString(err)); 389 | return 0; 390 | } 391 | 392 | return 1; 393 | } 394 | 395 | 396 | int correlation_backward_cuda_kernel( 397 | at::Tensor& gradOutput, 398 | int gob, 399 | int goc, 400 | int goh, 401 | int gow, 402 | int gosb, 403 | int gosc, 404 | int gosh, 405 | int gosw, 406 | 407 | at::Tensor& input1, 408 | int ic, 409 | int ih, 410 | int iw, 411 | int isb, 412 | int isc, 413 | int ish, 414 | int isw, 415 | 416 | at::Tensor& input2, 417 | int gsb, 418 | int gsc, 419 | int gsh, 420 | int gsw, 421 | 422 | at::Tensor& gradInput1, 423 | int gisb, 424 | int gisc, 425 | int gish, 426 | int gisw, 427 | 428 | at::Tensor& gradInput2, 429 | int ggc, 430 | int ggsb, 431 | int ggsc, 432 | int ggsh, 433 | int ggsw, 434 | 435 | at::Tensor& rInput1, 436 | at::Tensor& rInput2, 437 | int pad_size, 438 | int kernel_size, 439 | int max_displacement, 440 | int stride1, 441 | int stride2, 442 | int corr_type_multiply, 443 | cudaStream_t stream) 444 | { 445 | 446 | int batchSize = gob; 447 | int num = batchSize; 448 | 449 | int nInputChannels = ic; 450 | int inputWidth = iw; 451 | int inputHeight = ih; 452 | 453 | int nOutputChannels = goc; 454 | int outputWidth = gow; 455 | int outputHeight = goh; 456 | 457 | dim3 blocks_grid(batchSize, inputHeight, inputWidth); 458 | dim3 threads_block(THREADS_PER_BLOCK); 459 | 460 | 461 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "lltm_forward_cuda", ([&] { 462 | 463 | channels_first << > >( 464 | input1.data(), 465 | rInput1.data(), 466 | nInputChannels, 467 | inputHeight, 468 | inputWidth, 469 | pad_size 470 | ); 471 | })); 472 | 473 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] { 474 | 475 | channels_first << > >( 476 | input2.data(), 477 | rInput2.data(), 478 | nInputChannels, 479 | inputHeight, 480 | inputWidth, 481 | pad_size 482 | ); 483 | })); 484 | 485 | dim3 threadsPerBlock(THREADS_PER_BLOCK); 486 | dim3 totalBlocksCorr(inputHeight, inputWidth, nInputChannels); 487 | 488 | for (int n = 0; n < num; ++n) { 489 | 490 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] { 491 | 492 | 493 | correlation_backward_input1 << > > ( 494 | n, gradInput1.data(), nInputChannels, inputHeight, inputWidth, 495 | gradOutput.data(), nOutputChannels, outputHeight, outputWidth, 496 | rInput2.data(), 497 | pad_size, 498 | kernel_size, 499 | max_displacement, 500 | stride1, 501 | stride2); 502 | })); 503 | } 504 | 505 | for (int n = 0; n < batchSize; n++) { 506 | 507 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(rInput1.type(), "lltm_forward_cuda", ([&] { 508 | 509 | correlation_backward_input2 << > >( 510 | n, gradInput2.data(), nInputChannels, inputHeight, inputWidth, 511 | gradOutput.data(), nOutputChannels, outputHeight, outputWidth, 512 | rInput1.data(), 513 | pad_size, 514 | kernel_size, 515 | max_displacement, 516 | stride1, 517 | stride2); 518 | 519 | })); 520 | } 521 | 522 | // check for errors 523 | cudaError_t err = cudaGetLastError(); 524 | if (err != cudaSuccess) { 525 | printf("error in correlation_backward_cuda_kernel: %s\n", cudaGetErrorString(err)); 526 | return 0; 527 | } 528 | 529 | return 1; 530 | } 531 | -------------------------------------------------------------------------------- /models/correlation_package/correlation_cuda_kernel.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | int correlation_forward_cuda_kernel(at::Tensor& output, 8 | int ob, 9 | int oc, 10 | int oh, 11 | int ow, 12 | int osb, 13 | int osc, 14 | int osh, 15 | int osw, 16 | 17 | at::Tensor& input1, 18 | int ic, 19 | int ih, 20 | int iw, 21 | int isb, 22 | int isc, 23 | int ish, 24 | int isw, 25 | 26 | at::Tensor& input2, 27 | int gc, 28 | int gsb, 29 | int gsc, 30 | int gsh, 31 | int gsw, 32 | 33 | at::Tensor& rInput1, 34 | at::Tensor& rInput2, 35 | int pad_size, 36 | int kernel_size, 37 | int max_displacement, 38 | int stride1, 39 | int stride2, 40 | int corr_type_multiply, 41 | cudaStream_t stream); 42 | 43 | 44 | int correlation_backward_cuda_kernel( 45 | at::Tensor& gradOutput, 46 | int gob, 47 | int goc, 48 | int goh, 49 | int gow, 50 | int gosb, 51 | int gosc, 52 | int gosh, 53 | int gosw, 54 | 55 | at::Tensor& input1, 56 | int ic, 57 | int ih, 58 | int iw, 59 | int isb, 60 | int isc, 61 | int ish, 62 | int isw, 63 | 64 | at::Tensor& input2, 65 | int gsb, 66 | int gsc, 67 | int gsh, 68 | int gsw, 69 | 70 | at::Tensor& gradInput1, 71 | int gisb, 72 | int gisc, 73 | int gish, 74 | int gisw, 75 | 76 | at::Tensor& gradInput2, 77 | int ggc, 78 | int ggsb, 79 | int ggsc, 80 | int ggsh, 81 | int ggsw, 82 | 83 | at::Tensor& rInput1, 84 | at::Tensor& rInput2, 85 | int pad_size, 86 | int kernel_size, 87 | int max_displacement, 88 | int stride1, 89 | int stride2, 90 | int corr_type_multiply, 91 | cudaStream_t stream); 92 | -------------------------------------------------------------------------------- /models/correlation_package/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | cxx_args = ["-std=c++11"] 6 | 7 | nvcc_args = [ 8 | "-gencode", 9 | "arch=compute_50,code=sm_50", 10 | "-gencode", 11 | "arch=compute_52,code=sm_52", 12 | "-gencode", 13 | "arch=compute_60,code=sm_60", 14 | "-gencode", 15 | "arch=compute_61,code=sm_61", 16 | "-gencode", 17 | "arch=compute_61,code=compute_61", 18 | "-ccbin", 19 | "/usr/bin/gcc-5", 20 | ] 21 | 22 | setup( 23 | name="correlation_cuda", 24 | ext_modules=[ 25 | CUDAExtension( 26 | "correlation_cuda", 27 | ["correlation_cuda.cc", "correlation_cuda_kernel.cu"], 28 | extra_compile_args={ 29 | "cxx": cxx_args, 30 | "nvcc": nvcc_args, 31 | "cuda-path": ["/usr/local/cuda-9.0"], 32 | }, 33 | ) 34 | ], 35 | cmdclass={"build_ext": BuildExtension}, 36 | ) 37 | -------------------------------------------------------------------------------- /models/get_model.py: -------------------------------------------------------------------------------- 1 | from .pwclite import PWCLite 2 | 3 | 4 | def get_model(cfg): 5 | if cfg.type == "pwclite": 6 | model = PWCLite(cfg) 7 | else: 8 | raise NotImplementedError(cfg.type) 9 | return model 10 | -------------------------------------------------------------------------------- /models/pwclite.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.correlation_internal import Correlation_ours 5 | 6 | # from .correlation_package.correlation import Correlation 7 | # from .correlation_native import Correlation 8 | 9 | from transforms.input_transforms import full_segs_to_adj_maps 10 | 11 | from utils.warp_utils import flow_warp 12 | 13 | 14 | def conv(in_planes, out_planes, kernel_size=3, stride=1, dilation=1, isReLU=True): 15 | if isReLU: 16 | return nn.Sequential( 17 | nn.Conv2d( 18 | in_planes, 19 | out_planes, 20 | kernel_size=kernel_size, 21 | stride=stride, 22 | dilation=dilation, 23 | padding=((kernel_size - 1) * dilation) // 2, 24 | bias=True, 25 | ), 26 | nn.LeakyReLU(0.1, inplace=True), 27 | ) 28 | else: 29 | return nn.Sequential( 30 | nn.Conv2d( 31 | in_planes, 32 | out_planes, 33 | kernel_size=kernel_size, 34 | stride=stride, 35 | dilation=dilation, 36 | padding=((kernel_size - 1) * dilation) // 2, 37 | bias=True, 38 | ) 39 | ) 40 | 41 | 42 | class FeatureExtractor(nn.Module): 43 | def __init__(self, num_chs, input_adj_map=False): 44 | super(FeatureExtractor, self).__init__() 45 | self.num_chs = num_chs 46 | self.convs = nn.ModuleList() 47 | 48 | if input_adj_map: 49 | self.adj_map_net = nn.Sequential( 50 | conv(81, 32, kernel_size=1), 51 | conv(32, 32, kernel_size=3, stride=2), 52 | conv(32, 32, kernel_size=3), 53 | conv(32, 32, kernel_size=3, stride=2), 54 | conv(32, 32, kernel_size=3), 55 | ) 56 | else: 57 | self.adj_map_net = None 58 | 59 | for level, (ch_in, ch_out) in enumerate(zip(num_chs[:-1], num_chs[1:])): 60 | if input_adj_map and level == 2: 61 | ch_in += 32 62 | layer = nn.Sequential(conv(ch_in, ch_out, stride=2), conv(ch_out, ch_out)) 63 | self.convs.append(layer) 64 | 65 | def forward(self, x, adj_map=None): 66 | feature_pyramid = [x] 67 | if self.adj_map_net is not None: 68 | adj_map_feat = self.adj_map_net(adj_map) 69 | 70 | for i, conv in enumerate(self.convs): 71 | if self.adj_map_net is not None and i == 2: 72 | x = torch.concat((x, adj_map_feat), dim=1) 73 | x = conv(x) 74 | feature_pyramid.append(x) 75 | 76 | return feature_pyramid[::-1] 77 | 78 | 79 | class FlowEstimatorDense(nn.Module): 80 | def __init__(self, ch_in): 81 | super(FlowEstimatorDense, self).__init__() 82 | self.conv1 = conv(ch_in, 128) 83 | self.conv2 = conv(ch_in + 128, 128) 84 | self.conv3 = conv(ch_in + 256, 96) 85 | self.conv4 = conv(ch_in + 352, 64) 86 | self.conv5 = conv(ch_in + 416, 32) 87 | self.feat_dim = ch_in + 448 88 | self.conv_last = conv(ch_in + 448, 2, isReLU=False) 89 | 90 | def forward(self, x): 91 | x1 = torch.cat([self.conv1(x), x], dim=1) 92 | x2 = torch.cat([self.conv2(x1), x1], dim=1) 93 | x3 = torch.cat([self.conv3(x2), x2], dim=1) 94 | x4 = torch.cat([self.conv4(x3), x3], dim=1) 95 | x5 = torch.cat([self.conv5(x4), x4], dim=1) 96 | x_out = self.conv_last(x5) 97 | return x5, x_out 98 | 99 | 100 | class FlowEstimatorReduce(nn.Module): 101 | # can reduce 25% of training time. 102 | def __init__(self, ch_in): 103 | super(FlowEstimatorReduce, self).__init__() 104 | self.conv1 = conv(ch_in, 128) 105 | self.conv2 = conv(128, 128) 106 | self.conv3 = conv(128 + 128, 96) 107 | self.conv4 = conv(128 + 96, 64) 108 | self.conv5 = conv(96 + 64, 32) 109 | self.feat_dim = 32 110 | self.predict_flow = conv(64 + 32, 2, isReLU=False) 111 | 112 | def forward(self, x): 113 | x1 = self.conv1(x) 114 | x2 = self.conv2(x1) 115 | x3 = self.conv3(torch.cat([x1, x2], dim=1)) 116 | x4 = self.conv4(torch.cat([x2, x3], dim=1)) 117 | x5 = self.conv5(torch.cat([x3, x4], dim=1)) 118 | flow = self.predict_flow(torch.cat([x4, x5], dim=1)) 119 | return x5, flow 120 | 121 | 122 | class ContextNetwork(nn.Module): 123 | def __init__(self, ch_in): 124 | super(ContextNetwork, self).__init__() 125 | 126 | self.convs = nn.Sequential( 127 | conv(ch_in, 128, 3, 1, 1), 128 | conv(128, 128, 3, 1, 2), 129 | conv(128, 128, 3, 1, 4), 130 | conv(128, 96, 3, 1, 8), 131 | ) 132 | self.flow_head = nn.Sequential( 133 | conv(96, 64, 3, 1, 16), conv(64, 32, 3, 1, 1), conv(32, 2, isReLU=False) 134 | ) 135 | 136 | def forward(self, x): 137 | feat = self.convs(x) 138 | flow = self.flow_head(feat) 139 | return flow, feat 140 | 141 | 142 | class UpFlowNetwork(nn.Module): 143 | def __init__(self, ch_in=96, scale_factor=4): 144 | super(UpFlowNetwork, self).__init__() 145 | self.convs = nn.Sequential( 146 | conv(ch_in, 128, 3, 1, 1), conv(128, scale_factor**2 * 9, 3, 1, 1) 147 | ) 148 | 149 | # adapted from https://github.com/princeton-vl/RAFT/blob/aac9dd54726caf2cf81d8661b07663e220c5586d/core/raft.py#L72 150 | def upsample_flow(self, flow, mask): 151 | """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination""" 152 | N, _, H, W = flow.shape 153 | mask = mask.view(N, 1, 9, 4, 4, H, W) 154 | mask = torch.softmax(mask, dim=2) 155 | 156 | up_flow = F.unfold(4 * flow, [3, 3], padding=1) 157 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 158 | 159 | up_flow = torch.sum(mask * up_flow, dim=2) 160 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 161 | return up_flow.reshape(N, 2, 4 * H, 4 * W) 162 | 163 | def forward(self, flow, feat): 164 | # scale mask to balence gradients 165 | up_mask = 0.25 * self.convs(feat) 166 | return self.upsample_flow(flow, up_mask) 167 | 168 | 169 | class PWCLite(nn.Module): 170 | def __init__(self, cfg): 171 | super(PWCLite, self).__init__() 172 | if "input_adj_map" not in cfg: 173 | cfg.input_adj_map = False 174 | 175 | if "input_boundary" not in cfg: 176 | cfg.input_boundary = False 177 | 178 | if "add_mask_corr" not in cfg: 179 | cfg.add_mask_corr = False 180 | 181 | self.cfg = cfg 182 | self.search_range = 4 183 | self.num_chs = [3, 16, 32, 64, 96, 128, 192] 184 | if cfg.input_boundary: 185 | self.num_chs[0] += 2 186 | 187 | self.output_level = 4 188 | self.num_levels = 7 189 | self.leakyRELU = nn.LeakyReLU(0.1, inplace=True) 190 | 191 | # encoder 192 | self.feature_pyramid_extractor = FeatureExtractor( 193 | self.num_chs, input_adj_map=cfg.input_adj_map 194 | ) 195 | 196 | # decoder 197 | ## Our correlation implementation 198 | self.corr = Correlation_ours( 199 | kernel_size=1, 200 | patch_size=(2 * self.search_range + 1), 201 | stride=1, 202 | padding=0, 203 | dilation_patch=1, 204 | normalize=True, 205 | ) 206 | 207 | ## Correlation modeuld in the original code 208 | # self.corr = Correlation( 209 | # pad_size=self.search_range, 210 | # kernel_size=1, 211 | # max_displacement=self.search_range, 212 | # stride1=1, 213 | # stride2=1, 214 | # corr_multiply=1, 215 | # ) 216 | 217 | self.dim_corr = (self.search_range * 2 + 1) ** 2 218 | 219 | if cfg.add_mask_corr: 220 | self.num_ch_in = 32 + 2 * self.dim_corr + 2 221 | else: 222 | self.num_ch_in = 32 + self.dim_corr + 2 223 | 224 | if cfg.reduce_dense: 225 | self.flow_estimators = FlowEstimatorReduce(self.num_ch_in) 226 | else: 227 | self.flow_estimators = FlowEstimatorDense(self.num_ch_in) 228 | 229 | self.context_networks = ContextNetwork(self.flow_estimators.feat_dim + 2) 230 | 231 | if cfg.learned_upsampler: 232 | self.output_flow_upsampler = UpFlowNetwork(ch_in=96, scale_factor=4) 233 | else: 234 | self.output_flow_upsampler = None 235 | 236 | self.conv_1x1 = nn.ModuleList( 237 | [ 238 | conv(self.num_chs[-1], 32, kernel_size=1, stride=1, dilation=1), 239 | conv(self.num_chs[-2], 32, kernel_size=1, stride=1, dilation=1), 240 | conv(self.num_chs[-3], 32, kernel_size=1, stride=1, dilation=1), 241 | conv(self.num_chs[-4], 32, kernel_size=1, stride=1, dilation=1), 242 | conv(self.num_chs[-5], 32, kernel_size=1, stride=1, dilation=1), 243 | ] 244 | ) 245 | 246 | if cfg.add_mask_corr: 247 | self.conv_1x1_mask = nn.ModuleList( 248 | [ 249 | conv(self.num_chs[-1], 32, kernel_size=1, stride=1, dilation=1), 250 | conv(self.num_chs[-2], 32, kernel_size=1, stride=1, dilation=1), 251 | conv(self.num_chs[-3], 32, kernel_size=1, stride=1, dilation=1), 252 | conv(self.num_chs[-4], 32, kernel_size=1, stride=1, dilation=1), 253 | conv(self.num_chs[-5], 32, kernel_size=1, stride=1, dilation=1), 254 | ] 255 | ) 256 | 257 | if self.cfg.aggregation_type == "residual": 258 | self.mask_aggregation = conv( 259 | 32, 32, kernel_size=1, stride=1, dilation=1 260 | ) 261 | elif self.cfg.aggregation_type == "concat": 262 | self.mask_aggregation = conv( 263 | 64, 32, kernel_size=1, stride=1, dilation=1 264 | ) 265 | 266 | def num_parameters(self): 267 | return sum( 268 | [p.data.nelement() if p.requires_grad else 0 for p in self.parameters()] 269 | ) 270 | 271 | def init_weights(self): 272 | for layer in self.named_modules(): 273 | if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.ConvTranspose2d): 274 | nn.init.kaiming_normal_(layer.weight) 275 | if layer.bias is not None: 276 | nn.init.constant_(layer.bias, 0) 277 | 278 | def decoder(self, x1_pyramid, x2_pyramid, full_seg1=None, full_seg2=None): 279 | # outputs 280 | flows = [] 281 | 282 | # init 283 | ( 284 | b_size, 285 | _, 286 | h_x1, 287 | w_x1, 288 | ) = x1_pyramid[0].size() 289 | init_dtype = x1_pyramid[0].dtype 290 | init_device = x1_pyramid[0].device 291 | flow = torch.zeros( 292 | b_size, 2, h_x1, w_x1, dtype=init_dtype, device=init_device 293 | ).float() 294 | 295 | for level, (x1, x2) in enumerate(zip(x1_pyramid, x2_pyramid)): 296 | 297 | # warping 298 | if level > 0: 299 | flow = F.interpolate( 300 | flow * 2, scale_factor=2, mode="bilinear", align_corners=True 301 | ) 302 | x2_warp = flow_warp(x2, flow) 303 | else: 304 | x2_warp = x2 305 | 306 | # correlation 307 | out_corr = self.corr(x1, x2_warp) 308 | out_corr_relu = self.leakyRELU(out_corr) 309 | 310 | # import IPython 311 | 312 | # IPython.embed() 313 | # exit() 314 | 315 | x1_1by1 = self.conv_1x1[level](x1) 316 | 317 | if self.cfg.add_mask_corr: 318 | x1_1by1_mask = self.conv_1x1_mask[level](x1) 319 | full_seg1_down = F.interpolate( 320 | full_seg1, x1.shape[-2:], mode="nearest" 321 | ).long() 322 | full_seg1_down_oh = F.one_hot(full_seg1_down) 323 | mask_pooled_value1 = torch.amax( 324 | full_seg1_down_oh * x1_1by1_mask[..., None], dim=(2, 3) 325 | ) 326 | mask_feat1 = ( 327 | full_seg1_down_oh * mask_pooled_value1[:, :, None, None, :] 328 | ).sum(dim=-1) 329 | 330 | x2_1by1_mask = self.conv_1x1_mask[level](x2) 331 | full_seg2_down = F.interpolate( 332 | full_seg2, x2.shape[-2:], mode="nearest" 333 | ).long() 334 | full_seg2_down_oh = F.one_hot(full_seg2_down) 335 | mask_pooled_value2 = torch.amax( 336 | full_seg2_down_oh * x2_1by1_mask[..., None], dim=(2, 3) 337 | ) 338 | mask_feat2 = ( 339 | full_seg2_down_oh * mask_pooled_value2[:, :, None, None, :] 340 | ).sum(dim=-1) 341 | 342 | if self.cfg.aggregation_type == "residual": 343 | x_mask_feat1 = x1_1by1_mask + self.mask_aggregation(mask_feat1) 344 | x_mask_feat2 = x2_1by1_mask + self.mask_aggregation(mask_feat2) 345 | elif self.cfg.aggregation_type == "concat": 346 | x_mask_feat1 = self.mask_aggregation( 347 | torch.concat((x1_1by1_mask, mask_feat1), axis=1) 348 | ) 349 | x_mask_feat2 = self.mask_aggregation( 350 | torch.concat((x2_1by1_mask, mask_feat2), axis=1) 351 | ) 352 | else: 353 | raise NotImplementedError 354 | 355 | x_mask_feat2_warp = flow_warp(x_mask_feat2, flow) 356 | out_mask_corr = self.corr(x_mask_feat1, x_mask_feat2_warp) 357 | out_mask_corr_relu = self.leakyRELU(out_mask_corr) 358 | 359 | x_intm, flow_res = self.flow_estimators( 360 | torch.cat([out_corr_relu, out_mask_corr_relu, x1_1by1, flow], dim=1) 361 | ) 362 | 363 | else: 364 | x_intm, flow_res = self.flow_estimators( 365 | torch.cat([out_corr_relu, x1_1by1, flow], dim=1) 366 | ) 367 | 368 | flow = flow + flow_res 369 | 370 | flow_fine, up_feat = self.context_networks(torch.cat([x_intm, flow], dim=1)) 371 | flow = flow + flow_fine 372 | 373 | if self.output_flow_upsampler is not None: 374 | flow_up = self.output_flow_upsampler(flow, up_feat) 375 | else: 376 | flow_up = F.interpolate( 377 | flow * 4, scale_factor=4, mode="bilinear", align_corners=True 378 | ) 379 | flows.append(flow_up) 380 | 381 | # upsampling or post-processing 382 | if level == self.output_level: 383 | break 384 | 385 | return flows[::-1] 386 | 387 | def forward(self, img1, img2, full_seg1=None, full_seg2=None, with_bk=False): 388 | 389 | batch_size, _, h, w = img1.shape 390 | 391 | if self.cfg.input_adj_map: 392 | adj_maps = full_segs_to_adj_maps( 393 | torch.concat((full_seg1, full_seg2), axis=0) 394 | ) 395 | adj_map1 = adj_maps[:batch_size] 396 | adj_map2 = adj_maps[batch_size:] 397 | else: 398 | adj_map1, adj_map2 = None, None 399 | 400 | if self.cfg.input_boundary: 401 | 402 | def compute_seg_edge(full_seg): 403 | batch_size, _, h, w = full_seg.shape 404 | seg_edge_x = (full_seg[..., :, 1:] != full_seg[..., :, :-1]).float() 405 | seg_edge_x = torch.concat( 406 | ( 407 | seg_edge_x, 408 | torch.zeros((batch_size, 1, h, 1)).to(seg_edge_x.device), 409 | ), 410 | axis=-1, 411 | ) 412 | seg_edge_y = (full_seg[..., 1:, :] != full_seg[..., :-1, :]).float() 413 | seg_edge_y = torch.concat( 414 | ( 415 | seg_edge_y, 416 | torch.zeros((batch_size, 1, 1, w)).to(seg_edge_x.device), 417 | ), 418 | axis=-2, 419 | ) 420 | return seg_edge_x, seg_edge_y 421 | 422 | img1 = torch.concat((img1, *compute_seg_edge(full_seg1)), axis=1) 423 | img2 = torch.concat((img2, *compute_seg_edge(full_seg2)), axis=1) 424 | 425 | feat1 = self.feature_pyramid_extractor(img1, adj_map1) 426 | feat2 = self.feature_pyramid_extractor(img2, adj_map2) 427 | 428 | # decode outputs 429 | res_dict = {} 430 | res_dict["flows_12"] = self.decoder(feat1, feat2, full_seg1, full_seg2) 431 | if with_bk: 432 | res_dict["flows_21"] = self.decoder(feat2, feat1, full_seg2, full_seg1) 433 | 434 | return res_dict 435 | -------------------------------------------------------------------------------- /sam_inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | Adapted from https://github.com/facebookresearch/segment-anything/blob/main/scripts/amg.py 5 | """ 6 | 7 | import argparse 8 | import json 9 | import os 10 | from typing import Any, Dict, List 11 | 12 | import cv2 # type: ignore 13 | 14 | # from utils.manifold_utils import pathmgr 15 | import numpy as np 16 | 17 | from segment_anything import sam_model_registry, SamAutomaticMaskGenerator 18 | from tqdm import tqdm 19 | 20 | parser = argparse.ArgumentParser( 21 | description=( 22 | "Runs automatic mask generation on an input image or directory of images, " 23 | "and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, " 24 | "as well as pycocotools if saving in RLE format." 25 | ) 26 | ) 27 | 28 | # parser.add_argument( 29 | # "--input", 30 | # type=str, 31 | # required=True, 32 | # help="Path to either a single input image or folder of images.", 33 | # ) 34 | 35 | parser.add_argument( 36 | "--dataset", 37 | type=str, 38 | required=True, 39 | help="The dataset for inference.", 40 | ) 41 | 42 | parser.add_argument( 43 | "--output", 44 | type=str, 45 | required=True, 46 | help=( 47 | "Path to the directory where masks will be output. Output will be either a folder " 48 | "of PNGs per image or a single json with COCO-style masks." 49 | ), 50 | ) 51 | 52 | parser.add_argument( 53 | "--model-type", 54 | type=str, 55 | required=True, 56 | help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']", 57 | ) 58 | 59 | parser.add_argument( 60 | "--checkpoint", 61 | type=str, 62 | required=True, 63 | help="The path to the SAM checkpoint to use for mask generation.", 64 | ) 65 | 66 | parser.add_argument( 67 | "--device", type=str, default="cuda", help="The device to run generation on." 68 | ) 69 | 70 | parser.add_argument( 71 | "--convert-to-rle", 72 | action="store_true", 73 | help=( 74 | "Save masks as COCO RLEs in a single json instead of as a folder of PNGs. " 75 | "Requires pycocotools." 76 | ), 77 | ) 78 | 79 | amg_settings = parser.add_argument_group("AMG Settings") 80 | 81 | amg_settings.add_argument( 82 | "--points-per-side", 83 | type=int, 84 | default=None, 85 | help="Generate masks by sampling a grid over the image with this many points to a side.", 86 | ) 87 | 88 | amg_settings.add_argument( 89 | "--points-per-batch", 90 | type=int, 91 | default=None, 92 | help="How many input points to process simultaneously in one batch.", 93 | ) 94 | 95 | amg_settings.add_argument( 96 | "--pred-iou-thresh", 97 | type=float, 98 | default=None, 99 | help="Exclude masks with a predicted score from the model that is lower than this threshold.", 100 | ) 101 | 102 | amg_settings.add_argument( 103 | "--stability-score-thresh", 104 | type=float, 105 | default=None, 106 | help="Exclude masks with a stability score lower than this threshold.", 107 | ) 108 | 109 | amg_settings.add_argument( 110 | "--stability-score-offset", 111 | type=float, 112 | default=None, 113 | help="Larger values perturb the mask more when measuring stability score.", 114 | ) 115 | 116 | amg_settings.add_argument( 117 | "--box-nms-thresh", 118 | type=float, 119 | default=None, 120 | help="The overlap threshold for excluding a duplicate mask.", 121 | ) 122 | 123 | amg_settings.add_argument( 124 | "--crop-n-layers", 125 | type=int, 126 | default=None, 127 | help=( 128 | "If >0, mask generation is run on smaller crops of the image to generate more masks. " 129 | "The value sets how many different scales to crop at." 130 | ), 131 | ) 132 | 133 | amg_settings.add_argument( 134 | "--crop-nms-thresh", 135 | type=float, 136 | default=None, 137 | help="The overlap threshold for excluding duplicate masks across different crops.", 138 | ) 139 | 140 | amg_settings.add_argument( 141 | "--crop-overlap-ratio", 142 | type=int, 143 | default=None, 144 | help="Larger numbers mean image crops will overlap more.", 145 | ) 146 | 147 | amg_settings.add_argument( 148 | "--crop-n-points-downscale-factor", 149 | type=int, 150 | default=None, 151 | help="The number of points-per-side in each layer of crop is reduced by this factor.", 152 | ) 153 | 154 | amg_settings.add_argument( 155 | "--min-mask-region-area", 156 | type=int, 157 | default=None, 158 | help=( 159 | "Disconnected mask regions or holes with area smaller than this value " 160 | "in pixels are removed by postprocessing." 161 | ), 162 | ) 163 | 164 | 165 | def write_masks_to_folder(masks: List[Dict[str, Any]], path: str) -> None: 166 | header = "id,area,bbox_x0,bbox_y0,bbox_w,bbox_h,point_input_x,point_input_y,predicted_iou,stability_score,crop_box_x0,crop_box_y0,crop_box_w,crop_box_h" # noqa 167 | metadata = [header] 168 | for i, mask_data in enumerate(masks): 169 | mask = mask_data["segmentation"] 170 | filename = f"{i}.png" 171 | cv2.imwrite(os.path.join(path, filename), mask * 255) 172 | mask_metadata = [ 173 | str(i), 174 | str(mask_data["area"]), 175 | *[str(x) for x in mask_data["bbox"]], 176 | *[str(x) for x in mask_data["point_coords"][0]], 177 | str(mask_data["predicted_iou"]), 178 | str(mask_data["stability_score"]), 179 | *[str(x) for x in mask_data["crop_box"]], 180 | ] 181 | row = ",".join(mask_metadata) 182 | metadata.append(row) 183 | metadata_path = os.path.join(path, "metadata.csv") 184 | with open(metadata_path, "w") as f: 185 | f.write("\n".join(metadata)) 186 | 187 | return 188 | 189 | 190 | def get_amg_kwargs(args): 191 | amg_kwargs = { 192 | "points_per_side": args.points_per_side, 193 | "points_per_batch": args.points_per_batch, 194 | "pred_iou_thresh": args.pred_iou_thresh, 195 | "stability_score_thresh": args.stability_score_thresh, 196 | "stability_score_offset": args.stability_score_offset, 197 | "box_nms_thresh": args.box_nms_thresh, 198 | "crop_n_layers": args.crop_n_layers, 199 | "crop_nms_thresh": args.crop_nms_thresh, 200 | "crop_overlap_ratio": args.crop_overlap_ratio, 201 | "crop_n_points_downscale_factor": args.crop_n_points_downscale_factor, 202 | "min_mask_region_area": args.min_mask_region_area, 203 | } 204 | amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None} 205 | return amg_kwargs 206 | 207 | 208 | def main(args: argparse.Namespace) -> None: 209 | 210 | print("Loading model...") 211 | sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint) 212 | _ = sam.to(device=args.device) 213 | output_mode = "coco_rle" if args.convert_to_rle else "binary_mask" 214 | amg_kwargs = get_amg_kwargs(args) 215 | generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode, **amg_kwargs) 216 | 217 | # if not os.path.isdir(args.input): 218 | # targets = [args.input] 219 | # else: 220 | # targets = [ 221 | # f 222 | # for f in os.listdir(args.input) 223 | # if not os.path.isdir(os.path.join(args.input, f)) 224 | # ] 225 | # targets = [os.path.join(args.input, f) for f in targets] 226 | 227 | if args.dataset == "KITTI-2015" or args.dataset == "KITTI-2012": 228 | dataset_root = ( 229 | YOUR_DIR + args.dataset 230 | ) 231 | 232 | targets = [] 233 | for split in ["training", "testing"]: 234 | with open(os.path.join(dataset_root, split, "image_list.txt"), "r") as f: 235 | line = f.readlines()[0] 236 | line = line.split(" ") 237 | targets += [os.path.join(split, t) for t in line] 238 | 239 | elif args.dataset == "KITTI-raw": 240 | dataset_root = YOUR_DIR 241 | 242 | targets = [] 243 | with open(os.path.join(dataset_root, "kitti_train_2f_sv.txt"), "r") as f: 244 | lines = f.readlines() 245 | 246 | for line in lines: 247 | targets += line.split() 248 | targets = np.unique(targets).tolist() 249 | 250 | elif args.dataset == "Sintel": 251 | dataset_root = YOUR_DIR 252 | 253 | targets = [] 254 | for split in ["training", "test"]: 255 | with open(os.path.join(dataset_root, split, "image_list.txt"), "r") as f: 256 | line = f.readlines()[0] 257 | line = line.split(" ") 258 | targets += [os.path.join(split, t) for t in line] 259 | 260 | elif args.dataset == "Sintel-raw": 261 | dataset_root = YOUR_DIR 262 | 263 | targets = [] 264 | with open(os.path.join(dataset_root, "sample_list.txt"), "r") as f: 265 | lines = f.readlines() 266 | 267 | for line in lines: 268 | targets += line.split() 269 | targets = np.unique(targets).tolist() 270 | 271 | else: 272 | raise ValueError(f"Unknown dataset: {args.dataset}") 273 | 274 | os.makedirs(os.path.join(args.output, args.dataset), exist_ok=True) 275 | 276 | for t in tqdm(targets): 277 | print(f"Processing '{t}'...") 278 | image = cv2.imread(os.path.join(dataset_root, t)) 279 | if image is None: 280 | print(f"Could not load '{t}' as an image, skipping...") 281 | continue 282 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 283 | 284 | masks = generator.generate(image) 285 | 286 | base = os.path.splitext(t)[0] 287 | save_base = os.path.join(args.output, args.dataset, base) 288 | if not os.path.exists(os.path.dirname(save_base)): 289 | os.makedirs(os.path.dirname(save_base), exist_ok=True) 290 | if output_mode == "binary_mask": 291 | os.makedirs(save_base, exist_ok=False) 292 | write_masks_to_folder(masks, save_base) 293 | else: 294 | save_file = save_base + ".json" 295 | with open(save_file, "w") as f: 296 | json.dump(masks, f) 297 | print("Done!") 298 | 299 | 300 | def main_mask_to_full_seg(): 301 | home_dir = YOUR_DIR 302 | 303 | import imageio 304 | from pycocotools import mask as mask_utils 305 | 306 | # ds = "KITTI-raw" 307 | # ds = "Sintel-raw" 308 | ds = "KITTI-2012/training" 309 | 310 | with open("{}/data/{}/image_list_mv.txt".format(home_dir, ds), "r") as f: 311 | lines = f.readlines() 312 | img_list = [line.strip() for line in lines] 313 | 314 | for img_name in tqdm(img_list): 315 | with open( 316 | "{}/results/sam_results/raw/{}/{}.json".format(home_dir, ds, img_name[:-4]), 317 | "r", 318 | ) as f: 319 | masks = json.load(f) 320 | 321 | masks_map = np.array( 322 | mask_utils.decode([mask["segmentation"] for mask in masks]), 323 | dtype=np.float32, 324 | ) 325 | 326 | H, W = masks_map.shape[:2] 327 | masks_area = np.array([mask["area"] for mask in masks]) 328 | 329 | # drop mask if it equals the full frame 330 | masks_map = masks_map[:, :, masks_area < H * W] 331 | masks_area = masks_area[masks_area < H * W] 332 | 333 | # sort the class ids by area, largest to smallest 334 | area_order = np.argsort(masks_area)[::-1] 335 | masks_area = masks_area[area_order] 336 | masks_map = masks_map[:, :, area_order] 337 | 338 | # add a "background mask" for pixels that are not included in any masks 339 | masks_map_aug = np.concatenate((np.ones((H, W, 1)), masks_map), axis=-1) 340 | masks_area_aug = np.array([H * W] + masks_area.tolist()) 341 | masks_area_aug = np.array(masks_area_aug, dtype=np.float32) 342 | 343 | unified_mask = np.argmin( 344 | masks_map_aug * masks_area_aug[None, None, :] 345 | + (1 - masks_map_aug) * (H * W + 1), 346 | axis=-1, 347 | ) 348 | 349 | unique_classes = np.unique(unified_mask) 350 | mapping = np.zeros((unique_classes.max() + 1)) 351 | for i, cl in enumerate(unique_classes): 352 | mapping[cl] = i 353 | new_mask = mapping[unified_mask] 354 | 355 | if new_mask.max() > 255: # almost not existent 356 | print("More than 256 masks detect for image {}".format(img_name)) 357 | new_mask[new_mask > 255] = 0 358 | new_mask = new_mask.astype(np.uint8) 359 | 360 | save_path = "{}/results/sam_results/full_seg/{}/{}".format( 361 | home_dir, ds, img_name 362 | ) 363 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 364 | imageio.imwrite(save_path, new_mask) 365 | 366 | 367 | def main_mask_to_key_objects(): 368 | home_dir = YOUR_DIR 369 | 370 | from pycocotools import mask as mask_utils 371 | 372 | # ds = "KITTI-raw" 373 | # ds = "Sintel-raw" 374 | # ds = "KITTI-2015/training" 375 | ds = "KITTI-2012/training" 376 | # ds = "Sintel/training" 377 | 378 | with open("{}/data/{}/image_list_mv.txt".format(home_dir, ds), "r") as f: 379 | lines = f.readlines() 380 | img_list = [line.strip() for line in lines] 381 | 382 | for img_name in tqdm(img_list): 383 | with open( 384 | "{}/results/sam_results/raw/{}/{}.json".format(home_dir, ds, img_name[:-4]), 385 | "r", 386 | ) as f: 387 | masks = json.load(f) 388 | 389 | masks_map = np.array( 390 | mask_utils.decode([mask["segmentation"] for mask in masks]), 391 | dtype=np.float32, 392 | ) 393 | H, W = masks_map.shape[:2] 394 | obj_masks = np.zeros((H, W, 0), dtype=np.uint8) 395 | 396 | for mask_id in range(len(masks)): 397 | mask = masks_map[:, :, mask_id] 398 | w, h = masks[mask_id]["bbox"][2:4] 399 | area = masks[mask_id]["area"] 400 | 401 | if not (50 <= h <= 200 and 50 <= w <= 300): 402 | continue 403 | 404 | if area / (h * w) < 0.5: 405 | continue 406 | 407 | num_unique_masks = ((masks_map * mask[:, :, None]).sum((0, 1)) > 0).sum() 408 | if num_unique_masks >= 6: 409 | obj_masks = np.concatenate( 410 | (obj_masks, (mask[:, :, None] * 255).astype(np.uint8)), axis=-1 411 | ) 412 | 413 | save_path = "{}/results/sam_results/key_objects/{}/{}.npy".format( 414 | home_dir, ds, img_name[:-4] 415 | ) 416 | if not os.path.exists(os.path.dirname(save_path)): 417 | os.makedirs(os.path.dirname(save_path)) 418 | np.save(save_path, obj_masks) 419 | 420 | 421 | def invoke_main() -> None: 422 | args = parser.parse_args() 423 | main(args) 424 | 425 | # main_mask_to_full_seg() 426 | 427 | # main_mask_to_key_objects() 428 | 429 | 430 | if __name__ == "__main__": 431 | invoke_main() # pragma: no cover 432 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | """ 4 | 5 | import argparse 6 | 7 | import os 8 | 9 | import torch 10 | from datasets.flow_datasets import KITTIFlowEval, Sintel 11 | 12 | from models.get_model import get_model 13 | 14 | from torchvision import transforms 15 | from tqdm import tqdm 16 | from transforms import input_transforms 17 | from utils.config_parser import init_config 18 | from utils.flow_utils import resize_flow, writeFlowKITTI, writeFlowSintel 19 | from utils.manifold_utils import MANIFOLD_BUCKET, MANIFOLD_PATH, pathmgr 20 | from utils.torch_utils import restore_model 21 | 22 | parser = argparse.ArgumentParser( 23 | description="create_submission", 24 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 25 | ) 26 | parser.add_argument( 27 | "--model-folder", 28 | required=True, 29 | type=str, 30 | help="the model folder (that contains the configuration file)", 31 | ) 32 | parser.add_argument( 33 | "--output-dir", 34 | default=None, 35 | type=str, 36 | help="Output directory; default is test_flow under the folder of the model", 37 | ) 38 | parser.add_argument( 39 | "--trained-model", 40 | required=True, 41 | default="model_ckpt.pth.tar", 42 | type=str, 43 | help="trained model path in the model folder", 44 | ) 45 | parser.add_argument( 46 | "--dataset", type=str, choices=["sintel", "kitti"], help="sintel/kitti" 47 | ) 48 | parser.add_argument( 49 | "--subset", type=str, default="test", choices=["train", "test"], help="train/test" 50 | ) 51 | 52 | 53 | def tensor2array(tensor): 54 | return tensor.detach().cpu().numpy().transpose([0, 2, 3, 1]) 55 | 56 | 57 | @torch.no_grad() 58 | def create_sintel_submission(model, args): 59 | """Create submission for the Sintel leaderboard""" 60 | 61 | input_transform = transforms.Compose( 62 | [ 63 | input_transforms.Zoom(args.img_height, args.img_width), 64 | input_transforms.ArrayToTensor(), 65 | ] 66 | ) 67 | 68 | # start inference 69 | model.eval() 70 | for dstype in ["final", "clean"]: 71 | # ds_dir = os.path.join(args.output_dir, dstype) 72 | ds_dir_local = os.path.join(args.output_local_dir, dstype) 73 | ds_dir_bw_local = os.path.join(args.output_local_dir + "_bw", dstype) 74 | # pathmgr.mkdirs(ds_dir) 75 | os.makedirs(ds_dir_local, exist_ok=True) 76 | os.makedirs(ds_dir_bw_local, exist_ok=True) 77 | 78 | dataset = Sintel( 79 | args.root_sintel, 80 | args.full_seg_root_sintel, 81 | None, 82 | name="sintel-" + dstype, 83 | dataset_type=dstype, 84 | split=args.subset, 85 | with_flow=False, 86 | input_transform=input_transform, 87 | ) 88 | data_loader = torch.utils.data.DataLoader( 89 | dataset, batch_size=4, pin_memory=True, shuffle=False 90 | ) 91 | 92 | for data in tqdm(data_loader): 93 | img1, img2 = data["img1"].cuda(), data["img2"].cuda() 94 | full_seg1, full_seg2 = data["full_seg1"].cuda(), data["full_seg2"].cuda() 95 | 96 | # compute output 97 | output = model(img1, img2, full_seg1, full_seg2, with_bk=True) 98 | flow_pred = output["flows_12"][0] 99 | flow_pred_bw = output["flows_21"][0] 100 | 101 | for i in range(flow_pred.shape[0]): 102 | 103 | h, w = data["raw_size"][0][i], data["raw_size"][1][i] 104 | h, w = h.item(), w.item() 105 | flow_pred_up = resize_flow(flow_pred[i : (i + 1)], (h, w)) 106 | 107 | scene, frame_id = data["img1_path"][i].split("/")[-2:] 108 | filename = frame_id[:5] + frame_id[6:10] + ".flo" 109 | # output_file = os.path.join(ds_dir, scene, filename) 110 | output_file_local = os.path.join(ds_dir_local, scene, filename) 111 | 112 | # wrtie to local and then move to manifold 113 | writeFlowSintel(output_file_local, tensor2array(flow_pred_up)[0]) 114 | 115 | ## also compute backward flow 116 | flow_pred_bw_up = resize_flow(flow_pred_bw[i : (i + 1)], (h, w)) 117 | output_file_local = os.path.join(ds_dir_bw_local, scene, filename) 118 | writeFlowSintel(output_file_local, tensor2array(flow_pred_bw_up)[0]) 119 | 120 | # if not pathmgr.exists(os.path.dirname(output_file)): 121 | # pathmgr.mkdirs(os.path.dirname(output_file)) 122 | # pathmgr.copy_from_local(output_file_local, output_file) 123 | 124 | print("Completed!") 125 | return 126 | 127 | 128 | @torch.no_grad() 129 | def create_kitti_submission(model, args): 130 | """Create submission for the KITTI leaderboard""" 131 | 132 | input_transform = transforms.Compose( 133 | [ 134 | input_transforms.Zoom(args.img_height, args.img_width), 135 | input_transforms.ArrayToTensor(), 136 | ] 137 | ) 138 | 139 | dataset_2012 = KITTIFlowEval( 140 | os.path.join(args.root_kitti12, args.subset + "ing"), 141 | os.path.join(args.full_seg_root_kitti12, args.subset + "ing"), 142 | None, 143 | name="kitti2012", 144 | input_transform=input_transform, 145 | test_mode=True, 146 | ) 147 | dataset_2015 = KITTIFlowEval( 148 | os.path.join(args.root_kitti15, args.subset + "ing"), 149 | os.path.join(args.full_seg_root_kitti15, args.subset + "ing"), 150 | None, 151 | name="kitti2015", 152 | input_transform=input_transform, 153 | test_mode=True, 154 | ) 155 | 156 | # start inference 157 | model.eval() 158 | for ds in [dataset_2015, dataset_2012]: 159 | # ds_dir = os.path.join(args.output_dir, ds.name) 160 | ds_dir_local = os.path.join(args.output_local_dir, ds.name) 161 | ds_dir_bw_local = os.path.join(args.output_local_dir + "_bw", ds.name) 162 | # pathmgr.mkdirs(os.path.join(ds_dir, "flow")) 163 | os.makedirs(os.path.join(ds_dir_local, "flow"), exist_ok=True) 164 | os.makedirs(os.path.join(ds_dir_bw_local, "flow"), exist_ok=True) 165 | 166 | data_loader = torch.utils.data.DataLoader( 167 | ds, batch_size=4, pin_memory=True, shuffle=False 168 | ) 169 | for data in tqdm(data_loader): 170 | 171 | img1, img2 = data["img1"].cuda(), data["img2"].cuda() 172 | full_seg1, full_seg2 = data["full_seg1"].cuda(), data["full_seg2"].cuda() 173 | 174 | # compute output 175 | output = model(img1, img2, full_seg1, full_seg2, with_bk=True) 176 | flow_pred = output["flows_12"][0] 177 | flow_pred_bw = output["flows_21"][0] 178 | 179 | for i in range(flow_pred.shape[0]): 180 | h, w = data["raw_size"][0][i], data["raw_size"][1][i] 181 | h, w = h.item(), w.item() 182 | flow_pred_up = resize_flow(flow_pred[i : (i + 1)], (h, w)) 183 | 184 | filename = os.path.basename(data["img1_path"][i]) 185 | # output_file = os.path.join(ds_dir, "flow", filename) 186 | output_file_local = os.path.join(ds_dir_local, "flow", filename) 187 | 188 | # wrtie to local and then move to manifold 189 | writeFlowKITTI(output_file_local, tensor2array(flow_pred_up)[0]) 190 | # pathmgr.copy_from_local(output_file_local, output_file) 191 | 192 | ## also compute backward flow 193 | flow_pred_bw_up = resize_flow(flow_pred_bw[i : (i + 1)], (h, w)) 194 | output_file_local = os.path.join(ds_dir_bw_local, "flow", filename) 195 | writeFlowKITTI(output_file_local, tensor2array(flow_pred_bw_up)[0]) 196 | 197 | print("Completed!") 198 | return 199 | 200 | 201 | @torch.no_grad() 202 | def main(): 203 | args = parser.parse_args() 204 | 205 | args.full_model_folder = os.path.join( 206 | "memcache_manifold://", MANIFOLD_BUCKET, MANIFOLD_PATH, args.model_folder 207 | ) 208 | 209 | if args.output_dir is None: 210 | args.output_dir = os.path.join( 211 | args.full_model_folder, args.subset + "_flow_" + args.dataset 212 | ) 213 | args.output_local_dir = os.path.join( 214 | YOUR_DIR, 215 | args.model_folder, 216 | args.subset + "_flow_" + args.dataset, 217 | ) 218 | 219 | # pathmgr.mkdirs(args.output_dir) 220 | os.makedirs(args.output_local_dir, exist_ok=True) 221 | 222 | ## set up the model 223 | config_file = os.path.join(args.full_model_folder, "config.json") 224 | model_file = os.path.join(args.full_model_folder, args.trained_model) 225 | cfg = init_config(config_file) 226 | 227 | model = get_model(cfg.model).cuda() 228 | 229 | model = restore_model(model, model_file) 230 | model.eval() 231 | 232 | if args.dataset == "sintel": 233 | args.img_height, args.img_width = 448, 1024 234 | 235 | # Use local data to save time 236 | args.root_sintel = YOUR_DIR 237 | args.full_seg_root_sintel = YOUR_DIR 238 | 239 | create_sintel_submission(model, args) 240 | elif args.dataset == "kitti": 241 | args.img_height, args.img_width = 256, 832 242 | 243 | # Use local data to save time 244 | args.root_kitti12 = YOUR_DIR 245 | args.root_kitti15 = YOUR_DIR 246 | args.full_seg_root_kitti12 = YOUR_DIR 247 | args.full_seg_root_kitti15 = YOUR_DIR 248 | 249 | create_kitti_submission(model, args) 250 | 251 | 252 | if __name__ == "__main__": 253 | main() 254 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | """ 4 | 5 | import datetime 6 | 7 | curr_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") 8 | 9 | import argparse 10 | import os 11 | import pprint 12 | 13 | import torch 14 | 15 | from utils.config_parser import init_config 16 | 17 | # from utils.logger import init_logger 18 | 19 | torch.backends.cudnn.benchmark = True 20 | import numpy as np 21 | 22 | import pkg_resources 23 | 24 | from datasets.get_dataset import get_dataset 25 | 26 | from fblearner.flow.util.visualization_utils import summary_writer 27 | 28 | from losses.get_loss import get_loss 29 | 30 | from models.get_model import get_model 31 | 32 | from trainer.get_trainer import get_trainer 33 | 34 | # our internal file system; please comment out this line and change I/O to your own file system 35 | from utils.manifold_utils import MANIFOLD_BUCKET, MANIFOLD_PATH, pathmgr 36 | 37 | from utils.torch_utils import init_seed 38 | 39 | 40 | def main_ddp(rank, world_size, cfg): 41 | init_seed(cfg.seed) 42 | 43 | # set up distributed process groups 44 | os.environ["MASTER_ADDR"] = "localhost" 45 | os.environ["MASTER_PORT"] = "12356" 46 | 47 | torch.distributed.init_process_group( 48 | backend="nccl", rank=rank, world_size=world_size 49 | ) 50 | 51 | device = torch.device("cuda:%d" % rank) 52 | torch.cuda.set_device(device) 53 | print(f"Use GPU {rank} ({torch.cuda.get_device_name(rank)}) for training") 54 | 55 | # prepare data 56 | train_sets, valid_sets, train_sets_epoches = get_dataset(cfg.data) 57 | if rank == 0: 58 | print( 59 | "train sets: " 60 | + ", ".join( 61 | ["{} ({} samples)".format(ds.name, len(ds)) for ds in train_sets] 62 | ) 63 | ) 64 | print( 65 | "val sets: " 66 | + ", ".join( 67 | ["{} ({} samples)".format(ds.name, len(ds)) for ds in valid_sets] 68 | ) 69 | ) 70 | 71 | train_sets_epoches = [np.inf if e == -1 else e for e in train_sets_epoches] 72 | 73 | train_loaders, valid_loaders = [], [] 74 | for ds in train_sets: 75 | sampler = torch.utils.data.DistributedSampler( 76 | ds, num_replicas=world_size, rank=rank, shuffle=True, drop_last=True 77 | ) 78 | train_loader = torch.utils.data.DataLoader( 79 | ds, 80 | batch_size=cfg.train.batch_size // world_size, 81 | num_workers=cfg.train.workers // world_size, 82 | pin_memory=True, 83 | sampler=sampler, 84 | ) 85 | train_loaders.append(train_loader) 86 | 87 | if rank == 0: 88 | # prepare tensorboard 89 | writer = summary_writer(log_dir=cfg.save_root) 90 | 91 | # prepare validation dataset 92 | for ds in valid_sets: 93 | valid_loader = torch.utils.data.DataLoader( 94 | ds, 95 | batch_size=4, 96 | num_workers=4, 97 | pin_memory=True, 98 | shuffle=False, 99 | drop_last=False, 100 | ) 101 | valid_loaders.append(valid_loader) 102 | valid_size = sum([len(loader) for loader in valid_loaders]) 103 | if cfg.train.valid_size == 0: 104 | cfg.train.valid_size = valid_size 105 | cfg.train.valid_size = min(cfg.train.valid_size, valid_size) 106 | 107 | else: 108 | writer = None 109 | valid_loaders = [] 110 | 111 | # prepare model 112 | model = get_model(cfg.model).to(device) 113 | model = torch.nn.parallel.DistributedDataParallel( 114 | model, 115 | device_ids=[rank], 116 | output_device=rank, 117 | ) 118 | 119 | # prepare loss 120 | loss = get_loss(cfg.loss) 121 | 122 | # prepare training scipt 123 | trainer = get_trainer(cfg.trainer)( 124 | train_loaders, 125 | valid_loaders, 126 | model, 127 | loss, 128 | cfg.save_root, 129 | cfg.train, 130 | resume=cfg.resume, 131 | train_sets_epoches=train_sets_epoches, 132 | summary_writer=writer, 133 | rank=rank, 134 | world_size=world_size, 135 | ) 136 | 137 | trainer.train() 138 | 139 | torch.distributed.destroy_process_group() 140 | 141 | 142 | def main(args, run_id=None): 143 | 144 | # resuming 145 | if args.resume is not None: 146 | args.config = os.path.join( 147 | "manifold://", MANIFOLD_BUCKET, MANIFOLD_PATH, args.resume, "config.json" 148 | ) 149 | else: 150 | args.config = pkg_resources.resource_filename(__name__, args.config) 151 | 152 | # load config 153 | cfg = init_config(args.config) 154 | cfg.train.n_gpu = args.n_gpu 155 | 156 | # DEBUG options 157 | cfg.train.DEBUG = args.DEBUG 158 | if args.DEBUG: 159 | cfg.data.update( 160 | { 161 | "epoches_raw": 3, 162 | } 163 | ) 164 | cfg.train.update( 165 | { 166 | "batch_size": 4, 167 | "epoch_num": 5, 168 | "epoch_size": 20, 169 | "print_freq": 1, 170 | "record_freq": 1, 171 | "val_epoch_size": 2, 172 | "valid_size": 4, 173 | "save_iter": 2, 174 | } 175 | ) 176 | if "stage1" in cfg.train: 177 | cfg.train.stage1.update({"epoch": 5}) 178 | if "stage2" in cfg.train: 179 | cfg.train.stage2.update({"epoch": 5}) 180 | 181 | # pretrained model 182 | if args.model is not None: 183 | cfg.train.pretrained_model = args.model 184 | 185 | # init save_root: store files by curr_time 186 | if args.resume is not None: 187 | cfg.resume = True 188 | cfg.save_root = os.path.join( 189 | "manifold://", MANIFOLD_BUCKET, MANIFOLD_PATH, args.resume 190 | ) 191 | else: 192 | cfg.resume = False 193 | args.name = os.path.basename(args.config)[:-5] 194 | 195 | dirname = curr_time + "_" + args.name 196 | if run_id is not None: 197 | dirname = dirname + "_f" + str(run_id) 198 | if args.DEBUG: 199 | dirname = "_DEBUG_" + dirname 200 | 201 | cfg.save_root = os.path.join( 202 | "manifold://", 203 | MANIFOLD_BUCKET, 204 | MANIFOLD_PATH, 205 | args.exp_folder, 206 | dirname, 207 | ) 208 | 209 | ## for the manifold file system 210 | 211 | if not pathmgr.exists(cfg.save_root): 212 | pathmgr.mkdirs(cfg.save_root) 213 | 214 | pathmgr.copy_from_local( 215 | args.config, os.path.join(cfg.save_root, "config.json") 216 | ) 217 | 218 | if "base_configs" in cfg: 219 | pathmgr.copy_from_local( 220 | os.path.join(os.path.dirname(args.config), cfg.base_configs), 221 | os.path.join(cfg.save_root, cfg.base_configs), 222 | ) 223 | 224 | """ 225 | ## for the linux file system 226 | os.makedirs(cfg.save_root) 227 | os.system( 228 | "cp {} {}".format(args.config, os.path.join(cfg.save_root, "config.json")) 229 | ) 230 | if "base_configs" in cfg: 231 | os.system( 232 | "cp {} {}".format( 233 | os.path.join(os.path.dirname(args.config), cfg.base_configs), 234 | os.path.join(cfg.save_root, cfg.base_configs), 235 | ) 236 | ) 237 | """ 238 | 239 | print("=> will save everything to {}".format(cfg.save_root)) 240 | 241 | # show configurations 242 | cfg_str = pprint.pformat(cfg) 243 | print("=> configurations \n " + cfg_str) 244 | 245 | # spawn ddp 246 | world_size = args.n_gpu 247 | torch.multiprocessing.spawn( 248 | main_ddp, 249 | args=(world_size, cfg), 250 | nprocs=world_size, 251 | ) 252 | 253 | print("Completed!") 254 | return 255 | 256 | 257 | def invoke_main() -> None: 258 | parser = argparse.ArgumentParser() 259 | parser.add_argument("-c", "--config", default="configs/base_kitti.json") 260 | parser.add_argument("-m", "--model", default=None) 261 | parser.add_argument("--exp_folder", default="other") 262 | parser.add_argument("-n", "--name", default=None) 263 | parser.add_argument("-r", "--resume", default=None) 264 | parser.add_argument("--n_gpu", type=int, default=2) 265 | parser.add_argument("--DEBUG", action="store_true") 266 | args = parser.parse_args() 267 | 268 | main(args) 269 | 270 | 271 | if __name__ == "__main__": 272 | invoke_main() # pragma: no cover 273 | -------------------------------------------------------------------------------- /trainer/base_trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | """ 4 | 5 | import os 6 | 7 | from abc import abstractmethod 8 | 9 | import numpy as np 10 | import torch 11 | 12 | from utils.manifold_utils import pathmgr 13 | from utils.torch_utils import ( 14 | AdamW, 15 | bias_parameters, 16 | load_checkpoint, 17 | other_parameters, 18 | save_checkpoint, 19 | weight_parameters, 20 | ) 21 | 22 | from .object_cache import ObjectCache 23 | 24 | 25 | class BaseTrainer: 26 | """ 27 | Base class for all trainers 28 | """ 29 | 30 | def __init__( 31 | self, 32 | train_loaders, 33 | valid_loaders, 34 | model, 35 | loss_func, 36 | save_root, 37 | config, 38 | resume=False, 39 | train_sets_epoches=None, 40 | summary_writer=None, 41 | rank=0, 42 | world_size=1, 43 | ): 44 | self.cfg = config 45 | self.save_root = save_root 46 | self.summary_writer = summary_writer 47 | self.train_loaders, self.valid_loaders = train_loaders, valid_loaders 48 | self.train_sets_epoches = train_sets_epoches 49 | 50 | self.rank, self.world_size = rank, world_size 51 | self.device = model.device 52 | self.loss_func = loss_func 53 | 54 | if resume: # load all states 55 | self._load_resume_ckpt(model) 56 | else: 57 | self.model = self._init_model(model) 58 | self.i_epoch, self.i_iter = 0, 0 59 | self.i_train_set = 0 60 | while ( 61 | self.train_sets_epoches[self.i_train_set] == 0 62 | ): # skip the datasets of 0 epoches 63 | self.i_train_set += 1 64 | 65 | self.optimizer = self._create_optimizer() 66 | self.scheduler = self._create_scheduler( 67 | self.optimizer, self.train_sets_epoches[self.i_train_set] 68 | ) 69 | 70 | self.best_error = np.inf 71 | 72 | @abstractmethod 73 | def _run_one_epoch(self): 74 | ... 75 | 76 | @abstractmethod 77 | def _validate_with_gt(self): 78 | ... 79 | 80 | def log(self, s): 81 | if self.rank == 0: 82 | print(s) 83 | 84 | def set_up_obj_cache(self, cache_size=500): 85 | self.obj_cache = ObjectCache(cache_size=cache_size) 86 | 87 | def train(self): 88 | 89 | if ( 90 | self.cfg.pretrained_model is not None 91 | ): # if using a pretrained model, evaluate that first to compare 92 | if self.rank == 0: 93 | self._validate_with_gt() 94 | torch.distributed.barrier() 95 | 96 | for _epoch in range(self.i_epoch, self.cfg.epoch_num): 97 | self._run_one_epoch() 98 | 99 | if self.i_epoch >= sum(self.train_sets_epoches[: (self.i_train_set + 1)]): 100 | self.i_train_set += 1 101 | self.optimizer = ( 102 | self._create_optimizer() 103 | ) # reset the states of optimizer as well 104 | self.scheduler = self._create_scheduler( 105 | self.optimizer, self.train_sets_epoches[self.i_train_set] 106 | ) 107 | 108 | if self.rank == 0: 109 | if self.i_epoch % self.cfg.val_epoch_size == 0: 110 | self._validate_with_gt() 111 | self.log(" * Epoch {} validation complete.".format(self.i_epoch)) 112 | 113 | torch.distributed.barrier() 114 | 115 | # def zero_grad(self): 116 | # # One Pytorch tutorial suggests clearing the gradients this way for faster speed 117 | # # https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html 118 | # for param in self.model.parameters(): 119 | # param.grad = None 120 | 121 | def _init_model(self, model): 122 | model = model.to(self.device) 123 | 124 | if self.cfg.pretrained_model: 125 | self.log( 126 | "=> using pre-trained weights {}.".format(self.cfg.pretrained_model) 127 | ) 128 | epoch, weights = load_checkpoint(self.cfg.pretrained_model) 129 | model.module.load_state_dict(weights) 130 | else: 131 | self.log("=> Train from scratch.") 132 | model.module.init_weights() 133 | 134 | self.log("number of parameters: {}".format(self.count_parameters(model))) 135 | self.log( 136 | "gpu memory allocated (model parameters only): {} Bytes".format( 137 | torch.cuda.memory_allocated() 138 | ) 139 | ) 140 | return model 141 | 142 | def _create_optimizer(self): 143 | self.log("=> setting {} optimizer".format(self.cfg.optim)) 144 | param_groups = [ 145 | { 146 | "params": bias_parameters(self.model.module), 147 | "weight_decay": self.cfg.bias_decay, 148 | }, 149 | { 150 | "params": weight_parameters(self.model.module), 151 | "weight_decay": self.cfg.weight_decay, 152 | }, 153 | {"params": other_parameters(self.model.module), "weight_decay": 0}, 154 | ] 155 | 156 | if self.cfg.optim == "adamw": 157 | optimizer = AdamW( 158 | param_groups, self.cfg.lr, betas=(self.cfg.momentum, self.cfg.beta) 159 | ) 160 | elif self.cfg.optim == "adam": 161 | optimizer = torch.optim.Adam( 162 | param_groups, 163 | self.cfg.lr, 164 | betas=(self.cfg.momentum, self.cfg.beta), 165 | eps=1e-7, 166 | ) 167 | else: 168 | raise NotImplementedError(self.cfg.optim) 169 | 170 | return optimizer 171 | 172 | def _create_scheduler(self, optimizer, epoches=np.inf): 173 | 174 | if ( 175 | self.i_train_set < len(self.train_sets_epoches) - 1 176 | ): # try only the last loader uses onecyclelr 177 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=1) 178 | return scheduler 179 | 180 | if "lr_scheduler" in self.cfg.keys(): 181 | self.log("=> setting {} scheduler".format(self.cfg.lr_scheduler.module)) 182 | 183 | params = self.cfg.lr_scheduler.params 184 | 185 | if self.cfg.lr_scheduler.module == "OneCycleLR": 186 | params["epochs"] = min(epoches, self.cfg.epoch_num - self.i_epoch) 187 | params["steps_per_epoch"] = self.cfg.epoch_size 188 | 189 | scheduler = getattr(torch.optim.lr_scheduler, self.cfg.lr_scheduler.module)( 190 | optimizer, **params 191 | ) 192 | else: # a dummy scheduler by default 193 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=1) 194 | 195 | return scheduler 196 | 197 | def _load_resume_ckpt(self, model): 198 | self.log("==> resuming") 199 | 200 | with pathmgr.open( 201 | os.path.join(self.save_root, "model_ckpt.pth.tar"), "rb" 202 | ) as f: 203 | ckpt_dict = torch.load(f) 204 | 205 | if "iter" not in ckpt_dict.keys(): 206 | ckpt_dict["iter"] = ckpt_dict["epoch"] * self.cfg.epoch_size 207 | if "best_error" not in ckpt_dict.keys(): 208 | ckpt_dict["best_error"] = np.inf 209 | self.i_epoch, self.i_iter, self.best_error = ( 210 | ckpt_dict["epoch"], 211 | ckpt_dict["iter"], 212 | ckpt_dict["best_error"], 213 | ) 214 | self.i_train_set = np.where(self.i_epoch < np.cumsum(self.train_sets_epoches))[ 215 | 0 216 | ][0] 217 | 218 | model = model.to(self.device) 219 | model.module.load_state_dict(ckpt_dict["state_dict"]) 220 | # self.model = torch.nn.DataParallel(model, device_ids=self.device_ids) 221 | 222 | self.optimizer = self._create_optimizer() 223 | self.scheduler = self._create_scheduler( 224 | self.optimizer, self.train_sets_epoches[self.i_train_set] 225 | ) 226 | 227 | if "optimizer_dict" in ckpt_dict.keys(): 228 | self.optimizer.load_state_dict(ckpt_dict["optimizer_dict"]) 229 | if "scheduler_dict" in ckpt_dict.keys(): 230 | self.scheduler.load_state_dict(ckpt_dict["scheduler_dict"]) 231 | 232 | return 233 | 234 | # def _prepare_device(self, n_gpu_use): 235 | # """ 236 | # setup GPU device if available, move model into configured device 237 | # """ 238 | # n_gpu = torch.cuda.device_count() 239 | # if n_gpu_use > 0 and n_gpu == 0: 240 | # self.log( 241 | # "Warning: There's no GPU available on this machine," 242 | # "training will be performed on CPU." 243 | # ) 244 | # n_gpu_use = 0 245 | # if n_gpu_use > n_gpu: 246 | # self.log( 247 | # "Warning: The number of GPU's configured to use is {}, " 248 | # "but only {} are available.".format(n_gpu_use, n_gpu) 249 | # ) 250 | # n_gpu_use = n_gpu 251 | # device = torch.device("cuda:0" if n_gpu_use > 0 else "cpu") 252 | # list_ids = list(range(n_gpu_use)) 253 | # self.log("=> gpu in use: {} gpu(s)".format(n_gpu_use)) 254 | # self.log( 255 | # "device names: {}".format([torch.cuda.get_device_name(i) for i in list_ids]) 256 | # ) 257 | # return device, list_ids 258 | 259 | def count_parameters(self, model): 260 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 261 | 262 | def save_model(self, name, save_with_runtime=True): 263 | if save_with_runtime: 264 | models = { 265 | "epoch": self.i_epoch, 266 | "iter": self.i_iter, 267 | "best_error": self.best_error, 268 | "state_dict": self.model.module.state_dict(), 269 | "optimizer_dict": self.optimizer.state_dict(), 270 | "scheduler_dict": self.scheduler.state_dict(), 271 | } 272 | else: 273 | models = {"state_dict": self.model.module.state_dict()} 274 | 275 | save_checkpoint(self.save_root, models, name, is_best=False) 276 | -------------------------------------------------------------------------------- /trainer/get_trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | """ 4 | 5 | from . import kitti_trainer_ar, sintel_trainer_ar 6 | 7 | 8 | def get_trainer(name): 9 | if name == "KITTI_AR": 10 | TrainFramework = kitti_trainer_ar.TrainFramework 11 | elif name == "SINTEL_AR": 12 | TrainFramework = sintel_trainer_ar.TrainFramework 13 | else: 14 | raise NotImplementedError(name) 15 | 16 | return TrainFramework 17 | -------------------------------------------------------------------------------- /trainer/object_cache.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | # import torch.nn.functional as F 5 | 6 | 7 | class ObjectCache: 8 | def __init__(self, cache_size=500): 9 | self.cache_size = cache_size 10 | self._obj_mask_cache = None 11 | self._img_cache = None 12 | self._motion_cache = None 13 | self.count = 0 14 | 15 | # We initialize the cache when we receive the first push sample so that it adapts to the current img size automatically 16 | def init_cache(self, img_size): 17 | self._obj_mask_cache = torch.zeros( 18 | (self.cache_size, 1, *img_size), dtype=torch.float32 19 | ) 20 | self._img_cache = torch.zeros( 21 | (self.cache_size, 3, *img_size), dtype=torch.float32 22 | ) 23 | self._motion_cache = torch.zeros((self.cache_size, 2), dtype=torch.float32) 24 | return 25 | 26 | def pop(self, B=8, with_aug=True): # we do not remove objects after popping 27 | if self.count < self.cache_size: # do not use it before it is full 28 | return None 29 | 30 | idx = np.random.choice(self.cache_size, B, replace=False) 31 | obj_mask = self._obj_mask_cache[idx] 32 | img = self._img_cache[idx] 33 | motion = self._motion_cache[idx] 34 | 35 | if with_aug: 36 | rand_scale = ( 37 | torch.rand(B) * 0.7 + 0.8 38 | ) # randomly rescale motion by 0.8-1.5 times 39 | rand_scale *= (-1) ** ( 40 | torch.rand(B) > 0.5 41 | ).float() # randomly reverse motion 42 | motion = motion * rand_scale[:, None] 43 | 44 | flip_flag = torch.rand(B) > 0.5 # randomly horitontal-flip obj mask 45 | img[flip_flag] = img[flip_flag].flip(dims=[3]) 46 | obj_mask[flip_flag] = obj_mask[flip_flag].flip(dims=[3]) 47 | motion[flip_flag, 0] *= -1 48 | 49 | return obj_mask, img, motion 50 | 51 | def push(self, obj_mask, img, motion): 52 | """ 53 | obj_mask: [B, 1, H, W] 54 | img: [B, 3, H, W] 55 | motion: [B, 2] 56 | """ 57 | 58 | if self._obj_mask_cache is None: 59 | self.init_cache(img_size=img.shape[-2:]) 60 | 61 | B = obj_mask.shape[0] 62 | 63 | if self.count <= self.cache_size - B: # many spaces 64 | self._obj_mask_cache[self.count : (self.count + B)] = obj_mask 65 | self._img_cache[self.count : (self.count + B)] = img 66 | self._motion_cache[self.count : (self.count + B)] = motion 67 | self.count += B 68 | return 69 | 70 | elif self.count < self.cache_size: # partial space 71 | space = self.cache_size - self.count 72 | self._obj_mask_cache[self.count :] = obj_mask[:space] 73 | self._img_cache[self.count :] = img[:space] 74 | self._motion_cache[self.count :] = motion[:space] 75 | 76 | overwrite_idx = np.random.choice(self.count, B - space, replace=False) 77 | self._obj_mask_cache[overwrite_idx] = obj_mask[space:] 78 | self._img_cache[overwrite_idx] = img[space:] 79 | self._motion_cache[overwrite_idx] = motion[space:] 80 | self.count += space 81 | return 82 | 83 | else: # no spaces; random overwrite 84 | overwrite_idx = np.random.choice(self.cache_size, B, replace=False) 85 | self._obj_mask_cache[overwrite_idx] = obj_mask 86 | self._img_cache[overwrite_idx] = img 87 | self._motion_cache[overwrite_idx] = motion 88 | return 89 | -------------------------------------------------------------------------------- /transforms/ar_transforms/ap_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import ImageFilter 4 | from torchvision import transforms as tf 5 | 6 | 7 | def get_ap_transforms(cfg): 8 | transforms = [] 9 | if cfg.cj: 10 | transforms.append( 11 | ColorJitter( 12 | brightness=cfg.cj_bri, 13 | contrast=cfg.cj_con, 14 | saturation=cfg.cj_sat, 15 | hue=cfg.cj_hue, 16 | ) 17 | ) 18 | transforms.append(ToPILImage()) 19 | if cfg.gblur: 20 | transforms.append(RandomGaussianBlur(0.5, 3)) 21 | transforms.append(ToTensor()) 22 | if cfg.gamma: 23 | transforms.append(RandomGamma(min_gamma=0.7, max_gamma=1.5, clip_image=True)) 24 | return tf.Compose(transforms) 25 | 26 | 27 | # from https://github.com/visinf/irr/blob/master/datasets/transforms.py 28 | class ToPILImage(tf.ToPILImage): 29 | def __call__(self, imgs): 30 | return [super(ToPILImage, self).__call__(im) for im in imgs] 31 | 32 | 33 | class ColorJitter(tf.ColorJitter): 34 | def __call__(self, imgs): 35 | _, h, w = imgs[0].shape 36 | new_big_img = self.forward(torch.concatenate(imgs, dim=1)) 37 | return list(torch.split(new_big_img, h, dim=1)) 38 | # return [self.foward(im) for im in imgs] 39 | 40 | 41 | class ToTensor(tf.ToTensor): 42 | def __call__(self, imgs): 43 | return [super(ToTensor, self).__call__(im) for im in imgs] 44 | 45 | 46 | class RandomGamma: 47 | def __init__(self, min_gamma=0.7, max_gamma=1.5, clip_image=False): 48 | self._min_gamma = min_gamma 49 | self._max_gamma = max_gamma 50 | self._clip_image = clip_image 51 | 52 | @staticmethod 53 | def get_params(min_gamma, max_gamma): 54 | return np.random.uniform(min_gamma, max_gamma) 55 | 56 | @staticmethod 57 | def adjust_gamma(image, gamma, clip_image): 58 | adjusted = torch.pow(image, gamma) 59 | if clip_image: 60 | adjusted.clamp_(0.0, 1.0) 61 | return adjusted 62 | 63 | def __call__(self, imgs): 64 | gamma = self.get_params(self._min_gamma, self._max_gamma) 65 | return [self.adjust_gamma(im, gamma, self._clip_image) for im in imgs] 66 | 67 | 68 | class RandomGaussianBlur: 69 | def __init__(self, p, max_k_sz): 70 | self.p = p 71 | self.max_k_sz = max_k_sz 72 | 73 | def __call__(self, imgs): 74 | if np.random.random() < self.p: 75 | radius = np.random.uniform(0, self.max_k_sz) 76 | imgs = [im.filter(ImageFilter.GaussianBlur(radius)) for im in imgs] 77 | return imgs 78 | -------------------------------------------------------------------------------- /transforms/ar_transforms/interpolation.py: -------------------------------------------------------------------------------- 1 | ## Portions of Code from, copyright 2018 Jochen Gast 2 | 3 | from __future__ import absolute_import, division, print_function 4 | 5 | import torch 6 | import torch.nn.functional as tf 7 | from torch import nn 8 | 9 | 10 | def _bchw2bhwc(tensor): 11 | return tensor.transpose(1, 2).transpose(2, 3) 12 | 13 | 14 | def _bhwc2bchw(tensor): 15 | return tensor.transpose(2, 3).transpose(1, 2) 16 | 17 | 18 | class Meshgrid(nn.Module): 19 | def __init__(self): 20 | super(Meshgrid, self).__init__() 21 | self.width = 0 22 | self.height = 0 23 | self.register_buffer("xx", torch.zeros(1, 1)) 24 | self.register_buffer("yy", torch.zeros(1, 1)) 25 | self.register_buffer("rangex", torch.zeros(1, 1)) 26 | self.register_buffer("rangey", torch.zeros(1, 1)) 27 | 28 | def _compute_meshgrid(self, width, height): 29 | torch.arange(0, width, out=self.rangex) 30 | torch.arange(0, height, out=self.rangey) 31 | self.xx = self.rangex.repeat(height, 1).contiguous() 32 | self.yy = self.rangey.repeat(width, 1).t().contiguous() 33 | 34 | def forward(self, width, height): 35 | if self.width != width or self.height != height: 36 | self._compute_meshgrid(width=width, height=height) 37 | self.width = width 38 | self.height = height 39 | return self.xx, self.yy 40 | 41 | 42 | class BatchSub2Ind(nn.Module): 43 | def __init__(self): 44 | super(BatchSub2Ind, self).__init__() 45 | self.register_buffer("_offsets", torch.LongTensor()) 46 | 47 | def forward(self, shape, row_sub, col_sub, out=None): 48 | batch_size = row_sub.size(0) 49 | height, width = shape 50 | ind = row_sub * width + col_sub 51 | torch.arange(batch_size, out=self._offsets) 52 | self._offsets *= height * width 53 | 54 | if out is None: 55 | return torch.add(ind, self._offsets.view(-1, 1, 1)) 56 | else: 57 | torch.add(ind, self._offsets.view(-1, 1, 1), out=out) 58 | 59 | 60 | class Interp2(nn.Module): 61 | def __init__(self, clamp=False): 62 | super(Interp2, self).__init__() 63 | self._clamp = clamp 64 | self._batch_sub2ind = BatchSub2Ind() 65 | self.register_buffer("_x0", torch.LongTensor()) 66 | self.register_buffer("_x1", torch.LongTensor()) 67 | self.register_buffer("_y0", torch.LongTensor()) 68 | self.register_buffer("_y1", torch.LongTensor()) 69 | self.register_buffer("_i00", torch.LongTensor()) 70 | self.register_buffer("_i01", torch.LongTensor()) 71 | self.register_buffer("_i10", torch.LongTensor()) 72 | self.register_buffer("_i11", torch.LongTensor()) 73 | self.register_buffer("_v00", torch.FloatTensor()) 74 | self.register_buffer("_v01", torch.FloatTensor()) 75 | self.register_buffer("_v10", torch.FloatTensor()) 76 | self.register_buffer("_v11", torch.FloatTensor()) 77 | self.register_buffer("_x", torch.FloatTensor()) 78 | self.register_buffer("_y", torch.FloatTensor()) 79 | 80 | def forward(self, v, xq, yq): 81 | batch_size, channels, height, width = v.size() 82 | 83 | # clamp if wanted 84 | if self._clamp: 85 | xq.clamp_(0, width - 1) 86 | yq.clamp_(0, height - 1) 87 | 88 | # ------------------------------------------------------------------ 89 | # Find neighbors 90 | # 91 | # x0 = torch.floor(xq).long(), x0.clamp_(0, width - 1) 92 | # x1 = x0 + 1, x1.clamp_(0, width - 1) 93 | # y0 = torch.floor(yq).long(), y0.clamp_(0, height - 1) 94 | # y1 = y0 + 1, y1.clamp_(0, height - 1) 95 | # 96 | # ------------------------------------------------------------------ 97 | self._x0 = torch.floor(xq).long().clamp(0, width - 1) 98 | self._y0 = torch.floor(yq).long().clamp(0, height - 1) 99 | 100 | self._x1 = torch.add(self._x0, 1).clamp(0, width - 1) 101 | self._y1 = torch.add(self._y0, 1).clamp(0, height - 1) 102 | 103 | # batch_sub2ind 104 | self._batch_sub2ind([height, width], self._y0, self._x0, out=self._i00) 105 | self._batch_sub2ind([height, width], self._y0, self._x1, out=self._i01) 106 | self._batch_sub2ind([height, width], self._y1, self._x0, out=self._i10) 107 | self._batch_sub2ind([height, width], self._y1, self._x1, out=self._i11) 108 | 109 | # reshape 110 | v_flat = _bchw2bhwc(v).contiguous().view(-1, channels) 111 | torch.index_select(v_flat, dim=0, index=self._i00.view(-1), out=self._v00) 112 | torch.index_select(v_flat, dim=0, index=self._i01.view(-1), out=self._v01) 113 | torch.index_select(v_flat, dim=0, index=self._i10.view(-1), out=self._v10) 114 | torch.index_select(v_flat, dim=0, index=self._i11.view(-1), out=self._v11) 115 | 116 | # local_coords 117 | torch.add(xq, -self._x0.float(), out=self._x) 118 | torch.add(yq, -self._y0.float(), out=self._y) 119 | 120 | # weights 121 | w00 = torch.unsqueeze((1.0 - self._y) * (1.0 - self._x), dim=1) 122 | w01 = torch.unsqueeze((1.0 - self._y) * self._x, dim=1) 123 | w10 = torch.unsqueeze(self._y * (1.0 - self._x), dim=1) 124 | w11 = torch.unsqueeze(self._y * self._x, dim=1) 125 | 126 | def _reshape(u): 127 | return _bhwc2bchw(u.view(batch_size, height, width, channels)) 128 | 129 | # values 130 | values = ( 131 | _reshape(self._v00) * w00 132 | + _reshape(self._v01) * w01 133 | + _reshape(self._v10) * w10 134 | + _reshape(self._v11) * w11 135 | ) 136 | 137 | if self._clamp: 138 | return values 139 | else: 140 | # find_invalid 141 | invalid = ( 142 | ((xq < 0) | (xq >= width) | (yq < 0) | (yq >= height)) 143 | .unsqueeze(dim=1) 144 | .float() 145 | ) 146 | # maskout invalid 147 | transformed = invalid * torch.zeros_like(values) + (1.0 - invalid) * values 148 | 149 | return transformed 150 | 151 | 152 | def resize2D(inputs, size_targets, mode="bilinear"): 153 | size_inputs = [inputs.size(2), inputs.size(3)] 154 | 155 | if all([size_inputs == size_targets]): 156 | return inputs # nothing to do 157 | elif any([size_targets < size_inputs]): 158 | resized = tf.adaptive_avg_pool2d(inputs, size_targets) # downscaling 159 | else: 160 | resized = tf.upsample(inputs, size=size_targets, mode=mode) # upsampling 161 | 162 | # correct scaling 163 | return resized 164 | 165 | 166 | def resize2D_as(inputs, output_as, mode="bilinear"): 167 | size_targets = [output_as.size(2), output_as.size(3)] 168 | return resize2D(inputs, size_targets, mode=mode) 169 | -------------------------------------------------------------------------------- /transforms/ar_transforms/oc_transforms.py: -------------------------------------------------------------------------------- 1 | # from skimage.color import rgb2yuv 2 | # import cv2 3 | import numpy as np 4 | import torch 5 | 6 | # from fast_slic.avx2 import SlicAvx2 as Slic 7 | # from skimage.segmentation import slic as sk_slic 8 | 9 | from utils.warp_utils import flow_warp 10 | 11 | 12 | # def run_slic_pt(img_batch, n_seg=200, compact=10, rd_select=(8, 16), fast=True): # Nx1xHxW 13 | # """ 14 | 15 | # :param img: Nx3xHxW 0~1 float32 16 | # :param n_seg: 17 | # :param compact: 18 | # :return: Nx1xHxW float32 19 | # """ 20 | # B = img_batch.size(0) 21 | # dtype = img_batch.type() 22 | # img_batch = np.split( 23 | # img_batch.detach().cpu().numpy().transpose([0, 2, 3, 1]), B, axis=0) 24 | # out = [] 25 | # if fast: 26 | # fast_slic = Slic(num_components=n_seg, compactness=compact, min_size_factor=0.8) 27 | # for img in img_batch: 28 | # img = np.copy((img * 255).squeeze(0).astype(np.uint8), order='C') 29 | # if fast: 30 | # img = cv2.cvtColor(img, cv2.COLOR_RGB2LAB) 31 | # seg = fast_slic.iterate(img) 32 | # else: 33 | # seg = sk_slic(img, n_segments=200, compactness=10) 34 | 35 | # if rd_select is not None: 36 | # n_select = np.random.randint(rd_select[0], rd_select[1]) 37 | # select_list = np.random.choice(range(0, np.max(seg) + 1), n_select, 38 | # replace=False) 39 | 40 | # seg = np.bitwise_or.reduce([seg == seg_id for seg_id in select_list]) 41 | # out.append(seg) 42 | # x_out = torch.tensor(np.stack(out)).type(dtype).unsqueeze(1) 43 | # return x_out 44 | 45 | 46 | def random_crop(img, full_segs, flow, occ_mask, crop_sz): 47 | """ 48 | 49 | :param img: Nx6xHxW 50 | :param flows: n * [Nx2xHxW] 51 | :param occ_masks: n * [Nx1xHxW] 52 | :param crop_sz: 53 | :return: 54 | """ 55 | _, _, h, w = img.size() 56 | c_h, c_w = crop_sz 57 | 58 | if c_h == h and c_w == w: 59 | return img, flow, occ_mask 60 | 61 | x1 = np.random.randint(0, w - c_w) 62 | y1 = np.random.randint(0, h - c_h) 63 | img = img[:, :, y1 : y1 + c_h, x1 : x1 + c_w] 64 | full_segs = full_segs[:, :, y1 : y1 + c_h, x1 : x1 + c_w] 65 | flow = flow[:, :, y1 : y1 + c_h, x1 : x1 + c_w] 66 | occ_mask = occ_mask[:, :, y1 : y1 + c_h, x1 : x1 + c_w] 67 | 68 | return img, full_segs, flow, occ_mask 69 | 70 | 71 | # def semantic_connected_components( 72 | # semseg, class_indices, width_range=(30, 200), height_range=(30, 100) 73 | # ): 74 | # """ 75 | # Input: 76 | # semsegs: Onehot semantic segmentations of size [c, H, W] 77 | # class_indices: A list of the indices of the classes of interest. 78 | # For example, [car_idx] or [sign_idx, pole_idx, traffic_light_idx] 79 | # width_range, height_range: The width and height ranges for the objects of interest. 80 | # Output: 81 | # list of masks for cars in the size range 82 | # """ 83 | 84 | # curr_sem = semseg[class_indices].sum(dim=0) 85 | # curr_sem = (curr_sem[:, :, None] * 255).numpy().astype(np.uint8) 86 | # num_labels, labels = cv2.connectedComponents(curr_sem) 87 | 88 | # sem_list = [] 89 | # # 0 is background, so ignore them. 90 | # for i in range(1, num_labels): 91 | # curr_obj = labels == i 92 | # hs, ws = np.where(curr_obj) 93 | # h_len, w_len = np.max(hs) - np.min(hs), np.max(ws) - np.min(ws) 94 | # if (height_range[0] < h_len < height_range[1]) and ( 95 | # width_range[0] < w_len < width_range[1] 96 | # ): 97 | # if ( 98 | # curr_obj.sum() / ((h_len + 1) * (w_len + 1)) > 0.6 99 | # ): # filter some wrong car estimates or largely occluded cars 100 | # sem_list.append(curr_obj) 101 | 102 | # return sem_list 103 | 104 | 105 | # def find_semantic_group(semseg, class_indices, win_width=200): 106 | # curr_sem = semseg[class_indices].sum(dim=0).numpy() 107 | # freq = curr_sem.mean(axis=0) # 1d frequency 108 | # freq_win = np.convolve( 109 | # freq, np.ones(win_width) / win_width, mode="valid" 110 | # ) # find the most frequent window 111 | 112 | # if max(freq_win) > 0.1: 113 | # # optimal window: [left, right) 114 | # left = np.argmax(freq_win) 115 | # right = left + win_width 116 | # curr_sem[:, :left] = 0 117 | # curr_sem[:, right:] = 0 118 | # return curr_sem 119 | 120 | # else: 121 | # return None 122 | 123 | 124 | def add_fake_object(input_dict): 125 | 126 | # prepare input 127 | img1_ot = input_dict["img1_tgt"] 128 | img2_ot = input_dict["img2_tgt"] 129 | full_seg1_ot = input_dict["full_seg1_st"] 130 | full_seg2_ot = input_dict["full_seg2_st"] 131 | flow_ot = input_dict["flow_tgt"] 132 | noc_ot = input_dict["noc_tgt"] 133 | 134 | img = input_dict["img_src"] 135 | obj_mask = input_dict["obj_mask"] 136 | motion = input_dict["motion"][:, :, None, None] 137 | 138 | b, _, h, w = img1_ot.shape 139 | N1 = full_seg1_ot.max() 140 | N2 = full_seg2_ot.max() 141 | 142 | # add object to frame 1 143 | img1_ot = obj_mask * img + (1 - obj_mask) * img1_ot 144 | full_seg1_ot = obj_mask * (N1 + 1) + (1 - obj_mask) * full_seg1_ot 145 | 146 | # add object to frame 2 147 | new_obj_mask = flow_warp(obj_mask, -motion.repeat(1, 1, h, w), pad="zeros") 148 | new_img = flow_warp(img, -motion.repeat(1, 1, h, w), pad="border") 149 | img2_ot = new_obj_mask * new_img + (1 - new_obj_mask) * img2_ot 150 | full_seg2_ot = new_obj_mask * (N2 + 1) + (1 - new_obj_mask) * full_seg2_ot 151 | 152 | # change flow 153 | flow_ot = obj_mask * motion + (1 - obj_mask) * flow_ot 154 | noc_ot = torch.max(noc_ot, obj_mask) # where we are confident about flow_ot 155 | 156 | return img1_ot, img2_ot, full_seg1_ot, full_seg2_ot, flow_ot, noc_ot, new_obj_mask 157 | -------------------------------------------------------------------------------- /transforms/ar_transforms/sp_transforms.py: -------------------------------------------------------------------------------- 1 | # Part of the code from https://github.com/visinf/irr/blob/master/augmentations.py 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from transforms.ar_transforms.interpolation import Interp2, Meshgrid 7 | 8 | 9 | def denormalize_coords(xx, yy, width, height): 10 | """scale indices from [-1, 1] to [0, width/height]""" 11 | xx = 0.5 * (width - 1.0) * (xx.float() + 1.0) 12 | yy = 0.5 * (height - 1.0) * (yy.float() + 1.0) 13 | return xx, yy 14 | 15 | 16 | def normalize_coords(xx, yy, width, height): 17 | """scale indices from [0, width/height] to [-1, 1]""" 18 | xx = (2.0 / (width - 1.0)) * xx.float() - 1.0 19 | yy = (2.0 / (height - 1.0)) * yy.float() - 1.0 20 | return xx, yy 21 | 22 | 23 | def apply_transform_to_params(theta0, theta_transform): 24 | a1 = theta0[:, 0] 25 | a2 = theta0[:, 1] 26 | a3 = theta0[:, 2] 27 | a4 = theta0[:, 3] 28 | a5 = theta0[:, 4] 29 | a6 = theta0[:, 5] 30 | # 31 | b1 = theta_transform[:, 0] 32 | b2 = theta_transform[:, 1] 33 | b3 = theta_transform[:, 2] 34 | b4 = theta_transform[:, 3] 35 | b5 = theta_transform[:, 4] 36 | b6 = theta_transform[:, 5] 37 | # 38 | c1 = a1 * b1 + a4 * b2 39 | c2 = a2 * b1 + a5 * b2 40 | c3 = b3 + a3 * b1 + a6 * b2 41 | c4 = a1 * b4 + a4 * b5 42 | c5 = a2 * b4 + a5 * b5 43 | c6 = b6 + a3 * b4 + a6 * b5 44 | # 45 | new_theta = torch.stack([c1, c2, c3, c4, c5, c6], dim=1) 46 | return new_theta 47 | 48 | 49 | class _IdentityParams(nn.Module): 50 | def __init__(self): 51 | super(_IdentityParams, self).__init__() 52 | self._batch_size = 0 53 | self.register_buffer("_o", torch.FloatTensor()) 54 | self.register_buffer("_i", torch.FloatTensor()) 55 | 56 | def _update(self, batch_size): 57 | torch.zeros([batch_size, 1], out=self._o) 58 | torch.ones([batch_size, 1], out=self._i) 59 | return torch.cat([self._i, self._o, self._o, self._o, self._i, self._o], dim=1) 60 | 61 | def forward(self, batch_size): 62 | if self._batch_size != batch_size: 63 | self._identity_params = self._update(batch_size) 64 | self._batch_size = batch_size 65 | return self._identity_params 66 | 67 | 68 | class RandomMirror(nn.Module): 69 | def __init__(self, vertical=True, p=0.5): 70 | super(RandomMirror, self).__init__() 71 | self._batch_size = 0 72 | self._p = p 73 | self._vertical = vertical 74 | self.register_buffer("_mirror_probs", torch.FloatTensor()) 75 | 76 | def update_probs(self, batch_size): 77 | torch.ones([batch_size, 1], out=self._mirror_probs) 78 | self._mirror_probs *= self._p 79 | 80 | def forward(self, theta_list): 81 | batch_size = theta_list[0].size(0) 82 | if batch_size != self._batch_size: 83 | self.update_probs(batch_size) 84 | self._batch_size = batch_size 85 | 86 | # apply random sign to a1 a2 a3 (these are the guys responsible for x) 87 | sign = torch.sign(2.0 * torch.bernoulli(self._mirror_probs) - 1.0) 88 | i = torch.ones_like(sign) 89 | horizontal_mirror = torch.cat([sign, sign, sign, i, i, i], dim=1) 90 | theta_list = [theta * horizontal_mirror for theta in theta_list] 91 | 92 | # apply random sign to a4 a5 a6 (these are the guys responsible for y) 93 | if self._vertical: 94 | sign = torch.sign(2.0 * torch.bernoulli(self._mirror_probs) - 1.0) 95 | vertical_mirror = torch.cat([i, i, i, sign, sign, sign], dim=1) 96 | theta_list = [theta * vertical_mirror for theta in theta_list] 97 | 98 | return theta_list 99 | 100 | 101 | class RandomAffineFlow(nn.Module): 102 | def __init__(self, cfg, addnoise=True): 103 | super(RandomAffineFlow, self).__init__() 104 | self.cfg = cfg 105 | self._interp2 = Interp2(clamp=False) 106 | self._flow_interp2 = Interp2(clamp=False) 107 | self._meshgrid = Meshgrid() 108 | self._identity = _IdentityParams() 109 | self._random_mirror = ( 110 | RandomMirror(cfg.vflip) if cfg.hflip else RandomMirror(p=1) 111 | ) 112 | self._addnoise = addnoise 113 | 114 | self.register_buffer("_noise1", torch.FloatTensor()) 115 | self.register_buffer("_noise2", torch.FloatTensor()) 116 | self.register_buffer("_xbounds", torch.FloatTensor([-1, -1, 1, 1])) 117 | self.register_buffer("_ybounds", torch.FloatTensor([-1, 1, -1, 1])) 118 | self.register_buffer("_x", torch.IntTensor(1)) 119 | self.register_buffer("_y", torch.IntTensor(1)) 120 | 121 | def inverse_transform_coords( 122 | self, width, height, thetas, offset_x=None, offset_y=None 123 | ): 124 | xx, yy = self._meshgrid(width=width, height=height) 125 | 126 | xx = torch.unsqueeze(xx, dim=0).float() 127 | yy = torch.unsqueeze(yy, dim=0).float() 128 | 129 | if offset_x is not None: 130 | xx = xx + offset_x 131 | if offset_y is not None: 132 | yy = yy + offset_y 133 | 134 | a1 = thetas[:, 0].contiguous().view(-1, 1, 1) 135 | a2 = thetas[:, 1].contiguous().view(-1, 1, 1) 136 | a3 = thetas[:, 2].contiguous().view(-1, 1, 1) 137 | a4 = thetas[:, 3].contiguous().view(-1, 1, 1) 138 | a5 = thetas[:, 4].contiguous().view(-1, 1, 1) 139 | a6 = thetas[:, 5].contiguous().view(-1, 1, 1) 140 | 141 | xx, yy = normalize_coords(xx, yy, width=width, height=height) 142 | xq = a1 * xx + a2 * yy + a3 143 | yq = a4 * xx + a5 * yy + a6 144 | xq, yq = denormalize_coords(xq, yq, width=width, height=height) 145 | return xq, yq 146 | 147 | def transform_coords(self, width, height, thetas): 148 | xx1, yy1 = self._meshgrid(width=width, height=height) 149 | xx, yy = normalize_coords(xx1, yy1, width=width, height=height) 150 | 151 | def _unsqueeze12(u): 152 | return torch.unsqueeze(torch.unsqueeze(u, dim=1), dim=1) 153 | 154 | a1 = _unsqueeze12(thetas[:, 0]) 155 | a2 = _unsqueeze12(thetas[:, 1]) 156 | a3 = _unsqueeze12(thetas[:, 2]) 157 | a4 = _unsqueeze12(thetas[:, 3]) 158 | a5 = _unsqueeze12(thetas[:, 4]) 159 | a6 = _unsqueeze12(thetas[:, 5]) 160 | # 161 | z = a1 * a5 - a2 * a4 162 | b1 = a5 / z 163 | b2 = -a2 / z 164 | b4 = -a4 / z 165 | b5 = a1 / z 166 | # 167 | xhat = xx - a3 168 | yhat = yy - a6 169 | xq = b1 * xhat + b2 * yhat 170 | yq = b4 * xhat + b5 * yhat 171 | 172 | xq, yq = denormalize_coords(xq, yq, width=width, height=height) 173 | return xq, yq 174 | 175 | def find_invalid(self, width, height, thetas): 176 | x = self._xbounds 177 | y = self._ybounds 178 | # 179 | a1 = torch.unsqueeze(thetas[:, 0], dim=1) 180 | a2 = torch.unsqueeze(thetas[:, 1], dim=1) 181 | a3 = torch.unsqueeze(thetas[:, 2], dim=1) 182 | a4 = torch.unsqueeze(thetas[:, 3], dim=1) 183 | a5 = torch.unsqueeze(thetas[:, 4], dim=1) 184 | a6 = torch.unsqueeze(thetas[:, 5], dim=1) 185 | # 186 | z = a1 * a5 - a2 * a4 187 | b1 = a5 / z 188 | b2 = -a2 / z 189 | b4 = -a4 / z 190 | b5 = a1 / z 191 | # 192 | xhat = x - a3 193 | yhat = y - a6 194 | xq = b1 * xhat + b2 * yhat 195 | yq = b4 * xhat + b5 * yhat 196 | xq, yq = denormalize_coords(xq, yq, width=width, height=height) 197 | # 198 | invalid = ((xq < 0) | (yq < 0) | (xq >= width) | (yq >= height)).sum( 199 | dim=1, keepdim=True 200 | ) > 0 201 | 202 | return invalid 203 | 204 | def apply_random_transforms_to_params( 205 | self, 206 | theta0, 207 | max_translate, 208 | min_zoom, 209 | max_zoom, 210 | min_squeeze, 211 | max_squeeze, 212 | min_rotate, 213 | max_rotate, 214 | validate_size=None, 215 | ): 216 | max_translate *= 0.5 217 | batch_size = theta0.size(0) 218 | height, width = validate_size 219 | 220 | # collect valid params here 221 | thetas = torch.zeros_like(theta0) 222 | 223 | zoom = theta0.new(batch_size, 1).zero_() 224 | squeeze = torch.zeros_like(zoom) 225 | tx = torch.zeros_like(zoom) 226 | ty = torch.zeros_like(zoom) 227 | phi = torch.zeros_like(zoom) 228 | invalid = torch.ones_like(zoom).byte() 229 | 230 | while invalid.sum() > 0: 231 | # random sampling 232 | zoom.uniform_(min_zoom, max_zoom) 233 | squeeze.uniform_(min_squeeze, max_squeeze) 234 | tx.uniform_(-max_translate, max_translate) 235 | ty.uniform_(-max_translate, max_translate) 236 | phi.uniform_(min_rotate, max_rotate) 237 | 238 | # construct affine parameters 239 | sx = zoom * squeeze 240 | sy = zoom / squeeze 241 | sin_phi = torch.sin(phi) 242 | cos_phi = torch.cos(phi) 243 | b1 = cos_phi * sx 244 | b2 = sin_phi * sy 245 | b3 = tx 246 | b4 = -sin_phi * sx 247 | b5 = cos_phi * sy 248 | b6 = ty 249 | 250 | theta_transform = torch.cat([b1, b2, b3, b4, b5, b6], dim=1) 251 | theta_try = apply_transform_to_params(theta0, theta_transform) 252 | thetas = invalid.float() * theta_try + (1 - invalid.float()) * thetas 253 | 254 | # compute new invalid ones 255 | invalid = self.find_invalid(width=width, height=height, thetas=thetas) 256 | 257 | # here we should have good thetas within borders 258 | return thetas 259 | 260 | def transform_image(self, images, thetas): 261 | batch_size, channels, height, width = images.size() 262 | xq, yq = self.transform_coords(width=width, height=height, thetas=thetas) 263 | transformed = self._interp2(images, xq, yq) 264 | return transformed 265 | 266 | def transform_flow(self, flow, theta1, theta2): 267 | batch_size, channels, height, width = flow.size() 268 | u = flow[:, 0, :, :] 269 | v = flow[:, 1, :, :] 270 | 271 | # inverse transform coords 272 | x0, y0 = self.inverse_transform_coords( 273 | width=width, height=height, thetas=theta1 274 | ) 275 | 276 | x1, y1 = self.inverse_transform_coords( 277 | width=width, height=height, thetas=theta2, offset_x=u, offset_y=v 278 | ) 279 | 280 | # subtract and create new flow 281 | u = x1 - x0 282 | v = y1 - y0 283 | new_flow = torch.stack([u, v], dim=1) 284 | 285 | # transform coords 286 | xq, yq = self.transform_coords(width=width, height=height, thetas=theta1) 287 | 288 | # interp2 289 | transformed = self._flow_interp2(new_flow, xq, yq) 290 | return transformed 291 | 292 | def forward(self, data): 293 | # 01234 flow 12 21 23 32 294 | imgs = data["imgs"] 295 | full_segs = data["full_segs"] 296 | flows_f = data["flows_f"] 297 | masks_f = data["masks_f"] 298 | 299 | batch_size, _, height, width = imgs[0].size() 300 | 301 | # identity = no transform 302 | theta0 = self._identity(batch_size) 303 | 304 | # global transform 305 | theta_list = [ 306 | self.apply_random_transforms_to_params( 307 | theta0, 308 | max_translate=self.cfg.trans[0], 309 | min_zoom=self.cfg.zoom[0], 310 | max_zoom=self.cfg.zoom[1], 311 | min_squeeze=self.cfg.squeeze[0], 312 | max_squeeze=self.cfg.squeeze[1], 313 | min_rotate=self.cfg.rotate[0], 314 | max_rotate=self.cfg.rotate[1], 315 | validate_size=[height, width], 316 | ) 317 | ] 318 | 319 | # relative transform 320 | for _i in range(len(imgs) - 1): 321 | theta_list.append( 322 | self.apply_random_transforms_to_params( 323 | theta_list[-1], 324 | max_translate=self.cfg.trans[1], 325 | min_zoom=self.cfg.zoom[2], 326 | max_zoom=self.cfg.zoom[3], 327 | min_squeeze=self.cfg.squeeze[2], 328 | max_squeeze=self.cfg.squeeze[3], 329 | min_rotate=self.cfg.rotate[2], 330 | max_rotate=self.cfg.rotate[3], 331 | validate_size=[height, width], 332 | ) 333 | ) 334 | 335 | # random flip images 336 | theta_list = self._random_mirror(theta_list) 337 | 338 | # 01234 339 | imgs = [self.transform_image(im, theta) for im, theta in zip(imgs, theta_list)] 340 | full_segs = [ 341 | self.transform_image(full_seg, theta) 342 | for full_seg, theta in zip(full_segs, theta_list) 343 | ] 344 | 345 | if len(imgs) > 2: 346 | theta_list = theta_list[1:-1] 347 | # 12 23 348 | flows_f = [ 349 | self.transform_flow(flo, theta1, theta2) 350 | for flo, theta1, theta2 in zip(flows_f, theta_list[:-1], theta_list[1:]) 351 | ] 352 | 353 | masks_f = [ 354 | self.transform_image(mask, theta) 355 | for mask, theta in zip(masks_f, theta_list) 356 | ] 357 | 358 | if self._addnoise: 359 | stddev = np.random.uniform(0.0, 0.04) 360 | for im in imgs: 361 | noise = torch.zeros_like(im) 362 | noise.normal_(std=stddev) 363 | im.add_(noise) 364 | im.clamp_(0.0, 1.0) 365 | 366 | data["imgs"] = imgs 367 | data["full_segs"] = full_segs 368 | data["flows_f"] = flows_f 369 | data["masks_f"] = masks_f 370 | return data 371 | -------------------------------------------------------------------------------- /transforms/co_transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | """ 4 | 5 | import numbers 6 | import random 7 | 8 | import numpy as np 9 | 10 | 11 | def get_co_transforms(aug_args): 12 | transforms = [] 13 | if aug_args.swap: # swap first and second frame 14 | transforms.append(RandomTemporalSwap()) 15 | if aug_args.hflip: 16 | transforms.append(RandomHorizontalFlip()) 17 | if aug_args.crop: 18 | transforms.append(RandomCrop(aug_args.para_crop)) 19 | return Compose(transforms) 20 | 21 | 22 | class Compose: 23 | def __init__(self, co_transforms): 24 | self.co_transforms = co_transforms 25 | 26 | def __call__(self, imgs, full_segs, key_objs, target): 27 | for t in self.co_transforms: 28 | imgs, full_segs, key_objs, target = t(imgs, full_segs, key_objs, target) 29 | return imgs, full_segs, key_objs, target 30 | 31 | 32 | class RandomCrop: 33 | """Crops the given PIL.Image at a random location to have a region of 34 | the given size. size can be a tuple (target_height, target_width) 35 | or an integer, in which case the target will be of a square shape (size, size) 36 | """ 37 | 38 | def __init__(self, size): 39 | if isinstance(size, numbers.Number): 40 | self.size = (int(size), int(size)) 41 | else: 42 | self.size = size 43 | 44 | def __call__(self, imgs, full_segs, key_objs, target): 45 | h, w, _ = imgs[0].shape 46 | th, tw = self.size 47 | if w == tw and h == th: 48 | return imgs, target 49 | 50 | x1 = random.randint(0, w - tw) 51 | y1 = random.randint(0, h - th) 52 | imgs = [img[y1 : y1 + th, x1 : x1 + tw] for img in imgs] 53 | full_segs = [full_seg[y1 : y1 + th, x1 : x1 + tw] for full_seg in full_segs] 54 | key_objs = [key_obj[y1 : y1 + th, x1 : x1 + tw] for key_obj in key_objs] 55 | 56 | if target != {}: 57 | raise NotImplementedError( 58 | "RandomCrop currently does not take ground-truth labels" 59 | ) 60 | 61 | return imgs, full_segs, key_objs, target 62 | 63 | 64 | class RandomTemporalSwap: 65 | """Randomly swap first and second frames""" 66 | 67 | def __call__(self, imgs, full_segs, key_objs, target): 68 | 69 | if random.random() < 0.5: 70 | imgs = imgs[::-1] 71 | full_segs = full_segs[::-1] 72 | key_objs = key_objs[::-1] 73 | 74 | if target != {}: 75 | raise NotImplementedError( 76 | "RandomTemporalSwap currently does not take ground-truth labels" 77 | ) 78 | 79 | return imgs, full_segs, key_objs, target 80 | 81 | 82 | class RandomHorizontalFlip: 83 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5""" 84 | 85 | def __call__(self, imgs, full_segs, key_objs, target): 86 | if random.random() < 0.5: 87 | imgs = [np.copy(np.fliplr(im)) for im in imgs] 88 | full_segs = [np.copy(np.fliplr(full_seg)) for full_seg in full_segs] 89 | key_objs = [np.copy(np.fliplr(key_obj)) for key_obj in key_objs] 90 | 91 | if target != {}: 92 | raise NotImplementedError( 93 | "RandomHorizontalFlip currently does not take ground-truth labels" 94 | ) 95 | return imgs, full_segs, key_objs, target 96 | -------------------------------------------------------------------------------- /transforms/input_transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | """ 4 | 5 | import cv2 6 | import numpy as np 7 | 8 | import torch 9 | from torch.nn import functional as F 10 | 11 | 12 | class ArrayToTensor: 13 | """Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W).""" 14 | 15 | def __call__(self, all_data): 16 | imgs, full_segs, key_objs = all_data 17 | imgs = [torch.from_numpy(img.transpose((2, 0, 1))).float() for img in imgs] 18 | full_segs = [ 19 | torch.from_numpy(full_seg.transpose((2, 0, 1))).float() 20 | for full_seg in full_segs 21 | ] 22 | 23 | if key_objs is not None: 24 | key_objs = [ 25 | torch.from_numpy(key_obj.transpose((2, 0, 1))).float() 26 | for key_obj in key_objs 27 | ] 28 | 29 | return imgs, full_segs, key_objs 30 | 31 | 32 | class Zoom: 33 | def __init__(self, new_h, new_w): 34 | self.new_h = new_h 35 | self.new_w = new_w 36 | 37 | def __call__(self, all_data): 38 | imgs, full_segs, key_objs = all_data 39 | imgs = [cv2.resize(img, (self.new_w, self.new_h)) for img in imgs] 40 | full_segs = [ 41 | cv2.resize( 42 | full_seg, (self.new_w, self.new_h), interpolation=cv2.INTER_NEAREST 43 | )[:, :, None] 44 | for full_seg in full_segs 45 | ] 46 | 47 | if key_objs is not None: 48 | new_key_objs = [] 49 | for key_obj in key_objs: 50 | if key_obj.shape[-1] == 0: ## no key obj found 51 | new_key_obj = np.zeros((self.new_h, self.new_w, 0), dtype=np.uint8) 52 | elif key_obj.shape[-1] == 1: 53 | new_key_obj = cv2.resize( 54 | key_obj, 55 | (self.new_w, self.new_h), 56 | interpolation=cv2.INTER_NEAREST, 57 | )[:, :, None] 58 | else: 59 | new_key_obj = cv2.resize( 60 | key_obj, 61 | (self.new_w, self.new_h), 62 | interpolation=cv2.INTER_NEAREST, 63 | ) 64 | 65 | new_key_objs.append(new_key_obj) 66 | else: 67 | new_key_objs = None 68 | 69 | return imgs, full_segs, new_key_objs 70 | 71 | 72 | def full_segs_to_adj_maps(full_segs, win_size=9, pad_mode="replicate"): 73 | """ 74 | Input: full_segs: [B, 1, H, W] 75 | Output: adj_maps: [B, win_size * win_size, H, W] 76 | """ 77 | 78 | r = (win_size - 1) // 2 79 | b, _, h, w = full_segs.shape 80 | full_segs_padded = F.pad(full_segs, (r, r, r, r), mode=pad_mode) 81 | 82 | nb = F.unfold(full_segs_padded, [win_size, win_size]) 83 | nb = nb.reshape((b, win_size * win_size, h, w)) 84 | 85 | adj_maps = (full_segs == nb).float() 86 | return adj_maps 87 | -------------------------------------------------------------------------------- /utils/config_parser.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | """ 4 | 5 | import json 6 | import os 7 | 8 | from easydict import EasyDict 9 | from utils.manifold_utils import pathmgr 10 | 11 | 12 | def update_config(base_cfg, new_cfg): 13 | for key in new_cfg: 14 | if key not in base_cfg: 15 | base_cfg[key] = new_cfg[key] 16 | elif type(base_cfg[key]) == EasyDict and type(new_cfg[key]) == EasyDict: 17 | update_config(base_cfg[key], new_cfg[key]) 18 | else: 19 | base_cfg[key] = new_cfg[key] 20 | 21 | return base_cfg 22 | 23 | 24 | def init_config(cfg_file): 25 | with open(pathmgr.get_local_path(cfg_file)) as f: 26 | cfg = EasyDict(json.load(f)) 27 | 28 | if "base_configs" in cfg: 29 | base_cfg_file = os.path.join(os.path.dirname(cfg_file), cfg.base_configs) 30 | with open(pathmgr.get_local_path(base_cfg_file)) as f: 31 | base_cfg = EasyDict(json.load(f)) 32 | cfg = update_config(base_cfg, cfg) 33 | 34 | return cfg 35 | -------------------------------------------------------------------------------- /utils/flow_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import imageio 5 | import numpy as np 6 | import torch 7 | 8 | 9 | def load_flow(path): 10 | if path.endswith(".png"): 11 | # for KITTI which uses 16bit PNG images 12 | # see 'https://github.com/ClementPinard/FlowNetPytorch/blob/master/datasets/KITTI.py' 13 | # The -1 is here to specify not to change the image depth (16bit), and is compatible 14 | # with both OpenCV2 and OpenCV3 15 | flo_file = cv2.imread(path, -1) 16 | flo_img = flo_file[:, :, 2:0:-1].astype(np.float32) 17 | invalid = flo_file[:, :, 0] == 0 # mask 18 | flo_img = flo_img - 32768 19 | flo_img = flo_img / 64 20 | flo_img[np.abs(flo_img) < 1e-10] = 1e-10 21 | flo_img[invalid, :] = 0 22 | return flo_img, np.expand_dims(flo_file[:, :, 0], 2) 23 | else: 24 | with open(path, "rb") as f: 25 | magic = np.fromfile(f, np.float32, count=1) 26 | assert 202021.25 == magic, "Magic number incorrect. Invalid .flo file" 27 | h = np.fromfile(f, np.int32, count=1)[0] 28 | w = np.fromfile(f, np.int32, count=1)[0] 29 | data = np.fromfile(f, np.float32, count=2 * w * h) 30 | # Reshape data into 3D array (columns, rows, bands) 31 | data2D = np.resize(data, (w, h, 2)) 32 | return data2D 33 | 34 | 35 | def load_mask(path): 36 | # 0~255 HxWx1 37 | mask = imageio.imread(path).astype(np.float32) / 255.0 38 | if len(mask.shape) == 3: 39 | mask = mask[:, :, 0] 40 | return np.expand_dims(mask, -1) 41 | 42 | 43 | # def flow_to_image(flow, max_flow=256): 44 | # import numpy as np 45 | # from matplotlib.colors import hsv_to_rgb 46 | # if max_flow is not None: 47 | # max_flow = max(max_flow, 1.0) 48 | # else: 49 | # max_flow = np.max(flow) 50 | 51 | # n = 8 52 | # u, v = flow[:, :, 0], flow[:, :, 1] 53 | # mag = np.sqrt(np.square(u) + np.square(v)) 54 | # angle = np.arctan2(v, u) 55 | # im_h = np.mod(angle / (2 * np.pi) + 1, 1) 56 | # im_s = np.clip(mag * n / max_flow, a_min=0, a_max=1) 57 | # im_v = np.clip(n - im_s, a_min=0, a_max=1) 58 | # im = hsv_to_rgb(np.stack([im_h, im_s, im_v], 2)) 59 | # return (im * 255).astype(np.uint8) 60 | 61 | 62 | def resize_flow(flow, new_shape): 63 | _, _, h, w = flow.shape 64 | new_h, new_w = new_shape 65 | flow = torch.nn.functional.interpolate( 66 | flow, (new_h, new_w), mode="bilinear", align_corners=True 67 | ) 68 | scale_h, scale_w = h / float(new_h), w / float(new_w) 69 | flow[:, 0] /= scale_w 70 | flow[:, 1] /= scale_h 71 | return flow 72 | 73 | 74 | # credit: https://github.com/princeton-vl/RAFT/blob/master/core/utils/frame_utils.py 75 | def writeFlowSintel(filename, uv, v=None): 76 | """Write optical flow to file. 77 | 78 | If v is None, uv is assumed to contain both u and v channels, 79 | stacked in depth. 80 | Original code by Deqing Sun, adapted from Daniel Scharstein. 81 | """ 82 | nBands = 2 83 | TAG_CHAR = np.array([202021.25], np.float32) 84 | 85 | if v is None: 86 | assert uv.ndim == 3 87 | assert uv.shape[2] == 2 88 | u = uv[:, :, 0] 89 | v = uv[:, :, 1] 90 | else: 91 | u = uv 92 | 93 | assert u.shape == v.shape 94 | height, width = u.shape 95 | 96 | os.makedirs(os.path.dirname(filename), exist_ok=True) 97 | with open(filename, "wb") as f: 98 | # write the header 99 | f.write(TAG_CHAR) 100 | np.array(width).astype(np.int32).tofile(f) 101 | np.array(height).astype(np.int32).tofile(f) 102 | # arrange into matrix form 103 | tmp = np.zeros((height, width * nBands)) 104 | tmp[:, np.arange(width) * 2] = u 105 | tmp[:, np.arange(width) * 2 + 1] = v 106 | tmp.astype(np.float32).tofile(f) 107 | 108 | 109 | # credit: https://github.com/princeton-vl/RAFT/blob/master/core/utils/frame_utils.py 110 | def writeFlowKITTI(filename, uv): 111 | uv = 64.0 * uv + 2**15 112 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 113 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 114 | cv2.imwrite(filename, uv[..., ::-1]) 115 | 116 | 117 | def evaluate_flow(gt_flows, pred_flows, moving_masks=None): 118 | # credit "undepthflow/eval/evaluate_flow.py" 119 | def calculate_error_rate(epe_map, gt_flow, mask): 120 | bad_pixels = np.logical_and( 121 | epe_map * mask > 3, 122 | epe_map * mask > 0.05 * np.sqrt(np.sum(np.square(gt_flow), axis=2)), 123 | ) 124 | return bad_pixels.sum() / mask.sum() * 100.0 125 | 126 | ( 127 | error, 128 | error_noc, 129 | error_occ, 130 | error_move, 131 | error_static, 132 | error_rate, 133 | error_rate_noc, 134 | ) = (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) 135 | error_move_rate, error_static_rate = 0.0, 0.0 136 | B = len(gt_flows) 137 | for gt_flow, pred_flow, i in zip(gt_flows, pred_flows, range(B)): 138 | H, W = gt_flow.shape[:2] 139 | h, w = pred_flow.shape[:2] 140 | 141 | # pred_flow = np.copy(pred_flow) 142 | # pred_flow[:, :, 0] = pred_flow[:, :, 0] / w * W 143 | # pred_flow[:, :, 1] = pred_flow[:, :, 1] / h * H 144 | 145 | # flo_pred = cv2.resize(pred_flow, (W, H), interpolation=cv2.INTER_LINEAR) 146 | 147 | pred_flow = torch.from_numpy(pred_flow)[None].permute(0, 3, 1, 2) 148 | flo_pred = resize_flow(pred_flow, (H, W)) 149 | flo_pred = flo_pred[0].numpy().transpose(1, 2, 0) 150 | 151 | epe_map = np.sqrt( 152 | np.sum(np.square(flo_pred[:, :, :2] - gt_flow[:, :, :2]), axis=2) 153 | ) 154 | if gt_flow.shape[-1] == 2: 155 | error += np.mean(epe_map) 156 | 157 | elif gt_flow.shape[-1] == 4: # with occ and noc mask 158 | error += np.sum(epe_map * gt_flow[:, :, 2]) / np.sum(gt_flow[:, :, 2]) 159 | noc_mask = gt_flow[:, :, -1] 160 | error_noc += np.sum(epe_map * noc_mask) / np.sum(noc_mask) 161 | 162 | error_occ += np.sum(epe_map * (gt_flow[:, :, 2] - noc_mask)) / max( 163 | np.sum(gt_flow[:, :, 2] - noc_mask), 1.0 164 | ) 165 | 166 | error_rate += calculate_error_rate( 167 | epe_map, gt_flow[:, :, 0:2], gt_flow[:, :, 2] 168 | ) 169 | error_rate_noc += calculate_error_rate( 170 | epe_map, gt_flow[:, :, 0:2], noc_mask 171 | ) 172 | if moving_masks is not None: 173 | move_mask = moving_masks[i] 174 | 175 | error_move_rate += calculate_error_rate( 176 | epe_map, gt_flow[:, :, 0:2], gt_flow[:, :, 2] * move_mask 177 | ) 178 | error_static_rate += calculate_error_rate( 179 | epe_map, gt_flow[:, :, 0:2], gt_flow[:, :, 2] * (1.0 - move_mask) 180 | ) 181 | 182 | error_move += np.sum(epe_map * gt_flow[:, :, 2] * move_mask) / np.sum( 183 | gt_flow[:, :, 2] * move_mask 184 | ) 185 | error_static += np.sum( 186 | epe_map * gt_flow[:, :, 2] * (1.0 - move_mask) 187 | ) / np.sum(gt_flow[:, :, 2] * (1.0 - move_mask)) 188 | 189 | if gt_flows[0].shape[-1] == 4: 190 | res = [ 191 | error / B, 192 | error_noc / B, 193 | error_occ / B, 194 | error_rate / B, 195 | error_rate_noc / B, 196 | ] 197 | if moving_masks is not None: 198 | res += [error_move / B, error_static / B] 199 | return res 200 | else: 201 | return [error / B] 202 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging.config 3 | import logging.handlers 4 | 5 | # from pathlib import Path 6 | 7 | 8 | def init_logger(level="INFO", log_name="main_logger"): 9 | 10 | logger = logging.getLogger(log_name) 11 | logger.setLevel(level) 12 | 13 | fh = logging.StreamHandler() 14 | formatter = logging.Formatter( 15 | "[%(levelname)s] %(message)s", 16 | ) 17 | fh.setFormatter(formatter) 18 | logger.addHandler(fh) 19 | 20 | logger.info("Start logging!") 21 | return logger 22 | 23 | 24 | # def init_logger( 25 | # level="INFO", log_dir="./", log_name="main_logger", filename="main.log" 26 | # ): 27 | 28 | # logger = logging.getLogger(log_name) 29 | 30 | # fh = logging.handlers.RotatingFileHandler( 31 | # Path(log_dir) / filename, "w", 20 * 1024 * 1024, 5 32 | # ) 33 | # formatter = logging.Formatter( 34 | # "%(asctime)s %(levelname)5s - %(name)s " 35 | # "[%(filename)s line %(lineno)d] - %(message)s", 36 | # datefmt="%m-%d %H:%M:%S", 37 | # ) 38 | # fh.setFormatter(formatter) 39 | # logger.addHandler(fh) 40 | 41 | # # logging to screen 42 | # if "DEBUG" in log_dir: 43 | # fh = logging.StreamHandler() 44 | # formatter = logging.Formatter( 45 | # "[%(levelname)s] %(message)s", 46 | # ) 47 | # fh.setFormatter(formatter) 48 | # logger.addHandler(fh) 49 | 50 | # logger.setLevel(level) 51 | # logger.info("Start training") 52 | # return logger 53 | -------------------------------------------------------------------------------- /utils/misc_utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | 4 | def update_dict(orig_dict, new_dict): 5 | for key, val in new_dict.items(): 6 | if isinstance(val, collections.Mapping): 7 | tmp = update_dict(orig_dict.get(key, {}), val) 8 | orig_dict[key] = tmp 9 | else: 10 | orig_dict[key] = val 11 | return orig_dict 12 | 13 | 14 | class AverageMeter: 15 | """Computes and stores the average and current value""" 16 | 17 | def __init__(self, i=1, precision=3, names=None): 18 | self.meters = i 19 | self.precision = precision 20 | self.reset(self.meters) 21 | self.names = names 22 | if names is not None: 23 | assert self.meters == len(self.names) 24 | else: 25 | self.names = [""] * self.meters 26 | 27 | def reset(self, i): 28 | self.val = [0] * i 29 | self.avg = [0] * i 30 | self.sum = [0] * i 31 | self.count = [0] * i 32 | 33 | def update(self, val, n=1): 34 | if not isinstance(val, list): 35 | val = [val] 36 | if not isinstance(n, list): 37 | n = [n] * self.meters 38 | assert len(val) == self.meters and len(n) == self.meters 39 | for i in range(self.meters): 40 | self.count[i] += n[i] 41 | for i, v in enumerate(val): 42 | self.val[i] = v 43 | self.sum[i] += v * n[i] 44 | self.avg[i] = self.sum[i] / self.count[i] 45 | 46 | def __repr__(self): 47 | val = " ".join( 48 | [ 49 | "{} {:.{}f}".format(n, v, self.precision) 50 | for n, v in zip(self.names, self.val) 51 | ] 52 | ) 53 | avg = " ".join( 54 | [ 55 | "{} {:.{}f}".format(n, a, self.precision) 56 | for n, a in zip(self.names, self.avg) 57 | ] 58 | ) 59 | return "{} ({})".format(val, avg) 60 | -------------------------------------------------------------------------------- /utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | """ 4 | 5 | import math 6 | import os 7 | import random 8 | 9 | # import shutil 10 | 11 | import numpy as np 12 | import torch 13 | 14 | # import torch.nn as nn 15 | # import torch.nn.functional as F 16 | from torch.optim import Optimizer 17 | from utils.manifold_utils import MANIFOLD_BUCKET, pathmgr 18 | 19 | 20 | def init_seed(seed): 21 | torch.manual_seed(seed) 22 | torch.cuda.manual_seed_all(seed) 23 | np.random.seed(seed) 24 | random.seed(seed) 25 | 26 | 27 | def weight_parameters(module): 28 | return [param for name, param in module.named_parameters() if ".weight" in name] 29 | 30 | 31 | def bias_parameters(module): 32 | return [param for name, param in module.named_parameters() if ".bias" in name] 33 | 34 | 35 | def other_parameters(module): 36 | return [ 37 | param 38 | for name, param in module.named_parameters() 39 | if ".bias" not in name and ".weight" not in name 40 | ] 41 | 42 | 43 | def load_checkpoint(model_path): 44 | # weights = torch.load(model_path) 45 | 46 | if "manifold" not in model_path: 47 | model_path = os.path.join("manifold://" + MANIFOLD_BUCKET, model_path) 48 | with pathmgr.open(model_path, "rb") as f: 49 | for i in range(3): 50 | try: 51 | weights = torch.load(f) 52 | break 53 | except Exception: 54 | if i == 2: 55 | raise Exception 56 | 57 | epoch = None 58 | if "epoch" in weights: 59 | epoch = weights.pop("epoch") 60 | if "state_dict" in weights: 61 | state_dict = weights["state_dict"] 62 | else: 63 | state_dict = weights 64 | return epoch, state_dict 65 | 66 | 67 | def save_checkpoint(save_path, states, file_prefixes, is_best, filename="ckpt.pth.tar"): 68 | def run_one_sample(save_path, state, prefix, is_best, filename): 69 | # torch.save(state, os.path.join(save_path, "{}_{}".format(prefix, filename))) 70 | 71 | if "manifold" not in save_path: 72 | save_path = os.path.join("manifold://" + MANIFOLD_BUCKET, save_path) 73 | save_path = os.path.join(save_path, "{}_{}".format(prefix, filename)) 74 | with pathmgr.open(save_path, "wb") as f: 75 | for i in range(3): 76 | try: 77 | torch.save(state, f) 78 | return 79 | except Exception: 80 | if i == 2: 81 | raise Exception 82 | 83 | if not isinstance(file_prefixes, str): 84 | for (prefix, state) in zip(file_prefixes, states): 85 | run_one_sample(save_path, state, prefix, is_best, filename) 86 | 87 | else: 88 | run_one_sample(save_path, states, file_prefixes, is_best, filename) 89 | 90 | 91 | def restore_model(model, pretrained_file): 92 | epoch, weights = load_checkpoint(pretrained_file) 93 | 94 | model_keys = set(model.state_dict().keys()) 95 | weight_keys = set(weights.keys()) 96 | 97 | # load weights by name 98 | weights_not_in_model = sorted(weight_keys - model_keys) 99 | model_not_in_weights = sorted(model_keys - weight_keys) 100 | if len(model_not_in_weights): 101 | print("Warning: There are weights in model but not in pre-trained.") 102 | for key in model_not_in_weights: 103 | print(key) 104 | weights[key] = model.state_dict()[key] 105 | if len(weights_not_in_model): 106 | print("Warning: There are pre-trained weights not in model.") 107 | for key in weights_not_in_model: 108 | print(key) 109 | from collections import OrderedDict 110 | 111 | new_weights = OrderedDict() 112 | for key in model_keys: 113 | new_weights[key] = weights[key] 114 | weights = new_weights 115 | 116 | model.load_state_dict(weights) 117 | return model 118 | 119 | 120 | class AdamW(Optimizer): 121 | """Implements AdamW algorithm. 122 | 123 | It has been proposed in `Fixing Weight Decay Regularization in Adam`_. 124 | 125 | Arguments: 126 | params (iterable): iterable of parameters to optimize or dicts defining 127 | parameter groups 128 | lr (float, optional): learning rate (default: 1e-3) 129 | betas (Tuple[float, float], optional): coefficients used for computing 130 | running averages of gradient and its square (default: (0.9, 0.999)) 131 | eps (float, optional): term added to the denominator to improve 132 | numerical stability (default: 1e-8) 133 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 134 | 135 | .. Fixing Weight Decay Regularization in Adam: 136 | https://arxiv.org/abs/1711.05101 137 | """ 138 | 139 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 140 | defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay} 141 | super(AdamW, self).__init__(params, defaults) 142 | 143 | def step(self, closure=None): 144 | """Performs a single optimization step. 145 | 146 | Arguments: 147 | closure (callable, optional): A closure that reevaluates the model 148 | and returns the loss. 149 | """ 150 | loss = None 151 | if closure is not None: 152 | loss = closure() 153 | 154 | for group in self.param_groups: 155 | for p in group["params"]: 156 | if p.grad is None: 157 | continue 158 | grad = p.grad.data 159 | if grad.is_sparse: 160 | raise RuntimeError( 161 | "AdamW does not support sparse gradients, please consider SparseAdam instead" 162 | ) 163 | 164 | state = self.state[p] 165 | 166 | # State initialization 167 | if len(state) == 0: 168 | state["step"] = 0 169 | # Exponential moving average of gradient values 170 | state["exp_avg"] = torch.zeros_like(p.data) 171 | # Exponential moving average of squared gradient values 172 | state["exp_avg_sq"] = torch.zeros_like(p.data) 173 | 174 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 175 | beta1, beta2 = group["betas"] 176 | 177 | state["step"] += 1 178 | 179 | # according to the paper, this penalty should come after the bias correction 180 | # if group['weight_decay'] != 0: 181 | # grad = grad.add(group['weight_decay'], p.data) 182 | 183 | # Decay the first and second moment running average coefficient 184 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 185 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 186 | 187 | denom = exp_avg_sq.sqrt().add_(group["eps"]) 188 | 189 | bias_correction1 = 1 - beta1 ** state["step"] 190 | bias_correction2 = 1 - beta2 ** state["step"] 191 | step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 192 | 193 | p.data.addcdiv_(-step_size, exp_avg, denom) 194 | 195 | if group["weight_decay"] != 0: 196 | p.data.add_(-group["weight_decay"], p.data) 197 | 198 | return loss 199 | -------------------------------------------------------------------------------- /utils/warp_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # import torch.nn.functional as F 5 | 6 | 7 | def mesh_grid(B, H, W): 8 | # mesh grid 9 | x_base = torch.arange(0, W).repeat(B, H, 1) # BHW 10 | y_base = torch.arange(0, H).repeat(B, W, 1).transpose(1, 2) # BHW 11 | 12 | base_grid = torch.stack([x_base, y_base], 1) # B2HW 13 | return base_grid 14 | 15 | 16 | def norm_grid(v_grid): 17 | _, _, H, W = v_grid.size() 18 | 19 | # scale grid to [-1,1] 20 | v_grid_norm = torch.zeros_like(v_grid) 21 | v_grid_norm[:, 0, :, :] = 2.0 * v_grid[:, 0, :, :] / (W - 1) - 1.0 22 | v_grid_norm[:, 1, :, :] = 2.0 * v_grid[:, 1, :, :] / (H - 1) - 1.0 23 | return v_grid_norm.permute(0, 2, 3, 1) # BHW2 24 | 25 | 26 | def get_corresponding_map(data): 27 | """ 28 | 29 | :param data: unnormalized coordinates Bx2xHxW 30 | :return: Bx1xHxW 31 | """ 32 | B, _, H, W = data.size() 33 | 34 | # x = data[:, 0, :, :].view(B, -1).clamp(0, W - 1) # BxN (N=H*W) 35 | # y = data[:, 1, :, :].view(B, -1).clamp(0, H - 1) 36 | 37 | x = data[:, 0, :, :].view(B, -1) # BxN (N=H*W) 38 | y = data[:, 1, :, :].view(B, -1) 39 | 40 | # invalid = (x < 0) | (x > W - 1) | (y < 0) | (y > H - 1) # BxN 41 | # invalid = invalid.repeat([1, 4]) 42 | 43 | x1 = torch.floor(x) 44 | x_floor = x1.clamp(0, W - 1) 45 | y1 = torch.floor(y) 46 | y_floor = y1.clamp(0, H - 1) 47 | x0 = x1 + 1 48 | x_ceil = x0.clamp(0, W - 1) 49 | y0 = y1 + 1 50 | y_ceil = y0.clamp(0, H - 1) 51 | 52 | x_ceil_out = x0 != x_ceil 53 | y_ceil_out = y0 != y_ceil 54 | x_floor_out = x1 != x_floor 55 | y_floor_out = y1 != y_floor 56 | invalid = torch.cat( 57 | [ 58 | x_ceil_out | y_ceil_out, 59 | x_ceil_out | y_floor_out, 60 | x_floor_out | y_ceil_out, 61 | x_floor_out | y_floor_out, 62 | ], 63 | dim=1, 64 | ) 65 | 66 | # encode coordinates, since the scatter function can only index along one axis 67 | corresponding_map = torch.zeros(B, H * W).type_as(data) 68 | indices = torch.cat( 69 | [ 70 | x_ceil + y_ceil * W, 71 | x_ceil + y_floor * W, 72 | x_floor + y_ceil * W, 73 | x_floor + y_floor * W, 74 | ], 75 | 1, 76 | ).long() # BxN (N=4*H*W) 77 | values = torch.cat( 78 | [ 79 | (1 - torch.abs(x - x_ceil)) * (1 - torch.abs(y - y_ceil)), 80 | (1 - torch.abs(x - x_ceil)) * (1 - torch.abs(y - y_floor)), 81 | (1 - torch.abs(x - x_floor)) * (1 - torch.abs(y - y_ceil)), 82 | (1 - torch.abs(x - x_floor)) * (1 - torch.abs(y - y_floor)), 83 | ], 84 | 1, 85 | ) 86 | # values = torch.ones_like(values) 87 | 88 | values[invalid] = 0 89 | 90 | corresponding_map.scatter_add_(1, indices, values) 91 | # decode coordinates 92 | corresponding_map = corresponding_map.view(B, H, W) 93 | 94 | return corresponding_map.unsqueeze(1) 95 | 96 | 97 | def flow_warp(x, flow12, pad="border", mode="bilinear"): 98 | B, _, H, W = x.size() 99 | 100 | base_grid = mesh_grid(B, H, W).type_as(x) # B2HW 101 | 102 | v_grid = norm_grid(base_grid + flow12) # BHW2 103 | im1_recons = nn.functional.grid_sample( 104 | x, v_grid, mode=mode, padding_mode=pad, align_corners=True 105 | ) 106 | return im1_recons 107 | 108 | 109 | def get_occu_mask_bidirection(flow12, flow21, scale=0.01, bias=0.5): 110 | flow21_warped = flow_warp(flow21, flow12, pad="zeros") 111 | flow12_diff = flow12 + flow21_warped 112 | mag = (flow12 * flow12).sum(1, keepdim=True) + (flow21_warped * flow21_warped).sum( 113 | 1, keepdim=True 114 | ) 115 | occ_thresh = scale * mag + bias 116 | occ = (flow12_diff * flow12_diff).sum(1, keepdim=True) > occ_thresh 117 | return occ.float() 118 | 119 | 120 | def get_occu_mask_backward(flow21, th=0.2): 121 | B, _, H, W = flow21.size() 122 | base_grid = mesh_grid(B, H, W).type_as(flow21) # B2HW 123 | 124 | corr_map = get_corresponding_map(base_grid + flow21) # BHW 125 | occu_mask = corr_map.clamp(min=0.0, max=1.0) < th 126 | return occu_mask.float() 127 | --------------------------------------------------------------------------------