├── CHANGELIST.md ├── CONTRIBUTING.md ├── LICENSE-2.0.txt ├── README.md ├── docs ├── Bookkeeper.md ├── Loss.md ├── PrettyTensor.md ├── PrettyTensorTupleMixin.md └── pretty_tensor_top_level.md ├── inception_module.png ├── prettytensor ├── __init__.py ├── bookkeeper.py ├── bookkeeper_test.py ├── chain_dict.py ├── chain_dict_test.py ├── funcs.py ├── functions.py ├── functions_test.py ├── input_helpers.py ├── layers.py ├── local_trainer.py ├── local_trainer_test.py ├── parameters.py ├── pretty_tensor_class.py ├── pretty_tensor_image_methods.py ├── pretty_tensor_loss_methods.py ├── pretty_tensor_methods.py ├── pretty_tensor_normalization_methods.py ├── pretty_tensor_sparse_methods.py ├── pretty_tensor_test.py ├── pretty_tensor_testing.py ├── recurrent_networks.py ├── recurrent_networks_test.py ├── recurrent_networks_testing_utils.py ├── replay_queue.py ├── replay_queue_test.py ├── scopes.py ├── scopes_test.py ├── sequence_with_deltas.py ├── templated_pretty_tensor_test.py ├── train.py └── tutorial │ ├── README.md │ ├── __init__.py │ ├── baby_names.csv │ ├── baby_names.py │ ├── data_utils.py │ ├── mnist.py │ └── shakespeare.py ├── setup.py ├── test_pip_install.sh └── unrolled_lstm.png /CHANGELIST.md: -------------------------------------------------------------------------------- 1 | ## 0.7.3 2 | 3 | 1. Maintenance release w/ some deprecation notice fixes. Note: this may change the names of the summaries. 4 | 5 | ## 0.7.2 6 | 7 | 1. Maintenance release w/ change to project pip dependencies to better support GPU builds. 8 | 9 | ## 0.7.1 10 | 11 | ### General 12 | 1. Changed `weights` to `init` and `bias_init` to `bias` and made these support initialization functions or `Tensors`. 13 | 2. Added `parameter_modifier`. These are functions that are applied after creating a `Variable`, but before it is used in the graph. They allow you to apply a function like normalization or drop connect to the graph. See `pt.parameters` for details. 14 | 3. Added support for directly chaining many useful TensorFlow functions. See [pretty_tensor_methods.py](https://github.com/google/prettytensor/blob/master/prettytensor/pretty_tensor_methods.py#L700) for details. Note: when a function is removed from tf (e.g. complex_abs), it will be removed here. 15 | 1. Changed internal calls to TF to comply with API changes. 16 | 2. Internally changed the name of the first parameter to be more consistent. This should not be user visible since it is the variable to the left of the '.'. 17 | 18 | ### Losses 19 | 20 | 1. Added `per_output_weights` to `binary_cross_entropy_with_logits` and that allow you to weight the loss from classes and examples. 21 | 2. Added `sparse_cross_entropy` to efficiently calculate the loss of when you have a vector of 1 hot labels as indices (`tf.int32`/`tf.int64`). Also added `evaluate_classifier_sparse`. 22 | 3. Fixed `softmax_classifier_with_sampled_loss` to support specified parameters and parameter modification. 23 | 4. Standardized on `num_classes` and changed the parameter name in `softmax_classifier` accordingly. 24 | 25 | ### Optimizer 26 | 1. Added `clip_gradients_by_norm` to `apply_optimizer`. 27 | 28 | ### Images 29 | 30 | 1. Added a differentiable sampling method for images called `bilinear_sampling`. 31 | 32 | 33 | ## 0.6.2 34 | 35 | Add Depthwise Convolution 36 | 37 | ### Batch Normalization 38 | 1. Make Batch Normalization work with arbitrary dimensionality. 39 | 2. Allow passing through arguments to BN using a namedtuple. 40 | 3. Add BN default values. 41 | 4. Remove requirement to use with_update_ops to make BN accumulate values for 42 | inference. 43 | 44 | 45 | 46 | ## 0.6.0 47 | 48 | 1. Adding scoped control of summary creation. 49 | 2. Scoped variable collections. 50 | 3. Can initialize variables from literals. 51 | 4. Fixed operators -- Sequential's plus no longer has side effects. 52 | 5. Operators now work on Pretty Tensors that contain lists. 53 | 54 | 55 | Note: (4) may be breaking! 56 | 57 | ## 0.5.3 58 | 59 | 1. Fixing tutorials (thanks jkahn!) 60 | 2. Adding a precicion and recall evaluation. 61 | 3. Various bug fixes. 62 | 63 | Tested on TF 0.7.1 64 | 65 | ## 0.5.2 66 | 67 | 1. Various bug fixes 68 | 2. Reordered the arguments to a better positional order. 69 | 3. Added a length argument to recurrent networks to support short circuiting. 70 | 4. Improvements to reshape. 71 | 5. Python 3 support. 72 | 73 | ## 0.5.0 74 | 75 | Initial Release 76 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Want to contribute? Great! First, read this page (including the small print at the end). 2 | 3 | ### Before you contribute 4 | Before we can use your code, you must sign the 5 | [Google Individual Contributor License Agreement] 6 | (https://cla.developers.google.com/about/google-individual) 7 | (CLA), which you can do online. The CLA is necessary mainly because you own the 8 | copyright to your changes, even after your contribution becomes part of our 9 | codebase, so we need your permission to use and distribute your code. We also 10 | need to be sure of various other things—for instance that you'll tell us if you 11 | know that your code infringes on other people's patents. You don't have to sign 12 | the CLA until after you've submitted your code for review and a member has 13 | approved it, but you must do it before we can put your code into our codebase. 14 | Before you start working on a larger contribution, you should get in touch with 15 | us first through the issue tracker with your idea so that we can help out and 16 | possibly guide you. Coordinating up front makes it much easier to avoid 17 | frustration later on. 18 | 19 | ### Code reviews 20 | All submissions, including submissions by project members, require review. We 21 | use Github pull requests for this purpose. 22 | 23 | ### The small print 24 | Contributions made by corporations are covered by a different agreement than 25 | the one above, the 26 | [Software Grant and Corporate Contributor License Agreement] 27 | (https://cla.developers.google.com/about/google-corporate). 28 | -------------------------------------------------------------------------------- /LICENSE-2.0.txt: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pretty Tensor - Fluent Neural Networks in TensorFlow 2 | 3 | Pretty Tensor provides a high level builder API for TensorFlow. It provides 4 | thin wrappers on Tensors so that you can easily build multi-layer neural 5 | networks. 6 | 7 | Pretty Tensor provides a set of objects that behave likes Tensors, but also 8 | support a chainable object syntax to quickly define neural networks 9 | and other layered architectures in TensorFlow. 10 | 11 | result = (pretty_tensor.wrap(input_data, m) 12 | .flatten() 13 | .fully_connected(200, activation_fn=tf.nn.relu) 14 | .fully_connected(10, activation_fn=None) 15 | .softmax(labels, name=softmax_name)) 16 | 17 | Please look here for full documentation of the PrettyTensor object for all 18 | available operations: 19 | [Available Operations](docs/PrettyTensor.md) or you can check out the [complete 20 | documentation](docs/pretty_tensor_top_level.md) 21 | 22 | See the tutorial directory for samples: 23 | [tutorial/](prettytensor/tutorial/) 24 | 25 | ## Installation 26 | 27 | The easiest installation is just to use pip: 28 | 29 | 1. Follow the instructions at 30 | [tensorflow.org](https://www.tensorflow.org/versions/master/get_started/os_setup.html#pip_install) 31 | 2. `pip install prettytensor` 32 | 33 | 34 | **Note:** Head is tested against the TensorFlow nightly builds and pip is tested against TensorFlow release. 35 | 36 | ## Quick start 37 | 38 | ### Imports 39 | import prettytensor as pt 40 | import tensorflow as tf 41 | 42 | ### Setup your input 43 | my_inputs = # numpy array of shape (BATCHES, BATCH_SIZE, DATA_SIZE) 44 | my_labels = # numpy array of shape (BATCHES, BATCH_SIZE, CLASSES) 45 | input_tensor = tf.placeholder(np.float32, shape=(BATCH_SIZE, DATA_SIZE)) 46 | label_tensor = tf.placeholder(np.float32, shape=(BATCH_SIZE, CLASSES)) 47 | pretty_input = pt.wrap(input_tensor) 48 | 49 | ### Define your model 50 | softmax, loss = (pretty_input. 51 | fully_connected(100). 52 | softmax_classifier(CLASSES, labels=label_tensor)) 53 | 54 | ### Train and evaluate 55 | accuracy = softmax.evaluate_classifier(label_tensor) 56 | 57 | optimizer = tf.train.GradientDescentOptimizer(0.1) # learning rate 58 | train_op = pt.apply_optimizer(optimizer, losses=[loss]) 59 | 60 | init_op = tf.initialize_all_variables() 61 | 62 | with tf.Session() as sess: 63 | sess.run(init_op) 64 | for inp, label in zip(my_inputs, my_labels): 65 | unused_loss_value, accuracy_value = sess.run([loss, accuracy], 66 | {input_tensor: inp, label_tensor: label}) 67 | print 'Accuracy: %g' % accuracy_value 68 | 69 | ## Features 70 | 71 | ### Thin 72 | 73 | #### Full power of TensorFlow is easy to use 74 | 75 | Pretty Tensors can be used (almost) everywhere that a tensor can. Just call 76 | `pt.wrap` to make a tensor pretty. 77 | 78 | You can also add any existing TensorFlow function to the chain using `apply`. 79 | `apply` applies the current Tensor as the first argument and takes all the other 80 | arguments as normal. 81 | 82 | *Note:* because apply is so generic, Pretty Tensor doesn't try to wrap the 83 | world. 84 | 85 | #### Plays well with other libraries 86 | 87 | It also uses standard TensorFlow idioms so that it plays well with other 88 | libraries, this means that you can use it a little bit in a model or throughout. 89 | Just make sure to run the update_ops on each training set 90 | (see [with_update_ops](docs/pretty_tensor_top_level.md#with_update_ops)). 91 | 92 | ### Terse 93 | 94 | You've already seen how a Pretty Tensor is chainable and you may have noticed 95 | that it takes care of handling the input shape. One other feature worth noting 96 | are defaults. Using defaults you can specify reused values in a single place 97 | without having to repeat yourself. 98 | 99 | with pt.defaults_scope(activation_fn=tf.nn.relu): 100 | hidden_output2 = (pretty_images.flatten() 101 | .fully_connected(100) 102 | .fully_connected(100)) 103 | 104 | Check out the documentation to see 105 | [all supported defaults](docs/pretty_tensor_top_level.md#defaults_scope). 106 | 107 | ### Code matches model 108 | 109 | Sequential mode lets you break model construction across lines and provides 110 | the subdivide syntactic sugar that makes it easy to define and understand 111 | complex structures like an [inception module](http://arxiv.org/abs/1409.4842): 112 | 113 | 114 | with pretty_tensor.defaults_scope(activation_fn=tf.nn.relu): 115 | seq = pretty_input.sequential() 116 | with seq.subdivide(4) as towers: 117 | towers[0].conv2d(1, 64) 118 | towers[1].conv2d(1, 112).conv2d(3, 224) 119 | towers[2].conv2d(1, 32).conv2d(5, 64) 120 | towers[3].max_pool(2, 3).conv2d(1, 32) 121 | 122 | ![Inception module showing branch and rejoin](inception_module.png) 123 | 124 | Templates provide guaranteed parameter reuse and make unrolling recurrent 125 | networks easy: 126 | 127 | output = [], s = tf.zeros([BATCH, 256 * 2]) 128 | 129 | A = (pretty_tensor.template('x') 130 | .lstm_cell(num_units=256, state=UnboundVariable('state')) 131 | 132 | for x in pretty_input_array: 133 | h, s = A.construct(x=x, state=s) 134 | output.append(h) 135 | 136 | There are also some convenient shorthands for LSTMs and GRUs: 137 | 138 | pretty_input_array.sequence_lstm(num_units=256) 139 | 140 | ![Unrolled RNN](unrolled_lstm.png) 141 | 142 | ### Extensible 143 | 144 | You can call any existing operation by using `apply` and it will simply 145 | subsitute the current tensor for the first argument. 146 | 147 | pretty_input.apply(tf.mul, 5) 148 | 149 | You can also create a new operation There are two supported registration 150 | mechanisms to add your own functions. `@Register()` allows you to create a 151 | method on PrettyTensor that operates on the Tensors and returns either a loss or 152 | a new value. Name scoping and variable scoping are handled by the framework. 153 | 154 | The following method adds the leaky_relu method to every Pretty Tensor: 155 | 156 | @pt.Register 157 | def leaky_relu(input_pt): 158 | return tf.select(tf.greater(input_pt, 0.0), input_pt, 0.01 * input_pt) 159 | 160 | 161 | `@RegisterCompoundOp()` is like adding a macro, it is designed to group together 162 | common sets of operations. 163 | 164 | ### Safe variable reuse 165 | 166 | Within a graph, you can reuse variables by using templates. A template is 167 | just like a regular graph except that some variables are left unbound. 168 | 169 | See more details in [PrettyTensor class](docs/PrettyTensor.md). 170 | 171 | ### Accessing Variables 172 | 173 | Pretty Tensor uses the standard graph collections from TensorFlow to store variables. These can be accessed using `tf.get_collection(key)` with the following keys: 174 | 175 | * `tf.GraphKeys.VARIABLES`: all variables that should be saved (including some statistics). 176 | * `tf.GraphKeys.TRAINABLE_VARIABLES: all variables that can be trained (including those before a `stop_gradients` call). These are what would typically be called *parameters* of the model in ML parlance. 177 | * `pt.GraphKeys.TEST_VARIABLES`: variables used to evaluate a model. These are typically not saved and are reset by the LocalRunner.evaluate method to get a fresh evaluation. 178 | 179 | ## Authors 180 | 181 | Eider Moore (eiderman) 182 | 183 | with key contributions from: 184 | 185 | * Hubert Eichner 186 | * Oliver Lange 187 | * Sagar Jain (sagarjn) 188 | -------------------------------------------------------------------------------- /docs/Bookkeeper.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Bookkeeper 4 | 5 | Small class to gather needed pieces from a Graph being built. 6 | 7 | This class is mostly an implementation detail of Pretty Tensor and almost 8 | never needs to be used when building a model. Most of the useful methods 9 | are exposed in the `pt` namespace. The most common usecase for directly 10 | calling a Bookkeeper methods are to create summaries in the same way as 11 | Pretty Tensor that are controlled by the `pt.defaults_scope`. 12 | 13 | - - - 14 | 15 | [TOC] 16 | 17 | 18 | ## add_average_summary(var, tag=None, decay=0.999, ignore_nan=True) 19 | 20 | 21 | 22 | Add a summary with the moving average of var. 23 | 24 | Adds a variable to keep track of the exponential moving average and adds an 25 | update operation to the bookkeeper. The name of the variable is 26 | '%s_average' % name prefixed with the current variable scope. 27 | 28 | #### Args: 29 | 30 | 31 | * var: The variable for which a moving average should be computed. 32 | * tag: The tag of the summary. If None var.name[:-2] is used to strip off 33 | the ':0' that is added by TF. 34 | * decay: How much history to use in the moving average. 35 | Higher, means more history values [0.9, 1) accepted. 36 | * ignore_nan: If the value is NaN or Inf, skip it. Note that this default 37 | is different than the exponential_moving_average one. 38 | 39 | #### Returns: 40 | 41 | The averaged variable. 42 | 43 | 44 | #### Raises: 45 | 46 | 47 | * ValueError: if decay is not in [0.9, 1). 48 | 49 | 50 | - - - 51 | 52 | ## add_histogram_summary(x, tag=None) 53 | 54 | 55 | 56 | Add a summary operation to visualize the histogram of x's values. 57 | 58 | 59 | 60 | 61 | 62 | - - - 63 | 64 | ## add_loss(loss, name=None, regularization=False, add_summaries=True) 65 | 66 | 67 | 68 | Append a loss to the total loss for the network. 69 | 70 | #### Args: 71 | 72 | 73 | * loss: append this loss operation 74 | * name: The name for this loss, defaults to loss.op.name 75 | * regularization: Set to True if this is a regularization loss. 76 | * add_summaries: Set to True if you want to see scalar and average summary. 77 | 78 | 79 | 80 | 81 | 82 | - - - 83 | 84 | ## add_losses(losses, regularization=False) 85 | 86 | 87 | 88 | 89 | - - - 90 | 91 | ## add_scalar_summary(x, tag=None) 92 | 93 | 94 | 95 | Adds a scalar summary for x. 96 | 97 | 98 | 99 | 100 | 101 | - - - 102 | 103 | ## create_composite_loss(losses, regularize=True, include_marked=True, name=cost) 104 | 105 | 106 | 107 | Creates a loss that is the sum of all specified losses. 108 | 109 | #### Args: 110 | 111 | 112 | * losses: A sequence of losses to include. 113 | * regularize: Whether or not to include regularization losses. 114 | * include_marked: Whether or not to use the marked losses. 115 | * name: The name for this variable. 116 | 117 | #### Returns: 118 | 119 | A single tensor that is the sum of all losses. 120 | 121 | 122 | #### Raises: 123 | 124 | 125 | * ValueError: if there are no losses. 126 | 127 | 128 | - - - 129 | 130 | ## exponential_moving_average(var, avg_var=None, decay=0.999, ignore_nan=False) 131 | 132 | 133 | 134 | Calculates the exponential moving average. 135 | 136 | TODO(): check if this implementation of moving average can now 137 | be replaced by tensorflows implementation. 138 | 139 | Adds a variable to keep track of the exponential moving average and adds an 140 | update operation to the bookkeeper. The name of the variable is 141 | '%s_average' % name prefixed with the current variable scope. 142 | 143 | #### Args: 144 | 145 | 146 | * var: The variable for which a moving average should be computed. 147 | * avg_var: The variable to set the average into, if None create a zero 148 | initialized one. 149 | * decay: How much history to use in the moving average. 150 | Higher, means more history values [0, 1) accepted. 151 | * ignore_nan: If the value is NaN or Inf, skip it. 152 | 153 | #### Returns: 154 | 155 | The averaged variable. 156 | 157 | 158 | #### Raises: 159 | 160 | 161 | * ValueError: if decay is not in [0, 1). 162 | 163 | 164 | - - - 165 | 166 | ## reset_summary_collections() 167 | 168 | 169 | 170 | Sets the summary collections to the default. 171 | 172 | 173 | 174 | 175 | 176 | - - - 177 | 178 | ## with_update_ops(train_op) 179 | 180 | 181 | 182 | 183 | - - - 184 | ## Properties 185 | 186 | * g 187 | * global_step 188 | * marked_losses 189 | * recurrent_state 190 | * regularization_losses 191 | * summaries 192 | * update_ops 193 | -------------------------------------------------------------------------------- /docs/PrettyTensorTupleMixin.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # PrettyTensorTupleMixin 4 | 5 | Adds methods to any sequence type so that it can be used with binding. 6 | 7 | Generally this can be used with namedtuples to provide clean multi-value 8 | returns: 9 | 10 | class MyType(namedtuple(...), PrettyTensorTupleMixin): 11 | pass 12 | 13 | Subclasses with nested structure should note that this does not unpack 14 | nested structure by default. You must implement flatten and 15 | build_from_flattened. 16 | 17 | - - - 18 | 19 | [TOC] 20 | 21 | 22 | ## as_fn() 23 | 24 | 25 | 26 | Creates a function by binding the arguments in the given order. 27 | 28 | #### Args: 29 | 30 | 31 | * binding_order: The unbound variables. This must include all values. 32 | 33 | #### Returns: 34 | 35 | A function that takes the arguments of binding_order. 36 | 37 | 38 | #### Raises: 39 | 40 | 41 | * ValueError: If the bindings are missing values or include unknown values. 42 | 43 | 44 | - - - 45 | 46 | ## bind() 47 | 48 | 49 | 50 | Makes the bindings to each item in this and returns a new tuple. 51 | 52 | 53 | 54 | 55 | 56 | - - - 57 | 58 | ## build_from_flattened(flattened) 59 | 60 | 61 | 62 | Given a flattened structure from flatten, make a new version of this. 63 | 64 | 65 | 66 | 67 | 68 | - - - 69 | 70 | ## construct() 71 | 72 | 73 | 74 | 75 | - - - 76 | 77 | ## flatten() 78 | 79 | 80 | 81 | Subclasses with nested structure should implement this method. 82 | 83 | 84 | #### Returns: 85 | 86 | A list of data that should be bound and constructed, by default just self. 87 | 88 | 89 | 90 | 91 | - - - 92 | 93 | ## has_unbound_vars() 94 | 95 | 96 | 97 | Returns whether there are any unbound vars in this tuple. 98 | 99 | 100 | 101 | 102 | 103 | - - - 104 | ## Properties 105 | 106 | * unbound_vars 107 | -------------------------------------------------------------------------------- /docs/pretty_tensor_top_level.md: -------------------------------------------------------------------------------- 1 | 2 | # Pretty Tensor Base Imports 3 | 4 | 5 | 6 | [TOC] 7 | - - - 8 | 9 | ## BatchNormalizationArguments 10 | 11 | BatchNormalizationArguments(learned_moments_update_rate, variance_epsilon, scale_after_normalization) 12 | - - - 13 | 14 | ### Properties 15 | 16 | * count 17 | * index 18 | * learned_moments_update_rate 19 | * scale_after_normalization 20 | * variance_epsilon 21 | 22 | - - - 23 | ## Class Bookkeeper 24 | 25 | [see details in Bookkeeper.md](Bookkeeper.md) 26 | - - - 27 | 28 | ## GraphKeys 29 | 30 | Graphs can store data in graph keys for constructing the graph. 31 | - - - 32 | 33 | ### Properties 34 | 35 | * LOSSES 36 | * MARKED_LOSSES 37 | * RECURRENT_STATE_VARIABLES 38 | * REGULARIZATION_LOSSES 39 | * TEST_VARIABLES 40 | * UPDATE_OPS 41 | 42 | - - - 43 | ## Class Loss 44 | 45 | [see details in Loss.md](Loss.md) 46 | - - - 47 | 48 | ## Phase 49 | 50 | Some nodes are different depending on the phase of the graph construction. 51 | 52 | The standard phases are train, test and infer. 53 | 54 | - - - 55 | 56 | ### Properties 57 | 58 | * infer 59 | * test 60 | * train 61 | 62 | - - - 63 | ## Class PrettyTensor 64 | 65 | [see details in PrettyTensor.md](PrettyTensor.md) 66 | - - - 67 | ## Class PrettyTensorTupleMixin 68 | 69 | [see details in PrettyTensorTupleMixin.md](PrettyTensorTupleMixin.md) 70 | - - - 71 | 72 | ## Register 73 | 74 | Decorator for registering a method in PrettyTensor. 75 | 76 | This is either used to decorate a bare function or an object that has a no-arg 77 | constructor and a __call__ method. 78 | 79 | The first argument to the function will be the PrettyTensor object. The 80 | registered method's return value must be one of the following: 81 | 82 | 1. A PrettyTensor, i.e. the result of calling `with_tensor` or 83 | `with_sequence`. 84 | 2. A Tensor. 85 | 3. A Loss result from calling `add_loss`. 86 | 87 | `RegisterCompoundOp` is provided for more direct manipulations with some 88 | caveats. 89 | 90 | - - - 91 | 92 | 93 | ### create_deferred(func, input_layer, deferred_args, deferred_kwargs, name) 94 | 95 | 96 | 97 | Creates a deferred node with captured scope. 98 | 99 | ##### Args: 100 | 101 | 102 | * func: The original function to call. 103 | * input_layer: The input_layer. 104 | * deferred_args: The arguments that will be used bythe deferred function. 105 | * deferred_kwargs: The keyword args for the deferred function. 106 | * name: The name of this layer. 107 | 108 | ##### Returns: 109 | 110 | A _DeferredLayer that will execute func in the correct scopes. 111 | 112 | 113 | 114 | 115 | 116 | ### create_method(obj) 117 | 118 | 119 | 120 | 121 | 122 | ### fill_kwargs(input_layer, kwargs) 123 | 124 | 125 | 126 | Applies name_suffix and defaults to kwargs and returns the result. 127 | 128 | 129 | 130 | 131 | 132 | 133 | - - - 134 | 135 | ## RegisterCompoundOp 136 | 137 | This is used to register a compound operation. 138 | 139 | The operation is executed immediately on the base PrettyTensor type. This has 140 | the following implications: 141 | 142 | 1. `tensor` and `sequence` may not be available in the deferred case. 143 | 2. The object passed in might be sequential or a layer. 144 | 145 | Also because this is intended to provide convenience chaining of other 146 | registered methods, it does not add a name or id scope automatically, which 147 | makes it behave as if the raw methods were called (unless the op itself does 148 | scoping). 149 | 150 | - - - 151 | 152 | 153 | ### create_method(func) 154 | 155 | 156 | 157 | Creates the method. 158 | 159 | 160 | 161 | 162 | 163 | 164 | ### fill_kwargs(input_layer, kwargs) 165 | 166 | 167 | 168 | Applies name_suffix and defaults to kwargs and returns the result. 169 | 170 | 171 | 172 | 173 | 174 | 175 | - - - 176 | 177 | ## UnboundVariable 178 | 179 | An UnboundVariable is a variable with a value that is supplied using bind. 180 | 181 | UnboundVariables are typically used so that input layers can be specified at a 182 | later time or for hyper parameters. Supplying a UnboundVariable as an input 183 | variable automatically forces the graph to be a template. 184 | 185 | - - - 186 | 187 | 188 | ### has_default() 189 | 190 | 191 | 192 | 193 | 194 | - - - 195 | 196 | ## VarStoreMethod 197 | 198 | Convenience base class for registered methods that create variables. 199 | 200 | This tracks the variables and requries subclasses to provide a __call__ 201 | method. 202 | 203 | - - - 204 | 205 | 206 | ### variable(var_name, shape, init, dt=, train=None) 207 | 208 | 209 | 210 | Adds a named variable to this bookkeeper or returns an existing one. 211 | 212 | Variables marked train are returned by the training_variables method. If 213 | the requested name already exists and it is compatible (same shape, dt and 214 | train) then it is returned. In case of an incompatible type, an exception is 215 | thrown. 216 | 217 | ##### Args: 218 | 219 | 220 | * var_name: The unique name of this variable. If a variable with the same 221 | name exists, then it is returned. 222 | * shape: The shape of the variable. 223 | * init: The init function to use or a Tensor to copy. 224 | * dt: The datatype, defaults to float. This will automatically extract the 225 | base dtype. 226 | * train: Whether or not the variable should be trained; defaults to 227 | True unless a default_scope has overridden it. 228 | 229 | ##### Returns: 230 | 231 | A TensorFlow tensor. 232 | 233 | 234 | ##### Raises: 235 | 236 | 237 | * ValueError: if reuse is False (or unspecified and allow_reuse is False) 238 | and the variable already exists or if the specification of a reused 239 | variable does not match the original. 240 | 241 | 242 | 243 | - - - 244 | 245 | ## apply_optimizer(losses, regularize=True, include_marked=True, clip_gradients_by_norm=None) 246 | 247 | 248 | 249 | Apply an optimizer to the graph and returns a train_op. 250 | 251 | The resulting operation will minimize the specified losses, plus the 252 | regularization losses that have been collected during graph construction and 253 | the losses that were marked by calling `mark_as_required`. 254 | 255 | It will also apply any updates that have been collected (e.g. for moving 256 | average summaries). 257 | 258 | This is equivalent to: 259 | 260 | total_loss = prettytensor.create_composite_loss( 261 | losses=losses, regularize=regularize, include_marked=include_marked) 262 | train_op_without_updates = optimizer.minimize(total_loss) 263 | train_op = prettytensor.with_update_ops(train_op_without_updates) 264 | 265 | N.B. Pay special attention to the `gate_gradients` argument to the optimizer. 266 | If your graph is large, it will likely train unacceptably slow if you don't 267 | specify it as GATE_NONE. 268 | 269 | #### Args: 270 | 271 | 272 | * optimizer: The optimizer the minimize. 273 | * losses: A list of losses to apply. 274 | * regularize: Whether or not to include the regularization losses. 275 | * include_marked: Whether or not to use the marked losses. 276 | * clip_gradients_by_norm: If not None, clip gradients by the norm using 277 | `tf.clip_by_norm`. 278 | **kwargs: Additional arguments to pass into the optimizer. 279 | 280 | #### Returns: 281 | 282 | An operation to use for training that also updates any required ops such as 283 | moving averages. 284 | 285 | 286 | 287 | 288 | - - - 289 | 290 | ## for_default_graph() 291 | 292 | 293 | 294 | Creates a bookkeeper for the default graph. 295 | 296 | #### Args: 297 | 298 | 299 | * args: Arguments to pass into Bookkeeper's constructor. 300 | **kwargs: Arguments to pass into Bookkeeper's constructor. 301 | 302 | #### Returns: 303 | 304 | A new Bookkeeper. 305 | 306 | 307 | #### Raises: 308 | 309 | 310 | * ValueError: If args or kwargs are provided and the Bookkeeper already 311 | exists. 312 | 313 | 314 | - - - 315 | 316 | ## for_new_graph() 317 | 318 | 319 | 320 | Creates a Bookkeeper for a new graph. 321 | 322 | You must use `m.g.as_default()` to put the graph in scope: 323 | 324 | m = Bookkeeper.for_new_graph() 325 | with m.g.as_default(): 326 | ... 327 | 328 | #### Args: 329 | 330 | 331 | * args: Arguments to pass into Bookkeeper's constructor. 332 | **kwargs: Arguments to pass into Bookkeeper's constructor. 333 | 334 | #### Returns: 335 | 336 | A new Bookkeeper. 337 | 338 | 339 | 340 | 341 | - - - 342 | 343 | ## construct_all() 344 | 345 | 346 | 347 | Constructs all the given templates in a single pass without redundancy. 348 | 349 | This is useful when the templates have a common substructure and you want the 350 | smallest possible graph. 351 | 352 | #### Args: 353 | 354 | 355 | * templates: A sequence of templates. 356 | **unbound_var_values: The unbound_var values to replace. 357 | 358 | #### Returns: 359 | 360 | A list of results corresponding to templates. 361 | 362 | 363 | #### Raises: 364 | 365 | 366 | * TypeError: If any value in templates is unsupported. 367 | * ValueError: If the unbound_var values specified are not complete or contain 368 | unknown values. 369 | 370 | 371 | - - - 372 | 373 | ## create_composite_loss(regularize=True, include_marked=True, name=cost) 374 | 375 | 376 | 377 | Creates a loss that is the sum of all specified losses. 378 | 379 | #### Args: 380 | 381 | 382 | * losses: A sequence of losses to include. 383 | * regularize: Whether or not to include regularization losses. 384 | * include_marked: Whether or not to use the marked losses. 385 | * name: The name for this variable. 386 | 387 | #### Returns: 388 | 389 | A single tensor that is the sum of all losses. 390 | 391 | 392 | 393 | 394 | - - - 395 | 396 | 397 | ## defaults_scope(... 398 | 399 | defaults_scope(activation_fn=None, batch_normalize=None, l2loss=None, learned_moments_update_rate=None, parameter_modifier=None, phase=None, scale_after_normalization=None, summary_collections=None, trainable_variables=None, unroll=None, variable_collections=None, variance_epsilon=None) 400 | 401 | Creates a scope for the defaults that are used in a `with` block. 402 | 403 | Note: `defaults_scope` supports nesting where later defaults can be 404 | overridden. Also, an explicitly given keyword argument on a method always 405 | takes precedence. 406 | 407 | In addition to setting defaults for some methods, this also can control: 408 | 409 | * `summary_collections`: Choose which collection to place summaries in or 410 | disable with `None`. 411 | * `trainable_variables`: Boolean indicating if variables are trainable. 412 | * `variable_collections`: Default collections in which to place variables; 413 | `tf.GraphKeys.GLOBAL_VARIABLES` is always included. 414 | 415 | The supported defaults and methods that use them are: 416 | 417 | 418 | * `activation_fn`: 419 | * [conv2d](PrettyTensor.md#conv2d) 420 | * [depthwise_conv2d](PrettyTensor.md#depthwise_conv2d) 421 | * [fully_connected](PrettyTensor.md#fully_connected) 422 | 423 | * `batch_normalize`: 424 | * [conv2d](PrettyTensor.md#conv2d) 425 | * [depthwise_conv2d](PrettyTensor.md#depthwise_conv2d) 426 | 427 | * `l2loss`: 428 | * [conv2d](PrettyTensor.md#conv2d) 429 | * [depthwise_conv2d](PrettyTensor.md#depthwise_conv2d) 430 | * [diagonal_matrix_mul](PrettyTensor.md#diagonal_matrix_mul) 431 | * [fully_connected](PrettyTensor.md#fully_connected) 432 | 433 | * `learned_moments_update_rate`: 434 | * [batch_normalize](PrettyTensor.md#batch_normalize) 435 | 436 | * `parameter_modifier`: 437 | * [conv2d](PrettyTensor.md#conv2d) 438 | * [depthwise_conv2d](PrettyTensor.md#depthwise_conv2d) 439 | * [softmax_classifier_with_sampled_loss](PrettyTensor.md#softmax_classifier_with_sampled_loss) 440 | * [softmax_classifier](PrettyTensor.md#softmax_classifier) 441 | * [diagonal_matrix_mul](PrettyTensor.md#diagonal_matrix_mul) 442 | * [fully_connected](PrettyTensor.md#fully_connected) 443 | * [lstm_cell](PrettyTensor.md#lstm_cell) 444 | * [sequence_lstm](PrettyTensor.md#sequence_lstm) 445 | * [gru_cell](PrettyTensor.md#gru_cell) 446 | * [sequence_gru](PrettyTensor.md#sequence_gru) 447 | * [embedding_lookup](PrettyTensor.md#embedding_lookup) 448 | 449 | * `phase`: 450 | * [batch_normalize](PrettyTensor.md#batch_normalize) 451 | * [conv2d](PrettyTensor.md#conv2d) 452 | * [depthwise_conv2d](PrettyTensor.md#depthwise_conv2d) 453 | * [evaluate_precision_recall](PrettyTensor.md#evaluate_precision_recall) 454 | * [evaluate_classifier_fraction](PrettyTensor.md#evaluate_classifier_fraction) 455 | * [evaluate_classifier](PrettyTensor.md#evaluate_classifier) 456 | * [evaluate_classifier_fraction_sparse](PrettyTensor.md#evaluate_classifier_fraction_sparse) 457 | * [evaluate_classifier_sparse](PrettyTensor.md#evaluate_classifier_sparse) 458 | * [dropout](PrettyTensor.md#dropout) 459 | * [diagonal_matrix_mul](PrettyTensor.md#diagonal_matrix_mul) 460 | * [fully_connected](PrettyTensor.md#fully_connected) 461 | * [lstm_cell](PrettyTensor.md#lstm_cell) 462 | * [sequence_lstm](PrettyTensor.md#sequence_lstm) 463 | * [gru_cell](PrettyTensor.md#gru_cell) 464 | * [sequence_gru](PrettyTensor.md#sequence_gru) 465 | * [embedding_lookup](PrettyTensor.md#embedding_lookup) 466 | 467 | * `scale_after_normalization`: 468 | * [batch_normalize](PrettyTensor.md#batch_normalize) 469 | 470 | * `unroll`: 471 | * [cleave_sequence](PrettyTensor.md#cleave_sequence) 472 | 473 | * `variance_epsilon`: 474 | * [batch_normalize](PrettyTensor.md#batch_normalize) 475 | 476 | - - - 477 | 478 | ## global_step() 479 | 480 | 481 | 482 | Returns the global step variable. 483 | 484 | 485 | 486 | 487 | 488 | - - - 489 | 490 | ## join_pretty_tensors(output, join_function=None, name=join) 491 | 492 | 493 | 494 | Joins the list of pretty_tensors and sets head of output_pretty_tensor. 495 | 496 | #### Args: 497 | 498 | 499 | * tensors: A sequence of Layers or SequentialLayerBuilders to join. 500 | * output: A pretty_tensor to set the head with the result. 501 | * join_function: A function to join the tensors, defaults to concat on the 502 | last dimension. 503 | * name: A name that is used for the name_scope 504 | 505 | #### Returns: 506 | 507 | The result of calling with_tensor on output 508 | 509 | 510 | #### Raises: 511 | 512 | 513 | * ValueError: if pretty_tensors is None or empty. 514 | 515 | 516 | - - - 517 | 518 | ## make_template(func) 519 | 520 | 521 | 522 | Given an arbitrary function, wrap it so that it does parameter sharing. 523 | 524 | 525 | 526 | 527 | 528 | - - - 529 | 530 | ## recurrent_state() 531 | 532 | 533 | 534 | 535 | - - - 536 | 537 | ## set_recurrent_state_saver() 538 | 539 | 540 | 541 | Sets the state saver used for recurrent sequences. 542 | 543 | 544 | 545 | 546 | 547 | - - - 548 | 549 | ## template(books=None, optional=False) 550 | 551 | 552 | 553 | Starts a Pretty Tensor graph template. 554 | 555 | ## Template Mode 556 | 557 | Templates allow you to define a graph with some unknown 558 | values. The most common use case is to leave the input undefined and then 559 | define a graph normally. The variables are only defined once the first time 560 | the graph is constructed. For example: 561 | 562 | template = (pretty_tensor.template('input') 563 | .fully_connected(200, name='l1') 564 | .fully_connected(200, name='l2')) 565 | train_output = template.construct(input=train_data) 566 | 567 | # All parameters are reused when the same template object is called again. 568 | test_output = template.construct(input=test_data) 569 | 570 | Any argument to a pretty tensor method can be substituted by using an 571 | `UnboundVariable`. 572 | This allows you to parameterize a graph in arbitrary ways. The most cannonical 573 | usage would be to substitute a phase variable. 574 | 575 | with pretty_tensor.defaults_scope(phase=UnboundVariable('train')): 576 | # dropout uses train to optionaly disable itself. 577 | 578 | template = (pretty_tensor.template('input') 579 | .fully_connected(200, name='l1') 580 | .fully_connected(200, name='l2') 581 | .dropout(.8)) 582 | train_output = template.construct(input=train_data, train=True) 583 | test_output = template.construct(input=test_data, train=False) 584 | 585 | 586 | You should use caution because if a template is called with incompatible 587 | values (e.g. train and test using different widths), then it will break. This 588 | is because we guarantee variable reuse across instantiations. 589 | 590 | template = (pretty_tensor.template('input') 591 | .fully_connected(200, name='l1') 592 | .fully_connected( 593 | pretty_tensor.UnboundVariable('width'), name='l2')) 594 | train_output = template.construct(input=train_data, width=200) 595 | 596 | # The following line will die because the shared parameter is the wrong 597 | # size. 598 | test_output = template.construct(input=test_data, width=100) 599 | 600 | 601 | A Layer in the resulting graph can be realized by calling 602 | `bind(key=value)` and then `construct`. 603 | 604 | #### Args: 605 | 606 | 607 | * key: A key for this template, used for assigning the correct substitution. 608 | * books: The bookkeeper. 609 | * optional: If this template is an optional value. 610 | 611 | #### Returns: 612 | 613 | A template that can be constructed or attached to other layers and that 614 | guarantees parameter reuse when constructed/attached multiple times. 615 | 616 | 617 | 618 | 619 | - - - 620 | 621 | ## with_update_ops() 622 | 623 | 624 | 625 | Creates a new op that runs all of the required updates when train_op runs. 626 | 627 | #### Args: 628 | 629 | 630 | * train_op: An operation that will run every step, usually the result of an 631 | optimizer. 632 | 633 | #### Returns: 634 | 635 | A new op that returns the same value as train_op, but also runs the 636 | updaters. 637 | 638 | 639 | 640 | 641 | - - - 642 | 643 | ## wrap(books=None, tensor_shape=None) 644 | 645 | 646 | 647 | Creates an input layer representing the given tensor. 648 | 649 | #### Args: 650 | 651 | 652 | * tensor: The tensor. 653 | * books: The bookkeeper; this is usually not required unless you are building 654 | multiple `tf.Graphs.` 655 | * tensor_shape: An optional shape that will be set on the Tensor or verified 656 | to match the tensor. 657 | 658 | #### Returns: 659 | 660 | A layer. 661 | 662 | 663 | 664 | 665 | - - - 666 | 667 | ## wrap_sequence(books=None, tensor_shape=None) 668 | 669 | 670 | 671 | Creates an input layer representing the given sequence of tensors. 672 | 673 | #### Args: 674 | 675 | 676 | * sequence: A sequence of tensors. 677 | * books: The bookkeeper. 678 | * tensor_shape: An optional shape that will be set on the Tensor or verified 679 | to match the tensor. 680 | 681 | #### Returns: 682 | 683 | A layer. 684 | 685 | 686 | 687 | 688 | - - - 689 | 690 | 691 | ## Extensions 692 | 693 | 694 | 695 | - - - 696 | -------------------------------------------------------------------------------- /inception_module.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/prettytensor/75daa0b11252590f548da5647addc0ea610c4c45/inception_module.png -------------------------------------------------------------------------------- /prettytensor/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | """PrettyTensor nice syntax layer on top of TensorFlow. 12 | 13 | This will eventually be the preferred place to import prettytensor. 14 | 15 | For now, please use pretty_tensor.py since this is in a state of flux. 16 | 17 | see [README.md](https://github.com/google/prettytensor) for documentation. 18 | see pretty_tensor_samples/ for usage examples. 19 | """ 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | # pylint: disable=unused-import 25 | from prettytensor import funcs 26 | from prettytensor import parameters 27 | from prettytensor import train 28 | from prettytensor.bookkeeper import apply_optimizer 29 | from prettytensor.bookkeeper import Bookkeeper 30 | from prettytensor.bookkeeper import create_composite_loss 31 | from prettytensor.bookkeeper import for_default_graph as bookkeeper_for_default_graph 32 | from prettytensor.bookkeeper import for_new_graph as bookkeeper_for_new_graph 33 | from prettytensor.bookkeeper import global_step 34 | from prettytensor.bookkeeper import GraphKeys 35 | from prettytensor.bookkeeper import recurrent_state 36 | from prettytensor.bookkeeper import set_recurrent_state_saver 37 | from prettytensor.bookkeeper import with_update_ops 38 | 39 | from prettytensor.pretty_tensor_class import construct_all 40 | from prettytensor.pretty_tensor_class import defaults_scope 41 | from prettytensor.pretty_tensor_class import DIM_REST 42 | from prettytensor.pretty_tensor_class import DIM_SAME 43 | from prettytensor.pretty_tensor_class import join_pretty_tensors 44 | from prettytensor.pretty_tensor_class import Loss 45 | from prettytensor.pretty_tensor_class import PAD_SAME 46 | from prettytensor.pretty_tensor_class import PAD_VALID 47 | from prettytensor.pretty_tensor_class import Phase 48 | from prettytensor.pretty_tensor_class import PrettyTensor 49 | from prettytensor.pretty_tensor_class import PrettyTensorTupleMixin 50 | from prettytensor.pretty_tensor_class import PROVIDED 51 | from prettytensor.pretty_tensor_class import Register 52 | from prettytensor.pretty_tensor_class import RegisterCompoundOp 53 | from prettytensor.pretty_tensor_class import template 54 | from prettytensor.pretty_tensor_class import UnboundVariable 55 | from prettytensor.pretty_tensor_class import VarStoreMethod 56 | from prettytensor.pretty_tensor_class import wrap 57 | from prettytensor.pretty_tensor_class import wrap_sequence 58 | 59 | from prettytensor.pretty_tensor_normalization_methods import BatchNormalizationArguments 60 | from prettytensor.scopes import make_template 61 | 62 | __version__ = '0.7.1' 63 | -------------------------------------------------------------------------------- /prettytensor/bookkeeper_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | """Test class for bookkeepers.""" 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | 16 | import unittest 17 | 18 | 19 | import tensorflow as tf 20 | 21 | from prettytensor import bookkeeper 22 | from prettytensor import pretty_tensor_testing 23 | 24 | 25 | class BookkeeperTest(pretty_tensor_testing.PtTestCase): 26 | 27 | def setUp(self): 28 | super(self.__class__, self).setUp() 29 | 30 | def assertSameContents(self, list1, list2, msg): 31 | self.assertEqual(len(list1), len(list2), msg) 32 | self.assertEqual(set(list1), set(list2), msg) 33 | 34 | def testGraphIsReused(self): 35 | b1 = bookkeeper.for_default_graph() 36 | b2 = bookkeeper.for_default_graph() 37 | self.assertTrue(b1 is b2) 38 | 39 | def testPassingArgsCausesError(self): 40 | b1 = bookkeeper.for_new_graph() 41 | with b1.g.as_default(), self.assertRaises(ValueError): 42 | bookkeeper.for_default_graph(global_step=None) 43 | 44 | def testGlobalStep(self): 45 | v = tf.Variable(0) 46 | b1 = bookkeeper.for_new_graph(global_step=v) 47 | with b1.g.as_default(): 48 | self.assertEqual(v, bookkeeper.global_step()) 49 | with self.assertRaises(ValueError): 50 | bookkeeper.for_new_graph(global_step=tf.Variable(1.0)) 51 | 52 | def testUniqueBookkeeperPerGraph(self): 53 | b1 = bookkeeper.for_default_graph() 54 | with tf.Graph().as_default(): 55 | b2 = bookkeeper.for_default_graph() 56 | self.assertFalse(b1 is b2) 57 | 58 | def testBareVarName(self): 59 | name = 'hello' 60 | var = tf.Variable([1], name=name) 61 | self.assertEquals(name, bookkeeper._bare_var_name(var)) 62 | self.assertEquals(name, 63 | bookkeeper._bare_var_name(var._as_graph_element())) 64 | 65 | 66 | if __name__ == '__main__': 67 | unittest.main() 68 | -------------------------------------------------------------------------------- /prettytensor/chain_dict.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | """Creates a dict with a parent so that missing values are sent up.""" 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | 16 | import collections 17 | 18 | 19 | class ChainDict(collections.MutableMapping): 20 | """The Name class.""" 21 | 22 | def __init__(self, parent): 23 | self._map = {} 24 | self._parent = parent 25 | self._dead_count = 0 26 | 27 | def __getitem__(self, key): 28 | if key in self._map: 29 | return self._map[key] 30 | elif self._parent: 31 | return self._parent[key] 32 | else: 33 | raise KeyError('Key not found: %s' % key) 34 | 35 | def __setitem__(self, key, value): 36 | self._map[key] = value 37 | 38 | def __delitem__(self, key): 39 | raise Exception('Deleting items not supported.') 40 | 41 | def _full_map(self): 42 | """Creates a full mapping of this and all parent key, value pairs.""" 43 | result = {} 44 | if self._parent: 45 | result.update(self._parent) 46 | result.update(self._map) 47 | return result 48 | 49 | def __iter__(self): 50 | return self._full_map().__iter__() 51 | 52 | def __len__(self): 53 | return len(self._full_map()) 54 | -------------------------------------------------------------------------------- /prettytensor/chain_dict_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | """Test class for ChainDict.""" 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | 16 | import unittest 17 | 18 | from prettytensor import chain_dict 19 | 20 | 21 | class ChainDictTest(unittest.TestCase): 22 | 23 | def testSet(self): 24 | d = chain_dict.ChainDict(None) 25 | d['KEY'] = 'VALUE' 26 | self.assertEqual({'KEY': 'VALUE'}, d._map) 27 | 28 | def testGetNoParent(self): 29 | d = chain_dict.ChainDict(None) 30 | d['KEY'] = 'VALUE' 31 | self.assertEqual('VALUE', d['KEY']) 32 | 33 | def testGetAbsentNoParent(self): 34 | d = chain_dict.ChainDict(None) 35 | with self.assertRaises(KeyError): 36 | # pylint: disable=pointless-statement 37 | d['KEY'] 38 | 39 | def testGetInParent(self): 40 | parent = chain_dict.ChainDict(None) 41 | d = chain_dict.ChainDict(parent) 42 | parent['KEY'] = 'VALUE' 43 | self.assertEqual('VALUE', parent['KEY']) 44 | self.assertEqual('VALUE', d['KEY']) 45 | 46 | def testGetNotInParent(self): 47 | parent = chain_dict.ChainDict(None) 48 | d = chain_dict.ChainDict(parent) 49 | with self.assertRaises(KeyError): 50 | # pylint: disable=pointless-statement 51 | parent['KEY'] 52 | with self.assertRaises(KeyError): 53 | # pylint: disable=pointless-statement 54 | d['KEY'] 55 | 56 | def testGetOverridden(self): 57 | parent = chain_dict.ChainDict(None) 58 | d = chain_dict.ChainDict(parent) 59 | parent['KEY'] = 'VALUE' 60 | d['KEY'] = 'OTHER_VALUE' 61 | self.assertEqual('VALUE', parent['KEY']) 62 | self.assertEqual('OTHER_VALUE', d['KEY']) 63 | 64 | def testLen(self): 65 | parent = chain_dict.ChainDict(None) 66 | d = chain_dict.ChainDict(parent) 67 | self.assertEqual(0, len(d)) 68 | self.assertEqual(0, len(parent)) 69 | 70 | parent['KEY'] = 'VALUE' 71 | self.assertEqual(1, len(d)) 72 | self.assertEqual(1, len(parent)) 73 | 74 | d['KEY'] = 'OTHER_VALUE' 75 | self.assertEqual(1, len(d)) 76 | self.assertEqual(1, len(parent)) 77 | 78 | d['OTHER_KEY'] = 'YAV' 79 | self.assertEqual(2, len(d)) 80 | self.assertEqual(1, len(parent)) 81 | 82 | def testIteration(self): 83 | parent = chain_dict.ChainDict(None) 84 | d = chain_dict.ChainDict(parent) 85 | parent['KEY'] = 'VALUE' 86 | d['KEY'] = 'OTHER_VALUE' 87 | d['OTHER_KEY'] = 'YAV' 88 | 89 | self.assertEqual([('KEY', 'OTHER_VALUE'), ('OTHER_KEY', 'YAV')], 90 | sorted(d.items())) 91 | self.assertEqual([('KEY', 'VALUE')], sorted(parent.items())) 92 | 93 | def testPlainDictParent(self): 94 | d = chain_dict.ChainDict({'KEY': 'VALUE'}) 95 | self.assertEqual('VALUE', d['KEY']) 96 | self.assertEqual(len(d), 1) 97 | # In Python 3, items produces an iterator. 98 | self.assertEqual(list(d.items()), [('KEY', 'VALUE')]) 99 | 100 | if __name__ == '__main__': 101 | unittest.main() 102 | -------------------------------------------------------------------------------- /prettytensor/funcs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | """Experimental functional API for PrettyTensor. 12 | 13 | This exposes all of the standard PrettyTensor functions, but instead of 14 | chaining, they are invoked like regular functions. 15 | """ 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import sys 21 | 22 | import six 23 | 24 | # pylint: disable=unused-import, wildcard-import 25 | from prettytensor.pretty_tensor_image_methods import * 26 | from prettytensor.pretty_tensor_loss_methods import * 27 | from prettytensor.pretty_tensor_methods import * 28 | from prettytensor.pretty_tensor_sparse_methods import * 29 | from prettytensor.recurrent_networks import * 30 | 31 | 32 | def _remove_non_methods(): 33 | """Removes any object in dict that is not a registered method.""" 34 | cur_module = sys.modules[__name__] 35 | my_globals = dict(globals()) 36 | # Import here so that it doesn't get added to the global namespace or deleted. 37 | # pylint: disable=g-import-not-at-top 38 | from prettytensor.pretty_tensor_class import PrettyTensor 39 | for name, _ in six.iteritems(my_globals): 40 | if not hasattr(PrettyTensor, name): 41 | delattr(cur_module, name) 42 | # Remove a couple of special ones.... 43 | if hasattr(cur_module, 'bookkeeper'): 44 | delattr(cur_module, 'bookkeeper') 45 | 46 | _remove_non_methods() 47 | -------------------------------------------------------------------------------- /prettytensor/functions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | """Reusable TensorFlow functions. 12 | 13 | This file is divided into roughly 4 parts: 14 | 15 | * _regression functions with signature: `fn(y, target, name=)` where y is 16 | the output of your DNN and target is the regression target as a tensor and 17 | the result is a tensor with shape [1]. 18 | * _distance functions with signature: `fn(t1, t2, name=) where t1 and t2 are 19 | both tensors and the result is a tensor with shape [N] where N is the first 20 | dimension of t1 and t2. 21 | * Activation functions with signature `fn(x, name=)` where x is a tensor 22 | and the result is a tensor of the same shape. 23 | * Utility functions. These include normalizations that are used in embedding 24 | models as a non-linearity and a few others. 25 | """ 26 | from __future__ import absolute_import 27 | from __future__ import division 28 | from __future__ import print_function 29 | 30 | from six.moves import xrange # pylint: disable=redefined-builtin 31 | import tensorflow as tf 32 | 33 | # To improve numerical stability, we want to not do an exponential above this 34 | # value. 35 | _SOFTPLUS_STABILITY_LIMIT = 20.0 36 | 37 | 38 | def l1_regression_loss(y, target, name=None): 39 | """Calculates the sum of absolute errors between y and target. 40 | 41 | Args: 42 | y: the calculated values. 43 | target: the desired values. 44 | name: the name for this op, defaults to l1_regression 45 | Returns: 46 | A tensorflow op. 47 | """ 48 | with tf.name_scope(name, 'l1_regression', [y, target]) as scope: 49 | y = tf.convert_to_tensor(y, name='y') 50 | target = tf.convert_to_tensor(target, name='target') 51 | return reduce_batch_sum(tf.abs(y - target), name=scope) 52 | 53 | 54 | def l2_regression_sq_loss(y, target, name=None): 55 | """Calculates the sum of squared errors between y and target. 56 | 57 | Args: 58 | y: the calculated values. 59 | target: the desired values. 60 | name: the name for this op, defaults to l2_regression 61 | Returns: 62 | A tensorflow op. 63 | """ 64 | with tf.name_scope(name, 'l2_regression_sq', [y, target]) as scope: 65 | y = tf.convert_to_tensor(y, name='y') 66 | target = tf.convert_to_tensor(target, name='target') 67 | return reduce_batch_sum(tf.square(y - target), name=scope) 68 | 69 | 70 | def reduce_batch_sum(x, name=None): 71 | with tf.name_scope(name, 'reduce_batch_sum', [x]) as scope: 72 | ndims = x.get_shape().ndims 73 | if ndims == 0: 74 | raise ValueError('Cannot reduce a scalar into batches.') 75 | elif ndims == 1: 76 | return x # Don't include a useless sum. 77 | elif ndims: 78 | reduction_indices = list(range(1, x.get_shape().ndims)) 79 | shape = [x.get_shape().dims[0]] 80 | else: 81 | reduction_indices = tf.range(1, tf.size(tf.shape(x))) 82 | shape = [None] # We don't know much about the shape, but it is rank 1. 83 | result = tf.reduce_sum(x, reduction_indices=reduction_indices, name=scope) 84 | 85 | # Give a shape hint in case we have extra information. 86 | result.set_shape(shape) 87 | return result 88 | 89 | 90 | def l2_regression_loss(y, target, name=None): 91 | """Calculates the square root of the SSE between y and target. 92 | 93 | Args: 94 | y: the calculated values. 95 | target: the desired values. 96 | name: the name for this op, defaults to l2_regression 97 | Returns: 98 | A tensorflow op. 99 | """ 100 | with tf.name_scope(name, 'l2_regression', [y, target]) as scope: 101 | y = tf.convert_to_tensor(y, name='y') 102 | target = tf.convert_to_tensor(target, name='target') 103 | return tf.sqrt(l2_regression_sq_loss(y, target, name=scope)) 104 | 105 | 106 | def binary_cross_entropy_loss_with_logits(x, target, name=None): 107 | """Calculates the binary cross entropy between sigmoid(x) and target. 108 | 109 | Expects unscaled logits. Do not pass in results of sigmoid operation. 110 | 111 | Args: 112 | x: the calculated pre-sigmoid values 113 | target: the desired values. 114 | name: the name for this op, defaults to binary_cross_entropy_with_logits 115 | Returns: 116 | -(target * -softplus(-x) + (1-target) * (-x - softplus(-x))) 117 | Raises: 118 | ValueError: If shapes are incompatible. 119 | """ 120 | with tf.name_scope(name, 'binary_cross_entropy_with_logits', 121 | [x, target]) as scope: 122 | x.get_shape().assert_is_compatible_with(target.get_shape()) 123 | neg_softplus = -tf.nn.softplus(-x) 124 | return -tf.add(tf.multiply(target, neg_softplus), 125 | tf.multiply(1 - target, -x + neg_softplus), 126 | name=scope) 127 | 128 | 129 | def cos_distance(t1, t2, epsilon=1e-12, name=None): 130 | """Cos distance between t1 and t2 and caps the gradient of the Square Root. 131 | 132 | Args: 133 | t1: A tensor 134 | t2: A tensor that can be multiplied by t1. 135 | epsilon: A lower bound value for the distance. The square root is used as 136 | the normalizer. 137 | name: Optional name for this op. 138 | Returns: 139 | The cos distance between t1 and t2. 140 | """ 141 | with tf.name_scope(name, 'cos_distance', [t1, t2]) as scope: 142 | t1 = tf.convert_to_tensor(t1, name='t1') 143 | t2 = tf.convert_to_tensor(t2, name='t2') 144 | x_inv_norm = tf.rsqrt(tf.maximum(length_squared(t1) * length_squared(t2), 145 | epsilon)) 146 | return tf.subtract(1.0, dot_product(t1, t2) * x_inv_norm, name=scope) 147 | 148 | 149 | def dot_distance(t1, t2, name=None): 150 | """dot "distance" between t1 and t2. 151 | 152 | Args: 153 | t1: A tensor. 154 | t2: A tensor that is the same size as t1. 155 | name: Optional name for this op. 156 | Returns: 157 | The dot distance between t1 and t2. 158 | """ 159 | with tf.name_scope(name, 'dot_distance', [t1, t2]) as scope: 160 | return -dot_product(t1, t2, name=scope) 161 | 162 | 163 | def l2_distance_sq(t1, t2, name=None): 164 | """Square of l2 distance between t1 and t2. 165 | 166 | Args: 167 | t1: A tensor. 168 | t2: A tensor that is the same size as t1. 169 | name: Optional name for this op. 170 | Returns: 171 | The l2 distance between t1 and t2. 172 | """ 173 | with tf.name_scope(name, 'l2_distance_sq', [t1, t2]) as scope: 174 | t1 = tf.convert_to_tensor(t1, name='t1') 175 | t2 = tf.convert_to_tensor(t2, name='t2') 176 | return length_squared(tf.subtract(t1, t2), name=scope) 177 | 178 | 179 | def l2_distance(t1, t2, epsilon=1e-12, name=None): 180 | """l2 distance between t1 and t2 and caps the gradient of the Square Root. 181 | 182 | Args: 183 | t1: A tensor. 184 | t2: A tensor that is the same size as t1. 185 | epsilon: A lower bound for distance, useful to avoid sqrt of very small 186 | values that can blow up gradients. 187 | name: Optional name for this op. 188 | Returns: 189 | The l2 distance between t1 and t2. 190 | """ 191 | with tf.name_scope(name, 'l2_distance', [t1, t2]) as scope: 192 | t1 = tf.convert_to_tensor(t1, name='t1') 193 | t2 = tf.convert_to_tensor(t2, name='t2') 194 | return tf.sqrt(tf.maximum(l2_distance_sq(t1, t2, scope), epsilon)) 195 | 196 | 197 | def l1_distance(t1, t2, name=None): 198 | """l1 distance between t1 and t2. 199 | 200 | Args: 201 | t1: A tensor. 202 | t2: A tensor that is the same size as t1. 203 | name: Optional name for this op. 204 | Returns: 205 | The l1 distance between t1 and t2. 206 | """ 207 | with tf.name_scope(name, 'l1_distance', [t1, t2]) as scope: 208 | t1 = tf.convert_to_tensor(t1, name='t1') 209 | t2 = tf.convert_to_tensor(t2, name='t2') 210 | sub = tf.subtract(t1, t2) 211 | reduction_dim = _last_index(sub, 1) 212 | return tf.reduce_sum(tf.abs(sub), reduction_dim, name=scope) 213 | 214 | 215 | def leaky_relu(x, name=None): 216 | """Creates a leaky_relu. 217 | 218 | This is an alternate non-linearity to relu. The leaky part of the relu may 219 | prevent dead Neurons in a model since the gradient doesn't go completely to 220 | 0. 221 | 222 | Args: 223 | x: The input tensor. 224 | name: Optional name for this op. 225 | Returns: 226 | x if x > 0 otherwise 0.01 * x. 227 | """ 228 | with tf.name_scope(name, 'leaky_relu', [x]) as scope: 229 | x = tf.convert_to_tensor(x, name='x') 230 | return tf.where(tf.less(x, 0.0), 0.01 * x, x, name=scope) 231 | 232 | 233 | def softplus(x, scale=1.0, name=None): 234 | """Computes softplus with a scale factor to sharpen of the hinge. 235 | 236 | This is an alternate non-linearity to relu. It has a similar shape, but 237 | it has a smooth transition from the linear part to 0. 238 | 239 | Args: 240 | x: A tensor. 241 | scale: A float that sharpens the curve. 242 | name: Optional name. 243 | Returns: 244 | y = log(1 + exp(scale * x)) / scale 245 | 246 | """ 247 | if scale == 1: 248 | return tf.nn.softplus(x) 249 | else: 250 | with tf.name_scope(name, 'softplus', [x]): 251 | scale = tf.convert_to_tensor(scale, dtype=x.dtype.base_dtype) 252 | return tf.nn.softplus(x * scale) / scale 253 | 254 | 255 | # Copied to keep API consistency with other functions. 256 | l2_normalize = tf.nn.l2_normalize 257 | 258 | 259 | def l1_normalize(x, dim, epsilon=1e-12, name=None): 260 | """l1 normalizes x. 261 | 262 | Args: 263 | x: The tensor to normalize. 264 | dim: The dimension to normalize along. 265 | epsilon: Lower bound on the norm, used to avoid exploding gradients as the 266 | norm approaches 0. 267 | name: Optional name for this op. 268 | Returns: 269 | x normalized along dim. 270 | """ 271 | with tf.name_scope(name, 'l1_normalize', [x]) as scope: 272 | x = tf.convert_to_tensor(x, name='x') 273 | x = tf.verify_tensor_all_finite(x, 'Error at input %s' % scope) 274 | x_norm = tf.maximum(tf.reduce_sum(tf.abs(x), [dim], keep_dims=True), 275 | epsilon) 276 | return tf.div(x, x_norm, name=scope) 277 | 278 | 279 | def every_other(x, name=None): 280 | """Drops every other value from the tensor and returns a 1D tensor. 281 | 282 | This is useful if you are running multiple inputs through a model tower 283 | before splitting them and you want to line it up with some other data. 284 | 285 | Args: 286 | x: the target tensor. 287 | name: the name for this op, defaults to every_other 288 | Returns: 289 | A tensorflow op. 290 | """ 291 | with tf.name_scope(name, 'every_other', [x]) as scope: 292 | x = tf.convert_to_tensor(x, name='x') 293 | return tf.reshape( 294 | tf.slice( 295 | tf.reshape(x, [-1, 2]), [0, 0], [-1, 1]), 296 | [-1], 297 | name=scope) 298 | 299 | 300 | def dot_product(t1, t2, keep_dims=False, name=None, reduction_dim=None): 301 | """Computes the dot product of t1 and t2. 302 | 303 | Args: 304 | t1: A rank 2 tensor. 305 | t2: A tensor that is the same size as t1. 306 | keep_dims: If true, reduction does not change the rank of the input. 307 | name: Optional name for this op. 308 | reduction_dim: The dimension to reduce, by default choose the last one 309 | and if no shape is specified guess 1. 310 | Returns: 311 | The dot product. 312 | """ 313 | with tf.name_scope(name, 'dot', [t1, t2]) as scope: 314 | t1 = tf.convert_to_tensor(t1, name='t1') 315 | t2 = tf.convert_to_tensor(t2, name='t2') 316 | mul = tf.multiply(t1, t2) 317 | if not reduction_dim: 318 | reduction_dim = _last_index(mul, 1) 319 | return tf.reduce_sum(mul, reduction_dim, name=scope, keep_dims=keep_dims) 320 | 321 | 322 | def length_squared(x, keep_dims=False, name=None, reduction_dim=None): 323 | """Computes the squared length of x. 324 | 325 | Args: 326 | x: A tensor. 327 | keep_dims: If true, reduction does not change the rank of the input. 328 | name: Optional name for this op. 329 | reduction_dim: The dimension to reduce, by default choose the last one 330 | and if no shape is specified guess 1. 331 | Returns: 332 | The squared length of x. 333 | """ 334 | with tf.name_scope(name, 'length_squared', [x]) as scope: 335 | x = tf.convert_to_tensor(x, name='x') 336 | if not reduction_dim: 337 | reduction_dim = _last_index(x, 1) 338 | return tf.reduce_sum( 339 | tf.square(x), 340 | reduction_dim, 341 | keep_dims=keep_dims, 342 | name=scope) 343 | 344 | 345 | def unzip(x, split_dim, current_length, num_splits=2, name=None): 346 | """Splits a tensor by unzipping along the split_dim. 347 | 348 | For example the following array split into 2 would be: 349 | [1, 2, 3, 4, 5, 6] -> [1, 3, 5], [2, 4, 6] 350 | and by 3: 351 | [1, 2, 3, 4] -> [1, 4], [2], [3] 352 | 353 | Args: 354 | x: The tensor to split. 355 | split_dim: The dimension to split along. 356 | current_length: Current length along the split_dim. 357 | num_splits: The number of splits. 358 | name: Optional name for this op. 359 | Returns: 360 | A length num_splits sequence. 361 | """ 362 | with tf.name_scope(name, 'unzip', [x]) as scope: 363 | x = tf.convert_to_tensor(x, name='x') 364 | # There is probably a more efficient way to do this. 365 | all_splits = tf.split( 366 | value=x, num_or_size_splits=current_length, axis=split_dim, name=scope) 367 | splits = [[] for _ in xrange(num_splits)] 368 | for i in xrange(current_length): 369 | splits[i % num_splits].append(all_splits[i]) 370 | return [tf.concat(s, split_dim) for s in splits] 371 | 372 | 373 | def _last_index(x, default_dim): 374 | """Returns the last dimension's index or default_dim if x has no shape.""" 375 | if x.get_shape().ndims is not None: 376 | return len(x.get_shape()) - 1 377 | else: 378 | return default_dim 379 | 380 | 381 | def _all_dims(x, default_dims=None): 382 | """Returns a list of dims in x or default_dims if the rank is unknown.""" 383 | if x.get_shape().ndims is not None: 384 | return list(xrange(x.get_shape().ndims)) 385 | else: 386 | return default_dims 387 | -------------------------------------------------------------------------------- /prettytensor/functions_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | """Test class for functions.""" 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | 16 | 17 | 18 | import numpy 19 | from numpy import linalg 20 | from numpy import testing 21 | import tensorflow as tf 22 | 23 | from prettytensor import functions 24 | 25 | TOLERANCE = 0.00001 26 | 27 | 28 | # Distance functions used in tests. These are defined here instead of using 29 | # scipy so the open source tests don't depend on such a huge module for 3 30 | # 1 line functions. 31 | def cosine(u, v): # pylint: disable=invalid-name 32 | return 1.0 - numpy.dot(u, v) / (linalg.norm(u, ord=2) * linalg.norm(v, ord=2)) 33 | 34 | 35 | def cityblock(u, v): # pylint: disable=invalid-name 36 | return numpy.abs(u - v).sum() 37 | 38 | 39 | def euclidean(u, v): # pylint: disable=invalid-name 40 | return linalg.norm(u - v, ord=2) 41 | 42 | 43 | class TensorFlowOpTest(tf.test.TestCase): 44 | 45 | def eval_tensor(self, tensors): 46 | if isinstance(tensors, tf.Tensor): 47 | tensors = [tensors] 48 | with self.test_session() as sess: 49 | return sess.run(tensors) 50 | 51 | def test_every_other(self): 52 | tensor = tf.constant([[1, 2], [3, 4]]) 53 | out = self.eval_tensor(functions.every_other(tensor)) 54 | testing.assert_array_equal(out[0], numpy.array([1, 3], dtype=numpy.int32)) 55 | tensor = tf.constant([[1, 2, 3, 4]]) 56 | out = self.eval_tensor(functions.every_other(tensor)) 57 | testing.assert_array_equal(out[0], numpy.array([1, 3], dtype=numpy.int32)) 58 | 59 | def test_l1_regression_loss(self): 60 | ftensor1 = tf.constant([1., 2., 3., 4.]) 61 | ftensor2 = tf.constant([5., 6., 7., -8.]) 62 | out = self.eval_tensor(functions.l1_regression_loss(ftensor1, ftensor2)) 63 | testing.assert_array_equal(out[0], numpy.array([4., 4., 4., 12.])) 64 | 65 | def test_l2_sq_regression_loss(self): 66 | ftensor1 = tf.constant([1., 2., 3., 4.]) 67 | ftensor2 = tf.constant([5., 6., 7., -8.]) 68 | out = self.eval_tensor(functions.l2_regression_sq_loss(ftensor1, ftensor2)) 69 | testing.assert_array_equal(out[0], numpy.array([16., 16., 16, 144])) 70 | 71 | def test_l2_regression_loss(self): 72 | ftensor1 = tf.constant([1., 2., 3., 4.]) 73 | ftensor2 = tf.constant([5., 6., 7., -8.]) 74 | out = self.eval_tensor(functions.l2_regression_loss(ftensor1, ftensor2)) 75 | testing.assert_allclose( 76 | out[0], 77 | numpy.array([4., 4., 4., 12.]), 78 | rtol=TOLERANCE, atol=TOLERANCE) 79 | 80 | def test_binary_cross_entropy_loss_with_logits(self): 81 | n1 = numpy.array([2., 3., 4., 5., -6., -7.], dtype=numpy.float32) 82 | n2 = numpy.array([1., 1., 0., 0., 0., 1.], dtype=numpy.float32) 83 | ftensor1 = tf.constant(n1) 84 | ftensor2 = tf.constant(n2) 85 | out = self.eval_tensor(functions.binary_cross_entropy_loss_with_logits( 86 | ftensor1, ftensor2)) 87 | testing.assert_allclose( 88 | out[0], 89 | n1 * (1-n2) + numpy.log(1 + numpy.exp(-n1)), 90 | rtol=0.00001) 91 | 92 | def test_soft_plus(self): 93 | # 100 overflows naive implementations in float 94 | values = ( 95 | numpy.array( 96 | [-100., -10., 1., 0, 1., 10., 100.], 97 | dtype=numpy.float32)) 98 | out = self.eval_tensor( 99 | functions.softplus( 100 | tf.constant( 101 | values, 102 | dtype=tf.float32), 103 | 1.)) 104 | np_values = numpy.log(1. + numpy.exp(values)) 105 | np_values[6] = 100. 106 | testing.assert_allclose(out[0], np_values, rtol=TOLERANCE, atol=TOLERANCE) 107 | 108 | out = self.eval_tensor(functions.softplus(tf.constant(values), 2.)) 109 | np_values = numpy.log(1. + numpy.exp(values * 2.)) / 2. 110 | np_values[6] = 100. 111 | testing.assert_allclose(out[0], np_values, rtol=TOLERANCE, atol=TOLERANCE) 112 | 113 | def test_cos_distance(self): 114 | n1 = numpy.array([[1., 2., 3., 4.], [1., 1., 1., 1.]], dtype=numpy.float32) 115 | n2 = numpy.array([[5., 6., 7., -8.], [1., 1., 1., 1.]], dtype=numpy.float32) 116 | out = self.eval_tensor(functions.cos_distance(n1, n2)) 117 | 118 | testing.assert_allclose( 119 | out[0], 120 | numpy.array([cosine(n1[0], n2[0]), cosine(n1[1], n2[1])]), 121 | rtol=TOLERANCE, atol=TOLERANCE) 122 | 123 | def test_l1_distance(self): 124 | n1 = numpy.array([[1., 2., 3., 4.], [1., 1., 1., 1.]], dtype=numpy.float32) 125 | n2 = numpy.array([[5., 6., 7., -8.], [1., 1., 1., 1.]], dtype=numpy.float32) 126 | out = self.eval_tensor(functions.l1_distance(n1, n2)) 127 | testing.assert_allclose( 128 | out[0], 129 | numpy.array( 130 | [cityblock(n1[0], n2[0]), cityblock(n1[1], n2[1]) 131 | ]), 132 | rtol=TOLERANCE, atol=TOLERANCE) 133 | 134 | def test_l2_distance(self): 135 | n1 = numpy.array([[1., 2., 3., 4.], [1., 1., 1., 1.]], dtype=numpy.float32) 136 | n2 = numpy.array([[5., 6., 7., -8.], [1., 1., 1., 1.]], dtype=numpy.float32) 137 | out = self.eval_tensor(functions.l2_distance(n1, n2)) 138 | testing.assert_allclose( 139 | out[0], 140 | numpy.array( 141 | [euclidean(n1[0], n2[0]), 142 | 1e-6 # Epsilon sets the minimum distance so use that instead of 0. 143 | ]), 144 | rtol=TOLERANCE, atol=TOLERANCE) 145 | 146 | def test_l2_distance_sq(self): 147 | n1 = numpy.array([[1., 2., 3., 4.], [1., 1., 1., 1.]], dtype=numpy.float32) 148 | n2 = numpy.array([[5., 6., 7., -8.], [1., 1., 1., 1.]], dtype=numpy.float32) 149 | out = self.eval_tensor(functions.l2_distance_sq(n1, n2)) 150 | testing.assert_allclose( 151 | out[0], 152 | numpy.power( 153 | numpy.array( 154 | [euclidean(n1[0], n2[0]), euclidean( 155 | n1[1], n2[1])]), 2), 156 | rtol=TOLERANCE, atol=TOLERANCE) 157 | 158 | def test_dot_distance(self): 159 | n1 = numpy.array([[1., 2., 3., 4.], [1., 1., 1., 1.]], dtype=numpy.float32) 160 | n2 = numpy.array([[5., 6., 7., -8.], [1., 1., 1., 1.]], dtype=numpy.float32) 161 | out = self.eval_tensor(functions.dot_distance(n1, n2)) 162 | testing.assert_allclose( 163 | out[0], 164 | numpy.array(-numpy.sum(n1 * n2, 165 | axis=1)), 166 | rtol=TOLERANCE, atol=TOLERANCE) 167 | 168 | def test_cos_distance_with_broadcast(self): 169 | n1 = numpy.array([[[1., 2., 3., 4.], [1., 1., 1., 1.]], [[5., 6., 7., 8.], 170 | [1., 1., 1., 2.]]], 171 | dtype=numpy.float32) 172 | n2 = numpy.array([[5., 6., 7., -8.], [1., 1., 1., 1.]], dtype=numpy.float32) 173 | out = self.eval_tensor(functions.cos_distance(n1, n2)) 174 | expected = numpy.array( 175 | [[cosine(n1[0, 0], n2[0]), cosine(n1[0, 1], n2[1])], 176 | [cosine(n1[1, 0], n2[0]), cosine(n1[1, 1], n2[1])]]) 177 | testing.assert_allclose(expected, out[0], atol=TOLERANCE) 178 | 179 | def test_l1_distance_with_broadcast(self): 180 | n1 = numpy.array([[[1., 2., 3., 4.], [1., 1., 1., 1.]], [[5., 6., 7., 8.], 181 | [1., 1., 1., 2.]]], 182 | dtype=numpy.float32) 183 | n2 = numpy.array([[5., 6., 7., -8.], [1., 1., 1., 1.]], dtype=numpy.float32) 184 | out = self.eval_tensor(functions.l1_distance(n1, n2)) 185 | expected = numpy.array( 186 | [[cityblock(n1[0, 0], n2[0]), cityblock( 187 | n1[0, 1], n2[1])], [cityblock(n1[1, 0], n2[0]), 188 | cityblock(n1[1, 1], n2[1])]]) 189 | testing.assert_allclose(expected, out[0], atol=TOLERANCE) 190 | 191 | def test_l2_distance_with_broadcast(self): 192 | n1 = numpy.array([[[1., 2., 3., 4.], [1., 1., 1., 1.]], [[5., 6., 7., 8.], 193 | [1., 1., 1., 2.]]], 194 | dtype=numpy.float32) 195 | n2 = numpy.array([[5., 6., 7., -8.], [1., 1., 1., 1.]], dtype=numpy.float32) 196 | out = self.eval_tensor(functions.l2_distance(n1, n2)) 197 | expected = numpy.array( 198 | [[euclidean(n1[0, 0], n2[0]), euclidean( 199 | n1[0, 1], n2[1])], [euclidean(n1[1, 0], n2[0]), 200 | euclidean(n1[1, 1], n2[1])]]) 201 | testing.assert_allclose(expected, out[0], atol=TOLERANCE) 202 | 203 | def test_l2_distance_sq_with_broadcast(self): 204 | n1 = numpy.array([[[1., 2., 3., 4.], [1., 1., 1., 1.]], [[5., 6., 7., 8.], 205 | [1., 1., 1., 2.]]], 206 | dtype=numpy.float32) 207 | n2 = numpy.array([[5., 6., 7., -8.], [1., 1., 1., 1.]], dtype=numpy.float32) 208 | out = self.eval_tensor(functions.l2_distance_sq(n1, n2)) 209 | expected = numpy.array( 210 | [[euclidean(n1[0, 0], n2[0]), euclidean( 211 | n1[0, 1], n2[1])], [euclidean(n1[1, 0], n2[0]), 212 | euclidean(n1[1, 1], n2[1])]]) 213 | expected = numpy.power(expected, 2) 214 | testing.assert_allclose(expected, out[0], atol=TOLERANCE) 215 | 216 | def test_dot_distance_with_broadcast(self): 217 | n1 = numpy.array([[[1., 2., 3., 4.], [1., 1., 1., 1.]], [[5., 6., 7., 8.], 218 | [1., 1., 1., 2.]]], 219 | dtype=numpy.float32) 220 | n2 = numpy.array([[5., 6., 7., -8.], [1., 1., 1., 1.]], dtype=numpy.float32) 221 | out = self.eval_tensor(functions.dot_distance(n1, n2)) 222 | testing.assert_allclose( 223 | out[0], 224 | numpy.array(-numpy.sum(n1 * n2, 225 | axis=2)), 226 | rtol=TOLERANCE, atol=TOLERANCE) 227 | 228 | def test_l2_normalize(self): 229 | n1 = numpy.array([[1., 2., 3., 4.], [1., 1., 1., 1.]], dtype=numpy.float32) 230 | t1 = tf.constant(n1) 231 | out = self.eval_tensor(functions.l2_normalize(t1, 1)) 232 | testing.assert_allclose( 233 | out[0], 234 | n1 / linalg.norm(n1, 2, axis=1).reshape((2, 1)), 235 | rtol=TOLERANCE, atol=TOLERANCE) 236 | 237 | def test_l1_normalize(self): 238 | n1 = numpy.array([[1., 2., 3., 4.], [1., 1., 1., 1.]], dtype=numpy.float32) 239 | t1 = tf.constant(n1) 240 | out = self.eval_tensor(functions.l1_normalize(t1, 1)) 241 | testing.assert_allclose( 242 | out[0], 243 | n1 / linalg.norm(n1, 1, axis=1).reshape((2, 1)), 244 | rtol=TOLERANCE, atol=TOLERANCE) 245 | 246 | def test_leaky_relu(self): 247 | values = ( 248 | numpy.array( 249 | [-100., -10., 1., 0, 1., 10., 100.], 250 | dtype=numpy.float32)) 251 | tensor = tf.constant(values) 252 | out = self.eval_tensor(functions.leaky_relu(tensor)) 253 | for i, value in enumerate(values): 254 | if value < 0: 255 | values[i] *= 0.01 256 | testing.assert_allclose(out[0], values, rtol=TOLERANCE, atol=TOLERANCE) 257 | 258 | def test_unzip(self): 259 | n1 = numpy.array([[1., 2.], [3., 4.], [5., 6.], [7., 8.]], 260 | dtype=numpy.float32) 261 | t1 = tf.constant(n1) 262 | out = self.eval_tensor(functions.unzip(t1, 0, 4, 2)) 263 | 264 | expected = numpy.array([[1., 2.], [5., 6.]], dtype=numpy.float32) 265 | testing.assert_allclose(expected, out[0], rtol=TOLERANCE, atol=TOLERANCE) 266 | expected = numpy.array([[3., 4.], [7., 8.]], dtype=numpy.float32) 267 | testing.assert_allclose(expected, out[1], rtol=TOLERANCE, atol=TOLERANCE) 268 | 269 | def test_split(self): 270 | """Testing TF functionality to highlight difference with Unzip.""" 271 | n1 = numpy.array([[1., 2.], [3., 4.], [5., 6.], [7., 8.]], 272 | dtype=numpy.float32) 273 | t1 = tf.constant(n1) 274 | out = self.eval_tensor(tf.split(value=t1, num_or_size_splits=2, axis=0)) 275 | expected = numpy.array([[1., 2.], [3., 4.]], dtype=numpy.float32) 276 | testing.assert_allclose(expected, out[0], rtol=TOLERANCE, atol=TOLERANCE) 277 | expected = numpy.array([[5., 6.], [7., 8.]], dtype=numpy.float32) 278 | testing.assert_allclose(expected, out[1], rtol=TOLERANCE, atol=TOLERANCE) 279 | 280 | 281 | if __name__ == '__main__': 282 | tf.test.main() 283 | -------------------------------------------------------------------------------- /prettytensor/input_helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | """Provides helpers for feeding in numpy data to a TF graph. 12 | 13 | These methods are intended to aid experimentation. For large datasets consider 14 | using readers and queues. 15 | """ 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import itertools 21 | 22 | 23 | 24 | from six.moves import xrange # pylint: disable=redefined-builtin 25 | import tensorflow as tf 26 | 27 | from prettytensor import bookkeeper 28 | 29 | 30 | def feed_numpy(batch_size, *arrays): 31 | """Given a set of numpy arrays, produce slices of batch_size. 32 | 33 | Note: You can use itertools.cycle to have this repeat forever. 34 | 35 | Args: 36 | batch_size: The batch_size for each array. 37 | *arrays: A list of arrays. 38 | Yields: 39 | A list of slices from the arrays of length batch_size except the last one 40 | which will contain the rest. 41 | Raises: 42 | ValueError: If arrays aren't all the same length or no arrays are provided. 43 | """ 44 | if not arrays: 45 | raise ValueError('Arrays cannot be empty.') 46 | size = len(arrays[0]) 47 | for a in arrays: 48 | if size != len(a): 49 | raise ValueError('All arrays must be the same size.') 50 | count = int(size / batch_size) 51 | 52 | for i in xrange(count): 53 | start = i * batch_size 54 | end = start + batch_size 55 | yield [x[start:end] for x in arrays] 56 | if count * batch_size < size: 57 | yield [x[end:] for x in arrays] 58 | 59 | 60 | def batch(input_iter, batch_size=32): 61 | """Batches data from an iterator that returns single items at a time.""" 62 | input_iter = iter(input_iter) 63 | next_ = list(itertools.islice(input_iter, batch_size)) 64 | while next_: 65 | yield next_ 66 | next_ = list(itertools.islice(input_iter, batch_size)) 67 | 68 | 69 | def slice_constant(data, batch_size=32, name='constant_data', global_step=None): 70 | """Provide a slice based on the global_step. 71 | 72 | This is useful when the entire data array can be stored in memory because it 73 | allows you to feed the data very efficiently. 74 | 75 | Args: 76 | data: A numpy array or tensor. 77 | batch_size: The batch size for the produced data. 78 | name: An optional name for this data. 79 | global_step: A global step variable that is used to read the data. If None 80 | then the default prettytensor global_step is used. 81 | Returns: 82 | A tensor that produces the given data. 83 | """ 84 | with tf.name_scope(name): 85 | all_data = tf.convert_to_tensor(data) 86 | global_step = global_step or bookkeeper.global_step() 87 | 88 | count = len(data) / batch_size 89 | extra = len(data) - count * batch_size 90 | 91 | if extra: 92 | offset = tf.mod(global_step, count) 93 | return tf.slice(all_data, offset * batch_size, batch_size) 94 | else: 95 | offset = tf.mod(global_step, count + 1) 96 | return tf.slice(all_data, offset * batch_size, 97 | tf.where(tf.equal(offset, count), extra, batch_size)) 98 | -------------------------------------------------------------------------------- /prettytensor/layers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | """Utility functions for adding layers to a Model. 12 | 13 | NB: This is used by PrettyTensor, but it will be deprecated. Please do not use! 14 | """ 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import collections 20 | import math 21 | 22 | import tensorflow as tf 23 | 24 | from prettytensor import functions 25 | 26 | # Implementation note: this takes a tuple for an activation instead of 27 | # encouraging lambdas so that we can inspect the actual function and add 28 | # appropriate summaries. 29 | 30 | 31 | def apply_activation( 32 | books, 33 | x, 34 | activation, 35 | activation_args=(), 36 | activation_kwargs=None): 37 | """Returns activation(x, *activation_args, **activation_kwargs). 38 | 39 | This applies the given activation and adds useful summaries specific to the 40 | activation. 41 | 42 | Args: 43 | books: The bookkeeper. 44 | x: The tensor to apply activation to. 45 | activation: An activation function. 46 | activation_args: Optional additional arguments for the activation. 47 | activation_kwargs: Optional keyword args for activation. 48 | Returns: 49 | A tensor with activation applied to x. 50 | """ 51 | if activation is None: 52 | return x 53 | if activation_kwargs is None: 54 | activation_kwargs = {} 55 | y = activation(x, *activation_args, **activation_kwargs) 56 | if activation in (tf.nn.relu, functions.leaky_relu, functions.softplus): 57 | books.add_scalar_summary( 58 | tf.reduce_mean(tf.cast(tf.less(x, 0.0), tf.float32)), 59 | '%s/zeros' % y.op.name) 60 | elif activation is tf.nn.relu6: 61 | books.add_scalar_summary( 62 | tf.reduce_mean(tf.cast(tf.less(x, 0.0), tf.float32)), 63 | '%s/zeros' % y.op.name) 64 | books.add_scalar_summary( 65 | tf.reduce_mean(tf.cast(tf.greater(x, 6.0), tf.float32)), 66 | '%s/sixes' % y.op.name) 67 | elif activation in (functions.l2_normalize, tf.nn.l2_normalize, 68 | functions.l1_normalize): 69 | books.add_scalar_summary( 70 | tf.reduce_mean(tf.sqrt(tf.reduce_sum( 71 | tf.square(x), 1))), '%s/length' % y.op.name) 72 | return y 73 | 74 | 75 | def add_l2loss(books, params, l2loss, name='weight_decay'): 76 | if l2loss: 77 | books.add_loss( 78 | tf.multiply( 79 | tf.nn.l2_loss(params), l2loss, name=name), 80 | regularization=True, 81 | add_summaries=False) 82 | 83 | 84 | def he_init(n_inputs, n_outputs, activation_fn, uniform=True): 85 | """Sets the parameter initialization using the method described. 86 | 87 | This method is designed to keep the scale of the gradients roughly the same 88 | in all layers with ReLU activations. 89 | 90 | He et al. (2015): 91 | Delving deep into rectifiers: surpassing human-level performance on 92 | imageNet classification. International Conference on Computer Vision. 93 | 94 | For activations other than ReLU and ReLU6, this method uses Xavier 95 | initialization as in xavier_init(). 96 | 97 | Args: 98 | n_inputs: The number of input nodes into each output. 99 | n_outputs: The number of output nodes for each input. 100 | activation_fn: Activation function used in this layer. 101 | uniform: If uniform distribution will be used for Xavier initialization. 102 | Normal distribution will be used if False. 103 | Returns: 104 | An initializer. 105 | """ 106 | def in_relu_family(activation_fn): 107 | if isinstance(activation_fn, collections.Sequence): 108 | activation_fn = activation_fn[0] 109 | return activation_fn in (tf.nn.relu, tf.nn.relu6) 110 | 111 | if in_relu_family(activation_fn): 112 | stddev = math.sqrt(2.0 / n_inputs) 113 | # TODO(): Evaluates truncated_normal_initializer. 114 | return tf.random_normal_initializer(stddev=stddev) 115 | else: 116 | return xavier_init(n_inputs, n_outputs, uniform) 117 | 118 | 119 | def xavier_init(n_inputs, n_outputs, uniform=True): 120 | """Set the parameter initialization using the method described. 121 | 122 | This method is designed to keep the scale of the gradients roughly the same 123 | in all layers. 124 | 125 | Xavier Glorot and Yoshua Bengio (2010): 126 | Understanding the difficulty of training deep feedforward neural 127 | networks. International conference on artificial intelligence and 128 | statistics. 129 | Args: 130 | n_inputs: The number of input nodes into each output. 131 | n_outputs: The number of output nodes for each input. 132 | uniform: If true use a uniform distribution, otherwise use a normal. 133 | Returns: 134 | An initializer. 135 | """ 136 | if uniform: 137 | # 6 was used in the paper. 138 | init_range = math.sqrt(6.0 / (n_inputs + n_outputs)) 139 | return tf.random_uniform_initializer(-init_range, init_range) 140 | else: 141 | # 3 gives us approximately the same limits as above since this repicks 142 | # values greater than 2 standard deviations from the mean. 143 | stddev = math.sqrt(3.0 / (n_inputs + n_outputs)) 144 | return tf.truncated_normal_initializer(stddev=stddev) 145 | 146 | 147 | def spatial_slice_zeros(x): 148 | """Experimental summary that shows how many planes are unused for a batch.""" 149 | return tf.cast(tf.reduce_all(tf.less_equal(x, 0.0), [0, 1, 2]), 150 | tf.float32) 151 | -------------------------------------------------------------------------------- /prettytensor/local_trainer_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | """Tests for local_trainer.""" 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | 16 | import itertools 17 | import os 18 | import shutil 19 | import tempfile 20 | import threading 21 | import unittest 22 | 23 | 24 | 25 | import numpy 26 | import tensorflow as tf 27 | 28 | import prettytensor as pt 29 | from prettytensor import input_helpers 30 | from prettytensor import local_trainer 31 | 32 | 33 | class LocalTrainerTest(unittest.TestCase): 34 | 35 | def random_numpy(self, shape, dtype, partition_info=None): 36 | _ = partition_info 37 | if tf.float32.is_compatible_with(dtype): 38 | size = 1 39 | for n in shape: 40 | size *= n 41 | return self.prng.normal(size=size).astype(numpy.float32).reshape(shape) 42 | else: 43 | raise ValueError('This method only supports float32: %s' % dtype) 44 | 45 | def setUp(self): 46 | tf.reset_default_graph() 47 | self.prng = numpy.random.RandomState(42) 48 | 49 | self.input = tf.placeholder(tf.float32, [4, 2]) 50 | self.target = tf.placeholder(tf.float32) 51 | xor_inputs = numpy.array([[0., 0.], [1., 0.], [1., 1.], [0., 1.]]) 52 | xor_outputs = numpy.array([[0., 1.], [1., 0.], [0., 1.], [1., 0.]]) 53 | 54 | self.xor_data = itertools.cycle( 55 | input_helpers.feed_numpy(4, xor_inputs, xor_outputs)) 56 | 57 | self.softmax_result = ( 58 | pt.wrap(self.input).fully_connected(2, 59 | activation_fn=tf.sigmoid, 60 | weights=self.random_numpy) 61 | .fully_connected(2, 62 | activation_fn=None, 63 | weights=self.random_numpy).softmax(self.target)) 64 | self.tmp_file = tempfile.mkdtemp() 65 | 66 | def tearDown(self): 67 | shutil.rmtree(self.tmp_file) 68 | 69 | def test_run(self): 70 | runner = local_trainer.Runner() 71 | with tf.Session(): 72 | optimizer = tf.train.GradientDescentOptimizer(0.5) 73 | train_op = pt.apply_optimizer(optimizer, 74 | losses=[self.softmax_result.loss]) 75 | 76 | runner.train_model(train_op, 77 | self.softmax_result.loss, 78 | 10, 79 | (self.input, self.target), 80 | self.xor_data, 81 | print_every=2) 82 | 83 | def test_checkpoint(self): 84 | f = os.path.join(self.tmp_file, 'checkpoint') 85 | runner = local_trainer.Runner(save_path=f) 86 | with tf.Session(): 87 | optimizer = tf.train.GradientDescentOptimizer(0.1) 88 | train_op = pt.apply_optimizer(optimizer, 89 | losses=[self.softmax_result.loss]) 90 | 91 | runner.train_model(train_op, 92 | self.softmax_result.loss, 93 | 10, 94 | (self.input, self.target), 95 | self.xor_data, 96 | print_every=2) 97 | assert runner._saver.last_checkpoints, 'Expected checkpoints.' 98 | for x in runner._saver.last_checkpoints: 99 | self.assertTrue(tf.train.checkpoint_exists(x), 100 | 'Promised file not saved: %s' % x) 101 | self.assertTrue(x.startswith(f), 'Name not as expected: %s' % x) 102 | 103 | def test_eval(self): 104 | f = os.path.join(self.tmp_file, 'checkpoint') 105 | runner = local_trainer.Runner(save_path=f) 106 | with tf.Session(): 107 | classification_acuracy = self.softmax_result.softmax.evaluate_classifier( 108 | self.target, phase=pt.Phase.test) 109 | 110 | optimizer = tf.train.GradientDescentOptimizer(0.2) 111 | train_op = pt.apply_optimizer(optimizer, 112 | losses=[self.softmax_result.loss]) 113 | 114 | runner.train_model(train_op, 115 | self.softmax_result.loss, 116 | 100, 117 | (self.input, self.target), 118 | self.xor_data, 119 | print_every=50) 120 | self.assertTrue(runner._last_init) 121 | save_paths = list(runner._saver.last_checkpoints) 122 | 123 | # The accuracy should be 50% right now since model is consistently 124 | # generated. 125 | accuracy = runner.evaluate_model(classification_acuracy, 126 | 1, 127 | (self.input, self.target), 128 | self.xor_data) 129 | self.assertEquals(runner._saver.last_checkpoints, save_paths, 130 | 'No additional paths should have been saved.') 131 | self.assertFalse(runner._last_init) 132 | self.assertEqual(accuracy, [0.5]) 133 | 134 | # Train the model to 100% accuracy. 135 | runner.train_model(train_op, 136 | self.softmax_result.loss, 137 | 2000, 138 | (self.input, self.target), 139 | self.xor_data, 140 | print_every=1000) 141 | accuracy = runner.evaluate_model(classification_acuracy, 1, 142 | (self.input, self.target), self.xor_data) 143 | self.assertFalse(runner._last_init) 144 | 145 | # Make sure that the previous computation didn't impact this eval. 146 | self.assertEqual(accuracy, [1.0]) 147 | 148 | def restore_helper(self, runner): 149 | with tf.Session(): 150 | classification_acuracy = self.softmax_result.softmax.evaluate_classifier( 151 | self.target, phase=pt.Phase.test) 152 | 153 | optimizer = tf.train.GradientDescentOptimizer(0.5) 154 | train_op = pt.apply_optimizer(optimizer, 155 | losses=[self.softmax_result.loss]) 156 | 157 | runner.train_model(train_op, 158 | self.softmax_result.loss, 159 | 10, 160 | (self.input, self.target), 161 | self.xor_data, 162 | print_every=2) 163 | self.assertTrue(runner._last_init) 164 | self.assertFalse(runner._last_restore) 165 | with tf.Session(): 166 | save_paths = list(runner._saver.last_checkpoints) 167 | runner.evaluate_model(classification_acuracy, 1, 168 | (self.input, self.target), self.xor_data) 169 | self.assertEquals(runner._saver.last_checkpoints, save_paths, 170 | 'No additional paths should have been saved.') 171 | self.assertFalse(runner._last_init) 172 | 173 | def test_manual_save_restore(self): 174 | runner = local_trainer.Runner() 175 | f = os.path.join(self.tmp_file, 'manual.chkpt') 176 | 177 | v = tf.Variable(tf.random_normal(shape=[100], dtype=tf.float32)) 178 | 179 | # Save it. 180 | with runner.session() as sess: 181 | runner.prepare_model(sess) # Create variables 182 | value = v.eval() # Grab the variable 183 | runner.saver.save(sess, f) 184 | 185 | with runner.session() as sess: 186 | # Restore the model 187 | runner.saver.restore(sess, f) 188 | new_value = v.eval() 189 | numpy.testing.assert_array_equal(value, new_value) 190 | 191 | def test_restore(self): 192 | f = os.path.join(self.tmp_file, 'checkpoint') 193 | runner = local_trainer.Runner(save_path=f) 194 | self.restore_helper(runner) 195 | self.assertTrue(runner._last_restore) 196 | 197 | def test_not_restored(self): 198 | f = os.path.join(self.tmp_file, 'checkpoint') 199 | runner = local_trainer.Runner(save_path=f, restore=False) 200 | with self.assertRaises(tf.errors.FailedPreconditionError): 201 | self.restore_helper(runner) 202 | 203 | def test_evaluate_without_initialize_error(self): 204 | with tf.Graph().as_default(): 205 | runner = local_trainer.Runner() 206 | tf.Variable(1) # Put a variable in the graph. 207 | 208 | with runner.session(), self.assertRaises(ValueError): 209 | runner.evaluate_model( 210 | self.softmax_result, 1, (self.input, self.target), self.xor_data) 211 | 212 | def test_evaluate_repeatedly_one_time(self): 213 | f = os.path.join(self.tmp_file, 'checkpoint') 214 | runner = local_trainer.Runner(save_path=f) 215 | self.restore_helper(runner) 216 | local_variable = tf.Variable(22, collections=[tf.GraphKeys.LOCAL_VARIABLES]) 217 | accuracy = local_variable.assign_add(1) 218 | 219 | answer = runner.evaluate_repeatedly(accuracy, 20, evaluation_times=1) 220 | self.assertEqual([42], answer) 221 | 222 | def test_queues(self): 223 | qr = FakeQueueRunner() 224 | tf.train.add_queue_runner(qr) 225 | runner = local_trainer.Runner() 226 | with tf.Session(): 227 | optimizer = tf.train.GradientDescentOptimizer(0.5) 228 | train_op = pt.apply_optimizer(optimizer, 229 | losses=[self.softmax_result.loss]) 230 | 231 | runner.train_model(train_op, 232 | self.softmax_result.loss, 233 | 100, 234 | (self.input, self.target), 235 | self.xor_data, 236 | print_every=2) 237 | with tf.Session(): 238 | with self.assertRaisesRegexp(ValueError, r'.*\bstop_queues\b.*'): 239 | runner.train_model(train_op, 240 | self.softmax_result.loss, 241 | 100, 242 | (self.input, self.target), 243 | self.xor_data, 244 | print_every=2) 245 | 246 | runner.stop_queues() 247 | qr.assert_worked(self) 248 | 249 | def test_queue_error(self): 250 | qr = FakeQueueRunner(RuntimeError('expected')) 251 | tf.train.add_queue_runner(qr) 252 | runner = local_trainer.Runner() 253 | with tf.Session(): 254 | optimizer = tf.train.GradientDescentOptimizer(0.5) 255 | train_op = pt.apply_optimizer(optimizer, 256 | losses=[self.softmax_result.loss]) 257 | 258 | with self.assertRaisesRegexp(RuntimeError, 'expected'): 259 | runner.train_model(train_op, 260 | self.softmax_result.loss, 261 | 100, 262 | (self.input, self.target), 263 | self.xor_data, 264 | print_every=2) 265 | qr.assert_worked(self) 266 | 267 | 268 | class FakeQueueRunner(object): 269 | called = 0 270 | stopped = False 271 | 272 | def __init__(self, error=None): 273 | self.error = error 274 | 275 | def create_threads(self, sess, coord=None, daemon=False, start=False): # pylint: disable=unused-argument 276 | self.called += 1 277 | threads = [threading.Thread(target=self.set_stopped, args=(coord,))] 278 | if self.error: 279 | threads.append(threading.Thread(target=self.die, 280 | args=(coord, self.error))) 281 | if start: 282 | for t in threads: 283 | t.start() 284 | return threads 285 | 286 | def die(self, coord, error): 287 | try: 288 | raise error 289 | except RuntimeError as e: 290 | coord.request_stop(e) 291 | 292 | def set_stopped(self, coord): 293 | coord.wait_for_stop() 294 | self.stopped = True 295 | 296 | def assert_worked(self, test): 297 | test.assertEqual(1, self.called) 298 | test.assertTrue(self.stopped) 299 | 300 | if __name__ == '__main__': 301 | tf.test.main() 302 | -------------------------------------------------------------------------------- /prettytensor/parameters.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | r"""Provides some standard functions to modify parameter variables. 12 | 13 | These are applied before the variables are used within the graph; the basic 14 | function signature is: 15 | 16 | def my_func(var_name, variable, phase): 17 | \"""A function to apply to a model's parameters. 18 | 19 | Args: 20 | var_name: The short name for the variable. 21 | variable: A `Tensor` that can be used in the model. Note: this is often 22 | a tf.Variable, but may not be and the name only usually contains 23 | var_name (it doesn't in the case of reuse). 24 | phase: The phase of model construction. 25 | 26 | Returns: 27 | A `Variable` or `Tensor` with the same shape and type as `variable` to 28 | use. 29 | \""" 30 | return something_done_to_variable 31 | """ 32 | 33 | import re 34 | 35 | 36 | import tensorflow as tf 37 | 38 | from prettytensor import pretty_tensor_class as pt 39 | 40 | 41 | def identity(unused_var_name, variable, unused_phase): 42 | return variable 43 | 44 | 45 | def regularizer(name, regularization_fn, name_filter='weights'): 46 | """Wraps a regularizer in a parameter-function. 47 | 48 | Args: 49 | name: The name scope for this regularizer. 50 | regularization_fn: A function with signature: 51 | fn(variable) -> loss `Tensor` or `None`. 52 | name_filter: A regex that will be used to filter variables by name. 53 | 54 | Returns: 55 | A parameter modification function that adds the loss to the 56 | REGULARIZATION_LOSSES graph key. 57 | """ 58 | regex = re.compile(name_filter) 59 | def fn(var_name, variable, phase): 60 | if phase is pt.Phase.train and regex.search(var_name): 61 | with tf.name_scope(None, name, [variable]): 62 | loss = regularization_fn(variable) 63 | if loss is not None: 64 | tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, loss) 65 | return variable 66 | return fn 67 | 68 | 69 | def l2_regularizer(decay, name_filter='weights'): 70 | """Create an l2 regularizer.""" 71 | return regularizer( 72 | 'l2_regularizer', 73 | lambda x: tf.nn.l2_loss(x) * decay, 74 | name_filter=name_filter) 75 | 76 | 77 | def l1_regularizer(decay, name_filter='weights'): 78 | """Create an l1 regularizer.""" 79 | return regularizer( 80 | 'l1_regularizer', 81 | lambda x: tf.reduce_sum(tf.abs(x)) * decay, 82 | name_filter=name_filter) 83 | 84 | 85 | def compose(*parameter_functions): 86 | """Composes multiple modification functions in order. 87 | 88 | Args: 89 | *parameter_functions: The functions to compose. 90 | 91 | Returns: 92 | A parameter modification function that consists of applying all the provided 93 | functions. 94 | """ 95 | def composed_fn(var_name, variable, phase): 96 | for fn in parameter_functions: 97 | variable = fn(var_name, variable, phase) 98 | return variable 99 | return composed_fn 100 | 101 | 102 | class Noise(object): 103 | """Regularize the model by applying gaussian noise to the variables.""" 104 | 105 | def __init__(self, stddev): 106 | self._stddev = stddev 107 | 108 | def __call__(self, var_name, variable, phase): 109 | if phase is pt.Phase.train: 110 | return variable * tf.random_normal( 111 | tf.shape(variable), mean=1., stddev=self._stddev) 112 | else: 113 | return variable 114 | 115 | 116 | class DropConnect(object): 117 | """Drop out some connections. 118 | 119 | See the paper: http://www.matthewzeiler.com/pubs/icml2013/icml2013.pdf 120 | """ 121 | 122 | def __init__(self, keep_prob): 123 | self._keep_prob = keep_prob 124 | 125 | def __call__(self, var_name, variable, phase): 126 | if 'bias' not in var_name or phase is not pt.Phase.train: 127 | return variable 128 | return tf.nn.dropout(variable, self._keep_prob) 129 | -------------------------------------------------------------------------------- /prettytensor/pretty_tensor_normalization_methods.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | """Batch Normalization and eventually some friends for PrettyTensor.""" 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | 16 | import collections 17 | 18 | import tensorflow as tf 19 | 20 | from prettytensor import pretty_tensor_class as prettytensor 21 | from prettytensor.pretty_tensor_class import Phase 22 | from prettytensor.pretty_tensor_class import PROVIDED 23 | 24 | 25 | BatchNormalizationArguments = collections.namedtuple( 26 | 'BatchNormalizationArguments', 27 | ('learned_moments_update_rate', 'variance_epsilon', 28 | 'scale_after_normalization')) 29 | 30 | BatchNormalizationArguments.__new__.__defaults__ = (None, None, None) 31 | 32 | 33 | def batch_normalize_with_arguments(x, arguments): 34 | """Applies batch normalization to x as specified in arguments. 35 | 36 | Args: 37 | x: A Pretty Tensor. 38 | arguments: Either a boolean to batch_normalize or a 39 | BatchNormalizationArguments 40 | 41 | Returns: 42 | x with batch normalization applied. 43 | """ 44 | x = prettytensor.wrap(x) 45 | # Backwards compatibility. 46 | if isinstance(arguments, bool): 47 | if arguments: 48 | return x.batch_normalize() 49 | else: 50 | return x 51 | 52 | # pylint: disable=protected-access 53 | kwargs = arguments._asdict() 54 | defaults = prettytensor._defaults 55 | # pylint: enable=protected-access 56 | for arg in ('learned_moments_update_rate', 'variance_epsilon', 57 | 'scale_after_normalization'): 58 | if kwargs.get(arg, None) is None: 59 | if arg in defaults: 60 | kwargs[arg] = defaults[arg] 61 | else: 62 | del kwargs[arg] 63 | return x.batch_normalize(**kwargs) 64 | 65 | 66 | # pylint: disable=invalid-name 67 | @prettytensor.Register( 68 | assign_defaults=('learned_moments_update_rate', 'variance_epsilon', 69 | 'scale_after_normalization', 'phase')) 70 | class batch_normalize(prettytensor.VarStoreMethod): 71 | 72 | def __call__(self, 73 | input_layer, 74 | name=PROVIDED, 75 | learned_moments_update_rate=0.0003, 76 | variance_epsilon=0.001, 77 | scale_after_normalization=False, 78 | phase=Phase.train): 79 | """Batch normalize this layer. 80 | 81 | This only supports global batch normalization and it can be enabled for all 82 | convolutional layers by setting the default 'batch_normalize' to True. 83 | learned_moments_update_rate, variance_epsilon and scale_after_normalization 84 | need to either be set here or be set in defaults as well. 85 | 86 | Args: 87 | input_layer: The chainable object, supplied. 88 | name: The name for this operation is also used to create/find the 89 | parameter variables. 90 | learned_moments_update_rate: Update rate for the learned moments. 91 | variance_epsilon: A float. A small float number to avoid dividing by 0. 92 | scale_after_normalization: A bool indicating whether the resulted tensor 93 | needs to be multiplied with gamma. 94 | phase: The phase of construction. 95 | Returns: 96 | Handle to the generated layer. 97 | """ 98 | # Allocate variables to hold the moving averages of the moments. 99 | params_shape = [input_layer.shape[-1]] 100 | 101 | # Allocate parameters for the beta and gamma of the normalization. 102 | beta = self.variable('beta', params_shape, tf.constant_initializer(0.0)) 103 | if scale_after_normalization: 104 | gamma = self.variable('gamma', params_shape, tf.constant_initializer(1.0)) 105 | else: 106 | gamma = None 107 | moving_mean = self.variable('moving_mean', 108 | params_shape, 109 | tf.constant_initializer(0.0), 110 | train=False) 111 | moving_variance = self.variable('moving_variance', 112 | params_shape, 113 | tf.constant_initializer(1.0), 114 | train=False) 115 | 116 | if phase == Phase.train: 117 | # Calculate the moments based on the individual batch. 118 | mean, variance = tf.nn.moments( 119 | input_layer.tensor, list(range(len(input_layer.get_shape()) - 1))) 120 | input_layer.bookkeeper.add_histogram_summary(mean) 121 | input_layer.bookkeeper.add_histogram_summary(variance) 122 | 123 | avg_mean = input_layer.bookkeeper.exponential_moving_average( 124 | mean, moving_mean, 1.0 - learned_moments_update_rate) 125 | avg_variance = input_layer.bookkeeper.exponential_moving_average( 126 | variance, moving_variance, 1.0 - learned_moments_update_rate) 127 | with tf.control_dependencies([avg_variance, avg_mean]): 128 | y = tf.nn.batch_normalization(input_layer, mean, variance, beta, gamma, 129 | variance_epsilon) 130 | else: 131 | # Load the mean and variance as the 'moving average' moments 132 | # from the checkpoint. 133 | y = tf.nn.batch_normalization(input_layer, moving_mean, moving_variance, 134 | beta, gamma, variance_epsilon) 135 | 136 | return input_layer.with_tensor(y) 137 | # pylint: enable=invalid-name 138 | -------------------------------------------------------------------------------- /prettytensor/pretty_tensor_sparse_methods.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | """Holds PrettyTensor methods related to sparse data types.""" 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | 16 | import tensorflow as tf 17 | 18 | from prettytensor import pretty_tensor_class as prettytensor 19 | 20 | 21 | @prettytensor.Register 22 | def to_dense_one_hot(labels, class_count): 23 | """Converts a vector that specified one-hot per batch into a dense version. 24 | 25 | Args: 26 | labels: The labels input. 27 | class_count: The number of classes as an int. 28 | Returns: 29 | One dense vector for each item in the batch. 30 | Raises: 31 | ValueError: If labels is not rank 1. 32 | TypeError: If class_count is not an integer or labels is not an integer 33 | Tensor. 34 | """ 35 | if not isinstance(class_count, tf.compat.integral_types): 36 | raise TypeError('class_count must be an integer type.') 37 | if labels.dtype.base_dtype not in (tf.int32, tf.int64): 38 | raise TypeError('Labels must be an integer: %s' % labels.dtype) 39 | if labels.get_shape().ndims != 1: 40 | raise ValueError('Labels must be a rank 1 tensor: %s' % labels.get_shape()) 41 | 42 | dtype = labels.dtype.base_dtype 43 | class_tensor = tf.convert_to_tensor( 44 | class_count, dtype=dtype, name='class_count') 45 | 46 | # Extract the batch from the shape so this is batch independent. 47 | batch = tf.gather(tf.shape(labels), 0) 48 | count = tf.expand_dims(tf.range(0, limit=batch), 1) 49 | labels = tf.expand_dims(labels, 1) 50 | batch = tf.gather(tf.shape(labels), 0) 51 | 52 | if dtype != tf.int32: 53 | count = tf.cast(count, dtype) 54 | batch = tf.cast(batch, dtype) 55 | 56 | result = tf.sparse_to_dense( 57 | tf.concat([count, labels], 1), 58 | tf.concat([tf.expand_dims(batch, 0), tf.expand_dims(class_tensor, 0)], 0), 59 | 1.0, 0.0) 60 | result.set_shape([labels.get_shape().dims[0], class_count]) 61 | return result 62 | -------------------------------------------------------------------------------- /prettytensor/pretty_tensor_testing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | """Test class for PrettyTensor.""" 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | 16 | import unittest 17 | import six 18 | import tensorflow as tf 19 | 20 | import prettytensor 21 | 22 | 23 | # Count of tests for unique root namespaces. 24 | _count = 0 25 | 26 | 27 | class PtTestCase(unittest.TestCase): 28 | """Contains shared setUp/tearDown and convenience methods. 29 | 30 | This adds the following attributes to self: 31 | 32 | self.bookkeeper 33 | self.sess 34 | """ 35 | 36 | def __init__(self, *args): 37 | super(PtTestCase, self).__init__(*args) 38 | self.bookkeeper = None 39 | self.sess = None 40 | self._graph_with = None 41 | 42 | def RunTensor(self, layer, init=True): 43 | """Convenience method to run a tensor.""" 44 | if init: 45 | self.sess.run(tf.global_variables_initializer()) 46 | if isinstance(layer, (tf.Tensor, six.string_types)): 47 | return self.sess.run(layer) 48 | elif layer.is_sequence(): 49 | return self.sess.run(layer.sequence) 50 | else: 51 | return self.sess.run(layer) 52 | 53 | def Wrap(self, tensor, tensor_shape=None): 54 | """Convenience for prettytensor.wrap(tensor, self.bookkeeper).""" 55 | return prettytensor.wrap(tensor, self.bookkeeper, tensor_shape) 56 | 57 | def setUp(self): 58 | unittest.TestCase.setUp(self) 59 | self.SetBookkeeper(prettytensor.bookkeeper_for_new_graph()) 60 | 61 | def tearDown(self): 62 | self.TearDownBookkeeper() 63 | unittest.TestCase.tearDown(self) 64 | 65 | def SetBookkeeper(self, m): 66 | """Used to set custom bookkeeper code.""" 67 | self.TearDownBookkeeper() 68 | self.bookkeeper = m 69 | self._graph_with = self.bookkeeper.g.as_default() 70 | self._graph_with.__enter__() 71 | global _count 72 | _count += 1 73 | 74 | self.sess = tf.Session('') 75 | 76 | def TearDownBookkeeper(self): 77 | if self._graph_with: 78 | self._graph_with.__exit__(None, None, None) 79 | self._graph_with = None 80 | if self.sess: 81 | self.sess.close() 82 | self.sess = None 83 | self.bookkeeper = None 84 | -------------------------------------------------------------------------------- /prettytensor/recurrent_networks_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | """Test class for the recurrent networks module.""" 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | 16 | import unittest 17 | 18 | 19 | 20 | import numpy 21 | from numpy import testing 22 | import six 23 | from six.moves import xrange # pylint: disable=redefined-builtin 24 | from six.moves import zip # pylint: disable=redefined-builtin 25 | import tensorflow as tf 26 | 27 | import prettytensor 28 | from prettytensor import pretty_tensor_testing 29 | from prettytensor import recurrent_networks 30 | from prettytensor import recurrent_networks_testing_utils as testing_utils 31 | 32 | 33 | TOLERANCE = 0.00001 34 | 35 | 36 | class RecurrentNetworksTest(pretty_tensor_testing.PtTestCase): 37 | 38 | def setUp(self): 39 | super(self.__class__, self).setUp() 40 | self.input_data = numpy.array( 41 | [ 42 | [[1.], [2.], [3.], [4.]], [[5.], [6.], [7.], [8.]], 43 | [[5.], [6.], [7.], [8.]], [[9.], [10.], [11.], [12.]] 44 | ], 45 | dtype=numpy.float) 46 | self.sequence = testing_utils.SequenceInputMock( 47 | self.bookkeeper, self.input_data, [[0.], [0.], [0.], [1.]], 13) 48 | 49 | self.input, self.output = recurrent_networks.create_sequence_pretty_tensor( 50 | self.sequence) 51 | 52 | def testSquashAndCleave(self): 53 | squashed = self.input.squash_sequence() 54 | result = self.RunTensor(squashed) 55 | 56 | testing.assert_allclose( 57 | self.input_data.reshape(16, 1), 58 | result, 59 | rtol=TOLERANCE) 60 | 61 | result = self.RunTensor(squashed.cleave_sequence()) 62 | 63 | for i in xrange(len(self.input_data)): 64 | testing.assert_allclose( 65 | self.input_data[i], result[i], 66 | rtol=TOLERANCE) 67 | 68 | def testSquashAndCleaveLength1(self): 69 | input_data = numpy.array( 70 | [[[1.], [2.], [3.], [4.]]], dtype=numpy.float) 71 | sequence = testing_utils.SequenceInputMock( 72 | self.bookkeeper, input_data, [[0.]], 13) 73 | 74 | inp, _ = recurrent_networks.create_sequence_pretty_tensor(sequence) 75 | squashed = inp.squash_sequence() 76 | result = self.RunTensor(squashed) 77 | 78 | testing.assert_allclose( 79 | input_data.reshape(4, 1), result, 80 | rtol=TOLERANCE) 81 | 82 | result = self.RunTensor(squashed.cleave_sequence()) 83 | 84 | testing.assert_allclose(input_data[0], result[0], rtol=TOLERANCE) 85 | self.assertEquals(1, len(result)) 86 | 87 | def testSequenceLstm(self): 88 | lstm = self.input.sequence_lstm(13) 89 | result = self.RunTensor(lstm) 90 | 91 | self.assertEquals([4, 13], lstm.shape) 92 | self.assertEquals(4, len(result)) 93 | for i in xrange(4): 94 | self.assertSequenceEqual(lstm.shape, result[i].shape) 95 | 96 | def testSequenceGru(self): 97 | gru = self.input.sequence_gru(13) 98 | result = self.RunTensor(gru) 99 | 100 | self.assertEquals([4, 13], gru.shape) 101 | self.assertEquals(4, len(result)) 102 | for i in xrange(4): 103 | self.assertSequenceEqual(gru.shape, result[i].shape) 104 | 105 | def performTestArbitraryBatchSizeRnn(self, cell_type): 106 | # Tests whether LSTM / GRU / Bookkeeper function when batch_size is not 107 | # specified at graph creation time (i.e., None). 108 | self.assertTrue(cell_type == 'lstm' or cell_type == 'gru') 109 | super(self.__class__, self).SetBookkeeper( 110 | prettytensor.bookkeeper_for_new_graph()) 111 | 112 | # Build a graph. Specify None for the batch_size dimension. 113 | placeholder = tf.placeholder(tf.float32, [None, 1]) 114 | input_pt = prettytensor.wrap_sequence([placeholder]) 115 | if cell_type == 'lstm': 116 | output, _ = (input_pt 117 | .sequence_lstm(4) 118 | .squash_sequence() 119 | .softmax_classifier(2)) 120 | elif cell_type == 'gru': 121 | output, _ = (input_pt 122 | .sequence_gru(4) 123 | .squash_sequence() 124 | .softmax_classifier(2)) 125 | 126 | self.sess.run(tf.global_variables_initializer()) 127 | 128 | # Use RecurrentRunner for state saving and managing feeds. 129 | recurrent_runner = recurrent_networks.RecurrentRunner(batch_size=1) 130 | 131 | # Run with a batch size of 1 for 10 steps, save output for reference. 132 | out_orig = [] 133 | for t in xrange(10): 134 | outs = recurrent_runner.run( 135 | [output.name], 136 | {placeholder.name: numpy.array([[1.2]])}, 137 | sess=self.sess) 138 | out = outs[0] 139 | self.assertEqual(1, len(out)) 140 | self.assertEqual(2, len(out[0])) 141 | out_orig.append(out[0]) 142 | 143 | # Test the reset functionality - after a reset, the results must be 144 | # identical to what we just got above. 145 | recurrent_runner.reset() 146 | for t in xrange(10): 147 | outs = recurrent_runner.run( 148 | [output.name], 149 | {placeholder.name: numpy.array([[1.2]])}, 150 | sess=self.sess) 151 | out = outs[0] 152 | self.assertEqual(1, len(out)) 153 | self.assertEqual(2, len(out[0])) 154 | testing.assert_allclose(out[0], out_orig[t]) 155 | 156 | # Test whether the recurrent runner detects changes to the default graph. 157 | # It should raise an Assertion because RecurrentRunner's state saver 158 | # information (collected during __init__) is not valid anymore. 159 | with tf.Graph().as_default(): 160 | placeholder2 = tf.placeholder(tf.float32, [None, 1]) 161 | input_pt2 = prettytensor.wrap_sequence([placeholder2]) 162 | if cell_type == 'lstm': 163 | output2, _ = (input_pt2 164 | .sequence_lstm(4) 165 | .squash_sequence() 166 | .softmax_classifier(2)) 167 | elif cell_type == 'gru': 168 | output2, _ = (input_pt2 169 | .sequence_gru(4) 170 | .squash_sequence() 171 | .softmax_classifier(2)) 172 | self.assertRaises(ValueError, 173 | recurrent_runner.run, 174 | [output2.name], None, self.sess) 175 | 176 | # Run with a batch size of 3; first and third input are identical and must 177 | # yield identical output, and the same output as in the single batch run 178 | # above (up to floating point rounding errors). 179 | recurrent_runner = recurrent_networks.RecurrentRunner(batch_size=3) 180 | for t in xrange(10): 181 | outs = recurrent_runner.run( 182 | [output.name], 183 | {placeholder.name: numpy.array([[1.2], [3.4], [1.2]])}, 184 | sess=self.sess) 185 | out = outs[0] 186 | self.assertEqual(3, len(out)) 187 | self.assertEqual(2, len(out[0])) 188 | testing.assert_allclose(out[0], out[2], rtol=TOLERANCE) 189 | testing.assert_allclose(out[0], out_orig[t], rtol=TOLERANCE) 190 | # Sanity check to protect against trivial outputs that might hide errors. 191 | # Need to avoid checking after t = 2 since untrained GRUs have a 192 | # tendency to converge to large state values, leading to outputs like 193 | # 1.0, 0.0. 194 | if cell_type == 'gru' and t > 2: 195 | continue 196 | self.assertFalse((out[0] == out[1]).all()) 197 | 198 | def testArbitraryBatchSizeLstm(self): 199 | self.performTestArbitraryBatchSizeRnn('lstm') 200 | 201 | def testArbitraryBatchSizeGru(self): 202 | self.performTestArbitraryBatchSizeRnn('gru') 203 | 204 | def testSequence(self): 205 | result = self.RunTensor(self.input[-1]) 206 | testing.assert_allclose( 207 | self.input_data[-1], result, 208 | rtol=TOLERANCE) 209 | 210 | def testEmbeddingNameWorkaround(self): 211 | """Just make sure this runs since it is ensuring that a workaround works.""" 212 | input_data = self.Wrap(self.input_data.astype(numpy.int32).reshape([16, 1])) 213 | result = input_data.embedding_lookup( 214 | 13, [1], name='params') 215 | self.RunTensor(result) 216 | 217 | def testEmbeddingLookupRequiresRank2(self): 218 | """Just make sure this runs since it is ensuring that a workaround works.""" 219 | input_data = self.Wrap(self.input_data.astype(numpy.int32)) 220 | with self.assertRaises(ValueError): 221 | input_data.embedding_lookup(13, [1], name='params') 222 | 223 | def testLstmStateTuples(self): 224 | self.states = recurrent_networks.lstm_state_tuples(13, 'blah') 225 | self.RunTensor( 226 | self.input.sequence_lstm(13, name='blah')[-1]) 227 | 228 | for state in self.states: 229 | self.assertTrue( 230 | state[0] in self.sequence.requested_tensors, '%s missing: %s' % 231 | (state[0], list(six.iterkeys(self.sequence.requested_tensors)))) 232 | self.assertEqual( 233 | len(self.states), len(self.sequence.requested_tensors), 234 | 'Wrong number of Tensor states.') 235 | 236 | def testGruStateTuples(self): 237 | self.states = recurrent_networks.gru_state_tuples(13, 'blah') 238 | self.RunTensor( 239 | self.input.sequence_gru(13, name='blah')[-1]) 240 | 241 | for state in self.states: 242 | self.assertTrue( 243 | state[0] in self.sequence.requested_tensors, '%s missing: %s' % 244 | (state[0], list(six.iterkeys(self.sequence.requested_tensors)))) 245 | self.assertEqual( 246 | len(self.states), len(self.sequence.requested_tensors), 247 | 'Wrong number of Tensor states.') 248 | 249 | def testLength(self): 250 | tf.set_random_seed(4321) 251 | with tf.variable_scope('test') as vs: 252 | base_lstm = self.input.sequence_lstm(13) 253 | lengths = tf.placeholder(dtype=tf.int32, shape=[4]) 254 | 255 | # Use the same parameters. 256 | with tf.variable_scope(vs, reuse=True): 257 | lstm_truncated = self.input.sequence_lstm(13, lengths=lengths) 258 | 259 | with tf.Session() as sess: 260 | tf.global_variables_initializer().run() 261 | 262 | result = sess.run(base_lstm.sequence + lstm_truncated.sequence, 263 | {lengths: [10, 4, 1, 1]}) 264 | base_result = result[:len(base_lstm.sequence)] 265 | full_result = result[len(base_lstm.sequence):] 266 | truncated_result = sess.run(lstm_truncated.sequence, 267 | {lengths: [1, 2, 1, 1]}) 268 | 269 | for i, (x, y) in enumerate(zip(base_result, truncated_result)): 270 | if i < 2: 271 | testing.assert_allclose(x, y, rtol=TOLERANCE) 272 | else: 273 | # After the specified output, we check to make sure the same values are 274 | # propagated forward. 275 | self.assertFalse(numpy.allclose(x, y, rtol=TOLERANCE)) 276 | testing.assert_allclose(y, truncated_result[i - 1], rtol=TOLERANCE) 277 | for x, y in zip(base_result, full_result): 278 | # The later results tend to diverge. This is something that requires 279 | # investigation. 280 | testing.assert_allclose(x, y, atol=0.1) 281 | 282 | 283 | if __name__ == '__main__': 284 | unittest.main() 285 | -------------------------------------------------------------------------------- /prettytensor/recurrent_networks_testing_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | """Utility class for testing recurrent networks.""" 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | 16 | import tensorflow as tf 17 | 18 | 19 | class SequenceInputMock(object): 20 | """A sequence input mock for testing recurrent networks.""" 21 | 22 | def __init__(self, bookkeeper, input_list, label_list, node_depth): 23 | self.inputs = self.for_constants(input_list) 24 | self.targets = self.for_constants(label_list) 25 | self.node_depth = node_depth 26 | self.batch_size = input_list[0].shape[0] 27 | self.requested_tensors = {} 28 | self.bookkeeper = bookkeeper 29 | self.num_timesteps = len(input_list) 30 | 31 | def for_constants(self, ls): 32 | return [tf.constant(x, dtype=tf.float32) for x in ls] 33 | 34 | def state(self, state_name): 35 | """Returns, creating if necessary, a state variable with the given name.""" 36 | if state_name not in self.requested_tensors: 37 | count = tf.get_variable('count_%s' % state_name, 38 | [], 39 | tf.int32, 40 | tf.zeros_initializer(), 41 | trainable=False) 42 | value = tf.get_variable(state_name, [self.batch_size, self.node_depth], 43 | tf.float32, tf.zeros_initializer()) 44 | self.requested_tensors[state_name] = (count, value) 45 | 46 | return self.requested_tensors[state_name][1] 47 | 48 | def save_state(self, state_name, unused_value, name='SaveState'): 49 | return tf.assign_add(self.requested_tensors[state_name][0], 1, name=name) 50 | 51 | -------------------------------------------------------------------------------- /prettytensor/replay_queue.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | """Creates a replayable queue.""" 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | 16 | import collections 17 | import contextlib 18 | 19 | 20 | import six 21 | import tensorflow as tf 22 | 23 | from prettytensor import pretty_tensor_class as prettytensor 24 | 25 | 26 | def _make_tuple(x): 27 | """TF has an obnoxious habit of being lenient with single vs tuple.""" 28 | if isinstance(x, prettytensor.PrettyTensor): 29 | if x.is_sequence(): 30 | return tuple(x.sequence) 31 | else: 32 | return (x.tensor,) 33 | elif isinstance(x, tuple): 34 | return x 35 | elif (isinstance(x, collections.Sequence) and 36 | not isinstance(x, six.string_types)): 37 | return tuple(x) 38 | else: 39 | return (x,) 40 | 41 | 42 | class ReplayableQueue(object): 43 | """A switchable queue between the original data and a replayed subset. 44 | 45 | This queue combines 2 concepts: 46 | 47 | 1. A replay queue that re-enqueues the data every time it is dequeued so that 48 | multiple passes over the same data can be made (make sure to iterated 49 | `replay_steps` times). 50 | 2. The ability to switch data sources between the original source and the 51 | replayable queue. Embedding the switch makes it easy to construct a single 52 | graph on top of both input sources. 53 | 54 | Note: The queue requires manual filling by calling `refill`! 55 | 56 | The typical use case for replaying data is when you want to run an experiment 57 | on a subset of the training data and need to reuse the data multiple times. 58 | 59 | Here is an example that uses the replay queue to monitor loss on a dynamically 60 | selected validation set: 61 | 62 | ``` 63 | replay = pt.train.ReplayableQueue(lambda: MY_Q.dequeue_many(BATCH_SIZE), 64 | REPLAY_SIZE) 65 | # Build a graph with replay.output 66 | my_train_op, my_loss = build_graph(replay.output) 67 | 68 | with tf.Session() as sess: 69 | # Capture some data 70 | replay.refill(sess) 71 | 72 | for epoch in xrange(EPOCHS): 73 | # Train for a while 74 | for _ in xrange(1000): 75 | sess.run(my_train_op) 76 | loss = 0 77 | with replay.replay_scope(): 78 | for _ in xrange(replay.replay_steps): 79 | loss += sess.run(my_loss) 80 | loss /= replay.replay_steps 81 | print('Loss at epoch %d: %g' % (epoch, loss)) 82 | ``` 83 | """ 84 | 85 | def __init__(self, input_fn, replay_size, batch_size=None): 86 | """Creates a ReplayableQueue that takes data from `input_fn`. 87 | 88 | See also: `pt.train.ReplayableQueue.build_from_queue`. 89 | 90 | Note: the shapes of the inputs must be fully defined. 91 | 92 | Note: `input_fn` is a function instead of an input. This is because 93 | otherwise if the input came from a queue, dependencies wouldn't be set up 94 | properly and the data would always be dequeued. If you are providing data 95 | from a queue, then pass in `lambda: q.dequeue_many(batch_size)`. 96 | 97 | Args: 98 | input_fn: A function of no arguments that returns the input as a tuple of 99 | `Tensors`. 100 | replay_size: The size of the replay queue. 101 | batch_size: If provided, use this as the batch size otherwise infer it. 102 | 103 | Raises: 104 | ValueError: if `replay_size` is not divisible by `batch_size` or if the 105 | shapes on the input are wrong. 106 | """ 107 | inputs = _make_tuple(input_fn()) 108 | 109 | for x in inputs: 110 | x.get_shape().assert_is_fully_defined() 111 | if batch_size is not None: 112 | x.get_shape()[0].assert_is_compatible_with(batch_size) 113 | else: 114 | batch_size = x.get_shape()[0].value 115 | 116 | dtypes = [x.dtype for x in inputs] 117 | shapes = [x.get_shape()[1:] if x.get_shape() else () for x in inputs] 118 | 119 | if replay_size % batch_size != 0: 120 | raise ValueError('replay_size size (%d) must be a multiple of batch size ' 121 | '(%d)' % (replay_size, batch_size)) 122 | 123 | # Setup the flag that controls replay. 124 | self._replay_var = tf.get_variable( 125 | 'replay', 126 | dtype=tf.bool, 127 | shape=[], 128 | initializer=tf.constant_initializer(False), 129 | trainable=False) 130 | self._set_replay_ph = tf.placeholder(dtype=tf.bool) 131 | self._set_replay = self._replay_var.assign(self._set_replay_ph) 132 | 133 | self._replay_queue = tf.FIFOQueue(replay_size, dtypes, shapes) 134 | 135 | # _fill_queue adds data to the queue and then returns whether it is full. 136 | with tf.control_dependencies([self._replay_queue.enqueue_many(inputs)]): 137 | self._fill_queue = tf.less(self._replay_queue.size(), replay_size) 138 | 139 | # Dequeue all the things! 140 | self._clear_queue = self._replay_queue.dequeue_many( 141 | self._replay_queue.size()) 142 | 143 | def _pull_from_replay(): 144 | data_tuple = _make_tuple(self._replay_queue.dequeue_many(batch_size)) 145 | with tf.control_dependencies([self._replay_queue.enqueue_many(data_tuple) 146 | ]): 147 | return (tf.identity(data_tuple[0]),) + data_tuple[1:] 148 | 149 | def _pull_from_original(): 150 | return _make_tuple(input_fn()) 151 | 152 | self._output = prettytensor.wrap( 153 | tf.cond(self._replay_var, _pull_from_replay, _pull_from_original)) 154 | 155 | @classmethod 156 | def build_from_queue(cls, input_queue, replay_size, batch_size): 157 | """Builds a `ReplayableQueue` that draws from a regular `input_queue`. 158 | 159 | Args: 160 | input_queue: The queue to draw from. 161 | replay_size: The size of the replay buffer. 162 | batch_size: The size of each batch. 163 | 164 | Returns: 165 | A ReplayableQueue. 166 | """ 167 | return cls( 168 | lambda: input_queue.dequeue_many(batch_size), 169 | replay_size, 170 | batch_size=batch_size) 171 | 172 | @property 173 | def output(self): 174 | """Returns the output Tensor for this queue. 175 | 176 | The output is selected between the original data and the replay data 177 | depending on the replay value. 178 | 179 | Returns: 180 | The output tensors as a tuple. 181 | """ 182 | return self._output 183 | 184 | @property 185 | def replay_steps(self): 186 | """Returns the number of steps to replay.""" 187 | return self._replay_size // self._batch_size 188 | 189 | @property 190 | def replay_size(self): 191 | """Returns the total number of examples this queue holds.""" 192 | return self._replay_size 193 | 194 | @contextlib.contextmanager 195 | def replay_scope(self, sess): 196 | """Enters a replay scope that unsets it at the end.""" 197 | current_replay = self.replay(sess) 198 | try: 199 | self.set_replay(sess, True) 200 | yield 201 | finally: 202 | self.set_replay(sess, current_replay) 203 | 204 | def replay(self, sess): 205 | """Gets the current value of replay from the graph. 206 | 207 | Note: this runs the graph, but it is just a var read so it is fairly cheap. 208 | 209 | Args: 210 | sess: The session in which to run. 211 | Returns: 212 | The value of the replay variable. 213 | """ 214 | return sess.run(self._replay_var) 215 | 216 | def set_replay(self, sess, replay): 217 | """Changes the current replay setting on the graph.""" 218 | sess.run(self._set_replay, {self._set_replay_ph: replay}) 219 | 220 | def refill(self, sess): 221 | """Clears the current queue and then refills it with new data.""" 222 | sess.run(self._clear_queue) 223 | # Run until full. 224 | while sess.run(self._fill_queue): 225 | pass 226 | -------------------------------------------------------------------------------- /prettytensor/replay_queue_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | """Tests for replay_queue.""" 12 | 13 | 14 | import numpy as np 15 | import tensorflow as tf 16 | 17 | import prettytensor as pt 18 | 19 | 20 | class ReplayableQueueTest(tf.test.TestCase): 21 | 22 | def test_replay_queue_with_queue_input(self): 23 | # Put a lot of replay.output on the queue 24 | q = tf.FIFOQueue(1000, tf.float32, []) 25 | enqueue = q.enqueue_many(tf.to_float(tf.range(0, 1000))) 26 | replay = pt.train.ReplayableQueue.build_from_queue(q, 100, 10) 27 | 28 | with self.test_session() as sess: 29 | sess.run(tf.global_variables_initializer()) 30 | 31 | sess.run(enqueue) 32 | 33 | d = sess.run(replay.output) 34 | self.assertAllClose(np.arange(10).astype(np.float), d) 35 | 36 | # Now fill the queue 37 | replay.refill(sess) 38 | 39 | self.assertEqual(100, replay._replay_queue.size().eval()) 40 | 41 | d = sess.run(replay.output) 42 | 43 | # Replay is still false, but the queue has advanced 44 | self.assertAllClose(np.arange(110, 120).astype(np.float), d) 45 | 46 | # Now set replay. 47 | replay.set_replay(sess, True) 48 | for i in range(10): 49 | d = sess.run(replay.output) 50 | range_start = 10 + i * 10 51 | self.assertAllClose( 52 | np.arange(range_start, range_start + 10).astype(np.float), d) 53 | 54 | # And again 55 | for i in range(10): 56 | d = sess.run(replay.output) 57 | range_start = 10 + i * 10 58 | self.assertAllClose( 59 | np.arange(range_start, range_start + 10).astype(np.float), d) 60 | 61 | replay.set_replay(sess, False) 62 | 63 | # Back to the replay.output stream 64 | d = sess.run(replay.output) 65 | self.assertAllClose(np.arange(120, 130).astype(np.float), d) 66 | 67 | # And refill the queue 68 | replay.refill(sess) 69 | replay.set_replay(sess, True) 70 | d = sess.run(replay.output) 71 | self.assertAllClose(np.arange(130, 140).astype(np.float), d) 72 | 73 | replay.set_replay(sess, False) 74 | d = sess.run(replay.output) 75 | self.assertAllClose(np.arange(230, 240).astype(np.float), d) 76 | 77 | if __name__ == '__main__': 78 | tf.test.main() 79 | -------------------------------------------------------------------------------- /prettytensor/scopes.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | """Contains methods related to making templates.""" 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | 16 | import contextlib 17 | import functools 18 | import traceback 19 | 20 | from six.moves import zip # pylint: disable=redefined-builtin 21 | import tensorflow as tf 22 | 23 | from tensorflow.python.ops import variable_scope 24 | 25 | 26 | @contextlib.contextmanager 27 | def var_and_name_scope(names): 28 | """Creates a variable scope and a name scope. 29 | 30 | If a variable_scope is provided, this will reenter that variable scope. 31 | However, if none is provided then the variable scope will match the generated 32 | part of the name scope. 33 | 34 | Args: 35 | names: A tuple of name_scope, variable_scope or None. 36 | Yields: 37 | The result of name_scope and variable_scope as a tuple. 38 | """ 39 | # pylint: disable=protected-access 40 | if not names: 41 | yield None, None 42 | else: 43 | name, var_scope = names 44 | with tf.name_scope(name) as scope: 45 | # TODO(eiderman): This is a workaround until the variable_scope updates land 46 | # in a TF release. 47 | old_vs = tf.get_variable_scope() 48 | if var_scope is None: 49 | count = len(name.split('/')) 50 | scoped_name = '/'.join(scope.split('/')[-count - 1:-1]) 51 | full_name = (old_vs.name + '/' + scoped_name).lstrip('/') 52 | else: 53 | full_name = var_scope.name 54 | 55 | vs_key = tf.get_collection_ref(variable_scope._VARSCOPE_KEY) 56 | try: 57 | # TODO(eiderman): Remove this hack or fix the full file. 58 | try: 59 | vs_key[0] = tf.VariableScope( 60 | old_vs.reuse, 61 | name=full_name, 62 | initializer=old_vs.initializer, 63 | regularizer=old_vs.regularizer, 64 | caching_device=old_vs.caching_device) 65 | except AttributeError: 66 | vs_key[0] = variable_scope._VariableScope( 67 | old_vs.reuse, 68 | name=full_name, 69 | initializer=old_vs.initializer) 70 | 71 | vs_key[0].name_scope = scope 72 | yield scope, vs_key[0] 73 | finally: 74 | vs_key[0] = old_vs 75 | 76 | 77 | def get_current_name_scope(): 78 | """Gets the current name scope.""" 79 | # pylint: disable=protected-access 80 | g = tf.get_default_graph() 81 | # TODO(eiderman): Remove this hack once TF update is released. 82 | if isinstance(g._name_stack, tuple): 83 | return g._name_stack[0] + '/' 84 | else: 85 | return g._name_stack + '/' 86 | 87 | 88 | def _get_last_part_of_name_scope(scope): 89 | splits = scope.split('/') 90 | return splits[-2] 91 | 92 | 93 | def make_template(name, func, *args, **kwargs): 94 | """Given an arbitrary function, wrap it so that it does parameter sharing.""" 95 | if args or kwargs: 96 | func = functools.partial(func, *args, **kwargs) 97 | return Template(name, func) 98 | 99 | 100 | def skip_common_stack_elements(stacktrace, base_case): 101 | """Skips items that the target stacktrace shares with the base stacktrace.""" 102 | for i, (trace, base) in enumerate(zip(stacktrace, base_case)): 103 | if trace != base: 104 | return stacktrace[i:] 105 | return stacktrace[-1:] 106 | 107 | 108 | class Template(object): 109 | """A Template captures variable and namescopes to help variable sharing.""" 110 | 111 | def __init__(self, name, func): 112 | """Creates a template for the given function. 113 | 114 | Args: 115 | name: The variable_scope to use, if None the current scope is captured. 116 | func: The function to apply each time. 117 | """ 118 | self._func = func 119 | if name: 120 | self._var_scope = None 121 | self._name = name 122 | else: 123 | self._var_scope = tf.get_variable_scope() 124 | self._name = None 125 | self._reuse = None 126 | self._stacktrace = traceback.format_stack()[:-3] 127 | 128 | def _call_func(self, args, kwargs): 129 | try: 130 | self._reuse = True 131 | return self._func(*args, **kwargs) 132 | except Exception as exc: 133 | # Reraise the exception, but append the original definition to the 134 | # trace. 135 | args = exc.args 136 | if not args: 137 | arg0 = '' 138 | else: 139 | arg0 = args[0] 140 | trace = ''.join(skip_common_stack_elements(self._stacktrace, 141 | traceback.format_stack())) 142 | arg0 = '%s\n\noriginally defined at:\n%s' % (arg0, trace) 143 | new_args = [arg0] 144 | new_args.extend(args[1:]) 145 | exc.args = tuple(new_args) 146 | raise 147 | 148 | def __call__(self, *args, **kwargs): 149 | if self._name: 150 | with var_and_name_scope((self._name, self._var_scope)) as (_, vs): 151 | if self._reuse: 152 | vs.reuse_variables() 153 | else: 154 | self._var_scope = vs 155 | return self._call_func(args, kwargs) 156 | else: 157 | with tf.variable_scope(self._var_scope, reuse=self._reuse) as vs: 158 | return self._call_func(args, kwargs) 159 | -------------------------------------------------------------------------------- /prettytensor/scopes_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | """Tests for scopes.""" 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | 16 | import traceback 17 | import unittest 18 | 19 | 20 | import tensorflow as tf 21 | 22 | from prettytensor import scopes 23 | 24 | 25 | def var_scoped_function(): 26 | return tf.get_variable('dummy', shape=[1], initializer=tf.zeros_initializer()) 27 | 28 | 29 | class ScopesTest(unittest.TestCase): 30 | 31 | def test_skip_stack_frames(self): 32 | first = traceback.format_stack() 33 | second = traceback.format_stack() 34 | result = scopes.skip_common_stack_elements(first, second) 35 | self.assertEqual(1, len(result)) 36 | self.assertNotEqual(len(first), len(result)) 37 | 38 | def test_get_current_name_scope(self): 39 | self.assertEquals('/', scopes.get_current_name_scope()) 40 | self.assertEquals('', scopes._get_last_part_of_name_scope('/')) 41 | with tf.name_scope('one') as scope: 42 | self.assertEquals(scope, scopes.get_current_name_scope()) 43 | self.assertEquals('one', scopes._get_last_part_of_name_scope(scope)) 44 | 45 | with tf.name_scope('one') as scope: 46 | self.assertEquals(scope, scopes.get_current_name_scope()) 47 | self.assertEquals('one_1', scopes._get_last_part_of_name_scope(scope)) 48 | with tf.name_scope('two') as nested_scope: 49 | self.assertEquals(nested_scope, scopes.get_current_name_scope()) 50 | self.assertEquals('two', 51 | scopes._get_last_part_of_name_scope(nested_scope)) 52 | 53 | def test_template_without_name(self): 54 | tmpl1 = scopes.Template(None, var_scoped_function) 55 | 56 | v1 = tmpl1() 57 | v2 = tmpl1() 58 | self.assertEqual(v1, v2) 59 | self.assertEqual('dummy:0', v1.name) 60 | 61 | def test_template_with_name(self): 62 | tmpl1 = scopes.Template('s1', var_scoped_function) 63 | tmpl2 = scopes.Template('s1', var_scoped_function) 64 | 65 | v1 = tmpl1() 66 | v2 = tmpl1() 67 | v3 = tmpl2() 68 | self.assertEqual(v1, v2) 69 | self.assertNotEqual(v1, v3) 70 | self.assertEqual('s1/dummy:0', v1.name) 71 | self.assertEqual('s1_2/dummy:0', v3.name) 72 | 73 | def test_var_and_name_scope(self): 74 | with tf.Graph().as_default(): 75 | with scopes.var_and_name_scope(('one', None)) as (ns, vs): 76 | self.assertEqual('one/', ns) 77 | self.assertEqual('one', vs.name) 78 | with scopes.var_and_name_scope(('one', None)) as (ns, vs): 79 | self.assertEqual('one_1/', ns) 80 | self.assertEqual('one_1', vs.name) 81 | with scopes.var_and_name_scope(('one/two', None)) as (ns, vs): 82 | self.assertEqual('one/two/', ns) 83 | self.assertEqual('one/two', vs.name) 84 | 85 | 86 | if __name__ == '__main__': 87 | unittest.main() 88 | -------------------------------------------------------------------------------- /prettytensor/sequence_with_deltas.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | """Provides a class that just implements a sequence with a delta count.""" 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | 16 | import collections 17 | 18 | 19 | class SequenceWithDeltas(collections.MutableSequence): 20 | """Provides a sequence with a count of modifications.""" 21 | 22 | def __init__(self, other_seq=None): 23 | if other_seq is None: 24 | self._seq = [] 25 | else: 26 | self._seq = list(other_seq) 27 | self._mods = len(self._seq) 28 | self._mark = 0 29 | 30 | def __getitem__(self, key): 31 | return self._seq[key] 32 | 33 | def __setitem__(self, key, value): 34 | self._mods += 1 35 | self._seq[key] = value 36 | 37 | def __delitem__(self, key): 38 | self._mods += 1 39 | del self._seq[key] 40 | 41 | def __len__(self): 42 | return len(self._seq) 43 | 44 | def insert(self, key, value): 45 | self._mods += 1 46 | self._seq.insert(key, value) 47 | 48 | @property 49 | def deltas(self): 50 | return self._mods 51 | 52 | def mark(self): 53 | """Marks this sequence at the current number of deltas.""" 54 | self._mark = self._mods 55 | 56 | def has_changed(self): 57 | """Returns if it has changed since the last mark.""" 58 | return self._mark == self._mods 59 | -------------------------------------------------------------------------------- /prettytensor/templated_pretty_tensor_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | """Tests for templating in PrettyTensor.""" 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | 16 | import unittest 17 | 18 | 19 | 20 | import numpy 21 | from numpy import testing 22 | import tensorflow as tf 23 | 24 | import prettytensor 25 | from prettytensor import pretty_tensor_class 26 | from prettytensor import pretty_tensor_testing 27 | 28 | KEY = 'random_key' 29 | TOLERANCE = 0.000001 30 | 31 | 32 | @prettytensor.Register(assign_defaults='value') 33 | def ValidateMethod(input_tensor, test_class, value): 34 | test_class.assertEqual(KEY, value) 35 | return input_tensor 36 | 37 | 38 | class TemplatedPrettyTensorTest(pretty_tensor_testing.PtTestCase): 39 | 40 | def setUp(self): 41 | super(self.__class__, self).setUp() 42 | # Input is 2x3x5, which isn't a natural size for any op. 43 | self.input_data = numpy.array( 44 | [[[1, 2, 3, 4, 5], 45 | [6, 7, 8, 9, 10], 46 | [10, 12, 13, 14, 15],], [[-1, 2, -3, 4, -5], [6, -7, 8, -9, 10], 47 | [-10, 12, -13, 14, -15]]], 48 | dtype=numpy.float) 49 | self.input = tf.constant(self.input_data, dtype=tf.float32) 50 | 51 | def Template(self, key): 52 | return prettytensor.template(key, self.bookkeeper) 53 | 54 | def testSimpleTemplate(self): 55 | template = self.Template(KEY) 56 | 57 | x = template.construct(random_key=self.input) 58 | out = self.RunTensor(x) 59 | testing.assert_allclose(self.input_data, out, rtol=TOLERANCE) 60 | 61 | def testSingleMethod(self): 62 | template = self.Template(KEY).flatten() 63 | 64 | x = template.construct(random_key=self.input) 65 | out = self.RunTensor(x) 66 | testing.assert_allclose( 67 | self.input_data.reshape([2, 15]), 68 | out, 69 | rtol=TOLERANCE) 70 | 71 | def testSequential(self): 72 | seq = self.Template(KEY).sequential() 73 | seq.flatten() 74 | seq.fully_connected(100) 75 | out = self.RunTensor(seq.as_layer().construct(random_key=self.input)) 76 | self.assertSequenceEqual([2, 100], out.shape) 77 | 78 | def testAttach(self): 79 | input_pt = self.Wrap(self.input) 80 | template = self.Template('input').flatten().fully_connected(100) 81 | out = self.RunTensor(input_pt.attach_template(template, 'input')) 82 | self.assertSequenceEqual([2, 100], out.shape) 83 | 84 | def testUnboundVariableForParameter(self): 85 | input_pt = self.Wrap(self.input) 86 | template = input_pt.flatten().fully_connected(prettytensor.UnboundVariable( 87 | 'width')) 88 | self.assertTrue(isinstance(template, pretty_tensor_class._DeferredLayer)) 89 | out = self.RunTensor(template.construct(width=200)) 90 | self.assertSequenceEqual([2, 200], out.shape) 91 | 92 | def testMissingUnboundVariable(self): 93 | input_pt = self.Wrap(self.input) 94 | template = input_pt.flatten().fully_connected(prettytensor.UnboundVariable( 95 | 'width')) 96 | with self.assertRaises(ValueError): 97 | template.construct() 98 | 99 | def testUnboundVariableReused(self): 100 | """The same unbound_var can be used multiple times in a graph.""" 101 | input_pt = self.Wrap(self.input) 102 | unbound_var = prettytensor.UnboundVariable('width') 103 | template = (input_pt.flatten().fully_connected(unbound_var) 104 | .fully_connected(unbound_var)) 105 | out = self.RunTensor(template.construct(width=200)) 106 | self.assertSequenceEqual([2, 200], out.shape) 107 | 108 | def testAttachToTemplate(self): 109 | input_pt = self.Wrap(self.input) 110 | template1 = self.Template('input').flatten() 111 | template2 = self.Template('input').fully_connected(100) 112 | 113 | joined = template1.attach_template(template2, 'input') 114 | out = self.RunTensor(input_pt.attach_template(joined, 'input')) 115 | self.assertSequenceEqual([2, 100], out.shape) 116 | 117 | def testUnboundVariableAsDefault(self): 118 | """The same unbound_var can be used multiple times in a graph.""" 119 | input_pt = self.Wrap(self.input) 120 | with prettytensor.defaults_scope( 121 | value=prettytensor.UnboundVariable('key')): 122 | x = input_pt.ValidateMethod(self) 123 | self.assertTrue(isinstance(x, pretty_tensor_class._DeferredLayer)) 124 | x.construct(key=KEY) 125 | 126 | def testConflictingUnboundVariables(self): 127 | """Two unbound_vars with the same name are considered conflicting.""" 128 | input_pt = self.Wrap(self.input) 129 | with self.assertRaises(ValueError): 130 | (input_pt.flatten() 131 | .fully_connected(prettytensor.UnboundVariable('width')) 132 | .fully_connected(prettytensor.UnboundVariable('width'))) 133 | 134 | def testMultipleUnboundVariables(self): 135 | input_pt = self.Wrap(self.input) 136 | template = (input_pt.flatten() 137 | .fully_connected(prettytensor.UnboundVariable('width')) 138 | .fully_connected(prettytensor.UnboundVariable('width2'))) 139 | out = self.RunTensor(template.construct(width=200, width2=100)) 140 | self.assertSequenceEqual([2, 100], out.shape) 141 | 142 | def testExtraValues(self): 143 | input_pt = self.Wrap(self.input) 144 | template = (input_pt.flatten() 145 | .fully_connected(prettytensor.UnboundVariable('width'))) 146 | with self.assertRaises(ValueError): 147 | template.construct(width=200, width2=100) 148 | 149 | def testIncompatibleUnboundVariableValues(self): 150 | """Ensures that an error is thrown if a var is given incompatible values. 151 | 152 | Since the primary use case of templates is parameter sharing, it is 153 | important that substitutions don't conflict. 154 | """ 155 | input_pt = self.Wrap(self.input) 156 | full = input_pt.flatten().fully_connected(prettytensor.UnboundVariable( 157 | 'width')) 158 | full.construct(width=100) 159 | with self.assertRaises(ValueError): 160 | full.construct(width=200) 161 | 162 | def BuildLargishGraph(self, input_pt): 163 | seq = input_pt.sequential() 164 | seq.reshape('___1') 165 | seq.conv2d(1, 10) 166 | with seq.subdivide(2) as [a, b]: 167 | a.with_name('a').conv2d(1, 5) 168 | b.with_name('b').conv2d(1, 15) 169 | seq.with_name('wow') 170 | seq.flatten() 171 | seq.fully_connected(100, name='a_funny_name') 172 | return seq.as_layer() 173 | 174 | def testGraphMatchesImmediate(self): 175 | """Ensures that the vars line up between the two modes.""" 176 | with tf.Graph().as_default(): 177 | input_pt = prettytensor.wrap( 178 | tf.constant(self.input_data, dtype=tf.float32)) 179 | self.BuildLargishGraph(input_pt) 180 | normal_names = sorted([v.name for v in tf.global_variables()]) 181 | 182 | with tf.Graph().as_default(): 183 | template = prettytensor.template('input') 184 | self.BuildLargishGraph(template).construct(input=prettytensor.wrap( 185 | tf.constant(self.input_data, dtype=tf.float32))) 186 | template_names = sorted([v.name for v in tf.global_variables()]) 187 | 188 | self.assertSequenceEqual(normal_names, template_names) 189 | 190 | def testVariablesAreShared(self): 191 | """Ensures that adding the graph twice shares variables.""" 192 | input_pt = self.Wrap(self.input) 193 | template = self.Template('input').flatten().fully_connected(10) 194 | 195 | l1 = template.construct(input=input_pt) 196 | l2 = template.construct(input=input_pt) 197 | self.assertNotEqual(l1.tensor, l2.tensor) 198 | 199 | v1 = self.RunTensor(l1, init=True) 200 | v2 = self.RunTensor(l2, init=False) 201 | testing.assert_allclose(v1, v2, rtol=TOLERANCE) 202 | 203 | def testBind(self): 204 | input_pt = self.Wrap(self.input) 205 | template = self.Template('input').flatten().fully_connected(10) 206 | 207 | l1 = template.bind(input=input_pt).construct() 208 | l2 = template.construct(input=input_pt) 209 | v1 = self.RunTensor(l1, init=True) 210 | v2 = self.RunTensor(l2, init=False) 211 | testing.assert_allclose(v1, v2, rtol=TOLERANCE) 212 | 213 | def testBindTuple(self): 214 | labels = numpy.array([[0., 1.], [1., 0.]], dtype=numpy.float32) 215 | template = self.Template('input').flatten().softmax_classifier(2, labels) 216 | bound = template.bind(input=self.input) 217 | 218 | tuple1 = bound.construct() 219 | tuple2 = template.construct(input=self.input) 220 | 221 | self.assertNotEqual(tuple1.softmax.tensor, tuple2.softmax.tensor) 222 | softmax1 = self.RunTensor(tuple1.softmax, init=True) 223 | loss1 = self.RunTensor(tuple1.loss, init=False) 224 | softmax2 = self.RunTensor(tuple2.softmax, init=False) 225 | loss2 = self.RunTensor(tuple2.loss, init=False) 226 | testing.assert_allclose(softmax1, softmax2, rtol=TOLERANCE) 227 | testing.assert_allclose(loss1, loss2, rtol=TOLERANCE) 228 | 229 | def testConstructAllWithConflictingValues(self): 230 | labels = numpy.array([[0., 1.], [1., 0.]], dtype=numpy.float32) 231 | template = self.Template('input').flatten().softmax_classifier(2, labels) 232 | 233 | softmax = template.softmax.bind(input=self.input) 234 | loss = template.loss.bind(input=labels) 235 | with self.assertRaises(ValueError): 236 | prettytensor.construct_all([softmax, loss]) 237 | 238 | 239 | if __name__ == '__main__': 240 | unittest.main() 241 | -------------------------------------------------------------------------------- /prettytensor/train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | """Imports some utilities for training models.""" 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | 16 | # pylint: disable=unused-import, wildcard-import 17 | from prettytensor.input_helpers import batch 18 | from prettytensor.input_helpers import feed_numpy 19 | from prettytensor.local_trainer import create_checkpointing_runner 20 | from prettytensor.local_trainer import create_follower_runner 21 | from prettytensor.local_trainer import Runner 22 | from prettytensor.recurrent_networks import RecurrentRunner 23 | from prettytensor.replay_queue import ReplayableQueue 24 | -------------------------------------------------------------------------------- /prettytensor/tutorial/README.md: -------------------------------------------------------------------------------- 1 | # Tutorial 2 | 3 | These are simple models that highlight some of Pretty Tensor's features and 4 | that hopefully will be useful branching points for your own experiments. 5 | 6 | While each tutorial is intended to be standalone, the recommended order is: 7 | 8 | 1. `mnist.py` 9 | 2. `baby_names.py` 10 | 3. `shakespeare.py` 11 | 12 | All of the tutorials show you how to build a model and run training/evaluation. 13 | 14 | ## MNIST 15 | 16 | `mnist.py` shows a simple image classification model using the 17 | [MNIST dataset](http://yann.lecun.com/exdb/mnist/). 18 | 19 | ## Baby Names 20 | 21 | `baby_names.py` is a recurrent network that uses data from the Office of 22 | Retirement and Disability Policy, Social Security Administration about all 23 | children born in the US for the past century. The model uses an 24 | [Long Short Term Memory](http://colah.github.io/posts/2015-08-Understanding-LSTMs/) 25 | (LSTM) to make a prediction on the boy/girl ratio for each name. 26 | 27 | ## Shakespeare 28 | 29 | `shakespeare.py` uses stacked LSTMs to read each character of Shakespeare and 30 | predict the next character. It can be used to sample your own Shakespearian 31 | comedy. 32 | -------------------------------------------------------------------------------- /prettytensor/tutorial/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | -------------------------------------------------------------------------------- /prettytensor/tutorial/baby_names.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | """Tutorial to predict the sex of a baby from the name. 12 | 13 | This model takes a dataset of baby names to ratio of sexes for that name and 14 | then trains an LSTM to predict the ratio given the characters of the name. 15 | 16 | The input are the characters of the name as ASCII codes (0-127) and it is 17 | unrolled for 15 steps, which is the longest name in the corpus. The results are 18 | fed through a recurrent network and then to a 2 way classifier that predicts the 19 | boy/girl ratio. 20 | 21 | This demonstrates how to train a classifier on the last output using an LSTM, 22 | which can be at any point (some names are short and some are long) by setting 23 | weights on each example. This also demonstrates how to efficiently reshape the 24 | network for the classifier and how to use dropout in both a training and eval 25 | graph. 26 | """ 27 | from __future__ import absolute_import 28 | from __future__ import division 29 | from __future__ import print_function 30 | 31 | 32 | import numpy 33 | from six.moves import xrange # pylint: disable=redefined-builtin 34 | import tensorflow as tf 35 | 36 | import prettytensor as pt 37 | from prettytensor.tutorial import data_utils 38 | 39 | tf.app.flags.DEFINE_string( 40 | 'save_path', None, 'Where to save the model checkpoints on local disk. ' 41 | 'Checkpoints are in LevelDb.') 42 | FLAGS = tf.app.flags.FLAGS 43 | 44 | 45 | BATCH_SIZE = 32 46 | CHARS = 128 47 | TIMESTEPS = 15 48 | SEXES = 2 49 | 50 | EMBEDDING_SIZE = 16 51 | 52 | 53 | def create_model(text_in, 54 | labels, 55 | timesteps, 56 | per_example_weights, 57 | phase=pt.Phase.train): 58 | """Creates a model for running baby names.""" 59 | with pt.defaults_scope(phase=phase, l2loss=0.00001): 60 | # The embedding lookup must be placed on a cpu. 61 | with tf.device('/cpu:0'): 62 | embedded = text_in.embedding_lookup(CHARS, [EMBEDDING_SIZE]) 63 | # We need to cleave the sequence because sequence lstm expect each 64 | # timestep to be in its own Tensor. 65 | lstm = (embedded.cleave_sequence(timesteps).sequence_lstm(CHARS)) 66 | 67 | # The classifier is much more efficient if it runs across the entire 68 | # batch at once, so we want to squash (i.e. uncleave). 69 | # 70 | # Hidden nodes is set to 32 because it seems to work well. 71 | return (lstm.squash_sequence().fully_connected(32, 72 | activation_fn=tf.nn.relu) 73 | .dropout(0.7) 74 | .softmax_classifier(SEXES, 75 | labels, 76 | per_example_weights=per_example_weights)) 77 | 78 | 79 | def main(_=None): 80 | print('Starting Baby Names') 81 | 82 | # Since we are feeding our data as numpy arrays, we need to create 83 | # placeholders in the graph. 84 | # These must then be fed using the feed dict. 85 | input_placeholder = tf.placeholder(tf.int32, [BATCH_SIZE, TIMESTEPS]) 86 | output_placeholder = tf.placeholder(tf.float32, [BATCH_SIZE, SEXES]) 87 | 88 | inp = data_utils.reshape_data(input_placeholder) 89 | 90 | # Create a label for each timestep. 91 | labels = data_utils.reshape_data( 92 | tf.reshape( 93 | tf.tile(output_placeholder, [1, TIMESTEPS]), [BATCH_SIZE, TIMESTEPS, 94 | SEXES]), 95 | per_example_length=2) 96 | 97 | # We also need to set per example weights so that the softmax doesn't output a 98 | # prediction on intermediate nodes. 99 | length_placeholder = tf.placeholder(tf.int32, [BATCH_SIZE, 1]) 100 | 101 | # We need a dense multiplier for the per example weights. The only place 102 | # that has a non-zero loss is the first EOS after the last character of the 103 | # name; the characters in the name and the trailing EOS characters are given a 104 | # 0 loss by assigning the weight to 0.0 and in the end only one character in 105 | # each batch has a weight of 1.0. 106 | # sparse_to_dense does a lookup using the indices from the first Tensor. 107 | # Because we are filling in a 2D array, the indices need to be 2 dimensional. 108 | # Since we want to assign 1 value for each row, the first dimension can just 109 | # be a sequence. 110 | t = tf.concat( 111 | [ 112 | tf.constant( 113 | numpy.arange(BATCH_SIZE).reshape((BATCH_SIZE, 1)), 114 | dtype=tf.int32), length_placeholder 115 | ], 116 | 1) 117 | 118 | # Squeeze removes dimensions that are equal to 1. per_example_weights must 119 | # end up as 1 dimensional. 120 | per_example_weights = data_utils.reshape_data(tf.sparse_to_dense( 121 | t, [BATCH_SIZE, TIMESTEPS], 1.0, default_value=0.0)).squeeze() 122 | 123 | # We need 2 copies of the graph that share variables. The first copy runs 124 | # training and will do dropout if specified and the second will not include 125 | # dropout. Dropout is controlled by the phase argument, which sets the mode 126 | # consistently throughout a graph. 127 | with tf.variable_scope('baby_names'): 128 | result = create_model(inp, labels, TIMESTEPS, per_example_weights) 129 | 130 | # Call variable scope by name so we also create a name scope. This ensures 131 | # that we share variables and our names are properly organized. 132 | with tf.variable_scope('baby_names', reuse=True): 133 | # Some ops have different behaviors in test vs train and these take a phase 134 | # argument. 135 | test_result = create_model(inp, 136 | labels, 137 | TIMESTEPS, 138 | per_example_weights, 139 | phase=pt.Phase.test) 140 | 141 | # For tracking accuracy in evaluation, we need to add an evaluation node. 142 | # We only run this when testing, so we need to specify that in the phase. 143 | # Some ops have different behaviors in test vs train and these take a phase 144 | # argument. 145 | accuracy = test_result.softmax.evaluate_classifier( 146 | labels, 147 | phase=pt.Phase.test, 148 | per_example_weights=per_example_weights) 149 | 150 | # We can also compute a batch accuracy to monitor progress. 151 | batch_accuracy = result.softmax.evaluate_classifier( 152 | labels, 153 | phase=pt.Phase.train, 154 | per_example_weights=per_example_weights) 155 | 156 | # Grab the inputs, outputs and lengths as numpy arrays. 157 | # Lengths could have been calculated from names, but it was easier to 158 | # calculate inside the utility function. 159 | names, sex, lengths = data_utils.baby_names(TIMESTEPS) 160 | 161 | epoch_size = len(names) // BATCH_SIZE 162 | # Create the gradient optimizer and apply it to the graph. 163 | # pt.apply_optimizer adds regularization losses and sets up a step counter 164 | # (pt.global_step()) for you. 165 | # This sequence model does very well with initially high rates. 166 | optimizer = tf.train.AdagradOptimizer( 167 | tf.train.exponential_decay(1.0, 168 | pt.global_step(), 169 | epoch_size, 170 | 0.95, 171 | staircase=True)) 172 | train_op = pt.apply_optimizer(optimizer, losses=[result.loss]) 173 | 174 | # We can set a save_path in the runner to automatically checkpoint every so 175 | # often. Otherwise at the end of the session, the model will be lost. 176 | runner = pt.train.Runner(save_path=FLAGS.save_path) 177 | with tf.Session(): 178 | for epoch in xrange(100): 179 | # Shuffle the training data. 180 | names, sex, lengths = data_utils.permute_data((names, sex, lengths)) 181 | 182 | runner.train_model( 183 | train_op, 184 | [result.loss, batch_accuracy], 185 | epoch_size, 186 | feed_vars=(input_placeholder, output_placeholder, length_placeholder), 187 | feed_data=pt.train.feed_numpy(BATCH_SIZE, names, sex, lengths), 188 | print_every=100) 189 | classification_accuracy = runner.evaluate_model( 190 | accuracy, 191 | epoch_size, 192 | print_every=0, 193 | feed_vars=(input_placeholder, output_placeholder, length_placeholder), 194 | feed_data=pt.train.feed_numpy(BATCH_SIZE, names, sex, lengths)) 195 | 196 | print('Accuracy after epoch %d: %g%%' % ( 197 | epoch + 1, classification_accuracy * 100)) 198 | 199 | 200 | if __name__ == '__main__': 201 | tf.app.run() 202 | -------------------------------------------------------------------------------- /prettytensor/tutorial/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | """Data utils bundles the utilties to download and munge data in numpy.""" 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | 16 | import csv 17 | import gzip 18 | import os.path 19 | import sys 20 | 21 | 22 | 23 | import numpy as np 24 | from six.moves import xrange # pylint: disable=redefined-builtin 25 | from six.moves.urllib import request 26 | import tensorflow as tf 27 | 28 | import prettytensor as pt 29 | 30 | 31 | WORK_DIRECTORY = '/tmp/data' 32 | MNIST_URL = 'http://yann.lecun.com/exdb/mnist/' 33 | UNK = 0 34 | EOS = 1 35 | 36 | 37 | def maybe_download(url, filename): 38 | """Download the data from Yann's website, unless it's already here.""" 39 | if not os.path.exists(WORK_DIRECTORY): 40 | os.mkdir(WORK_DIRECTORY) 41 | filepath = os.path.join(WORK_DIRECTORY, filename) 42 | if not os.path.exists(filepath): 43 | filepath, _ = request.urlretrieve(url + filename, filepath) 44 | statinfo = os.stat(filepath) 45 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 46 | return filepath 47 | 48 | 49 | def mnist_extract_data(filename, num_images): 50 | """Extract the images into a 4D tensor [image index, y, x, channels]. 51 | 52 | Values are rescaled from [0, 255] down to [-0.5, 0.5]. 53 | 54 | Args: 55 | filename: The local filename. 56 | num_images: The number of images in this file. 57 | Returns: 58 | The data as a numpy array with the values centered. 59 | """ 60 | print('Extracting', filename) 61 | with gzip.open(filename) as bytestream: 62 | bytestream.read(16) 63 | buf = bytestream.read(28 * 28 * num_images) 64 | data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32) 65 | data -= 255 / 2.0 66 | data /= 255.0 67 | data = data.reshape(num_images, 28, 28, 1) 68 | return data 69 | 70 | 71 | def mnist_extract_labels(filename, num_images): 72 | """Extract the labels into a 1-hot matrix [image index, label index].""" 73 | print('Extracting', filename) 74 | with gzip.open(filename) as bytestream: 75 | bytestream.read(8) 76 | buf = bytestream.read(1 * num_images) 77 | labels = np.frombuffer(buf, dtype=np.uint8) 78 | # Convert to dense 1-hot representation. 79 | return (np.arange(10) == labels[:, None]).astype(np.float32) 80 | 81 | 82 | def permute_data(arrays, random_state=None): 83 | """Permute multiple numpy arrays with the same order.""" 84 | if any(len(a) != len(arrays[0]) for a in arrays): 85 | raise ValueError('All arrays must be the same length.') 86 | if not random_state: 87 | random_state = np.random 88 | order = random_state.permutation(len(arrays[0])) 89 | return [a[order] for a in arrays] 90 | 91 | 92 | def mnist(training): 93 | """Downloads MNIST and loads it into numpy arrays.""" 94 | if training: 95 | data_filename = 'train-images-idx3-ubyte.gz' 96 | labels_filename = 'train-labels-idx1-ubyte.gz' 97 | count = 60000 98 | else: 99 | data_filename = 't10k-images-idx3-ubyte.gz' 100 | labels_filename = 't10k-labels-idx1-ubyte.gz' 101 | count = 10000 102 | data_filename = maybe_download(MNIST_URL, data_filename) 103 | labels_filename = maybe_download(MNIST_URL, labels_filename) 104 | 105 | return (mnist_extract_data(data_filename, count), 106 | mnist_extract_labels(labels_filename, count)) 107 | 108 | 109 | def convert_to_int(char): 110 | i = ord(char) 111 | if i >= 128: 112 | return UNK 113 | return i 114 | 115 | 116 | def shakespeare(chunk_size): 117 | """Downloads Shakespeare, converts it into ASCII codes and chunks it. 118 | 119 | Args: 120 | chunk_size: The dataset is broken down so that it is shaped into batches x 121 | chunk_size. 122 | Returns: 123 | A numpy array of ASCII codes shaped into batches x chunk_size. 124 | """ 125 | file_name = maybe_download('http://cs.stanford.edu/people/karpathy/char-rnn/', 126 | 'shakespear.txt') 127 | with open(file_name) as f: 128 | shakespeare_full = f.read() 129 | 130 | # Truncate the data. 131 | length = (len(shakespeare_full) // chunk_size) * chunk_size 132 | if length < len(shakespeare_full): 133 | shakespeare_full = shakespeare_full[:length] 134 | arr = np.array([convert_to_int(c) for c in shakespeare_full])[ 135 | 0:len(shakespeare_full) / chunk_size * chunk_size] 136 | return arr.reshape((len(arr) / chunk_size, chunk_size)) 137 | 138 | 139 | def baby_names(max_length=15): 140 | """Opens the baby_names csv file and produces numpy array. 141 | 142 | Args: 143 | max_length: The maximum length, 15 was the longest name when this was 144 | written. Short entries will be padded with the EOS marker. 145 | Returns: 146 | A numpy array of the names converted to ascii codes, the labels and an 147 | array of lengths. 148 | Raises: 149 | ValueError: if max_length is too small. 150 | """ 151 | names = [] 152 | lengths = [] 153 | targets = [] 154 | with open(os.path.join(os.path.dirname(sys.modules[__name__].__file__), 155 | 'baby_names.csv'), 'rb') as f: 156 | first = True 157 | for l in csv.reader(f, delimiter=','): 158 | if first: 159 | first = False 160 | continue 161 | assert len(l) == 4, l 162 | name = l[0] 163 | if max_length < len(name): 164 | raise ValueError('Max length is too small: %d > %d' % 165 | (max_length, len(name))) 166 | chars = [convert_to_int(c) for c in name] 167 | names.append(chars + ([EOS] * (max_length - len(chars)))) 168 | lengths.append([len(name)]) 169 | values = [float(l[2]), float(l[3])] 170 | if abs(sum(values) - 1) > 0.001: 171 | raise ValueError('Each row must sum to 1: %s' % l) 172 | targets.append(values) 173 | return np.array(names), np.array(targets), np.array(lengths) 174 | 175 | 176 | def reshape_data(tensor, per_example_length=1): 177 | """Reshapes input so that it is appropriate for sequence_lstm.. 178 | 179 | The expected format for sequence lstms is 180 | [timesteps * batch, per_example_length] and the data produced by the utilities 181 | is [batch, timestep, *optional* expected_length]. The result can be cleaved 182 | so that there is a Tensor per timestep. 183 | 184 | Args: 185 | tensor: The tensor to reshape. 186 | per_example_length: The number of examples at each timestep. 187 | Returns: 188 | A Pretty Tensor that is compatible with cleave and then sequence_lstm. 189 | 190 | """ 191 | # We can put the data into a format that can be easily cleaved by 192 | # transposing it (so that it varies fastest in batch) and then making each 193 | # component have a single value. 194 | # This will make it compatible with the Pretty Tensor function 195 | # cleave_sequence. 196 | dims = [1, 0] 197 | for i in xrange(2, tensor.get_shape().ndims): 198 | dims.append(i) 199 | return pt.wrap(tf.transpose(tensor, dims)).reshape([-1, per_example_length]) 200 | -------------------------------------------------------------------------------- /prettytensor/tutorial/mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | """MNIST tutorial. 12 | 13 | This uses Pretty Tensor to define and train either a 2 layer model or a 14 | convolutional model in the style of LeNet 5. 15 | See: http://yann.lecun.com/exdb/lenet/ 16 | """ 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | 22 | from six.moves import xrange # pylint: disable=redefined-builtin 23 | import tensorflow as tf 24 | 25 | import prettytensor as pt 26 | from prettytensor.tutorial import data_utils 27 | 28 | tf.app.flags.DEFINE_string( 29 | 'save_path', None, 'Where to save the model checkpoints.') 30 | FLAGS = tf.app.flags.FLAGS 31 | 32 | BATCH_SIZE = 50 33 | EPOCH_SIZE = 60000 // BATCH_SIZE 34 | TEST_SIZE = 10000 // BATCH_SIZE 35 | 36 | tf.app.flags.DEFINE_string('model', 'full', 37 | 'Choose one of the models, either full or conv') 38 | FLAGS = tf.app.flags.FLAGS 39 | 40 | 41 | def multilayer_fully_connected(images, labels): 42 | """Creates a multi layer network of fully_connected layers. 43 | 44 | Each layer is 100 neurons. Please change this to experiment with 45 | architectures. 46 | 47 | Args: 48 | images: The input images. 49 | labels: The labels as dense one-hot vectors. 50 | Returns: 51 | A softmax result. 52 | """ 53 | # Pretty Tensor is a thin wrapper on Tensors. 54 | # Change this method to experiment with other architectures 55 | images = pt.wrap(images) 56 | with pt.defaults_scope(activation_fn=tf.nn.relu, l2loss=0.00001): 57 | return (images.flatten().fully_connected(100).fully_connected(100) 58 | .softmax_classifier(10, labels)) 59 | 60 | 61 | def lenet5(images, labels): 62 | """Creates a multi layer convolutional network. 63 | 64 | The architecture is similar to that defined in LeNet 5. 65 | Please change this to experiment with architectures. 66 | 67 | Args: 68 | images: The input images. 69 | labels: The labels as dense one-hot vectors. 70 | Returns: 71 | A softmax result. 72 | """ 73 | images = pt.wrap(images) 74 | with pt.defaults_scope(activation_fn=tf.nn.relu, l2loss=0.00001): 75 | return (images.conv2d(5, 20).max_pool(2, 2).conv2d(5, 50).max_pool(2, 2) 76 | .flatten().fully_connected(500).softmax_classifier(10, labels)) 77 | 78 | 79 | def main(_=None): 80 | # Since we are feeding our data as numpy arrays, we need to create 81 | # placeholders in the graph. 82 | # These must then be fed using the feed dict. 83 | image_placeholder = tf.placeholder(tf.float32, [BATCH_SIZE, 28, 28, 1]) 84 | labels_placeholder = tf.placeholder(tf.float32, [BATCH_SIZE, 10]) 85 | 86 | # Create our model. The result of softmax_classifier is a namedtuple 87 | # that has members result.loss and result.softmax. 88 | if FLAGS.model == 'full': 89 | result = multilayer_fully_connected(image_placeholder, labels_placeholder) 90 | elif FLAGS.model == 'conv': 91 | result = lenet5(image_placeholder, labels_placeholder) 92 | else: 93 | raise ValueError('model must be full or conv: %s' % FLAGS.model) 94 | 95 | # For tracking accuracy in evaluation, we need to add an evaluation node. 96 | # We only include this part of the graph when testing, so we need to specify 97 | # that in the phase. 98 | # Some ops have different behaviors in test vs train and these take a phase 99 | # argument. 100 | accuracy = result.softmax.evaluate_classifier(labels_placeholder, 101 | phase=pt.Phase.test) 102 | 103 | # Grab the data as numpy arrays. 104 | train_images, train_labels = data_utils.mnist(training=True) 105 | test_images, test_labels = data_utils.mnist(training=False) 106 | 107 | # Create the gradient optimizer and apply it to the graph. 108 | # pt.apply_optimizer adds regularization losses and sets up a step counter 109 | # (pt.global_step()) for you. 110 | optimizer = tf.train.GradientDescentOptimizer(0.01) 111 | train_op = pt.apply_optimizer(optimizer, losses=[result.loss]) 112 | 113 | # We can set a save_path in the runner to automatically checkpoint every so 114 | # often. Otherwise at the end of the session, the model will be lost. 115 | runner = pt.train.Runner(save_path=FLAGS.save_path) 116 | with tf.Session(): 117 | for epoch in xrange(10): 118 | # Shuffle the training data. 119 | train_images, train_labels = data_utils.permute_data( 120 | (train_images, train_labels)) 121 | 122 | runner.train_model( 123 | train_op, 124 | result.loss, 125 | EPOCH_SIZE, 126 | feed_vars=(image_placeholder, labels_placeholder), 127 | feed_data=pt.train.feed_numpy(BATCH_SIZE, train_images, train_labels), 128 | print_every=100) 129 | classification_accuracy = runner.evaluate_model( 130 | accuracy, 131 | TEST_SIZE, 132 | feed_vars=(image_placeholder, labels_placeholder), 133 | feed_data=pt.train.feed_numpy(BATCH_SIZE, test_images, test_labels)) 134 | print('Accuracy after %d epoch %g%%' % ( 135 | epoch + 1, classification_accuracy * 100)) 136 | 137 | 138 | if __name__ == '__main__': 139 | tf.app.run() 140 | -------------------------------------------------------------------------------- /prettytensor/tutorial/shakespeare.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | """Shakespeare tutorial. 12 | 13 | The Shakespeare tutorial downloads a snippet of Shakespeare, munges the data 14 | into the correct format and then creates a 2 layer LSTM to predict the next 15 | character given the current character. 16 | """ 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import random 22 | 23 | 24 | import numpy 25 | from six.moves import xrange # pylint: disable=redefined-builtin 26 | import tensorflow as tf 27 | 28 | import prettytensor as pt 29 | from prettytensor.tutorial import data_utils 30 | 31 | tf.app.flags.DEFINE_string( 32 | 'save_path', None, 'Where to save the model checkpoints.') 33 | tf.app.flags.DEFINE_integer( 34 | 'epochs', 10, 'The number of epochs to run training on this model.') 35 | FLAGS = tf.app.flags.FLAGS 36 | 37 | BATCH_SIZE = 8 38 | CHARS = 128 39 | TIMESTEPS = 100 40 | 41 | # The size of the embedding for each character that will be learned. 42 | EMBEDDING_SIZE = 16 43 | 44 | # The number of cells in the lower and upper LSTM layers. 45 | LOWER = 128 46 | UPPER = 256 47 | 48 | 49 | def create_model(text_in, timesteps, phase): 50 | """Creates a 2 layer LSTM model with dropout. 51 | 52 | Args: 53 | text_in: The input text as ASCII ordinals in a Tensor. 54 | timesteps: The number of timesteps in the sequence. 55 | phase: Phase controls whether or not dropout is active. In training mode 56 | we want to perform dropout, but in test we want to disable it. 57 | Returns: 58 | The logits. 59 | """ 60 | with pt.defaults_scope(activation_fn=tf.nn.relu, l2loss=0.00001): 61 | # The embedding lookup must be placed on a cpu. 62 | with tf.device('/cpu:0'): 63 | embedded = text_in.embedding_lookup(CHARS, [EMBEDDING_SIZE]) 64 | # Because the sequence LSTM expects each timestep to be its own Tensor, 65 | # we need to cleave the sequence. 66 | # Below we can build a stacked 2 layer LSTM by just chaining them together. 67 | # You can stack as many layers as you want. 68 | lstm = (embedded 69 | .cleave_sequence(timesteps) 70 | .sequence_lstm(LOWER) 71 | .sequence_lstm(UPPER)) 72 | 73 | # The classifier is much more efficient if it runs across the entire 74 | # dataset at once, so we want to squash (i.e. uncleave). 75 | # Note: if phase is test, dropout is a noop. 76 | return (lstm.squash_sequence() 77 | .dropout(keep_prob=0.8, phase=phase) 78 | .fully_connected(CHARS, activation_fn=None)) 79 | 80 | 81 | def sample( 82 | input_placeholder, logits, seed=None, max_length=1024, temperature=1.0): 83 | """Samples from the LSTM model. 84 | 85 | Sampling is done by first running either the seed or an arbitrary character 86 | through the model and then drawing the next character from the probability 87 | distribution definted by `softmax`. 88 | 89 | Args: 90 | input_placeholder: A placeholder that expects a scalar feed. 91 | logits: The logits. This works with the logits so that it can apply the 92 | temperature. 93 | seed: Either a string of characters to prime the network or None. 94 | max_length: The maximum length to draw in case EOS is not reached. 95 | temperature: A value that is used to renormalize the inputs. A higher value 96 | selects less likely choices. 97 | Returns: 98 | A string that was sampled from the model. 99 | """ 100 | assert temperature > 0, 'Temperature must be greater than 0.' 101 | if not seed: 102 | # The model expects an input to do inference, so seed with a single letter. 103 | seed = chr(ord('A') + random.randint(0, 25)) 104 | result = '' 105 | 106 | # The recurrent runner takes care of tracking the model's state at each step 107 | # and provides a reset call to zero it out for each query. 108 | recurrent_runner = pt.train.RecurrentRunner() 109 | 110 | # We need to reset the hidden state for each query. 111 | recurrent_runner.reset() 112 | # Initialize the system 113 | for c in seed[:-1]: 114 | recurrent_runner.run([logits], 115 | {input_placeholder: data_utils.convert_to_int(c)}) 116 | result += c 117 | 118 | # Start sampling! 119 | ci = ord(seed[-1]) 120 | while len(result) < max_length and ci != data_utils.EOS: 121 | result += chr(ci) 122 | # The softmax is probability normalized and would have been appropriate here 123 | # if we weren't applying the temperature (temperature could also be done in 124 | # TensorFlow). 125 | logit_result = recurrent_runner.run([logits], 126 | {input_placeholder: ci})[0][0] 127 | logit_result /= temperature 128 | 129 | # Apply the softmax in numpy to convert from logits to probabilities. 130 | # Subtract off the max for numerical stability -- logits are invariant to 131 | # additive scaling and this eliminates overflows. 132 | logit_result -= logit_result.max() 133 | 134 | distribution = numpy.exp(logit_result) 135 | distribution /= distribution.sum() 136 | 137 | # Numpy multinomial needs the value to be strictly < 1 138 | distribution -= .00000001 139 | ci = numpy.argmax(numpy.random.multinomial(1, distribution)) 140 | result += chr(ci) # Add the last letter. 141 | return result 142 | 143 | 144 | def main(_=None): 145 | print('Starting Shakespeare') 146 | 147 | # Since we are feeding our data as numpy arrays, we need to create 148 | # placeholders in the graph. 149 | # These must then be fed using the feed dict. 150 | input_placeholder = tf.placeholder(tf.int32, [BATCH_SIZE, TIMESTEPS]) 151 | output_placeholder = tf.placeholder(tf.int32, [BATCH_SIZE, TIMESTEPS]) 152 | 153 | merged_size = BATCH_SIZE * TIMESTEPS 154 | 155 | inp = data_utils.reshape_data(input_placeholder) 156 | 157 | # We need a dense output to calculate loss and accuracy. 158 | # sparse_to_dense does a lookup using the indices from the first Tensor. 159 | # Because we are filling in a 2D array, the indices need to be 2 dimensional. 160 | t = tf.concat( 161 | [ 162 | tf.constant( 163 | numpy.arange(merged_size).reshape((merged_size, 1)), 164 | dtype=tf.int32), data_utils.reshape_data(output_placeholder) 165 | ], 166 | 1) 167 | 168 | labels = tf.sparse_to_dense(t, [merged_size, CHARS], 1.0, 0.0) 169 | 170 | # Some ops have different behaviors in test vs train and these take a phase 171 | # argument. 172 | with tf.variable_scope('shakespeare'): 173 | training_logits = create_model(inp, TIMESTEPS, pt.Phase.train) 174 | # Create the result. Softmax applies softmax and creates a cross entropy 175 | # loss. The result is a namedtuple. 176 | training_result = training_logits.softmax(labels) 177 | 178 | # Create the gradient optimizer and apply it to the graph. 179 | # pt.apply_optimizer adds regularization losses and sets up a step counter 180 | # (pt.global_step()) for you. 181 | optimizer = tf.train.AdagradOptimizer(0.5) 182 | train_op = pt.apply_optimizer(optimizer, losses=[training_result.loss]) 183 | 184 | # For tracking accuracy in evaluation, we need to add an evaluation node. 185 | # We only run this when testing, so we need to specify that in the phase. 186 | # We also want to disable dropout, so we pass the phase to create_model. 187 | 188 | # Call variable scope by name so we also create a name scope. This ensures 189 | # that we share variables and our names are properly organized. 190 | with tf.variable_scope('shakespeare', reuse=True): 191 | test_logits = create_model(inp, TIMESTEPS, pt.Phase.test) 192 | test_result = test_logits.softmax(labels) 193 | 194 | # Accuracy creates variables, so make it outside of the above scope. 195 | accuracy = test_result.softmax.evaluate_classifier(labels, 196 | phase=pt.Phase.test) 197 | 198 | # Create an inference model so that we can sample. The big difference is 199 | # that the input is a single character and it requires reset nodes. 200 | # Also place summaries in a different collection. The default summaries have 201 | # dependencies on running the graph and would introduce a dependence on the 202 | # inference placeholder. 203 | with tf.variable_scope('shakespeare', reuse=True), pt.defaults_scope( 204 | summary_collections=['INFERENCE_SUMMARIES']): 205 | inference_input = tf.placeholder(tf.int32, []) 206 | # Needs to be 2 dimensional so that it matches the dims of the other models. 207 | reshaped = pt.wrap(inference_input).reshape([1, 1]) 208 | inference_logits = create_model(reshaped, 1, pt.Phase.infer) 209 | 210 | # Grab the data as numpy arrays. 211 | shakespeare = data_utils.shakespeare(TIMESTEPS + 1) 212 | shakespeare_in = shakespeare[:, :-1] 213 | shakespeare_out = shakespeare[:, 1:] 214 | 215 | # We can set a save_path in the runner to automatically checkpoint every so 216 | # often. Otherwise at the end of the session, the model will be lost. 217 | runner = pt.train.Runner(save_path=FLAGS.save_path) 218 | with tf.Session(): 219 | for epoch in xrange(FLAGS.epochs): 220 | # Shuffle the training data. 221 | shakespeare_in, shakespeare_out = data_utils.permute_data( 222 | (shakespeare_in, shakespeare_out)) 223 | 224 | runner.train_model(train_op, 225 | training_result.loss, 226 | len(shakespeare_in) // BATCH_SIZE, 227 | feed_vars=(input_placeholder, output_placeholder), 228 | feed_data=pt.train.feed_numpy( 229 | BATCH_SIZE, shakespeare_in, shakespeare_out), 230 | print_every=10) 231 | classification_accuracy = runner.evaluate_model( 232 | accuracy, 233 | len(shakespeare_in) // BATCH_SIZE, 234 | feed_vars=(input_placeholder, output_placeholder), 235 | feed_data=pt.train.feed_numpy(BATCH_SIZE, shakespeare_in, 236 | shakespeare_out)) 237 | 238 | print('Next character accuracy after epoch %d: %g%%' % ( 239 | epoch + 1, classification_accuracy * 100)) 240 | 241 | # Use a temperature smaller than 1 because the early stages of the model 242 | # don't assign much confidence. 243 | print(sample(inference_input, 244 | inference_logits, 245 | max_length=128, 246 | temperature=0.5)) 247 | 248 | # Print a sampling from the model. 249 | print(sample(inference_input, inference_logits)) 250 | 251 | 252 | if __name__ == '__main__': 253 | tf.app.run() 254 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | import fnmatch 13 | import os 14 | from setuptools import find_packages, setup, Extension 15 | 16 | _VERSION = '0.7.2' 17 | 18 | REQUIRED_PACKAGES = [ 19 | 'enum34 >= 1.0.0', 20 | 'six >= 1.10.0', 21 | # Having an explicit dependency here breaks GPU 22 | # tensorflow >= 0.12.0rc0', 23 | ] 24 | 25 | # pylint: disable=line-too-long 26 | CONSOLE_SCRIPTS = [ 27 | 'prettytensor_model_mnist = prettytensor.tutorial.mnist:main', 28 | 'prettytensor_model_shakespeare = prettytensor.tutorial.shakespeare:main', 29 | 'prettytensor_model_baby_names = prettytensor.tutorial.baby_names:main', 30 | ] 31 | # pylint: enable=line-too-long 32 | 33 | TEST_PACKAGES = [ 34 | 'nose >= 1.3.7', 35 | ] 36 | 37 | setup( 38 | name='prettytensor', 39 | version=_VERSION, 40 | description='Pretty Tensor makes learning beautiful', 41 | long_description='', 42 | url='https://github.com/google/prettytensor', 43 | author='Eider Moore', 44 | author_email='opensource@google.com', 45 | # Contained modules and scripts. 46 | packages=find_packages(), 47 | include_package_data=True, 48 | package_data={ 49 | 'prettytensor': ['tutorial/baby_names.csv'] 50 | }, 51 | entry_points={ 52 | 'console_scripts': CONSOLE_SCRIPTS 53 | }, 54 | install_requires=REQUIRED_PACKAGES, 55 | tests_require=REQUIRED_PACKAGES + TEST_PACKAGES, 56 | test_suite = 'nose.collector', 57 | # PyPI package information. 58 | classifiers=[ 59 | 'Development Status :: 4 - Beta', 60 | 'Intended Audience :: Developers', 61 | 'Intended Audience :: Education', 62 | 'Intended Audience :: Science/Research', 63 | 'License :: OSI Approved :: Apache Software License', 64 | 'Programming Language :: Python :: 2.7', 65 | 'Topic :: Scientific/Engineering :: Mathematics', 66 | 'Topic :: Software Development :: Libraries :: Python Modules', 67 | 'Topic :: Software Development :: Libraries', 68 | ], 69 | license='Apache 2.0', 70 | keywords='tensorflow tensor machine learning', 71 | ) 72 | -------------------------------------------------------------------------------- /test_pip_install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | cur_dir=$(pwd) 5 | 6 | # Python 2 7 | 8 | rm -rf /tmp/clean-venv 9 | virtualenv /tmp/clean-venv 10 | cd /tmp/clean-venv 11 | source bin/activate 12 | pip install --upgrade pip 13 | pip install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0rc0-cp27-none-linux_x86_64.whl 14 | pip install prettytensor 15 | pip install nose 16 | nosetests prettytensor 17 | 18 | deactivate 19 | 20 | cd "$cur_dir" 21 | # Python 3 22 | 23 | rm -rf /tmp/clean-venv 24 | virtualenv -p python3 /tmp/clean-venv 25 | cd /tmp/clean-venv 26 | source bin/activate 27 | pip install --upgrade pip 28 | pip install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0rc0-cp34-cp34m-linux_x86_64.whl 29 | pip install prettytensor 30 | pip install nose 31 | nosetests prettytensor 32 | 33 | deactivate 34 | rm -rf /tmp/clean-venv 35 | 36 | cd "$cur_dir" 37 | -------------------------------------------------------------------------------- /unrolled_lstm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/prettytensor/75daa0b11252590f548da5647addc0ea610c4c45/unrolled_lstm.png --------------------------------------------------------------------------------