├── .gitignore ├── LICENSE ├── README.md ├── demo.py ├── get_data.sh ├── get_zhang_colorization.sh ├── images ├── divcolor_figure.png └── divcolor_imagenet.png ├── mdn ├── __init__.py ├── arch │ ├── __init__.py │ └── layer_factory.py ├── data_loaders │ ├── __init__.py │ └── zhangfeats_loader.py ├── mdn.py └── save_mdn_gmm.py ├── requirements.txt ├── run_demo.sh ├── run_lfw.sh ├── third_party ├── __init__.py └── save_zhang_feats.py └── vae ├── __init__.py ├── arch ├── __init__.py ├── layer_factory.py ├── network.py ├── vae_skipconn.py └── vae_wo_skipconn.py ├── data_loaders ├── __init__.py └── lab_imageloader.py ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.txt 3 | *.jpg 4 | *.JPEG 5 | *meta 6 | *ckpt 7 | *npy 8 | *npz 9 | *mat 10 | *jpg 11 | *webp 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2017 Aditya Deshpande 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Tensorflow implementation of Deshpande et al. "[Learning Diverse Image Colorization](https://arxiv.org/abs/1612.01958)" 2 | 3 | The code is tested for Tensorflow-v1.0.1 and python-2.7. The code additionally needs numpy, scipy, 4 | scikit-learn and caffe-r1.0 (caffe only for Zhang et al. colorization network). 5 | 6 | Fetch data by 7 | 8 | ``` 9 | bash get_data.sh 10 | ``` 11 | 12 | Fetch Zhang et al. colorization network for MDN features by 13 | 14 | ``` 15 | bash get_zhang_colorization.sh 16 | ``` 17 | 18 | Execute run_lfw.sh to first train vae+mdn and then, generate results for LFW 19 | 20 | ``` 21 | bash run_lfw.sh 22 | ``` 23 | 24 | Execute run_demo.sh to get diverse colorization for any image, the model is trained on imagenet 25 | 26 | ``` 27 | bash run_demo.sh 28 | ``` 29 | 30 | If you use this code, please cite 31 | 32 | ``` 33 | @inproceedings{DeshpandeLDColor17, 34 | author = {Aditya Deshpande, Jiajun Lu, Mao-Chuang Yeh, Min Jin Chong and David Forsyth}, 35 | title = {Learning Diverse Image Colorization}, 36 | booktitle={Computer Vision and Pattern Recognition}, 37 | url={https://arxiv.org/abs/1612.01958}, 38 | year={2017} 39 | } 40 | ``` 41 | 42 | Some examples of diverse colorizations on LFW, LSUN Church and ImageNet-Val dataset 43 | 44 |

45 | 46 |

47 | 48 | Some examples of diverse colorizations for images in the wild, model is trained on imagenet 49 | 50 |

51 | 52 |

53 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 3 | import socket 4 | import sys 5 | 6 | import tensorflow as tf 7 | import numpy as np 8 | from vae.data_loaders.lab_imageloader import lab_imageloader 9 | from vae.arch.vae_skipconn import vae_skipconn as vae 10 | from vae.arch.network import network 11 | from third_party.save_zhang_feats import save_zhang_feats 12 | 13 | flags = tf.flags 14 | 15 | #Directory params 16 | flags.DEFINE_string("out_dir", "", "") 17 | flags.DEFINE_string("in_dir", "", "") 18 | flags.DEFINE_string("list_dir", "", "") 19 | 20 | #Dataset Params 21 | flags.DEFINE_integer("batch_size", 32, "batch size") 22 | flags.DEFINE_integer("updates_per_epoch", 1, "number of updates per epoch") 23 | flags.DEFINE_integer("log_interval", 1, "input image height") 24 | flags.DEFINE_integer("img_width", 64, "input image width") 25 | flags.DEFINE_integer("img_height", 64, "input image height") 26 | 27 | #Network Params 28 | flags.DEFINE_boolean("is_only_data", False, "Is training flag") 29 | flags.DEFINE_boolean("is_train", False, "Is training flag") 30 | flags.DEFINE_boolean("is_run_cvae", False, "Is training flag") 31 | flags.DEFINE_integer("hidden_size", 64, "size of the hidden VAE unit") 32 | flags.DEFINE_float("lr_vae", 1e-6, "learning rate for vae") 33 | flags.DEFINE_integer("max_epoch_vae", 10, "max epoch") 34 | flags.DEFINE_integer("pc_comp", 20, "number of principle components") 35 | 36 | FLAGS = flags.FLAGS 37 | 38 | def main(): 39 | 40 | FLAGS.log_interval = 1 41 | FLAGS.list_dir = None 42 | FLAGS.in_dir = 'data/testimgs/' 43 | FLAGS.ext = 'JPEG' 44 | data_loader = lab_imageloader(FLAGS.in_dir, \ 45 | 'data/output/testimgs', listdir=None, ext=FLAGS.ext) 46 | img_fns = data_loader.test_img_fns 47 | 48 | if(FLAGS.is_only_data == True): 49 | feats_fns = save_zhang_feats(img_fns, ext=FLAGS.ext) 50 | 51 | with open('%s/list.train.txt' % FLAGS.in_dir, 'w') as fp: 52 | for feats_fn in feats_fns: 53 | fp.write('%s\n' % feats_fn) 54 | 55 | with open('%s/list.test.txt' % FLAGS.in_dir, 'w') as fp: 56 | for feats_fn in feats_fns: 57 | fp.write('%s\n' % feats_fn) 58 | 59 | np.save('%s/lv_color_train.mat.npy' % FLAGS.in_dir, \ 60 | np.zeros((len(img_fns), 2*FLAGS.hidden_size))) 61 | np.save('%s/lv_color_test.mat.npy' % FLAGS.in_dir, \ 62 | np.zeros((len(img_fns), 2*FLAGS.hidden_size))) 63 | else: 64 | nmix = 8 65 | lv_mdn_test = np.load(os.path.join(FLAGS.in_dir, 'lv_color_mdn_test.mat.npy')) 66 | num_batches = np.int_(np.ceil((lv_mdn_test.shape[0]*1.)/FLAGS.batch_size)) 67 | 68 | graph_divcolor = tf.Graph() 69 | with graph_divcolor.as_default(): 70 | model_colorfield = vae(FLAGS, nch=2, condinference_flag=True) 71 | dnn = network(model_colorfield, data_loader, 2, FLAGS) 72 | dnn.run_divcolor('data/imagenet_models/' , \ 73 | lv_mdn_test, num_batches=num_batches) 74 | 75 | if __name__ == "__main__": 76 | main() 77 | -------------------------------------------------------------------------------- /get_data.sh: -------------------------------------------------------------------------------- 1 | wget http://vision.cs.illinois.edu/projects/divcolor/data.zip 2 | unzip data.zip 3 | rm data.zip 4 | wget http://vis-www.cs.umass.edu/lfw/lfw-deepfunneled.tgz 5 | tar -xvzf lfw-deepfunneled.tgz 6 | mv lfw-deepfunneled data/lfw_images 7 | rm lfw-deepfunneled.tgz 8 | -------------------------------------------------------------------------------- /get_zhang_colorization.sh: -------------------------------------------------------------------------------- 1 | cd third_party 2 | git clone https://github.com/richzhang/colorization.git 3 | rm -rf colorization/.git 4 | cd colorization/ 5 | bash models/fetch_release_models.sh 6 | -------------------------------------------------------------------------------- /images/divcolor_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aditya12agd5/divcolor/570155a062bcb6428353ec0e80343badbc290caf/images/divcolor_figure.png -------------------------------------------------------------------------------- /images/divcolor_imagenet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aditya12agd5/divcolor/570155a062bcb6428353ec0e80343badbc290caf/images/divcolor_imagenet.png -------------------------------------------------------------------------------- /mdn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aditya12agd5/divcolor/570155a062bcb6428353ec0e80343badbc290caf/mdn/__init__.py -------------------------------------------------------------------------------- /mdn/arch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aditya12agd5/divcolor/570155a062bcb6428353ec0e80343badbc290caf/mdn/arch/__init__.py -------------------------------------------------------------------------------- /mdn/arch/layer_factory.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | from tensorflow.python.framework import tensor_shape 5 | 6 | class layer_factory: 7 | 8 | def __init__(self): 9 | pass 10 | 11 | def weight_variable(self, name, shape=None, mean=0., stddev=.001, gain=np.sqrt(2)): 12 | if(shape == None): 13 | return tf.get_variable(name) 14 | # #Adaptive initialize based on variable shape 15 | # if(len(shape) == 4): 16 | # stddev = (1.0 * gain) / np.sqrt(shape[0] * shape[1] * shape[3]) 17 | # else: 18 | # stddev = (1.0 * gain) / np.sqrt(shape[0]) 19 | return tf.get_variable(name, shape=shape, initializer=tf.random_normal_initializer(mean=mean, stddev=stddev)) 20 | 21 | def bias_variable(self, name, shape=None, constval=.001): 22 | if(shape == None): 23 | return tf.get_variable(name) 24 | return tf.get_variable(name, shape=shape, initializer=tf.constant_initializer(constval)) 25 | 26 | def conv2d(self, x, W, stride=1, padding='SAME'): 27 | return tf.nn.conv2d(x, W, strides=[1, stride, stride, 1], padding=padding) 28 | 29 | def batch_norm_aiuiuc_wrapper(self, x, train_phase, name, reuse_vars): 30 | output = tf.contrib.layers.batch_norm(x, \ 31 | decay=.99, \ 32 | is_training=train_phase, \ 33 | scale=True, \ 34 | epsilon=1e-4, \ 35 | updates_collections=None,\ 36 | scope=name,\ 37 | reuse=reuse_vars) 38 | return output 39 | -------------------------------------------------------------------------------- /mdn/data_loaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aditya12agd5/divcolor/570155a062bcb6428353ec0e80343badbc290caf/mdn/data_loaders/__init__.py -------------------------------------------------------------------------------- /mdn/data_loaders/zhangfeats_loader.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import glob 3 | import math 4 | import numpy as np 5 | 6 | class zhangfeats_loader: 7 | 8 | def __init__(self, list_train_fn, list_test_fn, lv_train_fn, lv_test_fn, \ 9 | featshape=(512, 28, 28)): 10 | 11 | self.train_img_fns = [] 12 | self.test_img_fns = [] 13 | 14 | with open(list_train_fn, 'r') as ftr: 15 | for img_fn in ftr: 16 | self.train_img_fns.append(img_fn.strip('\n')) 17 | 18 | with open(list_test_fn, 'r') as fte: 19 | for img_fn in fte: 20 | self.test_img_fns.append(img_fn.strip('\n')) 21 | 22 | self.lv_train = np.load(lv_train_fn) 23 | self.lv_test = np.load(lv_test_fn) 24 | self.hidden_size = np.int_((self.lv_train.shape[1]*1.)/2.) 25 | 26 | self.train_img_fns = self.train_img_fns[:self.lv_train.shape[0]] 27 | self.test_img_fns = self.test_img_fns[:self.lv_test.shape[0]] 28 | self.featshape = featshape 29 | 30 | self.train_img_num = len(self.train_img_fns) 31 | self.test_img_num = len(self.test_img_fns) 32 | self.train_batch_head = 0 33 | self.test_batch_head = 0 34 | self.train_shuff_ids = range(self.train_img_num) 35 | self.test_shuff_ids = range(self.test_img_num) 36 | 37 | def reset(self): 38 | self.train_batch_head = 0 39 | self.test_batch_head = 0 40 | self.train_shuff_ids = range(self.train_img_num) 41 | self.test_shuff_ids = range(self.test_img_num) 42 | 43 | def random_reset(self): 44 | self.train_batch_head = 0 45 | self.test_batch_head = 0 46 | self.train_shuff_ids = np.random.permutation(self.train_img_num) 47 | self.test_shuff_ids = range(self.test_img_num) 48 | 49 | def train_next_batch(self, batch_size): 50 | batch = np.zeros((batch_size, self.featshape[2], \ 51 | self.featshape[1], self.featshape[0]), dtype='f') 52 | batch_gt = np.zeros((batch_size, self.hidden_size), dtype='f') 53 | 54 | if(self.train_batch_head + batch_size >= self.train_img_num): 55 | self.train_shuff_ids = np.random.permutation(self.train_img_num) 56 | self.train_batch_head = 0 57 | 58 | for i_n, i in enumerate(range(self.train_batch_head, self.train_batch_head+batch_size)): 59 | currid = self.train_shuff_ids[i] 60 | featobj = np.load(self.train_img_fns[currid]) 61 | feats = featobj['arr_0'].reshape(self.featshape[0], self.featshape[1], \ 62 | self.featshape[2]) 63 | feats2d = feats.reshape(self.featshape[0], -1).T 64 | feats3d = feats2d.reshape(self.featshape[1], self.featshape[2], \ 65 | self.featshape[0]) 66 | batch[i_n, ...] = feats3d 67 | eps = np.random.normal(loc=0., scale=1., size=(self.hidden_size)) 68 | batch_gt[i_n, ...] = self.lv_train[currid, :self.hidden_size] \ 69 | + eps*self.lv_train[currid, self.hidden_size:] 70 | 71 | self.train_batch_head = self.train_batch_head + batch_size 72 | 73 | return batch, batch_gt 74 | 75 | def test_next_batch(self, batch_size): 76 | batch = np.zeros((batch_size, self.featshape[2], \ 77 | self.featshape[1], self.featshape[0]), dtype='f') 78 | batch_gt = np.zeros((batch_size, self.hidden_size), dtype='f') 79 | 80 | if(self.test_batch_head + batch_size > self.test_img_num): 81 | self.test_shuff_ids = range(self.test_img_num) 82 | self.test_batch_head = 0 83 | 84 | for i_n, i in enumerate(range(self.test_batch_head, self.test_batch_head+batch_size)): 85 | currid = self.test_shuff_ids[i] 86 | featobj = np.load(self.test_img_fns[currid]) 87 | feats = featobj['arr_0'].reshape(self.featshape[0], self.featshape[1], \ 88 | self.featshape[2]) 89 | feats2d = feats.reshape(self.featshape[0], -1).T 90 | feats3d = feats2d.reshape(self.featshape[1], self.featshape[2], \ 91 | self.featshape[0]) 92 | batch[i_n, ...] = feats3d 93 | eps = np.random.normal(loc=0., scale=1., size=(self.hidden_size)) 94 | batch_gt[i_n, ...] = self.lv_test[currid, :self.hidden_size] \ 95 | +eps*self.lv_test[currid, self.hidden_size:] 96 | 97 | self.test_batch_head = self.test_batch_head + batch_size 98 | return batch, batch_gt 99 | -------------------------------------------------------------------------------- /mdn/mdn.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 3 | 4 | import socket 5 | import sys 6 | 7 | import tensorflow as tf 8 | import numpy as np 9 | from data_loaders.zhangfeats_loader import zhangfeats_loader 10 | from arch.layer_factory import layer_factory 11 | 12 | flags = tf.flags 13 | FLAGS = flags.FLAGS 14 | 15 | 16 | flags.DEFINE_integer("feats_height", 28, "") 17 | flags.DEFINE_integer("feats_width", 28, "") 18 | flags.DEFINE_integer("feats_nch", 512, "") 19 | flags.DEFINE_integer("nmix", 8, "GMM components") 20 | flags.DEFINE_boolean("is_train", True, "") 21 | flags.DEFINE_integer("max_epoch", 5, "Max epoch") 22 | flags.DEFINE_float("lr", 1e-5, "Learning rate") 23 | 24 | #Dynamical assigned params 25 | flags.DEFINE_integer("batch_size", 1, "Batch size") 26 | flags.DEFINE_integer("hidden_size", 64, "VAE latent variable dimension") 27 | flags.DEFINE_integer("updates_per_epoch", 0, "Number of updates per epoch") 28 | 29 | def cnn_feedforward(lf, input_tensor, bn_is_training, keep_prob, reuse=False): 30 | 31 | nout = (FLAGS.hidden_size+1)*FLAGS.nmix 32 | 33 | if(reuse == False): 34 | W_conv1 = lf.weight_variable(name='W_conv1', shape=[5, 5, FLAGS.feats_nch, 384]) 35 | W_conv1_1 = lf.weight_variable(name='W_conv1_1', shape=[5, 5, 384, 320]) 36 | W_conv1_2 = lf.weight_variable(name='W_conv1_2', shape=[5, 5, 320, 288]) 37 | W_conv1_3 = lf.weight_variable(name='W_conv1_3', shape=[5, 5, 288, 256]) 38 | 39 | W_conv2 = lf.weight_variable(name='W_conv2', shape=[5, 5, 256, 128]) 40 | W_fc1 = lf.weight_variable(name='W_fc1', shape=[14*14*128, 4096]) 41 | W_fc2 = lf.weight_variable(name='W_fc2', shape=[4096, nout]) 42 | 43 | 44 | b_fc1 = lf.bias_variable(name='b_fc1', shape=[4096]) 45 | b_fc2 = lf.bias_variable(name='b_fc2', shape=[nout]) 46 | 47 | else: 48 | W_conv1 = lf.weight_variable(name='W_conv1') 49 | W_conv1_1 = lf.weight_variable(name='W_conv1_1') 50 | W_conv1_2 = lf.weight_variable(name='W_conv1_2') 51 | W_conv1_3 = lf.weight_variable(name='W_conv1_3') 52 | 53 | W_conv2 = lf.weight_variable(name='W_conv2') 54 | W_fc1 = lf.weight_variable(name='W_fc1') 55 | W_fc2 = lf.weight_variable(name='W_fc2') 56 | 57 | b_fc1 = lf.bias_variable(name='b_fc1') 58 | b_fc2 = lf.bias_variable(name='b_fc2') 59 | 60 | conv1 = tf.nn.relu(lf.conv2d(input_tensor, W_conv1, stride=1)) 61 | conv1_norm = lf.batch_norm_aiuiuc_wrapper(conv1, bn_is_training, \ 62 | 'BN1', reuse_vars=reuse) 63 | 64 | conv1_1 = tf.nn.relu(lf.conv2d(conv1_norm, W_conv1_1, stride=1)) 65 | conv1_1_norm = lf.batch_norm_aiuiuc_wrapper(conv1_1, bn_is_training, \ 66 | 'BN1_1', reuse_vars=reuse) 67 | 68 | conv1_2 = tf.nn.relu(lf.conv2d(conv1_1_norm, W_conv1_2, stride=1)) 69 | conv1_2_norm = lf.batch_norm_aiuiuc_wrapper(conv1_2, bn_is_training, \ 70 | 'BN1_2', reuse_vars=reuse) 71 | 72 | conv1_3 = tf.nn.relu(lf.conv2d(conv1_2_norm, W_conv1_3, stride=2)) 73 | conv1_3_norm = lf.batch_norm_aiuiuc_wrapper(conv1_3, bn_is_training, \ 74 | 'BN1_3', reuse_vars=reuse) 75 | 76 | conv2 = tf.nn.relu(lf.conv2d(conv1_3_norm, W_conv2, stride=1)) 77 | conv2_norm = lf.batch_norm_aiuiuc_wrapper(conv2, bn_is_training, \ 78 | 'BN2', reuse_vars=reuse) 79 | 80 | dropout1 = tf.nn.dropout(conv2_norm, keep_prob) 81 | flatten1 = tf.reshape(dropout1, [-1, 14*14*128]) 82 | fc1 = tf.tanh(tf.matmul(flatten1, W_fc1)+b_fc1) 83 | 84 | dropout2 = tf.nn.dropout(fc1, keep_prob) 85 | fc2 = tf.matmul(dropout2, W_fc2)+b_fc2 86 | 87 | return fc2 88 | 89 | def get_mixture_coeff(out_fc): 90 | out_mu = out_fc[..., :FLAGS.hidden_size*FLAGS.nmix] 91 | out_pi = tf.nn.softmax(out_fc[..., FLAGS.hidden_size*FLAGS.nmix:]) 92 | out_sigma = tf.constant(.1, shape=[FLAGS.batch_size, FLAGS.nmix]) 93 | return out_pi, out_mu, out_sigma 94 | 95 | def compute_gmm_loss(gt_tensor, op_tensor_activ, summ=False): 96 | 97 | #Replicate ground-truth tensor per mixture component 98 | gt_tensor_flat = tf.tile(gt_tensor, [FLAGS.nmix, 1]) 99 | 100 | #Pi, mu, sigma 101 | op_tensor_pi, op_tensor_mu, op_tensor_sigma = get_mixture_coeff(op_tensor_activ) 102 | 103 | #Flatten means, sigma, pi aligned to gt above 104 | op_tensor_mu_flat = tf.reshape(op_tensor_mu, [FLAGS.nmix*FLAGS.batch_size, FLAGS.hidden_size]) 105 | op_tensor_sigma_flat = tf.reshape(op_tensor_sigma, [FLAGS.nmix*FLAGS.batch_size]) 106 | 107 | #N(t|x, mu, sigma): batch_size x nmix 108 | op_norm_dist = tf.reshape(tf.div((.5*tf.reduce_sum(tf.square(gt_tensor_flat-op_tensor_mu_flat), \ 109 | reduction_indices=1)), op_tensor_sigma_flat), [FLAGS.batch_size, FLAGS.nmix]) 110 | op_norm_dist_min = tf.reduce_min(op_norm_dist, reduction_indices=1) 111 | op_norm_dist_minind = tf.to_int32(tf.argmin(op_norm_dist, 1)) 112 | op_pi_minind_flattened = tf.range(0, FLAGS.batch_size)*FLAGS.nmix + op_norm_dist_minind 113 | op_pi_min = tf.gather(tf.reshape(op_tensor_pi, [-1]), op_pi_minind_flattened) 114 | 115 | if(summ == True): 116 | gmm_loss = tf.reduce_mean(-tf.log(op_pi_min+1e-30) + op_norm_dist_min, reduction_indices=0) 117 | else: 118 | gmm_loss = tf.reduce_mean(op_norm_dist_min, reduction_indices=0) 119 | 120 | if(summ == True): 121 | tf.summary.scalar('gmm_loss', gmm_loss) 122 | tf.summary.scalar('op_norm_dist_min', tf.reduce_min(op_norm_dist)) 123 | tf.summary.scalar('op_norm_dist_max', tf.reduce_max(op_norm_dist)) 124 | tf.summary.scalar('op_pi_min', tf.reduce_mean(op_pi_min)) 125 | 126 | return gmm_loss, op_tensor_pi, op_tensor_mu, op_tensor_sigma 127 | 128 | def optimize(loss, lr): 129 | optimizer = tf.train.GradientDescentOptimizer(lr) 130 | return optimizer.minimize(loss) 131 | 132 | def save_chkpt(saver, epoch, sess, chkptdir, prefix='model'): 133 | if not os.path.exists(chkptdir): 134 | os.makedirs(chkptdir) 135 | save_path = saver.save(sess, "%s/%s_%06d.ckpt" % (chkptdir, prefix, epoch)) 136 | print("[DEBUG] ############ Model saved in file: %s ################" % save_path) 137 | 138 | def load_chkpt(saver, sess, chkptdir): 139 | ckpt = tf.train.get_checkpoint_state(chkptdir) 140 | if ckpt and ckpt.model_checkpoint_path: 141 | ckpt_fn = ckpt.model_checkpoint_path.replace('//', '/') 142 | print('[DEBUG] Loading checkpoint from %s' % ckpt_fn) 143 | saver.restore(sess, ckpt_fn) 144 | else: 145 | raise NameError('[ERROR] No checkpoint found at: %s' % chkptdir) 146 | 147 | def main(): 148 | 149 | if(len(sys.argv) == 1): 150 | raise NameError('[ERROR] No dataset key') 151 | if(sys.argv[1] == 'imagenetval'): 152 | FLAGS.updates_per_epoch = 49000 153 | FLAGS.num_test_batches = 1000 154 | FLAGS.in_featdir = 'data/featslist/imagenetval/' 155 | FLAGS.in_lvdir = 'data/output/imagenetval/' 156 | elif(sys.argv[1] == 'lfw'): 157 | FLAGS.updates_per_epoch = 12233 158 | FLAGS.num_test_batches = 1000 159 | FLAGS.in_featdir = 'data/featslist/lfw/' 160 | FLAGS.in_lvdir = 'data/output/lfw/' 161 | elif(sys.argv[1] == 'church'): 162 | FLAGS.updates_per_epoch = 125227 163 | FLAGS.num_test_batches = 1000 164 | FLAGS.in_featdir = 'data/featslist/church/' 165 | FLAGS.in_lvdir = 'data/output/church/' 166 | else: 167 | raise NameError('[ERROR] Incorrect dataset key') 168 | 169 | data_loader = zhangfeats_loader(os.path.join(FLAGS.in_featdir, 'list.train.txt'), \ 170 | os.path.join(FLAGS.in_featdir, 'list.test.txt'),\ 171 | os.path.join(FLAGS.in_lvdir, 'lv_color_train.mat.npy'),\ 172 | os.path.join(FLAGS.in_lvdir, 'lv_color_test.mat.npy')) 173 | 174 | #Inputs 175 | lf = layer_factory() 176 | input_tensor = tf.placeholder(tf.float32, [FLAGS.batch_size, FLAGS.feats_height, \ 177 | FLAGS.feats_width, FLAGS.feats_nch]) 178 | output_gt_tensor = tf.placeholder(tf.float32, [FLAGS.batch_size, FLAGS.hidden_size]) 179 | is_training = tf.placeholder(tf.bool) 180 | keep_prob = tf.placeholder(tf.float32) 181 | 182 | #Inference 183 | with tf.variable_scope('Inference', reuse=False): 184 | output_activ = cnn_feedforward(lf, input_tensor, is_training, keep_prob, reuse=False) 185 | 186 | with tf.variable_scope('Inference', reuse=True): 187 | output_test_activ = cnn_feedforward(lf, input_tensor, is_training, keep_prob, reuse=True) 188 | 189 | #Loss and gradient descent step 190 | loss, _, _, _ = compute_gmm_loss(output_gt_tensor, output_activ, summ=True) 191 | loss_test, pi_test, mu_test, sigma_test = compute_gmm_loss(output_gt_tensor, output_test_activ) 192 | 193 | train_step = optimize(loss, FLAGS.lr) 194 | 195 | #Standard steps 196 | check_nan_op = tf.add_check_numerics_ops() 197 | init = tf.global_variables_initializer() 198 | saver = tf.train.Saver(max_to_keep=0) 199 | summary_op = tf.summary.merge_all() 200 | 201 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9) 202 | sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) 203 | train_writer = tf.summary.FileWriter(os.path.join(FLAGS.in_lvdir, 'logs_mdn'), sess.graph) 204 | 205 | sess.run(init) 206 | 207 | if(FLAGS.is_train): 208 | for epoch in range(FLAGS.max_epoch): 209 | training_loss = 0. 210 | 211 | data_loader.random_reset() 212 | for i in range(FLAGS.updates_per_epoch): 213 | batch, batch_gt = data_loader.train_next_batch(FLAGS.batch_size) 214 | feed_dict = {input_tensor:batch, output_gt_tensor:batch_gt, \ 215 | is_training:True, keep_prob:.75} 216 | _, _, loss_value, summary_str = sess.run(\ 217 | [check_nan_op, train_step, loss, summary_op], \ 218 | feed_dict) 219 | train_writer.add_summary(summary_str, epoch*FLAGS.updates_per_epoch+i) 220 | training_loss = training_loss + loss_value 221 | 222 | print('[DEBUG] Epoch# %d, Loss: %f' % (epoch, \ 223 | (training_loss*1.)/FLAGS.updates_per_epoch)) 224 | 225 | save_chkpt(saver, epoch, sess, os.path.join(FLAGS.in_lvdir, 'models_mdn'), \ 226 | prefix='model_%d_exp' % FLAGS.nmix) 227 | else: 228 | load_chkpt(saver, sess, os.path.join(FLAGS.in_lvdir, 'models_mdn')) 229 | 230 | test_loss = 0. 231 | data_loader.reset() 232 | lv_test_codes = np.zeros((0, (FLAGS.hidden_size+1+1)*FLAGS.nmix), dtype='f') 233 | for i in range(FLAGS.num_test_batches): 234 | batch, batch_gt = data_loader.test_next_batch(FLAGS.batch_size) 235 | feed_dict = {input_tensor:batch, output_gt_tensor:batch_gt, \ 236 | is_training:False, keep_prob:1.} 237 | _, loss_value, output_pi, output_mu, output_sigma = \ 238 | sess.run([check_nan_op, loss_test, pi_test, mu_test, sigma_test], feed_dict) 239 | 240 | test_loss = test_loss + loss_value 241 | output = np.concatenate((output_mu, output_sigma, output_pi), axis=1) 242 | lv_test_codes = np.concatenate((lv_test_codes, output), axis=0) 243 | 244 | print('[DEBUG] Test Loss: %f' % ((test_loss*1.)/FLAGS.num_test_batches)) 245 | np.save(os.path.join(FLAGS.in_lvdir, 'lv_color_mdn_test.mat'), lv_test_codes) 246 | print(lv_test_codes.shape) 247 | 248 | sess.close() 249 | 250 | if __name__ == "__main__": 251 | main() 252 | -------------------------------------------------------------------------------- /mdn/save_mdn_gmm.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 3 | 4 | import socket 5 | import sys 6 | 7 | import tensorflow as tf 8 | import numpy as np 9 | from data_loaders.zhangfeats_loader import zhangfeats_loader 10 | from arch.layer_factory import layer_factory 11 | 12 | flags = tf.flags 13 | 14 | #MDN Params 15 | flags.DEFINE_integer("feats_height", 28, "") 16 | flags.DEFINE_integer("feats_width", 28, "") 17 | flags.DEFINE_integer("feats_nch", 512, "") 18 | flags.DEFINE_integer("nmix", 8, "GMM components") 19 | flags.DEFINE_integer("max_epoch", 5, "Max epoch") 20 | flags.DEFINE_float("lr", 1e-5, "Learning rate") 21 | flags.DEFINE_integer("batch_size_mdn", 1, "Batch size") 22 | 23 | flags.DEFINE_integer("hidden_size", 64, "VAE latent variable dimension") 24 | 25 | FLAGS = flags.FLAGS 26 | 27 | def cnn_feedforward(lf, input_tensor, bn_is_training, keep_prob, reuse=False): 28 | 29 | nout = (FLAGS.hidden_size+1)*FLAGS.nmix 30 | 31 | if(reuse == False): 32 | W_conv1 = lf.weight_variable(name='W_conv1', shape=[5, 5, FLAGS.feats_nch, 384]) 33 | W_conv1_1 = lf.weight_variable(name='W_conv1_1', shape=[5, 5, 384, 320]) 34 | W_conv1_2 = lf.weight_variable(name='W_conv1_2', shape=[5, 5, 320, 288]) 35 | W_conv1_3 = lf.weight_variable(name='W_conv1_3', shape=[5, 5, 288, 256]) 36 | 37 | W_conv2 = lf.weight_variable(name='W_conv2', shape=[5, 5, 256, 128]) 38 | W_fc1 = lf.weight_variable(name='W_fc1', shape=[14*14*128, 4096]) 39 | W_fc2 = lf.weight_variable(name='W_fc2', shape=[4096, nout]) 40 | 41 | 42 | b_fc1 = lf.bias_variable(name='b_fc1', shape=[4096]) 43 | b_fc2 = lf.bias_variable(name='b_fc2', shape=[nout]) 44 | 45 | else: 46 | W_conv1 = lf.weight_variable(name='W_conv1') 47 | W_conv1_1 = lf.weight_variable(name='W_conv1_1') 48 | W_conv1_2 = lf.weight_variable(name='W_conv1_2') 49 | W_conv1_3 = lf.weight_variable(name='W_conv1_3') 50 | 51 | W_conv2 = lf.weight_variable(name='W_conv2') 52 | W_fc1 = lf.weight_variable(name='W_fc1') 53 | W_fc2 = lf.weight_variable(name='W_fc2') 54 | 55 | b_fc1 = lf.bias_variable(name='b_fc1') 56 | b_fc2 = lf.bias_variable(name='b_fc2') 57 | 58 | conv1 = tf.nn.relu(lf.conv2d(input_tensor, W_conv1, stride=1)) 59 | conv1_norm = lf.batch_norm_aiuiuc_wrapper(conv1, bn_is_training, \ 60 | 'BN1', reuse_vars=reuse) 61 | 62 | conv1_1 = tf.nn.relu(lf.conv2d(conv1_norm, W_conv1_1, stride=1)) 63 | conv1_1_norm = lf.batch_norm_aiuiuc_wrapper(conv1_1, bn_is_training, \ 64 | 'BN1_1', reuse_vars=reuse) 65 | 66 | conv1_2 = tf.nn.relu(lf.conv2d(conv1_1_norm, W_conv1_2, stride=1)) 67 | conv1_2_norm = lf.batch_norm_aiuiuc_wrapper(conv1_2, bn_is_training, \ 68 | 'BN1_2', reuse_vars=reuse) 69 | 70 | conv1_3 = tf.nn.relu(lf.conv2d(conv1_2_norm, W_conv1_3, stride=2)) 71 | conv1_3_norm = lf.batch_norm_aiuiuc_wrapper(conv1_3, bn_is_training, \ 72 | 'BN1_3', reuse_vars=reuse) 73 | 74 | conv2 = tf.nn.relu(lf.conv2d(conv1_3_norm, W_conv2, stride=1)) 75 | conv2_norm = lf.batch_norm_aiuiuc_wrapper(conv2, bn_is_training, \ 76 | 'BN2', reuse_vars=reuse) 77 | 78 | dropout1 = tf.nn.dropout(conv2_norm, keep_prob) 79 | flatten1 = tf.reshape(dropout1, [-1, 14*14*128]) 80 | fc1 = tf.tanh(tf.matmul(flatten1, W_fc1)+b_fc1) 81 | 82 | dropout2 = tf.nn.dropout(fc1, keep_prob) 83 | fc2 = tf.matmul(dropout2, W_fc2)+b_fc2 84 | 85 | return fc2 86 | 87 | def get_mixture_coeff(out_fc): 88 | out_mu = out_fc[..., :FLAGS.hidden_size*FLAGS.nmix] 89 | out_pi = tf.nn.softmax(out_fc[..., FLAGS.hidden_size*FLAGS.nmix:]) 90 | out_sigma = tf.constant(.1, shape=[FLAGS.batch_size_mdn, FLAGS.nmix]) 91 | return out_pi, out_mu, out_sigma 92 | 93 | def compute_gmm_loss(gt_tensor, op_tensor_activ, summ=False): 94 | 95 | #Replicate ground-truth tensor per mixture component 96 | gt_tensor_flat = tf.tile(gt_tensor, [FLAGS.nmix, 1]) 97 | 98 | #Pi, mu, sigma 99 | op_tensor_pi, op_tensor_mu, op_tensor_sigma = get_mixture_coeff(op_tensor_activ) 100 | 101 | #Flatten means, sigma, pi aligned to gt above 102 | op_tensor_mu_flat = tf.reshape(op_tensor_mu, [FLAGS.nmix*FLAGS.batch_size_mdn, FLAGS.hidden_size]) 103 | op_tensor_sigma_flat = tf.reshape(op_tensor_sigma, [FLAGS.nmix*FLAGS.batch_size_mdn]) 104 | 105 | #N(t|x, mu, sigma): batch_size_mdn x nmix 106 | op_norm_dist = tf.reshape(tf.div((.5*tf.reduce_sum(tf.square(gt_tensor_flat-op_tensor_mu_flat), \ 107 | reduction_indices=1)), op_tensor_sigma_flat), [FLAGS.batch_size_mdn, FLAGS.nmix]) 108 | op_norm_dist_min = tf.reduce_min(op_norm_dist, reduction_indices=1) 109 | op_norm_dist_minind = tf.to_int32(tf.argmin(op_norm_dist, 1)) 110 | op_pi_minind_flattened = tf.range(0, FLAGS.batch_size_mdn)*FLAGS.nmix + op_norm_dist_minind 111 | op_pi_min = tf.gather(tf.reshape(op_tensor_pi, [-1]), op_pi_minind_flattened) 112 | 113 | if(summ == True): 114 | gmm_loss = tf.reduce_mean(-tf.log(op_pi_min+1e-30) + op_norm_dist_min, reduction_indices=0) 115 | else: 116 | gmm_loss = tf.reduce_mean(op_norm_dist_min, reduction_indices=0) 117 | 118 | if(summ == True): 119 | tf.summary.scalar('gmm_loss', gmm_loss) 120 | tf.summary.scalar('op_norm_dist_min', tf.reduce_min(op_norm_dist)) 121 | tf.summary.scalar('op_norm_dist_max', tf.reduce_max(op_norm_dist)) 122 | tf.summary.scalar('op_pi_min', tf.reduce_mean(op_pi_min)) 123 | 124 | return gmm_loss, op_tensor_pi, op_tensor_mu, op_tensor_sigma 125 | 126 | def optimize(loss, lr): 127 | optimizer = tf.train.GradientDescentOptimizer(lr) 128 | return optimizer.minimize(loss) 129 | 130 | def save_chkpt(saver, epoch, sess, chkptdir, prefix='model'): 131 | if not os.path.exists(chkptdir): 132 | os.makedirs(chkptdir) 133 | save_path = saver.save(sess, "%s/%s_%06d.ckpt" % (chkptdir, prefix, epoch)) 134 | print("[DEBUG] ############ Model saved in file: %s ################" % save_path) 135 | 136 | def load_chkpt(saver, sess, chkptdir): 137 | ckpt = tf.train.get_checkpoint_state(chkptdir) 138 | print ckpt.model_checkpoint_path 139 | if ckpt and ckpt.model_checkpoint_path: 140 | ckpt_fn = ckpt.model_checkpoint_path.replace('//', '/') 141 | print('[DEBUG] Loading checkpoint from %s' % ckpt_fn) 142 | saver.restore(sess, ckpt_fn) 143 | else: 144 | raise NameError('[ERROR] No checkpoint found at: %s' % chkptdir) 145 | 146 | def save_mdn_gmm(data_dir): 147 | 148 | FLAGS.in_featdir = data_dir 149 | FLAGS.in_lvdir = data_dir 150 | 151 | data_loader = zhangfeats_loader(os.path.join(FLAGS.in_featdir, 'list.train.txt'), \ 152 | os.path.join(FLAGS.in_featdir, 'list.test.txt'),\ 153 | os.path.join(FLAGS.in_lvdir, 'lv_color_train.mat.npy'),\ 154 | os.path.join(FLAGS.in_lvdir, 'lv_color_test.mat.npy')) 155 | 156 | FLAGS.num_test_batches = data_loader.test_img_num 157 | 158 | #Inputs 159 | lf = layer_factory() 160 | input_tensor = tf.placeholder(tf.float32, [FLAGS.batch_size_mdn, FLAGS.feats_height, \ 161 | FLAGS.feats_width, FLAGS.feats_nch]) 162 | output_gt_tensor = tf.placeholder(tf.float32, [FLAGS.batch_size_mdn, FLAGS.hidden_size]) 163 | is_training = tf.placeholder(tf.bool) 164 | keep_prob = tf.placeholder(tf.float32) 165 | 166 | #Inference 167 | with tf.variable_scope('Inference', reuse=False): 168 | output_activ = cnn_feedforward(lf, input_tensor, is_training, keep_prob, reuse=False) 169 | 170 | with tf.variable_scope('Inference', reuse=True): 171 | output_test_activ = cnn_feedforward(lf, input_tensor, is_training, keep_prob, reuse=True) 172 | 173 | #Loss and gradient descent step 174 | loss, _, _, _ = compute_gmm_loss(output_gt_tensor, output_activ, summ=True) 175 | loss_test, pi_test, mu_test, sigma_test = compute_gmm_loss(output_gt_tensor, output_test_activ) 176 | 177 | train_step = optimize(loss, FLAGS.lr) 178 | 179 | #Standard steps 180 | check_nan_op = tf.add_check_numerics_ops() 181 | init = tf.global_variables_initializer() 182 | saver = tf.train.Saver(max_to_keep=0) 183 | summary_op = tf.summary.merge_all() 184 | 185 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9) 186 | sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) 187 | 188 | sess.run(init) 189 | 190 | load_chkpt(saver, sess, 'data/imagenet_models_mdn/') 191 | 192 | data_loader.reset() 193 | lv_test_codes = np.zeros((0, (FLAGS.hidden_size+1+1)*FLAGS.nmix), dtype='f') 194 | for i in range(FLAGS.num_test_batches): 195 | batch, batch_gt = data_loader.test_next_batch(FLAGS.batch_size_mdn) 196 | feed_dict = {input_tensor:batch, output_gt_tensor:batch_gt, \ 197 | is_training:False, keep_prob:1.} 198 | _, output_pi, output_mu, output_sigma = \ 199 | sess.run([check_nan_op, pi_test, mu_test, sigma_test], feed_dict) 200 | output = np.concatenate((output_mu, output_sigma, output_pi), axis=1) 201 | lv_test_codes = np.concatenate((lv_test_codes, output), axis=0) 202 | 203 | np.save(os.path.join(FLAGS.in_lvdir, 'lv_color_mdn_test.mat'), lv_test_codes) 204 | print(lv_test_codes.shape) 205 | 206 | sess.close() 207 | 208 | return lv_test_codes 209 | 210 | if __name__=='__main__': 211 | save_mdn_gmm(sys.argv[1]) 212 | 213 | 214 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow_gpu==1.1.0 2 | scipy==0.19.0 3 | scikit_image==0.13.0 4 | matplotlib==2.0.1 5 | numpy==1.12.1 6 | Pillow==4.1.1 7 | protobuf==3.3.0 8 | scikit_learn==0.18.1 9 | tensorflow==1.1.0 10 | -------------------------------------------------------------------------------- /run_demo.sh: -------------------------------------------------------------------------------- 1 | python demo.py --is_only_data=True 2 | python mdn/save_mdn_gmm.py data/testimgs/ 3 | python demo.py --is_only_data=False 4 | -------------------------------------------------------------------------------- /run_lfw.sh: -------------------------------------------------------------------------------- 1 | python vae/train.py lfw 2 | python mdn/mdn.py lfw 3 | python vae/test.py lfw 4 | -------------------------------------------------------------------------------- /third_party/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aditya12agd5/divcolor/570155a062bcb6428353ec0e80343badbc290caf/third_party/__init__.py -------------------------------------------------------------------------------- /third_party/save_zhang_feats.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"]="1" 3 | import sys 4 | 5 | 6 | import numpy as np 7 | import caffe 8 | import skimage.color as color 9 | import skimage.io 10 | import scipy.ndimage.interpolation as sni 11 | import cv2 12 | 13 | def save_zhang_feats(img_fns, ext='JPEG'): 14 | 15 | gpu_id = 0 16 | caffe.set_mode_gpu() 17 | caffe.set_device(gpu_id) 18 | net = caffe.Net('third_party/colorization/models/colorization_deploy_v1.prototxt', \ 19 | 'third_party/colorization/models/colorization_release_v1.caffemodel', caffe.TEST) 20 | 21 | (H_in,W_in) = net.blobs['data_l'].data.shape[2:] # get input shape 22 | (H_out,W_out) = net.blobs['class8_ab'].data.shape[2:] # get output shape 23 | net.blobs['Trecip'].data[...] = 6/np.log(10) # 1/T, set annealing temperature 24 | 25 | feats_fns = [] 26 | for img_fn_i, img_fn in enumerate(img_fns): 27 | 28 | # load the original image 29 | img_rgb = caffe.io.load_image(img_fn) 30 | img_lab = color.rgb2lab(img_rgb) # convert image to lab color space 31 | img_l = img_lab[:,:,0] # pull out L channel 32 | (H_orig,W_orig) = img_rgb.shape[:2] # original image size 33 | 34 | # create grayscale version of image (just for displaying) 35 | img_lab_bw = img_lab.copy() 36 | img_lab_bw[:,:,1:] = 0 37 | img_rgb_bw = color.lab2rgb(img_lab_bw) 38 | 39 | # resize image to network input size 40 | img_rs = caffe.io.resize_image(img_rgb,(H_in,W_in)) # resize image to network input size 41 | img_lab_rs = color.rgb2lab(img_rs) 42 | img_l_rs = img_lab_rs[:,:,0] 43 | 44 | net.blobs['data_l'].data[0,0,:,:] = img_l_rs-50 # subtract 50 for mean-centering 45 | net.forward() # run network 46 | 47 | npz_fn = img_fn.replace(ext, 'npz') 48 | np.savez_compressed(npz_fn, net.blobs['conv7_3'].data) 49 | feats_fns.append(npz_fn) 50 | 51 | return feats_fns 52 | -------------------------------------------------------------------------------- /vae/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aditya12agd5/divcolor/570155a062bcb6428353ec0e80343badbc290caf/vae/__init__.py -------------------------------------------------------------------------------- /vae/arch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aditya12agd5/divcolor/570155a062bcb6428353ec0e80343badbc290caf/vae/arch/__init__.py -------------------------------------------------------------------------------- /vae/arch/layer_factory.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from tensorflow.python.framework import tensor_shape 4 | 5 | class layer_factory: 6 | 7 | def __init__(self): 8 | pass 9 | 10 | def weight_variable(self, name, shape=None, mean=0., stddev=.001, gain=np.sqrt(2)): 11 | if(shape == None): 12 | return tf.get_variable(name) 13 | # #Adaptive initialize based on variable shape 14 | # if(len(shape) == 4): 15 | # stddev = (1.0 * gain) / np.sqrt(shape[0] * shape[1] * shape[3]) 16 | # else: 17 | # stddev = (1.0 * gain) / np.sqrt(shape[0]) 18 | return tf.get_variable(name, shape=shape, initializer=tf.random_normal_initializer(mean=mean, stddev=stddev)) 19 | 20 | def bias_variable(self, name, shape=None, constval=.001): 21 | if(shape == None): 22 | return tf.get_variable(name) 23 | return tf.get_variable(name, shape=shape, initializer=tf.constant_initializer(constval)) 24 | 25 | def conv2d(self, x, W, stride=1, padding='SAME'): 26 | return tf.nn.conv2d(x, W, strides=[1, stride, stride, 1], padding=padding) 27 | 28 | def lrelu(self, x, leak=.2): 29 | return tf.maximum(x, leak*x) 30 | 31 | def batch_norm_aiuiuc_wrapper(self, x, train_phase, name, reuse_vars): 32 | output = tf.contrib.layers.batch_norm(x, \ 33 | decay=.99, \ 34 | is_training=train_phase, \ 35 | scale=True, \ 36 | epsilon=1e-4, \ 37 | updates_collections=None,\ 38 | scope=name,\ 39 | reuse=reuse_vars) 40 | return output 41 | -------------------------------------------------------------------------------- /vae/arch/network.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | import numpy as np 4 | from sklearn.cluster import KMeans 5 | 6 | class network: 7 | 8 | def __init__(self, model, data_loader, nch, flags): 9 | self.model = model 10 | self.data_loader = data_loader 11 | self.nch = nch 12 | self.flags = flags 13 | self.__build_graph() 14 | 15 | def __build_graph(self): 16 | self.plhold_img, self.plhold_greylevel, self.plhold_latent, self.plhold_is_training, \ 17 | self.plhold_keep_prob, self.plhold_kl_weight, self.plhold_lossweights \ 18 | = self.model.inputs() 19 | 20 | #inference graph 21 | self.op_mean, self.op_stddev, self.op_vae, \ 22 | self.op_mean_test, self.op_stddev_test, self.op_vae_test, \ 23 | self.op_vae_condinference \ 24 | = self.model.inference(self.plhold_img, self.plhold_greylevel, \ 25 | self.plhold_latent, self.plhold_is_training, self.plhold_keep_prob) 26 | 27 | #loss function and gd step for vae 28 | self.loss = self.model.loss(self.plhold_img, self.op_vae, self.op_mean, \ 29 | self.op_stddev, self.plhold_kl_weight, self.plhold_lossweights) 30 | self.train_step = self.model.optimize(self.loss, epsilon=1e-6) 31 | 32 | #standard steps 33 | self.check_nan_op = tf.add_check_numerics_ops() 34 | self.init = tf.global_variables_initializer() 35 | self.saver = tf.train.Saver(max_to_keep=0) 36 | self.summary_op = tf.summary.merge_all() 37 | 38 | def train_vae(self, chkptdir, is_train=True): 39 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.95) 40 | sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) 41 | num_train_batches = np.int_(np.floor((self.data_loader.train_img_num*1.)/\ 42 | self.flags.batch_size)) 43 | if(is_train == True): 44 | sess.run(self.init) 45 | print('[DEBUG] Saving TensorBoard summaries to: %s/logs' % self.flags.out_dir) 46 | self.train_writer = tf.summary.FileWriter(os.path.join(self.flags.out_dir, 'logs'), sess.graph) 47 | 48 | #Train vae 49 | for epoch in range(self.flags.max_epoch_vae): 50 | epoch_loss = self.run_vae_epoch_train(epoch, sess) 51 | epoch_loss = (epoch_loss*1.) / (self.flags.updates_per_epoch) 52 | print('[DEBUG] ####### Train VAE Epoch#%d, Loss %f #######' % (epoch, epoch_loss)) 53 | self.__save_chkpt(epoch, sess, chkptdir, prefix='model_vae') 54 | else: 55 | self.__load_chkpt(sess, chkptdir) 56 | 57 | epoch_latentvars_train = self.run_vae_epoch_test(self.flags.max_epoch_vae,\ 58 | sess, num_batches=num_train_batches, is_train=True) 59 | 60 | epoch_latentvars_test = self.run_vae_epoch_test(2, sess, num_batches=3, is_train=False) 61 | 62 | sess.close() 63 | return epoch_latentvars_train, epoch_latentvars_test 64 | 65 | def run_vae_epoch_train(self, epoch, sess): 66 | epoch_loss = 0. 67 | self.data_loader.random_reset() 68 | delta_kl_weight = (1e-2*1.)/(self.flags.max_epoch_vae*1.) #CHANGED WEIGHT SCHEDULE 69 | latent_feed = np.zeros((self.flags.batch_size, self.flags.hidden_size), dtype='f') 70 | for i in range(self.flags.updates_per_epoch): 71 | kl_weight = delta_kl_weight*(epoch) 72 | batch, batch_recon_const, batch_lossweights, batch_recon_const_outres = \ 73 | self.data_loader.train_next_batch(self.flags.batch_size, self.nch) 74 | feed_dict = {self.plhold_img: batch, self.plhold_is_training:True, \ 75 | self.plhold_keep_prob:.7, \ 76 | self.plhold_kl_weight:kl_weight, \ 77 | self.plhold_latent:latent_feed, \ 78 | self.plhold_lossweights:batch_lossweights, \ 79 | self.plhold_greylevel:batch_recon_const} 80 | try: 81 | _, _, loss_value, output, summary_str = sess.run(\ 82 | [self.check_nan_op, self.train_step, self.loss, \ 83 | self.op_vae, self.summary_op], feed_dict) 84 | except: 85 | raise NameError('[ERROR] Found nan values in run_vae_epoch_train') 86 | self.train_writer.add_summary(summary_str, epoch*self.flags.updates_per_epoch+i) 87 | if(i % self.flags.log_interval == 0): 88 | self.data_loader.save_output_with_gt(output, batch, epoch, i, \ 89 | '%02d_train_vae' % self.nch, self.flags.batch_size, \ 90 | num_cols=8, net_recon_const=batch_recon_const_outres) 91 | epoch_loss += loss_value 92 | return epoch_loss 93 | 94 | def run_vae_epoch_test(self, epoch, sess, num_batches=3, is_train=False): 95 | self.data_loader.reset() 96 | kl_weight = 0. 97 | latentvars_epoch = np.zeros((0, 2*self.flags.hidden_size), dtype='f') 98 | latent_feed = np.zeros((self.flags.batch_size, self.flags.hidden_size), dtype='f') 99 | for i in range(num_batches): 100 | if(is_train == False): 101 | batch, batch_recon_const, batch_recon_const_outres, _ = \ 102 | self.data_loader.test_next_batch(self.flags.batch_size, self.nch) 103 | else: 104 | batch, batch_recon_const, _, batch_recon_const_outres = \ 105 | self.data_loader.train_next_batch(self.flags.batch_size, self.nch) 106 | 107 | batch_lossweights = np.ones((self.flags.batch_size, \ 108 | self.nch*self.flags.img_height*self.flags.img_width), dtype='f') 109 | 110 | feed_dict = {self.plhold_img: batch, self.plhold_is_training:False, \ 111 | self.plhold_keep_prob:1., \ 112 | self.plhold_kl_weight:kl_weight, \ 113 | self.plhold_latent:latent_feed, \ 114 | self.plhold_lossweights: batch_lossweights, \ 115 | self.plhold_greylevel:batch_recon_const} 116 | try: 117 | _, means_batch, stddevs_batch, output = sess.run(\ 118 | [self.check_nan_op, self.op_mean_test, self.op_stddev_test, \ 119 | self.op_vae_test], feed_dict) 120 | except: 121 | raise NameError('[ERROR] Found nan values in run_vae_epoch_test') 122 | if(is_train == False): 123 | self.data_loader.save_output_with_gt(output, batch, epoch, i, \ 124 | '%02d_test_vae' % self.nch, self.flags.batch_size, num_cols=8, \ 125 | net_recon_const=batch_recon_const_outres) 126 | else: 127 | if(i % self.flags.log_interval == 0): 128 | self.data_loader.save_output_with_gt(output, batch, epoch, i, \ 129 | '%02d_latentvar' % self.nch, self.flags.batch_size, num_cols=8, \ 130 | net_recon_const=batch_recon_const_outres) 131 | 132 | latentvars_epoch = np.concatenate((latentvars_epoch, \ 133 | np.concatenate((means_batch, stddevs_batch), axis=1)), axis=0) 134 | return latentvars_epoch 135 | 136 | def run_cvae(self, chkptdir, latentvars, num_batches=3, num_repeat=8, num_cluster=5): 137 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.95) 138 | sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) 139 | self.__load_chkpt(sess, chkptdir) 140 | self.data_loader.reset() 141 | kl_weight = 0. 142 | for i in range(num_batches): 143 | print ('[DEBUG] Batch %d (/ %d)' % (i, num_batches)) 144 | batch, batch_recon_const, batch_recon_const_outres, batch_imgnames = \ 145 | self.data_loader.test_next_batch(self.flags.batch_size, self.nch) 146 | batch_lossweights = np.ones((self.flags.batch_size, \ 147 | self.nch*self.flags.img_height*self.flags.img_width), dtype='f') 148 | for j in range(self.flags.batch_size): 149 | output_all = np.zeros((0, self.nch*self.flags.img_height*self.flags.img_width), dtype='f') 150 | print ('[DEBUG] Batch %d (/ %d), Img %d' % (i, num_batches, j)) 151 | for k in range(num_repeat): 152 | imgid = i*self.flags.batch_size+j 153 | batch_1 = np.tile(batch[j, ...], (self.flags.batch_size, 1)) 154 | batch_recon_const_1 = np.tile(batch_recon_const[j, ...], (self.flags.batch_size, 1)) 155 | batch_recon_const_outres_1 = np.tile(batch_recon_const_outres[j, ...], (self.flags.batch_size, 1)) 156 | 157 | latent_feed = np.random.normal(loc=0., scale=1., \ 158 | size=(self.flags.batch_size, self.flags.hidden_size)) 159 | #latent_feed = latentvars[imgid*self.flags.batch_size:(imgid+1)*self.flags.batch_size, ...] 160 | feed_dict = {self.plhold_img:batch_1, self.plhold_is_training:False, \ 161 | self.plhold_keep_prob:1., \ 162 | self.plhold_kl_weight:kl_weight, \ 163 | self.plhold_latent:latent_feed, \ 164 | self.plhold_lossweights: batch_lossweights, 165 | self.plhold_greylevel:batch_recon_const_1} 166 | try: 167 | _, output = sess.run(\ 168 | [self.check_nan_op, self.op_vae_condinference], \ 169 | feed_dict) 170 | except: 171 | raise NameError('[ERROR] Found nan values in condinference_vae') 172 | output_all = np.concatenate((output_all, output), axis=0) 173 | 174 | print ('[DEBUG] Clustering %d predictions' % output_all.shape[0]) 175 | kmeans = KMeans(n_clusters=num_cluster, random_state=0).fit(output_all) 176 | output_clust = kmeans.cluster_centers_ 177 | self.data_loader.save_divcolor(output_clust, batch_1[:num_cluster], i, j, \ 178 | 'cvae', num_cluster, batch_imgnames[j], num_cols=8, \ 179 | net_recon_const=batch_recon_const_outres_1[:num_cluster]) 180 | 181 | sess.close() 182 | 183 | def run_divcolor(self, chkptdir, latentvars, num_batches=3, topk=8): 184 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.95) 185 | sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) 186 | self.__load_chkpt(sess, chkptdir) 187 | self.data_loader.reset() 188 | kl_weight = 0. 189 | nmix = topk 190 | for i in range(num_batches): 191 | batch, batch_recon_const, batch_recon_const_outres, batch_imgnames = \ 192 | self.data_loader.test_next_batch(self.flags.batch_size, self.nch) 193 | batch_lossweights = np.ones((self.flags.batch_size, \ 194 | self.nch*self.flags.img_height*self.flags.img_width), dtype='f') 195 | output_all = np.zeros((0, \ 196 | self.nch*self.flags.img_height*self.flags.img_width), dtype='f') 197 | for j in range(self.flags.batch_size): 198 | imgid = i*self.flags.batch_size+j 199 | print ('[DEBUG] Running divcolor on Image %d (/%d)' % \ 200 | (imgid, self.data_loader.test_img_num)) 201 | if(imgid >= self.data_loader.test_img_num): 202 | break 203 | batch_1 = np.tile(batch[j, ...], (self.flags.batch_size, 1)) 204 | batch_recon_const_1 = np.tile(batch_recon_const[j, ...], (self.flags.batch_size, 1)) 205 | batch_recon_const_outres_1 = np.tile(batch_recon_const_outres[j, ...], (self.flags.batch_size, 1)) 206 | curr_means = latentvars[imgid, :self.flags.hidden_size*nmix].reshape(nmix, self.flags.hidden_size) 207 | curr_sigma = latentvars[imgid, self.flags.hidden_size*nmix:(self.flags.hidden_size+1)*nmix].reshape(-1) 208 | curr_pi = latentvars[imgid, (self.flags.hidden_size+1)*nmix:].reshape(-1) 209 | selectid = np.argsort(-1*curr_pi) 210 | latent_feed = np.tile(curr_means[selectid, ...], (np.int_(np.round((self.flags.batch_size*1.)/nmix)), 1)) 211 | 212 | feed_dict = {self.plhold_img:batch_1, self.plhold_is_training:False, \ 213 | self.plhold_keep_prob:1., \ 214 | self.plhold_kl_weight:kl_weight, \ 215 | self.plhold_latent:latent_feed, \ 216 | self.plhold_lossweights: batch_lossweights, 217 | self.plhold_greylevel:batch_recon_const_1} 218 | 219 | _, output = sess.run(\ 220 | [self.check_nan_op, self.op_vae_condinference], \ 221 | feed_dict) 222 | 223 | self.data_loader.save_divcolor(output[:topk], batch_1[:topk], i, j, \ 224 | 'divcolor', topk, batch_imgnames[j], num_cols=8, \ 225 | net_recon_const=batch_recon_const_outres_1[:topk, ...]) 226 | 227 | sess.close() 228 | 229 | 230 | def __save_chkpt(self, epoch, sess, chkptdir, prefix='model'): 231 | if not os.path.exists(chkptdir): 232 | os.makedirs(chkptdir) 233 | save_path = self.saver.save(sess, "%s/%s_%06d.ckpt" % (chkptdir, prefix, epoch)) 234 | print("[DEBUG] ############ Model saved in file: %s ################" % save_path) 235 | 236 | def __load_chkpt(self, sess, chkptdir): 237 | ckpt = tf.train.get_checkpoint_state(chkptdir) 238 | if ckpt and ckpt.model_checkpoint_path: 239 | ckpt_fn = ckpt.model_checkpoint_path.replace('//', '/') 240 | print('[DEBUG] Loading checkpoint from %s' % ckpt_fn) 241 | self.saver.restore(sess, ckpt_fn) 242 | else: 243 | raise NameError('[ERROR] No checkpoint found at: %s' % chkptdir) 244 | -------------------------------------------------------------------------------- /vae/arch/vae_skipconn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from layer_factory import layer_factory 3 | from tensorflow.python.framework import tensor_shape 4 | 5 | class vae_skipconn: 6 | 7 | def __init__(self, flags, nch=2, condinference_flag=False): 8 | self.flags = flags 9 | self.nch = nch 10 | self.layer_factory = layer_factory() 11 | self.condinference_flag = condinference_flag 12 | 13 | #Returns handles to input placeholders 14 | def inputs(self): 15 | inp_img = tf.placeholder(tf.float32, [self.flags.batch_size, \ 16 | self.nch * self.flags.img_height * self.flags.img_width]) 17 | inp_greylevel = tf.placeholder(tf.float32, [self.flags.batch_size, \ 18 | self.flags.img_height * self.flags.img_width]) 19 | inp_latent = tf.placeholder(tf.float32, [self.flags.batch_size, \ 20 | self.flags.hidden_size]) 21 | is_training = tf.placeholder(tf.bool) 22 | keep_prob = tf.placeholder(tf.float32) 23 | kl_weight = tf.placeholder(tf.float32) 24 | lossweights = tf.placeholder(tf.float32, [self.flags.batch_size, \ 25 | self.nch * self.flags.img_height * self.flags.img_width]) 26 | return inp_img, inp_greylevel, inp_latent, is_training, keep_prob, kl_weight, lossweights 27 | 28 | #Takes input placeholders, builds inference graph and returns net. outputs 29 | def inference(self, inp_img, inp_greylevel, inp_latent, is_training, keep_prob): 30 | 31 | with tf.variable_scope('Inference', reuse=False) as sc: 32 | gfeat32, gfeat16, gfeat8, gfeat4 = \ 33 | self.__cond_encoder(sc, inp_greylevel, is_training, keep_prob, \ 34 | in_nch=1, reuse=False) 35 | z1_train = self.__encoder(sc, inp_img, is_training, keep_prob, \ 36 | in_nch=self.nch, reuse=False) 37 | epsilon_train = tf.truncated_normal([self.flags.batch_size, self.flags.hidden_size]) 38 | mean_train = z1_train[:, :self.flags.hidden_size] 39 | stddev_train = tf.sqrt(tf.exp(z1_train[:, self.flags.hidden_size:])) 40 | z1_sample = mean_train + epsilon_train * stddev_train 41 | output_train = self.__decoder(sc, gfeat32, gfeat16, gfeat8, gfeat4, \ 42 | is_training, inp_greylevel, z1_sample, reuse=False) 43 | 44 | with tf.variable_scope('Inference', reuse=True) as sc: 45 | gfeat32, gfeat16, gfeat8, gfeat4 = \ 46 | self.__cond_encoder(sc, inp_greylevel, is_training, keep_prob, \ 47 | in_nch=1, reuse=True) 48 | 49 | if(self.condinference_flag == False): 50 | z1_test = self.__encoder(sc, inp_img, is_training, keep_prob, \ 51 | in_nch=self.nch, reuse=True) 52 | epsilon_test = tf.truncated_normal([self.flags.batch_size, self.flags.hidden_size]) 53 | mean_test = z1_test[:, :self.flags.hidden_size] 54 | stddev_test = tf.sqrt(tf.exp(z1_test[:, self.flags.hidden_size:])) 55 | z1_sample = mean_test + epsilon_test * stddev_test 56 | tf.stop_gradient(z1_sample) #Fix the encoder 57 | output_test = self.__decoder(sc, gfeat32, gfeat16, gfeat8, gfeat4, \ 58 | is_training, inp_greylevel, z1_sample, reuse=True) 59 | output_condinference = None 60 | else: 61 | mean_test = None 62 | stddev_test = None 63 | output_test = None 64 | output_condinference = self.__decoder(sc, gfeat32, gfeat16, gfeat8,\ 65 | gfeat4, is_training, inp_greylevel, inp_latent, reuse=True) 66 | 67 | return mean_train, stddev_train, output_train, mean_test, stddev_test, \ 68 | output_test, output_condinference 69 | 70 | #Takes net. outputs and computes loss for vae(enc+dec) 71 | def loss(self, target_tensor, op_tensor, mean, stddev, kl_weight, lossweights, epsilon=1e-6): 72 | 73 | kl_loss = tf.reduce_sum(0.5 * (tf.square(mean) + tf.square(stddev) \ 74 | - tf.log(tf.maximum(tf.square(stddev), epsilon)) - 1.0)) 75 | 76 | recon_loss = tf.reduce_mean(tf.sqrt(tf.reduce_sum( \ 77 | lossweights*tf.square(target_tensor-op_tensor), 1)), 0) 78 | 79 | recon_loss_l2 = tf.reduce_mean(tf.sqrt(tf.reduce_sum( \ 80 | tf.square(target_tensor-op_tensor), 1)), 0) 81 | 82 | if(self.nch == 2): 83 | target_tensor2d = tf.reshape(target_tensor, [self.flags.batch_size, \ 84 | self.flags.img_height, self.flags.img_width, self.nch]) 85 | op_tensor2d = tf.reshape(op_tensor, [self.flags.batch_size, \ 86 | self.flags.img_height, self.flags.img_width, self.nch]) 87 | [n,w,h,c] = target_tensor2d.get_shape().as_list() 88 | dv = tf.square((target_tensor2d[:,1:,:h-1,:] - target_tensor2d[:,:w-1,:h-1,:]) 89 | - (op_tensor2d[:,1:,:h-1,:] - op_tensor2d[:,:w-1,:h-1,:])) 90 | dh = tf.square((target_tensor2d[:,:w-1,1:,:] - target_tensor2d[:,:w-1,:h-1,:]) 91 | - (op_tensor2d[:,:w-1,1:,:] - op_tensor2d[:,:w-1,:h-1,:])) 92 | grad_loss = tf.reduce_mean(tf.sqrt(tf.reduce_sum(dv+dh,[1,2,3]))) 93 | recon_loss = recon_loss + (1e-3)*grad_loss 94 | 95 | loss = kl_weight*kl_loss + recon_loss 96 | tf.summary.scalar('kl_loss', kl_loss) 97 | tf.summary.scalar('grad_loss', grad_loss) 98 | tf.summary.scalar('recon_loss', recon_loss) 99 | tf.summary.scalar('recon_loss_l2', recon_loss_l2) 100 | tf.summary.scalar('loss', loss) 101 | return loss 102 | 103 | #Takes loss and returns GD train step 104 | def optimize(self, loss, epsilon): 105 | train_step = tf.train.AdamOptimizer(self.flags.lr_vae, epsilon=epsilon).minimize(loss) 106 | return train_step 107 | 108 | def __cond_encoder(self, scope, input_tensor, bn_is_training, keep_prob, in_nch=1, reuse=False): 109 | 110 | lf = self.layer_factory 111 | input_tensor2d = tf.reshape(input_tensor, [self.flags.batch_size, \ 112 | self.flags.img_height, self.flags.img_width, 1]) 113 | nch = tensor_shape.as_dimension(input_tensor2d.get_shape()[3]).value 114 | nout = self.flags.hidden_size 115 | 116 | if(reuse == False): 117 | W_conv1 = lf.weight_variable(name='W_conv1_cond', shape=[5, 5, nch, 128]) 118 | W_conv2 = lf.weight_variable(name='W_conv2_cond', shape=[5, 5, 128, 256]) 119 | W_conv3 = lf.weight_variable(name='W_conv3_cond', shape=[5, 5, 256, 512]) 120 | W_conv4 = lf.weight_variable(name='W_conv4_cond', shape=[4, 4, 512, self.flags.hidden_size]) 121 | 122 | b_conv1 = lf.bias_variable(name='b_conv1_cond', shape=[128]) 123 | b_conv2 = lf.bias_variable(name='b_conv2_cond', shape=[256]) 124 | b_conv3 = lf.bias_variable(name='b_conv3_cond', shape=[512]) 125 | b_conv4 = lf.bias_variable(name='b_conv4_cond', shape=[self.flags.hidden_size]) 126 | else: 127 | W_conv1 = lf.weight_variable(name='W_conv1_cond') 128 | W_conv2 = lf.weight_variable(name='W_conv2_cond') 129 | W_conv3 = lf.weight_variable(name='W_conv3_cond') 130 | W_conv4 = lf.weight_variable(name='W_conv4_cond') 131 | 132 | b_conv1 = lf.bias_variable(name='b_conv1_cond') 133 | b_conv2 = lf.bias_variable(name='b_conv2_cond') 134 | b_conv3 = lf.bias_variable(name='b_conv3_cond') 135 | b_conv4 = lf.bias_variable(name='b_conv4_cond') 136 | 137 | conv1 = tf.nn.relu(lf.conv2d(input_tensor2d, W_conv1, stride=2) + b_conv1) 138 | conv1_norm = lf.batch_norm_aiuiuc_wrapper(conv1, bn_is_training, \ 139 | 'BN1_cond', reuse_vars=reuse) 140 | 141 | conv2 = tf.nn.relu(lf.conv2d(conv1_norm, W_conv2, stride=2) + b_conv2) 142 | conv2_norm = lf.batch_norm_aiuiuc_wrapper(conv2, bn_is_training, \ 143 | 'BN2_cond', reuse_vars=reuse) 144 | 145 | conv3 = tf.nn.relu(lf.conv2d(conv2_norm, W_conv3, stride=2) + b_conv3) 146 | conv3_norm = lf.batch_norm_aiuiuc_wrapper(conv3, bn_is_training, \ 147 | 'BN3_cond', reuse_vars=reuse) 148 | 149 | conv4 = tf.nn.relu(lf.conv2d(conv3_norm, W_conv4, stride=2) + b_conv4) 150 | conv4_norm = lf.batch_norm_aiuiuc_wrapper(conv4, bn_is_training, \ 151 | 'BN4_cond', reuse_vars=reuse) 152 | 153 | return conv1_norm, conv2_norm, conv3_norm, conv4_norm 154 | 155 | def __encoder(self, scope, input_tensor, bn_is_training, keep_prob, in_nch=2, reuse=False): 156 | 157 | lf = self.layer_factory 158 | 159 | input_tensor2d = tf.reshape(input_tensor, [self.flags.batch_size, \ 160 | self.flags.img_height, self.flags.img_width, in_nch]) 161 | 162 | nch = tensor_shape.as_dimension(input_tensor2d.get_shape()[3]).value 163 | 164 | if(reuse==False): 165 | W_conv1 = lf.weight_variable(name='W_conv1', shape=[5, 5, nch, 128]) 166 | W_conv2 = lf.weight_variable(name='W_conv2', shape=[5, 5, 128, 256]) 167 | W_conv3 = lf.weight_variable(name='W_conv3', shape=[5, 5, 256, 512]) 168 | W_conv4 = lf.weight_variable(name='W_conv4', shape=[4, 4, 512, 1024]) 169 | W_fc1 = lf.weight_variable(name='W_fc1', shape=[4*4*1024, self.flags.hidden_size * 2]) 170 | 171 | b_conv1 = lf.bias_variable(name='b_conv1', shape=[128]) 172 | b_conv2 = lf.bias_variable(name='b_conv2', shape=[256]) 173 | b_conv3 = lf.bias_variable(name='b_conv3', shape=[512]) 174 | b_conv4 = lf.bias_variable(name='b_conv4', shape=[1024]) 175 | b_fc1 = lf.bias_variable(name='b_fc1', shape=[self.flags.hidden_size * 2]) 176 | else: 177 | W_conv1 = lf.weight_variable(name='W_conv1') 178 | W_conv2 = lf.weight_variable(name='W_conv2') 179 | W_conv3 = lf.weight_variable(name='W_conv3') 180 | W_conv4 = lf.weight_variable(name='W_conv4') 181 | W_fc1 = lf.weight_variable(name='W_fc1') 182 | 183 | b_conv1 = lf.bias_variable(name='b_conv1') 184 | b_conv2 = lf.bias_variable(name='b_conv2') 185 | b_conv3 = lf.bias_variable(name='b_conv3') 186 | b_conv4 = lf.bias_variable(name='b_conv4') 187 | b_fc1 = lf.bias_variable(name='b_fc1') 188 | 189 | conv1 = tf.nn.relu(lf.conv2d(input_tensor2d, W_conv1, stride=2) + b_conv1) 190 | conv1_norm = lf.batch_norm_aiuiuc_wrapper(conv1, bn_is_training, \ 191 | 'BN1', reuse_vars=reuse) 192 | 193 | conv2 = tf.nn.relu(lf.conv2d(conv1_norm, W_conv2, stride=2) + b_conv2) 194 | conv2_norm = lf.batch_norm_aiuiuc_wrapper(conv2, bn_is_training, \ 195 | 'BN2', reuse_vars=reuse) 196 | 197 | conv3 = tf.nn.relu(lf.conv2d(conv2_norm, W_conv3, stride=2) + b_conv3) 198 | conv3_norm = lf.batch_norm_aiuiuc_wrapper(conv3, bn_is_training, \ 199 | 'BN3', reuse_vars=reuse) 200 | 201 | conv4 = tf.nn.relu(lf.conv2d(conv3_norm, W_conv4, stride=2) + b_conv4) 202 | conv4_norm = lf.batch_norm_aiuiuc_wrapper(conv4, bn_is_training, \ 203 | 'BN4', reuse_vars=reuse) 204 | 205 | dropout1 = tf.nn.dropout(conv4_norm, keep_prob) 206 | flatten1 = tf.reshape(dropout1, [-1, 4*4*1024]) 207 | 208 | fc1 = tf.matmul(flatten1, W_fc1)+b_fc1 209 | 210 | return fc1 211 | 212 | def __decoder(self, scope, gfeat32, gfeat16, gfeat8, gfeat4, bn_is_training, inp_greylevel,\ 213 | z1_sample, reuse=False): 214 | 215 | lf = self.layer_factory 216 | 217 | if(reuse == False): 218 | W_deconv1 = lf.weight_variable(name='W_deconv1', shape=[4, 4, self.flags.hidden_size, 1024]) 219 | W_deconv2 = lf.weight_variable(name='W_deconv2', shape=[5, 5, 1024+512, 512]) 220 | W_deconv3 = lf.weight_variable(name='W_deconv3', shape=[5, 5, 512+256, 256]) 221 | W_deconv4 = lf.weight_variable(name='W_deconv4', shape=[5, 5, 256+128, 128]) 222 | W_deconv5 = lf.weight_variable(name='W_deconv5', shape=[5, 5, 128, self.nch]) 223 | 224 | b_deconv1 = lf.bias_variable(name='b_deconv1', shape=[1024]) 225 | b_deconv2 = lf.bias_variable(name='b_deconv2', shape=[512]) 226 | b_deconv3 = lf.bias_variable(name='b_deconv3', shape=[256]) 227 | b_deconv4 = lf.bias_variable(name='b_deconv4', shape=[128]) 228 | b_deconv5 = lf.bias_variable(name='b_deconv5', shape=[self.nch]) 229 | else: 230 | W_deconv1 = lf.weight_variable(name='W_deconv1') 231 | W_deconv2 = lf.weight_variable(name='W_deconv2') 232 | W_deconv3 = lf.weight_variable(name='W_deconv3') 233 | W_deconv4 = lf.weight_variable(name='W_deconv4') 234 | W_deconv5 = lf.weight_variable(name='W_deconv5') 235 | 236 | b_deconv1 = lf.bias_variable(name='b_deconv1') 237 | b_deconv2 = lf.bias_variable(name='b_deconv2') 238 | b_deconv3 = lf.bias_variable(name='b_deconv3') 239 | b_deconv4 = lf.bias_variable(name='b_deconv4') 240 | b_deconv5 = lf.bias_variable(name='b_deconv5') 241 | 242 | inp_greylevel2d = tf.reshape(inp_greylevel, [self.flags.batch_size, \ 243 | self.flags.img_height, self.flags.img_width, 1]) 244 | input2d = tf.reshape(z1_sample, [self.flags.batch_size, 1, 1, self.flags.hidden_size]) 245 | deconv1_upsamp = tf.image.resize_images(input2d, [4, 4]) 246 | 247 | deconv1_upsamp_sc = tf.multiply(deconv1_upsamp, gfeat4) 248 | 249 | deconv1 = tf.nn.relu(lf.conv2d(deconv1_upsamp_sc, W_deconv1, stride=1) + b_deconv1) 250 | deconv1_norm = lf.batch_norm_aiuiuc_wrapper(deconv1, bn_is_training, \ 251 | 'BN_deconv1', reuse_vars=reuse) 252 | 253 | deconv2_upsamp = tf.image.resize_images(deconv1_norm, [8, 8]) 254 | deconv2_upsamp_sc = tf.concat([deconv2_upsamp, gfeat8], 3) 255 | deconv2 = tf.nn.relu(lf.conv2d(deconv2_upsamp_sc, W_deconv2, stride=1) + b_deconv2) 256 | deconv2_norm = lf.batch_norm_aiuiuc_wrapper(deconv2, bn_is_training, \ 257 | 'BN_deconv2', reuse_vars=reuse) 258 | 259 | deconv3_upsamp = tf.image.resize_images(deconv2_norm, [16, 16]) 260 | deconv3_upsamp_sc = tf.concat([deconv3_upsamp, gfeat16], 3) 261 | deconv3 = tf.nn.relu(lf.conv2d(deconv3_upsamp_sc, W_deconv3, stride=1) + b_deconv3) 262 | deconv3_norm = lf.batch_norm_aiuiuc_wrapper(deconv3, bn_is_training, \ 263 | 'BN_deconv3', reuse_vars=reuse) 264 | 265 | deconv4_upsamp = tf.image.resize_images(deconv3_norm, [32, 32]) 266 | deconv4_upsamp_sc = tf.concat([deconv4_upsamp, gfeat32], 3) 267 | deconv4 = tf.nn.relu(lf.conv2d(deconv4_upsamp_sc, W_deconv4, stride=1) + b_deconv4) 268 | deconv4_norm = lf.batch_norm_aiuiuc_wrapper(deconv4, bn_is_training, \ 269 | 'BN_deconv4', reuse_vars=reuse) 270 | 271 | deconv5_upsamp = tf.image.resize_images(deconv4_norm, [64, 64]) 272 | deconv5 = lf.conv2d(deconv5_upsamp, W_deconv5, stride=1) + b_deconv5 273 | deconv5_norm = lf.batch_norm_aiuiuc_wrapper(deconv5, bn_is_training, \ 274 | 'BN_deconv5', reuse_vars=reuse) 275 | 276 | decoded_ch = tf.reshape(tf.tanh(deconv5_norm), \ 277 | [self.flags.batch_size, self.flags.img_height*self.flags.img_width*self.nch]) 278 | 279 | return decoded_ch 280 | -------------------------------------------------------------------------------- /vae/arch/vae_wo_skipconn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tensorflow as tf 4 | from layer_factory import layer_factory 5 | from tensorflow.python.framework import tensor_shape 6 | 7 | class vae_wo_skipconn: 8 | 9 | def __init__(self, flags, nch=2, condinference_flag=False): 10 | self.flags = flags 11 | self.nch = nch 12 | self.layer_factory = layer_factory() 13 | self.condinference_flag = condinference_flag 14 | 15 | #Returns handles to input placeholders 16 | def inputs(self): 17 | inp_img = tf.placeholder(tf.float32, [self.flags.batch_size, \ 18 | self.nch * self.flags.img_height * self.flags.img_width]) 19 | inp_greylevel = tf.placeholder(tf.float32, [self.flags.batch_size, \ 20 | self.flags.img_height * self.flags.img_width]) 21 | inp_latent = tf.placeholder(tf.float32, [self.flags.batch_size, \ 22 | self.flags.hidden_size]) 23 | is_training = tf.placeholder(tf.bool) 24 | is_training_dec = tf.placeholder(tf.bool) 25 | keep_prob = tf.placeholder(tf.float32) 26 | kl_weight = tf.placeholder(tf.float32) 27 | lossweights = tf.placeholder(tf.float32, [self.flags.batch_size, \ 28 | self.nch * self.flags.img_height * self.flags.img_width]) 29 | 30 | return inp_img, inp_greylevel, inp_latent, is_training, keep_prob, kl_weight, lossweights 31 | 32 | #Takes input placeholders, builds inference graph and returns net. outputs 33 | def inference(self, inp_img, inp_greylevel, inp_latent, is_training, keep_prob): 34 | 35 | with tf.variable_scope('Inference', reuse=False) as sc: 36 | z1_train = self.__encoder(sc, inp_img, is_training, keep_prob, \ 37 | in_nch=self.nch, reuse=False) 38 | epsilon_train = tf.truncated_normal([self.flags.batch_size, self.flags.hidden_size]) 39 | mean_train = z1_train[:, :self.flags.hidden_size] 40 | stddev_train = tf.sqrt(tf.exp(z1_train[:, self.flags.hidden_size:])) 41 | z1_sample = mean_train + epsilon_train * stddev_train 42 | output_train = self.__decoder(sc, is_training, inp_greylevel, z1_sample, reuse=False) 43 | 44 | with tf.variable_scope('Inference', reuse=True) as sc: 45 | if(self.condinference_flag == False): 46 | z1_test = self.__encoder(sc, inp_img, is_training, keep_prob, \ 47 | in_nch=self.nch, reuse=True) 48 | epsilon_test = tf.truncated_normal([self.flags.batch_size, self.flags.hidden_size]) 49 | mean_test = z1_test[:, :self.flags.hidden_size] 50 | stddev_test = tf.sqrt(tf.exp(z1_test[:, self.flags.hidden_size:])) 51 | z1_sample = mean_test + epsilon_test * stddev_test 52 | tf.stop_gradient(z1_sample) #Fix the encoder 53 | output_test = self.__decoder(sc, is_training, inp_greylevel, z1_sample, reuse=True) 54 | output_condinference = None 55 | else: 56 | mean_test = None 57 | stddev_test = None 58 | output_test = None 59 | output_condinference = self.__decoder(sc, is_training, inp_greylevel, inp_latent,\ 60 | reuse=True) 61 | 62 | return mean_train, stddev_train, output_train, mean_test, stddev_test, \ 63 | output_test, output_condinference 64 | 65 | #Takes net. outputs and computes loss for vae(enc+dec) 66 | def loss(self, target_tensor, op_tensor, mean, stddev, kl_weight, lossweights, epsilon=1e-6, \ 67 | is_regression=True): 68 | 69 | kl_loss = tf.reduce_sum(0.5 * (tf.square(mean) + tf.square(stddev) \ 70 | - tf.log(tf.maximum(tf.square(stddev), epsilon)) - 1.0)) 71 | 72 | recon_loss_chi = tf.reduce_mean(tf.sqrt(tf.reduce_sum( \ 73 | lossweights*tf.square(target_tensor-op_tensor), 1)), 0) 74 | 75 | #Load Principle components 76 | np_pcvec = np.transpose(np.load(os.path.join(self.flags.pc_dir, 'components.mat.npy'))) 77 | np_pcvar = 1./np.load(os.path.join(self.flags.pc_dir, 'exp_variance.mat.npy')) 78 | np_pcvec = np_pcvec[:, :self.flags.pc_comp] 79 | np_pcvar = np_pcvar[:self.flags.pc_comp] 80 | pcvec = tf.constant(np_pcvec) 81 | pcvar = tf.constant(np_pcvar) 82 | 83 | projmat_op = tf.matmul(op_tensor, pcvec) 84 | projmat_target = tf.matmul(target_tensor, pcvec) 85 | weightmat = tf.tile(tf.reshape(pcvar, [1, self.flags.pc_comp]), [self.flags.batch_size, 1]) 86 | loss_topk_pc = tf.reduce_mean(tf.reduce_sum(\ 87 | tf.multiply(tf.square(projmat_op-projmat_target), weightmat), 1), 0) 88 | 89 | res_op = op_tensor 90 | res_target = target_tensor 91 | for npc in range(self.flags.pc_comp): 92 | pcvec_curr = tf.tile(tf.reshape(tf.transpose(pcvec[:, npc]), [1, -1]), \ 93 | [self.flags.batch_size, 1]) 94 | projop_curr = tf.tile(tf.reshape(projmat_op[:, npc], [self.flags.batch_size, 1]), \ 95 | [1, self.nch * self.flags.img_height * self.flags.img_width]) 96 | 97 | projtarget_curr = tf.tile(tf.reshape(projmat_target[:, npc], [self.flags.batch_size, 1]), \ 98 | [1, self.nch * self.flags.img_height * self.flags.img_width]) 99 | 100 | res_op = tf.subtract(res_op, tf.multiply(projop_curr, pcvec_curr)) 101 | res_target = tf.subtract(res_target, tf.multiply(projtarget_curr, pcvec_curr)) 102 | 103 | res_error = tf.reduce_sum(tf.square(res_op-res_target), 1) 104 | res_error_weight = tf.tile(tf.reshape(pcvar[self.flags.pc_comp-1], [1, 1]), [self.flags.batch_size, 1]) 105 | loss_res_pc = tf.reduce_mean(tf.multiply(\ 106 | tf.reshape(res_error, [self.flags.batch_size, 1]), res_error_weight)) 107 | 108 | recon_loss = recon_loss_chi + (1e-1)*(loss_topk_pc + loss_res_pc) 109 | 110 | if(self.nch == 2): 111 | target_tensor2d = tf.reshape(target_tensor, [self.flags.batch_size, \ 112 | self.flags.img_height, self.flags.img_width, self.nch]) 113 | op_tensor2d = tf.reshape(op_tensor, [self.flags.batch_size, \ 114 | self.flags.img_height, self.flags.img_width, self.nch]) 115 | [n,w,h,c] = target_tensor2d.get_shape().as_list() 116 | dv = tf.square((target_tensor2d[:,1:,:h-1,:] - target_tensor2d[:,:w-1,:h-1,:]) 117 | - (op_tensor2d[:,1:,:h-1,:] - op_tensor2d[:,:w-1,:h-1,:])) 118 | dh = tf.square((target_tensor2d[:,:w-1,1:,:] - target_tensor2d[:,:w-1,:h-1,:]) 119 | - (op_tensor2d[:,:w-1,1:,:] - op_tensor2d[:,:w-1,:h-1,:])) 120 | grad_loss = tf.reduce_mean(tf.sqrt(tf.reduce_sum(dv+dh,[1,2,3]))) 121 | recon_loss = recon_loss + (1e-3)*grad_loss 122 | 123 | 124 | loss = kl_weight*kl_loss + recon_loss 125 | 126 | tf.summary.scalar('grad_loss', grad_loss) 127 | tf.summary.scalar('kl_loss', kl_loss) 128 | tf.summary.scalar('recon_loss_chi', recon_loss_chi) 129 | tf.summary.scalar('recon_loss', recon_loss) 130 | return loss 131 | 132 | #Takes loss and returns GD train step 133 | def optimize(self, loss, epsilon): 134 | train_step = tf.train.AdamOptimizer(self.flags.lr_vae, epsilon=epsilon).minimize(loss) 135 | return train_step 136 | 137 | def __encoder(self, scope, input_tensor, bn_is_training, keep_prob, in_nch=1, reuse=False): 138 | 139 | lf = self.layer_factory 140 | 141 | input_tensor2d = tf.reshape(input_tensor, [self.flags.batch_size, \ 142 | self.flags.img_height, self.flags.img_width, in_nch]) 143 | 144 | if(self.nch == 1 and reuse==False): 145 | tf.image_summary('summ_input_tensor2d', input_tensor2d, max_images=10) 146 | 147 | nch = tensor_shape.as_dimension(input_tensor2d.get_shape()[3]).value 148 | 149 | if(reuse==False): 150 | W_conv1 = lf.weight_variable(name='W_conv1', shape=[5, 5, nch, 128]) 151 | W_conv2 = lf.weight_variable(name='W_conv2', shape=[5, 5, 128, 256]) 152 | W_conv3 = lf.weight_variable(name='W_conv3', shape=[5, 5, 256, 512]) 153 | W_conv4 = lf.weight_variable(name='W_conv4', shape=[4, 4, 512, 1024]) 154 | W_fc1 = lf.weight_variable(name='W_fc1', shape=[4*4*1024, self.flags.hidden_size * 2]) 155 | 156 | b_conv1 = lf.bias_variable(name='b_conv1', shape=[128]) 157 | b_conv2 = lf.bias_variable(name='b_conv2', shape=[256]) 158 | b_conv3 = lf.bias_variable(name='b_conv3', shape=[512]) 159 | b_conv4 = lf.bias_variable(name='b_conv4', shape=[1024]) 160 | b_fc1 = lf.bias_variable(name='b_fc1', shape=[self.flags.hidden_size * 2]) 161 | else: 162 | W_conv1 = lf.weight_variable(name='W_conv1') 163 | W_conv2 = lf.weight_variable(name='W_conv2') 164 | W_conv3 = lf.weight_variable(name='W_conv3') 165 | W_conv4 = lf.weight_variable(name='W_conv4') 166 | W_fc1 = lf.weight_variable(name='W_fc1') 167 | 168 | b_conv1 = lf.bias_variable(name='b_conv1') 169 | b_conv2 = lf.bias_variable(name='b_conv2') 170 | b_conv3 = lf.bias_variable(name='b_conv3') 171 | b_conv4 = lf.bias_variable(name='b_conv4') 172 | b_fc1 = lf.bias_variable(name='b_fc1') 173 | 174 | conv1 = tf.nn.relu(lf.conv2d(input_tensor2d, W_conv1, stride=2) + b_conv1) 175 | conv1_norm = lf.batch_norm_aiuiuc_wrapper(conv1, bn_is_training, \ 176 | 'BN1', reuse_vars=reuse) 177 | 178 | conv2 = tf.nn.relu(lf.conv2d(conv1_norm, W_conv2, stride=2) + b_conv2) 179 | conv2_norm = lf.batch_norm_aiuiuc_wrapper(conv2, bn_is_training, \ 180 | 'BN2', reuse_vars=reuse) 181 | 182 | conv3 = tf.nn.relu(lf.conv2d(conv2_norm, W_conv3, stride=2) + b_conv3) 183 | conv3_norm = lf.batch_norm_aiuiuc_wrapper(conv3, bn_is_training, \ 184 | 'BN3', reuse_vars=reuse) 185 | 186 | conv4 = tf.nn.relu(lf.conv2d(conv3_norm, W_conv4, stride=2) + b_conv4) 187 | conv4_norm = lf.batch_norm_aiuiuc_wrapper(conv4, bn_is_training, \ 188 | 'BN4', reuse_vars=reuse) 189 | 190 | dropout1 = tf.nn.dropout(conv4_norm, keep_prob) 191 | flatten1 = tf.reshape(dropout1, [-1, 4*4*1024]) 192 | 193 | fc1 = tf.matmul(flatten1, W_fc1)+b_fc1 194 | 195 | return fc1 196 | 197 | def __decoder(self, scope, bn_is_training, inp_greylevel, z1_sample, reuse=False): 198 | 199 | lf = self.layer_factory 200 | 201 | if(reuse == False): 202 | W_deconv1 = lf.weight_variable(name='W_deconv1', shape=[4, 4, self.flags.hidden_size, 1024]) 203 | W_deconv2 = lf.weight_variable(name='W_deconv2', shape=[5, 5, 1024, 512]) 204 | W_deconv3 = lf.weight_variable(name='W_deconv3', shape=[5, 5, 514, 256]) 205 | W_deconv4 = lf.weight_variable(name='W_deconv4', shape=[5, 5, 258, 128]) 206 | W_deconv5 = lf.weight_variable(name='W_deconv5', shape=[5, 5, 128, self.nch]) 207 | 208 | b_deconv1 = lf.bias_variable(name='b_deconv1', shape=[1024]) 209 | b_deconv2 = lf.bias_variable(name='b_deconv2', shape=[512]) 210 | b_deconv3 = lf.bias_variable(name='b_deconv3', shape=[256]) 211 | b_deconv4 = lf.bias_variable(name='b_deconv4', shape=[128]) 212 | b_deconv5 = lf.bias_variable(name='b_deconv5', shape=[self.nch]) 213 | else: 214 | W_deconv1 = lf.weight_variable(name='W_deconv1') 215 | W_deconv2 = lf.weight_variable(name='W_deconv2') 216 | W_deconv3 = lf.weight_variable(name='W_deconv3') 217 | W_deconv4 = lf.weight_variable(name='W_deconv4') 218 | W_deconv5 = lf.weight_variable(name='W_deconv5') 219 | 220 | b_deconv1 = lf.bias_variable(name='b_deconv1') 221 | b_deconv2 = lf.bias_variable(name='b_deconv2') 222 | b_deconv3 = lf.bias_variable(name='b_deconv3') 223 | b_deconv4 = lf.bias_variable(name='b_deconv4') 224 | b_deconv5 = lf.bias_variable(name='b_deconv5') 225 | 226 | inp_greylevel2d = tf.reshape(inp_greylevel, [self.flags.batch_size, \ 227 | self.flags.img_height, self.flags.img_width, 1]) 228 | input_concat2d = tf.reshape(z1_sample, [self.flags.batch_size, 1, 1, self.flags.hidden_size]) 229 | 230 | deconv1_upsamp = tf.image.resize_images(input_concat2d, [4, 4]) 231 | deconv1 = tf.nn.relu(lf.conv2d(deconv1_upsamp, W_deconv1, stride=1) + b_deconv1) 232 | deconv1_norm = lf.batch_norm_aiuiuc_wrapper(deconv1, bn_is_training, \ 233 | 'BN_deconv1', reuse_vars=reuse) 234 | 235 | deconv2_upsamp = tf.image.resize_images(deconv1_norm, [8, 8]) 236 | deconv2 = tf.nn.relu(lf.conv2d(deconv2_upsamp, W_deconv2, stride=1) + b_deconv2) 237 | deconv2_norm = lf.batch_norm_aiuiuc_wrapper(deconv2, bn_is_training, \ 238 | 'BN_deconv2', reuse_vars=reuse) 239 | 240 | deconv3_upsamp = tf.image.resize_images(deconv2_norm, [16, 16]) 241 | grey_deconv3_dv, grey_deconv3_dh = self.__get_gradients(inp_greylevel2d, \ 242 | shape=[16, 16]) 243 | deconv3_upsamp_edge = tf.concat([deconv3_upsamp, grey_deconv3_dv, grey_deconv3_dh], 3) 244 | deconv3 = tf.nn.relu(lf.conv2d(deconv3_upsamp_edge, W_deconv3, stride=1) + b_deconv3) 245 | deconv3_norm = lf.batch_norm_aiuiuc_wrapper(deconv3, bn_is_training, \ 246 | 'BN_deconv3', reuse_vars=reuse) 247 | 248 | deconv4_upsamp = tf.image.resize_images(deconv3_norm, [32, 32]) 249 | grey_deconv4_dv, grey_deconv4_dh = self.__get_gradients(inp_greylevel2d, \ 250 | shape=[32, 32]) 251 | deconv4_upsamp_edge = tf.concat([deconv4_upsamp, grey_deconv4_dv, grey_deconv4_dh], 3) 252 | deconv4 = tf.nn.relu(lf.conv2d(deconv4_upsamp_edge, W_deconv4, stride=1) + b_deconv4) 253 | deconv4_norm = lf.batch_norm_aiuiuc_wrapper(deconv4, bn_is_training, \ 254 | 'BN_deconv4', reuse_vars=reuse) 255 | 256 | deconv5_upsamp = tf.image.resize_images(deconv4_norm, [64, 64]) 257 | deconv5 = lf.conv2d(deconv5_upsamp, W_deconv5, stride=1) + b_deconv5 258 | deconv5_norm = lf.batch_norm_aiuiuc_wrapper(deconv5, bn_is_training, \ 259 | 'BN_deconv5', reuse_vars=reuse) 260 | 261 | decoded_ch = tf.reshape(tf.tanh(deconv5_norm), \ 262 | [self.flags.batch_size, self.flags.img_height*self.flags.img_width*self.nch]) 263 | 264 | return decoded_ch 265 | 266 | def __get_gradients(self, in_tensor2d, shape=None): 267 | if(shape is not None): 268 | in_tensor = tf.image.resize_images(in_tensor2d, [shape[0], shape[1]]) 269 | else: 270 | in_tensor = in_tensor2d 271 | [n,w,h,c] = in_tensor.get_shape().as_list() 272 | dvert = in_tensor[:,1:,:h,:] - in_tensor[:,:w-1,:h,:] 273 | dvert_padded = tf.concat([tf.constant(0., shape=[n, 1, h, c]), dvert], 1) 274 | dhorz = in_tensor[:,:w,1:,:] - in_tensor[:,:w,:h-1,:] 275 | dhorz_padded = tf.concat([tf.constant(0., shape=[n, w, 1, c]), dhorz], 2) 276 | return dvert_padded, dhorz_padded 277 | -------------------------------------------------------------------------------- /vae/data_loaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aditya12agd5/divcolor/570155a062bcb6428353ec0e80343badbc290caf/vae/data_loaders/__init__.py -------------------------------------------------------------------------------- /vae/data_loaders/lab_imageloader.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import glob 3 | import math 4 | import numpy as np 5 | import os 6 | 7 | class lab_imageloader: 8 | 9 | def __init__(self, data_directory, out_directory, listdir=None, shape=(64, 64), \ 10 | subdir=False, ext='JPEG', outshape=(256, 256)): 11 | 12 | if(listdir == None): 13 | if(subdir==False): 14 | self.test_img_fns = glob.glob('%s/*.%s' % (data_directory, ext)) 15 | else: 16 | self.test_img_fns = glob.glob('%s/*/*.%s' % (data_directory, ext)) 17 | self.train_img_fns = [] 18 | else: 19 | self.train_img_fns = [] 20 | self.test_img_fns = [] 21 | with open('%s/list.train.vae.txt' % listdir, 'r') as ftr: 22 | for img_fn in ftr: 23 | self.train_img_fns.append(img_fn.strip('\n')) 24 | 25 | with open('%s/list.test.vae.txt' % listdir, 'r') as fte: 26 | for img_fn in fte: 27 | self.test_img_fns.append(img_fn.strip('\n')) 28 | 29 | self.train_img_num = len(self.train_img_fns) 30 | self.test_img_num = len(self.test_img_fns) 31 | self.train_batch_head = 0 32 | self.test_batch_head = 0 33 | self.train_shuff_ids = np.random.permutation(len(self.train_img_fns)) 34 | self.test_shuff_ids = range(len(self.test_img_fns)) 35 | self.shape = shape 36 | self.outshape = outshape 37 | self.out_directory = out_directory 38 | self.lossweights = None 39 | 40 | countbins = 1./np.load('data/zhang_weights/prior_probs.npy') 41 | binedges = np.load('data/zhang_weights/ab_quantize.npy').reshape(2, 313) 42 | lossweights = {} 43 | for i in range(313): 44 | if binedges[0, i] not in lossweights: 45 | lossweights[binedges[0, i]] = {} 46 | lossweights[binedges[0,i]][binedges[1,i]] = countbins[i] 47 | self.binedges = binedges 48 | self.lossweights = lossweights 49 | 50 | def reset(self): 51 | self.train_batch_head = 0 52 | self.test_batch_head = 0 53 | self.train_shuff_ids = range(len(self.train_img_fns)) 54 | self.test_shuff_ids = range(len(self.test_img_fns)) 55 | 56 | def random_reset(self): 57 | self.train_batch_head = 0 58 | self.test_batch_head = 0 59 | self.train_shuff_ids = np.random.permutation(len(self.train_img_fns)) 60 | self.test_shuff_ids = range(len(self.test_img_fns)) 61 | 62 | def train_next_batch(self, batch_size, nch): 63 | batch = np.zeros((batch_size, nch*np.prod(self.shape)), dtype='f') 64 | batch_lossweights = np.ones((batch_size, nch*np.prod(self.shape)), dtype='f') 65 | batch_recon_const = np.zeros((batch_size, np.prod(self.shape)), dtype='f') 66 | batch_recon_const_outres = np.zeros((batch_size, np.prod(self.outshape)), dtype='f') 67 | 68 | if(self.train_batch_head + batch_size >= len(self.train_img_fns)): 69 | self.train_shuff_ids = np.random.permutation(len(self.train_img_fns)) 70 | self.train_batch_head = 0 71 | 72 | for i_n, i in enumerate(range(self.train_batch_head, self.train_batch_head+batch_size)): 73 | currid = self.train_shuff_ids[i] 74 | img_large = cv2.imread(self.train_img_fns[currid]) 75 | 76 | if(self.shape is not None): 77 | img = cv2.resize(img_large, (self.shape[0], self.shape[1])) 78 | img_outres = cv2.resize(img_large, (self.outshape[0], self.outshape[1])) 79 | 80 | img_lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) 81 | img_lab_outres = cv2.cvtColor(img_outres, cv2.COLOR_BGR2LAB) 82 | batch_recon_const[i_n, ...] = ((img_lab[..., 0].reshape(-1)*1.)-128.)/128. 83 | batch_recon_const_outres[i_n, ...] = ((img_lab_outres[..., 0].reshape(-1)*1.)-128.)/128. 84 | batch[i_n, ...] = np.concatenate((((img_lab[..., 1].reshape(-1)*1.)-128.)/128., 85 | ((img_lab[..., 2].reshape(-1)*1.)-128.)/128.), axis=0) 86 | 87 | if(self.lossweights is not None): 88 | batch_lossweights[i_n, ...] = self.__get_lossweights(batch[i_n, ...]) 89 | 90 | self.train_batch_head = self.train_batch_head + batch_size 91 | 92 | return batch, batch_recon_const, batch_lossweights, batch_recon_const_outres 93 | 94 | def test_next_batch(self, batch_size, nch): 95 | batch = np.zeros((batch_size, nch*np.prod(self.shape)), dtype='f') 96 | batch_recon_const = np.zeros((batch_size, np.prod(self.shape)), dtype='f') 97 | batch_recon_const_outres = np.zeros((batch_size, np.prod(self.outshape)), dtype='f') 98 | batch_imgnames = [] 99 | if(self.test_batch_head + batch_size > len(self.test_img_fns)): 100 | self.test_batch_head = 0 101 | 102 | for i_n, i in enumerate(range(self.test_batch_head, self.test_batch_head+batch_size)): 103 | if(i >= self.test_img_num): 104 | #Repeat first image to make up for incomplete last batch 105 | i = 0 106 | currid = self.test_shuff_ids[i] 107 | img_large = cv2.imread(self.test_img_fns[currid]) 108 | batch_imgnames.append(self.test_img_fns[currid].split('/')[-1]) 109 | if(self.shape is not None): 110 | img = cv2.resize(img_large, (self.shape[1], self.shape[0])) 111 | img_outres = cv2.resize(img_large, (self.outshape[0], self.outshape[1])) 112 | 113 | img_lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) 114 | img_lab_outres = cv2.cvtColor(img_outres, cv2.COLOR_BGR2LAB) 115 | 116 | batch_recon_const[i_n, ...] = ((img_lab[..., 0].reshape(-1)*1.)-128.)/128. 117 | batch_recon_const_outres[i_n, ...] = ((img_lab_outres[..., 0].reshape(-1)*1.)-128.)/128. 118 | batch[i_n, ...] = np.concatenate((((img_lab[..., 1].reshape(-1)*1.)-128.)/128., 119 | ((img_lab[..., 2].reshape(-1)*1.)-128.)/128.), axis=0) 120 | 121 | self.test_batch_head = self.test_batch_head + batch_size 122 | 123 | return batch, batch_recon_const, batch_recon_const_outres, batch_imgnames 124 | 125 | def save_output_with_gt(self, net_op, gt, epoch, itr_id, prefix, batch_size, num_cols=8, net_recon_const=None): 126 | net_out_img = self.save_output(net_op, batch_size, num_cols=num_cols, net_recon_const=net_recon_const) 127 | gt_out_img = self.save_output(gt, batch_size, num_cols=num_cols, net_recon_const=net_recon_const) 128 | num_rows = np.int_(np.ceil((batch_size*1.)/num_cols)) 129 | border_img = 255*np.ones((num_rows*self.outshape[0], 128, 3), dtype='uint8') 130 | out_fn_pred = '%s/%s_pred_%06d_%06d.png' % (self.out_directory, prefix, epoch, itr_id) 131 | print('[DEBUG] Writing output image: %s' % out_fn_pred) 132 | cv2.imwrite(out_fn_pred, np.concatenate((net_out_img, border_img, gt_out_img), axis=1)) 133 | 134 | def save_output(self, net_op, batch_size, num_cols=8, net_recon_const=None): 135 | num_rows = np.int_(np.ceil((batch_size*1.)/num_cols)) 136 | out_img = np.zeros((num_rows*self.outshape[0], num_cols*self.outshape[1], 3), dtype='uint8') 137 | img_lab = np.zeros((self.outshape[0], self.outshape[1], 3), dtype='uint8') 138 | c = 0 139 | r = 0 140 | for i in range(batch_size): 141 | if(i % num_cols == 0 and i > 0): 142 | r = r + 1 143 | c = 0 144 | img_lab[..., 0] = self.__get_decoded_img(net_recon_const[i, ...].reshape(self.outshape[0], self.outshape[1])) 145 | img_lab[..., 1] = self.__get_decoded_img(net_op[i, :np.prod(self.shape)].reshape(self.shape[0], self.shape[1])) 146 | img_lab[..., 2] = self.__get_decoded_img(net_op[i, np.prod(self.shape):].reshape(self.shape[0], self.shape[1])) 147 | img_rgb = cv2.cvtColor(img_lab, cv2.COLOR_LAB2BGR) 148 | out_img[r*self.outshape[0]:(r+1)*self.outshape[0], c*self.outshape[1]:(c+1)*self.outshape[1], ...] = img_rgb 149 | c = c+1 150 | return out_img 151 | 152 | def save_divcolor(self, net_op, gt, epoch, itr_id, prefix, batch_size, imgname, num_cols=8, net_recon_const=None): 153 | img_lab = np.zeros((self.outshape[0], self.outshape[1], 3), dtype='uint8') 154 | img_lab_mat = np.zeros((self.shape[0], self.shape[1], 2), dtype='uint8') 155 | if not os.path.exists('%s/%s' % (self.out_directory, imgname)): 156 | os.makedirs('%s/%s' % (self.out_directory, imgname)) 157 | for i in range(batch_size): 158 | img_lab[..., 0] = self.__get_decoded_img(net_recon_const[i, ...].reshape(self.outshape[0], self.outshape[1])) 159 | img_lab[..., 1] = self.__get_decoded_img(net_op[i, :np.prod(self.shape)].reshape(self.shape[0], self.shape[1])) 160 | img_lab[..., 2] = self.__get_decoded_img(net_op[i, np.prod(self.shape):].reshape(self.shape[0], self.shape[1])) 161 | img_lab_mat[..., 0] = 128.*net_op[i, :np.prod(self.shape)].reshape(self.shape[0], self.shape[1])+128. 162 | img_lab_mat[..., 1] = 128.*net_op[i, np.prod(self.shape):].reshape(self.shape[0], self.shape[1])+128. 163 | img_rgb = cv2.cvtColor(img_lab, cv2.COLOR_LAB2BGR) 164 | out_fn_pred = '%s/%s/%s_%03d.png' % (self.out_directory, imgname, prefix, i) 165 | cv2.imwrite(out_fn_pred, img_rgb) 166 | # out_fn_mat = '%s/%s/%s_%03d.mat' % (self.out_directory, imgname, prefix, i) 167 | # np.save(out_fn_mat, img_lab_mat) 168 | img_lab[..., 0] = self.__get_decoded_img(net_recon_const[i, ...].reshape(self.outshape[0], self.outshape[1])) 169 | img_lab[..., 1] = self.__get_decoded_img(gt[0, :np.prod(self.shape)].reshape(self.shape[0], self.shape[1])) 170 | img_lab[..., 2] = self.__get_decoded_img(gt[0, np.prod(self.shape):].reshape(self.shape[0], self.shape[1])) 171 | img_lab_mat[..., 0] = 128.*gt[0, :np.prod(self.shape)].reshape(self.shape[0], self.shape[1])+128. 172 | img_lab_mat[..., 1] = 128.*gt[0, np.prod(self.shape):].reshape(self.shape[0], self.shape[1])+128. 173 | out_fn_pred = '%s/%s/gt.png' % (self.out_directory, imgname) 174 | img_rgb = cv2.cvtColor(img_lab, cv2.COLOR_LAB2BGR) 175 | cv2.imwrite(out_fn_pred, img_rgb) 176 | # out_fn_mat = '%s/%s/gt.mat' % (self.out_directory, imgname) 177 | # np.save(out_fn_mat, img_lab_mat) 178 | 179 | def __get_decoded_img(self, img_enc): 180 | img_dec = 128.*img_enc + 128 181 | img_dec[img_dec < 0.] = 0. 182 | img_dec[img_dec > 255.] = 255. 183 | return cv2.resize(np.uint8(img_dec), (self.outshape[0], self.outshape[1])) 184 | 185 | def __get_lossweights(self, img_vec): 186 | img_vec = img_vec*128. 187 | img_lossweights = np.zeros(img_vec.shape, dtype='f') 188 | img_vec_a = img_vec[:np.prod(self.shape)] 189 | binedges_a = self.binedges[0,...].reshape(-1) 190 | binid_a = [binedges_a.flat[np.abs(binedges_a-v).argmin()] for v in img_vec_a] 191 | img_vec_b = img_vec[np.prod(self.shape):] 192 | binedges_b = self.binedges[1,...].reshape(-1) 193 | binid_b = [binedges_b.flat[np.abs(binedges_b-v).argmin()] for v in img_vec_b] 194 | binweights = np.array([self.lossweights[v1][v2] for v1,v2 in zip(binid_a, binid_b)]) 195 | img_lossweights[:np.prod(self.shape)] = binweights 196 | img_lossweights[np.prod(self.shape):] = binweights 197 | return img_lossweights 198 | -------------------------------------------------------------------------------- /vae/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 3 | import socket 4 | import sys 5 | 6 | import tensorflow as tf 7 | import numpy as np 8 | from data_loaders.lab_imageloader import lab_imageloader 9 | from arch.vae_skipconn import vae_skipconn as vae 10 | #from arch.vae_wo_skipconn import vae_wo_skipconn as vae 11 | 12 | from arch.network import network 13 | 14 | flags = tf.flags 15 | 16 | #Directory params 17 | flags.DEFINE_string("out_dir", "", "") 18 | flags.DEFINE_string("in_dir", "", "") 19 | flags.DEFINE_string("list_dir", "", "") 20 | 21 | #Dataset Params 22 | flags.DEFINE_integer("batch_size", 32, "batch size") 23 | flags.DEFINE_integer("updates_per_epoch", 1, "number of updates per epoch") 24 | flags.DEFINE_integer("log_interval", 1, "input image height") 25 | flags.DEFINE_integer("img_width", 64, "input image width") 26 | flags.DEFINE_integer("img_height", 64, "input image height") 27 | 28 | #Network Params 29 | flags.DEFINE_boolean("is_train", True, "Is training flag") 30 | flags.DEFINE_boolean("is_run_cvae", False, "Is training flag") 31 | flags.DEFINE_integer("hidden_size", 64, "size of the hidden VAE unit") 32 | flags.DEFINE_float("lr_vae", 1e-6, "learning rate for vae") 33 | flags.DEFINE_integer("max_epoch_vae", 10, "max epoch") 34 | flags.DEFINE_integer("pc_comp", 20, "number of principle components") 35 | 36 | 37 | FLAGS = flags.FLAGS 38 | 39 | def main(): 40 | if(len(sys.argv) == 1): 41 | raise NameError('[ERROR] No dataset key') 42 | elif(sys.argv[1] == 'lfw'): 43 | FLAGS.updates_per_epoch = 380 44 | FLAGS.log_interval = 120 45 | FLAGS.out_dir = 'data/output/lfw/' 46 | FLAGS.list_dir = 'data/imglist/lfw/' 47 | FLAGS.pc_dir = 'data/pcomp/lfw/' 48 | else: 49 | raise NameError('[ERROR] Incorrect dataset key') 50 | data_loader = lab_imageloader(FLAGS.in_dir, \ 51 | os.path.join(FLAGS.out_dir, 'images'), \ 52 | listdir=FLAGS.list_dir) 53 | 54 | #Diverse Colorization 55 | nmix = 8 56 | num_batches = 31 57 | lv_mdn_test = np.load(os.path.join(FLAGS.out_dir, 'lv_color_mdn_test.mat.npy')) 58 | 59 | graph_divcolor = tf.Graph() 60 | with graph_divcolor.as_default(): 61 | model_colorfield = vae(FLAGS, nch=2, condinference_flag=True) 62 | dnn = network(model_colorfield, data_loader, 2, FLAGS) 63 | dnn.run_divcolor(os.path.join(FLAGS.out_dir, 'models') , \ 64 | latent_vars_colorfield_test, num_batches=num_batches) 65 | if(FLAGS.is_run_cvae == True): 66 | dnn.run_cvae(os.path.join(FLAGS.out_dir, 'models') , \ 67 | lv_mdn_test, num_batches=num_batches, num_repeat=8, num_cluster=5) 68 | 69 | if __name__ == "__main__": 70 | main() 71 | -------------------------------------------------------------------------------- /vae/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 3 | import socket 4 | import sys 5 | 6 | import tensorflow as tf 7 | import numpy as np 8 | from data_loaders.lab_imageloader import lab_imageloader 9 | from arch.vae_skipconn import vae_skipconn as vae 10 | #from arch.vae_wo_skipconn import vae_wo_skipconn as vae 11 | 12 | from arch.network import network as network 13 | 14 | flags = tf.flags 15 | 16 | #Directory params 17 | flags.DEFINE_string("out_dir", "", "") 18 | flags.DEFINE_string("in_dir", "", "") 19 | flags.DEFINE_string("list_dir", "", "") 20 | 21 | #Dataset Params 22 | flags.DEFINE_integer("batch_size", 32, "batch size") 23 | flags.DEFINE_integer("updates_per_epoch", 1, "number of updates per epoch") 24 | flags.DEFINE_integer("log_interval", 1, "input image height") 25 | flags.DEFINE_integer("img_width", 64, "input image width") 26 | flags.DEFINE_integer("img_height", 64, "input image height") 27 | 28 | #Network Params 29 | flags.DEFINE_boolean("is_train", True, "Is training flag") 30 | flags.DEFINE_integer("hidden_size", 64, "size of the hidden VAE unit") 31 | flags.DEFINE_float("lr_vae", 1e-6, "learning rate for vae") 32 | flags.DEFINE_integer("max_epoch_vae", 10, "max epoch") 33 | flags.DEFINE_integer("pc_comp", 20, "number of principle components") 34 | 35 | 36 | FLAGS = flags.FLAGS 37 | 38 | def main(): 39 | if(len(sys.argv) == 1): 40 | raise NameError('[ERROR] No dataset key') 41 | elif(sys.argv[1] == 'lfw'): 42 | FLAGS.updates_per_epoch = 380 43 | FLAGS.log_interval = 120 44 | FLAGS.out_dir = 'data/output/lfw/' 45 | FLAGS.list_dir = 'data/imglist/lfw/' 46 | FLAGS.pc_dir = 'data/pcomp/lfw/' 47 | #add other datasets here 48 | else: 49 | raise NameError('[ERROR] Incorrect dataset key') 50 | 51 | data_loader = lab_imageloader(FLAGS.in_dir, \ 52 | os.path.join(FLAGS.out_dir, 'images'), \ 53 | listdir=FLAGS.list_dir) 54 | 55 | #Train colorfield VAE 56 | graph_vae = tf.Graph() 57 | with graph_vae.as_default(): 58 | model_colorfield = vae(FLAGS, nch=2) 59 | dnn = network(model_colorfield, data_loader, 2, FLAGS) 60 | latent_vars_colorfield, latent_vars_colorfield_musigma_test = \ 61 | dnn.train_vae(os.path.join(FLAGS.out_dir, 'models'), FLAGS.is_train) 62 | 63 | np.save(os.path.join(FLAGS.out_dir, 'lv_color_train.mat'), latent_vars_colorfield) 64 | np.save(os.path.join(FLAGS.out_dir, 'lv_color_test.mat'), latent_vars_colorfield_musigma_test) 65 | 66 | if __name__ == "__main__": 67 | main() 68 | --------------------------------------------------------------------------------