├── .gitignore ├── AUTHORS ├── LICENSE ├── README.md ├── notebooks └── mnist_demo.ipynb ├── setup.cfg ├── setup.py ├── tdb ├── __init__.py ├── app.py ├── debug_session.py ├── examples │ ├── __init__.py │ ├── mnist.py │ └── viz.py ├── ht_op.py ├── interface.py ├── op_store.py ├── plot_op.py ├── python_op.py ├── tests │ ├── __init__.py │ ├── mnist_0.npz │ ├── run_tests.py │ ├── test_exe_order.py │ ├── test_mixed.py │ ├── test_mnist.py │ ├── test_pure_ht.py │ ├── test_pure_tf.py │ └── test_ui.py └── transitive_closure.py └── tdb_ext ├── JSXTransformer.js ├── activate_dev.sh ├── bower.json ├── components ├── plotlistview.jsx ├── plotview.jsx ├── textinputview.jsx ├── ui.jsx └── user_msg_view.jsx ├── config.yaml ├── dispatcher.js ├── htapp.js ├── keymirror.js ├── main.js ├── package.json └── stores └── plotstore.js /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | bower_components/ 3 | node_modules/ 4 | .ipynb_checkpoints/ 5 | build/ 6 | dist/ 7 | -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | 2 | Eric Jang -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2015 Eric Jang 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TDB 2 | 3 | *Note: This project is no longer actively being maintained. Please check out the official [tfdbg debugger](https://www.tensorflow.org/versions/master/how_tos/debugger/) 4 | 5 | TensorDebugger (TDB) is a visual debugger for deep learning. It extends [TensorFlow](https://www.tensorflow.org/) (Google's Deep Learning framework) with breakpoints + real-time visualization of the data flowing through the computational graph. 6 | 7 | [Video Demo](https://www.youtube.com/watch?v=VcoVEvGEmFM) 8 | 9 | 10 | 11 | Specifically, TDB is the combination of a Python library and a Jupyter notebook extension, built around Google's TensorFlow framework. Together, these extend TensorFlow with the following features: 12 | 13 | - **Breakpoints**: Set breakpoints on Ops and Tensors in the graph. Graph execution is paused on breakpoints and resumed by the user (via `tdb.c()`) Debugging features can be used with or without the visualization frontend. 14 | - **Arbitrary Summary Plots**: Real-time visualization of high-level information (e.g. histograms, gradient magnitudes, weight saturation) while the network is being trained. Supports arbitrary, user-defined plot functions. 15 | - **Flexible**: Mix user-defined Python and plotting functions with TensorFlow Nodes. These take in `tf.Tensors` and output placeholder nodes to be plugged into TensorFlow nodes. The below diagram illustrates how TDB nodes can be mixed with the TensorFlow graph. 16 | 17 | ![heterogenous](http://i.imgur.com/7xfA6Pg.png?1) 18 | 19 | 20 | ## Motivations 21 | 22 | Modern machine learning models are parametrically complex and require considerable intuition to fine-tune properly. 23 | 24 | In particular, Deep Learning methods are especially powerful, but hard to interpret in regards to their capabilities and learned representations. 25 | 26 | Can we enable better understanding of how neural nets learn, without having to change model code or sacrifice performance? Can I finish my thesis on time? 27 | 28 | TDB addresses these challenges by providing run-time visualization tools for neural nets. Real-time visual debugging allows training bugs to be detected sooner, thereby reducing the iteration time needed to build the right model. 29 | 30 | ## Setup 31 | 32 | To install the Python library, 33 | 34 | ```bash 35 | pip install tfdebugger 36 | ``` 37 | 38 | To install the Jupyter Notebook extension, run the following in a Python terminal (you will need to have IPython or [Jupyter](https://jupyter.readthedocs.org/en/latest/install.html) installed) 39 | 40 | ```python 41 | import notebook.nbextensions 42 | import urllib 43 | import zipfile 44 | SOURCE_URL = 'https://github.com/ericjang/tdb/releases/download/tdb_ext_v0.1/tdb_ext.zip' 45 | urllib.urlretrieve(SOURCE_URL, 'tdb_ext.zip') 46 | with zipfile.ZipFile('tdb_ext.zip', "r") as z: 47 | z.extractall("") 48 | notebook.nbextensions.install_nbextension('tdb_ext',user=True) 49 | ``` 50 | 51 | ## Tutorial 52 | 53 | To get started, check out the [MNIST Visualization Demo](notebooks/mnist_demo.ipynb). More examples and visualizations to come soon. 54 | 55 | ## User Guide 56 | 57 | ### Debugging 58 | 59 | #### Start 60 | ```python 61 | status,result=tdb.debug(evals,feed_dict=None,breakpoints=None,break_immediately=False,session=None) 62 | ``` 63 | 64 | `debug()` behaves just like Tensorflow's Session.run(). If a breakpoint is hit, `status` is set to 'PAUSED' and `result` is set to `None`. Otherwise, `status` is set to 'FINISHED' and `result` is set to a list of evaluated values. 65 | 66 | #### Continue 67 | ```python 68 | status,result=tdb.c() 69 | ``` 70 | 71 | Continues execution of a paused session, until the next breakpoint or end. Behaves like `debug`. 72 | 73 | 74 | #### Step 75 | ```python 76 | status,result=tdb.s() 77 | ``` 78 | 79 | Evaluate the next node, then pause immediately to await user input. Unless we have reached the end of the execution queue, `status` will remain 'PAUSED'. `result` is set to the value of the node we just evaluated. 80 | 81 | #### Where 82 | 83 | ```python 84 | q=tdb.get_exe_queue() 85 | ``` 86 | 87 | Return value: list of remaining nodes to be evaluated, in order. 88 | 89 | #### print 90 | 91 | ```python 92 | val=tdb.get_value(node) 93 | ``` 94 | 95 | Returns value of an evaluated node (a string name or a tf.Tensor) 96 | 97 | ### Custom Nodes 98 | 99 | TDB supports 2 types of custom Ops: 100 | 101 | #### Python 102 | 103 | Here is an example of mixing tdb.PythonOps with TensorFlow. 104 | 105 | Define the following function: 106 | ``` 107 | def myadd(ctx,a,b): 108 | return a+b 109 | ``` 110 | 111 | ```python 112 | a=tf.constant(2) 113 | b=tf.constant(3) 114 | c=tdb.python_op(myadd,inputs=[a,b],outputs=[tf.placeholder(tf.int32)]) # a+b 115 | d=tf.neg(c) 116 | status,result=tdb.debug([d], feed_dict=None, breakpoints=None, break_immediately=False) 117 | ``` 118 | 119 | When `myadd` gets evaluated, `ctx` is the instance of the PythonOp that it belongs to. You can use ctx to store state information (i.e. accumulate loss history). 120 | 121 | #### Plotting 122 | 123 | PlotOps are a special instance of PythonOp that send graphical output to the frontend. 124 | 125 | This only works with Matplotlib at the moment, but other plotting backends (Seaborn, Bokeh, Plotly) are coming soon. 126 | 127 | ```python 128 | def watch_loss(ctx,loss): 129 | if not hasattr(ctx, 'loss_history'): 130 | ctx.loss_history=[] 131 | ctx.loss_history.append(loss) 132 | plt.plot(ctx.loss_history) 133 | plt.ylabel('loss') 134 | ``` 135 | 136 | ```python 137 | ploss=tdb.plot_op(viz.watch_loss,inputs=[loss]) 138 | ``` 139 | 140 | Refer to the [MNIST Visualization Demo](notebooks/mnist_demo.ipynb) for more examples. You can also find more examples in the [tests/](tdb/tests) directory. 141 | 142 | ## FAQ 143 | 144 | ### Is TDB affiliated with TensorFlow? 145 | 146 | No, but it is built on top of it. 147 | 148 | ### What is TDB good for? 149 | 150 | TDB is especially useful at the model prototyping stage and verifying correctness in an intuitive manner. It is also useful for high-level visualization of hidden layers during training. 151 | 152 | ### How is TDB different from TensorBoard? 153 | 154 | TensorBoard is a suite of visualization tools included with Tensorflow. Both TDB and TensorBoard attach auxiliary nodes to the TensorFlow graph in order to inspect data. 155 | 156 | TensorBoard cannot be used concurrently with running a TensorFlow graph; log files must be written first. TDB interfaces directly with the execution of a TensorFlow graph, and allows for stepping through execution one node at a time. 157 | 158 | Out of the box, TensorBoard currently only supports logging for a few predefined data formats. 159 | 160 | TDB is to TensorBoard as GDB is to printf. Both are useful in different contexts. 161 | 162 | 163 | 164 | ## License 165 | 166 | Apache 2.0 167 | 168 | 169 | -------------------------------------------------------------------------------- /notebooks/mnist_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# MNIST Visualization Example\n", 8 | "\n", 9 | "Real-time visualization of MNIST training on a CNN, using TensorFlow and [TensorDebugger](https://github.com/ericjang/tdb)\n", 10 | "\n", 11 | "The visualizations in this notebook won't show up on http://nbviewer.ipython.org. To view the widgets and interact with them, you will need to download this notebook and run it with a Jupyter Notebook server." 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "## Step 1: Load TDB Notebook Extension" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "metadata": { 25 | "collapsed": false 26 | }, 27 | "outputs": [ 28 | { 29 | "data": { 30 | "application/javascript": [ 31 | "Jupyter.utils.load_extensions('tdb_ext/main')" 32 | ], 33 | "text/plain": [ 34 | "" 35 | ] 36 | }, 37 | "metadata": {}, 38 | "output_type": "display_data" 39 | } 40 | ], 41 | "source": [ 42 | "%%javascript\n", 43 | "Jupyter.utils.load_extensions('tdb_ext/main')" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 1, 49 | "metadata": { 50 | "collapsed": false 51 | }, 52 | "outputs": [], 53 | "source": [ 54 | "#import sys\n", 55 | "#sys.path.append('/home/evjang/thesis/tensor_debugger')\n", 56 | "import tdb\n", 57 | "from tdb.examples import mnist, viz\n", 58 | "import matplotlib.pyplot as plt\n", 59 | "import tensorflow as tf\n", 60 | "import urllib" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": {}, 66 | "source": [ 67 | "## Step 2: Build TensorFlow Model" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 3, 73 | "metadata": { 74 | "collapsed": true 75 | }, 76 | "outputs": [], 77 | "source": [ 78 | "(train_data_node,\n", 79 | " train_labels_node,\n", 80 | " validation_data_node,\n", 81 | " test_data_node,\n", 82 | " # predictions\n", 83 | " train_prediction,\n", 84 | " validation_prediction,\n", 85 | " test_prediction,\n", 86 | " # weights\n", 87 | " conv1_weights,\n", 88 | " conv2_weights,\n", 89 | " fc1_weights,\n", 90 | " fc2_weights,\n", 91 | " # training\n", 92 | " optimizer,\n", 93 | " loss,\n", 94 | " learning_rate,\n", 95 | " summaries) = mnist.build_model()" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": {}, 101 | "source": [ 102 | "## Step 3: Attach Plotting Ops" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 4, 108 | "metadata": { 109 | "collapsed": true 110 | }, 111 | "outputs": [], 112 | "source": [ 113 | "def viz_activations(ctx, m):\n", 114 | " plt.matshow(m.T,cmap=plt.cm.gray)\n", 115 | " plt.title(\"LeNet Predictions\")\n", 116 | " plt.xlabel(\"Batch\")\n", 117 | " plt.ylabel(\"Digit Activation\")" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 5, 123 | "metadata": { 124 | "collapsed": false 125 | }, 126 | "outputs": [], 127 | "source": [ 128 | "# plotting a user-defined function 'viz_activations'\n", 129 | "p0=tdb.plot_op(viz_activations,inputs=[train_prediction])\n", 130 | "# weight variables are of type tf.Variable, so we need to find the corresponding tf.Tensor instead\n", 131 | "g=tf.get_default_graph()\n", 132 | "p1=tdb.plot_op(viz.viz_conv_weights,inputs=[g.as_graph_element(conv1_weights)])\n", 133 | "p2=tdb.plot_op(viz.viz_conv_weights,inputs=[g.as_graph_element(conv2_weights)])\n", 134 | "p3=tdb.plot_op(viz.viz_fc_weights,inputs=[g.as_graph_element(fc1_weights)])\n", 135 | "p4=tdb.plot_op(viz.viz_fc_weights,inputs=[g.as_graph_element(fc2_weights)])\n", 136 | "p2=tdb.plot_op(viz.viz_conv_hist,inputs=[g.as_graph_element(conv1_weights)])\n", 137 | "ploss=tdb.plot_op(viz.watch_loss,inputs=[loss])" 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": {}, 143 | "source": [ 144 | "## Step 4: Download the MNIST dataset\n" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 6, 150 | "metadata": { 151 | "collapsed": false 152 | }, 153 | "outputs": [ 154 | { 155 | "name": "stdout", 156 | "output_type": "stream", 157 | "text": [ 158 | "train-images-idx3-ubyte.gz\n", 159 | "train-labels-idx1-ubyte.gz\n", 160 | "t10k-images-idx3-ubyte.gz\n", 161 | "t10k-labels-idx1-ubyte.gz\n" 162 | ] 163 | } 164 | ], 165 | "source": [ 166 | "base_url='http://yann.lecun.com/exdb/mnist/'\n", 167 | "files=['train-images-idx3-ubyte.gz',\n", 168 | " 'train-labels-idx1-ubyte.gz',\n", 169 | " 't10k-images-idx3-ubyte.gz',\n", 170 | " 't10k-labels-idx1-ubyte.gz']\n", 171 | "download_dir='/tmp/'\n", 172 | "for f in files:\n", 173 | " print(f)\n", 174 | " urllib.urlretrieve(base_url+f, download_dir+f)" 175 | ] 176 | }, 177 | { 178 | "cell_type": "markdown", 179 | "metadata": {}, 180 | "source": [ 181 | "## Step 5: Debug + Visualize!\n", 182 | "\n", 183 | "Upon evaluating plot nodes p1,p2,p3,p4,ploss, plots will be generated in the Plot view on the right." 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 7, 189 | "metadata": { 190 | "collapsed": false 191 | }, 192 | "outputs": [ 193 | { 194 | "name": "stdout", 195 | "output_type": "stream", 196 | "text": [ 197 | "('Extracting', '/tmp/train-images-idx3-ubyte.gz')\n", 198 | "('Extracting', '/tmp/train-labels-idx1-ubyte.gz')\n", 199 | "('Extracting', '/tmp/t10k-images-idx3-ubyte.gz')\n", 200 | "('Extracting', '/tmp/t10k-labels-idx1-ubyte.gz')\n" 201 | ] 202 | } 203 | ], 204 | "source": [ 205 | "# return the TF nodes corresponding to graph input placeholders\n", 206 | "(train_data, \n", 207 | " train_labels, \n", 208 | " validation_data, \n", 209 | " validation_labels, \n", 210 | " test_data, \n", 211 | " test_labels) = mnist.get_data(download_dir)" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": 8, 217 | "metadata": { 218 | "collapsed": false 219 | }, 220 | "outputs": [], 221 | "source": [ 222 | "# start the TensorFlow session that will be used to evaluate the graph\n", 223 | "s=tf.InteractiveSession()\n", 224 | "tf.initialize_all_variables().run()" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": 9, 230 | "metadata": { 231 | "collapsed": false 232 | }, 233 | "outputs": [ 234 | { 235 | "name": "stdout", 236 | "output_type": "stream", 237 | "text": [ 238 | "loss: 29.668428\n", 239 | "loss: 15.983353\n", 240 | "loss: 11.249242\n", 241 | "loss: 10.028939\n", 242 | "loss: 8.065391\n", 243 | "loss: 9.335689\n", 244 | "loss: 7.316875\n", 245 | "loss: 8.376289\n", 246 | "loss: 7.735221\n", 247 | "loss: 8.383675\n", 248 | "loss: 5.704120\n", 249 | "loss: 6.037778\n", 250 | "loss: 7.309663\n", 251 | "loss: 7.349874\n", 252 | "loss: 7.528041\n", 253 | "loss: 8.209503\n" 254 | ] 255 | }, 256 | { 257 | "ename": "KeyboardInterrupt", 258 | "evalue": "", 259 | "output_type": "error", 260 | "traceback": [ 261 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", 262 | "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 263 | "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[0;32m 12\u001b[0m }\n\u001b[0;32m 13\u001b[0m \u001b[1;31m# run training node and visualization node\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 14\u001b[1;33m \u001b[0mstatus\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mresult\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtdb\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdebug\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mp0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mfeed_dict\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msession\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0ms\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 15\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mstep\u001b[0m \u001b[1;33m%\u001b[0m \u001b[1;36m10\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 16\u001b[0m \u001b[0mstatus\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mresult\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtdb\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdebug\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mloss\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mp1\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mp2\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mp3\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mp4\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mploss\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mfeed_dict\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbreakpoints\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mNone\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbreak_immediately\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mFalse\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msession\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0ms\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 264 | "\u001b[1;32m/gpfs/main/sys/shared/psfu/contrib/projects/tensorflow/tensorflow-venv/local/lib/python2.7/site-packages/tdb/interface.pyc\u001b[0m in \u001b[0;36mdebug\u001b[1;34m(evals, feed_dict, breakpoints, break_immediately, session)\u001b[0m\n\u001b[0;32m 15\u001b[0m \u001b[1;32mglobal\u001b[0m \u001b[0m_dbsession\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 16\u001b[0m \u001b[0m_dbsession\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdebug_session\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mDebugSession\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msession\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 17\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0m_dbsession\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mevals\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mfeed_dict\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mbreakpoints\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mbreak_immediately\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 18\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 19\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0ms\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 265 | "\u001b[1;32m/gpfs/main/sys/shared/psfu/contrib/projects/tensorflow/tensorflow-venv/local/lib/python2.7/site-packages/tdb/debug_session.pyc\u001b[0m in \u001b[0;36mrun\u001b[1;34m(self, evals, feed_dict, breakpoints, break_immediately)\u001b[0m\n\u001b[0;32m 59\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_break\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 60\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 61\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mc\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 62\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 63\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0ms\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 266 | "\u001b[1;32m/gpfs/main/sys/shared/psfu/contrib/projects/tensorflow/tensorflow-venv/local/lib/python2.7/site-packages/tdb/debug_session.pyc\u001b[0m in \u001b[0;36mc\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 85\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 86\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstate\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mRUNNING\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 87\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_eval\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 88\u001b[0m \u001b[1;31m# increment to next node\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 89\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m+\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 267 | "\u001b[1;32m/gpfs/main/sys/shared/psfu/contrib/projects/tensorflow/tensorflow-venv/local/lib/python2.7/site-packages/tdb/debug_session.pyc\u001b[0m in \u001b[0;36m_eval\u001b[1;34m(self, node)\u001b[0m\n\u001b[0;32m 168\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m \u001b[1;31m# is a TensorFlow node\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 169\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mtf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 170\u001b[1;33m \u001b[0mresult\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msession\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_cache\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 171\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_cache\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mnode\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mresult\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 172\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 268 | "\u001b[1;32m/gpfs/main/sys/shared/psfu/contrib/projects/tensorflow/tensorflow-venv/local/lib/python2.7/site-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36mrun\u001b[1;34m(self, fetches, feed_dict)\u001b[0m\n\u001b[0;32m 346\u001b[0m \u001b[1;31m# Run request and get response.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 347\u001b[0m \u001b[1;31m#pdb.set_trace()\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 348\u001b[1;33m \u001b[0mresults\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_do_run\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtarget_list\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0munique_fetch_targets\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeed_dict_string\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 349\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 350\u001b[0m \u001b[1;31m# User may have fetched the same tensor multiple times, but we\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 269 | "\u001b[1;32m/gpfs/main/sys/shared/psfu/contrib/projects/tensorflow/tensorflow-venv/local/lib/python2.7/site-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36m_do_run\u001b[1;34m(self, target_list, fetch_list, feed_dict)\u001b[0m\n\u001b[0;32m 405\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 406\u001b[0m return tf_session.TF_Run(self._session, feed_dict, fetch_list,\n\u001b[1;32m--> 407\u001b[1;33m target_list)\n\u001b[0m\u001b[0;32m 408\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 409\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mStatusNotOK\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 270 | "\u001b[1;31mKeyboardInterrupt\u001b[0m: " 271 | ] 272 | } 273 | ], 274 | "source": [ 275 | "BATCH_SIZE = 64\n", 276 | "NUM_EPOCHS = 5\n", 277 | "TRAIN_SIZE=10000\n", 278 | "\n", 279 | "for step in xrange(NUM_EPOCHS * TRAIN_SIZE // BATCH_SIZE):\n", 280 | " offset = (step * BATCH_SIZE) % (TRAIN_SIZE - BATCH_SIZE)\n", 281 | " batch_data = train_data[offset:(offset + BATCH_SIZE), :, :, :]\n", 282 | " batch_labels = train_labels[offset:(offset + BATCH_SIZE)]\n", 283 | " feed_dict = {\n", 284 | " train_data_node: batch_data,\n", 285 | " train_labels_node: batch_labels\n", 286 | " }\n", 287 | " # run training node and visualization node\n", 288 | " status,result=tdb.debug([optimizer,p0], feed_dict=feed_dict, session=s)\n", 289 | " if step % 10 == 0: \n", 290 | " status,result=tdb.debug([loss,p1,p2,p3,p4,ploss], feed_dict=feed_dict, breakpoints=None, break_immediately=False, session=s)\n", 291 | " print('loss: %f' % (result[0]))" 292 | ] 293 | } 294 | ], 295 | "metadata": { 296 | "kernelspec": { 297 | "display_name": "Python 2", 298 | "language": "python", 299 | "name": "python2" 300 | }, 301 | "language_info": { 302 | "codemirror_mode": { 303 | "name": "ipython", 304 | "version": 2 305 | }, 306 | "file_extension": ".py", 307 | "mimetype": "text/x-python", 308 | "name": "python", 309 | "nbconvert_exporter": "python", 310 | "pygments_lexer": "ipython2", 311 | "version": "2.7.9" 312 | } 313 | }, 314 | "nbformat": 4, 315 | "nbformat_minor": 0 316 | } 317 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'tfdebugger', 5 | packages = find_packages(), 6 | version = '0.1.1', 7 | description = 'TensorFlow Debugger', 8 | author = 'Eric Jang', 9 | author_email = 'ericjang2004@gmail.com', 10 | url = 'https://github.com/ericjang/tdb', # use the URL to the github repo 11 | download_url = 'https://github.com/ericjang/tdb/archive/0.1.tar.gz', 12 | keywords = ['TDB', 'Deep Learning', 'TensorFlow', 'debugging', 'visualization'], 13 | classifiers = [ 14 | 'Intended Audience :: Developers', 15 | 'Intended Audience :: Science/Research', 16 | 'Programming Language :: Python' 17 | ], 18 | license='Apache 2.0', 19 | install_requires=['toposort>=1.4'] 20 | ) 21 | -------------------------------------------------------------------------------- /tdb/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Debugger for Tensorflow 3 | 4 | import tfdebugger as td 5 | 6 | """ 7 | 8 | from interface import debug, c, s, get_exe_queue, get_value 9 | import op_store 10 | from plot_op import plot_op 11 | from python_op import python_op 12 | from debug_session import INITIALIZED, RUNNING, PAUSED, FINISHED 13 | from app import is_notebook, connect 14 | import examples 15 | import tests 16 | 17 | connect() -------------------------------------------------------------------------------- /tdb/app.py: -------------------------------------------------------------------------------- 1 | from base64 import b64encode 2 | from ipykernel.comm import Comm 3 | from IPython import get_ipython 4 | import StringIO 5 | import urllib 6 | 7 | _comm=None 8 | 9 | def is_notebook(): 10 | iPython=get_ipython() 11 | if iPython is None or not iPython.config: 12 | return False 13 | return 'IPKernelApp' in iPython.config 14 | 15 | def connect(): 16 | """ 17 | establish connection to frontend notebook 18 | """ 19 | if not is_notebook(): 20 | print('Python session is not running in a Notebook Kernel') 21 | return 22 | 23 | global _comm 24 | 25 | kernel=get_ipython().kernel 26 | kernel.comm_manager.register_target('tdb',handle_comm_opened) 27 | # initiate connection to frontend. 28 | _comm=Comm(target_name='tdb',data={}) 29 | # bind recv handler 30 | _comm.on_msg(None) 31 | 32 | def send_action(action, params=None): 33 | """ 34 | helper method for sending actions 35 | """ 36 | data={"msg_type":"action", "action":action} 37 | if params is not None: 38 | data['params']=params 39 | _comm.send(data) 40 | 41 | def send_fig(fig,name): 42 | """ 43 | sends figure to frontend 44 | """ 45 | imgdata = StringIO.StringIO() 46 | fig.savefig(imgdata, format='png') 47 | imgdata.seek(0) # rewind the data 48 | uri = 'data:image/png;base64,' + urllib.quote(b64encode(imgdata.buf)) 49 | send_action("update_plot",params={"src":uri, "name":name}) 50 | 51 | # handler messages 52 | def handle_comm_opened(msg): 53 | # this won't appear in the notebook 54 | print('comm opened') 55 | print(msg) -------------------------------------------------------------------------------- /tdb/debug_session.py: -------------------------------------------------------------------------------- 1 | 2 | from ht_op import HTOp 3 | import op_store 4 | import tensorflow as tf 5 | 6 | # debug status codes 7 | INITIALIZED = 'INITIALIZED' 8 | RUNNING = 'RUNNING' 9 | PAUSED = 'PAUSED' 10 | FINISHED = 'FINISHED' 11 | 12 | class DebugSession(object): 13 | 14 | def __init__(self, session=None): 15 | super(DebugSession, self).__init__() 16 | 17 | if session is None: 18 | session=tf.InteractiveSession() 19 | _original_evals=None 20 | self.step=0 # index into execution order 21 | self.session=session 22 | self.state=INITIALIZED 23 | self._original_evals=[] # evals passed into self.debug, in order 24 | self._evalset=set() # string names to evaluate 25 | self._bpset=set() # breakpoint names 26 | self._cache={} # key: node names in evalset -> np.ndarray 27 | self._exe_order=[] # list of HTOps, tf.Tensors to be evaluated 28 | 29 | ### 30 | ### PUBLIC METHODS 31 | ### 32 | 33 | def run(self, evals, feed_dict=None, breakpoints=None, break_immediately=False): 34 | """ 35 | starts the debug session 36 | """ 37 | if not isinstance(evals,list): 38 | evals=[evals] 39 | if feed_dict is None: 40 | feed_dict={} 41 | if breakpoints is None: 42 | breakpoints=[] 43 | 44 | self.state=RUNNING 45 | self._original_evals=evals 46 | self._original_feed_dict=feed_dict 47 | self._exe_order=op_store.compute_exe_order(evals) 48 | self._init_evals_bps(evals, breakpoints) 49 | 50 | # convert cache keys to strings 51 | for k,v in feed_dict.items(): 52 | if not isinstance(k,str): 53 | k=k.name 54 | self._cache[k]=v 55 | 56 | op_store.register_dbsession(self) 57 | 58 | if break_immediately: 59 | return self._break() 60 | else: 61 | return self.c() 62 | 63 | def s(self): 64 | """ 65 | step to the next node in the execution order 66 | """ 67 | next_node=self._exe_order[self.step] 68 | self._eval(next_node) 69 | self.step+=1 70 | if self.step==len(self._exe_order): 71 | return self._finish() 72 | else: 73 | # if stepping, return the value of the node we just 74 | # evaled 75 | return self._break(value=self._cache.get(next_node.name)) 76 | 77 | def c(self): 78 | """ 79 | continue 80 | """ 81 | i,node=self._get_next_eval() 82 | if node.name in self._bpset: 83 | if self.state == RUNNING: 84 | return self._break() 85 | 86 | self.state = RUNNING 87 | self._eval(node) 88 | # increment to next node 89 | self.step=i+1 90 | if self.step < len(self._exe_order): 91 | return self.c() 92 | else: 93 | return self._finish() 94 | 95 | def get_values(self): 96 | """ 97 | returns final values (same result as tf.Session.run()) 98 | """ 99 | return [self._cache.get(i.name,None) for i in self._original_evals] 100 | 101 | def get_exe_queue(self): 102 | return self._exe_order[self.step:] 103 | 104 | def get_value(self, node): 105 | """ 106 | retrieve a node value from the cache 107 | """ 108 | if isinstance(node,tf.Tensor): 109 | return self._cache.get(node.name,None) 110 | elif isinstance(node,tf.Operation): 111 | return None 112 | else: # handle ascii, unicode strings 113 | return self._cache.get(node,None) 114 | 115 | ### 116 | ### PRIVATE METHODS 117 | ### 118 | 119 | def _cache_value(self, tensor, ndarray): 120 | """ 121 | store tensor ndarray value in cache. this is called by python ops 122 | """ 123 | self._cache[tensor.name]=ndarray 124 | 125 | def _init_evals_bps(self, evals, breakpoints): 126 | # If an eval or bp is the tf.Placeholder output of a tdb.PythonOp, replace it with its respective PythonOp node 127 | evals2=[op_store.get_op(t) if op_store.is_htop_out(t) else t for t in evals] 128 | breakpoints2=[op_store.get_op(t) if op_store.is_htop_out(t) else t for t in breakpoints] 129 | # compute execution order 130 | self._exe_order=op_store.compute_exe_order(evals2) # list of nodes 131 | # compute evaluation set 132 | """ 133 | HTOps may depend on tf.Tensors that are not in eval. We need to have all inputs to HTOps ready 134 | upon evaluation. 135 | 136 | 1. all evals that were originally specified are added 137 | 2. each HTOp in the execution closure needs to be in eval (they won't be eval'ed automatically by Session.run) 138 | 3. if an input to an HTOp is a tf.Tensor (not a HT placeholder tensor), it needs to be in eval as well (it's not 139 | tensorflow so we'll have to manually evaluate it). Remember, we don't track Placeholders because we instead 140 | run the HTOps that generate their values. 141 | """ 142 | self._evalset=set([e.name for e in evals2]) 143 | for e in self._exe_order: 144 | if isinstance(e,HTOp): 145 | self._evalset.add(e.name) 146 | for t in e.inputs: 147 | if not op_store.is_htop_out(t): 148 | self._evalset.add(t.name) 149 | 150 | # compute breakpoint set 151 | self._bpset=set([bp.name for bp in breakpoints2]) 152 | 153 | def _get_next_eval(self): 154 | n=len(self._exe_order) 155 | o=self._exe_order 156 | return next((i,o[i]) for i in range(self.step,n) if (o[i].name in self._evalset or o[i].name in self._bpset)) 157 | 158 | def _eval(self, node): 159 | """ 160 | node is a TensorFlow Op or Tensor from self._exe_order 161 | """ 162 | # if node.name == 'Momentum': 163 | # pdb.set_trace() 164 | if isinstance(node,HTOp): 165 | # All Tensors MUST be in the cache. 166 | feed_dict=dict((t,self._cache[t.name]) for t in node.inputs) 167 | node.run(feed_dict) # this will populate self._cache on its own 168 | else: # is a TensorFlow node 169 | if isinstance(node,tf.Tensor): 170 | result=self.session.run(node,self._cache) 171 | self._cache[node.name]=result 172 | else: 173 | # is an operation 174 | if node.type =='Assign' or node.type == 'AssignAdd' or node.type == 'AssignSub': 175 | # special operation that takes in a tensor ref and mutates it 176 | # unfortunately, we end up having to execute nearly the full graph? 177 | # alternatively, find a way to pass the tensor_ref thru the feed_dict 178 | # rather than the tensor values. 179 | self.session.run(node,self._original_feed_dict) 180 | 181 | def _break(self,value=None): 182 | self.state=PAUSED 183 | i,next_node=self._get_next_eval() 184 | print('Breakpoint triggered. Next Node: ', next_node.name) 185 | return (self.state,value) 186 | 187 | def _finish(self): 188 | self.state=FINISHED 189 | return (self.state, self.get_values()) 190 | 191 | -------------------------------------------------------------------------------- /tdb/examples/__init__.py: -------------------------------------------------------------------------------- 1 | import mnist 2 | import viz -------------------------------------------------------------------------------- /tdb/examples/mnist.py: -------------------------------------------------------------------------------- 1 | """ 2 | builds a simple mnist model 3 | """ 4 | 5 | import gzip 6 | import numpy as np 7 | import re 8 | import sys 9 | import tensorflow as tf 10 | 11 | FLAGS = tf.app.flags.FLAGS 12 | 13 | tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train', 14 | """Directory where to write event logs """ 15 | """and checkpoint.""") 16 | 17 | IMAGE_SIZE = 28 18 | NUM_CHANNELS = 1 19 | PIXEL_DEPTH = 255 20 | NUM_LABELS = 10 21 | VALIDATION_SIZE = 5000 # Size of the validation set. 22 | SEED = 66478 # Set to None for random seed. 23 | BATCH_SIZE = 64 24 | NUM_EPOCHS = 1 25 | 26 | TEST_SIZE=55000 27 | TRAIN_SIZE=10000 28 | 29 | # DATA PRE-PROCESSING 30 | def extract_data(filename, num_images): 31 | """ 32 | Extract the images into a 4D tensor [image index, y, x, channels]. 33 | Values are rescaled from [0, 255] down to [-0.5, 0.5]. 34 | """ 35 | print('Extracting', filename) 36 | with gzip.open(filename) as bytestream: 37 | bytestream.read(16) 38 | buf = bytestream.read(IMAGE_SIZE * IMAGE_SIZE * num_images) 39 | data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32) 40 | data = (data - (PIXEL_DEPTH / 2.0)) / PIXEL_DEPTH 41 | data = data.reshape(num_images, IMAGE_SIZE, IMAGE_SIZE, 1) 42 | return data 43 | 44 | def extract_labels(filename, num_images): 45 | """ 46 | Extract the labels into a 1-hot matrix [image index, label index]. 47 | """ 48 | print('Extracting', filename) 49 | with gzip.open(filename) as bytestream: 50 | bytestream.read(8) 51 | buf = bytestream.read(1 * num_images) 52 | labels = np.frombuffer(buf, dtype=np.uint8) 53 | # Convert to dense 1-hot representation. 54 | return (np.arange(NUM_LABELS) == labels[:, None]).astype(np.float32) 55 | 56 | def get_data(data_root): 57 | train_data_filename = data_root+'train-images-idx3-ubyte.gz' 58 | train_labels_filename = data_root+'train-labels-idx1-ubyte.gz' 59 | test_data_filename = data_root+'t10k-images-idx3-ubyte.gz' 60 | test_labels_filename = data_root+'t10k-labels-idx1-ubyte.gz' 61 | 62 | # Extract it into numpy arrays. 63 | train_data = extract_data(train_data_filename, 60000) 64 | train_labels = extract_labels(train_labels_filename, 60000) 65 | test_data = extract_data(test_data_filename, 10000) 66 | test_labels = extract_labels(test_labels_filename, 10000) 67 | 68 | validation_data = train_data[:VALIDATION_SIZE, :, :, :] 69 | validation_labels = train_labels[:VALIDATION_SIZE] 70 | train_data = train_data[VALIDATION_SIZE:, :, :, :] 71 | train_labels = train_labels[VALIDATION_SIZE:] 72 | 73 | global TRAIN_SIZE, TEST_SIZE 74 | TRAIN_SIZE=train_labels.shape[0] 75 | TEST_SIZE=test_labels.shape[0] 76 | 77 | return train_data, train_labels, validation_data, validation_labels, test_data, test_labels 78 | 79 | def _activation_summary(x): 80 | """Helper to create summaries for activations. 81 | Creates a summary that provides a histogram of activations. 82 | Creates a summary that measure the sparsity of activations. 83 | Args: 84 | x: Tensor 85 | Returns: 86 | nothing 87 | """ 88 | # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training 89 | # session. This helps the clarity of presentation on tensorboard. 90 | tf.histogram_summary(x.name + '/activations', x) 91 | tf.scalar_summary(x.name + '/sparsity', tf.nn.zero_fraction(x)) 92 | 93 | # MODEL BUILDING 94 | def build_model(): 95 | """ 96 | Builds the computation graph consisting of training/testing LeNet 97 | 98 | train data - used for learning 99 | validation data - used for printing progress (does not impact learning) 100 | test data - used for printing final test error 101 | """ 102 | # training data 103 | train_data_node = tf.placeholder(tf.float32,shape=(BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS)) 104 | train_labels_node = tf.placeholder(tf.float32,shape=(BATCH_SIZE, NUM_LABELS)) 105 | 106 | validation_data_node= tf.placeholder(tf.float32,shape=(VALIDATION_SIZE, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS)) 107 | test_data_node=tf.placeholder(tf.float32,shape=(TEST_SIZE,IMAGE_SIZE,IMAGE_SIZE,NUM_CHANNELS)) 108 | # validation dataset held in a single constant node 109 | # validation_data_node = tf.constant(validation_data) 110 | # test_data_node = tf.constant(test_data) 111 | 112 | # LEARNABLE WEIGHT NODES SHARED BETWEEN 113 | conv1_weights = tf.Variable(tf.truncated_normal([5, 5, NUM_CHANNELS, 32],stddev=0.1,seed=SEED)) 114 | conv1_biases = tf.Variable(tf.zeros([32])) 115 | conv2_weights = tf.Variable(tf.truncated_normal([5, 5, 32, 64],stddev=0.1,seed=SEED)) 116 | conv2_biases = tf.Variable(tf.constant(0.1, shape=[64])) 117 | fc1_weights = tf.Variable(tf.truncated_normal([IMAGE_SIZE // 4 * IMAGE_SIZE // 4 * 64, 512],stddev=0.1,seed=SEED)) 118 | fc1_biases = tf.Variable(tf.constant(0.1, shape=[512])) 119 | fc2_weights = tf.Variable(tf.truncated_normal([512, NUM_LABELS],stddev=0.1,seed=SEED)) 120 | fc2_biases = tf.Variable(tf.constant(0.1, shape=[NUM_LABELS])) 121 | 122 | # LENET 123 | def build_lenet(data,train=False): 124 | # subroutine for wiring up nodes and weights to training and evaluation LeNets 125 | conv1 = tf.nn.conv2d(data,conv1_weights,strides=[1, 1, 1, 1],padding='SAME') 126 | relu1 = tf.nn.relu(tf.nn.bias_add(conv1, conv1_biases)) 127 | pool1 = tf.nn.max_pool(relu1,ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1],padding='SAME') 128 | conv2 = tf.nn.conv2d(pool1,conv2_weights,strides=[1, 1, 1, 1],padding='SAME') 129 | relu2 = tf.nn.relu(tf.nn.bias_add(conv2, conv2_biases)) 130 | pool2 = tf.nn.max_pool(relu2,ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1],padding='SAME') 131 | # Reshape the feature map cuboid into a 2D matrix to feed it to the 132 | # fully connected layers. 133 | pool_shape = pool2.get_shape().as_list() 134 | reshape = tf.reshape(pool2,[pool_shape[0], pool_shape[1] * pool_shape[2] * pool_shape[3]]) 135 | fc1 = tf.nn.relu(tf.matmul(reshape, fc1_weights) + fc1_biases) 136 | # Add a 50% dropout during training only. Dropout also scales 137 | # activations such that no rescaling is needed at evaluation time. 138 | if train: 139 | fc1 = tf.nn.dropout(fc1, 0.5, seed=SEED) 140 | # append summary ops to train 141 | _activation_summary(conv1) 142 | _activation_summary(fc1) 143 | 144 | fc2 = tf.matmul(fc1, fc2_weights) + fc2_biases 145 | return fc2 146 | 147 | # TRAINING LOSS / REGULARIZATION NODES 148 | logits = build_lenet(train_data_node, True) 149 | loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits, train_labels_node)) 150 | 151 | tf.scalar_summary(loss.op.name,loss) 152 | 153 | regularizers = (tf.nn.l2_loss(fc1_weights) + tf.nn.l2_loss(fc1_biases) + tf.nn.l2_loss(fc2_weights) + tf.nn.l2_loss(fc2_biases)) 154 | # Add the regularization term to the loss. 155 | loss += 5e-4 * regularizers 156 | 157 | # OPTIMIZER NODES 158 | batch = tf.Variable(0) 159 | # Decay once per epoch, using an exponential schedule starting at 0.01. 160 | learning_rate = tf.train.exponential_decay( 161 | 0.01, # Base learning rate. 162 | batch * BATCH_SIZE, # Current index into the dataset. 163 | TRAIN_SIZE, # Decay step. 164 | 0.95, # Decay rate. 165 | staircase=True) 166 | # Use simple momentum for the optimization. 167 | optimizer = tf.train.MomentumOptimizer(learning_rate,0.9).minimize(loss,global_step=batch) 168 | 169 | # # Predictions for the minibatch, validation set and test set. 170 | train_prediction = tf.nn.softmax(logits) 171 | # # We'll compute them only once in a while by calling their {eval()} method. 172 | validation_prediction = tf.nn.softmax(build_lenet(validation_data_node)) 173 | test_prediction = tf.nn.softmax(build_lenet(test_data_node)) 174 | 175 | summaries=tf.merge_all_summaries() 176 | 177 | # return input nodes and output nodes 178 | return (train_data_node, 179 | train_labels_node, 180 | validation_data_node, 181 | test_data_node, 182 | train_prediction, 183 | validation_prediction, 184 | test_prediction, 185 | conv1_weights, 186 | conv2_weights, 187 | fc1_weights, 188 | fc2_weights, 189 | optimizer, 190 | loss, 191 | learning_rate, 192 | summaries) 193 | 194 | def error_rate(predictions, labels): 195 | """Return the error rate based on dense predictions and 1-hot labels.""" 196 | return 100.0 - ( 197 | 100.0 * 198 | np.sum(np.argmax(predictions, 1) == np.argmax(labels, 1)) / 199 | predictions.shape[0]) 200 | 201 | def main(): 202 | # get dataset as numpy arrays 203 | train_data, train_labels, validation_data, validation_labels, test_data, test_labels = get_data() 204 | #pdb.set_trace() 205 | 206 | # build net (return inputs) 207 | (train_data_node, 208 | train_labels_node, 209 | validation_data_node, 210 | test_data_node, 211 | optimizer, 212 | loss, 213 | learning_rate, 214 | train_prediction, 215 | validation_prediction, 216 | test_prediction, 217 | summaries) = build_model() 218 | 219 | with tf.Session() as s: 220 | # Run all the initializers to prepare the trainable parameters. 221 | tf.initialize_all_variables().run() 222 | print('Initialized!') 223 | # Loop through training steps. 224 | summary_writer=tf.train.SummaryWriter(FLAGS.train_dir, graph_def=s.graph_def) 225 | for step in xrange(NUM_EPOCHS * TRAIN_SIZE // BATCH_SIZE): 226 | # Compute the offset of the current minibatch in the data. 227 | offset = (step * BATCH_SIZE) % (TRAIN_SIZE - BATCH_SIZE) 228 | batch_data = train_data[offset:(offset + BATCH_SIZE), :, :, :] 229 | batch_labels = train_labels[offset:(offset + BATCH_SIZE)] 230 | feed_dict = { 231 | train_data_node: batch_data, 232 | train_labels_node: batch_labels 233 | } 234 | # Run the graph and fetch some of the nodes. 235 | _, l, lr, predictions = s.run([optimizer, loss, learning_rate, train_prediction],feed_dict=feed_dict) 236 | 237 | if step % 100 == 0: 238 | # re-run graph, save summaries 239 | summary_str = summaries.eval(feed_dict) 240 | summary_writer.add_summary(summary_str, step) 241 | 242 | if step % 100 == 0: 243 | print('Epoch %.2f' % (float(step) * BATCH_SIZE / TRAIN_SIZE)) 244 | print('Minibatch loss: %.3f, learning rate: %.6f' % (l, lr)) 245 | print('Minibatch error: %.1f%%' % error_rate(predictions, batch_labels)) 246 | val_predict=validation_prediction.eval(feed_dict={validation_data_node:validation_data}) 247 | print('Validation error: %.1f%%' %error_rate(val_predict, validation_labels)) 248 | sys.stdout.flush() 249 | # Done training - print the result! 250 | test_error = error_rate(test_prediction.eval(feed_dict={test_data_node:test_data}), test_labels) 251 | print('Test error: %.1f%%' % test_error) 252 | 253 | if __name__ == "__main__": 254 | main() 255 | -------------------------------------------------------------------------------- /tdb/examples/viz.py: -------------------------------------------------------------------------------- 1 | # a collection of sample visualization functions 2 | # for binding to plotnode 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | 7 | def viz_square(data, normalize=True, cmap=plt.cm.gray, padsize=1, padval=0): 8 | """ 9 | takes a np.ndarray of shape (n, height, width) or (n, height, width, channels) 10 | visualize each (height, width) thing in a grid of size approx. sqrt(n) by sqrt(n) 11 | However, this only draws first input channel 12 | """ 13 | # normalize to 0-1 range 14 | if normalize: 15 | data -= data.min() 16 | data /= data.max() 17 | n = int(np.ceil(np.sqrt(data.shape[0]))) # force square 18 | padding = ((0, n ** 2 - data.shape[0]), (0, padsize), (0, padsize)) + ((0, 0),) * (data.ndim - 3) 19 | data = np.pad(data, padding, mode='constant', constant_values=(padval, padval)) 20 | # tile the filters into an image 21 | data = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1))) 22 | data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:]) 23 | plt.matshow(data,cmap=cmap) 24 | 25 | def viz_conv_weights(ctx, weight): 26 | # visualize all output filters 27 | # for the first input channel 28 | viz_square(weight.transpose(3,0,1,2)[:,:,:,0]) 29 | 30 | def viz_activations(ctx, m): 31 | plt.matshow(m.T,cmap=plt.cm.gray) 32 | plt.title("LeNet Predictions") 33 | plt.xlabel("Batch") 34 | plt.ylabel("Digit Activation") 35 | 36 | def viz_weight_hist(ctx, w): 37 | plt.hist(w.flatten()) 38 | 39 | def viz_conv_hist(ctx, w): 40 | n = int(np.ceil(np.sqrt(w.shape[3]))) # force square 41 | f, axes = plt.subplots(n,n,sharex=True,sharey=True) 42 | for i in range(w.shape[3]): # for each output channel 43 | r,c=i//n,i%n 44 | axes[r,c].hist(w[:,:,:,i].flatten()) 45 | axes[r,c].get_xaxis().set_visible(False) 46 | axes[r,c].get_yaxis().set_visible(False) 47 | 48 | def viz_fc_weights(ctx, w): 49 | # visualize fully connected weights 50 | plt.matshow(w.T,cmap=plt.cm.gray) 51 | 52 | def watch_loss(ctx,loss): 53 | if not hasattr(ctx, 'loss_history'): 54 | ctx.loss_history=[] 55 | ctx.loss_history.append(loss) 56 | plt.plot(ctx.loss_history) 57 | plt.ylabel('loss') 58 | -------------------------------------------------------------------------------- /tdb/ht_op.py: -------------------------------------------------------------------------------- 1 | """ 2 | abstract base class for hypertree Op 3 | """ 4 | 5 | class HTOp(object): 6 | """ 7 | Abstract class for HyperTree Operation 8 | """ 9 | def __init__(self, node_type, i, inputs, outputs): 10 | """ 11 | Args: 12 | node_type: enum type of node 13 | i: count of specific node type (used to compute name), incremented by constructor functions 14 | inputs: tf.Tensors 15 | outputs: tf.Tensors 16 | """ 17 | super(HTOp, self).__init__() 18 | self.node_type=node_type 19 | self.name=node_type+"_"+repr(i) 20 | self.session = None 21 | self.inputs=inputs 22 | self.outputs=outputs 23 | 24 | def set_session(self, debugsession): 25 | """ 26 | once nodes compute their designated input and output values 27 | they need to be able to update the DebugSession cache with the 28 | feed_dict values for their placeholder tensors. 29 | 30 | However, a DebugSession might not exist upon creation. That is why 31 | the DebugSession registers itself with all its nodes. 32 | 33 | Args: 34 | debugsession: instance of DebugSession 35 | """ 36 | self.session=debugsession 37 | 38 | def run(self,feed_dict): 39 | """ 40 | run produces the output Tensors, if any, of a given Node. 41 | This is in contrast to TensorFlow 42 | """ 43 | raise NotImplementedError('Please implement Node.run() in a subclass') 44 | -------------------------------------------------------------------------------- /tdb/interface.py: -------------------------------------------------------------------------------- 1 | """ 2 | top-level interface methods so user doesn't need to directly construct 3 | a dbsession 4 | """ 5 | 6 | import debug_session 7 | 8 | # default session 9 | _dbsession=None 10 | 11 | def debug(evals,feed_dict=None,breakpoints=None,break_immediately=False,session=None): 12 | """ 13 | spawns a new debug session 14 | """ 15 | global _dbsession 16 | _dbsession=debug_session.DebugSession(session) 17 | return _dbsession.run(evals,feed_dict,breakpoints,break_immediately) 18 | 19 | def s(): 20 | """ 21 | step to the next node in the execution order 22 | """ 23 | global _dbsession 24 | return _dbsession.s() 25 | 26 | def c(): 27 | """ 28 | continue 29 | """ 30 | global _dbsession 31 | return _dbsession.c() 32 | 33 | def get_exe_queue(): 34 | global _dbsession 35 | return _dbsession.get_exe_queue() 36 | 37 | def get_value(node): 38 | global _dbsession 39 | return _dbsession.get_value(node) -------------------------------------------------------------------------------- /tdb/op_store.py: -------------------------------------------------------------------------------- 1 | from toposort import toposort, toposort_flatten 2 | from transitive_closure import transitive_closure 3 | import tensorflow as tf 4 | 5 | _ops={} # Map 6 | _placeholder_2_op={} # Map 7 | 8 | def add_op(op): 9 | _ops[op.name]=op 10 | for t in op.outputs: 11 | _placeholder_2_op[t]=op 12 | 13 | def get_op(placeholder): 14 | return _placeholder_2_op[placeholder] 15 | 16 | def is_htop_out(placeholder): 17 | # returns True if placeholder is the output of a PythonOp 18 | return placeholder in _placeholder_2_op 19 | 20 | def compute_exe_order(evals): 21 | deps=compute_node_deps() 22 | eval_names=[e.name for e in evals] 23 | tc_deps=transitive_closure(eval_names, deps) 24 | ordered_names = toposort_flatten(tc_deps) 25 | return [get_node(name) for name in ordered_names] 26 | 27 | def get_node(name): 28 | """ 29 | returns HTOp or tf graph element corresponding to requested node name 30 | """ 31 | if name in _ops: 32 | return _ops[name] 33 | else: 34 | g=tf.get_default_graph() 35 | return g.as_graph_element(name) 36 | 37 | def register_dbsession(dbsession): 38 | for op in _ops.values(): 39 | op.set_session(dbsession) 40 | 41 | def compute_node_deps(): 42 | """ 43 | - returns the full dependency graph of ALL ops and ALL tensors 44 | Map> where key=node name, values=list of dependency names 45 | 46 | If an Op takes in a placeholder tensor that is the ouput of a PythonOp, 47 | we need to replace that Placeholder with the PythonOp. 48 | """ 49 | deps={} 50 | g=tf.get_default_graph() 51 | for op in g.get_operations(): 52 | d=set([i.name for i in op.control_inputs]) 53 | for t in op.inputs: 54 | if is_htop_out(t): 55 | d.add(get_op(t).name) 56 | else: 57 | d.add(t.name) 58 | deps[op.name]=d 59 | for t in op.outputs: 60 | deps[t.name]=set([op.name]) 61 | # do the same thing with HTOps 62 | for op in _ops.values(): 63 | d=set() 64 | for t in op.inputs: 65 | if is_htop_out(t): 66 | d.add(get_op(t).name) 67 | else: 68 | d.add(t.name) 69 | deps[op.name]=d 70 | return deps 71 | -------------------------------------------------------------------------------- /tdb/plot_op.py: -------------------------------------------------------------------------------- 1 | 2 | COUNT=0 3 | 4 | from python_op import PythonOp 5 | import app 6 | import inspect 7 | import matplotlib.pyplot as plt 8 | import op_store 9 | 10 | def plot_op(fn, inputs=[], outputs=[]): 11 | """ 12 | User-exposed api method for constructing a python_node 13 | 14 | Args: 15 | fn: python function that computes some np.ndarrays given np.ndarrays as inputs. it can have arbitrary side effects. 16 | inputs: array of tf.Tensors (optional). These are where fn derives its values from 17 | outputs: tf.Placeholder nodes (optional). These are constructed by the user (which allows the user to 18 | plug them into other ht.Ops or tf.Ops). The outputs of fn are mapped to each of the output placeholders. 19 | 20 | raises an Error if fn cannot map 21 | """ 22 | global COUNT, ht 23 | # check outputs 24 | if not isinstance(outputs,list): 25 | outputs=[outputs] 26 | 27 | for tensor in outputs: 28 | if tensor.op.type is not 'Placeholder': 29 | raise Error('Output nodes must be Placeholders') 30 | 31 | op=PlotOp(fn, COUNT, inputs, outputs) 32 | 33 | op_store.add_op(op) 34 | COUNT+=1 35 | 36 | # if node has output, return value for python_op is the first output (placeholder) tensor 37 | # otherwise, return the op 38 | if outputs: 39 | return outputs[0] 40 | else: 41 | return op 42 | 43 | class PlotOp(PythonOp): 44 | def __init__(self, fn, i, inputs, outputs): 45 | super(PlotOp, self).__init__('Plot', fn, i, inputs, outputs) 46 | 47 | def run(self, feed_dict): 48 | results=super(PlotOp, self).run(feed_dict) 49 | # send the image over 50 | if app.is_notebook(): 51 | fig=plt.gcf() 52 | app.send_fig(plt.gcf(), self.name) 53 | # close the figure 54 | plt.close(fig) 55 | return results 56 | -------------------------------------------------------------------------------- /tdb/python_op.py: -------------------------------------------------------------------------------- 1 | 2 | COUNT=0 3 | 4 | from ht_op import HTOp 5 | import inspect 6 | import numpy as np 7 | import op_store 8 | 9 | def python_op(fn, inputs=None, outputs=None): 10 | """ 11 | User-exposed api method for constructing a python_node 12 | 13 | Args: 14 | fn: python function that computes some np.ndarrays given np.ndarrays as inputs. it can have arbitrary side effects. 15 | inputs: array of tf.Tensors (optional). These are where fn derives its values from 16 | outputs: tf.Placeholder nodes (optional). These are constructed by the user (which allows the user to 17 | plug them into other ht.Ops or tf.Ops). The outputs of fn are mapped to each of the output placeholders. 18 | 19 | raises an Error if fn cannot map 20 | """ 21 | 22 | # construct a PythonOp and return its TensorNode outputs, if it has one 23 | global COUNT 24 | # check outputs 25 | if not isinstance(outputs,list): 26 | outputs=[outputs] 27 | for tensor in outputs: 28 | if tensor.op.type != 'Placeholder': 29 | raise TypeError('Output nodes must be Placeholders') 30 | op=PythonOp('Python', fn, COUNT, inputs, outputs) 31 | op_store.add_op(op) 32 | COUNT+=1 33 | if outputs: 34 | return outputs[0] 35 | else: 36 | return op 37 | 38 | class PythonOp(HTOp): 39 | """docstring for PythonOp""" 40 | def __init__(self, node_type, fn, i, inputs, outputs): 41 | """ 42 | constructor. user does not call this. 43 | """ 44 | super(PythonOp, self).__init__(node_type, i, inputs, outputs) 45 | self.fn=fn 46 | 47 | def run(self, feed_dict): 48 | #pdb.set_trace() 49 | args=tuple(feed_dict[i] for i in self.inputs) 50 | results=self.fn(self, *args) 51 | self.cache_values(results) 52 | return results 53 | 54 | def cache_values(self, results): 55 | """ 56 | loads into DebugSession cache 57 | """ 58 | if results is None: 59 | # self.fn was probably only used to compute side effects. 60 | return 61 | elif isinstance(results,np.ndarray): 62 | # fn returns single np.ndarray. 63 | # re-format it into a list 64 | results=[results] 65 | # check validity of fn output 66 | elif isinstance(results,list): 67 | if len(results) is not len(self.outputs): 68 | raise ValueError('Number of output tensors does not match number of outputs produced by function') 69 | elif isinstance(results,np.number): 70 | if len(self.outputs) != 1: 71 | raise ValueError('Fn produces scalar but %d outputs expected' % (len(self.outputs))) 72 | results=[results] 73 | # assign each element in ndarrays to corresponding output tensor 74 | for i,ndarray in enumerate(results): 75 | self.session._cache_value(self.outputs[i], ndarray) -------------------------------------------------------------------------------- /tdb/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ericjang/tdb/5e78b5dbecf78b6d28eb2f5b67decf8d1f1eb17d/tdb/tests/__init__.py -------------------------------------------------------------------------------- /tdb/tests/mnist_0.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ericjang/tdb/5e78b5dbecf78b6d28eb2f5b67decf8d1f1eb17d/tdb/tests/mnist_0.npz -------------------------------------------------------------------------------- /tdb/tests/run_tests.py: -------------------------------------------------------------------------------- 1 | #!/bin/python 2 | 3 | import unittest 4 | 5 | def main(): 6 | test_modules=[ 7 | 'test_exe_order', 8 | 'test_pure_tf', 9 | 'test_pure_ht', 10 | 'test_mixed', 11 | 'test_ui', 12 | 'test_mnist' 13 | ] 14 | 15 | suite=unittest.TestSuite() 16 | 17 | for t in test_modules: 18 | try: 19 | # If the module defines a suite() function, call it to get the suite. 20 | mod = __import__(t, globals(), locals(), ['suite']) 21 | suitefn = getattr(mod, 'suite') 22 | suite.addTest(suitefn()) 23 | except (ImportError, AttributeError): 24 | # else, just load all the test cases from the module. 25 | suite.addTest(unittest.defaultTestLoader.loadTestsFromName(t)) 26 | 27 | unittest.TextTestRunner().run(suite) 28 | 29 | if __name__ == '__main__': 30 | main() -------------------------------------------------------------------------------- /tdb/tests/test_exe_order.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import tdb 4 | from tdb.transitive_closure import transitive_closure 5 | from test_pure_tf import build_graph_tf 6 | 7 | 8 | import tensorflow as tf 9 | 10 | 11 | """ 12 | testing exe order stuff (topological sorting, transitive_closure, etc.) 13 | 14 | to run this test: python -m unittest test_exe_order 15 | 16 | """ 17 | 18 | class TestExeOrderMethods(unittest.TestCase): 19 | def test_closure(self): 20 | """ 21 | evaluating 8 in this graph does not depend on evaluating 6 22 | so 6 should be excluded from the closure. 23 | """ 24 | G={ 25 | 1:{}, 26 | 2:{}, 27 | 3:{}, 28 | 4:{1,2}, 29 | 5:{2,3}, 30 | 6:{3}, 31 | 7:{5}, 32 | 8:{4,7} 33 | } 34 | T=transitive_closure([8],G) 35 | for i in [1,2,3,4,5,7,8]: 36 | self.assertTrue(i in T) 37 | self.assertFalse(6 in T) 38 | 39 | def test_tf(self): 40 | """ 41 | ensures that execution ordering is correct 42 | """ 43 | build_graph_tf() 44 | g=tf.get_default_graph() 45 | for op in g.get_operations(): 46 | print(op.name) 47 | 48 | deps=tdb.op_store.compute_node_deps() 49 | unidict = {k.encode('ascii'): set([v.encode('ascii') for v in s]) for k, s in deps.items()} 50 | 51 | a="Const" 52 | b="Const_1" 53 | a0=a+":0" 54 | b0=b+":0" 55 | c="Add" 56 | c0=c+":0" 57 | d="Mul" 58 | d0=d+":0" 59 | e="Neg" 60 | e0=e+":0" 61 | target = { 62 | a:set(), 63 | b:set(), 64 | a0:set([a]), 65 | b0:set([b]), 66 | c:set([a0,b0]), 67 | c0:set([c]), 68 | d:set([a0,b0]), 69 | d0:set([d]), 70 | e:set([c0]), 71 | e0:set([e]) 72 | } 73 | print("deps") 74 | print(unidict) 75 | print("target") 76 | print(target) 77 | self.assertEqual(deps,target) 78 | -------------------------------------------------------------------------------- /tdb/tests/test_mixed.py: -------------------------------------------------------------------------------- 1 | """ 2 | test heterogenous graph consisting of hypertree and tensorflow nodes 3 | 4 | todo - include a graph that uses placeholder nodes for input AND 5 | has placeholder ops from PythonOps. 6 | 7 | cases: 8 | 9 | tf op -> ht op 10 | """ 11 | 12 | 13 | 14 | import sys 15 | import tensorflow as tf 16 | import unittest 17 | import tdb 18 | 19 | from test_pure_ht import myadd,mymult,myneg 20 | 21 | class TestDebuggingMixed(unittest.TestCase): 22 | def test_1(self): 23 | """ 24 | ht->tf 25 | """ 26 | a=tf.constant(2) 27 | b=tf.constant(3) 28 | c=tdb.python_op(myadd,inputs=[a,b],outputs=[tf.placeholder(tf.int32)]) # a+b 29 | d=tf.neg(c) 30 | status,result=tdb.debug([d], feed_dict=None, breakpoints=None, break_immediately=False) 31 | self.assertEqual(status, tdb.FINISHED) 32 | self.assertEqual(result[0],-5) 33 | 34 | def test_2(self): 35 | """ 36 | tf -> ht 37 | """ 38 | a=tf.constant(2) 39 | b=tf.constant(3) 40 | c=tf.add(a,b) 41 | d=tdb.python_op(myneg,inputs=[c],outputs=[tf.placeholder(tf.int32)]) 42 | status,result=tdb.debug([d], feed_dict=None, breakpoints=None, break_immediately=False) 43 | self.assertEqual(status, tdb.FINISHED) 44 | self.assertEqual(result[0],-5) 45 | self.assertEqual(tdb.get_value(d),-5) 46 | self.assertEqual(tdb.get_value(d.name),-5) 47 | -------------------------------------------------------------------------------- /tdb/tests/test_mnist.py: -------------------------------------------------------------------------------- 1 | """ 2 | HT debugging on a simple LeNET-5 convolutional model 3 | """ 4 | 5 | import numpy as np 6 | import sys 7 | import tensorflow as tf 8 | import unittest 9 | import tdb 10 | from tdb.examples import mnist, viz 11 | 12 | 13 | class TestMNIST(unittest.TestCase): 14 | def test_1(self): 15 | # single passthrough 16 | (train_data_node, 17 | train_labels_node, 18 | validation_data_node, 19 | test_data_node, 20 | # predictions 21 | train_prediction, 22 | validation_prediction, 23 | test_prediction, 24 | # weights 25 | conv1_weights, 26 | conv2_weights, 27 | fc1_weights, 28 | fc2_weights, 29 | # training 30 | optimizer, 31 | loss, 32 | learning_rate, 33 | summaries) = mnist.build_model() 34 | 35 | with tf.Session() as s: 36 | tf.initialize_all_variables().run() 37 | print('Variables initialized') 38 | step=0 39 | with np.load("mnist_0.npz") as data: 40 | feed_dict = { 41 | train_data_node: data['batch_data'], 42 | train_labels_node: data['batch_labels'] 43 | } 44 | evals=[train_prediction] 45 | status,result=tdb.debug(evals, feed_dict=feed_dict, breakpoints=None, break_immediately=False, session=s) 46 | self.assertEqual(status,tdb.FINISHED) 47 | 48 | def test_2(self): 49 | """ 50 | mnist with plotting 51 | """ 52 | (train_data_node, 53 | train_labels_node, 54 | validation_data_node, 55 | test_data_node, 56 | # predictions 57 | train_prediction, 58 | validation_prediction, 59 | test_prediction, 60 | # weights 61 | conv1_weights, 62 | conv2_weights, 63 | fc1_weights, 64 | fc2_weights, 65 | # training 66 | optimizer, 67 | loss, 68 | learning_rate, 69 | summaries) = mnist.build_model() 70 | 71 | s=tf.InteractiveSession() 72 | tf.initialize_all_variables().run() 73 | 74 | # use the same input every time for this test 75 | with np.load("mnist_0.npz") as data: 76 | a=data['batch_data'] 77 | b=data['batch_labels'] 78 | feed_dict = { 79 | train_data_node: a, 80 | train_labels_node: b 81 | } 82 | 83 | # pdb.set_trace() 84 | # result=s.run(optimizer,feed_dict) 85 | # pdb.set_trace() 86 | # tmp 87 | # return 88 | 89 | evals=[optimizer,loss,train_prediction,conv1_weights,conv2_weights,fc1_weights,fc2_weights] 90 | 91 | # define some plotting functions 92 | 93 | # use one debugSession per run 94 | 95 | # attach plot nodes 96 | g=tf.get_default_graph() 97 | p1=tdb.plot_op(viz.viz_conv_weights,inputs=[g.as_graph_element(conv1_weights)]) 98 | p2=tdb.plot_op(viz.viz_conv_weights,inputs=[g.as_graph_element(conv2_weights)]) 99 | p3=tdb.plot_op(viz.viz_fc_weights,inputs=[g.as_graph_element(fc1_weights)]) 100 | p4=tdb.plot_op(viz.viz_fc_weights,inputs=[g.as_graph_element(fc2_weights)]) 101 | 102 | # get the plot op by name and 103 | evals=[optimizer, loss, learning_rate, train_prediction, p1,p2,p3,p4] 104 | status,result=tdb.debug(evals, feed_dict=feed_dict, session=s) 105 | -------------------------------------------------------------------------------- /tdb/tests/test_pure_ht.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test networks consisting only of hypertree nodes (no TensorFlow nodes) 3 | 4 | This replicates the pure-Tensorflow network from test_pure_tf, but implements 5 | the ops as HyperTree PythonOps. 6 | """ 7 | 8 | import tensorflow as tf 9 | import unittest 10 | import tdb 11 | 12 | def myadd(ctx,a,b): 13 | """ 14 | a,b are scalars 15 | """ 16 | return a+b 17 | 18 | def mymult(ctx, a,b): 19 | """ 20 | a,b are scalars 21 | """ 22 | return a*b 23 | 24 | def myneg(ctx, a): 25 | return -a 26 | 27 | def build_graph_ht(): 28 | a=tf.constant(2) 29 | b=tf.constant(3) 30 | c=tdb.python_op(myadd,inputs=[a,b],outputs=[tf.placeholder(tf.int32)]) 31 | c2=tdb.python_op(mymult,inputs=[a,b],outputs=[tf.placeholder(tf.int32)]) 32 | d=tdb.python_op(myneg,inputs=[c],outputs=[tf.placeholder(tf.int32)]) 33 | return a,b,c,c2,d 34 | 35 | class TestDebuggingHT(unittest.TestCase): 36 | def test_1(self): 37 | """ 38 | See TestDebuggingTF.test_1 39 | """ 40 | # construct TensorFlow graph as usual 41 | a,b,c,c2,d=build_graph_ht() 42 | evals=[a,b,c,c2,d] 43 | status,result=tdb.debug(evals, feed_dict=None, breakpoints=None, break_immediately=False) 44 | self.assertEqual(status, tdb.FINISHED) 45 | self.assertEqual(result[0],2) # a = 2 46 | self.assertEqual(result[1],3) # b = 3 47 | self.assertEqual(result[2],5) # c = 5 48 | self.assertEqual(result[4],-5) # c2 = -5 49 | self.assertEqual(result[3],6) # d = 6 50 | 51 | def test_2(self): 52 | """ 53 | See TestDebuggingTF.test_2 54 | """ 55 | a,b,c,c2,d=build_graph_ht() 56 | status,result=tdb.debug([c], feed_dict=None, breakpoints=None, break_immediately=True) 57 | self.assertEqual(status, tdb.PAUSED) 58 | status,result=tdb.c() # continue 59 | self.assertEqual(status, tdb.FINISHED) 60 | self.assertEqual(result[0],5) # check that c = 5 61 | -------------------------------------------------------------------------------- /tdb/tests/test_pure_tf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Headless debugging of HyperTree where nodes only consist of 3 | TensorFlow nodes 4 | """ 5 | 6 | import tensorflow as tf 7 | import unittest 8 | import tdb 9 | 10 | def build_graph_tf(): 11 | a=tf.constant(2) 12 | b=tf.constant(3) 13 | c=tf.add(a,b) 14 | c2=tf.mul(a,b) 15 | d=tf.neg(c) 16 | return a,b,c,c2,d 17 | 18 | class TestDebuggingTF(unittest.TestCase): 19 | def test_1(self): 20 | """ 21 | test debugging of a pure TensorFlow graph 22 | no breakpoints, all nodes evaluated 23 | this should automatically build an InteractiveSession for us and create a HyperTree 24 | """ 25 | # construct TensorFlow graph as usual 26 | a,b,c,c2,d=build_graph_tf() 27 | evals=[a,b,c,c2,d] 28 | status,result=tdb.debug(evals, feed_dict=None, breakpoints=None, break_immediately=False) 29 | self.assertEqual(status, tdb.FINISHED) 30 | self.assertEqual(result[0],2) # a = 2 31 | self.assertEqual(result[1],3) # b = 3 32 | self.assertEqual(result[2],5) # c = 5 33 | self.assertEqual(result[4],-5) # c2 = 6 34 | self.assertEqual(result[3],6) # d = -5 35 | 36 | def test_2(self): 37 | """ 38 | single eval of the pentultimate node 39 | breka immediately. 40 | verify that the execution order does NOT contain d or c2 41 | """ 42 | a,b,c,c2,d=build_graph_tf() 43 | status,result=tdb.debug([c], feed_dict=None, breakpoints=None, break_immediately=True) 44 | self.assertEqual(status, tdb.PAUSED) 45 | status,result=tdb.c() # continue 46 | self.assertEqual(status, tdb.FINISHED) 47 | self.assertEqual(result[0],5) # check that c = 5 48 | 49 | def test_3(self): 50 | """ 51 | with breakpoints 52 | """ 53 | # construct TensorFlow graph as usual 54 | a,b,c,c2,d=build_graph_tf() 55 | status,result=tdb.debug(d, feed_dict=None, breakpoints=[c], break_immediately=False) 56 | self.assertEqual(status, tdb.PAUSED) 57 | self.assertEqual(result, None) 58 | status,result=tdb.c() 59 | self.assertEqual(status, tdb.FINISHED) 60 | self.assertEqual(result[0],-5) -------------------------------------------------------------------------------- /tdb/tests/test_ui.py: -------------------------------------------------------------------------------- 1 | 2 | import unittest 3 | import tdb 4 | 5 | class TestUI(unittest.TestCase): 6 | def test_1(self): 7 | # verify that ui is indeed disabled 8 | self.assertFalse(tdb.is_notebook()) 9 | -------------------------------------------------------------------------------- /tdb/transitive_closure.py: -------------------------------------------------------------------------------- 1 | 2 | def _tchelper(tc_deps,evals,deps): 3 | """ 4 | modifies graph in place 5 | """ 6 | for e in evals: 7 | if e in tc_deps: # we've already included it 8 | continue 9 | else: 10 | if e in deps: # has additional dependnecies 11 | tc_deps[e]=deps[e] 12 | # add to tc_deps the dependencies of the dependencies 13 | _tchelper(tc_deps,deps[e],deps) 14 | return tc_deps 15 | 16 | def transitive_closure(evals,deps): 17 | """ 18 | evals = node names we want values for (i.e. we don't care about any other nodes after 19 | we've evaluated all the eval nodes) 20 | deps = full dependency graph of all nodes 21 | """ 22 | return _tchelper({},evals,deps) 23 | -------------------------------------------------------------------------------- /tdb_ext/activate_dev.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export PATH=$PATH:node_modules/bower/bin 4 | -------------------------------------------------------------------------------- /tdb_ext/bower.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "ht", 3 | "version": "0.0.0", 4 | "authors": [ 5 | "Eric Jang " 6 | ], 7 | "main": "main.js", 8 | "keywords": [ 9 | "deep learning", 10 | "machine learning", 11 | "tensorflow", 12 | "AI" 13 | ], 14 | "license": "Apache 2.0", 15 | "ignore": [ 16 | "**/.*", 17 | "node_modules", 18 | "bower_components" 19 | ], 20 | "dependencies": { 21 | "flux": "~2.1.1", 22 | "eventemitter2": "~0.4.14", 23 | "react": "~0.13.3", 24 | "requirejs-react-jsx": "~0.14.2" 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /tdb_ext/components/plotlistview.jsx: -------------------------------------------------------------------------------- 1 | // plottable watch variable in HyperTree 2 | 3 | 4 | var COMPONENTS='/nbextensions/tdb_ext/components/' 5 | 6 | 7 | define([ 8 | 'react', 9 | 'jquery', 10 | 'dispatcher', 11 | 'jsx!' + COMPONENTS + 'plotview.jsx' 12 | ], function(React,$,dispatcher,PlotView){ 13 | 14 | var PlotListView = React.createClass({ 15 | render: function() { 16 | var plots=[] 17 | 18 | var allPlots = this.props.plots 19 | for (var key in allPlots) { 20 | plots.push() 21 | } 22 | 23 | return ( 24 |
25 | {plots} 26 |
27 | ) 28 | } 29 | }) 30 | 31 | return PlotListView 32 | }); 33 | -------------------------------------------------------------------------------- /tdb_ext/components/plotview.jsx: -------------------------------------------------------------------------------- 1 | // plottable watch variable in HyperTree 2 | 3 | define([ 4 | 'react', 5 | 'jquery', 6 | 'dispatcher', 7 | ], function(React,$,dispatcher){ 8 | 9 | 10 | var PlotView = React.createClass({ 11 | propTypes : { 12 | name: React.PropTypes.string, 13 | src: React.PropTypes.string 14 | }, 15 | render: function() { 16 | // need an x close button 17 | var css ={ 18 | border:"1px solid", 19 | borderRadius:"5px", 20 | margin:"10px", 21 | position: "relative" 22 | } 23 | var imgcss = { 24 | maxWidth: "100%", 25 | padding:"8px" 26 | } 27 | var h3css={ 28 | textAlign: "center" 29 | } 30 | var closecss={ 31 | position: "absolute", 32 | top: "20px", 33 | right: "20px" 34 | } 35 | return ( 36 |
37 |

{this.props.name}

38 | 41 | 42 |
43 | ) 44 | }, 45 | _onDestroyClick: function() { 46 | dispatcher.dispatch({ 47 | actionType:'remove_plot', 48 | name:this.props.name 49 | }) 50 | } 51 | }) 52 | 53 | return PlotView 54 | }); 55 | -------------------------------------------------------------------------------- /tdb_ext/components/textinputview.jsx: -------------------------------------------------------------------------------- 1 | // plottable watch variable in HyperTree 2 | 3 | define([ 4 | 'react', 5 | 'jquery', 6 | 'dispatcher', 7 | ], function(React,$,dispatcher){ 8 | 9 | 10 | var TextInputView = React.createClass({ 11 | propTypes : { 12 | name: React.PropTypes.string, 13 | src: React.PropTypes.string 14 | }, 15 | render: function() { 16 | // need an x close button 17 | var css ={ 18 | border:"1px solid", 19 | borderRadius:"5px", 20 | margin:"10px" 21 | } 22 | var imgcss = { 23 | maxWidth: "100%", 24 | margin:"5px" 25 | } 26 | return ( 27 |
28 |

{this.props.name}

29 | 30 |
31 | ) 32 | } 33 | }) 34 | 35 | return TextInputView 36 | }); 37 | -------------------------------------------------------------------------------- /tdb_ext/components/ui.jsx: -------------------------------------------------------------------------------- 1 | // class that manages the top-level React component 2 | // handles UI-related methods 3 | 4 | var COMPONENTS='/nbextensions/tdb_ext/components/' 5 | 6 | define([ 7 | 'react', 8 | 'jquery', 9 | 'dispatcher', 10 | 'plotstore', 11 | 'jsx!' + COMPONENTS + 'plotlistview', 12 | 'jsx!' + COMPONENTS + 'user_msg_view' 13 | ], function(React,$,dispatcher,PlotStore, PlotListView, UserMsgView){ 14 | 15 | var UI=function(){ 16 | this.had_loaded=false 17 | // starting notebook width 18 | this._nbw = .65 // width of notebook (percentage) when HT pane is open 19 | 20 | this.UIView = React.createClass({ 21 | getInitialState: function() { 22 | return this._getState() 23 | }, 24 | render: function() { 25 | var css={ 26 | 'backgroundColor':'#FFFFFF', 27 | } 28 | return ( 29 |
30 | 31 | 32 |
33 | ) 34 | }, 35 | componentDidMount: function() { 36 | // listen to PlotStore 37 | PlotStore.addChangeListener(this._onChange) 38 | }, 39 | _onChange: function() { 40 | this.setState(this._getState()) 41 | }, 42 | _getState: function() { 43 | return { 44 | "plots":PlotStore.getPlots() 45 | } 46 | } 47 | }) 48 | 49 | } 50 | 51 | UI.prototype.load_ui = function() { 52 | // initializes the UI 53 | 54 | $('#site').after('
') 55 | 56 | var w=$('#site').width() 57 | var h=$('#site').height() 58 | 59 | $('#ht_separator').css({ 60 | height: "100%", 61 | width: "1%", 62 | float: "left", 63 | backgroundColor: "grey" 64 | }) 65 | 66 | $('#ht_separator').after('
') 67 | 68 | 69 | $('#ht_main').width((1-this._nbw-.02)*w) 70 | 71 | $('#ht_main').css({ 72 | float:'left', 73 | height:'100%', 74 | overflow:'scroll' 75 | }) 76 | 77 | var self=this 78 | $('#ht_separator').draggable({ 79 | axis: 'x', 80 | containment: 'parent', 81 | helper: 'clone', 82 | start: function(event, ui) { 83 | self.ow=$(window).width() 84 | }, 85 | drag: function (event, ui) { 86 | //console.log(ui) 87 | var width=ui.offset.left 88 | $(this).prev().width(width) 89 | //console.log(self.ow) 90 | $(this).next().width(self.ow-width-0.01*self.ow) 91 | } 92 | }); 93 | 94 | React.render(,document.getElementById("ht_main")) 95 | this.has_loaded=true 96 | } 97 | 98 | UI.prototype.show_ui = function(){ 99 | // load the split view 100 | // slide notebook to left and inject ui 101 | 102 | // something weird going on - has show_ui and hide_ui in prototype 103 | // but missing load_ui and test() 104 | // for now, manually access the global object 105 | 106 | $('#site').width(this._nbw*100+'%') 107 | $('#notebook-container').width(this._nbw*100+'%') 108 | $('#site').css('float','left'); 109 | 110 | $('#ht_main').show() 111 | } 112 | 113 | UI.prototype.hide_ui = function() { 114 | // hide the split view 115 | $('#ht_main').hide() 116 | $('#site').width('') 117 | $('#notebook-container').width('') 118 | } 119 | 120 | return UI 121 | }) 122 | 123 | 124 | 125 | -------------------------------------------------------------------------------- /tdb_ext/components/user_msg_view.jsx: -------------------------------------------------------------------------------- 1 | /* 2 | * system messages/errors appear here 3 | */ 4 | 5 | define([ 6 | 'react', 7 | 'jquery', 8 | 'dispatcher', 9 | ], function(React,$,dispatcher){ 10 | 11 | var UserMsgView = React.createClass({ 12 | getInitialState: function() { 13 | return {msg:"Waiting for TDB to connect..."} 14 | }, 15 | render: function() { 16 | return ( 17 |
18 |

{this.state.msg}

19 |
20 | ) 21 | }, 22 | componentDidMount: function() { 23 | var self=this 24 | dispatcher.register(function(action){ 25 | if (action.actionType == 'user_msg') 26 | self.setState({msg:action.data}) 27 | }) 28 | } 29 | }) 30 | 31 | return UserMsgView 32 | }); 33 | 34 | -------------------------------------------------------------------------------- /tdb_ext/config.yaml: -------------------------------------------------------------------------------- 1 | # Jupyter Notebook Extension Description 2 | Name: HyperTree 3 | Description: Deep Learning IDE 4 | Link: evjang.com/hypertree 5 | #Icon: icon.png 6 | Main: main.js 7 | Compatibility: 4.x 8 | -------------------------------------------------------------------------------- /tdb_ext/dispatcher.js: -------------------------------------------------------------------------------- 1 | // application-wide event dispatcher 2 | 3 | define([ 4 | "flux" 5 | ], function(flux){ 6 | return new flux.Dispatcher() 7 | }) 8 | -------------------------------------------------------------------------------- /tdb_ext/htapp.js: -------------------------------------------------------------------------------- 1 | // communicates with the corresponding htapp on the python side 2 | 3 | 4 | define([ 5 | 'base/js/namespace', 6 | 'jsx!/nbextensions/tdb_ext/components/ui', 7 | 'base/js/events', 8 | 'dispatcher' 9 | ], function(Jupyter,UI,events,dispatcher){ 10 | 11 | var HTApp=function(comm_manager){ 12 | // 13 | // MEMBER VARIABLES 14 | // 15 | 16 | this.ui = new UI() 17 | this._comm = null 18 | this.is_connected = true 19 | 20 | // 21 | // CONSTRUCTOR INITIALIZATION 22 | // 23 | var self=this 24 | comm_manager.register_target('tdb', function(comm,msg){ 25 | dispatcher.dispatch({ 26 | actionType:'user_msg', 27 | data:'TDB connected: success' 28 | }) 29 | 30 | self._comm=comm 31 | this.is_connected=true 32 | comm.on_msg(function(msg){ 33 | var data=msg['content']['data'] 34 | 35 | var msg_type=data['msg_type'] 36 | if (msg_type == 'action') { 37 | // if we receive an action, send it to dispatcher 38 | var action={ 39 | actionType: data['action'], 40 | data: data['params'] 41 | } 42 | dispatcher.dispatch(action) 43 | } else { 44 | console.log('Unrecognized msg_type ' + msg_type) 45 | console.log(data) 46 | } 47 | }) 48 | }) 49 | 50 | events.on('kernel_restarting.Kernel', function() { 51 | dispatcher.dispatch({ 52 | actionType: 'clear' 53 | }) 54 | }); 55 | 56 | } 57 | 58 | 59 | return HTApp 60 | }) 61 | 62 | 63 | -------------------------------------------------------------------------------- /tdb_ext/keymirror.js: -------------------------------------------------------------------------------- 1 | 2 | /** 3 | * AMD-version of Facebook's Keymirror 4 | * Constructs an enumeration with keys equal to their value. 5 | * 6 | * For example: 7 | * 8 | * var COLORS = keyMirror({blue: null, red: null}); 9 | * var myColor = COLORS.blue; 10 | * var isColorValid = !!COLORS[myColor]; 11 | * 12 | * The last line could not be performed if the values of the generated enum were 13 | * not equal to their keys. 14 | * 15 | * Input: {key1: val1, key2: val2} 16 | * Output: {key1: key1, key2: key2} 17 | * 18 | * @param {object} obj 19 | * @return {object} 20 | */ 21 | 22 | define([],function(){ 23 | var keyMirror = function(obj) { 24 | var ret = {}; 25 | var key; 26 | if (!(obj instanceof Object && !Array.isArray(obj))) { 27 | throw new Error('keyMirror(...): Argument must be an object.'); 28 | } 29 | for (key in obj) { 30 | if (obj.hasOwnProperty(key)) { 31 | ret[key] = key; 32 | } 33 | } 34 | return ret; 35 | }; 36 | return keyMirror; 37 | }) 38 | -------------------------------------------------------------------------------- /tdb_ext/main.js: -------------------------------------------------------------------------------- 1 | /* 2 | This is the entry point into HyperTree nbextension 3 | after installing into ~/.jupyter/nbextensions/, do 4 | javascript:require(["nbextensions/tdb_ext"] function(ht){}); 5 | */ 6 | 7 | var ROOT = '/nbextensions/tdb_ext' 8 | var BOWER = ROOT+'/bower_components' 9 | var STORES = '/nbextensions/tdb_ext/stores' 10 | 11 | // easy access to vendor libs 12 | 13 | require.config({ 14 | paths: { 15 | "dispatcher": ROOT + "/dispatcher", 16 | "flux":BOWER+"/flux/dist/Flux.min", 17 | "eventemitter2":BOWER + '/eventemitter2/lib/eventemitter2', 18 | "keyMirror":ROOT+'/keymirror', 19 | "raphael":BOWER + "/raphael/raphael-min", 20 | "dispatcher": ROOT + "/dispatcher", 21 | "react": BOWER + "/react/react-with-addons", 22 | "JSXTransformer": BOWER + "/react/JSXTransformer", 23 | "jsx": BOWER + "/requirejs-react-jsx/jsx", 24 | "text": BOWER + "/requirejs-text/text", 25 | "plotstore":STORES + "/plotstore" 26 | }, 27 | shim : { 28 | "react": { 29 | "exports": "React" 30 | }, 31 | "JSXTransformer": "JSXTransformer" 32 | }, 33 | config: { 34 | jsx: { 35 | fileExtension: ".jsx", 36 | transformOptions: { 37 | harmony: true, 38 | stripTypes: false, 39 | inlineSourceMap: true 40 | }, 41 | usePragma: false 42 | } 43 | } 44 | }); 45 | 46 | define([ 47 | 'base/js/namespace', 48 | ROOT + '/htapp.js' 49 | ], function(Jupyter,HTApp){ 50 | var kernel = Jupyter.notebook.kernel 51 | var comm_manager=Jupyter.notebook.kernel.comm_manager 52 | // icky - binding as a global variable 53 | HT=new HTApp(comm_manager) 54 | HT.ui.load_ui() 55 | HT.ui.show_ui() 56 | console.log('HT nbextension loaded') 57 | }) 58 | -------------------------------------------------------------------------------- /tdb_ext/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "ht", 3 | "version": "0.0.0", 4 | "description": "Deep Learning IDE", 5 | "main": "", 6 | "scripts": { 7 | "test": "echo \"Error: no test specified\" && exit 1" 8 | }, 9 | "author": "Eric Jang", 10 | "license": "Apache 2.0", 11 | "devDependencies": { 12 | "bower": "~1.6.7" 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /tdb_ext/stores/plotstore.js: -------------------------------------------------------------------------------- 1 | 2 | define([ 3 | 'dispatcher', 4 | 'eventemitter2' 5 | ], function(dispatcher,EventEmitter){ 6 | 7 | var CHANGE_EVENT = 'change' 8 | 9 | // key = PlotOp node name 10 | // value = img src 11 | var _plots = {} 12 | 13 | 14 | var PlotStore = Object.assign({}, EventEmitter.prototype, { 15 | getPlots: function() { 16 | return _plots 17 | }, 18 | update: function(data) { 19 | _plots[data.name]=data.src 20 | }, 21 | clear: function() { 22 | _plots={} 23 | }, 24 | remove: function(name) { 25 | if (_plots.hasOwnProperty(name)) { 26 | delete _plots[name] 27 | } 28 | }, 29 | emitChange: function() { 30 | this.emit(CHANGE_EVENT); 31 | }, 32 | addChangeListener: function(callback) { 33 | this.on(CHANGE_EVENT, callback) 34 | }, 35 | removeChangeListener: function(callback) { 36 | this.removeListener(CHANGE_EVENT, callback) 37 | } 38 | }) 39 | 40 | // register store with the dispatcher 41 | dispatcher.register(function(action){ 42 | switch(action.actionType){ 43 | case 'update_plot': 44 | // create new plot or update an old one 45 | PlotStore.update(action.data) 46 | PlotStore.emitChange() 47 | break; 48 | case 'remove_plot': 49 | PlotStore.remove(action.name) 50 | PlotStore.emitChange() 51 | break; 52 | case 'clear': 53 | PlotStore.clear() 54 | PlotStore.emitChange() 55 | break; 56 | default: 57 | // no op 58 | } 59 | }) 60 | 61 | return PlotStore 62 | }) 63 | --------------------------------------------------------------------------------