├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── data.py ├── dataset_utils.py ├── evaluate_quantitative_metrics.py ├── imgs ├── app_interpolation.jpg ├── app_variation.jpg └── teaser_with_caption.jpg ├── layers.py ├── losses.py ├── networks.py ├── neural_rerendering.py ├── options.py ├── pretrain_appearance.py ├── segment_dataset.py ├── staged_model.py ├── style_loss.py └── utils.py /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows 28 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Rerendering in the Wild 2 | Moustafa Meshry1, 3 | [Dan B Goldman](http://www.danbgoldman.com/)2, 4 | [Sameh Khamis](http://www.samehkhamis.com/)2, 5 | [Hugues Hoppe](http://hhoppe.com/)2, 6 | Rohit Pandey2, 7 | [Noah Snavely](http://www.cs.cornell.edu/~snavely/)2, 8 | [Ricardo Martin-Brualla](http://www.ricardomartinbrualla.com/)2. 9 | 10 | 1University of Maryland, College Park      2Google Inc. 11 | 12 | To appear at CVPR 2019 (Oral).

13 | 14 | 15 |
16 | 17 |
18 | 19 | 20 | 21 | We will provide Tensorflow implementation and pretrained models for our paper soon. 22 | 23 | [**Paper**](https://arxiv.org/abs/1904.04290) | [**Video**](https://www.youtube.com/watch?v=E1crWQn_kmY) | [**Code**](https://github.com/MoustafaMeshry/neural_rerendering_in_the_wild) | [**Project page**](https://moustafameshry.github.io/neural_rerendering_in_the_wild/) 24 | 25 | ### Abstract 26 | 27 | We explore total scene capture — recording, modeling, and rerendering a scene under varying appearance such as season and time of day. 28 | Starting from internet photos of a tourist landmark, we apply traditional 3D reconstruction to register the photos and approximate the scene as a point cloud. 29 | For each photo, we render the scene points into a deep framebuffer, 30 | and train a neural network to learn the mapping of these initial renderings to the actual photos. 31 | This rerendering network also takes as input a latent appearance vector and a semantic mask indicating the location of transient objects like pedestrians. 32 | The model is evaluated on several datasets of publicly available images spanning a broad range of illumination conditions. 33 | We create short videos demonstrating realistic manipulation of the image viewpoint, appearance, and semantic labeling. 34 | We also compare results with prior work on scene reconstruction from internet photos. 35 | 36 | ### Video 37 | [![Supplementary material video](https://img.youtube.com/vi/E1crWQn_kmY/0.jpg)](https://www.youtube.com/watch?v=E1crWQn_kmY) 38 | 39 | 40 | ### Appearance variation 41 | 42 | We capture the appearance of the original images in the left column, and rerender several viewpoints under them. The last column is a detail of the previous one. The top row shows the renderings part of the input to the rerenderer, that exhibit artifacts like incomplete features in the statue, and an inconsistent mix of day and night appearances. Note the hallucinated twilight scene in the sky using the last appearance. Image credits: Flickr users William Warby, Neil Rickards, Rafael Jimenez, acme401 (Creative Commons). 43 | 44 |
45 | 46 |
47 | 48 | ### Appearance interpolation 49 | Frames from a synthesized camera path that smoothly transitions from the photo on the left to the photo on the right by smoothly interpolating both viewpoint and the latent appearance vectors. Please see the supplementary video. Photo Credits: Allie Caulfield, Tahbepet, Till Westermayer, Elliott Brown (Creative Commons). 50 |
51 | 52 |
53 | 54 | ### Acknowledgements 55 | We thank Gregory Blascovich for his help in conducting the user study, and Johannes Schönberger and True Price for their help generating datasets. 56 | 57 | ### Run and train instructions 58 | 59 | Staged-training consists of three stages: 60 | 61 | - Pretraining the appearance network. 62 | - Training the rendering network while fixing the weights for the appearance 63 | network. 64 | - Finetuning both the appearance and the rendering networks. 65 | 66 | ### Aligned dataset preprocessing 67 | 68 | #### Manual preparation 69 | 70 | * Set a path to a base_dir that contains the source code: 71 | 72 | ``` 73 | base_dir=//to/neural_rendering 74 | mkdir $base_dir 75 | cd $base_dir 76 | ``` 77 | 78 | * We assume the following format for an aligned dataset: 79 | * Each training image contains 3 file with the following nameing format: 80 | * real image: %04d_reference.png 81 | * render color: %04d_color.png 82 | * render depth: %04d_depth.png 83 | * Set dataset name: e.g. 84 | ``` 85 | dataset_name='trevi3k' # set to any name 86 | ``` 87 | * Split the dataset into train and validation sets in two subdirectories: 88 | * $base_dir/datasets/$dataset_name/train 89 | * $base_dir/datasets/$dataset_name/val 90 | * Download the DeepLab semantic segmentation model trained on the ADE20K 91 | dataset from this link: 92 | http://download.tensorflow.org/models/deeplabv3_xception_ade20k_train_2018_05_29.tar.gz 93 | * Unzip the downloaded file to: $base_dir/deeplabv3_xception_ade20k_train 94 | * Download this [file](https://github.com/MoustafaMeshry/vgg_loss/blob/master/vgg16.py) for an implementation of a vgg-based perceptual loss. 95 | * Download trained weights for the vgg network as instructed in this link: https://github.com/machrisaa/tensorflow-vgg 96 | * Save the vgg weights to $base_dir/vgg16_weights/vgg16.npy 97 | 98 | 99 | #### Data preprocessing 100 | 101 | * Run the preprocessing pipeline which consists of: 102 | * Filtering out sparse renders. 103 | * Semantic segmentation of ground truth images. 104 | * Exporting the dataset to tfrecord format. 105 | 106 | ``` 107 | # Run locally 108 | python tools/dataset_utils.py \ 109 | --dataset_name=$dataset_name \ 110 | --dataset_parent_dir=$base_dir/datasets/$dataset_name \ 111 | --output_dir=$base_dir/datasets/$dataset_name \ 112 | --xception_frozen_graph_path=$base_dir/deeplabv3_xception_ade20k_train/frozen_inference_graph.pb \ 113 | --alsologtostderr 114 | ``` 115 | 116 | ### Pretraining the appearance encoder network 117 | 118 | ``` 119 | # Run locally 120 | python pretrain_appearance.py \ 121 | --dataset_name=$dataset_name \ 122 | --train_dir=$base_dir/train_models/$dataset_name-app_pretrain \ 123 | --imageset_dir=$base_dir/datasets/$dataset_name/train \ 124 | --train_resolution=512 \ 125 | --metadata_output_dir=$base_dir/datasets/$dataset_name 126 | ``` 127 | 128 | ### Training the rerendering network with a fixed appearance encoder 129 | 130 | Set the dataset_parent_dir variable below to point to the directory containing 131 | the generated TFRecords. 132 | 133 | ``` 134 | # Run locally: 135 | dataset_parent_dir=$base_dir/datasets/$dataset_name 136 | train_dir=$base_dir/train_models/$dataset_name-staged-fixed_appearance 137 | load_pretrained_app_encoder=true 138 | appearance_pretrain_dir=$base_dir/train_models/$dataset_name-app_pretrain 139 | load_from_another_ckpt=false 140 | fixed_appearance_train_dir='' 141 | train_app_encoder=false 142 | 143 | python neural_rerendering.py \ 144 | --dataset_name=$dataset_name \ 145 | --dataset_parent_dir=$dataset_parent_dir \ 146 | --train_dir=$train_dir \ 147 | --load_pretrained_app_encoder=$load_pretrained_app_encoder \ 148 | --appearance_pretrain_dir=$appearance_pretrain_dir \ 149 | --train_app_encoder=$train_app_encoder \ 150 | --load_from_another_ckpt=$load_from_another_ckpt \ 151 | --fixed_appearance_train_dir=$fixed_appearance_train_dir \ 152 | --total_kimg=4000 153 | ``` 154 | 155 | ### Finetuning the rerendering network and the appearance encoder 156 | 157 | Set the fixed_appearance_train_dir to the train directory from the previous 158 | step. 159 | 160 | ``` 161 | # Run locally: 162 | dataset_parent_dir=$base_dir/datasets/$dataset_name 163 | train_dir=$base_dir/train_models/$dataset_name-staged-finetune_appearance 164 | load_pretrained_app_encoder=false 165 | appearance_pretrain_dir='' 166 | load_from_another_ckpt=true 167 | fixed_appearance_train_dir=$base_dir/train_models/$dataset_name-staged-fixed_appearance 168 | train_app_encoder=true 169 | 170 | python neural_rerendering.py \ 171 | --dataset_name=$dataset_name \ 172 | --dataset_parent_dir=$dataset_parent_dir \ 173 | --train_dir=$train_dir \ 174 | --load_pretrained_app_encoder=$load_pretrained_app_encoder \ 175 | --appearance_pretrain_dir=$appearance_pretrain_dir \ 176 | --train_app_encoder=$train_app_encoder \ 177 | --load_from_another_ckpt=$load_from_another_ckpt \ 178 | --fixed_appearance_train_dir=$fixed_appearance_train_dir \ 179 | --total_kimg=4000 180 | ``` 181 | 182 | 183 | ### Evaluate model on validation set 184 | 185 | ``` 186 | experiment_title=$dataset_name-staged-finetune_appearance 187 | local_train_dir=$base_dir/train_models/$experiment_title 188 | dataset_parent_dir=$base_dir/datasets/$dataset_name 189 | val_set_out_dir=$local_train_dir/val_set_output 190 | 191 | # Run the model on validation set 192 | echo "Evaluating the validation set" 193 | python neural_rerendering.py \ 194 | --train_dir=$local_train_dir \ 195 | --dataset_name=$dataset_name \ 196 | --dataset_parent_dir=$dataset_parent_dir \ 197 | --run_mode='eval_subset' \ 198 | --virtual_seq_name='val' \ 199 | --output_validation_dir=$val_set_out_dir \ 200 | --logtostderr 201 | # Evaluate quantitative metrics 202 | python evaluate_quantitative_metrics.py \ 203 | --val_set_out_dir=$val_set_out_dir \ 204 | --experiment_title=$experiment_title \ 205 | --logtostderr 206 | ``` 207 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | #     https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from options import FLAGS as opts 16 | import functools 17 | import glob 18 | import numpy as np 19 | import os.path as osp 20 | import random 21 | import tensorflow as tf 22 | 23 | 24 | def provide_data(dataset_name='', parent_dir='', batch_size=8, subset=None, 25 | max_examples=None, crop_flag=False, crop_size=256, seeds=None, 26 | use_appearance=True, shuffle=128): 27 | # Parsing function for each tfrecord example. 28 | record_parse_fn = functools.partial( 29 | _parser_rendered_dataset, crop_flag=crop_flag, crop_size=crop_size, 30 | use_alpha=opts.use_alpha, use_depth=opts.use_depth, 31 | use_semantics=opts.use_semantic, seeds=seeds, 32 | use_appearance=use_appearance) 33 | 34 | input_dict_var = multi_input_fn_record( 35 | record_parse_fn, parent_dir, dataset_name, batch_size, 36 | subset=subset, max_examples=max_examples, shuffle=shuffle) 37 | return input_dict_var 38 | 39 | 40 | def _parser_rendered_dataset( 41 | serialized_example, crop_flag, crop_size, seeds, use_alpha, use_depth, 42 | use_semantics, use_appearance): 43 | """ 44 | Parses a single tf.Example into a features dictionary with input tensors. 45 | """ 46 | # Structure of features_dict need to match the dictionary structure that was 47 | # serialized to a tf.Example 48 | features_dict = {'height': tf.FixedLenFeature([], tf.int64), 49 | 'width': tf.FixedLenFeature([], tf.int64), 50 | 'rendered': tf.FixedLenFeature([], tf.string), 51 | 'depth': tf.FixedLenFeature([], tf.string), 52 | 'real': tf.FixedLenFeature([], tf.string), 53 | 'seg': tf.FixedLenFeature([], tf.string)} 54 | features = tf.parse_single_example(serialized_example, features=features_dict) 55 | height = tf.cast(features['height'], tf.int32) 56 | width = tf.cast(features['width'], tf.int32) 57 | 58 | # Parse the rendered image. 59 | rendered = tf.decode_raw(features['rendered'], tf.uint8) 60 | rendered = tf.cast(rendered, tf.float32) * (2.0 / 255) - 1.0 61 | rendered = tf.reshape(rendered, [height, width, 4]) 62 | if not use_alpha: 63 | rendered = tf.slice(rendered, [0, 0, 0], [height, width, 3]) 64 | conditional_input = rendered 65 | 66 | # Parse the depth image. 67 | if use_depth: 68 | depth = tf.decode_raw(features['depth'], tf.uint16) 69 | depth = tf.reshape(depth, [height, width, 1]) 70 | depth = tf.cast(depth, tf.float32) * (2.0 / 255) - 1.0 71 | conditional_input = tf.concat([conditional_input, depth], axis=-1) 72 | 73 | # Parse the semantic map. 74 | if use_semantics: 75 | seg_img = tf.decode_raw(features['seg'], tf.uint8) 76 | seg_img = tf.reshape(seg_img, [height, width, 3]) 77 | seg_img = tf.cast(seg_img, tf.float32) * (2.0 / 255) - 1 78 | conditional_input = tf.concat([conditional_input, seg_img], axis=-1) 79 | 80 | # Verify that the parsed input has the correct number of channels. 81 | assert conditional_input.shape[-1] == opts.deep_buffer_nc, ('num channels ' 82 | 'in the parsed input doesn\'t match num input channels specified in ' 83 | 'opts.deep_buffer_nc!') 84 | 85 | # Parse the ground truth image. 86 | real = tf.decode_raw(features['real'], tf.uint8) 87 | real = tf.cast(real, tf.float32) * (2.0 / 255) - 1.0 88 | real = tf.reshape(real, [height, width, 3]) 89 | 90 | # Parse the appearance image (if any). 91 | appearance_input = [] 92 | if use_appearance: 93 | # Concatenate the deep buffer to the real image. 94 | appearance_input = tf.concat([real, conditional_input], axis=-1) 95 | # Verify that the parsed input has the correct number of channels. 96 | assert appearance_input.shape[-1] == opts.appearance_nc, ('num channels ' 97 | 'in the parsed appearance input doesn\'t match num input channels ' 98 | 'specified in opts.appearance_nc!') 99 | 100 | # Crop conditional_input and real images, but keep the appearance input 101 | # uncropped (learn a one-to-many mapping from appearance to output) 102 | if crop_flag: 103 | assert crop_size is not None, 'crop_size is not provided!' 104 | if isinstance(crop_size, int): 105 | crop_size = [crop_size, crop_size] 106 | assert len(crop_size) == 2, 'crop_size is either an int or a 2-tuple!' 107 | 108 | # Central crop 109 | if seeds is not None and len(seeds) <= 1: 110 | conditional_input = tf.image.resize_image_with_crop_or_pad( 111 | conditional_input, crop_size[0], crop_size[1]) 112 | real = tf.image.resize_image_with_crop_or_pad(real, crop_size[0], 113 | crop_size[1]) 114 | else: 115 | if not seeds: # random crops 116 | seed = random.randint(0, (1 << 31) - 1) 117 | else: # fixed crops 118 | seed_idx = random.randint(0, len(seeds) - 1) 119 | seed = seeds[seed_idx] 120 | conditional_input = tf.random_crop( 121 | conditional_input, crop_size + [opts.deep_buffer_nc], seed=seed) 122 | real = tf.random_crop(real, crop_size + [3], seed=seed) 123 | 124 | features = {'conditional_input': conditional_input, 125 | 'expected_output': real, 126 | 'peek_input': appearance_input} 127 | return features 128 | 129 | 130 | def multi_input_fn_record( 131 | record_parse_fn, parent_dir, tfrecord_basename, batch_size, subset=None, 132 | max_examples=None, shuffle=128): 133 | """Creates a Dataset pipeline for tfrecord files. 134 | 135 | Returns: 136 | Dataset iterator. 137 | """ 138 | subset_suffix = '*_%s.tfrecord' % subset if subset else '*.tfrecord' 139 | input_pattern = osp.join(parent_dir, tfrecord_basename + subset_suffix) 140 | filenames = sorted(glob.glob(input_pattern)) 141 | assert len(filenames) > 0, ('Error! input pattern "%s" didn\'t match any ' 142 | 'files' % input_pattern) 143 | dataset = tf.data.TFRecordDataset(filenames) 144 | if shuffle == 0: # keep input deterministic 145 | # use one thread to get deterministic results 146 | dataset = dataset.map(record_parse_fn, num_parallel_calls=None) 147 | else: 148 | dataset = dataset.repeat() # Repeat indefinitely. 149 | dataset = dataset.map(record_parse_fn, 150 | num_parallel_calls=max(4, batch_size // 4)) 151 | if opts.training_pipeline == 'drit': 152 | dataset1 = dataset.shuffle(shuffle) 153 | dataset2 = dataset.shuffle(shuffle) 154 | paired_dataset = tf.data.Dataset.zip((dataset1, dataset2)) 155 | 156 | def _join_paired_dataset(features_a, features_b): 157 | features_a['conditional_input_2'] = features_b['conditional_input'] 158 | features_a['expected_output_2'] = features_b['expected_output'] 159 | return features_a 160 | 161 | joined_dataset = paired_dataset.map(_join_paired_dataset) 162 | dataset = joined_dataset 163 | else: 164 | dataset = dataset.shuffle(shuffle) 165 | if max_examples is not None: 166 | dataset = dataset.take(max_examples) 167 | dataset = dataset.batch(batch_size) 168 | if shuffle > 0: # input is not deterministic 169 | dataset = dataset.prefetch(4) # Prefetch a few batches. 170 | return dataset.make_one_shot_iterator().get_next() 171 | -------------------------------------------------------------------------------- /dataset_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | #     https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from PIL import Image 16 | from absl import app 17 | from absl import flags 18 | from options import FLAGS as opts 19 | import cv2 20 | import data 21 | import functools 22 | import glob 23 | import numpy as np 24 | import os 25 | import os.path as osp 26 | import shutil 27 | import six 28 | import tensorflow as tf 29 | import segment_dataset as segment_utils 30 | import utils 31 | 32 | FLAGS = flags.FLAGS 33 | flags.DEFINE_string('output_dir', None, 'Directory to save exported tfrecords.') 34 | flags.DEFINE_string('xception_frozen_graph_path', None, 35 | 'Path to the deeplab xception model frozen graph') 36 | 37 | 38 | class AlignedRenderedDataset(object): 39 | def __init__(self, rendered_filepattern, use_semantic_map=True): 40 | """ 41 | Args: 42 | rendered_filepattern: string, path filepattern to 3D rendered images ( 43 | assumes filenames are '/path/to/dataset/%d_color.png') 44 | use_semantic_map: bool, include semantic maps. in the TFRecord 45 | """ 46 | self.filenames = sorted(glob.glob(rendered_filepattern)) 47 | assert len(self.filenames) > 0, ('input %s didn\'t match any files!' % 48 | rendered_filepattern) 49 | self.iter_idx = 0 50 | self.use_semantic_map = use_semantic_map 51 | 52 | def __iter__(self): 53 | return self 54 | 55 | def __next__(self): 56 | return self.next() 57 | 58 | def next(self): 59 | if self.iter_idx < len(self.filenames): 60 | rendered_img_name = self.filenames[self.iter_idx] 61 | basename = rendered_img_name[:-9] # remove the 'color.png' suffix 62 | ref_img_name = basename + 'reference.png' 63 | depth_img_name = basename + 'depth.png' 64 | # Read the 3D rendered image 65 | img_rendered = cv2.imread(rendered_img_name, cv2.IMREAD_UNCHANGED) 66 | # Change BGR (default cv2 format) to RGB 67 | img_rendered = img_rendered[:, :, [2,1,0,3]] # it has a 4th alpha channel 68 | # Read the depth image 69 | img_depth = cv2.imread(depth_img_name, cv2.IMREAD_UNCHANGED) 70 | # Workaround as some depth images are read with a different data type! 71 | img_depth = img_depth.astype(np.uint16) 72 | # Read reference image if exists, otherwise replace with a zero image. 73 | if osp.exists(ref_img_name): 74 | img_ref = cv2.imread(ref_img_name) 75 | img_ref = img_ref[:, :, ::-1] # Change BGR to RGB format. 76 | else: # use a dummy 3-channel zero image as a placeholder 77 | print('Warning: no reference image found! Using a dummy placeholder!') 78 | img_height, img_width = img_depth.shape 79 | img_ref = np.zeros((img_height, img_width, 3), dtype=np.uint8) 80 | 81 | if self.use_semantic_map: 82 | semantic_seg_img_name = basename + 'seg_rgb.png' 83 | img_seg = cv2.imread(semantic_seg_img_name) 84 | img_seg = img_seg[:, :, ::-1] # Change from BGR to RGB 85 | if img_seg.shape[0] == 512 and img_seg.shape[1] == 512: 86 | img_ref = utils.get_central_crop(img_ref) 87 | img_rendered = utils.get_central_crop(img_rendered) 88 | img_depth = utils.get_central_crop(img_depth) 89 | 90 | img_shape = img_depth.shape 91 | assert img_seg.shape == (img_shape + (3,)), 'error in seg image %s %s' % ( 92 | basename, str(img_seg.shape)) 93 | assert img_ref.shape == (img_shape + (3,)), 'error in ref image %s %s' % ( 94 | basename, str(img_ref.shape)) 95 | assert img_rendered.shape == (img_shape + (4,)), ('error in rendered ' 96 | 'image %s %s' % (basename, str(img_rendered.shape))) 97 | assert len(img_depth.shape) == 2, 'error in depth image %s %s' % ( 98 | basename, str(img_depth.shape)) 99 | 100 | raw_example = dict() 101 | raw_example['height'] = img_ref.shape[0] 102 | raw_example['width'] = img_ref.shape[1] 103 | raw_example['rendered'] = img_rendered.tostring() 104 | raw_example['depth'] = img_depth.tostring() 105 | raw_example['real'] = img_ref.tostring() 106 | if self.use_semantic_map: 107 | raw_example['seg'] = img_seg.tostring() 108 | self.iter_idx += 1 109 | return raw_example 110 | else: 111 | raise StopIteration() 112 | 113 | 114 | def filter_out_sparse_renders(dataset_dir, splits, ratio_threshold=0.15): 115 | print('Filtering %s' % dataset_dir) 116 | if splits is None: 117 | imgs_dirs = [dataset_dir] 118 | else: 119 | imgs_dirs = [osp.join(dataset_dir, split) for split in splits] 120 | 121 | filtered_images = [] 122 | total_images = 0 123 | sum_density = 0 124 | for cur_dir in imgs_dirs: 125 | filtered_dir = osp.join(cur_dir, 'sparse_renders') 126 | if not osp.exists(filtered_dir): 127 | os.makedirs(filtered_dir) 128 | imgs_file_pattern = osp.join(cur_dir, '*_color.png') 129 | images_path = sorted(glob.glob(imgs_file_pattern)) 130 | print('Processing %d files' % len(images_path)) 131 | total_images += len(images_path) 132 | for ii, img_path in enumerate(images_path): 133 | img = np.array(Image.open(img_path)) 134 | aggregate = np.squeeze(np.sum(img, axis=2)) 135 | height, width = aggregate.shape 136 | mask = aggregate > 0 137 | density = np.sum(mask) * 1. / (height * width) 138 | sum_density += density 139 | if density <= ratio_threshold: 140 | parent, basename = osp.split(img_path) 141 | basename = basename[:-10] # remove the '_color.png' suffix 142 | srcs = sorted(glob.glob(osp.join(parent, basename + '_*'))) 143 | dest = unicode(filtered_dir + '/.') 144 | for src in srcs: 145 | shutil.move(src, dest) 146 | filtered_images.append(basename) 147 | print('filtered fie %d: %s with a desnity of %.3f' % (ii, basename, 148 | density)) 149 | print('Filtered %d/%d images' % (len(filtered_images), total_images)) 150 | print('Mean desnity = %.4f' % (sum_density / total_images)) 151 | 152 | 153 | def _to_example(dictionary): 154 | """Helper: build tf.Example from (string -> int/float/str list) dictionary.""" 155 | features = {} 156 | for (k, v) in six.iteritems(dictionary): 157 | if isinstance(v, six.integer_types): 158 | features[k] = tf.train.Feature(int64_list=tf.train.Int64List(value=[v])) 159 | elif isinstance(v, float): 160 | features[k] = tf.train.Feature(float_list=tf.train.FloatList(value=[v])) 161 | elif isinstance(v, six.string_types): 162 | features[k] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[v])) 163 | elif isinstance(v, bytes): 164 | features[k] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[v])) 165 | else: 166 | raise ValueError("Value for %s is not a recognized type; v: %s type: %s" % 167 | (k, str(v[0]), str(type(v[0])))) 168 | 169 | return tf.train.Example(features=tf.train.Features(feature=features)) 170 | 171 | 172 | def _generate_tfrecord_dataset(generator, 173 | output_name, 174 | output_dir): 175 | """Convert a dataset into TFRecord format.""" 176 | output_filename = os.path.join(output_dir, output_name) 177 | output_file = os.path.join(output_dir, output_filename) 178 | tf.logging.info("Writing TFRecords to file %s", output_file) 179 | writer = tf.python_io.TFRecordWriter(output_file) 180 | 181 | counter = 0 182 | for case in generator: 183 | if counter % 100 == 0: 184 | print('Generating case %d for %s.' % (counter, output_name)) 185 | counter += 1 186 | example = _to_example(case) 187 | writer.write(example.SerializeToString()) 188 | 189 | writer.close() 190 | return output_file 191 | 192 | 193 | def export_aligned_dataset_to_tfrecord( 194 | dataset_dir, output_dir, output_basename, splits, 195 | xception_frozen_graph_path): 196 | 197 | # Step 1: filter out sparse renders 198 | filter_out_sparse_renders(dataset_dir, splits, 0.15) 199 | 200 | # Step 2: generate semantic segmentation masks 201 | segment_utils.segment_and_color_dataset( 202 | dataset_dir, xception_frozen_graph_path, splits) 203 | 204 | # Step 3: export dataset to TFRecord 205 | if splits is None: 206 | input_filepattern = osp.join(dataset_dir, '*_color.png') 207 | dataset_iter = AlignedRenderedDataset(input_filepattern) 208 | output_name = output_basename + '.tfrecord' 209 | _generate_tfrecord_dataset(dataset_iter, output_name, output_dir) 210 | else: 211 | for split in splits: 212 | input_filepattern = osp.join(dataset_dir, split, '*_color.png') 213 | dataset_iter = AlignedRenderedDataset(input_filepattern) 214 | output_name = '%s_%s.tfrecord' % (output_basename, split) 215 | _generate_tfrecord_dataset(dataset_iter, output_name, output_dir) 216 | 217 | 218 | def main(argv): 219 | # Read input flags 220 | dataset_name = opts.dataset_name 221 | dataset_parent_dir = opts.dataset_parent_dir 222 | output_dir = FLAGS.output_dir 223 | xception_frozen_graph_path = FLAGS.xception_frozen_graph_path 224 | splits = ['train', 'val'] 225 | # Run the preprocessing pipeline 226 | export_aligned_dataset_to_tfrecord( 227 | dataset_parent_dir, output_dir, dataset_name, splits, 228 | xception_frozen_graph_path) 229 | 230 | 231 | if __name__ == '__main__': 232 | app.run(main) 233 | -------------------------------------------------------------------------------- /evaluate_quantitative_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | #     https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from PIL import Image 16 | from absl import app 17 | from absl import flags 18 | import functools 19 | import glob 20 | import numpy as np 21 | import os 22 | import os.path as osp 23 | import skimage.measure 24 | import tensorflow as tf 25 | import utils 26 | 27 | FLAGS = flags.FLAGS 28 | flags.DEFINE_string('val_set_out_dir', None, 29 | 'Output directory with concatenated fake and real images.') 30 | flags.DEFINE_string('experiment_title', 'experiment', 31 | 'Name for the experiment to evaluate') 32 | 33 | 34 | def _extract_real_and_fake_from_concatenated_output(val_set_out_dir): 35 | out_dir = osp.join(val_set_out_dir, 'fake') 36 | gt_dir = osp.join(val_set_out_dir, 'real') 37 | if not osp.exists(out_dir): 38 | os.makedirs(out_dir) 39 | if not osp.exists(gt_dir): 40 | os.makedirs(gt_dir) 41 | imgs_pattern = osp.join(val_set_out_dir, '*.png') 42 | imgs_paths = sorted(glob.glob(imgs_pattern)) 43 | print('Separating %d images in %s.' % (len(imgs_paths), val_set_out_dir)) 44 | for img_path in imgs_paths: 45 | basename = osp.basename(img_path)[:-4] # remove the '.png' suffix 46 | img = np.array(Image.open(img_path)) 47 | img_res = 512 48 | fake = img[:, -2*img_res:-img_res, :] 49 | real = img[:, -img_res:, :] 50 | fake_path = osp.join(out_dir, '%s_fake.png' % basename) 51 | real_path = osp.join(gt_dir, '%s_real.png' % basename) 52 | Image.fromarray(fake).save(fake_path) 53 | Image.fromarray(real).save(real_path) 54 | 55 | 56 | def compute_l1_loss_metric(image_set1_paths, image_set2_paths): 57 | assert len(image_set1_paths) == len(image_set2_paths) 58 | assert len(image_set1_paths) > 0 59 | print('Evaluating L1 loss for %d pairs' % len(image_set1_paths)) 60 | 61 | total_loss = 0. 62 | for ii, (img1_path, img2_path) in enumerate(zip(image_set1_paths, 63 | image_set2_paths)): 64 | img1_in_ar = np.array(Image.open(img1_path), dtype=np.float32) 65 | img1_in_ar = utils.crop_to_multiple(img1_in_ar) 66 | 67 | img2_in_ar = np.array(Image.open(img2_path), dtype=np.float32) 68 | img2_in_ar = utils.crop_to_multiple(img2_in_ar) 69 | 70 | loss_l1 = np.mean(np.abs(img1_in_ar - img2_in_ar)) 71 | total_loss += loss_l1 72 | 73 | return total_loss / len(image_set1_paths) 74 | 75 | 76 | def compute_psnr_loss_metric(image_set1_paths, image_set2_paths): 77 | assert len(image_set1_paths) == len(image_set2_paths) 78 | assert len(image_set1_paths) > 0 79 | print('Evaluating PSNR loss for %d pairs' % len(image_set1_paths)) 80 | 81 | total_loss = 0. 82 | for ii, (img1_path, img2_path) in enumerate(zip(image_set1_paths, 83 | image_set2_paths)): 84 | img1_in_ar = np.array(Image.open(img1_path)) 85 | img1_in_ar = utils.crop_to_multiple(img1_in_ar) 86 | 87 | img2_in_ar = np.array(Image.open(img2_path)) 88 | img2_in_ar = utils.crop_to_multiple(img2_in_ar) 89 | 90 | loss_psnr = skimage.measure.compare_psnr(img1_in_ar, img2_in_ar) 91 | total_loss += loss_psnr 92 | 93 | return total_loss / len(image_set1_paths) 94 | 95 | 96 | def evaluate_experiment(val_set_out_dir, title='experiment', 97 | metrics=['psnr', 'l1']): 98 | 99 | out_dir = osp.join(val_set_out_dir, 'fake') 100 | gt_dir = osp.join(val_set_out_dir, 'real') 101 | _extract_real_and_fake_from_concatenated_output(val_set_out_dir) 102 | input_pattern1 = osp.join(gt_dir, '*.png') 103 | input_pattern2 = osp.join(out_dir, '*.png') 104 | set1 = sorted(glob.glob(input_pattern1)) 105 | set2 = sorted(glob.glob(input_pattern2)) 106 | for metric in metrics: 107 | if metric == 'l1': 108 | mean_loss = compute_l1_loss_metric(set1, set2) 109 | elif metric == 'psnr': 110 | mean_loss = compute_psnr_loss_metric(set1, set2) 111 | print('*** mean %s loss for %s = %f' % (metric, title, mean_loss)) 112 | 113 | 114 | def main(argv): 115 | evaluate_experiment(FLAGS.val_set_out_dir, title=FLAGS.experiment_title, 116 | metrics=['psnr', 'l1']) 117 | 118 | 119 | if __name__ == '__main__': 120 | app.run(main) 121 | -------------------------------------------------------------------------------- /imgs/app_interpolation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/neural_rerendering_in_the_wild/5f5226f0083559adb0a21b567104dfc075b6f6e5/imgs/app_interpolation.jpg -------------------------------------------------------------------------------- /imgs/app_variation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/neural_rerendering_in_the_wild/5f5226f0083559adb0a21b567104dfc075b6f6e5/imgs/app_variation.jpg -------------------------------------------------------------------------------- /imgs/teaser_with_caption.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/neural_rerendering_in_the_wild/5f5226f0083559adb0a21b567104dfc075b6f6e5/imgs/teaser_with_caption.jpg -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | #     https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import functools 16 | import numpy as np 17 | import tensorflow as tf 18 | 19 | 20 | class LayerInstanceNorm(object): 21 | 22 | def __init__(self, scope_suffix='instance_norm'): 23 | curr_scope = tf.get_variable_scope().name 24 | self._scope = curr_scope + '/' + scope_suffix 25 | 26 | def __call__(self, x): 27 | with tf.variable_scope(self._scope, reuse=tf.AUTO_REUSE): 28 | return tf.contrib.layers.instance_norm( 29 | x, epsilon=1e-05, center=True, scale=True) 30 | 31 | 32 | def layer_norm(x, scope='layer_norm'): 33 | return tf.contrib.layers.layer_norm(x, center=True, scale=True) 34 | 35 | 36 | def pixel_norm(x): 37 | """Pixel normalization. 38 | 39 | Args: 40 | x: 4D image tensor in B01C format. 41 | 42 | Returns: 43 | 4D tensor with pixel normalized channels. 44 | """ 45 | return x * tf.rsqrt(tf.reduce_mean(tf.square(x), [-1], keepdims=True) + 1e-8) 46 | 47 | 48 | def global_avg_pooling(x): 49 | return tf.reduce_mean(x, axis=[1, 2], keepdims=True) 50 | 51 | 52 | class FullyConnected(object): 53 | 54 | def __init__(self, n_out_units, scope_suffix='FC'): 55 | weight_init = tf.random_normal_initializer(mean=0., stddev=0.02) 56 | weight_regularizer = tf.contrib.layers.l2_regularizer(scale=0.0001) 57 | 58 | curr_scope = tf.get_variable_scope().name 59 | self._scope = curr_scope + '/' + scope_suffix 60 | self.fc_layer = functools.partial( 61 | tf.layers.dense, units=n_out_units, kernel_initializer=weight_init, 62 | kernel_regularizer=weight_regularizer, use_bias=True) 63 | 64 | def __call__(self, x): 65 | with tf.variable_scope(self._scope, reuse=tf.AUTO_REUSE): 66 | return self.fc_layer(x) 67 | 68 | 69 | def init_he_scale(shape, slope=1.0): 70 | """He neural network random normal scaling for initialization. 71 | 72 | Args: 73 | shape: list of the dimensions of the tensor. 74 | slope: float, slope of the ReLu following the layer. 75 | 76 | Returns: 77 | a float, He's standard deviation. 78 | """ 79 | fan_in = np.prod(shape[:-1]) 80 | return np.sqrt(2. / ((1. + slope**2) * fan_in)) 81 | 82 | 83 | class LayerConv(object): 84 | """Convolution layer with support for equalized learning.""" 85 | 86 | def __init__(self, 87 | name, 88 | w, 89 | n, 90 | stride, 91 | padding='SAME', 92 | use_scaling=False, 93 | relu_slope=1.): 94 | """Layer constructor. 95 | 96 | Args: 97 | name: string, layer name. 98 | w: int or 2-tuple, width of the convolution kernel. 99 | n: 2-tuple of ints, input and output channel depths. 100 | stride: int or 2-tuple, stride for the convolution kernel. 101 | padding: string, the padding method. {SAME, VALID, REFLECT}. 102 | use_scaling: bool, whether to use weight norm and scaling. 103 | relu_slope: float, the slope of the ReLu following the layer. 104 | """ 105 | assert padding in ['SAME', 'VALID', 'REFLECT'], 'Error: unsupported padding' 106 | self._padding = padding 107 | with tf.variable_scope(name): 108 | if isinstance(stride, int): 109 | stride = [1, stride, stride, 1] 110 | else: 111 | assert len(stride) == 0, "stride is either an int or a 2-tuple" 112 | stride = [1, stride[0], stride[1], 1] 113 | if isinstance(w, int): 114 | w = [w, w] 115 | self.w = w 116 | shape = [w[0], w[1], n[0], n[1]] 117 | init_scale, pre_scale = init_he_scale(shape, relu_slope), 1. 118 | if use_scaling: 119 | init_scale, pre_scale = pre_scale, init_scale 120 | self._stride = stride 121 | self._pre_scale = pre_scale 122 | self._weight = tf.get_variable( 123 | 'weight', 124 | shape=shape, 125 | initializer=tf.random_normal_initializer(stddev=init_scale)) 126 | self._bias = tf.get_variable( 127 | 'bias', shape=[n[1]], initializer=tf.zeros_initializer) 128 | 129 | def __call__(self, x): 130 | """Apply layer to tensor x.""" 131 | if self._padding != 'REFLECT': 132 | padding = self._padding 133 | else: 134 | padding = 'VALID' 135 | pad_top = self.w[0] // 2 136 | pad_left = self.w[1] // 2 137 | if (self.w[0] - self._stride[1]) % 2 == 0: 138 | pad_bottom = pad_top 139 | else: 140 | pad_bottom = self.w[0] - self._stride[1] - pad_top 141 | if (self.w[1] - self._stride[2]) % 2 == 0: 142 | pad_right = pad_left 143 | else: 144 | pad_right = self.w[1] - self._stride[2] - pad_left 145 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], 146 | [0, 0]], mode='REFLECT') 147 | y = tf.nn.conv2d(x, self._weight, strides=self._stride, padding=padding) 148 | return self._pre_scale * y + self._bias 149 | 150 | 151 | class LayerTransposedConv(object): 152 | """Convolution layer with support for equalized learning.""" 153 | 154 | def __init__(self, 155 | name, 156 | w, 157 | n, 158 | stride, 159 | padding='SAME', 160 | use_scaling=False, 161 | relu_slope=1.): 162 | """Layer constructor. 163 | 164 | Args: 165 | name: string, layer name. 166 | w: int or 2-tuple, width of the convolution kernel. 167 | n: 2-tuple int, [n_in_channels, n_out_channels] 168 | stride: int or 2-tuple, stride for the convolution kernel. 169 | padding: string, the padding method {SAME, VALID, REFLECT}. 170 | use_scaling: bool, whether to use weight norm and scaling. 171 | relu_slope: float, the slope of the ReLu following the layer. 172 | """ 173 | assert padding in ['SAME'], 'Error: unsupported padding for transposed conv' 174 | if isinstance(stride, int): 175 | stride = [1, stride, stride, 1] 176 | else: 177 | assert len(stride) == 2, "stride is either an int or a 2-tuple" 178 | stride = [1, stride[0], stride[1], 1] 179 | if isinstance(w, int): 180 | w = [w, w] 181 | self.padding = padding 182 | self.nc_in, self.nc_out = n 183 | self.stride = stride 184 | with tf.variable_scope(name): 185 | kernel_shape = [w[0], w[1], self.nc_out, self.nc_in] 186 | init_scale, pre_scale = init_he_scale(kernel_shape, relu_slope), 1. 187 | if use_scaling: 188 | init_scale, pre_scale = pre_scale, init_scale 189 | self._pre_scale = pre_scale 190 | self._weight = tf.get_variable( 191 | 'weight', 192 | shape=kernel_shape, 193 | initializer=tf.random_normal_initializer(stddev=init_scale)) 194 | self._bias = tf.get_variable( 195 | 'bias', shape=[self.nc_out], initializer=tf.zeros_initializer) 196 | 197 | def __call__(self, x): 198 | """Apply layer to tensor x.""" 199 | x_shape = x.get_shape().as_list() 200 | batch_size = tf.shape(x)[0] 201 | stride_x, stride_y = self.stride[1], self.stride[2] 202 | output_shape = tf.stack([ 203 | batch_size, x_shape[1] * stride_x, x_shape[2] * stride_y, self.nc_out]) 204 | y = tf.nn.conv2d_transpose( 205 | x, filter=self._weight, output_shape=output_shape, strides=self.stride, 206 | padding=self.padding) 207 | return self._pre_scale * y + self._bias 208 | 209 | 210 | class ResBlock(object): 211 | def __init__(self, 212 | name, 213 | nc, 214 | norm_layer_constructor, 215 | activation, 216 | padding='SAME', 217 | use_scaling=False, 218 | relu_slope=1.): 219 | """Layer constructor.""" 220 | self.name = name 221 | conv2d = functools.partial( 222 | LayerConv, w=3, n=[nc, nc], stride=1, padding=padding, 223 | use_scaling=use_scaling, relu_slope=relu_slope) 224 | self.blocks = [] 225 | with tf.variable_scope(self.name): 226 | with tf.variable_scope('res0'): 227 | self.blocks.append( 228 | LayerPipe([ 229 | conv2d('res0_conv'), 230 | norm_layer_constructor('res0_norm'), 231 | activation 232 | ]) 233 | ) 234 | with tf.variable_scope('res1'): 235 | self.blocks.append( 236 | LayerPipe([ 237 | conv2d('res1_conv'), 238 | norm_layer_constructor('res1_norm') 239 | ]) 240 | ) 241 | 242 | def __call__(self, x_init): 243 | """Apply layer to tensor x.""" 244 | x = x_init 245 | for f in self.blocks: 246 | x = f(x) 247 | return x + x_init 248 | 249 | 250 | class BasicBlock(object): 251 | def __init__(self, 252 | name, 253 | n, 254 | activation=functools.partial(tf.nn.leaky_relu, alpha=0.2), 255 | padding='SAME', 256 | use_scaling=True, 257 | relu_slope=1.): 258 | """Layer constructor.""" 259 | self.name = name 260 | conv2d = functools.partial( 261 | LayerConv, stride=1, padding=padding, 262 | use_scaling=use_scaling, relu_slope=relu_slope) 263 | avg_pool = functools.partial(downscale, n=2) 264 | nc_in, nc_out = n # n is a 2-tuple 265 | with tf.variable_scope(self.name): 266 | self.path1_blocks = [] 267 | with tf.variable_scope('bb_path1'): 268 | self.path1_blocks.append( 269 | LayerPipe([ 270 | activation, 271 | conv2d('bb_conv0', w=3, n=[nc_in, nc_out]), 272 | activation, 273 | conv2d('bb_conv1', w=3, n=[nc_out, nc_out]), 274 | downscale 275 | ]) 276 | ) 277 | 278 | self.path2_blocks = [] 279 | with tf.variable_scope('bb_path2'): 280 | self.path2_blocks.append( 281 | LayerPipe([ 282 | downscale, 283 | conv2d('path2_conv', w=1, n=[nc_in, nc_out]) 284 | ]) 285 | ) 286 | 287 | def __call__(self, x_init): 288 | """Apply layer to tensor x.""" 289 | x1 = x_init 290 | x2 = x_init 291 | for f in self.path1_blocks: 292 | x1 = f(x1) 293 | for f in self.path2_blocks: 294 | x2 = f(x2) 295 | return x1 + x2 296 | 297 | 298 | class LayerDense(object): 299 | """Dense layer with a non-linearity.""" 300 | 301 | def __init__(self, name, n, use_scaling=False, relu_slope=1.): 302 | """Layer constructor. 303 | 304 | Args: 305 | name: string, layer name. 306 | n: 2-tuple of ints, input and output widths. 307 | use_scaling: bool, whether to use weight norm and scaling. 308 | relu_slope: float, the slope of the ReLu following the layer. 309 | """ 310 | with tf.variable_scope(name): 311 | init_scale, pre_scale = init_he_scale(n, relu_slope), 1. 312 | if use_scaling: 313 | init_scale, pre_scale = pre_scale, init_scale 314 | self._pre_scale = pre_scale 315 | self._weight = tf.get_variable( 316 | 'weight', 317 | shape=n, 318 | initializer=tf.random_normal_initializer(stddev=init_scale)) 319 | self._bias = tf.get_variable( 320 | 'bias', shape=[n[1]], initializer=tf.zeros_initializer) 321 | 322 | def __call__(self, x): 323 | """Apply layer to tensor x.""" 324 | return self._pre_scale * tf.matmul(x, self._weight) + self._bias 325 | 326 | 327 | class LayerPipe(object): 328 | """Pipe a sequence of functions.""" 329 | 330 | def __init__(self, functions): 331 | """Layer constructor. 332 | 333 | Args: 334 | functions: list, functions to pipe. 335 | """ 336 | self._functions = tuple(functions) 337 | 338 | def __call__(self, x, **kwargs): 339 | """Apply pipe to tensor x and return result.""" 340 | del kwargs 341 | for f in self._functions: 342 | x = f(x) 343 | return x 344 | 345 | 346 | def downscale(x, n=2): 347 | """Box downscaling. 348 | 349 | Args: 350 | x: 4D image tensor. 351 | n: integer scale (must be a power of 2). 352 | 353 | Returns: 354 | 4D tensor of images down scaled by a factor n. 355 | """ 356 | if n == 1: 357 | return x 358 | return tf.nn.avg_pool(x, [1, n, n, 1], [1, n, n, 1], 'VALID') 359 | 360 | 361 | def upscale(x, n): 362 | """Box upscaling (also called nearest neighbors). 363 | 364 | Args: 365 | x: 4D image tensor in B01C format. 366 | n: integer scale (must be a power of 2). 367 | 368 | Returns: 369 | 4D tensor of images up scaled by a factor n. 370 | """ 371 | if n == 1: 372 | return x 373 | x_shape = tf.shape(x) 374 | height, width = x_shape[1], x_shape[2] 375 | return tf.image.resize_nearest_neighbor(x, [n * height, n * width]) 376 | 377 | 378 | def tile_and_concatenate(x, z, n_z): 379 | z = tf.reshape(z, shape=[-1, 1, 1, n_z]) 380 | z = tf.tile(z, [1, tf.shape(x)[1], tf.shape(x)[2], 1]) 381 | x = tf.concat([x, z], axis=-1) 382 | return x 383 | 384 | 385 | def minibatch_mean_variance(x): 386 | """Computes the variance average. 387 | 388 | This is used by the discriminator as a form of batch discrimination. 389 | 390 | Args: 391 | x: nD tensor for which to compute variance average. 392 | 393 | Returns: 394 | a scalar, the mean variance of variable x. 395 | """ 396 | mean = tf.reduce_mean(x, 0, keepdims=True) 397 | vals = tf.sqrt(tf.reduce_mean(tf.squared_difference(x, mean), 0) + 1e-8) 398 | vals = tf.reduce_mean(vals) 399 | return vals 400 | 401 | 402 | def scalar_concat(x, scalar): 403 | """Concatenate a scalar to a 4D tensor as an extra channel. 404 | 405 | Args: 406 | x: 4D image tensor in B01C format. 407 | scalar: a scalar to concatenate to the tensor. 408 | 409 | Returns: 410 | a 4D tensor with one extra channel containing the value scalar at 411 | every position. 412 | """ 413 | s = tf.shape(x) 414 | return tf.concat([x, tf.ones([s[0], s[1], s[2], 1]) * scalar], axis=3) 415 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | #     https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from options import FLAGS as opts 16 | import layers 17 | import os.path as osp 18 | import tensorflow as tf 19 | import vgg16 20 | 21 | 22 | def gradient_penalty_loss(y_xy, xy, iwass_target=1, iwass_lambda=10): 23 | grad = tf.gradients(tf.reduce_sum(y_xy), [xy])[0] 24 | grad_norm = tf.sqrt(tf.reduce_sum(tf.square(grad), axis=[1, 2, 3]) + 1e-8) 25 | loss_gp = tf.reduce_mean( 26 | tf.square(grad_norm - iwass_target)) * iwass_lambda / iwass_target**2 27 | return loss_gp 28 | 29 | 30 | def KL_loss(mean, logvar): 31 | loss = 0.5 * tf.reduce_sum(tf.square(mean) + tf.exp(logvar) - 1. - logvar, 32 | axis=-1) 33 | return tf.reduce_sum(loss) # just to match DRIT implementation 34 | 35 | 36 | def l2_regularize(x): 37 | return tf.reduce_mean(tf.square(x)) 38 | 39 | 40 | def L1_loss(x, y): 41 | return tf.reduce_mean(tf.abs(x - y)) 42 | 43 | 44 | class PerceptualLoss: 45 | def __init__(self, x, y, image_shape, layers, w_layers, w_act=0.1): 46 | """ 47 | Builds vgg16 network and computes the perceptual loss. 48 | """ 49 | assert len(image_shape) == 3 and image_shape[-1] == 3 50 | assert osp.exists(opts.vgg16_path), 'Cannot find %s' % opts.vgg16_path 51 | 52 | self.w_act = w_act 53 | self.vgg_layers = layers 54 | self.w_layers = w_layers 55 | batch_shape = [None] + image_shape # [None, H, W, 3] 56 | 57 | vgg_net = vgg16.Vgg16(opts.vgg16_path) 58 | self.x_acts = vgg_net.get_vgg_activations(x, layers) 59 | self.y_acts = vgg_net.get_vgg_activations(y, layers) 60 | loss = 0 61 | for w, act1, act2 in zip(self.w_layers, self.x_acts, self.y_acts): 62 | loss += w * tf.reduce_mean(tf.square(self.w_act * (act1 - act2))) 63 | self.loss = loss 64 | 65 | def __call__(self): 66 | return self.loss 67 | 68 | 69 | def lsgan_appearance_E_loss(disc_response): 70 | disc_response = tf.squeeze(disc_response) 71 | gt_label = 0.5 72 | loss = tf.reduce_mean(tf.square(disc_response - gt_label)) 73 | return loss 74 | 75 | 76 | def lsgan_loss(disc_response, is_real): 77 | gt_label = 1 if is_real else 0 78 | disc_response = tf.squeeze(disc_response) 79 | # The following works for both regular and patchGAN discriminators 80 | loss = tf.reduce_mean(tf.square(disc_response - gt_label)) 81 | return loss 82 | 83 | 84 | def multiscale_discriminator_loss(Ds_responses, is_real): 85 | num_D = len(Ds_responses) 86 | loss = 0 87 | for i in range(num_D): 88 | curr_response = Ds_responses[i][-1][-1] 89 | loss += lsgan_loss(curr_response, is_real) 90 | return loss 91 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | #     https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from options import FLAGS as opts 16 | import functools 17 | import layers 18 | import tensorflow as tf 19 | 20 | 21 | class RenderingModel(object): 22 | 23 | def __init__(self, model_name, use_appearance=True): 24 | 25 | if model_name == 'pggan': 26 | self._model = ModelPGGAN(use_appearance) 27 | else: 28 | raise ValueError('Model %s not implemented!' % model_name) 29 | 30 | def __call__(self, x_in, z_app=None): 31 | return self._model(x_in, z_app) 32 | 33 | def get_appearance_encoder(self): 34 | return self._model._appearance_encoder 35 | 36 | def get_generator(self): 37 | return self._model._generator 38 | 39 | def get_content_encoder(self): 40 | return self._model._content_encoder 41 | 42 | 43 | # "Progressive Growing of GANs (PGGAN)"-inspired architecture. Implementation is 44 | # based on the implementation details in their paper, but code is not taken from 45 | # the authors' released code. 46 | # Main changes are: 47 | # - conditional GAN setup by introducting an encoder + skip connections. 48 | # - no progressive growing during training. 49 | class ModelPGGAN(RenderingModel): 50 | 51 | def __init__(self, use_appearance=True): 52 | self._use_appearance = use_appearance 53 | self._content_encoder = None 54 | self._generator = GeneratorPGGAN(appearance_vec_size=opts.app_vector_size) 55 | if use_appearance: 56 | self._appearance_encoder = DRITAppearanceEncoderConcat( 57 | 'appearance_net', opts.appearance_nc, opts.normalize_drit_Ez) 58 | else: 59 | self._appearance_encoder = None 60 | 61 | def __call__(self, x_in, z_app=None): 62 | y = self._generator(x_in, z_app) 63 | return y 64 | 65 | def get_appearance_encoder(self): 66 | return self._appearance_encoder 67 | 68 | def get_generator(self): 69 | return self._generator 70 | 71 | def get_content_encoder(self): 72 | return self._content_encoder 73 | 74 | 75 | class PatchGANDiscriminator(object): 76 | 77 | def __init__(self, name_scope, input_nc, nf=64, n_layers=3, get_fmaps=False): 78 | """Constructor for a patchGAN discriminators. 79 | 80 | Args: 81 | name_scope: str - tf name scope. 82 | input_nc: int - number of input channels. 83 | nf: int - starting number of discriminator filters. 84 | n_layers: int - number of layers in the discriminator. 85 | get_fmaps: bool - return intermediate feature maps for FeatLoss. 86 | """ 87 | self.get_fmaps = get_fmaps 88 | self.n_layers = n_layers 89 | kw = 4 # kernel width for convolution 90 | 91 | activation = functools.partial(tf.nn.leaky_relu, alpha=0.2) 92 | norm_layer = functools.partial(layers.LayerInstanceNorm) 93 | conv2d = functools.partial(layers.LayerConv, use_scaling=opts.use_scaling, 94 | relu_slope=0.2) 95 | 96 | def minibatch_stats(x): 97 | return layers.scalar_concat(x, layers.minibatch_mean_variance(x)) 98 | 99 | # Create layers. 100 | self.blocks = [] 101 | with tf.variable_scope(name_scope, tf.AUTO_REUSE): 102 | with tf.variable_scope('block_0'): 103 | self.blocks.append([ 104 | conv2d('conv0', w=kw, n=[input_nc, nf], stride=2), 105 | activation 106 | ]) 107 | for ii_block in range(1, n_layers): 108 | nf_prev = nf 109 | nf = min(nf * 2, 512) 110 | with tf.variable_scope('block_%d' % ii_block): 111 | self.blocks.append([ 112 | conv2d('conv%d' % ii_block, w=kw, n=[nf_prev, nf], stride=2), 113 | norm_layer(), 114 | activation 115 | ]) 116 | # Add minibatch_stats (from PGGAN) and do a stride1 convolution. 117 | nf_prev = nf 118 | nf = min(nf * 2, 512) 119 | with tf.variable_scope('block_%d' % (n_layers + 1)): 120 | self.blocks.append([ 121 | minibatch_stats, # this is improvised by @meshry 122 | conv2d('conv%d' % (n_layers + 1), w=kw, n=[nf_prev + 1, nf], 123 | stride=1), 124 | norm_layer(), 125 | activation 126 | ]) 127 | # Get 1-channel patchGAN logits 128 | with tf.variable_scope('patchGAN_logits'): 129 | self.blocks.append([ 130 | conv2d('conv%d' % (n_layers + 2), w=kw, n=[nf, 1], stride=1) 131 | ]) 132 | 133 | def __call__(self, x, x_cond=None): 134 | # Concatenate extra conditioning input, if any. 135 | if x_cond is not None: 136 | x = tf.concat([x, x_cond], axis=3) 137 | 138 | if self.get_fmaps: 139 | # Dummy addition of x to D_fmaps, which will be removed before returing 140 | D_fmaps = [[x]] 141 | for i_block in range(len(self.blocks)): 142 | # Apply layer #0 in the current block 143 | block_fmaps = [self.blocks[i_block][0](D_fmaps[-1][-1])] 144 | # Apply the remaining layers of this block 145 | for i_layer in range(1, len(self.blocks[i_block])): 146 | block_fmaps.append(self.blocks[i_block][i_layer](block_fmaps[-1])) 147 | # Append the feature maps of this block to D_fmaps 148 | D_fmaps.append(block_fmaps) 149 | return D_fmaps[1:] # exclude the input x which we added initially 150 | else: 151 | y = x 152 | for i_block in range(len(self.blocks)): 153 | for i_layer in range(len(self.blocks[i_block])): 154 | y = self.blocks[i_block][i_layer](y) 155 | return [[y]] 156 | 157 | 158 | class MultiScaleDiscriminator(object): 159 | 160 | def __init__(self, name_scope, input_nc, num_scales=3, nf=64, n_layers=3, 161 | get_fmaps=False): 162 | self.get_fmaps = get_fmaps 163 | discs = [] 164 | with tf.variable_scope(name_scope): 165 | for i in range(num_scales): 166 | discs.append(PatchGANDiscriminator( 167 | 'D_scale%d' % i, input_nc, nf=nf, n_layers=n_layers, 168 | get_fmaps=get_fmaps)) 169 | self.discriminators = discs 170 | 171 | def __call__(self, x, x_cond=None, params=None): 172 | del params 173 | if x_cond is not None: 174 | x = tf.concat([x, x_cond], axis=3) 175 | 176 | responses = [] 177 | for ii, D in enumerate(self.discriminators): 178 | responses.append(D(x, x_cond=None)) # x_cond is already concatenated 179 | if ii != len(self.discriminators) - 1: 180 | x = layers.downscale(x, n=2) 181 | return responses 182 | 183 | 184 | class GeneratorPGGAN(object): 185 | def __init__(self, appearance_vec_size=8, use_scaling=True, 186 | num_blocks=5, input_nc=7, 187 | fmap_base=8192, fmap_decay=1.0, fmap_max=512): 188 | """Generator model. 189 | 190 | Args: 191 | appearance_vec_size: int, size of the latent appearance vector. 192 | use_scaling: bool, whether to use weight scaling. 193 | resolution: int, width of the images (assumed to be square). 194 | input_nc: int, number of input channles. 195 | fmap_base: int, base number of channels. 196 | fmap_decay: float, decay rate of channels with respect to depth. 197 | fmap_max: int, max number of channels (supersedes fmap_base). 198 | 199 | Returns: 200 | function of the model. 201 | """ 202 | def _num_filters(fmap_base, fmap_decay, fmap_max, stage): 203 | if opts.g_nf == 32: 204 | return min(int(2**(10 - stage)), fmap_max) # nf32 205 | elif opts.g_nf == 64: 206 | return min(int(2**(11 - stage)), fmap_max) # nf64 207 | else: 208 | raise ValueError('Currently unsupported num filters') 209 | 210 | nf = functools.partial(_num_filters, fmap_base, fmap_decay, fmap_max) 211 | self.num_blocks = num_blocks 212 | activation = functools.partial(tf.nn.leaky_relu, alpha=0.2) 213 | conv2d_stride1 = functools.partial( 214 | layers.LayerConv, stride=1, use_scaling=use_scaling, relu_slope=0.2) 215 | conv2d_rgb = functools.partial(layers.LayerConv, w=1, stride=1, 216 | use_scaling=use_scaling) 217 | 218 | # Create encoder layers. 219 | with tf.variable_scope('g_model_enc', tf.AUTO_REUSE): 220 | self.enc_stage = [] 221 | self.from_rgb = [] 222 | 223 | if opts.use_appearance and opts.inject_z == 'to_encoder': 224 | input_nc += appearance_vec_size 225 | 226 | for i in range(num_blocks, -1, -1): 227 | with tf.variable_scope('res_%d' % i): 228 | self.from_rgb.append( 229 | layers.LayerPipe([ 230 | conv2d_rgb('from_rgb', n=[input_nc, nf(i + 1)]), 231 | activation, 232 | ]) 233 | ) 234 | self.enc_stage.append( 235 | layers.LayerPipe([ 236 | functools.partial(layers.downscale, n=2), 237 | conv2d_stride1('conv0', w=3, n=[nf(i + 1), nf(i)]), 238 | activation, 239 | layers.pixel_norm, 240 | conv2d_stride1('conv1', w=3, n=[nf(i), nf(i)]), 241 | activation, 242 | layers.pixel_norm 243 | ]) 244 | ) 245 | 246 | # Create decoder layers. 247 | with tf.variable_scope('g_model_dec', tf.AUTO_REUSE): 248 | self.dec_stage = [] 249 | self.to_rgb = [] 250 | 251 | nf_bottleneck = nf(0) # num input filters at the bottleneck 252 | if opts.use_appearance and opts.inject_z == 'to_bottleneck': 253 | nf_bottleneck += appearance_vec_size 254 | 255 | with tf.variable_scope('res_0'): 256 | self.dec_stage.append( 257 | layers.LayerPipe([ 258 | functools.partial(layers.upscale, n=2), 259 | conv2d_stride1('conv0', w=3, n=[nf_bottleneck, nf(1)]), 260 | activation, 261 | layers.pixel_norm, 262 | conv2d_stride1('conv1', w=3, n=[nf(1), nf(1)]), 263 | activation, 264 | layers.pixel_norm 265 | ]) 266 | ) 267 | self.to_rgb.append(conv2d_rgb('to_rgb', n=[nf(1), opts.output_nc])) 268 | 269 | multiply_factor = 2 if opts.concatenate_skip_layers else 1 270 | for i in range(1, num_blocks + 1): 271 | with tf.variable_scope('res_%d' % i): 272 | self.dec_stage.append( 273 | layers.LayerPipe([ 274 | functools.partial(layers.upscale, n=2), 275 | conv2d_stride1('conv0', w=3, 276 | n=[multiply_factor * nf(i), nf(i + 1)]), 277 | activation, 278 | layers.pixel_norm, 279 | conv2d_stride1('conv1', w=3, n=[nf(i + 1), nf(i + 1)]), 280 | activation, 281 | layers.pixel_norm 282 | ]) 283 | ) 284 | self.to_rgb.append(conv2d_rgb('to_rgb', 285 | n=[nf(i + 1), opts.output_nc])) 286 | 287 | def __call__(self, x, appearance_embedding=None, encoder_fmaps=None): 288 | """Generator function. 289 | 290 | Args: 291 | x: 2D tensor (batch, latents), the conditioning input batch of images. 292 | appearance_embedding: float tensor: latent appearance vector. 293 | Returns: 294 | 4D tensor of images (NHWC), the generated images. 295 | """ 296 | del encoder_fmaps 297 | enc_st_idx = 0 298 | if opts.use_appearance and opts.inject_z == 'to_encoder': 299 | x = layers.tile_and_concatenate(x, appearance_embedding, 300 | opts.app_vector_size) 301 | y = self.from_rgb[enc_st_idx](x) 302 | 303 | enc_responses = [] 304 | for i in range(enc_st_idx, len(self.enc_stage)): 305 | y = self.enc_stage[i](y) 306 | enc_responses.insert(0, y) 307 | 308 | # Concatenate appearance vector to y 309 | if opts.use_appearance and opts.inject_z == 'to_bottleneck': 310 | appearance_tensor = tf.tile(appearance_embedding, 311 | [1, tf.shape(y)[1], tf.shape(y)[2], 1]) 312 | y = tf.concat([y, appearance_tensor], axis=3) 313 | 314 | y_list = [] 315 | for i in range(self.num_blocks + 1): 316 | if i > 0: 317 | y_skip = enc_responses[i] # skip layer 318 | if opts.concatenate_skip_layers: 319 | y = tf.concat([y, y_skip], axis=3) 320 | else: 321 | y = y + y_skip 322 | y = self.dec_stage[i](y) 323 | y_list.append(y) 324 | 325 | return self.to_rgb[self.num_blocks](y_list[-1]) 326 | 327 | 328 | class DRITAppearanceEncoderConcat(object): 329 | 330 | def __init__(self, name_scope, input_nc, normalize_encoder): 331 | self.blocks = [] 332 | activation = functools.partial(tf.nn.leaky_relu, alpha=0.2) 333 | conv2d = functools.partial(layers.LayerConv, use_scaling=opts.use_scaling, 334 | relu_slope=0.2, padding='SAME') 335 | with tf.variable_scope(name_scope, tf.AUTO_REUSE): 336 | if normalize_encoder: 337 | self.blocks.append(layers.LayerPipe([ 338 | conv2d('conv0', w=4, n=[input_nc, 64], stride=2), 339 | layers.BasicBlock('BB0', n=[64, 128], use_scaling=opts.use_scaling), 340 | layers.pixel_norm, 341 | layers.BasicBlock('BB1', n=[128, 192], use_scaling=opts.use_scaling), 342 | layers.pixel_norm, 343 | layers.BasicBlock('BB2', n=[192, 256], use_scaling=opts.use_scaling), 344 | layers.pixel_norm, 345 | activation, 346 | layers.global_avg_pooling 347 | ])) 348 | else: 349 | self.blocks.append(layers.LayerPipe([ 350 | conv2d('conv0', w=4, n=[input_nc, 64], stride=2), 351 | layers.BasicBlock('BB0', n=[64, 128], use_scaling=opts.use_scaling), 352 | layers.BasicBlock('BB1', n=[128, 192], use_scaling=opts.use_scaling), 353 | layers.BasicBlock('BB2', n=[192, 256], use_scaling=opts.use_scaling), 354 | activation, 355 | layers.global_avg_pooling 356 | ])) 357 | # FC layers to get the mean and logvar 358 | self.fc_mean = layers.FullyConnected(opts.app_vector_size, 'FC_mean') 359 | self.fc_logvar = layers.FullyConnected(opts.app_vector_size, 'FC_logvar') 360 | 361 | def __call__(self, x): 362 | for f in self.blocks: 363 | x = f(x) 364 | 365 | mean = self.fc_mean(x) 366 | logvar = self.fc_logvar(x) 367 | # The following is an arbitrarily chosen *deterministic* latent vector 368 | # computation. Another option is to let z = mean, but gradients from logvar 369 | # will be None and will need to be removed. 370 | z = mean + tf.exp(0.5 * logvar) 371 | return z, mean, logvar 372 | -------------------------------------------------------------------------------- /neural_rerendering.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | #     https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from PIL import Image 16 | from absl import app 17 | from options import FLAGS as opts 18 | import data 19 | import datetime 20 | import functools 21 | import glob 22 | import losses 23 | import networks 24 | import numpy as np 25 | import options 26 | import os.path as osp 27 | import random 28 | import skimage.measure 29 | import staged_model 30 | import tensorflow as tf 31 | import time 32 | import utils 33 | 34 | 35 | def build_model_fn(use_exponential_moving_average=True): 36 | """Builds and returns the model function for an estimator. 37 | 38 | Args: 39 | use_exponential_moving_average: bool. If true, the exponential moving 40 | average will be used. 41 | 42 | Returns: 43 | function, the model_fn function typically required by an estimator. 44 | """ 45 | arch_type = opts.arch_type 46 | use_appearance = opts.use_appearance 47 | def model_fn(features, labels, mode, params): 48 | """An estimator build_fn.""" 49 | del labels, params 50 | if mode == tf.estimator.ModeKeys.TRAIN: 51 | step = tf.train.get_global_step() 52 | 53 | x_in = features['conditional_input'] 54 | x_gt = features['expected_output'] # ground truth output 55 | x_app = features['peek_input'] 56 | 57 | if opts.training_pipeline == 'staged': 58 | ops = staged_model.create_computation_graph(x_in, x_gt, x_app=x_app, 59 | arch_type=opts.arch_type) 60 | op_increment_step = tf.assign_add(step, 1) 61 | train_disc_op = ops['train_disc_op'] 62 | train_renderer_op = ops['train_renderer_op'] 63 | train_op = tf.group(train_disc_op, train_renderer_op, op_increment_step) 64 | 65 | utils.HookReport.log_tensor(ops['total_loss_d'], 'total_loss_d') 66 | utils.HookReport.log_tensor(ops['loss_d_real'], 'loss_d_real') 67 | utils.HookReport.log_tensor(ops['loss_d_fake'], 'loss_d_fake') 68 | utils.HookReport.log_tensor(ops['total_loss_g'], 'total_loss_g') 69 | utils.HookReport.log_tensor(ops['loss_g_gan'], 'loss_g_gan') 70 | utils.HookReport.log_tensor(ops['loss_g_recon'], 'loss_g_recon') 71 | utils.HookReport.log_tensor(step, 'global_step') 72 | 73 | return tf.estimator.EstimatorSpec( 74 | mode=mode, loss=ops['total_loss_d'] + ops['total_loss_g'], 75 | train_op=train_op) 76 | else: 77 | raise NotImplementedError('%s training is not implemented.' % 78 | opts.training_pipeline) 79 | elif mode == tf.estimator.ModeKeys.EVAL: 80 | raise NotImplementedError('Eval is not implemented.') 81 | else: # all below modes are for difference inference tasks. 82 | # Build network and initialize inference variables. 83 | g_func = networks.RenderingModel(arch_type, use_appearance) 84 | if use_appearance: 85 | app_func = g_func.get_appearance_encoder() 86 | if use_exponential_moving_average: 87 | ema = tf.train.ExponentialMovingAverage(decay=0.999) 88 | var_dict = ema.variables_to_restore() 89 | tf.train.init_from_checkpoint(osp.join(opts.train_dir), var_dict) 90 | 91 | if mode == tf.estimator.ModeKeys.PREDICT: 92 | x_in = features['conditional_input'] 93 | if use_appearance: 94 | x_app = features['peek_input'] 95 | x_app_embedding, _, _ = app_func(x_app) 96 | else: 97 | x_app_embedding = None 98 | y = g_func(x_in, x_app_embedding) 99 | tf.logging.info('DBG: shape of y during prediction %s.' % str(y.shape)) 100 | return tf.estimator.EstimatorSpec(mode=mode, predictions=y) 101 | 102 | # 'eval_subset' mode is same as PREDICT but it concatenates the output to 103 | # the input render, semantic map and ground truth for easy comparison. 104 | elif mode == 'eval_subset': 105 | x_in = features['conditional_input'] 106 | x_gt = features['expected_output'] 107 | if use_appearance: 108 | x_app = features['peek_input'] 109 | x_app_embedding, _, _ = app_func(x_app) 110 | else: 111 | x_app_embedding = None 112 | y = g_func(x_in, x_app_embedding) 113 | tf.logging.info('DBG: shape of y during prediction %s.' % str(y.shape)) 114 | x_in_rgb = tf.slice(x_in, [0, 0, 0, 0], [-1, -1, -1, 3]) 115 | if opts.use_semantic: 116 | x_in_semantic = tf.slice(x_in, [0, 0, 0, 4], [-1, -1, -1, 3]) 117 | output_tuple = tf.concat([x_in_rgb, x_in_semantic, y, x_gt], axis=2) 118 | else: 119 | output_tuple = tf.concat([x_in_rgb, y, x_gt], axis=2) 120 | return tf.estimator.EstimatorSpec(mode=mode, predictions=output_tuple) 121 | 122 | # 'compute_appearance' mode computes and returns the latent z vector. 123 | elif mode == 'compute_appearance': 124 | assert use_appearance, 'use_appearance is set to False!' 125 | x_app_in = features['peek_input'] 126 | # NOTE the following line is a temporary hack (which is 127 | # specially bad for inputs smaller than 512x512). 128 | x_app_in = tf.image.resize_image_with_crop_or_pad(x_app_in, 512, 512) 129 | app_embedding, _, _ = app_func(x_app_in) 130 | return tf.estimator.EstimatorSpec(mode=mode, predictions=app_embedding) 131 | 132 | # 'interpolate_appearance' mode expects an already computed latent z 133 | # vector as input passed a value to the dict key 'appearance_embedding'. 134 | elif mode == 'interpolate_appearance': 135 | assert use_appearance, 'use_appearance is set to False!' 136 | x_in = features['conditional_input'] 137 | x_app_embedding = features['appearance_embedding'] 138 | y = g_func(x_in, x_app_embedding) 139 | tf.logging.info('DBG: shape of y during prediction %s.' % str(y.shape)) 140 | return tf.estimator.EstimatorSpec(mode=mode, predictions=y) 141 | else: 142 | raise ValueError('Unsupported mode: ' + mode) 143 | 144 | return model_fn 145 | 146 | 147 | def make_sample_grid_and_save(est, dataset_name, dataset_parent_dir, grid_dims, 148 | output_dir, cur_nimg): 149 | """Evaluate a fixed set of validation images and save output. 150 | 151 | Args: 152 | est: tf,estimator.Estimator, TF estimator to run the predictions. 153 | dataset_name: basename for the validation tfrecord from which to load 154 | validation images. 155 | dataset_parent_dir: path to a directory containing the validation tfrecord. 156 | grid_dims: 2-tuple int for the grid size (1 unit = 1 image). 157 | output_dir: string, where to save image samples. 158 | cur_nimg: int, current number of images seen by training. 159 | 160 | Returns: 161 | None. 162 | """ 163 | num_examples = grid_dims[0] * grid_dims[1] 164 | def input_val_fn(): 165 | dict_inp = data.provide_data( 166 | dataset_name=dataset_name, parent_dir=dataset_parent_dir, subset='val', 167 | batch_size=1, crop_flag=True, crop_size=opts.train_resolution, 168 | seeds=[0], max_examples=num_examples, 169 | use_appearance=opts.use_appearance, shuffle=0) 170 | x_in = dict_inp['conditional_input'] 171 | x_gt = dict_inp['expected_output'] # ground truth output 172 | x_app = dict_inp['peek_input'] 173 | return x_in, x_gt, x_app 174 | 175 | def est_input_val_fn(): 176 | x_in, _, x_app = input_val_fn() 177 | features = {'conditional_input': x_in, 'peek_input': x_app} 178 | return features 179 | 180 | images = [x for x in est.predict(est_input_val_fn)] 181 | images = np.array(images, 'f') 182 | images = images.reshape(grid_dims + images.shape[1:]) 183 | utils.save_images(utils.to_png(utils.images_to_grid(images)), output_dir, 184 | cur_nimg) 185 | 186 | 187 | def visualize_image_sequence(est, dataset_name, dataset_parent_dir, 188 | input_sequence_name, app_base_path, output_dir): 189 | """Generates an image sequence as a video and stores it to disk.""" 190 | batch_sz = opts.batch_size 191 | def input_seq_fn(): 192 | dict_inp = data.provide_data( 193 | dataset_name=dataset_name, parent_dir=dataset_parent_dir, 194 | subset=input_sequence_name, batch_size=batch_sz, crop_flag=False, 195 | seeds=None, use_appearance=False, shuffle=0) 196 | x_in = dict_inp['conditional_input'] 197 | return x_in 198 | 199 | # Compute appearance embedding only once and use it for all input frames. 200 | app_rgb_path = app_base_path + '_reference.png' 201 | app_rendered_path = app_base_path + '_color.png' 202 | app_depth_path = app_base_path + '_depth.png' 203 | app_sem_path = app_base_path + '_seg_rgb.png' 204 | x_app = _load_and_concatenate_image_channels( 205 | app_rgb_path, app_rendered_path, app_depth_path, app_sem_path) 206 | def seq_with_single_appearance_inp_fn(): 207 | """input frames with a fixed latent appearance vector.""" 208 | x_in_op = input_seq_fn() 209 | x_app_op = tf.convert_to_tensor(x_app) 210 | x_app_tiled_op = tf.tile(x_app_op, [tf.shape(x_in_op)[0], 1, 1, 1]) 211 | return {'conditional_input': x_in_op, 212 | 'peek_input': x_app_tiled_op} 213 | 214 | images = [x for x in est.predict(seq_with_single_appearance_inp_fn)] 215 | for i, gen_img in enumerate(images): 216 | output_file_path = osp.join(output_dir, 'out_%04d.png' % i) 217 | print('Saving frame #%d to %s' % (i, output_file_path)) 218 | with tf.gfile.Open(output_file_path, 'wb') as f: 219 | f.write(utils.to_png(gen_img)) 220 | 221 | 222 | def train(dataset_name, dataset_parent_dir, load_pretrained_app_encoder, 223 | load_trained_fixed_app, save_samples_kimg=50): 224 | """Main training procedure. 225 | 226 | The trained model is saved in opts.train_dir, the function itself does not 227 | return anything. 228 | 229 | Args: 230 | save_samples_kimg: int, period (in KiB) to save sample images. 231 | 232 | Returns: 233 | None. 234 | """ 235 | image_dir = osp.join(opts.train_dir, 'images') # to save validation images. 236 | tf.gfile.MakeDirs(image_dir) 237 | config = tf.estimator.RunConfig( 238 | save_summary_steps=(1 << 10) // opts.batch_size, 239 | save_checkpoints_steps=(save_samples_kimg << 10) // opts.batch_size, 240 | keep_checkpoint_max=5, 241 | log_step_count_steps=1 << 30) 242 | model_dir = opts.train_dir 243 | if (opts.use_appearance and load_trained_fixed_app and 244 | not tf.train.latest_checkpoint(model_dir)): 245 | tf.logging.warning('***** Loading resume_step from %s!' % 246 | opts.fixed_appearance_train_dir) 247 | resume_step = utils.load_global_step_from_checkpoint_dir( 248 | opts.fixed_appearance_train_dir) 249 | else: 250 | tf.logging.warning('***** Loading resume_step (if any) from %s!' % 251 | model_dir) 252 | resume_step = utils.load_global_step_from_checkpoint_dir(model_dir) 253 | if resume_step != 0: 254 | tf.logging.warning('****** Resuming training at %d!' % resume_step) 255 | 256 | model_fn = build_model_fn() # model function for TFEstimator. 257 | 258 | hooks = [utils.HookReport(1 << 12, opts.batch_size)] 259 | 260 | if opts.use_appearance and load_pretrained_app_encoder: 261 | tf.logging.warning('***** will warm-start from %s!' % 262 | opts.appearance_pretrain_dir) 263 | ws = tf.estimator.WarmStartSettings( 264 | ckpt_to_initialize_from=opts.appearance_pretrain_dir, 265 | vars_to_warm_start='appearance_net/.*') 266 | elif opts.use_appearance and load_trained_fixed_app: 267 | tf.logging.warning('****** finetuning will warm-start from %s!' % 268 | opts.fixed_appearance_train_dir) 269 | ws = tf.estimator.WarmStartSettings( 270 | ckpt_to_initialize_from=opts.fixed_appearance_train_dir, 271 | vars_to_warm_start='.*') 272 | else: 273 | ws = None 274 | tf.logging.warning('****** No warm-starting; using random initialization!') 275 | 276 | est = tf.estimator.Estimator(model_fn, model_dir, config, params={}, 277 | warm_start_from=ws) 278 | 279 | for next_kimg in range(opts.save_samples_kimg, opts.total_kimg + 1, 280 | opts.save_samples_kimg): 281 | next_step = (next_kimg << 10) // opts.batch_size 282 | if opts.num_crops == -1: # use random crops 283 | crop_seeds = None 284 | else: 285 | crop_seeds = list(100 * np.arange(opts.num_crops)) 286 | input_train_fn = functools.partial( 287 | data.provide_data, dataset_name=dataset_name, 288 | parent_dir=dataset_parent_dir, subset='train', 289 | batch_size=opts.batch_size, crop_flag=True, 290 | crop_size=opts.train_resolution, seeds=crop_seeds, 291 | use_appearance=opts.use_appearance) 292 | est.train(input_train_fn, max_steps=next_step, hooks=hooks) 293 | tf.logging.info('DBG: kimg=%d, cur_step=%d' % (next_kimg, next_step)) 294 | tf.logging.info('DBG: Saving a validation grid image %06d to %s' % ( 295 | next_kimg, image_dir)) 296 | make_sample_grid_and_save(est, dataset_name, dataset_parent_dir, (3, 3), 297 | image_dir, next_kimg << 10) 298 | 299 | 300 | def _build_inference_estimator(model_dir): 301 | model_fn = build_model_fn() 302 | est = tf.estimator.Estimator(model_fn, model_dir) 303 | return est 304 | 305 | 306 | def evaluate_sequence(dataset_name, dataset_parent_dir, virtual_seq_name, 307 | app_base_path): 308 | output_dir = osp.join(opts.train_dir, 'seq_output_%s' % virtual_seq_name) 309 | tf.gfile.MakeDirs(output_dir) 310 | est = _build_inference_estimator(opts.train_dir) 311 | visualize_image_sequence(est, dataset_name, dataset_parent_dir, 312 | virtual_seq_name, app_base_path, output_dir) 313 | 314 | 315 | def evaluate_image_set(dataset_name, dataset_parent_dir, subset_suffix, 316 | output_dir=None, batch_size=6): 317 | if output_dir is None: 318 | output_dir = osp.join(opts.train_dir, 'validation_output_%s' % subset_suffix) 319 | tf.gfile.MakeDirs(output_dir) 320 | model_fn_old = build_model_fn() 321 | def model_fn_wrapper(features, labels, mode, params): 322 | del mode 323 | return model_fn_old(features, labels, 'eval_subset', params) 324 | model_dir = opts.train_dir 325 | est = tf.estimator.Estimator(model_fn_wrapper, model_dir) 326 | est_inp_fn = functools.partial( 327 | data.provide_data, dataset_name=dataset_name, 328 | parent_dir=dataset_parent_dir, subset=subset_suffix, 329 | batch_size=batch_size, use_appearance=opts.use_appearance, shuffle=0) 330 | 331 | print('Evaluating images for subset %s' % subset_suffix) 332 | images = [x for x in est.predict(est_inp_fn)] 333 | print('Evaluated %d images' % len(images)) 334 | for i, img in enumerate(images): 335 | output_file_path = osp.join(output_dir, 'out_%04d.png' % i) 336 | print('Saving file #%d: %s' % (i, output_file_path)) 337 | with tf.gfile.Open(output_file_path, 'wb') as f: 338 | f.write(utils.to_png(img)) 339 | 340 | 341 | def _load_and_concatenate_image_channels(rgb_path=None, rendered_path=None, 342 | depth_path=None, seg_path=None, 343 | size_multiple=64): 344 | """Prepares a single input for the network.""" 345 | if (rgb_path is None and rendered_path is None and depth_path is None and 346 | seg_path is None): 347 | raise ValueError('At least one of the inputs has to be not None') 348 | 349 | channels = () 350 | if rgb_path is not None: 351 | rgb_img = np.array(Image.open(rgb_path)).astype(np.float32) 352 | rgb_img = utils.crop_to_multiple(rgb_img, size_multiple) 353 | channels = channels + (rgb_img,) 354 | if rendered_path is not None: 355 | rendered_img = np.array(Image.open(rendered_path)).astype(np.float32) 356 | if not opts.use_alpha: 357 | rendered_img = rendered_img[:, :, :3] # drop the alpha channel 358 | rendered_img = utils.crop_to_multiple(rendered_img, size_multiple) 359 | channels = channels + (rendered_img,) 360 | if depth_path is not None: 361 | depth_img = np.array(Image.open(depth_path)) 362 | depth_img = depth_img.astype(np.float32) 363 | depth_img = utils.crop_to_multiple(depth_img[:, :, np.newaxis], 364 | size_multiple) 365 | channels = channels + (depth_img,) 366 | # depth_img = depth_img * (2.0 / 255) - 1.0 367 | if seg_path is not None: 368 | seg_img = np.array(Image.open(seg_path)).astype(np.float32) 369 | seg_img = utils.crop_to_multiple(seg_img, size_multiple) 370 | channels = channels + (seg_img,) 371 | # Concatenate and normalize channels 372 | img = np.dstack(channels) 373 | img = np.expand_dims(img, axis=0) 374 | img = img * (2.0 / 255) - 1.0 375 | return img 376 | 377 | 378 | def infer_dir(model_dir, input_dir, output_dir): 379 | tf.gfile.MakeDirs(output_dir) 380 | est = _build_inference_estimator(opts.train_dir) 381 | 382 | def read_image(base_path, is_appearance=False): 383 | if is_appearance: 384 | ref_img_path = base_path + '_reference.png' 385 | else: 386 | ref_img_path = None 387 | rendered_img_path = base_path + '_color.png' 388 | depth_img_path = base_path + '_depth.png' 389 | seg_img_path = base_path + '_seg_rgb.png' 390 | img = _load_and_concatenate_image_channels( 391 | rgb_path=ref_img_path, rendered_path=rendered_img_path, 392 | depth_path=depth_img_path, seg_path=seg_img_path) 393 | return img 394 | 395 | def get_inference_input_fn(base_path, app_base_path): 396 | x_in = read_image(base_path, False) 397 | x_app_in = read_image(app_base_path, True) 398 | def infer_input_fn(): 399 | return {'conditional_input': x_in, 'peek_input': x_app_in} 400 | return infer_input_fn 401 | 402 | file_paths = sorted(glob.glob(osp.join(input_dir, '*_depth.png'))) 403 | base_paths = [x[:-10] for x in file_paths] # remove the '_depth.png' suffix 404 | for inp_base_path in base_paths: 405 | est_inp_fn = get_inference_input_fn(inp_base_path, inp_base_path) 406 | img = next(est.predict(est_inp_fn)) 407 | basename = osp.basename(inp_base_path) 408 | output_img_path = osp.join(output_dir, basename + '_out.png') 409 | print('Saving generated image to %s' % output_img_path) 410 | with tf.gfile.Open(output_img_path, 'wb') as f: 411 | f.write(utils.to_png(img)) 412 | 413 | 414 | def joint_interpolation(model_dir, app_input_dir, st_app_basename, 415 | end_app_basename, camera_path_dir): 416 | """ 417 | Interpolates both viewpoint and appearance between two input images. 418 | """ 419 | # Create output direcotry 420 | output_dir = osp.join(model_dir, 'joint_interpolation_out') 421 | tf.gfile.MakeDirs(output_dir) 422 | 423 | # Build estimator 424 | model_fn_old = build_model_fn() 425 | def model_fn_wrapper(features, labels, mode, params): 426 | del mode 427 | return model_fn_old(features, labels, 'interpolate_appearance', params) 428 | def appearance_model_fn(features, labels, mode, params): 429 | del mode 430 | return model_fn_old(features, labels, 'compute_appearance', params) 431 | config = tf.estimator.RunConfig( 432 | save_summary_steps=1000, save_checkpoints_steps=50000, 433 | keep_checkpoint_max=50, log_step_count_steps=1 << 30) 434 | model_dir = model_dir 435 | est = tf.estimator.Estimator(model_fn_wrapper, model_dir, config, params={}) 436 | est_app = tf.estimator.Estimator(appearance_model_fn, model_dir, config, 437 | params={}) 438 | 439 | # Compute appearance embeddings for the two input appearance images. 440 | app_inputs = [] 441 | for app_basename in [st_app_basename, end_app_basename]: 442 | app_rgb_path = osp.join(app_input_dir, app_basename + '_reference.png') 443 | app_rendered_path = osp.join(app_input_dir, app_basename + '_color.png') 444 | app_depth_path = osp.join(app_input_dir, app_basename + '_depth.png') 445 | app_seg_path = osp.join(app_input_dir, app_basename + '_seg_rgb.png') 446 | app_in = _load_and_concatenate_image_channels( 447 | rgb_path=app_rgb_path, rendered_path=app_rendered_path, 448 | depth_path=app_depth_path, seg_path=app_seg_path) 449 | # app_inputs.append(tf.convert_to_tensor(app_in)) 450 | app_inputs.append(app_in) 451 | 452 | embedding1 = next(est_app.predict( 453 | lambda: {'peek_input': app_inputs[0]})) 454 | embedding1 = np.expand_dims(embedding1, axis=0) 455 | embedding2 = next(est_app.predict( 456 | lambda: {'peek_input': app_inputs[1]})) 457 | embedding2 = np.expand_dims(embedding2, axis=0) 458 | 459 | file_paths = sorted(glob.glob(osp.join(camera_path_dir, '*_depth.png'))) 460 | base_paths = [x[:-10] for x in file_paths] # remove the '_depth.png' suffix 461 | 462 | # Compute interpolated appearance embeddings 463 | num_interpolations = len(base_paths) 464 | interpolated_embeddings = [] 465 | delta_vec = (embedding2 - embedding1) / (num_interpolations - 1) 466 | for delta_iter in range(num_interpolations): 467 | x_app_embedding = embedding1 + delta_iter * delta_vec 468 | interpolated_embeddings.append(x_app_embedding) 469 | 470 | # Generate and save interpolated images 471 | for frame_idx, embedding in enumerate(interpolated_embeddings): 472 | # Read in input frame 473 | frame_render_path = osp.join(base_paths[frame_idx] + '_color.png') 474 | frame_depth_path = osp.join(base_paths[frame_idx] + '_depth.png') 475 | frame_seg_path = osp.join(base_paths[frame_idx] + '_seg_rgb.png') 476 | x_in = _load_and_concatenate_image_channels( 477 | rgb_path=None, rendered_path=frame_render_path, 478 | depth_path=frame_depth_path, seg_path=frame_seg_path) 479 | 480 | img = next(est.predict( 481 | lambda: {'conditional_input': tf.convert_to_tensor(x_in), 482 | 'appearance_embedding': tf.convert_to_tensor(embedding)})) 483 | output_img_name = '%s_%s_%03d.png' % (st_app_basename, end_app_basename, 484 | frame_idx) 485 | output_img_path = osp.join(output_dir, output_img_name) 486 | print('Saving interpolated image to %s' % output_img_path) 487 | with tf.gfile.Open(output_img_path, 'wb') as f: 488 | f.write(utils.to_png(img)) 489 | 490 | 491 | def interpolate_appearance(model_dir, input_dir, target_img_basename, 492 | appearance_img1_basename, appearance_img2_basename): 493 | # Create output direcotry 494 | output_dir = osp.join(model_dir, 'interpolate_appearance_out') 495 | tf.gfile.MakeDirs(output_dir) 496 | 497 | # Build estimator 498 | model_fn_old = build_model_fn() 499 | def model_fn_wrapper(features, labels, mode, params): 500 | del mode 501 | return model_fn_old(features, labels, 'interpolate_appearance', params) 502 | def appearance_model_fn(features, labels, mode, params): 503 | del mode 504 | return model_fn_old(features, labels, 'compute_appearance', params) 505 | config = tf.estimator.RunConfig( 506 | save_summary_steps=1000, save_checkpoints_steps=50000, 507 | keep_checkpoint_max=50, log_step_count_steps=1 << 30) 508 | model_dir = model_dir 509 | est = tf.estimator.Estimator(model_fn_wrapper, model_dir, config, params={}) 510 | est_app = tf.estimator.Estimator(appearance_model_fn, model_dir, config, 511 | params={}) 512 | 513 | # Compute appearance embeddings for the two input appearance images. 514 | app_inputs = [] 515 | for app_basename in [appearance_img1_basename, appearance_img2_basename]: 516 | app_rgb_path = osp.join(input_dir, app_basename + '_reference.png') 517 | app_rendered_path = osp.join(input_dir, app_basename + '_color.png') 518 | app_depth_path = osp.join(input_dir, app_basename + '_depth.png') 519 | app_seg_path = osp.join(input_dir, app_basename + '_seg_rgb.png') 520 | app_in = _load_and_concatenate_image_channels( 521 | rgb_path=app_rgb_path, rendered_path=app_rendered_path, 522 | depth_path=app_depth_path, seg_path=app_seg_path) 523 | # app_inputs.append(tf.convert_to_tensor(app_in)) 524 | app_inputs.append(app_in) 525 | 526 | embedding1 = next(est_app.predict( 527 | lambda: {'peek_input': app_inputs[0]})) 528 | embedding2 = next(est_app.predict( 529 | lambda: {'peek_input': app_inputs[1]})) 530 | embedding1 = np.expand_dims(embedding1, axis=0) 531 | embedding2 = np.expand_dims(embedding2, axis=0) 532 | 533 | # Compute interpolated appearance embeddings 534 | num_interpolations = 10 535 | interpolated_embeddings = [] 536 | delta_vec = (embedding2 - embedding1) / num_interpolations 537 | for delta_iter in range(num_interpolations + 1): 538 | x_app_embedding = embedding1 + delta_iter * delta_vec 539 | interpolated_embeddings.append(x_app_embedding) 540 | 541 | # Read in the generator input for the target image to render 542 | rendered_img_path = osp.join(input_dir, target_img_basename + '_color.png') 543 | depth_img_path = osp.join(input_dir, target_img_basename + '_depth.png') 544 | seg_img_path = osp.join(input_dir, target_img_basename + '_seg_rgb.png') 545 | x_in = _load_and_concatenate_image_channels( 546 | rgb_path=None, rendered_path=rendered_img_path, 547 | depth_path=depth_img_path, seg_path=seg_img_path) 548 | 549 | # Generate and save interpolated images 550 | for interpolate_iter, embedding in enumerate(interpolated_embeddings): 551 | img = next(est.predict( 552 | lambda: {'conditional_input': tf.convert_to_tensor(x_in), 553 | 'appearance_embedding': tf.convert_to_tensor(embedding)})) 554 | output_img_name = 'interpolate_%s_%s_%s_%03d.png' % ( 555 | target_img_basename, appearance_img1_basename, appearance_img2_basename, 556 | interpolate_iter) 557 | output_img_path = osp.join(output_dir, output_img_name) 558 | print('Saving interpolated image to %s' % output_img_path) 559 | with tf.gfile.Open(output_img_path, 'wb') as f: 560 | f.write(utils.to_png(img)) 561 | 562 | 563 | def main(argv): 564 | del argv 565 | configs_str = options.list_options() 566 | tf.gfile.MakeDirs(opts.train_dir) 567 | with tf.gfile.Open(osp.join(opts.train_dir, 'configs.txt'), 'wb') as f: 568 | f.write(configs_str) 569 | tf.logging.info('Local configs\n%s' % configs_str) 570 | 571 | if opts.run_mode == 'train': 572 | dataset_name = opts.dataset_name 573 | dataset_parent_dir = opts.dataset_parent_dir 574 | load_pretrained_app_encoder = opts.load_pretrained_app_encoder 575 | load_trained_fixed_app = opts.load_from_another_ckpt 576 | batch_size = opts.batch_size 577 | train(dataset_name, dataset_parent_dir, load_pretrained_app_encoder, 578 | load_trained_fixed_app) 579 | elif opts.run_mode == 'eval': # generate a camera path output sequence from TFRecord inputs. 580 | dataset_name = opts.dataset_name 581 | dataset_parent_dir = opts.dataset_parent_dir 582 | virtual_seq_name = opts.virtual_seq_name 583 | inp_app_img_base_path = opts.inp_app_img_base_path 584 | evaluate_sequence(dataset_name, dataset_parent_dir, virtual_seq_name, 585 | inp_app_img_base_path) 586 | elif opts.run_mode == 'eval_subset': # generate output for validation set (encoded as TFRecords) 587 | dataset_name = opts.dataset_name 588 | dataset_parent_dir = opts.dataset_parent_dir 589 | virtual_seq_name = opts.virtual_seq_name 590 | evaluate_image_set(dataset_name, dataset_parent_dir, virtual_seq_name, 591 | opts.output_validation_dir, opts.batch_size) 592 | elif opts.run_mode == 'eval_dir': # evaluate output for a directory with input images 593 | input_dir = opts.inference_input_path 594 | output_dir = opts.inference_output_dir 595 | model_dir = opts.train_dir 596 | infer_dir(model_dir, input_dir, output_dir) 597 | elif opts.run_mode == 'interpolate_appearance': # interpolate appearance only between two images. 598 | model_dir = opts.train_dir 599 | input_dir = opts.inference_input_path 600 | target_img_basename = opts.target_img_basename 601 | app_img1_basename = opts.appearance_img1_basename 602 | app_img2_basename = opts.appearance_img2_basename 603 | interpolate_appearance(model_dir, input_dir, target_img_basename, 604 | app_img1_basename, app_img2_basename) 605 | elif opts.run_mode == 'joint_interpolation': # interpolate viewpoint and appearance between two images 606 | model_dir = opts.train_dir 607 | app_input_dir = opts.inference_input_path 608 | st_app_basename = opts.appearance_img1_basename 609 | end_app_basename = opts.appearance_img2_basename 610 | frames_dir = opts.frames_dir 611 | joint_interpolation(model_dir, app_input_dir, st_app_basename, 612 | end_app_basename, frames_dir) 613 | else: 614 | raise ValueError('Unsupported --run_mode %s' % opts.run_mode) 615 | 616 | 617 | if __name__ == '__main__': 618 | app.run(main) 619 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | #     https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from absl import flags 17 | import numpy as np 18 | 19 | FLAGS = flags.FLAGS 20 | 21 | # ------------------------------------------------------------------------------ 22 | # Train flags 23 | # ------------------------------------------------------------------------------ 24 | 25 | # Dataset, model directory and run mode 26 | flags.DEFINE_string('train_dir', '/tmp/nerual_rendering', 27 | 'Directory for model training.') 28 | flags.DEFINE_string('dataset_name', 'sanmarco9k', 'name ID for a dataset.') 29 | flags.DEFINE_string( 30 | 'dataset_parent_dir', '', 31 | 'Directory containing generated tfrecord dataset.') 32 | flags.DEFINE_string('run_mode', 'train', "{'train', 'eval', 'infer'}") 33 | flags.DEFINE_string('imageset_dir', None, 'Directory containing trainset ' 34 | 'images for appearance pretraining.') 35 | flags.DEFINE_string('metadata_output_dir', None, 'Directory to save pickled ' 36 | 'pairwise distance matrix for appearance pretraining.') 37 | flags.DEFINE_integer('save_samples_kimg', 50, 'kimg cycle to save sample' 38 | 'validation ouptut during training.') 39 | 40 | # Network inputs/outputs 41 | flags.DEFINE_boolean('use_depth', True, 'Add depth image to the deep buffer.') 42 | flags.DEFINE_boolean('use_alpha', False, 43 | 'Add alpha channel to the deep buffer.') 44 | flags.DEFINE_boolean('use_semantic', True, 45 | 'Add semantic map to the deep buffer.') 46 | flags.DEFINE_boolean('use_appearance', True, 47 | 'Capture appearance from an input real image.') 48 | flags.DEFINE_integer('deep_buffer_nc', 7, 49 | 'Number of input channels in the deep buffer.') 50 | flags.DEFINE_integer('appearance_nc', 10, 51 | 'Number of input channels to the appearance encoder.') 52 | flags.DEFINE_integer('output_nc', 3, 53 | 'Number of channels for the generated image.') 54 | 55 | # Staged training flags 56 | flags.DEFINE_string( 57 | 'vgg16_path', './vgg16_weights/vgg16.npy', 58 | 'path to a *.npy file with vgg16 pretrained weights') 59 | flags.DEFINE_boolean('load_pretrained_app_encoder', False, 60 | 'Warmstart appearance encoder with pretrained weights.') 61 | flags.DEFINE_string('appearance_pretrain_dir', '', 62 | 'Model dir for the pretrained appearance encoder.') 63 | flags.DEFINE_boolean('train_app_encoder', False, 'Whether to make the weights ' 64 | 'for the appearance encoder trainable or not.') 65 | flags.DEFINE_boolean( 66 | 'load_from_another_ckpt', False, 'Load weights from another trained model, ' 67 | 'e.g load model trained with a fixed appearance encoder.') 68 | flags.DEFINE_string('fixed_appearance_train_dir', '', 69 | 'Model dir for training G with a fixed appearance net.') 70 | 71 | # ----------------------------------------------------------------------------- 72 | 73 | # More hparams 74 | flags.DEFINE_integer('train_resolution', 256, 75 | 'Crop train images to this resolution.') 76 | flags.DEFINE_float('d_lr', 0.001, 'Learning rate for the discriminator.') 77 | flags.DEFINE_float('g_lr', 0.001, 'Learning rate for the generator.') 78 | flags.DEFINE_float('ez_lr', 0.0001, 'Learning rate for appearance encoder.') 79 | flags.DEFINE_integer('batch_size', 8, 'Batch size for training.') 80 | flags.DEFINE_boolean('use_scaling', True, "use He's scaling.") 81 | flags.DEFINE_integer('num_crops', 30, 'num crops from train images' 82 | '(use -1 for random crops).') 83 | flags.DEFINE_integer('app_vector_size', 8, 'Size of latent appearance vector.') 84 | flags.DEFINE_integer('total_kimg', 20000, 85 | 'Max number (in kilo) of training images for training.') 86 | flags.DEFINE_float('adam_beta1', 0.0, 'beta1 for adam optimizer.') 87 | flags.DEFINE_float('adam_beta2', 0.99, 'beta2 for adam optimizer.') 88 | 89 | # Loss weights 90 | flags.DEFINE_float('w_loss_vgg', 0.3, 'VGG loss weight.') 91 | flags.DEFINE_float('w_loss_feat', 10., 'Feature loss weight (from pix2pixHD).') 92 | flags.DEFINE_float('w_loss_l1', 50., 'L1 loss weight.') 93 | flags.DEFINE_float('w_loss_z_recon', 10., 'Z reconstruction loss weight.') 94 | flags.DEFINE_float('w_loss_gan', 1., 'Adversarial loss weight.') 95 | flags.DEFINE_float('w_loss_z_gan', 1., 'Z adversarial loss weight.') 96 | flags.DEFINE_float('w_loss_kl', 0.01, 'KL divergence weight.') 97 | flags.DEFINE_float('w_loss_l2_reg', 0.01, 'Weight for L2 regression on Z.') 98 | 99 | # ----------------------------------------------------------------------------- 100 | 101 | # Architecture and training setup 102 | flags.DEFINE_string('arch_type', 'pggan', 103 | 'Architecture type: {pggan, pix2pixhd}.') 104 | flags.DEFINE_string('training_pipeline', 'staged', 105 | 'Training type type: {staged, bicycle_gan, drit}.') 106 | flags.DEFINE_integer('g_nf', 64, 107 | 'num filters in the first/last layers of U-net.') 108 | flags.DEFINE_boolean('concatenate_skip_layers', True, 109 | 'Use concatenation for skip connections.') 110 | 111 | ## if arch_type == 'pggan': 112 | flags.DEFINE_integer('pggan_n_blocks', 5, 113 | 'Num blocks for the pggan architecture.') 114 | ## if arch_type == 'pix2pixhd': 115 | flags.DEFINE_integer('p2p_n_downsamples', 3, 116 | 'Num downsamples for the pix2pixHD architecture.') 117 | flags.DEFINE_integer('p2p_n_resblocks', 4, 'Num residual blocks at the ' 118 | 'end/start of the pix2pixHD encoder/decoder.') 119 | ## if use_drit_pipeline: 120 | flags.DEFINE_boolean('use_concat', True, '"concat" mode from DRIT.') 121 | flags.DEFINE_boolean('normalize_drit_Ez', True, 'Add pixelnorm layers to the ' 122 | 'appearance encoder.') 123 | flags.DEFINE_boolean('concat_z_in_all_layers', True, 'Inject z at each ' 124 | 'upsampling layer in the decoder (only for DRIT baseline)') 125 | flags.DEFINE_string('inject_z', 'to_bottleneck', 'Method for injecting z; ' 126 | 'one of {to_encoder, to_bottleneck}.') 127 | flags.DEFINE_boolean('use_vgg_loss', True, 'vgg v L1 reconstruction loss.') 128 | 129 | # ------------------------------------------------------------------------------ 130 | # Inference flags 131 | # ------------------------------------------------------------------------------ 132 | 133 | flags.DEFINE_string('inference_input_path', '', 134 | 'Parent directory for input images at inference time.') 135 | flags.DEFINE_string('inference_output_dir', '', 'Output path for inference') 136 | flags.DEFINE_string('target_img_basename', '', 137 | 'basename of target image to render for interpolation') 138 | flags.DEFINE_string('virtual_seq_name', 'full_camera_path', 139 | 'name for the virtual camera path suffix for the TFRecord.') 140 | flags.DEFINE_string('inp_app_img_base_path', '', 141 | 'base path for the input appearance image for camera paths') 142 | 143 | flags.DEFINE_string('appearance_img1_basename', '', 144 | 'basename of the first appearance image for interpolation') 145 | flags.DEFINE_string('appearance_img2_basename', '', 146 | 'basename of the first appearance image for interpolation') 147 | flags.DEFINE_list('input_basenames', [], 'input basenames for inference') 148 | flags.DEFINE_list('input_app_basenames', [], 'input appearance basenames for ' 149 | 'inference') 150 | flags.DEFINE_string('frames_dir', '', 151 | 'Folder with input frames to a camera path') 152 | flags.DEFINE_string('output_validation_dir', '', 153 | 'dataset_name for storing results in a structured folder') 154 | flags.DEFINE_string('input_rendered', '', 155 | 'input rendered image name for inference') 156 | flags.DEFINE_string('input_depth', '', 'input depth image name for inference') 157 | flags.DEFINE_string('input_seg', '', 158 | 'input segmentation mask image name for inference') 159 | flags.DEFINE_string('input_app_rgb', '', 160 | 'input appearance rgb image name for inference') 161 | flags.DEFINE_string('input_app_rendered', '', 162 | 'input appearance rendered image name for inference') 163 | flags.DEFINE_string('input_app_depth', '', 164 | 'input appearance depth image name for inference') 165 | flags.DEFINE_string('input_app_seg', '', 166 | 'input appearance segmentation mask image name for' 167 | 'inference') 168 | flags.DEFINE_string('output_img_name', '', 169 | '[OPTIONAL] output image name for inference') 170 | 171 | # ----------------------------------------------------------------------------- 172 | # Some validation and assertions 173 | # ----------------------------------------------------------------------------- 174 | 175 | def validate_options(): 176 | if FLAGS.use_drit_training: 177 | assert FLAGS.use_appearance, 'DRIT pipeline requires --use_appearance' 178 | assert not ( 179 | FLAGS.load_pretrained_appearance_encoder and FLAGS.load_from_another_ckpt), ( 180 | 'You cannot load weights for the appearance encoder from two different ' 181 | 'checkpoints!') 182 | if not FLAGS.use_appearance: 183 | print('**Warning: setting --app_vector_size to 0 since ' 184 | '--use_appearance=False!') 185 | FLAGS.set_default('app_vector_size', 0) 186 | 187 | # ----------------------------------------------------------------------------- 188 | # Print all options 189 | # ----------------------------------------------------------------------------- 190 | 191 | def list_options(): 192 | configs = ('# Run flags/options from options.py:\n' 193 | '# ----------------------------------\n') 194 | configs += ('## Train flags:\n' 195 | '## ------------\n') 196 | configs += 'train_dir = %s\n' % FLAGS.train_dir 197 | configs += 'dataset_name = %s\n' % FLAGS.dataset_name 198 | configs += 'dataset_parent_dir = %s\n' % FLAGS.dataset_parent_dir 199 | configs += 'run_mode = %s\n' % FLAGS.run_mode 200 | configs += 'save_samples_kimg = %d\n' % FLAGS.save_samples_kimg 201 | configs += '\n# --------------------------------------------------------\n\n' 202 | 203 | configs += ('## Network inputs and outputs:\n' 204 | '## ---------------------------\n') 205 | configs += 'use_depth = %s\n' % str(FLAGS.use_depth) 206 | configs += 'use_alpha = %s\n' % str(FLAGS.use_alpha) 207 | configs += 'use_semantic = %s\n' % str(FLAGS.use_semantic) 208 | configs += 'use_appearance = %s\n' % str(FLAGS.use_appearance) 209 | configs += 'deep_buffer_nc = %d\n' % FLAGS.deep_buffer_nc 210 | configs += 'appearance_nc = %d\n' % FLAGS.appearance_nc 211 | configs += 'output_nc = %d\n' % FLAGS.output_nc 212 | configs += 'train_resolution = %d\n' % FLAGS.train_resolution 213 | configs += '\n# --------------------------------------------------------\n\n' 214 | 215 | configs += ('## Staged training flags:\n' 216 | '## ----------------------\n') 217 | configs += 'load_pretrained_app_encoder = %s\n' % str( 218 | FLAGS.load_pretrained_app_encoder) 219 | configs += 'appearance_pretrain_dir = %s\n' % FLAGS.appearance_pretrain_dir 220 | configs += 'train_app_encoder = %s\n' % str(FLAGS.train_app_encoder) 221 | configs += 'load_from_another_ckpt = %s\n' % str(FLAGS.load_from_another_ckpt) 222 | configs += 'fixed_appearance_train_dir = %s\n' % str( 223 | FLAGS.fixed_appearance_train_dir) 224 | configs += '\n# --------------------------------------------------------\n\n' 225 | 226 | configs += ('## More hyper-parameters:\n' 227 | '## ----------------------\n') 228 | configs += 'd_lr = %f\n' % FLAGS.d_lr 229 | configs += 'g_lr = %f\n' % FLAGS.g_lr 230 | configs += 'ez_lr = %f\n' % FLAGS.ez_lr 231 | configs += 'batch_size = %d\n' % FLAGS.batch_size 232 | configs += 'use_scaling = %s\n' % str(FLAGS.use_scaling) 233 | configs += 'num_crops = %d\n' % FLAGS.num_crops 234 | configs += 'app_vector_size = %d\n' % FLAGS.app_vector_size 235 | configs += 'total_kimg = %d\n' % FLAGS.total_kimg 236 | configs += 'adam_beta1 = %f\n' % FLAGS.adam_beta1 237 | configs += 'adam_beta2 = %f\n' % FLAGS.adam_beta2 238 | configs += '\n# --------------------------------------------------------\n\n' 239 | 240 | configs += ('## Loss weights:\n' 241 | '## -------------\n') 242 | configs += 'w_loss_vgg = %f\n' % FLAGS.w_loss_vgg 243 | configs += 'w_loss_feat = %f\n' % FLAGS.w_loss_feat 244 | configs += 'w_loss_l1 = %f\n' % FLAGS.w_loss_l1 245 | configs += 'w_loss_z_recon = %f\n' % FLAGS.w_loss_z_recon 246 | configs += 'w_loss_gan = %f\n' % FLAGS.w_loss_gan 247 | configs += 'w_loss_z_gan = %f\n' % FLAGS.w_loss_z_gan 248 | configs += 'w_loss_kl = %f\n' % FLAGS.w_loss_kl 249 | configs += 'w_loss_l2_reg = %f\n' % FLAGS.w_loss_l2_reg 250 | configs += '\n# --------------------------------------------------------\n\n' 251 | 252 | configs += ('## Architecture and training setup:\n' 253 | '## --------------------------------\n') 254 | configs += 'arch_type = %s\n' % FLAGS.arch_type 255 | configs += 'training_pipeline = %s\n' % FLAGS.training_pipeline 256 | configs += 'g_nf = %d\n' % FLAGS.g_nf 257 | configs += 'concatenate_skip_layers = %s\n' % str( 258 | FLAGS.concatenate_skip_layers) 259 | configs += 'p2p_n_downsamples = %d\n' % FLAGS.p2p_n_downsamples 260 | configs += 'p2p_n_resblocks = %d\n' % FLAGS.p2p_n_resblocks 261 | configs += 'use_concat = %s\n' % str(FLAGS.use_concat) 262 | configs += 'normalize_drit_Ez = %s\n' % str(FLAGS.normalize_drit_Ez) 263 | configs += 'inject_z = %s\n' % FLAGS.inject_z 264 | configs += 'concat_z_in_all_layers = %s\n' % str(FLAGS.concat_z_in_all_layers) 265 | configs += 'use_vgg_loss = %s\n' % str(FLAGS.use_vgg_loss) 266 | configs += '\n# --------------------------------------------------------\n\n' 267 | 268 | return configs 269 | -------------------------------------------------------------------------------- /pretrain_appearance.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | #     https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from PIL import Image 16 | from absl import app 17 | from absl import flags 18 | from options import FLAGS as opts 19 | import glob 20 | import networks 21 | import numpy as np 22 | import os 23 | import os.path as osp 24 | import pickle 25 | import style_loss 26 | import tensorflow as tf 27 | import utils 28 | 29 | 30 | def _load_and_concatenate_image_channels( 31 | rgb_path=None, rendered_path=None, depth_path=None, seg_path=None, 32 | crop_size=512): 33 | if (rgb_path is None and rendered_path is None and depth_path is None and 34 | seg_path is None): 35 | raise ValueError('At least one of the inputs has to be not None') 36 | 37 | channels = () 38 | if rgb_path is not None: 39 | rgb_img = np.array(Image.open(rgb_path)).astype(np.float32) 40 | rgb_img = utils.get_central_crop(rgb_img, crop_size, crop_size) 41 | channels = channels + (rgb_img,) 42 | if rendered_path is not None: 43 | rendered_img = np.array(Image.open(rendered_path)).astype(np.float32) 44 | rendered_img = utils.get_central_crop(rendered_img, crop_size, crop_size) 45 | if not opts.use_alpha: 46 | rendered_img = rendered_img[:,:, :3] # drop the alpha channel 47 | channels = channels + (rendered_img,) 48 | if depth_path is not None: 49 | depth_img = np.array(Image.open(depth_path)) 50 | depth_img = depth_img.astype(np.float32) 51 | depth_img = utils.get_central_crop(depth_img, crop_size, crop_size) 52 | channels = channels + (depth_img,) 53 | if seg_path is not None: 54 | seg_img = np.array(Image.open(seg_path)).astype(np.float32) 55 | channels = channels + (seg_img,) 56 | # Concatenate and normalize channels 57 | img = np.dstack(channels) 58 | img = img * (2.0 / 255) - 1.0 59 | return img 60 | 61 | 62 | def read_single_appearance_input(rgb_img_path): 63 | base_path = rgb_img_path[:-14] # remove the '_reference.png' suffix 64 | rendered_img_path = base_path + '_color.png' 65 | depth_img_path = base_path + '_depth.png' 66 | semantic_img_path = base_path + '_seg_rgb.png' 67 | network_input_img = _load_and_concatenate_image_channels( 68 | rgb_img_path, rendered_img_path, depth_img_path, semantic_img_path, 69 | crop_size=opts.train_resolution) 70 | return network_input_img 71 | 72 | 73 | def get_triplet_input_fn(dataset_path, dist_file_path=None, k_max_nearest=5, 74 | k_max_farthest=13): 75 | input_images_pattern = osp.join(dataset_path, '*_reference.png') 76 | filenames = sorted(glob.glob(input_images_pattern)) 77 | print('DBG: obtained %d input filenames for triplet inputs' % len(filenames)) 78 | print('DBG: Computing pairwise style distances:') 79 | if dist_file_path is not None and osp.exists(dist_file_path): 80 | print('*** Loading distance matrix from %s' % dist_file_path) 81 | with open(dist_file_path, 'rb') as f: 82 | dist_matrix = pickle.load(f)['dist_matrix'] 83 | print('loaded a dist_matrix of shape: %s' % str(dist_matrix.shape)) 84 | else: 85 | dist_matrix = style_loss.compute_pairwise_style_loss_v2(filenames) 86 | dist_dict = {'dist_matrix': dist_matrix} 87 | print('Saving distance matrix to %s' % dist_file_path) 88 | with open(dist_file_path, 'wb') as f: 89 | pickle.dump(dist_dict, f) 90 | 91 | # Sort neighbors for each anchor image 92 | num_imgs = len(dist_matrix) 93 | sorted_neighbors = [np.argsort(dist_matrix[ii, :]) for ii in range(num_imgs)] 94 | 95 | def triplet_input_fn(anchor_idx): 96 | # start from 1 to avoid getting the same image as its own neighbor 97 | positive_neighbor_idx = np.random.randint(1, k_max_nearest + 1) 98 | negative_neighbor_idx = num_imgs - 1 - np.random.randint(0, k_max_farthest) 99 | positive_img_idx = sorted_neighbors[anchor_idx][positive_neighbor_idx] 100 | negative_img_idx = sorted_neighbors[anchor_idx][negative_neighbor_idx] 101 | # Read anchor image 102 | anchor_rgb_path = osp.join(dataset_path, filenames[anchor_idx]) 103 | anchor_input = read_single_appearance_input(anchor_rgb_path) 104 | # Read positive image 105 | positive_rgb_path = osp.join(dataset_path, filenames[positive_img_idx]) 106 | positive_input = read_single_appearance_input(positive_rgb_path) 107 | # Read negative image 108 | negative_rgb_path = osp.join(dataset_path, filenames[negative_img_idx]) 109 | negative_input = read_single_appearance_input(negative_rgb_path) 110 | # Return triplet 111 | return anchor_input, positive_input, negative_input 112 | 113 | return triplet_input_fn 114 | 115 | 116 | def get_tf_triplet_dataset_iter( 117 | dataset_path, trainset_size, dist_file_path, batch_size=4, 118 | deterministic_flag=False, shuffle_buf_size=128, repeat_flag=True): 119 | # Create a dataset of anchor image indices. 120 | idx_dataset = tf.data.Dataset.range(trainset_size) 121 | # Create a mapper function from anchor idx to triplet images. 122 | triplet_mapper = lambda idx: tuple(tf.py_func( 123 | get_triplet_input_fn(dataset_path, dist_file_path), [idx], 124 | [tf.float32, tf.float32, tf.float32])) 125 | # Convert triplet to a dictionary for the estimator input format. 126 | triplet_to_dict_mapper = lambda anchor, pos, neg: { 127 | 'anchor_img': anchor, 'positive_img': pos, 'negative_img': neg} 128 | if repeat_flag: 129 | idx_dataset = idx_dataset.repeat() # Repeat indefinitely. 130 | if not deterministic_flag: 131 | idx_dataset = idx_dataset.shuffle(shuffle_buf_size) 132 | triplet_dataset = idx_dataset.map( 133 | triplet_mapper, num_parallel_calls=max(4, batch_size // 4)) 134 | triplet_dataset = triplet_dataset.map( 135 | triplet_to_dict_mapper, num_parallel_calls=max(4, batch_size // 4)) 136 | else: 137 | triplet_dataset = idx_dataset.map(triplet_mapper, num_parallel_calls=None) 138 | triplet_dataset = triplet_dataset.map(triplet_to_dict_mapper, 139 | num_parallel_calls=None) 140 | triplet_dataset = triplet_dataset.batch(batch_size) 141 | if not deterministic_flag: 142 | triplet_dataset = triplet_dataset.prefetch(4) # Prefetch a few batches. 143 | return triplet_dataset.make_one_shot_iterator() 144 | 145 | 146 | def build_model_fn(batch_size, lr_app_pretrain=0.0001, adam_beta1=0.0, 147 | adam_beta2=0.99): 148 | def model_fn(features, labels, mode, params): 149 | del labels, params 150 | 151 | step = tf.train.get_global_step() 152 | app_func = networks.DRITAppearanceEncoderConcat( 153 | 'appearance_net', opts.appearance_nc, opts.normalize_drit_Ez) 154 | 155 | if mode == tf.estimator.ModeKeys.TRAIN: 156 | op_increment_step = tf.assign_add(step, 1) 157 | with tf.name_scope('Appearance_Loss'): 158 | anchor_img = features['anchor_img'] 159 | positive_img = features['positive_img'] 160 | negative_img = features['negative_img'] 161 | # Compute embeddings (each of shape [batch_sz, 1, 1, app_vector_sz]) 162 | z_anchor, _, _ = app_func(anchor_img) 163 | z_pos, _, _ = app_func(positive_img) 164 | z_neg, _, _ = app_func(negative_img) 165 | # Squeeze into shape of [batch_sz x vec_sz] 166 | anchor_embedding = tf.squeeze(z_anchor, axis=[1, 2], name='z_anchor') 167 | positive_embedding = tf.squeeze(z_pos, axis=[1, 2]) 168 | negative_embedding = tf.squeeze(z_neg, axis=[1, 2]) 169 | # Compute triplet loss 170 | margin = 0.1 171 | anchor_positive_dist = tf.reduce_sum( 172 | tf.square(anchor_embedding - positive_embedding), axis=1) 173 | anchor_negative_dist = tf.reduce_sum( 174 | tf.square(anchor_embedding - negative_embedding), axis=1) 175 | triplet_loss = anchor_positive_dist - anchor_negative_dist + margin 176 | triplet_loss = tf.maximum(triplet_loss, 0.) 177 | triplet_loss = tf.reduce_sum(triplet_loss) / batch_size 178 | tf.summary.scalar('appearance_triplet_loss', triplet_loss) 179 | 180 | # Image summaries 181 | anchor_rgb = tf.slice(anchor_img, [0, 0, 0, 0], [-1, -1, -1, 3]) 182 | positive_rgb = tf.slice(positive_img, [0, 0, 0, 0], [-1, -1, -1, 3]) 183 | negative_rgb = tf.slice(negative_img, [0, 0, 0, 0], [-1, -1, -1, 3]) 184 | tb_vis = tf.concat([anchor_rgb, positive_rgb, negative_rgb], axis=2) 185 | with tf.name_scope('triplet_vis'): 186 | tf.summary.image('anchor-pos-neg', tb_vis) 187 | 188 | optimizer = tf.train.AdamOptimizer(lr_app_pretrain, adam_beta1, 189 | adam_beta2) 190 | optimizer = tf.contrib.estimator.TowerOptimizer(optimizer) 191 | app_vars = utils.model_vars('appearance_net')[0] 192 | print('\n\n***************************************************') 193 | print('DBG: len(app_vars) = %d' % len(app_vars)) 194 | for ii, v in enumerate(app_vars): 195 | print('%03d) %s' % (ii, str(v))) 196 | print('***************************************************\n\n') 197 | app_train_op = optimizer.minimize(triplet_loss, var_list=app_vars) 198 | return tf.estimator.EstimatorSpec( 199 | mode=mode, loss=triplet_loss, 200 | train_op=tf.group(app_train_op, op_increment_step)) 201 | elif mode == tf.estimator.ModeKeys.PREDICT: 202 | imgs = features['anchor_img'] 203 | embeddings = tf.squeeze(app_func(imgs), axis=[1, 2]) 204 | app_vars = utils.model_vars('appearance_net')[0] 205 | tf.train.init_from_checkpoint(osp.join(opts.train_dir), 206 | {'appearance_net/': 'appearance_net/'}) 207 | return tf.estimator.EstimatorSpec(mode=mode, predictions=embeddings) 208 | else: 209 | raise ValueError('Unsupported mode for the appearance model: ' + mode) 210 | 211 | return model_fn 212 | 213 | 214 | def compute_dist_matrix(imageset_dir, dist_file_path, recompute_dist=False): 215 | if not recompute_dist and osp.exists(dist_file_path): 216 | print('*** Loading distance matrix from %s' % dist_file_path) 217 | with open(dist_file_path, 'rb') as f: 218 | dist_matrix = pickle.load(f)['dist_matrix'] 219 | print('loaded a dist_matrix of shape: %s' % str(dist_matrix.shape)) 220 | return dist_matrix 221 | else: 222 | images_paths = sorted(glob.glob(osp.join(imageset_dir, '*_reference.png'))) 223 | dist_matrix = style_loss.compute_pairwise_style_loss_v2(images_paths) 224 | dist_dict = {'dist_matrix': dist_matrix} 225 | print('Saving distance matrix to %s' % dist_file_path) 226 | with open(dist_file_path, 'wb') as f: 227 | pickle.dump(dist_dict, f) 228 | return dist_matrix 229 | 230 | 231 | def train_appearance(train_dir, imageset_dir, dist_file_path): 232 | batch_size = 8 233 | lr_app_pretrain = 0.001 234 | 235 | trainset_size = len(glob.glob(osp.join(imageset_dir, '*_reference.png'))) 236 | resume_step = utils.load_global_step_from_checkpoint_dir(train_dir) 237 | if resume_step != 0: 238 | tf.logging.warning('DBG: resuming apperance pretraining at %d!' % 239 | resume_step) 240 | model_fn = build_model_fn(batch_size, lr_app_pretrain) 241 | config = tf.estimator.RunConfig( 242 | save_summary_steps=50, 243 | save_checkpoints_steps=500, 244 | keep_checkpoint_max=5, 245 | log_step_count_steps=100) 246 | est = tf.estimator.Estimator( 247 | tf.contrib.estimator.replicate_model_fn(model_fn), train_dir, 248 | config, params={}) 249 | # Get input function 250 | input_train_fn = lambda: get_tf_triplet_dataset_iter( 251 | imageset_dir, trainset_size, dist_file_path, 252 | batch_size=batch_size).get_next() 253 | print('Starting pretraining steps...') 254 | est.train(input_train_fn, steps=None, hooks=None) # train indefinitely 255 | 256 | 257 | def main(argv): 258 | if len(argv) > 1: 259 | raise app.UsageError('Too many command-line arguments.') 260 | 261 | train_dir = opts.train_dir 262 | dataset_name = opts.dataset_name 263 | imageset_dir = opts.imageset_dir 264 | output_dir = opts.metadata_output_dir 265 | if not osp.exists(output_dir): 266 | os.makedirs(output_dir) 267 | dist_file_path = osp.join(output_dir, 'dist_%s.pckl' % dataset_name) 268 | compute_dist_matrix(imageset_dir, dist_file_path) 269 | train_appearance(train_dir, imageset_dir, dist_file_path) 270 | 271 | if __name__ == '__main__': 272 | app.run(main) 273 | -------------------------------------------------------------------------------- /segment_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | #     https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Generate semantic segmentations 16 | This module uses Xception model trained on ADE20K dataset to generate semantic 17 | segmentation mask to any set of images. 18 | """ 19 | 20 | from absl import app 21 | from absl import flags 22 | from PIL import Image 23 | import glob 24 | import matplotlib.pyplot as plt 25 | import numpy as np 26 | import os 27 | import os.path as osp 28 | import shutil 29 | import tensorflow as tf 30 | import utils 31 | 32 | 33 | def get_semantic_color_coding(): 34 | """ 35 | assigns the 30 (actually 29) semantic colors from cityscapes semantic mapping 36 | to selected classes from the ADE20K150 semantic classes. 37 | """ 38 | # Below are the 30 cityscape colors (one is duplicate. so total is 29 not 30) 39 | colors = [ 40 | (111, 74, 0), 41 | ( 81, 0, 81), 42 | (128, 64,128), 43 | (244, 35,232), 44 | (250,170,160), 45 | (230,150,140), 46 | ( 70, 70, 70), 47 | (102,102,156), 48 | (190,153,153), 49 | (180,165,180), 50 | (150,100,100), 51 | (150,120, 90), 52 | (153,153,153), 53 | # (153,153,153), 54 | (250,170, 30), 55 | (220,220, 0), 56 | (107,142, 35), 57 | (152,251,152), 58 | ( 70,130,180), 59 | (220, 20, 60), 60 | (255, 0, 0), 61 | ( 0, 0,142), 62 | ( 0, 0, 70), 63 | ( 0, 60,100), 64 | ( 0, 0, 90), 65 | ( 0, 0,110), 66 | ( 0, 80,100), 67 | ( 0, 0,230), 68 | (119, 11, 32), 69 | ( 0, 0,142)] 70 | k_num_ade20k_classes = 150 71 | # initially all 150 classes are mapped to a single color (last color idx: -1) 72 | # Some classes are to be assigned independent colors 73 | # semantic classes are 1-based (1 thru 150) 74 | semantic_to_color_idx = -1 * np.ones(k_num_ade20k_classes + 1, dtype=int) 75 | semantic_to_color_idx [1] = 0 # wall 76 | semantic_to_color_idx [2] = 1 # building;edifice 77 | semantic_to_color_idx [3] = 2 # sky 78 | semantic_to_color_idx [105] = 3 # fountain 79 | semantic_to_color_idx [27] = 4 # sea 80 | semantic_to_color_idx [60] = 5 # stairway;staircase 81 | semantic_to_color_idx [5] = 6 # tree 82 | semantic_to_color_idx [12] = 7 # sidewalk;pavement 83 | semantic_to_color_idx [4] = 7 # floor;flooring 84 | semantic_to_color_idx [7] = 7 # road;route 85 | semantic_to_color_idx [13] = 8 # people 86 | semantic_to_color_idx [18] = 9 # plant;flora;plant;life 87 | semantic_to_color_idx [17] = 10 # mountain;mount 88 | semantic_to_color_idx [20] = 11 # chair 89 | semantic_to_color_idx [6] = 12 # ceiling 90 | semantic_to_color_idx [22] = 13 # water 91 | semantic_to_color_idx [35] = 14 # rock;stone 92 | semantic_to_color_idx [14] = 15 # earth;ground 93 | semantic_to_color_idx [10] = 16 # grass 94 | semantic_to_color_idx [70] = 17 # bench 95 | semantic_to_color_idx [54] = 18 # stairs;steps 96 | semantic_to_color_idx [101] = 19 # poster 97 | semantic_to_color_idx [77] = 20 # boat 98 | semantic_to_color_idx [85] = 21 # tower 99 | semantic_to_color_idx [23] = 22 # painting;picture 100 | semantic_to_color_idx [88] = 23 # streetlight;stree;lamp 101 | semantic_to_color_idx [43] = 24 # column;pillar 102 | semantic_to_color_idx [9] = 25 # window;windowpane 103 | semantic_to_color_idx [15] = 26 # door; 104 | semantic_to_color_idx [133] = 27 # sculpture 105 | 106 | semantic_to_rgb = np.array( 107 | [colors[col_idx][:] for col_idx in semantic_to_color_idx]) 108 | return semantic_to_rgb 109 | 110 | 111 | def _apply_colors(seg_images_path, save_dir, idx_to_color): 112 | for i, img_path in enumerate(seg_images_path): 113 | print('processing img #%05d / %05d: %s' % (i, len(seg_images_path), 114 | osp.split(img_path)[1])) 115 | seg = np.array(Image.open(img_path)) 116 | seg_rgb = np.zeros(seg.shape + (3,), dtype=np.uint8) 117 | for col_idx in range(len(idx_to_color)): 118 | if idx_to_color[col_idx][0] != -1: 119 | mask = seg == col_idx 120 | seg_rgb[mask, :] = idx_to_color[col_idx][:] 121 | 122 | parent_dir, filename = osp.split(img_path) 123 | basename, ext = osp.splitext(filename) 124 | out_filename = basename + "_rgb.png" 125 | out_filepath = osp.join(save_dir, out_filename) 126 | # Save rescaled segmentation image 127 | Image.fromarray(seg_rgb).save(out_filepath) 128 | 129 | 130 | # The frozen xception model only segments 512x512 images. But it would be better 131 | # to segment the full image instead! 132 | def segment_images(images_path, xception_frozen_graph_path, save_dir, 133 | crop_height=512, crop_width=512): 134 | if not osp.exists(xception_frozen_graph_path): 135 | raise OSError('Xception frozen graph not found at %s' % 136 | xception_frozen_graph_path) 137 | with tf.gfile.GFile(xception_frozen_graph_path, "rb") as f: 138 | graph_def = tf.GraphDef() 139 | graph_def.ParseFromString(f.read()) 140 | 141 | with tf.Graph().as_default() as graph: 142 | new_input = tf.placeholder(tf.uint8, [1, crop_height, crop_width, 3], 143 | name="new_input") 144 | tf.import_graph_def( 145 | graph_def, 146 | input_map={"ImageTensor:0": new_input}, 147 | return_elements=None, 148 | name="sem_seg", 149 | op_dict=None, 150 | producer_op_list=None 151 | ) 152 | 153 | corrupted_dir = osp.join(save_dir, 'corrupted') 154 | if not osp.exists(corrupted_dir): 155 | os.makedirs(corrupted_dir) 156 | with tf.Session(graph=graph) as sess: 157 | for i, img_path in enumerate(images_path): 158 | print('Segmenting image %05d / %05d: %s' % (i + 1, len(images_path), 159 | img_path)) 160 | img = np.array(Image.open(img_path)) 161 | if len(img.shape) == 2 or img.shape[2] != 3: 162 | print('Warning! corrupted image %s' % img_path) 163 | img_base_path = img_path[:-14] # remove the '_reference.png' suffix 164 | srcs = sorted(glob.glob(img_base_path + '_*')) 165 | dest = unicode(corrupted_dir + '/.') 166 | for src in srcs: 167 | shutil.move(src, dest) 168 | continue 169 | img = utils.get_central_crop(img, crop_height=crop_height, 170 | crop_width=crop_width) 171 | img = np.expand_dims(img, 0) # convert to NHWC format 172 | seg = sess.run("sem_seg/SemanticPredictions:0", feed_dict={ 173 | new_input: img}) 174 | assert np.max(seg[:]) <= 255, 'segmentation image is not of type uint8!' 175 | seg = np.squeeze(np.uint8(seg)) # convert to uint8 and squeeze to WxH. 176 | parent_dir, filename = osp.split(img_path) 177 | basename, ext = osp.splitext(filename) 178 | basename = basename[:-10] # remove the '_reference' suffix 179 | seg_filename = basename + "_seg.png" 180 | seg_filepath = osp.join(save_dir, seg_filename) 181 | # Save segmentation image 182 | Image.fromarray(seg).save(seg_filepath) 183 | 184 | def segment_and_color_dataset(dataset_dir, xception_frozen_graph_path, 185 | splits=None, resegment_images=True): 186 | if splits is None: 187 | imgs_dirs = [dataset_dir] 188 | else: 189 | imgs_dirs = [osp.join(dataset_dir, split) for split in splits] 190 | 191 | for cur_dir in imgs_dirs: 192 | imgs_file_pattern = osp.join(cur_dir, '*_reference.png') 193 | images_path = sorted(glob.glob(imgs_file_pattern)) 194 | if resegment_images: 195 | segment_images(images_path, xception_frozen_graph_path, cur_dir, 196 | crop_height=512, crop_width=512) 197 | 198 | idx_to_col = get_semantic_color_coding() 199 | 200 | for cur_dir in imgs_dirs: 201 | save_dir = cur_dir 202 | seg_file_pattern = osp.join(cur_dir, '*_seg.png') 203 | seg_imgs_paths = sorted(glob.glob(seg_file_pattern)) 204 | _apply_colors(seg_imgs_paths, save_dir, idx_to_col) 205 | -------------------------------------------------------------------------------- /staged_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | #     https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Neural re-rerendering in the wild. 16 | 17 | Implementation of the staged training pipeline. 18 | """ 19 | 20 | from options import FLAGS as opts 21 | import losses 22 | import networks 23 | import tensorflow as tf 24 | import utils 25 | 26 | 27 | def create_computation_graph(x_in, x_gt, x_app=None, arch_type='pggan', 28 | use_appearance=True): 29 | """Create the models and the losses. 30 | 31 | Args: 32 | x_in: 4D tensor, batch of conditional input images in NHWC format. 33 | x_gt: 2D tensor, batch ground-truth images in NHWC format. 34 | x_app: 4D tensor, batch of input appearance images. 35 | 36 | Returns: 37 | Dictionary of placeholders and TF graph functions. 38 | """ 39 | # --------------------------------------------------------------------------- 40 | # Build models/networks 41 | # --------------------------------------------------------------------------- 42 | 43 | rerenderer = networks.RenderingModel(arch_type, use_appearance) 44 | app_enc = rerenderer.get_appearance_encoder() 45 | discriminator = networks.MultiScaleDiscriminator( 46 | 'd_model', opts.appearance_nc, num_scales=3, nf=64, n_layers=3, 47 | get_fmaps=False) 48 | 49 | # --------------------------------------------------------------------------- 50 | # Forward pass 51 | # --------------------------------------------------------------------------- 52 | 53 | if opts.use_appearance: 54 | z_app, _, _ = app_enc(x_app) 55 | else: 56 | z_app = None 57 | 58 | y = rerenderer(x_in, z_app) 59 | 60 | # --------------------------------------------------------------------------- 61 | # Losses 62 | # --------------------------------------------------------------------------- 63 | 64 | w_loss_gan = opts.w_loss_gan 65 | w_loss_recon = opts.w_loss_vgg if opts.use_vgg_loss else opts.w_loss_l1 66 | 67 | # compute discriminator logits 68 | disc_real_featmaps = discriminator(x_gt, x_in) 69 | disc_fake_featmaps = discriminator(y, x_in) 70 | 71 | # discriminator loss 72 | loss_d_real = losses.multiscale_discriminator_loss(disc_real_featmaps, True) 73 | loss_d_fake = losses.multiscale_discriminator_loss(disc_fake_featmaps, False) 74 | loss_d = loss_d_real + loss_d_fake 75 | 76 | # generator loss 77 | loss_g_gan = losses.multiscale_discriminator_loss(disc_fake_featmaps, True) 78 | if opts.use_vgg_loss: 79 | vgg_layers = ['conv%d_2' % i for i in range(1, 6)] # conv1 through conv5 80 | vgg_layer_weights = [1./32, 1./16, 1./8, 1./4, 1.] 81 | vgg_loss = losses.PerceptualLoss(y, x_gt, [256, 256, 3], vgg_layers, 82 | vgg_layer_weights) # NOTE: shouldn't hardcode image size! 83 | loss_g_recon = vgg_loss() 84 | else: 85 | loss_g_recon = losses.L1_loss(y, x_gt) 86 | loss_g = w_loss_gan * loss_g_gan + w_loss_recon * loss_g_recon 87 | 88 | # --------------------------------------------------------------------------- 89 | # Tensorboard visualizations 90 | # --------------------------------------------------------------------------- 91 | 92 | x_in_render = tf.slice(x_in, [0, 0, 0, 0], [-1, -1, -1, 3]) 93 | if opts.use_semantic: 94 | x_in_semantic = tf.slice(x_in, [0, 0, 0, 4], [-1, -1, -1, 3]) 95 | tb_visualization = tf.concat([x_in_render, x_in_semantic, y, x_gt], axis=2) 96 | else: 97 | tb_visualization = tf.concat([x_in_render, y, x_gt], axis=2) 98 | tf.summary.image('rendered-semantic-generated-gt tuple', tb_visualization) 99 | 100 | # Show input appearance images 101 | if opts.use_appearance: 102 | x_app_rgb = tf.slice(x_app, [0, 0, 0, 0], [-1, -1, -1, 3]) 103 | x_app_sem = tf.slice(x_app, [0, 0, 0, 7], [-1, -1, -1, -1]) 104 | tb_app_visualization = tf.concat([x_app_rgb, x_app_sem], axis=2) 105 | tf.summary.image('input appearance image', tb_app_visualization) 106 | 107 | # Loss summaries 108 | with tf.name_scope('Discriminator_Loss'): 109 | tf.summary.scalar('D_real_loss', loss_d_real) 110 | tf.summary.scalar('D_fake_loss', loss_d_fake) 111 | tf.summary.scalar('D_total_loss', loss_d) 112 | with tf.name_scope('Generator_Loss'): 113 | tf.summary.scalar('G_GAN_loss', w_loss_gan * loss_g_gan) 114 | tf.summary.scalar('G_reconstruction_loss', w_loss_recon * loss_g_recon) 115 | tf.summary.scalar('G_total_loss', loss_g) 116 | 117 | # --------------------------------------------------------------------------- 118 | # Optimizers 119 | # --------------------------------------------------------------------------- 120 | 121 | def get_optimizer(lr, loss, var_list): 122 | optimizer = tf.train.AdamOptimizer(lr, opts.adam_beta1, opts.adam_beta2) 123 | # optimizer = tf.contrib.estimator.TowerOptimizer(optimizer) 124 | return optimizer.minimize(loss, var_list=var_list) 125 | 126 | # Training ops. 127 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 128 | with tf.control_dependencies(update_ops): 129 | with tf.variable_scope('optimizers'): 130 | d_vars = utils.model_vars('d_model')[0] 131 | g_vars_all = utils.model_vars('g_model')[0] 132 | train_d = [get_optimizer(opts.d_lr, loss_d, d_vars)] 133 | train_g = [get_optimizer(opts.g_lr, loss_g, g_vars_all)] 134 | 135 | train_app_encoder = [] 136 | if opts.train_app_encoder: 137 | lr_app = opts.ez_lr 138 | app_enc_vars = utils.model_vars('appearance_net')[0] 139 | train_app_encoder.append(get_optimizer(lr_app, loss_g, app_enc_vars)) 140 | 141 | ema = tf.train.ExponentialMovingAverage(decay=0.999) 142 | with tf.control_dependencies(train_g + train_app_encoder): 143 | inference_vars_all = g_vars_all 144 | if opts.use_appearance: 145 | app_enc_vars = utils.model_vars('appearance_net')[0] 146 | inference_vars_all += app_enc_vars 147 | ema_op = ema.apply(inference_vars_all) 148 | 149 | print('***************************************************') 150 | print('len(g_vars_all) = %d' % len(g_vars_all)) 151 | for ii, v in enumerate(g_vars_all): 152 | print('%03d) %s' % (ii, str(v))) 153 | print('-------------------------------------------------------') 154 | print('len(d_vars) = %d' % len(d_vars)) 155 | for ii, v in enumerate(d_vars): 156 | print('%03d) %s' % (ii, str(v))) 157 | if opts.train_app_encoder: 158 | print('-------------------------------------------------------') 159 | print('len(app_enc_vars) = %d' % len(app_enc_vars)) 160 | for ii, v in enumerate(app_enc_vars): 161 | print('%03d) %s' % (ii, str(v))) 162 | print('***************************************************\n\n') 163 | 164 | return { 165 | 'train_disc_op': tf.group(train_d), 166 | 'train_renderer_op': ema_op, 167 | 'total_loss_d': loss_d, 168 | 'loss_d_real': loss_d_real, 169 | 'loss_d_fake': loss_d_fake, 170 | 'loss_g_gan': w_loss_gan * loss_g_gan, 171 | 'loss_g_recon': w_loss_recon * loss_g_recon, 172 | 'total_loss_g': loss_g} 173 | -------------------------------------------------------------------------------- /style_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | #     https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from PIL import Image 16 | from options import FLAGS as opts 17 | import data 18 | import layers 19 | import numpy as np 20 | import tensorflow as tf 21 | import utils 22 | import vgg16 23 | 24 | 25 | def gram_matrix(layer): 26 | """Computes the gram_matrix for a batch of single vgg layer 27 | Input: 28 | layer: a batch of vgg activations for a single conv layer 29 | Returns: 30 | gram: [batch_sz x num_channels x num_channels]: a batch of gram matrices 31 | """ 32 | batch_size, height, width, num_channels = layer.get_shape().as_list() 33 | features = tf.reshape(layer, [batch_size, height * width, num_channels]) 34 | num_elements = tf.constant(num_channels * height * width, tf.float32) 35 | gram = tf.matmul(features, features, adjoint_a=True) / num_elements 36 | return gram 37 | 38 | 39 | def compute_gram_matrices( 40 | images, vgg_layers=['conv1_2', 'conv2_2', 'conv3_2', 'conv4_2', 'conv5_2']): 41 | """Computes the gram matrix representation of a batch of images""" 42 | vgg_net = vgg16.Vgg16(opts.vgg16_path) 43 | vgg_acts = vgg_net.get_vgg_activations(images, vgg_layers) 44 | grams = [gram_matrix(layer) for layer in vgg_acts] 45 | return grams 46 | 47 | 48 | def compute_pairwise_style_loss_v2(image_paths_list): 49 | grams_all = [None] * len(image_paths_list) 50 | crop_height, crop_width = opts.train_resolution, opts.train_resolution 51 | img_var = tf.placeholder(tf.float32, shape=[1, crop_height, crop_width, 3]) 52 | vgg_layers = ['conv%d_2' % i for i in range(1, 6)] # conv1 through conv5 53 | grams_ops = compute_gram_matrices(img_var, vgg_layers) 54 | with tf.Session() as sess: 55 | for ii, img_path in enumerate(image_paths_list): 56 | print('Computing gram matrices for image #%d' % (ii + 1)) 57 | img = np.array(Image.open(img_path), dtype=np.float32) 58 | img = img * 2. / 255. - 1 # normalize image 59 | img = utils.get_central_crop(img, crop_height, crop_width) 60 | img = np.expand_dims(img, axis=0) 61 | grams_all[ii] = sess.run(grams_ops, feed_dict={img_var: img}) 62 | print('Number of images = %d' % len(grams_all)) 63 | print('Gram matrices per image:') 64 | for i in range(len(grams_all[0])): 65 | print('gram_matrix[%d].shape = %s' % (i, grams_all[0][i].shape)) 66 | n_imgs = len(grams_all) 67 | dist_matrix = np.zeros((n_imgs, n_imgs)) 68 | for i in range(n_imgs): 69 | print('Computing distances for image #%d' % i) 70 | for j in range(i + 1, n_imgs): 71 | loss_style = 0 72 | # Compute loss using all gram matrices from all layers 73 | for gram_i, gram_j in zip(grams_all[i], grams_all[j]): 74 | loss_style += np.mean((gram_i - gram_j) ** 2, axis=(1, 2)) 75 | dist_matrix[i][j] = dist_matrix[j][i] = loss_style 76 | 77 | return dist_matrix 78 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | #     https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities for GANs. 16 | 17 | Basic functions such as generating sample grid, exporting to PNG, etc... 18 | """ 19 | 20 | import functools 21 | import numpy as np 22 | import os.path 23 | import tensorflow as tf 24 | import time 25 | 26 | 27 | def crop_to_multiple(img, size_multiple=64): 28 | """ Crops the image so that its dimensions are multiples of size_multiple.""" 29 | new_width = (img.shape[1] // size_multiple) * size_multiple 30 | new_height = (img.shape[0] // size_multiple) * size_multiple 31 | offset_x = (img.shape[1] - new_width) // 2 32 | offset_y = (img.shape[0] - new_height) // 2 33 | return img[offset_y:offset_y + new_height, offset_x:offset_x + new_width, :] 34 | 35 | 36 | def get_central_crop(img, crop_height=512, crop_width=512): 37 | if len(img.shape) == 2: 38 | img = np.expand_dims(img, axis=2) 39 | assert len(img.shape) == 3, ('input image should be either a 2D or 3D matrix,' 40 | ' but input was of shape %s' % str(img.shape)) 41 | height, width, _ = img.shape 42 | assert height >= crop_height and width >= crop_width, ('input image cannot ' 43 | 'be smaller than the requested crop size') 44 | st_y = (height - crop_height) // 2 45 | st_x = (width - crop_width) // 2 46 | return np.squeeze(img[st_y : st_y + crop_height, st_x : st_x + crop_width, :]) 47 | 48 | 49 | def load_global_step_from_checkpoint_dir(checkpoint_dir): 50 | """Loads the global step from the checkpoint directory. 51 | 52 | Args: 53 | checkpoint_dir: string, path to the checkpoint directory. 54 | 55 | Returns: 56 | int, the global step of the latest checkpoint or 0 if none was found. 57 | """ 58 | try: 59 | checkpoint_reader = tf.train.NewCheckpointReader( 60 | tf.train.latest_checkpoint(checkpoint_dir)) 61 | return checkpoint_reader.get_tensor(tf.GraphKeys.GLOBAL_STEP) 62 | except: 63 | return 0 64 | 65 | 66 | def model_vars(prefix): 67 | """Return trainable variables matching a prefix. 68 | 69 | Args: 70 | prefix: string, the prefix variable names must match. 71 | 72 | Returns: 73 | a tuple (match, others) of TF variables, 'match' contains the matched 74 | variables and 'others' contains the remaining variables. 75 | """ 76 | match, no_match = [], [] 77 | for x in tf.trainable_variables(): 78 | if x.name.startswith(prefix): 79 | match.append(x) 80 | else: 81 | no_match.append(x) 82 | return match, no_match 83 | 84 | 85 | def to_png(x): 86 | """Convert a 3D tensor to png. 87 | 88 | Args: 89 | x: Tensor, 01C formatted input image. 90 | 91 | Returns: 92 | Tensor, 1D string representing the image in png format. 93 | """ 94 | with tf.Graph().as_default(): 95 | with tf.Session() as sess_temp: 96 | x = tf.constant(x) 97 | y = tf.image.encode_png( 98 | tf.cast( 99 | tf.clip_by_value(tf.round(127.5 + 127.5 * x), 0, 255), tf.uint8), 100 | compression=9) 101 | return sess_temp.run(y) 102 | 103 | 104 | def images_to_grid(images): 105 | """Converts a grid of images (5D tensor) to a single image. 106 | 107 | Args: 108 | images: 5D tensor (count_y, count_x, height, width, colors), grid of images. 109 | 110 | Returns: 111 | a 3D tensor image of shape (count_y * height, count_x * width, colors). 112 | """ 113 | ny, nx, h, w, c = images.shape 114 | images = images.transpose(0, 2, 1, 3, 4) 115 | images = images.reshape([ny * h, nx * w, c]) 116 | return images 117 | 118 | 119 | def save_images(image, output_dir, cur_nimg): 120 | """Saves images to disk. 121 | 122 | Saves a file called 'name.png' containing the latest samples from the 123 | generator and a file called 'name_123.png' where 123 is the KiB of trained 124 | images. 125 | 126 | Args: 127 | image: 3D numpy array (height, width, colors), the image to save. 128 | output_dir: string, the directory where to save the image. 129 | cur_nimg: int, current number of images seen by training. 130 | 131 | Returns: 132 | None 133 | """ 134 | for name in ('name.png', 'name_%06d.png' % (cur_nimg >> 10)): 135 | with tf.gfile.Open(os.path.join(output_dir, name), 'wb') as f: 136 | f.write(image) 137 | 138 | 139 | class HookReport(tf.train.SessionRunHook): 140 | """Custom reporting hook. 141 | 142 | Register your tensor scalars with HookReport.log_tensor(my_tensor, 'my_name'). 143 | This hook will report their average values over report period argument 144 | provided to the constructed. The values are printed in the order the tensors 145 | were registered. 146 | 147 | Attributes: 148 | step: int, the current global step. 149 | active: bool, whether logging is active or disabled. 150 | """ 151 | _REPORT_KEY = 'report' 152 | _TENSOR_NAMES = {} 153 | 154 | def __init__(self, period, batch_size): 155 | self.step = 0 156 | self.active = True 157 | self._period = period // batch_size 158 | self._batch_size = batch_size 159 | self._sums = np.array([]) 160 | self._count = 0 161 | self._nimgs_per_cycle = 0 162 | self._step_ratio = 0 163 | self._start = time.time() 164 | self._nimgs = 0 165 | self._batch_size = batch_size 166 | 167 | def disable(self): 168 | parent = self 169 | 170 | class Disabler(object): 171 | 172 | def __enter__(self): 173 | parent.active = False 174 | return parent 175 | 176 | def __exit__(self, exc_type, exc_val, exc_tb): 177 | parent.active = True 178 | 179 | return Disabler() 180 | 181 | def begin(self): 182 | self.active = True 183 | self._count = 0 184 | self._nimgs_per_cycle = 0 185 | self._start = time.time() 186 | 187 | def before_run(self, run_context): 188 | if not self.active: 189 | return 190 | del run_context 191 | fetches = tf.get_collection(self._REPORT_KEY) 192 | return tf.train.SessionRunArgs(fetches) 193 | 194 | def after_run(self, run_context, run_values): 195 | if not self.active: 196 | return 197 | del run_context 198 | results = run_values.results 199 | # Note: sometimes the returned step is incorrect (off by one) for some 200 | # unknown reason. 201 | self.step = results[-1] + 1 202 | self._count += 1 203 | self._nimgs_per_cycle += self._batch_size 204 | self._nimgs += self._batch_size 205 | 206 | if not self._sums.size: 207 | self._sums = np.array(results[:-1], 'd') 208 | else: 209 | self._sums += np.array(results[:-1], 'd') 210 | 211 | if self.step // self._period != self._step_ratio: 212 | fetches = tf.get_collection(self._REPORT_KEY)[:-1] 213 | stats = ' '.join('%s=% .2f' % (self._TENSOR_NAMES[tensor], 214 | value / self._count) 215 | for tensor, value in zip(fetches, self._sums)) 216 | stop = time.time() 217 | tf.logging.info('step=%d, kimg=%d %s [%.2f img/s]' % 218 | (self.step, ((self.step * self._batch_size) >> 10), 219 | stats, self._nimgs_per_cycle / (stop - self._start))) 220 | self._step_ratio = self.step // self._period 221 | self._start = stop 222 | self._sums *= 0 223 | self._count = 0 224 | self._nimgs_per_cycle = 0 225 | 226 | def end(self, session=None): 227 | del session 228 | 229 | @classmethod 230 | def log_tensor(cls, tensor, name): 231 | """Adds a tensor to be reported by the hook. 232 | 233 | Args: 234 | tensor: `tensor scalar`, a value to report. 235 | name: string, the name to give the value in the report. 236 | 237 | Returns: 238 | None. 239 | """ 240 | cls._TENSOR_NAMES[tensor] = name 241 | tf.add_to_collection(cls._REPORT_KEY, tensor) 242 | --------------------------------------------------------------------------------