├── .gitignore ├── LICENSE ├── README.md ├── configs ├── beauty_0 │ └── base.yaml ├── beauty_1 │ └── base.yaml ├── hash.json ├── lemon_hit │ └── base.yaml ├── scene_0 │ └── base.yaml └── white_smoke │ └── base.yaml ├── data_preprocessing ├── RAFT │ ├── core │ │ ├── __init__.py │ │ ├── corr.py │ │ ├── datasets.py │ │ ├── extractor.py │ │ ├── raft.py │ │ ├── update.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── augmentor.py │ │ │ ├── flow_viz.py │ │ │ ├── frame_utils.py │ │ │ └── utils.py │ ├── demo.py │ ├── models │ │ └── download_models.sh │ └── run_raft.sh └── preproc_mask.py ├── datasets ├── __init__.py ├── distributed_weighted_sampler.py └── video_dataset.py ├── docs ├── index.html ├── static │ ├── css │ │ ├── bulma-carousel.min.css │ │ ├── bulma-slider.min.css │ │ ├── bulma.css.map.txt │ │ ├── bulma.min.css │ │ ├── fontawesome.all.min.css │ │ └── index.css │ ├── images │ │ └── framework.png │ ├── js │ │ ├── bulma-carousel.js │ │ ├── bulma-carousel.min.js │ │ ├── bulma-slider.js │ │ ├── bulma-slider.min.js │ │ ├── fontawesome.all.min.js │ │ ├── index.js │ │ └── video_comparison.js │ └── video_demos_compressed │ │ ├── 1080x540 │ │ ├── beauty_0_base_video_hash_transformed_dual(1).mp4 │ │ ├── beauty_0_base_video_hash_transformed_dual(2).mp4 │ │ ├── beauty_1_base_transformed_dual(1).mp4 │ │ ├── beauty_1_base_transformed_dual(2).mp4 │ │ ├── rainbow.mp4 │ │ └── tifa.mp4 │ │ ├── 1080x960 │ │ ├── dog.mp4 │ │ ├── scene_1_base_transformed_dual.mp4 │ │ ├── smoke_colorful.mp4 │ │ └── smoke_ink.mp4 │ │ ├── 1920x540 │ │ ├── cloud_atlas4_base_transformed_dual.mp4 │ │ ├── diamond_1_base_transformed_dual.mp4 │ │ ├── lemon_earth.mp4 │ │ ├── long_season.mp4 │ │ ├── scene_3_base_transformed_dual.mp4 │ │ └── titanic.mp4 │ │ ├── slider │ │ └── slider.mp4 │ │ └── teaser │ │ ├── cloud_atlas_SR_dual.mp4 │ │ ├── segtrack.mp4 │ │ ├── teaser.mp4 │ │ └── tracking_zoom.mp4 └── teaser.gif ├── losses.py ├── metrics.py ├── models └── implicit_model.py ├── opt.py ├── requirements.txt ├── scripts ├── test_canonical.sh ├── test_multi.sh └── train_multi.sh ├── train.py └── utils ├── __init__.py ├── image_utils.py ├── video_visualizer.py └── warmup_scheduler.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore compiled files. 2 | __pycache__/ 3 | *.py[cod] 4 | tmp_build/ 5 | *.pyc 6 | 7 | # Ignore files created by IDEs. 8 | /.vscode/ 9 | /.idea/ 10 | .ipynb_*/ 11 | *.ipynb 12 | .DS_Store 13 | *.sw[pon] 14 | 15 | # Ignore data files. 16 | data/ 17 | *.npy 18 | *.tar 19 | *.zip 20 | *.mdb 21 | *.ckpt 22 | 23 | # Ignore network files. 24 | ckpts/ 25 | *.pth 26 | *.pt 27 | *.pkl 28 | *.h5 29 | *.dat 30 | *.ckpt 31 | 32 | # Ignore log files. 33 | results/ 34 | resources/ 35 | events/ 36 | profile/ 37 | logs/ 38 | # *.json 39 | *.log 40 | events.* 41 | 42 | # Files that should not be ignored. 43 | !/requirements/* 44 | 45 | # Others shuld be ignored. 46 | all_sequences/ 47 | ckpts/ 48 | logs/ 49 | # configs/ 50 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | ------------------------------ LICENSE for CoDeF ------------------------------ 2 | 3 | Copyright (c) 2023 Ant Group. 4 | 5 | MIT License 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CoDeF: Content Deformation Fields for Temporally Consistent Video Processing 2 | 3 | 4 | 5 | [Hao Ouyang](https://ken-ouyang.github.io/)\*, [Qiuyu Wang](https://github.com/qiuyu96/)\*, [Yuxi Xiao](https://henry123-boy.github.io/)\*, [Qingyan Bai](https://scholar.google.com/citations?user=xUMjxi4AAAAJ&hl=en), [Juntao Zhang](https://github.com/JordanZh), [Kecheng Zheng](https://scholar.google.com/citations?user=hMDQifQAAAAJ), [Xiaowei Zhou](https://xzhou.me/), 6 | [Qifeng Chen](https://cqf.io/)†, [Yujun Shen](https://shenyujun.github.io/)† (*equal contribution, †corresponding author) 7 | 8 | **CVPR 2024 Highlight** 9 | 10 | #### [Project Page](https://qiuyu96.github.io/CoDeF/) | [Paper](https://arxiv.org/abs/2308.07926) | [High-Res Translation Demo](https://ezioby.github.io/CoDeF_Demo/) | [Colab](https://colab.research.google.com/github/camenduru/CoDeF-colab/blob/main/CoDeF_colab.ipynb) 11 | 12 | 13 | 14 | ## Requirements 15 | 16 | The codebase is tested on 17 | 18 | * Ubuntu 20.04 19 | * Python 3.10 20 | * [PyTorch](https://pytorch.org/) 2.0.0 21 | * [PyTorch Lightning](https://www.pytorchlightning.ai/index.html) 2.0.2 22 | * 1 NVIDIA GPU (RTX A6000) with CUDA version 11.7. (Other GPUs are also suitable, and 10GB GPU memory is sufficient to run our code.) 23 | 24 | To use video visualizer, please install `ffmpeg` via 25 | 26 | ```shell 27 | sudo apt-get install ffmpeg 28 | ``` 29 | 30 | For additional Python libraries, please install with 31 | 32 | ```shell 33 | pip install -r requirements.txt 34 | ``` 35 | 36 | Our code also depends on [tiny-cuda-nn](https://github.com/NVlabs/tiny-cuda-nn). 37 | See [this repository](https://github.com/NVlabs/tiny-cuda-nn#pytorch-extension) 38 | for Pytorch extension install instructions. 39 | 40 | ## Data 41 | 42 | ### Provided data 43 | 44 | We have provided some videos [here](https://drive.google.com/file/d/1cKZF6ILeokCjsSAGBmummcQh0uRGaC_F/view?usp=sharing) for quick test. Please download and unzip the data and put them in the root directory. More videos can be downloaded [here](https://drive.google.com/file/d/10Msz37MpjZQFPXlDWCZqrcQjhxpQSvCI/view?usp=sharing). 45 | 46 | ### Customize your own data 47 | 48 | We segement video sequences using [SAM-Track](https://github.com/z-x-yang/Segment-and-Track-Anything). Once you obtain the mask files, place them in the folder `all_sequences/{YOUR_SEQUENCE_NAME}/{YOUR_SEQUENCE_NAME}_masks`. Next, execute the following command: 49 | 50 | ```shell 51 | cd data_preprocessing 52 | python preproc_mask.py 53 | ``` 54 | 55 | We extract optical flows of video sequences using [RAFT](https://github.com/princeton-vl/RAFT). To get started, please follow the instructions provided [here](https://github.com/princeton-vl/RAFT#demos) to download their pretrained model. Once downloaded, place the model in the `data_preprocessing/RAFT/models` folder. After that, you can execute the following command: 56 | 57 | ```shell 58 | cd data_preprocessing/RAFT 59 | ./run_raft.sh 60 | ``` 61 | 62 | Remember to update the sequence name and root directory in both `data_preprocessing/preproc_mask.py` and `data_preprocessing/RAFT/run_raft.sh` accordingly. 63 | 64 | After obtaining the files, please organize your own data as follows: 65 | 66 | ``` 67 | CoDeF 68 | │ 69 | └─── all_sequences 70 | │ 71 | └─── NAME1 72 | └─ NAME1 73 | └─ NAME1_masks_0 (optional) 74 | └─ NAME1_masks_1 (optional) 75 | └─ NAME1_flow (optional) 76 | └─ NAME1_flow_confidence (optional) 77 | │ 78 | └─── NAME2 79 | └─ NAME2 80 | └─ NAME2_masks_0 (optional) 81 | └─ NAME2_masks_1 (optional) 82 | └─ NAME2_flow (optional) 83 | └─ NAME2_flow_confidence (optional) 84 | │ 85 | └─── ... 86 | ``` 87 | 88 | ## Pretrained checkpoints 89 | 90 | You can download checkpoints pre-trained on the provided videos via 91 | 92 | | Sequence Name | Config | Download | OpenXLab | 93 | | :-------- | :----: | :----------------------------------------------------------: | :---------:| 94 | | beauty_0 | configs/beauty_0/base.yaml | [Google drive link](https://drive.google.com/file/d/11SWfnfDct8bE16802PyqYJqsU4x6ACn8/view?usp=sharing) |[![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models/detail/HaoOuyang/CoDeF)| 95 | | beauty_1 | configs/beauty_1/base.yaml | [Google drive link](https://drive.google.com/file/d/1bSK0ChbPdURWGLdtc9CPLkN4Tfnng51k/view?usp=sharing) | [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models/detail/HaoOuyang/CoDeF) | 96 | | white_smoke | configs/white_smoke/base.yaml | [Google drive link](https://drive.google.com/file/d/1QOBCDGV2hHwxq4eL1E_45z5zhZ-wTJR7/view?usp=sharing) | [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models/detail/HaoOuyang/CoDeF) | 97 | | lemon_hit | configs/lemon_hit/base.yaml | [Google drive link](https://drive.google.com/file/d/140ctcLbv7JTIiy53MuCYtI4_zpIvRXzq/view?usp=sharing) | [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models/detail/HaoOuyang/CoDeF)| 98 | | scene_0 | configs/scene_0/base.yaml | [Google drive link](https://drive.google.com/file/d/1abOdREarfw1DGscahOJd2gZf1Xn_zN-F/view?usp=sharing) |[![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models/detail/HaoOuyang/CoDeF)| 99 | 100 | And organize files as follows 101 | 102 | ``` 103 | CoDeF 104 | │ 105 | └─── ckpts/all_sequences 106 | │ 107 | └─── NAME1 108 | │ 109 | └─── EXP_NAME (base) 110 | │ 111 | └─── NAME1.ckpt 112 | │ 113 | └─── NAME2 114 | │ 115 | └─── EXP_NAME (base) 116 | │ 117 | └─── NAME2.ckpt 118 | | 119 | └─── ... 120 | ``` 121 | 122 | ## Train a new model 123 | 124 | ```shell 125 | ./scripts/train_multi.sh 126 | ``` 127 | 128 | where 129 | * `GPU`: Decide which GPU to train on; 130 | * `NAME`: Name of the video sequence; 131 | * `EXP_NAME`: Name of the experiment; 132 | * `ROOT_DIRECTORY`: Directory of the input video sequence; 133 | * `MODEL_SAVE_PATH`: Path to save the checkpoints; 134 | * `LOG_SAVE_PATH`: Path to save the logs; 135 | * `MASK_DIRECTORY`: Directory of the preprocessed masks (optional); 136 | * `FLOW_DIRECTORY`: Directory of the preprocessed optical flows (optional); 137 | 138 | Please check configuration files in ``configs/``, and you can always add your own model config. 139 | 140 | ## Test reconstruction 141 | 142 | ```shell 143 | ./scripts/test_multi.sh 144 | ``` 145 | After running the script, the reconstructed videos can be found in `results/all_sequences/{NAME}/{EXP_NAME}`, along with the canonical image. 146 | 147 | ## Test video translation 148 | 149 | After obtaining the canonical image through [this step](#anchor), use your preferred text prompts to transfer it using [ControlNet](https://github.com/lllyasviel/ControlNet). 150 | Once you have the transferred canonical image, place it in `all_sequences/${NAME}/${EXP_NAME}_control` (i.e. `CANONICAL_DIR` in `scripts/test_canonical.sh`). 151 | 152 | Then run 153 | 154 | ```shell 155 | ./scripts/test_canonical.sh 156 | ``` 157 | 158 | The transferred results can be seen in `results/all_sequences/{NAME}/{EXP_NAME}_transformed`. 159 | 160 | *Note*: The `canonical_wh` option in the configuration file should be set with caution, usually a little larger than `img_wh`, as it determines the field of view of the canonical image. 161 | 162 | ### BibTeX 163 | 164 | ```bibtex 165 | @article{ouyang2023codef, 166 | title={CoDeF: Content Deformation Fields for Temporally Consistent Video Processing}, 167 | author={Hao Ouyang and Qiuyu Wang and Yuxi Xiao and Qingyan Bai and Juntao Zhang and Kecheng Zheng and Xiaowei Zhou and Qifeng Chen and Yujun Shen}, 168 | journal={arXiv preprint arXiv:2308.07926}, 169 | year={2023} 170 | } 171 | ``` 172 | 173 | ### Acknowledgements 174 | We thank [camenduru](https://github.com/camenduru) for providing the [colab demo](https://github.com/camenduru/CoDeF-colab). 175 | -------------------------------------------------------------------------------- /configs/beauty_0/base.yaml: -------------------------------------------------------------------------------- 1 | mask_dir: null 2 | 3 | img_wh: [540, 540] 4 | canonical_wh: [640, 640] 5 | 6 | lr: 0.001 7 | bg_loss: 0.003 8 | 9 | ref_idx: null # 0 10 | 11 | N_xyz_w: [8,] 12 | flow_loss: 1 13 | flow_step: -1 14 | self_bg: True 15 | 16 | deform_hash: True 17 | vid_hash: True 18 | 19 | num_steps: 10000 20 | decay_step: [2500, 5000, 7500] 21 | annealed_begin_step: 4000 22 | annealed_step: 4000 23 | save_model_iters: 2000 24 | 25 | fps: 15 26 | -------------------------------------------------------------------------------- /configs/beauty_1/base.yaml: -------------------------------------------------------------------------------- 1 | img_wh: [540, 540] 2 | canonical_wh: [540, 540] 3 | 4 | lr: 0.001 5 | bg_loss: 0.003 6 | 7 | ref_idx: null # 0 8 | 9 | N_xyz_w: [8, 8] 10 | flow_loss: 1 11 | flow_step: -1 12 | self_bg: True 13 | 14 | deform_hash: True 15 | vid_hash: True 16 | 17 | num_steps: 10000 18 | decay_step: [2500, 5000, 7500] 19 | annealed_begin_step: 4000 20 | annealed_step: 4000 21 | save_model_iters: 2000 22 | 23 | fps: 15 24 | -------------------------------------------------------------------------------- /configs/hash.json: -------------------------------------------------------------------------------- 1 | { 2 | "encoding_deform3d": { 3 | "otype": "HashGrid", 4 | "n_levels": 16, 5 | "n_features_per_level": 2, 6 | "log2_hashmap_size": 19, 7 | "base_resolution": 16, 8 | "per_level_scale": 1.38 9 | }, 10 | "encoding": { 11 | "otype": "HashGrid", 12 | "n_levels": 16, 13 | "n_features_per_level": 2, 14 | "log2_hashmap_size": 19, 15 | "base_resolution": 16, 16 | "per_level_scale": 1.44 17 | }, 18 | "encoding_deform2d": { 19 | "otype": "HashGrid", 20 | "n_levels": 16, 21 | "n_features_per_level": 2, 22 | "log2_hashmap_size": 19, 23 | "base_resolution": 16, 24 | "per_level_scale": 1.44 25 | }, 26 | "network": { 27 | "otype": "FullyFusedMLP", 28 | "activation": "ReLU", 29 | "output_activation": "None", 30 | "n_neurons": 64, 31 | "n_hidden_layers": 2 32 | }, 33 | "network_deform": { 34 | "otype": "FullyFusedMLP", 35 | "activation": "ReLU", 36 | "output_activation": "None", 37 | "n_neurons": 64, 38 | "n_hidden_layers": 8 39 | }, 40 | "network_occ": { 41 | "otype": "FullyFusedMLP", 42 | "activation": "ReLU", 43 | "output_activation": "Sigmoid", 44 | "n_neurons": 64, 45 | "n_hidden_layers": 2 46 | }, 47 | "time_code": 64 48 | } 49 | -------------------------------------------------------------------------------- /configs/lemon_hit/base.yaml: -------------------------------------------------------------------------------- 1 | mask_dir: null 2 | flow_dir: null 3 | 4 | img_wh: [960, 540] 5 | canonical_wh: [1280, 640] 6 | 7 | lr: 0.001 8 | bg_loss: 0.003 9 | 10 | ref_idx: null # 0 11 | 12 | N_xyz_w: [8,] 13 | flow_loss: 0 14 | flow_step: -1 15 | self_bg: True 16 | 17 | deform_hash: True 18 | vid_hash: True 19 | 20 | num_steps: 10000 21 | decay_step: [2500, 5000, 7500] 22 | annealed_begin_step: 4000 23 | annealed_step: 4000 24 | save_model_iters: 2000 25 | -------------------------------------------------------------------------------- /configs/scene_0/base.yaml: -------------------------------------------------------------------------------- 1 | mask_dir: null 2 | 3 | img_wh: [540, 960] 4 | canonical_wh: [720, 1280] 5 | 6 | lr: 0.001 7 | bg_loss: 0.003 8 | 9 | ref_idx: null # 0 10 | 11 | N_xyz_w: [8,] 12 | flow_loss: 1 13 | flow_step: -1 14 | self_bg: True 15 | 16 | deform_hash: True 17 | vid_hash: True 18 | 19 | num_steps: 10000 20 | decay_step: [2500, 5000, 7500] 21 | annealed_begin_step: 4000 22 | annealed_step: 4000 23 | save_model_iters: 2000 24 | -------------------------------------------------------------------------------- /configs/white_smoke/base.yaml: -------------------------------------------------------------------------------- 1 | mask_dir: null 2 | flow_dir: null 3 | 4 | img_wh: [960, 540] 5 | canonical_wh: [1280, 640] 6 | 7 | lr: 0.001 8 | bg_loss: 0.003 9 | 10 | ref_idx: null # 0 11 | 12 | N_xyz_w: [8,] 13 | flow_loss: 0 14 | flow_step: -1 15 | self_bg: True 16 | 17 | deform_hash: True 18 | vid_hash: True 19 | 20 | num_steps: 10000 21 | decay_step: [2500, 5000, 7500] 22 | annealed_begin_step: 4000 23 | annealed_step: 4000 24 | save_model_iters: 2000 25 | -------------------------------------------------------------------------------- /data_preprocessing/RAFT/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/data_preprocessing/RAFT/core/__init__.py -------------------------------------------------------------------------------- /data_preprocessing/RAFT/core/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from utils.utils import bilinear_sampler, coords_grid 4 | 5 | try: 6 | import alt_cuda_corr 7 | except: 8 | # alt_cuda_corr is not compiled 9 | pass 10 | 11 | 12 | class CorrBlock: 13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 14 | self.num_levels = num_levels 15 | self.radius = radius 16 | self.corr_pyramid = [] 17 | 18 | # all pairs correlation 19 | corr = CorrBlock.corr(fmap1, fmap2) 20 | 21 | batch, h1, w1, dim, h2, w2 = corr.shape 22 | corr = corr.reshape(batch*h1*w1, dim, h2, w2) 23 | 24 | self.corr_pyramid.append(corr) 25 | for i in range(self.num_levels-1): 26 | corr = F.avg_pool2d(corr, 2, stride=2) 27 | self.corr_pyramid.append(corr) 28 | 29 | def __call__(self, coords): 30 | r = self.radius 31 | coords = coords.permute(0, 2, 3, 1) 32 | batch, h1, w1, _ = coords.shape 33 | 34 | out_pyramid = [] 35 | for i in range(self.num_levels): 36 | corr = self.corr_pyramid[i] 37 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device) 38 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device) 39 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) 40 | 41 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i 42 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) 43 | coords_lvl = centroid_lvl + delta_lvl 44 | 45 | corr = bilinear_sampler(corr, coords_lvl) 46 | corr = corr.view(batch, h1, w1, -1) 47 | out_pyramid.append(corr) 48 | 49 | out = torch.cat(out_pyramid, dim=-1) 50 | return out.permute(0, 3, 1, 2).contiguous().float() 51 | 52 | @staticmethod 53 | def corr(fmap1, fmap2): 54 | batch, dim, ht, wd = fmap1.shape 55 | fmap1 = fmap1.view(batch, dim, ht*wd) 56 | fmap2 = fmap2.view(batch, dim, ht*wd) 57 | 58 | corr = torch.matmul(fmap1.transpose(1,2), fmap2) 59 | corr = corr.view(batch, ht, wd, 1, ht, wd) 60 | return corr / torch.sqrt(torch.tensor(dim).float()) 61 | 62 | 63 | class AlternateCorrBlock: 64 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 65 | self.num_levels = num_levels 66 | self.radius = radius 67 | 68 | self.pyramid = [(fmap1, fmap2)] 69 | for i in range(self.num_levels): 70 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2) 71 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 72 | self.pyramid.append((fmap1, fmap2)) 73 | 74 | def __call__(self, coords): 75 | coords = coords.permute(0, 2, 3, 1) 76 | B, H, W, _ = coords.shape 77 | dim = self.pyramid[0][0].shape[1] 78 | 79 | corr_list = [] 80 | for i in range(self.num_levels): 81 | r = self.radius 82 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() 83 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() 84 | 85 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() 86 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) 87 | corr_list.append(corr.squeeze(1)) 88 | 89 | corr = torch.stack(corr_list, dim=1) 90 | corr = corr.reshape(B, -1, H, W) 91 | return corr / torch.sqrt(torch.tensor(dim).float()) 92 | -------------------------------------------------------------------------------- /data_preprocessing/RAFT/core/datasets.py: -------------------------------------------------------------------------------- 1 | # Data loading based on https://github.com/NVIDIA/flownet2-pytorch 2 | 3 | import numpy as np 4 | import torch 5 | import torch.utils.data as data 6 | import torch.nn.functional as F 7 | 8 | import os 9 | import math 10 | import random 11 | from glob import glob 12 | import os.path as osp 13 | 14 | from utils import frame_utils 15 | from utils.augmentor import FlowAugmentor, SparseFlowAugmentor 16 | 17 | 18 | class FlowDataset(data.Dataset): 19 | def __init__(self, aug_params=None, sparse=False): 20 | self.augmentor = None 21 | self.sparse = sparse 22 | if aug_params is not None: 23 | if sparse: 24 | self.augmentor = SparseFlowAugmentor(**aug_params) 25 | else: 26 | self.augmentor = FlowAugmentor(**aug_params) 27 | 28 | self.is_test = False 29 | self.init_seed = False 30 | self.flow_list = [] 31 | self.image_list = [] 32 | self.extra_info = [] 33 | 34 | def __getitem__(self, index): 35 | 36 | if self.is_test: 37 | img1 = frame_utils.read_gen(self.image_list[index][0]) 38 | img2 = frame_utils.read_gen(self.image_list[index][1]) 39 | img1 = np.array(img1).astype(np.uint8)[..., :3] 40 | img2 = np.array(img2).astype(np.uint8)[..., :3] 41 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 42 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 43 | return img1, img2, self.extra_info[index] 44 | 45 | if not self.init_seed: 46 | worker_info = torch.utils.data.get_worker_info() 47 | if worker_info is not None: 48 | torch.manual_seed(worker_info.id) 49 | np.random.seed(worker_info.id) 50 | random.seed(worker_info.id) 51 | self.init_seed = True 52 | 53 | index = index % len(self.image_list) 54 | valid = None 55 | if self.sparse: 56 | flow, valid = frame_utils.readFlowKITTI(self.flow_list[index]) 57 | else: 58 | flow = frame_utils.read_gen(self.flow_list[index]) 59 | 60 | img1 = frame_utils.read_gen(self.image_list[index][0]) 61 | img2 = frame_utils.read_gen(self.image_list[index][1]) 62 | 63 | flow = np.array(flow).astype(np.float32) 64 | img1 = np.array(img1).astype(np.uint8) 65 | img2 = np.array(img2).astype(np.uint8) 66 | 67 | # grayscale images 68 | if len(img1.shape) == 2: 69 | img1 = np.tile(img1[...,None], (1, 1, 3)) 70 | img2 = np.tile(img2[...,None], (1, 1, 3)) 71 | else: 72 | img1 = img1[..., :3] 73 | img2 = img2[..., :3] 74 | 75 | if self.augmentor is not None: 76 | if self.sparse: 77 | img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid) 78 | else: 79 | img1, img2, flow = self.augmentor(img1, img2, flow) 80 | 81 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 82 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 83 | flow = torch.from_numpy(flow).permute(2, 0, 1).float() 84 | 85 | if valid is not None: 86 | valid = torch.from_numpy(valid) 87 | else: 88 | valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) 89 | 90 | return img1, img2, flow, valid.float() 91 | 92 | 93 | def __rmul__(self, v): 94 | self.flow_list = v * self.flow_list 95 | self.image_list = v * self.image_list 96 | return self 97 | 98 | def __len__(self): 99 | return len(self.image_list) 100 | 101 | 102 | class MpiSintel(FlowDataset): 103 | def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'): 104 | super(MpiSintel, self).__init__(aug_params) 105 | flow_root = osp.join(root, split, 'flow') 106 | image_root = osp.join(root, split, dstype) 107 | 108 | if split == 'test': 109 | self.is_test = True 110 | 111 | for scene in os.listdir(image_root): 112 | image_list = sorted(glob(osp.join(image_root, scene, '*.png'))) 113 | for i in range(len(image_list)-1): 114 | self.image_list += [ [image_list[i], image_list[i+1]] ] 115 | self.extra_info += [ (scene, i) ] # scene and frame_id 116 | 117 | if split != 'test': 118 | self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo'))) 119 | 120 | 121 | class FlyingChairs(FlowDataset): 122 | def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'): 123 | super(FlyingChairs, self).__init__(aug_params) 124 | 125 | images = sorted(glob(osp.join(root, '*.ppm'))) 126 | flows = sorted(glob(osp.join(root, '*.flo'))) 127 | assert (len(images)//2 == len(flows)) 128 | 129 | split_list = np.loadtxt('chairs_split.txt', dtype=np.int32) 130 | for i in range(len(flows)): 131 | xid = split_list[i] 132 | if (split=='training' and xid==1) or (split=='validation' and xid==2): 133 | self.flow_list += [ flows[i] ] 134 | self.image_list += [ [images[2*i], images[2*i+1]] ] 135 | 136 | 137 | class FlyingThings3D(FlowDataset): 138 | def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'): 139 | super(FlyingThings3D, self).__init__(aug_params) 140 | 141 | for cam in ['left']: 142 | for direction in ['into_future', 'into_past']: 143 | image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*'))) 144 | image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) 145 | 146 | flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*'))) 147 | flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs]) 148 | 149 | for idir, fdir in zip(image_dirs, flow_dirs): 150 | images = sorted(glob(osp.join(idir, '*.png')) ) 151 | flows = sorted(glob(osp.join(fdir, '*.pfm')) ) 152 | for i in range(len(flows)-1): 153 | if direction == 'into_future': 154 | self.image_list += [ [images[i], images[i+1]] ] 155 | self.flow_list += [ flows[i] ] 156 | elif direction == 'into_past': 157 | self.image_list += [ [images[i+1], images[i]] ] 158 | self.flow_list += [ flows[i+1] ] 159 | 160 | 161 | class KITTI(FlowDataset): 162 | def __init__(self, aug_params=None, split='training', root='datasets/KITTI'): 163 | super(KITTI, self).__init__(aug_params, sparse=True) 164 | if split == 'testing': 165 | self.is_test = True 166 | 167 | root = osp.join(root, split) 168 | images1 = sorted(glob(osp.join(root, 'image_2/*_10.png'))) 169 | images2 = sorted(glob(osp.join(root, 'image_2/*_11.png'))) 170 | 171 | for img1, img2 in zip(images1, images2): 172 | frame_id = img1.split('/')[-1] 173 | self.extra_info += [ [frame_id] ] 174 | self.image_list += [ [img1, img2] ] 175 | 176 | if split == 'training': 177 | self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png'))) 178 | 179 | 180 | class HD1K(FlowDataset): 181 | def __init__(self, aug_params=None, root='datasets/HD1k'): 182 | super(HD1K, self).__init__(aug_params, sparse=True) 183 | 184 | seq_ix = 0 185 | while 1: 186 | flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix))) 187 | images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix))) 188 | 189 | if len(flows) == 0: 190 | break 191 | 192 | for i in range(len(flows)-1): 193 | self.flow_list += [flows[i]] 194 | self.image_list += [ [images[i], images[i+1]] ] 195 | 196 | seq_ix += 1 197 | 198 | 199 | def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'): 200 | """ Create the data loader for the corresponding trainign set """ 201 | 202 | if args.stage == 'chairs': 203 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True} 204 | train_dataset = FlyingChairs(aug_params, split='training') 205 | 206 | elif args.stage == 'things': 207 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True} 208 | clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass') 209 | final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass') 210 | train_dataset = clean_dataset + final_dataset 211 | 212 | elif args.stage == 'sintel': 213 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True} 214 | things = FlyingThings3D(aug_params, dstype='frames_cleanpass') 215 | sintel_clean = MpiSintel(aug_params, split='training', dstype='clean') 216 | sintel_final = MpiSintel(aug_params, split='training', dstype='final') 217 | 218 | if TRAIN_DS == 'C+T+K+S+H': 219 | kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True}) 220 | hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True}) 221 | train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things 222 | 223 | elif TRAIN_DS == 'C+T+K/S': 224 | train_dataset = 100*sintel_clean + 100*sintel_final + things 225 | 226 | elif args.stage == 'kitti': 227 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} 228 | train_dataset = KITTI(aug_params, split='training') 229 | 230 | train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, 231 | pin_memory=False, shuffle=True, num_workers=4, drop_last=True) 232 | 233 | print('Training with %d image pairs' % len(train_dataset)) 234 | return train_loader 235 | 236 | -------------------------------------------------------------------------------- /data_preprocessing/RAFT/core/extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 8 | super(ResidualBlock, self).__init__() 9 | 10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) 11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 12 | self.relu = nn.ReLU(inplace=True) 13 | 14 | num_groups = planes // 8 15 | 16 | if norm_fn == 'group': 17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 19 | if not stride == 1: 20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 21 | 22 | elif norm_fn == 'batch': 23 | self.norm1 = nn.BatchNorm2d(planes) 24 | self.norm2 = nn.BatchNorm2d(planes) 25 | if not stride == 1: 26 | self.norm3 = nn.BatchNorm2d(planes) 27 | 28 | elif norm_fn == 'instance': 29 | self.norm1 = nn.InstanceNorm2d(planes) 30 | self.norm2 = nn.InstanceNorm2d(planes) 31 | if not stride == 1: 32 | self.norm3 = nn.InstanceNorm2d(planes) 33 | 34 | elif norm_fn == 'none': 35 | self.norm1 = nn.Sequential() 36 | self.norm2 = nn.Sequential() 37 | if not stride == 1: 38 | self.norm3 = nn.Sequential() 39 | 40 | if stride == 1: 41 | self.downsample = None 42 | 43 | else: 44 | self.downsample = nn.Sequential( 45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 46 | 47 | 48 | def forward(self, x): 49 | y = x 50 | y = self.relu(self.norm1(self.conv1(y))) 51 | y = self.relu(self.norm2(self.conv2(y))) 52 | 53 | if self.downsample is not None: 54 | x = self.downsample(x) 55 | 56 | return self.relu(x+y) 57 | 58 | 59 | 60 | class BottleneckBlock(nn.Module): 61 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 62 | super(BottleneckBlock, self).__init__() 63 | 64 | self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) 65 | self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) 66 | self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) 67 | self.relu = nn.ReLU(inplace=True) 68 | 69 | num_groups = planes // 8 70 | 71 | if norm_fn == 'group': 72 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 73 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 74 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 75 | if not stride == 1: 76 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 77 | 78 | elif norm_fn == 'batch': 79 | self.norm1 = nn.BatchNorm2d(planes//4) 80 | self.norm2 = nn.BatchNorm2d(planes//4) 81 | self.norm3 = nn.BatchNorm2d(planes) 82 | if not stride == 1: 83 | self.norm4 = nn.BatchNorm2d(planes) 84 | 85 | elif norm_fn == 'instance': 86 | self.norm1 = nn.InstanceNorm2d(planes//4) 87 | self.norm2 = nn.InstanceNorm2d(planes//4) 88 | self.norm3 = nn.InstanceNorm2d(planes) 89 | if not stride == 1: 90 | self.norm4 = nn.InstanceNorm2d(planes) 91 | 92 | elif norm_fn == 'none': 93 | self.norm1 = nn.Sequential() 94 | self.norm2 = nn.Sequential() 95 | self.norm3 = nn.Sequential() 96 | if not stride == 1: 97 | self.norm4 = nn.Sequential() 98 | 99 | if stride == 1: 100 | self.downsample = None 101 | 102 | else: 103 | self.downsample = nn.Sequential( 104 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) 105 | 106 | 107 | def forward(self, x): 108 | y = x 109 | y = self.relu(self.norm1(self.conv1(y))) 110 | y = self.relu(self.norm2(self.conv2(y))) 111 | y = self.relu(self.norm3(self.conv3(y))) 112 | 113 | if self.downsample is not None: 114 | x = self.downsample(x) 115 | 116 | return self.relu(x+y) 117 | 118 | class BasicEncoder(nn.Module): 119 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 120 | super(BasicEncoder, self).__init__() 121 | self.norm_fn = norm_fn 122 | 123 | if self.norm_fn == 'group': 124 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) 125 | 126 | elif self.norm_fn == 'batch': 127 | self.norm1 = nn.BatchNorm2d(64) 128 | 129 | elif self.norm_fn == 'instance': 130 | self.norm1 = nn.InstanceNorm2d(64) 131 | 132 | elif self.norm_fn == 'none': 133 | self.norm1 = nn.Sequential() 134 | 135 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 136 | self.relu1 = nn.ReLU(inplace=True) 137 | 138 | self.in_planes = 64 139 | self.layer1 = self._make_layer(64, stride=1) 140 | self.layer2 = self._make_layer(96, stride=2) 141 | self.layer3 = self._make_layer(128, stride=2) 142 | 143 | # output convolution 144 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) 145 | 146 | self.dropout = None 147 | if dropout > 0: 148 | self.dropout = nn.Dropout2d(p=dropout) 149 | 150 | for m in self.modules(): 151 | if isinstance(m, nn.Conv2d): 152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 153 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 154 | if m.weight is not None: 155 | nn.init.constant_(m.weight, 1) 156 | if m.bias is not None: 157 | nn.init.constant_(m.bias, 0) 158 | 159 | def _make_layer(self, dim, stride=1): 160 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) 161 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 162 | layers = (layer1, layer2) 163 | 164 | self.in_planes = dim 165 | return nn.Sequential(*layers) 166 | 167 | 168 | def forward(self, x): 169 | 170 | # if input is list, combine batch dimension 171 | is_list = isinstance(x, tuple) or isinstance(x, list) 172 | if is_list: 173 | batch_dim = x[0].shape[0] 174 | x = torch.cat(x, dim=0) 175 | 176 | x = self.conv1(x) 177 | x = self.norm1(x) 178 | x = self.relu1(x) 179 | 180 | x = self.layer1(x) 181 | x = self.layer2(x) 182 | x = self.layer3(x) 183 | 184 | x = self.conv2(x) 185 | 186 | if self.training and self.dropout is not None: 187 | x = self.dropout(x) 188 | 189 | if is_list: 190 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 191 | 192 | return x 193 | 194 | 195 | class SmallEncoder(nn.Module): 196 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 197 | super(SmallEncoder, self).__init__() 198 | self.norm_fn = norm_fn 199 | 200 | if self.norm_fn == 'group': 201 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) 202 | 203 | elif self.norm_fn == 'batch': 204 | self.norm1 = nn.BatchNorm2d(32) 205 | 206 | elif self.norm_fn == 'instance': 207 | self.norm1 = nn.InstanceNorm2d(32) 208 | 209 | elif self.norm_fn == 'none': 210 | self.norm1 = nn.Sequential() 211 | 212 | self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) 213 | self.relu1 = nn.ReLU(inplace=True) 214 | 215 | self.in_planes = 32 216 | self.layer1 = self._make_layer(32, stride=1) 217 | self.layer2 = self._make_layer(64, stride=2) 218 | self.layer3 = self._make_layer(96, stride=2) 219 | 220 | self.dropout = None 221 | if dropout > 0: 222 | self.dropout = nn.Dropout2d(p=dropout) 223 | 224 | self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) 225 | 226 | for m in self.modules(): 227 | if isinstance(m, nn.Conv2d): 228 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 229 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 230 | if m.weight is not None: 231 | nn.init.constant_(m.weight, 1) 232 | if m.bias is not None: 233 | nn.init.constant_(m.bias, 0) 234 | 235 | def _make_layer(self, dim, stride=1): 236 | layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) 237 | layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) 238 | layers = (layer1, layer2) 239 | 240 | self.in_planes = dim 241 | return nn.Sequential(*layers) 242 | 243 | 244 | def forward(self, x): 245 | 246 | # if input is list, combine batch dimension 247 | is_list = isinstance(x, tuple) or isinstance(x, list) 248 | if is_list: 249 | batch_dim = x[0].shape[0] 250 | x = torch.cat(x, dim=0) 251 | 252 | x = self.conv1(x) 253 | x = self.norm1(x) 254 | x = self.relu1(x) 255 | 256 | x = self.layer1(x) 257 | x = self.layer2(x) 258 | x = self.layer3(x) 259 | x = self.conv2(x) 260 | 261 | if self.training and self.dropout is not None: 262 | x = self.dropout(x) 263 | 264 | if is_list: 265 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 266 | 267 | return x 268 | -------------------------------------------------------------------------------- /data_preprocessing/RAFT/core/raft.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from update import BasicUpdateBlock, SmallUpdateBlock 7 | from extractor import BasicEncoder, SmallEncoder 8 | from corr import CorrBlock, AlternateCorrBlock 9 | from utils.utils import bilinear_sampler, coords_grid, upflow8 10 | 11 | try: 12 | autocast = torch.cuda.amp.autocast 13 | except: 14 | # dummy autocast for PyTorch < 1.6 15 | class autocast: 16 | def __init__(self, enabled): 17 | pass 18 | def __enter__(self): 19 | pass 20 | def __exit__(self, *args): 21 | pass 22 | 23 | 24 | class RAFT(nn.Module): 25 | def __init__(self, args): 26 | super(RAFT, self).__init__() 27 | self.args = args 28 | 29 | if args.small: 30 | self.hidden_dim = hdim = 96 31 | self.context_dim = cdim = 64 32 | args.corr_levels = 4 33 | args.corr_radius = 3 34 | 35 | else: 36 | self.hidden_dim = hdim = 128 37 | self.context_dim = cdim = 128 38 | args.corr_levels = 4 39 | args.corr_radius = 4 40 | 41 | if 'dropout' not in self.args: 42 | self.args.dropout = 0 43 | 44 | if 'alternate_corr' not in self.args: 45 | self.args.alternate_corr = False 46 | 47 | # feature network, context network, and update block 48 | if args.small: 49 | self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout) 50 | self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout) 51 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) 52 | 53 | else: 54 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout) 55 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout) 56 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) 57 | 58 | def freeze_bn(self): 59 | for m in self.modules(): 60 | if isinstance(m, nn.BatchNorm2d): 61 | m.eval() 62 | 63 | def initialize_flow(self, img): 64 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 65 | N, C, H, W = img.shape 66 | coords0 = coords_grid(N, H//8, W//8, device=img.device) 67 | coords1 = coords_grid(N, H//8, W//8, device=img.device) 68 | 69 | # optical flow computed as difference: flow = coords1 - coords0 70 | return coords0, coords1 71 | 72 | def upsample_flow(self, flow, mask): 73 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 74 | N, _, H, W = flow.shape 75 | mask = mask.view(N, 1, 9, 8, 8, H, W) 76 | mask = torch.softmax(mask, dim=2) 77 | 78 | up_flow = F.unfold(8 * flow, [3,3], padding=1) 79 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 80 | 81 | up_flow = torch.sum(mask * up_flow, dim=2) 82 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 83 | return up_flow.reshape(N, 2, 8*H, 8*W) 84 | 85 | 86 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False): 87 | """ Estimate optical flow between pair of frames """ 88 | 89 | image1 = 2 * (image1 / 255.0) - 1.0 90 | image2 = 2 * (image2 / 255.0) - 1.0 91 | 92 | image1 = image1.contiguous() 93 | image2 = image2.contiguous() 94 | 95 | hdim = self.hidden_dim 96 | cdim = self.context_dim 97 | 98 | # run the feature network 99 | with autocast(enabled=self.args.mixed_precision): 100 | fmap1, fmap2 = self.fnet([image1, image2]) 101 | 102 | fmap1 = fmap1.float() 103 | fmap2 = fmap2.float() 104 | if self.args.alternate_corr: 105 | corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 106 | else: 107 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 108 | 109 | # run the context network 110 | with autocast(enabled=self.args.mixed_precision): 111 | cnet = self.cnet(image1) 112 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 113 | net = torch.tanh(net) 114 | inp = torch.relu(inp) 115 | 116 | coords0, coords1 = self.initialize_flow(image1) 117 | 118 | if flow_init is not None: 119 | coords1 = coords1 + flow_init 120 | 121 | flow_predictions = [] 122 | for itr in range(iters): 123 | coords1 = coords1.detach() 124 | corr = corr_fn(coords1) # index correlation volume 125 | 126 | flow = coords1 - coords0 127 | with autocast(enabled=self.args.mixed_precision): 128 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) 129 | 130 | # F(t+1) = F(t) + \Delta(t) 131 | coords1 = coords1 + delta_flow 132 | 133 | # upsample predictions 134 | if up_mask is None: 135 | flow_up = upflow8(coords1 - coords0) 136 | else: 137 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 138 | 139 | flow_predictions.append(flow_up) 140 | 141 | if test_mode: 142 | return coords1 - coords0, flow_up 143 | 144 | return flow_predictions 145 | -------------------------------------------------------------------------------- /data_preprocessing/RAFT/core/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FlowHead(nn.Module): 7 | def __init__(self, input_dim=128, hidden_dim=256): 8 | super(FlowHead, self).__init__() 9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | def forward(self, x): 14 | return self.conv2(self.relu(self.conv1(x))) 15 | 16 | class ConvGRU(nn.Module): 17 | def __init__(self, hidden_dim=128, input_dim=192+128): 18 | super(ConvGRU, self).__init__() 19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 22 | 23 | def forward(self, h, x): 24 | hx = torch.cat([h, x], dim=1) 25 | 26 | z = torch.sigmoid(self.convz(hx)) 27 | r = torch.sigmoid(self.convr(hx)) 28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) 29 | 30 | h = (1-z) * h + z * q 31 | return h 32 | 33 | class SepConvGRU(nn.Module): 34 | def __init__(self, hidden_dim=128, input_dim=192+128): 35 | super(SepConvGRU, self).__init__() 36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 39 | 40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 43 | 44 | 45 | def forward(self, h, x): 46 | # horizontal 47 | hx = torch.cat([h, x], dim=1) 48 | z = torch.sigmoid(self.convz1(hx)) 49 | r = torch.sigmoid(self.convr1(hx)) 50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 51 | h = (1-z) * h + z * q 52 | 53 | # vertical 54 | hx = torch.cat([h, x], dim=1) 55 | z = torch.sigmoid(self.convz2(hx)) 56 | r = torch.sigmoid(self.convr2(hx)) 57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 58 | h = (1-z) * h + z * q 59 | 60 | return h 61 | 62 | class SmallMotionEncoder(nn.Module): 63 | def __init__(self, args): 64 | super(SmallMotionEncoder, self).__init__() 65 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 66 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) 67 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3) 68 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1) 69 | self.conv = nn.Conv2d(128, 80, 3, padding=1) 70 | 71 | def forward(self, flow, corr): 72 | cor = F.relu(self.convc1(corr)) 73 | flo = F.relu(self.convf1(flow)) 74 | flo = F.relu(self.convf2(flo)) 75 | cor_flo = torch.cat([cor, flo], dim=1) 76 | out = F.relu(self.conv(cor_flo)) 77 | return torch.cat([out, flow], dim=1) 78 | 79 | class BasicMotionEncoder(nn.Module): 80 | def __init__(self, args): 81 | super(BasicMotionEncoder, self).__init__() 82 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 83 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 84 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 85 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 86 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 87 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) 88 | 89 | def forward(self, flow, corr): 90 | cor = F.relu(self.convc1(corr)) 91 | cor = F.relu(self.convc2(cor)) 92 | flo = F.relu(self.convf1(flow)) 93 | flo = F.relu(self.convf2(flo)) 94 | 95 | cor_flo = torch.cat([cor, flo], dim=1) 96 | out = F.relu(self.conv(cor_flo)) 97 | return torch.cat([out, flow], dim=1) 98 | 99 | class SmallUpdateBlock(nn.Module): 100 | def __init__(self, args, hidden_dim=96): 101 | super(SmallUpdateBlock, self).__init__() 102 | self.encoder = SmallMotionEncoder(args) 103 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) 104 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128) 105 | 106 | def forward(self, net, inp, corr, flow): 107 | motion_features = self.encoder(flow, corr) 108 | inp = torch.cat([inp, motion_features], dim=1) 109 | net = self.gru(net, inp) 110 | delta_flow = self.flow_head(net) 111 | 112 | return net, None, delta_flow 113 | 114 | class BasicUpdateBlock(nn.Module): 115 | def __init__(self, args, hidden_dim=128, input_dim=128): 116 | super(BasicUpdateBlock, self).__init__() 117 | self.args = args 118 | self.encoder = BasicMotionEncoder(args) 119 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) 120 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 121 | 122 | self.mask = nn.Sequential( 123 | nn.Conv2d(128, 256, 3, padding=1), 124 | nn.ReLU(inplace=True), 125 | nn.Conv2d(256, 64*9, 1, padding=0)) 126 | 127 | def forward(self, net, inp, corr, flow, upsample=True): 128 | motion_features = self.encoder(flow, corr) 129 | inp = torch.cat([inp, motion_features], dim=1) 130 | 131 | net = self.gru(net, inp) 132 | delta_flow = self.flow_head(net) 133 | 134 | # scale mask to balence gradients 135 | mask = .25 * self.mask(net) 136 | return net, mask, delta_flow 137 | 138 | 139 | 140 | -------------------------------------------------------------------------------- /data_preprocessing/RAFT/core/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/data_preprocessing/RAFT/core/utils/__init__.py -------------------------------------------------------------------------------- /data_preprocessing/RAFT/core/utils/augmentor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import math 4 | from PIL import Image 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | import torch 11 | from torchvision.transforms import ColorJitter 12 | import torch.nn.functional as F 13 | 14 | 15 | class FlowAugmentor: 16 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): 17 | 18 | # spatial augmentation params 19 | self.crop_size = crop_size 20 | self.min_scale = min_scale 21 | self.max_scale = max_scale 22 | self.spatial_aug_prob = 0.8 23 | self.stretch_prob = 0.8 24 | self.max_stretch = 0.2 25 | 26 | # flip augmentation params 27 | self.do_flip = do_flip 28 | self.h_flip_prob = 0.5 29 | self.v_flip_prob = 0.1 30 | 31 | # photometric augmentation params 32 | self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14) 33 | self.asymmetric_color_aug_prob = 0.2 34 | self.eraser_aug_prob = 0.5 35 | 36 | def color_transform(self, img1, img2): 37 | """ Photometric augmentation """ 38 | 39 | # asymmetric 40 | if np.random.rand() < self.asymmetric_color_aug_prob: 41 | img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) 42 | img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) 43 | 44 | # symmetric 45 | else: 46 | image_stack = np.concatenate([img1, img2], axis=0) 47 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 48 | img1, img2 = np.split(image_stack, 2, axis=0) 49 | 50 | return img1, img2 51 | 52 | def eraser_transform(self, img1, img2, bounds=[50, 100]): 53 | """ Occlusion augmentation """ 54 | 55 | ht, wd = img1.shape[:2] 56 | if np.random.rand() < self.eraser_aug_prob: 57 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 58 | for _ in range(np.random.randint(1, 3)): 59 | x0 = np.random.randint(0, wd) 60 | y0 = np.random.randint(0, ht) 61 | dx = np.random.randint(bounds[0], bounds[1]) 62 | dy = np.random.randint(bounds[0], bounds[1]) 63 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color 64 | 65 | return img1, img2 66 | 67 | def spatial_transform(self, img1, img2, flow): 68 | # randomly sample scale 69 | ht, wd = img1.shape[:2] 70 | min_scale = np.maximum( 71 | (self.crop_size[0] + 8) / float(ht), 72 | (self.crop_size[1] + 8) / float(wd)) 73 | 74 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 75 | scale_x = scale 76 | scale_y = scale 77 | if np.random.rand() < self.stretch_prob: 78 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 79 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 80 | 81 | scale_x = np.clip(scale_x, min_scale, None) 82 | scale_y = np.clip(scale_y, min_scale, None) 83 | 84 | if np.random.rand() < self.spatial_aug_prob: 85 | # rescale the images 86 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 87 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 88 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 89 | flow = flow * [scale_x, scale_y] 90 | 91 | if self.do_flip: 92 | if np.random.rand() < self.h_flip_prob: # h-flip 93 | img1 = img1[:, ::-1] 94 | img2 = img2[:, ::-1] 95 | flow = flow[:, ::-1] * [-1.0, 1.0] 96 | 97 | if np.random.rand() < self.v_flip_prob: # v-flip 98 | img1 = img1[::-1, :] 99 | img2 = img2[::-1, :] 100 | flow = flow[::-1, :] * [1.0, -1.0] 101 | 102 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) 103 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) 104 | 105 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 106 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 107 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 108 | 109 | return img1, img2, flow 110 | 111 | def __call__(self, img1, img2, flow): 112 | img1, img2 = self.color_transform(img1, img2) 113 | img1, img2 = self.eraser_transform(img1, img2) 114 | img1, img2, flow = self.spatial_transform(img1, img2, flow) 115 | 116 | img1 = np.ascontiguousarray(img1) 117 | img2 = np.ascontiguousarray(img2) 118 | flow = np.ascontiguousarray(flow) 119 | 120 | return img1, img2, flow 121 | 122 | class SparseFlowAugmentor: 123 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): 124 | # spatial augmentation params 125 | self.crop_size = crop_size 126 | self.min_scale = min_scale 127 | self.max_scale = max_scale 128 | self.spatial_aug_prob = 0.8 129 | self.stretch_prob = 0.8 130 | self.max_stretch = 0.2 131 | 132 | # flip augmentation params 133 | self.do_flip = do_flip 134 | self.h_flip_prob = 0.5 135 | self.v_flip_prob = 0.1 136 | 137 | # photometric augmentation params 138 | self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) 139 | self.asymmetric_color_aug_prob = 0.2 140 | self.eraser_aug_prob = 0.5 141 | 142 | def color_transform(self, img1, img2): 143 | image_stack = np.concatenate([img1, img2], axis=0) 144 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 145 | img1, img2 = np.split(image_stack, 2, axis=0) 146 | return img1, img2 147 | 148 | def eraser_transform(self, img1, img2): 149 | ht, wd = img1.shape[:2] 150 | if np.random.rand() < self.eraser_aug_prob: 151 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 152 | for _ in range(np.random.randint(1, 3)): 153 | x0 = np.random.randint(0, wd) 154 | y0 = np.random.randint(0, ht) 155 | dx = np.random.randint(50, 100) 156 | dy = np.random.randint(50, 100) 157 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color 158 | 159 | return img1, img2 160 | 161 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): 162 | ht, wd = flow.shape[:2] 163 | coords = np.meshgrid(np.arange(wd), np.arange(ht)) 164 | coords = np.stack(coords, axis=-1) 165 | 166 | coords = coords.reshape(-1, 2).astype(np.float32) 167 | flow = flow.reshape(-1, 2).astype(np.float32) 168 | valid = valid.reshape(-1).astype(np.float32) 169 | 170 | coords0 = coords[valid>=1] 171 | flow0 = flow[valid>=1] 172 | 173 | ht1 = int(round(ht * fy)) 174 | wd1 = int(round(wd * fx)) 175 | 176 | coords1 = coords0 * [fx, fy] 177 | flow1 = flow0 * [fx, fy] 178 | 179 | xx = np.round(coords1[:,0]).astype(np.int32) 180 | yy = np.round(coords1[:,1]).astype(np.int32) 181 | 182 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) 183 | xx = xx[v] 184 | yy = yy[v] 185 | flow1 = flow1[v] 186 | 187 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) 188 | valid_img = np.zeros([ht1, wd1], dtype=np.int32) 189 | 190 | flow_img[yy, xx] = flow1 191 | valid_img[yy, xx] = 1 192 | 193 | return flow_img, valid_img 194 | 195 | def spatial_transform(self, img1, img2, flow, valid): 196 | # randomly sample scale 197 | 198 | ht, wd = img1.shape[:2] 199 | min_scale = np.maximum( 200 | (self.crop_size[0] + 1) / float(ht), 201 | (self.crop_size[1] + 1) / float(wd)) 202 | 203 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 204 | scale_x = np.clip(scale, min_scale, None) 205 | scale_y = np.clip(scale, min_scale, None) 206 | 207 | if np.random.rand() < self.spatial_aug_prob: 208 | # rescale the images 209 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 210 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 211 | flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) 212 | 213 | if self.do_flip: 214 | if np.random.rand() < 0.5: # h-flip 215 | img1 = img1[:, ::-1] 216 | img2 = img2[:, ::-1] 217 | flow = flow[:, ::-1] * [-1.0, 1.0] 218 | valid = valid[:, ::-1] 219 | 220 | margin_y = 20 221 | margin_x = 50 222 | 223 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) 224 | x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) 225 | 226 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) 227 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) 228 | 229 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 230 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 231 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 232 | valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 233 | return img1, img2, flow, valid 234 | 235 | 236 | def __call__(self, img1, img2, flow, valid): 237 | img1, img2 = self.color_transform(img1, img2) 238 | img1, img2 = self.eraser_transform(img1, img2) 239 | img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) 240 | 241 | img1 = np.ascontiguousarray(img1) 242 | img2 = np.ascontiguousarray(img2) 243 | flow = np.ascontiguousarray(flow) 244 | valid = np.ascontiguousarray(valid) 245 | 246 | return img1, img2, flow, valid 247 | -------------------------------------------------------------------------------- /data_preprocessing/RAFT/core/utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | 20 | def make_colorwheel(): 21 | """ 22 | Generates a color wheel for optical flow visualization as presented in: 23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 25 | 26 | Code follows the original C++ source code of Daniel Scharstein. 27 | Code follows the the Matlab source code of Deqing Sun. 28 | 29 | Returns: 30 | np.ndarray: Color wheel 31 | """ 32 | 33 | RY = 15 34 | YG = 6 35 | GC = 4 36 | CB = 11 37 | BM = 13 38 | MR = 6 39 | 40 | ncols = RY + YG + GC + CB + BM + MR 41 | colorwheel = np.zeros((ncols, 3)) 42 | col = 0 43 | 44 | # RY 45 | colorwheel[0:RY, 0] = 255 46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 47 | col = col+RY 48 | # YG 49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 50 | colorwheel[col:col+YG, 1] = 255 51 | col = col+YG 52 | # GC 53 | colorwheel[col:col+GC, 1] = 255 54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 55 | col = col+GC 56 | # CB 57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 58 | colorwheel[col:col+CB, 2] = 255 59 | col = col+CB 60 | # BM 61 | colorwheel[col:col+BM, 2] = 255 62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 63 | col = col+BM 64 | # MR 65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 66 | colorwheel[col:col+MR, 0] = 255 67 | return colorwheel 68 | 69 | 70 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 71 | """ 72 | Applies the flow color wheel to (possibly clipped) flow components u and v. 73 | 74 | According to the C++ source code of Daniel Scharstein 75 | According to the Matlab source code of Deqing Sun 76 | 77 | Args: 78 | u (np.ndarray): Input horizontal flow of shape [H,W] 79 | v (np.ndarray): Input vertical flow of shape [H,W] 80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 81 | 82 | Returns: 83 | np.ndarray: Flow visualization image of shape [H,W,3] 84 | """ 85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 86 | colorwheel = make_colorwheel() # shape [55x3] 87 | ncols = colorwheel.shape[0] 88 | rad = np.sqrt(np.square(u) + np.square(v)) 89 | a = np.arctan2(-v, -u)/np.pi 90 | fk = (a+1) / 2*(ncols-1) 91 | k0 = np.floor(fk).astype(np.int32) 92 | k1 = k0 + 1 93 | k1[k1 == ncols] = 0 94 | f = fk - k0 95 | for i in range(colorwheel.shape[1]): 96 | tmp = colorwheel[:,i] 97 | col0 = tmp[k0] / 255.0 98 | col1 = tmp[k1] / 255.0 99 | col = (1-f)*col0 + f*col1 100 | idx = (rad <= 1) 101 | col[idx] = 1 - rad[idx] * (1-col[idx]) 102 | col[~idx] = col[~idx] * 0.75 # out of range 103 | # Note the 2-i => BGR instead of RGB 104 | ch_idx = 2-i if convert_to_bgr else i 105 | flow_image[:,:,ch_idx] = np.floor(255 * col) 106 | return flow_image 107 | 108 | 109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 110 | """ 111 | Expects a two dimensional flow image of shape. 112 | 113 | Args: 114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 117 | 118 | Returns: 119 | np.ndarray: Flow visualization image of shape [H,W,3] 120 | """ 121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 123 | if clip_flow is not None: 124 | flow_uv = np.clip(flow_uv, 0, clip_flow) 125 | u = flow_uv[:,:,0] 126 | v = flow_uv[:,:,1] 127 | rad = np.sqrt(np.square(u) + np.square(v)) 128 | rad_max = np.max(rad) 129 | epsilon = 1e-5 130 | u = u / (rad_max + epsilon) 131 | v = v / (rad_max + epsilon) 132 | return flow_uv_to_colors(u, v, convert_to_bgr) -------------------------------------------------------------------------------- /data_preprocessing/RAFT/core/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | TAG_CHAR = np.array([202021.25], np.float32) 11 | 12 | def readFlow(fn): 13 | """ Read .flo file in Middlebury format""" 14 | # Code adapted from: 15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 16 | 17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 18 | # print 'fn = %s'%(fn) 19 | with open(fn, 'rb') as f: 20 | magic = np.fromfile(f, np.float32, count=1) 21 | if 202021.25 != magic: 22 | print('Magic number incorrect. Invalid .flo file') 23 | return None 24 | else: 25 | w = np.fromfile(f, np.int32, count=1) 26 | h = np.fromfile(f, np.int32, count=1) 27 | # print 'Reading %d x %d flo file\n' % (w, h) 28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) 29 | # Reshape data into 3D array (columns, rows, bands) 30 | # The reshape here is for visualization, the original code is (w,h,2) 31 | return np.resize(data, (int(h), int(w), 2)) 32 | 33 | def readPFM(file): 34 | file = open(file, 'rb') 35 | 36 | color = None 37 | width = None 38 | height = None 39 | scale = None 40 | endian = None 41 | 42 | header = file.readline().rstrip() 43 | if header == b'PF': 44 | color = True 45 | elif header == b'Pf': 46 | color = False 47 | else: 48 | raise Exception('Not a PFM file.') 49 | 50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 51 | if dim_match: 52 | width, height = map(int, dim_match.groups()) 53 | else: 54 | raise Exception('Malformed PFM header.') 55 | 56 | scale = float(file.readline().rstrip()) 57 | if scale < 0: # little-endian 58 | endian = '<' 59 | scale = -scale 60 | else: 61 | endian = '>' # big-endian 62 | 63 | data = np.fromfile(file, endian + 'f') 64 | shape = (height, width, 3) if color else (height, width) 65 | 66 | data = np.reshape(data, shape) 67 | data = np.flipud(data) 68 | return data 69 | 70 | def writeFlow(filename,uv,v=None): 71 | """ Write optical flow to file. 72 | 73 | If v is None, uv is assumed to contain both u and v channels, 74 | stacked in depth. 75 | Original code by Deqing Sun, adapted from Daniel Scharstein. 76 | """ 77 | nBands = 2 78 | 79 | if v is None: 80 | assert(uv.ndim == 3) 81 | assert(uv.shape[2] == 2) 82 | u = uv[:,:,0] 83 | v = uv[:,:,1] 84 | else: 85 | u = uv 86 | 87 | assert(u.shape == v.shape) 88 | height,width = u.shape 89 | f = open(filename,'wb') 90 | # write the header 91 | f.write(TAG_CHAR) 92 | np.array(width).astype(np.int32).tofile(f) 93 | np.array(height).astype(np.int32).tofile(f) 94 | # arrange into matrix form 95 | tmp = np.zeros((height, width*nBands)) 96 | tmp[:,np.arange(width)*2] = u 97 | tmp[:,np.arange(width)*2 + 1] = v 98 | tmp.astype(np.float32).tofile(f) 99 | f.close() 100 | 101 | 102 | def readFlowKITTI(filename): 103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) 104 | flow = flow[:,:,::-1].astype(np.float32) 105 | flow, valid = flow[:, :, :2], flow[:, :, 2] 106 | flow = (flow - 2**15) / 64.0 107 | return flow, valid 108 | 109 | def readDispKITTI(filename): 110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 111 | valid = disp > 0.0 112 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 113 | return flow, valid 114 | 115 | 116 | def writeFlowKITTI(filename, uv): 117 | uv = 64.0 * uv + 2**15 118 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 120 | cv2.imwrite(filename, uv[..., ::-1]) 121 | 122 | 123 | def read_gen(file_name, pil=False): 124 | ext = splitext(file_name)[-1] 125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 126 | return Image.open(file_name) 127 | elif ext == '.bin' or ext == '.raw': 128 | return np.load(file_name) 129 | elif ext == '.flo': 130 | return readFlow(file_name).astype(np.float32) 131 | elif ext == '.pfm': 132 | flow = readPFM(file_name).astype(np.float32) 133 | if len(flow.shape) == 2: 134 | return flow 135 | else: 136 | return flow[:, :, :-1] 137 | return [] -------------------------------------------------------------------------------- /data_preprocessing/RAFT/core/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | 7 | class InputPadder: 8 | """ Pads images such that dimensions are divisible by 8 """ 9 | def __init__(self, dims, mode='sintel'): 10 | self.ht, self.wd = dims[-2:] 11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 13 | if mode == 'sintel': 14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 15 | else: 16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 17 | 18 | def pad(self, *inputs): 19 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 20 | 21 | def unpad(self,x): 22 | ht, wd = x.shape[-2:] 23 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 24 | return x[..., c[0]:c[1], c[2]:c[3]] 25 | 26 | def forward_interpolate(flow): 27 | flow = flow.detach().cpu().numpy() 28 | dx, dy = flow[0], flow[1] 29 | 30 | ht, wd = dx.shape 31 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 32 | 33 | x1 = x0 + dx 34 | y1 = y0 + dy 35 | 36 | x1 = x1.reshape(-1) 37 | y1 = y1.reshape(-1) 38 | dx = dx.reshape(-1) 39 | dy = dy.reshape(-1) 40 | 41 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 42 | x1 = x1[valid] 43 | y1 = y1[valid] 44 | dx = dx[valid] 45 | dy = dy[valid] 46 | 47 | flow_x = interpolate.griddata( 48 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 49 | 50 | flow_y = interpolate.griddata( 51 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 52 | 53 | flow = np.stack([flow_x, flow_y], axis=0) 54 | return torch.from_numpy(flow).float() 55 | 56 | 57 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 58 | """ Wrapper for grid_sample, uses pixel coordinates """ 59 | H, W = img.shape[-2:] 60 | xgrid, ygrid = coords.split([1,1], dim=-1) 61 | xgrid = 2*xgrid/(W-1) - 1 62 | ygrid = 2*ygrid/(H-1) - 1 63 | 64 | grid = torch.cat([xgrid, ygrid], dim=-1) 65 | img = F.grid_sample(img, grid, align_corners=True) 66 | 67 | if mask: 68 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 69 | return img, mask.float() 70 | 71 | return img 72 | 73 | 74 | def coords_grid(batch, ht, wd, device): 75 | coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) 76 | coords = torch.stack(coords[::-1], dim=0).float() 77 | return coords[None].repeat(batch, 1, 1, 1) 78 | 79 | 80 | def upflow8(flow, mode='bilinear'): 81 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 82 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 83 | -------------------------------------------------------------------------------- /data_preprocessing/RAFT/demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('core') 3 | 4 | import argparse 5 | import os 6 | import cv2 7 | import glob 8 | import numpy as np 9 | import torch 10 | from PIL import Image 11 | import torch.nn.functional as F 12 | 13 | from raft import RAFT 14 | from utils import flow_viz 15 | from utils.utils import InputPadder 16 | 17 | DEVICE = 'cuda' 18 | 19 | 20 | def load_image(imfile): 21 | img = np.array(Image.open(imfile)).astype(np.uint8) 22 | img = torch.from_numpy(img).permute(2, 0, 1).float() 23 | return img[None].to(DEVICE) 24 | 25 | 26 | def viz(img, flo,img_name=None): 27 | img = img[0].permute(1,2,0).cpu().numpy() 28 | flo = flo[0].permute(1,2,0).cpu().numpy() 29 | 30 | # map flow to rgb image 31 | flo = flow_viz.flow_to_image(flo) 32 | img_flo = np.concatenate([img, flo], axis=0) 33 | 34 | cv2.imwrite(f'{img_name}', img_flo[:, :, [2,1,0]]) 35 | 36 | 37 | def demo(args): 38 | model = torch.nn.DataParallel(RAFT(args)) 39 | model.load_state_dict(torch.load(args.model)) 40 | 41 | model = model.module 42 | model.to(DEVICE) 43 | model.eval() 44 | os.makedirs(args.outdir, exist_ok=True) 45 | os.makedirs(args.outdir_conf, exist_ok=True) 46 | with torch.no_grad(): 47 | images = glob.glob(os.path.join(args.path, '*.png')) + \ 48 | glob.glob(os.path.join(args.path, '*.jpg')) 49 | 50 | images = sorted(images) 51 | i=0 52 | for imfile1, imfile2 in zip(images[:-1], images[1:]): 53 | image1 = load_image(imfile1) 54 | image2 = load_image(imfile2) 55 | if args.if_mask: 56 | mk_file1=imfile1.split("/") 57 | mk_file1[-2]=f"{args.name}_masks" 58 | mk_file1='/'.join(mk_file1) 59 | mk_file2=imfile2.split("/") 60 | mk_file2[-2]=f"{args.name}_masks" 61 | mk_file2='/'.join(mk_file2) 62 | mask1=cv2.imread(mk_file1.replace('jpg','png') 63 | ,0) 64 | mask2=cv2.imread(mk_file2.replace('jpg','png'), 65 | 0) 66 | mask1=torch.from_numpy(mask1).to(DEVICE).float() 67 | mask2=torch.from_numpy(mask2).to(DEVICE).float() 68 | mask1[mask1>0]=1 69 | mask2[mask2>0]=1 70 | image1*=mask1 71 | image2*=mask2 72 | 73 | padder = InputPadder(image1.shape) 74 | image1, image2 = padder.pad(image1, image2) 75 | if args.if_mask: 76 | mask1,mask2=padder.pad(mask1.unsqueeze(0).unsqueeze(0), 77 | mask2.unsqueeze(0).unsqueeze(0)) 78 | mask1=mask1.squeeze() 79 | mask2=mask2.squeeze() 80 | 81 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) 82 | flow_low_, flow_up_ = model(image2, image1, iters=20, test_mode=True) 83 | flow_1to2 = flow_up.clone() 84 | flow_2to1 = flow_up_.clone() 85 | 86 | _,_,H,W=image1.shape 87 | x = torch.linspace(0, 1, W) 88 | y = torch.linspace(0, 1, H) 89 | grid_x,grid_y=torch.meshgrid(x,y) 90 | grid=torch.stack([grid_x,grid_y],dim=0).to(DEVICE) 91 | grid=grid.permute(0,2,1) 92 | grid[0]*=W 93 | grid[1]*=H 94 | if args.if_mask: 95 | flow_up[:,:,mask1.long()==0]=10000 96 | grid_=grid+flow_up.squeeze() 97 | 98 | grid_norm=grid_.clone() 99 | grid_norm[0,...]=2*grid_norm[0,...]/(W-1)-1 100 | grid_norm[1,...]=2*grid_norm[1,...]/(H-1)-1 101 | 102 | flow_bilinear_=F.grid_sample(flow_up_,grid_norm.unsqueeze(0).permute(0,2,3,1),mode='bilinear',padding_mode='zeros') 103 | 104 | rgb_bilinear_=F.grid_sample(image2,grid_norm.unsqueeze(0).permute(0,2,3,1),mode='bilinear',padding_mode='zeros') 105 | rgb_np=rgb_bilinear_.squeeze().permute(1,2,0).cpu().numpy()[:, :, ::-1] 106 | cv2.imwrite(f'{args.outdir}/warped.png',rgb_np) 107 | 108 | if args.confidence: 109 | ### Calculate confidence map using cycle consistency. 110 | # 1). First calculate `warped_image2` by the following formula: 111 | # warped_image2 = F.grid_sample(image1, flow_2to1) 112 | # 2). Then calculate `warped_image1` by the following formula: 113 | # warped_image1 = F.grid_sample(warped_image2, flow_1to2) 114 | # 3) Finally calculate the confidence map: 115 | # confidence_map = metric_func(image1 - warped_image1) 116 | 117 | grid_2to1 = grid + flow_2to1.squeeze() 118 | norm_grid_2to1 = grid_2to1.clone() 119 | norm_grid_2to1[0, ...] = 2 * norm_grid_2to1[0, ...] / (W - 1) - 1 120 | norm_grid_2to1[1, ...] = 2 * norm_grid_2to1[1, ...] / (H - 1) - 1 121 | warped_image2 = F.grid_sample(image1, norm_grid_2to1.unsqueeze(0).permute(0,2,3,1), mode='bilinear', padding_mode='zeros') 122 | 123 | grid_1to2 = grid + flow_1to2.squeeze() 124 | norm_grid_1to2 = grid_1to2.clone() 125 | norm_grid_1to2[0, ...] = 2 * norm_grid_1to2[0, ...] / (W - 1) - 1 126 | norm_grid_1to2[1, ...] = 2 * norm_grid_1to2[1, ...] / (H - 1) - 1 127 | warped_image1 = F.grid_sample(warped_image2, norm_grid_1to2.unsqueeze(0).permute(0,2,3,1), mode='bilinear', padding_mode='zeros') 128 | 129 | error = torch.abs(image1 - warped_image1) 130 | confidence_map = torch.mean(error, dim=1, keepdim=True) 131 | confidence_map[confidence_map < args.thres] = 1 132 | confidence_map[confidence_map >= args.thres] = 0 133 | 134 | grid_bck=grid+flow_up.squeeze()+flow_bilinear_.squeeze() 135 | res=grid-grid_bck 136 | res=torch.norm(res,dim=0) 137 | mk=(res<10)&(flow_up.norm(dim=1).squeeze()>5) 138 | 139 | pts_src=grid[:,mk] 140 | 141 | pts_dst=(grid[:,mk]+flow_up.squeeze()[:,mk]) 142 | 143 | pts_src=pts_src.permute(1,0).cpu().numpy() 144 | pts_dst=pts_dst.permute(1,0).cpu().numpy() 145 | indx=torch.randperm(pts_src.shape[0])[:30] 146 | # use cv2 to draw the matches in image1 and image2 147 | img_new=np.zeros((H,W*2,3),dtype=np.uint8) 148 | img_new[:,:W,:]=image1[0].permute(1,2,0).cpu().numpy() 149 | img_new[:,W:,:]=image2[0].permute(1,2,0).cpu().numpy() 150 | 151 | for j in indx: 152 | cv2.line(img_new,(int(pts_src[j,0]),int(pts_src[j,1])),(int(pts_dst[j,0])+W,int(pts_dst[j,1])),(0,255,0),1) 153 | 154 | cv2.imwrite(f'{args.outdir}/matches.png',img_new) 155 | 156 | np.save(f'{args.outdir}/{i:06d}.npy', flow_up.cpu().numpy()) 157 | if args.confidence: 158 | np.save(f'{args.outdir_conf}/{i:06d}_c.npy', confidence_map.cpu().numpy()) 159 | i += 1 160 | 161 | viz(image1, flow_up,f'{args.outdir}/flow_up{i:03d}.png') 162 | 163 | 164 | if __name__ == '__main__': 165 | parser = argparse.ArgumentParser() 166 | parser.add_argument('--model', help="restore checkpoint") 167 | parser.add_argument('--path', help="dataset for evaluation") 168 | parser.add_argument('--outdir',help="directory for the ouput the result") 169 | parser.add_argument('--small', action='store_true', help='use small model') 170 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') 171 | parser.add_argument('--if_mask', action='store_true', help='if using the image mask to mask the color img') 172 | parser.add_argument('--confidence', action='store_true', help='if saving the confidence map') 173 | parser.add_argument('--discrete', action='store_true', help='if saving the confidence map in discrete') 174 | parser.add_argument('--thres', default=4, help='Threshold value for confidence map') 175 | parser.add_argument('--outdir_conf', help="directory to save flow confidence") 176 | parser.add_argument('--name', help="the name of a sequence") 177 | args = parser.parse_args() 178 | 179 | demo(args) 180 | -------------------------------------------------------------------------------- /data_preprocessing/RAFT/models/download_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget https://www.dropbox.com/s/4j4z58wuv8o0mfz/models.zip 3 | unzip models.zip 4 | -------------------------------------------------------------------------------- /data_preprocessing/RAFT/run_raft.sh: -------------------------------------------------------------------------------- 1 | NAME=beauty_1 2 | ROOT_DIR=/home/xxx/code/CoDeF/all_sequences 3 | CODE_DIR=/home/xxx/code/CoDeF/data_preprocessing/RAFT 4 | 5 | IMG_DIR=$ROOT_DIR/${NAME}/${NAME} 6 | FLOW_DIR=$ROOT_DIR/${NAME}/${NAME}_flow 7 | CONF_DIR=${FLOW_DIR}_confidence 8 | 9 | CUDA_VISIBLE_DEVICES=0 \ 10 | python ${CODE_DIR}/demo.py \ 11 | --model=${CODE_DIR}/models/raft-sintel.pth \ 12 | --path=$IMG_DIR \ 13 | --outdir=$FLOW_DIR \ 14 | --name=$NAME \ 15 | --confidence \ 16 | --outdir_conf=$CONF_DIR 17 | -------------------------------------------------------------------------------- /data_preprocessing/preproc_mask.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | from glob import glob 5 | from tqdm import tqdm 6 | 7 | root_dir = '/home/xxx/code/CoDeF/all_sequences' 8 | name = 'beauty_1' 9 | 10 | msk_folder = f'{root_dir}/{name}/{name}_masks' 11 | img_folder = f'{root_dir}/{name}/{name}' 12 | frg_mask_folder = f'{root_dir}/{name}/{name}_masks_0' 13 | bkg_mask_folder = f'{root_dir}/{name}/{name}_masks_1' 14 | os.makedirs(frg_mask_folder, exist_ok=True) 15 | os.makedirs(bkg_mask_folder, exist_ok=True) 16 | 17 | files = glob(msk_folder + '/*.png') 18 | num = len(files) 19 | 20 | for i in tqdm(range(num)): 21 | file_n = os.path.basename(files[i]) 22 | mask = cv2.imread(os.path.join(msk_folder, file_n), 0) 23 | mask[mask > 0] = 1 24 | cv2.imwrite(os.path.join(frg_mask_folder, file_n), mask * 255) 25 | 26 | bg_mask = mask.copy() 27 | bg_mask[bg_mask == 0] = 127 28 | bg_mask[bg_mask == 255] = 0 29 | bg_mask[bg_mask == 127] = 255 30 | cv2.imwrite(os.path.join(bkg_mask_folder, file_n), bg_mask) -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .distributed_weighted_sampler import DistributedWeightedSampler 2 | from .video_dataset import VideoDataset 3 | 4 | dataset_dict = {'video': VideoDataset} 5 | 6 | custom_sampler_dict = {'weighted': DistributedWeightedSampler} -------------------------------------------------------------------------------- /datasets/distributed_weighted_sampler.py: -------------------------------------------------------------------------------- 1 | # Combine weighted sampler and distributed sampler 2 | import math 3 | from typing import TypeVar, Optional, Iterator 4 | 5 | import torch 6 | import numpy as np 7 | from torch.utils.data import Sampler, Dataset 8 | import torch.distributed as dist 9 | 10 | 11 | T_co = TypeVar('T_co', covariant=True) 12 | 13 | 14 | class DistributedWeightedSampler(Sampler[T_co]): 15 | r"""Sampler that restricts data loading to a subset of the dataset. 16 | 17 | It is especially useful in conjunction with 18 | :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each 19 | process can pass a :class:`~torch.utils.data.DistributedSampler` instance as a 20 | :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the 21 | original dataset that is exclusive to it. 22 | 23 | .. note:: 24 | Dataset is assumed to be of constant size. 25 | 26 | Args: 27 | dataset: Dataset used for sampling. 28 | num_replicas (int, optional): Number of processes participating in 29 | distributed training. By default, :attr:`world_size` is retrieved from the 30 | current distributed group. 31 | rank (int, optional): Rank of the current process within :attr:`num_replicas`. 32 | By default, :attr:`rank` is retrieved from the current distributed 33 | group. 34 | shuffle (bool, optional): If ``True`` (default), sampler will shuffle the 35 | indices. 36 | seed (int, optional): random seed used to shuffle the sampler if 37 | :attr:`shuffle=True`. This number should be identical across all 38 | processes in the distributed group. Default: ``0``. 39 | drop_last (bool, optional): if ``True``, then the sampler will drop the 40 | tail of the data to make it evenly divisible across the number of 41 | replicas. If ``False``, the sampler will add extra indices to make 42 | the data evenly divisible across the replicas. Default: ``False``. 43 | 44 | .. warning:: 45 | In distributed mode, calling the :meth:`set_epoch` method at 46 | the beginning of each epoch **before** creating the :class:`DataLoader` iterator 47 | is necessary to make shuffling work properly across multiple epochs. Otherwise, 48 | the same ordering will be always used. 49 | 50 | Example:: 51 | 52 | >>> sampler = DistributedSampler(dataset) if is_distributed else None 53 | >>> loader = DataLoader(dataset, shuffle=(sampler is None), 54 | ... sampler=sampler) 55 | >>> for epoch in range(start_epoch, n_epochs): 56 | ... if is_distributed: 57 | ... sampler.set_epoch(epoch) 58 | ... train(loader) 59 | """ 60 | 61 | def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None, 62 | rank: Optional[int] = None, shuffle: bool = True, 63 | seed: int = 0, drop_last: bool = False, replacement: bool = True) -> None: 64 | if num_replicas is None: 65 | if not dist.is_available(): 66 | raise RuntimeError("Requires distributed package to be available") 67 | num_replicas = dist.get_world_size() 68 | if rank is None: 69 | if not dist.is_available(): 70 | raise RuntimeError("Requires distributed package to be available") 71 | rank = dist.get_rank() 72 | if rank >= num_replicas or rank < 0: 73 | raise ValueError( 74 | "Invalid rank {}, rank should be in the interval" 75 | " [0, {}]".format(rank, num_replicas - 1)) 76 | self.dataset = dataset 77 | self.num_replicas = num_replicas 78 | self.rank = rank 79 | self.epoch = 0 80 | self.drop_last = drop_last 81 | # If the dataset length is evenly divisible by # of replicas, then there 82 | # is no need to drop any data, since the dataset will be split equally. 83 | if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type] 84 | # Split to nearest available length that is evenly divisible. 85 | # This is to ensure each rank receives the same amount of data when 86 | # using this Sampler. 87 | self.num_samples = math.ceil( 88 | # `type:ignore` is required because Dataset cannot provide a default __len__ 89 | # see NOTE in pytorch/torch/utils/data/sampler.py 90 | (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] 91 | ) 92 | else: 93 | self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type] 94 | self.total_size = self.num_samples * self.num_replicas 95 | self.shuffle = shuffle 96 | self.seed = seed 97 | self.weights = self.dataset.weights 98 | self.replacement = replacement 99 | 100 | # def calculate_weights(self, targets): 101 | # class_sample_count = torch.tensor( 102 | # [(targets == t).sum() for t in torch.unique(targets, sorted=True)]) 103 | # weight = 1. / class_sample_count.double() 104 | # samples_weight = torch.tensor([weight[t] for t in targets]) 105 | # return samples_weight 106 | 107 | def __iter__(self) -> Iterator[T_co]: 108 | if self.shuffle: 109 | # deterministically shuffle based on epoch and seed 110 | g = torch.Generator() 111 | g.manual_seed(self.seed + self.epoch) 112 | indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] 113 | else: 114 | indices = list(range(len(self.dataset))) # type: ignore[arg-type] 115 | 116 | if not self.drop_last: 117 | # add extra samples to make it evenly divisible 118 | padding_size = self.total_size - len(indices) 119 | if padding_size <= len(indices): 120 | indices += indices[:padding_size] 121 | else: 122 | indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] 123 | else: 124 | # remove tail of data to make it evenly divisible. 125 | indices = indices[:self.total_size] 126 | assert len(indices) == self.total_size 127 | 128 | # subsample 129 | indices = indices[self.rank:self.total_size:self.num_replicas] 130 | assert len(indices) == self.num_samples 131 | 132 | # subsample weights 133 | # targets = self.targets[indices] 134 | weights = self.weights[indices][:, 0] 135 | assert len(weights) == self.num_samples 136 | 137 | ########################################################################### 138 | # the upper bound category number of multinomial is 2^24, to handle this we can use chunk or using random choices 139 | # subsample_balanced_indicies = torch.multinomial(weights, self.num_samples, self.replacement) 140 | ########################################################################### 141 | # using random choices 142 | rand_tensor = np.random.choice(range(0, len(weights)), 143 | size=self.num_samples, 144 | p=weights.numpy() / torch.sum(weights).numpy(), 145 | replace=self.replacement) 146 | subsample_balanced_indicies = torch.from_numpy(rand_tensor) 147 | dataset_indices = torch.tensor(indices)[subsample_balanced_indicies] 148 | 149 | return iter(dataset_indices.tolist()) 150 | 151 | 152 | def __len__(self) -> int: 153 | return self.num_samples 154 | 155 | def set_epoch(self, epoch: int) -> None: 156 | r""" 157 | Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas 158 | use a different random ordering for each epoch. Otherwise, the next iteration of this 159 | sampler will yield the same ordering. 160 | 161 | Args: 162 | epoch (int): Epoch number. 163 | """ 164 | self.epoch = epoch -------------------------------------------------------------------------------- /datasets/video_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import numpy as np 4 | import os 5 | from PIL import Image 6 | from einops import rearrange, reduce, repeat 7 | from torchvision import transforms as T 8 | import glob 9 | import cv2 10 | 11 | # The basic dataset of reading rays 12 | class VideoDataset(Dataset): 13 | 14 | def __init__(self, 15 | root_dir, 16 | split='train', 17 | img_wh=(504, 378), 18 | mask_dir=None, 19 | flow_dir=None, 20 | canonical_wh=None, 21 | ref_idx=None, 22 | canonical_dir=None, 23 | test=False): 24 | self.test = test 25 | self.root_dir = root_dir 26 | self.split = split 27 | self.img_wh = img_wh 28 | self.mask_dir = mask_dir 29 | self.flow_dir = flow_dir 30 | self.canonical_wh = canonical_wh 31 | self.ref_idx = ref_idx 32 | self.canonical_dir = canonical_dir 33 | self.read_meta() 34 | 35 | def read_meta(self): 36 | all_images_path = [] 37 | self.ts_w = [] 38 | self.all_images = [] 39 | h = self.img_wh[1] 40 | w = self.img_wh[0] 41 | # construct grid 42 | grid = np.indices((h, w)).astype(np.float32) 43 | # normalize 44 | grid[0,:,:] = grid[0,:,:] / h 45 | grid[1,:,:] = grid[1,:,:] / w 46 | self.grid = torch.from_numpy(rearrange(grid, 'c h w -> (h w) c')) 47 | warp_code = 1 48 | for input_image_path in sorted(glob.glob(f'{self.root_dir}/*')): 49 | print(input_image_path) 50 | all_images_path.append(input_image_path) 51 | self.ts_w.append(torch.Tensor([warp_code]).long()) 52 | warp_code += 1 53 | 54 | if self.canonical_wh: 55 | h_c = self.canonical_wh[1] 56 | w_c = self.canonical_wh[0] 57 | grid_c = np.indices((h_c, w_c)).astype(np.float32) 58 | grid_c[0,:,:] = (grid_c[0,:,:] - (h_c - h) / 2) / h 59 | grid_c[1,:,:] = (grid_c[1,:,:] - (w_c - w) / 2) / w 60 | self.grid_c = torch.from_numpy(rearrange(grid_c, 'c h w -> (h w) c')) 61 | else: 62 | self.grid_c = self.grid 63 | self.canonical_wh = self.img_wh 64 | 65 | if self.mask_dir: 66 | self.all_masks = [] 67 | if self.flow_dir: 68 | self.all_flows = [] 69 | else: 70 | self.all_flows = None 71 | 72 | if self.split == 'train' or self.split == 'val': 73 | if self.canonical_dir is not None: 74 | all_images_path_ = sorted(glob.glob(f'{self.canonical_dir}/*.png')) 75 | self.canonical_img = [] 76 | for input_image_path in all_images_path_: 77 | input_image = cv2.imread(input_image_path) 78 | input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB) 79 | input_image = cv2.resize(input_image, (self.canonical_wh[0], self.canonical_wh[1]), interpolation = cv2.INTER_AREA) 80 | input_image_tensor = torch.from_numpy(input_image).float() / 256 81 | self.canonical_img.append(input_image_tensor) 82 | self.canonical_img = torch.stack(self.canonical_img, dim=0) 83 | 84 | for input_image_path in all_images_path: 85 | input_image = cv2.imread(input_image_path) 86 | input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB) 87 | input_image = cv2.resize(input_image, (self.img_wh[0], self.img_wh[1]), interpolation = cv2.INTER_AREA) 88 | input_image_tensor = torch.from_numpy(input_image).float() / 256 89 | self.all_images.append(input_image_tensor) 90 | if self.mask_dir: 91 | input_image_name = input_image_path.split("/")[-1][:-4] 92 | for i in range(len(self.mask_dir)): 93 | input_mask = cv2.imread(f'{self.mask_dir[i]}/{input_image_name}.png') 94 | input_mask = cv2.resize(input_mask, (self.img_wh[0], self.img_wh[1]), interpolation = cv2.INTER_AREA) 95 | input_mask_tensor = torch.from_numpy(input_mask).float() / 256 96 | self.all_masks.append(input_mask_tensor) 97 | 98 | if self.split == 'val': 99 | input_image = cv2.imread(all_images_path[0]) 100 | input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB) 101 | input_image = cv2.resize(input_image, (self.img_wh[0], self.img_wh[1]), interpolation = cv2.INTER_AREA) 102 | input_image_tensor = torch.from_numpy(input_image).float() / 256 103 | self.all_images.append(input_image_tensor) 104 | if self.mask_dir: 105 | input_image_name = all_images_path[0].split("/")[-1][:-4] 106 | for i in range(len(self.mask_dir)): 107 | input_mask = cv2.imread(f'{self.mask_dir[i]}/{input_image_name}.png') 108 | input_mask = cv2.resize(input_mask, (self.img_wh[0], self.img_wh[1]), interpolation = cv2.INTER_AREA) 109 | input_mask_tensor = torch.from_numpy(input_mask).float() / 256 110 | self.all_masks.append(input_mask_tensor) 111 | 112 | if self.flow_dir: 113 | for input_image_path in sorted(glob.glob(f'{self.flow_dir}/*npy')): 114 | flow_load=np.load(input_image_path) # (1, 2, h, w) 115 | flow_tensor=torch.from_numpy(flow_load).float()[:, [1, 0]] 116 | flow_tensor=torch.nn.functional.interpolate(flow_tensor,size=(self.img_wh[1],self.img_wh[0])) 117 | H_,W_=flow_load.shape[-2],flow_load.shape[-1] 118 | flow_tensor=flow_tensor.reshape(2,-1).transpose(1,0) 119 | flow_tensor[..., 0] /= W_ 120 | flow_tensor[..., 1] /= H_ 121 | self.all_flows.append(flow_tensor) 122 | 123 | i = 0 124 | for input_image_path in sorted(glob.glob(f'{self.flow_dir}_confidence/*npy')): 125 | flow_load=np.load(input_image_path) 126 | flow_tensor=torch.from_numpy(flow_load).float() 127 | flow_tensor=torch.nn.functional.interpolate(flow_tensor,size=(self.img_wh[1],self.img_wh[0])) 128 | flow_tensor=flow_tensor.reshape(1,-1).transpose(1,0) 129 | flow_tensor = flow_tensor.sum(dim=-1) < 0.05 130 | self.all_flows[i][flow_tensor] = 5 131 | i += 1 132 | 133 | if self.split == 'val': 134 | self.ref_idx = 0 135 | 136 | def __len__(self): 137 | if self.test: 138 | return len(self.all_images) 139 | return 200 * len(self.all_images) 140 | 141 | def __getitem__(self, idx): 142 | if self.split == 'train' or self.split == 'val': 143 | idx = idx % len(self.all_images) 144 | sample = {'rgbs': self.all_images[idx], 145 | 'canonical_img': self.all_images[idx] if self.canonical_dir is None else self.canonical_img, 146 | 'ts_w': self.ts_w[idx], 147 | 'grid': self.grid, 148 | 'canonical_wh': self.canonical_wh, 149 | 'img_wh': self.img_wh, 150 | 'masks': self.all_masks[len(self.mask_dir)*idx:len(self.mask_dir)*idx+len(self.mask_dir)] if self.mask_dir else [torch.ones((self.img_wh[1], self.img_wh[0], 1))], 151 | 'flows': self.all_flows[idx] if (idx 1 && arguments[1] !== undefined ? arguments[1] : {}; 107 | 108 | _classCallCheck(this, bulmaSlider); 109 | 110 | var _this = _possibleConstructorReturn(this, (bulmaSlider.__proto__ || Object.getPrototypeOf(bulmaSlider)).call(this)); 111 | 112 | _this.element = typeof selector === 'string' ? document.querySelector(selector) : selector; 113 | // An invalid selector or non-DOM node has been provided. 114 | if (!_this.element) { 115 | throw new Error('An invalid selector or non-DOM node has been provided.'); 116 | } 117 | 118 | _this._clickEvents = ['click']; 119 | /// Set default options and merge with instance defined 120 | _this.options = _extends({}, options); 121 | 122 | _this.onSliderInput = _this.onSliderInput.bind(_this); 123 | 124 | _this.init(); 125 | return _this; 126 | } 127 | 128 | /** 129 | * Initiate all DOM element containing selector 130 | * @method 131 | * @return {Array} Array of all slider instances 132 | */ 133 | 134 | 135 | _createClass(bulmaSlider, [{ 136 | key: 'init', 137 | 138 | 139 | /** 140 | * Initiate plugin 141 | * @method init 142 | * @return {void} 143 | */ 144 | value: function init() { 145 | this._id = 'bulmaSlider' + new Date().getTime() + Math.floor(Math.random() * Math.floor(9999)); 146 | this.output = this._findOutputForSlider(); 147 | 148 | this._bindEvents(); 149 | 150 | if (this.output) { 151 | if (this.element.classList.contains('has-output-tooltip')) { 152 | // Get new output position 153 | var newPosition = this._getSliderOutputPosition(); 154 | 155 | // Set output position 156 | this.output.style['left'] = newPosition.position; 157 | } 158 | } 159 | 160 | this.emit('bulmaslider:ready', this.element.value); 161 | } 162 | }, { 163 | key: '_findOutputForSlider', 164 | value: function _findOutputForSlider() { 165 | var _this2 = this; 166 | 167 | var result = null; 168 | var outputs = document.getElementsByTagName('output') || []; 169 | 170 | Array.from(outputs).forEach(function (output) { 171 | if (output.htmlFor == _this2.element.getAttribute('id')) { 172 | result = output; 173 | return true; 174 | } 175 | }); 176 | return result; 177 | } 178 | }, { 179 | key: '_getSliderOutputPosition', 180 | value: function _getSliderOutputPosition() { 181 | // Update output position 182 | var newPlace, minValue; 183 | 184 | var style = window.getComputedStyle(this.element, null); 185 | // Measure width of range input 186 | var sliderWidth = parseInt(style.getPropertyValue('width'), 10); 187 | 188 | // Figure out placement percentage between left and right of input 189 | if (!this.element.getAttribute('min')) { 190 | minValue = 0; 191 | } else { 192 | minValue = this.element.getAttribute('min'); 193 | } 194 | var newPoint = (this.element.value - minValue) / (this.element.getAttribute('max') - minValue); 195 | 196 | // Prevent bubble from going beyond left or right (unsupported browsers) 197 | if (newPoint < 0) { 198 | newPlace = 0; 199 | } else if (newPoint > 1) { 200 | newPlace = sliderWidth; 201 | } else { 202 | newPlace = sliderWidth * newPoint; 203 | } 204 | 205 | return { 206 | 'position': newPlace + 'px' 207 | }; 208 | } 209 | 210 | /** 211 | * Bind all events 212 | * @method _bindEvents 213 | * @return {void} 214 | */ 215 | 216 | }, { 217 | key: '_bindEvents', 218 | value: function _bindEvents() { 219 | if (this.output) { 220 | // Add event listener to update output when slider value change 221 | this.element.addEventListener('input', this.onSliderInput, false); 222 | } 223 | } 224 | }, { 225 | key: 'onSliderInput', 226 | value: function onSliderInput(e) { 227 | e.preventDefault(); 228 | 229 | if (this.element.classList.contains('has-output-tooltip')) { 230 | // Get new output position 231 | var newPosition = this._getSliderOutputPosition(); 232 | 233 | // Set output position 234 | this.output.style['left'] = newPosition.position; 235 | } 236 | 237 | // Check for prefix and postfix 238 | var prefix = this.output.hasAttribute('data-prefix') ? this.output.getAttribute('data-prefix') : ''; 239 | var postfix = this.output.hasAttribute('data-postfix') ? this.output.getAttribute('data-postfix') : ''; 240 | 241 | // Update output with slider value 242 | this.output.value = prefix + this.element.value + postfix; 243 | 244 | this.emit('bulmaslider:ready', this.element.value); 245 | } 246 | }], [{ 247 | key: 'attach', 248 | value: function attach() { 249 | var _this3 = this; 250 | 251 | var selector = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : 'input[type="range"].slider'; 252 | var options = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {}; 253 | 254 | var instances = new Array(); 255 | 256 | var elements = isString(selector) ? document.querySelectorAll(selector) : Array.isArray(selector) ? selector : [selector]; 257 | elements.forEach(function (element) { 258 | if (typeof element[_this3.constructor.name] === 'undefined') { 259 | var instance = new bulmaSlider(element, options); 260 | element[_this3.constructor.name] = instance; 261 | instances.push(instance); 262 | } else { 263 | instances.push(element[_this3.constructor.name]); 264 | } 265 | }); 266 | 267 | return instances; 268 | } 269 | }]); 270 | 271 | return bulmaSlider; 272 | }(__WEBPACK_IMPORTED_MODULE_0__events__["a" /* default */]); 273 | 274 | /* harmony default export */ __webpack_exports__["default"] = (bulmaSlider); 275 | 276 | /***/ }), 277 | /* 1 */ 278 | /***/ (function(module, __webpack_exports__, __webpack_require__) { 279 | 280 | "use strict"; 281 | var _createClass = function () { function defineProperties(target, props) { for (var i = 0; i < props.length; i++) { var descriptor = props[i]; descriptor.enumerable = descriptor.enumerable || false; descriptor.configurable = true; if ("value" in descriptor) descriptor.writable = true; Object.defineProperty(target, descriptor.key, descriptor); } } return function (Constructor, protoProps, staticProps) { if (protoProps) defineProperties(Constructor.prototype, protoProps); if (staticProps) defineProperties(Constructor, staticProps); return Constructor; }; }(); 282 | 283 | function _classCallCheck(instance, Constructor) { if (!(instance instanceof Constructor)) { throw new TypeError("Cannot call a class as a function"); } } 284 | 285 | var EventEmitter = function () { 286 | function EventEmitter() { 287 | var listeners = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : []; 288 | 289 | _classCallCheck(this, EventEmitter); 290 | 291 | this._listeners = new Map(listeners); 292 | this._middlewares = new Map(); 293 | } 294 | 295 | _createClass(EventEmitter, [{ 296 | key: "listenerCount", 297 | value: function listenerCount(eventName) { 298 | if (!this._listeners.has(eventName)) { 299 | return 0; 300 | } 301 | 302 | var eventListeners = this._listeners.get(eventName); 303 | return eventListeners.length; 304 | } 305 | }, { 306 | key: "removeListeners", 307 | value: function removeListeners() { 308 | var _this = this; 309 | 310 | var eventName = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : null; 311 | var middleware = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : false; 312 | 313 | if (eventName !== null) { 314 | if (Array.isArray(eventName)) { 315 | name.forEach(function (e) { 316 | return _this.removeListeners(e, middleware); 317 | }); 318 | } else { 319 | this._listeners.delete(eventName); 320 | 321 | if (middleware) { 322 | this.removeMiddleware(eventName); 323 | } 324 | } 325 | } else { 326 | this._listeners = new Map(); 327 | } 328 | } 329 | }, { 330 | key: "middleware", 331 | value: function middleware(eventName, fn) { 332 | var _this2 = this; 333 | 334 | if (Array.isArray(eventName)) { 335 | name.forEach(function (e) { 336 | return _this2.middleware(e, fn); 337 | }); 338 | } else { 339 | if (!Array.isArray(this._middlewares.get(eventName))) { 340 | this._middlewares.set(eventName, []); 341 | } 342 | 343 | this._middlewares.get(eventName).push(fn); 344 | } 345 | } 346 | }, { 347 | key: "removeMiddleware", 348 | value: function removeMiddleware() { 349 | var _this3 = this; 350 | 351 | var eventName = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : null; 352 | 353 | if (eventName !== null) { 354 | if (Array.isArray(eventName)) { 355 | name.forEach(function (e) { 356 | return _this3.removeMiddleware(e); 357 | }); 358 | } else { 359 | this._middlewares.delete(eventName); 360 | } 361 | } else { 362 | this._middlewares = new Map(); 363 | } 364 | } 365 | }, { 366 | key: "on", 367 | value: function on(name, callback) { 368 | var _this4 = this; 369 | 370 | var once = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false; 371 | 372 | if (Array.isArray(name)) { 373 | name.forEach(function (e) { 374 | return _this4.on(e, callback); 375 | }); 376 | } else { 377 | name = name.toString(); 378 | var split = name.split(/,|, | /); 379 | 380 | if (split.length > 1) { 381 | split.forEach(function (e) { 382 | return _this4.on(e, callback); 383 | }); 384 | } else { 385 | if (!Array.isArray(this._listeners.get(name))) { 386 | this._listeners.set(name, []); 387 | } 388 | 389 | this._listeners.get(name).push({ once: once, callback: callback }); 390 | } 391 | } 392 | } 393 | }, { 394 | key: "once", 395 | value: function once(name, callback) { 396 | this.on(name, callback, true); 397 | } 398 | }, { 399 | key: "emit", 400 | value: function emit(name, data) { 401 | var _this5 = this; 402 | 403 | var silent = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false; 404 | 405 | name = name.toString(); 406 | var listeners = this._listeners.get(name); 407 | var middlewares = null; 408 | var doneCount = 0; 409 | var execute = silent; 410 | 411 | if (Array.isArray(listeners)) { 412 | listeners.forEach(function (listener, index) { 413 | // Start Middleware checks unless we're doing a silent emit 414 | if (!silent) { 415 | middlewares = _this5._middlewares.get(name); 416 | // Check and execute Middleware 417 | if (Array.isArray(middlewares)) { 418 | middlewares.forEach(function (middleware) { 419 | middleware(data, function () { 420 | var newData = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : null; 421 | 422 | if (newData !== null) { 423 | data = newData; 424 | } 425 | doneCount++; 426 | }, name); 427 | }); 428 | 429 | if (doneCount >= middlewares.length) { 430 | execute = true; 431 | } 432 | } else { 433 | execute = true; 434 | } 435 | } 436 | 437 | // If Middleware checks have been passed, execute 438 | if (execute) { 439 | if (listener.once) { 440 | listeners[index] = null; 441 | } 442 | listener.callback(data); 443 | } 444 | }); 445 | 446 | // Dirty way of removing used Events 447 | while (listeners.indexOf(null) !== -1) { 448 | listeners.splice(listeners.indexOf(null), 1); 449 | } 450 | } 451 | } 452 | }]); 453 | 454 | return EventEmitter; 455 | }(); 456 | 457 | /* harmony default export */ __webpack_exports__["a"] = (EventEmitter); 458 | 459 | /***/ }) 460 | /******/ ])["default"]; 461 | }); -------------------------------------------------------------------------------- /docs/static/js/bulma-slider.min.js: -------------------------------------------------------------------------------- 1 | !function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define([],e):"object"==typeof exports?exports.bulmaSlider=e():t.bulmaSlider=e()}("undefined"!=typeof self?self:this,function(){return function(n){var r={};function i(t){if(r[t])return r[t].exports;var e=r[t]={i:t,l:!1,exports:{}};return n[t].call(e.exports,e,e.exports,i),e.l=!0,e.exports}return i.m=n,i.c=r,i.d=function(t,e,n){i.o(t,e)||Object.defineProperty(t,e,{configurable:!1,enumerable:!0,get:n})},i.n=function(t){var e=t&&t.__esModule?function(){return t.default}:function(){return t};return i.d(e,"a",e),e},i.o=function(t,e){return Object.prototype.hasOwnProperty.call(t,e)},i.p="",i(i.s=0)}([function(t,e,n){"use strict";Object.defineProperty(e,"__esModule",{value:!0}),n.d(e,"isString",function(){return l});var r=n(1),i=Object.assign||function(t){for(var e=1;e=l.length&&(s=!0)):s=!0),s&&(t.once&&(u[e]=null),t.callback(r))});-1!==u.indexOf(null);)u.splice(u.indexOf(null),1)}}]),e}();e.a=i}]).default}); -------------------------------------------------------------------------------- /docs/static/js/index.js: -------------------------------------------------------------------------------- 1 | window.HELP_IMPROVE_VIDEOJS = false; 2 | 3 | var INTERP_BASE = "./static/interpolation/stacked"; 4 | var NUM_INTERP_FRAMES = 240; 5 | 6 | var interp_images = []; 7 | function preloadInterpolationImages() { 8 | for (var i = 0; i < NUM_INTERP_FRAMES; i++) { 9 | var path = INTERP_BASE + '/' + String(i).padStart(6, '0') + '.jpg'; 10 | interp_images[i] = new Image(); 11 | interp_images[i].src = path; 12 | } 13 | } 14 | 15 | function setInterpolationImage(i) { 16 | var image = interp_images[i]; 17 | image.ondragstart = function() { return false; }; 18 | image.oncontextmenu = function() { return false; }; 19 | $('#interpolation-image-wrapper').empty().append(image); 20 | } 21 | 22 | 23 | $(document).ready(function() { 24 | // Check for click events on the navbar burger icon 25 | $(".navbar-burger").click(function() { 26 | // Toggle the "is-active" class on both the "navbar-burger" and the "navbar-menu" 27 | $(".navbar-burger").toggleClass("is-active"); 28 | $(".navbar-menu").toggleClass("is-active"); 29 | 30 | }); 31 | 32 | var options = { 33 | slidesToScroll: 1, 34 | slidesToShow: 3, 35 | loop: true, 36 | infinite: true, 37 | autoplay: false, 38 | autoplaySpeed: 3000, 39 | } 40 | 41 | // Initialize all div with carousel class 42 | var carousels = bulmaCarousel.attach('.carousel', options); 43 | 44 | // Loop on each carousel initialized 45 | for(var i = 0; i < carousels.length; i++) { 46 | // Add listener to event 47 | carousels[i].on('before:show', state => { 48 | console.log(state); 49 | }); 50 | } 51 | 52 | // Access to bulmaCarousel instance of an element 53 | var element = document.querySelector('#my-element'); 54 | if (element && element.bulmaCarousel) { 55 | // bulmaCarousel instance is available as element.bulmaCarousel 56 | element.bulmaCarousel.on('before-show', function(state) { 57 | console.log(state); 58 | }); 59 | } 60 | 61 | /*var player = document.getElementById('interpolation-video'); 62 | player.addEventListener('loadedmetadata', function() { 63 | $('#interpolation-slider').on('input', function(event) { 64 | console.log(this.value, player.duration); 65 | player.currentTime = player.duration / 100 * this.value; 66 | }) 67 | }, false);*/ 68 | preloadInterpolationImages(); 69 | 70 | $('#interpolation-slider').on('input', function(event) { 71 | setInterpolationImage(this.value); 72 | }); 73 | setInterpolationImage(0); 74 | $('#interpolation-slider').prop('max', NUM_INTERP_FRAMES - 1); 75 | 76 | bulmaSlider.attach(); 77 | 78 | }) 79 | -------------------------------------------------------------------------------- /docs/static/js/video_comparison.js: -------------------------------------------------------------------------------- 1 | // This is based on: http://thenewcode.com/364/Interactive-Before-and-After-Video-Comparison-in-HTML5-Canvas 2 | // With additional modifications based on: https://jsfiddle.net/7sk5k4gp/13/ 3 | 4 | function playVids(videoId) { 5 | var videoMerge = document.getElementById(videoId + "Merge"); 6 | var vid = document.getElementById(videoId); 7 | 8 | var position = 0.5; 9 | var vidWidth = vid.videoWidth/2; 10 | var vidHeight = vid.videoHeight; 11 | 12 | var mergeContext = videoMerge.getContext("2d"); 13 | 14 | 15 | if (vid.readyState > 3) { 16 | vid.play(); 17 | 18 | function trackLocation(e) { 19 | // Normalize to [0, 1] 20 | bcr = videoMerge.getBoundingClientRect(); 21 | position = ((e.pageX - bcr.x) / bcr.width); 22 | } 23 | function trackLocationTouch(e) { 24 | // Normalize to [0, 1] 25 | bcr = videoMerge.getBoundingClientRect(); 26 | position = ((e.touches[0].pageX - bcr.x) / bcr.width); 27 | } 28 | 29 | videoMerge.addEventListener("mousemove", trackLocation, false); 30 | videoMerge.addEventListener("touchstart", trackLocationTouch, false); 31 | videoMerge.addEventListener("touchmove", trackLocationTouch, false); 32 | 33 | 34 | function drawLoop() { 35 | mergeContext.drawImage(vid, 0, 0, vidWidth, vidHeight, 0, 0, vidWidth, vidHeight); 36 | var colStart = (vidWidth * position).clamp(0.0, vidWidth); 37 | var colWidth = (vidWidth - (vidWidth * position)).clamp(0.0, vidWidth); 38 | mergeContext.drawImage(vid, colStart+vidWidth, 0, colWidth, vidHeight, colStart, 0, colWidth, vidHeight); 39 | requestAnimationFrame(drawLoop); 40 | 41 | 42 | var arrowLength = 0.07 * vidHeight; 43 | var arrowheadWidth = 0.020 * vidHeight; 44 | var arrowheadLength = 0.04 * vidHeight; 45 | var arrowPosY = vidHeight / 10; 46 | var arrowWidth = 0.007 * vidHeight; 47 | var currX = vidWidth * position; 48 | 49 | // Draw circle 50 | mergeContext.arc(currX, arrowPosY, arrowLength*0.7, 0, Math.PI * 2, false); 51 | mergeContext.fillStyle = "#FFD79340"; 52 | mergeContext.fill() 53 | //mergeContext.strokeStyle = "#444444"; 54 | //mergeContext.stroke() 55 | 56 | // Draw border 57 | mergeContext.beginPath(); 58 | mergeContext.moveTo(vidWidth*position, 0); 59 | mergeContext.lineTo(vidWidth*position, vidHeight); 60 | mergeContext.closePath() 61 | mergeContext.strokeStyle = "#444444"; 62 | mergeContext.lineWidth = 3; 63 | mergeContext.stroke(); 64 | 65 | // Draw arrow 66 | mergeContext.beginPath(); 67 | mergeContext.moveTo(currX, arrowPosY - arrowWidth/2); 68 | 69 | // Move right until meeting arrow head 70 | mergeContext.lineTo(currX + arrowLength/2 - arrowheadLength/2, arrowPosY - arrowWidth/2); 71 | 72 | // Draw right arrow head 73 | mergeContext.lineTo(currX + arrowLength/2 - arrowheadLength/2, arrowPosY - arrowheadWidth/2); 74 | mergeContext.lineTo(currX + arrowLength/2, arrowPosY); 75 | mergeContext.lineTo(currX + arrowLength/2 - arrowheadLength/2, arrowPosY + arrowheadWidth/2); 76 | mergeContext.lineTo(currX + arrowLength/2 - arrowheadLength/2, arrowPosY + arrowWidth/2); 77 | 78 | // Go back to the left until meeting left arrow head 79 | mergeContext.lineTo(currX - arrowLength/2 + arrowheadLength/2, arrowPosY + arrowWidth/2); 80 | 81 | // Draw left arrow head 82 | mergeContext.lineTo(currX - arrowLength/2 + arrowheadLength/2, arrowPosY + arrowheadWidth/2); 83 | mergeContext.lineTo(currX - arrowLength/2, arrowPosY); 84 | mergeContext.lineTo(currX - arrowLength/2 + arrowheadLength/2, arrowPosY - arrowheadWidth/2); 85 | mergeContext.lineTo(currX - arrowLength/2 + arrowheadLength/2, arrowPosY); 86 | 87 | mergeContext.lineTo(currX - arrowLength/2 + arrowheadLength/2, arrowPosY - arrowWidth/2); 88 | mergeContext.lineTo(currX, arrowPosY - arrowWidth/2); 89 | 90 | mergeContext.closePath(); 91 | 92 | mergeContext.fillStyle = "#444444"; 93 | mergeContext.fill(); 94 | 95 | 96 | 97 | } 98 | requestAnimationFrame(drawLoop); 99 | } 100 | } 101 | 102 | Number.prototype.clamp = function(min, max) { 103 | return Math.min(Math.max(this, min), max); 104 | }; 105 | 106 | 107 | function resizeAndPlay(element) 108 | { 109 | var cv = document.getElementById(element.id + "Merge"); 110 | cv.width = element.videoWidth/2; 111 | cv.height = element.videoHeight; 112 | element.play(); 113 | element.style.height = "0px"; // Hide video without stopping it 114 | 115 | playVids(element.id); 116 | } 117 | -------------------------------------------------------------------------------- /docs/static/video_demos_compressed/1080x540/beauty_0_base_video_hash_transformed_dual(1).mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/1080x540/beauty_0_base_video_hash_transformed_dual(1).mp4 -------------------------------------------------------------------------------- /docs/static/video_demos_compressed/1080x540/beauty_0_base_video_hash_transformed_dual(2).mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/1080x540/beauty_0_base_video_hash_transformed_dual(2).mp4 -------------------------------------------------------------------------------- /docs/static/video_demos_compressed/1080x540/beauty_1_base_transformed_dual(1).mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/1080x540/beauty_1_base_transformed_dual(1).mp4 -------------------------------------------------------------------------------- /docs/static/video_demos_compressed/1080x540/beauty_1_base_transformed_dual(2).mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/1080x540/beauty_1_base_transformed_dual(2).mp4 -------------------------------------------------------------------------------- /docs/static/video_demos_compressed/1080x540/rainbow.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/1080x540/rainbow.mp4 -------------------------------------------------------------------------------- /docs/static/video_demos_compressed/1080x540/tifa.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/1080x540/tifa.mp4 -------------------------------------------------------------------------------- /docs/static/video_demos_compressed/1080x960/dog.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/1080x960/dog.mp4 -------------------------------------------------------------------------------- /docs/static/video_demos_compressed/1080x960/scene_1_base_transformed_dual.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/1080x960/scene_1_base_transformed_dual.mp4 -------------------------------------------------------------------------------- /docs/static/video_demos_compressed/1080x960/smoke_colorful.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/1080x960/smoke_colorful.mp4 -------------------------------------------------------------------------------- /docs/static/video_demos_compressed/1080x960/smoke_ink.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/1080x960/smoke_ink.mp4 -------------------------------------------------------------------------------- /docs/static/video_demos_compressed/1920x540/cloud_atlas4_base_transformed_dual.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/1920x540/cloud_atlas4_base_transformed_dual.mp4 -------------------------------------------------------------------------------- /docs/static/video_demos_compressed/1920x540/diamond_1_base_transformed_dual.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/1920x540/diamond_1_base_transformed_dual.mp4 -------------------------------------------------------------------------------- /docs/static/video_demos_compressed/1920x540/lemon_earth.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/1920x540/lemon_earth.mp4 -------------------------------------------------------------------------------- /docs/static/video_demos_compressed/1920x540/long_season.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/1920x540/long_season.mp4 -------------------------------------------------------------------------------- /docs/static/video_demos_compressed/1920x540/scene_3_base_transformed_dual.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/1920x540/scene_3_base_transformed_dual.mp4 -------------------------------------------------------------------------------- /docs/static/video_demos_compressed/1920x540/titanic.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/1920x540/titanic.mp4 -------------------------------------------------------------------------------- /docs/static/video_demos_compressed/slider/slider.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/slider/slider.mp4 -------------------------------------------------------------------------------- /docs/static/video_demos_compressed/teaser/cloud_atlas_SR_dual.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/teaser/cloud_atlas_SR_dual.mp4 -------------------------------------------------------------------------------- /docs/static/video_demos_compressed/teaser/segtrack.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/teaser/segtrack.mp4 -------------------------------------------------------------------------------- /docs/static/video_demos_compressed/teaser/teaser.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/teaser/teaser.mp4 -------------------------------------------------------------------------------- /docs/static/video_demos_compressed/teaser/tracking_zoom.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/static/video_demos_compressed/teaser/tracking_zoom.mp4 -------------------------------------------------------------------------------- /docs/teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/CoDeF/b1d6f3677561675cf7672e9c60b865319847d254/docs/teaser.gif -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import torchvision 4 | from einops import rearrange, reduce, repeat 5 | 6 | 7 | class MSELoss(nn.Module): 8 | def __init__(self, coef=1): 9 | super().__init__() 10 | self.coef = coef 11 | self.loss = nn.MSELoss(reduction='mean') 12 | 13 | def forward(self, inputs, targets): 14 | loss = self.loss(inputs, targets) 15 | return self.coef * loss 16 | 17 | 18 | def rgb_to_gray(image): 19 | gray_image = (0.299 * image[:, 0, :, :] + 0.587 * image[:, 1, :, :] + 20 | 0.114 * image[:, 2, :, :]) 21 | gray_image = gray_image.unsqueeze(1) 22 | 23 | return gray_image 24 | 25 | 26 | def compute_gradient_loss(pred, gt, mask): 27 | assert pred.shape == gt.shape, "a and b must have the same shape" 28 | 29 | pred = rgb_to_gray(pred) 30 | gt = rgb_to_gray(gt) 31 | 32 | sobel_kernel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=pred.dtype, device=pred.device) 33 | sobel_kernel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=pred.dtype, device=pred.device) 34 | 35 | gradient_a_x = torch.nn.functional.conv2d(pred.repeat(1,3,1,1), sobel_kernel_x.unsqueeze(0).unsqueeze(0).repeat(1,3,1,1), padding=1)/3 36 | gradient_a_y = torch.nn.functional.conv2d(pred.repeat(1,3,1,1), sobel_kernel_y.unsqueeze(0).unsqueeze(0).repeat(1,3,1,1), padding=1)/3 37 | # gradient_a_magnitude = torch.sqrt(gradient_a_x ** 2 + gradient_a_y ** 2) 38 | 39 | gradient_b_x = torch.nn.functional.conv2d(gt.repeat(1,3,1,1), sobel_kernel_x.unsqueeze(0).unsqueeze(0).repeat(1,3,1,1), padding=1)/3 40 | gradient_b_y = torch.nn.functional.conv2d(gt.repeat(1,3,1,1), sobel_kernel_y.unsqueeze(0).unsqueeze(0).repeat(1,3,1,1), padding=1)/3 41 | # gradient_b_magnitude = torch.sqrt(gradient_b_x ** 2 + gradient_b_y ** 2) 42 | 43 | pred_grad = torch.cat([gradient_a_x, gradient_a_y], dim=1) 44 | gt_grad = torch.cat([gradient_b_x, gradient_b_y], dim=1) 45 | 46 | gradient_difference = torch.abs(pred_grad - gt_grad).mean(dim=1,keepdim=True)[mask].sum()/(mask.sum()+1e-8) 47 | 48 | return gradient_difference 49 | 50 | 51 | loss_dict = {'mse': MSELoss} 52 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from kornia.losses import ssim as dssim 3 | from skimage.metrics import structural_similarity 4 | from einops import rearrange 5 | import numpy as np 6 | 7 | def mse(image_pred, image_gt, valid_mask=None, reduction='mean'): 8 | value = (image_pred-image_gt)**2 9 | if valid_mask is not None: 10 | value = value[valid_mask] 11 | if reduction == 'mean': 12 | return torch.mean(value) 13 | return value 14 | 15 | def psnr(image_pred, image_gt, valid_mask=None, reduction='mean'): 16 | return -10*torch.log10(mse(image_pred, image_gt, valid_mask, reduction)) 17 | 18 | def ssim(image_pred, image_gt, reduction='mean'): 19 | return structural_similarity(image_pred.cpu().numpy(), image_gt, win_size=11, multichannel=True, gaussian_weights=True) 20 | 21 | def lpips(image_pred, image_gt, lpips_model): 22 | gt_lpips = image_gt * 2.0 - 1.0 23 | gt_lpips = rearrange(gt_lpips, '(b h) w c -> b c h w', b=1) 24 | gt_lpips = torch.from_numpy(gt_lpips) 25 | predict_image_lpips = image_pred.clone().detach().cpu() * 2.0 - 1.0 26 | predict_image_lpips = rearrange(predict_image_lpips, '(b h) w c -> b c h w', b=1) 27 | lpips_result = lpips_model.forward(predict_image_lpips, gt_lpips).cpu().detach().numpy() 28 | return np.squeeze(lpips_result) 29 | -------------------------------------------------------------------------------- /models/implicit_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import math 4 | import tinycudann as tcnn 5 | 6 | 7 | def init_weights(m): 8 | if isinstance(m, nn.Linear): 9 | torch.nn.init.xavier_uniform_(m.weight) 10 | nn.init.zeros_(m.bias) 11 | 12 | 13 | class TranslationField(nn.Module): 14 | def __init__(self, D=6, W=128, 15 | in_channels_w=8, in_channels_xyz=34, 16 | skips=[4]): 17 | """ 18 | D: number of layers for density (sigma) encoder 19 | W: number of hidden units in each layer 20 | in_channels_xyz: number of input channels for xyz (2+2*8*2=34 by default) 21 | in_channels_w: number of channels for warping channels 22 | skips: add skip connection in the Dth layer 23 | """ 24 | super(TranslationField, self).__init__() 25 | self.D = D 26 | self.W = W 27 | self.skips = skips 28 | self.in_channels_xyz = in_channels_xyz 29 | self.in_channels_w = in_channels_w 30 | self.typ = "translation" 31 | 32 | # encoding layers 33 | for i in range(D): 34 | if i == 0: 35 | layer = nn.Linear(in_channels_xyz+self.in_channels_w, W) 36 | elif i in skips: 37 | layer = nn.Linear(W+in_channels_xyz+self.in_channels_w, W) 38 | else: 39 | layer = nn.Linear(W, W) 40 | init_weights(layer) 41 | layer = nn.Sequential(layer, nn.ReLU(True)) 42 | # init the models 43 | setattr(self, f"warping_field_xyz_encoding_{i+1}", layer) 44 | out_layer = nn.Linear(W, 2) 45 | nn.init.zeros_(out_layer.bias) 46 | nn.init.uniform_(out_layer.weight, -1e-4, 1e-4) 47 | self.output = nn.Sequential(out_layer) 48 | 49 | def forward(self, x): 50 | """ 51 | Encodes input xyz to warp field for points 52 | 53 | Inputs: 54 | x: (B, self.in_channels_xyz) 55 | the embedded vector of position and direction 56 | Outputs: 57 | t: warping field 58 | """ 59 | input_xyz = x 60 | 61 | xyz_ = input_xyz 62 | for i in range(self.D): 63 | if i in self.skips: 64 | xyz_ = torch.cat([input_xyz, xyz_], -1) 65 | xyz_ = getattr(self, f"warping_field_xyz_encoding_{i+1}")(xyz_) 66 | 67 | t = self.output(xyz_) 68 | 69 | return t 70 | 71 | 72 | class Embedding(nn.Module): 73 | def __init__(self, in_channels, N_freqs, logscale=True, identity=True): 74 | """ 75 | Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...) 76 | in_channels: number of input channels (3 for both xyz and direction) 77 | """ 78 | super(Embedding, self).__init__() 79 | self.N_freqs = N_freqs 80 | self.annealed = False 81 | self.identity = identity 82 | self.in_channels = in_channels 83 | self.funcs = [torch.sin, torch.cos] 84 | self.out_channels = in_channels*(len(self.funcs)*N_freqs+1) 85 | 86 | if logscale: 87 | self.freq_bands = 2**torch.linspace(0, N_freqs-1, N_freqs) 88 | else: 89 | self.freq_bands = torch.linspace(1, 2**(N_freqs-1), N_freqs) 90 | 91 | def forward(self, x): 92 | """ 93 | Embeds x to (x, sin(2^k x), cos(2^k x), ...) 94 | Different from the paper, "x" is also in the output 95 | See https://github.com/bmild/nerf/issues/12 96 | 97 | Inputs: 98 | x: (B, self.in_channels) 99 | 100 | Outputs: 101 | out: (B, self.out_channels) 102 | """ 103 | if self.identity: 104 | out = [x] 105 | else: 106 | out = [] 107 | for freq in self.freq_bands: 108 | for func in self.funcs: 109 | out += [func(freq*x)] 110 | 111 | return torch.cat(out, -1) 112 | 113 | 114 | class AnnealedEmbedding(nn.Module): 115 | def __init__(self, in_channels, N_freqs, annealed_step, annealed_begin_step=0, logscale=True, identity=True): 116 | """ 117 | Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...) 118 | in_channels: number of input channels (3 for both xyz and direction) 119 | """ 120 | super(AnnealedEmbedding, self).__init__() 121 | self.N_freqs = N_freqs 122 | self.in_channels = in_channels 123 | self.annealed = True 124 | self.annealed_step = annealed_step 125 | self.annealed_begin_step = annealed_begin_step 126 | self.funcs = [torch.sin, torch.cos] 127 | self.out_channels = in_channels*(len(self.funcs)*N_freqs+1) 128 | self.index = torch.linspace(0, N_freqs-1, N_freqs) 129 | self.identity = identity 130 | 131 | if logscale: 132 | self.freq_bands = 2**torch.linspace(0, N_freqs-1, N_freqs) 133 | else: 134 | self.freq_bands = torch.linspace(1, 2**(N_freqs-1), N_freqs) 135 | 136 | def forward(self, x, step): 137 | """ 138 | Embeds x to (x, sin(2^k x), cos(2^k x), ...) 139 | Different from the paper, "x" is also in the output 140 | See https://github.com/bmild/nerf/issues/12 141 | 142 | Inputs: 143 | x: (B, self.in_channels) 144 | 145 | Outputs: 146 | out: (B, self.out_channels) 147 | """ 148 | if self.identity: 149 | out = [x] 150 | else: 151 | out = [] 152 | 153 | if self.annealed_begin_step == 0: 154 | # calculate the w for each freq bands 155 | alpha = self.N_freqs * step / float(self.annealed_step) 156 | else: 157 | if step <= self.annealed_begin_step: 158 | alpha = 0 159 | else: 160 | alpha = self.N_freqs * (step - self.annealed_begin_step) / float( 161 | self.annealed_step) 162 | 163 | for j, freq in enumerate(self.freq_bands): 164 | w = (1 - torch.cos( 165 | math.pi * torch.clamp(alpha - self.index[j], 0, 1))) / 2 166 | for func in self.funcs: 167 | out += [w * func(freq*x)] 168 | 169 | return torch.cat(out, -1) 170 | 171 | 172 | class AnnealedHash(nn.Module): 173 | def __init__(self, in_channels, annealed_step, annealed_begin_step=0, identity=True): 174 | """ 175 | Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...) 176 | in_channels: number of input channels (3 for both xyz and direction) 177 | """ 178 | super(AnnealedHash, self).__init__() 179 | self.N_freqs = 16 180 | self.in_channels = in_channels 181 | self.annealed = True 182 | self.annealed_step = annealed_step 183 | self.annealed_begin_step = annealed_begin_step 184 | 185 | self.index = torch.linspace(0, self.N_freqs - 1, self.N_freqs) 186 | self.identity = identity 187 | 188 | self.index_2 = self.index.view(-1, 1).repeat(1, 2).view(-1) 189 | 190 | def forward(self, x_embed, step): 191 | """ 192 | Embeds x to (x, sin(2^k x), cos(2^k x), ...) 193 | Different from the paper, "x" is also in the output 194 | See https://github.com/bmild/nerf/issues/12 195 | 196 | Inputs: 197 | x: (B, self.in_channels) 198 | 199 | Outputs: 200 | out: (B, self.out_channels) 201 | """ 202 | 203 | if self.annealed_begin_step == 0: 204 | # calculate the w for each freq bands 205 | alpha = self.N_freqs * step / float(self.annealed_step) 206 | else: 207 | if step <= self.annealed_begin_step: 208 | alpha = 0 209 | else: 210 | alpha = self.N_freqs * (step - self.annealed_begin_step) / float( 211 | self.annealed_step) 212 | 213 | w = (1 - torch.cos(math.pi * torch.clamp(alpha * torch.ones_like(self.index_2) - self.index_2, 0, 1))) / 2 214 | 215 | out = x_embed * w.to(x_embed.device) 216 | 217 | return out 218 | 219 | 220 | class ImplicitVideo(nn.Module): 221 | def __init__(self, 222 | D=8, W=256, 223 | in_channels_xyz=34, 224 | skips=[4], 225 | out_channels=3, 226 | sigmoid_offset=0): 227 | """ 228 | D: number of layers for density (sigma) encoder 229 | W: number of hidden units in each layer 230 | in_channels_xyz: number of input channels for xyz (3+3*8*2=51 by default) 231 | skips: add skip connection in the Dth layer 232 | 233 | ------ for nerfies ------ 234 | encode_warp: whether to encode warping 235 | in_channels_w: dimension of warping embeddings 236 | """ 237 | super(ImplicitVideo, self).__init__() 238 | self.D = D 239 | self.W = W 240 | self.in_channels_xyz = in_channels_xyz 241 | self.skips = skips 242 | self.in_channels_xyz = self.in_channels_xyz 243 | self.sigmoid_offset = sigmoid_offset 244 | 245 | # xyz encoding layers 246 | for i in range(D): 247 | if i == 0: 248 | layer = nn.Linear(self.in_channels_xyz, W) 249 | elif i in skips: 250 | layer = nn.Linear(W+self.in_channels_xyz, W) 251 | else: 252 | layer = nn.Linear(W, W) 253 | init_weights(layer) 254 | layer = nn.Sequential(layer, nn.ReLU(True)) 255 | setattr(self, f"xyz_encoding_{i+1}", layer) 256 | self.xyz_encoding_final = nn.Linear(W, W) 257 | init_weights(self.xyz_encoding_final) 258 | 259 | # output layers 260 | self.rgb = nn.Sequential(nn.Linear(W, out_channels)) 261 | 262 | self.rgb.apply(init_weights) 263 | 264 | def forward(self, x): 265 | """ 266 | Encodes input (xyz+dir) to rgb+sigma (not ready to render yet). 267 | For rendering this ray, please see rendering.py 268 | 269 | Inputs: 270 | x: (B, self.in_channels_xyz) 271 | the embedded vector of position and direction 272 | sigma_only: whether to infer sigma only. If True, 273 | x is of shape (B, self.in_channels_xyz) 274 | 275 | Outputs: 276 | if sigma_ony: 277 | sigma: (B, 1) sigma 278 | else: 279 | out: (B, 4), rgb and sigma 280 | """ 281 | input_xyz = x 282 | 283 | xyz_ = input_xyz 284 | for i in range(self.D): 285 | if i in self.skips: 286 | xyz_ = torch.cat([input_xyz, xyz_], -1) 287 | xyz_ = getattr(self, f"xyz_encoding_{i+1}")(xyz_) 288 | 289 | xyz_encoding_final = self.xyz_encoding_final(xyz_) 290 | out = self.rgb(xyz_encoding_final) 291 | 292 | out = torch.sigmoid(out) - self.sigmoid_offset 293 | 294 | return out 295 | 296 | 297 | class ImplicitVideo_Hash(nn.Module): 298 | def __init__(self, config): 299 | super().__init__() 300 | self.encoder = tcnn.Encoding(n_input_dims=2, 301 | encoding_config=config["encoding"]) 302 | self.decoder = tcnn.Network(n_input_dims=self.encoder.n_output_dims + 303 | 2, 304 | n_output_dims=3, 305 | network_config=config["network"]) 306 | 307 | def forward(self, x): 308 | input = x 309 | input = self.encoder(input) 310 | input = torch.cat([x, input], dim=-1) 311 | weight = torch.ones(input.shape[-1], device=input.device).cuda() 312 | x = self.decoder(weight * input) 313 | return x 314 | 315 | 316 | class Deform_Hash3d(nn.Module): 317 | def __init__(self, config): 318 | super().__init__() 319 | self.encoder = tcnn.Encoding(n_input_dims=3, 320 | encoding_config=config["encoding_deform3d"]) 321 | self.decoder = tcnn.Network(n_input_dims=self.encoder.n_output_dims + 3, 322 | n_output_dims=2, 323 | network_config=config["network_deform"]) 324 | 325 | def forward(self, x, step=0, aneal_func=None): 326 | input = x 327 | input = self.encoder(input) 328 | if aneal_func is not None: 329 | input = torch.cat([x, aneal_func(input,step)], dim=-1) 330 | else: 331 | input = torch.cat([x, input], dim=-1) 332 | 333 | weight = torch.ones(input.shape[-1], device=input.device).cuda() 334 | x = self.decoder(weight * input) / 5 335 | 336 | return x 337 | 338 | 339 | class Deform_Hash3d_Warp(nn.Module): 340 | def __init__(self, config): 341 | super().__init__() 342 | self.Deform_Hash3d = Deform_Hash3d(config) 343 | 344 | def forward(self, xyt_norm, step=0,aneal_func=None): 345 | x = self.Deform_Hash3d(xyt_norm,step=step, aneal_func=aneal_func) 346 | 347 | return x 348 | -------------------------------------------------------------------------------- /opt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | 4 | 5 | def get_opts(): 6 | parser = argparse.ArgumentParser() 7 | 8 | # General Setttings 9 | parser.add_argument('--root_dir', type=str, default='Batman_masked_frames', 10 | help='root directory of dataset') 11 | parser.add_argument('--canonical_dir', type=str, default=None, 12 | help='directory of canonical dataset') 13 | 14 | # support multiple mask as input (each mask has different deformation fields) 15 | parser.add_argument('--mask_dir', nargs="+", type=str, default=None, 16 | help='mask of the dataset') 17 | parser.add_argument('--flow_dir', type=str, 18 | default=None, 19 | help='masks of dataset') 20 | parser.add_argument('--dataset_name', type=str, default='video', 21 | choices=['video'], 22 | help='which dataset to train/val') 23 | parser.add_argument('--img_wh', nargs="+", type=int, default=[842, 512], 24 | help='resolution (img_w, img_h) of the full image') 25 | parser.add_argument('--canonical_wh', nargs="+", type=int, default=None, 26 | help='default same as the img_wh, can be set to a larger range to include more content') 27 | parser.add_argument('--ref_idx', type=int, default=None, 28 | help='manually select a frame as reference (for rigid movement)') 29 | 30 | # Deformation Setting 31 | parser.add_argument('--encode_w', default=False, action="store_true", 32 | help='whether to apply warping') 33 | 34 | # Training Setttings 35 | 36 | parser.add_argument('--batch_size', type=int, default=1, 37 | help='batch size') 38 | parser.add_argument('--num_steps', type=int, default=10000, 39 | help='number of training epochs') 40 | parser.add_argument('--valid_iters', type=int, default=30, 41 | help='valid iters for each epoch') 42 | parser.add_argument('--valid_batches', type=int, default=0, 43 | help='valid batches for each valid process') 44 | parser.add_argument('--save_model_iters', type=int, default=5000, 45 | help='iterations to save the models') 46 | parser.add_argument('--gpus', nargs="+", type=int, default=[0], 47 | help='gpu devices') 48 | 49 | # Test Setttings 50 | parser.add_argument('--test', default=False, action="store_true", 51 | help='whether to disable identity') 52 | 53 | # Model Save and Load 54 | parser.add_argument('--ckpt_path', type=str, default=None, 55 | help='pretrained checkpoint to load (including optimizers, etc)') 56 | parser.add_argument('--prefixes_to_ignore', nargs='+', type=str, default=['loss'], 57 | help='the prefixes to ignore in the checkpoint state dict') 58 | parser.add_argument('--weight_path', type=str, default=None, 59 | help='pretrained model weight to load (do not load optimizers, etc)') 60 | parser.add_argument('--model_save_path', type=str, default='ckpts', 61 | help='save checkpoint to') 62 | parser.add_argument('--log_save_path', type=str, default='logs', 63 | help='save log to') 64 | parser.add_argument('--exp_name', type=str, default='exp', 65 | help='experiment name') 66 | 67 | # Optimize Settings 68 | parser.add_argument('--optimizer', type=str, default='adam', 69 | help='optimizer type', 70 | choices=['sgd', 'adam', 'radam', 'ranger']) 71 | parser.add_argument('--lr', type=float, default=5e-4, 72 | help='learning rate') 73 | parser.add_argument('--momentum', type=float, default=0.9, 74 | help='learning rate momentum') 75 | parser.add_argument('--weight_decay', type=float, default=0, 76 | help='weight decay') 77 | parser.add_argument('--lr_scheduler', type=str, default='steplr', 78 | help='scheduler type', 79 | choices=['steplr', 'cosine', 'poly', 'exponential']) 80 | 81 | #### params for steplr #### 82 | parser.add_argument('--decay_step', nargs='+', type=int, 83 | default=[2500, 5000, 7500], 84 | help='scheduler decay step') 85 | parser.add_argument('--decay_gamma', type=float, default=0.5, 86 | help='learning rate decay amount') 87 | 88 | #### params for warmup, only applied when optimizer == 'sgd' or 'adam' 89 | parser.add_argument('--warmup_multiplier', type=float, default=1.0, 90 | help='lr is multiplied by this factor after --warmup_epochs') 91 | parser.add_argument('--warmup_epochs', type=int, default=0, 92 | help='Gradually warm-up(increasing) learning rate in optimizer') 93 | 94 | ##### annealed positional encoding ###### 95 | parser.add_argument('--annealed', default=False, action="store_true", 96 | help='whether to apply annealed positional encoding (Only in the warping field)') 97 | parser.add_argument('--annealed_begin_step', type=int, default=0, 98 | help='annealed step to begin for positional encoding') 99 | parser.add_argument('--annealed_step', type=int, default=5000, 100 | help='maximum annealed step for positional encoding') 101 | 102 | ##### Additional losses ###### 103 | parser.add_argument('--flow_loss', type=float, default=None, 104 | help='optical flow loss weight') 105 | parser.add_argument('--bg_loss', type=float, default=None, 106 | help='regularize the rest part of each object ') 107 | parser.add_argument('--grad_loss', type=float, default=0.1, 108 | help='image gradient loss weight') 109 | parser.add_argument('--flow_step', type=int, default=-1, 110 | help='Step to begin to perform flow loss.') 111 | parser.add_argument('--ref_step', type=int, default=-1, 112 | help='Step to stop reference frame loss.') 113 | parser.add_argument('--self_bg', type=bool_parser, default=False, 114 | help='Whether to use self background as bg loss.') 115 | 116 | ##### Special cases: for black-dominated images 117 | parser.add_argument('--sigmoid_offset', type=float, default=0, 118 | help='whether to process balck-dominated images.') 119 | 120 | # Other miscellaneous settings. 121 | parser.add_argument('--save_deform', type=bool_parser, default=False, 122 | help='Whether to save deformation field or not.') 123 | parser.add_argument('--save_video', type=bool_parser, default=True, 124 | help='Whether to save video or not.') 125 | parser.add_argument('--fps', type=int, default=30, 126 | help='FPS of the saved video.') 127 | 128 | # Network settings for PE. 129 | parser.add_argument('--deform_D', type=int, default=6, 130 | help='The depth of deformation field MLP.') 131 | parser.add_argument('--deform_W', type=int, default=128, 132 | help='The width of deformation field MLP.') 133 | parser.add_argument('--vid_D', type=int, default=8, 134 | help='The depth of implicit video MLP.') 135 | parser.add_argument('--vid_W', type=int, default=256, 136 | help='The width of implicit video MLP.') 137 | parser.add_argument('--N_vocab_w', type=int, default=200, 138 | help='number of vocabulary for warp code in the dataset for nn.Embedding') 139 | parser.add_argument('--N_w', type=int, default=8, 140 | help='embeddings size for warping') 141 | parser.add_argument('--N_xyz_w', nargs="+", type=int, default=[8, 8], 142 | help='positional encoding frequency of deformation field') 143 | 144 | # Network settings for Hash, please see details in configs/hash.json 145 | parser.add_argument('--vid_hash', type=bool_parser, default=False, 146 | help='Whether to use hash encoding in implicit video system.') 147 | parser.add_argument('--deform_hash', type=bool_parser, default=False, 148 | help='Whether to use hash encoding in deformation field.') 149 | 150 | # Config files 151 | parser.add_argument('--config', type=str, default=None, 152 | help='path to the YAML config file.') 153 | 154 | args = parser.parse_args() 155 | 156 | if args.config is not None: 157 | with open(args.config, 'r') as f: 158 | config = yaml.safe_load(f) 159 | args_dict = vars(args) 160 | args_dict.update(config) 161 | args_new = argparse.Namespace(**args_dict) 162 | return args_new 163 | 164 | return args 165 | 166 | 167 | def bool_parser(arg): 168 | """Parses an argument to boolean.""" 169 | if isinstance(arg, bool): 170 | return arg 171 | if arg is None: 172 | return False 173 | if arg.lower() in ['1', 'true', 't', 'yes', 'y']: 174 | return True 175 | if arg.lower() in ['0', 'false', 'f', 'no', 'n']: 176 | return False 177 | raise ValueError(f'`{arg}` cannot be converted to boolean!') 178 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch_lightning==2.0.2 2 | easydict==1.10 3 | einops==0.6.1 4 | ipdb==0.13.13 5 | numpy==1.24.3 6 | opencv-python-headless==4.5.5.62 7 | PyYAML==6.0 8 | scikit-image==0.21.0 9 | scipy==1.10.1 10 | sk-video==1.1.10 11 | tensorboard==2.13.0 12 | tqdm==4.65.0 13 | torch_optimizer==0.3.0 14 | kornia==0.5.10 15 | -------------------------------------------------------------------------------- /scripts/test_canonical.sh: -------------------------------------------------------------------------------- 1 | GPUS=0 2 | 3 | NAME=scene_0 4 | EXP_NAME=base 5 | 6 | ROOT_DIRECTORY="all_sequences/$NAME/$NAME" 7 | LOG_SAVE_PATH="logs/test_all_sequences/$NAME" 8 | 9 | MASK_DIRECTORY="all_sequences/$NAME/${NAME}_masks_0 all_sequences/$NAME/${NAME}_masks_1" 10 | 11 | CANONICAL_DIR="all_sequences/${NAME}/${EXP_NAME}_control" 12 | 13 | WEIGHT_PATH=ckpts/all_sequences/$NAME/${EXP_NAME}/${NAME}.ckpt 14 | # WEIGHT_PATH=ckpts/all_sequences/$NAME/${EXP_NAME}/step=10000.ckpt 15 | 16 | python train.py --test --encode_w \ 17 | --root_dir $ROOT_DIRECTORY \ 18 | --log_save_path $LOG_SAVE_PATH \ 19 | --mask_dir $MASK_DIRECTORY \ 20 | --weight_path $WEIGHT_PATH \ 21 | --gpus $GPUS \ 22 | --canonical_dir $CANONICAL_DIR \ 23 | --config configs/${NAME}/${EXP_NAME}.yaml \ 24 | --exp_name $EXP_NAME 25 | -------------------------------------------------------------------------------- /scripts/test_multi.sh: -------------------------------------------------------------------------------- 1 | GPUS=0 2 | 3 | NAME=scene_0 4 | EXP_NAME=base 5 | 6 | ROOT_DIRECTORY="all_sequences/$NAME/$NAME" 7 | LOG_SAVE_PATH="logs/test_all_sequences/$NAME" 8 | 9 | MASK_DIRECTORY="all_sequences/$NAME/${NAME}_masks_0 all_sequences/$NAME/${NAME}_masks_1" 10 | 11 | WEIGHT_PATH=ckpts/all_sequences/$NAME/${EXP_NAME}/${NAME}.ckpt 12 | # WEIGHT_PATH=ckpts/all_sequences/$NAME/${EXP_NAME}/step=10000.ckpt 13 | 14 | python train.py --test --encode_w \ 15 | --root_dir $ROOT_DIRECTORY \ 16 | --log_save_path $LOG_SAVE_PATH \ 17 | --mask_dir $MASK_DIRECTORY \ 18 | --weight_path $WEIGHT_PATH \ 19 | --gpus $GPUS \ 20 | --config configs/${NAME}/${EXP_NAME}.yaml \ 21 | --exp_name ${EXP_NAME} \ 22 | --save_deform False 23 | -------------------------------------------------------------------------------- /scripts/train_multi.sh: -------------------------------------------------------------------------------- 1 | GPUS=0 2 | 3 | NAME=scene_0 4 | EXP_NAME=base 5 | 6 | ROOT_DIRECTORY="all_sequences/$NAME/$NAME" 7 | MODEL_SAVE_PATH="ckpts/all_sequences/$NAME" 8 | LOG_SAVE_PATH="logs/all_sequences/$NAME" 9 | 10 | MASK_DIRECTORY="all_sequences/$NAME/${NAME}_masks_0 all_sequences/$NAME/${NAME}_masks_1" 11 | FLOW_DIRECTORY="all_sequences/$NAME/${NAME}_flow" 12 | 13 | python train.py --root_dir $ROOT_DIRECTORY \ 14 | --model_save_path $MODEL_SAVE_PATH \ 15 | --log_save_path $LOG_SAVE_PATH \ 16 | --mask_dir $MASK_DIRECTORY \ 17 | --flow_dir $FLOW_DIRECTORY \ 18 | --gpus $GPUS \ 19 | --encode_w --annealed \ 20 | --config configs/${NAME}/${EXP_NAME}.yaml \ 21 | --exp_name ${EXP_NAME} 22 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # optimizer 3 | from torch.optim import SGD, Adam 4 | import torch_optimizer as optim 5 | # scheduler 6 | from torch.optim.lr_scheduler import CosineAnnealingLR 7 | from torch.optim.lr_scheduler import MultiStepLR 8 | from torch.optim.lr_scheduler import LambdaLR 9 | from torch.optim.lr_scheduler import ExponentialLR 10 | from .warmup_scheduler import GradualWarmupScheduler 11 | 12 | from .video_visualizer import VideoVisualizer 13 | 14 | 15 | def get_parameters(models): 16 | """Get all model parameters recursively.""" 17 | parameters = [] 18 | if isinstance(models, list): 19 | # print("is list") 20 | for model in models: 21 | parameters += get_parameters(model) 22 | elif isinstance(models, dict): 23 | # print("is dict") 24 | for model in models.values(): 25 | parameters += get_parameters(model) 26 | else: # models is actually a single pytorch model 27 | parameters += list(models.parameters()) 28 | return parameters 29 | 30 | def get_optimizer(hparams, models): 31 | eps = 1e-8 32 | parameters = get_parameters(models) 33 | if hparams.optimizer == 'sgd': 34 | optimizer = SGD(parameters, lr=hparams.lr, 35 | momentum=hparams.momentum, weight_decay=hparams.weight_decay) 36 | elif hparams.optimizer == 'adam': 37 | optimizer = Adam(parameters, lr=hparams.lr, eps=eps, 38 | weight_decay=hparams.weight_decay) 39 | elif hparams.optimizer == 'radam': 40 | optimizer = optim.RAdam(parameters, lr=hparams.lr, eps=eps, 41 | weight_decay=hparams.weight_decay) 42 | elif hparams.optimizer == 'ranger': 43 | optimizer = optim.Ranger(parameters, lr=hparams.lr, eps=eps, 44 | weight_decay=hparams.weight_decay) 45 | else: 46 | raise ValueError('optimizer not recognized!') 47 | 48 | return optimizer 49 | 50 | def get_scheduler(hparams, optimizer): 51 | eps = 1e-8 52 | if hparams.lr_scheduler == 'steplr': 53 | scheduler = MultiStepLR(optimizer, milestones=hparams.decay_step, 54 | gamma=hparams.decay_gamma) 55 | elif hparams.lr_scheduler == 'cosine': 56 | scheduler = CosineAnnealingLR(optimizer, T_max=hparams.num_epochs, eta_min=eps) 57 | elif hparams.lr_scheduler == 'poly': 58 | scheduler = LambdaLR(optimizer, 59 | lambda epoch: (1-epoch/hparams.num_epochs)**hparams.poly_exp) 60 | elif hparams.lr_scheduler == 'exponential': 61 | # Adaptively adjust the schedule 62 | scheduler = LambdaLR(optimizer, 63 | lambda step: hparams.exponent_base**(step/(2 * hparams.num_steps))) 64 | else: 65 | raise ValueError('scheduler not recognized!') 66 | 67 | if hparams.warmup_epochs > 0 and hparams.optimizer not in ['radam', 'ranger']: 68 | scheduler = GradualWarmupScheduler(optimizer, multiplier=hparams.warmup_multiplier, 69 | total_epoch=hparams.warmup_epochs, after_scheduler=scheduler) 70 | 71 | return scheduler 72 | 73 | def get_learning_rate(optimizer): 74 | for param_group in optimizer.param_groups: 75 | return param_group['lr'] 76 | 77 | def extract_model_state_dict(ckpt_path, model_name='model', prefixes_to_ignore=[]): 78 | checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu')) 79 | checkpoint_ = {} 80 | if 'state_dict' in checkpoint: # if it's a pytorch-lightning checkpoint 81 | checkpoint = checkpoint['state_dict'] 82 | for k, v in checkpoint.items(): 83 | if not k.startswith(model_name): 84 | continue 85 | k = k[len(model_name)+1:] 86 | for prefix in prefixes_to_ignore: 87 | if k.startswith(prefix): 88 | print('ignore', k) 89 | break 90 | else: 91 | checkpoint_[k] = v 92 | return checkpoint_ 93 | 94 | def load_ckpt(model, ckpt_path, model_name='model', prefixes_to_ignore=[]): 95 | if not ckpt_path: 96 | return 97 | model_dict = model.state_dict() 98 | checkpoint_ = extract_model_state_dict(ckpt_path, model_name, prefixes_to_ignore) 99 | model_dict.update(checkpoint_) 100 | model.load_state_dict(model_dict) 101 | 102 | -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains utility functions for image processing. 3 | 4 | The module is primarily built on `cv2`. But, differently, we assume all colorful 5 | images are with `RGB` channel order by default. Also, we assume all gray-scale 6 | images to be with shape [height, width, 1]. 7 | """ 8 | 9 | import os 10 | import cv2 11 | import numpy as np 12 | 13 | # File extensions regarding images (not including GIFs). 14 | IMAGE_EXTENSIONS = ( 15 | '.bmp', '.ppm', '.pgm', '.jpeg', '.jpg', '.jpe', '.jp2', '.png', '.webp', 16 | '.tiff', '.tif' 17 | ) 18 | 19 | def check_file_ext(filename, *ext_list): 20 | """Checks whether the given filename is with target extension(s). 21 | 22 | NOTE: If `ext_list` is empty, this function will always return `False`. 23 | 24 | Args: 25 | filename: Filename to check. 26 | *ext_list: A list of extensions. 27 | 28 | Returns: 29 | `True` if the filename is with one of extensions in `ext_list`, 30 | otherwise `False`. 31 | """ 32 | if len(ext_list) == 0: 33 | return False 34 | ext_list = [ext if ext.startswith('.') else '.' + ext for ext in ext_list] 35 | ext_list = [ext.lower() for ext in ext_list] 36 | basename = os.path.basename(filename) 37 | ext = os.path.splitext(basename)[1].lower() 38 | return ext in ext_list 39 | 40 | 41 | def _check_2d_image(image): 42 | """Checks whether a given image is valid. 43 | 44 | A valid image is expected to be with dtype `uint8`. Also, it should have 45 | shape like: 46 | 47 | (1) (height, width, 1) # gray-scale image. 48 | (2) (height, width, 3) # colorful image. 49 | (3) (height, width, 4) # colorful image with transparency (RGBA) 50 | """ 51 | assert isinstance(image, np.ndarray) 52 | assert image.dtype == np.uint8 53 | assert image.ndim == 3 and image.shape[2] in [1, 3, 4] 54 | 55 | 56 | def get_blank_image(height, width, channels=3, use_black=True): 57 | """Gets a blank image, either white of black. 58 | 59 | NOTE: This function will always return an image with `RGB` channel order for 60 | color image and pixel range [0, 255]. 61 | 62 | Args: 63 | height: Height of the returned image. 64 | width: Width of the returned image. 65 | channels: Number of channels. (default: 3) 66 | use_black: Whether to return a black image. (default: True) 67 | """ 68 | shape = (height, width, channels) 69 | if use_black: 70 | return np.zeros(shape, dtype=np.uint8) 71 | return np.ones(shape, dtype=np.uint8) * 255 72 | 73 | 74 | def load_image(path): 75 | """Loads an image from disk. 76 | 77 | NOTE: This function will always return an image with `RGB` channel order for 78 | color image and pixel range [0, 255]. 79 | 80 | Args: 81 | path: Path to load the image from. 82 | 83 | Returns: 84 | An image with dtype `np.ndarray`, or `None` if `path` does not exist. 85 | """ 86 | image = cv2.imread(path, cv2.IMREAD_UNCHANGED) 87 | if image is None: 88 | return None 89 | 90 | if image.ndim == 2: 91 | image = image[:, :, np.newaxis] 92 | _check_2d_image(image) 93 | if image.shape[2] == 3: 94 | return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 95 | if image.shape[2] == 4: 96 | return cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) 97 | return image 98 | 99 | 100 | def save_image(path, image): 101 | """Saves an image to disk. 102 | 103 | NOTE: The input image (if colorful) is assumed to be with `RGB` channel 104 | order and pixel range [0, 255]. 105 | 106 | Args: 107 | path: Path to save the image to. 108 | image: Image to save. 109 | """ 110 | if image is None: 111 | return 112 | 113 | _check_2d_image(image) 114 | if image.shape[2] == 1: 115 | cv2.imwrite(path, image) 116 | elif image.shape[2] == 3: 117 | cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) 118 | elif image.shape[2] == 4: 119 | cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_RGBA2BGRA)) 120 | 121 | 122 | def resize_image(image, *args, **kwargs): 123 | """Resizes image. 124 | 125 | This is a wrap of `cv2.resize()`. 126 | 127 | NOTE: The channel order of the input image will not be changed. 128 | 129 | Args: 130 | image: Image to resize. 131 | *args: Additional positional arguments. 132 | **kwargs: Additional keyword arguments. 133 | 134 | Returns: 135 | An image with dtype `np.ndarray`, or `None` if `image` is empty. 136 | """ 137 | if image is None: 138 | return None 139 | 140 | _check_2d_image(image) 141 | if image.shape[2] == 1: # Re-expand the squeezed dim of gray-scale image. 142 | return cv2.resize(image, *args, **kwargs)[:, :, np.newaxis] 143 | return cv2.resize(image, *args, **kwargs) 144 | 145 | 146 | def add_text_to_image(image, 147 | text='', 148 | position=None, 149 | font=cv2.FONT_HERSHEY_TRIPLEX, 150 | font_size=1.0, 151 | line_type=cv2.LINE_8, 152 | line_width=1, 153 | color=(255, 255, 255)): 154 | """Overlays text on given image. 155 | 156 | NOTE: The input image is assumed to be with `RGB` channel order. 157 | 158 | Args: 159 | image: The image to overlay text on. 160 | text: Text content to overlay on the image. (default: empty) 161 | position: Target position (bottom-left corner) to add text. If not set, 162 | center of the image will be used by default. (default: None) 163 | font: Font of the text added. (default: cv2.FONT_HERSHEY_TRIPLEX) 164 | font_size: Font size of the text added. (default: 1.0) 165 | line_type: Line type used to depict the text. (default: cv2.LINE_8) 166 | line_width: Line width used to depict the text. (default: 1) 167 | color: Color of the text added in `RGB` channel order. (default: 168 | (255, 255, 255)) 169 | 170 | Returns: 171 | An image with target text overlaid on. 172 | """ 173 | if image is None or not text: 174 | return image 175 | 176 | _check_2d_image(image) 177 | cv2.putText(img=image, 178 | text=text, 179 | org=position, 180 | fontFace=font, 181 | fontScale=font_size, 182 | color=color, 183 | thickness=line_width, 184 | lineType=line_type, 185 | bottomLeftOrigin=False) 186 | return image 187 | 188 | 189 | def preprocess_image(image, min_val=-1.0, max_val=1.0): 190 | """Pre-processes image by adjusting the pixel range and to dtype `float32`. 191 | 192 | This function is particularly used to convert an image or a batch of images 193 | to `NCHW` format, which matches the data type commonly used in deep models. 194 | 195 | NOTE: The input image is assumed to be with pixel range [0, 255] and with 196 | format `HWC` or `NHWC`. The returned image will be always be with format 197 | `NCHW`. 198 | 199 | Args: 200 | image: The input image for pre-processing. 201 | min_val: Minimum value of the output image. 202 | max_val: Maximum value of the output image. 203 | 204 | Returns: 205 | The pre-processed image. 206 | """ 207 | assert isinstance(image, np.ndarray) 208 | 209 | image = image.astype(np.float64) 210 | image = image / 255.0 * (max_val - min_val) + min_val 211 | 212 | if image.ndim == 3: 213 | image = image[np.newaxis] 214 | assert image.ndim == 4 and image.shape[3] in [1, 3, 4] 215 | return image.transpose(0, 3, 1, 2) 216 | 217 | 218 | def postprocess_image(image, min_val=-1.0, max_val=1.0): 219 | """Post-processes image to pixel range [0, 255] with dtype `uint8`. 220 | 221 | This function is particularly used to handle the results produced by deep 222 | models. 223 | 224 | NOTE: The input image is assumed to be with format `NCHW`, and the returned 225 | image will always be with format `NHWC`. 226 | 227 | Args: 228 | image: The input image for post-processing. 229 | min_val: Expected minimum value of the input image. 230 | max_val: Expected maximum value of the input image. 231 | 232 | Returns: 233 | The post-processed image. 234 | """ 235 | assert isinstance(image, np.ndarray) 236 | 237 | image = image.astype(np.float64) 238 | image = (image - min_val) / (max_val - min_val) * 255 239 | image = np.clip(image + 0.5, 0, 255).astype(np.uint8) 240 | 241 | assert image.ndim == 4 and image.shape[1] in [1, 3, 4] 242 | return image.transpose(0, 2, 3, 1) 243 | 244 | 245 | def parse_image_size(obj): 246 | """Parses an object to a pair of image size, i.e., (height, width). 247 | 248 | Args: 249 | obj: The input object to parse image size from. 250 | 251 | Returns: 252 | A two-element tuple, indicating image height and width respectively. 253 | 254 | Raises: 255 | If the input is invalid, i.e., neither a list or tuple, nor a string. 256 | """ 257 | if obj is None or obj == '': 258 | height = 0 259 | width = 0 260 | elif isinstance(obj, int): 261 | height = obj 262 | width = obj 263 | elif isinstance(obj, (list, tuple, str, np.ndarray)): 264 | if isinstance(obj, str): 265 | splits = obj.replace(' ', '').split(',') 266 | numbers = tuple(map(int, splits)) 267 | else: 268 | numbers = tuple(obj) 269 | if len(numbers) == 0: 270 | height = 0 271 | width = 0 272 | elif len(numbers) == 1: 273 | height = int(numbers[0]) 274 | width = int(numbers[0]) 275 | elif len(numbers) == 2: 276 | height = int(numbers[0]) 277 | width = int(numbers[1]) 278 | else: 279 | raise ValueError('At most two elements for image size.') 280 | else: 281 | raise ValueError(f'Invalid type of input: `{type(obj)}`!') 282 | 283 | return (max(0, height), max(0, width)) 284 | 285 | 286 | def get_grid_shape(size, height=0, width=0, is_portrait=False): 287 | """Gets the shape of a grid based on the size. 288 | 289 | This function makes greatest effort on making the output grid square if 290 | neither `height` nor `width` is set. If `is_portrait` is set as `False`, the 291 | height will always be equal to or smaller than the width. For example, if 292 | input `size = 16`, output shape will be `(4, 4)`; if input `size = 15`, 293 | output shape will be (3, 5). Otherwise, the height will always be equal to 294 | or larger than the width. 295 | 296 | Args: 297 | size: Size (height * width) of the target grid. 298 | height: Expected height. If `size % height != 0`, this field will be 299 | ignored. (default: 0) 300 | width: Expected width. If `size % width != 0`, this field will be 301 | ignored. (default: 0) 302 | is_portrait: Whether to return a portrait size of a landscape size. 303 | (default: False) 304 | 305 | Returns: 306 | A two-element tuple, representing height and width respectively. 307 | """ 308 | assert isinstance(size, int) 309 | assert isinstance(height, int) 310 | assert isinstance(width, int) 311 | if size <= 0: 312 | return (0, 0) 313 | 314 | if height > 0 and width > 0 and height * width != size: 315 | height = 0 316 | width = 0 317 | 318 | if height > 0 and width > 0 and height * width == size: 319 | return (height, width) 320 | if height > 0 and size % height == 0: 321 | return (height, size // height) 322 | if width > 0 and size % width == 0: 323 | return (size // width, width) 324 | 325 | height = int(np.sqrt(size)) 326 | while height > 0: 327 | if size % height == 0: 328 | width = size // height 329 | break 330 | height = height - 1 331 | 332 | return (width, height) if is_portrait else (height, width) 333 | 334 | 335 | def list_images_from_dir(directory): 336 | """Lists all images from the given directory. 337 | 338 | NOTE: Do NOT support finding images recursively. 339 | 340 | Args: 341 | directory: The directory to find images from. 342 | 343 | Returns: 344 | A list of sorted filenames, with the directory as prefix. 345 | """ 346 | image_list = [] 347 | for filename in os.listdir(directory): 348 | if check_file_ext(filename, *IMAGE_EXTENSIONS): 349 | image_list.append(os.path.join(directory, filename)) 350 | return sorted(image_list) 351 | -------------------------------------------------------------------------------- /utils/video_visualizer.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from skvideo.io import FFmpegWriter 3 | from .image_utils import parse_image_size 4 | from .image_utils import load_image 5 | from .image_utils import resize_image 6 | from .image_utils import list_images_from_dir 7 | 8 | 9 | class VideoVisualizer(object): 10 | """Defines the video visualizer that presents images as a video.""" 11 | 12 | def __init__(self, 13 | path=None, 14 | frame_size=None, 15 | fps=25.0, 16 | codec='libx264', 17 | pix_fmt='yuv420p', 18 | crf=1): 19 | """Initializes the video visualizer. 20 | 21 | Args: 22 | path: Path to write the video. (default: None) 23 | frame_size: Frame size, i.e., (height, width). (default: None) 24 | fps: Frames per second. (default: 24) 25 | codec: Codec. (default: `libx264`) 26 | pix_fmt: Pixel format. (default: `yuv420p`) 27 | crf: Constant rate factor, which controls the compression. The 28 | larger this field is, the higher compression and lower quality. 29 | `0` means no compression and consequently the highest quality. 30 | To enable QuickTime playing (requires YUV to be 4:2:0, but 31 | `crf = 0` results YUV to be 4:4:4), please set this field as 32 | at least 1. (default: 1) 33 | """ 34 | self.set_path(path) 35 | self.set_frame_size(frame_size) 36 | self.set_fps(fps) 37 | self.set_codec(codec) 38 | self.set_pix_fmt(pix_fmt) 39 | self.set_crf(crf) 40 | self.video = None 41 | 42 | def set_path(self, path=None): 43 | """Sets the path to save the video.""" 44 | self.path = path 45 | 46 | def set_frame_size(self, frame_size=None): 47 | """Sets the video frame size.""" 48 | height, width = parse_image_size(frame_size) 49 | self.frame_height = height 50 | self.frame_width = width 51 | 52 | def set_fps(self, fps=25.0): 53 | """Sets the FPS (frame per second) of the video.""" 54 | self.fps = fps 55 | 56 | def set_codec(self, codec='libx264'): 57 | """Sets the video codec.""" 58 | self.codec = codec 59 | 60 | def set_pix_fmt(self, pix_fmt='yuv420p'): 61 | """Sets the video pixel format.""" 62 | self.pix_fmt = pix_fmt 63 | 64 | def set_crf(self, crf=1): 65 | """Sets the CRF (constant rate factor) of the video.""" 66 | self.crf = crf 67 | 68 | def init_video(self): 69 | """Initializes an empty video with expected settings.""" 70 | assert self.frame_height > 0 71 | assert self.frame_width > 0 72 | 73 | video_setting = { 74 | '-r': f'{self.fps:.2f}', 75 | '-s': f'{self.frame_width}x{self.frame_height}', 76 | '-vcodec': f'{self.codec}', 77 | '-crf': f'{self.crf}', 78 | '-pix_fmt': f'{self.pix_fmt}', 79 | } 80 | self.video = FFmpegWriter(self.path, outputdict=video_setting) 81 | 82 | def add(self, frame): 83 | """Adds a frame into the video visualizer. 84 | 85 | NOTE: The input frame is assumed to be with `RGB` channel order. 86 | """ 87 | if self.video is None: 88 | height, width = frame.shape[0:2] 89 | height = self.frame_height or height 90 | width = self.frame_width or width 91 | self.set_frame_size((height, width)) 92 | self.init_video() 93 | if frame.shape[0:2] != (self.frame_height, self.frame_width): 94 | frame = resize_image(frame, (self.frame_width, self.frame_height)) 95 | self.video.writeFrame(frame) 96 | 97 | def visualize_collection(self, images, save_path=None): 98 | """Visualizes a collection of images one by one.""" 99 | if save_path is not None and save_path != self.path: 100 | self.save() 101 | self.set_path(save_path) 102 | for image in images: 103 | self.add(image) 104 | self.save() 105 | 106 | def visualize_list(self, image_list, save_path=None): 107 | """Visualizes a list of image files.""" 108 | if save_path is not None and save_path != self.path: 109 | self.save() 110 | self.set_path(save_path) 111 | for filename in image_list: 112 | image = load_image(filename) 113 | self.add(image) 114 | self.save() 115 | 116 | def visualize_directory(self, directory, save_path=None): 117 | """Visualizes all images under a directory.""" 118 | image_list = list_images_from_dir(directory) 119 | self.visualize_list(image_list, save_path) 120 | 121 | def save(self): 122 | """Saves the video by closing the file.""" 123 | if self.video is not None: 124 | self.video.close() 125 | self.video = None 126 | self.set_path(None) 127 | 128 | 129 | if __name__ == '__main__': 130 | from glob import glob 131 | import cv2 132 | video_visualizer = VideoVisualizer(path='output.mp4', 133 | frame_size=None, 134 | fps=25.0) 135 | img_folder = 'src_images/' 136 | imgs = sorted(glob(img_folder + '/*.png')) 137 | for img in imgs: 138 | image = cv2.imread(img) 139 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 140 | video_visualizer.add(image) 141 | video_visualizer.save() -------------------------------------------------------------------------------- /utils/warmup_scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | from torch.optim.lr_scheduler import ReduceLROnPlateau 3 | 4 | class GradualWarmupScheduler(_LRScheduler): 5 | """ Gradually warm-up(increasing) learning rate in optimizer. 6 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 7 | Args: 8 | optimizer (Optimizer): Wrapped optimizer. 9 | multiplier: target learning rate = base lr * multiplier 10 | total_epoch: target learning rate is reached at total_epoch, gradually 11 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 12 | """ 13 | 14 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 15 | self.multiplier = multiplier 16 | if self.multiplier < 1.: 17 | raise ValueError('multiplier should be greater thant or equal to 1.') 18 | self.total_epoch = total_epoch 19 | self.after_scheduler = after_scheduler 20 | self.finished = False 21 | super().__init__(optimizer) 22 | 23 | def get_lr(self): 24 | if self.last_epoch > self.total_epoch: 25 | if self.after_scheduler: 26 | if not self.finished: 27 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 28 | self.finished = True 29 | return self.after_scheduler.get_lr() 30 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 31 | 32 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 33 | 34 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 35 | if epoch is None: 36 | epoch = self.last_epoch + 1 37 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 38 | if self.last_epoch <= self.total_epoch: 39 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 40 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 41 | param_group['lr'] = lr 42 | else: 43 | if epoch is None: 44 | self.after_scheduler.step(metrics, None) 45 | else: 46 | self.after_scheduler.step(metrics, epoch - self.total_epoch) 47 | 48 | def step(self, epoch=None, metrics=None): 49 | if type(self.after_scheduler) != ReduceLROnPlateau: 50 | if self.finished and self.after_scheduler: 51 | if epoch is None: 52 | self.after_scheduler.step(None) 53 | else: 54 | self.after_scheduler.step(epoch - self.total_epoch) 55 | else: 56 | return super(GradualWarmupScheduler, self).step(epoch) 57 | else: 58 | self.step_ReduceLROnPlateau(metrics, epoch) --------------------------------------------------------------------------------