├── .gitignore ├── AUTHORS ├── LICENSE ├── README.md ├── act.py ├── act_test.py ├── cifar_data_provider.py ├── cifar_data_provider_test.py ├── cifar_main.py ├── cifar_model.py ├── cifar_model_test.py ├── draw_ponder_maps.py ├── external ├── __init__.py ├── dataset_utils.py ├── datasets_cifar10.py ├── datasets_imagenet.py ├── download_and_convert_cifar10.py └── inception_preprocessing.py ├── fake_cifar10.py ├── fake_imagenet.py ├── flopsometer.py ├── flopsometer_test.py ├── imagenet_data_provider.py ├── imagenet_data_provider_test.py ├── imagenet_eval.py ├── imagenet_export.py ├── imagenet_model.py ├── imagenet_model_test.py ├── imagenet_ponder_map.py ├── imagenet_train.py ├── pics ├── 20.92_93_im.jpg ├── 20.92_93_ponder.png ├── 22.28_95_im.jpg ├── 22.28_95_ponder.png ├── 26.75_36_im.jpg ├── 26.75_36_ponder.png ├── cat.jpg ├── cat_colorbar.jpg ├── cat_ponder.jpg ├── export-image-442041-ponder.jpg ├── export-image-442041.jpg ├── gasworks.jpg ├── gasworks_colorbar.jpg └── gasworks_ponder.jpg ├── requirements-gpu.txt ├── requirements.txt ├── resnet_act.py ├── squeeze_model.py ├── summary_utils.py ├── summary_utils_test.py ├── testdata ├── cifar10 │ ├── cifar10_test.tfrecord │ └── cifar10_train.tfrecord └── imagenet │ ├── train-00000-of-00001 │ └── validation-00000-of-00001 ├── training_utils.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | # This is the official list of authors for copyright purposes. 2 | # This file is distinct from the CONTRIBUTORS files. 3 | # See the latter for an explanation. 4 | 5 | # Names should be added to this file as: 6 | # Name or Organization 7 | # The email address is not required for organizations. 8 | 9 | Google Inc. 10 | Michael Figurnov 11 | Maxwell D Collins 12 | Yukun Zhu 13 | Li Zhang 14 | Jonathan Huang 15 | 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright 2016 Google Inc. 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Spatially Adaptive Computation Time for Residual Networks 2 | 3 | This code implements a deep learning architecture based on Residual Network that dynamically adjusts the number of executed layers for the regions of the image. 4 | The architecture is end-to-end trainable, deterministic and problem-agnostic. 5 | The included code applies this to the CIFAR-10 an ImageNet image classification problems. 6 | It is implemented using TensorFlow and TF-Slim. 7 | 8 | Paper describing the project: 9 | 10 | Michael Figurnov, Maxwell D. Collins, Yukun Zhu, Li Zhang, Jonathan Huang, Dmitry Vetrov, Ruslan Salakhutdinov. Spatially Adaptive Computation Time for Residual Networks. *CVPR 2017* [[arxiv]](https://arxiv.org/abs/1612.02297). 11 | 12 | Image (with detections) | Ponder cost map 13 | :-------------------------------:|:--------------------------------------: 14 | ![](pics/export-image-442041.jpg)|![](pics/export-image-442041-ponder.jpg) 15 | 16 | ## Setup 17 | 18 | Install prerequisites: 19 | 20 | ``` bash 21 | pip install -r requirements.txt # CPU 22 | pip install -r requirements-gpu.txt # GPU 23 | ``` 24 | 25 | Prerequisite packages: 26 | - Python 2.x/3.x (mostly tested with Python 2.7) 27 | - Tensorflow 1.0 28 | - NumPy 29 | - (Optional) nose 30 | - (Optional) h5py 31 | - (Optional) matplotlib 32 | 33 | Run tests. It takes a couple of minutes: 34 | 35 | ``` bash 36 | nosetests --logging-level=WARNING 37 | ``` 38 | 39 | ## CIFAR-10 40 | 41 | Download and convert CIFAR-10 dataset: 42 | 43 | ``` bash 44 | PYTHONPATH=external python external/download_and_convert_cifar10.py --dataset_dir="${HOME}/tensorflow/data/cifar10" 45 | ``` 46 | 47 | Let's train and continuously evaluate a CIFAR-10 Adaptive Computation Time model with five residual units per block (ResNet-32): 48 | 49 | ``` bash 50 | export ACT_LOGDIR='/tmp/cifar10_resnet_5_act_1e-2' 51 | python cifar_main.py --model_type=act --model=5 --tau=0.01 --train_log_dir="${ACT_LOGDIR}/train" --save_summaries_secs=300 & 52 | python cifar_main.py --model_type=act --model=5 --tau=0.01 --checkpoint_dir="${ACT_LOGDIR}/train" --eval_dir="${ACT_LOGDIR}/eval" --mode=eval 53 | ``` 54 | 55 | Or, for _spatially_ adaptive computation time (SACT): 56 | 57 | ``` bash 58 | export SACT_LOGDIR='/tmp/cifar10_resnet_5_sact_1e-2' 59 | python cifar_main.py --model_type=sact --model=5 --tau=0.01 --train_log_dir="${SACT_LOGDIR}/train" --save_summaries_secs=300 & 60 | python cifar_main.py --model_type=sact --model=5 --tau=0.01 --checkpoint_dir="${SACT_LOGDIR}/train" --eval_dir="${SACT_LOGDIR}/eval" --mode=eval 61 | ``` 62 | 63 | To download and evaluate a [pretrained ResNet-32 SACT model](https://s3.us-east-2.amazonaws.com/sact-models/cifar10_resnet_5_sact_1e-2.tar.gz) (1.8 MB file): 64 | 65 | ``` bash 66 | mkdir -p models && curl https://s3.us-east-2.amazonaws.com/sact-models/cifar10_resnet_5_sact_1e-2.tar.gz | tar xv -C models 67 | python cifar_main.py --model_type=sact --model=5 --tau=0.01 --checkpoint_dir='models/cifar10_resnet_5_sact_1e-2' --mode=eval --eval_dir='/tmp' --evaluate_once 68 | ``` 69 | 70 | This model is expected to achieve an accuracy of 91.82%, with the output looking like so: 71 | 72 | ``` 73 | eval/Accuracy[0.9182] 74 | eval/Mean Loss[0.59591407] 75 | Total Flops/mean[82393168] 76 | Total Flops/std[7588926] 77 | ... 78 | ``` 79 | 80 | ## ImageNet 81 | 82 | Follow the [instructions](https://github.com/tensorflow/models/tree/master/research/inception#getting-started) to prepare the ImageNet dataset in TF-Slim format. 83 | The default directory for the dataset is `~/tensorflow/imagenet`. 84 | You can change it with the `--dataset_dir` flag. 85 | 86 | We initialized all ACT/SACT models with a [pretrained ResNet-101 model](https://s3.us-east-2.amazonaws.com/sact-models/imagenet_101.tar.gz) (159MB file). 87 | 88 | Download [pretrained ResNet-101 SACT model](https://s3.us-east-2.amazonaws.com/sact-models/imagenet_101_sact_5e-3.tar.gz), trained with tau=0.005 (160 MB file): 89 | ``` bash 90 | mkdir -p models && curl https://s3.us-east-2.amazonaws.com/sact-models/imagenet_101_sact_5e-3.tar.gz | tar xv -C models 91 | ``` 92 | 93 | Evaluate the pretrained model 94 | ``` bash 95 | python imagenet_eval.py --model_type=sact --model=101 --tau=0.005 --checkpoint_dir=models/imagenet_101_sact_5e-3 --eval_dir=/tmp --evaluate_once 96 | ``` 97 | 98 | Expected output: 99 | ``` 100 | eval/Accuracy[0.75609803] 101 | eval/Recall@5[0.9274632117722329] 102 | Total Flops/mean[1.1100941e+10] 103 | Total Flops/std[4.5691142e+08] 104 | ... 105 | ``` 106 | 107 | Note that evaluation on the full validation dataset will take some time using only CPU. 108 | Add the arguments `--num_examples=10 --batch_size=10` for a quicker test. 109 | 110 | Draw some images from ImageNet validation set and the corresponding ponder cost maps: 111 | 112 | ``` bash 113 | python imagenet_export.py --model_type=sact --model=101 --tau=0.005 --checkpoint_dir=models/imagenet_101_sact_5e-3 --export_path=/tmp/maps.h5 --batch_size=1 --num_examples=200 114 | 115 | mkdir /tmp/maps 116 | python draw_ponder_maps.py --input_file=/tmp/maps.h5 --output_dir=/tmp/maps 117 | ``` 118 | 119 | Example visualizations. See Figure 9 of the paper for more 120 | 121 | Image | Ponder cost map 122 | :-------------------------:|:----------------------------: 123 | ![](pics/20.92_93_im.jpg) | ![](pics/20.92_93_ponder.png) 124 | ![](pics/22.28_95_im.jpg) | ![](pics/22.28_95_ponder.png) 125 | ![](pics/26.75_36_im.jpg) | ![](pics/26.75_36_ponder.png) 126 | 127 | Apply the pretrained model to your own jpeg images. 128 | For best results, first resize them to somewhere between 320x240 and 640x480. 129 | 130 | ``` bash 131 | python2 imagenet_ponder_map.py --model=101 --checkpoint_dir=models/imagenet_101_sact_5e-3 --images_pattern=pics/gasworks.jpg --output_dir output/ 132 | ``` 133 | 134 | Image | Ponder cost map | Colorbar 135 | :--------------------:|:------------------------------:|--------- 136 | ![](pics/gasworks.jpg)| ![](pics/gasworks_ponder.jpg) | ![](pics/gasworks_colorbar.jpg) 137 | ![](pics/cat.jpg) | ![](pics/cat_ponder.jpg) | ![](pics/cat_colorbar.jpg) 138 | 139 | Note that an ImageNet-pretrained model tends to ignore people - there is no "person" class in ImageNet! 140 | 141 | ## Disclaimer 142 | 143 | This is not an official Google product. 144 | -------------------------------------------------------------------------------- /act.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Functions for adaptive computation time.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | 25 | def adaptive_computation_time(halting_proba, eps=1e-2): 26 | """Gets cost, number of steps and halting dist. for adaptive computation time. 27 | 28 | See Alex Graves "Adaptive Computation Time for Recurrent Neural Networks" 29 | https://arxiv.org/pdf/1603.08983v4.pdf 30 | 31 | Also see notes by Hugo Larochelle: 32 | https://www.evernote.com/shard/s189/sh/fd165646-b630-48b7-844c-86ad2f07fcda/c9ab960af967ef847097f21d94b0bff7 33 | 34 | This module makes several assumptions: 35 | 1) The maximum number of units is `max_units`. 36 | 2) We run all the units for each object during training and inference. 37 | The unused units are simply "masked". 38 | 39 | Args: 40 | halting_proba: A 2-D `Tensor` of type `float32`. Probabilities 41 | of halting the computation at a given unit for the object. 42 | Shape is `[batch, max_units - 1]`. 43 | The values need to be in the range [0, 1]. 44 | eps: A `float` in the range [0, 1]. Small number to ensure that 45 | the computation can halt after the first unit. 46 | 47 | Returns: 48 | ponder_cost: An 1-D `Tensor` of type `float32`. 49 | A differentiable upper bound on the number of units. 50 | num_units: An 1-D `Tensor` of type `int32`. 51 | Actual number of units that were actually executed. 52 | num_units < ponder_cost. 53 | halting_distribution: A 2-D `Tensor` of type `float32`. 54 | Shape is `[batch, max_units]`. Halting probability distribution. 55 | halting_distribution[i, j] is probability of computation for i-th object 56 | to halt at j-th unit. Sum of every row should be close to one. 57 | """ 58 | sh = halting_proba.get_shape().as_list() 59 | batch = sh[0] 60 | max_units = sh[1] + 1 61 | 62 | zero_col = tf.zeros((batch, 1)) 63 | 64 | halting_padded = tf.concat([halting_proba, zero_col], 1) 65 | 66 | halting_cumsum = tf.cumsum(halting_proba, axis=1) 67 | halting_cumsum_padded = tf.concat([zero_col, halting_cumsum], 1) 68 | 69 | # Does computation halt at this unit? 70 | halt_flag = (halting_cumsum >= 1 - eps) 71 | # Always halt at the final unit. 72 | halt_flag_final = tf.concat([halt_flag, tf.fill([batch, 1], True)], 1) 73 | 74 | # Halting iteration (zero-based), eqn. (7). 75 | # Add a decaying value that ensures that the first true value is selected. 76 | # The decay value is always less than one. 77 | decay = 1. / (2. + tf.to_float(tf.range(max_units))) 78 | halt_flag_final_with_decay = tf.to_float(halt_flag_final) + decay[None, :] 79 | N = tf.to_int32(tf.argmax(halt_flag_final_with_decay, dimension=1)) 80 | 81 | N = tf.stop_gradient(N) 82 | 83 | # Fancy indexing to obtain the value of the remainder. Eqn. (8). 84 | N_indices = tf.range(batch) * max_units + N 85 | remainder = 1 - tf.gather(tf.reshape(halting_cumsum_padded, [-1]), N_indices) 86 | 87 | # Switch to one-based indexing here for num_units. 88 | num_units = N + 1 89 | ponder_cost = tf.to_float(num_units) + remainder 90 | 91 | unit_index = tf.range(max_units)[None, :] 92 | # Calculate the halting distribution, eqn. (6). 93 | # Fill the first N steps with the halting probabilities. 94 | # Next values are zero. 95 | p = tf.where(tf.less(unit_index, N[:, None]), 96 | halting_padded, 97 | tf.zeros((batch, max_units))) 98 | # Fill the (N+1)-st step with the remainder value. 99 | p = tf.where(tf.equal(unit_index, N[:, None]), 100 | tf.tile(remainder[:, None], tf.stack([1, max_units])), 101 | p) 102 | halting_distribution = p 103 | 104 | return (ponder_cost, num_units, halting_distribution) 105 | 106 | 107 | def run_units(inputs, unit, max_units, scope, reuse=False): 108 | """Helper function for running units of the network.""" 109 | states = [] 110 | halting_probas = [] 111 | all_flops = [] 112 | with tf.variable_scope(scope, reuse=reuse): 113 | state = inputs 114 | for unit_idx in range(max_units): 115 | state, halting_proba, flops = unit(state, unit_idx) 116 | states.append(state) 117 | halting_probas.append(halting_proba) 118 | all_flops.append(flops) 119 | return states, halting_probas, all_flops 120 | 121 | 122 | def adaptive_computation_time_wrapper(inputs, unit, max_units, 123 | eps=1e-2, scope='act'): 124 | """A wrapper of `adaptive_computation_time`. 125 | 126 | Wraps `adaptive_computation_time` with an interface compatible with 127 | `adaptive_computation_early_stopping`. Should do the same thing as 128 | `adaptive_computation_early_stopping` but should work in cases when tf.cond 129 | fails. 130 | """ 131 | states, halting_probas, all_flops = run_units(inputs, unit, 132 | max_units, scope) 133 | 134 | (ponder_cost, num_units, halting_distribution) = \ 135 | adaptive_computation_time(tf.concat(halting_probas[:-1], 1), eps=eps) 136 | 137 | if states[0].get_shape().is_fully_defined(): 138 | sh = states[0].get_shape().as_list() 139 | else: 140 | sh = tf.shape(states[0]) 141 | batch = sh[0] 142 | h = tf.reshape(halting_distribution, [batch, 1, max_units]) 143 | s = tf.reshape(tf.stack(states, axis=1), [batch, max_units, -1]) 144 | outputs = tf.matmul(h, s) 145 | outputs = tf.reshape(outputs, sh) 146 | 147 | flops_per_iteration = [ 148 | f * tf.to_int64(num_units > i) for (i, f) in enumerate(all_flops) 149 | ] 150 | flops = tf.add_n(flops_per_iteration) 151 | 152 | return (ponder_cost, num_units, flops, halting_distribution, outputs) 153 | 154 | 155 | def adaptive_computation_early_stopping(inputs, unit, max_units, 156 | eps=1e-2, scope='act'): 157 | """Builds adaptive computation module with early stopping of computation. 158 | 159 | `adaptive_computation_time` requires all units to be always 160 | computed. This function stops the computation as soon as all objects in the 161 | batch halt. However, if any object still needs calculation, the 162 | unit is executed for all objects. 163 | 164 | See `adaptive_computation_time` description for more information. 165 | 166 | Args: 167 | inputs: Input state at the first unit. Can have different shape from 168 | state and output. Should have fully defined shape. 169 | unit: A function which is called as follows: 170 | `new_state, halting_proba, flops = unit(old_state, unit_idx)` 171 | If `unit_idx==1`, then `old_state` is `inputs`. 172 | Flops should be a 1-D `Tensor` of length batch_size of type `int64`. 173 | It can perform different computation depending on `unit_idx`. 174 | The function should not have any Python side-effects (due to `tf.cond` 175 | implementation). 176 | 177 | The function is called two times for each `unit_idx`. 178 | 1) Outside `tf.cond` to create the necessary variables with reuse=False. 179 | 2) Inside `tf.cond` with reuse=True. 180 | For this reason, all variables should have static names. 181 | Good: `w = tf.get_variable('weights', [5, 3])` 182 | Bad: `w = tf.Variable(tf.zeros([5, 3])) # The name is auto-generated` 183 | max_units: Maximum number of units. 184 | eps: A `float` in the range [0, 1]. Small number to ensure that 185 | the computation can halt after the first unit. 186 | scope: variable scope or scope name in which the layers are created. 187 | Defaults to 'act'. 188 | 189 | Returns: 190 | ponder_cost: A 1-D `Tensor` of type `float32`. 191 | A differentiable upper bound on the number of units. 192 | num_units: A 1-D `Tensor` of type `int32`. 193 | Actual number of units that took place. num_units < ponder_cost. 194 | flops: A 1-D `Tensor` of type `int64`. 195 | Number of floating point operations that took place. 196 | halting_distribution: A 2-D `Tensor` of type `float`. 197 | Shape is `[batch, max_units]`. Halting probability distribution. 198 | halting_distribution[i, j] is probability of computation for i-th object 199 | to halt at j-th unit. Sum of every row should be close to one. 200 | outputs: A `Tensor` of shape [batch, ...]. Has same shape as states. 201 | Outputs of the ACT module, intermediate states weighted 202 | by the halting distribution for the units. 203 | """ 204 | if inputs.get_shape().is_fully_defined(): 205 | sh = inputs.get_shape().as_list() 206 | else: 207 | sh = tf.shape(inputs) 208 | batch = sh[0] 209 | inputs_rank = len(sh) 210 | 211 | def _body(unit_idx, state, halting_cumsum, elements_finished, remainder, 212 | ponder_cost, num_units, flops, outputs): 213 | 214 | (new_state, halting_proba, cur_flops) = unit(state, unit_idx) 215 | 216 | # We always halt at the last unit. 217 | if unit_idx < max_units - 1: 218 | halting_proba = tf.reshape(halting_proba, [batch]) 219 | else: 220 | halting_proba = tf.ones([batch]) 221 | 222 | halting_cumsum += halting_proba 223 | cur_elements_finished = (halting_cumsum >= 1 - eps) 224 | # Zero out halting_proba for the previously finished objects. 225 | halting_proba = tf.where(cur_elements_finished, 226 | tf.zeros([batch]), 227 | halting_proba) 228 | # Find objects which have halted at the current unit. 229 | just_finished = tf.logical_and(tf.logical_not(elements_finished), 230 | cur_elements_finished) 231 | # For such objects, the halting distribution value is the remainder. 232 | # For others, it is the halting_proba. 233 | cur_halting_distrib = tf.where(just_finished, 234 | remainder, 235 | halting_proba) 236 | 237 | # Update ponder_cost. Add 1 to objects which are still computed, 238 | # remainder to the objects which have just halted and 239 | # 0 to the previously halted objects. 240 | ponder_cost += tf.where( 241 | cur_elements_finished, 242 | tf.where(just_finished, remainder, tf.zeros([batch])), 243 | tf.ones([batch])) 244 | 245 | # Add a unit to the objects that were active during this unit 246 | # (not the ones that will be active the next unit). 247 | evaluated_objects = tf.logical_not(elements_finished) 248 | num_units += tf.to_int32(evaluated_objects) 249 | 250 | # Update the FLOPS counters for the same objects. 251 | flops += cur_flops * tf.to_int64(evaluated_objects) 252 | 253 | # Add new state to the outputs weighted by the halting distribution. 254 | outputs += new_state * tf.reshape(cur_halting_distrib, 255 | [-1] + [1] * (inputs_rank - 1)) 256 | 257 | remainder -= halting_proba 258 | 259 | return (new_state, halting_cumsum, cur_elements_finished, remainder, 260 | ponder_cost, num_units, flops, cur_halting_distrib, outputs) 261 | 262 | def _identity(unit_idx, state, halting_cumsum, elements_finished, 263 | remainder, ponder_cost, num_units, flops, outputs): 264 | return (state, halting_cumsum, elements_finished, remainder, ponder_cost, 265 | num_units, flops, tf.zeros([batch]), outputs) 266 | 267 | # Create all the variables and losses outside of tf.cond. 268 | # Without this, regularization losses would not work correctly. 269 | run_units(inputs, unit, max_units, scope) 270 | 271 | state = inputs 272 | halting_cumsum = tf.zeros([batch]) 273 | elements_finished = tf.fill([batch], False) 274 | remainder = tf.ones([batch]) 275 | # Initialize ponder_cost with one to fix an off-by-one error. 276 | ponder_cost = tf.ones([batch]) 277 | num_units = tf.zeros([batch], dtype=tf.int32) 278 | flops = tf.zeros([batch], dtype=tf.int64) 279 | 280 | # We don't know the shape of the outputs. Initialize it to scalar and 281 | # run the first iteration outside tf.cond (it wants outputs of both 282 | # branches to have the same shapes). 283 | outputs = 0. 284 | 285 | # Reuse the variables created above. 286 | with tf.variable_scope(scope, reuse=True): 287 | halting_distribs = [] 288 | for unit_idx in range(max_units): 289 | finished = tf.reduce_all(elements_finished) 290 | args = (unit_idx, state, halting_cumsum, elements_finished, remainder, 291 | ponder_cost, num_units, flops, outputs) 292 | if unit_idx == 0: 293 | return_values = _body(*args) 294 | else: 295 | return_values = tf.cond(finished, 296 | lambda: _identity(*args), 297 | lambda: _body(*args)) 298 | (state, halting_cumsum, elements_finished, remainder, ponder_cost, 299 | num_units, flops, cur_halting_distrib, outputs) = return_values 300 | 301 | halting_distribs.append(tf.reshape(cur_halting_distrib, [batch, 1])) 302 | 303 | halting_distribution = tf.concat(halting_distribs, 1) 304 | 305 | return (ponder_cost, num_units, flops, halting_distribution, outputs) 306 | 307 | 308 | def spatially_adaptive_computation_time(inputs, unit, max_units, 309 | eps=1e-2, scope='act'): 310 | """Spatially adaptive computation time. 311 | 312 | Each spatial position in the states tensor has its own halting distribution. 313 | This allows to process different part of an image for a different number of 314 | units. 315 | 316 | The code is similar to `adaptive_computation_early_stopping`. The differences 317 | are: 318 | 1) The states are expected to be 4-D tensors (Batch-Height-Width-Channels). 319 | ACT is applied for first three dimensions. 320 | 2) unit should have a `residual_mask` argument. It is a `float32` mask 321 | with 1's corresponding to the positions which need to be updated. 322 | 0's should be frozen. For ResNets this can be achieved by multiplying the 323 | residual branch responses by `residual_mask`. 324 | 3) There is no tf.cond part so the computation is not actually saved. 325 | 326 | Args: 327 | inputs: Input states at the first unit, 4-D `Tensor` of type `float32`. 328 | unit: A function. See `adaptive_computation_early_stopping` for 329 | detailed explanation. 330 | max_units: Maximum number of units. 331 | eps: A `float` in the range [0, 1]. Small number to ensure that 332 | the computation can halt after the first unit. 333 | scope: variable scope or scope name in which the layers are created. 334 | Defaults to 'act'. 335 | 336 | Returns: 337 | ponder_cost: A 3-D `Tensor` of type `float32`. 338 | Shape is [batch, height, width]. 339 | A differentiable upper bound on the number of units per spatial position. 340 | num_units: A 3-D `Tensor` of type `int32`. 341 | Shape is [batch, height, width]. 342 | Actual number of units per spatial position that were used. 343 | num_units < ponder_cost. 344 | flops: A 1-D `Tensor` of type `int64`. 345 | Number of floating point operations that were used. 346 | halting_distribution: A 4-D `Tensor` of type `float32`. 347 | Shape is `[batch, height, width, max_units]`. 348 | Halting probability distribution. 349 | halting_distribution[i, h, w, j] is the probability of computation 350 | for i-th object at the spatial position (h, w) to halt at j-th unit. 351 | Sum over the last dimension should be close to one. 352 | outputs: A 4-D `Tensor` of shape [batch, height, width, depth]. Outputs of 353 | the ACT module, intermediate states weighted by the halting distribution 354 | tensor. 355 | """ 356 | with tf.variable_scope(scope): 357 | halting_distribs = [] 358 | for unit_idx in range(max_units): 359 | 360 | if not unit_idx: 361 | (state, halting_proba, flops) = unit( 362 | inputs, unit_idx, residual_mask=None) 363 | 364 | # Initialize the variables which depend on the state shape. 365 | state_shape_fully_defined = state.get_shape().is_fully_defined() 366 | if state_shape_fully_defined: 367 | sh = state.get_shape().as_list() 368 | assert len(sh) == 4 369 | else: 370 | sh = tf.shape(state) 371 | halting_cumsum = tf.zeros(sh[:3]) 372 | elements_finished = tf.fill(sh[:3], False) 373 | remainder = tf.ones(sh[:3]) 374 | # Initialize ponder_cost with one to fix an off-by-one error. 375 | ponder_cost = tf.ones(sh[:3]) 376 | num_units = tf.zeros(sh[:3], dtype=tf.int32) 377 | else: 378 | # Mask out the residual values for the not calculated outputs. 379 | residual_mask = tf.to_float(tf.logical_not(elements_finished)) 380 | residual_mask = tf.expand_dims(residual_mask, 3) 381 | (state, halting_proba, current_flops) = unit( 382 | state, unit_idx, residual_mask=residual_mask) 383 | flops += current_flops 384 | 385 | # We always halt at the last unit. 386 | if unit_idx < max_units - 1: 387 | halting_proba = tf.reshape(halting_proba, sh[:3]) 388 | else: 389 | halting_proba = tf.ones(sh[:3]) 390 | 391 | halting_cumsum += halting_proba 392 | # Which objects are no longer calculated after this unit? 393 | cur_elements_finished = (halting_cumsum >= 1 - eps) 394 | # Zero out halting_proba for the previously finished positions. 395 | halting_proba = tf.where(cur_elements_finished, 396 | tf.zeros(sh[:3]), 397 | halting_proba) 398 | # Find positions which have halted at the current unit. 399 | just_finished = tf.logical_and(tf.logical_not(elements_finished), 400 | cur_elements_finished) 401 | # For such positions, the halting distribution value is the remainder. 402 | # For others, it is the halting_proba. 403 | cur_halting_distrib = tf.where(just_finished, 404 | remainder, 405 | halting_proba) 406 | 407 | # Update ponder_cost. Add 1 to positions which are still computed, 408 | # remainder to the positions which have just halted and 409 | # 0 to the previously halted positions. 410 | ponder_cost += tf.where( 411 | cur_elements_finished, 412 | tf.where(just_finished, remainder, tf.zeros(sh[:3])), 413 | tf.ones(sh[:3])) 414 | 415 | # Add a unit to the positions that were active during this unit 416 | # (not the ones that will be active the next unit). 417 | num_units += tf.to_int32(tf.logical_not(elements_finished)) 418 | 419 | # Add new state to the outputs weighted by the halting distribution. 420 | update = state * tf.expand_dims(cur_halting_distrib, 3) 421 | if unit_idx: 422 | outputs += update 423 | else: 424 | outputs = update 425 | 426 | remainder -= halting_proba 427 | 428 | elements_finished = cur_elements_finished 429 | 430 | halting_distribs.append(cur_halting_distrib) 431 | 432 | halting_distribution = tf.stack(halting_distribs, axis=3) 433 | 434 | if not state_shape_fully_defined: 435 | # Update static shape info. Faster RCNN code wants to know batch dimension 436 | # statically. 437 | outputs.set_shape(inputs.get_shape().as_list()[:1] + [None] * 3) 438 | 439 | return (ponder_cost, num_units, flops, halting_distribution, outputs) 440 | -------------------------------------------------------------------------------- /act_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for adaptive computation time.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import numpy as np 23 | import tensorflow as tf 24 | 25 | import act 26 | 27 | 28 | class ActTest(tf.test.TestCase): 29 | 30 | def testOutputSize(self): 31 | batch_size = 5 32 | max_units = 8 33 | h = tf.sigmoid(tf.random_normal(shape=[batch_size, max_units - 1])) 34 | (cost, num_units, distrib) = act.adaptive_computation_time(h) 35 | with self.test_session() as sess: 36 | (cost_out, num_units_out, distrib_out) = sess.run( 37 | (cost, num_units, distrib)) 38 | self.assertEqual(cost_out.shape, (batch_size,)) 39 | self.assertEqual(num_units_out.shape, (batch_size,)) 40 | self.assertEqual(distrib_out.shape, (batch_size, max_units)) 41 | 42 | def testEqualValuesInBatch(self): 43 | batch_size = 2 44 | max_units = 8 45 | h = tf.sigmoid(tf.random_normal(shape=[1, max_units - 1])) 46 | h = tf.tile(h, tf.stack([batch_size, 1])) 47 | (cost, num_units, distrib) = act.adaptive_computation_time(h) 48 | with self.test_session() as sess: 49 | (cost_out, num_units_out, distrib_out) = sess.run( 50 | (cost, num_units, distrib)) 51 | self.assertAlmostEqual(cost_out[0], cost_out[1]) 52 | self.assertEqual(num_units_out[0], num_units_out[1]) 53 | self.assertAllEqual(distrib_out[0], distrib_out[1]) 54 | 55 | def testStopsAtFirstUnit(self): 56 | h = tf.constant([[0.999] * 4]) 57 | (cost, num_units, distrib) = act.adaptive_computation_time(h, eps=1e-2) 58 | with self.test_session() as sess: 59 | (cost_out, num_units_out, distrib_out) = sess.run( 60 | (cost, num_units, distrib)) 61 | self.assertAllClose(cost_out, np.array([2.0])) 62 | self.assertAllEqual(num_units_out, np.array([1])) 63 | self.assertAllClose(distrib_out, np.array([[1.] + [0.] * 4])) 64 | 65 | def testStopsAtMiddleUnit(self): 66 | h = tf.constant([[0.01, 0.50, 0.60, 0.70]]) 67 | (cost, num_units, distrib) = act.adaptive_computation_time(h) 68 | with self.test_session() as sess: 69 | (cost_out, num_units_out, distrib_out) = sess.run( 70 | (cost, num_units, distrib)) 71 | self.assertAllClose(cost_out, np.array([3.49])) 72 | self.assertAllEqual(num_units_out, np.array([3])) 73 | self.assertAllClose(distrib_out, np.array([[0.01, 0.50, 0.49, 0., 0.]])) 74 | 75 | def testStopsAtLastUnit(self): 76 | h = tf.constant([[0.01] * 4]) 77 | (cost, num_units, distrib) = act.adaptive_computation_time(h) 78 | with self.test_session() as sess: 79 | (cost_out, num_units_out, distrib_out) = sess.run( 80 | (cost, num_units, distrib)) 81 | self.assertAllClose(cost_out, np.array([5.96])) 82 | self.assertAllEqual(num_units_out, np.array([5])) 83 | self.assertAllClose(distrib_out, np.array([[0.01] * 4 + [0.96]])) 84 | 85 | def testCostGradientsStopsAtFirstUnit(self): 86 | h = tf.constant([[0.999] * 4]) 87 | (cost, num_units, distrib) = act.adaptive_computation_time(h) 88 | cost_grad = tf.gradients(cost, h) 89 | with self.test_session() as sess: 90 | cost_grad_out = sess.run(cost_grad) 91 | self.assertAllClose(cost_grad_out, np.array([[[0.] * 4]])) 92 | 93 | def testCostGradientsStopsAtMiddleUnit(self): 94 | h = tf.constant([[0.01, 0.50, 0.60, 0.70]]) 95 | (cost, num_units, distrib) = act.adaptive_computation_time(h) 96 | cost_grad = tf.gradients(cost, h) 97 | with self.test_session() as sess: 98 | cost_grad_out = sess.run(cost_grad) 99 | self.assertAllClose(cost_grad_out, np.array([[[-1., -1., 0., 0.]]])) 100 | 101 | def testCostGradientsStopsAtLastUnit(self): 102 | h = tf.constant([[0.01] * 4]) 103 | (cost, num_units, distrib) = act.adaptive_computation_time(h) 104 | cost_grad = tf.gradients(cost, h) 105 | with self.test_session() as sess: 106 | cost_grad_out = sess.run(cost_grad) 107 | self.assertAllClose(cost_grad_out, np.array([[[-1.] * 4]])) 108 | 109 | 110 | class ActWrapperTest(tf.test.TestCase): 111 | 112 | def _runAct(self, unit_outputs, halting_probas): 113 | self.assertEqual(len(unit_outputs), len(halting_probas)) 114 | batch = len(unit_outputs) 115 | 116 | # halting_proba[i][-1] should not be used, but we still pass it here 117 | # to be able to check that it does not affect anything. 118 | for (l, h) in zip(unit_outputs, halting_probas): 119 | self.assertEqual(len(l), len(h)) 120 | max_units = len(unit_outputs[0]) 121 | 122 | unit_outputs_tf = tf.constant( 123 | unit_outputs, shape=[batch, max_units], dtype=tf.float32) 124 | halting_probas_tf = tf.constant( 125 | halting_probas, shape=[batch, max_units], dtype=tf.float32) 126 | # Every unit for each object is two FLOPS. 127 | flops_tf = tf.constant(2, shape=[batch, max_units], dtype=tf.int64) 128 | 129 | def unit(x, unit_idx): 130 | return ( 131 | tf.reshape(unit_outputs_tf[:, unit_idx], tf.stack([-1, 1])), 132 | tf.reshape(halting_probas_tf[:, unit_idx], tf.stack([-1, 1])), 133 | flops_tf[:, unit_idx]) 134 | 135 | inputs = tf.random_normal(shape=[batch, 1]) 136 | (cost, num_units, flops, distrib, outputs 137 | ) = act.adaptive_computation_time_wrapper(inputs, unit, max_units) 138 | cost_grad = tf.gradients(cost, halting_probas_tf) 139 | with self.test_session() as sess: 140 | sess.run(tf.global_variables_initializer()) 141 | return sess.run((cost, num_units, flops, distrib, outputs, cost_grad)) 142 | 143 | def testEqualValuesInBatch(self): 144 | (cost, num_units, flops, distrib, outputs, cost_grad) = self._runAct( 145 | [list(range(5))] * 2, [[0.999] * 5] * 2) 146 | self.assertAlmostEqual(cost[0], cost[1]) 147 | self.assertEqual(num_units[0], num_units[1]) 148 | self.assertEqual(flops[0], flops[1]) 149 | self.assertAllClose(distrib[0], distrib[1]) 150 | self.assertAllClose(outputs[0], outputs[1]) 151 | self.assertAllClose(cost_grad[0][0], cost_grad[0][1]) 152 | 153 | def testStopsAtFirstUnit(self): 154 | (cost, num_units, flops, distrib, outputs, cost_grad) = self._runAct( 155 | [list(range(5))], [[0.999] + [0.5] * 4]) 156 | self.assertAllClose(cost, [2.0]) 157 | self.assertAllEqual(num_units, [1]) 158 | self.assertAllEqual(flops, [2]) 159 | self.assertAllClose(distrib, [[1.0] + [0.0] * 4]) 160 | self.assertAllClose(outputs, [[0.0]]) 161 | self.assertAllClose(cost_grad, [[[0.0] * 5]]) 162 | 163 | def testStopsAtMiddleUnit(self): 164 | (cost, num_units, flops, distrib, outputs, cost_grad) = self._runAct( 165 | [list(range(5))], [[0.01, 0.5, 0.6, 0.7, 0.8]]) 166 | self.assertAllClose(cost, [3.49]) 167 | self.assertAllEqual(num_units, [3]) 168 | self.assertAllEqual(flops, [6]) 169 | self.assertAllClose(distrib, [[0.01, 0.50, 0.49, 0., 0.]]) 170 | self.assertAllClose(outputs, [[1.48]]) 171 | self.assertAllClose(cost_grad, [[[-1., -1., 0., 0., 0.]]]) 172 | 173 | def testStopsAtLastUnit(self): 174 | (cost, num_units, flops, distrib, outputs, cost_grad) = self._runAct( 175 | [list(range(5))], [[0.01] * 5]) 176 | self.assertAllClose(cost, [5.96]) 177 | self.assertAllEqual(num_units, [5]) 178 | self.assertAllEqual(flops, [10]) 179 | self.assertAllClose(distrib, [[0.01] * 4 + [0.96]]) 180 | self.assertAllClose(outputs, [[3.9]]) 181 | self.assertAllClose(cost_grad, [[[-1.] * 4 + [0.]]]) 182 | 183 | def testInputs(self): 184 | inputs = tf.random_normal(shape=[2, 3]) 185 | 186 | def unit(x, unit_idx): 187 | # First object runs for two units, second object for four units. 188 | return (x, tf.constant( 189 | [0.7, 0.3], shape=[2, 1]), tf.constant( 190 | 0, shape=[2], dtype=tf.int64)) 191 | 192 | (_, _, _, _, outputs) = act.adaptive_computation_time_wrapper(inputs, 193 | unit, 5) 194 | with self.test_session() as sess: 195 | (inputs_out, outputs_out) = sess.run((inputs, outputs)) 196 | self.assertAllClose(inputs_out, outputs_out) 197 | 198 | def testRegularization(self): 199 | inputs = tf.random_normal(shape=[1, 3]) 200 | 201 | def unit(x, unit_idx): 202 | with tf.variable_scope('{}'.format(unit_idx)): 203 | w = tf.get_variable( 204 | 'test_variable', [1, 1], 205 | initializer=tf.constant_initializer(1.0), 206 | regularizer=lambda _: 2.0 * tf.nn.l2_loss(_)) 207 | return (w, tf.constant( 208 | 1.0, shape=[1, 1]), tf.constant( 209 | 0, shape=[1], dtype=tf.int64)) 210 | 211 | (_, _, _, _, outputs) = act.adaptive_computation_time_wrapper(inputs, 212 | unit, 5) 213 | decay_cost = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) 214 | with self.test_session() as sess: 215 | sess.run(tf.global_variables_initializer()) 216 | (_, decay_cost_out) = sess.run((outputs, decay_cost)) 217 | self.assertEqual(decay_cost_out, 5.0) 218 | 219 | 220 | class ActEarlyStoppingTest(tf.test.TestCase): 221 | 222 | def _runAct(self, unit_outputs, halting_probas): 223 | self.assertEqual(len(unit_outputs), len(halting_probas)) 224 | batch = len(unit_outputs) 225 | 226 | # halting_proba[i][-1] should not be used, but we still pass it here 227 | # to be able to check that it does not affect anything. 228 | for (l, h) in zip(unit_outputs, halting_probas): 229 | self.assertEqual(len(l), len(h)) 230 | max_units = len(unit_outputs[0]) 231 | 232 | unit_outputs_tf = tf.constant( 233 | unit_outputs, shape=[batch, max_units], dtype=tf.float32) 234 | halting_probas_tf = tf.constant( 235 | halting_probas, shape=[batch, max_units], dtype=tf.float32) 236 | # Every unit for each object is two FLOPS. 237 | flops_tf = tf.constant(2, shape=[batch, max_units], dtype=tf.int64) 238 | unit_counter = tf.Variable(0, trainable=False) 239 | 240 | def unit(x, unit_idx): 241 | assign_op = unit_counter.assign_add(1) 242 | with tf.control_dependencies([assign_op]): 243 | return ( 244 | tf.reshape(unit_outputs_tf[:, unit_idx], tf.stack([-1, 1])), 245 | tf.reshape(halting_probas_tf[:, unit_idx], tf.stack([-1, 1])), 246 | flops_tf[:, unit_idx]) 247 | 248 | inputs = tf.random_normal(shape=[batch, 1]) 249 | (cost, num_units, flops, distrib, outputs 250 | ) = act.adaptive_computation_early_stopping(inputs, unit, max_units) 251 | cost_grad = tf.gradients(cost, halting_probas_tf) 252 | with self.test_session() as sess: 253 | sess.run(tf.global_variables_initializer()) 254 | return sess.run((cost, num_units, flops, distrib, outputs, cost_grad, 255 | unit_counter)) 256 | 257 | def testEqualValuesInBatch(self): 258 | (cost, num_units, flops, distrib, outputs, cost_grad, 259 | unit_counter) = self._runAct([list(range(5))] * 2, [[0.999] * 5] * 2) 260 | self.assertAlmostEqual(cost[0], cost[1]) 261 | self.assertEqual(num_units[0], num_units[1]) 262 | self.assertEqual(flops[0], flops[1]) 263 | self.assertAllClose(distrib[0], distrib[1]) 264 | self.assertAllClose(outputs[0], outputs[1]) 265 | self.assertAllClose(cost_grad[0][0], cost_grad[0][1]) 266 | self.assertEqual(unit_counter, 1) 267 | 268 | def testStopsAtFirstUnit(self): 269 | (cost, num_units, flops, distrib, outputs, cost_grad, 270 | unit_counter) = self._runAct([list(range(5))], [[0.999] + [0.5] * 4]) 271 | self.assertAllClose(cost, [2.0]) 272 | self.assertAllEqual(num_units, [1]) 273 | self.assertAllEqual(flops, [2]) 274 | self.assertAllClose(distrib, [[1.0] + [0.0] * 4]) 275 | self.assertAllClose(outputs, [[0.0]]) 276 | self.assertAllClose(cost_grad, [[[0.0] * 5]]) 277 | self.assertEqual(unit_counter, 1) 278 | 279 | def testStopsAtMiddleUnit(self): 280 | (cost, num_units, flops, distrib, outputs, cost_grad, 281 | unit_counter) = self._runAct([list(range(5))], [[0.01, 0.5, 0.6, 0.7, 0.8]]) 282 | self.assertAllClose(cost, [3.49]) 283 | self.assertAllEqual(num_units, [3]) 284 | self.assertAllEqual(flops, [6]) 285 | self.assertAllClose(distrib, [[0.01, 0.50, 0.49, 0., 0.]]) 286 | self.assertAllClose(outputs, [[1.48]]) 287 | self.assertAllClose(cost_grad, [[[-1., -1., 0., 0., 0.]]]) 288 | self.assertEqual(unit_counter, 3) 289 | 290 | def testStopsAtLastUnit(self): 291 | (cost, num_units, flops, distrib, outputs, cost_grad, 292 | unit_counter) = self._runAct([list(range(5))], [[0.01] * 5]) 293 | self.assertAllClose(cost, [5.96]) 294 | self.assertAllEqual(num_units, [5]) 295 | self.assertAllEqual(flops, [10]) 296 | self.assertAllClose(distrib, [[0.01] * 4 + [0.96]]) 297 | self.assertAllClose(outputs, [[3.9]]) 298 | self.assertAllClose(cost_grad, [[[-1.] * 4 + [0.]]]) 299 | self.assertEqual(unit_counter, 5) 300 | 301 | def testInputs(self): 302 | inputs = tf.random_normal(shape=[2, 3]) 303 | 304 | def unit(x, unit_idx): 305 | # First object runs for two units, second object for four units. 306 | return (x, tf.constant( 307 | [0.7, 0.3], shape=[2, 1]), tf.constant( 308 | 0, shape=[2], dtype=tf.int64)) 309 | 310 | (_, _, _, _, outputs) = act.adaptive_computation_early_stopping(inputs, 311 | unit, 5) 312 | with self.test_session() as sess: 313 | (inputs_out, outputs_out) = sess.run((inputs, outputs)) 314 | self.assertAllClose(inputs_out, outputs_out) 315 | 316 | def testRegularization(self): 317 | inputs = tf.random_normal(shape=[1, 3]) 318 | 319 | def unit(x, unit_idx): 320 | with tf.variable_scope('{}'.format(unit_idx)): 321 | w = tf.get_variable( 322 | 'test_variable', [1, 1], 323 | initializer=tf.constant_initializer(1.0), 324 | regularizer=lambda _: 2.0 * tf.nn.l2_loss(_)) 325 | return (w, tf.constant( 326 | 1.0, shape=[1, 1]), tf.constant( 327 | 0, shape=[1], dtype=tf.int64)) 328 | 329 | (_, _, _, _, outputs) = act.adaptive_computation_early_stopping(inputs, 330 | unit, 5) 331 | decay_cost = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) 332 | with self.test_session() as sess: 333 | sess.run(tf.global_variables_initializer()) 334 | (outputs_out, decay_cost_out) = sess.run((outputs, decay_cost)) 335 | self.assertEqual(decay_cost_out, 5.0) 336 | 337 | 338 | class SactTest(tf.test.TestCase): 339 | 340 | def testSimple(self): 341 | # Batch x Height x Width x Channels 342 | sh = [1, 1, 2, 1] 343 | unit_outputs = [ 344 | np.array([1.0, 2.0]).reshape(sh), 345 | np.array([3.0, 4.0]).reshape(sh), 346 | np.array([5.0, 6.0]).reshape(sh), 347 | ] 348 | halting_probas = [ 349 | np.array([0.9, 0.1]).reshape(sh), 350 | np.array([0.5, 0.1]).reshape(sh), 351 | np.array([0.8, 0.1]).reshape(sh), # unused 352 | ] 353 | flops = [2, 2, 2] 354 | max_units = 3 355 | residual_masks = [] 356 | 357 | def unit(_, unit_idx, residual_mask): 358 | residual_masks.append(residual_mask) 359 | return (tf.constant( 360 | unit_outputs[unit_idx], dtype=tf.float32), tf.constant( 361 | halting_probas[unit_idx], dtype=tf.float32), tf.constant( 362 | flops[unit_idx], shape=[1], dtype=tf.int64)) 363 | 364 | inputs = tf.random_normal(shape=sh) 365 | (cost, num_units, flops, distrib, outputs 366 | ) = act.spatially_adaptive_computation_time(inputs, unit, max_units) 367 | with self.test_session() as sess: 368 | (cost_out, num_units_out, flops_out, distrib_out, outputs_out, 369 | residual_masks_out) = sess.run( 370 | (cost, num_units, flops, distrib, outputs, residual_masks[1:])) 371 | # Batch x Height x Width 372 | sh = [1, 1, 2] 373 | self.assertAllClose(cost_out, np.array([2.1, 3.8]).reshape(sh)) 374 | self.assertAllEqual(num_units_out, np.array([2, 3]).reshape(sh)) 375 | self.assertAllEqual(flops_out, [6]) 376 | distrib_expected = np.array([[0.9, 0.1, 0.0], [0.1, 0.1, 0.8]]) 377 | self.assertAllClose(distrib_out, distrib_expected.reshape(sh + [3])) 378 | outputs_expected = np.array([1.2, 5.4]) 379 | self.assertAllClose(outputs_out, outputs_expected.reshape(sh + [1])) 380 | # Residual mask for the second unit 381 | self.assertAllClose(residual_masks_out[0], 382 | np.array([1., 1.]).reshape(sh + [1])) 383 | # Residual mask for the third unit 384 | self.assertAllClose(residual_masks_out[1], 385 | np.array([0., 1.]).reshape(sh + [1])) 386 | 387 | def testInputs(self): 388 | max_units = 5 389 | inputs = tf.random_normal(shape=[2, 5, 3, 3]) 390 | # Generate random probabilities for first four units that sum up to one. 391 | # Fill in last unit with zeros. 392 | probas = tf.random_normal(shape=[max_units - 1, 2, 5, 3]) 393 | probas = tf.reshape(probas, [max_units - 1, 2 * 5 * 3]) 394 | probas = tf.nn.softmax(probas) 395 | probas = tf.reshape(probas, [max_units - 1, 2, 5, 3]) 396 | probas = tf.concat([probas, tf.zeros([1, 2, 5, 3])], 0) 397 | 398 | def unit(x, unit_idx, residual_mask): 399 | return (x, tf.reshape(probas[unit_idx, :, :, :], [2, 5, 3, 1]), 400 | tf.zeros( 401 | [2], dtype=tf.int64)) 402 | 403 | (_, _, _, _, outputs) = act.spatially_adaptive_computation_time( 404 | inputs, unit, max_units) 405 | with self.test_session() as sess: 406 | (inputs_out, outputs_out) = sess.run((inputs, outputs)) 407 | self.assertAllClose(inputs_out, outputs_out) 408 | 409 | def testResidualMask(self): 410 | # Batch x Height x Width x Channels 411 | sh = [1, 1, 2, 1] 412 | halting_probas = [ 413 | np.array([0.9, 0.1]).reshape(sh), 414 | np.array([0.5, 0.1]).reshape(sh), 415 | np.array([0.8, 0.1]).reshape(sh), # unused 416 | ] 417 | max_units = 3 418 | 419 | unit_outputs = [] 420 | 421 | def unit(x, unit_idx, residual_mask): 422 | residual = tf.ones(sh) 423 | if residual_mask is not None: 424 | residual *= residual_mask 425 | outputs = x + residual 426 | unit_outputs.append(outputs) 427 | return (outputs, tf.constant( 428 | halting_probas[unit_idx], dtype=tf.float32), tf.zeros( 429 | [2], dtype=tf.int64)) 430 | 431 | inputs = tf.zeros(sh) 432 | (_, _, _, _, outputs) = act.spatially_adaptive_computation_time( 433 | inputs, unit, max_units) 434 | with self.test_session() as sess: 435 | unit_outputs_out, final_outputs_out = sess.run( 436 | (unit_outputs, outputs)) 437 | 438 | # First position runs for two iterations, 439 | # second position for three iterations 440 | self.assertAllClose(unit_outputs_out[0], 441 | np.array([1.0, 1.0]).reshape(sh)) 442 | self.assertAllClose(unit_outputs_out[1], 443 | np.array([2.0, 2.0]).reshape(sh)) 444 | self.assertAllClose(unit_outputs_out[2], 445 | np.array([2.0, 3.0]).reshape(sh)) 446 | 447 | self.assertAllClose(final_outputs_out, np.array([1.1, 2.7]).reshape(sh)) 448 | 449 | 450 | if __name__ == '__main__': 451 | tf.test.main() 452 | -------------------------------------------------------------------------------- /cifar_data_provider.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Contains code for loading and preprocessing the CIFAR-10 data.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | 24 | import tensorflow as tf 25 | 26 | from tensorflow.contrib import slim 27 | from tensorflow.contrib.slim import dataset_data_provider 28 | 29 | from external import datasets_cifar10 30 | 31 | 32 | def provide_data(split_name, batch_size, dataset_dir=None): 33 | """Provides batches of CIFAR data. 34 | 35 | Args: 36 | split_name: Either 'train' or 'test'. 37 | batch_size: The number of images in each batch. 38 | dataset_dir: Directory where the CIFAR-10 TFRecord files live. 39 | Defaults to "~/tensorflow/data/cifar10" 40 | 41 | Returns: 42 | images: A `Tensor` of size [batch_size, 32, 32, 3] 43 | images_not_whiten: A `Tensor` with the same size of `images`, unwhitened 44 | images. 45 | one_hot_labels: A `Tensor` of size [batch_size, num_classes], where 46 | each row has a single element set to one and the rest set to zeros. 47 | dataset.num_samples: The number of total samples in the dataset. 48 | dataset.num_classes: The number of object classes in the dataset. 49 | 50 | Raises: 51 | ValueError: if the split_name is not either 'train' or 'test'. 52 | """ 53 | with tf.device('/cpu:0'): 54 | is_train = split_name == 'train' 55 | 56 | if dataset_dir is None: 57 | dataset_dir = os.path.expanduser('~/tensorflow/data/cifar10') 58 | 59 | dataset = datasets_cifar10.get_split(split_name, dataset_dir) 60 | provider = dataset_data_provider.DatasetDataProvider( 61 | dataset, 62 | common_queue_capacity=5 * batch_size, 63 | common_queue_min=batch_size, 64 | shuffle=is_train) 65 | [image, label] = provider.get(['image', 'label']) 66 | image = tf.to_float(image) 67 | 68 | image_size = 32 69 | if is_train: 70 | num_threads = 4 71 | 72 | image = tf.image.resize_image_with_crop_or_pad(image, image_size + 4, 73 | image_size + 4) 74 | image = tf.random_crop(image, [image_size, image_size, 3]) 75 | image = tf.image.random_flip_left_right(image) 76 | # Brightness/saturation/constrast provides small gains .2%~.5% on cifar. 77 | # image = tf.image.random_brightness(image, max_delta=63. / 255.) 78 | # image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 79 | # image = tf.image.random_contrast(image, lower=0.2, upper=1.8) 80 | else: 81 | num_threads = 1 82 | 83 | image = tf.image.resize_image_with_crop_or_pad(image, image_size, 84 | image_size) 85 | 86 | image_not_whiten = image 87 | image = tf.image.per_image_standardization(image) 88 | 89 | # Creates a QueueRunner for the pre-fetching operation. 90 | images, images_not_whiten, labels = tf.train.batch( 91 | [image, image_not_whiten, label], 92 | batch_size=batch_size, 93 | num_threads=num_threads, 94 | capacity=5 * batch_size) 95 | 96 | labels = tf.reshape(labels, [-1]) 97 | one_hot_labels = slim.one_hot_encoding(labels, dataset.num_classes) 98 | 99 | return (images, images_not_whiten, one_hot_labels, dataset.num_samples, 100 | dataset.num_classes) 101 | -------------------------------------------------------------------------------- /cifar_data_provider_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for cifar_data_provider.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | from tensorflow.contrib import slim 24 | 25 | import cifar_data_provider 26 | 27 | 28 | class CifarDataProviderTest(tf.test.TestCase): 29 | 30 | def _testCifar10(self, split_name, expected_num_samples): 31 | data_tup = cifar_data_provider.provide_data( 32 | split_name, 4, dataset_dir='testdata/cifar10') 33 | images, _, one_hot_labels, num_samples, num_classes = data_tup 34 | 35 | self.assertEqual(num_samples, expected_num_samples) 36 | self.assertEqual(num_classes, 10) 37 | with self.test_session() as sess: 38 | with slim.queues.QueueRunners(sess): 39 | images_out, one_hot_labels_out = sess.run([images, one_hot_labels]) 40 | self.assertEqual(images_out.shape, (4, 32, 32, 3)) 41 | self.assertEqual(one_hot_labels_out.shape, (4, 10)) 42 | 43 | def testCifar10TrainSet(self): 44 | self._testCifar10('train', 50000) 45 | 46 | def testCifar10TestSet(self): 47 | self._testCifar10('test', 10000) 48 | 49 | 50 | if __name__ == '__main__': 51 | tf.test.main() 52 | -------------------------------------------------------------------------------- /cifar_main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Trains or evaluates a CIFAR ResNet ACT model.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import math 23 | 24 | import tensorflow as tf 25 | from tensorflow.contrib import slim 26 | 27 | import cifar_data_provider 28 | import cifar_model 29 | import summary_utils 30 | import training_utils 31 | import utils 32 | 33 | 34 | FLAGS = tf.app.flags.FLAGS 35 | 36 | # General settings 37 | tf.app.flags.DEFINE_string('mode', 'train', 'One of "train" or "eval".') 38 | 39 | # Training settings 40 | tf.app.flags.DEFINE_integer('batch_size', 128, 41 | 'The number of images in each batch.') 42 | 43 | tf.app.flags.DEFINE_string('master', '', 44 | 'Name of the TensorFlow master to use.') 45 | 46 | tf.app.flags.DEFINE_string('train_log_dir', '/tmp/resnet_act_cifar/', 47 | 'Directory where to write event logs.') 48 | 49 | tf.app.flags.DEFINE_integer( 50 | 'save_summaries_secs', 30, 51 | 'The frequency with which summaries are saved, in seconds.') 52 | 53 | tf.app.flags.DEFINE_integer( 54 | 'save_interval_secs', 60, 55 | 'The frequency with which the model is saved, in seconds.') 56 | 57 | tf.app.flags.DEFINE_integer('max_number_of_steps', 100000, 58 | 'The maximum number of gradient steps.') 59 | 60 | tf.app.flags.DEFINE_integer( 61 | 'ps_tasks', 0, 62 | 'The number of parameter servers. If the value is 0, then the parameters ' 63 | 'are handled locally by the worker.') 64 | 65 | tf.app.flags.DEFINE_integer( 66 | 'task', 0, 67 | 'The Task ID. This value is used when training with multiple workers to ' 68 | 'identify each worker.') 69 | 70 | tf.app.flags.DEFINE_string( 71 | 'dataset_dir', None, 72 | 'Directory with CIFAR-10 data, should contain files ' 73 | '"cifar10_train.tfrecord" and "cifar10_test.tfrecord".') 74 | 75 | # Evaluation settings 76 | tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/resnet_act_cifar/', 77 | 'Directory where the model was written to.') 78 | 79 | tf.app.flags.DEFINE_string('eval_dir', '/tmp/resnet_act_cifar/', 80 | 'Directory where the results are saved to.') 81 | 82 | tf.app.flags.DEFINE_integer('eval_batch_size', 100, 83 | 'The number of images in each batch for evaluation.') 84 | 85 | tf.app.flags.DEFINE_integer( 86 | 'eval_interval_secs', 60, 87 | 'The frequency, in seconds, with which evaluation is run.') 88 | 89 | tf.app.flags.DEFINE_string('split_name', 'test', """Either 'train' or 'test'.""") 90 | 91 | tf.app.flags.DEFINE_bool('evaluate_once', False, 'Evaluate the model just once?') 92 | 93 | # Model settings 94 | tf.app.flags.DEFINE_string( 95 | 'model_type', 'vanilla', 96 | 'Options: vanilla (basic ResNet model), act (Adaptive Computation Time), ' 97 | 'act_early_stopping (act implementation which actually saves time), ' 98 | 'sact (Spatially Adaptive Computation Time)') 99 | 100 | tf.app.flags.DEFINE_float('tau', 1.0, 'The value of tau (ponder relative cost).') 101 | 102 | tf.app.flags.DEFINE_string( 103 | 'model', 104 | '5', 105 | 'An underscore separated string, number of residual units per block. ' 106 | 'If only one number is provided, uses the same number of units in all blocks') 107 | 108 | tf.app.flags.DEFINE_string('finetune_path', '', 109 | 'Path for the initial checkpoint for finetuning.') 110 | 111 | 112 | def train(): 113 | if not tf.gfile.Exists(FLAGS.train_log_dir): 114 | tf.gfile.MakeDirs(FLAGS.train_log_dir) 115 | 116 | g = tf.Graph() 117 | with g.as_default(): 118 | # If ps_tasks is zero, the local device is used. When using multiple 119 | # (non-local) replicas, the ReplicaDeviceSetter distributes the variables 120 | # across the different devices. 121 | with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): 122 | data_tuple = cifar_data_provider.provide_data( 123 | 'train', FLAGS.batch_size, dataset_dir=FLAGS.dataset_dir) 124 | images, _, one_hot_labels, _, num_classes = data_tuple 125 | 126 | # Define the model: 127 | with slim.arg_scope(cifar_model.resnet_arg_scope(is_training=True)): 128 | model = utils.split_and_int(FLAGS.model) 129 | logits, end_points = cifar_model.resnet( 130 | images, 131 | model=model, 132 | num_classes=num_classes, 133 | model_type=FLAGS.model_type) 134 | 135 | # Specify the loss function: 136 | tf.losses.softmax_cross_entropy( 137 | onehot_labels=one_hot_labels, logits=logits) 138 | if FLAGS.model_type in ('act', 'act_early_stopping', 'sact'): 139 | training_utils.add_all_ponder_costs(end_points, weights=FLAGS.tau) 140 | total_loss = tf.losses.get_total_loss() 141 | tf.summary.scalar('Total Loss', total_loss) 142 | 143 | metric_map = {} # summary_utils.flops_metric_map(end_points, False) 144 | if FLAGS.model_type in ('act', 'act_early_stopping', 'sact'): 145 | metric_map.update(summary_utils.act_metric_map(end_points, False)) 146 | for name, value in metric_map.iteritems(): 147 | tf.summary.scalar(name, value) 148 | 149 | if FLAGS.model_type == 'sact': 150 | summary_utils.add_heatmaps_image_summary(end_points) 151 | 152 | init_fn = training_utils.finetuning_init_fn(FLAGS.finetune_path) 153 | 154 | # Specify the optimization scheme: 155 | global_step = slim.get_or_create_global_step() 156 | # Original LR schedule 157 | # boundaries = [40000, 60000, 80000] 158 | # "Longer" LR schedule 159 | boundaries = [60000, 75000, 90000] 160 | boundaries = [tf.constant(x, dtype=tf.int64) for x in boundaries] 161 | values = [0.1, 0.01, 0.001, 0.0001] 162 | learning_rate = tf.train.piecewise_constant(global_step, boundaries, 163 | values) 164 | tf.summary.scalar('Learning Rate', learning_rate) 165 | optimizer = tf.train.MomentumOptimizer(learning_rate, 0.9) 166 | 167 | # Set up training. 168 | train_op = slim.learning.create_train_op(total_loss, optimizer) 169 | 170 | if FLAGS.train_log_dir: 171 | logdir = FLAGS.train_log_dir 172 | else: 173 | logdir = None 174 | 175 | config = tf.ConfigProto() 176 | config.gpu_options.allow_growth = True 177 | 178 | # Run training. 179 | slim.learning.train( 180 | train_op=train_op, 181 | init_fn=init_fn, 182 | logdir=logdir, 183 | master=FLAGS.master, 184 | number_of_steps=FLAGS.max_number_of_steps, 185 | save_summaries_secs=FLAGS.save_summaries_secs, 186 | save_interval_secs=FLAGS.save_interval_secs, 187 | session_config=config) 188 | 189 | 190 | def evaluate(): 191 | g = tf.Graph() 192 | with g.as_default(): 193 | data_tuple = cifar_data_provider.provide_data(FLAGS.split_name, 194 | FLAGS.eval_batch_size, 195 | dataset_dir=FLAGS.dataset_dir) 196 | images, _, one_hot_labels, num_samples, num_classes = data_tuple 197 | 198 | # Define the model: 199 | with slim.arg_scope(cifar_model.resnet_arg_scope(is_training=False)): 200 | model = utils.split_and_int(FLAGS.model) 201 | logits, end_points = cifar_model.resnet( 202 | images, 203 | model=model, 204 | num_classes=num_classes, 205 | model_type=FLAGS.model_type) 206 | 207 | predictions = tf.argmax(logits, 1) 208 | 209 | tf.losses.softmax_cross_entropy( 210 | onehot_labels=one_hot_labels, logits=logits) 211 | if FLAGS.model_type in ('act', 'act_early_stopping', 'sact'): 212 | training_utils.add_all_ponder_costs(end_points, weights=FLAGS.tau) 213 | 214 | loss = tf.losses.get_total_loss() 215 | 216 | # Define the metrics: 217 | labels = tf.argmax(one_hot_labels, 1) 218 | metric_map = { 219 | 'eval/Accuracy': 220 | tf.contrib.metrics.streaming_accuracy(predictions, labels), 221 | 'eval/Mean Loss': 222 | tf.contrib.metrics.streaming_mean(loss), 223 | } 224 | metric_map.update(summary_utils.flops_metric_map(end_points, True)) 225 | if FLAGS.model_type in ('act', 'act_early_stopping', 'sact'): 226 | metric_map.update(summary_utils.act_metric_map(end_points, True)) 227 | names_to_values, names_to_updates = tf.contrib.metrics.aggregate_metric_map( 228 | metric_map) 229 | 230 | for name, value in names_to_values.iteritems(): 231 | summ = tf.summary.scalar(name, value, collections=[]) 232 | summ = tf.Print(summ, [value], name) 233 | tf.add_to_collection(tf.GraphKeys.SUMMARIES, summ) 234 | 235 | if FLAGS.model_type == 'sact': 236 | summary_utils.add_heatmaps_image_summary(end_points) 237 | 238 | # This ensures that we make a single pass over all of the data. 239 | num_batches = math.ceil(num_samples / float(FLAGS.eval_batch_size)) 240 | 241 | if not FLAGS.evaluate_once: 242 | eval_function = slim.evaluation.evaluation_loop 243 | checkpoint_path = FLAGS.checkpoint_dir 244 | eval_kwargs = {'eval_interval_secs': FLAGS.eval_interval_secs} 245 | else: 246 | eval_function = slim.evaluation.evaluate_once 247 | checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) 248 | assert checkpoint_path is not None 249 | eval_kwargs = {} 250 | 251 | config = tf.ConfigProto() 252 | config.gpu_options.allow_growth = True 253 | 254 | eval_function( 255 | FLAGS.master, 256 | checkpoint_path, 257 | logdir=FLAGS.eval_dir, 258 | num_evals=num_batches, 259 | eval_op=names_to_updates.values(), 260 | session_config=config, 261 | **eval_kwargs) 262 | 263 | 264 | def main(_): 265 | if FLAGS.mode == 'train': 266 | train() 267 | elif FLAGS.mode == 'eval': 268 | evaluate() 269 | 270 | 271 | if __name__ == '__main__': 272 | tf.app.run() 273 | -------------------------------------------------------------------------------- /cifar_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """ Adaptive computation time residual network for CIFAR-10. 17 | 18 | The code is based on https://github.com/tensorflow/models/blob/master/resnet/resnet_model.py 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import tensorflow as tf 26 | 27 | from tensorflow.contrib import slim 28 | from tensorflow.contrib.slim.nets import resnet_utils 29 | 30 | import flopsometer 31 | import resnet_act 32 | 33 | 34 | def lrelu(x, leakiness=0.1): 35 | return tf.maximum(x, x * leakiness) 36 | 37 | 38 | def residual(inputs, 39 | depth, 40 | stride, 41 | activate_before_residual, 42 | residual_mask=None, 43 | scope=None): 44 | with tf.variable_scope(scope, 'residual', [inputs]): 45 | depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4) 46 | preact = slim.batch_norm(inputs, scope='preact') 47 | if activate_before_residual: 48 | shortcut = preact 49 | else: 50 | shortcut = inputs 51 | 52 | if residual_mask is not None: 53 | # Max-pooling trick only works correctly when stride is 1. 54 | # We assume that stride=2 happens in the first layer where 55 | # residual_mask is None. 56 | assert stride == 1 57 | diluted_residual_mask = slim.max_pool2d( 58 | residual_mask, [3, 3], stride=1, padding='SAME') 59 | else: 60 | diluted_residual_mask = None 61 | 62 | flops = 0 63 | conv_output, current_flops = flopsometer.conv2d( 64 | preact, 65 | depth, 66 | 3, 67 | stride=stride, 68 | padding='SAME', 69 | output_mask=diluted_residual_mask, 70 | scope='conv1') 71 | flops += current_flops 72 | 73 | conv_output, current_flops = flopsometer.conv2d( 74 | conv_output, 75 | depth, 76 | 3, 77 | stride=1, 78 | padding='SAME', 79 | activation_fn=None, 80 | normalizer_fn=None, 81 | output_mask=residual_mask, 82 | scope='conv2') 83 | flops += current_flops 84 | 85 | if depth_in != depth: 86 | shortcut = slim.avg_pool2d(shortcut, stride, stride, padding='VALID') 87 | value = (depth - depth_in) // 2 88 | shortcut = tf.pad(shortcut, [[0, 0], [0, 0], [0, 0], [value, value]]) 89 | 90 | if residual_mask is not None: 91 | conv_output *= residual_mask 92 | 93 | outputs = shortcut + conv_output 94 | 95 | return outputs, flops 96 | 97 | 98 | def resnet(inputs, 99 | model, 100 | num_classes, 101 | model_type='vanilla', 102 | base_channels=16, 103 | scope='resnet_residual'): 104 | """Builds a CIFAR-10 resnet model.""" 105 | num_blocks = 3 106 | num_units = model 107 | if len(num_units) == 1: 108 | num_units *= num_blocks 109 | assert len(num_units) == num_blocks 110 | 111 | b = resnet_utils.Block 112 | bc = base_channels 113 | blocks = [ 114 | b('block_1', residual, 115 | [(bc, 1, True)] + [(bc, 1, False)] * (num_units[0] - 1)), 116 | b('block_2', residual, 117 | [(2 * bc, 2, False)] + [(2 * bc, 1, False)] * (num_units[1] - 1)), 118 | b('block_3', residual, 119 | [(4 * bc, 2, False)] + [(4 * bc, 1, False)] * (num_units[2] - 1)) 120 | ] 121 | 122 | with tf.variable_scope(scope, [inputs]): 123 | end_points = {'inputs': inputs} 124 | end_points['flops'] = 0 125 | net = inputs 126 | net, current_flops = flopsometer.conv2d( 127 | net, bc, 3, activation_fn=None, normalizer_fn=None) 128 | end_points['flops'] += current_flops 129 | net, end_points = resnet_act.stack_blocks( 130 | net, 131 | blocks, 132 | model_type=model_type, 133 | end_points=end_points) 134 | net = tf.reduce_mean(net, [1, 2], keep_dims=True) 135 | net = slim.batch_norm(net) 136 | net, current_flops = flopsometer.conv2d( 137 | net, 138 | num_classes, [1, 1], 139 | activation_fn=None, 140 | normalizer_fn=None, 141 | scope='logits') 142 | end_points['flops'] += current_flops 143 | net = tf.squeeze(net, [1, 2], name='SpatialSqueeze') 144 | 145 | return net, end_points 146 | 147 | 148 | def resnet_arg_scope(is_training=True): 149 | """Sets up the default arguments for the CIFAR-10 resnet model.""" 150 | batch_norm_params = { 151 | 'is_training': is_training, 152 | 'decay': 0.9, 153 | 'epsilon': 0.001, 154 | 'scale': True, 155 | # This forces batch_norm to compute the moving averages in-place 156 | # instead of using a global collection which does not work with tf.cond. 157 | # 'updates_collections': None, 158 | } 159 | 160 | with slim.arg_scope([slim.conv2d, slim.batch_norm], activation_fn=lrelu): 161 | with slim.arg_scope( 162 | [slim.conv2d], 163 | weights_regularizer=slim.l2_regularizer(0.0002), 164 | weights_initializer=slim.variance_scaling_initializer(), 165 | normalizer_fn=slim.batch_norm, 166 | normalizer_params=batch_norm_params): 167 | with slim.arg_scope([slim.batch_norm], **batch_norm_params) as arg_sc: 168 | return arg_sc 169 | -------------------------------------------------------------------------------- /cifar_model_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for resnet_model.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import numpy as np 23 | import tensorflow as tf 24 | from tensorflow.contrib import slim 25 | 26 | import cifar_model 27 | import summary_utils 28 | import training_utils 29 | 30 | 31 | class CifarModelTest(tf.test.TestCase): 32 | 33 | def _runBatch(self, is_training, model_type, model=[2]): 34 | batch_size = 2 35 | height, width = 32, 32 36 | num_classes = 10 37 | 38 | with slim.arg_scope( 39 | cifar_model.resnet_arg_scope(is_training=is_training)): 40 | with self.test_session() as sess: 41 | images = tf.random_uniform((batch_size, height, width, 3)) 42 | logits, end_points = cifar_model.resnet( 43 | images, 44 | model=model, 45 | num_classes=num_classes, 46 | model_type=model_type, 47 | base_channels=1) 48 | if model_type in ('act', 'act_early_stopping', 'sact'): 49 | metrics = summary_utils.act_metric_map(end_points, 50 | not is_training) 51 | metrics.update(summary_utils.flops_metric_map(end_points, 52 | not is_training)) 53 | else: 54 | metrics = {} 55 | 56 | if is_training: 57 | labels = tf.random_uniform( 58 | (batch_size,), maxval=num_classes, dtype=tf.int32) 59 | one_hot_labels = slim.one_hot_encoding(labels, num_classes) 60 | tf.losses.softmax_cross_entropy( 61 | onehot_labels=one_hot_labels, logits=logits) 62 | if model_type in ('act', 'act_early_stopping', 'sact'): 63 | training_utils.add_all_ponder_costs(end_points, weights=1.0) 64 | total_loss = tf.losses.get_total_loss() 65 | optimizer = tf.train.MomentumOptimizer(0.1, 0.9) 66 | train_op = slim.learning.create_train_op(total_loss, optimizer) 67 | sess.run(tf.global_variables_initializer()) 68 | sess.run((train_op, metrics)) 69 | else: 70 | sess.run([tf.local_variables_initializer(), 71 | tf.global_variables_initializer()]) 72 | logits_out, metrics_out = sess.run((logits, metrics)) 73 | self.assertEqual(logits_out.shape, (batch_size, num_classes)) 74 | 75 | def testTrainVanilla(self): 76 | self._runBatch(is_training=True, model_type='vanilla') 77 | 78 | def testTrainAct(self): 79 | self._runBatch(is_training=True, model_type='act') 80 | 81 | def testTrainSact(self): 82 | self._runBatch(is_training=True, model_type='sact') 83 | 84 | def testTestVanilla(self): 85 | self._runBatch(is_training=False, model_type='vanilla') 86 | 87 | def testTestVanillaResidualUnits(self): 88 | self._runBatch( 89 | is_training=False, model_type='vanilla', model=[1, 2, 3]) 90 | 91 | def testTestAct(self): 92 | self._runBatch(is_training=False, model_type='act') 93 | 94 | def testTestSact(self): 95 | self._runBatch(is_training=False, model_type='sact') 96 | 97 | def testFlopsVanilla(self): 98 | batch_size = 3 99 | height, width = 32, 32 100 | num_classes = 10 101 | 102 | with slim.arg_scope(cifar_model.resnet_arg_scope(is_training=False)): 103 | with self.test_session() as sess: 104 | images = tf.random_uniform((batch_size, height, width, 3)) 105 | _, end_points = cifar_model.resnet( 106 | images, 107 | model=[18], 108 | num_classes=num_classes, 109 | model_type='vanilla') 110 | flops = sess.run(end_points['flops']) 111 | # TF graph_metrics value: 506307850 (0.1% difference) 112 | expected_flops = 505775360 113 | self.assertAllEqual(flops, [expected_flops] * 3) 114 | 115 | def testVisualizationBasic(self): 116 | batch_size = 3 117 | height, width = 32, 32 118 | num_classes = 10 119 | is_training = False 120 | num_images = 2 121 | border = 5 122 | 123 | with slim.arg_scope(cifar_model.resnet_arg_scope(is_training=is_training)): 124 | with self.test_session() as sess: 125 | images = tf.random_uniform((batch_size, height, width, 3)) 126 | logits, end_points = cifar_model.resnet( 127 | images, 128 | model=[2], 129 | num_classes=num_classes, 130 | model_type='sact', 131 | base_channels=1) 132 | 133 | vis_ponder = summary_utils.sact_image_heatmap( 134 | end_points, 135 | 'ponder_cost', 136 | num_images=num_images, 137 | alpha=0.75, 138 | border=border) 139 | vis_units = summary_utils.sact_image_heatmap( 140 | end_points, 141 | 'num_units', 142 | num_images=num_images, 143 | alpha=0.75, 144 | border=border) 145 | 146 | sess.run(tf.global_variables_initializer()) 147 | vis_ponder_out, vis_units_out = sess.run( 148 | [vis_ponder, vis_units]) 149 | self.assertEqual(vis_ponder_out.shape, 150 | (num_images, height, width * 2 + border, 3)) 151 | self.assertEqual(vis_units_out.shape, 152 | (num_images, height, width * 2 + border, 3)) 153 | 154 | 155 | if __name__ == '__main__': 156 | tf.test.main() 157 | -------------------------------------------------------------------------------- /draw_ponder_maps.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | 17 | """Draws example ponder cost maps""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import os 24 | 25 | import h5py 26 | import matplotlib 27 | matplotlib.use('agg') # disables drawing to X 28 | import matplotlib.pyplot as plt 29 | import numpy as np 30 | import tensorflow as tf 31 | 32 | FLAGS = tf.app.flags.FLAGS 33 | 34 | tf.app.flags.DEFINE_string('input_file', None, 35 | 'An HDF5 file produced by imagenet_export.') 36 | 37 | tf.app.flags.DEFINE_string('output_dir', None, 38 | 'The directory to output the plotted ponder maps to.') 39 | 40 | 41 | def main(_): 42 | f = h5py.File(FLAGS.input_file, 'r') 43 | 44 | num_images = f['images'].shape[0] 45 | ponder_cost = np.array(f['ponder_cost_map']) 46 | min_ponder = np.percentile(ponder_cost.ravel(), 0.1) 47 | max_ponder = np.percentile(ponder_cost.ravel(), 99.9) 48 | print('0.1st percentile of ponder cost {:.2f} '.format(min_ponder)) 49 | print('99.9th percentile of ponder cost {:.2f} '.format(max_ponder)) 50 | 51 | fig = plt.figure(figsize=(0.2, 2)) 52 | ax = fig.add_axes([0.0, 0.0, 1.0, 1.0]) 53 | cb = matplotlib.colorbar.ColorbarBase( 54 | ax, cmap='viridis', 55 | norm=matplotlib.colors.Normalize(vmin=min_ponder, vmax=max_ponder)) 56 | ax.tick_params(labelsize=12) 57 | filename = os.path.join(FLAGS.output_dir, 'colorbar.pdf') 58 | plt.savefig(filename, bbox_inches='tight') 59 | 60 | for i in range(num_images): 61 | current_map = np.squeeze(f['ponder_cost_map'][i]) 62 | mean_ponder = np.mean(current_map) 63 | filename = '{}/{:.2f}_{}_ponder.png'.format(FLAGS.output_dir, mean_ponder, i) 64 | matplotlib.image.imsave( 65 | filename, current_map, cmap='viridis', vmin=min_ponder, vmax=max_ponder) 66 | 67 | im = f['images'][i] 68 | im = (im + 1.0) / 2.0 69 | filename = '{}/{:.2f}_{}_im.jpg'.format(FLAGS.output_dir, mean_ponder, i) 70 | matplotlib.image.imsave(filename, im) 71 | 72 | 73 | if __name__ == '__main__': 74 | tf.app.run() 75 | -------------------------------------------------------------------------------- /external/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mfigurnov/sact/1c81cbaaa9219a57c03ac3bdaeed30f13beb98e7/external/__init__.py -------------------------------------------------------------------------------- /external/dataset_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains utilities for downloading and converting datasets. 16 | 17 | Copied from https://github.com/tensorflow/models/blob/master/slim/datasets/dataset_utils.py 18 | """ 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import os 24 | import sys 25 | import tarfile 26 | 27 | from six.moves import urllib 28 | import tensorflow as tf 29 | 30 | LABELS_FILENAME = 'labels.txt' 31 | 32 | 33 | def int64_feature(values): 34 | """Returns a TF-Feature of int64s. 35 | 36 | Args: 37 | values: A scalar or list of values. 38 | 39 | Returns: 40 | a TF-Feature. 41 | """ 42 | if not isinstance(values, (tuple, list)): 43 | values = [values] 44 | return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) 45 | 46 | 47 | def bytes_feature(values): 48 | """Returns a TF-Feature of bytes. 49 | 50 | Args: 51 | values: A string. 52 | 53 | Returns: 54 | a TF-Feature. 55 | """ 56 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) 57 | 58 | 59 | def image_to_tfexample(image_data, image_format, height, width, class_id): 60 | return tf.train.Example(features=tf.train.Features(feature={ 61 | 'image/encoded': bytes_feature(image_data), 62 | 'image/format': bytes_feature(image_format), 63 | 'image/class/label': int64_feature(class_id), 64 | 'image/height': int64_feature(height), 65 | 'image/width': int64_feature(width), 66 | })) 67 | 68 | 69 | def download_and_uncompress_tarball(tarball_url, dataset_dir): 70 | """Downloads the `tarball_url` and uncompresses it locally. 71 | 72 | Args: 73 | tarball_url: The URL of a tarball file. 74 | dataset_dir: The directory where the temporary files are stored. 75 | """ 76 | filename = tarball_url.split('/')[-1] 77 | filepath = os.path.join(dataset_dir, filename) 78 | 79 | def _progress(count, block_size, total_size): 80 | sys.stdout.write('\r>> Downloading %s %.1f%%' % ( 81 | filename, float(count * block_size) / float(total_size) * 100.0)) 82 | sys.stdout.flush() 83 | filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress) 84 | print() 85 | statinfo = os.stat(filepath) 86 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 87 | tarfile.open(filepath, 'r:gz').extractall(dataset_dir) 88 | 89 | 90 | def write_label_file(labels_to_class_names, dataset_dir, 91 | filename=LABELS_FILENAME): 92 | """Writes a file with the list of class names. 93 | 94 | Args: 95 | labels_to_class_names: A map of (integer) labels to class names. 96 | dataset_dir: The directory in which the labels file should be written. 97 | filename: The filename where the class names are written. 98 | """ 99 | labels_filename = os.path.join(dataset_dir, filename) 100 | with tf.gfile.Open(labels_filename, 'w') as f: 101 | for label in labels_to_class_names: 102 | class_name = labels_to_class_names[label] 103 | f.write('%d:%s\n' % (label, class_name)) 104 | 105 | 106 | def has_labels(dataset_dir, filename=LABELS_FILENAME): 107 | """Specifies whether or not the dataset directory contains a label map file. 108 | 109 | Args: 110 | dataset_dir: The directory in which the labels file is found. 111 | filename: The filename where the class names are written. 112 | 113 | Returns: 114 | `True` if the labels file exists and `False` otherwise. 115 | """ 116 | return tf.gfile.Exists(os.path.join(dataset_dir, filename)) 117 | 118 | 119 | def read_label_file(dataset_dir, filename=LABELS_FILENAME): 120 | """Reads the labels file and returns a mapping from ID to class name. 121 | 122 | Args: 123 | dataset_dir: The directory in which the labels file is found. 124 | filename: The filename where the class names are written. 125 | 126 | Returns: 127 | A map from a label (integer) to class name. 128 | """ 129 | labels_filename = os.path.join(dataset_dir, filename) 130 | with tf.gfile.Open(labels_filename, 'r') as f: 131 | lines = f.read().decode() 132 | lines = lines.split('\n') 133 | lines = filter(None, lines) 134 | 135 | labels_to_class_names = {} 136 | for line in lines: 137 | index = line.index(':') 138 | labels_to_class_names[int(line[:index])] = line[index+1:] 139 | return labels_to_class_names 140 | -------------------------------------------------------------------------------- /external/datasets_cifar10.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides data for the Cifar10 dataset. 16 | 17 | The dataset scripts used to create the dataset can be found at: 18 | tensorflow/models/slim/datasets/download_and_convert_cifar10.py 19 | 20 | Copied from https://github.com/tensorflow/models/blob/master/slim/datasets/cifar10.py 21 | """ 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import os 28 | import tensorflow as tf 29 | 30 | from . import dataset_utils 31 | 32 | slim = tf.contrib.slim 33 | 34 | _FILE_PATTERN = 'cifar10_%s.tfrecord' 35 | 36 | SPLITS_TO_SIZES = {'train': 50000, 'test': 10000} 37 | 38 | _NUM_CLASSES = 10 39 | 40 | _ITEMS_TO_DESCRIPTIONS = { 41 | 'image': 'A [32 x 32 x 3] color image.', 42 | 'label': 'A single integer between 0 and 9', 43 | } 44 | 45 | 46 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None): 47 | """Gets a dataset tuple with instructions for reading cifar10. 48 | 49 | Args: 50 | split_name: A train/test split name. 51 | dataset_dir: The base directory of the dataset sources. 52 | file_pattern: The file pattern to use when matching the dataset sources. 53 | It is assumed that the pattern contains a '%s' string so that the split 54 | name can be inserted. 55 | reader: The TensorFlow reader type. 56 | 57 | Returns: 58 | A `Dataset` namedtuple. 59 | 60 | Raises: 61 | ValueError: if `split_name` is not a valid train/test split. 62 | """ 63 | if split_name not in SPLITS_TO_SIZES: 64 | raise ValueError('split name %s was not recognized.' % split_name) 65 | 66 | if not file_pattern: 67 | file_pattern = _FILE_PATTERN 68 | file_pattern = os.path.join(dataset_dir, file_pattern % split_name) 69 | 70 | # Allowing None in the signature so that dataset_factory can use the default. 71 | if not reader: 72 | reader = tf.TFRecordReader 73 | 74 | keys_to_features = { 75 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 76 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='png'), 77 | 'image/class/label': tf.FixedLenFeature( 78 | [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)), 79 | } 80 | 81 | items_to_handlers = { 82 | 'image': slim.tfexample_decoder.Image(shape=[32, 32, 3]), 83 | 'label': slim.tfexample_decoder.Tensor('image/class/label'), 84 | } 85 | 86 | decoder = slim.tfexample_decoder.TFExampleDecoder( 87 | keys_to_features, items_to_handlers) 88 | 89 | labels_to_names = None 90 | if dataset_utils.has_labels(dataset_dir): 91 | labels_to_names = dataset_utils.read_label_file(dataset_dir) 92 | 93 | return slim.dataset.Dataset( 94 | data_sources=file_pattern, 95 | reader=reader, 96 | decoder=decoder, 97 | num_samples=SPLITS_TO_SIZES[split_name], 98 | items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, 99 | num_classes=_NUM_CLASSES, 100 | labels_to_names=labels_to_names) 101 | -------------------------------------------------------------------------------- /external/datasets_imagenet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides data for the ImageNet ILSVRC 2012 Dataset plus some bounding boxes. 16 | 17 | Some images have one or more bounding boxes associated with the label of the 18 | image. See details here: http://image-net.org/download-bboxes 19 | 20 | ImageNet is based upon WordNet 3.0. To uniquely identify a synset, we use 21 | "WordNet ID" (wnid), which is a concatenation of POS ( i.e. part of speech ) 22 | and SYNSET OFFSET of WordNet. For more information, please refer to the 23 | WordNet documentation[http://wordnet.princeton.edu/wordnet/documentation/]. 24 | 25 | "There are bounding boxes for over 3000 popular synsets available. 26 | For each synset, there are on average 150 images with bounding boxes." 27 | 28 | WARNING: Don't use for object detection, in this case all the bounding boxes 29 | of the image belong to just one class. 30 | 31 | Copied from https://github.com/tensorflow/models/blob/master/slim/datasets/imagenet.py 32 | """ 33 | from __future__ import absolute_import 34 | from __future__ import division 35 | from __future__ import print_function 36 | 37 | import os 38 | from six.moves import urllib 39 | import tensorflow as tf 40 | 41 | from . import dataset_utils 42 | 43 | slim = tf.contrib.slim 44 | 45 | # TODO(nsilberman): Add tfrecord file type once the script is updated. 46 | _FILE_PATTERN = '%s-*' 47 | 48 | _SPLITS_TO_SIZES = { 49 | 'train': 1281167, 50 | 'validation': 50000, 51 | } 52 | 53 | _ITEMS_TO_DESCRIPTIONS = { 54 | 'image': 'A color image of varying height and width.', 55 | 'label': 'The label id of the image, integer between 0 and 999', 56 | 'label_text': 'The text of the label.', 57 | 'object/bbox': 'A list of bounding boxes.', 58 | 'object/label': 'A list of labels, one per each object.', 59 | } 60 | 61 | _NUM_CLASSES = 1001 62 | 63 | 64 | def create_readable_names_for_imagenet_labels(): 65 | """Create a dict mapping label id to human readable string. 66 | 67 | Returns: 68 | labels_to_names: dictionary where keys are integers from to 1000 69 | and values are human-readable names. 70 | 71 | We retrieve a synset file, which contains a list of valid synset labels used 72 | by ILSVRC competition. There is one synset one per line, eg. 73 | # n01440764 74 | # n01443537 75 | We also retrieve a synset_to_human_file, which contains a mapping from synsets 76 | to human-readable names for every synset in Imagenet. These are stored in a 77 | tsv format, as follows: 78 | # n02119247 black fox 79 | # n02119359 silver fox 80 | We assign each synset (in alphabetical order) an integer, starting from 1 81 | (since 0 is reserved for the background class). 82 | 83 | Code is based on 84 | https://github.com/tensorflow/models/blob/master/inception/inception/data/build_imagenet_data.py#L463 85 | """ 86 | 87 | # pylint: disable=g-line-too-long 88 | base_url = 'https://raw.githubusercontent.com/tensorflow/models/master/research/inception/inception/data/' 89 | synset_url = '{}/imagenet_lsvrc_2015_synsets.txt'.format(base_url) 90 | synset_to_human_url = '{}/imagenet_metadata.txt'.format(base_url) 91 | 92 | filename, _ = urllib.request.urlretrieve(synset_url) 93 | synset_list = [s.strip() for s in open(filename).readlines()] 94 | num_synsets_in_ilsvrc = len(synset_list) 95 | assert num_synsets_in_ilsvrc == 1000 96 | 97 | filename, _ = urllib.request.urlretrieve(synset_to_human_url) 98 | synset_to_human_list = open(filename).readlines() 99 | num_synsets_in_all_imagenet = len(synset_to_human_list) 100 | assert num_synsets_in_all_imagenet == 21842 101 | 102 | synset_to_human = {} 103 | for s in synset_to_human_list: 104 | parts = s.strip().split('\t') 105 | assert len(parts) == 2 106 | synset = parts[0] 107 | human = parts[1] 108 | synset_to_human[synset] = human 109 | 110 | label_index = 1 111 | labels_to_names = {0: 'background'} 112 | for synset in synset_list: 113 | name = synset_to_human[synset] 114 | labels_to_names[label_index] = name 115 | label_index += 1 116 | 117 | return labels_to_names 118 | 119 | 120 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None): 121 | """Gets a dataset tuple with instructions for reading ImageNet. 122 | 123 | Args: 124 | split_name: A train/test split name. 125 | dataset_dir: The base directory of the dataset sources. 126 | file_pattern: The file pattern to use when matching the dataset sources. 127 | It is assumed that the pattern contains a '%s' string so that the split 128 | name can be inserted. 129 | reader: The TensorFlow reader type. 130 | 131 | Returns: 132 | A `Dataset` namedtuple. 133 | 134 | Raises: 135 | ValueError: if `split_name` is not a valid train/test split. 136 | """ 137 | if split_name not in _SPLITS_TO_SIZES: 138 | raise ValueError('split name %s was not recognized.' % split_name) 139 | 140 | if not file_pattern: 141 | file_pattern = _FILE_PATTERN 142 | file_pattern = os.path.join(dataset_dir, file_pattern % split_name) 143 | 144 | # Allowing None in the signature so that dataset_factory can use the default. 145 | if reader is None: 146 | reader = tf.TFRecordReader 147 | 148 | keys_to_features = { 149 | 'image/encoded': tf.FixedLenFeature( 150 | (), tf.string, default_value=''), 151 | 'image/format': tf.FixedLenFeature( 152 | (), tf.string, default_value='jpeg'), 153 | 'image/class/label': tf.FixedLenFeature( 154 | [], dtype=tf.int64, default_value=-1), 155 | 'image/class/text': tf.FixedLenFeature( 156 | [], dtype=tf.string, default_value=''), 157 | 'image/object/bbox/xmin': tf.VarLenFeature( 158 | dtype=tf.float32), 159 | 'image/object/bbox/ymin': tf.VarLenFeature( 160 | dtype=tf.float32), 161 | 'image/object/bbox/xmax': tf.VarLenFeature( 162 | dtype=tf.float32), 163 | 'image/object/bbox/ymax': tf.VarLenFeature( 164 | dtype=tf.float32), 165 | 'image/object/class/label': tf.VarLenFeature( 166 | dtype=tf.int64), 167 | } 168 | 169 | items_to_handlers = { 170 | 'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'), 171 | 'label': slim.tfexample_decoder.Tensor('image/class/label'), 172 | 'label_text': slim.tfexample_decoder.Tensor('image/class/text'), 173 | 'object/bbox': slim.tfexample_decoder.BoundingBox( 174 | ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'), 175 | 'object/label': slim.tfexample_decoder.Tensor('image/object/class/label'), 176 | } 177 | 178 | decoder = slim.tfexample_decoder.TFExampleDecoder( 179 | keys_to_features, items_to_handlers) 180 | 181 | labels_to_names = None 182 | if dataset_utils.has_labels(dataset_dir): 183 | labels_to_names = dataset_utils.read_label_file(dataset_dir) 184 | else: 185 | labels_to_names = create_readable_names_for_imagenet_labels() 186 | dataset_utils.write_label_file(labels_to_names, dataset_dir) 187 | 188 | return slim.dataset.Dataset( 189 | data_sources=file_pattern, 190 | reader=reader, 191 | decoder=decoder, 192 | num_samples=_SPLITS_TO_SIZES[split_name], 193 | items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, 194 | num_classes=_NUM_CLASSES, 195 | labels_to_names=labels_to_names) 196 | -------------------------------------------------------------------------------- /external/download_and_convert_cifar10.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""Downloads and converts cifar10 data to TFRecords of TF-Example protos. 16 | 17 | This module downloads the cifar10 data, uncompresses it, reads the files 18 | that make up the cifar10 data and creates two TFRecord datasets: one for train 19 | and one for test. Each TFRecord dataset is comprised of a set of TF-Example 20 | protocol buffers, each of which contain a single image and label. 21 | 22 | The script should take several minutes to run. 23 | 24 | Copied from https://github.com/tensorflow/models/blob/master/slim/datasets/download_and_convert_cifar10.py 25 | """ 26 | from __future__ import absolute_import 27 | from __future__ import division 28 | from __future__ import print_function 29 | 30 | import cPickle 31 | import os 32 | import sys 33 | import tarfile 34 | 35 | import numpy as np 36 | from six.moves import urllib 37 | import tensorflow as tf 38 | 39 | import dataset_utils 40 | 41 | FLAGS = tf.app.flags.FLAGS 42 | 43 | tf.app.flags.DEFINE_string( 44 | 'dataset_dir', 45 | None, 46 | 'The directory where the output TFRecords and temporary files are saved.') 47 | 48 | # The URL where the CIFAR data can be downloaded. 49 | _DATA_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' 50 | 51 | # The number of training files. 52 | _NUM_TRAIN_FILES = 5 53 | 54 | # The height and width of each image. 55 | _IMAGE_SIZE = 32 56 | 57 | # The names of the classes. 58 | _CLASS_NAMES = [ 59 | 'airplane', 60 | 'automobile', 61 | 'bird', 62 | 'cat', 63 | 'deer', 64 | 'dog', 65 | 'frog', 66 | 'horse', 67 | 'ship', 68 | 'truck', 69 | ] 70 | 71 | 72 | def _add_to_tfrecord(filename, tfrecord_writer, offset=0): 73 | """Loads data from the cifar10 pickle files and writes files to a TFRecord. 74 | 75 | Args: 76 | filename: The filename of the cifar10 pickle file. 77 | tfrecord_writer: The TFRecord writer to use for writing. 78 | offset: An offset into the absolute number of images previously written. 79 | 80 | Returns: 81 | The new offset. 82 | """ 83 | with tf.gfile.Open(filename, 'r') as f: 84 | data = cPickle.load(f) 85 | 86 | images = data['data'] 87 | num_images = images.shape[0] 88 | 89 | images = images.reshape((num_images, 3, 32, 32)) 90 | labels = data['labels'] 91 | 92 | with tf.Graph().as_default(): 93 | image_placeholder = tf.placeholder(dtype=tf.uint8) 94 | encoded_image = tf.image.encode_png(image_placeholder) 95 | 96 | with tf.Session('') as sess: 97 | 98 | for j in range(num_images): 99 | sys.stdout.write('\r>> Reading file [%s] image %d/%d' % ( 100 | filename, offset + j + 1, offset + num_images)) 101 | sys.stdout.flush() 102 | 103 | image = np.squeeze(images[j]).transpose((1, 2, 0)) 104 | label = labels[j] 105 | 106 | png_string = sess.run(encoded_image, 107 | feed_dict={image_placeholder: image}) 108 | 109 | example = dataset_utils.image_to_tfexample( 110 | png_string, 'png', _IMAGE_SIZE, _IMAGE_SIZE, label) 111 | tfrecord_writer.write(example.SerializeToString()) 112 | 113 | return offset + num_images 114 | 115 | 116 | def _get_output_filename(dataset_dir, split_name): 117 | """Creates the output filename. 118 | 119 | Args: 120 | dataset_dir: The dataset directory where the dataset is stored. 121 | split_name: The name of the train/test split. 122 | 123 | Returns: 124 | An absolute file path. 125 | """ 126 | return '%s/cifar10_%s.tfrecord' % (dataset_dir, split_name) 127 | 128 | 129 | def _download_and_uncompress_dataset(dataset_dir): 130 | """Downloads cifar10 and uncompresses it locally. 131 | 132 | Args: 133 | dataset_dir: The directory where the temporary files are stored. 134 | """ 135 | filename = _DATA_URL.split('/')[-1] 136 | filepath = os.path.join(dataset_dir, filename) 137 | 138 | if not os.path.exists(filepath): 139 | def _progress(count, block_size, total_size): 140 | sys.stdout.write('\r>> Downloading %s %.1f%%' % ( 141 | filename, float(count * block_size) / float(total_size) * 100.0)) 142 | sys.stdout.flush() 143 | filepath, _ = urllib.request.urlretrieve(_DATA_URL, filepath, _progress) 144 | print() 145 | statinfo = os.stat(filepath) 146 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 147 | tarfile.open(filepath, 'r:gz').extractall(dataset_dir) 148 | 149 | 150 | def _clean_up_temporary_files(dataset_dir): 151 | """Removes temporary files used to create the dataset. 152 | 153 | Args: 154 | dataset_dir: The directory where the temporary files are stored. 155 | """ 156 | filename = _DATA_URL.split('/')[-1] 157 | filepath = os.path.join(dataset_dir, filename) 158 | tf.gfile.Remove(filepath) 159 | 160 | tmp_dir = os.path.join(dataset_dir, 'cifar-10-batches-py') 161 | tf.gfile.DeleteRecursively(tmp_dir) 162 | 163 | 164 | def main(_): 165 | """Runs the download and conversion operation. 166 | 167 | Args: 168 | dataset_dir: The dataset directory where the dataset is stored. 169 | """ 170 | dataset_dir = FLAGS.dataset_dir 171 | 172 | if not tf.gfile.Exists(dataset_dir): 173 | tf.gfile.MakeDirs(dataset_dir) 174 | 175 | training_filename = _get_output_filename(dataset_dir, 'train') 176 | testing_filename = _get_output_filename(dataset_dir, 'test') 177 | 178 | if tf.gfile.Exists(training_filename) and tf.gfile.Exists(testing_filename): 179 | print('Dataset files already exist. Exiting without re-creating them.') 180 | return 181 | 182 | dataset_utils.download_and_uncompress_tarball(_DATA_URL, dataset_dir) 183 | 184 | # First, process the training data: 185 | with tf.python_io.TFRecordWriter(training_filename) as tfrecord_writer: 186 | offset = 0 187 | for i in range(_NUM_TRAIN_FILES): 188 | filename = os.path.join(dataset_dir, 189 | 'cifar-10-batches-py', 190 | 'data_batch_%d' % (i + 1)) # 1-indexed. 191 | offset = _add_to_tfrecord(filename, tfrecord_writer, offset) 192 | 193 | # Next, process the testing data: 194 | with tf.python_io.TFRecordWriter(testing_filename) as tfrecord_writer: 195 | filename = os.path.join(dataset_dir, 196 | 'cifar-10-batches-py', 197 | 'test_batch') 198 | _add_to_tfrecord(filename, tfrecord_writer) 199 | 200 | # Finally, write the labels file: 201 | labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES)) 202 | dataset_utils.write_label_file(labels_to_class_names, dataset_dir) 203 | 204 | _clean_up_temporary_files(dataset_dir) 205 | print('\nFinished converting the Cifar10 dataset!') 206 | 207 | 208 | if __name__ == '__main__': 209 | tf.app.run() 210 | -------------------------------------------------------------------------------- /external/inception_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides utilities to preprocess images for the Inception networks. 16 | 17 | Copied from https://github.com/tensorflow/models/blob/master/slim/preprocessing/inception_preprocessing.py 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import tensorflow as tf 25 | 26 | from tensorflow.python.ops import control_flow_ops 27 | 28 | 29 | def apply_with_random_selector(x, func, num_cases): 30 | """Computes func(x, sel), with sel sampled from [0...num_cases-1]. 31 | 32 | Args: 33 | x: input Tensor. 34 | func: Python function to apply. 35 | num_cases: Python int32, number of cases to sample sel from. 36 | 37 | Returns: 38 | The result of func(x, sel), where func receives the value of the 39 | selector as a python integer, but sel is sampled dynamically. 40 | """ 41 | sel = tf.random_uniform([], maxval=num_cases, dtype=tf.int32) 42 | # Pass the real x only to one of the func calls. 43 | return control_flow_ops.merge([ 44 | func(control_flow_ops.switch(x, tf.equal(sel, case))[1], case) 45 | for case in range(num_cases)])[0] 46 | 47 | 48 | def distort_color(image, color_ordering=0, fast_mode=True, scope=None): 49 | """Distort the color of a Tensor image. 50 | 51 | Each color distortion is non-commutative and thus ordering of the color ops 52 | matters. Ideally we would randomly permute the ordering of the color ops. 53 | Rather then adding that level of complication, we select a distinct ordering 54 | of color ops for each preprocessing thread. 55 | 56 | Args: 57 | image: 3-D Tensor containing single image in [0, 1]. 58 | color_ordering: Python int, a type of distortion (valid values: 0-3). 59 | fast_mode: Avoids slower ops (random_hue and random_contrast) 60 | scope: Optional scope for name_scope. 61 | Returns: 62 | 3-D Tensor color-distorted image on range [0, 1] 63 | Raises: 64 | ValueError: if color_ordering not in [0, 3] 65 | """ 66 | with tf.name_scope(scope, 'distort_color', [image]): 67 | if fast_mode: 68 | if color_ordering == 0: 69 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 70 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 71 | else: 72 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 73 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 74 | else: 75 | if color_ordering == 0: 76 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 77 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 78 | image = tf.image.random_hue(image, max_delta=0.2) 79 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 80 | elif color_ordering == 1: 81 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 82 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 83 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 84 | image = tf.image.random_hue(image, max_delta=0.2) 85 | elif color_ordering == 2: 86 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 87 | image = tf.image.random_hue(image, max_delta=0.2) 88 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 89 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 90 | elif color_ordering == 3: 91 | image = tf.image.random_hue(image, max_delta=0.2) 92 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 93 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 94 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 95 | else: 96 | raise ValueError('color_ordering must be in [0, 3]') 97 | 98 | # The random_* ops do not necessarily clamp. 99 | return tf.clip_by_value(image, 0.0, 1.0) 100 | 101 | 102 | def distorted_bounding_box_crop(image, 103 | bbox, 104 | min_object_covered=0.1, 105 | aspect_ratio_range=(0.75, 1.33), 106 | area_range=(0.05, 1.0), 107 | max_attempts=100, 108 | scope=None): 109 | """Generates cropped_image using a one of the bboxes randomly distorted. 110 | 111 | See `tf.image.sample_distorted_bounding_box` for more documentation. 112 | 113 | Args: 114 | image: 3-D Tensor of image (it will be converted to floats in [0, 1]). 115 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] 116 | where each coordinate is [0, 1) and the coordinates are arranged 117 | as [ymin, xmin, ymax, xmax]. If num_boxes is 0 then it would use the whole 118 | image. 119 | min_object_covered: An optional `float`. Defaults to `0.1`. The cropped 120 | area of the image must contain at least this fraction of any bounding box 121 | supplied. 122 | aspect_ratio_range: An optional list of `floats`. The cropped area of the 123 | image must have an aspect ratio = width / height within this range. 124 | area_range: An optional list of `floats`. The cropped area of the image 125 | must contain a fraction of the supplied image within in this range. 126 | max_attempts: An optional `int`. Number of attempts at generating a cropped 127 | region of the image of the specified constraints. After `max_attempts` 128 | failures, return the entire image. 129 | scope: Optional scope for name_scope. 130 | Returns: 131 | A tuple, a 3-D Tensor cropped_image and the distorted bbox 132 | """ 133 | with tf.name_scope(scope, 'distorted_bounding_box_crop', [image, bbox]): 134 | # Each bounding box has shape [1, num_boxes, box coords] and 135 | # the coordinates are ordered [ymin, xmin, ymax, xmax]. 136 | 137 | # A large fraction of image datasets contain a human-annotated bounding 138 | # box delineating the region of the image containing the object of interest. 139 | # We choose to create a new bounding box for the object which is a randomly 140 | # distorted version of the human-annotated bounding box that obeys an 141 | # allowed range of aspect ratios, sizes and overlap with the human-annotated 142 | # bounding box. If no box is supplied, then we assume the bounding box is 143 | # the entire image. 144 | sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( 145 | tf.shape(image), 146 | bounding_boxes=bbox, 147 | min_object_covered=min_object_covered, 148 | aspect_ratio_range=aspect_ratio_range, 149 | area_range=area_range, 150 | max_attempts=max_attempts, 151 | use_image_if_no_bounding_boxes=True) 152 | bbox_begin, bbox_size, distort_bbox = sample_distorted_bounding_box 153 | 154 | # Crop the image to the specified bounding box. 155 | cropped_image = tf.slice(image, bbox_begin, bbox_size) 156 | return cropped_image, distort_bbox 157 | 158 | 159 | def preprocess_for_train(image, height, width, bbox, 160 | fast_mode=True, 161 | scope=None): 162 | """Distort one image for training a network. 163 | 164 | Distorting images provides a useful technique for augmenting the data 165 | set during training in order to make the network invariant to aspects 166 | of the image that do not effect the label. 167 | 168 | Additionally it would create image_summaries to display the different 169 | transformations applied to the image. 170 | 171 | Args: 172 | image: 3-D Tensor of image. If dtype is tf.float32 then the range should be 173 | [0, 1], otherwise it would converted to tf.float32 assuming that the range 174 | is [0, MAX], where MAX is largest positive representable number for 175 | int(8/16/32) data type (see `tf.image.convert_image_dtype` for details). 176 | height: integer 177 | width: integer 178 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] 179 | where each coordinate is [0, 1) and the coordinates are arranged 180 | as [ymin, xmin, ymax, xmax]. 181 | fast_mode: Optional boolean, if True avoids slower transformations (i.e. 182 | bi-cubic resizing, random_hue or random_contrast). 183 | scope: Optional scope for name_scope. 184 | Returns: 185 | 3-D float Tensor of distorted image used for training with range [-1, 1]. 186 | """ 187 | with tf.name_scope(scope, 'distort_image', [image, height, width, bbox]): 188 | if bbox is None: 189 | bbox = tf.constant([0.0, 0.0, 1.0, 1.0], 190 | dtype=tf.float32, 191 | shape=[1, 1, 4]) 192 | if image.dtype != tf.float32: 193 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 194 | # Each bounding box has shape [1, num_boxes, box coords] and 195 | # the coordinates are ordered [ymin, xmin, ymax, xmax]. 196 | image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0), 197 | bbox) 198 | tf.summary.image('image_with_bounding_boxes', image_with_box) 199 | 200 | distorted_image, distorted_bbox = distorted_bounding_box_crop(image, bbox) 201 | # Restore the shape since the dynamic slice based upon the bbox_size loses 202 | # the third dimension. 203 | distorted_image.set_shape([None, None, 3]) 204 | image_with_distorted_box = tf.image.draw_bounding_boxes( 205 | tf.expand_dims(image, 0), distorted_bbox) 206 | tf.summary.image('images_with_distorted_bounding_box', 207 | image_with_distorted_box) 208 | 209 | # This resizing operation may distort the images because the aspect 210 | # ratio is not respected. We select a resize method in a round robin 211 | # fashion based on the thread number. 212 | # Note that ResizeMethod contains 4 enumerated resizing methods. 213 | 214 | # We select only 1 case for fast_mode bilinear. 215 | num_resize_cases = 1 if fast_mode else 4 216 | distorted_image = apply_with_random_selector( 217 | distorted_image, 218 | lambda x, method: tf.image.resize_images(x, [height, width], method=method), 219 | num_cases=num_resize_cases) 220 | 221 | tf.summary.image('cropped_resized_image', 222 | tf.expand_dims(distorted_image, 0)) 223 | 224 | # Randomly flip the image horizontally. 225 | distorted_image = tf.image.random_flip_left_right(distorted_image) 226 | 227 | # Randomly distort the colors. There are 4 ways to do it. 228 | distorted_image = apply_with_random_selector( 229 | distorted_image, 230 | lambda x, ordering: distort_color(x, ordering, fast_mode), 231 | num_cases=4) 232 | 233 | tf.summary.image('final_distorted_image', 234 | tf.expand_dims(distorted_image, 0)) 235 | distorted_image = tf.subtract(distorted_image, 0.5) 236 | distorted_image = tf.multiply(distorted_image, 2.0) 237 | return distorted_image 238 | 239 | 240 | def preprocess_for_eval(image, height, width, 241 | central_fraction=0.875, scope=None): 242 | """Prepare one image for evaluation. 243 | 244 | If height and width are specified it would output an image with that size by 245 | applying resize_bilinear. 246 | 247 | If central_fraction is specified it would cropt the central fraction of the 248 | input image. 249 | 250 | Args: 251 | image: 3-D Tensor of image. If dtype is tf.float32 then the range should be 252 | [0, 1], otherwise it would converted to tf.float32 assuming that the range 253 | is [0, MAX], where MAX is largest positive representable number for 254 | int(8/16/32) data type (see `tf.image.convert_image_dtype` for details) 255 | height: integer 256 | width: integer 257 | central_fraction: Optional Float, fraction of the image to crop. 258 | scope: Optional scope for name_scope. 259 | Returns: 260 | 3-D float Tensor of prepared image. 261 | """ 262 | with tf.name_scope(scope, 'eval_image', [image, height, width]): 263 | if image.dtype != tf.float32: 264 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 265 | # Crop the central region of the image with an area containing 87.5% of 266 | # the original image. 267 | if central_fraction: 268 | image = tf.image.central_crop(image, central_fraction=central_fraction) 269 | 270 | if height and width: 271 | # Resize the image to the specified height and width. 272 | image = tf.expand_dims(image, 0) 273 | image = tf.image.resize_bilinear(image, [height, width], 274 | align_corners=False) 275 | image = tf.squeeze(image, [0]) 276 | image = tf.subtract(image, 0.5) 277 | image = tf.multiply(image, 2.0) 278 | return image 279 | 280 | 281 | def preprocess_image(image, height, width, 282 | is_training=False, 283 | bbox=None, 284 | fast_mode=True): 285 | """Pre-process one image for training or evaluation. 286 | 287 | Args: 288 | image: 3-D Tensor [height, width, channels] with the image. 289 | height: integer, image expected height. 290 | width: integer, image expected width. 291 | is_training: Boolean. If true it would transform an image for train, 292 | otherwise it would transform it for evaluation. 293 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] 294 | where each coordinate is [0, 1) and the coordinates are arranged as 295 | [ymin, xmin, ymax, xmax]. 296 | fast_mode: Optional boolean, if True avoids slower transformations. 297 | 298 | Returns: 299 | 3-D float Tensor containing an appropriately scaled image 300 | 301 | Raises: 302 | ValueError: if user does not provide bounding box 303 | """ 304 | if is_training: 305 | return preprocess_for_train(image, height, width, bbox, fast_mode) 306 | else: 307 | return preprocess_for_eval(image, height, width) 308 | -------------------------------------------------------------------------------- /fake_cifar10.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Script to generate random data of the same format as CIFAR-10. 17 | 18 | Creates TFRecord files with the same fields as 19 | tensorflow/models/slim/datasets/downlod_and_convert_cifar10.py 20 | for use in unit tests of the code that handles this data. 21 | """ 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import os 28 | import StringIO 29 | 30 | import numpy as np 31 | from PIL import Image 32 | import tensorflow as tf 33 | 34 | from tensorflow_models.slim.datasets import dataset_utils 35 | 36 | tf.app.flags.DEFINE_string('out_directory', 'testdata/cifar10', 37 | 'Output directory for the test data.') 38 | 39 | FLAGS = tf.app.flags.FLAGS 40 | 41 | 42 | _IMAGE_SIZE = 32 43 | 44 | 45 | def create_fake_data(split_name, num_examples=4): 46 | """Writes the fake TFRecords for one split of the dataset. 47 | 48 | Args: 49 | split_name: One of 'train' or 'test'. 50 | num_examples: The number of random examples to generate and write to the 51 | output TFRecord file. 52 | """ 53 | output_file = os.path.join(FLAGS.out_directory, 54 | 'cifar10_%s.tfrecord' % split_name) 55 | writer = tf.python_io.TFRecordWriter(output_file) 56 | for _ in range(num_examples): 57 | image = np.random.randint(256, size=(_IMAGE_SIZE, _IMAGE_SIZE, 3), 58 | dtype=np.uint8) 59 | image = Image.fromarray(image) 60 | image_buffer = StringIO.StringIO() 61 | image.save(image_buffer, format='png') 62 | image_buffer = image_buffer.getvalue() 63 | 64 | label = 0 65 | example = dataset_utils.image_to_tfexample( 66 | image_buffer, 'png', _IMAGE_SIZE, _IMAGE_SIZE, label) 67 | writer.write(example.SerializeToString()) 68 | writer.close() 69 | 70 | 71 | def main(_): 72 | create_fake_data('train') 73 | create_fake_data('test') 74 | 75 | 76 | if __name__ == '__main__': 77 | tf.app.run() 78 | -------------------------------------------------------------------------------- /fake_imagenet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Script to generate random data of the same format as ImageNet. 17 | 18 | Creates TFRecord files with the same fields as 19 | tensorflow/models/inception/inception/build_imagenet_data 20 | for use in unit tests of the code that handles this data. 21 | """ 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import os 28 | import StringIO 29 | 30 | import numpy as np 31 | from PIL import Image 32 | import tensorflow as tf 33 | 34 | from inception.inception.data import build_imagenet_data 35 | 36 | 37 | tf.app.flags.DEFINE_string('out_directory', 'testdata/imagenet', 38 | 'Output directory for the test data.') 39 | 40 | FLAGS = tf.app.flags.FLAGS 41 | 42 | 43 | def _random_bounds(n): 44 | x1, x2 = tuple(np.random.randint(n + 1, size=(2,)) / n) 45 | return min(x1, x2), max(x1, x2) 46 | 47 | 48 | def _random_bbox(image_width, image_height): 49 | xmin, xmax = _random_bounds(image_width) 50 | ymin, ymax = _random_bounds(image_height) 51 | return [xmin, ymin, xmax, ymax] 52 | 53 | 54 | def create_fake_data(split_name, image_width=640, image_height=480): 55 | """Generates the fake data for a given ImageNet split. 56 | 57 | Args: 58 | split_name: One of 'train' or 'valdiation'. 59 | image_width: The width of the random image to generate and write as an 60 | integer. 61 | image_height: Integer height o fthe random image. 62 | """ 63 | filename = '/tmp/fake_%s.jpg' % split_name 64 | 65 | image = np.random.randint(256, size=(image_height, image_width, 3), 66 | dtype=np.uint8) 67 | image = Image.fromarray(image) 68 | image_buffer = StringIO.StringIO() 69 | image.save(image_buffer, format='jpeg') 70 | image_buffer = image_buffer.getvalue() 71 | 72 | bboxes = [_random_bbox(image_width, image_height)] 73 | 74 | output_file = os.path.join(FLAGS.out_directory, 75 | '%s-00000-of-00001' % split_name) 76 | writer = tf.python_io.TFRecordWriter(output_file) 77 | # pylint: disable=protected-access 78 | example = build_imagenet_data._convert_to_example( 79 | filename, image_buffer, 0, 'n02110341', 'dalmation', bboxes, 80 | image_height, image_width) 81 | # pylint: enable=protected-access 82 | writer.write(example.SerializeToString()) 83 | writer.close() 84 | 85 | 86 | def main(_): 87 | create_fake_data('train') 88 | create_fake_data('validation') 89 | 90 | 91 | if __name__ == '__main__': 92 | tf.app.run() 93 | -------------------------------------------------------------------------------- /flopsometer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Measures FLOPS in convolution layers.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | from tensorflow.contrib import slim 25 | from tensorflow.contrib.layers.python.layers import utils 26 | 27 | 28 | def conv2d(inputs, num_outputs, kernel_size, *args, **kwargs): 29 | """A wrapper/substitute for conv2d that counts the flops. 30 | 31 | This counts the number of floating-point operations (flops) for a conv2d 32 | layer, including one with a "mask." The optional keyword argument 33 | `output_mask` specifies which of the position in the output response map need 34 | actually be calculated, the rest can be discarded and are not counted in the 35 | result. 36 | 37 | Since this is a wrapper around slim.conv2d, see that function for details on 38 | the inputs/outputs. 39 | 40 | Args: 41 | inputs: The input response map to the convolution. 42 | num_outputs: The number of output channels for the convolution. 43 | kernel_size: Spatial size of the convolution kernel. 44 | *args: Additional position arguments forwarded to slim.conv2d. 45 | **kwargs: Additional keyword args forwarded to slim.conv2d. 46 | Returns: 47 | outputs: The result of the convolution from slim.conv2d. 48 | flops: The operation count as a scalar integer tensor. 49 | """ 50 | output_mask = kwargs.pop('output_mask', None) 51 | 52 | outputs = slim.conv2d(inputs, num_outputs, kernel_size, *args, **kwargs) 53 | 54 | if inputs.get_shape().is_fully_defined(): 55 | inputs_shape = inputs.get_shape().as_list() 56 | outputs_shape = outputs.get_shape().as_list() 57 | else: 58 | inputs_shape = tf.to_int64(tf.shape(inputs)) 59 | outputs_shape = tf.to_int64(tf.shape(outputs)) 60 | batch_size = outputs_shape[0] 61 | 62 | num_filters_in = inputs_shape[3] 63 | kernel_h, kernel_w = utils.two_element_tuple(kernel_size) 64 | if output_mask is None: 65 | num_spatial_positions = tf.fill( 66 | # tf.fill does not support int64 dims :-| 67 | dims=tf.to_int32(tf.stack([batch_size])), 68 | value=outputs_shape[1] * outputs_shape[2]) 69 | else: 70 | num_spatial_positions = tf.reduce_sum(output_mask, [1, 2]) 71 | num_spatial_positions = tf.to_int64(num_spatial_positions) 72 | 73 | num_output_positions = num_spatial_positions * num_outputs 74 | flops = 2 * num_output_positions * (kernel_h * kernel_w * num_filters_in) 75 | 76 | # The numbers are slightly different than TensorFlow graph_metrics since we 77 | # ignore biases. We do not try to mimic graph_metrics because it is 78 | # inconsistent in the treatment of biases (batch_norm makes biases "free"). 79 | return outputs, flops 80 | 81 | 82 | def conv2d_same(inputs, 83 | num_outputs, 84 | kernel_size, 85 | stride, 86 | rate=1, 87 | output_mask=None, 88 | scope=None): 89 | """Version of TF-Slim resnet_utils.conv2d_same that uses the flopsometer.""" 90 | if stride == 1: 91 | return conv2d( 92 | inputs, 93 | num_outputs, 94 | kernel_size, 95 | stride=1, 96 | rate=rate, 97 | padding='SAME', 98 | output_mask=output_mask, 99 | scope=scope) 100 | else: 101 | kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) 102 | pad_total = kernel_size_effective - 1 103 | pad_beg = pad_total // 2 104 | pad_end = pad_total - pad_beg 105 | inputs = tf.pad(inputs, 106 | [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]]) 107 | return conv2d( 108 | inputs, 109 | num_outputs, 110 | kernel_size, 111 | stride=stride, 112 | rate=rate, 113 | padding='VALID', 114 | output_mask=output_mask, 115 | scope=scope) 116 | -------------------------------------------------------------------------------- /flopsometer_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for flopsometer.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import numpy as np 23 | import tensorflow as tf 24 | 25 | import flopsometer 26 | 27 | 28 | class FlopsometerTest(tf.test.TestCase): 29 | 30 | def testConv2d(self): 31 | inputs = tf.zeros([2, 16, 16, 4]) 32 | _, flops = flopsometer.conv2d( 33 | inputs, 8, [3, 3], stride=1, padding='SAME', output_mask=None) 34 | expected_flops = 2 * 16 * 16 * 3 * 3 * 8 * 4 35 | with self.test_session() as sess: 36 | sess.run(tf.global_variables_initializer()) 37 | flops_out = sess.run(flops) 38 | self.assertAllEqual(flops_out, [expected_flops, expected_flops]) 39 | 40 | def testConv2dUnknownSize(self): 41 | inputs = np.zeros([2, 16, 16, 4], dtype=np.float32) 42 | inputs_tf = tf.placeholder(tf.float32, shape=(2, None, None, 4)) 43 | _, flops = flopsometer.conv2d( 44 | inputs_tf, 8, [3, 3], stride=1, padding='SAME', output_mask=None) 45 | expected_flops = 2 * 16 * 16 * 3 * 3 * 8 * 4 46 | with self.test_session() as sess: 47 | sess.run(tf.global_variables_initializer()) 48 | flops_out = sess.run(flops, feed_dict={inputs_tf: inputs}) 49 | self.assertAllEqual(flops_out, [expected_flops, expected_flops]) 50 | 51 | def testConv2dStride(self): 52 | inputs = tf.zeros([2, 16, 16, 4]) 53 | _, flops = flopsometer.conv2d( 54 | inputs, 8, [3, 3], stride=2, padding='SAME', output_mask=None) 55 | output_positions = 8 * 8 56 | expected_flops = 2 * output_positions * 3 * 3 * 8 * 4 57 | with self.test_session() as sess: 58 | sess.run(tf.global_variables_initializer()) 59 | flops_out = sess.run(flops) 60 | self.assertAllEqual(flops_out, [expected_flops, expected_flops]) 61 | 62 | def testConv2dOutputMask(self): 63 | inputs = tf.zeros([2, 16, 16, 4]) 64 | mask = np.random.random([2, 16, 16]) <= 0.6 65 | mask_tf = tf.constant(np.float32(mask)) 66 | _, flops = flopsometer.conv2d( 67 | inputs, 8, [3, 3], stride=1, padding='SAME', output_mask=mask_tf) 68 | 69 | per_position_flops = 2 * 3 * 3 * 8 * 4 70 | num_positions = np.sum(np.sum(np.int32(mask), 2), 1) 71 | expected_flops = [ 72 | per_position_flops * num_positions[0], 73 | per_position_flops * num_positions[1] 74 | ] 75 | 76 | with self.test_session() as sess: 77 | sess.run(tf.global_variables_initializer()) 78 | flops_out = sess.run(flops) 79 | self.assertAllEqual(flops_out, expected_flops) 80 | 81 | 82 | if __name__ == '__main__': 83 | tf.test.main() 84 | -------------------------------------------------------------------------------- /imagenet_data_provider.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Contains code for loading and preprocessing the ImageNet data.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | 24 | import tensorflow as tf 25 | 26 | from tensorflow.contrib import slim 27 | from tensorflow.contrib.slim import dataset_data_provider 28 | 29 | from external import inception_preprocessing 30 | from external import datasets_imagenet 31 | 32 | 33 | def provide_data(split_name, batch_size, dataset_dir=None, is_training=False, 34 | num_readers=4, num_preprocessing_threads=4, image_size=224): 35 | """Provides batches of Imagenet data. 36 | 37 | Applies the processing in external/inception_preprocessing 38 | to the TF-Slim ImageNet dataset class. 39 | 40 | Args: 41 | split_name: Either 'train' or 'validation'. 42 | batch_size: The number of images in each batch. 43 | dataset_dir: Directory where the ImageNet TFRecord files live. 44 | Defaults to "~/tensorflow/data/imagenet" 45 | is_training: Whether to apply data augmentation and shuffling. 46 | num_readers: Number of parallel readers. Always set to one for evaluation. 47 | num_preprocessing_threads: Number of preprocessing threads. 48 | 49 | Returns: 50 | images: A `Tensor` of size [batch_size, image_size, image_size, 3] 51 | one_hot_labels: A `Tensor` of size [batch_size, num_classes], where 52 | each row has a single element set to one and the rest set to zeros. 53 | dataset.num_samples: The number of total samples in the dataset. 54 | dataset.num_classes: The number of object classes in the dataset. 55 | 56 | Raises: 57 | ValueError: if the split_name is not either 'train' or 'validation'. 58 | """ 59 | 60 | with tf.device('/cpu:0'): 61 | if dataset_dir is None: 62 | dataset_dir = os.path.expanduser('~/tensorflow/data/imagenet') 63 | 64 | if not is_training: 65 | num_readers = 1 66 | 67 | dataset = datasets_imagenet.get_split(split_name, dataset_dir) 68 | provider = dataset_data_provider.DatasetDataProvider( 69 | dataset, 70 | num_readers=num_readers, 71 | shuffle=is_training, 72 | common_queue_capacity=5 * batch_size, 73 | common_queue_min=batch_size) 74 | 75 | [image, bbox, label] = provider.get(['image', 'object/bbox', 'label']) 76 | bbox = tf.expand_dims(bbox, 0) 77 | 78 | image = inception_preprocessing.preprocess_image( 79 | image, image_size, image_size, is_training, bbox, fast_mode=False) 80 | 81 | images, labels = tf.train.batch( 82 | [image, label], 83 | batch_size=batch_size, 84 | num_threads=num_preprocessing_threads, 85 | capacity=5 * batch_size) 86 | 87 | one_hot_labels = tf.one_hot(labels, dataset.num_classes) 88 | 89 | return images, one_hot_labels, dataset.num_samples, dataset.num_classes 90 | -------------------------------------------------------------------------------- /imagenet_data_provider_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for imagenet_data_provider.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | from tensorflow.contrib import slim 24 | 25 | import imagenet_data_provider 26 | 27 | 28 | class ImagenetDataProviderTest(tf.test.TestCase): 29 | 30 | def _testImageNet(self, split_name, is_training, expected_num_samples): 31 | images, one_hot_labels, num_samples, num_classes = \ 32 | imagenet_data_provider.provide_data(split_name, 1, 33 | dataset_dir='testdata/imagenet', 34 | is_training=is_training) 35 | self.assertEqual(num_samples, expected_num_samples) 36 | self.assertEqual(num_classes, 1001) 37 | with self.test_session() as sess: 38 | with slim.queues.QueueRunners(sess): 39 | images_out, one_hot_labels_out = sess.run([images, one_hot_labels]) 40 | self.assertEqual(images_out.shape, (1, 224, 224, 3)) 41 | self.assertEqual(one_hot_labels_out.shape, (1, 1001)) 42 | 43 | def testImageNetTrainSet(self): 44 | self._testImageNet('train', True, 1281167) 45 | 46 | def testImageNetValidationSet(self): 47 | self._testImageNet('validation', False, 50000) 48 | 49 | 50 | if __name__ == '__main__': 51 | tf.test.main() 52 | -------------------------------------------------------------------------------- /imagenet_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Evaluates a trained ResNet model.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import math 23 | 24 | import tensorflow as tf 25 | from tensorflow.contrib import slim 26 | 27 | import imagenet_data_provider 28 | import imagenet_model 29 | import summary_utils 30 | import utils 31 | 32 | FLAGS = tf.app.flags.FLAGS 33 | 34 | tf.app.flags.DEFINE_string('master', '', 35 | 'Name of the TensorFlow master to use.') 36 | 37 | tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/resnet/', 38 | 'Directory where the model was written to.') 39 | 40 | tf.app.flags.DEFINE_string('eval_dir', '/tmp/resnet/', 41 | 'Directory where the results are saved to.') 42 | 43 | tf.app.flags.DEFINE_string('dataset_dir', None, 'Directory with Imagenet data.') 44 | 45 | tf.app.flags.DEFINE_integer('eval_interval_secs', 600, 46 | 'The frequency, in seconds, with which evaluation is run.') 47 | 48 | tf.app.flags.DEFINE_integer('num_examples', 50000, 49 | 'The number of examples to evaluate') 50 | 51 | tf.app.flags.DEFINE_integer( 52 | 'batch_size', 32, 53 | 'The number of examples to evaluate per evaluation iteration.') 54 | 55 | tf.app.flags.DEFINE_string( 56 | 'split_name', 'validation', 57 | 'The name of the train/test split, either \'train\' or \'validation\'.') 58 | 59 | tf.app.flags.DEFINE_float('moving_average_decay', 0.9999, 60 | 'The decay to use for the moving average.') 61 | 62 | tf.app.flags.DEFINE_integer('image_size', 224, 63 | 'Image resolution for resize.') 64 | 65 | tf.app.flags.DEFINE_string( 66 | 'model', '101', 67 | 'Depth of the network to train (50, 101, 152, 200), or number of layers' 68 | ' in each block (e.g. 3_4_23_3).') 69 | 70 | tf.app.flags.DEFINE_string( 71 | 'model_type', 'vanilla', 72 | 'Options: vanilla (basic ResNet model), act (Adaptive Computation Time), ' 73 | 'act_early_stopping (act implementation which actually saves time), ' 74 | 'sact (Spatially Adaptive Computation Time)') 75 | 76 | tf.app.flags.DEFINE_float('tau', 1.0, 'The value of tau (ponder relative cost).') 77 | 78 | tf.app.flags.DEFINE_bool('evaluate_once', False, 'Evaluate the model just once?') 79 | 80 | 81 | def main(_): 82 | g = tf.Graph() 83 | with g.as_default(): 84 | data_tuple = imagenet_data_provider.provide_data( 85 | FLAGS.split_name, 86 | FLAGS.batch_size, 87 | dataset_dir=FLAGS.dataset_dir, 88 | is_training=False, 89 | image_size=FLAGS.image_size) 90 | images, one_hot_labels, examples_per_epoch, num_classes = data_tuple 91 | 92 | # Define the model: 93 | with slim.arg_scope(imagenet_model.resnet_arg_scope(is_training=False)): 94 | model = utils.split_and_int(FLAGS.model) 95 | logits, end_points = imagenet_model.get_network( 96 | images, 97 | model, 98 | num_classes, 99 | model_type=FLAGS.model_type) 100 | 101 | predictions = tf.argmax(end_points['predictions'], 1) 102 | 103 | # Define the metrics: 104 | labels = tf.argmax(one_hot_labels, 1) 105 | metric_map = { 106 | 'eval/Accuracy': 107 | tf.contrib.metrics.streaming_accuracy(predictions, labels), 108 | 'eval/Recall@5': 109 | tf.contrib.metrics.streaming_sparse_recall_at_k( 110 | end_points['predictions'], tf.expand_dims(labels, 1), 5), 111 | } 112 | metric_map.update(summary_utils.flops_metric_map(end_points, True)) 113 | if FLAGS.model_type in ['act', 'act_early_stopping', 'sact']: 114 | metric_map.update(summary_utils.act_metric_map(end_points, True)) 115 | 116 | names_to_values, names_to_updates = tf.contrib.metrics.aggregate_metric_map( 117 | metric_map) 118 | 119 | for name, value in names_to_values.iteritems(): 120 | summ = tf.summary.scalar(name, value, collections=[]) 121 | summ = tf.Print(summ, [value], name) 122 | tf.add_to_collection(tf.GraphKeys.SUMMARIES, summ) 123 | 124 | if FLAGS.model_type == 'sact': 125 | summary_utils.add_heatmaps_image_summary(end_points, border=10) 126 | 127 | # This ensures that we make a single pass over all of the data. 128 | num_batches = math.ceil(FLAGS.num_examples / float(FLAGS.batch_size)) 129 | 130 | if not FLAGS.evaluate_once: 131 | eval_function = slim.evaluation.evaluation_loop 132 | checkpoint_path = FLAGS.checkpoint_dir 133 | kwargs = {'eval_interval_secs': FLAGS.eval_interval_secs} 134 | else: 135 | eval_function = slim.evaluation.evaluate_once 136 | checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) 137 | assert checkpoint_path is not None 138 | kwargs = {} 139 | 140 | eval_function( 141 | FLAGS.master, 142 | checkpoint_path, 143 | logdir=FLAGS.eval_dir, 144 | num_evals=num_batches, 145 | eval_op=names_to_updates.values(), 146 | **kwargs) 147 | 148 | 149 | if __name__ == '__main__': 150 | tf.app.run() 151 | -------------------------------------------------------------------------------- /imagenet_export.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Exports data about a trained ResNet-ACT/SACT model into a HDF5 file.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import math 23 | 24 | import tensorflow as tf 25 | from tensorflow.contrib import slim 26 | 27 | import imagenet_data_provider 28 | import imagenet_model 29 | import summary_utils 30 | import utils 31 | 32 | FLAGS = tf.app.flags.FLAGS 33 | 34 | tf.app.flags.DEFINE_integer( 35 | 'num_examples', 1000, 36 | 'The number of examples to evaluate') 37 | 38 | tf.app.flags.DEFINE_integer( 39 | 'batch_size', 32, 40 | 'The number of examples to evaluate per evaluation iteration.') 41 | 42 | tf.app.flags.DEFINE_string( 43 | 'split_name', 'validation', 44 | 'The name of the train/test split, either \'train\' or \'validation\'.') 45 | 46 | tf.app.flags.DEFINE_string( 47 | 'model', '101', 48 | 'Depth of the network to train (50, 101, 152, 200), or number of layers' 49 | ' in each block (e.g. 3_4_23_3).') 50 | 51 | tf.app.flags.DEFINE_string( 52 | 'model_type', 'vanilla', 53 | 'Options: act (Adaptive Computation Time), ' 54 | 'act_early_stopping (act implementation which actually saves time), ' 55 | 'sact (Spatially Adaptive Computation Time)') 56 | 57 | tf.app.flags.DEFINE_string('checkpoint_dir', '', 58 | 'Directory with the checkpoints.') 59 | 60 | tf.app.flags.DEFINE_string('export_path', '', 61 | 'Path to write the hdf5 file with exported data.') 62 | 63 | tf.app.flags.DEFINE_string('dataset_dir', None, 'Directory with Imagenet data.') 64 | 65 | 66 | def main(_): 67 | assert FLAGS.model_type in ('act', 'act_early_stopping', 'sact') 68 | 69 | g = tf.Graph() 70 | with g.as_default(): 71 | data_tuple = imagenet_data_provider.provide_data( 72 | FLAGS.split_name, 73 | FLAGS.batch_size, 74 | dataset_dir=FLAGS.dataset_dir, 75 | is_training=False) 76 | images, labels, _, num_classes = data_tuple 77 | 78 | # Define the model: 79 | with slim.arg_scope(imagenet_model.resnet_arg_scope(is_training=False)): 80 | model = utils.split_and_int(FLAGS.model) 81 | logits, end_points = imagenet_model.get_network( 82 | images, 83 | model, 84 | num_classes, 85 | model_type=FLAGS.model_type) 86 | 87 | summary_utils.export_to_h5(FLAGS.checkpoint_dir, FLAGS.export_path, 88 | images, end_points, FLAGS.num_examples, 89 | FLAGS.batch_size, FLAGS.model_type=='sact') 90 | 91 | 92 | if __name__ == '__main__': 93 | tf.app.run() 94 | -------------------------------------------------------------------------------- /imagenet_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Definition of Resnet-ACT model used for imagenet classification.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | from tensorflow.contrib import slim 25 | from tensorflow.contrib.slim.nets import resnet_utils 26 | 27 | import act 28 | import flopsometer 29 | import resnet_act 30 | 31 | 32 | def bottleneck(inputs, 33 | depth, 34 | depth_bottleneck, 35 | stride, 36 | rate=1, 37 | residual_mask=None, 38 | scope=None): 39 | with tf.variable_scope(scope, 'bottleneck_v2', [inputs]) as sc: 40 | flops = 0 41 | 42 | depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4) 43 | preact = slim.batch_norm(inputs, activation_fn=tf.nn.relu, scope='preact') 44 | if depth == depth_in: 45 | shortcut = resnet_utils.subsample(inputs, stride, 'shortcut') 46 | else: 47 | shortcut, current_flops = flopsometer.conv2d( 48 | preact, 49 | depth, [1, 1], 50 | stride=stride, 51 | normalizer_fn=None, 52 | activation_fn=None, 53 | scope='shortcut') 54 | flops += current_flops 55 | 56 | if residual_mask is not None: 57 | # Max-pooling trick only works correctly when stride is 1. 58 | # We assume that stride=2 happens in the first layer where 59 | # residual_mask is None. 60 | assert stride == 1 61 | diluted_residual_mask = slim.max_pool2d( 62 | residual_mask, [3, 3], stride=1, padding='SAME') 63 | else: 64 | diluted_residual_mask = None 65 | 66 | residual, current_flops = flopsometer.conv2d( 67 | preact, 68 | depth_bottleneck, [1, 1], 69 | stride=1, 70 | output_mask=diluted_residual_mask, 71 | scope='conv1') 72 | flops += current_flops 73 | 74 | residual, current_flops = flopsometer.conv2d_same( 75 | residual, 76 | depth_bottleneck, 77 | 3, 78 | stride, 79 | rate=rate, 80 | output_mask=residual_mask, 81 | scope='conv2') 82 | flops += current_flops 83 | 84 | residual, current_flops = flopsometer.conv2d( 85 | residual, 86 | depth, [1, 1], 87 | stride=1, 88 | normalizer_fn=None, 89 | activation_fn=None, 90 | output_mask=residual_mask, 91 | scope='conv3') 92 | flops += current_flops 93 | 94 | if residual_mask is not None: 95 | residual *= residual_mask 96 | 97 | outputs = shortcut + residual 98 | 99 | return outputs, flops 100 | 101 | 102 | def resnet_v2(inputs, 103 | blocks, 104 | num_classes=None, 105 | global_pool=True, 106 | model_type='vanilla', 107 | scope=None, 108 | reuse=None, 109 | end_points=None): 110 | with tf.variable_scope(scope, 'resnet_v2', [inputs], reuse=reuse) as sc: 111 | if end_points is None: 112 | end_points = {} 113 | end_points['inputs'] = inputs 114 | end_points['flops'] = end_points.get('flops', 0) 115 | net = inputs 116 | # We do not include batch normalization or activation functions in conv1 117 | # because the first ResNet unit will perform these. Cf. Appendix of [2]. 118 | with slim.arg_scope([slim.conv2d], activation_fn=None, normalizer_fn=None): 119 | net, current_flops = flopsometer.conv2d_same( 120 | net, 64, 7, stride=2, scope='conv1') 121 | end_points['flops'] += current_flops 122 | net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool1') 123 | # Early stopping is broken in distributed training. 124 | net, end_points = resnet_act.stack_blocks( 125 | net, 126 | blocks, 127 | model_type=model_type, 128 | end_points=end_points) 129 | 130 | if global_pool or num_classes is not None: 131 | # This is needed because the pre-activation variant does not have batch 132 | # normalization or activation functions in the residual unit output. See 133 | # Appendix of [2]. 134 | net = slim.batch_norm(net, activation_fn=tf.nn.relu, scope='postnorm') 135 | 136 | if global_pool: 137 | # Global average pooling. 138 | net = tf.reduce_mean(net, [1, 2], name='pool5', keep_dims=True) 139 | 140 | if num_classes is not None: 141 | net, current_flops = flopsometer.conv2d( 142 | net, 143 | num_classes, [1, 1], 144 | activation_fn=None, 145 | normalizer_fn=None, 146 | scope='logits') 147 | end_points['flops'] += current_flops 148 | end_points['predictions'] = slim.softmax(net, scope='predictions') 149 | return net, end_points 150 | 151 | 152 | def resnet_arg_scope(is_training=True): 153 | return resnet_utils.resnet_arg_scope(is_training) 154 | # with slim.arg_scope(resnet_utils.resnet_arg_scope(is_training)): 155 | # # This forces batch_norm to compute the moving averages in-place 156 | # # instead of using a global collection which does not work with tf.cond. 157 | # with slim.arg_scope([slim.batch_norm], updates_collections=None) as arg_sc: 158 | # return arg_sc 159 | 160 | 161 | def get_network(images, 162 | model, 163 | num_classes, 164 | model_type='vanilla', 165 | global_pool=True, 166 | base_channels=64, 167 | scope=None, 168 | reuse=None, 169 | end_points=None): 170 | # These settings are *not* compatible with Slim's ResNet v2. 171 | # In ResNet Slim the downsampling is performed by the last layer of the 172 | # current block. Here we perform downsampling in the first layer of the next 173 | # block. This is consistent with the ResNet paper. 174 | num_blocks = 4 175 | if len(model) == 1: 176 | standard_networks = { 177 | 50: [3, 4, 6, 3], 178 | 101: [3, 4, 23, 3], 179 | 152: [3, 8, 36, 3], 180 | 200: [3, 24, 36, 3], 181 | } 182 | num_units = standard_networks[model[0]] 183 | else: 184 | num_units = model 185 | assert len(num_units) == num_blocks 186 | 187 | b = resnet_utils.Block 188 | bc = base_channels 189 | blocks = [ 190 | b('block1', bottleneck, [(4 * bc, bc, 1)] * num_units[0]), 191 | b('block2', bottleneck, 192 | [(8 * bc, 2 * bc, 2)] + [(8 * bc, 2 * bc, 1)] * (num_units[1] - 1)), 193 | b('block3', bottleneck, 194 | [(16 * bc, 4 * bc, 2)] + [(16 * bc, 4 * bc, 1)] * (num_units[2] - 1)), 195 | b('block4', bottleneck, 196 | [(32 * bc, 8 * bc, 2)] + [(32 * bc, 8 * bc, 1)] * (num_units[3] - 1)), 197 | ] 198 | 199 | logits, end_points = resnet_v2( 200 | images, 201 | blocks, 202 | num_classes, 203 | global_pool=global_pool, 204 | model_type=model_type, 205 | scope=scope, 206 | reuse=reuse, 207 | end_points=end_points) 208 | 209 | if num_classes is not None and global_pool: 210 | logits = tf.squeeze(logits, [1, 2], name='SpatialSqueeze') 211 | end_points['predictions'] = tf.squeeze( 212 | end_points['predictions'], [1, 2], name='SpatialSqueeze') 213 | return logits, end_points 214 | -------------------------------------------------------------------------------- /imagenet_model_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for imagenet_model.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import numpy as np 23 | import tensorflow as tf 24 | from tensorflow.contrib import slim 25 | 26 | import imagenet_model 27 | import summary_utils 28 | import training_utils 29 | 30 | 31 | class ImagenetModelTest(tf.test.TestCase): 32 | 33 | def _runBatch(self, 34 | is_training, 35 | model_type, 36 | model=[2, 2, 2, 2]): 37 | batch_size = 2 38 | height, width = 128, 128 39 | num_classes = 10 40 | 41 | with self.test_session() as sess: 42 | images = tf.random_uniform((batch_size, height, width, 3)) 43 | with slim.arg_scope( 44 | imagenet_model.resnet_arg_scope(is_training=is_training)): 45 | logits, end_points = imagenet_model.get_network( 46 | images, model, num_classes, model_type='sact', base_channels=1) 47 | if model_type in ('act', 'act_early_stopping', 'sact'): 48 | metrics = summary_utils.act_metric_map(end_points, 49 | not is_training) 50 | metrics.update(summary_utils.flops_metric_map(end_points, 51 | not is_training)) 52 | else: 53 | metrics = {} 54 | 55 | if is_training: 56 | labels = tf.random_uniform( 57 | (batch_size,), maxval=num_classes, dtype=tf.int32) 58 | one_hot_labels = slim.one_hot_encoding(labels, num_classes) 59 | tf.losses.softmax_cross_entropy( 60 | onehot_labels=one_hot_labels, logits=logits, 61 | label_smoothing=0.1, weights=1.0) 62 | if model_type in ('act', 'act_early_stopping', 'sact'): 63 | training_utils.add_all_ponder_costs(end_points, weights=1.0) 64 | total_loss = tf.losses.get_total_loss() 65 | optimizer = tf.train.MomentumOptimizer(0.1, 0.9) 66 | train_op = slim.learning.create_train_op(total_loss, optimizer) 67 | sess.run(tf.global_variables_initializer()) 68 | sess.run((train_op, metrics)) 69 | else: 70 | sess.run([tf.local_variables_initializer(), 71 | tf.global_variables_initializer()]) 72 | logits_out, metrics_out = sess.run((logits, metrics)) 73 | self.assertEqual(logits_out.shape, (batch_size, num_classes)) 74 | 75 | def testTrainVanilla(self): 76 | self._runBatch(is_training=True, model_type='vanilla') 77 | 78 | def testTrainAct(self): 79 | self._runBatch(is_training=True, model_type='act') 80 | 81 | def testTrainSact(self): 82 | self._runBatch(is_training=True, model_type='sact') 83 | 84 | def testTestVanilla(self): 85 | self._runBatch(is_training=False, model_type='vanilla') 86 | 87 | def testTestAct(self): 88 | self._runBatch(is_training=False, model_type='act') 89 | 90 | def testTestSact(self): 91 | self._runBatch(is_training=False, model_type='sact') 92 | 93 | def testTestResNet50Model(self): 94 | self._runBatch(is_training=False, model_type='vanilla', model=[50]) 95 | 96 | def testFlopsVanilla(self): 97 | batch_size = 3 98 | height, width = 224, 224 99 | num_classes = 1001 100 | 101 | with self.test_session() as sess: 102 | images = tf.random_uniform((batch_size, height, width, 3)) 103 | with slim.arg_scope(imagenet_model.resnet_arg_scope(is_training=False)): 104 | _, end_points = imagenet_model.get_network( 105 | images, [101], num_classes, 'vanilla') 106 | flops = sess.run(end_points['flops']) 107 | # TF graph_metrics value: 15614055401 (0.1% difference) 108 | expected_flops = 15602814976 109 | self.assertAllEqual(flops, [expected_flops] * 3) 110 | 111 | def testVisualizationBasic(self): 112 | batch_size = 5 113 | height, width = 128, 128 114 | num_classes = 10 115 | is_training = False 116 | num_images = 3 117 | border = 5 118 | 119 | with self.test_session() as sess: 120 | images = tf.random_uniform((batch_size, height, width, 3)) 121 | with slim.arg_scope(imagenet_model.resnet_arg_scope(is_training=is_training)): 122 | logits, end_points = imagenet_model.get_network( 123 | images, [2, 2, 2, 2], num_classes, model_type='sact', 124 | base_channels=1) 125 | 126 | vis_ponder = summary_utils.sact_image_heatmap( 127 | end_points, 128 | 'ponder_cost', 129 | num_images=num_images, 130 | alpha=0.75, 131 | border=border) 132 | vis_units = summary_utils.sact_image_heatmap( 133 | end_points, 134 | 'num_units', 135 | num_images=num_images, 136 | alpha=0.75, 137 | border=border) 138 | 139 | sess.run(tf.global_variables_initializer()) 140 | vis_ponder_out, vis_units_out = sess.run( 141 | [vis_ponder, vis_units]) 142 | self.assertEqual(vis_ponder_out.shape, 143 | (num_images, height, width * 2 + border, 3)) 144 | self.assertEqual(vis_units_out.shape, 145 | (num_images, height, width * 2 + border, 3)) 146 | 147 | 148 | if __name__ == '__main__': 149 | tf.test.main() 150 | -------------------------------------------------------------------------------- /imagenet_ponder_map.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Exports ponder cost maps for input images.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import glob 23 | import math 24 | import os 25 | 26 | import matplotlib 27 | import matplotlib.image 28 | matplotlib.use('agg') # disables drawing to X 29 | import matplotlib.pyplot as plt 30 | import tensorflow as tf 31 | from tensorflow.contrib import slim 32 | 33 | import imagenet_model 34 | import summary_utils 35 | import utils 36 | 37 | FLAGS = tf.app.flags.FLAGS 38 | 39 | tf.app.flags.DEFINE_string( 40 | 'model', '101', 41 | 'Depth of the network to train (50, 101, 152, 200), or number of layers' 42 | ' in each block (e.g. 3_4_23_3).') 43 | 44 | tf.app.flags.DEFINE_string('checkpoint_dir', '', 45 | 'Directory with the checkpoints.') 46 | 47 | tf.app.flags.DEFINE_string('images_pattern', '', 48 | 'Pattern of the JPEG images to process.') 49 | 50 | tf.app.flags.DEFINE_string('output_dir', '', 51 | 'Directory to write the results to.') 52 | 53 | tf.app.flags.DEFINE_integer( 54 | 'image_size', 0, 55 | 'Resize the input image so that the longer edge has this many pixels.' 56 | 'Not resizing if set to zero (the default).') 57 | 58 | def preprocessing(image): 59 | image = tf.subtract(image, 0.5) 60 | image = tf.multiply(image, 2.0) 61 | return image 62 | 63 | 64 | def reverse_preprocessing(image): 65 | image = tf.multiply(image, 0.5) 66 | image = tf.add(image, 0.5) 67 | return image 68 | 69 | 70 | def main(_): 71 | if not tf.gfile.Exists(FLAGS.output_dir): 72 | tf.gfile.MakeDirs(FLAGS.output_dir) 73 | 74 | num_classes = 1001 75 | 76 | path = tf.placeholder(tf.string) 77 | contents = tf.read_file(path) 78 | image = tf.image.decode_jpeg(contents, channels=3) 79 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 80 | images = tf.expand_dims(image, 0) 81 | images.set_shape([1, None, None, 3]) 82 | 83 | if FLAGS.image_size: 84 | sh = tf.shape(image) 85 | height, width = tf.to_float(sh[0]), tf.to_float(sh[1]) 86 | longer_size = tf.constant(FLAGS.image_size, dtype=tf.float32) 87 | 88 | new_size = tf.cond( 89 | height >= width, 90 | lambda: (longer_size, (width / height) * longer_size), 91 | lambda: ((height / width) * longer_size, longer_size)) 92 | images_resized = tf.image.resize_images(images, 93 | size=tf.to_int32(tf.stack(new_size)), 94 | method=tf.image.ResizeMethod.BICUBIC) 95 | else: 96 | images_resized = images 97 | 98 | images_resized = preprocessing(images_resized) 99 | 100 | # Define the model: 101 | with slim.arg_scope(imagenet_model.resnet_arg_scope(is_training=False)): 102 | model = utils.split_and_int(FLAGS.model) 103 | logits, end_points = imagenet_model.get_network( 104 | images_resized, 105 | model, 106 | num_classes, 107 | model_type='sact') 108 | ponder_cost_map = summary_utils.sact_map(end_points, 'ponder_cost') 109 | 110 | checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) 111 | assert checkpoint_path is not None 112 | 113 | saver = tf.train.Saver() 114 | sess = tf.Session() 115 | 116 | saver.restore(sess, checkpoint_path) 117 | 118 | for current_path in glob.glob(FLAGS.images_pattern): 119 | print('Processing {}'.format(current_path)) 120 | 121 | [image_resized_out, ponder_cost_map_out] = sess.run( 122 | [tf.squeeze(reverse_preprocessing(images_resized), 0), 123 | tf.squeeze(ponder_cost_map, [0, 3])], 124 | feed_dict={path: current_path}) 125 | 126 | basename = os.path.splitext(os.path.basename(current_path))[0] 127 | if FLAGS.image_size: 128 | matplotlib.image.imsave( 129 | os.path.join(FLAGS.output_dir, '{}_im.jpg'.format(basename)), 130 | image_resized_out) 131 | matplotlib.image.imsave( 132 | os.path.join(FLAGS.output_dir, '{}_ponder.jpg'.format(basename)), 133 | ponder_cost_map_out, 134 | cmap='viridis') 135 | 136 | min_ponder = ponder_cost_map_out.min() 137 | max_ponder = ponder_cost_map_out.max() 138 | print('Minimum/maximum ponder cost {:.2f}/{:.2f}'.format( 139 | min_ponder, max_ponder)) 140 | 141 | fig = plt.figure(figsize=(0.2, 2)) 142 | ax = fig.add_axes([0.0, 0.0, 1.0, 1.0]) 143 | cb = matplotlib.colorbar.ColorbarBase( 144 | ax, cmap='viridis', 145 | norm=matplotlib.colors.Normalize(vmin=min_ponder, vmax=max_ponder)) 146 | ax.tick_params(labelsize=12) 147 | filename = os.path.join(FLAGS.output_dir, '{}_colorbar.pdf'.format(basename)) 148 | plt.savefig(filename, bbox_inches='tight') 149 | 150 | 151 | if __name__ == '__main__': 152 | tf.app.run() 153 | -------------------------------------------------------------------------------- /imagenet_train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Trains a ResNet-ACT model.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | from tensorflow.contrib import slim 24 | 25 | import imagenet_data_provider 26 | import imagenet_model 27 | import summary_utils 28 | import training_utils 29 | import utils 30 | 31 | FLAGS = tf.app.flags.FLAGS 32 | 33 | tf.app.flags.DEFINE_string('master', '', 34 | 'Name of the TensorFlow master to use.') 35 | 36 | tf.app.flags.DEFINE_string('train_log_dir', '/tmp/resnet/', 37 | 'Directory where to write event logs.') 38 | 39 | tf.app.flags.DEFINE_string( 40 | 'split_name', 'train', 41 | """The name of the train/test split, either 'train' or 'validation'.""") 42 | 43 | tf.app.flags.DEFINE_integer('worker_replicas', 1, 'Number of worker replicas.') 44 | 45 | tf.app.flags.DEFINE_integer( 46 | 'ps_tasks', 0, 47 | 'The number of parameter servers. If the value is 0, then the parameters ' 48 | 'are handled locally by the worker.') 49 | 50 | tf.app.flags.DEFINE_integer( 51 | 'save_summaries_secs', 600, 52 | 'The frequency with which summaries are saved, in seconds.') 53 | 54 | tf.app.flags.DEFINE_integer('save_interval_secs', 600, 55 | 'The frequency with which the model is saved, in seconds.') 56 | 57 | tf.app.flags.DEFINE_integer('startup_delay_steps', 15, 58 | 'Number of training steps between replicas startup.') 59 | 60 | tf.app.flags.DEFINE_integer('task', 0, 'Task id of the replica running the training.') 61 | 62 | tf.app.flags.DEFINE_string('dataset_dir', None, 'Directory with ImageNet data.') 63 | 64 | # Training parameters. 65 | tf.app.flags.DEFINE_integer('batch_size', 32, 66 | 'The number of images in each batch.') 67 | 68 | tf.app.flags.DEFINE_float('learning_rate', 0.05, """Initial learning rate.""") 69 | 70 | tf.app.flags.DEFINE_float('momentum', 0.9, """Momentum.""") 71 | 72 | tf.app.flags.DEFINE_float('learning_rate_decay_factor', 0.1, 73 | 'Learning rate decay factor.') 74 | 75 | tf.app.flags.DEFINE_float('num_epochs_per_decay', 30.0, 76 | 'Number of epochs after which learning rate decays.') 77 | 78 | tf.app.flags.DEFINE_integer( 79 | 'replicas_to_aggregate', 1, 80 | 'The Number of gradients to collect before updating params.') 81 | 82 | tf.app.flags.DEFINE_float('moving_average_decay', 0.9999, 83 | 'The decay to use for the moving average.') 84 | 85 | tf.app.flags.DEFINE_integer('image_size', 224, 86 | 'Image resolution for resize.') 87 | 88 | tf.app.flags.DEFINE_string( 89 | 'model', '101', 90 | 'Depth of the network to train (50, 101, 152, 200), or number of layers' 91 | ' in each block (e.g. 3_4_23_3).') 92 | 93 | tf.app.flags.DEFINE_string( 94 | 'model_type', 'vanilla', 95 | 'Options: vanilla (basic ResNet model), act (Adaptive Computation Time), ' 96 | 'act_early_stopping (act implementation which actually saves time), ' 97 | 'sact (Spatially Adaptive Computation Time)') 98 | 99 | tf.app.flags.DEFINE_float('tau', 1.0, 'Target value of tau (ponder relative cost).') 100 | 101 | tf.app.flags.DEFINE_string('finetune_path', '', 102 | 'Path for the initial checkpoint for finetuning.') 103 | 104 | 105 | def main(_): 106 | g = tf.Graph() 107 | with g.as_default(): 108 | # If ps_tasks is zero, the local device is used. When using multiple 109 | # (non-local) replicas, the ReplicaDeviceSetter distributes the variables 110 | # across the different devices. 111 | with tf.device(tf.train.replica_device_setter( 112 | FLAGS.ps_tasks, merge_devices=True)): 113 | data_tuple = imagenet_data_provider.provide_data( 114 | FLAGS.split_name, 115 | FLAGS.batch_size, 116 | dataset_dir=FLAGS.dataset_dir, 117 | is_training=True, 118 | image_size=FLAGS.image_size) 119 | images, labels, examples_per_epoch, num_classes = data_tuple 120 | 121 | # Define the model: 122 | with slim.arg_scope(imagenet_model.resnet_arg_scope(is_training=True)): 123 | model = utils.split_and_int(FLAGS.model) 124 | logits, end_points = imagenet_model.get_network( 125 | images, 126 | model, 127 | num_classes, 128 | model_type=FLAGS.model_type) 129 | 130 | # Specify the loss function: 131 | tf.losses.softmax_cross_entropy( 132 | onehot_labels=labels, logits=logits, label_smoothing=0.1, weights=1.0) 133 | if FLAGS.model_type in ('act', 'act_early_stopping', 'sact'): 134 | training_utils.add_all_ponder_costs(end_points, weights=FLAGS.tau) 135 | total_loss = tf.losses.get_total_loss() 136 | 137 | # Configure the learning rate using an exponetial decay. 138 | decay_steps = int(examples_per_epoch / FLAGS.batch_size * 139 | FLAGS.num_epochs_per_decay) 140 | 141 | learning_rate = tf.train.exponential_decay( 142 | FLAGS.learning_rate, 143 | slim.get_or_create_global_step(), 144 | decay_steps, 145 | FLAGS.learning_rate_decay_factor, 146 | staircase=True) 147 | 148 | opt = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum) 149 | 150 | init_fn = training_utils.finetuning_init_fn(FLAGS.finetune_path) 151 | 152 | train_tensor = slim.learning.create_train_op( 153 | total_loss, 154 | optimizer=opt, 155 | update_ops=tf.get_collection(tf.GraphKeys.UPDATE_OPS)) 156 | 157 | # Summaries: 158 | tf.summary.scalar('losses/Total Loss', total_loss) 159 | tf.summary.scalar('training/Learning Rate', learning_rate) 160 | 161 | metric_map = {} # summary_utils.flops_metric_map(end_points, False) 162 | if FLAGS.model_type in ('act', 'act_early_stopping', 'sact'): 163 | metric_map.update(summary_utils.act_metric_map(end_points, False)) 164 | for name, value in metric_map.iteritems(): 165 | tf.summary.scalar(name, value) 166 | 167 | if FLAGS.model_type == 'sact': 168 | summary_utils.add_heatmaps_image_summary(end_points, border=10) 169 | 170 | startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps 171 | 172 | slim.learning.train( 173 | train_tensor, 174 | init_fn=init_fn, 175 | logdir=FLAGS.train_log_dir, 176 | master=FLAGS.master, 177 | is_chief=(FLAGS.task == 0), 178 | startup_delay_steps=startup_delay_steps, 179 | save_summaries_secs=FLAGS.save_summaries_secs, 180 | save_interval_secs=FLAGS.save_interval_secs) 181 | 182 | 183 | if __name__ == '__main__': 184 | tf.app.run() 185 | -------------------------------------------------------------------------------- /pics/20.92_93_im.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mfigurnov/sact/1c81cbaaa9219a57c03ac3bdaeed30f13beb98e7/pics/20.92_93_im.jpg -------------------------------------------------------------------------------- /pics/20.92_93_ponder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mfigurnov/sact/1c81cbaaa9219a57c03ac3bdaeed30f13beb98e7/pics/20.92_93_ponder.png -------------------------------------------------------------------------------- /pics/22.28_95_im.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mfigurnov/sact/1c81cbaaa9219a57c03ac3bdaeed30f13beb98e7/pics/22.28_95_im.jpg -------------------------------------------------------------------------------- /pics/22.28_95_ponder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mfigurnov/sact/1c81cbaaa9219a57c03ac3bdaeed30f13beb98e7/pics/22.28_95_ponder.png -------------------------------------------------------------------------------- /pics/26.75_36_im.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mfigurnov/sact/1c81cbaaa9219a57c03ac3bdaeed30f13beb98e7/pics/26.75_36_im.jpg -------------------------------------------------------------------------------- /pics/26.75_36_ponder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mfigurnov/sact/1c81cbaaa9219a57c03ac3bdaeed30f13beb98e7/pics/26.75_36_ponder.png -------------------------------------------------------------------------------- /pics/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mfigurnov/sact/1c81cbaaa9219a57c03ac3bdaeed30f13beb98e7/pics/cat.jpg -------------------------------------------------------------------------------- /pics/cat_colorbar.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mfigurnov/sact/1c81cbaaa9219a57c03ac3bdaeed30f13beb98e7/pics/cat_colorbar.jpg -------------------------------------------------------------------------------- /pics/cat_ponder.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mfigurnov/sact/1c81cbaaa9219a57c03ac3bdaeed30f13beb98e7/pics/cat_ponder.jpg -------------------------------------------------------------------------------- /pics/export-image-442041-ponder.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mfigurnov/sact/1c81cbaaa9219a57c03ac3bdaeed30f13beb98e7/pics/export-image-442041-ponder.jpg -------------------------------------------------------------------------------- /pics/export-image-442041.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mfigurnov/sact/1c81cbaaa9219a57c03ac3bdaeed30f13beb98e7/pics/export-image-442041.jpg -------------------------------------------------------------------------------- /pics/gasworks.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mfigurnov/sact/1c81cbaaa9219a57c03ac3bdaeed30f13beb98e7/pics/gasworks.jpg -------------------------------------------------------------------------------- /pics/gasworks_colorbar.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mfigurnov/sact/1c81cbaaa9219a57c03ac3bdaeed30f13beb98e7/pics/gasworks_colorbar.jpg -------------------------------------------------------------------------------- /pics/gasworks_ponder.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mfigurnov/sact/1c81cbaaa9219a57c03ac3bdaeed30f13beb98e7/pics/gasworks_ponder.jpg -------------------------------------------------------------------------------- /requirements-gpu.txt: -------------------------------------------------------------------------------- 1 | h5py==2.7.0 2 | matplotlib==2.0.0 3 | nose==1.3.7 4 | six==1.10.0 5 | tensorflow-gpu==1.0.1 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | h5py==2.7.0 2 | matplotlib==2.0.0 3 | nose==1.3.7 4 | six==1.10.0 5 | tensorflow==1.0.0 6 | -------------------------------------------------------------------------------- /resnet_act.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Function for building ResNet with adaptive computation time (ACT).""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from functools import partial 23 | 24 | import h5py 25 | import tensorflow as tf 26 | 27 | from tensorflow.contrib import slim 28 | 29 | import act 30 | import flopsometer 31 | 32 | 33 | SACT_KERNEL_SIZE = 3 34 | INIT_BIAS = -3. 35 | 36 | 37 | def get_halting_proba(outputs): 38 | with tf.variable_scope('halting_proba'): 39 | x = outputs 40 | x = tf.reduce_mean(x, [1, 2], keep_dims=True) 41 | 42 | x = slim.batch_norm(x, scope='global_bn') 43 | halting_proba, flops = flopsometer.conv2d( 44 | x, 45 | 1, 46 | 1, 47 | activation_fn=tf.nn.sigmoid, 48 | normalizer_fn=None, 49 | biases_initializer=tf.constant_initializer(INIT_BIAS), 50 | scope='global_conv') 51 | halting_proba = tf.squeeze(halting_proba, [1, 2]) 52 | 53 | return halting_proba, flops 54 | 55 | 56 | def get_halting_proba_conv(outputs, residual_mask=None): 57 | with tf.variable_scope('halting_proba'): 58 | flops = 0 59 | 60 | x = outputs 61 | 62 | local_feature = slim.batch_norm(x, scope='local_bn') 63 | halting_logit, current_flops = flopsometer.conv2d( 64 | local_feature, 65 | 1, 66 | SACT_KERNEL_SIZE, 67 | activation_fn=None, 68 | normalizer_fn=None, 69 | biases_initializer=tf.constant_initializer(INIT_BIAS), 70 | output_mask=residual_mask, 71 | scope='local_conv') 72 | flops += current_flops 73 | 74 | # Add global halting logit. 75 | global_feature = tf.reduce_mean(x, [1, 2], keep_dims=True) 76 | global_feature = slim.batch_norm(global_feature, scope='global_bn') 77 | halting_logit_global, current_flops = flopsometer.conv2d( 78 | global_feature, 79 | 1, 80 | 1, 81 | activation_fn=None, 82 | normalizer_fn=None, 83 | biases_initializer=None, # biases are already present in local logits 84 | scope='global_conv') 85 | flops += current_flops 86 | 87 | # Addition with broadcasting over spatial dimensions. 88 | halting_logit += halting_logit_global 89 | 90 | halting_proba = tf.sigmoid(halting_logit) 91 | 92 | return halting_proba, flops 93 | 94 | 95 | def unit_act(block, 96 | inputs, 97 | unit_idx, 98 | skip_halting_proba=False, 99 | sact=False, 100 | residual_mask=None): 101 | with tf.variable_scope('unit_%d' % (unit_idx + 1), [inputs]): 102 | outputs, flops = block.unit_fn( 103 | inputs, *block.args[unit_idx], residual_mask=residual_mask) 104 | 105 | if not skip_halting_proba and unit_idx < len(block.args) - 1: 106 | if sact: 107 | halting_proba, current_flops = get_halting_proba_conv( 108 | outputs, residual_mask) 109 | flops += current_flops 110 | else: 111 | halting_proba, current_flops = get_halting_proba(outputs) 112 | flops += current_flops 113 | else: 114 | halting_proba = None 115 | 116 | return outputs, halting_proba, flops 117 | 118 | 119 | def stack_blocks(net, blocks, model_type, end_points=None): 120 | """Utility function for assembling SACT models consisting of 'blocks.'""" 121 | if end_points is None: 122 | end_points = {} 123 | end_points['flops'] = end_points.get('flops', 0) 124 | end_points['block_scopes'] = [block.scope for block in blocks] 125 | end_points['block_num_units'] = [len(block.args) for block in blocks] 126 | 127 | assert model_type in ('vanilla', 'act', 'act_early_stopping', 'sact') 128 | model_type_to_func = { 129 | 'act': act.adaptive_computation_time_wrapper, 130 | 'act_early_stopping': act.adaptive_computation_early_stopping, 131 | 'sact': act.spatially_adaptive_computation_time, 132 | } 133 | act_func = model_type_to_func.get(model_type, None) 134 | 135 | for block in blocks: 136 | if act_func: 137 | (ponder_cost, num_units, flops, halting_distribution, net) = act_func( 138 | net, 139 | partial(unit_act, block, sact=(model_type == 'sact')), 140 | len(block.args), 141 | scope=block.scope) 142 | 143 | end_points['{}/ponder_cost'.format(block.scope)] = ponder_cost 144 | end_points['{}/num_units'.format(block.scope)] = num_units 145 | end_points['{}/halting_distribution'.format( 146 | block.scope)] = halting_distribution 147 | else: 148 | with tf.variable_scope(block.scope, 'block', [net]): 149 | flops = 0 150 | for unit_idx in range(len(block.args)): 151 | net, _, current_flops = unit_act( 152 | block, net, unit_idx, skip_halting_proba=True) 153 | flops += current_flops 154 | 155 | end_points['{}/flops'.format(block.scope)] = flops 156 | end_points['flops'] += flops 157 | end_points[block.scope] = net 158 | 159 | return net, end_points 160 | -------------------------------------------------------------------------------- /squeeze_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Removes the Momentum and Moving Average variables, 17 | reducing the model size 2-3 times. 18 | The provided pretrained models are squeezed. 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import math 26 | 27 | import tensorflow as tf 28 | from tensorflow.contrib import slim 29 | 30 | import cifar_model 31 | import imagenet_model 32 | import utils 33 | 34 | FLAGS = tf.app.flags.FLAGS 35 | 36 | tf.app.flags.DEFINE_string('input_dir', '/tmp/resnet/', 37 | 'Directory where the model was written to.') 38 | 39 | tf.app.flags.DEFINE_string('output_dir', '/tmp/resnet2/', 40 | 'Directory where the squeezed model will be written to.') 41 | 42 | tf.app.flags.DEFINE_string( 43 | 'model', 44 | None, 45 | 'A description of the model.') 46 | 47 | tf.app.flags.DEFINE_string( 48 | 'model_type', None, 49 | 'Options: vanilla (basic ResNet model), act (Adaptive Computation Time), ' 50 | 'act_early_stopping (act implementation which actually saves time), ' 51 | 'sact (Spatially Adaptive Computation Time)') 52 | 53 | tf.app.flags.DEFINE_string( 54 | 'dataset', None, 55 | 'Options: imagenet, cifar' 56 | ) 57 | 58 | 59 | def main(_): 60 | if not tf.gfile.Exists(FLAGS.output_dir): 61 | tf.gfile.MakeDirs(FLAGS.output_dir) 62 | 63 | assert FLAGS.model is not None 64 | assert FLAGS.model_type in ('vanilla', 'act', 'act_early_stopping', 'sact') 65 | assert FLAGS.dataset in ('imagenet', 'cifar') 66 | 67 | batch_size = 1 68 | 69 | if FLAGS.dataset == 'imagenet': 70 | height, width = 224, 224 71 | num_classes = 1001 72 | elif FLAGS.dataset == 'cifar': 73 | height, width = 32, 32 74 | num_classes = 10 75 | 76 | images = tf.random_uniform((batch_size, height, width, 3)) 77 | model = utils.split_and_int(FLAGS.model) 78 | 79 | # Define the model 80 | if FLAGS.dataset == 'imagenet': 81 | with slim.arg_scope(imagenet_model.resnet_arg_scope(is_training=False)): 82 | logits, end_points = imagenet_model.get_network( 83 | images, 84 | model, 85 | num_classes, 86 | model_type=FLAGS.model_type) 87 | elif FLAGS.dataset == 'cifar': 88 | # Define the model: 89 | with slim.arg_scope(cifar_model.resnet_arg_scope(is_training=False)): 90 | logits, end_points = cifar_model.resnet( 91 | images, 92 | model=model, 93 | num_classes=num_classes, 94 | model_type=FLAGS.model_type) 95 | 96 | tf_global_step = slim.get_or_create_global_step() 97 | 98 | checkpoint_path = tf.train.latest_checkpoint(FLAGS.input_dir) 99 | assert checkpoint_path is not None 100 | 101 | saver = tf.train.Saver(write_version=2) 102 | 103 | with tf.Session() as sess: 104 | saver.restore(sess, checkpoint_path) 105 | saver.save(sess, FLAGS.output_dir + '/model', global_step=tf_global_step) 106 | 107 | 108 | if __name__ == '__main__': 109 | tf.app.run() 110 | -------------------------------------------------------------------------------- /summary_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Summary utility functions.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import h5py 23 | import tensorflow as tf 24 | 25 | from tensorflow.contrib import slim 26 | 27 | 28 | def moments_metric_map(x, name, mean_metric, delimiter='_', do_shift=False): 29 | if not mean_metric: 30 | if do_shift: 31 | shift = tf.reduce_mean(x) # Seems to help numerical issues, but slower 32 | else: 33 | shift = None 34 | 35 | mean, var = tf.nn.moments(x, list(range(len(x.get_shape().as_list()))), 36 | shift=shift) 37 | std = tf.sqrt(tf.maximum(0.0, var)) 38 | else: 39 | mean = tf.contrib.metrics.streaming_mean(x) 40 | # Variance is estimated over the whole dataset, therefore it will 41 | # generally be higher than during training. 42 | var_value, var_update = tf.contrib.metrics.streaming_covariance(x, x) 43 | std_value = tf.sqrt(tf.maximum(0.0, var_value)) 44 | std = std_value, var_update 45 | 46 | metric_map = { 47 | '{}{}mean'.format(name, delimiter): mean, 48 | '{}{}std'.format(name, delimiter): std, 49 | } 50 | 51 | return metric_map 52 | 53 | 54 | def act_metric_map(end_points, mean_metric): 55 | """Assembles ACT-specific metrics into a map for use in tf.contrib.metrics.""" 56 | metric_map = {} 57 | 58 | for block_scope in end_points['block_scopes']: 59 | name = '{}/ponder_cost'.format(block_scope) 60 | ponder_cost = end_points[name] 61 | ponder_map = moments_metric_map(ponder_cost, name, mean_metric) 62 | metric_map.update(ponder_map) 63 | 64 | name = '{}/num_units'.format(block_scope) 65 | num_units = tf.to_float(end_points[name]) 66 | num_units_map = moments_metric_map(num_units, name, mean_metric) 67 | metric_map.update(num_units_map) 68 | 69 | if not mean_metric: 70 | # Not sure how to make a streaming version of this metric, 71 | # so tracking it only during training. 72 | name = '{}/num_units_max'.format(block_scope) 73 | metric_map[name] = tf.reduce_max(num_units) 74 | 75 | return metric_map 76 | 77 | 78 | def flops_metric_map(end_points, mean_metric, total_name='Total Flops'): 79 | """Assembles flops-count metrics into a map for use in tf.contrib.metrics.""" 80 | metric_map = {} 81 | total_flops = tf.to_float(end_points['flops']) 82 | flops_map = moments_metric_map(total_flops, total_name, mean_metric, 83 | delimiter='/', do_shift=True) 84 | metric_map.update(flops_map) 85 | 86 | for block_scope in end_points['block_scopes']: 87 | name = '{}/flops'.format(block_scope) 88 | flops = tf.to_float(end_points[name]) 89 | flops_map = moments_metric_map(flops, name, mean_metric, do_shift=True) 90 | metric_map.update(flops_map) 91 | 92 | return metric_map 93 | 94 | 95 | def sact_image_heatmap(end_points, 96 | metric_name, 97 | num_images=5, 98 | alpha=0.75, 99 | border=5, 100 | normalize_images=True): 101 | """Overlays a heatmap of the ponder cost onto the input image.""" 102 | assert metric_name in ('ponder_cost', 'num_units') 103 | 104 | images = end_points['inputs'] 105 | if num_images is not None: 106 | images = images[:num_images, :, :, :] 107 | else: 108 | num_images = tf.shape(images)[0] 109 | 110 | # Normalize the images 111 | if normalize_images: 112 | images -= tf.reduce_min(images, [1, 2, 3], True) 113 | images /= tf.reduce_max(images, [1, 2, 3], True) 114 | 115 | resolution = tf.shape(images)[1:3] 116 | 117 | max_value = sum(end_points['block_num_units']) 118 | if metric_name == 'ponder_cost': 119 | max_value += len(end_points['block_num_units']) 120 | 121 | heatmaps = [] 122 | for scope in end_points['block_scopes']: 123 | h = end_points['{}/{}'.format(scope, metric_name)] 124 | h = tf.to_float(h) 125 | h = h[:num_images, :, :] 126 | h = tf.expand_dims(h, 3) 127 | # The metric maps can be lower resolution than the image. 128 | # We simply resize the map to the image size. 129 | h = tf.image.resize_nearest_neighbor(h, resolution, align_corners=False) 130 | # Heatmap is in Red channel. Fill Blue and Green channels with zeros. 131 | dimensions = tf.stack([num_images, resolution[0], resolution[1], 2]) 132 | h = tf.concat([h, tf.zeros(dimensions)], 3) 133 | heatmaps.append(h) 134 | 135 | im_heatmap = images * (1 - alpha) + tf.add_n(heatmaps) * (alpha / max_value) 136 | 137 | # image, black border, image with overlayed heatmap 138 | dimensions = tf.stack([num_images, resolution[0], border, 3]) 139 | ret = tf.concat([images, tf.zeros(dimensions), im_heatmap], 2) 140 | 141 | return ret 142 | 143 | 144 | def add_heatmaps_image_summary(end_points, num_images=3, alpha=0.75, border=5): 145 | tf.summary.image( 146 | 'heatmaps/ponder_cost', 147 | sact_image_heatmap( 148 | end_points, 149 | 'ponder_cost', 150 | num_images=num_images, 151 | alpha=alpha, 152 | border=border)) 153 | tf.summary.image( 154 | 'heatmaps/num_units', 155 | sact_image_heatmap( 156 | end_points, 157 | 'num_units', 158 | num_images=num_images, 159 | alpha=alpha, 160 | border=border)) 161 | 162 | 163 | def sact_map(end_points, metric_name): 164 | """Generates a heatmap of the ponder cost for visualization.""" 165 | assert metric_name in ('ponder_cost', 'num_units') 166 | 167 | inputs = end_points['inputs'] 168 | if inputs.get_shape().is_fully_defined(): 169 | sh = inputs.get_shape().as_list() 170 | else: 171 | sh = tf.shape(inputs) 172 | resolution = sh[1:3] 173 | 174 | heatmaps = [] 175 | for scope in end_points['block_scopes']: 176 | h = end_points['{}/{}'.format(scope, metric_name)] 177 | h = tf.to_float(h) 178 | h = tf.expand_dims(h, 3) 179 | # The metric maps can be lower resolution than the image. 180 | # We simply resize the map to the image size. 181 | h = tf.image.resize_nearest_neighbor(h, resolution, align_corners=False) 182 | heatmaps.append(h) 183 | 184 | return tf.add_n(heatmaps) 185 | 186 | 187 | def export_to_h5(checkpoint_dir, export_path, images, end_points, num_samples, 188 | batch_size, sact): 189 | """Exports ponder cost maps and other useful info to an HDF5 file.""" 190 | output_file = h5py.File(export_path, 'w') 191 | 192 | output_file.attrs['block_scopes'] = end_points['block_scopes'] 193 | keys_to_tensors = {} 194 | for block_scope in end_points['block_scopes']: 195 | for k in ('{}/ponder_cost'.format(block_scope), 196 | '{}/num_units'.format(block_scope), 197 | '{}/halting_distribution'.format(block_scope), 198 | '{}/flops'.format(block_scope)): 199 | keys_to_tensors[k] = end_points[k] 200 | keys_to_tensors['images'] = images 201 | keys_to_tensors['flops'] = end_points['flops'] 202 | 203 | if sact: 204 | keys_to_tensors['ponder_cost_map'] = sact_map(end_points, 'ponder_cost') 205 | keys_to_tensors['num_units_map'] = sact_map(end_points, 'num_units') 206 | 207 | keys_to_datasets = {} 208 | for key, tensor in keys_to_tensors.iteritems(): 209 | sh = tensor.get_shape().as_list() 210 | sh[0] = num_samples 211 | print(key, sh) 212 | keys_to_datasets[key] = output_file.create_dataset( 213 | key, sh, compression='lzf') 214 | 215 | variables_to_restore = slim.get_model_variables() 216 | checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir) 217 | assert checkpoint_path is not None 218 | init_fn = slim.assign_from_checkpoint_fn(checkpoint_path, 219 | variables_to_restore) 220 | 221 | sv = tf.train.Supervisor( 222 | graph=tf.get_default_graph(), 223 | logdir=None, 224 | summary_op=None, 225 | summary_writer=None, 226 | global_step=None, 227 | saver=None) 228 | 229 | assert num_samples % batch_size == 0 230 | num_batches = num_samples // batch_size 231 | 232 | with sv.managed_session('', start_standard_services=False) as sess: 233 | init_fn(sess) 234 | sv.start_queue_runners(sess) 235 | 236 | for i in range(num_batches): 237 | tf.logging.info('Evaluating batch %d/%d', i + 1, num_batches) 238 | end_points_out = sess.run(keys_to_tensors) 239 | for key, dataset in keys_to_datasets.iteritems(): 240 | dataset[i * batch_size:(i + 1) * batch_size, ...] = end_points_out[key] 241 | -------------------------------------------------------------------------------- /summary_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for summary_utils.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import numpy as np 23 | import tensorflow as tf 24 | 25 | import summary_utils 26 | 27 | 28 | class SummaryUtilsTest(tf.test.TestCase): 29 | 30 | def testSactImageHeatmap(self): 31 | batch = 9 32 | num_images = 5 33 | height, width, channels = 32, 32, 3 34 | border = 4 35 | alpha = 0.75 36 | 37 | end_points = { 38 | 'inputs': tf.ones([batch, height, width, channels]), 39 | 'block_num_units': [10], 40 | 'block_scopes': ['block_1'], 41 | 'block_1/ponder_cost': 5 * tf.ones([batch, height / 2, width / 2]), 42 | } 43 | 44 | heatmap = summary_utils.sact_image_heatmap( 45 | end_points, 46 | 'ponder_cost', 47 | num_images=num_images, 48 | alpha=alpha, 49 | border=border, 50 | normalize_images=False) 51 | 52 | with self.test_session() as sess: 53 | inputs_out, heatmap_out = sess.run([end_points['inputs'], heatmap]) 54 | 55 | self.assertEqual(heatmap_out.shape, 56 | (num_images, height, width * 2 + border, channels)) 57 | self.assertAllClose(heatmap_out[:, :, :width, :], 58 | inputs_out[:num_images, :, :, :]) 59 | 60 | expected_heatmap = 0.25 * inputs_out[:num_images, :, :, :] 61 | expected_heatmap[:, :, :, 0] += 0.75 * (5.0 / 11.0) 62 | self.assertAllClose(heatmap_out[:, :, width + border:, :], expected_heatmap) 63 | 64 | 65 | if __name__ == '__main__': 66 | tf.test.main() 67 | -------------------------------------------------------------------------------- /testdata/cifar10/cifar10_test.tfrecord: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mfigurnov/sact/1c81cbaaa9219a57c03ac3bdaeed30f13beb98e7/testdata/cifar10/cifar10_test.tfrecord -------------------------------------------------------------------------------- /testdata/cifar10/cifar10_train.tfrecord: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mfigurnov/sact/1c81cbaaa9219a57c03ac3bdaeed30f13beb98e7/testdata/cifar10/cifar10_train.tfrecord -------------------------------------------------------------------------------- /testdata/imagenet/train-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mfigurnov/sact/1c81cbaaa9219a57c03ac3bdaeed30f13beb98e7/testdata/imagenet/train-00000-of-00001 -------------------------------------------------------------------------------- /testdata/imagenet/validation-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mfigurnov/sact/1c81cbaaa9219a57c03ac3bdaeed30f13beb98e7/testdata/imagenet/validation-00000-of-00001 -------------------------------------------------------------------------------- /training_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Training utility functions.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | 25 | def add_all_ponder_costs(end_points, weights): 26 | total_ponder_cost = 0. 27 | for scope in end_points['block_scopes']: 28 | ponder_cost = end_points['{}/ponder_cost'.format(scope)] 29 | total_ponder_cost += tf.reduce_mean(ponder_cost) 30 | tf.losses.add_loss(total_ponder_cost * weights) 31 | 32 | 33 | def variables_to_str(variables): 34 | return ', '.join([var.op.name for var in variables]) 35 | 36 | 37 | def finetuning_init_fn(finetune_path): 38 | """Sets up fine-tuning of a SACT model.""" 39 | if not finetune_path: 40 | return None 41 | 42 | tf.logging.warning('Finetuning from {}'.format(finetune_path)) 43 | variables = tf.contrib.framework.get_model_variables() 44 | variables_to_restore = [ 45 | var for var in variables if '/halting_proba/' not in var.op.name 46 | ] 47 | tf.logging.info('Restoring variables: {}'.format( 48 | variables_to_str(variables_to_restore))) 49 | init_fn = tf.contrib.framework.assign_from_checkpoint_fn( 50 | finetune_path, variables_to_restore) 51 | 52 | return init_fn 53 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Summary utility functions.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | 23 | def split_and_int(s): 24 | return [int(x) for x in s.split('_')] 25 | --------------------------------------------------------------------------------