├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── data.py
├── dataset_utils.py
├── evaluate_quantitative_metrics.py
├── imgs
├── app_interpolation.jpg
├── app_variation.jpg
└── teaser_with_caption.jpg
├── layers.py
├── losses.py
├── networks.py
├── neural_rerendering.py
├── options.py
├── pretrain_appearance.py
├── segment_dataset.py
├── staged_model.py
├── style_loss.py
└── utils.py
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution;
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Community Guidelines
26 |
27 | This project follows
28 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/).
29 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Neural Rerendering in the Wild
2 | Moustafa Meshry1,
3 | [Dan B Goldman](http://www.danbgoldman.com/)2,
4 | [Sameh Khamis](http://www.samehkhamis.com/)2,
5 | [Hugues Hoppe](http://hhoppe.com/)2,
6 | Rohit Pandey2,
7 | [Noah Snavely](http://www.cs.cornell.edu/~snavely/)2,
8 | [Ricardo Martin-Brualla](http://www.ricardomartinbrualla.com/)2.
9 |
10 | 1University of Maryland, College Park 2Google Inc.
11 |
12 | To appear at CVPR 2019 (Oral).
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 | We will provide Tensorflow implementation and pretrained models for our paper soon.
22 |
23 | [**Paper**](https://arxiv.org/abs/1904.04290) | [**Video**](https://www.youtube.com/watch?v=E1crWQn_kmY) | [**Code**](https://github.com/MoustafaMeshry/neural_rerendering_in_the_wild) | [**Project page**](https://moustafameshry.github.io/neural_rerendering_in_the_wild/)
24 |
25 | ### Abstract
26 |
27 | We explore total scene capture — recording, modeling, and rerendering a scene under varying appearance such as season and time of day.
28 | Starting from internet photos of a tourist landmark, we apply traditional 3D reconstruction to register the photos and approximate the scene as a point cloud.
29 | For each photo, we render the scene points into a deep framebuffer,
30 | and train a neural network to learn the mapping of these initial renderings to the actual photos.
31 | This rerendering network also takes as input a latent appearance vector and a semantic mask indicating the location of transient objects like pedestrians.
32 | The model is evaluated on several datasets of publicly available images spanning a broad range of illumination conditions.
33 | We create short videos demonstrating realistic manipulation of the image viewpoint, appearance, and semantic labeling.
34 | We also compare results with prior work on scene reconstruction from internet photos.
35 |
36 | ### Video
37 | [](https://www.youtube.com/watch?v=E1crWQn_kmY)
38 |
39 |
40 | ### Appearance variation
41 |
42 | We capture the appearance of the original images in the left column, and rerender several viewpoints under them. The last column is a detail of the previous one. The top row shows the renderings part of the input to the rerenderer, that exhibit artifacts like incomplete features in the statue, and an inconsistent mix of day and night appearances. Note the hallucinated twilight scene in the sky using the last appearance. Image credits: Flickr users William Warby, Neil Rickards, Rafael Jimenez, acme401 (Creative Commons).
43 |
44 |
45 |
46 |
47 |
48 | ### Appearance interpolation
49 | Frames from a synthesized camera path that smoothly transitions from the photo on the left to the photo on the right by smoothly interpolating both viewpoint and the latent appearance vectors. Please see the supplementary video. Photo Credits: Allie Caulfield, Tahbepet, Till Westermayer, Elliott Brown (Creative Commons).
50 |
51 |
52 |
53 |
54 | ### Acknowledgements
55 | We thank Gregory Blascovich for his help in conducting the user study, and Johannes Schönberger and True Price for their help generating datasets.
56 |
57 | ### Run and train instructions
58 |
59 | Staged-training consists of three stages:
60 |
61 | - Pretraining the appearance network.
62 | - Training the rendering network while fixing the weights for the appearance
63 | network.
64 | - Finetuning both the appearance and the rendering networks.
65 |
66 | ### Aligned dataset preprocessing
67 |
68 | #### Manual preparation
69 |
70 | * Set a path to a base_dir that contains the source code:
71 |
72 | ```
73 | base_dir=//to/neural_rendering
74 | mkdir $base_dir
75 | cd $base_dir
76 | ```
77 |
78 | * We assume the following format for an aligned dataset:
79 | * Each training image contains 3 file with the following nameing format:
80 | * real image: %04d_reference.png
81 | * render color: %04d_color.png
82 | * render depth: %04d_depth.png
83 | * Set dataset name: e.g.
84 | ```
85 | dataset_name='trevi3k' # set to any name
86 | ```
87 | * Split the dataset into train and validation sets in two subdirectories:
88 | * $base_dir/datasets/$dataset_name/train
89 | * $base_dir/datasets/$dataset_name/val
90 | * Download the DeepLab semantic segmentation model trained on the ADE20K
91 | dataset from this link:
92 | http://download.tensorflow.org/models/deeplabv3_xception_ade20k_train_2018_05_29.tar.gz
93 | * Unzip the downloaded file to: $base_dir/deeplabv3_xception_ade20k_train
94 | * Download this [file](https://github.com/MoustafaMeshry/vgg_loss/blob/master/vgg16.py) for an implementation of a vgg-based perceptual loss.
95 | * Download trained weights for the vgg network as instructed in this link: https://github.com/machrisaa/tensorflow-vgg
96 | * Save the vgg weights to $base_dir/vgg16_weights/vgg16.npy
97 |
98 |
99 | #### Data preprocessing
100 |
101 | * Run the preprocessing pipeline which consists of:
102 | * Filtering out sparse renders.
103 | * Semantic segmentation of ground truth images.
104 | * Exporting the dataset to tfrecord format.
105 |
106 | ```
107 | # Run locally
108 | python tools/dataset_utils.py \
109 | --dataset_name=$dataset_name \
110 | --dataset_parent_dir=$base_dir/datasets/$dataset_name \
111 | --output_dir=$base_dir/datasets/$dataset_name \
112 | --xception_frozen_graph_path=$base_dir/deeplabv3_xception_ade20k_train/frozen_inference_graph.pb \
113 | --alsologtostderr
114 | ```
115 |
116 | ### Pretraining the appearance encoder network
117 |
118 | ```
119 | # Run locally
120 | python pretrain_appearance.py \
121 | --dataset_name=$dataset_name \
122 | --train_dir=$base_dir/train_models/$dataset_name-app_pretrain \
123 | --imageset_dir=$base_dir/datasets/$dataset_name/train \
124 | --train_resolution=512 \
125 | --metadata_output_dir=$base_dir/datasets/$dataset_name
126 | ```
127 |
128 | ### Training the rerendering network with a fixed appearance encoder
129 |
130 | Set the dataset_parent_dir variable below to point to the directory containing
131 | the generated TFRecords.
132 |
133 | ```
134 | # Run locally:
135 | dataset_parent_dir=$base_dir/datasets/$dataset_name
136 | train_dir=$base_dir/train_models/$dataset_name-staged-fixed_appearance
137 | load_pretrained_app_encoder=true
138 | appearance_pretrain_dir=$base_dir/train_models/$dataset_name-app_pretrain
139 | load_from_another_ckpt=false
140 | fixed_appearance_train_dir=''
141 | train_app_encoder=false
142 |
143 | python neural_rerendering.py \
144 | --dataset_name=$dataset_name \
145 | --dataset_parent_dir=$dataset_parent_dir \
146 | --train_dir=$train_dir \
147 | --load_pretrained_app_encoder=$load_pretrained_app_encoder \
148 | --appearance_pretrain_dir=$appearance_pretrain_dir \
149 | --train_app_encoder=$train_app_encoder \
150 | --load_from_another_ckpt=$load_from_another_ckpt \
151 | --fixed_appearance_train_dir=$fixed_appearance_train_dir \
152 | --total_kimg=4000
153 | ```
154 |
155 | ### Finetuning the rerendering network and the appearance encoder
156 |
157 | Set the fixed_appearance_train_dir to the train directory from the previous
158 | step.
159 |
160 | ```
161 | # Run locally:
162 | dataset_parent_dir=$base_dir/datasets/$dataset_name
163 | train_dir=$base_dir/train_models/$dataset_name-staged-finetune_appearance
164 | load_pretrained_app_encoder=false
165 | appearance_pretrain_dir=''
166 | load_from_another_ckpt=true
167 | fixed_appearance_train_dir=$base_dir/train_models/$dataset_name-staged-fixed_appearance
168 | train_app_encoder=true
169 |
170 | python neural_rerendering.py \
171 | --dataset_name=$dataset_name \
172 | --dataset_parent_dir=$dataset_parent_dir \
173 | --train_dir=$train_dir \
174 | --load_pretrained_app_encoder=$load_pretrained_app_encoder \
175 | --appearance_pretrain_dir=$appearance_pretrain_dir \
176 | --train_app_encoder=$train_app_encoder \
177 | --load_from_another_ckpt=$load_from_another_ckpt \
178 | --fixed_appearance_train_dir=$fixed_appearance_train_dir \
179 | --total_kimg=4000
180 | ```
181 |
182 |
183 | ### Evaluate model on validation set
184 |
185 | ```
186 | experiment_title=$dataset_name-staged-finetune_appearance
187 | local_train_dir=$base_dir/train_models/$experiment_title
188 | dataset_parent_dir=$base_dir/datasets/$dataset_name
189 | val_set_out_dir=$local_train_dir/val_set_output
190 |
191 | # Run the model on validation set
192 | echo "Evaluating the validation set"
193 | python neural_rerendering.py \
194 | --train_dir=$local_train_dir \
195 | --dataset_name=$dataset_name \
196 | --dataset_parent_dir=$dataset_parent_dir \
197 | --run_mode='eval_subset' \
198 | --virtual_seq_name='val' \
199 | --output_validation_dir=$val_set_out_dir \
200 | --logtostderr
201 | # Evaluate quantitative metrics
202 | python evaluate_quantitative_metrics.py \
203 | --val_set_out_dir=$val_set_out_dir \
204 | --experiment_title=$experiment_title \
205 | --logtostderr
206 | ```
207 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from options import FLAGS as opts
16 | import functools
17 | import glob
18 | import numpy as np
19 | import os.path as osp
20 | import random
21 | import tensorflow as tf
22 |
23 |
24 | def provide_data(dataset_name='', parent_dir='', batch_size=8, subset=None,
25 | max_examples=None, crop_flag=False, crop_size=256, seeds=None,
26 | use_appearance=True, shuffle=128):
27 | # Parsing function for each tfrecord example.
28 | record_parse_fn = functools.partial(
29 | _parser_rendered_dataset, crop_flag=crop_flag, crop_size=crop_size,
30 | use_alpha=opts.use_alpha, use_depth=opts.use_depth,
31 | use_semantics=opts.use_semantic, seeds=seeds,
32 | use_appearance=use_appearance)
33 |
34 | input_dict_var = multi_input_fn_record(
35 | record_parse_fn, parent_dir, dataset_name, batch_size,
36 | subset=subset, max_examples=max_examples, shuffle=shuffle)
37 | return input_dict_var
38 |
39 |
40 | def _parser_rendered_dataset(
41 | serialized_example, crop_flag, crop_size, seeds, use_alpha, use_depth,
42 | use_semantics, use_appearance):
43 | """
44 | Parses a single tf.Example into a features dictionary with input tensors.
45 | """
46 | # Structure of features_dict need to match the dictionary structure that was
47 | # serialized to a tf.Example
48 | features_dict = {'height': tf.FixedLenFeature([], tf.int64),
49 | 'width': tf.FixedLenFeature([], tf.int64),
50 | 'rendered': tf.FixedLenFeature([], tf.string),
51 | 'depth': tf.FixedLenFeature([], tf.string),
52 | 'real': tf.FixedLenFeature([], tf.string),
53 | 'seg': tf.FixedLenFeature([], tf.string)}
54 | features = tf.parse_single_example(serialized_example, features=features_dict)
55 | height = tf.cast(features['height'], tf.int32)
56 | width = tf.cast(features['width'], tf.int32)
57 |
58 | # Parse the rendered image.
59 | rendered = tf.decode_raw(features['rendered'], tf.uint8)
60 | rendered = tf.cast(rendered, tf.float32) * (2.0 / 255) - 1.0
61 | rendered = tf.reshape(rendered, [height, width, 4])
62 | if not use_alpha:
63 | rendered = tf.slice(rendered, [0, 0, 0], [height, width, 3])
64 | conditional_input = rendered
65 |
66 | # Parse the depth image.
67 | if use_depth:
68 | depth = tf.decode_raw(features['depth'], tf.uint16)
69 | depth = tf.reshape(depth, [height, width, 1])
70 | depth = tf.cast(depth, tf.float32) * (2.0 / 255) - 1.0
71 | conditional_input = tf.concat([conditional_input, depth], axis=-1)
72 |
73 | # Parse the semantic map.
74 | if use_semantics:
75 | seg_img = tf.decode_raw(features['seg'], tf.uint8)
76 | seg_img = tf.reshape(seg_img, [height, width, 3])
77 | seg_img = tf.cast(seg_img, tf.float32) * (2.0 / 255) - 1
78 | conditional_input = tf.concat([conditional_input, seg_img], axis=-1)
79 |
80 | # Verify that the parsed input has the correct number of channels.
81 | assert conditional_input.shape[-1] == opts.deep_buffer_nc, ('num channels '
82 | 'in the parsed input doesn\'t match num input channels specified in '
83 | 'opts.deep_buffer_nc!')
84 |
85 | # Parse the ground truth image.
86 | real = tf.decode_raw(features['real'], tf.uint8)
87 | real = tf.cast(real, tf.float32) * (2.0 / 255) - 1.0
88 | real = tf.reshape(real, [height, width, 3])
89 |
90 | # Parse the appearance image (if any).
91 | appearance_input = []
92 | if use_appearance:
93 | # Concatenate the deep buffer to the real image.
94 | appearance_input = tf.concat([real, conditional_input], axis=-1)
95 | # Verify that the parsed input has the correct number of channels.
96 | assert appearance_input.shape[-1] == opts.appearance_nc, ('num channels '
97 | 'in the parsed appearance input doesn\'t match num input channels '
98 | 'specified in opts.appearance_nc!')
99 |
100 | # Crop conditional_input and real images, but keep the appearance input
101 | # uncropped (learn a one-to-many mapping from appearance to output)
102 | if crop_flag:
103 | assert crop_size is not None, 'crop_size is not provided!'
104 | if isinstance(crop_size, int):
105 | crop_size = [crop_size, crop_size]
106 | assert len(crop_size) == 2, 'crop_size is either an int or a 2-tuple!'
107 |
108 | # Central crop
109 | if seeds is not None and len(seeds) <= 1:
110 | conditional_input = tf.image.resize_image_with_crop_or_pad(
111 | conditional_input, crop_size[0], crop_size[1])
112 | real = tf.image.resize_image_with_crop_or_pad(real, crop_size[0],
113 | crop_size[1])
114 | else:
115 | if not seeds: # random crops
116 | seed = random.randint(0, (1 << 31) - 1)
117 | else: # fixed crops
118 | seed_idx = random.randint(0, len(seeds) - 1)
119 | seed = seeds[seed_idx]
120 | conditional_input = tf.random_crop(
121 | conditional_input, crop_size + [opts.deep_buffer_nc], seed=seed)
122 | real = tf.random_crop(real, crop_size + [3], seed=seed)
123 |
124 | features = {'conditional_input': conditional_input,
125 | 'expected_output': real,
126 | 'peek_input': appearance_input}
127 | return features
128 |
129 |
130 | def multi_input_fn_record(
131 | record_parse_fn, parent_dir, tfrecord_basename, batch_size, subset=None,
132 | max_examples=None, shuffle=128):
133 | """Creates a Dataset pipeline for tfrecord files.
134 |
135 | Returns:
136 | Dataset iterator.
137 | """
138 | subset_suffix = '*_%s.tfrecord' % subset if subset else '*.tfrecord'
139 | input_pattern = osp.join(parent_dir, tfrecord_basename + subset_suffix)
140 | filenames = sorted(glob.glob(input_pattern))
141 | assert len(filenames) > 0, ('Error! input pattern "%s" didn\'t match any '
142 | 'files' % input_pattern)
143 | dataset = tf.data.TFRecordDataset(filenames)
144 | if shuffle == 0: # keep input deterministic
145 | # use one thread to get deterministic results
146 | dataset = dataset.map(record_parse_fn, num_parallel_calls=None)
147 | else:
148 | dataset = dataset.repeat() # Repeat indefinitely.
149 | dataset = dataset.map(record_parse_fn,
150 | num_parallel_calls=max(4, batch_size // 4))
151 | if opts.training_pipeline == 'drit':
152 | dataset1 = dataset.shuffle(shuffle)
153 | dataset2 = dataset.shuffle(shuffle)
154 | paired_dataset = tf.data.Dataset.zip((dataset1, dataset2))
155 |
156 | def _join_paired_dataset(features_a, features_b):
157 | features_a['conditional_input_2'] = features_b['conditional_input']
158 | features_a['expected_output_2'] = features_b['expected_output']
159 | return features_a
160 |
161 | joined_dataset = paired_dataset.map(_join_paired_dataset)
162 | dataset = joined_dataset
163 | else:
164 | dataset = dataset.shuffle(shuffle)
165 | if max_examples is not None:
166 | dataset = dataset.take(max_examples)
167 | dataset = dataset.batch(batch_size)
168 | if shuffle > 0: # input is not deterministic
169 | dataset = dataset.prefetch(4) # Prefetch a few batches.
170 | return dataset.make_one_shot_iterator().get_next()
171 |
--------------------------------------------------------------------------------
/dataset_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from PIL import Image
16 | from absl import app
17 | from absl import flags
18 | from options import FLAGS as opts
19 | import cv2
20 | import data
21 | import functools
22 | import glob
23 | import numpy as np
24 | import os
25 | import os.path as osp
26 | import shutil
27 | import six
28 | import tensorflow as tf
29 | import segment_dataset as segment_utils
30 | import utils
31 |
32 | FLAGS = flags.FLAGS
33 | flags.DEFINE_string('output_dir', None, 'Directory to save exported tfrecords.')
34 | flags.DEFINE_string('xception_frozen_graph_path', None,
35 | 'Path to the deeplab xception model frozen graph')
36 |
37 |
38 | class AlignedRenderedDataset(object):
39 | def __init__(self, rendered_filepattern, use_semantic_map=True):
40 | """
41 | Args:
42 | rendered_filepattern: string, path filepattern to 3D rendered images (
43 | assumes filenames are '/path/to/dataset/%d_color.png')
44 | use_semantic_map: bool, include semantic maps. in the TFRecord
45 | """
46 | self.filenames = sorted(glob.glob(rendered_filepattern))
47 | assert len(self.filenames) > 0, ('input %s didn\'t match any files!' %
48 | rendered_filepattern)
49 | self.iter_idx = 0
50 | self.use_semantic_map = use_semantic_map
51 |
52 | def __iter__(self):
53 | return self
54 |
55 | def __next__(self):
56 | return self.next()
57 |
58 | def next(self):
59 | if self.iter_idx < len(self.filenames):
60 | rendered_img_name = self.filenames[self.iter_idx]
61 | basename = rendered_img_name[:-9] # remove the 'color.png' suffix
62 | ref_img_name = basename + 'reference.png'
63 | depth_img_name = basename + 'depth.png'
64 | # Read the 3D rendered image
65 | img_rendered = cv2.imread(rendered_img_name, cv2.IMREAD_UNCHANGED)
66 | # Change BGR (default cv2 format) to RGB
67 | img_rendered = img_rendered[:, :, [2,1,0,3]] # it has a 4th alpha channel
68 | # Read the depth image
69 | img_depth = cv2.imread(depth_img_name, cv2.IMREAD_UNCHANGED)
70 | # Workaround as some depth images are read with a different data type!
71 | img_depth = img_depth.astype(np.uint16)
72 | # Read reference image if exists, otherwise replace with a zero image.
73 | if osp.exists(ref_img_name):
74 | img_ref = cv2.imread(ref_img_name)
75 | img_ref = img_ref[:, :, ::-1] # Change BGR to RGB format.
76 | else: # use a dummy 3-channel zero image as a placeholder
77 | print('Warning: no reference image found! Using a dummy placeholder!')
78 | img_height, img_width = img_depth.shape
79 | img_ref = np.zeros((img_height, img_width, 3), dtype=np.uint8)
80 |
81 | if self.use_semantic_map:
82 | semantic_seg_img_name = basename + 'seg_rgb.png'
83 | img_seg = cv2.imread(semantic_seg_img_name)
84 | img_seg = img_seg[:, :, ::-1] # Change from BGR to RGB
85 | if img_seg.shape[0] == 512 and img_seg.shape[1] == 512:
86 | img_ref = utils.get_central_crop(img_ref)
87 | img_rendered = utils.get_central_crop(img_rendered)
88 | img_depth = utils.get_central_crop(img_depth)
89 |
90 | img_shape = img_depth.shape
91 | assert img_seg.shape == (img_shape + (3,)), 'error in seg image %s %s' % (
92 | basename, str(img_seg.shape))
93 | assert img_ref.shape == (img_shape + (3,)), 'error in ref image %s %s' % (
94 | basename, str(img_ref.shape))
95 | assert img_rendered.shape == (img_shape + (4,)), ('error in rendered '
96 | 'image %s %s' % (basename, str(img_rendered.shape)))
97 | assert len(img_depth.shape) == 2, 'error in depth image %s %s' % (
98 | basename, str(img_depth.shape))
99 |
100 | raw_example = dict()
101 | raw_example['height'] = img_ref.shape[0]
102 | raw_example['width'] = img_ref.shape[1]
103 | raw_example['rendered'] = img_rendered.tostring()
104 | raw_example['depth'] = img_depth.tostring()
105 | raw_example['real'] = img_ref.tostring()
106 | if self.use_semantic_map:
107 | raw_example['seg'] = img_seg.tostring()
108 | self.iter_idx += 1
109 | return raw_example
110 | else:
111 | raise StopIteration()
112 |
113 |
114 | def filter_out_sparse_renders(dataset_dir, splits, ratio_threshold=0.15):
115 | print('Filtering %s' % dataset_dir)
116 | if splits is None:
117 | imgs_dirs = [dataset_dir]
118 | else:
119 | imgs_dirs = [osp.join(dataset_dir, split) for split in splits]
120 |
121 | filtered_images = []
122 | total_images = 0
123 | sum_density = 0
124 | for cur_dir in imgs_dirs:
125 | filtered_dir = osp.join(cur_dir, 'sparse_renders')
126 | if not osp.exists(filtered_dir):
127 | os.makedirs(filtered_dir)
128 | imgs_file_pattern = osp.join(cur_dir, '*_color.png')
129 | images_path = sorted(glob.glob(imgs_file_pattern))
130 | print('Processing %d files' % len(images_path))
131 | total_images += len(images_path)
132 | for ii, img_path in enumerate(images_path):
133 | img = np.array(Image.open(img_path))
134 | aggregate = np.squeeze(np.sum(img, axis=2))
135 | height, width = aggregate.shape
136 | mask = aggregate > 0
137 | density = np.sum(mask) * 1. / (height * width)
138 | sum_density += density
139 | if density <= ratio_threshold:
140 | parent, basename = osp.split(img_path)
141 | basename = basename[:-10] # remove the '_color.png' suffix
142 | srcs = sorted(glob.glob(osp.join(parent, basename + '_*')))
143 | dest = unicode(filtered_dir + '/.')
144 | for src in srcs:
145 | shutil.move(src, dest)
146 | filtered_images.append(basename)
147 | print('filtered fie %d: %s with a desnity of %.3f' % (ii, basename,
148 | density))
149 | print('Filtered %d/%d images' % (len(filtered_images), total_images))
150 | print('Mean desnity = %.4f' % (sum_density / total_images))
151 |
152 |
153 | def _to_example(dictionary):
154 | """Helper: build tf.Example from (string -> int/float/str list) dictionary."""
155 | features = {}
156 | for (k, v) in six.iteritems(dictionary):
157 | if isinstance(v, six.integer_types):
158 | features[k] = tf.train.Feature(int64_list=tf.train.Int64List(value=[v]))
159 | elif isinstance(v, float):
160 | features[k] = tf.train.Feature(float_list=tf.train.FloatList(value=[v]))
161 | elif isinstance(v, six.string_types):
162 | features[k] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[v]))
163 | elif isinstance(v, bytes):
164 | features[k] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[v]))
165 | else:
166 | raise ValueError("Value for %s is not a recognized type; v: %s type: %s" %
167 | (k, str(v[0]), str(type(v[0]))))
168 |
169 | return tf.train.Example(features=tf.train.Features(feature=features))
170 |
171 |
172 | def _generate_tfrecord_dataset(generator,
173 | output_name,
174 | output_dir):
175 | """Convert a dataset into TFRecord format."""
176 | output_filename = os.path.join(output_dir, output_name)
177 | output_file = os.path.join(output_dir, output_filename)
178 | tf.logging.info("Writing TFRecords to file %s", output_file)
179 | writer = tf.python_io.TFRecordWriter(output_file)
180 |
181 | counter = 0
182 | for case in generator:
183 | if counter % 100 == 0:
184 | print('Generating case %d for %s.' % (counter, output_name))
185 | counter += 1
186 | example = _to_example(case)
187 | writer.write(example.SerializeToString())
188 |
189 | writer.close()
190 | return output_file
191 |
192 |
193 | def export_aligned_dataset_to_tfrecord(
194 | dataset_dir, output_dir, output_basename, splits,
195 | xception_frozen_graph_path):
196 |
197 | # Step 1: filter out sparse renders
198 | filter_out_sparse_renders(dataset_dir, splits, 0.15)
199 |
200 | # Step 2: generate semantic segmentation masks
201 | segment_utils.segment_and_color_dataset(
202 | dataset_dir, xception_frozen_graph_path, splits)
203 |
204 | # Step 3: export dataset to TFRecord
205 | if splits is None:
206 | input_filepattern = osp.join(dataset_dir, '*_color.png')
207 | dataset_iter = AlignedRenderedDataset(input_filepattern)
208 | output_name = output_basename + '.tfrecord'
209 | _generate_tfrecord_dataset(dataset_iter, output_name, output_dir)
210 | else:
211 | for split in splits:
212 | input_filepattern = osp.join(dataset_dir, split, '*_color.png')
213 | dataset_iter = AlignedRenderedDataset(input_filepattern)
214 | output_name = '%s_%s.tfrecord' % (output_basename, split)
215 | _generate_tfrecord_dataset(dataset_iter, output_name, output_dir)
216 |
217 |
218 | def main(argv):
219 | # Read input flags
220 | dataset_name = opts.dataset_name
221 | dataset_parent_dir = opts.dataset_parent_dir
222 | output_dir = FLAGS.output_dir
223 | xception_frozen_graph_path = FLAGS.xception_frozen_graph_path
224 | splits = ['train', 'val']
225 | # Run the preprocessing pipeline
226 | export_aligned_dataset_to_tfrecord(
227 | dataset_parent_dir, output_dir, dataset_name, splits,
228 | xception_frozen_graph_path)
229 |
230 |
231 | if __name__ == '__main__':
232 | app.run(main)
233 |
--------------------------------------------------------------------------------
/evaluate_quantitative_metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from PIL import Image
16 | from absl import app
17 | from absl import flags
18 | import functools
19 | import glob
20 | import numpy as np
21 | import os
22 | import os.path as osp
23 | import skimage.measure
24 | import tensorflow as tf
25 | import utils
26 |
27 | FLAGS = flags.FLAGS
28 | flags.DEFINE_string('val_set_out_dir', None,
29 | 'Output directory with concatenated fake and real images.')
30 | flags.DEFINE_string('experiment_title', 'experiment',
31 | 'Name for the experiment to evaluate')
32 |
33 |
34 | def _extract_real_and_fake_from_concatenated_output(val_set_out_dir):
35 | out_dir = osp.join(val_set_out_dir, 'fake')
36 | gt_dir = osp.join(val_set_out_dir, 'real')
37 | if not osp.exists(out_dir):
38 | os.makedirs(out_dir)
39 | if not osp.exists(gt_dir):
40 | os.makedirs(gt_dir)
41 | imgs_pattern = osp.join(val_set_out_dir, '*.png')
42 | imgs_paths = sorted(glob.glob(imgs_pattern))
43 | print('Separating %d images in %s.' % (len(imgs_paths), val_set_out_dir))
44 | for img_path in imgs_paths:
45 | basename = osp.basename(img_path)[:-4] # remove the '.png' suffix
46 | img = np.array(Image.open(img_path))
47 | img_res = 512
48 | fake = img[:, -2*img_res:-img_res, :]
49 | real = img[:, -img_res:, :]
50 | fake_path = osp.join(out_dir, '%s_fake.png' % basename)
51 | real_path = osp.join(gt_dir, '%s_real.png' % basename)
52 | Image.fromarray(fake).save(fake_path)
53 | Image.fromarray(real).save(real_path)
54 |
55 |
56 | def compute_l1_loss_metric(image_set1_paths, image_set2_paths):
57 | assert len(image_set1_paths) == len(image_set2_paths)
58 | assert len(image_set1_paths) > 0
59 | print('Evaluating L1 loss for %d pairs' % len(image_set1_paths))
60 |
61 | total_loss = 0.
62 | for ii, (img1_path, img2_path) in enumerate(zip(image_set1_paths,
63 | image_set2_paths)):
64 | img1_in_ar = np.array(Image.open(img1_path), dtype=np.float32)
65 | img1_in_ar = utils.crop_to_multiple(img1_in_ar)
66 |
67 | img2_in_ar = np.array(Image.open(img2_path), dtype=np.float32)
68 | img2_in_ar = utils.crop_to_multiple(img2_in_ar)
69 |
70 | loss_l1 = np.mean(np.abs(img1_in_ar - img2_in_ar))
71 | total_loss += loss_l1
72 |
73 | return total_loss / len(image_set1_paths)
74 |
75 |
76 | def compute_psnr_loss_metric(image_set1_paths, image_set2_paths):
77 | assert len(image_set1_paths) == len(image_set2_paths)
78 | assert len(image_set1_paths) > 0
79 | print('Evaluating PSNR loss for %d pairs' % len(image_set1_paths))
80 |
81 | total_loss = 0.
82 | for ii, (img1_path, img2_path) in enumerate(zip(image_set1_paths,
83 | image_set2_paths)):
84 | img1_in_ar = np.array(Image.open(img1_path))
85 | img1_in_ar = utils.crop_to_multiple(img1_in_ar)
86 |
87 | img2_in_ar = np.array(Image.open(img2_path))
88 | img2_in_ar = utils.crop_to_multiple(img2_in_ar)
89 |
90 | loss_psnr = skimage.measure.compare_psnr(img1_in_ar, img2_in_ar)
91 | total_loss += loss_psnr
92 |
93 | return total_loss / len(image_set1_paths)
94 |
95 |
96 | def evaluate_experiment(val_set_out_dir, title='experiment',
97 | metrics=['psnr', 'l1']):
98 |
99 | out_dir = osp.join(val_set_out_dir, 'fake')
100 | gt_dir = osp.join(val_set_out_dir, 'real')
101 | _extract_real_and_fake_from_concatenated_output(val_set_out_dir)
102 | input_pattern1 = osp.join(gt_dir, '*.png')
103 | input_pattern2 = osp.join(out_dir, '*.png')
104 | set1 = sorted(glob.glob(input_pattern1))
105 | set2 = sorted(glob.glob(input_pattern2))
106 | for metric in metrics:
107 | if metric == 'l1':
108 | mean_loss = compute_l1_loss_metric(set1, set2)
109 | elif metric == 'psnr':
110 | mean_loss = compute_psnr_loss_metric(set1, set2)
111 | print('*** mean %s loss for %s = %f' % (metric, title, mean_loss))
112 |
113 |
114 | def main(argv):
115 | evaluate_experiment(FLAGS.val_set_out_dir, title=FLAGS.experiment_title,
116 | metrics=['psnr', 'l1'])
117 |
118 |
119 | if __name__ == '__main__':
120 | app.run(main)
121 |
--------------------------------------------------------------------------------
/imgs/app_interpolation.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/neural_rerendering_in_the_wild/5f5226f0083559adb0a21b567104dfc075b6f6e5/imgs/app_interpolation.jpg
--------------------------------------------------------------------------------
/imgs/app_variation.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/neural_rerendering_in_the_wild/5f5226f0083559adb0a21b567104dfc075b6f6e5/imgs/app_variation.jpg
--------------------------------------------------------------------------------
/imgs/teaser_with_caption.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/neural_rerendering_in_the_wild/5f5226f0083559adb0a21b567104dfc075b6f6e5/imgs/teaser_with_caption.jpg
--------------------------------------------------------------------------------
/layers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import functools
16 | import numpy as np
17 | import tensorflow as tf
18 |
19 |
20 | class LayerInstanceNorm(object):
21 |
22 | def __init__(self, scope_suffix='instance_norm'):
23 | curr_scope = tf.get_variable_scope().name
24 | self._scope = curr_scope + '/' + scope_suffix
25 |
26 | def __call__(self, x):
27 | with tf.variable_scope(self._scope, reuse=tf.AUTO_REUSE):
28 | return tf.contrib.layers.instance_norm(
29 | x, epsilon=1e-05, center=True, scale=True)
30 |
31 |
32 | def layer_norm(x, scope='layer_norm'):
33 | return tf.contrib.layers.layer_norm(x, center=True, scale=True)
34 |
35 |
36 | def pixel_norm(x):
37 | """Pixel normalization.
38 |
39 | Args:
40 | x: 4D image tensor in B01C format.
41 |
42 | Returns:
43 | 4D tensor with pixel normalized channels.
44 | """
45 | return x * tf.rsqrt(tf.reduce_mean(tf.square(x), [-1], keepdims=True) + 1e-8)
46 |
47 |
48 | def global_avg_pooling(x):
49 | return tf.reduce_mean(x, axis=[1, 2], keepdims=True)
50 |
51 |
52 | class FullyConnected(object):
53 |
54 | def __init__(self, n_out_units, scope_suffix='FC'):
55 | weight_init = tf.random_normal_initializer(mean=0., stddev=0.02)
56 | weight_regularizer = tf.contrib.layers.l2_regularizer(scale=0.0001)
57 |
58 | curr_scope = tf.get_variable_scope().name
59 | self._scope = curr_scope + '/' + scope_suffix
60 | self.fc_layer = functools.partial(
61 | tf.layers.dense, units=n_out_units, kernel_initializer=weight_init,
62 | kernel_regularizer=weight_regularizer, use_bias=True)
63 |
64 | def __call__(self, x):
65 | with tf.variable_scope(self._scope, reuse=tf.AUTO_REUSE):
66 | return self.fc_layer(x)
67 |
68 |
69 | def init_he_scale(shape, slope=1.0):
70 | """He neural network random normal scaling for initialization.
71 |
72 | Args:
73 | shape: list of the dimensions of the tensor.
74 | slope: float, slope of the ReLu following the layer.
75 |
76 | Returns:
77 | a float, He's standard deviation.
78 | """
79 | fan_in = np.prod(shape[:-1])
80 | return np.sqrt(2. / ((1. + slope**2) * fan_in))
81 |
82 |
83 | class LayerConv(object):
84 | """Convolution layer with support for equalized learning."""
85 |
86 | def __init__(self,
87 | name,
88 | w,
89 | n,
90 | stride,
91 | padding='SAME',
92 | use_scaling=False,
93 | relu_slope=1.):
94 | """Layer constructor.
95 |
96 | Args:
97 | name: string, layer name.
98 | w: int or 2-tuple, width of the convolution kernel.
99 | n: 2-tuple of ints, input and output channel depths.
100 | stride: int or 2-tuple, stride for the convolution kernel.
101 | padding: string, the padding method. {SAME, VALID, REFLECT}.
102 | use_scaling: bool, whether to use weight norm and scaling.
103 | relu_slope: float, the slope of the ReLu following the layer.
104 | """
105 | assert padding in ['SAME', 'VALID', 'REFLECT'], 'Error: unsupported padding'
106 | self._padding = padding
107 | with tf.variable_scope(name):
108 | if isinstance(stride, int):
109 | stride = [1, stride, stride, 1]
110 | else:
111 | assert len(stride) == 0, "stride is either an int or a 2-tuple"
112 | stride = [1, stride[0], stride[1], 1]
113 | if isinstance(w, int):
114 | w = [w, w]
115 | self.w = w
116 | shape = [w[0], w[1], n[0], n[1]]
117 | init_scale, pre_scale = init_he_scale(shape, relu_slope), 1.
118 | if use_scaling:
119 | init_scale, pre_scale = pre_scale, init_scale
120 | self._stride = stride
121 | self._pre_scale = pre_scale
122 | self._weight = tf.get_variable(
123 | 'weight',
124 | shape=shape,
125 | initializer=tf.random_normal_initializer(stddev=init_scale))
126 | self._bias = tf.get_variable(
127 | 'bias', shape=[n[1]], initializer=tf.zeros_initializer)
128 |
129 | def __call__(self, x):
130 | """Apply layer to tensor x."""
131 | if self._padding != 'REFLECT':
132 | padding = self._padding
133 | else:
134 | padding = 'VALID'
135 | pad_top = self.w[0] // 2
136 | pad_left = self.w[1] // 2
137 | if (self.w[0] - self._stride[1]) % 2 == 0:
138 | pad_bottom = pad_top
139 | else:
140 | pad_bottom = self.w[0] - self._stride[1] - pad_top
141 | if (self.w[1] - self._stride[2]) % 2 == 0:
142 | pad_right = pad_left
143 | else:
144 | pad_right = self.w[1] - self._stride[2] - pad_left
145 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right],
146 | [0, 0]], mode='REFLECT')
147 | y = tf.nn.conv2d(x, self._weight, strides=self._stride, padding=padding)
148 | return self._pre_scale * y + self._bias
149 |
150 |
151 | class LayerTransposedConv(object):
152 | """Convolution layer with support for equalized learning."""
153 |
154 | def __init__(self,
155 | name,
156 | w,
157 | n,
158 | stride,
159 | padding='SAME',
160 | use_scaling=False,
161 | relu_slope=1.):
162 | """Layer constructor.
163 |
164 | Args:
165 | name: string, layer name.
166 | w: int or 2-tuple, width of the convolution kernel.
167 | n: 2-tuple int, [n_in_channels, n_out_channels]
168 | stride: int or 2-tuple, stride for the convolution kernel.
169 | padding: string, the padding method {SAME, VALID, REFLECT}.
170 | use_scaling: bool, whether to use weight norm and scaling.
171 | relu_slope: float, the slope of the ReLu following the layer.
172 | """
173 | assert padding in ['SAME'], 'Error: unsupported padding for transposed conv'
174 | if isinstance(stride, int):
175 | stride = [1, stride, stride, 1]
176 | else:
177 | assert len(stride) == 2, "stride is either an int or a 2-tuple"
178 | stride = [1, stride[0], stride[1], 1]
179 | if isinstance(w, int):
180 | w = [w, w]
181 | self.padding = padding
182 | self.nc_in, self.nc_out = n
183 | self.stride = stride
184 | with tf.variable_scope(name):
185 | kernel_shape = [w[0], w[1], self.nc_out, self.nc_in]
186 | init_scale, pre_scale = init_he_scale(kernel_shape, relu_slope), 1.
187 | if use_scaling:
188 | init_scale, pre_scale = pre_scale, init_scale
189 | self._pre_scale = pre_scale
190 | self._weight = tf.get_variable(
191 | 'weight',
192 | shape=kernel_shape,
193 | initializer=tf.random_normal_initializer(stddev=init_scale))
194 | self._bias = tf.get_variable(
195 | 'bias', shape=[self.nc_out], initializer=tf.zeros_initializer)
196 |
197 | def __call__(self, x):
198 | """Apply layer to tensor x."""
199 | x_shape = x.get_shape().as_list()
200 | batch_size = tf.shape(x)[0]
201 | stride_x, stride_y = self.stride[1], self.stride[2]
202 | output_shape = tf.stack([
203 | batch_size, x_shape[1] * stride_x, x_shape[2] * stride_y, self.nc_out])
204 | y = tf.nn.conv2d_transpose(
205 | x, filter=self._weight, output_shape=output_shape, strides=self.stride,
206 | padding=self.padding)
207 | return self._pre_scale * y + self._bias
208 |
209 |
210 | class ResBlock(object):
211 | def __init__(self,
212 | name,
213 | nc,
214 | norm_layer_constructor,
215 | activation,
216 | padding='SAME',
217 | use_scaling=False,
218 | relu_slope=1.):
219 | """Layer constructor."""
220 | self.name = name
221 | conv2d = functools.partial(
222 | LayerConv, w=3, n=[nc, nc], stride=1, padding=padding,
223 | use_scaling=use_scaling, relu_slope=relu_slope)
224 | self.blocks = []
225 | with tf.variable_scope(self.name):
226 | with tf.variable_scope('res0'):
227 | self.blocks.append(
228 | LayerPipe([
229 | conv2d('res0_conv'),
230 | norm_layer_constructor('res0_norm'),
231 | activation
232 | ])
233 | )
234 | with tf.variable_scope('res1'):
235 | self.blocks.append(
236 | LayerPipe([
237 | conv2d('res1_conv'),
238 | norm_layer_constructor('res1_norm')
239 | ])
240 | )
241 |
242 | def __call__(self, x_init):
243 | """Apply layer to tensor x."""
244 | x = x_init
245 | for f in self.blocks:
246 | x = f(x)
247 | return x + x_init
248 |
249 |
250 | class BasicBlock(object):
251 | def __init__(self,
252 | name,
253 | n,
254 | activation=functools.partial(tf.nn.leaky_relu, alpha=0.2),
255 | padding='SAME',
256 | use_scaling=True,
257 | relu_slope=1.):
258 | """Layer constructor."""
259 | self.name = name
260 | conv2d = functools.partial(
261 | LayerConv, stride=1, padding=padding,
262 | use_scaling=use_scaling, relu_slope=relu_slope)
263 | avg_pool = functools.partial(downscale, n=2)
264 | nc_in, nc_out = n # n is a 2-tuple
265 | with tf.variable_scope(self.name):
266 | self.path1_blocks = []
267 | with tf.variable_scope('bb_path1'):
268 | self.path1_blocks.append(
269 | LayerPipe([
270 | activation,
271 | conv2d('bb_conv0', w=3, n=[nc_in, nc_out]),
272 | activation,
273 | conv2d('bb_conv1', w=3, n=[nc_out, nc_out]),
274 | downscale
275 | ])
276 | )
277 |
278 | self.path2_blocks = []
279 | with tf.variable_scope('bb_path2'):
280 | self.path2_blocks.append(
281 | LayerPipe([
282 | downscale,
283 | conv2d('path2_conv', w=1, n=[nc_in, nc_out])
284 | ])
285 | )
286 |
287 | def __call__(self, x_init):
288 | """Apply layer to tensor x."""
289 | x1 = x_init
290 | x2 = x_init
291 | for f in self.path1_blocks:
292 | x1 = f(x1)
293 | for f in self.path2_blocks:
294 | x2 = f(x2)
295 | return x1 + x2
296 |
297 |
298 | class LayerDense(object):
299 | """Dense layer with a non-linearity."""
300 |
301 | def __init__(self, name, n, use_scaling=False, relu_slope=1.):
302 | """Layer constructor.
303 |
304 | Args:
305 | name: string, layer name.
306 | n: 2-tuple of ints, input and output widths.
307 | use_scaling: bool, whether to use weight norm and scaling.
308 | relu_slope: float, the slope of the ReLu following the layer.
309 | """
310 | with tf.variable_scope(name):
311 | init_scale, pre_scale = init_he_scale(n, relu_slope), 1.
312 | if use_scaling:
313 | init_scale, pre_scale = pre_scale, init_scale
314 | self._pre_scale = pre_scale
315 | self._weight = tf.get_variable(
316 | 'weight',
317 | shape=n,
318 | initializer=tf.random_normal_initializer(stddev=init_scale))
319 | self._bias = tf.get_variable(
320 | 'bias', shape=[n[1]], initializer=tf.zeros_initializer)
321 |
322 | def __call__(self, x):
323 | """Apply layer to tensor x."""
324 | return self._pre_scale * tf.matmul(x, self._weight) + self._bias
325 |
326 |
327 | class LayerPipe(object):
328 | """Pipe a sequence of functions."""
329 |
330 | def __init__(self, functions):
331 | """Layer constructor.
332 |
333 | Args:
334 | functions: list, functions to pipe.
335 | """
336 | self._functions = tuple(functions)
337 |
338 | def __call__(self, x, **kwargs):
339 | """Apply pipe to tensor x and return result."""
340 | del kwargs
341 | for f in self._functions:
342 | x = f(x)
343 | return x
344 |
345 |
346 | def downscale(x, n=2):
347 | """Box downscaling.
348 |
349 | Args:
350 | x: 4D image tensor.
351 | n: integer scale (must be a power of 2).
352 |
353 | Returns:
354 | 4D tensor of images down scaled by a factor n.
355 | """
356 | if n == 1:
357 | return x
358 | return tf.nn.avg_pool(x, [1, n, n, 1], [1, n, n, 1], 'VALID')
359 |
360 |
361 | def upscale(x, n):
362 | """Box upscaling (also called nearest neighbors).
363 |
364 | Args:
365 | x: 4D image tensor in B01C format.
366 | n: integer scale (must be a power of 2).
367 |
368 | Returns:
369 | 4D tensor of images up scaled by a factor n.
370 | """
371 | if n == 1:
372 | return x
373 | x_shape = tf.shape(x)
374 | height, width = x_shape[1], x_shape[2]
375 | return tf.image.resize_nearest_neighbor(x, [n * height, n * width])
376 |
377 |
378 | def tile_and_concatenate(x, z, n_z):
379 | z = tf.reshape(z, shape=[-1, 1, 1, n_z])
380 | z = tf.tile(z, [1, tf.shape(x)[1], tf.shape(x)[2], 1])
381 | x = tf.concat([x, z], axis=-1)
382 | return x
383 |
384 |
385 | def minibatch_mean_variance(x):
386 | """Computes the variance average.
387 |
388 | This is used by the discriminator as a form of batch discrimination.
389 |
390 | Args:
391 | x: nD tensor for which to compute variance average.
392 |
393 | Returns:
394 | a scalar, the mean variance of variable x.
395 | """
396 | mean = tf.reduce_mean(x, 0, keepdims=True)
397 | vals = tf.sqrt(tf.reduce_mean(tf.squared_difference(x, mean), 0) + 1e-8)
398 | vals = tf.reduce_mean(vals)
399 | return vals
400 |
401 |
402 | def scalar_concat(x, scalar):
403 | """Concatenate a scalar to a 4D tensor as an extra channel.
404 |
405 | Args:
406 | x: 4D image tensor in B01C format.
407 | scalar: a scalar to concatenate to the tensor.
408 |
409 | Returns:
410 | a 4D tensor with one extra channel containing the value scalar at
411 | every position.
412 | """
413 | s = tf.shape(x)
414 | return tf.concat([x, tf.ones([s[0], s[1], s[2], 1]) * scalar], axis=3)
415 |
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from options import FLAGS as opts
16 | import layers
17 | import os.path as osp
18 | import tensorflow as tf
19 | import vgg16
20 |
21 |
22 | def gradient_penalty_loss(y_xy, xy, iwass_target=1, iwass_lambda=10):
23 | grad = tf.gradients(tf.reduce_sum(y_xy), [xy])[0]
24 | grad_norm = tf.sqrt(tf.reduce_sum(tf.square(grad), axis=[1, 2, 3]) + 1e-8)
25 | loss_gp = tf.reduce_mean(
26 | tf.square(grad_norm - iwass_target)) * iwass_lambda / iwass_target**2
27 | return loss_gp
28 |
29 |
30 | def KL_loss(mean, logvar):
31 | loss = 0.5 * tf.reduce_sum(tf.square(mean) + tf.exp(logvar) - 1. - logvar,
32 | axis=-1)
33 | return tf.reduce_sum(loss) # just to match DRIT implementation
34 |
35 |
36 | def l2_regularize(x):
37 | return tf.reduce_mean(tf.square(x))
38 |
39 |
40 | def L1_loss(x, y):
41 | return tf.reduce_mean(tf.abs(x - y))
42 |
43 |
44 | class PerceptualLoss:
45 | def __init__(self, x, y, image_shape, layers, w_layers, w_act=0.1):
46 | """
47 | Builds vgg16 network and computes the perceptual loss.
48 | """
49 | assert len(image_shape) == 3 and image_shape[-1] == 3
50 | assert osp.exists(opts.vgg16_path), 'Cannot find %s' % opts.vgg16_path
51 |
52 | self.w_act = w_act
53 | self.vgg_layers = layers
54 | self.w_layers = w_layers
55 | batch_shape = [None] + image_shape # [None, H, W, 3]
56 |
57 | vgg_net = vgg16.Vgg16(opts.vgg16_path)
58 | self.x_acts = vgg_net.get_vgg_activations(x, layers)
59 | self.y_acts = vgg_net.get_vgg_activations(y, layers)
60 | loss = 0
61 | for w, act1, act2 in zip(self.w_layers, self.x_acts, self.y_acts):
62 | loss += w * tf.reduce_mean(tf.square(self.w_act * (act1 - act2)))
63 | self.loss = loss
64 |
65 | def __call__(self):
66 | return self.loss
67 |
68 |
69 | def lsgan_appearance_E_loss(disc_response):
70 | disc_response = tf.squeeze(disc_response)
71 | gt_label = 0.5
72 | loss = tf.reduce_mean(tf.square(disc_response - gt_label))
73 | return loss
74 |
75 |
76 | def lsgan_loss(disc_response, is_real):
77 | gt_label = 1 if is_real else 0
78 | disc_response = tf.squeeze(disc_response)
79 | # The following works for both regular and patchGAN discriminators
80 | loss = tf.reduce_mean(tf.square(disc_response - gt_label))
81 | return loss
82 |
83 |
84 | def multiscale_discriminator_loss(Ds_responses, is_real):
85 | num_D = len(Ds_responses)
86 | loss = 0
87 | for i in range(num_D):
88 | curr_response = Ds_responses[i][-1][-1]
89 | loss += lsgan_loss(curr_response, is_real)
90 | return loss
91 |
--------------------------------------------------------------------------------
/networks.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from options import FLAGS as opts
16 | import functools
17 | import layers
18 | import tensorflow as tf
19 |
20 |
21 | class RenderingModel(object):
22 |
23 | def __init__(self, model_name, use_appearance=True):
24 |
25 | if model_name == 'pggan':
26 | self._model = ModelPGGAN(use_appearance)
27 | else:
28 | raise ValueError('Model %s not implemented!' % model_name)
29 |
30 | def __call__(self, x_in, z_app=None):
31 | return self._model(x_in, z_app)
32 |
33 | def get_appearance_encoder(self):
34 | return self._model._appearance_encoder
35 |
36 | def get_generator(self):
37 | return self._model._generator
38 |
39 | def get_content_encoder(self):
40 | return self._model._content_encoder
41 |
42 |
43 | # "Progressive Growing of GANs (PGGAN)"-inspired architecture. Implementation is
44 | # based on the implementation details in their paper, but code is not taken from
45 | # the authors' released code.
46 | # Main changes are:
47 | # - conditional GAN setup by introducting an encoder + skip connections.
48 | # - no progressive growing during training.
49 | class ModelPGGAN(RenderingModel):
50 |
51 | def __init__(self, use_appearance=True):
52 | self._use_appearance = use_appearance
53 | self._content_encoder = None
54 | self._generator = GeneratorPGGAN(appearance_vec_size=opts.app_vector_size)
55 | if use_appearance:
56 | self._appearance_encoder = DRITAppearanceEncoderConcat(
57 | 'appearance_net', opts.appearance_nc, opts.normalize_drit_Ez)
58 | else:
59 | self._appearance_encoder = None
60 |
61 | def __call__(self, x_in, z_app=None):
62 | y = self._generator(x_in, z_app)
63 | return y
64 |
65 | def get_appearance_encoder(self):
66 | return self._appearance_encoder
67 |
68 | def get_generator(self):
69 | return self._generator
70 |
71 | def get_content_encoder(self):
72 | return self._content_encoder
73 |
74 |
75 | class PatchGANDiscriminator(object):
76 |
77 | def __init__(self, name_scope, input_nc, nf=64, n_layers=3, get_fmaps=False):
78 | """Constructor for a patchGAN discriminators.
79 |
80 | Args:
81 | name_scope: str - tf name scope.
82 | input_nc: int - number of input channels.
83 | nf: int - starting number of discriminator filters.
84 | n_layers: int - number of layers in the discriminator.
85 | get_fmaps: bool - return intermediate feature maps for FeatLoss.
86 | """
87 | self.get_fmaps = get_fmaps
88 | self.n_layers = n_layers
89 | kw = 4 # kernel width for convolution
90 |
91 | activation = functools.partial(tf.nn.leaky_relu, alpha=0.2)
92 | norm_layer = functools.partial(layers.LayerInstanceNorm)
93 | conv2d = functools.partial(layers.LayerConv, use_scaling=opts.use_scaling,
94 | relu_slope=0.2)
95 |
96 | def minibatch_stats(x):
97 | return layers.scalar_concat(x, layers.minibatch_mean_variance(x))
98 |
99 | # Create layers.
100 | self.blocks = []
101 | with tf.variable_scope(name_scope, tf.AUTO_REUSE):
102 | with tf.variable_scope('block_0'):
103 | self.blocks.append([
104 | conv2d('conv0', w=kw, n=[input_nc, nf], stride=2),
105 | activation
106 | ])
107 | for ii_block in range(1, n_layers):
108 | nf_prev = nf
109 | nf = min(nf * 2, 512)
110 | with tf.variable_scope('block_%d' % ii_block):
111 | self.blocks.append([
112 | conv2d('conv%d' % ii_block, w=kw, n=[nf_prev, nf], stride=2),
113 | norm_layer(),
114 | activation
115 | ])
116 | # Add minibatch_stats (from PGGAN) and do a stride1 convolution.
117 | nf_prev = nf
118 | nf = min(nf * 2, 512)
119 | with tf.variable_scope('block_%d' % (n_layers + 1)):
120 | self.blocks.append([
121 | minibatch_stats, # this is improvised by @meshry
122 | conv2d('conv%d' % (n_layers + 1), w=kw, n=[nf_prev + 1, nf],
123 | stride=1),
124 | norm_layer(),
125 | activation
126 | ])
127 | # Get 1-channel patchGAN logits
128 | with tf.variable_scope('patchGAN_logits'):
129 | self.blocks.append([
130 | conv2d('conv%d' % (n_layers + 2), w=kw, n=[nf, 1], stride=1)
131 | ])
132 |
133 | def __call__(self, x, x_cond=None):
134 | # Concatenate extra conditioning input, if any.
135 | if x_cond is not None:
136 | x = tf.concat([x, x_cond], axis=3)
137 |
138 | if self.get_fmaps:
139 | # Dummy addition of x to D_fmaps, which will be removed before returing
140 | D_fmaps = [[x]]
141 | for i_block in range(len(self.blocks)):
142 | # Apply layer #0 in the current block
143 | block_fmaps = [self.blocks[i_block][0](D_fmaps[-1][-1])]
144 | # Apply the remaining layers of this block
145 | for i_layer in range(1, len(self.blocks[i_block])):
146 | block_fmaps.append(self.blocks[i_block][i_layer](block_fmaps[-1]))
147 | # Append the feature maps of this block to D_fmaps
148 | D_fmaps.append(block_fmaps)
149 | return D_fmaps[1:] # exclude the input x which we added initially
150 | else:
151 | y = x
152 | for i_block in range(len(self.blocks)):
153 | for i_layer in range(len(self.blocks[i_block])):
154 | y = self.blocks[i_block][i_layer](y)
155 | return [[y]]
156 |
157 |
158 | class MultiScaleDiscriminator(object):
159 |
160 | def __init__(self, name_scope, input_nc, num_scales=3, nf=64, n_layers=3,
161 | get_fmaps=False):
162 | self.get_fmaps = get_fmaps
163 | discs = []
164 | with tf.variable_scope(name_scope):
165 | for i in range(num_scales):
166 | discs.append(PatchGANDiscriminator(
167 | 'D_scale%d' % i, input_nc, nf=nf, n_layers=n_layers,
168 | get_fmaps=get_fmaps))
169 | self.discriminators = discs
170 |
171 | def __call__(self, x, x_cond=None, params=None):
172 | del params
173 | if x_cond is not None:
174 | x = tf.concat([x, x_cond], axis=3)
175 |
176 | responses = []
177 | for ii, D in enumerate(self.discriminators):
178 | responses.append(D(x, x_cond=None)) # x_cond is already concatenated
179 | if ii != len(self.discriminators) - 1:
180 | x = layers.downscale(x, n=2)
181 | return responses
182 |
183 |
184 | class GeneratorPGGAN(object):
185 | def __init__(self, appearance_vec_size=8, use_scaling=True,
186 | num_blocks=5, input_nc=7,
187 | fmap_base=8192, fmap_decay=1.0, fmap_max=512):
188 | """Generator model.
189 |
190 | Args:
191 | appearance_vec_size: int, size of the latent appearance vector.
192 | use_scaling: bool, whether to use weight scaling.
193 | resolution: int, width of the images (assumed to be square).
194 | input_nc: int, number of input channles.
195 | fmap_base: int, base number of channels.
196 | fmap_decay: float, decay rate of channels with respect to depth.
197 | fmap_max: int, max number of channels (supersedes fmap_base).
198 |
199 | Returns:
200 | function of the model.
201 | """
202 | def _num_filters(fmap_base, fmap_decay, fmap_max, stage):
203 | if opts.g_nf == 32:
204 | return min(int(2**(10 - stage)), fmap_max) # nf32
205 | elif opts.g_nf == 64:
206 | return min(int(2**(11 - stage)), fmap_max) # nf64
207 | else:
208 | raise ValueError('Currently unsupported num filters')
209 |
210 | nf = functools.partial(_num_filters, fmap_base, fmap_decay, fmap_max)
211 | self.num_blocks = num_blocks
212 | activation = functools.partial(tf.nn.leaky_relu, alpha=0.2)
213 | conv2d_stride1 = functools.partial(
214 | layers.LayerConv, stride=1, use_scaling=use_scaling, relu_slope=0.2)
215 | conv2d_rgb = functools.partial(layers.LayerConv, w=1, stride=1,
216 | use_scaling=use_scaling)
217 |
218 | # Create encoder layers.
219 | with tf.variable_scope('g_model_enc', tf.AUTO_REUSE):
220 | self.enc_stage = []
221 | self.from_rgb = []
222 |
223 | if opts.use_appearance and opts.inject_z == 'to_encoder':
224 | input_nc += appearance_vec_size
225 |
226 | for i in range(num_blocks, -1, -1):
227 | with tf.variable_scope('res_%d' % i):
228 | self.from_rgb.append(
229 | layers.LayerPipe([
230 | conv2d_rgb('from_rgb', n=[input_nc, nf(i + 1)]),
231 | activation,
232 | ])
233 | )
234 | self.enc_stage.append(
235 | layers.LayerPipe([
236 | functools.partial(layers.downscale, n=2),
237 | conv2d_stride1('conv0', w=3, n=[nf(i + 1), nf(i)]),
238 | activation,
239 | layers.pixel_norm,
240 | conv2d_stride1('conv1', w=3, n=[nf(i), nf(i)]),
241 | activation,
242 | layers.pixel_norm
243 | ])
244 | )
245 |
246 | # Create decoder layers.
247 | with tf.variable_scope('g_model_dec', tf.AUTO_REUSE):
248 | self.dec_stage = []
249 | self.to_rgb = []
250 |
251 | nf_bottleneck = nf(0) # num input filters at the bottleneck
252 | if opts.use_appearance and opts.inject_z == 'to_bottleneck':
253 | nf_bottleneck += appearance_vec_size
254 |
255 | with tf.variable_scope('res_0'):
256 | self.dec_stage.append(
257 | layers.LayerPipe([
258 | functools.partial(layers.upscale, n=2),
259 | conv2d_stride1('conv0', w=3, n=[nf_bottleneck, nf(1)]),
260 | activation,
261 | layers.pixel_norm,
262 | conv2d_stride1('conv1', w=3, n=[nf(1), nf(1)]),
263 | activation,
264 | layers.pixel_norm
265 | ])
266 | )
267 | self.to_rgb.append(conv2d_rgb('to_rgb', n=[nf(1), opts.output_nc]))
268 |
269 | multiply_factor = 2 if opts.concatenate_skip_layers else 1
270 | for i in range(1, num_blocks + 1):
271 | with tf.variable_scope('res_%d' % i):
272 | self.dec_stage.append(
273 | layers.LayerPipe([
274 | functools.partial(layers.upscale, n=2),
275 | conv2d_stride1('conv0', w=3,
276 | n=[multiply_factor * nf(i), nf(i + 1)]),
277 | activation,
278 | layers.pixel_norm,
279 | conv2d_stride1('conv1', w=3, n=[nf(i + 1), nf(i + 1)]),
280 | activation,
281 | layers.pixel_norm
282 | ])
283 | )
284 | self.to_rgb.append(conv2d_rgb('to_rgb',
285 | n=[nf(i + 1), opts.output_nc]))
286 |
287 | def __call__(self, x, appearance_embedding=None, encoder_fmaps=None):
288 | """Generator function.
289 |
290 | Args:
291 | x: 2D tensor (batch, latents), the conditioning input batch of images.
292 | appearance_embedding: float tensor: latent appearance vector.
293 | Returns:
294 | 4D tensor of images (NHWC), the generated images.
295 | """
296 | del encoder_fmaps
297 | enc_st_idx = 0
298 | if opts.use_appearance and opts.inject_z == 'to_encoder':
299 | x = layers.tile_and_concatenate(x, appearance_embedding,
300 | opts.app_vector_size)
301 | y = self.from_rgb[enc_st_idx](x)
302 |
303 | enc_responses = []
304 | for i in range(enc_st_idx, len(self.enc_stage)):
305 | y = self.enc_stage[i](y)
306 | enc_responses.insert(0, y)
307 |
308 | # Concatenate appearance vector to y
309 | if opts.use_appearance and opts.inject_z == 'to_bottleneck':
310 | appearance_tensor = tf.tile(appearance_embedding,
311 | [1, tf.shape(y)[1], tf.shape(y)[2], 1])
312 | y = tf.concat([y, appearance_tensor], axis=3)
313 |
314 | y_list = []
315 | for i in range(self.num_blocks + 1):
316 | if i > 0:
317 | y_skip = enc_responses[i] # skip layer
318 | if opts.concatenate_skip_layers:
319 | y = tf.concat([y, y_skip], axis=3)
320 | else:
321 | y = y + y_skip
322 | y = self.dec_stage[i](y)
323 | y_list.append(y)
324 |
325 | return self.to_rgb[self.num_blocks](y_list[-1])
326 |
327 |
328 | class DRITAppearanceEncoderConcat(object):
329 |
330 | def __init__(self, name_scope, input_nc, normalize_encoder):
331 | self.blocks = []
332 | activation = functools.partial(tf.nn.leaky_relu, alpha=0.2)
333 | conv2d = functools.partial(layers.LayerConv, use_scaling=opts.use_scaling,
334 | relu_slope=0.2, padding='SAME')
335 | with tf.variable_scope(name_scope, tf.AUTO_REUSE):
336 | if normalize_encoder:
337 | self.blocks.append(layers.LayerPipe([
338 | conv2d('conv0', w=4, n=[input_nc, 64], stride=2),
339 | layers.BasicBlock('BB0', n=[64, 128], use_scaling=opts.use_scaling),
340 | layers.pixel_norm,
341 | layers.BasicBlock('BB1', n=[128, 192], use_scaling=opts.use_scaling),
342 | layers.pixel_norm,
343 | layers.BasicBlock('BB2', n=[192, 256], use_scaling=opts.use_scaling),
344 | layers.pixel_norm,
345 | activation,
346 | layers.global_avg_pooling
347 | ]))
348 | else:
349 | self.blocks.append(layers.LayerPipe([
350 | conv2d('conv0', w=4, n=[input_nc, 64], stride=2),
351 | layers.BasicBlock('BB0', n=[64, 128], use_scaling=opts.use_scaling),
352 | layers.BasicBlock('BB1', n=[128, 192], use_scaling=opts.use_scaling),
353 | layers.BasicBlock('BB2', n=[192, 256], use_scaling=opts.use_scaling),
354 | activation,
355 | layers.global_avg_pooling
356 | ]))
357 | # FC layers to get the mean and logvar
358 | self.fc_mean = layers.FullyConnected(opts.app_vector_size, 'FC_mean')
359 | self.fc_logvar = layers.FullyConnected(opts.app_vector_size, 'FC_logvar')
360 |
361 | def __call__(self, x):
362 | for f in self.blocks:
363 | x = f(x)
364 |
365 | mean = self.fc_mean(x)
366 | logvar = self.fc_logvar(x)
367 | # The following is an arbitrarily chosen *deterministic* latent vector
368 | # computation. Another option is to let z = mean, but gradients from logvar
369 | # will be None and will need to be removed.
370 | z = mean + tf.exp(0.5 * logvar)
371 | return z, mean, logvar
372 |
--------------------------------------------------------------------------------
/neural_rerendering.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from PIL import Image
16 | from absl import app
17 | from options import FLAGS as opts
18 | import data
19 | import datetime
20 | import functools
21 | import glob
22 | import losses
23 | import networks
24 | import numpy as np
25 | import options
26 | import os.path as osp
27 | import random
28 | import skimage.measure
29 | import staged_model
30 | import tensorflow as tf
31 | import time
32 | import utils
33 |
34 |
35 | def build_model_fn(use_exponential_moving_average=True):
36 | """Builds and returns the model function for an estimator.
37 |
38 | Args:
39 | use_exponential_moving_average: bool. If true, the exponential moving
40 | average will be used.
41 |
42 | Returns:
43 | function, the model_fn function typically required by an estimator.
44 | """
45 | arch_type = opts.arch_type
46 | use_appearance = opts.use_appearance
47 | def model_fn(features, labels, mode, params):
48 | """An estimator build_fn."""
49 | del labels, params
50 | if mode == tf.estimator.ModeKeys.TRAIN:
51 | step = tf.train.get_global_step()
52 |
53 | x_in = features['conditional_input']
54 | x_gt = features['expected_output'] # ground truth output
55 | x_app = features['peek_input']
56 |
57 | if opts.training_pipeline == 'staged':
58 | ops = staged_model.create_computation_graph(x_in, x_gt, x_app=x_app,
59 | arch_type=opts.arch_type)
60 | op_increment_step = tf.assign_add(step, 1)
61 | train_disc_op = ops['train_disc_op']
62 | train_renderer_op = ops['train_renderer_op']
63 | train_op = tf.group(train_disc_op, train_renderer_op, op_increment_step)
64 |
65 | utils.HookReport.log_tensor(ops['total_loss_d'], 'total_loss_d')
66 | utils.HookReport.log_tensor(ops['loss_d_real'], 'loss_d_real')
67 | utils.HookReport.log_tensor(ops['loss_d_fake'], 'loss_d_fake')
68 | utils.HookReport.log_tensor(ops['total_loss_g'], 'total_loss_g')
69 | utils.HookReport.log_tensor(ops['loss_g_gan'], 'loss_g_gan')
70 | utils.HookReport.log_tensor(ops['loss_g_recon'], 'loss_g_recon')
71 | utils.HookReport.log_tensor(step, 'global_step')
72 |
73 | return tf.estimator.EstimatorSpec(
74 | mode=mode, loss=ops['total_loss_d'] + ops['total_loss_g'],
75 | train_op=train_op)
76 | else:
77 | raise NotImplementedError('%s training is not implemented.' %
78 | opts.training_pipeline)
79 | elif mode == tf.estimator.ModeKeys.EVAL:
80 | raise NotImplementedError('Eval is not implemented.')
81 | else: # all below modes are for difference inference tasks.
82 | # Build network and initialize inference variables.
83 | g_func = networks.RenderingModel(arch_type, use_appearance)
84 | if use_appearance:
85 | app_func = g_func.get_appearance_encoder()
86 | if use_exponential_moving_average:
87 | ema = tf.train.ExponentialMovingAverage(decay=0.999)
88 | var_dict = ema.variables_to_restore()
89 | tf.train.init_from_checkpoint(osp.join(opts.train_dir), var_dict)
90 |
91 | if mode == tf.estimator.ModeKeys.PREDICT:
92 | x_in = features['conditional_input']
93 | if use_appearance:
94 | x_app = features['peek_input']
95 | x_app_embedding, _, _ = app_func(x_app)
96 | else:
97 | x_app_embedding = None
98 | y = g_func(x_in, x_app_embedding)
99 | tf.logging.info('DBG: shape of y during prediction %s.' % str(y.shape))
100 | return tf.estimator.EstimatorSpec(mode=mode, predictions=y)
101 |
102 | # 'eval_subset' mode is same as PREDICT but it concatenates the output to
103 | # the input render, semantic map and ground truth for easy comparison.
104 | elif mode == 'eval_subset':
105 | x_in = features['conditional_input']
106 | x_gt = features['expected_output']
107 | if use_appearance:
108 | x_app = features['peek_input']
109 | x_app_embedding, _, _ = app_func(x_app)
110 | else:
111 | x_app_embedding = None
112 | y = g_func(x_in, x_app_embedding)
113 | tf.logging.info('DBG: shape of y during prediction %s.' % str(y.shape))
114 | x_in_rgb = tf.slice(x_in, [0, 0, 0, 0], [-1, -1, -1, 3])
115 | if opts.use_semantic:
116 | x_in_semantic = tf.slice(x_in, [0, 0, 0, 4], [-1, -1, -1, 3])
117 | output_tuple = tf.concat([x_in_rgb, x_in_semantic, y, x_gt], axis=2)
118 | else:
119 | output_tuple = tf.concat([x_in_rgb, y, x_gt], axis=2)
120 | return tf.estimator.EstimatorSpec(mode=mode, predictions=output_tuple)
121 |
122 | # 'compute_appearance' mode computes and returns the latent z vector.
123 | elif mode == 'compute_appearance':
124 | assert use_appearance, 'use_appearance is set to False!'
125 | x_app_in = features['peek_input']
126 | # NOTE the following line is a temporary hack (which is
127 | # specially bad for inputs smaller than 512x512).
128 | x_app_in = tf.image.resize_image_with_crop_or_pad(x_app_in, 512, 512)
129 | app_embedding, _, _ = app_func(x_app_in)
130 | return tf.estimator.EstimatorSpec(mode=mode, predictions=app_embedding)
131 |
132 | # 'interpolate_appearance' mode expects an already computed latent z
133 | # vector as input passed a value to the dict key 'appearance_embedding'.
134 | elif mode == 'interpolate_appearance':
135 | assert use_appearance, 'use_appearance is set to False!'
136 | x_in = features['conditional_input']
137 | x_app_embedding = features['appearance_embedding']
138 | y = g_func(x_in, x_app_embedding)
139 | tf.logging.info('DBG: shape of y during prediction %s.' % str(y.shape))
140 | return tf.estimator.EstimatorSpec(mode=mode, predictions=y)
141 | else:
142 | raise ValueError('Unsupported mode: ' + mode)
143 |
144 | return model_fn
145 |
146 |
147 | def make_sample_grid_and_save(est, dataset_name, dataset_parent_dir, grid_dims,
148 | output_dir, cur_nimg):
149 | """Evaluate a fixed set of validation images and save output.
150 |
151 | Args:
152 | est: tf,estimator.Estimator, TF estimator to run the predictions.
153 | dataset_name: basename for the validation tfrecord from which to load
154 | validation images.
155 | dataset_parent_dir: path to a directory containing the validation tfrecord.
156 | grid_dims: 2-tuple int for the grid size (1 unit = 1 image).
157 | output_dir: string, where to save image samples.
158 | cur_nimg: int, current number of images seen by training.
159 |
160 | Returns:
161 | None.
162 | """
163 | num_examples = grid_dims[0] * grid_dims[1]
164 | def input_val_fn():
165 | dict_inp = data.provide_data(
166 | dataset_name=dataset_name, parent_dir=dataset_parent_dir, subset='val',
167 | batch_size=1, crop_flag=True, crop_size=opts.train_resolution,
168 | seeds=[0], max_examples=num_examples,
169 | use_appearance=opts.use_appearance, shuffle=0)
170 | x_in = dict_inp['conditional_input']
171 | x_gt = dict_inp['expected_output'] # ground truth output
172 | x_app = dict_inp['peek_input']
173 | return x_in, x_gt, x_app
174 |
175 | def est_input_val_fn():
176 | x_in, _, x_app = input_val_fn()
177 | features = {'conditional_input': x_in, 'peek_input': x_app}
178 | return features
179 |
180 | images = [x for x in est.predict(est_input_val_fn)]
181 | images = np.array(images, 'f')
182 | images = images.reshape(grid_dims + images.shape[1:])
183 | utils.save_images(utils.to_png(utils.images_to_grid(images)), output_dir,
184 | cur_nimg)
185 |
186 |
187 | def visualize_image_sequence(est, dataset_name, dataset_parent_dir,
188 | input_sequence_name, app_base_path, output_dir):
189 | """Generates an image sequence as a video and stores it to disk."""
190 | batch_sz = opts.batch_size
191 | def input_seq_fn():
192 | dict_inp = data.provide_data(
193 | dataset_name=dataset_name, parent_dir=dataset_parent_dir,
194 | subset=input_sequence_name, batch_size=batch_sz, crop_flag=False,
195 | seeds=None, use_appearance=False, shuffle=0)
196 | x_in = dict_inp['conditional_input']
197 | return x_in
198 |
199 | # Compute appearance embedding only once and use it for all input frames.
200 | app_rgb_path = app_base_path + '_reference.png'
201 | app_rendered_path = app_base_path + '_color.png'
202 | app_depth_path = app_base_path + '_depth.png'
203 | app_sem_path = app_base_path + '_seg_rgb.png'
204 | x_app = _load_and_concatenate_image_channels(
205 | app_rgb_path, app_rendered_path, app_depth_path, app_sem_path)
206 | def seq_with_single_appearance_inp_fn():
207 | """input frames with a fixed latent appearance vector."""
208 | x_in_op = input_seq_fn()
209 | x_app_op = tf.convert_to_tensor(x_app)
210 | x_app_tiled_op = tf.tile(x_app_op, [tf.shape(x_in_op)[0], 1, 1, 1])
211 | return {'conditional_input': x_in_op,
212 | 'peek_input': x_app_tiled_op}
213 |
214 | images = [x for x in est.predict(seq_with_single_appearance_inp_fn)]
215 | for i, gen_img in enumerate(images):
216 | output_file_path = osp.join(output_dir, 'out_%04d.png' % i)
217 | print('Saving frame #%d to %s' % (i, output_file_path))
218 | with tf.gfile.Open(output_file_path, 'wb') as f:
219 | f.write(utils.to_png(gen_img))
220 |
221 |
222 | def train(dataset_name, dataset_parent_dir, load_pretrained_app_encoder,
223 | load_trained_fixed_app, save_samples_kimg=50):
224 | """Main training procedure.
225 |
226 | The trained model is saved in opts.train_dir, the function itself does not
227 | return anything.
228 |
229 | Args:
230 | save_samples_kimg: int, period (in KiB) to save sample images.
231 |
232 | Returns:
233 | None.
234 | """
235 | image_dir = osp.join(opts.train_dir, 'images') # to save validation images.
236 | tf.gfile.MakeDirs(image_dir)
237 | config = tf.estimator.RunConfig(
238 | save_summary_steps=(1 << 10) // opts.batch_size,
239 | save_checkpoints_steps=(save_samples_kimg << 10) // opts.batch_size,
240 | keep_checkpoint_max=5,
241 | log_step_count_steps=1 << 30)
242 | model_dir = opts.train_dir
243 | if (opts.use_appearance and load_trained_fixed_app and
244 | not tf.train.latest_checkpoint(model_dir)):
245 | tf.logging.warning('***** Loading resume_step from %s!' %
246 | opts.fixed_appearance_train_dir)
247 | resume_step = utils.load_global_step_from_checkpoint_dir(
248 | opts.fixed_appearance_train_dir)
249 | else:
250 | tf.logging.warning('***** Loading resume_step (if any) from %s!' %
251 | model_dir)
252 | resume_step = utils.load_global_step_from_checkpoint_dir(model_dir)
253 | if resume_step != 0:
254 | tf.logging.warning('****** Resuming training at %d!' % resume_step)
255 |
256 | model_fn = build_model_fn() # model function for TFEstimator.
257 |
258 | hooks = [utils.HookReport(1 << 12, opts.batch_size)]
259 |
260 | if opts.use_appearance and load_pretrained_app_encoder:
261 | tf.logging.warning('***** will warm-start from %s!' %
262 | opts.appearance_pretrain_dir)
263 | ws = tf.estimator.WarmStartSettings(
264 | ckpt_to_initialize_from=opts.appearance_pretrain_dir,
265 | vars_to_warm_start='appearance_net/.*')
266 | elif opts.use_appearance and load_trained_fixed_app:
267 | tf.logging.warning('****** finetuning will warm-start from %s!' %
268 | opts.fixed_appearance_train_dir)
269 | ws = tf.estimator.WarmStartSettings(
270 | ckpt_to_initialize_from=opts.fixed_appearance_train_dir,
271 | vars_to_warm_start='.*')
272 | else:
273 | ws = None
274 | tf.logging.warning('****** No warm-starting; using random initialization!')
275 |
276 | est = tf.estimator.Estimator(model_fn, model_dir, config, params={},
277 | warm_start_from=ws)
278 |
279 | for next_kimg in range(opts.save_samples_kimg, opts.total_kimg + 1,
280 | opts.save_samples_kimg):
281 | next_step = (next_kimg << 10) // opts.batch_size
282 | if opts.num_crops == -1: # use random crops
283 | crop_seeds = None
284 | else:
285 | crop_seeds = list(100 * np.arange(opts.num_crops))
286 | input_train_fn = functools.partial(
287 | data.provide_data, dataset_name=dataset_name,
288 | parent_dir=dataset_parent_dir, subset='train',
289 | batch_size=opts.batch_size, crop_flag=True,
290 | crop_size=opts.train_resolution, seeds=crop_seeds,
291 | use_appearance=opts.use_appearance)
292 | est.train(input_train_fn, max_steps=next_step, hooks=hooks)
293 | tf.logging.info('DBG: kimg=%d, cur_step=%d' % (next_kimg, next_step))
294 | tf.logging.info('DBG: Saving a validation grid image %06d to %s' % (
295 | next_kimg, image_dir))
296 | make_sample_grid_and_save(est, dataset_name, dataset_parent_dir, (3, 3),
297 | image_dir, next_kimg << 10)
298 |
299 |
300 | def _build_inference_estimator(model_dir):
301 | model_fn = build_model_fn()
302 | est = tf.estimator.Estimator(model_fn, model_dir)
303 | return est
304 |
305 |
306 | def evaluate_sequence(dataset_name, dataset_parent_dir, virtual_seq_name,
307 | app_base_path):
308 | output_dir = osp.join(opts.train_dir, 'seq_output_%s' % virtual_seq_name)
309 | tf.gfile.MakeDirs(output_dir)
310 | est = _build_inference_estimator(opts.train_dir)
311 | visualize_image_sequence(est, dataset_name, dataset_parent_dir,
312 | virtual_seq_name, app_base_path, output_dir)
313 |
314 |
315 | def evaluate_image_set(dataset_name, dataset_parent_dir, subset_suffix,
316 | output_dir=None, batch_size=6):
317 | if output_dir is None:
318 | output_dir = osp.join(opts.train_dir, 'validation_output_%s' % subset_suffix)
319 | tf.gfile.MakeDirs(output_dir)
320 | model_fn_old = build_model_fn()
321 | def model_fn_wrapper(features, labels, mode, params):
322 | del mode
323 | return model_fn_old(features, labels, 'eval_subset', params)
324 | model_dir = opts.train_dir
325 | est = tf.estimator.Estimator(model_fn_wrapper, model_dir)
326 | est_inp_fn = functools.partial(
327 | data.provide_data, dataset_name=dataset_name,
328 | parent_dir=dataset_parent_dir, subset=subset_suffix,
329 | batch_size=batch_size, use_appearance=opts.use_appearance, shuffle=0)
330 |
331 | print('Evaluating images for subset %s' % subset_suffix)
332 | images = [x for x in est.predict(est_inp_fn)]
333 | print('Evaluated %d images' % len(images))
334 | for i, img in enumerate(images):
335 | output_file_path = osp.join(output_dir, 'out_%04d.png' % i)
336 | print('Saving file #%d: %s' % (i, output_file_path))
337 | with tf.gfile.Open(output_file_path, 'wb') as f:
338 | f.write(utils.to_png(img))
339 |
340 |
341 | def _load_and_concatenate_image_channels(rgb_path=None, rendered_path=None,
342 | depth_path=None, seg_path=None,
343 | size_multiple=64):
344 | """Prepares a single input for the network."""
345 | if (rgb_path is None and rendered_path is None and depth_path is None and
346 | seg_path is None):
347 | raise ValueError('At least one of the inputs has to be not None')
348 |
349 | channels = ()
350 | if rgb_path is not None:
351 | rgb_img = np.array(Image.open(rgb_path)).astype(np.float32)
352 | rgb_img = utils.crop_to_multiple(rgb_img, size_multiple)
353 | channels = channels + (rgb_img,)
354 | if rendered_path is not None:
355 | rendered_img = np.array(Image.open(rendered_path)).astype(np.float32)
356 | if not opts.use_alpha:
357 | rendered_img = rendered_img[:, :, :3] # drop the alpha channel
358 | rendered_img = utils.crop_to_multiple(rendered_img, size_multiple)
359 | channels = channels + (rendered_img,)
360 | if depth_path is not None:
361 | depth_img = np.array(Image.open(depth_path))
362 | depth_img = depth_img.astype(np.float32)
363 | depth_img = utils.crop_to_multiple(depth_img[:, :, np.newaxis],
364 | size_multiple)
365 | channels = channels + (depth_img,)
366 | # depth_img = depth_img * (2.0 / 255) - 1.0
367 | if seg_path is not None:
368 | seg_img = np.array(Image.open(seg_path)).astype(np.float32)
369 | seg_img = utils.crop_to_multiple(seg_img, size_multiple)
370 | channels = channels + (seg_img,)
371 | # Concatenate and normalize channels
372 | img = np.dstack(channels)
373 | img = np.expand_dims(img, axis=0)
374 | img = img * (2.0 / 255) - 1.0
375 | return img
376 |
377 |
378 | def infer_dir(model_dir, input_dir, output_dir):
379 | tf.gfile.MakeDirs(output_dir)
380 | est = _build_inference_estimator(opts.train_dir)
381 |
382 | def read_image(base_path, is_appearance=False):
383 | if is_appearance:
384 | ref_img_path = base_path + '_reference.png'
385 | else:
386 | ref_img_path = None
387 | rendered_img_path = base_path + '_color.png'
388 | depth_img_path = base_path + '_depth.png'
389 | seg_img_path = base_path + '_seg_rgb.png'
390 | img = _load_and_concatenate_image_channels(
391 | rgb_path=ref_img_path, rendered_path=rendered_img_path,
392 | depth_path=depth_img_path, seg_path=seg_img_path)
393 | return img
394 |
395 | def get_inference_input_fn(base_path, app_base_path):
396 | x_in = read_image(base_path, False)
397 | x_app_in = read_image(app_base_path, True)
398 | def infer_input_fn():
399 | return {'conditional_input': x_in, 'peek_input': x_app_in}
400 | return infer_input_fn
401 |
402 | file_paths = sorted(glob.glob(osp.join(input_dir, '*_depth.png')))
403 | base_paths = [x[:-10] for x in file_paths] # remove the '_depth.png' suffix
404 | for inp_base_path in base_paths:
405 | est_inp_fn = get_inference_input_fn(inp_base_path, inp_base_path)
406 | img = next(est.predict(est_inp_fn))
407 | basename = osp.basename(inp_base_path)
408 | output_img_path = osp.join(output_dir, basename + '_out.png')
409 | print('Saving generated image to %s' % output_img_path)
410 | with tf.gfile.Open(output_img_path, 'wb') as f:
411 | f.write(utils.to_png(img))
412 |
413 |
414 | def joint_interpolation(model_dir, app_input_dir, st_app_basename,
415 | end_app_basename, camera_path_dir):
416 | """
417 | Interpolates both viewpoint and appearance between two input images.
418 | """
419 | # Create output direcotry
420 | output_dir = osp.join(model_dir, 'joint_interpolation_out')
421 | tf.gfile.MakeDirs(output_dir)
422 |
423 | # Build estimator
424 | model_fn_old = build_model_fn()
425 | def model_fn_wrapper(features, labels, mode, params):
426 | del mode
427 | return model_fn_old(features, labels, 'interpolate_appearance', params)
428 | def appearance_model_fn(features, labels, mode, params):
429 | del mode
430 | return model_fn_old(features, labels, 'compute_appearance', params)
431 | config = tf.estimator.RunConfig(
432 | save_summary_steps=1000, save_checkpoints_steps=50000,
433 | keep_checkpoint_max=50, log_step_count_steps=1 << 30)
434 | model_dir = model_dir
435 | est = tf.estimator.Estimator(model_fn_wrapper, model_dir, config, params={})
436 | est_app = tf.estimator.Estimator(appearance_model_fn, model_dir, config,
437 | params={})
438 |
439 | # Compute appearance embeddings for the two input appearance images.
440 | app_inputs = []
441 | for app_basename in [st_app_basename, end_app_basename]:
442 | app_rgb_path = osp.join(app_input_dir, app_basename + '_reference.png')
443 | app_rendered_path = osp.join(app_input_dir, app_basename + '_color.png')
444 | app_depth_path = osp.join(app_input_dir, app_basename + '_depth.png')
445 | app_seg_path = osp.join(app_input_dir, app_basename + '_seg_rgb.png')
446 | app_in = _load_and_concatenate_image_channels(
447 | rgb_path=app_rgb_path, rendered_path=app_rendered_path,
448 | depth_path=app_depth_path, seg_path=app_seg_path)
449 | # app_inputs.append(tf.convert_to_tensor(app_in))
450 | app_inputs.append(app_in)
451 |
452 | embedding1 = next(est_app.predict(
453 | lambda: {'peek_input': app_inputs[0]}))
454 | embedding1 = np.expand_dims(embedding1, axis=0)
455 | embedding2 = next(est_app.predict(
456 | lambda: {'peek_input': app_inputs[1]}))
457 | embedding2 = np.expand_dims(embedding2, axis=0)
458 |
459 | file_paths = sorted(glob.glob(osp.join(camera_path_dir, '*_depth.png')))
460 | base_paths = [x[:-10] for x in file_paths] # remove the '_depth.png' suffix
461 |
462 | # Compute interpolated appearance embeddings
463 | num_interpolations = len(base_paths)
464 | interpolated_embeddings = []
465 | delta_vec = (embedding2 - embedding1) / (num_interpolations - 1)
466 | for delta_iter in range(num_interpolations):
467 | x_app_embedding = embedding1 + delta_iter * delta_vec
468 | interpolated_embeddings.append(x_app_embedding)
469 |
470 | # Generate and save interpolated images
471 | for frame_idx, embedding in enumerate(interpolated_embeddings):
472 | # Read in input frame
473 | frame_render_path = osp.join(base_paths[frame_idx] + '_color.png')
474 | frame_depth_path = osp.join(base_paths[frame_idx] + '_depth.png')
475 | frame_seg_path = osp.join(base_paths[frame_idx] + '_seg_rgb.png')
476 | x_in = _load_and_concatenate_image_channels(
477 | rgb_path=None, rendered_path=frame_render_path,
478 | depth_path=frame_depth_path, seg_path=frame_seg_path)
479 |
480 | img = next(est.predict(
481 | lambda: {'conditional_input': tf.convert_to_tensor(x_in),
482 | 'appearance_embedding': tf.convert_to_tensor(embedding)}))
483 | output_img_name = '%s_%s_%03d.png' % (st_app_basename, end_app_basename,
484 | frame_idx)
485 | output_img_path = osp.join(output_dir, output_img_name)
486 | print('Saving interpolated image to %s' % output_img_path)
487 | with tf.gfile.Open(output_img_path, 'wb') as f:
488 | f.write(utils.to_png(img))
489 |
490 |
491 | def interpolate_appearance(model_dir, input_dir, target_img_basename,
492 | appearance_img1_basename, appearance_img2_basename):
493 | # Create output direcotry
494 | output_dir = osp.join(model_dir, 'interpolate_appearance_out')
495 | tf.gfile.MakeDirs(output_dir)
496 |
497 | # Build estimator
498 | model_fn_old = build_model_fn()
499 | def model_fn_wrapper(features, labels, mode, params):
500 | del mode
501 | return model_fn_old(features, labels, 'interpolate_appearance', params)
502 | def appearance_model_fn(features, labels, mode, params):
503 | del mode
504 | return model_fn_old(features, labels, 'compute_appearance', params)
505 | config = tf.estimator.RunConfig(
506 | save_summary_steps=1000, save_checkpoints_steps=50000,
507 | keep_checkpoint_max=50, log_step_count_steps=1 << 30)
508 | model_dir = model_dir
509 | est = tf.estimator.Estimator(model_fn_wrapper, model_dir, config, params={})
510 | est_app = tf.estimator.Estimator(appearance_model_fn, model_dir, config,
511 | params={})
512 |
513 | # Compute appearance embeddings for the two input appearance images.
514 | app_inputs = []
515 | for app_basename in [appearance_img1_basename, appearance_img2_basename]:
516 | app_rgb_path = osp.join(input_dir, app_basename + '_reference.png')
517 | app_rendered_path = osp.join(input_dir, app_basename + '_color.png')
518 | app_depth_path = osp.join(input_dir, app_basename + '_depth.png')
519 | app_seg_path = osp.join(input_dir, app_basename + '_seg_rgb.png')
520 | app_in = _load_and_concatenate_image_channels(
521 | rgb_path=app_rgb_path, rendered_path=app_rendered_path,
522 | depth_path=app_depth_path, seg_path=app_seg_path)
523 | # app_inputs.append(tf.convert_to_tensor(app_in))
524 | app_inputs.append(app_in)
525 |
526 | embedding1 = next(est_app.predict(
527 | lambda: {'peek_input': app_inputs[0]}))
528 | embedding2 = next(est_app.predict(
529 | lambda: {'peek_input': app_inputs[1]}))
530 | embedding1 = np.expand_dims(embedding1, axis=0)
531 | embedding2 = np.expand_dims(embedding2, axis=0)
532 |
533 | # Compute interpolated appearance embeddings
534 | num_interpolations = 10
535 | interpolated_embeddings = []
536 | delta_vec = (embedding2 - embedding1) / num_interpolations
537 | for delta_iter in range(num_interpolations + 1):
538 | x_app_embedding = embedding1 + delta_iter * delta_vec
539 | interpolated_embeddings.append(x_app_embedding)
540 |
541 | # Read in the generator input for the target image to render
542 | rendered_img_path = osp.join(input_dir, target_img_basename + '_color.png')
543 | depth_img_path = osp.join(input_dir, target_img_basename + '_depth.png')
544 | seg_img_path = osp.join(input_dir, target_img_basename + '_seg_rgb.png')
545 | x_in = _load_and_concatenate_image_channels(
546 | rgb_path=None, rendered_path=rendered_img_path,
547 | depth_path=depth_img_path, seg_path=seg_img_path)
548 |
549 | # Generate and save interpolated images
550 | for interpolate_iter, embedding in enumerate(interpolated_embeddings):
551 | img = next(est.predict(
552 | lambda: {'conditional_input': tf.convert_to_tensor(x_in),
553 | 'appearance_embedding': tf.convert_to_tensor(embedding)}))
554 | output_img_name = 'interpolate_%s_%s_%s_%03d.png' % (
555 | target_img_basename, appearance_img1_basename, appearance_img2_basename,
556 | interpolate_iter)
557 | output_img_path = osp.join(output_dir, output_img_name)
558 | print('Saving interpolated image to %s' % output_img_path)
559 | with tf.gfile.Open(output_img_path, 'wb') as f:
560 | f.write(utils.to_png(img))
561 |
562 |
563 | def main(argv):
564 | del argv
565 | configs_str = options.list_options()
566 | tf.gfile.MakeDirs(opts.train_dir)
567 | with tf.gfile.Open(osp.join(opts.train_dir, 'configs.txt'), 'wb') as f:
568 | f.write(configs_str)
569 | tf.logging.info('Local configs\n%s' % configs_str)
570 |
571 | if opts.run_mode == 'train':
572 | dataset_name = opts.dataset_name
573 | dataset_parent_dir = opts.dataset_parent_dir
574 | load_pretrained_app_encoder = opts.load_pretrained_app_encoder
575 | load_trained_fixed_app = opts.load_from_another_ckpt
576 | batch_size = opts.batch_size
577 | train(dataset_name, dataset_parent_dir, load_pretrained_app_encoder,
578 | load_trained_fixed_app)
579 | elif opts.run_mode == 'eval': # generate a camera path output sequence from TFRecord inputs.
580 | dataset_name = opts.dataset_name
581 | dataset_parent_dir = opts.dataset_parent_dir
582 | virtual_seq_name = opts.virtual_seq_name
583 | inp_app_img_base_path = opts.inp_app_img_base_path
584 | evaluate_sequence(dataset_name, dataset_parent_dir, virtual_seq_name,
585 | inp_app_img_base_path)
586 | elif opts.run_mode == 'eval_subset': # generate output for validation set (encoded as TFRecords)
587 | dataset_name = opts.dataset_name
588 | dataset_parent_dir = opts.dataset_parent_dir
589 | virtual_seq_name = opts.virtual_seq_name
590 | evaluate_image_set(dataset_name, dataset_parent_dir, virtual_seq_name,
591 | opts.output_validation_dir, opts.batch_size)
592 | elif opts.run_mode == 'eval_dir': # evaluate output for a directory with input images
593 | input_dir = opts.inference_input_path
594 | output_dir = opts.inference_output_dir
595 | model_dir = opts.train_dir
596 | infer_dir(model_dir, input_dir, output_dir)
597 | elif opts.run_mode == 'interpolate_appearance': # interpolate appearance only between two images.
598 | model_dir = opts.train_dir
599 | input_dir = opts.inference_input_path
600 | target_img_basename = opts.target_img_basename
601 | app_img1_basename = opts.appearance_img1_basename
602 | app_img2_basename = opts.appearance_img2_basename
603 | interpolate_appearance(model_dir, input_dir, target_img_basename,
604 | app_img1_basename, app_img2_basename)
605 | elif opts.run_mode == 'joint_interpolation': # interpolate viewpoint and appearance between two images
606 | model_dir = opts.train_dir
607 | app_input_dir = opts.inference_input_path
608 | st_app_basename = opts.appearance_img1_basename
609 | end_app_basename = opts.appearance_img2_basename
610 | frames_dir = opts.frames_dir
611 | joint_interpolation(model_dir, app_input_dir, st_app_basename,
612 | end_app_basename, frames_dir)
613 | else:
614 | raise ValueError('Unsupported --run_mode %s' % opts.run_mode)
615 |
616 |
617 | if __name__ == '__main__':
618 | app.run(main)
619 |
--------------------------------------------------------------------------------
/options.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from absl import flags
17 | import numpy as np
18 |
19 | FLAGS = flags.FLAGS
20 |
21 | # ------------------------------------------------------------------------------
22 | # Train flags
23 | # ------------------------------------------------------------------------------
24 |
25 | # Dataset, model directory and run mode
26 | flags.DEFINE_string('train_dir', '/tmp/nerual_rendering',
27 | 'Directory for model training.')
28 | flags.DEFINE_string('dataset_name', 'sanmarco9k', 'name ID for a dataset.')
29 | flags.DEFINE_string(
30 | 'dataset_parent_dir', '',
31 | 'Directory containing generated tfrecord dataset.')
32 | flags.DEFINE_string('run_mode', 'train', "{'train', 'eval', 'infer'}")
33 | flags.DEFINE_string('imageset_dir', None, 'Directory containing trainset '
34 | 'images for appearance pretraining.')
35 | flags.DEFINE_string('metadata_output_dir', None, 'Directory to save pickled '
36 | 'pairwise distance matrix for appearance pretraining.')
37 | flags.DEFINE_integer('save_samples_kimg', 50, 'kimg cycle to save sample'
38 | 'validation ouptut during training.')
39 |
40 | # Network inputs/outputs
41 | flags.DEFINE_boolean('use_depth', True, 'Add depth image to the deep buffer.')
42 | flags.DEFINE_boolean('use_alpha', False,
43 | 'Add alpha channel to the deep buffer.')
44 | flags.DEFINE_boolean('use_semantic', True,
45 | 'Add semantic map to the deep buffer.')
46 | flags.DEFINE_boolean('use_appearance', True,
47 | 'Capture appearance from an input real image.')
48 | flags.DEFINE_integer('deep_buffer_nc', 7,
49 | 'Number of input channels in the deep buffer.')
50 | flags.DEFINE_integer('appearance_nc', 10,
51 | 'Number of input channels to the appearance encoder.')
52 | flags.DEFINE_integer('output_nc', 3,
53 | 'Number of channels for the generated image.')
54 |
55 | # Staged training flags
56 | flags.DEFINE_string(
57 | 'vgg16_path', './vgg16_weights/vgg16.npy',
58 | 'path to a *.npy file with vgg16 pretrained weights')
59 | flags.DEFINE_boolean('load_pretrained_app_encoder', False,
60 | 'Warmstart appearance encoder with pretrained weights.')
61 | flags.DEFINE_string('appearance_pretrain_dir', '',
62 | 'Model dir for the pretrained appearance encoder.')
63 | flags.DEFINE_boolean('train_app_encoder', False, 'Whether to make the weights '
64 | 'for the appearance encoder trainable or not.')
65 | flags.DEFINE_boolean(
66 | 'load_from_another_ckpt', False, 'Load weights from another trained model, '
67 | 'e.g load model trained with a fixed appearance encoder.')
68 | flags.DEFINE_string('fixed_appearance_train_dir', '',
69 | 'Model dir for training G with a fixed appearance net.')
70 |
71 | # -----------------------------------------------------------------------------
72 |
73 | # More hparams
74 | flags.DEFINE_integer('train_resolution', 256,
75 | 'Crop train images to this resolution.')
76 | flags.DEFINE_float('d_lr', 0.001, 'Learning rate for the discriminator.')
77 | flags.DEFINE_float('g_lr', 0.001, 'Learning rate for the generator.')
78 | flags.DEFINE_float('ez_lr', 0.0001, 'Learning rate for appearance encoder.')
79 | flags.DEFINE_integer('batch_size', 8, 'Batch size for training.')
80 | flags.DEFINE_boolean('use_scaling', True, "use He's scaling.")
81 | flags.DEFINE_integer('num_crops', 30, 'num crops from train images'
82 | '(use -1 for random crops).')
83 | flags.DEFINE_integer('app_vector_size', 8, 'Size of latent appearance vector.')
84 | flags.DEFINE_integer('total_kimg', 20000,
85 | 'Max number (in kilo) of training images for training.')
86 | flags.DEFINE_float('adam_beta1', 0.0, 'beta1 for adam optimizer.')
87 | flags.DEFINE_float('adam_beta2', 0.99, 'beta2 for adam optimizer.')
88 |
89 | # Loss weights
90 | flags.DEFINE_float('w_loss_vgg', 0.3, 'VGG loss weight.')
91 | flags.DEFINE_float('w_loss_feat', 10., 'Feature loss weight (from pix2pixHD).')
92 | flags.DEFINE_float('w_loss_l1', 50., 'L1 loss weight.')
93 | flags.DEFINE_float('w_loss_z_recon', 10., 'Z reconstruction loss weight.')
94 | flags.DEFINE_float('w_loss_gan', 1., 'Adversarial loss weight.')
95 | flags.DEFINE_float('w_loss_z_gan', 1., 'Z adversarial loss weight.')
96 | flags.DEFINE_float('w_loss_kl', 0.01, 'KL divergence weight.')
97 | flags.DEFINE_float('w_loss_l2_reg', 0.01, 'Weight for L2 regression on Z.')
98 |
99 | # -----------------------------------------------------------------------------
100 |
101 | # Architecture and training setup
102 | flags.DEFINE_string('arch_type', 'pggan',
103 | 'Architecture type: {pggan, pix2pixhd}.')
104 | flags.DEFINE_string('training_pipeline', 'staged',
105 | 'Training type type: {staged, bicycle_gan, drit}.')
106 | flags.DEFINE_integer('g_nf', 64,
107 | 'num filters in the first/last layers of U-net.')
108 | flags.DEFINE_boolean('concatenate_skip_layers', True,
109 | 'Use concatenation for skip connections.')
110 |
111 | ## if arch_type == 'pggan':
112 | flags.DEFINE_integer('pggan_n_blocks', 5,
113 | 'Num blocks for the pggan architecture.')
114 | ## if arch_type == 'pix2pixhd':
115 | flags.DEFINE_integer('p2p_n_downsamples', 3,
116 | 'Num downsamples for the pix2pixHD architecture.')
117 | flags.DEFINE_integer('p2p_n_resblocks', 4, 'Num residual blocks at the '
118 | 'end/start of the pix2pixHD encoder/decoder.')
119 | ## if use_drit_pipeline:
120 | flags.DEFINE_boolean('use_concat', True, '"concat" mode from DRIT.')
121 | flags.DEFINE_boolean('normalize_drit_Ez', True, 'Add pixelnorm layers to the '
122 | 'appearance encoder.')
123 | flags.DEFINE_boolean('concat_z_in_all_layers', True, 'Inject z at each '
124 | 'upsampling layer in the decoder (only for DRIT baseline)')
125 | flags.DEFINE_string('inject_z', 'to_bottleneck', 'Method for injecting z; '
126 | 'one of {to_encoder, to_bottleneck}.')
127 | flags.DEFINE_boolean('use_vgg_loss', True, 'vgg v L1 reconstruction loss.')
128 |
129 | # ------------------------------------------------------------------------------
130 | # Inference flags
131 | # ------------------------------------------------------------------------------
132 |
133 | flags.DEFINE_string('inference_input_path', '',
134 | 'Parent directory for input images at inference time.')
135 | flags.DEFINE_string('inference_output_dir', '', 'Output path for inference')
136 | flags.DEFINE_string('target_img_basename', '',
137 | 'basename of target image to render for interpolation')
138 | flags.DEFINE_string('virtual_seq_name', 'full_camera_path',
139 | 'name for the virtual camera path suffix for the TFRecord.')
140 | flags.DEFINE_string('inp_app_img_base_path', '',
141 | 'base path for the input appearance image for camera paths')
142 |
143 | flags.DEFINE_string('appearance_img1_basename', '',
144 | 'basename of the first appearance image for interpolation')
145 | flags.DEFINE_string('appearance_img2_basename', '',
146 | 'basename of the first appearance image for interpolation')
147 | flags.DEFINE_list('input_basenames', [], 'input basenames for inference')
148 | flags.DEFINE_list('input_app_basenames', [], 'input appearance basenames for '
149 | 'inference')
150 | flags.DEFINE_string('frames_dir', '',
151 | 'Folder with input frames to a camera path')
152 | flags.DEFINE_string('output_validation_dir', '',
153 | 'dataset_name for storing results in a structured folder')
154 | flags.DEFINE_string('input_rendered', '',
155 | 'input rendered image name for inference')
156 | flags.DEFINE_string('input_depth', '', 'input depth image name for inference')
157 | flags.DEFINE_string('input_seg', '',
158 | 'input segmentation mask image name for inference')
159 | flags.DEFINE_string('input_app_rgb', '',
160 | 'input appearance rgb image name for inference')
161 | flags.DEFINE_string('input_app_rendered', '',
162 | 'input appearance rendered image name for inference')
163 | flags.DEFINE_string('input_app_depth', '',
164 | 'input appearance depth image name for inference')
165 | flags.DEFINE_string('input_app_seg', '',
166 | 'input appearance segmentation mask image name for'
167 | 'inference')
168 | flags.DEFINE_string('output_img_name', '',
169 | '[OPTIONAL] output image name for inference')
170 |
171 | # -----------------------------------------------------------------------------
172 | # Some validation and assertions
173 | # -----------------------------------------------------------------------------
174 |
175 | def validate_options():
176 | if FLAGS.use_drit_training:
177 | assert FLAGS.use_appearance, 'DRIT pipeline requires --use_appearance'
178 | assert not (
179 | FLAGS.load_pretrained_appearance_encoder and FLAGS.load_from_another_ckpt), (
180 | 'You cannot load weights for the appearance encoder from two different '
181 | 'checkpoints!')
182 | if not FLAGS.use_appearance:
183 | print('**Warning: setting --app_vector_size to 0 since '
184 | '--use_appearance=False!')
185 | FLAGS.set_default('app_vector_size', 0)
186 |
187 | # -----------------------------------------------------------------------------
188 | # Print all options
189 | # -----------------------------------------------------------------------------
190 |
191 | def list_options():
192 | configs = ('# Run flags/options from options.py:\n'
193 | '# ----------------------------------\n')
194 | configs += ('## Train flags:\n'
195 | '## ------------\n')
196 | configs += 'train_dir = %s\n' % FLAGS.train_dir
197 | configs += 'dataset_name = %s\n' % FLAGS.dataset_name
198 | configs += 'dataset_parent_dir = %s\n' % FLAGS.dataset_parent_dir
199 | configs += 'run_mode = %s\n' % FLAGS.run_mode
200 | configs += 'save_samples_kimg = %d\n' % FLAGS.save_samples_kimg
201 | configs += '\n# --------------------------------------------------------\n\n'
202 |
203 | configs += ('## Network inputs and outputs:\n'
204 | '## ---------------------------\n')
205 | configs += 'use_depth = %s\n' % str(FLAGS.use_depth)
206 | configs += 'use_alpha = %s\n' % str(FLAGS.use_alpha)
207 | configs += 'use_semantic = %s\n' % str(FLAGS.use_semantic)
208 | configs += 'use_appearance = %s\n' % str(FLAGS.use_appearance)
209 | configs += 'deep_buffer_nc = %d\n' % FLAGS.deep_buffer_nc
210 | configs += 'appearance_nc = %d\n' % FLAGS.appearance_nc
211 | configs += 'output_nc = %d\n' % FLAGS.output_nc
212 | configs += 'train_resolution = %d\n' % FLAGS.train_resolution
213 | configs += '\n# --------------------------------------------------------\n\n'
214 |
215 | configs += ('## Staged training flags:\n'
216 | '## ----------------------\n')
217 | configs += 'load_pretrained_app_encoder = %s\n' % str(
218 | FLAGS.load_pretrained_app_encoder)
219 | configs += 'appearance_pretrain_dir = %s\n' % FLAGS.appearance_pretrain_dir
220 | configs += 'train_app_encoder = %s\n' % str(FLAGS.train_app_encoder)
221 | configs += 'load_from_another_ckpt = %s\n' % str(FLAGS.load_from_another_ckpt)
222 | configs += 'fixed_appearance_train_dir = %s\n' % str(
223 | FLAGS.fixed_appearance_train_dir)
224 | configs += '\n# --------------------------------------------------------\n\n'
225 |
226 | configs += ('## More hyper-parameters:\n'
227 | '## ----------------------\n')
228 | configs += 'd_lr = %f\n' % FLAGS.d_lr
229 | configs += 'g_lr = %f\n' % FLAGS.g_lr
230 | configs += 'ez_lr = %f\n' % FLAGS.ez_lr
231 | configs += 'batch_size = %d\n' % FLAGS.batch_size
232 | configs += 'use_scaling = %s\n' % str(FLAGS.use_scaling)
233 | configs += 'num_crops = %d\n' % FLAGS.num_crops
234 | configs += 'app_vector_size = %d\n' % FLAGS.app_vector_size
235 | configs += 'total_kimg = %d\n' % FLAGS.total_kimg
236 | configs += 'adam_beta1 = %f\n' % FLAGS.adam_beta1
237 | configs += 'adam_beta2 = %f\n' % FLAGS.adam_beta2
238 | configs += '\n# --------------------------------------------------------\n\n'
239 |
240 | configs += ('## Loss weights:\n'
241 | '## -------------\n')
242 | configs += 'w_loss_vgg = %f\n' % FLAGS.w_loss_vgg
243 | configs += 'w_loss_feat = %f\n' % FLAGS.w_loss_feat
244 | configs += 'w_loss_l1 = %f\n' % FLAGS.w_loss_l1
245 | configs += 'w_loss_z_recon = %f\n' % FLAGS.w_loss_z_recon
246 | configs += 'w_loss_gan = %f\n' % FLAGS.w_loss_gan
247 | configs += 'w_loss_z_gan = %f\n' % FLAGS.w_loss_z_gan
248 | configs += 'w_loss_kl = %f\n' % FLAGS.w_loss_kl
249 | configs += 'w_loss_l2_reg = %f\n' % FLAGS.w_loss_l2_reg
250 | configs += '\n# --------------------------------------------------------\n\n'
251 |
252 | configs += ('## Architecture and training setup:\n'
253 | '## --------------------------------\n')
254 | configs += 'arch_type = %s\n' % FLAGS.arch_type
255 | configs += 'training_pipeline = %s\n' % FLAGS.training_pipeline
256 | configs += 'g_nf = %d\n' % FLAGS.g_nf
257 | configs += 'concatenate_skip_layers = %s\n' % str(
258 | FLAGS.concatenate_skip_layers)
259 | configs += 'p2p_n_downsamples = %d\n' % FLAGS.p2p_n_downsamples
260 | configs += 'p2p_n_resblocks = %d\n' % FLAGS.p2p_n_resblocks
261 | configs += 'use_concat = %s\n' % str(FLAGS.use_concat)
262 | configs += 'normalize_drit_Ez = %s\n' % str(FLAGS.normalize_drit_Ez)
263 | configs += 'inject_z = %s\n' % FLAGS.inject_z
264 | configs += 'concat_z_in_all_layers = %s\n' % str(FLAGS.concat_z_in_all_layers)
265 | configs += 'use_vgg_loss = %s\n' % str(FLAGS.use_vgg_loss)
266 | configs += '\n# --------------------------------------------------------\n\n'
267 |
268 | return configs
269 |
--------------------------------------------------------------------------------
/pretrain_appearance.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from PIL import Image
16 | from absl import app
17 | from absl import flags
18 | from options import FLAGS as opts
19 | import glob
20 | import networks
21 | import numpy as np
22 | import os
23 | import os.path as osp
24 | import pickle
25 | import style_loss
26 | import tensorflow as tf
27 | import utils
28 |
29 |
30 | def _load_and_concatenate_image_channels(
31 | rgb_path=None, rendered_path=None, depth_path=None, seg_path=None,
32 | crop_size=512):
33 | if (rgb_path is None and rendered_path is None and depth_path is None and
34 | seg_path is None):
35 | raise ValueError('At least one of the inputs has to be not None')
36 |
37 | channels = ()
38 | if rgb_path is not None:
39 | rgb_img = np.array(Image.open(rgb_path)).astype(np.float32)
40 | rgb_img = utils.get_central_crop(rgb_img, crop_size, crop_size)
41 | channels = channels + (rgb_img,)
42 | if rendered_path is not None:
43 | rendered_img = np.array(Image.open(rendered_path)).astype(np.float32)
44 | rendered_img = utils.get_central_crop(rendered_img, crop_size, crop_size)
45 | if not opts.use_alpha:
46 | rendered_img = rendered_img[:,:, :3] # drop the alpha channel
47 | channels = channels + (rendered_img,)
48 | if depth_path is not None:
49 | depth_img = np.array(Image.open(depth_path))
50 | depth_img = depth_img.astype(np.float32)
51 | depth_img = utils.get_central_crop(depth_img, crop_size, crop_size)
52 | channels = channels + (depth_img,)
53 | if seg_path is not None:
54 | seg_img = np.array(Image.open(seg_path)).astype(np.float32)
55 | channels = channels + (seg_img,)
56 | # Concatenate and normalize channels
57 | img = np.dstack(channels)
58 | img = img * (2.0 / 255) - 1.0
59 | return img
60 |
61 |
62 | def read_single_appearance_input(rgb_img_path):
63 | base_path = rgb_img_path[:-14] # remove the '_reference.png' suffix
64 | rendered_img_path = base_path + '_color.png'
65 | depth_img_path = base_path + '_depth.png'
66 | semantic_img_path = base_path + '_seg_rgb.png'
67 | network_input_img = _load_and_concatenate_image_channels(
68 | rgb_img_path, rendered_img_path, depth_img_path, semantic_img_path,
69 | crop_size=opts.train_resolution)
70 | return network_input_img
71 |
72 |
73 | def get_triplet_input_fn(dataset_path, dist_file_path=None, k_max_nearest=5,
74 | k_max_farthest=13):
75 | input_images_pattern = osp.join(dataset_path, '*_reference.png')
76 | filenames = sorted(glob.glob(input_images_pattern))
77 | print('DBG: obtained %d input filenames for triplet inputs' % len(filenames))
78 | print('DBG: Computing pairwise style distances:')
79 | if dist_file_path is not None and osp.exists(dist_file_path):
80 | print('*** Loading distance matrix from %s' % dist_file_path)
81 | with open(dist_file_path, 'rb') as f:
82 | dist_matrix = pickle.load(f)['dist_matrix']
83 | print('loaded a dist_matrix of shape: %s' % str(dist_matrix.shape))
84 | else:
85 | dist_matrix = style_loss.compute_pairwise_style_loss_v2(filenames)
86 | dist_dict = {'dist_matrix': dist_matrix}
87 | print('Saving distance matrix to %s' % dist_file_path)
88 | with open(dist_file_path, 'wb') as f:
89 | pickle.dump(dist_dict, f)
90 |
91 | # Sort neighbors for each anchor image
92 | num_imgs = len(dist_matrix)
93 | sorted_neighbors = [np.argsort(dist_matrix[ii, :]) for ii in range(num_imgs)]
94 |
95 | def triplet_input_fn(anchor_idx):
96 | # start from 1 to avoid getting the same image as its own neighbor
97 | positive_neighbor_idx = np.random.randint(1, k_max_nearest + 1)
98 | negative_neighbor_idx = num_imgs - 1 - np.random.randint(0, k_max_farthest)
99 | positive_img_idx = sorted_neighbors[anchor_idx][positive_neighbor_idx]
100 | negative_img_idx = sorted_neighbors[anchor_idx][negative_neighbor_idx]
101 | # Read anchor image
102 | anchor_rgb_path = osp.join(dataset_path, filenames[anchor_idx])
103 | anchor_input = read_single_appearance_input(anchor_rgb_path)
104 | # Read positive image
105 | positive_rgb_path = osp.join(dataset_path, filenames[positive_img_idx])
106 | positive_input = read_single_appearance_input(positive_rgb_path)
107 | # Read negative image
108 | negative_rgb_path = osp.join(dataset_path, filenames[negative_img_idx])
109 | negative_input = read_single_appearance_input(negative_rgb_path)
110 | # Return triplet
111 | return anchor_input, positive_input, negative_input
112 |
113 | return triplet_input_fn
114 |
115 |
116 | def get_tf_triplet_dataset_iter(
117 | dataset_path, trainset_size, dist_file_path, batch_size=4,
118 | deterministic_flag=False, shuffle_buf_size=128, repeat_flag=True):
119 | # Create a dataset of anchor image indices.
120 | idx_dataset = tf.data.Dataset.range(trainset_size)
121 | # Create a mapper function from anchor idx to triplet images.
122 | triplet_mapper = lambda idx: tuple(tf.py_func(
123 | get_triplet_input_fn(dataset_path, dist_file_path), [idx],
124 | [tf.float32, tf.float32, tf.float32]))
125 | # Convert triplet to a dictionary for the estimator input format.
126 | triplet_to_dict_mapper = lambda anchor, pos, neg: {
127 | 'anchor_img': anchor, 'positive_img': pos, 'negative_img': neg}
128 | if repeat_flag:
129 | idx_dataset = idx_dataset.repeat() # Repeat indefinitely.
130 | if not deterministic_flag:
131 | idx_dataset = idx_dataset.shuffle(shuffle_buf_size)
132 | triplet_dataset = idx_dataset.map(
133 | triplet_mapper, num_parallel_calls=max(4, batch_size // 4))
134 | triplet_dataset = triplet_dataset.map(
135 | triplet_to_dict_mapper, num_parallel_calls=max(4, batch_size // 4))
136 | else:
137 | triplet_dataset = idx_dataset.map(triplet_mapper, num_parallel_calls=None)
138 | triplet_dataset = triplet_dataset.map(triplet_to_dict_mapper,
139 | num_parallel_calls=None)
140 | triplet_dataset = triplet_dataset.batch(batch_size)
141 | if not deterministic_flag:
142 | triplet_dataset = triplet_dataset.prefetch(4) # Prefetch a few batches.
143 | return triplet_dataset.make_one_shot_iterator()
144 |
145 |
146 | def build_model_fn(batch_size, lr_app_pretrain=0.0001, adam_beta1=0.0,
147 | adam_beta2=0.99):
148 | def model_fn(features, labels, mode, params):
149 | del labels, params
150 |
151 | step = tf.train.get_global_step()
152 | app_func = networks.DRITAppearanceEncoderConcat(
153 | 'appearance_net', opts.appearance_nc, opts.normalize_drit_Ez)
154 |
155 | if mode == tf.estimator.ModeKeys.TRAIN:
156 | op_increment_step = tf.assign_add(step, 1)
157 | with tf.name_scope('Appearance_Loss'):
158 | anchor_img = features['anchor_img']
159 | positive_img = features['positive_img']
160 | negative_img = features['negative_img']
161 | # Compute embeddings (each of shape [batch_sz, 1, 1, app_vector_sz])
162 | z_anchor, _, _ = app_func(anchor_img)
163 | z_pos, _, _ = app_func(positive_img)
164 | z_neg, _, _ = app_func(negative_img)
165 | # Squeeze into shape of [batch_sz x vec_sz]
166 | anchor_embedding = tf.squeeze(z_anchor, axis=[1, 2], name='z_anchor')
167 | positive_embedding = tf.squeeze(z_pos, axis=[1, 2])
168 | negative_embedding = tf.squeeze(z_neg, axis=[1, 2])
169 | # Compute triplet loss
170 | margin = 0.1
171 | anchor_positive_dist = tf.reduce_sum(
172 | tf.square(anchor_embedding - positive_embedding), axis=1)
173 | anchor_negative_dist = tf.reduce_sum(
174 | tf.square(anchor_embedding - negative_embedding), axis=1)
175 | triplet_loss = anchor_positive_dist - anchor_negative_dist + margin
176 | triplet_loss = tf.maximum(triplet_loss, 0.)
177 | triplet_loss = tf.reduce_sum(triplet_loss) / batch_size
178 | tf.summary.scalar('appearance_triplet_loss', triplet_loss)
179 |
180 | # Image summaries
181 | anchor_rgb = tf.slice(anchor_img, [0, 0, 0, 0], [-1, -1, -1, 3])
182 | positive_rgb = tf.slice(positive_img, [0, 0, 0, 0], [-1, -1, -1, 3])
183 | negative_rgb = tf.slice(negative_img, [0, 0, 0, 0], [-1, -1, -1, 3])
184 | tb_vis = tf.concat([anchor_rgb, positive_rgb, negative_rgb], axis=2)
185 | with tf.name_scope('triplet_vis'):
186 | tf.summary.image('anchor-pos-neg', tb_vis)
187 |
188 | optimizer = tf.train.AdamOptimizer(lr_app_pretrain, adam_beta1,
189 | adam_beta2)
190 | optimizer = tf.contrib.estimator.TowerOptimizer(optimizer)
191 | app_vars = utils.model_vars('appearance_net')[0]
192 | print('\n\n***************************************************')
193 | print('DBG: len(app_vars) = %d' % len(app_vars))
194 | for ii, v in enumerate(app_vars):
195 | print('%03d) %s' % (ii, str(v)))
196 | print('***************************************************\n\n')
197 | app_train_op = optimizer.minimize(triplet_loss, var_list=app_vars)
198 | return tf.estimator.EstimatorSpec(
199 | mode=mode, loss=triplet_loss,
200 | train_op=tf.group(app_train_op, op_increment_step))
201 | elif mode == tf.estimator.ModeKeys.PREDICT:
202 | imgs = features['anchor_img']
203 | embeddings = tf.squeeze(app_func(imgs), axis=[1, 2])
204 | app_vars = utils.model_vars('appearance_net')[0]
205 | tf.train.init_from_checkpoint(osp.join(opts.train_dir),
206 | {'appearance_net/': 'appearance_net/'})
207 | return tf.estimator.EstimatorSpec(mode=mode, predictions=embeddings)
208 | else:
209 | raise ValueError('Unsupported mode for the appearance model: ' + mode)
210 |
211 | return model_fn
212 |
213 |
214 | def compute_dist_matrix(imageset_dir, dist_file_path, recompute_dist=False):
215 | if not recompute_dist and osp.exists(dist_file_path):
216 | print('*** Loading distance matrix from %s' % dist_file_path)
217 | with open(dist_file_path, 'rb') as f:
218 | dist_matrix = pickle.load(f)['dist_matrix']
219 | print('loaded a dist_matrix of shape: %s' % str(dist_matrix.shape))
220 | return dist_matrix
221 | else:
222 | images_paths = sorted(glob.glob(osp.join(imageset_dir, '*_reference.png')))
223 | dist_matrix = style_loss.compute_pairwise_style_loss_v2(images_paths)
224 | dist_dict = {'dist_matrix': dist_matrix}
225 | print('Saving distance matrix to %s' % dist_file_path)
226 | with open(dist_file_path, 'wb') as f:
227 | pickle.dump(dist_dict, f)
228 | return dist_matrix
229 |
230 |
231 | def train_appearance(train_dir, imageset_dir, dist_file_path):
232 | batch_size = 8
233 | lr_app_pretrain = 0.001
234 |
235 | trainset_size = len(glob.glob(osp.join(imageset_dir, '*_reference.png')))
236 | resume_step = utils.load_global_step_from_checkpoint_dir(train_dir)
237 | if resume_step != 0:
238 | tf.logging.warning('DBG: resuming apperance pretraining at %d!' %
239 | resume_step)
240 | model_fn = build_model_fn(batch_size, lr_app_pretrain)
241 | config = tf.estimator.RunConfig(
242 | save_summary_steps=50,
243 | save_checkpoints_steps=500,
244 | keep_checkpoint_max=5,
245 | log_step_count_steps=100)
246 | est = tf.estimator.Estimator(
247 | tf.contrib.estimator.replicate_model_fn(model_fn), train_dir,
248 | config, params={})
249 | # Get input function
250 | input_train_fn = lambda: get_tf_triplet_dataset_iter(
251 | imageset_dir, trainset_size, dist_file_path,
252 | batch_size=batch_size).get_next()
253 | print('Starting pretraining steps...')
254 | est.train(input_train_fn, steps=None, hooks=None) # train indefinitely
255 |
256 |
257 | def main(argv):
258 | if len(argv) > 1:
259 | raise app.UsageError('Too many command-line arguments.')
260 |
261 | train_dir = opts.train_dir
262 | dataset_name = opts.dataset_name
263 | imageset_dir = opts.imageset_dir
264 | output_dir = opts.metadata_output_dir
265 | if not osp.exists(output_dir):
266 | os.makedirs(output_dir)
267 | dist_file_path = osp.join(output_dir, 'dist_%s.pckl' % dataset_name)
268 | compute_dist_matrix(imageset_dir, dist_file_path)
269 | train_appearance(train_dir, imageset_dir, dist_file_path)
270 |
271 | if __name__ == '__main__':
272 | app.run(main)
273 |
--------------------------------------------------------------------------------
/segment_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Generate semantic segmentations
16 | This module uses Xception model trained on ADE20K dataset to generate semantic
17 | segmentation mask to any set of images.
18 | """
19 |
20 | from absl import app
21 | from absl import flags
22 | from PIL import Image
23 | import glob
24 | import matplotlib.pyplot as plt
25 | import numpy as np
26 | import os
27 | import os.path as osp
28 | import shutil
29 | import tensorflow as tf
30 | import utils
31 |
32 |
33 | def get_semantic_color_coding():
34 | """
35 | assigns the 30 (actually 29) semantic colors from cityscapes semantic mapping
36 | to selected classes from the ADE20K150 semantic classes.
37 | """
38 | # Below are the 30 cityscape colors (one is duplicate. so total is 29 not 30)
39 | colors = [
40 | (111, 74, 0),
41 | ( 81, 0, 81),
42 | (128, 64,128),
43 | (244, 35,232),
44 | (250,170,160),
45 | (230,150,140),
46 | ( 70, 70, 70),
47 | (102,102,156),
48 | (190,153,153),
49 | (180,165,180),
50 | (150,100,100),
51 | (150,120, 90),
52 | (153,153,153),
53 | # (153,153,153),
54 | (250,170, 30),
55 | (220,220, 0),
56 | (107,142, 35),
57 | (152,251,152),
58 | ( 70,130,180),
59 | (220, 20, 60),
60 | (255, 0, 0),
61 | ( 0, 0,142),
62 | ( 0, 0, 70),
63 | ( 0, 60,100),
64 | ( 0, 0, 90),
65 | ( 0, 0,110),
66 | ( 0, 80,100),
67 | ( 0, 0,230),
68 | (119, 11, 32),
69 | ( 0, 0,142)]
70 | k_num_ade20k_classes = 150
71 | # initially all 150 classes are mapped to a single color (last color idx: -1)
72 | # Some classes are to be assigned independent colors
73 | # semantic classes are 1-based (1 thru 150)
74 | semantic_to_color_idx = -1 * np.ones(k_num_ade20k_classes + 1, dtype=int)
75 | semantic_to_color_idx [1] = 0 # wall
76 | semantic_to_color_idx [2] = 1 # building;edifice
77 | semantic_to_color_idx [3] = 2 # sky
78 | semantic_to_color_idx [105] = 3 # fountain
79 | semantic_to_color_idx [27] = 4 # sea
80 | semantic_to_color_idx [60] = 5 # stairway;staircase
81 | semantic_to_color_idx [5] = 6 # tree
82 | semantic_to_color_idx [12] = 7 # sidewalk;pavement
83 | semantic_to_color_idx [4] = 7 # floor;flooring
84 | semantic_to_color_idx [7] = 7 # road;route
85 | semantic_to_color_idx [13] = 8 # people
86 | semantic_to_color_idx [18] = 9 # plant;flora;plant;life
87 | semantic_to_color_idx [17] = 10 # mountain;mount
88 | semantic_to_color_idx [20] = 11 # chair
89 | semantic_to_color_idx [6] = 12 # ceiling
90 | semantic_to_color_idx [22] = 13 # water
91 | semantic_to_color_idx [35] = 14 # rock;stone
92 | semantic_to_color_idx [14] = 15 # earth;ground
93 | semantic_to_color_idx [10] = 16 # grass
94 | semantic_to_color_idx [70] = 17 # bench
95 | semantic_to_color_idx [54] = 18 # stairs;steps
96 | semantic_to_color_idx [101] = 19 # poster
97 | semantic_to_color_idx [77] = 20 # boat
98 | semantic_to_color_idx [85] = 21 # tower
99 | semantic_to_color_idx [23] = 22 # painting;picture
100 | semantic_to_color_idx [88] = 23 # streetlight;stree;lamp
101 | semantic_to_color_idx [43] = 24 # column;pillar
102 | semantic_to_color_idx [9] = 25 # window;windowpane
103 | semantic_to_color_idx [15] = 26 # door;
104 | semantic_to_color_idx [133] = 27 # sculpture
105 |
106 | semantic_to_rgb = np.array(
107 | [colors[col_idx][:] for col_idx in semantic_to_color_idx])
108 | return semantic_to_rgb
109 |
110 |
111 | def _apply_colors(seg_images_path, save_dir, idx_to_color):
112 | for i, img_path in enumerate(seg_images_path):
113 | print('processing img #%05d / %05d: %s' % (i, len(seg_images_path),
114 | osp.split(img_path)[1]))
115 | seg = np.array(Image.open(img_path))
116 | seg_rgb = np.zeros(seg.shape + (3,), dtype=np.uint8)
117 | for col_idx in range(len(idx_to_color)):
118 | if idx_to_color[col_idx][0] != -1:
119 | mask = seg == col_idx
120 | seg_rgb[mask, :] = idx_to_color[col_idx][:]
121 |
122 | parent_dir, filename = osp.split(img_path)
123 | basename, ext = osp.splitext(filename)
124 | out_filename = basename + "_rgb.png"
125 | out_filepath = osp.join(save_dir, out_filename)
126 | # Save rescaled segmentation image
127 | Image.fromarray(seg_rgb).save(out_filepath)
128 |
129 |
130 | # The frozen xception model only segments 512x512 images. But it would be better
131 | # to segment the full image instead!
132 | def segment_images(images_path, xception_frozen_graph_path, save_dir,
133 | crop_height=512, crop_width=512):
134 | if not osp.exists(xception_frozen_graph_path):
135 | raise OSError('Xception frozen graph not found at %s' %
136 | xception_frozen_graph_path)
137 | with tf.gfile.GFile(xception_frozen_graph_path, "rb") as f:
138 | graph_def = tf.GraphDef()
139 | graph_def.ParseFromString(f.read())
140 |
141 | with tf.Graph().as_default() as graph:
142 | new_input = tf.placeholder(tf.uint8, [1, crop_height, crop_width, 3],
143 | name="new_input")
144 | tf.import_graph_def(
145 | graph_def,
146 | input_map={"ImageTensor:0": new_input},
147 | return_elements=None,
148 | name="sem_seg",
149 | op_dict=None,
150 | producer_op_list=None
151 | )
152 |
153 | corrupted_dir = osp.join(save_dir, 'corrupted')
154 | if not osp.exists(corrupted_dir):
155 | os.makedirs(corrupted_dir)
156 | with tf.Session(graph=graph) as sess:
157 | for i, img_path in enumerate(images_path):
158 | print('Segmenting image %05d / %05d: %s' % (i + 1, len(images_path),
159 | img_path))
160 | img = np.array(Image.open(img_path))
161 | if len(img.shape) == 2 or img.shape[2] != 3:
162 | print('Warning! corrupted image %s' % img_path)
163 | img_base_path = img_path[:-14] # remove the '_reference.png' suffix
164 | srcs = sorted(glob.glob(img_base_path + '_*'))
165 | dest = unicode(corrupted_dir + '/.')
166 | for src in srcs:
167 | shutil.move(src, dest)
168 | continue
169 | img = utils.get_central_crop(img, crop_height=crop_height,
170 | crop_width=crop_width)
171 | img = np.expand_dims(img, 0) # convert to NHWC format
172 | seg = sess.run("sem_seg/SemanticPredictions:0", feed_dict={
173 | new_input: img})
174 | assert np.max(seg[:]) <= 255, 'segmentation image is not of type uint8!'
175 | seg = np.squeeze(np.uint8(seg)) # convert to uint8 and squeeze to WxH.
176 | parent_dir, filename = osp.split(img_path)
177 | basename, ext = osp.splitext(filename)
178 | basename = basename[:-10] # remove the '_reference' suffix
179 | seg_filename = basename + "_seg.png"
180 | seg_filepath = osp.join(save_dir, seg_filename)
181 | # Save segmentation image
182 | Image.fromarray(seg).save(seg_filepath)
183 |
184 | def segment_and_color_dataset(dataset_dir, xception_frozen_graph_path,
185 | splits=None, resegment_images=True):
186 | if splits is None:
187 | imgs_dirs = [dataset_dir]
188 | else:
189 | imgs_dirs = [osp.join(dataset_dir, split) for split in splits]
190 |
191 | for cur_dir in imgs_dirs:
192 | imgs_file_pattern = osp.join(cur_dir, '*_reference.png')
193 | images_path = sorted(glob.glob(imgs_file_pattern))
194 | if resegment_images:
195 | segment_images(images_path, xception_frozen_graph_path, cur_dir,
196 | crop_height=512, crop_width=512)
197 |
198 | idx_to_col = get_semantic_color_coding()
199 |
200 | for cur_dir in imgs_dirs:
201 | save_dir = cur_dir
202 | seg_file_pattern = osp.join(cur_dir, '*_seg.png')
203 | seg_imgs_paths = sorted(glob.glob(seg_file_pattern))
204 | _apply_colors(seg_imgs_paths, save_dir, idx_to_col)
205 |
--------------------------------------------------------------------------------
/staged_model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Neural re-rerendering in the wild.
16 |
17 | Implementation of the staged training pipeline.
18 | """
19 |
20 | from options import FLAGS as opts
21 | import losses
22 | import networks
23 | import tensorflow as tf
24 | import utils
25 |
26 |
27 | def create_computation_graph(x_in, x_gt, x_app=None, arch_type='pggan',
28 | use_appearance=True):
29 | """Create the models and the losses.
30 |
31 | Args:
32 | x_in: 4D tensor, batch of conditional input images in NHWC format.
33 | x_gt: 2D tensor, batch ground-truth images in NHWC format.
34 | x_app: 4D tensor, batch of input appearance images.
35 |
36 | Returns:
37 | Dictionary of placeholders and TF graph functions.
38 | """
39 | # ---------------------------------------------------------------------------
40 | # Build models/networks
41 | # ---------------------------------------------------------------------------
42 |
43 | rerenderer = networks.RenderingModel(arch_type, use_appearance)
44 | app_enc = rerenderer.get_appearance_encoder()
45 | discriminator = networks.MultiScaleDiscriminator(
46 | 'd_model', opts.appearance_nc, num_scales=3, nf=64, n_layers=3,
47 | get_fmaps=False)
48 |
49 | # ---------------------------------------------------------------------------
50 | # Forward pass
51 | # ---------------------------------------------------------------------------
52 |
53 | if opts.use_appearance:
54 | z_app, _, _ = app_enc(x_app)
55 | else:
56 | z_app = None
57 |
58 | y = rerenderer(x_in, z_app)
59 |
60 | # ---------------------------------------------------------------------------
61 | # Losses
62 | # ---------------------------------------------------------------------------
63 |
64 | w_loss_gan = opts.w_loss_gan
65 | w_loss_recon = opts.w_loss_vgg if opts.use_vgg_loss else opts.w_loss_l1
66 |
67 | # compute discriminator logits
68 | disc_real_featmaps = discriminator(x_gt, x_in)
69 | disc_fake_featmaps = discriminator(y, x_in)
70 |
71 | # discriminator loss
72 | loss_d_real = losses.multiscale_discriminator_loss(disc_real_featmaps, True)
73 | loss_d_fake = losses.multiscale_discriminator_loss(disc_fake_featmaps, False)
74 | loss_d = loss_d_real + loss_d_fake
75 |
76 | # generator loss
77 | loss_g_gan = losses.multiscale_discriminator_loss(disc_fake_featmaps, True)
78 | if opts.use_vgg_loss:
79 | vgg_layers = ['conv%d_2' % i for i in range(1, 6)] # conv1 through conv5
80 | vgg_layer_weights = [1./32, 1./16, 1./8, 1./4, 1.]
81 | vgg_loss = losses.PerceptualLoss(y, x_gt, [256, 256, 3], vgg_layers,
82 | vgg_layer_weights) # NOTE: shouldn't hardcode image size!
83 | loss_g_recon = vgg_loss()
84 | else:
85 | loss_g_recon = losses.L1_loss(y, x_gt)
86 | loss_g = w_loss_gan * loss_g_gan + w_loss_recon * loss_g_recon
87 |
88 | # ---------------------------------------------------------------------------
89 | # Tensorboard visualizations
90 | # ---------------------------------------------------------------------------
91 |
92 | x_in_render = tf.slice(x_in, [0, 0, 0, 0], [-1, -1, -1, 3])
93 | if opts.use_semantic:
94 | x_in_semantic = tf.slice(x_in, [0, 0, 0, 4], [-1, -1, -1, 3])
95 | tb_visualization = tf.concat([x_in_render, x_in_semantic, y, x_gt], axis=2)
96 | else:
97 | tb_visualization = tf.concat([x_in_render, y, x_gt], axis=2)
98 | tf.summary.image('rendered-semantic-generated-gt tuple', tb_visualization)
99 |
100 | # Show input appearance images
101 | if opts.use_appearance:
102 | x_app_rgb = tf.slice(x_app, [0, 0, 0, 0], [-1, -1, -1, 3])
103 | x_app_sem = tf.slice(x_app, [0, 0, 0, 7], [-1, -1, -1, -1])
104 | tb_app_visualization = tf.concat([x_app_rgb, x_app_sem], axis=2)
105 | tf.summary.image('input appearance image', tb_app_visualization)
106 |
107 | # Loss summaries
108 | with tf.name_scope('Discriminator_Loss'):
109 | tf.summary.scalar('D_real_loss', loss_d_real)
110 | tf.summary.scalar('D_fake_loss', loss_d_fake)
111 | tf.summary.scalar('D_total_loss', loss_d)
112 | with tf.name_scope('Generator_Loss'):
113 | tf.summary.scalar('G_GAN_loss', w_loss_gan * loss_g_gan)
114 | tf.summary.scalar('G_reconstruction_loss', w_loss_recon * loss_g_recon)
115 | tf.summary.scalar('G_total_loss', loss_g)
116 |
117 | # ---------------------------------------------------------------------------
118 | # Optimizers
119 | # ---------------------------------------------------------------------------
120 |
121 | def get_optimizer(lr, loss, var_list):
122 | optimizer = tf.train.AdamOptimizer(lr, opts.adam_beta1, opts.adam_beta2)
123 | # optimizer = tf.contrib.estimator.TowerOptimizer(optimizer)
124 | return optimizer.minimize(loss, var_list=var_list)
125 |
126 | # Training ops.
127 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
128 | with tf.control_dependencies(update_ops):
129 | with tf.variable_scope('optimizers'):
130 | d_vars = utils.model_vars('d_model')[0]
131 | g_vars_all = utils.model_vars('g_model')[0]
132 | train_d = [get_optimizer(opts.d_lr, loss_d, d_vars)]
133 | train_g = [get_optimizer(opts.g_lr, loss_g, g_vars_all)]
134 |
135 | train_app_encoder = []
136 | if opts.train_app_encoder:
137 | lr_app = opts.ez_lr
138 | app_enc_vars = utils.model_vars('appearance_net')[0]
139 | train_app_encoder.append(get_optimizer(lr_app, loss_g, app_enc_vars))
140 |
141 | ema = tf.train.ExponentialMovingAverage(decay=0.999)
142 | with tf.control_dependencies(train_g + train_app_encoder):
143 | inference_vars_all = g_vars_all
144 | if opts.use_appearance:
145 | app_enc_vars = utils.model_vars('appearance_net')[0]
146 | inference_vars_all += app_enc_vars
147 | ema_op = ema.apply(inference_vars_all)
148 |
149 | print('***************************************************')
150 | print('len(g_vars_all) = %d' % len(g_vars_all))
151 | for ii, v in enumerate(g_vars_all):
152 | print('%03d) %s' % (ii, str(v)))
153 | print('-------------------------------------------------------')
154 | print('len(d_vars) = %d' % len(d_vars))
155 | for ii, v in enumerate(d_vars):
156 | print('%03d) %s' % (ii, str(v)))
157 | if opts.train_app_encoder:
158 | print('-------------------------------------------------------')
159 | print('len(app_enc_vars) = %d' % len(app_enc_vars))
160 | for ii, v in enumerate(app_enc_vars):
161 | print('%03d) %s' % (ii, str(v)))
162 | print('***************************************************\n\n')
163 |
164 | return {
165 | 'train_disc_op': tf.group(train_d),
166 | 'train_renderer_op': ema_op,
167 | 'total_loss_d': loss_d,
168 | 'loss_d_real': loss_d_real,
169 | 'loss_d_fake': loss_d_fake,
170 | 'loss_g_gan': w_loss_gan * loss_g_gan,
171 | 'loss_g_recon': w_loss_recon * loss_g_recon,
172 | 'total_loss_g': loss_g}
173 |
--------------------------------------------------------------------------------
/style_loss.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from PIL import Image
16 | from options import FLAGS as opts
17 | import data
18 | import layers
19 | import numpy as np
20 | import tensorflow as tf
21 | import utils
22 | import vgg16
23 |
24 |
25 | def gram_matrix(layer):
26 | """Computes the gram_matrix for a batch of single vgg layer
27 | Input:
28 | layer: a batch of vgg activations for a single conv layer
29 | Returns:
30 | gram: [batch_sz x num_channels x num_channels]: a batch of gram matrices
31 | """
32 | batch_size, height, width, num_channels = layer.get_shape().as_list()
33 | features = tf.reshape(layer, [batch_size, height * width, num_channels])
34 | num_elements = tf.constant(num_channels * height * width, tf.float32)
35 | gram = tf.matmul(features, features, adjoint_a=True) / num_elements
36 | return gram
37 |
38 |
39 | def compute_gram_matrices(
40 | images, vgg_layers=['conv1_2', 'conv2_2', 'conv3_2', 'conv4_2', 'conv5_2']):
41 | """Computes the gram matrix representation of a batch of images"""
42 | vgg_net = vgg16.Vgg16(opts.vgg16_path)
43 | vgg_acts = vgg_net.get_vgg_activations(images, vgg_layers)
44 | grams = [gram_matrix(layer) for layer in vgg_acts]
45 | return grams
46 |
47 |
48 | def compute_pairwise_style_loss_v2(image_paths_list):
49 | grams_all = [None] * len(image_paths_list)
50 | crop_height, crop_width = opts.train_resolution, opts.train_resolution
51 | img_var = tf.placeholder(tf.float32, shape=[1, crop_height, crop_width, 3])
52 | vgg_layers = ['conv%d_2' % i for i in range(1, 6)] # conv1 through conv5
53 | grams_ops = compute_gram_matrices(img_var, vgg_layers)
54 | with tf.Session() as sess:
55 | for ii, img_path in enumerate(image_paths_list):
56 | print('Computing gram matrices for image #%d' % (ii + 1))
57 | img = np.array(Image.open(img_path), dtype=np.float32)
58 | img = img * 2. / 255. - 1 # normalize image
59 | img = utils.get_central_crop(img, crop_height, crop_width)
60 | img = np.expand_dims(img, axis=0)
61 | grams_all[ii] = sess.run(grams_ops, feed_dict={img_var: img})
62 | print('Number of images = %d' % len(grams_all))
63 | print('Gram matrices per image:')
64 | for i in range(len(grams_all[0])):
65 | print('gram_matrix[%d].shape = %s' % (i, grams_all[0][i].shape))
66 | n_imgs = len(grams_all)
67 | dist_matrix = np.zeros((n_imgs, n_imgs))
68 | for i in range(n_imgs):
69 | print('Computing distances for image #%d' % i)
70 | for j in range(i + 1, n_imgs):
71 | loss_style = 0
72 | # Compute loss using all gram matrices from all layers
73 | for gram_i, gram_j in zip(grams_all[i], grams_all[j]):
74 | loss_style += np.mean((gram_i - gram_j) ** 2, axis=(1, 2))
75 | dist_matrix[i][j] = dist_matrix[j][i] = loss_style
76 |
77 | return dist_matrix
78 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utilities for GANs.
16 |
17 | Basic functions such as generating sample grid, exporting to PNG, etc...
18 | """
19 |
20 | import functools
21 | import numpy as np
22 | import os.path
23 | import tensorflow as tf
24 | import time
25 |
26 |
27 | def crop_to_multiple(img, size_multiple=64):
28 | """ Crops the image so that its dimensions are multiples of size_multiple."""
29 | new_width = (img.shape[1] // size_multiple) * size_multiple
30 | new_height = (img.shape[0] // size_multiple) * size_multiple
31 | offset_x = (img.shape[1] - new_width) // 2
32 | offset_y = (img.shape[0] - new_height) // 2
33 | return img[offset_y:offset_y + new_height, offset_x:offset_x + new_width, :]
34 |
35 |
36 | def get_central_crop(img, crop_height=512, crop_width=512):
37 | if len(img.shape) == 2:
38 | img = np.expand_dims(img, axis=2)
39 | assert len(img.shape) == 3, ('input image should be either a 2D or 3D matrix,'
40 | ' but input was of shape %s' % str(img.shape))
41 | height, width, _ = img.shape
42 | assert height >= crop_height and width >= crop_width, ('input image cannot '
43 | 'be smaller than the requested crop size')
44 | st_y = (height - crop_height) // 2
45 | st_x = (width - crop_width) // 2
46 | return np.squeeze(img[st_y : st_y + crop_height, st_x : st_x + crop_width, :])
47 |
48 |
49 | def load_global_step_from_checkpoint_dir(checkpoint_dir):
50 | """Loads the global step from the checkpoint directory.
51 |
52 | Args:
53 | checkpoint_dir: string, path to the checkpoint directory.
54 |
55 | Returns:
56 | int, the global step of the latest checkpoint or 0 if none was found.
57 | """
58 | try:
59 | checkpoint_reader = tf.train.NewCheckpointReader(
60 | tf.train.latest_checkpoint(checkpoint_dir))
61 | return checkpoint_reader.get_tensor(tf.GraphKeys.GLOBAL_STEP)
62 | except:
63 | return 0
64 |
65 |
66 | def model_vars(prefix):
67 | """Return trainable variables matching a prefix.
68 |
69 | Args:
70 | prefix: string, the prefix variable names must match.
71 |
72 | Returns:
73 | a tuple (match, others) of TF variables, 'match' contains the matched
74 | variables and 'others' contains the remaining variables.
75 | """
76 | match, no_match = [], []
77 | for x in tf.trainable_variables():
78 | if x.name.startswith(prefix):
79 | match.append(x)
80 | else:
81 | no_match.append(x)
82 | return match, no_match
83 |
84 |
85 | def to_png(x):
86 | """Convert a 3D tensor to png.
87 |
88 | Args:
89 | x: Tensor, 01C formatted input image.
90 |
91 | Returns:
92 | Tensor, 1D string representing the image in png format.
93 | """
94 | with tf.Graph().as_default():
95 | with tf.Session() as sess_temp:
96 | x = tf.constant(x)
97 | y = tf.image.encode_png(
98 | tf.cast(
99 | tf.clip_by_value(tf.round(127.5 + 127.5 * x), 0, 255), tf.uint8),
100 | compression=9)
101 | return sess_temp.run(y)
102 |
103 |
104 | def images_to_grid(images):
105 | """Converts a grid of images (5D tensor) to a single image.
106 |
107 | Args:
108 | images: 5D tensor (count_y, count_x, height, width, colors), grid of images.
109 |
110 | Returns:
111 | a 3D tensor image of shape (count_y * height, count_x * width, colors).
112 | """
113 | ny, nx, h, w, c = images.shape
114 | images = images.transpose(0, 2, 1, 3, 4)
115 | images = images.reshape([ny * h, nx * w, c])
116 | return images
117 |
118 |
119 | def save_images(image, output_dir, cur_nimg):
120 | """Saves images to disk.
121 |
122 | Saves a file called 'name.png' containing the latest samples from the
123 | generator and a file called 'name_123.png' where 123 is the KiB of trained
124 | images.
125 |
126 | Args:
127 | image: 3D numpy array (height, width, colors), the image to save.
128 | output_dir: string, the directory where to save the image.
129 | cur_nimg: int, current number of images seen by training.
130 |
131 | Returns:
132 | None
133 | """
134 | for name in ('name.png', 'name_%06d.png' % (cur_nimg >> 10)):
135 | with tf.gfile.Open(os.path.join(output_dir, name), 'wb') as f:
136 | f.write(image)
137 |
138 |
139 | class HookReport(tf.train.SessionRunHook):
140 | """Custom reporting hook.
141 |
142 | Register your tensor scalars with HookReport.log_tensor(my_tensor, 'my_name').
143 | This hook will report their average values over report period argument
144 | provided to the constructed. The values are printed in the order the tensors
145 | were registered.
146 |
147 | Attributes:
148 | step: int, the current global step.
149 | active: bool, whether logging is active or disabled.
150 | """
151 | _REPORT_KEY = 'report'
152 | _TENSOR_NAMES = {}
153 |
154 | def __init__(self, period, batch_size):
155 | self.step = 0
156 | self.active = True
157 | self._period = period // batch_size
158 | self._batch_size = batch_size
159 | self._sums = np.array([])
160 | self._count = 0
161 | self._nimgs_per_cycle = 0
162 | self._step_ratio = 0
163 | self._start = time.time()
164 | self._nimgs = 0
165 | self._batch_size = batch_size
166 |
167 | def disable(self):
168 | parent = self
169 |
170 | class Disabler(object):
171 |
172 | def __enter__(self):
173 | parent.active = False
174 | return parent
175 |
176 | def __exit__(self, exc_type, exc_val, exc_tb):
177 | parent.active = True
178 |
179 | return Disabler()
180 |
181 | def begin(self):
182 | self.active = True
183 | self._count = 0
184 | self._nimgs_per_cycle = 0
185 | self._start = time.time()
186 |
187 | def before_run(self, run_context):
188 | if not self.active:
189 | return
190 | del run_context
191 | fetches = tf.get_collection(self._REPORT_KEY)
192 | return tf.train.SessionRunArgs(fetches)
193 |
194 | def after_run(self, run_context, run_values):
195 | if not self.active:
196 | return
197 | del run_context
198 | results = run_values.results
199 | # Note: sometimes the returned step is incorrect (off by one) for some
200 | # unknown reason.
201 | self.step = results[-1] + 1
202 | self._count += 1
203 | self._nimgs_per_cycle += self._batch_size
204 | self._nimgs += self._batch_size
205 |
206 | if not self._sums.size:
207 | self._sums = np.array(results[:-1], 'd')
208 | else:
209 | self._sums += np.array(results[:-1], 'd')
210 |
211 | if self.step // self._period != self._step_ratio:
212 | fetches = tf.get_collection(self._REPORT_KEY)[:-1]
213 | stats = ' '.join('%s=% .2f' % (self._TENSOR_NAMES[tensor],
214 | value / self._count)
215 | for tensor, value in zip(fetches, self._sums))
216 | stop = time.time()
217 | tf.logging.info('step=%d, kimg=%d %s [%.2f img/s]' %
218 | (self.step, ((self.step * self._batch_size) >> 10),
219 | stats, self._nimgs_per_cycle / (stop - self._start)))
220 | self._step_ratio = self.step // self._period
221 | self._start = stop
222 | self._sums *= 0
223 | self._count = 0
224 | self._nimgs_per_cycle = 0
225 |
226 | def end(self, session=None):
227 | del session
228 |
229 | @classmethod
230 | def log_tensor(cls, tensor, name):
231 | """Adds a tensor to be reported by the hook.
232 |
233 | Args:
234 | tensor: `tensor scalar`, a value to report.
235 | name: string, the name to give the value in the report.
236 |
237 | Returns:
238 | None.
239 | """
240 | cls._TENSOR_NAMES[tensor] = name
241 | tf.add_to_collection(cls._REPORT_KEY, tensor)
242 |
--------------------------------------------------------------------------------