├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md └── semisup ├── .gitignore ├── __init__.py ├── architectures.py ├── backend.py ├── eval.py ├── mnist_train_eval.py ├── tools ├── __init__.py ├── data_dirs.py ├── data_dirs.py.template ├── data_util │ ├── create_mnistm.py │ ├── create_mnistm.sh │ └── dataset_visualization.ipynb ├── gtsrb.py ├── mnist.py ├── mnist3.py ├── mnistm.py ├── stl10.py ├── svhn.py ├── synth.py ├── synth_signs.py └── usps.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.kate-swp 2 | *.pyc 3 | .idea -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Want to contribute? Great! First, read this page (including the small print at the end). 2 | 3 | ### Before you contribute 4 | Before we can use your code, you must sign the 5 | [Google Individual Contributor License Agreement] 6 | (https://cla.developers.google.com/about/google-individual) 7 | (CLA), which you can do online. The CLA is necessary mainly because you own the 8 | copyright to your changes, even after your contribution becomes part of our 9 | codebase, so we need your permission to use and distribute your code. We also 10 | need to be sure of various other things—for instance that you'll tell us if you 11 | know that your code infringes on other people's patents. You don't have to sign 12 | the CLA until after you've submitted your code for review and a member has 13 | approved it, but you must do it before we can put your code into our codebase. 14 | Before you start working on a larger contribution, you should get in touch with 15 | us first through the issue tracker with your idea so that we can help out and 16 | possibly guide you. Coordinating up front makes it much easier to avoid 17 | frustration later on. 18 | 19 | ### Code reviews 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. 22 | 23 | ### The small print 24 | Contributions made by corporations are covered by a different agreement than 25 | the one above, the 26 | [Software Grant and Corporate Contributor License Agreement] 27 | (https://cla.developers.google.com/about/google-corporate). -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repository contains code for the paper [Learning by Association - A versatile semi-supervised training method for neural networks (CVPR 2017)](https://vision.in.tum.de/_media/spezial/bib/haeusser_cvpr_17.pdf) and the follow-up work [Associative Domain Adaptation (ICCV 2017)](https://vision.in.tum.de/_media/spezial/bib/haeusser_iccv_17.pdf) 2 | 3 | It is implemented with TensorFlow. Please refer to the [TensorFlow documentation](https://www.tensorflow.org/install/) for further information. 4 | 5 | The core functions are implemented in `semisup/backend.py`. 6 | The files `train.py` and `eval.py` demonstrate how to use them. A quick example is contained in `mnist_train_eval.py`. 7 | 8 | In order to reproduce the results from the paper, please use the architectures and pipelines from the `tools/{stl10,svhn,synth}.py`. They are loaded automatically by setting the flag `[target_]dataset` in `{train,eval}.py` accordingly. 9 | 10 | Before you get started, please make sure to add the following to your `~/.bashrc`: 11 | ``` 12 | export PYTHONPATH=/path/to/learning_by_association:$PYTHONPATH 13 | ``` 14 | 15 | Copy the file `semisup/tools/data_dirs.py.template` to `semisup/tools/data_dirs.py`, adapt the paths and .gitignore this file. 16 | 17 | ## Domain Adaptation Hyper parameters 18 | ### Synth. Signs -> GTSRB 19 | ``` 20 | "target_dataset": "gtsrb", 21 | "walker_weight_envelope_delay": "0", 22 | "max_checkpoints": 5, 23 | "dataset": "synth_signs", 24 | "visit_weight": "0.1", 25 | "sup_per_batch": 24, 26 | "walker_weight_envelope_steps": 1, 27 | "eval_batch_size": 24, 28 | "walker_weight_envelope": "linear", 29 | "unsup_batch_size": 1032, 30 | "visit_weight_envelope": "linear", 31 | "decay_steps": 9000, 32 | "sup_per_class": -1, 33 | "max_steps": 12000, 34 | "architecture": "svhn_model" 35 | ``` 36 | 37 | ### MNIST -> MNIST-M 38 | ``` 39 | "target_dataset": "mnistm", 40 | "walker_weight_envelope_delay": "500", 41 | "max_checkpoints": 5, 42 | "new_size": 32, 43 | "dataset": "mnist3", 44 | "visit_weight": "0.6", 45 | "augmentation": true, 46 | "walker_weight_envelope_steps": 1, 47 | "walker_weight_envelope": "linear", 48 | "unsup_batch_size": 1000, 49 | "visit_weight_envelope": "linear", 50 | "decay_steps": 9000, 51 | "architecture": "svhn_model", 52 | "sup_per_class": -1, 53 | "sup_per_batch": 100, 54 | "max_steps": "12000", 55 | ``` 56 | 57 | ### SVHN -> MNIST 58 | ``` 59 | "target_dataset": "mnist3", 60 | "walker_weight_envelope_delay": "500", 61 | "max_checkpoints": 5, 62 | "new_size": 32, 63 | "dataset": "svhn", 64 | "sup_per_batch": 100, 65 | "decay_steps": 9000, 66 | "unsup_batch_size": 1000, 67 | "sup_per_class": -1, 68 | "walker_weight_envelope_steps": 1, 69 | "walker_weight_envelope": "linear", 70 | "visit_weight_envelope": "linear", 71 | "architecture": "svhn_model", 72 | "visit_weight": 0.2, 73 | "max_steps": "12000" 74 | ``` 75 | 76 | ### Synth. Digits --> SVHN 77 | ``` 78 | "target_dataset": "svhn", 79 | "walker_weight_envelope_delay": "2000", 80 | "max_checkpoints": 5, 81 | "dataset": "synth", 82 | "sup_per_class": -1, 83 | "sup_per_batch": 100, 84 | "walker_weight_envelope_steps": 1, 85 | "walker_weight_envelope": "linear", 86 | "decay_steps": 9000, 87 | "unsup_batch_size": 1000, 88 | "visit_weight_envelope": "linear", 89 | "architecture": "svhn_model", 90 | "visit_weight": 0.2, 91 | "max_steps": "20000", 92 | ``` 93 | 94 | If you use the code, please cite the paper "Learning by Association - A versatile semi-supervised training method for neural networks" or "Associative Domain Adaptation": 95 | ``` 96 | @string{cvpr="IEEE Conference on Computer Vision and Pattern Recognition (CVPR)"} 97 | @InProceedings{haeusser-cvpr-17, 98 | author = "P. Haeusser and A. Mordvintsev and D. Cremers", 99 | title = "Learning by Association - A versatile semi-supervised training method for neural networks", 100 | booktitle = cvpr, 101 | year = "2017", 102 | } 103 | 104 | @string{iccv="IEEE International Conference on Computer Vision (ICCV)"} 105 | @InProceedings{haeusser-iccv-17, 106 | author = "P. Haeusser and T. Frerix and A. Mordvintsev and D. Cremers", 107 | title = "Associative Domain Adaptation", 108 | booktitle = iccv, 109 | year = "2017", 110 | } 111 | ``` 112 | 113 | For questions please contact Philip Haeusser (haeusser@cs.tum.edu). 114 | 115 | -------------------------------------------------------------------------------- /semisup/.gitignore: -------------------------------------------------------------------------------- 1 | /data_dirs.py 2 | -------------------------------------------------------------------------------- /semisup/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | # pylint: disable=unused-import,wildcard-import 6 | from . import architectures 7 | from .backend import * -------------------------------------------------------------------------------- /semisup/architectures.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2016 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | 16 | 17 | Definitions and utilities for the svhn model. 18 | 19 | This file contains functions that are needed for semisup training and 20 | evalutaion on the SVHN dataset. 21 | They are used in svhn_train.py and svhn_eval.py. 22 | """ 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import tensorflow as tf 28 | import tensorflow.contrib.slim as slim 29 | 30 | 31 | def svhn_model(inputs, 32 | is_training=True, 33 | augmentation_function=None, 34 | emb_size=128, 35 | l2_weight=1e-4, 36 | img_shape=None, 37 | new_shape=None, 38 | image_summary=False, 39 | batch_norm_decay=0.99): # pylint: disable=unused-argument 40 | """Construct the image-to-embedding vector model.""" 41 | inputs = tf.cast(inputs, tf.float32) 42 | if new_shape is not None: 43 | shape = new_shape 44 | inputs = tf.image.resize_images( 45 | inputs, 46 | tf.constant(new_shape[:2]), 47 | method=tf.image.ResizeMethod.BILINEAR) 48 | else: 49 | shape = img_shape 50 | if is_training and augmentation_function is not None: 51 | inputs = augmentation_function(inputs, shape) 52 | if image_summary: 53 | tf.summary.image('Inputs', inputs, max_outputs=3) 54 | 55 | net = inputs 56 | mean = tf.reduce_mean(net, [1, 2], True) 57 | std = tf.reduce_mean(tf.square(net - mean), [1, 2], True) 58 | net = (net - mean) / (std + 1e-5) 59 | with slim.arg_scope( 60 | [slim.conv2d, slim.fully_connected], 61 | activation_fn=tf.nn.elu, 62 | weights_regularizer=slim.l2_regularizer(l2_weight)): 63 | with slim.arg_scope([slim.dropout], is_training=is_training): 64 | net = slim.conv2d(net, 32, [3, 3], scope='conv1') 65 | net = slim.conv2d(net, 32, [3, 3], scope='conv1_2') 66 | net = slim.conv2d(net, 32, [3, 3], scope='conv1_3') 67 | net = slim.max_pool2d(net, [2, 2], scope='pool1') # 14 68 | net = slim.conv2d(net, 64, [3, 3], scope='conv2_1') 69 | net = slim.conv2d(net, 64, [3, 3], scope='conv2_2') 70 | net = slim.conv2d(net, 64, [3, 3], scope='conv2_3') 71 | net = slim.max_pool2d(net, [2, 2], scope='pool2') # 7 72 | net = slim.conv2d(net, 128, [3, 3], scope='conv3') 73 | net = slim.conv2d(net, 128, [3, 3], scope='conv3_2') 74 | net = slim.conv2d(net, 128, [3, 3], scope='conv3_3') 75 | net = slim.max_pool2d(net, [2, 2], scope='pool3') # 3 76 | net = slim.flatten(net, scope='flatten') 77 | 78 | with slim.arg_scope([slim.fully_connected], normalizer_fn=None): 79 | emb = slim.fully_connected(net, emb_size, scope='fc1') 80 | 81 | return emb 82 | 83 | 84 | def dann_model(inputs, 85 | is_training=True, 86 | augmentation_function=None, 87 | emb_size=2048, 88 | l2_weight=1e-4, 89 | img_shape=None, 90 | new_shape=None, 91 | image_summary=False, 92 | batch_norm_decay=0.99): # pylint: disable=unused-argument 93 | """Construct the image-to-embedding vector model.""" 94 | inputs = tf.cast(inputs, tf.float32) 95 | if new_shape is not None: 96 | shape = new_shape 97 | inputs = tf.image.resize_images( 98 | inputs, 99 | tf.constant(new_shape[:2]), 100 | method=tf.image.ResizeMethod.BILINEAR) 101 | else: 102 | shape = img_shape 103 | if is_training and augmentation_function is not None: 104 | inputs = augmentation_function(inputs, shape) 105 | if image_summary: 106 | tf.summary.image('Inputs', inputs, max_outputs=3) 107 | 108 | net = inputs 109 | mean = tf.reduce_mean(net, [1, 2], True) 110 | std = tf.reduce_mean(tf.square(net - mean), [1, 2], True) 111 | net = (net - mean) / (std + 1e-5) 112 | with slim.arg_scope( 113 | [slim.conv2d, slim.fully_connected], 114 | activation_fn=tf.nn.relu, 115 | weights_regularizer=slim.l2_regularizer(l2_weight)): 116 | with slim.arg_scope([slim.dropout], is_training=is_training): 117 | # TODO(tfrerix) ab hier 118 | net = slim.conv2d(net, 32, [3, 3], scope='conv1') 119 | net = slim.conv2d(net, 32, [3, 3], scope='conv1_2') 120 | net = slim.conv2d(net, 32, [3, 3], scope='conv1_3') 121 | net = slim.max_pool2d(net, [2, 2], scope='pool1') # 14 122 | net = slim.conv2d(net, 64, [3, 3], scope='conv2_1') 123 | net = slim.conv2d(net, 64, [3, 3], scope='conv2_2') 124 | net = slim.conv2d(net, 64, [3, 3], scope='conv2_3') 125 | net = slim.max_pool2d(net, [2, 2], scope='pool2') # 7 126 | net = slim.conv2d(net, 128, [3, 3], scope='conv3') 127 | net = slim.conv2d(net, 128, [3, 3], scope='conv3_2') 128 | net = slim.conv2d(net, 128, [3, 3], scope='conv3_3') 129 | net = slim.max_pool2d(net, [2, 2], scope='pool3') # 3 130 | net = slim.flatten(net, scope='flatten') 131 | 132 | with slim.arg_scope([slim.fully_connected], normalizer_fn=None): 133 | emb = slim.fully_connected(net, emb_size, scope='fc1') 134 | 135 | return emb 136 | 137 | 138 | def stl10_model(inputs, 139 | is_training=True, 140 | augmentation_function=None, 141 | emb_size=128, 142 | img_shape=None, 143 | new_shape=None, 144 | image_summary=False, 145 | batch_norm_decay=0.99): 146 | """Construct the image-to-embedding model.""" 147 | inputs = tf.cast(inputs, tf.float32) 148 | if new_shape is not None: 149 | shape = new_shape 150 | inputs = tf.image.resize_images( 151 | inputs, 152 | tf.constant(new_shape[:2]), 153 | method=tf.image.ResizeMethod.BILINEAR) 154 | else: 155 | shape = img_shape 156 | if is_training and augmentation_function is not None: 157 | inputs = augmentation_function(inputs, shape) 158 | if image_summary: 159 | tf.summary.image('Inputs', inputs, max_outputs=3) 160 | net = inputs 161 | net = (net - 128.0) / 128.0 162 | with slim.arg_scope([slim.dropout], is_training=is_training): 163 | with slim.arg_scope( 164 | [slim.conv2d, slim.fully_connected], 165 | normalizer_fn=slim.batch_norm, 166 | normalizer_params={ 167 | 'is_training': is_training, 168 | 'decay': batch_norm_decay 169 | }, 170 | activation_fn=tf.nn.elu, 171 | weights_regularizer=slim.l2_regularizer(5e-3), ): 172 | with slim.arg_scope([slim.conv2d], padding='SAME'): 173 | with slim.arg_scope([slim.dropout], is_training=is_training): 174 | net = slim.conv2d(net, 32, [3, 3], scope='conv_s2') # 175 | net = slim.conv2d(net, 64, [3, 3], stride=2, scope='conv1') 176 | net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool1') # 177 | net = slim.conv2d(net, 64, [3, 3], scope='conv2') 178 | net = slim.conv2d(net, 128, [3, 3], scope='conv2_2') 179 | net = slim.max_pool2d(net, [2, 2], stride=2, scope='pool2') # 180 | net = slim.conv2d(net, 128, [3, 3], scope='conv3_1') 181 | net = slim.conv2d(net, 256, [3, 3], scope='conv3_2') 182 | net = slim.max_pool2d(net, [2, 2], stride=2, scope='pool3') # 183 | net = slim.conv2d(net, 128, [3, 3], scope='conv4') 184 | net = slim.flatten(net, scope='flatten') 185 | 186 | with slim.arg_scope([slim.fully_connected], normalizer_fn=None): 187 | emb = slim.fully_connected( 188 | net, emb_size, activation_fn=None, scope='fc1') 189 | return emb 190 | 191 | 192 | def mnist_model(inputs, 193 | is_training=True, 194 | emb_size=128, 195 | l2_weight=1e-3, 196 | batch_norm_decay=None, 197 | img_shape=None, 198 | new_shape=None, 199 | augmentation_function=None, 200 | image_summary=False): # pylint: disable=unused-argument 201 | 202 | """Construct the image-to-embedding vector model.""" 203 | 204 | inputs = tf.cast(inputs, tf.float32) / 255.0 205 | if new_shape is not None: 206 | shape = new_shape 207 | inputs = tf.image.resize_images( 208 | inputs, 209 | tf.constant(new_shape[:2]), 210 | method=tf.image.ResizeMethod.BILINEAR) 211 | else: 212 | shape = img_shape 213 | net = inputs 214 | with slim.arg_scope( 215 | [slim.conv2d, slim.fully_connected], 216 | activation_fn=tf.nn.elu, 217 | weights_regularizer=slim.l2_regularizer(l2_weight)): 218 | net = slim.conv2d(net, 32, [3, 3], scope='conv1_1') 219 | net = slim.conv2d(net, 32, [3, 3], scope='conv1_2') 220 | net = slim.max_pool2d(net, [2, 2], scope='pool1') # 14 221 | 222 | net = slim.conv2d(net, 64, [3, 3], scope='conv2_1') 223 | net = slim.conv2d(net, 64, [3, 3], scope='conv2_2') 224 | net = slim.max_pool2d(net, [2, 2], scope='pool2') # 7 225 | 226 | net = slim.conv2d(net, 128, [3, 3], scope='conv3_1') 227 | net = slim.conv2d(net, 128, [3, 3], scope='conv3_2') 228 | net = slim.max_pool2d(net, [2, 2], scope='pool3') # 3 229 | 230 | net = slim.flatten(net, scope='flatten') 231 | emb = slim.fully_connected(net, emb_size, scope='fc1') 232 | return emb 233 | 234 | 235 | def inception_model(inputs, 236 | emb_size=128, 237 | is_training=True, 238 | end_point='Mixed_7c', 239 | augmentation_function=None, 240 | img_shape=None, 241 | new_shape=None, 242 | batch_norm_decay=None, 243 | dropout_keep_prob=0.8, 244 | min_depth=16, 245 | depth_multiplier=1.0, 246 | spatial_squeeze=True, 247 | reuse=None, 248 | scope='InceptionV3', 249 | num_classes=10, 250 | **kwargs): 251 | from tensorflow.contrib.slim.python.slim.nets.inception_v3 import inception_v3_base 252 | from tensorflow.python.ops import variable_scope 253 | from tensorflow.contrib.framework.python.ops import arg_scope 254 | from tensorflow.contrib.layers.python.layers import layers as layers_lib 255 | 256 | inputs = tf.cast(inputs, tf.float32) / 255.0 257 | if new_shape is not None: 258 | shape = new_shape 259 | inputs = tf.image.resize_images( 260 | inputs, 261 | tf.constant(new_shape[:2]), 262 | method=tf.image.ResizeMethod.BILINEAR) 263 | else: 264 | shape = img_shape 265 | 266 | net = inputs 267 | mean = tf.reduce_mean(net, [1, 2], True) 268 | std = tf.reduce_mean(tf.square(net - mean), [1, 2], True) 269 | net = (net - mean) / (std + 1e-5) 270 | 271 | inputs = net 272 | 273 | if depth_multiplier <= 0: 274 | raise ValueError('depth_multiplier is not greater than zero.') 275 | 276 | with variable_scope.variable_scope( 277 | scope, 'InceptionV3', [inputs, num_classes], reuse=reuse) as scope: 278 | with arg_scope( 279 | [layers_lib.batch_norm, layers_lib.dropout], is_training=is_training): 280 | _, end_points = inception_v3_base( 281 | inputs, 282 | scope=scope, 283 | min_depth=min_depth, 284 | depth_multiplier=depth_multiplier, 285 | final_endpoint=end_point) 286 | 287 | net = end_points[end_point] 288 | net = slim.flatten(net, scope='flatten') 289 | with slim.arg_scope([slim.fully_connected], normalizer_fn=None): 290 | emb = slim.fully_connected(net, emb_size, scope='fc') 291 | return emb 292 | 293 | 294 | def inception_model_small(inputs, 295 | emb_size=128, 296 | is_training=True, 297 | **kwargs): 298 | return inception_model(inputs=inputs, emb_size=emb_size, is_training=is_training, 299 | end_point='Mixed_5d', **kwargs) 300 | 301 | 302 | def vgg16_model(inputs, emb_size=128, is_training=True, img_shape=None, new_shape=None, dropout_keep_prob=0.5, 303 | l2_weight=0.0005, 304 | end_point=None, **kwargs): 305 | inputs = tf.cast(inputs, tf.float32) 306 | if new_shape is not None: 307 | shape = new_shape 308 | inputs = tf.image.resize_images( 309 | inputs, 310 | tf.constant(new_shape[:2]), 311 | method=tf.image.ResizeMethod.BILINEAR) 312 | else: 313 | shape = img_shape 314 | 315 | net = inputs 316 | mean = tf.reduce_mean(net, [1, 2], True) 317 | std = tf.reduce_mean(tf.square(net - mean), [1, 2], True) 318 | net = (net - mean) / (std + 1e-5) 319 | with slim.arg_scope( 320 | [slim.conv2d, slim.fully_connected], 321 | weights_regularizer=slim.l2_regularizer(l2_weight)): 322 | with slim.arg_scope([slim.dropout], is_training=is_training): 323 | net = slim.repeat(net, 2, slim.conv2d, 64, [3, 3], scope='conv1') # 100 324 | net = slim.max_pool2d(net, [2, 2], scope='pool1') # 50 325 | net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2') 326 | net = slim.max_pool2d(net, [2, 2], scope='pool2') # 25 327 | net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3') 328 | net = slim.max_pool2d(net, [2, 2], scope='pool3') # 12 329 | net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv4') 330 | net = slim.max_pool2d(net, [2, 2], scope='pool4') # 6 331 | net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv5') 332 | net = slim.max_pool2d(net, [2, 2], scope='pool5') # 3 333 | net = slim.flatten(net, scope='flatten') 334 | 335 | with slim.arg_scope([slim.fully_connected], normalizer_fn=None): 336 | net = slim.fully_connected(net, 4096, [7, 7], activation_fn=tf.nn.relu, scope='fc6') 337 | if end_point == 'fc6': 338 | return net 339 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 340 | scope='dropout6') 341 | emb = slim.fully_connected(net, emb_size, [1, 1], activation_fn=None, scope='fc7') 342 | 343 | return emb 344 | 345 | 346 | def vgg16_model_small(inputs, emb_size=128, is_training=True, img_shape=None, new_shape=None, dropout_keep_prob=0.5, 347 | **kwargs): 348 | return vgg16_model(inputs, emb_size, is_training, img_shape, new_shape, dropout_keep_prob, end_point='fc6', 349 | **kwargs) 350 | 351 | 352 | def alexnet_model(inputs, 353 | is_training=True, 354 | augmentation_function=None, 355 | emb_size=128, 356 | l2_weight=1e-4, 357 | img_shape=None, 358 | new_shape=None, 359 | image_summary=False, 360 | batch_norm_decay=0.99): 361 | """Mostly identical to slim.nets.alexnt, except for the reverted fc layers""" 362 | 363 | from tensorflow.contrib import layers 364 | from tensorflow.contrib.framework.python.ops import arg_scope 365 | from tensorflow.contrib.layers.python.layers import layers as layers_lib 366 | from tensorflow.contrib.layers.python.layers import regularizers 367 | from tensorflow.python.ops import init_ops 368 | from tensorflow.python.ops import nn_ops 369 | from tensorflow.python.ops import variable_scope 370 | 371 | trunc_normal = lambda stddev: init_ops.truncated_normal_initializer(0.0, stddev) 372 | 373 | def alexnet_v2_arg_scope(weight_decay=0.0005): 374 | with arg_scope( 375 | [layers.conv2d, layers_lib.fully_connected], 376 | activation_fn=nn_ops.relu, 377 | biases_initializer=init_ops.constant_initializer(0.1), 378 | weights_regularizer=regularizers.l2_regularizer(weight_decay)): 379 | with arg_scope([layers.conv2d], padding='SAME'): 380 | with arg_scope([layers_lib.max_pool2d], padding='VALID') as arg_sc: 381 | return arg_sc 382 | 383 | def alexnet_v2(inputs, 384 | is_training=True, 385 | emb_size=4096, 386 | dropout_keep_prob=0.5, 387 | scope='alexnet_v2'): 388 | 389 | inputs = tf.cast(inputs, tf.float32) 390 | if new_shape is not None: 391 | shape = new_shape 392 | inputs = tf.image.resize_images( 393 | inputs, 394 | tf.constant(new_shape[:2]), 395 | method=tf.image.ResizeMethod.BILINEAR) 396 | else: 397 | shape = img_shape 398 | if is_training and augmentation_function is not None: 399 | inputs = augmentation_function(inputs, shape) 400 | if image_summary: 401 | tf.summary.image('Inputs', inputs, max_outputs=3) 402 | 403 | net = inputs 404 | mean = tf.reduce_mean(net, [1, 2], True) 405 | std = tf.reduce_mean(tf.square(net - mean), [1, 2], True) 406 | net = (net - mean) / (std + 1e-5) 407 | inputs = net 408 | 409 | with variable_scope.variable_scope(scope, 'alexnet_v2', [inputs]) as sc: 410 | end_points_collection = sc.original_name_scope + '_end_points' 411 | 412 | # Collect outputs for conv2d, fully_connected and max_pool2d. 413 | with arg_scope( 414 | [layers.conv2d, layers_lib.fully_connected, layers_lib.max_pool2d], 415 | outputs_collections=[end_points_collection]): 416 | net = layers.conv2d( 417 | inputs, 64, [11, 11], 4, padding='VALID', scope='conv1') 418 | net = layers_lib.max_pool2d(net, [3, 3], 2, scope='pool1') 419 | net = layers.conv2d(net, 192, [5, 5], scope='conv2') 420 | net = layers_lib.max_pool2d(net, [3, 3], 2, scope='pool2') 421 | net = layers.conv2d(net, 384, [3, 3], scope='conv3') 422 | net = layers.conv2d(net, 384, [3, 3], scope='conv4') 423 | net = layers.conv2d(net, 256, [3, 3], scope='conv5') 424 | net = layers_lib.max_pool2d(net, [3, 3], 2, scope='pool5') 425 | 426 | net = slim.flatten(net, scope='flatten') 427 | 428 | # Use conv2d instead of fully_connected layers. 429 | with arg_scope( 430 | [slim.fully_connected], 431 | weights_initializer=trunc_normal(0.005), 432 | biases_initializer=init_ops.constant_initializer(0.1)): 433 | net = layers.fully_connected(net, 4096, scope='fc6') 434 | net = layers_lib.dropout( 435 | net, dropout_keep_prob, is_training=is_training, scope='dropout6') 436 | net = layers.fully_connected(net, emb_size, scope='fc7') 437 | 438 | return net 439 | 440 | with slim.arg_scope(alexnet_v2_arg_scope()): 441 | return alexnet_v2(inputs, is_training, emb_size) 442 | -------------------------------------------------------------------------------- /semisup/backend.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2016 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | 16 | Utility functions for Association-based semisupervised training. 17 | """ 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import numpy as np 24 | 25 | import tensorflow as tf 26 | import tensorflow.contrib.slim as slim 27 | 28 | 29 | def create_input(input_images, input_labels, batch_size): 30 | """Create preloaded data batch inputs. 31 | 32 | Args: 33 | input_images: 4D numpy array of input images. 34 | input_labels: 2D numpy array of labels. 35 | batch_size: Size of batches that will be produced. 36 | 37 | Returns: 38 | A list containing the images and labels batches. 39 | """ 40 | if input_labels is not None: 41 | image, label = tf.train.slice_input_producer([input_images, input_labels]) 42 | return tf.train.batch([image, label], batch_size=batch_size) 43 | else: 44 | image = tf.train.slice_input_producer([input_images]) 45 | return tf.train.batch(image, batch_size=batch_size) 46 | 47 | 48 | def create_per_class_inputs(image_by_class, n_per_class, class_labels=None): 49 | """Create batch inputs with specified number of samples per class. 50 | 51 | Args: 52 | image_by_class: List of image arrays, where image_by_class[i] containts 53 | images sampled from the class class_labels[i]. 54 | n_per_class: Number of samples per class in the output batch. 55 | class_labels: List of class labels. Equals to range(len(image_by_class)) if 56 | not provided. 57 | 58 | Returns: 59 | images: Tensor of n_per_class*len(image_by_class) images. 60 | labels: Tensor of same number of labels. 61 | """ 62 | if class_labels is None: 63 | class_labels = np.arange(len(image_by_class)) 64 | batch_images, batch_labels = [], [] 65 | for images, label in zip(image_by_class, class_labels): 66 | labels = tf.fill([len(images)], label) 67 | images, labels = create_input(images, labels, n_per_class) 68 | batch_images.append(images) 69 | batch_labels.append(labels) 70 | return tf.concat(batch_images, 0), tf.concat(batch_labels, 0) 71 | 72 | 73 | def sample_by_label(images, labels, n_per_label, num_labels, seed=None): 74 | """Extract equal number of sampels per class.""" 75 | res = [] 76 | rng = np.random.RandomState(seed=seed) 77 | for i in xrange(num_labels): 78 | a = images[labels == i] 79 | if n_per_label == -1: # use all available labeled data 80 | res.append(a) 81 | else: # use randomly chosen subset 82 | res.append(a[rng.choice(len(a), n_per_label, False)]) 83 | return res 84 | 85 | 86 | def create_virt_emb(n, size): 87 | """Create virtual embeddings.""" 88 | emb = slim.variables.model_variable( 89 | name='virt_emb', 90 | shape=[n, size], 91 | dtype=tf.float32, 92 | trainable=True, 93 | initializer=tf.random_normal_initializer(stddev=0.01)) 94 | return emb 95 | 96 | 97 | def confusion_matrix(labels, predictions, num_labels): 98 | """Compute the confusion matrix.""" 99 | rows = [] 100 | for i in xrange(num_labels): 101 | row = np.bincount(predictions[labels == i], minlength=num_labels) 102 | rows.append(row) 103 | return np.vstack(rows) 104 | 105 | 106 | class SemisupModel(object): 107 | """Helper class for setting up semi-supervised training.""" 108 | 109 | def __init__(self, model_func, num_labels, input_shape, test_in=None): 110 | """Initialize SemisupModel class. 111 | 112 | Creates an evaluation graph for the provided model_func. 113 | 114 | Args: 115 | model_func: Model function. It should receive a tensor of images as 116 | the first argument, along with the 'is_training' flag. 117 | num_labels: Number of taget classes. 118 | input_shape: List, containing input images shape in form 119 | [height, width, channel_num]. 120 | test_in: None or a tensor holding test images. If None, a placeholder will 121 | be created. 122 | """ 123 | 124 | self.num_labels = num_labels 125 | self.step = slim.get_or_create_global_step() 126 | self.ema = tf.train.ExponentialMovingAverage(0.99, self.step) 127 | 128 | self.test_batch_size = 100 129 | 130 | self.model_func = model_func 131 | 132 | if test_in is not None: 133 | self.test_in = test_in 134 | else: 135 | self.test_in = tf.placeholder(np.float32, [None] + input_shape, 'test_in') 136 | 137 | self.test_emb = self.image_to_embedding(self.test_in, is_training=False) 138 | self.test_logit = self.embedding_to_logit(self.test_emb, is_training=False) 139 | 140 | def image_to_embedding(self, images, is_training=True): 141 | """Create a graph, transforming images into embedding vectors.""" 142 | with tf.variable_scope('net', reuse=is_training): 143 | return self.model_func(images, is_training=is_training) 144 | 145 | def embedding_to_logit(self, embedding, is_training=True): 146 | """Create a graph, transforming embedding vectors to logit classs scores.""" 147 | with tf.variable_scope('net', reuse=is_training): 148 | return slim.fully_connected( 149 | embedding, 150 | self.num_labels, 151 | activation_fn=None, 152 | weights_regularizer=slim.l2_regularizer(1e-4)) 153 | 154 | def add_semisup_loss(self, a, b, labels, walker_weight=1.0, visit_weight=1.0): 155 | """Add semi-supervised classification loss to the model. 156 | 157 | The loss constist of two terms: "walker" and "visit". 158 | 159 | Args: 160 | a: [N, emb_size] tensor with supervised embedding vectors. 161 | b: [M, emb_size] tensor with unsupervised embedding vectors. 162 | labels : [N] tensor with labels for supervised embeddings. 163 | walker_weight: Weight coefficient of the "walker" loss. 164 | visit_weight: Weight coefficient of the "visit" loss. 165 | """ 166 | 167 | equality_matrix = tf.equal(tf.reshape(labels, [-1, 1]), labels) 168 | equality_matrix = tf.cast(equality_matrix, tf.float32) 169 | p_target = (equality_matrix / tf.reduce_sum( 170 | equality_matrix, [1], keep_dims=True)) 171 | 172 | match_ab = tf.matmul(a, b, transpose_b=True, name='match_ab') 173 | p_ab = tf.nn.softmax(match_ab, name='p_ab') 174 | p_ba = tf.nn.softmax(tf.transpose(match_ab), name='p_ba') 175 | p_aba = tf.matmul(p_ab, p_ba, name='p_aba') 176 | 177 | self.create_walk_statistics(p_aba, equality_matrix) 178 | 179 | loss_aba = tf.losses.softmax_cross_entropy( 180 | p_target, 181 | tf.log(1e-8 + p_aba), 182 | weights=walker_weight, 183 | scope='loss_aba') 184 | self.add_visit_loss(p_ab, visit_weight) 185 | 186 | tf.summary.scalar('Loss_aba', loss_aba) 187 | 188 | def add_visit_loss(self, p, weight=1.0): 189 | """Add the "visit" loss to the model. 190 | 191 | Args: 192 | p: [N, M] tensor. Each row must be a valid probability distribution 193 | (i.e. sum to 1.0) 194 | weight: Loss weight. 195 | """ 196 | visit_probability = tf.reduce_mean( 197 | p, [0], keep_dims=True, name='visit_prob') 198 | t_nb = tf.shape(p)[1] 199 | visit_loss = tf.losses.softmax_cross_entropy( 200 | tf.fill([1, t_nb], 1.0 / tf.cast(t_nb, tf.float32)), 201 | tf.log(1e-8 + visit_probability), 202 | weights=weight, 203 | scope='loss_visit') 204 | 205 | tf.summary.scalar('Loss_Visit', visit_loss) 206 | 207 | def add_logit_loss(self, logits, labels, weight=1.0, smoothing=0.0): 208 | """Add supervised classification loss to the model.""" 209 | 210 | logit_loss = tf.losses.softmax_cross_entropy( 211 | tf.one_hot(labels, logits.get_shape()[-1]), 212 | logits, 213 | scope='loss_logit', 214 | weights=weight, 215 | label_smoothing=smoothing) 216 | 217 | tf.summary.scalar('Loss_Logit', logit_loss) 218 | 219 | def create_walk_statistics(self, p_aba, equality_matrix): 220 | """Adds "walker" loss statistics to the graph. 221 | 222 | Args: 223 | p_aba: [N, N] matrix, where element [i, j] corresponds to the 224 | probalility of the round-trip between supervised samples i and j. 225 | Sum of each row of 'p_aba' must be equal to one. 226 | equality_matrix: [N, N] boolean matrix, [i,j] is True, when samples 227 | i and j belong to the same class. 228 | """ 229 | # Using the square root of the correct round trip probalilty as an estimate 230 | # of the current classifier accuracy. 231 | per_row_accuracy = 1.0 - tf.reduce_sum((equality_matrix * p_aba), 1)**0.5 232 | estimate_error = tf.reduce_mean( 233 | 1.0 - per_row_accuracy, name=p_aba.name[:-2] + '_esterr') 234 | self.add_average(estimate_error) 235 | self.add_average(p_aba) 236 | 237 | tf.summary.scalar('Stats_EstError', estimate_error) 238 | 239 | def add_average(self, variable): 240 | """Add moving average variable to the model.""" 241 | tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, self.ema.apply([variable])) 242 | average_variable = tf.identity( 243 | self.ema.average(variable), name=variable.name[:-2] + '_avg') 244 | return average_variable 245 | 246 | def create_train_op(self, learning_rate): 247 | """Create and return training operation.""" 248 | 249 | slim.model_analyzer.analyze_vars( 250 | tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), print_info=True) 251 | 252 | self.train_loss = tf.losses.get_total_loss() 253 | self.train_loss_average = self.add_average(self.train_loss) 254 | 255 | tf.summary.scalar('Learning_Rate', learning_rate) 256 | tf.summary.scalar('Loss_Total_Avg', self.train_loss_average) 257 | tf.summary.scalar('Loss_Total', self.train_loss) 258 | 259 | trainer = tf.train.AdamOptimizer(learning_rate) 260 | 261 | self.train_op = slim.learning.create_train_op(self.train_loss, trainer) 262 | return self.train_op 263 | 264 | def calc_embedding(self, images, endpoint): 265 | """Evaluate 'endpoint' tensor for all 'images' using batches.""" 266 | batch_size = self.test_batch_size 267 | emb = [] 268 | for i in xrange(0, len(images), batch_size): 269 | emb.append(endpoint.eval({self.test_in: images[i:i + batch_size]})) 270 | return np.concatenate(emb) 271 | 272 | def classify(self, images): 273 | """Compute logit scores for provided images.""" 274 | return self.calc_embedding(images, self.test_logit) 275 | -------------------------------------------------------------------------------- /semisup/eval.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | """ 3 | Copyright 2016 Google Inc. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | 17 | Association-based semi-supervised eval module. 18 | 19 | This script defines the evaluation loop that works with the training loop 20 | from train.py. 21 | """ 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import math 28 | from functools import partial 29 | from importlib import import_module 30 | 31 | import semisup 32 | import tensorflow as tf 33 | import tensorflow.contrib.slim as slim 34 | from tensorflow.python.platform import app 35 | from tensorflow.python.platform import flags 36 | 37 | FLAGS = flags.FLAGS 38 | 39 | flags.DEFINE_string('dataset', 'svhn', 'Which dataset to work on.') 40 | 41 | flags.DEFINE_string('architecture', 'svhn_model', 'Which dataset to work on.') 42 | 43 | flags.DEFINE_integer('eval_batch_size', 500, 'Batch size for eval loop.') 44 | 45 | flags.DEFINE_integer('new_size', 0, 'If > 0, resize image to this width/height.' 46 | 'Needs to match size used for training.') 47 | 48 | flags.DEFINE_integer('emb_size', 128, 49 | 'Size of the embeddings to learn.') 50 | 51 | flags.DEFINE_integer('eval_interval_secs', 300, 52 | 'How many seconds between executions of the eval loop.') 53 | 54 | flags.DEFINE_string('logdir', '/tmp/semisup', 55 | 'Where the checkpoints are stored ' 56 | 'and eval events will be written to.') 57 | 58 | flags.DEFINE_string('master', '', 59 | 'BNS name of the TensorFlow master to use.') 60 | 61 | flags.DEFINE_integer('timeout', 1200, 62 | 'The maximum amount of time to wait between checkpoints. ' 63 | 'If left as `None`, then the process will wait ' 64 | 'indefinitely.') 65 | 66 | def main(_): 67 | # Get dataset-related toolbox. 68 | dataset_tools = import_module('tools.' + FLAGS.dataset) 69 | architecture = getattr(semisup.architectures, FLAGS.architecture) 70 | 71 | num_labels = dataset_tools.NUM_LABELS 72 | image_shape = dataset_tools.IMAGE_SHAPE 73 | 74 | test_images, test_labels = dataset_tools.get_data('test') 75 | 76 | graph = tf.Graph() 77 | with graph.as_default(): 78 | 79 | # Set up input pipeline. 80 | image, label = tf.train.slice_input_producer([test_images, test_labels]) 81 | images, labels = tf.train.batch( 82 | [image, label], batch_size=FLAGS.eval_batch_size) 83 | images = tf.cast(images, tf.float32) 84 | labels = tf.cast(labels, tf.int64) 85 | 86 | # Reshape if necessary. 87 | if FLAGS.new_size > 0: 88 | new_shape = [FLAGS.new_size, FLAGS.new_size, 3] 89 | else: 90 | new_shape = None 91 | 92 | # Create function that defines the network. 93 | model_function = partial( 94 | architecture, 95 | is_training=False, 96 | new_shape=new_shape, 97 | img_shape=image_shape, 98 | augmentation_function=None, 99 | image_summary=False, 100 | emb_size=FLAGS.emb_size) 101 | 102 | 103 | # Set up semisup model. 104 | model = semisup.SemisupModel( 105 | model_function, 106 | num_labels, 107 | image_shape, 108 | test_in=images) 109 | 110 | # Add moving average variables. 111 | for var in tf.get_collection('moving_vars'): 112 | tf.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES, var) 113 | for var in slim.get_model_variables(): 114 | tf.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES, var) 115 | 116 | # Get prediction tensor from semisup model. 117 | predictions = tf.argmax(model.test_logit, 1) 118 | 119 | # Accuracy metric for summaries. 120 | names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({ 121 | 'Accuracy': slim.metrics.streaming_accuracy(predictions, labels), 122 | }) 123 | for name, value in names_to_values.iteritems(): 124 | tf.summary.scalar(name, value) 125 | 126 | # Run the actual evaluation loop. 127 | num_batches = math.ceil(len(test_labels) / float(FLAGS.eval_batch_size)) 128 | 129 | config = tf.ConfigProto() 130 | config.gpu_options.allow_growth = True 131 | slim.evaluation.evaluation_loop( 132 | master=FLAGS.master, 133 | checkpoint_dir=FLAGS.logdir + '/train', 134 | logdir=FLAGS.logdir + '/eval', 135 | num_evals=num_batches, 136 | eval_op=names_to_updates.values(), 137 | eval_interval_secs=FLAGS.eval_interval_secs, 138 | session_config=config, 139 | timeout=FLAGS.timeout 140 | ) 141 | 142 | 143 | if __name__ == '__main__': 144 | tf.logging.set_verbosity(tf.logging.INFO) 145 | app.run() 146 | -------------------------------------------------------------------------------- /semisup/mnist_train_eval.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | """ 3 | Copyright 2016 Google Inc. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | 17 | Association-based semi-supervised training example in MNIST dataset. 18 | 19 | Training should reach ~1% error rate on the test set using 100 labeled samples 20 | in 5000-10000 steps (a few minutes on Titan X GPU) 21 | 22 | """ 23 | 24 | from __future__ import absolute_import 25 | from __future__ import division 26 | from __future__ import print_function 27 | 28 | import tensorflow as tf 29 | import semisup 30 | 31 | from tensorflow.python.platform import app 32 | from tensorflow.python.platform import flags 33 | 34 | FLAGS = flags.FLAGS 35 | 36 | flags.DEFINE_integer('sup_per_class', 10, 37 | 'Number of labeled samples used per class.') 38 | 39 | flags.DEFINE_integer('sup_seed', -1, 40 | 'Integer random seed used for labeled set selection.') 41 | 42 | flags.DEFINE_integer('sup_per_batch', 10, 43 | 'Number of labeled samples per class per batch.') 44 | 45 | flags.DEFINE_integer('unsup_batch_size', 100, 46 | 'Number of unlabeled samples per batch.') 47 | 48 | flags.DEFINE_integer('eval_interval', 500, 49 | 'Number of steps between evaluations.') 50 | 51 | flags.DEFINE_float('learning_rate', 1e-3, 'Initial learning rate.') 52 | 53 | flags.DEFINE_float('decay_factor', 0.33, 'Learning rate decay factor.') 54 | 55 | flags.DEFINE_float('decay_steps', 5000, 56 | 'Learning rate decay interval in steps.') 57 | 58 | flags.DEFINE_float('visit_weight', 1.0, 'Weight for visit loss.') 59 | 60 | flags.DEFINE_integer('max_steps', 20000, 'Number of training steps.') 61 | 62 | flags.DEFINE_string('logdir', '/tmp/semisup_mnist', 'Training log path.') 63 | 64 | from tools import mnist as mnist_tools 65 | 66 | NUM_LABELS = mnist_tools.NUM_LABELS 67 | IMAGE_SHAPE = mnist_tools.IMAGE_SHAPE 68 | 69 | 70 | def main(_): 71 | train_images, train_labels = mnist_tools.get_data('train') 72 | test_images, test_labels = mnist_tools.get_data('test') 73 | 74 | # Sample labeled training subset. 75 | seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else None 76 | sup_by_label = semisup.sample_by_label(train_images, train_labels, 77 | FLAGS.sup_per_class, NUM_LABELS, seed) 78 | 79 | graph = tf.Graph() 80 | with graph.as_default(): 81 | model = semisup.SemisupModel(semisup.architectures.mnist_model, NUM_LABELS, 82 | IMAGE_SHAPE) 83 | 84 | # Set up inputs. 85 | t_unsup_images, _ = semisup.create_input(train_images, train_labels, 86 | FLAGS.unsup_batch_size) 87 | t_sup_images, t_sup_labels = semisup.create_per_class_inputs( 88 | sup_by_label, FLAGS.sup_per_batch) 89 | 90 | # Compute embeddings and logits. 91 | t_sup_emb = model.image_to_embedding(t_sup_images) 92 | t_unsup_emb = model.image_to_embedding(t_unsup_images) 93 | t_sup_logit = model.embedding_to_logit(t_sup_emb) 94 | 95 | # Add losses. 96 | model.add_semisup_loss( 97 | t_sup_emb, t_unsup_emb, t_sup_labels, visit_weight=FLAGS.visit_weight) 98 | model.add_logit_loss(t_sup_logit, t_sup_labels) 99 | 100 | t_learning_rate = tf.train.exponential_decay( 101 | FLAGS.learning_rate, 102 | model.step, 103 | FLAGS.decay_steps, 104 | FLAGS.decay_factor, 105 | staircase=True) 106 | train_op = model.create_train_op(t_learning_rate) 107 | summary_op = tf.summary.merge_all() 108 | 109 | summary_writer = tf.summary.FileWriter(FLAGS.logdir, graph) 110 | 111 | saver = tf.train.Saver() 112 | 113 | with tf.Session(graph=graph) as sess: 114 | tf.global_variables_initializer().run() 115 | 116 | coord = tf.train.Coordinator() 117 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 118 | 119 | for step in xrange(FLAGS.max_steps): 120 | _, summaries = sess.run([train_op, summary_op]) 121 | if (step + 1) % FLAGS.eval_interval == 0 or step == 99: 122 | print('Step: %d' % step) 123 | test_pred = model.classify(test_images).argmax(-1) 124 | conf_mtx = semisup.confusion_matrix(test_labels, test_pred, NUM_LABELS) 125 | test_err = (test_labels != test_pred).mean() * 100 126 | print(conf_mtx) 127 | print('Test error: %.2f %%' % test_err) 128 | print() 129 | 130 | test_summary = tf.Summary( 131 | value=[tf.Summary.Value( 132 | tag='Test Err', simple_value=test_err)]) 133 | 134 | summary_writer.add_summary(summaries, step) 135 | summary_writer.add_summary(test_summary, step) 136 | 137 | saver.save(sess, FLAGS.logdir, model.step) 138 | 139 | coord.request_stop() 140 | coord.join(threads) 141 | 142 | 143 | if __name__ == '__main__': 144 | app.run() 145 | -------------------------------------------------------------------------------- /semisup/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haeusser/learning_by_association/7cdef811af671b60eaedc495600e384fcc1f1aff/semisup/tools/__init__.py -------------------------------------------------------------------------------- /semisup/tools/data_dirs.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains the directory paths to data sets on your hard drive. 3 | Change the directories and .gitignore this file. 4 | """ 5 | 6 | synth = '/work/haeusser/data/synth/' 7 | stl10 = '/work/haeusser/data/stl10_binary/' 8 | svhn = '/work/haeusser/data/svhn/' 9 | mnist = '/work/haeusser/data/mnist/' 10 | imagenet = '/work/haeusser/data/imagenet/raw-data/' 11 | imagenet_labels = '/usr/wiss/haeusser/libs/tfmodels/inception/inception/data/imagenet_lsvrc_2015_synsets.txt' 12 | gtsrb = '/work/haeusser/data/gtsrb/' 13 | usps = '/work/haeusser/data/usps/' 14 | office = '/work/haeusser/data/office/' 15 | mnistm = '' -------------------------------------------------------------------------------- /semisup/tools/data_dirs.py.template: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains the directory paths to data sets on your hard drive. 3 | Change the directories, rename the file to data_dirs.py and .gitignore the file. 4 | """ 5 | 6 | synth = '/path/to/synth/' 7 | stl10 = '/path/to/stl10_binary/' 8 | svhn = '/path/to/svhn/' 9 | mnist = '/path/to/mnist/' 10 | usps = '/path/to/usps/' 11 | imagenet = '/path/to/imagenet/raw-data/' 12 | imagenet_labels = '/path/to/imagenet_lsvrc_2015_synsets.txt' 13 | gtsrb = '/path/to/gtsrb/' 14 | synth_signs = '/path/to/synth_signs/' 15 | office = '/path/to/office/' 16 | mnistm = '/path/to/mnistm/' 17 | -------------------------------------------------------------------------------- /semisup/tools/data_util/create_mnistm.py: -------------------------------------------------------------------------------- 1 | import tarfile 2 | import cPickle as pkl 3 | import numpy as np 4 | import skimage.io 5 | from tensorflow.examples.tutorials.mnist import input_data 6 | 7 | print 'Retrieving MNIST data...' 8 | mnist = input_data.read_data_sets('MNIST_data') 9 | 10 | BST_PATH = 'BSR_bsds500.tgz' 11 | 12 | rand = np.random.RandomState(42) 13 | 14 | f = tarfile.open(BST_PATH) 15 | train_files = [] 16 | for name in f.getnames(): 17 | if name.startswith('BSR/BSDS500/data/images/train/'): 18 | train_files.append(name) 19 | 20 | print 'Loading BSR training images...' 21 | background_data = [] 22 | for name in train_files: 23 | try: 24 | fp = f.extractfile(name) 25 | bg_img = skimage.io.imread(fp) 26 | background_data.append(bg_img) 27 | except: 28 | continue 29 | 30 | 31 | def compose_image(digit, background): 32 | """Difference-blend a digit and a random patch from a background image.""" 33 | w, h, _ = background.shape 34 | dw, dh, _ = digit.shape 35 | x = np.random.randint(0, w - dw) 36 | y = np.random.randint(0, h - dh) 37 | 38 | bg = background[x:x + dw, y:y + dh] 39 | return np.abs(bg - digit).astype(np.uint8) 40 | 41 | 42 | def mnist_to_img(x): 43 | """Binarize MNIST digit and convert to RGB.""" 44 | x = (x > 0).astype(np.float32) 45 | d = x.reshape([28, 28, 1]) * 255 46 | return np.concatenate([d, d, d], 2) 47 | 48 | 49 | def create_mnistm(X): 50 | """ 51 | Give an array of MNIST digits, blend random background patches to 52 | build the MNIST-M dataset as described in 53 | http://jmlr.org/papers/volume17/15-239/15-239.pdf 54 | """ 55 | X_ = np.zeros([X.shape[0], 28, 28, 3], np.uint8) 56 | for i in range(X.shape[0]): 57 | 58 | if i % 1000 == 0: 59 | print i 60 | 61 | bg_img = rand.choice(background_data) 62 | 63 | d = mnist_to_img(X[i]) 64 | d = compose_image(d, bg_img) 65 | X_[i] = d 66 | 67 | return X_ 68 | 69 | 70 | print 'Building train set...' 71 | train = create_mnistm(mnist.train.images) 72 | print 'Building test set...' 73 | test = create_mnistm(mnist.test.images) 74 | print 'Building validation set...' 75 | valid = create_mnistm(mnist.validation.images) 76 | 77 | # Save dataset as pickle 78 | print 'Saving data dictionary as pickle file...' 79 | with open('mnistm_data.pkl', 'w') as f: 80 | pkl.dump({'train_images': train, 81 | 'test_images': test, 82 | 'valid_images': valid, 83 | 'train_labels': mnist.train.labels, 84 | 'test_labels': mnist.test.labels, 85 | 'valid_labels': mnist.validation.labels}, f, -1) 86 | -------------------------------------------------------------------------------- /semisup/tools/data_util/create_mnistm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # fetch BSDS500 data 4 | wget https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz 5 | 6 | # call conversion script via python 7 | python create_mnistm.py 8 | 9 | # delete temporary files 10 | rm -r MNIST_data/ 11 | rm BSR_bsds500.tgz 12 | 13 | 14 | -------------------------------------------------------------------------------- /semisup/tools/gtsrb.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import csv 5 | import pickle 6 | 7 | import matplotlib.pyplot as plt 8 | from PIL import Image 9 | import numpy as np 10 | 11 | import data_dirs 12 | 13 | DATADIR = data_dirs.gtsrb 14 | 15 | NUM_LABELS = 43 16 | IMAGE_SHAPE = [40, 40, 3] 17 | 18 | 19 | def get_data(name): 20 | """Utility for convenient data loading.""" 21 | if name in ['train', 'unlabeled']: 22 | return read_gtsrb_pickle(DATADIR + '/gtsrb_train.p') 23 | elif name == 'test': 24 | return read_gtsrb_pickle(DATADIR + '/gtsrb_test.p') 25 | 26 | 27 | def read_gtsrb_pickle(filename): 28 | """ 29 | Extract images from pickle file. 30 | :param filename: 31 | :return: 32 | """ 33 | with open(filename, mode='rb') as f: 34 | data = pickle.load(f) 35 | if not type(data['labels'][0]) == int: 36 | labels = [int(x) for x in data['labels']] 37 | else: 38 | labels = data['labels'] 39 | return np.array(data['images']), np.array(labels) 40 | 41 | 42 | def preprocess_gtsrb(images, roi_boxes, resize_to): 43 | """ 44 | Crops images to region-of-interest boxes and applies resizing with bilinear 45 | interpolation. 46 | :param images: np.array of images 47 | :param roi_boxes: np.array of region-of-interest boxes of the form 48 | (left, upper, right, lower) 49 | :return: 50 | """ 51 | preprocessed_images = [] 52 | for idx, img in enumerate(images): 53 | pil_img = Image.fromarray(img) 54 | cropped_pil_img = pil_img.crop(roi_boxes[idx]) 55 | resized_pil_img = cropped_pil_img.resize(resize_to, Image.BILINEAR) 56 | preprocessed_images.append(np.asarray(resized_pil_img)) 57 | 58 | return np.asarray(preprocessed_images) 59 | 60 | 61 | def load_and_append_image_class(prefix, gtFile, images, labels, roi_boxes): 62 | gtReader = csv.reader(gtFile, 63 | delimiter=';') # csv parser for annotations file 64 | gtReader.next() # skip header 65 | # loop over all images in current annotations file 66 | for row in gtReader: 67 | images.append( 68 | plt.imread(prefix + row[0])) # the 1st column is the filename 69 | roi_boxes.append( 70 | (float(row[3]), float(row[4]), float(row[5]), float(row[6]))) 71 | labels.append(row[7]) # the 8th column is the label 72 | gtFile.close() 73 | 74 | 75 | def preprocess_and_convert_gtsrb_to_pickle(rootpath, pickle_filename, 76 | type='train'): 77 | """ 78 | Reads traffic sign data for German Traffic Sign Recognition Benchmark. 79 | When loading the test dataset, make sure to have downloaded the EXTENDED 80 | annotaitons including the class ids. 81 | :param rootpath: path to the traffic sign data, 82 | for example './GTSRB/Training' 83 | :return: list of images, list of corresponding labels 84 | """ 85 | images = [] # images 86 | labels = [] # corresponding labels 87 | roi_boxes = [] # box coordinates for ROI (left, upper, right, lower) 88 | 89 | if type == 'train': 90 | # loop over all 42 classes 91 | for c in range(0, NUM_LABELS): 92 | prefix = rootpath + '/' + format(c, '05d') + '/' # subdir for class 93 | gtFile = open( 94 | prefix + 'GT-' + format(c, '05d') + '.csv') # annotations file 95 | load_and_append_image_class(prefix, gtFile, images, labels, 96 | roi_boxes) 97 | elif type == 'test': 98 | prefix = rootpath + '/' 99 | gtFile = open(prefix + 'GT-final_test' + '.csv') # annotations file 100 | load_and_append_image_class(prefix, gtFile, images, labels, roi_boxes) 101 | else: 102 | raise ValueError( 103 | 'The data partition type you have provided is not valid.') 104 | 105 | images = np.asarray(images) 106 | labels = np.asarray(labels) 107 | roi_boxes = np.asarray(roi_boxes) 108 | 109 | preprocessed_images = preprocess_gtsrb(images, roi_boxes, 110 | resize_to=IMAGE_SHAPE[:-1]) 111 | 112 | pickle.dump({'images': preprocessed_images, 'labels': labels}, 113 | open(pickle_filename, "wb")) 114 | -------------------------------------------------------------------------------- /semisup/tools/mnist.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2016 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | 16 | Definitions and utilities for the MNIST model. 17 | 18 | This file contains functions that are needed for semisup training and evalutaion 19 | on the MNIST dataset. 20 | They are used in MNIST_train_eval.py. 21 | 22 | """ 23 | 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import gzip 28 | import numpy as np 29 | import data_dirs 30 | 31 | DATADIR = data_dirs.mnist 32 | 33 | 34 | NUM_LABELS = 10 35 | IMAGE_SHAPE = [28, 28, 1] 36 | 37 | 38 | def get_data(name): 39 | """Utility for convenient data loading.""" 40 | 41 | if name == 'train' or name == 'unlabeled': 42 | return extract_images(DATADIR + 43 | '/train-images-idx3-ubyte.gz'), extract_labels( 44 | DATADIR + '/train-labels-idx1-ubyte.gz') 45 | elif name == 'test': 46 | return extract_images(DATADIR + 47 | '/t10k-images-idx3-ubyte.gz'), extract_labels( 48 | DATADIR + '/t10k-labels-idx1-ubyte.gz') 49 | 50 | 51 | def _read32(bytestream): 52 | dt = np.dtype(np.uint32).newbyteorder('>') 53 | return np.frombuffer(bytestream.read(4), dtype=dt)[0] 54 | 55 | 56 | def extract_images(filename): 57 | """Extract the images into a 4D uint8 numpy array [index, y, x, depth].""" 58 | print('Extracting', filename) 59 | with open(filename, 'r') as f, gzip.GzipFile(fileobj=f) as bytestream: 60 | magic = _read32(bytestream) 61 | if magic != 2051: 62 | raise ValueError('Invalid magic number %d in MNIST image file: %s' % 63 | (magic, filename)) 64 | num_images = _read32(bytestream) 65 | rows = _read32(bytestream) 66 | cols = _read32(bytestream) 67 | buf = bytestream.read(rows * cols * num_images) 68 | data = np.frombuffer(buf, dtype=np.uint8) 69 | data = data.reshape(num_images, rows, cols, 1) 70 | return data 71 | 72 | 73 | def extract_labels(filename): 74 | """Extract the labels into a 1D uint8 numpy array [index].""" 75 | print('Extracting', filename) 76 | with open(filename, 'r') as f, gzip.GzipFile(fileobj=f) as bytestream: 77 | magic = _read32(bytestream) 78 | if magic != 2049: 79 | raise ValueError('Invalid magic number %d in MNIST label file: %s' % 80 | (magic, filename)) 81 | num_items = _read32(bytestream) 82 | buf = bytestream.read(num_items) 83 | labels = np.frombuffer(buf, dtype=np.uint8) 84 | return labels 85 | 86 | -------------------------------------------------------------------------------- /semisup/tools/mnist3.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2016 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | 16 | Definitions and utilities for the MNIST model. 17 | 18 | This file contains functions that are needed for semisup training and evalutaion 19 | on the MNIST dataset. 20 | They are used in MNIST_train_eval.py. 21 | 22 | """ 23 | 24 | import numpy as np 25 | 26 | import mnist 27 | 28 | NUM_LABELS = mnist.NUM_LABELS 29 | IMAGE_SHAPE = mnist.IMAGE_SHAPE[:2] + [3] 30 | 31 | 32 | def get_data(name): 33 | """Utility for convenient data loading.""" 34 | images, labels = mnist.get_data(name) 35 | images = np.concatenate([images] * 3, 3) 36 | return images, labels 37 | -------------------------------------------------------------------------------- /semisup/tools/mnistm.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import pickle 5 | import numpy as np 6 | import data_dirs 7 | 8 | DATADIR = data_dirs.mnistm 9 | 10 | 11 | NUM_LABELS = 10 12 | IMAGE_SHAPE = [28, 28, 3] 13 | 14 | 15 | def get_data(name): 16 | """Utility for convenient data loading.""" 17 | if name in ['train', 'unlabeled']: 18 | return load_mnistm(DATADIR, 'train') 19 | elif name == 'test': 20 | return load_mnistm(DATADIR, 'test') 21 | 22 | 23 | def load_mnistm(fileroot, partition): 24 | with open(fileroot + 'mnistm_data.pkl', 'rb') as f: 25 | data = pickle.load(f) 26 | 27 | if partition == 'train': 28 | images = np.concatenate((data['train_images'], 29 | data['valid_images']), axis=0) 30 | labels = np.concatenate((data['train_labels'], 31 | data['valid_labels']), axis=0) 32 | elif partition == 'test': 33 | images = data['test_images'] 34 | labels = data['test_labels'] 35 | else: 36 | raise ValueError('The provided data partition name is not valid. ' 37 | 'Use "train" or "test".') 38 | 39 | return images, labels 40 | 41 | -------------------------------------------------------------------------------- /semisup/tools/stl10.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2016 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | 16 | Definitions and utilities for the STL10 model. 17 | 18 | This file contains functions that are needed for semisup training and 19 | evalutaion on the STL10 dataset. 20 | They are used in stl10_train.py and stl10_eval.py. 21 | 22 | """ 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | import numpy as np 27 | from tensorflow.python.platform import tf_logging as logging 28 | from tensorflow.python.platform import gfile 29 | import data_dirs 30 | 31 | DATADIR = data_dirs.stl10 32 | NUM_LABELS = 10 33 | IMAGE_SHAPE = [96, 96, 3] 34 | 35 | 36 | def get_data(name, max_num=20000): 37 | """Utility for convenient data loading. 38 | 39 | Args: 40 | name: Name of the split. Can be 'test', 'train' or 'unlabeled'. 41 | max_num: maximum number of unlabeled samples. 42 | Returns: 43 | A tuple containing (images, labels) where lables=None for the unlabeled 44 | split. 45 | """ 46 | if name == 'train': 47 | return extract_images(DATADIR + 'train_X.bin', 48 | IMAGE_SHAPE), extract_labels(DATADIR + 'train_y.bin') 49 | elif name == 'test': 50 | return extract_images(DATADIR + 'test_X.bin', 51 | IMAGE_SHAPE), extract_labels(DATADIR + 'test_y.bin') 52 | 53 | elif name == 'unlabeled': 54 | res = extract_images(DATADIR + 'unlabeled_X.bin', IMAGE_SHAPE) 55 | num_images = len(res) 56 | if num_images > max_num: 57 | rng = np.random.RandomState() 58 | return res[rng.choice(len(res), max_num, False)], None 59 | else: 60 | return res, None 61 | 62 | 63 | def extract_images(filename, shape): 64 | """Extract the images into a 4D uint8 numpy array [index, y, x, depth].""" 65 | logging.info('Extracting %s', filename) 66 | with gfile.Open(filename) as f: 67 | imgs = np.fromstring(f.read(), np.uint8) 68 | imgs = imgs.reshape(-1, *shape[::-1]) 69 | imgs = np.transpose(imgs, [0, 3, 2, 1]) 70 | return imgs 71 | 72 | 73 | def extract_labels(filename): 74 | """Extract the labels into a 1D uint8 numpy array [index].""" 75 | logging.info('Extracting %s', filename) 76 | with gfile.Open(filename) as f: 77 | lbls = np.fromstring(f.read(), np.uint8) 78 | lbls -= 1 # STL-10 labels are not zero-indexed 79 | return lbls 80 | 81 | 82 | def pick_fold(images, labels, fold=-1): 83 | """Choose subset of labeled training data. 84 | 85 | According to the training protocol suggested by the creators of the dataset 86 | https://cs.stanford.edu/~acoates/stl10/ 87 | 88 | Args: 89 | images: A 4D numpy array containing the images. 90 | labels: A 1D numpy array containing the corresponding labels. 91 | fold: The fold index in [0, 9]. Default: -1 = use all data. 92 | Returns: 93 | A tuple (images, lables) 94 | """ 95 | assert -1 <= fold <= 9, 'Fold index needs to be in [0, 9] or -1 for all data.' 96 | if fold > -1: 97 | logging.info('Selecting fold %d', fold) 98 | fold_indices = [] 99 | with gfile.Open( 100 | 'path_to_stl10_binary/fold_indices.txt', 101 | 'r') as f: 102 | for line in f.iteritems(): 103 | fold_indices.append((line.split(' ')[:-1])) 104 | fold_indices = np.array(fold_indices).astype(np.uint16) 105 | 106 | images = images[fold_indices[fold]] 107 | labels = labels[fold_indices[fold]] 108 | else: 109 | logging.info('Using all folds.') 110 | return images, labels 111 | 112 | 113 | # Dataset specific augmentation parameters. 114 | augmentation_params = dict() 115 | augmentation_params['max_crop_percentage'] = 0.2 116 | augmentation_params['brightness_max_delta'] = 1.3 117 | augmentation_params['saturation_lower'] = 0.5 118 | augmentation_params['saturation_upper'] = 1.2 119 | augmentation_params['hue_max_delta'] = 0.1 120 | augmentation_params['gray_prob'] = 0.5 121 | augmentation_params['max_rotate_angle'] = 10 -------------------------------------------------------------------------------- /semisup/tools/svhn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2016 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | 16 | 17 | Definitions and utilities for the svhn model. 18 | 19 | This file contains functions that are needed for semisup training and 20 | evalutaion on the SVHN dataset. 21 | They are used in svhn_train.py and svhn_eval.py. 22 | """ 23 | from __future__ import division 24 | from __future__ import print_function 25 | import numpy as np 26 | import scipy.io 27 | import data_dirs 28 | 29 | 30 | DATADIR = data_dirs.svhn 31 | NUM_LABELS = 10 32 | IMAGE_SHAPE = [32, 32, 3] 33 | 34 | 35 | def get_data(name): 36 | """Get a split from the dataset. 37 | 38 | Args: 39 | name: 'train' or 'test' 40 | 41 | Returns: 42 | images, labels 43 | """ 44 | 45 | if name == 'train' or name == 'unlabeled': 46 | data = scipy.io.loadmat(DATADIR + 'train_32x32.mat') 47 | elif name == 'test': 48 | data = scipy.io.loadmat(DATADIR + 'test_32x32.mat') 49 | 50 | images = np.rollaxis(data['X'], -1) 51 | labels = data['y'].ravel() % 10 52 | 53 | if name == 'unlabeled': 54 | return images, None 55 | else: 56 | return images, labels -------------------------------------------------------------------------------- /semisup/tools/synth.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2016 Google Inc. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | 16 | Definitions and utilities for the synthetic digit (Ganin) model. 17 | 18 | This file contains functions that are needed for semisup training and 19 | evalutaion on the SVHN dataset. 20 | They are used in svhn_train.py and svhn_eval.py. 21 | 22 | """ 23 | from __future__ import division 24 | from __future__ import print_function 25 | import numpy as np 26 | import scipy.io 27 | import data_dirs 28 | 29 | DATADIR = data_dirs.synth 30 | NUM_LABELS = 10 31 | IMAGE_SHAPE = [32, 32, 3] 32 | 33 | 34 | def get_data(name, num=70000): 35 | """Get a split from the synth dataset. 36 | 37 | Args: 38 | name: 'train' or 'test' 39 | num: How many samples to read (randomly) from the data set 40 | 41 | Returns: 42 | images, labels 43 | """ 44 | 45 | if name == 'train' or name == 'unlabeled': 46 | fn = 'synth_train_32x32.mat' 47 | elif name == 'test': 48 | fn = 'synth_test_32x32.mat' 49 | 50 | data = scipy.io.loadmat(DATADIR + fn) 51 | 52 | images = np.rollaxis(data['X'], -1) 53 | labels = data['y'].ravel() % 10 54 | 55 | num_samples = len(images) 56 | indices = np.random.choice(num_samples, min(num, num_samples), False) 57 | 58 | images = images[indices] 59 | labels = labels[indices] 60 | 61 | if name == 'unlabeled': 62 | return images, None 63 | else: 64 | return images, labels -------------------------------------------------------------------------------- /semisup/tools/synth_signs.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import pickle 5 | import numpy as np 6 | from PIL import Image 7 | import csv 8 | 9 | import data_dirs 10 | 11 | DATADIR = data_dirs.synth_signs 12 | 13 | NUM_LABELS = 43 14 | IMAGE_SHAPE = [40, 40, 3] 15 | 16 | 17 | def get_data(name): 18 | """Utility for convenient data loading.""" 19 | if name in ['train', 'unlabeled']: 20 | return read_synth_signs_pickle(DATADIR + '/synth_signs_train.p') 21 | elif name == 'test': 22 | return read_synth_signs_pickle(DATADIR + '/synth_signs_test.p') 23 | 24 | 25 | def read_synth_signs_pickle(filename): 26 | """ 27 | Extract images from pickle file. 28 | :param filename: 29 | :return: 30 | """ 31 | with open(filename, mode='rb') as f: 32 | data = pickle.load(f) 33 | if not type(data['labels'][0]) == int: 34 | labels = [int(x) for x in data['labels']] 35 | else: 36 | labels = data['labels'] 37 | return np.array(data['images']), np.array(labels) 38 | 39 | 40 | def preprocess_and_convert_synth_signs_to_pickle(rootpath): 41 | # take a randomly shuffled train/test split, but always the same 42 | np.random.seed(314) 43 | train_fraction = 0.9 44 | 45 | images = [] # images 46 | labels = [] # corresponding labels 47 | 48 | with open(rootpath + 'train_labelling.txt', 'rt') as f: 49 | reader = csv.reader(f, delimiter=' ') 50 | for row in reader: 51 | filepath = rootpath + row[0] 52 | img = Image.open(open(filepath, 'r')) 53 | img = img.resize(IMAGE_SHAPE[:-1], Image.BILINEAR) 54 | images.append(np.asarray(img)) 55 | labels.append(int(row[1])) 56 | 57 | rand_idx = range(len(images)) 58 | np.random.shuffle(rand_idx) 59 | split = int(len(images) * train_fraction) 60 | images = np.asarray(images) 61 | images = images[rand_idx] 62 | train_img = images[:split] 63 | test_img = images[split:] 64 | 65 | labels = np.asarray(labels) 66 | labels = labels[rand_idx] 67 | train_labels = labels[:split] 68 | test_labels = labels[split:] 69 | 70 | pickle.dump({'images': train_img, 'labels': train_labels}, 71 | open('synth_signs_train.p', "wb")) 72 | pickle.dump({'images': test_img, 'labels': test_labels}, 73 | open('synth_signs_test.p', "wb")) 74 | -------------------------------------------------------------------------------- /semisup/tools/usps.py: -------------------------------------------------------------------------------- 1 | """ 2 | Download USPS dataset from 3 | http://statweb.stanford.edu/~tibs/ElemStatLearn/data.html 4 | 5 | Explicit links: 6 | Training: http://statweb.stanford.edu/~tibs/ElemStatLearn/datasets/zip.train.gz 7 | Test: http://statweb.stanford.edu/~tibs/ElemStatLearn/datasets/zip.test.gz 8 | """ 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import gzip 13 | import numpy as np 14 | import data_dirs 15 | 16 | DATADIR = data_dirs.usps 17 | 18 | NUM_LABELS = 10 19 | IMAGE_SHAPE = [16, 16, 1] 20 | 21 | 22 | def get_data(name): 23 | """Utility for convenient data loading.""" 24 | if name in ['train', 'unlabeled']: 25 | return extract_images_labels(DATADIR + '/zip.train.gz') 26 | elif name == 'test': 27 | return extract_images_labels(DATADIR + '/zip.test.gz') 28 | 29 | 30 | def extract_images_labels(filename): 31 | print('Extracting', filename) 32 | with gzip.open(filename, 'rb') as f: 33 | raw_data = f.read().split() 34 | data = np.asarray([raw_data[start:start + 257] 35 | for start in range(0, len(raw_data), 257)], 36 | dtype=np.float32) 37 | images_vec = data[:, 1:] 38 | images = np.expand_dims( 39 | np.reshape(images_vec, (images_vec.shape[0], 16, 16)), axis=3) 40 | labels = data[:, 0].astype(int) 41 | return images, labels 42 | -------------------------------------------------------------------------------- /semisup/train.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | """ 3 | Copyright 2016 Google Inc. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | 17 | Association-based semi-supervised training module. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import sys 25 | from functools import partial 26 | from importlib import import_module 27 | 28 | import numpy as np 29 | import semisup 30 | import tensorflow as tf 31 | import tensorflow.contrib.slim as slim 32 | from tensorflow.python.platform import app 33 | from tensorflow.python.platform import flags 34 | from tensorflow.python.training import saver as tf_saver 35 | 36 | FLAGS = flags.FLAGS 37 | 38 | flags.DEFINE_string('dataset', 'svhn', 'Which dataset to work on.') 39 | 40 | flags.DEFINE_string('target_dataset', None, 41 | 'If specified, perform domain adaptation using dataset as ' 42 | 'source domain and target_dataset as target domain.') 43 | 44 | flags.DEFINE_string('target_dataset_split', 'unlabeled', 45 | 'Which split of the target dataset to use for domain ' 46 | 'adaptation.') 47 | 48 | flags.DEFINE_string('architecture', 'svhn_model', 'Which network architecture ' 49 | 'from architectures.py to use.') 50 | 51 | flags.DEFINE_integer('sup_per_class', 100, 52 | 'Number of labeled samples used per class in total.' 53 | ' -1 = all') 54 | 55 | flags.DEFINE_integer('unsup_samples', -1, 56 | 'Number of unlabeled samples used in total. -1 = all.') 57 | 58 | flags.DEFINE_integer('sup_seed', -1, 59 | 'Integer random seed used for labeled set selection.') 60 | 61 | flags.DEFINE_integer('sup_per_batch', 10, 62 | 'Number of labeled samples per class per batch.') 63 | 64 | flags.DEFINE_integer('unsup_batch_size', 100, 65 | 'Number of unlabeled samples per batch.') 66 | 67 | flags.DEFINE_integer('emb_size', 128, 68 | 'Size of the embeddings to learn.') 69 | 70 | flags.DEFINE_float('learning_rate', 1e-4, 'Initial learning rate.') 71 | 72 | flags.DEFINE_float('minimum_learning_rate', 1e-6, 73 | 'Lower bound for learning rate.') 74 | 75 | flags.DEFINE_float('decay_factor', 0.33, 'Learning rate decay factor.') 76 | 77 | flags.DEFINE_float('decay_steps', 60000, 78 | 'Learning rate decay interval in steps.') 79 | 80 | flags.DEFINE_float('visit_weight', 0.0, 'Weight for visit loss.') 81 | 82 | flags.DEFINE_string('visit_weight_envelope', None, 83 | 'Increase visit weight with an envelope: [None, sigmoid, linear]') 84 | 85 | flags.DEFINE_integer('visit_weight_envelope_steps', -1, 86 | 'Number of steps (after delay) at which envelope ' 87 | 'saturates. -1 = follow walker loss env.') 88 | 89 | flags.DEFINE_integer('visit_weight_envelope_delay', -1, 90 | 'Number of steps at which envelope starts. -1 = follow ' 91 | 'walker loss env.') 92 | 93 | flags.DEFINE_float('walker_weight', 1.0, 'Weight for walker loss.') 94 | 95 | flags.DEFINE_string('walker_weight_envelope', None, 96 | 'Increase walker weight with an envelope: [None, sigmoid, linear]') 97 | 98 | flags.DEFINE_integer('walker_weight_envelope_steps', 100, 99 | 'Number of steps (after delay) at which envelope ' 100 | 'saturates.') 101 | 102 | flags.DEFINE_integer('walker_weight_envelope_delay', 3000, 103 | 'Number of steps at which envelope starts.') 104 | 105 | flags.DEFINE_float('logit_weight', 1.0, 'Weight for logit loss.') 106 | 107 | flags.DEFINE_integer('max_steps', 100000, 'Number of training steps.') 108 | 109 | flags.DEFINE_bool('augmentation', False, 110 | 'Apply data augmentation during training.') 111 | 112 | flags.DEFINE_integer('new_size', 0, 113 | 'If > 0, resize image to this width/height.') 114 | 115 | flags.DEFINE_integer('virtual_embeddings', 0, 116 | 'How many virtual embeddings to add.') 117 | 118 | flags.DEFINE_string('logdir', '/tmp/semisup', 'Training log path.') 119 | 120 | flags.DEFINE_integer('save_summaries_secs', 150, 121 | 'How often should summaries be saved (in seconds).') 122 | 123 | flags.DEFINE_integer('save_interval_secs', 300, 124 | 'How often should checkpoints be saved (in seconds).') 125 | 126 | flags.DEFINE_integer('log_every_n_steps', 100, 127 | 'Logging interval for slim training loop.') 128 | 129 | flags.DEFINE_integer('max_checkpoints', 5, 130 | 'Maximum number of recent checkpoints to keep.') 131 | 132 | flags.DEFINE_float('keep_checkpoint_every_n_hours', 5.0, 133 | 'How often checkpoints should be kept.') 134 | 135 | flags.DEFINE_float('batch_norm_decay', 0.99, 136 | 'Batch norm decay factor ' 137 | '(only used for STL-10 at the moment.') 138 | 139 | flags.DEFINE_integer('remove_classes', 0, 140 | 'Remove this number of classes from the labeled set, ' 141 | 'starting with highest label number.') 142 | 143 | flags.DEFINE_string('master', '', 144 | 'BNS name of the TensorFlow master to use.') 145 | 146 | flags.DEFINE_integer('ps_tasks', 0, 147 | 'The number of parameter servers. If the value is 0, ' 148 | 'then the parameters ' 149 | 'are handled locally by the worker.') 150 | 151 | flags.DEFINE_integer('task', 0, 152 | 'The Task ID. This value is used when training with ' 153 | 'multiple workers to identify each worker.') 154 | 155 | 156 | def logistic_growth(current_step, target, steps): 157 | """Logistic envelope from zero to target value. 158 | 159 | This can be used to slowly increase parameters or weights over the course of 160 | training. 161 | 162 | Args: 163 | current_step: Current step (e.g. tf.get_global_step()) 164 | target: Target value > 0. 165 | steps: Twice the number of steps after which target/2 should be reached. 166 | Returns: 167 | TF tensor holding the target value modulated by a logistic function. 168 | 169 | """ 170 | assert target > 0., 'Target value must be positive.' 171 | alpha = 5. / steps 172 | current_step = tf.cast(current_step, tf.float32) 173 | steps = tf.cast(steps, tf.float32) 174 | return target * (tf.tanh(alpha * (current_step - steps / 2.)) + 1.) / 2. 175 | 176 | 177 | def apply_envelope(type, step, final_weight, growing_steps, delay): 178 | assert growing_steps > 0, "Growing steps for envelope must be > 0." 179 | step = tf.cast(step - delay, tf.float32) 180 | final_step = growing_steps + delay 181 | 182 | if type is None: 183 | value = final_weight 184 | 185 | elif type in ['sigmoid', 'sigmoidal', 'logistic', 'log']: 186 | value = logistic_growth(step, final_weight, final_step) 187 | 188 | elif type in ['linear', 'lin']: 189 | m = float(final_weight) / ( 190 | growing_steps) if not growing_steps == 0.0 else 999. 191 | value = m * step 192 | 193 | else: 194 | raise NameError('Invalid type: ' + str(type)) 195 | 196 | return tf.clip_by_value(value, 0., final_weight) 197 | 198 | 199 | def main(argv): 200 | del argv 201 | 202 | # Load data. 203 | dataset_tools = import_module('tools.' + FLAGS.dataset) 204 | train_images, train_labels = dataset_tools.get_data('train') 205 | if FLAGS.target_dataset is not None: 206 | target_dataset_tools = import_module('tools.' + FLAGS.target_dataset) 207 | train_images_unlabeled, _ = target_dataset_tools.get_data( 208 | FLAGS.target_dataset_split) 209 | else: 210 | train_images_unlabeled, _ = dataset_tools.get_data('unlabeled') 211 | 212 | architecture = getattr(semisup.architectures, FLAGS.architecture) 213 | 214 | num_labels = dataset_tools.NUM_LABELS 215 | image_shape = dataset_tools.IMAGE_SHAPE 216 | 217 | # Sample labeled training subset. 218 | seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else None 219 | sup_by_label = semisup.sample_by_label(train_images, train_labels, 220 | FLAGS.sup_per_class, num_labels, 221 | seed) 222 | 223 | # Sample unlabeled training subset. 224 | if FLAGS.unsup_samples > -1: 225 | num_unlabeled = len(train_images_unlabeled) 226 | assert FLAGS.unsup_samples <= num_unlabeled, ( 227 | 'Chose more unlabeled samples ({})' 228 | ' than there are in the ' 229 | 'unlabeled batch ({}).'.format(FLAGS.unsup_samples, num_unlabeled)) 230 | 231 | rng = np.random.RandomState(seed=seed) 232 | train_images_unlabeled = train_images_unlabeled[rng.choice( 233 | num_unlabeled, FLAGS.unsup_samples, False)] 234 | 235 | graph = tf.Graph() 236 | with graph.as_default(): 237 | with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks, 238 | merge_devices=True)): 239 | 240 | # Set up inputs. 241 | t_unsup_images = semisup.create_input(train_images_unlabeled, None, 242 | FLAGS.unsup_batch_size) 243 | t_sup_images, t_sup_labels = semisup.create_per_class_inputs( 244 | sup_by_label, FLAGS.sup_per_batch) 245 | 246 | if FLAGS.remove_classes: 247 | t_sup_images = tf.slice( 248 | t_sup_images, [0, 0, 0, 0], 249 | [FLAGS.sup_per_batch * ( 250 | num_labels - FLAGS.remove_classes)] + 251 | image_shape) 252 | 253 | # Resize if necessary. 254 | if FLAGS.new_size > 0: 255 | new_shape = [FLAGS.new_size, FLAGS.new_size, image_shape[-1]] 256 | else: 257 | new_shape = None 258 | 259 | # Apply augmentation 260 | if FLAGS.augmentation: 261 | # TODO(haeusser) generalize augmentation 262 | def _random_invert(inputs, _): 263 | randu = tf.random_uniform( 264 | shape=[FLAGS.sup_per_batch * num_labels], minval=0., 265 | maxval=1., 266 | dtype=tf.float32) 267 | randu = tf.cast(tf.less(randu, 0.5), tf.float32) 268 | randu = tf.expand_dims(randu, 1) 269 | randu = tf.expand_dims(randu, 1) 270 | randu = tf.expand_dims(randu, 1) 271 | inputs = tf.cast(inputs, tf.float32) 272 | return tf.abs(inputs - 255 * randu) 273 | 274 | augmentation_function = _random_invert 275 | else: 276 | augmentation_function = None 277 | 278 | # Create function that defines the network. 279 | model_function = partial( 280 | architecture, 281 | new_shape=new_shape, 282 | img_shape=image_shape, 283 | augmentation_function=augmentation_function, 284 | batch_norm_decay=FLAGS.batch_norm_decay, 285 | emb_size=FLAGS.emb_size) 286 | 287 | # Set up semisup model. 288 | model = semisup.SemisupModel(model_function, num_labels, 289 | image_shape) 290 | 291 | # Compute embeddings and logits. 292 | t_sup_emb = model.image_to_embedding(t_sup_images) 293 | t_unsup_emb = model.image_to_embedding(t_unsup_images) 294 | 295 | # Add virtual embeddings. 296 | if FLAGS.virtual_embeddings: 297 | t_sup_emb = tf.concat([ 298 | t_sup_emb, semisup.create_virt_emb(FLAGS.virtual_embeddings, 299 | FLAGS.emb_size) 300 | ], 0) 301 | 302 | if not FLAGS.remove_classes: 303 | # need to add additional labels for virtual embeddings 304 | t_sup_labels = tf.concat(0, [ 305 | t_sup_labels, 306 | (num_labels + tf.range(1, FLAGS.virtual_embeddings + 1, 307 | tf.int64)) 308 | * tf.ones([FLAGS.virtual_embeddings], tf.int64) 309 | ]) 310 | 311 | t_sup_logit = model.embedding_to_logit(t_sup_emb) 312 | 313 | # Add losses. 314 | visit_weight_envelope_steps = ( 315 | FLAGS.walker_weight_envelope_steps 316 | if FLAGS.visit_weight_envelope_steps == -1 317 | else FLAGS.visit_weight_envelope_steps) 318 | visit_weight_envelope_delay = ( 319 | FLAGS.walker_weight_envelope_delay 320 | if FLAGS.visit_weight_envelope_delay == -1 321 | else FLAGS.visit_weight_envelope_delay) 322 | visit_weight = apply_envelope( 323 | type=FLAGS.visit_weight_envelope, 324 | step=model.step, 325 | final_weight=FLAGS.visit_weight, 326 | growing_steps=visit_weight_envelope_steps, 327 | delay=visit_weight_envelope_delay) 328 | walker_weight = apply_envelope( 329 | type=FLAGS.walker_weight_envelope, 330 | step=model.step, 331 | final_weight=FLAGS.walker_weight, 332 | growing_steps=FLAGS.walker_weight_envelope_steps, # pylint:disable=line-too-long 333 | delay=FLAGS.walker_weight_envelope_delay) 334 | tf.summary.scalar('Weights_Visit', visit_weight) 335 | tf.summary.scalar('Weights_Walker', walker_weight) 336 | 337 | if FLAGS.unsup_samples != 0: 338 | model.add_semisup_loss(t_sup_emb, 339 | t_unsup_emb, 340 | t_sup_labels, 341 | visit_weight=visit_weight, 342 | walker_weight=walker_weight) 343 | 344 | model.add_logit_loss(t_sup_logit, 345 | t_sup_labels, 346 | weight=FLAGS.logit_weight) 347 | 348 | # Set up learning rate 349 | t_learning_rate = tf.maximum( 350 | tf.train.exponential_decay( 351 | FLAGS.learning_rate, 352 | model.step, 353 | FLAGS.decay_steps, 354 | FLAGS.decay_factor, 355 | staircase=True), 356 | FLAGS.minimum_learning_rate) 357 | 358 | # Create training operation and start the actual training loop. 359 | train_op = model.create_train_op(t_learning_rate) 360 | 361 | config = tf.ConfigProto() 362 | config.gpu_options.allow_growth = True 363 | # config.log_device_placement = True 364 | 365 | saver = tf_saver.Saver(max_to_keep=FLAGS.max_checkpoints, 366 | keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours) # pylint:disable=line-too-long 367 | 368 | slim.learning.train( 369 | train_op, 370 | logdir=FLAGS.logdir + '/train', 371 | save_summaries_secs=FLAGS.save_summaries_secs, 372 | save_interval_secs=FLAGS.save_interval_secs, 373 | master=FLAGS.master, 374 | is_chief=(FLAGS.task == 0), 375 | startup_delay_steps=(FLAGS.task * 20), 376 | log_every_n_steps=FLAGS.log_every_n_steps, 377 | session_config=config, 378 | trace_every_n_steps=1000, 379 | saver=saver, 380 | number_of_steps=FLAGS.max_steps, 381 | ) 382 | 383 | 384 | if __name__ == '__main__': 385 | tf.logging.set_verbosity(tf.logging.INFO) 386 | app.run() 387 | --------------------------------------------------------------------------------