├── .gitignore
├── LICENSE
├── README.md
├── configs
├── beauty_0
│ └── base.yaml
├── beauty_1
│ └── base.yaml
├── hash.json
├── lemon_hit
│ └── base.yaml
├── scene_0
│ └── base.yaml
└── white_smoke
│ └── base.yaml
├── data_preprocessing
├── RAFT
│ ├── core
│ │ ├── __init__.py
│ │ ├── corr.py
│ │ ├── datasets.py
│ │ ├── extractor.py
│ │ ├── raft.py
│ │ ├── update.py
│ │ └── utils
│ │ │ ├── __init__.py
│ │ │ ├── augmentor.py
│ │ │ ├── flow_viz.py
│ │ │ ├── frame_utils.py
│ │ │ └── utils.py
│ ├── demo.py
│ ├── models
│ │ └── download_models.sh
│ └── run_raft.sh
└── preproc_mask.py
├── datasets
├── __init__.py
├── distributed_weighted_sampler.py
└── video_dataset.py
├── docs
├── index.html
├── static
│ ├── css
│ │ ├── bulma-carousel.min.css
│ │ ├── bulma-slider.min.css
│ │ ├── bulma.css.map.txt
│ │ ├── bulma.min.css
│ │ ├── fontawesome.all.min.css
│ │ └── index.css
│ ├── images
│ │ └── framework.png
│ ├── js
│ │ ├── bulma-carousel.js
│ │ ├── bulma-carousel.min.js
│ │ ├── bulma-slider.js
│ │ ├── bulma-slider.min.js
│ │ ├── fontawesome.all.min.js
│ │ ├── index.js
│ │ └── video_comparison.js
│ └── video_demos_compressed
│ │ ├── 1080x540
│ │ ├── beauty_0_base_video_hash_transformed_dual(1).mp4
│ │ ├── beauty_0_base_video_hash_transformed_dual(2).mp4
│ │ ├── beauty_1_base_transformed_dual(1).mp4
│ │ ├── beauty_1_base_transformed_dual(2).mp4
│ │ ├── rainbow.mp4
│ │ └── tifa.mp4
│ │ ├── 1080x960
│ │ ├── dog.mp4
│ │ ├── scene_1_base_transformed_dual.mp4
│ │ ├── smoke_colorful.mp4
│ │ └── smoke_ink.mp4
│ │ ├── 1920x540
│ │ ├── cloud_atlas4_base_transformed_dual.mp4
│ │ ├── diamond_1_base_transformed_dual.mp4
│ │ ├── lemon_earth.mp4
│ │ ├── long_season.mp4
│ │ ├── scene_3_base_transformed_dual.mp4
│ │ └── titanic.mp4
│ │ ├── slider
│ │ └── slider.mp4
│ │ └── teaser
│ │ ├── cloud_atlas_SR_dual.mp4
│ │ ├── segtrack.mp4
│ │ ├── teaser.mp4
│ │ └── tracking_zoom.mp4
└── teaser.gif
├── losses.py
├── metrics.py
├── models
└── implicit_model.py
├── opt.py
├── requirements.txt
├── scripts
├── test_canonical.sh
├── test_multi.sh
└── train_multi.sh
├── train.py
└── utils
├── __init__.py
├── image_utils.py
├── video_visualizer.py
└── warmup_scheduler.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore compiled files.
2 | __pycache__/
3 | *.py[cod]
4 | tmp_build/
5 | *.pyc
6 |
7 | # Ignore files created by IDEs.
8 | /.vscode/
9 | /.idea/
10 | .ipynb_*/
11 | *.ipynb
12 | .DS_Store
13 | *.sw[pon]
14 |
15 | # Ignore data files.
16 | data/
17 | *.npy
18 | *.tar
19 | *.zip
20 | *.mdb
21 | *.ckpt
22 |
23 | # Ignore network files.
24 | ckpts/
25 | *.pth
26 | *.pt
27 | *.pkl
28 | *.h5
29 | *.dat
30 | *.ckpt
31 |
32 | # Ignore log files.
33 | results/
34 | resources/
35 | events/
36 | profile/
37 | logs/
38 | # *.json
39 | *.log
40 | events.*
41 |
42 | # Files that should not be ignored.
43 | !/requirements/*
44 |
45 | # Others shuld be ignored.
46 | all_sequences/
47 | ckpts/
48 | logs/
49 | # configs/
50 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | ------------------------------ LICENSE for CoDeF ------------------------------
2 |
3 | Copyright (c) 2023 Ant Group.
4 |
5 | MIT License
6 |
7 | Permission is hereby granted, free of charge, to any person obtaining a copy
8 | of this software and associated documentation files (the "Software"), to deal
9 | in the Software without restriction, including without limitation the rights
10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | copies of the Software, and to permit persons to whom the Software is
12 | furnished to do so, subject to the following conditions:
13 |
14 | The above copyright notice and this permission notice shall be included in all
15 | copies or substantial portions of the Software.
16 |
17 | THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | SOFTWARE.
24 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CoDeF: Content Deformation Fields for Temporally Consistent Video Processing
2 |
3 |
4 |
5 | [Hao Ouyang](https://ken-ouyang.github.io/)\*, [Qiuyu Wang](https://github.com/qiuyu96/)\*, [Yuxi Xiao](https://henry123-boy.github.io/)\*, [Qingyan Bai](https://scholar.google.com/citations?user=xUMjxi4AAAAJ&hl=en), [Juntao Zhang](https://github.com/JordanZh), [Kecheng Zheng](https://scholar.google.com/citations?user=hMDQifQAAAAJ), [Xiaowei Zhou](https://xzhou.me/),
6 | [Qifeng Chen](https://cqf.io/)†, [Yujun Shen](https://shenyujun.github.io/)† (*equal contribution, †corresponding author)
7 |
8 | **CVPR 2024 Highlight**
9 |
10 | #### [Project Page](https://qiuyu96.github.io/CoDeF/) | [Paper](https://arxiv.org/abs/2308.07926) | [High-Res Translation Demo](https://ezioby.github.io/CoDeF_Demo/) | [Colab](https://colab.research.google.com/github/camenduru/CoDeF-colab/blob/main/CoDeF_colab.ipynb)
11 |
12 |
13 |
14 | ## Requirements
15 |
16 | The codebase is tested on
17 |
18 | * Ubuntu 20.04
19 | * Python 3.10
20 | * [PyTorch](https://pytorch.org/) 2.0.0
21 | * [PyTorch Lightning](https://www.pytorchlightning.ai/index.html) 2.0.2
22 | * 1 NVIDIA GPU (RTX A6000) with CUDA version 11.7. (Other GPUs are also suitable, and 10GB GPU memory is sufficient to run our code.)
23 |
24 | To use video visualizer, please install `ffmpeg` via
25 |
26 | ```shell
27 | sudo apt-get install ffmpeg
28 | ```
29 |
30 | For additional Python libraries, please install with
31 |
32 | ```shell
33 | pip install -r requirements.txt
34 | ```
35 |
36 | Our code also depends on [tiny-cuda-nn](https://github.com/NVlabs/tiny-cuda-nn).
37 | See [this repository](https://github.com/NVlabs/tiny-cuda-nn#pytorch-extension)
38 | for Pytorch extension install instructions.
39 |
40 | ## Data
41 |
42 | ### Provided data
43 |
44 | We have provided some videos [here](https://drive.google.com/file/d/1cKZF6ILeokCjsSAGBmummcQh0uRGaC_F/view?usp=sharing) for quick test. Please download and unzip the data and put them in the root directory. More videos can be downloaded [here](https://drive.google.com/file/d/10Msz37MpjZQFPXlDWCZqrcQjhxpQSvCI/view?usp=sharing).
45 |
46 | ### Customize your own data
47 |
48 | We segement video sequences using [SAM-Track](https://github.com/z-x-yang/Segment-and-Track-Anything). Once you obtain the mask files, place them in the folder `all_sequences/{YOUR_SEQUENCE_NAME}/{YOUR_SEQUENCE_NAME}_masks`. Next, execute the following command:
49 |
50 | ```shell
51 | cd data_preprocessing
52 | python preproc_mask.py
53 | ```
54 |
55 | We extract optical flows of video sequences using [RAFT](https://github.com/princeton-vl/RAFT). To get started, please follow the instructions provided [here](https://github.com/princeton-vl/RAFT#demos) to download their pretrained model. Once downloaded, place the model in the `data_preprocessing/RAFT/models` folder. After that, you can execute the following command:
56 |
57 | ```shell
58 | cd data_preprocessing/RAFT
59 | ./run_raft.sh
60 | ```
61 |
62 | Remember to update the sequence name and root directory in both `data_preprocessing/preproc_mask.py` and `data_preprocessing/RAFT/run_raft.sh` accordingly.
63 |
64 | After obtaining the files, please organize your own data as follows:
65 |
66 | ```
67 | CoDeF
68 | │
69 | └─── all_sequences
70 | │
71 | └─── NAME1
72 | └─ NAME1
73 | └─ NAME1_masks_0 (optional)
74 | └─ NAME1_masks_1 (optional)
75 | └─ NAME1_flow (optional)
76 | └─ NAME1_flow_confidence (optional)
77 | │
78 | └─── NAME2
79 | └─ NAME2
80 | └─ NAME2_masks_0 (optional)
81 | └─ NAME2_masks_1 (optional)
82 | └─ NAME2_flow (optional)
83 | └─ NAME2_flow_confidence (optional)
84 | │
85 | └─── ...
86 | ```
87 |
88 | ## Pretrained checkpoints
89 |
90 | You can download checkpoints pre-trained on the provided videos via
91 |
92 | | Sequence Name | Config | Download | OpenXLab |
93 | | :-------- | :----: | :----------------------------------------------------------: | :---------:|
94 | | beauty_0 | configs/beauty_0/base.yaml | [Google drive link](https://drive.google.com/file/d/11SWfnfDct8bE16802PyqYJqsU4x6ACn8/view?usp=sharing) |[](https://openxlab.org.cn/models/detail/HaoOuyang/CoDeF)|
95 | | beauty_1 | configs/beauty_1/base.yaml | [Google drive link](https://drive.google.com/file/d/1bSK0ChbPdURWGLdtc9CPLkN4Tfnng51k/view?usp=sharing) | [](https://openxlab.org.cn/models/detail/HaoOuyang/CoDeF) |
96 | | white_smoke | configs/white_smoke/base.yaml | [Google drive link](https://drive.google.com/file/d/1QOBCDGV2hHwxq4eL1E_45z5zhZ-wTJR7/view?usp=sharing) | [](https://openxlab.org.cn/models/detail/HaoOuyang/CoDeF) |
97 | | lemon_hit | configs/lemon_hit/base.yaml | [Google drive link](https://drive.google.com/file/d/140ctcLbv7JTIiy53MuCYtI4_zpIvRXzq/view?usp=sharing) | [](https://openxlab.org.cn/models/detail/HaoOuyang/CoDeF)|
98 | | scene_0 | configs/scene_0/base.yaml | [Google drive link](https://drive.google.com/file/d/1abOdREarfw1DGscahOJd2gZf1Xn_zN-F/view?usp=sharing) |[](https://openxlab.org.cn/models/detail/HaoOuyang/CoDeF)|
99 |
100 | And organize files as follows
101 |
102 | ```
103 | CoDeF
104 | │
105 | └─── ckpts/all_sequences
106 | │
107 | └─── NAME1
108 | │
109 | └─── EXP_NAME (base)
110 | │
111 | └─── NAME1.ckpt
112 | │
113 | └─── NAME2
114 | │
115 | └─── EXP_NAME (base)
116 | │
117 | └─── NAME2.ckpt
118 | |
119 | └─── ...
120 | ```
121 |
122 | ## Train a new model
123 |
124 | ```shell
125 | ./scripts/train_multi.sh
126 | ```
127 |
128 | where
129 | * `GPU`: Decide which GPU to train on;
130 | * `NAME`: Name of the video sequence;
131 | * `EXP_NAME`: Name of the experiment;
132 | * `ROOT_DIRECTORY`: Directory of the input video sequence;
133 | * `MODEL_SAVE_PATH`: Path to save the checkpoints;
134 | * `LOG_SAVE_PATH`: Path to save the logs;
135 | * `MASK_DIRECTORY`: Directory of the preprocessed masks (optional);
136 | * `FLOW_DIRECTORY`: Directory of the preprocessed optical flows (optional);
137 |
138 | Please check configuration files in ``configs/``, and you can always add your own model config.
139 |
140 | ## Test reconstruction
141 |
142 | ```shell
143 | ./scripts/test_multi.sh
144 | ```
145 | After running the script, the reconstructed videos can be found in `results/all_sequences/{NAME}/{EXP_NAME}`, along with the canonical image.
146 |
147 | ## Test video translation
148 |
149 | After obtaining the canonical image through [this step](#anchor), use your preferred text prompts to transfer it using [ControlNet](https://github.com/lllyasviel/ControlNet).
150 | Once you have the transferred canonical image, place it in `all_sequences/${NAME}/${EXP_NAME}_control` (i.e. `CANONICAL_DIR` in `scripts/test_canonical.sh`).
151 |
152 | Then run
153 |
154 | ```shell
155 | ./scripts/test_canonical.sh
156 | ```
157 |
158 | The transferred results can be seen in `results/all_sequences/{NAME}/{EXP_NAME}_transformed`.
159 |
160 | *Note*: The `canonical_wh` option in the configuration file should be set with caution, usually a little larger than `img_wh`, as it determines the field of view of the canonical image.
161 |
162 | ### BibTeX
163 |
164 | ```bibtex
165 | @article{ouyang2023codef,
166 | title={CoDeF: Content Deformation Fields for Temporally Consistent Video Processing},
167 | author={Hao Ouyang and Qiuyu Wang and Yuxi Xiao and Qingyan Bai and Juntao Zhang and Kecheng Zheng and Xiaowei Zhou and Qifeng Chen and Yujun Shen},
168 | journal={arXiv preprint arXiv:2308.07926},
169 | year={2023}
170 | }
171 | ```
172 |
173 | ### Acknowledgements
174 | We thank [camenduru](https://github.com/camenduru) for providing the [colab demo](https://github.com/camenduru/CoDeF-colab).
175 |
--------------------------------------------------------------------------------
/configs/beauty_0/base.yaml:
--------------------------------------------------------------------------------
1 | mask_dir: null
2 |
3 | img_wh: [540, 540]
4 | canonical_wh: [640, 640]
5 |
6 | lr: 0.001
7 | bg_loss: 0.003
8 |
9 | ref_idx: null # 0
10 |
11 | N_xyz_w: [8,]
12 | flow_loss: 1
13 | flow_step: -1
14 | self_bg: True
15 |
16 | deform_hash: True
17 | vid_hash: True
18 |
19 | num_steps: 10000
20 | decay_step: [2500, 5000, 7500]
21 | annealed_begin_step: 4000
22 | annealed_step: 4000
23 | save_model_iters: 2000
24 |
25 | fps: 15
26 |
--------------------------------------------------------------------------------
/configs/beauty_1/base.yaml:
--------------------------------------------------------------------------------
1 | img_wh: [540, 540]
2 | canonical_wh: [540, 540]
3 |
4 | lr: 0.001
5 | bg_loss: 0.003
6 |
7 | ref_idx: null # 0
8 |
9 | N_xyz_w: [8, 8]
10 | flow_loss: 1
11 | flow_step: -1
12 | self_bg: True
13 |
14 | deform_hash: True
15 | vid_hash: True
16 |
17 | num_steps: 10000
18 | decay_step: [2500, 5000, 7500]
19 | annealed_begin_step: 4000
20 | annealed_step: 4000
21 | save_model_iters: 2000
22 |
23 | fps: 15
24 |
--------------------------------------------------------------------------------
/configs/hash.json:
--------------------------------------------------------------------------------
1 | {
2 | "encoding_deform3d": {
3 | "otype": "HashGrid",
4 | "n_levels": 16,
5 | "n_features_per_level": 2,
6 | "log2_hashmap_size": 19,
7 | "base_resolution": 16,
8 | "per_level_scale": 1.38
9 | },
10 | "encoding": {
11 | "otype": "HashGrid",
12 | "n_levels": 16,
13 | "n_features_per_level": 2,
14 | "log2_hashmap_size": 19,
15 | "base_resolution": 16,
16 | "per_level_scale": 1.44
17 | },
18 | "encoding_deform2d": {
19 | "otype": "HashGrid",
20 | "n_levels": 16,
21 | "n_features_per_level": 2,
22 | "log2_hashmap_size": 19,
23 | "base_resolution": 16,
24 | "per_level_scale": 1.44
25 | },
26 | "network": {
27 | "otype": "FullyFusedMLP",
28 | "activation": "ReLU",
29 | "output_activation": "None",
30 | "n_neurons": 64,
31 | "n_hidden_layers": 2
32 | },
33 | "network_deform": {
34 | "otype": "FullyFusedMLP",
35 | "activation": "ReLU",
36 | "output_activation": "None",
37 | "n_neurons": 64,
38 | "n_hidden_layers": 8
39 | },
40 | "network_occ": {
41 | "otype": "FullyFusedMLP",
42 | "activation": "ReLU",
43 | "output_activation": "Sigmoid",
44 | "n_neurons": 64,
45 | "n_hidden_layers": 2
46 | },
47 | "time_code": 64
48 | }
49 |
--------------------------------------------------------------------------------
/configs/lemon_hit/base.yaml:
--------------------------------------------------------------------------------
1 | mask_dir: null
2 | flow_dir: null
3 |
4 | img_wh: [960, 540]
5 | canonical_wh: [1280, 640]
6 |
7 | lr: 0.001
8 | bg_loss: 0.003
9 |
10 | ref_idx: null # 0
11 |
12 | N_xyz_w: [8,]
13 | flow_loss: 0
14 | flow_step: -1
15 | self_bg: True
16 |
17 | deform_hash: True
18 | vid_hash: True
19 |
20 | num_steps: 10000
21 | decay_step: [2500, 5000, 7500]
22 | annealed_begin_step: 4000
23 | annealed_step: 4000
24 | save_model_iters: 2000
25 |
--------------------------------------------------------------------------------
/configs/scene_0/base.yaml:
--------------------------------------------------------------------------------
1 | mask_dir: null
2 |
3 | img_wh: [540, 960]
4 | canonical_wh: [720, 1280]
5 |
6 | lr: 0.001
7 | bg_loss: 0.003
8 |
9 | ref_idx: null # 0
10 |
11 | N_xyz_w: [8,]
12 | flow_loss: 1
13 | flow_step: -1
14 | self_bg: True
15 |
16 | deform_hash: True
17 | vid_hash: True
18 |
19 | num_steps: 10000
20 | decay_step: [2500, 5000, 7500]
21 | annealed_begin_step: 4000
22 | annealed_step: 4000
23 | save_model_iters: 2000
24 |
--------------------------------------------------------------------------------
/configs/white_smoke/base.yaml:
--------------------------------------------------------------------------------
1 | mask_dir: null
2 | flow_dir: null
3 |
4 | img_wh: [960, 540]
5 | canonical_wh: [1280, 640]
6 |
7 | lr: 0.001
8 | bg_loss: 0.003
9 |
10 | ref_idx: null # 0
11 |
12 | N_xyz_w: [8,]
13 | flow_loss: 0
14 | flow_step: -1
15 | self_bg: True
16 |
17 | deform_hash: True
18 | vid_hash: True
19 |
20 | num_steps: 10000
21 | decay_step: [2500, 5000, 7500]
22 | annealed_begin_step: 4000
23 | annealed_step: 4000
24 | save_model_iters: 2000
25 |
--------------------------------------------------------------------------------
/data_preprocessing/RAFT/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/data_preprocessing/RAFT/core/__init__.py
--------------------------------------------------------------------------------
/data_preprocessing/RAFT/core/corr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from utils.utils import bilinear_sampler, coords_grid
4 |
5 | try:
6 | import alt_cuda_corr
7 | except:
8 | # alt_cuda_corr is not compiled
9 | pass
10 |
11 |
12 | class CorrBlock:
13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
14 | self.num_levels = num_levels
15 | self.radius = radius
16 | self.corr_pyramid = []
17 |
18 | # all pairs correlation
19 | corr = CorrBlock.corr(fmap1, fmap2)
20 |
21 | batch, h1, w1, dim, h2, w2 = corr.shape
22 | corr = corr.reshape(batch*h1*w1, dim, h2, w2)
23 |
24 | self.corr_pyramid.append(corr)
25 | for i in range(self.num_levels-1):
26 | corr = F.avg_pool2d(corr, 2, stride=2)
27 | self.corr_pyramid.append(corr)
28 |
29 | def __call__(self, coords):
30 | r = self.radius
31 | coords = coords.permute(0, 2, 3, 1)
32 | batch, h1, w1, _ = coords.shape
33 |
34 | out_pyramid = []
35 | for i in range(self.num_levels):
36 | corr = self.corr_pyramid[i]
37 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
38 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
39 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
40 |
41 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
42 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
43 | coords_lvl = centroid_lvl + delta_lvl
44 |
45 | corr = bilinear_sampler(corr, coords_lvl)
46 | corr = corr.view(batch, h1, w1, -1)
47 | out_pyramid.append(corr)
48 |
49 | out = torch.cat(out_pyramid, dim=-1)
50 | return out.permute(0, 3, 1, 2).contiguous().float()
51 |
52 | @staticmethod
53 | def corr(fmap1, fmap2):
54 | batch, dim, ht, wd = fmap1.shape
55 | fmap1 = fmap1.view(batch, dim, ht*wd)
56 | fmap2 = fmap2.view(batch, dim, ht*wd)
57 |
58 | corr = torch.matmul(fmap1.transpose(1,2), fmap2)
59 | corr = corr.view(batch, ht, wd, 1, ht, wd)
60 | return corr / torch.sqrt(torch.tensor(dim).float())
61 |
62 |
63 | class AlternateCorrBlock:
64 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
65 | self.num_levels = num_levels
66 | self.radius = radius
67 |
68 | self.pyramid = [(fmap1, fmap2)]
69 | for i in range(self.num_levels):
70 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
71 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
72 | self.pyramid.append((fmap1, fmap2))
73 |
74 | def __call__(self, coords):
75 | coords = coords.permute(0, 2, 3, 1)
76 | B, H, W, _ = coords.shape
77 | dim = self.pyramid[0][0].shape[1]
78 |
79 | corr_list = []
80 | for i in range(self.num_levels):
81 | r = self.radius
82 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
83 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
84 |
85 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
86 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
87 | corr_list.append(corr.squeeze(1))
88 |
89 | corr = torch.stack(corr_list, dim=1)
90 | corr = corr.reshape(B, -1, H, W)
91 | return corr / torch.sqrt(torch.tensor(dim).float())
92 |
--------------------------------------------------------------------------------
/data_preprocessing/RAFT/core/datasets.py:
--------------------------------------------------------------------------------
1 | # Data loading based on https://github.com/NVIDIA/flownet2-pytorch
2 |
3 | import numpy as np
4 | import torch
5 | import torch.utils.data as data
6 | import torch.nn.functional as F
7 |
8 | import os
9 | import math
10 | import random
11 | from glob import glob
12 | import os.path as osp
13 |
14 | from utils import frame_utils
15 | from utils.augmentor import FlowAugmentor, SparseFlowAugmentor
16 |
17 |
18 | class FlowDataset(data.Dataset):
19 | def __init__(self, aug_params=None, sparse=False):
20 | self.augmentor = None
21 | self.sparse = sparse
22 | if aug_params is not None:
23 | if sparse:
24 | self.augmentor = SparseFlowAugmentor(**aug_params)
25 | else:
26 | self.augmentor = FlowAugmentor(**aug_params)
27 |
28 | self.is_test = False
29 | self.init_seed = False
30 | self.flow_list = []
31 | self.image_list = []
32 | self.extra_info = []
33 |
34 | def __getitem__(self, index):
35 |
36 | if self.is_test:
37 | img1 = frame_utils.read_gen(self.image_list[index][0])
38 | img2 = frame_utils.read_gen(self.image_list[index][1])
39 | img1 = np.array(img1).astype(np.uint8)[..., :3]
40 | img2 = np.array(img2).astype(np.uint8)[..., :3]
41 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
42 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
43 | return img1, img2, self.extra_info[index]
44 |
45 | if not self.init_seed:
46 | worker_info = torch.utils.data.get_worker_info()
47 | if worker_info is not None:
48 | torch.manual_seed(worker_info.id)
49 | np.random.seed(worker_info.id)
50 | random.seed(worker_info.id)
51 | self.init_seed = True
52 |
53 | index = index % len(self.image_list)
54 | valid = None
55 | if self.sparse:
56 | flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
57 | else:
58 | flow = frame_utils.read_gen(self.flow_list[index])
59 |
60 | img1 = frame_utils.read_gen(self.image_list[index][0])
61 | img2 = frame_utils.read_gen(self.image_list[index][1])
62 |
63 | flow = np.array(flow).astype(np.float32)
64 | img1 = np.array(img1).astype(np.uint8)
65 | img2 = np.array(img2).astype(np.uint8)
66 |
67 | # grayscale images
68 | if len(img1.shape) == 2:
69 | img1 = np.tile(img1[...,None], (1, 1, 3))
70 | img2 = np.tile(img2[...,None], (1, 1, 3))
71 | else:
72 | img1 = img1[..., :3]
73 | img2 = img2[..., :3]
74 |
75 | if self.augmentor is not None:
76 | if self.sparse:
77 | img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
78 | else:
79 | img1, img2, flow = self.augmentor(img1, img2, flow)
80 |
81 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
82 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
83 | flow = torch.from_numpy(flow).permute(2, 0, 1).float()
84 |
85 | if valid is not None:
86 | valid = torch.from_numpy(valid)
87 | else:
88 | valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
89 |
90 | return img1, img2, flow, valid.float()
91 |
92 |
93 | def __rmul__(self, v):
94 | self.flow_list = v * self.flow_list
95 | self.image_list = v * self.image_list
96 | return self
97 |
98 | def __len__(self):
99 | return len(self.image_list)
100 |
101 |
102 | class MpiSintel(FlowDataset):
103 | def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'):
104 | super(MpiSintel, self).__init__(aug_params)
105 | flow_root = osp.join(root, split, 'flow')
106 | image_root = osp.join(root, split, dstype)
107 |
108 | if split == 'test':
109 | self.is_test = True
110 |
111 | for scene in os.listdir(image_root):
112 | image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
113 | for i in range(len(image_list)-1):
114 | self.image_list += [ [image_list[i], image_list[i+1]] ]
115 | self.extra_info += [ (scene, i) ] # scene and frame_id
116 |
117 | if split != 'test':
118 | self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
119 |
120 |
121 | class FlyingChairs(FlowDataset):
122 | def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):
123 | super(FlyingChairs, self).__init__(aug_params)
124 |
125 | images = sorted(glob(osp.join(root, '*.ppm')))
126 | flows = sorted(glob(osp.join(root, '*.flo')))
127 | assert (len(images)//2 == len(flows))
128 |
129 | split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
130 | for i in range(len(flows)):
131 | xid = split_list[i]
132 | if (split=='training' and xid==1) or (split=='validation' and xid==2):
133 | self.flow_list += [ flows[i] ]
134 | self.image_list += [ [images[2*i], images[2*i+1]] ]
135 |
136 |
137 | class FlyingThings3D(FlowDataset):
138 | def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'):
139 | super(FlyingThings3D, self).__init__(aug_params)
140 |
141 | for cam in ['left']:
142 | for direction in ['into_future', 'into_past']:
143 | image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*')))
144 | image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
145 |
146 | flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))
147 | flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])
148 |
149 | for idir, fdir in zip(image_dirs, flow_dirs):
150 | images = sorted(glob(osp.join(idir, '*.png')) )
151 | flows = sorted(glob(osp.join(fdir, '*.pfm')) )
152 | for i in range(len(flows)-1):
153 | if direction == 'into_future':
154 | self.image_list += [ [images[i], images[i+1]] ]
155 | self.flow_list += [ flows[i] ]
156 | elif direction == 'into_past':
157 | self.image_list += [ [images[i+1], images[i]] ]
158 | self.flow_list += [ flows[i+1] ]
159 |
160 |
161 | class KITTI(FlowDataset):
162 | def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):
163 | super(KITTI, self).__init__(aug_params, sparse=True)
164 | if split == 'testing':
165 | self.is_test = True
166 |
167 | root = osp.join(root, split)
168 | images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
169 | images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
170 |
171 | for img1, img2 in zip(images1, images2):
172 | frame_id = img1.split('/')[-1]
173 | self.extra_info += [ [frame_id] ]
174 | self.image_list += [ [img1, img2] ]
175 |
176 | if split == 'training':
177 | self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
178 |
179 |
180 | class HD1K(FlowDataset):
181 | def __init__(self, aug_params=None, root='datasets/HD1k'):
182 | super(HD1K, self).__init__(aug_params, sparse=True)
183 |
184 | seq_ix = 0
185 | while 1:
186 | flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
187 | images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))
188 |
189 | if len(flows) == 0:
190 | break
191 |
192 | for i in range(len(flows)-1):
193 | self.flow_list += [flows[i]]
194 | self.image_list += [ [images[i], images[i+1]] ]
195 |
196 | seq_ix += 1
197 |
198 |
199 | def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
200 | """ Create the data loader for the corresponding trainign set """
201 |
202 | if args.stage == 'chairs':
203 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
204 | train_dataset = FlyingChairs(aug_params, split='training')
205 |
206 | elif args.stage == 'things':
207 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
208 | clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
209 | final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
210 | train_dataset = clean_dataset + final_dataset
211 |
212 | elif args.stage == 'sintel':
213 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
214 | things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
215 | sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
216 | sintel_final = MpiSintel(aug_params, split='training', dstype='final')
217 |
218 | if TRAIN_DS == 'C+T+K+S+H':
219 | kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
220 | hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
221 | train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
222 |
223 | elif TRAIN_DS == 'C+T+K/S':
224 | train_dataset = 100*sintel_clean + 100*sintel_final + things
225 |
226 | elif args.stage == 'kitti':
227 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
228 | train_dataset = KITTI(aug_params, split='training')
229 |
230 | train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
231 | pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
232 |
233 | print('Training with %d image pairs' % len(train_dataset))
234 | return train_loader
235 |
236 |
--------------------------------------------------------------------------------
/data_preprocessing/RAFT/core/extractor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class ResidualBlock(nn.Module):
7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
8 | super(ResidualBlock, self).__init__()
9 |
10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
12 | self.relu = nn.ReLU(inplace=True)
13 |
14 | num_groups = planes // 8
15 |
16 | if norm_fn == 'group':
17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
19 | if not stride == 1:
20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
21 |
22 | elif norm_fn == 'batch':
23 | self.norm1 = nn.BatchNorm2d(planes)
24 | self.norm2 = nn.BatchNorm2d(planes)
25 | if not stride == 1:
26 | self.norm3 = nn.BatchNorm2d(planes)
27 |
28 | elif norm_fn == 'instance':
29 | self.norm1 = nn.InstanceNorm2d(planes)
30 | self.norm2 = nn.InstanceNorm2d(planes)
31 | if not stride == 1:
32 | self.norm3 = nn.InstanceNorm2d(planes)
33 |
34 | elif norm_fn == 'none':
35 | self.norm1 = nn.Sequential()
36 | self.norm2 = nn.Sequential()
37 | if not stride == 1:
38 | self.norm3 = nn.Sequential()
39 |
40 | if stride == 1:
41 | self.downsample = None
42 |
43 | else:
44 | self.downsample = nn.Sequential(
45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
46 |
47 |
48 | def forward(self, x):
49 | y = x
50 | y = self.relu(self.norm1(self.conv1(y)))
51 | y = self.relu(self.norm2(self.conv2(y)))
52 |
53 | if self.downsample is not None:
54 | x = self.downsample(x)
55 |
56 | return self.relu(x+y)
57 |
58 |
59 |
60 | class BottleneckBlock(nn.Module):
61 | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
62 | super(BottleneckBlock, self).__init__()
63 |
64 | self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
65 | self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
66 | self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
67 | self.relu = nn.ReLU(inplace=True)
68 |
69 | num_groups = planes // 8
70 |
71 | if norm_fn == 'group':
72 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
73 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
74 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
75 | if not stride == 1:
76 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
77 |
78 | elif norm_fn == 'batch':
79 | self.norm1 = nn.BatchNorm2d(planes//4)
80 | self.norm2 = nn.BatchNorm2d(planes//4)
81 | self.norm3 = nn.BatchNorm2d(planes)
82 | if not stride == 1:
83 | self.norm4 = nn.BatchNorm2d(planes)
84 |
85 | elif norm_fn == 'instance':
86 | self.norm1 = nn.InstanceNorm2d(planes//4)
87 | self.norm2 = nn.InstanceNorm2d(planes//4)
88 | self.norm3 = nn.InstanceNorm2d(planes)
89 | if not stride == 1:
90 | self.norm4 = nn.InstanceNorm2d(planes)
91 |
92 | elif norm_fn == 'none':
93 | self.norm1 = nn.Sequential()
94 | self.norm2 = nn.Sequential()
95 | self.norm3 = nn.Sequential()
96 | if not stride == 1:
97 | self.norm4 = nn.Sequential()
98 |
99 | if stride == 1:
100 | self.downsample = None
101 |
102 | else:
103 | self.downsample = nn.Sequential(
104 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
105 |
106 |
107 | def forward(self, x):
108 | y = x
109 | y = self.relu(self.norm1(self.conv1(y)))
110 | y = self.relu(self.norm2(self.conv2(y)))
111 | y = self.relu(self.norm3(self.conv3(y)))
112 |
113 | if self.downsample is not None:
114 | x = self.downsample(x)
115 |
116 | return self.relu(x+y)
117 |
118 | class BasicEncoder(nn.Module):
119 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
120 | super(BasicEncoder, self).__init__()
121 | self.norm_fn = norm_fn
122 |
123 | if self.norm_fn == 'group':
124 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
125 |
126 | elif self.norm_fn == 'batch':
127 | self.norm1 = nn.BatchNorm2d(64)
128 |
129 | elif self.norm_fn == 'instance':
130 | self.norm1 = nn.InstanceNorm2d(64)
131 |
132 | elif self.norm_fn == 'none':
133 | self.norm1 = nn.Sequential()
134 |
135 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
136 | self.relu1 = nn.ReLU(inplace=True)
137 |
138 | self.in_planes = 64
139 | self.layer1 = self._make_layer(64, stride=1)
140 | self.layer2 = self._make_layer(96, stride=2)
141 | self.layer3 = self._make_layer(128, stride=2)
142 |
143 | # output convolution
144 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
145 |
146 | self.dropout = None
147 | if dropout > 0:
148 | self.dropout = nn.Dropout2d(p=dropout)
149 |
150 | for m in self.modules():
151 | if isinstance(m, nn.Conv2d):
152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
153 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
154 | if m.weight is not None:
155 | nn.init.constant_(m.weight, 1)
156 | if m.bias is not None:
157 | nn.init.constant_(m.bias, 0)
158 |
159 | def _make_layer(self, dim, stride=1):
160 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
161 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
162 | layers = (layer1, layer2)
163 |
164 | self.in_planes = dim
165 | return nn.Sequential(*layers)
166 |
167 |
168 | def forward(self, x):
169 |
170 | # if input is list, combine batch dimension
171 | is_list = isinstance(x, tuple) or isinstance(x, list)
172 | if is_list:
173 | batch_dim = x[0].shape[0]
174 | x = torch.cat(x, dim=0)
175 |
176 | x = self.conv1(x)
177 | x = self.norm1(x)
178 | x = self.relu1(x)
179 |
180 | x = self.layer1(x)
181 | x = self.layer2(x)
182 | x = self.layer3(x)
183 |
184 | x = self.conv2(x)
185 |
186 | if self.training and self.dropout is not None:
187 | x = self.dropout(x)
188 |
189 | if is_list:
190 | x = torch.split(x, [batch_dim, batch_dim], dim=0)
191 |
192 | return x
193 |
194 |
195 | class SmallEncoder(nn.Module):
196 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
197 | super(SmallEncoder, self).__init__()
198 | self.norm_fn = norm_fn
199 |
200 | if self.norm_fn == 'group':
201 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
202 |
203 | elif self.norm_fn == 'batch':
204 | self.norm1 = nn.BatchNorm2d(32)
205 |
206 | elif self.norm_fn == 'instance':
207 | self.norm1 = nn.InstanceNorm2d(32)
208 |
209 | elif self.norm_fn == 'none':
210 | self.norm1 = nn.Sequential()
211 |
212 | self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
213 | self.relu1 = nn.ReLU(inplace=True)
214 |
215 | self.in_planes = 32
216 | self.layer1 = self._make_layer(32, stride=1)
217 | self.layer2 = self._make_layer(64, stride=2)
218 | self.layer3 = self._make_layer(96, stride=2)
219 |
220 | self.dropout = None
221 | if dropout > 0:
222 | self.dropout = nn.Dropout2d(p=dropout)
223 |
224 | self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
225 |
226 | for m in self.modules():
227 | if isinstance(m, nn.Conv2d):
228 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
229 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
230 | if m.weight is not None:
231 | nn.init.constant_(m.weight, 1)
232 | if m.bias is not None:
233 | nn.init.constant_(m.bias, 0)
234 |
235 | def _make_layer(self, dim, stride=1):
236 | layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
237 | layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
238 | layers = (layer1, layer2)
239 |
240 | self.in_planes = dim
241 | return nn.Sequential(*layers)
242 |
243 |
244 | def forward(self, x):
245 |
246 | # if input is list, combine batch dimension
247 | is_list = isinstance(x, tuple) or isinstance(x, list)
248 | if is_list:
249 | batch_dim = x[0].shape[0]
250 | x = torch.cat(x, dim=0)
251 |
252 | x = self.conv1(x)
253 | x = self.norm1(x)
254 | x = self.relu1(x)
255 |
256 | x = self.layer1(x)
257 | x = self.layer2(x)
258 | x = self.layer3(x)
259 | x = self.conv2(x)
260 |
261 | if self.training and self.dropout is not None:
262 | x = self.dropout(x)
263 |
264 | if is_list:
265 | x = torch.split(x, [batch_dim, batch_dim], dim=0)
266 |
267 | return x
268 |
--------------------------------------------------------------------------------
/data_preprocessing/RAFT/core/raft.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from update import BasicUpdateBlock, SmallUpdateBlock
7 | from extractor import BasicEncoder, SmallEncoder
8 | from corr import CorrBlock, AlternateCorrBlock
9 | from utils.utils import bilinear_sampler, coords_grid, upflow8
10 |
11 | try:
12 | autocast = torch.cuda.amp.autocast
13 | except:
14 | # dummy autocast for PyTorch < 1.6
15 | class autocast:
16 | def __init__(self, enabled):
17 | pass
18 | def __enter__(self):
19 | pass
20 | def __exit__(self, *args):
21 | pass
22 |
23 |
24 | class RAFT(nn.Module):
25 | def __init__(self, args):
26 | super(RAFT, self).__init__()
27 | self.args = args
28 |
29 | if args.small:
30 | self.hidden_dim = hdim = 96
31 | self.context_dim = cdim = 64
32 | args.corr_levels = 4
33 | args.corr_radius = 3
34 |
35 | else:
36 | self.hidden_dim = hdim = 128
37 | self.context_dim = cdim = 128
38 | args.corr_levels = 4
39 | args.corr_radius = 4
40 |
41 | if 'dropout' not in self.args:
42 | self.args.dropout = 0
43 |
44 | if 'alternate_corr' not in self.args:
45 | self.args.alternate_corr = False
46 |
47 | # feature network, context network, and update block
48 | if args.small:
49 | self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
50 | self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
51 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
52 |
53 | else:
54 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
55 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
56 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
57 |
58 | def freeze_bn(self):
59 | for m in self.modules():
60 | if isinstance(m, nn.BatchNorm2d):
61 | m.eval()
62 |
63 | def initialize_flow(self, img):
64 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
65 | N, C, H, W = img.shape
66 | coords0 = coords_grid(N, H//8, W//8, device=img.device)
67 | coords1 = coords_grid(N, H//8, W//8, device=img.device)
68 |
69 | # optical flow computed as difference: flow = coords1 - coords0
70 | return coords0, coords1
71 |
72 | def upsample_flow(self, flow, mask):
73 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
74 | N, _, H, W = flow.shape
75 | mask = mask.view(N, 1, 9, 8, 8, H, W)
76 | mask = torch.softmax(mask, dim=2)
77 |
78 | up_flow = F.unfold(8 * flow, [3,3], padding=1)
79 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
80 |
81 | up_flow = torch.sum(mask * up_flow, dim=2)
82 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
83 | return up_flow.reshape(N, 2, 8*H, 8*W)
84 |
85 |
86 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
87 | """ Estimate optical flow between pair of frames """
88 |
89 | image1 = 2 * (image1 / 255.0) - 1.0
90 | image2 = 2 * (image2 / 255.0) - 1.0
91 |
92 | image1 = image1.contiguous()
93 | image2 = image2.contiguous()
94 |
95 | hdim = self.hidden_dim
96 | cdim = self.context_dim
97 |
98 | # run the feature network
99 | with autocast(enabled=self.args.mixed_precision):
100 | fmap1, fmap2 = self.fnet([image1, image2])
101 |
102 | fmap1 = fmap1.float()
103 | fmap2 = fmap2.float()
104 | if self.args.alternate_corr:
105 | corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
106 | else:
107 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
108 |
109 | # run the context network
110 | with autocast(enabled=self.args.mixed_precision):
111 | cnet = self.cnet(image1)
112 | net, inp = torch.split(cnet, [hdim, cdim], dim=1)
113 | net = torch.tanh(net)
114 | inp = torch.relu(inp)
115 |
116 | coords0, coords1 = self.initialize_flow(image1)
117 |
118 | if flow_init is not None:
119 | coords1 = coords1 + flow_init
120 |
121 | flow_predictions = []
122 | for itr in range(iters):
123 | coords1 = coords1.detach()
124 | corr = corr_fn(coords1) # index correlation volume
125 |
126 | flow = coords1 - coords0
127 | with autocast(enabled=self.args.mixed_precision):
128 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
129 |
130 | # F(t+1) = F(t) + \Delta(t)
131 | coords1 = coords1 + delta_flow
132 |
133 | # upsample predictions
134 | if up_mask is None:
135 | flow_up = upflow8(coords1 - coords0)
136 | else:
137 | flow_up = self.upsample_flow(coords1 - coords0, up_mask)
138 |
139 | flow_predictions.append(flow_up)
140 |
141 | if test_mode:
142 | return coords1 - coords0, flow_up
143 |
144 | return flow_predictions
145 |
--------------------------------------------------------------------------------
/data_preprocessing/RAFT/core/update.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class FlowHead(nn.Module):
7 | def __init__(self, input_dim=128, hidden_dim=256):
8 | super(FlowHead, self).__init__()
9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
11 | self.relu = nn.ReLU(inplace=True)
12 |
13 | def forward(self, x):
14 | return self.conv2(self.relu(self.conv1(x)))
15 |
16 | class ConvGRU(nn.Module):
17 | def __init__(self, hidden_dim=128, input_dim=192+128):
18 | super(ConvGRU, self).__init__()
19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
22 |
23 | def forward(self, h, x):
24 | hx = torch.cat([h, x], dim=1)
25 |
26 | z = torch.sigmoid(self.convz(hx))
27 | r = torch.sigmoid(self.convr(hx))
28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
29 |
30 | h = (1-z) * h + z * q
31 | return h
32 |
33 | class SepConvGRU(nn.Module):
34 | def __init__(self, hidden_dim=128, input_dim=192+128):
35 | super(SepConvGRU, self).__init__()
36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
39 |
40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
43 |
44 |
45 | def forward(self, h, x):
46 | # horizontal
47 | hx = torch.cat([h, x], dim=1)
48 | z = torch.sigmoid(self.convz1(hx))
49 | r = torch.sigmoid(self.convr1(hx))
50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
51 | h = (1-z) * h + z * q
52 |
53 | # vertical
54 | hx = torch.cat([h, x], dim=1)
55 | z = torch.sigmoid(self.convz2(hx))
56 | r = torch.sigmoid(self.convr2(hx))
57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
58 | h = (1-z) * h + z * q
59 |
60 | return h
61 |
62 | class SmallMotionEncoder(nn.Module):
63 | def __init__(self, args):
64 | super(SmallMotionEncoder, self).__init__()
65 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
66 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
67 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
68 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
69 | self.conv = nn.Conv2d(128, 80, 3, padding=1)
70 |
71 | def forward(self, flow, corr):
72 | cor = F.relu(self.convc1(corr))
73 | flo = F.relu(self.convf1(flow))
74 | flo = F.relu(self.convf2(flo))
75 | cor_flo = torch.cat([cor, flo], dim=1)
76 | out = F.relu(self.conv(cor_flo))
77 | return torch.cat([out, flow], dim=1)
78 |
79 | class BasicMotionEncoder(nn.Module):
80 | def __init__(self, args):
81 | super(BasicMotionEncoder, self).__init__()
82 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
83 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
84 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
85 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
86 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
87 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
88 |
89 | def forward(self, flow, corr):
90 | cor = F.relu(self.convc1(corr))
91 | cor = F.relu(self.convc2(cor))
92 | flo = F.relu(self.convf1(flow))
93 | flo = F.relu(self.convf2(flo))
94 |
95 | cor_flo = torch.cat([cor, flo], dim=1)
96 | out = F.relu(self.conv(cor_flo))
97 | return torch.cat([out, flow], dim=1)
98 |
99 | class SmallUpdateBlock(nn.Module):
100 | def __init__(self, args, hidden_dim=96):
101 | super(SmallUpdateBlock, self).__init__()
102 | self.encoder = SmallMotionEncoder(args)
103 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
104 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
105 |
106 | def forward(self, net, inp, corr, flow):
107 | motion_features = self.encoder(flow, corr)
108 | inp = torch.cat([inp, motion_features], dim=1)
109 | net = self.gru(net, inp)
110 | delta_flow = self.flow_head(net)
111 |
112 | return net, None, delta_flow
113 |
114 | class BasicUpdateBlock(nn.Module):
115 | def __init__(self, args, hidden_dim=128, input_dim=128):
116 | super(BasicUpdateBlock, self).__init__()
117 | self.args = args
118 | self.encoder = BasicMotionEncoder(args)
119 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
120 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
121 |
122 | self.mask = nn.Sequential(
123 | nn.Conv2d(128, 256, 3, padding=1),
124 | nn.ReLU(inplace=True),
125 | nn.Conv2d(256, 64*9, 1, padding=0))
126 |
127 | def forward(self, net, inp, corr, flow, upsample=True):
128 | motion_features = self.encoder(flow, corr)
129 | inp = torch.cat([inp, motion_features], dim=1)
130 |
131 | net = self.gru(net, inp)
132 | delta_flow = self.flow_head(net)
133 |
134 | # scale mask to balence gradients
135 | mask = .25 * self.mask(net)
136 | return net, mask, delta_flow
137 |
138 |
139 |
140 |
--------------------------------------------------------------------------------
/data_preprocessing/RAFT/core/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/data_preprocessing/RAFT/core/utils/__init__.py
--------------------------------------------------------------------------------
/data_preprocessing/RAFT/core/utils/augmentor.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import math
4 | from PIL import Image
5 |
6 | import cv2
7 | cv2.setNumThreads(0)
8 | cv2.ocl.setUseOpenCL(False)
9 |
10 | import torch
11 | from torchvision.transforms import ColorJitter
12 | import torch.nn.functional as F
13 |
14 |
15 | class FlowAugmentor:
16 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
17 |
18 | # spatial augmentation params
19 | self.crop_size = crop_size
20 | self.min_scale = min_scale
21 | self.max_scale = max_scale
22 | self.spatial_aug_prob = 0.8
23 | self.stretch_prob = 0.8
24 | self.max_stretch = 0.2
25 |
26 | # flip augmentation params
27 | self.do_flip = do_flip
28 | self.h_flip_prob = 0.5
29 | self.v_flip_prob = 0.1
30 |
31 | # photometric augmentation params
32 | self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
33 | self.asymmetric_color_aug_prob = 0.2
34 | self.eraser_aug_prob = 0.5
35 |
36 | def color_transform(self, img1, img2):
37 | """ Photometric augmentation """
38 |
39 | # asymmetric
40 | if np.random.rand() < self.asymmetric_color_aug_prob:
41 | img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
42 | img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
43 |
44 | # symmetric
45 | else:
46 | image_stack = np.concatenate([img1, img2], axis=0)
47 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
48 | img1, img2 = np.split(image_stack, 2, axis=0)
49 |
50 | return img1, img2
51 |
52 | def eraser_transform(self, img1, img2, bounds=[50, 100]):
53 | """ Occlusion augmentation """
54 |
55 | ht, wd = img1.shape[:2]
56 | if np.random.rand() < self.eraser_aug_prob:
57 | mean_color = np.mean(img2.reshape(-1, 3), axis=0)
58 | for _ in range(np.random.randint(1, 3)):
59 | x0 = np.random.randint(0, wd)
60 | y0 = np.random.randint(0, ht)
61 | dx = np.random.randint(bounds[0], bounds[1])
62 | dy = np.random.randint(bounds[0], bounds[1])
63 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color
64 |
65 | return img1, img2
66 |
67 | def spatial_transform(self, img1, img2, flow):
68 | # randomly sample scale
69 | ht, wd = img1.shape[:2]
70 | min_scale = np.maximum(
71 | (self.crop_size[0] + 8) / float(ht),
72 | (self.crop_size[1] + 8) / float(wd))
73 |
74 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
75 | scale_x = scale
76 | scale_y = scale
77 | if np.random.rand() < self.stretch_prob:
78 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
79 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
80 |
81 | scale_x = np.clip(scale_x, min_scale, None)
82 | scale_y = np.clip(scale_y, min_scale, None)
83 |
84 | if np.random.rand() < self.spatial_aug_prob:
85 | # rescale the images
86 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
87 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
88 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
89 | flow = flow * [scale_x, scale_y]
90 |
91 | if self.do_flip:
92 | if np.random.rand() < self.h_flip_prob: # h-flip
93 | img1 = img1[:, ::-1]
94 | img2 = img2[:, ::-1]
95 | flow = flow[:, ::-1] * [-1.0, 1.0]
96 |
97 | if np.random.rand() < self.v_flip_prob: # v-flip
98 | img1 = img1[::-1, :]
99 | img2 = img2[::-1, :]
100 | flow = flow[::-1, :] * [1.0, -1.0]
101 |
102 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
103 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
104 |
105 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
106 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
107 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
108 |
109 | return img1, img2, flow
110 |
111 | def __call__(self, img1, img2, flow):
112 | img1, img2 = self.color_transform(img1, img2)
113 | img1, img2 = self.eraser_transform(img1, img2)
114 | img1, img2, flow = self.spatial_transform(img1, img2, flow)
115 |
116 | img1 = np.ascontiguousarray(img1)
117 | img2 = np.ascontiguousarray(img2)
118 | flow = np.ascontiguousarray(flow)
119 |
120 | return img1, img2, flow
121 |
122 | class SparseFlowAugmentor:
123 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
124 | # spatial augmentation params
125 | self.crop_size = crop_size
126 | self.min_scale = min_scale
127 | self.max_scale = max_scale
128 | self.spatial_aug_prob = 0.8
129 | self.stretch_prob = 0.8
130 | self.max_stretch = 0.2
131 |
132 | # flip augmentation params
133 | self.do_flip = do_flip
134 | self.h_flip_prob = 0.5
135 | self.v_flip_prob = 0.1
136 |
137 | # photometric augmentation params
138 | self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
139 | self.asymmetric_color_aug_prob = 0.2
140 | self.eraser_aug_prob = 0.5
141 |
142 | def color_transform(self, img1, img2):
143 | image_stack = np.concatenate([img1, img2], axis=0)
144 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
145 | img1, img2 = np.split(image_stack, 2, axis=0)
146 | return img1, img2
147 |
148 | def eraser_transform(self, img1, img2):
149 | ht, wd = img1.shape[:2]
150 | if np.random.rand() < self.eraser_aug_prob:
151 | mean_color = np.mean(img2.reshape(-1, 3), axis=0)
152 | for _ in range(np.random.randint(1, 3)):
153 | x0 = np.random.randint(0, wd)
154 | y0 = np.random.randint(0, ht)
155 | dx = np.random.randint(50, 100)
156 | dy = np.random.randint(50, 100)
157 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color
158 |
159 | return img1, img2
160 |
161 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
162 | ht, wd = flow.shape[:2]
163 | coords = np.meshgrid(np.arange(wd), np.arange(ht))
164 | coords = np.stack(coords, axis=-1)
165 |
166 | coords = coords.reshape(-1, 2).astype(np.float32)
167 | flow = flow.reshape(-1, 2).astype(np.float32)
168 | valid = valid.reshape(-1).astype(np.float32)
169 |
170 | coords0 = coords[valid>=1]
171 | flow0 = flow[valid>=1]
172 |
173 | ht1 = int(round(ht * fy))
174 | wd1 = int(round(wd * fx))
175 |
176 | coords1 = coords0 * [fx, fy]
177 | flow1 = flow0 * [fx, fy]
178 |
179 | xx = np.round(coords1[:,0]).astype(np.int32)
180 | yy = np.round(coords1[:,1]).astype(np.int32)
181 |
182 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
183 | xx = xx[v]
184 | yy = yy[v]
185 | flow1 = flow1[v]
186 |
187 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
188 | valid_img = np.zeros([ht1, wd1], dtype=np.int32)
189 |
190 | flow_img[yy, xx] = flow1
191 | valid_img[yy, xx] = 1
192 |
193 | return flow_img, valid_img
194 |
195 | def spatial_transform(self, img1, img2, flow, valid):
196 | # randomly sample scale
197 |
198 | ht, wd = img1.shape[:2]
199 | min_scale = np.maximum(
200 | (self.crop_size[0] + 1) / float(ht),
201 | (self.crop_size[1] + 1) / float(wd))
202 |
203 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
204 | scale_x = np.clip(scale, min_scale, None)
205 | scale_y = np.clip(scale, min_scale, None)
206 |
207 | if np.random.rand() < self.spatial_aug_prob:
208 | # rescale the images
209 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
210 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
211 | flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
212 |
213 | if self.do_flip:
214 | if np.random.rand() < 0.5: # h-flip
215 | img1 = img1[:, ::-1]
216 | img2 = img2[:, ::-1]
217 | flow = flow[:, ::-1] * [-1.0, 1.0]
218 | valid = valid[:, ::-1]
219 |
220 | margin_y = 20
221 | margin_x = 50
222 |
223 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
224 | x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
225 |
226 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
227 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
228 |
229 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
230 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
231 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
232 | valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
233 | return img1, img2, flow, valid
234 |
235 |
236 | def __call__(self, img1, img2, flow, valid):
237 | img1, img2 = self.color_transform(img1, img2)
238 | img1, img2 = self.eraser_transform(img1, img2)
239 | img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
240 |
241 | img1 = np.ascontiguousarray(img1)
242 | img2 = np.ascontiguousarray(img2)
243 | flow = np.ascontiguousarray(flow)
244 | valid = np.ascontiguousarray(valid)
245 |
246 | return img1, img2, flow, valid
247 |
--------------------------------------------------------------------------------
/data_preprocessing/RAFT/core/utils/flow_viz.py:
--------------------------------------------------------------------------------
1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
2 |
3 |
4 | # MIT License
5 | #
6 | # Copyright (c) 2018 Tom Runia
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to conditions.
14 | #
15 | # Author: Tom Runia
16 | # Date Created: 2018-08-03
17 |
18 | import numpy as np
19 |
20 | def make_colorwheel():
21 | """
22 | Generates a color wheel for optical flow visualization as presented in:
23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
25 |
26 | Code follows the original C++ source code of Daniel Scharstein.
27 | Code follows the the Matlab source code of Deqing Sun.
28 |
29 | Returns:
30 | np.ndarray: Color wheel
31 | """
32 |
33 | RY = 15
34 | YG = 6
35 | GC = 4
36 | CB = 11
37 | BM = 13
38 | MR = 6
39 |
40 | ncols = RY + YG + GC + CB + BM + MR
41 | colorwheel = np.zeros((ncols, 3))
42 | col = 0
43 |
44 | # RY
45 | colorwheel[0:RY, 0] = 255
46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
47 | col = col+RY
48 | # YG
49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
50 | colorwheel[col:col+YG, 1] = 255
51 | col = col+YG
52 | # GC
53 | colorwheel[col:col+GC, 1] = 255
54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
55 | col = col+GC
56 | # CB
57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
58 | colorwheel[col:col+CB, 2] = 255
59 | col = col+CB
60 | # BM
61 | colorwheel[col:col+BM, 2] = 255
62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
63 | col = col+BM
64 | # MR
65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
66 | colorwheel[col:col+MR, 0] = 255
67 | return colorwheel
68 |
69 |
70 | def flow_uv_to_colors(u, v, convert_to_bgr=False):
71 | """
72 | Applies the flow color wheel to (possibly clipped) flow components u and v.
73 |
74 | According to the C++ source code of Daniel Scharstein
75 | According to the Matlab source code of Deqing Sun
76 |
77 | Args:
78 | u (np.ndarray): Input horizontal flow of shape [H,W]
79 | v (np.ndarray): Input vertical flow of shape [H,W]
80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
81 |
82 | Returns:
83 | np.ndarray: Flow visualization image of shape [H,W,3]
84 | """
85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
86 | colorwheel = make_colorwheel() # shape [55x3]
87 | ncols = colorwheel.shape[0]
88 | rad = np.sqrt(np.square(u) + np.square(v))
89 | a = np.arctan2(-v, -u)/np.pi
90 | fk = (a+1) / 2*(ncols-1)
91 | k0 = np.floor(fk).astype(np.int32)
92 | k1 = k0 + 1
93 | k1[k1 == ncols] = 0
94 | f = fk - k0
95 | for i in range(colorwheel.shape[1]):
96 | tmp = colorwheel[:,i]
97 | col0 = tmp[k0] / 255.0
98 | col1 = tmp[k1] / 255.0
99 | col = (1-f)*col0 + f*col1
100 | idx = (rad <= 1)
101 | col[idx] = 1 - rad[idx] * (1-col[idx])
102 | col[~idx] = col[~idx] * 0.75 # out of range
103 | # Note the 2-i => BGR instead of RGB
104 | ch_idx = 2-i if convert_to_bgr else i
105 | flow_image[:,:,ch_idx] = np.floor(255 * col)
106 | return flow_image
107 |
108 |
109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
110 | """
111 | Expects a two dimensional flow image of shape.
112 |
113 | Args:
114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
117 |
118 | Returns:
119 | np.ndarray: Flow visualization image of shape [H,W,3]
120 | """
121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions'
122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
123 | if clip_flow is not None:
124 | flow_uv = np.clip(flow_uv, 0, clip_flow)
125 | u = flow_uv[:,:,0]
126 | v = flow_uv[:,:,1]
127 | rad = np.sqrt(np.square(u) + np.square(v))
128 | rad_max = np.max(rad)
129 | epsilon = 1e-5
130 | u = u / (rad_max + epsilon)
131 | v = v / (rad_max + epsilon)
132 | return flow_uv_to_colors(u, v, convert_to_bgr)
--------------------------------------------------------------------------------
/data_preprocessing/RAFT/core/utils/frame_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | from os.path import *
4 | import re
5 |
6 | import cv2
7 | cv2.setNumThreads(0)
8 | cv2.ocl.setUseOpenCL(False)
9 |
10 | TAG_CHAR = np.array([202021.25], np.float32)
11 |
12 | def readFlow(fn):
13 | """ Read .flo file in Middlebury format"""
14 | # Code adapted from:
15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
16 |
17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only!
18 | # print 'fn = %s'%(fn)
19 | with open(fn, 'rb') as f:
20 | magic = np.fromfile(f, np.float32, count=1)
21 | if 202021.25 != magic:
22 | print('Magic number incorrect. Invalid .flo file')
23 | return None
24 | else:
25 | w = np.fromfile(f, np.int32, count=1)
26 | h = np.fromfile(f, np.int32, count=1)
27 | # print 'Reading %d x %d flo file\n' % (w, h)
28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
29 | # Reshape data into 3D array (columns, rows, bands)
30 | # The reshape here is for visualization, the original code is (w,h,2)
31 | return np.resize(data, (int(h), int(w), 2))
32 |
33 | def readPFM(file):
34 | file = open(file, 'rb')
35 |
36 | color = None
37 | width = None
38 | height = None
39 | scale = None
40 | endian = None
41 |
42 | header = file.readline().rstrip()
43 | if header == b'PF':
44 | color = True
45 | elif header == b'Pf':
46 | color = False
47 | else:
48 | raise Exception('Not a PFM file.')
49 |
50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
51 | if dim_match:
52 | width, height = map(int, dim_match.groups())
53 | else:
54 | raise Exception('Malformed PFM header.')
55 |
56 | scale = float(file.readline().rstrip())
57 | if scale < 0: # little-endian
58 | endian = '<'
59 | scale = -scale
60 | else:
61 | endian = '>' # big-endian
62 |
63 | data = np.fromfile(file, endian + 'f')
64 | shape = (height, width, 3) if color else (height, width)
65 |
66 | data = np.reshape(data, shape)
67 | data = np.flipud(data)
68 | return data
69 |
70 | def writeFlow(filename,uv,v=None):
71 | """ Write optical flow to file.
72 |
73 | If v is None, uv is assumed to contain both u and v channels,
74 | stacked in depth.
75 | Original code by Deqing Sun, adapted from Daniel Scharstein.
76 | """
77 | nBands = 2
78 |
79 | if v is None:
80 | assert(uv.ndim == 3)
81 | assert(uv.shape[2] == 2)
82 | u = uv[:,:,0]
83 | v = uv[:,:,1]
84 | else:
85 | u = uv
86 |
87 | assert(u.shape == v.shape)
88 | height,width = u.shape
89 | f = open(filename,'wb')
90 | # write the header
91 | f.write(TAG_CHAR)
92 | np.array(width).astype(np.int32).tofile(f)
93 | np.array(height).astype(np.int32).tofile(f)
94 | # arrange into matrix form
95 | tmp = np.zeros((height, width*nBands))
96 | tmp[:,np.arange(width)*2] = u
97 | tmp[:,np.arange(width)*2 + 1] = v
98 | tmp.astype(np.float32).tofile(f)
99 | f.close()
100 |
101 |
102 | def readFlowKITTI(filename):
103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
104 | flow = flow[:,:,::-1].astype(np.float32)
105 | flow, valid = flow[:, :, :2], flow[:, :, 2]
106 | flow = (flow - 2**15) / 64.0
107 | return flow, valid
108 |
109 | def readDispKITTI(filename):
110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
111 | valid = disp > 0.0
112 | flow = np.stack([-disp, np.zeros_like(disp)], -1)
113 | return flow, valid
114 |
115 |
116 | def writeFlowKITTI(filename, uv):
117 | uv = 64.0 * uv + 2**15
118 | valid = np.ones([uv.shape[0], uv.shape[1], 1])
119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
120 | cv2.imwrite(filename, uv[..., ::-1])
121 |
122 |
123 | def read_gen(file_name, pil=False):
124 | ext = splitext(file_name)[-1]
125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
126 | return Image.open(file_name)
127 | elif ext == '.bin' or ext == '.raw':
128 | return np.load(file_name)
129 | elif ext == '.flo':
130 | return readFlow(file_name).astype(np.float32)
131 | elif ext == '.pfm':
132 | flow = readPFM(file_name).astype(np.float32)
133 | if len(flow.shape) == 2:
134 | return flow
135 | else:
136 | return flow[:, :, :-1]
137 | return []
--------------------------------------------------------------------------------
/data_preprocessing/RAFT/core/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from scipy import interpolate
5 |
6 |
7 | class InputPadder:
8 | """ Pads images such that dimensions are divisible by 8 """
9 | def __init__(self, dims, mode='sintel'):
10 | self.ht, self.wd = dims[-2:]
11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
13 | if mode == 'sintel':
14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
15 | else:
16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
17 |
18 | def pad(self, *inputs):
19 | return [F.pad(x, self._pad, mode='replicate') for x in inputs]
20 |
21 | def unpad(self,x):
22 | ht, wd = x.shape[-2:]
23 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
24 | return x[..., c[0]:c[1], c[2]:c[3]]
25 |
26 | def forward_interpolate(flow):
27 | flow = flow.detach().cpu().numpy()
28 | dx, dy = flow[0], flow[1]
29 |
30 | ht, wd = dx.shape
31 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
32 |
33 | x1 = x0 + dx
34 | y1 = y0 + dy
35 |
36 | x1 = x1.reshape(-1)
37 | y1 = y1.reshape(-1)
38 | dx = dx.reshape(-1)
39 | dy = dy.reshape(-1)
40 |
41 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
42 | x1 = x1[valid]
43 | y1 = y1[valid]
44 | dx = dx[valid]
45 | dy = dy[valid]
46 |
47 | flow_x = interpolate.griddata(
48 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
49 |
50 | flow_y = interpolate.griddata(
51 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
52 |
53 | flow = np.stack([flow_x, flow_y], axis=0)
54 | return torch.from_numpy(flow).float()
55 |
56 |
57 | def bilinear_sampler(img, coords, mode='bilinear', mask=False):
58 | """ Wrapper for grid_sample, uses pixel coordinates """
59 | H, W = img.shape[-2:]
60 | xgrid, ygrid = coords.split([1,1], dim=-1)
61 | xgrid = 2*xgrid/(W-1) - 1
62 | ygrid = 2*ygrid/(H-1) - 1
63 |
64 | grid = torch.cat([xgrid, ygrid], dim=-1)
65 | img = F.grid_sample(img, grid, align_corners=True)
66 |
67 | if mask:
68 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
69 | return img, mask.float()
70 |
71 | return img
72 |
73 |
74 | def coords_grid(batch, ht, wd, device):
75 | coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device))
76 | coords = torch.stack(coords[::-1], dim=0).float()
77 | return coords[None].repeat(batch, 1, 1, 1)
78 |
79 |
80 | def upflow8(flow, mode='bilinear'):
81 | new_size = (8 * flow.shape[2], 8 * flow.shape[3])
82 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
83 |
--------------------------------------------------------------------------------
/data_preprocessing/RAFT/demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('core')
3 |
4 | import argparse
5 | import os
6 | import cv2
7 | import glob
8 | import numpy as np
9 | import torch
10 | from PIL import Image
11 | import torch.nn.functional as F
12 |
13 | from raft import RAFT
14 | from utils import flow_viz
15 | from utils.utils import InputPadder
16 |
17 | DEVICE = 'cuda'
18 |
19 |
20 | def load_image(imfile):
21 | img = np.array(Image.open(imfile)).astype(np.uint8)
22 | img = torch.from_numpy(img).permute(2, 0, 1).float()
23 | return img[None].to(DEVICE)
24 |
25 |
26 | def viz(img, flo,img_name=None):
27 | img = img[0].permute(1,2,0).cpu().numpy()
28 | flo = flo[0].permute(1,2,0).cpu().numpy()
29 |
30 | # map flow to rgb image
31 | flo = flow_viz.flow_to_image(flo)
32 | img_flo = np.concatenate([img, flo], axis=0)
33 |
34 | cv2.imwrite(f'{img_name}', img_flo[:, :, [2,1,0]])
35 |
36 |
37 | def demo(args):
38 | model = torch.nn.DataParallel(RAFT(args))
39 | model.load_state_dict(torch.load(args.model))
40 |
41 | model = model.module
42 | model.to(DEVICE)
43 | model.eval()
44 | os.makedirs(args.outdir, exist_ok=True)
45 | os.makedirs(args.outdir_conf, exist_ok=True)
46 | with torch.no_grad():
47 | images = glob.glob(os.path.join(args.path, '*.png')) + \
48 | glob.glob(os.path.join(args.path, '*.jpg'))
49 |
50 | images = sorted(images)
51 | i=0
52 | for imfile1, imfile2 in zip(images[:-1], images[1:]):
53 | image1 = load_image(imfile1)
54 | image2 = load_image(imfile2)
55 | if args.if_mask:
56 | mk_file1=imfile1.split("/")
57 | mk_file1[-2]=f"{args.name}_masks"
58 | mk_file1='/'.join(mk_file1)
59 | mk_file2=imfile2.split("/")
60 | mk_file2[-2]=f"{args.name}_masks"
61 | mk_file2='/'.join(mk_file2)
62 | mask1=cv2.imread(mk_file1.replace('jpg','png')
63 | ,0)
64 | mask2=cv2.imread(mk_file2.replace('jpg','png'),
65 | 0)
66 | mask1=torch.from_numpy(mask1).to(DEVICE).float()
67 | mask2=torch.from_numpy(mask2).to(DEVICE).float()
68 | mask1[mask1>0]=1
69 | mask2[mask2>0]=1
70 | image1*=mask1
71 | image2*=mask2
72 |
73 | padder = InputPadder(image1.shape)
74 | image1, image2 = padder.pad(image1, image2)
75 | if args.if_mask:
76 | mask1,mask2=padder.pad(mask1.unsqueeze(0).unsqueeze(0),
77 | mask2.unsqueeze(0).unsqueeze(0))
78 | mask1=mask1.squeeze()
79 | mask2=mask2.squeeze()
80 |
81 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
82 | flow_low_, flow_up_ = model(image2, image1, iters=20, test_mode=True)
83 | flow_1to2 = flow_up.clone()
84 | flow_2to1 = flow_up_.clone()
85 |
86 | _,_,H,W=image1.shape
87 | x = torch.linspace(0, 1, W)
88 | y = torch.linspace(0, 1, H)
89 | grid_x,grid_y=torch.meshgrid(x,y)
90 | grid=torch.stack([grid_x,grid_y],dim=0).to(DEVICE)
91 | grid=grid.permute(0,2,1)
92 | grid[0]*=W
93 | grid[1]*=H
94 | if args.if_mask:
95 | flow_up[:,:,mask1.long()==0]=10000
96 | grid_=grid+flow_up.squeeze()
97 |
98 | grid_norm=grid_.clone()
99 | grid_norm[0,...]=2*grid_norm[0,...]/(W-1)-1
100 | grid_norm[1,...]=2*grid_norm[1,...]/(H-1)-1
101 |
102 | flow_bilinear_=F.grid_sample(flow_up_,grid_norm.unsqueeze(0).permute(0,2,3,1),mode='bilinear',padding_mode='zeros')
103 |
104 | rgb_bilinear_=F.grid_sample(image2,grid_norm.unsqueeze(0).permute(0,2,3,1),mode='bilinear',padding_mode='zeros')
105 | rgb_np=rgb_bilinear_.squeeze().permute(1,2,0).cpu().numpy()[:, :, ::-1]
106 | cv2.imwrite(f'{args.outdir}/warped.png',rgb_np)
107 |
108 | if args.confidence:
109 | ### Calculate confidence map using cycle consistency.
110 | # 1). First calculate `warped_image2` by the following formula:
111 | # warped_image2 = F.grid_sample(image1, flow_2to1)
112 | # 2). Then calculate `warped_image1` by the following formula:
113 | # warped_image1 = F.grid_sample(warped_image2, flow_1to2)
114 | # 3) Finally calculate the confidence map:
115 | # confidence_map = metric_func(image1 - warped_image1)
116 |
117 | grid_2to1 = grid + flow_2to1.squeeze()
118 | norm_grid_2to1 = grid_2to1.clone()
119 | norm_grid_2to1[0, ...] = 2 * norm_grid_2to1[0, ...] / (W - 1) - 1
120 | norm_grid_2to1[1, ...] = 2 * norm_grid_2to1[1, ...] / (H - 1) - 1
121 | warped_image2 = F.grid_sample(image1, norm_grid_2to1.unsqueeze(0).permute(0,2,3,1), mode='bilinear', padding_mode='zeros')
122 |
123 | grid_1to2 = grid + flow_1to2.squeeze()
124 | norm_grid_1to2 = grid_1to2.clone()
125 | norm_grid_1to2[0, ...] = 2 * norm_grid_1to2[0, ...] / (W - 1) - 1
126 | norm_grid_1to2[1, ...] = 2 * norm_grid_1to2[1, ...] / (H - 1) - 1
127 | warped_image1 = F.grid_sample(warped_image2, norm_grid_1to2.unsqueeze(0).permute(0,2,3,1), mode='bilinear', padding_mode='zeros')
128 |
129 | error = torch.abs(image1 - warped_image1)
130 | confidence_map = torch.mean(error, dim=1, keepdim=True)
131 | confidence_map[confidence_map < args.thres] = 1
132 | confidence_map[confidence_map >= args.thres] = 0
133 |
134 | grid_bck=grid+flow_up.squeeze()+flow_bilinear_.squeeze()
135 | res=grid-grid_bck
136 | res=torch.norm(res,dim=0)
137 | mk=(res<10)&(flow_up.norm(dim=1).squeeze()>5)
138 |
139 | pts_src=grid[:,mk]
140 |
141 | pts_dst=(grid[:,mk]+flow_up.squeeze()[:,mk])
142 |
143 | pts_src=pts_src.permute(1,0).cpu().numpy()
144 | pts_dst=pts_dst.permute(1,0).cpu().numpy()
145 | indx=torch.randperm(pts_src.shape[0])[:30]
146 | # use cv2 to draw the matches in image1 and image2
147 | img_new=np.zeros((H,W*2,3),dtype=np.uint8)
148 | img_new[:,:W,:]=image1[0].permute(1,2,0).cpu().numpy()
149 | img_new[:,W:,:]=image2[0].permute(1,2,0).cpu().numpy()
150 |
151 | for j in indx:
152 | cv2.line(img_new,(int(pts_src[j,0]),int(pts_src[j,1])),(int(pts_dst[j,0])+W,int(pts_dst[j,1])),(0,255,0),1)
153 |
154 | cv2.imwrite(f'{args.outdir}/matches.png',img_new)
155 |
156 | np.save(f'{args.outdir}/{i:06d}.npy', flow_up.cpu().numpy())
157 | if args.confidence:
158 | np.save(f'{args.outdir_conf}/{i:06d}_c.npy', confidence_map.cpu().numpy())
159 | i += 1
160 |
161 | viz(image1, flow_up,f'{args.outdir}/flow_up{i:03d}.png')
162 |
163 |
164 | if __name__ == '__main__':
165 | parser = argparse.ArgumentParser()
166 | parser.add_argument('--model', help="restore checkpoint")
167 | parser.add_argument('--path', help="dataset for evaluation")
168 | parser.add_argument('--outdir',help="directory for the ouput the result")
169 | parser.add_argument('--small', action='store_true', help='use small model')
170 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
171 | parser.add_argument('--if_mask', action='store_true', help='if using the image mask to mask the color img')
172 | parser.add_argument('--confidence', action='store_true', help='if saving the confidence map')
173 | parser.add_argument('--discrete', action='store_true', help='if saving the confidence map in discrete')
174 | parser.add_argument('--thres', default=4, help='Threshold value for confidence map')
175 | parser.add_argument('--outdir_conf', help="directory to save flow confidence")
176 | parser.add_argument('--name', help="the name of a sequence")
177 | args = parser.parse_args()
178 |
179 | demo(args)
180 |
--------------------------------------------------------------------------------
/data_preprocessing/RAFT/models/download_models.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | wget https://www.dropbox.com/s/4j4z58wuv8o0mfz/models.zip
3 | unzip models.zip
4 |
--------------------------------------------------------------------------------
/data_preprocessing/RAFT/run_raft.sh:
--------------------------------------------------------------------------------
1 | NAME=beauty_1
2 | ROOT_DIR=/home/xxx/code/CoDeF/all_sequences
3 | CODE_DIR=/home/xxx/code/CoDeF/data_preprocessing/RAFT
4 |
5 | IMG_DIR=$ROOT_DIR/${NAME}/${NAME}
6 | FLOW_DIR=$ROOT_DIR/${NAME}/${NAME}_flow
7 | CONF_DIR=${FLOW_DIR}_confidence
8 |
9 | CUDA_VISIBLE_DEVICES=0 \
10 | python ${CODE_DIR}/demo.py \
11 | --model=${CODE_DIR}/models/raft-sintel.pth \
12 | --path=$IMG_DIR \
13 | --outdir=$FLOW_DIR \
14 | --name=$NAME \
15 | --confidence \
16 | --outdir_conf=$CONF_DIR
17 |
--------------------------------------------------------------------------------
/data_preprocessing/preproc_mask.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import cv2
4 | from glob import glob
5 | from tqdm import tqdm
6 |
7 | root_dir = '/home/xxx/code/CoDeF/all_sequences'
8 | name = 'beauty_1'
9 |
10 | msk_folder = f'{root_dir}/{name}/{name}_masks'
11 | img_folder = f'{root_dir}/{name}/{name}'
12 | frg_mask_folder = f'{root_dir}/{name}/{name}_masks_0'
13 | bkg_mask_folder = f'{root_dir}/{name}/{name}_masks_1'
14 | os.makedirs(frg_mask_folder, exist_ok=True)
15 | os.makedirs(bkg_mask_folder, exist_ok=True)
16 |
17 | files = glob(msk_folder + '/*.png')
18 | num = len(files)
19 |
20 | for i in tqdm(range(num)):
21 | file_n = os.path.basename(files[i])
22 | mask = cv2.imread(os.path.join(msk_folder, file_n), 0)
23 | mask[mask > 0] = 1
24 | cv2.imwrite(os.path.join(frg_mask_folder, file_n), mask * 255)
25 |
26 | bg_mask = mask.copy()
27 | bg_mask[bg_mask == 0] = 127
28 | bg_mask[bg_mask == 255] = 0
29 | bg_mask[bg_mask == 127] = 255
30 | cv2.imwrite(os.path.join(bkg_mask_folder, file_n), bg_mask)
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .distributed_weighted_sampler import DistributedWeightedSampler
2 | from .video_dataset import VideoDataset
3 |
4 | dataset_dict = {'video': VideoDataset}
5 |
6 | custom_sampler_dict = {'weighted': DistributedWeightedSampler}
--------------------------------------------------------------------------------
/datasets/distributed_weighted_sampler.py:
--------------------------------------------------------------------------------
1 | # Combine weighted sampler and distributed sampler
2 | import math
3 | from typing import TypeVar, Optional, Iterator
4 |
5 | import torch
6 | import numpy as np
7 | from torch.utils.data import Sampler, Dataset
8 | import torch.distributed as dist
9 |
10 |
11 | T_co = TypeVar('T_co', covariant=True)
12 |
13 |
14 | class DistributedWeightedSampler(Sampler[T_co]):
15 | r"""Sampler that restricts data loading to a subset of the dataset.
16 |
17 | It is especially useful in conjunction with
18 | :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each
19 | process can pass a :class:`~torch.utils.data.DistributedSampler` instance as a
20 | :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the
21 | original dataset that is exclusive to it.
22 |
23 | .. note::
24 | Dataset is assumed to be of constant size.
25 |
26 | Args:
27 | dataset: Dataset used for sampling.
28 | num_replicas (int, optional): Number of processes participating in
29 | distributed training. By default, :attr:`world_size` is retrieved from the
30 | current distributed group.
31 | rank (int, optional): Rank of the current process within :attr:`num_replicas`.
32 | By default, :attr:`rank` is retrieved from the current distributed
33 | group.
34 | shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
35 | indices.
36 | seed (int, optional): random seed used to shuffle the sampler if
37 | :attr:`shuffle=True`. This number should be identical across all
38 | processes in the distributed group. Default: ``0``.
39 | drop_last (bool, optional): if ``True``, then the sampler will drop the
40 | tail of the data to make it evenly divisible across the number of
41 | replicas. If ``False``, the sampler will add extra indices to make
42 | the data evenly divisible across the replicas. Default: ``False``.
43 |
44 | .. warning::
45 | In distributed mode, calling the :meth:`set_epoch` method at
46 | the beginning of each epoch **before** creating the :class:`DataLoader` iterator
47 | is necessary to make shuffling work properly across multiple epochs. Otherwise,
48 | the same ordering will be always used.
49 |
50 | Example::
51 |
52 | >>> sampler = DistributedSampler(dataset) if is_distributed else None
53 | >>> loader = DataLoader(dataset, shuffle=(sampler is None),
54 | ... sampler=sampler)
55 | >>> for epoch in range(start_epoch, n_epochs):
56 | ... if is_distributed:
57 | ... sampler.set_epoch(epoch)
58 | ... train(loader)
59 | """
60 |
61 | def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None,
62 | rank: Optional[int] = None, shuffle: bool = True,
63 | seed: int = 0, drop_last: bool = False, replacement: bool = True) -> None:
64 | if num_replicas is None:
65 | if not dist.is_available():
66 | raise RuntimeError("Requires distributed package to be available")
67 | num_replicas = dist.get_world_size()
68 | if rank is None:
69 | if not dist.is_available():
70 | raise RuntimeError("Requires distributed package to be available")
71 | rank = dist.get_rank()
72 | if rank >= num_replicas or rank < 0:
73 | raise ValueError(
74 | "Invalid rank {}, rank should be in the interval"
75 | " [0, {}]".format(rank, num_replicas - 1))
76 | self.dataset = dataset
77 | self.num_replicas = num_replicas
78 | self.rank = rank
79 | self.epoch = 0
80 | self.drop_last = drop_last
81 | # If the dataset length is evenly divisible by # of replicas, then there
82 | # is no need to drop any data, since the dataset will be split equally.
83 | if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
84 | # Split to nearest available length that is evenly divisible.
85 | # This is to ensure each rank receives the same amount of data when
86 | # using this Sampler.
87 | self.num_samples = math.ceil(
88 | # `type:ignore` is required because Dataset cannot provide a default __len__
89 | # see NOTE in pytorch/torch/utils/data/sampler.py
90 | (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
91 | )
92 | else:
93 | self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
94 | self.total_size = self.num_samples * self.num_replicas
95 | self.shuffle = shuffle
96 | self.seed = seed
97 | self.weights = self.dataset.weights
98 | self.replacement = replacement
99 |
100 | # def calculate_weights(self, targets):
101 | # class_sample_count = torch.tensor(
102 | # [(targets == t).sum() for t in torch.unique(targets, sorted=True)])
103 | # weight = 1. / class_sample_count.double()
104 | # samples_weight = torch.tensor([weight[t] for t in targets])
105 | # return samples_weight
106 |
107 | def __iter__(self) -> Iterator[T_co]:
108 | if self.shuffle:
109 | # deterministically shuffle based on epoch and seed
110 | g = torch.Generator()
111 | g.manual_seed(self.seed + self.epoch)
112 | indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
113 | else:
114 | indices = list(range(len(self.dataset))) # type: ignore[arg-type]
115 |
116 | if not self.drop_last:
117 | # add extra samples to make it evenly divisible
118 | padding_size = self.total_size - len(indices)
119 | if padding_size <= len(indices):
120 | indices += indices[:padding_size]
121 | else:
122 | indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
123 | else:
124 | # remove tail of data to make it evenly divisible.
125 | indices = indices[:self.total_size]
126 | assert len(indices) == self.total_size
127 |
128 | # subsample
129 | indices = indices[self.rank:self.total_size:self.num_replicas]
130 | assert len(indices) == self.num_samples
131 |
132 | # subsample weights
133 | # targets = self.targets[indices]
134 | weights = self.weights[indices][:, 0]
135 | assert len(weights) == self.num_samples
136 |
137 | ###########################################################################
138 | # the upper bound category number of multinomial is 2^24, to handle this we can use chunk or using random choices
139 | # subsample_balanced_indicies = torch.multinomial(weights, self.num_samples, self.replacement)
140 | ###########################################################################
141 | # using random choices
142 | rand_tensor = np.random.choice(range(0, len(weights)),
143 | size=self.num_samples,
144 | p=weights.numpy() / torch.sum(weights).numpy(),
145 | replace=self.replacement)
146 | subsample_balanced_indicies = torch.from_numpy(rand_tensor)
147 | dataset_indices = torch.tensor(indices)[subsample_balanced_indicies]
148 |
149 | return iter(dataset_indices.tolist())
150 |
151 |
152 | def __len__(self) -> int:
153 | return self.num_samples
154 |
155 | def set_epoch(self, epoch: int) -> None:
156 | r"""
157 | Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
158 | use a different random ordering for each epoch. Otherwise, the next iteration of this
159 | sampler will yield the same ordering.
160 |
161 | Args:
162 | epoch (int): Epoch number.
163 | """
164 | self.epoch = epoch
--------------------------------------------------------------------------------
/datasets/video_dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | import numpy as np
4 | import os
5 | from PIL import Image
6 | from einops import rearrange, reduce, repeat
7 | from torchvision import transforms as T
8 | import glob
9 | import cv2
10 |
11 | # The basic dataset of reading rays
12 | class VideoDataset(Dataset):
13 |
14 | def __init__(self,
15 | root_dir,
16 | split='train',
17 | img_wh=(504, 378),
18 | mask_dir=None,
19 | flow_dir=None,
20 | canonical_wh=None,
21 | ref_idx=None,
22 | canonical_dir=None,
23 | test=False):
24 | self.test = test
25 | self.root_dir = root_dir
26 | self.split = split
27 | self.img_wh = img_wh
28 | self.mask_dir = mask_dir
29 | self.flow_dir = flow_dir
30 | self.canonical_wh = canonical_wh
31 | self.ref_idx = ref_idx
32 | self.canonical_dir = canonical_dir
33 | self.read_meta()
34 |
35 | def read_meta(self):
36 | all_images_path = []
37 | self.ts_w = []
38 | self.all_images = []
39 | h = self.img_wh[1]
40 | w = self.img_wh[0]
41 | # construct grid
42 | grid = np.indices((h, w)).astype(np.float32)
43 | # normalize
44 | grid[0,:,:] = grid[0,:,:] / h
45 | grid[1,:,:] = grid[1,:,:] / w
46 | self.grid = torch.from_numpy(rearrange(grid, 'c h w -> (h w) c'))
47 | warp_code = 1
48 | for input_image_path in sorted(glob.glob(f'{self.root_dir}/*')):
49 | print(input_image_path)
50 | all_images_path.append(input_image_path)
51 | self.ts_w.append(torch.Tensor([warp_code]).long())
52 | warp_code += 1
53 |
54 | if self.canonical_wh:
55 | h_c = self.canonical_wh[1]
56 | w_c = self.canonical_wh[0]
57 | grid_c = np.indices((h_c, w_c)).astype(np.float32)
58 | grid_c[0,:,:] = (grid_c[0,:,:] - (h_c - h) / 2) / h
59 | grid_c[1,:,:] = (grid_c[1,:,:] - (w_c - w) / 2) / w
60 | self.grid_c = torch.from_numpy(rearrange(grid_c, 'c h w -> (h w) c'))
61 | else:
62 | self.grid_c = self.grid
63 | self.canonical_wh = self.img_wh
64 |
65 | if self.mask_dir:
66 | self.all_masks = []
67 | if self.flow_dir:
68 | self.all_flows = []
69 | else:
70 | self.all_flows = None
71 |
72 | if self.split == 'train' or self.split == 'val':
73 | if self.canonical_dir is not None:
74 | all_images_path_ = sorted(glob.glob(f'{self.canonical_dir}/*.png'))
75 | self.canonical_img = []
76 | for input_image_path in all_images_path_:
77 | input_image = cv2.imread(input_image_path)
78 | input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
79 | input_image = cv2.resize(input_image, (self.canonical_wh[0], self.canonical_wh[1]), interpolation = cv2.INTER_AREA)
80 | input_image_tensor = torch.from_numpy(input_image).float() / 256
81 | self.canonical_img.append(input_image_tensor)
82 | self.canonical_img = torch.stack(self.canonical_img, dim=0)
83 |
84 | for input_image_path in all_images_path:
85 | input_image = cv2.imread(input_image_path)
86 | input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
87 | input_image = cv2.resize(input_image, (self.img_wh[0], self.img_wh[1]), interpolation = cv2.INTER_AREA)
88 | input_image_tensor = torch.from_numpy(input_image).float() / 256
89 | self.all_images.append(input_image_tensor)
90 | if self.mask_dir:
91 | input_image_name = input_image_path.split("/")[-1][:-4]
92 | for i in range(len(self.mask_dir)):
93 | input_mask = cv2.imread(f'{self.mask_dir[i]}/{input_image_name}.png')
94 | input_mask = cv2.resize(input_mask, (self.img_wh[0], self.img_wh[1]), interpolation = cv2.INTER_AREA)
95 | input_mask_tensor = torch.from_numpy(input_mask).float() / 256
96 | self.all_masks.append(input_mask_tensor)
97 |
98 | if self.split == 'val':
99 | input_image = cv2.imread(all_images_path[0])
100 | input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
101 | input_image = cv2.resize(input_image, (self.img_wh[0], self.img_wh[1]), interpolation = cv2.INTER_AREA)
102 | input_image_tensor = torch.from_numpy(input_image).float() / 256
103 | self.all_images.append(input_image_tensor)
104 | if self.mask_dir:
105 | input_image_name = all_images_path[0].split("/")[-1][:-4]
106 | for i in range(len(self.mask_dir)):
107 | input_mask = cv2.imread(f'{self.mask_dir[i]}/{input_image_name}.png')
108 | input_mask = cv2.resize(input_mask, (self.img_wh[0], self.img_wh[1]), interpolation = cv2.INTER_AREA)
109 | input_mask_tensor = torch.from_numpy(input_mask).float() / 256
110 | self.all_masks.append(input_mask_tensor)
111 |
112 | if self.flow_dir:
113 | for input_image_path in sorted(glob.glob(f'{self.flow_dir}/*npy')):
114 | flow_load=np.load(input_image_path) # (1, 2, h, w)
115 | flow_tensor=torch.from_numpy(flow_load).float()[:, [1, 0]]
116 | flow_tensor=torch.nn.functional.interpolate(flow_tensor,size=(self.img_wh[1],self.img_wh[0]))
117 | H_,W_=flow_load.shape[-2],flow_load.shape[-1]
118 | flow_tensor=flow_tensor.reshape(2,-1).transpose(1,0)
119 | flow_tensor[..., 0] /= W_
120 | flow_tensor[..., 1] /= H_
121 | self.all_flows.append(flow_tensor)
122 |
123 | i = 0
124 | for input_image_path in sorted(glob.glob(f'{self.flow_dir}_confidence/*npy')):
125 | flow_load=np.load(input_image_path)
126 | flow_tensor=torch.from_numpy(flow_load).float()
127 | flow_tensor=torch.nn.functional.interpolate(flow_tensor,size=(self.img_wh[1],self.img_wh[0]))
128 | flow_tensor=flow_tensor.reshape(1,-1).transpose(1,0)
129 | flow_tensor = flow_tensor.sum(dim=-1) < 0.05
130 | self.all_flows[i][flow_tensor] = 5
131 | i += 1
132 |
133 | if self.split == 'val':
134 | self.ref_idx = 0
135 |
136 | def __len__(self):
137 | if self.test:
138 | return len(self.all_images)
139 | return 200 * len(self.all_images)
140 |
141 | def __getitem__(self, idx):
142 | if self.split == 'train' or self.split == 'val':
143 | idx = idx % len(self.all_images)
144 | sample = {'rgbs': self.all_images[idx],
145 | 'canonical_img': self.all_images[idx] if self.canonical_dir is None else self.canonical_img,
146 | 'ts_w': self.ts_w[idx],
147 | 'grid': self.grid,
148 | 'canonical_wh': self.canonical_wh,
149 | 'img_wh': self.img_wh,
150 | 'masks': self.all_masks[len(self.mask_dir)*idx:len(self.mask_dir)*idx+len(self.mask_dir)] if self.mask_dir else [torch.ones((self.img_wh[1], self.img_wh[0], 1))],
151 | 'flows': self.all_flows[idx] if (idx 1 && arguments[1] !== undefined ? arguments[1] : {};
107 |
108 | _classCallCheck(this, bulmaSlider);
109 |
110 | var _this = _possibleConstructorReturn(this, (bulmaSlider.__proto__ || Object.getPrototypeOf(bulmaSlider)).call(this));
111 |
112 | _this.element = typeof selector === 'string' ? document.querySelector(selector) : selector;
113 | // An invalid selector or non-DOM node has been provided.
114 | if (!_this.element) {
115 | throw new Error('An invalid selector or non-DOM node has been provided.');
116 | }
117 |
118 | _this._clickEvents = ['click'];
119 | /// Set default options and merge with instance defined
120 | _this.options = _extends({}, options);
121 |
122 | _this.onSliderInput = _this.onSliderInput.bind(_this);
123 |
124 | _this.init();
125 | return _this;
126 | }
127 |
128 | /**
129 | * Initiate all DOM element containing selector
130 | * @method
131 | * @return {Array} Array of all slider instances
132 | */
133 |
134 |
135 | _createClass(bulmaSlider, [{
136 | key: 'init',
137 |
138 |
139 | /**
140 | * Initiate plugin
141 | * @method init
142 | * @return {void}
143 | */
144 | value: function init() {
145 | this._id = 'bulmaSlider' + new Date().getTime() + Math.floor(Math.random() * Math.floor(9999));
146 | this.output = this._findOutputForSlider();
147 |
148 | this._bindEvents();
149 |
150 | if (this.output) {
151 | if (this.element.classList.contains('has-output-tooltip')) {
152 | // Get new output position
153 | var newPosition = this._getSliderOutputPosition();
154 |
155 | // Set output position
156 | this.output.style['left'] = newPosition.position;
157 | }
158 | }
159 |
160 | this.emit('bulmaslider:ready', this.element.value);
161 | }
162 | }, {
163 | key: '_findOutputForSlider',
164 | value: function _findOutputForSlider() {
165 | var _this2 = this;
166 |
167 | var result = null;
168 | var outputs = document.getElementsByTagName('output') || [];
169 |
170 | Array.from(outputs).forEach(function (output) {
171 | if (output.htmlFor == _this2.element.getAttribute('id')) {
172 | result = output;
173 | return true;
174 | }
175 | });
176 | return result;
177 | }
178 | }, {
179 | key: '_getSliderOutputPosition',
180 | value: function _getSliderOutputPosition() {
181 | // Update output position
182 | var newPlace, minValue;
183 |
184 | var style = window.getComputedStyle(this.element, null);
185 | // Measure width of range input
186 | var sliderWidth = parseInt(style.getPropertyValue('width'), 10);
187 |
188 | // Figure out placement percentage between left and right of input
189 | if (!this.element.getAttribute('min')) {
190 | minValue = 0;
191 | } else {
192 | minValue = this.element.getAttribute('min');
193 | }
194 | var newPoint = (this.element.value - minValue) / (this.element.getAttribute('max') - minValue);
195 |
196 | // Prevent bubble from going beyond left or right (unsupported browsers)
197 | if (newPoint < 0) {
198 | newPlace = 0;
199 | } else if (newPoint > 1) {
200 | newPlace = sliderWidth;
201 | } else {
202 | newPlace = sliderWidth * newPoint;
203 | }
204 |
205 | return {
206 | 'position': newPlace + 'px'
207 | };
208 | }
209 |
210 | /**
211 | * Bind all events
212 | * @method _bindEvents
213 | * @return {void}
214 | */
215 |
216 | }, {
217 | key: '_bindEvents',
218 | value: function _bindEvents() {
219 | if (this.output) {
220 | // Add event listener to update output when slider value change
221 | this.element.addEventListener('input', this.onSliderInput, false);
222 | }
223 | }
224 | }, {
225 | key: 'onSliderInput',
226 | value: function onSliderInput(e) {
227 | e.preventDefault();
228 |
229 | if (this.element.classList.contains('has-output-tooltip')) {
230 | // Get new output position
231 | var newPosition = this._getSliderOutputPosition();
232 |
233 | // Set output position
234 | this.output.style['left'] = newPosition.position;
235 | }
236 |
237 | // Check for prefix and postfix
238 | var prefix = this.output.hasAttribute('data-prefix') ? this.output.getAttribute('data-prefix') : '';
239 | var postfix = this.output.hasAttribute('data-postfix') ? this.output.getAttribute('data-postfix') : '';
240 |
241 | // Update output with slider value
242 | this.output.value = prefix + this.element.value + postfix;
243 |
244 | this.emit('bulmaslider:ready', this.element.value);
245 | }
246 | }], [{
247 | key: 'attach',
248 | value: function attach() {
249 | var _this3 = this;
250 |
251 | var selector = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : 'input[type="range"].slider';
252 | var options = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {};
253 |
254 | var instances = new Array();
255 |
256 | var elements = isString(selector) ? document.querySelectorAll(selector) : Array.isArray(selector) ? selector : [selector];
257 | elements.forEach(function (element) {
258 | if (typeof element[_this3.constructor.name] === 'undefined') {
259 | var instance = new bulmaSlider(element, options);
260 | element[_this3.constructor.name] = instance;
261 | instances.push(instance);
262 | } else {
263 | instances.push(element[_this3.constructor.name]);
264 | }
265 | });
266 |
267 | return instances;
268 | }
269 | }]);
270 |
271 | return bulmaSlider;
272 | }(__WEBPACK_IMPORTED_MODULE_0__events__["a" /* default */]);
273 |
274 | /* harmony default export */ __webpack_exports__["default"] = (bulmaSlider);
275 |
276 | /***/ }),
277 | /* 1 */
278 | /***/ (function(module, __webpack_exports__, __webpack_require__) {
279 |
280 | "use strict";
281 | var _createClass = function () { function defineProperties(target, props) { for (var i = 0; i < props.length; i++) { var descriptor = props[i]; descriptor.enumerable = descriptor.enumerable || false; descriptor.configurable = true; if ("value" in descriptor) descriptor.writable = true; Object.defineProperty(target, descriptor.key, descriptor); } } return function (Constructor, protoProps, staticProps) { if (protoProps) defineProperties(Constructor.prototype, protoProps); if (staticProps) defineProperties(Constructor, staticProps); return Constructor; }; }();
282 |
283 | function _classCallCheck(instance, Constructor) { if (!(instance instanceof Constructor)) { throw new TypeError("Cannot call a class as a function"); } }
284 |
285 | var EventEmitter = function () {
286 | function EventEmitter() {
287 | var listeners = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : [];
288 |
289 | _classCallCheck(this, EventEmitter);
290 |
291 | this._listeners = new Map(listeners);
292 | this._middlewares = new Map();
293 | }
294 |
295 | _createClass(EventEmitter, [{
296 | key: "listenerCount",
297 | value: function listenerCount(eventName) {
298 | if (!this._listeners.has(eventName)) {
299 | return 0;
300 | }
301 |
302 | var eventListeners = this._listeners.get(eventName);
303 | return eventListeners.length;
304 | }
305 | }, {
306 | key: "removeListeners",
307 | value: function removeListeners() {
308 | var _this = this;
309 |
310 | var eventName = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : null;
311 | var middleware = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : false;
312 |
313 | if (eventName !== null) {
314 | if (Array.isArray(eventName)) {
315 | name.forEach(function (e) {
316 | return _this.removeListeners(e, middleware);
317 | });
318 | } else {
319 | this._listeners.delete(eventName);
320 |
321 | if (middleware) {
322 | this.removeMiddleware(eventName);
323 | }
324 | }
325 | } else {
326 | this._listeners = new Map();
327 | }
328 | }
329 | }, {
330 | key: "middleware",
331 | value: function middleware(eventName, fn) {
332 | var _this2 = this;
333 |
334 | if (Array.isArray(eventName)) {
335 | name.forEach(function (e) {
336 | return _this2.middleware(e, fn);
337 | });
338 | } else {
339 | if (!Array.isArray(this._middlewares.get(eventName))) {
340 | this._middlewares.set(eventName, []);
341 | }
342 |
343 | this._middlewares.get(eventName).push(fn);
344 | }
345 | }
346 | }, {
347 | key: "removeMiddleware",
348 | value: function removeMiddleware() {
349 | var _this3 = this;
350 |
351 | var eventName = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : null;
352 |
353 | if (eventName !== null) {
354 | if (Array.isArray(eventName)) {
355 | name.forEach(function (e) {
356 | return _this3.removeMiddleware(e);
357 | });
358 | } else {
359 | this._middlewares.delete(eventName);
360 | }
361 | } else {
362 | this._middlewares = new Map();
363 | }
364 | }
365 | }, {
366 | key: "on",
367 | value: function on(name, callback) {
368 | var _this4 = this;
369 |
370 | var once = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false;
371 |
372 | if (Array.isArray(name)) {
373 | name.forEach(function (e) {
374 | return _this4.on(e, callback);
375 | });
376 | } else {
377 | name = name.toString();
378 | var split = name.split(/,|, | /);
379 |
380 | if (split.length > 1) {
381 | split.forEach(function (e) {
382 | return _this4.on(e, callback);
383 | });
384 | } else {
385 | if (!Array.isArray(this._listeners.get(name))) {
386 | this._listeners.set(name, []);
387 | }
388 |
389 | this._listeners.get(name).push({ once: once, callback: callback });
390 | }
391 | }
392 | }
393 | }, {
394 | key: "once",
395 | value: function once(name, callback) {
396 | this.on(name, callback, true);
397 | }
398 | }, {
399 | key: "emit",
400 | value: function emit(name, data) {
401 | var _this5 = this;
402 |
403 | var silent = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false;
404 |
405 | name = name.toString();
406 | var listeners = this._listeners.get(name);
407 | var middlewares = null;
408 | var doneCount = 0;
409 | var execute = silent;
410 |
411 | if (Array.isArray(listeners)) {
412 | listeners.forEach(function (listener, index) {
413 | // Start Middleware checks unless we're doing a silent emit
414 | if (!silent) {
415 | middlewares = _this5._middlewares.get(name);
416 | // Check and execute Middleware
417 | if (Array.isArray(middlewares)) {
418 | middlewares.forEach(function (middleware) {
419 | middleware(data, function () {
420 | var newData = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : null;
421 |
422 | if (newData !== null) {
423 | data = newData;
424 | }
425 | doneCount++;
426 | }, name);
427 | });
428 |
429 | if (doneCount >= middlewares.length) {
430 | execute = true;
431 | }
432 | } else {
433 | execute = true;
434 | }
435 | }
436 |
437 | // If Middleware checks have been passed, execute
438 | if (execute) {
439 | if (listener.once) {
440 | listeners[index] = null;
441 | }
442 | listener.callback(data);
443 | }
444 | });
445 |
446 | // Dirty way of removing used Events
447 | while (listeners.indexOf(null) !== -1) {
448 | listeners.splice(listeners.indexOf(null), 1);
449 | }
450 | }
451 | }
452 | }]);
453 |
454 | return EventEmitter;
455 | }();
456 |
457 | /* harmony default export */ __webpack_exports__["a"] = (EventEmitter);
458 |
459 | /***/ })
460 | /******/ ])["default"];
461 | });
--------------------------------------------------------------------------------
/docs/static/js/bulma-slider.min.js:
--------------------------------------------------------------------------------
1 | !function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define([],e):"object"==typeof exports?exports.bulmaSlider=e():t.bulmaSlider=e()}("undefined"!=typeof self?self:this,function(){return function(n){var r={};function i(t){if(r[t])return r[t].exports;var e=r[t]={i:t,l:!1,exports:{}};return n[t].call(e.exports,e,e.exports,i),e.l=!0,e.exports}return i.m=n,i.c=r,i.d=function(t,e,n){i.o(t,e)||Object.defineProperty(t,e,{configurable:!1,enumerable:!0,get:n})},i.n=function(t){var e=t&&t.__esModule?function(){return t.default}:function(){return t};return i.d(e,"a",e),e},i.o=function(t,e){return Object.prototype.hasOwnProperty.call(t,e)},i.p="",i(i.s=0)}([function(t,e,n){"use strict";Object.defineProperty(e,"__esModule",{value:!0}),n.d(e,"isString",function(){return l});var r=n(1),i=Object.assign||function(t){for(var e=1;e=l.length&&(s=!0)):s=!0),s&&(t.once&&(u[e]=null),t.callback(r))});-1!==u.indexOf(null);)u.splice(u.indexOf(null),1)}}]),e}();e.a=i}]).default});
--------------------------------------------------------------------------------
/docs/static/js/index.js:
--------------------------------------------------------------------------------
1 | window.HELP_IMPROVE_VIDEOJS = false;
2 |
3 | var INTERP_BASE = "./static/interpolation/stacked";
4 | var NUM_INTERP_FRAMES = 240;
5 |
6 | var interp_images = [];
7 | function preloadInterpolationImages() {
8 | for (var i = 0; i < NUM_INTERP_FRAMES; i++) {
9 | var path = INTERP_BASE + '/' + String(i).padStart(6, '0') + '.jpg';
10 | interp_images[i] = new Image();
11 | interp_images[i].src = path;
12 | }
13 | }
14 |
15 | function setInterpolationImage(i) {
16 | var image = interp_images[i];
17 | image.ondragstart = function() { return false; };
18 | image.oncontextmenu = function() { return false; };
19 | $('#interpolation-image-wrapper').empty().append(image);
20 | }
21 |
22 |
23 | $(document).ready(function() {
24 | // Check for click events on the navbar burger icon
25 | $(".navbar-burger").click(function() {
26 | // Toggle the "is-active" class on both the "navbar-burger" and the "navbar-menu"
27 | $(".navbar-burger").toggleClass("is-active");
28 | $(".navbar-menu").toggleClass("is-active");
29 |
30 | });
31 |
32 | var options = {
33 | slidesToScroll: 1,
34 | slidesToShow: 3,
35 | loop: true,
36 | infinite: true,
37 | autoplay: false,
38 | autoplaySpeed: 3000,
39 | }
40 |
41 | // Initialize all div with carousel class
42 | var carousels = bulmaCarousel.attach('.carousel', options);
43 |
44 | // Loop on each carousel initialized
45 | for(var i = 0; i < carousels.length; i++) {
46 | // Add listener to event
47 | carousels[i].on('before:show', state => {
48 | console.log(state);
49 | });
50 | }
51 |
52 | // Access to bulmaCarousel instance of an element
53 | var element = document.querySelector('#my-element');
54 | if (element && element.bulmaCarousel) {
55 | // bulmaCarousel instance is available as element.bulmaCarousel
56 | element.bulmaCarousel.on('before-show', function(state) {
57 | console.log(state);
58 | });
59 | }
60 |
61 | /*var player = document.getElementById('interpolation-video');
62 | player.addEventListener('loadedmetadata', function() {
63 | $('#interpolation-slider').on('input', function(event) {
64 | console.log(this.value, player.duration);
65 | player.currentTime = player.duration / 100 * this.value;
66 | })
67 | }, false);*/
68 | preloadInterpolationImages();
69 |
70 | $('#interpolation-slider').on('input', function(event) {
71 | setInterpolationImage(this.value);
72 | });
73 | setInterpolationImage(0);
74 | $('#interpolation-slider').prop('max', NUM_INTERP_FRAMES - 1);
75 |
76 | bulmaSlider.attach();
77 |
78 | })
79 |
--------------------------------------------------------------------------------
/docs/static/js/video_comparison.js:
--------------------------------------------------------------------------------
1 | // This is based on: http://thenewcode.com/364/Interactive-Before-and-After-Video-Comparison-in-HTML5-Canvas
2 | // With additional modifications based on: https://jsfiddle.net/7sk5k4gp/13/
3 |
4 | function playVids(videoId) {
5 | var videoMerge = document.getElementById(videoId + "Merge");
6 | var vid = document.getElementById(videoId);
7 |
8 | var position = 0.5;
9 | var vidWidth = vid.videoWidth/2;
10 | var vidHeight = vid.videoHeight;
11 |
12 | var mergeContext = videoMerge.getContext("2d");
13 |
14 |
15 | if (vid.readyState > 3) {
16 | vid.play();
17 |
18 | function trackLocation(e) {
19 | // Normalize to [0, 1]
20 | bcr = videoMerge.getBoundingClientRect();
21 | position = ((e.pageX - bcr.x) / bcr.width);
22 | }
23 | function trackLocationTouch(e) {
24 | // Normalize to [0, 1]
25 | bcr = videoMerge.getBoundingClientRect();
26 | position = ((e.touches[0].pageX - bcr.x) / bcr.width);
27 | }
28 |
29 | videoMerge.addEventListener("mousemove", trackLocation, false);
30 | videoMerge.addEventListener("touchstart", trackLocationTouch, false);
31 | videoMerge.addEventListener("touchmove", trackLocationTouch, false);
32 |
33 |
34 | function drawLoop() {
35 | mergeContext.drawImage(vid, 0, 0, vidWidth, vidHeight, 0, 0, vidWidth, vidHeight);
36 | var colStart = (vidWidth * position).clamp(0.0, vidWidth);
37 | var colWidth = (vidWidth - (vidWidth * position)).clamp(0.0, vidWidth);
38 | mergeContext.drawImage(vid, colStart+vidWidth, 0, colWidth, vidHeight, colStart, 0, colWidth, vidHeight);
39 | requestAnimationFrame(drawLoop);
40 |
41 |
42 | var arrowLength = 0.07 * vidHeight;
43 | var arrowheadWidth = 0.020 * vidHeight;
44 | var arrowheadLength = 0.04 * vidHeight;
45 | var arrowPosY = vidHeight / 10;
46 | var arrowWidth = 0.007 * vidHeight;
47 | var currX = vidWidth * position;
48 |
49 | // Draw circle
50 | mergeContext.arc(currX, arrowPosY, arrowLength*0.7, 0, Math.PI * 2, false);
51 | mergeContext.fillStyle = "#FFD79340";
52 | mergeContext.fill()
53 | //mergeContext.strokeStyle = "#444444";
54 | //mergeContext.stroke()
55 |
56 | // Draw border
57 | mergeContext.beginPath();
58 | mergeContext.moveTo(vidWidth*position, 0);
59 | mergeContext.lineTo(vidWidth*position, vidHeight);
60 | mergeContext.closePath()
61 | mergeContext.strokeStyle = "#444444";
62 | mergeContext.lineWidth = 3;
63 | mergeContext.stroke();
64 |
65 | // Draw arrow
66 | mergeContext.beginPath();
67 | mergeContext.moveTo(currX, arrowPosY - arrowWidth/2);
68 |
69 | // Move right until meeting arrow head
70 | mergeContext.lineTo(currX + arrowLength/2 - arrowheadLength/2, arrowPosY - arrowWidth/2);
71 |
72 | // Draw right arrow head
73 | mergeContext.lineTo(currX + arrowLength/2 - arrowheadLength/2, arrowPosY - arrowheadWidth/2);
74 | mergeContext.lineTo(currX + arrowLength/2, arrowPosY);
75 | mergeContext.lineTo(currX + arrowLength/2 - arrowheadLength/2, arrowPosY + arrowheadWidth/2);
76 | mergeContext.lineTo(currX + arrowLength/2 - arrowheadLength/2, arrowPosY + arrowWidth/2);
77 |
78 | // Go back to the left until meeting left arrow head
79 | mergeContext.lineTo(currX - arrowLength/2 + arrowheadLength/2, arrowPosY + arrowWidth/2);
80 |
81 | // Draw left arrow head
82 | mergeContext.lineTo(currX - arrowLength/2 + arrowheadLength/2, arrowPosY + arrowheadWidth/2);
83 | mergeContext.lineTo(currX - arrowLength/2, arrowPosY);
84 | mergeContext.lineTo(currX - arrowLength/2 + arrowheadLength/2, arrowPosY - arrowheadWidth/2);
85 | mergeContext.lineTo(currX - arrowLength/2 + arrowheadLength/2, arrowPosY);
86 |
87 | mergeContext.lineTo(currX - arrowLength/2 + arrowheadLength/2, arrowPosY - arrowWidth/2);
88 | mergeContext.lineTo(currX, arrowPosY - arrowWidth/2);
89 |
90 | mergeContext.closePath();
91 |
92 | mergeContext.fillStyle = "#444444";
93 | mergeContext.fill();
94 |
95 |
96 |
97 | }
98 | requestAnimationFrame(drawLoop);
99 | }
100 | }
101 |
102 | Number.prototype.clamp = function(min, max) {
103 | return Math.min(Math.max(this, min), max);
104 | };
105 |
106 |
107 | function resizeAndPlay(element)
108 | {
109 | var cv = document.getElementById(element.id + "Merge");
110 | cv.width = element.videoWidth/2;
111 | cv.height = element.videoHeight;
112 | element.play();
113 | element.style.height = "0px"; // Hide video without stopping it
114 |
115 | playVids(element.id);
116 | }
117 |
--------------------------------------------------------------------------------
/docs/static/video_demos_compressed/1080x540/beauty_0_base_video_hash_transformed_dual(1).mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/1080x540/beauty_0_base_video_hash_transformed_dual(1).mp4
--------------------------------------------------------------------------------
/docs/static/video_demos_compressed/1080x540/beauty_0_base_video_hash_transformed_dual(2).mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/1080x540/beauty_0_base_video_hash_transformed_dual(2).mp4
--------------------------------------------------------------------------------
/docs/static/video_demos_compressed/1080x540/beauty_1_base_transformed_dual(1).mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/1080x540/beauty_1_base_transformed_dual(1).mp4
--------------------------------------------------------------------------------
/docs/static/video_demos_compressed/1080x540/beauty_1_base_transformed_dual(2).mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/1080x540/beauty_1_base_transformed_dual(2).mp4
--------------------------------------------------------------------------------
/docs/static/video_demos_compressed/1080x540/rainbow.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/1080x540/rainbow.mp4
--------------------------------------------------------------------------------
/docs/static/video_demos_compressed/1080x540/tifa.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/1080x540/tifa.mp4
--------------------------------------------------------------------------------
/docs/static/video_demos_compressed/1080x960/dog.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/1080x960/dog.mp4
--------------------------------------------------------------------------------
/docs/static/video_demos_compressed/1080x960/scene_1_base_transformed_dual.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/1080x960/scene_1_base_transformed_dual.mp4
--------------------------------------------------------------------------------
/docs/static/video_demos_compressed/1080x960/smoke_colorful.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/1080x960/smoke_colorful.mp4
--------------------------------------------------------------------------------
/docs/static/video_demos_compressed/1080x960/smoke_ink.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/1080x960/smoke_ink.mp4
--------------------------------------------------------------------------------
/docs/static/video_demos_compressed/1920x540/cloud_atlas4_base_transformed_dual.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/1920x540/cloud_atlas4_base_transformed_dual.mp4
--------------------------------------------------------------------------------
/docs/static/video_demos_compressed/1920x540/diamond_1_base_transformed_dual.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/1920x540/diamond_1_base_transformed_dual.mp4
--------------------------------------------------------------------------------
/docs/static/video_demos_compressed/1920x540/lemon_earth.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/1920x540/lemon_earth.mp4
--------------------------------------------------------------------------------
/docs/static/video_demos_compressed/1920x540/long_season.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/1920x540/long_season.mp4
--------------------------------------------------------------------------------
/docs/static/video_demos_compressed/1920x540/scene_3_base_transformed_dual.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/1920x540/scene_3_base_transformed_dual.mp4
--------------------------------------------------------------------------------
/docs/static/video_demos_compressed/1920x540/titanic.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/1920x540/titanic.mp4
--------------------------------------------------------------------------------
/docs/static/video_demos_compressed/slider/slider.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/slider/slider.mp4
--------------------------------------------------------------------------------
/docs/static/video_demos_compressed/teaser/cloud_atlas_SR_dual.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/teaser/cloud_atlas_SR_dual.mp4
--------------------------------------------------------------------------------
/docs/static/video_demos_compressed/teaser/segtrack.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/teaser/segtrack.mp4
--------------------------------------------------------------------------------
/docs/static/video_demos_compressed/teaser/teaser.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/teaser/teaser.mp4
--------------------------------------------------------------------------------
/docs/static/video_demos_compressed/teaser/tracking_zoom.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/teaser/tracking_zoom.mp4
--------------------------------------------------------------------------------
/docs/teaser.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/teaser.gif
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 | import torchvision
4 | from einops import rearrange, reduce, repeat
5 |
6 |
7 | class MSELoss(nn.Module):
8 | def __init__(self, coef=1):
9 | super().__init__()
10 | self.coef = coef
11 | self.loss = nn.MSELoss(reduction='mean')
12 |
13 | def forward(self, inputs, targets):
14 | loss = self.loss(inputs, targets)
15 | return self.coef * loss
16 |
17 |
18 | def rgb_to_gray(image):
19 | gray_image = (0.299 * image[:, 0, :, :] + 0.587 * image[:, 1, :, :] +
20 | 0.114 * image[:, 2, :, :])
21 | gray_image = gray_image.unsqueeze(1)
22 |
23 | return gray_image
24 |
25 |
26 | def compute_gradient_loss(pred, gt, mask):
27 | assert pred.shape == gt.shape, "a and b must have the same shape"
28 |
29 | pred = rgb_to_gray(pred)
30 | gt = rgb_to_gray(gt)
31 |
32 | sobel_kernel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=pred.dtype, device=pred.device)
33 | sobel_kernel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=pred.dtype, device=pred.device)
34 |
35 | gradient_a_x = torch.nn.functional.conv2d(pred.repeat(1,3,1,1), sobel_kernel_x.unsqueeze(0).unsqueeze(0).repeat(1,3,1,1), padding=1)/3
36 | gradient_a_y = torch.nn.functional.conv2d(pred.repeat(1,3,1,1), sobel_kernel_y.unsqueeze(0).unsqueeze(0).repeat(1,3,1,1), padding=1)/3
37 | # gradient_a_magnitude = torch.sqrt(gradient_a_x ** 2 + gradient_a_y ** 2)
38 |
39 | gradient_b_x = torch.nn.functional.conv2d(gt.repeat(1,3,1,1), sobel_kernel_x.unsqueeze(0).unsqueeze(0).repeat(1,3,1,1), padding=1)/3
40 | gradient_b_y = torch.nn.functional.conv2d(gt.repeat(1,3,1,1), sobel_kernel_y.unsqueeze(0).unsqueeze(0).repeat(1,3,1,1), padding=1)/3
41 | # gradient_b_magnitude = torch.sqrt(gradient_b_x ** 2 + gradient_b_y ** 2)
42 |
43 | pred_grad = torch.cat([gradient_a_x, gradient_a_y], dim=1)
44 | gt_grad = torch.cat([gradient_b_x, gradient_b_y], dim=1)
45 |
46 | gradient_difference = torch.abs(pred_grad - gt_grad).mean(dim=1,keepdim=True)[mask].sum()/(mask.sum()+1e-8)
47 |
48 | return gradient_difference
49 |
50 |
51 | loss_dict = {'mse': MSELoss}
52 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from kornia.losses import ssim as dssim
3 | from skimage.metrics import structural_similarity
4 | from einops import rearrange
5 | import numpy as np
6 |
7 | def mse(image_pred, image_gt, valid_mask=None, reduction='mean'):
8 | value = (image_pred-image_gt)**2
9 | if valid_mask is not None:
10 | value = value[valid_mask]
11 | if reduction == 'mean':
12 | return torch.mean(value)
13 | return value
14 |
15 | def psnr(image_pred, image_gt, valid_mask=None, reduction='mean'):
16 | return -10*torch.log10(mse(image_pred, image_gt, valid_mask, reduction))
17 |
18 | def ssim(image_pred, image_gt, reduction='mean'):
19 | return structural_similarity(image_pred.cpu().numpy(), image_gt, win_size=11, multichannel=True, gaussian_weights=True)
20 |
21 | def lpips(image_pred, image_gt, lpips_model):
22 | gt_lpips = image_gt * 2.0 - 1.0
23 | gt_lpips = rearrange(gt_lpips, '(b h) w c -> b c h w', b=1)
24 | gt_lpips = torch.from_numpy(gt_lpips)
25 | predict_image_lpips = image_pred.clone().detach().cpu() * 2.0 - 1.0
26 | predict_image_lpips = rearrange(predict_image_lpips, '(b h) w c -> b c h w', b=1)
27 | lpips_result = lpips_model.forward(predict_image_lpips, gt_lpips).cpu().detach().numpy()
28 | return np.squeeze(lpips_result)
29 |
--------------------------------------------------------------------------------
/models/implicit_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import math
4 | import tinycudann as tcnn
5 |
6 |
7 | def init_weights(m):
8 | if isinstance(m, nn.Linear):
9 | torch.nn.init.xavier_uniform_(m.weight)
10 | nn.init.zeros_(m.bias)
11 |
12 |
13 | class TranslationField(nn.Module):
14 | def __init__(self, D=6, W=128,
15 | in_channels_w=8, in_channels_xyz=34,
16 | skips=[4]):
17 | """
18 | D: number of layers for density (sigma) encoder
19 | W: number of hidden units in each layer
20 | in_channels_xyz: number of input channels for xyz (2+2*8*2=34 by default)
21 | in_channels_w: number of channels for warping channels
22 | skips: add skip connection in the Dth layer
23 | """
24 | super(TranslationField, self).__init__()
25 | self.D = D
26 | self.W = W
27 | self.skips = skips
28 | self.in_channels_xyz = in_channels_xyz
29 | self.in_channels_w = in_channels_w
30 | self.typ = "translation"
31 |
32 | # encoding layers
33 | for i in range(D):
34 | if i == 0:
35 | layer = nn.Linear(in_channels_xyz+self.in_channels_w, W)
36 | elif i in skips:
37 | layer = nn.Linear(W+in_channels_xyz+self.in_channels_w, W)
38 | else:
39 | layer = nn.Linear(W, W)
40 | init_weights(layer)
41 | layer = nn.Sequential(layer, nn.ReLU(True))
42 | # init the models
43 | setattr(self, f"warping_field_xyz_encoding_{i+1}", layer)
44 | out_layer = nn.Linear(W, 2)
45 | nn.init.zeros_(out_layer.bias)
46 | nn.init.uniform_(out_layer.weight, -1e-4, 1e-4)
47 | self.output = nn.Sequential(out_layer)
48 |
49 | def forward(self, x):
50 | """
51 | Encodes input xyz to warp field for points
52 |
53 | Inputs:
54 | x: (B, self.in_channels_xyz)
55 | the embedded vector of position and direction
56 | Outputs:
57 | t: warping field
58 | """
59 | input_xyz = x
60 |
61 | xyz_ = input_xyz
62 | for i in range(self.D):
63 | if i in self.skips:
64 | xyz_ = torch.cat([input_xyz, xyz_], -1)
65 | xyz_ = getattr(self, f"warping_field_xyz_encoding_{i+1}")(xyz_)
66 |
67 | t = self.output(xyz_)
68 |
69 | return t
70 |
71 |
72 | class Embedding(nn.Module):
73 | def __init__(self, in_channels, N_freqs, logscale=True, identity=True):
74 | """
75 | Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...)
76 | in_channels: number of input channels (3 for both xyz and direction)
77 | """
78 | super(Embedding, self).__init__()
79 | self.N_freqs = N_freqs
80 | self.annealed = False
81 | self.identity = identity
82 | self.in_channels = in_channels
83 | self.funcs = [torch.sin, torch.cos]
84 | self.out_channels = in_channels*(len(self.funcs)*N_freqs+1)
85 |
86 | if logscale:
87 | self.freq_bands = 2**torch.linspace(0, N_freqs-1, N_freqs)
88 | else:
89 | self.freq_bands = torch.linspace(1, 2**(N_freqs-1), N_freqs)
90 |
91 | def forward(self, x):
92 | """
93 | Embeds x to (x, sin(2^k x), cos(2^k x), ...)
94 | Different from the paper, "x" is also in the output
95 | See https://github.com/bmild/nerf/issues/12
96 |
97 | Inputs:
98 | x: (B, self.in_channels)
99 |
100 | Outputs:
101 | out: (B, self.out_channels)
102 | """
103 | if self.identity:
104 | out = [x]
105 | else:
106 | out = []
107 | for freq in self.freq_bands:
108 | for func in self.funcs:
109 | out += [func(freq*x)]
110 |
111 | return torch.cat(out, -1)
112 |
113 |
114 | class AnnealedEmbedding(nn.Module):
115 | def __init__(self, in_channels, N_freqs, annealed_step, annealed_begin_step=0, logscale=True, identity=True):
116 | """
117 | Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...)
118 | in_channels: number of input channels (3 for both xyz and direction)
119 | """
120 | super(AnnealedEmbedding, self).__init__()
121 | self.N_freqs = N_freqs
122 | self.in_channels = in_channels
123 | self.annealed = True
124 | self.annealed_step = annealed_step
125 | self.annealed_begin_step = annealed_begin_step
126 | self.funcs = [torch.sin, torch.cos]
127 | self.out_channels = in_channels*(len(self.funcs)*N_freqs+1)
128 | self.index = torch.linspace(0, N_freqs-1, N_freqs)
129 | self.identity = identity
130 |
131 | if logscale:
132 | self.freq_bands = 2**torch.linspace(0, N_freqs-1, N_freqs)
133 | else:
134 | self.freq_bands = torch.linspace(1, 2**(N_freqs-1), N_freqs)
135 |
136 | def forward(self, x, step):
137 | """
138 | Embeds x to (x, sin(2^k x), cos(2^k x), ...)
139 | Different from the paper, "x" is also in the output
140 | See https://github.com/bmild/nerf/issues/12
141 |
142 | Inputs:
143 | x: (B, self.in_channels)
144 |
145 | Outputs:
146 | out: (B, self.out_channels)
147 | """
148 | if self.identity:
149 | out = [x]
150 | else:
151 | out = []
152 |
153 | if self.annealed_begin_step == 0:
154 | # calculate the w for each freq bands
155 | alpha = self.N_freqs * step / float(self.annealed_step)
156 | else:
157 | if step <= self.annealed_begin_step:
158 | alpha = 0
159 | else:
160 | alpha = self.N_freqs * (step - self.annealed_begin_step) / float(
161 | self.annealed_step)
162 |
163 | for j, freq in enumerate(self.freq_bands):
164 | w = (1 - torch.cos(
165 | math.pi * torch.clamp(alpha - self.index[j], 0, 1))) / 2
166 | for func in self.funcs:
167 | out += [w * func(freq*x)]
168 |
169 | return torch.cat(out, -1)
170 |
171 |
172 | class AnnealedHash(nn.Module):
173 | def __init__(self, in_channels, annealed_step, annealed_begin_step=0, identity=True):
174 | """
175 | Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...)
176 | in_channels: number of input channels (3 for both xyz and direction)
177 | """
178 | super(AnnealedHash, self).__init__()
179 | self.N_freqs = 16
180 | self.in_channels = in_channels
181 | self.annealed = True
182 | self.annealed_step = annealed_step
183 | self.annealed_begin_step = annealed_begin_step
184 |
185 | self.index = torch.linspace(0, self.N_freqs - 1, self.N_freqs)
186 | self.identity = identity
187 |
188 | self.index_2 = self.index.view(-1, 1).repeat(1, 2).view(-1)
189 |
190 | def forward(self, x_embed, step):
191 | """
192 | Embeds x to (x, sin(2^k x), cos(2^k x), ...)
193 | Different from the paper, "x" is also in the output
194 | See https://github.com/bmild/nerf/issues/12
195 |
196 | Inputs:
197 | x: (B, self.in_channels)
198 |
199 | Outputs:
200 | out: (B, self.out_channels)
201 | """
202 |
203 | if self.annealed_begin_step == 0:
204 | # calculate the w for each freq bands
205 | alpha = self.N_freqs * step / float(self.annealed_step)
206 | else:
207 | if step <= self.annealed_begin_step:
208 | alpha = 0
209 | else:
210 | alpha = self.N_freqs * (step - self.annealed_begin_step) / float(
211 | self.annealed_step)
212 |
213 | w = (1 - torch.cos(math.pi * torch.clamp(alpha * torch.ones_like(self.index_2) - self.index_2, 0, 1))) / 2
214 |
215 | out = x_embed * w.to(x_embed.device)
216 |
217 | return out
218 |
219 |
220 | class ImplicitVideo(nn.Module):
221 | def __init__(self,
222 | D=8, W=256,
223 | in_channels_xyz=34,
224 | skips=[4],
225 | out_channels=3,
226 | sigmoid_offset=0):
227 | """
228 | D: number of layers for density (sigma) encoder
229 | W: number of hidden units in each layer
230 | in_channels_xyz: number of input channels for xyz (3+3*8*2=51 by default)
231 | skips: add skip connection in the Dth layer
232 |
233 | ------ for nerfies ------
234 | encode_warp: whether to encode warping
235 | in_channels_w: dimension of warping embeddings
236 | """
237 | super(ImplicitVideo, self).__init__()
238 | self.D = D
239 | self.W = W
240 | self.in_channels_xyz = in_channels_xyz
241 | self.skips = skips
242 | self.in_channels_xyz = self.in_channels_xyz
243 | self.sigmoid_offset = sigmoid_offset
244 |
245 | # xyz encoding layers
246 | for i in range(D):
247 | if i == 0:
248 | layer = nn.Linear(self.in_channels_xyz, W)
249 | elif i in skips:
250 | layer = nn.Linear(W+self.in_channels_xyz, W)
251 | else:
252 | layer = nn.Linear(W, W)
253 | init_weights(layer)
254 | layer = nn.Sequential(layer, nn.ReLU(True))
255 | setattr(self, f"xyz_encoding_{i+1}", layer)
256 | self.xyz_encoding_final = nn.Linear(W, W)
257 | init_weights(self.xyz_encoding_final)
258 |
259 | # output layers
260 | self.rgb = nn.Sequential(nn.Linear(W, out_channels))
261 |
262 | self.rgb.apply(init_weights)
263 |
264 | def forward(self, x):
265 | """
266 | Encodes input (xyz+dir) to rgb+sigma (not ready to render yet).
267 | For rendering this ray, please see rendering.py
268 |
269 | Inputs:
270 | x: (B, self.in_channels_xyz)
271 | the embedded vector of position and direction
272 | sigma_only: whether to infer sigma only. If True,
273 | x is of shape (B, self.in_channels_xyz)
274 |
275 | Outputs:
276 | if sigma_ony:
277 | sigma: (B, 1) sigma
278 | else:
279 | out: (B, 4), rgb and sigma
280 | """
281 | input_xyz = x
282 |
283 | xyz_ = input_xyz
284 | for i in range(self.D):
285 | if i in self.skips:
286 | xyz_ = torch.cat([input_xyz, xyz_], -1)
287 | xyz_ = getattr(self, f"xyz_encoding_{i+1}")(xyz_)
288 |
289 | xyz_encoding_final = self.xyz_encoding_final(xyz_)
290 | out = self.rgb(xyz_encoding_final)
291 |
292 | out = torch.sigmoid(out) - self.sigmoid_offset
293 |
294 | return out
295 |
296 |
297 | class ImplicitVideo_Hash(nn.Module):
298 | def __init__(self, config):
299 | super().__init__()
300 | self.encoder = tcnn.Encoding(n_input_dims=2,
301 | encoding_config=config["encoding"])
302 | self.decoder = tcnn.Network(n_input_dims=self.encoder.n_output_dims +
303 | 2,
304 | n_output_dims=3,
305 | network_config=config["network"])
306 |
307 | def forward(self, x):
308 | input = x
309 | input = self.encoder(input)
310 | input = torch.cat([x, input], dim=-1)
311 | weight = torch.ones(input.shape[-1], device=input.device).cuda()
312 | x = self.decoder(weight * input)
313 | return x
314 |
315 |
316 | class Deform_Hash3d(nn.Module):
317 | def __init__(self, config):
318 | super().__init__()
319 | self.encoder = tcnn.Encoding(n_input_dims=3,
320 | encoding_config=config["encoding_deform3d"])
321 | self.decoder = tcnn.Network(n_input_dims=self.encoder.n_output_dims + 3,
322 | n_output_dims=2,
323 | network_config=config["network_deform"])
324 |
325 | def forward(self, x, step=0, aneal_func=None):
326 | input = x
327 | input = self.encoder(input)
328 | if aneal_func is not None:
329 | input = torch.cat([x, aneal_func(input,step)], dim=-1)
330 | else:
331 | input = torch.cat([x, input], dim=-1)
332 |
333 | weight = torch.ones(input.shape[-1], device=input.device).cuda()
334 | x = self.decoder(weight * input) / 5
335 |
336 | return x
337 |
338 |
339 | class Deform_Hash3d_Warp(nn.Module):
340 | def __init__(self, config):
341 | super().__init__()
342 | self.Deform_Hash3d = Deform_Hash3d(config)
343 |
344 | def forward(self, xyt_norm, step=0,aneal_func=None):
345 | x = self.Deform_Hash3d(xyt_norm,step=step, aneal_func=aneal_func)
346 |
347 | return x
348 |
--------------------------------------------------------------------------------
/opt.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import yaml
3 |
4 |
5 | def get_opts():
6 | parser = argparse.ArgumentParser()
7 |
8 | # General Setttings
9 | parser.add_argument('--root_dir', type=str, default='Batman_masked_frames',
10 | help='root directory of dataset')
11 | parser.add_argument('--canonical_dir', type=str, default=None,
12 | help='directory of canonical dataset')
13 |
14 | # support multiple mask as input (each mask has different deformation fields)
15 | parser.add_argument('--mask_dir', nargs="+", type=str, default=None,
16 | help='mask of the dataset')
17 | parser.add_argument('--flow_dir', type=str,
18 | default=None,
19 | help='masks of dataset')
20 | parser.add_argument('--dataset_name', type=str, default='video',
21 | choices=['video'],
22 | help='which dataset to train/val')
23 | parser.add_argument('--img_wh', nargs="+", type=int, default=[842, 512],
24 | help='resolution (img_w, img_h) of the full image')
25 | parser.add_argument('--canonical_wh', nargs="+", type=int, default=None,
26 | help='default same as the img_wh, can be set to a larger range to include more content')
27 | parser.add_argument('--ref_idx', type=int, default=None,
28 | help='manually select a frame as reference (for rigid movement)')
29 |
30 | # Deformation Setting
31 | parser.add_argument('--encode_w', default=False, action="store_true",
32 | help='whether to apply warping')
33 |
34 | # Training Setttings
35 |
36 | parser.add_argument('--batch_size', type=int, default=1,
37 | help='batch size')
38 | parser.add_argument('--num_steps', type=int, default=10000,
39 | help='number of training epochs')
40 | parser.add_argument('--valid_iters', type=int, default=30,
41 | help='valid iters for each epoch')
42 | parser.add_argument('--valid_batches', type=int, default=0,
43 | help='valid batches for each valid process')
44 | parser.add_argument('--save_model_iters', type=int, default=5000,
45 | help='iterations to save the models')
46 | parser.add_argument('--gpus', nargs="+", type=int, default=[0],
47 | help='gpu devices')
48 |
49 | # Test Setttings
50 | parser.add_argument('--test', default=False, action="store_true",
51 | help='whether to disable identity')
52 |
53 | # Model Save and Load
54 | parser.add_argument('--ckpt_path', type=str, default=None,
55 | help='pretrained checkpoint to load (including optimizers, etc)')
56 | parser.add_argument('--prefixes_to_ignore', nargs='+', type=str, default=['loss'],
57 | help='the prefixes to ignore in the checkpoint state dict')
58 | parser.add_argument('--weight_path', type=str, default=None,
59 | help='pretrained model weight to load (do not load optimizers, etc)')
60 | parser.add_argument('--model_save_path', type=str, default='ckpts',
61 | help='save checkpoint to')
62 | parser.add_argument('--log_save_path', type=str, default='logs',
63 | help='save log to')
64 | parser.add_argument('--exp_name', type=str, default='exp',
65 | help='experiment name')
66 |
67 | # Optimize Settings
68 | parser.add_argument('--optimizer', type=str, default='adam',
69 | help='optimizer type',
70 | choices=['sgd', 'adam', 'radam', 'ranger'])
71 | parser.add_argument('--lr', type=float, default=5e-4,
72 | help='learning rate')
73 | parser.add_argument('--momentum', type=float, default=0.9,
74 | help='learning rate momentum')
75 | parser.add_argument('--weight_decay', type=float, default=0,
76 | help='weight decay')
77 | parser.add_argument('--lr_scheduler', type=str, default='steplr',
78 | help='scheduler type',
79 | choices=['steplr', 'cosine', 'poly', 'exponential'])
80 |
81 | #### params for steplr ####
82 | parser.add_argument('--decay_step', nargs='+', type=int,
83 | default=[2500, 5000, 7500],
84 | help='scheduler decay step')
85 | parser.add_argument('--decay_gamma', type=float, default=0.5,
86 | help='learning rate decay amount')
87 |
88 | #### params for warmup, only applied when optimizer == 'sgd' or 'adam'
89 | parser.add_argument('--warmup_multiplier', type=float, default=1.0,
90 | help='lr is multiplied by this factor after --warmup_epochs')
91 | parser.add_argument('--warmup_epochs', type=int, default=0,
92 | help='Gradually warm-up(increasing) learning rate in optimizer')
93 |
94 | ##### annealed positional encoding ######
95 | parser.add_argument('--annealed', default=False, action="store_true",
96 | help='whether to apply annealed positional encoding (Only in the warping field)')
97 | parser.add_argument('--annealed_begin_step', type=int, default=0,
98 | help='annealed step to begin for positional encoding')
99 | parser.add_argument('--annealed_step', type=int, default=5000,
100 | help='maximum annealed step for positional encoding')
101 |
102 | ##### Additional losses ######
103 | parser.add_argument('--flow_loss', type=float, default=None,
104 | help='optical flow loss weight')
105 | parser.add_argument('--bg_loss', type=float, default=None,
106 | help='regularize the rest part of each object ')
107 | parser.add_argument('--grad_loss', type=float, default=0.1,
108 | help='image gradient loss weight')
109 | parser.add_argument('--flow_step', type=int, default=-1,
110 | help='Step to begin to perform flow loss.')
111 | parser.add_argument('--ref_step', type=int, default=-1,
112 | help='Step to stop reference frame loss.')
113 | parser.add_argument('--self_bg', type=bool_parser, default=False,
114 | help='Whether to use self background as bg loss.')
115 |
116 | ##### Special cases: for black-dominated images
117 | parser.add_argument('--sigmoid_offset', type=float, default=0,
118 | help='whether to process balck-dominated images.')
119 |
120 | # Other miscellaneous settings.
121 | parser.add_argument('--save_deform', type=bool_parser, default=False,
122 | help='Whether to save deformation field or not.')
123 | parser.add_argument('--save_video', type=bool_parser, default=True,
124 | help='Whether to save video or not.')
125 | parser.add_argument('--fps', type=int, default=30,
126 | help='FPS of the saved video.')
127 |
128 | # Network settings for PE.
129 | parser.add_argument('--deform_D', type=int, default=6,
130 | help='The depth of deformation field MLP.')
131 | parser.add_argument('--deform_W', type=int, default=128,
132 | help='The width of deformation field MLP.')
133 | parser.add_argument('--vid_D', type=int, default=8,
134 | help='The depth of implicit video MLP.')
135 | parser.add_argument('--vid_W', type=int, default=256,
136 | help='The width of implicit video MLP.')
137 | parser.add_argument('--N_vocab_w', type=int, default=200,
138 | help='number of vocabulary for warp code in the dataset for nn.Embedding')
139 | parser.add_argument('--N_w', type=int, default=8,
140 | help='embeddings size for warping')
141 | parser.add_argument('--N_xyz_w', nargs="+", type=int, default=[8, 8],
142 | help='positional encoding frequency of deformation field')
143 |
144 | # Network settings for Hash, please see details in configs/hash.json
145 | parser.add_argument('--vid_hash', type=bool_parser, default=False,
146 | help='Whether to use hash encoding in implicit video system.')
147 | parser.add_argument('--deform_hash', type=bool_parser, default=False,
148 | help='Whether to use hash encoding in deformation field.')
149 |
150 | # Config files
151 | parser.add_argument('--config', type=str, default=None,
152 | help='path to the YAML config file.')
153 |
154 | args = parser.parse_args()
155 |
156 | if args.config is not None:
157 | with open(args.config, 'r') as f:
158 | config = yaml.safe_load(f)
159 | args_dict = vars(args)
160 | args_dict.update(config)
161 | args_new = argparse.Namespace(**args_dict)
162 | return args_new
163 |
164 | return args
165 |
166 |
167 | def bool_parser(arg):
168 | """Parses an argument to boolean."""
169 | if isinstance(arg, bool):
170 | return arg
171 | if arg is None:
172 | return False
173 | if arg.lower() in ['1', 'true', 't', 'yes', 'y']:
174 | return True
175 | if arg.lower() in ['0', 'false', 'f', 'no', 'n']:
176 | return False
177 | raise ValueError(f'`{arg}` cannot be converted to boolean!')
178 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pytorch_lightning==2.0.2
2 | easydict==1.10
3 | einops==0.6.1
4 | ipdb==0.13.13
5 | numpy==1.24.3
6 | opencv-python-headless==4.5.5.62
7 | PyYAML==6.0
8 | scikit-image==0.21.0
9 | scipy==1.10.1
10 | sk-video==1.1.10
11 | tensorboard==2.13.0
12 | tqdm==4.65.0
13 | torch_optimizer==0.3.0
14 | kornia==0.5.10
15 |
--------------------------------------------------------------------------------
/scripts/test_canonical.sh:
--------------------------------------------------------------------------------
1 | GPUS=0
2 |
3 | NAME=scene_0
4 | EXP_NAME=base
5 |
6 | ROOT_DIRECTORY="all_sequences/$NAME/$NAME"
7 | LOG_SAVE_PATH="logs/test_all_sequences/$NAME"
8 |
9 | MASK_DIRECTORY="all_sequences/$NAME/${NAME}_masks_0 all_sequences/$NAME/${NAME}_masks_1"
10 |
11 | CANONICAL_DIR="all_sequences/${NAME}/${EXP_NAME}_control"
12 |
13 | WEIGHT_PATH=ckpts/all_sequences/$NAME/${EXP_NAME}/${NAME}.ckpt
14 | # WEIGHT_PATH=ckpts/all_sequences/$NAME/${EXP_NAME}/step=10000.ckpt
15 |
16 | python train.py --test --encode_w \
17 | --root_dir $ROOT_DIRECTORY \
18 | --log_save_path $LOG_SAVE_PATH \
19 | --mask_dir $MASK_DIRECTORY \
20 | --weight_path $WEIGHT_PATH \
21 | --gpus $GPUS \
22 | --canonical_dir $CANONICAL_DIR \
23 | --config configs/${NAME}/${EXP_NAME}.yaml \
24 | --exp_name $EXP_NAME
25 |
--------------------------------------------------------------------------------
/scripts/test_multi.sh:
--------------------------------------------------------------------------------
1 | GPUS=0
2 |
3 | NAME=scene_0
4 | EXP_NAME=base
5 |
6 | ROOT_DIRECTORY="all_sequences/$NAME/$NAME"
7 | LOG_SAVE_PATH="logs/test_all_sequences/$NAME"
8 |
9 | MASK_DIRECTORY="all_sequences/$NAME/${NAME}_masks_0 all_sequences/$NAME/${NAME}_masks_1"
10 |
11 | WEIGHT_PATH=ckpts/all_sequences/$NAME/${EXP_NAME}/${NAME}.ckpt
12 | # WEIGHT_PATH=ckpts/all_sequences/$NAME/${EXP_NAME}/step=10000.ckpt
13 |
14 | python train.py --test --encode_w \
15 | --root_dir $ROOT_DIRECTORY \
16 | --log_save_path $LOG_SAVE_PATH \
17 | --mask_dir $MASK_DIRECTORY \
18 | --weight_path $WEIGHT_PATH \
19 | --gpus $GPUS \
20 | --config configs/${NAME}/${EXP_NAME}.yaml \
21 | --exp_name ${EXP_NAME} \
22 | --save_deform False
23 |
--------------------------------------------------------------------------------
/scripts/train_multi.sh:
--------------------------------------------------------------------------------
1 | GPUS=0
2 |
3 | NAME=scene_0
4 | EXP_NAME=base
5 |
6 | ROOT_DIRECTORY="all_sequences/$NAME/$NAME"
7 | MODEL_SAVE_PATH="ckpts/all_sequences/$NAME"
8 | LOG_SAVE_PATH="logs/all_sequences/$NAME"
9 |
10 | MASK_DIRECTORY="all_sequences/$NAME/${NAME}_masks_0 all_sequences/$NAME/${NAME}_masks_1"
11 | FLOW_DIRECTORY="all_sequences/$NAME/${NAME}_flow"
12 |
13 | python train.py --root_dir $ROOT_DIRECTORY \
14 | --model_save_path $MODEL_SAVE_PATH \
15 | --log_save_path $LOG_SAVE_PATH \
16 | --mask_dir $MASK_DIRECTORY \
17 | --flow_dir $FLOW_DIRECTORY \
18 | --gpus $GPUS \
19 | --encode_w --annealed \
20 | --config configs/${NAME}/${EXP_NAME}.yaml \
21 | --exp_name ${EXP_NAME}
22 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | # optimizer
3 | from torch.optim import SGD, Adam
4 | import torch_optimizer as optim
5 | # scheduler
6 | from torch.optim.lr_scheduler import CosineAnnealingLR
7 | from torch.optim.lr_scheduler import MultiStepLR
8 | from torch.optim.lr_scheduler import LambdaLR
9 | from torch.optim.lr_scheduler import ExponentialLR
10 | from .warmup_scheduler import GradualWarmupScheduler
11 |
12 | from .video_visualizer import VideoVisualizer
13 |
14 |
15 | def get_parameters(models):
16 | """Get all model parameters recursively."""
17 | parameters = []
18 | if isinstance(models, list):
19 | # print("is list")
20 | for model in models:
21 | parameters += get_parameters(model)
22 | elif isinstance(models, dict):
23 | # print("is dict")
24 | for model in models.values():
25 | parameters += get_parameters(model)
26 | else: # models is actually a single pytorch model
27 | parameters += list(models.parameters())
28 | return parameters
29 |
30 | def get_optimizer(hparams, models):
31 | eps = 1e-8
32 | parameters = get_parameters(models)
33 | if hparams.optimizer == 'sgd':
34 | optimizer = SGD(parameters, lr=hparams.lr,
35 | momentum=hparams.momentum, weight_decay=hparams.weight_decay)
36 | elif hparams.optimizer == 'adam':
37 | optimizer = Adam(parameters, lr=hparams.lr, eps=eps,
38 | weight_decay=hparams.weight_decay)
39 | elif hparams.optimizer == 'radam':
40 | optimizer = optim.RAdam(parameters, lr=hparams.lr, eps=eps,
41 | weight_decay=hparams.weight_decay)
42 | elif hparams.optimizer == 'ranger':
43 | optimizer = optim.Ranger(parameters, lr=hparams.lr, eps=eps,
44 | weight_decay=hparams.weight_decay)
45 | else:
46 | raise ValueError('optimizer not recognized!')
47 |
48 | return optimizer
49 |
50 | def get_scheduler(hparams, optimizer):
51 | eps = 1e-8
52 | if hparams.lr_scheduler == 'steplr':
53 | scheduler = MultiStepLR(optimizer, milestones=hparams.decay_step,
54 | gamma=hparams.decay_gamma)
55 | elif hparams.lr_scheduler == 'cosine':
56 | scheduler = CosineAnnealingLR(optimizer, T_max=hparams.num_epochs, eta_min=eps)
57 | elif hparams.lr_scheduler == 'poly':
58 | scheduler = LambdaLR(optimizer,
59 | lambda epoch: (1-epoch/hparams.num_epochs)**hparams.poly_exp)
60 | elif hparams.lr_scheduler == 'exponential':
61 | # Adaptively adjust the schedule
62 | scheduler = LambdaLR(optimizer,
63 | lambda step: hparams.exponent_base**(step/(2 * hparams.num_steps)))
64 | else:
65 | raise ValueError('scheduler not recognized!')
66 |
67 | if hparams.warmup_epochs > 0 and hparams.optimizer not in ['radam', 'ranger']:
68 | scheduler = GradualWarmupScheduler(optimizer, multiplier=hparams.warmup_multiplier,
69 | total_epoch=hparams.warmup_epochs, after_scheduler=scheduler)
70 |
71 | return scheduler
72 |
73 | def get_learning_rate(optimizer):
74 | for param_group in optimizer.param_groups:
75 | return param_group['lr']
76 |
77 | def extract_model_state_dict(ckpt_path, model_name='model', prefixes_to_ignore=[]):
78 | checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
79 | checkpoint_ = {}
80 | if 'state_dict' in checkpoint: # if it's a pytorch-lightning checkpoint
81 | checkpoint = checkpoint['state_dict']
82 | for k, v in checkpoint.items():
83 | if not k.startswith(model_name):
84 | continue
85 | k = k[len(model_name)+1:]
86 | for prefix in prefixes_to_ignore:
87 | if k.startswith(prefix):
88 | print('ignore', k)
89 | break
90 | else:
91 | checkpoint_[k] = v
92 | return checkpoint_
93 |
94 | def load_ckpt(model, ckpt_path, model_name='model', prefixes_to_ignore=[]):
95 | if not ckpt_path:
96 | return
97 | model_dict = model.state_dict()
98 | checkpoint_ = extract_model_state_dict(ckpt_path, model_name, prefixes_to_ignore)
99 | model_dict.update(checkpoint_)
100 | model.load_state_dict(model_dict)
101 |
102 |
--------------------------------------------------------------------------------
/utils/image_utils.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains utility functions for image processing.
3 |
4 | The module is primarily built on `cv2`. But, differently, we assume all colorful
5 | images are with `RGB` channel order by default. Also, we assume all gray-scale
6 | images to be with shape [height, width, 1].
7 | """
8 |
9 | import os
10 | import cv2
11 | import numpy as np
12 |
13 | # File extensions regarding images (not including GIFs).
14 | IMAGE_EXTENSIONS = (
15 | '.bmp', '.ppm', '.pgm', '.jpeg', '.jpg', '.jpe', '.jp2', '.png', '.webp',
16 | '.tiff', '.tif'
17 | )
18 |
19 | def check_file_ext(filename, *ext_list):
20 | """Checks whether the given filename is with target extension(s).
21 |
22 | NOTE: If `ext_list` is empty, this function will always return `False`.
23 |
24 | Args:
25 | filename: Filename to check.
26 | *ext_list: A list of extensions.
27 |
28 | Returns:
29 | `True` if the filename is with one of extensions in `ext_list`,
30 | otherwise `False`.
31 | """
32 | if len(ext_list) == 0:
33 | return False
34 | ext_list = [ext if ext.startswith('.') else '.' + ext for ext in ext_list]
35 | ext_list = [ext.lower() for ext in ext_list]
36 | basename = os.path.basename(filename)
37 | ext = os.path.splitext(basename)[1].lower()
38 | return ext in ext_list
39 |
40 |
41 | def _check_2d_image(image):
42 | """Checks whether a given image is valid.
43 |
44 | A valid image is expected to be with dtype `uint8`. Also, it should have
45 | shape like:
46 |
47 | (1) (height, width, 1) # gray-scale image.
48 | (2) (height, width, 3) # colorful image.
49 | (3) (height, width, 4) # colorful image with transparency (RGBA)
50 | """
51 | assert isinstance(image, np.ndarray)
52 | assert image.dtype == np.uint8
53 | assert image.ndim == 3 and image.shape[2] in [1, 3, 4]
54 |
55 |
56 | def get_blank_image(height, width, channels=3, use_black=True):
57 | """Gets a blank image, either white of black.
58 |
59 | NOTE: This function will always return an image with `RGB` channel order for
60 | color image and pixel range [0, 255].
61 |
62 | Args:
63 | height: Height of the returned image.
64 | width: Width of the returned image.
65 | channels: Number of channels. (default: 3)
66 | use_black: Whether to return a black image. (default: True)
67 | """
68 | shape = (height, width, channels)
69 | if use_black:
70 | return np.zeros(shape, dtype=np.uint8)
71 | return np.ones(shape, dtype=np.uint8) * 255
72 |
73 |
74 | def load_image(path):
75 | """Loads an image from disk.
76 |
77 | NOTE: This function will always return an image with `RGB` channel order for
78 | color image and pixel range [0, 255].
79 |
80 | Args:
81 | path: Path to load the image from.
82 |
83 | Returns:
84 | An image with dtype `np.ndarray`, or `None` if `path` does not exist.
85 | """
86 | image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
87 | if image is None:
88 | return None
89 |
90 | if image.ndim == 2:
91 | image = image[:, :, np.newaxis]
92 | _check_2d_image(image)
93 | if image.shape[2] == 3:
94 | return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
95 | if image.shape[2] == 4:
96 | return cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
97 | return image
98 |
99 |
100 | def save_image(path, image):
101 | """Saves an image to disk.
102 |
103 | NOTE: The input image (if colorful) is assumed to be with `RGB` channel
104 | order and pixel range [0, 255].
105 |
106 | Args:
107 | path: Path to save the image to.
108 | image: Image to save.
109 | """
110 | if image is None:
111 | return
112 |
113 | _check_2d_image(image)
114 | if image.shape[2] == 1:
115 | cv2.imwrite(path, image)
116 | elif image.shape[2] == 3:
117 | cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
118 | elif image.shape[2] == 4:
119 | cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_RGBA2BGRA))
120 |
121 |
122 | def resize_image(image, *args, **kwargs):
123 | """Resizes image.
124 |
125 | This is a wrap of `cv2.resize()`.
126 |
127 | NOTE: The channel order of the input image will not be changed.
128 |
129 | Args:
130 | image: Image to resize.
131 | *args: Additional positional arguments.
132 | **kwargs: Additional keyword arguments.
133 |
134 | Returns:
135 | An image with dtype `np.ndarray`, or `None` if `image` is empty.
136 | """
137 | if image is None:
138 | return None
139 |
140 | _check_2d_image(image)
141 | if image.shape[2] == 1: # Re-expand the squeezed dim of gray-scale image.
142 | return cv2.resize(image, *args, **kwargs)[:, :, np.newaxis]
143 | return cv2.resize(image, *args, **kwargs)
144 |
145 |
146 | def add_text_to_image(image,
147 | text='',
148 | position=None,
149 | font=cv2.FONT_HERSHEY_TRIPLEX,
150 | font_size=1.0,
151 | line_type=cv2.LINE_8,
152 | line_width=1,
153 | color=(255, 255, 255)):
154 | """Overlays text on given image.
155 |
156 | NOTE: The input image is assumed to be with `RGB` channel order.
157 |
158 | Args:
159 | image: The image to overlay text on.
160 | text: Text content to overlay on the image. (default: empty)
161 | position: Target position (bottom-left corner) to add text. If not set,
162 | center of the image will be used by default. (default: None)
163 | font: Font of the text added. (default: cv2.FONT_HERSHEY_TRIPLEX)
164 | font_size: Font size of the text added. (default: 1.0)
165 | line_type: Line type used to depict the text. (default: cv2.LINE_8)
166 | line_width: Line width used to depict the text. (default: 1)
167 | color: Color of the text added in `RGB` channel order. (default:
168 | (255, 255, 255))
169 |
170 | Returns:
171 | An image with target text overlaid on.
172 | """
173 | if image is None or not text:
174 | return image
175 |
176 | _check_2d_image(image)
177 | cv2.putText(img=image,
178 | text=text,
179 | org=position,
180 | fontFace=font,
181 | fontScale=font_size,
182 | color=color,
183 | thickness=line_width,
184 | lineType=line_type,
185 | bottomLeftOrigin=False)
186 | return image
187 |
188 |
189 | def preprocess_image(image, min_val=-1.0, max_val=1.0):
190 | """Pre-processes image by adjusting the pixel range and to dtype `float32`.
191 |
192 | This function is particularly used to convert an image or a batch of images
193 | to `NCHW` format, which matches the data type commonly used in deep models.
194 |
195 | NOTE: The input image is assumed to be with pixel range [0, 255] and with
196 | format `HWC` or `NHWC`. The returned image will be always be with format
197 | `NCHW`.
198 |
199 | Args:
200 | image: The input image for pre-processing.
201 | min_val: Minimum value of the output image.
202 | max_val: Maximum value of the output image.
203 |
204 | Returns:
205 | The pre-processed image.
206 | """
207 | assert isinstance(image, np.ndarray)
208 |
209 | image = image.astype(np.float64)
210 | image = image / 255.0 * (max_val - min_val) + min_val
211 |
212 | if image.ndim == 3:
213 | image = image[np.newaxis]
214 | assert image.ndim == 4 and image.shape[3] in [1, 3, 4]
215 | return image.transpose(0, 3, 1, 2)
216 |
217 |
218 | def postprocess_image(image, min_val=-1.0, max_val=1.0):
219 | """Post-processes image to pixel range [0, 255] with dtype `uint8`.
220 |
221 | This function is particularly used to handle the results produced by deep
222 | models.
223 |
224 | NOTE: The input image is assumed to be with format `NCHW`, and the returned
225 | image will always be with format `NHWC`.
226 |
227 | Args:
228 | image: The input image for post-processing.
229 | min_val: Expected minimum value of the input image.
230 | max_val: Expected maximum value of the input image.
231 |
232 | Returns:
233 | The post-processed image.
234 | """
235 | assert isinstance(image, np.ndarray)
236 |
237 | image = image.astype(np.float64)
238 | image = (image - min_val) / (max_val - min_val) * 255
239 | image = np.clip(image + 0.5, 0, 255).astype(np.uint8)
240 |
241 | assert image.ndim == 4 and image.shape[1] in [1, 3, 4]
242 | return image.transpose(0, 2, 3, 1)
243 |
244 |
245 | def parse_image_size(obj):
246 | """Parses an object to a pair of image size, i.e., (height, width).
247 |
248 | Args:
249 | obj: The input object to parse image size from.
250 |
251 | Returns:
252 | A two-element tuple, indicating image height and width respectively.
253 |
254 | Raises:
255 | If the input is invalid, i.e., neither a list or tuple, nor a string.
256 | """
257 | if obj is None or obj == '':
258 | height = 0
259 | width = 0
260 | elif isinstance(obj, int):
261 | height = obj
262 | width = obj
263 | elif isinstance(obj, (list, tuple, str, np.ndarray)):
264 | if isinstance(obj, str):
265 | splits = obj.replace(' ', '').split(',')
266 | numbers = tuple(map(int, splits))
267 | else:
268 | numbers = tuple(obj)
269 | if len(numbers) == 0:
270 | height = 0
271 | width = 0
272 | elif len(numbers) == 1:
273 | height = int(numbers[0])
274 | width = int(numbers[0])
275 | elif len(numbers) == 2:
276 | height = int(numbers[0])
277 | width = int(numbers[1])
278 | else:
279 | raise ValueError('At most two elements for image size.')
280 | else:
281 | raise ValueError(f'Invalid type of input: `{type(obj)}`!')
282 |
283 | return (max(0, height), max(0, width))
284 |
285 |
286 | def get_grid_shape(size, height=0, width=0, is_portrait=False):
287 | """Gets the shape of a grid based on the size.
288 |
289 | This function makes greatest effort on making the output grid square if
290 | neither `height` nor `width` is set. If `is_portrait` is set as `False`, the
291 | height will always be equal to or smaller than the width. For example, if
292 | input `size = 16`, output shape will be `(4, 4)`; if input `size = 15`,
293 | output shape will be (3, 5). Otherwise, the height will always be equal to
294 | or larger than the width.
295 |
296 | Args:
297 | size: Size (height * width) of the target grid.
298 | height: Expected height. If `size % height != 0`, this field will be
299 | ignored. (default: 0)
300 | width: Expected width. If `size % width != 0`, this field will be
301 | ignored. (default: 0)
302 | is_portrait: Whether to return a portrait size of a landscape size.
303 | (default: False)
304 |
305 | Returns:
306 | A two-element tuple, representing height and width respectively.
307 | """
308 | assert isinstance(size, int)
309 | assert isinstance(height, int)
310 | assert isinstance(width, int)
311 | if size <= 0:
312 | return (0, 0)
313 |
314 | if height > 0 and width > 0 and height * width != size:
315 | height = 0
316 | width = 0
317 |
318 | if height > 0 and width > 0 and height * width == size:
319 | return (height, width)
320 | if height > 0 and size % height == 0:
321 | return (height, size // height)
322 | if width > 0 and size % width == 0:
323 | return (size // width, width)
324 |
325 | height = int(np.sqrt(size))
326 | while height > 0:
327 | if size % height == 0:
328 | width = size // height
329 | break
330 | height = height - 1
331 |
332 | return (width, height) if is_portrait else (height, width)
333 |
334 |
335 | def list_images_from_dir(directory):
336 | """Lists all images from the given directory.
337 |
338 | NOTE: Do NOT support finding images recursively.
339 |
340 | Args:
341 | directory: The directory to find images from.
342 |
343 | Returns:
344 | A list of sorted filenames, with the directory as prefix.
345 | """
346 | image_list = []
347 | for filename in os.listdir(directory):
348 | if check_file_ext(filename, *IMAGE_EXTENSIONS):
349 | image_list.append(os.path.join(directory, filename))
350 | return sorted(image_list)
351 |
--------------------------------------------------------------------------------
/utils/video_visualizer.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | from skvideo.io import FFmpegWriter
3 | from .image_utils import parse_image_size
4 | from .image_utils import load_image
5 | from .image_utils import resize_image
6 | from .image_utils import list_images_from_dir
7 |
8 |
9 | class VideoVisualizer(object):
10 | """Defines the video visualizer that presents images as a video."""
11 |
12 | def __init__(self,
13 | path=None,
14 | frame_size=None,
15 | fps=25.0,
16 | codec='libx264',
17 | pix_fmt='yuv420p',
18 | crf=1):
19 | """Initializes the video visualizer.
20 |
21 | Args:
22 | path: Path to write the video. (default: None)
23 | frame_size: Frame size, i.e., (height, width). (default: None)
24 | fps: Frames per second. (default: 24)
25 | codec: Codec. (default: `libx264`)
26 | pix_fmt: Pixel format. (default: `yuv420p`)
27 | crf: Constant rate factor, which controls the compression. The
28 | larger this field is, the higher compression and lower quality.
29 | `0` means no compression and consequently the highest quality.
30 | To enable QuickTime playing (requires YUV to be 4:2:0, but
31 | `crf = 0` results YUV to be 4:4:4), please set this field as
32 | at least 1. (default: 1)
33 | """
34 | self.set_path(path)
35 | self.set_frame_size(frame_size)
36 | self.set_fps(fps)
37 | self.set_codec(codec)
38 | self.set_pix_fmt(pix_fmt)
39 | self.set_crf(crf)
40 | self.video = None
41 |
42 | def set_path(self, path=None):
43 | """Sets the path to save the video."""
44 | self.path = path
45 |
46 | def set_frame_size(self, frame_size=None):
47 | """Sets the video frame size."""
48 | height, width = parse_image_size(frame_size)
49 | self.frame_height = height
50 | self.frame_width = width
51 |
52 | def set_fps(self, fps=25.0):
53 | """Sets the FPS (frame per second) of the video."""
54 | self.fps = fps
55 |
56 | def set_codec(self, codec='libx264'):
57 | """Sets the video codec."""
58 | self.codec = codec
59 |
60 | def set_pix_fmt(self, pix_fmt='yuv420p'):
61 | """Sets the video pixel format."""
62 | self.pix_fmt = pix_fmt
63 |
64 | def set_crf(self, crf=1):
65 | """Sets the CRF (constant rate factor) of the video."""
66 | self.crf = crf
67 |
68 | def init_video(self):
69 | """Initializes an empty video with expected settings."""
70 | assert self.frame_height > 0
71 | assert self.frame_width > 0
72 |
73 | video_setting = {
74 | '-r': f'{self.fps:.2f}',
75 | '-s': f'{self.frame_width}x{self.frame_height}',
76 | '-vcodec': f'{self.codec}',
77 | '-crf': f'{self.crf}',
78 | '-pix_fmt': f'{self.pix_fmt}',
79 | }
80 | self.video = FFmpegWriter(self.path, outputdict=video_setting)
81 |
82 | def add(self, frame):
83 | """Adds a frame into the video visualizer.
84 |
85 | NOTE: The input frame is assumed to be with `RGB` channel order.
86 | """
87 | if self.video is None:
88 | height, width = frame.shape[0:2]
89 | height = self.frame_height or height
90 | width = self.frame_width or width
91 | self.set_frame_size((height, width))
92 | self.init_video()
93 | if frame.shape[0:2] != (self.frame_height, self.frame_width):
94 | frame = resize_image(frame, (self.frame_width, self.frame_height))
95 | self.video.writeFrame(frame)
96 |
97 | def visualize_collection(self, images, save_path=None):
98 | """Visualizes a collection of images one by one."""
99 | if save_path is not None and save_path != self.path:
100 | self.save()
101 | self.set_path(save_path)
102 | for image in images:
103 | self.add(image)
104 | self.save()
105 |
106 | def visualize_list(self, image_list, save_path=None):
107 | """Visualizes a list of image files."""
108 | if save_path is not None and save_path != self.path:
109 | self.save()
110 | self.set_path(save_path)
111 | for filename in image_list:
112 | image = load_image(filename)
113 | self.add(image)
114 | self.save()
115 |
116 | def visualize_directory(self, directory, save_path=None):
117 | """Visualizes all images under a directory."""
118 | image_list = list_images_from_dir(directory)
119 | self.visualize_list(image_list, save_path)
120 |
121 | def save(self):
122 | """Saves the video by closing the file."""
123 | if self.video is not None:
124 | self.video.close()
125 | self.video = None
126 | self.set_path(None)
127 |
128 |
129 | if __name__ == '__main__':
130 | from glob import glob
131 | import cv2
132 | video_visualizer = VideoVisualizer(path='output.mp4',
133 | frame_size=None,
134 | fps=25.0)
135 | img_folder = 'src_images/'
136 | imgs = sorted(glob(img_folder + '/*.png'))
137 | for img in imgs:
138 | image = cv2.imread(img)
139 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
140 | video_visualizer.add(image)
141 | video_visualizer.save()
--------------------------------------------------------------------------------
/utils/warmup_scheduler.py:
--------------------------------------------------------------------------------
1 | from torch.optim.lr_scheduler import _LRScheduler
2 | from torch.optim.lr_scheduler import ReduceLROnPlateau
3 |
4 | class GradualWarmupScheduler(_LRScheduler):
5 | """ Gradually warm-up(increasing) learning rate in optimizer.
6 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
7 | Args:
8 | optimizer (Optimizer): Wrapped optimizer.
9 | multiplier: target learning rate = base lr * multiplier
10 | total_epoch: target learning rate is reached at total_epoch, gradually
11 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
12 | """
13 |
14 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
15 | self.multiplier = multiplier
16 | if self.multiplier < 1.:
17 | raise ValueError('multiplier should be greater thant or equal to 1.')
18 | self.total_epoch = total_epoch
19 | self.after_scheduler = after_scheduler
20 | self.finished = False
21 | super().__init__(optimizer)
22 |
23 | def get_lr(self):
24 | if self.last_epoch > self.total_epoch:
25 | if self.after_scheduler:
26 | if not self.finished:
27 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
28 | self.finished = True
29 | return self.after_scheduler.get_lr()
30 | return [base_lr * self.multiplier for base_lr in self.base_lrs]
31 |
32 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
33 |
34 | def step_ReduceLROnPlateau(self, metrics, epoch=None):
35 | if epoch is None:
36 | epoch = self.last_epoch + 1
37 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
38 | if self.last_epoch <= self.total_epoch:
39 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
40 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
41 | param_group['lr'] = lr
42 | else:
43 | if epoch is None:
44 | self.after_scheduler.step(metrics, None)
45 | else:
46 | self.after_scheduler.step(metrics, epoch - self.total_epoch)
47 |
48 | def step(self, epoch=None, metrics=None):
49 | if type(self.after_scheduler) != ReduceLROnPlateau:
50 | if self.finished and self.after_scheduler:
51 | if epoch is None:
52 | self.after_scheduler.step(None)
53 | else:
54 | self.after_scheduler.step(epoch - self.total_epoch)
55 | else:
56 | return super(GradualWarmupScheduler, self).step(epoch)
57 | else:
58 | self.step_ReduceLROnPlateau(metrics, epoch)
--------------------------------------------------------------------------------