├── .gitignore ├── .travis.yml ├── CHANGELOG.md ├── LICENSE ├── README.md ├── doc └── intro.md ├── misc ├── addconst.pb ├── constant.pb ├── gengraph.py ├── linreg.pb ├── mnist │ ├── checkpoint │ ├── mnist_simple.data-00000-of-00001 │ ├── mnist_simple.index │ ├── mnist_simple.meta │ └── mnist_simple.pbtxt ├── mnist_much.py ├── mul2vars.pb └── mulbymat.pb ├── project.clj ├── resources └── ops.pbtxt ├── src └── tensorflow_clj │ ├── core.clj │ ├── experimental.clj │ ├── graph │ ├── attributes.clj │ ├── gradients.clj │ ├── node_defs.clj │ ├── proto_much.clj │ ├── transform.clj │ └── variables.clj │ ├── graph_ops.clj │ ├── graph_playground.clj │ └── util.clj └── test └── tensorflow_clj └── core_test.clj /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /classes 3 | /checkouts 4 | pom.xml 5 | pom.xml.asc 6 | *.jar 7 | *.class 8 | /.lein-* 9 | /.nrepl-port 10 | .hgignore 11 | .hg/ 12 | /.idea 13 | *.iml 14 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: required 2 | dist: trusty 3 | language: clojure 4 | #python: # Only two versions for now 5 | # - "2.7" 6 | # - "3.4" 7 | # command to install dependencies 8 | #install: 9 | # - pip install tensorflow 10 | #- pip install numpy 11 | # install TensorFlow from https://storage.googleapis.com/tensorflow/ 12 | #- if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then 13 | # pip install https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0-cp27-none-linux_x86_64.whl; 14 | # elif [[ "$TRAVIS_PYTHON_VERSION" == "3.4" ]]; then 15 | # pip install https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0-cp34-cp34m-linux_x86_64.whl; 16 | # fi 17 | jdk: 18 | - oraclejdk8 19 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Change Log 2 | All notable changes to this project will be documented in this file. This change log follows the conventions of [keepachangelog.com](http://keepachangelog.com/). 3 | 4 | ## [Unreleased][unreleased] 5 | ### Changed 6 | - Add a new arity to `make-widget-async` to provide a different widget shape. 7 | 8 | ## [0.1.1] - 2016-01-09 9 | ### Changed 10 | - Documentation on how to make the widgets. 11 | 12 | ### Removed 13 | - `make-widget-sync` - we're all async, all the time. 14 | 15 | ### Fixed 16 | - Fixed widget maker to keep working when daylight savings switches over. 17 | 18 | ## 0.1.0 - 2016-01-09 19 | ### Added 20 | - Files from the new template. 21 | - Widget maker public API - `make-widget-sync`. 22 | 23 | [unreleased]: https://github.com/your-name/tensorflow-clj/compare/0.1.1...HEAD 24 | [0.1.1]: https://github.com/your-name/tensorflow-clj/compare/0.1.0...0.1.1 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | THE ACCOMPANYING PROGRAM IS PROVIDED UNDER THE TERMS OF THIS ECLIPSE PUBLIC 2 | LICENSE ("AGREEMENT"). ANY USE, REPRODUCTION OR DISTRIBUTION OF THE PROGRAM 3 | CONSTITUTES RECIPIENT'S ACCEPTANCE OF THIS AGREEMENT. 4 | 5 | 1. DEFINITIONS 6 | 7 | "Contribution" means: 8 | 9 | a) in the case of the initial Contributor, the initial code and 10 | documentation distributed under this Agreement, and 11 | 12 | b) in the case of each subsequent Contributor: 13 | 14 | i) changes to the Program, and 15 | 16 | ii) additions to the Program; 17 | 18 | where such changes and/or additions to the Program originate from and are 19 | distributed by that particular Contributor. A Contribution 'originates' from 20 | a Contributor if it was added to the Program by such Contributor itself or 21 | anyone acting on such Contributor's behalf. Contributions do not include 22 | additions to the Program which: (i) are separate modules of software 23 | distributed in conjunction with the Program under their own license 24 | agreement, and (ii) are not derivative works of the Program. 25 | 26 | "Contributor" means any person or entity that distributes the Program. 27 | 28 | "Licensed Patents" mean patent claims licensable by a Contributor which are 29 | necessarily infringed by the use or sale of its Contribution alone or when 30 | combined with the Program. 31 | 32 | "Program" means the Contributions distributed in accordance with this 33 | Agreement. 34 | 35 | "Recipient" means anyone who receives the Program under this Agreement, 36 | including all Contributors. 37 | 38 | 2. GRANT OF RIGHTS 39 | 40 | a) Subject to the terms of this Agreement, each Contributor hereby grants 41 | Recipient a non-exclusive, worldwide, royalty-free copyright license to 42 | reproduce, prepare derivative works of, publicly display, publicly perform, 43 | distribute and sublicense the Contribution of such Contributor, if any, and 44 | such derivative works, in source code and object code form. 45 | 46 | b) Subject to the terms of this Agreement, each Contributor hereby grants 47 | Recipient a non-exclusive, worldwide, royalty-free patent license under 48 | Licensed Patents to make, use, sell, offer to sell, import and otherwise 49 | transfer the Contribution of such Contributor, if any, in source code and 50 | object code form. This patent license shall apply to the combination of the 51 | Contribution and the Program if, at the time the Contribution is added by the 52 | Contributor, such addition of the Contribution causes such combination to be 53 | covered by the Licensed Patents. The patent license shall not apply to any 54 | other combinations which include the Contribution. No hardware per se is 55 | licensed hereunder. 56 | 57 | c) Recipient understands that although each Contributor grants the licenses 58 | to its Contributions set forth herein, no assurances are provided by any 59 | Contributor that the Program does not infringe the patent or other 60 | intellectual property rights of any other entity. Each Contributor disclaims 61 | any liability to Recipient for claims brought by any other entity based on 62 | infringement of intellectual property rights or otherwise. As a condition to 63 | exercising the rights and licenses granted hereunder, each Recipient hereby 64 | assumes sole responsibility to secure any other intellectual property rights 65 | needed, if any. For example, if a third party patent license is required to 66 | allow Recipient to distribute the Program, it is Recipient's responsibility 67 | to acquire that license before distributing the Program. 68 | 69 | d) Each Contributor represents that to its knowledge it has sufficient 70 | copyright rights in its Contribution, if any, to grant the copyright license 71 | set forth in this Agreement. 72 | 73 | 3. REQUIREMENTS 74 | 75 | A Contributor may choose to distribute the Program in object code form under 76 | its own license agreement, provided that: 77 | 78 | a) it complies with the terms and conditions of this Agreement; and 79 | 80 | b) its license agreement: 81 | 82 | i) effectively disclaims on behalf of all Contributors all warranties and 83 | conditions, express and implied, including warranties or conditions of title 84 | and non-infringement, and implied warranties or conditions of merchantability 85 | and fitness for a particular purpose; 86 | 87 | ii) effectively excludes on behalf of all Contributors all liability for 88 | damages, including direct, indirect, special, incidental and consequential 89 | damages, such as lost profits; 90 | 91 | iii) states that any provisions which differ from this Agreement are offered 92 | by that Contributor alone and not by any other party; and 93 | 94 | iv) states that source code for the Program is available from such 95 | Contributor, and informs licensees how to obtain it in a reasonable manner on 96 | or through a medium customarily used for software exchange. 97 | 98 | When the Program is made available in source code form: 99 | 100 | a) it must be made available under this Agreement; and 101 | 102 | b) a copy of this Agreement must be included with each copy of the Program. 103 | 104 | Contributors may not remove or alter any copyright notices contained within 105 | the Program. 106 | 107 | Each Contributor must identify itself as the originator of its Contribution, 108 | if any, in a manner that reasonably allows subsequent Recipients to identify 109 | the originator of the Contribution. 110 | 111 | 4. COMMERCIAL DISTRIBUTION 112 | 113 | Commercial distributors of software may accept certain responsibilities with 114 | respect to end users, business partners and the like. While this license is 115 | intended to facilitate the commercial use of the Program, the Contributor who 116 | includes the Program in a commercial product offering should do so in a 117 | manner which does not create potential liability for other Contributors. 118 | Therefore, if a Contributor includes the Program in a commercial product 119 | offering, such Contributor ("Commercial Contributor") hereby agrees to defend 120 | and indemnify every other Contributor ("Indemnified Contributor") against any 121 | losses, damages and costs (collectively "Losses") arising from claims, 122 | lawsuits and other legal actions brought by a third party against the 123 | Indemnified Contributor to the extent caused by the acts or omissions of such 124 | Commercial Contributor in connection with its distribution of the Program in 125 | a commercial product offering. The obligations in this section do not apply 126 | to any claims or Losses relating to any actual or alleged intellectual 127 | property infringement. In order to qualify, an Indemnified Contributor must: 128 | a) promptly notify the Commercial Contributor in writing of such claim, and 129 | b) allow the Commercial Contributor tocontrol, and cooperate with the 130 | Commercial Contributor in, the defense and any related settlement 131 | negotiations. The Indemnified Contributor may participate in any such claim 132 | at its own expense. 133 | 134 | For example, a Contributor might include the Program in a commercial product 135 | offering, Product X. That Contributor is then a Commercial Contributor. If 136 | that Commercial Contributor then makes performance claims, or offers 137 | warranties related to Product X, those performance claims and warranties are 138 | such Commercial Contributor's responsibility alone. Under this section, the 139 | Commercial Contributor would have to defend claims against the other 140 | Contributors related to those performance claims and warranties, and if a 141 | court requires any other Contributor to pay any damages as a result, the 142 | Commercial Contributor must pay those damages. 143 | 144 | 5. NO WARRANTY 145 | 146 | EXCEPT AS EXPRESSLY SET FORTH IN THIS AGREEMENT, THE PROGRAM IS PROVIDED ON 147 | AN "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER 148 | EXPRESS OR IMPLIED INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OR 149 | CONDITIONS OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY OR FITNESS FOR A 150 | PARTICULAR PURPOSE. Each Recipient is solely responsible for determining the 151 | appropriateness of using and distributing the Program and assumes all risks 152 | associated with its exercise of rights under this Agreement , including but 153 | not limited to the risks and costs of program errors, compliance with 154 | applicable laws, damage to or loss of data, programs or equipment, and 155 | unavailability or interruption of operations. 156 | 157 | 6. DISCLAIMER OF LIABILITY 158 | 159 | EXCEPT AS EXPRESSLY SET FORTH IN THIS AGREEMENT, NEITHER RECIPIENT NOR ANY 160 | CONTRIBUTORS SHALL HAVE ANY LIABILITY FOR ANY DIRECT, INDIRECT, INCIDENTAL, 161 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING WITHOUT LIMITATION 162 | LOST PROFITS), HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 163 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 164 | ARISING IN ANY WAY OUT OF THE USE OR DISTRIBUTION OF THE PROGRAM OR THE 165 | EXERCISE OF ANY RIGHTS GRANTED HEREUNDER, EVEN IF ADVISED OF THE POSSIBILITY 166 | OF SUCH DAMAGES. 167 | 168 | 7. GENERAL 169 | 170 | If any provision of this Agreement is invalid or unenforceable under 171 | applicable law, it shall not affect the validity or enforceability of the 172 | remainder of the terms of this Agreement, and without further action by the 173 | parties hereto, such provision shall be reformed to the minimum extent 174 | necessary to make such provision valid and enforceable. 175 | 176 | If Recipient institutes patent litigation against any entity (including a 177 | cross-claim or counterclaim in a lawsuit) alleging that the Program itself 178 | (excluding combinations of the Program with other software or hardware) 179 | infringes such Recipient's patent(s), then such Recipient's rights granted 180 | under Section 2(b) shall terminate as of the date such litigation is filed. 181 | 182 | All Recipient's rights under this Agreement shall terminate if it fails to 183 | comply with any of the material terms or conditions of this Agreement and 184 | does not cure such failure in a reasonable period of time after becoming 185 | aware of such noncompliance. If all Recipient's rights under this Agreement 186 | terminate, Recipient agrees to cease use and distribution of the Program as 187 | soon as reasonably practicable. However, Recipient's obligations under this 188 | Agreement and any licenses granted by Recipient relating to the Program shall 189 | continue and survive. 190 | 191 | Everyone is permitted to copy and distribute copies of this Agreement, but in 192 | order to avoid inconsistency the Agreement is copyrighted and may only be 193 | modified in the following manner. The Agreement Steward reserves the right to 194 | publish new versions (including revisions) of this Agreement from time to 195 | time. No one other than the Agreement Steward has the right to modify this 196 | Agreement. The Eclipse Foundation is the initial Agreement Steward. The 197 | Eclipse Foundation may assign the responsibility to serve as the Agreement 198 | Steward to a suitable separate entity. Each new version of the Agreement will 199 | be given a distinguishing version number. The Program (including 200 | Contributions) may always be distributed subject to the version of the 201 | Agreement under which it was received. In addition, after a new version of 202 | the Agreement is published, Contributor may elect to distribute the Program 203 | (including its Contributions) under the new version. Except as expressly 204 | stated in Sections 2(a) and 2(b) above, Recipient receives no rights or 205 | licenses to the intellectual property of any Contributor under this 206 | Agreement, whether expressly, by implication, estoppel or otherwise. All 207 | rights in the Program not expressly granted under this Agreement are 208 | reserved. 209 | 210 | This Agreement is governed by the laws of the State of New York and the 211 | intellectual property laws of the United States of America. No party to this 212 | Agreement will bring a legal action under this Agreement more than one year 213 | after the cause of action arose. Each party waives its rights to a jury trial 214 | in any resulting litigation. 215 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tensorflow-clj 2 | 3 | [![Build Status](https://travis-ci.org/enragedginger/tensorflow-clj.svg?branch=master)](https://travis-ci.org/enragedginger/tensorflow-clj) 4 | 5 | This project is under active development. Expect the APIs to change frequently for the next several months. 6 | 7 | ## Summary 8 | Clojure API for building and running computations on Google's TensorFlow framework. 9 | 10 | ## Goal(s) 11 | * Create a mechanism for solving hard problems through machine learning that doesn't require a deep understanding of machine learning and utilizes idiomatic Clojure practices, guidelines, and ideas. 12 | 13 | ## Rationale (because it wouldn't be a Clojure library without this section) 14 | There exist numerous hard problems that can be solved through machine learning. However, most machine learning frameworks 15 | and libraries require thorough knowledge of the complex foundational topics of the field. In some cases, these libraries 16 | assume the user has sufficient breadth of understanding to know which particular algorithm / approach should be selected 17 | at a given decision point in their venture. 18 | 19 | Furthermore, most machine learning libraries are architected in such a way that code re-use across projects for data 20 | scientists is either a copy-paste extravaganza or simply impossible. Oftentimes, this is the result of object oriented 21 | or procedural design principles and practices. 22 | 23 | Therefore, `tensorflow-clj` focuses on empowering machine learning plebians to solve hard problems 24 | through the utilization of high-level constructs and automated tooling. However, this library will also allow 25 | machine learning gurus to compose basic machine learning building blocks into constructs they require to solve hard 26 | problems. This will be achieved by taking a data-first, functional approach to automate the building of Tensorflow 27 | graphs. 28 | 29 | ## Milestones 30 | * Load, run, update, and save Tensorflow graphs which are already available in the pre-defined TF Protobuf format (done) 31 | * Build a representation for any arbitrary Tensorflow operation nodes (done) 32 | * Convert collections of Clojure TF op nodes to TF Protobuf format (done) 33 | (At this point, we can build, load, run, update, and save any arbitrary TF computation graph) 34 | * Mimic convenience functionality present only in Python client (looking at you, GradientDescent) 35 | * Generate `clojure.spec` node schemas based on `op` definitions found in TF Protobuf exports 36 | * Create `clojure.spec` schemas for governing entire collections of nodes 37 | * Mimic convenience functionality present in Keras and similar libraries for building neural nets 38 | * Add functionality for building ML / NN graphs for the plebians 39 | 40 | 41 | ## How to Get Involved 42 | * You can find and chat with us on **#tensorflow** @ Clojurians Slack. [Get an invite here.](http://clojurians.net/) 43 | * If you find bugs / have feature requests, feel free to make an issue here on GitHub. For bugs, please provide sample 44 | code where possible for reproducing the issue. Also, be sure to let us know what environment (OS, Java version, CLJ version, 45 | etc.) you're using. 46 | 47 | ## Usage 48 | 49 | This library is available via Clojars: `[tensorflow-clj "0.1"]` 50 | 51 | This project should run out-of-the-box using Leiningen. It's developed locally on MacOS and also built on Travis CI Linux. 52 | 53 | More usage instructions coming soon! 54 | 55 | Run some basic tests: 56 | 57 | $ lein test 58 | 59 | ## License 60 | 61 | Copyright © 2016 Stephen M. Hopper 62 | 63 | Distributed under the Eclipse Public License either version 1.0 or (at 64 | your option) any later version. 65 | -------------------------------------------------------------------------------- /doc/intro.md: -------------------------------------------------------------------------------- 1 | # Introduction to tensorflow-clj 2 | 3 | TODO: write [great documentation](http://jacobian.org/writing/what-to-write/) 4 | -------------------------------------------------------------------------------- /misc/addconst.pb: -------------------------------------------------------------------------------- 1 | 2 | 2 3 | ConstConst* 4 | value B 5 | *@@* 6 | dtype0 7 | 4 8 | Placeholder Placeholder* 9 | dtype0* 10 | shape: 11 | ' 12 | mulMulConst Placeholder* 13 | T0" -------------------------------------------------------------------------------- /misc/constant.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/enragedginger/tensorflow-clj/11d7478bd43a9cbe838fb79fde2f2cb1f1004da3/misc/constant.pb -------------------------------------------------------------------------------- /misc/gengraph.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | 3 | import os 4 | from contextlib import contextmanager 5 | 6 | import tensorflow as tf 7 | 8 | 9 | @contextmanager 10 | def gen(name): 11 | name = os.path.join(os.getcwd(), "{}.pb".format(name)) 12 | with open(name, "wb") as out: 13 | g = tf.Graph() 14 | with g.as_default(): 15 | yield 16 | out.write(g.as_graph_def().SerializeToString()) 17 | 18 | 19 | with gen("constant"): 20 | tf.constant(123.0) 21 | 22 | with gen("addconst"): 23 | tf.constant(3.0) * tf.placeholder(tf.float32) 24 | 25 | with gen("mulbymat"): 26 | tf.placeholder(tf.float32) * tf.constant([[1., 2.], [3., 4.]]) 27 | 28 | with gen("mul2vars"): 29 | tf.placeholder(tf.float32, name="a") * \ 30 | tf.placeholder(tf.float32, name="b") 31 | 32 | with gen("linreg"): 33 | W = tf.Variable([.3], tf.float32, name="W") 34 | b = tf.Variable([-.3], tf.float32, name="b") 35 | x = tf.placeholder(tf.float32, name="x") 36 | linear_model = W * x + b 37 | tf.identity(linear_model, name="linear_model") 38 | 39 | y = tf.placeholder(tf.float32, name="y") 40 | squared_deltas = tf.square(linear_model - y, name="squared_deltas") 41 | loss = tf.reduce_sum(squared_deltas, name="loss") 42 | 43 | optimizer = tf.train.GradientDescentOptimizer(0.01) 44 | train = optimizer.minimize(loss, name="train") 45 | 46 | fixW = tf.assign(W, [-1.], name="fixW") 47 | fixb = tf.assign(b, [1.], name="fixb") 48 | 49 | init = tf.variables_initializer(tf.global_variables(), name="init") 50 | -------------------------------------------------------------------------------- /misc/linreg.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/enragedginger/tensorflow-clj/11d7478bd43a9cbe838fb79fde2f2cb1f1004da3/misc/linreg.pb -------------------------------------------------------------------------------- /misc/mnist/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "mnist_simple" 2 | all_model_checkpoint_paths: "mnist_simple" 3 | -------------------------------------------------------------------------------- /misc/mnist/mnist_simple.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/enragedginger/tensorflow-clj/11d7478bd43a9cbe838fb79fde2f2cb1f1004da3/misc/mnist/mnist_simple.data-00000-of-00001 -------------------------------------------------------------------------------- /misc/mnist/mnist_simple.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/enragedginger/tensorflow-clj/11d7478bd43a9cbe838fb79fde2f2cb1f1004da3/misc/mnist/mnist_simple.index -------------------------------------------------------------------------------- /misc/mnist/mnist_simple.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/enragedginger/tensorflow-clj/11d7478bd43a9cbe838fb79fde2f2cb1f1004da3/misc/mnist/mnist_simple.meta -------------------------------------------------------------------------------- /misc/mnist/mnist_simple.pbtxt: -------------------------------------------------------------------------------- 1 | node { 2 | name: "x" 3 | op: "Placeholder" 4 | attr { 5 | key: "dtype" 6 | value { 7 | type: DT_FLOAT 8 | } 9 | } 10 | attr { 11 | key: "shape" 12 | value { 13 | shape { 14 | } 15 | } 16 | } 17 | } 18 | node { 19 | name: "zeros" 20 | op: "Const" 21 | attr { 22 | key: "dtype" 23 | value { 24 | type: DT_FLOAT 25 | } 26 | } 27 | attr { 28 | key: "value" 29 | value { 30 | tensor { 31 | dtype: DT_FLOAT 32 | tensor_shape { 33 | dim { 34 | size: 784 35 | } 36 | dim { 37 | size: 10 38 | } 39 | } 40 | float_val: 0.0 41 | } 42 | } 43 | } 44 | } 45 | node { 46 | name: "W" 47 | op: "VariableV2" 48 | attr { 49 | key: "container" 50 | value { 51 | s: "" 52 | } 53 | } 54 | attr { 55 | key: "dtype" 56 | value { 57 | type: DT_FLOAT 58 | } 59 | } 60 | attr { 61 | key: "shape" 62 | value { 63 | shape { 64 | dim { 65 | size: 784 66 | } 67 | dim { 68 | size: 10 69 | } 70 | } 71 | } 72 | } 73 | attr { 74 | key: "shared_name" 75 | value { 76 | s: "" 77 | } 78 | } 79 | } 80 | node { 81 | name: "W/Assign" 82 | op: "Assign" 83 | input: "W" 84 | input: "zeros" 85 | attr { 86 | key: "T" 87 | value { 88 | type: DT_FLOAT 89 | } 90 | } 91 | attr { 92 | key: "_class" 93 | value { 94 | list { 95 | s: "loc:@W" 96 | } 97 | } 98 | } 99 | attr { 100 | key: "use_locking" 101 | value { 102 | b: true 103 | } 104 | } 105 | attr { 106 | key: "validate_shape" 107 | value { 108 | b: true 109 | } 110 | } 111 | } 112 | node { 113 | name: "W/read" 114 | op: "Identity" 115 | input: "W" 116 | attr { 117 | key: "T" 118 | value { 119 | type: DT_FLOAT 120 | } 121 | } 122 | attr { 123 | key: "_class" 124 | value { 125 | list { 126 | s: "loc:@W" 127 | } 128 | } 129 | } 130 | } 131 | node { 132 | name: "zeros_1" 133 | op: "Const" 134 | attr { 135 | key: "dtype" 136 | value { 137 | type: DT_FLOAT 138 | } 139 | } 140 | attr { 141 | key: "value" 142 | value { 143 | tensor { 144 | dtype: DT_FLOAT 145 | tensor_shape { 146 | dim { 147 | size: 10 148 | } 149 | } 150 | float_val: 0.0 151 | } 152 | } 153 | } 154 | } 155 | node { 156 | name: "b" 157 | op: "VariableV2" 158 | attr { 159 | key: "container" 160 | value { 161 | s: "" 162 | } 163 | } 164 | attr { 165 | key: "dtype" 166 | value { 167 | type: DT_FLOAT 168 | } 169 | } 170 | attr { 171 | key: "shape" 172 | value { 173 | shape { 174 | dim { 175 | size: 10 176 | } 177 | } 178 | } 179 | } 180 | attr { 181 | key: "shared_name" 182 | value { 183 | s: "" 184 | } 185 | } 186 | } 187 | node { 188 | name: "b/Assign" 189 | op: "Assign" 190 | input: "b" 191 | input: "zeros_1" 192 | attr { 193 | key: "T" 194 | value { 195 | type: DT_FLOAT 196 | } 197 | } 198 | attr { 199 | key: "_class" 200 | value { 201 | list { 202 | s: "loc:@b" 203 | } 204 | } 205 | } 206 | attr { 207 | key: "use_locking" 208 | value { 209 | b: true 210 | } 211 | } 212 | attr { 213 | key: "validate_shape" 214 | value { 215 | b: true 216 | } 217 | } 218 | } 219 | node { 220 | name: "b/read" 221 | op: "Identity" 222 | input: "b" 223 | attr { 224 | key: "T" 225 | value { 226 | type: DT_FLOAT 227 | } 228 | } 229 | attr { 230 | key: "_class" 231 | value { 232 | list { 233 | s: "loc:@b" 234 | } 235 | } 236 | } 237 | } 238 | node { 239 | name: "MatMul" 240 | op: "MatMul" 241 | input: "x" 242 | input: "W/read" 243 | attr { 244 | key: "T" 245 | value { 246 | type: DT_FLOAT 247 | } 248 | } 249 | attr { 250 | key: "transpose_a" 251 | value { 252 | b: false 253 | } 254 | } 255 | attr { 256 | key: "transpose_b" 257 | value { 258 | b: false 259 | } 260 | } 261 | } 262 | node { 263 | name: "add" 264 | op: "Add" 265 | input: "MatMul" 266 | input: "b/read" 267 | attr { 268 | key: "T" 269 | value { 270 | type: DT_FLOAT 271 | } 272 | } 273 | } 274 | node { 275 | name: "Placeholder" 276 | op: "Placeholder" 277 | attr { 278 | key: "dtype" 279 | value { 280 | type: DT_FLOAT 281 | } 282 | } 283 | attr { 284 | key: "shape" 285 | value { 286 | shape { 287 | } 288 | } 289 | } 290 | } 291 | node { 292 | name: "Rank" 293 | op: "Const" 294 | attr { 295 | key: "dtype" 296 | value { 297 | type: DT_INT32 298 | } 299 | } 300 | attr { 301 | key: "value" 302 | value { 303 | tensor { 304 | dtype: DT_INT32 305 | tensor_shape { 306 | } 307 | int_val: 2 308 | } 309 | } 310 | } 311 | } 312 | node { 313 | name: "Shape" 314 | op: "Shape" 315 | input: "add" 316 | attr { 317 | key: "T" 318 | value { 319 | type: DT_FLOAT 320 | } 321 | } 322 | attr { 323 | key: "out_type" 324 | value { 325 | type: DT_INT32 326 | } 327 | } 328 | } 329 | node { 330 | name: "Rank_1" 331 | op: "Const" 332 | attr { 333 | key: "dtype" 334 | value { 335 | type: DT_INT32 336 | } 337 | } 338 | attr { 339 | key: "value" 340 | value { 341 | tensor { 342 | dtype: DT_INT32 343 | tensor_shape { 344 | } 345 | int_val: 2 346 | } 347 | } 348 | } 349 | } 350 | node { 351 | name: "Shape_1" 352 | op: "Shape" 353 | input: "add" 354 | attr { 355 | key: "T" 356 | value { 357 | type: DT_FLOAT 358 | } 359 | } 360 | attr { 361 | key: "out_type" 362 | value { 363 | type: DT_INT32 364 | } 365 | } 366 | } 367 | node { 368 | name: "Sub/y" 369 | op: "Const" 370 | attr { 371 | key: "dtype" 372 | value { 373 | type: DT_INT32 374 | } 375 | } 376 | attr { 377 | key: "value" 378 | value { 379 | tensor { 380 | dtype: DT_INT32 381 | tensor_shape { 382 | } 383 | int_val: 1 384 | } 385 | } 386 | } 387 | } 388 | node { 389 | name: "Sub" 390 | op: "Sub" 391 | input: "Rank_1" 392 | input: "Sub/y" 393 | attr { 394 | key: "T" 395 | value { 396 | type: DT_INT32 397 | } 398 | } 399 | } 400 | node { 401 | name: "Slice/begin" 402 | op: "Pack" 403 | input: "Sub" 404 | attr { 405 | key: "N" 406 | value { 407 | i: 1 408 | } 409 | } 410 | attr { 411 | key: "T" 412 | value { 413 | type: DT_INT32 414 | } 415 | } 416 | attr { 417 | key: "axis" 418 | value { 419 | i: 0 420 | } 421 | } 422 | } 423 | node { 424 | name: "Slice/size" 425 | op: "Const" 426 | attr { 427 | key: "dtype" 428 | value { 429 | type: DT_INT32 430 | } 431 | } 432 | attr { 433 | key: "value" 434 | value { 435 | tensor { 436 | dtype: DT_INT32 437 | tensor_shape { 438 | dim { 439 | size: 1 440 | } 441 | } 442 | int_val: 1 443 | } 444 | } 445 | } 446 | } 447 | node { 448 | name: "Slice" 449 | op: "Slice" 450 | input: "Shape_1" 451 | input: "Slice/begin" 452 | input: "Slice/size" 453 | attr { 454 | key: "Index" 455 | value { 456 | type: DT_INT32 457 | } 458 | } 459 | attr { 460 | key: "T" 461 | value { 462 | type: DT_INT32 463 | } 464 | } 465 | } 466 | node { 467 | name: "concat/values_0" 468 | op: "Const" 469 | attr { 470 | key: "dtype" 471 | value { 472 | type: DT_INT32 473 | } 474 | } 475 | attr { 476 | key: "value" 477 | value { 478 | tensor { 479 | dtype: DT_INT32 480 | tensor_shape { 481 | dim { 482 | size: 1 483 | } 484 | } 485 | int_val: -1 486 | } 487 | } 488 | } 489 | } 490 | node { 491 | name: "concat/axis" 492 | op: "Const" 493 | attr { 494 | key: "dtype" 495 | value { 496 | type: DT_INT32 497 | } 498 | } 499 | attr { 500 | key: "value" 501 | value { 502 | tensor { 503 | dtype: DT_INT32 504 | tensor_shape { 505 | } 506 | int_val: 0 507 | } 508 | } 509 | } 510 | } 511 | node { 512 | name: "concat" 513 | op: "ConcatV2" 514 | input: "concat/values_0" 515 | input: "Slice" 516 | input: "concat/axis" 517 | attr { 518 | key: "N" 519 | value { 520 | i: 2 521 | } 522 | } 523 | attr { 524 | key: "T" 525 | value { 526 | type: DT_INT32 527 | } 528 | } 529 | attr { 530 | key: "Tidx" 531 | value { 532 | type: DT_INT32 533 | } 534 | } 535 | } 536 | node { 537 | name: "Reshape" 538 | op: "Reshape" 539 | input: "add" 540 | input: "concat" 541 | attr { 542 | key: "T" 543 | value { 544 | type: DT_FLOAT 545 | } 546 | } 547 | attr { 548 | key: "Tshape" 549 | value { 550 | type: DT_INT32 551 | } 552 | } 553 | } 554 | node { 555 | name: "Rank_2" 556 | op: "Const" 557 | attr { 558 | key: "dtype" 559 | value { 560 | type: DT_INT32 561 | } 562 | } 563 | attr { 564 | key: "value" 565 | value { 566 | tensor { 567 | dtype: DT_INT32 568 | tensor_shape { 569 | } 570 | int_val: 2 571 | } 572 | } 573 | } 574 | } 575 | node { 576 | name: "Shape_2" 577 | op: "Shape" 578 | input: "Placeholder" 579 | attr { 580 | key: "T" 581 | value { 582 | type: DT_FLOAT 583 | } 584 | } 585 | attr { 586 | key: "out_type" 587 | value { 588 | type: DT_INT32 589 | } 590 | } 591 | } 592 | node { 593 | name: "Sub_1/y" 594 | op: "Const" 595 | attr { 596 | key: "dtype" 597 | value { 598 | type: DT_INT32 599 | } 600 | } 601 | attr { 602 | key: "value" 603 | value { 604 | tensor { 605 | dtype: DT_INT32 606 | tensor_shape { 607 | } 608 | int_val: 1 609 | } 610 | } 611 | } 612 | } 613 | node { 614 | name: "Sub_1" 615 | op: "Sub" 616 | input: "Rank_2" 617 | input: "Sub_1/y" 618 | attr { 619 | key: "T" 620 | value { 621 | type: DT_INT32 622 | } 623 | } 624 | } 625 | node { 626 | name: "Slice_1/begin" 627 | op: "Pack" 628 | input: "Sub_1" 629 | attr { 630 | key: "N" 631 | value { 632 | i: 1 633 | } 634 | } 635 | attr { 636 | key: "T" 637 | value { 638 | type: DT_INT32 639 | } 640 | } 641 | attr { 642 | key: "axis" 643 | value { 644 | i: 0 645 | } 646 | } 647 | } 648 | node { 649 | name: "Slice_1/size" 650 | op: "Const" 651 | attr { 652 | key: "dtype" 653 | value { 654 | type: DT_INT32 655 | } 656 | } 657 | attr { 658 | key: "value" 659 | value { 660 | tensor { 661 | dtype: DT_INT32 662 | tensor_shape { 663 | dim { 664 | size: 1 665 | } 666 | } 667 | int_val: 1 668 | } 669 | } 670 | } 671 | } 672 | node { 673 | name: "Slice_1" 674 | op: "Slice" 675 | input: "Shape_2" 676 | input: "Slice_1/begin" 677 | input: "Slice_1/size" 678 | attr { 679 | key: "Index" 680 | value { 681 | type: DT_INT32 682 | } 683 | } 684 | attr { 685 | key: "T" 686 | value { 687 | type: DT_INT32 688 | } 689 | } 690 | } 691 | node { 692 | name: "concat_1/values_0" 693 | op: "Const" 694 | attr { 695 | key: "dtype" 696 | value { 697 | type: DT_INT32 698 | } 699 | } 700 | attr { 701 | key: "value" 702 | value { 703 | tensor { 704 | dtype: DT_INT32 705 | tensor_shape { 706 | dim { 707 | size: 1 708 | } 709 | } 710 | int_val: -1 711 | } 712 | } 713 | } 714 | } 715 | node { 716 | name: "concat_1/axis" 717 | op: "Const" 718 | attr { 719 | key: "dtype" 720 | value { 721 | type: DT_INT32 722 | } 723 | } 724 | attr { 725 | key: "value" 726 | value { 727 | tensor { 728 | dtype: DT_INT32 729 | tensor_shape { 730 | } 731 | int_val: 0 732 | } 733 | } 734 | } 735 | } 736 | node { 737 | name: "concat_1" 738 | op: "ConcatV2" 739 | input: "concat_1/values_0" 740 | input: "Slice_1" 741 | input: "concat_1/axis" 742 | attr { 743 | key: "N" 744 | value { 745 | i: 2 746 | } 747 | } 748 | attr { 749 | key: "T" 750 | value { 751 | type: DT_INT32 752 | } 753 | } 754 | attr { 755 | key: "Tidx" 756 | value { 757 | type: DT_INT32 758 | } 759 | } 760 | } 761 | node { 762 | name: "Reshape_1" 763 | op: "Reshape" 764 | input: "Placeholder" 765 | input: "concat_1" 766 | attr { 767 | key: "T" 768 | value { 769 | type: DT_FLOAT 770 | } 771 | } 772 | attr { 773 | key: "Tshape" 774 | value { 775 | type: DT_INT32 776 | } 777 | } 778 | } 779 | node { 780 | name: "SoftmaxCrossEntropyWithLogits" 781 | op: "SoftmaxCrossEntropyWithLogits" 782 | input: "Reshape" 783 | input: "Reshape_1" 784 | attr { 785 | key: "T" 786 | value { 787 | type: DT_FLOAT 788 | } 789 | } 790 | } 791 | node { 792 | name: "Sub_2/y" 793 | op: "Const" 794 | attr { 795 | key: "dtype" 796 | value { 797 | type: DT_INT32 798 | } 799 | } 800 | attr { 801 | key: "value" 802 | value { 803 | tensor { 804 | dtype: DT_INT32 805 | tensor_shape { 806 | } 807 | int_val: 1 808 | } 809 | } 810 | } 811 | } 812 | node { 813 | name: "Sub_2" 814 | op: "Sub" 815 | input: "Rank" 816 | input: "Sub_2/y" 817 | attr { 818 | key: "T" 819 | value { 820 | type: DT_INT32 821 | } 822 | } 823 | } 824 | node { 825 | name: "Slice_2/begin" 826 | op: "Const" 827 | attr { 828 | key: "dtype" 829 | value { 830 | type: DT_INT32 831 | } 832 | } 833 | attr { 834 | key: "value" 835 | value { 836 | tensor { 837 | dtype: DT_INT32 838 | tensor_shape { 839 | dim { 840 | size: 1 841 | } 842 | } 843 | int_val: 0 844 | } 845 | } 846 | } 847 | } 848 | node { 849 | name: "Slice_2/size" 850 | op: "Pack" 851 | input: "Sub_2" 852 | attr { 853 | key: "N" 854 | value { 855 | i: 1 856 | } 857 | } 858 | attr { 859 | key: "T" 860 | value { 861 | type: DT_INT32 862 | } 863 | } 864 | attr { 865 | key: "axis" 866 | value { 867 | i: 0 868 | } 869 | } 870 | } 871 | node { 872 | name: "Slice_2" 873 | op: "Slice" 874 | input: "Shape" 875 | input: "Slice_2/begin" 876 | input: "Slice_2/size" 877 | attr { 878 | key: "Index" 879 | value { 880 | type: DT_INT32 881 | } 882 | } 883 | attr { 884 | key: "T" 885 | value { 886 | type: DT_INT32 887 | } 888 | } 889 | } 890 | node { 891 | name: "Reshape_2" 892 | op: "Reshape" 893 | input: "SoftmaxCrossEntropyWithLogits" 894 | input: "Slice_2" 895 | attr { 896 | key: "T" 897 | value { 898 | type: DT_FLOAT 899 | } 900 | } 901 | attr { 902 | key: "Tshape" 903 | value { 904 | type: DT_INT32 905 | } 906 | } 907 | } 908 | node { 909 | name: "Const" 910 | op: "Const" 911 | attr { 912 | key: "dtype" 913 | value { 914 | type: DT_INT32 915 | } 916 | } 917 | attr { 918 | key: "value" 919 | value { 920 | tensor { 921 | dtype: DT_INT32 922 | tensor_shape { 923 | dim { 924 | size: 1 925 | } 926 | } 927 | int_val: 0 928 | } 929 | } 930 | } 931 | } 932 | node { 933 | name: "Mean" 934 | op: "Mean" 935 | input: "Reshape_2" 936 | input: "Const" 937 | attr { 938 | key: "T" 939 | value { 940 | type: DT_FLOAT 941 | } 942 | } 943 | attr { 944 | key: "Tidx" 945 | value { 946 | type: DT_INT32 947 | } 948 | } 949 | attr { 950 | key: "keep_dims" 951 | value { 952 | b: false 953 | } 954 | } 955 | } 956 | node { 957 | name: "gradients/Shape" 958 | op: "Const" 959 | attr { 960 | key: "dtype" 961 | value { 962 | type: DT_INT32 963 | } 964 | } 965 | attr { 966 | key: "value" 967 | value { 968 | tensor { 969 | dtype: DT_INT32 970 | tensor_shape { 971 | dim { 972 | } 973 | } 974 | } 975 | } 976 | } 977 | } 978 | node { 979 | name: "gradients/Const" 980 | op: "Const" 981 | attr { 982 | key: "dtype" 983 | value { 984 | type: DT_FLOAT 985 | } 986 | } 987 | attr { 988 | key: "value" 989 | value { 990 | tensor { 991 | dtype: DT_FLOAT 992 | tensor_shape { 993 | } 994 | float_val: 1.0 995 | } 996 | } 997 | } 998 | } 999 | node { 1000 | name: "gradients/Fill" 1001 | op: "Fill" 1002 | input: "gradients/Shape" 1003 | input: "gradients/Const" 1004 | attr { 1005 | key: "T" 1006 | value { 1007 | type: DT_FLOAT 1008 | } 1009 | } 1010 | } 1011 | node { 1012 | name: "gradients/Mean_grad/Reshape/shape" 1013 | op: "Const" 1014 | attr { 1015 | key: "dtype" 1016 | value { 1017 | type: DT_INT32 1018 | } 1019 | } 1020 | attr { 1021 | key: "value" 1022 | value { 1023 | tensor { 1024 | dtype: DT_INT32 1025 | tensor_shape { 1026 | dim { 1027 | size: 1 1028 | } 1029 | } 1030 | int_val: 1 1031 | } 1032 | } 1033 | } 1034 | } 1035 | node { 1036 | name: "gradients/Mean_grad/Reshape" 1037 | op: "Reshape" 1038 | input: "gradients/Fill" 1039 | input: "gradients/Mean_grad/Reshape/shape" 1040 | attr { 1041 | key: "T" 1042 | value { 1043 | type: DT_FLOAT 1044 | } 1045 | } 1046 | attr { 1047 | key: "Tshape" 1048 | value { 1049 | type: DT_INT32 1050 | } 1051 | } 1052 | } 1053 | node { 1054 | name: "gradients/Mean_grad/Shape" 1055 | op: "Shape" 1056 | input: "Reshape_2" 1057 | attr { 1058 | key: "T" 1059 | value { 1060 | type: DT_FLOAT 1061 | } 1062 | } 1063 | attr { 1064 | key: "out_type" 1065 | value { 1066 | type: DT_INT32 1067 | } 1068 | } 1069 | } 1070 | node { 1071 | name: "gradients/Mean_grad/Tile" 1072 | op: "Tile" 1073 | input: "gradients/Mean_grad/Reshape" 1074 | input: "gradients/Mean_grad/Shape" 1075 | attr { 1076 | key: "T" 1077 | value { 1078 | type: DT_FLOAT 1079 | } 1080 | } 1081 | attr { 1082 | key: "Tmultiples" 1083 | value { 1084 | type: DT_INT32 1085 | } 1086 | } 1087 | } 1088 | node { 1089 | name: "gradients/Mean_grad/Shape_1" 1090 | op: "Shape" 1091 | input: "Reshape_2" 1092 | attr { 1093 | key: "T" 1094 | value { 1095 | type: DT_FLOAT 1096 | } 1097 | } 1098 | attr { 1099 | key: "out_type" 1100 | value { 1101 | type: DT_INT32 1102 | } 1103 | } 1104 | } 1105 | node { 1106 | name: "gradients/Mean_grad/Shape_2" 1107 | op: "Const" 1108 | attr { 1109 | key: "dtype" 1110 | value { 1111 | type: DT_INT32 1112 | } 1113 | } 1114 | attr { 1115 | key: "value" 1116 | value { 1117 | tensor { 1118 | dtype: DT_INT32 1119 | tensor_shape { 1120 | dim { 1121 | } 1122 | } 1123 | } 1124 | } 1125 | } 1126 | } 1127 | node { 1128 | name: "gradients/Mean_grad/Const" 1129 | op: "Const" 1130 | attr { 1131 | key: "dtype" 1132 | value { 1133 | type: DT_INT32 1134 | } 1135 | } 1136 | attr { 1137 | key: "value" 1138 | value { 1139 | tensor { 1140 | dtype: DT_INT32 1141 | tensor_shape { 1142 | dim { 1143 | size: 1 1144 | } 1145 | } 1146 | int_val: 0 1147 | } 1148 | } 1149 | } 1150 | } 1151 | node { 1152 | name: "gradients/Mean_grad/Prod" 1153 | op: "Prod" 1154 | input: "gradients/Mean_grad/Shape_1" 1155 | input: "gradients/Mean_grad/Const" 1156 | attr { 1157 | key: "T" 1158 | value { 1159 | type: DT_INT32 1160 | } 1161 | } 1162 | attr { 1163 | key: "Tidx" 1164 | value { 1165 | type: DT_INT32 1166 | } 1167 | } 1168 | attr { 1169 | key: "keep_dims" 1170 | value { 1171 | b: false 1172 | } 1173 | } 1174 | } 1175 | node { 1176 | name: "gradients/Mean_grad/Const_1" 1177 | op: "Const" 1178 | attr { 1179 | key: "dtype" 1180 | value { 1181 | type: DT_INT32 1182 | } 1183 | } 1184 | attr { 1185 | key: "value" 1186 | value { 1187 | tensor { 1188 | dtype: DT_INT32 1189 | tensor_shape { 1190 | dim { 1191 | size: 1 1192 | } 1193 | } 1194 | int_val: 0 1195 | } 1196 | } 1197 | } 1198 | } 1199 | node { 1200 | name: "gradients/Mean_grad/Prod_1" 1201 | op: "Prod" 1202 | input: "gradients/Mean_grad/Shape_2" 1203 | input: "gradients/Mean_grad/Const_1" 1204 | attr { 1205 | key: "T" 1206 | value { 1207 | type: DT_INT32 1208 | } 1209 | } 1210 | attr { 1211 | key: "Tidx" 1212 | value { 1213 | type: DT_INT32 1214 | } 1215 | } 1216 | attr { 1217 | key: "keep_dims" 1218 | value { 1219 | b: false 1220 | } 1221 | } 1222 | } 1223 | node { 1224 | name: "gradients/Mean_grad/Maximum/y" 1225 | op: "Const" 1226 | attr { 1227 | key: "dtype" 1228 | value { 1229 | type: DT_INT32 1230 | } 1231 | } 1232 | attr { 1233 | key: "value" 1234 | value { 1235 | tensor { 1236 | dtype: DT_INT32 1237 | tensor_shape { 1238 | } 1239 | int_val: 1 1240 | } 1241 | } 1242 | } 1243 | } 1244 | node { 1245 | name: "gradients/Mean_grad/Maximum" 1246 | op: "Maximum" 1247 | input: "gradients/Mean_grad/Prod_1" 1248 | input: "gradients/Mean_grad/Maximum/y" 1249 | attr { 1250 | key: "T" 1251 | value { 1252 | type: DT_INT32 1253 | } 1254 | } 1255 | } 1256 | node { 1257 | name: "gradients/Mean_grad/floordiv" 1258 | op: "FloorDiv" 1259 | input: "gradients/Mean_grad/Prod" 1260 | input: "gradients/Mean_grad/Maximum" 1261 | attr { 1262 | key: "T" 1263 | value { 1264 | type: DT_INT32 1265 | } 1266 | } 1267 | } 1268 | node { 1269 | name: "gradients/Mean_grad/Cast" 1270 | op: "Cast" 1271 | input: "gradients/Mean_grad/floordiv" 1272 | attr { 1273 | key: "DstT" 1274 | value { 1275 | type: DT_FLOAT 1276 | } 1277 | } 1278 | attr { 1279 | key: "SrcT" 1280 | value { 1281 | type: DT_INT32 1282 | } 1283 | } 1284 | } 1285 | node { 1286 | name: "gradients/Mean_grad/truediv" 1287 | op: "RealDiv" 1288 | input: "gradients/Mean_grad/Tile" 1289 | input: "gradients/Mean_grad/Cast" 1290 | attr { 1291 | key: "T" 1292 | value { 1293 | type: DT_FLOAT 1294 | } 1295 | } 1296 | } 1297 | node { 1298 | name: "gradients/Reshape_2_grad/Shape" 1299 | op: "Shape" 1300 | input: "SoftmaxCrossEntropyWithLogits" 1301 | attr { 1302 | key: "T" 1303 | value { 1304 | type: DT_FLOAT 1305 | } 1306 | } 1307 | attr { 1308 | key: "out_type" 1309 | value { 1310 | type: DT_INT32 1311 | } 1312 | } 1313 | } 1314 | node { 1315 | name: "gradients/Reshape_2_grad/Reshape" 1316 | op: "Reshape" 1317 | input: "gradients/Mean_grad/truediv" 1318 | input: "gradients/Reshape_2_grad/Shape" 1319 | attr { 1320 | key: "T" 1321 | value { 1322 | type: DT_FLOAT 1323 | } 1324 | } 1325 | attr { 1326 | key: "Tshape" 1327 | value { 1328 | type: DT_INT32 1329 | } 1330 | } 1331 | } 1332 | node { 1333 | name: "gradients/zeros_like" 1334 | op: "ZerosLike" 1335 | input: "SoftmaxCrossEntropyWithLogits:1" 1336 | attr { 1337 | key: "T" 1338 | value { 1339 | type: DT_FLOAT 1340 | } 1341 | } 1342 | } 1343 | node { 1344 | name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" 1345 | op: "Const" 1346 | attr { 1347 | key: "dtype" 1348 | value { 1349 | type: DT_INT32 1350 | } 1351 | } 1352 | attr { 1353 | key: "value" 1354 | value { 1355 | tensor { 1356 | dtype: DT_INT32 1357 | tensor_shape { 1358 | } 1359 | int_val: -1 1360 | } 1361 | } 1362 | } 1363 | } 1364 | node { 1365 | name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" 1366 | op: "ExpandDims" 1367 | input: "gradients/Reshape_2_grad/Reshape" 1368 | input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" 1369 | attr { 1370 | key: "T" 1371 | value { 1372 | type: DT_FLOAT 1373 | } 1374 | } 1375 | attr { 1376 | key: "Tdim" 1377 | value { 1378 | type: DT_INT32 1379 | } 1380 | } 1381 | } 1382 | node { 1383 | name: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" 1384 | op: "Mul" 1385 | input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" 1386 | input: "SoftmaxCrossEntropyWithLogits:1" 1387 | attr { 1388 | key: "T" 1389 | value { 1390 | type: DT_FLOAT 1391 | } 1392 | } 1393 | } 1394 | node { 1395 | name: "gradients/Reshape_grad/Shape" 1396 | op: "Shape" 1397 | input: "add" 1398 | attr { 1399 | key: "T" 1400 | value { 1401 | type: DT_FLOAT 1402 | } 1403 | } 1404 | attr { 1405 | key: "out_type" 1406 | value { 1407 | type: DT_INT32 1408 | } 1409 | } 1410 | } 1411 | node { 1412 | name: "gradients/Reshape_grad/Reshape" 1413 | op: "Reshape" 1414 | input: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" 1415 | input: "gradients/Reshape_grad/Shape" 1416 | attr { 1417 | key: "T" 1418 | value { 1419 | type: DT_FLOAT 1420 | } 1421 | } 1422 | attr { 1423 | key: "Tshape" 1424 | value { 1425 | type: DT_INT32 1426 | } 1427 | } 1428 | } 1429 | node { 1430 | name: "gradients/add_grad/Shape" 1431 | op: "Shape" 1432 | input: "MatMul" 1433 | attr { 1434 | key: "T" 1435 | value { 1436 | type: DT_FLOAT 1437 | } 1438 | } 1439 | attr { 1440 | key: "out_type" 1441 | value { 1442 | type: DT_INT32 1443 | } 1444 | } 1445 | } 1446 | node { 1447 | name: "gradients/add_grad/Shape_1" 1448 | op: "Const" 1449 | attr { 1450 | key: "dtype" 1451 | value { 1452 | type: DT_INT32 1453 | } 1454 | } 1455 | attr { 1456 | key: "value" 1457 | value { 1458 | tensor { 1459 | dtype: DT_INT32 1460 | tensor_shape { 1461 | dim { 1462 | size: 1 1463 | } 1464 | } 1465 | int_val: 10 1466 | } 1467 | } 1468 | } 1469 | } 1470 | node { 1471 | name: "gradients/add_grad/BroadcastGradientArgs" 1472 | op: "BroadcastGradientArgs" 1473 | input: "gradients/add_grad/Shape" 1474 | input: "gradients/add_grad/Shape_1" 1475 | attr { 1476 | key: "T" 1477 | value { 1478 | type: DT_INT32 1479 | } 1480 | } 1481 | } 1482 | node { 1483 | name: "gradients/add_grad/Sum" 1484 | op: "Sum" 1485 | input: "gradients/Reshape_grad/Reshape" 1486 | input: "gradients/add_grad/BroadcastGradientArgs" 1487 | attr { 1488 | key: "T" 1489 | value { 1490 | type: DT_FLOAT 1491 | } 1492 | } 1493 | attr { 1494 | key: "Tidx" 1495 | value { 1496 | type: DT_INT32 1497 | } 1498 | } 1499 | attr { 1500 | key: "keep_dims" 1501 | value { 1502 | b: false 1503 | } 1504 | } 1505 | } 1506 | node { 1507 | name: "gradients/add_grad/Reshape" 1508 | op: "Reshape" 1509 | input: "gradients/add_grad/Sum" 1510 | input: "gradients/add_grad/Shape" 1511 | attr { 1512 | key: "T" 1513 | value { 1514 | type: DT_FLOAT 1515 | } 1516 | } 1517 | attr { 1518 | key: "Tshape" 1519 | value { 1520 | type: DT_INT32 1521 | } 1522 | } 1523 | } 1524 | node { 1525 | name: "gradients/add_grad/Sum_1" 1526 | op: "Sum" 1527 | input: "gradients/Reshape_grad/Reshape" 1528 | input: "gradients/add_grad/BroadcastGradientArgs:1" 1529 | attr { 1530 | key: "T" 1531 | value { 1532 | type: DT_FLOAT 1533 | } 1534 | } 1535 | attr { 1536 | key: "Tidx" 1537 | value { 1538 | type: DT_INT32 1539 | } 1540 | } 1541 | attr { 1542 | key: "keep_dims" 1543 | value { 1544 | b: false 1545 | } 1546 | } 1547 | } 1548 | node { 1549 | name: "gradients/add_grad/Reshape_1" 1550 | op: "Reshape" 1551 | input: "gradients/add_grad/Sum_1" 1552 | input: "gradients/add_grad/Shape_1" 1553 | attr { 1554 | key: "T" 1555 | value { 1556 | type: DT_FLOAT 1557 | } 1558 | } 1559 | attr { 1560 | key: "Tshape" 1561 | value { 1562 | type: DT_INT32 1563 | } 1564 | } 1565 | } 1566 | node { 1567 | name: "gradients/add_grad/tuple/group_deps" 1568 | op: "NoOp" 1569 | input: "^gradients/add_grad/Reshape" 1570 | input: "^gradients/add_grad/Reshape_1" 1571 | } 1572 | node { 1573 | name: "gradients/add_grad/tuple/control_dependency" 1574 | op: "Identity" 1575 | input: "gradients/add_grad/Reshape" 1576 | input: "^gradients/add_grad/tuple/group_deps" 1577 | attr { 1578 | key: "T" 1579 | value { 1580 | type: DT_FLOAT 1581 | } 1582 | } 1583 | attr { 1584 | key: "_class" 1585 | value { 1586 | list { 1587 | s: "loc:@gradients/add_grad/Reshape" 1588 | } 1589 | } 1590 | } 1591 | } 1592 | node { 1593 | name: "gradients/add_grad/tuple/control_dependency_1" 1594 | op: "Identity" 1595 | input: "gradients/add_grad/Reshape_1" 1596 | input: "^gradients/add_grad/tuple/group_deps" 1597 | attr { 1598 | key: "T" 1599 | value { 1600 | type: DT_FLOAT 1601 | } 1602 | } 1603 | attr { 1604 | key: "_class" 1605 | value { 1606 | list { 1607 | s: "loc:@gradients/add_grad/Reshape_1" 1608 | } 1609 | } 1610 | } 1611 | } 1612 | node { 1613 | name: "gradients/MatMul_grad/MatMul" 1614 | op: "MatMul" 1615 | input: "gradients/add_grad/tuple/control_dependency" 1616 | input: "W/read" 1617 | attr { 1618 | key: "T" 1619 | value { 1620 | type: DT_FLOAT 1621 | } 1622 | } 1623 | attr { 1624 | key: "transpose_a" 1625 | value { 1626 | b: false 1627 | } 1628 | } 1629 | attr { 1630 | key: "transpose_b" 1631 | value { 1632 | b: true 1633 | } 1634 | } 1635 | } 1636 | node { 1637 | name: "gradients/MatMul_grad/MatMul_1" 1638 | op: "MatMul" 1639 | input: "x" 1640 | input: "gradients/add_grad/tuple/control_dependency" 1641 | attr { 1642 | key: "T" 1643 | value { 1644 | type: DT_FLOAT 1645 | } 1646 | } 1647 | attr { 1648 | key: "transpose_a" 1649 | value { 1650 | b: true 1651 | } 1652 | } 1653 | attr { 1654 | key: "transpose_b" 1655 | value { 1656 | b: false 1657 | } 1658 | } 1659 | } 1660 | node { 1661 | name: "gradients/MatMul_grad/tuple/group_deps" 1662 | op: "NoOp" 1663 | input: "^gradients/MatMul_grad/MatMul" 1664 | input: "^gradients/MatMul_grad/MatMul_1" 1665 | } 1666 | node { 1667 | name: "gradients/MatMul_grad/tuple/control_dependency" 1668 | op: "Identity" 1669 | input: "gradients/MatMul_grad/MatMul" 1670 | input: "^gradients/MatMul_grad/tuple/group_deps" 1671 | attr { 1672 | key: "T" 1673 | value { 1674 | type: DT_FLOAT 1675 | } 1676 | } 1677 | attr { 1678 | key: "_class" 1679 | value { 1680 | list { 1681 | s: "loc:@gradients/MatMul_grad/MatMul" 1682 | } 1683 | } 1684 | } 1685 | } 1686 | node { 1687 | name: "gradients/MatMul_grad/tuple/control_dependency_1" 1688 | op: "Identity" 1689 | input: "gradients/MatMul_grad/MatMul_1" 1690 | input: "^gradients/MatMul_grad/tuple/group_deps" 1691 | attr { 1692 | key: "T" 1693 | value { 1694 | type: DT_FLOAT 1695 | } 1696 | } 1697 | attr { 1698 | key: "_class" 1699 | value { 1700 | list { 1701 | s: "loc:@gradients/MatMul_grad/MatMul_1" 1702 | } 1703 | } 1704 | } 1705 | } 1706 | node { 1707 | name: "GradientDescent/learning_rate" 1708 | op: "Const" 1709 | attr { 1710 | key: "dtype" 1711 | value { 1712 | type: DT_FLOAT 1713 | } 1714 | } 1715 | attr { 1716 | key: "value" 1717 | value { 1718 | tensor { 1719 | dtype: DT_FLOAT 1720 | tensor_shape { 1721 | } 1722 | float_val: 0.5 1723 | } 1724 | } 1725 | } 1726 | } 1727 | node { 1728 | name: "GradientDescent/update_W/ApplyGradientDescent" 1729 | op: "ApplyGradientDescent" 1730 | input: "W" 1731 | input: "GradientDescent/learning_rate" 1732 | input: "gradients/MatMul_grad/tuple/control_dependency_1" 1733 | attr { 1734 | key: "T" 1735 | value { 1736 | type: DT_FLOAT 1737 | } 1738 | } 1739 | attr { 1740 | key: "_class" 1741 | value { 1742 | list { 1743 | s: "loc:@W" 1744 | } 1745 | } 1746 | } 1747 | attr { 1748 | key: "use_locking" 1749 | value { 1750 | b: false 1751 | } 1752 | } 1753 | } 1754 | node { 1755 | name: "GradientDescent/update_b/ApplyGradientDescent" 1756 | op: "ApplyGradientDescent" 1757 | input: "b" 1758 | input: "GradientDescent/learning_rate" 1759 | input: "gradients/add_grad/tuple/control_dependency_1" 1760 | attr { 1761 | key: "T" 1762 | value { 1763 | type: DT_FLOAT 1764 | } 1765 | } 1766 | attr { 1767 | key: "_class" 1768 | value { 1769 | list { 1770 | s: "loc:@b" 1771 | } 1772 | } 1773 | } 1774 | attr { 1775 | key: "use_locking" 1776 | value { 1777 | b: false 1778 | } 1779 | } 1780 | } 1781 | node { 1782 | name: "GradientDescent" 1783 | op: "NoOp" 1784 | input: "^GradientDescent/update_W/ApplyGradientDescent" 1785 | input: "^GradientDescent/update_b/ApplyGradientDescent" 1786 | } 1787 | node { 1788 | name: "init" 1789 | op: "NoOp" 1790 | input: "^W/Assign" 1791 | input: "^b/Assign" 1792 | } 1793 | node { 1794 | name: "ArgMax/dimension" 1795 | op: "Const" 1796 | attr { 1797 | key: "dtype" 1798 | value { 1799 | type: DT_INT32 1800 | } 1801 | } 1802 | attr { 1803 | key: "value" 1804 | value { 1805 | tensor { 1806 | dtype: DT_INT32 1807 | tensor_shape { 1808 | } 1809 | int_val: 1 1810 | } 1811 | } 1812 | } 1813 | } 1814 | node { 1815 | name: "ArgMax" 1816 | op: "ArgMax" 1817 | input: "add" 1818 | input: "ArgMax/dimension" 1819 | attr { 1820 | key: "T" 1821 | value { 1822 | type: DT_FLOAT 1823 | } 1824 | } 1825 | attr { 1826 | key: "Tidx" 1827 | value { 1828 | type: DT_INT32 1829 | } 1830 | } 1831 | } 1832 | node { 1833 | name: "ArgMax_1/dimension" 1834 | op: "Const" 1835 | attr { 1836 | key: "dtype" 1837 | value { 1838 | type: DT_INT32 1839 | } 1840 | } 1841 | attr { 1842 | key: "value" 1843 | value { 1844 | tensor { 1845 | dtype: DT_INT32 1846 | tensor_shape { 1847 | } 1848 | int_val: 1 1849 | } 1850 | } 1851 | } 1852 | } 1853 | node { 1854 | name: "ArgMax_1" 1855 | op: "ArgMax" 1856 | input: "Placeholder" 1857 | input: "ArgMax_1/dimension" 1858 | attr { 1859 | key: "T" 1860 | value { 1861 | type: DT_FLOAT 1862 | } 1863 | } 1864 | attr { 1865 | key: "Tidx" 1866 | value { 1867 | type: DT_INT32 1868 | } 1869 | } 1870 | } 1871 | node { 1872 | name: "Equal" 1873 | op: "Equal" 1874 | input: "ArgMax" 1875 | input: "ArgMax_1" 1876 | attr { 1877 | key: "T" 1878 | value { 1879 | type: DT_INT64 1880 | } 1881 | } 1882 | } 1883 | node { 1884 | name: "Cast_1" 1885 | op: "Cast" 1886 | input: "Equal" 1887 | attr { 1888 | key: "DstT" 1889 | value { 1890 | type: DT_FLOAT 1891 | } 1892 | } 1893 | attr { 1894 | key: "SrcT" 1895 | value { 1896 | type: DT_BOOL 1897 | } 1898 | } 1899 | } 1900 | node { 1901 | name: "Const_1" 1902 | op: "Const" 1903 | attr { 1904 | key: "dtype" 1905 | value { 1906 | type: DT_INT32 1907 | } 1908 | } 1909 | attr { 1910 | key: "value" 1911 | value { 1912 | tensor { 1913 | dtype: DT_INT32 1914 | tensor_shape { 1915 | dim { 1916 | size: 1 1917 | } 1918 | } 1919 | int_val: 0 1920 | } 1921 | } 1922 | } 1923 | } 1924 | node { 1925 | name: "Mean_1" 1926 | op: "Mean" 1927 | input: "Cast_1" 1928 | input: "Const_1" 1929 | attr { 1930 | key: "T" 1931 | value { 1932 | type: DT_FLOAT 1933 | } 1934 | } 1935 | attr { 1936 | key: "Tidx" 1937 | value { 1938 | type: DT_INT32 1939 | } 1940 | } 1941 | attr { 1942 | key: "keep_dims" 1943 | value { 1944 | b: false 1945 | } 1946 | } 1947 | } 1948 | node { 1949 | name: "save/Const" 1950 | op: "Const" 1951 | attr { 1952 | key: "dtype" 1953 | value { 1954 | type: DT_STRING 1955 | } 1956 | } 1957 | attr { 1958 | key: "value" 1959 | value { 1960 | tensor { 1961 | dtype: DT_STRING 1962 | tensor_shape { 1963 | } 1964 | string_val: "model" 1965 | } 1966 | } 1967 | } 1968 | } 1969 | node { 1970 | name: "save/SaveV2/tensor_names" 1971 | op: "Const" 1972 | attr { 1973 | key: "dtype" 1974 | value { 1975 | type: DT_STRING 1976 | } 1977 | } 1978 | attr { 1979 | key: "value" 1980 | value { 1981 | tensor { 1982 | dtype: DT_STRING 1983 | tensor_shape { 1984 | dim { 1985 | size: 2 1986 | } 1987 | } 1988 | string_val: "W" 1989 | string_val: "b" 1990 | } 1991 | } 1992 | } 1993 | } 1994 | node { 1995 | name: "save/SaveV2/shape_and_slices" 1996 | op: "Const" 1997 | attr { 1998 | key: "dtype" 1999 | value { 2000 | type: DT_STRING 2001 | } 2002 | } 2003 | attr { 2004 | key: "value" 2005 | value { 2006 | tensor { 2007 | dtype: DT_STRING 2008 | tensor_shape { 2009 | dim { 2010 | size: 2 2011 | } 2012 | } 2013 | string_val: "" 2014 | string_val: "" 2015 | } 2016 | } 2017 | } 2018 | } 2019 | node { 2020 | name: "save/SaveV2" 2021 | op: "SaveV2" 2022 | input: "save/Const" 2023 | input: "save/SaveV2/tensor_names" 2024 | input: "save/SaveV2/shape_and_slices" 2025 | input: "W" 2026 | input: "b" 2027 | attr { 2028 | key: "dtypes" 2029 | value { 2030 | list { 2031 | type: DT_FLOAT 2032 | type: DT_FLOAT 2033 | } 2034 | } 2035 | } 2036 | } 2037 | node { 2038 | name: "save/control_dependency" 2039 | op: "Identity" 2040 | input: "save/Const" 2041 | input: "^save/SaveV2" 2042 | attr { 2043 | key: "T" 2044 | value { 2045 | type: DT_STRING 2046 | } 2047 | } 2048 | attr { 2049 | key: "_class" 2050 | value { 2051 | list { 2052 | s: "loc:@save/Const" 2053 | } 2054 | } 2055 | } 2056 | } 2057 | node { 2058 | name: "save/RestoreV2/tensor_names" 2059 | op: "Const" 2060 | attr { 2061 | key: "dtype" 2062 | value { 2063 | type: DT_STRING 2064 | } 2065 | } 2066 | attr { 2067 | key: "value" 2068 | value { 2069 | tensor { 2070 | dtype: DT_STRING 2071 | tensor_shape { 2072 | dim { 2073 | size: 1 2074 | } 2075 | } 2076 | string_val: "W" 2077 | } 2078 | } 2079 | } 2080 | } 2081 | node { 2082 | name: "save/RestoreV2/shape_and_slices" 2083 | op: "Const" 2084 | attr { 2085 | key: "dtype" 2086 | value { 2087 | type: DT_STRING 2088 | } 2089 | } 2090 | attr { 2091 | key: "value" 2092 | value { 2093 | tensor { 2094 | dtype: DT_STRING 2095 | tensor_shape { 2096 | dim { 2097 | size: 1 2098 | } 2099 | } 2100 | string_val: "" 2101 | } 2102 | } 2103 | } 2104 | } 2105 | node { 2106 | name: "save/RestoreV2" 2107 | op: "RestoreV2" 2108 | input: "save/Const" 2109 | input: "save/RestoreV2/tensor_names" 2110 | input: "save/RestoreV2/shape_and_slices" 2111 | attr { 2112 | key: "dtypes" 2113 | value { 2114 | list { 2115 | type: DT_FLOAT 2116 | } 2117 | } 2118 | } 2119 | } 2120 | node { 2121 | name: "save/Assign" 2122 | op: "Assign" 2123 | input: "W" 2124 | input: "save/RestoreV2" 2125 | attr { 2126 | key: "T" 2127 | value { 2128 | type: DT_FLOAT 2129 | } 2130 | } 2131 | attr { 2132 | key: "_class" 2133 | value { 2134 | list { 2135 | s: "loc:@W" 2136 | } 2137 | } 2138 | } 2139 | attr { 2140 | key: "use_locking" 2141 | value { 2142 | b: true 2143 | } 2144 | } 2145 | attr { 2146 | key: "validate_shape" 2147 | value { 2148 | b: true 2149 | } 2150 | } 2151 | } 2152 | node { 2153 | name: "save/RestoreV2_1/tensor_names" 2154 | op: "Const" 2155 | attr { 2156 | key: "dtype" 2157 | value { 2158 | type: DT_STRING 2159 | } 2160 | } 2161 | attr { 2162 | key: "value" 2163 | value { 2164 | tensor { 2165 | dtype: DT_STRING 2166 | tensor_shape { 2167 | dim { 2168 | size: 1 2169 | } 2170 | } 2171 | string_val: "b" 2172 | } 2173 | } 2174 | } 2175 | } 2176 | node { 2177 | name: "save/RestoreV2_1/shape_and_slices" 2178 | op: "Const" 2179 | attr { 2180 | key: "dtype" 2181 | value { 2182 | type: DT_STRING 2183 | } 2184 | } 2185 | attr { 2186 | key: "value" 2187 | value { 2188 | tensor { 2189 | dtype: DT_STRING 2190 | tensor_shape { 2191 | dim { 2192 | size: 1 2193 | } 2194 | } 2195 | string_val: "" 2196 | } 2197 | } 2198 | } 2199 | } 2200 | node { 2201 | name: "save/RestoreV2_1" 2202 | op: "RestoreV2" 2203 | input: "save/Const" 2204 | input: "save/RestoreV2_1/tensor_names" 2205 | input: "save/RestoreV2_1/shape_and_slices" 2206 | attr { 2207 | key: "dtypes" 2208 | value { 2209 | list { 2210 | type: DT_FLOAT 2211 | } 2212 | } 2213 | } 2214 | } 2215 | node { 2216 | name: "save/Assign_1" 2217 | op: "Assign" 2218 | input: "b" 2219 | input: "save/RestoreV2_1" 2220 | attr { 2221 | key: "T" 2222 | value { 2223 | type: DT_FLOAT 2224 | } 2225 | } 2226 | attr { 2227 | key: "_class" 2228 | value { 2229 | list { 2230 | s: "loc:@b" 2231 | } 2232 | } 2233 | } 2234 | attr { 2235 | key: "use_locking" 2236 | value { 2237 | b: true 2238 | } 2239 | } 2240 | attr { 2241 | key: "validate_shape" 2242 | value { 2243 | b: true 2244 | } 2245 | } 2246 | } 2247 | node { 2248 | name: "save/restore_all" 2249 | op: "NoOp" 2250 | input: "^save/Assign" 2251 | input: "^save/Assign_1" 2252 | } 2253 | versions { 2254 | producer: 21 2255 | } 2256 | -------------------------------------------------------------------------------- /misc/mnist_much.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import sys 7 | 8 | from tensorflow.examples.tutorials.mnist import input_data 9 | 10 | import tensorflow as tf 11 | 12 | FLAGS = None 13 | 14 | import os 15 | from contextlib import contextmanager 16 | 17 | @contextmanager 18 | def gen(name): 19 | name = os.path.join(os.getcwd(), "{}.pbtxt".format(name)) 20 | g = tf.Graph() 21 | with g.as_default(): 22 | yield 23 | #tf.train.export_meta_graph(graph_def=g.as_graph_def(), filename=name, as_text=True) 24 | tf.train.write_graph(g, '.', name, as_text=True) 25 | 26 | def main(_): 27 | with gen("mnist/mnist_simple"): 28 | # Import data 29 | mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) 30 | 31 | # Create the model 32 | x = tf.placeholder(tf.float32, [None, 784], name="x") 33 | W = tf.Variable(tf.zeros([784, 10]), name="W") 34 | b = tf.Variable(tf.zeros([10]), name="b") 35 | y = tf.matmul(x, W) + b 36 | 37 | # Define loss and optimizer 38 | y_ = tf.placeholder(tf.float32, [None, 10]) 39 | 40 | # The raw formulation of cross-entropy, 41 | # 42 | # tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)), 43 | # reduction_indices=[1])) 44 | # 45 | # can be numerically unstable. 46 | # 47 | # So here we use tf.nn.softmax_cross_entropy_with_logits on the raw 48 | # outputs of 'y', and then average across the batch. 49 | cross_entropy = tf.reduce_mean( 50 | tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)) 51 | train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) 52 | 53 | sess = tf.InteractiveSession() 54 | tf.global_variables_initializer().run() 55 | # Train 56 | for _ in range(1000): 57 | batch_xs, batch_ys = mnist.train.next_batch(100) 58 | sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) 59 | 60 | # Test trained model 61 | correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) 62 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 63 | print(sess.run(accuracy, feed_dict={x: mnist.test.images, 64 | y_: mnist.test.labels})) 65 | tf.train.Saver().save(sess, 'mnist/mnist_simple') 66 | 67 | if __name__ == '__main__': 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data', 70 | help='Directory for storing input data') 71 | FLAGS, unparsed = parser.parse_known_args() 72 | tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) -------------------------------------------------------------------------------- /misc/mul2vars.pb: -------------------------------------------------------------------------------- 1 | 2 | * 3 | a Placeholder* 4 | dtype0* 5 | shape: 6 | * 7 | b Placeholder* 8 | dtype0* 9 | shape: 10 |  11 | mulMulab* 12 | T0" -------------------------------------------------------------------------------- /misc/mulbymat.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/enragedginger/tensorflow-clj/11d7478bd43a9cbe838fb79fde2f2cb1f1004da3/misc/mulbymat.pb -------------------------------------------------------------------------------- /project.clj: -------------------------------------------------------------------------------- 1 | (defproject tensorflow-clj "0.1" 2 | :description "Gateway from Clojure to Tensorflow" 3 | :url "https://github.com/enragedginger/tensorflow-clj" 4 | :license {:name "Eclipse Public License" 5 | :url "http://www.eclipse.org/legal/epl-v10.html"} 6 | :dependencies [[org.clojure/clojure "1.8.0"] 7 | [net.mikera/core.matrix "0.58.0"] 8 | [org.tensorflow/tensorflow "1.1.0"] 9 | [org.tensorflow/proto "1.1.0"] 10 | [org.clojars.ghaskins/protobuf "3.0.2-2"] 11 | [com.google.protobuf/protobuf-java "3.2.0"] 12 | [random-string "0.1.0"] 13 | [camel-snake-kebab "0.4.0"] 14 | [ubergraph "0.3.1"]] 15 | :signing {:gpg-key "enragedginger@gmail.com"} 16 | :main ^:skip-aot tensorflow-clj.core 17 | :target-path "target/%s" 18 | :profiles {:uberjar {:aot :all}}) 19 | -------------------------------------------------------------------------------- /src/tensorflow_clj/core.clj: -------------------------------------------------------------------------------- 1 | (ns tensorflow-clj.core 2 | (:require [clojure.core.matrix :as matrix] 3 | [tensorflow-clj.util :as util]) 4 | (:gen-class)) 5 | 6 | (def ^:dynamic graph nil) 7 | (def ^:dynamic session nil) 8 | 9 | (defmacro with-graph-and-session [& body] 10 | `(binding [graph (org.tensorflow.Graph.)] 11 | (try 12 | (binding [session (org.tensorflow.Session. graph)] 13 | (try 14 | ~@body 15 | (finally (.close session)))) 16 | (finally (.close graph))))) 17 | 18 | (defmacro with-graph-file [filename & body] 19 | `(with-graph-and-session 20 | (.importGraphDef graph (util/slurp-binary ~filename)) 21 | ~@body 22 | )) 23 | 24 | (defn tensor [value] 25 | (let [shp (matrix/shape value)] 26 | (if-not shp 27 | (org.tensorflow.Tensor/create (float value)) 28 | (org.tensorflow.Tensor/create 29 | (long-array shp) 30 | (java.nio.FloatBuffer/wrap 31 | (float-array (matrix/to-vector value))))))) 32 | 33 | (defn tensor->clj [t] 34 | (assert (instance? org.tensorflow.Tensor t)) 35 | (let [shp (vec (.shape t))] 36 | (if (empty? shp) 37 | (.floatValue t) 38 | (let [buf (java.nio.FloatBuffer/allocate (.numElements t))] 39 | (.writeTo t buf) 40 | (matrix/reshape (vec (.array buf)) 41 | shp))))) 42 | 43 | (defn run-graph [feed-ops & fetch-ops] 44 | (assert session) 45 | (let [runner (.runner session)] 46 | (doseq [[feed-op feed-value] feed-ops] 47 | (if feed-value 48 | (.feed runner (name feed-op) (tensor feed-value)) 49 | (.addTarget runner (name feed-op)))) 50 | (doseq [fetch-op fetch-ops] 51 | (.fetch runner (name fetch-op))) 52 | (vec (map tensor->clj (.run runner))))) 53 | -------------------------------------------------------------------------------- /src/tensorflow_clj/experimental.clj: -------------------------------------------------------------------------------- 1 | (ns tensorflow-clj.experimental 2 | (:require [clojure.core.matrix :as matrix] 3 | [tensorflow-clj.core :as core] 4 | [tensorflow-clj.util :as util])) 5 | 6 | (defn exec-graph-sess-fn [graph-sess-fn] 7 | (let [graph (org.tensorflow.Graph.)] 8 | (try 9 | (let [session (org.tensorflow.Session. graph)] 10 | (try 11 | (graph-sess-fn graph session) 12 | (finally (.close session)))) 13 | (finally (.close graph))))) 14 | 15 | (defn load-graph! [graph filename] 16 | (.importGraphDef graph (util/slurp-binary filename))) 17 | 18 | (defn run-graph-thing [session feed-ops & fetch-ops] 19 | (let [runner (.runner session)] 20 | (doseq [[feed-op feed-value] feed-ops] 21 | (if feed-value 22 | (.feed runner (name feed-op) (core/tensor feed-value)) 23 | (.addTarget runner (name feed-op)))) 24 | (doseq [fetch-op fetch-ops] 25 | (.fetch runner (name fetch-op))) 26 | (vec (map core/tensor->clj (.run runner))))) 27 | -------------------------------------------------------------------------------- /src/tensorflow_clj/graph/attributes.clj: -------------------------------------------------------------------------------- 1 | (ns tensorflow-clj.graph.attributes 2 | (require [tensorflow-clj.util :refer [assoc-not-empty assoc-in-not-empty]])) 3 | 4 | (def tf-data-types 5 | [ 6 | { 7 | :checker-fn float? 8 | :val-key :float-val 9 | :tf-enum-types #{"DT_FLOAT" "DT_DOUBLE" "DT_BFLOAT16" 10 | "DT_FLOAT_REF" "DT_DOUBLE_REF" "DT_BFLOAT16_REF"} 11 | } 12 | { 13 | :checker-fn integer? 14 | :val-key :int-val 15 | :tf-enum-types #{"DT_INT32" "DT_UINT8" "DT_INT16" "DT_INT8" "DT_INT64" "DT_QINT8" "DT_QUINT8" "DT_QINT32" "DT_QINT16" "DT_QUINT16" "DT_UINT16" 16 | "DT_INT32_REF" "DT_UINT8_REF" "DT_INT16_REF" "DT_INT8_REF" "DT_INT64_REF" "DT_QINT8_REF" "DT_QUINT8_REF" 17 | "DT_QINT32_REF" "DT_QINT16_REF" "DT_QUINT16_REF" "DT_UINT16_REF"} 18 | } 19 | { 20 | :checker-fn (fn [x] (or (= x true) (= x false))) 21 | :val-key :bool-val 22 | :tf-enum-types #{"DT_BOOL" "DT_BOOL_REF"} 23 | } 24 | { 25 | :checker-fn string? 26 | :val-key :string-val 27 | :tf-enum-types #{"DT_STRING_REF"} 28 | } 29 | ] 30 | ;DT_COMPLEX64(8), DT_COMPLEX128(18), DT_HALF(19), DT_RESOURCE(20), DT_COMPLEX64_REF(108), DT_COMPLEX128_REF(118), 31 | ;DT_HALF_REF(119), DT_RESOURCE_REF(120), UNRECOGNIZED(-1) 32 | ) 33 | 34 | (defn lookup-by-dtype [dtype] 35 | (let [matches (filter #(contains? (:tf-enum-types %) dtype) tf-data-types)] 36 | (first matches))) 37 | 38 | (defn find-dtype [attr] 39 | (or (-> attr :T :type) 40 | (-> attr :dtype :type) 41 | (-> attr :value :tensor :dtype))) 42 | 43 | (defn build-dims [dims] 44 | (mapv #(assoc {} :size %) dims)) 45 | 46 | (defn build-attr [k v] 47 | ;{:key k :value v} 48 | {k v}) 49 | 50 | ;;todo should value be "values"? do we need to add support for that? 51 | (defn build-attr-value [value value-dtype dims] 52 | (let [attr (build-attr :value { 53 | :tensor { 54 | :dtype value-dtype 55 | :tensor_shape {} 56 | } 57 | }) 58 | val-key (-> value-dtype lookup-by-dtype :val-key)] 59 | (-> attr 60 | (assoc-in-not-empty [:value :tensor :tensor_shape :dim] (build-dims dims)) 61 | ;;todo add support for nil values? 62 | (assoc-in [:value :tensor val-key] [value])))) 63 | 64 | (defn build-attr-dtype [dtype] 65 | (build-attr :dtype { :type dtype })) 66 | 67 | (defn build-attr-n [val] 68 | (build-attr :N { :i val })) 69 | 70 | (defn build-attr-t [dtype] 71 | (build-attr :T { :type dtype })) 72 | 73 | (defn build-attr-tidx [dtype] 74 | (build-attr :Tidx { :type dtype })) 75 | 76 | (defn build-attr-tshape [dtype] 77 | (build-attr :Tshape { :type dtype })) 78 | 79 | (defn build-attr-out-type [dtype] 80 | (build-attr :out_type { :type dtype })) 81 | 82 | (defn build-attr-axis [val] 83 | (build-attr :axis { :i val })) 84 | 85 | (defn build-attr-index [dtype] 86 | (build-attr :Index { :type dtype })) 87 | 88 | (defn build-attr-shape [dims] 89 | (build-attr :shape {:shape { 90 | :dim (build-dims dims) 91 | }})) -------------------------------------------------------------------------------- /src/tensorflow_clj/graph/gradients.clj: -------------------------------------------------------------------------------- 1 | (ns tensorflow-clj.graph.gradients 2 | (require [tensorflow-clj.graph.variables :as variables] 3 | [tensorflow-clj.graph.node_defs :refer :all] 4 | [ubergraph.core :as uber] 5 | [ubergraph.alg :as uber-alg])) 6 | 7 | (defn build-reverse-node-pointers [node] 8 | (let [node-name (:name node)] 9 | (map #(vec [node-name %]) (:inputs node)))) 10 | 11 | ;;The list of ops that we can gradientize 12 | (def gradientable-ops #{"Add", "MatMul", "Mean", "Reshape", "SoftmaxCrossEntropyWithLogits", "Sub"}) 13 | 14 | (defn find-gradientable-nodes [nodes] 15 | (let [pointers (mapcat identity (map build-reverse-node-pointers nodes)) 16 | var-node-names (into #{} (map :variable (variables/find-variable-nodes nodes))) 17 | node-names (map :name (filter #(contains? gradientable-ops (:op %)) nodes)) 18 | uber-graph (apply uber/digraph pointers) 19 | ;;find path from each gradientable node to a var and keep if it exists 20 | gradient-nodes (filter #(uber-alg/shortest-path 21 | uber-graph 22 | {:start-node % :end-nodes var-node-names}) node-names)] 23 | (uber/pprint uber-graph) 24 | gradient-nodes)) 25 | 26 | ;;todo either use defined gradient or use symbolic gradient 27 | ;def _SymGrad(op, out_grads): 28 | ;"""Backprop through a function call node op given its outputs' gradients.""" 29 | ;f_in = [x for x in op.inputs] + out_grads 30 | ;f_types = [x.dtype for x in op.inputs] 31 | ;f = attr_value_pb2.NameAttrList() 32 | ;f.name = op.type 33 | ;for k in op.node_def.attr: 34 | ;f.attr[k].CopyFrom(op.node_def.attr[k]) 35 | ;# pylint: disable=protected-access 36 | ;in_grads = functional_ops._symbolic_gradient(input=f_in, Tout=f_types, f=f) 37 | ;# pylint: enable=protected-access 38 | ;return in_grads 39 | 40 | ; functions annotated with @ops.RegisterGradient(op) 41 | (defmulti build-nodes-gradient (fn [node] (:op node))) 42 | 43 | ;build entries from sources files like tensorflow/python/ops/math_grad.py 44 | (defmethod build-nodes-gradient "MatMul" [node] 45 | "something") 46 | 47 | ;(build-nodes-gradient {:op "MatMul"}) 48 | ;(ns-unmap *ns* 'build-nodes-gradient) 49 | 50 | (defn build-nodes-gradient [nodes target-node] 51 | ) 52 | 53 | (defn build-nodes-gradient-descent-optimizer [nodes input-node] 54 | (let [var-refs (variables/find-variable-nodes nodes) 55 | gradientable-nodes (find-gradientable-nodes nodes) 56 | gradients-shape-node (build-node-const nil "DT_INT32" []) 57 | gradients-const-node (build-node-const 1.0 "DT_FLOAT" []) 58 | graidents-fill-node (build-node-fill gradients-shape-node gradients-const-node) 59 | 60 | ;;ExpandDims of input-node?!?!?! 61 | ] 62 | gradientable-nodes)) -------------------------------------------------------------------------------- /src/tensorflow_clj/graph/node_defs.clj: -------------------------------------------------------------------------------- 1 | (ns tensorflow-clj.graph.node_defs 2 | (require [clojure.string :as str] 3 | [tensorflow-clj.graph.attributes :refer :all] 4 | [random-string.core :as randy-str])) 5 | 6 | (defn gen-name [prefix preserve?] 7 | (if preserve? 8 | prefix 9 | (str prefix "_" (randy-str/string 16)))) 10 | 11 | (defn build-node-name [op & {:keys [name prefix]}] 12 | (if name 13 | (let [base-name (or name op) 14 | fullname (if prefix 15 | (str/join "/" [prefix base-name]) 16 | base-name)] 17 | fullname) 18 | (gen-name op false))) 19 | 20 | (defn build-node [op & {:keys [name inputs control-deps attr meta-attr]}] 21 | (let [node {:op op 22 | :name (or name (build-node-name op)) 23 | :inputs inputs 24 | :control-deps control-deps 25 | :attr (apply merge attr) 26 | :meta-attr (apply merge meta-attr)}] 27 | (into {} (filter second node)))) 28 | 29 | (defn build-node-placeholder [dtype & {:keys [name prefix]}] 30 | (let [op "Placeholder" 31 | attr-dtype (build-attr-dtype dtype) 32 | ;;todo attr-shape?!?!?! 33 | fullname (build-node-name op :name name :prefix prefix)] 34 | (build-node op :name fullname :attr [attr-dtype]))) 35 | 36 | (defn build-node-const [value value-dtype dims & {:keys [name prefix]}] 37 | (let [op "Const" 38 | attr-dtype (build-attr-dtype value-dtype) 39 | attr-value (build-attr-value value value-dtype dims) 40 | fullname (build-node-name op :name name :prefix prefix) 41 | base (build-node op :name fullname :attr [attr-dtype attr-value])] 42 | base)) 43 | 44 | (defn build-node-variable [dims value-dtype & {:keys [name prefix]}] 45 | (let [op "VariableV2" 46 | fullname (build-node-name op :name name :prefix prefix) 47 | attr-dtype (build-attr-dtype value-dtype) 48 | attr-shape (build-attr-shape dims) 49 | base (build-node op :name fullname :attr [attr-dtype attr-shape])] 50 | base)) 51 | 52 | (defn build-node-assign [variable value] 53 | (let [op "Assign" 54 | fullname (str/join "/" [(-> variable :name) "Assign"]) 55 | inputs (mapv :name [variable value]) 56 | attr-t (build-attr-t (-> variable :attr find-dtype)) 57 | base (build-node op :name fullname :inputs inputs :attr [attr-t])] 58 | base)) 59 | 60 | (defn build-node-identity [target] 61 | (let [op "Identity" 62 | fullname (str/join "/" [(-> target :name) "read"]) 63 | attr-t (-> target :attr find-dtype build-attr-t) 64 | inputs [(-> target :name)]] 65 | (build-node op :name fullname :inputs inputs :attr [attr-t]))) 66 | 67 | (defn build-node-matmul [x y] 68 | (let [op "MatMul" 69 | inputs (mapv :name [x y]) 70 | attr-t (-> x :attr find-dtype build-attr-t)] 71 | (build-node op :inputs inputs :attr [attr-t]))) 72 | 73 | (defn build-node-add [x y] 74 | (let [op "Add" 75 | inputs (mapv :name [x y]) 76 | attr-t (-> x :attr find-dtype build-attr-t)] 77 | (build-node op :inputs inputs :attr [attr-t]))) 78 | 79 | (defn build-node-sub [x y] 80 | (let [op "Sub" 81 | inputs (mapv :name [x y]) 82 | attr-t (-> x :attr find-dtype build-attr-t)] 83 | (build-node op :inputs inputs :attr [attr-t]))) 84 | 85 | (defn build-node-slice [input-node begin-node size-node] 86 | (let [op "Slice" 87 | inputs (mapv :name [input-node begin-node size-node]) 88 | attr-index (build-attr-index "DT_INT32") 89 | attr-t (build-attr-t "DT_INT32")] 90 | (build-node op :inputs inputs :attr [attr-index attr-t]))) 91 | 92 | (defn build-node-concat-v2 [value-nodes axis-node] 93 | (let [op "ConcatV2" 94 | inputs (mapv :name (concat value-nodes [axis-node])) 95 | attr-n (build-attr-n (count value-nodes)) 96 | attr-t (build-attr-t "DT_INT32") 97 | attr-tidx (build-attr-tidx "DT_INT32")] 98 | (build-node op :inputs inputs :attr [attr-n attr-t attr-tidx]))) 99 | 100 | (defn build-node-reshape [tensor-node shape-node] 101 | (let [op "Reshape" 102 | inputs (mapv :name [tensor-node shape-node]) 103 | attr-t (-> tensor-node :attr find-dtype build-attr-t) 104 | attr-tshape (build-attr-tshape "DT_INT32")] 105 | (build-node op :inputs inputs :attr [attr-t attr-tshape]))) 106 | 107 | (defn build-node-shape [input-node] 108 | (let [op "Shape" 109 | inputs (mapv :name [input-node]) 110 | attr-t (-> input-node :attr find-dtype build-attr-t) 111 | attr-out-type (build-attr-out-type "DT_INT32")] 112 | (build-node op :inputs inputs :attr [attr-t attr-out-type]))) 113 | 114 | (defn build-node-pack [input-nodes] 115 | (let [op "Pack" 116 | inputs (mapv :name input-nodes) 117 | attr-n (build-attr-n (count input-nodes)) 118 | attr-t (build-attr-t "DT_INT32") 119 | attr-axis (build-attr-axis 0)] 120 | (build-node op :inputs inputs :attr [attr-n attr-t attr-axis]))) 121 | 122 | (defn build-node-softmax-cross-entropy-with-logits [labels-node logits-node] 123 | (let [op "SoftmaxCrossEntropyWithLogits" 124 | inputs (mapv :name [labels-node logits-node]) 125 | attr-t (-> labels-node :attr find-dtype build-attr-t)] 126 | (build-node op :inputs inputs :attr [attr-t]))) 127 | 128 | (defn build-node-reduce-mean [input-node reduction-indices-node] 129 | (let [op "Mean" 130 | inputs (mapv :name [input-node reduction-indices-node]) 131 | attr-t (-> input-node :attr find-dtype build-attr-t) 132 | ;;todo build keep_dims attr? 133 | attr-tidx (build-attr-tidx "DT_INT32")] 134 | (build-node op :inputs inputs :attr [attr-t attr-tidx]))) 135 | 136 | (defn build-node-fill [dims-node value-node] 137 | (let [op "Fill" 138 | inputs (mapv :name [dims-node value-node]) 139 | attr-t (-> value-node :attr find-dtype build-attr-t)] 140 | (build-node op :inputs inputs :attr [attr-t]))) 141 | 142 | (defn build-node-apply-gradient-descent [input-node alpha-node delta-node] 143 | (let [op "ApplyGradientDescent" 144 | inputs (mapv :name [input-node alpha-node delta-node]) 145 | attr-t (-> input-node :attr find-dtype build-attr-t)] 146 | (build-node op :inputs inputs :attr [attr-t]))) -------------------------------------------------------------------------------- /src/tensorflow_clj/graph/proto_much.clj: -------------------------------------------------------------------------------- 1 | (ns tensorflow-clj.graph.proto-much 2 | (:require [flatland.protobuf.core :as proto] 3 | [tensorflow-clj.util :as util] 4 | [random-string.core :as randy-str]) 5 | (:import 6 | (org.tensorflow.framework.GraphDef))) 7 | 8 | (def proto-meta-graph-def (proto/protodef org.tensorflow.framework.MetaGraphDef)) 9 | (def proto-graph-def (proto/protodef org.tensorflow.framework.GraphDef)) 10 | (def graph-node (proto/protodef org.tensorflow.framework.NodeDef)) 11 | 12 | (defn graph-to-bytes [graph] 13 | (let [proto-graph (apply 14 | (partial proto/protobuf proto-graph-def) 15 | (->> graph 16 | (into []) 17 | (apply concat)))] 18 | (proto/protobuf-dump proto-graph))) 19 | 20 | (defn byte-string-to-string [^com.google.protobuf.ByteString$LiteralByteString byte-string-literal] 21 | (-> byte-string-literal .toStringUtf8)) 22 | -------------------------------------------------------------------------------- /src/tensorflow_clj/graph/transform.clj: -------------------------------------------------------------------------------- 1 | (ns tensorflow-clj.graph.transform 2 | (require [clojure.string :as str] 3 | [tensorflow-clj.util :refer [assoc-not-empty assoc-in-not-empty]])) 4 | 5 | ;;Control dependencies start with a caret, apparently 6 | (defn is-control-dep-name [name] 7 | (str/starts-with? name "^")) 8 | 9 | (defn drop-caret [name] 10 | (subs name 1)) 11 | 12 | (defn add-caret [name] 13 | (str "^" name)) 14 | 15 | ;;Transforme node defs from Tensorflow into something suitable for us to play with (and back again) and don't 16 | ;;pretend to know everything about the structure / content of the map 17 | (defn tensorflow-node->clj-node [node] 18 | (let [inputs (filter (complement is-control-dep-name) (:input node)) 19 | control-deps (mapv drop-caret (filter is-control-dep-name (:input node))) 20 | converted-attrs (into {} (mapv #(vec [(:key %) (:value %)]) (:attr node)))] 21 | (-> node 22 | (dissoc :input) 23 | (assoc :attr converted-attrs) 24 | (assoc-not-empty :inputs inputs) 25 | (assoc-not-empty :control-deps control-deps)))) 26 | 27 | (defn clj-node->tensorflow-node [node] 28 | (let [input (concat (:inputs node) (mapv add-caret (:control-deps node))) 29 | converted-attrs (into [] (mapv (fn [[k v]] (apply hash-map [:key (name k) :value v])) (:attr node)))] 30 | (-> node 31 | (dissoc :inputs :control-deps) 32 | (assoc :attr converted-attrs) 33 | (assoc-not-empty :input input)))) -------------------------------------------------------------------------------- /src/tensorflow_clj/graph/variables.clj: -------------------------------------------------------------------------------- 1 | (ns tensorflow-clj.graph.variables) 2 | 3 | (defn build-meta-attr-variable [trainable? variable assign identity] 4 | {:variable variable 5 | :assign assign 6 | :identity identity 7 | :trainable? trainable?}) 8 | 9 | (defn find-node-ref [op nodes input-node-name] 10 | (let [var-nodes (filter #(= op (:op %)) nodes) 11 | filtered-nodes (filter #(contains? (into #{} (:inputs %)) input-node-name) var-nodes)] 12 | (first filtered-nodes))) 13 | 14 | (defn build-meta-var-ref [nodes trainable? var-node-name] 15 | (let [identity-node (find-node-ref "Identity" nodes var-node-name) 16 | assign-node (find-node-ref "Assign" nodes var-node-name)] 17 | (build-meta-attr-variable trainable? var-node-name (:name assign-node) (:name identity-node)))) 18 | 19 | (defn find-variable-nodes [nodes] 20 | (let [var-nodes (filter #(= "VariableV2" (:op %)) nodes) 21 | var-refs (map #(build-meta-var-ref nodes true (:name %)) var-nodes)] 22 | var-refs)) -------------------------------------------------------------------------------- /src/tensorflow_clj/graph_ops.clj: -------------------------------------------------------------------------------- 1 | (ns tensorflow-clj.graph-ops 2 | (:require [flatland.protobuf.core :as proto] 3 | [tensorflow-clj.experimental :as exp] 4 | [tensorflow-clj.util :as util] 5 | [camel-snake-kebab.core :as csk]) 6 | (:import 7 | [org.tensorflow.framework OpList OpList$Builder GraphDef 8 | OpDef OpDef$ArgDef OpDef$AttrDef AttrValue ConfigProto DataType AttrValue$ListValue] 9 | [com.google.protobuf TextFormat])) 10 | 11 | (defn build-op-def-map* 12 | "Get a list of all registered operation definitions, 13 | like TF_GetAllOpList in the C API. 14 | Useful for auto generating operations." 15 | [] 16 | (let [op-list-protobuf-src (slurp "resources/ops.pbtxt") 17 | op-list-builder (OpList/newBuilder) 18 | _ (TextFormat/merge ^java.lang.CharSequence op-list-protobuf-src op-list-builder) 19 | op-list (-> op-list-builder .build .getOpList) 20 | name-keys (map #(.getName ^OpDef %) op-list)] 21 | (zipmap name-keys op-list))) 22 | 23 | (def op-def-map (memoize build-op-def-map*)) 24 | 25 | (defn get-op-def 26 | "Get operation definition from ops.pbtxt" 27 | [op-name] 28 | (get (op-def-map) op-name)) 29 | 30 | (defn keywordize-name 31 | [name] 32 | (keyword (csk/->kebab-case name))) 33 | 34 | (defn data-type->map 35 | [^DataType dt-def] 36 | { 37 | :name (.name dt-def) 38 | :number (.getNumber dt-def) 39 | :name-key (keywordize-name (.name dt-def)) 40 | }) 41 | 42 | (defn attr-value->map 43 | [^AttrValue attr-value] 44 | {:type (data-type->map (.getType attr-value)) 45 | :list (mapv data-type->map (.getTypeList (.getList attr-value))) 46 | }) 47 | 48 | (defn attr-def->map 49 | [^OpDef$AttrDef attr-def] 50 | {:name (.getName attr-def) 51 | :description (.getDescription attr-def) 52 | ;; clj-tf style name 53 | :name-key (keywordize-name (.getName attr-def)) 54 | :type (.getType attr-def) 55 | :has-minimum (.getHasMinimum attr-def) 56 | :minimum (.getMinimum attr-def) 57 | :allowed-values (attr-value->map (.getAllowedValues attr-def)) 58 | :default-value (attr-value->map (.getDefaultValue attr-def)) 59 | }) 60 | 61 | (defn arg-def->map 62 | [^OpDef$ArgDef arg-def] 63 | {:name (.getName arg-def) 64 | :description (.getDescription arg-def) 65 | ;; clj-tf style name 66 | :keywordized-name (keywordize-name (.getName arg-def)) 67 | :number-attr (.getNumberAttr arg-def) 68 | ;; TODO :type (.getType arg-def) 69 | :type (data-type->map (.getType arg-def)) 70 | :type-attr (.getTypeAttr arg-def) 71 | :type-list-attr (.getTypeListAttr arg-def) 72 | :type-value (.getTypeValue arg-def) 73 | :is-ref (.getIsRef arg-def) 74 | } 75 | ) 76 | 77 | (defn op-def->map 78 | "Get description map of a tensorFlow operation definition." 79 | [^OpDef op-def] 80 | {:name (.getName op-def) 81 | :summary (.getSummary op-def) 82 | :description (.getDescription op-def) 83 | :attributes (mapv attr-def->map (.getAttrList op-def)) 84 | :inputs (mapv arg-def->map (.getInputArgList op-def)) 85 | :outputs (mapv arg-def->map (.getOutputArgList op-def)) 86 | }) 87 | 88 | ;(op-def->map (get-op-def "Mul")) 89 | ;(op-def->map (get-op-def "ApplyAdagradDA")) 90 | ;(op-def->map (get-op-def "SparseApplyAdadelta")) 91 | -------------------------------------------------------------------------------- /src/tensorflow_clj/graph_playground.clj: -------------------------------------------------------------------------------- 1 | (ns tensorflow-clj.graph-playground 2 | (require [tensorflow-clj.graph.proto-much :as proto-much] 3 | [tensorflow-clj.experimental :as exp] 4 | [flatland.protobuf.core :as proto] 5 | [tensorflow-clj.util :as util] 6 | [tensorflow-clj.graph.node_defs :refer :all] 7 | [tensorflow-clj.graph.transform :refer :all] 8 | [tensorflow-clj.util :refer [assoc-not-empty assoc-in-not-empty]] 9 | [clojure.string :as str])) 10 | ;TODO this is a playground namespace for now 11 | ;pieces will be hacked on here in isolation and then moved out to other namespaces 12 | ;where they can be re-used once they are considered semi-stable 13 | ;much of this code is commented out as I'm just running it once here or there 14 | ;during development. anything that I think *might* be useful will be thrown into a function 15 | ;any functions that prove to have some utility will be grouped and moved to their own ns 16 | 17 | (defn def-tensor-nodes [name value dtype dims] 18 | (let [target (build-node-variable dims dtype :name name) 19 | value-node (build-node-const value dtype dims :name "zeros" :prefix name) 20 | assign (build-node-assign target value-node) 21 | identity (build-node-identity target)] 22 | [target value-node assign identity])) 23 | 24 | (defn build-nodes-y-mx-b [] 25 | (let [x-nodes (def-tensor-nodes "x" 0.0 "DT_FLOAT" [10 784]) 26 | W-nodes (def-tensor-nodes "W" 0.0 "DT_FLOAT" [784 10]) 27 | b-nodes (def-tensor-nodes "b" 0.0 "DT_FLOAT" [10]) 28 | matmul-node (build-node-matmul (last x-nodes) (last W-nodes)) 29 | add-node (build-node-add matmul-node (last b-nodes))] 30 | (concat x-nodes W-nodes b-nodes [matmul-node add-node]))) 31 | 32 | (defn build-nodes-softmax-cross-entropy-with-logits [input-node placeholder-node] 33 | (let [rank-0-node (build-node-const 2 "DT_INT32" []) 34 | shape-0-node (build-node-shape input-node) 35 | rank-1-node (build-node-const 2 "DT_INT32" []) 36 | shape-1-node (build-node-shape input-node) 37 | sub-y-node (build-node-const 1 "DT_INT32" []) 38 | sub-node (build-node-sub rank-0-node sub-y-node) 39 | slice-begin-node (build-node-pack [sub-node]) 40 | slice-size-node (build-node-const 1 "DT_INT32" [1]) 41 | slice-node (build-node-slice shape-1-node slice-begin-node slice-size-node) 42 | concat-values-0-node (build-node-const -1 "DT_INT32" [1]) 43 | concat-axis-node (build-node-const 0 "DT_INT32" []) 44 | concat-node (build-node-concat-v2 [concat-values-0-node slice-node] concat-axis-node) 45 | reshape-node (build-node-reshape input-node concat-node) 46 | rank-2-node (build-node-const 2 "DT_INT32" []) 47 | shape-2-node (build-node-shape placeholder-node) 48 | sub-1-y-node (build-node-const 1 "DT_INT32" []) 49 | sub-1-node (build-node-sub rank-2-node sub-1-y-node) 50 | slice-1-begin-node (build-node-pack [sub-1-node]) 51 | slice-1-size-node (build-node-const 1 "DT_INT32" [1]) 52 | slice-1-node (build-node-slice shape-2-node slice-1-begin-node slice-1-size-node) 53 | concat-1-values-0-node (build-node-const -1 "DT_INT32" [1]) 54 | concat-1-axis-node (build-node-const 0 "DT_INT32" []) 55 | concat-1-node (build-node-concat-v2 [concat-1-values-0-node slice-1-node] concat-1-axis-node) 56 | reshape-1-node (build-node-reshape placeholder-node concat-1-node) 57 | cross-entropy-node (build-node-softmax-cross-entropy-with-logits reshape-node reshape-1-node)] 58 | [placeholder-node rank-0-node shape-0-node rank-1-node shape-1-node sub-y-node 59 | sub-node slice-begin-node slice-size-node slice-node concat-values-0-node 60 | concat-axis-node concat-node reshape-node rank-2-node shape-2-node 61 | sub-1-y-node sub-1-node slice-1-begin-node slice-1-size-node slice-1-node 62 | concat-1-values-0-node concat-1-axis-node concat-1-node reshape-1-node cross-entropy-node 63 | ])) 64 | 65 | (defn build-nodes-reduce-mean [input-node] 66 | (let [reduction-indices-node (build-node-const 0 "DT_INT32" [1]) 67 | reduce-mean-node (build-node-reduce-mean input-node reduction-indices-node)] 68 | [reduction-indices-node reduce-mean-node])) 69 | 70 | ;;todo build these for reals 71 | (defn build-loop [loop-node times]) 72 | (defn build-train-next-batch []) 73 | (defn build-prediction-check []) 74 | (defn build-equal-check []) 75 | (defn build-cast [node dtype]) 76 | (defn build-ApplyGradientDescent [variable learning-rate gradient-control]) 77 | 78 | ;(def linreg-graph (proto/protobuf-load proto-much/proto-graph-def (util/slurp-binary "misc/linreg.pb"))) 79 | ;(-> linreg-graph :node count) 80 | ;(mapv :name (-> linreg-graph :node)) 81 | ;(filter #(= "Identity" (:op %)) (-> linreg-graph :node)) 82 | ;(map #(str (:name %) " " (:op %) " " (:input %)) (-> linreg-graph :node)) 83 | ;(def addconst-graph (proto/protobuf-load proto-much/proto-graph-def (util/slurp-binary "misc/addconst.pb"))) 84 | ;(-> addconst-graph :node count) 85 | ;(def mnist-simple-graph (proto/protobuf-load proto-much/proto-graph-def (util/slurp-binary "misc/mnist/mnist_simple.pbtxt"))) 86 | ;(def mnist-meta-graph (proto/protobuf-load proto-much/proto-meta-graph-def (util/slurp-binary "misc/mnist/mnist_simple.model.meta"))) 87 | ;(-> mnist-meta-graph keys) 88 | 89 | (defn parse-node-ref [node-ref] 90 | (let [[name output] (str/split node-ref #":")] 91 | (-> {} 92 | (assoc-not-empty :name name) 93 | (assoc-not-empty :output output)))) 94 | 95 | (defn parse-variable [trainable? entry] 96 | (let [split-entry (-> entry 97 | proto-much/byte-string-to-string 98 | (str/replace "\n" "") 99 | (str/replace "\b" "") 100 | (str/split #"[\u0003\u0012\u001A]")) 101 | [var assign identity] (map parse-node-ref (remove empty? split-entry))] 102 | {:variable var 103 | :assign assign 104 | :identity identity 105 | :trainable? trainable?})) 106 | 107 | (defn parse-trainable-vars [graph] 108 | (->> graph 109 | :collection-def 110 | (filter #(= "trainable_variables" (:key %))) 111 | first 112 | :value 113 | :bytes-list 114 | :value 115 | (map (partial parse-variable true)))) 116 | 117 | ;(parse-trainable-vars mnist-meta-graph) 118 | ;(def thing (-> nonsense-graph :node second :attr second :value :tensor :string-val first proto-much/byte-string-to-string)) 119 | 120 | ;(let [y-mx-b-nodes (build-nodes-y-mx-b) 121 | ; placeholder-node (build-node-placeholder "DT_FLOAT" :name "y_hat") 122 | ; softmax-nodes (build-nodes-softmax-cross-entropy-with-logits (last y-mx-b-nodes) placeholder-node) 123 | ; reduce-mean-nodes (build-nodes-reduce-mean (last softmax-nodes)) 124 | ; nodes (concat y-mx-b-nodes softmax-nodes reduce-mean-nodes) 125 | ; tf-nodes (map clj-node->tensorflow-node nodes) 126 | ; graph {:node tf-nodes 127 | ; :versions {:producer 21}}] 128 | ; (proto/protobuf-load proto-much/proto-graph-def 129 | ; (proto/protobuf-dump proto-much/proto-graph-def graph)) 130 | ; graph) 131 | 132 | ;(time 133 | ; (exp/exec-graph-sess-fn 134 | ; (fn [graph session] 135 | ; (let [nodes (build-nodes-y-mx-b) 136 | ; out-node-name (-> nodes last :name) 137 | ; tf-nodes (map clj-node->tensorflow-node nodes) 138 | ; graph-def {:node tf-nodes 139 | ; :versions {:producer 21}} 140 | ; graph-bytes (proto/protobuf-dump proto-much/proto-graph-def graph-def)] 141 | ; (.importGraphDef graph graph-bytes) 142 | ; (exp/run-graph-thing session {:x [[5.0 12.0] 143 | ; [2.5 3.4]] 144 | ; :W [[6.0 1.3] 145 | ; [0.0 0.0]] 146 | ; :b [[0.0]]} 147 | ; out-node-name) 148 | ; )))) 149 | ; 150 | ;(time 151 | ; (exp/exec-graph-sess-fn 152 | ; (fn [graph session] 153 | ; (let [y-mx-b-nodes (build-nodes-y-mx-b) 154 | ; out-node-1 (-> y-mx-b-nodes last) 155 | ; placeholder-node (build-node-placeholder "DT_FLOAT" :name "y_hat") 156 | ; softmax-nodes (build-nodes-softmax-cross-entropy-with-logits out-node-1 placeholder-node) 157 | ; out-node-2 (-> softmax-nodes last) 158 | ; nodes (concat y-mx-b-nodes softmax-nodes) 159 | ; tf-nodes (map clj-node->tensorflow-node nodes) 160 | ; graph-def {:node tf-nodes 161 | ; :versions {:producer 21}} 162 | ; graph-bytes (proto/protobuf-dump proto-much/proto-graph-def graph-def)] 163 | ; (.importGraphDef graph graph-bytes) 164 | ; (exp/run-graph-thing session {:x [[5.0 12.0] 165 | ; [2.5 3.4]] 166 | ; :W [[6.0 1.3] 167 | ; [0.0 0.0]] 168 | ; :b [[2.0]] 169 | ; :y_hat [[5.3 8.5] 170 | ; [900.24 9.94]]} 171 | ; (-> out-node-1 :name) 172 | ; (-> out-node-2 :name)) 173 | ; )))) 174 | ; 175 | ;(time 176 | ; (exp/exec-graph-sess-fn 177 | ; (fn [graph session] 178 | ; (let [y-mx-b-nodes (build-nodes-y-mx-b) 179 | ; out-node-1 (-> y-mx-b-nodes last) 180 | ; placeholder-node (build-node-placeholder "DT_FLOAT" :name "y_hat") 181 | ; softmax-nodes (build-nodes-softmax-cross-entropy-with-logits out-node-1 placeholder-node) 182 | ; out-node-2 (-> softmax-nodes last) 183 | ; reduce-mean-nodes (build-nodes-reduce-mean out-node-2) 184 | ; out-node-3 (-> reduce-mean-nodes last) 185 | ; nodes (concat y-mx-b-nodes softmax-nodes reduce-mean-nodes) 186 | ; tf-nodes (map clj-node->tensorflow-node nodes) 187 | ; graph-def {:node tf-nodes 188 | ; :versions {:producer 21}} 189 | ; graph-bytes (proto/protobuf-dump proto-much/proto-graph-def graph-def)] 190 | ; (.importGraphDef graph graph-bytes) 191 | ; (exp/run-graph-thing session {:x [[5.0 12.0] 192 | ; [2.5 3.4]] 193 | ; :W [[6.0 1.3] 194 | ; [0.0 0.0]] 195 | ; :b [[2.0]] 196 | ; :y_hat [[5.3 3.5] 197 | ; [0.24 3.94]]} 198 | ; (-> out-node-1 :name) 199 | ; (-> out-node-2 :name) 200 | ; (-> out-node-3 :name)) 201 | ; )))) 202 | -------------------------------------------------------------------------------- /src/tensorflow_clj/util.clj: -------------------------------------------------------------------------------- 1 | (ns tensorflow-clj.util) 2 | 3 | (defn slurp-binary [filename] 4 | (-> (java.nio.file.FileSystems/getDefault) 5 | (.getPath "" (into-array String [filename])) 6 | (java.nio.file.Files/readAllBytes))) 7 | 8 | (defn approx= [a b] 9 | (< (Math/abs (- a b)) 0.0001)) 10 | 11 | (defn round2 12 | "Round a double to the given precision (number of significant digits). 13 | Stolen from http://stackoverflow.com/questions/10751638/clojure-rounding-to-decimal-places" 14 | [precision d] 15 | (let [factor (Math/pow 10 precision)] 16 | (/ (Math/round (* d factor)) factor))) 17 | 18 | (defn assoc-not-empty [m k v] 19 | (if (and v (-> v empty? not)) 20 | (assoc m k v) 21 | m)) 22 | 23 | (defn assoc-in-not-empty [m ks v] 24 | (if (and v (-> v empty? not)) 25 | (assoc-in m ks v) 26 | m)) -------------------------------------------------------------------------------- /test/tensorflow_clj/core_test.clj: -------------------------------------------------------------------------------- 1 | (ns tensorflow-clj.core-test 2 | (:require [clojure.test :refer :all] 3 | [tensorflow-clj.core :refer :all] 4 | [tensorflow-clj.experimental :refer :all] 5 | [tensorflow-clj.util :refer :all])) 6 | 7 | (defmacro test-both-apis [graph-file & body] 8 | `(do 9 | (with-graph-file ~graph-file 10 | (letfn [(~'run-graph [& args#] (apply run-graph args#))] 11 | ~@body)) 12 | (exec-graph-sess-fn 13 | (fn [graph# session#] 14 | (load-graph! graph# ~graph-file) 15 | (letfn [(~'run-graph [& args#] (apply run-graph-thing session# args#))] 16 | ~@body))))) 17 | 18 | (deftest scalar-tensor 19 | (testing "Scalar tensor" 20 | (let [t (tensor 123.0)] 21 | (is (= org.tensorflow.DataType/FLOAT (.dataType t))) 22 | (is (= 0 (.numDimensions t))) 23 | (is (= [] (vec (.shape t))))))) 24 | 25 | (deftest vector-tensor 26 | (testing "Vector tensor" 27 | (let [t (tensor [1.0 2.0 3.0])] 28 | (is (= org.tensorflow.DataType/FLOAT (.dataType t))) 29 | (is (= 1 (.numDimensions t))) 30 | (is (= [3] (vec (.shape t))))))) 31 | 32 | (deftest matrix-tensor 33 | (testing "Matrix tensor" 34 | (let [t (tensor [[1.0 2.0 3.0] 35 | [4.0 5.0 6.0]])] 36 | (is (= org.tensorflow.DataType/FLOAT (.dataType t))) 37 | (is (= 2 (.numDimensions t))) 38 | (is (= [2 3] (vec (.shape t))))))) 39 | 40 | (deftest tensor-conversion 41 | (testing "Converting between tensors and core.matrix" 42 | (letfn [(test [x] (is (= x (tensor->clj (tensor x)))))] 43 | (test 123.0) 44 | (test [1.0 2.0 3.0]) 45 | (test [[1.0 -2.0 3.0] 46 | [4.0 5.0 -6.0]]) 47 | (test [[[1., 2., 3.]], [[7., 8., 9.]]]) 48 | (test [[[[555.5]]]])))) 49 | 50 | (deftest protobuf-session 51 | (testing "Session from Protocol Buffers file" 52 | (test-both-apis "misc/constant.pb" 53 | (let [[v] (run-graph {} :Const)] 54 | (is (= 123.0 v)))))) 55 | 56 | (deftest protobuf-feed 57 | (testing "Variable feed to loaded graph" 58 | (test-both-apis "misc/addconst.pb" 59 | (let [[v] (run-graph {:Placeholder (float 123.0)} :mul)] 60 | (is (= 369.0 v)))))) 61 | 62 | (deftest matrix-feed 63 | (testing "Matrix fed to loaded graph" 64 | (test-both-apis "misc/addconst.pb" 65 | (let [[v] (run-graph {:Placeholder [[1 2] [3 4]]} :mul)] 66 | (is (= v [[3.0 6.0] [9.0 12.0]])))))) 67 | 68 | (deftest mulbymat-graph 69 | (testing "Multiplying variable by constant matrix" 70 | (test-both-apis "misc/mulbymat.pb" 71 | (let [[v] (run-graph {:Placeholder 5} :mul)] 72 | (is (= v [[5. 10.] [15. 20.]]))) 73 | (let [[v] (run-graph {:Placeholder [[1. -1.] [2. -2.]]} :mul)] 74 | (is (= v [[1. -2.] [6. -8.]])))))) 75 | 76 | (deftest mul2vars-graph 77 | (testing "Multiplying two variables" 78 | (test-both-apis "misc/mul2vars.pb" 79 | (let [[v] (run-graph {:a 4. :b 10.5} :mul)] 80 | (is (= v 42.0))) 81 | (let [[v] (run-graph {:a [[1. -1.] [2. -2.]] 82 | :b [[1. 2.] [3. 4.]]} 83 | :mul)] 84 | (is (= v [[1. -2.] [6. -8.]])))))) 85 | 86 | (def x-train [1. 2. 3. 4.]) 87 | (def y-train [0. -1. -2. -3.]) 88 | 89 | (deftest linreg-one-pass 90 | (testing "Linear regression (one pass)" 91 | (test-both-apis "misc/linreg.pb" 92 | (run-graph {:init nil}) 93 | (let [[[a b c d]] (run-graph {:x x-train} :linear_model)] 94 | (is (approx= 0.0 a)) 95 | (is (approx= 0.3 b)) 96 | (is (approx= 0.6 c)) 97 | (is (approx= 0.9 d)))))) 98 | 99 | (deftest linreg-one-loss 100 | (testing "Linear regression (one loss)" 101 | (test-both-apis "misc/linreg.pb" 102 | (run-graph {:init nil}) 103 | (let [[loss] (run-graph {:x x-train :y y-train} :loss)] 104 | (is (approx= 23.66 loss)))))) 105 | 106 | (deftest linreg-fixed-vars 107 | (testing "Linear regression (fixed variables)" 108 | (test-both-apis "misc/linreg.pb" 109 | (run-graph {:fixW nil :fixb nil}) 110 | (let [[loss] (run-graph {:x x-train :y y-train} :loss)] 111 | (is (approx= 0.0 loss)))))) 112 | 113 | (deftest linreg-graph-iterations 114 | (testing "Linear regression (1000 iterations)" 115 | (test-both-apis "misc/linreg.pb" 116 | (run-graph {:init nil}) 117 | (dotimes [i 1000] 118 | (run-graph {:x x-train :y y-train :train nil})) 119 | (let [[[W] [b] loss] (run-graph {:x x-train :y y-train} :W :b :loss)] 120 | (is (approx= -0.9999 W)) 121 | (is (approx= 0.99999 b)) 122 | (is (approx= 5.6999738e-11 loss)))))) 123 | --------------------------------------------------------------------------------