├── LICENSE ├── PI.py ├── README.md ├── main.py └── util ├── .gitkeep ├── HandleIIDDataTFRecord.py ├── dataset_utils.py ├── layers.py ├── losses.py └── svhn.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /PI.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from layers import Layers 4 | from losses import LossFunctions 5 | 6 | class PI(object): 7 | 8 | def __init__(self, d, lr, lambda_pi_usl, use_pi): 9 | 10 | """ flags for each regularizor """ 11 | self.use_pi = use_pi 12 | 13 | """ data and external toolkits """ 14 | self.d = d # dataset manager 15 | self.ls = Layers() 16 | self.lf = LossFunctions(self.ls, d, self.encoder) 17 | 18 | """ placeholders defined outside""" 19 | self.lr = lr 20 | self.lambda_pi_usl = lambda_pi_usl 21 | 22 | def encoder(self, x, is_train=True, do_update_bn=True): 23 | 24 | """ https://arxiv.org/pdf/1610.02242.pdf """ 25 | 26 | if is_train: 27 | h = self.distort(x) 28 | h = self.ls.get_corrupted(x, 0.15) 29 | else: 30 | h = x 31 | 32 | scope = '1' 33 | h = self.ls.conv2d(scope+'_1', h, 128, activation=self.ls.lrelu) 34 | h = self.ls.conv2d(scope+'_2', h, 128, activation=self.ls.lrelu) 35 | h = self.ls.conv2d(scope+'_3', h, 128, activation=self.ls.lrelu) 36 | h = self.ls.max_pool(h) 37 | if is_train: h = tf.nn.dropout(h, 0.5) 38 | 39 | scope = '2' 40 | h = self.ls.conv2d(scope+'_1', h, 256, activation=self.ls.lrelu) 41 | h = self.ls.conv2d(scope+'_2', h, 256, activation=self.ls.lrelu) 42 | h = self.ls.conv2d(scope+'_3', h, 256, activation=self.ls.lrelu) 43 | h = self.ls.max_pool(h) 44 | if is_train: h = tf.nn.dropout(h, 0.5) 45 | 46 | scope = '3' 47 | h = self.ls.conv2d(scope+'_1', h, 512, activation=self.ls.lrelu) 48 | h = self.ls.conv2d(scope+'_2', h, 256, activation=self.ls.lrelu, filter_size=(1,1)) 49 | h = self.ls.conv2d(scope+'_3', h, 128, activation=self.ls.lrelu, filter_size=(1,1)) 50 | h = tf.reduce_mean(h, reduction_indices=[1, 2]) # Global average pooling 51 | h = self.ls.dense(scope, h, self.d.l) 52 | 53 | return h 54 | 55 | def build_graph_train(self, x_l, y_l, x, is_supervised=True): 56 | 57 | o = dict() # output 58 | loss = 0 59 | 60 | logit = self.encoder(x) 61 | 62 | with tf.variable_scope(tf.get_variable_scope(), reuse=True): 63 | logit_l = self.encoder(x_l, is_train=True, do_update_bn=False) # for pyx and vat loss computation 64 | 65 | """ Classification Loss """ 66 | o['Ly'], o['accur'] = self.lf.get_loss_pyx(logit_l, y_l) 67 | loss += o['Ly'] 68 | 69 | """ PI Model Loss """ 70 | if self.use_pi: 71 | with tf.variable_scope(tf.get_variable_scope(), reuse=True): 72 | _,_,o['Lp'] = self.lf.get_loss_pi(x, logit, is_train=True) 73 | loss += self.lambda_pi_usl * o['Lp'] 74 | else: 75 | o['Lp'] = tf.constant(0) 76 | 77 | """ set losses """ 78 | o['loss'] = loss 79 | self.o_train = o 80 | 81 | """ set optimizer """ 82 | optimizer = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5) 83 | #self.op = optimizer.minimize(loss) 84 | grads = optimizer.compute_gradients(loss) 85 | for i,(g,v) in enumerate(grads): 86 | if g is not None: 87 | #g = tf.Print(g, [g], "g %s = "%(v)) 88 | grads[i] = (tf.clip_by_norm(g,5),v) # clip gradients 89 | else: 90 | print('g is None:', v) 91 | v = tf.Print(v, [v], "v = ", summarize=10000) 92 | self.op = optimizer.apply_gradients(grads) # return train_op 93 | 94 | 95 | def build_graph_test(self, x_l, y_l ): 96 | 97 | o = dict() # output 98 | loss = 0 99 | 100 | logit_l = self.encoder(x_l, is_train=False, do_update_bn=False) # for pyx and vat loss computation 101 | 102 | """ classification loss """ 103 | o['Ly'], o['accur'] = self.lf.get_loss_pyx(logit_l, y_l) 104 | loss += o['Ly'] 105 | 106 | """ set losses """ 107 | o['loss'] = loss 108 | self.o_test = o 109 | 110 | def distort(self, x): 111 | 112 | _d = self.d 113 | 114 | def _distort(a_image): 115 | """ 116 | bounding_boxes: A Tensor of type float32. 117 | 3-D with shape [batch, N, 4] describing the N bounding boxes associated with the image. 118 | Bounding boxes are supplied and returned as [y_min, x_min, y_max, x_max] 119 | """ 120 | # shape: [1, 1, 4] 121 | bounding_boxes = tf.constant([[[1/10, 1/10, 9/10, 9/10]]], dtype=tf.float32) 122 | 123 | begin, size, _ = tf.image.sample_distorted_bounding_box( 124 | (_d.h, _d.w, _d.c), bounding_boxes, 125 | min_object_covered=(8.5/10.0), 126 | aspect_ratio_range=[7.0/10.0, 10.0/7.0]) 127 | 128 | a_image = tf.slice(a_image, begin, size) 129 | """ for the purpose of distorting not use tf.image.resize_image_with_crop_or_pad under """ 130 | a_image = tf.image.resize_images(a_image, [_d.h, _d.w]) 131 | """ due to the size of channel returned from tf.image.resize_images is not being given, 132 | specify it manually. """ 133 | a_image = tf.reshape(a_image, [_d.h, _d.w, _d.c]) 134 | return a_image 135 | 136 | """ process batch times in parallel """ 137 | return tf.map_fn( _distort, x) 138 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PI 2 | Very simple TensorFlow implementation of NVIDIA’s Π Model from [“Temporal Ensembling for Semi-Supervised Learning”](https://arxiv.org/pdf/1610.02242.pdf) (ICLR 2017) on the SVHN classification task. 3 | 4 |
5 |   6 |
7 | 8 | ## Usage 9 | 10 | ```python main.py``` 11 | 12 | 13 | ## Useful Resources 14 | 15 | - [Original Authors’ implementation with Theano and Lasagne](https://github.com/ericjang/draw) 16 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import sys, os, time 4 | 5 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/util') 6 | from HandleIIDDataTFRecord import HandleIIDDataTFRecord 7 | from PI import PI 8 | 9 | DO_TRAIN = True 10 | DO_TEST = True 11 | USE_PI = True 12 | 13 | tf.flags.DEFINE_string("dataset", "SVHN", "MNIST / CIFAR10 / SVHN / CharImages") 14 | tf.flags.DEFINE_boolean("restore", False, "restore from the last check point") 15 | tf.flags.DEFINE_string("dir_logs", "./out/", "") 16 | FLAGS = tf.flags.FLAGS 17 | 18 | if not DO_TRAIN and not FLAGS.restore: 19 | print('[WARN] FLAGS.restore is set to True compulsorily') 20 | FLAGS.restore = True 21 | 22 | N_EPOCHS = 100 23 | 24 | FILE_OF_CKPT = os.path.join(FLAGS.dir_logs,"drawmodel.ckpt") 25 | 26 | # learning rate decay 27 | STARTER_LEARNING_RATE = 1e-3 28 | DECAY_AFTER = 2 29 | DECAY_INTERVAL = 2 30 | DECAY_FACTOR = 0.97 31 | 32 | def get_lambda_pi_usl(epoch): 33 | if USE_PI: 34 | import math 35 | def _rampup(epoch): 36 | """ https://github.com/smlaine2/tempens/blob/master/train.py """ 37 | PI_RAMPUP_LENGTH = 80 # there seems to be no other option than 80, according to the paper. 38 | if epoch < PI_RAMPUP_LENGTH: 39 | p = 1.0 - (max(0.0, float(epoch)) / float(PI_RAMPUP_LENGTH)) 40 | return math.exp(-p*p*5.0) 41 | else: 42 | return 1.0 43 | 44 | PI_W_MAX = 100 45 | _pi_m_n = d.n_labeled / d.n_train 46 | return _rampup(epoch) * PI_W_MAX * _pi_m_n 47 | else: 48 | return 0.0 49 | 50 | 51 | def test(): 52 | accur = [] 53 | for i in range(d.n_batches_test): 54 | r = sess.run(m.o_test) 55 | accur.append( r['accur']) 56 | return np.mean(accur, axis=0) 57 | 58 | with tf.Graph().as_default() as g: 59 | 60 | ########################################### 61 | """ Load Data """ 62 | ########################################### 63 | BATCH_SIZE = 100 64 | d = HandleIIDDataTFRecord(FLAGS.dataset, BATCH_SIZE) 65 | (x_train, y_train), x, (x_test, y_test) = d.get_tfrecords() 66 | 67 | ########################################### 68 | """ Build Model Graphs """ 69 | ########################################### 70 | lr = tf.placeholder(tf.float32, shape=[], name="learning_rate") 71 | lambda_pi_usl = tf.placeholder(tf.float32, shape=(), name="lambda_pi_usl") 72 | 73 | with tf.variable_scope("watashinomodel") as scope: 74 | 75 | m = PI( d, lr, lambda_pi_usl, use_pi=USE_PI) 76 | 77 | print('... now building the graph for training.') 78 | m.build_graph_train(x_train,y_train,x) # the third one is a dummy for future 79 | scope.reuse_variables() 80 | if DO_TEST : 81 | print('... now building the graph for test.') 82 | m.build_graph_test(x_test,y_test) 83 | 84 | 85 | ########################################### 86 | """ Init """ 87 | ########################################### 88 | init_op = tf.global_variables_initializer() 89 | for v in tf.all_variables(): print("[DEBUG] %s : %s" % (v.name,v.get_shape())) 90 | 91 | saver = tf.train.Saver() 92 | config = tf.ConfigProto() 93 | config.gpu_options.allocator_type = 'BFC' 94 | sess = tf.Session(config = config) 95 | 96 | _lr, ratio = STARTER_LEARNING_RATE, 1.0 97 | 98 | if FLAGS.restore: 99 | print("... restore from the last check point.") 100 | saver.restore(sess, FILE_OF_CKPT) 101 | else: 102 | sess.run(init_op) 103 | 104 | merged = tf.summary.merge_all() 105 | tf.get_default_graph().finalize() 106 | 107 | ########################################### 108 | """ Training Loop """ 109 | ########################################### 110 | if DO_TRAIN: 111 | print('... start training') 112 | tf.train.start_queue_runners(sess=sess) 113 | for epoch in range(1, N_EPOCHS+1): 114 | 115 | loss, accur = [],[] 116 | for i in range(d.n_batches_train): 117 | 118 | feed_dict = {lr:_lr, lambda_pi_usl:get_lambda_pi_usl(epoch)} 119 | 120 | """ do update """ 121 | time_start = time.time() 122 | _, r, op, current_lr = sess.run([merged, m.o_train, m.op, m.lr], feed_dict=feed_dict) 123 | elapsed_time = time.time() - time_start 124 | 125 | loss.append(r['loss']) 126 | accur.append(r['accur']) 127 | 128 | if i % 100 == 0 and i != 0: 129 | 130 | print(" iter:%2d, loss: %.5f, accr: %.5f, Ly: %s, Lp: %s, time:%.3f" % \ 131 | (i, np.mean(np.array(loss)), np.mean(np.array(accur)), r['Ly'], r['Lp'], elapsed_time )) 132 | 133 | """ test """ 134 | if DO_TEST and epoch % 1 == 0: 135 | time_start = time.time() 136 | accur = test() 137 | elapsed_time = time.time() - time_start 138 | print("epoch:%d, accur: %s, time:%.3f" % (epoch, accur, elapsed_time )) 139 | 140 | """ save """ 141 | if epoch % 1 == 0: 142 | print("Model saved in file: %s" % saver.save(sess,FILE_OF_CKPT)) 143 | 144 | 145 | """ learning rate decay""" 146 | if (epoch % DECAY_INTERVAL == 0) and (epoch > DECAY_AFTER): 147 | ratio *= DECAY_FACTOR 148 | _lr = STARTER_LEARNING_RATE * ratio 149 | print('lr decaying is scheduled. epoch:%d, lr:%f <= %f' % ( epoch, _lr, current_lr)) 150 | 151 | 152 | sess.close() 153 | -------------------------------------------------------------------------------- /util/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /util/HandleIIDDataTFRecord.py: -------------------------------------------------------------------------------- 1 | import sys, os, time 2 | 3 | class HandleIIDDataTFRecord(object): 4 | 5 | def __init__(self, dataset, batch_size, is_debug=False): 6 | 7 | self.dataset = dataset 8 | self.batch_size = batch_size 9 | self.is_debug = is_debug 10 | 11 | if self.dataset == 'SVHN': 12 | from svhn import N_LABELED 13 | n_train, n_test, n_labeled = 73257, 26032, N_LABELED 14 | _h, _w, _c = 32,32,3 15 | _img_size = _h*_w*_c 16 | _l = 10 17 | _is_3d = True 18 | else: 19 | sys.exit('[ERROR] not implemented yet') 20 | 21 | self.h = _h 22 | self.w = _w 23 | self.c = _c 24 | self.l = _l 25 | self.is_3d = _is_3d 26 | self.img_size = _img_size 27 | self.n_train = n_train 28 | self.n_test = n_test 29 | self.n_labeled = n_labeled 30 | self.n_batches_train = int(n_train/batch_size) 31 | self.n_batches_test = int(n_test/batch_size) 32 | 33 | ######################################## 34 | """ inputs """ 35 | ######################################## 36 | def get_tfrecords(self): 37 | 38 | """ 39 | xtrain: all records 40 | *_l : partial records 41 | """ 42 | if self.dataset =='SVHN': 43 | from svhn import inputs, unlabeled_inputs 44 | xtrain_l, ytrain_l = inputs(batch_size=self.batch_size, train=True, validation=False, shuffle=True) 45 | xtrain = unlabeled_inputs(batch_size=self.batch_size, validation=False, shuffle=True) 46 | xtest , ytest = inputs(batch_size=self.batch_size, train=False, validation=False, shuffle=True) 47 | else: 48 | sys.exit('[ERROR] not implemented yet') 49 | return (xtrain_l, ytrain_l), xtrain, (xtest , ytest) 50 | 51 | 52 | if __name__ == '__main__': 53 | 54 | BATCH_SIZE = 20 55 | 56 | d = HandleIIDDataTFRecord( 'SVHN', BATCH_SIZE, is_debug=True) 57 | print(d.get_tfrecords()) 58 | 59 | sys.exit('saigo') 60 | -------------------------------------------------------------------------------- /util/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os, sys, pickle 3 | import numpy as np 4 | from scipy import linalg 5 | 6 | #FLAGS = tf.app.flags.FLAGS 7 | #tf.app.flags.DEFINE_bool('aug_trans', False, "") 8 | #tf.app.flags.DEFINE_bool('aug_flip', False, "") 9 | 10 | AUG_TRANS = False 11 | AUG_FLIP = False 12 | 13 | def unpickle(file): 14 | fp = open(file, 'rb') 15 | if sys.version_info.major == 2: 16 | data = pickle.load(fp) 17 | elif sys.version_info.major == 3: 18 | data = pickle.load(fp, encoding='latin-1') 19 | fp.close() 20 | return data 21 | 22 | 23 | def ZCA(data, reg=1e-6): 24 | mean = np.mean(data, axis=0) 25 | mdata = data - mean 26 | sigma = np.dot(mdata.T, mdata) / mdata.shape[0] 27 | U, S, V = linalg.svd(sigma) 28 | components = np.dot(np.dot(U, np.diag(1 / np.sqrt(S) + reg)), U.T) 29 | whiten = np.dot(data - mean, components.T) 30 | return components, mean, whiten 31 | 32 | 33 | def _int64_feature(value): 34 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 35 | 36 | 37 | def _bytes_feature(value): 38 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 39 | 40 | 41 | def convert_images_and_labels(images, labels, filepath): 42 | 43 | print('[DEBUG] inputs shape:', images.shape, labels.shape) # (4000, 3072) (4000,) 44 | num_examples = labels.shape[0] 45 | if images.shape[0] != num_examples: 46 | raise ValueError("Images size %d does not match label size %d." % 47 | (images.shape[0], num_examples)) 48 | print('Writing', filepath) 49 | writer = tf.python_io.TFRecordWriter(filepath) 50 | for index in range(num_examples): 51 | image = images[index].tolist() 52 | image_feature = tf.train.Feature(float_list=tf.train.FloatList(value=image)) 53 | #print('[DEBUG] image_feature:', image_feature) # float_list { value: xxx},...} 54 | example = tf.train.Example(features=tf.train.Features(feature={ 55 | 'height': _int64_feature(32), 56 | 'width': _int64_feature(32), 57 | 'depth': _int64_feature(3), 58 | 'label': _int64_feature(int(labels[index])), 59 | 'image': image_feature})) 60 | writer.write(example.SerializeToString()) 61 | writer.close() 62 | 63 | 64 | def read(filename_queue): 65 | reader = tf.TFRecordReader() 66 | print('filename_queue',filename_queue) 67 | _, serialized_example = reader.read(filename_queue) 68 | features = tf.parse_single_example( 69 | serialized_example, 70 | # Defaults are not specified since both keys are required. 71 | features={ 72 | 'image': tf.FixedLenFeature([3072], tf.float32), 73 | 'label': tf.FixedLenFeature([], tf.int64), 74 | }) 75 | 76 | # Convert label from a scalar uint8 tensor to an int32 scalar. 77 | image = features['image'] 78 | image = tf.reshape(image, [32, 32, 3]) 79 | label = tf.one_hot(tf.cast(features['label'], tf.int32), 10) 80 | return image, label 81 | 82 | 83 | def generate_batch( 84 | example, 85 | min_queue_examples, 86 | batch_size, shuffle): 87 | """ 88 | Arg: 89 | list of tensors. 90 | """ 91 | num_preprocess_threads = 1 92 | 93 | if shuffle: 94 | ret = tf.train.shuffle_batch( 95 | example, 96 | batch_size=batch_size, 97 | num_threads=num_preprocess_threads, 98 | capacity=min_queue_examples + 5 * batch_size, 99 | min_after_dequeue=min_queue_examples) 100 | else: 101 | ret = tf.train.batch( 102 | example, 103 | batch_size=batch_size, 104 | num_threads=num_preprocess_threads, 105 | allow_smaller_final_batch=True, 106 | capacity=min_queue_examples + 5 * batch_size) 107 | 108 | return ret 109 | 110 | 111 | def transform(image): 112 | image = tf.reshape(image, [32, 32, 3]) 113 | if AUG_TRANS or AUG_FLIP: 114 | print("augmentation") 115 | if AUG_TRANS: 116 | image = tf.pad(image, [[2, 2], [2, 2], [0, 0]]) 117 | image = tf.random_crop(image, [32, 32, 3]) 118 | if AUG_FLIP: 119 | image = tf.image.random_flip_left_right(image) 120 | return image 121 | 122 | 123 | def generate_filename_queue(filenames, data_dir, num_epochs=None): 124 | print("filenames in queue:", filenames) 125 | for i in range(len(filenames)): 126 | filenames[i] = os.path.join(data_dir, filenames[i]) 127 | return tf.train.string_input_producer(filenames, num_epochs=num_epochs) 128 | 129 | 130 | -------------------------------------------------------------------------------- /util/layers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | import sys 6 | 7 | class Layers(object): 8 | 9 | def __init__(self): 10 | self.do_share = False 11 | 12 | def set_do_share(self, flag): 13 | self.do_share = flag 14 | 15 | def W( self, W_shape, W_name='W', W_init=None): 16 | if W_init is None: 17 | W_initializer = tf.contrib.layers.xavier_initializer() 18 | else: 19 | W_initializer = tf.constant_initializer(W_init) 20 | 21 | return tf.get_variable(W_name, W_shape, initializer=W_initializer) 22 | 23 | def Wb( self, W_shape, b_shape, W_name='W', b_name='b', W_init=None, b_init=0.1): 24 | 25 | W = self.W(W_shape, W_name=W_name, W_init=None) 26 | b = tf.get_variable(b_name, b_shape, initializer=tf.constant_initializer(b_init)) 27 | 28 | def _summaries(var): 29 | """Attach a lot of summaries to a Tensor (for TensorBoard visualization).""" 30 | with tf.name_scope('summaries'): 31 | mean = tf.reduce_mean(var) 32 | tf.summary.scalar('mean', mean) 33 | with tf.name_scope('stddev'): 34 | stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) 35 | tf.summary.scalar('stddev', stddev) 36 | tf.summary.scalar('max', tf.reduce_max(var)) 37 | tf.summary.scalar('min', tf.reduce_min(var)) 38 | tf.summary.histogram('histogram', var) 39 | _summaries(W) 40 | _summaries(b) 41 | 42 | return W, b 43 | 44 | 45 | def denseV2( self, scope, x, output_dim, activation=None): 46 | return tf.contrib.layers.fully_connected( x, output_dim, activation_fn=activation, reuse=self.do_share, scope=scope) 47 | 48 | def dense( self, scope, x, output_dim, activation=None): 49 | if len(x.get_shape()) == 2: # 1d 50 | pass 51 | elif len(x.get_shape()) == 4: # cnn as NHWC 52 | #x = tf.reshape(x, [tf.shape(x)[0], -1]) # flatten 53 | x = tf.reshape(x, [x.get_shape().as_list()[0], -1]) # flatten 54 | #x = tf.reshape(x, [tf.cast(x.get_shape()[0], tf.int32), -1]) # flatten 55 | with tf.variable_scope(scope,reuse=self.do_share): W, b = self.Wb([x.get_shape()[1], output_dim], [output_dim]) 56 | #with tf.variable_scope(scope,reuse=self.do_share): W, b = self.Wb([x.get_shape()[1], output_dim], [output_dim]) 57 | o = tf.matmul(x, W) + b 58 | return o if activation is None else activation(o) 59 | 60 | def lrelu(self, x, a=0.1): 61 | if a < 1e-16: 62 | return tf.nn.relu(x) 63 | else: 64 | return tf.maximum(x, a * x) 65 | 66 | def avg_pool(self, x, ksize=2, stride=2): 67 | return tf.nn.avg_pool(x, ksize=[1, ksize, ksize, 1], strides=[1, stride, stride, 1], padding='SAME') 68 | 69 | def max_pool(self, x, ksize=2, stride=2): 70 | return tf.nn.max_pool(x, ksize=[1, ksize, ksize, 1], strides=[1, stride, stride, 1], padding='SAME') 71 | 72 | def conv2d( self, scope, x, out_c, filter_size=(3,3), strides=(1,1,1,1), padding="SAME", activation=None): 73 | """ 74 | x: [BATCH_SIZE, in_height, in_width, in_channels] 75 | filter : [filter_height, filter_width, in_channels, out_channels] 76 | """ 77 | filter = [filter_size[0], filter_size[1], int(x.get_shape()[3]), out_c] 78 | with tf.variable_scope(scope,reuse=self.do_share): W, b = self.Wb(filter, [out_c]) 79 | o = tf.nn.conv2d(x, W, strides, padding) + b 80 | return o if activation is None else activation(o) 81 | 82 | ########################################### 83 | """ Softmax """ 84 | ########################################### 85 | def softmax( self, scope, input, size): 86 | if input.get_shape()[1] != size: 87 | print("softmax w/ fc:", input.get_shape()[1], '->', size) 88 | return self.dense(scope, input, size, tf.nn.softmax) 89 | else: 90 | print("softmax w/o fc") 91 | return tf.nn.softmax(input) 92 | 93 | ########################################### 94 | """ Noise/Denose Function """ 95 | ########################################### 96 | def get_corrupted(self, x, noise_std=.10): 97 | return self.sampler( x, noise_std) 98 | 99 | def epsilon( self, _shape, _stddev=1.): 100 | return tf.truncated_normal(_shape, mean=0, stddev=_stddev) 101 | 102 | def sampler( self, mu, sigma): 103 | """ 104 | mu,sigma : (BATCH_SIZE, z_size) 105 | """ 106 | return mu + sigma*self.epsilon( tf.shape(mu) ) 107 | -------------------------------------------------------------------------------- /util/losses.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import sys 4 | 5 | 6 | eps = 1e-8 7 | 8 | class LossFunctions(object): 9 | 10 | def __init__(self, layers, dataset, encoder): 11 | 12 | self.ls = layers 13 | self.d = dataset 14 | self.encoder = encoder 15 | self.reconst_pixel_log_stdv = tf.get_variable("reconst_pixel_log_stdv", initializer=tf.constant(0.0)) 16 | 17 | def get_loss_pyx(self, logit, y): 18 | 19 | loss = self._ce(logit, y) 20 | accur = self._accuracy(logit, y) 21 | return loss, accur 22 | 23 | def get_loss_pi(self, x, logit_real, is_train): 24 | logit_real = tf.stop_gradient(logit_real) 25 | logit_virtual = self.encoder(x, is_train=is_train, do_update_bn=False) 26 | loss = tf.sqrt(tf.reduce_mean(tf.square(tf.subtract(logit_real, logit_virtual))) + eps) 27 | return logit_real, logit_virtual, loss 28 | 29 | 30 | """ https://github.com/takerum/vat_tf/blob/master/layers.py """ 31 | def _ce(self, logit, y): 32 | return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logit, labels=y)) 33 | 34 | def _accuracy(self, logit, y): 35 | pred = tf.argmax(logit, 1) 36 | true = tf.argmax(y, 1) 37 | return tf.reduce_mean(tf.to_float(tf.equal(pred, true))) 38 | -------------------------------------------------------------------------------- /util/svhn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import sys 7 | from scipy.io import loadmat 8 | 9 | import numpy as np 10 | from scipy import linalg 11 | import glob 12 | import pickle 13 | 14 | from six.moves import xrange # pylint: disable=redefined-builtin 15 | from six.moves import urllib 16 | 17 | import tensorflow as tf 18 | from dataset_utils import * 19 | 20 | DATA_URL_TRAIN = 'http://ufldl.stanford.edu/housenumbers/train_32x32.mat' 21 | DATA_URL_TEST = 'http://ufldl.stanford.edu/housenumbers/test_32x32.mat' 22 | 23 | N_LABELED = 4000 24 | DATASET_SEED = 1 25 | DATA_DIR = 'PATH_TO_DIR_OF_SVHN_IN_YOUR_ENVIRONMENT' 26 | 27 | FLAGS = tf.app.flags.FLAGS 28 | tf.app.flags.DEFINE_integer('num_valid_examples', 1000, "The number of validation examples") 29 | 30 | NUM_EXAMPLES_TRAIN = 73257 31 | NUM_EXAMPLES_TEST = 26032 32 | 33 | 34 | def maybe_download_and_extract(): 35 | if not os.path.exists(DATA_DIR): 36 | os.makedirs(DATA_DIR) 37 | filepath_train_mat = os.path.join(DATA_DIR, 'train_32x32.mat') 38 | filepath_test_mat = os.path.join(DATA_DIR, 'test_32x32.mat') 39 | if not os.path.exists(filepath_train_mat) or not os.path.exists(filepath_test_mat): 40 | def _progress(count, block_size, total_size): 41 | sys.stdout.write('\r>> Downloading %.1f%%' % (float(count * block_size) / float(total_size) * 100.0)) 42 | sys.stdout.flush() 43 | 44 | urllib.request.urlretrieve(DATA_URL_TRAIN, filepath_train_mat, _progress) 45 | urllib.request.urlretrieve(DATA_URL_TEST, filepath_test_mat, _progress) 46 | 47 | # Training set 48 | print("Loading training data...") 49 | print("Preprocessing training data...") 50 | train_data = loadmat(DATA_DIR + '/train_32x32.mat') 51 | # geosada 170717 52 | #train_x = (-127.5 + train_data['X']) / 255. 53 | train_x = (train_data['X']) / 255. 54 | train_x = train_x.transpose((3, 0, 1, 2)) 55 | train_x = train_x.reshape([train_x.shape[0], -1]) 56 | train_y = train_data['y'].flatten().astype(np.int32) 57 | train_y[train_y == 10] = 0 58 | 59 | # Test set 60 | print("Loading test data...") 61 | test_data = loadmat(DATA_DIR + '/test_32x32.mat') 62 | # geosada 170717 63 | #test_x = (-127.5 + test_data['X']) / 255. 64 | test_x = (test_data['X']) / 255. 65 | test_x = test_x.transpose((3, 0, 1, 2)) 66 | test_x = test_x.reshape((test_x.shape[0], -1)) 67 | test_y = test_data['y'].flatten().astype(np.int32) 68 | test_y[test_y == 10] = 0 69 | 70 | np.save('{}/train_images'.format(DATA_DIR), train_x) 71 | np.save('{}/train_labels'.format(DATA_DIR), train_y) 72 | np.save('{}/test_images'.format(DATA_DIR), test_x) 73 | np.save('{}/test_labels'.format(DATA_DIR), test_y) 74 | 75 | 76 | def load_svhn(): 77 | maybe_download_and_extract() 78 | train_images = np.load('{}/train_images.npy'.format(DATA_DIR)).astype(np.float32) 79 | train_labels = np.load('{}/train_labels.npy'.format(DATA_DIR)).astype(np.float32) 80 | test_images = np.load('{}/test_images.npy'.format(DATA_DIR)).astype(np.float32) 81 | test_labels = np.load('{}/test_labels.npy'.format(DATA_DIR)).astype(np.float32) 82 | return (train_images, train_labels), (test_images, test_labels) 83 | 84 | 85 | def prepare_dataset(): 86 | (train_images, train_labels), (test_images, test_labels) = load_svhn() 87 | dirpath = os.path.join(DATA_DIR, 'seed' + str(DATASET_SEED)) 88 | if not os.path.exists(dirpath): 89 | os.makedirs(dirpath) 90 | 91 | rng = np.random.RandomState(DATASET_SEED) 92 | rand_ix = rng.permutation(NUM_EXAMPLES_TRAIN) 93 | print(rand_ix) 94 | _train_images, _train_labels = train_images[rand_ix], train_labels[rand_ix] 95 | 96 | labeled_ind = np.arange(N_LABELED) 97 | labeled_train_images, labeled_train_labels = _train_images[labeled_ind], _train_labels[labeled_ind] 98 | _train_images = np.delete(_train_images, labeled_ind, 0) 99 | _train_labels = np.delete(_train_labels, labeled_ind, 0) 100 | convert_images_and_labels(labeled_train_images, 101 | labeled_train_labels, 102 | os.path.join(dirpath, 'labeled_train.tfrecords')) 103 | convert_images_and_labels(train_images, train_labels, 104 | os.path.join(dirpath, 'unlabeled_train.tfrecords')) 105 | convert_images_and_labels(test_images, 106 | test_labels, 107 | os.path.join(dirpath, 'test.tfrecords')) 108 | 109 | # Construct dataset for validation 110 | train_images_valid, train_labels_valid = labeled_train_images, labeled_train_labels 111 | test_images_valid, test_labels_valid = \ 112 | _train_images[:FLAGS.num_valid_examples], _train_labels[:FLAGS.num_valid_examples] 113 | unlabeled_train_images_valid = np.concatenate( 114 | (train_images_valid, _train_images[FLAGS.num_valid_examples:]), axis=0) 115 | unlabeled_train_labels_valid = np.concatenate( 116 | (train_labels_valid, _train_labels[FLAGS.num_valid_examples:]), axis=0) 117 | convert_images_and_labels(train_images_valid, 118 | train_labels_valid, 119 | os.path.join(dirpath, 'labeled_train_val.tfrecords')) 120 | convert_images_and_labels(unlabeled_train_images_valid, 121 | unlabeled_train_labels_valid, 122 | os.path.join(dirpath, 'unlabeled_train_val.tfrecords')) 123 | convert_images_and_labels(test_images_valid, 124 | test_labels_valid, 125 | os.path.join(dirpath, 'test_val.tfrecords')) 126 | 127 | 128 | def inputs(batch_size=100, 129 | train=True, validation=False, 130 | shuffle=True, num_epochs=None): 131 | if validation: 132 | if train: 133 | filenames = ['labeled_train_val.tfrecords'] 134 | num_examples = N_LABELED 135 | else: 136 | filenames = ['test_val.tfrecords'] 137 | num_examples = FLAGS.num_valid_examples 138 | else: 139 | if train: 140 | filenames = ['labeled_train.tfrecords'] 141 | num_examples = N_LABELED 142 | else: 143 | filenames = ['test.tfrecords'] 144 | num_examples = NUM_EXAMPLES_TEST 145 | 146 | filenames = [os.path.join('seed' + str(DATASET_SEED), filename) for filename in filenames] 147 | filename_queue = generate_filename_queue(filenames, DATA_DIR, num_epochs) 148 | image, label = read(filename_queue) 149 | image = transform(tf.cast(image, tf.float32)) if train else image 150 | return generate_batch([image, label], num_examples, batch_size, shuffle) 151 | 152 | 153 | def unlabeled_inputs(batch_size=100, 154 | validation=False, 155 | shuffle=True): 156 | if validation: 157 | filenames = ['unlabeled_train_val.tfrecords'] 158 | num_examples = NUM_EXAMPLES_TRAIN - FLAGS.num_valid_examples 159 | else: 160 | filenames = ['unlabeled_train.tfrecords'] 161 | num_examples = NUM_EXAMPLES_TRAIN 162 | 163 | filenames = [os.path.join('seed' + str(DATASET_SEED), filename) for filename in filenames] 164 | filename_queue = generate_filename_queue(filenames, data_dir=DATA_DIR) 165 | image, label = read(filename_queue) 166 | image = transform(tf.cast(image, tf.float32)) 167 | return generate_batch([image], num_examples, batch_size, shuffle) 168 | 169 | 170 | def main(argv): 171 | prepare_dataset() 172 | 173 | 174 | if __name__ == "__main__": 175 | tf.app.run() 176 | --------------------------------------------------------------------------------