├── .idea └── .gitignore ├── LICENSE ├── README.md ├── assets ├── breakdance.gif ├── flowlens.png ├── in_beyond.gif └── out_beyond.gif ├── configs ├── KITTI360EX-I_FlowLens.json ├── KITTI360EX-I_FlowLens_re.json ├── KITTI360EX-I_FlowLens_small.json ├── KITTI360EX-I_FlowLens_small_re.json ├── KITTI360EX-O_FlowLens.json ├── KITTI360EX-O_FlowLens_re.json ├── KITTI360EX-O_FlowLens_small.json └── KITTI360EX-O_FlowLens_small_re.json ├── core ├── dataset.py ├── dist.py ├── loss.py ├── lr_scheduler.py ├── metrics.py ├── trainer.py └── utils.py ├── datasets └── KITTI-360EX │ ├── InnerSphere │ ├── test.json │ └── train.json │ └── OuterPinhole │ ├── test.json │ └── train.json ├── evaluate.py ├── model ├── flowlens.py └── modules │ ├── feat_prop.py │ ├── flow_comp.py │ ├── maskflownets.py │ ├── mix_focal_transformer.py │ └── spectral_norm.py ├── release_model ├── README.md └── maskflownets_8x1_sfine_flyingthings3d_subset_384x768.py └── train.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # 默认忽略的文件 2 | /shelf/ 3 | /workspace.xml 4 | # 基于编辑器的 HTTP 客户端请求 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Hao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ###

FlowLens: Seeing Beyond the FoV via Flow-guided Clip-Recurrent Transformer 2 |
3 |

4 | Hao Shi·   5 | Qi Jiang·   6 | Kailun Yang·   7 | Xiaoting Yin·   8 | Kaiwei Wang 9 |

10 | Paper 11 | 12 | #### 13 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/flowlens-seeing-beyond-the-fov-via-flow/seeing-beyond-the-visible-on-kitti360-ex)](https://paperswithcode.com/sota/seeing-beyond-the-visible-on-kitti360-ex?p=flowlens-seeing-beyond-the-fov-via-flow) 14 | 15 | [comment]: <> (Paper  ) 16 | 17 | [comment]: <> ( Demo Video (Youtube)  ) 18 | 19 | [comment]: <> ( 演示视频 (B站)  ) 20 |
21 | 22 | [comment]: <> (
) 23 | 24 | [comment]: <> (

:hammer_and_wrench: :construction_worker: :rocket:

) 25 | 26 | [comment]: <> (

:fire: We will release code and checkpoints in the future. :fire:

) 27 | 28 | [comment]: <> (
) 29 | 30 |
31 | 32 | ### Update 33 | - 2022.11.19 Init repository. 34 | - 2022.11.21 Release the [arXiv](https://arxiv.org/abs/2211.11293) version with supplementary materials. 35 | - 2023.04.04 :fire: Our code is publicly available. 36 | - 2023.04.04 :fire: Release pretrained models. 37 | - 2023.04.04 :fire: Release KITTI360-EX dataset. 38 | 39 | ### TODO List 40 | 41 | - [x] Code release. 42 | - [x] KITTI360-EX release. 43 | - [x] Towards higher performance with extra small costs. 44 | 45 | 46 | ### Abstract 47 | Limited by hardware cost and system size, camera's Field-of-View (FoV) is not always satisfactory. 48 | However, from a spatio-temporal perspective, information beyond the camera’s physical FoV is off-the-shelf and can actually be obtained ''for free'' from past video streams. 49 | In this paper, we propose a novel task termed Beyond-FoV Estimation, aiming to exploit past visual cues and bidirectional break through the physical FoV of a camera. 50 | We put forward a FlowLens architecture to expand the FoV by achieving feature propagation explicitly by optical flow and implicitly by a novel clip-recurrent transformer, 51 | which has two appealing features: 1) FlowLens comprises a newly proposed Clip-Recurrent Hub with 3D-Decoupled Cross Attention (DDCA) to progressively process global information accumulated in the temporal dimension. 2) A multi-branch Mix Fusion Feed Forward Network (MixF3N) is integrated to enhance the spatially-precise flow of local features. To foster training and evaluation, we establish KITTI360-EX, a dataset for outer- and inner FoV expansion. 52 | Extensive experiments on both video inpainting and beyond-FoV estimation tasks show that FlowLens achieves state-of-the-art performance. 53 | 54 | ### Demos 55 | 56 |

57 | (Outer Beyond-FoV) 58 |

59 |

60 | Animation 61 |

62 |

63 | 64 |

65 | (Inner Beyond-FoV) 66 |

67 |

68 | Animation 69 |

70 |

71 | 72 |

73 | (Object Removal) 74 |

75 |

76 | Animation 77 |

78 |

79 | 80 | ### Dependencies 81 | This repo has been tested in the following environment: 82 | ```angular2html 83 | torch == 1.10.2 84 | cuda == 11.3 85 | mmflow == 0.5.2 86 | ``` 87 | 88 | ### Usage 89 | To train FlowLens(-S), use: 90 | ```angular2html 91 | python train.py --config configs/KITTI360EX-I_FlowLens_small_re.json 92 | ``` 93 | 94 | To eval on KITTI360-EX, run: 95 | ```angular2html 96 | python evaluate.py \ 97 | --model flowlens \ 98 | --cfg_path configs/KITTI360EX-I_FlowLens_small_re.json \ 99 | --ckpt release_model/FlowLens-S_re_Out_500000.pth --fov fov5 100 | ``` 101 | 102 | Turn on ```--reverse``` for test time augmentation (TTA). 103 | 104 | Trun on ```--save_results``` to save your output. 105 | 106 | ### Pretrained Models 107 | The pretrained model can be found there: 108 | ```angular2html 109 | https://share.weiyun.com/6G6QEdaa 110 | ``` 111 | 112 | ### KITTI360-EX for Beyond-FoV Estimation 113 | The preprocessed KITTI360-EX can be downloaded from here: 114 | ```angular2html 115 | https://share.weiyun.com/BReRdDiP 116 | ``` 117 | 118 | ### Results 119 | #### KITTI360EX-InnerSphere 120 | | Method | Test Logic | TTA | PSNR | SSIM | VFID | Runtime (s/frame) | 121 | | :--------- | :----------: | :----------: | :----------: | :--------: | :---------: | :------------: | 122 | | _FlowLens-S (Paper)_ |_Beyond-FoV_|_wo_| _36.17_ | _0.9916_ | _0.030_ | _0.023_ | 123 | | FlowLens-S (This Repo) |Beyond-FoV|wo| 37.31 | 0.9926 | 0.025 | **0.015** | 124 | | FlowLens-S+ (This Repo) |Beyond-FoV|with| 38.36 | 0.9938 | 0.017 | 0.050 | 125 | | FlowLens-S (This Repo) |Video Inpainting|wo| 38.01 | 0.9938 | 0.022 | 0.042 | 126 | | FlowLens-S+ (This Repo) |Video Inpainting|with| **38.97** | **0.9947** | **0.015** | 0.142 | 127 | 128 | | Method | Test Logic | TTA | PSNR | SSIM | VFID | Runtime (s/frame) | 129 | | :--------- | :----------: | :----------: | :----------: | :--------: | :---------: | :------------: | 130 | | _FlowLens (Paper)_ |_Beyond-FoV_|_wo_| _36.69_ | _0.9916_ | _0.027_ | _0.049_ | 131 | | FlowLens (This Repo) |Beyond-FoV|wo| 37.65 | 0.9927 | 0.024 | **0.033** | 132 | | FlowLens+ (This Repo) |Beyond-FoV|with| 38.74 | 0.9941 | 0.017 | 0.095 | 133 | | FlowLens (This Repo) |Video Inpainting|wo| 38.38 | 0.9939 | 0.018 | 0.086 | 134 | | FlowLens+ (This Repo) |Video Inpainting|with| **39.40** | **0.9950** | **0.015** | 0.265 | 135 | ### 136 | 137 | #### KITTI360EX-OuterPinhole 138 | | Method | Test Logic | TTA | PSNR | SSIM | VFID | Runtime (s/frame) | 139 | | :--------- | :----------: | :----------: | :----------: | :--------: | :---------: | :------------: | 140 | | _FlowLens-S (Paper)_ |_Beyond-FoV_|_wo_| _19.68_ | _0.9247_ | _0.300_ | _0.023_ | 141 | | FlowLens-S (This Repo) |Beyond-FoV|wo| 20.41 | 0.9332 | 0.285 | **0.021** | 142 | | FlowLens-S+ (This Repo) |Beyond-FoV|with| 21.30 | 0.9397 | 0.302 | 0.056 | 143 | | FlowLens-S (This Repo) |Video Inpainting|wo| 21.69 | 0.9453 | **0.245** | 0.048 | 144 | | FlowLens-S+ (This Repo) |Video Inpainting|with| **22.40** | **0.9503** | 0.271 | 0.146 | 145 | 146 | | Method | Test Logic | TTA | PSNR | SSIM | VFID | Runtime (s/frame) | 147 | | :--------- | :----------: | :----------: | :----------: | :--------: | :---------: | :------------: | 148 | | _FlowLens (Paper)_ |_Beyond-FoV_|_wo_| _20.13_ | _0.9314_ | _0.281_ | _0.049_ | 149 | | FlowLens (This Repo) |Beyond-FoV|wo| 20.85 | 0.9381 | 0.259 | **0.035** | 150 | | FlowLens+ (This Repo) |Beyond-FoV|with| 21.65 | 0.9432 | 0.276 | 0.097 | 151 | | FlowLens (This Repo) |Video Inpainting|wo| 22.23 | 0.9507 | **0.231** | 0.085 | 152 | | FlowLens+ (This Repo) |Video Inpainting|with| **22.86** | **0.9543** | 0.253 | 0.260 | 153 | 154 | Note that when using the ''Video Inpainting'' logic for output, 155 | the model is allowed to use more reference frames from the future, 156 | and each local frame is estimated at least twice, 157 | thus higher accuracy can be obtained while result in slower inference speed, 158 | and it is not realistic for real-world deployment. 159 | 160 | ### Citation 161 | 162 | If you find our paper or repo useful, please consider citing our paper: 163 | 164 | ```bibtex 165 | @article{shi2022flowlens, 166 | title={FlowLens: Seeing Beyond the FoV via Flow-guided Clip-Recurrent Transformer}, 167 | author={Shi, Hao and Jiang, Qi and Yang, Kailun and Yin, Xiaoting and Wang, Kaiwei}, 168 | journal={arXiv preprint arXiv:2211.11293}, 169 | year={2022} 170 | } 171 | ``` 172 | ### Acknowledgement 173 | This project would not have been possible without the following outstanding repositories: 174 | 175 | [STTN](https://github.com/researchmm/STTN), [MMFlow](https://github.com/open-mmlab/mmflow) 176 | 177 | 178 | ### Devs 179 | Hao Shi 180 | 181 | ### Contact 182 | Feel free to contact me if you have additional questions or have interests in collaboration. Please drop me an email at haoshi@zju.edu.cn. =) 183 | -------------------------------------------------------------------------------- /assets/breakdance.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MasterHow/FlowLens/252568a423c89bb83d188cfad2e962a3d3423ee3/assets/breakdance.gif -------------------------------------------------------------------------------- /assets/flowlens.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MasterHow/FlowLens/252568a423c89bb83d188cfad2e962a3d3423ee3/assets/flowlens.png -------------------------------------------------------------------------------- /assets/in_beyond.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MasterHow/FlowLens/252568a423c89bb83d188cfad2e962a3d3423ee3/assets/in_beyond.gif -------------------------------------------------------------------------------- /assets/out_beyond.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MasterHow/FlowLens/252568a423c89bb83d188cfad2e962a3d3423ee3/assets/out_beyond.gif -------------------------------------------------------------------------------- /configs/KITTI360EX-I_FlowLens.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 2023, 3 | "save_dir": "release_model/", 4 | "train_data_loader": { 5 | "name": "KITTI360-EX", 6 | "data_root": "datasets//KITTI-360EX//InnerSphere", 7 | "w": 336, 8 | "h": 336, 9 | "num_local_frames": 5, 10 | "num_ref_frames": 3 11 | }, 12 | "losses": { 13 | "hole_weight": 1, 14 | "valid_weight": 1, 15 | "flow_weight": 1, 16 | "adversarial_weight": 0.01, 17 | "GAN_LOSS": "hinge" 18 | }, 19 | "model": { 20 | "net": "flowlens", 21 | "no_dis": 0, 22 | "depths": 9, 23 | "window_size": [7, 7], 24 | "output_size": [84, 84], 25 | "small_model": 0 26 | }, 27 | "trainer": { 28 | "type": "Adam", 29 | "beta1": 0, 30 | "beta2": 0.99, 31 | "lr": 0.25e-4, 32 | "batch_size": 2, 33 | "num_workers": 8, 34 | "log_freq": 100, 35 | "save_freq": 5e3, 36 | "iterations": 50e4, 37 | "scheduler": { 38 | "type": "MultiStepLR", 39 | "milestones": [ 40 | 40e4 41 | ], 42 | "gamma": 0.1 43 | } 44 | } 45 | } -------------------------------------------------------------------------------- /configs/KITTI360EX-I_FlowLens_re.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 2023, 3 | "save_dir": "release_model/", 4 | "train_data_loader": { 5 | "name": "KITTI360-EX", 6 | "data_root": "/workspace/mnt/storage/shihao/VI/KITTI-360EX/InnerSphere", 7 | "w": 336, 8 | "h": 336, 9 | "num_local_frames": 10, 10 | "num_ref_frames": 3 11 | }, 12 | "losses": { 13 | "hole_weight": 1, 14 | "valid_weight": 1, 15 | "flow_weight": 1, 16 | "adversarial_weight": 0.01, 17 | "GAN_LOSS": "hinge" 18 | }, 19 | "model": { 20 | "net": "flowlens", 21 | "no_dis": 0, 22 | "depths": 9, 23 | "window_size": [7, 7], 24 | "output_size": [84, 84], 25 | "small_model": 0 26 | }, 27 | "trainer": { 28 | "type": "Adam", 29 | "beta1": 0, 30 | "beta2": 0.99, 31 | "lr": 1e-4, 32 | "batch_size": 8, 33 | "num_workers": 4, 34 | "log_freq": 100, 35 | "save_freq": 5e3, 36 | "iterations": 50e4, 37 | "scheduler": { 38 | "type": "MultiStepLR", 39 | "milestones": [ 40 | 40e4 41 | ], 42 | "gamma": 0.1 43 | } 44 | } 45 | } -------------------------------------------------------------------------------- /configs/KITTI360EX-I_FlowLens_small.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 2023, 3 | "save_dir": "release_model/", 4 | "train_data_loader": { 5 | "name": "KITTI360-EX", 6 | "data_root": "/workspace/mnt/storage/shihao/VI/KITTI-360EX/InnerSphere", 7 | "w": 336, 8 | "h": 336, 9 | "num_local_frames": 5, 10 | "num_ref_frames": 3, 11 | "random_mask": 1 12 | }, 13 | "losses": { 14 | "hole_weight": 1, 15 | "valid_weight": 1, 16 | "flow_weight": 1, 17 | "adversarial_weight": 0.01, 18 | "GAN_LOSS": "hinge" 19 | }, 20 | "model": { 21 | "net": "flowlens", 22 | "no_dis": 0, 23 | "spy_net": 1, 24 | "mfn_teach": 1, 25 | "freeze_dcn": 0, 26 | "depths": 5, 27 | "window_size": [7, 7], 28 | "output_size": [84, 84], 29 | "small_model": 1 30 | }, 31 | "trainer": { 32 | "type": "Adam", 33 | "beta1": 0, 34 | "beta2": 0.99, 35 | "lr": 0.25e-4, 36 | "batch_size": 2, 37 | "num_workers": 2, 38 | "log_freq": 100, 39 | "save_freq": 5e3, 40 | "iterations": 50e4, 41 | "scheduler": { 42 | "type": "MultiStepLR", 43 | "milestones": [ 44 | 40e4 45 | ], 46 | "gamma": 0.1 47 | } 48 | } 49 | } -------------------------------------------------------------------------------- /configs/KITTI360EX-I_FlowLens_small_re.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 2023, 3 | "save_dir": "release_model/", 4 | "train_data_loader": { 5 | "name": "KITTI360-EX", 6 | "data_root": "/workspace/mnt/storage/shihao/VI/KITTI-360EX/InnerSphere", 7 | "w": 336, 8 | "h": 336, 9 | "num_local_frames": 10, 10 | "num_ref_frames": 3, 11 | "random_mask": 0 12 | }, 13 | "losses": { 14 | "hole_weight": 1, 15 | "valid_weight": 1, 16 | "flow_weight": 1, 17 | "adversarial_weight": 0.01, 18 | "GAN_LOSS": "hinge" 19 | }, 20 | "model": { 21 | "net": "flowlens", 22 | "no_dis": 0, 23 | "spy_net": 1, 24 | "mfn_teach": 1, 25 | "freeze_dcn": 0, 26 | "depths": 5, 27 | "window_size": [7, 7], 28 | "output_size": [84, 84], 29 | "small_model": 1 30 | }, 31 | "trainer": { 32 | "type": "Adam", 33 | "beta1": 0, 34 | "beta2": 0.99, 35 | "lr": 1e-4, 36 | "batch_size": 1, 37 | "num_workers": 4, 38 | "log_freq": 100, 39 | "save_freq": 5e3, 40 | "iterations": 50e4, 41 | "scheduler": { 42 | "type": "MultiStepLR", 43 | "milestones": [ 44 | 40e4 45 | ], 46 | "gamma": 0.1 47 | } 48 | } 49 | } -------------------------------------------------------------------------------- /configs/KITTI360EX-O_FlowLens.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 2023, 3 | "save_dir": "release_model/", 4 | "train_data_loader": { 5 | "name": "KITTI360-EX", 6 | "data_root": "datasets//KITTI-360EX//OuterPinhole", 7 | "w": 432, 8 | "h": 240, 9 | "num_local_frames": 5, 10 | "num_ref_frames": 3 11 | }, 12 | "losses": { 13 | "hole_weight": 1, 14 | "valid_weight": 1, 15 | "flow_weight": 1, 16 | "adversarial_weight": 0.01, 17 | "GAN_LOSS": "hinge" 18 | }, 19 | "model": { 20 | "net": "flowlens", 21 | "no_dis": 0, 22 | "depths": 9, 23 | "window_size": 0, 24 | "output_size": 0, 25 | "small_model": 0 26 | }, 27 | "trainer": { 28 | "type": "Adam", 29 | "beta1": 0, 30 | "beta2": 0.99, 31 | "lr": 0.25e-4, 32 | "batch_size": 2, 33 | "num_workers": 8, 34 | "log_freq": 100, 35 | "save_freq": 5e3, 36 | "iterations": 50e4, 37 | "scheduler": { 38 | "type": "MultiStepLR", 39 | "milestones": [ 40 | 40e4 41 | ], 42 | "gamma": 0.1 43 | } 44 | } 45 | } -------------------------------------------------------------------------------- /configs/KITTI360EX-O_FlowLens_re.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 2023, 3 | "save_dir": "release_model/", 4 | "train_data_loader": { 5 | "name": "KITTI360-EX", 6 | "data_root": "/workspace/mnt/storage/shihao/VI/KITTI-360EX/OuterPinhole", 7 | "w": 432, 8 | "h": 240, 9 | "num_local_frames": 11, 10 | "num_ref_frames": 3 11 | }, 12 | "losses": { 13 | "hole_weight": 1, 14 | "valid_weight": 1, 15 | "flow_weight": 1, 16 | "adversarial_weight": 0.01, 17 | "GAN_LOSS": "hinge" 18 | }, 19 | "model": { 20 | "net": "flowlens", 21 | "no_dis": 0, 22 | "depths": 9, 23 | "window_size": 0, 24 | "output_size": 0, 25 | "small_model": 0 26 | }, 27 | "trainer": { 28 | "type": "Adam", 29 | "beta1": 0, 30 | "beta2": 0.99, 31 | "lr": 1e-4, 32 | "batch_size": 8, 33 | "num_workers": 4, 34 | "log_freq": 100, 35 | "save_freq": 5e3, 36 | "iterations": 50e4, 37 | "scheduler": { 38 | "type": "MultiStepLR", 39 | "milestones": [ 40 | 40e4 41 | ], 42 | "gamma": 0.1 43 | } 44 | } 45 | } -------------------------------------------------------------------------------- /configs/KITTI360EX-O_FlowLens_small.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 2023, 3 | "save_dir": "release_model/", 4 | "train_data_loader": { 5 | "name": "KITTI360-EX", 6 | "data_root": "/workspace/mnt/storage/shihao/VI/KITTI-360EX/OuterPinhole", 7 | "w": 432, 8 | "h": 240, 9 | "num_local_frames": 5, 10 | "num_ref_frames": 3 11 | }, 12 | "losses": { 13 | "hole_weight": 1, 14 | "valid_weight": 1, 15 | "flow_weight": 1, 16 | "adversarial_weight": 0.01, 17 | "GAN_LOSS": "hinge" 18 | }, 19 | "model": { 20 | "net": "flowlens", 21 | "no_dis": 0, 22 | "spy_net": 1, 23 | "mfn_teach": 1, 24 | "depths": 5, 25 | "window_size": 0, 26 | "output_size": 0, 27 | "small_model": 1 28 | }, 29 | "trainer": { 30 | "type": "Adam", 31 | "beta1": 0, 32 | "beta2": 0.99, 33 | "lr": 0.25e-4, 34 | "batch_size": 2, 35 | "num_workers": 2, 36 | "log_freq": 100, 37 | "save_freq": 5e3, 38 | "iterations": 50e4, 39 | "scheduler": { 40 | "type": "MultiStepLR", 41 | "milestones": [ 42 | 40e4 43 | ], 44 | "gamma": 0.1 45 | } 46 | } 47 | } -------------------------------------------------------------------------------- /configs/KITTI360EX-O_FlowLens_small_re.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 2023, 3 | "save_dir": "release_model/", 4 | "train_data_loader": { 5 | "name": "KITTI360-EX", 6 | "data_root": "/workspace/mnt/storage/shihao/VI/KITTI-360EX/OuterPinhole", 7 | "w": 432, 8 | "h": 240, 9 | "num_local_frames": 11, 10 | "num_ref_frames": 3 11 | }, 12 | "losses": { 13 | "hole_weight": 1, 14 | "valid_weight": 1, 15 | "flow_weight": 1, 16 | "adversarial_weight": 0.01, 17 | "GAN_LOSS": "hinge" 18 | }, 19 | "model": { 20 | "net": "flowlens", 21 | "no_dis": 0, 22 | "spy_net": 1, 23 | "mfn_teach": 1, 24 | "depths": 5, 25 | "window_size": null, 26 | "output_size": null, 27 | "small_model": 1 28 | }, 29 | "trainer": { 30 | "type": "Adam", 31 | "beta1": 0, 32 | "beta2": 0.99, 33 | "lr": 1e-4, 34 | "batch_size": 8, 35 | "num_workers": 4, 36 | "log_freq": 100, 37 | "save_freq": 5e3, 38 | "iterations": 50e4, 39 | "scheduler": { 40 | "type": "MultiStepLR", 41 | "milestones": [ 42 | 40e4 43 | ], 44 | "gamma": 0.1 45 | } 46 | } 47 | } -------------------------------------------------------------------------------- /core/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | 5 | import cv2 6 | from PIL import Image 7 | import numpy as np 8 | 9 | import torch 10 | import torchvision.transforms as transforms 11 | 12 | from core.utils import (TrainZipReader, TestZipReader, 13 | create_random_shape_with_random_motion, Stack, 14 | ToTorchFormatTensor, GroupRandomHorizontalFlip, 15 | create_random_shape_with_random_motion_seq) 16 | from core.dist import get_world_size 17 | 18 | 19 | class TrainDataset(torch.utils.data.Dataset): 20 | """ 21 | Sequence Video Train Dataloader by Hao, based on E2FGVI train loader. 22 | same_mask(bool): If True, use same mask until video changes to the next video. 23 | random_mask(bool): If True, use random mask rather than FoV mask for KITTI360. 24 | """ 25 | def __init__(self, args: dict, debug=False, start=0, end=1, batch_size=1, same_mask=False): 26 | self.args = args 27 | self.num_local_frames = args['num_local_frames'] 28 | self.num_ref_frames = args['num_ref_frames'] 29 | self.size = self.w, self.h = (args['w'], args['h']) 30 | 31 | # 是否为KITTI360使用随机mask 32 | args['random_mask'] = args.get('random_mask', 0) 33 | if args['random_mask'] != 0: 34 | self.random_mask = True 35 | else: 36 | # default 37 | self.random_mask = False 38 | 39 | if args['name'] != 'KITTI360-EX': 40 | json_path = os.path.join(args['data_root'], args['name'], 'train.json') 41 | self.dataset_name = args['name'] 42 | else: 43 | json_path = os.path.join(args['data_root'], 'train.json') 44 | self.dataset_name = 'KITTI360-EX' 45 | 46 | with open(json_path, 'r') as f: 47 | self.video_dict = json.load(f) 48 | self.video_names = list(self.video_dict.keys()) 49 | if args['name'] == 'KITTI360-EX': 50 | # 打乱数据顺序防止过拟合 51 | from random import shuffle 52 | shuffle(self.video_names) 53 | if debug: 54 | self.video_names = self.video_names[:100] 55 | 56 | self._to_tensors = transforms.Compose([ 57 | Stack(), 58 | ToTorchFormatTensor(), 59 | ]) 60 | 61 | # self.neighbor_stride = 5 # 每隔5步采样一组LF和NLF 这里可能是个采样bug 62 | self.neighbor_stride = self.num_local_frames # 每隔window size步采样一组LF和NLF 63 | self.start_index = 0 # 采样的起始点 64 | self.start = start # 数据集迭代起点 65 | self.end = len(self.video_names) # 数据集迭代终点 66 | self.batch_size = batch_size # 用于计算数据集迭代器 67 | 68 | # 多卡DDP训练需要知道数据的local rank 69 | self.world_size = get_world_size() 70 | 71 | # 如果多卡ddp,实际上每张卡上是独立进程,每张卡上的batch size是要除以总的卡数的 72 | if self.world_size > 1: 73 | self.batch_size = self.batch_size // self.world_size 74 | 75 | self.index = 0 # 自定义视频index 76 | self.batch_buffer = 0 # 用于随着batch size更新视频index 77 | self.new_video_flag = False # 用于判断是否到了新视频 78 | self.worker_group = 0 # 用于随着每组worker更新 79 | 80 | self.same_mask = same_mask # 如果为True, 在切换视频前使用相同的mask,这样的行为模式 81 | if self.same_mask: 82 | self.random_dict_list = [] 83 | self.new_mask_list = [] 84 | for i in range(0, self.batch_size): 85 | # 用于存储随机mask的参数字典, 不同batch不一样 86 | self.random_dict_list.append(None) 87 | # 当设置为True时,mask会重新随机生成 88 | self.new_mask_list.append(False) 89 | 90 | # 为每个batch创建独立的video index和start_index, 以及worker_group 91 | self.video_index_list = [] 92 | for i in range(0, self.batch_size): 93 | # 初始化video id时将不同batch的错开防止数据重复和过拟合 94 | if self.world_size == 1: 95 | # 单卡逻辑 96 | self.video_index_list.append(i * len(self.video_names)//self.batch_size) 97 | else: 98 | # 多卡逻辑 99 | # 随机初始化视频index 100 | self.video_index_list.append(random.randint(0, len(self.video_names))) 101 | 102 | self.start_index_list = [] 103 | for i in range(0, self.batch_size): 104 | self.start_index_list.append(0) 105 | 106 | self.worker_group_list = [] 107 | for i in range(0, self.batch_size): 108 | self.worker_group_list.append(0) 109 | 110 | def __len__(self): 111 | # 视频切换等操作自定义完成,iter定义为一个大数避免与自定义index冲突 112 | return len(self.video_names)*1000 113 | 114 | # 两次迭代相距5帧, 不同batch通道是不同的视频, 视频index和起始帧index在batch之间独立, 避免数据浪费 115 | def __getitem__(self, index): 116 | worker_info = torch.utils.data.get_worker_info() 117 | if worker_info is None: 118 | # single-process data loading, skip idx rearrange 119 | print('Warning: Only one data worker was used, the manner has not been test!') 120 | pass 121 | else: 122 | if self.start_index_list[self.batch_buffer] == 0 and self.worker_group_list[self.batch_buffer] == 0: 123 | # 新视频, 初始化start index 124 | 125 | if self.same_mask: 126 | # 在这种行为模式下,当切换到新视频时,我们重新生成mask 127 | # 只有对于第一个worker,我们希望他可以成功生成新的mask,其他的worker最好和他用一样的mask 128 | if worker_info.id == 0: 129 | self.random_dict_list[self.batch_buffer] = None 130 | self.new_mask_list[self.batch_buffer] = True 131 | 132 | # 判断start index有没有超出视频的长度 133 | if (self.neighbor_stride * worker_info.id) <= self.video_dict[self.video_names[self.video_index_list[self.batch_buffer]]]: 134 | self.start_index_list[self.batch_buffer] = self.neighbor_stride * worker_info.id 135 | else: 136 | # 超出则等待第一个worker超出,并且start index置为当前worker的上一个没有超出的worker的start id 137 | # 等待第一个worker超出后再更新到下一个视频,防止不同worker的start id错位 138 | # 判断前面的worker有没有超出,用最近的且没有超出的worker的值替换 139 | out_list = [] 140 | final_worker_idx = 0 141 | for previous_worker in range(0, worker_info.num_workers): 142 | out_list.append( 143 | ((self.neighbor_stride * previous_worker) <= self.video_dict[self.video_names[self.video_index_list[self.batch_buffer]]]) 144 | ) 145 | out_list.reverse() 146 | final_worker_idx = worker_info.num_workers - 1 - out_list.index(True) 147 | self.start_index_list[self.batch_buffer] = self.neighbor_stride * final_worker_idx 148 | 149 | else: 150 | # 不是新视频 151 | if self.same_mask: 152 | # 在这种行为模式下,不是新视频,直接用上一次的mask参数 153 | self.new_mask_list[self.batch_buffer] = False 154 | 155 | # 判断start index有没有超出视频的长度 156 | if (self.start_index_list[self.batch_buffer] + self.neighbor_stride * worker_info.num_workers) <= self.video_dict[self.video_names[self.video_index_list[self.batch_buffer]]]: 157 | self.start_index_list[self.batch_buffer] += self.neighbor_stride * worker_info.num_workers 158 | else: 159 | # 超出则切换到下一个视频,并且start index置为0(每个worker的start仍然不同) 160 | # 如果第一个worker没有超出,就复制当前worker的上一个worker的start id,等待第一个worker超出后再更新到下一个视频, 161 | # 防止不同worker的start id错位 162 | if (self.start_index_list[self.batch_buffer] + self.neighbor_stride * (worker_info.num_workers - worker_info.id)) <= self.video_dict[self.video_names[self.video_index_list[self.batch_buffer]]]: 163 | # 判断前面的worker有没有超出,用最近的且没有超出的worker的值替换 164 | out_list = [] 165 | final_worker_idx = 0 166 | for previous_worker in range(0, worker_info.num_workers): 167 | out_list.append( 168 | ((self.start_index_list[self.batch_buffer] + self.neighbor_stride * (worker_info.num_workers - worker_info.id + previous_worker)) <= self.video_dict[self.video_names[self.video_index_list[self.batch_buffer]]]) 169 | ) 170 | out_list.reverse() 171 | final_worker_idx = worker_info.num_workers - 1 - out_list.index(True) 172 | self.start_index_list[self.batch_buffer] = self.start_index_list[self.batch_buffer] + self.neighbor_stride * (worker_info.num_workers - worker_info.id + final_worker_idx) 173 | else: 174 | # 如果第一个worker超出了,切换到下一个视频 175 | self.new_video_flag = True 176 | 177 | if self.same_mask: 178 | # 在这种行为模式下,当切换到新视频时,我们重新生成mask 179 | # 只有对于第一个worker,我们希望他可以成功生成新的mask,其他的worker最好和他用一样的mask 180 | if worker_info.id == 0: 181 | self.random_dict_list[self.batch_buffer] = None 182 | self.new_mask_list[self.batch_buffer] = True 183 | 184 | # self.worker_group = 0 185 | self.worker_group_list[self.batch_buffer] = 0 186 | self.start_index_list[self.batch_buffer] = self.neighbor_stride * worker_info.id 187 | # 判断视频 index有没有超出视频的个数 188 | if (self.video_index_list[self.batch_buffer] + 1) < len(self.video_names): 189 | self.video_index_list[self.batch_buffer] += 1 190 | else: 191 | # 超出则切换回第一个视频 192 | self.video_index_list[self.batch_buffer] = 0 193 | 194 | # 根据index和start index读取帧 195 | self.index = self.video_index_list[self.batch_buffer] 196 | self.start_index = self.start_index_list[self.batch_buffer] 197 | item = self.load_item_v4() 198 | 199 | # 更新woker group的index 200 | self.worker_group_list[self.batch_buffer] += 1 201 | 202 | self.batch_buffer += 1 203 | if self.batch_buffer == self.batch_size: 204 | # 重置batch缓存 205 | self.batch_buffer = 0 206 | 207 | return item 208 | 209 | def _sample_index_seq(self, length, sample_length, num_ref_frame=3, pivot=0, before_nlf=False): 210 | """ 211 | 212 | Args: 213 | length: 214 | sample_length: 215 | num_ref_frame: 216 | pivot: 217 | before_nlf: If True, the non local frames will be sampled only from previous frames, not from future. 218 | 219 | Returns: 220 | 221 | """ 222 | complete_idx_set = list(range(length)) 223 | local_idx = complete_idx_set[pivot:pivot + sample_length] 224 | 225 | # 保证最后几帧返回的局部帧数量一致,也是5帧(步数),使得batch stack的时候不会出错: 226 | if len(local_idx) < self.neighbor_stride: 227 | for i in range(0, self.neighbor_stride-len(local_idx)): 228 | if local_idx: 229 | local_idx.append(local_idx[-1]) 230 | else: 231 | # 恰好视频长度是局部帧步幅的整数倍,取local_idx为最后一帧5次 232 | local_idx.append(complete_idx_set[-1]) 233 | 234 | if before_nlf: 235 | # 非局部帧只会从过去的视频帧中选取,不会使用未来的信息 236 | complete_idx_set = complete_idx_set[:pivot + sample_length] 237 | 238 | remain_idx = list(set(complete_idx_set) - set(local_idx)) 239 | 240 | # 当只用过去的帧作为非局部帧时,可能会出现过去的帧数量少于非局部帧需求的问题,比如视频的一开始 241 | if before_nlf: 242 | if len(remain_idx) < num_ref_frame: 243 | # 则我们允许从局部帧中采样非局部帧 转换为set可以去除重复元素 244 | remain_idx = list(set(remain_idx + local_idx)) 245 | 246 | ref_index = sorted(random.sample(remain_idx, num_ref_frame)) 247 | 248 | return local_idx + ref_index 249 | 250 | def load_item_v4(self): 251 | """避免dataloader的index和worker的index冲突""" 252 | video_name = self.video_names[self.index] 253 | 254 | # create masks 255 | if self.dataset_name != 'KITTI360-EX': 256 | # 对于非KITTI360-EX数据集,随机创建mask 257 | if not self.same_mask: 258 | # 每次迭代都会生成新形状的随机mask 259 | all_masks = create_random_shape_with_random_motion( 260 | self.video_dict[video_name], imageHeight=self.h, imageWidth=self.w) 261 | 262 | else: 263 | # 在切换新视频前使用一样的mask参数 264 | all_masks, random_dict = create_random_shape_with_random_motion_seq( 265 | self.video_dict[video_name], imageHeight=self.h, imageWidth=self.w, 266 | new_mask=self.new_mask_list[self.batch_buffer], 267 | random_dict=self.random_dict_list[self.batch_buffer]) 268 | # 更新随机mask的参数 269 | self.random_dict_list[self.batch_buffer] = random_dict 270 | 271 | elif self.random_mask: 272 | # 如果使用random_mask,则为KITTI360创建随机mask 273 | all_masks = create_random_shape_with_random_motion( 274 | self.video_dict[video_name], imageHeight=self.h, imageWidth=self.w) 275 | 276 | # create sample index 277 | # 对于KITTI360-EX这样视场角扩展的场景,非局部帧只能从过去的信息中获取 278 | if self.dataset_name == 'KITTI360-EX': 279 | before_nlf = True 280 | else: 281 | # 默认视频补全可以用未来的信息 282 | before_nlf = False 283 | selected_index = self._sample_index_seq(self.video_dict[video_name], 284 | self.num_local_frames, 285 | self.num_ref_frames, 286 | pivot=self.start_index, 287 | before_nlf=before_nlf) 288 | 289 | # read video frames 290 | frames = [] 291 | masks = [] 292 | for idx in selected_index: 293 | if self.dataset_name != 'KITTI360-EX': 294 | video_path = os.path.join(self.args['data_root'], 295 | self.args['name'], 'JPEGImages', 296 | f'{video_name}.zip') 297 | else: 298 | video_path = os.path.join(self.args['data_root'], 299 | 'JPEGImages', 300 | f'{video_name}.zip') 301 | img = TrainZipReader.imread(video_path, idx).convert('RGB') 302 | img = img.resize(self.size) 303 | frames.append(img) 304 | if self.dataset_name != 'KITTI360-EX': 305 | masks.append(all_masks[idx]) 306 | elif self.random_mask: 307 | # 对于KITTI360-EX数据集, 也可以使用random mask 308 | masks.append(all_masks[idx]) 309 | else: 310 | # 对于KITTI360-EX数据集,读取zip中存储的mask 311 | mask_path = os.path.join(self.args['data_root'], 312 | 'test_masks', 313 | f'{video_name}.zip') 314 | mask = TrainZipReader.imread(mask_path, idx) 315 | mask = mask.resize(self.size).convert('L') 316 | mask = np.asarray(mask) 317 | m = np.array(mask > 0).astype(np.uint8) 318 | mask = Image.fromarray(m * 255) 319 | masks.append(mask) 320 | 321 | # normalizate, to tensors 322 | frames = GroupRandomHorizontalFlip()(frames) 323 | if self.dataset_name == 'KITTI360-EX': 324 | # 对于本地读取的mask 也需要随着frame翻转 325 | if not self.random_mask: 326 | masks = GroupRandomHorizontalFlip()(masks) 327 | frame_tensors = self._to_tensors(frames) * 2.0 - 1.0 328 | mask_tensors = self._to_tensors(masks) 329 | 330 | if not self.same_mask: 331 | # 每次生成新的随机mask,不需要返回字典 332 | return frame_tensors, mask_tensors, video_name, self.index, self.start_index 333 | else: 334 | # 要控制mask的行为一致,需要返回字典 335 | return frame_tensors, mask_tensors, video_name, self.index, self.start_index,\ 336 | self.new_mask_list[self.batch_buffer], self.random_dict_list 337 | 338 | 339 | class TestDataset(torch.utils.data.Dataset): 340 | def __init__(self, args): 341 | self.args = args 342 | self.size = self.w, self.h = args.size 343 | try: 344 | # 是否使用 object mask 345 | self.object_mask = args.object 346 | except: 347 | self.object_mask = False 348 | 349 | if args.dataset != 'KITTI360-EX': 350 | # default manner 351 | json_path = os.path.join(args.data_root, args.dataset, 'test.json') 352 | else: 353 | # for KITTI360-EX 354 | json_path = os.path.join(args.data_root, 'test.json') 355 | 356 | with open(json_path, 'r') as f: 357 | self.video_dict = json.load(f) 358 | self.video_names = list(self.video_dict.keys()) 359 | 360 | self._to_tensors = transforms.Compose([ 361 | Stack(), 362 | ToTorchFormatTensor(), 363 | ]) 364 | 365 | def __len__(self): 366 | return len(self.video_names) 367 | 368 | def __getitem__(self, index): 369 | item = self.load_item(index) 370 | return item 371 | 372 | def load_item(self, index): 373 | video_name = self.video_names[index] 374 | ref_index = list(range(self.video_dict[video_name])) 375 | 376 | # read video frames 377 | frames = [] 378 | masks = [] 379 | for idx in ref_index: 380 | 381 | # read img from zip 382 | if self.args.dataset != 'KITTI360-EX': 383 | # default manner 384 | video_path = os.path.join(self.args.data_root, self.args.dataset, 'JPEGImages', f'{video_name}.zip') 385 | else: 386 | # for KITTI360-EX 387 | video_path = os.path.join(self.args.data_root, 'JPEGImages', f'{video_name}.zip') 388 | 389 | img = TestZipReader.imread(video_path, idx).convert('RGB') 390 | img = img.resize(self.size) 391 | frames.append(img) 392 | 393 | # read mask from folder 394 | if self.args.dataset != 'KITTI360-EX': 395 | if not self.object_mask: 396 | # default manner for video completion 397 | mask_path = os.path.join(self.args.data_root, self.args.dataset, 398 | 'test_masks', video_name, 399 | str(idx).zfill(5) + '.png') 400 | else: 401 | # default manner for object removal 402 | mask_path = os.path.join(self.args.data_root, self.args.dataset, 403 | 'test_masks_object', video_name, 404 | str(idx).zfill(5) + '.png') 405 | else: 406 | # for KITTI360-EX: use seq 10 for testing 407 | mask_path = os.path.join(self.args.data_root, 'test_masks', 'seq10', self.args.fov, video_name, 408 | str(idx).zfill(6) + '.png') 409 | 410 | mask = Image.open(mask_path).resize(self.size, 411 | Image.NEAREST).convert('L') 412 | # origin: 0 indicates missing. now: 1 indicates missing 413 | mask = np.asarray(mask) 414 | m = np.array(mask > 0).astype(np.uint8) 415 | m = cv2.dilate(m, 416 | cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)), 417 | iterations=4) 418 | mask = Image.fromarray(m * 255) 419 | masks.append(mask) 420 | 421 | # to tensors 422 | frames_PIL = [np.array(f).astype(np.uint8) for f in frames] 423 | frame_tensors = self._to_tensors(frames) * 2.0 - 1.0 424 | mask_tensors = self._to_tensors(masks) 425 | return frame_tensors, mask_tensors, video_name, frames_PIL 426 | -------------------------------------------------------------------------------- /core/dist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | 5 | def get_world_size(): 6 | """Find OMPI world size without calling mpi functions 7 | :rtype: int 8 | """ 9 | if os.environ.get('PMI_SIZE') is not None: 10 | return int(os.environ.get('PMI_SIZE') or 1) 11 | elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None: 12 | return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1) 13 | else: 14 | return torch.cuda.device_count() 15 | 16 | 17 | def get_global_rank(): 18 | """Find OMPI world rank without calling mpi functions 19 | :rtype: int 20 | """ 21 | if os.environ.get('PMI_RANK') is not None: 22 | return int(os.environ.get('PMI_RANK') or 0) 23 | elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None: 24 | return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0) 25 | else: 26 | return 0 27 | 28 | 29 | def get_local_rank(): 30 | """Find OMPI local rank without calling mpi functions 31 | :rtype: int 32 | """ 33 | if os.environ.get('MPI_LOCALRANKID') is not None: 34 | return int(os.environ.get('MPI_LOCALRANKID') or 0) 35 | elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None: 36 | return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0) 37 | else: 38 | return 0 39 | 40 | 41 | def get_master_ip(): 42 | if os.environ.get('AZ_BATCH_MASTER_NODE') is not None: 43 | return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0] 44 | elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None: 45 | return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') 46 | else: 47 | return "127.0.0.1" 48 | -------------------------------------------------------------------------------- /core/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class AdversarialLoss(nn.Module): 6 | r""" 7 | Adversarial loss 8 | https://arxiv.org/abs/1711.10337 9 | """ 10 | def __init__(self, 11 | type='nsgan', 12 | target_real_label=1.0, 13 | target_fake_label=0.0): 14 | r""" 15 | type = nsgan | lsgan | hinge 16 | """ 17 | super(AdversarialLoss, self).__init__() 18 | self.type = type 19 | self.register_buffer('real_label', torch.tensor(target_real_label)) 20 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 21 | 22 | if type == 'nsgan': 23 | self.criterion = nn.BCELoss() 24 | elif type == 'lsgan': 25 | self.criterion = nn.MSELoss() 26 | elif type == 'hinge': 27 | self.criterion = nn.ReLU() 28 | 29 | def __call__(self, outputs, is_real, is_disc=None): 30 | if self.type == 'hinge': 31 | if is_disc: 32 | if is_real: 33 | outputs = -outputs 34 | return self.criterion(1 + outputs).mean() 35 | else: 36 | return (-outputs).mean() 37 | else: 38 | labels = (self.real_label 39 | if is_real else self.fake_label).expand_as(outputs) 40 | loss = self.criterion(outputs, labels) 41 | return loss 42 | -------------------------------------------------------------------------------- /core/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | """ 2 | LR scheduler from BasicSR https://github.com/xinntao/BasicSR 3 | """ 4 | import math 5 | from collections import Counter 6 | from torch.optim.lr_scheduler import _LRScheduler 7 | 8 | 9 | class MultiStepRestartLR(_LRScheduler): 10 | """ MultiStep with restarts learning rate scheme. 11 | Args: 12 | optimizer (torch.nn.optimizer): Torch optimizer. 13 | milestones (list): Iterations that will decrease learning rate. 14 | gamma (float): Decrease ratio. Default: 0.1. 15 | restarts (list): Restart iterations. Default: [0]. 16 | restart_weights (list): Restart weights at each restart iteration. 17 | Default: [1]. 18 | last_epoch (int): Used in _LRScheduler. Default: -1. 19 | """ 20 | def __init__(self, 21 | optimizer, 22 | milestones, 23 | gamma=0.1, 24 | restarts=(0, ), 25 | restart_weights=(1, ), 26 | last_epoch=-1): 27 | self.milestones = Counter(milestones) 28 | self.gamma = gamma 29 | self.restarts = restarts 30 | self.restart_weights = restart_weights 31 | assert len(self.restarts) == len( 32 | self.restart_weights), 'restarts and their weights do not match.' 33 | super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) 34 | 35 | def get_lr(self): 36 | if self.last_epoch in self.restarts: 37 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 38 | return [ 39 | group['initial_lr'] * weight 40 | for group in self.optimizer.param_groups 41 | ] 42 | if self.last_epoch not in self.milestones: 43 | return [group['lr'] for group in self.optimizer.param_groups] 44 | return [ 45 | group['lr'] * self.gamma**self.milestones[self.last_epoch] 46 | for group in self.optimizer.param_groups 47 | ] 48 | 49 | 50 | def get_position_from_periods(iteration, cumulative_period): 51 | """Get the position from a period list. 52 | It will return the index of the right-closest number in the period list. 53 | For example, the cumulative_period = [100, 200, 300, 400], 54 | if iteration == 50, return 0; 55 | if iteration == 210, return 2; 56 | if iteration == 300, return 2. 57 | Args: 58 | iteration (int): Current iteration. 59 | cumulative_period (list[int]): Cumulative period list. 60 | Returns: 61 | int: The position of the right-closest number in the period list. 62 | """ 63 | for i, period in enumerate(cumulative_period): 64 | if iteration <= period: 65 | return i 66 | 67 | 68 | class CosineAnnealingRestartLR(_LRScheduler): 69 | """ Cosine annealing with restarts learning rate scheme. 70 | An example of config: 71 | periods = [10, 10, 10, 10] 72 | restart_weights = [1, 0.5, 0.5, 0.5] 73 | eta_min=1e-7 74 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the 75 | scheduler will restart with the weights in restart_weights. 76 | Args: 77 | optimizer (torch.nn.optimizer): Torch optimizer. 78 | periods (list): Period for each cosine anneling cycle. 79 | restart_weights (list): Restart weights at each restart iteration. 80 | Default: [1]. 81 | eta_min (float): The mimimum lr. Default: 0. 82 | last_epoch (int): Used in _LRScheduler. Default: -1. 83 | """ 84 | def __init__(self, 85 | optimizer, 86 | periods, 87 | restart_weights=(1, ), 88 | eta_min=1e-7, 89 | last_epoch=-1): 90 | self.periods = periods 91 | self.restart_weights = restart_weights 92 | self.eta_min = eta_min 93 | assert (len(self.periods) == len(self.restart_weights) 94 | ), 'periods and restart_weights should have the same length.' 95 | self.cumulative_period = [ 96 | sum(self.periods[0:i + 1]) for i in range(0, len(self.periods)) 97 | ] 98 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) 99 | 100 | def get_lr(self): 101 | idx = get_position_from_periods(self.last_epoch, 102 | self.cumulative_period) 103 | current_weight = self.restart_weights[idx] 104 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] 105 | current_period = self.periods[idx] 106 | 107 | return [ 108 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * 109 | (1 + math.cos(math.pi * ( 110 | (self.last_epoch - nearest_restart) / current_period))) 111 | for base_lr in self.base_lrs 112 | ] 113 | -------------------------------------------------------------------------------- /core/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import skimage 3 | import skimage.metrics 4 | from skimage import measure 5 | from scipy import linalg 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from core.utils import to_tensors 12 | 13 | 14 | def calculate_epe(flow1, flow2): 15 | """Calculate End point errors.""" 16 | 17 | epe = torch.sum((flow1 - flow2)**2, dim=1).sqrt() 18 | epe = epe.view(-1) 19 | return epe.mean().item() 20 | 21 | 22 | def calculate_psnr(img1, img2): 23 | """Calculate PSNR (Peak Signal-to-Noise Ratio). 24 | Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio 25 | Args: 26 | img1 (ndarray): Images with range [0, 255]. 27 | img2 (ndarray): Images with range [0, 255]. 28 | Returns: 29 | float: psnr result. 30 | """ 31 | 32 | assert img1.shape == img2.shape, \ 33 | (f'Image shapes are differnet: {img1.shape}, {img2.shape}.') 34 | 35 | mse = np.mean((img1 - img2)**2) 36 | if mse == 0: 37 | return float('inf') 38 | return 20. * np.log10(255. / np.sqrt(mse)) 39 | 40 | 41 | def calc_psnr_and_ssim(img1, img2): 42 | """Calculate PSNR and SSIM for images. 43 | img1: ndarray, range [0, 255] 44 | img2: ndarray, range [0, 255] 45 | """ 46 | img1 = img1.astype(np.float64) 47 | img2 = img2.astype(np.float64) 48 | 49 | psnr = calculate_psnr(img1, img2) 50 | if skimage.__version__ != '0.19.3': 51 | # old version skimage 52 | ssim = measure.compare_ssim(img1, 53 | img2, 54 | data_range=255, 55 | multichannel=True, 56 | win_size=65) 57 | else: 58 | # new version skimage 59 | ssim = skimage.metrics.structural_similarity(img1, 60 | img2, 61 | data_range=255, 62 | multichannel=True, 63 | win_size=65) 64 | 65 | return psnr, ssim 66 | 67 | 68 | ########################### 69 | # I3D models 70 | ########################### 71 | 72 | 73 | def init_i3d_model(): 74 | i3d_model_path = './release_model/i3d_rgb_imagenet.pt' 75 | print(f"[Loading I3D model from {i3d_model_path} for FID score ..]") 76 | i3d_model = InceptionI3d(400, in_channels=3, final_endpoint='Logits') 77 | i3d_model.load_state_dict(torch.load(i3d_model_path)) 78 | i3d_model.to(torch.device('cuda:0')) 79 | return i3d_model 80 | 81 | 82 | def calculate_i3d_activations(video1, video2, i3d_model, device): 83 | """Calculate VFID metric. 84 | video1: list[PIL.Image] 85 | video2: list[PIL.Image] 86 | """ 87 | video1 = to_tensors()(video1).unsqueeze(0).to(device) 88 | video2 = to_tensors()(video2).unsqueeze(0).to(device) 89 | video1_activations = get_i3d_activations( 90 | video1, i3d_model).cpu().numpy().flatten() 91 | video2_activations = get_i3d_activations( 92 | video2, i3d_model).cpu().numpy().flatten() 93 | 94 | return video1_activations, video2_activations 95 | 96 | 97 | def calculate_vfid(real_activations, fake_activations): 98 | """ 99 | Given two distribution of features, compute the FID score between them 100 | Params: 101 | real_activations: list[ndarray] 102 | fake_activations: list[ndarray] 103 | """ 104 | m1 = np.mean(real_activations, axis=0) 105 | m2 = np.mean(fake_activations, axis=0) 106 | s1 = np.cov(real_activations, rowvar=False) 107 | s2 = np.cov(fake_activations, rowvar=False) 108 | return calculate_frechet_distance(m1, s1, m2, s2) 109 | 110 | 111 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 112 | """Numpy implementation of the Frechet Distance. 113 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 114 | and X_2 ~ N(mu_2, C_2) is 115 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 116 | Stable version by Dougal J. Sutherland. 117 | Params: 118 | -- mu1 : Numpy array containing the activations of a layer of the 119 | inception net (like returned by the function 'get_predictions') 120 | for generated samples. 121 | -- mu2 : The sample mean over activations, precalculated on an 122 | representive data set. 123 | -- sigma1: The covariance matrix over activations for generated samples. 124 | -- sigma2: The covariance matrix over activations, precalculated on an 125 | representive data set. 126 | Returns: 127 | -- : The Frechet Distance. 128 | """ 129 | 130 | mu1 = np.atleast_1d(mu1) 131 | mu2 = np.atleast_1d(mu2) 132 | 133 | sigma1 = np.atleast_2d(sigma1) 134 | sigma2 = np.atleast_2d(sigma2) 135 | 136 | assert mu1.shape == mu2.shape, \ 137 | 'Training and test mean vectors have different lengths' 138 | assert sigma1.shape == sigma2.shape, \ 139 | 'Training and test covariances have different dimensions' 140 | 141 | diff = mu1 - mu2 142 | 143 | # Product might be almost singular 144 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 145 | if not np.isfinite(covmean).all(): 146 | msg = ('fid calculation produces singular product; ' 147 | 'adding %s to diagonal of cov estimates') % eps 148 | print(msg) 149 | offset = np.eye(sigma1.shape[0]) * eps 150 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 151 | 152 | # Numerical error might give slight imaginary component 153 | if np.iscomplexobj(covmean): 154 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 155 | m = np.max(np.abs(covmean.imag)) 156 | raise ValueError('Imaginary component {}'.format(m)) 157 | covmean = covmean.real 158 | 159 | tr_covmean = np.trace(covmean) 160 | 161 | return (diff.dot(diff) + np.trace(sigma1) + # NOQA 162 | np.trace(sigma2) - 2 * tr_covmean) 163 | 164 | 165 | def get_i3d_activations(batched_video, 166 | i3d_model, 167 | target_endpoint='Logits', 168 | flatten=True, 169 | grad_enabled=False): 170 | """ 171 | Get features from i3d model and flatten them to 1d feature, 172 | valid target endpoints are defined in InceptionI3d.VALID_ENDPOINTS 173 | VALID_ENDPOINTS = ( 174 | 'Conv3d_1a_7x7', 175 | 'MaxPool3d_2a_3x3', 176 | 'Conv3d_2b_1x1', 177 | 'Conv3d_2c_3x3', 178 | 'MaxPool3d_3a_3x3', 179 | 'Mixed_3b', 180 | 'Mixed_3c', 181 | 'MaxPool3d_4a_3x3', 182 | 'Mixed_4b', 183 | 'Mixed_4c', 184 | 'Mixed_4d', 185 | 'Mixed_4e', 186 | 'Mixed_4f', 187 | 'MaxPool3d_5a_2x2', 188 | 'Mixed_5b', 189 | 'Mixed_5c', 190 | 'Logits', 191 | 'Predictions', 192 | ) 193 | """ 194 | with torch.set_grad_enabled(grad_enabled): 195 | feat = i3d_model.extract_features(batched_video.transpose(1, 2), 196 | target_endpoint) 197 | if flatten: 198 | feat = feat.view(feat.size(0), -1) 199 | 200 | return feat 201 | 202 | 203 | # This code is from https://github.com/piergiaj/pytorch-i3d/blob/master/pytorch_i3d.py 204 | # I only fix flake8 errors and do some cleaning here 205 | 206 | 207 | class MaxPool3dSamePadding(nn.MaxPool3d): 208 | def compute_pad(self, dim, s): 209 | if s % self.stride[dim] == 0: 210 | return max(self.kernel_size[dim] - self.stride[dim], 0) 211 | else: 212 | return max(self.kernel_size[dim] - (s % self.stride[dim]), 0) 213 | 214 | def forward(self, x): 215 | # compute 'same' padding 216 | (batch, channel, t, h, w) = x.size() 217 | pad_t = self.compute_pad(0, t) 218 | pad_h = self.compute_pad(1, h) 219 | pad_w = self.compute_pad(2, w) 220 | 221 | pad_t_f = pad_t // 2 222 | pad_t_b = pad_t - pad_t_f 223 | pad_h_f = pad_h // 2 224 | pad_h_b = pad_h - pad_h_f 225 | pad_w_f = pad_w // 2 226 | pad_w_b = pad_w - pad_w_f 227 | 228 | pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b) 229 | x = F.pad(x, pad) 230 | return super(MaxPool3dSamePadding, self).forward(x) 231 | 232 | 233 | class Unit3D(nn.Module): 234 | def __init__(self, 235 | in_channels, 236 | output_channels, 237 | kernel_shape=(1, 1, 1), 238 | stride=(1, 1, 1), 239 | padding=0, 240 | activation_fn=F.relu, 241 | use_batch_norm=True, 242 | use_bias=False, 243 | name='unit_3d'): 244 | """Initializes Unit3D module.""" 245 | super(Unit3D, self).__init__() 246 | 247 | self._output_channels = output_channels 248 | self._kernel_shape = kernel_shape 249 | self._stride = stride 250 | self._use_batch_norm = use_batch_norm 251 | self._activation_fn = activation_fn 252 | self._use_bias = use_bias 253 | self.name = name 254 | self.padding = padding 255 | 256 | self.conv3d = nn.Conv3d( 257 | in_channels=in_channels, 258 | out_channels=self._output_channels, 259 | kernel_size=self._kernel_shape, 260 | stride=self._stride, 261 | padding=0, # we always want padding to be 0 here. We will 262 | # dynamically pad based on input size in forward function 263 | bias=self._use_bias) 264 | 265 | if self._use_batch_norm: 266 | self.bn = nn.BatchNorm3d(self._output_channels, 267 | eps=0.001, 268 | momentum=0.01) 269 | 270 | def compute_pad(self, dim, s): 271 | if s % self._stride[dim] == 0: 272 | return max(self._kernel_shape[dim] - self._stride[dim], 0) 273 | else: 274 | return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0) 275 | 276 | def forward(self, x): 277 | # compute 'same' padding 278 | (batch, channel, t, h, w) = x.size() 279 | pad_t = self.compute_pad(0, t) 280 | pad_h = self.compute_pad(1, h) 281 | pad_w = self.compute_pad(2, w) 282 | 283 | pad_t_f = pad_t // 2 284 | pad_t_b = pad_t - pad_t_f 285 | pad_h_f = pad_h // 2 286 | pad_h_b = pad_h - pad_h_f 287 | pad_w_f = pad_w // 2 288 | pad_w_b = pad_w - pad_w_f 289 | 290 | pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b) 291 | x = F.pad(x, pad) 292 | 293 | x = self.conv3d(x) 294 | if self._use_batch_norm: 295 | x = self.bn(x) 296 | if self._activation_fn is not None: 297 | x = self._activation_fn(x) 298 | return x 299 | 300 | 301 | class InceptionModule(nn.Module): 302 | def __init__(self, in_channels, out_channels, name): 303 | super(InceptionModule, self).__init__() 304 | 305 | self.b0 = Unit3D(in_channels=in_channels, 306 | output_channels=out_channels[0], 307 | kernel_shape=[1, 1, 1], 308 | padding=0, 309 | name=name + '/Branch_0/Conv3d_0a_1x1') 310 | self.b1a = Unit3D(in_channels=in_channels, 311 | output_channels=out_channels[1], 312 | kernel_shape=[1, 1, 1], 313 | padding=0, 314 | name=name + '/Branch_1/Conv3d_0a_1x1') 315 | self.b1b = Unit3D(in_channels=out_channels[1], 316 | output_channels=out_channels[2], 317 | kernel_shape=[3, 3, 3], 318 | name=name + '/Branch_1/Conv3d_0b_3x3') 319 | self.b2a = Unit3D(in_channels=in_channels, 320 | output_channels=out_channels[3], 321 | kernel_shape=[1, 1, 1], 322 | padding=0, 323 | name=name + '/Branch_2/Conv3d_0a_1x1') 324 | self.b2b = Unit3D(in_channels=out_channels[3], 325 | output_channels=out_channels[4], 326 | kernel_shape=[3, 3, 3], 327 | name=name + '/Branch_2/Conv3d_0b_3x3') 328 | self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3], 329 | stride=(1, 1, 1), 330 | padding=0) 331 | self.b3b = Unit3D(in_channels=in_channels, 332 | output_channels=out_channels[5], 333 | kernel_shape=[1, 1, 1], 334 | padding=0, 335 | name=name + '/Branch_3/Conv3d_0b_1x1') 336 | self.name = name 337 | 338 | def forward(self, x): 339 | b0 = self.b0(x) 340 | b1 = self.b1b(self.b1a(x)) 341 | b2 = self.b2b(self.b2a(x)) 342 | b3 = self.b3b(self.b3a(x)) 343 | return torch.cat([b0, b1, b2, b3], dim=1) 344 | 345 | 346 | class InceptionI3d(nn.Module): 347 | """Inception-v1 I3D architecture. 348 | The model is introduced in: 349 | Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset 350 | Joao Carreira, Andrew Zisserman 351 | https://arxiv.org/pdf/1705.07750v1.pdf. 352 | See also the Inception architecture, introduced in: 353 | Going deeper with convolutions 354 | Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, 355 | Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich. 356 | http://arxiv.org/pdf/1409.4842v1.pdf. 357 | """ 358 | 359 | # Endpoints of the model in order. During construction, all the endpoints up 360 | # to a designated `final_endpoint` are returned in a dictionary as the 361 | # second return value. 362 | VALID_ENDPOINTS = ( 363 | 'Conv3d_1a_7x7', 364 | 'MaxPool3d_2a_3x3', 365 | 'Conv3d_2b_1x1', 366 | 'Conv3d_2c_3x3', 367 | 'MaxPool3d_3a_3x3', 368 | 'Mixed_3b', 369 | 'Mixed_3c', 370 | 'MaxPool3d_4a_3x3', 371 | 'Mixed_4b', 372 | 'Mixed_4c', 373 | 'Mixed_4d', 374 | 'Mixed_4e', 375 | 'Mixed_4f', 376 | 'MaxPool3d_5a_2x2', 377 | 'Mixed_5b', 378 | 'Mixed_5c', 379 | 'Logits', 380 | 'Predictions', 381 | ) 382 | 383 | def __init__(self, 384 | num_classes=400, 385 | spatial_squeeze=True, 386 | final_endpoint='Logits', 387 | name='inception_i3d', 388 | in_channels=3, 389 | dropout_keep_prob=0.5): 390 | """Initializes I3D model instance. 391 | Args: 392 | num_classes: The number of outputs in the logit layer (default 400, which 393 | matches the Kinetics dataset). 394 | spatial_squeeze: Whether to squeeze the spatial dimensions for the logits 395 | before returning (default True). 396 | final_endpoint: The model contains many possible endpoints. 397 | `final_endpoint` specifies the last endpoint for the model to be built 398 | up to. In addition to the output at `final_endpoint`, all the outputs 399 | at endpoints up to `final_endpoint` will also be returned, in a 400 | dictionary. `final_endpoint` must be one of 401 | InceptionI3d.VALID_ENDPOINTS (default 'Logits'). 402 | name: A string (optional). The name of this module. 403 | Raises: 404 | ValueError: if `final_endpoint` is not recognized. 405 | """ 406 | 407 | if final_endpoint not in self.VALID_ENDPOINTS: 408 | raise ValueError('Unknown final endpoint %s' % final_endpoint) 409 | 410 | super(InceptionI3d, self).__init__() 411 | self._num_classes = num_classes 412 | self._spatial_squeeze = spatial_squeeze 413 | self._final_endpoint = final_endpoint 414 | self.logits = None 415 | 416 | if self._final_endpoint not in self.VALID_ENDPOINTS: 417 | raise ValueError('Unknown final endpoint %s' % 418 | self._final_endpoint) 419 | 420 | self.end_points = {} 421 | end_point = 'Conv3d_1a_7x7' 422 | self.end_points[end_point] = Unit3D(in_channels=in_channels, 423 | output_channels=64, 424 | kernel_shape=[7, 7, 7], 425 | stride=(2, 2, 2), 426 | padding=(3, 3, 3), 427 | name=name + end_point) 428 | if self._final_endpoint == end_point: 429 | return 430 | 431 | end_point = 'MaxPool3d_2a_3x3' 432 | self.end_points[end_point] = MaxPool3dSamePadding( 433 | kernel_size=[1, 3, 3], stride=(1, 2, 2), padding=0) 434 | if self._final_endpoint == end_point: 435 | return 436 | 437 | end_point = 'Conv3d_2b_1x1' 438 | self.end_points[end_point] = Unit3D(in_channels=64, 439 | output_channels=64, 440 | kernel_shape=[1, 1, 1], 441 | padding=0, 442 | name=name + end_point) 443 | if self._final_endpoint == end_point: 444 | return 445 | 446 | end_point = 'Conv3d_2c_3x3' 447 | self.end_points[end_point] = Unit3D(in_channels=64, 448 | output_channels=192, 449 | kernel_shape=[3, 3, 3], 450 | padding=1, 451 | name=name + end_point) 452 | if self._final_endpoint == end_point: 453 | return 454 | 455 | end_point = 'MaxPool3d_3a_3x3' 456 | self.end_points[end_point] = MaxPool3dSamePadding( 457 | kernel_size=[1, 3, 3], stride=(1, 2, 2), padding=0) 458 | if self._final_endpoint == end_point: 459 | return 460 | 461 | end_point = 'Mixed_3b' 462 | self.end_points[end_point] = InceptionModule(192, 463 | [64, 96, 128, 16, 32, 32], 464 | name + end_point) 465 | if self._final_endpoint == end_point: 466 | return 467 | 468 | end_point = 'Mixed_3c' 469 | self.end_points[end_point] = InceptionModule( 470 | 256, [128, 128, 192, 32, 96, 64], name + end_point) 471 | if self._final_endpoint == end_point: 472 | return 473 | 474 | end_point = 'MaxPool3d_4a_3x3' 475 | self.end_points[end_point] = MaxPool3dSamePadding( 476 | kernel_size=[3, 3, 3], stride=(2, 2, 2), padding=0) 477 | if self._final_endpoint == end_point: 478 | return 479 | 480 | end_point = 'Mixed_4b' 481 | self.end_points[end_point] = InceptionModule( 482 | 128 + 192 + 96 + 64, [192, 96, 208, 16, 48, 64], name + end_point) 483 | if self._final_endpoint == end_point: 484 | return 485 | 486 | end_point = 'Mixed_4c' 487 | self.end_points[end_point] = InceptionModule( 488 | 192 + 208 + 48 + 64, [160, 112, 224, 24, 64, 64], name + end_point) 489 | if self._final_endpoint == end_point: 490 | return 491 | 492 | end_point = 'Mixed_4d' 493 | self.end_points[end_point] = InceptionModule( 494 | 160 + 224 + 64 + 64, [128, 128, 256, 24, 64, 64], name + end_point) 495 | if self._final_endpoint == end_point: 496 | return 497 | 498 | end_point = 'Mixed_4e' 499 | self.end_points[end_point] = InceptionModule( 500 | 128 + 256 + 64 + 64, [112, 144, 288, 32, 64, 64], name + end_point) 501 | if self._final_endpoint == end_point: 502 | return 503 | 504 | end_point = 'Mixed_4f' 505 | self.end_points[end_point] = InceptionModule( 506 | 112 + 288 + 64 + 64, [256, 160, 320, 32, 128, 128], 507 | name + end_point) 508 | if self._final_endpoint == end_point: 509 | return 510 | 511 | end_point = 'MaxPool3d_5a_2x2' 512 | self.end_points[end_point] = MaxPool3dSamePadding( 513 | kernel_size=[2, 2, 2], stride=(2, 2, 2), padding=0) 514 | if self._final_endpoint == end_point: 515 | return 516 | 517 | end_point = 'Mixed_5b' 518 | self.end_points[end_point] = InceptionModule( 519 | 256 + 320 + 128 + 128, [256, 160, 320, 32, 128, 128], 520 | name + end_point) 521 | if self._final_endpoint == end_point: 522 | return 523 | 524 | end_point = 'Mixed_5c' 525 | self.end_points[end_point] = InceptionModule( 526 | 256 + 320 + 128 + 128, [384, 192, 384, 48, 128, 128], 527 | name + end_point) 528 | if self._final_endpoint == end_point: 529 | return 530 | 531 | end_point = 'Logits' 532 | self.avg_pool = nn.AvgPool3d(kernel_size=[2, 7, 7], stride=(1, 1, 1)) 533 | self.dropout = nn.Dropout(dropout_keep_prob) 534 | self.logits = Unit3D(in_channels=384 + 384 + 128 + 128, 535 | output_channels=self._num_classes, 536 | kernel_shape=[1, 1, 1], 537 | padding=0, 538 | activation_fn=None, 539 | use_batch_norm=False, 540 | use_bias=True, 541 | name='logits') 542 | 543 | self.build() 544 | 545 | def replace_logits(self, num_classes): 546 | self._num_classes = num_classes 547 | self.logits = Unit3D(in_channels=384 + 384 + 128 + 128, 548 | output_channels=self._num_classes, 549 | kernel_shape=[1, 1, 1], 550 | padding=0, 551 | activation_fn=None, 552 | use_batch_norm=False, 553 | use_bias=True, 554 | name='logits') 555 | 556 | def build(self): 557 | for k in self.end_points.keys(): 558 | self.add_module(k, self.end_points[k]) 559 | 560 | def forward(self, x): 561 | for end_point in self.VALID_ENDPOINTS: 562 | if end_point in self.end_points: 563 | x = self._modules[end_point]( 564 | x) # use _modules to work with dataparallel 565 | 566 | x = self.logits(self.dropout(self.avg_pool(x))) 567 | if self._spatial_squeeze: 568 | logits = x.squeeze(3).squeeze(3) 569 | # logits is batch X time X classes, which is what we want to work with 570 | return logits 571 | 572 | def extract_features(self, x, target_endpoint='Logits'): 573 | for end_point in self.VALID_ENDPOINTS: 574 | if end_point in self.end_points: 575 | x = self._modules[end_point](x) 576 | if end_point == target_endpoint: 577 | break 578 | if target_endpoint == 'Logits': 579 | return x.mean(4).mean(3).mean(2) 580 | else: 581 | return x 582 | -------------------------------------------------------------------------------- /core/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import cv2 4 | import random 5 | import numpy as np 6 | from PIL import Image, ImageOps 7 | import zipfile 8 | 9 | import torch 10 | import matplotlib 11 | import matplotlib.patches as patches 12 | from matplotlib.path import Path 13 | from matplotlib import pyplot as plt 14 | from torchvision import transforms 15 | 16 | # 解决随机mask生成的OMP Error15报错 17 | os.environ['KMP_DUPLICATE_LIB_OK']='True' 18 | 19 | # ########################################################################### 20 | # Directory IO 21 | # ########################################################################### 22 | 23 | 24 | def read_dirnames_under_root(root_dir): 25 | dirnames = [ 26 | name for i, name in enumerate(sorted(os.listdir(root_dir))) 27 | if os.path.isdir(os.path.join(root_dir, name)) 28 | ] 29 | print(f'Reading directories under {root_dir}, num: {len(dirnames)}') 30 | return dirnames 31 | 32 | 33 | class TrainZipReader(object): 34 | file_dict = dict() 35 | 36 | def __init__(self): 37 | super(TrainZipReader, self).__init__() 38 | 39 | @staticmethod 40 | def build_file_dict(path): 41 | file_dict = TrainZipReader.file_dict 42 | if path in file_dict: 43 | return file_dict[path] 44 | else: 45 | file_handle = zipfile.ZipFile(path, 'r') 46 | file_dict[path] = file_handle 47 | return file_dict[path] 48 | 49 | @staticmethod 50 | def imread(path, idx): 51 | zfile = TrainZipReader.build_file_dict(path) 52 | filelist = zfile.namelist() 53 | filelist.sort() 54 | data = zfile.read(filelist[idx]) 55 | # 56 | im = Image.open(io.BytesIO(data)) 57 | return im 58 | 59 | 60 | class TestZipReader(object): 61 | file_dict = dict() 62 | 63 | def __init__(self): 64 | super(TestZipReader, self).__init__() 65 | 66 | @staticmethod 67 | def build_file_dict(path): 68 | file_dict = TestZipReader.file_dict 69 | if path in file_dict: 70 | return file_dict[path] 71 | else: 72 | file_handle = zipfile.ZipFile(path, 'r') 73 | file_dict[path] = file_handle 74 | return file_dict[path] 75 | 76 | @staticmethod 77 | def imread(path, idx): 78 | zfile = TestZipReader.build_file_dict(path) 79 | filelist = zfile.namelist() 80 | filelist.sort() 81 | data = zfile.read(filelist[idx]) 82 | file_bytes = np.asarray(bytearray(data), dtype=np.uint8) 83 | im = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR) 84 | im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB)) 85 | # im = Image.open(io.BytesIO(data)) 86 | return im 87 | 88 | 89 | # ########################################################################### 90 | # Data augmentation 91 | # ########################################################################### 92 | 93 | 94 | def to_tensors(): 95 | return transforms.Compose([Stack(), ToTorchFormatTensor()]) 96 | 97 | 98 | class GroupRandomHorizontalFlowFlip(object): 99 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 100 | """ 101 | def __init__(self, is_flow=True): 102 | self.is_flow = is_flow 103 | 104 | def __call__(self, img_group, mask_group, flowF_group, flowB_group): 105 | v = random.random() 106 | if v < 0.5: 107 | ret_img = [ 108 | img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group 109 | ] 110 | ret_mask = [ 111 | mask.transpose(Image.FLIP_LEFT_RIGHT) for mask in mask_group 112 | ] 113 | ret_flowF = [ff[:, ::-1] * [-1.0, 1.0] for ff in flowF_group] 114 | ret_flowB = [fb[:, ::-1] * [-1.0, 1.0] for fb in flowB_group] 115 | return ret_img, ret_mask, ret_flowF, ret_flowB 116 | else: 117 | return img_group, mask_group, flowF_group, flowB_group 118 | 119 | 120 | class GroupRandomHorizontalFlip(object): 121 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 122 | """ 123 | def __init__(self, is_flow=False): 124 | self.is_flow = is_flow 125 | 126 | def __call__(self, img_group, is_flow=False): 127 | v = random.random() 128 | if v < 0.5: 129 | ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group] 130 | if self.is_flow: 131 | for i in range(0, len(ret), 2): 132 | # invert flow pixel values when flipping 133 | ret[i] = ImageOps.invert(ret[i]) 134 | return ret 135 | else: 136 | return img_group 137 | 138 | 139 | class Stack(object): 140 | def __init__(self, roll=False): 141 | self.roll = roll 142 | 143 | def __call__(self, img_group): 144 | mode = img_group[0].mode 145 | if mode == '1': 146 | img_group = [img.convert('L') for img in img_group] 147 | mode = 'L' 148 | if mode == 'L': 149 | return np.stack([np.expand_dims(x, 2) for x in img_group], axis=2) 150 | elif mode == 'RGB': 151 | if self.roll: 152 | return np.stack([np.array(x)[:, :, ::-1] for x in img_group], 153 | axis=2) 154 | else: 155 | return np.stack(img_group, axis=2) 156 | else: 157 | raise NotImplementedError(f"Image mode {mode}") 158 | 159 | 160 | class ToTorchFormatTensor(object): 161 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] 162 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ 163 | def __init__(self, div=True): 164 | self.div = div 165 | 166 | def __call__(self, pic): 167 | if isinstance(pic, np.ndarray): 168 | # numpy img: [L, C, H, W] 169 | img = torch.from_numpy(pic).permute(2, 3, 0, 1).contiguous() 170 | else: 171 | # handle PIL Image 172 | img = torch.ByteTensor(torch.ByteStorage.from_buffer( 173 | pic.tobytes())) 174 | img = img.view(pic.size[1], pic.size[0], len(pic.mode)) 175 | # put it from HWC to CHW format 176 | # yikes, this transpose takes 80% of the loading time/CPU 177 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 178 | img = img.float().div(255) if self.div else img.float() 179 | return img 180 | 181 | 182 | # ########################################################################### 183 | # Create masks with random shape 184 | # ########################################################################### 185 | 186 | 187 | def create_random_shape_with_random_motion(video_length, 188 | imageHeight=240, 189 | imageWidth=432): 190 | # get a random shape 191 | height = random.randint(imageHeight // 3, imageHeight - 1) 192 | width = random.randint(imageWidth // 3, imageWidth - 1) 193 | edge_num = random.randint(6, 8) 194 | ratio = random.randint(6, 8) / 10 195 | region = get_random_shape(edge_num=edge_num, 196 | ratio=ratio, 197 | height=height, 198 | width=width) 199 | region_width, region_height = region.size 200 | # get random position 201 | x, y = random.randint(0, imageHeight - region_height), random.randint( 202 | 0, imageWidth - region_width) 203 | velocity = get_random_velocity(max_speed=3) 204 | m = Image.fromarray(np.zeros((imageHeight, imageWidth)).astype(np.uint8)) 205 | m.paste(region, (y, x, y + region.size[0], x + region.size[1])) 206 | masks = [m.convert('L')] 207 | # return fixed masks 208 | if random.uniform(0, 1) > 0.5: 209 | return masks * video_length 210 | # return moving masks 211 | for _ in range(video_length - 1): 212 | x, y, velocity = random_move_control_points(x, 213 | y, 214 | imageHeight, 215 | imageWidth, 216 | velocity, 217 | region.size, 218 | maxLineAcceleration=(3, 219 | 0.5), 220 | maxInitSpeed=3) 221 | m = Image.fromarray( 222 | np.zeros((imageHeight, imageWidth)).astype(np.uint8)) 223 | m.paste(region, (y, x, y + region.size[0], x + region.size[1])) 224 | masks.append(m.convert('L')) 225 | return masks 226 | 227 | 228 | def create_random_shape_with_random_motion_seq(video_length, 229 | imageHeight=240, 230 | imageWidth=432, 231 | new_mask=True, 232 | random_dict=None 233 | ): 234 | """ 235 | Sequence Mask Generator by Hao. 236 | new_mask (bool): If True, generate a new batch of mask; 237 | If False, use the previous params to generate the same mask. 238 | """ 239 | 240 | if new_mask: 241 | # 生成新mask的一大堆参数 242 | # get a random shape 243 | height = random.randint(imageHeight // 3, imageHeight - 1) 244 | width = random.randint(imageWidth // 3, imageWidth - 1) 245 | edge_num = random.randint(6, 8) 246 | ratio = random.randint(6, 8) / 10 247 | region, random_point = get_random_shape_seq(edge_num=edge_num, 248 | ratio=ratio, 249 | height=height, 250 | width=width) 251 | region_width, region_height = region.size 252 | 253 | # get random position 254 | x, y = random.randint(0, imageHeight - region_height), random.randint( 255 | 0, imageWidth - region_width) 256 | velocity = get_random_velocity(max_speed=3) 257 | 258 | # get random probability 259 | prob = random.uniform(0, 1) 260 | 261 | # 存储新生成的随机参数字典 262 | random_dict = {} 263 | # random_dict['region'] = region # pytorch 的 dataloader无法返回PIL-Image格式的数据 264 | random_dict['edge_num'] = edge_num 265 | random_dict['ratio'] = ratio 266 | random_dict['height'] = height 267 | random_dict['width'] = width 268 | random_dict['random_point'] = random_point 269 | random_dict['x'] = x 270 | random_dict['y'] = y 271 | random_dict['velocity'] = velocity 272 | random_dict['prob'] = prob 273 | 274 | else: 275 | # 用旧的参数,不用重新生成 276 | # region = random_dict['region'] 277 | edge_num = random_dict['edge_num'] 278 | ratio = random_dict['ratio'] 279 | height = random_dict['height'] 280 | width = random_dict['width'] 281 | random_point = random_dict['random_point'] 282 | # 当random point固定后区域实际上已经固定了 283 | region, random_point = get_random_shape_seq(edge_num=edge_num, 284 | ratio=ratio, 285 | height=height, 286 | width=width, 287 | random_point=random_point) 288 | 289 | x = random_dict['x'] 290 | y = random_dict['y'] 291 | velocity = random_dict['velocity'] 292 | prob = random_dict['prob'] 293 | 294 | # 创建静态mask 295 | m = Image.fromarray(np.zeros((imageHeight, imageWidth)).astype(np.uint8)) 296 | m.paste(region, (y, x, y + region.size[0], x + region.size[1])) 297 | masks = [m.convert('L')] 298 | 299 | # return fixed masks 300 | if prob > 0.5: 301 | return masks * video_length, random_dict 302 | 303 | # return moving masks 304 | # 由于moving有个随机运动, 也希望一致所以需要返回这些位置和速度 305 | if new_mask: 306 | x_list = [] 307 | y_list = [] 308 | velocity_list = [] 309 | else: 310 | # 如果使用之前的运动参数,读取一下 311 | x_list = random_dict['x_list'] 312 | y_list = random_dict['y_list'] 313 | velocity_list = random_dict['velocity_list'] 314 | 315 | for idx in range(video_length - 1): 316 | if new_mask: 317 | # 重新生成每一帧的随机 318 | x, y, velocity = random_move_control_points(x, 319 | y, 320 | imageHeight, 321 | imageWidth, 322 | velocity, 323 | region.size, 324 | maxLineAcceleration=(3, 325 | 0.5), 326 | maxInitSpeed=3) 327 | # 存储位置和速度 328 | x_list.append(x) 329 | y_list.append(y) 330 | velocity_list.append(velocity) 331 | else: 332 | # 直接从dict的list里面读取 333 | x = x_list[idx] 334 | y = y_list[idx] 335 | velocity = velocity_list[idx] 336 | 337 | # 生成当前帧的运动mask 338 | m = Image.fromarray( 339 | np.zeros((imageHeight, imageWidth)).astype(np.uint8)) 340 | m.paste(region, (y, x, y + region.size[0], x + region.size[1])) 341 | masks.append(m.convert('L')) 342 | 343 | # 把新生成的运动mask位置和速度保存到字典里 344 | if new_mask: 345 | random_dict['x_list'] = x_list 346 | random_dict['y_list'] = y_list 347 | random_dict['velocity_list'] = velocity_list 348 | 349 | return masks, random_dict 350 | 351 | 352 | def get_random_shape(edge_num=9, ratio=0.7, width=432, height=240): 353 | ''' 354 | There is the initial point and 3 points per cubic bezier curve. 355 | Thus, the curve will only pass though n points, which will be the sharp edges. 356 | The other 2 modify the shape of the bezier curve. 357 | edge_num, Number of possibly sharp edges 358 | points_num, number of points in the Path 359 | ratio, (0, 1) magnitude of the perturbation from the unit circle, 360 | ''' 361 | points_num = edge_num * 3 + 1 362 | angles = np.linspace(0, 2 * np.pi, points_num) 363 | codes = np.full(points_num, Path.CURVE4) 364 | codes[0] = Path.MOVETO 365 | # Using this instad of Path.CLOSEPOLY avoids an innecessary straight line 366 | verts = np.stack((np.cos(angles), np.sin(angles))).T * \ 367 | (2*ratio*np.random.random(points_num)+1-ratio)[:, None] 368 | verts[-1, :] = verts[0, :] 369 | path = Path(verts, codes) 370 | # draw paths into images 371 | fig = plt.figure() 372 | ax = fig.add_subplot(111) 373 | patch = patches.PathPatch(path, facecolor='black', lw=2) 374 | ax.add_patch(patch) 375 | ax.set_xlim(np.min(verts) * 1.1, np.max(verts) * 1.1) 376 | ax.set_ylim(np.min(verts) * 1.1, np.max(verts) * 1.1) 377 | ax.axis('off') # removes the axis to leave only the shape 378 | fig.canvas.draw() 379 | # convert plt images into numpy images 380 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 381 | data = data.reshape((fig.canvas.get_width_height()[::-1] + (3, ))) 382 | plt.close(fig) 383 | # postprocess 384 | data = cv2.resize(data, (width, height))[:, :, 0] 385 | data = (1 - np.array(data > 0).astype(np.uint8)) * 255 386 | corrdinates = np.where(data > 0) 387 | xmin, xmax, ymin, ymax = np.min(corrdinates[0]), np.max( 388 | corrdinates[0]), np.min(corrdinates[1]), np.max(corrdinates[1]) 389 | region = Image.fromarray(data).crop((ymin, xmin, ymax, xmax)) 390 | return region 391 | 392 | 393 | def get_random_shape_seq(edge_num=9, ratio=0.7, width=432, height=240, random_point=None): 394 | ''' 395 | There is the initial point and 3 points per cubic bezier curve. 396 | Thus, the curve will only pass though n points, which will be the sharp edges. 397 | The other 2 modify the shape of the bezier curve. 398 | edge_num, Number of possibly sharp edges 399 | points_num, number of points in the Path 400 | ratio, (0, 1) magnitude of the perturbation from the unit circle, 401 | Revised by Hao: 402 | random_point: if given, the shape is known and fixed. 403 | ''' 404 | points_num = edge_num * 3 + 1 405 | 406 | if random_point is None: 407 | random_point = np.random.random(points_num) 408 | else: 409 | random_point = random_point 410 | 411 | angles = np.linspace(0, 2 * np.pi, points_num) 412 | codes = np.full(points_num, Path.CURVE4) 413 | codes[0] = Path.MOVETO 414 | # Using this instad of Path.CLOSEPOLY avoids an innecessary straight line 415 | verts = np.stack((np.cos(angles), np.sin(angles))).T * \ 416 | (2*ratio*random_point+1-ratio)[:, None] 417 | verts[-1, :] = verts[0, :] 418 | path = Path(verts, codes) 419 | # draw paths into images 420 | fig = plt.figure() 421 | ax = fig.add_subplot(111) 422 | patch = patches.PathPatch(path, facecolor='black', lw=2) 423 | ax.add_patch(patch) 424 | ax.set_xlim(np.min(verts) * 1.1, np.max(verts) * 1.1) 425 | ax.set_ylim(np.min(verts) * 1.1, np.max(verts) * 1.1) 426 | ax.axis('off') # removes the axis to leave only the shape 427 | fig.canvas.draw() 428 | # convert plt images into numpy images 429 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 430 | data = data.reshape((fig.canvas.get_width_height()[::-1] + (3, ))) 431 | plt.close(fig) 432 | # postprocess 433 | data = cv2.resize(data, (width, height))[:, :, 0] 434 | data = (1 - np.array(data > 0).astype(np.uint8)) * 255 435 | corrdinates = np.where(data > 0) 436 | xmin, xmax, ymin, ymax = np.min(corrdinates[0]), np.max( 437 | corrdinates[0]), np.min(corrdinates[1]), np.max(corrdinates[1]) 438 | region = Image.fromarray(data).crop((ymin, xmin, ymax, xmax)) 439 | return region, random_point 440 | 441 | 442 | def random_accelerate(velocity, maxAcceleration, dist='uniform'): 443 | speed, angle = velocity 444 | d_speed, d_angle = maxAcceleration 445 | if dist == 'uniform': 446 | speed += np.random.uniform(-d_speed, d_speed) 447 | angle += np.random.uniform(-d_angle, d_angle) 448 | elif dist == 'guassian': 449 | speed += np.random.normal(0, d_speed / 2) 450 | angle += np.random.normal(0, d_angle / 2) 451 | else: 452 | raise NotImplementedError( 453 | f'Distribution type {dist} is not supported.') 454 | return (speed, angle) 455 | 456 | 457 | def get_random_velocity(max_speed=3, dist='uniform'): 458 | if dist == 'uniform': 459 | speed = np.random.uniform(max_speed) 460 | elif dist == 'guassian': 461 | speed = np.abs(np.random.normal(0, max_speed / 2)) 462 | else: 463 | raise NotImplementedError( 464 | f'Distribution type {dist} is not supported.') 465 | angle = np.random.uniform(0, 2 * np.pi) 466 | return (speed, angle) 467 | 468 | 469 | def random_move_control_points(X, 470 | Y, 471 | imageHeight, 472 | imageWidth, 473 | lineVelocity, 474 | region_size, 475 | maxLineAcceleration=(3, 0.5), 476 | maxInitSpeed=3): 477 | region_width, region_height = region_size 478 | speed, angle = lineVelocity 479 | X += int(speed * np.cos(angle)) 480 | Y += int(speed * np.sin(angle)) 481 | lineVelocity = random_accelerate(lineVelocity, 482 | maxLineAcceleration, 483 | dist='guassian') 484 | if ((X > imageHeight - region_height) or (X < 0) 485 | or (Y > imageWidth - region_width) or (Y < 0)): 486 | lineVelocity = get_random_velocity(maxInitSpeed, dist='guassian') 487 | new_X = np.clip(X, 0, imageHeight - region_height) 488 | new_Y = np.clip(Y, 0, imageWidth - region_width) 489 | return new_X, new_Y, lineVelocity 490 | 491 | 492 | if __name__ == '__main__': 493 | 494 | trials = 10 495 | for _ in range(trials): 496 | video_length = 10 497 | # The returned masks are either stationary (50%) or moving (50%) 498 | masks = create_random_shape_with_random_motion(video_length, 499 | imageHeight=240, 500 | imageWidth=432) 501 | 502 | for m in masks: 503 | cv2.imshow('mask', np.array(m)) 504 | cv2.waitKey(500) 505 | -------------------------------------------------------------------------------- /datasets/KITTI-360EX/InnerSphere/test.json: -------------------------------------------------------------------------------- 1 | {"00000719": 100, "00000720": 100, "00000721": 100, "00000722": 100, "00000723": 100, "00000724": 100, "00000725": 100, "00000726": 100, "00000727": 100, "00000728": 100, "00000729": 100, "00000730": 100, "00000731": 100, "00000732": 100, "00000733": 100, "00000734": 100, "00000735": 100, "00000736": 100, "00000737": 100, "00000738": 100, "00000739": 100, "00000740": 100, "00000741": 100, "00000742": 100, "00000743": 100, "00000744": 100, "00000745": 100, "00000746": 100, "00000747": 100, "00000748": 100, "00000749": 100, "00000750": 100, "00000751": 100, "00000752": 100, "00000753": 100, "00000754": 100, "00000755": 100, "00000756": 100} -------------------------------------------------------------------------------- /datasets/KITTI-360EX/InnerSphere/train.json: -------------------------------------------------------------------------------- 1 | {"00000000": 100, "00000001": 100, "00000002": 100, "00000003": 100, "00000004": 100, "00000005": 100, "00000006": 100, "00000007": 100, "00000008": 100, "00000009": 100, "00000010": 100, "00000011": 100, "00000012": 100, "00000013": 100, "00000014": 100, "00000015": 100, "00000016": 100, "00000017": 100, "00000018": 100, "00000019": 100, "00000020": 100, "00000021": 100, "00000022": 100, "00000023": 100, "00000024": 100, "00000025": 100, "00000026": 100, "00000027": 100, "00000028": 100, "00000029": 100, "00000030": 100, "00000031": 100, "00000032": 100, "00000033": 100, "00000034": 100, "00000035": 100, "00000036": 100, "00000037": 100, "00000038": 100, "00000039": 100, "00000040": 100, "00000041": 100, "00000042": 100, "00000043": 100, "00000044": 100, "00000045": 100, "00000046": 100, "00000047": 100, "00000048": 100, "00000049": 100, "00000050": 100, "00000051": 100, "00000052": 100, "00000053": 100, "00000054": 100, "00000055": 100, "00000056": 100, "00000057": 100, "00000058": 100, "00000059": 100, "00000060": 100, "00000061": 100, "00000062": 100, "00000063": 100, "00000064": 100, "00000065": 100, "00000066": 100, "00000067": 100, "00000068": 100, "00000069": 100, "00000070": 100, "00000071": 100, "00000072": 100, "00000073": 100, "00000074": 100, "00000075": 100, "00000076": 100, "00000077": 100, "00000078": 100, "00000079": 100, "00000080": 100, "00000081": 100, "00000082": 100, "00000083": 100, "00000084": 100, "00000085": 100, "00000086": 100, "00000087": 100, "00000088": 100, "00000089": 100, "00000090": 100, "00000091": 100, "00000092": 100, "00000093": 100, "00000094": 100, "00000095": 100, "00000096": 100, "00000097": 100, "00000098": 100, "00000099": 100, "00000100": 100, "00000101": 100, "00000102": 100, "00000103": 100, "00000104": 100, "00000105": 100, "00000106": 100, "00000107": 100, "00000108": 100, "00000109": 100, "00000110": 100, "00000111": 100, "00000112": 100, "00000113": 100, "00000114": 100, "00000115": 100, "00000116": 100, "00000117": 100, "00000118": 100, "00000119": 100, "00000120": 100, "00000121": 100, "00000122": 100, "00000123": 100, "00000124": 100, "00000125": 100, "00000126": 100, "00000127": 100, "00000128": 100, "00000129": 100, "00000130": 100, "00000131": 100, "00000132": 100, "00000133": 100, "00000134": 100, "00000135": 100, "00000136": 100, "00000137": 100, "00000138": 100, "00000139": 100, "00000140": 100, "00000141": 100, "00000142": 100, "00000143": 100, "00000144": 100, "00000145": 100, "00000146": 100, "00000147": 100, "00000148": 100, "00000149": 100, "00000150": 100, "00000151": 100, "00000152": 100, "00000153": 100, "00000154": 100, "00000155": 100, "00000156": 100, "00000157": 100, "00000158": 100, "00000159": 100, "00000160": 100, "00000161": 100, "00000162": 100, "00000163": 100, "00000164": 100, "00000165": 100, "00000166": 100, "00000167": 100, "00000168": 100, "00000169": 100, "00000170": 100, "00000171": 100, "00000172": 100, "00000173": 100, "00000174": 100, "00000175": 100, "00000176": 100, "00000177": 100, "00000178": 100, "00000179": 100, "00000180": 100, "00000181": 100, "00000182": 100, "00000183": 100, "00000184": 100, "00000185": 100, "00000186": 100, "00000187": 100, "00000188": 100, "00000189": 100, "00000190": 100, "00000191": 100, "00000192": 100, "00000193": 100, "00000194": 100, "00000195": 100, "00000196": 100, "00000197": 100, "00000198": 100, "00000199": 100, "00000200": 100, "00000201": 100, "00000202": 100, "00000203": 100, "00000204": 100, "00000205": 100, "00000206": 100, "00000207": 100, "00000208": 100, "00000209": 100, "00000210": 100, "00000211": 100, "00000212": 100, "00000213": 100, "00000214": 100, "00000215": 100, "00000216": 100, "00000217": 100, "00000218": 100, "00000219": 100, "00000220": 100, "00000221": 100, "00000222": 100, "00000223": 100, "00000224": 100, "00000225": 100, "00000226": 100, "00000227": 100, "00000228": 100, "00000229": 100, "00000230": 100, "00000231": 100, "00000232": 100, "00000233": 100, "00000234": 100, "00000235": 100, "00000236": 100, "00000237": 100, "00000238": 100, "00000239": 100, "00000240": 100, "00000241": 100, "00000242": 100, "00000243": 100, "00000244": 100, "00000245": 100, "00000246": 100, "00000247": 100, "00000248": 100, "00000249": 100, "00000250": 100, "00000251": 100, "00000252": 100, "00000253": 100, "00000254": 100, "00000255": 100, "00000256": 100, "00000257": 100, "00000258": 100, "00000259": 100, "00000260": 100, "00000261": 100, "00000262": 100, "00000263": 100, "00000264": 100, "00000265": 100, "00000266": 100, "00000267": 100, "00000268": 100, "00000269": 100, "00000270": 100, "00000271": 100, "00000272": 100, "00000273": 100, "00000274": 100, "00000275": 100, "00000276": 100, "00000277": 100, "00000278": 100, "00000279": 100, "00000280": 100, "00000281": 100, "00000282": 100, "00000283": 100, "00000284": 100, "00000285": 100, "00000286": 100, "00000287": 100, "00000288": 100, "00000289": 100, "00000290": 100, "00000291": 100, "00000292": 100, "00000293": 100, "00000294": 100, "00000295": 100, "00000296": 100, "00000297": 100, "00000298": 100, "00000299": 100, "00000300": 100, "00000301": 100, "00000302": 100, "00000303": 100, "00000304": 100, "00000305": 100, "00000306": 100, "00000307": 100, "00000308": 100, "00000309": 100, "00000310": 100, "00000311": 100, "00000312": 100, "00000313": 100, "00000314": 100, "00000315": 100, "00000316": 100, "00000317": 100, "00000318": 100, "00000319": 100, "00000320": 100, "00000321": 100, "00000322": 100, "00000323": 100, "00000324": 100, "00000325": 100, "00000326": 100, "00000327": 100, "00000328": 100, "00000329": 100, "00000330": 100, "00000331": 100, "00000332": 100, "00000333": 100, "00000334": 100, "00000335": 100, "00000336": 100, "00000337": 100, "00000338": 100, "00000339": 100, "00000340": 100, "00000341": 100, "00000342": 100, "00000343": 100, "00000344": 100, "00000345": 100, "00000346": 100, "00000347": 100, "00000348": 100, "00000349": 100, "00000350": 100, "00000351": 100, "00000352": 100, "00000353": 100, "00000354": 100, "00000355": 100, "00000356": 100, "00000357": 100, "00000358": 100, "00000359": 100, "00000360": 100, "00000361": 100, "00000362": 100, "00000363": 100, "00000364": 100, "00000365": 100, "00000366": 100, "00000367": 100, "00000368": 100, "00000369": 100, "00000370": 100, "00000371": 100, "00000372": 100, "00000373": 100, "00000374": 100, "00000375": 100, "00000376": 100, "00000377": 100, "00000378": 100, "00000379": 100, "00000380": 100, "00000381": 100, "00000382": 100, "00000383": 100, "00000384": 100, "00000385": 100, "00000386": 100, "00000387": 100, "00000388": 100, "00000389": 100, "00000390": 100, "00000391": 100, "00000392": 100, "00000393": 100, "00000394": 100, "00000395": 100, "00000396": 100, "00000397": 100, "00000398": 100, "00000399": 100, "00000400": 100, "00000401": 100, "00000402": 100, "00000403": 100, "00000404": 100, "00000405": 100, "00000406": 100, "00000407": 100, "00000408": 100, "00000409": 100, "00000410": 100, "00000411": 100, "00000412": 100, "00000413": 100, "00000414": 100, "00000415": 100, "00000416": 100, "00000417": 100, "00000418": 100, "00000419": 100, "00000420": 100, "00000421": 100, "00000422": 100, "00000423": 100, "00000424": 100, "00000425": 100, "00000426": 100, "00000427": 100, "00000428": 100, "00000429": 100, "00000430": 100, "00000431": 100, "00000432": 100, "00000433": 100, "00000434": 100, "00000435": 100, "00000436": 100, "00000437": 100, "00000438": 100, "00000439": 100, "00000440": 100, "00000441": 100, "00000442": 100, "00000443": 100, "00000444": 100, "00000445": 100, "00000446": 100, "00000447": 100, "00000448": 100, "00000449": 100, "00000450": 100, "00000451": 100, "00000452": 100, "00000453": 100, "00000454": 100, "00000455": 100, "00000456": 100, "00000457": 100, "00000458": 100, "00000459": 100, "00000460": 100, "00000461": 100, "00000462": 100, "00000463": 100, "00000464": 100, "00000465": 100, "00000466": 100, "00000467": 100, "00000468": 100, "00000469": 100, "00000470": 100, "00000471": 100, "00000472": 100, "00000473": 100, "00000474": 100, "00000475": 100, "00000476": 100, "00000477": 100, "00000478": 100, "00000479": 100, "00000480": 100, "00000481": 100, "00000482": 100, "00000483": 100, "00000484": 100, "00000485": 100, "00000486": 100, "00000487": 100, "00000488": 100, "00000489": 100, "00000490": 100, "00000491": 100, "00000492": 100, "00000493": 100, "00000494": 100, "00000495": 100, "00000496": 100, "00000497": 100, "00000498": 100, "00000499": 100, "00000500": 100, "00000501": 100, "00000502": 100, "00000503": 100, "00000504": 100, "00000505": 100, "00000506": 100, "00000507": 100, "00000508": 100, "00000509": 100, "00000510": 100, "00000511": 100, "00000512": 100, "00000513": 100, "00000514": 100, "00000515": 100, "00000516": 100, "00000517": 100, "00000518": 100, "00000519": 100, "00000520": 100, "00000521": 100, "00000522": 100, "00000523": 100, "00000524": 100, "00000525": 100, "00000526": 100, "00000527": 100, "00000528": 100, "00000529": 100, "00000530": 100, "00000531": 100, "00000532": 100, "00000533": 100, "00000534": 100, "00000535": 100, "00000536": 100, "00000537": 100, "00000538": 100, "00000539": 100, "00000540": 100, "00000541": 100, "00000542": 100, "00000543": 100, "00000544": 100, "00000545": 100, "00000546": 100, "00000547": 100, "00000548": 100, "00000549": 100, "00000550": 100, "00000551": 100, "00000552": 100, "00000553": 100, "00000554": 100, "00000555": 100, "00000556": 100, "00000557": 100, "00000558": 100, "00000559": 100, "00000560": 100, "00000561": 100, "00000562": 100, "00000563": 100, "00000564": 100, "00000565": 100, "00000566": 100, "00000567": 100, "00000568": 100, "00000569": 100, "00000570": 100, "00000571": 100, "00000572": 100, "00000573": 100, "00000574": 100, "00000575": 100, "00000576": 100, "00000577": 100, "00000578": 100, "00000579": 100, "00000580": 100, "00000581": 100, "00000582": 100, "00000583": 100, "00000584": 100, "00000585": 100, "00000586": 100, "00000587": 100, "00000588": 100, "00000589": 100, "00000590": 100, "00000591": 100, "00000592": 100, "00000593": 100, "00000594": 100, "00000595": 100, "00000596": 100, "00000597": 100, "00000598": 100, "00000599": 100, "00000600": 100, "00000601": 100, "00000602": 100, "00000603": 100, "00000604": 100, "00000605": 100, "00000606": 100, "00000607": 100, "00000608": 100, "00000609": 100, "00000610": 100, "00000611": 100, "00000612": 100, "00000613": 100, "00000614": 100, "00000615": 100, "00000616": 100, "00000617": 100, "00000618": 100, "00000619": 100, "00000620": 100, "00000621": 100, "00000622": 100, "00000623": 100, "00000624": 100, "00000625": 100, "00000626": 100, "00000627": 100, "00000628": 100, "00000629": 100, "00000630": 100, "00000631": 100, "00000632": 100, "00000633": 100, "00000634": 100, "00000635": 100, "00000636": 100, "00000637": 100, "00000638": 100, "00000639": 100, "00000640": 100, "00000641": 100, "00000642": 100, "00000643": 100, "00000644": 100, "00000645": 100, "00000646": 100, "00000647": 100, "00000648": 100, "00000649": 100, "00000650": 100, "00000651": 100, "00000652": 100, "00000653": 100, "00000654": 100, "00000655": 100, "00000656": 100, "00000657": 100, "00000658": 100, "00000659": 100, "00000660": 100, "00000661": 100, "00000662": 100, "00000663": 100, "00000664": 100, "00000665": 100, "00000666": 100, "00000667": 100, "00000668": 100, "00000669": 100, "00000670": 100, "00000671": 100, "00000672": 100, "00000673": 100, "00000674": 100, "00000675": 100, "00000676": 100, "00000677": 100, "00000678": 100, "00000679": 100, "00000680": 100, "00000681": 100, "00000682": 100, "00000683": 100, "00000684": 100, "00000685": 100, "00000686": 100, "00000687": 100, "00000688": 100, "00000689": 100, "00000690": 100, "00000691": 100, "00000692": 100, "00000693": 100, "00000694": 100, "00000695": 100, "00000696": 100, "00000697": 100, "00000698": 100, "00000699": 100, "00000700": 100, "00000701": 100, "00000702": 100, "00000703": 100, "00000704": 100, "00000705": 100, "00000706": 100, "00000707": 100, "00000708": 100, "00000709": 100, "00000710": 100, "00000711": 100, "00000712": 100, "00000713": 100, "00000714": 100, "00000715": 100, "00000716": 100, "00000717": 100, "00000718": 100} -------------------------------------------------------------------------------- /datasets/KITTI-360EX/OuterPinhole/test.json: -------------------------------------------------------------------------------- 1 | {"00000722": 100, "00000723": 100, "00000724": 100, "00000725": 100, "00000726": 100, "00000727": 100, "00000728": 100, "00000729": 100, "00000730": 100, "00000731": 100, "00000732": 100, "00000733": 100, "00000734": 100, "00000735": 100, "00000736": 100, "00000737": 100, "00000738": 100, "00000739": 100, "00000740": 100, "00000741": 100, "00000742": 100, "00000743": 100, "00000744": 100, "00000745": 100, "00000746": 100, "00000747": 100, "00000748": 100, "00000749": 100, "00000750": 100, "00000751": 100, "00000752": 100, "00000753": 100, "00000754": 100, "00000755": 100, "00000756": 100, "00000757": 100, "00000758": 100, "00000759": 100} -------------------------------------------------------------------------------- /datasets/KITTI-360EX/OuterPinhole/train.json: -------------------------------------------------------------------------------- 1 | {"00000000": 100, "00000001": 100, "00000002": 100, "00000003": 100, "00000004": 100, "00000005": 100, "00000006": 100, "00000007": 100, "00000008": 100, "00000009": 100, "00000010": 100, "00000011": 100, "00000012": 100, "00000013": 100, "00000014": 100, "00000015": 100, "00000016": 100, "00000017": 100, "00000018": 100, "00000019": 100, "00000020": 100, "00000021": 100, "00000022": 100, "00000023": 100, "00000024": 100, "00000025": 100, "00000026": 100, "00000027": 100, "00000028": 100, "00000029": 100, "00000030": 100, "00000031": 100, "00000032": 100, "00000033": 100, "00000034": 100, "00000035": 100, "00000036": 100, "00000037": 100, "00000038": 100, "00000039": 100, "00000040": 100, "00000041": 100, "00000042": 100, "00000043": 100, "00000044": 100, "00000045": 100, "00000046": 100, "00000047": 100, "00000048": 100, "00000049": 100, "00000050": 100, "00000051": 100, "00000052": 100, "00000053": 100, "00000054": 100, "00000055": 100, "00000056": 100, "00000057": 100, "00000058": 100, "00000059": 100, "00000060": 100, "00000061": 100, "00000062": 100, "00000063": 100, "00000064": 100, "00000065": 100, "00000066": 100, "00000067": 100, "00000068": 100, "00000069": 100, "00000070": 100, "00000071": 100, "00000072": 100, "00000073": 100, "00000074": 100, "00000075": 100, "00000076": 100, "00000077": 100, "00000078": 100, "00000079": 100, "00000080": 100, "00000081": 100, "00000082": 100, "00000083": 100, "00000084": 100, "00000085": 100, "00000086": 100, "00000087": 100, "00000088": 100, "00000089": 100, "00000090": 100, "00000091": 100, "00000092": 100, "00000093": 100, "00000094": 100, "00000095": 100, "00000096": 100, "00000097": 100, "00000098": 100, "00000099": 100, "00000100": 100, "00000101": 100, "00000102": 100, "00000103": 100, "00000104": 100, "00000105": 100, "00000106": 100, "00000107": 100, "00000108": 100, "00000109": 100, "00000110": 100, "00000111": 100, "00000112": 100, "00000113": 100, "00000114": 100, "00000115": 100, "00000116": 100, "00000117": 100, "00000118": 100, "00000119": 100, "00000120": 100, "00000121": 100, "00000122": 100, "00000123": 100, "00000124": 100, "00000125": 100, "00000126": 100, "00000127": 100, "00000128": 100, "00000129": 100, "00000130": 100, "00000131": 100, "00000132": 100, "00000133": 100, "00000134": 100, "00000135": 100, "00000136": 100, "00000137": 100, "00000138": 100, "00000139": 100, "00000140": 100, "00000141": 100, "00000142": 100, "00000143": 100, "00000144": 100, "00000145": 100, "00000146": 100, "00000147": 100, "00000148": 100, "00000149": 100, "00000150": 100, "00000151": 100, "00000152": 100, "00000153": 100, "00000154": 100, "00000155": 100, "00000156": 100, "00000157": 100, "00000158": 100, "00000159": 100, "00000160": 100, "00000161": 100, "00000162": 100, "00000163": 100, "00000164": 100, "00000165": 100, "00000166": 100, "00000167": 100, "00000168": 100, "00000169": 100, "00000170": 100, "00000171": 100, "00000172": 100, "00000173": 100, "00000174": 100, "00000175": 100, "00000176": 100, "00000177": 100, "00000178": 100, "00000179": 100, "00000180": 100, "00000181": 100, "00000182": 100, "00000183": 100, "00000184": 100, "00000185": 100, "00000186": 100, "00000187": 100, "00000188": 100, "00000189": 100, "00000190": 100, "00000191": 100, "00000192": 100, "00000193": 100, "00000194": 100, "00000195": 100, "00000196": 100, "00000197": 100, "00000198": 100, "00000199": 100, "00000200": 100, "00000201": 100, "00000202": 100, "00000203": 100, "00000204": 100, "00000205": 100, "00000206": 100, "00000207": 100, "00000208": 100, "00000209": 100, "00000210": 100, "00000211": 100, "00000212": 100, "00000213": 100, "00000214": 100, "00000215": 100, "00000216": 100, "00000217": 100, "00000218": 100, "00000219": 100, "00000220": 100, "00000221": 100, "00000222": 100, "00000223": 100, "00000224": 100, "00000225": 100, "00000226": 100, "00000227": 100, "00000228": 100, "00000229": 100, "00000230": 100, "00000231": 100, "00000232": 100, "00000233": 100, "00000234": 100, "00000235": 100, "00000236": 100, "00000237": 100, "00000238": 100, "00000239": 100, "00000240": 100, "00000241": 100, "00000242": 100, "00000243": 100, "00000244": 100, "00000245": 100, "00000246": 100, "00000247": 100, "00000248": 100, "00000249": 100, "00000250": 100, "00000251": 100, "00000252": 100, "00000253": 100, "00000254": 100, "00000255": 100, "00000256": 100, "00000257": 100, "00000258": 100, "00000259": 100, "00000260": 100, "00000261": 100, "00000262": 100, "00000263": 100, "00000264": 100, "00000265": 100, "00000266": 100, "00000267": 100, "00000268": 100, "00000269": 100, "00000270": 100, "00000271": 100, "00000272": 100, "00000273": 100, "00000274": 100, "00000275": 100, "00000276": 100, "00000277": 100, "00000278": 100, "00000279": 100, "00000280": 100, "00000281": 100, "00000282": 100, "00000283": 100, "00000284": 100, "00000285": 100, "00000286": 100, "00000287": 100, "00000288": 100, "00000289": 100, "00000290": 100, "00000291": 100, "00000292": 100, "00000293": 100, "00000294": 100, "00000295": 100, "00000296": 100, "00000297": 100, "00000298": 100, "00000299": 100, "00000300": 100, "00000301": 100, "00000302": 100, "00000303": 100, "00000304": 100, "00000305": 100, "00000306": 100, "00000307": 100, "00000308": 100, "00000309": 100, "00000310": 100, "00000311": 100, "00000312": 100, "00000313": 100, "00000314": 100, "00000315": 100, "00000316": 100, "00000317": 100, "00000318": 100, "00000319": 100, "00000320": 100, "00000321": 100, "00000322": 100, "00000323": 100, "00000324": 100, "00000325": 100, "00000326": 100, "00000327": 100, "00000328": 100, "00000329": 100, "00000330": 100, "00000331": 100, "00000332": 100, "00000333": 100, "00000334": 100, "00000335": 100, "00000336": 100, "00000337": 100, "00000338": 100, "00000339": 100, "00000340": 100, "00000341": 100, "00000342": 100, "00000343": 100, "00000344": 100, "00000345": 100, "00000346": 100, "00000347": 100, "00000348": 100, "00000349": 100, "00000350": 100, "00000351": 100, "00000352": 100, "00000353": 100, "00000354": 100, "00000355": 100, "00000356": 100, "00000357": 100, "00000358": 100, "00000359": 100, "00000360": 100, "00000361": 100, "00000362": 100, "00000363": 100, "00000364": 100, "00000365": 100, "00000366": 100, "00000367": 100, "00000368": 100, "00000369": 100, "00000370": 100, "00000371": 100, "00000372": 100, "00000373": 100, "00000374": 100, "00000375": 100, "00000376": 100, "00000377": 100, "00000378": 100, "00000379": 100, "00000380": 100, "00000381": 100, "00000382": 100, "00000383": 100, "00000384": 100, "00000385": 100, "00000386": 100, "00000387": 100, "00000388": 100, "00000389": 100, "00000390": 100, "00000391": 100, "00000392": 100, "00000393": 100, "00000394": 100, "00000395": 100, "00000396": 100, "00000397": 100, "00000398": 100, "00000399": 100, "00000400": 100, "00000401": 100, "00000402": 100, "00000403": 100, "00000404": 100, "00000405": 100, "00000406": 100, "00000407": 100, "00000408": 100, "00000409": 100, "00000410": 100, "00000411": 100, "00000412": 100, "00000413": 100, "00000414": 100, "00000415": 100, "00000416": 100, "00000417": 100, "00000418": 100, "00000419": 100, "00000420": 100, "00000421": 100, "00000422": 100, "00000423": 100, "00000424": 100, "00000425": 100, "00000426": 100, "00000427": 100, "00000428": 100, "00000429": 100, "00000430": 100, "00000431": 100, "00000432": 100, "00000433": 100, "00000434": 100, "00000435": 100, "00000436": 100, "00000437": 100, "00000438": 100, "00000439": 100, "00000440": 100, "00000441": 100, "00000442": 100, "00000443": 100, "00000444": 100, "00000445": 100, "00000446": 100, "00000447": 100, "00000448": 100, "00000449": 100, "00000450": 100, "00000451": 100, "00000452": 100, "00000453": 100, "00000454": 100, "00000455": 100, "00000456": 100, "00000457": 100, "00000458": 100, "00000459": 100, "00000460": 100, "00000461": 100, "00000462": 100, "00000463": 100, "00000464": 100, "00000465": 100, "00000466": 100, "00000467": 100, "00000468": 100, "00000469": 100, "00000470": 100, "00000471": 100, "00000472": 100, "00000473": 100, "00000474": 100, "00000475": 100, "00000476": 100, "00000477": 100, "00000478": 100, "00000479": 100, "00000480": 100, "00000481": 100, "00000482": 100, "00000483": 100, "00000484": 100, "00000485": 100, "00000486": 100, "00000487": 100, "00000488": 100, "00000489": 100, "00000490": 100, "00000491": 100, "00000492": 100, "00000493": 100, "00000494": 100, "00000495": 100, "00000496": 100, "00000497": 100, "00000498": 100, "00000499": 100, "00000500": 100, "00000501": 100, "00000502": 100, "00000503": 100, "00000504": 100, "00000505": 100, "00000506": 100, "00000507": 100, "00000508": 100, "00000509": 100, "00000510": 100, "00000511": 100, "00000512": 100, "00000513": 100, "00000514": 100, "00000515": 100, "00000516": 100, "00000517": 100, "00000518": 100, "00000519": 100, "00000520": 100, "00000521": 100, "00000522": 100, "00000523": 100, "00000524": 100, "00000525": 100, "00000526": 100, "00000527": 100, "00000528": 100, "00000529": 100, "00000530": 100, "00000531": 100, "00000532": 100, "00000533": 100, "00000534": 100, "00000535": 100, "00000536": 100, "00000537": 100, "00000538": 100, "00000539": 100, "00000540": 100, "00000541": 100, "00000542": 100, "00000543": 100, "00000544": 100, "00000545": 100, "00000546": 100, "00000547": 100, "00000548": 100, "00000549": 100, "00000550": 100, "00000551": 100, "00000552": 100, "00000553": 100, "00000554": 100, "00000555": 100, "00000556": 100, "00000557": 100, "00000558": 100, "00000559": 100, "00000560": 100, "00000561": 100, "00000562": 100, "00000563": 100, "00000564": 100, "00000565": 100, "00000566": 100, "00000567": 100, "00000568": 100, "00000569": 100, "00000570": 100, "00000571": 100, "00000572": 100, "00000573": 100, "00000574": 100, "00000575": 100, "00000576": 100, "00000577": 100, "00000578": 100, "00000579": 100, "00000580": 100, "00000581": 100, "00000582": 100, "00000583": 100, "00000584": 100, "00000585": 100, "00000586": 100, "00000587": 100, "00000588": 100, "00000589": 100, "00000590": 100, "00000591": 100, "00000592": 100, "00000593": 100, "00000594": 100, "00000595": 100, "00000596": 100, "00000597": 100, "00000598": 100, "00000599": 100, "00000600": 100, "00000601": 100, "00000602": 100, "00000603": 100, "00000604": 100, "00000605": 100, "00000606": 100, "00000607": 100, "00000608": 100, "00000609": 100, "00000610": 100, "00000611": 100, "00000612": 100, "00000613": 100, "00000614": 100, "00000615": 100, "00000616": 100, "00000617": 100, "00000618": 100, "00000619": 100, "00000620": 100, "00000621": 100, "00000622": 100, "00000623": 100, "00000624": 100, "00000625": 100, "00000626": 100, "00000627": 100, "00000628": 100, "00000629": 100, "00000630": 100, "00000631": 100, "00000632": 100, "00000633": 100, "00000634": 100, "00000635": 100, "00000636": 100, "00000637": 100, "00000638": 100, "00000639": 100, "00000640": 100, "00000641": 100, "00000642": 100, "00000643": 100, "00000644": 100, "00000645": 100, "00000646": 100, "00000647": 100, "00000648": 100, "00000649": 100, "00000650": 100, "00000651": 100, "00000652": 100, "00000653": 100, "00000654": 100, "00000655": 100, "00000656": 100, "00000657": 100, "00000658": 100, "00000659": 100, "00000660": 100, "00000661": 100, "00000662": 100, "00000663": 100, "00000664": 100, "00000665": 100, "00000666": 100, "00000667": 100, "00000668": 100, "00000669": 100, "00000670": 100, "00000671": 100, "00000672": 100, "00000673": 100, "00000674": 100, "00000675": 100, "00000676": 100, "00000677": 100, "00000678": 100, "00000679": 100, "00000680": 100, "00000681": 100, "00000682": 100, "00000683": 100, "00000684": 100, "00000685": 100, "00000686": 100, "00000687": 100, "00000688": 100, "00000689": 100, "00000690": 100, "00000691": 100, "00000692": 100, "00000693": 100, "00000694": 100, "00000695": 100, "00000696": 100, "00000697": 100, "00000698": 100, "00000699": 100, "00000700": 100, "00000701": 100, "00000702": 100, "00000703": 100, "00000704": 100, "00000705": 100, "00000706": 100, "00000707": 100, "00000708": 100, "00000709": 100, "00000710": 100, "00000711": 100, "00000712": 100, "00000713": 100, "00000714": 100, "00000715": 100, "00000716": 100, "00000717": 100, "00000718": 100, "00000719": 100, "00000720": 100, "00000721": 100} -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import cv2 3 | import numpy as np 4 | import importlib 5 | import os 6 | import time 7 | import json 8 | import random 9 | import argparse 10 | from PIL import Image 11 | 12 | import torch 13 | from torch.utils.data import DataLoader 14 | 15 | from core.dataset import TestDataset 16 | from core.metrics import calc_psnr_and_ssim, calculate_i3d_activations, calculate_vfid, init_i3d_model 17 | 18 | # global variables 19 | # w h can be changed by args.output_size 20 | w, h = 432, 240 # default acc. test setting in e2fgvi for davis dataset and KITTI-EXO 21 | # w, h = 336, 336 # default acc. test setting for KITTI-EXI 22 | ref_length = 10 # non-local frames的步幅间隔,此处为每10帧取1帧NLF 23 | 24 | 25 | def read_cfg(args): 26 | """read flowlens cfg from config file""" 27 | # loading configs 28 | config = json.load(open(args.cfg_path)) 29 | 30 | # # # # pass config to args # # # # 31 | args.dataset = config['train_data_loader']['name'] 32 | args.data_root = config['train_data_loader']['data_root'] 33 | args.output_size = [432, 240] 34 | args.output_size[0], args.output_size[1] = (config['train_data_loader']['w'], config['train_data_loader']['h']) 35 | args.model_win_size = config['model'].get('window_size', None) 36 | args.model_output_size = config['model'].get('output_size', None) 37 | args.neighbor_stride = config['train_data_loader'].get('num_local_frames', 10) 38 | 39 | # 是否使用spynet作为光流补全网络 (FlowLens-S) 40 | config['model']['spy_net'] = config['model'].get('spy_net', 0) 41 | if config['model']['spy_net'] != 0: 42 | # default for FlowLens-S 43 | args.spy_net = True 44 | else: 45 | # default for FlowLens 46 | args.spy_net = False 47 | 48 | if config['model']['net'] == 'flowlens': 49 | 50 | # 定义transformer的深度 51 | if config['model']['depths'] != 0: 52 | args.depths = config['model']['depths'] 53 | else: 54 | # 使用网络默认的深度 55 | args.depths = None 56 | 57 | # 定义trans block的window个数(token除以window划分大小) 58 | config['model']['window_size'] = config['model'].get('window_size', 0) 59 | if config['model']['window_size'] != 0: 60 | args.window_size = config['model']['window_size'] 61 | else: 62 | # 使用网络默认的window 63 | args.window_size = None 64 | 65 | # 定义是大模型还是小模型 66 | if config['model']['small_model'] != 0: 67 | args.small_model = True 68 | else: 69 | args.small_model = False 70 | 71 | # 是否冻结dcn参数 72 | config['model']['freeze_dcn'] = config['model'].get('freeze_dcn', 0) 73 | if config['model']['freeze_dcn'] != 0: 74 | args.freeze_dcn = True 75 | else: 76 | # default 77 | args.freeze_dcn = False 78 | 79 | # # # # pass config to args # # # # 80 | 81 | return args 82 | 83 | 84 | # sample reference frames from the whole video with mem support 85 | def get_ref_index_mem(length, neighbor_ids, same_id=False): 86 | """smae_id(bool): If True, allow same ref and local id as input.""" 87 | ref_index = [] 88 | for i in range(0, length, ref_length): 89 | if same_id: 90 | # 允许相同id 91 | ref_index.append(i) 92 | else: 93 | # 不允许相同的id,当出现相同id时找到最近的一个不同的i 94 | if i not in neighbor_ids: 95 | ref_index.append(i) 96 | else: 97 | lf_id_avg = sum(neighbor_ids)/len(neighbor_ids) # 计算 local frame id 平均值 98 | for _iter in range(0, 100): 99 | if i < (length - 1): 100 | # 不能超过视频长度 101 | if i == 0: 102 | # 第0帧的时候重复,直接取到下一个 NLF + 5 +5是为了防止和下一个重复的 nlf id 改的id重复 103 | i = ref_length + args.neighbor_stride 104 | ref_index.append(i) 105 | break 106 | elif i < lf_id_avg: 107 | # 往前找不重复的参考帧, 防止都往一个方向找而重复 108 | i -= 1 109 | else: 110 | # 往后找不重复的参考帧 111 | i += 1 112 | else: 113 | # 超过了直接用最后一帧,然后退出 114 | ref_index.append(i) 115 | break 116 | 117 | if i not in neighbor_ids: 118 | ref_index.append(i) 119 | break 120 | 121 | return ref_index 122 | 123 | 124 | # sample reference frames from the remain frames with random behavior like trainning 125 | def get_ref_index_mem_random(neighbor_ids, video_length, num_ref_frame=3, before_nlf=False): 126 | if not before_nlf: 127 | # 从过去和未来采集非局部帧 128 | complete_idx_set = list(range(video_length)) 129 | else: 130 | # 非局部帧只会从过去的视频帧中选取,不会使用未来的信息 131 | complete_idx_set = list(range(neighbor_ids[-1])) 132 | # complete_idx_set = list(range(video_length)) 133 | 134 | remain_idx = list(set(complete_idx_set) - set(neighbor_ids)) 135 | 136 | # 当只用过去的帧作为非局部帧时,可能会出现过去的帧数量少于非局部帧需求的问题,比如视频的一开始 137 | if before_nlf: 138 | if len(remain_idx) < num_ref_frame: 139 | # 则我们允许从局部帧中采样非局部帧 转换为set可以去除重复元素 140 | remain_idx = list(set(remain_idx + neighbor_ids)) 141 | 142 | ref_index = sorted(random.sample(remain_idx, num_ref_frame)) 143 | return ref_index 144 | 145 | 146 | def main_worker(args): 147 | args = read_cfg(args=args) # 读取网络的所有设置 148 | w = args.output_size[0] 149 | h = args.output_size[1] 150 | args.size = (w, h) 151 | 152 | # set up datasets and data loader 153 | # default result 154 | test_dataset = TestDataset(args) 155 | 156 | test_loader = DataLoader(test_dataset, 157 | batch_size=1, 158 | shuffle=False, 159 | num_workers=args.num_workers) 160 | 161 | # set up models 162 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 163 | net = importlib.import_module('model.' + args.model) 164 | 165 | if args.model == 'flowlens': 166 | model = net.InpaintGenerator(freeze_dcn=args.freeze_dcn, spy_net=args.spy_net, depths=args.depths, 167 | window_size=args.model_win_size, output_size=args.model_output_size, 168 | small_model=args.small_model).to(device) 169 | else: 170 | # 加载一些尺寸窗口设置 171 | model = net.InpaintGenerator(window_size=args.model_win_size, output_size=args.model_output_size).to(device) 172 | 173 | if args.ckpt is not None: 174 | data = torch.load(args.ckpt, map_location=device) 175 | model.load_state_dict(data) 176 | print(f'Loading from: {args.ckpt}') 177 | 178 | # # half 179 | # model = model.half() 180 | 181 | model.eval() 182 | 183 | total_frame_psnr = [] 184 | total_frame_ssim = [] 185 | 186 | output_i3d_activations = [] 187 | real_i3d_activations = [] 188 | 189 | print('Start evaluation...') 190 | 191 | time_all = 0 192 | len_all = 0 193 | 194 | # create results directory 195 | if args.ckpt is not None: 196 | ckpt = args.ckpt.split('/')[-1] 197 | else: 198 | ckpt = 'random' 199 | 200 | if args.fov is not None: 201 | if args.reverse: 202 | result_path = os.path.join('results', f'{args.model}+_{ckpt}_{args.fov}_{args.dataset}') 203 | else: 204 | result_path = os.path.join('results', f'{args.model}_{ckpt}_{args.fov}_{args.dataset}') 205 | else: 206 | if args.reverse: 207 | result_path = os.path.join('results', f'{args.model}+_{ckpt}_{args.dataset}') 208 | else: 209 | result_path = os.path.join('results', f'{args.model}_{ckpt}_{args.dataset}') 210 | 211 | # if args.fov is not None: 212 | # if args.reverse: 213 | # result_path = os.path.join('/workspace/mnt/storage/shihao/BEV_Flow/tmp', f'{args.model}+_{ckpt}_{args.fov}_{args.dataset}') 214 | # else: 215 | # result_path = os.path.join('/workspace/mnt/storage/shihao/BEV_Flow/tmp', f'{args.model}_{ckpt}_{args.fov}_{args.dataset}') 216 | # else: 217 | # if args.reverse: 218 | # result_path = os.path.join('/workspace/mnt/storage/shihao/BEV_Flow/tmp', f'{args.model}+_{ckpt}_{args.dataset}') 219 | # else: 220 | # result_path = os.path.join('/workspace/mnt/storage/shihao/BEV_Flow/tmp', f'{args.model}_{ckpt}_{args.dataset}') 221 | 222 | if not os.path.exists(result_path): 223 | os.makedirs(result_path) 224 | eval_summary = open( 225 | os.path.join(result_path, f"{args.model}_{args.dataset}_metrics.txt"), 226 | "w") 227 | 228 | i3d_model = init_i3d_model() 229 | 230 | for index, items in enumerate(test_loader): 231 | 232 | for blk in model.transformer: 233 | try: 234 | blk.attn.m_k = [] 235 | blk.attn.m_v = [] 236 | except: 237 | pass 238 | 239 | frames, masks, video_name, frames_PIL = items 240 | 241 | # # half 242 | # frames = frames.half() 243 | # masks = masks.half() 244 | 245 | video_length = frames.size(1) 246 | frames, masks = frames.to(device), masks.to(device) 247 | ori_frames = frames_PIL # 原始帧,可视为真值 248 | ori_frames = [ 249 | ori_frames[i].squeeze().cpu().numpy() for i in range(video_length) 250 | ] 251 | comp_frames = [None] * video_length # 补全帧 252 | 253 | len_all += video_length 254 | 255 | # complete holes by our model 256 | # 当这个循环走完的时候,一段视频已经被补全了 257 | for f in range(0, video_length, args.neighbor_stride): 258 | if args.same_memory: 259 | # 尽可能与video in-painting的测试逻辑一致 260 | # 输入的时间维度T保持一致 261 | if (f - args.neighbor_stride > 0) and (f + args.neighbor_stride + 1 < video_length): 262 | # 视频首尾均不会越界,不需要补充额外帧 263 | neighbor_ids = [ 264 | i for i in range(max(0, f - args.neighbor_stride), 265 | min(video_length, f + args.neighbor_stride + 1)) 266 | ] # neighbor_ids即为Local Frames, 局部帧 267 | else: 268 | # 视频越界,补充额外帧保证记忆缓存的时间通道维度一致,后面也可以尝试放到trans里直接复制特征的时间维度 269 | neighbor_ids = [ 270 | i for i in range(max(0, f - args.neighbor_stride), 271 | min(video_length, f + args.neighbor_stride + 1)) 272 | ] # neighbor_ids即为Local Frames, 局部帧 273 | repeat_num = (args.neighbor_stride * 2 + 1) - len(neighbor_ids) 274 | 275 | lf_id_avg = sum(neighbor_ids) / len(neighbor_ids) # 计算 local frame id 平均值 276 | first_id = neighbor_ids[0] 277 | for ii in range(0, repeat_num): 278 | # 保证局部窗口的大小一致,防止缓存通道数变化 279 | if lf_id_avg < (video_length // 2): 280 | # 前半段视频也向前找局部id,防止和下一个窗口的输入完全一样 281 | new_id = video_length - 1 - ii 282 | else: 283 | # 后半段视频向前找局部id 284 | new_id = first_id - 1 - ii 285 | neighbor_ids.append(new_id) 286 | 287 | neighbor_ids = sorted(neighbor_ids) # 重新排序 288 | 289 | else: 290 | # 与记忆力模型的训练逻辑一致 291 | if not args.recurrent: 292 | if video_length < (f + args.neighbor_stride): 293 | neighbor_ids = [ 294 | i for i in range(f, video_length) 295 | ] # 时间上不重叠的窗口,每个局部帧只会被计算一次,视频尾部可能不足5帧局部帧,复制最后一帧补全数量 296 | for repeat_idx in range(0, args.neighbor_stride - len(neighbor_ids)): 297 | neighbor_ids.append(neighbor_ids[-1]) 298 | else: 299 | neighbor_ids = [ 300 | i for i in range(f, f + args.neighbor_stride) 301 | ] # 时间上不重叠的窗口,每个局部帧只会被计算一次 302 | else: 303 | # 在recurrent模式下,每次局部窗口都为1 304 | neighbor_ids = [f] 305 | 306 | # 为了保证时间维度一致, 允许输入相同id的帧 307 | if args.same_memory: 308 | ref_ids = get_ref_index_mem(video_length, neighbor_ids, same_id=False) # ref_ids即为Non-Local Frames, 非局部帧 309 | elif args.past_ref: 310 | ref_ids = get_ref_index_mem_random(neighbor_ids, video_length, num_ref_frame=3, before_nlf=True) # 只允许过去的参考帧 311 | else: 312 | ref_ids = get_ref_index_mem_random(neighbor_ids, video_length, num_ref_frame=3) # 与序列训练同样的非局部帧输入逻辑 313 | 314 | ref_ids = sorted(ref_ids) # 重新排序 315 | selected_imgs_lf = frames[:1, neighbor_ids, :, :, :] 316 | selected_imgs_nlf = frames[:1, ref_ids, :, :, :] 317 | selected_imgs = torch.cat((selected_imgs_lf, selected_imgs_nlf), dim=1) 318 | selected_masks_lf = masks[:1, neighbor_ids, :, :, :] 319 | selected_masks_nlf = masks[:1, ref_ids, :, :, :] 320 | selected_masks = torch.cat((selected_masks_lf, selected_masks_nlf), dim=1) 321 | 322 | with torch.no_grad(): 323 | masked_frames = selected_imgs * (1 - selected_masks) 324 | 325 | torch.cuda.synchronize() 326 | time_start = time.time() 327 | 328 | pred_img, _ = model(masked_frames, len(neighbor_ids)) # forward里会输入局部帧数量来对两种数据分开处理 329 | 330 | # 水平与竖直翻转增强 331 | if args.reverse: 332 | masked_frames_horizontal_aug = torch.from_numpy(masked_frames.cpu().numpy()[:, :, :, :, ::-1].copy()).cuda() 333 | pred_img_horizontal_aug, _ = model(masked_frames_horizontal_aug, len(neighbor_ids)) 334 | pred_img_horizontal_aug = torch.from_numpy(pred_img_horizontal_aug.cpu().numpy()[:, :, :, ::-1].copy()).cuda() 335 | masked_frames_vertical_aug = torch.from_numpy(masked_frames.cpu().numpy()[:, :, :, ::-1, :].copy()).cuda() 336 | pred_img_vertical_aug, _ = model(masked_frames_vertical_aug, len(neighbor_ids)) 337 | pred_img_vertical_aug = torch.from_numpy(pred_img_vertical_aug.cpu().numpy()[:, :, ::-1, :].copy()).cuda() 338 | 339 | pred_img = 1 / 3 * (pred_img + pred_img_horizontal_aug + pred_img_vertical_aug) 340 | 341 | torch.cuda.synchronize() 342 | time_end = time.time() 343 | time_sum = time_end - time_start 344 | time_all += time_sum 345 | 346 | pred_img = (pred_img + 1) / 2 347 | pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255 348 | binary_masks = masks[0, neighbor_ids, :, :, :].cpu().permute( 349 | 0, 2, 3, 1).numpy().astype(np.uint8) 350 | for i in range(len(neighbor_ids)): 351 | idx = neighbor_ids[i] 352 | img = np.array(pred_img[i]).astype(np.uint8) * binary_masks[i] \ 353 | + ori_frames[idx] * (1 - binary_masks[i]) 354 | 355 | if comp_frames[idx] is None: 356 | # 如果第一次补全Local Frame中的某帧,直接记录到补全帧list (comp_frames) 里 357 | # good_fusion下所有img多出一个‘次数’通道,用来记录所有的结果 358 | comp_frames[idx] = img[np.newaxis, :, :, :] 359 | 360 | # 直接把所有结果都记录下来,最后沿着通道平均 361 | else: 362 | comp_frames[idx] = np.concatenate((comp_frames[idx], img[np.newaxis, :, :, :]), axis=0) 363 | ######################################################################################## 364 | 365 | # 对于good_fusion, 推理一遍后需要沿着axis=0取平均 366 | for idx, comp_frame in zip(range(0, video_length), comp_frames): 367 | comp_frame = comp_frame.astype(np.float32).sum(axis=0)/comp_frame.shape[0] 368 | comp_frames[idx] = comp_frame 369 | 370 | # calculate metrics 371 | cur_video_psnr = [] 372 | cur_video_ssim = [] 373 | comp_PIL = [] # to calculate VFID 374 | frames_PIL = [] 375 | for ori, comp in zip(ori_frames, comp_frames): 376 | psnr, ssim = calc_psnr_and_ssim(ori, comp) 377 | 378 | cur_video_psnr.append(psnr) 379 | cur_video_ssim.append(ssim) 380 | 381 | total_frame_psnr.append(psnr) 382 | total_frame_ssim.append(ssim) 383 | 384 | frames_PIL.append(Image.fromarray(ori.astype(np.uint8))) 385 | comp_PIL.append(Image.fromarray(comp.astype(np.uint8))) 386 | cur_psnr = sum(cur_video_psnr) / len(cur_video_psnr) 387 | cur_ssim = sum(cur_video_ssim) / len(cur_video_ssim) 388 | 389 | # saving i3d activations 390 | frames_i3d, comp_i3d = calculate_i3d_activations(frames_PIL, 391 | comp_PIL, 392 | i3d_model, 393 | device=device) 394 | real_i3d_activations.append(frames_i3d) 395 | output_i3d_activations.append(comp_i3d) 396 | 397 | print( 398 | f'[{index+1:3}/{len(test_loader)}] Name: {str(video_name):25} | PSNR/SSIM: {cur_psnr:.4f}/{cur_ssim:.4f}' 399 | ) 400 | eval_summary.write( 401 | f'[{index+1:3}/{len(test_loader)}] Name: {str(video_name):25} | PSNR/SSIM: {cur_psnr:.4f}/{cur_ssim:.4f}\n' 402 | ) 403 | 404 | print('Average run time: (%f) per frame' % (time_all/len_all)) 405 | 406 | # saving images for evaluating warpping errors 407 | if args.save_results: 408 | save_frame_path = os.path.join(result_path, video_name[0]) 409 | os.makedirs(save_frame_path, exist_ok=False) 410 | 411 | for i, frame in enumerate(comp_frames): 412 | cv2.imwrite( 413 | os.path.join(save_frame_path, 414 | str(i).zfill(5) + '.png'), 415 | cv2.cvtColor(frame.astype(np.uint8), cv2.COLOR_RGB2BGR)) 416 | 417 | avg_frame_psnr = sum(total_frame_psnr) / len(total_frame_psnr) 418 | avg_frame_ssim = sum(total_frame_ssim) / len(total_frame_ssim) 419 | 420 | fid_score = calculate_vfid(real_i3d_activations, output_i3d_activations) 421 | print('Finish evaluation... Average Frame PSNR/SSIM/VFID: ' 422 | f'{avg_frame_psnr:.2f}/{avg_frame_ssim:.4f}/{fid_score:.3f}') 423 | eval_summary.write( 424 | 'Finish evaluation... Average Frame PSNR/SSIM/VFID: ' 425 | f'{avg_frame_psnr:.2f}/{avg_frame_ssim:.4f}/{fid_score:.3f}') 426 | eval_summary.close() 427 | 428 | print('All average forward run time: (%f) per frame' % (time_all / len_all)) 429 | 430 | return len(total_frame_psnr) 431 | 432 | 433 | if __name__ == '__main__': 434 | parser = argparse.ArgumentParser(description='FlowLens') 435 | parser.add_argument('--cfg_path', default='configs/KITTI360EX-I_FlowLens_small_re.json') 436 | parser.add_argument('--dataset', choices=['KITTI360-EX'], type=str) # 相当于train的‘name’ 437 | parser.add_argument('--data_root', type=str) 438 | parser.add_argument('--output_size', type=int, nargs='+', default=[432, 240]) 439 | parser.add_argument('--object', action='store_true', default=False) # if true, use object removal mask 440 | parser.add_argument('--fov', choices=['fov5', 'fov10', 'fov20'], type=str) # 对于KITTI360-EX, 测试需要输入fov 441 | parser.add_argument('--past_ref', action='store_true', default=True) # 对于KITTI360-EX, 测试时只允许使用之前的参考帧 442 | parser.add_argument('--model', choices=['flowlens'], type=str) 443 | parser.add_argument('--ckpt', type=str, default=None) 444 | parser.add_argument('--save_results', action='store_true', default=False) 445 | parser.add_argument('--num_workers', default=4, type=int) 446 | parser.add_argument('--same_memory', action='store_true', default=False, 447 | help='test with memory ability in video in-painting style') 448 | parser.add_argument('--reverse', action='store_true', default=False, 449 | help='test with horizontal and vertical reverse augmentation') 450 | parser.add_argument('--model_win_size', type=int, nargs='+', default=[5, 9]) 451 | parser.add_argument('--model_output_size', type=int, nargs='+', default=[60, 108]) 452 | parser.add_argument('--recurrent', action='store_true', default=False, 453 | help='keep window = 1, stride = 1 to not use any local future info') 454 | args = parser.parse_args() 455 | 456 | if args.dataset == 'KITTI360-EX': 457 | # 对于KITTI360-EX, 测试时只允许使用之前的参考帧 458 | args.past_ref = True 459 | 460 | frame_num = main_worker(args) 461 | -------------------------------------------------------------------------------- /model/flowlens.py: -------------------------------------------------------------------------------- 1 | ''' Towards An End-to-End Framework for Video Inpainting 2 | ''' 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from model.modules.flow_comp import SPyNet 9 | from model.modules.maskflownets import MaskFlowNetS 10 | from model.modules.feat_prop import BidirectionalPropagation, SecondOrderDeformableAlignment 11 | from model.modules.mix_focal_transformer import MixFocalTransformerBlock, SoftSplit, SoftComp 12 | from model.modules.spectral_norm import spectral_norm as _spectral_norm 13 | 14 | 15 | class BaseNetwork(nn.Module): 16 | def __init__(self): 17 | super(BaseNetwork, self).__init__() 18 | 19 | def print_network(self): 20 | if isinstance(self, list): 21 | self = self[0] 22 | num_params = 0 23 | for param in self.parameters(): 24 | num_params += param.numel() 25 | print( 26 | 'Network [%s] was created. Total number of parameters: %.1f million. ' 27 | 'To see the architecture, do print(network).' % 28 | (type(self).__name__, num_params / 1000000)) 29 | 30 | def init_weights(self, init_type='normal', gain=0.02): 31 | ''' 32 | initialize network's weights 33 | init_type: normal | xavier | kaiming | orthogonal 34 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 35 | ''' 36 | def init_func(m): 37 | classname = m.__class__.__name__ 38 | if classname.find('InstanceNorm2d') != -1: 39 | if hasattr(m, 'weight') and m.weight is not None: 40 | nn.init.constant_(m.weight.data, 1.0) 41 | if hasattr(m, 'bias') and m.bias is not None: 42 | nn.init.constant_(m.bias.data, 0.0) 43 | elif hasattr(m, 'weight') and (classname.find('Conv') != -1 44 | or classname.find('Linear') != -1): 45 | if init_type == 'normal': 46 | nn.init.normal_(m.weight.data, 0.0, gain) 47 | elif init_type == 'xavier': 48 | nn.init.xavier_normal_(m.weight.data, gain=gain) 49 | elif init_type == 'xavier_uniform': 50 | nn.init.xavier_uniform_(m.weight.data, gain=1.0) 51 | elif init_type == 'kaiming': 52 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 53 | elif init_type == 'orthogonal': 54 | nn.init.orthogonal_(m.weight.data, gain=gain) 55 | elif init_type == 'none': # uses pytorch's default init method 56 | m.reset_parameters() 57 | else: 58 | raise NotImplementedError( 59 | 'initialization method [%s] is not implemented' % 60 | init_type) 61 | if hasattr(m, 'bias') and m.bias is not None: 62 | nn.init.constant_(m.bias.data, 0.0) 63 | 64 | self.apply(init_func) 65 | 66 | # propagate to children 67 | for m in self.children(): 68 | if hasattr(m, 'init_weights'): 69 | m.init_weights(init_type, gain) 70 | 71 | 72 | class Encoder(nn.Module): 73 | def __init__(self, out_channel=128, reduction=1): 74 | super(Encoder, self).__init__() 75 | self.group = [1, 2, 4, 8, 1] 76 | self.layers = nn.ModuleList([ 77 | nn.Conv2d(3, 64//reduction, kernel_size=3, stride=2, padding=1), 78 | nn.LeakyReLU(0.2, inplace=True), 79 | nn.Conv2d(64//reduction, 64//reduction, kernel_size=3, stride=1, padding=1), 80 | nn.LeakyReLU(0.2, inplace=True), 81 | nn.Conv2d(64//reduction, 128//reduction, kernel_size=3, stride=2, padding=1), 82 | nn.LeakyReLU(0.2, inplace=True), 83 | nn.Conv2d(128//reduction, 256//reduction, kernel_size=3, stride=1, padding=1), 84 | nn.LeakyReLU(0.2, inplace=True), 85 | nn.Conv2d(256//reduction, 384//reduction, kernel_size=3, stride=1, padding=1, groups=1), 86 | nn.LeakyReLU(0.2, inplace=True), 87 | nn.Conv2d(640//reduction, 512//reduction, kernel_size=3, stride=1, padding=1, groups=2), 88 | nn.LeakyReLU(0.2, inplace=True), 89 | nn.Conv2d(768//reduction, 384//reduction, kernel_size=3, stride=1, padding=1, groups=4), 90 | nn.LeakyReLU(0.2, inplace=True), 91 | nn.Conv2d(640//reduction, 256//reduction, kernel_size=3, stride=1, padding=1, groups=8), 92 | nn.LeakyReLU(0.2, inplace=True), 93 | nn.Conv2d(512//reduction, out_channel, kernel_size=3, stride=1, padding=1, groups=1), 94 | nn.LeakyReLU(0.2, inplace=True) 95 | ]) 96 | 97 | def forward(self, x): 98 | bt, c, _, _ = x.size() 99 | # h, w = h//4, w//4 100 | out = x 101 | for i, layer in enumerate(self.layers): 102 | if i == 8: 103 | x0 = out 104 | _, _, h, w = x0.size() 105 | if i > 8 and i % 2 == 0: 106 | g = self.group[(i - 8) // 2] 107 | x = x0.view(bt, g, -1, h, w) 108 | o = out.view(bt, g, -1, h, w) 109 | out = torch.cat([x, o], 2).view(bt, -1, h, w) 110 | out = layer(out) 111 | return out 112 | 113 | 114 | class deconv(nn.Module): 115 | def __init__(self, 116 | input_channel, 117 | output_channel, 118 | kernel_size=3, 119 | padding=0): 120 | super().__init__() 121 | self.conv = nn.Conv2d(input_channel, 122 | output_channel, 123 | kernel_size=kernel_size, 124 | stride=1, 125 | padding=padding) 126 | 127 | def forward(self, x): 128 | x = F.interpolate(x, 129 | scale_factor=2, 130 | mode='bilinear', 131 | align_corners=True) 132 | return self.conv(x) 133 | 134 | 135 | class InpaintGenerator(BaseNetwork): 136 | """ 137 | window_size # 窗口的尺寸,相当于patch数量除以4 138 | output_size # 输出的尺寸,训练尺寸//4 139 | spy_net # 如果为True,使用spynet替换maskflownets计算光流 140 | freeze_dcn # 如果为True,冻结dcn参数 141 | """ 142 | def __init__(self, init_weights=True, freeze_dcn=False, spy_net=False, 143 | flow_res=0.25, 144 | depths=9, 145 | window_size=None, output_size=None, small_model=False): 146 | super(InpaintGenerator, self).__init__() 147 | 148 | if not small_model: 149 | # large model: 150 | channel = 256 # default 151 | hidden = 512 # default 152 | reduction = 1 # default 153 | else: 154 | # v2 155 | channel = 128 156 | hidden = 256 157 | reduction = 2 158 | 159 | # 设置transformer参数 160 | # 设置trans block的数量 161 | if depths is None: 162 | depths = 2 # 0.08s/frame, 0.07s/frame with hidden = 128, 163 | else: 164 | depths = depths 165 | 166 | # 只有一个stage 167 | # 设置不同层的head数量 168 | # 默认条件下每层都使用4个head,相当于宽度和高度各2个head 169 | num_heads = [4] * depths 170 | 171 | # 光流推理分辨率 172 | self.flow_res = flow_res 173 | 174 | # encoder 175 | self.encoder = Encoder(out_channel=channel // 2, reduction=reduction) 176 | 177 | # decoder-default 178 | self.decoder = nn.Sequential( 179 | deconv(channel // 2, 128, kernel_size=3, padding=1), 180 | nn.LeakyReLU(0.2, inplace=True), 181 | nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), 182 | nn.LeakyReLU(0.2, inplace=True), 183 | deconv(64, 64, kernel_size=3, padding=1), 184 | nn.LeakyReLU(0.2, inplace=True), 185 | nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)) 186 | 187 | # feature propagation module 188 | self.feat_prop_module = BidirectionalPropagation(channel // 2, freeze_dcn=freeze_dcn) 189 | 190 | # soft split and soft composition 191 | kernel_size = (7, 7) # 滑块的大小 192 | padding = (3, 3) # 两个方向上隐式填0的数量 193 | stride = (3, 3) # 滑块的步长 194 | if output_size is None: 195 | # 默认的输出尺寸 196 | output_size = (60, 108) 197 | else: 198 | output_size = (output_size[0], output_size[1]) 199 | t2t_params = { 200 | 'kernel_size': kernel_size, 201 | 'stride': stride, 202 | 'padding': padding 203 | } 204 | 205 | self.ss = SoftSplit(channel // 2, 206 | hidden, 207 | kernel_size, 208 | stride, 209 | padding, 210 | t2t_param=t2t_params) 211 | 212 | self.sc = SoftComp(channel // 2, hidden, kernel_size, stride, padding) 213 | 214 | n_vecs = 1 # 计算token的数量 215 | for i, d in enumerate(kernel_size): 216 | n_vecs *= int((output_size[i] + 2 * padding[i] - 217 | (d - 1) - 1) / stride[i] + 1) 218 | 219 | blocks = [] 220 | if window_size is None: 221 | window_size = [(5, 9)] * depths 222 | focal_windows = [(5, 9)] * depths 223 | else: 224 | window_size_h = window_size[0] 225 | window_size_w = window_size[1] 226 | window_size = [(window_size_h, window_size_w)] * depths 227 | focal_windows = [(window_size_h, window_size_w)] * depths 228 | focal_levels = [2] * depths 229 | pool_method = "fc" 230 | 231 | # default temporal focal transformer 232 | for i in range(depths): 233 | # 只有第一层有记忆力 234 | if (i + 1) == 1: 235 | # 第一层有记忆力 236 | blocks.append( 237 | MixFocalTransformerBlock(dim=hidden, 238 | num_heads=num_heads[i], 239 | window_size=window_size[i], 240 | focal_level=focal_levels[i], 241 | focal_window=focal_windows[i], 242 | n_vecs=n_vecs, 243 | t2t_params=t2t_params, 244 | pool_method=pool_method, 245 | memory=True, 246 | cs_win_strip=1), ) 247 | 248 | else: 249 | # 后面的层没有记忆 250 | blocks.append( 251 | MixFocalTransformerBlock(dim=hidden, 252 | num_heads=num_heads[i], 253 | window_size=window_size[i], 254 | focal_level=focal_levels[i], 255 | focal_window=focal_windows[i], 256 | n_vecs=n_vecs, 257 | t2t_params=t2t_params, 258 | pool_method=pool_method, 259 | memory=False)) 260 | 261 | self.transformer = nn.Sequential(*blocks) 262 | 263 | if init_weights: 264 | self.init_weights() 265 | # Need to initial the weights of MSDeformAttn specifically 266 | for m in self.modules(): 267 | if isinstance(m, SecondOrderDeformableAlignment): 268 | m.init_offset() 269 | 270 | # flow completion network 271 | if spy_net: 272 | # 使用SpyNet 273 | self.update_MFN = SPyNet() 274 | else: 275 | # 使用MaskFlowNetS 276 | self.update_MFN = MaskFlowNetS() 277 | 278 | def forward_bidirect_flow(self, masked_local_frames): 279 | b, l_t, c, h, w = masked_local_frames.size() 280 | scale_factor = int(1 / self.flow_res) # 1/4 -> 4 用来恢复尺度 281 | 282 | # compute forward and backward flows of masked frames 283 | masked_local_frames = F.interpolate(masked_local_frames.view( 284 | -1, c, h, w), 285 | scale_factor=self.flow_res, # 1/4 for default 286 | mode='bilinear', 287 | align_corners=True, 288 | recompute_scale_factor=True) 289 | masked_local_frames = masked_local_frames.view(b, l_t, c, h // scale_factor, 290 | w // scale_factor) 291 | mlf_1 = masked_local_frames[:, :-1, :, :, :].reshape( 292 | -1, c, h // scale_factor, w // scale_factor) 293 | mlf_2 = masked_local_frames[:, 1:, :, :, :].reshape( 294 | -1, c, h // scale_factor, w // scale_factor) 295 | pred_flows_forward = self.update_MFN(mlf_1, mlf_2) 296 | pred_flows_backward = self.update_MFN(mlf_2, mlf_1) 297 | 298 | pred_flows_forward = pred_flows_forward.view(b, l_t - 1, 2, h // scale_factor, 299 | w // scale_factor) 300 | pred_flows_backward = pred_flows_backward.view(b, l_t - 1, 2, h // scale_factor, 301 | w // scale_factor) 302 | 303 | # 最后必须resize到1/4来做feature warping 304 | if scale_factor != 4: 305 | pred_flows_forward = F.interpolate(pred_flows_forward.view(-1, 2, h // scale_factor, w // scale_factor), 306 | scale_factor=scale_factor / 4, 307 | mode='bilinear', 308 | align_corners=True, 309 | recompute_scale_factor=True).view(b, l_t - 1, 2, h // 4, 310 | w // 4) 311 | pred_flows_backward = F.interpolate(pred_flows_backward.view(-1, 2, h // scale_factor, w // scale_factor), 312 | scale_factor=scale_factor / 4, 313 | mode='bilinear', 314 | align_corners=True, 315 | recompute_scale_factor=True).view(b, l_t - 1, 2, h // 4, 316 | w // 4) 317 | 318 | return pred_flows_forward, pred_flows_backward 319 | 320 | def forward(self, masked_frames, num_local_frames=5): 321 | l_t = num_local_frames 322 | b, t, ori_c, ori_h, ori_w = masked_frames.size() 323 | 324 | # normalization before feeding into the flow completion module 325 | masked_local_frames = (masked_frames[:, :l_t, ...] + 1) / 2 326 | pred_flows = self.forward_bidirect_flow(masked_local_frames) 327 | 328 | # extracting features and performing the feature propagation on local features 329 | enc_feat = self.encoder(masked_frames.view(b * t, ori_c, ori_h, ori_w)) 330 | _, c, h, w = enc_feat.size() 331 | fold_output_size = (h, w) 332 | local_feat = enc_feat.view(b, t, c, h, w)[:, :l_t, ...] 333 | ref_feat = enc_feat.view(b, t, c, h, w)[:, l_t:, ...] 334 | 335 | local_feat = self.feat_prop_module(local_feat, pred_flows[1], 336 | pred_flows[0]) 337 | 338 | enc_feat = torch.cat((local_feat, ref_feat), dim=1) 339 | 340 | # content hallucination through stacking multiple temporal focal transformer blocks 341 | trans_feat = self.ss(enc_feat.view(-1, c, h, w), b, fold_output_size) # [B, t, f_h, f_w, hidden] 342 | 343 | trans_feat = self.transformer([trans_feat, fold_output_size, l_t]) # 比默认行为多传一个lt 344 | 345 | # 软组合 346 | trans_feat = self.sc(trans_feat[0], t, fold_output_size) 347 | 348 | trans_feat = trans_feat.view(b, t, -1, h, w) 349 | enc_feat = enc_feat + trans_feat # 残差链接 350 | 351 | # decode frames from features 352 | output = self.decoder(enc_feat.view(b * t, c, h, w)) 353 | output = torch.tanh(output) 354 | return output, pred_flows 355 | 356 | 357 | # ###################################################################### 358 | # Discriminator for Temporal Patch GAN 359 | # ###################################################################### 360 | 361 | 362 | class Discriminator(BaseNetwork): 363 | def __init__(self, 364 | in_channels=3, 365 | use_sigmoid=False, 366 | use_spectral_norm=True, 367 | init_weights=True): 368 | super(Discriminator, self).__init__() 369 | self.use_sigmoid = use_sigmoid 370 | nf = 32 371 | 372 | self.conv = nn.Sequential( 373 | spectral_norm( 374 | nn.Conv3d(in_channels=in_channels, 375 | out_channels=nf * 1, 376 | kernel_size=(3, 5, 5), 377 | stride=(1, 2, 2), 378 | padding=1, 379 | bias=not use_spectral_norm), use_spectral_norm), 380 | # nn.InstanceNorm2d(64, track_running_stats=False), 381 | nn.LeakyReLU(0.2, inplace=True), 382 | spectral_norm( 383 | nn.Conv3d(nf * 1, 384 | nf * 2, 385 | kernel_size=(3, 5, 5), 386 | stride=(1, 2, 2), 387 | padding=(1, 2, 2), 388 | bias=not use_spectral_norm), use_spectral_norm), 389 | # nn.InstanceNorm2d(128, track_running_stats=False), 390 | nn.LeakyReLU(0.2, inplace=True), 391 | spectral_norm( 392 | nn.Conv3d(nf * 2, 393 | nf * 4, 394 | kernel_size=(3, 5, 5), 395 | stride=(1, 2, 2), 396 | padding=(1, 2, 2), 397 | bias=not use_spectral_norm), use_spectral_norm), 398 | # nn.InstanceNorm2d(256, track_running_stats=False), 399 | nn.LeakyReLU(0.2, inplace=True), 400 | spectral_norm( 401 | nn.Conv3d(nf * 4, 402 | nf * 4, 403 | kernel_size=(3, 5, 5), 404 | stride=(1, 2, 2), 405 | padding=(1, 2, 2), 406 | bias=not use_spectral_norm), use_spectral_norm), 407 | # nn.InstanceNorm2d(256, track_running_stats=False), 408 | nn.LeakyReLU(0.2, inplace=True), 409 | spectral_norm( 410 | nn.Conv3d(nf * 4, 411 | nf * 4, 412 | kernel_size=(3, 5, 5), 413 | stride=(1, 2, 2), 414 | padding=(1, 2, 2), 415 | bias=not use_spectral_norm), use_spectral_norm), 416 | # nn.InstanceNorm2d(256, track_running_stats=False), 417 | nn.LeakyReLU(0.2, inplace=True), 418 | nn.Conv3d(nf * 4, 419 | nf * 4, 420 | kernel_size=(3, 5, 5), 421 | stride=(1, 2, 2), 422 | padding=(1, 2, 2))) 423 | 424 | if init_weights: 425 | self.init_weights() 426 | 427 | def forward(self, xs): 428 | # T, C, H, W = xs.shape (old) 429 | # B, T, C, H, W (new) 430 | xs_t = torch.transpose(xs, 1, 2) 431 | feat = self.conv(xs_t) 432 | if self.use_sigmoid: 433 | feat = torch.sigmoid(feat) 434 | out = torch.transpose(feat, 1, 2) # B, T, C, H, W 435 | return out 436 | 437 | 438 | def spectral_norm(module, mode=True): 439 | if mode: 440 | return _spectral_norm(module) 441 | return module 442 | -------------------------------------------------------------------------------- /model/modules/feat_prop.py: -------------------------------------------------------------------------------- 1 | """ 2 | BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment, CVPR 2022 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | from mmcv.ops import ModulatedDeformConv2d, modulated_deform_conv2d 8 | from mmcv.cnn import constant_init 9 | 10 | from model.modules.flow_comp import flow_warp 11 | 12 | 13 | class SecondOrderDeformableAlignment(ModulatedDeformConv2d): 14 | """Second-order deformable alignment module.""" 15 | def __init__(self, *args, **kwargs): 16 | self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10) 17 | 18 | super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs) 19 | 20 | self.conv_offset = nn.Sequential( 21 | nn.Conv2d(3 * self.out_channels + 4, self.out_channels, 3, 1, 1), 22 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 23 | nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), 24 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 25 | nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), 26 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 27 | nn.Conv2d(self.out_channels, 27 * self.deform_groups, 3, 1, 1), 28 | ) 29 | 30 | self.init_offset() 31 | 32 | def init_offset(self): 33 | constant_init(self.conv_offset[-1], val=0, bias=0) 34 | 35 | def forward(self, x, extra_feat, flow_1, flow_2): 36 | extra_feat = torch.cat([extra_feat, flow_1, flow_2], dim=1) 37 | out = self.conv_offset(extra_feat) 38 | o1, o2, mask = torch.chunk(out, 3, dim=1) 39 | 40 | # offset 41 | offset = self.max_residue_magnitude * torch.tanh( 42 | torch.cat((o1, o2), dim=1)) 43 | offset_1, offset_2 = torch.chunk(offset, 2, dim=1) 44 | offset_1 = offset_1 + flow_1.flip(1).repeat(1, 45 | offset_1.size(1) // 2, 1, 46 | 1) 47 | offset_2 = offset_2 + flow_2.flip(1).repeat(1, 48 | offset_2.size(1) // 2, 1, 49 | 1) 50 | offset = torch.cat([offset_1, offset_2], dim=1) 51 | 52 | # mask 53 | mask = torch.sigmoid(mask) 54 | 55 | # 默认使用的2阶对齐方法是dcn-v2,也就是调制可变形卷积,所谓的调制就是新增了一个和图像等大的mask,范围在0-1 56 | return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias, 57 | self.stride, self.padding, 58 | self.dilation, self.groups, 59 | self.deform_groups) 60 | 61 | 62 | class ConvBNReLU(nn.Module): 63 | """Conv with BN and ReLU, used for Simple Second Fusion""" 64 | 65 | def __init__(self, 66 | in_chan, 67 | out_chan, 68 | ks=3, 69 | stride=1, 70 | padding=1, 71 | *args, 72 | **kwargs): 73 | super(ConvBNReLU, self).__init__() 74 | self.conv = nn.Conv2d( 75 | in_chan, 76 | out_chan, 77 | kernel_size=ks, 78 | stride=stride, 79 | padding=padding, 80 | bias=False) 81 | self.bn = torch.nn.BatchNorm2d(out_chan) 82 | self.relu = nn.ReLU(inplace=True) 83 | 84 | def forward(self, x): 85 | x = self.conv(x) 86 | x = self.bn(x) 87 | x = self.relu(x) 88 | return x 89 | 90 | def init_weight(self): 91 | for ly in self.children(): 92 | if isinstance(ly, nn.Conv2d): 93 | nn.init.kaiming_normal_(ly.weight, a=1) 94 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 95 | 96 | 97 | class BidirectionalPropagation(nn.Module): 98 | def __init__(self, channel, freeze_dcn=False): 99 | super(BidirectionalPropagation, self).__init__() 100 | modules = ['backward_', 'forward_'] 101 | 102 | self.backbone = nn.ModuleDict() 103 | self.channel = channel 104 | 105 | self.deform_align = nn.ModuleDict() 106 | 107 | for i, module in enumerate(modules): 108 | self.deform_align[module] = SecondOrderDeformableAlignment( 109 | 2 * channel, channel, 3, padding=1, deform_groups=16) 110 | 111 | self.backbone[module] = nn.Sequential( 112 | nn.Conv2d((2 + i) * channel, channel, 3, 1, 1), 113 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 114 | nn.Conv2d(channel, channel, 3, 1, 1), 115 | ) 116 | 117 | self.fusion = nn.Conv2d(2 * channel, channel, 1, 1, 0) 118 | 119 | if freeze_dcn: 120 | # 冻结dcn 121 | self.freeze(self.deform_align) 122 | 123 | @staticmethod 124 | def freeze(layer): 125 | """For freezing some layers.""" 126 | for child in layer.children(): 127 | for param in child.parameters(): 128 | param.requires_grad = False 129 | 130 | def forward(self, x, flows_backward, flows_forward): 131 | """ 132 | x shape : [b, t, c, h, w] 133 | return [b, t, c, h, w] 134 | """ 135 | b, t, c, h, w = x.shape 136 | feats = {} 137 | feats['spatial'] = [x[:, i, :, :, :] for i in range(0, t)] 138 | 139 | for module_name in ['backward_', 'forward_']: 140 | 141 | feats[module_name] = [] 142 | 143 | frame_idx = range(0, t) 144 | flow_idx = range(-1, t - 1) 145 | mapping_idx = list(range(0, len(feats['spatial']))) 146 | mapping_idx += mapping_idx[::-1] 147 | 148 | if 'backward' in module_name: 149 | frame_idx = frame_idx[::-1] 150 | flows = flows_backward 151 | else: 152 | flows = flows_forward 153 | 154 | feat_prop = x.new_zeros(b, self.channel, h, w) 155 | 156 | # 修正backward时存在i和idx不对应的bug 157 | for i, idx in enumerate(frame_idx): 158 | feat_current = feats['spatial'][mapping_idx[idx]] 159 | 160 | if i > 0: 161 | if 'backward' in module_name: 162 | flow_n1 = flows[:, flow_idx[idx]+1, :, :, :] 163 | else: 164 | flow_n1 = flows[:, flow_idx[i], :, :, :] 165 | cond_n1 = flow_warp(feat_prop, flow_n1.permute(0, 2, 3, 1)) 166 | 167 | # initialize second-order features 168 | feat_n2 = torch.zeros_like(feat_prop) 169 | flow_n2 = torch.zeros_like(flow_n1) 170 | cond_n2 = torch.zeros_like(cond_n1) 171 | if i > 1: 172 | feat_n2 = feats[module_name][-2] 173 | if 'backward' in module_name: 174 | flow_n2 = flows[:, flow_idx[idx]+2, :, :, :] 175 | else: 176 | flow_n2 = flows[:, flow_idx[i - 1], :, :, :] 177 | flow_n2 = flow_n1 + flow_warp( 178 | flow_n2, flow_n1.permute(0, 2, 3, 1)) 179 | cond_n2 = flow_warp(feat_n2, 180 | flow_n2.permute(0, 2, 3, 1)) 181 | 182 | cond = torch.cat([cond_n1, feat_current, cond_n2], dim=1) 183 | 184 | # default 185 | feat_prop = torch.cat([feat_prop, feat_n2], dim=1) 186 | feat_prop = self.deform_align[module_name](feat_prop, cond, 187 | flow_n1, 188 | flow_n2) 189 | 190 | feat = [feat_current] + [ 191 | feats[k][idx] 192 | for k in feats if k not in ['spatial', module_name] 193 | ] + [feat_prop] 194 | 195 | feat = torch.cat(feat, dim=1) 196 | feat_prop = feat_prop + self.backbone[module_name](feat) 197 | feats[module_name].append(feat_prop) 198 | ################################################################## 199 | 200 | if 'backward' in module_name: 201 | feats[module_name] = feats[module_name][::-1] 202 | 203 | outputs = [] 204 | for i in range(0, t): 205 | align_feats = [feats[k].pop(0) for k in feats if k != 'spatial'] 206 | align_feats = torch.cat(align_feats, dim=1) 207 | outputs.append(self.fusion(align_feats)) 208 | 209 | return torch.stack(outputs, dim=1) + x 210 | -------------------------------------------------------------------------------- /model/modules/flow_comp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch 6 | 7 | from mmcv.cnn import ConvModule 8 | from mmcv.runner import load_checkpoint 9 | from model.modules.maskflownets import MaskFlowNetS 10 | 11 | 12 | class FlowCompletionLoss(nn.Module): 13 | """Flow completion loss""" 14 | def __init__(self, estimator='spy', device='cuda:0', flow_res=0.25): 15 | super().__init__() 16 | if estimator == 'spy': 17 | # default flow compute with spynet: 18 | self.fix_spynet = SPyNet() 19 | elif estimator == 'mfn': 20 | self.fix_spynet = MaskFlowNetS(device=device) 21 | else: 22 | raise TypeError('[estimator] should be spy or mfn, ' 23 | f'but got {estimator}.') 24 | 25 | # 算GT的SpyNet锁了权重,补全光流的SpyNet没有锁权重 26 | for p in self.fix_spynet.parameters(): 27 | p.requires_grad = False 28 | 29 | self.l1_criterion = nn.L1Loss() 30 | 31 | self.flow_res = flow_res # 在哪个分辨率计算光流 32 | 33 | def forward(self, pred_flows, gt_local_frames): 34 | b, l_t, c, h, w = gt_local_frames.size() 35 | scale_factor = int(1 / self.flow_res) # 1/4 -> 4 用来恢复尺度 36 | 37 | with torch.no_grad(): 38 | # compute gt forward and backward flows 39 | gt_local_frames = F.interpolate(gt_local_frames.view(-1, c, h, w), 40 | scale_factor=self.flow_res, 41 | mode='bilinear', 42 | align_corners=True, 43 | recompute_scale_factor=True) 44 | gt_local_frames = gt_local_frames.view(b, l_t, c, h // scale_factor, w // scale_factor) 45 | gtlf_1 = gt_local_frames[:, :-1, :, :, :].reshape( 46 | -1, c, h // scale_factor, w // scale_factor) 47 | gtlf_2 = gt_local_frames[:, 1:, :, :, :].reshape( 48 | -1, c, h // scale_factor, w // scale_factor) 49 | gt_flows_forward = self.fix_spynet(gtlf_1, gtlf_2) 50 | gt_flows_backward = self.fix_spynet(gtlf_2, gtlf_1) 51 | 52 | # 最后必须resize到1/4来做feature warping的flow的监督 53 | if scale_factor != 4: 54 | gt_flows_forward = F.interpolate(gt_flows_forward.view(-1, 2, h // scale_factor, w // scale_factor), 55 | scale_factor=scale_factor / 4, 56 | mode='bilinear', 57 | align_corners=True, 58 | recompute_scale_factor=True).view(b, l_t - 1, 2, h // 4, 59 | w // 4) 60 | gt_flows_backward = F.interpolate(gt_flows_backward.view(-1, 2, h // scale_factor, w // scale_factor), 61 | scale_factor=scale_factor / 4, 62 | mode='bilinear', 63 | align_corners=True, 64 | recompute_scale_factor=True).view(b, l_t - 1, 2, h // 4, 65 | w // 4) 66 | 67 | # calculate loss for flow completion 68 | forward_flow_loss = self.l1_criterion( 69 | pred_flows[0].view(-1, 2, h // 4, w // 4), gt_flows_forward) 70 | backward_flow_loss = self.l1_criterion( 71 | pred_flows[1].view(-1, 2, h // 4, w // 4), gt_flows_backward) 72 | flow_loss = forward_flow_loss + backward_flow_loss 73 | 74 | return flow_loss 75 | 76 | 77 | class SPyNet(nn.Module): 78 | """SPyNet network structure. 79 | The difference to the SPyNet in [tof.py] is that 80 | 1. more SPyNetBasicModule is used in this version, and 81 | 2. no batch normalization is used in this version. 82 | Paper: 83 | Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017 84 | Args: 85 | pretrained (str): path for pre-trained SPyNet. Default: None. 86 | """ 87 | def __init__( 88 | self, 89 | use_pretrain=True, 90 | pretrained='https://download.openmmlab.com/mmediting/restorers/basicvsr/spynet_20210409-c6c1bd09.pth', 91 | module_level=6 92 | ): 93 | super().__init__() 94 | 95 | self.basic_module = nn.ModuleList( 96 | [SPyNetBasicModule() for _ in range(module_level)]) 97 | 98 | if use_pretrain: 99 | if isinstance(pretrained, str): 100 | print("load pretrained SPyNet...") 101 | load_checkpoint(self, pretrained, strict=True) 102 | elif pretrained is not None: 103 | raise TypeError('[pretrained] should be str or None, ' 104 | f'but got {type(pretrained)}.') 105 | 106 | self.register_buffer( 107 | 'mean', 108 | torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 109 | self.register_buffer( 110 | 'std', 111 | torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 112 | 113 | def compute_flow(self, ref, supp): 114 | """Compute flow from ref to supp. 115 | Note that in this function, the images are already resized to a 116 | multiple of 32. 117 | Args: 118 | ref (Tensor): Reference image with shape of (n, 3, h, w). 119 | supp (Tensor): Supporting image with shape of (n, 3, h, w). 120 | Returns: 121 | Tensor: Estimated optical flow: (n, 2, h, w). 122 | """ 123 | n, _, h, w = ref.size() 124 | 125 | # normalize the input images 126 | ref = [(ref - self.mean) / self.std] 127 | supp = [(supp - self.mean) / self.std] 128 | 129 | # generate downsampled frames 130 | # for level in range(5): # default 131 | for level in range(len(self.basic_module)-1): 132 | ref.append( 133 | F.avg_pool2d(input=ref[-1], 134 | kernel_size=2, 135 | stride=2, 136 | count_include_pad=False)) 137 | supp.append( 138 | F.avg_pool2d(input=supp[-1], 139 | kernel_size=2, 140 | stride=2, 141 | count_include_pad=False)) 142 | ref = ref[::-1] 143 | supp = supp[::-1] 144 | 145 | # flow computation 146 | # flow = ref[0].new_zeros(n, 2, h // 32, w // 32) # default 147 | reduce = 2**(len(self.basic_module)-1) 148 | flow = ref[0].new_zeros(n, 2, h // reduce, w // reduce) 149 | 150 | # for level in range(len(ref)): # default 151 | for level in range(len(self.basic_module)): 152 | if level == 0: 153 | flow_up = flow 154 | else: 155 | flow_up = F.interpolate(input=flow, 156 | scale_factor=2, 157 | mode='bilinear', 158 | align_corners=True) * 2.0 159 | 160 | # add the residue to the upsampled flow 161 | flow = flow_up + self.basic_module[level](torch.cat([ 162 | ref[level], 163 | flow_warp(supp[level], 164 | flow_up.permute(0, 2, 3, 1).contiguous(), 165 | padding_mode='border'), flow_up 166 | ], 1)) 167 | 168 | return flow 169 | 170 | def forward(self, ref, supp): 171 | """Forward function of SPyNet. 172 | This function computes the optical flow from ref to supp. 173 | Args: 174 | ref (Tensor): Reference image with shape of (n, 3, h, w). 175 | supp (Tensor): Supporting image with shape of (n, 3, h, w). 176 | Returns: 177 | Tensor: Estimated optical flow: (n, 2, h, w). 178 | """ 179 | 180 | # upsize to a multiple of 32 181 | h, w = ref.shape[2:4] 182 | w_up = w if (w % 32) == 0 else 32 * (w // 32 + 1) 183 | h_up = h if (h % 32) == 0 else 32 * (h // 32 + 1) 184 | ref = F.interpolate(input=ref, 185 | size=(h_up, w_up), 186 | mode='bilinear', 187 | align_corners=False) 188 | supp = F.interpolate(input=supp, 189 | size=(h_up, w_up), 190 | mode='bilinear', 191 | align_corners=False) 192 | 193 | # compute flow, and resize back to the original resolution 194 | flow = F.interpolate(input=self.compute_flow(ref, supp), 195 | size=(h, w), 196 | mode='bilinear', 197 | align_corners=False) 198 | 199 | # adjust the flow values 200 | flow[:, 0, :, :] *= float(w) / float(w_up) 201 | flow[:, 1, :, :] *= float(h) / float(h_up) 202 | 203 | return flow 204 | 205 | 206 | class SPyNetBasicModule(nn.Module): 207 | """Basic Module for SPyNet. 208 | Paper: 209 | Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017 210 | """ 211 | def __init__(self): 212 | super().__init__() 213 | 214 | self.basic_module = nn.Sequential( 215 | ConvModule(in_channels=8, 216 | out_channels=32, 217 | kernel_size=7, 218 | stride=1, 219 | padding=3, 220 | norm_cfg=None, 221 | act_cfg=dict(type='ReLU')), 222 | ConvModule(in_channels=32, 223 | out_channels=64, 224 | kernel_size=7, 225 | stride=1, 226 | padding=3, 227 | norm_cfg=None, 228 | act_cfg=dict(type='ReLU')), 229 | ConvModule(in_channels=64, 230 | out_channels=32, 231 | kernel_size=7, 232 | stride=1, 233 | padding=3, 234 | norm_cfg=None, 235 | act_cfg=dict(type='ReLU')), 236 | ConvModule(in_channels=32, 237 | out_channels=16, 238 | kernel_size=7, 239 | stride=1, 240 | padding=3, 241 | norm_cfg=None, 242 | act_cfg=dict(type='ReLU')), 243 | ConvModule(in_channels=16, 244 | out_channels=2, 245 | kernel_size=7, 246 | stride=1, 247 | padding=3, 248 | norm_cfg=None, 249 | act_cfg=None)) 250 | 251 | def forward(self, tensor_input): 252 | """ 253 | Args: 254 | tensor_input (Tensor): Input tensor with shape (b, 8, h, w). 255 | 8 channels contain: 256 | [reference image (3), neighbor image (3), initial flow (2)]. 257 | Returns: 258 | Tensor: Refined flow with shape (b, 2, h, w) 259 | """ 260 | return self.basic_module(tensor_input) 261 | 262 | 263 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 264 | def make_colorwheel(): 265 | """ 266 | Generates a color wheel for optical flow visualization as presented in: 267 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 268 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 269 | 270 | Code follows the original C++ source code of Daniel Scharstein. 271 | Code follows the the Matlab source code of Deqing Sun. 272 | 273 | Returns: 274 | np.ndarray: Color wheel 275 | """ 276 | 277 | RY = 15 278 | YG = 6 279 | GC = 4 280 | CB = 11 281 | BM = 13 282 | MR = 6 283 | 284 | ncols = RY + YG + GC + CB + BM + MR 285 | colorwheel = np.zeros((ncols, 3)) 286 | col = 0 287 | 288 | # RY 289 | colorwheel[0:RY, 0] = 255 290 | colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY) 291 | col = col + RY 292 | # YG 293 | colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG) 294 | colorwheel[col:col + YG, 1] = 255 295 | col = col + YG 296 | # GC 297 | colorwheel[col:col + GC, 1] = 255 298 | colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC) 299 | col = col + GC 300 | # CB 301 | colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB) 302 | colorwheel[col:col + CB, 2] = 255 303 | col = col + CB 304 | # BM 305 | colorwheel[col:col + BM, 2] = 255 306 | colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM) 307 | col = col + BM 308 | # MR 309 | colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR) 310 | colorwheel[col:col + MR, 0] = 255 311 | return colorwheel 312 | 313 | 314 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 315 | """ 316 | Applies the flow color wheel to (possibly clipped) flow components u and v. 317 | 318 | According to the C++ source code of Daniel Scharstein 319 | According to the Matlab source code of Deqing Sun 320 | 321 | Args: 322 | u (np.ndarray): Input horizontal flow of shape [H,W] 323 | v (np.ndarray): Input vertical flow of shape [H,W] 324 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 325 | 326 | Returns: 327 | np.ndarray: Flow visualization image of shape [H,W,3] 328 | """ 329 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 330 | colorwheel = make_colorwheel() # shape [55x3] 331 | ncols = colorwheel.shape[0] 332 | rad = np.sqrt(np.square(u) + np.square(v)) 333 | a = np.arctan2(-v, -u) / np.pi 334 | fk = (a + 1) / 2 * (ncols - 1) 335 | k0 = np.floor(fk).astype(np.int32) 336 | k1 = k0 + 1 337 | k1[k1 == ncols] = 0 338 | f = fk - k0 339 | for i in range(colorwheel.shape[1]): 340 | tmp = colorwheel[:, i] 341 | col0 = tmp[k0] / 255.0 342 | col1 = tmp[k1] / 255.0 343 | col = (1 - f) * col0 + f * col1 344 | idx = (rad <= 1) 345 | col[idx] = 1 - rad[idx] * (1 - col[idx]) 346 | col[~idx] = col[~idx] * 0.75 # out of range 347 | # Note the 2-i => BGR instead of RGB 348 | ch_idx = 2 - i if convert_to_bgr else i 349 | flow_image[:, :, ch_idx] = np.floor(255 * col) 350 | return flow_image 351 | 352 | 353 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 354 | """ 355 | Expects a two dimensional flow image of shape. 356 | 357 | Args: 358 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 359 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 360 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 361 | 362 | Returns: 363 | np.ndarray: Flow visualization image of shape [H,W,3] 364 | """ 365 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 366 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 367 | if clip_flow is not None: 368 | flow_uv = np.clip(flow_uv, 0, clip_flow) 369 | u = flow_uv[:, :, 0] 370 | v = flow_uv[:, :, 1] 371 | rad = np.sqrt(np.square(u) + np.square(v)) 372 | rad_max = np.max(rad) 373 | epsilon = 1e-5 374 | u = u / (rad_max + epsilon) 375 | v = v / (rad_max + epsilon) 376 | return flow_uv_to_colors(u, v, convert_to_bgr) 377 | 378 | 379 | def flow_warp(x, 380 | flow, 381 | interpolation='bilinear', 382 | padding_mode='zeros', 383 | align_corners=True): 384 | """Warp an image or a feature map with optical flow. 385 | Args: 386 | x (Tensor): Tensor with size (n, c, h, w). 387 | flow (Tensor): Tensor with size (n, h, w, 2). The last dimension is 388 | a two-channel, denoting the width and height relative offsets. 389 | Note that the values are not normalized to [-1, 1]. 390 | interpolation (str): Interpolation mode: 'nearest' or 'bilinear'. 391 | Default: 'bilinear'. 392 | padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'. 393 | Default: 'zeros'. 394 | align_corners (bool): Whether align corners. Default: True. 395 | Returns: 396 | Tensor: Warped image or feature map. 397 | """ 398 | if x.size()[-2:] != flow.size()[1:3]: 399 | raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and ' 400 | f'flow ({flow.size()[1:3]}) are not the same.') 401 | _, _, h, w = x.size() 402 | # create mesh grid 403 | grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w)) 404 | grid = torch.stack((grid_x, grid_y), 2).type_as(x) # (w, h, 2) 405 | grid.requires_grad = False 406 | 407 | grid_flow = grid + flow 408 | # scale grid_flow to [-1,1] 409 | grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0 410 | grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0 411 | grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3) 412 | output = F.grid_sample(x, 413 | grid_flow, 414 | mode=interpolation, 415 | padding_mode=padding_mode, 416 | align_corners=align_corners) 417 | return output 418 | 419 | 420 | def initial_mask_flow(mask): 421 | """ 422 | mask 1 indicates valid pixel 0 indicates unknown pixel 423 | """ 424 | B, T, C, H, W = mask.shape 425 | 426 | # calculate relative position 427 | grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) 428 | 429 | grid_y, grid_x = grid_y.type_as(mask), grid_x.type_as(mask) 430 | abs_relative_pos_y = H - torch.abs(grid_y[None, :, :] - grid_y[:, None, :]) 431 | relative_pos_y = H - (grid_y[None, :, :] - grid_y[:, None, :]) 432 | 433 | abs_relative_pos_x = W - torch.abs(grid_x[:, None, :] - grid_x[:, :, None]) 434 | relative_pos_x = W - (grid_x[:, None, :] - grid_x[:, :, None]) 435 | 436 | # calculate the nearest indices 437 | pos_up = mask.unsqueeze(3).repeat( 438 | 1, 1, 1, H, 1, 1).flip(4) * abs_relative_pos_y[None, None, None] * ( 439 | relative_pos_y <= H)[None, None, None] 440 | nearest_indice_up = pos_up.max(dim=4)[1] 441 | 442 | pos_down = mask.unsqueeze(3).repeat(1, 1, 1, H, 1, 1) * abs_relative_pos_y[ 443 | None, None, None] * (relative_pos_y <= H)[None, None, None] 444 | nearest_indice_down = (pos_down).max(dim=4)[1] 445 | 446 | pos_left = mask.unsqueeze(4).repeat( 447 | 1, 1, 1, 1, W, 1).flip(5) * abs_relative_pos_x[None, None, None] * ( 448 | relative_pos_x <= W)[None, None, None] 449 | nearest_indice_left = (pos_left).max(dim=5)[1] 450 | 451 | pos_right = mask.unsqueeze(4).repeat( 452 | 1, 1, 1, 1, W, 1) * abs_relative_pos_x[None, None, None] * ( 453 | relative_pos_x <= W)[None, None, None] 454 | nearest_indice_right = (pos_right).max(dim=5)[1] 455 | 456 | # NOTE: IMPORTANT !!! depending on how to use this offset 457 | initial_offset_up = -(nearest_indice_up - grid_y[None, None, None]).flip(3) 458 | initial_offset_down = nearest_indice_down - grid_y[None, None, None] 459 | 460 | initial_offset_left = -(nearest_indice_left - 461 | grid_x[None, None, None]).flip(4) 462 | initial_offset_right = nearest_indice_right - grid_x[None, None, None] 463 | 464 | # nearest_indice_x = (mask.unsqueeze(1).repeat(1, img_width, 1) * relative_pos_x).max(dim=2)[1] 465 | # initial_offset_x = nearest_indice_x - grid_x 466 | 467 | # handle the boundary cases 468 | final_offset_down = (initial_offset_down < 0) * initial_offset_up + ( 469 | initial_offset_down > 0) * initial_offset_down 470 | final_offset_up = (initial_offset_up > 0) * initial_offset_down + ( 471 | initial_offset_up < 0) * initial_offset_up 472 | final_offset_right = (initial_offset_right < 0) * initial_offset_left + ( 473 | initial_offset_right > 0) * initial_offset_right 474 | final_offset_left = (initial_offset_left > 0) * initial_offset_right + ( 475 | initial_offset_left < 0) * initial_offset_left 476 | zero_offset = torch.zeros_like(final_offset_down) 477 | # out = torch.cat([final_offset_left, zero_offset, final_offset_right, zero_offset, zero_offset, final_offset_up, zero_offset, final_offset_down], dim=2) 478 | out = torch.cat([ 479 | zero_offset, final_offset_left, zero_offset, final_offset_right, 480 | final_offset_up, zero_offset, final_offset_down, zero_offset 481 | ], 482 | dim=2) 483 | 484 | return out 485 | -------------------------------------------------------------------------------- /model/modules/maskflownets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from mmflow.apis import init_model 6 | 7 | 8 | class MaskFlowNetS(nn.Module): 9 | """MaskFlowNetS network structure. 10 | Paper: 11 | MaskFlownet: Asymmetric Feature Matching With Learnable Occlusion Mask, CVPR, 2020 12 | Args: 13 | pretrained (str): path for pre-trained MaskFlowNetS. Default: None. 14 | """ 15 | def __init__( 16 | self, 17 | use_pretrain=True, 18 | # pretrained='https://download.openmmlab.com/mmflow/maskflownet/maskflownets_8x1_slong_flyingchairs_384x448.pth', 19 | pretrained='./release_model/maskflownets_8x1_sfine_flyingthings3d_subset_384x768.pth', 20 | config_file='../mmflow/configs/maskflownet/maskflownets_8x1_sfine_flyingthings3d_subset_384x768.py', 21 | device='cuda:0', 22 | module_level=6 23 | ): 24 | super().__init__() 25 | 26 | if use_pretrain: 27 | if isinstance(pretrained, str): 28 | print("load pretrained MaskFlowNetS...") 29 | # self.maskflownetS = init_model(config_file, pretrained, device=device) 30 | self.maskflownetS = init_model(config_file, pretrained, device='cpu') 31 | # load_checkpoint(self, pretrained, strict=True) 32 | elif pretrained is not None: 33 | raise TypeError('[pretrained] should be str or None, ' 34 | f'but got {type(pretrained)}.') 35 | 36 | self.register_buffer( 37 | 'mean', 38 | torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 39 | self.register_buffer( 40 | 'std', 41 | torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 42 | 43 | @staticmethod 44 | def centralize(img1, img2): 45 | """Centralize input images. 46 | Args: 47 | img1 (Tensor): The first input image. 48 | img2 (Tensor): The second input image. 49 | Returns: 50 | Tuple[Tensor, Tensor]: The first centralized image and the second 51 | centralized image. 52 | """ 53 | rgb_mean = torch.cat((img1, img2), 2) 54 | rgb_mean = rgb_mean.view(rgb_mean.shape[0], 3, -1).mean(2) 55 | rgb_mean = rgb_mean.view(rgb_mean.shape[0], 3, 1, 1) 56 | return img1 - rgb_mean, img2 - rgb_mean, rgb_mean 57 | 58 | def compute_flow(self, ref, supp): 59 | """Compute flow from ref to supp. 60 | Note that in this function, the images are already resized to a 61 | multiple of 64. 62 | Args: 63 | ref (Tensor): Reference image with shape of (n, 3, h, w). 64 | supp (Tensor): Supporting image with shape of (n, 3, h, w). 65 | Returns: 66 | Tensor: Estimated optical flow: (n, 2, h, w). 67 | """ 68 | n, _, h, w = ref.size() 69 | 70 | feat1, feat2 = self.maskflownetS.extract_feat(torch.cat((ref, supp), dim=1)) 71 | flows_stage1, mask_stage1 = self.maskflownetS.decoder( 72 | feat1, feat2, return_mask=True) 73 | 74 | return flows_stage1 75 | 76 | def forward(self, ref, supp): 77 | """Forward function of MaskFlowNetS. 78 | This function computes the optical flow from ref to supp. 79 | Args: 80 | ref (Tensor): Reference image with shape of (n, 3, h, w). 81 | supp (Tensor): Supporting image with shape of (n, 3, h, w). 82 | Returns: 83 | Tensor: Estimated optical flow: (n, 2, h, w). 84 | """ 85 | 86 | # upsize to a multiple of 64 87 | h, w = ref.shape[2:4] 88 | w_up = w if (w % 64) == 0 else 64 * (w // 64 + 1) 89 | h_up = h if (h % 64) == 0 else 64 * (h // 64 + 1) 90 | ref = F.interpolate(input=ref, 91 | size=(h_up, w_up), 92 | mode='bilinear', 93 | align_corners=False) 94 | supp = F.interpolate(input=supp, 95 | size=(h_up, w_up), 96 | mode='bilinear', 97 | align_corners=False) 98 | 99 | # compute flow, and resize back to the original resolution 100 | flow = F.interpolate(input=self.compute_flow(ref, supp)['level2'], 101 | size=(h, w), 102 | mode='bilinear', 103 | align_corners=False) 104 | 105 | # adjust the flow values 106 | flow[:, 0, :, :] *= float(w) / float(w_up) 107 | flow[:, 1, :, :] *= float(h) / float(h_up) 108 | 109 | return flow 110 | 111 | 112 | def test_MFN(): 113 | # Specify the path to model config and checkpoint file 114 | config_file = '../../../mmflow/configs/maskflownets_8x1_sfine_flyingthings3d_subset_384x768.py' 115 | checkpoint_file = '../../release_model/maskflownets_8x1_sfine_flyingthings3d_subset_384x768.pth' 116 | 117 | # build the model from a config file and a checkpoint file 118 | model = init_model(config_file, checkpoint_file, device='cuda:0') 119 | pass 120 | 121 | 122 | if __name__ == "__main__": 123 | test_MFN() 124 | -------------------------------------------------------------------------------- /model/modules/spectral_norm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Spectral Normalization from https://arxiv.org/abs/1802.05957 3 | """ 4 | import torch 5 | from torch.nn.functional import normalize 6 | 7 | 8 | class SpectralNorm(object): 9 | # Invariant before and after each forward call: 10 | # u = normalize(W @ v) 11 | # NB: At initialization, this invariant is not enforced 12 | 13 | _version = 1 14 | 15 | # At version 1: 16 | # made `W` not a buffer, 17 | # added `v` as a buffer, and 18 | # made eval mode use `W = u @ W_orig @ v` rather than the stored `W`. 19 | 20 | def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12): 21 | self.name = name 22 | self.dim = dim 23 | if n_power_iterations <= 0: 24 | raise ValueError( 25 | 'Expected n_power_iterations to be positive, but ' 26 | 'got n_power_iterations={}'.format(n_power_iterations)) 27 | self.n_power_iterations = n_power_iterations 28 | self.eps = eps 29 | 30 | def reshape_weight_to_matrix(self, weight): 31 | weight_mat = weight 32 | if self.dim != 0: 33 | # permute dim to front 34 | weight_mat = weight_mat.permute( 35 | self.dim, 36 | *[d for d in range(weight_mat.dim()) if d != self.dim]) 37 | height = weight_mat.size(0) 38 | return weight_mat.reshape(height, -1) 39 | 40 | def compute_weight(self, module, do_power_iteration): 41 | # NB: If `do_power_iteration` is set, the `u` and `v` vectors are 42 | # updated in power iteration **in-place**. This is very important 43 | # because in `DataParallel` forward, the vectors (being buffers) are 44 | # broadcast from the parallelized module to each module replica, 45 | # which is a new module object created on the fly. And each replica 46 | # runs its own spectral norm power iteration. So simply assigning 47 | # the updated vectors to the module this function runs on will cause 48 | # the update to be lost forever. And the next time the parallelized 49 | # module is replicated, the same randomly initialized vectors are 50 | # broadcast and used! 51 | # 52 | # Therefore, to make the change propagate back, we rely on two 53 | # important behaviors (also enforced via tests): 54 | # 1. `DataParallel` doesn't clone storage if the broadcast tensor 55 | # is already on correct device; and it makes sure that the 56 | # parallelized module is already on `device[0]`. 57 | # 2. If the out tensor in `out=` kwarg has correct shape, it will 58 | # just fill in the values. 59 | # Therefore, since the same power iteration is performed on all 60 | # devices, simply updating the tensors in-place will make sure that 61 | # the module replica on `device[0]` will update the _u vector on the 62 | # parallized module (by shared storage). 63 | # 64 | # However, after we update `u` and `v` in-place, we need to **clone** 65 | # them before using them to normalize the weight. This is to support 66 | # backproping through two forward passes, e.g., the common pattern in 67 | # GAN training: loss = D(real) - D(fake). Otherwise, engine will 68 | # complain that variables needed to do backward for the first forward 69 | # (i.e., the `u` and `v` vectors) are changed in the second forward. 70 | weight = getattr(module, self.name + '_orig') 71 | u = getattr(module, self.name + '_u') 72 | v = getattr(module, self.name + '_v') 73 | weight_mat = self.reshape_weight_to_matrix(weight) 74 | 75 | if do_power_iteration: 76 | with torch.no_grad(): 77 | for _ in range(self.n_power_iterations): 78 | # Spectral norm of weight equals to `u^T W v`, where `u` and `v` 79 | # are the first left and right singular vectors. 80 | # This power iteration produces approximations of `u` and `v`. 81 | v = normalize(torch.mv(weight_mat.t(), u), 82 | dim=0, 83 | eps=self.eps, 84 | out=v) 85 | u = normalize(torch.mv(weight_mat, v), 86 | dim=0, 87 | eps=self.eps, 88 | out=u) 89 | if self.n_power_iterations > 0: 90 | # See above on why we need to clone 91 | u = u.clone() 92 | v = v.clone() 93 | 94 | sigma = torch.dot(u, torch.mv(weight_mat, v)) 95 | weight = weight / sigma 96 | return weight 97 | 98 | def remove(self, module): 99 | with torch.no_grad(): 100 | weight = self.compute_weight(module, do_power_iteration=False) 101 | delattr(module, self.name) 102 | delattr(module, self.name + '_u') 103 | delattr(module, self.name + '_v') 104 | delattr(module, self.name + '_orig') 105 | module.register_parameter(self.name, 106 | torch.nn.Parameter(weight.detach())) 107 | 108 | def __call__(self, module, inputs): 109 | setattr( 110 | module, self.name, 111 | self.compute_weight(module, do_power_iteration=module.training)) 112 | 113 | def _solve_v_and_rescale(self, weight_mat, u, target_sigma): 114 | # Tries to returns a vector `v` s.t. `u = normalize(W @ v)` 115 | # (the invariant at top of this class) and `u @ W @ v = sigma`. 116 | # This uses pinverse in case W^T W is not invertible. 117 | v = torch.chain_matmul(weight_mat.t().mm(weight_mat).pinverse(), 118 | weight_mat.t(), u.unsqueeze(1)).squeeze(1) 119 | return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v))) 120 | 121 | @staticmethod 122 | def apply(module, name, n_power_iterations, dim, eps): 123 | for k, hook in module._forward_pre_hooks.items(): 124 | if isinstance(hook, SpectralNorm) and hook.name == name: 125 | raise RuntimeError( 126 | "Cannot register two spectral_norm hooks on " 127 | "the same parameter {}".format(name)) 128 | 129 | fn = SpectralNorm(name, n_power_iterations, dim, eps) 130 | weight = module._parameters[name] 131 | 132 | with torch.no_grad(): 133 | weight_mat = fn.reshape_weight_to_matrix(weight) 134 | 135 | h, w = weight_mat.size() 136 | # randomly initialize `u` and `v` 137 | u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps) 138 | v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps) 139 | 140 | delattr(module, fn.name) 141 | module.register_parameter(fn.name + "_orig", weight) 142 | # We still need to assign weight back as fn.name because all sorts of 143 | # things may assume that it exists, e.g., when initializing weights. 144 | # However, we can't directly assign as it could be an nn.Parameter and 145 | # gets added as a parameter. Instead, we register weight.data as a plain 146 | # attribute. 147 | setattr(module, fn.name, weight.data) 148 | module.register_buffer(fn.name + "_u", u) 149 | module.register_buffer(fn.name + "_v", v) 150 | 151 | module.register_forward_pre_hook(fn) 152 | 153 | module._register_state_dict_hook(SpectralNormStateDictHook(fn)) 154 | module._register_load_state_dict_pre_hook( 155 | SpectralNormLoadStateDictPreHook(fn)) 156 | return fn 157 | 158 | 159 | # This is a top level class because Py2 pickle doesn't like inner class nor an 160 | # instancemethod. 161 | class SpectralNormLoadStateDictPreHook(object): 162 | # See docstring of SpectralNorm._version on the changes to spectral_norm. 163 | def __init__(self, fn): 164 | self.fn = fn 165 | 166 | # For state_dict with version None, (assuming that it has gone through at 167 | # least one training forward), we have 168 | # 169 | # u = normalize(W_orig @ v) 170 | # W = W_orig / sigma, where sigma = u @ W_orig @ v 171 | # 172 | # To compute `v`, we solve `W_orig @ x = u`, and let 173 | # v = x / (u @ W_orig @ x) * (W / W_orig). 174 | def __call__(self, state_dict, prefix, local_metadata, strict, 175 | missing_keys, unexpected_keys, error_msgs): 176 | fn = self.fn 177 | version = local_metadata.get('spectral_norm', 178 | {}).get(fn.name + '.version', None) 179 | if version is None or version < 1: 180 | with torch.no_grad(): 181 | weight_orig = state_dict[prefix + fn.name + '_orig'] 182 | # weight = state_dict.pop(prefix + fn.name) 183 | # sigma = (weight_orig / weight).mean() 184 | weight_mat = fn.reshape_weight_to_matrix(weight_orig) 185 | u = state_dict[prefix + fn.name + '_u'] 186 | # v = fn._solve_v_and_rescale(weight_mat, u, sigma) 187 | # state_dict[prefix + fn.name + '_v'] = v 188 | 189 | 190 | # This is a top level class because Py2 pickle doesn't like inner class nor an 191 | # instancemethod. 192 | class SpectralNormStateDictHook(object): 193 | # See docstring of SpectralNorm._version on the changes to spectral_norm. 194 | def __init__(self, fn): 195 | self.fn = fn 196 | 197 | def __call__(self, module, state_dict, prefix, local_metadata): 198 | if 'spectral_norm' not in local_metadata: 199 | local_metadata['spectral_norm'] = {} 200 | key = self.fn.name + '.version' 201 | if key in local_metadata['spectral_norm']: 202 | raise RuntimeError( 203 | "Unexpected key in metadata['spectral_norm']: {}".format(key)) 204 | local_metadata['spectral_norm'][key] = self.fn._version 205 | 206 | 207 | def spectral_norm(module, 208 | name='weight', 209 | n_power_iterations=1, 210 | eps=1e-12, 211 | dim=None): 212 | r"""Applies spectral normalization to a parameter in the given module. 213 | 214 | .. math:: 215 | \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, 216 | \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} 217 | 218 | Spectral normalization stabilizes the training of discriminators (critics) 219 | in Generative Adversarial Networks (GANs) by rescaling the weight tensor 220 | with spectral norm :math:`\sigma` of the weight matrix calculated using 221 | power iteration method. If the dimension of the weight tensor is greater 222 | than 2, it is reshaped to 2D in power iteration method to get spectral 223 | norm. This is implemented via a hook that calculates spectral norm and 224 | rescales weight before every :meth:`~Module.forward` call. 225 | 226 | See `Spectral Normalization for Generative Adversarial Networks`_ . 227 | 228 | .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 229 | 230 | Args: 231 | module (nn.Module): containing module 232 | name (str, optional): name of weight parameter 233 | n_power_iterations (int, optional): number of power iterations to 234 | calculate spectral norm 235 | eps (float, optional): epsilon for numerical stability in 236 | calculating norms 237 | dim (int, optional): dimension corresponding to number of outputs, 238 | the default is ``0``, except for modules that are instances of 239 | ConvTranspose{1,2,3}d, when it is ``1`` 240 | 241 | Returns: 242 | The original module with the spectral norm hook 243 | 244 | Example:: 245 | 246 | >>> m = spectral_norm(nn.Linear(20, 40)) 247 | >>> m 248 | Linear(in_features=20, out_features=40, bias=True) 249 | >>> m.weight_u.size() 250 | torch.Size([40]) 251 | 252 | """ 253 | if dim is None: 254 | if isinstance(module, 255 | (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, 256 | torch.nn.ConvTranspose3d)): 257 | dim = 1 258 | else: 259 | dim = 0 260 | SpectralNorm.apply(module, name, n_power_iterations, dim, eps) 261 | return module 262 | 263 | 264 | def remove_spectral_norm(module, name='weight'): 265 | r"""Removes the spectral normalization reparameterization from a module. 266 | 267 | Args: 268 | module (Module): containing module 269 | name (str, optional): name of weight parameter 270 | 271 | Example: 272 | >>> m = spectral_norm(nn.Linear(40, 10)) 273 | >>> remove_spectral_norm(m) 274 | """ 275 | for k, hook in module._forward_pre_hooks.items(): 276 | if isinstance(hook, SpectralNorm) and hook.name == name: 277 | hook.remove(module) 278 | del module._forward_pre_hooks[k] 279 | return module 280 | 281 | raise ValueError("spectral_norm of '{}' not found in {}".format( 282 | name, module)) 283 | 284 | 285 | def use_spectral_norm(module, use_sn=False): 286 | if use_sn: 287 | return spectral_norm(module) 288 | return module -------------------------------------------------------------------------------- /release_model/README.md: -------------------------------------------------------------------------------- 1 | Place the downloaded model here. 2 | 3 | [comment]: <> (:link: **Download Links:** [[Google Drive](https://drive.google.com/file/d/1tNJMTJ2gmWdIXJoHVi5-H504uImUiJW9/view?usp=sharing)] [[Baidu Disk](https://pan.baidu.com/s/1qXAErbilY_n_Fh9KB8UF7w?pwd=lsjw)]) 4 | 5 | The directory structure will be arranged as: 6 | ``` 7 | release_model 8 | |- FlowLens.pth 9 | |- i3d_rgb_imagenet.pt (for evaluating VFID metric) 10 | |- maskflownets_8x1_sfine_flyingthings3d_subset_384x768.pth 11 | |- README.md 12 | ``` 13 | -------------------------------------------------------------------------------- /release_model/maskflownets_8x1_sfine_flyingthings3d_subset_384x768.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '/mnt/WORKSPACE/mmflow/configs/_base_/models/maskflownets.py', 3 | '/mnt/WORKSPACE/mmflow/configs/_base_/datasets/flyingthings3d_subset_384x768.py', 4 | '/mnt/WORKSPACE/mmflow/configs/_base_/schedules/schedule_s_fine.py', './mnt/WORKSPACE/mmflow/configs/_base_/default_runtime.py' 5 | ] 6 | 7 | optimizer = dict(type='Adam', lr=0.0001, weight_decay=0., betas=(0.9, 0.999)) 8 | 9 | # Train on FlyingChairs and finetune on FlyingThings3D_subset 10 | load_from = 'https://download.openmmlab.com/mmflow/maskflownet/maskflownets_8x1_slong_flyingchairs_384x448.pth' # noqa -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | from shutil import copyfile 5 | 6 | import torch 7 | import torch.multiprocessing as mp 8 | 9 | from core.trainer import Trainer 10 | from core.dist import ( 11 | get_world_size, 12 | get_local_rank, 13 | get_global_rank, 14 | get_master_ip, 15 | ) 16 | 17 | parser = argparse.ArgumentParser(description='FlowLens') 18 | parser.add_argument('-c', 19 | '--config', 20 | default=None, 21 | type=str) 22 | parser.add_argument('-p', '--port', default='23455', type=str) 23 | args = parser.parse_args() 24 | 25 | 26 | def main_worker(rank, config): 27 | if 'local_rank' not in config: 28 | config['local_rank'] = config['global_rank'] = rank 29 | if config['distributed']: 30 | torch.cuda.set_device(int(config['local_rank'])) 31 | torch.distributed.init_process_group(backend='nccl', # nccl for Linux DDP, gloo for windows 32 | init_method=config['init_method'], 33 | world_size=config['world_size'], 34 | rank=config['global_rank'], 35 | group_name='mtorch') 36 | print('using GPU {}-{} for training'.format(int(config['global_rank']), 37 | int(config['local_rank']))) 38 | 39 | config['save_dir'] = os.path.join( 40 | config['save_dir'], 41 | '{}_{}'.format(config['model']['net'], 42 | os.path.basename(args.config).split('.')[0])) 43 | 44 | config['save_metric_dir'] = os.path.join( 45 | './scores', 46 | '{}_{}'.format(config['model']['net'], 47 | os.path.basename(args.config).split('.')[0])) 48 | 49 | if torch.cuda.is_available(): 50 | config['device'] = torch.device("cuda:{}".format(config['local_rank'])) 51 | else: 52 | config['device'] = 'cpu' 53 | 54 | if (not config['distributed']) or config['global_rank'] == 0: 55 | os.makedirs(config['save_dir'], exist_ok=True) 56 | os.makedirs(config['save_metric_dir'], exist_ok=True) 57 | config_path = os.path.join(config['save_dir'], 58 | args.config.split('/')[-1]) 59 | if not os.path.isfile(config_path): 60 | copyfile(args.config, config_path) 61 | print('[**] create folder {}'.format(config['save_dir'])) 62 | 63 | trainer = Trainer(config) 64 | trainer.train() 65 | 66 | 67 | if __name__ == "__main__": 68 | 69 | torch.backends.cudnn.benchmark = True 70 | 71 | mp.set_sharing_strategy('file_system') 72 | 73 | # loading configs 74 | config = json.load(open(args.config)) 75 | 76 | # setting distributed configurations 77 | config['world_size'] = get_world_size() 78 | config['init_method'] = f"tcp://{get_master_ip()}:{args.port}" 79 | config['distributed'] = True if config['world_size'] > 1 else False 80 | print(config['world_size']) 81 | # setup distributed parallel training environments 82 | if get_master_ip() == "127.0.0.1": 83 | # manually launch distributed processes 84 | mp.spawn(main_worker, nprocs=config['world_size'], args=(config, )) 85 | else: 86 | # multiple processes have been launched by openmpi 87 | config['local_rank'] = get_local_rank() 88 | config['global_rank'] = get_global_rank() 89 | main_worker(-1, config) 90 | --------------------------------------------------------------------------------