├── .gitignore
├── LICENSE
├── README.md
├── demo.py
├── get_data.sh
├── get_zhang_colorization.sh
├── images
├── divcolor_figure.png
└── divcolor_imagenet.png
├── mdn
├── __init__.py
├── arch
│ ├── __init__.py
│ └── layer_factory.py
├── data_loaders
│ ├── __init__.py
│ └── zhangfeats_loader.py
├── mdn.py
└── save_mdn_gmm.py
├── requirements.txt
├── run_demo.sh
├── run_lfw.sh
├── third_party
├── __init__.py
└── save_zhang_feats.py
└── vae
├── __init__.py
├── arch
├── __init__.py
├── layer_factory.py
├── network.py
├── vae_skipconn.py
└── vae_wo_skipconn.py
├── data_loaders
├── __init__.py
└── lab_imageloader.py
├── test.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.txt
3 | *.jpg
4 | *.JPEG
5 | *meta
6 | *ckpt
7 | *npy
8 | *npz
9 | *mat
10 | *jpg
11 | *webp
12 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "{}"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2017 Aditya Deshpande
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Tensorflow implementation of Deshpande et al. "[Learning Diverse Image Colorization](https://arxiv.org/abs/1612.01958)"
2 |
3 | The code is tested for Tensorflow-v1.0.1 and python-2.7. The code additionally needs numpy, scipy,
4 | scikit-learn and caffe-r1.0 (caffe only for Zhang et al. colorization network).
5 |
6 | Fetch data by
7 |
8 | ```
9 | bash get_data.sh
10 | ```
11 |
12 | Fetch Zhang et al. colorization network for MDN features by
13 |
14 | ```
15 | bash get_zhang_colorization.sh
16 | ```
17 |
18 | Execute run_lfw.sh to first train vae+mdn and then, generate results for LFW
19 |
20 | ```
21 | bash run_lfw.sh
22 | ```
23 |
24 | Execute run_demo.sh to get diverse colorization for any image, the model is trained on imagenet
25 |
26 | ```
27 | bash run_demo.sh
28 | ```
29 |
30 | If you use this code, please cite
31 |
32 | ```
33 | @inproceedings{DeshpandeLDColor17,
34 | author = {Aditya Deshpande, Jiajun Lu, Mao-Chuang Yeh, Min Jin Chong and David Forsyth},
35 | title = {Learning Diverse Image Colorization},
36 | booktitle={Computer Vision and Pattern Recognition},
37 | url={https://arxiv.org/abs/1612.01958},
38 | year={2017}
39 | }
40 | ```
41 |
42 | Some examples of diverse colorizations on LFW, LSUN Church and ImageNet-Val dataset
43 |
44 |
45 |
46 |
47 |
48 | Some examples of diverse colorizations for images in the wild, model is trained on imagenet
49 |
50 |
51 |
52 |
53 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["CUDA_VISIBLE_DEVICES"]="0"
3 | import socket
4 | import sys
5 |
6 | import tensorflow as tf
7 | import numpy as np
8 | from vae.data_loaders.lab_imageloader import lab_imageloader
9 | from vae.arch.vae_skipconn import vae_skipconn as vae
10 | from vae.arch.network import network
11 | from third_party.save_zhang_feats import save_zhang_feats
12 |
13 | flags = tf.flags
14 |
15 | #Directory params
16 | flags.DEFINE_string("out_dir", "", "")
17 | flags.DEFINE_string("in_dir", "", "")
18 | flags.DEFINE_string("list_dir", "", "")
19 |
20 | #Dataset Params
21 | flags.DEFINE_integer("batch_size", 32, "batch size")
22 | flags.DEFINE_integer("updates_per_epoch", 1, "number of updates per epoch")
23 | flags.DEFINE_integer("log_interval", 1, "input image height")
24 | flags.DEFINE_integer("img_width", 64, "input image width")
25 | flags.DEFINE_integer("img_height", 64, "input image height")
26 |
27 | #Network Params
28 | flags.DEFINE_boolean("is_only_data", False, "Is training flag")
29 | flags.DEFINE_boolean("is_train", False, "Is training flag")
30 | flags.DEFINE_boolean("is_run_cvae", False, "Is training flag")
31 | flags.DEFINE_integer("hidden_size", 64, "size of the hidden VAE unit")
32 | flags.DEFINE_float("lr_vae", 1e-6, "learning rate for vae")
33 | flags.DEFINE_integer("max_epoch_vae", 10, "max epoch")
34 | flags.DEFINE_integer("pc_comp", 20, "number of principle components")
35 |
36 | FLAGS = flags.FLAGS
37 |
38 | def main():
39 |
40 | FLAGS.log_interval = 1
41 | FLAGS.list_dir = None
42 | FLAGS.in_dir = 'data/testimgs/'
43 | FLAGS.ext = 'JPEG'
44 | data_loader = lab_imageloader(FLAGS.in_dir, \
45 | 'data/output/testimgs', listdir=None, ext=FLAGS.ext)
46 | img_fns = data_loader.test_img_fns
47 |
48 | if(FLAGS.is_only_data == True):
49 | feats_fns = save_zhang_feats(img_fns, ext=FLAGS.ext)
50 |
51 | with open('%s/list.train.txt' % FLAGS.in_dir, 'w') as fp:
52 | for feats_fn in feats_fns:
53 | fp.write('%s\n' % feats_fn)
54 |
55 | with open('%s/list.test.txt' % FLAGS.in_dir, 'w') as fp:
56 | for feats_fn in feats_fns:
57 | fp.write('%s\n' % feats_fn)
58 |
59 | np.save('%s/lv_color_train.mat.npy' % FLAGS.in_dir, \
60 | np.zeros((len(img_fns), 2*FLAGS.hidden_size)))
61 | np.save('%s/lv_color_test.mat.npy' % FLAGS.in_dir, \
62 | np.zeros((len(img_fns), 2*FLAGS.hidden_size)))
63 | else:
64 | nmix = 8
65 | lv_mdn_test = np.load(os.path.join(FLAGS.in_dir, 'lv_color_mdn_test.mat.npy'))
66 | num_batches = np.int_(np.ceil((lv_mdn_test.shape[0]*1.)/FLAGS.batch_size))
67 |
68 | graph_divcolor = tf.Graph()
69 | with graph_divcolor.as_default():
70 | model_colorfield = vae(FLAGS, nch=2, condinference_flag=True)
71 | dnn = network(model_colorfield, data_loader, 2, FLAGS)
72 | dnn.run_divcolor('data/imagenet_models/' , \
73 | lv_mdn_test, num_batches=num_batches)
74 |
75 | if __name__ == "__main__":
76 | main()
77 |
--------------------------------------------------------------------------------
/get_data.sh:
--------------------------------------------------------------------------------
1 | wget http://vision.cs.illinois.edu/projects/divcolor/data.zip
2 | unzip data.zip
3 | rm data.zip
4 | wget http://vis-www.cs.umass.edu/lfw/lfw-deepfunneled.tgz
5 | tar -xvzf lfw-deepfunneled.tgz
6 | mv lfw-deepfunneled data/lfw_images
7 | rm lfw-deepfunneled.tgz
8 |
--------------------------------------------------------------------------------
/get_zhang_colorization.sh:
--------------------------------------------------------------------------------
1 | cd third_party
2 | git clone https://github.com/richzhang/colorization.git
3 | rm -rf colorization/.git
4 | cd colorization/
5 | bash models/fetch_release_models.sh
6 |
--------------------------------------------------------------------------------
/images/divcolor_figure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aditya12agd5/divcolor/570155a062bcb6428353ec0e80343badbc290caf/images/divcolor_figure.png
--------------------------------------------------------------------------------
/images/divcolor_imagenet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aditya12agd5/divcolor/570155a062bcb6428353ec0e80343badbc290caf/images/divcolor_imagenet.png
--------------------------------------------------------------------------------
/mdn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aditya12agd5/divcolor/570155a062bcb6428353ec0e80343badbc290caf/mdn/__init__.py
--------------------------------------------------------------------------------
/mdn/arch/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aditya12agd5/divcolor/570155a062bcb6428353ec0e80343badbc290caf/mdn/arch/__init__.py
--------------------------------------------------------------------------------
/mdn/arch/layer_factory.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 | from tensorflow.python.framework import tensor_shape
5 |
6 | class layer_factory:
7 |
8 | def __init__(self):
9 | pass
10 |
11 | def weight_variable(self, name, shape=None, mean=0., stddev=.001, gain=np.sqrt(2)):
12 | if(shape == None):
13 | return tf.get_variable(name)
14 | # #Adaptive initialize based on variable shape
15 | # if(len(shape) == 4):
16 | # stddev = (1.0 * gain) / np.sqrt(shape[0] * shape[1] * shape[3])
17 | # else:
18 | # stddev = (1.0 * gain) / np.sqrt(shape[0])
19 | return tf.get_variable(name, shape=shape, initializer=tf.random_normal_initializer(mean=mean, stddev=stddev))
20 |
21 | def bias_variable(self, name, shape=None, constval=.001):
22 | if(shape == None):
23 | return tf.get_variable(name)
24 | return tf.get_variable(name, shape=shape, initializer=tf.constant_initializer(constval))
25 |
26 | def conv2d(self, x, W, stride=1, padding='SAME'):
27 | return tf.nn.conv2d(x, W, strides=[1, stride, stride, 1], padding=padding)
28 |
29 | def batch_norm_aiuiuc_wrapper(self, x, train_phase, name, reuse_vars):
30 | output = tf.contrib.layers.batch_norm(x, \
31 | decay=.99, \
32 | is_training=train_phase, \
33 | scale=True, \
34 | epsilon=1e-4, \
35 | updates_collections=None,\
36 | scope=name,\
37 | reuse=reuse_vars)
38 | return output
39 |
--------------------------------------------------------------------------------
/mdn/data_loaders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aditya12agd5/divcolor/570155a062bcb6428353ec0e80343badbc290caf/mdn/data_loaders/__init__.py
--------------------------------------------------------------------------------
/mdn/data_loaders/zhangfeats_loader.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import glob
3 | import math
4 | import numpy as np
5 |
6 | class zhangfeats_loader:
7 |
8 | def __init__(self, list_train_fn, list_test_fn, lv_train_fn, lv_test_fn, \
9 | featshape=(512, 28, 28)):
10 |
11 | self.train_img_fns = []
12 | self.test_img_fns = []
13 |
14 | with open(list_train_fn, 'r') as ftr:
15 | for img_fn in ftr:
16 | self.train_img_fns.append(img_fn.strip('\n'))
17 |
18 | with open(list_test_fn, 'r') as fte:
19 | for img_fn in fte:
20 | self.test_img_fns.append(img_fn.strip('\n'))
21 |
22 | self.lv_train = np.load(lv_train_fn)
23 | self.lv_test = np.load(lv_test_fn)
24 | self.hidden_size = np.int_((self.lv_train.shape[1]*1.)/2.)
25 |
26 | self.train_img_fns = self.train_img_fns[:self.lv_train.shape[0]]
27 | self.test_img_fns = self.test_img_fns[:self.lv_test.shape[0]]
28 | self.featshape = featshape
29 |
30 | self.train_img_num = len(self.train_img_fns)
31 | self.test_img_num = len(self.test_img_fns)
32 | self.train_batch_head = 0
33 | self.test_batch_head = 0
34 | self.train_shuff_ids = range(self.train_img_num)
35 | self.test_shuff_ids = range(self.test_img_num)
36 |
37 | def reset(self):
38 | self.train_batch_head = 0
39 | self.test_batch_head = 0
40 | self.train_shuff_ids = range(self.train_img_num)
41 | self.test_shuff_ids = range(self.test_img_num)
42 |
43 | def random_reset(self):
44 | self.train_batch_head = 0
45 | self.test_batch_head = 0
46 | self.train_shuff_ids = np.random.permutation(self.train_img_num)
47 | self.test_shuff_ids = range(self.test_img_num)
48 |
49 | def train_next_batch(self, batch_size):
50 | batch = np.zeros((batch_size, self.featshape[2], \
51 | self.featshape[1], self.featshape[0]), dtype='f')
52 | batch_gt = np.zeros((batch_size, self.hidden_size), dtype='f')
53 |
54 | if(self.train_batch_head + batch_size >= self.train_img_num):
55 | self.train_shuff_ids = np.random.permutation(self.train_img_num)
56 | self.train_batch_head = 0
57 |
58 | for i_n, i in enumerate(range(self.train_batch_head, self.train_batch_head+batch_size)):
59 | currid = self.train_shuff_ids[i]
60 | featobj = np.load(self.train_img_fns[currid])
61 | feats = featobj['arr_0'].reshape(self.featshape[0], self.featshape[1], \
62 | self.featshape[2])
63 | feats2d = feats.reshape(self.featshape[0], -1).T
64 | feats3d = feats2d.reshape(self.featshape[1], self.featshape[2], \
65 | self.featshape[0])
66 | batch[i_n, ...] = feats3d
67 | eps = np.random.normal(loc=0., scale=1., size=(self.hidden_size))
68 | batch_gt[i_n, ...] = self.lv_train[currid, :self.hidden_size] \
69 | + eps*self.lv_train[currid, self.hidden_size:]
70 |
71 | self.train_batch_head = self.train_batch_head + batch_size
72 |
73 | return batch, batch_gt
74 |
75 | def test_next_batch(self, batch_size):
76 | batch = np.zeros((batch_size, self.featshape[2], \
77 | self.featshape[1], self.featshape[0]), dtype='f')
78 | batch_gt = np.zeros((batch_size, self.hidden_size), dtype='f')
79 |
80 | if(self.test_batch_head + batch_size > self.test_img_num):
81 | self.test_shuff_ids = range(self.test_img_num)
82 | self.test_batch_head = 0
83 |
84 | for i_n, i in enumerate(range(self.test_batch_head, self.test_batch_head+batch_size)):
85 | currid = self.test_shuff_ids[i]
86 | featobj = np.load(self.test_img_fns[currid])
87 | feats = featobj['arr_0'].reshape(self.featshape[0], self.featshape[1], \
88 | self.featshape[2])
89 | feats2d = feats.reshape(self.featshape[0], -1).T
90 | feats3d = feats2d.reshape(self.featshape[1], self.featshape[2], \
91 | self.featshape[0])
92 | batch[i_n, ...] = feats3d
93 | eps = np.random.normal(loc=0., scale=1., size=(self.hidden_size))
94 | batch_gt[i_n, ...] = self.lv_test[currid, :self.hidden_size] \
95 | +eps*self.lv_test[currid, self.hidden_size:]
96 |
97 | self.test_batch_head = self.test_batch_head + batch_size
98 | return batch, batch_gt
99 |
--------------------------------------------------------------------------------
/mdn/mdn.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["CUDA_VISIBLE_DEVICES"]="0"
3 |
4 | import socket
5 | import sys
6 |
7 | import tensorflow as tf
8 | import numpy as np
9 | from data_loaders.zhangfeats_loader import zhangfeats_loader
10 | from arch.layer_factory import layer_factory
11 |
12 | flags = tf.flags
13 | FLAGS = flags.FLAGS
14 |
15 |
16 | flags.DEFINE_integer("feats_height", 28, "")
17 | flags.DEFINE_integer("feats_width", 28, "")
18 | flags.DEFINE_integer("feats_nch", 512, "")
19 | flags.DEFINE_integer("nmix", 8, "GMM components")
20 | flags.DEFINE_boolean("is_train", True, "")
21 | flags.DEFINE_integer("max_epoch", 5, "Max epoch")
22 | flags.DEFINE_float("lr", 1e-5, "Learning rate")
23 |
24 | #Dynamical assigned params
25 | flags.DEFINE_integer("batch_size", 1, "Batch size")
26 | flags.DEFINE_integer("hidden_size", 64, "VAE latent variable dimension")
27 | flags.DEFINE_integer("updates_per_epoch", 0, "Number of updates per epoch")
28 |
29 | def cnn_feedforward(lf, input_tensor, bn_is_training, keep_prob, reuse=False):
30 |
31 | nout = (FLAGS.hidden_size+1)*FLAGS.nmix
32 |
33 | if(reuse == False):
34 | W_conv1 = lf.weight_variable(name='W_conv1', shape=[5, 5, FLAGS.feats_nch, 384])
35 | W_conv1_1 = lf.weight_variable(name='W_conv1_1', shape=[5, 5, 384, 320])
36 | W_conv1_2 = lf.weight_variable(name='W_conv1_2', shape=[5, 5, 320, 288])
37 | W_conv1_3 = lf.weight_variable(name='W_conv1_3', shape=[5, 5, 288, 256])
38 |
39 | W_conv2 = lf.weight_variable(name='W_conv2', shape=[5, 5, 256, 128])
40 | W_fc1 = lf.weight_variable(name='W_fc1', shape=[14*14*128, 4096])
41 | W_fc2 = lf.weight_variable(name='W_fc2', shape=[4096, nout])
42 |
43 |
44 | b_fc1 = lf.bias_variable(name='b_fc1', shape=[4096])
45 | b_fc2 = lf.bias_variable(name='b_fc2', shape=[nout])
46 |
47 | else:
48 | W_conv1 = lf.weight_variable(name='W_conv1')
49 | W_conv1_1 = lf.weight_variable(name='W_conv1_1')
50 | W_conv1_2 = lf.weight_variable(name='W_conv1_2')
51 | W_conv1_3 = lf.weight_variable(name='W_conv1_3')
52 |
53 | W_conv2 = lf.weight_variable(name='W_conv2')
54 | W_fc1 = lf.weight_variable(name='W_fc1')
55 | W_fc2 = lf.weight_variable(name='W_fc2')
56 |
57 | b_fc1 = lf.bias_variable(name='b_fc1')
58 | b_fc2 = lf.bias_variable(name='b_fc2')
59 |
60 | conv1 = tf.nn.relu(lf.conv2d(input_tensor, W_conv1, stride=1))
61 | conv1_norm = lf.batch_norm_aiuiuc_wrapper(conv1, bn_is_training, \
62 | 'BN1', reuse_vars=reuse)
63 |
64 | conv1_1 = tf.nn.relu(lf.conv2d(conv1_norm, W_conv1_1, stride=1))
65 | conv1_1_norm = lf.batch_norm_aiuiuc_wrapper(conv1_1, bn_is_training, \
66 | 'BN1_1', reuse_vars=reuse)
67 |
68 | conv1_2 = tf.nn.relu(lf.conv2d(conv1_1_norm, W_conv1_2, stride=1))
69 | conv1_2_norm = lf.batch_norm_aiuiuc_wrapper(conv1_2, bn_is_training, \
70 | 'BN1_2', reuse_vars=reuse)
71 |
72 | conv1_3 = tf.nn.relu(lf.conv2d(conv1_2_norm, W_conv1_3, stride=2))
73 | conv1_3_norm = lf.batch_norm_aiuiuc_wrapper(conv1_3, bn_is_training, \
74 | 'BN1_3', reuse_vars=reuse)
75 |
76 | conv2 = tf.nn.relu(lf.conv2d(conv1_3_norm, W_conv2, stride=1))
77 | conv2_norm = lf.batch_norm_aiuiuc_wrapper(conv2, bn_is_training, \
78 | 'BN2', reuse_vars=reuse)
79 |
80 | dropout1 = tf.nn.dropout(conv2_norm, keep_prob)
81 | flatten1 = tf.reshape(dropout1, [-1, 14*14*128])
82 | fc1 = tf.tanh(tf.matmul(flatten1, W_fc1)+b_fc1)
83 |
84 | dropout2 = tf.nn.dropout(fc1, keep_prob)
85 | fc2 = tf.matmul(dropout2, W_fc2)+b_fc2
86 |
87 | return fc2
88 |
89 | def get_mixture_coeff(out_fc):
90 | out_mu = out_fc[..., :FLAGS.hidden_size*FLAGS.nmix]
91 | out_pi = tf.nn.softmax(out_fc[..., FLAGS.hidden_size*FLAGS.nmix:])
92 | out_sigma = tf.constant(.1, shape=[FLAGS.batch_size, FLAGS.nmix])
93 | return out_pi, out_mu, out_sigma
94 |
95 | def compute_gmm_loss(gt_tensor, op_tensor_activ, summ=False):
96 |
97 | #Replicate ground-truth tensor per mixture component
98 | gt_tensor_flat = tf.tile(gt_tensor, [FLAGS.nmix, 1])
99 |
100 | #Pi, mu, sigma
101 | op_tensor_pi, op_tensor_mu, op_tensor_sigma = get_mixture_coeff(op_tensor_activ)
102 |
103 | #Flatten means, sigma, pi aligned to gt above
104 | op_tensor_mu_flat = tf.reshape(op_tensor_mu, [FLAGS.nmix*FLAGS.batch_size, FLAGS.hidden_size])
105 | op_tensor_sigma_flat = tf.reshape(op_tensor_sigma, [FLAGS.nmix*FLAGS.batch_size])
106 |
107 | #N(t|x, mu, sigma): batch_size x nmix
108 | op_norm_dist = tf.reshape(tf.div((.5*tf.reduce_sum(tf.square(gt_tensor_flat-op_tensor_mu_flat), \
109 | reduction_indices=1)), op_tensor_sigma_flat), [FLAGS.batch_size, FLAGS.nmix])
110 | op_norm_dist_min = tf.reduce_min(op_norm_dist, reduction_indices=1)
111 | op_norm_dist_minind = tf.to_int32(tf.argmin(op_norm_dist, 1))
112 | op_pi_minind_flattened = tf.range(0, FLAGS.batch_size)*FLAGS.nmix + op_norm_dist_minind
113 | op_pi_min = tf.gather(tf.reshape(op_tensor_pi, [-1]), op_pi_minind_flattened)
114 |
115 | if(summ == True):
116 | gmm_loss = tf.reduce_mean(-tf.log(op_pi_min+1e-30) + op_norm_dist_min, reduction_indices=0)
117 | else:
118 | gmm_loss = tf.reduce_mean(op_norm_dist_min, reduction_indices=0)
119 |
120 | if(summ == True):
121 | tf.summary.scalar('gmm_loss', gmm_loss)
122 | tf.summary.scalar('op_norm_dist_min', tf.reduce_min(op_norm_dist))
123 | tf.summary.scalar('op_norm_dist_max', tf.reduce_max(op_norm_dist))
124 | tf.summary.scalar('op_pi_min', tf.reduce_mean(op_pi_min))
125 |
126 | return gmm_loss, op_tensor_pi, op_tensor_mu, op_tensor_sigma
127 |
128 | def optimize(loss, lr):
129 | optimizer = tf.train.GradientDescentOptimizer(lr)
130 | return optimizer.minimize(loss)
131 |
132 | def save_chkpt(saver, epoch, sess, chkptdir, prefix='model'):
133 | if not os.path.exists(chkptdir):
134 | os.makedirs(chkptdir)
135 | save_path = saver.save(sess, "%s/%s_%06d.ckpt" % (chkptdir, prefix, epoch))
136 | print("[DEBUG] ############ Model saved in file: %s ################" % save_path)
137 |
138 | def load_chkpt(saver, sess, chkptdir):
139 | ckpt = tf.train.get_checkpoint_state(chkptdir)
140 | if ckpt and ckpt.model_checkpoint_path:
141 | ckpt_fn = ckpt.model_checkpoint_path.replace('//', '/')
142 | print('[DEBUG] Loading checkpoint from %s' % ckpt_fn)
143 | saver.restore(sess, ckpt_fn)
144 | else:
145 | raise NameError('[ERROR] No checkpoint found at: %s' % chkptdir)
146 |
147 | def main():
148 |
149 | if(len(sys.argv) == 1):
150 | raise NameError('[ERROR] No dataset key')
151 | if(sys.argv[1] == 'imagenetval'):
152 | FLAGS.updates_per_epoch = 49000
153 | FLAGS.num_test_batches = 1000
154 | FLAGS.in_featdir = 'data/featslist/imagenetval/'
155 | FLAGS.in_lvdir = 'data/output/imagenetval/'
156 | elif(sys.argv[1] == 'lfw'):
157 | FLAGS.updates_per_epoch = 12233
158 | FLAGS.num_test_batches = 1000
159 | FLAGS.in_featdir = 'data/featslist/lfw/'
160 | FLAGS.in_lvdir = 'data/output/lfw/'
161 | elif(sys.argv[1] == 'church'):
162 | FLAGS.updates_per_epoch = 125227
163 | FLAGS.num_test_batches = 1000
164 | FLAGS.in_featdir = 'data/featslist/church/'
165 | FLAGS.in_lvdir = 'data/output/church/'
166 | else:
167 | raise NameError('[ERROR] Incorrect dataset key')
168 |
169 | data_loader = zhangfeats_loader(os.path.join(FLAGS.in_featdir, 'list.train.txt'), \
170 | os.path.join(FLAGS.in_featdir, 'list.test.txt'),\
171 | os.path.join(FLAGS.in_lvdir, 'lv_color_train.mat.npy'),\
172 | os.path.join(FLAGS.in_lvdir, 'lv_color_test.mat.npy'))
173 |
174 | #Inputs
175 | lf = layer_factory()
176 | input_tensor = tf.placeholder(tf.float32, [FLAGS.batch_size, FLAGS.feats_height, \
177 | FLAGS.feats_width, FLAGS.feats_nch])
178 | output_gt_tensor = tf.placeholder(tf.float32, [FLAGS.batch_size, FLAGS.hidden_size])
179 | is_training = tf.placeholder(tf.bool)
180 | keep_prob = tf.placeholder(tf.float32)
181 |
182 | #Inference
183 | with tf.variable_scope('Inference', reuse=False):
184 | output_activ = cnn_feedforward(lf, input_tensor, is_training, keep_prob, reuse=False)
185 |
186 | with tf.variable_scope('Inference', reuse=True):
187 | output_test_activ = cnn_feedforward(lf, input_tensor, is_training, keep_prob, reuse=True)
188 |
189 | #Loss and gradient descent step
190 | loss, _, _, _ = compute_gmm_loss(output_gt_tensor, output_activ, summ=True)
191 | loss_test, pi_test, mu_test, sigma_test = compute_gmm_loss(output_gt_tensor, output_test_activ)
192 |
193 | train_step = optimize(loss, FLAGS.lr)
194 |
195 | #Standard steps
196 | check_nan_op = tf.add_check_numerics_ops()
197 | init = tf.global_variables_initializer()
198 | saver = tf.train.Saver(max_to_keep=0)
199 | summary_op = tf.summary.merge_all()
200 |
201 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9)
202 | sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
203 | train_writer = tf.summary.FileWriter(os.path.join(FLAGS.in_lvdir, 'logs_mdn'), sess.graph)
204 |
205 | sess.run(init)
206 |
207 | if(FLAGS.is_train):
208 | for epoch in range(FLAGS.max_epoch):
209 | training_loss = 0.
210 |
211 | data_loader.random_reset()
212 | for i in range(FLAGS.updates_per_epoch):
213 | batch, batch_gt = data_loader.train_next_batch(FLAGS.batch_size)
214 | feed_dict = {input_tensor:batch, output_gt_tensor:batch_gt, \
215 | is_training:True, keep_prob:.75}
216 | _, _, loss_value, summary_str = sess.run(\
217 | [check_nan_op, train_step, loss, summary_op], \
218 | feed_dict)
219 | train_writer.add_summary(summary_str, epoch*FLAGS.updates_per_epoch+i)
220 | training_loss = training_loss + loss_value
221 |
222 | print('[DEBUG] Epoch# %d, Loss: %f' % (epoch, \
223 | (training_loss*1.)/FLAGS.updates_per_epoch))
224 |
225 | save_chkpt(saver, epoch, sess, os.path.join(FLAGS.in_lvdir, 'models_mdn'), \
226 | prefix='model_%d_exp' % FLAGS.nmix)
227 | else:
228 | load_chkpt(saver, sess, os.path.join(FLAGS.in_lvdir, 'models_mdn'))
229 |
230 | test_loss = 0.
231 | data_loader.reset()
232 | lv_test_codes = np.zeros((0, (FLAGS.hidden_size+1+1)*FLAGS.nmix), dtype='f')
233 | for i in range(FLAGS.num_test_batches):
234 | batch, batch_gt = data_loader.test_next_batch(FLAGS.batch_size)
235 | feed_dict = {input_tensor:batch, output_gt_tensor:batch_gt, \
236 | is_training:False, keep_prob:1.}
237 | _, loss_value, output_pi, output_mu, output_sigma = \
238 | sess.run([check_nan_op, loss_test, pi_test, mu_test, sigma_test], feed_dict)
239 |
240 | test_loss = test_loss + loss_value
241 | output = np.concatenate((output_mu, output_sigma, output_pi), axis=1)
242 | lv_test_codes = np.concatenate((lv_test_codes, output), axis=0)
243 |
244 | print('[DEBUG] Test Loss: %f' % ((test_loss*1.)/FLAGS.num_test_batches))
245 | np.save(os.path.join(FLAGS.in_lvdir, 'lv_color_mdn_test.mat'), lv_test_codes)
246 | print(lv_test_codes.shape)
247 |
248 | sess.close()
249 |
250 | if __name__ == "__main__":
251 | main()
252 |
--------------------------------------------------------------------------------
/mdn/save_mdn_gmm.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["CUDA_VISIBLE_DEVICES"]="0"
3 |
4 | import socket
5 | import sys
6 |
7 | import tensorflow as tf
8 | import numpy as np
9 | from data_loaders.zhangfeats_loader import zhangfeats_loader
10 | from arch.layer_factory import layer_factory
11 |
12 | flags = tf.flags
13 |
14 | #MDN Params
15 | flags.DEFINE_integer("feats_height", 28, "")
16 | flags.DEFINE_integer("feats_width", 28, "")
17 | flags.DEFINE_integer("feats_nch", 512, "")
18 | flags.DEFINE_integer("nmix", 8, "GMM components")
19 | flags.DEFINE_integer("max_epoch", 5, "Max epoch")
20 | flags.DEFINE_float("lr", 1e-5, "Learning rate")
21 | flags.DEFINE_integer("batch_size_mdn", 1, "Batch size")
22 |
23 | flags.DEFINE_integer("hidden_size", 64, "VAE latent variable dimension")
24 |
25 | FLAGS = flags.FLAGS
26 |
27 | def cnn_feedforward(lf, input_tensor, bn_is_training, keep_prob, reuse=False):
28 |
29 | nout = (FLAGS.hidden_size+1)*FLAGS.nmix
30 |
31 | if(reuse == False):
32 | W_conv1 = lf.weight_variable(name='W_conv1', shape=[5, 5, FLAGS.feats_nch, 384])
33 | W_conv1_1 = lf.weight_variable(name='W_conv1_1', shape=[5, 5, 384, 320])
34 | W_conv1_2 = lf.weight_variable(name='W_conv1_2', shape=[5, 5, 320, 288])
35 | W_conv1_3 = lf.weight_variable(name='W_conv1_3', shape=[5, 5, 288, 256])
36 |
37 | W_conv2 = lf.weight_variable(name='W_conv2', shape=[5, 5, 256, 128])
38 | W_fc1 = lf.weight_variable(name='W_fc1', shape=[14*14*128, 4096])
39 | W_fc2 = lf.weight_variable(name='W_fc2', shape=[4096, nout])
40 |
41 |
42 | b_fc1 = lf.bias_variable(name='b_fc1', shape=[4096])
43 | b_fc2 = lf.bias_variable(name='b_fc2', shape=[nout])
44 |
45 | else:
46 | W_conv1 = lf.weight_variable(name='W_conv1')
47 | W_conv1_1 = lf.weight_variable(name='W_conv1_1')
48 | W_conv1_2 = lf.weight_variable(name='W_conv1_2')
49 | W_conv1_3 = lf.weight_variable(name='W_conv1_3')
50 |
51 | W_conv2 = lf.weight_variable(name='W_conv2')
52 | W_fc1 = lf.weight_variable(name='W_fc1')
53 | W_fc2 = lf.weight_variable(name='W_fc2')
54 |
55 | b_fc1 = lf.bias_variable(name='b_fc1')
56 | b_fc2 = lf.bias_variable(name='b_fc2')
57 |
58 | conv1 = tf.nn.relu(lf.conv2d(input_tensor, W_conv1, stride=1))
59 | conv1_norm = lf.batch_norm_aiuiuc_wrapper(conv1, bn_is_training, \
60 | 'BN1', reuse_vars=reuse)
61 |
62 | conv1_1 = tf.nn.relu(lf.conv2d(conv1_norm, W_conv1_1, stride=1))
63 | conv1_1_norm = lf.batch_norm_aiuiuc_wrapper(conv1_1, bn_is_training, \
64 | 'BN1_1', reuse_vars=reuse)
65 |
66 | conv1_2 = tf.nn.relu(lf.conv2d(conv1_1_norm, W_conv1_2, stride=1))
67 | conv1_2_norm = lf.batch_norm_aiuiuc_wrapper(conv1_2, bn_is_training, \
68 | 'BN1_2', reuse_vars=reuse)
69 |
70 | conv1_3 = tf.nn.relu(lf.conv2d(conv1_2_norm, W_conv1_3, stride=2))
71 | conv1_3_norm = lf.batch_norm_aiuiuc_wrapper(conv1_3, bn_is_training, \
72 | 'BN1_3', reuse_vars=reuse)
73 |
74 | conv2 = tf.nn.relu(lf.conv2d(conv1_3_norm, W_conv2, stride=1))
75 | conv2_norm = lf.batch_norm_aiuiuc_wrapper(conv2, bn_is_training, \
76 | 'BN2', reuse_vars=reuse)
77 |
78 | dropout1 = tf.nn.dropout(conv2_norm, keep_prob)
79 | flatten1 = tf.reshape(dropout1, [-1, 14*14*128])
80 | fc1 = tf.tanh(tf.matmul(flatten1, W_fc1)+b_fc1)
81 |
82 | dropout2 = tf.nn.dropout(fc1, keep_prob)
83 | fc2 = tf.matmul(dropout2, W_fc2)+b_fc2
84 |
85 | return fc2
86 |
87 | def get_mixture_coeff(out_fc):
88 | out_mu = out_fc[..., :FLAGS.hidden_size*FLAGS.nmix]
89 | out_pi = tf.nn.softmax(out_fc[..., FLAGS.hidden_size*FLAGS.nmix:])
90 | out_sigma = tf.constant(.1, shape=[FLAGS.batch_size_mdn, FLAGS.nmix])
91 | return out_pi, out_mu, out_sigma
92 |
93 | def compute_gmm_loss(gt_tensor, op_tensor_activ, summ=False):
94 |
95 | #Replicate ground-truth tensor per mixture component
96 | gt_tensor_flat = tf.tile(gt_tensor, [FLAGS.nmix, 1])
97 |
98 | #Pi, mu, sigma
99 | op_tensor_pi, op_tensor_mu, op_tensor_sigma = get_mixture_coeff(op_tensor_activ)
100 |
101 | #Flatten means, sigma, pi aligned to gt above
102 | op_tensor_mu_flat = tf.reshape(op_tensor_mu, [FLAGS.nmix*FLAGS.batch_size_mdn, FLAGS.hidden_size])
103 | op_tensor_sigma_flat = tf.reshape(op_tensor_sigma, [FLAGS.nmix*FLAGS.batch_size_mdn])
104 |
105 | #N(t|x, mu, sigma): batch_size_mdn x nmix
106 | op_norm_dist = tf.reshape(tf.div((.5*tf.reduce_sum(tf.square(gt_tensor_flat-op_tensor_mu_flat), \
107 | reduction_indices=1)), op_tensor_sigma_flat), [FLAGS.batch_size_mdn, FLAGS.nmix])
108 | op_norm_dist_min = tf.reduce_min(op_norm_dist, reduction_indices=1)
109 | op_norm_dist_minind = tf.to_int32(tf.argmin(op_norm_dist, 1))
110 | op_pi_minind_flattened = tf.range(0, FLAGS.batch_size_mdn)*FLAGS.nmix + op_norm_dist_minind
111 | op_pi_min = tf.gather(tf.reshape(op_tensor_pi, [-1]), op_pi_minind_flattened)
112 |
113 | if(summ == True):
114 | gmm_loss = tf.reduce_mean(-tf.log(op_pi_min+1e-30) + op_norm_dist_min, reduction_indices=0)
115 | else:
116 | gmm_loss = tf.reduce_mean(op_norm_dist_min, reduction_indices=0)
117 |
118 | if(summ == True):
119 | tf.summary.scalar('gmm_loss', gmm_loss)
120 | tf.summary.scalar('op_norm_dist_min', tf.reduce_min(op_norm_dist))
121 | tf.summary.scalar('op_norm_dist_max', tf.reduce_max(op_norm_dist))
122 | tf.summary.scalar('op_pi_min', tf.reduce_mean(op_pi_min))
123 |
124 | return gmm_loss, op_tensor_pi, op_tensor_mu, op_tensor_sigma
125 |
126 | def optimize(loss, lr):
127 | optimizer = tf.train.GradientDescentOptimizer(lr)
128 | return optimizer.minimize(loss)
129 |
130 | def save_chkpt(saver, epoch, sess, chkptdir, prefix='model'):
131 | if not os.path.exists(chkptdir):
132 | os.makedirs(chkptdir)
133 | save_path = saver.save(sess, "%s/%s_%06d.ckpt" % (chkptdir, prefix, epoch))
134 | print("[DEBUG] ############ Model saved in file: %s ################" % save_path)
135 |
136 | def load_chkpt(saver, sess, chkptdir):
137 | ckpt = tf.train.get_checkpoint_state(chkptdir)
138 | print ckpt.model_checkpoint_path
139 | if ckpt and ckpt.model_checkpoint_path:
140 | ckpt_fn = ckpt.model_checkpoint_path.replace('//', '/')
141 | print('[DEBUG] Loading checkpoint from %s' % ckpt_fn)
142 | saver.restore(sess, ckpt_fn)
143 | else:
144 | raise NameError('[ERROR] No checkpoint found at: %s' % chkptdir)
145 |
146 | def save_mdn_gmm(data_dir):
147 |
148 | FLAGS.in_featdir = data_dir
149 | FLAGS.in_lvdir = data_dir
150 |
151 | data_loader = zhangfeats_loader(os.path.join(FLAGS.in_featdir, 'list.train.txt'), \
152 | os.path.join(FLAGS.in_featdir, 'list.test.txt'),\
153 | os.path.join(FLAGS.in_lvdir, 'lv_color_train.mat.npy'),\
154 | os.path.join(FLAGS.in_lvdir, 'lv_color_test.mat.npy'))
155 |
156 | FLAGS.num_test_batches = data_loader.test_img_num
157 |
158 | #Inputs
159 | lf = layer_factory()
160 | input_tensor = tf.placeholder(tf.float32, [FLAGS.batch_size_mdn, FLAGS.feats_height, \
161 | FLAGS.feats_width, FLAGS.feats_nch])
162 | output_gt_tensor = tf.placeholder(tf.float32, [FLAGS.batch_size_mdn, FLAGS.hidden_size])
163 | is_training = tf.placeholder(tf.bool)
164 | keep_prob = tf.placeholder(tf.float32)
165 |
166 | #Inference
167 | with tf.variable_scope('Inference', reuse=False):
168 | output_activ = cnn_feedforward(lf, input_tensor, is_training, keep_prob, reuse=False)
169 |
170 | with tf.variable_scope('Inference', reuse=True):
171 | output_test_activ = cnn_feedforward(lf, input_tensor, is_training, keep_prob, reuse=True)
172 |
173 | #Loss and gradient descent step
174 | loss, _, _, _ = compute_gmm_loss(output_gt_tensor, output_activ, summ=True)
175 | loss_test, pi_test, mu_test, sigma_test = compute_gmm_loss(output_gt_tensor, output_test_activ)
176 |
177 | train_step = optimize(loss, FLAGS.lr)
178 |
179 | #Standard steps
180 | check_nan_op = tf.add_check_numerics_ops()
181 | init = tf.global_variables_initializer()
182 | saver = tf.train.Saver(max_to_keep=0)
183 | summary_op = tf.summary.merge_all()
184 |
185 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9)
186 | sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
187 |
188 | sess.run(init)
189 |
190 | load_chkpt(saver, sess, 'data/imagenet_models_mdn/')
191 |
192 | data_loader.reset()
193 | lv_test_codes = np.zeros((0, (FLAGS.hidden_size+1+1)*FLAGS.nmix), dtype='f')
194 | for i in range(FLAGS.num_test_batches):
195 | batch, batch_gt = data_loader.test_next_batch(FLAGS.batch_size_mdn)
196 | feed_dict = {input_tensor:batch, output_gt_tensor:batch_gt, \
197 | is_training:False, keep_prob:1.}
198 | _, output_pi, output_mu, output_sigma = \
199 | sess.run([check_nan_op, pi_test, mu_test, sigma_test], feed_dict)
200 | output = np.concatenate((output_mu, output_sigma, output_pi), axis=1)
201 | lv_test_codes = np.concatenate((lv_test_codes, output), axis=0)
202 |
203 | np.save(os.path.join(FLAGS.in_lvdir, 'lv_color_mdn_test.mat'), lv_test_codes)
204 | print(lv_test_codes.shape)
205 |
206 | sess.close()
207 |
208 | return lv_test_codes
209 |
210 | if __name__=='__main__':
211 | save_mdn_gmm(sys.argv[1])
212 |
213 |
214 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow_gpu==1.1.0
2 | scipy==0.19.0
3 | scikit_image==0.13.0
4 | matplotlib==2.0.1
5 | numpy==1.12.1
6 | Pillow==4.1.1
7 | protobuf==3.3.0
8 | scikit_learn==0.18.1
9 | tensorflow==1.1.0
10 |
--------------------------------------------------------------------------------
/run_demo.sh:
--------------------------------------------------------------------------------
1 | python demo.py --is_only_data=True
2 | python mdn/save_mdn_gmm.py data/testimgs/
3 | python demo.py --is_only_data=False
4 |
--------------------------------------------------------------------------------
/run_lfw.sh:
--------------------------------------------------------------------------------
1 | python vae/train.py lfw
2 | python mdn/mdn.py lfw
3 | python vae/test.py lfw
4 |
--------------------------------------------------------------------------------
/third_party/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aditya12agd5/divcolor/570155a062bcb6428353ec0e80343badbc290caf/third_party/__init__.py
--------------------------------------------------------------------------------
/third_party/save_zhang_feats.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["CUDA_VISIBLE_DEVICES"]="1"
3 | import sys
4 |
5 |
6 | import numpy as np
7 | import caffe
8 | import skimage.color as color
9 | import skimage.io
10 | import scipy.ndimage.interpolation as sni
11 | import cv2
12 |
13 | def save_zhang_feats(img_fns, ext='JPEG'):
14 |
15 | gpu_id = 0
16 | caffe.set_mode_gpu()
17 | caffe.set_device(gpu_id)
18 | net = caffe.Net('third_party/colorization/models/colorization_deploy_v1.prototxt', \
19 | 'third_party/colorization/models/colorization_release_v1.caffemodel', caffe.TEST)
20 |
21 | (H_in,W_in) = net.blobs['data_l'].data.shape[2:] # get input shape
22 | (H_out,W_out) = net.blobs['class8_ab'].data.shape[2:] # get output shape
23 | net.blobs['Trecip'].data[...] = 6/np.log(10) # 1/T, set annealing temperature
24 |
25 | feats_fns = []
26 | for img_fn_i, img_fn in enumerate(img_fns):
27 |
28 | # load the original image
29 | img_rgb = caffe.io.load_image(img_fn)
30 | img_lab = color.rgb2lab(img_rgb) # convert image to lab color space
31 | img_l = img_lab[:,:,0] # pull out L channel
32 | (H_orig,W_orig) = img_rgb.shape[:2] # original image size
33 |
34 | # create grayscale version of image (just for displaying)
35 | img_lab_bw = img_lab.copy()
36 | img_lab_bw[:,:,1:] = 0
37 | img_rgb_bw = color.lab2rgb(img_lab_bw)
38 |
39 | # resize image to network input size
40 | img_rs = caffe.io.resize_image(img_rgb,(H_in,W_in)) # resize image to network input size
41 | img_lab_rs = color.rgb2lab(img_rs)
42 | img_l_rs = img_lab_rs[:,:,0]
43 |
44 | net.blobs['data_l'].data[0,0,:,:] = img_l_rs-50 # subtract 50 for mean-centering
45 | net.forward() # run network
46 |
47 | npz_fn = img_fn.replace(ext, 'npz')
48 | np.savez_compressed(npz_fn, net.blobs['conv7_3'].data)
49 | feats_fns.append(npz_fn)
50 |
51 | return feats_fns
52 |
--------------------------------------------------------------------------------
/vae/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aditya12agd5/divcolor/570155a062bcb6428353ec0e80343badbc290caf/vae/__init__.py
--------------------------------------------------------------------------------
/vae/arch/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aditya12agd5/divcolor/570155a062bcb6428353ec0e80343badbc290caf/vae/arch/__init__.py
--------------------------------------------------------------------------------
/vae/arch/layer_factory.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | from tensorflow.python.framework import tensor_shape
4 |
5 | class layer_factory:
6 |
7 | def __init__(self):
8 | pass
9 |
10 | def weight_variable(self, name, shape=None, mean=0., stddev=.001, gain=np.sqrt(2)):
11 | if(shape == None):
12 | return tf.get_variable(name)
13 | # #Adaptive initialize based on variable shape
14 | # if(len(shape) == 4):
15 | # stddev = (1.0 * gain) / np.sqrt(shape[0] * shape[1] * shape[3])
16 | # else:
17 | # stddev = (1.0 * gain) / np.sqrt(shape[0])
18 | return tf.get_variable(name, shape=shape, initializer=tf.random_normal_initializer(mean=mean, stddev=stddev))
19 |
20 | def bias_variable(self, name, shape=None, constval=.001):
21 | if(shape == None):
22 | return tf.get_variable(name)
23 | return tf.get_variable(name, shape=shape, initializer=tf.constant_initializer(constval))
24 |
25 | def conv2d(self, x, W, stride=1, padding='SAME'):
26 | return tf.nn.conv2d(x, W, strides=[1, stride, stride, 1], padding=padding)
27 |
28 | def lrelu(self, x, leak=.2):
29 | return tf.maximum(x, leak*x)
30 |
31 | def batch_norm_aiuiuc_wrapper(self, x, train_phase, name, reuse_vars):
32 | output = tf.contrib.layers.batch_norm(x, \
33 | decay=.99, \
34 | is_training=train_phase, \
35 | scale=True, \
36 | epsilon=1e-4, \
37 | updates_collections=None,\
38 | scope=name,\
39 | reuse=reuse_vars)
40 | return output
41 |
--------------------------------------------------------------------------------
/vae/arch/network.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import os
3 | import numpy as np
4 | from sklearn.cluster import KMeans
5 |
6 | class network:
7 |
8 | def __init__(self, model, data_loader, nch, flags):
9 | self.model = model
10 | self.data_loader = data_loader
11 | self.nch = nch
12 | self.flags = flags
13 | self.__build_graph()
14 |
15 | def __build_graph(self):
16 | self.plhold_img, self.plhold_greylevel, self.plhold_latent, self.plhold_is_training, \
17 | self.plhold_keep_prob, self.plhold_kl_weight, self.plhold_lossweights \
18 | = self.model.inputs()
19 |
20 | #inference graph
21 | self.op_mean, self.op_stddev, self.op_vae, \
22 | self.op_mean_test, self.op_stddev_test, self.op_vae_test, \
23 | self.op_vae_condinference \
24 | = self.model.inference(self.plhold_img, self.plhold_greylevel, \
25 | self.plhold_latent, self.plhold_is_training, self.plhold_keep_prob)
26 |
27 | #loss function and gd step for vae
28 | self.loss = self.model.loss(self.plhold_img, self.op_vae, self.op_mean, \
29 | self.op_stddev, self.plhold_kl_weight, self.plhold_lossweights)
30 | self.train_step = self.model.optimize(self.loss, epsilon=1e-6)
31 |
32 | #standard steps
33 | self.check_nan_op = tf.add_check_numerics_ops()
34 | self.init = tf.global_variables_initializer()
35 | self.saver = tf.train.Saver(max_to_keep=0)
36 | self.summary_op = tf.summary.merge_all()
37 |
38 | def train_vae(self, chkptdir, is_train=True):
39 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.95)
40 | sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
41 | num_train_batches = np.int_(np.floor((self.data_loader.train_img_num*1.)/\
42 | self.flags.batch_size))
43 | if(is_train == True):
44 | sess.run(self.init)
45 | print('[DEBUG] Saving TensorBoard summaries to: %s/logs' % self.flags.out_dir)
46 | self.train_writer = tf.summary.FileWriter(os.path.join(self.flags.out_dir, 'logs'), sess.graph)
47 |
48 | #Train vae
49 | for epoch in range(self.flags.max_epoch_vae):
50 | epoch_loss = self.run_vae_epoch_train(epoch, sess)
51 | epoch_loss = (epoch_loss*1.) / (self.flags.updates_per_epoch)
52 | print('[DEBUG] ####### Train VAE Epoch#%d, Loss %f #######' % (epoch, epoch_loss))
53 | self.__save_chkpt(epoch, sess, chkptdir, prefix='model_vae')
54 | else:
55 | self.__load_chkpt(sess, chkptdir)
56 |
57 | epoch_latentvars_train = self.run_vae_epoch_test(self.flags.max_epoch_vae,\
58 | sess, num_batches=num_train_batches, is_train=True)
59 |
60 | epoch_latentvars_test = self.run_vae_epoch_test(2, sess, num_batches=3, is_train=False)
61 |
62 | sess.close()
63 | return epoch_latentvars_train, epoch_latentvars_test
64 |
65 | def run_vae_epoch_train(self, epoch, sess):
66 | epoch_loss = 0.
67 | self.data_loader.random_reset()
68 | delta_kl_weight = (1e-2*1.)/(self.flags.max_epoch_vae*1.) #CHANGED WEIGHT SCHEDULE
69 | latent_feed = np.zeros((self.flags.batch_size, self.flags.hidden_size), dtype='f')
70 | for i in range(self.flags.updates_per_epoch):
71 | kl_weight = delta_kl_weight*(epoch)
72 | batch, batch_recon_const, batch_lossweights, batch_recon_const_outres = \
73 | self.data_loader.train_next_batch(self.flags.batch_size, self.nch)
74 | feed_dict = {self.plhold_img: batch, self.plhold_is_training:True, \
75 | self.plhold_keep_prob:.7, \
76 | self.plhold_kl_weight:kl_weight, \
77 | self.plhold_latent:latent_feed, \
78 | self.plhold_lossweights:batch_lossweights, \
79 | self.plhold_greylevel:batch_recon_const}
80 | try:
81 | _, _, loss_value, output, summary_str = sess.run(\
82 | [self.check_nan_op, self.train_step, self.loss, \
83 | self.op_vae, self.summary_op], feed_dict)
84 | except:
85 | raise NameError('[ERROR] Found nan values in run_vae_epoch_train')
86 | self.train_writer.add_summary(summary_str, epoch*self.flags.updates_per_epoch+i)
87 | if(i % self.flags.log_interval == 0):
88 | self.data_loader.save_output_with_gt(output, batch, epoch, i, \
89 | '%02d_train_vae' % self.nch, self.flags.batch_size, \
90 | num_cols=8, net_recon_const=batch_recon_const_outres)
91 | epoch_loss += loss_value
92 | return epoch_loss
93 |
94 | def run_vae_epoch_test(self, epoch, sess, num_batches=3, is_train=False):
95 | self.data_loader.reset()
96 | kl_weight = 0.
97 | latentvars_epoch = np.zeros((0, 2*self.flags.hidden_size), dtype='f')
98 | latent_feed = np.zeros((self.flags.batch_size, self.flags.hidden_size), dtype='f')
99 | for i in range(num_batches):
100 | if(is_train == False):
101 | batch, batch_recon_const, batch_recon_const_outres, _ = \
102 | self.data_loader.test_next_batch(self.flags.batch_size, self.nch)
103 | else:
104 | batch, batch_recon_const, _, batch_recon_const_outres = \
105 | self.data_loader.train_next_batch(self.flags.batch_size, self.nch)
106 |
107 | batch_lossweights = np.ones((self.flags.batch_size, \
108 | self.nch*self.flags.img_height*self.flags.img_width), dtype='f')
109 |
110 | feed_dict = {self.plhold_img: batch, self.plhold_is_training:False, \
111 | self.plhold_keep_prob:1., \
112 | self.plhold_kl_weight:kl_weight, \
113 | self.plhold_latent:latent_feed, \
114 | self.plhold_lossweights: batch_lossweights, \
115 | self.plhold_greylevel:batch_recon_const}
116 | try:
117 | _, means_batch, stddevs_batch, output = sess.run(\
118 | [self.check_nan_op, self.op_mean_test, self.op_stddev_test, \
119 | self.op_vae_test], feed_dict)
120 | except:
121 | raise NameError('[ERROR] Found nan values in run_vae_epoch_test')
122 | if(is_train == False):
123 | self.data_loader.save_output_with_gt(output, batch, epoch, i, \
124 | '%02d_test_vae' % self.nch, self.flags.batch_size, num_cols=8, \
125 | net_recon_const=batch_recon_const_outres)
126 | else:
127 | if(i % self.flags.log_interval == 0):
128 | self.data_loader.save_output_with_gt(output, batch, epoch, i, \
129 | '%02d_latentvar' % self.nch, self.flags.batch_size, num_cols=8, \
130 | net_recon_const=batch_recon_const_outres)
131 |
132 | latentvars_epoch = np.concatenate((latentvars_epoch, \
133 | np.concatenate((means_batch, stddevs_batch), axis=1)), axis=0)
134 | return latentvars_epoch
135 |
136 | def run_cvae(self, chkptdir, latentvars, num_batches=3, num_repeat=8, num_cluster=5):
137 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.95)
138 | sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
139 | self.__load_chkpt(sess, chkptdir)
140 | self.data_loader.reset()
141 | kl_weight = 0.
142 | for i in range(num_batches):
143 | print ('[DEBUG] Batch %d (/ %d)' % (i, num_batches))
144 | batch, batch_recon_const, batch_recon_const_outres, batch_imgnames = \
145 | self.data_loader.test_next_batch(self.flags.batch_size, self.nch)
146 | batch_lossweights = np.ones((self.flags.batch_size, \
147 | self.nch*self.flags.img_height*self.flags.img_width), dtype='f')
148 | for j in range(self.flags.batch_size):
149 | output_all = np.zeros((0, self.nch*self.flags.img_height*self.flags.img_width), dtype='f')
150 | print ('[DEBUG] Batch %d (/ %d), Img %d' % (i, num_batches, j))
151 | for k in range(num_repeat):
152 | imgid = i*self.flags.batch_size+j
153 | batch_1 = np.tile(batch[j, ...], (self.flags.batch_size, 1))
154 | batch_recon_const_1 = np.tile(batch_recon_const[j, ...], (self.flags.batch_size, 1))
155 | batch_recon_const_outres_1 = np.tile(batch_recon_const_outres[j, ...], (self.flags.batch_size, 1))
156 |
157 | latent_feed = np.random.normal(loc=0., scale=1., \
158 | size=(self.flags.batch_size, self.flags.hidden_size))
159 | #latent_feed = latentvars[imgid*self.flags.batch_size:(imgid+1)*self.flags.batch_size, ...]
160 | feed_dict = {self.plhold_img:batch_1, self.plhold_is_training:False, \
161 | self.plhold_keep_prob:1., \
162 | self.plhold_kl_weight:kl_weight, \
163 | self.plhold_latent:latent_feed, \
164 | self.plhold_lossweights: batch_lossweights,
165 | self.plhold_greylevel:batch_recon_const_1}
166 | try:
167 | _, output = sess.run(\
168 | [self.check_nan_op, self.op_vae_condinference], \
169 | feed_dict)
170 | except:
171 | raise NameError('[ERROR] Found nan values in condinference_vae')
172 | output_all = np.concatenate((output_all, output), axis=0)
173 |
174 | print ('[DEBUG] Clustering %d predictions' % output_all.shape[0])
175 | kmeans = KMeans(n_clusters=num_cluster, random_state=0).fit(output_all)
176 | output_clust = kmeans.cluster_centers_
177 | self.data_loader.save_divcolor(output_clust, batch_1[:num_cluster], i, j, \
178 | 'cvae', num_cluster, batch_imgnames[j], num_cols=8, \
179 | net_recon_const=batch_recon_const_outres_1[:num_cluster])
180 |
181 | sess.close()
182 |
183 | def run_divcolor(self, chkptdir, latentvars, num_batches=3, topk=8):
184 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.95)
185 | sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
186 | self.__load_chkpt(sess, chkptdir)
187 | self.data_loader.reset()
188 | kl_weight = 0.
189 | nmix = topk
190 | for i in range(num_batches):
191 | batch, batch_recon_const, batch_recon_const_outres, batch_imgnames = \
192 | self.data_loader.test_next_batch(self.flags.batch_size, self.nch)
193 | batch_lossweights = np.ones((self.flags.batch_size, \
194 | self.nch*self.flags.img_height*self.flags.img_width), dtype='f')
195 | output_all = np.zeros((0, \
196 | self.nch*self.flags.img_height*self.flags.img_width), dtype='f')
197 | for j in range(self.flags.batch_size):
198 | imgid = i*self.flags.batch_size+j
199 | print ('[DEBUG] Running divcolor on Image %d (/%d)' % \
200 | (imgid, self.data_loader.test_img_num))
201 | if(imgid >= self.data_loader.test_img_num):
202 | break
203 | batch_1 = np.tile(batch[j, ...], (self.flags.batch_size, 1))
204 | batch_recon_const_1 = np.tile(batch_recon_const[j, ...], (self.flags.batch_size, 1))
205 | batch_recon_const_outres_1 = np.tile(batch_recon_const_outres[j, ...], (self.flags.batch_size, 1))
206 | curr_means = latentvars[imgid, :self.flags.hidden_size*nmix].reshape(nmix, self.flags.hidden_size)
207 | curr_sigma = latentvars[imgid, self.flags.hidden_size*nmix:(self.flags.hidden_size+1)*nmix].reshape(-1)
208 | curr_pi = latentvars[imgid, (self.flags.hidden_size+1)*nmix:].reshape(-1)
209 | selectid = np.argsort(-1*curr_pi)
210 | latent_feed = np.tile(curr_means[selectid, ...], (np.int_(np.round((self.flags.batch_size*1.)/nmix)), 1))
211 |
212 | feed_dict = {self.plhold_img:batch_1, self.plhold_is_training:False, \
213 | self.plhold_keep_prob:1., \
214 | self.plhold_kl_weight:kl_weight, \
215 | self.plhold_latent:latent_feed, \
216 | self.plhold_lossweights: batch_lossweights,
217 | self.plhold_greylevel:batch_recon_const_1}
218 |
219 | _, output = sess.run(\
220 | [self.check_nan_op, self.op_vae_condinference], \
221 | feed_dict)
222 |
223 | self.data_loader.save_divcolor(output[:topk], batch_1[:topk], i, j, \
224 | 'divcolor', topk, batch_imgnames[j], num_cols=8, \
225 | net_recon_const=batch_recon_const_outres_1[:topk, ...])
226 |
227 | sess.close()
228 |
229 |
230 | def __save_chkpt(self, epoch, sess, chkptdir, prefix='model'):
231 | if not os.path.exists(chkptdir):
232 | os.makedirs(chkptdir)
233 | save_path = self.saver.save(sess, "%s/%s_%06d.ckpt" % (chkptdir, prefix, epoch))
234 | print("[DEBUG] ############ Model saved in file: %s ################" % save_path)
235 |
236 | def __load_chkpt(self, sess, chkptdir):
237 | ckpt = tf.train.get_checkpoint_state(chkptdir)
238 | if ckpt and ckpt.model_checkpoint_path:
239 | ckpt_fn = ckpt.model_checkpoint_path.replace('//', '/')
240 | print('[DEBUG] Loading checkpoint from %s' % ckpt_fn)
241 | self.saver.restore(sess, ckpt_fn)
242 | else:
243 | raise NameError('[ERROR] No checkpoint found at: %s' % chkptdir)
244 |
--------------------------------------------------------------------------------
/vae/arch/vae_skipconn.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from layer_factory import layer_factory
3 | from tensorflow.python.framework import tensor_shape
4 |
5 | class vae_skipconn:
6 |
7 | def __init__(self, flags, nch=2, condinference_flag=False):
8 | self.flags = flags
9 | self.nch = nch
10 | self.layer_factory = layer_factory()
11 | self.condinference_flag = condinference_flag
12 |
13 | #Returns handles to input placeholders
14 | def inputs(self):
15 | inp_img = tf.placeholder(tf.float32, [self.flags.batch_size, \
16 | self.nch * self.flags.img_height * self.flags.img_width])
17 | inp_greylevel = tf.placeholder(tf.float32, [self.flags.batch_size, \
18 | self.flags.img_height * self.flags.img_width])
19 | inp_latent = tf.placeholder(tf.float32, [self.flags.batch_size, \
20 | self.flags.hidden_size])
21 | is_training = tf.placeholder(tf.bool)
22 | keep_prob = tf.placeholder(tf.float32)
23 | kl_weight = tf.placeholder(tf.float32)
24 | lossweights = tf.placeholder(tf.float32, [self.flags.batch_size, \
25 | self.nch * self.flags.img_height * self.flags.img_width])
26 | return inp_img, inp_greylevel, inp_latent, is_training, keep_prob, kl_weight, lossweights
27 |
28 | #Takes input placeholders, builds inference graph and returns net. outputs
29 | def inference(self, inp_img, inp_greylevel, inp_latent, is_training, keep_prob):
30 |
31 | with tf.variable_scope('Inference', reuse=False) as sc:
32 | gfeat32, gfeat16, gfeat8, gfeat4 = \
33 | self.__cond_encoder(sc, inp_greylevel, is_training, keep_prob, \
34 | in_nch=1, reuse=False)
35 | z1_train = self.__encoder(sc, inp_img, is_training, keep_prob, \
36 | in_nch=self.nch, reuse=False)
37 | epsilon_train = tf.truncated_normal([self.flags.batch_size, self.flags.hidden_size])
38 | mean_train = z1_train[:, :self.flags.hidden_size]
39 | stddev_train = tf.sqrt(tf.exp(z1_train[:, self.flags.hidden_size:]))
40 | z1_sample = mean_train + epsilon_train * stddev_train
41 | output_train = self.__decoder(sc, gfeat32, gfeat16, gfeat8, gfeat4, \
42 | is_training, inp_greylevel, z1_sample, reuse=False)
43 |
44 | with tf.variable_scope('Inference', reuse=True) as sc:
45 | gfeat32, gfeat16, gfeat8, gfeat4 = \
46 | self.__cond_encoder(sc, inp_greylevel, is_training, keep_prob, \
47 | in_nch=1, reuse=True)
48 |
49 | if(self.condinference_flag == False):
50 | z1_test = self.__encoder(sc, inp_img, is_training, keep_prob, \
51 | in_nch=self.nch, reuse=True)
52 | epsilon_test = tf.truncated_normal([self.flags.batch_size, self.flags.hidden_size])
53 | mean_test = z1_test[:, :self.flags.hidden_size]
54 | stddev_test = tf.sqrt(tf.exp(z1_test[:, self.flags.hidden_size:]))
55 | z1_sample = mean_test + epsilon_test * stddev_test
56 | tf.stop_gradient(z1_sample) #Fix the encoder
57 | output_test = self.__decoder(sc, gfeat32, gfeat16, gfeat8, gfeat4, \
58 | is_training, inp_greylevel, z1_sample, reuse=True)
59 | output_condinference = None
60 | else:
61 | mean_test = None
62 | stddev_test = None
63 | output_test = None
64 | output_condinference = self.__decoder(sc, gfeat32, gfeat16, gfeat8,\
65 | gfeat4, is_training, inp_greylevel, inp_latent, reuse=True)
66 |
67 | return mean_train, stddev_train, output_train, mean_test, stddev_test, \
68 | output_test, output_condinference
69 |
70 | #Takes net. outputs and computes loss for vae(enc+dec)
71 | def loss(self, target_tensor, op_tensor, mean, stddev, kl_weight, lossweights, epsilon=1e-6):
72 |
73 | kl_loss = tf.reduce_sum(0.5 * (tf.square(mean) + tf.square(stddev) \
74 | - tf.log(tf.maximum(tf.square(stddev), epsilon)) - 1.0))
75 |
76 | recon_loss = tf.reduce_mean(tf.sqrt(tf.reduce_sum( \
77 | lossweights*tf.square(target_tensor-op_tensor), 1)), 0)
78 |
79 | recon_loss_l2 = tf.reduce_mean(tf.sqrt(tf.reduce_sum( \
80 | tf.square(target_tensor-op_tensor), 1)), 0)
81 |
82 | if(self.nch == 2):
83 | target_tensor2d = tf.reshape(target_tensor, [self.flags.batch_size, \
84 | self.flags.img_height, self.flags.img_width, self.nch])
85 | op_tensor2d = tf.reshape(op_tensor, [self.flags.batch_size, \
86 | self.flags.img_height, self.flags.img_width, self.nch])
87 | [n,w,h,c] = target_tensor2d.get_shape().as_list()
88 | dv = tf.square((target_tensor2d[:,1:,:h-1,:] - target_tensor2d[:,:w-1,:h-1,:])
89 | - (op_tensor2d[:,1:,:h-1,:] - op_tensor2d[:,:w-1,:h-1,:]))
90 | dh = tf.square((target_tensor2d[:,:w-1,1:,:] - target_tensor2d[:,:w-1,:h-1,:])
91 | - (op_tensor2d[:,:w-1,1:,:] - op_tensor2d[:,:w-1,:h-1,:]))
92 | grad_loss = tf.reduce_mean(tf.sqrt(tf.reduce_sum(dv+dh,[1,2,3])))
93 | recon_loss = recon_loss + (1e-3)*grad_loss
94 |
95 | loss = kl_weight*kl_loss + recon_loss
96 | tf.summary.scalar('kl_loss', kl_loss)
97 | tf.summary.scalar('grad_loss', grad_loss)
98 | tf.summary.scalar('recon_loss', recon_loss)
99 | tf.summary.scalar('recon_loss_l2', recon_loss_l2)
100 | tf.summary.scalar('loss', loss)
101 | return loss
102 |
103 | #Takes loss and returns GD train step
104 | def optimize(self, loss, epsilon):
105 | train_step = tf.train.AdamOptimizer(self.flags.lr_vae, epsilon=epsilon).minimize(loss)
106 | return train_step
107 |
108 | def __cond_encoder(self, scope, input_tensor, bn_is_training, keep_prob, in_nch=1, reuse=False):
109 |
110 | lf = self.layer_factory
111 | input_tensor2d = tf.reshape(input_tensor, [self.flags.batch_size, \
112 | self.flags.img_height, self.flags.img_width, 1])
113 | nch = tensor_shape.as_dimension(input_tensor2d.get_shape()[3]).value
114 | nout = self.flags.hidden_size
115 |
116 | if(reuse == False):
117 | W_conv1 = lf.weight_variable(name='W_conv1_cond', shape=[5, 5, nch, 128])
118 | W_conv2 = lf.weight_variable(name='W_conv2_cond', shape=[5, 5, 128, 256])
119 | W_conv3 = lf.weight_variable(name='W_conv3_cond', shape=[5, 5, 256, 512])
120 | W_conv4 = lf.weight_variable(name='W_conv4_cond', shape=[4, 4, 512, self.flags.hidden_size])
121 |
122 | b_conv1 = lf.bias_variable(name='b_conv1_cond', shape=[128])
123 | b_conv2 = lf.bias_variable(name='b_conv2_cond', shape=[256])
124 | b_conv3 = lf.bias_variable(name='b_conv3_cond', shape=[512])
125 | b_conv4 = lf.bias_variable(name='b_conv4_cond', shape=[self.flags.hidden_size])
126 | else:
127 | W_conv1 = lf.weight_variable(name='W_conv1_cond')
128 | W_conv2 = lf.weight_variable(name='W_conv2_cond')
129 | W_conv3 = lf.weight_variable(name='W_conv3_cond')
130 | W_conv4 = lf.weight_variable(name='W_conv4_cond')
131 |
132 | b_conv1 = lf.bias_variable(name='b_conv1_cond')
133 | b_conv2 = lf.bias_variable(name='b_conv2_cond')
134 | b_conv3 = lf.bias_variable(name='b_conv3_cond')
135 | b_conv4 = lf.bias_variable(name='b_conv4_cond')
136 |
137 | conv1 = tf.nn.relu(lf.conv2d(input_tensor2d, W_conv1, stride=2) + b_conv1)
138 | conv1_norm = lf.batch_norm_aiuiuc_wrapper(conv1, bn_is_training, \
139 | 'BN1_cond', reuse_vars=reuse)
140 |
141 | conv2 = tf.nn.relu(lf.conv2d(conv1_norm, W_conv2, stride=2) + b_conv2)
142 | conv2_norm = lf.batch_norm_aiuiuc_wrapper(conv2, bn_is_training, \
143 | 'BN2_cond', reuse_vars=reuse)
144 |
145 | conv3 = tf.nn.relu(lf.conv2d(conv2_norm, W_conv3, stride=2) + b_conv3)
146 | conv3_norm = lf.batch_norm_aiuiuc_wrapper(conv3, bn_is_training, \
147 | 'BN3_cond', reuse_vars=reuse)
148 |
149 | conv4 = tf.nn.relu(lf.conv2d(conv3_norm, W_conv4, stride=2) + b_conv4)
150 | conv4_norm = lf.batch_norm_aiuiuc_wrapper(conv4, bn_is_training, \
151 | 'BN4_cond', reuse_vars=reuse)
152 |
153 | return conv1_norm, conv2_norm, conv3_norm, conv4_norm
154 |
155 | def __encoder(self, scope, input_tensor, bn_is_training, keep_prob, in_nch=2, reuse=False):
156 |
157 | lf = self.layer_factory
158 |
159 | input_tensor2d = tf.reshape(input_tensor, [self.flags.batch_size, \
160 | self.flags.img_height, self.flags.img_width, in_nch])
161 |
162 | nch = tensor_shape.as_dimension(input_tensor2d.get_shape()[3]).value
163 |
164 | if(reuse==False):
165 | W_conv1 = lf.weight_variable(name='W_conv1', shape=[5, 5, nch, 128])
166 | W_conv2 = lf.weight_variable(name='W_conv2', shape=[5, 5, 128, 256])
167 | W_conv3 = lf.weight_variable(name='W_conv3', shape=[5, 5, 256, 512])
168 | W_conv4 = lf.weight_variable(name='W_conv4', shape=[4, 4, 512, 1024])
169 | W_fc1 = lf.weight_variable(name='W_fc1', shape=[4*4*1024, self.flags.hidden_size * 2])
170 |
171 | b_conv1 = lf.bias_variable(name='b_conv1', shape=[128])
172 | b_conv2 = lf.bias_variable(name='b_conv2', shape=[256])
173 | b_conv3 = lf.bias_variable(name='b_conv3', shape=[512])
174 | b_conv4 = lf.bias_variable(name='b_conv4', shape=[1024])
175 | b_fc1 = lf.bias_variable(name='b_fc1', shape=[self.flags.hidden_size * 2])
176 | else:
177 | W_conv1 = lf.weight_variable(name='W_conv1')
178 | W_conv2 = lf.weight_variable(name='W_conv2')
179 | W_conv3 = lf.weight_variable(name='W_conv3')
180 | W_conv4 = lf.weight_variable(name='W_conv4')
181 | W_fc1 = lf.weight_variable(name='W_fc1')
182 |
183 | b_conv1 = lf.bias_variable(name='b_conv1')
184 | b_conv2 = lf.bias_variable(name='b_conv2')
185 | b_conv3 = lf.bias_variable(name='b_conv3')
186 | b_conv4 = lf.bias_variable(name='b_conv4')
187 | b_fc1 = lf.bias_variable(name='b_fc1')
188 |
189 | conv1 = tf.nn.relu(lf.conv2d(input_tensor2d, W_conv1, stride=2) + b_conv1)
190 | conv1_norm = lf.batch_norm_aiuiuc_wrapper(conv1, bn_is_training, \
191 | 'BN1', reuse_vars=reuse)
192 |
193 | conv2 = tf.nn.relu(lf.conv2d(conv1_norm, W_conv2, stride=2) + b_conv2)
194 | conv2_norm = lf.batch_norm_aiuiuc_wrapper(conv2, bn_is_training, \
195 | 'BN2', reuse_vars=reuse)
196 |
197 | conv3 = tf.nn.relu(lf.conv2d(conv2_norm, W_conv3, stride=2) + b_conv3)
198 | conv3_norm = lf.batch_norm_aiuiuc_wrapper(conv3, bn_is_training, \
199 | 'BN3', reuse_vars=reuse)
200 |
201 | conv4 = tf.nn.relu(lf.conv2d(conv3_norm, W_conv4, stride=2) + b_conv4)
202 | conv4_norm = lf.batch_norm_aiuiuc_wrapper(conv4, bn_is_training, \
203 | 'BN4', reuse_vars=reuse)
204 |
205 | dropout1 = tf.nn.dropout(conv4_norm, keep_prob)
206 | flatten1 = tf.reshape(dropout1, [-1, 4*4*1024])
207 |
208 | fc1 = tf.matmul(flatten1, W_fc1)+b_fc1
209 |
210 | return fc1
211 |
212 | def __decoder(self, scope, gfeat32, gfeat16, gfeat8, gfeat4, bn_is_training, inp_greylevel,\
213 | z1_sample, reuse=False):
214 |
215 | lf = self.layer_factory
216 |
217 | if(reuse == False):
218 | W_deconv1 = lf.weight_variable(name='W_deconv1', shape=[4, 4, self.flags.hidden_size, 1024])
219 | W_deconv2 = lf.weight_variable(name='W_deconv2', shape=[5, 5, 1024+512, 512])
220 | W_deconv3 = lf.weight_variable(name='W_deconv3', shape=[5, 5, 512+256, 256])
221 | W_deconv4 = lf.weight_variable(name='W_deconv4', shape=[5, 5, 256+128, 128])
222 | W_deconv5 = lf.weight_variable(name='W_deconv5', shape=[5, 5, 128, self.nch])
223 |
224 | b_deconv1 = lf.bias_variable(name='b_deconv1', shape=[1024])
225 | b_deconv2 = lf.bias_variable(name='b_deconv2', shape=[512])
226 | b_deconv3 = lf.bias_variable(name='b_deconv3', shape=[256])
227 | b_deconv4 = lf.bias_variable(name='b_deconv4', shape=[128])
228 | b_deconv5 = lf.bias_variable(name='b_deconv5', shape=[self.nch])
229 | else:
230 | W_deconv1 = lf.weight_variable(name='W_deconv1')
231 | W_deconv2 = lf.weight_variable(name='W_deconv2')
232 | W_deconv3 = lf.weight_variable(name='W_deconv3')
233 | W_deconv4 = lf.weight_variable(name='W_deconv4')
234 | W_deconv5 = lf.weight_variable(name='W_deconv5')
235 |
236 | b_deconv1 = lf.bias_variable(name='b_deconv1')
237 | b_deconv2 = lf.bias_variable(name='b_deconv2')
238 | b_deconv3 = lf.bias_variable(name='b_deconv3')
239 | b_deconv4 = lf.bias_variable(name='b_deconv4')
240 | b_deconv5 = lf.bias_variable(name='b_deconv5')
241 |
242 | inp_greylevel2d = tf.reshape(inp_greylevel, [self.flags.batch_size, \
243 | self.flags.img_height, self.flags.img_width, 1])
244 | input2d = tf.reshape(z1_sample, [self.flags.batch_size, 1, 1, self.flags.hidden_size])
245 | deconv1_upsamp = tf.image.resize_images(input2d, [4, 4])
246 |
247 | deconv1_upsamp_sc = tf.multiply(deconv1_upsamp, gfeat4)
248 |
249 | deconv1 = tf.nn.relu(lf.conv2d(deconv1_upsamp_sc, W_deconv1, stride=1) + b_deconv1)
250 | deconv1_norm = lf.batch_norm_aiuiuc_wrapper(deconv1, bn_is_training, \
251 | 'BN_deconv1', reuse_vars=reuse)
252 |
253 | deconv2_upsamp = tf.image.resize_images(deconv1_norm, [8, 8])
254 | deconv2_upsamp_sc = tf.concat([deconv2_upsamp, gfeat8], 3)
255 | deconv2 = tf.nn.relu(lf.conv2d(deconv2_upsamp_sc, W_deconv2, stride=1) + b_deconv2)
256 | deconv2_norm = lf.batch_norm_aiuiuc_wrapper(deconv2, bn_is_training, \
257 | 'BN_deconv2', reuse_vars=reuse)
258 |
259 | deconv3_upsamp = tf.image.resize_images(deconv2_norm, [16, 16])
260 | deconv3_upsamp_sc = tf.concat([deconv3_upsamp, gfeat16], 3)
261 | deconv3 = tf.nn.relu(lf.conv2d(deconv3_upsamp_sc, W_deconv3, stride=1) + b_deconv3)
262 | deconv3_norm = lf.batch_norm_aiuiuc_wrapper(deconv3, bn_is_training, \
263 | 'BN_deconv3', reuse_vars=reuse)
264 |
265 | deconv4_upsamp = tf.image.resize_images(deconv3_norm, [32, 32])
266 | deconv4_upsamp_sc = tf.concat([deconv4_upsamp, gfeat32], 3)
267 | deconv4 = tf.nn.relu(lf.conv2d(deconv4_upsamp_sc, W_deconv4, stride=1) + b_deconv4)
268 | deconv4_norm = lf.batch_norm_aiuiuc_wrapper(deconv4, bn_is_training, \
269 | 'BN_deconv4', reuse_vars=reuse)
270 |
271 | deconv5_upsamp = tf.image.resize_images(deconv4_norm, [64, 64])
272 | deconv5 = lf.conv2d(deconv5_upsamp, W_deconv5, stride=1) + b_deconv5
273 | deconv5_norm = lf.batch_norm_aiuiuc_wrapper(deconv5, bn_is_training, \
274 | 'BN_deconv5', reuse_vars=reuse)
275 |
276 | decoded_ch = tf.reshape(tf.tanh(deconv5_norm), \
277 | [self.flags.batch_size, self.flags.img_height*self.flags.img_width*self.nch])
278 |
279 | return decoded_ch
280 |
--------------------------------------------------------------------------------
/vae/arch/vae_wo_skipconn.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import tensorflow as tf
4 | from layer_factory import layer_factory
5 | from tensorflow.python.framework import tensor_shape
6 |
7 | class vae_wo_skipconn:
8 |
9 | def __init__(self, flags, nch=2, condinference_flag=False):
10 | self.flags = flags
11 | self.nch = nch
12 | self.layer_factory = layer_factory()
13 | self.condinference_flag = condinference_flag
14 |
15 | #Returns handles to input placeholders
16 | def inputs(self):
17 | inp_img = tf.placeholder(tf.float32, [self.flags.batch_size, \
18 | self.nch * self.flags.img_height * self.flags.img_width])
19 | inp_greylevel = tf.placeholder(tf.float32, [self.flags.batch_size, \
20 | self.flags.img_height * self.flags.img_width])
21 | inp_latent = tf.placeholder(tf.float32, [self.flags.batch_size, \
22 | self.flags.hidden_size])
23 | is_training = tf.placeholder(tf.bool)
24 | is_training_dec = tf.placeholder(tf.bool)
25 | keep_prob = tf.placeholder(tf.float32)
26 | kl_weight = tf.placeholder(tf.float32)
27 | lossweights = tf.placeholder(tf.float32, [self.flags.batch_size, \
28 | self.nch * self.flags.img_height * self.flags.img_width])
29 |
30 | return inp_img, inp_greylevel, inp_latent, is_training, keep_prob, kl_weight, lossweights
31 |
32 | #Takes input placeholders, builds inference graph and returns net. outputs
33 | def inference(self, inp_img, inp_greylevel, inp_latent, is_training, keep_prob):
34 |
35 | with tf.variable_scope('Inference', reuse=False) as sc:
36 | z1_train = self.__encoder(sc, inp_img, is_training, keep_prob, \
37 | in_nch=self.nch, reuse=False)
38 | epsilon_train = tf.truncated_normal([self.flags.batch_size, self.flags.hidden_size])
39 | mean_train = z1_train[:, :self.flags.hidden_size]
40 | stddev_train = tf.sqrt(tf.exp(z1_train[:, self.flags.hidden_size:]))
41 | z1_sample = mean_train + epsilon_train * stddev_train
42 | output_train = self.__decoder(sc, is_training, inp_greylevel, z1_sample, reuse=False)
43 |
44 | with tf.variable_scope('Inference', reuse=True) as sc:
45 | if(self.condinference_flag == False):
46 | z1_test = self.__encoder(sc, inp_img, is_training, keep_prob, \
47 | in_nch=self.nch, reuse=True)
48 | epsilon_test = tf.truncated_normal([self.flags.batch_size, self.flags.hidden_size])
49 | mean_test = z1_test[:, :self.flags.hidden_size]
50 | stddev_test = tf.sqrt(tf.exp(z1_test[:, self.flags.hidden_size:]))
51 | z1_sample = mean_test + epsilon_test * stddev_test
52 | tf.stop_gradient(z1_sample) #Fix the encoder
53 | output_test = self.__decoder(sc, is_training, inp_greylevel, z1_sample, reuse=True)
54 | output_condinference = None
55 | else:
56 | mean_test = None
57 | stddev_test = None
58 | output_test = None
59 | output_condinference = self.__decoder(sc, is_training, inp_greylevel, inp_latent,\
60 | reuse=True)
61 |
62 | return mean_train, stddev_train, output_train, mean_test, stddev_test, \
63 | output_test, output_condinference
64 |
65 | #Takes net. outputs and computes loss for vae(enc+dec)
66 | def loss(self, target_tensor, op_tensor, mean, stddev, kl_weight, lossweights, epsilon=1e-6, \
67 | is_regression=True):
68 |
69 | kl_loss = tf.reduce_sum(0.5 * (tf.square(mean) + tf.square(stddev) \
70 | - tf.log(tf.maximum(tf.square(stddev), epsilon)) - 1.0))
71 |
72 | recon_loss_chi = tf.reduce_mean(tf.sqrt(tf.reduce_sum( \
73 | lossweights*tf.square(target_tensor-op_tensor), 1)), 0)
74 |
75 | #Load Principle components
76 | np_pcvec = np.transpose(np.load(os.path.join(self.flags.pc_dir, 'components.mat.npy')))
77 | np_pcvar = 1./np.load(os.path.join(self.flags.pc_dir, 'exp_variance.mat.npy'))
78 | np_pcvec = np_pcvec[:, :self.flags.pc_comp]
79 | np_pcvar = np_pcvar[:self.flags.pc_comp]
80 | pcvec = tf.constant(np_pcvec)
81 | pcvar = tf.constant(np_pcvar)
82 |
83 | projmat_op = tf.matmul(op_tensor, pcvec)
84 | projmat_target = tf.matmul(target_tensor, pcvec)
85 | weightmat = tf.tile(tf.reshape(pcvar, [1, self.flags.pc_comp]), [self.flags.batch_size, 1])
86 | loss_topk_pc = tf.reduce_mean(tf.reduce_sum(\
87 | tf.multiply(tf.square(projmat_op-projmat_target), weightmat), 1), 0)
88 |
89 | res_op = op_tensor
90 | res_target = target_tensor
91 | for npc in range(self.flags.pc_comp):
92 | pcvec_curr = tf.tile(tf.reshape(tf.transpose(pcvec[:, npc]), [1, -1]), \
93 | [self.flags.batch_size, 1])
94 | projop_curr = tf.tile(tf.reshape(projmat_op[:, npc], [self.flags.batch_size, 1]), \
95 | [1, self.nch * self.flags.img_height * self.flags.img_width])
96 |
97 | projtarget_curr = tf.tile(tf.reshape(projmat_target[:, npc], [self.flags.batch_size, 1]), \
98 | [1, self.nch * self.flags.img_height * self.flags.img_width])
99 |
100 | res_op = tf.subtract(res_op, tf.multiply(projop_curr, pcvec_curr))
101 | res_target = tf.subtract(res_target, tf.multiply(projtarget_curr, pcvec_curr))
102 |
103 | res_error = tf.reduce_sum(tf.square(res_op-res_target), 1)
104 | res_error_weight = tf.tile(tf.reshape(pcvar[self.flags.pc_comp-1], [1, 1]), [self.flags.batch_size, 1])
105 | loss_res_pc = tf.reduce_mean(tf.multiply(\
106 | tf.reshape(res_error, [self.flags.batch_size, 1]), res_error_weight))
107 |
108 | recon_loss = recon_loss_chi + (1e-1)*(loss_topk_pc + loss_res_pc)
109 |
110 | if(self.nch == 2):
111 | target_tensor2d = tf.reshape(target_tensor, [self.flags.batch_size, \
112 | self.flags.img_height, self.flags.img_width, self.nch])
113 | op_tensor2d = tf.reshape(op_tensor, [self.flags.batch_size, \
114 | self.flags.img_height, self.flags.img_width, self.nch])
115 | [n,w,h,c] = target_tensor2d.get_shape().as_list()
116 | dv = tf.square((target_tensor2d[:,1:,:h-1,:] - target_tensor2d[:,:w-1,:h-1,:])
117 | - (op_tensor2d[:,1:,:h-1,:] - op_tensor2d[:,:w-1,:h-1,:]))
118 | dh = tf.square((target_tensor2d[:,:w-1,1:,:] - target_tensor2d[:,:w-1,:h-1,:])
119 | - (op_tensor2d[:,:w-1,1:,:] - op_tensor2d[:,:w-1,:h-1,:]))
120 | grad_loss = tf.reduce_mean(tf.sqrt(tf.reduce_sum(dv+dh,[1,2,3])))
121 | recon_loss = recon_loss + (1e-3)*grad_loss
122 |
123 |
124 | loss = kl_weight*kl_loss + recon_loss
125 |
126 | tf.summary.scalar('grad_loss', grad_loss)
127 | tf.summary.scalar('kl_loss', kl_loss)
128 | tf.summary.scalar('recon_loss_chi', recon_loss_chi)
129 | tf.summary.scalar('recon_loss', recon_loss)
130 | return loss
131 |
132 | #Takes loss and returns GD train step
133 | def optimize(self, loss, epsilon):
134 | train_step = tf.train.AdamOptimizer(self.flags.lr_vae, epsilon=epsilon).minimize(loss)
135 | return train_step
136 |
137 | def __encoder(self, scope, input_tensor, bn_is_training, keep_prob, in_nch=1, reuse=False):
138 |
139 | lf = self.layer_factory
140 |
141 | input_tensor2d = tf.reshape(input_tensor, [self.flags.batch_size, \
142 | self.flags.img_height, self.flags.img_width, in_nch])
143 |
144 | if(self.nch == 1 and reuse==False):
145 | tf.image_summary('summ_input_tensor2d', input_tensor2d, max_images=10)
146 |
147 | nch = tensor_shape.as_dimension(input_tensor2d.get_shape()[3]).value
148 |
149 | if(reuse==False):
150 | W_conv1 = lf.weight_variable(name='W_conv1', shape=[5, 5, nch, 128])
151 | W_conv2 = lf.weight_variable(name='W_conv2', shape=[5, 5, 128, 256])
152 | W_conv3 = lf.weight_variable(name='W_conv3', shape=[5, 5, 256, 512])
153 | W_conv4 = lf.weight_variable(name='W_conv4', shape=[4, 4, 512, 1024])
154 | W_fc1 = lf.weight_variable(name='W_fc1', shape=[4*4*1024, self.flags.hidden_size * 2])
155 |
156 | b_conv1 = lf.bias_variable(name='b_conv1', shape=[128])
157 | b_conv2 = lf.bias_variable(name='b_conv2', shape=[256])
158 | b_conv3 = lf.bias_variable(name='b_conv3', shape=[512])
159 | b_conv4 = lf.bias_variable(name='b_conv4', shape=[1024])
160 | b_fc1 = lf.bias_variable(name='b_fc1', shape=[self.flags.hidden_size * 2])
161 | else:
162 | W_conv1 = lf.weight_variable(name='W_conv1')
163 | W_conv2 = lf.weight_variable(name='W_conv2')
164 | W_conv3 = lf.weight_variable(name='W_conv3')
165 | W_conv4 = lf.weight_variable(name='W_conv4')
166 | W_fc1 = lf.weight_variable(name='W_fc1')
167 |
168 | b_conv1 = lf.bias_variable(name='b_conv1')
169 | b_conv2 = lf.bias_variable(name='b_conv2')
170 | b_conv3 = lf.bias_variable(name='b_conv3')
171 | b_conv4 = lf.bias_variable(name='b_conv4')
172 | b_fc1 = lf.bias_variable(name='b_fc1')
173 |
174 | conv1 = tf.nn.relu(lf.conv2d(input_tensor2d, W_conv1, stride=2) + b_conv1)
175 | conv1_norm = lf.batch_norm_aiuiuc_wrapper(conv1, bn_is_training, \
176 | 'BN1', reuse_vars=reuse)
177 |
178 | conv2 = tf.nn.relu(lf.conv2d(conv1_norm, W_conv2, stride=2) + b_conv2)
179 | conv2_norm = lf.batch_norm_aiuiuc_wrapper(conv2, bn_is_training, \
180 | 'BN2', reuse_vars=reuse)
181 |
182 | conv3 = tf.nn.relu(lf.conv2d(conv2_norm, W_conv3, stride=2) + b_conv3)
183 | conv3_norm = lf.batch_norm_aiuiuc_wrapper(conv3, bn_is_training, \
184 | 'BN3', reuse_vars=reuse)
185 |
186 | conv4 = tf.nn.relu(lf.conv2d(conv3_norm, W_conv4, stride=2) + b_conv4)
187 | conv4_norm = lf.batch_norm_aiuiuc_wrapper(conv4, bn_is_training, \
188 | 'BN4', reuse_vars=reuse)
189 |
190 | dropout1 = tf.nn.dropout(conv4_norm, keep_prob)
191 | flatten1 = tf.reshape(dropout1, [-1, 4*4*1024])
192 |
193 | fc1 = tf.matmul(flatten1, W_fc1)+b_fc1
194 |
195 | return fc1
196 |
197 | def __decoder(self, scope, bn_is_training, inp_greylevel, z1_sample, reuse=False):
198 |
199 | lf = self.layer_factory
200 |
201 | if(reuse == False):
202 | W_deconv1 = lf.weight_variable(name='W_deconv1', shape=[4, 4, self.flags.hidden_size, 1024])
203 | W_deconv2 = lf.weight_variable(name='W_deconv2', shape=[5, 5, 1024, 512])
204 | W_deconv3 = lf.weight_variable(name='W_deconv3', shape=[5, 5, 514, 256])
205 | W_deconv4 = lf.weight_variable(name='W_deconv4', shape=[5, 5, 258, 128])
206 | W_deconv5 = lf.weight_variable(name='W_deconv5', shape=[5, 5, 128, self.nch])
207 |
208 | b_deconv1 = lf.bias_variable(name='b_deconv1', shape=[1024])
209 | b_deconv2 = lf.bias_variable(name='b_deconv2', shape=[512])
210 | b_deconv3 = lf.bias_variable(name='b_deconv3', shape=[256])
211 | b_deconv4 = lf.bias_variable(name='b_deconv4', shape=[128])
212 | b_deconv5 = lf.bias_variable(name='b_deconv5', shape=[self.nch])
213 | else:
214 | W_deconv1 = lf.weight_variable(name='W_deconv1')
215 | W_deconv2 = lf.weight_variable(name='W_deconv2')
216 | W_deconv3 = lf.weight_variable(name='W_deconv3')
217 | W_deconv4 = lf.weight_variable(name='W_deconv4')
218 | W_deconv5 = lf.weight_variable(name='W_deconv5')
219 |
220 | b_deconv1 = lf.bias_variable(name='b_deconv1')
221 | b_deconv2 = lf.bias_variable(name='b_deconv2')
222 | b_deconv3 = lf.bias_variable(name='b_deconv3')
223 | b_deconv4 = lf.bias_variable(name='b_deconv4')
224 | b_deconv5 = lf.bias_variable(name='b_deconv5')
225 |
226 | inp_greylevel2d = tf.reshape(inp_greylevel, [self.flags.batch_size, \
227 | self.flags.img_height, self.flags.img_width, 1])
228 | input_concat2d = tf.reshape(z1_sample, [self.flags.batch_size, 1, 1, self.flags.hidden_size])
229 |
230 | deconv1_upsamp = tf.image.resize_images(input_concat2d, [4, 4])
231 | deconv1 = tf.nn.relu(lf.conv2d(deconv1_upsamp, W_deconv1, stride=1) + b_deconv1)
232 | deconv1_norm = lf.batch_norm_aiuiuc_wrapper(deconv1, bn_is_training, \
233 | 'BN_deconv1', reuse_vars=reuse)
234 |
235 | deconv2_upsamp = tf.image.resize_images(deconv1_norm, [8, 8])
236 | deconv2 = tf.nn.relu(lf.conv2d(deconv2_upsamp, W_deconv2, stride=1) + b_deconv2)
237 | deconv2_norm = lf.batch_norm_aiuiuc_wrapper(deconv2, bn_is_training, \
238 | 'BN_deconv2', reuse_vars=reuse)
239 |
240 | deconv3_upsamp = tf.image.resize_images(deconv2_norm, [16, 16])
241 | grey_deconv3_dv, grey_deconv3_dh = self.__get_gradients(inp_greylevel2d, \
242 | shape=[16, 16])
243 | deconv3_upsamp_edge = tf.concat([deconv3_upsamp, grey_deconv3_dv, grey_deconv3_dh], 3)
244 | deconv3 = tf.nn.relu(lf.conv2d(deconv3_upsamp_edge, W_deconv3, stride=1) + b_deconv3)
245 | deconv3_norm = lf.batch_norm_aiuiuc_wrapper(deconv3, bn_is_training, \
246 | 'BN_deconv3', reuse_vars=reuse)
247 |
248 | deconv4_upsamp = tf.image.resize_images(deconv3_norm, [32, 32])
249 | grey_deconv4_dv, grey_deconv4_dh = self.__get_gradients(inp_greylevel2d, \
250 | shape=[32, 32])
251 | deconv4_upsamp_edge = tf.concat([deconv4_upsamp, grey_deconv4_dv, grey_deconv4_dh], 3)
252 | deconv4 = tf.nn.relu(lf.conv2d(deconv4_upsamp_edge, W_deconv4, stride=1) + b_deconv4)
253 | deconv4_norm = lf.batch_norm_aiuiuc_wrapper(deconv4, bn_is_training, \
254 | 'BN_deconv4', reuse_vars=reuse)
255 |
256 | deconv5_upsamp = tf.image.resize_images(deconv4_norm, [64, 64])
257 | deconv5 = lf.conv2d(deconv5_upsamp, W_deconv5, stride=1) + b_deconv5
258 | deconv5_norm = lf.batch_norm_aiuiuc_wrapper(deconv5, bn_is_training, \
259 | 'BN_deconv5', reuse_vars=reuse)
260 |
261 | decoded_ch = tf.reshape(tf.tanh(deconv5_norm), \
262 | [self.flags.batch_size, self.flags.img_height*self.flags.img_width*self.nch])
263 |
264 | return decoded_ch
265 |
266 | def __get_gradients(self, in_tensor2d, shape=None):
267 | if(shape is not None):
268 | in_tensor = tf.image.resize_images(in_tensor2d, [shape[0], shape[1]])
269 | else:
270 | in_tensor = in_tensor2d
271 | [n,w,h,c] = in_tensor.get_shape().as_list()
272 | dvert = in_tensor[:,1:,:h,:] - in_tensor[:,:w-1,:h,:]
273 | dvert_padded = tf.concat([tf.constant(0., shape=[n, 1, h, c]), dvert], 1)
274 | dhorz = in_tensor[:,:w,1:,:] - in_tensor[:,:w,:h-1,:]
275 | dhorz_padded = tf.concat([tf.constant(0., shape=[n, w, 1, c]), dhorz], 2)
276 | return dvert_padded, dhorz_padded
277 |
--------------------------------------------------------------------------------
/vae/data_loaders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aditya12agd5/divcolor/570155a062bcb6428353ec0e80343badbc290caf/vae/data_loaders/__init__.py
--------------------------------------------------------------------------------
/vae/data_loaders/lab_imageloader.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import glob
3 | import math
4 | import numpy as np
5 | import os
6 |
7 | class lab_imageloader:
8 |
9 | def __init__(self, data_directory, out_directory, listdir=None, shape=(64, 64), \
10 | subdir=False, ext='JPEG', outshape=(256, 256)):
11 |
12 | if(listdir == None):
13 | if(subdir==False):
14 | self.test_img_fns = glob.glob('%s/*.%s' % (data_directory, ext))
15 | else:
16 | self.test_img_fns = glob.glob('%s/*/*.%s' % (data_directory, ext))
17 | self.train_img_fns = []
18 | else:
19 | self.train_img_fns = []
20 | self.test_img_fns = []
21 | with open('%s/list.train.vae.txt' % listdir, 'r') as ftr:
22 | for img_fn in ftr:
23 | self.train_img_fns.append(img_fn.strip('\n'))
24 |
25 | with open('%s/list.test.vae.txt' % listdir, 'r') as fte:
26 | for img_fn in fte:
27 | self.test_img_fns.append(img_fn.strip('\n'))
28 |
29 | self.train_img_num = len(self.train_img_fns)
30 | self.test_img_num = len(self.test_img_fns)
31 | self.train_batch_head = 0
32 | self.test_batch_head = 0
33 | self.train_shuff_ids = np.random.permutation(len(self.train_img_fns))
34 | self.test_shuff_ids = range(len(self.test_img_fns))
35 | self.shape = shape
36 | self.outshape = outshape
37 | self.out_directory = out_directory
38 | self.lossweights = None
39 |
40 | countbins = 1./np.load('data/zhang_weights/prior_probs.npy')
41 | binedges = np.load('data/zhang_weights/ab_quantize.npy').reshape(2, 313)
42 | lossweights = {}
43 | for i in range(313):
44 | if binedges[0, i] not in lossweights:
45 | lossweights[binedges[0, i]] = {}
46 | lossweights[binedges[0,i]][binedges[1,i]] = countbins[i]
47 | self.binedges = binedges
48 | self.lossweights = lossweights
49 |
50 | def reset(self):
51 | self.train_batch_head = 0
52 | self.test_batch_head = 0
53 | self.train_shuff_ids = range(len(self.train_img_fns))
54 | self.test_shuff_ids = range(len(self.test_img_fns))
55 |
56 | def random_reset(self):
57 | self.train_batch_head = 0
58 | self.test_batch_head = 0
59 | self.train_shuff_ids = np.random.permutation(len(self.train_img_fns))
60 | self.test_shuff_ids = range(len(self.test_img_fns))
61 |
62 | def train_next_batch(self, batch_size, nch):
63 | batch = np.zeros((batch_size, nch*np.prod(self.shape)), dtype='f')
64 | batch_lossweights = np.ones((batch_size, nch*np.prod(self.shape)), dtype='f')
65 | batch_recon_const = np.zeros((batch_size, np.prod(self.shape)), dtype='f')
66 | batch_recon_const_outres = np.zeros((batch_size, np.prod(self.outshape)), dtype='f')
67 |
68 | if(self.train_batch_head + batch_size >= len(self.train_img_fns)):
69 | self.train_shuff_ids = np.random.permutation(len(self.train_img_fns))
70 | self.train_batch_head = 0
71 |
72 | for i_n, i in enumerate(range(self.train_batch_head, self.train_batch_head+batch_size)):
73 | currid = self.train_shuff_ids[i]
74 | img_large = cv2.imread(self.train_img_fns[currid])
75 |
76 | if(self.shape is not None):
77 | img = cv2.resize(img_large, (self.shape[0], self.shape[1]))
78 | img_outres = cv2.resize(img_large, (self.outshape[0], self.outshape[1]))
79 |
80 | img_lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
81 | img_lab_outres = cv2.cvtColor(img_outres, cv2.COLOR_BGR2LAB)
82 | batch_recon_const[i_n, ...] = ((img_lab[..., 0].reshape(-1)*1.)-128.)/128.
83 | batch_recon_const_outres[i_n, ...] = ((img_lab_outres[..., 0].reshape(-1)*1.)-128.)/128.
84 | batch[i_n, ...] = np.concatenate((((img_lab[..., 1].reshape(-1)*1.)-128.)/128.,
85 | ((img_lab[..., 2].reshape(-1)*1.)-128.)/128.), axis=0)
86 |
87 | if(self.lossweights is not None):
88 | batch_lossweights[i_n, ...] = self.__get_lossweights(batch[i_n, ...])
89 |
90 | self.train_batch_head = self.train_batch_head + batch_size
91 |
92 | return batch, batch_recon_const, batch_lossweights, batch_recon_const_outres
93 |
94 | def test_next_batch(self, batch_size, nch):
95 | batch = np.zeros((batch_size, nch*np.prod(self.shape)), dtype='f')
96 | batch_recon_const = np.zeros((batch_size, np.prod(self.shape)), dtype='f')
97 | batch_recon_const_outres = np.zeros((batch_size, np.prod(self.outshape)), dtype='f')
98 | batch_imgnames = []
99 | if(self.test_batch_head + batch_size > len(self.test_img_fns)):
100 | self.test_batch_head = 0
101 |
102 | for i_n, i in enumerate(range(self.test_batch_head, self.test_batch_head+batch_size)):
103 | if(i >= self.test_img_num):
104 | #Repeat first image to make up for incomplete last batch
105 | i = 0
106 | currid = self.test_shuff_ids[i]
107 | img_large = cv2.imread(self.test_img_fns[currid])
108 | batch_imgnames.append(self.test_img_fns[currid].split('/')[-1])
109 | if(self.shape is not None):
110 | img = cv2.resize(img_large, (self.shape[1], self.shape[0]))
111 | img_outres = cv2.resize(img_large, (self.outshape[0], self.outshape[1]))
112 |
113 | img_lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
114 | img_lab_outres = cv2.cvtColor(img_outres, cv2.COLOR_BGR2LAB)
115 |
116 | batch_recon_const[i_n, ...] = ((img_lab[..., 0].reshape(-1)*1.)-128.)/128.
117 | batch_recon_const_outres[i_n, ...] = ((img_lab_outres[..., 0].reshape(-1)*1.)-128.)/128.
118 | batch[i_n, ...] = np.concatenate((((img_lab[..., 1].reshape(-1)*1.)-128.)/128.,
119 | ((img_lab[..., 2].reshape(-1)*1.)-128.)/128.), axis=0)
120 |
121 | self.test_batch_head = self.test_batch_head + batch_size
122 |
123 | return batch, batch_recon_const, batch_recon_const_outres, batch_imgnames
124 |
125 | def save_output_with_gt(self, net_op, gt, epoch, itr_id, prefix, batch_size, num_cols=8, net_recon_const=None):
126 | net_out_img = self.save_output(net_op, batch_size, num_cols=num_cols, net_recon_const=net_recon_const)
127 | gt_out_img = self.save_output(gt, batch_size, num_cols=num_cols, net_recon_const=net_recon_const)
128 | num_rows = np.int_(np.ceil((batch_size*1.)/num_cols))
129 | border_img = 255*np.ones((num_rows*self.outshape[0], 128, 3), dtype='uint8')
130 | out_fn_pred = '%s/%s_pred_%06d_%06d.png' % (self.out_directory, prefix, epoch, itr_id)
131 | print('[DEBUG] Writing output image: %s' % out_fn_pred)
132 | cv2.imwrite(out_fn_pred, np.concatenate((net_out_img, border_img, gt_out_img), axis=1))
133 |
134 | def save_output(self, net_op, batch_size, num_cols=8, net_recon_const=None):
135 | num_rows = np.int_(np.ceil((batch_size*1.)/num_cols))
136 | out_img = np.zeros((num_rows*self.outshape[0], num_cols*self.outshape[1], 3), dtype='uint8')
137 | img_lab = np.zeros((self.outshape[0], self.outshape[1], 3), dtype='uint8')
138 | c = 0
139 | r = 0
140 | for i in range(batch_size):
141 | if(i % num_cols == 0 and i > 0):
142 | r = r + 1
143 | c = 0
144 | img_lab[..., 0] = self.__get_decoded_img(net_recon_const[i, ...].reshape(self.outshape[0], self.outshape[1]))
145 | img_lab[..., 1] = self.__get_decoded_img(net_op[i, :np.prod(self.shape)].reshape(self.shape[0], self.shape[1]))
146 | img_lab[..., 2] = self.__get_decoded_img(net_op[i, np.prod(self.shape):].reshape(self.shape[0], self.shape[1]))
147 | img_rgb = cv2.cvtColor(img_lab, cv2.COLOR_LAB2BGR)
148 | out_img[r*self.outshape[0]:(r+1)*self.outshape[0], c*self.outshape[1]:(c+1)*self.outshape[1], ...] = img_rgb
149 | c = c+1
150 | return out_img
151 |
152 | def save_divcolor(self, net_op, gt, epoch, itr_id, prefix, batch_size, imgname, num_cols=8, net_recon_const=None):
153 | img_lab = np.zeros((self.outshape[0], self.outshape[1], 3), dtype='uint8')
154 | img_lab_mat = np.zeros((self.shape[0], self.shape[1], 2), dtype='uint8')
155 | if not os.path.exists('%s/%s' % (self.out_directory, imgname)):
156 | os.makedirs('%s/%s' % (self.out_directory, imgname))
157 | for i in range(batch_size):
158 | img_lab[..., 0] = self.__get_decoded_img(net_recon_const[i, ...].reshape(self.outshape[0], self.outshape[1]))
159 | img_lab[..., 1] = self.__get_decoded_img(net_op[i, :np.prod(self.shape)].reshape(self.shape[0], self.shape[1]))
160 | img_lab[..., 2] = self.__get_decoded_img(net_op[i, np.prod(self.shape):].reshape(self.shape[0], self.shape[1]))
161 | img_lab_mat[..., 0] = 128.*net_op[i, :np.prod(self.shape)].reshape(self.shape[0], self.shape[1])+128.
162 | img_lab_mat[..., 1] = 128.*net_op[i, np.prod(self.shape):].reshape(self.shape[0], self.shape[1])+128.
163 | img_rgb = cv2.cvtColor(img_lab, cv2.COLOR_LAB2BGR)
164 | out_fn_pred = '%s/%s/%s_%03d.png' % (self.out_directory, imgname, prefix, i)
165 | cv2.imwrite(out_fn_pred, img_rgb)
166 | # out_fn_mat = '%s/%s/%s_%03d.mat' % (self.out_directory, imgname, prefix, i)
167 | # np.save(out_fn_mat, img_lab_mat)
168 | img_lab[..., 0] = self.__get_decoded_img(net_recon_const[i, ...].reshape(self.outshape[0], self.outshape[1]))
169 | img_lab[..., 1] = self.__get_decoded_img(gt[0, :np.prod(self.shape)].reshape(self.shape[0], self.shape[1]))
170 | img_lab[..., 2] = self.__get_decoded_img(gt[0, np.prod(self.shape):].reshape(self.shape[0], self.shape[1]))
171 | img_lab_mat[..., 0] = 128.*gt[0, :np.prod(self.shape)].reshape(self.shape[0], self.shape[1])+128.
172 | img_lab_mat[..., 1] = 128.*gt[0, np.prod(self.shape):].reshape(self.shape[0], self.shape[1])+128.
173 | out_fn_pred = '%s/%s/gt.png' % (self.out_directory, imgname)
174 | img_rgb = cv2.cvtColor(img_lab, cv2.COLOR_LAB2BGR)
175 | cv2.imwrite(out_fn_pred, img_rgb)
176 | # out_fn_mat = '%s/%s/gt.mat' % (self.out_directory, imgname)
177 | # np.save(out_fn_mat, img_lab_mat)
178 |
179 | def __get_decoded_img(self, img_enc):
180 | img_dec = 128.*img_enc + 128
181 | img_dec[img_dec < 0.] = 0.
182 | img_dec[img_dec > 255.] = 255.
183 | return cv2.resize(np.uint8(img_dec), (self.outshape[0], self.outshape[1]))
184 |
185 | def __get_lossweights(self, img_vec):
186 | img_vec = img_vec*128.
187 | img_lossweights = np.zeros(img_vec.shape, dtype='f')
188 | img_vec_a = img_vec[:np.prod(self.shape)]
189 | binedges_a = self.binedges[0,...].reshape(-1)
190 | binid_a = [binedges_a.flat[np.abs(binedges_a-v).argmin()] for v in img_vec_a]
191 | img_vec_b = img_vec[np.prod(self.shape):]
192 | binedges_b = self.binedges[1,...].reshape(-1)
193 | binid_b = [binedges_b.flat[np.abs(binedges_b-v).argmin()] for v in img_vec_b]
194 | binweights = np.array([self.lossweights[v1][v2] for v1,v2 in zip(binid_a, binid_b)])
195 | img_lossweights[:np.prod(self.shape)] = binweights
196 | img_lossweights[np.prod(self.shape):] = binweights
197 | return img_lossweights
198 |
--------------------------------------------------------------------------------
/vae/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["CUDA_VISIBLE_DEVICES"]="0"
3 | import socket
4 | import sys
5 |
6 | import tensorflow as tf
7 | import numpy as np
8 | from data_loaders.lab_imageloader import lab_imageloader
9 | from arch.vae_skipconn import vae_skipconn as vae
10 | #from arch.vae_wo_skipconn import vae_wo_skipconn as vae
11 |
12 | from arch.network import network
13 |
14 | flags = tf.flags
15 |
16 | #Directory params
17 | flags.DEFINE_string("out_dir", "", "")
18 | flags.DEFINE_string("in_dir", "", "")
19 | flags.DEFINE_string("list_dir", "", "")
20 |
21 | #Dataset Params
22 | flags.DEFINE_integer("batch_size", 32, "batch size")
23 | flags.DEFINE_integer("updates_per_epoch", 1, "number of updates per epoch")
24 | flags.DEFINE_integer("log_interval", 1, "input image height")
25 | flags.DEFINE_integer("img_width", 64, "input image width")
26 | flags.DEFINE_integer("img_height", 64, "input image height")
27 |
28 | #Network Params
29 | flags.DEFINE_boolean("is_train", True, "Is training flag")
30 | flags.DEFINE_boolean("is_run_cvae", False, "Is training flag")
31 | flags.DEFINE_integer("hidden_size", 64, "size of the hidden VAE unit")
32 | flags.DEFINE_float("lr_vae", 1e-6, "learning rate for vae")
33 | flags.DEFINE_integer("max_epoch_vae", 10, "max epoch")
34 | flags.DEFINE_integer("pc_comp", 20, "number of principle components")
35 |
36 |
37 | FLAGS = flags.FLAGS
38 |
39 | def main():
40 | if(len(sys.argv) == 1):
41 | raise NameError('[ERROR] No dataset key')
42 | elif(sys.argv[1] == 'lfw'):
43 | FLAGS.updates_per_epoch = 380
44 | FLAGS.log_interval = 120
45 | FLAGS.out_dir = 'data/output/lfw/'
46 | FLAGS.list_dir = 'data/imglist/lfw/'
47 | FLAGS.pc_dir = 'data/pcomp/lfw/'
48 | else:
49 | raise NameError('[ERROR] Incorrect dataset key')
50 | data_loader = lab_imageloader(FLAGS.in_dir, \
51 | os.path.join(FLAGS.out_dir, 'images'), \
52 | listdir=FLAGS.list_dir)
53 |
54 | #Diverse Colorization
55 | nmix = 8
56 | num_batches = 31
57 | lv_mdn_test = np.load(os.path.join(FLAGS.out_dir, 'lv_color_mdn_test.mat.npy'))
58 |
59 | graph_divcolor = tf.Graph()
60 | with graph_divcolor.as_default():
61 | model_colorfield = vae(FLAGS, nch=2, condinference_flag=True)
62 | dnn = network(model_colorfield, data_loader, 2, FLAGS)
63 | dnn.run_divcolor(os.path.join(FLAGS.out_dir, 'models') , \
64 | latent_vars_colorfield_test, num_batches=num_batches)
65 | if(FLAGS.is_run_cvae == True):
66 | dnn.run_cvae(os.path.join(FLAGS.out_dir, 'models') , \
67 | lv_mdn_test, num_batches=num_batches, num_repeat=8, num_cluster=5)
68 |
69 | if __name__ == "__main__":
70 | main()
71 |
--------------------------------------------------------------------------------
/vae/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["CUDA_VISIBLE_DEVICES"]="0"
3 | import socket
4 | import sys
5 |
6 | import tensorflow as tf
7 | import numpy as np
8 | from data_loaders.lab_imageloader import lab_imageloader
9 | from arch.vae_skipconn import vae_skipconn as vae
10 | #from arch.vae_wo_skipconn import vae_wo_skipconn as vae
11 |
12 | from arch.network import network as network
13 |
14 | flags = tf.flags
15 |
16 | #Directory params
17 | flags.DEFINE_string("out_dir", "", "")
18 | flags.DEFINE_string("in_dir", "", "")
19 | flags.DEFINE_string("list_dir", "", "")
20 |
21 | #Dataset Params
22 | flags.DEFINE_integer("batch_size", 32, "batch size")
23 | flags.DEFINE_integer("updates_per_epoch", 1, "number of updates per epoch")
24 | flags.DEFINE_integer("log_interval", 1, "input image height")
25 | flags.DEFINE_integer("img_width", 64, "input image width")
26 | flags.DEFINE_integer("img_height", 64, "input image height")
27 |
28 | #Network Params
29 | flags.DEFINE_boolean("is_train", True, "Is training flag")
30 | flags.DEFINE_integer("hidden_size", 64, "size of the hidden VAE unit")
31 | flags.DEFINE_float("lr_vae", 1e-6, "learning rate for vae")
32 | flags.DEFINE_integer("max_epoch_vae", 10, "max epoch")
33 | flags.DEFINE_integer("pc_comp", 20, "number of principle components")
34 |
35 |
36 | FLAGS = flags.FLAGS
37 |
38 | def main():
39 | if(len(sys.argv) == 1):
40 | raise NameError('[ERROR] No dataset key')
41 | elif(sys.argv[1] == 'lfw'):
42 | FLAGS.updates_per_epoch = 380
43 | FLAGS.log_interval = 120
44 | FLAGS.out_dir = 'data/output/lfw/'
45 | FLAGS.list_dir = 'data/imglist/lfw/'
46 | FLAGS.pc_dir = 'data/pcomp/lfw/'
47 | #add other datasets here
48 | else:
49 | raise NameError('[ERROR] Incorrect dataset key')
50 |
51 | data_loader = lab_imageloader(FLAGS.in_dir, \
52 | os.path.join(FLAGS.out_dir, 'images'), \
53 | listdir=FLAGS.list_dir)
54 |
55 | #Train colorfield VAE
56 | graph_vae = tf.Graph()
57 | with graph_vae.as_default():
58 | model_colorfield = vae(FLAGS, nch=2)
59 | dnn = network(model_colorfield, data_loader, 2, FLAGS)
60 | latent_vars_colorfield, latent_vars_colorfield_musigma_test = \
61 | dnn.train_vae(os.path.join(FLAGS.out_dir, 'models'), FLAGS.is_train)
62 |
63 | np.save(os.path.join(FLAGS.out_dir, 'lv_color_train.mat'), latent_vars_colorfield)
64 | np.save(os.path.join(FLAGS.out_dir, 'lv_color_test.mat'), latent_vars_colorfield_musigma_test)
65 |
66 | if __name__ == "__main__":
67 | main()
68 |
--------------------------------------------------------------------------------