├── CHANGELIST.md
├── CONTRIBUTING.md
├── LICENSE-2.0.txt
├── README.md
├── docs
├── Bookkeeper.md
├── Loss.md
├── PrettyTensor.md
├── PrettyTensorTupleMixin.md
└── pretty_tensor_top_level.md
├── inception_module.png
├── prettytensor
├── __init__.py
├── bookkeeper.py
├── bookkeeper_test.py
├── chain_dict.py
├── chain_dict_test.py
├── funcs.py
├── functions.py
├── functions_test.py
├── input_helpers.py
├── layers.py
├── local_trainer.py
├── local_trainer_test.py
├── parameters.py
├── pretty_tensor_class.py
├── pretty_tensor_image_methods.py
├── pretty_tensor_loss_methods.py
├── pretty_tensor_methods.py
├── pretty_tensor_normalization_methods.py
├── pretty_tensor_sparse_methods.py
├── pretty_tensor_test.py
├── pretty_tensor_testing.py
├── recurrent_networks.py
├── recurrent_networks_test.py
├── recurrent_networks_testing_utils.py
├── replay_queue.py
├── replay_queue_test.py
├── scopes.py
├── scopes_test.py
├── sequence_with_deltas.py
├── templated_pretty_tensor_test.py
├── train.py
└── tutorial
│ ├── README.md
│ ├── __init__.py
│ ├── baby_names.csv
│ ├── baby_names.py
│ ├── data_utils.py
│ ├── mnist.py
│ └── shakespeare.py
├── setup.py
├── test_pip_install.sh
└── unrolled_lstm.png
/CHANGELIST.md:
--------------------------------------------------------------------------------
1 | ## 0.7.3
2 |
3 | 1. Maintenance release w/ some deprecation notice fixes. Note: this may change the names of the summaries.
4 |
5 | ## 0.7.2
6 |
7 | 1. Maintenance release w/ change to project pip dependencies to better support GPU builds.
8 |
9 | ## 0.7.1
10 |
11 | ### General
12 | 1. Changed `weights` to `init` and `bias_init` to `bias` and made these support initialization functions or `Tensors`.
13 | 2. Added `parameter_modifier`. These are functions that are applied after creating a `Variable`, but before it is used in the graph. They allow you to apply a function like normalization or drop connect to the graph. See `pt.parameters` for details.
14 | 3. Added support for directly chaining many useful TensorFlow functions. See [pretty_tensor_methods.py](https://github.com/google/prettytensor/blob/master/prettytensor/pretty_tensor_methods.py#L700) for details. Note: when a function is removed from tf (e.g. complex_abs), it will be removed here.
15 | 1. Changed internal calls to TF to comply with API changes.
16 | 2. Internally changed the name of the first parameter to be more consistent. This should not be user visible since it is the variable to the left of the '.'.
17 |
18 | ### Losses
19 |
20 | 1. Added `per_output_weights` to `binary_cross_entropy_with_logits` and that allow you to weight the loss from classes and examples.
21 | 2. Added `sparse_cross_entropy` to efficiently calculate the loss of when you have a vector of 1 hot labels as indices (`tf.int32`/`tf.int64`). Also added `evaluate_classifier_sparse`.
22 | 3. Fixed `softmax_classifier_with_sampled_loss` to support specified parameters and parameter modification.
23 | 4. Standardized on `num_classes` and changed the parameter name in `softmax_classifier` accordingly.
24 |
25 | ### Optimizer
26 | 1. Added `clip_gradients_by_norm` to `apply_optimizer`.
27 |
28 | ### Images
29 |
30 | 1. Added a differentiable sampling method for images called `bilinear_sampling`.
31 |
32 |
33 | ## 0.6.2
34 |
35 | Add Depthwise Convolution
36 |
37 | ### Batch Normalization
38 | 1. Make Batch Normalization work with arbitrary dimensionality.
39 | 2. Allow passing through arguments to BN using a namedtuple.
40 | 3. Add BN default values.
41 | 4. Remove requirement to use with_update_ops to make BN accumulate values for
42 | inference.
43 |
44 |
45 |
46 | ## 0.6.0
47 |
48 | 1. Adding scoped control of summary creation.
49 | 2. Scoped variable collections.
50 | 3. Can initialize variables from literals.
51 | 4. Fixed operators -- Sequential's plus no longer has side effects.
52 | 5. Operators now work on Pretty Tensors that contain lists.
53 |
54 |
55 | Note: (4) may be breaking!
56 |
57 | ## 0.5.3
58 |
59 | 1. Fixing tutorials (thanks jkahn!)
60 | 2. Adding a precicion and recall evaluation.
61 | 3. Various bug fixes.
62 |
63 | Tested on TF 0.7.1
64 |
65 | ## 0.5.2
66 |
67 | 1. Various bug fixes
68 | 2. Reordered the arguments to a better positional order.
69 | 3. Added a length argument to recurrent networks to support short circuiting.
70 | 4. Improvements to reshape.
71 | 5. Python 3 support.
72 |
73 | ## 0.5.0
74 |
75 | Initial Release
76 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | Want to contribute? Great! First, read this page (including the small print at the end).
2 |
3 | ### Before you contribute
4 | Before we can use your code, you must sign the
5 | [Google Individual Contributor License Agreement]
6 | (https://cla.developers.google.com/about/google-individual)
7 | (CLA), which you can do online. The CLA is necessary mainly because you own the
8 | copyright to your changes, even after your contribution becomes part of our
9 | codebase, so we need your permission to use and distribute your code. We also
10 | need to be sure of various other things—for instance that you'll tell us if you
11 | know that your code infringes on other people's patents. You don't have to sign
12 | the CLA until after you've submitted your code for review and a member has
13 | approved it, but you must do it before we can put your code into our codebase.
14 | Before you start working on a larger contribution, you should get in touch with
15 | us first through the issue tracker with your idea so that we can help out and
16 | possibly guide you. Coordinating up front makes it much easier to avoid
17 | frustration later on.
18 |
19 | ### Code reviews
20 | All submissions, including submissions by project members, require review. We
21 | use Github pull requests for this purpose.
22 |
23 | ### The small print
24 | Contributions made by corporations are covered by a different agreement than
25 | the one above, the
26 | [Software Grant and Corporate Contributor License Agreement]
27 | (https://cla.developers.google.com/about/google-corporate).
28 |
--------------------------------------------------------------------------------
/LICENSE-2.0.txt:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Pretty Tensor - Fluent Neural Networks in TensorFlow
2 |
3 | Pretty Tensor provides a high level builder API for TensorFlow. It provides
4 | thin wrappers on Tensors so that you can easily build multi-layer neural
5 | networks.
6 |
7 | Pretty Tensor provides a set of objects that behave likes Tensors, but also
8 | support a chainable object syntax to quickly define neural networks
9 | and other layered architectures in TensorFlow.
10 |
11 | result = (pretty_tensor.wrap(input_data, m)
12 | .flatten()
13 | .fully_connected(200, activation_fn=tf.nn.relu)
14 | .fully_connected(10, activation_fn=None)
15 | .softmax(labels, name=softmax_name))
16 |
17 | Please look here for full documentation of the PrettyTensor object for all
18 | available operations:
19 | [Available Operations](docs/PrettyTensor.md) or you can check out the [complete
20 | documentation](docs/pretty_tensor_top_level.md)
21 |
22 | See the tutorial directory for samples:
23 | [tutorial/](prettytensor/tutorial/)
24 |
25 | ## Installation
26 |
27 | The easiest installation is just to use pip:
28 |
29 | 1. Follow the instructions at
30 | [tensorflow.org](https://www.tensorflow.org/versions/master/get_started/os_setup.html#pip_install)
31 | 2. `pip install prettytensor`
32 |
33 |
34 | **Note:** Head is tested against the TensorFlow nightly builds and pip is tested against TensorFlow release.
35 |
36 | ## Quick start
37 |
38 | ### Imports
39 | import prettytensor as pt
40 | import tensorflow as tf
41 |
42 | ### Setup your input
43 | my_inputs = # numpy array of shape (BATCHES, BATCH_SIZE, DATA_SIZE)
44 | my_labels = # numpy array of shape (BATCHES, BATCH_SIZE, CLASSES)
45 | input_tensor = tf.placeholder(np.float32, shape=(BATCH_SIZE, DATA_SIZE))
46 | label_tensor = tf.placeholder(np.float32, shape=(BATCH_SIZE, CLASSES))
47 | pretty_input = pt.wrap(input_tensor)
48 |
49 | ### Define your model
50 | softmax, loss = (pretty_input.
51 | fully_connected(100).
52 | softmax_classifier(CLASSES, labels=label_tensor))
53 |
54 | ### Train and evaluate
55 | accuracy = softmax.evaluate_classifier(label_tensor)
56 |
57 | optimizer = tf.train.GradientDescentOptimizer(0.1) # learning rate
58 | train_op = pt.apply_optimizer(optimizer, losses=[loss])
59 |
60 | init_op = tf.initialize_all_variables()
61 |
62 | with tf.Session() as sess:
63 | sess.run(init_op)
64 | for inp, label in zip(my_inputs, my_labels):
65 | unused_loss_value, accuracy_value = sess.run([loss, accuracy],
66 | {input_tensor: inp, label_tensor: label})
67 | print 'Accuracy: %g' % accuracy_value
68 |
69 | ## Features
70 |
71 | ### Thin
72 |
73 | #### Full power of TensorFlow is easy to use
74 |
75 | Pretty Tensors can be used (almost) everywhere that a tensor can. Just call
76 | `pt.wrap` to make a tensor pretty.
77 |
78 | You can also add any existing TensorFlow function to the chain using `apply`.
79 | `apply` applies the current Tensor as the first argument and takes all the other
80 | arguments as normal.
81 |
82 | *Note:* because apply is so generic, Pretty Tensor doesn't try to wrap the
83 | world.
84 |
85 | #### Plays well with other libraries
86 |
87 | It also uses standard TensorFlow idioms so that it plays well with other
88 | libraries, this means that you can use it a little bit in a model or throughout.
89 | Just make sure to run the update_ops on each training set
90 | (see [with_update_ops](docs/pretty_tensor_top_level.md#with_update_ops)).
91 |
92 | ### Terse
93 |
94 | You've already seen how a Pretty Tensor is chainable and you may have noticed
95 | that it takes care of handling the input shape. One other feature worth noting
96 | are defaults. Using defaults you can specify reused values in a single place
97 | without having to repeat yourself.
98 |
99 | with pt.defaults_scope(activation_fn=tf.nn.relu):
100 | hidden_output2 = (pretty_images.flatten()
101 | .fully_connected(100)
102 | .fully_connected(100))
103 |
104 | Check out the documentation to see
105 | [all supported defaults](docs/pretty_tensor_top_level.md#defaults_scope).
106 |
107 | ### Code matches model
108 |
109 | Sequential mode lets you break model construction across lines and provides
110 | the subdivide syntactic sugar that makes it easy to define and understand
111 | complex structures like an [inception module](http://arxiv.org/abs/1409.4842):
112 |
113 |
114 | with pretty_tensor.defaults_scope(activation_fn=tf.nn.relu):
115 | seq = pretty_input.sequential()
116 | with seq.subdivide(4) as towers:
117 | towers[0].conv2d(1, 64)
118 | towers[1].conv2d(1, 112).conv2d(3, 224)
119 | towers[2].conv2d(1, 32).conv2d(5, 64)
120 | towers[3].max_pool(2, 3).conv2d(1, 32)
121 |
122 | 
123 |
124 | Templates provide guaranteed parameter reuse and make unrolling recurrent
125 | networks easy:
126 |
127 | output = [], s = tf.zeros([BATCH, 256 * 2])
128 |
129 | A = (pretty_tensor.template('x')
130 | .lstm_cell(num_units=256, state=UnboundVariable('state'))
131 |
132 | for x in pretty_input_array:
133 | h, s = A.construct(x=x, state=s)
134 | output.append(h)
135 |
136 | There are also some convenient shorthands for LSTMs and GRUs:
137 |
138 | pretty_input_array.sequence_lstm(num_units=256)
139 |
140 | 
141 |
142 | ### Extensible
143 |
144 | You can call any existing operation by using `apply` and it will simply
145 | subsitute the current tensor for the first argument.
146 |
147 | pretty_input.apply(tf.mul, 5)
148 |
149 | You can also create a new operation There are two supported registration
150 | mechanisms to add your own functions. `@Register()` allows you to create a
151 | method on PrettyTensor that operates on the Tensors and returns either a loss or
152 | a new value. Name scoping and variable scoping are handled by the framework.
153 |
154 | The following method adds the leaky_relu method to every Pretty Tensor:
155 |
156 | @pt.Register
157 | def leaky_relu(input_pt):
158 | return tf.select(tf.greater(input_pt, 0.0), input_pt, 0.01 * input_pt)
159 |
160 |
161 | `@RegisterCompoundOp()` is like adding a macro, it is designed to group together
162 | common sets of operations.
163 |
164 | ### Safe variable reuse
165 |
166 | Within a graph, you can reuse variables by using templates. A template is
167 | just like a regular graph except that some variables are left unbound.
168 |
169 | See more details in [PrettyTensor class](docs/PrettyTensor.md).
170 |
171 | ### Accessing Variables
172 |
173 | Pretty Tensor uses the standard graph collections from TensorFlow to store variables. These can be accessed using `tf.get_collection(key)` with the following keys:
174 |
175 | * `tf.GraphKeys.VARIABLES`: all variables that should be saved (including some statistics).
176 | * `tf.GraphKeys.TRAINABLE_VARIABLES: all variables that can be trained (including those before a `stop_gradients` call). These are what would typically be called *parameters* of the model in ML parlance.
177 | * `pt.GraphKeys.TEST_VARIABLES`: variables used to evaluate a model. These are typically not saved and are reset by the LocalRunner.evaluate method to get a fresh evaluation.
178 |
179 | ## Authors
180 |
181 | Eider Moore (eiderman)
182 |
183 | with key contributions from:
184 |
185 | * Hubert Eichner
186 | * Oliver Lange
187 | * Sagar Jain (sagarjn)
188 |
--------------------------------------------------------------------------------
/docs/Bookkeeper.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Bookkeeper
4 |
5 | Small class to gather needed pieces from a Graph being built.
6 |
7 | This class is mostly an implementation detail of Pretty Tensor and almost
8 | never needs to be used when building a model. Most of the useful methods
9 | are exposed in the `pt` namespace. The most common usecase for directly
10 | calling a Bookkeeper methods are to create summaries in the same way as
11 | Pretty Tensor that are controlled by the `pt.defaults_scope`.
12 |
13 | - - -
14 |
15 | [TOC]
16 |
17 |
18 | ## add_average_summary(var, tag=None, decay=0.999, ignore_nan=True)
19 |
20 |
21 |
22 | Add a summary with the moving average of var.
23 |
24 | Adds a variable to keep track of the exponential moving average and adds an
25 | update operation to the bookkeeper. The name of the variable is
26 | '%s_average' % name prefixed with the current variable scope.
27 |
28 | #### Args:
29 |
30 |
31 | * var: The variable for which a moving average should be computed.
32 | * tag: The tag of the summary. If None var.name[:-2] is used to strip off
33 | the ':0' that is added by TF.
34 | * decay: How much history to use in the moving average.
35 | Higher, means more history values [0.9, 1) accepted.
36 | * ignore_nan: If the value is NaN or Inf, skip it. Note that this default
37 | is different than the exponential_moving_average one.
38 |
39 | #### Returns:
40 |
41 | The averaged variable.
42 |
43 |
44 | #### Raises:
45 |
46 |
47 | * ValueError: if decay is not in [0.9, 1).
48 |
49 |
50 | - - -
51 |
52 | ## add_histogram_summary(x, tag=None)
53 |
54 |
55 |
56 | Add a summary operation to visualize the histogram of x's values.
57 |
58 |
59 |
60 |
61 |
62 | - - -
63 |
64 | ## add_loss(loss, name=None, regularization=False, add_summaries=True)
65 |
66 |
67 |
68 | Append a loss to the total loss for the network.
69 |
70 | #### Args:
71 |
72 |
73 | * loss: append this loss operation
74 | * name: The name for this loss, defaults to loss.op.name
75 | * regularization: Set to True if this is a regularization loss.
76 | * add_summaries: Set to True if you want to see scalar and average summary.
77 |
78 |
79 |
80 |
81 |
82 | - - -
83 |
84 | ## add_losses(losses, regularization=False)
85 |
86 |
87 |
88 |
89 | - - -
90 |
91 | ## add_scalar_summary(x, tag=None)
92 |
93 |
94 |
95 | Adds a scalar summary for x.
96 |
97 |
98 |
99 |
100 |
101 | - - -
102 |
103 | ## create_composite_loss(losses, regularize=True, include_marked=True, name=cost)
104 |
105 |
106 |
107 | Creates a loss that is the sum of all specified losses.
108 |
109 | #### Args:
110 |
111 |
112 | * losses: A sequence of losses to include.
113 | * regularize: Whether or not to include regularization losses.
114 | * include_marked: Whether or not to use the marked losses.
115 | * name: The name for this variable.
116 |
117 | #### Returns:
118 |
119 | A single tensor that is the sum of all losses.
120 |
121 |
122 | #### Raises:
123 |
124 |
125 | * ValueError: if there are no losses.
126 |
127 |
128 | - - -
129 |
130 | ## exponential_moving_average(var, avg_var=None, decay=0.999, ignore_nan=False)
131 |
132 |
133 |
134 | Calculates the exponential moving average.
135 |
136 | TODO(): check if this implementation of moving average can now
137 | be replaced by tensorflows implementation.
138 |
139 | Adds a variable to keep track of the exponential moving average and adds an
140 | update operation to the bookkeeper. The name of the variable is
141 | '%s_average' % name prefixed with the current variable scope.
142 |
143 | #### Args:
144 |
145 |
146 | * var: The variable for which a moving average should be computed.
147 | * avg_var: The variable to set the average into, if None create a zero
148 | initialized one.
149 | * decay: How much history to use in the moving average.
150 | Higher, means more history values [0, 1) accepted.
151 | * ignore_nan: If the value is NaN or Inf, skip it.
152 |
153 | #### Returns:
154 |
155 | The averaged variable.
156 |
157 |
158 | #### Raises:
159 |
160 |
161 | * ValueError: if decay is not in [0, 1).
162 |
163 |
164 | - - -
165 |
166 | ## reset_summary_collections()
167 |
168 |
169 |
170 | Sets the summary collections to the default.
171 |
172 |
173 |
174 |
175 |
176 | - - -
177 |
178 | ## with_update_ops(train_op)
179 |
180 |
181 |
182 |
183 | - - -
184 | ## Properties
185 |
186 | * g
187 | * global_step
188 | * marked_losses
189 | * recurrent_state
190 | * regularization_losses
191 | * summaries
192 | * update_ops
193 |
--------------------------------------------------------------------------------
/docs/PrettyTensorTupleMixin.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # PrettyTensorTupleMixin
4 |
5 | Adds methods to any sequence type so that it can be used with binding.
6 |
7 | Generally this can be used with namedtuples to provide clean multi-value
8 | returns:
9 |
10 | class MyType(namedtuple(...), PrettyTensorTupleMixin):
11 | pass
12 |
13 | Subclasses with nested structure should note that this does not unpack
14 | nested structure by default. You must implement flatten and
15 | build_from_flattened.
16 |
17 | - - -
18 |
19 | [TOC]
20 |
21 |
22 | ## as_fn()
23 |
24 |
25 |
26 | Creates a function by binding the arguments in the given order.
27 |
28 | #### Args:
29 |
30 |
31 | * binding_order: The unbound variables. This must include all values.
32 |
33 | #### Returns:
34 |
35 | A function that takes the arguments of binding_order.
36 |
37 |
38 | #### Raises:
39 |
40 |
41 | * ValueError: If the bindings are missing values or include unknown values.
42 |
43 |
44 | - - -
45 |
46 | ## bind()
47 |
48 |
49 |
50 | Makes the bindings to each item in this and returns a new tuple.
51 |
52 |
53 |
54 |
55 |
56 | - - -
57 |
58 | ## build_from_flattened(flattened)
59 |
60 |
61 |
62 | Given a flattened structure from flatten, make a new version of this.
63 |
64 |
65 |
66 |
67 |
68 | - - -
69 |
70 | ## construct()
71 |
72 |
73 |
74 |
75 | - - -
76 |
77 | ## flatten()
78 |
79 |
80 |
81 | Subclasses with nested structure should implement this method.
82 |
83 |
84 | #### Returns:
85 |
86 | A list of data that should be bound and constructed, by default just self.
87 |
88 |
89 |
90 |
91 | - - -
92 |
93 | ## has_unbound_vars()
94 |
95 |
96 |
97 | Returns whether there are any unbound vars in this tuple.
98 |
99 |
100 |
101 |
102 |
103 | - - -
104 | ## Properties
105 |
106 | * unbound_vars
107 |
--------------------------------------------------------------------------------
/docs/pretty_tensor_top_level.md:
--------------------------------------------------------------------------------
1 |
2 | # Pretty Tensor Base Imports
3 |
4 |
5 |
6 | [TOC]
7 | - - -
8 |
9 | ## BatchNormalizationArguments
10 |
11 | BatchNormalizationArguments(learned_moments_update_rate, variance_epsilon, scale_after_normalization)
12 | - - -
13 |
14 | ### Properties
15 |
16 | * count
17 | * index
18 | * learned_moments_update_rate
19 | * scale_after_normalization
20 | * variance_epsilon
21 |
22 | - - -
23 | ## Class Bookkeeper
24 |
25 | [see details in Bookkeeper.md](Bookkeeper.md)
26 | - - -
27 |
28 | ## GraphKeys
29 |
30 | Graphs can store data in graph keys for constructing the graph.
31 | - - -
32 |
33 | ### Properties
34 |
35 | * LOSSES
36 | * MARKED_LOSSES
37 | * RECURRENT_STATE_VARIABLES
38 | * REGULARIZATION_LOSSES
39 | * TEST_VARIABLES
40 | * UPDATE_OPS
41 |
42 | - - -
43 | ## Class Loss
44 |
45 | [see details in Loss.md](Loss.md)
46 | - - -
47 |
48 | ## Phase
49 |
50 | Some nodes are different depending on the phase of the graph construction.
51 |
52 | The standard phases are train, test and infer.
53 |
54 | - - -
55 |
56 | ### Properties
57 |
58 | * infer
59 | * test
60 | * train
61 |
62 | - - -
63 | ## Class PrettyTensor
64 |
65 | [see details in PrettyTensor.md](PrettyTensor.md)
66 | - - -
67 | ## Class PrettyTensorTupleMixin
68 |
69 | [see details in PrettyTensorTupleMixin.md](PrettyTensorTupleMixin.md)
70 | - - -
71 |
72 | ## Register
73 |
74 | Decorator for registering a method in PrettyTensor.
75 |
76 | This is either used to decorate a bare function or an object that has a no-arg
77 | constructor and a __call__ method.
78 |
79 | The first argument to the function will be the PrettyTensor object. The
80 | registered method's return value must be one of the following:
81 |
82 | 1. A PrettyTensor, i.e. the result of calling `with_tensor` or
83 | `with_sequence`.
84 | 2. A Tensor.
85 | 3. A Loss result from calling `add_loss`.
86 |
87 | `RegisterCompoundOp` is provided for more direct manipulations with some
88 | caveats.
89 |
90 | - - -
91 |
92 |
93 | ### create_deferred(func, input_layer, deferred_args, deferred_kwargs, name)
94 |
95 |
96 |
97 | Creates a deferred node with captured scope.
98 |
99 | ##### Args:
100 |
101 |
102 | * func: The original function to call.
103 | * input_layer: The input_layer.
104 | * deferred_args: The arguments that will be used bythe deferred function.
105 | * deferred_kwargs: The keyword args for the deferred function.
106 | * name: The name of this layer.
107 |
108 | ##### Returns:
109 |
110 | A _DeferredLayer that will execute func in the correct scopes.
111 |
112 |
113 |
114 |
115 |
116 | ### create_method(obj)
117 |
118 |
119 |
120 |
121 |
122 | ### fill_kwargs(input_layer, kwargs)
123 |
124 |
125 |
126 | Applies name_suffix and defaults to kwargs and returns the result.
127 |
128 |
129 |
130 |
131 |
132 |
133 | - - -
134 |
135 | ## RegisterCompoundOp
136 |
137 | This is used to register a compound operation.
138 |
139 | The operation is executed immediately on the base PrettyTensor type. This has
140 | the following implications:
141 |
142 | 1. `tensor` and `sequence` may not be available in the deferred case.
143 | 2. The object passed in might be sequential or a layer.
144 |
145 | Also because this is intended to provide convenience chaining of other
146 | registered methods, it does not add a name or id scope automatically, which
147 | makes it behave as if the raw methods were called (unless the op itself does
148 | scoping).
149 |
150 | - - -
151 |
152 |
153 | ### create_method(func)
154 |
155 |
156 |
157 | Creates the method.
158 |
159 |
160 |
161 |
162 |
163 |
164 | ### fill_kwargs(input_layer, kwargs)
165 |
166 |
167 |
168 | Applies name_suffix and defaults to kwargs and returns the result.
169 |
170 |
171 |
172 |
173 |
174 |
175 | - - -
176 |
177 | ## UnboundVariable
178 |
179 | An UnboundVariable is a variable with a value that is supplied using bind.
180 |
181 | UnboundVariables are typically used so that input layers can be specified at a
182 | later time or for hyper parameters. Supplying a UnboundVariable as an input
183 | variable automatically forces the graph to be a template.
184 |
185 | - - -
186 |
187 |
188 | ### has_default()
189 |
190 |
191 |
192 |
193 |
194 | - - -
195 |
196 | ## VarStoreMethod
197 |
198 | Convenience base class for registered methods that create variables.
199 |
200 | This tracks the variables and requries subclasses to provide a __call__
201 | method.
202 |
203 | - - -
204 |
205 |
206 | ### variable(var_name, shape, init, dt=, train=None)
207 |
208 |
209 |
210 | Adds a named variable to this bookkeeper or returns an existing one.
211 |
212 | Variables marked train are returned by the training_variables method. If
213 | the requested name already exists and it is compatible (same shape, dt and
214 | train) then it is returned. In case of an incompatible type, an exception is
215 | thrown.
216 |
217 | ##### Args:
218 |
219 |
220 | * var_name: The unique name of this variable. If a variable with the same
221 | name exists, then it is returned.
222 | * shape: The shape of the variable.
223 | * init: The init function to use or a Tensor to copy.
224 | * dt: The datatype, defaults to float. This will automatically extract the
225 | base dtype.
226 | * train: Whether or not the variable should be trained; defaults to
227 | True unless a default_scope has overridden it.
228 |
229 | ##### Returns:
230 |
231 | A TensorFlow tensor.
232 |
233 |
234 | ##### Raises:
235 |
236 |
237 | * ValueError: if reuse is False (or unspecified and allow_reuse is False)
238 | and the variable already exists or if the specification of a reused
239 | variable does not match the original.
240 |
241 |
242 |
243 | - - -
244 |
245 | ## apply_optimizer(losses, regularize=True, include_marked=True, clip_gradients_by_norm=None)
246 |
247 |
248 |
249 | Apply an optimizer to the graph and returns a train_op.
250 |
251 | The resulting operation will minimize the specified losses, plus the
252 | regularization losses that have been collected during graph construction and
253 | the losses that were marked by calling `mark_as_required`.
254 |
255 | It will also apply any updates that have been collected (e.g. for moving
256 | average summaries).
257 |
258 | This is equivalent to:
259 |
260 | total_loss = prettytensor.create_composite_loss(
261 | losses=losses, regularize=regularize, include_marked=include_marked)
262 | train_op_without_updates = optimizer.minimize(total_loss)
263 | train_op = prettytensor.with_update_ops(train_op_without_updates)
264 |
265 | N.B. Pay special attention to the `gate_gradients` argument to the optimizer.
266 | If your graph is large, it will likely train unacceptably slow if you don't
267 | specify it as GATE_NONE.
268 |
269 | #### Args:
270 |
271 |
272 | * optimizer: The optimizer the minimize.
273 | * losses: A list of losses to apply.
274 | * regularize: Whether or not to include the regularization losses.
275 | * include_marked: Whether or not to use the marked losses.
276 | * clip_gradients_by_norm: If not None, clip gradients by the norm using
277 | `tf.clip_by_norm`.
278 | **kwargs: Additional arguments to pass into the optimizer.
279 |
280 | #### Returns:
281 |
282 | An operation to use for training that also updates any required ops such as
283 | moving averages.
284 |
285 |
286 |
287 |
288 | - - -
289 |
290 | ## for_default_graph()
291 |
292 |
293 |
294 | Creates a bookkeeper for the default graph.
295 |
296 | #### Args:
297 |
298 |
299 | * args: Arguments to pass into Bookkeeper's constructor.
300 | **kwargs: Arguments to pass into Bookkeeper's constructor.
301 |
302 | #### Returns:
303 |
304 | A new Bookkeeper.
305 |
306 |
307 | #### Raises:
308 |
309 |
310 | * ValueError: If args or kwargs are provided and the Bookkeeper already
311 | exists.
312 |
313 |
314 | - - -
315 |
316 | ## for_new_graph()
317 |
318 |
319 |
320 | Creates a Bookkeeper for a new graph.
321 |
322 | You must use `m.g.as_default()` to put the graph in scope:
323 |
324 | m = Bookkeeper.for_new_graph()
325 | with m.g.as_default():
326 | ...
327 |
328 | #### Args:
329 |
330 |
331 | * args: Arguments to pass into Bookkeeper's constructor.
332 | **kwargs: Arguments to pass into Bookkeeper's constructor.
333 |
334 | #### Returns:
335 |
336 | A new Bookkeeper.
337 |
338 |
339 |
340 |
341 | - - -
342 |
343 | ## construct_all()
344 |
345 |
346 |
347 | Constructs all the given templates in a single pass without redundancy.
348 |
349 | This is useful when the templates have a common substructure and you want the
350 | smallest possible graph.
351 |
352 | #### Args:
353 |
354 |
355 | * templates: A sequence of templates.
356 | **unbound_var_values: The unbound_var values to replace.
357 |
358 | #### Returns:
359 |
360 | A list of results corresponding to templates.
361 |
362 |
363 | #### Raises:
364 |
365 |
366 | * TypeError: If any value in templates is unsupported.
367 | * ValueError: If the unbound_var values specified are not complete or contain
368 | unknown values.
369 |
370 |
371 | - - -
372 |
373 | ## create_composite_loss(regularize=True, include_marked=True, name=cost)
374 |
375 |
376 |
377 | Creates a loss that is the sum of all specified losses.
378 |
379 | #### Args:
380 |
381 |
382 | * losses: A sequence of losses to include.
383 | * regularize: Whether or not to include regularization losses.
384 | * include_marked: Whether or not to use the marked losses.
385 | * name: The name for this variable.
386 |
387 | #### Returns:
388 |
389 | A single tensor that is the sum of all losses.
390 |
391 |
392 |
393 |
394 | - - -
395 |
396 |
397 | ## defaults_scope(...
398 |
399 | defaults_scope(activation_fn=None, batch_normalize=None, l2loss=None, learned_moments_update_rate=None, parameter_modifier=None, phase=None, scale_after_normalization=None, summary_collections=None, trainable_variables=None, unroll=None, variable_collections=None, variance_epsilon=None)
400 |
401 | Creates a scope for the defaults that are used in a `with` block.
402 |
403 | Note: `defaults_scope` supports nesting where later defaults can be
404 | overridden. Also, an explicitly given keyword argument on a method always
405 | takes precedence.
406 |
407 | In addition to setting defaults for some methods, this also can control:
408 |
409 | * `summary_collections`: Choose which collection to place summaries in or
410 | disable with `None`.
411 | * `trainable_variables`: Boolean indicating if variables are trainable.
412 | * `variable_collections`: Default collections in which to place variables;
413 | `tf.GraphKeys.GLOBAL_VARIABLES` is always included.
414 |
415 | The supported defaults and methods that use them are:
416 |
417 |
418 | * `activation_fn`:
419 | * [conv2d](PrettyTensor.md#conv2d)
420 | * [depthwise_conv2d](PrettyTensor.md#depthwise_conv2d)
421 | * [fully_connected](PrettyTensor.md#fully_connected)
422 |
423 | * `batch_normalize`:
424 | * [conv2d](PrettyTensor.md#conv2d)
425 | * [depthwise_conv2d](PrettyTensor.md#depthwise_conv2d)
426 |
427 | * `l2loss`:
428 | * [conv2d](PrettyTensor.md#conv2d)
429 | * [depthwise_conv2d](PrettyTensor.md#depthwise_conv2d)
430 | * [diagonal_matrix_mul](PrettyTensor.md#diagonal_matrix_mul)
431 | * [fully_connected](PrettyTensor.md#fully_connected)
432 |
433 | * `learned_moments_update_rate`:
434 | * [batch_normalize](PrettyTensor.md#batch_normalize)
435 |
436 | * `parameter_modifier`:
437 | * [conv2d](PrettyTensor.md#conv2d)
438 | * [depthwise_conv2d](PrettyTensor.md#depthwise_conv2d)
439 | * [softmax_classifier_with_sampled_loss](PrettyTensor.md#softmax_classifier_with_sampled_loss)
440 | * [softmax_classifier](PrettyTensor.md#softmax_classifier)
441 | * [diagonal_matrix_mul](PrettyTensor.md#diagonal_matrix_mul)
442 | * [fully_connected](PrettyTensor.md#fully_connected)
443 | * [lstm_cell](PrettyTensor.md#lstm_cell)
444 | * [sequence_lstm](PrettyTensor.md#sequence_lstm)
445 | * [gru_cell](PrettyTensor.md#gru_cell)
446 | * [sequence_gru](PrettyTensor.md#sequence_gru)
447 | * [embedding_lookup](PrettyTensor.md#embedding_lookup)
448 |
449 | * `phase`:
450 | * [batch_normalize](PrettyTensor.md#batch_normalize)
451 | * [conv2d](PrettyTensor.md#conv2d)
452 | * [depthwise_conv2d](PrettyTensor.md#depthwise_conv2d)
453 | * [evaluate_precision_recall](PrettyTensor.md#evaluate_precision_recall)
454 | * [evaluate_classifier_fraction](PrettyTensor.md#evaluate_classifier_fraction)
455 | * [evaluate_classifier](PrettyTensor.md#evaluate_classifier)
456 | * [evaluate_classifier_fraction_sparse](PrettyTensor.md#evaluate_classifier_fraction_sparse)
457 | * [evaluate_classifier_sparse](PrettyTensor.md#evaluate_classifier_sparse)
458 | * [dropout](PrettyTensor.md#dropout)
459 | * [diagonal_matrix_mul](PrettyTensor.md#diagonal_matrix_mul)
460 | * [fully_connected](PrettyTensor.md#fully_connected)
461 | * [lstm_cell](PrettyTensor.md#lstm_cell)
462 | * [sequence_lstm](PrettyTensor.md#sequence_lstm)
463 | * [gru_cell](PrettyTensor.md#gru_cell)
464 | * [sequence_gru](PrettyTensor.md#sequence_gru)
465 | * [embedding_lookup](PrettyTensor.md#embedding_lookup)
466 |
467 | * `scale_after_normalization`:
468 | * [batch_normalize](PrettyTensor.md#batch_normalize)
469 |
470 | * `unroll`:
471 | * [cleave_sequence](PrettyTensor.md#cleave_sequence)
472 |
473 | * `variance_epsilon`:
474 | * [batch_normalize](PrettyTensor.md#batch_normalize)
475 |
476 | - - -
477 |
478 | ## global_step()
479 |
480 |
481 |
482 | Returns the global step variable.
483 |
484 |
485 |
486 |
487 |
488 | - - -
489 |
490 | ## join_pretty_tensors(output, join_function=None, name=join)
491 |
492 |
493 |
494 | Joins the list of pretty_tensors and sets head of output_pretty_tensor.
495 |
496 | #### Args:
497 |
498 |
499 | * tensors: A sequence of Layers or SequentialLayerBuilders to join.
500 | * output: A pretty_tensor to set the head with the result.
501 | * join_function: A function to join the tensors, defaults to concat on the
502 | last dimension.
503 | * name: A name that is used for the name_scope
504 |
505 | #### Returns:
506 |
507 | The result of calling with_tensor on output
508 |
509 |
510 | #### Raises:
511 |
512 |
513 | * ValueError: if pretty_tensors is None or empty.
514 |
515 |
516 | - - -
517 |
518 | ## make_template(func)
519 |
520 |
521 |
522 | Given an arbitrary function, wrap it so that it does parameter sharing.
523 |
524 |
525 |
526 |
527 |
528 | - - -
529 |
530 | ## recurrent_state()
531 |
532 |
533 |
534 |
535 | - - -
536 |
537 | ## set_recurrent_state_saver()
538 |
539 |
540 |
541 | Sets the state saver used for recurrent sequences.
542 |
543 |
544 |
545 |
546 |
547 | - - -
548 |
549 | ## template(books=None, optional=False)
550 |
551 |
552 |
553 | Starts a Pretty Tensor graph template.
554 |
555 | ## Template Mode
556 |
557 | Templates allow you to define a graph with some unknown
558 | values. The most common use case is to leave the input undefined and then
559 | define a graph normally. The variables are only defined once the first time
560 | the graph is constructed. For example:
561 |
562 | template = (pretty_tensor.template('input')
563 | .fully_connected(200, name='l1')
564 | .fully_connected(200, name='l2'))
565 | train_output = template.construct(input=train_data)
566 |
567 | # All parameters are reused when the same template object is called again.
568 | test_output = template.construct(input=test_data)
569 |
570 | Any argument to a pretty tensor method can be substituted by using an
571 | `UnboundVariable`.
572 | This allows you to parameterize a graph in arbitrary ways. The most cannonical
573 | usage would be to substitute a phase variable.
574 |
575 | with pretty_tensor.defaults_scope(phase=UnboundVariable('train')):
576 | # dropout uses train to optionaly disable itself.
577 |
578 | template = (pretty_tensor.template('input')
579 | .fully_connected(200, name='l1')
580 | .fully_connected(200, name='l2')
581 | .dropout(.8))
582 | train_output = template.construct(input=train_data, train=True)
583 | test_output = template.construct(input=test_data, train=False)
584 |
585 |
586 | You should use caution because if a template is called with incompatible
587 | values (e.g. train and test using different widths), then it will break. This
588 | is because we guarantee variable reuse across instantiations.
589 |
590 | template = (pretty_tensor.template('input')
591 | .fully_connected(200, name='l1')
592 | .fully_connected(
593 | pretty_tensor.UnboundVariable('width'), name='l2'))
594 | train_output = template.construct(input=train_data, width=200)
595 |
596 | # The following line will die because the shared parameter is the wrong
597 | # size.
598 | test_output = template.construct(input=test_data, width=100)
599 |
600 |
601 | A Layer in the resulting graph can be realized by calling
602 | `bind(key=value)` and then `construct`.
603 |
604 | #### Args:
605 |
606 |
607 | * key: A key for this template, used for assigning the correct substitution.
608 | * books: The bookkeeper.
609 | * optional: If this template is an optional value.
610 |
611 | #### Returns:
612 |
613 | A template that can be constructed or attached to other layers and that
614 | guarantees parameter reuse when constructed/attached multiple times.
615 |
616 |
617 |
618 |
619 | - - -
620 |
621 | ## with_update_ops()
622 |
623 |
624 |
625 | Creates a new op that runs all of the required updates when train_op runs.
626 |
627 | #### Args:
628 |
629 |
630 | * train_op: An operation that will run every step, usually the result of an
631 | optimizer.
632 |
633 | #### Returns:
634 |
635 | A new op that returns the same value as train_op, but also runs the
636 | updaters.
637 |
638 |
639 |
640 |
641 | - - -
642 |
643 | ## wrap(books=None, tensor_shape=None)
644 |
645 |
646 |
647 | Creates an input layer representing the given tensor.
648 |
649 | #### Args:
650 |
651 |
652 | * tensor: The tensor.
653 | * books: The bookkeeper; this is usually not required unless you are building
654 | multiple `tf.Graphs.`
655 | * tensor_shape: An optional shape that will be set on the Tensor or verified
656 | to match the tensor.
657 |
658 | #### Returns:
659 |
660 | A layer.
661 |
662 |
663 |
664 |
665 | - - -
666 |
667 | ## wrap_sequence(books=None, tensor_shape=None)
668 |
669 |
670 |
671 | Creates an input layer representing the given sequence of tensors.
672 |
673 | #### Args:
674 |
675 |
676 | * sequence: A sequence of tensors.
677 | * books: The bookkeeper.
678 | * tensor_shape: An optional shape that will be set on the Tensor or verified
679 | to match the tensor.
680 |
681 | #### Returns:
682 |
683 | A layer.
684 |
685 |
686 |
687 |
688 | - - -
689 |
690 |
691 | ## Extensions
692 |
693 |
694 |
695 | - - -
696 |
--------------------------------------------------------------------------------
/inception_module.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/prettytensor/75daa0b11252590f548da5647addc0ea610c4c45/inception_module.png
--------------------------------------------------------------------------------
/prettytensor/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 Google Inc. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 | """PrettyTensor nice syntax layer on top of TensorFlow.
12 |
13 | This will eventually be the preferred place to import prettytensor.
14 |
15 | For now, please use pretty_tensor.py since this is in a state of flux.
16 |
17 | see [README.md](https://github.com/google/prettytensor) for documentation.
18 | see pretty_tensor_samples/ for usage examples.
19 | """
20 | from __future__ import absolute_import
21 | from __future__ import division
22 | from __future__ import print_function
23 |
24 | # pylint: disable=unused-import
25 | from prettytensor import funcs
26 | from prettytensor import parameters
27 | from prettytensor import train
28 | from prettytensor.bookkeeper import apply_optimizer
29 | from prettytensor.bookkeeper import Bookkeeper
30 | from prettytensor.bookkeeper import create_composite_loss
31 | from prettytensor.bookkeeper import for_default_graph as bookkeeper_for_default_graph
32 | from prettytensor.bookkeeper import for_new_graph as bookkeeper_for_new_graph
33 | from prettytensor.bookkeeper import global_step
34 | from prettytensor.bookkeeper import GraphKeys
35 | from prettytensor.bookkeeper import recurrent_state
36 | from prettytensor.bookkeeper import set_recurrent_state_saver
37 | from prettytensor.bookkeeper import with_update_ops
38 |
39 | from prettytensor.pretty_tensor_class import construct_all
40 | from prettytensor.pretty_tensor_class import defaults_scope
41 | from prettytensor.pretty_tensor_class import DIM_REST
42 | from prettytensor.pretty_tensor_class import DIM_SAME
43 | from prettytensor.pretty_tensor_class import join_pretty_tensors
44 | from prettytensor.pretty_tensor_class import Loss
45 | from prettytensor.pretty_tensor_class import PAD_SAME
46 | from prettytensor.pretty_tensor_class import PAD_VALID
47 | from prettytensor.pretty_tensor_class import Phase
48 | from prettytensor.pretty_tensor_class import PrettyTensor
49 | from prettytensor.pretty_tensor_class import PrettyTensorTupleMixin
50 | from prettytensor.pretty_tensor_class import PROVIDED
51 | from prettytensor.pretty_tensor_class import Register
52 | from prettytensor.pretty_tensor_class import RegisterCompoundOp
53 | from prettytensor.pretty_tensor_class import template
54 | from prettytensor.pretty_tensor_class import UnboundVariable
55 | from prettytensor.pretty_tensor_class import VarStoreMethod
56 | from prettytensor.pretty_tensor_class import wrap
57 | from prettytensor.pretty_tensor_class import wrap_sequence
58 |
59 | from prettytensor.pretty_tensor_normalization_methods import BatchNormalizationArguments
60 | from prettytensor.scopes import make_template
61 |
62 | __version__ = '0.7.1'
63 |
--------------------------------------------------------------------------------
/prettytensor/bookkeeper_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 Google Inc. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 | """Test class for bookkeepers."""
12 | from __future__ import absolute_import
13 | from __future__ import division
14 | from __future__ import print_function
15 |
16 | import unittest
17 |
18 |
19 | import tensorflow as tf
20 |
21 | from prettytensor import bookkeeper
22 | from prettytensor import pretty_tensor_testing
23 |
24 |
25 | class BookkeeperTest(pretty_tensor_testing.PtTestCase):
26 |
27 | def setUp(self):
28 | super(self.__class__, self).setUp()
29 |
30 | def assertSameContents(self, list1, list2, msg):
31 | self.assertEqual(len(list1), len(list2), msg)
32 | self.assertEqual(set(list1), set(list2), msg)
33 |
34 | def testGraphIsReused(self):
35 | b1 = bookkeeper.for_default_graph()
36 | b2 = bookkeeper.for_default_graph()
37 | self.assertTrue(b1 is b2)
38 |
39 | def testPassingArgsCausesError(self):
40 | b1 = bookkeeper.for_new_graph()
41 | with b1.g.as_default(), self.assertRaises(ValueError):
42 | bookkeeper.for_default_graph(global_step=None)
43 |
44 | def testGlobalStep(self):
45 | v = tf.Variable(0)
46 | b1 = bookkeeper.for_new_graph(global_step=v)
47 | with b1.g.as_default():
48 | self.assertEqual(v, bookkeeper.global_step())
49 | with self.assertRaises(ValueError):
50 | bookkeeper.for_new_graph(global_step=tf.Variable(1.0))
51 |
52 | def testUniqueBookkeeperPerGraph(self):
53 | b1 = bookkeeper.for_default_graph()
54 | with tf.Graph().as_default():
55 | b2 = bookkeeper.for_default_graph()
56 | self.assertFalse(b1 is b2)
57 |
58 | def testBareVarName(self):
59 | name = 'hello'
60 | var = tf.Variable([1], name=name)
61 | self.assertEquals(name, bookkeeper._bare_var_name(var))
62 | self.assertEquals(name,
63 | bookkeeper._bare_var_name(var._as_graph_element()))
64 |
65 |
66 | if __name__ == '__main__':
67 | unittest.main()
68 |
--------------------------------------------------------------------------------
/prettytensor/chain_dict.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 Google Inc. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 | """Creates a dict with a parent so that missing values are sent up."""
12 | from __future__ import absolute_import
13 | from __future__ import division
14 | from __future__ import print_function
15 |
16 | import collections
17 |
18 |
19 | class ChainDict(collections.MutableMapping):
20 | """The Name class."""
21 |
22 | def __init__(self, parent):
23 | self._map = {}
24 | self._parent = parent
25 | self._dead_count = 0
26 |
27 | def __getitem__(self, key):
28 | if key in self._map:
29 | return self._map[key]
30 | elif self._parent:
31 | return self._parent[key]
32 | else:
33 | raise KeyError('Key not found: %s' % key)
34 |
35 | def __setitem__(self, key, value):
36 | self._map[key] = value
37 |
38 | def __delitem__(self, key):
39 | raise Exception('Deleting items not supported.')
40 |
41 | def _full_map(self):
42 | """Creates a full mapping of this and all parent key, value pairs."""
43 | result = {}
44 | if self._parent:
45 | result.update(self._parent)
46 | result.update(self._map)
47 | return result
48 |
49 | def __iter__(self):
50 | return self._full_map().__iter__()
51 |
52 | def __len__(self):
53 | return len(self._full_map())
54 |
--------------------------------------------------------------------------------
/prettytensor/chain_dict_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 Google Inc. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 | """Test class for ChainDict."""
12 | from __future__ import absolute_import
13 | from __future__ import division
14 | from __future__ import print_function
15 |
16 | import unittest
17 |
18 | from prettytensor import chain_dict
19 |
20 |
21 | class ChainDictTest(unittest.TestCase):
22 |
23 | def testSet(self):
24 | d = chain_dict.ChainDict(None)
25 | d['KEY'] = 'VALUE'
26 | self.assertEqual({'KEY': 'VALUE'}, d._map)
27 |
28 | def testGetNoParent(self):
29 | d = chain_dict.ChainDict(None)
30 | d['KEY'] = 'VALUE'
31 | self.assertEqual('VALUE', d['KEY'])
32 |
33 | def testGetAbsentNoParent(self):
34 | d = chain_dict.ChainDict(None)
35 | with self.assertRaises(KeyError):
36 | # pylint: disable=pointless-statement
37 | d['KEY']
38 |
39 | def testGetInParent(self):
40 | parent = chain_dict.ChainDict(None)
41 | d = chain_dict.ChainDict(parent)
42 | parent['KEY'] = 'VALUE'
43 | self.assertEqual('VALUE', parent['KEY'])
44 | self.assertEqual('VALUE', d['KEY'])
45 |
46 | def testGetNotInParent(self):
47 | parent = chain_dict.ChainDict(None)
48 | d = chain_dict.ChainDict(parent)
49 | with self.assertRaises(KeyError):
50 | # pylint: disable=pointless-statement
51 | parent['KEY']
52 | with self.assertRaises(KeyError):
53 | # pylint: disable=pointless-statement
54 | d['KEY']
55 |
56 | def testGetOverridden(self):
57 | parent = chain_dict.ChainDict(None)
58 | d = chain_dict.ChainDict(parent)
59 | parent['KEY'] = 'VALUE'
60 | d['KEY'] = 'OTHER_VALUE'
61 | self.assertEqual('VALUE', parent['KEY'])
62 | self.assertEqual('OTHER_VALUE', d['KEY'])
63 |
64 | def testLen(self):
65 | parent = chain_dict.ChainDict(None)
66 | d = chain_dict.ChainDict(parent)
67 | self.assertEqual(0, len(d))
68 | self.assertEqual(0, len(parent))
69 |
70 | parent['KEY'] = 'VALUE'
71 | self.assertEqual(1, len(d))
72 | self.assertEqual(1, len(parent))
73 |
74 | d['KEY'] = 'OTHER_VALUE'
75 | self.assertEqual(1, len(d))
76 | self.assertEqual(1, len(parent))
77 |
78 | d['OTHER_KEY'] = 'YAV'
79 | self.assertEqual(2, len(d))
80 | self.assertEqual(1, len(parent))
81 |
82 | def testIteration(self):
83 | parent = chain_dict.ChainDict(None)
84 | d = chain_dict.ChainDict(parent)
85 | parent['KEY'] = 'VALUE'
86 | d['KEY'] = 'OTHER_VALUE'
87 | d['OTHER_KEY'] = 'YAV'
88 |
89 | self.assertEqual([('KEY', 'OTHER_VALUE'), ('OTHER_KEY', 'YAV')],
90 | sorted(d.items()))
91 | self.assertEqual([('KEY', 'VALUE')], sorted(parent.items()))
92 |
93 | def testPlainDictParent(self):
94 | d = chain_dict.ChainDict({'KEY': 'VALUE'})
95 | self.assertEqual('VALUE', d['KEY'])
96 | self.assertEqual(len(d), 1)
97 | # In Python 3, items produces an iterator.
98 | self.assertEqual(list(d.items()), [('KEY', 'VALUE')])
99 |
100 | if __name__ == '__main__':
101 | unittest.main()
102 |
--------------------------------------------------------------------------------
/prettytensor/funcs.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 Google Inc. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 | """Experimental functional API for PrettyTensor.
12 |
13 | This exposes all of the standard PrettyTensor functions, but instead of
14 | chaining, they are invoked like regular functions.
15 | """
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 | import sys
21 |
22 | import six
23 |
24 | # pylint: disable=unused-import, wildcard-import
25 | from prettytensor.pretty_tensor_image_methods import *
26 | from prettytensor.pretty_tensor_loss_methods import *
27 | from prettytensor.pretty_tensor_methods import *
28 | from prettytensor.pretty_tensor_sparse_methods import *
29 | from prettytensor.recurrent_networks import *
30 |
31 |
32 | def _remove_non_methods():
33 | """Removes any object in dict that is not a registered method."""
34 | cur_module = sys.modules[__name__]
35 | my_globals = dict(globals())
36 | # Import here so that it doesn't get added to the global namespace or deleted.
37 | # pylint: disable=g-import-not-at-top
38 | from prettytensor.pretty_tensor_class import PrettyTensor
39 | for name, _ in six.iteritems(my_globals):
40 | if not hasattr(PrettyTensor, name):
41 | delattr(cur_module, name)
42 | # Remove a couple of special ones....
43 | if hasattr(cur_module, 'bookkeeper'):
44 | delattr(cur_module, 'bookkeeper')
45 |
46 | _remove_non_methods()
47 |
--------------------------------------------------------------------------------
/prettytensor/functions.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 Google Inc. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 | """Reusable TensorFlow functions.
12 |
13 | This file is divided into roughly 4 parts:
14 |
15 | * _regression functions with signature: `fn(y, target, name=)` where y is
16 | the output of your DNN and target is the regression target as a tensor and
17 | the result is a tensor with shape [1].
18 | * _distance functions with signature: `fn(t1, t2, name=) where t1 and t2 are
19 | both tensors and the result is a tensor with shape [N] where N is the first
20 | dimension of t1 and t2.
21 | * Activation functions with signature `fn(x, name=)` where x is a tensor
22 | and the result is a tensor of the same shape.
23 | * Utility functions. These include normalizations that are used in embedding
24 | models as a non-linearity and a few others.
25 | """
26 | from __future__ import absolute_import
27 | from __future__ import division
28 | from __future__ import print_function
29 |
30 | from six.moves import xrange # pylint: disable=redefined-builtin
31 | import tensorflow as tf
32 |
33 | # To improve numerical stability, we want to not do an exponential above this
34 | # value.
35 | _SOFTPLUS_STABILITY_LIMIT = 20.0
36 |
37 |
38 | def l1_regression_loss(y, target, name=None):
39 | """Calculates the sum of absolute errors between y and target.
40 |
41 | Args:
42 | y: the calculated values.
43 | target: the desired values.
44 | name: the name for this op, defaults to l1_regression
45 | Returns:
46 | A tensorflow op.
47 | """
48 | with tf.name_scope(name, 'l1_regression', [y, target]) as scope:
49 | y = tf.convert_to_tensor(y, name='y')
50 | target = tf.convert_to_tensor(target, name='target')
51 | return reduce_batch_sum(tf.abs(y - target), name=scope)
52 |
53 |
54 | def l2_regression_sq_loss(y, target, name=None):
55 | """Calculates the sum of squared errors between y and target.
56 |
57 | Args:
58 | y: the calculated values.
59 | target: the desired values.
60 | name: the name for this op, defaults to l2_regression
61 | Returns:
62 | A tensorflow op.
63 | """
64 | with tf.name_scope(name, 'l2_regression_sq', [y, target]) as scope:
65 | y = tf.convert_to_tensor(y, name='y')
66 | target = tf.convert_to_tensor(target, name='target')
67 | return reduce_batch_sum(tf.square(y - target), name=scope)
68 |
69 |
70 | def reduce_batch_sum(x, name=None):
71 | with tf.name_scope(name, 'reduce_batch_sum', [x]) as scope:
72 | ndims = x.get_shape().ndims
73 | if ndims == 0:
74 | raise ValueError('Cannot reduce a scalar into batches.')
75 | elif ndims == 1:
76 | return x # Don't include a useless sum.
77 | elif ndims:
78 | reduction_indices = list(range(1, x.get_shape().ndims))
79 | shape = [x.get_shape().dims[0]]
80 | else:
81 | reduction_indices = tf.range(1, tf.size(tf.shape(x)))
82 | shape = [None] # We don't know much about the shape, but it is rank 1.
83 | result = tf.reduce_sum(x, reduction_indices=reduction_indices, name=scope)
84 |
85 | # Give a shape hint in case we have extra information.
86 | result.set_shape(shape)
87 | return result
88 |
89 |
90 | def l2_regression_loss(y, target, name=None):
91 | """Calculates the square root of the SSE between y and target.
92 |
93 | Args:
94 | y: the calculated values.
95 | target: the desired values.
96 | name: the name for this op, defaults to l2_regression
97 | Returns:
98 | A tensorflow op.
99 | """
100 | with tf.name_scope(name, 'l2_regression', [y, target]) as scope:
101 | y = tf.convert_to_tensor(y, name='y')
102 | target = tf.convert_to_tensor(target, name='target')
103 | return tf.sqrt(l2_regression_sq_loss(y, target, name=scope))
104 |
105 |
106 | def binary_cross_entropy_loss_with_logits(x, target, name=None):
107 | """Calculates the binary cross entropy between sigmoid(x) and target.
108 |
109 | Expects unscaled logits. Do not pass in results of sigmoid operation.
110 |
111 | Args:
112 | x: the calculated pre-sigmoid values
113 | target: the desired values.
114 | name: the name for this op, defaults to binary_cross_entropy_with_logits
115 | Returns:
116 | -(target * -softplus(-x) + (1-target) * (-x - softplus(-x)))
117 | Raises:
118 | ValueError: If shapes are incompatible.
119 | """
120 | with tf.name_scope(name, 'binary_cross_entropy_with_logits',
121 | [x, target]) as scope:
122 | x.get_shape().assert_is_compatible_with(target.get_shape())
123 | neg_softplus = -tf.nn.softplus(-x)
124 | return -tf.add(tf.multiply(target, neg_softplus),
125 | tf.multiply(1 - target, -x + neg_softplus),
126 | name=scope)
127 |
128 |
129 | def cos_distance(t1, t2, epsilon=1e-12, name=None):
130 | """Cos distance between t1 and t2 and caps the gradient of the Square Root.
131 |
132 | Args:
133 | t1: A tensor
134 | t2: A tensor that can be multiplied by t1.
135 | epsilon: A lower bound value for the distance. The square root is used as
136 | the normalizer.
137 | name: Optional name for this op.
138 | Returns:
139 | The cos distance between t1 and t2.
140 | """
141 | with tf.name_scope(name, 'cos_distance', [t1, t2]) as scope:
142 | t1 = tf.convert_to_tensor(t1, name='t1')
143 | t2 = tf.convert_to_tensor(t2, name='t2')
144 | x_inv_norm = tf.rsqrt(tf.maximum(length_squared(t1) * length_squared(t2),
145 | epsilon))
146 | return tf.subtract(1.0, dot_product(t1, t2) * x_inv_norm, name=scope)
147 |
148 |
149 | def dot_distance(t1, t2, name=None):
150 | """dot "distance" between t1 and t2.
151 |
152 | Args:
153 | t1: A tensor.
154 | t2: A tensor that is the same size as t1.
155 | name: Optional name for this op.
156 | Returns:
157 | The dot distance between t1 and t2.
158 | """
159 | with tf.name_scope(name, 'dot_distance', [t1, t2]) as scope:
160 | return -dot_product(t1, t2, name=scope)
161 |
162 |
163 | def l2_distance_sq(t1, t2, name=None):
164 | """Square of l2 distance between t1 and t2.
165 |
166 | Args:
167 | t1: A tensor.
168 | t2: A tensor that is the same size as t1.
169 | name: Optional name for this op.
170 | Returns:
171 | The l2 distance between t1 and t2.
172 | """
173 | with tf.name_scope(name, 'l2_distance_sq', [t1, t2]) as scope:
174 | t1 = tf.convert_to_tensor(t1, name='t1')
175 | t2 = tf.convert_to_tensor(t2, name='t2')
176 | return length_squared(tf.subtract(t1, t2), name=scope)
177 |
178 |
179 | def l2_distance(t1, t2, epsilon=1e-12, name=None):
180 | """l2 distance between t1 and t2 and caps the gradient of the Square Root.
181 |
182 | Args:
183 | t1: A tensor.
184 | t2: A tensor that is the same size as t1.
185 | epsilon: A lower bound for distance, useful to avoid sqrt of very small
186 | values that can blow up gradients.
187 | name: Optional name for this op.
188 | Returns:
189 | The l2 distance between t1 and t2.
190 | """
191 | with tf.name_scope(name, 'l2_distance', [t1, t2]) as scope:
192 | t1 = tf.convert_to_tensor(t1, name='t1')
193 | t2 = tf.convert_to_tensor(t2, name='t2')
194 | return tf.sqrt(tf.maximum(l2_distance_sq(t1, t2, scope), epsilon))
195 |
196 |
197 | def l1_distance(t1, t2, name=None):
198 | """l1 distance between t1 and t2.
199 |
200 | Args:
201 | t1: A tensor.
202 | t2: A tensor that is the same size as t1.
203 | name: Optional name for this op.
204 | Returns:
205 | The l1 distance between t1 and t2.
206 | """
207 | with tf.name_scope(name, 'l1_distance', [t1, t2]) as scope:
208 | t1 = tf.convert_to_tensor(t1, name='t1')
209 | t2 = tf.convert_to_tensor(t2, name='t2')
210 | sub = tf.subtract(t1, t2)
211 | reduction_dim = _last_index(sub, 1)
212 | return tf.reduce_sum(tf.abs(sub), reduction_dim, name=scope)
213 |
214 |
215 | def leaky_relu(x, name=None):
216 | """Creates a leaky_relu.
217 |
218 | This is an alternate non-linearity to relu. The leaky part of the relu may
219 | prevent dead Neurons in a model since the gradient doesn't go completely to
220 | 0.
221 |
222 | Args:
223 | x: The input tensor.
224 | name: Optional name for this op.
225 | Returns:
226 | x if x > 0 otherwise 0.01 * x.
227 | """
228 | with tf.name_scope(name, 'leaky_relu', [x]) as scope:
229 | x = tf.convert_to_tensor(x, name='x')
230 | return tf.where(tf.less(x, 0.0), 0.01 * x, x, name=scope)
231 |
232 |
233 | def softplus(x, scale=1.0, name=None):
234 | """Computes softplus with a scale factor to sharpen of the hinge.
235 |
236 | This is an alternate non-linearity to relu. It has a similar shape, but
237 | it has a smooth transition from the linear part to 0.
238 |
239 | Args:
240 | x: A tensor.
241 | scale: A float that sharpens the curve.
242 | name: Optional name.
243 | Returns:
244 | y = log(1 + exp(scale * x)) / scale
245 |
246 | """
247 | if scale == 1:
248 | return tf.nn.softplus(x)
249 | else:
250 | with tf.name_scope(name, 'softplus', [x]):
251 | scale = tf.convert_to_tensor(scale, dtype=x.dtype.base_dtype)
252 | return tf.nn.softplus(x * scale) / scale
253 |
254 |
255 | # Copied to keep API consistency with other functions.
256 | l2_normalize = tf.nn.l2_normalize
257 |
258 |
259 | def l1_normalize(x, dim, epsilon=1e-12, name=None):
260 | """l1 normalizes x.
261 |
262 | Args:
263 | x: The tensor to normalize.
264 | dim: The dimension to normalize along.
265 | epsilon: Lower bound on the norm, used to avoid exploding gradients as the
266 | norm approaches 0.
267 | name: Optional name for this op.
268 | Returns:
269 | x normalized along dim.
270 | """
271 | with tf.name_scope(name, 'l1_normalize', [x]) as scope:
272 | x = tf.convert_to_tensor(x, name='x')
273 | x = tf.verify_tensor_all_finite(x, 'Error at input %s' % scope)
274 | x_norm = tf.maximum(tf.reduce_sum(tf.abs(x), [dim], keep_dims=True),
275 | epsilon)
276 | return tf.div(x, x_norm, name=scope)
277 |
278 |
279 | def every_other(x, name=None):
280 | """Drops every other value from the tensor and returns a 1D tensor.
281 |
282 | This is useful if you are running multiple inputs through a model tower
283 | before splitting them and you want to line it up with some other data.
284 |
285 | Args:
286 | x: the target tensor.
287 | name: the name for this op, defaults to every_other
288 | Returns:
289 | A tensorflow op.
290 | """
291 | with tf.name_scope(name, 'every_other', [x]) as scope:
292 | x = tf.convert_to_tensor(x, name='x')
293 | return tf.reshape(
294 | tf.slice(
295 | tf.reshape(x, [-1, 2]), [0, 0], [-1, 1]),
296 | [-1],
297 | name=scope)
298 |
299 |
300 | def dot_product(t1, t2, keep_dims=False, name=None, reduction_dim=None):
301 | """Computes the dot product of t1 and t2.
302 |
303 | Args:
304 | t1: A rank 2 tensor.
305 | t2: A tensor that is the same size as t1.
306 | keep_dims: If true, reduction does not change the rank of the input.
307 | name: Optional name for this op.
308 | reduction_dim: The dimension to reduce, by default choose the last one
309 | and if no shape is specified guess 1.
310 | Returns:
311 | The dot product.
312 | """
313 | with tf.name_scope(name, 'dot', [t1, t2]) as scope:
314 | t1 = tf.convert_to_tensor(t1, name='t1')
315 | t2 = tf.convert_to_tensor(t2, name='t2')
316 | mul = tf.multiply(t1, t2)
317 | if not reduction_dim:
318 | reduction_dim = _last_index(mul, 1)
319 | return tf.reduce_sum(mul, reduction_dim, name=scope, keep_dims=keep_dims)
320 |
321 |
322 | def length_squared(x, keep_dims=False, name=None, reduction_dim=None):
323 | """Computes the squared length of x.
324 |
325 | Args:
326 | x: A tensor.
327 | keep_dims: If true, reduction does not change the rank of the input.
328 | name: Optional name for this op.
329 | reduction_dim: The dimension to reduce, by default choose the last one
330 | and if no shape is specified guess 1.
331 | Returns:
332 | The squared length of x.
333 | """
334 | with tf.name_scope(name, 'length_squared', [x]) as scope:
335 | x = tf.convert_to_tensor(x, name='x')
336 | if not reduction_dim:
337 | reduction_dim = _last_index(x, 1)
338 | return tf.reduce_sum(
339 | tf.square(x),
340 | reduction_dim,
341 | keep_dims=keep_dims,
342 | name=scope)
343 |
344 |
345 | def unzip(x, split_dim, current_length, num_splits=2, name=None):
346 | """Splits a tensor by unzipping along the split_dim.
347 |
348 | For example the following array split into 2 would be:
349 | [1, 2, 3, 4, 5, 6] -> [1, 3, 5], [2, 4, 6]
350 | and by 3:
351 | [1, 2, 3, 4] -> [1, 4], [2], [3]
352 |
353 | Args:
354 | x: The tensor to split.
355 | split_dim: The dimension to split along.
356 | current_length: Current length along the split_dim.
357 | num_splits: The number of splits.
358 | name: Optional name for this op.
359 | Returns:
360 | A length num_splits sequence.
361 | """
362 | with tf.name_scope(name, 'unzip', [x]) as scope:
363 | x = tf.convert_to_tensor(x, name='x')
364 | # There is probably a more efficient way to do this.
365 | all_splits = tf.split(
366 | value=x, num_or_size_splits=current_length, axis=split_dim, name=scope)
367 | splits = [[] for _ in xrange(num_splits)]
368 | for i in xrange(current_length):
369 | splits[i % num_splits].append(all_splits[i])
370 | return [tf.concat(s, split_dim) for s in splits]
371 |
372 |
373 | def _last_index(x, default_dim):
374 | """Returns the last dimension's index or default_dim if x has no shape."""
375 | if x.get_shape().ndims is not None:
376 | return len(x.get_shape()) - 1
377 | else:
378 | return default_dim
379 |
380 |
381 | def _all_dims(x, default_dims=None):
382 | """Returns a list of dims in x or default_dims if the rank is unknown."""
383 | if x.get_shape().ndims is not None:
384 | return list(xrange(x.get_shape().ndims))
385 | else:
386 | return default_dims
387 |
--------------------------------------------------------------------------------
/prettytensor/functions_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 Google Inc. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 | """Test class for functions."""
12 | from __future__ import absolute_import
13 | from __future__ import division
14 | from __future__ import print_function
15 |
16 |
17 |
18 | import numpy
19 | from numpy import linalg
20 | from numpy import testing
21 | import tensorflow as tf
22 |
23 | from prettytensor import functions
24 |
25 | TOLERANCE = 0.00001
26 |
27 |
28 | # Distance functions used in tests. These are defined here instead of using
29 | # scipy so the open source tests don't depend on such a huge module for 3
30 | # 1 line functions.
31 | def cosine(u, v): # pylint: disable=invalid-name
32 | return 1.0 - numpy.dot(u, v) / (linalg.norm(u, ord=2) * linalg.norm(v, ord=2))
33 |
34 |
35 | def cityblock(u, v): # pylint: disable=invalid-name
36 | return numpy.abs(u - v).sum()
37 |
38 |
39 | def euclidean(u, v): # pylint: disable=invalid-name
40 | return linalg.norm(u - v, ord=2)
41 |
42 |
43 | class TensorFlowOpTest(tf.test.TestCase):
44 |
45 | def eval_tensor(self, tensors):
46 | if isinstance(tensors, tf.Tensor):
47 | tensors = [tensors]
48 | with self.test_session() as sess:
49 | return sess.run(tensors)
50 |
51 | def test_every_other(self):
52 | tensor = tf.constant([[1, 2], [3, 4]])
53 | out = self.eval_tensor(functions.every_other(tensor))
54 | testing.assert_array_equal(out[0], numpy.array([1, 3], dtype=numpy.int32))
55 | tensor = tf.constant([[1, 2, 3, 4]])
56 | out = self.eval_tensor(functions.every_other(tensor))
57 | testing.assert_array_equal(out[0], numpy.array([1, 3], dtype=numpy.int32))
58 |
59 | def test_l1_regression_loss(self):
60 | ftensor1 = tf.constant([1., 2., 3., 4.])
61 | ftensor2 = tf.constant([5., 6., 7., -8.])
62 | out = self.eval_tensor(functions.l1_regression_loss(ftensor1, ftensor2))
63 | testing.assert_array_equal(out[0], numpy.array([4., 4., 4., 12.]))
64 |
65 | def test_l2_sq_regression_loss(self):
66 | ftensor1 = tf.constant([1., 2., 3., 4.])
67 | ftensor2 = tf.constant([5., 6., 7., -8.])
68 | out = self.eval_tensor(functions.l2_regression_sq_loss(ftensor1, ftensor2))
69 | testing.assert_array_equal(out[0], numpy.array([16., 16., 16, 144]))
70 |
71 | def test_l2_regression_loss(self):
72 | ftensor1 = tf.constant([1., 2., 3., 4.])
73 | ftensor2 = tf.constant([5., 6., 7., -8.])
74 | out = self.eval_tensor(functions.l2_regression_loss(ftensor1, ftensor2))
75 | testing.assert_allclose(
76 | out[0],
77 | numpy.array([4., 4., 4., 12.]),
78 | rtol=TOLERANCE, atol=TOLERANCE)
79 |
80 | def test_binary_cross_entropy_loss_with_logits(self):
81 | n1 = numpy.array([2., 3., 4., 5., -6., -7.], dtype=numpy.float32)
82 | n2 = numpy.array([1., 1., 0., 0., 0., 1.], dtype=numpy.float32)
83 | ftensor1 = tf.constant(n1)
84 | ftensor2 = tf.constant(n2)
85 | out = self.eval_tensor(functions.binary_cross_entropy_loss_with_logits(
86 | ftensor1, ftensor2))
87 | testing.assert_allclose(
88 | out[0],
89 | n1 * (1-n2) + numpy.log(1 + numpy.exp(-n1)),
90 | rtol=0.00001)
91 |
92 | def test_soft_plus(self):
93 | # 100 overflows naive implementations in float
94 | values = (
95 | numpy.array(
96 | [-100., -10., 1., 0, 1., 10., 100.],
97 | dtype=numpy.float32))
98 | out = self.eval_tensor(
99 | functions.softplus(
100 | tf.constant(
101 | values,
102 | dtype=tf.float32),
103 | 1.))
104 | np_values = numpy.log(1. + numpy.exp(values))
105 | np_values[6] = 100.
106 | testing.assert_allclose(out[0], np_values, rtol=TOLERANCE, atol=TOLERANCE)
107 |
108 | out = self.eval_tensor(functions.softplus(tf.constant(values), 2.))
109 | np_values = numpy.log(1. + numpy.exp(values * 2.)) / 2.
110 | np_values[6] = 100.
111 | testing.assert_allclose(out[0], np_values, rtol=TOLERANCE, atol=TOLERANCE)
112 |
113 | def test_cos_distance(self):
114 | n1 = numpy.array([[1., 2., 3., 4.], [1., 1., 1., 1.]], dtype=numpy.float32)
115 | n2 = numpy.array([[5., 6., 7., -8.], [1., 1., 1., 1.]], dtype=numpy.float32)
116 | out = self.eval_tensor(functions.cos_distance(n1, n2))
117 |
118 | testing.assert_allclose(
119 | out[0],
120 | numpy.array([cosine(n1[0], n2[0]), cosine(n1[1], n2[1])]),
121 | rtol=TOLERANCE, atol=TOLERANCE)
122 |
123 | def test_l1_distance(self):
124 | n1 = numpy.array([[1., 2., 3., 4.], [1., 1., 1., 1.]], dtype=numpy.float32)
125 | n2 = numpy.array([[5., 6., 7., -8.], [1., 1., 1., 1.]], dtype=numpy.float32)
126 | out = self.eval_tensor(functions.l1_distance(n1, n2))
127 | testing.assert_allclose(
128 | out[0],
129 | numpy.array(
130 | [cityblock(n1[0], n2[0]), cityblock(n1[1], n2[1])
131 | ]),
132 | rtol=TOLERANCE, atol=TOLERANCE)
133 |
134 | def test_l2_distance(self):
135 | n1 = numpy.array([[1., 2., 3., 4.], [1., 1., 1., 1.]], dtype=numpy.float32)
136 | n2 = numpy.array([[5., 6., 7., -8.], [1., 1., 1., 1.]], dtype=numpy.float32)
137 | out = self.eval_tensor(functions.l2_distance(n1, n2))
138 | testing.assert_allclose(
139 | out[0],
140 | numpy.array(
141 | [euclidean(n1[0], n2[0]),
142 | 1e-6 # Epsilon sets the minimum distance so use that instead of 0.
143 | ]),
144 | rtol=TOLERANCE, atol=TOLERANCE)
145 |
146 | def test_l2_distance_sq(self):
147 | n1 = numpy.array([[1., 2., 3., 4.], [1., 1., 1., 1.]], dtype=numpy.float32)
148 | n2 = numpy.array([[5., 6., 7., -8.], [1., 1., 1., 1.]], dtype=numpy.float32)
149 | out = self.eval_tensor(functions.l2_distance_sq(n1, n2))
150 | testing.assert_allclose(
151 | out[0],
152 | numpy.power(
153 | numpy.array(
154 | [euclidean(n1[0], n2[0]), euclidean(
155 | n1[1], n2[1])]), 2),
156 | rtol=TOLERANCE, atol=TOLERANCE)
157 |
158 | def test_dot_distance(self):
159 | n1 = numpy.array([[1., 2., 3., 4.], [1., 1., 1., 1.]], dtype=numpy.float32)
160 | n2 = numpy.array([[5., 6., 7., -8.], [1., 1., 1., 1.]], dtype=numpy.float32)
161 | out = self.eval_tensor(functions.dot_distance(n1, n2))
162 | testing.assert_allclose(
163 | out[0],
164 | numpy.array(-numpy.sum(n1 * n2,
165 | axis=1)),
166 | rtol=TOLERANCE, atol=TOLERANCE)
167 |
168 | def test_cos_distance_with_broadcast(self):
169 | n1 = numpy.array([[[1., 2., 3., 4.], [1., 1., 1., 1.]], [[5., 6., 7., 8.],
170 | [1., 1., 1., 2.]]],
171 | dtype=numpy.float32)
172 | n2 = numpy.array([[5., 6., 7., -8.], [1., 1., 1., 1.]], dtype=numpy.float32)
173 | out = self.eval_tensor(functions.cos_distance(n1, n2))
174 | expected = numpy.array(
175 | [[cosine(n1[0, 0], n2[0]), cosine(n1[0, 1], n2[1])],
176 | [cosine(n1[1, 0], n2[0]), cosine(n1[1, 1], n2[1])]])
177 | testing.assert_allclose(expected, out[0], atol=TOLERANCE)
178 |
179 | def test_l1_distance_with_broadcast(self):
180 | n1 = numpy.array([[[1., 2., 3., 4.], [1., 1., 1., 1.]], [[5., 6., 7., 8.],
181 | [1., 1., 1., 2.]]],
182 | dtype=numpy.float32)
183 | n2 = numpy.array([[5., 6., 7., -8.], [1., 1., 1., 1.]], dtype=numpy.float32)
184 | out = self.eval_tensor(functions.l1_distance(n1, n2))
185 | expected = numpy.array(
186 | [[cityblock(n1[0, 0], n2[0]), cityblock(
187 | n1[0, 1], n2[1])], [cityblock(n1[1, 0], n2[0]),
188 | cityblock(n1[1, 1], n2[1])]])
189 | testing.assert_allclose(expected, out[0], atol=TOLERANCE)
190 |
191 | def test_l2_distance_with_broadcast(self):
192 | n1 = numpy.array([[[1., 2., 3., 4.], [1., 1., 1., 1.]], [[5., 6., 7., 8.],
193 | [1., 1., 1., 2.]]],
194 | dtype=numpy.float32)
195 | n2 = numpy.array([[5., 6., 7., -8.], [1., 1., 1., 1.]], dtype=numpy.float32)
196 | out = self.eval_tensor(functions.l2_distance(n1, n2))
197 | expected = numpy.array(
198 | [[euclidean(n1[0, 0], n2[0]), euclidean(
199 | n1[0, 1], n2[1])], [euclidean(n1[1, 0], n2[0]),
200 | euclidean(n1[1, 1], n2[1])]])
201 | testing.assert_allclose(expected, out[0], atol=TOLERANCE)
202 |
203 | def test_l2_distance_sq_with_broadcast(self):
204 | n1 = numpy.array([[[1., 2., 3., 4.], [1., 1., 1., 1.]], [[5., 6., 7., 8.],
205 | [1., 1., 1., 2.]]],
206 | dtype=numpy.float32)
207 | n2 = numpy.array([[5., 6., 7., -8.], [1., 1., 1., 1.]], dtype=numpy.float32)
208 | out = self.eval_tensor(functions.l2_distance_sq(n1, n2))
209 | expected = numpy.array(
210 | [[euclidean(n1[0, 0], n2[0]), euclidean(
211 | n1[0, 1], n2[1])], [euclidean(n1[1, 0], n2[0]),
212 | euclidean(n1[1, 1], n2[1])]])
213 | expected = numpy.power(expected, 2)
214 | testing.assert_allclose(expected, out[0], atol=TOLERANCE)
215 |
216 | def test_dot_distance_with_broadcast(self):
217 | n1 = numpy.array([[[1., 2., 3., 4.], [1., 1., 1., 1.]], [[5., 6., 7., 8.],
218 | [1., 1., 1., 2.]]],
219 | dtype=numpy.float32)
220 | n2 = numpy.array([[5., 6., 7., -8.], [1., 1., 1., 1.]], dtype=numpy.float32)
221 | out = self.eval_tensor(functions.dot_distance(n1, n2))
222 | testing.assert_allclose(
223 | out[0],
224 | numpy.array(-numpy.sum(n1 * n2,
225 | axis=2)),
226 | rtol=TOLERANCE, atol=TOLERANCE)
227 |
228 | def test_l2_normalize(self):
229 | n1 = numpy.array([[1., 2., 3., 4.], [1., 1., 1., 1.]], dtype=numpy.float32)
230 | t1 = tf.constant(n1)
231 | out = self.eval_tensor(functions.l2_normalize(t1, 1))
232 | testing.assert_allclose(
233 | out[0],
234 | n1 / linalg.norm(n1, 2, axis=1).reshape((2, 1)),
235 | rtol=TOLERANCE, atol=TOLERANCE)
236 |
237 | def test_l1_normalize(self):
238 | n1 = numpy.array([[1., 2., 3., 4.], [1., 1., 1., 1.]], dtype=numpy.float32)
239 | t1 = tf.constant(n1)
240 | out = self.eval_tensor(functions.l1_normalize(t1, 1))
241 | testing.assert_allclose(
242 | out[0],
243 | n1 / linalg.norm(n1, 1, axis=1).reshape((2, 1)),
244 | rtol=TOLERANCE, atol=TOLERANCE)
245 |
246 | def test_leaky_relu(self):
247 | values = (
248 | numpy.array(
249 | [-100., -10., 1., 0, 1., 10., 100.],
250 | dtype=numpy.float32))
251 | tensor = tf.constant(values)
252 | out = self.eval_tensor(functions.leaky_relu(tensor))
253 | for i, value in enumerate(values):
254 | if value < 0:
255 | values[i] *= 0.01
256 | testing.assert_allclose(out[0], values, rtol=TOLERANCE, atol=TOLERANCE)
257 |
258 | def test_unzip(self):
259 | n1 = numpy.array([[1., 2.], [3., 4.], [5., 6.], [7., 8.]],
260 | dtype=numpy.float32)
261 | t1 = tf.constant(n1)
262 | out = self.eval_tensor(functions.unzip(t1, 0, 4, 2))
263 |
264 | expected = numpy.array([[1., 2.], [5., 6.]], dtype=numpy.float32)
265 | testing.assert_allclose(expected, out[0], rtol=TOLERANCE, atol=TOLERANCE)
266 | expected = numpy.array([[3., 4.], [7., 8.]], dtype=numpy.float32)
267 | testing.assert_allclose(expected, out[1], rtol=TOLERANCE, atol=TOLERANCE)
268 |
269 | def test_split(self):
270 | """Testing TF functionality to highlight difference with Unzip."""
271 | n1 = numpy.array([[1., 2.], [3., 4.], [5., 6.], [7., 8.]],
272 | dtype=numpy.float32)
273 | t1 = tf.constant(n1)
274 | out = self.eval_tensor(tf.split(value=t1, num_or_size_splits=2, axis=0))
275 | expected = numpy.array([[1., 2.], [3., 4.]], dtype=numpy.float32)
276 | testing.assert_allclose(expected, out[0], rtol=TOLERANCE, atol=TOLERANCE)
277 | expected = numpy.array([[5., 6.], [7., 8.]], dtype=numpy.float32)
278 | testing.assert_allclose(expected, out[1], rtol=TOLERANCE, atol=TOLERANCE)
279 |
280 |
281 | if __name__ == '__main__':
282 | tf.test.main()
283 |
--------------------------------------------------------------------------------
/prettytensor/input_helpers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 Google Inc. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 | """Provides helpers for feeding in numpy data to a TF graph.
12 |
13 | These methods are intended to aid experimentation. For large datasets consider
14 | using readers and queues.
15 | """
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 | import itertools
21 |
22 |
23 |
24 | from six.moves import xrange # pylint: disable=redefined-builtin
25 | import tensorflow as tf
26 |
27 | from prettytensor import bookkeeper
28 |
29 |
30 | def feed_numpy(batch_size, *arrays):
31 | """Given a set of numpy arrays, produce slices of batch_size.
32 |
33 | Note: You can use itertools.cycle to have this repeat forever.
34 |
35 | Args:
36 | batch_size: The batch_size for each array.
37 | *arrays: A list of arrays.
38 | Yields:
39 | A list of slices from the arrays of length batch_size except the last one
40 | which will contain the rest.
41 | Raises:
42 | ValueError: If arrays aren't all the same length or no arrays are provided.
43 | """
44 | if not arrays:
45 | raise ValueError('Arrays cannot be empty.')
46 | size = len(arrays[0])
47 | for a in arrays:
48 | if size != len(a):
49 | raise ValueError('All arrays must be the same size.')
50 | count = int(size / batch_size)
51 |
52 | for i in xrange(count):
53 | start = i * batch_size
54 | end = start + batch_size
55 | yield [x[start:end] for x in arrays]
56 | if count * batch_size < size:
57 | yield [x[end:] for x in arrays]
58 |
59 |
60 | def batch(input_iter, batch_size=32):
61 | """Batches data from an iterator that returns single items at a time."""
62 | input_iter = iter(input_iter)
63 | next_ = list(itertools.islice(input_iter, batch_size))
64 | while next_:
65 | yield next_
66 | next_ = list(itertools.islice(input_iter, batch_size))
67 |
68 |
69 | def slice_constant(data, batch_size=32, name='constant_data', global_step=None):
70 | """Provide a slice based on the global_step.
71 |
72 | This is useful when the entire data array can be stored in memory because it
73 | allows you to feed the data very efficiently.
74 |
75 | Args:
76 | data: A numpy array or tensor.
77 | batch_size: The batch size for the produced data.
78 | name: An optional name for this data.
79 | global_step: A global step variable that is used to read the data. If None
80 | then the default prettytensor global_step is used.
81 | Returns:
82 | A tensor that produces the given data.
83 | """
84 | with tf.name_scope(name):
85 | all_data = tf.convert_to_tensor(data)
86 | global_step = global_step or bookkeeper.global_step()
87 |
88 | count = len(data) / batch_size
89 | extra = len(data) - count * batch_size
90 |
91 | if extra:
92 | offset = tf.mod(global_step, count)
93 | return tf.slice(all_data, offset * batch_size, batch_size)
94 | else:
95 | offset = tf.mod(global_step, count + 1)
96 | return tf.slice(all_data, offset * batch_size,
97 | tf.where(tf.equal(offset, count), extra, batch_size))
98 |
--------------------------------------------------------------------------------
/prettytensor/layers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 Google Inc. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 | """Utility functions for adding layers to a Model.
12 |
13 | NB: This is used by PrettyTensor, but it will be deprecated. Please do not use!
14 | """
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import collections
20 | import math
21 |
22 | import tensorflow as tf
23 |
24 | from prettytensor import functions
25 |
26 | # Implementation note: this takes a tuple for an activation instead of
27 | # encouraging lambdas so that we can inspect the actual function and add
28 | # appropriate summaries.
29 |
30 |
31 | def apply_activation(
32 | books,
33 | x,
34 | activation,
35 | activation_args=(),
36 | activation_kwargs=None):
37 | """Returns activation(x, *activation_args, **activation_kwargs).
38 |
39 | This applies the given activation and adds useful summaries specific to the
40 | activation.
41 |
42 | Args:
43 | books: The bookkeeper.
44 | x: The tensor to apply activation to.
45 | activation: An activation function.
46 | activation_args: Optional additional arguments for the activation.
47 | activation_kwargs: Optional keyword args for activation.
48 | Returns:
49 | A tensor with activation applied to x.
50 | """
51 | if activation is None:
52 | return x
53 | if activation_kwargs is None:
54 | activation_kwargs = {}
55 | y = activation(x, *activation_args, **activation_kwargs)
56 | if activation in (tf.nn.relu, functions.leaky_relu, functions.softplus):
57 | books.add_scalar_summary(
58 | tf.reduce_mean(tf.cast(tf.less(x, 0.0), tf.float32)),
59 | '%s/zeros' % y.op.name)
60 | elif activation is tf.nn.relu6:
61 | books.add_scalar_summary(
62 | tf.reduce_mean(tf.cast(tf.less(x, 0.0), tf.float32)),
63 | '%s/zeros' % y.op.name)
64 | books.add_scalar_summary(
65 | tf.reduce_mean(tf.cast(tf.greater(x, 6.0), tf.float32)),
66 | '%s/sixes' % y.op.name)
67 | elif activation in (functions.l2_normalize, tf.nn.l2_normalize,
68 | functions.l1_normalize):
69 | books.add_scalar_summary(
70 | tf.reduce_mean(tf.sqrt(tf.reduce_sum(
71 | tf.square(x), 1))), '%s/length' % y.op.name)
72 | return y
73 |
74 |
75 | def add_l2loss(books, params, l2loss, name='weight_decay'):
76 | if l2loss:
77 | books.add_loss(
78 | tf.multiply(
79 | tf.nn.l2_loss(params), l2loss, name=name),
80 | regularization=True,
81 | add_summaries=False)
82 |
83 |
84 | def he_init(n_inputs, n_outputs, activation_fn, uniform=True):
85 | """Sets the parameter initialization using the method described.
86 |
87 | This method is designed to keep the scale of the gradients roughly the same
88 | in all layers with ReLU activations.
89 |
90 | He et al. (2015):
91 | Delving deep into rectifiers: surpassing human-level performance on
92 | imageNet classification. International Conference on Computer Vision.
93 |
94 | For activations other than ReLU and ReLU6, this method uses Xavier
95 | initialization as in xavier_init().
96 |
97 | Args:
98 | n_inputs: The number of input nodes into each output.
99 | n_outputs: The number of output nodes for each input.
100 | activation_fn: Activation function used in this layer.
101 | uniform: If uniform distribution will be used for Xavier initialization.
102 | Normal distribution will be used if False.
103 | Returns:
104 | An initializer.
105 | """
106 | def in_relu_family(activation_fn):
107 | if isinstance(activation_fn, collections.Sequence):
108 | activation_fn = activation_fn[0]
109 | return activation_fn in (tf.nn.relu, tf.nn.relu6)
110 |
111 | if in_relu_family(activation_fn):
112 | stddev = math.sqrt(2.0 / n_inputs)
113 | # TODO(): Evaluates truncated_normal_initializer.
114 | return tf.random_normal_initializer(stddev=stddev)
115 | else:
116 | return xavier_init(n_inputs, n_outputs, uniform)
117 |
118 |
119 | def xavier_init(n_inputs, n_outputs, uniform=True):
120 | """Set the parameter initialization using the method described.
121 |
122 | This method is designed to keep the scale of the gradients roughly the same
123 | in all layers.
124 |
125 | Xavier Glorot and Yoshua Bengio (2010):
126 | Understanding the difficulty of training deep feedforward neural
127 | networks. International conference on artificial intelligence and
128 | statistics.
129 | Args:
130 | n_inputs: The number of input nodes into each output.
131 | n_outputs: The number of output nodes for each input.
132 | uniform: If true use a uniform distribution, otherwise use a normal.
133 | Returns:
134 | An initializer.
135 | """
136 | if uniform:
137 | # 6 was used in the paper.
138 | init_range = math.sqrt(6.0 / (n_inputs + n_outputs))
139 | return tf.random_uniform_initializer(-init_range, init_range)
140 | else:
141 | # 3 gives us approximately the same limits as above since this repicks
142 | # values greater than 2 standard deviations from the mean.
143 | stddev = math.sqrt(3.0 / (n_inputs + n_outputs))
144 | return tf.truncated_normal_initializer(stddev=stddev)
145 |
146 |
147 | def spatial_slice_zeros(x):
148 | """Experimental summary that shows how many planes are unused for a batch."""
149 | return tf.cast(tf.reduce_all(tf.less_equal(x, 0.0), [0, 1, 2]),
150 | tf.float32)
151 |
--------------------------------------------------------------------------------
/prettytensor/local_trainer_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 Google Inc. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 | """Tests for local_trainer."""
12 | from __future__ import absolute_import
13 | from __future__ import division
14 | from __future__ import print_function
15 |
16 | import itertools
17 | import os
18 | import shutil
19 | import tempfile
20 | import threading
21 | import unittest
22 |
23 |
24 |
25 | import numpy
26 | import tensorflow as tf
27 |
28 | import prettytensor as pt
29 | from prettytensor import input_helpers
30 | from prettytensor import local_trainer
31 |
32 |
33 | class LocalTrainerTest(unittest.TestCase):
34 |
35 | def random_numpy(self, shape, dtype, partition_info=None):
36 | _ = partition_info
37 | if tf.float32.is_compatible_with(dtype):
38 | size = 1
39 | for n in shape:
40 | size *= n
41 | return self.prng.normal(size=size).astype(numpy.float32).reshape(shape)
42 | else:
43 | raise ValueError('This method only supports float32: %s' % dtype)
44 |
45 | def setUp(self):
46 | tf.reset_default_graph()
47 | self.prng = numpy.random.RandomState(42)
48 |
49 | self.input = tf.placeholder(tf.float32, [4, 2])
50 | self.target = tf.placeholder(tf.float32)
51 | xor_inputs = numpy.array([[0., 0.], [1., 0.], [1., 1.], [0., 1.]])
52 | xor_outputs = numpy.array([[0., 1.], [1., 0.], [0., 1.], [1., 0.]])
53 |
54 | self.xor_data = itertools.cycle(
55 | input_helpers.feed_numpy(4, xor_inputs, xor_outputs))
56 |
57 | self.softmax_result = (
58 | pt.wrap(self.input).fully_connected(2,
59 | activation_fn=tf.sigmoid,
60 | weights=self.random_numpy)
61 | .fully_connected(2,
62 | activation_fn=None,
63 | weights=self.random_numpy).softmax(self.target))
64 | self.tmp_file = tempfile.mkdtemp()
65 |
66 | def tearDown(self):
67 | shutil.rmtree(self.tmp_file)
68 |
69 | def test_run(self):
70 | runner = local_trainer.Runner()
71 | with tf.Session():
72 | optimizer = tf.train.GradientDescentOptimizer(0.5)
73 | train_op = pt.apply_optimizer(optimizer,
74 | losses=[self.softmax_result.loss])
75 |
76 | runner.train_model(train_op,
77 | self.softmax_result.loss,
78 | 10,
79 | (self.input, self.target),
80 | self.xor_data,
81 | print_every=2)
82 |
83 | def test_checkpoint(self):
84 | f = os.path.join(self.tmp_file, 'checkpoint')
85 | runner = local_trainer.Runner(save_path=f)
86 | with tf.Session():
87 | optimizer = tf.train.GradientDescentOptimizer(0.1)
88 | train_op = pt.apply_optimizer(optimizer,
89 | losses=[self.softmax_result.loss])
90 |
91 | runner.train_model(train_op,
92 | self.softmax_result.loss,
93 | 10,
94 | (self.input, self.target),
95 | self.xor_data,
96 | print_every=2)
97 | assert runner._saver.last_checkpoints, 'Expected checkpoints.'
98 | for x in runner._saver.last_checkpoints:
99 | self.assertTrue(tf.train.checkpoint_exists(x),
100 | 'Promised file not saved: %s' % x)
101 | self.assertTrue(x.startswith(f), 'Name not as expected: %s' % x)
102 |
103 | def test_eval(self):
104 | f = os.path.join(self.tmp_file, 'checkpoint')
105 | runner = local_trainer.Runner(save_path=f)
106 | with tf.Session():
107 | classification_acuracy = self.softmax_result.softmax.evaluate_classifier(
108 | self.target, phase=pt.Phase.test)
109 |
110 | optimizer = tf.train.GradientDescentOptimizer(0.2)
111 | train_op = pt.apply_optimizer(optimizer,
112 | losses=[self.softmax_result.loss])
113 |
114 | runner.train_model(train_op,
115 | self.softmax_result.loss,
116 | 100,
117 | (self.input, self.target),
118 | self.xor_data,
119 | print_every=50)
120 | self.assertTrue(runner._last_init)
121 | save_paths = list(runner._saver.last_checkpoints)
122 |
123 | # The accuracy should be 50% right now since model is consistently
124 | # generated.
125 | accuracy = runner.evaluate_model(classification_acuracy,
126 | 1,
127 | (self.input, self.target),
128 | self.xor_data)
129 | self.assertEquals(runner._saver.last_checkpoints, save_paths,
130 | 'No additional paths should have been saved.')
131 | self.assertFalse(runner._last_init)
132 | self.assertEqual(accuracy, [0.5])
133 |
134 | # Train the model to 100% accuracy.
135 | runner.train_model(train_op,
136 | self.softmax_result.loss,
137 | 2000,
138 | (self.input, self.target),
139 | self.xor_data,
140 | print_every=1000)
141 | accuracy = runner.evaluate_model(classification_acuracy, 1,
142 | (self.input, self.target), self.xor_data)
143 | self.assertFalse(runner._last_init)
144 |
145 | # Make sure that the previous computation didn't impact this eval.
146 | self.assertEqual(accuracy, [1.0])
147 |
148 | def restore_helper(self, runner):
149 | with tf.Session():
150 | classification_acuracy = self.softmax_result.softmax.evaluate_classifier(
151 | self.target, phase=pt.Phase.test)
152 |
153 | optimizer = tf.train.GradientDescentOptimizer(0.5)
154 | train_op = pt.apply_optimizer(optimizer,
155 | losses=[self.softmax_result.loss])
156 |
157 | runner.train_model(train_op,
158 | self.softmax_result.loss,
159 | 10,
160 | (self.input, self.target),
161 | self.xor_data,
162 | print_every=2)
163 | self.assertTrue(runner._last_init)
164 | self.assertFalse(runner._last_restore)
165 | with tf.Session():
166 | save_paths = list(runner._saver.last_checkpoints)
167 | runner.evaluate_model(classification_acuracy, 1,
168 | (self.input, self.target), self.xor_data)
169 | self.assertEquals(runner._saver.last_checkpoints, save_paths,
170 | 'No additional paths should have been saved.')
171 | self.assertFalse(runner._last_init)
172 |
173 | def test_manual_save_restore(self):
174 | runner = local_trainer.Runner()
175 | f = os.path.join(self.tmp_file, 'manual.chkpt')
176 |
177 | v = tf.Variable(tf.random_normal(shape=[100], dtype=tf.float32))
178 |
179 | # Save it.
180 | with runner.session() as sess:
181 | runner.prepare_model(sess) # Create variables
182 | value = v.eval() # Grab the variable
183 | runner.saver.save(sess, f)
184 |
185 | with runner.session() as sess:
186 | # Restore the model
187 | runner.saver.restore(sess, f)
188 | new_value = v.eval()
189 | numpy.testing.assert_array_equal(value, new_value)
190 |
191 | def test_restore(self):
192 | f = os.path.join(self.tmp_file, 'checkpoint')
193 | runner = local_trainer.Runner(save_path=f)
194 | self.restore_helper(runner)
195 | self.assertTrue(runner._last_restore)
196 |
197 | def test_not_restored(self):
198 | f = os.path.join(self.tmp_file, 'checkpoint')
199 | runner = local_trainer.Runner(save_path=f, restore=False)
200 | with self.assertRaises(tf.errors.FailedPreconditionError):
201 | self.restore_helper(runner)
202 |
203 | def test_evaluate_without_initialize_error(self):
204 | with tf.Graph().as_default():
205 | runner = local_trainer.Runner()
206 | tf.Variable(1) # Put a variable in the graph.
207 |
208 | with runner.session(), self.assertRaises(ValueError):
209 | runner.evaluate_model(
210 | self.softmax_result, 1, (self.input, self.target), self.xor_data)
211 |
212 | def test_evaluate_repeatedly_one_time(self):
213 | f = os.path.join(self.tmp_file, 'checkpoint')
214 | runner = local_trainer.Runner(save_path=f)
215 | self.restore_helper(runner)
216 | local_variable = tf.Variable(22, collections=[tf.GraphKeys.LOCAL_VARIABLES])
217 | accuracy = local_variable.assign_add(1)
218 |
219 | answer = runner.evaluate_repeatedly(accuracy, 20, evaluation_times=1)
220 | self.assertEqual([42], answer)
221 |
222 | def test_queues(self):
223 | qr = FakeQueueRunner()
224 | tf.train.add_queue_runner(qr)
225 | runner = local_trainer.Runner()
226 | with tf.Session():
227 | optimizer = tf.train.GradientDescentOptimizer(0.5)
228 | train_op = pt.apply_optimizer(optimizer,
229 | losses=[self.softmax_result.loss])
230 |
231 | runner.train_model(train_op,
232 | self.softmax_result.loss,
233 | 100,
234 | (self.input, self.target),
235 | self.xor_data,
236 | print_every=2)
237 | with tf.Session():
238 | with self.assertRaisesRegexp(ValueError, r'.*\bstop_queues\b.*'):
239 | runner.train_model(train_op,
240 | self.softmax_result.loss,
241 | 100,
242 | (self.input, self.target),
243 | self.xor_data,
244 | print_every=2)
245 |
246 | runner.stop_queues()
247 | qr.assert_worked(self)
248 |
249 | def test_queue_error(self):
250 | qr = FakeQueueRunner(RuntimeError('expected'))
251 | tf.train.add_queue_runner(qr)
252 | runner = local_trainer.Runner()
253 | with tf.Session():
254 | optimizer = tf.train.GradientDescentOptimizer(0.5)
255 | train_op = pt.apply_optimizer(optimizer,
256 | losses=[self.softmax_result.loss])
257 |
258 | with self.assertRaisesRegexp(RuntimeError, 'expected'):
259 | runner.train_model(train_op,
260 | self.softmax_result.loss,
261 | 100,
262 | (self.input, self.target),
263 | self.xor_data,
264 | print_every=2)
265 | qr.assert_worked(self)
266 |
267 |
268 | class FakeQueueRunner(object):
269 | called = 0
270 | stopped = False
271 |
272 | def __init__(self, error=None):
273 | self.error = error
274 |
275 | def create_threads(self, sess, coord=None, daemon=False, start=False): # pylint: disable=unused-argument
276 | self.called += 1
277 | threads = [threading.Thread(target=self.set_stopped, args=(coord,))]
278 | if self.error:
279 | threads.append(threading.Thread(target=self.die,
280 | args=(coord, self.error)))
281 | if start:
282 | for t in threads:
283 | t.start()
284 | return threads
285 |
286 | def die(self, coord, error):
287 | try:
288 | raise error
289 | except RuntimeError as e:
290 | coord.request_stop(e)
291 |
292 | def set_stopped(self, coord):
293 | coord.wait_for_stop()
294 | self.stopped = True
295 |
296 | def assert_worked(self, test):
297 | test.assertEqual(1, self.called)
298 | test.assertTrue(self.stopped)
299 |
300 | if __name__ == '__main__':
301 | tf.test.main()
302 |
--------------------------------------------------------------------------------
/prettytensor/parameters.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 Google Inc. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 | r"""Provides some standard functions to modify parameter variables.
12 |
13 | These are applied before the variables are used within the graph; the basic
14 | function signature is:
15 |
16 | def my_func(var_name, variable, phase):
17 | \"""A function to apply to a model's parameters.
18 |
19 | Args:
20 | var_name: The short name for the variable.
21 | variable: A `Tensor` that can be used in the model. Note: this is often
22 | a tf.Variable, but may not be and the name only usually contains
23 | var_name (it doesn't in the case of reuse).
24 | phase: The phase of model construction.
25 |
26 | Returns:
27 | A `Variable` or `Tensor` with the same shape and type as `variable` to
28 | use.
29 | \"""
30 | return something_done_to_variable
31 | """
32 |
33 | import re
34 |
35 |
36 | import tensorflow as tf
37 |
38 | from prettytensor import pretty_tensor_class as pt
39 |
40 |
41 | def identity(unused_var_name, variable, unused_phase):
42 | return variable
43 |
44 |
45 | def regularizer(name, regularization_fn, name_filter='weights'):
46 | """Wraps a regularizer in a parameter-function.
47 |
48 | Args:
49 | name: The name scope for this regularizer.
50 | regularization_fn: A function with signature:
51 | fn(variable) -> loss `Tensor` or `None`.
52 | name_filter: A regex that will be used to filter variables by name.
53 |
54 | Returns:
55 | A parameter modification function that adds the loss to the
56 | REGULARIZATION_LOSSES graph key.
57 | """
58 | regex = re.compile(name_filter)
59 | def fn(var_name, variable, phase):
60 | if phase is pt.Phase.train and regex.search(var_name):
61 | with tf.name_scope(None, name, [variable]):
62 | loss = regularization_fn(variable)
63 | if loss is not None:
64 | tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, loss)
65 | return variable
66 | return fn
67 |
68 |
69 | def l2_regularizer(decay, name_filter='weights'):
70 | """Create an l2 regularizer."""
71 | return regularizer(
72 | 'l2_regularizer',
73 | lambda x: tf.nn.l2_loss(x) * decay,
74 | name_filter=name_filter)
75 |
76 |
77 | def l1_regularizer(decay, name_filter='weights'):
78 | """Create an l1 regularizer."""
79 | return regularizer(
80 | 'l1_regularizer',
81 | lambda x: tf.reduce_sum(tf.abs(x)) * decay,
82 | name_filter=name_filter)
83 |
84 |
85 | def compose(*parameter_functions):
86 | """Composes multiple modification functions in order.
87 |
88 | Args:
89 | *parameter_functions: The functions to compose.
90 |
91 | Returns:
92 | A parameter modification function that consists of applying all the provided
93 | functions.
94 | """
95 | def composed_fn(var_name, variable, phase):
96 | for fn in parameter_functions:
97 | variable = fn(var_name, variable, phase)
98 | return variable
99 | return composed_fn
100 |
101 |
102 | class Noise(object):
103 | """Regularize the model by applying gaussian noise to the variables."""
104 |
105 | def __init__(self, stddev):
106 | self._stddev = stddev
107 |
108 | def __call__(self, var_name, variable, phase):
109 | if phase is pt.Phase.train:
110 | return variable * tf.random_normal(
111 | tf.shape(variable), mean=1., stddev=self._stddev)
112 | else:
113 | return variable
114 |
115 |
116 | class DropConnect(object):
117 | """Drop out some connections.
118 |
119 | See the paper: http://www.matthewzeiler.com/pubs/icml2013/icml2013.pdf
120 | """
121 |
122 | def __init__(self, keep_prob):
123 | self._keep_prob = keep_prob
124 |
125 | def __call__(self, var_name, variable, phase):
126 | if 'bias' not in var_name or phase is not pt.Phase.train:
127 | return variable
128 | return tf.nn.dropout(variable, self._keep_prob)
129 |
--------------------------------------------------------------------------------
/prettytensor/pretty_tensor_normalization_methods.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 Google Inc. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 | """Batch Normalization and eventually some friends for PrettyTensor."""
12 | from __future__ import absolute_import
13 | from __future__ import division
14 | from __future__ import print_function
15 |
16 | import collections
17 |
18 | import tensorflow as tf
19 |
20 | from prettytensor import pretty_tensor_class as prettytensor
21 | from prettytensor.pretty_tensor_class import Phase
22 | from prettytensor.pretty_tensor_class import PROVIDED
23 |
24 |
25 | BatchNormalizationArguments = collections.namedtuple(
26 | 'BatchNormalizationArguments',
27 | ('learned_moments_update_rate', 'variance_epsilon',
28 | 'scale_after_normalization'))
29 |
30 | BatchNormalizationArguments.__new__.__defaults__ = (None, None, None)
31 |
32 |
33 | def batch_normalize_with_arguments(x, arguments):
34 | """Applies batch normalization to x as specified in arguments.
35 |
36 | Args:
37 | x: A Pretty Tensor.
38 | arguments: Either a boolean to batch_normalize or a
39 | BatchNormalizationArguments
40 |
41 | Returns:
42 | x with batch normalization applied.
43 | """
44 | x = prettytensor.wrap(x)
45 | # Backwards compatibility.
46 | if isinstance(arguments, bool):
47 | if arguments:
48 | return x.batch_normalize()
49 | else:
50 | return x
51 |
52 | # pylint: disable=protected-access
53 | kwargs = arguments._asdict()
54 | defaults = prettytensor._defaults
55 | # pylint: enable=protected-access
56 | for arg in ('learned_moments_update_rate', 'variance_epsilon',
57 | 'scale_after_normalization'):
58 | if kwargs.get(arg, None) is None:
59 | if arg in defaults:
60 | kwargs[arg] = defaults[arg]
61 | else:
62 | del kwargs[arg]
63 | return x.batch_normalize(**kwargs)
64 |
65 |
66 | # pylint: disable=invalid-name
67 | @prettytensor.Register(
68 | assign_defaults=('learned_moments_update_rate', 'variance_epsilon',
69 | 'scale_after_normalization', 'phase'))
70 | class batch_normalize(prettytensor.VarStoreMethod):
71 |
72 | def __call__(self,
73 | input_layer,
74 | name=PROVIDED,
75 | learned_moments_update_rate=0.0003,
76 | variance_epsilon=0.001,
77 | scale_after_normalization=False,
78 | phase=Phase.train):
79 | """Batch normalize this layer.
80 |
81 | This only supports global batch normalization and it can be enabled for all
82 | convolutional layers by setting the default 'batch_normalize' to True.
83 | learned_moments_update_rate, variance_epsilon and scale_after_normalization
84 | need to either be set here or be set in defaults as well.
85 |
86 | Args:
87 | input_layer: The chainable object, supplied.
88 | name: The name for this operation is also used to create/find the
89 | parameter variables.
90 | learned_moments_update_rate: Update rate for the learned moments.
91 | variance_epsilon: A float. A small float number to avoid dividing by 0.
92 | scale_after_normalization: A bool indicating whether the resulted tensor
93 | needs to be multiplied with gamma.
94 | phase: The phase of construction.
95 | Returns:
96 | Handle to the generated layer.
97 | """
98 | # Allocate variables to hold the moving averages of the moments.
99 | params_shape = [input_layer.shape[-1]]
100 |
101 | # Allocate parameters for the beta and gamma of the normalization.
102 | beta = self.variable('beta', params_shape, tf.constant_initializer(0.0))
103 | if scale_after_normalization:
104 | gamma = self.variable('gamma', params_shape, tf.constant_initializer(1.0))
105 | else:
106 | gamma = None
107 | moving_mean = self.variable('moving_mean',
108 | params_shape,
109 | tf.constant_initializer(0.0),
110 | train=False)
111 | moving_variance = self.variable('moving_variance',
112 | params_shape,
113 | tf.constant_initializer(1.0),
114 | train=False)
115 |
116 | if phase == Phase.train:
117 | # Calculate the moments based on the individual batch.
118 | mean, variance = tf.nn.moments(
119 | input_layer.tensor, list(range(len(input_layer.get_shape()) - 1)))
120 | input_layer.bookkeeper.add_histogram_summary(mean)
121 | input_layer.bookkeeper.add_histogram_summary(variance)
122 |
123 | avg_mean = input_layer.bookkeeper.exponential_moving_average(
124 | mean, moving_mean, 1.0 - learned_moments_update_rate)
125 | avg_variance = input_layer.bookkeeper.exponential_moving_average(
126 | variance, moving_variance, 1.0 - learned_moments_update_rate)
127 | with tf.control_dependencies([avg_variance, avg_mean]):
128 | y = tf.nn.batch_normalization(input_layer, mean, variance, beta, gamma,
129 | variance_epsilon)
130 | else:
131 | # Load the mean and variance as the 'moving average' moments
132 | # from the checkpoint.
133 | y = tf.nn.batch_normalization(input_layer, moving_mean, moving_variance,
134 | beta, gamma, variance_epsilon)
135 |
136 | return input_layer.with_tensor(y)
137 | # pylint: enable=invalid-name
138 |
--------------------------------------------------------------------------------
/prettytensor/pretty_tensor_sparse_methods.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 Google Inc. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 | """Holds PrettyTensor methods related to sparse data types."""
12 | from __future__ import absolute_import
13 | from __future__ import division
14 | from __future__ import print_function
15 |
16 | import tensorflow as tf
17 |
18 | from prettytensor import pretty_tensor_class as prettytensor
19 |
20 |
21 | @prettytensor.Register
22 | def to_dense_one_hot(labels, class_count):
23 | """Converts a vector that specified one-hot per batch into a dense version.
24 |
25 | Args:
26 | labels: The labels input.
27 | class_count: The number of classes as an int.
28 | Returns:
29 | One dense vector for each item in the batch.
30 | Raises:
31 | ValueError: If labels is not rank 1.
32 | TypeError: If class_count is not an integer or labels is not an integer
33 | Tensor.
34 | """
35 | if not isinstance(class_count, tf.compat.integral_types):
36 | raise TypeError('class_count must be an integer type.')
37 | if labels.dtype.base_dtype not in (tf.int32, tf.int64):
38 | raise TypeError('Labels must be an integer: %s' % labels.dtype)
39 | if labels.get_shape().ndims != 1:
40 | raise ValueError('Labels must be a rank 1 tensor: %s' % labels.get_shape())
41 |
42 | dtype = labels.dtype.base_dtype
43 | class_tensor = tf.convert_to_tensor(
44 | class_count, dtype=dtype, name='class_count')
45 |
46 | # Extract the batch from the shape so this is batch independent.
47 | batch = tf.gather(tf.shape(labels), 0)
48 | count = tf.expand_dims(tf.range(0, limit=batch), 1)
49 | labels = tf.expand_dims(labels, 1)
50 | batch = tf.gather(tf.shape(labels), 0)
51 |
52 | if dtype != tf.int32:
53 | count = tf.cast(count, dtype)
54 | batch = tf.cast(batch, dtype)
55 |
56 | result = tf.sparse_to_dense(
57 | tf.concat([count, labels], 1),
58 | tf.concat([tf.expand_dims(batch, 0), tf.expand_dims(class_tensor, 0)], 0),
59 | 1.0, 0.0)
60 | result.set_shape([labels.get_shape().dims[0], class_count])
61 | return result
62 |
--------------------------------------------------------------------------------
/prettytensor/pretty_tensor_testing.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 Google Inc. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 | """Test class for PrettyTensor."""
12 | from __future__ import absolute_import
13 | from __future__ import division
14 | from __future__ import print_function
15 |
16 | import unittest
17 | import six
18 | import tensorflow as tf
19 |
20 | import prettytensor
21 |
22 |
23 | # Count of tests for unique root namespaces.
24 | _count = 0
25 |
26 |
27 | class PtTestCase(unittest.TestCase):
28 | """Contains shared setUp/tearDown and convenience methods.
29 |
30 | This adds the following attributes to self:
31 |
32 | self.bookkeeper
33 | self.sess
34 | """
35 |
36 | def __init__(self, *args):
37 | super(PtTestCase, self).__init__(*args)
38 | self.bookkeeper = None
39 | self.sess = None
40 | self._graph_with = None
41 |
42 | def RunTensor(self, layer, init=True):
43 | """Convenience method to run a tensor."""
44 | if init:
45 | self.sess.run(tf.global_variables_initializer())
46 | if isinstance(layer, (tf.Tensor, six.string_types)):
47 | return self.sess.run(layer)
48 | elif layer.is_sequence():
49 | return self.sess.run(layer.sequence)
50 | else:
51 | return self.sess.run(layer)
52 |
53 | def Wrap(self, tensor, tensor_shape=None):
54 | """Convenience for prettytensor.wrap(tensor, self.bookkeeper)."""
55 | return prettytensor.wrap(tensor, self.bookkeeper, tensor_shape)
56 |
57 | def setUp(self):
58 | unittest.TestCase.setUp(self)
59 | self.SetBookkeeper(prettytensor.bookkeeper_for_new_graph())
60 |
61 | def tearDown(self):
62 | self.TearDownBookkeeper()
63 | unittest.TestCase.tearDown(self)
64 |
65 | def SetBookkeeper(self, m):
66 | """Used to set custom bookkeeper code."""
67 | self.TearDownBookkeeper()
68 | self.bookkeeper = m
69 | self._graph_with = self.bookkeeper.g.as_default()
70 | self._graph_with.__enter__()
71 | global _count
72 | _count += 1
73 |
74 | self.sess = tf.Session('')
75 |
76 | def TearDownBookkeeper(self):
77 | if self._graph_with:
78 | self._graph_with.__exit__(None, None, None)
79 | self._graph_with = None
80 | if self.sess:
81 | self.sess.close()
82 | self.sess = None
83 | self.bookkeeper = None
84 |
--------------------------------------------------------------------------------
/prettytensor/recurrent_networks_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 Google Inc. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 | """Test class for the recurrent networks module."""
12 | from __future__ import absolute_import
13 | from __future__ import division
14 | from __future__ import print_function
15 |
16 | import unittest
17 |
18 |
19 |
20 | import numpy
21 | from numpy import testing
22 | import six
23 | from six.moves import xrange # pylint: disable=redefined-builtin
24 | from six.moves import zip # pylint: disable=redefined-builtin
25 | import tensorflow as tf
26 |
27 | import prettytensor
28 | from prettytensor import pretty_tensor_testing
29 | from prettytensor import recurrent_networks
30 | from prettytensor import recurrent_networks_testing_utils as testing_utils
31 |
32 |
33 | TOLERANCE = 0.00001
34 |
35 |
36 | class RecurrentNetworksTest(pretty_tensor_testing.PtTestCase):
37 |
38 | def setUp(self):
39 | super(self.__class__, self).setUp()
40 | self.input_data = numpy.array(
41 | [
42 | [[1.], [2.], [3.], [4.]], [[5.], [6.], [7.], [8.]],
43 | [[5.], [6.], [7.], [8.]], [[9.], [10.], [11.], [12.]]
44 | ],
45 | dtype=numpy.float)
46 | self.sequence = testing_utils.SequenceInputMock(
47 | self.bookkeeper, self.input_data, [[0.], [0.], [0.], [1.]], 13)
48 |
49 | self.input, self.output = recurrent_networks.create_sequence_pretty_tensor(
50 | self.sequence)
51 |
52 | def testSquashAndCleave(self):
53 | squashed = self.input.squash_sequence()
54 | result = self.RunTensor(squashed)
55 |
56 | testing.assert_allclose(
57 | self.input_data.reshape(16, 1),
58 | result,
59 | rtol=TOLERANCE)
60 |
61 | result = self.RunTensor(squashed.cleave_sequence())
62 |
63 | for i in xrange(len(self.input_data)):
64 | testing.assert_allclose(
65 | self.input_data[i], result[i],
66 | rtol=TOLERANCE)
67 |
68 | def testSquashAndCleaveLength1(self):
69 | input_data = numpy.array(
70 | [[[1.], [2.], [3.], [4.]]], dtype=numpy.float)
71 | sequence = testing_utils.SequenceInputMock(
72 | self.bookkeeper, input_data, [[0.]], 13)
73 |
74 | inp, _ = recurrent_networks.create_sequence_pretty_tensor(sequence)
75 | squashed = inp.squash_sequence()
76 | result = self.RunTensor(squashed)
77 |
78 | testing.assert_allclose(
79 | input_data.reshape(4, 1), result,
80 | rtol=TOLERANCE)
81 |
82 | result = self.RunTensor(squashed.cleave_sequence())
83 |
84 | testing.assert_allclose(input_data[0], result[0], rtol=TOLERANCE)
85 | self.assertEquals(1, len(result))
86 |
87 | def testSequenceLstm(self):
88 | lstm = self.input.sequence_lstm(13)
89 | result = self.RunTensor(lstm)
90 |
91 | self.assertEquals([4, 13], lstm.shape)
92 | self.assertEquals(4, len(result))
93 | for i in xrange(4):
94 | self.assertSequenceEqual(lstm.shape, result[i].shape)
95 |
96 | def testSequenceGru(self):
97 | gru = self.input.sequence_gru(13)
98 | result = self.RunTensor(gru)
99 |
100 | self.assertEquals([4, 13], gru.shape)
101 | self.assertEquals(4, len(result))
102 | for i in xrange(4):
103 | self.assertSequenceEqual(gru.shape, result[i].shape)
104 |
105 | def performTestArbitraryBatchSizeRnn(self, cell_type):
106 | # Tests whether LSTM / GRU / Bookkeeper function when batch_size is not
107 | # specified at graph creation time (i.e., None).
108 | self.assertTrue(cell_type == 'lstm' or cell_type == 'gru')
109 | super(self.__class__, self).SetBookkeeper(
110 | prettytensor.bookkeeper_for_new_graph())
111 |
112 | # Build a graph. Specify None for the batch_size dimension.
113 | placeholder = tf.placeholder(tf.float32, [None, 1])
114 | input_pt = prettytensor.wrap_sequence([placeholder])
115 | if cell_type == 'lstm':
116 | output, _ = (input_pt
117 | .sequence_lstm(4)
118 | .squash_sequence()
119 | .softmax_classifier(2))
120 | elif cell_type == 'gru':
121 | output, _ = (input_pt
122 | .sequence_gru(4)
123 | .squash_sequence()
124 | .softmax_classifier(2))
125 |
126 | self.sess.run(tf.global_variables_initializer())
127 |
128 | # Use RecurrentRunner for state saving and managing feeds.
129 | recurrent_runner = recurrent_networks.RecurrentRunner(batch_size=1)
130 |
131 | # Run with a batch size of 1 for 10 steps, save output for reference.
132 | out_orig = []
133 | for t in xrange(10):
134 | outs = recurrent_runner.run(
135 | [output.name],
136 | {placeholder.name: numpy.array([[1.2]])},
137 | sess=self.sess)
138 | out = outs[0]
139 | self.assertEqual(1, len(out))
140 | self.assertEqual(2, len(out[0]))
141 | out_orig.append(out[0])
142 |
143 | # Test the reset functionality - after a reset, the results must be
144 | # identical to what we just got above.
145 | recurrent_runner.reset()
146 | for t in xrange(10):
147 | outs = recurrent_runner.run(
148 | [output.name],
149 | {placeholder.name: numpy.array([[1.2]])},
150 | sess=self.sess)
151 | out = outs[0]
152 | self.assertEqual(1, len(out))
153 | self.assertEqual(2, len(out[0]))
154 | testing.assert_allclose(out[0], out_orig[t])
155 |
156 | # Test whether the recurrent runner detects changes to the default graph.
157 | # It should raise an Assertion because RecurrentRunner's state saver
158 | # information (collected during __init__) is not valid anymore.
159 | with tf.Graph().as_default():
160 | placeholder2 = tf.placeholder(tf.float32, [None, 1])
161 | input_pt2 = prettytensor.wrap_sequence([placeholder2])
162 | if cell_type == 'lstm':
163 | output2, _ = (input_pt2
164 | .sequence_lstm(4)
165 | .squash_sequence()
166 | .softmax_classifier(2))
167 | elif cell_type == 'gru':
168 | output2, _ = (input_pt2
169 | .sequence_gru(4)
170 | .squash_sequence()
171 | .softmax_classifier(2))
172 | self.assertRaises(ValueError,
173 | recurrent_runner.run,
174 | [output2.name], None, self.sess)
175 |
176 | # Run with a batch size of 3; first and third input are identical and must
177 | # yield identical output, and the same output as in the single batch run
178 | # above (up to floating point rounding errors).
179 | recurrent_runner = recurrent_networks.RecurrentRunner(batch_size=3)
180 | for t in xrange(10):
181 | outs = recurrent_runner.run(
182 | [output.name],
183 | {placeholder.name: numpy.array([[1.2], [3.4], [1.2]])},
184 | sess=self.sess)
185 | out = outs[0]
186 | self.assertEqual(3, len(out))
187 | self.assertEqual(2, len(out[0]))
188 | testing.assert_allclose(out[0], out[2], rtol=TOLERANCE)
189 | testing.assert_allclose(out[0], out_orig[t], rtol=TOLERANCE)
190 | # Sanity check to protect against trivial outputs that might hide errors.
191 | # Need to avoid checking after t = 2 since untrained GRUs have a
192 | # tendency to converge to large state values, leading to outputs like
193 | # 1.0, 0.0.
194 | if cell_type == 'gru' and t > 2:
195 | continue
196 | self.assertFalse((out[0] == out[1]).all())
197 |
198 | def testArbitraryBatchSizeLstm(self):
199 | self.performTestArbitraryBatchSizeRnn('lstm')
200 |
201 | def testArbitraryBatchSizeGru(self):
202 | self.performTestArbitraryBatchSizeRnn('gru')
203 |
204 | def testSequence(self):
205 | result = self.RunTensor(self.input[-1])
206 | testing.assert_allclose(
207 | self.input_data[-1], result,
208 | rtol=TOLERANCE)
209 |
210 | def testEmbeddingNameWorkaround(self):
211 | """Just make sure this runs since it is ensuring that a workaround works."""
212 | input_data = self.Wrap(self.input_data.astype(numpy.int32).reshape([16, 1]))
213 | result = input_data.embedding_lookup(
214 | 13, [1], name='params')
215 | self.RunTensor(result)
216 |
217 | def testEmbeddingLookupRequiresRank2(self):
218 | """Just make sure this runs since it is ensuring that a workaround works."""
219 | input_data = self.Wrap(self.input_data.astype(numpy.int32))
220 | with self.assertRaises(ValueError):
221 | input_data.embedding_lookup(13, [1], name='params')
222 |
223 | def testLstmStateTuples(self):
224 | self.states = recurrent_networks.lstm_state_tuples(13, 'blah')
225 | self.RunTensor(
226 | self.input.sequence_lstm(13, name='blah')[-1])
227 |
228 | for state in self.states:
229 | self.assertTrue(
230 | state[0] in self.sequence.requested_tensors, '%s missing: %s' %
231 | (state[0], list(six.iterkeys(self.sequence.requested_tensors))))
232 | self.assertEqual(
233 | len(self.states), len(self.sequence.requested_tensors),
234 | 'Wrong number of Tensor states.')
235 |
236 | def testGruStateTuples(self):
237 | self.states = recurrent_networks.gru_state_tuples(13, 'blah')
238 | self.RunTensor(
239 | self.input.sequence_gru(13, name='blah')[-1])
240 |
241 | for state in self.states:
242 | self.assertTrue(
243 | state[0] in self.sequence.requested_tensors, '%s missing: %s' %
244 | (state[0], list(six.iterkeys(self.sequence.requested_tensors))))
245 | self.assertEqual(
246 | len(self.states), len(self.sequence.requested_tensors),
247 | 'Wrong number of Tensor states.')
248 |
249 | def testLength(self):
250 | tf.set_random_seed(4321)
251 | with tf.variable_scope('test') as vs:
252 | base_lstm = self.input.sequence_lstm(13)
253 | lengths = tf.placeholder(dtype=tf.int32, shape=[4])
254 |
255 | # Use the same parameters.
256 | with tf.variable_scope(vs, reuse=True):
257 | lstm_truncated = self.input.sequence_lstm(13, lengths=lengths)
258 |
259 | with tf.Session() as sess:
260 | tf.global_variables_initializer().run()
261 |
262 | result = sess.run(base_lstm.sequence + lstm_truncated.sequence,
263 | {lengths: [10, 4, 1, 1]})
264 | base_result = result[:len(base_lstm.sequence)]
265 | full_result = result[len(base_lstm.sequence):]
266 | truncated_result = sess.run(lstm_truncated.sequence,
267 | {lengths: [1, 2, 1, 1]})
268 |
269 | for i, (x, y) in enumerate(zip(base_result, truncated_result)):
270 | if i < 2:
271 | testing.assert_allclose(x, y, rtol=TOLERANCE)
272 | else:
273 | # After the specified output, we check to make sure the same values are
274 | # propagated forward.
275 | self.assertFalse(numpy.allclose(x, y, rtol=TOLERANCE))
276 | testing.assert_allclose(y, truncated_result[i - 1], rtol=TOLERANCE)
277 | for x, y in zip(base_result, full_result):
278 | # The later results tend to diverge. This is something that requires
279 | # investigation.
280 | testing.assert_allclose(x, y, atol=0.1)
281 |
282 |
283 | if __name__ == '__main__':
284 | unittest.main()
285 |
--------------------------------------------------------------------------------
/prettytensor/recurrent_networks_testing_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 Google Inc. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 | """Utility class for testing recurrent networks."""
12 | from __future__ import absolute_import
13 | from __future__ import division
14 | from __future__ import print_function
15 |
16 | import tensorflow as tf
17 |
18 |
19 | class SequenceInputMock(object):
20 | """A sequence input mock for testing recurrent networks."""
21 |
22 | def __init__(self, bookkeeper, input_list, label_list, node_depth):
23 | self.inputs = self.for_constants(input_list)
24 | self.targets = self.for_constants(label_list)
25 | self.node_depth = node_depth
26 | self.batch_size = input_list[0].shape[0]
27 | self.requested_tensors = {}
28 | self.bookkeeper = bookkeeper
29 | self.num_timesteps = len(input_list)
30 |
31 | def for_constants(self, ls):
32 | return [tf.constant(x, dtype=tf.float32) for x in ls]
33 |
34 | def state(self, state_name):
35 | """Returns, creating if necessary, a state variable with the given name."""
36 | if state_name not in self.requested_tensors:
37 | count = tf.get_variable('count_%s' % state_name,
38 | [],
39 | tf.int32,
40 | tf.zeros_initializer(),
41 | trainable=False)
42 | value = tf.get_variable(state_name, [self.batch_size, self.node_depth],
43 | tf.float32, tf.zeros_initializer())
44 | self.requested_tensors[state_name] = (count, value)
45 |
46 | return self.requested_tensors[state_name][1]
47 |
48 | def save_state(self, state_name, unused_value, name='SaveState'):
49 | return tf.assign_add(self.requested_tensors[state_name][0], 1, name=name)
50 |
51 |
--------------------------------------------------------------------------------
/prettytensor/replay_queue.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 Google Inc. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 | """Creates a replayable queue."""
12 | from __future__ import absolute_import
13 | from __future__ import division
14 | from __future__ import print_function
15 |
16 | import collections
17 | import contextlib
18 |
19 |
20 | import six
21 | import tensorflow as tf
22 |
23 | from prettytensor import pretty_tensor_class as prettytensor
24 |
25 |
26 | def _make_tuple(x):
27 | """TF has an obnoxious habit of being lenient with single vs tuple."""
28 | if isinstance(x, prettytensor.PrettyTensor):
29 | if x.is_sequence():
30 | return tuple(x.sequence)
31 | else:
32 | return (x.tensor,)
33 | elif isinstance(x, tuple):
34 | return x
35 | elif (isinstance(x, collections.Sequence) and
36 | not isinstance(x, six.string_types)):
37 | return tuple(x)
38 | else:
39 | return (x,)
40 |
41 |
42 | class ReplayableQueue(object):
43 | """A switchable queue between the original data and a replayed subset.
44 |
45 | This queue combines 2 concepts:
46 |
47 | 1. A replay queue that re-enqueues the data every time it is dequeued so that
48 | multiple passes over the same data can be made (make sure to iterated
49 | `replay_steps` times).
50 | 2. The ability to switch data sources between the original source and the
51 | replayable queue. Embedding the switch makes it easy to construct a single
52 | graph on top of both input sources.
53 |
54 | Note: The queue requires manual filling by calling `refill`!
55 |
56 | The typical use case for replaying data is when you want to run an experiment
57 | on a subset of the training data and need to reuse the data multiple times.
58 |
59 | Here is an example that uses the replay queue to monitor loss on a dynamically
60 | selected validation set:
61 |
62 | ```
63 | replay = pt.train.ReplayableQueue(lambda: MY_Q.dequeue_many(BATCH_SIZE),
64 | REPLAY_SIZE)
65 | # Build a graph with replay.output
66 | my_train_op, my_loss = build_graph(replay.output)
67 |
68 | with tf.Session() as sess:
69 | # Capture some data
70 | replay.refill(sess)
71 |
72 | for epoch in xrange(EPOCHS):
73 | # Train for a while
74 | for _ in xrange(1000):
75 | sess.run(my_train_op)
76 | loss = 0
77 | with replay.replay_scope():
78 | for _ in xrange(replay.replay_steps):
79 | loss += sess.run(my_loss)
80 | loss /= replay.replay_steps
81 | print('Loss at epoch %d: %g' % (epoch, loss))
82 | ```
83 | """
84 |
85 | def __init__(self, input_fn, replay_size, batch_size=None):
86 | """Creates a ReplayableQueue that takes data from `input_fn`.
87 |
88 | See also: `pt.train.ReplayableQueue.build_from_queue`.
89 |
90 | Note: the shapes of the inputs must be fully defined.
91 |
92 | Note: `input_fn` is a function instead of an input. This is because
93 | otherwise if the input came from a queue, dependencies wouldn't be set up
94 | properly and the data would always be dequeued. If you are providing data
95 | from a queue, then pass in `lambda: q.dequeue_many(batch_size)`.
96 |
97 | Args:
98 | input_fn: A function of no arguments that returns the input as a tuple of
99 | `Tensors`.
100 | replay_size: The size of the replay queue.
101 | batch_size: If provided, use this as the batch size otherwise infer it.
102 |
103 | Raises:
104 | ValueError: if `replay_size` is not divisible by `batch_size` or if the
105 | shapes on the input are wrong.
106 | """
107 | inputs = _make_tuple(input_fn())
108 |
109 | for x in inputs:
110 | x.get_shape().assert_is_fully_defined()
111 | if batch_size is not None:
112 | x.get_shape()[0].assert_is_compatible_with(batch_size)
113 | else:
114 | batch_size = x.get_shape()[0].value
115 |
116 | dtypes = [x.dtype for x in inputs]
117 | shapes = [x.get_shape()[1:] if x.get_shape() else () for x in inputs]
118 |
119 | if replay_size % batch_size != 0:
120 | raise ValueError('replay_size size (%d) must be a multiple of batch size '
121 | '(%d)' % (replay_size, batch_size))
122 |
123 | # Setup the flag that controls replay.
124 | self._replay_var = tf.get_variable(
125 | 'replay',
126 | dtype=tf.bool,
127 | shape=[],
128 | initializer=tf.constant_initializer(False),
129 | trainable=False)
130 | self._set_replay_ph = tf.placeholder(dtype=tf.bool)
131 | self._set_replay = self._replay_var.assign(self._set_replay_ph)
132 |
133 | self._replay_queue = tf.FIFOQueue(replay_size, dtypes, shapes)
134 |
135 | # _fill_queue adds data to the queue and then returns whether it is full.
136 | with tf.control_dependencies([self._replay_queue.enqueue_many(inputs)]):
137 | self._fill_queue = tf.less(self._replay_queue.size(), replay_size)
138 |
139 | # Dequeue all the things!
140 | self._clear_queue = self._replay_queue.dequeue_many(
141 | self._replay_queue.size())
142 |
143 | def _pull_from_replay():
144 | data_tuple = _make_tuple(self._replay_queue.dequeue_many(batch_size))
145 | with tf.control_dependencies([self._replay_queue.enqueue_many(data_tuple)
146 | ]):
147 | return (tf.identity(data_tuple[0]),) + data_tuple[1:]
148 |
149 | def _pull_from_original():
150 | return _make_tuple(input_fn())
151 |
152 | self._output = prettytensor.wrap(
153 | tf.cond(self._replay_var, _pull_from_replay, _pull_from_original))
154 |
155 | @classmethod
156 | def build_from_queue(cls, input_queue, replay_size, batch_size):
157 | """Builds a `ReplayableQueue` that draws from a regular `input_queue`.
158 |
159 | Args:
160 | input_queue: The queue to draw from.
161 | replay_size: The size of the replay buffer.
162 | batch_size: The size of each batch.
163 |
164 | Returns:
165 | A ReplayableQueue.
166 | """
167 | return cls(
168 | lambda: input_queue.dequeue_many(batch_size),
169 | replay_size,
170 | batch_size=batch_size)
171 |
172 | @property
173 | def output(self):
174 | """Returns the output Tensor for this queue.
175 |
176 | The output is selected between the original data and the replay data
177 | depending on the replay value.
178 |
179 | Returns:
180 | The output tensors as a tuple.
181 | """
182 | return self._output
183 |
184 | @property
185 | def replay_steps(self):
186 | """Returns the number of steps to replay."""
187 | return self._replay_size // self._batch_size
188 |
189 | @property
190 | def replay_size(self):
191 | """Returns the total number of examples this queue holds."""
192 | return self._replay_size
193 |
194 | @contextlib.contextmanager
195 | def replay_scope(self, sess):
196 | """Enters a replay scope that unsets it at the end."""
197 | current_replay = self.replay(sess)
198 | try:
199 | self.set_replay(sess, True)
200 | yield
201 | finally:
202 | self.set_replay(sess, current_replay)
203 |
204 | def replay(self, sess):
205 | """Gets the current value of replay from the graph.
206 |
207 | Note: this runs the graph, but it is just a var read so it is fairly cheap.
208 |
209 | Args:
210 | sess: The session in which to run.
211 | Returns:
212 | The value of the replay variable.
213 | """
214 | return sess.run(self._replay_var)
215 |
216 | def set_replay(self, sess, replay):
217 | """Changes the current replay setting on the graph."""
218 | sess.run(self._set_replay, {self._set_replay_ph: replay})
219 |
220 | def refill(self, sess):
221 | """Clears the current queue and then refills it with new data."""
222 | sess.run(self._clear_queue)
223 | # Run until full.
224 | while sess.run(self._fill_queue):
225 | pass
226 |
--------------------------------------------------------------------------------
/prettytensor/replay_queue_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 Google Inc. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 | """Tests for replay_queue."""
12 |
13 |
14 | import numpy as np
15 | import tensorflow as tf
16 |
17 | import prettytensor as pt
18 |
19 |
20 | class ReplayableQueueTest(tf.test.TestCase):
21 |
22 | def test_replay_queue_with_queue_input(self):
23 | # Put a lot of replay.output on the queue
24 | q = tf.FIFOQueue(1000, tf.float32, [])
25 | enqueue = q.enqueue_many(tf.to_float(tf.range(0, 1000)))
26 | replay = pt.train.ReplayableQueue.build_from_queue(q, 100, 10)
27 |
28 | with self.test_session() as sess:
29 | sess.run(tf.global_variables_initializer())
30 |
31 | sess.run(enqueue)
32 |
33 | d = sess.run(replay.output)
34 | self.assertAllClose(np.arange(10).astype(np.float), d)
35 |
36 | # Now fill the queue
37 | replay.refill(sess)
38 |
39 | self.assertEqual(100, replay._replay_queue.size().eval())
40 |
41 | d = sess.run(replay.output)
42 |
43 | # Replay is still false, but the queue has advanced
44 | self.assertAllClose(np.arange(110, 120).astype(np.float), d)
45 |
46 | # Now set replay.
47 | replay.set_replay(sess, True)
48 | for i in range(10):
49 | d = sess.run(replay.output)
50 | range_start = 10 + i * 10
51 | self.assertAllClose(
52 | np.arange(range_start, range_start + 10).astype(np.float), d)
53 |
54 | # And again
55 | for i in range(10):
56 | d = sess.run(replay.output)
57 | range_start = 10 + i * 10
58 | self.assertAllClose(
59 | np.arange(range_start, range_start + 10).astype(np.float), d)
60 |
61 | replay.set_replay(sess, False)
62 |
63 | # Back to the replay.output stream
64 | d = sess.run(replay.output)
65 | self.assertAllClose(np.arange(120, 130).astype(np.float), d)
66 |
67 | # And refill the queue
68 | replay.refill(sess)
69 | replay.set_replay(sess, True)
70 | d = sess.run(replay.output)
71 | self.assertAllClose(np.arange(130, 140).astype(np.float), d)
72 |
73 | replay.set_replay(sess, False)
74 | d = sess.run(replay.output)
75 | self.assertAllClose(np.arange(230, 240).astype(np.float), d)
76 |
77 | if __name__ == '__main__':
78 | tf.test.main()
79 |
--------------------------------------------------------------------------------
/prettytensor/scopes.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 Google Inc. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 | """Contains methods related to making templates."""
12 | from __future__ import absolute_import
13 | from __future__ import division
14 | from __future__ import print_function
15 |
16 | import contextlib
17 | import functools
18 | import traceback
19 |
20 | from six.moves import zip # pylint: disable=redefined-builtin
21 | import tensorflow as tf
22 |
23 | from tensorflow.python.ops import variable_scope
24 |
25 |
26 | @contextlib.contextmanager
27 | def var_and_name_scope(names):
28 | """Creates a variable scope and a name scope.
29 |
30 | If a variable_scope is provided, this will reenter that variable scope.
31 | However, if none is provided then the variable scope will match the generated
32 | part of the name scope.
33 |
34 | Args:
35 | names: A tuple of name_scope, variable_scope or None.
36 | Yields:
37 | The result of name_scope and variable_scope as a tuple.
38 | """
39 | # pylint: disable=protected-access
40 | if not names:
41 | yield None, None
42 | else:
43 | name, var_scope = names
44 | with tf.name_scope(name) as scope:
45 | # TODO(eiderman): This is a workaround until the variable_scope updates land
46 | # in a TF release.
47 | old_vs = tf.get_variable_scope()
48 | if var_scope is None:
49 | count = len(name.split('/'))
50 | scoped_name = '/'.join(scope.split('/')[-count - 1:-1])
51 | full_name = (old_vs.name + '/' + scoped_name).lstrip('/')
52 | else:
53 | full_name = var_scope.name
54 |
55 | vs_key = tf.get_collection_ref(variable_scope._VARSCOPE_KEY)
56 | try:
57 | # TODO(eiderman): Remove this hack or fix the full file.
58 | try:
59 | vs_key[0] = tf.VariableScope(
60 | old_vs.reuse,
61 | name=full_name,
62 | initializer=old_vs.initializer,
63 | regularizer=old_vs.regularizer,
64 | caching_device=old_vs.caching_device)
65 | except AttributeError:
66 | vs_key[0] = variable_scope._VariableScope(
67 | old_vs.reuse,
68 | name=full_name,
69 | initializer=old_vs.initializer)
70 |
71 | vs_key[0].name_scope = scope
72 | yield scope, vs_key[0]
73 | finally:
74 | vs_key[0] = old_vs
75 |
76 |
77 | def get_current_name_scope():
78 | """Gets the current name scope."""
79 | # pylint: disable=protected-access
80 | g = tf.get_default_graph()
81 | # TODO(eiderman): Remove this hack once TF update is released.
82 | if isinstance(g._name_stack, tuple):
83 | return g._name_stack[0] + '/'
84 | else:
85 | return g._name_stack + '/'
86 |
87 |
88 | def _get_last_part_of_name_scope(scope):
89 | splits = scope.split('/')
90 | return splits[-2]
91 |
92 |
93 | def make_template(name, func, *args, **kwargs):
94 | """Given an arbitrary function, wrap it so that it does parameter sharing."""
95 | if args or kwargs:
96 | func = functools.partial(func, *args, **kwargs)
97 | return Template(name, func)
98 |
99 |
100 | def skip_common_stack_elements(stacktrace, base_case):
101 | """Skips items that the target stacktrace shares with the base stacktrace."""
102 | for i, (trace, base) in enumerate(zip(stacktrace, base_case)):
103 | if trace != base:
104 | return stacktrace[i:]
105 | return stacktrace[-1:]
106 |
107 |
108 | class Template(object):
109 | """A Template captures variable and namescopes to help variable sharing."""
110 |
111 | def __init__(self, name, func):
112 | """Creates a template for the given function.
113 |
114 | Args:
115 | name: The variable_scope to use, if None the current scope is captured.
116 | func: The function to apply each time.
117 | """
118 | self._func = func
119 | if name:
120 | self._var_scope = None
121 | self._name = name
122 | else:
123 | self._var_scope = tf.get_variable_scope()
124 | self._name = None
125 | self._reuse = None
126 | self._stacktrace = traceback.format_stack()[:-3]
127 |
128 | def _call_func(self, args, kwargs):
129 | try:
130 | self._reuse = True
131 | return self._func(*args, **kwargs)
132 | except Exception as exc:
133 | # Reraise the exception, but append the original definition to the
134 | # trace.
135 | args = exc.args
136 | if not args:
137 | arg0 = ''
138 | else:
139 | arg0 = args[0]
140 | trace = ''.join(skip_common_stack_elements(self._stacktrace,
141 | traceback.format_stack()))
142 | arg0 = '%s\n\noriginally defined at:\n%s' % (arg0, trace)
143 | new_args = [arg0]
144 | new_args.extend(args[1:])
145 | exc.args = tuple(new_args)
146 | raise
147 |
148 | def __call__(self, *args, **kwargs):
149 | if self._name:
150 | with var_and_name_scope((self._name, self._var_scope)) as (_, vs):
151 | if self._reuse:
152 | vs.reuse_variables()
153 | else:
154 | self._var_scope = vs
155 | return self._call_func(args, kwargs)
156 | else:
157 | with tf.variable_scope(self._var_scope, reuse=self._reuse) as vs:
158 | return self._call_func(args, kwargs)
159 |
--------------------------------------------------------------------------------
/prettytensor/scopes_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 Google Inc. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 | """Tests for scopes."""
12 | from __future__ import absolute_import
13 | from __future__ import division
14 | from __future__ import print_function
15 |
16 | import traceback
17 | import unittest
18 |
19 |
20 | import tensorflow as tf
21 |
22 | from prettytensor import scopes
23 |
24 |
25 | def var_scoped_function():
26 | return tf.get_variable('dummy', shape=[1], initializer=tf.zeros_initializer())
27 |
28 |
29 | class ScopesTest(unittest.TestCase):
30 |
31 | def test_skip_stack_frames(self):
32 | first = traceback.format_stack()
33 | second = traceback.format_stack()
34 | result = scopes.skip_common_stack_elements(first, second)
35 | self.assertEqual(1, len(result))
36 | self.assertNotEqual(len(first), len(result))
37 |
38 | def test_get_current_name_scope(self):
39 | self.assertEquals('/', scopes.get_current_name_scope())
40 | self.assertEquals('', scopes._get_last_part_of_name_scope('/'))
41 | with tf.name_scope('one') as scope:
42 | self.assertEquals(scope, scopes.get_current_name_scope())
43 | self.assertEquals('one', scopes._get_last_part_of_name_scope(scope))
44 |
45 | with tf.name_scope('one') as scope:
46 | self.assertEquals(scope, scopes.get_current_name_scope())
47 | self.assertEquals('one_1', scopes._get_last_part_of_name_scope(scope))
48 | with tf.name_scope('two') as nested_scope:
49 | self.assertEquals(nested_scope, scopes.get_current_name_scope())
50 | self.assertEquals('two',
51 | scopes._get_last_part_of_name_scope(nested_scope))
52 |
53 | def test_template_without_name(self):
54 | tmpl1 = scopes.Template(None, var_scoped_function)
55 |
56 | v1 = tmpl1()
57 | v2 = tmpl1()
58 | self.assertEqual(v1, v2)
59 | self.assertEqual('dummy:0', v1.name)
60 |
61 | def test_template_with_name(self):
62 | tmpl1 = scopes.Template('s1', var_scoped_function)
63 | tmpl2 = scopes.Template('s1', var_scoped_function)
64 |
65 | v1 = tmpl1()
66 | v2 = tmpl1()
67 | v3 = tmpl2()
68 | self.assertEqual(v1, v2)
69 | self.assertNotEqual(v1, v3)
70 | self.assertEqual('s1/dummy:0', v1.name)
71 | self.assertEqual('s1_2/dummy:0', v3.name)
72 |
73 | def test_var_and_name_scope(self):
74 | with tf.Graph().as_default():
75 | with scopes.var_and_name_scope(('one', None)) as (ns, vs):
76 | self.assertEqual('one/', ns)
77 | self.assertEqual('one', vs.name)
78 | with scopes.var_and_name_scope(('one', None)) as (ns, vs):
79 | self.assertEqual('one_1/', ns)
80 | self.assertEqual('one_1', vs.name)
81 | with scopes.var_and_name_scope(('one/two', None)) as (ns, vs):
82 | self.assertEqual('one/two/', ns)
83 | self.assertEqual('one/two', vs.name)
84 |
85 |
86 | if __name__ == '__main__':
87 | unittest.main()
88 |
--------------------------------------------------------------------------------
/prettytensor/sequence_with_deltas.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 Google Inc. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 | """Provides a class that just implements a sequence with a delta count."""
12 | from __future__ import absolute_import
13 | from __future__ import division
14 | from __future__ import print_function
15 |
16 | import collections
17 |
18 |
19 | class SequenceWithDeltas(collections.MutableSequence):
20 | """Provides a sequence with a count of modifications."""
21 |
22 | def __init__(self, other_seq=None):
23 | if other_seq is None:
24 | self._seq = []
25 | else:
26 | self._seq = list(other_seq)
27 | self._mods = len(self._seq)
28 | self._mark = 0
29 |
30 | def __getitem__(self, key):
31 | return self._seq[key]
32 |
33 | def __setitem__(self, key, value):
34 | self._mods += 1
35 | self._seq[key] = value
36 |
37 | def __delitem__(self, key):
38 | self._mods += 1
39 | del self._seq[key]
40 |
41 | def __len__(self):
42 | return len(self._seq)
43 |
44 | def insert(self, key, value):
45 | self._mods += 1
46 | self._seq.insert(key, value)
47 |
48 | @property
49 | def deltas(self):
50 | return self._mods
51 |
52 | def mark(self):
53 | """Marks this sequence at the current number of deltas."""
54 | self._mark = self._mods
55 |
56 | def has_changed(self):
57 | """Returns if it has changed since the last mark."""
58 | return self._mark == self._mods
59 |
--------------------------------------------------------------------------------
/prettytensor/templated_pretty_tensor_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 Google Inc. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 | """Tests for templating in PrettyTensor."""
12 | from __future__ import absolute_import
13 | from __future__ import division
14 | from __future__ import print_function
15 |
16 | import unittest
17 |
18 |
19 |
20 | import numpy
21 | from numpy import testing
22 | import tensorflow as tf
23 |
24 | import prettytensor
25 | from prettytensor import pretty_tensor_class
26 | from prettytensor import pretty_tensor_testing
27 |
28 | KEY = 'random_key'
29 | TOLERANCE = 0.000001
30 |
31 |
32 | @prettytensor.Register(assign_defaults='value')
33 | def ValidateMethod(input_tensor, test_class, value):
34 | test_class.assertEqual(KEY, value)
35 | return input_tensor
36 |
37 |
38 | class TemplatedPrettyTensorTest(pretty_tensor_testing.PtTestCase):
39 |
40 | def setUp(self):
41 | super(self.__class__, self).setUp()
42 | # Input is 2x3x5, which isn't a natural size for any op.
43 | self.input_data = numpy.array(
44 | [[[1, 2, 3, 4, 5],
45 | [6, 7, 8, 9, 10],
46 | [10, 12, 13, 14, 15],], [[-1, 2, -3, 4, -5], [6, -7, 8, -9, 10],
47 | [-10, 12, -13, 14, -15]]],
48 | dtype=numpy.float)
49 | self.input = tf.constant(self.input_data, dtype=tf.float32)
50 |
51 | def Template(self, key):
52 | return prettytensor.template(key, self.bookkeeper)
53 |
54 | def testSimpleTemplate(self):
55 | template = self.Template(KEY)
56 |
57 | x = template.construct(random_key=self.input)
58 | out = self.RunTensor(x)
59 | testing.assert_allclose(self.input_data, out, rtol=TOLERANCE)
60 |
61 | def testSingleMethod(self):
62 | template = self.Template(KEY).flatten()
63 |
64 | x = template.construct(random_key=self.input)
65 | out = self.RunTensor(x)
66 | testing.assert_allclose(
67 | self.input_data.reshape([2, 15]),
68 | out,
69 | rtol=TOLERANCE)
70 |
71 | def testSequential(self):
72 | seq = self.Template(KEY).sequential()
73 | seq.flatten()
74 | seq.fully_connected(100)
75 | out = self.RunTensor(seq.as_layer().construct(random_key=self.input))
76 | self.assertSequenceEqual([2, 100], out.shape)
77 |
78 | def testAttach(self):
79 | input_pt = self.Wrap(self.input)
80 | template = self.Template('input').flatten().fully_connected(100)
81 | out = self.RunTensor(input_pt.attach_template(template, 'input'))
82 | self.assertSequenceEqual([2, 100], out.shape)
83 |
84 | def testUnboundVariableForParameter(self):
85 | input_pt = self.Wrap(self.input)
86 | template = input_pt.flatten().fully_connected(prettytensor.UnboundVariable(
87 | 'width'))
88 | self.assertTrue(isinstance(template, pretty_tensor_class._DeferredLayer))
89 | out = self.RunTensor(template.construct(width=200))
90 | self.assertSequenceEqual([2, 200], out.shape)
91 |
92 | def testMissingUnboundVariable(self):
93 | input_pt = self.Wrap(self.input)
94 | template = input_pt.flatten().fully_connected(prettytensor.UnboundVariable(
95 | 'width'))
96 | with self.assertRaises(ValueError):
97 | template.construct()
98 |
99 | def testUnboundVariableReused(self):
100 | """The same unbound_var can be used multiple times in a graph."""
101 | input_pt = self.Wrap(self.input)
102 | unbound_var = prettytensor.UnboundVariable('width')
103 | template = (input_pt.flatten().fully_connected(unbound_var)
104 | .fully_connected(unbound_var))
105 | out = self.RunTensor(template.construct(width=200))
106 | self.assertSequenceEqual([2, 200], out.shape)
107 |
108 | def testAttachToTemplate(self):
109 | input_pt = self.Wrap(self.input)
110 | template1 = self.Template('input').flatten()
111 | template2 = self.Template('input').fully_connected(100)
112 |
113 | joined = template1.attach_template(template2, 'input')
114 | out = self.RunTensor(input_pt.attach_template(joined, 'input'))
115 | self.assertSequenceEqual([2, 100], out.shape)
116 |
117 | def testUnboundVariableAsDefault(self):
118 | """The same unbound_var can be used multiple times in a graph."""
119 | input_pt = self.Wrap(self.input)
120 | with prettytensor.defaults_scope(
121 | value=prettytensor.UnboundVariable('key')):
122 | x = input_pt.ValidateMethod(self)
123 | self.assertTrue(isinstance(x, pretty_tensor_class._DeferredLayer))
124 | x.construct(key=KEY)
125 |
126 | def testConflictingUnboundVariables(self):
127 | """Two unbound_vars with the same name are considered conflicting."""
128 | input_pt = self.Wrap(self.input)
129 | with self.assertRaises(ValueError):
130 | (input_pt.flatten()
131 | .fully_connected(prettytensor.UnboundVariable('width'))
132 | .fully_connected(prettytensor.UnboundVariable('width')))
133 |
134 | def testMultipleUnboundVariables(self):
135 | input_pt = self.Wrap(self.input)
136 | template = (input_pt.flatten()
137 | .fully_connected(prettytensor.UnboundVariable('width'))
138 | .fully_connected(prettytensor.UnboundVariable('width2')))
139 | out = self.RunTensor(template.construct(width=200, width2=100))
140 | self.assertSequenceEqual([2, 100], out.shape)
141 |
142 | def testExtraValues(self):
143 | input_pt = self.Wrap(self.input)
144 | template = (input_pt.flatten()
145 | .fully_connected(prettytensor.UnboundVariable('width')))
146 | with self.assertRaises(ValueError):
147 | template.construct(width=200, width2=100)
148 |
149 | def testIncompatibleUnboundVariableValues(self):
150 | """Ensures that an error is thrown if a var is given incompatible values.
151 |
152 | Since the primary use case of templates is parameter sharing, it is
153 | important that substitutions don't conflict.
154 | """
155 | input_pt = self.Wrap(self.input)
156 | full = input_pt.flatten().fully_connected(prettytensor.UnboundVariable(
157 | 'width'))
158 | full.construct(width=100)
159 | with self.assertRaises(ValueError):
160 | full.construct(width=200)
161 |
162 | def BuildLargishGraph(self, input_pt):
163 | seq = input_pt.sequential()
164 | seq.reshape('___1')
165 | seq.conv2d(1, 10)
166 | with seq.subdivide(2) as [a, b]:
167 | a.with_name('a').conv2d(1, 5)
168 | b.with_name('b').conv2d(1, 15)
169 | seq.with_name('wow')
170 | seq.flatten()
171 | seq.fully_connected(100, name='a_funny_name')
172 | return seq.as_layer()
173 |
174 | def testGraphMatchesImmediate(self):
175 | """Ensures that the vars line up between the two modes."""
176 | with tf.Graph().as_default():
177 | input_pt = prettytensor.wrap(
178 | tf.constant(self.input_data, dtype=tf.float32))
179 | self.BuildLargishGraph(input_pt)
180 | normal_names = sorted([v.name for v in tf.global_variables()])
181 |
182 | with tf.Graph().as_default():
183 | template = prettytensor.template('input')
184 | self.BuildLargishGraph(template).construct(input=prettytensor.wrap(
185 | tf.constant(self.input_data, dtype=tf.float32)))
186 | template_names = sorted([v.name for v in tf.global_variables()])
187 |
188 | self.assertSequenceEqual(normal_names, template_names)
189 |
190 | def testVariablesAreShared(self):
191 | """Ensures that adding the graph twice shares variables."""
192 | input_pt = self.Wrap(self.input)
193 | template = self.Template('input').flatten().fully_connected(10)
194 |
195 | l1 = template.construct(input=input_pt)
196 | l2 = template.construct(input=input_pt)
197 | self.assertNotEqual(l1.tensor, l2.tensor)
198 |
199 | v1 = self.RunTensor(l1, init=True)
200 | v2 = self.RunTensor(l2, init=False)
201 | testing.assert_allclose(v1, v2, rtol=TOLERANCE)
202 |
203 | def testBind(self):
204 | input_pt = self.Wrap(self.input)
205 | template = self.Template('input').flatten().fully_connected(10)
206 |
207 | l1 = template.bind(input=input_pt).construct()
208 | l2 = template.construct(input=input_pt)
209 | v1 = self.RunTensor(l1, init=True)
210 | v2 = self.RunTensor(l2, init=False)
211 | testing.assert_allclose(v1, v2, rtol=TOLERANCE)
212 |
213 | def testBindTuple(self):
214 | labels = numpy.array([[0., 1.], [1., 0.]], dtype=numpy.float32)
215 | template = self.Template('input').flatten().softmax_classifier(2, labels)
216 | bound = template.bind(input=self.input)
217 |
218 | tuple1 = bound.construct()
219 | tuple2 = template.construct(input=self.input)
220 |
221 | self.assertNotEqual(tuple1.softmax.tensor, tuple2.softmax.tensor)
222 | softmax1 = self.RunTensor(tuple1.softmax, init=True)
223 | loss1 = self.RunTensor(tuple1.loss, init=False)
224 | softmax2 = self.RunTensor(tuple2.softmax, init=False)
225 | loss2 = self.RunTensor(tuple2.loss, init=False)
226 | testing.assert_allclose(softmax1, softmax2, rtol=TOLERANCE)
227 | testing.assert_allclose(loss1, loss2, rtol=TOLERANCE)
228 |
229 | def testConstructAllWithConflictingValues(self):
230 | labels = numpy.array([[0., 1.], [1., 0.]], dtype=numpy.float32)
231 | template = self.Template('input').flatten().softmax_classifier(2, labels)
232 |
233 | softmax = template.softmax.bind(input=self.input)
234 | loss = template.loss.bind(input=labels)
235 | with self.assertRaises(ValueError):
236 | prettytensor.construct_all([softmax, loss])
237 |
238 |
239 | if __name__ == '__main__':
240 | unittest.main()
241 |
--------------------------------------------------------------------------------
/prettytensor/train.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 Google Inc. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 | """Imports some utilities for training models."""
12 | from __future__ import absolute_import
13 | from __future__ import division
14 | from __future__ import print_function
15 |
16 | # pylint: disable=unused-import, wildcard-import
17 | from prettytensor.input_helpers import batch
18 | from prettytensor.input_helpers import feed_numpy
19 | from prettytensor.local_trainer import create_checkpointing_runner
20 | from prettytensor.local_trainer import create_follower_runner
21 | from prettytensor.local_trainer import Runner
22 | from prettytensor.recurrent_networks import RecurrentRunner
23 | from prettytensor.replay_queue import ReplayableQueue
24 |
--------------------------------------------------------------------------------
/prettytensor/tutorial/README.md:
--------------------------------------------------------------------------------
1 | # Tutorial
2 |
3 | These are simple models that highlight some of Pretty Tensor's features and
4 | that hopefully will be useful branching points for your own experiments.
5 |
6 | While each tutorial is intended to be standalone, the recommended order is:
7 |
8 | 1. `mnist.py`
9 | 2. `baby_names.py`
10 | 3. `shakespeare.py`
11 |
12 | All of the tutorials show you how to build a model and run training/evaluation.
13 |
14 | ## MNIST
15 |
16 | `mnist.py` shows a simple image classification model using the
17 | [MNIST dataset](http://yann.lecun.com/exdb/mnist/).
18 |
19 | ## Baby Names
20 |
21 | `baby_names.py` is a recurrent network that uses data from the Office of
22 | Retirement and Disability Policy, Social Security Administration about all
23 | children born in the US for the past century. The model uses an
24 | [Long Short Term Memory](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
25 | (LSTM) to make a prediction on the boy/girl ratio for each name.
26 |
27 | ## Shakespeare
28 |
29 | `shakespeare.py` uses stacked LSTMs to read each character of Shakespeare and
30 | predict the next character. It can be used to sample your own Shakespearian
31 | comedy.
32 |
--------------------------------------------------------------------------------
/prettytensor/tutorial/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 Google Inc. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 |
--------------------------------------------------------------------------------
/prettytensor/tutorial/baby_names.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 Google Inc. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 | """Tutorial to predict the sex of a baby from the name.
12 |
13 | This model takes a dataset of baby names to ratio of sexes for that name and
14 | then trains an LSTM to predict the ratio given the characters of the name.
15 |
16 | The input are the characters of the name as ASCII codes (0-127) and it is
17 | unrolled for 15 steps, which is the longest name in the corpus. The results are
18 | fed through a recurrent network and then to a 2 way classifier that predicts the
19 | boy/girl ratio.
20 |
21 | This demonstrates how to train a classifier on the last output using an LSTM,
22 | which can be at any point (some names are short and some are long) by setting
23 | weights on each example. This also demonstrates how to efficiently reshape the
24 | network for the classifier and how to use dropout in both a training and eval
25 | graph.
26 | """
27 | from __future__ import absolute_import
28 | from __future__ import division
29 | from __future__ import print_function
30 |
31 |
32 | import numpy
33 | from six.moves import xrange # pylint: disable=redefined-builtin
34 | import tensorflow as tf
35 |
36 | import prettytensor as pt
37 | from prettytensor.tutorial import data_utils
38 |
39 | tf.app.flags.DEFINE_string(
40 | 'save_path', None, 'Where to save the model checkpoints on local disk. '
41 | 'Checkpoints are in LevelDb.')
42 | FLAGS = tf.app.flags.FLAGS
43 |
44 |
45 | BATCH_SIZE = 32
46 | CHARS = 128
47 | TIMESTEPS = 15
48 | SEXES = 2
49 |
50 | EMBEDDING_SIZE = 16
51 |
52 |
53 | def create_model(text_in,
54 | labels,
55 | timesteps,
56 | per_example_weights,
57 | phase=pt.Phase.train):
58 | """Creates a model for running baby names."""
59 | with pt.defaults_scope(phase=phase, l2loss=0.00001):
60 | # The embedding lookup must be placed on a cpu.
61 | with tf.device('/cpu:0'):
62 | embedded = text_in.embedding_lookup(CHARS, [EMBEDDING_SIZE])
63 | # We need to cleave the sequence because sequence lstm expect each
64 | # timestep to be in its own Tensor.
65 | lstm = (embedded.cleave_sequence(timesteps).sequence_lstm(CHARS))
66 |
67 | # The classifier is much more efficient if it runs across the entire
68 | # batch at once, so we want to squash (i.e. uncleave).
69 | #
70 | # Hidden nodes is set to 32 because it seems to work well.
71 | return (lstm.squash_sequence().fully_connected(32,
72 | activation_fn=tf.nn.relu)
73 | .dropout(0.7)
74 | .softmax_classifier(SEXES,
75 | labels,
76 | per_example_weights=per_example_weights))
77 |
78 |
79 | def main(_=None):
80 | print('Starting Baby Names')
81 |
82 | # Since we are feeding our data as numpy arrays, we need to create
83 | # placeholders in the graph.
84 | # These must then be fed using the feed dict.
85 | input_placeholder = tf.placeholder(tf.int32, [BATCH_SIZE, TIMESTEPS])
86 | output_placeholder = tf.placeholder(tf.float32, [BATCH_SIZE, SEXES])
87 |
88 | inp = data_utils.reshape_data(input_placeholder)
89 |
90 | # Create a label for each timestep.
91 | labels = data_utils.reshape_data(
92 | tf.reshape(
93 | tf.tile(output_placeholder, [1, TIMESTEPS]), [BATCH_SIZE, TIMESTEPS,
94 | SEXES]),
95 | per_example_length=2)
96 |
97 | # We also need to set per example weights so that the softmax doesn't output a
98 | # prediction on intermediate nodes.
99 | length_placeholder = tf.placeholder(tf.int32, [BATCH_SIZE, 1])
100 |
101 | # We need a dense multiplier for the per example weights. The only place
102 | # that has a non-zero loss is the first EOS after the last character of the
103 | # name; the characters in the name and the trailing EOS characters are given a
104 | # 0 loss by assigning the weight to 0.0 and in the end only one character in
105 | # each batch has a weight of 1.0.
106 | # sparse_to_dense does a lookup using the indices from the first Tensor.
107 | # Because we are filling in a 2D array, the indices need to be 2 dimensional.
108 | # Since we want to assign 1 value for each row, the first dimension can just
109 | # be a sequence.
110 | t = tf.concat(
111 | [
112 | tf.constant(
113 | numpy.arange(BATCH_SIZE).reshape((BATCH_SIZE, 1)),
114 | dtype=tf.int32), length_placeholder
115 | ],
116 | 1)
117 |
118 | # Squeeze removes dimensions that are equal to 1. per_example_weights must
119 | # end up as 1 dimensional.
120 | per_example_weights = data_utils.reshape_data(tf.sparse_to_dense(
121 | t, [BATCH_SIZE, TIMESTEPS], 1.0, default_value=0.0)).squeeze()
122 |
123 | # We need 2 copies of the graph that share variables. The first copy runs
124 | # training and will do dropout if specified and the second will not include
125 | # dropout. Dropout is controlled by the phase argument, which sets the mode
126 | # consistently throughout a graph.
127 | with tf.variable_scope('baby_names'):
128 | result = create_model(inp, labels, TIMESTEPS, per_example_weights)
129 |
130 | # Call variable scope by name so we also create a name scope. This ensures
131 | # that we share variables and our names are properly organized.
132 | with tf.variable_scope('baby_names', reuse=True):
133 | # Some ops have different behaviors in test vs train and these take a phase
134 | # argument.
135 | test_result = create_model(inp,
136 | labels,
137 | TIMESTEPS,
138 | per_example_weights,
139 | phase=pt.Phase.test)
140 |
141 | # For tracking accuracy in evaluation, we need to add an evaluation node.
142 | # We only run this when testing, so we need to specify that in the phase.
143 | # Some ops have different behaviors in test vs train and these take a phase
144 | # argument.
145 | accuracy = test_result.softmax.evaluate_classifier(
146 | labels,
147 | phase=pt.Phase.test,
148 | per_example_weights=per_example_weights)
149 |
150 | # We can also compute a batch accuracy to monitor progress.
151 | batch_accuracy = result.softmax.evaluate_classifier(
152 | labels,
153 | phase=pt.Phase.train,
154 | per_example_weights=per_example_weights)
155 |
156 | # Grab the inputs, outputs and lengths as numpy arrays.
157 | # Lengths could have been calculated from names, but it was easier to
158 | # calculate inside the utility function.
159 | names, sex, lengths = data_utils.baby_names(TIMESTEPS)
160 |
161 | epoch_size = len(names) // BATCH_SIZE
162 | # Create the gradient optimizer and apply it to the graph.
163 | # pt.apply_optimizer adds regularization losses and sets up a step counter
164 | # (pt.global_step()) for you.
165 | # This sequence model does very well with initially high rates.
166 | optimizer = tf.train.AdagradOptimizer(
167 | tf.train.exponential_decay(1.0,
168 | pt.global_step(),
169 | epoch_size,
170 | 0.95,
171 | staircase=True))
172 | train_op = pt.apply_optimizer(optimizer, losses=[result.loss])
173 |
174 | # We can set a save_path in the runner to automatically checkpoint every so
175 | # often. Otherwise at the end of the session, the model will be lost.
176 | runner = pt.train.Runner(save_path=FLAGS.save_path)
177 | with tf.Session():
178 | for epoch in xrange(100):
179 | # Shuffle the training data.
180 | names, sex, lengths = data_utils.permute_data((names, sex, lengths))
181 |
182 | runner.train_model(
183 | train_op,
184 | [result.loss, batch_accuracy],
185 | epoch_size,
186 | feed_vars=(input_placeholder, output_placeholder, length_placeholder),
187 | feed_data=pt.train.feed_numpy(BATCH_SIZE, names, sex, lengths),
188 | print_every=100)
189 | classification_accuracy = runner.evaluate_model(
190 | accuracy,
191 | epoch_size,
192 | print_every=0,
193 | feed_vars=(input_placeholder, output_placeholder, length_placeholder),
194 | feed_data=pt.train.feed_numpy(BATCH_SIZE, names, sex, lengths))
195 |
196 | print('Accuracy after epoch %d: %g%%' % (
197 | epoch + 1, classification_accuracy * 100))
198 |
199 |
200 | if __name__ == '__main__':
201 | tf.app.run()
202 |
--------------------------------------------------------------------------------
/prettytensor/tutorial/data_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 Google Inc. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 | """Data utils bundles the utilties to download and munge data in numpy."""
12 | from __future__ import absolute_import
13 | from __future__ import division
14 | from __future__ import print_function
15 |
16 | import csv
17 | import gzip
18 | import os.path
19 | import sys
20 |
21 |
22 |
23 | import numpy as np
24 | from six.moves import xrange # pylint: disable=redefined-builtin
25 | from six.moves.urllib import request
26 | import tensorflow as tf
27 |
28 | import prettytensor as pt
29 |
30 |
31 | WORK_DIRECTORY = '/tmp/data'
32 | MNIST_URL = 'http://yann.lecun.com/exdb/mnist/'
33 | UNK = 0
34 | EOS = 1
35 |
36 |
37 | def maybe_download(url, filename):
38 | """Download the data from Yann's website, unless it's already here."""
39 | if not os.path.exists(WORK_DIRECTORY):
40 | os.mkdir(WORK_DIRECTORY)
41 | filepath = os.path.join(WORK_DIRECTORY, filename)
42 | if not os.path.exists(filepath):
43 | filepath, _ = request.urlretrieve(url + filename, filepath)
44 | statinfo = os.stat(filepath)
45 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
46 | return filepath
47 |
48 |
49 | def mnist_extract_data(filename, num_images):
50 | """Extract the images into a 4D tensor [image index, y, x, channels].
51 |
52 | Values are rescaled from [0, 255] down to [-0.5, 0.5].
53 |
54 | Args:
55 | filename: The local filename.
56 | num_images: The number of images in this file.
57 | Returns:
58 | The data as a numpy array with the values centered.
59 | """
60 | print('Extracting', filename)
61 | with gzip.open(filename) as bytestream:
62 | bytestream.read(16)
63 | buf = bytestream.read(28 * 28 * num_images)
64 | data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32)
65 | data -= 255 / 2.0
66 | data /= 255.0
67 | data = data.reshape(num_images, 28, 28, 1)
68 | return data
69 |
70 |
71 | def mnist_extract_labels(filename, num_images):
72 | """Extract the labels into a 1-hot matrix [image index, label index]."""
73 | print('Extracting', filename)
74 | with gzip.open(filename) as bytestream:
75 | bytestream.read(8)
76 | buf = bytestream.read(1 * num_images)
77 | labels = np.frombuffer(buf, dtype=np.uint8)
78 | # Convert to dense 1-hot representation.
79 | return (np.arange(10) == labels[:, None]).astype(np.float32)
80 |
81 |
82 | def permute_data(arrays, random_state=None):
83 | """Permute multiple numpy arrays with the same order."""
84 | if any(len(a) != len(arrays[0]) for a in arrays):
85 | raise ValueError('All arrays must be the same length.')
86 | if not random_state:
87 | random_state = np.random
88 | order = random_state.permutation(len(arrays[0]))
89 | return [a[order] for a in arrays]
90 |
91 |
92 | def mnist(training):
93 | """Downloads MNIST and loads it into numpy arrays."""
94 | if training:
95 | data_filename = 'train-images-idx3-ubyte.gz'
96 | labels_filename = 'train-labels-idx1-ubyte.gz'
97 | count = 60000
98 | else:
99 | data_filename = 't10k-images-idx3-ubyte.gz'
100 | labels_filename = 't10k-labels-idx1-ubyte.gz'
101 | count = 10000
102 | data_filename = maybe_download(MNIST_URL, data_filename)
103 | labels_filename = maybe_download(MNIST_URL, labels_filename)
104 |
105 | return (mnist_extract_data(data_filename, count),
106 | mnist_extract_labels(labels_filename, count))
107 |
108 |
109 | def convert_to_int(char):
110 | i = ord(char)
111 | if i >= 128:
112 | return UNK
113 | return i
114 |
115 |
116 | def shakespeare(chunk_size):
117 | """Downloads Shakespeare, converts it into ASCII codes and chunks it.
118 |
119 | Args:
120 | chunk_size: The dataset is broken down so that it is shaped into batches x
121 | chunk_size.
122 | Returns:
123 | A numpy array of ASCII codes shaped into batches x chunk_size.
124 | """
125 | file_name = maybe_download('http://cs.stanford.edu/people/karpathy/char-rnn/',
126 | 'shakespear.txt')
127 | with open(file_name) as f:
128 | shakespeare_full = f.read()
129 |
130 | # Truncate the data.
131 | length = (len(shakespeare_full) // chunk_size) * chunk_size
132 | if length < len(shakespeare_full):
133 | shakespeare_full = shakespeare_full[:length]
134 | arr = np.array([convert_to_int(c) for c in shakespeare_full])[
135 | 0:len(shakespeare_full) / chunk_size * chunk_size]
136 | return arr.reshape((len(arr) / chunk_size, chunk_size))
137 |
138 |
139 | def baby_names(max_length=15):
140 | """Opens the baby_names csv file and produces numpy array.
141 |
142 | Args:
143 | max_length: The maximum length, 15 was the longest name when this was
144 | written. Short entries will be padded with the EOS marker.
145 | Returns:
146 | A numpy array of the names converted to ascii codes, the labels and an
147 | array of lengths.
148 | Raises:
149 | ValueError: if max_length is too small.
150 | """
151 | names = []
152 | lengths = []
153 | targets = []
154 | with open(os.path.join(os.path.dirname(sys.modules[__name__].__file__),
155 | 'baby_names.csv'), 'rb') as f:
156 | first = True
157 | for l in csv.reader(f, delimiter=','):
158 | if first:
159 | first = False
160 | continue
161 | assert len(l) == 4, l
162 | name = l[0]
163 | if max_length < len(name):
164 | raise ValueError('Max length is too small: %d > %d' %
165 | (max_length, len(name)))
166 | chars = [convert_to_int(c) for c in name]
167 | names.append(chars + ([EOS] * (max_length - len(chars))))
168 | lengths.append([len(name)])
169 | values = [float(l[2]), float(l[3])]
170 | if abs(sum(values) - 1) > 0.001:
171 | raise ValueError('Each row must sum to 1: %s' % l)
172 | targets.append(values)
173 | return np.array(names), np.array(targets), np.array(lengths)
174 |
175 |
176 | def reshape_data(tensor, per_example_length=1):
177 | """Reshapes input so that it is appropriate for sequence_lstm..
178 |
179 | The expected format for sequence lstms is
180 | [timesteps * batch, per_example_length] and the data produced by the utilities
181 | is [batch, timestep, *optional* expected_length]. The result can be cleaved
182 | so that there is a Tensor per timestep.
183 |
184 | Args:
185 | tensor: The tensor to reshape.
186 | per_example_length: The number of examples at each timestep.
187 | Returns:
188 | A Pretty Tensor that is compatible with cleave and then sequence_lstm.
189 |
190 | """
191 | # We can put the data into a format that can be easily cleaved by
192 | # transposing it (so that it varies fastest in batch) and then making each
193 | # component have a single value.
194 | # This will make it compatible with the Pretty Tensor function
195 | # cleave_sequence.
196 | dims = [1, 0]
197 | for i in xrange(2, tensor.get_shape().ndims):
198 | dims.append(i)
199 | return pt.wrap(tf.transpose(tensor, dims)).reshape([-1, per_example_length])
200 |
--------------------------------------------------------------------------------
/prettytensor/tutorial/mnist.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 Google Inc. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 | """MNIST tutorial.
12 |
13 | This uses Pretty Tensor to define and train either a 2 layer model or a
14 | convolutional model in the style of LeNet 5.
15 | See: http://yann.lecun.com/exdb/lenet/
16 | """
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 |
22 | from six.moves import xrange # pylint: disable=redefined-builtin
23 | import tensorflow as tf
24 |
25 | import prettytensor as pt
26 | from prettytensor.tutorial import data_utils
27 |
28 | tf.app.flags.DEFINE_string(
29 | 'save_path', None, 'Where to save the model checkpoints.')
30 | FLAGS = tf.app.flags.FLAGS
31 |
32 | BATCH_SIZE = 50
33 | EPOCH_SIZE = 60000 // BATCH_SIZE
34 | TEST_SIZE = 10000 // BATCH_SIZE
35 |
36 | tf.app.flags.DEFINE_string('model', 'full',
37 | 'Choose one of the models, either full or conv')
38 | FLAGS = tf.app.flags.FLAGS
39 |
40 |
41 | def multilayer_fully_connected(images, labels):
42 | """Creates a multi layer network of fully_connected layers.
43 |
44 | Each layer is 100 neurons. Please change this to experiment with
45 | architectures.
46 |
47 | Args:
48 | images: The input images.
49 | labels: The labels as dense one-hot vectors.
50 | Returns:
51 | A softmax result.
52 | """
53 | # Pretty Tensor is a thin wrapper on Tensors.
54 | # Change this method to experiment with other architectures
55 | images = pt.wrap(images)
56 | with pt.defaults_scope(activation_fn=tf.nn.relu, l2loss=0.00001):
57 | return (images.flatten().fully_connected(100).fully_connected(100)
58 | .softmax_classifier(10, labels))
59 |
60 |
61 | def lenet5(images, labels):
62 | """Creates a multi layer convolutional network.
63 |
64 | The architecture is similar to that defined in LeNet 5.
65 | Please change this to experiment with architectures.
66 |
67 | Args:
68 | images: The input images.
69 | labels: The labels as dense one-hot vectors.
70 | Returns:
71 | A softmax result.
72 | """
73 | images = pt.wrap(images)
74 | with pt.defaults_scope(activation_fn=tf.nn.relu, l2loss=0.00001):
75 | return (images.conv2d(5, 20).max_pool(2, 2).conv2d(5, 50).max_pool(2, 2)
76 | .flatten().fully_connected(500).softmax_classifier(10, labels))
77 |
78 |
79 | def main(_=None):
80 | # Since we are feeding our data as numpy arrays, we need to create
81 | # placeholders in the graph.
82 | # These must then be fed using the feed dict.
83 | image_placeholder = tf.placeholder(tf.float32, [BATCH_SIZE, 28, 28, 1])
84 | labels_placeholder = tf.placeholder(tf.float32, [BATCH_SIZE, 10])
85 |
86 | # Create our model. The result of softmax_classifier is a namedtuple
87 | # that has members result.loss and result.softmax.
88 | if FLAGS.model == 'full':
89 | result = multilayer_fully_connected(image_placeholder, labels_placeholder)
90 | elif FLAGS.model == 'conv':
91 | result = lenet5(image_placeholder, labels_placeholder)
92 | else:
93 | raise ValueError('model must be full or conv: %s' % FLAGS.model)
94 |
95 | # For tracking accuracy in evaluation, we need to add an evaluation node.
96 | # We only include this part of the graph when testing, so we need to specify
97 | # that in the phase.
98 | # Some ops have different behaviors in test vs train and these take a phase
99 | # argument.
100 | accuracy = result.softmax.evaluate_classifier(labels_placeholder,
101 | phase=pt.Phase.test)
102 |
103 | # Grab the data as numpy arrays.
104 | train_images, train_labels = data_utils.mnist(training=True)
105 | test_images, test_labels = data_utils.mnist(training=False)
106 |
107 | # Create the gradient optimizer and apply it to the graph.
108 | # pt.apply_optimizer adds regularization losses and sets up a step counter
109 | # (pt.global_step()) for you.
110 | optimizer = tf.train.GradientDescentOptimizer(0.01)
111 | train_op = pt.apply_optimizer(optimizer, losses=[result.loss])
112 |
113 | # We can set a save_path in the runner to automatically checkpoint every so
114 | # often. Otherwise at the end of the session, the model will be lost.
115 | runner = pt.train.Runner(save_path=FLAGS.save_path)
116 | with tf.Session():
117 | for epoch in xrange(10):
118 | # Shuffle the training data.
119 | train_images, train_labels = data_utils.permute_data(
120 | (train_images, train_labels))
121 |
122 | runner.train_model(
123 | train_op,
124 | result.loss,
125 | EPOCH_SIZE,
126 | feed_vars=(image_placeholder, labels_placeholder),
127 | feed_data=pt.train.feed_numpy(BATCH_SIZE, train_images, train_labels),
128 | print_every=100)
129 | classification_accuracy = runner.evaluate_model(
130 | accuracy,
131 | TEST_SIZE,
132 | feed_vars=(image_placeholder, labels_placeholder),
133 | feed_data=pt.train.feed_numpy(BATCH_SIZE, test_images, test_labels))
134 | print('Accuracy after %d epoch %g%%' % (
135 | epoch + 1, classification_accuracy * 100))
136 |
137 |
138 | if __name__ == '__main__':
139 | tf.app.run()
140 |
--------------------------------------------------------------------------------
/prettytensor/tutorial/shakespeare.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 Google Inc. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 | """Shakespeare tutorial.
12 |
13 | The Shakespeare tutorial downloads a snippet of Shakespeare, munges the data
14 | into the correct format and then creates a 2 layer LSTM to predict the next
15 | character given the current character.
16 | """
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import random
22 |
23 |
24 | import numpy
25 | from six.moves import xrange # pylint: disable=redefined-builtin
26 | import tensorflow as tf
27 |
28 | import prettytensor as pt
29 | from prettytensor.tutorial import data_utils
30 |
31 | tf.app.flags.DEFINE_string(
32 | 'save_path', None, 'Where to save the model checkpoints.')
33 | tf.app.flags.DEFINE_integer(
34 | 'epochs', 10, 'The number of epochs to run training on this model.')
35 | FLAGS = tf.app.flags.FLAGS
36 |
37 | BATCH_SIZE = 8
38 | CHARS = 128
39 | TIMESTEPS = 100
40 |
41 | # The size of the embedding for each character that will be learned.
42 | EMBEDDING_SIZE = 16
43 |
44 | # The number of cells in the lower and upper LSTM layers.
45 | LOWER = 128
46 | UPPER = 256
47 |
48 |
49 | def create_model(text_in, timesteps, phase):
50 | """Creates a 2 layer LSTM model with dropout.
51 |
52 | Args:
53 | text_in: The input text as ASCII ordinals in a Tensor.
54 | timesteps: The number of timesteps in the sequence.
55 | phase: Phase controls whether or not dropout is active. In training mode
56 | we want to perform dropout, but in test we want to disable it.
57 | Returns:
58 | The logits.
59 | """
60 | with pt.defaults_scope(activation_fn=tf.nn.relu, l2loss=0.00001):
61 | # The embedding lookup must be placed on a cpu.
62 | with tf.device('/cpu:0'):
63 | embedded = text_in.embedding_lookup(CHARS, [EMBEDDING_SIZE])
64 | # Because the sequence LSTM expects each timestep to be its own Tensor,
65 | # we need to cleave the sequence.
66 | # Below we can build a stacked 2 layer LSTM by just chaining them together.
67 | # You can stack as many layers as you want.
68 | lstm = (embedded
69 | .cleave_sequence(timesteps)
70 | .sequence_lstm(LOWER)
71 | .sequence_lstm(UPPER))
72 |
73 | # The classifier is much more efficient if it runs across the entire
74 | # dataset at once, so we want to squash (i.e. uncleave).
75 | # Note: if phase is test, dropout is a noop.
76 | return (lstm.squash_sequence()
77 | .dropout(keep_prob=0.8, phase=phase)
78 | .fully_connected(CHARS, activation_fn=None))
79 |
80 |
81 | def sample(
82 | input_placeholder, logits, seed=None, max_length=1024, temperature=1.0):
83 | """Samples from the LSTM model.
84 |
85 | Sampling is done by first running either the seed or an arbitrary character
86 | through the model and then drawing the next character from the probability
87 | distribution definted by `softmax`.
88 |
89 | Args:
90 | input_placeholder: A placeholder that expects a scalar feed.
91 | logits: The logits. This works with the logits so that it can apply the
92 | temperature.
93 | seed: Either a string of characters to prime the network or None.
94 | max_length: The maximum length to draw in case EOS is not reached.
95 | temperature: A value that is used to renormalize the inputs. A higher value
96 | selects less likely choices.
97 | Returns:
98 | A string that was sampled from the model.
99 | """
100 | assert temperature > 0, 'Temperature must be greater than 0.'
101 | if not seed:
102 | # The model expects an input to do inference, so seed with a single letter.
103 | seed = chr(ord('A') + random.randint(0, 25))
104 | result = ''
105 |
106 | # The recurrent runner takes care of tracking the model's state at each step
107 | # and provides a reset call to zero it out for each query.
108 | recurrent_runner = pt.train.RecurrentRunner()
109 |
110 | # We need to reset the hidden state for each query.
111 | recurrent_runner.reset()
112 | # Initialize the system
113 | for c in seed[:-1]:
114 | recurrent_runner.run([logits],
115 | {input_placeholder: data_utils.convert_to_int(c)})
116 | result += c
117 |
118 | # Start sampling!
119 | ci = ord(seed[-1])
120 | while len(result) < max_length and ci != data_utils.EOS:
121 | result += chr(ci)
122 | # The softmax is probability normalized and would have been appropriate here
123 | # if we weren't applying the temperature (temperature could also be done in
124 | # TensorFlow).
125 | logit_result = recurrent_runner.run([logits],
126 | {input_placeholder: ci})[0][0]
127 | logit_result /= temperature
128 |
129 | # Apply the softmax in numpy to convert from logits to probabilities.
130 | # Subtract off the max for numerical stability -- logits are invariant to
131 | # additive scaling and this eliminates overflows.
132 | logit_result -= logit_result.max()
133 |
134 | distribution = numpy.exp(logit_result)
135 | distribution /= distribution.sum()
136 |
137 | # Numpy multinomial needs the value to be strictly < 1
138 | distribution -= .00000001
139 | ci = numpy.argmax(numpy.random.multinomial(1, distribution))
140 | result += chr(ci) # Add the last letter.
141 | return result
142 |
143 |
144 | def main(_=None):
145 | print('Starting Shakespeare')
146 |
147 | # Since we are feeding our data as numpy arrays, we need to create
148 | # placeholders in the graph.
149 | # These must then be fed using the feed dict.
150 | input_placeholder = tf.placeholder(tf.int32, [BATCH_SIZE, TIMESTEPS])
151 | output_placeholder = tf.placeholder(tf.int32, [BATCH_SIZE, TIMESTEPS])
152 |
153 | merged_size = BATCH_SIZE * TIMESTEPS
154 |
155 | inp = data_utils.reshape_data(input_placeholder)
156 |
157 | # We need a dense output to calculate loss and accuracy.
158 | # sparse_to_dense does a lookup using the indices from the first Tensor.
159 | # Because we are filling in a 2D array, the indices need to be 2 dimensional.
160 | t = tf.concat(
161 | [
162 | tf.constant(
163 | numpy.arange(merged_size).reshape((merged_size, 1)),
164 | dtype=tf.int32), data_utils.reshape_data(output_placeholder)
165 | ],
166 | 1)
167 |
168 | labels = tf.sparse_to_dense(t, [merged_size, CHARS], 1.0, 0.0)
169 |
170 | # Some ops have different behaviors in test vs train and these take a phase
171 | # argument.
172 | with tf.variable_scope('shakespeare'):
173 | training_logits = create_model(inp, TIMESTEPS, pt.Phase.train)
174 | # Create the result. Softmax applies softmax and creates a cross entropy
175 | # loss. The result is a namedtuple.
176 | training_result = training_logits.softmax(labels)
177 |
178 | # Create the gradient optimizer and apply it to the graph.
179 | # pt.apply_optimizer adds regularization losses and sets up a step counter
180 | # (pt.global_step()) for you.
181 | optimizer = tf.train.AdagradOptimizer(0.5)
182 | train_op = pt.apply_optimizer(optimizer, losses=[training_result.loss])
183 |
184 | # For tracking accuracy in evaluation, we need to add an evaluation node.
185 | # We only run this when testing, so we need to specify that in the phase.
186 | # We also want to disable dropout, so we pass the phase to create_model.
187 |
188 | # Call variable scope by name so we also create a name scope. This ensures
189 | # that we share variables and our names are properly organized.
190 | with tf.variable_scope('shakespeare', reuse=True):
191 | test_logits = create_model(inp, TIMESTEPS, pt.Phase.test)
192 | test_result = test_logits.softmax(labels)
193 |
194 | # Accuracy creates variables, so make it outside of the above scope.
195 | accuracy = test_result.softmax.evaluate_classifier(labels,
196 | phase=pt.Phase.test)
197 |
198 | # Create an inference model so that we can sample. The big difference is
199 | # that the input is a single character and it requires reset nodes.
200 | # Also place summaries in a different collection. The default summaries have
201 | # dependencies on running the graph and would introduce a dependence on the
202 | # inference placeholder.
203 | with tf.variable_scope('shakespeare', reuse=True), pt.defaults_scope(
204 | summary_collections=['INFERENCE_SUMMARIES']):
205 | inference_input = tf.placeholder(tf.int32, [])
206 | # Needs to be 2 dimensional so that it matches the dims of the other models.
207 | reshaped = pt.wrap(inference_input).reshape([1, 1])
208 | inference_logits = create_model(reshaped, 1, pt.Phase.infer)
209 |
210 | # Grab the data as numpy arrays.
211 | shakespeare = data_utils.shakespeare(TIMESTEPS + 1)
212 | shakespeare_in = shakespeare[:, :-1]
213 | shakespeare_out = shakespeare[:, 1:]
214 |
215 | # We can set a save_path in the runner to automatically checkpoint every so
216 | # often. Otherwise at the end of the session, the model will be lost.
217 | runner = pt.train.Runner(save_path=FLAGS.save_path)
218 | with tf.Session():
219 | for epoch in xrange(FLAGS.epochs):
220 | # Shuffle the training data.
221 | shakespeare_in, shakespeare_out = data_utils.permute_data(
222 | (shakespeare_in, shakespeare_out))
223 |
224 | runner.train_model(train_op,
225 | training_result.loss,
226 | len(shakespeare_in) // BATCH_SIZE,
227 | feed_vars=(input_placeholder, output_placeholder),
228 | feed_data=pt.train.feed_numpy(
229 | BATCH_SIZE, shakespeare_in, shakespeare_out),
230 | print_every=10)
231 | classification_accuracy = runner.evaluate_model(
232 | accuracy,
233 | len(shakespeare_in) // BATCH_SIZE,
234 | feed_vars=(input_placeholder, output_placeholder),
235 | feed_data=pt.train.feed_numpy(BATCH_SIZE, shakespeare_in,
236 | shakespeare_out))
237 |
238 | print('Next character accuracy after epoch %d: %g%%' % (
239 | epoch + 1, classification_accuracy * 100))
240 |
241 | # Use a temperature smaller than 1 because the early stages of the model
242 | # don't assign much confidence.
243 | print(sample(inference_input,
244 | inference_logits,
245 | max_length=128,
246 | temperature=0.5))
247 |
248 | # Print a sampling from the model.
249 | print(sample(inference_input, inference_logits))
250 |
251 |
252 | if __name__ == '__main__':
253 | tf.app.run()
254 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 Google Inc. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | import fnmatch
13 | import os
14 | from setuptools import find_packages, setup, Extension
15 |
16 | _VERSION = '0.7.2'
17 |
18 | REQUIRED_PACKAGES = [
19 | 'enum34 >= 1.0.0',
20 | 'six >= 1.10.0',
21 | # Having an explicit dependency here breaks GPU
22 | # tensorflow >= 0.12.0rc0',
23 | ]
24 |
25 | # pylint: disable=line-too-long
26 | CONSOLE_SCRIPTS = [
27 | 'prettytensor_model_mnist = prettytensor.tutorial.mnist:main',
28 | 'prettytensor_model_shakespeare = prettytensor.tutorial.shakespeare:main',
29 | 'prettytensor_model_baby_names = prettytensor.tutorial.baby_names:main',
30 | ]
31 | # pylint: enable=line-too-long
32 |
33 | TEST_PACKAGES = [
34 | 'nose >= 1.3.7',
35 | ]
36 |
37 | setup(
38 | name='prettytensor',
39 | version=_VERSION,
40 | description='Pretty Tensor makes learning beautiful',
41 | long_description='',
42 | url='https://github.com/google/prettytensor',
43 | author='Eider Moore',
44 | author_email='opensource@google.com',
45 | # Contained modules and scripts.
46 | packages=find_packages(),
47 | include_package_data=True,
48 | package_data={
49 | 'prettytensor': ['tutorial/baby_names.csv']
50 | },
51 | entry_points={
52 | 'console_scripts': CONSOLE_SCRIPTS
53 | },
54 | install_requires=REQUIRED_PACKAGES,
55 | tests_require=REQUIRED_PACKAGES + TEST_PACKAGES,
56 | test_suite = 'nose.collector',
57 | # PyPI package information.
58 | classifiers=[
59 | 'Development Status :: 4 - Beta',
60 | 'Intended Audience :: Developers',
61 | 'Intended Audience :: Education',
62 | 'Intended Audience :: Science/Research',
63 | 'License :: OSI Approved :: Apache Software License',
64 | 'Programming Language :: Python :: 2.7',
65 | 'Topic :: Scientific/Engineering :: Mathematics',
66 | 'Topic :: Software Development :: Libraries :: Python Modules',
67 | 'Topic :: Software Development :: Libraries',
68 | ],
69 | license='Apache 2.0',
70 | keywords='tensorflow tensor machine learning',
71 | )
72 |
--------------------------------------------------------------------------------
/test_pip_install.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -e
4 | cur_dir=$(pwd)
5 |
6 | # Python 2
7 |
8 | rm -rf /tmp/clean-venv
9 | virtualenv /tmp/clean-venv
10 | cd /tmp/clean-venv
11 | source bin/activate
12 | pip install --upgrade pip
13 | pip install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0rc0-cp27-none-linux_x86_64.whl
14 | pip install prettytensor
15 | pip install nose
16 | nosetests prettytensor
17 |
18 | deactivate
19 |
20 | cd "$cur_dir"
21 | # Python 3
22 |
23 | rm -rf /tmp/clean-venv
24 | virtualenv -p python3 /tmp/clean-venv
25 | cd /tmp/clean-venv
26 | source bin/activate
27 | pip install --upgrade pip
28 | pip install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0rc0-cp34-cp34m-linux_x86_64.whl
29 | pip install prettytensor
30 | pip install nose
31 | nosetests prettytensor
32 |
33 | deactivate
34 | rm -rf /tmp/clean-venv
35 |
36 | cd "$cur_dir"
37 |
--------------------------------------------------------------------------------
/unrolled_lstm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/prettytensor/75daa0b11252590f548da5647addc0ea610c4c45/unrolled_lstm.png
--------------------------------------------------------------------------------