├── LICENSE ├── README.md ├── bilinear_sampler.py ├── blendingnetwork.py ├── make_data_txt.py ├── test.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Blending for Free-Viewpoint Image-Based Rendering 2 | [Peter Hedman](http://www.phogzone.com), 3 | [Julien Philip](https://www-sop.inria.fr/members/Julien.Philip), 4 | [True Price](https://www.cs.unc.edu/~jtprice), 5 | [Jan-Michael Frahm](http://frahm.web.unc.edu), 6 | [George Drettakis](https://www-sop.inria.fr/members/George.Drettakis), and 7 | [Gabriel Brostow](http://www0.cs.ucl.ac.uk/staff/G.Brostow). *SIGGRAPH Asia 2018*. 8 | 9 | http://visual.cs.ucl.ac.uk/pubs/deepblending 10 | 11 | Teaser video 14 | 15 | 16 | This repository contains the training and test code for our blending network. 17 | 18 | ## New! Source code for rendering and geometry refinement. 19 | 20 | You can find a full source code release [here](https://gitlab.inria.fr/sibr/projects/inside_out_deep_blending). 21 | 22 | ## Usage 23 | 24 | Only tested on [Ubuntu 16.04](http://releases.ubuntu.com/16.04/) and [NVIDIA](http://www.nvidia.com) GPUs. 25 | 26 | ### Prerequisites 27 | 28 | Download and install [Python 3](https://www.python.org/download/releases/3.0/), [Pip](https://pip.pypa.io/en/stable/installing/), and [Tensorflow-GPU](https://www.tensorflow.org/install/gpu). 29 | 30 | Required Python packages: *Numpy*, *Scipy*, and *PIL*. 31 | 32 | Download the [training data](https://repo-sam.inria.fr/fungraph/deep-blending/data/DeepBlendingTrainingData.zip), and the [test data](https://repo-sam.inria.fr/fungraph/deep-blending/data/DeepBlendingTestDataAndResults.zip). Unzip these datasets. 33 | 34 | ### Training 35 | 36 | For training, we need to set a few more things up. 37 | 38 | 1) Download pretrained VGG weights 39 | ``` 40 | cd [DEEP_BLENDING_CODE_FOLDER] 41 | wget http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz 42 | tar -xvf vgg_16_2016_08_28.tar.gz 43 | rm vgg_16_2016_08_28.tar.gz 44 | ``` 45 | 46 | 2) Create an input txt file listing all the training images. 47 | ``` 48 | cd [DEEP_BLENDING_CODE_FOLDER] 49 | DATASET_DIR=[PATH_TO_UNZIPPED_TRAINING_DATA] 50 | for scene in `ls $DATASET_DIR | grep -v txt` ; 51 | do 52 | python3 make_data_txt.py $DATASET_DIR/$scene > ${scene}_training.txt ; 53 | cat ${scene}_training.txt >> training_all_ordered.txt ; 54 | rm ${scene}_training.txt ; 55 | done ; 56 | shuf training_all_ordered.txt > $DATASET_DIR/training.txt ; 57 | rm training_all_ordered.txt ; 58 | ``` 59 | 60 | 3) Create an input txt file listing all the validation images. 61 | ``` 62 | cd [DEEP_BLENDING_CODE_FOLDER] 63 | DATASET_DIR=[PATH_TO_UNZIPPED_TRAINING_DATA] 64 | for scene in `ls $DATASET_DIR | grep -v txt` ; 65 | do 66 | python3 make_data_txt.py $DATASET_DIR/$scene validation > ${scene}_validation.txt ; 67 | cat ${scene}_validation.txt >> validation_all_ordered.txt ; 68 | rm ${scene}_validation.txt ; 69 | done ; 70 | shuf validation_all_ordered.txt > $DATASET_DIR/validation.txt ; 71 | rm validation_all_ordered.txt ; 72 | ``` 73 | 74 | Now we're ready to train the network. This command trains the network with the parameters used in the paper: 75 | ``` 76 | cd [DEEP_BLENDING_CODE_FOLDER] 77 | python3 train.py [PATH_TO_UNZIPPED_TRAINING_DATA] NETWORK_OUTPUT_FOLDER --training_file=[PATH_TO_UNZIPPED_TRAINING_DATA]/training.txt --validation_file=[PATH_TO_UNZIPPED_TRAINING_DATA]/validation.txt 78 | ``` 79 | 80 | However, you can also use the following command line parameters to train a different version of the network: 81 | 82 | `--loss_function`, 83 | determines the mage loss to be used for training (*L1*, *VGG*, or *VGG_AND_L1*). Defaults to *VGG_AND_L1*. 84 | 85 | `--direct_regression`, 86 | directly regresses the output image instead of predicting blend weights. Off by default. 87 | 88 | `--no_temporal_loss`, 89 | trains the network without a temporal loss. Off by default. 90 | 91 | `--no_textured_mesh`, 92 | disable the input layer from the textured mesh. Off by default. 93 | 94 | `--num_input_mosaics`, 95 | number of input mosaic layers to use. Defaults to 4. 96 | 97 | `--temporal_alpha`, 98 | relative strength of the temporal loss. Defaults to 0.33. 99 | 100 | `--debug`, 101 | debug mode for training, shows more intermediate outputs in tensorboard. Off by default. 102 | 103 | `--num_batches`, 104 | training duration (in terms of number of minibatches). Defaults to 256000. 105 | 106 | `--batch_size`, 107 | batch size to be used for training. Defaults to 8. 108 | 109 | `--crop`, 110 | crop size for data augmentation. Defaults to 256. 111 | 112 | 113 | ### Testing 114 | 115 | Run the following command: 116 | ``` 117 | cd [DEEP_BLENDING_CODE_FOLDER] 118 | python3 test.py [PATH_TO_TEST_SCENE] [OUTPUT_DIRECTORY] --model_path=[NETWORK_OUTPUT_FOLDER] 119 | ``` 120 | 121 | **IMPORTANT** If you trained a network using custom command-line parameters, make sure that they match when you run the network in test mode! 122 | 123 | For testing on new scenes, you need to create a txt file which lists all the inputs: 124 | ``` 125 | cd [DEEP_BLENDING_CODE_FOLDER] 126 | TEST_SCENE_DIR=[PATH_TO_TEST_SCENE] 127 | python3 make_data_txt.py $TEST_SCENE_DIR/testdump test > $TEST_SCENE_DIR/test.txt 128 | ``` 129 | -------------------------------------------------------------------------------- /bilinear_sampler.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | # Copyright 2017 Modifications Clement Godard. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | 18 | import tensorflow as tf 19 | 20 | def bilinear_sampler(input_images, y_offset, x_offset, wrap_mode='border', name='bilinear_sampler', **kwargs): 21 | def _repeat(x, n_repeats): 22 | with tf.variable_scope('_repeat'): 23 | rep = tf.tile(tf.expand_dims(x, 1), [1, n_repeats]) 24 | return tf.reshape(rep, [-1]) 25 | 26 | def _interpolate(im, x, y): 27 | with tf.variable_scope('_interpolate'): 28 | 29 | # handle both texture border types 30 | _edge_size = 0 31 | if _wrap_mode == 'border': 32 | _edge_size = 1 33 | im = tf.pad(im, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='CONSTANT') 34 | x = x + _edge_size 35 | y = y + _edge_size 36 | elif _wrap_mode == 'edge': 37 | _edge_size = 0 38 | else: 39 | return None 40 | 41 | x = tf.clip_by_value(x, 0.0, _width_f - 1 + 2 * _edge_size) 42 | y = tf.clip_by_value(y, 0.0, _height_f - 1 + 2 * _edge_size) 43 | 44 | x0_f = tf.floor(x) 45 | y0_f = tf.floor(y) 46 | x1_f = x0_f + 1 47 | y1_f = y0_f + 1 48 | x0 = tf.cast(x0_f, tf.int32) 49 | y0 = tf.cast(y0_f, tf.int32) 50 | x1 = tf.cast(tf.minimum(x1_f, _width_f - 1 + 2 * _edge_size), tf.int32) 51 | y1 = tf.cast(tf.minimum(y1_f, _height_f - 1 + 2 * _edge_size), tf.int32) 52 | 53 | dim2 = (_width + 2 * _edge_size) 54 | dim1 = (_width + 2 * _edge_size) * (_height + 2 * _edge_size) 55 | base = _repeat(tf.range(_num_batch) * dim1, _height * _width) 56 | base_y0 = base + y0 * dim2 57 | base_y1 = base + y1 * dim2 58 | idx_a = base_y0 + x0 59 | idx_b = base_y1 + x0 60 | idx_c = base_y0 + x1 61 | idx_d = base_y1 + x1 62 | 63 | im_flat = tf.reshape(im, tf.stack([-1, _num_channels])) 64 | 65 | Ia = tf.gather(im_flat, idx_a) 66 | Ib = tf.gather(im_flat, idx_b) 67 | Ic = tf.gather(im_flat, idx_c) 68 | Id = tf.gather(im_flat, idx_d) 69 | 70 | wa = tf.expand_dims(((x1_f-x) * (y1_f-y)), 1) 71 | wb = tf.expand_dims(((x1_f-x) * (y-y0_f)), 1) 72 | wc = tf.expand_dims(((x-x0_f) * (y1_f-y)), 1) 73 | wd = tf.expand_dims(((x-x0_f) * (y-y0_f)), 1) 74 | output = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id]) 75 | return output 76 | 77 | def _transform(input_images, y_offset, x_offset): 78 | with tf.variable_scope('transform'): 79 | # grid of (x_t, y_t, 1), eq (1) in ref [1] 80 | x_t, y_t = tf.meshgrid(tf.linspace(0.0, _width_f - 1.0, _width), 81 | tf.linspace(0.0 , _height_f - 1.0 , _height)) 82 | 83 | x_t_flat = tf.reshape(x_t, (1, -1)) 84 | y_t_flat = tf.reshape(y_t, (1, -1)) 85 | 86 | x_t_flat = tf.tile(x_t_flat, tf.stack([_num_batch, 1])) 87 | y_t_flat = tf.tile(y_t_flat, tf.stack([_num_batch, 1])) 88 | 89 | x_t_flat = tf.reshape(x_t_flat, [-1]) 90 | y_t_flat = tf.reshape(y_t_flat, [-1]) 91 | 92 | if y_offset != None: 93 | y_t_flat = y_t_flat + tf.reshape(y_offset, [-1]) 94 | if x_offset != None: 95 | x_t_flat = x_t_flat + tf.reshape(x_offset, [-1]) 96 | 97 | input_transformed = _interpolate(input_images, x_t_flat, y_t_flat) 98 | 99 | output = tf.reshape( 100 | input_transformed, tf.stack([_num_batch, _height, _width, _num_channels])) 101 | return output 102 | 103 | with tf.variable_scope(name): 104 | _num_batch = tf.shape(input_images)[0] 105 | _height = tf.shape(input_images)[1] 106 | _width = tf.shape(input_images)[2] 107 | _num_channels = tf.shape(input_images)[3] 108 | 109 | _height_f = tf.cast(_height, tf.float32) 110 | _width_f = tf.cast(_width, tf.float32) 111 | 112 | _wrap_mode = wrap_mode 113 | 114 | output = _transform(input_images, y_offset, x_offset) 115 | return output 116 | 117 | def bilinear_sampler_2d(input_images, offsets, wrap_mode = 'border'): 118 | return bilinear_sampler(input_images, -offsets[:,:,:,1], offsets[:,:,:,0], wrap_mode) 119 | -------------------------------------------------------------------------------- /blendingnetwork.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Peter Hedman. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | ###################################### 17 | # Building our Deep Blending network # 18 | ###################################### 19 | 20 | import tensorflow as tf 21 | 22 | def blending_network(all_inputs, images, num_input_layers, use_global_mesh, direct_regression, batch_size): 23 | # We're using the NCHW data layout -- tensors should be concatenated 24 | # along the 1st (channels) dimension. 25 | concat_dimension = 1 26 | 27 | # Helper functions for the convolutional layers 28 | def conv(x, num_out_layers, kernel_size, stride, activation_fn=tf.nn.relu): 29 | c = tf.contrib.layers.conv2d( 30 | x, num_out_layers, kernel_size, stride, 'SAME', data_format='NCHW', 31 | activation_fn=activation_fn) 32 | return c 33 | def upsample_nn(x, num_in_layers): 34 | x_up = tf.reshape( 35 | tf.stack([x, x], axis=3), 36 | [tf.shape(x)[0], num_in_layers, 2 * tf.shape(x)[2], tf.shape(x)[3]]) 37 | return tf.reshape( 38 | tf.concat([x_up[...,tf.newaxis], x_up[...,tf.newaxis]], axis=-1), 39 | [tf.shape(x)[0], num_in_layers, 2 * tf.shape(x)[2], 2 * tf.shape(x)[3]]) 40 | 41 | def upconv(x, num_out_layers, kernel_size): 42 | upsample = upsample_nn(x, x.shape[1]) 43 | conv1 = conv(upsample, num_out_layers, kernel_size, 1) 44 | return conv1 45 | 46 | def downconv(x, num_out_layers, kernel_size, stride = 2): 47 | conv1 = conv(x, num_out_layers, kernel_size, 1) 48 | conv2 = conv(conv1, num_out_layers, kernel_size, stride) 49 | return conv2 50 | 51 | # Convert the inputs to NCHW 52 | all_inputs = tf.transpose(all_inputs, [0, 3, 1, 2]) 53 | 54 | # The deep blending U-net 55 | conv1 = conv(all_inputs, 32, 3, 1) 56 | conv2 = downconv(conv1, 48, 3, 2) 57 | conv3 = downconv(conv2, 64, 3, 2) 58 | conv4 = downconv(conv3, 96, 3, 2) 59 | conv5 = downconv(conv4, 128, 3, 2) 60 | conv6 = upconv(conv5, 96, 3) 61 | concat1 = tf.concat([conv6, conv4], concat_dimension) 62 | conv7 = upconv(concat1, 64, 3) 63 | concat2 = tf.concat([conv7, conv3], concat_dimension) 64 | conv8 = upconv(concat2, 48, 3) 65 | concat3 = tf.concat([conv8, conv2], concat_dimension) 66 | conv9 = upconv(concat3, 32, 3) 67 | features = tf.concat([conv9, conv1], concat_dimension) 68 | 69 | # obtain the final output 70 | num_images = num_input_layers 71 | if use_global_mesh: 72 | num_images = num_images + 1 73 | 74 | img_list = tf.split(images, num_images, 3) 75 | if direct_regression: 76 | out_image = 0.5 * (1.0 + conv(features, 3, 3, 1, tf.nn.tanh)) 77 | else: 78 | out = conv(features, num_images, 3, 1, None) 79 | softmax = tf.nn.softmax(out, 1) 80 | 81 | out = tf.reshape( 82 | all_inputs, [batch_size, num_images, 3, tf.shape(out)[2], tf.shape(out)[3]]) 83 | out *= softmax[:,:,tf.newaxis] 84 | out_image = tf.reduce_sum(out, axis=1) 85 | 86 | # Convert back to NHCW 87 | out_image = tf.transpose(out_image, [0, 2, 3, 1]) 88 | 89 | return out_image, img_list -------------------------------------------------------------------------------- /make_data_txt.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright 2019 Peter Hedman. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | import sys 18 | import os 19 | from PIL import Image 20 | 21 | if len(sys.argv) != 2 and len(sys.argv) != 3: 22 | print("Usage: ./" + sys.argv[0] + " [PATH_TO_DATASET_FOLDER] (validation/test)") 23 | sys.exit(1) 24 | 25 | is_validation = len(sys.argv) > 2 and sys.argv[2] == "validation" 26 | is_test = len(sys.argv) > 2 and sys.argv[2] == "test" 27 | 28 | dataset_path = sys.argv[1] 29 | dataset_dirname = os.path.split(dataset_path)[-1] 30 | 31 | files = os.listdir(dataset_path) 32 | if len(files) <= 0: 33 | print("ERROR: No files found in " + dataset_path) 34 | sys.exit(1) 35 | 36 | indices = sorted(list(set([f.split("_")[0] for f in files]))) 37 | 38 | first_image_files = sorted([f for f in files if str(indices[0]) == f[0:len(indices[0])]]) 39 | first_image_index = first_image_files[0].split("_")[0] 40 | 41 | if not is_test: 42 | reference_file = [f for f in first_image_files if "reference" in f][0] 43 | reference_suffix = "_reference" 44 | reference_extension = reference_file[-4:] 45 | 46 | probe_filename = [f for f in files if "global" in f][0] 47 | has_path_samples = probe_filename.find("path") >= 0 48 | random_suffixes = [""] 49 | path_suffixes = [""] 50 | if has_path_samples: 51 | num_filename_chunks = len(probe_filename.split("_")) 52 | if num_filename_chunks != 7: 53 | print("ERROR: Expected file name to contain seven chunks, found: " + str(num_filename_chunks)) 54 | print("(" + probe_filename + ")") 55 | sys.exit(0) 56 | 57 | flow_file = [f for f in first_image_files if "temporal_flow" in f][0] 58 | flow_suffix = "_temporal_flow" 59 | flow_extension = flow_file[-4:] 60 | 61 | # Build suffixes of the form _sample_000N 62 | random_sample_files = [f for f in first_image_files if "global" in f and "path_0000" in f] 63 | random_suffixes = sorted(list(set(["_sample_" + f.split("_")[-3] for f in random_sample_files]))) 64 | 65 | # Build suffixes of the form _path_000N 66 | path_sample_files = [f for f in first_image_files if "global" in f and "sample_0000" in f] 67 | path_suffixes = sorted(list(set(["_path_" + f.split("_")[-1][0:-4] for f in path_sample_files]))) 68 | 69 | # Find the remaining suffixes: _local_layer_N and _global_colors 70 | first_sample_files = [f for f in first_image_files if random_suffixes[0] in f and path_suffixes[0] in f] 71 | type_files = [f for f in first_sample_files if "reference" not in f and "temporal_flow" not in f] 72 | type_suffixes = ["_" + "_".join(f.split("_")[1:-4]) for f in type_files] 73 | type_extensions = [f.split("_")[-1][-4:] for f in type_files] 74 | else: 75 | type_files = [f for f in first_image_files if "reference" not in f] 76 | type_suffixes = ["_" + "_".join(f.split("_")[1:])[:-4] for f in type_files] 77 | type_extensions = [f.split("_")[-1][-4:] for f in type_files] 78 | 79 | def make_path(i, s): 80 | return dataset_dirname + "/" + i + s 81 | 82 | begin = 0 83 | end = int(len(indices) * 0.9) 84 | if is_validation: 85 | begin = end 86 | end = len(indices) 87 | elif is_test: 88 | begin = 0 89 | end = len(indices) 90 | 91 | for i in range(begin, end): 92 | for rs in random_suffixes: 93 | ii = indices[i] 94 | 95 | probe_image_path = dataset_path + "/" + ii + type_suffixes[0] + rs + path_suffixes[0] + type_extensions[0] 96 | probe_image = Image.open(probe_image_path) 97 | 98 | # Ignore images that are too low res, or can't be loaded 99 | if probe_image.size[0] < 256 or probe_image.size[1] < 256: 100 | continue 101 | 102 | line = str(probe_image.size[0]) + " " + str(probe_image.size[1]) 103 | if not is_test: 104 | reference_path = make_path(ii, reference_suffix) + reference_extension 105 | line += " " + reference_path 106 | 107 | if not is_test and has_path_samples: 108 | flow_path = make_path(ii, flow_suffix) + rs + flow_extension 109 | line += " " + flow_path 110 | 111 | for pi in range(len(path_suffixes)): 112 | for ti in range(len(type_suffixes)): 113 | line += " " + make_path(ii, type_suffixes[ti]) + rs + path_suffixes[pi] + type_extensions[ti] 114 | 115 | print(line) 116 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright 2019 Peter Hedman. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | import argparse 18 | import tensorflow as tf 19 | import numpy as np 20 | import scipy 21 | import os 22 | import sys 23 | import math 24 | 25 | from blendingnetwork import * 26 | 27 | 28 | 29 | 30 | #################################### 31 | # Parameters and argument parsing # 32 | ################################### 33 | 34 | parser = argparse.ArgumentParser( 35 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 36 | 37 | parser.add_argument("data_path", type=str, 38 | help="relative directory for all data") 39 | parser.add_argument("output_path", type=str, 40 | help="base-level output directory for results") 41 | 42 | parser.add_argument("--model_path", type=str, default="model", 43 | help="folder containing the tf network snapshots") 44 | 45 | # Network training parameters 46 | parser.add_argument("--direct_regression", default=False, action="store_true", 47 | help="Directly regress the output image instead of predicting blend weights") 48 | parser.add_argument("--no_textured_mesh", default=False, action="store_true", 49 | help="Disable the input layer from the textured mesh") 50 | parser.add_argument("--num_input_mosaics", type=int, default=4, 51 | help="Number of input mosaic layers to use") 52 | 53 | args = parser.parse_args() 54 | 55 | direct_regression = args.direct_regression 56 | use_global_mesh = not args.no_textured_mesh 57 | num_input_layers = args.num_input_mosaics 58 | batch_size = 1 59 | 60 | print("================================================") 61 | print("Running deep blending network in TEST mode.") 62 | if direct_regression: 63 | print("Directly predicting the output image.") 64 | else: 65 | print("Predicting blend weights.") 66 | print("Inputs: " + str(num_input_layers) + str(" mosaics") + \ 67 | (str(" and a render of the textured mesh.") if use_global_mesh else str("."))) 68 | print("================================================") 69 | 70 | 71 | 72 | 73 | ########################## 74 | # Data loader definition # 75 | ########################## 76 | 77 | class TestDataLoader(object): 78 | """Test mode data loader""" 79 | def __init__(self, filenames_file, data_folder): 80 | f = open(filenames_file, 'r') 81 | line = f.readline() 82 | f.close() 83 | self.init_width = int(line.split()[0]) 84 | self.init_height = int(line.split()[1]) 85 | self.current_images = None 86 | self.data_folder = None 87 | 88 | with tf.variable_scope("inputloader"): 89 | self.data_folder = tf.constant(data_folder) 90 | 91 | input_queue = tf.train.string_input_producer([filenames_file], shuffle=False) 92 | line_reader = tf.TextLineReader() 93 | _, line = line_reader.read(input_queue) 94 | split_line = tf.string_split([line]).values 95 | 96 | offset = 0 97 | current_width = tf.string_to_number(split_line[offset], tf.int32) 98 | 99 | offset += 1 100 | current_height = tf.string_to_number(split_line[offset], tf.int32) 101 | 102 | offset += 1 103 | frames_color = [] 104 | 105 | # First populate the CNN inputs with data from the global mesh 106 | global_color = self.read_jpg(split_line[offset]) 107 | global_color.set_shape([self.init_height, self.init_width, 3]) 108 | 109 | if use_global_mesh: 110 | frames_color.append(global_color) 111 | 112 | offset += 1 113 | # Then incorporate information from each per-view mosaic 114 | for i in range(num_input_layers): 115 | colors = self.read_jpg(split_line[offset + i]) 116 | colors.set_shape([self.init_height, self.init_width, 3]) 117 | frames_color.append(colors) 118 | 119 | images = tf.concat(frames_color, axis=2) 120 | 121 | # In test mode, we only use 1 thread for dataloading --- this prevents race conditions and loads the images in sequential order. 122 | min_after_dequeue = 16 123 | capacity = min_after_dequeue + 4 124 | dataloader_threads = 1 125 | self.current_images = tf.train.batch([images], batch_size, dataloader_threads, capacity) 126 | 127 | def read_jpg(self, image_path): 128 | image = tf.image.decode_jpeg(tf.read_file(tf.string_join([self.data_folder, image_path], "/"))) 129 | image = tf.image.convert_image_dtype(image, tf.float32) 130 | return image 131 | 132 | 133 | 134 | 135 | ######################### 136 | # Main code starts here # 137 | ######################### 138 | 139 | data_file = args.data_path + "/test.txt" 140 | dataloader = TestDataLoader(data_file, args.data_path) 141 | 142 | with tf.variable_scope("standard_inputs"): 143 | current_images = dataloader.current_images 144 | current_input = tf.concat(current_images, 3) 145 | 146 | with tf.variable_scope("model", reuse=False): 147 | current_out, current_img_list = \ 148 | blending_network(current_input, current_images, num_input_layers, use_global_mesh, direct_regression, batch_size) 149 | 150 | # Save intermediate models 151 | saver = tf.train.Saver(max_to_keep=1) 152 | 153 | # Creating a config to prevent GPU use at all 154 | config = tf.ConfigProto() 155 | 156 | # Start GPU memory small and allow to grow 157 | config.gpu_options.allow_growth=True 158 | 159 | sess = tf.Session(config=config) 160 | init_global = tf.global_variables_initializer() 161 | init_local = tf.local_variables_initializer() 162 | sess.run(init_local) 163 | sess.run(init_global) 164 | 165 | img_path = args.output_path 166 | model_path = args.model_path 167 | if not os.path.isdir(img_path): 168 | os.makedirs(img_path, 0o755) 169 | 170 | # Load pretrained weights 171 | last_checkpoint = tf.train.latest_checkpoint(model_path) 172 | if last_checkpoint != None: 173 | test_variables = [] 174 | for var in tf.global_variables(): 175 | if var.name.startswith("model") and var.name.find("Adam") == -1: 176 | test_variables.append(var) 177 | test_restorer = tf.train.Saver(test_variables) 178 | test_restorer.restore(sess, last_checkpoint) 179 | else: 180 | print("Could not load model weights from: " + model_path) 181 | sys.exit(0) 182 | 183 | # Start the data loader threads 184 | coordinator = tf.train.Coordinator() 185 | threads = tf.train.start_queue_runners(sess=sess, coord=coordinator) 186 | 187 | with sess.as_default(): 188 | try: 189 | test_file = open(data_file, 'r') 190 | test_lines = test_file.readlines() 191 | test_file.close() 192 | 193 | for i, line in enumerate(test_lines): 194 | if coordinator.should_stop(): 195 | break 196 | 197 | output = sess.run([current_out]) 198 | raw_image = np.clip(output[0] * 255, 0, 255).astype(np.uint8) 199 | raw_image = np.reshape(raw_image, (dataloader.init_height, dataloader.init_width, 3)) 200 | 201 | chunks = line.split(" ") 202 | image_index = chunks[-1].split("/")[1].split("_")[0] 203 | if not os.path.isdir(img_path): 204 | os.makedirs(img_path, 0o755) 205 | scipy.misc.imsave(img_path + "/" + image_index + ".jpg", raw_image) 206 | 207 | print("Test frame: %s" % image_index) 208 | except Exception as e: 209 | # Report exceptions to the coordinator. 210 | coordinator.request_stop(e) 211 | finally: 212 | coordinator.request_stop() 213 | coordinator.join(threads) 214 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright 2019 Peter Hedman. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | import argparse 18 | import tensorflow as tf 19 | import os 20 | import sys 21 | import math 22 | import tensorflow.contrib.slim as slim 23 | from tensorflow.contrib.slim.nets import vgg 24 | 25 | from bilinear_sampler import * 26 | from blendingnetwork import * 27 | 28 | 29 | 30 | #################################### 31 | # Parameters and argument parsing # 32 | ################################### 33 | 34 | parser = argparse.ArgumentParser( 35 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 36 | 37 | parser.add_argument("data_path", type=str, 38 | help="relative directory for all data") 39 | parser.add_argument("output_path", type=str, 40 | help="base-level output directory for results") 41 | 42 | parser.add_argument("--training_file", type=str, default=None, 43 | help="full path to the training file") 44 | parser.add_argument("--validation_file", type=str, default=None, 45 | help="full path to the validation file") 46 | 47 | # Define the input and output folders 48 | parser.add_argument("--log_path", type=str, default="log") 49 | parser.add_argument("--model_path", type=str, default="model") 50 | 51 | # Network training parameters 52 | parser.add_argument("--loss_function", type=str, default="VGG_AND_L1", 53 | help="Image loss to be used for training (L1, VGG, or VGG_AND_L1)") 54 | parser.add_argument("--direct_regression", default=False, action="store_true", 55 | help="Directly regress the output image instead of predicting blend weights") 56 | parser.add_argument("--no_temporal_loss", default=False, action="store_true", 57 | help="Train the network without a temporal loss") 58 | parser.add_argument("--no_textured_mesh", default=False, action="store_true", 59 | help="Disable the input layer from the textured mesh") 60 | parser.add_argument("--num_input_mosaics", type=int, default=4, 61 | help="Number of input mosaic layers to use") 62 | parser.add_argument("--temporal_alpha", type=float, default=0.33, 63 | help="Relative strength of the temporal loss") 64 | parser.add_argument("--debug", default=False, action="store_true", 65 | help="Enable debug mode for training") 66 | parser.add_argument("--num_batches", type=int, default=256000, 67 | help="Training duration (in terms of number of minibatches)") 68 | parser.add_argument("--batch_size", type=int, default=8, 69 | help="Batch size to be used for training") 70 | parser.add_argument("--crop", type=int, default=256, 71 | help="Crop size for data augmentation") 72 | 73 | args = parser.parse_args() 74 | 75 | image_loss = args.loss_function 76 | direct_regression = args.direct_regression 77 | use_temporal_loss = not args.no_temporal_loss 78 | use_global_mesh = not args.no_textured_mesh 79 | num_input_layers = args.num_input_mosaics 80 | debug_mode = args.debug 81 | temporal_alpha = args.temporal_alpha 82 | crop = args.crop 83 | num_batches = args.num_batches 84 | batch_size = args.batch_size 85 | 86 | print("================================================") 87 | print("Running deep blending network in TRAINING mode.") 88 | print("Loss function: " + image_loss + str(".")) 89 | if direct_regression: 90 | print("Directly predicting the output image.") 91 | else: 92 | print("Predicting blend weights.") 93 | print("Inputs: " + str(num_input_layers) + str(" mosaics") + \ 94 | (str(" and a render of the textured mesh.") if use_global_mesh else str("."))) 95 | print("Training " + (str("WITH") if use_temporal_loss else str("WITHOUT")) + str(" a temporal loss.")) 96 | if use_temporal_loss: 97 | print("Temporal alpha=" + str(temporal_alpha)) 98 | if debug_mode: 99 | print("Running in debug mode.") 100 | print("Random crop size: " + str(crop) + "x" + str(crop)) 101 | print("Batch size: " + str(batch_size)) 102 | print("Training until " + str(num_batches) + " minibatches") 103 | print("================================================") 104 | 105 | 106 | 107 | 108 | ########################## 109 | # Data loader definition # 110 | ########################## 111 | 112 | def augment_image(img, flip_vert, flip_horiz, rotate_img) : 113 | img = tf.cond(flip_vert > 0.5, lambda: tf.image.flip_up_down(img), lambda: img) 114 | img = tf.cond(flip_horiz > 0.5, lambda: tf.image.flip_left_right(img), lambda: img) 115 | img = tf.cond(rotate_img > 0.5, lambda: tf.image.rot90(img), lambda: img) 116 | return img 117 | 118 | def augment_flow(img, flip_vert, flip_horiz, rotate_img): 119 | x_flow = tf.expand_dims(img[:,:,0], 2) 120 | y_flow = tf.expand_dims(img[:,:,1], 2) 121 | z_flow = tf.expand_dims(img[:,:,2], 2) 122 | 123 | # Vertical flipping 124 | y_flow = tf.cond(flip_vert > 0.5, lambda: y_flow * -1.0, lambda: y_flow) 125 | 126 | # Horizontal flipping 127 | x_flow = tf.cond(flip_horiz > 0.5, lambda: x_flow * -1.0, lambda: x_flow) 128 | 129 | # 90 degree rotation 130 | new_x_flow = tf.cond(rotate_img > 0.5, lambda: y_flow * -1.0, lambda: x_flow) 131 | new_y_flow = tf.cond(rotate_img > 0.5, lambda: x_flow, lambda: y_flow) 132 | 133 | return augment_image(tf.concat([new_x_flow, new_y_flow, z_flow], axis=2), 134 | flip_vert, flip_horiz, rotate_img) 135 | 136 | class TrainingDataLoader(object): 137 | """Training mode data loader""" 138 | def __init__(self, filenames_file, data_folder): 139 | self.ref_images = None 140 | self.current_images = None 141 | self.old_images = None 142 | self.old_to_current = None 143 | self.init_height = None 144 | self.init_width = None 145 | self.data_folder = None 146 | 147 | with tf.variable_scope("inputloader"): 148 | self.data_folder = tf.constant(data_folder) 149 | 150 | input_queue = tf.train.string_input_producer([filenames_file], shuffle = True) 151 | line_reader = tf.TextLineReader() 152 | _, line = line_reader.read(input_queue) 153 | split_line = tf.string_split([line]).values 154 | 155 | offset = 0 156 | current_width = tf.string_to_number(split_line[offset], tf.int32) 157 | 158 | offset += 1 159 | current_height = tf.string_to_number(split_line[offset], tf.int32) 160 | 161 | # Load the reference image 162 | offset += 1 163 | ref_img = self.read_jpg(split_line[offset]) 164 | ref_img.set_shape([self.init_height, self.init_width, 3]) 165 | 166 | # Augment the reference image 167 | flip_vert = tf.random_uniform([], 0, 1) 168 | flip_horiz = tf.random_uniform([], 0, 1) 169 | rotate_img = tf.random_uniform([], 0, 1) 170 | 171 | crop_x = tf.random_uniform([], 0, current_width - crop, dtype=tf.int32) 172 | crop_y = tf.random_uniform([], 0, current_height - crop, dtype=tf.int32) 173 | 174 | asserts_ref = [ 175 | tf.assert_greater_equal(current_width, crop, message="Current width smaller than crop size"), 176 | tf.assert_greater_equal(current_height, crop, message="Current height smaller than crop size"), 177 | tf.assert_greater_equal(tf.shape(ref_img)[0], crop_y + crop, message="Reference height smaller than crop size"), 178 | tf.assert_greater_equal(tf.shape(ref_img)[1], crop_x + crop, message="Reference width smaller than crop size")] 179 | 180 | with tf.control_dependencies(asserts_ref): 181 | ref_img = tf.image.crop_to_bounding_box(ref_img, crop_y, crop_x, crop, crop) 182 | ref_img = augment_image(ref_img, flip_vert, flip_horiz, rotate_img) 183 | 184 | # Optionally: Load the optical flow between the two temporal frames 185 | offset += 1 186 | flow_img = tf.zeros_like(ref_img) 187 | if use_temporal_loss: 188 | flow_img = self.read_flow(split_line[offset]) 189 | 190 | # Also augment the current-to-previous optical flow 191 | asserts_flow = [ 192 | tf.assert_greater_equal(tf.shape(flow_img)[0], crop_y + crop, message="Temporal flow height smaller than crop size"), 193 | tf.assert_greater_equal(tf.shape(flow_img)[1], crop_x + crop, message="Temporal flow width smaller than crop size")] 194 | with tf.control_dependencies(asserts_flow): 195 | flow_img = tf.image.crop_to_bounding_box(flow_img, crop_y, crop_x, crop, crop) 196 | flow_img = augment_flow(flow_img, flip_vert, flip_horiz, rotate_img) 197 | 198 | # Always load two temporal frames for path samples -- we'll ignore these 199 | # later during training if use_temporal_loss is False. 200 | num_path_samples = 2 201 | frames_color = [] 202 | offset += 1 203 | 204 | for j in range(num_path_samples): 205 | # Our training data has 5 input layers per path sample: 206 | # global_colors and local_layer_N_colors (where N=0...4) 207 | path_sample_offset = 5 * j 208 | 209 | # First populate the CNN inputs with data from the global mesh 210 | if use_global_mesh: 211 | global_color = self.read_jpg(split_line[offset + path_sample_offset + 0]) 212 | global_color.set_shape([self.init_height, self.init_width, 3]) 213 | 214 | # Make sure to apply the same data augmentation to the textured global mesh layer 215 | asserts_global = [ 216 | tf.assert_greater_equal(tf.shape(global_color)[0], crop_y + crop, message="Textured mesh height smaller than crop size"), 217 | tf.assert_greater_equal(tf.shape(global_color)[1], crop_x + crop, message="Textured mesh width smaller than crop size")] 218 | with tf.control_dependencies(asserts_global): 219 | global_color = tf.image.crop_to_bounding_box(global_color, crop_y, crop_x, crop, crop) 220 | global_color = augment_image(global_color, flip_vert, flip_horiz, rotate_img) 221 | 222 | frames_color.append([global_color]) 223 | else: 224 | frames_color.append([]) 225 | 226 | # Index where the image mosaic layers begin (+1 is for the textured mesh layer above) 227 | mosaic_layers_begin = offset + path_sample_offset + 1 228 | for i in range(num_input_layers): 229 | colors = self.read_jpg(split_line[mosaic_layers_begin + i]) 230 | colors.set_shape([self.init_height, self.init_width, 3]) 231 | 232 | # Augment the image mosaics 233 | asserts_local = [ 234 | tf.assert_greater_equal(tf.shape(colors)[0], crop_y + crop, message="Mosaic height smaller than crop size"), 235 | tf.assert_greater_equal(tf.shape(colors)[1], crop_x + crop, message="Mosaic width smaller than crop size")] 236 | with tf.control_dependencies(asserts_local): 237 | colors = tf.image.crop_to_bounding_box(colors, crop_y, crop_x, crop, crop) 238 | colors = augment_image(colors, flip_vert, flip_horiz, rotate_img) 239 | frames_color[-1].append(colors) 240 | 241 | images_current = tf.concat(frames_color[0], axis=2) 242 | images_previous = tf.concat(frames_color[1], axis=2) 243 | 244 | dataloader_threads = 8 245 | min_after_dequeue = 16 246 | capacity = min_after_dequeue + 4 * batch_size 247 | if use_temporal_loss: 248 | self.ref_images, self.current_images, self.old_images, self.old_to_current = \ 249 | tf.train.shuffle_batch([ref_img, images_current, images_previous, flow_img], batch_size, capacity, min_after_dequeue, dataloader_threads) 250 | else: 251 | self.ref_images, self.current_images, self.old_images = \ 252 | tf.train.shuffle_batch([ref_img, images_current, images_previous], batch_size, capacity, min_after_dequeue, dataloader_threads) 253 | 254 | def read_flow(self, image_path): 255 | image = tf.image.decode_png(tf.read_file(tf.string_join([self.data_folder, image_path], "/")), channels=3, dtype=tf.uint16) 256 | image = tf.image.convert_image_dtype(image, tf.float32) 257 | image.set_shape([self.init_height, self.init_width, 3]) 258 | image = 512.0 * (image * 2.0 - 1.0) 259 | return image 260 | 261 | def read_jpg(self, image_path): 262 | image = tf.image.decode_jpeg(tf.read_file(tf.string_join([self.data_folder, image_path], "/"))) 263 | image = tf.image.convert_image_dtype(image, tf.float32) 264 | return image 265 | 266 | 267 | 268 | 269 | ############################################ 270 | # Helper functions for our training losses # 271 | ############################################ 272 | 273 | def compute_standard_error(output, reference): 274 | l1_image = tf.abs(output - reference) 275 | return l1_image 276 | 277 | vgg_mean = tf.reshape(tf.constant([123.68, 116.78, 103.94]), [1, 1, 3]) 278 | def vgg_16(inputs, scope='vgg_16'): 279 | """Computes deep image features as the first two maxpooling layers of a VGG16 network""" 280 | with tf.variable_scope('vgg_16', 'vgg_16', [inputs], reuse=tf.AUTO_REUSE) as sc: 281 | end_points_collection = sc.original_name_scope + '_end_points' 282 | 283 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 284 | outputs_collections=end_points_collection): 285 | net_a = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1') 286 | net_b = slim.max_pool2d(net_a, [2, 2], scope='pool1') 287 | net_c = slim.repeat(net_b, 2, slim.conv2d, 128, [3, 3], scope='conv2') 288 | return net_a, net_c 289 | 290 | def compute_vgg_error(output, reference, layer): 291 | scaled_output = output * 255 - vgg_mean 292 | scaled_reference = reference * 255 - vgg_mean 293 | with slim.arg_scope(vgg.vgg_arg_scope()): 294 | output_a, output_b = vgg_16(scaled_output) 295 | reference_a, reference_b = vgg_16(scaled_reference) 296 | if layer == 0: 297 | return tf.abs(output_a - reference_a) 298 | return tf.abs(output_b - reference_b) 299 | 300 | if image_loss == 'L1': # Standard L1 301 | def compute_error(output, reference, kernel_size): 302 | return compute_standard_error(output, reference) 303 | 304 | def compute_loss(output, reference, kernel_size): 305 | return tf.reduce_mean(compute_standard_error(output, reference)) 306 | elif image_loss == 'VGG': # Use a VGG loss 307 | def compute_error(output, reference, kernel_size): 308 | return compute_vgg_error(output, reference, 0) 309 | 310 | def compute_loss(output, reference, kernel_size): 311 | return tf.reduce_mean(compute_vgg_error(output, reference, 0)) + \ 312 | tf.reduce_mean(compute_vgg_error(output, reference, 1)) 313 | elif image_loss == 'VGG_AND_L1': # Mix L1 and VGG to perserve high frequency content 314 | def compute_error(output, reference, kernel_size): 315 | return 255.0 * compute_standard_error(output, reference) 316 | 317 | def compute_loss(output, reference, kernel_size): 318 | return tf.reduce_mean(compute_vgg_error(output, reference, 0)) + \ 319 | tf.reduce_mean(compute_vgg_error(output, reference, 1)) + \ 320 | 255.0 * tf.reduce_mean(compute_standard_error(output, reference)) 321 | else: 322 | print("Unexpected image loss: " + image_loss) 323 | sys.exit(0) 324 | 325 | 326 | 327 | 328 | ######################### 329 | # Main code starts here # 330 | ######################### 331 | 332 | data_file = args.training_file 333 | dataloader = TrainingDataLoader(data_file, args.data_path) 334 | 335 | dataloader_validation = None 336 | if args.validation_file is not None: 337 | validation_file = args.validation_file 338 | dataloader_validation = TrainingDataLoader(validation_file, args.data_path) 339 | 340 | # Collect inputs from the data loaders 341 | with tf.variable_scope("standard_inputs"): 342 | ref_image = dataloader.ref_images 343 | current_images = dataloader.current_images 344 | old_images = dataloader.old_images 345 | current_input = tf.concat(current_images, 3) 346 | old_input = tf.concat(old_images, 3) 347 | if use_temporal_loss: 348 | old_to_current = dataloader.old_to_current 349 | 350 | if args.validation_file is not None: 351 | with tf.variable_scope("validation_inputs"): 352 | v_ref_image = dataloader_validation.ref_images 353 | v_current_images = dataloader_validation.current_images 354 | v_old_images = dataloader_validation.old_images 355 | v_current_input = tf.concat(v_current_images, 3) 356 | v_old_input = tf.concat(v_old_images, 3) 357 | if use_temporal_loss: 358 | v_old_to_current = dataloader_validation.old_to_current 359 | 360 | # Produce outputs from the blending network 361 | with tf.variable_scope("model", reuse=False): 362 | current_out, current_img_list = \ 363 | blending_network(current_input, current_images, num_input_layers, use_global_mesh, direct_regression, batch_size) 364 | warped_current = current_out 365 | 366 | with tf.variable_scope("model", reuse=True): 367 | old_out, old_img_list = \ 368 | blending_network(old_input, old_images, num_input_layers, use_global_mesh, direct_regression, batch_size) 369 | 370 | if use_temporal_loss: 371 | warped_current = bilinear_sampler_2d(current_out, old_to_current, 'edge') 372 | 373 | if args.validation_file is not None: 374 | with tf.variable_scope("model", reuse=True): 375 | v_current_out, v_current_img_list = \ 376 | blending_network(v_current_input, v_current_images, num_input_layers, use_global_mesh, direct_regression, batch_size) 377 | v_warped_current = v_current_out 378 | 379 | with tf.variable_scope("model", reuse=True): 380 | v_old_out, v_old_img_list = \ 381 | blending_network(v_old_input, v_old_images, num_input_layers, use_global_mesh, direct_regression, batch_size) 382 | 383 | if use_temporal_loss: 384 | v_warped_current = bilinear_sampler_2d(v_current_out, v_old_to_current, 'edge') 385 | 386 | # Define losses for training 387 | standard_loss = compute_loss(current_out, ref_image, 2) 388 | temporal_loss = compute_loss(warped_current, old_out, 2) 389 | minimize_loss = standard_loss 390 | if use_temporal_loss: 391 | minimize_loss += temporal_alpha * temporal_loss 392 | 393 | if args.validation_file is not None: 394 | v_standard_loss = compute_loss(v_current_out, v_ref_image, 2) 395 | v_temporal_loss = compute_loss(v_warped_current, v_old_out, 2) 396 | 397 | # Train the network using Adam 398 | global_step = tf.get_variable("global_step", initializer=tf.constant(0, dtype=tf.int32), trainable=False) 399 | 400 | ibr_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "model") 401 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 402 | with tf.control_dependencies(update_ops): 403 | train_step = tf.train.AdamOptimizer(3e-4).minimize(minimize_loss, global_step=global_step, var_list=ibr_train_vars) 404 | 405 | # Save intermediate models 406 | saver = tf.train.Saver(max_to_keep=1) 407 | 408 | # Creating a config to prevent GPU use at all 409 | config = tf.ConfigProto() 410 | 411 | # Start GPU memory small and allow to grow 412 | config.gpu_options.allow_growth=True 413 | 414 | sess = tf.Session(config=config) 415 | init_global = tf.global_variables_initializer() 416 | init_local = tf.local_variables_initializer() 417 | sess.run(init_local) 418 | sess.run(init_global) 419 | 420 | # Surround an image with a completely white (1, 1, 1) 421 | # and black (0, 0, 0) line, to make sure that tensorboard 422 | # doesn't normalize each image idependently. 423 | def tensorboard_normalize_hack(img): 424 | return tf.concat([tf.zeros_like(img[:,0:1,:,:]), 425 | img, 426 | tf.ones_like(img[:,0:1,:,:])], axis=1) 427 | 428 | # Visualize results in tensorboard 429 | vis_output = tf.summary.image("result", tensorboard_normalize_hack(current_out), max_outputs=batch_size) 430 | vis_reference = tf.summary.image("reference", tensorboard_normalize_hack(ref_image), max_outputs=batch_size) 431 | vis_previous = tf.summary.image("previous", tensorboard_normalize_hack(old_out), max_outputs=batch_size) 432 | 433 | if use_temporal_loss: 434 | vis_warped_current = tf.summary.image("current_warped", tensorboard_normalize_hack(warped_current), max_outputs=batch_size) 435 | if args.validation_file is not None: 436 | vis_v_warped_current = tf.summary.image("validation_current_warped", tensorboard_normalize_hack(v_warped_current), max_outputs=batch_size) 437 | 438 | if args.validation_file is not None: 439 | vis_v_output = tf.summary.image("validation_result", tensorboard_normalize_hack(v_current_out), max_outputs=batch_size) 440 | vis_v_reference = tf.summary.image("validation_reference", tensorboard_normalize_hack(v_ref_image), max_outputs=batch_size) 441 | vis_v_previous = tf.summary.image("validation_previous", tensorboard_normalize_hack(v_old_out), max_outputs=batch_size) 442 | 443 | if debug_mode: 444 | image_index = 0 445 | if use_global_mesh: 446 | vis_g_c = tf.summary.image("input_textured_mesh", 447 | tensorboard_normalize_hack(current_img_list[image_index]), 448 | max_outputs=batch_size) 449 | image_index = image_index + 1 450 | if num_input_layers > 0: 451 | vis_a_c = tf.summary.image("input_00_mosaic", 452 | tensorboard_normalize_hack(current_img_list[image_index]), 453 | max_outputs=batch_size) 454 | image_index = image_index + 1 455 | if num_input_layers > 1: 456 | vis_b_c = tf.summary.image("input_01_mosaic", 457 | tensorboard_normalize_hack(current_img_list[image_index]), 458 | max_outputs=batch_size) 459 | image_index = image_index + 1 460 | if num_input_layers > 2: 461 | vis_c_c = tf.summary.image("input_02_mosaic", 462 | tensorboard_normalize_hack(current_img_list[image_index]), 463 | max_outputs=batch_size) 464 | image_index = image_index + 1 465 | if num_input_layers > 3: 466 | vis_d_c = tf.summary.image("input_03_mosaic", 467 | tensorboard_normalize_hack(current_img_list[image_index]), 468 | max_outputs=batch_size) 469 | 470 | vis_loss = tf.summary.image("image_error", compute_error(current_out, ref_image, 2), max_outputs=batch_size) 471 | 472 | train_summary = tf.summary.merge([tf.summary.scalar("batch_loss", standard_loss), 473 | tf.summary.scalar("temporal_loss", temporal_loss)]) 474 | 475 | if args.validation_file is not None: 476 | validation_summary = tf.summary.merge([tf.summary.scalar("validation_batch_loss", v_standard_loss), 477 | tf.summary.scalar("validation_temporal_loss", v_temporal_loss)]) 478 | 479 | log_path = os.path.join(args.output_path, args.log_path) 480 | model_path = os.path.join(args.output_path, args.model_path) 481 | if not os.path.isdir(log_path): 482 | os.makedirs(log_path, 0o755) 483 | if not os.path.isdir(model_path): 484 | os.makedirs(model_path, 0o755) 485 | summary_writer = tf.summary.FileWriter(log_path, sess.graph) 486 | 487 | # Load pretrained weights 488 | first_batch = 0 489 | last_checkpoint = tf.train.latest_checkpoint(model_path) 490 | if last_checkpoint != None: 491 | train_variables = [] 492 | for var in tf.global_variables(): 493 | train_variables.append(var) 494 | train_restorer = tf.train.Saver(train_variables) 495 | train_restorer.restore(sess, last_checkpoint) 496 | first_batch = int(last_checkpoint.split("-")[-1]) 497 | print("Restarting from batch: " + str(first_batch)) 498 | elif image_loss == 'VGG' or image_loss == "VGG_AND_L1": # New fresh start, we only need to load the VGG weights 499 | vgg_variables = [] 500 | for var in tf.global_variables(): 501 | if var.name.startswith("vgg_16") and var.name.find("Adam") == -1: 502 | vgg_variables.append(var) 503 | vgg_restorer = tf.train.Saver(vgg_variables) 504 | vgg_restorer.restore(sess, "vgg_16.ckpt") 505 | 506 | # Start the data loader threads 507 | coordinator = tf.train.Coordinator() 508 | threads = tf.train.start_queue_runners(sess=sess, coord=coordinator) 509 | 510 | # Main training loop 511 | with sess.as_default(): 512 | try: 513 | loss_dump_interval = 8 514 | image_dump_interval = 256 515 | for i in range(first_batch, num_batches): 516 | if coordinator.should_stop(): 517 | break 518 | 519 | run_list = [train_step] 520 | if i == 0 or i % loss_dump_interval == (loss_dump_interval - 1): 521 | run_list = run_list + [standard_loss, train_summary] 522 | if args.validation_file is not None: 523 | run_list += [validation_summary] 524 | 525 | if i == 0 or i % image_dump_interval == (image_dump_interval - 1): 526 | run_list = run_list + [vis_output, vis_reference, vis_previous] 527 | if args.validation_file is not None: 528 | run_list += [vis_v_output, vis_v_reference, vis_v_previous] 529 | 530 | if use_temporal_loss: 531 | run_list = run_list + [vis_warped_current] 532 | if args.validation_file is not None: 533 | run_list += [vis_v_warped_current] 534 | 535 | if debug_mode: 536 | run_list = run_list + [vis_loss] 537 | if use_global_mesh: 538 | run_list = run_list + [vis_g_c] 539 | if num_input_layers > 0: 540 | run_list = run_list + [vis_a_c] 541 | if num_input_layers > 1: 542 | run_list = run_list + [vis_b_c] 543 | if num_input_layers > 2: 544 | run_list = run_list + [vis_c_c] 545 | if num_input_layers > 3: 546 | run_list = run_list + [vis_d_c] 547 | 548 | output = sess.run(run_list) 549 | 550 | if i == 0 or i % loss_dump_interval == (loss_dump_interval - 1): 551 | print("Step: %d, batch loss: %f" % (i , output[1])) 552 | 553 | for j in range(2, len(output)): 554 | summary_writer.add_summary(output[j], global_step=i) 555 | 556 | if i % image_dump_interval == (image_dump_interval - 1): 557 | print("saving model") 558 | saver.save(sess, os.path.join(model_path, "model"), global_step=global_step) 559 | except Exception as e: 560 | # Report exceptions to the coordinator. 561 | coordinator.request_stop(e) 562 | finally: 563 | coordinator.request_stop() 564 | coordinator.join(threads) 565 | --------------------------------------------------------------------------------