├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── WINDOWS_INSTALLATION.md
├── cog.yaml
├── datasets
├── create_middlebury_tfrecord.py
├── create_ucf101_tfrecord.py
├── create_vimeo90K_tfrecord.py
├── create_xiph_tfrecord.py
└── util.py
├── eval
├── config
│ ├── middlebury.gin
│ ├── ucf101.gin
│ ├── vimeo_90K.gin
│ ├── xiph_2K.gin
│ └── xiph_4K.gin
├── eval_cli.py
├── interpolator.py
├── interpolator_cli.py
├── interpolator_test.py
└── util.py
├── losses
├── losses.py
└── vgg19_loss.py
├── models
└── film_net
│ ├── feature_extractor.py
│ ├── fusion.py
│ ├── interpolator.py
│ ├── options.py
│ ├── pyramid_flow_estimator.py
│ └── util.py
├── moment.gif
├── photos
├── one.png
└── two.png
├── predict.py
├── requirements.txt
└── training
├── augmentation_lib.py
├── build_saved_model_cli.py
├── config
├── film_net-L1.gin
├── film_net-Style.gin
└── film_net-VGG.gin
├── data_lib.py
├── eval_lib.py
├── metrics_lib.py
├── model_lib.py
├── train.py
└── train_lib.py
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement (CLA). You (or your employer) retain the copyright to your
10 | contribution; this simply gives us permission to use and redistribute your
11 | contributions as part of the project. Head over to
12 | to see your current agreements on file or
13 | to sign a new one.
14 |
15 | You generally only need to submit a CLA once, so if you've already submitted one
16 | (even if it was for a different project), you probably don't need to do it
17 | again.
18 |
19 | ## Code Reviews
20 |
21 | All submissions, including submissions by project members, require review. We
22 | use GitHub pull requests for this purpose. Consult
23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
24 | information on using pull requests.
25 |
26 | ## Community Guidelines
27 |
28 | This project follows
29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
30 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FILM: Frame Interpolation for Large Motion
2 |
3 | ### [Website](https://film-net.github.io/) | [Paper](https://arxiv.org/pdf/2202.04901.pdf) | [Google AI Blog](https://ai.googleblog.com/2022/10/large-motion-frame-interpolation.html) | [Tensorflow Hub Colab](https://www.tensorflow.org/hub/tutorials/tf_hub_film_example) | [YouTube](https://www.youtube.com/watch?v=OAD-BieIjH4)
4 |
5 | The official Tensorflow 2 implementation of our high quality frame interpolation neural network. We present a unified single-network approach that doesn't use additional pre-trained networks, like optical flow or depth, and yet achieve state-of-the-art results. We use a multi-scale feature extractor that shares the same convolution weights across the scales. Our model is trainable from frame triplets alone.
6 |
7 | [FILM: Frame Interpolation for Large Motion](https://arxiv.org/abs/2202.04901)
8 | [Fitsum Reda](https://fitsumreda.github.io/)1 , [Janne Kontkanen](https://scholar.google.com/citations?user=MnXc4JQAAAAJ&hl=en)1 , [Eric Tabellion](http://www.tabellion.org/et/)1 , [Deqing Sun](https://deqings.github.io/)1 , [Caroline Pantofaru](https://scholar.google.com/citations?user=vKAKE1gAAAAJ&hl=en)1 , [Brian Curless](https://homes.cs.washington.edu/~curless/)1,2
9 | 1 Google Research, 2 University of Washington
10 | In ECCV 2022.
11 |
12 | 
13 | FILM transforms near-duplicate photos into a slow motion footage that look like it is shot with a video camera.
14 |
15 | ## Web Demo
16 |
17 | Integrated into [Hugging Face Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). Try out the Web Demo: [](https://huggingface.co/spaces/johngoad/frame-interpolation)
18 |
19 | Try the interpolation model with the replicate web demo at
20 | [](https://replicate.com/google-research/frame-interpolation)
21 |
22 | Try FILM to interpolate between two or more images with the PyTTI-Tools at [](https://colab.sandbox.google.com/github/pytti-tools/frame-interpolation/blob/main/PyTTI_Tools_FiLM-colab.ipynb#scrollTo=-7TD7YZJbsy_)
23 |
24 | An alternative Colab for running FILM on arbitrarily more input images, not just on two images, [](https://colab.research.google.com/drive/1NuaPPSvUhYafymUf2mEkvhnEtpD5oihs)
25 |
26 | ## Change Log
27 | * **Nov 28, 2022**: Upgrade `eval.interpolator_cli` for **high resolution frame interpolation**. `--block_height` and `--block_width` determine the total number of patches (`block_height*block_width`) to subdivide the input images. By default, both arguments are set to 1, and so no subdivision will be done.
28 | * **Mar 12, 2022**: Support for Windows, see [WINDOWS_INSTALLATION.md](https://github.com/google-research/frame-interpolation/blob/main/WINDOWS_INSTALLATION.md).
29 | * **Mar 09, 2022**: Support for **high resolution frame interpolation**. Set `--block_height` and `--block_width` in `eval.interpolator_test` to extract patches from the inputs, and reconstruct the interpolated frame from the iteratively interpolated patches.
30 |
31 | ## Installation
32 |
33 | * Get Frame Interpolation source codes
34 |
35 | ```
36 | git clone https://github.com/google-research/frame-interpolation
37 | cd frame-interpolation
38 | ```
39 |
40 | * Optionally, pull the recommended Docker base image
41 |
42 | ```
43 | docker pull gcr.io/deeplearning-platform-release/tf2-gpu.2-6:latest
44 | ```
45 |
46 | * If you do not use Docker, set up your NVIDIA GPU environment with:
47 | * [Anaconda Python 3.9](https://www.anaconda.com/products/individual)
48 | * [CUDA Toolkit 11.2.1](https://developer.nvidia.com/cuda-11.2.1-download-archive)
49 | * [cuDNN 8.1.0](https://developer.nvidia.com/rdp/cudnn-download)
50 |
51 | * Install frame interpolation dependencies
52 |
53 | ```
54 | pip3 install -r requirements.txt
55 | sudo apt-get install -y ffmpeg
56 | ```
57 |
58 | ### See [WINDOWS_INSTALLATION](https://github.com/google-research/frame-interpolation/blob/main/WINDOWS_INSTALLATION.md) for Windows Support
59 |
60 | ## Pre-trained Models
61 |
62 | * Create a directory where you can keep large files. Ideally, not in this
63 | directory.
64 |
65 | ```
66 | mkdir -p
67 | ```
68 |
69 | * Download pre-trained TF2 Saved Models from
70 | [google drive](https://drive.google.com/drive/folders/1q8110-qp225asX3DQvZnfLfJPkCHmDpy?usp=sharing)
71 | and put into ``.
72 |
73 | The downloaded folder should have the following structure:
74 |
75 | ```
76 | /
77 | ├── film_net/
78 | │ ├── L1/
79 | │ ├── Style/
80 | │ ├── VGG/
81 | ├── vgg/
82 | │ ├── imagenet-vgg-verydeep-19.mat
83 | ```
84 |
85 | ## Running the Codes
86 |
87 | The following instructions run the interpolator on the photos provided in
88 | 'frame-interpolation/photos'.
89 |
90 | ### One mid-frame interpolation
91 |
92 | To generate an intermediate photo from the input near-duplicate photos, simply run:
93 |
94 | ```
95 | python3 -m eval.interpolator_test \
96 | --frame1 photos/one.png \
97 | --frame2 photos/two.png \
98 | --model_path /film_net/Style/saved_model \
99 | --output_frame photos/output_middle.png
100 | ```
101 |
102 | This will produce the sub-frame at `t=0.5` and save as 'photos/output_middle.png'.
103 |
104 | ### Many in-between frames interpolation
105 |
106 | It takes in a set of directories identified by a glob (--pattern). Each directory
107 | is expected to contain at least two input frames, with each contiguous frame
108 | pair treated as an input to generate in-between frames. Frames should be named such that when sorted (naturally) with `natsort`, their desired order is unchanged.
109 |
110 | ```
111 | python3 -m eval.interpolator_cli \
112 | --pattern "photos" \
113 | --model_path /film_net/Style/saved_model \
114 | --times_to_interpolate 6 \
115 | --output_video
116 | ```
117 |
118 | You will find the interpolated frames (including the input frames) in
119 | 'photos/interpolated_frames/', and the interpolated video at
120 | 'photos/interpolated.mp4'.
121 |
122 | The number of frames is determined by `--times_to_interpolate`, which controls
123 | the number of times the frame interpolator is invoked. When the number of frames
124 | in a directory is `num_frames`, the number of output frames will be
125 | `(2^times_to_interpolate+1)*(num_frames-1)`.
126 |
127 | ## Datasets
128 |
129 | We use [Vimeo-90K](http://data.csail.mit.edu/tofu/dataset/vimeo_triplet.zip) as
130 | our main training dataset. For quantitative evaluations, we rely on commonly
131 | used benchmark datasets, specifically:
132 |
133 | * [Vimeo-90K](http://data.csail.mit.edu/tofu/testset/vimeo_interp_test.zip)
134 | * [Middlebury-Other](https://vision.middlebury.edu/flow/data)
135 | * [UCF101](https://people.cs.umass.edu/~hzjiang/projects/superslomo/UCF101_results.zip)
136 | * [Xiph](https://github.com/sniklaus/softmax-splatting/blob/master/benchmark.py)
137 |
138 | ### Creating a TFRecord
139 |
140 | The training and benchmark evaluation scripts expect the frame triplets in the
141 | [TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord) storage format.
142 |
143 | We have included scripts that encode the relevant frame triplets into a
144 | [tf.train.Example](https://www.tensorflow.org/api_docs/python/tf/train/Example)
145 | data format, and export to a TFRecord file.
146 |
147 | You can use the commands `python3 -m
148 | datasets.create__tfrecord --help` for more information.
149 |
150 | For example, run the command below to create a TFRecord for the Middlebury-other
151 | dataset. Download the [images](https://vision.middlebury.edu/flow/data) and point `--input_dir` to the unzipped folder path.
152 |
153 | ```
154 | python3 -m datasets.create_middlebury_tfrecord \
155 | --input_dir= \
156 | --output_tfrecord_filepath= \
157 | --num_shards=3
158 | ```
159 |
160 | The above command will output a TFRecord file with 3 shards as `@3`.
161 |
162 | ## Training
163 |
164 | Below are our training gin configuration files for the different loss function:
165 |
166 | ```
167 | training/
168 | ├── config/
169 | │ ├── film_net-L1.gin
170 | │ ├── film_net-VGG.gin
171 | │ ├── film_net-Style.gin
172 | ```
173 |
174 | To launch a training, simply pass the configuration filepath to the desired
175 | experiment.
176 | By default, it uses all visible GPUs for training. To debug or train
177 | on a CPU, append `--mode cpu`.
178 |
179 | ```
180 | python3 -m training.train \
181 | --gin_config training/config/.gin \
182 | --base_folder \
183 | --label
184 | ```
185 |
186 | * When training finishes, the folder structure will look like this:
187 |
188 | ```
189 | /
190 | ├── /
191 | │ ├── config.gin
192 | │ ├── eval/
193 | │ ├── train/
194 | │ ├── saved_model/
195 | ```
196 |
197 | ### Build a SavedModel
198 |
199 | Optionally, to build a
200 | [SavedModel](https://www.tensorflow.org/guide/saved_model) format from a trained
201 | checkpoints folder, you can use this command:
202 |
203 | ```
204 | python3 -m training.build_saved_model_cli \
205 | --base_folder \
206 | --label
207 | ```
208 |
209 | * By default, a SavedModel is created when the training loop ends, and it will be saved at
210 | `//saved_model`.
211 |
212 | ## Evaluation on Benchmarks
213 |
214 | Below, we provided the evaluation gin configuration files for the benchmarks we
215 | have considered:
216 |
217 | ```
218 | eval/
219 | ├── config/
220 | │ ├── middlebury.gin
221 | │ ├── ucf101.gin
222 | │ ├── vimeo_90K.gin
223 | │ ├── xiph_2K.gin
224 | │ ├── xiph_4K.gin
225 | ```
226 |
227 | To run an evaluation, simply pass the configuration file of the desired evaluation dataset.
228 | If a GPU is visible, it runs on it.
229 |
230 | ```
231 | python3 -m eval.eval_cli \
232 | --gin_config eval/config/.gin \
233 | --model_path /film_net/L1/saved_model
234 | ```
235 |
236 | The above command will produce the PSNR and SSIM scores presented in the paper.
237 |
238 | ## Citation
239 |
240 | If you find this implementation useful in your works, please acknowledge it
241 | appropriately by citing:
242 |
243 | ```
244 | @inproceedings{reda2022film,
245 | title = {FILM: Frame Interpolation for Large Motion},
246 | author = {Fitsum Reda and Janne Kontkanen and Eric Tabellion and Deqing Sun and Caroline Pantofaru and Brian Curless},
247 | booktitle = {European Conference on Computer Vision (ECCV)},
248 | year = {2022}
249 | }
250 | ```
251 |
252 | ```
253 | @misc{film-tf,
254 | title = {Tensorflow 2 Implementation of "FILM: Frame Interpolation for Large Motion"},
255 | author = {Fitsum Reda and Janne Kontkanen and Eric Tabellion and Deqing Sun and Caroline Pantofaru and Brian Curless},
256 | year = {2022},
257 | publisher = {GitHub},
258 | journal = {GitHub repository},
259 | howpublished = {\url{https://github.com/google-research/frame-interpolation}}
260 | }
261 | ```
262 |
263 | ## Acknowledgments
264 |
265 | We would like to thank Richard Tucker, Jason Lai and David Minnen. We would also
266 | like to thank Jamie Aspinall for the imagery included in this repository.
267 |
268 | ## Coding style
269 |
270 | * 2 spaces for indentation
271 | * 80 character line length
272 | * PEP8 formatting
273 |
274 | ## Disclaimer
275 |
276 | This is not an officially supported Google product.
277 |
--------------------------------------------------------------------------------
/WINDOWS_INSTALLATION.md:
--------------------------------------------------------------------------------
1 | # [FILM](https://github.com/google-research/frame-interpolation): Windows Installation Instructions
2 |
3 | ## Anaconda Python 3.9 (Optional)
4 |
5 | #### Install Anaconda3 Python3.9
6 | * Go to [https://www.anaconda.com/products/individual](https://www.anaconda.com/products/individual) and click the "Download" button.
7 | * Download the Windows [64-Bit](https://repo.anaconda.com/archive/Anaconda3-2021.11-Windows-x86_64.exe) or [32-bit](https://repo.anaconda.com/archive/Anaconda3-2021.11-Windows-x86.exe) Graphical Installer, depending on your system needs.
8 | * Run the downloaded (`.exe`) file to begin the installation.
9 | * (Optional) Check the "Add Anaconda3 to my PATH environment variable". You may get a 'red text' warning of its implications, you may ignore it for this setup.
10 |
11 | #### Create a new Anaconda virtual environment
12 | * Open a new Terminal
13 | * Type the following command:
14 | ```
15 | conda create -n frame_interpolation pip python=3.9
16 | ```
17 | * The above command will create a new virtual environment with the name `frame_interpolation`
18 |
19 | #### Activate the Anaconda virtual environment
20 | * Activate the newly created virtual environment by typing in your terminal (Command Prompt or PowerShell)
21 | ```
22 | conda activate frame_interpolation
23 | ```
24 | * Once activated, your terminal should look like:
25 | ```
26 | (frame_interpolation) >
27 | ```
28 |
29 | ## NVIDIA GPU Support
30 | #### Install CUDA Toolkit
31 | * Go to [https://developer.nvidia.com/cuda-11.2.1-download-archive](https://developer.nvidia.com/cuda-11.2.1-download-archive) and select your `Windows`.
32 | * Download and install `CUDA Tookit 11.2.1`.
33 | * Additional CUDA installation information available [here](https://docs.nvidia.com/cuda/archive/11.2.2/cuda-installation-guide-microsoft-windows/index.html).
34 |
35 | #### Install cuDNN
36 | * Go to [https://developer.nvidia.com/rdp/cudnn-download](https://developer.nvidia.com/rdp/cudnn-download).
37 | * Create a user profile (if needed) and login.
38 | * Select `cuDNN v8.1.0 (January 26th, 2021), for CUDA 11.0,11.1 and 11.2`.
39 | * Download [cuDNN Library for Widnows (x86)](https://developer.nvidia.com/compute/machine-learning/cudnn/secure/8.1.0.77/11.2_20210127/cudnn-11.2-windows-x64-v8.1.0.77.zip).
40 | * Extract the contents of the zipped folder (it contains a folder named `cuda`) into `\NVIDIA GPU Computing Toolkit\CUDA\v11.2\`. `` points to the installation directory specified during CUDA Toolkit installation. By default, ` = C:\Program Files`.
41 |
42 | #### Environment Setup
43 | * Add the following paths to your 'Advanced System Settings' > 'Environment Variables ...' > Edit 'Path', and add:
44 | * \NVIDIA GPU Computing Toolkit\CUDA\v11.2\bin
45 | * \NVIDIA GPU Computing Toolkit\CUDA\v11.2\libnvvp
46 | * \NVIDIA GPU Computing Toolkit\CUDA\v11.2\include
47 | * \NVIDIA GPU Computing Toolkit\CUDA\v11.2\extras\CUPTI\lib64
48 | * \NVIDIA GPU Computing Toolkit\CUDA\v11.2\cuda\bin
49 |
50 | #### Verify Installation
51 | * Open a **new** terminal and type `conda activate frame_interpolation`.
52 | * Install (temporarily) tensorflow and run a simple operation, by typing:
53 | ```
54 | pip install --ignore-installed --upgrade tensorflow==2.6.0
55 | python -c "import tensorflow as tf;print(tf.reduce_sum(tf.random.normal([1000, 1000])))"
56 | ```
57 | * You should see success messages: 'Created device /job:localhost/replica:0/task:0/device:GPU:0'.
58 |
59 | ## FILM Installation
60 | * Get Frame Interpolation source codes
61 | ```
62 | git clone https://github.com/google-research/frame-interpolation
63 | cd frame-interpolation
64 | ```
65 | * Install dependencies
66 | ```
67 | pip install -r requirements.txt
68 | conda install -c conda-forge ffmpeg
69 | ```
70 | * Download pre-traned models, detailed [here](https://github.com/google-research/frame-interpolation#pre-trained-models).
71 |
72 | ## Running the Codes
73 | * One mid-frame interpolation. Note: `python3` may not be recognized in Windows, so simply drop `3` as below.
74 | ```
75 | python -m eval.interpolator_test --frame1 photos\one.png --frame2 photos\two.png --model_path \film_net\Style\saved_model --output_frame photos\output_middle.png
76 | ```
77 |
78 | * Large resolution mid-frame interpolation: Set `block_height` and `--block_width` to subdivide along the height and width to create patches, where the interpolator will be run iteratively, and the resulting interpolated mid-patches will be reconstructed into a final mid-frame. In the example below, will create and run on 4 patches (2*2).
79 | ```
80 | python -m eval.interpolator_test --frame1 photos\one.png --frame2 photos\two.png --block_height 2 --block_wdith 2 --model_path \film_net\Style\saved_model --output_frame photos\output_middle.png
81 | ```
82 | * Many in-between frames interpolation
83 | ```
84 | python -m eval.interpolator_cli --pattern "photos" --model_path \film_net\Style\saved_model --times_to_interpolate 6 --output_video
85 | ```
86 |
87 | ## Acknowledgments
88 |
89 | This windows installation guide is heavily based on [tensorflow-object-detection-api-tutorial](https://tensorflow-object-detection-api-tutorial.readthedocs.io/en/latest/install.html) .
90 |
--------------------------------------------------------------------------------
/cog.yaml:
--------------------------------------------------------------------------------
1 | build:
2 | gpu: true
3 | cuda: "11.2"
4 | python_version: "3.8"
5 | system_packages:
6 | - "libgl1-mesa-glx"
7 | - "libglib2.0-0"
8 | python_packages:
9 | - "ipython==7.30.1"
10 | - "tensorflow-gpu==2.8.0"
11 | - "tensorflow-datasets==4.4.0"
12 | - "tensorflow-addons==0.15.0"
13 | - "absl-py==0.12.0"
14 | - "gin-config==0.5.0"
15 | - "parameterized==0.8.1"
16 | - "mediapy==1.0.3"
17 | - "scikit-image==0.19.1"
18 | - "apache-beam==2.34.0"
19 | run:
20 | - apt-get update && apt-get install -y software-properties-common
21 | - apt-get install ffmpeg -y
22 |
23 | predict: "predict.py:Predictor"
24 |
--------------------------------------------------------------------------------
/datasets/create_middlebury_tfrecord.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | r"""Beam pipeline that generates Middlebury `Other Datasets` triplet TFRecords.
16 |
17 | Middlebury interpolation evaluation dataset consists of two subsets.
18 |
19 | (1) Two frames only, without the intermediate golden frame. A total of 12 such
20 | pairs, with folder names (Army, Backyard, Basketball, Dumptruck,
21 | Evergreen, Grove, Mequon, Schefflera, Teddy, Urban, Wooden, Yosemite)
22 |
23 | (2) Two frames together with the intermediate golden frame. A total of 12 such
24 | triplets, with folder names (Beanbags, Dimetrodon, DogDance, Grove2,
25 | Grove3, Hydrangea, MiniCooper, RubberWhale, Urban2, Urban3, Venus, Walking)
26 |
27 | This script runs on (2), i.e. the dataset with the golden frames. For more
28 | information, visit https://vision.middlebury.edu/flow/data.
29 |
30 | Input to the script is the root-folder that contains the unzipped folders
31 | of input pairs (other-data) and golen frames (other-gt-interp).
32 |
33 | Output TFRecord is a tf.train.Example proto of each image triplet.
34 | The feature_map takes the form:
35 | feature_map {
36 | 'frame_0/encoded':
37 | tf.io.FixedLenFeature((), tf.string, default_value=''),
38 | 'frame_0/format':
39 | tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
40 | 'frame_0/height':
41 | tf.io.FixedLenFeature((), tf.int64, default_value=0),
42 | 'frame_0/width':
43 | tf.io.FixedLenFeature((), tf.int64, default_value=0),
44 | 'frame_1/encoded':
45 | tf.io.FixedLenFeature((), tf.string, default_value=''),
46 | 'frame_1/format':
47 | tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
48 | 'frame_1/height':
49 | tf.io.FixedLenFeature((), tf.int64, default_value=0),
50 | 'frame_1/width':
51 | tf.io.FixedLenFeature((), tf.int64, default_value=0),
52 | 'frame_2/encoded':
53 | tf.io.FixedLenFeature((), tf.string, default_value=''),
54 | 'frame_2/format':
55 | tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
56 | 'frame_2/height':
57 | tf.io.FixedLenFeature((), tf.int64, default_value=0),
58 | 'frame_2/width':
59 | tf.io.FixedLenFeature((), tf.int64, default_value=0),
60 | 'path':
61 | tf.io.FixedLenFeature((), tf.string, default_value=''),
62 | }
63 |
64 | Usage example:
65 | python3 -m frame_interpolation.datasets.create_middlebury_tfrecord \
66 | --input_dir= \
67 | --output_tfrecord_filepath=
68 | """
69 |
70 | import os
71 |
72 | from . import util
73 | from absl import app
74 | from absl import flags
75 | from absl import logging
76 | import apache_beam as beam
77 | import tensorflow as tf
78 |
79 | _INPUT_DIR = flags.DEFINE_string(
80 | 'input_dir',
81 | default='/root/path/to/middlebury-other',
82 | help='Path to the root directory of the `Other Datasets` of the Middlebury '
83 | 'interpolation evaluation data. '
84 | 'We expect the data to have been downloaded and unzipped. \n'
85 | 'Folder structures:\n'
86 | '| raw_middlebury_other_dataset/\n'
87 | '| other-data/\n'
88 | '| | Beanbags\n'
89 | '| | | frame10.png\n'
90 | '| | | frame11.png\n'
91 | '| | Dimetrodon\n'
92 | '| | | frame10.png\n'
93 | '| | | frame11.png\n'
94 | '| | ...\n'
95 | '| other-gt-interp/\n'
96 | '| | Beanbags\n'
97 | '| | | frame10i11.png\n'
98 | '| | Dimetrodon\n'
99 | '| | | frame10i11.png\n'
100 | '| | ...\n')
101 |
102 | _INPUT_PAIRS_FOLDERNAME = flags.DEFINE_string(
103 | 'input_pairs_foldername',
104 | default='other-data',
105 | help='Foldername containing the folders of the input frame pairs.')
106 |
107 | _GOLDEN_FOLDERNAME = flags.DEFINE_string(
108 | 'golden_foldername',
109 | default='other-gt-interp',
110 | help='Foldername containing the folders of the golden frame.')
111 |
112 | _OUTPUT_TFRECORD_FILEPATH = flags.DEFINE_string(
113 | 'output_tfrecord_filepath',
114 | default=None,
115 | required=True,
116 | help='Filepath to the output TFRecord file.')
117 |
118 | _NUM_SHARDS = flags.DEFINE_integer('num_shards',
119 | default=3,
120 | help='Number of shards used for the output.')
121 |
122 | # Image key -> basename for frame interpolator: start / middle / end frames.
123 | _INTERPOLATOR_IMAGES_MAP = {
124 | 'frame_0': 'frame10.png',
125 | 'frame_1': 'frame10i11.png',
126 | 'frame_2': 'frame11.png',
127 | }
128 |
129 |
130 | def main(unused_argv):
131 | """Creates and runs a Beam pipeline to write frame triplets as a TFRecord."""
132 | # Collect the list of folder paths containing the input and golen frames.
133 | pairs_list = tf.io.gfile.listdir(
134 | os.path.join(_INPUT_DIR.value, _INPUT_PAIRS_FOLDERNAME.value))
135 |
136 | folder_names = [
137 | _INPUT_PAIRS_FOLDERNAME.value, _GOLDEN_FOLDERNAME.value,
138 | _INPUT_PAIRS_FOLDERNAME.value
139 | ]
140 | triplet_dicts = []
141 | for pair in pairs_list:
142 | triplet_dict = {
143 | image_key: os.path.join(_INPUT_DIR.value, folder, pair, image_basename)
144 | for folder, (image_key, image_basename
145 | ) in zip(folder_names, _INTERPOLATOR_IMAGES_MAP.items())
146 | }
147 | triplet_dicts.append(triplet_dict)
148 |
149 | p = beam.Pipeline('DirectRunner')
150 | (p | 'ReadInputTripletDicts' >> beam.Create(triplet_dicts) # pylint: disable=expression-not-assigned
151 | | 'GenerateSingleExample' >> beam.ParDo(
152 | util.ExampleGenerator(_INTERPOLATOR_IMAGES_MAP))
153 | | 'WriteToTFRecord' >> beam.io.tfrecordio.WriteToTFRecord(
154 | file_path_prefix=_OUTPUT_TFRECORD_FILEPATH.value,
155 | num_shards=_NUM_SHARDS.value,
156 | coder=beam.coders.BytesCoder()))
157 | result = p.run()
158 | result.wait_until_finish()
159 |
160 | logging.info('Succeeded in creating the output TFRecord file: \'%s@%s\'.',
161 | _OUTPUT_TFRECORD_FILEPATH.value, str(_NUM_SHARDS.value))
162 |
163 | if __name__ == '__main__':
164 | app.run(main)
165 |
--------------------------------------------------------------------------------
/datasets/create_ucf101_tfrecord.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | r"""Beam pipeline that generates UCF101 `interp_test` triplet TFRecords.
16 |
17 | UCF101 interpolation evaluation dataset consists of 379 triplets, with the
18 | middle frame being the golden intermediate. The dataset is available here:
19 | https://people.cs.umass.edu/~hzjiang/projects/superslomo/UCF101_results.zip.
20 |
21 | Input to the script is the root folder that contains the unzipped
22 | `UCF101_results` folder.
23 |
24 | Output TFRecord is a tf.train.Example proto of each image triplet.
25 | The feature_map takes the form:
26 | feature_map {
27 | 'frame_0/encoded':
28 | tf.io.FixedLenFeature((), tf.string, default_value=''),
29 | 'frame_0/format':
30 | tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
31 | 'frame_0/height':
32 | tf.io.FixedLenFeature((), tf.int64, default_value=0),
33 | 'frame_0/width':
34 | tf.io.FixedLenFeature((), tf.int64, default_value=0),
35 | 'frame_1/encoded':
36 | tf.io.FixedLenFeature((), tf.string, default_value=''),
37 | 'frame_1/format':
38 | tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
39 | 'frame_1/height':
40 | tf.io.FixedLenFeature((), tf.int64, default_value=0),
41 | 'frame_1/width':
42 | tf.io.FixedLenFeature((), tf.int64, default_value=0),
43 | 'frame_2/encoded':
44 | tf.io.FixedLenFeature((), tf.string, default_value=''),
45 | 'frame_2/format':
46 | tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
47 | 'frame_2/height':
48 | tf.io.FixedLenFeature((), tf.int64, default_value=0),
49 | 'frame_2/width':
50 | tf.io.FixedLenFeature((), tf.int64, default_value=0),
51 | 'path':
52 | tf.io.FixedLenFeature((), tf.string, default_value=''),
53 | }
54 |
55 | Usage example:
56 | python3 -m frame_interpolation.datasets.create_ucf101_tfrecord \
57 | --input_dir= \
58 | --output_tfrecord_filepath=
59 | """
60 |
61 | import os
62 |
63 | from . import util
64 | from absl import app
65 | from absl import flags
66 | from absl import logging
67 | import apache_beam as beam
68 | import tensorflow as tf
69 |
70 | _INPUT_DIR = flags.DEFINE_string(
71 | 'input_dir',
72 | default='/root/path/to/UCF101_results/ucf101_interp_ours',
73 | help='Path to the root directory of the `UCF101_results` of the UCF101 '
74 | 'interpolation evaluation data. '
75 | 'We expect the data to have been downloaded and unzipped. \n'
76 | 'Folder structures:\n'
77 | '| raw_UCF101_results/\n'
78 | '| ucf101_interp_ours/\n'
79 | '| | 1/\n'
80 | '| | | frame_00.png\n'
81 | '| | | frame_01_gt.png\n'
82 | '| | | frame_01_ours.png\n'
83 | '| | | frame_02.png\n'
84 | '| | 2/\n'
85 | '| | | frame_00.png\n'
86 | '| | | frame_01_gt.png\n'
87 | '| | | frame_01_ours.png\n'
88 | '| | | frame_02.png\n'
89 | '| | ...\n'
90 | '| ucf101_sepconv/\n'
91 | '| ...\n')
92 |
93 | _OUTPUT_TFRECORD_FILEPATH = flags.DEFINE_string(
94 | 'output_tfrecord_filepath',
95 | default=None,
96 | required=True,
97 | help='Filepath to the output TFRecord file.')
98 |
99 | _NUM_SHARDS = flags.DEFINE_integer('num_shards',
100 | default=2,
101 | help='Number of shards used for the output.')
102 |
103 | # Image key -> basename for frame interpolator: start / middle / end frames.
104 | _INTERPOLATOR_IMAGES_MAP = {
105 | 'frame_0': 'frame_00.png',
106 | 'frame_1': 'frame_01_gt.png',
107 | 'frame_2': 'frame_02.png',
108 | }
109 |
110 |
111 | def main(unused_argv):
112 | """Creates and runs a Beam pipeline to write frame triplets as a TFRecord."""
113 | # Collect the list of folder paths containing the input and golden frames.
114 | triplets_list = tf.io.gfile.listdir(_INPUT_DIR.value)
115 |
116 | triplet_dicts = []
117 | for triplet in triplets_list:
118 | triplet_dicts.append({
119 | image_key: os.path.join(_INPUT_DIR.value, triplet, image_basename)
120 | for image_key, image_basename in _INTERPOLATOR_IMAGES_MAP.items()
121 | })
122 |
123 | p = beam.Pipeline('DirectRunner')
124 | (p | 'ReadInputTripletDicts' >> beam.Create(triplet_dicts) # pylint: disable=expression-not-assigned
125 | | 'GenerateSingleExample' >> beam.ParDo(
126 | util.ExampleGenerator(_INTERPOLATOR_IMAGES_MAP))
127 | | 'WriteToTFRecord' >> beam.io.tfrecordio.WriteToTFRecord(
128 | file_path_prefix=_OUTPUT_TFRECORD_FILEPATH.value,
129 | num_shards=_NUM_SHARDS.value,
130 | coder=beam.coders.BytesCoder()))
131 | result = p.run()
132 | result.wait_until_finish()
133 |
134 | logging.info('Succeeded in creating the output TFRecord file: \'%s@%s\'.',
135 | _OUTPUT_TFRECORD_FILEPATH.value, str(_NUM_SHARDS.value))
136 |
137 | if __name__ == '__main__':
138 | app.run(main)
139 |
--------------------------------------------------------------------------------
/datasets/create_vimeo90K_tfrecord.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | r"""Beam pipeline that generates Vimeo-90K (train or test) triplet TFRecords.
16 |
17 | Vimeo-90K dataset is built upon 5,846 videos downloaded from vimeo.com. The list
18 | of the original video links are available here:
19 | https://github.com/anchen1011/toflow/blob/master/data/original_vimeo_links.txt.
20 | Each video is further cropped into a fixed spatial size of (448 x 256) to create
21 | 89,000 video clips.
22 |
23 | The Vimeo-90K dataset is designed for four video processing tasks. This script
24 | creates the TFRecords of frame triplets for frame interpolation task.
25 |
26 | Temporal frame interpolation triplet dataset:
27 | - 73,171 triplets of size (448x256) extracted from 15K subsets of Vimeo-90K.
28 | - The triplets are pre-split into (train,test) = (51313,3782)
29 | - Download links:
30 | Test-set: http://data.csail.mit.edu/tofu/testset/vimeo_interp_test.zip
31 | Train+test-set: http://data.csail.mit.edu/tofu/dataset/vimeo_triplet.zip
32 |
33 | For more information, see the arXiv paper, project page or the GitHub link.
34 | @article{xue17toflow,
35 | author = {Xue, Tianfan and
36 | Chen, Baian and
37 | Wu, Jiajun and
38 | Wei, Donglai and
39 | Freeman, William T},
40 | title = {Video Enhancement with Task-Oriented Flow},
41 | journal = {arXiv},
42 | year = {2017}
43 | }
44 | Project: http://toflow.csail.mit.edu/
45 | GitHub: https://github.com/anchen1011/toflow
46 |
47 | Inputs to the script are (1) the directory to the downloaded and unzipped folder
48 | (2) the filepath of the text-file that lists the subfolders of the triplets.
49 |
50 | Output TFRecord is a tf.train.Example proto of each image triplet.
51 | The feature_map takes the form:
52 | feature_map {
53 | 'frame_0/encoded':
54 | tf.io.FixedLenFeature((), tf.string, default_value=''),
55 | 'frame_0/format':
56 | tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
57 | 'frame_0/height':
58 | tf.io.FixedLenFeature((), tf.int64, default_value=0),
59 | 'frame_0/width':
60 | tf.io.FixedLenFeature((), tf.int64, default_value=0),
61 | 'frame_1/encoded':
62 | tf.io.FixedLenFeature((), tf.string, default_value=''),
63 | 'frame_1/format':
64 | tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
65 | 'frame_1/height':
66 | tf.io.FixedLenFeature((), tf.int64, default_value=0),
67 | 'frame_1/width':
68 | tf.io.FixedLenFeature((), tf.int64, default_value=0),
69 | 'frame_2/encoded':
70 | tf.io.FixedLenFeature((), tf.string, default_value=''),
71 | 'frame_2/format':
72 | tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
73 | 'frame_2/height':
74 | tf.io.FixedLenFeature((), tf.int64, default_value=0),
75 | 'frame_2/width':
76 | tf.io.FixedLenFeature((), tf.int64, default_value=0)
77 | 'path':
78 | tf.io.FixedLenFeature((), tf.string, default_value='')
79 | }
80 |
81 | Usage example:
82 | python3 -m frame_interpolation.datasets.create_vimeo90K_tfrecord \
83 | --input_dir= \
84 | --input_triplet_list_filepath= \
85 | --output_tfrecord_filepath=
86 | """
87 | import os
88 |
89 | from . import util
90 | from absl import app
91 | from absl import flags
92 | from absl import logging
93 | import apache_beam as beam
94 | import numpy as np
95 | import tensorflow as tf
96 |
97 |
98 | _INPUT_DIR = flags.DEFINE_string(
99 | 'input_dir',
100 | default='/path/to/raw_vimeo_interp/sequences',
101 | help='Path to the root directory of the vimeo frame interpolation dataset. '
102 | 'We expect the data to have been downloaded and unzipped.\n'
103 | 'Folder structures:\n'
104 | '| raw_vimeo_dataset/\n'
105 | '| sequences/\n'
106 | '| | 00001\n'
107 | '| | | 0389/\n'
108 | '| | | | im1.png\n'
109 | '| | | | im2.png\n'
110 | '| | | | im3.png\n'
111 | '| | | ...\n'
112 | '| | 00002/\n'
113 | '| | ...\n'
114 | '| readme.txt\n'
115 | '| tri_trainlist.txt\n'
116 | '| tri_testlist.txt \n')
117 |
118 | _INTPUT_TRIPLET_LIST_FILEPATH = flags.DEFINE_string(
119 | 'input_triplet_list_filepath',
120 | default='/path/to/raw_vimeo_dataset/tri_{test|train}list.txt',
121 | help='Text file containing a list of sub-directories of input triplets.')
122 |
123 | _OUTPUT_TFRECORD_FILEPATH = flags.DEFINE_string(
124 | 'output_tfrecord_filepath',
125 | default=None,
126 | help='Filepath to the output TFRecord file.')
127 |
128 | _NUM_SHARDS = flags.DEFINE_integer('num_shards',
129 | default=200, # set to 3 for vimeo_test, and 200 for vimeo_train.
130 | help='Number of shards used for the output.')
131 |
132 | # Image key -> basename for frame interpolator: start / middle / end frames.
133 | _INTERPOLATOR_IMAGES_MAP = {
134 | 'frame_0': 'im1.png',
135 | 'frame_1': 'im2.png',
136 | 'frame_2': 'im3.png',
137 | }
138 |
139 |
140 | def main(unused_argv):
141 | """Creates and runs a Beam pipeline to write frame triplets as a TFRecord."""
142 | with tf.io.gfile.GFile(_INTPUT_TRIPLET_LIST_FILEPATH.value, 'r') as fid:
143 | triplets_list = np.loadtxt(fid, dtype=str)
144 |
145 | triplet_dicts = []
146 | for triplet in triplets_list:
147 | triplet_dict = {
148 | image_key: os.path.join(_INPUT_DIR.value, triplet, image_basename)
149 | for image_key, image_basename in _INTERPOLATOR_IMAGES_MAP.items()
150 | }
151 | triplet_dicts.append(triplet_dict)
152 | p = beam.Pipeline('DirectRunner')
153 | (p | 'ReadInputTripletDicts' >> beam.Create(triplet_dicts) # pylint: disable=expression-not-assigned
154 | | 'GenerateSingleExample' >> beam.ParDo(
155 | util.ExampleGenerator(_INTERPOLATOR_IMAGES_MAP))
156 | | 'WriteToTFRecord' >> beam.io.tfrecordio.WriteToTFRecord(
157 | file_path_prefix=_OUTPUT_TFRECORD_FILEPATH.value,
158 | num_shards=_NUM_SHARDS.value,
159 | coder=beam.coders.BytesCoder()))
160 | result = p.run()
161 | result.wait_until_finish()
162 |
163 | logging.info('Succeeded in creating the output TFRecord file: \'%s@%s\'.',
164 | _OUTPUT_TFRECORD_FILEPATH.value, str(_NUM_SHARDS.value))
165 |
166 | if __name__ == '__main__':
167 | app.run(main)
168 |
--------------------------------------------------------------------------------
/datasets/create_xiph_tfrecord.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | r"""Beam pipeline that generates Xiph triplet TFRecords.
16 |
17 | Xiph is a frame sequence dataset commonly used to assess video compression. See
18 | here: https://media.xiph.org/video/derf/
19 |
20 | The SoftSplat paper selected eight 4K clips with the most amount of motion and
21 | extracted the first 100 frames from each clip. Each frame is then either resized
22 | from 4K to 2K, or a 2K center crop from them is performed before interpolating
23 | the even frames from the odd frames. These datasets are denoted as `Xiph-2K`
24 | and `Xiph-4K` respectively. For more information see the project page:
25 | https://github.com/sniklaus/softmax-splatting
26 |
27 | Input is the root folder that contains the 800 frames of the eight clips. Set
28 | center_crop_factor=2 and scale_factor=1 to generate `Xiph-4K`,and scale_factor=2
29 | , center_crop_factor=1 to generate `Xiph-2K`. The scripts defaults to `Xiph-2K`.
30 |
31 | Output TFRecord is a tf.train.Example proto of each image triplet.
32 | The feature_map takes the form:
33 | feature_map {
34 | 'frame_0/encoded':
35 | tf.io.FixedLenFeature((), tf.string, default_value=''),
36 | 'frame_0/format':
37 | tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
38 | 'frame_0/height':
39 | tf.io.FixedLenFeature((), tf.int64, default_value=0),
40 | 'frame_0/width':
41 | tf.io.FixedLenFeature((), tf.int64, default_value=0),
42 | 'frame_1/encoded':
43 | tf.io.FixedLenFeature((), tf.string, default_value=''),
44 | 'frame_1/format':
45 | tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
46 | 'frame_1/height':
47 | tf.io.FixedLenFeature((), tf.int64, default_value=0),
48 | 'frame_1/width':
49 | tf.io.FixedLenFeature((), tf.int64, default_value=0),
50 | 'frame_2/encoded':
51 | tf.io.FixedLenFeature((), tf.string, default_value=''),
52 | 'frame_2/format':
53 | tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
54 | 'frame_2/height':
55 | tf.io.FixedLenFeature((), tf.int64, default_value=0),
56 | 'frame_2/width':
57 | tf.io.FixedLenFeature((), tf.int64, default_value=0),
58 | 'path':
59 | tf.io.FixedLenFeature((), tf.string, default_value=''),
60 | }
61 |
62 | Usage example:
63 | python3 -m frame_interpolation.datasets.create_xiph_tfrecord \
64 | --input_dir= \
65 | --scale_factor= \
66 | --center_crop_factor= \
67 | --output_tfrecord_filepath=
68 | """
69 | import os
70 |
71 | from . import util
72 | from absl import app
73 | from absl import flags
74 | from absl import logging
75 | import apache_beam as beam
76 | import tensorflow as tf
77 |
78 | _INPUT_DIR = flags.DEFINE_string(
79 | 'input_dir',
80 | default='/root/path/to/selected/xiph/clips',
81 | help='Path to the root directory of the `Xiph` interpolation evaluation '
82 | 'data. We expect the data to have been downloaded and unzipped.')
83 | _CENTER_CROP_FACTOR = flags.DEFINE_integer(
84 | 'center_crop_factor',
85 | default=1,
86 | help='Factor to center crop image. If set to 2, an image of the same '
87 | 'resolution as the inputs but half the size is created.')
88 | _SCALE_FACTOR = flags.DEFINE_integer(
89 | 'scale_factor',
90 | default=2,
91 | help='Factor to downsample frames.')
92 | _NUM_CLIPS = flags.DEFINE_integer(
93 | 'num_clips', default=8, help='Number of clips.')
94 | _NUM_FRAMES = flags.DEFINE_integer(
95 | 'num_frames', default=100, help='Number of frames per clip.')
96 | _OUTPUT_TFRECORD_FILEPATH = flags.DEFINE_string(
97 | 'output_tfrecord_filepath',
98 | default=None,
99 | required=True,
100 | help='Filepath to the output TFRecord file.')
101 | _NUM_SHARDS = flags.DEFINE_integer('num_shards',
102 | default=2,
103 | help='Number of shards used for the output.')
104 |
105 | # Image key -> offset for frame interpolator: start / middle / end frame offset.
106 | _INTERPOLATOR_IMAGES_MAP = {
107 | 'frame_0': -1,
108 | 'frame_1': 0,
109 | 'frame_2': 1,
110 | }
111 |
112 |
113 | def main(unused_argv):
114 | """Creates and runs a Beam pipeline to write frame triplets as a TFRecord."""
115 | # Collect the list of frame filenames.
116 | frames_list = sorted(tf.io.gfile.listdir(_INPUT_DIR.value))
117 |
118 | # Collect the triplets, even frames serving as golden to interpolate odds.
119 | triplets_dict = []
120 | for clip_index in range(_NUM_CLIPS.value):
121 | for frame_index in range(1, _NUM_FRAMES.value - 1, 2):
122 | index = clip_index * _NUM_FRAMES.value + frame_index
123 | triplet_dict = {
124 | image_key: os.path.join(_INPUT_DIR.value,
125 | frames_list[index + image_offset])
126 | for image_key, image_offset in _INTERPOLATOR_IMAGES_MAP.items()
127 | }
128 | triplets_dict.append(triplet_dict)
129 |
130 | p = beam.Pipeline('DirectRunner')
131 | (p | 'ReadInputTripletDicts' >> beam.Create(triplets_dict) # pylint: disable=expression-not-assigned
132 | | 'GenerateSingleExample' >> beam.ParDo(
133 | util.ExampleGenerator(_INTERPOLATOR_IMAGES_MAP, _SCALE_FACTOR.value,
134 | _CENTER_CROP_FACTOR.value))
135 | | 'WriteToTFRecord' >> beam.io.tfrecordio.WriteToTFRecord(
136 | file_path_prefix=_OUTPUT_TFRECORD_FILEPATH.value,
137 | num_shards=_NUM_SHARDS.value,
138 | coder=beam.coders.BytesCoder()))
139 | result = p.run()
140 | result.wait_until_finish()
141 |
142 | logging.info('Succeeded in creating the output TFRecord file: \'%s@%s\'.',
143 | _OUTPUT_TFRECORD_FILEPATH.value, str(_NUM_SHARDS.value))
144 |
145 | if __name__ == '__main__':
146 | app.run(main)
147 |
--------------------------------------------------------------------------------
/datasets/util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Utility functions for creating a tf.train.Example proto of image triplets."""
16 |
17 | import io
18 | import os
19 | from typing import Any, List, Mapping, Optional
20 |
21 | from absl import logging
22 | import apache_beam as beam
23 | import numpy as np
24 | import PIL.Image
25 | import six
26 | from skimage import transform
27 | import tensorflow as tf
28 |
29 | _UINT8_MAX_F = float(np.iinfo(np.uint8).max)
30 | _GAMMA = 2.2
31 |
32 |
33 | def _resample_image(image: np.ndarray, resample_image_width: int,
34 | resample_image_height: int) -> np.ndarray:
35 | """Re-samples and returns an `image` to be `resample_image_size`."""
36 | # Convert image from uint8 gamma [0..255] to float linear [0..1].
37 | image = image.astype(np.float32) / _UINT8_MAX_F
38 | image = np.power(np.clip(image, 0, 1), _GAMMA)
39 |
40 | # Re-size the image
41 | resample_image_size = (resample_image_height, resample_image_width)
42 | image = transform.resize_local_mean(image, resample_image_size)
43 |
44 | # Convert back from float linear [0..1] to uint8 gamma [0..255].
45 | image = np.power(np.clip(image, 0, 1), 1.0 / _GAMMA)
46 | image = np.clip(image * _UINT8_MAX_F + 0.5, 0.0,
47 | _UINT8_MAX_F).astype(np.uint8)
48 | return image
49 |
50 |
51 | def generate_image_triplet_example(
52 | triplet_dict: Mapping[str, str],
53 | scale_factor: int = 1,
54 | center_crop_factor: int = 1) -> Optional[tf.train.Example]:
55 | """Generates and serializes a tf.train.Example proto from an image triplet.
56 |
57 | Default setting creates a triplet Example with the input images unchanged.
58 | Images are processed in the order of center-crop then downscale.
59 |
60 | Args:
61 | triplet_dict: A dict of image key to filepath of the triplet images.
62 | scale_factor: An integer scale factor to isotropically downsample images.
63 | center_crop_factor: An integer cropping factor to center crop images with
64 | the original resolution but isotropically downsized by the factor.
65 |
66 | Returns:
67 | tf.train.Example proto, or None upon error.
68 |
69 | Raises:
70 | ValueError if triplet_dict length is different from three or the scale input
71 | arguments are non-positive.
72 | """
73 | if len(triplet_dict) != 3:
74 | raise ValueError(
75 | f'Length of triplet_dict must be exactly 3, not {len(triplet_dict)}.')
76 |
77 | if scale_factor <= 0 or center_crop_factor <= 0:
78 | raise ValueError(f'(scale_factor, center_crop_factor) must be positive, '
79 | f'Not ({scale_factor}, {center_crop_factor}).')
80 |
81 | feature = {}
82 |
83 | # Keep track of the path where the images came from for debugging purposes.
84 | mid_frame_path = os.path.dirname(triplet_dict['frame_1'])
85 | feature['path'] = tf.train.Feature(
86 | bytes_list=tf.train.BytesList(value=[six.ensure_binary(mid_frame_path)]))
87 |
88 | for image_key, image_path in triplet_dict.items():
89 | if not tf.io.gfile.exists(image_path):
90 | logging.error('File not found: %s', image_path)
91 | return None
92 |
93 | # Note: we need both the raw bytes and the image size.
94 | # PIL.Image does not expose a method to grab the original bytes.
95 | # (Also it is not aware of non-local file systems.)
96 | # So we read with tf.io.gfile.GFile to get the bytes, and then wrap the
97 | # bytes in BytesIO to let PIL.Image open the image.
98 | try:
99 | byte_array = tf.io.gfile.GFile(image_path, 'rb').read()
100 | except tf.errors.InvalidArgumentError:
101 | logging.exception('Cannot read image file: %s', image_path)
102 | return None
103 | try:
104 | pil_image = PIL.Image.open(io.BytesIO(byte_array))
105 | except PIL.UnidentifiedImageError:
106 | logging.exception('Cannot decode image file: %s', image_path)
107 | return None
108 | width, height = pil_image.size
109 | pil_image_format = pil_image.format
110 |
111 | # Optionally center-crop images and downsize images
112 | # by `center_crop_factor`.
113 | if center_crop_factor > 1:
114 | image = np.array(pil_image)
115 | quarter_height = image.shape[0] // (2 * center_crop_factor)
116 | quarter_width = image.shape[1] // (2 * center_crop_factor)
117 | image = image[quarter_height:-quarter_height,
118 | quarter_width:-quarter_width, :]
119 | pil_image = PIL.Image.fromarray(image)
120 |
121 | # Update image properties.
122 | height, width, _ = image.shape
123 | buffer = io.BytesIO()
124 | try:
125 | pil_image.save(buffer, format='PNG')
126 | except OSError:
127 | logging.exception('Cannot encode image file: %s', image_path)
128 | return None
129 | byte_array = buffer.getvalue()
130 |
131 | # Optionally downsample images by `scale_factor`.
132 | if scale_factor > 1:
133 | image = np.array(pil_image)
134 | image = _resample_image(image, image.shape[1] // scale_factor,
135 | image.shape[0] // scale_factor)
136 | pil_image = PIL.Image.fromarray(image)
137 |
138 | # Update image properties.
139 | height, width, _ = image.shape
140 | buffer = io.BytesIO()
141 | try:
142 | pil_image.save(buffer, format='PNG')
143 | except OSError:
144 | logging.exception('Cannot encode image file: %s', image_path)
145 | return None
146 | byte_array = buffer.getvalue()
147 |
148 | # Create tf Features.
149 | image_feature = tf.train.Feature(
150 | bytes_list=tf.train.BytesList(value=[byte_array]))
151 | height_feature = tf.train.Feature(
152 | int64_list=tf.train.Int64List(value=[height]))
153 | width_feature = tf.train.Feature(
154 | int64_list=tf.train.Int64List(value=[width]))
155 | encoding = tf.train.Feature(
156 | bytes_list=tf.train.BytesList(
157 | value=[six.ensure_binary(pil_image_format.lower())]))
158 |
159 | # Update feature map.
160 | feature[f'{image_key}/encoded'] = image_feature
161 | feature[f'{image_key}/format'] = encoding
162 | feature[f'{image_key}/height'] = height_feature
163 | feature[f'{image_key}/width'] = width_feature
164 |
165 | # Create tf Example.
166 | features = tf.train.Features(feature=feature)
167 | example = tf.train.Example(features=features)
168 | return example
169 |
170 |
171 | class ExampleGenerator(beam.DoFn):
172 | """Generate a tf.train.Example per input image triplet filepaths."""
173 |
174 | def __init__(self,
175 | images_map: Mapping[str, Any],
176 | scale_factor: int = 1,
177 | center_crop_factor: int = 1):
178 | """Initializes the map of 3 images to add to each tf.train.Example.
179 |
180 | Args:
181 | images_map: Map from image key to image filepath.
182 | scale_factor: A scale factor to downsample frames.
183 | center_crop_factor: A factor to centercrop and downsize frames.
184 | """
185 | super().__init__()
186 | self._images_map = images_map
187 | self._scale_factor = scale_factor
188 | self._center_crop_factor = center_crop_factor
189 |
190 | def process(self, triplet_dict: Mapping[str, str]) -> List[bytes]:
191 | """Generates a serialized tf.train.Example for a triplet of images.
192 |
193 | Args:
194 | triplet_dict: A dict of image key to filepath of the triplet images.
195 |
196 | Returns:
197 | A serialized tf.train.Example proto. No shuffling is applied.
198 | """
199 | example = generate_image_triplet_example(triplet_dict, self._scale_factor,
200 | self._center_crop_factor)
201 | if example:
202 | return [example.SerializeToString()]
203 | else:
204 | return []
205 |
--------------------------------------------------------------------------------
/eval/config/middlebury.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | experiment.name = 'middlebury'
16 | evaluation.max_examples = -1
17 | evaluation.metrics = ['l1', 'l2', 'ssim', 'psnr']
18 | evaluation.tfrecord = 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/middlebury_other.tfrecord@3'
19 |
--------------------------------------------------------------------------------
/eval/config/ucf101.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | experiment.name = 'ucf101'
16 | evaluation.max_examples = -1
17 | evaluation.metrics = ['l1', 'l2', 'ssim', 'psnr']
18 | evaluation.tfrecord = 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/UCF101_interp_test.tfrecord@2'
19 |
--------------------------------------------------------------------------------
/eval/config/vimeo_90K.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | experiment.name = 'vimeo_90K'
16 | evaluation.max_examples = -1
17 | evaluation.metrics = ['l1', 'l2', 'ssim', 'psnr']
18 | evaluation.tfrecord = 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/vimeo_interp_test.tfrecord@3'
19 |
--------------------------------------------------------------------------------
/eval/config/xiph_2K.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | experiment.name = 'xiph_2K'
16 | evaluation.max_examples = -1
17 | evaluation.metrics = ['l1', 'l2', 'ssim', 'psnr']
18 | evaluation.tfrecord = 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/xiph_2K.tfrecord@2'
19 |
--------------------------------------------------------------------------------
/eval/config/xiph_4K.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | experiment.name = 'xiph_4K'
16 | evaluation.max_examples = -1
17 | evaluation.metrics = ['l1', 'l2', 'ssim', 'psnr']
18 | evaluation.tfrecord = 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/xiph_4K.tfrecord@2'
19 |
--------------------------------------------------------------------------------
/eval/eval_cli.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | r"""Evaluate the frame interpolation model from a tfrecord and store results.
16 |
17 | This script runs the inference on examples in a tfrecord and generates images
18 | and numeric results according to the gin config. For details, see the
19 | run_evaluation() function below.
20 |
21 | Usage example:
22 | python3 -m frame_interpolation.eval.eval_cli -- \
23 | --gin_config \
24 | --base_folder \
25 | --label < the foldername of the training session>
26 |
27 | or
28 |
29 | python3 -m frame_interpolation.eval.eval_cli -- \
30 | --gin_config \
31 | --model_path
32 |
33 | The output is saved at the parent directory of the `model_path`:
34 | /batch_eval.
35 |
36 | The evaluation is run on a GPU by default. Add the `--mode` argument for others.
37 | """
38 | import collections
39 | import os
40 | from typing import Any, Dict
41 |
42 | from . import util
43 | from absl import app
44 | from absl import flags
45 | from absl import logging
46 | import gin.tf
47 | from ..losses import losses
48 | import numpy as np
49 | import tensorflow as tf
50 | from ..training import data_lib
51 |
52 |
53 | _GIN_CONFIG = flags.DEFINE_string('gin_config', None, 'Gin config file.')
54 | _LABEL = flags.DEFINE_string(
55 | 'label', None, 'Descriptive label for the training session to eval.')
56 | _BASE_FOLDER = flags.DEFINE_string('base_folder', None,
57 | 'Root folder of training sessions.')
58 | _MODEL_PATH = flags.DEFINE_string(
59 | name='model_path',
60 | default=None,
61 | help='The path of the TF2 saved model to use. If _MODEL_PATH argument is '
62 | 'directly specified, _LABEL and _BASE_FOLDER arguments will be ignored.')
63 | _OUTPUT_FRAMES = flags.DEFINE_boolean(
64 | name='output_frames',
65 | default=False,
66 | help='If true, saves the the inputs, groud-truth and interpolated frames.')
67 | _MODE = flags.DEFINE_enum('mode', 'gpu', ['cpu', 'gpu'],
68 | 'Device to run evaluations.')
69 |
70 |
71 | @gin.configurable('experiment')
72 | def _get_experiment_config(name) -> Dict[str, Any]:
73 | """Fetches the gin config."""
74 | return {
75 | 'name': name,
76 | }
77 |
78 |
79 | def _set_visible_devices():
80 | """Set the visible devices according to running mode."""
81 | mode_devices = tf.config.list_physical_devices(_MODE.value.upper())
82 | tf.config.set_visible_devices([], 'GPU')
83 | tf.config.set_visible_devices([], 'TPU')
84 | tf.config.set_visible_devices(mode_devices, _MODE.value.upper())
85 | return
86 |
87 |
88 | @gin.configurable('evaluation')
89 | def run_evaluation(model_path, tfrecord, output_dir, max_examples, metrics):
90 | """Runs the eval loop for examples in the tfrecord.
91 |
92 | The evaluation is run for the first 'max_examples' number of examples, and
93 | resulting images are stored into the given output_dir. Any tensor that
94 | appears like an image is stored with its name -- this may include intermediate
95 | results, depending on what the model outputs.
96 |
97 | Additionally, numeric results are stored into results.csv file within the same
98 | directory. This includes per-example metrics and the mean across the whole
99 | dataset.
100 |
101 | Args:
102 | model_path: Directory TF2 saved model.
103 | tfrecord: Directory to the tfrecord eval data.
104 | output_dir: Directory to store the results into.
105 | max_examples: Maximum examples to evaluate.
106 | metrics: The names of loss functions to use.
107 | """
108 | model = tf.saved_model.load(model_path)
109 |
110 | # Store a 'readme.txt' that contains information on where the data came from.
111 | with tf.io.gfile.GFile(os.path.join(output_dir, 'readme.txt'), mode='w') as f:
112 | print('Results for:', file=f)
113 | print(f' model: {model_path}', file=f)
114 | print(f' tfrecord: {tfrecord}', file=f)
115 |
116 | with tf.io.gfile.GFile(
117 | os.path.join(output_dir, 'results.csv'), mode='w') as csv_file:
118 | test_losses = losses.test_losses(metrics, [
119 | 1.0,
120 | ] * len(metrics))
121 | title_row = ['key'] + list(test_losses)
122 | print(', '.join(title_row), file=csv_file)
123 |
124 | datasets = data_lib.create_eval_datasets(
125 | batch_size=1,
126 | files=[tfrecord],
127 | names=[os.path.basename(output_dir)],
128 | max_examples=max_examples)
129 | dataset = datasets[os.path.basename(output_dir)]
130 |
131 | all_losses = collections.defaultdict(list)
132 | for example in dataset:
133 | inputs = {
134 | 'x0': example['x0'],
135 | 'x1': example['x1'],
136 | 'time': example['time'][..., tf.newaxis],
137 | }
138 | prediction = model(inputs, training=False)
139 |
140 | # Get the key from encoded mid-frame path.
141 | path = example['path'][0].numpy().decode('utf-8')
142 | key = path.rsplit('.', 1)[0].rsplit(os.sep)[-1]
143 |
144 | # Combines both inputs and outputs into a single dictionary:
145 | combined = {**prediction, **example} if _OUTPUT_FRAMES.value else {}
146 | for name in combined:
147 | image = combined[name]
148 | if isinstance(image, tf.Tensor):
149 | # This saves any tensor that has a shape that can be interpreted
150 | # as an image, e.g. (1, H, W, C), where the batch dimension is always
151 | # 1, H and W are the image height and width, and C is either 1 or 3
152 | # (grayscale or color image).
153 | if len(image.shape) == 4 and (image.shape[-1] == 1 or
154 | image.shape[-1] == 3):
155 | util.write_image(
156 | os.path.join(output_dir, f'{key}_{name}.png'), image[0].numpy())
157 |
158 | # Evaluate losses if the dataset has ground truth 'y', otherwise just do
159 | # a visual eval.
160 | if 'y' in example:
161 | loss_values = []
162 | # Clip interpolator output to the range [0,1]. Clipping is done only
163 | # on the eval loop to get better metrics, but not on the training loop
164 | # so gradients are not killed.
165 | prediction['image'] = tf.clip_by_value(prediction['image'], 0., 1.)
166 | for loss_name, (loss_value_fn, loss_weight_fn) in test_losses.items():
167 | loss_value = loss_value_fn(example, prediction) * loss_weight_fn(0)
168 | loss_values.append(loss_value.numpy())
169 | all_losses[loss_name].append(loss_value.numpy())
170 | print(f'{key}, {str(loss_values)[1:-1]}', file=csv_file)
171 |
172 | if all_losses:
173 | totals = [np.mean(all_losses[loss_name]) for loss_name in test_losses]
174 | print(f'mean, {str(totals)[1:-1]}', file=csv_file)
175 | totals_dict = {
176 | loss_name: np.mean(all_losses[loss_name]) for loss_name in test_losses
177 | }
178 | logging.info('mean, %s', totals_dict)
179 |
180 |
181 | def main(argv):
182 | if len(argv) > 1:
183 | raise app.UsageError('Too many command-line arguments.')
184 |
185 | if _MODEL_PATH.value is not None:
186 | model_path = _MODEL_PATH.value
187 | else:
188 | model_path = os.path.join(_BASE_FOLDER.value, _LABEL.value, 'saved_model')
189 |
190 | gin.parse_config_files_and_bindings(
191 | config_files=[_GIN_CONFIG.value],
192 | bindings=None,
193 | skip_unknown=True)
194 |
195 | config = _get_experiment_config() # pylint: disable=no-value-for-parameter
196 | eval_name = config['name']
197 | output_dir = os.path.join(
198 | os.path.dirname(model_path), 'batch_eval', eval_name)
199 | logging.info('Creating output_dir @ %s ...', output_dir)
200 |
201 | # Copy config file to //batch_eval//config.gin.
202 | tf.io.gfile.makedirs(output_dir)
203 | tf.io.gfile.copy(
204 | _GIN_CONFIG.value, os.path.join(output_dir, 'config.gin'), overwrite=True)
205 |
206 | _set_visible_devices()
207 | logging.info('Evaluating %s on %s ...', eval_name, [
208 | el.name.split('/physical_device:')[-1]
209 | for el in tf.config.get_visible_devices()
210 | ])
211 | run_evaluation(model_path=model_path, output_dir=output_dir) # pylint: disable=no-value-for-parameter
212 |
213 | logging.info('Done. Evaluations saved @ %s.', output_dir)
214 |
215 | if __name__ == '__main__':
216 | app.run(main)
217 |
--------------------------------------------------------------------------------
/eval/interpolator.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """A wrapper class for running a frame interpolation TF2 saved model.
16 |
17 | Usage:
18 | model_path='/tmp/saved_model/'
19 | it = Interpolator(model_path)
20 | result_batch = it.interpolate(image_batch_0, image_batch_1, batch_dt)
21 |
22 | Where image_batch_1 and image_batch_2 are numpy tensors with TF standard
23 | (B,H,W,C) layout, batch_dt is the sub-frame time in range [0,1], (B,) layout.
24 | """
25 | from typing import List, Optional
26 | import numpy as np
27 | import tensorflow as tf
28 |
29 |
30 | def _pad_to_align(x, align):
31 | """Pad image batch x so width and height divide by align.
32 |
33 | Args:
34 | x: Image batch to align.
35 | align: Number to align to.
36 |
37 | Returns:
38 | 1) An image padded so width % align == 0 and height % align == 0.
39 | 2) A bounding box that can be fed readily to tf.image.crop_to_bounding_box
40 | to undo the padding.
41 | """
42 | # Input checking.
43 | assert np.ndim(x) == 4
44 | assert align > 0, 'align must be a positive number.'
45 |
46 | height, width = x.shape[-3:-1]
47 | height_to_pad = (align - height % align) if height % align != 0 else 0
48 | width_to_pad = (align - width % align) if width % align != 0 else 0
49 |
50 | bbox_to_pad = {
51 | 'offset_height': height_to_pad // 2,
52 | 'offset_width': width_to_pad // 2,
53 | 'target_height': height + height_to_pad,
54 | 'target_width': width + width_to_pad
55 | }
56 | padded_x = tf.image.pad_to_bounding_box(x, **bbox_to_pad)
57 | bbox_to_crop = {
58 | 'offset_height': height_to_pad // 2,
59 | 'offset_width': width_to_pad // 2,
60 | 'target_height': height,
61 | 'target_width': width
62 | }
63 | return padded_x, bbox_to_crop
64 |
65 |
66 | def image_to_patches(image: np.ndarray, block_shape: List[int]) -> np.ndarray:
67 | """Folds an image into patches and stacks along the batch dimension.
68 |
69 | Args:
70 | image: The input image of shape [B, H, W, C].
71 | block_shape: The number of patches along the height and width to extract.
72 | Each patch is shaped (H/block_shape[0], W/block_shape[1])
73 |
74 | Returns:
75 | The extracted patches shaped [num_blocks, patch_height, patch_width,...],
76 | with num_blocks = block_shape[0] * block_shape[1].
77 | """
78 | block_height, block_width = block_shape
79 | num_blocks = block_height * block_width
80 |
81 | height, width, channel = image.shape[-3:]
82 | patch_height, patch_width = height//block_height, width//block_width
83 |
84 | assert height == (
85 | patch_height * block_height
86 | ), 'block_height=%d should evenly divide height=%d.'%(block_height, height)
87 | assert width == (
88 | patch_width * block_width
89 | ), 'block_width=%d should evenly divide width=%d.'%(block_width, width)
90 |
91 | patch_size = patch_height * patch_width
92 | paddings = 2*[[0, 0]]
93 |
94 | patches = tf.space_to_batch(image, [patch_height, patch_width], paddings)
95 | patches = tf.split(patches, patch_size, 0)
96 | patches = tf.stack(patches, axis=3)
97 | patches = tf.reshape(patches,
98 | [num_blocks, patch_height, patch_width, channel])
99 | return patches.numpy()
100 |
101 |
102 | def patches_to_image(patches: np.ndarray, block_shape: List[int]) -> np.ndarray:
103 | """Unfolds patches (stacked along batch) into an image.
104 |
105 | Args:
106 | patches: The input patches, shaped [num_patches, patch_H, patch_W, C].
107 | block_shape: The number of patches along the height and width to unfold.
108 | Each patch assumed to be shaped (H/block_shape[0], W/block_shape[1]).
109 |
110 | Returns:
111 | The unfolded image shaped [B, H, W, C].
112 | """
113 | block_height, block_width = block_shape
114 | paddings = 2 * [[0, 0]]
115 |
116 | patch_height, patch_width, channel = patches.shape[-3:]
117 | patch_size = patch_height * patch_width
118 |
119 | patches = tf.reshape(patches,
120 | [1, block_height, block_width, patch_size, channel])
121 | patches = tf.split(patches, patch_size, axis=3)
122 | patches = tf.stack(patches, axis=0)
123 | patches = tf.reshape(patches,
124 | [patch_size, block_height, block_width, channel])
125 | image = tf.batch_to_space(patches, [patch_height, patch_width], paddings)
126 | return image.numpy()
127 |
128 |
129 | class Interpolator:
130 | """A class for generating interpolated frames between two input frames.
131 |
132 | Uses TF2 saved model format.
133 | """
134 |
135 | def __init__(self, model_path: str,
136 | align: Optional[int] = None,
137 | block_shape: Optional[List[int]] = None) -> None:
138 | """Loads a saved model.
139 |
140 | Args:
141 | model_path: Path to the saved model. If none are provided, uses the
142 | default model.
143 | align: 'If >1, pad the input size so it divides with this before
144 | inference.'
145 | block_shape: Number of patches along the (height, width) to sid-divide
146 | input images.
147 | """
148 | self._model = tf.compat.v2.saved_model.load(model_path)
149 | self._align = align or None
150 | self._block_shape = block_shape or None
151 |
152 | def interpolate(self, x0: np.ndarray, x1: np.ndarray,
153 | dt: np.ndarray) -> np.ndarray:
154 | """Generates an interpolated frame between given two batches of frames.
155 |
156 | All input tensors should be np.float32 datatype.
157 |
158 | Args:
159 | x0: First image batch. Dimensions: (batch_size, height, width, channels)
160 | x1: Second image batch. Dimensions: (batch_size, height, width, channels)
161 | dt: Sub-frame time. Range [0,1]. Dimensions: (batch_size,)
162 |
163 | Returns:
164 | The result with dimensions (batch_size, height, width, channels).
165 | """
166 | if self._align is not None:
167 | x0, bbox_to_crop = _pad_to_align(x0, self._align)
168 | x1, _ = _pad_to_align(x1, self._align)
169 |
170 | inputs = {'x0': x0, 'x1': x1, 'time': dt[..., np.newaxis]}
171 | result = self._model(inputs, training=False)
172 | image = result['image']
173 |
174 | if self._align is not None:
175 | image = tf.image.crop_to_bounding_box(image, **bbox_to_crop)
176 | return image.numpy()
177 |
178 | def __call__(self, x0: np.ndarray, x1: np.ndarray,
179 | dt: np.ndarray) -> np.ndarray:
180 | """Generates an interpolated frame between given two batches of frames.
181 |
182 | All input tensors should be np.float32 datatype.
183 |
184 | Args:
185 | x0: First image batch. Dimensions: (batch_size, height, width, channels)
186 | x1: Second image batch. Dimensions: (batch_size, height, width, channels)
187 | dt: Sub-frame time. Range [0,1]. Dimensions: (batch_size,)
188 |
189 | Returns:
190 | The result with dimensions (batch_size, height, width, channels).
191 | """
192 | if self._block_shape is not None and np.prod(self._block_shape) > 1:
193 | # Subdivide high-res images into managable non-overlapping patches.
194 | x0_patches = image_to_patches(x0, self._block_shape)
195 | x1_patches = image_to_patches(x1, self._block_shape)
196 |
197 | # Run the interpolator on each patch pair.
198 | output_patches = []
199 | for image_0, image_1 in zip(x0_patches, x1_patches):
200 | mid_patch = self.interpolate(image_0[np.newaxis, ...],
201 | image_1[np.newaxis, ...], dt)
202 | output_patches.append(mid_patch)
203 |
204 | # Reconstruct interpolated image by stitching interpolated patches.
205 | output_patches = np.concatenate(output_patches, axis=0)
206 | return patches_to_image(output_patches, self._block_shape)
207 | else:
208 | # Invoke the interpolator once.
209 | return self.interpolate(x0, x1, dt)
210 |
--------------------------------------------------------------------------------
/eval/interpolator_cli.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | r"""Runs the FILM frame interpolator on a pair of frames on beam.
16 |
17 | This script is used evaluate the output quality of the FILM Tensorflow frame
18 | interpolator. Optionally, it outputs a video of the interpolated frames.
19 |
20 | A beam pipeline for invoking the frame interpolator on a set of directories
21 | identified by a glob (--pattern). Each directory is expected to contain two
22 | input frames that are the inputs to the frame interpolator. If a directory has
23 | more than two frames, then each contiguous frame pair is treated as input to
24 | generate in-between frames.
25 |
26 | The output video is stored to interpolator.mp4 in each directory. The number of
27 | frames is determined by --times_to_interpolate, which controls the number of
28 | times the frame interpolator is invoked. When the number of input frames is 2,
29 | the number of output frames is 2^times_to_interpolate+1.
30 |
31 | This expects a directory structure such as:
32 | /01/frame1.png
33 | frame2.png
34 | /02/frame1.png
35 | frame2.png
36 | /03/frame1.png
37 | frame2.png
38 | ...
39 |
40 | And will produce:
41 | /01/interpolated_frames/frame0.png
42 | frame1.png
43 | frame2.png
44 | /02/interpolated_frames/frame0.png
45 | frame1.png
46 | frame2.png
47 | /03/interpolated_frames/frame0.png
48 | frame1.png
49 | frame2.png
50 | ...
51 |
52 | And optionally will produce:
53 | /01/interpolated.mp4
54 | /02/interpolated.mp4
55 | /03/interpolated.mp4
56 | ...
57 |
58 | Usage example:
59 | python3 -m frame_interpolation.eval.interpolator_cli \
60 | --model_path \
61 | --pattern "/*" \
62 | --times_to_interpolate
63 | """
64 |
65 | import functools
66 | import os
67 | from typing import List, Sequence
68 |
69 | from . import interpolator as interpolator_lib
70 | from . import util
71 | from absl import app
72 | from absl import flags
73 | from absl import logging
74 | import apache_beam as beam
75 | import mediapy as media
76 | import natsort
77 | import numpy as np
78 | import tensorflow as tf
79 | from tqdm.auto import tqdm
80 |
81 | # Controls TF_CCP log level.
82 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
83 |
84 |
85 | _PATTERN = flags.DEFINE_string(
86 | name='pattern',
87 | default=None,
88 | help='The pattern to determine the directories with the input frames.',
89 | required=True)
90 | _MODEL_PATH = flags.DEFINE_string(
91 | name='model_path',
92 | default=None,
93 | help='The path of the TF2 saved model to use.')
94 | _TIMES_TO_INTERPOLATE = flags.DEFINE_integer(
95 | name='times_to_interpolate',
96 | default=5,
97 | help='The number of times to run recursive midpoint interpolation. '
98 | 'The number of output frames will be 2^times_to_interpolate+1.')
99 | _FPS = flags.DEFINE_integer(
100 | name='fps',
101 | default=30,
102 | help='Frames per second to play interpolated videos in slow motion.')
103 | _ALIGN = flags.DEFINE_integer(
104 | name='align',
105 | default=64,
106 | help='If >1, pad the input size so it is evenly divisible by this value.')
107 | _BLOCK_HEIGHT = flags.DEFINE_integer(
108 | name='block_height',
109 | default=1,
110 | help='An int >= 1, number of patches along height, '
111 | 'patch_height = height//block_height, should be evenly divisible.')
112 | _BLOCK_WIDTH = flags.DEFINE_integer(
113 | name='block_width',
114 | default=1,
115 | help='An int >= 1, number of patches along width, '
116 | 'patch_width = width//block_width, should be evenly divisible.')
117 | _OUTPUT_VIDEO = flags.DEFINE_boolean(
118 | name='output_video',
119 | default=False,
120 | help='If true, creates a video of the frames in the interpolated_frames/ '
121 | 'subdirectory')
122 |
123 | # Add other extensions, if not either.
124 | _INPUT_EXT = ['png', 'jpg', 'jpeg']
125 |
126 |
127 | def _output_frames(frames: List[np.ndarray], frames_dir: str):
128 | """Writes PNG-images to a directory.
129 |
130 | If frames_dir doesn't exist, it is created. If frames_dir contains existing
131 | PNG-files, they are removed before saving the new ones.
132 |
133 | Args:
134 | frames: List of images to save.
135 | frames_dir: The output directory to save the images.
136 |
137 | """
138 | if tf.io.gfile.isdir(frames_dir):
139 | old_frames = tf.io.gfile.glob(f'{frames_dir}/frame_*.png')
140 | if old_frames:
141 | logging.info('Removing existing frames from %s.', frames_dir)
142 | for old_frame in old_frames:
143 | tf.io.gfile.remove(old_frame)
144 | else:
145 | tf.io.gfile.makedirs(frames_dir)
146 | for idx, frame in tqdm(
147 | enumerate(frames), total=len(frames), ncols=100, colour='green'):
148 | util.write_image(f'{frames_dir}/frame_{idx:03d}.png', frame)
149 | logging.info('Output frames saved in %s.', frames_dir)
150 |
151 |
152 | class ProcessDirectory(beam.DoFn):
153 | """DoFn for running the interpolator on a single directory at the time."""
154 |
155 | def setup(self):
156 | self.interpolator = interpolator_lib.Interpolator(
157 | _MODEL_PATH.value, _ALIGN.value,
158 | [_BLOCK_HEIGHT.value, _BLOCK_WIDTH.value])
159 |
160 | if _OUTPUT_VIDEO.value:
161 | ffmpeg_path = util.get_ffmpeg_path()
162 | media.set_ffmpeg(ffmpeg_path)
163 |
164 | def process(self, directory: str):
165 | input_frames_list = [
166 | natsort.natsorted(tf.io.gfile.glob(f'{directory}/*.{ext}'))
167 | for ext in _INPUT_EXT
168 | ]
169 | input_frames = functools.reduce(lambda x, y: x + y, input_frames_list)
170 | logging.info('Generating in-between frames for %s.', directory)
171 | frames = list(
172 | util.interpolate_recursively_from_files(
173 | input_frames, _TIMES_TO_INTERPOLATE.value, self.interpolator))
174 | _output_frames(frames, f'{directory}/interpolated_frames')
175 | if _OUTPUT_VIDEO.value:
176 | media.write_video(f'{directory}/interpolated.mp4', frames, fps=_FPS.value)
177 | logging.info('Output video saved at %s/interpolated.mp4.', directory)
178 |
179 |
180 | def _run_pipeline() -> None:
181 | directories = tf.io.gfile.glob(_PATTERN.value)
182 | pipeline = beam.Pipeline('DirectRunner')
183 | (pipeline | 'Create directory names' >> beam.Create(directories) # pylint: disable=expression-not-assigned
184 | | 'Process directories' >> beam.ParDo(ProcessDirectory()))
185 |
186 | result = pipeline.run()
187 | result.wait_until_finish()
188 |
189 |
190 | def main(argv: Sequence[str]) -> None:
191 | if len(argv) > 1:
192 | raise app.UsageError('Too many command-line arguments.')
193 | _run_pipeline()
194 |
195 |
196 | if __name__ == '__main__':
197 | app.run(main)
198 |
--------------------------------------------------------------------------------
/eval/interpolator_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | r"""A test script for mid frame interpolation from two input frames.
16 |
17 | Usage example:
18 | python3 -m frame_interpolation.eval.interpolator_test \
19 | --frame1 \
20 | --frame2 \
21 | --model_path
22 |
23 | The output is saved to /output_frame.png. If
24 | `--output_frame` filepath is provided, it will be used instead.
25 | """
26 | import os
27 | from typing import Sequence
28 |
29 | from . import interpolator as interpolator_lib
30 | from . import util
31 | from absl import app
32 | from absl import flags
33 | import numpy as np
34 |
35 | # Controls TF_CCP log level.
36 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
37 |
38 |
39 | _FRAME1 = flags.DEFINE_string(
40 | name='frame1',
41 | default=None,
42 | help='The filepath of the first input frame.',
43 | required=True)
44 | _FRAME2 = flags.DEFINE_string(
45 | name='frame2',
46 | default=None,
47 | help='The filepath of the second input frame.',
48 | required=True)
49 | _MODEL_PATH = flags.DEFINE_string(
50 | name='model_path',
51 | default=None,
52 | help='The path of the TF2 saved model to use.')
53 | _OUTPUT_FRAME = flags.DEFINE_string(
54 | name='output_frame',
55 | default=None,
56 | help='The output filepath of the interpolated mid-frame.')
57 | _ALIGN = flags.DEFINE_integer(
58 | name='align',
59 | default=64,
60 | help='If >1, pad the input size so it is evenly divisible by this value.')
61 | _BLOCK_HEIGHT = flags.DEFINE_integer(
62 | name='block_height',
63 | default=1,
64 | help='An int >= 1, number of patches along height, '
65 | 'patch_height = height//block_height, should be evenly divisible.')
66 | _BLOCK_WIDTH = flags.DEFINE_integer(
67 | name='block_width',
68 | default=1,
69 | help='An int >= 1, number of patches along width, '
70 | 'patch_width = width//block_width, should be evenly divisible.')
71 |
72 |
73 | def _run_interpolator() -> None:
74 | """Writes interpolated mid frame from a given two input frame filepaths."""
75 |
76 | interpolator = interpolator_lib.Interpolator(
77 | model_path=_MODEL_PATH.value,
78 | align=_ALIGN.value,
79 | block_shape=[_BLOCK_HEIGHT.value, _BLOCK_WIDTH.value])
80 |
81 | # First batched image.
82 | image_1 = util.read_image(_FRAME1.value)
83 | image_batch_1 = np.expand_dims(image_1, axis=0)
84 |
85 | # Second batched image.
86 | image_2 = util.read_image(_FRAME2.value)
87 | image_batch_2 = np.expand_dims(image_2, axis=0)
88 |
89 | # Batched time.
90 | batch_dt = np.full(shape=(1,), fill_value=0.5, dtype=np.float32)
91 |
92 | # Invoke the model for one mid-frame interpolation.
93 | mid_frame = interpolator(image_batch_1, image_batch_2, batch_dt)[0]
94 |
95 | # Write interpolated mid-frame.
96 | mid_frame_filepath = _OUTPUT_FRAME.value
97 | if not mid_frame_filepath:
98 | mid_frame_filepath = f'{os.path.dirname(_FRAME1.value)}/output_frame.png'
99 | util.write_image(mid_frame_filepath, mid_frame)
100 |
101 |
102 | def main(argv: Sequence[str]) -> None:
103 | if len(argv) > 1:
104 | raise app.UsageError('Too many command-line arguments.')
105 | _run_interpolator()
106 |
107 |
108 | if __name__ == '__main__':
109 | app.run(main)
110 |
--------------------------------------------------------------------------------
/eval/util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Utility functions for frame interpolation on a set of video frames."""
16 | import os
17 | import shutil
18 | from typing import Generator, Iterable, List, Optional
19 |
20 | from . import interpolator as interpolator_lib
21 | import numpy as np
22 | import tensorflow as tf
23 | from tqdm import tqdm
24 |
25 | _UINT8_MAX_F = float(np.iinfo(np.uint8).max)
26 | _CONFIG_FFMPEG_NAME_OR_PATH = 'ffmpeg'
27 |
28 |
29 | def read_image(filename: str) -> np.ndarray:
30 | """Reads an sRgb 8-bit image.
31 |
32 | Args:
33 | filename: The input filename to read.
34 |
35 | Returns:
36 | A float32 3-channel (RGB) ndarray with colors in the [0..1] range.
37 | """
38 | image_data = tf.io.read_file(filename)
39 | image = tf.io.decode_image(image_data, channels=3)
40 | image_numpy = tf.cast(image, dtype=tf.float32).numpy()
41 | return image_numpy / _UINT8_MAX_F
42 |
43 |
44 | def write_image(filename: str, image: np.ndarray) -> None:
45 | """Writes a float32 3-channel RGB ndarray image, with colors in range [0..1].
46 |
47 | Args:
48 | filename: The output filename to save.
49 | image: A float32 3-channel (RGB) ndarray with colors in the [0..1] range.
50 | """
51 | image_in_uint8_range = np.clip(image * _UINT8_MAX_F, 0.0, _UINT8_MAX_F)
52 | image_in_uint8 = (image_in_uint8_range + 0.5).astype(np.uint8)
53 |
54 | extension = os.path.splitext(filename)[1]
55 | if extension == '.jpg':
56 | image_data = tf.io.encode_jpeg(image_in_uint8)
57 | else:
58 | image_data = tf.io.encode_png(image_in_uint8)
59 | tf.io.write_file(filename, image_data)
60 |
61 |
62 | def _recursive_generator(
63 | frame1: np.ndarray, frame2: np.ndarray, num_recursions: int,
64 | interpolator: interpolator_lib.Interpolator,
65 | bar: Optional[tqdm] = None
66 | ) -> Generator[np.ndarray, None, None]:
67 | """Splits halfway to repeatedly generate more frames.
68 |
69 | Args:
70 | frame1: Input image 1.
71 | frame2: Input image 2.
72 | num_recursions: How many times to interpolate the consecutive image pairs.
73 | interpolator: The frame interpolator instance.
74 |
75 | Yields:
76 | The interpolated frames, including the first frame (frame1), but excluding
77 | the final frame2.
78 | """
79 | if num_recursions == 0:
80 | yield frame1
81 | else:
82 | # Adds the batch dimension to all inputs before calling the interpolator,
83 | # and remove it afterwards.
84 | time = np.full(shape=(1,), fill_value=0.5, dtype=np.float32)
85 | mid_frame = interpolator(frame1[np.newaxis, ...], frame2[np.newaxis, ...],
86 | time)[0]
87 | bar.update(1) if bar is not None else bar
88 | yield from _recursive_generator(frame1, mid_frame, num_recursions - 1,
89 | interpolator, bar)
90 | yield from _recursive_generator(mid_frame, frame2, num_recursions - 1,
91 | interpolator, bar)
92 |
93 |
94 | def interpolate_recursively_from_files(
95 | frames: List[str], times_to_interpolate: int,
96 | interpolator: interpolator_lib.Interpolator) -> Iterable[np.ndarray]:
97 | """Generates interpolated frames by repeatedly interpolating the midpoint.
98 |
99 | Loads the files on demand and uses the yield paradigm to return the frames
100 | to allow streamed processing of longer videos.
101 |
102 | Recursive interpolation is useful if the interpolator is trained to predict
103 | frames at midpoint only and is thus expected to perform poorly elsewhere.
104 |
105 | Args:
106 | frames: List of input frames. Expected shape (H, W, 3). The colors should be
107 | in the range[0, 1] and in gamma space.
108 | times_to_interpolate: Number of times to do recursive midpoint
109 | interpolation.
110 | interpolator: The frame interpolation model to use.
111 |
112 | Yields:
113 | The interpolated frames (including the inputs).
114 | """
115 | n = len(frames)
116 | num_frames = (n - 1) * (2**(times_to_interpolate) - 1)
117 | bar = tqdm(total=num_frames, ncols=100, colour='green')
118 | for i in range(1, n):
119 | yield from _recursive_generator(
120 | read_image(frames[i - 1]), read_image(frames[i]), times_to_interpolate,
121 | interpolator, bar)
122 | # Separately yield the final frame.
123 | yield read_image(frames[-1])
124 |
125 | def interpolate_recursively_from_memory(
126 | frames: List[np.ndarray], times_to_interpolate: int,
127 | interpolator: interpolator_lib.Interpolator) -> Iterable[np.ndarray]:
128 | """Generates interpolated frames by repeatedly interpolating the midpoint.
129 |
130 | This is functionally equivalent to interpolate_recursively_from_files(), but
131 | expects the inputs frames in memory, instead of loading them on demand.
132 |
133 | Recursive interpolation is useful if the interpolator is trained to predict
134 | frames at midpoint only and is thus expected to perform poorly elsewhere.
135 |
136 | Args:
137 | frames: List of input frames. Expected shape (H, W, 3). The colors should be
138 | in the range[0, 1] and in gamma space.
139 | times_to_interpolate: Number of times to do recursive midpoint
140 | interpolation.
141 | interpolator: The frame interpolation model to use.
142 |
143 | Yields:
144 | The interpolated frames (including the inputs).
145 | """
146 | n = len(frames)
147 | num_frames = (n - 1) * (2**(times_to_interpolate) - 1)
148 | bar = tqdm(total=num_frames, ncols=100, colour='green')
149 | for i in range(1, n):
150 | yield from _recursive_generator(frames[i - 1], frames[i],
151 | times_to_interpolate, interpolator, bar)
152 | # Separately yield the final frame.
153 | yield frames[-1]
154 |
155 |
156 | def get_ffmpeg_path() -> str:
157 | path = shutil.which(_CONFIG_FFMPEG_NAME_OR_PATH)
158 | if not path:
159 | raise RuntimeError(
160 | f"Program '{_CONFIG_FFMPEG_NAME_OR_PATH}' is not found;"
161 | " perhaps install ffmpeg using 'apt-get install ffmpeg'.")
162 | return path
163 |
--------------------------------------------------------------------------------
/losses/losses.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Loss functions used to train the FILM interpolation model.
16 |
17 | The losses for training and test loops are configurable via gin. Training can
18 | use more than one loss function. Test loop can also evaluate one ore more loss
19 | functions, each of which can be summarized separately.
20 | """
21 | from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple
22 |
23 | from . import vgg19_loss as vgg19
24 | import gin.tf
25 | import numpy as np
26 | import tensorflow as tf
27 |
28 |
29 | @gin.configurable('vgg', denylist=['example', 'prediction'])
30 | def vgg_loss(example: Mapping[str, tf.Tensor],
31 | prediction: Mapping[str, tf.Tensor],
32 | vgg_model_file: str,
33 | weights: Optional[List[float]] = None) -> tf.Tensor:
34 | """Perceptual loss for images in [0,1] color range.
35 |
36 | Args:
37 | example: A dictionary with the ground truth image as 'y'.
38 | prediction: The prediction dictionary with the image as 'image'.
39 | vgg_model_file: The path containing the vgg19 weights in MATLAB format.
40 | weights: An optional array of weights for different VGG layers. If None, the
41 | default weights are used (see vgg19.vgg_loss documentation).
42 |
43 | Returns:
44 | The perceptual loss.
45 | """
46 | return vgg19.vgg_loss(prediction['image'], example['y'], vgg_model_file,
47 | weights)
48 |
49 |
50 | @gin.configurable('style', denylist=['example', 'prediction'])
51 | def style_loss(example: Mapping[str, tf.Tensor],
52 | prediction: Mapping[str, tf.Tensor],
53 | vgg_model_file: str,
54 | weights: Optional[List[float]] = None) -> tf.Tensor:
55 | """Computes style loss from images in [0..1] color range.
56 |
57 | Args:
58 | example: A dictionary with the ground truth image as 'y'.
59 | prediction: The prediction dictionary with the image as 'image'.
60 | vgg_model_file: The path containing the vgg19 weights in MATLAB format.
61 | weights: An optional array of weights for different VGG layers. If None, the
62 | default weights are used (see vgg19.vgg_loss documentation).
63 |
64 | Returns:
65 | A tf.Tensor of a scalar representing the style loss computed over multiple
66 | vgg layer features.
67 | """
68 | return vgg19.style_loss(prediction['image'], example['y'], vgg_model_file,
69 | weights)
70 |
71 |
72 | def l1_loss(example: Mapping[str, tf.Tensor],
73 | prediction: Mapping[str, tf.Tensor]) -> tf.Tensor:
74 | return tf.reduce_mean(tf.abs(prediction['image'] - example['y']))
75 |
76 |
77 | def l1_warped_loss(example: Mapping[str, tf.Tensor],
78 | prediction: Mapping[str, tf.Tensor]) -> tf.Tensor:
79 | """Computes an l1 loss using only warped images.
80 |
81 | Args:
82 | example: A dictionary with the ground truth image as 'y'.
83 | prediction: The prediction dictionary with the image(s) as 'x0_warped'
84 | and/or 'x1_warped'.
85 |
86 | Returns:
87 | A tf.Tensor of a scalar representing the linear combination of l1 losses
88 | between prediction images and y.
89 | """
90 | loss = tf.constant(0.0, dtype=tf.float32)
91 | if 'x0_warped' in prediction:
92 | loss += tf.reduce_mean(tf.abs(prediction['x0_warped'] - example['y']))
93 | if 'x1_warped' in prediction:
94 | loss += tf.reduce_mean(tf.abs(prediction['x1_warped'] - example['y']))
95 | return loss
96 |
97 |
98 | def l2_loss(example: Mapping[str, tf.Tensor],
99 | prediction: Mapping[str, tf.Tensor]) -> tf.Tensor:
100 | return tf.reduce_mean(tf.square(prediction['image'] - example['y']))
101 |
102 |
103 | def ssim_loss(example: Mapping[str, tf.Tensor],
104 | prediction: Mapping[str, tf.Tensor]) -> tf.Tensor:
105 | image = prediction['image']
106 | y = example['y']
107 | return tf.reduce_mean(tf.image.ssim(image, y, max_val=1.0))
108 |
109 |
110 | def psnr_loss(example: Mapping[str, tf.Tensor],
111 | prediction: Mapping[str, tf.Tensor]) -> tf.Tensor:
112 | return tf.reduce_mean(
113 | tf.image.psnr(prediction['image'], example['y'], max_val=1.0))
114 |
115 |
116 | def get_loss(loss_name: str) -> Callable[[Any, Any], tf.Tensor]:
117 | """Returns the loss function corresponding to the given name."""
118 | if loss_name == 'l1':
119 | return l1_loss
120 | elif loss_name == 'l2':
121 | return l2_loss
122 | elif loss_name == 'ssim':
123 | return ssim_loss
124 | elif loss_name == 'vgg':
125 | return vgg_loss
126 | elif loss_name == 'style':
127 | return style_loss
128 | elif loss_name == 'psnr':
129 | return psnr_loss
130 | elif loss_name == 'l1_warped':
131 | return l1_warped_loss
132 | else:
133 | raise ValueError('Invalid loss function %s' % loss_name)
134 |
135 |
136 | # pylint: disable=unnecessary-lambda
137 | def get_loss_op(loss_name):
138 | """Returns a function for creating a loss calculation op."""
139 | loss = get_loss(loss_name)
140 | return lambda example, prediction: loss(example, prediction)
141 |
142 |
143 | def get_weight_op(weight_schedule):
144 | """Returns a function for creating an iteration dependent loss weight op."""
145 | return lambda iterations: weight_schedule(iterations)
146 |
147 |
148 | def create_losses(
149 | loss_names: List[str], loss_weight_schedules: List[
150 | tf.keras.optimizers.schedules.LearningRateSchedule]
151 | ) -> Dict[str, Tuple[Callable[[Any, Any], tf.Tensor], Callable[[Any],
152 | tf.Tensor]]]:
153 | """Returns a dictionary of functions for creating loss and loss_weight ops.
154 |
155 | As an example, create_losses(['l1', 'l2'], [PiecewiseConstantDecay(),
156 | PiecewiseConstantDecay()]) returns a dictionary with two keys, and each value
157 | being a tuple of ops for loss calculation and loss_weight sampling.
158 |
159 | Args:
160 | loss_names: Names of the losses.
161 | loss_weight_schedules: Instances of loss weight schedules.
162 |
163 | Returns:
164 | A dictionary that contains the loss and weight schedule ops keyed by the
165 | names.
166 | """
167 | losses = dict()
168 | for name, weight_schedule in zip(loss_names, loss_weight_schedules):
169 | unique_values = np.unique(weight_schedule.values)
170 | if len(unique_values) == 1 and unique_values[0] == 1.0:
171 | # Special case 'no weight' for prettier TensorBoard summaries.
172 | weighted_name = name
173 | else:
174 | # Weights are variable/scheduled, a constant "k" is used to
175 | # indicate weights are iteration dependent.
176 | weighted_name = 'k*' + name
177 | losses[weighted_name] = (get_loss_op(name), get_weight_op(weight_schedule))
178 | return losses
179 |
180 |
181 | @gin.configurable
182 | def training_losses(
183 | loss_names: List[str],
184 | loss_weights: Optional[List[float]] = None,
185 | loss_weight_schedules: Optional[List[
186 | tf.keras.optimizers.schedules.LearningRateSchedule]] = None,
187 | loss_weight_parameters: Optional[List[Mapping[str, List[Any]]]] = None
188 | ) -> Mapping[str, Tuple[Callable[[Any, Any], tf.Tensor], Callable[[Any],
189 | tf.Tensor]]]:
190 | """Creates the training loss functions and loss weight schedules."""
191 | weight_schedules = []
192 | if not loss_weights:
193 | for weight_schedule, weight_parameters in zip(loss_weight_schedules,
194 | loss_weight_parameters):
195 | weight_schedules.append(weight_schedule(**weight_parameters))
196 | else:
197 | for loss_weight in loss_weights:
198 | weight_parameters = {
199 | 'boundaries': [0],
200 | 'values': 2 * [
201 | loss_weight,
202 | ]
203 | }
204 | weight_schedules.append(
205 | tf.keras.optimizers.schedules.PiecewiseConstantDecay(
206 | **weight_parameters))
207 |
208 | return create_losses(loss_names, weight_schedules)
209 |
210 |
211 | @gin.configurable
212 | def test_losses(
213 | loss_names: List[str],
214 | loss_weights: Optional[List[float]] = None,
215 | loss_weight_schedules: Optional[List[
216 | tf.keras.optimizers.schedules.LearningRateSchedule]] = None,
217 | loss_weight_parameters: Optional[List[Mapping[str, List[Any]]]] = None
218 | ) -> Mapping[str, Tuple[Callable[[Any, Any], tf.Tensor], Callable[[Any],
219 | tf.Tensor]]]:
220 | """Creates the test loss functions and loss weight schedules."""
221 | weight_schedules = []
222 | if not loss_weights:
223 | for weight_schedule, weight_parameters in zip(loss_weight_schedules,
224 | loss_weight_parameters):
225 | weight_schedules.append(weight_schedule(**weight_parameters))
226 | else:
227 | for loss_weight in loss_weights:
228 | weight_parameters = {
229 | 'boundaries': [0],
230 | 'values': 2 * [
231 | loss_weight,
232 | ]
233 | }
234 | weight_schedules.append(
235 | tf.keras.optimizers.schedules.PiecewiseConstantDecay(
236 | **weight_parameters))
237 |
238 | return create_losses(loss_names, weight_schedules)
239 |
240 |
241 | def aggregate_batch_losses(
242 | batch_losses: List[Mapping[str, float]]) -> Mapping[str, float]:
243 | """Averages per batch losses into single dictionary for the whole epoch.
244 |
245 | As an example, if the batch_losses contained per batch losses:
246 | batch_losses = { {'l1': 0.2, 'ssim': 0.9}, {'l1': 0.3, 'ssim': 0.8}}
247 | The returned dictionary would look like: { 'l1': 0.25, 'ssim': 0.95 }
248 |
249 | Args:
250 | batch_losses: A list of dictionary objects, with one entry for each loss.
251 |
252 | Returns:
253 | Single dictionary with the losses aggregated.
254 | """
255 | transp_losses = {}
256 | # Loop through all losses
257 | for batch_loss in batch_losses:
258 | # Loop through per batch losses of a single type:
259 | for loss_name, loss in batch_loss.items():
260 | if loss_name not in transp_losses:
261 | transp_losses[loss_name] = []
262 | transp_losses[loss_name].append(loss)
263 | aggregate_losses = {}
264 | for loss_name in transp_losses:
265 | aggregate_losses[loss_name] = np.mean(transp_losses[loss_name])
266 | return aggregate_losses
267 |
--------------------------------------------------------------------------------
/losses/vgg19_loss.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Feature loss based on 19 layer VGG network.
16 |
17 |
18 | The network layers in the feature loss is weighted as described in
19 | 'Stereo Magnification: Learning View Synthesis using Multiplane Images',
20 | Tinghui Zhou, Richard Tucker, Flynn, Graham Fyffe, Noah Snavely, SIGGRAPH 2018.
21 | """
22 |
23 | from typing import Any, Callable, Dict, Optional, Sequence, Tuple
24 |
25 | import numpy as np
26 | import scipy.io as sio
27 | import tensorflow.compat.v1 as tf
28 |
29 |
30 | def _build_net(layer_type: str,
31 | input_tensor: tf.Tensor,
32 | weight_bias: Optional[Tuple[tf.Tensor, tf.Tensor]] = None,
33 | name: Optional[str] = None) -> Callable[[Any], Any]:
34 | """Build a layer of the VGG network.
35 |
36 | Args:
37 | layer_type: A string, type of this layer.
38 | input_tensor: A tensor.
39 | weight_bias: A tuple of weight and bias.
40 | name: A string, name of this layer.
41 |
42 | Returns:
43 | A callable function of the tensorflow layer.
44 |
45 | Raises:
46 | ValueError: If layer_type is not conv or pool.
47 | """
48 |
49 | if layer_type == 'conv':
50 | return tf.nn.relu(
51 | tf.nn.conv2d(
52 | input_tensor,
53 | weight_bias[0],
54 | strides=[1, 1, 1, 1],
55 | padding='SAME',
56 | name=name) + weight_bias[1])
57 | elif layer_type == 'pool':
58 | return tf.nn.avg_pool(
59 | input_tensor, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
60 | else:
61 | raise ValueError('Unsupported layer %s' % layer_type)
62 |
63 |
64 | def _get_weight_and_bias(vgg_layers: np.ndarray,
65 | index: int) -> Tuple[tf.Tensor, tf.Tensor]:
66 | """Get the weight and bias of a specific layer from the VGG pretrained model.
67 |
68 | Args:
69 | vgg_layers: An array, the VGG pretrained model.
70 | index: An integer, index of the layer.
71 |
72 | Returns:
73 | weights: A tensor.
74 | bias: A tensor.
75 | """
76 |
77 | weights = vgg_layers[index][0][0][2][0][0]
78 | weights = tf.constant(weights)
79 | bias = vgg_layers[index][0][0][2][0][1]
80 | bias = tf.constant(np.reshape(bias, (bias.size)))
81 |
82 | return weights, bias
83 |
84 |
85 | def _build_vgg19(image: tf.Tensor, model_filepath: str) -> Dict[str, tf.Tensor]:
86 | """Builds the VGG network given the model weights.
87 |
88 | The weights are loaded only for the first time this code is invoked.
89 |
90 | Args:
91 | image: A tensor, input image.
92 | model_filepath: A string, path to the VGG pretrained model.
93 |
94 | Returns:
95 | net: A dict mapping a layer name to a tensor.
96 | """
97 |
98 | with tf.variable_scope('vgg', reuse=True):
99 | net = {}
100 | if not hasattr(_build_vgg19, 'vgg_rawnet'):
101 | with tf.io.gfile.GFile(model_filepath, 'rb') as f:
102 | _build_vgg19.vgg_rawnet = sio.loadmat(f)
103 | vgg_layers = _build_vgg19.vgg_rawnet['layers'][0]
104 | imagenet_mean = tf.constant([123.6800, 116.7790, 103.9390],
105 | shape=[1, 1, 1, 3])
106 | net['input'] = image - imagenet_mean
107 | net['conv1_1'] = _build_net(
108 | 'conv',
109 | net['input'],
110 | _get_weight_and_bias(vgg_layers, 0),
111 | name='vgg_conv1_1')
112 | net['conv1_2'] = _build_net(
113 | 'conv',
114 | net['conv1_1'],
115 | _get_weight_and_bias(vgg_layers, 2),
116 | name='vgg_conv1_2')
117 | net['pool1'] = _build_net('pool', net['conv1_2'])
118 | net['conv2_1'] = _build_net(
119 | 'conv',
120 | net['pool1'],
121 | _get_weight_and_bias(vgg_layers, 5),
122 | name='vgg_conv2_1')
123 | net['conv2_2'] = _build_net(
124 | 'conv',
125 | net['conv2_1'],
126 | _get_weight_and_bias(vgg_layers, 7),
127 | name='vgg_conv2_2')
128 | net['pool2'] = _build_net('pool', net['conv2_2'])
129 | net['conv3_1'] = _build_net(
130 | 'conv',
131 | net['pool2'],
132 | _get_weight_and_bias(vgg_layers, 10),
133 | name='vgg_conv3_1')
134 | net['conv3_2'] = _build_net(
135 | 'conv',
136 | net['conv3_1'],
137 | _get_weight_and_bias(vgg_layers, 12),
138 | name='vgg_conv3_2')
139 | net['conv3_3'] = _build_net(
140 | 'conv',
141 | net['conv3_2'],
142 | _get_weight_and_bias(vgg_layers, 14),
143 | name='vgg_conv3_3')
144 | net['conv3_4'] = _build_net(
145 | 'conv',
146 | net['conv3_3'],
147 | _get_weight_and_bias(vgg_layers, 16),
148 | name='vgg_conv3_4')
149 | net['pool3'] = _build_net('pool', net['conv3_4'])
150 | net['conv4_1'] = _build_net(
151 | 'conv',
152 | net['pool3'],
153 | _get_weight_and_bias(vgg_layers, 19),
154 | name='vgg_conv4_1')
155 | net['conv4_2'] = _build_net(
156 | 'conv',
157 | net['conv4_1'],
158 | _get_weight_and_bias(vgg_layers, 21),
159 | name='vgg_conv4_2')
160 | net['conv4_3'] = _build_net(
161 | 'conv',
162 | net['conv4_2'],
163 | _get_weight_and_bias(vgg_layers, 23),
164 | name='vgg_conv4_3')
165 | net['conv4_4'] = _build_net(
166 | 'conv',
167 | net['conv4_3'],
168 | _get_weight_and_bias(vgg_layers, 25),
169 | name='vgg_conv4_4')
170 | net['pool4'] = _build_net('pool', net['conv4_4'])
171 | net['conv5_1'] = _build_net(
172 | 'conv',
173 | net['pool4'],
174 | _get_weight_and_bias(vgg_layers, 28),
175 | name='vgg_conv5_1')
176 | net['conv5_2'] = _build_net(
177 | 'conv',
178 | net['conv5_1'],
179 | _get_weight_and_bias(vgg_layers, 30),
180 | name='vgg_conv5_2')
181 |
182 | return net
183 |
184 |
185 | def _compute_error(fake: tf.Tensor,
186 | real: tf.Tensor,
187 | mask: Optional[tf.Tensor] = None) -> tf.Tensor:
188 | """Computes the L1 loss and reweights by the mask."""
189 | if mask is None:
190 | return tf.reduce_mean(tf.abs(fake - real))
191 | else:
192 | # Resizes mask to the same size as the input.
193 | size = (tf.shape(fake)[1], tf.shape(fake)[2])
194 | resized_mask = tf.image.resize(
195 | mask, size, method=tf.image.ResizeMethod.BILINEAR)
196 | return tf.reduce_mean(tf.abs(fake - real) * resized_mask)
197 |
198 |
199 | # Normalized VGG loss (from
200 | # https://github.com/CQFIO/PhotographicImageSynthesis)
201 | def vgg_loss(image: tf.Tensor,
202 | reference: tf.Tensor,
203 | vgg_model_file: str,
204 | weights: Optional[Sequence[float]] = None,
205 | mask: Optional[tf.Tensor] = None) -> tf.Tensor:
206 | """Computes the VGG loss for an image pair.
207 |
208 | The VGG loss is the average feature vector difference between the two images.
209 |
210 | The input images must be in [0, 1] range in (B, H, W, 3) RGB format and
211 | the recommendation seems to be to have them in gamma space.
212 |
213 | The pretrained weights are publicly available in
214 | http://www.vlfeat.org/matconvnet/models/imagenet-vgg-verydeep-19.mat
215 |
216 | Args:
217 | image: A tensor, typically the prediction from a network.
218 | reference: A tensor, the image to compare against, i.e. the golden image.
219 | vgg_model_file: A string, filename for the VGG 19 network weights in MATLAB
220 | format.
221 | weights: A list of float, optional weights for the layers. The defaults are
222 | from Qifeng Chen and Vladlen Koltun, "Photographic image synthesis with
223 | cascaded refinement networks," ICCV 2017.
224 | mask: An optional image-shape and single-channel tensor, the mask values are
225 | per-pixel weights to be applied on the losses. The mask will be resized to
226 | the same spatial resolution with the feature maps before been applied to
227 | the losses. When the mask value is zero, pixels near the boundary of the
228 | mask can still influence the loss if they fall into the receptive field of
229 | the VGG convolutional layers.
230 |
231 | Returns:
232 | vgg_loss: The linear combination of losses from five VGG layers.
233 | """
234 |
235 | if not weights:
236 | weights = [1.0 / 2.6, 1.0 / 4.8, 1.0 / 3.7, 1.0 / 5.6, 10.0 / 1.5]
237 |
238 | vgg_ref = _build_vgg19(reference * 255.0, vgg_model_file)
239 | vgg_img = _build_vgg19(image * 255.0, vgg_model_file)
240 | p1 = _compute_error(vgg_ref['conv1_2'], vgg_img['conv1_2'], mask) * weights[0]
241 | p2 = _compute_error(vgg_ref['conv2_2'], vgg_img['conv2_2'], mask) * weights[1]
242 | p3 = _compute_error(vgg_ref['conv3_2'], vgg_img['conv3_2'], mask) * weights[2]
243 | p4 = _compute_error(vgg_ref['conv4_2'], vgg_img['conv4_2'], mask) * weights[3]
244 | p5 = _compute_error(vgg_ref['conv5_2'], vgg_img['conv5_2'], mask) * weights[4]
245 |
246 | final_loss = p1 + p2 + p3 + p4 + p5
247 |
248 | # Scale to range [0..1].
249 | final_loss /= 255.0
250 |
251 | return final_loss
252 |
253 |
254 | def _compute_gram_matrix(input_features: tf.Tensor,
255 | mask: tf.Tensor) -> tf.Tensor:
256 | """Computes Gram matrix of `input_features`.
257 |
258 | Gram matrix described in https://en.wikipedia.org/wiki/Gramian_matrix.
259 |
260 | Args:
261 | input_features: A tf.Tensor of shape (B, H, W, C) representing a feature map
262 | obtained by a convolutional layer of a VGG network.
263 | mask: A tf.Tensor of shape (B, H, W, 1) representing the per-pixel weights
264 | to be applied on the `input_features`. The mask will be resized to the
265 | same spatial resolution as the `input_featues`. When the mask value is
266 | zero, pixels near the boundary of the mask can still influence the loss if
267 | they fall into the receptive field of the VGG convolutional layers.
268 |
269 | Returns:
270 | A tf.Tensor of shape (B, C, C) representing the gram matrix of the masked
271 | `input_features`.
272 | """
273 | _, h, w, c = tuple([
274 | i if (isinstance(i, int) or i is None) else i.value
275 | for i in input_features.shape
276 | ])
277 | if mask is None:
278 | reshaped_features = tf.reshape(input_features, (-1, h * w, c))
279 | else:
280 | # Resize mask to match the shape of `input_features`
281 | resized_mask = tf.image.resize(
282 | mask, (h, w), method=tf.image.ResizeMethod.BILINEAR)
283 | reshaped_features = tf.reshape(input_features * resized_mask,
284 | (-1, h * w, c))
285 | return tf.matmul(
286 | reshaped_features, reshaped_features, transpose_a=True) / float(h * w)
287 |
288 |
289 | def style_loss(image: tf.Tensor,
290 | reference: tf.Tensor,
291 | vgg_model_file: str,
292 | weights: Optional[Sequence[float]] = None,
293 | mask: Optional[tf.Tensor] = None) -> tf.Tensor:
294 | """Computes style loss as used in `A Neural Algorithm of Artistic Style`.
295 |
296 | Based on the work in https://github.com/cysmith/neural-style-tf. Weights are
297 | first initilaized to the inverse of the number of elements in each VGG layer
298 | considerd. After 1.5M iterations, they are rescaled to normalize the
299 | contribution of the Style loss to be equal to other losses (L1/VGG). This is
300 | based on the works of image inpainting (https://arxiv.org/abs/1804.07723)
301 | and frame prediction (https://arxiv.org/abs/1811.00684).
302 |
303 | The style loss is the average gram matrix difference between `image` and
304 | `reference`. The gram matrix is the inner product of a feature map of shape
305 | (B, H*W, C) with itself. Results in a symmetric gram matrix shaped (B, C, C).
306 |
307 | The input images must be in [0, 1] range in (B, H, W, 3) RGB format and
308 | the recommendation seems to be to have them in gamma space.
309 |
310 | The pretrained weights are publicly available in
311 | http://www.vlfeat.org/matconvnet/models/imagenet-vgg-verydeep-19.mat
312 |
313 | Args:
314 | image: A tensor, typically the prediction from a network.
315 | reference: A tensor, the image to compare against, i.e. the golden image.
316 | vgg_model_file: A string, filename for the VGG 19 network weights in MATLAB
317 | format.
318 | weights: A list of float, optional weights for the layers. The defaults are
319 | from Qifeng Chen and Vladlen Koltun, "Photographic image synthesis with
320 | cascaded refinement networks," ICCV 2017.
321 | mask: An optional image-shape and single-channel tensor, the mask values are
322 | per-pixel weights to be applied on the losses. The mask will be resized to
323 | the same spatial resolution with the feature maps before been applied to
324 | the losses. When the mask value is zero, pixels near the boundary of the
325 | mask can still influence the loss if they fall into the receptive field of
326 | the VGG convolutional layers.
327 |
328 | Returns:
329 | Style loss, a linear combination of gram matrix L2 differences of from five
330 | VGG layer features.
331 | """
332 |
333 | if not weights:
334 | weights = [1.0 / 2.6, 1.0 / 4.8, 1.0 / 3.7, 1.0 / 5.6, 10.0 / 1.5]
335 |
336 | vgg_ref = _build_vgg19(reference * 255.0, vgg_model_file)
337 | vgg_img = _build_vgg19(image * 255.0, vgg_model_file)
338 |
339 | p1 = tf.reduce_mean(
340 | tf.squared_difference(
341 | _compute_gram_matrix(vgg_ref['conv1_2'] / 255.0, mask),
342 | _compute_gram_matrix(vgg_img['conv1_2'] / 255.0, mask))) * weights[0]
343 | p2 = tf.reduce_mean(
344 | tf.squared_difference(
345 | _compute_gram_matrix(vgg_ref['conv2_2'] / 255.0, mask),
346 | _compute_gram_matrix(vgg_img['conv2_2'] / 255.0, mask))) * weights[1]
347 | p3 = tf.reduce_mean(
348 | tf.squared_difference(
349 | _compute_gram_matrix(vgg_ref['conv3_2'] / 255.0, mask),
350 | _compute_gram_matrix(vgg_img['conv3_2'] / 255.0, mask))) * weights[2]
351 | p4 = tf.reduce_mean(
352 | tf.squared_difference(
353 | _compute_gram_matrix(vgg_ref['conv4_2'] / 255.0, mask),
354 | _compute_gram_matrix(vgg_img['conv4_2'] / 255.0, mask))) * weights[3]
355 | p5 = tf.reduce_mean(
356 | tf.squared_difference(
357 | _compute_gram_matrix(vgg_ref['conv5_2'] / 255.0, mask),
358 | _compute_gram_matrix(vgg_img['conv5_2'] / 255.0, mask))) * weights[4]
359 |
360 | final_loss = p1 + p2 + p3 + p4 + p5
361 |
362 | return final_loss
363 |
--------------------------------------------------------------------------------
/models/film_net/feature_extractor.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """TF2 layer for extracting image features for the film_net interpolator.
16 |
17 | The feature extractor implemented here converts an image pyramid into a pyramid
18 | of deep features. The feature pyramid serves a similar purpose as U-Net
19 | architecture's encoder, but we use a special cascaded architecture described in
20 | Multi-view Image Fusion [1].
21 |
22 | For comprehensiveness, below is a short description of the idea. While the
23 | description is a bit involved, the cascaded feature pyramid can be used just
24 | like any image feature pyramid.
25 |
26 | Why cascaded architeture?
27 | =========================
28 | To understand the concept it is worth reviewing a traditional feature pyramid
29 | first: *A traditional feature pyramid* as in U-net or in many optical flow
30 | networks is built by alternating between convolutions and pooling, starting
31 | from the input image.
32 |
33 | It is well known that early features of such architecture correspond to low
34 | level concepts such as edges in the image whereas later layers extract
35 | semantically higher level concepts such as object classes etc. In other words,
36 | the meaning of the filters in each resolution level is different. For problems
37 | such as semantic segmentation and many others this is a desirable property.
38 |
39 | However, the asymmetric features preclude sharing weights across resolution
40 | levels in the feature extractor itself and in any subsequent neural networks
41 | that follow. This can be a downside, since optical flow prediction, for
42 | instance is symmetric across resolution levels. The cascaded feature
43 | architecture addresses this shortcoming.
44 |
45 | How is it built?
46 | ================
47 | The *cascaded* feature pyramid contains feature vectors that have constant
48 | length and meaning on each resolution level, except few of the finest ones. The
49 | advantage of this is that the subsequent optical flow layer can learn
50 | synergically from many resolutions. This means that coarse level prediction can
51 | benefit from finer resolution training examples, which can be useful with
52 | moderately sized datasets to avoid overfitting.
53 |
54 | The cascaded feature pyramid is built by extracting shallower subtree pyramids,
55 | each one of them similar to the traditional architecture. Each subtree
56 | pyramid S_i is extracted starting from each resolution level:
57 |
58 | image resolution 0 -> S_0
59 | image resolution 1 -> S_1
60 | image resolution 2 -> S_2
61 | ...
62 |
63 | If we denote the features at level j of subtree i as S_i_j, the cascaded pyramid
64 | is constructed by concatenating features as follows (assuming subtree depth=3):
65 |
66 | lvl
67 | feat_0 = concat( S_0_0 )
68 | feat_1 = concat( S_1_0 S_0_1 )
69 | feat_2 = concat( S_2_0 S_1_1 S_0_2 )
70 | feat_3 = concat( S_3_0 S_2_1 S_1_2 )
71 | feat_4 = concat( S_4_0 S_3_1 S_2_2 )
72 | feat_5 = concat( S_5_0 S_4_1 S_3_2 )
73 | ....
74 |
75 | In above, all levels except feat_0 and feat_1 have the same number of features
76 | with similar semantic meaning. This enables training a single optical flow
77 | predictor module shared by levels 2,3,4,5... . For more details and evaluation
78 | see [1].
79 |
80 | [1] Multi-view Image Fusion, Trinidad et al. 2019
81 | """
82 |
83 | from typing import List
84 |
85 | from . import options
86 | import tensorflow as tf
87 |
88 |
89 | def _relu(x: tf.Tensor) -> tf.Tensor:
90 | return tf.nn.leaky_relu(x, alpha=0.2)
91 |
92 |
93 | def _conv(filters: int, name: str):
94 | return tf.keras.layers.Conv2D(
95 | name=name,
96 | filters=filters,
97 | kernel_size=3,
98 | padding='same',
99 | activation=_relu)
100 |
101 |
102 | class SubTreeExtractor(tf.keras.layers.Layer):
103 | """Extracts a hierarchical set of features from an image.
104 |
105 | This is a conventional, hierarchical image feature extractor, that extracts
106 | [k, k*2, k*4... ] filters for the image pyramid where k=options.sub_levels.
107 | Each level is followed by average pooling.
108 |
109 | Attributes:
110 | name: Name for the layer
111 | config: Options for the fusion_net frame interpolator
112 | """
113 |
114 | def __init__(self, name: str, config: options.Options):
115 | super().__init__(name=name)
116 | k = config.filters
117 | n = config.sub_levels
118 | self.convs = []
119 | for i in range(n):
120 | self.convs.append(
121 | _conv(filters=(k << i), name='cfeat_conv_{}'.format(2 * i)))
122 | self.convs.append(
123 | _conv(filters=(k << i), name='cfeat_conv_{}'.format(2 * i + 1)))
124 |
125 | def call(self, image: tf.Tensor, n: int) -> List[tf.Tensor]:
126 | """Extracts a pyramid of features from the image.
127 |
128 | Args:
129 | image: tf.Tensor with shape BATCH_SIZE x HEIGHT x WIDTH x CHANNELS.
130 | n: number of pyramid levels to extract. This can be less or equal to
131 | options.sub_levels given in the __init__.
132 | Returns:
133 | The pyramid of features, starting from the finest level. Each element
134 | contains the output after the last convolution on the corresponding
135 | pyramid level.
136 | """
137 | head = image
138 | pool = tf.keras.layers.AveragePooling2D(
139 | pool_size=2, strides=2, padding='valid')
140 | pyramid = []
141 | for i in range(n):
142 | head = self.convs[2*i](head)
143 | head = self.convs[2*i+1](head)
144 | pyramid.append(head)
145 | if i < n-1:
146 | head = pool(head)
147 | return pyramid
148 |
149 |
150 | class FeatureExtractor(tf.keras.layers.Layer):
151 | """Extracts features from an image pyramid using a cascaded architecture.
152 |
153 | Attributes:
154 | name: Name of the layer
155 | config: Options for the fusion_net frame interpolator
156 | """
157 |
158 | def __init__(self, name: str, config: options.Options):
159 | super().__init__(name=name)
160 | self.extract_sublevels = SubTreeExtractor('sub_extractor', config)
161 | self.options = config
162 |
163 | def call(self, image_pyramid: List[tf.Tensor]) -> List[tf.Tensor]:
164 | """Extracts a cascaded feature pyramid.
165 |
166 | Args:
167 | image_pyramid: Image pyramid as a list, starting from the finest level.
168 | Returns:
169 | A pyramid of cascaded features.
170 | """
171 | sub_pyramids = []
172 | for i in range(len(image_pyramid)):
173 | # At each level of the image pyramid, creates a sub_pyramid of features
174 | # with 'sub_levels' pyramid levels, re-using the same SubTreeExtractor.
175 | # We use the same instance since we want to share the weights.
176 | #
177 | # However, we cap the depth of the sub_pyramid so we don't create features
178 | # that are beyond the coarsest level of the cascaded feature pyramid we
179 | # want to generate.
180 | capped_sub_levels = min(len(image_pyramid) - i, self.options.sub_levels)
181 | sub_pyramids.append(
182 | self.extract_sublevels(image_pyramid[i], capped_sub_levels))
183 | # Below we generate the cascades of features on each level of the feature
184 | # pyramid. Assuming sub_levels=3, The layout of the features will be
185 | # as shown in the example on file documentation above.
186 | feature_pyramid = []
187 | for i in range(len(image_pyramid)):
188 | features = sub_pyramids[i][0]
189 | for j in range(1, self.options.sub_levels):
190 | if j <= i:
191 | features = tf.concat([features, sub_pyramids[i - j][j]], axis=-1)
192 | feature_pyramid.append(features)
193 | return feature_pyramid
194 |
--------------------------------------------------------------------------------
/models/film_net/fusion.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """The final fusion stage for the film_net frame interpolator.
16 |
17 | The inputs to this module are the warped input images, image features and
18 | flow fields, all aligned to the target frame (often midway point between the
19 | two original inputs). The output is the final image. FILM has no explicit
20 | occlusion handling -- instead using the abovementioned information this module
21 | automatically decides how to best blend the inputs together to produce content
22 | in areas where the pixels can only be borrowed from one of the inputs.
23 |
24 | Similarly, this module also decides on how much to blend in each input in case
25 | of fractional timestep that is not at the halfway point. For example, if the two
26 | inputs images are at t=0 and t=1, and we were to synthesize a frame at t=0.1,
27 | it often makes most sense to favor the first input. However, this is not
28 | always the case -- in particular in occluded pixels.
29 |
30 | The architecture of the Fusion module follows U-net [1] architecture's decoder
31 | side, e.g. each pyramid level consists of concatenation with upsampled coarser
32 | level output, and two 3x3 convolutions.
33 |
34 | The upsampling is implemented as 'resize convolution', e.g. nearest neighbor
35 | upsampling followed by 2x2 convolution as explained in [2]. The classic U-net
36 | uses max-pooling which has a tendency to create checkerboard artifacts.
37 |
38 | [1] Ronneberger et al. U-Net: Convolutional Networks for Biomedical Image
39 | Segmentation, 2015, https://arxiv.org/pdf/1505.04597.pdf
40 | [2] https://distill.pub/2016/deconv-checkerboard/
41 | """
42 |
43 | from typing import List
44 |
45 | from . import options
46 | import tensorflow as tf
47 |
48 |
49 | def _relu(x: tf.Tensor) -> tf.Tensor:
50 | return tf.nn.leaky_relu(x, alpha=0.2)
51 |
52 |
53 | _NUMBER_OF_COLOR_CHANNELS = 3
54 |
55 |
56 | class Fusion(tf.keras.layers.Layer):
57 | """The decoder."""
58 |
59 | def __init__(self, name: str, config: options.Options):
60 | super().__init__(name=name)
61 |
62 | # Each item 'convs[i]' will contain the list of convolutions to be applied
63 | # for pyramid level 'i'.
64 | self.convs: List[List[tf.keras.layers.Layer]] = []
65 |
66 | # Store the levels, so we can verify right number of levels in call().
67 | self.levels = config.fusion_pyramid_levels
68 |
69 | # Create the convolutions. Roughly following the feature extractor, we
70 | # double the number of filters when the resolution halves, but only up to
71 | # the specialized_levels, after which we use the same number of filters on
72 | # all levels.
73 | #
74 | # We create the convs in fine-to-coarse order, so that the array index
75 | # for the convs will correspond to our normal indexing (0=finest level).
76 | for i in range(config.fusion_pyramid_levels - 1):
77 | m = config.specialized_levels
78 | k = config.filters
79 | num_filters = (k << i) if i < m else (k << m)
80 |
81 | convs: List[tf.keras.layers.Layer] = []
82 | convs.append(
83 | tf.keras.layers.Conv2D(
84 | filters=num_filters, kernel_size=[2, 2], padding='same'))
85 | convs.append(
86 | tf.keras.layers.Conv2D(
87 | filters=num_filters,
88 | kernel_size=[3, 3],
89 | padding='same',
90 | activation=_relu))
91 | convs.append(
92 | tf.keras.layers.Conv2D(
93 | filters=num_filters,
94 | kernel_size=[3, 3],
95 | padding='same',
96 | activation=_relu))
97 | self.convs.append(convs)
98 |
99 | # The final convolution that outputs RGB:
100 | self.output_conv = tf.keras.layers.Conv2D(
101 | filters=_NUMBER_OF_COLOR_CHANNELS, kernel_size=1)
102 |
103 | def call(self, pyramid: List[tf.Tensor]) -> tf.Tensor:
104 | """Runs the fusion module.
105 |
106 | Args:
107 | pyramid: The input feature pyramid as list of tensors. Each tensor being
108 | in (B x H x W x C) format, with finest level tensor first.
109 |
110 | Returns:
111 | A batch of RGB images.
112 | Raises:
113 | ValueError, if len(pyramid) != config.fusion_pyramid_levels as provided in
114 | the constructor.
115 | """
116 | if len(pyramid) != self.levels:
117 | raise ValueError(
118 | 'Fusion called with different number of pyramid levels '
119 | f'{len(pyramid)} than it was configured for, {self.levels}.')
120 |
121 | # As a slight difference to a conventional decoder (e.g. U-net), we don't
122 | # apply any extra convolutions to the coarsest level, but just pass it
123 | # to finer levels for concatenation. This choice has not been thoroughly
124 | # evaluated, but is motivated by the educated guess that the fusion part
125 | # probably does not need large spatial context, because at this point the
126 | # features are spatially aligned by the preceding warp.
127 | net = pyramid[-1]
128 |
129 | # Loop starting from the 2nd coarsest level:
130 | for i in reversed(range(0, self.levels - 1)):
131 | # Resize the tensor from coarser level to match for concatenation.
132 | level_size = tf.shape(pyramid[i])[1:3]
133 | net = tf.image.resize(net, level_size,
134 | tf.image.ResizeMethod.NEAREST_NEIGHBOR)
135 | net = self.convs[i][0](net)
136 | net = tf.concat([pyramid[i], net], axis=-1)
137 | net = self.convs[i][1](net)
138 | net = self.convs[i][2](net)
139 | net = self.output_conv(net)
140 | return net
141 |
--------------------------------------------------------------------------------
/models/film_net/interpolator.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """The film_net frame interpolator main model code.
16 |
17 | Basics
18 | ======
19 | The film_net is an end-to-end learned neural frame interpolator implemented as
20 | a TF2 model. It has the following inputs and outputs:
21 |
22 | Inputs:
23 | x0: image A.
24 | x1: image B.
25 | time: desired sub-frame time.
26 |
27 | Outputs:
28 | image: the predicted in-between image at the chosen time in range [0, 1].
29 |
30 | Additional outputs include forward and backward warped image pyramids, flow
31 | pyramids, etc., that can be visualized for debugging and analysis.
32 |
33 | Note that many training sets only contain triplets with ground truth at
34 | time=0.5. If a model has been trained with such training set, it will only work
35 | well for synthesizing frames at time=0.5. Such models can only generate more
36 | in-between frames using recursion.
37 |
38 | Architecture
39 | ============
40 | The inference consists of three main stages: 1) feature extraction 2) warping
41 | 3) fusion. On high-level, the architecture has similarities to Context-aware
42 | Synthesis for Video Frame Interpolation [1], but the exact architecture is
43 | closer to Multi-view Image Fusion [2] with some modifications for the frame
44 | interpolation use-case.
45 |
46 | Feature extraction stage employs the cascaded multi-scale architecture described
47 | in [2]. The advantage of this architecture is that coarse level flow prediction
48 | can be learned from finer resolution image samples. This is especially useful
49 | to avoid overfitting with moderately sized datasets.
50 |
51 | The warping stage uses a residual flow prediction idea that is similar to
52 | PWC-Net [3], Multi-view Image Fusion [2] and many others.
53 |
54 | The fusion stage is similar to U-Net's decoder where the skip connections are
55 | connected to warped image and feature pyramids. This is described in [2].
56 |
57 | Implementation Conventions
58 | ====================
59 | Pyramids
60 | --------
61 | Throughtout the model, all image and feature pyramids are stored as python lists
62 | with finest level first followed by downscaled versions obtained by successively
63 | halving the resolution. The depths of all pyramids are determined by
64 | options.pyramid_levels. The only exception to this is internal to the feature
65 | extractor, where smaller feature pyramids are temporarily constructed with depth
66 | options.sub_levels.
67 |
68 | Color ranges & gamma
69 | --------------------
70 | The model code makes no assumptions on whether the images are in gamma or
71 | linearized space or what is the range of RGB color values. So a model can be
72 | trained with different choices. This does not mean that all the choices lead to
73 | similar results. In practice the model has been proven to work well with RGB
74 | scale = [0,1] with gamma-space images (i.e. not linearized).
75 |
76 | [1] Context-aware Synthesis for Video Frame Interpolation, Niklaus and Liu, 2018
77 | [2] Multi-view Image Fusion, Trinidad et al, 2019
78 | [3] PWC-Net: CNNs for Optical Flow Using Pyramid, Warping, and Cost Volume
79 | """
80 |
81 | from . import feature_extractor
82 | from . import fusion
83 | from . import options
84 | from . import pyramid_flow_estimator
85 | from . import util
86 | import tensorflow as tf
87 |
88 |
89 | def create_model(x0: tf.Tensor, x1: tf.Tensor, time: tf.Tensor,
90 | config: options.Options) -> tf.keras.Model:
91 | """Creates a frame interpolator model.
92 |
93 | The frame interpolator is used to warp the two images to the in-between frame
94 | at given time. Note that training data is often restricted such that
95 | supervision only exists at 'time'=0.5. If trained with such data, the model
96 | will overfit to predicting images that are halfway between the two inputs and
97 | will not be as accurate elsewhere.
98 |
99 | Args:
100 | x0: first input image as BxHxWxC tensor.
101 | x1: second input image as BxHxWxC tensor.
102 | time: ignored by film_net. We always infer a frame at t = 0.5.
103 | config: FilmNetOptions object.
104 |
105 | Returns:
106 | A tf.Model that takes 'x0', 'x1', and 'time' as input and returns a
107 | dictionary with the interpolated result in 'image'. For additional
108 | diagnostics or supervision, the following intermediate results are
109 | also stored in the dictionary:
110 | 'x0_warped': an intermediate result obtained by warping from x0
111 | 'x1_warped': an intermediate result obtained by warping from x1
112 | 'forward_residual_flow_pyramid': pyramid with forward residual flows
113 | 'backward_residual_flow_pyramid': pyramid with backward residual flows
114 | 'forward_flow_pyramid': pyramid with forward flows
115 | 'backward_flow_pyramid': pyramid with backward flows
116 |
117 | Raises:
118 | ValueError, if config.pyramid_levels < config.fusion_pyramid_levels.
119 | """
120 | if config.pyramid_levels < config.fusion_pyramid_levels:
121 | raise ValueError('config.pyramid_levels must be greater than or equal to '
122 | 'config.fusion_pyramid_levels.')
123 |
124 | x0_decoded = x0
125 | x1_decoded = x1
126 |
127 | # shuffle images
128 | image_pyramids = [
129 | util.build_image_pyramid(x0_decoded, config),
130 | util.build_image_pyramid(x1_decoded, config)
131 | ]
132 |
133 | # Siamese feature pyramids:
134 | extract = feature_extractor.FeatureExtractor('feat_net', config)
135 | feature_pyramids = [extract(image_pyramids[0]), extract(image_pyramids[1])]
136 |
137 | predict_flow = pyramid_flow_estimator.PyramidFlowEstimator(
138 | 'predict_flow', config)
139 |
140 | # Predict forward flow.
141 | forward_residual_flow_pyramid = predict_flow(feature_pyramids[0],
142 | feature_pyramids[1])
143 | # Predict backward flow.
144 | backward_residual_flow_pyramid = predict_flow(feature_pyramids[1],
145 | feature_pyramids[0])
146 |
147 | # Concatenate features and images:
148 |
149 | # Note that we keep up to 'fusion_pyramid_levels' levels as only those
150 | # are used by the fusion module.
151 | fusion_pyramid_levels = config.fusion_pyramid_levels
152 |
153 | forward_flow_pyramid = util.flow_pyramid_synthesis(
154 | forward_residual_flow_pyramid)[:fusion_pyramid_levels]
155 | backward_flow_pyramid = util.flow_pyramid_synthesis(
156 | backward_residual_flow_pyramid)[:fusion_pyramid_levels]
157 |
158 | # We multiply the flows with t and 1-t to warp to the desired fractional time.
159 | #
160 | # Note: In film_net we fix time to be 0.5, and recursively invoke the interpo-
161 | # lator for multi-frame interpolation. Below, we create a constant tensor of
162 | # shape [B]. We use the `time` tensor to infer the batch size.
163 | mid_time = tf.keras.layers.Lambda(lambda x: tf.ones_like(x) * 0.5)(time)
164 | backward_flow = util.multiply_pyramid(backward_flow_pyramid, mid_time[:, 0])
165 | forward_flow = util.multiply_pyramid(forward_flow_pyramid, 1 - mid_time[:, 0])
166 |
167 | pyramids_to_warp = [
168 | util.concatenate_pyramids(image_pyramids[0][:fusion_pyramid_levels],
169 | feature_pyramids[0][:fusion_pyramid_levels]),
170 | util.concatenate_pyramids(image_pyramids[1][:fusion_pyramid_levels],
171 | feature_pyramids[1][:fusion_pyramid_levels])
172 | ]
173 |
174 | # Warp features and images using the flow. Note that we use backward warping
175 | # and backward flow is used to read from image 0 and forward flow from
176 | # image 1.
177 | forward_warped_pyramid = util.pyramid_warp(pyramids_to_warp[0], backward_flow)
178 | backward_warped_pyramid = util.pyramid_warp(pyramids_to_warp[1], forward_flow)
179 |
180 | aligned_pyramid = util.concatenate_pyramids(forward_warped_pyramid,
181 | backward_warped_pyramid)
182 | aligned_pyramid = util.concatenate_pyramids(aligned_pyramid, backward_flow)
183 | aligned_pyramid = util.concatenate_pyramids(aligned_pyramid, forward_flow)
184 |
185 | fuse = fusion.Fusion('fusion', config)
186 | prediction = fuse(aligned_pyramid)
187 |
188 | output_color = prediction[..., :3]
189 | outputs = {'image': output_color}
190 |
191 | if config.use_aux_outputs:
192 | outputs.update({
193 | 'x0_warped': forward_warped_pyramid[0][..., 0:3],
194 | 'x1_warped': backward_warped_pyramid[0][..., 0:3],
195 | 'forward_residual_flow_pyramid': forward_residual_flow_pyramid,
196 | 'backward_residual_flow_pyramid': backward_residual_flow_pyramid,
197 | 'forward_flow_pyramid': forward_flow_pyramid,
198 | 'backward_flow_pyramid': backward_flow_pyramid,
199 | })
200 |
201 | model = tf.keras.Model(
202 | inputs={
203 | 'x0': x0,
204 | 'x1': x1,
205 | 'time': time
206 | }, outputs=outputs)
207 | return model
208 |
--------------------------------------------------------------------------------
/models/film_net/options.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Options for the film_net video frame interpolator."""
16 |
17 | import gin.tf
18 |
19 |
20 | @gin.configurable('film_net')
21 | class Options(object):
22 | """Options for the film_net video frame interpolator.
23 |
24 | To further understand these options, see the paper here:
25 | https://augmentedperception.github.io/pixelfusion/.
26 |
27 | The default values are suitable for up to 64 pixel motions. For larger motions
28 | the number of flow convolutions and/or pyramid levels can be increased, but
29 | usually with the cost of accuracy on solving the smaller motions.
30 |
31 | The maximum motion in pixels that the system can resolve is equivalent to
32 | 2^(pyramid_levels-1) * flow_convs[-1]. I.e. the downsampling factor times
33 | the receptive field radius on the coarsest pyramid level. This, of course,
34 | assumes that the training data contains such motions.
35 |
36 | Note that to avoid a run-time error, the input image width and height have to
37 | be divisible by 2^(pyramid_levels-1).
38 |
39 | Attributes:
40 | pyramid_levels: How many pyramid levels to use for the feature pyramid and
41 | the flow prediction.
42 | fusion_pyramid_levels: How many pyramid levels to use for the fusion module
43 | this must be less or equal to 'pyramid_levels'.
44 | specialized_levels: How many fine levels of the pyramid shouldn't share the
45 | weights. If specialized_levels = 3, it means that two finest levels are
46 | independently learned, whereas the third will be learned together with the
47 | rest of the pyramid. Valid range [1, pyramid_levels].
48 | flow_convs: Convolutions per residual flow predictor. This array should have
49 | specialized_levels+1 items on it, the last item representing the number of
50 | convs used by any pyramid level that uses shared weights.
51 | flow_filters: Base number of filters in residual flow predictors. This array
52 | should have specialized_levels+1 items on it, the last item representing
53 | the number of filters used by any pyramid level that uses shared weights.
54 | sub_levels: The depth of the cascaded feature tree each pyramid level
55 | concatenates together to compute the flow. This must be within range [1,
56 | specialized_level+1]. It is recommended to set this to specialized_levels
57 | + 1
58 | filters: Base number of features to extract. On each pyramid level the
59 | number doubles. This is used by both feature extraction and fusion stages.
60 | use_aux_outputs: Set to True to include auxiliary outputs along with the
61 | predicted image.
62 | """
63 |
64 | def __init__(self,
65 | pyramid_levels=5,
66 | fusion_pyramid_levels=5,
67 | specialized_levels=3,
68 | flow_convs=None,
69 | flow_filters=None,
70 | sub_levels=4,
71 | filters=16,
72 | use_aux_outputs=True):
73 | self.pyramid_levels = pyramid_levels
74 | self.fusion_pyramid_levels = fusion_pyramid_levels
75 | self.specialized_levels = specialized_levels
76 | self.flow_convs = flow_convs or [4, 4, 4, 4]
77 | self.flow_filters = flow_filters or [64, 128, 256, 256]
78 | self.sub_levels = sub_levels
79 | self.filters = filters
80 | self.use_aux_outputs = use_aux_outputs
81 |
82 |
--------------------------------------------------------------------------------
/models/film_net/pyramid_flow_estimator.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """TF2 layer for estimating optical flow by a residual flow pyramid.
16 |
17 | This approach of estimating optical flow between two images can be traced back
18 | to [1], but is also used by later neural optical flow computation methods such
19 | as SpyNet [2] and PWC-Net [3].
20 |
21 | The basic idea is that the optical flow is first estimated in a coarse
22 | resolution, then the flow is upsampled to warp the higher resolution image and
23 | then a residual correction is computed and added to the estimated flow. This
24 | process is repeated in a pyramid on coarse to fine order to successively
25 | increase the resolution of both optical flow and the warped image.
26 |
27 | In here, the optical flow predictor is used as an internal component for the
28 | film_net frame interpolator, to warp the two input images into the inbetween,
29 | target frame.
30 |
31 | [1] F. Glazer, Hierarchical motion detection. PhD thesis, 1987.
32 | [2] A. Ranjan and M. J. Black, Optical Flow Estimation using a Spatial Pyramid
33 | Network. 2016
34 | [3] D. Sun X. Yang, M-Y. Liu and J. Kautz, PWC-Net: CNNs for Optical Flow Using
35 | Pyramid, Warping, and Cost Volume, 2017
36 | """
37 |
38 | from typing import List
39 |
40 | from . import options
41 | from . import util
42 | import tensorflow as tf
43 |
44 |
45 | def _relu(x: tf.Tensor) -> tf.Tensor:
46 | return tf.nn.leaky_relu(x, alpha=0.2)
47 |
48 |
49 | class FlowEstimator(tf.keras.layers.Layer):
50 | """Small-receptive field predictor for computing the flow between two images.
51 |
52 | This is used to compute the residual flow fields in PyramidFlowEstimator.
53 |
54 | Note that while the number of 3x3 convolutions & filters to apply is
55 | configurable, two extra 1x1 convolutions are appended to extract the flow in
56 | the end.
57 |
58 | Attributes:
59 | name: The name of the layer
60 | num_convs: Number of 3x3 convolutions to apply
61 | num_filters: Number of filters in each 3x3 convolution
62 | """
63 |
64 | def __init__(self, name: str, num_convs: int, num_filters: int):
65 | super(FlowEstimator, self).__init__(name=name)
66 | def conv(filters, size, name, activation=_relu):
67 | return tf.keras.layers.Conv2D(
68 | name=name,
69 | filters=filters,
70 | kernel_size=size,
71 | padding='same',
72 | activation=activation)
73 |
74 | self._convs = []
75 | for i in range(num_convs):
76 | self._convs.append(conv(filters=num_filters, size=3, name=f'conv_{i}'))
77 | self._convs.append(conv(filters=num_filters/2, size=1, name=f'conv_{i+1}'))
78 | # For the final convolution, we want no activation at all to predict the
79 | # optical flow vector values. We have done extensive testing on explicitly
80 | # bounding these values using sigmoid, but it turned out that having no
81 | # activation gives better results.
82 | self._convs.append(
83 | conv(filters=2, size=1, name=f'conv_{i+2}', activation=None))
84 |
85 | def call(self, features_a: tf.Tensor, features_b: tf.Tensor) -> tf.Tensor:
86 | """Estimates optical flow between two images.
87 |
88 | Args:
89 | features_a: per pixel feature vectors for image A (B x H x W x C)
90 | features_b: per pixel feature vectors for image B (B x H x W x C)
91 |
92 | Returns:
93 | A tensor with optical flow from A to B
94 | """
95 | net = tf.concat([features_a, features_b], axis=-1)
96 | for conv in self._convs:
97 | net = conv(net)
98 | return net
99 |
100 |
101 | class PyramidFlowEstimator(tf.keras.layers.Layer):
102 | """Predicts optical flow by coarse-to-fine refinement.
103 |
104 | Attributes:
105 | name: The name of the layer
106 | config: Options for the film_net frame interpolator
107 | """
108 |
109 | def __init__(self, name: str, config: options.Options):
110 | super(PyramidFlowEstimator, self).__init__(name=name)
111 | self._predictors = []
112 | for i in range(config.specialized_levels):
113 | self._predictors.append(
114 | FlowEstimator(
115 | name=f'flow_predictor_{i}',
116 | num_convs=config.flow_convs[i],
117 | num_filters=config.flow_filters[i]))
118 | shared_predictor = FlowEstimator(
119 | name='flow_predictor_shared',
120 | num_convs=config.flow_convs[-1],
121 | num_filters=config.flow_filters[-1])
122 | for i in range(config.specialized_levels, config.pyramid_levels):
123 | self._predictors.append(shared_predictor)
124 |
125 | def call(self, feature_pyramid_a: List[tf.Tensor],
126 | feature_pyramid_b: List[tf.Tensor]) -> List[tf.Tensor]:
127 | """Estimates residual flow pyramids between two image pyramids.
128 |
129 | Each image pyramid is represented as a list of tensors in fine-to-coarse
130 | order. Each individual image is represented as a tensor where each pixel is
131 | a vector of image features.
132 |
133 | util.flow_pyramid_synthesis can be used to convert the residual flow
134 | pyramid returned by this method into a flow pyramid, where each level
135 | encodes the flow instead of a residual correction.
136 |
137 | Args:
138 | feature_pyramid_a: image pyramid as a list in fine-to-coarse order
139 | feature_pyramid_b: image pyramid as a list in fine-to-coarse order
140 |
141 | Returns:
142 | List of flow tensors, in fine-to-coarse order, each level encoding the
143 | difference against the bilinearly upsampled version from the coarser
144 | level. The coarsest flow tensor, e.g. the last element in the array is the
145 | 'DC-term', e.g. not a residual (alternatively you can think of it being a
146 | residual against zero).
147 | """
148 | levels = len(feature_pyramid_a)
149 | v = self._predictors[-1](feature_pyramid_a[-1], feature_pyramid_b[-1])
150 | residuals = [v]
151 | for i in reversed(range(0, levels-1)):
152 | # Upsamples the flow to match the current pyramid level. Also, scales the
153 | # magnitude by two to reflect the new size.
154 | level_size = tf.shape(feature_pyramid_a[i])[1:3]
155 | v = tf.image.resize(images=2*v, size=level_size)
156 | # Warp feature_pyramid_b[i] image based on the current flow estimate.
157 | warped = util.warp(feature_pyramid_b[i], v)
158 | # Estimate the residual flow between pyramid_a[i] and warped image:
159 | v_residual = self._predictors[i](feature_pyramid_a[i], warped)
160 | residuals.append(v_residual)
161 | v = v_residual + v
162 | # Use reversed() to return in the 'standard' finest-first-order:
163 | return list(reversed(residuals))
164 |
--------------------------------------------------------------------------------
/models/film_net/util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Various utilities used in the film_net frame interpolator model."""
16 | from typing import List
17 |
18 | from .options import Options
19 | import tensorflow as tf
20 | import tensorflow_addons.image as tfa_image
21 |
22 |
23 | def build_image_pyramid(image: tf.Tensor,
24 | options: Options) -> List[tf.Tensor]:
25 | """Builds an image pyramid from a given image.
26 |
27 | The original image is included in the pyramid and the rest are generated by
28 | successively halving the resolution.
29 |
30 | Args:
31 | image: the input image.
32 | options: film_net options object
33 |
34 | Returns:
35 | A list of images starting from the finest with options.pyramid_levels items
36 | """
37 | levels = options.pyramid_levels
38 | pyramid = []
39 | pool = tf.keras.layers.AveragePooling2D(
40 | pool_size=2, strides=2, padding='valid')
41 | for i in range(0, levels):
42 | pyramid.append(image)
43 | if i < levels-1:
44 | image = pool(image)
45 | return pyramid
46 |
47 |
48 | def warp(image: tf.Tensor, flow: tf.Tensor) -> tf.Tensor:
49 | """Backward warps the image using the given flow.
50 |
51 | Specifically, the output pixel in batch b, at position x, y will be computed
52 | as follows:
53 | (flowed_y, flowed_x) = (y+flow[b, y, x, 1], x+flow[b, y, x, 0])
54 | output[b, y, x] = bilinear_lookup(image, b, flowed_y, flowed_x)
55 |
56 | Note that the flow vectors are expected as [x, y], e.g. x in position 0 and
57 | y in position 1.
58 |
59 | Args:
60 | image: An image with shape BxHxWxC.
61 | flow: A flow with shape BxHxWx2, with the two channels denoting the relative
62 | offset in order: (dx, dy).
63 | Returns:
64 | A warped image.
65 | """
66 | # tfa_image.dense_image_warp expects unconventional negated optical flow, so
67 | # negate the flow here. Also revert x and y for compatibility with older saved
68 | # models trained with custom warp op that stored (x, y) instead of (y, x) flow
69 | # vectors.
70 | flow = -flow[..., ::-1]
71 |
72 | # Note: we have to wrap tfa_image.dense_image_warp into a Keras Lambda,
73 | # because it is not compatible with Keras symbolic tensors and we want to use
74 | # this code as part of a Keras model. Wrapping it into a lambda has the
75 | # consequence that tfa_image.dense_image_warp is only called once the tensors
76 | # are concrete, e.g. actually contain data. The inner lambda is a workaround
77 | # for passing two parameters, e.g you would really want to write:
78 | # tf.keras.layers.Lambda(tfa_image.dense_image_warp)(image, flow), but this is
79 | # not supported by the Keras Lambda.
80 | warped = tf.keras.layers.Lambda(
81 | lambda x: tfa_image.dense_image_warp(*x))((image, flow))
82 | return tf.reshape(warped, shape=tf.shape(image))
83 |
84 |
85 | def multiply_pyramid(pyramid: List[tf.Tensor],
86 | scalar: tf.Tensor) -> List[tf.Tensor]:
87 | """Multiplies all image batches in the pyramid by a batch of scalars.
88 |
89 | Args:
90 | pyramid: Pyramid of image batches.
91 | scalar: Batch of scalars.
92 |
93 | Returns:
94 | An image pyramid with all images multiplied by the scalar.
95 | """
96 | # To multiply each image with its corresponding scalar, we first transpose
97 | # the batch of images from BxHxWxC-format to CxHxWxB. This can then be
98 | # multiplied with a batch of scalars, then we transpose back to the standard
99 | # BxHxWxC form.
100 | return [
101 | tf.transpose(tf.transpose(image, [3, 1, 2, 0]) * scalar, [3, 1, 2, 0])
102 | for image in pyramid
103 | ]
104 |
105 |
106 | def flow_pyramid_synthesis(
107 | residual_pyramid: List[tf.Tensor]) -> List[tf.Tensor]:
108 | """Converts a residual flow pyramid into a flow pyramid."""
109 | flow = residual_pyramid[-1]
110 | flow_pyramid = [flow]
111 | for residual_flow in reversed(residual_pyramid[:-1]):
112 | level_size = tf.shape(residual_flow)[1:3]
113 | flow = tf.image.resize(images=2*flow, size=level_size)
114 | flow = residual_flow + flow
115 | flow_pyramid.append(flow)
116 | # Use reversed() to return in the 'standard' finest-first-order:
117 | return list(reversed(flow_pyramid))
118 |
119 |
120 | def pyramid_warp(feature_pyramid: List[tf.Tensor],
121 | flow_pyramid: List[tf.Tensor]) -> List[tf.Tensor]:
122 | """Warps the feature pyramid using the flow pyramid.
123 |
124 | Args:
125 | feature_pyramid: feature pyramid starting from the finest level.
126 | flow_pyramid: flow fields, starting from the finest level.
127 |
128 | Returns:
129 | Reverse warped feature pyramid.
130 | """
131 | warped_feature_pyramid = []
132 | for features, flow in zip(feature_pyramid, flow_pyramid):
133 | warped_feature_pyramid.append(warp(features, flow))
134 | return warped_feature_pyramid
135 |
136 |
137 | def concatenate_pyramids(pyramid1: List[tf.Tensor],
138 | pyramid2: List[tf.Tensor]) -> List[tf.Tensor]:
139 | """Concatenates each pyramid level together in the channel dimension."""
140 | result = []
141 | for features1, features2 in zip(pyramid1, pyramid2):
142 | result.append(tf.concat([features1, features2], axis=-1))
143 | return result
144 |
--------------------------------------------------------------------------------
/moment.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/frame-interpolation/69f8708f08e62c2edf46a27616a4bfcf083e2076/moment.gif
--------------------------------------------------------------------------------
/photos/one.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/frame-interpolation/69f8708f08e62c2edf46a27616a4bfcf083e2076/photos/one.png
--------------------------------------------------------------------------------
/photos/two.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/frame-interpolation/69f8708f08e62c2edf46a27616a4bfcf083e2076/photos/two.png
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | import numpy as np
4 | import tempfile
5 | import tensorflow as tf
6 | import mediapy
7 | from PIL import Image
8 | import cog
9 |
10 | from eval import interpolator, util
11 |
12 | _UINT8_MAX_F = float(np.iinfo(np.uint8).max)
13 |
14 |
15 | class Predictor(cog.Predictor):
16 | def setup(self):
17 | import tensorflow as tf
18 | print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
19 | self.interpolator = interpolator.Interpolator("pretrained_models/film_net/Style/saved_model", None)
20 |
21 | # Batched time.
22 | self.batch_dt = np.full(shape=(1,), fill_value=0.5, dtype=np.float32)
23 |
24 | @cog.input(
25 | "frame1",
26 | type=Path,
27 | help="The first input frame",
28 | )
29 | @cog.input(
30 | "frame2",
31 | type=Path,
32 | help="The second input frame",
33 | )
34 | @cog.input(
35 | "times_to_interpolate",
36 | type=int,
37 | default=1,
38 | min=1,
39 | max=8,
40 | help="Controls the number of times the frame interpolator is invoked If set to 1, the output will be the "
41 | "sub-frame at t=0.5; when set to > 1, the output will be the interpolation video with "
42 | "(2^times_to_interpolate + 1) frames, fps of 30.",
43 | )
44 | def predict(self, frame1, frame2, times_to_interpolate):
45 | INPUT_EXT = ['.png', '.jpg', '.jpeg']
46 | assert os.path.splitext(str(frame1))[-1] in INPUT_EXT and os.path.splitext(str(frame2))[-1] in INPUT_EXT, \
47 | "Please provide png, jpg or jpeg images."
48 |
49 | # make sure 2 images are the same size
50 | img1 = Image.open(str(frame1))
51 | img2 = Image.open(str(frame2))
52 | if not img1.size == img2.size:
53 | img1 = img1.crop((0, 0, min(img1.size[0], img2.size[0]), min(img1.size[1], img2.size[1])))
54 | img2 = img2.crop((0, 0, min(img1.size[0], img2.size[0]), min(img1.size[1], img2.size[1])))
55 | frame1 = 'new_frame1.png'
56 | frame2 = 'new_frame2.png'
57 | img1.save(frame1)
58 | img2.save(frame2)
59 |
60 | if times_to_interpolate == 1:
61 | # First batched image.
62 | image_1 = util.read_image(str(frame1))
63 | image_batch_1 = np.expand_dims(image_1, axis=0)
64 |
65 | # Second batched image.
66 | image_2 = util.read_image(str(frame2))
67 | image_batch_2 = np.expand_dims(image_2, axis=0)
68 |
69 | # Invoke the model once.
70 |
71 | mid_frame = self.interpolator.interpolate(image_batch_1, image_batch_2, self.batch_dt)[0]
72 | out_path = Path(tempfile.mkdtemp()) / "out.png"
73 | util.write_image(str(out_path), mid_frame)
74 | return out_path
75 |
76 |
77 | input_frames = [str(frame1), str(frame2)]
78 |
79 | frames = list(
80 | util.interpolate_recursively_from_files(
81 | input_frames, times_to_interpolate, self.interpolator))
82 | print('Interpolated frames generated, saving now as output video.')
83 |
84 | ffmpeg_path = util.get_ffmpeg_path()
85 | mediapy.set_ffmpeg(ffmpeg_path)
86 | out_path = Path(tempfile.mkdtemp()) / "out.mp4"
87 | mediapy.write_video(str(out_path), frames, fps=30)
88 | return out_path
89 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # Docker base image: `gcr.io/deeplearning-platform-release/tf2-gpu.2-6:latest`
2 | tensorflow==2.6.2 # The latest should include tensorflow-gpu
3 | tensorflow-datasets==4.4.0
4 | tensorflow-addons==0.15.0
5 | absl-py==0.12.0
6 | gin-config==0.5.0
7 | parameterized==0.8.1
8 | mediapy==1.0.3
9 | scikit-image==0.19.1
10 | apache-beam==2.34.0
11 | google-cloud-bigquery-storage==1.1.0 # Suppresses a harmless error from beam
12 | natsort==8.1.0
13 | gdown==4.5.4
14 | tqdm==4.64.1
--------------------------------------------------------------------------------
/training/augmentation_lib.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Dataset augmentation for frame interpolation."""
16 | from typing import Callable, Dict, List
17 |
18 | import gin.tf
19 | import numpy as np
20 | import tensorflow as tf
21 | import tensorflow.math as tfm
22 | import tensorflow_addons.image as tfa_image
23 |
24 | _PI = 3.141592653589793
25 |
26 |
27 | def _rotate_flow_vectors(flow: tf.Tensor, angle_rad: float) -> tf.Tensor:
28 | r"""Rotate the (u,v) vector of each pixel with angle in radians.
29 |
30 | Flow matrix system of coordinates.
31 | . . . . u (x)
32 | .
33 | .
34 | . v (-y)
35 |
36 | Rotation system of coordinates.
37 | . y
38 | .
39 | .
40 | . . . . x
41 | Args:
42 | flow: Flow map which has been image-rotated.
43 | angle_rad: The rotation angle in radians.
44 |
45 | Returns:
46 | A flow with the same map but each (u,v) vector rotated by angle_rad.
47 | """
48 | u, v = tf.split(flow, 2, axis=-1)
49 | # rotu = u * cos(angle) - (-v) * sin(angle)
50 | rot_u = tfm.cos(angle_rad) * u + tfm.sin(angle_rad) * v
51 | # rotv = -(u * sin(theta) + (-v) * cos(theta))
52 | rot_v = -tfm.sin(angle_rad) * u + tfm.cos(angle_rad) * v
53 | return tf.concat((rot_u, rot_v), axis=-1)
54 |
55 |
56 | def flow_rot90(flow: tf.Tensor, k: int) -> tf.Tensor:
57 | """Rotates a flow by a multiple of 90 degrees.
58 |
59 | Args:
60 | flow: The flow image shaped (H, W, 2) to rotate by multiples of 90 degrees.
61 | k: The multiplier factor.
62 |
63 | Returns:
64 | A flow image of the same shape as the input rotated by multiples of 90
65 | degrees.
66 | """
67 | angle_rad = tf.cast(k, dtype=tf.float32) * 90. * (_PI/180.)
68 | flow = tf.image.rot90(flow, k)
69 | return _rotate_flow_vectors(flow, angle_rad)
70 |
71 |
72 | def rotate_flow(flow: tf.Tensor, angle_rad: float) -> tf.Tensor:
73 | """Rotates a flow by a the provided angle in radians.
74 |
75 | Args:
76 | flow: The flow image shaped (H, W, 2) to rotate by multiples of 90 degrees.
77 | angle_rad: The angle to ratate the flow in radians.
78 |
79 | Returns:
80 | A flow image of the same shape as the input rotated by the provided angle in
81 | radians.
82 | """
83 | flow = tfa_image.rotate(
84 | flow,
85 | angles=angle_rad,
86 | interpolation='bilinear',
87 | fill_mode='reflect')
88 | return _rotate_flow_vectors(flow, angle_rad)
89 |
90 |
91 | def flow_flip(flow: tf.Tensor) -> tf.Tensor:
92 | """Flips a flow left to right.
93 |
94 | Args:
95 | flow: The flow image shaped (H, W, 2) to flip left to right.
96 |
97 | Returns:
98 | A flow image of the same shape as the input flipped left to right.
99 | """
100 | flow = tf.image.flip_left_right(tf.identity(flow))
101 | flow_u, flow_v = tf.split(flow, 2, axis=-1)
102 | return tf.stack([-1 * flow_u, flow_v], axis=-1)
103 |
104 |
105 | def random_image_rot90(images: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
106 | """Rotates a stack of images by a random multiples of 90 degrees.
107 |
108 | Args:
109 | images: A tf.Tensor shaped (H, W, num_channels) of images stacked along the
110 | channel's axis.
111 | Returns:
112 | A tf.Tensor of the same rank as the `images` after random rotation by
113 | multiples of 90 degrees applied counter-clock wise.
114 | """
115 | random_k = tf.random.uniform((), minval=0, maxval=4, dtype=tf.int32)
116 | for key in images:
117 | images[key] = tf.image.rot90(images[key], k=random_k)
118 | return images
119 |
120 |
121 | def random_flip(images: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
122 | """Flips a stack of images randomly.
123 |
124 | Args:
125 | images: A tf.Tensor shaped (H, W, num_channels) of images stacked along the
126 | channel's axis.
127 |
128 | Returns:
129 | A tf.Tensor of the images after random left to right flip.
130 | """
131 | prob = tf.random.uniform((), minval=0, maxval=2, dtype=tf.int32)
132 | prob = tf.cast(prob, tf.bool)
133 |
134 | def _identity(image):
135 | return image
136 |
137 | def _flip_left_right(image):
138 | return tf.image.flip_left_right(image)
139 |
140 | # pylint: disable=cell-var-from-loop
141 | for key in images:
142 | images[key] = tf.cond(prob, lambda: _flip_left_right(images[key]),
143 | lambda: _identity(images[key]))
144 | return images
145 |
146 |
147 | def random_reverse(images: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
148 | """Reverses a stack of images randomly.
149 |
150 | Args:
151 | images: A dictionary of tf.Tensors, each shaped (H, W, num_channels), with
152 | each tensor being a stack of iamges along the last channel axis.
153 |
154 | Returns:
155 | A dictionary of tf.Tensors, each shaped the same as the input images dict.
156 | """
157 | prob = tf.random.uniform((), minval=0, maxval=2, dtype=tf.int32)
158 | prob = tf.cast(prob, tf.bool)
159 |
160 | def _identity(images):
161 | return images
162 |
163 | def _reverse(images):
164 | images['x0'], images['x1'] = images['x1'], images['x0']
165 | return images
166 |
167 | return tf.cond(prob, lambda: _reverse(images), lambda: _identity(images))
168 |
169 |
170 | def random_rotate(images: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
171 | """Rotates image randomly with [-45 to 45 degrees].
172 |
173 | Args:
174 | images: A tf.Tensor shaped (H, W, num_channels) of images stacked along the
175 | channel's axis.
176 |
177 | Returns:
178 | A tf.Tensor of the images after random rotation with a bound of -72 to 72
179 | degrees.
180 | """
181 | prob = tf.random.uniform((), minval=0, maxval=2, dtype=tf.int32)
182 | prob = tf.cast(prob, tf.float32)
183 | random_angle = tf.random.uniform((),
184 | minval=-0.25 * np.pi,
185 | maxval=0.25 * np.pi,
186 | dtype=tf.float32)
187 |
188 | for key in images:
189 | images[key] = tfa_image.rotate(
190 | images[key],
191 | angles=random_angle * prob,
192 | interpolation='bilinear',
193 | fill_mode='constant')
194 | return images
195 |
196 |
197 | @gin.configurable('data_augmentation')
198 | def data_augmentations(
199 | names: List[str]) -> Dict[str, Callable[..., tf.Tensor]]:
200 | """Creates the data augmentation functions.
201 |
202 | Args:
203 | names: The list of augmentation function names.
204 | Returns:
205 | A dictionary of Callables to the augmentation functions, keyed by their
206 | names.
207 | """
208 | augmentations = dict()
209 | for name in names:
210 | if name == 'random_image_rot90':
211 | augmentations[name] = random_image_rot90
212 | elif name == 'random_rotate':
213 | augmentations[name] = random_rotate
214 | elif name == 'random_flip':
215 | augmentations[name] = random_flip
216 | elif name == 'random_reverse':
217 | augmentations[name] = random_reverse
218 | else:
219 | raise AttributeError('Invalid augmentation function %s' % name)
220 | return augmentations
221 |
--------------------------------------------------------------------------------
/training/build_saved_model_cli.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | r"""Converts TF2 training checkpoint to a saved model.
16 |
17 | The model must match the checkpoint, so the gin config must be given.
18 |
19 | Usage example:
20 | python3 -m frame_interpolation.training.build_saved_model_cli \
21 | --gin_config \
22 | --base_folder \
23 | --label
24 |
25 | This will produce a saved model into: //saved_model
26 | """
27 | import os
28 | from typing import Sequence
29 |
30 | from . import model_lib
31 | from absl import app
32 | from absl import flags
33 | from absl import logging
34 | import gin.tf
35 | import tensorflow as tf
36 | tf.get_logger().setLevel('ERROR')
37 |
38 | _GIN_CONFIG = flags.DEFINE_string(
39 | name='gin_config',
40 | default='config.gin',
41 | help='Gin config file, saved in the training session .')
42 | _LABEL = flags.DEFINE_string(
43 | name='label',
44 | default=None,
45 | required=True,
46 | help='Descriptive label for the training session.')
47 | _BASE_FOLDER = flags.DEFINE_string(
48 | name='base_folder',
49 | default=None,
50 | help='Path to all training sessions.')
51 | _MODE = flags.DEFINE_enum(
52 | name='mode',
53 | default=None,
54 | enum_values=['cpu', 'gpu', 'tpu'],
55 | help='Distributed strategy approach.')
56 |
57 |
58 | def _build_saved_model(checkpoint_path: str, config_files: Sequence[str],
59 | output_model_path: str):
60 | """Builds a saved model based on the checkpoint directory."""
61 | gin.parse_config_files_and_bindings(
62 | config_files=config_files,
63 | bindings=None,
64 | skip_unknown=True)
65 | model = model_lib.create_model()
66 | checkpoint = tf.train.Checkpoint(model=model)
67 | checkpoint_file = tf.train.latest_checkpoint(checkpoint_path)
68 | try:
69 | logging.info('Restoring from %s', checkpoint_file)
70 | status = checkpoint.restore(checkpoint_file)
71 | status.assert_existing_objects_matched()
72 | status.expect_partial()
73 | model.save(output_model_path)
74 | except (tf.errors.NotFoundError, AssertionError) as err:
75 | logging.info('Failed to restore checkpoint from %s. Error:\n%s',
76 | checkpoint_file, err)
77 |
78 |
79 | def main(argv):
80 | if len(argv) > 1:
81 | raise app.UsageError('Too many command-line arguments.')
82 |
83 | checkpoint_path = os.path.join(_BASE_FOLDER.value, _LABEL.value, 'train')
84 | if not tf.io.gfile.exists(_GIN_CONFIG.value):
85 | config_file = os.path.join(_BASE_FOLDER.value, _LABEL.value,
86 | _GIN_CONFIG.value)
87 | else:
88 | config_file = _GIN_CONFIG.value
89 | output_model_path = os.path.join(_BASE_FOLDER.value, _LABEL.value,
90 | 'saved_model')
91 | _build_saved_model(
92 | checkpoint_path=checkpoint_path,
93 | config_files=[config_file],
94 | output_model_path=output_model_path)
95 | logging.info('The saved model stored into %s/.', output_model_path)
96 |
97 | if __name__ == '__main__':
98 | app.run(main)
99 |
--------------------------------------------------------------------------------
/training/config/film_net-L1.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | model.name = 'film_net'
16 |
17 | film_net.pyramid_levels = 7
18 | film_net.fusion_pyramid_levels = 5
19 | film_net.specialized_levels = 3
20 | film_net.sub_levels = 4
21 | film_net.flow_convs = [3, 3, 3, 3]
22 | film_net.flow_filters = [32, 64, 128, 256]
23 | film_net.filters = 64
24 |
25 | training.learning_rate = 0.0001
26 | training.learning_rate_decay_steps = 750000
27 | training.learning_rate_decay_rate = 0.464158
28 | training.learning_rate_staircase = True
29 | training.num_steps = 3000000
30 |
31 | # in the sweep
32 | training_dataset.file = 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/vimeo_interp_train.tfrecord@200'
33 | training_dataset.batch_size = 8
34 | training_dataset.crop_size = 256
35 |
36 | eval_datasets.batch_size = 1
37 | eval_datasets.max_examples = -1
38 | # eval_datasets.files = ['gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/vimeo_interp_test.tfrecord@3',
39 | # 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/middlebury_other.tfrecord@3',
40 | # 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/UCF101_interp_test.tfrecord@2',
41 | # 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/xiph_2K.tfrecord@2',
42 | # 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/xiph_4K.tfrecord@2']
43 | # eval_datasets.names = ['vimeo90K', 'middlebury', 'ucf101', 'xiph2K', 'xiph4K']
44 | eval_datasets.files = []
45 | eval_datasets.names = []
46 |
47 | # Training augmentation (in addition to random crop)
48 | data_augmentation.names = ['random_image_rot90', 'random_flip', 'random_rotate', 'random_reverse']
49 |
50 | # Loss functions
51 | training_losses.loss_names = ['l1']
52 | training_losses.loss_weights = [1.0]
53 |
54 | test_losses.loss_names = ['l1', 'psnr', 'ssim']
55 | test_losses.loss_weights = [1.0, 1.0, 1.0]
56 |
--------------------------------------------------------------------------------
/training/config/film_net-Style.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | model.name = 'film_net'
16 |
17 | film_net.pyramid_levels = 7
18 | film_net.fusion_pyramid_levels = 5
19 | film_net.specialized_levels = 3
20 | film_net.sub_levels = 4
21 | film_net.flow_convs = [3, 3, 3, 3]
22 | film_net.flow_filters = [32, 64, 128, 256]
23 | film_net.filters = 64
24 |
25 | training.learning_rate = 0.0001
26 | training.learning_rate_decay_steps = 750000
27 | training.learning_rate_decay_rate = 0.464158
28 | training.learning_rate_staircase = True
29 | training.num_steps = 3000000
30 |
31 | # in the sweep
32 | training_dataset.file = 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/vimeo_interp_train.tfrecord@200'
33 | training_dataset.batch_size = 8
34 | training_dataset.crop_size = 256
35 |
36 | eval_datasets.batch_size = 1
37 | eval_datasets.max_examples = -1
38 | # eval_datasets.files = ['gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/vimeo_interp_test.tfrecord@3',
39 | # 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/middlebury_other.tfrecord@3',
40 | # 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/UCF101_interp_test.tfrecord@2',
41 | # 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/xiph_2K.tfrecord@2',
42 | # 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/xiph_4K.tfrecord@2']
43 | # eval_datasets.names = ['vimeo90K', 'middlebury', 'ucf101', 'xiph2K', 'xiph4K']
44 | eval_datasets.files = []
45 | eval_datasets.names = []
46 |
47 | # Training augmentation (in addition to random crop)
48 | data_augmentation.names = ['random_image_rot90', 'random_flip', 'random_rotate', 'random_reverse']
49 |
50 | # Loss functions
51 | training_losses.loss_names = ['l1', 'vgg', 'style']
52 | training_losses.loss_weight_schedules = [
53 | @tf.keras.optimizers.schedules.PiecewiseConstantDecay,
54 | @tf.keras.optimizers.schedules.PiecewiseConstantDecay,
55 | @tf.keras.optimizers.schedules.PiecewiseConstantDecay]
56 | # Increase the weight of style loss at 1.5M steps.
57 | training_losses.loss_weight_parameters = [
58 | {'boundaries':[0], 'values':[1.0, 1.0]},
59 | {'boundaries':[1500000], 'values':[1.0, 0.25]},
60 | {'boundaries':[1500000], 'values':[0.0, 40.0]}]
61 |
62 | test_losses.loss_names = ['l1', 'psnr', 'ssim']
63 | test_losses.loss_weights = [1.0, 1.0, 1.0]
64 |
65 | vgg.vgg_model_file = 'gs://xcloud-shared/fitsumreda/frame_interpolation/pretrained_models/vgg/imagenet-vgg-verydeep-19.mat'
66 | style.vgg_model_file = 'gs://xcloud-shared/fitsumreda/frame_interpolation/pretrained_models/vgg/imagenet-vgg-verydeep-19.mat'
67 |
--------------------------------------------------------------------------------
/training/config/film_net-VGG.gin:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | model.name = 'film_net'
16 |
17 | film_net.pyramid_levels = 7
18 | film_net.fusion_pyramid_levels = 5
19 | film_net.specialized_levels = 3
20 | film_net.sub_levels = 4
21 | film_net.flow_convs = [3, 3, 3, 3]
22 | film_net.flow_filters = [32, 64, 128, 256]
23 | film_net.filters = 64
24 |
25 | training.learning_rate = 0.0001
26 | training.learning_rate_decay_steps = 750000
27 | training.learning_rate_decay_rate = 0.464158
28 | training.learning_rate_staircase = True
29 | training.num_steps = 3000000
30 |
31 | # in the sweep
32 | training_dataset.file = 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/vimeo_interp_train.tfrecord@200'
33 | training_dataset.batch_size = 8
34 | training_dataset.crop_size = 256
35 |
36 | eval_datasets.batch_size = 1
37 | eval_datasets.max_examples = -1
38 | # eval_datasets.files = ['gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/vimeo_interp_test.tfrecord@3',
39 | # 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/middlebury_other.tfrecord@3',
40 | # 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/UCF101_interp_test.tfrecord@2',
41 | # 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/xiph_2K.tfrecord@2',
42 | # 'gs://xcloud-shared/fitsumreda/frame_interpolation/datasets/xiph_4K.tfrecord@2']
43 | # eval_datasets.names = ['vimeo90K', 'middlebury', 'ucf101', 'xiph2K', 'xiph4K']
44 | eval_datasets.files = []
45 | eval_datasets.names = []
46 |
47 | # Training augmentation (in addition to random crop)
48 | data_augmentation.names = ['random_image_rot90', 'random_flip', 'random_rotate', 'random_reverse']
49 |
50 | # Loss functions
51 | training_losses.loss_names = ['l1', 'vgg']
52 | training_losses.loss_weight_schedules = [
53 | @tf.keras.optimizers.schedules.PiecewiseConstantDecay,
54 | @tf.keras.optimizers.schedules.PiecewiseConstantDecay]
55 |
56 | # Decrease the weight of VGG loss at 1.5M steps.
57 | training_losses.loss_weight_parameters = [
58 | {'boundaries':[0], 'values':[1.0, 1.0]},
59 | {'boundaries':[1500000], 'values':[1.0, 0.25]}]
60 |
61 | test_losses.loss_names = ['l1', 'psnr', 'ssim']
62 | test_losses.loss_weights = [1.0, 1.0, 1.0]
63 |
64 | vgg.vgg_model_file = 'gs://xcloud-shared/fitsumreda/frame_interpolation/pretrained_models/vgg/imagenet-vgg-verydeep-19.mat'
65 |
--------------------------------------------------------------------------------
/training/data_lib.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Dataset creation for frame interpolation."""
16 | from typing import Callable, Dict, List, Optional
17 |
18 | from absl import logging
19 | import gin.tf
20 | import tensorflow as tf
21 |
22 |
23 | def _create_feature_map() -> Dict[str, tf.io.FixedLenFeature]:
24 | """Creates the feature map for extracting the frame triplet."""
25 | feature_map = {
26 | 'frame_0/encoded':
27 | tf.io.FixedLenFeature((), tf.string, default_value=''),
28 | 'frame_0/format':
29 | tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
30 | 'frame_0/height':
31 | tf.io.FixedLenFeature((), tf.int64, default_value=0),
32 | 'frame_0/width':
33 | tf.io.FixedLenFeature((), tf.int64, default_value=0),
34 | 'frame_1/encoded':
35 | tf.io.FixedLenFeature((), tf.string, default_value=''),
36 | 'frame_1/format':
37 | tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
38 | 'frame_1/height':
39 | tf.io.FixedLenFeature((), tf.int64, default_value=0),
40 | 'frame_1/width':
41 | tf.io.FixedLenFeature((), tf.int64, default_value=0),
42 | 'frame_2/encoded':
43 | tf.io.FixedLenFeature((), tf.string, default_value=''),
44 | 'frame_2/format':
45 | tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
46 | 'frame_2/height':
47 | tf.io.FixedLenFeature((), tf.int64, default_value=0),
48 | 'frame_2/width':
49 | tf.io.FixedLenFeature((), tf.int64, default_value=0),
50 | 'path':
51 | tf.io.FixedLenFeature((), tf.string, default_value=''),
52 | }
53 | return feature_map
54 |
55 |
56 | def _parse_example(sample):
57 | """Parses a serialized sample.
58 |
59 | Args:
60 | sample: A serialized tf.Example to be parsed.
61 |
62 | Returns:
63 | dictionary containing the following:
64 | encoded_image
65 | image_height
66 | image_width
67 | """
68 | feature_map = _create_feature_map()
69 | features = tf.io.parse_single_example(sample, feature_map)
70 | output_dict = {
71 | 'x0': tf.io.decode_image(features['frame_0/encoded'], dtype=tf.float32),
72 | 'x1': tf.io.decode_image(features['frame_2/encoded'], dtype=tf.float32),
73 | 'y': tf.io.decode_image(features['frame_1/encoded'], dtype=tf.float32),
74 | # The fractional time value of frame_1 is not included in our tfrecords,
75 | # but is always at 0.5. The model will expect this to be specificed, so
76 | # we insert it here.
77 | 'time': 0.5,
78 | # Store the original mid frame filepath for identifying examples.
79 | 'path': features['path'],
80 | }
81 |
82 | return output_dict
83 |
84 |
85 | def _random_crop_images(crop_size: int, images: tf.Tensor,
86 | total_channel_size: int) -> tf.Tensor:
87 | """Crops the tensor with random offset to the given size."""
88 | if crop_size > 0:
89 | crop_shape = tf.constant([crop_size, crop_size, total_channel_size])
90 | images = tf.image.random_crop(images, crop_shape)
91 | return images
92 |
93 |
94 | def crop_example(example: tf.Tensor, crop_size: int,
95 | crop_keys: Optional[List[str]] = None):
96 | """Random crops selected images in the example to given size and keys.
97 |
98 | Args:
99 | example: Input tensor representing images to be cropped.
100 | crop_size: The size to crop images to. This value is used for both
101 | height and width.
102 | crop_keys: The images in the input example to crop.
103 |
104 | Returns:
105 | Example with cropping applied to selected images.
106 | """
107 | if crop_keys is None:
108 | crop_keys = ['x0', 'x1', 'y']
109 | channels = [3, 3, 3]
110 |
111 | # Stack images along channel axis, and perform a random crop once.
112 | image_to_crop = [example[key] for key in crop_keys]
113 | stacked_images = tf.concat(image_to_crop, axis=-1)
114 | cropped_images = _random_crop_images(crop_size, stacked_images, sum(channels))
115 | cropped_images = tf.split(
116 | cropped_images, num_or_size_splits=channels, axis=-1)
117 | for key, cropped_image in zip(crop_keys, cropped_images):
118 | example[key] = cropped_image
119 | return example
120 |
121 |
122 | def apply_data_augmentation(
123 | augmentation_fns: Dict[str, Callable[..., tf.Tensor]],
124 | example: tf.Tensor,
125 | augmentation_keys: Optional[List[str]] = None) -> tf.Tensor:
126 | """Applies random augmentation in succession to selected image keys.
127 |
128 | Args:
129 | augmentation_fns: A Dict of Callables to data augmentation functions.
130 | example: Input tensor representing images to be augmented.
131 | augmentation_keys: The images in the input example to augment.
132 |
133 | Returns:
134 | Example with augmentation applied to selected images.
135 | """
136 | if augmentation_keys is None:
137 | augmentation_keys = ['x0', 'x1', 'y']
138 |
139 | # Apply each augmentation in sequence
140 | augmented_images = {key: example[key] for key in augmentation_keys}
141 | for augmentation_function in augmentation_fns.values():
142 | augmented_images = augmentation_function(augmented_images)
143 |
144 | for key in augmentation_keys:
145 | example[key] = augmented_images[key]
146 | return example
147 |
148 |
149 | def _create_from_tfrecord(batch_size, file, augmentation_fns,
150 | crop_size) -> tf.data.Dataset:
151 | """Creates a dataset from TFRecord."""
152 | dataset = tf.data.TFRecordDataset(file)
153 | dataset = dataset.map(
154 | _parse_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)
155 |
156 | # Perform data_augmentation before cropping and batching
157 | if augmentation_fns is not None:
158 | dataset = dataset.map(
159 | lambda x: apply_data_augmentation(augmentation_fns, x),
160 | num_parallel_calls=tf.data.experimental.AUTOTUNE)
161 |
162 | if crop_size > 0:
163 | dataset = dataset.map(
164 | lambda x: crop_example(x, crop_size=crop_size),
165 | num_parallel_calls=tf.data.experimental.AUTOTUNE)
166 | dataset = dataset.batch(batch_size, drop_remainder=True)
167 | return dataset
168 |
169 |
170 | def _generate_sharded_filenames(filename: str) -> List[str]:
171 | """Generates filenames of the each file in the sharded filepath.
172 |
173 | Based on github.com/google/revisiting-self-supervised/blob/master/datasets.py.
174 |
175 | Args:
176 | filename: The sharded filepath.
177 |
178 | Returns:
179 | A list of filepaths for each file in the shard.
180 | """
181 | base, count = filename.split('@')
182 | count = int(count)
183 | return ['{}-{:05d}-of-{:05d}'.format(base, i, count) for i in range(count)]
184 |
185 |
186 | def _create_from_sharded_tfrecord(batch_size,
187 | train_mode,
188 | file,
189 | augmentation_fns,
190 | crop_size,
191 | max_examples=-1) -> tf.data.Dataset:
192 | """Creates a dataset from a sharded tfrecord."""
193 | dataset = tf.data.Dataset.from_tensor_slices(
194 | _generate_sharded_filenames(file))
195 |
196 | # pylint: disable=g-long-lambda
197 | dataset = dataset.interleave(
198 | lambda x: _create_from_tfrecord(
199 | batch_size,
200 | file=x,
201 | augmentation_fns=augmentation_fns,
202 | crop_size=crop_size),
203 | num_parallel_calls=tf.data.AUTOTUNE,
204 | deterministic=not train_mode)
205 | # pylint: enable=g-long-lambda
206 | dataset = dataset.prefetch(buffer_size=2)
207 | if max_examples > 0:
208 | return dataset.take(max_examples)
209 | return dataset
210 |
211 |
212 | @gin.configurable('training_dataset')
213 | def create_training_dataset(
214 | batch_size: int,
215 | file: Optional[str] = None,
216 | files: Optional[List[str]] = None,
217 | crop_size: int = -1,
218 | crop_sizes: Optional[List[int]] = None,
219 | augmentation_fns: Optional[Dict[str, Callable[..., tf.Tensor]]] = None
220 | ) -> tf.data.Dataset:
221 | """Creates the training dataset.
222 |
223 | The given tfrecord should contain data in a format produced by
224 | frame_interpolation/datasets/create_*_tfrecord.py
225 |
226 | Args:
227 | batch_size: The number of images to batch per example.
228 | file: (deprecated) A path to a sharded tfrecord in @N format.
229 | Deprecated. Use 'files' instead.
230 | files: A list of paths to sharded tfrecords in @N format.
231 | crop_size: (deprecated) If > 0, images are cropped to crop_size x crop_size
232 | using tensorflow's random cropping. Deprecated: use 'files' and
233 | 'crop_sizes' instead.
234 | crop_sizes: List of crop sizes. If > 0, images are cropped to
235 | crop_size x crop_size using tensorflow's random cropping.
236 | augmentation_fns: A Dict of Callables to data augmentation functions.
237 | Returns:
238 | A tensorflow dataset for accessing examples that contain the input images
239 | 'x0', 'x1', ground truth 'y' and time of the ground truth 'time'=[0,1] in a
240 | dictionary of tensors.
241 | """
242 | if file:
243 | logging.warning('gin-configurable training_dataset.file is deprecated. '
244 | 'Use training_dataset.files instead.')
245 | return _create_from_sharded_tfrecord(batch_size, True, file,
246 | augmentation_fns, crop_size)
247 | else:
248 | if not crop_sizes or len(crop_sizes) != len(files):
249 | raise ValueError('Please pass crop_sizes[] with training_dataset.files.')
250 | if crop_size > 0:
251 | raise ValueError(
252 | 'crop_size should not be used with files[], use crop_sizes[] instead.'
253 | )
254 | tables = []
255 | for file, crop_size in zip(files, crop_sizes):
256 | tables.append(
257 | _create_from_sharded_tfrecord(batch_size, True, file,
258 | augmentation_fns, crop_size))
259 | return tf.data.experimental.sample_from_datasets(tables)
260 |
261 |
262 | @gin.configurable('eval_datasets')
263 | def create_eval_datasets(batch_size: int,
264 | files: List[str],
265 | names: List[str],
266 | crop_size: int = -1,
267 | max_examples: int = -1) -> Dict[str, tf.data.Dataset]:
268 | """Creates the evaluation datasets.
269 |
270 | As opposed to create_training_dataset this function makes sure that the
271 | examples for each dataset are always read in a deterministic (same) order.
272 |
273 | Each given tfrecord should contain data in a format produced by
274 | frame_interpolation/datasets/create_*_tfrecord.py
275 |
276 | The (batch_size, crop_size, max_examples) are specified for all eval datasets.
277 |
278 | Args:
279 | batch_size: The number of images to batch per example.
280 | files: List of paths to a sharded tfrecord in @N format.
281 | names: List of names of eval datasets.
282 | crop_size: If > 0, images are cropped to crop_size x crop_size using
283 | tensorflow's random cropping.
284 | max_examples: If > 0, truncate the dataset to 'max_examples' in length. This
285 | can be useful for speeding up evaluation loop in case the tfrecord for the
286 | evaluation set is very large.
287 | Returns:
288 | A dict of name to tensorflow dataset for accessing examples that contain the
289 | input images 'x0', 'x1', ground truth 'y' and time of the ground truth
290 | 'time'=[0,1] in a dictionary of tensors.
291 | """
292 | return {
293 | name: _create_from_sharded_tfrecord(batch_size, False, file, None,
294 | crop_size, max_examples)
295 | for name, file in zip(names, files)
296 | }
297 |
--------------------------------------------------------------------------------
/training/eval_lib.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Evaluation library for frame interpolation."""
16 | from typing import Dict, Mapping, Text
17 |
18 | from absl import logging
19 | import tensorflow as tf
20 |
21 |
22 | def _collect_tensors(tensors: tf.Tensor) -> tf.Tensor:
23 | """Collect tensors of the different replicas into a list."""
24 | return tf.nest.flatten(tensors, expand_composites=True)
25 |
26 |
27 | @tf.function
28 | def _distributed_eval_step(strategy: tf.distribute.Strategy,
29 | batch: Dict[Text, tf.Tensor], model: tf.keras.Model,
30 | metrics: Dict[Text, tf.keras.metrics.Metric],
31 | checkpoint_step: int) -> Dict[Text, tf.Tensor]:
32 | """Distributed eval step.
33 |
34 | Args:
35 | strategy: A Tensorflow distribution strategy.
36 | batch: A batch of training examples.
37 | model: The Keras model to evaluate.
38 | metrics: The Keras metrics used for evaluation (a dictionary).
39 | checkpoint_step: The iteration number at which the checkpoint is restored.
40 |
41 | Returns:
42 | list of predictions from each replica.
43 | """
44 |
45 | def _eval_step(
46 | batch: Dict[Text, tf.Tensor]) -> Dict[Text, tf.Tensor]:
47 | """Eval for one step."""
48 | predictions = model(batch, training=False)
49 | # Note: these metrics expect batch and prediction dictionaries rather than
50 | # tensors like standard TF metrics do. This allows our losses and metrics to
51 | # use a richer set of inputs than just the predicted final image.
52 | for metric in metrics.values():
53 | metric.update_state(batch, predictions, checkpoint_step=checkpoint_step)
54 | return predictions
55 |
56 | return strategy.run(_eval_step, args=(batch,))
57 |
58 |
59 | def _summarize_image_tensors(combined, prefix, step):
60 | for name in combined:
61 | image = combined[name]
62 | if isinstance(image, tf.Tensor):
63 | if len(image.shape) == 4 and (image.shape[-1] == 1 or
64 | image.shape[-1] == 3):
65 | tf.summary.image(prefix + '/' + name, image, step=step)
66 |
67 |
68 | def eval_loop(strategy: tf.distribute.Strategy,
69 | eval_base_folder: str,
70 | model: tf.keras.Model,
71 | metrics: Dict[str, tf.keras.metrics.Metric],
72 | datasets: Mapping[str, tf.data.Dataset],
73 | summary_writer: tf.summary.SummaryWriter,
74 | checkpoint_step: int):
75 | """Eval function that is strategy agnostic.
76 |
77 | Args:
78 | strategy: A Tensorflow distributed strategy.
79 | eval_base_folder: A path to where the summaries event files and
80 | checkpoints will be saved.
81 | model: A function that returns the model.
82 | metrics: A function that returns the metrics dictionary.
83 | datasets: A dict of tf.data.Dataset to evaluate on.
84 | summary_writer: Eval summary writer.
85 | checkpoint_step: The number of iterations completed.
86 | """
87 | logging.info('Saving eval summaries to: %s...', eval_base_folder)
88 | summary_writer.set_as_default()
89 |
90 | for dataset_name, dataset in datasets.items():
91 | for metric in metrics.values():
92 | metric.reset_states()
93 |
94 | logging.info('Loading %s testing data ...', dataset_name)
95 | dataset = strategy.experimental_distribute_dataset(dataset)
96 |
97 | logging.info('Evaluating %s ...', dataset_name)
98 | batch_idx = 0
99 | max_batches_to_summarize = 10
100 | for batch in dataset:
101 | predictions = _distributed_eval_step(strategy, batch, model, metrics,
102 | checkpoint_step)
103 | # Clip interpolator output to [0,1]. Clipping is done only
104 | # on the eval loop to get better metrics, but not on the training loop
105 | # so gradients are not killed.
106 | if strategy.num_replicas_in_sync > 1:
107 | predictions = {
108 | 'image': tf.concat(predictions['image'].values, axis=0)
109 | }
110 | predictions['image'] = tf.clip_by_value(predictions['image'], 0., 1.)
111 | if batch_idx % 10 == 0:
112 | logging.info('Evaluating batch %s', batch_idx)
113 | batch_idx = batch_idx + 1
114 | if batch_idx < max_batches_to_summarize:
115 | # Loop through the global batch:
116 | prefix = f'{dataset_name}/eval_{batch_idx}'
117 | # Find all tensors that look like images, and summarize:
118 | combined = {**batch, **predictions}
119 | _summarize_image_tensors(combined, prefix, step=checkpoint_step)
120 |
121 | elif batch_idx == max_batches_to_summarize:
122 | tf.summary.flush()
123 |
124 | for name, metric in metrics.items():
125 | tf.summary.scalar(
126 | f'{dataset_name}/{name}', metric.result(), step=checkpoint_step)
127 | tf.summary.flush()
128 | logging.info('Step {:2}, {} {}'.format(checkpoint_step,
129 | f'{dataset_name}/{name}',
130 | metric.result().numpy()))
131 | metric.reset_states()
132 |
--------------------------------------------------------------------------------
/training/metrics_lib.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """A library for instantiating frame interpolation evaluation metrics."""
16 |
17 | from typing import Callable, Dict, Text
18 |
19 | from ..losses import losses
20 | import tensorflow as tf
21 |
22 |
23 | class TrainLossMetric(tf.keras.metrics.Metric):
24 | """Compute training loss for our example and prediction format.
25 |
26 | The purpose of this is to ensure that we always include a loss that is exactly
27 | like the training loss into the evaluation in order to detect possible
28 | overfitting.
29 | """
30 |
31 | def __init__(self, name='eval_loss', **kwargs):
32 | super(TrainLossMetric, self).__init__(name=name, **kwargs)
33 | self.acc = self.add_weight(name='train_metric_acc', initializer='zeros')
34 | self.count = self.add_weight(name='train_metric_count', initializer='zeros')
35 |
36 | def update_state(self,
37 | batch,
38 | predictions,
39 | sample_weight=None,
40 | checkpoint_step=0):
41 | loss_functions = losses.training_losses()
42 | loss_list = []
43 | for (loss_value, loss_weight) in loss_functions.values():
44 | loss_list.append(
45 | loss_value(batch, predictions) * loss_weight(checkpoint_step))
46 | loss = tf.add_n(loss_list)
47 | self.acc.assign_add(loss)
48 | self.count.assign_add(1)
49 |
50 | def result(self):
51 | return self.acc / self.count
52 |
53 | def reset_states(self):
54 | self.acc.assign(0)
55 | self.count.assign(0)
56 |
57 |
58 | class L1Metric(tf.keras.metrics.Metric):
59 | """Compute L1 over our training example and prediction format.
60 |
61 | The purpose of this is to ensure that we have at least one metric that is
62 | compatible across all eval the session and allows us to quickly compare models
63 | against each other.
64 | """
65 |
66 | def __init__(self, name='eval_loss', **kwargs):
67 | super(L1Metric, self).__init__(name=name, **kwargs)
68 | self.acc = self.add_weight(name='l1_metric_acc', initializer='zeros')
69 | self.count = self.add_weight(name='l1_metric_count', initializer='zeros')
70 |
71 | def update_state(self, batch, prediction, sample_weight=None,
72 | checkpoint_step=0):
73 | self.acc.assign_add(losses.l1_loss(batch, prediction))
74 | self.count.assign_add(1)
75 |
76 | def result(self):
77 | return self.acc / self.count
78 |
79 | def reset_states(self):
80 | self.acc.assign(0)
81 | self.count.assign(0)
82 |
83 |
84 | class GenericLossMetric(tf.keras.metrics.Metric):
85 | """Metric based on any loss function."""
86 |
87 | def __init__(self, name: str, loss: Callable[..., tf.Tensor],
88 | weight: Callable[..., tf.Tensor], **kwargs):
89 | """Initializes a metric based on a loss function and a weight schedule.
90 |
91 | Args:
92 | name: The name of the metric.
93 | loss: The callable loss that calculates a loss value for a (prediction,
94 | target) pair.
95 | weight: The callable weight scheduling function that samples a weight
96 | based on iteration.
97 | **kwargs: Any additional keyword arguments to be passed.
98 | """
99 | super(GenericLossMetric, self).__init__(name=name, **kwargs)
100 | self.acc = self.add_weight(name='loss_metric_acc', initializer='zeros')
101 | self.count = self.add_weight(name='loss_metric_count', initializer='zeros')
102 | self.loss = loss
103 | self.weight = weight
104 |
105 | def update_state(self,
106 | batch,
107 | predictions,
108 | sample_weight=None,
109 | checkpoint_step=0):
110 | self.acc.assign_add(
111 | self.loss(batch, predictions) * self.weight(checkpoint_step))
112 | self.count.assign_add(1)
113 |
114 | def result(self):
115 | return self.acc / self.count
116 |
117 | def reset_states(self):
118 | self.acc.assign(0)
119 | self.count.assign(0)
120 |
121 |
122 | def create_metrics_fn() -> Dict[Text, tf.keras.metrics.Metric]:
123 | """Create evaluation metrics.
124 |
125 | L1 and total training loss are added by default.
126 | The rest are the configured by the test_losses item via gin.
127 |
128 | Returns:
129 | A dictionary from metric name to Keras Metric object.
130 | """
131 | metrics = {}
132 | # L1 is explicitly added just so we always have some consistent numbers around
133 | # to compare across sessions.
134 | metrics['l1'] = L1Metric()
135 | # We also always include training loss for the eval set to detect overfitting:
136 | metrics['training_loss'] = TrainLossMetric()
137 |
138 | test_losses = losses.test_losses()
139 | for loss_name, (loss_value, loss_weight) in test_losses.items():
140 | metrics[loss_name] = GenericLossMetric(
141 | name=loss_name, loss=loss_value, weight=loss_weight)
142 | return metrics
143 |
--------------------------------------------------------------------------------
/training/model_lib.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """A library for instantiating the model for training frame interpolation.
16 |
17 | All models are expected to use three inputs: input image batches 'x0' and 'x1'
18 | and 'time', the fractional time where the output should be generated.
19 |
20 | The models are expected to output the prediction as a dictionary that contains
21 | at least the predicted image batch as 'image' plus optional data for debug,
22 | analysis or custom losses.
23 | """
24 |
25 | import gin.tf
26 | from ..models.film_net import interpolator as film_net_interpolator
27 | from ..models.film_net import options as film_net_options
28 |
29 | import tensorflow as tf
30 |
31 |
32 | @gin.configurable('model')
33 | def create_model(name: str) -> tf.keras.Model:
34 | """Creates the frame interpolation model based on given model name."""
35 | if name == 'film_net':
36 | return _create_film_net_model() # pylint: disable=no-value-for-parameter
37 | else:
38 | raise ValueError(f'Model {name} not implemented.')
39 |
40 |
41 | def _create_film_net_model() -> tf.keras.Model:
42 | """Creates the film_net interpolator."""
43 | # Options are gin-configured in the Options class directly.
44 | options = film_net_options.Options()
45 |
46 | x0 = tf.keras.Input(
47 | shape=(None, None, 3), batch_size=None, dtype=tf.float32, name='x0')
48 | x1 = tf.keras.Input(
49 | shape=(None, None, 3), batch_size=None, dtype=tf.float32, name='x1')
50 | time = tf.keras.Input(
51 | shape=(1,), batch_size=None, dtype=tf.float32, name='time')
52 |
53 | return film_net_interpolator.create_model(x0, x1, time, options)
54 |
--------------------------------------------------------------------------------
/training/train.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | r"""The training loop for frame interpolation.
16 |
17 | gin_config: The gin configuration file containing model, losses and datasets.
18 |
19 | To run on GPUs:
20 | python3 -m frame_interpolation.training.train \
21 | --gin_config \
22 | --base_folder \
23 | --label
24 |
25 | To debug the training loop on CPU:
26 | python3 -m frame_interpolation.training.train \
27 | --gin_config \
28 | --base_folder /tmp
29 | --label test_run \
30 | --mode cpu
31 |
32 | The training output directory will be created at /.
33 | """
34 | import os
35 |
36 | from . import augmentation_lib
37 | from . import data_lib
38 | from . import eval_lib
39 | from . import metrics_lib
40 | from . import model_lib
41 | from . import train_lib
42 | from absl import app
43 | from absl import flags
44 | from absl import logging
45 | import gin.tf
46 | from ..losses import losses
47 |
48 | # Reduce tensorflow logs to ERRORs only.
49 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
50 | import tensorflow as tf # pylint: disable=g-import-not-at-top
51 | tf.get_logger().setLevel('ERROR')
52 |
53 |
54 | _GIN_CONFIG = flags.DEFINE_string('gin_config', None, 'Gin config file.')
55 | _LABEL = flags.DEFINE_string('label', 'run0',
56 | 'Descriptive label for this run.')
57 | _BASE_FOLDER = flags.DEFINE_string('base_folder', None,
58 | 'Path to checkpoints/summaries.')
59 | _MODE = flags.DEFINE_enum('mode', 'gpu', ['cpu', 'gpu'],
60 | 'Distributed strategy approach.')
61 |
62 |
63 | @gin.configurable('training')
64 | class TrainingOptions(object):
65 | """Training-related options."""
66 |
67 | def __init__(self, learning_rate: float, learning_rate_decay_steps: int,
68 | learning_rate_decay_rate: int, learning_rate_staircase: int,
69 | num_steps: int):
70 | self.learning_rate = learning_rate
71 | self.learning_rate_decay_steps = learning_rate_decay_steps
72 | self.learning_rate_decay_rate = learning_rate_decay_rate
73 | self.learning_rate_staircase = learning_rate_staircase
74 | self.num_steps = num_steps
75 |
76 |
77 | def main(argv):
78 | if len(argv) > 1:
79 | raise app.UsageError('Too many command-line arguments.')
80 |
81 | output_dir = os.path.join(_BASE_FOLDER.value, _LABEL.value)
82 | logging.info('Creating output_dir @ %s ...', output_dir)
83 |
84 | # Copy config file to //config.gin.
85 | tf.io.gfile.makedirs(output_dir)
86 | tf.io.gfile.copy(
87 | _GIN_CONFIG.value, os.path.join(output_dir, 'config.gin'), overwrite=True)
88 |
89 | gin.external_configurable(
90 | tf.keras.optimizers.schedules.PiecewiseConstantDecay,
91 | module='tf.keras.optimizers.schedules')
92 |
93 | gin_configs = [_GIN_CONFIG.value]
94 | gin.parse_config_files_and_bindings(
95 | config_files=gin_configs, bindings=None, skip_unknown=True)
96 |
97 | training_options = TrainingOptions() # pylint: disable=no-value-for-parameter
98 |
99 | learning_rate = tf.keras.optimizers.schedules.ExponentialDecay(
100 | training_options.learning_rate,
101 | training_options.learning_rate_decay_steps,
102 | training_options.learning_rate_decay_rate,
103 | training_options.learning_rate_staircase,
104 | name='learning_rate')
105 |
106 | # Initialize data augmentation functions
107 | augmentation_fns = augmentation_lib.data_augmentations()
108 |
109 | saved_model_folder = os.path.join(_BASE_FOLDER.value, _LABEL.value,
110 | 'saved_model')
111 | train_folder = os.path.join(_BASE_FOLDER.value, _LABEL.value, 'train')
112 | eval_folder = os.path.join(_BASE_FOLDER.value, _LABEL.value, 'eval')
113 |
114 | train_lib.train(
115 | strategy=train_lib.get_strategy(_MODE.value),
116 | train_folder=train_folder,
117 | saved_model_folder=saved_model_folder,
118 | n_iterations=training_options.num_steps,
119 | create_model_fn=model_lib.create_model,
120 | create_losses_fn=losses.training_losses,
121 | create_metrics_fn=metrics_lib.create_metrics_fn,
122 | dataset=data_lib.create_training_dataset(
123 | augmentation_fns=augmentation_fns),
124 | learning_rate=learning_rate,
125 | eval_loop_fn=eval_lib.eval_loop,
126 | eval_folder=eval_folder,
127 | eval_datasets=data_lib.create_eval_datasets() or None)
128 |
129 |
130 | if __name__ == '__main__':
131 | app.run(main)
132 |
--------------------------------------------------------------------------------