├── .idea
└── .gitignore
├── LICENSE
├── README.md
├── assets
├── breakdance.gif
├── flowlens.png
├── in_beyond.gif
└── out_beyond.gif
├── configs
├── KITTI360EX-I_FlowLens.json
├── KITTI360EX-I_FlowLens_re.json
├── KITTI360EX-I_FlowLens_small.json
├── KITTI360EX-I_FlowLens_small_re.json
├── KITTI360EX-O_FlowLens.json
├── KITTI360EX-O_FlowLens_re.json
├── KITTI360EX-O_FlowLens_small.json
└── KITTI360EX-O_FlowLens_small_re.json
├── core
├── dataset.py
├── dist.py
├── loss.py
├── lr_scheduler.py
├── metrics.py
├── trainer.py
└── utils.py
├── datasets
└── KITTI-360EX
│ ├── InnerSphere
│ ├── test.json
│ └── train.json
│ └── OuterPinhole
│ ├── test.json
│ └── train.json
├── evaluate.py
├── model
├── flowlens.py
└── modules
│ ├── feat_prop.py
│ ├── flow_comp.py
│ ├── maskflownets.py
│ ├── mix_focal_transformer.py
│ └── spectral_norm.py
├── release_model
├── README.md
└── maskflownets_8x1_sfine_flyingthings3d_subset_384x768.py
└── train.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # 默认忽略的文件
2 | /shelf/
3 | /workspace.xml
4 | # 基于编辑器的 HTTP 客户端请求
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Hao
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ###
FlowLens: Seeing Beyond the FoV via Flow-guided Clip-Recurrent Transformer
2 |
3 |
4 |
Hao Shi ·
5 |
Qi Jiang ·
6 |
Kailun Yang ·
7 |
Xiaoting Yin ·
8 |
Kaiwei Wang
9 |
10 |
Paper
11 |
12 | ####
13 | [](https://paperswithcode.com/sota/seeing-beyond-the-visible-on-kitti360-ex?p=flowlens-seeing-beyond-the-fov-via-flow)
14 |
15 | [comment]: <> (
Paper )
16 |
17 | [comment]: <> (
Demo Video (Youtube) )
18 |
19 | [comment]: <> (
演示视频 (B站) )
20 |
21 |
22 | [comment]: <> (
)
23 |
24 | [comment]: <> (:hammer_and_wrench: :construction_worker: :rocket:
)
25 |
26 | [comment]: <> (:fire: We will release code and checkpoints in the future. :fire:
)
27 |
28 | [comment]: <> (
)
29 |
30 |
31 |
32 | ### Update
33 | - 2022.11.19 Init repository.
34 | - 2022.11.21 Release the [arXiv](https://arxiv.org/abs/2211.11293) version with supplementary materials.
35 | - 2023.04.04 :fire: Our code is publicly available.
36 | - 2023.04.04 :fire: Release pretrained models.
37 | - 2023.04.04 :fire: Release KITTI360-EX dataset.
38 |
39 | ### TODO List
40 |
41 | - [x] Code release.
42 | - [x] KITTI360-EX release.
43 | - [x] Towards higher performance with extra small costs.
44 |
45 |
46 | ### Abstract
47 | Limited by hardware cost and system size, camera's Field-of-View (FoV) is not always satisfactory.
48 | However, from a spatio-temporal perspective, information beyond the camera’s physical FoV is off-the-shelf and can actually be obtained ''for free'' from past video streams.
49 | In this paper, we propose a novel task termed Beyond-FoV Estimation, aiming to exploit past visual cues and bidirectional break through the physical FoV of a camera.
50 | We put forward a FlowLens architecture to expand the FoV by achieving feature propagation explicitly by optical flow and implicitly by a novel clip-recurrent transformer,
51 | which has two appealing features: 1) FlowLens comprises a newly proposed Clip-Recurrent Hub with 3D-Decoupled Cross Attention (DDCA) to progressively process global information accumulated in the temporal dimension. 2) A multi-branch Mix Fusion Feed Forward Network (MixF3N) is integrated to enhance the spatially-precise flow of local features. To foster training and evaluation, we establish KITTI360-EX, a dataset for outer- and inner FoV expansion.
52 | Extensive experiments on both video inpainting and beyond-FoV estimation tasks show that FlowLens achieves state-of-the-art performance.
53 |
54 | ### Demos
55 |
56 |
57 | (Outer Beyond-FoV)
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 | (Inner Beyond-FoV)
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 | (Object Removal)
74 |
75 |
76 |
77 |
78 |
79 |
80 | ### Dependencies
81 | This repo has been tested in the following environment:
82 | ```angular2html
83 | torch == 1.10.2
84 | cuda == 11.3
85 | mmflow == 0.5.2
86 | ```
87 |
88 | ### Usage
89 | To train FlowLens(-S), use:
90 | ```angular2html
91 | python train.py --config configs/KITTI360EX-I_FlowLens_small_re.json
92 | ```
93 |
94 | To eval on KITTI360-EX, run:
95 | ```angular2html
96 | python evaluate.py \
97 | --model flowlens \
98 | --cfg_path configs/KITTI360EX-I_FlowLens_small_re.json \
99 | --ckpt release_model/FlowLens-S_re_Out_500000.pth --fov fov5
100 | ```
101 |
102 | Turn on ```--reverse``` for test time augmentation (TTA).
103 |
104 | Trun on ```--save_results``` to save your output.
105 |
106 | ### Pretrained Models
107 | The pretrained model can be found there:
108 | ```angular2html
109 | https://share.weiyun.com/6G6QEdaa
110 | ```
111 |
112 | ### KITTI360-EX for Beyond-FoV Estimation
113 | The preprocessed KITTI360-EX can be downloaded from here:
114 | ```angular2html
115 | https://share.weiyun.com/BReRdDiP
116 | ```
117 |
118 | ### Results
119 | #### KITTI360EX-InnerSphere
120 | | Method | Test Logic | TTA | PSNR | SSIM | VFID | Runtime (s/frame) |
121 | | :--------- | :----------: | :----------: | :----------: | :--------: | :---------: | :------------: |
122 | | _FlowLens-S (Paper)_ |_Beyond-FoV_|_wo_| _36.17_ | _0.9916_ | _0.030_ | _0.023_ |
123 | | FlowLens-S (This Repo) |Beyond-FoV|wo| 37.31 | 0.9926 | 0.025 | **0.015** |
124 | | FlowLens-S+ (This Repo) |Beyond-FoV|with| 38.36 | 0.9938 | 0.017 | 0.050 |
125 | | FlowLens-S (This Repo) |Video Inpainting|wo| 38.01 | 0.9938 | 0.022 | 0.042 |
126 | | FlowLens-S+ (This Repo) |Video Inpainting|with| **38.97** | **0.9947** | **0.015** | 0.142 |
127 |
128 | | Method | Test Logic | TTA | PSNR | SSIM | VFID | Runtime (s/frame) |
129 | | :--------- | :----------: | :----------: | :----------: | :--------: | :---------: | :------------: |
130 | | _FlowLens (Paper)_ |_Beyond-FoV_|_wo_| _36.69_ | _0.9916_ | _0.027_ | _0.049_ |
131 | | FlowLens (This Repo) |Beyond-FoV|wo| 37.65 | 0.9927 | 0.024 | **0.033** |
132 | | FlowLens+ (This Repo) |Beyond-FoV|with| 38.74 | 0.9941 | 0.017 | 0.095 |
133 | | FlowLens (This Repo) |Video Inpainting|wo| 38.38 | 0.9939 | 0.018 | 0.086 |
134 | | FlowLens+ (This Repo) |Video Inpainting|with| **39.40** | **0.9950** | **0.015** | 0.265 |
135 | ###
136 |
137 | #### KITTI360EX-OuterPinhole
138 | | Method | Test Logic | TTA | PSNR | SSIM | VFID | Runtime (s/frame) |
139 | | :--------- | :----------: | :----------: | :----------: | :--------: | :---------: | :------------: |
140 | | _FlowLens-S (Paper)_ |_Beyond-FoV_|_wo_| _19.68_ | _0.9247_ | _0.300_ | _0.023_ |
141 | | FlowLens-S (This Repo) |Beyond-FoV|wo| 20.41 | 0.9332 | 0.285 | **0.021** |
142 | | FlowLens-S+ (This Repo) |Beyond-FoV|with| 21.30 | 0.9397 | 0.302 | 0.056 |
143 | | FlowLens-S (This Repo) |Video Inpainting|wo| 21.69 | 0.9453 | **0.245** | 0.048 |
144 | | FlowLens-S+ (This Repo) |Video Inpainting|with| **22.40** | **0.9503** | 0.271 | 0.146 |
145 |
146 | | Method | Test Logic | TTA | PSNR | SSIM | VFID | Runtime (s/frame) |
147 | | :--------- | :----------: | :----------: | :----------: | :--------: | :---------: | :------------: |
148 | | _FlowLens (Paper)_ |_Beyond-FoV_|_wo_| _20.13_ | _0.9314_ | _0.281_ | _0.049_ |
149 | | FlowLens (This Repo) |Beyond-FoV|wo| 20.85 | 0.9381 | 0.259 | **0.035** |
150 | | FlowLens+ (This Repo) |Beyond-FoV|with| 21.65 | 0.9432 | 0.276 | 0.097 |
151 | | FlowLens (This Repo) |Video Inpainting|wo| 22.23 | 0.9507 | **0.231** | 0.085 |
152 | | FlowLens+ (This Repo) |Video Inpainting|with| **22.86** | **0.9543** | 0.253 | 0.260 |
153 |
154 | Note that when using the ''Video Inpainting'' logic for output,
155 | the model is allowed to use more reference frames from the future,
156 | and each local frame is estimated at least twice,
157 | thus higher accuracy can be obtained while result in slower inference speed,
158 | and it is not realistic for real-world deployment.
159 |
160 | ### Citation
161 |
162 | If you find our paper or repo useful, please consider citing our paper:
163 |
164 | ```bibtex
165 | @article{shi2022flowlens,
166 | title={FlowLens: Seeing Beyond the FoV via Flow-guided Clip-Recurrent Transformer},
167 | author={Shi, Hao and Jiang, Qi and Yang, Kailun and Yin, Xiaoting and Wang, Kaiwei},
168 | journal={arXiv preprint arXiv:2211.11293},
169 | year={2022}
170 | }
171 | ```
172 | ### Acknowledgement
173 | This project would not have been possible without the following outstanding repositories:
174 |
175 | [STTN](https://github.com/researchmm/STTN), [MMFlow](https://github.com/open-mmlab/mmflow)
176 |
177 |
178 | ### Devs
179 | Hao Shi
180 |
181 | ### Contact
182 | Feel free to contact me if you have additional questions or have interests in collaboration. Please drop me an email at haoshi@zju.edu.cn. =)
183 |
--------------------------------------------------------------------------------
/assets/breakdance.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MasterHow/FlowLens/252568a423c89bb83d188cfad2e962a3d3423ee3/assets/breakdance.gif
--------------------------------------------------------------------------------
/assets/flowlens.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MasterHow/FlowLens/252568a423c89bb83d188cfad2e962a3d3423ee3/assets/flowlens.png
--------------------------------------------------------------------------------
/assets/in_beyond.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MasterHow/FlowLens/252568a423c89bb83d188cfad2e962a3d3423ee3/assets/in_beyond.gif
--------------------------------------------------------------------------------
/assets/out_beyond.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MasterHow/FlowLens/252568a423c89bb83d188cfad2e962a3d3423ee3/assets/out_beyond.gif
--------------------------------------------------------------------------------
/configs/KITTI360EX-I_FlowLens.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed": 2023,
3 | "save_dir": "release_model/",
4 | "train_data_loader": {
5 | "name": "KITTI360-EX",
6 | "data_root": "datasets//KITTI-360EX//InnerSphere",
7 | "w": 336,
8 | "h": 336,
9 | "num_local_frames": 5,
10 | "num_ref_frames": 3
11 | },
12 | "losses": {
13 | "hole_weight": 1,
14 | "valid_weight": 1,
15 | "flow_weight": 1,
16 | "adversarial_weight": 0.01,
17 | "GAN_LOSS": "hinge"
18 | },
19 | "model": {
20 | "net": "flowlens",
21 | "no_dis": 0,
22 | "depths": 9,
23 | "window_size": [7, 7],
24 | "output_size": [84, 84],
25 | "small_model": 0
26 | },
27 | "trainer": {
28 | "type": "Adam",
29 | "beta1": 0,
30 | "beta2": 0.99,
31 | "lr": 0.25e-4,
32 | "batch_size": 2,
33 | "num_workers": 8,
34 | "log_freq": 100,
35 | "save_freq": 5e3,
36 | "iterations": 50e4,
37 | "scheduler": {
38 | "type": "MultiStepLR",
39 | "milestones": [
40 | 40e4
41 | ],
42 | "gamma": 0.1
43 | }
44 | }
45 | }
--------------------------------------------------------------------------------
/configs/KITTI360EX-I_FlowLens_re.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed": 2023,
3 | "save_dir": "release_model/",
4 | "train_data_loader": {
5 | "name": "KITTI360-EX",
6 | "data_root": "/workspace/mnt/storage/shihao/VI/KITTI-360EX/InnerSphere",
7 | "w": 336,
8 | "h": 336,
9 | "num_local_frames": 10,
10 | "num_ref_frames": 3
11 | },
12 | "losses": {
13 | "hole_weight": 1,
14 | "valid_weight": 1,
15 | "flow_weight": 1,
16 | "adversarial_weight": 0.01,
17 | "GAN_LOSS": "hinge"
18 | },
19 | "model": {
20 | "net": "flowlens",
21 | "no_dis": 0,
22 | "depths": 9,
23 | "window_size": [7, 7],
24 | "output_size": [84, 84],
25 | "small_model": 0
26 | },
27 | "trainer": {
28 | "type": "Adam",
29 | "beta1": 0,
30 | "beta2": 0.99,
31 | "lr": 1e-4,
32 | "batch_size": 8,
33 | "num_workers": 4,
34 | "log_freq": 100,
35 | "save_freq": 5e3,
36 | "iterations": 50e4,
37 | "scheduler": {
38 | "type": "MultiStepLR",
39 | "milestones": [
40 | 40e4
41 | ],
42 | "gamma": 0.1
43 | }
44 | }
45 | }
--------------------------------------------------------------------------------
/configs/KITTI360EX-I_FlowLens_small.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed": 2023,
3 | "save_dir": "release_model/",
4 | "train_data_loader": {
5 | "name": "KITTI360-EX",
6 | "data_root": "/workspace/mnt/storage/shihao/VI/KITTI-360EX/InnerSphere",
7 | "w": 336,
8 | "h": 336,
9 | "num_local_frames": 5,
10 | "num_ref_frames": 3,
11 | "random_mask": 1
12 | },
13 | "losses": {
14 | "hole_weight": 1,
15 | "valid_weight": 1,
16 | "flow_weight": 1,
17 | "adversarial_weight": 0.01,
18 | "GAN_LOSS": "hinge"
19 | },
20 | "model": {
21 | "net": "flowlens",
22 | "no_dis": 0,
23 | "spy_net": 1,
24 | "mfn_teach": 1,
25 | "freeze_dcn": 0,
26 | "depths": 5,
27 | "window_size": [7, 7],
28 | "output_size": [84, 84],
29 | "small_model": 1
30 | },
31 | "trainer": {
32 | "type": "Adam",
33 | "beta1": 0,
34 | "beta2": 0.99,
35 | "lr": 0.25e-4,
36 | "batch_size": 2,
37 | "num_workers": 2,
38 | "log_freq": 100,
39 | "save_freq": 5e3,
40 | "iterations": 50e4,
41 | "scheduler": {
42 | "type": "MultiStepLR",
43 | "milestones": [
44 | 40e4
45 | ],
46 | "gamma": 0.1
47 | }
48 | }
49 | }
--------------------------------------------------------------------------------
/configs/KITTI360EX-I_FlowLens_small_re.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed": 2023,
3 | "save_dir": "release_model/",
4 | "train_data_loader": {
5 | "name": "KITTI360-EX",
6 | "data_root": "/workspace/mnt/storage/shihao/VI/KITTI-360EX/InnerSphere",
7 | "w": 336,
8 | "h": 336,
9 | "num_local_frames": 10,
10 | "num_ref_frames": 3,
11 | "random_mask": 0
12 | },
13 | "losses": {
14 | "hole_weight": 1,
15 | "valid_weight": 1,
16 | "flow_weight": 1,
17 | "adversarial_weight": 0.01,
18 | "GAN_LOSS": "hinge"
19 | },
20 | "model": {
21 | "net": "flowlens",
22 | "no_dis": 0,
23 | "spy_net": 1,
24 | "mfn_teach": 1,
25 | "freeze_dcn": 0,
26 | "depths": 5,
27 | "window_size": [7, 7],
28 | "output_size": [84, 84],
29 | "small_model": 1
30 | },
31 | "trainer": {
32 | "type": "Adam",
33 | "beta1": 0,
34 | "beta2": 0.99,
35 | "lr": 1e-4,
36 | "batch_size": 1,
37 | "num_workers": 4,
38 | "log_freq": 100,
39 | "save_freq": 5e3,
40 | "iterations": 50e4,
41 | "scheduler": {
42 | "type": "MultiStepLR",
43 | "milestones": [
44 | 40e4
45 | ],
46 | "gamma": 0.1
47 | }
48 | }
49 | }
--------------------------------------------------------------------------------
/configs/KITTI360EX-O_FlowLens.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed": 2023,
3 | "save_dir": "release_model/",
4 | "train_data_loader": {
5 | "name": "KITTI360-EX",
6 | "data_root": "datasets//KITTI-360EX//OuterPinhole",
7 | "w": 432,
8 | "h": 240,
9 | "num_local_frames": 5,
10 | "num_ref_frames": 3
11 | },
12 | "losses": {
13 | "hole_weight": 1,
14 | "valid_weight": 1,
15 | "flow_weight": 1,
16 | "adversarial_weight": 0.01,
17 | "GAN_LOSS": "hinge"
18 | },
19 | "model": {
20 | "net": "flowlens",
21 | "no_dis": 0,
22 | "depths": 9,
23 | "window_size": 0,
24 | "output_size": 0,
25 | "small_model": 0
26 | },
27 | "trainer": {
28 | "type": "Adam",
29 | "beta1": 0,
30 | "beta2": 0.99,
31 | "lr": 0.25e-4,
32 | "batch_size": 2,
33 | "num_workers": 8,
34 | "log_freq": 100,
35 | "save_freq": 5e3,
36 | "iterations": 50e4,
37 | "scheduler": {
38 | "type": "MultiStepLR",
39 | "milestones": [
40 | 40e4
41 | ],
42 | "gamma": 0.1
43 | }
44 | }
45 | }
--------------------------------------------------------------------------------
/configs/KITTI360EX-O_FlowLens_re.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed": 2023,
3 | "save_dir": "release_model/",
4 | "train_data_loader": {
5 | "name": "KITTI360-EX",
6 | "data_root": "/workspace/mnt/storage/shihao/VI/KITTI-360EX/OuterPinhole",
7 | "w": 432,
8 | "h": 240,
9 | "num_local_frames": 11,
10 | "num_ref_frames": 3
11 | },
12 | "losses": {
13 | "hole_weight": 1,
14 | "valid_weight": 1,
15 | "flow_weight": 1,
16 | "adversarial_weight": 0.01,
17 | "GAN_LOSS": "hinge"
18 | },
19 | "model": {
20 | "net": "flowlens",
21 | "no_dis": 0,
22 | "depths": 9,
23 | "window_size": 0,
24 | "output_size": 0,
25 | "small_model": 0
26 | },
27 | "trainer": {
28 | "type": "Adam",
29 | "beta1": 0,
30 | "beta2": 0.99,
31 | "lr": 1e-4,
32 | "batch_size": 8,
33 | "num_workers": 4,
34 | "log_freq": 100,
35 | "save_freq": 5e3,
36 | "iterations": 50e4,
37 | "scheduler": {
38 | "type": "MultiStepLR",
39 | "milestones": [
40 | 40e4
41 | ],
42 | "gamma": 0.1
43 | }
44 | }
45 | }
--------------------------------------------------------------------------------
/configs/KITTI360EX-O_FlowLens_small.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed": 2023,
3 | "save_dir": "release_model/",
4 | "train_data_loader": {
5 | "name": "KITTI360-EX",
6 | "data_root": "/workspace/mnt/storage/shihao/VI/KITTI-360EX/OuterPinhole",
7 | "w": 432,
8 | "h": 240,
9 | "num_local_frames": 5,
10 | "num_ref_frames": 3
11 | },
12 | "losses": {
13 | "hole_weight": 1,
14 | "valid_weight": 1,
15 | "flow_weight": 1,
16 | "adversarial_weight": 0.01,
17 | "GAN_LOSS": "hinge"
18 | },
19 | "model": {
20 | "net": "flowlens",
21 | "no_dis": 0,
22 | "spy_net": 1,
23 | "mfn_teach": 1,
24 | "depths": 5,
25 | "window_size": 0,
26 | "output_size": 0,
27 | "small_model": 1
28 | },
29 | "trainer": {
30 | "type": "Adam",
31 | "beta1": 0,
32 | "beta2": 0.99,
33 | "lr": 0.25e-4,
34 | "batch_size": 2,
35 | "num_workers": 2,
36 | "log_freq": 100,
37 | "save_freq": 5e3,
38 | "iterations": 50e4,
39 | "scheduler": {
40 | "type": "MultiStepLR",
41 | "milestones": [
42 | 40e4
43 | ],
44 | "gamma": 0.1
45 | }
46 | }
47 | }
--------------------------------------------------------------------------------
/configs/KITTI360EX-O_FlowLens_small_re.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed": 2023,
3 | "save_dir": "release_model/",
4 | "train_data_loader": {
5 | "name": "KITTI360-EX",
6 | "data_root": "/workspace/mnt/storage/shihao/VI/KITTI-360EX/OuterPinhole",
7 | "w": 432,
8 | "h": 240,
9 | "num_local_frames": 11,
10 | "num_ref_frames": 3
11 | },
12 | "losses": {
13 | "hole_weight": 1,
14 | "valid_weight": 1,
15 | "flow_weight": 1,
16 | "adversarial_weight": 0.01,
17 | "GAN_LOSS": "hinge"
18 | },
19 | "model": {
20 | "net": "flowlens",
21 | "no_dis": 0,
22 | "spy_net": 1,
23 | "mfn_teach": 1,
24 | "depths": 5,
25 | "window_size": null,
26 | "output_size": null,
27 | "small_model": 1
28 | },
29 | "trainer": {
30 | "type": "Adam",
31 | "beta1": 0,
32 | "beta2": 0.99,
33 | "lr": 1e-4,
34 | "batch_size": 8,
35 | "num_workers": 4,
36 | "log_freq": 100,
37 | "save_freq": 5e3,
38 | "iterations": 50e4,
39 | "scheduler": {
40 | "type": "MultiStepLR",
41 | "milestones": [
42 | 40e4
43 | ],
44 | "gamma": 0.1
45 | }
46 | }
47 | }
--------------------------------------------------------------------------------
/core/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import random
4 |
5 | import cv2
6 | from PIL import Image
7 | import numpy as np
8 |
9 | import torch
10 | import torchvision.transforms as transforms
11 |
12 | from core.utils import (TrainZipReader, TestZipReader,
13 | create_random_shape_with_random_motion, Stack,
14 | ToTorchFormatTensor, GroupRandomHorizontalFlip,
15 | create_random_shape_with_random_motion_seq)
16 | from core.dist import get_world_size
17 |
18 |
19 | class TrainDataset(torch.utils.data.Dataset):
20 | """
21 | Sequence Video Train Dataloader by Hao, based on E2FGVI train loader.
22 | same_mask(bool): If True, use same mask until video changes to the next video.
23 | random_mask(bool): If True, use random mask rather than FoV mask for KITTI360.
24 | """
25 | def __init__(self, args: dict, debug=False, start=0, end=1, batch_size=1, same_mask=False):
26 | self.args = args
27 | self.num_local_frames = args['num_local_frames']
28 | self.num_ref_frames = args['num_ref_frames']
29 | self.size = self.w, self.h = (args['w'], args['h'])
30 |
31 | # 是否为KITTI360使用随机mask
32 | args['random_mask'] = args.get('random_mask', 0)
33 | if args['random_mask'] != 0:
34 | self.random_mask = True
35 | else:
36 | # default
37 | self.random_mask = False
38 |
39 | if args['name'] != 'KITTI360-EX':
40 | json_path = os.path.join(args['data_root'], args['name'], 'train.json')
41 | self.dataset_name = args['name']
42 | else:
43 | json_path = os.path.join(args['data_root'], 'train.json')
44 | self.dataset_name = 'KITTI360-EX'
45 |
46 | with open(json_path, 'r') as f:
47 | self.video_dict = json.load(f)
48 | self.video_names = list(self.video_dict.keys())
49 | if args['name'] == 'KITTI360-EX':
50 | # 打乱数据顺序防止过拟合
51 | from random import shuffle
52 | shuffle(self.video_names)
53 | if debug:
54 | self.video_names = self.video_names[:100]
55 |
56 | self._to_tensors = transforms.Compose([
57 | Stack(),
58 | ToTorchFormatTensor(),
59 | ])
60 |
61 | # self.neighbor_stride = 5 # 每隔5步采样一组LF和NLF 这里可能是个采样bug
62 | self.neighbor_stride = self.num_local_frames # 每隔window size步采样一组LF和NLF
63 | self.start_index = 0 # 采样的起始点
64 | self.start = start # 数据集迭代起点
65 | self.end = len(self.video_names) # 数据集迭代终点
66 | self.batch_size = batch_size # 用于计算数据集迭代器
67 |
68 | # 多卡DDP训练需要知道数据的local rank
69 | self.world_size = get_world_size()
70 |
71 | # 如果多卡ddp,实际上每张卡上是独立进程,每张卡上的batch size是要除以总的卡数的
72 | if self.world_size > 1:
73 | self.batch_size = self.batch_size // self.world_size
74 |
75 | self.index = 0 # 自定义视频index
76 | self.batch_buffer = 0 # 用于随着batch size更新视频index
77 | self.new_video_flag = False # 用于判断是否到了新视频
78 | self.worker_group = 0 # 用于随着每组worker更新
79 |
80 | self.same_mask = same_mask # 如果为True, 在切换视频前使用相同的mask,这样的行为模式
81 | if self.same_mask:
82 | self.random_dict_list = []
83 | self.new_mask_list = []
84 | for i in range(0, self.batch_size):
85 | # 用于存储随机mask的参数字典, 不同batch不一样
86 | self.random_dict_list.append(None)
87 | # 当设置为True时,mask会重新随机生成
88 | self.new_mask_list.append(False)
89 |
90 | # 为每个batch创建独立的video index和start_index, 以及worker_group
91 | self.video_index_list = []
92 | for i in range(0, self.batch_size):
93 | # 初始化video id时将不同batch的错开防止数据重复和过拟合
94 | if self.world_size == 1:
95 | # 单卡逻辑
96 | self.video_index_list.append(i * len(self.video_names)//self.batch_size)
97 | else:
98 | # 多卡逻辑
99 | # 随机初始化视频index
100 | self.video_index_list.append(random.randint(0, len(self.video_names)))
101 |
102 | self.start_index_list = []
103 | for i in range(0, self.batch_size):
104 | self.start_index_list.append(0)
105 |
106 | self.worker_group_list = []
107 | for i in range(0, self.batch_size):
108 | self.worker_group_list.append(0)
109 |
110 | def __len__(self):
111 | # 视频切换等操作自定义完成,iter定义为一个大数避免与自定义index冲突
112 | return len(self.video_names)*1000
113 |
114 | # 两次迭代相距5帧, 不同batch通道是不同的视频, 视频index和起始帧index在batch之间独立, 避免数据浪费
115 | def __getitem__(self, index):
116 | worker_info = torch.utils.data.get_worker_info()
117 | if worker_info is None:
118 | # single-process data loading, skip idx rearrange
119 | print('Warning: Only one data worker was used, the manner has not been test!')
120 | pass
121 | else:
122 | if self.start_index_list[self.batch_buffer] == 0 and self.worker_group_list[self.batch_buffer] == 0:
123 | # 新视频, 初始化start index
124 |
125 | if self.same_mask:
126 | # 在这种行为模式下,当切换到新视频时,我们重新生成mask
127 | # 只有对于第一个worker,我们希望他可以成功生成新的mask,其他的worker最好和他用一样的mask
128 | if worker_info.id == 0:
129 | self.random_dict_list[self.batch_buffer] = None
130 | self.new_mask_list[self.batch_buffer] = True
131 |
132 | # 判断start index有没有超出视频的长度
133 | if (self.neighbor_stride * worker_info.id) <= self.video_dict[self.video_names[self.video_index_list[self.batch_buffer]]]:
134 | self.start_index_list[self.batch_buffer] = self.neighbor_stride * worker_info.id
135 | else:
136 | # 超出则等待第一个worker超出,并且start index置为当前worker的上一个没有超出的worker的start id
137 | # 等待第一个worker超出后再更新到下一个视频,防止不同worker的start id错位
138 | # 判断前面的worker有没有超出,用最近的且没有超出的worker的值替换
139 | out_list = []
140 | final_worker_idx = 0
141 | for previous_worker in range(0, worker_info.num_workers):
142 | out_list.append(
143 | ((self.neighbor_stride * previous_worker) <= self.video_dict[self.video_names[self.video_index_list[self.batch_buffer]]])
144 | )
145 | out_list.reverse()
146 | final_worker_idx = worker_info.num_workers - 1 - out_list.index(True)
147 | self.start_index_list[self.batch_buffer] = self.neighbor_stride * final_worker_idx
148 |
149 | else:
150 | # 不是新视频
151 | if self.same_mask:
152 | # 在这种行为模式下,不是新视频,直接用上一次的mask参数
153 | self.new_mask_list[self.batch_buffer] = False
154 |
155 | # 判断start index有没有超出视频的长度
156 | if (self.start_index_list[self.batch_buffer] + self.neighbor_stride * worker_info.num_workers) <= self.video_dict[self.video_names[self.video_index_list[self.batch_buffer]]]:
157 | self.start_index_list[self.batch_buffer] += self.neighbor_stride * worker_info.num_workers
158 | else:
159 | # 超出则切换到下一个视频,并且start index置为0(每个worker的start仍然不同)
160 | # 如果第一个worker没有超出,就复制当前worker的上一个worker的start id,等待第一个worker超出后再更新到下一个视频,
161 | # 防止不同worker的start id错位
162 | if (self.start_index_list[self.batch_buffer] + self.neighbor_stride * (worker_info.num_workers - worker_info.id)) <= self.video_dict[self.video_names[self.video_index_list[self.batch_buffer]]]:
163 | # 判断前面的worker有没有超出,用最近的且没有超出的worker的值替换
164 | out_list = []
165 | final_worker_idx = 0
166 | for previous_worker in range(0, worker_info.num_workers):
167 | out_list.append(
168 | ((self.start_index_list[self.batch_buffer] + self.neighbor_stride * (worker_info.num_workers - worker_info.id + previous_worker)) <= self.video_dict[self.video_names[self.video_index_list[self.batch_buffer]]])
169 | )
170 | out_list.reverse()
171 | final_worker_idx = worker_info.num_workers - 1 - out_list.index(True)
172 | self.start_index_list[self.batch_buffer] = self.start_index_list[self.batch_buffer] + self.neighbor_stride * (worker_info.num_workers - worker_info.id + final_worker_idx)
173 | else:
174 | # 如果第一个worker超出了,切换到下一个视频
175 | self.new_video_flag = True
176 |
177 | if self.same_mask:
178 | # 在这种行为模式下,当切换到新视频时,我们重新生成mask
179 | # 只有对于第一个worker,我们希望他可以成功生成新的mask,其他的worker最好和他用一样的mask
180 | if worker_info.id == 0:
181 | self.random_dict_list[self.batch_buffer] = None
182 | self.new_mask_list[self.batch_buffer] = True
183 |
184 | # self.worker_group = 0
185 | self.worker_group_list[self.batch_buffer] = 0
186 | self.start_index_list[self.batch_buffer] = self.neighbor_stride * worker_info.id
187 | # 判断视频 index有没有超出视频的个数
188 | if (self.video_index_list[self.batch_buffer] + 1) < len(self.video_names):
189 | self.video_index_list[self.batch_buffer] += 1
190 | else:
191 | # 超出则切换回第一个视频
192 | self.video_index_list[self.batch_buffer] = 0
193 |
194 | # 根据index和start index读取帧
195 | self.index = self.video_index_list[self.batch_buffer]
196 | self.start_index = self.start_index_list[self.batch_buffer]
197 | item = self.load_item_v4()
198 |
199 | # 更新woker group的index
200 | self.worker_group_list[self.batch_buffer] += 1
201 |
202 | self.batch_buffer += 1
203 | if self.batch_buffer == self.batch_size:
204 | # 重置batch缓存
205 | self.batch_buffer = 0
206 |
207 | return item
208 |
209 | def _sample_index_seq(self, length, sample_length, num_ref_frame=3, pivot=0, before_nlf=False):
210 | """
211 |
212 | Args:
213 | length:
214 | sample_length:
215 | num_ref_frame:
216 | pivot:
217 | before_nlf: If True, the non local frames will be sampled only from previous frames, not from future.
218 |
219 | Returns:
220 |
221 | """
222 | complete_idx_set = list(range(length))
223 | local_idx = complete_idx_set[pivot:pivot + sample_length]
224 |
225 | # 保证最后几帧返回的局部帧数量一致,也是5帧(步数),使得batch stack的时候不会出错:
226 | if len(local_idx) < self.neighbor_stride:
227 | for i in range(0, self.neighbor_stride-len(local_idx)):
228 | if local_idx:
229 | local_idx.append(local_idx[-1])
230 | else:
231 | # 恰好视频长度是局部帧步幅的整数倍,取local_idx为最后一帧5次
232 | local_idx.append(complete_idx_set[-1])
233 |
234 | if before_nlf:
235 | # 非局部帧只会从过去的视频帧中选取,不会使用未来的信息
236 | complete_idx_set = complete_idx_set[:pivot + sample_length]
237 |
238 | remain_idx = list(set(complete_idx_set) - set(local_idx))
239 |
240 | # 当只用过去的帧作为非局部帧时,可能会出现过去的帧数量少于非局部帧需求的问题,比如视频的一开始
241 | if before_nlf:
242 | if len(remain_idx) < num_ref_frame:
243 | # 则我们允许从局部帧中采样非局部帧 转换为set可以去除重复元素
244 | remain_idx = list(set(remain_idx + local_idx))
245 |
246 | ref_index = sorted(random.sample(remain_idx, num_ref_frame))
247 |
248 | return local_idx + ref_index
249 |
250 | def load_item_v4(self):
251 | """避免dataloader的index和worker的index冲突"""
252 | video_name = self.video_names[self.index]
253 |
254 | # create masks
255 | if self.dataset_name != 'KITTI360-EX':
256 | # 对于非KITTI360-EX数据集,随机创建mask
257 | if not self.same_mask:
258 | # 每次迭代都会生成新形状的随机mask
259 | all_masks = create_random_shape_with_random_motion(
260 | self.video_dict[video_name], imageHeight=self.h, imageWidth=self.w)
261 |
262 | else:
263 | # 在切换新视频前使用一样的mask参数
264 | all_masks, random_dict = create_random_shape_with_random_motion_seq(
265 | self.video_dict[video_name], imageHeight=self.h, imageWidth=self.w,
266 | new_mask=self.new_mask_list[self.batch_buffer],
267 | random_dict=self.random_dict_list[self.batch_buffer])
268 | # 更新随机mask的参数
269 | self.random_dict_list[self.batch_buffer] = random_dict
270 |
271 | elif self.random_mask:
272 | # 如果使用random_mask,则为KITTI360创建随机mask
273 | all_masks = create_random_shape_with_random_motion(
274 | self.video_dict[video_name], imageHeight=self.h, imageWidth=self.w)
275 |
276 | # create sample index
277 | # 对于KITTI360-EX这样视场角扩展的场景,非局部帧只能从过去的信息中获取
278 | if self.dataset_name == 'KITTI360-EX':
279 | before_nlf = True
280 | else:
281 | # 默认视频补全可以用未来的信息
282 | before_nlf = False
283 | selected_index = self._sample_index_seq(self.video_dict[video_name],
284 | self.num_local_frames,
285 | self.num_ref_frames,
286 | pivot=self.start_index,
287 | before_nlf=before_nlf)
288 |
289 | # read video frames
290 | frames = []
291 | masks = []
292 | for idx in selected_index:
293 | if self.dataset_name != 'KITTI360-EX':
294 | video_path = os.path.join(self.args['data_root'],
295 | self.args['name'], 'JPEGImages',
296 | f'{video_name}.zip')
297 | else:
298 | video_path = os.path.join(self.args['data_root'],
299 | 'JPEGImages',
300 | f'{video_name}.zip')
301 | img = TrainZipReader.imread(video_path, idx).convert('RGB')
302 | img = img.resize(self.size)
303 | frames.append(img)
304 | if self.dataset_name != 'KITTI360-EX':
305 | masks.append(all_masks[idx])
306 | elif self.random_mask:
307 | # 对于KITTI360-EX数据集, 也可以使用random mask
308 | masks.append(all_masks[idx])
309 | else:
310 | # 对于KITTI360-EX数据集,读取zip中存储的mask
311 | mask_path = os.path.join(self.args['data_root'],
312 | 'test_masks',
313 | f'{video_name}.zip')
314 | mask = TrainZipReader.imread(mask_path, idx)
315 | mask = mask.resize(self.size).convert('L')
316 | mask = np.asarray(mask)
317 | m = np.array(mask > 0).astype(np.uint8)
318 | mask = Image.fromarray(m * 255)
319 | masks.append(mask)
320 |
321 | # normalizate, to tensors
322 | frames = GroupRandomHorizontalFlip()(frames)
323 | if self.dataset_name == 'KITTI360-EX':
324 | # 对于本地读取的mask 也需要随着frame翻转
325 | if not self.random_mask:
326 | masks = GroupRandomHorizontalFlip()(masks)
327 | frame_tensors = self._to_tensors(frames) * 2.0 - 1.0
328 | mask_tensors = self._to_tensors(masks)
329 |
330 | if not self.same_mask:
331 | # 每次生成新的随机mask,不需要返回字典
332 | return frame_tensors, mask_tensors, video_name, self.index, self.start_index
333 | else:
334 | # 要控制mask的行为一致,需要返回字典
335 | return frame_tensors, mask_tensors, video_name, self.index, self.start_index,\
336 | self.new_mask_list[self.batch_buffer], self.random_dict_list
337 |
338 |
339 | class TestDataset(torch.utils.data.Dataset):
340 | def __init__(self, args):
341 | self.args = args
342 | self.size = self.w, self.h = args.size
343 | try:
344 | # 是否使用 object mask
345 | self.object_mask = args.object
346 | except:
347 | self.object_mask = False
348 |
349 | if args.dataset != 'KITTI360-EX':
350 | # default manner
351 | json_path = os.path.join(args.data_root, args.dataset, 'test.json')
352 | else:
353 | # for KITTI360-EX
354 | json_path = os.path.join(args.data_root, 'test.json')
355 |
356 | with open(json_path, 'r') as f:
357 | self.video_dict = json.load(f)
358 | self.video_names = list(self.video_dict.keys())
359 |
360 | self._to_tensors = transforms.Compose([
361 | Stack(),
362 | ToTorchFormatTensor(),
363 | ])
364 |
365 | def __len__(self):
366 | return len(self.video_names)
367 |
368 | def __getitem__(self, index):
369 | item = self.load_item(index)
370 | return item
371 |
372 | def load_item(self, index):
373 | video_name = self.video_names[index]
374 | ref_index = list(range(self.video_dict[video_name]))
375 |
376 | # read video frames
377 | frames = []
378 | masks = []
379 | for idx in ref_index:
380 |
381 | # read img from zip
382 | if self.args.dataset != 'KITTI360-EX':
383 | # default manner
384 | video_path = os.path.join(self.args.data_root, self.args.dataset, 'JPEGImages', f'{video_name}.zip')
385 | else:
386 | # for KITTI360-EX
387 | video_path = os.path.join(self.args.data_root, 'JPEGImages', f'{video_name}.zip')
388 |
389 | img = TestZipReader.imread(video_path, idx).convert('RGB')
390 | img = img.resize(self.size)
391 | frames.append(img)
392 |
393 | # read mask from folder
394 | if self.args.dataset != 'KITTI360-EX':
395 | if not self.object_mask:
396 | # default manner for video completion
397 | mask_path = os.path.join(self.args.data_root, self.args.dataset,
398 | 'test_masks', video_name,
399 | str(idx).zfill(5) + '.png')
400 | else:
401 | # default manner for object removal
402 | mask_path = os.path.join(self.args.data_root, self.args.dataset,
403 | 'test_masks_object', video_name,
404 | str(idx).zfill(5) + '.png')
405 | else:
406 | # for KITTI360-EX: use seq 10 for testing
407 | mask_path = os.path.join(self.args.data_root, 'test_masks', 'seq10', self.args.fov, video_name,
408 | str(idx).zfill(6) + '.png')
409 |
410 | mask = Image.open(mask_path).resize(self.size,
411 | Image.NEAREST).convert('L')
412 | # origin: 0 indicates missing. now: 1 indicates missing
413 | mask = np.asarray(mask)
414 | m = np.array(mask > 0).astype(np.uint8)
415 | m = cv2.dilate(m,
416 | cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)),
417 | iterations=4)
418 | mask = Image.fromarray(m * 255)
419 | masks.append(mask)
420 |
421 | # to tensors
422 | frames_PIL = [np.array(f).astype(np.uint8) for f in frames]
423 | frame_tensors = self._to_tensors(frames) * 2.0 - 1.0
424 | mask_tensors = self._to_tensors(masks)
425 | return frame_tensors, mask_tensors, video_name, frames_PIL
426 |
--------------------------------------------------------------------------------
/core/dist.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 |
5 | def get_world_size():
6 | """Find OMPI world size without calling mpi functions
7 | :rtype: int
8 | """
9 | if os.environ.get('PMI_SIZE') is not None:
10 | return int(os.environ.get('PMI_SIZE') or 1)
11 | elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None:
12 | return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1)
13 | else:
14 | return torch.cuda.device_count()
15 |
16 |
17 | def get_global_rank():
18 | """Find OMPI world rank without calling mpi functions
19 | :rtype: int
20 | """
21 | if os.environ.get('PMI_RANK') is not None:
22 | return int(os.environ.get('PMI_RANK') or 0)
23 | elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None:
24 | return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0)
25 | else:
26 | return 0
27 |
28 |
29 | def get_local_rank():
30 | """Find OMPI local rank without calling mpi functions
31 | :rtype: int
32 | """
33 | if os.environ.get('MPI_LOCALRANKID') is not None:
34 | return int(os.environ.get('MPI_LOCALRANKID') or 0)
35 | elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None:
36 | return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0)
37 | else:
38 | return 0
39 |
40 |
41 | def get_master_ip():
42 | if os.environ.get('AZ_BATCH_MASTER_NODE') is not None:
43 | return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0]
44 | elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None:
45 | return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE')
46 | else:
47 | return "127.0.0.1"
48 |
--------------------------------------------------------------------------------
/core/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class AdversarialLoss(nn.Module):
6 | r"""
7 | Adversarial loss
8 | https://arxiv.org/abs/1711.10337
9 | """
10 | def __init__(self,
11 | type='nsgan',
12 | target_real_label=1.0,
13 | target_fake_label=0.0):
14 | r"""
15 | type = nsgan | lsgan | hinge
16 | """
17 | super(AdversarialLoss, self).__init__()
18 | self.type = type
19 | self.register_buffer('real_label', torch.tensor(target_real_label))
20 | self.register_buffer('fake_label', torch.tensor(target_fake_label))
21 |
22 | if type == 'nsgan':
23 | self.criterion = nn.BCELoss()
24 | elif type == 'lsgan':
25 | self.criterion = nn.MSELoss()
26 | elif type == 'hinge':
27 | self.criterion = nn.ReLU()
28 |
29 | def __call__(self, outputs, is_real, is_disc=None):
30 | if self.type == 'hinge':
31 | if is_disc:
32 | if is_real:
33 | outputs = -outputs
34 | return self.criterion(1 + outputs).mean()
35 | else:
36 | return (-outputs).mean()
37 | else:
38 | labels = (self.real_label
39 | if is_real else self.fake_label).expand_as(outputs)
40 | loss = self.criterion(outputs, labels)
41 | return loss
42 |
--------------------------------------------------------------------------------
/core/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | """
2 | LR scheduler from BasicSR https://github.com/xinntao/BasicSR
3 | """
4 | import math
5 | from collections import Counter
6 | from torch.optim.lr_scheduler import _LRScheduler
7 |
8 |
9 | class MultiStepRestartLR(_LRScheduler):
10 | """ MultiStep with restarts learning rate scheme.
11 | Args:
12 | optimizer (torch.nn.optimizer): Torch optimizer.
13 | milestones (list): Iterations that will decrease learning rate.
14 | gamma (float): Decrease ratio. Default: 0.1.
15 | restarts (list): Restart iterations. Default: [0].
16 | restart_weights (list): Restart weights at each restart iteration.
17 | Default: [1].
18 | last_epoch (int): Used in _LRScheduler. Default: -1.
19 | """
20 | def __init__(self,
21 | optimizer,
22 | milestones,
23 | gamma=0.1,
24 | restarts=(0, ),
25 | restart_weights=(1, ),
26 | last_epoch=-1):
27 | self.milestones = Counter(milestones)
28 | self.gamma = gamma
29 | self.restarts = restarts
30 | self.restart_weights = restart_weights
31 | assert len(self.restarts) == len(
32 | self.restart_weights), 'restarts and their weights do not match.'
33 | super(MultiStepRestartLR, self).__init__(optimizer, last_epoch)
34 |
35 | def get_lr(self):
36 | if self.last_epoch in self.restarts:
37 | weight = self.restart_weights[self.restarts.index(self.last_epoch)]
38 | return [
39 | group['initial_lr'] * weight
40 | for group in self.optimizer.param_groups
41 | ]
42 | if self.last_epoch not in self.milestones:
43 | return [group['lr'] for group in self.optimizer.param_groups]
44 | return [
45 | group['lr'] * self.gamma**self.milestones[self.last_epoch]
46 | for group in self.optimizer.param_groups
47 | ]
48 |
49 |
50 | def get_position_from_periods(iteration, cumulative_period):
51 | """Get the position from a period list.
52 | It will return the index of the right-closest number in the period list.
53 | For example, the cumulative_period = [100, 200, 300, 400],
54 | if iteration == 50, return 0;
55 | if iteration == 210, return 2;
56 | if iteration == 300, return 2.
57 | Args:
58 | iteration (int): Current iteration.
59 | cumulative_period (list[int]): Cumulative period list.
60 | Returns:
61 | int: The position of the right-closest number in the period list.
62 | """
63 | for i, period in enumerate(cumulative_period):
64 | if iteration <= period:
65 | return i
66 |
67 |
68 | class CosineAnnealingRestartLR(_LRScheduler):
69 | """ Cosine annealing with restarts learning rate scheme.
70 | An example of config:
71 | periods = [10, 10, 10, 10]
72 | restart_weights = [1, 0.5, 0.5, 0.5]
73 | eta_min=1e-7
74 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
75 | scheduler will restart with the weights in restart_weights.
76 | Args:
77 | optimizer (torch.nn.optimizer): Torch optimizer.
78 | periods (list): Period for each cosine anneling cycle.
79 | restart_weights (list): Restart weights at each restart iteration.
80 | Default: [1].
81 | eta_min (float): The mimimum lr. Default: 0.
82 | last_epoch (int): Used in _LRScheduler. Default: -1.
83 | """
84 | def __init__(self,
85 | optimizer,
86 | periods,
87 | restart_weights=(1, ),
88 | eta_min=1e-7,
89 | last_epoch=-1):
90 | self.periods = periods
91 | self.restart_weights = restart_weights
92 | self.eta_min = eta_min
93 | assert (len(self.periods) == len(self.restart_weights)
94 | ), 'periods and restart_weights should have the same length.'
95 | self.cumulative_period = [
96 | sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
97 | ]
98 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
99 |
100 | def get_lr(self):
101 | idx = get_position_from_periods(self.last_epoch,
102 | self.cumulative_period)
103 | current_weight = self.restart_weights[idx]
104 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
105 | current_period = self.periods[idx]
106 |
107 | return [
108 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
109 | (1 + math.cos(math.pi * (
110 | (self.last_epoch - nearest_restart) / current_period)))
111 | for base_lr in self.base_lrs
112 | ]
113 |
--------------------------------------------------------------------------------
/core/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import skimage
3 | import skimage.metrics
4 | from skimage import measure
5 | from scipy import linalg
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | from core.utils import to_tensors
12 |
13 |
14 | def calculate_epe(flow1, flow2):
15 | """Calculate End point errors."""
16 |
17 | epe = torch.sum((flow1 - flow2)**2, dim=1).sqrt()
18 | epe = epe.view(-1)
19 | return epe.mean().item()
20 |
21 |
22 | def calculate_psnr(img1, img2):
23 | """Calculate PSNR (Peak Signal-to-Noise Ratio).
24 | Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
25 | Args:
26 | img1 (ndarray): Images with range [0, 255].
27 | img2 (ndarray): Images with range [0, 255].
28 | Returns:
29 | float: psnr result.
30 | """
31 |
32 | assert img1.shape == img2.shape, \
33 | (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
34 |
35 | mse = np.mean((img1 - img2)**2)
36 | if mse == 0:
37 | return float('inf')
38 | return 20. * np.log10(255. / np.sqrt(mse))
39 |
40 |
41 | def calc_psnr_and_ssim(img1, img2):
42 | """Calculate PSNR and SSIM for images.
43 | img1: ndarray, range [0, 255]
44 | img2: ndarray, range [0, 255]
45 | """
46 | img1 = img1.astype(np.float64)
47 | img2 = img2.astype(np.float64)
48 |
49 | psnr = calculate_psnr(img1, img2)
50 | if skimage.__version__ != '0.19.3':
51 | # old version skimage
52 | ssim = measure.compare_ssim(img1,
53 | img2,
54 | data_range=255,
55 | multichannel=True,
56 | win_size=65)
57 | else:
58 | # new version skimage
59 | ssim = skimage.metrics.structural_similarity(img1,
60 | img2,
61 | data_range=255,
62 | multichannel=True,
63 | win_size=65)
64 |
65 | return psnr, ssim
66 |
67 |
68 | ###########################
69 | # I3D models
70 | ###########################
71 |
72 |
73 | def init_i3d_model():
74 | i3d_model_path = './release_model/i3d_rgb_imagenet.pt'
75 | print(f"[Loading I3D model from {i3d_model_path} for FID score ..]")
76 | i3d_model = InceptionI3d(400, in_channels=3, final_endpoint='Logits')
77 | i3d_model.load_state_dict(torch.load(i3d_model_path))
78 | i3d_model.to(torch.device('cuda:0'))
79 | return i3d_model
80 |
81 |
82 | def calculate_i3d_activations(video1, video2, i3d_model, device):
83 | """Calculate VFID metric.
84 | video1: list[PIL.Image]
85 | video2: list[PIL.Image]
86 | """
87 | video1 = to_tensors()(video1).unsqueeze(0).to(device)
88 | video2 = to_tensors()(video2).unsqueeze(0).to(device)
89 | video1_activations = get_i3d_activations(
90 | video1, i3d_model).cpu().numpy().flatten()
91 | video2_activations = get_i3d_activations(
92 | video2, i3d_model).cpu().numpy().flatten()
93 |
94 | return video1_activations, video2_activations
95 |
96 |
97 | def calculate_vfid(real_activations, fake_activations):
98 | """
99 | Given two distribution of features, compute the FID score between them
100 | Params:
101 | real_activations: list[ndarray]
102 | fake_activations: list[ndarray]
103 | """
104 | m1 = np.mean(real_activations, axis=0)
105 | m2 = np.mean(fake_activations, axis=0)
106 | s1 = np.cov(real_activations, rowvar=False)
107 | s2 = np.cov(fake_activations, rowvar=False)
108 | return calculate_frechet_distance(m1, s1, m2, s2)
109 |
110 |
111 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
112 | """Numpy implementation of the Frechet Distance.
113 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
114 | and X_2 ~ N(mu_2, C_2) is
115 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
116 | Stable version by Dougal J. Sutherland.
117 | Params:
118 | -- mu1 : Numpy array containing the activations of a layer of the
119 | inception net (like returned by the function 'get_predictions')
120 | for generated samples.
121 | -- mu2 : The sample mean over activations, precalculated on an
122 | representive data set.
123 | -- sigma1: The covariance matrix over activations for generated samples.
124 | -- sigma2: The covariance matrix over activations, precalculated on an
125 | representive data set.
126 | Returns:
127 | -- : The Frechet Distance.
128 | """
129 |
130 | mu1 = np.atleast_1d(mu1)
131 | mu2 = np.atleast_1d(mu2)
132 |
133 | sigma1 = np.atleast_2d(sigma1)
134 | sigma2 = np.atleast_2d(sigma2)
135 |
136 | assert mu1.shape == mu2.shape, \
137 | 'Training and test mean vectors have different lengths'
138 | assert sigma1.shape == sigma2.shape, \
139 | 'Training and test covariances have different dimensions'
140 |
141 | diff = mu1 - mu2
142 |
143 | # Product might be almost singular
144 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
145 | if not np.isfinite(covmean).all():
146 | msg = ('fid calculation produces singular product; '
147 | 'adding %s to diagonal of cov estimates') % eps
148 | print(msg)
149 | offset = np.eye(sigma1.shape[0]) * eps
150 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
151 |
152 | # Numerical error might give slight imaginary component
153 | if np.iscomplexobj(covmean):
154 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
155 | m = np.max(np.abs(covmean.imag))
156 | raise ValueError('Imaginary component {}'.format(m))
157 | covmean = covmean.real
158 |
159 | tr_covmean = np.trace(covmean)
160 |
161 | return (diff.dot(diff) + np.trace(sigma1) + # NOQA
162 | np.trace(sigma2) - 2 * tr_covmean)
163 |
164 |
165 | def get_i3d_activations(batched_video,
166 | i3d_model,
167 | target_endpoint='Logits',
168 | flatten=True,
169 | grad_enabled=False):
170 | """
171 | Get features from i3d model and flatten them to 1d feature,
172 | valid target endpoints are defined in InceptionI3d.VALID_ENDPOINTS
173 | VALID_ENDPOINTS = (
174 | 'Conv3d_1a_7x7',
175 | 'MaxPool3d_2a_3x3',
176 | 'Conv3d_2b_1x1',
177 | 'Conv3d_2c_3x3',
178 | 'MaxPool3d_3a_3x3',
179 | 'Mixed_3b',
180 | 'Mixed_3c',
181 | 'MaxPool3d_4a_3x3',
182 | 'Mixed_4b',
183 | 'Mixed_4c',
184 | 'Mixed_4d',
185 | 'Mixed_4e',
186 | 'Mixed_4f',
187 | 'MaxPool3d_5a_2x2',
188 | 'Mixed_5b',
189 | 'Mixed_5c',
190 | 'Logits',
191 | 'Predictions',
192 | )
193 | """
194 | with torch.set_grad_enabled(grad_enabled):
195 | feat = i3d_model.extract_features(batched_video.transpose(1, 2),
196 | target_endpoint)
197 | if flatten:
198 | feat = feat.view(feat.size(0), -1)
199 |
200 | return feat
201 |
202 |
203 | # This code is from https://github.com/piergiaj/pytorch-i3d/blob/master/pytorch_i3d.py
204 | # I only fix flake8 errors and do some cleaning here
205 |
206 |
207 | class MaxPool3dSamePadding(nn.MaxPool3d):
208 | def compute_pad(self, dim, s):
209 | if s % self.stride[dim] == 0:
210 | return max(self.kernel_size[dim] - self.stride[dim], 0)
211 | else:
212 | return max(self.kernel_size[dim] - (s % self.stride[dim]), 0)
213 |
214 | def forward(self, x):
215 | # compute 'same' padding
216 | (batch, channel, t, h, w) = x.size()
217 | pad_t = self.compute_pad(0, t)
218 | pad_h = self.compute_pad(1, h)
219 | pad_w = self.compute_pad(2, w)
220 |
221 | pad_t_f = pad_t // 2
222 | pad_t_b = pad_t - pad_t_f
223 | pad_h_f = pad_h // 2
224 | pad_h_b = pad_h - pad_h_f
225 | pad_w_f = pad_w // 2
226 | pad_w_b = pad_w - pad_w_f
227 |
228 | pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
229 | x = F.pad(x, pad)
230 | return super(MaxPool3dSamePadding, self).forward(x)
231 |
232 |
233 | class Unit3D(nn.Module):
234 | def __init__(self,
235 | in_channels,
236 | output_channels,
237 | kernel_shape=(1, 1, 1),
238 | stride=(1, 1, 1),
239 | padding=0,
240 | activation_fn=F.relu,
241 | use_batch_norm=True,
242 | use_bias=False,
243 | name='unit_3d'):
244 | """Initializes Unit3D module."""
245 | super(Unit3D, self).__init__()
246 |
247 | self._output_channels = output_channels
248 | self._kernel_shape = kernel_shape
249 | self._stride = stride
250 | self._use_batch_norm = use_batch_norm
251 | self._activation_fn = activation_fn
252 | self._use_bias = use_bias
253 | self.name = name
254 | self.padding = padding
255 |
256 | self.conv3d = nn.Conv3d(
257 | in_channels=in_channels,
258 | out_channels=self._output_channels,
259 | kernel_size=self._kernel_shape,
260 | stride=self._stride,
261 | padding=0, # we always want padding to be 0 here. We will
262 | # dynamically pad based on input size in forward function
263 | bias=self._use_bias)
264 |
265 | if self._use_batch_norm:
266 | self.bn = nn.BatchNorm3d(self._output_channels,
267 | eps=0.001,
268 | momentum=0.01)
269 |
270 | def compute_pad(self, dim, s):
271 | if s % self._stride[dim] == 0:
272 | return max(self._kernel_shape[dim] - self._stride[dim], 0)
273 | else:
274 | return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0)
275 |
276 | def forward(self, x):
277 | # compute 'same' padding
278 | (batch, channel, t, h, w) = x.size()
279 | pad_t = self.compute_pad(0, t)
280 | pad_h = self.compute_pad(1, h)
281 | pad_w = self.compute_pad(2, w)
282 |
283 | pad_t_f = pad_t // 2
284 | pad_t_b = pad_t - pad_t_f
285 | pad_h_f = pad_h // 2
286 | pad_h_b = pad_h - pad_h_f
287 | pad_w_f = pad_w // 2
288 | pad_w_b = pad_w - pad_w_f
289 |
290 | pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
291 | x = F.pad(x, pad)
292 |
293 | x = self.conv3d(x)
294 | if self._use_batch_norm:
295 | x = self.bn(x)
296 | if self._activation_fn is not None:
297 | x = self._activation_fn(x)
298 | return x
299 |
300 |
301 | class InceptionModule(nn.Module):
302 | def __init__(self, in_channels, out_channels, name):
303 | super(InceptionModule, self).__init__()
304 |
305 | self.b0 = Unit3D(in_channels=in_channels,
306 | output_channels=out_channels[0],
307 | kernel_shape=[1, 1, 1],
308 | padding=0,
309 | name=name + '/Branch_0/Conv3d_0a_1x1')
310 | self.b1a = Unit3D(in_channels=in_channels,
311 | output_channels=out_channels[1],
312 | kernel_shape=[1, 1, 1],
313 | padding=0,
314 | name=name + '/Branch_1/Conv3d_0a_1x1')
315 | self.b1b = Unit3D(in_channels=out_channels[1],
316 | output_channels=out_channels[2],
317 | kernel_shape=[3, 3, 3],
318 | name=name + '/Branch_1/Conv3d_0b_3x3')
319 | self.b2a = Unit3D(in_channels=in_channels,
320 | output_channels=out_channels[3],
321 | kernel_shape=[1, 1, 1],
322 | padding=0,
323 | name=name + '/Branch_2/Conv3d_0a_1x1')
324 | self.b2b = Unit3D(in_channels=out_channels[3],
325 | output_channels=out_channels[4],
326 | kernel_shape=[3, 3, 3],
327 | name=name + '/Branch_2/Conv3d_0b_3x3')
328 | self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3],
329 | stride=(1, 1, 1),
330 | padding=0)
331 | self.b3b = Unit3D(in_channels=in_channels,
332 | output_channels=out_channels[5],
333 | kernel_shape=[1, 1, 1],
334 | padding=0,
335 | name=name + '/Branch_3/Conv3d_0b_1x1')
336 | self.name = name
337 |
338 | def forward(self, x):
339 | b0 = self.b0(x)
340 | b1 = self.b1b(self.b1a(x))
341 | b2 = self.b2b(self.b2a(x))
342 | b3 = self.b3b(self.b3a(x))
343 | return torch.cat([b0, b1, b2, b3], dim=1)
344 |
345 |
346 | class InceptionI3d(nn.Module):
347 | """Inception-v1 I3D architecture.
348 | The model is introduced in:
349 | Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset
350 | Joao Carreira, Andrew Zisserman
351 | https://arxiv.org/pdf/1705.07750v1.pdf.
352 | See also the Inception architecture, introduced in:
353 | Going deeper with convolutions
354 | Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed,
355 | Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich.
356 | http://arxiv.org/pdf/1409.4842v1.pdf.
357 | """
358 |
359 | # Endpoints of the model in order. During construction, all the endpoints up
360 | # to a designated `final_endpoint` are returned in a dictionary as the
361 | # second return value.
362 | VALID_ENDPOINTS = (
363 | 'Conv3d_1a_7x7',
364 | 'MaxPool3d_2a_3x3',
365 | 'Conv3d_2b_1x1',
366 | 'Conv3d_2c_3x3',
367 | 'MaxPool3d_3a_3x3',
368 | 'Mixed_3b',
369 | 'Mixed_3c',
370 | 'MaxPool3d_4a_3x3',
371 | 'Mixed_4b',
372 | 'Mixed_4c',
373 | 'Mixed_4d',
374 | 'Mixed_4e',
375 | 'Mixed_4f',
376 | 'MaxPool3d_5a_2x2',
377 | 'Mixed_5b',
378 | 'Mixed_5c',
379 | 'Logits',
380 | 'Predictions',
381 | )
382 |
383 | def __init__(self,
384 | num_classes=400,
385 | spatial_squeeze=True,
386 | final_endpoint='Logits',
387 | name='inception_i3d',
388 | in_channels=3,
389 | dropout_keep_prob=0.5):
390 | """Initializes I3D model instance.
391 | Args:
392 | num_classes: The number of outputs in the logit layer (default 400, which
393 | matches the Kinetics dataset).
394 | spatial_squeeze: Whether to squeeze the spatial dimensions for the logits
395 | before returning (default True).
396 | final_endpoint: The model contains many possible endpoints.
397 | `final_endpoint` specifies the last endpoint for the model to be built
398 | up to. In addition to the output at `final_endpoint`, all the outputs
399 | at endpoints up to `final_endpoint` will also be returned, in a
400 | dictionary. `final_endpoint` must be one of
401 | InceptionI3d.VALID_ENDPOINTS (default 'Logits').
402 | name: A string (optional). The name of this module.
403 | Raises:
404 | ValueError: if `final_endpoint` is not recognized.
405 | """
406 |
407 | if final_endpoint not in self.VALID_ENDPOINTS:
408 | raise ValueError('Unknown final endpoint %s' % final_endpoint)
409 |
410 | super(InceptionI3d, self).__init__()
411 | self._num_classes = num_classes
412 | self._spatial_squeeze = spatial_squeeze
413 | self._final_endpoint = final_endpoint
414 | self.logits = None
415 |
416 | if self._final_endpoint not in self.VALID_ENDPOINTS:
417 | raise ValueError('Unknown final endpoint %s' %
418 | self._final_endpoint)
419 |
420 | self.end_points = {}
421 | end_point = 'Conv3d_1a_7x7'
422 | self.end_points[end_point] = Unit3D(in_channels=in_channels,
423 | output_channels=64,
424 | kernel_shape=[7, 7, 7],
425 | stride=(2, 2, 2),
426 | padding=(3, 3, 3),
427 | name=name + end_point)
428 | if self._final_endpoint == end_point:
429 | return
430 |
431 | end_point = 'MaxPool3d_2a_3x3'
432 | self.end_points[end_point] = MaxPool3dSamePadding(
433 | kernel_size=[1, 3, 3], stride=(1, 2, 2), padding=0)
434 | if self._final_endpoint == end_point:
435 | return
436 |
437 | end_point = 'Conv3d_2b_1x1'
438 | self.end_points[end_point] = Unit3D(in_channels=64,
439 | output_channels=64,
440 | kernel_shape=[1, 1, 1],
441 | padding=0,
442 | name=name + end_point)
443 | if self._final_endpoint == end_point:
444 | return
445 |
446 | end_point = 'Conv3d_2c_3x3'
447 | self.end_points[end_point] = Unit3D(in_channels=64,
448 | output_channels=192,
449 | kernel_shape=[3, 3, 3],
450 | padding=1,
451 | name=name + end_point)
452 | if self._final_endpoint == end_point:
453 | return
454 |
455 | end_point = 'MaxPool3d_3a_3x3'
456 | self.end_points[end_point] = MaxPool3dSamePadding(
457 | kernel_size=[1, 3, 3], stride=(1, 2, 2), padding=0)
458 | if self._final_endpoint == end_point:
459 | return
460 |
461 | end_point = 'Mixed_3b'
462 | self.end_points[end_point] = InceptionModule(192,
463 | [64, 96, 128, 16, 32, 32],
464 | name + end_point)
465 | if self._final_endpoint == end_point:
466 | return
467 |
468 | end_point = 'Mixed_3c'
469 | self.end_points[end_point] = InceptionModule(
470 | 256, [128, 128, 192, 32, 96, 64], name + end_point)
471 | if self._final_endpoint == end_point:
472 | return
473 |
474 | end_point = 'MaxPool3d_4a_3x3'
475 | self.end_points[end_point] = MaxPool3dSamePadding(
476 | kernel_size=[3, 3, 3], stride=(2, 2, 2), padding=0)
477 | if self._final_endpoint == end_point:
478 | return
479 |
480 | end_point = 'Mixed_4b'
481 | self.end_points[end_point] = InceptionModule(
482 | 128 + 192 + 96 + 64, [192, 96, 208, 16, 48, 64], name + end_point)
483 | if self._final_endpoint == end_point:
484 | return
485 |
486 | end_point = 'Mixed_4c'
487 | self.end_points[end_point] = InceptionModule(
488 | 192 + 208 + 48 + 64, [160, 112, 224, 24, 64, 64], name + end_point)
489 | if self._final_endpoint == end_point:
490 | return
491 |
492 | end_point = 'Mixed_4d'
493 | self.end_points[end_point] = InceptionModule(
494 | 160 + 224 + 64 + 64, [128, 128, 256, 24, 64, 64], name + end_point)
495 | if self._final_endpoint == end_point:
496 | return
497 |
498 | end_point = 'Mixed_4e'
499 | self.end_points[end_point] = InceptionModule(
500 | 128 + 256 + 64 + 64, [112, 144, 288, 32, 64, 64], name + end_point)
501 | if self._final_endpoint == end_point:
502 | return
503 |
504 | end_point = 'Mixed_4f'
505 | self.end_points[end_point] = InceptionModule(
506 | 112 + 288 + 64 + 64, [256, 160, 320, 32, 128, 128],
507 | name + end_point)
508 | if self._final_endpoint == end_point:
509 | return
510 |
511 | end_point = 'MaxPool3d_5a_2x2'
512 | self.end_points[end_point] = MaxPool3dSamePadding(
513 | kernel_size=[2, 2, 2], stride=(2, 2, 2), padding=0)
514 | if self._final_endpoint == end_point:
515 | return
516 |
517 | end_point = 'Mixed_5b'
518 | self.end_points[end_point] = InceptionModule(
519 | 256 + 320 + 128 + 128, [256, 160, 320, 32, 128, 128],
520 | name + end_point)
521 | if self._final_endpoint == end_point:
522 | return
523 |
524 | end_point = 'Mixed_5c'
525 | self.end_points[end_point] = InceptionModule(
526 | 256 + 320 + 128 + 128, [384, 192, 384, 48, 128, 128],
527 | name + end_point)
528 | if self._final_endpoint == end_point:
529 | return
530 |
531 | end_point = 'Logits'
532 | self.avg_pool = nn.AvgPool3d(kernel_size=[2, 7, 7], stride=(1, 1, 1))
533 | self.dropout = nn.Dropout(dropout_keep_prob)
534 | self.logits = Unit3D(in_channels=384 + 384 + 128 + 128,
535 | output_channels=self._num_classes,
536 | kernel_shape=[1, 1, 1],
537 | padding=0,
538 | activation_fn=None,
539 | use_batch_norm=False,
540 | use_bias=True,
541 | name='logits')
542 |
543 | self.build()
544 |
545 | def replace_logits(self, num_classes):
546 | self._num_classes = num_classes
547 | self.logits = Unit3D(in_channels=384 + 384 + 128 + 128,
548 | output_channels=self._num_classes,
549 | kernel_shape=[1, 1, 1],
550 | padding=0,
551 | activation_fn=None,
552 | use_batch_norm=False,
553 | use_bias=True,
554 | name='logits')
555 |
556 | def build(self):
557 | for k in self.end_points.keys():
558 | self.add_module(k, self.end_points[k])
559 |
560 | def forward(self, x):
561 | for end_point in self.VALID_ENDPOINTS:
562 | if end_point in self.end_points:
563 | x = self._modules[end_point](
564 | x) # use _modules to work with dataparallel
565 |
566 | x = self.logits(self.dropout(self.avg_pool(x)))
567 | if self._spatial_squeeze:
568 | logits = x.squeeze(3).squeeze(3)
569 | # logits is batch X time X classes, which is what we want to work with
570 | return logits
571 |
572 | def extract_features(self, x, target_endpoint='Logits'):
573 | for end_point in self.VALID_ENDPOINTS:
574 | if end_point in self.end_points:
575 | x = self._modules[end_point](x)
576 | if end_point == target_endpoint:
577 | break
578 | if target_endpoint == 'Logits':
579 | return x.mean(4).mean(3).mean(2)
580 | else:
581 | return x
582 |
--------------------------------------------------------------------------------
/core/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import io
3 | import cv2
4 | import random
5 | import numpy as np
6 | from PIL import Image, ImageOps
7 | import zipfile
8 |
9 | import torch
10 | import matplotlib
11 | import matplotlib.patches as patches
12 | from matplotlib.path import Path
13 | from matplotlib import pyplot as plt
14 | from torchvision import transforms
15 |
16 | # 解决随机mask生成的OMP Error15报错
17 | os.environ['KMP_DUPLICATE_LIB_OK']='True'
18 |
19 | # ###########################################################################
20 | # Directory IO
21 | # ###########################################################################
22 |
23 |
24 | def read_dirnames_under_root(root_dir):
25 | dirnames = [
26 | name for i, name in enumerate(sorted(os.listdir(root_dir)))
27 | if os.path.isdir(os.path.join(root_dir, name))
28 | ]
29 | print(f'Reading directories under {root_dir}, num: {len(dirnames)}')
30 | return dirnames
31 |
32 |
33 | class TrainZipReader(object):
34 | file_dict = dict()
35 |
36 | def __init__(self):
37 | super(TrainZipReader, self).__init__()
38 |
39 | @staticmethod
40 | def build_file_dict(path):
41 | file_dict = TrainZipReader.file_dict
42 | if path in file_dict:
43 | return file_dict[path]
44 | else:
45 | file_handle = zipfile.ZipFile(path, 'r')
46 | file_dict[path] = file_handle
47 | return file_dict[path]
48 |
49 | @staticmethod
50 | def imread(path, idx):
51 | zfile = TrainZipReader.build_file_dict(path)
52 | filelist = zfile.namelist()
53 | filelist.sort()
54 | data = zfile.read(filelist[idx])
55 | #
56 | im = Image.open(io.BytesIO(data))
57 | return im
58 |
59 |
60 | class TestZipReader(object):
61 | file_dict = dict()
62 |
63 | def __init__(self):
64 | super(TestZipReader, self).__init__()
65 |
66 | @staticmethod
67 | def build_file_dict(path):
68 | file_dict = TestZipReader.file_dict
69 | if path in file_dict:
70 | return file_dict[path]
71 | else:
72 | file_handle = zipfile.ZipFile(path, 'r')
73 | file_dict[path] = file_handle
74 | return file_dict[path]
75 |
76 | @staticmethod
77 | def imread(path, idx):
78 | zfile = TestZipReader.build_file_dict(path)
79 | filelist = zfile.namelist()
80 | filelist.sort()
81 | data = zfile.read(filelist[idx])
82 | file_bytes = np.asarray(bytearray(data), dtype=np.uint8)
83 | im = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
84 | im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
85 | # im = Image.open(io.BytesIO(data))
86 | return im
87 |
88 |
89 | # ###########################################################################
90 | # Data augmentation
91 | # ###########################################################################
92 |
93 |
94 | def to_tensors():
95 | return transforms.Compose([Stack(), ToTorchFormatTensor()])
96 |
97 |
98 | class GroupRandomHorizontalFlowFlip(object):
99 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5
100 | """
101 | def __init__(self, is_flow=True):
102 | self.is_flow = is_flow
103 |
104 | def __call__(self, img_group, mask_group, flowF_group, flowB_group):
105 | v = random.random()
106 | if v < 0.5:
107 | ret_img = [
108 | img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group
109 | ]
110 | ret_mask = [
111 | mask.transpose(Image.FLIP_LEFT_RIGHT) for mask in mask_group
112 | ]
113 | ret_flowF = [ff[:, ::-1] * [-1.0, 1.0] for ff in flowF_group]
114 | ret_flowB = [fb[:, ::-1] * [-1.0, 1.0] for fb in flowB_group]
115 | return ret_img, ret_mask, ret_flowF, ret_flowB
116 | else:
117 | return img_group, mask_group, flowF_group, flowB_group
118 |
119 |
120 | class GroupRandomHorizontalFlip(object):
121 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5
122 | """
123 | def __init__(self, is_flow=False):
124 | self.is_flow = is_flow
125 |
126 | def __call__(self, img_group, is_flow=False):
127 | v = random.random()
128 | if v < 0.5:
129 | ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
130 | if self.is_flow:
131 | for i in range(0, len(ret), 2):
132 | # invert flow pixel values when flipping
133 | ret[i] = ImageOps.invert(ret[i])
134 | return ret
135 | else:
136 | return img_group
137 |
138 |
139 | class Stack(object):
140 | def __init__(self, roll=False):
141 | self.roll = roll
142 |
143 | def __call__(self, img_group):
144 | mode = img_group[0].mode
145 | if mode == '1':
146 | img_group = [img.convert('L') for img in img_group]
147 | mode = 'L'
148 | if mode == 'L':
149 | return np.stack([np.expand_dims(x, 2) for x in img_group], axis=2)
150 | elif mode == 'RGB':
151 | if self.roll:
152 | return np.stack([np.array(x)[:, :, ::-1] for x in img_group],
153 | axis=2)
154 | else:
155 | return np.stack(img_group, axis=2)
156 | else:
157 | raise NotImplementedError(f"Image mode {mode}")
158 |
159 |
160 | class ToTorchFormatTensor(object):
161 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
162 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
163 | def __init__(self, div=True):
164 | self.div = div
165 |
166 | def __call__(self, pic):
167 | if isinstance(pic, np.ndarray):
168 | # numpy img: [L, C, H, W]
169 | img = torch.from_numpy(pic).permute(2, 3, 0, 1).contiguous()
170 | else:
171 | # handle PIL Image
172 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(
173 | pic.tobytes()))
174 | img = img.view(pic.size[1], pic.size[0], len(pic.mode))
175 | # put it from HWC to CHW format
176 | # yikes, this transpose takes 80% of the loading time/CPU
177 | img = img.transpose(0, 1).transpose(0, 2).contiguous()
178 | img = img.float().div(255) if self.div else img.float()
179 | return img
180 |
181 |
182 | # ###########################################################################
183 | # Create masks with random shape
184 | # ###########################################################################
185 |
186 |
187 | def create_random_shape_with_random_motion(video_length,
188 | imageHeight=240,
189 | imageWidth=432):
190 | # get a random shape
191 | height = random.randint(imageHeight // 3, imageHeight - 1)
192 | width = random.randint(imageWidth // 3, imageWidth - 1)
193 | edge_num = random.randint(6, 8)
194 | ratio = random.randint(6, 8) / 10
195 | region = get_random_shape(edge_num=edge_num,
196 | ratio=ratio,
197 | height=height,
198 | width=width)
199 | region_width, region_height = region.size
200 | # get random position
201 | x, y = random.randint(0, imageHeight - region_height), random.randint(
202 | 0, imageWidth - region_width)
203 | velocity = get_random_velocity(max_speed=3)
204 | m = Image.fromarray(np.zeros((imageHeight, imageWidth)).astype(np.uint8))
205 | m.paste(region, (y, x, y + region.size[0], x + region.size[1]))
206 | masks = [m.convert('L')]
207 | # return fixed masks
208 | if random.uniform(0, 1) > 0.5:
209 | return masks * video_length
210 | # return moving masks
211 | for _ in range(video_length - 1):
212 | x, y, velocity = random_move_control_points(x,
213 | y,
214 | imageHeight,
215 | imageWidth,
216 | velocity,
217 | region.size,
218 | maxLineAcceleration=(3,
219 | 0.5),
220 | maxInitSpeed=3)
221 | m = Image.fromarray(
222 | np.zeros((imageHeight, imageWidth)).astype(np.uint8))
223 | m.paste(region, (y, x, y + region.size[0], x + region.size[1]))
224 | masks.append(m.convert('L'))
225 | return masks
226 |
227 |
228 | def create_random_shape_with_random_motion_seq(video_length,
229 | imageHeight=240,
230 | imageWidth=432,
231 | new_mask=True,
232 | random_dict=None
233 | ):
234 | """
235 | Sequence Mask Generator by Hao.
236 | new_mask (bool): If True, generate a new batch of mask;
237 | If False, use the previous params to generate the same mask.
238 | """
239 |
240 | if new_mask:
241 | # 生成新mask的一大堆参数
242 | # get a random shape
243 | height = random.randint(imageHeight // 3, imageHeight - 1)
244 | width = random.randint(imageWidth // 3, imageWidth - 1)
245 | edge_num = random.randint(6, 8)
246 | ratio = random.randint(6, 8) / 10
247 | region, random_point = get_random_shape_seq(edge_num=edge_num,
248 | ratio=ratio,
249 | height=height,
250 | width=width)
251 | region_width, region_height = region.size
252 |
253 | # get random position
254 | x, y = random.randint(0, imageHeight - region_height), random.randint(
255 | 0, imageWidth - region_width)
256 | velocity = get_random_velocity(max_speed=3)
257 |
258 | # get random probability
259 | prob = random.uniform(0, 1)
260 |
261 | # 存储新生成的随机参数字典
262 | random_dict = {}
263 | # random_dict['region'] = region # pytorch 的 dataloader无法返回PIL-Image格式的数据
264 | random_dict['edge_num'] = edge_num
265 | random_dict['ratio'] = ratio
266 | random_dict['height'] = height
267 | random_dict['width'] = width
268 | random_dict['random_point'] = random_point
269 | random_dict['x'] = x
270 | random_dict['y'] = y
271 | random_dict['velocity'] = velocity
272 | random_dict['prob'] = prob
273 |
274 | else:
275 | # 用旧的参数,不用重新生成
276 | # region = random_dict['region']
277 | edge_num = random_dict['edge_num']
278 | ratio = random_dict['ratio']
279 | height = random_dict['height']
280 | width = random_dict['width']
281 | random_point = random_dict['random_point']
282 | # 当random point固定后区域实际上已经固定了
283 | region, random_point = get_random_shape_seq(edge_num=edge_num,
284 | ratio=ratio,
285 | height=height,
286 | width=width,
287 | random_point=random_point)
288 |
289 | x = random_dict['x']
290 | y = random_dict['y']
291 | velocity = random_dict['velocity']
292 | prob = random_dict['prob']
293 |
294 | # 创建静态mask
295 | m = Image.fromarray(np.zeros((imageHeight, imageWidth)).astype(np.uint8))
296 | m.paste(region, (y, x, y + region.size[0], x + region.size[1]))
297 | masks = [m.convert('L')]
298 |
299 | # return fixed masks
300 | if prob > 0.5:
301 | return masks * video_length, random_dict
302 |
303 | # return moving masks
304 | # 由于moving有个随机运动, 也希望一致所以需要返回这些位置和速度
305 | if new_mask:
306 | x_list = []
307 | y_list = []
308 | velocity_list = []
309 | else:
310 | # 如果使用之前的运动参数,读取一下
311 | x_list = random_dict['x_list']
312 | y_list = random_dict['y_list']
313 | velocity_list = random_dict['velocity_list']
314 |
315 | for idx in range(video_length - 1):
316 | if new_mask:
317 | # 重新生成每一帧的随机
318 | x, y, velocity = random_move_control_points(x,
319 | y,
320 | imageHeight,
321 | imageWidth,
322 | velocity,
323 | region.size,
324 | maxLineAcceleration=(3,
325 | 0.5),
326 | maxInitSpeed=3)
327 | # 存储位置和速度
328 | x_list.append(x)
329 | y_list.append(y)
330 | velocity_list.append(velocity)
331 | else:
332 | # 直接从dict的list里面读取
333 | x = x_list[idx]
334 | y = y_list[idx]
335 | velocity = velocity_list[idx]
336 |
337 | # 生成当前帧的运动mask
338 | m = Image.fromarray(
339 | np.zeros((imageHeight, imageWidth)).astype(np.uint8))
340 | m.paste(region, (y, x, y + region.size[0], x + region.size[1]))
341 | masks.append(m.convert('L'))
342 |
343 | # 把新生成的运动mask位置和速度保存到字典里
344 | if new_mask:
345 | random_dict['x_list'] = x_list
346 | random_dict['y_list'] = y_list
347 | random_dict['velocity_list'] = velocity_list
348 |
349 | return masks, random_dict
350 |
351 |
352 | def get_random_shape(edge_num=9, ratio=0.7, width=432, height=240):
353 | '''
354 | There is the initial point and 3 points per cubic bezier curve.
355 | Thus, the curve will only pass though n points, which will be the sharp edges.
356 | The other 2 modify the shape of the bezier curve.
357 | edge_num, Number of possibly sharp edges
358 | points_num, number of points in the Path
359 | ratio, (0, 1) magnitude of the perturbation from the unit circle,
360 | '''
361 | points_num = edge_num * 3 + 1
362 | angles = np.linspace(0, 2 * np.pi, points_num)
363 | codes = np.full(points_num, Path.CURVE4)
364 | codes[0] = Path.MOVETO
365 | # Using this instad of Path.CLOSEPOLY avoids an innecessary straight line
366 | verts = np.stack((np.cos(angles), np.sin(angles))).T * \
367 | (2*ratio*np.random.random(points_num)+1-ratio)[:, None]
368 | verts[-1, :] = verts[0, :]
369 | path = Path(verts, codes)
370 | # draw paths into images
371 | fig = plt.figure()
372 | ax = fig.add_subplot(111)
373 | patch = patches.PathPatch(path, facecolor='black', lw=2)
374 | ax.add_patch(patch)
375 | ax.set_xlim(np.min(verts) * 1.1, np.max(verts) * 1.1)
376 | ax.set_ylim(np.min(verts) * 1.1, np.max(verts) * 1.1)
377 | ax.axis('off') # removes the axis to leave only the shape
378 | fig.canvas.draw()
379 | # convert plt images into numpy images
380 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
381 | data = data.reshape((fig.canvas.get_width_height()[::-1] + (3, )))
382 | plt.close(fig)
383 | # postprocess
384 | data = cv2.resize(data, (width, height))[:, :, 0]
385 | data = (1 - np.array(data > 0).astype(np.uint8)) * 255
386 | corrdinates = np.where(data > 0)
387 | xmin, xmax, ymin, ymax = np.min(corrdinates[0]), np.max(
388 | corrdinates[0]), np.min(corrdinates[1]), np.max(corrdinates[1])
389 | region = Image.fromarray(data).crop((ymin, xmin, ymax, xmax))
390 | return region
391 |
392 |
393 | def get_random_shape_seq(edge_num=9, ratio=0.7, width=432, height=240, random_point=None):
394 | '''
395 | There is the initial point and 3 points per cubic bezier curve.
396 | Thus, the curve will only pass though n points, which will be the sharp edges.
397 | The other 2 modify the shape of the bezier curve.
398 | edge_num, Number of possibly sharp edges
399 | points_num, number of points in the Path
400 | ratio, (0, 1) magnitude of the perturbation from the unit circle,
401 | Revised by Hao:
402 | random_point: if given, the shape is known and fixed.
403 | '''
404 | points_num = edge_num * 3 + 1
405 |
406 | if random_point is None:
407 | random_point = np.random.random(points_num)
408 | else:
409 | random_point = random_point
410 |
411 | angles = np.linspace(0, 2 * np.pi, points_num)
412 | codes = np.full(points_num, Path.CURVE4)
413 | codes[0] = Path.MOVETO
414 | # Using this instad of Path.CLOSEPOLY avoids an innecessary straight line
415 | verts = np.stack((np.cos(angles), np.sin(angles))).T * \
416 | (2*ratio*random_point+1-ratio)[:, None]
417 | verts[-1, :] = verts[0, :]
418 | path = Path(verts, codes)
419 | # draw paths into images
420 | fig = plt.figure()
421 | ax = fig.add_subplot(111)
422 | patch = patches.PathPatch(path, facecolor='black', lw=2)
423 | ax.add_patch(patch)
424 | ax.set_xlim(np.min(verts) * 1.1, np.max(verts) * 1.1)
425 | ax.set_ylim(np.min(verts) * 1.1, np.max(verts) * 1.1)
426 | ax.axis('off') # removes the axis to leave only the shape
427 | fig.canvas.draw()
428 | # convert plt images into numpy images
429 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
430 | data = data.reshape((fig.canvas.get_width_height()[::-1] + (3, )))
431 | plt.close(fig)
432 | # postprocess
433 | data = cv2.resize(data, (width, height))[:, :, 0]
434 | data = (1 - np.array(data > 0).astype(np.uint8)) * 255
435 | corrdinates = np.where(data > 0)
436 | xmin, xmax, ymin, ymax = np.min(corrdinates[0]), np.max(
437 | corrdinates[0]), np.min(corrdinates[1]), np.max(corrdinates[1])
438 | region = Image.fromarray(data).crop((ymin, xmin, ymax, xmax))
439 | return region, random_point
440 |
441 |
442 | def random_accelerate(velocity, maxAcceleration, dist='uniform'):
443 | speed, angle = velocity
444 | d_speed, d_angle = maxAcceleration
445 | if dist == 'uniform':
446 | speed += np.random.uniform(-d_speed, d_speed)
447 | angle += np.random.uniform(-d_angle, d_angle)
448 | elif dist == 'guassian':
449 | speed += np.random.normal(0, d_speed / 2)
450 | angle += np.random.normal(0, d_angle / 2)
451 | else:
452 | raise NotImplementedError(
453 | f'Distribution type {dist} is not supported.')
454 | return (speed, angle)
455 |
456 |
457 | def get_random_velocity(max_speed=3, dist='uniform'):
458 | if dist == 'uniform':
459 | speed = np.random.uniform(max_speed)
460 | elif dist == 'guassian':
461 | speed = np.abs(np.random.normal(0, max_speed / 2))
462 | else:
463 | raise NotImplementedError(
464 | f'Distribution type {dist} is not supported.')
465 | angle = np.random.uniform(0, 2 * np.pi)
466 | return (speed, angle)
467 |
468 |
469 | def random_move_control_points(X,
470 | Y,
471 | imageHeight,
472 | imageWidth,
473 | lineVelocity,
474 | region_size,
475 | maxLineAcceleration=(3, 0.5),
476 | maxInitSpeed=3):
477 | region_width, region_height = region_size
478 | speed, angle = lineVelocity
479 | X += int(speed * np.cos(angle))
480 | Y += int(speed * np.sin(angle))
481 | lineVelocity = random_accelerate(lineVelocity,
482 | maxLineAcceleration,
483 | dist='guassian')
484 | if ((X > imageHeight - region_height) or (X < 0)
485 | or (Y > imageWidth - region_width) or (Y < 0)):
486 | lineVelocity = get_random_velocity(maxInitSpeed, dist='guassian')
487 | new_X = np.clip(X, 0, imageHeight - region_height)
488 | new_Y = np.clip(Y, 0, imageWidth - region_width)
489 | return new_X, new_Y, lineVelocity
490 |
491 |
492 | if __name__ == '__main__':
493 |
494 | trials = 10
495 | for _ in range(trials):
496 | video_length = 10
497 | # The returned masks are either stationary (50%) or moving (50%)
498 | masks = create_random_shape_with_random_motion(video_length,
499 | imageHeight=240,
500 | imageWidth=432)
501 |
502 | for m in masks:
503 | cv2.imshow('mask', np.array(m))
504 | cv2.waitKey(500)
505 |
--------------------------------------------------------------------------------
/datasets/KITTI-360EX/InnerSphere/test.json:
--------------------------------------------------------------------------------
1 | {"00000719": 100, "00000720": 100, "00000721": 100, "00000722": 100, "00000723": 100, "00000724": 100, "00000725": 100, "00000726": 100, "00000727": 100, "00000728": 100, "00000729": 100, "00000730": 100, "00000731": 100, "00000732": 100, "00000733": 100, "00000734": 100, "00000735": 100, "00000736": 100, "00000737": 100, "00000738": 100, "00000739": 100, "00000740": 100, "00000741": 100, "00000742": 100, "00000743": 100, "00000744": 100, "00000745": 100, "00000746": 100, "00000747": 100, "00000748": 100, "00000749": 100, "00000750": 100, "00000751": 100, "00000752": 100, "00000753": 100, "00000754": 100, "00000755": 100, "00000756": 100}
--------------------------------------------------------------------------------
/datasets/KITTI-360EX/InnerSphere/train.json:
--------------------------------------------------------------------------------
1 | {"00000000": 100, "00000001": 100, "00000002": 100, "00000003": 100, "00000004": 100, "00000005": 100, "00000006": 100, "00000007": 100, "00000008": 100, "00000009": 100, "00000010": 100, "00000011": 100, "00000012": 100, "00000013": 100, "00000014": 100, "00000015": 100, "00000016": 100, "00000017": 100, "00000018": 100, "00000019": 100, "00000020": 100, "00000021": 100, "00000022": 100, "00000023": 100, "00000024": 100, "00000025": 100, "00000026": 100, "00000027": 100, "00000028": 100, "00000029": 100, "00000030": 100, "00000031": 100, "00000032": 100, "00000033": 100, "00000034": 100, "00000035": 100, "00000036": 100, "00000037": 100, "00000038": 100, "00000039": 100, "00000040": 100, "00000041": 100, "00000042": 100, "00000043": 100, "00000044": 100, "00000045": 100, "00000046": 100, "00000047": 100, "00000048": 100, "00000049": 100, "00000050": 100, "00000051": 100, "00000052": 100, "00000053": 100, "00000054": 100, "00000055": 100, "00000056": 100, "00000057": 100, "00000058": 100, "00000059": 100, "00000060": 100, "00000061": 100, "00000062": 100, "00000063": 100, "00000064": 100, "00000065": 100, "00000066": 100, "00000067": 100, "00000068": 100, "00000069": 100, "00000070": 100, "00000071": 100, "00000072": 100, "00000073": 100, "00000074": 100, "00000075": 100, "00000076": 100, "00000077": 100, "00000078": 100, "00000079": 100, "00000080": 100, "00000081": 100, "00000082": 100, "00000083": 100, "00000084": 100, "00000085": 100, "00000086": 100, "00000087": 100, "00000088": 100, "00000089": 100, "00000090": 100, "00000091": 100, "00000092": 100, "00000093": 100, "00000094": 100, "00000095": 100, "00000096": 100, "00000097": 100, "00000098": 100, "00000099": 100, "00000100": 100, "00000101": 100, "00000102": 100, "00000103": 100, "00000104": 100, "00000105": 100, "00000106": 100, "00000107": 100, "00000108": 100, "00000109": 100, "00000110": 100, "00000111": 100, "00000112": 100, "00000113": 100, "00000114": 100, "00000115": 100, "00000116": 100, "00000117": 100, "00000118": 100, "00000119": 100, "00000120": 100, "00000121": 100, "00000122": 100, "00000123": 100, "00000124": 100, "00000125": 100, "00000126": 100, "00000127": 100, "00000128": 100, "00000129": 100, "00000130": 100, "00000131": 100, "00000132": 100, "00000133": 100, "00000134": 100, "00000135": 100, "00000136": 100, "00000137": 100, "00000138": 100, "00000139": 100, "00000140": 100, "00000141": 100, "00000142": 100, "00000143": 100, "00000144": 100, "00000145": 100, "00000146": 100, "00000147": 100, "00000148": 100, "00000149": 100, "00000150": 100, "00000151": 100, "00000152": 100, "00000153": 100, "00000154": 100, "00000155": 100, "00000156": 100, "00000157": 100, "00000158": 100, "00000159": 100, "00000160": 100, "00000161": 100, "00000162": 100, "00000163": 100, "00000164": 100, "00000165": 100, "00000166": 100, "00000167": 100, "00000168": 100, "00000169": 100, "00000170": 100, "00000171": 100, "00000172": 100, "00000173": 100, "00000174": 100, "00000175": 100, "00000176": 100, "00000177": 100, "00000178": 100, "00000179": 100, "00000180": 100, "00000181": 100, "00000182": 100, "00000183": 100, "00000184": 100, "00000185": 100, "00000186": 100, "00000187": 100, "00000188": 100, "00000189": 100, "00000190": 100, "00000191": 100, "00000192": 100, "00000193": 100, "00000194": 100, "00000195": 100, "00000196": 100, "00000197": 100, "00000198": 100, "00000199": 100, "00000200": 100, "00000201": 100, "00000202": 100, "00000203": 100, "00000204": 100, "00000205": 100, "00000206": 100, "00000207": 100, "00000208": 100, "00000209": 100, "00000210": 100, "00000211": 100, "00000212": 100, "00000213": 100, "00000214": 100, "00000215": 100, "00000216": 100, "00000217": 100, "00000218": 100, "00000219": 100, "00000220": 100, "00000221": 100, "00000222": 100, "00000223": 100, "00000224": 100, "00000225": 100, "00000226": 100, "00000227": 100, "00000228": 100, "00000229": 100, "00000230": 100, "00000231": 100, "00000232": 100, "00000233": 100, "00000234": 100, "00000235": 100, "00000236": 100, "00000237": 100, "00000238": 100, "00000239": 100, "00000240": 100, "00000241": 100, "00000242": 100, "00000243": 100, "00000244": 100, "00000245": 100, "00000246": 100, "00000247": 100, "00000248": 100, "00000249": 100, "00000250": 100, "00000251": 100, "00000252": 100, "00000253": 100, "00000254": 100, "00000255": 100, "00000256": 100, "00000257": 100, "00000258": 100, "00000259": 100, "00000260": 100, "00000261": 100, "00000262": 100, "00000263": 100, "00000264": 100, "00000265": 100, "00000266": 100, "00000267": 100, "00000268": 100, "00000269": 100, "00000270": 100, "00000271": 100, "00000272": 100, "00000273": 100, "00000274": 100, "00000275": 100, "00000276": 100, "00000277": 100, "00000278": 100, "00000279": 100, "00000280": 100, "00000281": 100, "00000282": 100, "00000283": 100, "00000284": 100, "00000285": 100, "00000286": 100, "00000287": 100, "00000288": 100, "00000289": 100, "00000290": 100, "00000291": 100, "00000292": 100, "00000293": 100, "00000294": 100, "00000295": 100, "00000296": 100, "00000297": 100, "00000298": 100, "00000299": 100, "00000300": 100, "00000301": 100, "00000302": 100, "00000303": 100, "00000304": 100, "00000305": 100, "00000306": 100, "00000307": 100, "00000308": 100, "00000309": 100, "00000310": 100, "00000311": 100, "00000312": 100, "00000313": 100, "00000314": 100, "00000315": 100, "00000316": 100, "00000317": 100, "00000318": 100, "00000319": 100, "00000320": 100, "00000321": 100, "00000322": 100, "00000323": 100, "00000324": 100, "00000325": 100, "00000326": 100, "00000327": 100, "00000328": 100, "00000329": 100, "00000330": 100, "00000331": 100, "00000332": 100, "00000333": 100, "00000334": 100, "00000335": 100, "00000336": 100, "00000337": 100, "00000338": 100, "00000339": 100, "00000340": 100, "00000341": 100, "00000342": 100, "00000343": 100, "00000344": 100, "00000345": 100, "00000346": 100, "00000347": 100, "00000348": 100, "00000349": 100, "00000350": 100, "00000351": 100, "00000352": 100, "00000353": 100, "00000354": 100, "00000355": 100, "00000356": 100, "00000357": 100, "00000358": 100, "00000359": 100, "00000360": 100, "00000361": 100, "00000362": 100, "00000363": 100, "00000364": 100, "00000365": 100, "00000366": 100, "00000367": 100, "00000368": 100, "00000369": 100, "00000370": 100, "00000371": 100, "00000372": 100, "00000373": 100, "00000374": 100, "00000375": 100, "00000376": 100, "00000377": 100, "00000378": 100, "00000379": 100, "00000380": 100, "00000381": 100, "00000382": 100, "00000383": 100, "00000384": 100, "00000385": 100, "00000386": 100, "00000387": 100, "00000388": 100, "00000389": 100, "00000390": 100, "00000391": 100, "00000392": 100, "00000393": 100, "00000394": 100, "00000395": 100, "00000396": 100, "00000397": 100, "00000398": 100, "00000399": 100, "00000400": 100, "00000401": 100, "00000402": 100, "00000403": 100, "00000404": 100, "00000405": 100, "00000406": 100, "00000407": 100, "00000408": 100, "00000409": 100, "00000410": 100, "00000411": 100, "00000412": 100, "00000413": 100, "00000414": 100, "00000415": 100, "00000416": 100, "00000417": 100, "00000418": 100, "00000419": 100, "00000420": 100, "00000421": 100, "00000422": 100, "00000423": 100, "00000424": 100, "00000425": 100, "00000426": 100, "00000427": 100, "00000428": 100, "00000429": 100, "00000430": 100, "00000431": 100, "00000432": 100, "00000433": 100, "00000434": 100, "00000435": 100, "00000436": 100, "00000437": 100, "00000438": 100, "00000439": 100, "00000440": 100, "00000441": 100, "00000442": 100, "00000443": 100, "00000444": 100, "00000445": 100, "00000446": 100, "00000447": 100, "00000448": 100, "00000449": 100, "00000450": 100, "00000451": 100, "00000452": 100, "00000453": 100, "00000454": 100, "00000455": 100, "00000456": 100, "00000457": 100, "00000458": 100, "00000459": 100, "00000460": 100, "00000461": 100, "00000462": 100, "00000463": 100, "00000464": 100, "00000465": 100, "00000466": 100, "00000467": 100, "00000468": 100, "00000469": 100, "00000470": 100, "00000471": 100, "00000472": 100, "00000473": 100, "00000474": 100, "00000475": 100, "00000476": 100, "00000477": 100, "00000478": 100, "00000479": 100, "00000480": 100, "00000481": 100, "00000482": 100, "00000483": 100, "00000484": 100, "00000485": 100, "00000486": 100, "00000487": 100, "00000488": 100, "00000489": 100, "00000490": 100, "00000491": 100, "00000492": 100, "00000493": 100, "00000494": 100, "00000495": 100, "00000496": 100, "00000497": 100, "00000498": 100, "00000499": 100, "00000500": 100, "00000501": 100, "00000502": 100, "00000503": 100, "00000504": 100, "00000505": 100, "00000506": 100, "00000507": 100, "00000508": 100, "00000509": 100, "00000510": 100, "00000511": 100, "00000512": 100, "00000513": 100, "00000514": 100, "00000515": 100, "00000516": 100, "00000517": 100, "00000518": 100, "00000519": 100, "00000520": 100, "00000521": 100, "00000522": 100, "00000523": 100, "00000524": 100, "00000525": 100, "00000526": 100, "00000527": 100, "00000528": 100, "00000529": 100, "00000530": 100, "00000531": 100, "00000532": 100, "00000533": 100, "00000534": 100, "00000535": 100, "00000536": 100, "00000537": 100, "00000538": 100, "00000539": 100, "00000540": 100, "00000541": 100, "00000542": 100, "00000543": 100, "00000544": 100, "00000545": 100, "00000546": 100, "00000547": 100, "00000548": 100, "00000549": 100, "00000550": 100, "00000551": 100, "00000552": 100, "00000553": 100, "00000554": 100, "00000555": 100, "00000556": 100, "00000557": 100, "00000558": 100, "00000559": 100, "00000560": 100, "00000561": 100, "00000562": 100, "00000563": 100, "00000564": 100, "00000565": 100, "00000566": 100, "00000567": 100, "00000568": 100, "00000569": 100, "00000570": 100, "00000571": 100, "00000572": 100, "00000573": 100, "00000574": 100, "00000575": 100, "00000576": 100, "00000577": 100, "00000578": 100, "00000579": 100, "00000580": 100, "00000581": 100, "00000582": 100, "00000583": 100, "00000584": 100, "00000585": 100, "00000586": 100, "00000587": 100, "00000588": 100, "00000589": 100, "00000590": 100, "00000591": 100, "00000592": 100, "00000593": 100, "00000594": 100, "00000595": 100, "00000596": 100, "00000597": 100, "00000598": 100, "00000599": 100, "00000600": 100, "00000601": 100, "00000602": 100, "00000603": 100, "00000604": 100, "00000605": 100, "00000606": 100, "00000607": 100, "00000608": 100, "00000609": 100, "00000610": 100, "00000611": 100, "00000612": 100, "00000613": 100, "00000614": 100, "00000615": 100, "00000616": 100, "00000617": 100, "00000618": 100, "00000619": 100, "00000620": 100, "00000621": 100, "00000622": 100, "00000623": 100, "00000624": 100, "00000625": 100, "00000626": 100, "00000627": 100, "00000628": 100, "00000629": 100, "00000630": 100, "00000631": 100, "00000632": 100, "00000633": 100, "00000634": 100, "00000635": 100, "00000636": 100, "00000637": 100, "00000638": 100, "00000639": 100, "00000640": 100, "00000641": 100, "00000642": 100, "00000643": 100, "00000644": 100, "00000645": 100, "00000646": 100, "00000647": 100, "00000648": 100, "00000649": 100, "00000650": 100, "00000651": 100, "00000652": 100, "00000653": 100, "00000654": 100, "00000655": 100, "00000656": 100, "00000657": 100, "00000658": 100, "00000659": 100, "00000660": 100, "00000661": 100, "00000662": 100, "00000663": 100, "00000664": 100, "00000665": 100, "00000666": 100, "00000667": 100, "00000668": 100, "00000669": 100, "00000670": 100, "00000671": 100, "00000672": 100, "00000673": 100, "00000674": 100, "00000675": 100, "00000676": 100, "00000677": 100, "00000678": 100, "00000679": 100, "00000680": 100, "00000681": 100, "00000682": 100, "00000683": 100, "00000684": 100, "00000685": 100, "00000686": 100, "00000687": 100, "00000688": 100, "00000689": 100, "00000690": 100, "00000691": 100, "00000692": 100, "00000693": 100, "00000694": 100, "00000695": 100, "00000696": 100, "00000697": 100, "00000698": 100, "00000699": 100, "00000700": 100, "00000701": 100, "00000702": 100, "00000703": 100, "00000704": 100, "00000705": 100, "00000706": 100, "00000707": 100, "00000708": 100, "00000709": 100, "00000710": 100, "00000711": 100, "00000712": 100, "00000713": 100, "00000714": 100, "00000715": 100, "00000716": 100, "00000717": 100, "00000718": 100}
--------------------------------------------------------------------------------
/datasets/KITTI-360EX/OuterPinhole/test.json:
--------------------------------------------------------------------------------
1 | {"00000722": 100, "00000723": 100, "00000724": 100, "00000725": 100, "00000726": 100, "00000727": 100, "00000728": 100, "00000729": 100, "00000730": 100, "00000731": 100, "00000732": 100, "00000733": 100, "00000734": 100, "00000735": 100, "00000736": 100, "00000737": 100, "00000738": 100, "00000739": 100, "00000740": 100, "00000741": 100, "00000742": 100, "00000743": 100, "00000744": 100, "00000745": 100, "00000746": 100, "00000747": 100, "00000748": 100, "00000749": 100, "00000750": 100, "00000751": 100, "00000752": 100, "00000753": 100, "00000754": 100, "00000755": 100, "00000756": 100, "00000757": 100, "00000758": 100, "00000759": 100}
--------------------------------------------------------------------------------
/datasets/KITTI-360EX/OuterPinhole/train.json:
--------------------------------------------------------------------------------
1 | {"00000000": 100, "00000001": 100, "00000002": 100, "00000003": 100, "00000004": 100, "00000005": 100, "00000006": 100, "00000007": 100, "00000008": 100, "00000009": 100, "00000010": 100, "00000011": 100, "00000012": 100, "00000013": 100, "00000014": 100, "00000015": 100, "00000016": 100, "00000017": 100, "00000018": 100, "00000019": 100, "00000020": 100, "00000021": 100, "00000022": 100, "00000023": 100, "00000024": 100, "00000025": 100, "00000026": 100, "00000027": 100, "00000028": 100, "00000029": 100, "00000030": 100, "00000031": 100, "00000032": 100, "00000033": 100, "00000034": 100, "00000035": 100, "00000036": 100, "00000037": 100, "00000038": 100, "00000039": 100, "00000040": 100, "00000041": 100, "00000042": 100, "00000043": 100, "00000044": 100, "00000045": 100, "00000046": 100, "00000047": 100, "00000048": 100, "00000049": 100, "00000050": 100, "00000051": 100, "00000052": 100, "00000053": 100, "00000054": 100, "00000055": 100, "00000056": 100, "00000057": 100, "00000058": 100, "00000059": 100, "00000060": 100, "00000061": 100, "00000062": 100, "00000063": 100, "00000064": 100, "00000065": 100, "00000066": 100, "00000067": 100, "00000068": 100, "00000069": 100, "00000070": 100, "00000071": 100, "00000072": 100, "00000073": 100, "00000074": 100, "00000075": 100, "00000076": 100, "00000077": 100, "00000078": 100, "00000079": 100, "00000080": 100, "00000081": 100, "00000082": 100, "00000083": 100, "00000084": 100, "00000085": 100, "00000086": 100, "00000087": 100, "00000088": 100, "00000089": 100, "00000090": 100, "00000091": 100, "00000092": 100, "00000093": 100, "00000094": 100, "00000095": 100, "00000096": 100, "00000097": 100, "00000098": 100, "00000099": 100, "00000100": 100, "00000101": 100, "00000102": 100, "00000103": 100, "00000104": 100, "00000105": 100, "00000106": 100, "00000107": 100, "00000108": 100, "00000109": 100, "00000110": 100, "00000111": 100, "00000112": 100, "00000113": 100, "00000114": 100, "00000115": 100, "00000116": 100, "00000117": 100, "00000118": 100, "00000119": 100, "00000120": 100, "00000121": 100, "00000122": 100, "00000123": 100, "00000124": 100, "00000125": 100, "00000126": 100, "00000127": 100, "00000128": 100, "00000129": 100, "00000130": 100, "00000131": 100, "00000132": 100, "00000133": 100, "00000134": 100, "00000135": 100, "00000136": 100, "00000137": 100, "00000138": 100, "00000139": 100, "00000140": 100, "00000141": 100, "00000142": 100, "00000143": 100, "00000144": 100, "00000145": 100, "00000146": 100, "00000147": 100, "00000148": 100, "00000149": 100, "00000150": 100, "00000151": 100, "00000152": 100, "00000153": 100, "00000154": 100, "00000155": 100, "00000156": 100, "00000157": 100, "00000158": 100, "00000159": 100, "00000160": 100, "00000161": 100, "00000162": 100, "00000163": 100, "00000164": 100, "00000165": 100, "00000166": 100, "00000167": 100, "00000168": 100, "00000169": 100, "00000170": 100, "00000171": 100, "00000172": 100, "00000173": 100, "00000174": 100, "00000175": 100, "00000176": 100, "00000177": 100, "00000178": 100, "00000179": 100, "00000180": 100, "00000181": 100, "00000182": 100, "00000183": 100, "00000184": 100, "00000185": 100, "00000186": 100, "00000187": 100, "00000188": 100, "00000189": 100, "00000190": 100, "00000191": 100, "00000192": 100, "00000193": 100, "00000194": 100, "00000195": 100, "00000196": 100, "00000197": 100, "00000198": 100, "00000199": 100, "00000200": 100, "00000201": 100, "00000202": 100, "00000203": 100, "00000204": 100, "00000205": 100, "00000206": 100, "00000207": 100, "00000208": 100, "00000209": 100, "00000210": 100, "00000211": 100, "00000212": 100, "00000213": 100, "00000214": 100, "00000215": 100, "00000216": 100, "00000217": 100, "00000218": 100, "00000219": 100, "00000220": 100, "00000221": 100, "00000222": 100, "00000223": 100, "00000224": 100, "00000225": 100, "00000226": 100, "00000227": 100, "00000228": 100, "00000229": 100, "00000230": 100, "00000231": 100, "00000232": 100, "00000233": 100, "00000234": 100, "00000235": 100, "00000236": 100, "00000237": 100, "00000238": 100, "00000239": 100, "00000240": 100, "00000241": 100, "00000242": 100, "00000243": 100, "00000244": 100, "00000245": 100, "00000246": 100, "00000247": 100, "00000248": 100, "00000249": 100, "00000250": 100, "00000251": 100, "00000252": 100, "00000253": 100, "00000254": 100, "00000255": 100, "00000256": 100, "00000257": 100, "00000258": 100, "00000259": 100, "00000260": 100, "00000261": 100, "00000262": 100, "00000263": 100, "00000264": 100, "00000265": 100, "00000266": 100, "00000267": 100, "00000268": 100, "00000269": 100, "00000270": 100, "00000271": 100, "00000272": 100, "00000273": 100, "00000274": 100, "00000275": 100, "00000276": 100, "00000277": 100, "00000278": 100, "00000279": 100, "00000280": 100, "00000281": 100, "00000282": 100, "00000283": 100, "00000284": 100, "00000285": 100, "00000286": 100, "00000287": 100, "00000288": 100, "00000289": 100, "00000290": 100, "00000291": 100, "00000292": 100, "00000293": 100, "00000294": 100, "00000295": 100, "00000296": 100, "00000297": 100, "00000298": 100, "00000299": 100, "00000300": 100, "00000301": 100, "00000302": 100, "00000303": 100, "00000304": 100, "00000305": 100, "00000306": 100, "00000307": 100, "00000308": 100, "00000309": 100, "00000310": 100, "00000311": 100, "00000312": 100, "00000313": 100, "00000314": 100, "00000315": 100, "00000316": 100, "00000317": 100, "00000318": 100, "00000319": 100, "00000320": 100, "00000321": 100, "00000322": 100, "00000323": 100, "00000324": 100, "00000325": 100, "00000326": 100, "00000327": 100, "00000328": 100, "00000329": 100, "00000330": 100, "00000331": 100, "00000332": 100, "00000333": 100, "00000334": 100, "00000335": 100, "00000336": 100, "00000337": 100, "00000338": 100, "00000339": 100, "00000340": 100, "00000341": 100, "00000342": 100, "00000343": 100, "00000344": 100, "00000345": 100, "00000346": 100, "00000347": 100, "00000348": 100, "00000349": 100, "00000350": 100, "00000351": 100, "00000352": 100, "00000353": 100, "00000354": 100, "00000355": 100, "00000356": 100, "00000357": 100, "00000358": 100, "00000359": 100, "00000360": 100, "00000361": 100, "00000362": 100, "00000363": 100, "00000364": 100, "00000365": 100, "00000366": 100, "00000367": 100, "00000368": 100, "00000369": 100, "00000370": 100, "00000371": 100, "00000372": 100, "00000373": 100, "00000374": 100, "00000375": 100, "00000376": 100, "00000377": 100, "00000378": 100, "00000379": 100, "00000380": 100, "00000381": 100, "00000382": 100, "00000383": 100, "00000384": 100, "00000385": 100, "00000386": 100, "00000387": 100, "00000388": 100, "00000389": 100, "00000390": 100, "00000391": 100, "00000392": 100, "00000393": 100, "00000394": 100, "00000395": 100, "00000396": 100, "00000397": 100, "00000398": 100, "00000399": 100, "00000400": 100, "00000401": 100, "00000402": 100, "00000403": 100, "00000404": 100, "00000405": 100, "00000406": 100, "00000407": 100, "00000408": 100, "00000409": 100, "00000410": 100, "00000411": 100, "00000412": 100, "00000413": 100, "00000414": 100, "00000415": 100, "00000416": 100, "00000417": 100, "00000418": 100, "00000419": 100, "00000420": 100, "00000421": 100, "00000422": 100, "00000423": 100, "00000424": 100, "00000425": 100, "00000426": 100, "00000427": 100, "00000428": 100, "00000429": 100, "00000430": 100, "00000431": 100, "00000432": 100, "00000433": 100, "00000434": 100, "00000435": 100, "00000436": 100, "00000437": 100, "00000438": 100, "00000439": 100, "00000440": 100, "00000441": 100, "00000442": 100, "00000443": 100, "00000444": 100, "00000445": 100, "00000446": 100, "00000447": 100, "00000448": 100, "00000449": 100, "00000450": 100, "00000451": 100, "00000452": 100, "00000453": 100, "00000454": 100, "00000455": 100, "00000456": 100, "00000457": 100, "00000458": 100, "00000459": 100, "00000460": 100, "00000461": 100, "00000462": 100, "00000463": 100, "00000464": 100, "00000465": 100, "00000466": 100, "00000467": 100, "00000468": 100, "00000469": 100, "00000470": 100, "00000471": 100, "00000472": 100, "00000473": 100, "00000474": 100, "00000475": 100, "00000476": 100, "00000477": 100, "00000478": 100, "00000479": 100, "00000480": 100, "00000481": 100, "00000482": 100, "00000483": 100, "00000484": 100, "00000485": 100, "00000486": 100, "00000487": 100, "00000488": 100, "00000489": 100, "00000490": 100, "00000491": 100, "00000492": 100, "00000493": 100, "00000494": 100, "00000495": 100, "00000496": 100, "00000497": 100, "00000498": 100, "00000499": 100, "00000500": 100, "00000501": 100, "00000502": 100, "00000503": 100, "00000504": 100, "00000505": 100, "00000506": 100, "00000507": 100, "00000508": 100, "00000509": 100, "00000510": 100, "00000511": 100, "00000512": 100, "00000513": 100, "00000514": 100, "00000515": 100, "00000516": 100, "00000517": 100, "00000518": 100, "00000519": 100, "00000520": 100, "00000521": 100, "00000522": 100, "00000523": 100, "00000524": 100, "00000525": 100, "00000526": 100, "00000527": 100, "00000528": 100, "00000529": 100, "00000530": 100, "00000531": 100, "00000532": 100, "00000533": 100, "00000534": 100, "00000535": 100, "00000536": 100, "00000537": 100, "00000538": 100, "00000539": 100, "00000540": 100, "00000541": 100, "00000542": 100, "00000543": 100, "00000544": 100, "00000545": 100, "00000546": 100, "00000547": 100, "00000548": 100, "00000549": 100, "00000550": 100, "00000551": 100, "00000552": 100, "00000553": 100, "00000554": 100, "00000555": 100, "00000556": 100, "00000557": 100, "00000558": 100, "00000559": 100, "00000560": 100, "00000561": 100, "00000562": 100, "00000563": 100, "00000564": 100, "00000565": 100, "00000566": 100, "00000567": 100, "00000568": 100, "00000569": 100, "00000570": 100, "00000571": 100, "00000572": 100, "00000573": 100, "00000574": 100, "00000575": 100, "00000576": 100, "00000577": 100, "00000578": 100, "00000579": 100, "00000580": 100, "00000581": 100, "00000582": 100, "00000583": 100, "00000584": 100, "00000585": 100, "00000586": 100, "00000587": 100, "00000588": 100, "00000589": 100, "00000590": 100, "00000591": 100, "00000592": 100, "00000593": 100, "00000594": 100, "00000595": 100, "00000596": 100, "00000597": 100, "00000598": 100, "00000599": 100, "00000600": 100, "00000601": 100, "00000602": 100, "00000603": 100, "00000604": 100, "00000605": 100, "00000606": 100, "00000607": 100, "00000608": 100, "00000609": 100, "00000610": 100, "00000611": 100, "00000612": 100, "00000613": 100, "00000614": 100, "00000615": 100, "00000616": 100, "00000617": 100, "00000618": 100, "00000619": 100, "00000620": 100, "00000621": 100, "00000622": 100, "00000623": 100, "00000624": 100, "00000625": 100, "00000626": 100, "00000627": 100, "00000628": 100, "00000629": 100, "00000630": 100, "00000631": 100, "00000632": 100, "00000633": 100, "00000634": 100, "00000635": 100, "00000636": 100, "00000637": 100, "00000638": 100, "00000639": 100, "00000640": 100, "00000641": 100, "00000642": 100, "00000643": 100, "00000644": 100, "00000645": 100, "00000646": 100, "00000647": 100, "00000648": 100, "00000649": 100, "00000650": 100, "00000651": 100, "00000652": 100, "00000653": 100, "00000654": 100, "00000655": 100, "00000656": 100, "00000657": 100, "00000658": 100, "00000659": 100, "00000660": 100, "00000661": 100, "00000662": 100, "00000663": 100, "00000664": 100, "00000665": 100, "00000666": 100, "00000667": 100, "00000668": 100, "00000669": 100, "00000670": 100, "00000671": 100, "00000672": 100, "00000673": 100, "00000674": 100, "00000675": 100, "00000676": 100, "00000677": 100, "00000678": 100, "00000679": 100, "00000680": 100, "00000681": 100, "00000682": 100, "00000683": 100, "00000684": 100, "00000685": 100, "00000686": 100, "00000687": 100, "00000688": 100, "00000689": 100, "00000690": 100, "00000691": 100, "00000692": 100, "00000693": 100, "00000694": 100, "00000695": 100, "00000696": 100, "00000697": 100, "00000698": 100, "00000699": 100, "00000700": 100, "00000701": 100, "00000702": 100, "00000703": 100, "00000704": 100, "00000705": 100, "00000706": 100, "00000707": 100, "00000708": 100, "00000709": 100, "00000710": 100, "00000711": 100, "00000712": 100, "00000713": 100, "00000714": 100, "00000715": 100, "00000716": 100, "00000717": 100, "00000718": 100, "00000719": 100, "00000720": 100, "00000721": 100}
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import cv2
3 | import numpy as np
4 | import importlib
5 | import os
6 | import time
7 | import json
8 | import random
9 | import argparse
10 | from PIL import Image
11 |
12 | import torch
13 | from torch.utils.data import DataLoader
14 |
15 | from core.dataset import TestDataset
16 | from core.metrics import calc_psnr_and_ssim, calculate_i3d_activations, calculate_vfid, init_i3d_model
17 |
18 | # global variables
19 | # w h can be changed by args.output_size
20 | w, h = 432, 240 # default acc. test setting in e2fgvi for davis dataset and KITTI-EXO
21 | # w, h = 336, 336 # default acc. test setting for KITTI-EXI
22 | ref_length = 10 # non-local frames的步幅间隔,此处为每10帧取1帧NLF
23 |
24 |
25 | def read_cfg(args):
26 | """read flowlens cfg from config file"""
27 | # loading configs
28 | config = json.load(open(args.cfg_path))
29 |
30 | # # # # pass config to args # # # #
31 | args.dataset = config['train_data_loader']['name']
32 | args.data_root = config['train_data_loader']['data_root']
33 | args.output_size = [432, 240]
34 | args.output_size[0], args.output_size[1] = (config['train_data_loader']['w'], config['train_data_loader']['h'])
35 | args.model_win_size = config['model'].get('window_size', None)
36 | args.model_output_size = config['model'].get('output_size', None)
37 | args.neighbor_stride = config['train_data_loader'].get('num_local_frames', 10)
38 |
39 | # 是否使用spynet作为光流补全网络 (FlowLens-S)
40 | config['model']['spy_net'] = config['model'].get('spy_net', 0)
41 | if config['model']['spy_net'] != 0:
42 | # default for FlowLens-S
43 | args.spy_net = True
44 | else:
45 | # default for FlowLens
46 | args.spy_net = False
47 |
48 | if config['model']['net'] == 'flowlens':
49 |
50 | # 定义transformer的深度
51 | if config['model']['depths'] != 0:
52 | args.depths = config['model']['depths']
53 | else:
54 | # 使用网络默认的深度
55 | args.depths = None
56 |
57 | # 定义trans block的window个数(token除以window划分大小)
58 | config['model']['window_size'] = config['model'].get('window_size', 0)
59 | if config['model']['window_size'] != 0:
60 | args.window_size = config['model']['window_size']
61 | else:
62 | # 使用网络默认的window
63 | args.window_size = None
64 |
65 | # 定义是大模型还是小模型
66 | if config['model']['small_model'] != 0:
67 | args.small_model = True
68 | else:
69 | args.small_model = False
70 |
71 | # 是否冻结dcn参数
72 | config['model']['freeze_dcn'] = config['model'].get('freeze_dcn', 0)
73 | if config['model']['freeze_dcn'] != 0:
74 | args.freeze_dcn = True
75 | else:
76 | # default
77 | args.freeze_dcn = False
78 |
79 | # # # # pass config to args # # # #
80 |
81 | return args
82 |
83 |
84 | # sample reference frames from the whole video with mem support
85 | def get_ref_index_mem(length, neighbor_ids, same_id=False):
86 | """smae_id(bool): If True, allow same ref and local id as input."""
87 | ref_index = []
88 | for i in range(0, length, ref_length):
89 | if same_id:
90 | # 允许相同id
91 | ref_index.append(i)
92 | else:
93 | # 不允许相同的id,当出现相同id时找到最近的一个不同的i
94 | if i not in neighbor_ids:
95 | ref_index.append(i)
96 | else:
97 | lf_id_avg = sum(neighbor_ids)/len(neighbor_ids) # 计算 local frame id 平均值
98 | for _iter in range(0, 100):
99 | if i < (length - 1):
100 | # 不能超过视频长度
101 | if i == 0:
102 | # 第0帧的时候重复,直接取到下一个 NLF + 5 +5是为了防止和下一个重复的 nlf id 改的id重复
103 | i = ref_length + args.neighbor_stride
104 | ref_index.append(i)
105 | break
106 | elif i < lf_id_avg:
107 | # 往前找不重复的参考帧, 防止都往一个方向找而重复
108 | i -= 1
109 | else:
110 | # 往后找不重复的参考帧
111 | i += 1
112 | else:
113 | # 超过了直接用最后一帧,然后退出
114 | ref_index.append(i)
115 | break
116 |
117 | if i not in neighbor_ids:
118 | ref_index.append(i)
119 | break
120 |
121 | return ref_index
122 |
123 |
124 | # sample reference frames from the remain frames with random behavior like trainning
125 | def get_ref_index_mem_random(neighbor_ids, video_length, num_ref_frame=3, before_nlf=False):
126 | if not before_nlf:
127 | # 从过去和未来采集非局部帧
128 | complete_idx_set = list(range(video_length))
129 | else:
130 | # 非局部帧只会从过去的视频帧中选取,不会使用未来的信息
131 | complete_idx_set = list(range(neighbor_ids[-1]))
132 | # complete_idx_set = list(range(video_length))
133 |
134 | remain_idx = list(set(complete_idx_set) - set(neighbor_ids))
135 |
136 | # 当只用过去的帧作为非局部帧时,可能会出现过去的帧数量少于非局部帧需求的问题,比如视频的一开始
137 | if before_nlf:
138 | if len(remain_idx) < num_ref_frame:
139 | # 则我们允许从局部帧中采样非局部帧 转换为set可以去除重复元素
140 | remain_idx = list(set(remain_idx + neighbor_ids))
141 |
142 | ref_index = sorted(random.sample(remain_idx, num_ref_frame))
143 | return ref_index
144 |
145 |
146 | def main_worker(args):
147 | args = read_cfg(args=args) # 读取网络的所有设置
148 | w = args.output_size[0]
149 | h = args.output_size[1]
150 | args.size = (w, h)
151 |
152 | # set up datasets and data loader
153 | # default result
154 | test_dataset = TestDataset(args)
155 |
156 | test_loader = DataLoader(test_dataset,
157 | batch_size=1,
158 | shuffle=False,
159 | num_workers=args.num_workers)
160 |
161 | # set up models
162 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
163 | net = importlib.import_module('model.' + args.model)
164 |
165 | if args.model == 'flowlens':
166 | model = net.InpaintGenerator(freeze_dcn=args.freeze_dcn, spy_net=args.spy_net, depths=args.depths,
167 | window_size=args.model_win_size, output_size=args.model_output_size,
168 | small_model=args.small_model).to(device)
169 | else:
170 | # 加载一些尺寸窗口设置
171 | model = net.InpaintGenerator(window_size=args.model_win_size, output_size=args.model_output_size).to(device)
172 |
173 | if args.ckpt is not None:
174 | data = torch.load(args.ckpt, map_location=device)
175 | model.load_state_dict(data)
176 | print(f'Loading from: {args.ckpt}')
177 |
178 | # # half
179 | # model = model.half()
180 |
181 | model.eval()
182 |
183 | total_frame_psnr = []
184 | total_frame_ssim = []
185 |
186 | output_i3d_activations = []
187 | real_i3d_activations = []
188 |
189 | print('Start evaluation...')
190 |
191 | time_all = 0
192 | len_all = 0
193 |
194 | # create results directory
195 | if args.ckpt is not None:
196 | ckpt = args.ckpt.split('/')[-1]
197 | else:
198 | ckpt = 'random'
199 |
200 | if args.fov is not None:
201 | if args.reverse:
202 | result_path = os.path.join('results', f'{args.model}+_{ckpt}_{args.fov}_{args.dataset}')
203 | else:
204 | result_path = os.path.join('results', f'{args.model}_{ckpt}_{args.fov}_{args.dataset}')
205 | else:
206 | if args.reverse:
207 | result_path = os.path.join('results', f'{args.model}+_{ckpt}_{args.dataset}')
208 | else:
209 | result_path = os.path.join('results', f'{args.model}_{ckpt}_{args.dataset}')
210 |
211 | # if args.fov is not None:
212 | # if args.reverse:
213 | # result_path = os.path.join('/workspace/mnt/storage/shihao/BEV_Flow/tmp', f'{args.model}+_{ckpt}_{args.fov}_{args.dataset}')
214 | # else:
215 | # result_path = os.path.join('/workspace/mnt/storage/shihao/BEV_Flow/tmp', f'{args.model}_{ckpt}_{args.fov}_{args.dataset}')
216 | # else:
217 | # if args.reverse:
218 | # result_path = os.path.join('/workspace/mnt/storage/shihao/BEV_Flow/tmp', f'{args.model}+_{ckpt}_{args.dataset}')
219 | # else:
220 | # result_path = os.path.join('/workspace/mnt/storage/shihao/BEV_Flow/tmp', f'{args.model}_{ckpt}_{args.dataset}')
221 |
222 | if not os.path.exists(result_path):
223 | os.makedirs(result_path)
224 | eval_summary = open(
225 | os.path.join(result_path, f"{args.model}_{args.dataset}_metrics.txt"),
226 | "w")
227 |
228 | i3d_model = init_i3d_model()
229 |
230 | for index, items in enumerate(test_loader):
231 |
232 | for blk in model.transformer:
233 | try:
234 | blk.attn.m_k = []
235 | blk.attn.m_v = []
236 | except:
237 | pass
238 |
239 | frames, masks, video_name, frames_PIL = items
240 |
241 | # # half
242 | # frames = frames.half()
243 | # masks = masks.half()
244 |
245 | video_length = frames.size(1)
246 | frames, masks = frames.to(device), masks.to(device)
247 | ori_frames = frames_PIL # 原始帧,可视为真值
248 | ori_frames = [
249 | ori_frames[i].squeeze().cpu().numpy() for i in range(video_length)
250 | ]
251 | comp_frames = [None] * video_length # 补全帧
252 |
253 | len_all += video_length
254 |
255 | # complete holes by our model
256 | # 当这个循环走完的时候,一段视频已经被补全了
257 | for f in range(0, video_length, args.neighbor_stride):
258 | if args.same_memory:
259 | # 尽可能与video in-painting的测试逻辑一致
260 | # 输入的时间维度T保持一致
261 | if (f - args.neighbor_stride > 0) and (f + args.neighbor_stride + 1 < video_length):
262 | # 视频首尾均不会越界,不需要补充额外帧
263 | neighbor_ids = [
264 | i for i in range(max(0, f - args.neighbor_stride),
265 | min(video_length, f + args.neighbor_stride + 1))
266 | ] # neighbor_ids即为Local Frames, 局部帧
267 | else:
268 | # 视频越界,补充额外帧保证记忆缓存的时间通道维度一致,后面也可以尝试放到trans里直接复制特征的时间维度
269 | neighbor_ids = [
270 | i for i in range(max(0, f - args.neighbor_stride),
271 | min(video_length, f + args.neighbor_stride + 1))
272 | ] # neighbor_ids即为Local Frames, 局部帧
273 | repeat_num = (args.neighbor_stride * 2 + 1) - len(neighbor_ids)
274 |
275 | lf_id_avg = sum(neighbor_ids) / len(neighbor_ids) # 计算 local frame id 平均值
276 | first_id = neighbor_ids[0]
277 | for ii in range(0, repeat_num):
278 | # 保证局部窗口的大小一致,防止缓存通道数变化
279 | if lf_id_avg < (video_length // 2):
280 | # 前半段视频也向前找局部id,防止和下一个窗口的输入完全一样
281 | new_id = video_length - 1 - ii
282 | else:
283 | # 后半段视频向前找局部id
284 | new_id = first_id - 1 - ii
285 | neighbor_ids.append(new_id)
286 |
287 | neighbor_ids = sorted(neighbor_ids) # 重新排序
288 |
289 | else:
290 | # 与记忆力模型的训练逻辑一致
291 | if not args.recurrent:
292 | if video_length < (f + args.neighbor_stride):
293 | neighbor_ids = [
294 | i for i in range(f, video_length)
295 | ] # 时间上不重叠的窗口,每个局部帧只会被计算一次,视频尾部可能不足5帧局部帧,复制最后一帧补全数量
296 | for repeat_idx in range(0, args.neighbor_stride - len(neighbor_ids)):
297 | neighbor_ids.append(neighbor_ids[-1])
298 | else:
299 | neighbor_ids = [
300 | i for i in range(f, f + args.neighbor_stride)
301 | ] # 时间上不重叠的窗口,每个局部帧只会被计算一次
302 | else:
303 | # 在recurrent模式下,每次局部窗口都为1
304 | neighbor_ids = [f]
305 |
306 | # 为了保证时间维度一致, 允许输入相同id的帧
307 | if args.same_memory:
308 | ref_ids = get_ref_index_mem(video_length, neighbor_ids, same_id=False) # ref_ids即为Non-Local Frames, 非局部帧
309 | elif args.past_ref:
310 | ref_ids = get_ref_index_mem_random(neighbor_ids, video_length, num_ref_frame=3, before_nlf=True) # 只允许过去的参考帧
311 | else:
312 | ref_ids = get_ref_index_mem_random(neighbor_ids, video_length, num_ref_frame=3) # 与序列训练同样的非局部帧输入逻辑
313 |
314 | ref_ids = sorted(ref_ids) # 重新排序
315 | selected_imgs_lf = frames[:1, neighbor_ids, :, :, :]
316 | selected_imgs_nlf = frames[:1, ref_ids, :, :, :]
317 | selected_imgs = torch.cat((selected_imgs_lf, selected_imgs_nlf), dim=1)
318 | selected_masks_lf = masks[:1, neighbor_ids, :, :, :]
319 | selected_masks_nlf = masks[:1, ref_ids, :, :, :]
320 | selected_masks = torch.cat((selected_masks_lf, selected_masks_nlf), dim=1)
321 |
322 | with torch.no_grad():
323 | masked_frames = selected_imgs * (1 - selected_masks)
324 |
325 | torch.cuda.synchronize()
326 | time_start = time.time()
327 |
328 | pred_img, _ = model(masked_frames, len(neighbor_ids)) # forward里会输入局部帧数量来对两种数据分开处理
329 |
330 | # 水平与竖直翻转增强
331 | if args.reverse:
332 | masked_frames_horizontal_aug = torch.from_numpy(masked_frames.cpu().numpy()[:, :, :, :, ::-1].copy()).cuda()
333 | pred_img_horizontal_aug, _ = model(masked_frames_horizontal_aug, len(neighbor_ids))
334 | pred_img_horizontal_aug = torch.from_numpy(pred_img_horizontal_aug.cpu().numpy()[:, :, :, ::-1].copy()).cuda()
335 | masked_frames_vertical_aug = torch.from_numpy(masked_frames.cpu().numpy()[:, :, :, ::-1, :].copy()).cuda()
336 | pred_img_vertical_aug, _ = model(masked_frames_vertical_aug, len(neighbor_ids))
337 | pred_img_vertical_aug = torch.from_numpy(pred_img_vertical_aug.cpu().numpy()[:, :, ::-1, :].copy()).cuda()
338 |
339 | pred_img = 1 / 3 * (pred_img + pred_img_horizontal_aug + pred_img_vertical_aug)
340 |
341 | torch.cuda.synchronize()
342 | time_end = time.time()
343 | time_sum = time_end - time_start
344 | time_all += time_sum
345 |
346 | pred_img = (pred_img + 1) / 2
347 | pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255
348 | binary_masks = masks[0, neighbor_ids, :, :, :].cpu().permute(
349 | 0, 2, 3, 1).numpy().astype(np.uint8)
350 | for i in range(len(neighbor_ids)):
351 | idx = neighbor_ids[i]
352 | img = np.array(pred_img[i]).astype(np.uint8) * binary_masks[i] \
353 | + ori_frames[idx] * (1 - binary_masks[i])
354 |
355 | if comp_frames[idx] is None:
356 | # 如果第一次补全Local Frame中的某帧,直接记录到补全帧list (comp_frames) 里
357 | # good_fusion下所有img多出一个‘次数’通道,用来记录所有的结果
358 | comp_frames[idx] = img[np.newaxis, :, :, :]
359 |
360 | # 直接把所有结果都记录下来,最后沿着通道平均
361 | else:
362 | comp_frames[idx] = np.concatenate((comp_frames[idx], img[np.newaxis, :, :, :]), axis=0)
363 | ########################################################################################
364 |
365 | # 对于good_fusion, 推理一遍后需要沿着axis=0取平均
366 | for idx, comp_frame in zip(range(0, video_length), comp_frames):
367 | comp_frame = comp_frame.astype(np.float32).sum(axis=0)/comp_frame.shape[0]
368 | comp_frames[idx] = comp_frame
369 |
370 | # calculate metrics
371 | cur_video_psnr = []
372 | cur_video_ssim = []
373 | comp_PIL = [] # to calculate VFID
374 | frames_PIL = []
375 | for ori, comp in zip(ori_frames, comp_frames):
376 | psnr, ssim = calc_psnr_and_ssim(ori, comp)
377 |
378 | cur_video_psnr.append(psnr)
379 | cur_video_ssim.append(ssim)
380 |
381 | total_frame_psnr.append(psnr)
382 | total_frame_ssim.append(ssim)
383 |
384 | frames_PIL.append(Image.fromarray(ori.astype(np.uint8)))
385 | comp_PIL.append(Image.fromarray(comp.astype(np.uint8)))
386 | cur_psnr = sum(cur_video_psnr) / len(cur_video_psnr)
387 | cur_ssim = sum(cur_video_ssim) / len(cur_video_ssim)
388 |
389 | # saving i3d activations
390 | frames_i3d, comp_i3d = calculate_i3d_activations(frames_PIL,
391 | comp_PIL,
392 | i3d_model,
393 | device=device)
394 | real_i3d_activations.append(frames_i3d)
395 | output_i3d_activations.append(comp_i3d)
396 |
397 | print(
398 | f'[{index+1:3}/{len(test_loader)}] Name: {str(video_name):25} | PSNR/SSIM: {cur_psnr:.4f}/{cur_ssim:.4f}'
399 | )
400 | eval_summary.write(
401 | f'[{index+1:3}/{len(test_loader)}] Name: {str(video_name):25} | PSNR/SSIM: {cur_psnr:.4f}/{cur_ssim:.4f}\n'
402 | )
403 |
404 | print('Average run time: (%f) per frame' % (time_all/len_all))
405 |
406 | # saving images for evaluating warpping errors
407 | if args.save_results:
408 | save_frame_path = os.path.join(result_path, video_name[0])
409 | os.makedirs(save_frame_path, exist_ok=False)
410 |
411 | for i, frame in enumerate(comp_frames):
412 | cv2.imwrite(
413 | os.path.join(save_frame_path,
414 | str(i).zfill(5) + '.png'),
415 | cv2.cvtColor(frame.astype(np.uint8), cv2.COLOR_RGB2BGR))
416 |
417 | avg_frame_psnr = sum(total_frame_psnr) / len(total_frame_psnr)
418 | avg_frame_ssim = sum(total_frame_ssim) / len(total_frame_ssim)
419 |
420 | fid_score = calculate_vfid(real_i3d_activations, output_i3d_activations)
421 | print('Finish evaluation... Average Frame PSNR/SSIM/VFID: '
422 | f'{avg_frame_psnr:.2f}/{avg_frame_ssim:.4f}/{fid_score:.3f}')
423 | eval_summary.write(
424 | 'Finish evaluation... Average Frame PSNR/SSIM/VFID: '
425 | f'{avg_frame_psnr:.2f}/{avg_frame_ssim:.4f}/{fid_score:.3f}')
426 | eval_summary.close()
427 |
428 | print('All average forward run time: (%f) per frame' % (time_all / len_all))
429 |
430 | return len(total_frame_psnr)
431 |
432 |
433 | if __name__ == '__main__':
434 | parser = argparse.ArgumentParser(description='FlowLens')
435 | parser.add_argument('--cfg_path', default='configs/KITTI360EX-I_FlowLens_small_re.json')
436 | parser.add_argument('--dataset', choices=['KITTI360-EX'], type=str) # 相当于train的‘name’
437 | parser.add_argument('--data_root', type=str)
438 | parser.add_argument('--output_size', type=int, nargs='+', default=[432, 240])
439 | parser.add_argument('--object', action='store_true', default=False) # if true, use object removal mask
440 | parser.add_argument('--fov', choices=['fov5', 'fov10', 'fov20'], type=str) # 对于KITTI360-EX, 测试需要输入fov
441 | parser.add_argument('--past_ref', action='store_true', default=True) # 对于KITTI360-EX, 测试时只允许使用之前的参考帧
442 | parser.add_argument('--model', choices=['flowlens'], type=str)
443 | parser.add_argument('--ckpt', type=str, default=None)
444 | parser.add_argument('--save_results', action='store_true', default=False)
445 | parser.add_argument('--num_workers', default=4, type=int)
446 | parser.add_argument('--same_memory', action='store_true', default=False,
447 | help='test with memory ability in video in-painting style')
448 | parser.add_argument('--reverse', action='store_true', default=False,
449 | help='test with horizontal and vertical reverse augmentation')
450 | parser.add_argument('--model_win_size', type=int, nargs='+', default=[5, 9])
451 | parser.add_argument('--model_output_size', type=int, nargs='+', default=[60, 108])
452 | parser.add_argument('--recurrent', action='store_true', default=False,
453 | help='keep window = 1, stride = 1 to not use any local future info')
454 | args = parser.parse_args()
455 |
456 | if args.dataset == 'KITTI360-EX':
457 | # 对于KITTI360-EX, 测试时只允许使用之前的参考帧
458 | args.past_ref = True
459 |
460 | frame_num = main_worker(args)
461 |
--------------------------------------------------------------------------------
/model/flowlens.py:
--------------------------------------------------------------------------------
1 | ''' Towards An End-to-End Framework for Video Inpainting
2 | '''
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from model.modules.flow_comp import SPyNet
9 | from model.modules.maskflownets import MaskFlowNetS
10 | from model.modules.feat_prop import BidirectionalPropagation, SecondOrderDeformableAlignment
11 | from model.modules.mix_focal_transformer import MixFocalTransformerBlock, SoftSplit, SoftComp
12 | from model.modules.spectral_norm import spectral_norm as _spectral_norm
13 |
14 |
15 | class BaseNetwork(nn.Module):
16 | def __init__(self):
17 | super(BaseNetwork, self).__init__()
18 |
19 | def print_network(self):
20 | if isinstance(self, list):
21 | self = self[0]
22 | num_params = 0
23 | for param in self.parameters():
24 | num_params += param.numel()
25 | print(
26 | 'Network [%s] was created. Total number of parameters: %.1f million. '
27 | 'To see the architecture, do print(network).' %
28 | (type(self).__name__, num_params / 1000000))
29 |
30 | def init_weights(self, init_type='normal', gain=0.02):
31 | '''
32 | initialize network's weights
33 | init_type: normal | xavier | kaiming | orthogonal
34 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
35 | '''
36 | def init_func(m):
37 | classname = m.__class__.__name__
38 | if classname.find('InstanceNorm2d') != -1:
39 | if hasattr(m, 'weight') and m.weight is not None:
40 | nn.init.constant_(m.weight.data, 1.0)
41 | if hasattr(m, 'bias') and m.bias is not None:
42 | nn.init.constant_(m.bias.data, 0.0)
43 | elif hasattr(m, 'weight') and (classname.find('Conv') != -1
44 | or classname.find('Linear') != -1):
45 | if init_type == 'normal':
46 | nn.init.normal_(m.weight.data, 0.0, gain)
47 | elif init_type == 'xavier':
48 | nn.init.xavier_normal_(m.weight.data, gain=gain)
49 | elif init_type == 'xavier_uniform':
50 | nn.init.xavier_uniform_(m.weight.data, gain=1.0)
51 | elif init_type == 'kaiming':
52 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
53 | elif init_type == 'orthogonal':
54 | nn.init.orthogonal_(m.weight.data, gain=gain)
55 | elif init_type == 'none': # uses pytorch's default init method
56 | m.reset_parameters()
57 | else:
58 | raise NotImplementedError(
59 | 'initialization method [%s] is not implemented' %
60 | init_type)
61 | if hasattr(m, 'bias') and m.bias is not None:
62 | nn.init.constant_(m.bias.data, 0.0)
63 |
64 | self.apply(init_func)
65 |
66 | # propagate to children
67 | for m in self.children():
68 | if hasattr(m, 'init_weights'):
69 | m.init_weights(init_type, gain)
70 |
71 |
72 | class Encoder(nn.Module):
73 | def __init__(self, out_channel=128, reduction=1):
74 | super(Encoder, self).__init__()
75 | self.group = [1, 2, 4, 8, 1]
76 | self.layers = nn.ModuleList([
77 | nn.Conv2d(3, 64//reduction, kernel_size=3, stride=2, padding=1),
78 | nn.LeakyReLU(0.2, inplace=True),
79 | nn.Conv2d(64//reduction, 64//reduction, kernel_size=3, stride=1, padding=1),
80 | nn.LeakyReLU(0.2, inplace=True),
81 | nn.Conv2d(64//reduction, 128//reduction, kernel_size=3, stride=2, padding=1),
82 | nn.LeakyReLU(0.2, inplace=True),
83 | nn.Conv2d(128//reduction, 256//reduction, kernel_size=3, stride=1, padding=1),
84 | nn.LeakyReLU(0.2, inplace=True),
85 | nn.Conv2d(256//reduction, 384//reduction, kernel_size=3, stride=1, padding=1, groups=1),
86 | nn.LeakyReLU(0.2, inplace=True),
87 | nn.Conv2d(640//reduction, 512//reduction, kernel_size=3, stride=1, padding=1, groups=2),
88 | nn.LeakyReLU(0.2, inplace=True),
89 | nn.Conv2d(768//reduction, 384//reduction, kernel_size=3, stride=1, padding=1, groups=4),
90 | nn.LeakyReLU(0.2, inplace=True),
91 | nn.Conv2d(640//reduction, 256//reduction, kernel_size=3, stride=1, padding=1, groups=8),
92 | nn.LeakyReLU(0.2, inplace=True),
93 | nn.Conv2d(512//reduction, out_channel, kernel_size=3, stride=1, padding=1, groups=1),
94 | nn.LeakyReLU(0.2, inplace=True)
95 | ])
96 |
97 | def forward(self, x):
98 | bt, c, _, _ = x.size()
99 | # h, w = h//4, w//4
100 | out = x
101 | for i, layer in enumerate(self.layers):
102 | if i == 8:
103 | x0 = out
104 | _, _, h, w = x0.size()
105 | if i > 8 and i % 2 == 0:
106 | g = self.group[(i - 8) // 2]
107 | x = x0.view(bt, g, -1, h, w)
108 | o = out.view(bt, g, -1, h, w)
109 | out = torch.cat([x, o], 2).view(bt, -1, h, w)
110 | out = layer(out)
111 | return out
112 |
113 |
114 | class deconv(nn.Module):
115 | def __init__(self,
116 | input_channel,
117 | output_channel,
118 | kernel_size=3,
119 | padding=0):
120 | super().__init__()
121 | self.conv = nn.Conv2d(input_channel,
122 | output_channel,
123 | kernel_size=kernel_size,
124 | stride=1,
125 | padding=padding)
126 |
127 | def forward(self, x):
128 | x = F.interpolate(x,
129 | scale_factor=2,
130 | mode='bilinear',
131 | align_corners=True)
132 | return self.conv(x)
133 |
134 |
135 | class InpaintGenerator(BaseNetwork):
136 | """
137 | window_size # 窗口的尺寸,相当于patch数量除以4
138 | output_size # 输出的尺寸,训练尺寸//4
139 | spy_net # 如果为True,使用spynet替换maskflownets计算光流
140 | freeze_dcn # 如果为True,冻结dcn参数
141 | """
142 | def __init__(self, init_weights=True, freeze_dcn=False, spy_net=False,
143 | flow_res=0.25,
144 | depths=9,
145 | window_size=None, output_size=None, small_model=False):
146 | super(InpaintGenerator, self).__init__()
147 |
148 | if not small_model:
149 | # large model:
150 | channel = 256 # default
151 | hidden = 512 # default
152 | reduction = 1 # default
153 | else:
154 | # v2
155 | channel = 128
156 | hidden = 256
157 | reduction = 2
158 |
159 | # 设置transformer参数
160 | # 设置trans block的数量
161 | if depths is None:
162 | depths = 2 # 0.08s/frame, 0.07s/frame with hidden = 128,
163 | else:
164 | depths = depths
165 |
166 | # 只有一个stage
167 | # 设置不同层的head数量
168 | # 默认条件下每层都使用4个head,相当于宽度和高度各2个head
169 | num_heads = [4] * depths
170 |
171 | # 光流推理分辨率
172 | self.flow_res = flow_res
173 |
174 | # encoder
175 | self.encoder = Encoder(out_channel=channel // 2, reduction=reduction)
176 |
177 | # decoder-default
178 | self.decoder = nn.Sequential(
179 | deconv(channel // 2, 128, kernel_size=3, padding=1),
180 | nn.LeakyReLU(0.2, inplace=True),
181 | nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
182 | nn.LeakyReLU(0.2, inplace=True),
183 | deconv(64, 64, kernel_size=3, padding=1),
184 | nn.LeakyReLU(0.2, inplace=True),
185 | nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1))
186 |
187 | # feature propagation module
188 | self.feat_prop_module = BidirectionalPropagation(channel // 2, freeze_dcn=freeze_dcn)
189 |
190 | # soft split and soft composition
191 | kernel_size = (7, 7) # 滑块的大小
192 | padding = (3, 3) # 两个方向上隐式填0的数量
193 | stride = (3, 3) # 滑块的步长
194 | if output_size is None:
195 | # 默认的输出尺寸
196 | output_size = (60, 108)
197 | else:
198 | output_size = (output_size[0], output_size[1])
199 | t2t_params = {
200 | 'kernel_size': kernel_size,
201 | 'stride': stride,
202 | 'padding': padding
203 | }
204 |
205 | self.ss = SoftSplit(channel // 2,
206 | hidden,
207 | kernel_size,
208 | stride,
209 | padding,
210 | t2t_param=t2t_params)
211 |
212 | self.sc = SoftComp(channel // 2, hidden, kernel_size, stride, padding)
213 |
214 | n_vecs = 1 # 计算token的数量
215 | for i, d in enumerate(kernel_size):
216 | n_vecs *= int((output_size[i] + 2 * padding[i] -
217 | (d - 1) - 1) / stride[i] + 1)
218 |
219 | blocks = []
220 | if window_size is None:
221 | window_size = [(5, 9)] * depths
222 | focal_windows = [(5, 9)] * depths
223 | else:
224 | window_size_h = window_size[0]
225 | window_size_w = window_size[1]
226 | window_size = [(window_size_h, window_size_w)] * depths
227 | focal_windows = [(window_size_h, window_size_w)] * depths
228 | focal_levels = [2] * depths
229 | pool_method = "fc"
230 |
231 | # default temporal focal transformer
232 | for i in range(depths):
233 | # 只有第一层有记忆力
234 | if (i + 1) == 1:
235 | # 第一层有记忆力
236 | blocks.append(
237 | MixFocalTransformerBlock(dim=hidden,
238 | num_heads=num_heads[i],
239 | window_size=window_size[i],
240 | focal_level=focal_levels[i],
241 | focal_window=focal_windows[i],
242 | n_vecs=n_vecs,
243 | t2t_params=t2t_params,
244 | pool_method=pool_method,
245 | memory=True,
246 | cs_win_strip=1), )
247 |
248 | else:
249 | # 后面的层没有记忆
250 | blocks.append(
251 | MixFocalTransformerBlock(dim=hidden,
252 | num_heads=num_heads[i],
253 | window_size=window_size[i],
254 | focal_level=focal_levels[i],
255 | focal_window=focal_windows[i],
256 | n_vecs=n_vecs,
257 | t2t_params=t2t_params,
258 | pool_method=pool_method,
259 | memory=False))
260 |
261 | self.transformer = nn.Sequential(*blocks)
262 |
263 | if init_weights:
264 | self.init_weights()
265 | # Need to initial the weights of MSDeformAttn specifically
266 | for m in self.modules():
267 | if isinstance(m, SecondOrderDeformableAlignment):
268 | m.init_offset()
269 |
270 | # flow completion network
271 | if spy_net:
272 | # 使用SpyNet
273 | self.update_MFN = SPyNet()
274 | else:
275 | # 使用MaskFlowNetS
276 | self.update_MFN = MaskFlowNetS()
277 |
278 | def forward_bidirect_flow(self, masked_local_frames):
279 | b, l_t, c, h, w = masked_local_frames.size()
280 | scale_factor = int(1 / self.flow_res) # 1/4 -> 4 用来恢复尺度
281 |
282 | # compute forward and backward flows of masked frames
283 | masked_local_frames = F.interpolate(masked_local_frames.view(
284 | -1, c, h, w),
285 | scale_factor=self.flow_res, # 1/4 for default
286 | mode='bilinear',
287 | align_corners=True,
288 | recompute_scale_factor=True)
289 | masked_local_frames = masked_local_frames.view(b, l_t, c, h // scale_factor,
290 | w // scale_factor)
291 | mlf_1 = masked_local_frames[:, :-1, :, :, :].reshape(
292 | -1, c, h // scale_factor, w // scale_factor)
293 | mlf_2 = masked_local_frames[:, 1:, :, :, :].reshape(
294 | -1, c, h // scale_factor, w // scale_factor)
295 | pred_flows_forward = self.update_MFN(mlf_1, mlf_2)
296 | pred_flows_backward = self.update_MFN(mlf_2, mlf_1)
297 |
298 | pred_flows_forward = pred_flows_forward.view(b, l_t - 1, 2, h // scale_factor,
299 | w // scale_factor)
300 | pred_flows_backward = pred_flows_backward.view(b, l_t - 1, 2, h // scale_factor,
301 | w // scale_factor)
302 |
303 | # 最后必须resize到1/4来做feature warping
304 | if scale_factor != 4:
305 | pred_flows_forward = F.interpolate(pred_flows_forward.view(-1, 2, h // scale_factor, w // scale_factor),
306 | scale_factor=scale_factor / 4,
307 | mode='bilinear',
308 | align_corners=True,
309 | recompute_scale_factor=True).view(b, l_t - 1, 2, h // 4,
310 | w // 4)
311 | pred_flows_backward = F.interpolate(pred_flows_backward.view(-1, 2, h // scale_factor, w // scale_factor),
312 | scale_factor=scale_factor / 4,
313 | mode='bilinear',
314 | align_corners=True,
315 | recompute_scale_factor=True).view(b, l_t - 1, 2, h // 4,
316 | w // 4)
317 |
318 | return pred_flows_forward, pred_flows_backward
319 |
320 | def forward(self, masked_frames, num_local_frames=5):
321 | l_t = num_local_frames
322 | b, t, ori_c, ori_h, ori_w = masked_frames.size()
323 |
324 | # normalization before feeding into the flow completion module
325 | masked_local_frames = (masked_frames[:, :l_t, ...] + 1) / 2
326 | pred_flows = self.forward_bidirect_flow(masked_local_frames)
327 |
328 | # extracting features and performing the feature propagation on local features
329 | enc_feat = self.encoder(masked_frames.view(b * t, ori_c, ori_h, ori_w))
330 | _, c, h, w = enc_feat.size()
331 | fold_output_size = (h, w)
332 | local_feat = enc_feat.view(b, t, c, h, w)[:, :l_t, ...]
333 | ref_feat = enc_feat.view(b, t, c, h, w)[:, l_t:, ...]
334 |
335 | local_feat = self.feat_prop_module(local_feat, pred_flows[1],
336 | pred_flows[0])
337 |
338 | enc_feat = torch.cat((local_feat, ref_feat), dim=1)
339 |
340 | # content hallucination through stacking multiple temporal focal transformer blocks
341 | trans_feat = self.ss(enc_feat.view(-1, c, h, w), b, fold_output_size) # [B, t, f_h, f_w, hidden]
342 |
343 | trans_feat = self.transformer([trans_feat, fold_output_size, l_t]) # 比默认行为多传一个lt
344 |
345 | # 软组合
346 | trans_feat = self.sc(trans_feat[0], t, fold_output_size)
347 |
348 | trans_feat = trans_feat.view(b, t, -1, h, w)
349 | enc_feat = enc_feat + trans_feat # 残差链接
350 |
351 | # decode frames from features
352 | output = self.decoder(enc_feat.view(b * t, c, h, w))
353 | output = torch.tanh(output)
354 | return output, pred_flows
355 |
356 |
357 | # ######################################################################
358 | # Discriminator for Temporal Patch GAN
359 | # ######################################################################
360 |
361 |
362 | class Discriminator(BaseNetwork):
363 | def __init__(self,
364 | in_channels=3,
365 | use_sigmoid=False,
366 | use_spectral_norm=True,
367 | init_weights=True):
368 | super(Discriminator, self).__init__()
369 | self.use_sigmoid = use_sigmoid
370 | nf = 32
371 |
372 | self.conv = nn.Sequential(
373 | spectral_norm(
374 | nn.Conv3d(in_channels=in_channels,
375 | out_channels=nf * 1,
376 | kernel_size=(3, 5, 5),
377 | stride=(1, 2, 2),
378 | padding=1,
379 | bias=not use_spectral_norm), use_spectral_norm),
380 | # nn.InstanceNorm2d(64, track_running_stats=False),
381 | nn.LeakyReLU(0.2, inplace=True),
382 | spectral_norm(
383 | nn.Conv3d(nf * 1,
384 | nf * 2,
385 | kernel_size=(3, 5, 5),
386 | stride=(1, 2, 2),
387 | padding=(1, 2, 2),
388 | bias=not use_spectral_norm), use_spectral_norm),
389 | # nn.InstanceNorm2d(128, track_running_stats=False),
390 | nn.LeakyReLU(0.2, inplace=True),
391 | spectral_norm(
392 | nn.Conv3d(nf * 2,
393 | nf * 4,
394 | kernel_size=(3, 5, 5),
395 | stride=(1, 2, 2),
396 | padding=(1, 2, 2),
397 | bias=not use_spectral_norm), use_spectral_norm),
398 | # nn.InstanceNorm2d(256, track_running_stats=False),
399 | nn.LeakyReLU(0.2, inplace=True),
400 | spectral_norm(
401 | nn.Conv3d(nf * 4,
402 | nf * 4,
403 | kernel_size=(3, 5, 5),
404 | stride=(1, 2, 2),
405 | padding=(1, 2, 2),
406 | bias=not use_spectral_norm), use_spectral_norm),
407 | # nn.InstanceNorm2d(256, track_running_stats=False),
408 | nn.LeakyReLU(0.2, inplace=True),
409 | spectral_norm(
410 | nn.Conv3d(nf * 4,
411 | nf * 4,
412 | kernel_size=(3, 5, 5),
413 | stride=(1, 2, 2),
414 | padding=(1, 2, 2),
415 | bias=not use_spectral_norm), use_spectral_norm),
416 | # nn.InstanceNorm2d(256, track_running_stats=False),
417 | nn.LeakyReLU(0.2, inplace=True),
418 | nn.Conv3d(nf * 4,
419 | nf * 4,
420 | kernel_size=(3, 5, 5),
421 | stride=(1, 2, 2),
422 | padding=(1, 2, 2)))
423 |
424 | if init_weights:
425 | self.init_weights()
426 |
427 | def forward(self, xs):
428 | # T, C, H, W = xs.shape (old)
429 | # B, T, C, H, W (new)
430 | xs_t = torch.transpose(xs, 1, 2)
431 | feat = self.conv(xs_t)
432 | if self.use_sigmoid:
433 | feat = torch.sigmoid(feat)
434 | out = torch.transpose(feat, 1, 2) # B, T, C, H, W
435 | return out
436 |
437 |
438 | def spectral_norm(module, mode=True):
439 | if mode:
440 | return _spectral_norm(module)
441 | return module
442 |
--------------------------------------------------------------------------------
/model/modules/feat_prop.py:
--------------------------------------------------------------------------------
1 | """
2 | BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment, CVPR 2022
3 | """
4 | import torch
5 | import torch.nn as nn
6 |
7 | from mmcv.ops import ModulatedDeformConv2d, modulated_deform_conv2d
8 | from mmcv.cnn import constant_init
9 |
10 | from model.modules.flow_comp import flow_warp
11 |
12 |
13 | class SecondOrderDeformableAlignment(ModulatedDeformConv2d):
14 | """Second-order deformable alignment module."""
15 | def __init__(self, *args, **kwargs):
16 | self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10)
17 |
18 | super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs)
19 |
20 | self.conv_offset = nn.Sequential(
21 | nn.Conv2d(3 * self.out_channels + 4, self.out_channels, 3, 1, 1),
22 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
23 | nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
24 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
25 | nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
26 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
27 | nn.Conv2d(self.out_channels, 27 * self.deform_groups, 3, 1, 1),
28 | )
29 |
30 | self.init_offset()
31 |
32 | def init_offset(self):
33 | constant_init(self.conv_offset[-1], val=0, bias=0)
34 |
35 | def forward(self, x, extra_feat, flow_1, flow_2):
36 | extra_feat = torch.cat([extra_feat, flow_1, flow_2], dim=1)
37 | out = self.conv_offset(extra_feat)
38 | o1, o2, mask = torch.chunk(out, 3, dim=1)
39 |
40 | # offset
41 | offset = self.max_residue_magnitude * torch.tanh(
42 | torch.cat((o1, o2), dim=1))
43 | offset_1, offset_2 = torch.chunk(offset, 2, dim=1)
44 | offset_1 = offset_1 + flow_1.flip(1).repeat(1,
45 | offset_1.size(1) // 2, 1,
46 | 1)
47 | offset_2 = offset_2 + flow_2.flip(1).repeat(1,
48 | offset_2.size(1) // 2, 1,
49 | 1)
50 | offset = torch.cat([offset_1, offset_2], dim=1)
51 |
52 | # mask
53 | mask = torch.sigmoid(mask)
54 |
55 | # 默认使用的2阶对齐方法是dcn-v2,也就是调制可变形卷积,所谓的调制就是新增了一个和图像等大的mask,范围在0-1
56 | return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
57 | self.stride, self.padding,
58 | self.dilation, self.groups,
59 | self.deform_groups)
60 |
61 |
62 | class ConvBNReLU(nn.Module):
63 | """Conv with BN and ReLU, used for Simple Second Fusion"""
64 |
65 | def __init__(self,
66 | in_chan,
67 | out_chan,
68 | ks=3,
69 | stride=1,
70 | padding=1,
71 | *args,
72 | **kwargs):
73 | super(ConvBNReLU, self).__init__()
74 | self.conv = nn.Conv2d(
75 | in_chan,
76 | out_chan,
77 | kernel_size=ks,
78 | stride=stride,
79 | padding=padding,
80 | bias=False)
81 | self.bn = torch.nn.BatchNorm2d(out_chan)
82 | self.relu = nn.ReLU(inplace=True)
83 |
84 | def forward(self, x):
85 | x = self.conv(x)
86 | x = self.bn(x)
87 | x = self.relu(x)
88 | return x
89 |
90 | def init_weight(self):
91 | for ly in self.children():
92 | if isinstance(ly, nn.Conv2d):
93 | nn.init.kaiming_normal_(ly.weight, a=1)
94 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
95 |
96 |
97 | class BidirectionalPropagation(nn.Module):
98 | def __init__(self, channel, freeze_dcn=False):
99 | super(BidirectionalPropagation, self).__init__()
100 | modules = ['backward_', 'forward_']
101 |
102 | self.backbone = nn.ModuleDict()
103 | self.channel = channel
104 |
105 | self.deform_align = nn.ModuleDict()
106 |
107 | for i, module in enumerate(modules):
108 | self.deform_align[module] = SecondOrderDeformableAlignment(
109 | 2 * channel, channel, 3, padding=1, deform_groups=16)
110 |
111 | self.backbone[module] = nn.Sequential(
112 | nn.Conv2d((2 + i) * channel, channel, 3, 1, 1),
113 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
114 | nn.Conv2d(channel, channel, 3, 1, 1),
115 | )
116 |
117 | self.fusion = nn.Conv2d(2 * channel, channel, 1, 1, 0)
118 |
119 | if freeze_dcn:
120 | # 冻结dcn
121 | self.freeze(self.deform_align)
122 |
123 | @staticmethod
124 | def freeze(layer):
125 | """For freezing some layers."""
126 | for child in layer.children():
127 | for param in child.parameters():
128 | param.requires_grad = False
129 |
130 | def forward(self, x, flows_backward, flows_forward):
131 | """
132 | x shape : [b, t, c, h, w]
133 | return [b, t, c, h, w]
134 | """
135 | b, t, c, h, w = x.shape
136 | feats = {}
137 | feats['spatial'] = [x[:, i, :, :, :] for i in range(0, t)]
138 |
139 | for module_name in ['backward_', 'forward_']:
140 |
141 | feats[module_name] = []
142 |
143 | frame_idx = range(0, t)
144 | flow_idx = range(-1, t - 1)
145 | mapping_idx = list(range(0, len(feats['spatial'])))
146 | mapping_idx += mapping_idx[::-1]
147 |
148 | if 'backward' in module_name:
149 | frame_idx = frame_idx[::-1]
150 | flows = flows_backward
151 | else:
152 | flows = flows_forward
153 |
154 | feat_prop = x.new_zeros(b, self.channel, h, w)
155 |
156 | # 修正backward时存在i和idx不对应的bug
157 | for i, idx in enumerate(frame_idx):
158 | feat_current = feats['spatial'][mapping_idx[idx]]
159 |
160 | if i > 0:
161 | if 'backward' in module_name:
162 | flow_n1 = flows[:, flow_idx[idx]+1, :, :, :]
163 | else:
164 | flow_n1 = flows[:, flow_idx[i], :, :, :]
165 | cond_n1 = flow_warp(feat_prop, flow_n1.permute(0, 2, 3, 1))
166 |
167 | # initialize second-order features
168 | feat_n2 = torch.zeros_like(feat_prop)
169 | flow_n2 = torch.zeros_like(flow_n1)
170 | cond_n2 = torch.zeros_like(cond_n1)
171 | if i > 1:
172 | feat_n2 = feats[module_name][-2]
173 | if 'backward' in module_name:
174 | flow_n2 = flows[:, flow_idx[idx]+2, :, :, :]
175 | else:
176 | flow_n2 = flows[:, flow_idx[i - 1], :, :, :]
177 | flow_n2 = flow_n1 + flow_warp(
178 | flow_n2, flow_n1.permute(0, 2, 3, 1))
179 | cond_n2 = flow_warp(feat_n2,
180 | flow_n2.permute(0, 2, 3, 1))
181 |
182 | cond = torch.cat([cond_n1, feat_current, cond_n2], dim=1)
183 |
184 | # default
185 | feat_prop = torch.cat([feat_prop, feat_n2], dim=1)
186 | feat_prop = self.deform_align[module_name](feat_prop, cond,
187 | flow_n1,
188 | flow_n2)
189 |
190 | feat = [feat_current] + [
191 | feats[k][idx]
192 | for k in feats if k not in ['spatial', module_name]
193 | ] + [feat_prop]
194 |
195 | feat = torch.cat(feat, dim=1)
196 | feat_prop = feat_prop + self.backbone[module_name](feat)
197 | feats[module_name].append(feat_prop)
198 | ##################################################################
199 |
200 | if 'backward' in module_name:
201 | feats[module_name] = feats[module_name][::-1]
202 |
203 | outputs = []
204 | for i in range(0, t):
205 | align_feats = [feats[k].pop(0) for k in feats if k != 'spatial']
206 | align_feats = torch.cat(align_feats, dim=1)
207 | outputs.append(self.fusion(align_feats))
208 |
209 | return torch.stack(outputs, dim=1) + x
210 |
--------------------------------------------------------------------------------
/model/modules/flow_comp.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch
6 |
7 | from mmcv.cnn import ConvModule
8 | from mmcv.runner import load_checkpoint
9 | from model.modules.maskflownets import MaskFlowNetS
10 |
11 |
12 | class FlowCompletionLoss(nn.Module):
13 | """Flow completion loss"""
14 | def __init__(self, estimator='spy', device='cuda:0', flow_res=0.25):
15 | super().__init__()
16 | if estimator == 'spy':
17 | # default flow compute with spynet:
18 | self.fix_spynet = SPyNet()
19 | elif estimator == 'mfn':
20 | self.fix_spynet = MaskFlowNetS(device=device)
21 | else:
22 | raise TypeError('[estimator] should be spy or mfn, '
23 | f'but got {estimator}.')
24 |
25 | # 算GT的SpyNet锁了权重,补全光流的SpyNet没有锁权重
26 | for p in self.fix_spynet.parameters():
27 | p.requires_grad = False
28 |
29 | self.l1_criterion = nn.L1Loss()
30 |
31 | self.flow_res = flow_res # 在哪个分辨率计算光流
32 |
33 | def forward(self, pred_flows, gt_local_frames):
34 | b, l_t, c, h, w = gt_local_frames.size()
35 | scale_factor = int(1 / self.flow_res) # 1/4 -> 4 用来恢复尺度
36 |
37 | with torch.no_grad():
38 | # compute gt forward and backward flows
39 | gt_local_frames = F.interpolate(gt_local_frames.view(-1, c, h, w),
40 | scale_factor=self.flow_res,
41 | mode='bilinear',
42 | align_corners=True,
43 | recompute_scale_factor=True)
44 | gt_local_frames = gt_local_frames.view(b, l_t, c, h // scale_factor, w // scale_factor)
45 | gtlf_1 = gt_local_frames[:, :-1, :, :, :].reshape(
46 | -1, c, h // scale_factor, w // scale_factor)
47 | gtlf_2 = gt_local_frames[:, 1:, :, :, :].reshape(
48 | -1, c, h // scale_factor, w // scale_factor)
49 | gt_flows_forward = self.fix_spynet(gtlf_1, gtlf_2)
50 | gt_flows_backward = self.fix_spynet(gtlf_2, gtlf_1)
51 |
52 | # 最后必须resize到1/4来做feature warping的flow的监督
53 | if scale_factor != 4:
54 | gt_flows_forward = F.interpolate(gt_flows_forward.view(-1, 2, h // scale_factor, w // scale_factor),
55 | scale_factor=scale_factor / 4,
56 | mode='bilinear',
57 | align_corners=True,
58 | recompute_scale_factor=True).view(b, l_t - 1, 2, h // 4,
59 | w // 4)
60 | gt_flows_backward = F.interpolate(gt_flows_backward.view(-1, 2, h // scale_factor, w // scale_factor),
61 | scale_factor=scale_factor / 4,
62 | mode='bilinear',
63 | align_corners=True,
64 | recompute_scale_factor=True).view(b, l_t - 1, 2, h // 4,
65 | w // 4)
66 |
67 | # calculate loss for flow completion
68 | forward_flow_loss = self.l1_criterion(
69 | pred_flows[0].view(-1, 2, h // 4, w // 4), gt_flows_forward)
70 | backward_flow_loss = self.l1_criterion(
71 | pred_flows[1].view(-1, 2, h // 4, w // 4), gt_flows_backward)
72 | flow_loss = forward_flow_loss + backward_flow_loss
73 |
74 | return flow_loss
75 |
76 |
77 | class SPyNet(nn.Module):
78 | """SPyNet network structure.
79 | The difference to the SPyNet in [tof.py] is that
80 | 1. more SPyNetBasicModule is used in this version, and
81 | 2. no batch normalization is used in this version.
82 | Paper:
83 | Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017
84 | Args:
85 | pretrained (str): path for pre-trained SPyNet. Default: None.
86 | """
87 | def __init__(
88 | self,
89 | use_pretrain=True,
90 | pretrained='https://download.openmmlab.com/mmediting/restorers/basicvsr/spynet_20210409-c6c1bd09.pth',
91 | module_level=6
92 | ):
93 | super().__init__()
94 |
95 | self.basic_module = nn.ModuleList(
96 | [SPyNetBasicModule() for _ in range(module_level)])
97 |
98 | if use_pretrain:
99 | if isinstance(pretrained, str):
100 | print("load pretrained SPyNet...")
101 | load_checkpoint(self, pretrained, strict=True)
102 | elif pretrained is not None:
103 | raise TypeError('[pretrained] should be str or None, '
104 | f'but got {type(pretrained)}.')
105 |
106 | self.register_buffer(
107 | 'mean',
108 | torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
109 | self.register_buffer(
110 | 'std',
111 | torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
112 |
113 | def compute_flow(self, ref, supp):
114 | """Compute flow from ref to supp.
115 | Note that in this function, the images are already resized to a
116 | multiple of 32.
117 | Args:
118 | ref (Tensor): Reference image with shape of (n, 3, h, w).
119 | supp (Tensor): Supporting image with shape of (n, 3, h, w).
120 | Returns:
121 | Tensor: Estimated optical flow: (n, 2, h, w).
122 | """
123 | n, _, h, w = ref.size()
124 |
125 | # normalize the input images
126 | ref = [(ref - self.mean) / self.std]
127 | supp = [(supp - self.mean) / self.std]
128 |
129 | # generate downsampled frames
130 | # for level in range(5): # default
131 | for level in range(len(self.basic_module)-1):
132 | ref.append(
133 | F.avg_pool2d(input=ref[-1],
134 | kernel_size=2,
135 | stride=2,
136 | count_include_pad=False))
137 | supp.append(
138 | F.avg_pool2d(input=supp[-1],
139 | kernel_size=2,
140 | stride=2,
141 | count_include_pad=False))
142 | ref = ref[::-1]
143 | supp = supp[::-1]
144 |
145 | # flow computation
146 | # flow = ref[0].new_zeros(n, 2, h // 32, w // 32) # default
147 | reduce = 2**(len(self.basic_module)-1)
148 | flow = ref[0].new_zeros(n, 2, h // reduce, w // reduce)
149 |
150 | # for level in range(len(ref)): # default
151 | for level in range(len(self.basic_module)):
152 | if level == 0:
153 | flow_up = flow
154 | else:
155 | flow_up = F.interpolate(input=flow,
156 | scale_factor=2,
157 | mode='bilinear',
158 | align_corners=True) * 2.0
159 |
160 | # add the residue to the upsampled flow
161 | flow = flow_up + self.basic_module[level](torch.cat([
162 | ref[level],
163 | flow_warp(supp[level],
164 | flow_up.permute(0, 2, 3, 1).contiguous(),
165 | padding_mode='border'), flow_up
166 | ], 1))
167 |
168 | return flow
169 |
170 | def forward(self, ref, supp):
171 | """Forward function of SPyNet.
172 | This function computes the optical flow from ref to supp.
173 | Args:
174 | ref (Tensor): Reference image with shape of (n, 3, h, w).
175 | supp (Tensor): Supporting image with shape of (n, 3, h, w).
176 | Returns:
177 | Tensor: Estimated optical flow: (n, 2, h, w).
178 | """
179 |
180 | # upsize to a multiple of 32
181 | h, w = ref.shape[2:4]
182 | w_up = w if (w % 32) == 0 else 32 * (w // 32 + 1)
183 | h_up = h if (h % 32) == 0 else 32 * (h // 32 + 1)
184 | ref = F.interpolate(input=ref,
185 | size=(h_up, w_up),
186 | mode='bilinear',
187 | align_corners=False)
188 | supp = F.interpolate(input=supp,
189 | size=(h_up, w_up),
190 | mode='bilinear',
191 | align_corners=False)
192 |
193 | # compute flow, and resize back to the original resolution
194 | flow = F.interpolate(input=self.compute_flow(ref, supp),
195 | size=(h, w),
196 | mode='bilinear',
197 | align_corners=False)
198 |
199 | # adjust the flow values
200 | flow[:, 0, :, :] *= float(w) / float(w_up)
201 | flow[:, 1, :, :] *= float(h) / float(h_up)
202 |
203 | return flow
204 |
205 |
206 | class SPyNetBasicModule(nn.Module):
207 | """Basic Module for SPyNet.
208 | Paper:
209 | Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017
210 | """
211 | def __init__(self):
212 | super().__init__()
213 |
214 | self.basic_module = nn.Sequential(
215 | ConvModule(in_channels=8,
216 | out_channels=32,
217 | kernel_size=7,
218 | stride=1,
219 | padding=3,
220 | norm_cfg=None,
221 | act_cfg=dict(type='ReLU')),
222 | ConvModule(in_channels=32,
223 | out_channels=64,
224 | kernel_size=7,
225 | stride=1,
226 | padding=3,
227 | norm_cfg=None,
228 | act_cfg=dict(type='ReLU')),
229 | ConvModule(in_channels=64,
230 | out_channels=32,
231 | kernel_size=7,
232 | stride=1,
233 | padding=3,
234 | norm_cfg=None,
235 | act_cfg=dict(type='ReLU')),
236 | ConvModule(in_channels=32,
237 | out_channels=16,
238 | kernel_size=7,
239 | stride=1,
240 | padding=3,
241 | norm_cfg=None,
242 | act_cfg=dict(type='ReLU')),
243 | ConvModule(in_channels=16,
244 | out_channels=2,
245 | kernel_size=7,
246 | stride=1,
247 | padding=3,
248 | norm_cfg=None,
249 | act_cfg=None))
250 |
251 | def forward(self, tensor_input):
252 | """
253 | Args:
254 | tensor_input (Tensor): Input tensor with shape (b, 8, h, w).
255 | 8 channels contain:
256 | [reference image (3), neighbor image (3), initial flow (2)].
257 | Returns:
258 | Tensor: Refined flow with shape (b, 2, h, w)
259 | """
260 | return self.basic_module(tensor_input)
261 |
262 |
263 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
264 | def make_colorwheel():
265 | """
266 | Generates a color wheel for optical flow visualization as presented in:
267 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
268 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
269 |
270 | Code follows the original C++ source code of Daniel Scharstein.
271 | Code follows the the Matlab source code of Deqing Sun.
272 |
273 | Returns:
274 | np.ndarray: Color wheel
275 | """
276 |
277 | RY = 15
278 | YG = 6
279 | GC = 4
280 | CB = 11
281 | BM = 13
282 | MR = 6
283 |
284 | ncols = RY + YG + GC + CB + BM + MR
285 | colorwheel = np.zeros((ncols, 3))
286 | col = 0
287 |
288 | # RY
289 | colorwheel[0:RY, 0] = 255
290 | colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY)
291 | col = col + RY
292 | # YG
293 | colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG)
294 | colorwheel[col:col + YG, 1] = 255
295 | col = col + YG
296 | # GC
297 | colorwheel[col:col + GC, 1] = 255
298 | colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC)
299 | col = col + GC
300 | # CB
301 | colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB)
302 | colorwheel[col:col + CB, 2] = 255
303 | col = col + CB
304 | # BM
305 | colorwheel[col:col + BM, 2] = 255
306 | colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM)
307 | col = col + BM
308 | # MR
309 | colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR)
310 | colorwheel[col:col + MR, 0] = 255
311 | return colorwheel
312 |
313 |
314 | def flow_uv_to_colors(u, v, convert_to_bgr=False):
315 | """
316 | Applies the flow color wheel to (possibly clipped) flow components u and v.
317 |
318 | According to the C++ source code of Daniel Scharstein
319 | According to the Matlab source code of Deqing Sun
320 |
321 | Args:
322 | u (np.ndarray): Input horizontal flow of shape [H,W]
323 | v (np.ndarray): Input vertical flow of shape [H,W]
324 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
325 |
326 | Returns:
327 | np.ndarray: Flow visualization image of shape [H,W,3]
328 | """
329 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
330 | colorwheel = make_colorwheel() # shape [55x3]
331 | ncols = colorwheel.shape[0]
332 | rad = np.sqrt(np.square(u) + np.square(v))
333 | a = np.arctan2(-v, -u) / np.pi
334 | fk = (a + 1) / 2 * (ncols - 1)
335 | k0 = np.floor(fk).astype(np.int32)
336 | k1 = k0 + 1
337 | k1[k1 == ncols] = 0
338 | f = fk - k0
339 | for i in range(colorwheel.shape[1]):
340 | tmp = colorwheel[:, i]
341 | col0 = tmp[k0] / 255.0
342 | col1 = tmp[k1] / 255.0
343 | col = (1 - f) * col0 + f * col1
344 | idx = (rad <= 1)
345 | col[idx] = 1 - rad[idx] * (1 - col[idx])
346 | col[~idx] = col[~idx] * 0.75 # out of range
347 | # Note the 2-i => BGR instead of RGB
348 | ch_idx = 2 - i if convert_to_bgr else i
349 | flow_image[:, :, ch_idx] = np.floor(255 * col)
350 | return flow_image
351 |
352 |
353 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
354 | """
355 | Expects a two dimensional flow image of shape.
356 |
357 | Args:
358 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
359 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
360 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
361 |
362 | Returns:
363 | np.ndarray: Flow visualization image of shape [H,W,3]
364 | """
365 | assert flow_uv.ndim == 3, 'input flow must have three dimensions'
366 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
367 | if clip_flow is not None:
368 | flow_uv = np.clip(flow_uv, 0, clip_flow)
369 | u = flow_uv[:, :, 0]
370 | v = flow_uv[:, :, 1]
371 | rad = np.sqrt(np.square(u) + np.square(v))
372 | rad_max = np.max(rad)
373 | epsilon = 1e-5
374 | u = u / (rad_max + epsilon)
375 | v = v / (rad_max + epsilon)
376 | return flow_uv_to_colors(u, v, convert_to_bgr)
377 |
378 |
379 | def flow_warp(x,
380 | flow,
381 | interpolation='bilinear',
382 | padding_mode='zeros',
383 | align_corners=True):
384 | """Warp an image or a feature map with optical flow.
385 | Args:
386 | x (Tensor): Tensor with size (n, c, h, w).
387 | flow (Tensor): Tensor with size (n, h, w, 2). The last dimension is
388 | a two-channel, denoting the width and height relative offsets.
389 | Note that the values are not normalized to [-1, 1].
390 | interpolation (str): Interpolation mode: 'nearest' or 'bilinear'.
391 | Default: 'bilinear'.
392 | padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'.
393 | Default: 'zeros'.
394 | align_corners (bool): Whether align corners. Default: True.
395 | Returns:
396 | Tensor: Warped image or feature map.
397 | """
398 | if x.size()[-2:] != flow.size()[1:3]:
399 | raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and '
400 | f'flow ({flow.size()[1:3]}) are not the same.')
401 | _, _, h, w = x.size()
402 | # create mesh grid
403 | grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w))
404 | grid = torch.stack((grid_x, grid_y), 2).type_as(x) # (w, h, 2)
405 | grid.requires_grad = False
406 |
407 | grid_flow = grid + flow
408 | # scale grid_flow to [-1,1]
409 | grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0
410 | grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0
411 | grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3)
412 | output = F.grid_sample(x,
413 | grid_flow,
414 | mode=interpolation,
415 | padding_mode=padding_mode,
416 | align_corners=align_corners)
417 | return output
418 |
419 |
420 | def initial_mask_flow(mask):
421 | """
422 | mask 1 indicates valid pixel 0 indicates unknown pixel
423 | """
424 | B, T, C, H, W = mask.shape
425 |
426 | # calculate relative position
427 | grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
428 |
429 | grid_y, grid_x = grid_y.type_as(mask), grid_x.type_as(mask)
430 | abs_relative_pos_y = H - torch.abs(grid_y[None, :, :] - grid_y[:, None, :])
431 | relative_pos_y = H - (grid_y[None, :, :] - grid_y[:, None, :])
432 |
433 | abs_relative_pos_x = W - torch.abs(grid_x[:, None, :] - grid_x[:, :, None])
434 | relative_pos_x = W - (grid_x[:, None, :] - grid_x[:, :, None])
435 |
436 | # calculate the nearest indices
437 | pos_up = mask.unsqueeze(3).repeat(
438 | 1, 1, 1, H, 1, 1).flip(4) * abs_relative_pos_y[None, None, None] * (
439 | relative_pos_y <= H)[None, None, None]
440 | nearest_indice_up = pos_up.max(dim=4)[1]
441 |
442 | pos_down = mask.unsqueeze(3).repeat(1, 1, 1, H, 1, 1) * abs_relative_pos_y[
443 | None, None, None] * (relative_pos_y <= H)[None, None, None]
444 | nearest_indice_down = (pos_down).max(dim=4)[1]
445 |
446 | pos_left = mask.unsqueeze(4).repeat(
447 | 1, 1, 1, 1, W, 1).flip(5) * abs_relative_pos_x[None, None, None] * (
448 | relative_pos_x <= W)[None, None, None]
449 | nearest_indice_left = (pos_left).max(dim=5)[1]
450 |
451 | pos_right = mask.unsqueeze(4).repeat(
452 | 1, 1, 1, 1, W, 1) * abs_relative_pos_x[None, None, None] * (
453 | relative_pos_x <= W)[None, None, None]
454 | nearest_indice_right = (pos_right).max(dim=5)[1]
455 |
456 | # NOTE: IMPORTANT !!! depending on how to use this offset
457 | initial_offset_up = -(nearest_indice_up - grid_y[None, None, None]).flip(3)
458 | initial_offset_down = nearest_indice_down - grid_y[None, None, None]
459 |
460 | initial_offset_left = -(nearest_indice_left -
461 | grid_x[None, None, None]).flip(4)
462 | initial_offset_right = nearest_indice_right - grid_x[None, None, None]
463 |
464 | # nearest_indice_x = (mask.unsqueeze(1).repeat(1, img_width, 1) * relative_pos_x).max(dim=2)[1]
465 | # initial_offset_x = nearest_indice_x - grid_x
466 |
467 | # handle the boundary cases
468 | final_offset_down = (initial_offset_down < 0) * initial_offset_up + (
469 | initial_offset_down > 0) * initial_offset_down
470 | final_offset_up = (initial_offset_up > 0) * initial_offset_down + (
471 | initial_offset_up < 0) * initial_offset_up
472 | final_offset_right = (initial_offset_right < 0) * initial_offset_left + (
473 | initial_offset_right > 0) * initial_offset_right
474 | final_offset_left = (initial_offset_left > 0) * initial_offset_right + (
475 | initial_offset_left < 0) * initial_offset_left
476 | zero_offset = torch.zeros_like(final_offset_down)
477 | # out = torch.cat([final_offset_left, zero_offset, final_offset_right, zero_offset, zero_offset, final_offset_up, zero_offset, final_offset_down], dim=2)
478 | out = torch.cat([
479 | zero_offset, final_offset_left, zero_offset, final_offset_right,
480 | final_offset_up, zero_offset, final_offset_down, zero_offset
481 | ],
482 | dim=2)
483 |
484 | return out
485 |
--------------------------------------------------------------------------------
/model/modules/maskflownets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from mmflow.apis import init_model
6 |
7 |
8 | class MaskFlowNetS(nn.Module):
9 | """MaskFlowNetS network structure.
10 | Paper:
11 | MaskFlownet: Asymmetric Feature Matching With Learnable Occlusion Mask, CVPR, 2020
12 | Args:
13 | pretrained (str): path for pre-trained MaskFlowNetS. Default: None.
14 | """
15 | def __init__(
16 | self,
17 | use_pretrain=True,
18 | # pretrained='https://download.openmmlab.com/mmflow/maskflownet/maskflownets_8x1_slong_flyingchairs_384x448.pth',
19 | pretrained='./release_model/maskflownets_8x1_sfine_flyingthings3d_subset_384x768.pth',
20 | config_file='../mmflow/configs/maskflownet/maskflownets_8x1_sfine_flyingthings3d_subset_384x768.py',
21 | device='cuda:0',
22 | module_level=6
23 | ):
24 | super().__init__()
25 |
26 | if use_pretrain:
27 | if isinstance(pretrained, str):
28 | print("load pretrained MaskFlowNetS...")
29 | # self.maskflownetS = init_model(config_file, pretrained, device=device)
30 | self.maskflownetS = init_model(config_file, pretrained, device='cpu')
31 | # load_checkpoint(self, pretrained, strict=True)
32 | elif pretrained is not None:
33 | raise TypeError('[pretrained] should be str or None, '
34 | f'but got {type(pretrained)}.')
35 |
36 | self.register_buffer(
37 | 'mean',
38 | torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
39 | self.register_buffer(
40 | 'std',
41 | torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
42 |
43 | @staticmethod
44 | def centralize(img1, img2):
45 | """Centralize input images.
46 | Args:
47 | img1 (Tensor): The first input image.
48 | img2 (Tensor): The second input image.
49 | Returns:
50 | Tuple[Tensor, Tensor]: The first centralized image and the second
51 | centralized image.
52 | """
53 | rgb_mean = torch.cat((img1, img2), 2)
54 | rgb_mean = rgb_mean.view(rgb_mean.shape[0], 3, -1).mean(2)
55 | rgb_mean = rgb_mean.view(rgb_mean.shape[0], 3, 1, 1)
56 | return img1 - rgb_mean, img2 - rgb_mean, rgb_mean
57 |
58 | def compute_flow(self, ref, supp):
59 | """Compute flow from ref to supp.
60 | Note that in this function, the images are already resized to a
61 | multiple of 64.
62 | Args:
63 | ref (Tensor): Reference image with shape of (n, 3, h, w).
64 | supp (Tensor): Supporting image with shape of (n, 3, h, w).
65 | Returns:
66 | Tensor: Estimated optical flow: (n, 2, h, w).
67 | """
68 | n, _, h, w = ref.size()
69 |
70 | feat1, feat2 = self.maskflownetS.extract_feat(torch.cat((ref, supp), dim=1))
71 | flows_stage1, mask_stage1 = self.maskflownetS.decoder(
72 | feat1, feat2, return_mask=True)
73 |
74 | return flows_stage1
75 |
76 | def forward(self, ref, supp):
77 | """Forward function of MaskFlowNetS.
78 | This function computes the optical flow from ref to supp.
79 | Args:
80 | ref (Tensor): Reference image with shape of (n, 3, h, w).
81 | supp (Tensor): Supporting image with shape of (n, 3, h, w).
82 | Returns:
83 | Tensor: Estimated optical flow: (n, 2, h, w).
84 | """
85 |
86 | # upsize to a multiple of 64
87 | h, w = ref.shape[2:4]
88 | w_up = w if (w % 64) == 0 else 64 * (w // 64 + 1)
89 | h_up = h if (h % 64) == 0 else 64 * (h // 64 + 1)
90 | ref = F.interpolate(input=ref,
91 | size=(h_up, w_up),
92 | mode='bilinear',
93 | align_corners=False)
94 | supp = F.interpolate(input=supp,
95 | size=(h_up, w_up),
96 | mode='bilinear',
97 | align_corners=False)
98 |
99 | # compute flow, and resize back to the original resolution
100 | flow = F.interpolate(input=self.compute_flow(ref, supp)['level2'],
101 | size=(h, w),
102 | mode='bilinear',
103 | align_corners=False)
104 |
105 | # adjust the flow values
106 | flow[:, 0, :, :] *= float(w) / float(w_up)
107 | flow[:, 1, :, :] *= float(h) / float(h_up)
108 |
109 | return flow
110 |
111 |
112 | def test_MFN():
113 | # Specify the path to model config and checkpoint file
114 | config_file = '../../../mmflow/configs/maskflownets_8x1_sfine_flyingthings3d_subset_384x768.py'
115 | checkpoint_file = '../../release_model/maskflownets_8x1_sfine_flyingthings3d_subset_384x768.pth'
116 |
117 | # build the model from a config file and a checkpoint file
118 | model = init_model(config_file, checkpoint_file, device='cuda:0')
119 | pass
120 |
121 |
122 | if __name__ == "__main__":
123 | test_MFN()
124 |
--------------------------------------------------------------------------------
/model/modules/spectral_norm.py:
--------------------------------------------------------------------------------
1 | """
2 | Spectral Normalization from https://arxiv.org/abs/1802.05957
3 | """
4 | import torch
5 | from torch.nn.functional import normalize
6 |
7 |
8 | class SpectralNorm(object):
9 | # Invariant before and after each forward call:
10 | # u = normalize(W @ v)
11 | # NB: At initialization, this invariant is not enforced
12 |
13 | _version = 1
14 |
15 | # At version 1:
16 | # made `W` not a buffer,
17 | # added `v` as a buffer, and
18 | # made eval mode use `W = u @ W_orig @ v` rather than the stored `W`.
19 |
20 | def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
21 | self.name = name
22 | self.dim = dim
23 | if n_power_iterations <= 0:
24 | raise ValueError(
25 | 'Expected n_power_iterations to be positive, but '
26 | 'got n_power_iterations={}'.format(n_power_iterations))
27 | self.n_power_iterations = n_power_iterations
28 | self.eps = eps
29 |
30 | def reshape_weight_to_matrix(self, weight):
31 | weight_mat = weight
32 | if self.dim != 0:
33 | # permute dim to front
34 | weight_mat = weight_mat.permute(
35 | self.dim,
36 | *[d for d in range(weight_mat.dim()) if d != self.dim])
37 | height = weight_mat.size(0)
38 | return weight_mat.reshape(height, -1)
39 |
40 | def compute_weight(self, module, do_power_iteration):
41 | # NB: If `do_power_iteration` is set, the `u` and `v` vectors are
42 | # updated in power iteration **in-place**. This is very important
43 | # because in `DataParallel` forward, the vectors (being buffers) are
44 | # broadcast from the parallelized module to each module replica,
45 | # which is a new module object created on the fly. And each replica
46 | # runs its own spectral norm power iteration. So simply assigning
47 | # the updated vectors to the module this function runs on will cause
48 | # the update to be lost forever. And the next time the parallelized
49 | # module is replicated, the same randomly initialized vectors are
50 | # broadcast and used!
51 | #
52 | # Therefore, to make the change propagate back, we rely on two
53 | # important behaviors (also enforced via tests):
54 | # 1. `DataParallel` doesn't clone storage if the broadcast tensor
55 | # is already on correct device; and it makes sure that the
56 | # parallelized module is already on `device[0]`.
57 | # 2. If the out tensor in `out=` kwarg has correct shape, it will
58 | # just fill in the values.
59 | # Therefore, since the same power iteration is performed on all
60 | # devices, simply updating the tensors in-place will make sure that
61 | # the module replica on `device[0]` will update the _u vector on the
62 | # parallized module (by shared storage).
63 | #
64 | # However, after we update `u` and `v` in-place, we need to **clone**
65 | # them before using them to normalize the weight. This is to support
66 | # backproping through two forward passes, e.g., the common pattern in
67 | # GAN training: loss = D(real) - D(fake). Otherwise, engine will
68 | # complain that variables needed to do backward for the first forward
69 | # (i.e., the `u` and `v` vectors) are changed in the second forward.
70 | weight = getattr(module, self.name + '_orig')
71 | u = getattr(module, self.name + '_u')
72 | v = getattr(module, self.name + '_v')
73 | weight_mat = self.reshape_weight_to_matrix(weight)
74 |
75 | if do_power_iteration:
76 | with torch.no_grad():
77 | for _ in range(self.n_power_iterations):
78 | # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
79 | # are the first left and right singular vectors.
80 | # This power iteration produces approximations of `u` and `v`.
81 | v = normalize(torch.mv(weight_mat.t(), u),
82 | dim=0,
83 | eps=self.eps,
84 | out=v)
85 | u = normalize(torch.mv(weight_mat, v),
86 | dim=0,
87 | eps=self.eps,
88 | out=u)
89 | if self.n_power_iterations > 0:
90 | # See above on why we need to clone
91 | u = u.clone()
92 | v = v.clone()
93 |
94 | sigma = torch.dot(u, torch.mv(weight_mat, v))
95 | weight = weight / sigma
96 | return weight
97 |
98 | def remove(self, module):
99 | with torch.no_grad():
100 | weight = self.compute_weight(module, do_power_iteration=False)
101 | delattr(module, self.name)
102 | delattr(module, self.name + '_u')
103 | delattr(module, self.name + '_v')
104 | delattr(module, self.name + '_orig')
105 | module.register_parameter(self.name,
106 | torch.nn.Parameter(weight.detach()))
107 |
108 | def __call__(self, module, inputs):
109 | setattr(
110 | module, self.name,
111 | self.compute_weight(module, do_power_iteration=module.training))
112 |
113 | def _solve_v_and_rescale(self, weight_mat, u, target_sigma):
114 | # Tries to returns a vector `v` s.t. `u = normalize(W @ v)`
115 | # (the invariant at top of this class) and `u @ W @ v = sigma`.
116 | # This uses pinverse in case W^T W is not invertible.
117 | v = torch.chain_matmul(weight_mat.t().mm(weight_mat).pinverse(),
118 | weight_mat.t(), u.unsqueeze(1)).squeeze(1)
119 | return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v)))
120 |
121 | @staticmethod
122 | def apply(module, name, n_power_iterations, dim, eps):
123 | for k, hook in module._forward_pre_hooks.items():
124 | if isinstance(hook, SpectralNorm) and hook.name == name:
125 | raise RuntimeError(
126 | "Cannot register two spectral_norm hooks on "
127 | "the same parameter {}".format(name))
128 |
129 | fn = SpectralNorm(name, n_power_iterations, dim, eps)
130 | weight = module._parameters[name]
131 |
132 | with torch.no_grad():
133 | weight_mat = fn.reshape_weight_to_matrix(weight)
134 |
135 | h, w = weight_mat.size()
136 | # randomly initialize `u` and `v`
137 | u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps)
138 | v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps)
139 |
140 | delattr(module, fn.name)
141 | module.register_parameter(fn.name + "_orig", weight)
142 | # We still need to assign weight back as fn.name because all sorts of
143 | # things may assume that it exists, e.g., when initializing weights.
144 | # However, we can't directly assign as it could be an nn.Parameter and
145 | # gets added as a parameter. Instead, we register weight.data as a plain
146 | # attribute.
147 | setattr(module, fn.name, weight.data)
148 | module.register_buffer(fn.name + "_u", u)
149 | module.register_buffer(fn.name + "_v", v)
150 |
151 | module.register_forward_pre_hook(fn)
152 |
153 | module._register_state_dict_hook(SpectralNormStateDictHook(fn))
154 | module._register_load_state_dict_pre_hook(
155 | SpectralNormLoadStateDictPreHook(fn))
156 | return fn
157 |
158 |
159 | # This is a top level class because Py2 pickle doesn't like inner class nor an
160 | # instancemethod.
161 | class SpectralNormLoadStateDictPreHook(object):
162 | # See docstring of SpectralNorm._version on the changes to spectral_norm.
163 | def __init__(self, fn):
164 | self.fn = fn
165 |
166 | # For state_dict with version None, (assuming that it has gone through at
167 | # least one training forward), we have
168 | #
169 | # u = normalize(W_orig @ v)
170 | # W = W_orig / sigma, where sigma = u @ W_orig @ v
171 | #
172 | # To compute `v`, we solve `W_orig @ x = u`, and let
173 | # v = x / (u @ W_orig @ x) * (W / W_orig).
174 | def __call__(self, state_dict, prefix, local_metadata, strict,
175 | missing_keys, unexpected_keys, error_msgs):
176 | fn = self.fn
177 | version = local_metadata.get('spectral_norm',
178 | {}).get(fn.name + '.version', None)
179 | if version is None or version < 1:
180 | with torch.no_grad():
181 | weight_orig = state_dict[prefix + fn.name + '_orig']
182 | # weight = state_dict.pop(prefix + fn.name)
183 | # sigma = (weight_orig / weight).mean()
184 | weight_mat = fn.reshape_weight_to_matrix(weight_orig)
185 | u = state_dict[prefix + fn.name + '_u']
186 | # v = fn._solve_v_and_rescale(weight_mat, u, sigma)
187 | # state_dict[prefix + fn.name + '_v'] = v
188 |
189 |
190 | # This is a top level class because Py2 pickle doesn't like inner class nor an
191 | # instancemethod.
192 | class SpectralNormStateDictHook(object):
193 | # See docstring of SpectralNorm._version on the changes to spectral_norm.
194 | def __init__(self, fn):
195 | self.fn = fn
196 |
197 | def __call__(self, module, state_dict, prefix, local_metadata):
198 | if 'spectral_norm' not in local_metadata:
199 | local_metadata['spectral_norm'] = {}
200 | key = self.fn.name + '.version'
201 | if key in local_metadata['spectral_norm']:
202 | raise RuntimeError(
203 | "Unexpected key in metadata['spectral_norm']: {}".format(key))
204 | local_metadata['spectral_norm'][key] = self.fn._version
205 |
206 |
207 | def spectral_norm(module,
208 | name='weight',
209 | n_power_iterations=1,
210 | eps=1e-12,
211 | dim=None):
212 | r"""Applies spectral normalization to a parameter in the given module.
213 |
214 | .. math::
215 | \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})},
216 | \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
217 |
218 | Spectral normalization stabilizes the training of discriminators (critics)
219 | in Generative Adversarial Networks (GANs) by rescaling the weight tensor
220 | with spectral norm :math:`\sigma` of the weight matrix calculated using
221 | power iteration method. If the dimension of the weight tensor is greater
222 | than 2, it is reshaped to 2D in power iteration method to get spectral
223 | norm. This is implemented via a hook that calculates spectral norm and
224 | rescales weight before every :meth:`~Module.forward` call.
225 |
226 | See `Spectral Normalization for Generative Adversarial Networks`_ .
227 |
228 | .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
229 |
230 | Args:
231 | module (nn.Module): containing module
232 | name (str, optional): name of weight parameter
233 | n_power_iterations (int, optional): number of power iterations to
234 | calculate spectral norm
235 | eps (float, optional): epsilon for numerical stability in
236 | calculating norms
237 | dim (int, optional): dimension corresponding to number of outputs,
238 | the default is ``0``, except for modules that are instances of
239 | ConvTranspose{1,2,3}d, when it is ``1``
240 |
241 | Returns:
242 | The original module with the spectral norm hook
243 |
244 | Example::
245 |
246 | >>> m = spectral_norm(nn.Linear(20, 40))
247 | >>> m
248 | Linear(in_features=20, out_features=40, bias=True)
249 | >>> m.weight_u.size()
250 | torch.Size([40])
251 |
252 | """
253 | if dim is None:
254 | if isinstance(module,
255 | (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
256 | torch.nn.ConvTranspose3d)):
257 | dim = 1
258 | else:
259 | dim = 0
260 | SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
261 | return module
262 |
263 |
264 | def remove_spectral_norm(module, name='weight'):
265 | r"""Removes the spectral normalization reparameterization from a module.
266 |
267 | Args:
268 | module (Module): containing module
269 | name (str, optional): name of weight parameter
270 |
271 | Example:
272 | >>> m = spectral_norm(nn.Linear(40, 10))
273 | >>> remove_spectral_norm(m)
274 | """
275 | for k, hook in module._forward_pre_hooks.items():
276 | if isinstance(hook, SpectralNorm) and hook.name == name:
277 | hook.remove(module)
278 | del module._forward_pre_hooks[k]
279 | return module
280 |
281 | raise ValueError("spectral_norm of '{}' not found in {}".format(
282 | name, module))
283 |
284 |
285 | def use_spectral_norm(module, use_sn=False):
286 | if use_sn:
287 | return spectral_norm(module)
288 | return module
--------------------------------------------------------------------------------
/release_model/README.md:
--------------------------------------------------------------------------------
1 | Place the downloaded model here.
2 |
3 | [comment]: <> (:link: **Download Links:** [[Google Drive](https://drive.google.com/file/d/1tNJMTJ2gmWdIXJoHVi5-H504uImUiJW9/view?usp=sharing)] [[Baidu Disk](https://pan.baidu.com/s/1qXAErbilY_n_Fh9KB8UF7w?pwd=lsjw)])
4 |
5 | The directory structure will be arranged as:
6 | ```
7 | release_model
8 | |- FlowLens.pth
9 | |- i3d_rgb_imagenet.pt (for evaluating VFID metric)
10 | |- maskflownets_8x1_sfine_flyingthings3d_subset_384x768.pth
11 | |- README.md
12 | ```
13 |
--------------------------------------------------------------------------------
/release_model/maskflownets_8x1_sfine_flyingthings3d_subset_384x768.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '/mnt/WORKSPACE/mmflow/configs/_base_/models/maskflownets.py',
3 | '/mnt/WORKSPACE/mmflow/configs/_base_/datasets/flyingthings3d_subset_384x768.py',
4 | '/mnt/WORKSPACE/mmflow/configs/_base_/schedules/schedule_s_fine.py', './mnt/WORKSPACE/mmflow/configs/_base_/default_runtime.py'
5 | ]
6 |
7 | optimizer = dict(type='Adam', lr=0.0001, weight_decay=0., betas=(0.9, 0.999))
8 |
9 | # Train on FlyingChairs and finetune on FlyingThings3D_subset
10 | load_from = 'https://download.openmmlab.com/mmflow/maskflownet/maskflownets_8x1_slong_flyingchairs_384x448.pth' # noqa
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | from shutil import copyfile
5 |
6 | import torch
7 | import torch.multiprocessing as mp
8 |
9 | from core.trainer import Trainer
10 | from core.dist import (
11 | get_world_size,
12 | get_local_rank,
13 | get_global_rank,
14 | get_master_ip,
15 | )
16 |
17 | parser = argparse.ArgumentParser(description='FlowLens')
18 | parser.add_argument('-c',
19 | '--config',
20 | default=None,
21 | type=str)
22 | parser.add_argument('-p', '--port', default='23455', type=str)
23 | args = parser.parse_args()
24 |
25 |
26 | def main_worker(rank, config):
27 | if 'local_rank' not in config:
28 | config['local_rank'] = config['global_rank'] = rank
29 | if config['distributed']:
30 | torch.cuda.set_device(int(config['local_rank']))
31 | torch.distributed.init_process_group(backend='nccl', # nccl for Linux DDP, gloo for windows
32 | init_method=config['init_method'],
33 | world_size=config['world_size'],
34 | rank=config['global_rank'],
35 | group_name='mtorch')
36 | print('using GPU {}-{} for training'.format(int(config['global_rank']),
37 | int(config['local_rank'])))
38 |
39 | config['save_dir'] = os.path.join(
40 | config['save_dir'],
41 | '{}_{}'.format(config['model']['net'],
42 | os.path.basename(args.config).split('.')[0]))
43 |
44 | config['save_metric_dir'] = os.path.join(
45 | './scores',
46 | '{}_{}'.format(config['model']['net'],
47 | os.path.basename(args.config).split('.')[0]))
48 |
49 | if torch.cuda.is_available():
50 | config['device'] = torch.device("cuda:{}".format(config['local_rank']))
51 | else:
52 | config['device'] = 'cpu'
53 |
54 | if (not config['distributed']) or config['global_rank'] == 0:
55 | os.makedirs(config['save_dir'], exist_ok=True)
56 | os.makedirs(config['save_metric_dir'], exist_ok=True)
57 | config_path = os.path.join(config['save_dir'],
58 | args.config.split('/')[-1])
59 | if not os.path.isfile(config_path):
60 | copyfile(args.config, config_path)
61 | print('[**] create folder {}'.format(config['save_dir']))
62 |
63 | trainer = Trainer(config)
64 | trainer.train()
65 |
66 |
67 | if __name__ == "__main__":
68 |
69 | torch.backends.cudnn.benchmark = True
70 |
71 | mp.set_sharing_strategy('file_system')
72 |
73 | # loading configs
74 | config = json.load(open(args.config))
75 |
76 | # setting distributed configurations
77 | config['world_size'] = get_world_size()
78 | config['init_method'] = f"tcp://{get_master_ip()}:{args.port}"
79 | config['distributed'] = True if config['world_size'] > 1 else False
80 | print(config['world_size'])
81 | # setup distributed parallel training environments
82 | if get_master_ip() == "127.0.0.1":
83 | # manually launch distributed processes
84 | mp.spawn(main_worker, nprocs=config['world_size'], args=(config, ))
85 | else:
86 | # multiple processes have been launched by openmpi
87 | config['local_rank'] = get_local_rank()
88 | config['global_rank'] = get_global_rank()
89 | main_worker(-1, config)
90 |
--------------------------------------------------------------------------------