├── LICENSE
├── PI.py
├── README.md
├── main.py
└── util
├── .gitkeep
├── HandleIIDDataTFRecord.py
├── dataset_utils.py
├── layers.py
├── losses.py
└── svhn.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/PI.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | from layers import Layers
4 | from losses import LossFunctions
5 |
6 | class PI(object):
7 |
8 | def __init__(self, d, lr, lambda_pi_usl, use_pi):
9 |
10 | """ flags for each regularizor """
11 | self.use_pi = use_pi
12 |
13 | """ data and external toolkits """
14 | self.d = d # dataset manager
15 | self.ls = Layers()
16 | self.lf = LossFunctions(self.ls, d, self.encoder)
17 |
18 | """ placeholders defined outside"""
19 | self.lr = lr
20 | self.lambda_pi_usl = lambda_pi_usl
21 |
22 | def encoder(self, x, is_train=True, do_update_bn=True):
23 |
24 | """ https://arxiv.org/pdf/1610.02242.pdf """
25 |
26 | if is_train:
27 | h = self.distort(x)
28 | h = self.ls.get_corrupted(x, 0.15)
29 | else:
30 | h = x
31 |
32 | scope = '1'
33 | h = self.ls.conv2d(scope+'_1', h, 128, activation=self.ls.lrelu)
34 | h = self.ls.conv2d(scope+'_2', h, 128, activation=self.ls.lrelu)
35 | h = self.ls.conv2d(scope+'_3', h, 128, activation=self.ls.lrelu)
36 | h = self.ls.max_pool(h)
37 | if is_train: h = tf.nn.dropout(h, 0.5)
38 |
39 | scope = '2'
40 | h = self.ls.conv2d(scope+'_1', h, 256, activation=self.ls.lrelu)
41 | h = self.ls.conv2d(scope+'_2', h, 256, activation=self.ls.lrelu)
42 | h = self.ls.conv2d(scope+'_3', h, 256, activation=self.ls.lrelu)
43 | h = self.ls.max_pool(h)
44 | if is_train: h = tf.nn.dropout(h, 0.5)
45 |
46 | scope = '3'
47 | h = self.ls.conv2d(scope+'_1', h, 512, activation=self.ls.lrelu)
48 | h = self.ls.conv2d(scope+'_2', h, 256, activation=self.ls.lrelu, filter_size=(1,1))
49 | h = self.ls.conv2d(scope+'_3', h, 128, activation=self.ls.lrelu, filter_size=(1,1))
50 | h = tf.reduce_mean(h, reduction_indices=[1, 2]) # Global average pooling
51 | h = self.ls.dense(scope, h, self.d.l)
52 |
53 | return h
54 |
55 | def build_graph_train(self, x_l, y_l, x, is_supervised=True):
56 |
57 | o = dict() # output
58 | loss = 0
59 |
60 | logit = self.encoder(x)
61 |
62 | with tf.variable_scope(tf.get_variable_scope(), reuse=True):
63 | logit_l = self.encoder(x_l, is_train=True, do_update_bn=False) # for pyx and vat loss computation
64 |
65 | """ Classification Loss """
66 | o['Ly'], o['accur'] = self.lf.get_loss_pyx(logit_l, y_l)
67 | loss += o['Ly']
68 |
69 | """ PI Model Loss """
70 | if self.use_pi:
71 | with tf.variable_scope(tf.get_variable_scope(), reuse=True):
72 | _,_,o['Lp'] = self.lf.get_loss_pi(x, logit, is_train=True)
73 | loss += self.lambda_pi_usl * o['Lp']
74 | else:
75 | o['Lp'] = tf.constant(0)
76 |
77 | """ set losses """
78 | o['loss'] = loss
79 | self.o_train = o
80 |
81 | """ set optimizer """
82 | optimizer = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.5)
83 | #self.op = optimizer.minimize(loss)
84 | grads = optimizer.compute_gradients(loss)
85 | for i,(g,v) in enumerate(grads):
86 | if g is not None:
87 | #g = tf.Print(g, [g], "g %s = "%(v))
88 | grads[i] = (tf.clip_by_norm(g,5),v) # clip gradients
89 | else:
90 | print('g is None:', v)
91 | v = tf.Print(v, [v], "v = ", summarize=10000)
92 | self.op = optimizer.apply_gradients(grads) # return train_op
93 |
94 |
95 | def build_graph_test(self, x_l, y_l ):
96 |
97 | o = dict() # output
98 | loss = 0
99 |
100 | logit_l = self.encoder(x_l, is_train=False, do_update_bn=False) # for pyx and vat loss computation
101 |
102 | """ classification loss """
103 | o['Ly'], o['accur'] = self.lf.get_loss_pyx(logit_l, y_l)
104 | loss += o['Ly']
105 |
106 | """ set losses """
107 | o['loss'] = loss
108 | self.o_test = o
109 |
110 | def distort(self, x):
111 |
112 | _d = self.d
113 |
114 | def _distort(a_image):
115 | """
116 | bounding_boxes: A Tensor of type float32.
117 | 3-D with shape [batch, N, 4] describing the N bounding boxes associated with the image.
118 | Bounding boxes are supplied and returned as [y_min, x_min, y_max, x_max]
119 | """
120 | # shape: [1, 1, 4]
121 | bounding_boxes = tf.constant([[[1/10, 1/10, 9/10, 9/10]]], dtype=tf.float32)
122 |
123 | begin, size, _ = tf.image.sample_distorted_bounding_box(
124 | (_d.h, _d.w, _d.c), bounding_boxes,
125 | min_object_covered=(8.5/10.0),
126 | aspect_ratio_range=[7.0/10.0, 10.0/7.0])
127 |
128 | a_image = tf.slice(a_image, begin, size)
129 | """ for the purpose of distorting not use tf.image.resize_image_with_crop_or_pad under """
130 | a_image = tf.image.resize_images(a_image, [_d.h, _d.w])
131 | """ due to the size of channel returned from tf.image.resize_images is not being given,
132 | specify it manually. """
133 | a_image = tf.reshape(a_image, [_d.h, _d.w, _d.c])
134 | return a_image
135 |
136 | """ process batch times in parallel """
137 | return tf.map_fn( _distort, x)
138 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PI
2 | Very simple TensorFlow implementation of NVIDIA’s Π Model from [“Temporal Ensembling for Semi-Supervised Learning”](https://arxiv.org/pdf/1610.02242.pdf) (ICLR 2017) on the SVHN classification task.
3 |
4 |
5 |

6 |
7 |
8 | ## Usage
9 |
10 | ```python main.py```
11 |
12 |
13 | ## Useful Resources
14 |
15 | - [Original Authors’ implementation with Theano and Lasagne](https://github.com/ericjang/draw)
16 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import sys, os, time
4 |
5 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/util')
6 | from HandleIIDDataTFRecord import HandleIIDDataTFRecord
7 | from PI import PI
8 |
9 | DO_TRAIN = True
10 | DO_TEST = True
11 | USE_PI = True
12 |
13 | tf.flags.DEFINE_string("dataset", "SVHN", "MNIST / CIFAR10 / SVHN / CharImages")
14 | tf.flags.DEFINE_boolean("restore", False, "restore from the last check point")
15 | tf.flags.DEFINE_string("dir_logs", "./out/", "")
16 | FLAGS = tf.flags.FLAGS
17 |
18 | if not DO_TRAIN and not FLAGS.restore:
19 | print('[WARN] FLAGS.restore is set to True compulsorily')
20 | FLAGS.restore = True
21 |
22 | N_EPOCHS = 100
23 |
24 | FILE_OF_CKPT = os.path.join(FLAGS.dir_logs,"drawmodel.ckpt")
25 |
26 | # learning rate decay
27 | STARTER_LEARNING_RATE = 1e-3
28 | DECAY_AFTER = 2
29 | DECAY_INTERVAL = 2
30 | DECAY_FACTOR = 0.97
31 |
32 | def get_lambda_pi_usl(epoch):
33 | if USE_PI:
34 | import math
35 | def _rampup(epoch):
36 | """ https://github.com/smlaine2/tempens/blob/master/train.py """
37 | PI_RAMPUP_LENGTH = 80 # there seems to be no other option than 80, according to the paper.
38 | if epoch < PI_RAMPUP_LENGTH:
39 | p = 1.0 - (max(0.0, float(epoch)) / float(PI_RAMPUP_LENGTH))
40 | return math.exp(-p*p*5.0)
41 | else:
42 | return 1.0
43 |
44 | PI_W_MAX = 100
45 | _pi_m_n = d.n_labeled / d.n_train
46 | return _rampup(epoch) * PI_W_MAX * _pi_m_n
47 | else:
48 | return 0.0
49 |
50 |
51 | def test():
52 | accur = []
53 | for i in range(d.n_batches_test):
54 | r = sess.run(m.o_test)
55 | accur.append( r['accur'])
56 | return np.mean(accur, axis=0)
57 |
58 | with tf.Graph().as_default() as g:
59 |
60 | ###########################################
61 | """ Load Data """
62 | ###########################################
63 | BATCH_SIZE = 100
64 | d = HandleIIDDataTFRecord(FLAGS.dataset, BATCH_SIZE)
65 | (x_train, y_train), x, (x_test, y_test) = d.get_tfrecords()
66 |
67 | ###########################################
68 | """ Build Model Graphs """
69 | ###########################################
70 | lr = tf.placeholder(tf.float32, shape=[], name="learning_rate")
71 | lambda_pi_usl = tf.placeholder(tf.float32, shape=(), name="lambda_pi_usl")
72 |
73 | with tf.variable_scope("watashinomodel") as scope:
74 |
75 | m = PI( d, lr, lambda_pi_usl, use_pi=USE_PI)
76 |
77 | print('... now building the graph for training.')
78 | m.build_graph_train(x_train,y_train,x) # the third one is a dummy for future
79 | scope.reuse_variables()
80 | if DO_TEST :
81 | print('... now building the graph for test.')
82 | m.build_graph_test(x_test,y_test)
83 |
84 |
85 | ###########################################
86 | """ Init """
87 | ###########################################
88 | init_op = tf.global_variables_initializer()
89 | for v in tf.all_variables(): print("[DEBUG] %s : %s" % (v.name,v.get_shape()))
90 |
91 | saver = tf.train.Saver()
92 | config = tf.ConfigProto()
93 | config.gpu_options.allocator_type = 'BFC'
94 | sess = tf.Session(config = config)
95 |
96 | _lr, ratio = STARTER_LEARNING_RATE, 1.0
97 |
98 | if FLAGS.restore:
99 | print("... restore from the last check point.")
100 | saver.restore(sess, FILE_OF_CKPT)
101 | else:
102 | sess.run(init_op)
103 |
104 | merged = tf.summary.merge_all()
105 | tf.get_default_graph().finalize()
106 |
107 | ###########################################
108 | """ Training Loop """
109 | ###########################################
110 | if DO_TRAIN:
111 | print('... start training')
112 | tf.train.start_queue_runners(sess=sess)
113 | for epoch in range(1, N_EPOCHS+1):
114 |
115 | loss, accur = [],[]
116 | for i in range(d.n_batches_train):
117 |
118 | feed_dict = {lr:_lr, lambda_pi_usl:get_lambda_pi_usl(epoch)}
119 |
120 | """ do update """
121 | time_start = time.time()
122 | _, r, op, current_lr = sess.run([merged, m.o_train, m.op, m.lr], feed_dict=feed_dict)
123 | elapsed_time = time.time() - time_start
124 |
125 | loss.append(r['loss'])
126 | accur.append(r['accur'])
127 |
128 | if i % 100 == 0 and i != 0:
129 |
130 | print(" iter:%2d, loss: %.5f, accr: %.5f, Ly: %s, Lp: %s, time:%.3f" % \
131 | (i, np.mean(np.array(loss)), np.mean(np.array(accur)), r['Ly'], r['Lp'], elapsed_time ))
132 |
133 | """ test """
134 | if DO_TEST and epoch % 1 == 0:
135 | time_start = time.time()
136 | accur = test()
137 | elapsed_time = time.time() - time_start
138 | print("epoch:%d, accur: %s, time:%.3f" % (epoch, accur, elapsed_time ))
139 |
140 | """ save """
141 | if epoch % 1 == 0:
142 | print("Model saved in file: %s" % saver.save(sess,FILE_OF_CKPT))
143 |
144 |
145 | """ learning rate decay"""
146 | if (epoch % DECAY_INTERVAL == 0) and (epoch > DECAY_AFTER):
147 | ratio *= DECAY_FACTOR
148 | _lr = STARTER_LEARNING_RATE * ratio
149 | print('lr decaying is scheduled. epoch:%d, lr:%f <= %f' % ( epoch, _lr, current_lr))
150 |
151 |
152 | sess.close()
153 |
--------------------------------------------------------------------------------
/util/.gitkeep:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/util/HandleIIDDataTFRecord.py:
--------------------------------------------------------------------------------
1 | import sys, os, time
2 |
3 | class HandleIIDDataTFRecord(object):
4 |
5 | def __init__(self, dataset, batch_size, is_debug=False):
6 |
7 | self.dataset = dataset
8 | self.batch_size = batch_size
9 | self.is_debug = is_debug
10 |
11 | if self.dataset == 'SVHN':
12 | from svhn import N_LABELED
13 | n_train, n_test, n_labeled = 73257, 26032, N_LABELED
14 | _h, _w, _c = 32,32,3
15 | _img_size = _h*_w*_c
16 | _l = 10
17 | _is_3d = True
18 | else:
19 | sys.exit('[ERROR] not implemented yet')
20 |
21 | self.h = _h
22 | self.w = _w
23 | self.c = _c
24 | self.l = _l
25 | self.is_3d = _is_3d
26 | self.img_size = _img_size
27 | self.n_train = n_train
28 | self.n_test = n_test
29 | self.n_labeled = n_labeled
30 | self.n_batches_train = int(n_train/batch_size)
31 | self.n_batches_test = int(n_test/batch_size)
32 |
33 | ########################################
34 | """ inputs """
35 | ########################################
36 | def get_tfrecords(self):
37 |
38 | """
39 | xtrain: all records
40 | *_l : partial records
41 | """
42 | if self.dataset =='SVHN':
43 | from svhn import inputs, unlabeled_inputs
44 | xtrain_l, ytrain_l = inputs(batch_size=self.batch_size, train=True, validation=False, shuffle=True)
45 | xtrain = unlabeled_inputs(batch_size=self.batch_size, validation=False, shuffle=True)
46 | xtest , ytest = inputs(batch_size=self.batch_size, train=False, validation=False, shuffle=True)
47 | else:
48 | sys.exit('[ERROR] not implemented yet')
49 | return (xtrain_l, ytrain_l), xtrain, (xtest , ytest)
50 |
51 |
52 | if __name__ == '__main__':
53 |
54 | BATCH_SIZE = 20
55 |
56 | d = HandleIIDDataTFRecord( 'SVHN', BATCH_SIZE, is_debug=True)
57 | print(d.get_tfrecords())
58 |
59 | sys.exit('saigo')
60 |
--------------------------------------------------------------------------------
/util/dataset_utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import os, sys, pickle
3 | import numpy as np
4 | from scipy import linalg
5 |
6 | #FLAGS = tf.app.flags.FLAGS
7 | #tf.app.flags.DEFINE_bool('aug_trans', False, "")
8 | #tf.app.flags.DEFINE_bool('aug_flip', False, "")
9 |
10 | AUG_TRANS = False
11 | AUG_FLIP = False
12 |
13 | def unpickle(file):
14 | fp = open(file, 'rb')
15 | if sys.version_info.major == 2:
16 | data = pickle.load(fp)
17 | elif sys.version_info.major == 3:
18 | data = pickle.load(fp, encoding='latin-1')
19 | fp.close()
20 | return data
21 |
22 |
23 | def ZCA(data, reg=1e-6):
24 | mean = np.mean(data, axis=0)
25 | mdata = data - mean
26 | sigma = np.dot(mdata.T, mdata) / mdata.shape[0]
27 | U, S, V = linalg.svd(sigma)
28 | components = np.dot(np.dot(U, np.diag(1 / np.sqrt(S) + reg)), U.T)
29 | whiten = np.dot(data - mean, components.T)
30 | return components, mean, whiten
31 |
32 |
33 | def _int64_feature(value):
34 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
35 |
36 |
37 | def _bytes_feature(value):
38 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
39 |
40 |
41 | def convert_images_and_labels(images, labels, filepath):
42 |
43 | print('[DEBUG] inputs shape:', images.shape, labels.shape) # (4000, 3072) (4000,)
44 | num_examples = labels.shape[0]
45 | if images.shape[0] != num_examples:
46 | raise ValueError("Images size %d does not match label size %d." %
47 | (images.shape[0], num_examples))
48 | print('Writing', filepath)
49 | writer = tf.python_io.TFRecordWriter(filepath)
50 | for index in range(num_examples):
51 | image = images[index].tolist()
52 | image_feature = tf.train.Feature(float_list=tf.train.FloatList(value=image))
53 | #print('[DEBUG] image_feature:', image_feature) # float_list { value: xxx},...}
54 | example = tf.train.Example(features=tf.train.Features(feature={
55 | 'height': _int64_feature(32),
56 | 'width': _int64_feature(32),
57 | 'depth': _int64_feature(3),
58 | 'label': _int64_feature(int(labels[index])),
59 | 'image': image_feature}))
60 | writer.write(example.SerializeToString())
61 | writer.close()
62 |
63 |
64 | def read(filename_queue):
65 | reader = tf.TFRecordReader()
66 | print('filename_queue',filename_queue)
67 | _, serialized_example = reader.read(filename_queue)
68 | features = tf.parse_single_example(
69 | serialized_example,
70 | # Defaults are not specified since both keys are required.
71 | features={
72 | 'image': tf.FixedLenFeature([3072], tf.float32),
73 | 'label': tf.FixedLenFeature([], tf.int64),
74 | })
75 |
76 | # Convert label from a scalar uint8 tensor to an int32 scalar.
77 | image = features['image']
78 | image = tf.reshape(image, [32, 32, 3])
79 | label = tf.one_hot(tf.cast(features['label'], tf.int32), 10)
80 | return image, label
81 |
82 |
83 | def generate_batch(
84 | example,
85 | min_queue_examples,
86 | batch_size, shuffle):
87 | """
88 | Arg:
89 | list of tensors.
90 | """
91 | num_preprocess_threads = 1
92 |
93 | if shuffle:
94 | ret = tf.train.shuffle_batch(
95 | example,
96 | batch_size=batch_size,
97 | num_threads=num_preprocess_threads,
98 | capacity=min_queue_examples + 5 * batch_size,
99 | min_after_dequeue=min_queue_examples)
100 | else:
101 | ret = tf.train.batch(
102 | example,
103 | batch_size=batch_size,
104 | num_threads=num_preprocess_threads,
105 | allow_smaller_final_batch=True,
106 | capacity=min_queue_examples + 5 * batch_size)
107 |
108 | return ret
109 |
110 |
111 | def transform(image):
112 | image = tf.reshape(image, [32, 32, 3])
113 | if AUG_TRANS or AUG_FLIP:
114 | print("augmentation")
115 | if AUG_TRANS:
116 | image = tf.pad(image, [[2, 2], [2, 2], [0, 0]])
117 | image = tf.random_crop(image, [32, 32, 3])
118 | if AUG_FLIP:
119 | image = tf.image.random_flip_left_right(image)
120 | return image
121 |
122 |
123 | def generate_filename_queue(filenames, data_dir, num_epochs=None):
124 | print("filenames in queue:", filenames)
125 | for i in range(len(filenames)):
126 | filenames[i] = os.path.join(data_dir, filenames[i])
127 | return tf.train.string_input_producer(filenames, num_epochs=num_epochs)
128 |
129 |
130 |
--------------------------------------------------------------------------------
/util/layers.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import tensorflow as tf
4 | import numpy as np
5 | import sys
6 |
7 | class Layers(object):
8 |
9 | def __init__(self):
10 | self.do_share = False
11 |
12 | def set_do_share(self, flag):
13 | self.do_share = flag
14 |
15 | def W( self, W_shape, W_name='W', W_init=None):
16 | if W_init is None:
17 | W_initializer = tf.contrib.layers.xavier_initializer()
18 | else:
19 | W_initializer = tf.constant_initializer(W_init)
20 |
21 | return tf.get_variable(W_name, W_shape, initializer=W_initializer)
22 |
23 | def Wb( self, W_shape, b_shape, W_name='W', b_name='b', W_init=None, b_init=0.1):
24 |
25 | W = self.W(W_shape, W_name=W_name, W_init=None)
26 | b = tf.get_variable(b_name, b_shape, initializer=tf.constant_initializer(b_init))
27 |
28 | def _summaries(var):
29 | """Attach a lot of summaries to a Tensor (for TensorBoard visualization)."""
30 | with tf.name_scope('summaries'):
31 | mean = tf.reduce_mean(var)
32 | tf.summary.scalar('mean', mean)
33 | with tf.name_scope('stddev'):
34 | stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
35 | tf.summary.scalar('stddev', stddev)
36 | tf.summary.scalar('max', tf.reduce_max(var))
37 | tf.summary.scalar('min', tf.reduce_min(var))
38 | tf.summary.histogram('histogram', var)
39 | _summaries(W)
40 | _summaries(b)
41 |
42 | return W, b
43 |
44 |
45 | def denseV2( self, scope, x, output_dim, activation=None):
46 | return tf.contrib.layers.fully_connected( x, output_dim, activation_fn=activation, reuse=self.do_share, scope=scope)
47 |
48 | def dense( self, scope, x, output_dim, activation=None):
49 | if len(x.get_shape()) == 2: # 1d
50 | pass
51 | elif len(x.get_shape()) == 4: # cnn as NHWC
52 | #x = tf.reshape(x, [tf.shape(x)[0], -1]) # flatten
53 | x = tf.reshape(x, [x.get_shape().as_list()[0], -1]) # flatten
54 | #x = tf.reshape(x, [tf.cast(x.get_shape()[0], tf.int32), -1]) # flatten
55 | with tf.variable_scope(scope,reuse=self.do_share): W, b = self.Wb([x.get_shape()[1], output_dim], [output_dim])
56 | #with tf.variable_scope(scope,reuse=self.do_share): W, b = self.Wb([x.get_shape()[1], output_dim], [output_dim])
57 | o = tf.matmul(x, W) + b
58 | return o if activation is None else activation(o)
59 |
60 | def lrelu(self, x, a=0.1):
61 | if a < 1e-16:
62 | return tf.nn.relu(x)
63 | else:
64 | return tf.maximum(x, a * x)
65 |
66 | def avg_pool(self, x, ksize=2, stride=2):
67 | return tf.nn.avg_pool(x, ksize=[1, ksize, ksize, 1], strides=[1, stride, stride, 1], padding='SAME')
68 |
69 | def max_pool(self, x, ksize=2, stride=2):
70 | return tf.nn.max_pool(x, ksize=[1, ksize, ksize, 1], strides=[1, stride, stride, 1], padding='SAME')
71 |
72 | def conv2d( self, scope, x, out_c, filter_size=(3,3), strides=(1,1,1,1), padding="SAME", activation=None):
73 | """
74 | x: [BATCH_SIZE, in_height, in_width, in_channels]
75 | filter : [filter_height, filter_width, in_channels, out_channels]
76 | """
77 | filter = [filter_size[0], filter_size[1], int(x.get_shape()[3]), out_c]
78 | with tf.variable_scope(scope,reuse=self.do_share): W, b = self.Wb(filter, [out_c])
79 | o = tf.nn.conv2d(x, W, strides, padding) + b
80 | return o if activation is None else activation(o)
81 |
82 | ###########################################
83 | """ Softmax """
84 | ###########################################
85 | def softmax( self, scope, input, size):
86 | if input.get_shape()[1] != size:
87 | print("softmax w/ fc:", input.get_shape()[1], '->', size)
88 | return self.dense(scope, input, size, tf.nn.softmax)
89 | else:
90 | print("softmax w/o fc")
91 | return tf.nn.softmax(input)
92 |
93 | ###########################################
94 | """ Noise/Denose Function """
95 | ###########################################
96 | def get_corrupted(self, x, noise_std=.10):
97 | return self.sampler( x, noise_std)
98 |
99 | def epsilon( self, _shape, _stddev=1.):
100 | return tf.truncated_normal(_shape, mean=0, stddev=_stddev)
101 |
102 | def sampler( self, mu, sigma):
103 | """
104 | mu,sigma : (BATCH_SIZE, z_size)
105 | """
106 | return mu + sigma*self.epsilon( tf.shape(mu) )
107 |
--------------------------------------------------------------------------------
/util/losses.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import sys
4 |
5 |
6 | eps = 1e-8
7 |
8 | class LossFunctions(object):
9 |
10 | def __init__(self, layers, dataset, encoder):
11 |
12 | self.ls = layers
13 | self.d = dataset
14 | self.encoder = encoder
15 | self.reconst_pixel_log_stdv = tf.get_variable("reconst_pixel_log_stdv", initializer=tf.constant(0.0))
16 |
17 | def get_loss_pyx(self, logit, y):
18 |
19 | loss = self._ce(logit, y)
20 | accur = self._accuracy(logit, y)
21 | return loss, accur
22 |
23 | def get_loss_pi(self, x, logit_real, is_train):
24 | logit_real = tf.stop_gradient(logit_real)
25 | logit_virtual = self.encoder(x, is_train=is_train, do_update_bn=False)
26 | loss = tf.sqrt(tf.reduce_mean(tf.square(tf.subtract(logit_real, logit_virtual))) + eps)
27 | return logit_real, logit_virtual, loss
28 |
29 |
30 | """ https://github.com/takerum/vat_tf/blob/master/layers.py """
31 | def _ce(self, logit, y):
32 | return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logit, labels=y))
33 |
34 | def _accuracy(self, logit, y):
35 | pred = tf.argmax(logit, 1)
36 | true = tf.argmax(y, 1)
37 | return tf.reduce_mean(tf.to_float(tf.equal(pred, true)))
38 |
--------------------------------------------------------------------------------
/util/svhn.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 | import sys
7 | from scipy.io import loadmat
8 |
9 | import numpy as np
10 | from scipy import linalg
11 | import glob
12 | import pickle
13 |
14 | from six.moves import xrange # pylint: disable=redefined-builtin
15 | from six.moves import urllib
16 |
17 | import tensorflow as tf
18 | from dataset_utils import *
19 |
20 | DATA_URL_TRAIN = 'http://ufldl.stanford.edu/housenumbers/train_32x32.mat'
21 | DATA_URL_TEST = 'http://ufldl.stanford.edu/housenumbers/test_32x32.mat'
22 |
23 | N_LABELED = 4000
24 | DATASET_SEED = 1
25 | DATA_DIR = 'PATH_TO_DIR_OF_SVHN_IN_YOUR_ENVIRONMENT'
26 |
27 | FLAGS = tf.app.flags.FLAGS
28 | tf.app.flags.DEFINE_integer('num_valid_examples', 1000, "The number of validation examples")
29 |
30 | NUM_EXAMPLES_TRAIN = 73257
31 | NUM_EXAMPLES_TEST = 26032
32 |
33 |
34 | def maybe_download_and_extract():
35 | if not os.path.exists(DATA_DIR):
36 | os.makedirs(DATA_DIR)
37 | filepath_train_mat = os.path.join(DATA_DIR, 'train_32x32.mat')
38 | filepath_test_mat = os.path.join(DATA_DIR, 'test_32x32.mat')
39 | if not os.path.exists(filepath_train_mat) or not os.path.exists(filepath_test_mat):
40 | def _progress(count, block_size, total_size):
41 | sys.stdout.write('\r>> Downloading %.1f%%' % (float(count * block_size) / float(total_size) * 100.0))
42 | sys.stdout.flush()
43 |
44 | urllib.request.urlretrieve(DATA_URL_TRAIN, filepath_train_mat, _progress)
45 | urllib.request.urlretrieve(DATA_URL_TEST, filepath_test_mat, _progress)
46 |
47 | # Training set
48 | print("Loading training data...")
49 | print("Preprocessing training data...")
50 | train_data = loadmat(DATA_DIR + '/train_32x32.mat')
51 | # geosada 170717
52 | #train_x = (-127.5 + train_data['X']) / 255.
53 | train_x = (train_data['X']) / 255.
54 | train_x = train_x.transpose((3, 0, 1, 2))
55 | train_x = train_x.reshape([train_x.shape[0], -1])
56 | train_y = train_data['y'].flatten().astype(np.int32)
57 | train_y[train_y == 10] = 0
58 |
59 | # Test set
60 | print("Loading test data...")
61 | test_data = loadmat(DATA_DIR + '/test_32x32.mat')
62 | # geosada 170717
63 | #test_x = (-127.5 + test_data['X']) / 255.
64 | test_x = (test_data['X']) / 255.
65 | test_x = test_x.transpose((3, 0, 1, 2))
66 | test_x = test_x.reshape((test_x.shape[0], -1))
67 | test_y = test_data['y'].flatten().astype(np.int32)
68 | test_y[test_y == 10] = 0
69 |
70 | np.save('{}/train_images'.format(DATA_DIR), train_x)
71 | np.save('{}/train_labels'.format(DATA_DIR), train_y)
72 | np.save('{}/test_images'.format(DATA_DIR), test_x)
73 | np.save('{}/test_labels'.format(DATA_DIR), test_y)
74 |
75 |
76 | def load_svhn():
77 | maybe_download_and_extract()
78 | train_images = np.load('{}/train_images.npy'.format(DATA_DIR)).astype(np.float32)
79 | train_labels = np.load('{}/train_labels.npy'.format(DATA_DIR)).astype(np.float32)
80 | test_images = np.load('{}/test_images.npy'.format(DATA_DIR)).astype(np.float32)
81 | test_labels = np.load('{}/test_labels.npy'.format(DATA_DIR)).astype(np.float32)
82 | return (train_images, train_labels), (test_images, test_labels)
83 |
84 |
85 | def prepare_dataset():
86 | (train_images, train_labels), (test_images, test_labels) = load_svhn()
87 | dirpath = os.path.join(DATA_DIR, 'seed' + str(DATASET_SEED))
88 | if not os.path.exists(dirpath):
89 | os.makedirs(dirpath)
90 |
91 | rng = np.random.RandomState(DATASET_SEED)
92 | rand_ix = rng.permutation(NUM_EXAMPLES_TRAIN)
93 | print(rand_ix)
94 | _train_images, _train_labels = train_images[rand_ix], train_labels[rand_ix]
95 |
96 | labeled_ind = np.arange(N_LABELED)
97 | labeled_train_images, labeled_train_labels = _train_images[labeled_ind], _train_labels[labeled_ind]
98 | _train_images = np.delete(_train_images, labeled_ind, 0)
99 | _train_labels = np.delete(_train_labels, labeled_ind, 0)
100 | convert_images_and_labels(labeled_train_images,
101 | labeled_train_labels,
102 | os.path.join(dirpath, 'labeled_train.tfrecords'))
103 | convert_images_and_labels(train_images, train_labels,
104 | os.path.join(dirpath, 'unlabeled_train.tfrecords'))
105 | convert_images_and_labels(test_images,
106 | test_labels,
107 | os.path.join(dirpath, 'test.tfrecords'))
108 |
109 | # Construct dataset for validation
110 | train_images_valid, train_labels_valid = labeled_train_images, labeled_train_labels
111 | test_images_valid, test_labels_valid = \
112 | _train_images[:FLAGS.num_valid_examples], _train_labels[:FLAGS.num_valid_examples]
113 | unlabeled_train_images_valid = np.concatenate(
114 | (train_images_valid, _train_images[FLAGS.num_valid_examples:]), axis=0)
115 | unlabeled_train_labels_valid = np.concatenate(
116 | (train_labels_valid, _train_labels[FLAGS.num_valid_examples:]), axis=0)
117 | convert_images_and_labels(train_images_valid,
118 | train_labels_valid,
119 | os.path.join(dirpath, 'labeled_train_val.tfrecords'))
120 | convert_images_and_labels(unlabeled_train_images_valid,
121 | unlabeled_train_labels_valid,
122 | os.path.join(dirpath, 'unlabeled_train_val.tfrecords'))
123 | convert_images_and_labels(test_images_valid,
124 | test_labels_valid,
125 | os.path.join(dirpath, 'test_val.tfrecords'))
126 |
127 |
128 | def inputs(batch_size=100,
129 | train=True, validation=False,
130 | shuffle=True, num_epochs=None):
131 | if validation:
132 | if train:
133 | filenames = ['labeled_train_val.tfrecords']
134 | num_examples = N_LABELED
135 | else:
136 | filenames = ['test_val.tfrecords']
137 | num_examples = FLAGS.num_valid_examples
138 | else:
139 | if train:
140 | filenames = ['labeled_train.tfrecords']
141 | num_examples = N_LABELED
142 | else:
143 | filenames = ['test.tfrecords']
144 | num_examples = NUM_EXAMPLES_TEST
145 |
146 | filenames = [os.path.join('seed' + str(DATASET_SEED), filename) for filename in filenames]
147 | filename_queue = generate_filename_queue(filenames, DATA_DIR, num_epochs)
148 | image, label = read(filename_queue)
149 | image = transform(tf.cast(image, tf.float32)) if train else image
150 | return generate_batch([image, label], num_examples, batch_size, shuffle)
151 |
152 |
153 | def unlabeled_inputs(batch_size=100,
154 | validation=False,
155 | shuffle=True):
156 | if validation:
157 | filenames = ['unlabeled_train_val.tfrecords']
158 | num_examples = NUM_EXAMPLES_TRAIN - FLAGS.num_valid_examples
159 | else:
160 | filenames = ['unlabeled_train.tfrecords']
161 | num_examples = NUM_EXAMPLES_TRAIN
162 |
163 | filenames = [os.path.join('seed' + str(DATASET_SEED), filename) for filename in filenames]
164 | filename_queue = generate_filename_queue(filenames, data_dir=DATA_DIR)
165 | image, label = read(filename_queue)
166 | image = transform(tf.cast(image, tf.float32))
167 | return generate_batch([image], num_examples, batch_size, shuffle)
168 |
169 |
170 | def main(argv):
171 | prepare_dataset()
172 |
173 |
174 | if __name__ == "__main__":
175 | tf.app.run()
176 |
--------------------------------------------------------------------------------