├── .gitmodules ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── assets ├── depth_train.gif ├── depth_train_init.gif ├── dog_depth.gif ├── dog_depth_init.gif ├── rgb_dog.gif ├── rgb_train.gif └── teaser.gif ├── configs └── __init__.py ├── datasets ├── __init__.py ├── base_dataset.py ├── davis_sequence.py └── shutterstock.py ├── dependencies ├── conda_packages.txt └── requirements.txt ├── experiments ├── davis │ ├── test_cmd.txt │ └── train_sequence.sh └── shutterstock │ ├── test_cmd.txt │ └── train_sequence.sh ├── loggers ├── Progbar.py ├── __init__.py ├── html_template.py └── loggers.py ├── losses ├── __init__.py └── scene_flow_projection.py ├── models ├── __init__.py ├── netinterface.py ├── scene_flow_motion_field.py └── video_base.py ├── networks ├── FCNUnet.py ├── MLP.py ├── __init__.py ├── blocks.py └── sceneflow_field.py ├── options ├── __init__.py ├── options_test.py └── options_train.py ├── scripts ├── download_data_and_depth_ckpt.sh ├── download_triangulation_files.sh └── preprocess │ ├── davis │ ├── generate_flows.py │ ├── generate_frame_midas.py │ └── generate_sequence_midas.py │ └── shutterstock │ ├── generate_flows.py │ ├── generate_frame_midas.py │ └── generate_sequence_midas.py ├── test.py ├── third_party ├── MiDaS.py ├── __init__.py ├── hourglass.py ├── midas_blocks.py └── util_colormap.py ├── train.py ├── util ├── __init__.py ├── util_config.py ├── util_flow.py ├── util_html.py ├── util_imageIO.py ├── util_loadlib.py ├── util_plot.py ├── util_print.py └── util_visualize.py └── visualize ├── base_visualizer.py └── html_visualizer.py /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/RAFT"] 2 | path = third_party/RAFT 3 | url = https://github.com/princeton-vl/RAFT.git 4 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code Reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows [Google's Open Source Community 28 | Guidelines](https://opensource.google/conduct/). 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Consistent Depth of Moving Objects in Video 2 | 3 | ![teaser](./assets/teaser.gif) 4 | 5 | This repository contains training code for the SIGGRAPH 2021 paper 6 | "[Consistent Depth of Moving Objects in 7 | Video](https://dynamic-video-depth.github.io/)". 8 | 9 | This is not an officially supported Google product. 10 | 11 | ## Installing Dependencies 12 | 13 | We provide both conda and pip installations for dependencies. 14 | 15 | - To install with conda, run 16 | 17 | ``` 18 | conda create --name dynamic-video-depth --file ./dependencies/conda_packages.txt 19 | ``` 20 | 21 | - To install with pip, run 22 | 23 | ``` 24 | pip install -r ./dependencies/requirements.txt 25 | ``` 26 | 27 | 28 | 29 | ## Training 30 | We provide two preprocessed video tracks from the DAVIS dataset. To download the pre-trained single-image depth prediction checkpoints, as well as the example data, run: 31 | 32 | 33 | ``` 34 | bash ./scripts/download_data_and_depth_ckpt.sh 35 | ``` 36 | 37 | This script will automatically download and unzip the checkpoints and data. To download mannually, use [this link](https://drive.google.com/drive/folders/19_hbgJ9mettcbMQBYYnH0seiUaREZD1D?usp=sharing). 38 | 39 | To train using the example data, run: 40 | 41 | ``` 42 | bash ./experiments/davis/train_sequence.sh 0 --track_id dog 43 | ``` 44 | 45 | The first argument indicates the GPU id for training, and `--track_id` indicates the name of the track. ('dog' and 'train' are provided.) 46 | 47 | After training, the results should look like: 48 | 49 | | Video | Our Depth | Single Image Depth | 50 | :----:| :----:| :----: 51 | ![](assets/rgb_dog.gif) | ![](assets/dog_depth.gif) | ![](assets/dog_depth_init.gif) | 52 | ![](assets/rgb_train.gif) | ![](assets/depth_train.gif) | ![](assets/depth_train_init.gif) | 53 | 54 | 55 | ## Dataset Preparation: 56 | 57 | To help with generating custom datasets for training, We provide examples of preparing the dataset from DAVIS, and two sequences from ShutterStock, which are showcased in our paper. 58 | 59 | The general work flow for preprocessing the dataset is: 60 | 61 | 1. Calibrate the scale of camera translation, transform the camera matrices into camera-to-world convention, and save as individual files. 62 | 63 | 2. Calculate flow between pairs of frames, as well as occlusion estimates. 64 | 65 | 3. Pack flow and per-frame data into training batches. 66 | 67 | To be more specific, example codes are provided in `.scripts/preprocess` 68 | 69 | We provide the triangulation results [here](https://drive.google.com/file/d/1U07e9xtwYbBZPpJ2vfsLaXYMWATt4XyB/view?usp=sharing) and [here](https://drive.google.com/file/d/1om58tVKujaq1Jo_ShpKc4sWVAWBoKY6U/view?usp=sharing). You can download them in a single script by running: 70 | 71 | ``` 72 | bash ./scripts/download_triangulation_files.sh 73 | ``` 74 | 75 | ### Davis data preparation 76 | 77 | 1. Download the DAVIS dataset here, and unzip it under `./datafiles`. 78 | 79 | 2. Run `python ./scripts/preprocess/davis/generate_frame_midas.py`. This requires `trimesh` to be installed (`pip install trimesh` should do the trick). This script projects the triangulated 3D points to calibrate camera translation scales. 80 | 81 | 3. Run `python ./scripts/preprocess/davis/generate_flows.py` to generate optical flows between pairs of images. This stage requires `RAFT`, which is included as a submodule in this repo. 82 | 83 | 84 | 4. Run `python ./scripts/preprocess/davis/generate_sequence_midas.py` to pack camera calibrations and images into training batches. 85 | 86 | ### ShutterStock Videos 87 | 88 | 89 | 1. Download the ShutterStock videos [here](https://www.shutterstock.com/video/clip-1058262031-loyal-golden-retriever-dog-running-across-green) and [here](https://www.shutterstock.com/nb/video/clip-1058781907-handsome-pedigree-cute-white-labrador-walking-on). 90 | 91 | 92 | 93 | 2. Cast the videos as images, put them under `./datafiles/shutterstock/images`, and rename them to match the file names in `./datafiles/shutterstock/triangulation`. Note that not all frames are triangulated; time stamp of valid frames are recorded in the triangulation file name. 94 | 95 | 2. Run `python ./scripts/preprocess/shutterstock/generate_frame_midas.py` to pack per-frame data. 96 | 97 | 3. Run `python ./scripts/preprocess/shutterstock/generate_flows.py` to generate optical flows between pairs of images. 98 | 99 | 4. Run `python ./scripts/preprocess/shutterstock/generate_sequence_midas.py` to pack flows and per-frame data into training batches. 100 | 101 | 5. Example training script is located at `./experiments/shutterstock/train_sequence.sh` 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /assets/depth_train.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/dynamic-video-depth/79177ef5941b15b0aa2395b626f922fcb7b4c179/assets/depth_train.gif -------------------------------------------------------------------------------- /assets/depth_train_init.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/dynamic-video-depth/79177ef5941b15b0aa2395b626f922fcb7b4c179/assets/depth_train_init.gif -------------------------------------------------------------------------------- /assets/dog_depth.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/dynamic-video-depth/79177ef5941b15b0aa2395b626f922fcb7b4c179/assets/dog_depth.gif -------------------------------------------------------------------------------- /assets/dog_depth_init.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/dynamic-video-depth/79177ef5941b15b0aa2395b626f922fcb7b4c179/assets/dog_depth_init.gif -------------------------------------------------------------------------------- /assets/rgb_dog.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/dynamic-video-depth/79177ef5941b15b0aa2395b626f922fcb7b4c179/assets/rgb_dog.gif -------------------------------------------------------------------------------- /assets/rgb_train.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/dynamic-video-depth/79177ef5941b15b0aa2395b626f922fcb7b4c179/assets/rgb_train.gif -------------------------------------------------------------------------------- /assets/teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/dynamic-video-depth/79177ef5941b15b0aa2395b626f922fcb7b4c179/assets/teaser.gif -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | depth_pretrain_path = './pretrained_depth_ckpt/best_depth_Ours_Bilinear_inc_3_net_G.pth' 16 | midas_pretrain_path = './pretrained_depth_ckpt/midas_cpkt.pt' 17 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import importlib 16 | 17 | 18 | def get_dataset(alias): 19 | dataset_module = importlib.import_module('datasets.' + alias.lower()) 20 | return dataset_module.Dataset 21 | -------------------------------------------------------------------------------- /datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import torch.utils.data as data 17 | import torch 18 | 19 | 20 | class Dataset(data.Dataset): 21 | @classmethod 22 | def add_arguments(cls, parser): 23 | return parser, set() 24 | 25 | def __init__(self, opt, mode='train', model=None): 26 | super().__init__() 27 | self.opt = opt 28 | self.mode = mode 29 | 30 | @staticmethod 31 | def convert_to_float32(sample_loaded): 32 | for k, v in sample_loaded.items(): 33 | if isinstance(v, np.ndarray): 34 | sample_loaded[k] = torch.from_numpy(v).float() 35 | -------------------------------------------------------------------------------- /datasets/davis_sequence.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | from .base_dataset import Dataset as base_dataset 17 | from glob import glob 18 | from os.path import join 19 | import torch 20 | 21 | 22 | class Dataset(base_dataset): 23 | 24 | @classmethod 25 | def add_arguments(cls, parser): 26 | parser.add_argument('--cache', action='store_true', help='cache the data into ram') 27 | parser.add_argument('--subsample', action='store_true', help='subsample the video in time') 28 | parser.add_argument('--track_id', default='train', type=str, help='the track id to load') 29 | parser.add_argument('--overfit', action='store_true', help='overfit and see if things works') 30 | parser.add_argument('--gaps', type=str, default='1,2,3,4', help='gaps for sequences') 31 | parser.add_argument('--repeat', type=int, default=1, help='number of repeatition') 32 | parser.add_argument('--select', action='store_true', help='pred') 33 | return parser, set() 34 | 35 | def __init__(self, opt, mode='train', model=None): 36 | super().__init__(opt, mode, model) 37 | self.mode = mode 38 | assert mode in ('train', 'vali') 39 | 40 | data_root = './datafiles/davis_processed' 41 | # tracks = sorted(glob(join(data_root, 'frames_midas', '*'))) 42 | # tracks = [x.split('/')[-1] for x in tracks] 43 | track_name = opt.track_id # tracks[opt.track_id] 44 | if model is None: 45 | self.required = ['img', 'flow'] 46 | self.preproc = None 47 | elif mode == 'train': 48 | self.required = model.requires 49 | self.preproc = model.preprocess 50 | else: 51 | self.required = ['img'] 52 | self.preproc = model.preprocess 53 | 54 | frame_prefix = 'frames_midas' 55 | seq_prefix = 'sequences_select_pairs_midas' 56 | 57 | if mode == 'train': 58 | 59 | if self.opt.subsample: 60 | data_path = join(data_root, seq_prefix, track_name, 'subsample') 61 | else: 62 | data_path = join(data_root, seq_prefix, track_name, '%03d' % 1) 63 | 64 | gaps = opt.gaps.split(',') 65 | gaps = [int(x) for x in gaps] 66 | self.file_list = [] 67 | for g in gaps: 68 | 69 | file_list = sorted(glob(join(data_path, f'shuffle_False_gap_{g:02d}_*.pt'))) 70 | self.file_list += file_list 71 | 72 | frame_data_path = join(data_root, frame_prefix, track_name) 73 | self.n_frames = len(sorted(glob(join(frame_data_path, '*.npz')))) + 0.0 74 | 75 | else: 76 | data_path = join(data_root, frame_prefix, track_name) 77 | self.file_list = sorted(glob(join(data_path, '*.npz'))) 78 | self.n_frames = len(self.file_list) + 0.0 79 | 80 | def __len__(self): 81 | if self.mode != 'train': 82 | return len(self.file_list) 83 | else: 84 | return len(self.file_list) * self.opt.repeat 85 | 86 | def __getitem__(self, idx): 87 | sample_loaded = {} 88 | if self.opt.overfit: 89 | idx = idx % self.opt.capat 90 | else: 91 | idx = idx % len(self.file_list) 92 | 93 | if self.opt.subsample: 94 | unit = 2.0 95 | else: 96 | unit = 1.0 97 | 98 | if self.mode == 'train': 99 | 100 | dataset = torch.load(self.file_list[idx]) 101 | 102 | _, H, W, _ = dataset['img_1'].shape 103 | dataset['img_1'] = dataset['img_1'].permute([0, 3, 1, 2]) 104 | dataset['img_2'] = dataset['img_2'].permute([0, 3, 1, 2]) 105 | ts = dataset['fid_1'].reshape([-1, 1, 1, 1]).expand(-1, -1, H, W) / self.n_frames 106 | ts2 = dataset['fid_2'].reshape([-1, 1, 1, 1]).expand(-1, -1, H, W) / self.n_frames 107 | for k in dataset: 108 | if type(dataset[k]) == list: 109 | continue 110 | sample_loaded[k] = dataset[k].float() 111 | sample_loaded['time_step'] = unit / self.n_frames 112 | sample_loaded['time_stamp_1'] = ts.float() 113 | sample_loaded['time_stamp_2'] = ts2.float() 114 | sample_loaded['frame_id_1'] = np.asarray(dataset['fid_1']) 115 | sample_loaded['frame_id_2'] = np.asarray(dataset['fid_2']) 116 | 117 | else: 118 | dataset = np.load(self.file_list[idx]) 119 | H, W, _ = dataset['img'].shape 120 | sample_loaded['time_stamp_1'] = np.ones([1, H, W]) * idx / self.n_frames 121 | sample_loaded['img'] = np.transpose(dataset['img'], [2, 0, 1]) 122 | sample_loaded['frame_id_1'] = idx 123 | 124 | sample_loaded['time_step'] = unit / self.n_frames 125 | sample_loaded['depth_pred'] = dataset['depth_pred'][None, ...] 126 | sample_loaded['cam_c2w'] = dataset['pose_c2w'] 127 | sample_loaded['K'] = dataset['intrinsics'] 128 | sample_loaded['depth_mvs'] = dataset['depth_mvs'][None, ...] 129 | # add decomposed cam mat 130 | cam_pose_c2w_1 = dataset['pose_c2w'] 131 | R_1 = cam_pose_c2w_1[:3, :3] 132 | t_1 = cam_pose_c2w_1[:3, 3] 133 | K = dataset['intrinsics'] 134 | 135 | # for network use: 136 | R_1_tensor = np.zeros([1, 1, 3, 3]) 137 | R_1_T_tensor = np.zeros([1, 1, 3, 3]) 138 | t_1_tensor = np.zeros([1, 1, 1, 3]) 139 | K_tensor = np.zeros([1, 1, 3, 3]) 140 | K_inv_tensor = np.zeros([1, 1, 3, 3]) 141 | R_1_tensor[..., :, :] = R_1.T 142 | R_1_T_tensor[..., :, :] = R_1 143 | t_1_tensor[..., :] = t_1 144 | K_tensor[..., :, :] = K.T 145 | K_inv_tensor[..., :, :] = np.linalg.inv(K).T 146 | 147 | sample_loaded['R_1'] = R_1_tensor 148 | sample_loaded['R_1_T'] = R_1_T_tensor 149 | sample_loaded['t_1'] = t_1_tensor 150 | sample_loaded['K'] = K_tensor 151 | sample_loaded['K_inv'] = K_inv_tensor 152 | sample_loaded['pair_path'] = self.file_list[idx] 153 | self.convert_to_float32(sample_loaded) 154 | return sample_loaded 155 | -------------------------------------------------------------------------------- /datasets/shutterstock.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | from .base_dataset import Dataset as base_dataset 17 | from glob import glob 18 | from os.path import join 19 | import torch 20 | 21 | 22 | class Dataset(base_dataset): 23 | 24 | @classmethod 25 | def add_arguments(cls, parser): 26 | parser.add_argument('--cache', action='store_true', help='cache the data into ram') 27 | parser.add_argument('--subsample', action='store_true', help='subsample the video in time') 28 | parser.add_argument('--track_id', default=0, type=int, help='the track id to load') 29 | parser.add_argument('--overfit', action='store_true', help='overfit and see if things works') 30 | parser.add_argument('--gaps', type=str, default='1,2,3,4', help='gaps for sequences') 31 | parser.add_argument('--repeat', type=int, default=1, help='number of repeatition') 32 | parser.add_argument('--select', action='store_true', help='pred') 33 | return parser, set() 34 | 35 | def __init__(self, opt, mode='train', model=None): 36 | super().__init__(opt, mode, model) 37 | self.mode = mode 38 | assert mode in ('train', 'vali') 39 | 40 | data_root = './datafiles/shutterstock' 41 | tracks = sorted(glob(join(data_root, 'frames_midas', '*'))) 42 | tracks = [x.split('/')[-1] for x in tracks] 43 | track_name = tracks[opt.track_id] 44 | if model is None: 45 | self.required = ['img', 'flow'] 46 | self.preproc = None 47 | elif mode == 'train': 48 | self.required = model.requires 49 | self.preproc = model.preprocess 50 | else: 51 | self.required = ['img'] 52 | self.preproc = model.preprocess 53 | 54 | frame_prefix = 'frames_midas' 55 | seq_prefix = 'sequences_select_pairs_midas' 56 | 57 | if mode == 'train': 58 | 59 | if self.opt.subsample: 60 | data_path = join(data_root, seq_prefix, track_name, 'subsample') 61 | else: 62 | data_path = join(data_root, seq_prefix, track_name, '%03d' % 1) 63 | 64 | gaps = opt.gaps.split(',') 65 | gaps = [int(x) for x in gaps] 66 | self.file_list = [] 67 | for g in gaps: 68 | 69 | file_list = sorted(glob(join(data_path, f'shuffle_False_gap_{g:02d}_*.pt'))) 70 | self.file_list += file_list 71 | 72 | frame_data_path = join(data_root, frame_prefix, track_name) 73 | self.n_frames = len(sorted(glob(join(frame_data_path, '*.npz')))) + 0.0 74 | 75 | else: 76 | data_path = join(data_root, frame_prefix, track_name) 77 | self.file_list = sorted(glob(join(data_path, '*.npz'))) 78 | self.n_frames = len(self.file_list) + 0.0 79 | 80 | def __len__(self): 81 | if self.mode != 'train': 82 | return len(self.file_list) 83 | else: 84 | return len(self.file_list) * self.opt.repeat 85 | 86 | def __getitem__(self, idx): 87 | sample_loaded = {} 88 | if self.opt.overfit: 89 | idx = idx % self.opt.capat 90 | else: 91 | idx = idx % len(self.file_list) 92 | 93 | if self.opt.subsample: 94 | unit = 2.0 95 | else: 96 | unit = 1.0 97 | 98 | if self.mode == 'train': 99 | 100 | dataset = torch.load(self.file_list[idx]) 101 | 102 | _, H, W, _ = dataset['img_1'].shape 103 | dataset['img_1'] = dataset['img_1'].permute([0, 3, 1, 2]) 104 | dataset['img_2'] = dataset['img_2'].permute([0, 3, 1, 2]) 105 | ts = dataset['fid_1'].reshape([-1, 1, 1, 1]).expand(-1, -1, H, W) / self.n_frames 106 | ts2 = dataset['fid_2'].reshape([-1, 1, 1, 1]).expand(-1, -1, H, W) / self.n_frames 107 | for k in dataset: 108 | if type(dataset[k]) == list: 109 | continue 110 | sample_loaded[k] = dataset[k].float() 111 | sample_loaded['time_step'] = unit / self.n_frames 112 | sample_loaded['time_stamp_1'] = ts.float() 113 | sample_loaded['time_stamp_2'] = ts2.float() 114 | sample_loaded['frame_id_1'] = np.asarray(dataset['fid_1']) 115 | sample_loaded['frame_id_2'] = np.asarray(dataset['fid_2']) 116 | 117 | else: 118 | dataset = np.load(self.file_list[idx]) 119 | H, W, _ = dataset['img'].shape 120 | sample_loaded['time_stamp_1'] = np.ones([1, H, W]) * idx / self.n_frames 121 | sample_loaded['img'] = np.transpose(dataset['img'], [2, 0, 1]) 122 | sample_loaded['frame_id_1'] = idx 123 | 124 | sample_loaded['time_step'] = unit / self.n_frames 125 | sample_loaded['depth_pred'] = dataset['depth_pred'][None, ...] 126 | sample_loaded['cam_c2w'] = dataset['pose_c2w'] 127 | sample_loaded['K'] = dataset['intrinsics'] 128 | sample_loaded['depth_mvs'] = dataset['depth_mvs'][None, ...] 129 | # add decomposed cam mat 130 | cam_pose_c2w_1 = dataset['pose_c2w'] 131 | R_1 = cam_pose_c2w_1[:3, :3] 132 | t_1 = cam_pose_c2w_1[:3, 3] 133 | K = dataset['intrinsics'] 134 | 135 | # for network use: 136 | R_1_tensor = np.zeros([1, 1, 3, 3]) 137 | R_1_T_tensor = np.zeros([1, 1, 3, 3]) 138 | t_1_tensor = np.zeros([1, 1, 1, 3]) 139 | K_tensor = np.zeros([1, 1, 3, 3]) 140 | K_inv_tensor = np.zeros([1, 1, 3, 3]) 141 | R_1_tensor[..., :, :] = R_1.T 142 | R_1_T_tensor[..., :, :] = R_1 143 | t_1_tensor[..., :] = t_1 144 | K_tensor[..., :, :] = K.T 145 | K_inv_tensor[..., :, :] = np.linalg.inv(K).T 146 | 147 | sample_loaded['R_1'] = R_1_tensor 148 | sample_loaded['R_1_T'] = R_1_T_tensor 149 | sample_loaded['t_1'] = t_1_tensor 150 | sample_loaded['K'] = K_tensor 151 | sample_loaded['K_inv'] = K_inv_tensor 152 | sample_loaded['pair_path'] = self.file_list[idx] 153 | self.convert_to_float32(sample_loaded) 154 | return sample_loaded 155 | -------------------------------------------------------------------------------- /dependencies/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.13.0 2 | backcall==0.2.0 3 | cachetools==4.2.2 4 | certifi==2021.5.30 5 | charset-normalizer==2.0.4 6 | cycler==0.10.0 7 | decorator==4.4.2 8 | filelock==3.0.12 9 | gdown==3.13.0 10 | google-auth==1.34.0 11 | google-auth-oauthlib==0.4.5 12 | grpcio==1.39.0 13 | idna==3.2 14 | imageio==2.9.0 15 | ipykernel==5.3.4 16 | ipython==7.18.1 17 | ipython-genutils==0.2.0 18 | jedi==0.17.2 19 | jupyter-client==6.1.7 20 | jupyter-core==4.6.3 21 | kiwisolver==1.3.1 22 | Markdown==3.3.4 23 | matplotlib==3.4.2 24 | networkx==2.6.2 25 | numpy==1.21.1 26 | oauthlib==3.1.1 27 | opencv-python==4.5.3.56 28 | pandas==1.3.1 29 | parso==0.7.1 30 | pexpect==4.8.0 31 | pickleshare==0.7.5 32 | Pillow==8.3.1 33 | prompt-toolkit==3.0.7 34 | protobuf==3.17.3 35 | ptyprocess==0.6.0 36 | pyasn1==0.4.8 37 | pyasn1-modules==0.2.8 38 | Pygments==2.6.1 39 | pyparsing==2.4.7 40 | PySocks==1.7.1 41 | python-dateutil==2.8.2 42 | pytz==2021.1 43 | PyWavelets==1.1.1 44 | pyzmq==19.0.2 45 | requests==2.26.0 46 | requests-oauthlib==1.3.0 47 | rsa==4.7.2 48 | scikit-image==0.18.2 49 | scipy==1.7.1 50 | six==1.16.0 51 | tensorboard==2.6.0 52 | tensorboard-data-server==0.6.1 53 | tensorboard-plugin-wit==1.8.0 54 | tifffile==2021.8.8 55 | torch==1.9.0 56 | torchvision==0.10.0 57 | tornado==6.1 58 | tqdm==4.62.0 59 | traitlets==5.0.4 60 | trimesh==3.9.27 61 | typing-extensions==3.10.0.0 62 | urllib3==1.26.6 63 | wcwidth==0.2.5 64 | Werkzeug==2.0.1 65 | -------------------------------------------------------------------------------- /experiments/davis/test_cmd.txt: -------------------------------------------------------------------------------- 1 | python test.py --net {net} --dataset {dataset} --workers 4 --output_dir './test_results/{dataset}/' --epoch {epoch} --html_logger --batch_size 1 --gpu {gpu} --track_id {track_id} --suffix {suffix_expand} --checkpoint_path {full_logdir} -------------------------------------------------------------------------------- /experiments/davis/train_sequence.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | #!/bin/bash 16 | if [ $# -lt 1 ]; then 17 | echo "Usage: $0 gpu " 18 | exit 1 19 | fi 20 | gpu="$1" 21 | shift 22 | set -e 23 | cmd=" 24 | python train.py \ 25 | --net scene_flow_motion_field \ 26 | --dataset davis_sequence \ 27 | --track_id train \ 28 | --log_time \ 29 | --epoch_batches 2000 \ 30 | --epoch 20 \ 31 | --lr 1e-6 \ 32 | --html_logger \ 33 | --vali_batches 150 \ 34 | --batch_size 1 \ 35 | --optim adam \ 36 | --vis_batches_vali 4 \ 37 | --vis_every_vali 1 \ 38 | --vis_every_train 1 \ 39 | --vis_batches_train 5 \ 40 | --vis_at_start \ 41 | --tensorboard \ 42 | --gpu "$gpu" \ 43 | --save_net 1 \ 44 | --workers 4 \ 45 | --one_way \ 46 | --loss_type l1 \ 47 | --l1_mul 0 \ 48 | --acc_mul 1 \ 49 | --disp_mul 1 \ 50 | --warm_sf 5 \ 51 | --scene_lr_mul 1000 \ 52 | --repeat 1 \ 53 | --flow_mul 1\ 54 | --sf_mag_div 100 \ 55 | --time_dependent \ 56 | --gaps 1,2,4,6,8 \ 57 | --midas \ 58 | --use_disp \ 59 | --logdir './checkpoints/davis/sequence/' \ 60 | --suffix 'track_{track_id}_{loss_type}_wreg_{warm_reg}_acc_{acc_mul}_disp_{disp_mul}_flowmul_{flow_mul}_time_{time_dependent}_CNN_{use_cnn}_gap_{gaps}_Midas_{midas}_ud_{use_disp}' \ 61 | --test_template './experiments/davis/test_cmd.txt' \ 62 | --force_overwrite \ 63 | $*" 64 | echo $cmd 65 | eval $cmd 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /experiments/shutterstock/test_cmd.txt: -------------------------------------------------------------------------------- 1 | python test.py --net {net} --dataset {dataset} --workers 4 --output_dir './test_results/{dataset}/' --epoch {epoch} --html_logger --batch_size 1 --gpu {gpu} --track_id {track_id} --suffix {suffix_expand} --checkpoint_path {full_logdir} -------------------------------------------------------------------------------- /experiments/shutterstock/train_sequence.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | #!/bin/bash 16 | if [ $# -lt 1 ]; then 17 | echo "Usage: $0 gpu " 18 | exit 1 19 | fi 20 | gpu="$1" 21 | shift 22 | set -e 23 | cmd=" 24 | python train.py \ 25 | --net scene_flow_motion_field \ 26 | --dataset shutterstock \ 27 | --track_id 0 \ 28 | --log_time \ 29 | --epoch_batches 2000 \ 30 | --epoch 20 \ 31 | --lr 1e-6 \ 32 | --html_logger \ 33 | --vali_batches 150 \ 34 | --batch_size 1 \ 35 | --optim adam \ 36 | --vis_batches_vali 4 \ 37 | --vis_every_vali 1 \ 38 | --vis_every_train 1 \ 39 | --vis_batches_train 5 \ 40 | --vis_at_start \ 41 | --tensorboard \ 42 | --gpu "$gpu" \ 43 | --save_net 1 \ 44 | --workers 4 \ 45 | --one_way \ 46 | --loss_type l1 \ 47 | --l1_mul 0 \ 48 | --acc_mul 1 \ 49 | --disp_mul 1 \ 50 | --warm_sf 5 \ 51 | --scene_lr_mul 1000 \ 52 | --repeat 1 \ 53 | --flow_mul 1\ 54 | --sf_mag_div 100 \ 55 | --time_dependent \ 56 | --gaps 1,2,4,6,8 \ 57 | --midas \ 58 | --use_disp \ 59 | --logdir './checkpoints/shutterstock/sequence/' \ 60 | --suffix 'track_{track_id}_{loss_type}_wreg_{warm_reg}_acc_{acc_mul}_disp_{disp_mul}_flowmul_{flow_mul}_time_{time_dependent}_CNN_{use_cnn}_gap_{gaps}_Midas_{midas}_ud_{use_disp}' \ 61 | --test_template './experiments/shutterstock/test_cmd.txt' \ 62 | --force_overwrite \ 63 | $*" 64 | echo $cmd 65 | eval $cmd 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /loggers/Progbar.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import sys 17 | import time 18 | import numpy as np 19 | 20 | 21 | class Progbar(object): 22 | """Displays a progress bar. 23 | # Arguments 24 | target: Total number of steps expected, None if unknown. 25 | interval: Minimum visual progress update interval (in seconds). 26 | """ 27 | 28 | def __init__(self, target, width=30, verbose=1, interval=0.05, no_accum=False): 29 | self.width = width 30 | if target is None: 31 | target = -1 32 | self.target = target 33 | self.sum_values = {} 34 | self.step_values = {} 35 | self.unique_values = [] 36 | self.start = time.time() 37 | self.last_update = 0 38 | self.interval = interval 39 | self.total_width = 0 40 | self.seen_so_far = 0 41 | self.verbose = verbose 42 | self.no_accume = no_accum 43 | 44 | def update(self, current, values=None, force=False): 45 | """Updates the progress bar. 46 | # Arguments 47 | current: Index of current step. 48 | values: List of tuples (name, value_for_last_step). 49 | The progress bar will display averages for these values. 50 | force: Whether to force visual progress update. 51 | """ 52 | values = values or [] 53 | for k, v in values: 54 | if k not in self.sum_values: 55 | self.sum_values[k] = [v * (current - self.seen_so_far), 56 | current - self.seen_so_far] 57 | self.unique_values.append(k) 58 | self.step_values[k] = v 59 | else: 60 | self.sum_values[k][0] += v * (current - self.seen_so_far) 61 | self.sum_values[k][1] += (current - self.seen_so_far) 62 | self.step_values[k] = v 63 | 64 | self.seen_so_far = current 65 | 66 | now = time.time() 67 | if self.verbose == 1: 68 | if not force and (now - self.last_update) < self.interval: 69 | return 70 | 71 | prev_total_width = self.total_width 72 | sys.stdout.write('\b' * prev_total_width) 73 | sys.stdout.write('\r') 74 | 75 | if self.target is not -1: 76 | numdigits = int(np.floor(np.log10(self.target))) + 1 77 | barstr = '%%%dd/%%%dd [' % (numdigits, numdigits) 78 | bar = barstr % (current, self.target) 79 | prog = float(current) / self.target 80 | prog_width = int(self.width * prog) 81 | if prog_width > 0: 82 | bar += ('=' * (prog_width - 1)) 83 | if current < self.target: 84 | bar += '>' 85 | else: 86 | bar += '=' 87 | bar += ('.' * (self.width - prog_width)) 88 | bar += ']' 89 | sys.stdout.write(bar) 90 | self.total_width = len(bar) 91 | 92 | if current: 93 | time_per_unit = (now - self.start) / current 94 | else: 95 | time_per_unit = 0 96 | eta = time_per_unit * (self.target - current) 97 | info = '' 98 | if current < self.target and self.target is not -1: 99 | info += ' - ETA: %ds' % eta 100 | else: 101 | info += ' - %ds' % (now - self.start) 102 | for k in self.unique_values: 103 | info += ' - %s:' % k 104 | if isinstance(self.sum_values[k], list): 105 | if self.no_accume: 106 | avg = self.step_values[k] 107 | else: 108 | avg = np.mean( 109 | self.sum_values[k][0] / max(1, self.sum_values[k][1])) 110 | if abs(avg) > 1e-3: 111 | info += ' %.4f' % avg 112 | else: 113 | info += ' %.4e' % avg 114 | else: 115 | info += ' %s' % self.sum_values[k] 116 | 117 | self.total_width += len(info) 118 | if prev_total_width > self.total_width: 119 | info += ((prev_total_width - self.total_width) * ' ') 120 | 121 | sys.stdout.write(info) 122 | sys.stdout.flush() 123 | 124 | if current >= self.target: 125 | sys.stdout.write('\n') 126 | 127 | if self.verbose == 2: 128 | if current >= self.target: 129 | info = '%ds' % (now - self.start) 130 | for k in self.unique_values: 131 | info += ' - %s:' % k 132 | avg = np.mean( 133 | self.sum_values[k][0] / max(1, self.sum_values[k][1])) 134 | if avg > 1e-3: 135 | info += ' %.4f' % avg 136 | else: 137 | info += ' %.4e' % avg 138 | sys.stdout.write(info + "\n") 139 | 140 | self.last_update = now 141 | 142 | def add(self, n, values=None): 143 | self.update(self.seen_so_far + n, values) 144 | -------------------------------------------------------------------------------- /loggers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /loggers/html_template.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | TABLE_HEADER = """ 16 | 17 | 18 | 19 | 20 | 22 | 24 | 26 | 27 | 28 | 30 | 31 | 32 | 34 | 35 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | {table_header} 62 | 63 | 64 | 65 | 66 | {table_body} 67 | 68 |
69 | 70 | 71 | """ 72 | image_tag_template = "" 73 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /losses/scene_flow_projection.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import torch 17 | from torch import nn 18 | import torch.nn.functional as F 19 | 20 | 21 | class project_ptcld(nn.Module): 22 | 23 | def __init__(self, is_one_way=True): 24 | super().__init__() 25 | self.coord = None 26 | 27 | def forward(self, global_p1, R_1_T, t_1, K): 28 | 29 | B, H, W, _, _ = global_p1.shape 30 | if self.coord is None: 31 | yy, xx = torch.meshgrid(torch.arange(H).float(), torch.arange(W).float()) 32 | self.coord = torch.ones([1, H, W, 1, 2]) 33 | self.coord[0, ..., 0, 0] = xx 34 | self.coord[0, ..., 0, 1] = yy 35 | 36 | coord = self.coord.expand([B, H, W, 1, 2]) 37 | 38 | p1_camera_1 = torch.matmul(global_p1 - t_1, R_1_T) 39 | 40 | p1_image_1 = torch.matmul(p1_camera_1, K) 41 | coord_image_sf = (p1_image_1 / (p1_image_1[..., -1:] + 1e-8))[..., : -1] 42 | displace_field = coord_image_sf - coord 43 | sf_proj = displace_field.squeeze() 44 | return sf_proj 45 | 46 | 47 | # class unproject_ptcld() 48 | class unproject_ptcld(nn.Module): 49 | # tested 50 | def __init__(self, is_one_way=True): 51 | super().__init__() 52 | self.coord = None 53 | 54 | def forward(self, depth_1, R_1, t_1, K_inv): 55 | B, _, H, W = depth_1.shape 56 | if self.coord is None: 57 | yy, xx = torch.meshgrid(torch.arange(H).float(), torch.arange(W).float()) 58 | self.coord = torch.ones([1, H, W, 1, 3]) 59 | self.coord[0, ..., 0, 0] = xx 60 | self.coord[0, ..., 0, 1] = yy 61 | self.coord = self.coord.to(depth_1.device) 62 | 63 | depth_1 = depth_1.view([B, H, W, 1, 1]) 64 | p1_camera_1 = depth_1 * torch.matmul(self.coord, K_inv) 65 | global_p1 = torch.matmul(p1_camera_1, R_1) + t_1 66 | 67 | return global_p1 68 | 69 | 70 | class unproject_ptcld_single(nn.Module): 71 | # tested 72 | def __init__(self, is_one_way=True): 73 | super().__init__() 74 | self.coord = None 75 | 76 | def forward(self, depth, pose, K): 77 | B, _, H, W = depth.shape 78 | assert B == 1 79 | if self.coord is None: 80 | yy, xx = torch.meshgrid(torch.arange(H).float(), torch.arange(W).float()) 81 | self.coord = torch.ones([H, W, 3]) 82 | self.coord[..., 0] = xx 83 | self.coord[..., 1] = yy 84 | self.coord = self.coord.to(depth.device) 85 | 86 | depth = depth.view([H, W, 1]) 87 | p1_camera_1 = depth * torch.matmul(self.coord, torch.inverse(K).transpose(0, 1)) 88 | R = pose[: 3, : 3].T 89 | t = pose[: 3, 3: 4].T 90 | global_p1 = torch.matmul(p1_camera_1, R) + t 91 | 92 | return global_p1 93 | 94 | 95 | class flow_by_depth(nn.Module): 96 | # tested 97 | def __init__(self, is_one_way=True): 98 | super().__init__() 99 | self.coord = None 100 | self.sample_grid = None 101 | self.one_way = is_one_way 102 | 103 | def backward_warp(self, depth_2, flow_1_2): 104 | # flow[...,0]: dh 105 | # flow[...,0]: dw 106 | B, _, H, W = depth_2.shape 107 | coord = self.coord[..., :2].view(1, H, W, 2).expand([B, H, W, 2]) 108 | sample_grids = coord + flow_1_2 109 | sample_grids[..., 0] /= (W - 1) / 2 110 | sample_grids[..., 1] /= (H - 1) / 2 111 | sample_grids -= 1 112 | return F.grid_sample(depth_2, sample_grids, align_corners=True, padding_mode='border') 113 | 114 | def forward(self, depth_1, depth_2, flow_1_2, R_1, R_2, R_1_T, R_2_T, t_1, t_2, K, K_inv): 115 | B, _, H, W = depth_1.shape 116 | if self.coord is None: 117 | yy, xx = torch.meshgrid(torch.arange(H).float(), torch.arange(W).float()) 118 | self.coord = torch.ones([1, H, W, 1, 3]) 119 | self.coord[0, ..., 0, 0] = xx 120 | self.coord[0, ..., 0, 1] = yy 121 | self.coord = self.coord.to(depth_1.device) 122 | 123 | coord = self.coord.expand([B, H, W, 1, 3]) 124 | depth_1 = depth_1.view([B, H, W, 1, 1]) 125 | depth_2 = depth_2.view([B, H, W, 1, 1]) 126 | 127 | p1_camera_1 = depth_1 * torch.matmul(self.coord, K_inv) 128 | p2_camera_2 = depth_2 * torch.matmul(self.coord, K_inv) 129 | 130 | global_p1 = torch.matmul(p1_camera_1, R_1) + t_1 131 | global_p2 = torch.matmul(p2_camera_2, R_2) + t_2 # BHW13 132 | 133 | global_p2 = global_p2.squeeze(3).permute([0, 3, 1, 2]) # B3HW 134 | warped_global_p2 = self.backward_warp(global_p2, flow_1_2) 135 | warped_global_p2 = warped_global_p2.permute([0, 2, 3, 1])[..., None, :] # BHW13 136 | sf_by_depth = warped_global_p2 - global_p1 137 | 138 | p1_camera_2 = torch.matmul(global_p1 - t_2, R_2_T) 139 | 140 | p1_image_2 = torch.matmul(p1_camera_2, K) 141 | 142 | coord_image_2 = (p1_image_2 / (p1_image_2[..., -1:] + 1e-8))[..., :-1] 143 | 144 | idB, idH, idW, idC, idF = torch.where(p1_image_2[..., -1:] < 1e-3) 145 | tr_coord = coord[..., :-1] 146 | coord_image_2[idB, idH, idW, idC, idF] = tr_coord[idB, idH, idW, idC, idF] 147 | coord_image_2[idB, idH, idW, idC, idF + 1] = tr_coord[idB, idH, idW, idC, idF + 1] 148 | 149 | depth_flow_1_2 = (coord_image_2 - coord[..., :-1])[..., 0, :] # p_{1 -> 2} 150 | 151 | # warp by flow 152 | 153 | return {'dflow_1_2': depth_flow_1_2, 'sf_by_depth': sf_by_depth, 'warped_global_p2': warped_global_p2, 'global_p1': global_p1} 154 | 155 | 156 | def calc_rigidity_loss(global_p1, sf, depth_1, s=1): 157 | mp = torch.nn.MaxPool2d(3, stride=1, padding=1, dilation=1) 158 | p_u = global_p1[:, None, 0:-2, 1:-1, 0, :] 159 | p_d = global_p1[:, None, 2:, 1:-1, 0, :] 160 | p_l = global_p1[:, None, 1:-1, 0:-2, 0, :] 161 | p_r = global_p1[:, None, 1:-1, 2:, 0, :] 162 | p_c = global_p1[:, None, 1:-1, 1:-1, 0, :] 163 | p_concat = torch.cat([p_u, p_d, p_c, p_l, p_r], axis=1) 164 | d_u = depth_1[:, :, 0:-2, 1:-1] 165 | d_d = depth_1[:, :, 2:, 1:-1] 166 | d_l = depth_1[:, :, 1:-1, 0:-2] 167 | d_r = depth_1[:, :, 1:-1, 2:] 168 | d_c = depth_1[:, :, 1:-1, 1:-1] 169 | d_concat = torch.cat([d_u, d_d, d_c, d_l, d_r], axis=1) 170 | s_u = sf[:, None, 0:-2, 1:-1, 0, :] 171 | s_d = sf[:, None, 2:, 1:-1, 0, :] 172 | s_l = sf[:, None, 1:-1, 0:-2, 0, :] 173 | s_r = sf[:, None, 1:-1, 2:, 0, :] 174 | s_c = sf[:, None, 1:-1, 1:-1, 0, :] 175 | s_concat = torch.cat([s_u, s_d, s_c, s_l, s_r], axis=1) 176 | 177 | prev_u = p_concat[:, 0, ...] - p_concat[:, 2, ...] 178 | prev_d = p_concat[:, 1, ...] - p_concat[:, 2, ...] 179 | prev_l = p_concat[:, 3, ...] - p_concat[:, 2, ...] 180 | prev_r = p_concat[:, 4, ...] - p_concat[:, 2, ...] 181 | after_u = s_concat[:, 0, ...] - s_concat[:, 2, ...] 182 | after_d = s_concat[:, 1, ...] - s_concat[:, 2, ...] 183 | after_l = s_concat[:, 3, ...] - s_concat[:, 2, ...] 184 | after_r = s_concat[:, 4, ...] - s_concat[:, 2, ...] 185 | gradd_u = d_concat[:, 0, ...] - d_concat[:, 2, ...] 186 | gradd_d = d_concat[:, 1, ...] - d_concat[:, 2, ...] 187 | gradd_l = d_concat[:, 3, ...] - d_concat[:, 2, ...] 188 | gradd_r = d_concat[:, 4, ...] - d_concat[:, 2, ...] 189 | 190 | lu = torch.abs(torch.norm(prev_u, dim=-1) - torch.norm(after_u, dim=-1)) 191 | ld = torch.abs(torch.norm(prev_d, dim=-1) - torch.norm(after_d, dim=-1)) 192 | lr = torch.abs(torch.norm(prev_r, dim=-1) - torch.norm(after_r, dim=-1)) 193 | ll = torch.abs(torch.norm(prev_l, dim=-1) - torch.norm(after_l, dim=-1)) 194 | 195 | weight_u = torch.exp(-s * mp(torch.abs(gradd_u))) 196 | weight_d = torch.exp(-s * mp(torch.abs(gradd_d))) 197 | weight_l = torch.exp(-s * mp(torch.abs(gradd_l))) 198 | weight_r = torch.exp(-s * mp(torch.abs(gradd_r))) 199 | total_loss = weight_u * lu + weight_r * lr + weight_d * ld + weight_l * ll 200 | loss_items = {'lu': lu, 'lr': lr, 'ld': ld, 'll': ll, 'weight_u': weight_u, 'weight_d': weight_d, 'weight_r': weight_r, 'weight_l': weight_l} 201 | return total_loss, loss_items 202 | 203 | 204 | class scene_flow_projection_slack(nn.Module): 205 | # tested 206 | def __init__(self, is_one_way=False): 207 | super().__init__() 208 | self.coord = None 209 | self.sample_grid = None 210 | self.is_one_way = is_one_way 211 | 212 | def backward_warp(self, depth_2, flow_1_2): 213 | 214 | B, _, H, W = depth_2.shape 215 | coord = self.coord[..., :2].view(1, H, W, 2).expand([B, H, W, 2]) 216 | sample_grids = coord + flow_1_2 217 | sample_grids[..., 0] /= (W - 1) / 2 218 | sample_grids[..., 1] /= (H - 1) / 2 219 | sample_grids -= 1 220 | return F.grid_sample(depth_2, sample_grids, align_corners=True, padding_mode='border') 221 | 222 | def forward(self, depth_1, depth_2, flow_1_2, flow_2_1, R_1, R_2, R_1_T, R_2_T, t_1, t_2, K, K_inv, sflow_1_2, sflow_2_1): 223 | B, _, H, W = depth_1.shape 224 | if self.coord is None: 225 | yy, xx = torch.meshgrid(torch.arange(H).float(), torch.arange(W).float()) 226 | self.coord = torch.ones([1, H, W, 1, 3]) 227 | self.coord[0, ..., 0, 0] = xx 228 | self.coord[0, ..., 0, 1] = yy 229 | self.coord = self.coord.to(depth_1.device) 230 | 231 | coord = self.coord.expand([B, H, W, 1, 3]) 232 | depth_1 = depth_1.view([B, H, W, 1, 1]) 233 | depth_2 = depth_2.view([B, H, W, 1, 1]) 234 | 235 | p1_camera_1 = depth_1 * torch.matmul(self.coord, K_inv) 236 | p2_camera_2 = depth_2 * torch.matmul(self.coord, K_inv) 237 | global_p1 = torch.matmul(p1_camera_1, R_1) + t_1 238 | global_p2 = torch.matmul(p2_camera_2, R_2) + t_2 239 | 240 | p2_camera_2_w = p2_camera_2.squeeze(3).permute([0, 3, 1, 2]) # B3HW 241 | warped_p2_camera_2 = self.backward_warp(p2_camera_2_w, flow_1_2) 242 | warped_p2_camera_2 = warped_p2_camera_2.permute([0, 2, 3, 1])[..., None, :] # BHW13 243 | 244 | p1_camera_2 = torch.matmul(global_p1 + sflow_1_2 - t_2, R_2_T) 245 | p1_camera_2_static = torch.matmul(global_p1 - t_2, R_2_T) 246 | p2_camera_1 = torch.matmul(global_p2 + sflow_2_1 - t_1, R_1_T) 247 | p1_image_2 = torch.matmul(p1_camera_2, K) 248 | p2_image_1 = torch.matmul(p2_camera_1, K) 249 | p1_image_2_static = torch.matmul(p1_camera_2_static, K) 250 | coord_image_2_static = (p1_image_2_static / (p1_image_2_static[..., -1:] + 1e-8))[..., : -1] 251 | coord_image_2 = (p1_image_2 / (p1_image_2[..., -1:] + 1e-8))[..., : -1] 252 | coord_image_1 = (p2_image_1 / (p2_image_1[..., -1:] + 1e-8))[..., : -1] 253 | idB, idH, idW, idC, idF = torch.where(p1_image_2[..., -1:] < 1e-3) 254 | tr_coord = coord[..., :-1] 255 | coord_image_2[idB, idH, idW, idC, idF] = tr_coord[idB, idH, idW, idC, idF] 256 | coord_image_2[idB, idH, idW, idC, idF + 1] = tr_coord[idB, idH, idW, idC, idF + 1] 257 | idB, idH, idW, idC, idF = torch.where(p2_image_1[..., -1:] < 1e-3) 258 | coord_image_1[idB, idH, idW, idC, idF] = tr_coord[idB, idH, idW, idC, idF] 259 | coord_image_1[idB, idH, idW, idC, idF + 1] = tr_coord[idB, idH, idW, idC, idF + 1] 260 | 261 | idB, idH, idW, idC, idF = torch.where(p1_image_2_static[..., -1:] < 1e-3) 262 | coord_image_2_static[idB, idH, idW, idC, idF] = tr_coord[idB, idH, idW, idC, idF] 263 | coord_image_2_static[idB, idH, idW, idC, idF + 1] = tr_coord[idB, idH, idW, idC, idF + 1] 264 | 265 | depth_flow_1_2 = (coord_image_2 - coord[..., :-1])[..., 0, :] # p_{1 -> 2} 266 | 267 | depth_flow_1_2_static = (coord_image_2_static - coord[..., :-1])[..., 0, :] 268 | depth_image_1_2 = p1_image_2[..., -1].permute(0, 3, 1, 2) # z_{1 -> 2} 269 | 270 | # forward warping depth 271 | depth_1 = depth_1.view(B, 1, H, W) 272 | depth_2 = depth_2.view(B, 1, H, W) 273 | 274 | depth_warp_1_2 = self.backward_warp(depth_2, flow_1_2) 275 | 276 | depth_warp_1_2 = depth_warp_1_2.view([B, 1, H, W]) 277 | 278 | return {'dflow_1_2': depth_flow_1_2, 'depth_image_1_2': depth_image_1_2, 'depth_warp_1_2': depth_warp_1_2, 'depth_1': depth_1, 'depth_2': depth_2, 'scenef_1_2': sflow_1_2, 'global_p1': global_p1, 'staticflow_1_2': depth_flow_1_2_static, 'p1_camera_2': p1_camera_2, 'warped_p2_camera_2': warped_p2_camera_2} 279 | 280 | 281 | class BackwardWarp(nn.Module): 282 | 283 | def __init__(self, is_one_way=False): 284 | super().__init__() 285 | self.coord = None 286 | self.sample_grid = None 287 | self.is_one_way = is_one_way 288 | 289 | def backward_warp(self, buffer, flow_1_2): 290 | 291 | B, _, H, W = buffer.shape 292 | coord = self.coord[..., :2].view(1, H, W, 2).expand([B, H, W, 2]) 293 | sample_grids = coord + flow_1_2 294 | sample_grids[..., 0] /= (W - 1) / 2 295 | sample_grids[..., 1] /= (H - 1) / 2 296 | sample_grids -= 1 297 | return F.grid_sample(buffer, sample_grids, align_corners=True, padding_mode='border') 298 | 299 | def forward(self, buffer, flow_1_2): 300 | B, _, H, W = buffer.shape 301 | if self.coord is None: 302 | yy, xx = torch.meshgrid(torch.arange(H).float(), torch.arange(W).float()) 303 | self.coord = torch.ones([1, H, W, 1, 3]) 304 | self.coord[0, ..., 0, 0] = xx 305 | self.coord[0, ..., 0, 1] = yy 306 | self.coord = self.coord.to(buffer.device) 307 | return self.backward_warp(buffer, flow_1_2) 308 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import importlib 16 | 17 | 18 | def get_model(alias, test=False): 19 | module = importlib.import_module('models.' + alias) 20 | return module.Model 21 | -------------------------------------------------------------------------------- /models/video_base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from os.path import join, dirname 16 | import numpy as np 17 | import torch 18 | from models.netinterface import NetInterface 19 | from os import makedirs 20 | import matplotlib.pyplot as plt 21 | import matplotlib.cm 22 | from matplotlib.colors import ListedColormap 23 | from third_party.util_colormap import turbo_colormap_data 24 | # matplotlib.cm.register_cmap('turbo', cmap=ListedColormap(turbo_colormap_data)) 25 | import matplotlib 26 | import shutil 27 | 28 | 29 | class VideoBaseModel(NetInterface): 30 | def disp_loss(self, d1, d2): 31 | if self.opt.use_disp: 32 | t1 = torch.clamp(d1, min=1e-3) 33 | t2 = torch.clamp(d2, min=1e-3) 34 | return 300 * torch.abs((1 / t1) - (1 / t2)) 35 | else: 36 | return torch.abs(d1 - d2) 37 | 38 | def _train_on_batch(self, epoch, batch_ind, batch): 39 | for n in self._nets: 40 | n.zero_grad() 41 | # self.net_depth.eval() # freeze bn to check 42 | 43 | self.load_batch(batch) 44 | batch_size = batch['img_1'].shape[0] 45 | pred = self._predict_on_batch() 46 | loss, loss_data = self._calc_loss(pred) 47 | loss.backward() 48 | for optimizer in self._optimizers: 49 | optimizer.step() 50 | 51 | if np.mod(epoch, self.opt.vis_every_train) == 0: 52 | indx = batch_ind if self.opt.vis_at_start else self.opt.epoch_batches - batch_ind 53 | if indx <= self.opt.vis_batches_train: 54 | for k, v in pred.items(): 55 | pred[k] = v.data.cpu().numpy() 56 | outdir = join(self.full_logdir, 'visualize', 'epoch%04d_train' % epoch) 57 | makedirs(outdir, exist_ok=True) 58 | output = self.pack_output(pred, batch) 59 | if self.global_rank == 0: 60 | if self.visualizer is not None: 61 | self.visualizer.visualize(output, indx + (1000 * epoch), outdir) 62 | np.savez(join(outdir, 'rank%04d_batch%04d' % (self.global_rank, batch_ind)), **output) 63 | batch_log = {'size': batch_size, 'loss': loss.item(), **loss_data} 64 | return batch_log 65 | 66 | @staticmethod 67 | def depth2disp(depth): 68 | valid = depth > 1e-2 69 | valid = valid.float() 70 | return (1 / (depth + (1 - valid) * 1e-8)) * valid 71 | 72 | def disp_vali(self, d1, d2): 73 | vali = d2 > 1e-2 74 | return torch.nn.functional.mse_loss(self.depth2disp(d1) * vali, self.depth2disp(d2) * vali) 75 | 76 | def _vali_on_batch(self, epoch, batch_idx, batch): 77 | for n in self._nets: 78 | n.eval() 79 | self.load_batch(batch) 80 | with torch.no_grad(): 81 | pred = self._predict_on_batch(is_train=False) 82 | gt_depth = batch['depth_mvs'].to(pred['depth'].device) 83 | # try: 84 | loss = self.disp_vali(pred['depth'], gt_depth).item() 85 | # except: 86 | # print('error when eval losses, might be in test mode') 87 | # pass 88 | 89 | if np.mod(epoch, self.opt.vis_every_vali) == 0: 90 | if batch_idx < self.opt.vis_batches_vali: 91 | for k, v in pred.items(): 92 | pred[k] = v.cpu().numpy() 93 | outdir = join(self.full_logdir, 'visualize', 'epoch%04d_vali' % epoch) 94 | makedirs(outdir, exist_ok=True) 95 | output = self.pack_output(pred, batch) 96 | if self.global_rank == 0: 97 | if self.visualizer is not None: 98 | self.visualizer.visualize(output, batch_idx + (1000 * epoch), outdir) 99 | np.savez(join(outdir, 'rank%04d_batch%04d' % (self.global_rank, batch_idx)), **output) 100 | batch_size = batch['img'].shape[0] 101 | 102 | batch_log = {'size': batch_size, 'loss': loss} 103 | return batch_log 104 | 105 | def pack_output(self, pred_all, batch): 106 | batch_size = len(batch['pair_path']) 107 | if 'img' not in batch.keys(): 108 | img_1 = batch['img_1'].cpu().numpy() 109 | img_2 = batch['img_2'].cpu().numpy() 110 | else: 111 | img_1 = batch['img'] 112 | img_2 = batch['img'] 113 | output = {'batch_size': batch_size, 'img_1': img_1, 'img_2': img_2, **pred_all} 114 | 115 | if 'img' not in batch.keys(): 116 | output['flow_1_2'] = self._input.flow_1_2.cpu().numpy() 117 | output['flow_2_1'] = self._input.flow_2_1.cpu().numpy() 118 | output['depth_nn_1'] = batch['depth_pred_1'].cpu().numpy() 119 | 120 | else: 121 | output['depth_nn'] = batch['depth_pred'].cpu().numpy() 122 | output['depth_gt'] = batch['depth_mvs'].cpu().numpy() 123 | output['cam_c2w'] = batch['cam_c2w'].cpu().numpy() 124 | output['K'] = batch['K'].cpu().numpy() 125 | output['pair_path'] = batch['pair_path'] 126 | return output 127 | 128 | def test_on_batch(self, batch_idx, batch): 129 | if not hasattr(self, 'test_cache'): 130 | self.test_cache = [] 131 | for n in self._nets: 132 | n.eval() 133 | self.load_batch(batch) 134 | with torch.no_grad(): 135 | pred = self._predict_on_batch(is_train=False) 136 | 137 | if not hasattr(self, 'test_loss'): 138 | self.test_loss = 0 139 | 140 | for k, v in pred.items(): 141 | pred[k] = v.cpu().numpy() 142 | epoch_string = 'best' if self.opt.epoch < 0 else '%04d' % self.opt.epoch 143 | outdir = join(self.opt.output_dir, 'epoch%s_test' % epoch_string) 144 | if not hasattr(self, 'outdir'): 145 | self.outdir = outdir 146 | makedirs(outdir, exist_ok=True) 147 | output = self.pack_output(pred, batch) 148 | if batch_idx == 223: 149 | output['depth'][0, 0, 0, :] = output['depth'][0, 0, 2, :] 150 | output['depth'][0, 0, 1, :] = output['depth'][0, 0, 2, :] 151 | self.test_cache.append(output.copy()) 152 | if self.global_rank == 0: 153 | if self.visualizer is not None: 154 | self.visualizer.visualize(output, batch_idx, outdir) 155 | np.savez(join(outdir, 'batch%04d' % (batch_idx)), **output) 156 | 157 | def on_test_end(self): 158 | 159 | # make test video: 160 | from subprocess import call 161 | from util.util_html import Webpage 162 | from tqdm import tqdm 163 | depth_pred = [] 164 | depth_nn = [] 165 | depth_gt = [] 166 | imgs = [] 167 | c2ws = [] 168 | Ks = [] 169 | for pack in self.test_cache: 170 | depth_pred.append(pack['depth']) 171 | depth_nn.append(pack['depth_nn']) 172 | imgs.append(pack['img_1']) 173 | c2ws.append(pack['cam_c2w']) 174 | Ks.append(pack['K']) 175 | depth_gt.append(pack['depth_gt']) 176 | 177 | depth_pred = np.concatenate(depth_pred, axis=0) 178 | depth_nn = np.concatenate(depth_nn, axis=0) 179 | imgs = np.concatenate(imgs, axis=0) 180 | c2ws = np.concatenate(c2ws, axis=0) 181 | Ks = np.concatenate(Ks, axis=0) 182 | depth_gt = np.concatenate(depth_gt, axis=0) 183 | 184 | pred_max = depth_pred.max() 185 | 186 | pred_min = depth_pred.min() 187 | 188 | print(pred_max, pred_min) 189 | depth_cmap = 'turbo' 190 | mask_valid = np.where(depth_gt > 1e-8, 1, 0) 191 | 192 | for i in tqdm(range(depth_pred.shape[0])): 193 | plt.figure(figsize=[60, 20], dpi=40, facecolor='black') 194 | 195 | plt.subplot(1, 3, 1) 196 | plt.title('Refined', fontsize=100, color='w') 197 | plt.imshow(1 / depth_pred[i, 0, ...], cmap=depth_cmap, vmax=1 / pred_min, vmin=1 / pred_max) 198 | cbar = plt.colorbar(fraction=0.048 * 0.5, pad=0.01) 199 | plt.axis('off') 200 | cbar.ax.yaxis.set_tick_params(color='w', labelsize=40) 201 | plt.setp(plt.getp(cbar.ax.axes, 'yticklabels'), color='w') 202 | 203 | plt.subplot(1, 3, 2) 204 | plt.title('Initial', fontsize=100, color='w') 205 | plt.imshow(1 / depth_nn[i, 0, ...], cmap=depth_cmap, vmax=1 / pred_min, vmin=1 / pred_max) 206 | plt.axis('off') 207 | cbar = plt.colorbar(fraction=0.048 * 0.5, pad=0.01) 208 | cbar.ax.yaxis.set_tick_params(color='w', labelsize=40) 209 | plt.setp(plt.getp(cbar.ax.axes, 'yticklabels'), color='w') 210 | 211 | plt.subplot(1, 3, 3) 212 | plt.title('GT', fontsize=100, color='w') 213 | 214 | plt.imshow(mask_valid[i, 0, ...] / (depth_gt[i, 0, ...] + 1e-8), cmap=depth_cmap, vmax=1 / pred_min, vmin=1 / pred_max) 215 | plt.axis('off') 216 | cbar = plt.colorbar(fraction=0.048 * 0.5, pad=0.01) 217 | cbar.ax.yaxis.set_tick_params(color='w', labelsize=40) 218 | plt.setp(plt.getp(cbar.ax.axes, 'yticklabels'), color='w') 219 | plt.savefig(join(self.outdir, 'compare_%04d.png' % i), bbox_inches='tight', facecolor='black', dpi='figure') 220 | plt.close() 221 | 222 | plt.imshow(imgs[i, ...].transpose(1, 2, 0)) 223 | plt.axis('off') 224 | plt.savefig(join(self.outdir, 'rgb_%04d.png' % i), bbox_inches='tight', facecolor='black', dpi='figure') 225 | plt.close() 226 | 227 | epoch_string = self.outdir.split('/')[-1] 228 | 229 | gen_vid_command = 'ffmpeg -y -r 30 -i {img_template} -vcodec libx264 -crf 25 -pix_fmt yuv420p -vf "pad=ceil(iw/2)*2:ceil(ih/2)*2" {video_path} > /dev/null' 230 | gen_vid_command_slow = 'ffmpeg -y -r 2 -i {img_template} -vcodec libx264 -crf 25 -pix_fmt yuv420p -vf "pad=ceil(iw/2)*2:ceil(ih/2)*2" {video_path} > /dev/null' 231 | for r_number in range(120, 140): 232 | plt.figure(figsize=[60, 20], dpi=20, facecolor='black') 233 | 234 | plt.subplot(1, 2, 1) 235 | plt.title('Refined', fontsize=100, color='w') 236 | plt.imshow(1 / depth_pred[:, 0, r_number, :], cmap=depth_cmap) 237 | cbar = plt.colorbar(fraction=0.048 * 0.5, pad=0.01) 238 | plt.axis('off') 239 | cbar.ax.yaxis.set_tick_params(color='w', labelsize=40) 240 | plt.setp(plt.getp(cbar.ax.axes, 'yticklabels'), color='w') 241 | 242 | plt.subplot(1, 2, 2) 243 | plt.title('Initial', fontsize=100, color='w') 244 | plt.imshow(1 / depth_nn[:, 0, r_number, :], cmap=depth_cmap) 245 | 246 | plt.axis('off') 247 | cbar = plt.colorbar(fraction=0.048 * 0.5, pad=0.01) 248 | cbar.ax.yaxis.set_tick_params(color='w', labelsize=40) 249 | plt.setp(plt.getp(cbar.ax.axes, 'yticklabels'), color='w') 250 | plt.savefig(join(self.outdir, 'temporal_slice_%04d.png' % (r_number - 120)), bbox_inches='tight', facecolor='black', dpi='figure') 251 | plt.close() 252 | 253 | img_template = join(self.outdir, 'compare_%04d.png') 254 | 255 | img_template_t = join(self.outdir, 'temporal_slice_%04d.png') 256 | 257 | video_path = join(dirname(self.outdir), epoch_string + '.mp4') 258 | 259 | video_path_t = join(dirname(self.outdir), epoch_string + '_temporal.mp4') 260 | 261 | gen_vid_command_c = gen_vid_command.format(img_template=img_template, video_path=video_path) 262 | call(gen_vid_command_c, shell=True) 263 | gen_vid_command_t = gen_vid_command_slow.format(img_template=img_template_t, video_path=video_path_t) 264 | 265 | call(gen_vid_command_t, shell=True) 266 | 267 | web = Webpage() 268 | 269 | web.add_video(epoch_string + '_rgb.mp4', title='original video') 270 | web.add_video(epoch_string + '.mp4', title=f'Disparity loss {self.test_loss}') 271 | 272 | web.save(join(dirname(self.outdir), epoch_string + '.html')) 273 | 274 | @staticmethod 275 | def copy_and_make_dir(src, target): 276 | fname = dirname(target) 277 | makedirs(fname, exist_ok=True) 278 | shutil.copy(src, target) 279 | 280 | @staticmethod 281 | def scale_tesnor(t): 282 | t = (t - t.min()) / (t.max() - t.min() + 1e-9) 283 | return t 284 | 285 | 286 | # %% 287 | -------------------------------------------------------------------------------- /networks/FCNUnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import torch 17 | from torch import nn 18 | from .blocks import Conv2dBlock, DoubleConv2dBlock 19 | 20 | 21 | class FCNUnet(nn.Module): 22 | # 23 | def __init__(self, conv_setup, n_down=4, feat=32, block_type='conv', 24 | down_sample_type='avgpool', in_channel=2, out_channel=64, dialated_pool=False, output_activation=None): 25 | super().__init__() 26 | assert down_sample_type in ['avgpool', 'maxpool', 'none'] 27 | if block_type == 'conv': 28 | Block = Conv2dBlock 29 | elif block_type == 'double_conv': 30 | Block = DoubleConv2dBlock 31 | else: 32 | raise NotImplementedError(f'block type {block_type} not supported') 33 | # 2x downsampling using avgpool 34 | if down_sample_type == 'avgpool': 35 | self.down_sample = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) 36 | elif down_sample_type == 'maxpool': 37 | self.down_sample = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 38 | else: 39 | self.down_sample = None 40 | 41 | self.upsample = nn.Upsample( 42 | scale_factor=2, mode='bilinear', align_corners=True) 43 | # this is not Implemented yet 44 | 45 | self.n_down = n_down 46 | self.down_conv = [] 47 | self.up_conv = [] 48 | 49 | ch_in = in_channel 50 | ch_out = feat 51 | for k in range(n_down): 52 | self.down_conv += [Block(ch_in, ch_out, kernel_size=3, padding=1, **conv_setup)] 53 | self.add_module('down_%02d' % k, self.down_conv[-1]) 54 | ch_in = ch_out 55 | ch_out = ch_out * 2 56 | self.mid_conv = Block(ch_in, ch_in, kernel_size=3, padding=1, **conv_setup) 57 | 58 | for k in range(n_down - 1): 59 | self.up_conv += [Block(ch_in * 2, ch_in // 2, 60 | padding=1, kernel_size=3, **conv_setup)] 61 | self.add_module('up_%04d' % k, self.up_conv[-1]) 62 | ch_in = ch_in // 2 63 | # This is for matching original unet implementation. 64 | self.up_conv += [Block(ch_in * 2, ch_in, padding=1, 65 | kernel_size=3, **conv_setup)] 66 | self.add_module('up_%04d' % (k + 1), self.up_conv[-1]) 67 | 68 | conv_setup['activation'] = 'none' 69 | conv_setup['norm'] = 'none' 70 | self.output_conv = Conv2dBlock( 71 | ch_in, out_channel, kernel_size=1, **conv_setup) 72 | # self.add_module('output', self.up_conv[-1]) 73 | if output_activation == 'tanh': 74 | self.final_act = nn.Tanh() 75 | elif output_activation == 'sigmoid': 76 | self.final_act = nn.Sigmoid() 77 | else: 78 | self.final_act = nn.Identity() 79 | 80 | def forward(self, x): 81 | feat = [] 82 | for module in self.down_conv: 83 | x = module(x) 84 | feat.append(x) 85 | x = self.down_sample(x) 86 | x = self.mid_conv(x) 87 | for idm, module in enumerate(self.up_conv): 88 | up_x = self.upsample(x) 89 | f = feat[-(idm + 1)] 90 | x = module(torch.cat([f, up_x], 1)) 91 | x = self.output_conv(x) 92 | return self.final_act(x) 93 | -------------------------------------------------------------------------------- /networks/MLP.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from torch import nn 17 | from .blocks import PeriodicEmbed 18 | 19 | 20 | class EmbededMLP(nn.Module): 21 | def __init__(self, in_ch=3, out_ch=3, depth=3, width=64, N_freq=8, skip=3, act_fn=nn.functional.leaky_relu, output_act=None, norm=None, init_val=None): 22 | super().__init__() 23 | self.embed = PeriodicEmbed(N_freq=N_freq, linspace=False) 24 | N_input_channel = in_ch + in_ch * 2 * N_freq 25 | self.layers = [] 26 | self.skip = skip 27 | self.layers.append(DenseLayer(N_input_channel, width, act_fn, norm)) 28 | for d in range(depth - 1): 29 | if (d + 1) % skip == 0 and d > 0: 30 | self.layers.append(DenseLayer(width + N_input_channel, width, act_fn, norm)) 31 | else: 32 | self.layers.append(DenseLayer(width, width, act_fn, norm)) 33 | self.layers.append(DenseLayer(width, out_ch, output_act, norm=None)) 34 | for idl, l in enumerate(self.layers): 35 | self.add_module(f'layer_{idl:03d}', l) 36 | 37 | if init_val is not None: 38 | self.layers[-1].linear.bias.data.fill_(init_val) 39 | 40 | def forward(self, x): 41 | x = self.embed(x) 42 | embed = x 43 | 44 | for idl, l in enumerate(self.layers): 45 | if idl % self.skip == 0 and idl > 0 and idl < len(self.layers) - 1: 46 | x = torch.cat([x, embed], -1) 47 | x = l(x) 48 | return x 49 | 50 | 51 | class MLP(nn.Module): 52 | def __init__(self, in_ch=64, out_ch=3, depth=3, width=64, act_fn=nn.functional.relu, output_act=None, norm=None): 53 | super().__init__() 54 | layers = [] 55 | layers.append(DenseLayer(in_ch, width, act_fn, norm)) 56 | for d in range(depth - 1): 57 | layers.append(DenseLayer(width, width, act_fn, norm)) 58 | layers.append(DenseLayer(width, out_ch, output_act, norm=None)) 59 | self.model = nn.Sequential(*layers) 60 | 61 | def forward(self, x): 62 | return self.model(x) 63 | 64 | 65 | class DenseLayer(nn.Module): 66 | def __init__(self, in_ch, out_ch, act_fn=None, norm=None): 67 | super().__init__() 68 | self.linear = nn.Linear(in_ch, out_ch) 69 | if act_fn is None: 70 | self.act_fn = nn.Identity() 71 | else: 72 | self.act_fn = act_fn 73 | if norm is None: 74 | self.norm = nn.Identity() 75 | else: 76 | self.norm = norm 77 | 78 | def forward(self, x): 79 | x = self.linear(x) 80 | x = self.norm(x) 81 | x = self.act_fn(x) 82 | return x 83 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /networks/blocks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from torch import nn 16 | import torch 17 | 18 | 19 | class PeriodicEmbed(nn.Module): 20 | def __init__(self, max_freq=5, N_freq=4, linspace=True): 21 | super().__init__() 22 | self.embed_functions = [torch.cos, torch.sin] 23 | if linspace: 24 | self.freqs = torch.linspace(1, max_freq + 1, steps=N_freq) 25 | else: 26 | exps = torch.linspace(0, N_freq - 1, steps=N_freq) 27 | self.freqs = 2**exps 28 | 29 | def forward(self, x): 30 | output = [x] 31 | for f in self.embed_functions: 32 | for freq in self.freqs: 33 | output.append(f(freq * x)) 34 | return torch.cat(output, 1) 35 | 36 | 37 | class DoubleConv2dBlock(nn.Module): 38 | def __init__(self, input_dim, output_dim, kernel_size, stride=1, 39 | padding=0, dilation=1, norm='weight', activation='relu', pad_type='zero', use_bias=True, **kargs): 40 | super().__init__() 41 | self.model = nn.Sequential(Conv2dBlock(input_dim, output_dim, kernel_size, stride, 42 | padding, dilation, norm, activation, pad_type, use_bias), 43 | Conv2dBlock(output_dim, output_dim, kernel_size, stride, 44 | padding, dilation, norm, activation, pad_type, use_bias)) 45 | 46 | def forward(self, x): 47 | return self.model(x) 48 | 49 | 50 | class Conv2dBlock(nn.Module): 51 | def __init__(self, input_dim, output_dim, kernel_size, stride, 52 | padding=0, dilation=1, norm='weight', activation='relu', pad_type='zero', use_bias=True, *args, **karg): 53 | super(Conv2dBlock, self).__init__() 54 | 55 | self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, 56 | padding=0, dilation=dilation, bias=use_bias) 57 | 58 | # initialize padding 59 | if pad_type == 'reflect': 60 | self.pad = nn.ReflectionPad2d(padding) 61 | elif pad_type == 'zero': 62 | self.pad = nn.ZeroPad2d(padding) 63 | else: 64 | assert 0, "Unsupported padding type: {}".format(pad_type) 65 | 66 | # initialize normalization 67 | norm_dim = output_dim 68 | if norm == 'batch': 69 | self.norm = nn.BatchNorm2d(norm_dim) 70 | elif norm == 'inst': 71 | self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=False) 72 | elif norm == 'ln': 73 | self.norm = nn.LayerNorm(norm_dim) 74 | elif norm == 'none': 75 | self.norm = nn.Identity() 76 | elif norm == 'weight': 77 | self.conv = nn.utils.weight_norm(self.conv) 78 | self.norm = nn.Identity() 79 | else: 80 | assert 0, "Unsupported normalization: {}".format(norm) 81 | 82 | # initialize activation 83 | if activation == 'relu': 84 | self.activation = nn.ReLU(inplace=True) 85 | elif activation == 'lrelu': 86 | self.activation = nn.LeakyReLU(0.2, inplace=True) 87 | elif activation == 'prelu': 88 | self.activation = nn.PReLU() 89 | elif activation == 'selu': 90 | self.activation = nn.SELU(inplace=True) 91 | elif activation == 'tanh': 92 | self.activation = nn.Tanh() 93 | elif activation == 'none': 94 | self.activation = nn.Identity() 95 | else: 96 | assert 0, "Unsupported activation: {}".format(activation) 97 | 98 | def forward(self, x): 99 | x = self.conv(self.pad(x)) 100 | x = self.norm(x) 101 | x = self.activation(x) 102 | return x 103 | 104 | 105 | class ResConv2DBlock(nn.Module): 106 | def __init__(self, dim_in, dim_out, kernel_size=3, stride=1, padding=0, dilation=1, 107 | norm='weight', activation='relu', pad_type='zero', use_bias=True): 108 | model = [] 109 | model += [Conv2dBlock(dim_in, dim_out, kernel_size, stride, padding, dilation=dilation, norm=norm, 110 | activation=activation, pad_type=pad_type, use_bias=use_bias)] 111 | model += [Conv2dBlock((dim_in + dim_out) // 2, dim_out, kernel_size, stride, padding, dilation=dilation, norm=norm, 112 | activation=activation, pad_type=pad_type, use_bias=use_bias)] 113 | if dim_in != dim_out: 114 | self.skip = Conv2dBlock(dim_in, dim_out, 1, stride, padding, dilation=dilation, norm=norm, 115 | activation=activation, pad_type=pad_type, use_bias=use_bias) 116 | else: 117 | self.skip = nn.Indentity() 118 | self.model = nn.Sequential(*model) 119 | 120 | def forward(self, x): 121 | res = self.skip(x) 122 | out = self.model(x) 123 | return out + res 124 | -------------------------------------------------------------------------------- /networks/sceneflow_field.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from torch import nn 17 | from .blocks import PeriodicEmbed, Conv2dBlock 18 | 19 | 20 | class SceneFlowFieldNet(nn.Module): 21 | def __init__(self, time_dependent=True, N_freq_xyz=0, N_freq_t=0, output_dim=3, net_width=32, n_layers=3, activation='lrelu', norm='none'): 22 | super().__init__() 23 | N_input_channel_xyz = 3 + 3 * 2 * N_freq_xyz 24 | N_input_channel_t = 1 + 1 * 2 * N_freq_t 25 | N_input_channel = N_input_channel_xyz + N_input_channel_t if time_dependent else N_input_channel_xyz 26 | if N_freq_xyz == 0: 27 | xyz_embed = nn.Identity() 28 | else: 29 | xyz_embed = PeriodicEmbed(max_freq=N_freq_xyz, N_freq=N_freq_xyz) 30 | if N_freq_t == 0: 31 | t_embed = nn.Identity() 32 | else: 33 | t_embed = PeriodicEmbed(max_freq=N_freq_t, N_freq=N_freq_t) 34 | convs = [Conv2dBlock(N_input_channel, net_width, 1, 1, norm=norm, activation=activation)] 35 | for i in range(n_layers): 36 | convs.append(Conv2dBlock(net_width, net_width, 1, 1, norm=norm, activation=activation)) 37 | convs.append(Conv2dBlock(net_width, output_dim, 1, 1, norm='none', activation='none')) 38 | self.convs = nn.Sequential(*convs) 39 | self.t_embed = t_embed 40 | self.xyz_embed = xyz_embed 41 | self.time_dependent = time_dependent 42 | 43 | def forward(self, x, t=None): 44 | x = x.contiguous() 45 | if t is None and self.time_dependent: 46 | raise ValueError 47 | xyz_embedded = self.xyz_embed(x) 48 | if self.time_dependent: 49 | t_embedded = self.t_embed(t) 50 | input_feat = torch.cat([t_embedded, xyz_embedded], 1) 51 | else: 52 | input_feat = xyz_embedded 53 | return self.convs(input_feat) 54 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /options/options_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import sys 16 | import argparse 17 | from datasets import get_dataset 18 | from models import get_model 19 | 20 | 21 | def add_general_arguments(parser): 22 | 23 | # GPU 24 | parser.add_argument('--gpu', type=str, required=True, help='gpu idx') 25 | 26 | # dataset 27 | parser.add_argument('--dataset', type=str, required=True, help='name of the dataset') 28 | 29 | # dataloader 30 | parser.add_argument('--workers', type=int, default=4, 31 | help='number of data loading workers') 32 | parser.add_argument('--batch_size', type=int, default=16, 33 | help='training batch size') 34 | 35 | # Network 36 | parser.add_argument('--net', type=str, required=True, help='name of the model') 37 | parser.add_argument('--checkpoint_path', type=str, required=True, help='checkpoint path') 38 | parser.add_argument('--epoch', type=int, default=-1, help='epoch id for testing') 39 | 40 | # Output 41 | parser.add_argument('--output_dir', type=str, required=True, 42 | help="Output directory") 43 | parser.add_argument('--overwrite', action='store_true', 44 | help="Whether to overwrite the output folder if it exists") 45 | parser.add_argument('--suffix', default='epoch_{epoch}', type=str, 46 | help="Suffix for `logdir` that will be formatted with `opt`, e.g., '{classes}_lr{lr}'") 47 | 48 | # visualizer 49 | parser.add_argument('--html_logger', action='store_true', 50 | help="use html_logger for visualization") 51 | 52 | # Misc 53 | parser.add_argument('--manual_seed', type=int, default=None, 54 | help='manual seed for randomness') 55 | 56 | return parser 57 | 58 | 59 | def parse(add_additional_arguments=None): 60 | parser = argparse.ArgumentParser() 61 | parser = add_general_arguments(parser) 62 | if add_additional_arguments: 63 | parser, _ = add_additional_arguments(parser) 64 | opt_general, _ = parser.parse_known_args() 65 | net_name = opt_general.net 66 | 67 | dataset_name = opt_general.dataset 68 | # Add parsers depending on dataset and models 69 | parser, _ = get_dataset(dataset_name).add_arguments(parser) 70 | parser, _ = get_model(net_name).add_arguments(parser) 71 | 72 | # Manually add '-h' after adding all parser arguments 73 | if '--printhelp' in sys.argv: 74 | sys.argv.append('-h') 75 | 76 | opt = parser.parse_args() 77 | return opt 78 | -------------------------------------------------------------------------------- /options/options_train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import sys 16 | import argparse 17 | import torch 18 | from util.util_print import str_warning 19 | from datasets import get_dataset 20 | from models import get_model 21 | 22 | 23 | def add_general_arguments(parser): 24 | # Parameters that will NOT be overwritten when resuming 25 | unique_params = {'gpu', 'resume', 'epoch', 'workers', 26 | 'batch_size', 'save_net', 'epoch_batches', 'logdir', 'pt_no_overwrite', 'full_logdir', 'vis_batches_vali', 'vali_batches', 'vali_at_start', 'vis_every_vali'} 27 | 28 | parser.add_argument('--gpu', default='none', type=str, 29 | help='gpu to use') 30 | parser.add_argument('--manual_seed', type=int, default=None, 31 | help='manual seed for randomness') 32 | parser.add_argument('--resume', type=int, default=0, 33 | help='resume training by loading checkpoint.pt or best.pt. Use 0 for training from scratch, -1 for last and -2 for previous best. Use positive number for a specific epoch. \ 34 | Most options will be overwritten to resume training with exactly same environment') 35 | parser.add_argument('--suffix', default='', type=str, 36 | help="Suffix for `logdir` that will be formatted with `opt`, e.g., '{classes}_lr{lr}'") 37 | parser.add_argument('--epoch', type=int, default=0, 38 | help='number of epochs to train') 39 | parser.add_argument('--force_overwrite', action='store_true', 40 | help='force to overwrite previous experiments, without keyborad confirmations') 41 | 42 | # Dataset IO 43 | parser.add_argument('--dataset', type=str, default=None, 44 | help='dataset to use') 45 | parser.add_argument('--workers', type=int, default=4, 46 | help='number of data loading workers') 47 | parser.add_argument('--batch_size', type=int, default=16, 48 | help='training batch size') 49 | parser.add_argument('--no_batching', action='store_true', help='do not use batching.') 50 | parser.add_argument('--epoch_batches', default=None, 51 | type=int, help='number of batches used per epoch') 52 | parser.add_argument('--vali_batches', default=None, 53 | type=int, help='max number of batches used for validation per epoch') 54 | parser.add_argument('--vali_at_start', action='store_true', 55 | help='run validation before starting to train') 56 | parser.add_argument('--log_time', action='store_true', 57 | help='adding time log') 58 | parser.add_argument('--print_net', action='store_true', 59 | help="print network") 60 | 61 | # Distributed Training 62 | parser.add_argument('--multiprocess_distributed', action='store_true', 63 | help='set this flag to enable multiprocess distributed training. \ 64 | This would spawn an individual process for each GPU') 65 | parser.add_argument('--world_size', type=int, default=1, 66 | help='Number of nodes in multiprocess distributed training. Number of processes in plain distributed training.') 67 | parser.add_argument('--node_rank', type=int, default=0, 68 | help='Specify the node_rank for distributed multiprocess training.') 69 | parser.add_argument('--dist_backend', type=str, default='nccl', choices=['nccl', 'gloo', 'mpi'], 70 | help='the backend for distributed training.') 71 | parser.add_argument('--init_url', type=str, default='tcp://127.0.0.1:60504', 72 | help='init url for process group initialziation') 73 | 74 | 75 | # Network name 76 | parser.add_argument('--net', type=str, required=True, 77 | help='network type to use') 78 | 79 | # Optimizer 80 | parser.add_argument('--optim', type=str, default='adam', 81 | help='optimizer to use') 82 | parser.add_argument('--lr', type=float, default=1e-4, 83 | help='learning rate') 84 | parser.add_argument('--adam_beta1', type=float, default=0.5, 85 | help='beta1 of adam') 86 | parser.add_argument('--adam_beta2', type=float, default=0.9, 87 | help='beta2 of adam') 88 | parser.add_argument('--sgd_momentum', type=float, default=0.9, 89 | help="momentum factor of SGD") 90 | parser.add_argument('--sgd_dampening', type=float, default=0, 91 | help="dampening for momentum of SGD") 92 | parser.add_argument('--wdecay', type=float, default=0.0, 93 | help='weight decay') 94 | 95 | # initialization 96 | parser.add_argument('--init_type', type=str, default='normal', help='type of initialziation to use') 97 | 98 | 99 | # Mixed precision training 100 | parser.add_argument('--mixed_precision_training', action='store_true', 101 | help='use mixed precision for training.') 102 | parser.add_argument('--loss_scaling', type=float, default=255, 103 | help='the loss scale factor for mixed precision training. Set to -1 for dynamic scaling.') 104 | 105 | 106 | # Logging and visualization 107 | parser.add_argument('--logdir', type=str, default=None, 108 | help='Root directory for logging. Actual dir is [logdir]/[net_classes_dataset]/[expr_id]') 109 | parser.add_argument('--full_logdir', type=str, default=None, 110 | help='having the option to override this. this is useful for resuming to another dataset for finetuning purposes.') 111 | parser.add_argument('--exprdir_no_prefix', action='store_true', 112 | help='do not append the prefix to logdir. without this, expr_dir is set to net_classes_dataset') 113 | parser.add_argument('--pt_no_overwrite', action='store_true', 114 | help='having the option to not overwrite previous pt for on the fly eval purposes.') 115 | parser.add_argument('--log_batch', action='store_true', 116 | help='Log batch loss') 117 | parser.add_argument('--progbar_interval', type=float, default=0.05, 118 | help='time interval for updating the progbar') 119 | parser.add_argument('--no_accum', action='store_true', 120 | help='progbar show batch level loss values instead of mean.') 121 | parser.add_argument('--expr_id', type=int, default=0, 122 | help='Experiment index. non-positive ones are overwritten by default. Use 0 for code test. ') 123 | parser.add_argument('--save_net', type=int, default=1, 124 | help='Period of saving network weights') 125 | parser.add_argument('--save_net_opt', action='store_true', 126 | help='Save optimizer state in regular network saving') 127 | parser.add_argument('--vis_every_vali', default=1, type=int, 128 | help="Visualize every N epochs during validation") 129 | parser.add_argument('--vis_every_train', default=1, type=int, 130 | help="Visualize every N epochs during training") 131 | parser.add_argument('--vis_batches_vali', type=int, default=10, 132 | help="# batches to visualize during validation") 133 | parser.add_argument('--vis_batches_train', type=int, default=10, 134 | help="# batches to visualize during training") 135 | parser.add_argument('--tensorboard', action='store_true', 136 | help='Use tensorboard for logging. If enabled, the output log will be at [logdir]/[tensorboard]/[net_classes_dataset]/[expr_id]') 137 | parser.add_argument('--tensorboard_keyword', type=str, default='checkpoints', 138 | help='this is used to search for keywords in logdir and split according to this keyword. all tensorboard is logged at this dir. For example, if the logdir is /parent_dir/keyword/child_dir, then tensorboard would log to /parent_dir/keyword/tensorboard/child_dir/opt.expr_dir/expr_id/. Use \'none\' to disable this feature.') 139 | parser.add_argument('--html_logger', action='store_true', help='use html logger for images visualization') 140 | parser.add_argument('--vis_workers', default=2, type=int, 141 | help="# workers for the visualizer") 142 | parser.add_argument('--vis_param_f', default=None, type=str, 143 | help="Parameter file read by the visualizer on every batch; defaults to 'visualize/config.json'") 144 | parser.add_argument('--vis_at_start', action='store_true', help='visualize the first batches in an epoch instead of the last ones.') 145 | parser.add_argument('--test_template', type=str, default=None, help='test command template path') 146 | 147 | return parser, unique_params 148 | 149 | 150 | def overwrite(opt, opt_f_old, unique_params): 151 | opt_dict = vars(opt) 152 | opt_dict_old = torch.load(opt_f_old) 153 | for k, v in opt_dict_old.items(): 154 | if k in opt_dict: 155 | if (k not in unique_params) and (opt_dict[k] != v): 156 | print(str_warning, "Overwriting %s for resuming training: %s -> %s" 157 | % (k, str(opt_dict[k]), str(v))) 158 | opt_dict[k] = v 159 | else: 160 | print(str_warning, "Ignoring %s, an old option that no longer exists" % k) 161 | opt = argparse.Namespace(**opt_dict) 162 | return opt 163 | 164 | 165 | def parse(add_additional_arguments=None): 166 | parser = argparse.ArgumentParser() 167 | parser, unique_params = add_general_arguments(parser) 168 | if add_additional_arguments is not None: 169 | parser, unique_params_additional = add_additional_arguments(parser) 170 | unique_params = unique_params.union(unique_params_additional) 171 | opt_general, _ = parser.parse_known_args() 172 | dataset_name, net_name = opt_general.dataset, opt_general.net 173 | del opt_general 174 | 175 | # Add parsers depending on dataset and models 176 | parser, unique_params_dataset = get_dataset( 177 | dataset_name).add_arguments(parser) 178 | parser, unique_params_model = get_model(net_name).add_arguments(parser) 179 | 180 | # Manually add '-h' after adding all parser arguments 181 | if '--printhelp' in sys.argv: 182 | sys.argv.append('-h') 183 | 184 | opt, unknown = parser.parse_known_args() 185 | if len(unknown) > 0: 186 | print(str_warning, f'ignoring unknown argument {unknown}') 187 | unique_params = unique_params.union(unique_params_dataset) 188 | unique_params = unique_params.union(unique_params_model) 189 | return opt, unique_params 190 | -------------------------------------------------------------------------------- /scripts/download_data_and_depth_ckpt.sh: -------------------------------------------------------------------------------- 1 | echo -e "\e[91m Downloading depth checkpoints\e[39m" 2 | gdown https://drive.google.com/uc?id=167YnhuCbWe51lnCAFY7lu_bxD2wx9EKb -O - --quiet | tar xvf - 3 | 4 | echo -e "\e[91m Downloading example data\e[39m" 5 | gdown https://drive.google.com/uc?id=1Y7-Q2nBIuVmkFSQZkZjHJHpk3KjbFwaa -O - --quiet | tar xvf - 6 | 7 | 8 | -------------------------------------------------------------------------------- /scripts/download_triangulation_files.sh: -------------------------------------------------------------------------------- 1 | echo -e "\e[91m Downloading DAVIS triangulation data\e[39m" 2 | gdown https://drive.google.com/uc?id=1U07e9xtwYbBZPpJ2vfsLaXYMWATt4XyB -O - --quiet | tar xvf - 3 | 4 | 5 | echo -e "\e[91m Downloading shutterstock triangulation data\e[39m" 6 | gdown https://drive.google.com/uc?id=1om58tVKujaq1Jo_ShpKc4sWVAWBoKY6U -O - --quiet | tar xvf - -------------------------------------------------------------------------------- /scripts/preprocess/davis/generate_flows.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import sys 17 | import os 18 | from os.path import join 19 | sys.path.append('./third_party/RAFT') 20 | sys.path.append('./third_party/RAFT/core') 21 | from raft import RAFT 22 | import numpy as np 23 | import torch.nn.functional as F 24 | from functools import lru_cache 25 | from glob import glob 26 | import argparse 27 | from tqdm import tqdm 28 | import subprocess 29 | 30 | try: 31 | import cv2 32 | except ImportError: 33 | subprocess.check_call([sys.executable, "-m", "pip", "install", 'opencv-python']) 34 | finally: 35 | import cv2 36 | 37 | from skimage.transform import resize as imresize 38 | 39 | 40 | data_list_root = "./datafiles/davis_processed/frames_midas" 41 | outpath = './datafiles/davis_processed/flow_pairs' 42 | 43 | 44 | def resize_flow(flow, size): 45 | resized_width, resized_height = size 46 | H, W = flow.shape[:2] 47 | scale = np.array((resized_width / float(W), resized_height / float(H))).reshape( 48 | 1, 1, -1 49 | ) 50 | resized = cv2.resize( 51 | flow, dsize=(resized_width, resized_height), interpolation=cv2.INTER_CUBIC 52 | ) 53 | resized *= scale 54 | return resized 55 | 56 | 57 | def get_oob_mask(flow_1_2): 58 | H, W, _ = flow_1_2.shape 59 | hh, ww = torch.meshgrid(torch.arange(H).float(), torch.arange(W).float()) 60 | coord = torch.zeros([H, W, 2]) 61 | coord[..., 0] = ww 62 | coord[..., 1] = hh 63 | target_range = coord + flow_1_2 64 | m1 = (target_range[..., 0] < 0) + (target_range[..., 0] > W - 1) 65 | m2 = (target_range[..., 1] < 0) + (target_range[..., 1] > H - 1) 66 | return (m1 + m2).float().numpy() 67 | 68 | 69 | def backward_flow_warp(im2, flow_1_2): 70 | H, W, _ = im2.shape 71 | hh, ww = torch.meshgrid(torch.arange(H).float(), torch.arange(W).float()) 72 | coord = torch.zeros([1, H, W, 2]) 73 | coord[0, ..., 0] = ww 74 | coord[0, ..., 1] = hh 75 | sample_grids = coord + flow_1_2[None, ...] 76 | sample_grids[..., 0] /= (W - 1) / 2 77 | sample_grids[..., 1] /= (H - 1) / 2 78 | sample_grids -= 1 79 | im = torch.from_numpy(im2).float().permute(2, 0, 1)[None, ...] 80 | out = F.grid_sample(im, sample_grids, align_corners=True) 81 | o = out[0, ...].permute(1, 2, 0).numpy() 82 | return o 83 | 84 | 85 | def get_L2_error_map(v1, v2): 86 | return np.linalg.norm(v1 - v2, axis=-1) 87 | 88 | 89 | def load_RAFT(): 90 | parser = argparse.ArgumentParser() 91 | parser.add_argument('--model', help="restore checkpoint") 92 | parser.add_argument('--path', help="dataset for evaluation") 93 | parser.add_argument('--small', action='store_true', help='use small model') 94 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') 95 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') 96 | args = parser.parse_args(['--model', './third_party/RAFT/models/raft-sintel.pth', '--path', './']) 97 | net = torch.nn.DataParallel(RAFT(args).cuda()) 98 | net.load_state_dict(torch.load(args.model)) 99 | return net 100 | 101 | 102 | @lru_cache(maxsize=200) 103 | def read_frame_data(key, frame_id): 104 | data = np.load(join(data_list_root, key, 'frame_%05d.npz' % frame_id)) 105 | data_dict = {} 106 | for k in data.keys(): 107 | data_dict[k] = data[k] 108 | return data_dict 109 | 110 | 111 | net = load_RAFT() 112 | 113 | 114 | def generate_pair_data(key, frame_id_1, frame_id_2, save=True): 115 | im1_data = read_frame_data(key, frame_id_1) 116 | im2_data = read_frame_data(key, frame_id_2) 117 | 118 | im1 = im1_data['img_orig'] * 255 119 | im2 = im2_data['img_orig'] * 255 120 | im1 = imresize(im1, [288, 512], anti_aliasing=True) 121 | im2 = imresize(im2, [288, 512], anti_aliasing=True) 122 | 123 | images = [im1, im2] 124 | images = np.array(images).transpose(0, 3, 1, 2) 125 | im = torch.from_numpy(images.astype(np.float32)).cuda() 126 | with torch.no_grad(): 127 | flow_low, flow_up = net(image1=im[0:1, ...], image2=im[1:2, ...], iters=20, test_mode=True) 128 | flow_1_2 = flow_up.squeeze().permute(1, 2, 0).cpu().numpy() 129 | 130 | H, W, _ = im1_data['img'].shape 131 | flow_1_2 = resize_flow(flow_1_2, [W, H]) 132 | 133 | with torch.no_grad(): 134 | flow_low, flow_up = net(image1=im[1:2, ...], image2=im[0:1, ...], iters=20, test_mode=True) 135 | flow_2_1 = flow_up.squeeze().permute(1, 2, 0).cpu().numpy() 136 | 137 | flow_2_1 = resize_flow(flow_2_1, [W, H]) 138 | 139 | warp_flow_1_2 = backward_flow_warp(flow_1_2, flow_2_1) # using latter to sample former 140 | err_1 = np.linalg.norm(warp_flow_1_2 + flow_2_1, axis=-1) 141 | mask_1 = np.where(err_1 > 1, 1, 0) 142 | oob_mask_1 = get_oob_mask(flow_2_1) 143 | mask_1 = np.clip(mask_1 + oob_mask_1, a_min=0, a_max=1) 144 | warp_flow_2_1 = backward_flow_warp(flow_2_1, flow_1_2) 145 | err_2 = np.linalg.norm(warp_flow_2_1 + flow_1_2, axis=-1) 146 | mask_2 = np.where(err_2 > 1, 1, 0) 147 | oob_mask_2 = get_oob_mask(flow_1_2) 148 | mask_2 = np.clip(mask_2 + oob_mask_2, a_min=0, a_max=1) 149 | save_dict = {} 150 | save_dict['flow_1_2'] = flow_1_2.astype(np.float32) 151 | save_dict['flow_2_1'] = flow_2_1.astype(np.float32) 152 | save_dict['mask_1'] = mask_1.astype(np.uint8) 153 | save_dict['mask_2'] = mask_2.astype(np.uint8) 154 | save_dict['frame_id_1'] = frame_id_1 155 | save_dict['frame_id_2'] = frame_id_2 156 | if save: 157 | np.savez(join(outpath, key, f'flowpair_{frame_id_1:05d}_{frame_id_2:05d}.npz'), **save_dict) 158 | return 1 159 | else: 160 | return save_dict 161 | 162 | # %% 163 | 164 | 165 | track_names = sorted(glob(join(data_list_root, '*'))) 166 | track_names = ['dog', 'train'] 167 | track_ids = np.arange(len(track_names)) 168 | 169 | # %% 170 | for track_id in tqdm(track_ids): 171 | key = track_names[track_id] 172 | print(key) 173 | l = len(sorted(glob(join(data_list_root, key, 'frame_*.npz')))) 174 | os.makedirs(join(outpath, track_names[track_id]), exist_ok=True) 175 | gaps = [1, 2, 3, 4, 5, 6, 7, 8] 176 | for g in gaps: 177 | for k in tqdm(range(l - g)): 178 | generate_pair_data(key, k, k + g) 179 | -------------------------------------------------------------------------------- /scripts/preprocess/davis/generate_frame_midas.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import torch 17 | from os.path import join 18 | from os import makedirs 19 | from glob import glob 20 | from skimage.transform import resize as imresize 21 | from tqdm import tqdm 22 | from PIL import Image 23 | from scipy.ndimage import map_coordinates 24 | import trimesh 25 | import sys 26 | sys.path.insert(0, '') 27 | from configs import midas_pretrain_path 28 | from third_party.MiDaS import MidasNet 29 | 30 | model = MidasNet(midas_pretrain_path, non_negative=True, resize=[256, 512], normalize_input=True) 31 | 32 | model = model.eval().cuda() 33 | 34 | 35 | data_list_root = "./datafiles/DAVIS/JPEGImages/1080p" 36 | camera_path = "./datafiles/DAVIS/triangulation" 37 | mask_path = './datafiles/DAVIS/Annotations/1080p' 38 | outpath = './datafiles/davis_processed/frames_midas' 39 | 40 | track_names = ['train', 'dog'] 41 | track_ids = [0, 1] 42 | 43 | 44 | for track_id in track_ids: 45 | print(track_names[track_id]) 46 | frames = sorted(glob(join(data_list_root, f'{track_names[track_id]}', '*.jpg'))) 47 | mask_paths = sorted(glob(join(mask_path, f'{track_names[track_id]}', '*.png'))) 48 | makedirs(join(outpath, f'{track_names[track_id]}'), exist_ok=True) 49 | intrinsics_path = join(camera_path, f'{track_names[track_id]}.intrinsics.txt') 50 | extrinsics_path = join(camera_path, f'{track_names[track_id]}.matrices.txt') 51 | obj_path = join(camera_path, f'{track_names[track_id]}.obj') 52 | fx, fy, cx, cy = np.loadtxt(intrinsics_path)[0][1:] 53 | extrinsics = np.loadtxt(extrinsics_path) 54 | extrinsics = extrinsics[:, 1:] 55 | extrinsics = np.asarray([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])[None, ...]@np.linalg.inv(np.reshape(extrinsics, [-1, 4, 4])) 56 | mesh = trimesh.load(obj_path) 57 | points_3d = mesh.vertices 58 | h_pt = np.ones([points_3d.shape[0], 4]) 59 | h_pt[:, :3] = points_3d 60 | h_pt = h_pt.T 61 | intrinsics = np.zeros([3, 3]) 62 | intrinsics[[0, 0, 1, 1, 2], [0, 2, 1, 2, 2]] = [fx, cx, fy, cy, 1] 63 | 64 | print('calculating NN_depth') 65 | full_pred_depths = [] 66 | pts_list = [] 67 | 68 | mvs_depths = [] 69 | pred_depths = [] 70 | masks = [] 71 | for x in tqdm(range(len(frames))): 72 | img = np.asarray(Image.open(frames[x])).astype(np.float32) / 255 73 | 74 | img_batch = torch.from_numpy(img).permute(2, 0, 1)[None, ...].float().cuda() 75 | 76 | with torch.no_grad(): 77 | pred_d = model(img_batch) 78 | pred_d = pred_d.squeeze().cpu().numpy() 79 | full_pred_depths.append(pred_d) 80 | 81 | out = extrinsics[x, :]@h_pt 82 | im_pt = intrinsics @ out[:3, :] 83 | depth = im_pt[2, :].copy() 84 | im_pt = im_pt / im_pt[2:, :] 85 | 86 | mask = np.asarray(Image.open(mask_paths[x]).convert('RGB')).astype(np.float32)[:, :, 0] / 255 87 | masks.append(mask) 88 | H, W, _ = img.shape 89 | select_idx = np.where((im_pt[0, :] >= 0) * (im_pt[0, :] < W) * (im_pt[1, :] >= 0) * (im_pt[1, :] < H))[0] 90 | pts = im_pt[:, select_idx] 91 | depth = depth[select_idx] 92 | out = map_coordinates(mask, [pts[1, :], pts[0, :]]) 93 | select_idx = np.where(out < 0.1)[0] 94 | pts = pts[:, select_idx] 95 | depth = depth[select_idx] 96 | select_idx = np.where(depth > 1e-3)[0] 97 | pts = pts[:, select_idx] 98 | depth = depth[select_idx] 99 | 100 | pred_depth = map_coordinates(pred_d, [pts[1, :], pts[0, :]]) 101 | mvs_depths.append(depth) 102 | pred_depths.append(pred_depth) 103 | pts_list.append(pts) 104 | print(img.shape) 105 | 106 | print('calculating scale') 107 | scales = [] 108 | for x in tqdm(range(len(frames))): 109 | nn_depth = pred_depths[x] 110 | mvs_depth = mvs_depths[x] 111 | scales.append(np.median(nn_depth / mvs_depth)) 112 | s = np.mean(scales) 113 | 114 | print('saving per frame output') 115 | 116 | for idf, frame_path in tqdm(enumerate(frames)): 117 | img_orig = np.asarray(Image.open(frames[idf])).astype(np.float32) / 255 118 | max_W = 384 119 | multiple = 64 120 | H, W, _ = img_orig.shape 121 | if W > max_W: 122 | sc = max_W / W 123 | target_W = max_W 124 | else: 125 | target_W = W 126 | target_H = int(np.round((H * sc) / multiple) * multiple) 127 | 128 | img = imresize(img_orig, ([target_H, target_W]), preserve_range=True).astype(np.float32) 129 | 130 | T_G_1 = extrinsics[idf, ...] # world2cam 131 | T_G_1[:3, 3] *= s 132 | T_G_1 = np.linalg.inv(T_G_1) # cam2world 133 | T_G_1 = T_G_1.astype(np.float32) 134 | depth_mvs = imresize(full_pred_depths[idf].astype(np.float32), ([target_H, target_W]), preserve_range=True).astype(np.float32) 135 | in_1 = intrinsics.copy() 136 | in_1[0, 0] /= W / target_W 137 | in_1[1, 1] /= H / target_H 138 | in_1[0, 2] = (target_W - 1) / 2 139 | in_1[1, 2] = (target_H - 1) / 2 140 | in_1 = in_1.astype(np.float32) 141 | depth = full_pred_depths[idf].astype(np.float32) 142 | depth = imresize(depth, ([target_H, target_W]), preserve_range=True).astype(np.float32) 143 | resized_mask = imresize(masks[idf], [target_H, target_W], preserve_range=True) 144 | resized_mask = np.where(resized_mask > 1e-3, 1, 0) 145 | 146 | np.savez(join(outpath, track_names[track_id], 'frame_%05d.npz' % idf), img=img, pose_c2w=T_G_1, 147 | depth_mvs=depth_mvs, intrinsics=in_1, depth_pred=depth, img_orig=img_orig, motion_seg=resized_mask) 148 | -------------------------------------------------------------------------------- /scripts/preprocess/davis/generate_sequence_midas.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from os.path import join 17 | from os import makedirs 18 | import numpy as np 19 | from functools import lru_cache 20 | from glob import glob 21 | from tqdm import tqdm 22 | 23 | 24 | data_list_root = "./datafiles/davis_processed/frames_midas/" 25 | 26 | flow_path = './datafiles/davis_processed/flow_pairs/' 27 | 28 | save_path_root = './datafiles/davis_processed/sequences_select_pairs_midas/' 29 | 30 | 31 | @lru_cache(maxsize=1024) 32 | def read_frame_data(key, frame_id): 33 | data = np.load(join(data_list_root, key, 'frame_%05d.npz' % frame_id)) 34 | data_dict = {} 35 | for k in data.keys(): 36 | data_dict[k] = data[k] 37 | return data_dict 38 | 39 | 40 | @lru_cache(maxsize=1024) 41 | def read_flow_data(key, frame_id_1, frame_id_2): 42 | data = np.load(join(flow_path, key, f'flowpair_{frame_id_1:05d}_{frame_id_2:05d}.npz'), allow_pickle=True) 43 | data_dict = {} 44 | for k in data.keys(): 45 | data_dict[k] = data[k] 46 | return data_dict 47 | 48 | 49 | def prepare_pose_dict_one_way(im1_data, im2_data): 50 | # return R_1 R_2 t_1 t_2 51 | cam_pose_c2w_1 = im1_data['pose_c2w'] 52 | R_1 = cam_pose_c2w_1[:3, :3] 53 | t_1 = cam_pose_c2w_1[:3, 3] 54 | 55 | cam_pose_c2w_2 = im2_data['pose_c2w'] 56 | R_2 = cam_pose_c2w_2[:3, :3] 57 | t_2 = cam_pose_c2w_2[:3, 3] 58 | K = im1_data['intrinsics'] 59 | 60 | # for network use: 61 | R_1_tensor = torch.zeros([1, 1, 1, 3, 3]) 62 | R_1_T_tensor = torch.zeros([1, 1, 1, 3, 3]) 63 | R_2_tensor = torch.zeros([1, 1, 1, 3, 3]) 64 | R_2_T_tensor = torch.zeros([1, 1, 1, 3, 3]) 65 | t_1_tensor = torch.zeros([1, 1, 1, 1, 3]) 66 | t_2_tensor = torch.zeros([1, 1, 1, 1, 3]) 67 | K_tensor = torch.zeros([1, 1, 1, 3, 3]) 68 | K_inv_tensor = torch.zeros([1, 1, 1, 3, 3]) 69 | R_1_tensor[0, ..., :, :] = torch.from_numpy(R_1.T) 70 | R_2_tensor[0, ..., :, :] = torch.from_numpy(R_2.T) 71 | R_1_T_tensor[0, ..., :, :] = torch.from_numpy(R_1) 72 | R_2_T_tensor[0, ..., :, :] = torch.from_numpy(R_2) 73 | t_1_tensor[0, ..., :] = torch.from_numpy(t_1) 74 | t_2_tensor[0, ..., :] = torch.from_numpy(t_2) 75 | K_tensor[..., :, :] = torch.from_numpy(K.T) 76 | K_inv_tensor[..., :, :] = torch.from_numpy(np.linalg.inv(K).T) 77 | 78 | pose_dict = {} 79 | pose_dict['R_1'] = R_1_tensor 80 | pose_dict['R_2'] = R_2_tensor 81 | pose_dict['R_1_T'] = R_1_T_tensor 82 | pose_dict['R_2_T'] = R_2_T_tensor 83 | pose_dict['t_1'] = t_1_tensor 84 | pose_dict['t_2'] = t_2_tensor 85 | pose_dict['K'] = K_tensor 86 | pose_dict['K_inv'] = K_inv_tensor 87 | return pose_dict 88 | 89 | 90 | def collate_sequence_fix_gap(key, seq_list, gap=1): 91 | sequential_pairs_start = seq_list 92 | sequential_pairs_end = seq_list + gap 93 | list_of_pairs_select = [(x, y) for x, y in zip(sequential_pairs_start, sequential_pairs_end)] 94 | sequential_data = collate_pairs(key, list_of_pairs_select) 95 | flow_1_2_batch = sequential_data['flow_1_2'] 96 | flow_1_2_batch = flow_1_2_batch.permute([0, 3, 1, 2]) 97 | 98 | return sequential_data 99 | 100 | 101 | def collate_pairs(key, list_of_pairs): 102 | 103 | dict_of_list = {} 104 | for idp, pair in enumerate(list_of_pairs): 105 | 106 | dd = datadict_from_pair(key, pair) 107 | for k, v in dd.items(): 108 | if k not in dict_of_list.keys(): 109 | dict_of_list[k] = [] 110 | dict_of_list[k].append(v) 111 | 112 | for k in dict_of_list.keys(): 113 | dict_of_list[k] = torch.cat(dict_of_list[k], dim=0) 114 | return dict_of_list 115 | 116 | 117 | def datadict_from_pair(key, pair): 118 | frame_id_1, frame_id_2 = pair 119 | im1_data = read_frame_data(key, pair[0]) 120 | im2_data = read_frame_data(key, pair[1]) 121 | fid_1, fid_2 = sorted([frame_id_1, frame_id_2]) 122 | flow_data_dict = read_flow_data(key, fid_1, fid_2) 123 | if fid_1 == frame_id_1: 124 | flow_1_2 = flow_data_dict['flow_1_2'] 125 | flow_2_1 = flow_data_dict['flow_2_1'] 126 | mask_1 = flow_data_dict['mask_1'] 127 | mask_2 = flow_data_dict['mask_2'] 128 | else: 129 | flow_1_2 = flow_data_dict['flow_2_1'] 130 | flow_2_1 = flow_data_dict['flow_1_2'] 131 | mask_1 = flow_data_dict['mask_1'] 132 | mask_2 = flow_data_dict['mask_2'] 133 | pose_dict = prepare_pose_dict_one_way(im1_data, im2_data) 134 | gt_depth_1 = torch.from_numpy(im1_data['depth_mvs']).float() 135 | pred_depth_1 = torch.from_numpy(im1_data['depth_pred']).float() 136 | H, W = gt_depth_1.shape 137 | depth_1_tensor = torch.zeros([1, 1, H, W]) 138 | depth_1_tensor[0, 0, ...] = gt_depth_1 139 | depth_1_tensor_p = torch.zeros([1, 1, H, W]) 140 | depth_1_tensor_p[0, 0, ...] = pred_depth_1 141 | 142 | flow_1_2 = torch.from_numpy(flow_1_2).float()[None, ...] 143 | flow_2_1 = torch.from_numpy(flow_2_1).float()[None, ...] 144 | mask_1 = torch.from_numpy(mask_1).float() 145 | mask_2 = torch.from_numpy(mask_2).float() 146 | mask_1 = 1 - torch.ceil(mask_1)[None, ..., None, None] 147 | mask_2 = 1 - torch.ceil(mask_2)[None, ..., None, None] 148 | img_1 = torch.from_numpy(im1_data['img']).float()[None, ...] 149 | img_2 = torch.from_numpy(im2_data['img']).float()[None, ...] 150 | fid_1 = pair[0] 151 | fid_2 = pair[1] 152 | if 'motion_seg' in im1_data.keys(): 153 | motion_seg = torch.from_numpy(im1_data['motion_seg'])[None, ..., None, None].float() 154 | else: 155 | motion_seg = mask_2 156 | samples = {} 157 | for k in pose_dict: 158 | samples[k] = pose_dict[k] 159 | samples['img_1'] = img_1 160 | samples['img_2'] = img_2 161 | samples['depth_1'] = depth_1_tensor 162 | samples['flow_1_2'] = flow_1_2 163 | samples['flow_2_1'] = flow_2_1 164 | samples['mask_1'] = mask_1 165 | samples['mask_2'] = mask_2 166 | samples['motion_seg_1'] = motion_seg 167 | samples['depth_pred_1'] = depth_1_tensor_p 168 | samples['fid_1'] = torch.FloatTensor([fid_1]) 169 | samples['fid_2'] = torch.FloatTensor([fid_2]) 170 | return samples 171 | 172 | 173 | if __name__ == '__main__': 174 | track_names = sorted(glob(join(data_list_root, '*'))) 175 | track_names = ['train', 'dog'] 176 | for key in track_names: 177 | all_frames = sorted(glob(join(data_list_root, key, '*.npz'))) 178 | gaps = [1, 2, 3, 4, 5, 6, 7, 8] 179 | bs = 1 180 | save_path = join(save_path_root, key, '001') 181 | print(key) 182 | makedirs(save_path, exist_ok=True) 183 | 184 | print('saving...') 185 | 186 | for gap in tqdm(gaps): 187 | fids = np.arange(len(all_frames) - bs - gap) 188 | cnt = 0 189 | for f in fids: 190 | seq_list_forward = np.arange(f, f + bs) 191 | sequence = collate_sequence_fix_gap(key, seq_list_forward, gap=gap,) 192 | torch.save(sequence, join(save_path, f'shuffle_False_gap_{gap:02d}_sequence_{cnt:05d}.pt')) 193 | cnt += 1 194 | -------------------------------------------------------------------------------- /scripts/preprocess/shutterstock/generate_flows.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import sys 17 | import os 18 | from os.path import join, basename 19 | sys.path.append('./third_party/RAFT') 20 | sys.path.append('./third_party/RAFT/core') 21 | from raft import RAFT 22 | import numpy as np 23 | import torch.nn.functional as F 24 | from functools import lru_cache 25 | from glob import glob 26 | import argparse 27 | from tqdm import tqdm 28 | import subprocess 29 | 30 | try: 31 | import cv2 32 | except ImportError: 33 | subprocess.check_call([sys.executable, "-m", "pip", "install", 'opencv-python']) 34 | finally: 35 | import cv2 36 | 37 | from skimage.transform import resize as imresize 38 | 39 | 40 | data_list_root = "./datafiles/shutterstock/frames_midas" 41 | outpath = './datafiles/shutterstock/flow_pairs' 42 | 43 | 44 | def resize_flow(flow, size): 45 | resized_width, resized_height = size 46 | H, W = flow.shape[:2] 47 | scale = np.array((resized_width / float(W), resized_height / float(H))).reshape( 48 | 1, 1, -1 49 | ) 50 | resized = cv2.resize( 51 | flow, dsize=(resized_width, resized_height), interpolation=cv2.INTER_CUBIC 52 | ) 53 | resized *= scale 54 | return resized 55 | 56 | 57 | def get_oob_mask(flow_1_2): 58 | H, W, _ = flow_1_2.shape 59 | hh, ww = torch.meshgrid(torch.arange(H).float(), torch.arange(W).float()) 60 | coord = torch.zeros([H, W, 2]) 61 | coord[..., 0] = ww 62 | coord[..., 1] = hh 63 | target_range = coord + flow_1_2 64 | m1 = (target_range[..., 0] < 0) + (target_range[..., 0] > W - 1) 65 | m2 = (target_range[..., 1] < 0) + (target_range[..., 1] > H - 1) 66 | return (m1 + m2).float().numpy() 67 | 68 | 69 | def backward_flow_warp(im2, flow_1_2): 70 | H, W, _ = im2.shape 71 | hh, ww = torch.meshgrid(torch.arange(H).float(), torch.arange(W).float()) 72 | coord = torch.zeros([1, H, W, 2]) 73 | coord[0, ..., 0] = ww 74 | coord[0, ..., 1] = hh 75 | sample_grids = coord + flow_1_2[None, ...] 76 | sample_grids[..., 0] /= (W - 1) / 2 77 | sample_grids[..., 1] /= (H - 1) / 2 78 | sample_grids -= 1 79 | im = torch.from_numpy(im2).float().permute(2, 0, 1)[None, ...] 80 | out = F.grid_sample(im, sample_grids, align_corners=True) 81 | o = out[0, ...].permute(1, 2, 0).numpy() 82 | return o 83 | 84 | 85 | def get_L2_error_map(v1, v2): 86 | return np.linalg.norm(v1 - v2, axis=-1) 87 | 88 | 89 | def load_RAFT(): 90 | parser = argparse.ArgumentParser() 91 | parser.add_argument('--model', help="restore checkpoint") 92 | parser.add_argument('--path', help="dataset for evaluation") 93 | parser.add_argument('--small', action='store_true', help='use small model') 94 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') 95 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') 96 | args = parser.parse_args(['--model', './third_party/RAFT/models/raft-sintel.pth', '--path', './']) 97 | net = torch.nn.DataParallel(RAFT(args).cuda()) 98 | net.load_state_dict(torch.load(args.model)) 99 | return net 100 | 101 | 102 | @lru_cache(maxsize=200) 103 | def read_frame_data(key, frame_id): 104 | data = np.load(join(data_list_root, key, 'frame_%05d.npz' % frame_id), allow_pickle=True) 105 | data_dict = {} 106 | for k in data.keys(): 107 | data_dict[k] = data[k] 108 | return data_dict 109 | 110 | net = load_RAFT() 111 | 112 | def generate_pair_data(key, frame_id_1, frame_id_2_list): 113 | im1_data = read_frame_data(key, frame_id_1) 114 | im2_data_list = [read_frame_data(key, f) for f in frame_id_2_list] 115 | 116 | im1 = im1_data['img_orig'] * 255 117 | im2_list = [i['img_orig'] * 255 for i in im2_data_list] 118 | im1 = imresize(im1, (288, 512), anti_aliasing=True) 119 | im2_list = [imresize(i, (288, 512), anti_aliasing=True) for i in im2_list] 120 | 121 | im1_list = [im1] * len(im2_list) 122 | 123 | def image_list_to_cuda_batch(im_list): 124 | reorder = np.array(im_list).transpose(0, 3, 1, 2) 125 | to_cuda = torch.from_numpy(reorder.astype(np.float32)).cuda() 126 | return to_cuda 127 | 128 | def cuda_batch_to_numpy_batch(cuda_batch): 129 | return cuda_batch.permute(0, 2, 3, 1).cpu().numpy() 130 | 131 | im1_batch = image_list_to_cuda_batch(im1_list) 132 | im2_batch = image_list_to_cuda_batch(im2_list) 133 | 134 | with torch.no_grad(): 135 | flow_low, flow_up = net(image1=im1_batch, image2=im2_batch, iters=20, test_mode=True) 136 | flow_1_2_batch = cuda_batch_to_numpy_batch(flow_up) 137 | 138 | flow_low, flow_up = net(image1=im2_batch, image2=im1_batch, iters=20, test_mode=True) 139 | flow_2_1_batch = cuda_batch_to_numpy_batch(flow_up) 140 | 141 | H, W, _ = im1_data['img'].shape 142 | for j, frame_id_2 in enumerate(frame_id_2_list): 143 | flow_1_2 = resize_flow(flow_1_2_batch[j,...], [W, H]) 144 | flow_2_1 = resize_flow(flow_2_1_batch[j,...], [W, H]) 145 | 146 | warp_flow_1_2 = backward_flow_warp(flow_1_2, flow_2_1) # using latter to sample former 147 | err_1 = np.linalg.norm(warp_flow_1_2 + flow_2_1, axis=-1) 148 | mask_1 = np.where(err_1 > 1, 1, 0) 149 | oob_mask_1 = get_oob_mask(flow_2_1) 150 | mask_1 = np.clip(mask_1 + oob_mask_1, a_min=0, a_max=1) 151 | warp_flow_2_1 = backward_flow_warp(flow_2_1, flow_1_2) 152 | err_2 = np.linalg.norm(warp_flow_2_1 + flow_1_2, axis=-1) 153 | mask_2 = np.where(err_2 > 1, 1, 0) 154 | oob_mask_2 = get_oob_mask(flow_1_2) 155 | mask_2 = np.clip(mask_2 + oob_mask_2, a_min=0, a_max=1) 156 | save_dict = {} 157 | save_dict['flow_1_2'] = flow_1_2.astype(np.float32) 158 | save_dict['flow_2_1'] = flow_2_1.astype(np.float32) 159 | save_dict['mask_1'] = mask_1.astype(np.uint8) 160 | save_dict['mask_2'] = mask_2.astype(np.uint8) 161 | save_dict['frame_id_1'] = frame_id_1 162 | save_dict['frame_id_2'] = frame_id_2 163 | np.savez(join(outpath, key, f'flowpair_{frame_id_1:05d}_{frame_id_2:05d}.npz'), **save_dict) 164 | 165 | # %% 166 | 167 | 168 | track_names = sorted(glob(join(data_list_root, '*'))) 169 | track_names = [basename(x) for x in track_names] 170 | track_ids = np.arange(len(track_names)) 171 | 172 | # %% 173 | for track_id in tqdm(track_ids): 174 | key = track_names[track_id] 175 | print(key) 176 | l = len(sorted(glob(join(data_list_root, key, 'frame_*.npz')))) 177 | os.makedirs(join(outpath, track_names[track_id]), exist_ok=True) 178 | MAX_GAP = 8 179 | for k in tqdm(range(l-1)): 180 | gaps = range(min(l-k-1, MAX_GAP)) 181 | end_frames = [k + g + 1 for g in gaps] 182 | generate_pair_data(key, k, end_frames) 183 | -------------------------------------------------------------------------------- /scripts/preprocess/shutterstock/generate_frame_midas.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import sys 17 | import numpy as np 18 | import h5py 19 | import torch 20 | from os.path import join, basename, dirname 21 | from os import makedirs 22 | from glob import glob 23 | from skimage.transform import resize as imresize 24 | from tqdm import tqdm 25 | sys.path.insert(0, '') 26 | from third_party.MiDaS import MidasNet 27 | from configs import midas_pretrain_path 28 | from PIL import Image 29 | 30 | model = MidasNet(midas_pretrain_path, non_negative=True, resize=None, normalize_input=True) 31 | model = model.eval().cuda() 32 | 33 | 34 | data_list_root = "./datafiles/shutterstock/triangulation" 35 | image_list_root = './datafiles/shutterstock/images' 36 | outpath = './datafiles/shutterstock/frames_midas' 37 | TRIM_BAD_FRAMES = True 38 | 39 | track_paths = sorted(glob(join(data_list_root, '*'))) 40 | 41 | track_names = [basename(x) for x in track_paths] 42 | print('Track names: ', track_names) 43 | track_ids = np.arange(len(track_names)) 44 | # %% filter out valid sequences 45 | track_lut = {} 46 | track_ts = {} 47 | tracks_grad = {} 48 | for tr in track_ids: 49 | file_list = [] 50 | all_files = sorted(glob(join(track_paths[tr], '*.h5'))) 51 | ts = [] 52 | for f in all_files: 53 | ts_str = f.split('/')[-1].split('_')[-1].split('.')[0] 54 | ts.append(int(ts_str)) 55 | 56 | idx = np.argsort(ts) 57 | sorted_path = [all_files[x] for x in idx] 58 | track_lut[tr] = sorted_path 59 | track_ts[tr] = sorted(ts) 60 | 61 | sorted_ts = np.array(sorted(ts)) 62 | grad = sorted_ts[1:] - sorted_ts[:-1] 63 | tracks_grad[tr] = grad 64 | 65 | for tr in track_ids: 66 | valid_tracks = [] 67 | th = 40000 68 | g = tracks_grad[tr] 69 | idx = np.where(g > th)[0] 70 | print('Valid indices: ', idx) 71 | 72 | if TRIM_BAD_FRAMES: 73 | valid_track_lut = {} 74 | for tr in track_ids: 75 | valid_tracks = [] 76 | if tr == 0: 77 | valid_tracks = track_lut[tr][14:] 78 | elif tr == 3: 79 | valid_tracks = track_lut[tr][:134] 80 | else: 81 | valid_tracks = track_lut[tr] 82 | valid_track_lut[tr] = valid_tracks 83 | else: 84 | valid_track_lut = track_lut 85 | 86 | 87 | def get_im_size(im, dim_max=384, multiple=32): 88 | H, W, _ = im.shape 89 | if W > H: 90 | if W > dim_max: 91 | sc = dim_max / W 92 | target_W = dim_max 93 | else: 94 | target_W = np.floor(W / multiple) * multiple 95 | sc = target_W / W 96 | target_H = int(np.round((H * sc) / multiple) * multiple) 97 | return [target_H, target_W] 98 | else: 99 | if H > dim_max: 100 | sc = dim_max / H 101 | target_H = dim_max 102 | else: 103 | target_H = np.floor(H / multiple) * multiple 104 | sc = target_H / H 105 | target_W = int(np.round((W * sc) / multiple) * multiple) 106 | return [target_H, target_W] 107 | 108 | 109 | for track_id in track_ids: 110 | frames = valid_track_lut[track_id] 111 | 112 | hdf5_file_handles = [] 113 | 114 | for idf, f in enumerate(frames): 115 | hdf5_file_handles.append(h5py.File(f, 'r')) 116 | test_in = np.array(hdf5_file_handles[0]['prediction/K']) 117 | if len(test_in) < 3: 118 | for f in hdf5_file_handles: 119 | test_f = np.array(hdf5_file_handles[0]['prediction/K']) 120 | if np.any(np.isnan(test_f)): 121 | continue 122 | else: 123 | print('found!!') 124 | print('corrupted!') 125 | continue 126 | makedirs(join(outpath, f'{track_names[track_id]}'), exist_ok=True) 127 | 128 | print(track_names[track_id]) 129 | 130 | print('calculating NN_depth') 131 | depths = [] 132 | conf = [] 133 | mvs_depths = [] 134 | for x in tqdm(range(len(hdf5_file_handles))): 135 | img = hdf5_file_handles[x]['prediction/img'] 136 | if not img: 137 | img = Image.open(hdf5_file_handles[x]['prediction'].attrs['image_path']) 138 | stored_shape = hdf5_file_handles[x]['prediction'].attrs['image_shape'] 139 | if not np.all(img.shape == stored_shape): 140 | img = imresize(np.asarray(img), stored_shape[:2], preserve_range=True) 141 | img = np.asarray(img).astype(float) / 255 142 | else: 143 | img = np.asarray(img) # Already a float32 array in range 0,1 144 | 145 | img_batch = torch.from_numpy(img).permute(2, 0, 1)[None, ...].float().cuda() 146 | with torch.no_grad(): 147 | pred_d = model(img_batch) 148 | depths.append(pred_d.squeeze().cpu().numpy()) 149 | mvs_depth = np.array(hdf5_file_handles[x]['prediction/mvs_depth']) 150 | mvs_depths.append(mvs_depth) 151 | print(img.shape) 152 | 153 | print('calculating scale') 154 | scales = [] 155 | for x in tqdm(range(len(hdf5_file_handles))): 156 | nn_depth = depths[x] 157 | mvs_depth = mvs_depths[x] 158 | idx, idy = np.where(mvs_depth > 1e-3) 159 | scales.append(np.median(nn_depth[idx, idy] / mvs_depth[idx, idy])) 160 | s = np.mean(scales) 161 | print(s) 162 | 163 | print('saving per frame output') 164 | 165 | for idf, h5file in tqdm(enumerate(hdf5_file_handles)): 166 | img_orig = h5file['prediction/img'] 167 | if not img_orig: 168 | img_orig = Image.open(h5file['prediction'].attrs['image_path']) 169 | img_orig = np.asarray(img_orig).astype(float) / 255 170 | else: 171 | img_orig = np.asarray(img_orig) # Already a float32 array in range 0,1 172 | 173 | max_dim = 384 174 | multiple = 32 175 | H, W, _ = img_orig.shape 176 | target_H, target_W = get_im_size(img_orig) 177 | 178 | img = imresize(img_orig, ((target_H, target_W)), preserve_range=True).astype(np.float32) 179 | 180 | T_G_1 = np.array(h5file['prediction/T_1_G']) 181 | T_G_1[:3, 3] *= s 182 | T_G_1 = np.linalg.inv(T_G_1) 183 | T_G_1 = T_G_1.astype(np.float32) 184 | depth_mvs = mvs_depths[idf] * s 185 | depth_mvs = depth_mvs.astype(np.float32) 186 | depth_mvs = imresize(depth_mvs, ([target_H, target_W]), preserve_range=True).astype(np.float32) 187 | in_1 = np.array(h5file['prediction/K']) 188 | in_1[0, 0] /= W / target_W 189 | in_1[1, 1] /= H / target_H 190 | in_1[0, 2] = (target_W - 1) / 2 191 | in_1[1, 2] = (target_H - 1) / 2 192 | in_1 = in_1.astype(np.float32) 193 | depth = depths[idf].astype(np.float32) 194 | depth = imresize(depth, ([target_H, target_W]), preserve_range=True).astype(np.float32) 195 | np.savez(join(outpath, track_names[track_id], 'frame_%05d.npz' % idf), img=img, pose_c2w=T_G_1, 196 | depth_mvs=depth_mvs, intrinsics=in_1, depth_pred=depth, img_orig=img_orig, motion_seg=None) 197 | 198 | 199 | # %% 200 | -------------------------------------------------------------------------------- /scripts/preprocess/shutterstock/generate_sequence_midas.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import torch 17 | from os.path import join, basename 18 | from os import makedirs 19 | import numpy as np 20 | from functools import lru_cache 21 | from glob import glob 22 | from tqdm import tqdm 23 | 24 | 25 | data_list_root = "./datafiles/shutterstock/frames_midas/" 26 | 27 | flow_path = './datafiles/shutterstock/flow_pairs/' 28 | 29 | save_path_root = './datafiles/shutterstock/sequences_select_pairs_midas/' 30 | 31 | 32 | @lru_cache(maxsize=1024) 33 | def read_frame_data(key, frame_id): 34 | data = np.load(join(data_list_root, key, 'frame_%05d.npz' % frame_id), allow_pickle=True) 35 | data_dict = {} 36 | for k in data.keys(): 37 | data_dict[k] = data[k] 38 | return data_dict 39 | 40 | 41 | @lru_cache(maxsize=1024) 42 | def read_flow_data(key, frame_id_1, frame_id_2): 43 | data = np.load(join(flow_path, key, f'flowpair_{frame_id_1:05d}_{frame_id_2:05d}.npz')) 44 | data_dict = {} 45 | for k in data.keys(): 46 | data_dict[k] = data[k] 47 | return data_dict 48 | 49 | 50 | def prepare_pose_dict_one_way(im1_data, im2_data): 51 | # return R_1 R_2 t_1 t_2 52 | cam_pose_c2w_1 = im1_data['pose_c2w'] 53 | R_1 = cam_pose_c2w_1[:3, :3] 54 | t_1 = cam_pose_c2w_1[:3, 3] 55 | 56 | cam_pose_c2w_2 = im2_data['pose_c2w'] 57 | R_2 = cam_pose_c2w_2[:3, :3] 58 | t_2 = cam_pose_c2w_2[:3, 3] 59 | K = im1_data['intrinsics'] 60 | 61 | # for network use: 62 | R_1_tensor = torch.zeros([1, 1, 1, 3, 3]) 63 | R_1_T_tensor = torch.zeros([1, 1, 1, 3, 3]) 64 | R_2_tensor = torch.zeros([1, 1, 1, 3, 3]) 65 | R_2_T_tensor = torch.zeros([1, 1, 1, 3, 3]) 66 | t_1_tensor = torch.zeros([1, 1, 1, 1, 3]) 67 | t_2_tensor = torch.zeros([1, 1, 1, 1, 3]) 68 | K_tensor = torch.zeros([1, 1, 1, 3, 3]) 69 | K_inv_tensor = torch.zeros([1, 1, 1, 3, 3]) 70 | R_1_tensor[0, ..., :, :] = torch.from_numpy(R_1.T) 71 | R_2_tensor[0, ..., :, :] = torch.from_numpy(R_2.T) 72 | R_1_T_tensor[0, ..., :, :] = torch.from_numpy(R_1) 73 | R_2_T_tensor[0, ..., :, :] = torch.from_numpy(R_2) 74 | t_1_tensor[0, ..., :] = torch.from_numpy(t_1) 75 | t_2_tensor[0, ..., :] = torch.from_numpy(t_2) 76 | K_tensor[..., :, :] = torch.from_numpy(K.T) 77 | K_inv_tensor[..., :, :] = torch.from_numpy(np.linalg.inv(K).T) 78 | 79 | pose_dict = {} 80 | pose_dict['R_1'] = R_1_tensor 81 | pose_dict['R_2'] = R_2_tensor 82 | pose_dict['R_1_T'] = R_1_T_tensor 83 | pose_dict['R_2_T'] = R_2_T_tensor 84 | pose_dict['t_1'] = t_1_tensor 85 | pose_dict['t_2'] = t_2_tensor 86 | pose_dict['K'] = K_tensor 87 | pose_dict['K_inv'] = K_inv_tensor 88 | return pose_dict 89 | 90 | 91 | def collate_sequence_fix_gap(key, seq_list, gap=1): 92 | sequential_pairs_start = seq_list 93 | sequential_pairs_end = seq_list + gap 94 | list_of_pairs_select = [(x, y) for x, y in zip(sequential_pairs_start, sequential_pairs_end)] 95 | sequential_data = collate_pairs(key, list_of_pairs_select) 96 | flow_1_2_batch = sequential_data['flow_1_2'] 97 | flow_1_2_batch = flow_1_2_batch.permute([0, 3, 1, 2]) 98 | 99 | return sequential_data 100 | 101 | 102 | def collate_pairs(key, list_of_pairs): 103 | 104 | dict_of_list = {} 105 | for idp, pair in enumerate(list_of_pairs): 106 | 107 | dd = datadict_from_pair(key, pair) 108 | for k, v in dd.items(): 109 | if k not in dict_of_list.keys(): 110 | dict_of_list[k] = [] 111 | dict_of_list[k].append(v) 112 | 113 | for k in dict_of_list.keys(): 114 | dict_of_list[k] = torch.cat(dict_of_list[k], dim=0) 115 | return dict_of_list 116 | 117 | 118 | def datadict_from_pair(key, pair): 119 | frame_id_1, frame_id_2 = pair 120 | im1_data = read_frame_data(key, pair[0]) 121 | im2_data = read_frame_data(key, pair[1]) 122 | fid_1, fid_2 = sorted([frame_id_1, frame_id_2]) 123 | flow_data_dict = read_flow_data(key, fid_1, fid_2) 124 | if fid_1 == frame_id_1: 125 | flow_1_2 = flow_data_dict['flow_1_2'] 126 | flow_2_1 = flow_data_dict['flow_2_1'] 127 | mask_1 = flow_data_dict['mask_1'] 128 | mask_2 = flow_data_dict['mask_2'] 129 | else: 130 | flow_1_2 = flow_data_dict['flow_2_1'] 131 | flow_2_1 = flow_data_dict['flow_1_2'] 132 | mask_1 = flow_data_dict['mask_1'] 133 | mask_2 = flow_data_dict['mask_2'] 134 | pose_dict = prepare_pose_dict_one_way(im1_data, im2_data) 135 | gt_depth_1 = torch.from_numpy(im1_data['depth_mvs']).float() 136 | pred_depth_1 = torch.from_numpy(im1_data['depth_pred']).float() 137 | H, W = gt_depth_1.shape 138 | depth_1_tensor = torch.zeros([1, 1, H, W]) 139 | depth_1_tensor[0, 0, ...] = gt_depth_1 140 | depth_1_tensor_p = torch.zeros([1, 1, H, W]) 141 | depth_1_tensor_p[0, 0, ...] = pred_depth_1 142 | 143 | flow_1_2 = torch.from_numpy(flow_1_2).float()[None, ...] 144 | flow_2_1 = torch.from_numpy(flow_2_1).float()[None, ...] 145 | mask_1 = torch.from_numpy(mask_1).float() 146 | mask_2 = torch.from_numpy(mask_2).float() 147 | mask_1 = 1 - torch.ceil(mask_1)[None, ..., None, None] 148 | mask_2 = 1 - torch.ceil(mask_2)[None, ..., None, None] 149 | img_1 = torch.from_numpy(im1_data['img']).float()[None, ...] 150 | img_2 = torch.from_numpy(im2_data['img']).float()[None, ...] 151 | fid_1 = pair[0] 152 | fid_2 = pair[1] 153 | if 'motion_seg' in im1_data.keys(): 154 | if im1_data['motion_seg'].item() is not None: 155 | motion_seg = torch.from_numpy(im1_data['motion_seg'])[None, ..., None, None].float() 156 | else: 157 | motion_seg = mask_2 158 | else: 159 | motion_seg = mask_2 160 | samples = {} 161 | for k in pose_dict: 162 | samples[k] = pose_dict[k] 163 | samples['img_1'] = img_1 164 | samples['img_2'] = img_2 165 | samples['depth_1'] = depth_1_tensor 166 | samples['flow_1_2'] = flow_1_2 167 | samples['flow_2_1'] = flow_2_1 168 | samples['mask_1'] = mask_1 169 | samples['mask_2'] = mask_2 170 | samples['motion_seg_1'] = motion_seg 171 | samples['depth_pred_1'] = depth_1_tensor_p 172 | samples['fid_1'] = torch.FloatTensor([fid_1]) 173 | samples['fid_2'] = torch.FloatTensor([fid_2]) 174 | return samples 175 | 176 | 177 | if __name__ == '__main__': 178 | track_names = sorted(glob(join(data_list_root, '*'))) 179 | track_names = [basename(x) for x in track_names] 180 | for key in track_names: 181 | all_frames = sorted(glob(join(data_list_root, key, '*.npz'))) 182 | gaps = [1, 2, 3, 4, 5, 6, 7, 8] 183 | bs = 1 184 | save_path = join(save_path_root, key, '001') 185 | print(key) 186 | makedirs(save_path, exist_ok=True) 187 | 188 | print('saving...') 189 | 190 | for gap in tqdm(gaps): 191 | fids = np.arange(len(all_frames) - bs - gap) 192 | cnt = 0 193 | for f in fids: 194 | seq_list_forward = np.arange(f, f + bs) 195 | sequence = collate_sequence_fix_gap(key, seq_list_forward, gap=gap,) 196 | torch.save(sequence, join(save_path, f'shuffle_False_gap_{gap:02d}_sequence_{cnt:05d}.pt')) 197 | cnt += 1 198 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from os.path import join 17 | import time 18 | from shutil import rmtree 19 | from tqdm import tqdm 20 | import torch 21 | from options import options_test 22 | import datasets 23 | import models 24 | from util.util_print import str_error, str_stage, str_verbose 25 | import util.util_loadlib as loadlib 26 | from loggers import loggers 27 | from argparse import Namespace 28 | print("Testing Pipeline") 29 | 30 | ################################################### 31 | 32 | print(str_stage, "Parsing arguments") 33 | opt = options_test.parse() 34 | opt.full_logdir = None 35 | print(opt) 36 | 37 | ################################################### 38 | 39 | print(str_stage, "Setting device") 40 | if opt.gpu == '-1': 41 | device = torch.device('cpu') 42 | else: 43 | loadlib.set_gpu(opt.gpu) 44 | device = torch.device('cuda') 45 | if opt.manual_seed is not None: 46 | loadlib.set_manual_seed(opt.manual_seed) 47 | 48 | ################################################### 49 | 50 | print(str_stage, "Setting up output directory") 51 | output_dir = opt.output_dir 52 | output_dir += (opt.net + '_' + opt.dataset + '_' + opt.suffix.format(**vars(opt))) \ 53 | if opt.suffix != '' else (opt.net + '_' + opt.dataset) 54 | opt.output_dir = output_dir 55 | if os.path.isdir(join(output_dir, 'epoch_%04d' % opt.epoch)): 56 | if opt.overwrite: 57 | rmtree(join(output_dir, 'epoch_%04d' % opt.epoch)) 58 | else: 59 | raise ValueError(str_error + " %s already exists, but no overwrite flag" 60 | % output_dir) 61 | os.makedirs(output_dir, exist_ok=True) 62 | opt.output_dir = output_dir 63 | 64 | ################################################### 65 | 66 | print(str_stage, "Setting up loggers") 67 | logger_list = [ 68 | loggers.TerminateOnNaN(), 69 | ] 70 | if opt.html_logger: 71 | html_summary_filepath = os.path.join(opt.output_dir, 'summary') 72 | html_logger = loggers.HtmlLogger(html_summary_filepath) 73 | logger_list.append(html_logger) 74 | logger = loggers.ComposeLogger(logger_list) 75 | 76 | ################################################### 77 | 78 | print(str_stage, "Setting up models") 79 | Model = models.get_model(opt.net, test=True) 80 | # load opt_original 81 | opt_dict = torch.load(join(opt.checkpoint_path, 'opt.pt')) 82 | opt_train = Namespace(**opt_dict) 83 | opt_train.global_rank = 0 84 | opt_train.output_dir = opt.output_dir 85 | model = Model(opt_train, logger) 86 | if hasattr(opt_train, 'midas'): 87 | opt.midas = opt_train.midas 88 | 89 | # checkpoint_path 90 | if opt.epoch < 0: 91 | net_file = join(opt.checkpoint_path, 'best.pt') 92 | else: 93 | net_file = join(opt.checkpoint_path, 'nets', '%04d.pt' % opt.epoch) 94 | 95 | 96 | ################################################### 97 | print(str_stage, "Setting up data loaders") 98 | start_time = time.time() 99 | Dataset = datasets.get_dataset(opt_train.dataset) 100 | dataset = Dataset(opt_train, mode='vali', model=model) 101 | dataloader = torch.utils.data.DataLoader( 102 | dataset, 103 | batch_size=opt.batch_size, 104 | num_workers=opt.workers, 105 | pin_memory=True, 106 | drop_last=False, 107 | shuffle=False 108 | ) 109 | n_batches = len(dataloader) 110 | dataiter = iter(dataloader) 111 | print(str_verbose, "Time spent in data IO initialization: %.2fs" % 112 | (time.time() - start_time)) 113 | print(str_verbose, "# test points: " + str(len(dataset))) 114 | print(str_verbose, "# test batches: " + str(n_batches)) 115 | 116 | if hasattr(model, 'update_opt'): 117 | model.update_opt(opt_train, is_train=False) 118 | 119 | model.load_state_dict(net_file) 120 | model.to(device) 121 | model.eval() 122 | print(model) 123 | print("# model parameters: {:,d}".format(model.num_parameters())) 124 | ################################################### 125 | 126 | print(str_stage, "Testing") 127 | if opt.html_logger: 128 | html_logger.on_train_begin() 129 | html_logger.training = False 130 | html_logger.on_epoch_begin(0) 131 | 132 | model.opt.epoch = opt.epoch 133 | for i in tqdm(range(n_batches)): 134 | batch = next(dataiter) 135 | model.test_on_batch(i, batch) 136 | 137 | if hasattr(model, 'on_test_end'): 138 | model.on_test_end() 139 | -------------------------------------------------------------------------------- /third_party/MiDaS.py: -------------------------------------------------------------------------------- 1 | 2 | """ MIT License 3 | 4 | Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab) 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. """ 23 | 24 | 25 | import torch 26 | import torch.nn as nn 27 | from .midas_blocks import FeatureFusionBlock, Interpolate, _make_encoder 28 | 29 | 30 | class BaseModel(torch.nn.Module): 31 | def load(self, path): 32 | """Load model from file. 33 | Args: 34 | path (str): file path 35 | """ 36 | parameters = torch.load(path) 37 | 38 | if "optimizer" in parameters: 39 | parameters = parameters["model"] 40 | 41 | self.load_state_dict(parameters) 42 | 43 | 44 | class MidasNet_mod(BaseModel): 45 | def __init__(self, path=None, features=256, non_negative=True, normalize_input=False, resize=None, freeze_backbone=False, mask_branch=False): 46 | """Init. 47 | Args: 48 | path (str, optional): Path to saved model. Defaults to None. 49 | features (int, optional): Number of features. Defaults to 256. 50 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 51 | """ 52 | print("Loading weights: ", path) 53 | 54 | super(MidasNet_mod, self).__init__() 55 | 56 | use_pretrained = False if path is None else True 57 | 58 | self.pretrained, self.scratch = _make_encoder(features, use_pretrained) 59 | 60 | self.scratch.refinenet4 = FeatureFusionBlock(features) 61 | self.scratch.refinenet3 = FeatureFusionBlock(features) 62 | self.scratch.refinenet2 = FeatureFusionBlock(features) 63 | self.scratch.refinenet1 = FeatureFusionBlock(features) 64 | 65 | self.scratch.output_conv = nn.Sequential( 66 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 67 | Interpolate(scale_factor=2, mode="bilinear"), 68 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 69 | nn.ReLU(True), 70 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 71 | nn.ReLU(True) if non_negative else nn.Identity(), 72 | ) 73 | 74 | if path: 75 | self.load(path) 76 | 77 | if mask_branch: 78 | self.scratch.output_conv_mask = nn.Sequential( 79 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 80 | Interpolate(scale_factor=2, mode="bilinear"), 81 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 82 | nn.ReLU(True), 83 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 84 | nn.Sigmoid() 85 | ) 86 | self.mask_branch = mask_branch 87 | 88 | if normalize_input: 89 | self.mean = torch.FloatTensor([0.485, 0.456, 0.406]) 90 | self.std = torch.FloatTensor([0.229, 0.224, 0.225]) 91 | self.normalize_input = normalize_input 92 | self.resize = resize 93 | self.freeze_backbone = freeze_backbone 94 | 95 | def freeze(self): 96 | for p in self.parameters(): 97 | p.requires_grad = False 98 | 99 | def defrost(self): 100 | if self.freeze_backbone: 101 | for p in self.scratch.parameters(): 102 | p.requires_grad = True 103 | else: 104 | for p in self.parameters(): 105 | p.requires_grad = True 106 | 107 | def forward(self, x): 108 | """Forward pass. 109 | Args: 110 | x (tensor): input data (image) 111 | Returns: 112 | tensor: depth 113 | """ 114 | if self.normalize_input: 115 | self.mean = self.mean.to(x.device) 116 | self.std = self.std.to(x.device) 117 | x = x.permute([0, 2, 3, 1]) 118 | x = (x - self.mean) / self.std 119 | x = x.permute([0, 3, 1, 2]).contiguous() 120 | 121 | orig_shape = x.shape[-2:] 122 | if self.resize is not None: 123 | x = torch.nn.functional.interpolate(x, size=self.resize, mode='bicubic', align_corners=True) 124 | 125 | if self.freeze_backbone: 126 | with torch.no_grad(): 127 | layer_1 = self.pretrained.layer1(x) 128 | layer_2 = self.pretrained.layer2(layer_1) 129 | layer_3 = self.pretrained.layer3(layer_2) 130 | layer_4 = self.pretrained.layer4(layer_3) 131 | else: 132 | layer_1 = self.pretrained.layer1(x) 133 | layer_2 = self.pretrained.layer2(layer_1) 134 | layer_3 = self.pretrained.layer3(layer_2) 135 | layer_4 = self.pretrained.layer4(layer_3) 136 | 137 | layer_1_rn = self.scratch.layer1_rn(layer_1) 138 | layer_2_rn = self.scratch.layer2_rn(layer_2) 139 | layer_3_rn = self.scratch.layer3_rn(layer_3) 140 | layer_4_rn = self.scratch.layer4_rn(layer_4) 141 | 142 | path_4 = self.scratch.refinenet4(layer_4_rn) 143 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 144 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 145 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 146 | 147 | out = self.scratch.output_conv(path_1) 148 | out = torch.clamp(out, min=1e-2) 149 | 150 | out = 10000 / (out) 151 | 152 | if self.mask_branch: 153 | mask = self.scratch.output_conv_mask(path_1) 154 | 155 | else: 156 | mask = torch.zeros_like(out) 157 | 158 | if self.resize is not None: 159 | out = torch.nn.functional.interpolate(out, size=orig_shape, mode='bicubic', align_corners=True) 160 | mask = torch.nn.functional.interpolate(mask, size=orig_shape, mode='bicubic', align_corners=True) 161 | return out, mask 162 | 163 | 164 | class MidasNet(BaseModel): 165 | """Network for monocular depth estimation. 166 | """ 167 | 168 | def __init__(self, path=None, features=256, non_negative=True, normalize_input=False, resize=None): 169 | """Init. 170 | Args: 171 | path (str, optional): Path to saved model. Defaults to None. 172 | features (int, optional): Number of features. Defaults to 256. 173 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 174 | """ 175 | print("Loading weights: ", path) 176 | 177 | super(MidasNet, self).__init__() 178 | 179 | use_pretrained = False if path is None else True 180 | 181 | self.pretrained, self.scratch = _make_encoder(features, use_pretrained) 182 | 183 | self.scratch.refinenet4 = FeatureFusionBlock(features) 184 | self.scratch.refinenet3 = FeatureFusionBlock(features) 185 | self.scratch.refinenet2 = FeatureFusionBlock(features) 186 | self.scratch.refinenet1 = FeatureFusionBlock(features) 187 | 188 | self.scratch.output_conv = nn.Sequential( 189 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 190 | Interpolate(scale_factor=2, mode="bilinear"), 191 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 192 | nn.ReLU(True), 193 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 194 | nn.ReLU(True) if non_negative else nn.Identity(), 195 | ) 196 | 197 | if path: 198 | self.load(path) 199 | 200 | if normalize_input: 201 | self.mean = torch.FloatTensor([0.485, 0.456, 0.406]) 202 | self.std = torch.FloatTensor([0.229, 0.224, 0.225]) 203 | self.normalize_input = normalize_input 204 | self.resize = resize 205 | 206 | def forward(self, x): 207 | """Forward pass. 208 | Args: 209 | x (tensor): input data (image) 210 | Returns: 211 | tensor: depth 212 | """ 213 | if self.normalize_input: 214 | self.mean = self.mean.to(x.device) 215 | self.std = self.std.to(x.device) 216 | x = x.permute([0, 2, 3, 1]) 217 | x = (x - self.mean) / self.std 218 | x = x.permute([0, 3, 1, 2]).contiguous() 219 | 220 | orig_shape = x.shape[-2:] 221 | if self.resize is not None: 222 | x = torch.nn.functional.interpolate(x, size=self.resize, mode='bicubic', align_corners=True) 223 | 224 | layer_1 = self.pretrained.layer1(x) 225 | layer_2 = self.pretrained.layer2(layer_1) 226 | layer_3 = self.pretrained.layer3(layer_2) 227 | layer_4 = self.pretrained.layer4(layer_3) 228 | 229 | layer_1_rn = self.scratch.layer1_rn(layer_1) 230 | layer_2_rn = self.scratch.layer2_rn(layer_2) 231 | layer_3_rn = self.scratch.layer3_rn(layer_3) 232 | layer_4_rn = self.scratch.layer4_rn(layer_4) 233 | 234 | path_4 = self.scratch.refinenet4(layer_4_rn) 235 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 236 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 237 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 238 | 239 | out = self.scratch.output_conv(path_1) 240 | out = torch.clamp(out, min=1e-2) 241 | 242 | out = 10000 / (out) 243 | 244 | if self.resize is not None: 245 | out = torch.nn.functional.interpolate(out, size=orig_shape, mode='bicubic', align_corners=True) 246 | return out 247 | -------------------------------------------------------------------------------- /third_party/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/dynamic-video-depth/79177ef5941b15b0aa2395b626f922fcb7b4c179/third_party/__init__.py -------------------------------------------------------------------------------- /third_party/hourglass.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # from https://github.com/google/mannequinchallenge/blob/3448d9d49dc130db7ed18053b70f66bc157d238f/models/hourglass.py#L19 16 | 17 | import torch 18 | import torch.nn as nn 19 | 20 | 21 | class inception(nn.Module): 22 | def __init__(self, input_size, config): 23 | self.config = config 24 | super(inception, self).__init__() 25 | self.convs = nn.ModuleList() 26 | 27 | # Base 1*1 conv layer 28 | self.convs.append(nn.Sequential( 29 | nn.Conv2d(input_size, config[0][0], 1), 30 | nn.BatchNorm2d(config[0][0], affine=False), 31 | nn.ReLU(True), 32 | )) 33 | 34 | # Additional layers 35 | for i in range(1, len(config)): 36 | filt = config[i][0] 37 | pad = int((filt - 1) / 2) 38 | out_a = config[i][1] 39 | out_b = config[i][2] 40 | conv = nn.Sequential( 41 | nn.Conv2d(input_size, out_a, 1), 42 | nn.BatchNorm2d(out_a, affine=False), 43 | nn.ReLU(True), 44 | nn.Conv2d(out_a, out_b, filt, padding=pad), 45 | nn.BatchNorm2d(out_b, affine=False), 46 | nn.ReLU(True) 47 | ) 48 | self.convs.append(conv) 49 | 50 | def __repr__(self): 51 | return "inception" + str(self.config) 52 | 53 | def forward(self, x): 54 | ret = [] 55 | for conv in (self.convs): 56 | ret.append(conv(x)) 57 | return torch.cat(ret, dim=1) 58 | 59 | 60 | class Channels1(nn.Module): 61 | def __init__(self): 62 | super(Channels1, self).__init__() 63 | self.list = nn.ModuleList() 64 | self.list.append( 65 | nn.Sequential( 66 | inception(256, [[64], [3, 32, 64], [5, 32, 64], [7, 32, 64]]), 67 | inception(256, [[64], [3, 32, 64], [5, 32, 64], [7, 32, 64]]) 68 | ) 69 | ) # EE 70 | self.list.append( 71 | nn.Sequential( 72 | nn.AvgPool2d(2), 73 | inception(256, [[64], [3, 32, 64], [5, 32, 64], [7, 32, 64]]), 74 | inception(256, [[64], [3, 32, 64], [5, 32, 64], [7, 32, 64]]), 75 | inception(256, [[64], [3, 32, 64], [5, 32, 64], [7, 32, 64]]), 76 | nn.UpsamplingBilinear2d(scale_factor=2) 77 | ) 78 | ) # EEE 79 | 80 | def forward(self, x): 81 | return self.list[0](x) + self.list[1](x) 82 | 83 | 84 | class Channels2(nn.Module): 85 | def __init__(self): 86 | super(Channels2, self).__init__() 87 | self.list = nn.ModuleList() 88 | self.list.append( 89 | nn.Sequential( 90 | inception(256, [[64], [3, 32, 64], [5, 32, 64], [7, 32, 64]]), 91 | inception(256, [[64], [3, 64, 64], [7, 64, 64], [11, 64, 64]]) 92 | ) 93 | ) # EF 94 | self.list.append( 95 | nn.Sequential( 96 | nn.AvgPool2d(2), 97 | inception(256, [[64], [3, 32, 64], [5, 32, 64], [7, 32, 64]]), 98 | inception(256, [[64], [3, 32, 64], [5, 32, 64], [7, 32, 64]]), 99 | Channels1(), 100 | inception(256, [[64], [3, 32, 64], [5, 32, 64], [7, 32, 64]]), 101 | inception(256, [[64], [3, 64, 64], [7, 64, 64], [11, 64, 64]]), 102 | nn.UpsamplingBilinear2d(scale_factor=2) 103 | ) 104 | ) # EE1EF 105 | 106 | def forward(self, x): 107 | return self.list[0](x) + self.list[1](x) 108 | 109 | 110 | class Channels3(nn.Module): 111 | def __init__(self): 112 | super(Channels3, self).__init__() 113 | self.list = nn.ModuleList() 114 | self.list.append( 115 | nn.Sequential( 116 | nn.AvgPool2d(2), 117 | inception(128, [[32], [3, 32, 32], [5, 32, 32], [7, 32, 32]]), 118 | inception(128, [[64], [3, 32, 64], [5, 32, 64], [7, 32, 64]]), 119 | Channels2(), 120 | inception(256, [[64], [3, 32, 64], [5, 32, 64], [7, 32, 64]]), 121 | inception(256, [[32], [3, 32, 32], [5, 32, 32], [7, 32, 32]]), 122 | nn.UpsamplingBilinear2d(scale_factor=2) 123 | ) 124 | ) # BD2EG 125 | self.list.append( 126 | nn.Sequential( 127 | inception(128, [[32], [3, 32, 32], [5, 32, 32], [7, 32, 32]]), 128 | inception(128, [[32], [3, 64, 32], [7, 64, 32], [11, 64, 32]]) 129 | ) 130 | ) # BC 131 | 132 | def forward(self, x): 133 | return self.list[0](x) + self.list[1](x) 134 | 135 | 136 | class Channels4(nn.Module): 137 | def __init__(self): 138 | super(Channels4, self).__init__() 139 | self.list = nn.ModuleList() 140 | self.list.append( 141 | nn.Sequential( 142 | nn.AvgPool2d(2), 143 | inception(128, [[32], [3, 32, 32], [5, 32, 32], [7, 32, 32]]), 144 | inception(128, [[32], [3, 32, 32], [5, 32, 32], [7, 32, 32]]), 145 | Channels3(), 146 | inception(128, [[32], [3, 64, 32], [5, 64, 32], [7, 64, 32]]), 147 | inception(128, [[16], [3, 32, 16], [7, 32, 16], [11, 32, 16]]), 148 | nn.UpsamplingBilinear2d(scale_factor=2) 149 | ) 150 | ) # BB3BA 151 | self.list.append( 152 | nn.Sequential( 153 | inception(128, [[16], [3, 64, 16], [7, 64, 16], [11, 64, 16]]) 154 | ) 155 | ) # A 156 | 157 | def forward(self, x): 158 | return self.list[0](x) + self.list[1](x) 159 | 160 | 161 | class HourglassModel(nn.Module): 162 | def __init__(self, num_input=3, noexp=False): 163 | super(HourglassModel, self).__init__() 164 | 165 | self.seq = nn.Sequential( 166 | nn.Conv2d(num_input, 128, 7, padding=3), 167 | nn.BatchNorm2d(128), 168 | nn.ReLU(True), 169 | Channels4(), 170 | ) 171 | 172 | uncertainty_layer = [ 173 | nn.Conv2d(64, 1, 3, padding=1), torch.nn.Sigmoid()] 174 | self.uncertainty_layer = torch.nn.Sequential(*uncertainty_layer) 175 | self.pred_layer = nn.Conv2d(64, 1, 3, padding=1) 176 | self.noexp = noexp 177 | 178 | def forward(self, input_): 179 | pred_feature = self.seq(input_) 180 | 181 | pred_d = self.pred_layer(pred_feature) 182 | 183 | if self.noexp: 184 | depth = pred_d 185 | else: 186 | depth = torch.exp(pred_d) 187 | 188 | return depth 189 | 190 | 191 | class HourglassModel_Embed(nn.Module): 192 | def __init__(self, num_input=3, noexp=False, use_embedding=False, n_embedding=100): 193 | super(HourglassModel_Embed, self).__init__() 194 | self.net_depth = HourglassModel(num_input, noexp) 195 | 196 | self.use_embedding = use_embedding 197 | if use_embedding: 198 | self.embedding = nn.Embedding(n_embedding, 1, _weight=torch.ones([n_embedding, 1])) 199 | 200 | def freeze(self): 201 | self.net_depth.eval() 202 | for param in self.net_depth.parameters(): 203 | param.requires_grad = False 204 | 205 | def defrost(self): 206 | self.net_depth.eval() 207 | for param in self.net_depth.parameters(): 208 | param.requires_grad = True 209 | 210 | def forward(self, input_, embed_index=None): 211 | depth = self.net_depth(input_) 212 | return depth 213 | -------------------------------------------------------------------------------- /third_party/midas_blocks.py: -------------------------------------------------------------------------------- 1 | 2 | """ MIT License 3 | 4 | Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab) 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. """ 23 | 24 | import torch 25 | import torch.nn as nn 26 | 27 | 28 | def _make_encoder(features, use_pretrained): 29 | pretrained = _make_pretrained_resnext101_wsl(use_pretrained) 30 | scratch = _make_scratch([256, 512, 1024, 2048], features) 31 | 32 | return pretrained, scratch 33 | 34 | 35 | def _make_resnet_backbone(resnet): 36 | pretrained = nn.Module() 37 | pretrained.layer1 = nn.Sequential( 38 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 39 | ) 40 | 41 | pretrained.layer2 = resnet.layer2 42 | pretrained.layer3 = resnet.layer3 43 | pretrained.layer4 = resnet.layer4 44 | 45 | return pretrained 46 | 47 | 48 | def _make_pretrained_resnext101_wsl(use_pretrained): 49 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") 50 | return _make_resnet_backbone(resnet) 51 | 52 | 53 | def _make_scratch(in_shape, out_shape): 54 | scratch = nn.Module() 55 | 56 | scratch.layer1_rn = nn.Conv2d( 57 | in_shape[0], out_shape, kernel_size=3, stride=1, padding=1, bias=False 58 | ) 59 | scratch.layer2_rn = nn.Conv2d( 60 | in_shape[1], out_shape, kernel_size=3, stride=1, padding=1, bias=False 61 | ) 62 | scratch.layer3_rn = nn.Conv2d( 63 | in_shape[2], out_shape, kernel_size=3, stride=1, padding=1, bias=False 64 | ) 65 | scratch.layer4_rn = nn.Conv2d( 66 | in_shape[3], out_shape, kernel_size=3, stride=1, padding=1, bias=False 67 | ) 68 | return scratch 69 | 70 | 71 | class Interpolate(nn.Module): 72 | """Interpolation module. 73 | """ 74 | 75 | def __init__(self, scale_factor, mode): 76 | """Init. 77 | Args: 78 | scale_factor (float): scaling 79 | mode (str): interpolation mode 80 | """ 81 | super(Interpolate, self).__init__() 82 | 83 | self.interp = nn.functional.interpolate 84 | self.scale_factor = scale_factor 85 | self.mode = mode 86 | 87 | def forward(self, x): 88 | """Forward pass. 89 | Args: 90 | x (tensor): input 91 | Returns: 92 | tensor: interpolated data 93 | """ 94 | 95 | x = self.interp( 96 | x, scale_factor=self.scale_factor, mode=self.mode, align_corners=False 97 | ) 98 | 99 | return x 100 | 101 | 102 | class ResidualConvUnit(nn.Module): 103 | """Residual convolution module. 104 | """ 105 | 106 | def __init__(self, features): 107 | """Init. 108 | Args: 109 | features (int): number of features 110 | """ 111 | super().__init__() 112 | 113 | self.conv1 = nn.Conv2d( 114 | features, features, kernel_size=3, stride=1, padding=1, bias=True 115 | ) 116 | 117 | self.conv2 = nn.Conv2d( 118 | features, features, kernel_size=3, stride=1, padding=1, bias=True 119 | ) 120 | 121 | self.relu = nn.ReLU(inplace=True) 122 | 123 | def forward(self, x): 124 | """Forward pass. 125 | Args: 126 | x (tensor): input 127 | Returns: 128 | tensor: output 129 | """ 130 | out = self.relu(x) 131 | out = self.conv1(out) 132 | out = self.relu(out) 133 | out = self.conv2(out) 134 | 135 | return out + x 136 | 137 | 138 | class FeatureFusionBlock(nn.Module): 139 | """Feature fusion block. 140 | """ 141 | 142 | def __init__(self, features): 143 | """Init. 144 | Args: 145 | features (int): number of features 146 | """ 147 | super(FeatureFusionBlock, self).__init__() 148 | 149 | self.resConfUnit1 = ResidualConvUnit(features) 150 | self.resConfUnit2 = ResidualConvUnit(features) 151 | 152 | def forward(self, *xs): 153 | """Forward pass. 154 | Returns: 155 | tensor: output 156 | """ 157 | output = xs[0] 158 | 159 | if len(xs) == 2: 160 | output += self.resConfUnit1(xs[1]) 161 | 162 | output = self.resConfUnit2(output) 163 | 164 | output = nn.functional.interpolate( 165 | output, scale_factor=2, mode="bilinear", align_corners=True 166 | ) 167 | 168 | return output 169 | -------------------------------------------------------------------------------- /third_party/util_colormap.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # Author: Anton Mikhailov 5 | import numpy as np 6 | turbo_colormap_data = [[0.18995, 0.07176, 0.23217], [0.19483, 0.08339, 0.26149], [0.19956, 0.09498, 0.29024], [0.20415, 0.10652, 0.31844], [0.20860, 0.11802, 0.34607], [0.21291, 0.12947, 0.37314], [0.21708, 0.14087, 0.39964], [0.22111, 0.15223, 0.42558], [0.22500, 0.16354, 0.45096], [0.22875, 0.17481, 0.47578], [0.23236, 0.18603, 0.50004], [0.23582, 0.19720, 0.52373], [0.23915, 0.20833, 0.54686], [0.24234, 0.21941, 0.56942], [0.24539, 0.23044, 0.59142], [0.24830, 0.24143, 0.61286], [0.25107, 0.25237, 0.63374], [0.25369, 0.26327, 0.65406], [0.25618, 0.27412, 0.67381], [0.25853, 0.28492, 0.69300], [0.26074, 0.29568, 0.71162], [0.26280, 0.30639, 0.72968], [0.26473, 0.31706, 0.74718], [0.26652, 0.32768, 0.76412], [0.26816, 0.33825, 0.78050], [0.26967, 0.34878, 0.79631], [0.27103, 0.35926, 0.81156], [0.27226, 0.36970, 0.82624], [0.27334, 0.38008, 0.84037], [0.27429, 0.39043, 0.85393], [0.27509, 0.40072, 0.86692], [0.27576, 0.41097, 0.87936], [0.27628, 0.42118, 0.89123], [0.27667, 0.43134, 0.90254], [0.27691, 0.44145, 0.91328], [0.27701, 0.45152, 0.92347], [0.27698, 0.46153, 0.93309], [0.27680, 0.47151, 0.94214], [0.27648, 0.48144, 0.95064], [0.27603, 0.49132, 0.95857], [0.27543, 0.50115, 0.96594], [0.27469, 0.51094, 0.97275], [0.27381, 0.52069, 0.97899], [0.27273, 0.53040, 0.98461], [0.27106, 0.54015, 0.98930], [0.26878, 0.54995, 0.99303], [0.26592, 0.55979, 0.99583], [0.26252, 0.56967, 0.99773], [0.25862, 0.57958, 0.99876], [0.25425, 0.58950, 0.99896], [0.24946, 0.59943, 0.99835], [0.24427, 0.60937, 0.99697], [0.23874, 0.61931, 0.99485], [0.23288, 0.62923, 0.99202], [0.22676, 0.63913, 0.98851], [0.22039, 0.64901, 0.98436], [0.21382, 0.65886, 0.97959], [0.20708, 0.66866, 0.97423], [0.20021, 0.67842, 0.96833], [0.19326, 0.68812, 0.96190], [0.18625, 0.69775, 0.95498], [0.17923, 0.70732, 0.94761], [0.17223, 0.71680, 0.93981], [0.16529, 0.72620, 0.93161], [0.15844, 0.73551, 0.92305], [0.15173, 0.74472, 0.91416], [0.14519, 0.75381, 0.90496], [0.13886, 0.76279, 0.89550], [0.13278, 0.77165, 0.88580], [0.12698, 0.78037, 0.87590], [0.12151, 0.78896, 0.86581], [0.11639, 0.79740, 0.85559], [0.11167, 0.80569, 0.84525], [0.10738, 0.81381, 0.83484], [0.10357, 0.82177, 0.82437], [0.10026, 0.82955, 0.81389], [0.09750, 0.83714, 0.80342], [0.09532, 0.84455, 0.79299], [0.09377, 0.85175, 0.78264], [0.09287, 0.85875, 0.77240], [0.09267, 0.86554, 0.76230], [0.09320, 0.87211, 0.75237], [0.09451, 0.87844, 0.74265], [0.09662, 0.88454, 0.73316], [0.09958, 0.89040, 0.72393], [0.10342, 0.89600, 0.71500], [0.10815, 0.90142, 0.70599], [0.11374, 0.90673, 0.69651], [0.12014, 0.91193, 0.68660], [0.12733, 0.91701, 0.67627], [0.13526, 0.92197, 0.66556], [0.14391, 0.92680, 0.65448], [0.15323, 0.93151, 0.64308], [0.16319, 0.93609, 0.63137], [0.17377, 0.94053, 0.61938], [0.18491, 0.94484, 0.60713], [0.19659, 0.94901, 0.59466], [0.20877, 0.95304, 0.58199], [0.22142, 0.95692, 0.56914], [0.23449, 0.96065, 0.55614], [0.24797, 0.96423, 0.54303], [0.26180, 0.96765, 0.52981], [0.27597, 0.97092, 0.51653], [0.29042, 0.97403, 0.50321], [0.30513, 0.97697, 0.48987], [0.32006, 0.97974, 0.47654], [0.33517, 0.98234, 0.46325], [0.35043, 0.98477, 0.45002], [0.36581, 0.98702, 0.43688], [0.38127, 0.98909, 0.42386], [0.39678, 0.99098, 0.41098], [0.41229, 0.99268, 0.39826], [0.42778, 0.99419, 0.38575], [0.44321, 0.99551, 0.37345], [0.45854, 0.99663, 0.36140], [0.47375, 0.99755, 0.34963], [0.48879, 0.99828, 0.33816], [0.50362, 0.99879, 0.32701], [0.51822, 0.99910, 0.31622], [0.53255, 0.99919, 0.30581], [0.54658, 0.99907, 0.29581], [0.56026, 0.99873, 0.28623], [0.57357, 0.99817, 0.27712], [0.58646, 0.99739, 0.26849], [0.59891, 0.99638, 0.26038], [0.61088, 0.99514, 0.25280], [0.62233, 0.99366, 0.24579], [0.63323, 0.99195, 0.23937], [ 7 | 0.64362, 0.98999, 0.23356], [0.65394, 0.98775, 0.22835], [0.66428, 0.98524, 0.22370], [0.67462, 0.98246, 0.21960], [0.68494, 0.97941, 0.21602], [0.69525, 0.97610, 0.21294], [0.70553, 0.97255, 0.21032], [0.71577, 0.96875, 0.20815], [0.72596, 0.96470, 0.20640], [0.73610, 0.96043, 0.20504], [0.74617, 0.95593, 0.20406], [0.75617, 0.95121, 0.20343], [0.76608, 0.94627, 0.20311], [0.77591, 0.94113, 0.20310], [0.78563, 0.93579, 0.20336], [0.79524, 0.93025, 0.20386], [0.80473, 0.92452, 0.20459], [0.81410, 0.91861, 0.20552], [0.82333, 0.91253, 0.20663], [0.83241, 0.90627, 0.20788], [0.84133, 0.89986, 0.20926], [0.85010, 0.89328, 0.21074], [0.85868, 0.88655, 0.21230], [0.86709, 0.87968, 0.21391], [0.87530, 0.87267, 0.21555], [0.88331, 0.86553, 0.21719], [0.89112, 0.85826, 0.21880], [0.89870, 0.85087, 0.22038], [0.90605, 0.84337, 0.22188], [0.91317, 0.83576, 0.22328], [0.92004, 0.82806, 0.22456], [0.92666, 0.82025, 0.22570], [0.93301, 0.81236, 0.22667], [0.93909, 0.80439, 0.22744], [0.94489, 0.79634, 0.22800], [0.95039, 0.78823, 0.22831], [0.95560, 0.78005, 0.22836], [0.96049, 0.77181, 0.22811], [0.96507, 0.76352, 0.22754], [0.96931, 0.75519, 0.22663], [0.97323, 0.74682, 0.22536], [0.97679, 0.73842, 0.22369], [0.98000, 0.73000, 0.22161], [0.98289, 0.72140, 0.21918], [0.98549, 0.71250, 0.21650], [0.98781, 0.70330, 0.21358], [0.98986, 0.69382, 0.21043], [0.99163, 0.68408, 0.20706], [0.99314, 0.67408, 0.20348], [0.99438, 0.66386, 0.19971], [0.99535, 0.65341, 0.19577], [0.99607, 0.64277, 0.19165], [0.99654, 0.63193, 0.18738], [0.99675, 0.62093, 0.18297], [0.99672, 0.60977, 0.17842], [0.99644, 0.59846, 0.17376], [0.99593, 0.58703, 0.16899], [0.99517, 0.57549, 0.16412], [0.99419, 0.56386, 0.15918], [0.99297, 0.55214, 0.15417], [0.99153, 0.54036, 0.14910], [0.98987, 0.52854, 0.14398], [0.98799, 0.51667, 0.13883], [0.98590, 0.50479, 0.13367], [0.98360, 0.49291, 0.12849], [0.98108, 0.48104, 0.12332], [0.97837, 0.46920, 0.11817], [0.97545, 0.45740, 0.11305], [0.97234, 0.44565, 0.10797], [0.96904, 0.43399, 0.10294], [0.96555, 0.42241, 0.09798], [0.96187, 0.41093, 0.09310], [0.95801, 0.39958, 0.08831], [0.95398, 0.38836, 0.08362], [0.94977, 0.37729, 0.07905], [0.94538, 0.36638, 0.07461], [0.94084, 0.35566, 0.07031], [0.93612, 0.34513, 0.06616], [0.93125, 0.33482, 0.06218], [0.92623, 0.32473, 0.05837], [0.92105, 0.31489, 0.05475], [0.91572, 0.30530, 0.05134], [0.91024, 0.29599, 0.04814], [0.90463, 0.28696, 0.04516], [0.89888, 0.27824, 0.04243], [0.89298, 0.26981, 0.03993], [0.88691, 0.26152, 0.03753], [0.88066, 0.25334, 0.03521], [0.87422, 0.24526, 0.03297], [0.86760, 0.23730, 0.03082], [0.86079, 0.22945, 0.02875], [0.85380, 0.22170, 0.02677], [0.84662, 0.21407, 0.02487], [0.83926, 0.20654, 0.02305], [0.83172, 0.19912, 0.02131], [0.82399, 0.19182, 0.01966], [0.81608, 0.18462, 0.01809], [0.80799, 0.17753, 0.01660], [0.79971, 0.17055, 0.01520], [0.79125, 0.16368, 0.01387], [0.78260, 0.15693, 0.01264], [0.77377, 0.15028, 0.01148], [0.76476, 0.14374, 0.01041], [0.75556, 0.13731, 0.00942], [0.74617, 0.13098, 0.00851], [0.73661, 0.12477, 0.00769], [0.72686, 0.11867, 0.00695], [0.71692, 0.11268, 0.00629], [0.70680, 0.10680, 0.00571], [0.69650, 0.10102, 0.00522], [0.68602, 0.09536, 0.00481], [0.67535, 0.08980, 0.00449], [0.66449, 0.08436, 0.00424], [0.65345, 0.07902, 0.00408], [0.64223, 0.07380, 0.00401], [0.63082, 0.06868, 0.00401], [0.61923, 0.06367, 0.00410], [0.60746, 0.05878, 0.00427], [0.59550, 0.05399, 0.00453], [0.58336, 0.04931, 0.00486], [0.57103, 0.04474, 0.00529], [0.55852, 0.04028, 0.00579], [0.54583, 0.03593, 0.00638], [0.53295, 0.03169, 0.00705], [0.51989, 0.02756, 0.00780], [0.50664, 0.02354, 0.00863], [0.49321, 0.01963, 0.00955], [0.47960, 0.01583, 0.01055]] 8 | turbo_colormap_data_np = np.array(turbo_colormap_data) 9 | 10 | 11 | def heatmap_to_pseudo_color(heatmap): 12 | x = heatmap 13 | x = x.clip(0, 1) 14 | a = (x * 255).astype(int) 15 | b = (a + 1).clip(max=255) 16 | f = x * 255.0 - a 17 | pseudo_color = ( 18 | turbo_colormap_data_np[a] + (turbo_colormap_data_np[b] - turbo_colormap_data_np[a]) * f[..., None] 19 | ) 20 | pseudo_color[heatmap < 0.0] = 0.0 21 | pseudo_color[heatmap > 1.0] = 1.0 22 | return pseudo_color 23 | 24 | # The look-up table contains 256 entries. Each entry is a floating point sRGB triplet. 25 | # To use it with matplotlib, pass cmap=ListedColormap(turbo_colormap_data) as an arg to imshow() (don't forget "from matplotlib.colors import ListedColormap"). 26 | # If you have a typical 8-bit greyscale image, you can use the 8-bit value to index into this LUT directly. 27 | # The floating point color values can be converted to 8-bit sRGB via multiplying by 255 and casting/flooring to an integer. Saturation should not be required for IEEE-754 compliant arithmetic. 28 | # If you have a floating point value in the range [0,1], you can use interpolate() to linearly interpolate between the entries. 29 | # If you have 16-bit or 32-bit integer values, convert them to floating point values on the [0,1] range and then use interpolate(). Doing the interpolation in floating point will reduce banding. 30 | # If some of your values may lie outside the [0,1] range, use interpolate_or_clip() to highlight them. 31 | 32 | 33 | def interpolate(colormap, x): 34 | x = max(0.0, min(1.0, x)) 35 | a = int(x * 255.0) 36 | b = min(255, a + 1) 37 | f = x * 255.0 - a 38 | return [colormap[a][0] + (colormap[b][0] - colormap[a][0]) * f, 39 | colormap[a][1] + (colormap[b][1] - colormap[a][1]) * f, 40 | colormap[a][2] + (colormap[b][2] - colormap[a][2]) * f] 41 | 42 | 43 | def interpolate_or_clip(colormap, x): 44 | if x < 0.0: 45 | return [0.0, 0.0, 0.0] 46 | elif x > 1.0: 47 | return [1.0, 1.0, 1.0] 48 | else: 49 | return interpolate(colormap, x) 50 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from os.path import join, dirname, exists 16 | project_path = dirname(dirname(__file__)) 17 | -------------------------------------------------------------------------------- /util/util_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import configparser 16 | from os.path import join, abspath, dirname 17 | 18 | 19 | def get_project_config(file_path=None): 20 | if file_path is None: 21 | file_path = join(dirname(abspath(__file__)), '../configs/project_config.cfg') 22 | config = configparser.ConfigParser() 23 | config.read(file_path) 24 | assert 'Paths' in config 25 | config_dict = {} 26 | for k, v in config['Paths'].items(): 27 | config_dict[k] = v 28 | return config_dict 29 | -------------------------------------------------------------------------------- /util/util_flow.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import matplotlib.pyplot as plt 17 | import os.path 18 | 19 | TAG_CHAR = np.array([202021.25], np.float32) 20 | 21 | 22 | def readFlow(fn): 23 | """ Read .flo file in Middlebury format""" 24 | # Code adapted from: 25 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 26 | 27 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 28 | # print 'fn = %s'%(fn) 29 | with open(fn, 'rb') as f: 30 | magic = np.fromfile(f, np.float32, count=1) 31 | if 202021.25 != magic: 32 | print('Magic number incorrect. Invalid .flo file') 33 | return None 34 | else: 35 | w = np.fromfile(f, np.int32, count=1) 36 | h = np.fromfile(f, np.int32, count=1) 37 | # print 'Reading %d x %d flo file\n' % (w, h) 38 | data = np.fromfile(f, np.float32, count=2 * int(w) * int(h)) 39 | # Reshape data into 3D array (columns, rows, bands) 40 | # The reshape here is for visualization, the original code is (w,h,2) 41 | return np.resize(data, (int(h), int(w), 2)) 42 | 43 | 44 | def writeFlow(filename, uv, v=None): 45 | """ Write optical flow to file. 46 | 47 | If v is None, uv is assumed to contain both u and v channels, 48 | stacked in depth. 49 | Original code by Deqing Sun, adapted from Daniel Scharstein. 50 | """ 51 | nBands = 2 52 | 53 | if v is None: 54 | assert(uv.ndim == 3) 55 | assert(uv.shape[2] == 2) 56 | u = uv[:, :, 0] 57 | v = uv[:, :, 1] 58 | else: 59 | u = uv 60 | 61 | assert(u.shape == v.shape) 62 | height, width = u.shape 63 | f = open(filename, 'wb') 64 | # write the header 65 | f.write(TAG_CHAR) 66 | np.array(width).astype(np.int32).tofile(f) 67 | np.array(height).astype(np.int32).tofile(f) 68 | # arrange into matrix form 69 | tmp = np.zeros((height, width * nBands)) 70 | tmp[:, np.arange(width) * 2] = u 71 | tmp[:, np.arange(width) * 2 + 1] = v 72 | tmp.astype(np.float32).tofile(f) 73 | f.close() 74 | 75 | 76 | # ref: https://github.com/sampepose/flownet2-tf/ 77 | # blob/18f87081db44939414fc4a48834f9e0da3e69f4c/src/flowlib.py#L240 78 | def visulize_flow_file(flow_filename, save_dir=None): 79 | flow_data = readFlow(flow_filename) 80 | img = flow2img(flow_data) 81 | # plt.imshow(img) 82 | # plt.show() 83 | if save_dir: 84 | idx = flow_filename.rfind("/") + 1 85 | plt.imsave(os.path.join(save_dir, "%s-vis.png" % flow_filename[idx:-4]), img) 86 | 87 | 88 | def get_maxrad(flow_data): 89 | u = flow_data[:, :, 0] 90 | v = flow_data[:, :, 1] 91 | UNKNOW_FLOW_THRESHOLD = 1e7 92 | pr1 = abs(u) > UNKNOW_FLOW_THRESHOLD 93 | pr2 = abs(v) > UNKNOW_FLOW_THRESHOLD 94 | idx_unknown = (pr1 | pr2) 95 | u[idx_unknown] = v[idx_unknown] = 0 96 | rad = np.sqrt(u ** 2 + v ** 2) 97 | 98 | maxrad = max(-1, np.max(rad)) 99 | return maxrad 100 | 101 | 102 | def flow2img(flow_data, maxrad=None): 103 | """ 104 | convert optical flow into color image 105 | :param flow_data: 106 | :return: color image 107 | """ 108 | # print(flow_data.shape) 109 | # print(type(flow_data)) 110 | u = flow_data[:, :, 0] 111 | v = flow_data[:, :, 1] 112 | 113 | UNKNOW_FLOW_THRESHOLD = 1e7 114 | pr1 = abs(u) > UNKNOW_FLOW_THRESHOLD 115 | pr2 = abs(v) > UNKNOW_FLOW_THRESHOLD 116 | idx_unknown = (pr1 | pr2) 117 | u[idx_unknown] = v[idx_unknown] = 0 118 | 119 | # get max value in each direction 120 | ''' 121 | maxu = -999. 122 | maxv = -999. 123 | minu = 999. 124 | minv = 999. 125 | maxu = max(maxu, np.max(u)) 126 | maxv = max(maxv, np.max(v)) 127 | minu = min(minu, np.min(u)) 128 | minv = min(minv, np.min(v)) 129 | ''' 130 | 131 | rad = np.sqrt(u ** 2 + v ** 2) 132 | if maxrad is None: 133 | maxrad = max(-1, np.max(rad)) 134 | u = u / maxrad + np.finfo(float).eps 135 | v = v / maxrad + np.finfo(float).eps 136 | 137 | img = compute_color(u, v) 138 | 139 | idx = np.repeat(idx_unknown[:, :, np.newaxis], 3, axis=2) 140 | img[idx] = 0 141 | 142 | return img 143 | 144 | 145 | def compute_color(u, v): 146 | """ 147 | compute optical flow color map 148 | :param u: horizontal optical flow 149 | :param v: vertical optical flow 150 | :return: 151 | """ 152 | 153 | height, width = u.shape 154 | img = np.zeros((height, width, 3)) 155 | 156 | NAN_idx = np.isnan(u) | np.isnan(v) 157 | u[NAN_idx] = v[NAN_idx] = 0 158 | 159 | colorwheel = make_color_wheel() 160 | ncols = np.size(colorwheel, 0) 161 | 162 | rad = np.sqrt(u ** 2 + v ** 2) 163 | rad = np.where(rad > 1, 1, rad) 164 | 165 | a = np.arctan2(-v, -u) / np.pi 166 | 167 | fk = (a + 1) / 2 * (ncols - 1) + 1 168 | 169 | k0 = np.floor(fk).astype(int) 170 | 171 | k1 = k0 + 1 172 | k1[k1 == ncols + 1] = 1 173 | f = fk - k0 174 | 175 | for i in range(0, np.size(colorwheel, 1)): 176 | tmp = colorwheel[:, i] 177 | col0 = tmp[k0 - 1] / 255 178 | col1 = tmp[k1 - 1] / 255 179 | col = (1 - f) * col0 + f * col1 180 | 181 | idx = rad <= 1 182 | col[idx] = 1 - rad[idx] * (1 - col[idx]) 183 | notidx = np.logical_not(idx) 184 | 185 | col[notidx] *= 0.75 186 | img[:, :, i] = np.uint8(np.floor(255 * col * (1 - NAN_idx))) 187 | 188 | return img 189 | 190 | 191 | def make_color_wheel(): 192 | """ 193 | Generate color wheel according Middlebury color code 194 | :return: Color wheel 195 | """ 196 | RY = 15 197 | YG = 6 198 | GC = 4 199 | CB = 11 200 | BM = 13 201 | MR = 6 202 | 203 | ncols = RY + YG + GC + CB + BM + MR 204 | 205 | colorwheel = np.zeros([ncols, 3]) 206 | 207 | col = 0 208 | 209 | # RY 210 | colorwheel[0:RY, 0] = 255 211 | colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY)) 212 | col += RY 213 | 214 | # YG 215 | colorwheel[col:col + YG, 0] = 255 - np.transpose(np.floor(255 * np.arange(0, YG) / YG)) 216 | colorwheel[col:col + YG, 1] = 255 217 | col += YG 218 | 219 | # GC 220 | colorwheel[col:col + GC, 1] = 255 221 | colorwheel[col:col + GC, 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC)) 222 | col += GC 223 | 224 | # CB 225 | colorwheel[col:col + CB, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, CB) / CB)) 226 | colorwheel[col:col + CB, 2] = 255 227 | col += CB 228 | 229 | # BM 230 | colorwheel[col:col + BM, 2] = 255 231 | colorwheel[col:col + BM, 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM)) 232 | col += + BM 233 | 234 | # MR 235 | colorwheel[col:col + MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR)) 236 | colorwheel[col:col + MR, 0] = 255 237 | 238 | return colorwheel 239 | -------------------------------------------------------------------------------- /util/util_html.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from os.path import join, dirname, basename 16 | from glob import glob 17 | from os import makedirs 18 | 19 | 20 | class Webpage(): 21 | WEB_TEMPLATE = """ 22 | 23 | 24 | 25 | 30 | 32 | 34 | 36 | 37 | 38 | 40 | 41 | 42 | 44 | 45 | 61 | 62 | 63 | 64 | {body_content} 65 | 66 | 67 | """ 68 | image_tag_template = "" 69 | table_template = """ 70 | 71 | 72 | 73 | {table_header} 74 | 75 | 76 | 77 | 78 | {table_body} 79 | 80 |
81 | """ 82 | 83 | def __init__(self, notable=False): 84 | self.content = self.WEB_TEMPLATE 85 | self.table_content = self.table_template 86 | self.video_content = '' 87 | if not notable: 88 | self.devider = f'
data table

' 89 | else: 90 | self.devider = '' 91 | 92 | def add_image_table_from_folder(self, path, img_prefixes, keys=None, rel_path='./'): 93 | if keys is None: 94 | keys = img_prefixes 95 | header = '' 96 | for k in keys: 97 | header += f"{k}\n" 98 | content = "" 99 | file_lists = {} 100 | 101 | for prefix in img_prefixes: 102 | file_lists[prefix] = sorted(glob(join(path, prefix + '*'))) 103 | l = len(file_lists[prefix]) 104 | for i in range(l): 105 | content += "\n" 106 | for k in file_lists.keys(): 107 | link = join(rel_path, basename(file_lists[k][i])) 108 | content += f"\n" 109 | content += "\n" 110 | self.table_content = self.table_content.format(table_header=header, table_body=content) 111 | 112 | def add_video(self, rel_path_to_video, title=''): 113 | video_tag = f'
{title}


' 114 | self.video_content += video_tag 115 | 116 | def add_div(self, div_string): 117 | self.video_content += div_string 118 | 119 | def save(self, path): 120 | content = self.content.format(body_content=self.video_content + self.devider + self.table_content) 121 | d = dirname(path) 122 | makedirs(d, exist_ok=True) 123 | with open(path, 'w') as f: 124 | f.write(content) 125 | return 126 | -------------------------------------------------------------------------------- /util/util_imageIO.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from PIL import Image 16 | import numpy as np 17 | from skimage.transform import resize as imresize 18 | 19 | 20 | def read_image(path, load_alpha=False): 21 | im = np.asarray(Image.open(path)) 22 | dims = len(im.shape) 23 | if dims == 2: 24 | return im 25 | elif dims == 3: 26 | if im.shape[-1] == 3: 27 | return im 28 | elif load_alpha: 29 | return im 30 | else: 31 | return im[..., :3] 32 | else: 33 | raise ValueError(f'invalid dimensions encoutered. Only except dims 2,3 but encoutered {dims}') 34 | 35 | 36 | def resize_image(im, size=None, scale=None): 37 | H, W = im.shape[:2] 38 | if scale: 39 | th = H // scale 40 | tw = W // scale 41 | s = (th, tw) 42 | else: 43 | s = size 44 | im = imresize(im, s) 45 | return im 46 | 47 | 48 | def hwc2chw(im): 49 | dims = len(im.shape) 50 | if dims == 2: 51 | return im[None, ...] 52 | elif dims == 3: 53 | return np.transpose(im, (2, 0, 1)) 54 | else: 55 | raise ValueError(f'invalid dimensions encoutered. Only except dims 2,3 but encoutered {dims}') 56 | -------------------------------------------------------------------------------- /util/util_loadlib.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import subprocess 16 | from .util_print import str_warning, str_verbose 17 | import sys 18 | 19 | 20 | def set_gpu(gpu, check=True): 21 | import os 22 | import torch 23 | import torch.backends.cudnn as cudnn 24 | cudnn.benchmark = False 25 | torch.cuda.set_device(int(gpu)) 26 | 27 | 28 | def _check_gpu_setting_in_use(gpu): 29 | ''' 30 | check that CUDA_VISIBLE_DEVICES is actually working 31 | by starting a clean thread with the same CUDA_VISIBLE_DEVICES 32 | ''' 33 | import subprocess 34 | try: 35 | which_python = sys.executable 36 | output = subprocess.check_output( 37 | 'CUDA_VISIBLE_DEVICES=%s %s -c "import torch; print(torch.cuda.device_count())"' % (gpu, which_python), shell=True) 38 | except subprocess.CalledProcessError as e: 39 | print(str(e)) 40 | output = output.decode().strip() 41 | import torch 42 | return torch.cuda.device_count() == int(output) 43 | 44 | 45 | def _check_gpu(gpu): 46 | msg = subprocess.check_output( 47 | 'nvidia-smi --query-gpu=index,utilization.gpu,memory.used --format=csv,nounits,noheader -i %s' % (gpu,), shell=True) 48 | msg = msg.decode('utf-8') 49 | all_ok = True 50 | for line in msg.split('\n'): 51 | if line == '': 52 | break 53 | stats = [x.strip() for x in line.split(',')] 54 | gpu = stats[0] 55 | util = int(stats[1]) 56 | mem_used = int(stats[2]) 57 | if util > 10 or mem_used > 1000: # util in percentage and mem_used in MiB 58 | print(str_warning, 'Designated GPU in use: id=%s, util=%d%%, memory in use: %d MiB' % ( 59 | gpu, util, mem_used)) 60 | all_ok = False 61 | if all_ok: 62 | print(str_verbose, 'All designated GPU(s) free to use. ') 63 | 64 | 65 | def set_manual_seed(seed): 66 | import random 67 | random.seed(seed) 68 | try: 69 | import numpy as np 70 | np.random.seed(seed) 71 | except ImportError as err: 72 | print('Numpy not found. Random seed for numpy not set. ') 73 | try: 74 | import torch 75 | torch.manual_seed(seed) 76 | torch.cuda.manual_seed_all(seed) 77 | except ImportError as err: 78 | print('Pytorch not found. Random seed for pytorch not set. ') 79 | -------------------------------------------------------------------------------- /util/util_plot.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import matplotlib.pyplot as plt 16 | import matplotlib.animation as animation 17 | from third_party.util_colormap import turbo_colormap_data 18 | import matplotlib 19 | from matplotlib.colors import ListedColormap 20 | 21 | matplotlib.cm.register_cmap('turbo', cmap=ListedColormap(turbo_colormap_data)) 22 | 23 | 24 | def save_img_tensor_anim(img_tensor, path, markers=None, dpi=80, fps=30): 25 | fig = plt.figure(dpi=dpi) 26 | f = plt.imshow(img_tensor[0, ...].numpy()) 27 | if markers is not None: 28 | s = plt.plot(markers[:, 0, 0], markers[:, 1, 0], 'w+', markersize=5) 29 | plt.axis('off') 30 | plt.tight_layout(h_pad=0, w_pad=0) 31 | plt.margins(0.0) 32 | fig.subplots_adjust(0, 0, 1, 1) 33 | 34 | def anim(step): 35 | f.set_data(img_tensor[step, ...]) 36 | if markers is not None: 37 | s[0].set_data([markers[:, 0, step], markers[:, 1, step]]) 38 | all_steps = min(img_tensor.shape[0], markers.shape[-1]) 39 | line_ani = animation.FuncAnimation(fig, anim, all_steps, 40 | interval=1000 / fps, blit=False) 41 | line_ani.save(path) 42 | plt.close() 43 | 44 | 45 | def save_depth_tensor_anim(depth_tensor, path, minv=None, maxv=None, markers=None, dpi=80, fps=30): 46 | 47 | fig = plt.figure() 48 | f = plt.imshow(1 / depth_tensor[0, 0, ...].numpy(), cmap='turbo', vmax=1 / minv, vmin=1 / maxv) 49 | s = plt.plot(markers[:, 0, 0], markers[:, 1, 0], 'w+', markersize=5) 50 | 51 | def anim(step): 52 | f.set_data(1 / depth_tensor[step, 0, ...].numpy()) 53 | if markers is not None: 54 | s[0].set_data([markers[:, 0, step], markers[:, 1, step]]) 55 | 56 | all_steps = markers.shape[-1] 57 | 58 | line_ani = animation.FuncAnimation(fig, anim, all_steps, 59 | interval=1000 / fps, blit=False) 60 | line_ani.save(path) 61 | plt.close() 62 | -------------------------------------------------------------------------------- /util/util_print.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | class bcolors: 17 | HEADER = '\033[95m' 18 | OKBLUE = '\033[94m' 19 | OKGREEN = '\033[92m' 20 | WARNING = '\033[93m' 21 | FAIL = '\033[91m' 22 | ENDC = '\033[0m' 23 | BOLD = '\033[1m' 24 | UNDERLINE = '\033[4m' 25 | 26 | 27 | str_stage = bcolors.OKBLUE + '==>' + bcolors.ENDC 28 | str_verbose = bcolors.OKGREEN + '[Verbose]' + bcolors.ENDC 29 | str_warning = bcolors.WARNING + '[Warning]' + bcolors.ENDC 30 | str_error = bcolors.FAIL + '[Error]' + bcolors.ENDC 31 | -------------------------------------------------------------------------------- /util/util_visualize.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | from third_party.util_colormap import heatmap_to_pseudo_color 17 | KEYWORDS = ['depth', 'edgegradient', 'flow', 'img', 'rgb', 'image', 'edge', 'contour', 'softmask'] 18 | 19 | 20 | def detach_to_cpu(tensor): 21 | if type(tensor) == np.ndarray: 22 | return tensor 23 | else: 24 | if tensor.requires_grad: 25 | tensor.requires_grad = False 26 | tensor = tensor.cpu() 27 | return tensor.numpy() 28 | 29 | 30 | class Converter: 31 | def __init__(self): 32 | pass 33 | 34 | @staticmethod 35 | def depth2img(tensor, normalize=True, disparity=True, eps=1e-6, **kargs): 36 | t = detach_to_cpu(tensor) 37 | assert len(t.shape) == 4 38 | assert t.shape[1] == 1 39 | t = 1 / (t + eps) 40 | # if normalize: 41 | max_v = np.max(t, axis=(2, 3), keepdims=True) 42 | min_v = np.min(t, axis=(2, 3), keepdims=True) 43 | t = (t - min_v) / (max_v - min_v + eps) 44 | # return t 45 | # else: 46 | # return t 47 | cs = [] 48 | for b in range(t.shape[0]): 49 | c = heatmap_to_pseudo_color(t[b, 0, ...]) 50 | cs.append(c[None, ...]) 51 | cs = np.concatenate(cs, axis=0) 52 | cs = np.transpose(cs, [0, 3, 1, 2]) 53 | return cs 54 | 55 | @staticmethod 56 | def edge2img(tensor, normalize=True, eps=1e-6, **kargs): 57 | t = detach_to_cpu(tensor) 58 | if np.max(t) > 1 or np.min(t) < 0: 59 | t = 1 / (1 + np.exp(-t)) 60 | assert len(t.shape) == 4 61 | assert t.shape[1] == 1 62 | return t 63 | 64 | @staticmethod 65 | def image2img(tensor, **kargs): 66 | return Converter.img2img(tensor) 67 | 68 | @staticmethod 69 | def softmask2img(tensor, **kargs): 70 | t = detach_to_cpu(tensor) # [:, None, ...] 71 | # t = #detach_to_cpu(tensor) 72 | return t 73 | 74 | @staticmethod 75 | def scenef2img(tensor, **kargs): 76 | t = detach_to_cpu(tensor.squeeze(3)) 77 | assert len(t.shape) == 4 78 | return np.linalg.norm(t, ord=1, axis=-1, keepdims=True) 79 | 80 | @staticmethod 81 | def rgb2img(tensor, **kargs): 82 | return Converter.img2img(tensor) 83 | 84 | @staticmethod 85 | def img2img(tensor, **kargs): 86 | t = detach_to_cpu(tensor) 87 | if np.min(t) < -0.1: 88 | t = (t + 1) / 2 89 | elif np.max(t) > 1.5: 90 | t = t / 255 91 | return t 92 | 93 | @staticmethod 94 | def edgegradient2img(tensor, **kargs): 95 | t = detach_to_cpu(tensor) 96 | mag = np.max(abs(t)) 97 | positive = np.where(t > 0, t, 0) 98 | positive /= mag 99 | negative = np.where(t < 0, abs(t), 0) 100 | negative /= mag 101 | rgb = np.concatenate((positive, negative, np.zeros(negative.shape)), axis=1) 102 | return rgb 103 | 104 | @staticmethod 105 | def flow2img(tensor, **kargs): 106 | t = detach_to_cpu(tensor) 107 | return t 108 | 109 | 110 | def convert2rgb(tensor, key, **kargs): 111 | found = False 112 | for k in KEYWORDS: 113 | if k in key: 114 | convert = getattr(Converter, k + '2img') 115 | found = True 116 | break 117 | if not found: 118 | return None 119 | else: 120 | return convert(tensor, **kargs) 121 | 122 | 123 | def is_key_image(key): 124 | """check if the given key correspondes to images 125 | 126 | Arguments: 127 | key {str} -- key of a data pack 128 | 129 | Returns: 130 | bool -- [True if the given key correspondes to an image] 131 | """ 132 | 133 | for k in KEYWORDS: 134 | if k in key: 135 | return True 136 | return False 137 | 138 | 139 | def parse_key(key): 140 | rkey = None 141 | found = False 142 | mode = None 143 | for k in KEYWORDS: 144 | if k in key: 145 | rkey = k 146 | found = True 147 | break 148 | if 'pred' in key: 149 | mode = 'pred' 150 | elif 'gt' in key: 151 | mode = 'gt' 152 | if not found: 153 | return None, None 154 | else: 155 | return rkey, mode 156 | -------------------------------------------------------------------------------- /visualize/base_visualizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from torch.multiprocessing import Pool 16 | from os.path import join, dirname 17 | from os import makedirs 18 | import atexit 19 | from util.util_config import get_project_config 20 | 21 | 22 | class BaseVisualizer(): 23 | """ 24 | Async Visulization Worker 25 | """ 26 | 27 | def __init__(self, n_workers=4): 28 | # read global configs 29 | self.cfg = get_project_config() 30 | if n_workers == 0: 31 | pool = None 32 | elif n_workers > 0: 33 | pool = Pool(n_workers) 34 | else: 35 | raise ValueError(n_workers) 36 | self.pool = pool 37 | 38 | def cleanup(): 39 | if pool: 40 | pool.close() 41 | pool.join() 42 | atexit.register(cleanup) 43 | 44 | def visualize(self, pack, batch_idx, outdir): 45 | if self.pool: 46 | self.pool.apply_async( 47 | self._visualize, 48 | [pack, batch_idx, outdir, self.cfg], 49 | error_callback=self._error_callback 50 | ) 51 | else: 52 | self._visualize(pack, batch_idx, outdir, self.cfg) 53 | 54 | @staticmethod 55 | def _visualize(pack, batch_idx, param_f, outdir): 56 | # main visualiztion thread. 57 | raise NotImplementedError 58 | 59 | @staticmethod 60 | def _error_callback(e): 61 | print(str(e)) 62 | -------------------------------------------------------------------------------- /visualize/html_visualizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from util.util_visualize import convert2rgb, is_key_image 16 | from torch.multiprocessing import Pool 17 | from util.util_flow import flow2img 18 | from os.path import join 19 | import atexit 20 | import numpy as np 21 | from PIL import Image 22 | 23 | 24 | class HTMLVisualizer(): 25 | """ 26 | Async Visulization Worker 27 | """ 28 | 29 | def __init__(self, html_logger, n_workers=4): 30 | # read global configs 31 | if n_workers == 0: 32 | pool = None 33 | elif n_workers > 0: 34 | pool = Pool(n_workers) 35 | else: 36 | raise ValueError(n_workers) 37 | self.pool = pool 38 | self.html_logger = html_logger 39 | self.header_lut = None 40 | 41 | def cleanup(): 42 | if pool: 43 | pool.close() 44 | pool.join() 45 | atexit.register(cleanup) 46 | 47 | def visualize(self, pack, batch_idx, outdir, is_test=False): 48 | # first append to the shared content 49 | # then launch the subprocesses for dumping images. 50 | # b_size = pack['batch_size'] 51 | epoch_folder = outdir.split('/')[-1] 52 | if is_test: 53 | epoch_folder = None 54 | self.prepare_HTML_string(pack, batch_idx, epoch_folder) 55 | 56 | if self.pool: 57 | self.pool.apply_async( 58 | self._visualize, 59 | [pack, batch_idx, outdir], 60 | error_callback=self._error_callback 61 | ) 62 | else: 63 | self._visualize(pack, batch_idx, outdir) 64 | # prepare HTML string 65 | 66 | def prepare_HTML_string(self, pack, batch_idx, epoch_folder): 67 | if self.html_logger.epoch_content is None: 68 | self.html_logger.epoch_content = {} 69 | #self.html_logger.epoch_content['header'] = '' 70 | header = '' 71 | for k in sorted(list(pack.keys())): 72 | # get the ones that are useful: 73 | if is_key_image(k): 74 | header += f"{k}\n" 75 | self.html_logger.epoch_content['header'] = header 76 | self.html_logger.epoch_content['content'] = '' 77 | content = '' 78 | batch_size = pack['batch_size'] 79 | tags = pack['tags'] if 'tags' in pack.keys() else None 80 | for b in range(batch_size): 81 | content += "\n" 82 | for k in sorted(list(pack.keys())): 83 | # get the ones that are useful: 84 | if is_key_image(k): 85 | if tags is not None: 86 | prefix = '%s_%s.png' % (k, tags[b]) 87 | else: 88 | prefix = '%s_%04d_%04d.png' % (k, batch_idx, b) 89 | if epoch_folder is not None: 90 | link = join(epoch_folder, prefix) 91 | else: 92 | link = prefix 93 | # html is at outdir/../, so link is epochXXX/ 94 | content += f"\n" 95 | content += "\n" 96 | self.html_logger.epoch_content['content'] += content 97 | 98 | @staticmethod 99 | def _visualize(pack, batch_idx, outdir): 100 | # this thread saves the packed tensor into individual images 101 | batch_size = pack['batch_size'] 102 | tags = pack['tags'] if 'tags' in pack.keys() else None 103 | for k, v in pack.items(): 104 | rgb_tensor = convert2rgb(v, k) 105 | if rgb_tensor is None: 106 | continue 107 | for b in range(batch_size): 108 | if 'flow' in k: 109 | img = flow2img(rgb_tensor[b, :, :, :]) / 255 110 | else: 111 | img = np.squeeze(np.transpose(rgb_tensor[b, :, :, :], (1, 2, 0))) 112 | if tags is not None: 113 | prefix = '%s_%s.png' % (k, tags[b]) 114 | else: 115 | prefix = '%s_%04d_%04d.png' % (k, batch_idx, b) 116 | Image.fromarray((img * 255).astype(np.uint8)).save(join(outdir, prefix), 'PNG') 117 | 118 | @staticmethod 119 | def _error_callback(e): 120 | print(str(e)) 121 | --------------------------------------------------------------------------------