├── .gitmodules
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── assets
├── depth_train.gif
├── depth_train_init.gif
├── dog_depth.gif
├── dog_depth_init.gif
├── rgb_dog.gif
├── rgb_train.gif
└── teaser.gif
├── configs
└── __init__.py
├── datasets
├── __init__.py
├── base_dataset.py
├── davis_sequence.py
└── shutterstock.py
├── dependencies
├── conda_packages.txt
└── requirements.txt
├── experiments
├── davis
│ ├── test_cmd.txt
│ └── train_sequence.sh
└── shutterstock
│ ├── test_cmd.txt
│ └── train_sequence.sh
├── loggers
├── Progbar.py
├── __init__.py
├── html_template.py
└── loggers.py
├── losses
├── __init__.py
└── scene_flow_projection.py
├── models
├── __init__.py
├── netinterface.py
├── scene_flow_motion_field.py
└── video_base.py
├── networks
├── FCNUnet.py
├── MLP.py
├── __init__.py
├── blocks.py
└── sceneflow_field.py
├── options
├── __init__.py
├── options_test.py
└── options_train.py
├── scripts
├── download_data_and_depth_ckpt.sh
├── download_triangulation_files.sh
└── preprocess
│ ├── davis
│ ├── generate_flows.py
│ ├── generate_frame_midas.py
│ └── generate_sequence_midas.py
│ └── shutterstock
│ ├── generate_flows.py
│ ├── generate_frame_midas.py
│ └── generate_sequence_midas.py
├── test.py
├── third_party
├── MiDaS.py
├── __init__.py
├── hourglass.py
├── midas_blocks.py
└── util_colormap.py
├── train.py
├── util
├── __init__.py
├── util_config.py
├── util_flow.py
├── util_html.py
├── util_imageIO.py
├── util_loadlib.py
├── util_plot.py
├── util_print.py
└── util_visualize.py
└── visualize
├── base_visualizer.py
└── html_visualizer.py
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "third_party/RAFT"]
2 | path = third_party/RAFT
3 | url = https://github.com/princeton-vl/RAFT.git
4 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution;
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code Reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Community Guidelines
26 |
27 | This project follows [Google's Open Source Community
28 | Guidelines](https://opensource.google/conduct/).
29 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Consistent Depth of Moving Objects in Video
2 |
3 | 
4 |
5 | This repository contains training code for the SIGGRAPH 2021 paper
6 | "[Consistent Depth of Moving Objects in
7 | Video](https://dynamic-video-depth.github.io/)".
8 |
9 | This is not an officially supported Google product.
10 |
11 | ## Installing Dependencies
12 |
13 | We provide both conda and pip installations for dependencies.
14 |
15 | - To install with conda, run
16 |
17 | ```
18 | conda create --name dynamic-video-depth --file ./dependencies/conda_packages.txt
19 | ```
20 |
21 | - To install with pip, run
22 |
23 | ```
24 | pip install -r ./dependencies/requirements.txt
25 | ```
26 |
27 |
28 |
29 | ## Training
30 | We provide two preprocessed video tracks from the DAVIS dataset. To download the pre-trained single-image depth prediction checkpoints, as well as the example data, run:
31 |
32 |
33 | ```
34 | bash ./scripts/download_data_and_depth_ckpt.sh
35 | ```
36 |
37 | This script will automatically download and unzip the checkpoints and data. To download mannually, use [this link](https://drive.google.com/drive/folders/19_hbgJ9mettcbMQBYYnH0seiUaREZD1D?usp=sharing).
38 |
39 | To train using the example data, run:
40 |
41 | ```
42 | bash ./experiments/davis/train_sequence.sh 0 --track_id dog
43 | ```
44 |
45 | The first argument indicates the GPU id for training, and `--track_id` indicates the name of the track. ('dog' and 'train' are provided.)
46 |
47 | After training, the results should look like:
48 |
49 | | Video | Our Depth | Single Image Depth |
50 | :----:| :----:| :----:
51 |  |  |  |
52 |  |  |  |
53 |
54 |
55 | ## Dataset Preparation:
56 |
57 | To help with generating custom datasets for training, We provide examples of preparing the dataset from DAVIS, and two sequences from ShutterStock, which are showcased in our paper.
58 |
59 | The general work flow for preprocessing the dataset is:
60 |
61 | 1. Calibrate the scale of camera translation, transform the camera matrices into camera-to-world convention, and save as individual files.
62 |
63 | 2. Calculate flow between pairs of frames, as well as occlusion estimates.
64 |
65 | 3. Pack flow and per-frame data into training batches.
66 |
67 | To be more specific, example codes are provided in `.scripts/preprocess`
68 |
69 | We provide the triangulation results [here](https://drive.google.com/file/d/1U07e9xtwYbBZPpJ2vfsLaXYMWATt4XyB/view?usp=sharing) and [here](https://drive.google.com/file/d/1om58tVKujaq1Jo_ShpKc4sWVAWBoKY6U/view?usp=sharing). You can download them in a single script by running:
70 |
71 | ```
72 | bash ./scripts/download_triangulation_files.sh
73 | ```
74 |
75 | ### Davis data preparation
76 |
77 | 1. Download the DAVIS dataset here, and unzip it under `./datafiles`.
78 |
79 | 2. Run `python ./scripts/preprocess/davis/generate_frame_midas.py`. This requires `trimesh` to be installed (`pip install trimesh` should do the trick). This script projects the triangulated 3D points to calibrate camera translation scales.
80 |
81 | 3. Run `python ./scripts/preprocess/davis/generate_flows.py` to generate optical flows between pairs of images. This stage requires `RAFT`, which is included as a submodule in this repo.
82 |
83 |
84 | 4. Run `python ./scripts/preprocess/davis/generate_sequence_midas.py` to pack camera calibrations and images into training batches.
85 |
86 | ### ShutterStock Videos
87 |
88 |
89 | 1. Download the ShutterStock videos [here](https://www.shutterstock.com/video/clip-1058262031-loyal-golden-retriever-dog-running-across-green) and [here](https://www.shutterstock.com/nb/video/clip-1058781907-handsome-pedigree-cute-white-labrador-walking-on).
90 |
91 |
92 |
93 | 2. Cast the videos as images, put them under `./datafiles/shutterstock/images`, and rename them to match the file names in `./datafiles/shutterstock/triangulation`. Note that not all frames are triangulated; time stamp of valid frames are recorded in the triangulation file name.
94 |
95 | 2. Run `python ./scripts/preprocess/shutterstock/generate_frame_midas.py` to pack per-frame data.
96 |
97 | 3. Run `python ./scripts/preprocess/shutterstock/generate_flows.py` to generate optical flows between pairs of images.
98 |
99 | 4. Run `python ./scripts/preprocess/shutterstock/generate_sequence_midas.py` to pack flows and per-frame data into training batches.
100 |
101 | 5. Example training script is located at `./experiments/shutterstock/train_sequence.sh`
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
--------------------------------------------------------------------------------
/assets/depth_train.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/dynamic-video-depth/79177ef5941b15b0aa2395b626f922fcb7b4c179/assets/depth_train.gif
--------------------------------------------------------------------------------
/assets/depth_train_init.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/dynamic-video-depth/79177ef5941b15b0aa2395b626f922fcb7b4c179/assets/depth_train_init.gif
--------------------------------------------------------------------------------
/assets/dog_depth.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/dynamic-video-depth/79177ef5941b15b0aa2395b626f922fcb7b4c179/assets/dog_depth.gif
--------------------------------------------------------------------------------
/assets/dog_depth_init.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/dynamic-video-depth/79177ef5941b15b0aa2395b626f922fcb7b4c179/assets/dog_depth_init.gif
--------------------------------------------------------------------------------
/assets/rgb_dog.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/dynamic-video-depth/79177ef5941b15b0aa2395b626f922fcb7b4c179/assets/rgb_dog.gif
--------------------------------------------------------------------------------
/assets/rgb_train.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/dynamic-video-depth/79177ef5941b15b0aa2395b626f922fcb7b4c179/assets/rgb_train.gif
--------------------------------------------------------------------------------
/assets/teaser.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/dynamic-video-depth/79177ef5941b15b0aa2395b626f922fcb7b4c179/assets/teaser.gif
--------------------------------------------------------------------------------
/configs/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | depth_pretrain_path = './pretrained_depth_ckpt/best_depth_Ours_Bilinear_inc_3_net_G.pth'
16 | midas_pretrain_path = './pretrained_depth_ckpt/midas_cpkt.pt'
17 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import importlib
16 |
17 |
18 | def get_dataset(alias):
19 | dataset_module = importlib.import_module('datasets.' + alias.lower())
20 | return dataset_module.Dataset
21 |
--------------------------------------------------------------------------------
/datasets/base_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import numpy as np
16 | import torch.utils.data as data
17 | import torch
18 |
19 |
20 | class Dataset(data.Dataset):
21 | @classmethod
22 | def add_arguments(cls, parser):
23 | return parser, set()
24 |
25 | def __init__(self, opt, mode='train', model=None):
26 | super().__init__()
27 | self.opt = opt
28 | self.mode = mode
29 |
30 | @staticmethod
31 | def convert_to_float32(sample_loaded):
32 | for k, v in sample_loaded.items():
33 | if isinstance(v, np.ndarray):
34 | sample_loaded[k] = torch.from_numpy(v).float()
35 |
--------------------------------------------------------------------------------
/datasets/davis_sequence.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import numpy as np
16 | from .base_dataset import Dataset as base_dataset
17 | from glob import glob
18 | from os.path import join
19 | import torch
20 |
21 |
22 | class Dataset(base_dataset):
23 |
24 | @classmethod
25 | def add_arguments(cls, parser):
26 | parser.add_argument('--cache', action='store_true', help='cache the data into ram')
27 | parser.add_argument('--subsample', action='store_true', help='subsample the video in time')
28 | parser.add_argument('--track_id', default='train', type=str, help='the track id to load')
29 | parser.add_argument('--overfit', action='store_true', help='overfit and see if things works')
30 | parser.add_argument('--gaps', type=str, default='1,2,3,4', help='gaps for sequences')
31 | parser.add_argument('--repeat', type=int, default=1, help='number of repeatition')
32 | parser.add_argument('--select', action='store_true', help='pred')
33 | return parser, set()
34 |
35 | def __init__(self, opt, mode='train', model=None):
36 | super().__init__(opt, mode, model)
37 | self.mode = mode
38 | assert mode in ('train', 'vali')
39 |
40 | data_root = './datafiles/davis_processed'
41 | # tracks = sorted(glob(join(data_root, 'frames_midas', '*')))
42 | # tracks = [x.split('/')[-1] for x in tracks]
43 | track_name = opt.track_id # tracks[opt.track_id]
44 | if model is None:
45 | self.required = ['img', 'flow']
46 | self.preproc = None
47 | elif mode == 'train':
48 | self.required = model.requires
49 | self.preproc = model.preprocess
50 | else:
51 | self.required = ['img']
52 | self.preproc = model.preprocess
53 |
54 | frame_prefix = 'frames_midas'
55 | seq_prefix = 'sequences_select_pairs_midas'
56 |
57 | if mode == 'train':
58 |
59 | if self.opt.subsample:
60 | data_path = join(data_root, seq_prefix, track_name, 'subsample')
61 | else:
62 | data_path = join(data_root, seq_prefix, track_name, '%03d' % 1)
63 |
64 | gaps = opt.gaps.split(',')
65 | gaps = [int(x) for x in gaps]
66 | self.file_list = []
67 | for g in gaps:
68 |
69 | file_list = sorted(glob(join(data_path, f'shuffle_False_gap_{g:02d}_*.pt')))
70 | self.file_list += file_list
71 |
72 | frame_data_path = join(data_root, frame_prefix, track_name)
73 | self.n_frames = len(sorted(glob(join(frame_data_path, '*.npz')))) + 0.0
74 |
75 | else:
76 | data_path = join(data_root, frame_prefix, track_name)
77 | self.file_list = sorted(glob(join(data_path, '*.npz')))
78 | self.n_frames = len(self.file_list) + 0.0
79 |
80 | def __len__(self):
81 | if self.mode != 'train':
82 | return len(self.file_list)
83 | else:
84 | return len(self.file_list) * self.opt.repeat
85 |
86 | def __getitem__(self, idx):
87 | sample_loaded = {}
88 | if self.opt.overfit:
89 | idx = idx % self.opt.capat
90 | else:
91 | idx = idx % len(self.file_list)
92 |
93 | if self.opt.subsample:
94 | unit = 2.0
95 | else:
96 | unit = 1.0
97 |
98 | if self.mode == 'train':
99 |
100 | dataset = torch.load(self.file_list[idx])
101 |
102 | _, H, W, _ = dataset['img_1'].shape
103 | dataset['img_1'] = dataset['img_1'].permute([0, 3, 1, 2])
104 | dataset['img_2'] = dataset['img_2'].permute([0, 3, 1, 2])
105 | ts = dataset['fid_1'].reshape([-1, 1, 1, 1]).expand(-1, -1, H, W) / self.n_frames
106 | ts2 = dataset['fid_2'].reshape([-1, 1, 1, 1]).expand(-1, -1, H, W) / self.n_frames
107 | for k in dataset:
108 | if type(dataset[k]) == list:
109 | continue
110 | sample_loaded[k] = dataset[k].float()
111 | sample_loaded['time_step'] = unit / self.n_frames
112 | sample_loaded['time_stamp_1'] = ts.float()
113 | sample_loaded['time_stamp_2'] = ts2.float()
114 | sample_loaded['frame_id_1'] = np.asarray(dataset['fid_1'])
115 | sample_loaded['frame_id_2'] = np.asarray(dataset['fid_2'])
116 |
117 | else:
118 | dataset = np.load(self.file_list[idx])
119 | H, W, _ = dataset['img'].shape
120 | sample_loaded['time_stamp_1'] = np.ones([1, H, W]) * idx / self.n_frames
121 | sample_loaded['img'] = np.transpose(dataset['img'], [2, 0, 1])
122 | sample_loaded['frame_id_1'] = idx
123 |
124 | sample_loaded['time_step'] = unit / self.n_frames
125 | sample_loaded['depth_pred'] = dataset['depth_pred'][None, ...]
126 | sample_loaded['cam_c2w'] = dataset['pose_c2w']
127 | sample_loaded['K'] = dataset['intrinsics']
128 | sample_loaded['depth_mvs'] = dataset['depth_mvs'][None, ...]
129 | # add decomposed cam mat
130 | cam_pose_c2w_1 = dataset['pose_c2w']
131 | R_1 = cam_pose_c2w_1[:3, :3]
132 | t_1 = cam_pose_c2w_1[:3, 3]
133 | K = dataset['intrinsics']
134 |
135 | # for network use:
136 | R_1_tensor = np.zeros([1, 1, 3, 3])
137 | R_1_T_tensor = np.zeros([1, 1, 3, 3])
138 | t_1_tensor = np.zeros([1, 1, 1, 3])
139 | K_tensor = np.zeros([1, 1, 3, 3])
140 | K_inv_tensor = np.zeros([1, 1, 3, 3])
141 | R_1_tensor[..., :, :] = R_1.T
142 | R_1_T_tensor[..., :, :] = R_1
143 | t_1_tensor[..., :] = t_1
144 | K_tensor[..., :, :] = K.T
145 | K_inv_tensor[..., :, :] = np.linalg.inv(K).T
146 |
147 | sample_loaded['R_1'] = R_1_tensor
148 | sample_loaded['R_1_T'] = R_1_T_tensor
149 | sample_loaded['t_1'] = t_1_tensor
150 | sample_loaded['K'] = K_tensor
151 | sample_loaded['K_inv'] = K_inv_tensor
152 | sample_loaded['pair_path'] = self.file_list[idx]
153 | self.convert_to_float32(sample_loaded)
154 | return sample_loaded
155 |
--------------------------------------------------------------------------------
/datasets/shutterstock.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import numpy as np
16 | from .base_dataset import Dataset as base_dataset
17 | from glob import glob
18 | from os.path import join
19 | import torch
20 |
21 |
22 | class Dataset(base_dataset):
23 |
24 | @classmethod
25 | def add_arguments(cls, parser):
26 | parser.add_argument('--cache', action='store_true', help='cache the data into ram')
27 | parser.add_argument('--subsample', action='store_true', help='subsample the video in time')
28 | parser.add_argument('--track_id', default=0, type=int, help='the track id to load')
29 | parser.add_argument('--overfit', action='store_true', help='overfit and see if things works')
30 | parser.add_argument('--gaps', type=str, default='1,2,3,4', help='gaps for sequences')
31 | parser.add_argument('--repeat', type=int, default=1, help='number of repeatition')
32 | parser.add_argument('--select', action='store_true', help='pred')
33 | return parser, set()
34 |
35 | def __init__(self, opt, mode='train', model=None):
36 | super().__init__(opt, mode, model)
37 | self.mode = mode
38 | assert mode in ('train', 'vali')
39 |
40 | data_root = './datafiles/shutterstock'
41 | tracks = sorted(glob(join(data_root, 'frames_midas', '*')))
42 | tracks = [x.split('/')[-1] for x in tracks]
43 | track_name = tracks[opt.track_id]
44 | if model is None:
45 | self.required = ['img', 'flow']
46 | self.preproc = None
47 | elif mode == 'train':
48 | self.required = model.requires
49 | self.preproc = model.preprocess
50 | else:
51 | self.required = ['img']
52 | self.preproc = model.preprocess
53 |
54 | frame_prefix = 'frames_midas'
55 | seq_prefix = 'sequences_select_pairs_midas'
56 |
57 | if mode == 'train':
58 |
59 | if self.opt.subsample:
60 | data_path = join(data_root, seq_prefix, track_name, 'subsample')
61 | else:
62 | data_path = join(data_root, seq_prefix, track_name, '%03d' % 1)
63 |
64 | gaps = opt.gaps.split(',')
65 | gaps = [int(x) for x in gaps]
66 | self.file_list = []
67 | for g in gaps:
68 |
69 | file_list = sorted(glob(join(data_path, f'shuffle_False_gap_{g:02d}_*.pt')))
70 | self.file_list += file_list
71 |
72 | frame_data_path = join(data_root, frame_prefix, track_name)
73 | self.n_frames = len(sorted(glob(join(frame_data_path, '*.npz')))) + 0.0
74 |
75 | else:
76 | data_path = join(data_root, frame_prefix, track_name)
77 | self.file_list = sorted(glob(join(data_path, '*.npz')))
78 | self.n_frames = len(self.file_list) + 0.0
79 |
80 | def __len__(self):
81 | if self.mode != 'train':
82 | return len(self.file_list)
83 | else:
84 | return len(self.file_list) * self.opt.repeat
85 |
86 | def __getitem__(self, idx):
87 | sample_loaded = {}
88 | if self.opt.overfit:
89 | idx = idx % self.opt.capat
90 | else:
91 | idx = idx % len(self.file_list)
92 |
93 | if self.opt.subsample:
94 | unit = 2.0
95 | else:
96 | unit = 1.0
97 |
98 | if self.mode == 'train':
99 |
100 | dataset = torch.load(self.file_list[idx])
101 |
102 | _, H, W, _ = dataset['img_1'].shape
103 | dataset['img_1'] = dataset['img_1'].permute([0, 3, 1, 2])
104 | dataset['img_2'] = dataset['img_2'].permute([0, 3, 1, 2])
105 | ts = dataset['fid_1'].reshape([-1, 1, 1, 1]).expand(-1, -1, H, W) / self.n_frames
106 | ts2 = dataset['fid_2'].reshape([-1, 1, 1, 1]).expand(-1, -1, H, W) / self.n_frames
107 | for k in dataset:
108 | if type(dataset[k]) == list:
109 | continue
110 | sample_loaded[k] = dataset[k].float()
111 | sample_loaded['time_step'] = unit / self.n_frames
112 | sample_loaded['time_stamp_1'] = ts.float()
113 | sample_loaded['time_stamp_2'] = ts2.float()
114 | sample_loaded['frame_id_1'] = np.asarray(dataset['fid_1'])
115 | sample_loaded['frame_id_2'] = np.asarray(dataset['fid_2'])
116 |
117 | else:
118 | dataset = np.load(self.file_list[idx])
119 | H, W, _ = dataset['img'].shape
120 | sample_loaded['time_stamp_1'] = np.ones([1, H, W]) * idx / self.n_frames
121 | sample_loaded['img'] = np.transpose(dataset['img'], [2, 0, 1])
122 | sample_loaded['frame_id_1'] = idx
123 |
124 | sample_loaded['time_step'] = unit / self.n_frames
125 | sample_loaded['depth_pred'] = dataset['depth_pred'][None, ...]
126 | sample_loaded['cam_c2w'] = dataset['pose_c2w']
127 | sample_loaded['K'] = dataset['intrinsics']
128 | sample_loaded['depth_mvs'] = dataset['depth_mvs'][None, ...]
129 | # add decomposed cam mat
130 | cam_pose_c2w_1 = dataset['pose_c2w']
131 | R_1 = cam_pose_c2w_1[:3, :3]
132 | t_1 = cam_pose_c2w_1[:3, 3]
133 | K = dataset['intrinsics']
134 |
135 | # for network use:
136 | R_1_tensor = np.zeros([1, 1, 3, 3])
137 | R_1_T_tensor = np.zeros([1, 1, 3, 3])
138 | t_1_tensor = np.zeros([1, 1, 1, 3])
139 | K_tensor = np.zeros([1, 1, 3, 3])
140 | K_inv_tensor = np.zeros([1, 1, 3, 3])
141 | R_1_tensor[..., :, :] = R_1.T
142 | R_1_T_tensor[..., :, :] = R_1
143 | t_1_tensor[..., :] = t_1
144 | K_tensor[..., :, :] = K.T
145 | K_inv_tensor[..., :, :] = np.linalg.inv(K).T
146 |
147 | sample_loaded['R_1'] = R_1_tensor
148 | sample_loaded['R_1_T'] = R_1_T_tensor
149 | sample_loaded['t_1'] = t_1_tensor
150 | sample_loaded['K'] = K_tensor
151 | sample_loaded['K_inv'] = K_inv_tensor
152 | sample_loaded['pair_path'] = self.file_list[idx]
153 | self.convert_to_float32(sample_loaded)
154 | return sample_loaded
155 |
--------------------------------------------------------------------------------
/dependencies/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.13.0
2 | backcall==0.2.0
3 | cachetools==4.2.2
4 | certifi==2021.5.30
5 | charset-normalizer==2.0.4
6 | cycler==0.10.0
7 | decorator==4.4.2
8 | filelock==3.0.12
9 | gdown==3.13.0
10 | google-auth==1.34.0
11 | google-auth-oauthlib==0.4.5
12 | grpcio==1.39.0
13 | idna==3.2
14 | imageio==2.9.0
15 | ipykernel==5.3.4
16 | ipython==7.18.1
17 | ipython-genutils==0.2.0
18 | jedi==0.17.2
19 | jupyter-client==6.1.7
20 | jupyter-core==4.6.3
21 | kiwisolver==1.3.1
22 | Markdown==3.3.4
23 | matplotlib==3.4.2
24 | networkx==2.6.2
25 | numpy==1.21.1
26 | oauthlib==3.1.1
27 | opencv-python==4.5.3.56
28 | pandas==1.3.1
29 | parso==0.7.1
30 | pexpect==4.8.0
31 | pickleshare==0.7.5
32 | Pillow==8.3.1
33 | prompt-toolkit==3.0.7
34 | protobuf==3.17.3
35 | ptyprocess==0.6.0
36 | pyasn1==0.4.8
37 | pyasn1-modules==0.2.8
38 | Pygments==2.6.1
39 | pyparsing==2.4.7
40 | PySocks==1.7.1
41 | python-dateutil==2.8.2
42 | pytz==2021.1
43 | PyWavelets==1.1.1
44 | pyzmq==19.0.2
45 | requests==2.26.0
46 | requests-oauthlib==1.3.0
47 | rsa==4.7.2
48 | scikit-image==0.18.2
49 | scipy==1.7.1
50 | six==1.16.0
51 | tensorboard==2.6.0
52 | tensorboard-data-server==0.6.1
53 | tensorboard-plugin-wit==1.8.0
54 | tifffile==2021.8.8
55 | torch==1.9.0
56 | torchvision==0.10.0
57 | tornado==6.1
58 | tqdm==4.62.0
59 | traitlets==5.0.4
60 | trimesh==3.9.27
61 | typing-extensions==3.10.0.0
62 | urllib3==1.26.6
63 | wcwidth==0.2.5
64 | Werkzeug==2.0.1
65 |
--------------------------------------------------------------------------------
/experiments/davis/test_cmd.txt:
--------------------------------------------------------------------------------
1 | python test.py --net {net} --dataset {dataset} --workers 4 --output_dir './test_results/{dataset}/' --epoch {epoch} --html_logger --batch_size 1 --gpu {gpu} --track_id {track_id} --suffix {suffix_expand} --checkpoint_path {full_logdir}
--------------------------------------------------------------------------------
/experiments/davis/train_sequence.sh:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | #!/bin/bash
16 | if [ $# -lt 1 ]; then
17 | echo "Usage: $0 gpu "
18 | exit 1
19 | fi
20 | gpu="$1"
21 | shift
22 | set -e
23 | cmd="
24 | python train.py \
25 | --net scene_flow_motion_field \
26 | --dataset davis_sequence \
27 | --track_id train \
28 | --log_time \
29 | --epoch_batches 2000 \
30 | --epoch 20 \
31 | --lr 1e-6 \
32 | --html_logger \
33 | --vali_batches 150 \
34 | --batch_size 1 \
35 | --optim adam \
36 | --vis_batches_vali 4 \
37 | --vis_every_vali 1 \
38 | --vis_every_train 1 \
39 | --vis_batches_train 5 \
40 | --vis_at_start \
41 | --tensorboard \
42 | --gpu "$gpu" \
43 | --save_net 1 \
44 | --workers 4 \
45 | --one_way \
46 | --loss_type l1 \
47 | --l1_mul 0 \
48 | --acc_mul 1 \
49 | --disp_mul 1 \
50 | --warm_sf 5 \
51 | --scene_lr_mul 1000 \
52 | --repeat 1 \
53 | --flow_mul 1\
54 | --sf_mag_div 100 \
55 | --time_dependent \
56 | --gaps 1,2,4,6,8 \
57 | --midas \
58 | --use_disp \
59 | --logdir './checkpoints/davis/sequence/' \
60 | --suffix 'track_{track_id}_{loss_type}_wreg_{warm_reg}_acc_{acc_mul}_disp_{disp_mul}_flowmul_{flow_mul}_time_{time_dependent}_CNN_{use_cnn}_gap_{gaps}_Midas_{midas}_ud_{use_disp}' \
61 | --test_template './experiments/davis/test_cmd.txt' \
62 | --force_overwrite \
63 | $*"
64 | echo $cmd
65 | eval $cmd
66 |
67 |
68 |
69 |
--------------------------------------------------------------------------------
/experiments/shutterstock/test_cmd.txt:
--------------------------------------------------------------------------------
1 | python test.py --net {net} --dataset {dataset} --workers 4 --output_dir './test_results/{dataset}/' --epoch {epoch} --html_logger --batch_size 1 --gpu {gpu} --track_id {track_id} --suffix {suffix_expand} --checkpoint_path {full_logdir}
--------------------------------------------------------------------------------
/experiments/shutterstock/train_sequence.sh:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | #!/bin/bash
16 | if [ $# -lt 1 ]; then
17 | echo "Usage: $0 gpu "
18 | exit 1
19 | fi
20 | gpu="$1"
21 | shift
22 | set -e
23 | cmd="
24 | python train.py \
25 | --net scene_flow_motion_field \
26 | --dataset shutterstock \
27 | --track_id 0 \
28 | --log_time \
29 | --epoch_batches 2000 \
30 | --epoch 20 \
31 | --lr 1e-6 \
32 | --html_logger \
33 | --vali_batches 150 \
34 | --batch_size 1 \
35 | --optim adam \
36 | --vis_batches_vali 4 \
37 | --vis_every_vali 1 \
38 | --vis_every_train 1 \
39 | --vis_batches_train 5 \
40 | --vis_at_start \
41 | --tensorboard \
42 | --gpu "$gpu" \
43 | --save_net 1 \
44 | --workers 4 \
45 | --one_way \
46 | --loss_type l1 \
47 | --l1_mul 0 \
48 | --acc_mul 1 \
49 | --disp_mul 1 \
50 | --warm_sf 5 \
51 | --scene_lr_mul 1000 \
52 | --repeat 1 \
53 | --flow_mul 1\
54 | --sf_mag_div 100 \
55 | --time_dependent \
56 | --gaps 1,2,4,6,8 \
57 | --midas \
58 | --use_disp \
59 | --logdir './checkpoints/shutterstock/sequence/' \
60 | --suffix 'track_{track_id}_{loss_type}_wreg_{warm_reg}_acc_{acc_mul}_disp_{disp_mul}_flowmul_{flow_mul}_time_{time_dependent}_CNN_{use_cnn}_gap_{gaps}_Midas_{midas}_ud_{use_disp}' \
61 | --test_template './experiments/shutterstock/test_cmd.txt' \
62 | --force_overwrite \
63 | $*"
64 | echo $cmd
65 | eval $cmd
66 |
67 |
68 |
69 |
--------------------------------------------------------------------------------
/loggers/Progbar.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import sys
17 | import time
18 | import numpy as np
19 |
20 |
21 | class Progbar(object):
22 | """Displays a progress bar.
23 | # Arguments
24 | target: Total number of steps expected, None if unknown.
25 | interval: Minimum visual progress update interval (in seconds).
26 | """
27 |
28 | def __init__(self, target, width=30, verbose=1, interval=0.05, no_accum=False):
29 | self.width = width
30 | if target is None:
31 | target = -1
32 | self.target = target
33 | self.sum_values = {}
34 | self.step_values = {}
35 | self.unique_values = []
36 | self.start = time.time()
37 | self.last_update = 0
38 | self.interval = interval
39 | self.total_width = 0
40 | self.seen_so_far = 0
41 | self.verbose = verbose
42 | self.no_accume = no_accum
43 |
44 | def update(self, current, values=None, force=False):
45 | """Updates the progress bar.
46 | # Arguments
47 | current: Index of current step.
48 | values: List of tuples (name, value_for_last_step).
49 | The progress bar will display averages for these values.
50 | force: Whether to force visual progress update.
51 | """
52 | values = values or []
53 | for k, v in values:
54 | if k not in self.sum_values:
55 | self.sum_values[k] = [v * (current - self.seen_so_far),
56 | current - self.seen_so_far]
57 | self.unique_values.append(k)
58 | self.step_values[k] = v
59 | else:
60 | self.sum_values[k][0] += v * (current - self.seen_so_far)
61 | self.sum_values[k][1] += (current - self.seen_so_far)
62 | self.step_values[k] = v
63 |
64 | self.seen_so_far = current
65 |
66 | now = time.time()
67 | if self.verbose == 1:
68 | if not force and (now - self.last_update) < self.interval:
69 | return
70 |
71 | prev_total_width = self.total_width
72 | sys.stdout.write('\b' * prev_total_width)
73 | sys.stdout.write('\r')
74 |
75 | if self.target is not -1:
76 | numdigits = int(np.floor(np.log10(self.target))) + 1
77 | barstr = '%%%dd/%%%dd [' % (numdigits, numdigits)
78 | bar = barstr % (current, self.target)
79 | prog = float(current) / self.target
80 | prog_width = int(self.width * prog)
81 | if prog_width > 0:
82 | bar += ('=' * (prog_width - 1))
83 | if current < self.target:
84 | bar += '>'
85 | else:
86 | bar += '='
87 | bar += ('.' * (self.width - prog_width))
88 | bar += ']'
89 | sys.stdout.write(bar)
90 | self.total_width = len(bar)
91 |
92 | if current:
93 | time_per_unit = (now - self.start) / current
94 | else:
95 | time_per_unit = 0
96 | eta = time_per_unit * (self.target - current)
97 | info = ''
98 | if current < self.target and self.target is not -1:
99 | info += ' - ETA: %ds' % eta
100 | else:
101 | info += ' - %ds' % (now - self.start)
102 | for k in self.unique_values:
103 | info += ' - %s:' % k
104 | if isinstance(self.sum_values[k], list):
105 | if self.no_accume:
106 | avg = self.step_values[k]
107 | else:
108 | avg = np.mean(
109 | self.sum_values[k][0] / max(1, self.sum_values[k][1]))
110 | if abs(avg) > 1e-3:
111 | info += ' %.4f' % avg
112 | else:
113 | info += ' %.4e' % avg
114 | else:
115 | info += ' %s' % self.sum_values[k]
116 |
117 | self.total_width += len(info)
118 | if prev_total_width > self.total_width:
119 | info += ((prev_total_width - self.total_width) * ' ')
120 |
121 | sys.stdout.write(info)
122 | sys.stdout.flush()
123 |
124 | if current >= self.target:
125 | sys.stdout.write('\n')
126 |
127 | if self.verbose == 2:
128 | if current >= self.target:
129 | info = '%ds' % (now - self.start)
130 | for k in self.unique_values:
131 | info += ' - %s:' % k
132 | avg = np.mean(
133 | self.sum_values[k][0] / max(1, self.sum_values[k][1]))
134 | if avg > 1e-3:
135 | info += ' %.4f' % avg
136 | else:
137 | info += ' %.4e' % avg
138 | sys.stdout.write(info + "\n")
139 |
140 | self.last_update = now
141 |
142 | def add(self, n, values=None):
143 | self.update(self.seen_so_far + n, values)
144 |
--------------------------------------------------------------------------------
/loggers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/loggers/html_template.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | TABLE_HEADER = """
16 |
17 |
'
114 | self.video_content += video_tag
115 |
116 | def add_div(self, div_string):
117 | self.video_content += div_string
118 |
119 | def save(self, path):
120 | content = self.content.format(body_content=self.video_content + self.devider + self.table_content)
121 | d = dirname(path)
122 | makedirs(d, exist_ok=True)
123 | with open(path, 'w') as f:
124 | f.write(content)
125 | return
126 |
--------------------------------------------------------------------------------
/util/util_imageIO.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from PIL import Image
16 | import numpy as np
17 | from skimage.transform import resize as imresize
18 |
19 |
20 | def read_image(path, load_alpha=False):
21 | im = np.asarray(Image.open(path))
22 | dims = len(im.shape)
23 | if dims == 2:
24 | return im
25 | elif dims == 3:
26 | if im.shape[-1] == 3:
27 | return im
28 | elif load_alpha:
29 | return im
30 | else:
31 | return im[..., :3]
32 | else:
33 | raise ValueError(f'invalid dimensions encoutered. Only except dims 2,3 but encoutered {dims}')
34 |
35 |
36 | def resize_image(im, size=None, scale=None):
37 | H, W = im.shape[:2]
38 | if scale:
39 | th = H // scale
40 | tw = W // scale
41 | s = (th, tw)
42 | else:
43 | s = size
44 | im = imresize(im, s)
45 | return im
46 |
47 |
48 | def hwc2chw(im):
49 | dims = len(im.shape)
50 | if dims == 2:
51 | return im[None, ...]
52 | elif dims == 3:
53 | return np.transpose(im, (2, 0, 1))
54 | else:
55 | raise ValueError(f'invalid dimensions encoutered. Only except dims 2,3 but encoutered {dims}')
56 |
--------------------------------------------------------------------------------
/util/util_loadlib.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import subprocess
16 | from .util_print import str_warning, str_verbose
17 | import sys
18 |
19 |
20 | def set_gpu(gpu, check=True):
21 | import os
22 | import torch
23 | import torch.backends.cudnn as cudnn
24 | cudnn.benchmark = False
25 | torch.cuda.set_device(int(gpu))
26 |
27 |
28 | def _check_gpu_setting_in_use(gpu):
29 | '''
30 | check that CUDA_VISIBLE_DEVICES is actually working
31 | by starting a clean thread with the same CUDA_VISIBLE_DEVICES
32 | '''
33 | import subprocess
34 | try:
35 | which_python = sys.executable
36 | output = subprocess.check_output(
37 | 'CUDA_VISIBLE_DEVICES=%s %s -c "import torch; print(torch.cuda.device_count())"' % (gpu, which_python), shell=True)
38 | except subprocess.CalledProcessError as e:
39 | print(str(e))
40 | output = output.decode().strip()
41 | import torch
42 | return torch.cuda.device_count() == int(output)
43 |
44 |
45 | def _check_gpu(gpu):
46 | msg = subprocess.check_output(
47 | 'nvidia-smi --query-gpu=index,utilization.gpu,memory.used --format=csv,nounits,noheader -i %s' % (gpu,), shell=True)
48 | msg = msg.decode('utf-8')
49 | all_ok = True
50 | for line in msg.split('\n'):
51 | if line == '':
52 | break
53 | stats = [x.strip() for x in line.split(',')]
54 | gpu = stats[0]
55 | util = int(stats[1])
56 | mem_used = int(stats[2])
57 | if util > 10 or mem_used > 1000: # util in percentage and mem_used in MiB
58 | print(str_warning, 'Designated GPU in use: id=%s, util=%d%%, memory in use: %d MiB' % (
59 | gpu, util, mem_used))
60 | all_ok = False
61 | if all_ok:
62 | print(str_verbose, 'All designated GPU(s) free to use. ')
63 |
64 |
65 | def set_manual_seed(seed):
66 | import random
67 | random.seed(seed)
68 | try:
69 | import numpy as np
70 | np.random.seed(seed)
71 | except ImportError as err:
72 | print('Numpy not found. Random seed for numpy not set. ')
73 | try:
74 | import torch
75 | torch.manual_seed(seed)
76 | torch.cuda.manual_seed_all(seed)
77 | except ImportError as err:
78 | print('Pytorch not found. Random seed for pytorch not set. ')
79 |
--------------------------------------------------------------------------------
/util/util_plot.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import matplotlib.pyplot as plt
16 | import matplotlib.animation as animation
17 | from third_party.util_colormap import turbo_colormap_data
18 | import matplotlib
19 | from matplotlib.colors import ListedColormap
20 |
21 | matplotlib.cm.register_cmap('turbo', cmap=ListedColormap(turbo_colormap_data))
22 |
23 |
24 | def save_img_tensor_anim(img_tensor, path, markers=None, dpi=80, fps=30):
25 | fig = plt.figure(dpi=dpi)
26 | f = plt.imshow(img_tensor[0, ...].numpy())
27 | if markers is not None:
28 | s = plt.plot(markers[:, 0, 0], markers[:, 1, 0], 'w+', markersize=5)
29 | plt.axis('off')
30 | plt.tight_layout(h_pad=0, w_pad=0)
31 | plt.margins(0.0)
32 | fig.subplots_adjust(0, 0, 1, 1)
33 |
34 | def anim(step):
35 | f.set_data(img_tensor[step, ...])
36 | if markers is not None:
37 | s[0].set_data([markers[:, 0, step], markers[:, 1, step]])
38 | all_steps = min(img_tensor.shape[0], markers.shape[-1])
39 | line_ani = animation.FuncAnimation(fig, anim, all_steps,
40 | interval=1000 / fps, blit=False)
41 | line_ani.save(path)
42 | plt.close()
43 |
44 |
45 | def save_depth_tensor_anim(depth_tensor, path, minv=None, maxv=None, markers=None, dpi=80, fps=30):
46 |
47 | fig = plt.figure()
48 | f = plt.imshow(1 / depth_tensor[0, 0, ...].numpy(), cmap='turbo', vmax=1 / minv, vmin=1 / maxv)
49 | s = plt.plot(markers[:, 0, 0], markers[:, 1, 0], 'w+', markersize=5)
50 |
51 | def anim(step):
52 | f.set_data(1 / depth_tensor[step, 0, ...].numpy())
53 | if markers is not None:
54 | s[0].set_data([markers[:, 0, step], markers[:, 1, step]])
55 |
56 | all_steps = markers.shape[-1]
57 |
58 | line_ani = animation.FuncAnimation(fig, anim, all_steps,
59 | interval=1000 / fps, blit=False)
60 | line_ani.save(path)
61 | plt.close()
62 |
--------------------------------------------------------------------------------
/util/util_print.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | class bcolors:
17 | HEADER = '\033[95m'
18 | OKBLUE = '\033[94m'
19 | OKGREEN = '\033[92m'
20 | WARNING = '\033[93m'
21 | FAIL = '\033[91m'
22 | ENDC = '\033[0m'
23 | BOLD = '\033[1m'
24 | UNDERLINE = '\033[4m'
25 |
26 |
27 | str_stage = bcolors.OKBLUE + '==>' + bcolors.ENDC
28 | str_verbose = bcolors.OKGREEN + '[Verbose]' + bcolors.ENDC
29 | str_warning = bcolors.WARNING + '[Warning]' + bcolors.ENDC
30 | str_error = bcolors.FAIL + '[Error]' + bcolors.ENDC
31 |
--------------------------------------------------------------------------------
/util/util_visualize.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import numpy as np
16 | from third_party.util_colormap import heatmap_to_pseudo_color
17 | KEYWORDS = ['depth', 'edgegradient', 'flow', 'img', 'rgb', 'image', 'edge', 'contour', 'softmask']
18 |
19 |
20 | def detach_to_cpu(tensor):
21 | if type(tensor) == np.ndarray:
22 | return tensor
23 | else:
24 | if tensor.requires_grad:
25 | tensor.requires_grad = False
26 | tensor = tensor.cpu()
27 | return tensor.numpy()
28 |
29 |
30 | class Converter:
31 | def __init__(self):
32 | pass
33 |
34 | @staticmethod
35 | def depth2img(tensor, normalize=True, disparity=True, eps=1e-6, **kargs):
36 | t = detach_to_cpu(tensor)
37 | assert len(t.shape) == 4
38 | assert t.shape[1] == 1
39 | t = 1 / (t + eps)
40 | # if normalize:
41 | max_v = np.max(t, axis=(2, 3), keepdims=True)
42 | min_v = np.min(t, axis=(2, 3), keepdims=True)
43 | t = (t - min_v) / (max_v - min_v + eps)
44 | # return t
45 | # else:
46 | # return t
47 | cs = []
48 | for b in range(t.shape[0]):
49 | c = heatmap_to_pseudo_color(t[b, 0, ...])
50 | cs.append(c[None, ...])
51 | cs = np.concatenate(cs, axis=0)
52 | cs = np.transpose(cs, [0, 3, 1, 2])
53 | return cs
54 |
55 | @staticmethod
56 | def edge2img(tensor, normalize=True, eps=1e-6, **kargs):
57 | t = detach_to_cpu(tensor)
58 | if np.max(t) > 1 or np.min(t) < 0:
59 | t = 1 / (1 + np.exp(-t))
60 | assert len(t.shape) == 4
61 | assert t.shape[1] == 1
62 | return t
63 |
64 | @staticmethod
65 | def image2img(tensor, **kargs):
66 | return Converter.img2img(tensor)
67 |
68 | @staticmethod
69 | def softmask2img(tensor, **kargs):
70 | t = detach_to_cpu(tensor) # [:, None, ...]
71 | # t = #detach_to_cpu(tensor)
72 | return t
73 |
74 | @staticmethod
75 | def scenef2img(tensor, **kargs):
76 | t = detach_to_cpu(tensor.squeeze(3))
77 | assert len(t.shape) == 4
78 | return np.linalg.norm(t, ord=1, axis=-1, keepdims=True)
79 |
80 | @staticmethod
81 | def rgb2img(tensor, **kargs):
82 | return Converter.img2img(tensor)
83 |
84 | @staticmethod
85 | def img2img(tensor, **kargs):
86 | t = detach_to_cpu(tensor)
87 | if np.min(t) < -0.1:
88 | t = (t + 1) / 2
89 | elif np.max(t) > 1.5:
90 | t = t / 255
91 | return t
92 |
93 | @staticmethod
94 | def edgegradient2img(tensor, **kargs):
95 | t = detach_to_cpu(tensor)
96 | mag = np.max(abs(t))
97 | positive = np.where(t > 0, t, 0)
98 | positive /= mag
99 | negative = np.where(t < 0, abs(t), 0)
100 | negative /= mag
101 | rgb = np.concatenate((positive, negative, np.zeros(negative.shape)), axis=1)
102 | return rgb
103 |
104 | @staticmethod
105 | def flow2img(tensor, **kargs):
106 | t = detach_to_cpu(tensor)
107 | return t
108 |
109 |
110 | def convert2rgb(tensor, key, **kargs):
111 | found = False
112 | for k in KEYWORDS:
113 | if k in key:
114 | convert = getattr(Converter, k + '2img')
115 | found = True
116 | break
117 | if not found:
118 | return None
119 | else:
120 | return convert(tensor, **kargs)
121 |
122 |
123 | def is_key_image(key):
124 | """check if the given key correspondes to images
125 |
126 | Arguments:
127 | key {str} -- key of a data pack
128 |
129 | Returns:
130 | bool -- [True if the given key correspondes to an image]
131 | """
132 |
133 | for k in KEYWORDS:
134 | if k in key:
135 | return True
136 | return False
137 |
138 |
139 | def parse_key(key):
140 | rkey = None
141 | found = False
142 | mode = None
143 | for k in KEYWORDS:
144 | if k in key:
145 | rkey = k
146 | found = True
147 | break
148 | if 'pred' in key:
149 | mode = 'pred'
150 | elif 'gt' in key:
151 | mode = 'gt'
152 | if not found:
153 | return None, None
154 | else:
155 | return rkey, mode
156 |
--------------------------------------------------------------------------------
/visualize/base_visualizer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from torch.multiprocessing import Pool
16 | from os.path import join, dirname
17 | from os import makedirs
18 | import atexit
19 | from util.util_config import get_project_config
20 |
21 |
22 | class BaseVisualizer():
23 | """
24 | Async Visulization Worker
25 | """
26 |
27 | def __init__(self, n_workers=4):
28 | # read global configs
29 | self.cfg = get_project_config()
30 | if n_workers == 0:
31 | pool = None
32 | elif n_workers > 0:
33 | pool = Pool(n_workers)
34 | else:
35 | raise ValueError(n_workers)
36 | self.pool = pool
37 |
38 | def cleanup():
39 | if pool:
40 | pool.close()
41 | pool.join()
42 | atexit.register(cleanup)
43 |
44 | def visualize(self, pack, batch_idx, outdir):
45 | if self.pool:
46 | self.pool.apply_async(
47 | self._visualize,
48 | [pack, batch_idx, outdir, self.cfg],
49 | error_callback=self._error_callback
50 | )
51 | else:
52 | self._visualize(pack, batch_idx, outdir, self.cfg)
53 |
54 | @staticmethod
55 | def _visualize(pack, batch_idx, param_f, outdir):
56 | # main visualiztion thread.
57 | raise NotImplementedError
58 |
59 | @staticmethod
60 | def _error_callback(e):
61 | print(str(e))
62 |
--------------------------------------------------------------------------------
/visualize/html_visualizer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from util.util_visualize import convert2rgb, is_key_image
16 | from torch.multiprocessing import Pool
17 | from util.util_flow import flow2img
18 | from os.path import join
19 | import atexit
20 | import numpy as np
21 | from PIL import Image
22 |
23 |
24 | class HTMLVisualizer():
25 | """
26 | Async Visulization Worker
27 | """
28 |
29 | def __init__(self, html_logger, n_workers=4):
30 | # read global configs
31 | if n_workers == 0:
32 | pool = None
33 | elif n_workers > 0:
34 | pool = Pool(n_workers)
35 | else:
36 | raise ValueError(n_workers)
37 | self.pool = pool
38 | self.html_logger = html_logger
39 | self.header_lut = None
40 |
41 | def cleanup():
42 | if pool:
43 | pool.close()
44 | pool.join()
45 | atexit.register(cleanup)
46 |
47 | def visualize(self, pack, batch_idx, outdir, is_test=False):
48 | # first append to the shared content
49 | # then launch the subprocesses for dumping images.
50 | # b_size = pack['batch_size']
51 | epoch_folder = outdir.split('/')[-1]
52 | if is_test:
53 | epoch_folder = None
54 | self.prepare_HTML_string(pack, batch_idx, epoch_folder)
55 |
56 | if self.pool:
57 | self.pool.apply_async(
58 | self._visualize,
59 | [pack, batch_idx, outdir],
60 | error_callback=self._error_callback
61 | )
62 | else:
63 | self._visualize(pack, batch_idx, outdir)
64 | # prepare HTML string
65 |
66 | def prepare_HTML_string(self, pack, batch_idx, epoch_folder):
67 | if self.html_logger.epoch_content is None:
68 | self.html_logger.epoch_content = {}
69 | #self.html_logger.epoch_content['header'] = ''
70 | header = ''
71 | for k in sorted(list(pack.keys())):
72 | # get the ones that are useful:
73 | if is_key_image(k):
74 | header += f"
{k}
\n"
75 | self.html_logger.epoch_content['header'] = header
76 | self.html_logger.epoch_content['content'] = ''
77 | content = ''
78 | batch_size = pack['batch_size']
79 | tags = pack['tags'] if 'tags' in pack.keys() else None
80 | for b in range(batch_size):
81 | content += "
\n"
82 | for k in sorted(list(pack.keys())):
83 | # get the ones that are useful:
84 | if is_key_image(k):
85 | if tags is not None:
86 | prefix = '%s_%s.png' % (k, tags[b])
87 | else:
88 | prefix = '%s_%04d_%04d.png' % (k, batch_idx, b)
89 | if epoch_folder is not None:
90 | link = join(epoch_folder, prefix)
91 | else:
92 | link = prefix
93 | # html is at outdir/../, so link is epochXXX/
94 | content += f"
\n"
95 | content += "
\n"
96 | self.html_logger.epoch_content['content'] += content
97 |
98 | @staticmethod
99 | def _visualize(pack, batch_idx, outdir):
100 | # this thread saves the packed tensor into individual images
101 | batch_size = pack['batch_size']
102 | tags = pack['tags'] if 'tags' in pack.keys() else None
103 | for k, v in pack.items():
104 | rgb_tensor = convert2rgb(v, k)
105 | if rgb_tensor is None:
106 | continue
107 | for b in range(batch_size):
108 | if 'flow' in k:
109 | img = flow2img(rgb_tensor[b, :, :, :]) / 255
110 | else:
111 | img = np.squeeze(np.transpose(rgb_tensor[b, :, :, :], (1, 2, 0)))
112 | if tags is not None:
113 | prefix = '%s_%s.png' % (k, tags[b])
114 | else:
115 | prefix = '%s_%04d_%04d.png' % (k, batch_idx, b)
116 | Image.fromarray((img * 255).astype(np.uint8)).save(join(outdir, prefix), 'PNG')
117 |
118 | @staticmethod
119 | def _error_callback(e):
120 | print(str(e))
121 |
--------------------------------------------------------------------------------