├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── docs ├── Makefile ├── conf.py ├── index.rst ├── nbuild │ └── .gitkeep ├── nstatic │ └── .gitkeep └── ntemplates │ └── .gitkeep ├── examples ├── README.md ├── estimator-example.py ├── horovod-all-example.py ├── horovod-example.py ├── load_session-example.py ├── monitoredtrainingsession-example.py ├── monkey_patching-example.py ├── options-example.py └── timeline-example.py ├── gallery ├── README.md ├── global-tracing.png ├── horovod-timeline-details.png ├── horovod-timeline.png ├── main.png └── tracing.png ├── requirements.txt ├── setup.py └── tftracer ├── __init__.py ├── __main__.py ├── monkey_patching.py ├── resources ├── templates │ ├── on_change_callback.js │ ├── on_click_callback.js │ ├── on_hover_callback.js │ ├── timeline.html │ ├── tooltips.html │ └── update_ranges.js └── web │ ├── app.js │ ├── main.html │ └── tensorflow.png ├── timeline.py ├── timeline_visualizer.py ├── tracing_server.py └── version.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/nbuild/html 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # IPython 78 | profile_default/ 79 | ipython_config.py 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # celery beat schedule file 85 | celerybeat-schedule 86 | 87 | # SageMath parsed files 88 | *.sage.py 89 | 90 | # Environments 91 | .env 92 | .venv 93 | env/ 94 | venv/ 95 | ENV/ 96 | env.bak/ 97 | venv.bak/ 98 | 99 | # Spyder project settings 100 | .spyderproject 101 | .spyproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | # mypy 110 | .mypy_cache/ 111 | .dmypy.json 112 | dmypy.json 113 | 114 | # Pyre type checker 115 | .pyre/ 116 | 117 | # PyCharm 118 | .idea/ 119 | 120 | # macOS 121 | .DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow Runtime Tracer 2 | This project is a web application to monitor and trace TensorFlow scripts in the runtime on the `op` level. 3 | 4 | It starts a web server upon the execution of the script. The web interface keeps track of all the session runs and can trace the execution on demand. 5 | 6 | The goal of this tool is to facilitate the process of performance tuning with minimal code changes and insignificant runtime overhead. Both Higher-level ([tf.estimator.Estimator](https://www.tensorflow.org/guide/estimators)) and Low-level ([tf.train.MonitoredTrainingSession](https://www.tensorflow.org/api_docs/python/tf/train/MonitoredTrainingSession) and co) APIs are supported. It also supports [horovod](https://github.com/uber/horovod) and [IBM Distributed Deep Learning (DDL)](https://dataplatform.cloud.ibm.com/docs/content/analyze-data/ml_dlaas_ibm_ddl.html). 7 | The tracing session can be saved, reloaded, and distributed effortlessly. 8 | 9 | Some screenshots [here](https://github.com/xldrx/tensorflow-tracer/blob/master/gallery). 10 | 11 | ## Installation 12 | Use `pip` to install: 13 | ```bash 14 | pip install tensorflow-tracer 15 | ``` 16 | 17 | ## Quick Start 18 | 1. Install `tensorflow-tracer` and run an example: 19 | ```html 20 | $ pip3 install tensorflow-tracer 21 | $ git clone https://github.com/xldrx/tensorflow-tracer.git 22 | $ python3 ./tensorflow-tracer/examples/estimator-example.py 23 | ``` 24 | 2. Browse to: `http://0.0.0.0:9999` 25 | 26 | ## How to Use 27 | 1. Add `tftracer` to your code: 28 | 29 | Estimator API: 30 | ```python 31 | from tftracer import TracingServer 32 | ... 33 | 34 | tracing_server = TracingServer() 35 | estimator.train(input_fn, hooks=[tracing_server.hook]) 36 | ``` 37 | 38 | Low-Level API: 39 | ```python 40 | from tftracer import TracingServer 41 | ... 42 | tracing_server = TracingServer() 43 | with tf.train.MonitoredTrainingSession(hooks=[tracing_server.hook]): 44 | ... 45 | ``` 46 | 47 | [[More examples here]](https://github.com/xldrx/tensorflow-tracer/blob/master/examples/) 48 | 49 | 2. Run your code and browse to: 50 | ```html 51 | http://0.0.0.0:9999 52 | ``` 53 | 54 | ## How to Trace an Existing Code 55 | 56 | If you want to trace an existing script without any modification use `tftracer.hook_inject` 57 | Please note that this is experimental and may cause unexpected errors: 58 | 59 | 1. Add the following to the beggining of the main script: 60 | .. code-block:: python 61 | 62 | import tftracer 63 | tftracer.hook_inject() 64 | ... 65 | 2. Run your code and browse to `http://0.0.0.0:9999` 66 | 67 | 68 | ## Command line 69 | Tracing sessions can be stored either through the web interface or by calling `tracing_server.save_session(filename)`. 70 | 71 | To reload a session, run this in the terminal: 72 | ```bash 73 | tftracer filename 74 | ``` 75 | 76 | Then browse to: 77 | ```html 78 | http://0.0.0.0:9999 79 | ``` 80 | 81 | ## API 82 | Full Documentation is [here](https://tensorflow-tracer.readthedocs.io/en/latest/). 83 | 84 | ## Known Bugs/Limitations 85 | * Only Python3 is supported. 86 | * The web interface loads javascript/css libraries remotely (e.g. `vue.js`, `ui-kit`, `jquery`, `jquery-ui`, `Google Roboto`, `awesome-icons`, ... ). Therefore an active internet connection is needed to properly render the interface. The tracing server does not require any remote connection. 87 | * All traces are kept in the memory while tracing server is running. 88 | * Tracing uses `tf.train.SessionRunHook` and is unable to trace auxiliary runs such as `init_op`. 89 | * The tracing capability is limited to what `tf.RunMetadata` offers. For example, CUPTI events are missing when tracing a distributed job. 90 | * HTTPS is not supported. 91 | 92 | ## Frequently Asked Questions 93 | 94 | ### How to trace/visualize just one session run? 95 | Use `tftracer.Timeline`. for example: 96 | ```python 97 | from tftracer import Timeline 98 | ... 99 | with tf.train.MonitoredTrainingSession() as sess: 100 | with Timeline() as tl: 101 | sess.run(fetches, **tl.kwargs) 102 | ... 103 | tl.visualize(filename) 104 | ``` 105 | 106 | ### Comparision to TensorBoard? 107 | The nature of this project is a short-lived light-weight interactive tracing interface to monitor and trace execution on the `op`-level. In comparison `TensorBoard` is a full-featured tool to inspect the application on many levels: 108 | * `tftracer` does not make any assumption about the dataflow DAG. There is no need to add any additional `op` to the data flow dag (i.e. `tf.summary`) or having a `global step`. 109 | 110 | * `tftracer` runs as a thread and lives from the start of the execution and lasts until the end of it. `TensorBoard` runs as a separate process and can outlive the main script. 111 | 112 | ## Cite this tool 113 | ```latex 114 | @misc{hashemi-tftracer-2018, 115 | author = {Sayed Hadi Hashemi}, 116 | title = {TensorFlow Runtime Tracer}, 117 | year = {2018}, 118 | publisher = {GitHub}, 119 | journal = {GitHub repository}, 120 | howpublished = {\url{https://github.com/xldrx/tensorflow-tracer}}, 121 | } -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = . 8 | BUILDDIR = nbuild 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # This file does only contain a selection of the most common options. For a 6 | # full list see the documentation: 7 | # http://www.sphinx-doc.org/en/master/config 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | import os 16 | import sys 17 | 18 | sys.path.insert(0, os.path.abspath('..')) 19 | 20 | from tftracer.version import __version__ 21 | # -- Project information ----------------------------------------------------- 22 | 23 | project = 'TensorFlow Runtime Tracer' 24 | copyright = '2018, Sayed Hadi Hashemi' 25 | author = 'Sayed Hadi Hashemi' 26 | 27 | # The short X.Y version 28 | version = __version__ 29 | # The full version, including alpha/beta/rc tags 30 | release = __version__ 31 | 32 | 33 | # -- General configuration --------------------------------------------------- 34 | 35 | # If your documentation needs a minimal Sphinx version, state it here. 36 | # 37 | # needs_sphinx = '1.0' 38 | 39 | # Add any Sphinx extension module names here, as strings. They can be 40 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 41 | # ones. 42 | extensions = [ 43 | 'sphinx.ext.autodoc', 44 | # 'sphinx.ext.viewcode', 45 | # 'sphinx.ext.githubpages', 46 | 'sphinx.ext.napoleon' 47 | ] 48 | 49 | # Add any paths that contain templates here, relative to this directory. 50 | templates_path = ['ntemplates'] 51 | 52 | # The suffix(es) of source filenames. 53 | # You can specify multiple suffix as a list of string: 54 | # 55 | # source_suffix = ['.rst', '.md'] 56 | source_suffix = '.rst' 57 | 58 | # The master toctree document. 59 | master_doc = 'index' 60 | 61 | # The language for content autogenerated by Sphinx. Refer to documentation 62 | # for a list of supported languages. 63 | # 64 | # This is also used if you do content translation via gettext catalogs. 65 | # Usually you set "language" from the command line for these cases. 66 | language = None 67 | 68 | # List of patterns, relative to source directory, that match files and 69 | # directories to ignore when looking for source files. 70 | # This pattern also affects html_static_path and html_extra_path. 71 | exclude_patterns = ['nbuild', 'Thumbs.db', '.DS_Store'] 72 | 73 | # The name of the Pygments (syntax highlighting) style to use. 74 | pygments_style = None 75 | 76 | 77 | # -- Options for HTML output ------------------------------------------------- 78 | 79 | # The theme to use for HTML and HTML Help pages. See the documentation for 80 | # a list of builtin themes. 81 | # 82 | # html_theme = 'alabaster' 83 | html_theme = "sphinx_rtd_theme" 84 | html_theme_path = ["_themes", ] 85 | 86 | # Theme options are theme-specific and customize the look and feel of a theme 87 | # further. For a list of options available for each theme, see the 88 | # documentation. 89 | # 90 | html_theme_options = { 91 | 'analytics_id': '', 92 | 'collapse_navigation': False, 93 | } 94 | 95 | # Add any paths that contain custom static files (such as style sheets) here, 96 | # relative to this directory. They are copied after the builtin static files, 97 | # so a file named "default.css" will overwrite the builtin "default.css". 98 | html_static_path = ['nstatic'] 99 | 100 | # Custom sidebar templates, must be a dictionary that maps document names 101 | # to template names. 102 | # 103 | # The default sidebars (for documents that don't match any pattern) are 104 | # defined by theme itself. Builtin themes are using these templates by 105 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 106 | # 'searchbox.html']``. 107 | # 108 | # html_sidebars = {} 109 | 110 | # html_sidebars = { 111 | # '**': [ 112 | # 'about.html', 113 | # 'navigation.html', 114 | # 'relations.html', 115 | # 'searchbox.html', 116 | # 'donate.html', 117 | # ] 118 | # } 119 | 120 | 121 | # -- Options for HTMLHelp output --------------------------------------------- 122 | 123 | # Output file base name for HTML help builder. 124 | htmlhelp_basename = 'tensorflow-tracerdoc' 125 | 126 | 127 | # -- Options for LaTeX output ------------------------------------------------ 128 | 129 | latex_elements = { 130 | # The paper size ('letterpaper' or 'a4paper'). 131 | # 132 | # 'papersize': 'letterpaper', 133 | 134 | # The font size ('10pt', '11pt' or '12pt'). 135 | # 136 | # 'pointsize': '10pt', 137 | 138 | # Additional stuff for the LaTeX preamble. 139 | # 140 | # 'preamble': '', 141 | 142 | # Latex figure (float) alignment 143 | # 144 | # 'figure_align': 'htbp', 145 | } 146 | 147 | # Grouping the document tree into LaTeX files. List of tuples 148 | # (source start file, target name, title, 149 | # author, documentclass [howto, manual, or own class]). 150 | latex_documents = [ 151 | (master_doc, 'tensorflow-tracer.tex', 'TensorFlow Runtime Tracer Documentation', 152 | 'Sayed Hadi Hashemi', 'manual'), 153 | ] 154 | 155 | 156 | # -- Options for manual page output ------------------------------------------ 157 | 158 | # One entry per manual page. List of tuples 159 | # (source start file, name, description, authors, manual section). 160 | man_pages = [ 161 | (master_doc, 'tensorflow-tracer', 'TensorFlow Runtime Tracer Documentation', 162 | [author], 1) 163 | ] 164 | 165 | 166 | # -- Options for Texinfo output ---------------------------------------------- 167 | 168 | # Grouping the document tree into Texinfo files. List of tuples 169 | # (source start file, target name, title, author, 170 | # dir menu entry, description, category) 171 | texinfo_documents = [ 172 | (master_doc, 'tensorflow-tracer', 'TensorFlow Runtime Tracer Documentation', 173 | author, 'tensorflow-tracer', 'One line description of project.', 174 | 'Miscellaneous'), 175 | ] 176 | 177 | 178 | # -- Options for Epub output ------------------------------------------------- 179 | 180 | # Bibliographic Dublin Core info. 181 | epub_title = project 182 | 183 | # The unique identifier of the text. This can be a ISBN number 184 | # or the project homepage. 185 | # 186 | # epub_identifier = '' 187 | 188 | # A unique identification for the text. 189 | # 190 | # epub_uid = '' 191 | 192 | # A list of files that should not be packed into the epub file. 193 | epub_exclude_files = ['search.html'] 194 | 195 | 196 | # -- Extension configuration ------------------------------------------------- -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. tensorflow-tracer documentation master file, created by 2 | sphinx-quickstart on Sun Nov 25 23:02:10 2018. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to TensorFlow Runtime Tracer documentation! 7 | ==================================================== 8 | 9 | TensorFlow Runtime Tracer is a web application to monitor and trace TensorFlow scripts in the runtime on the ``op`` level. 10 | 11 | It starts a web server upon the execution of the script. The web interface keeps track of all the session runs and can trace the execution on demand. 12 | 13 | The goal of this tool is to facilitate the process of performance tuning with minimal code changes and insignificant runtime overhead. Both Higher-level (\ `tf.estimator.Estimator `_\ ) and Low-level (\ `tf.train.MonitoredTrainingSession `_ and co) APIs are supported. It also supports `horovod `_ and `IBM Distributed Deep Learning (DDL) `_. 14 | The tracing session can be saved, reloaded, and distributed effortlessly. 15 | 16 | Installation 17 | ============ 18 | 19 | Use ``pip`` to install: 20 | 21 | .. code-block:: bash 22 | 23 | pip install tensorflow-tracer 24 | 25 | Quick Start 26 | =========== 27 | 28 | #. 29 | Install `tensorflow-tracer` and run an example: 30 | 31 | .. code-block:: bash 32 | 33 | $ pip3 install tensorflow-tracer 34 | $ git clone https://github.com/xldrx/tensorflow-tracer.git 35 | $ python3 ./tensorflow-tracer/examples/estimator-example.py 36 | 37 | #. 38 | Browse to: ``http://0.0.0.0:9999`` 39 | 40 | How to Use 41 | ========== 42 | 43 | #. 44 | Add ``tftracer`` to your code: 45 | 46 | Estimator API: 47 | 48 | .. code-block:: python 49 | 50 | from tftracer import TracingServer 51 | ... 52 | 53 | tracing_server = TracingServer() 54 | estimator.train(input_fn, hooks=[tracing_server.hook]) 55 | 56 | Low-Level API: 57 | 58 | .. code-block:: python 59 | 60 | from tftracer import TracingServer 61 | ... 62 | tracing_server = TracingServer() 63 | with tf.train.MonitoredTrainingSession(hooks=[tracing_server.hook]): 64 | ... 65 | 66 | #. 67 | Run your code and browse to: 68 | 69 | .. code-block:: html 70 | 71 | http://0.0.0.0:9999 72 | 73 | How to Trace an Existing Code 74 | ============================= 75 | 76 | If you want to trace an existing script without any modification use :func:`tftracer.hook_inject`. Please note that 77 | this is experimental and may cause unexpected errors: 78 | 79 | #. 80 | Add the following to the beggining of the main script: 81 | 82 | .. code-block:: python 83 | 84 | import tftracer 85 | tftracer.hook_inject() 86 | ... 87 | 88 | 89 | #. 90 | Run your code and browse to: 91 | 92 | .. code-block:: html 93 | 94 | http://0.0.0.0:9999 95 | 96 | 97 | Command line 98 | ============ 99 | 100 | Tracing sessions can be stored either through the web interface or by calling :func:`tftracer.TracingServer.save_session`. 101 | 102 | To reload a session, run this in the terminal: 103 | 104 | .. code-block:: bash 105 | 106 | tftracer filename 107 | 108 | Then browse to: 109 | 110 | .. code-block:: html 111 | 112 | http://0.0.0.0:9999 113 | 114 | Full Usage 115 | ---------- 116 | .. code-block:: bash 117 | 118 | usage: tftracer [-h] [--port PORT] [--ip IP] session_file 119 | 120 | positional arguments: 121 | session_file Path to the trace session file 122 | 123 | optional arguments: 124 | -h, --help show this help message and exit 125 | --port PORT To what TCP port web server to listen 126 | --ip IP To what IP address web server to listen 127 | 128 | Examples 129 | ======== 130 | .. glossary:: 131 | 132 | Higher-Level API <`estimator-example.py `__> 133 | Example of using :class:`tftracer.TracingServer` with TensorFlow ``estimator`` API. 134 | 135 | Low-Level API <`monitoredtrainingsession-example.py `__> 136 | Example of using :class:`tftracer.TracingServer` with TensorFlow ``MonitoredTrainingSession`` API. 137 | 138 | Monkey Patching <`monkey_patching-example.py `__> 139 | Example of using :func:`tftracer.hook_inject` to trace a script without any modifications. 140 | 141 | Horovod: One Process <`horovod-example.py `__> 142 | Example of using :class:`tftracer.TracingServer` with ``horovod``. In this example only the one process is being traced. 143 | 144 | Horovod: All Processes <`horovod-all-example.py `__> 145 | Example of using :class:`tftracer.TracingServer` with ``horovod``. In this example all processes are being traced. 146 | 147 | Timeline <`timeline-example.py `__> 148 | Example of using :class:`tftracer.Timeline` to trace and visualize one ``session.run`` call without a tracing server. 149 | 150 | Load Session <`load_session-example.py `__> 151 | Example of saving and loading tracing sessions. 152 | 153 | TracingServer Options <`options-example.py `__> 154 | Example of setting tracing options. 155 | 156 | 157 | API Reference 158 | ============= 159 | 160 | tftracer.TracingServer 161 | ---------------------- 162 | .. autoclass:: tftracer.TracingServer 163 | :members: 164 | :inherited-members: 165 | :undoc-members: 166 | 167 | tftracer.Timeline 168 | ----------------- 169 | .. autoclass:: tftracer.Timeline 170 | :members: 171 | :inherited-members: 172 | :undoc-members: 173 | :exclude-members: communication_elapsed_time, communication_time, computation_time 174 | 175 | tftracer.hook_inject 176 | -------------------- 177 | .. automodule:: tftracer 178 | :members: hook_inject 179 | 180 | 181 | Known Bugs/Limitations 182 | ====================== 183 | 184 | * Only Python3 is supported. 185 | * The web interface loads javascript/css libraries remotely (e.g. ``vue.js``\ , ``ui-kit``\ , ``jquery``\ , ``jquery-ui``\ , ``Google Roboto``\ , ``awesome-icons``\ , ... ). Therefore an active internet connection is needed to properly render the interface. The tracing server does not require any remote connection. 186 | * All traces are kept in the memory while tracing server is running. 187 | * Tracing uses ``tf.train.SessionRunHook`` and is unable to trace auxiliary runs such as ``init_op``. 188 | * The tracing capability is limited to what ``tf.RunMetadata`` offers. For example, CUPTI events are missing when tracing a distributed job. 189 | * HTTPS is not supported. 190 | 191 | 192 | Frequently Asked Questions 193 | ========================== 194 | 195 | How to trace/visualize just one session run? 196 | -------------------------------------------- 197 | 198 | Use ``tftracer.Timeline``. for example: 199 | 200 | .. code-block:: python 201 | 202 | from tftracer import Timeline 203 | ... 204 | with tf.train.MonitoredTrainingSession() as sess: 205 | with Timeline() as tl: 206 | sess.run(fetches, **tl.kwargs) 207 | ... 208 | tl.visualize(filename) 209 | 210 | Comparision to TensorBoard? 211 | --------------------------- 212 | 213 | The nature of this project is a short-lived light-weight interactive tracing interface to monitor and trace execution on the ``op``\ -level. In comparison ``TensorBoard`` is a full-featured tool to inspect the application on many levels: 214 | 215 | 216 | * 217 | ``tftracer`` does not make any assumption about the dataflow DAG. There is no need to add any additional ``op`` to the data flow dag (i.e. ``tf.summary``\ ) or having a ``global step``. 218 | 219 | * 220 | ``tftracer`` runs as a thread and lives from the start of the execution and lasts until the end of it. ``TensorBoard`` runs as a separate process and can outlive the main script. 221 | -------------------------------------------------------------------------------- /docs/nbuild/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xldrx/tensorflow-tracer/0db65fe55c2b6acd55d37112effc17aa90886bcf/docs/nbuild/.gitkeep -------------------------------------------------------------------------------- /docs/nstatic/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xldrx/tensorflow-tracer/0db65fe55c2b6acd55d37112effc17aa90886bcf/docs/nstatic/.gitkeep -------------------------------------------------------------------------------- /docs/ntemplates/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xldrx/tensorflow-tracer/0db65fe55c2b6acd55d37112effc17aa90886bcf/docs/ntemplates/.gitkeep -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | ### Higher-Level API 4 | [estimator-example.py](https://github.com/xldrx/tensorflow-tracer/blob/master/examples/estimator-example.py) 5 |
6 | Example of using `tftracer.TracingServer` with TensorFlow ``estimator`` API. 7 | 8 | ### Low-Level API 9 | [monitoredtrainingsession-example.py](https://github.com/xldrx/tensorflow-tracer/blob/master/examples/monitoredtrainingsession-example.py) 10 |
11 | Example of using `tftracer.TracingServer` with TensorFlow `MonitoredTrainingSession` API. 12 | 13 | ### Monkey Patching 14 | [monkey_patching-example.py](https://github.com/xldrx/tensorflow-tracer/blob/master/examples/monkey_patching-example.py) 15 |
16 | Example of using `tftracer.hook_inject` to trace a script without any modifications. 17 | 18 | ### Horovod: One Process 19 | [horovod-example.py](https://github.com/xldrx/tensorflow-tracer/blob/master/examples/horovod-example.py) 20 |
21 | Example of using `tftracer.TracingServer` with `horovod`. In this example only the one process is being traced. 22 | 23 | ### Horovod: All Processes 24 | [horovod-all-example.py](https://github.com/xldrx/tensorflow-tracer/blob/master/examples/horovod-all-example.py) 25 |
26 | Example of using `tftracer.TracingServer` with `horovod`. In this example all processes are being traced. 27 | 28 | ### Timeline 29 | [timeline-example.py](https://github.com/xldrx/tensorflow-tracer/blob/master/examples/timeline-example.py) 30 |
31 | Example of using :class:`tftracer.Timeline` to trace and visualize one ``session.run`` call without a tracing server. 32 | 33 | ### Load Session 34 | [load_session-example.py](https://github.com/xldrx/tensorflow-tracer/blob/master/examples/load_session-example.py) 35 |
36 | Example of saving and loading tracing sessions. 37 | 38 | ### TracingServer Options 39 | [options-example.py](https://github.com/xldrx/tensorflow-tracer/blob/master/examples/options-example.py) 40 |
41 | Example of setting tracing options. 42 | -------------------------------------------------------------------------------- /examples/estimator-example.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python -u 2 | # coding=utf-8 3 | 4 | # Using tracing server with TesorFlow Estimator API 5 | 6 | __author__ = 'Sayed Hadi Hashemi' 7 | import tensorflow as tf 8 | import numpy as np 9 | from tftracer import TracingServer 10 | 11 | INPUT_SIZE = (299, 299, 3) 12 | MINIBATCH_SIZE = 128 13 | NUM_CLASSES = 1000 14 | NUM_STEPS = 500 15 | 16 | 17 | def input_fn(): 18 | dataset = tf.data.Dataset.from_tensor_slices([0]).repeat(MINIBATCH_SIZE) 19 | dataset = dataset.map( 20 | lambda _: 21 | ( 22 | {"x": np.random.uniform(size=INPUT_SIZE)}, 23 | [np.random.random_integers(0, NUM_CLASSES)] 24 | ) 25 | ) 26 | dataset = dataset.repeat(NUM_STEPS).batch(MINIBATCH_SIZE) 27 | return dataset 28 | 29 | 30 | def main(): 31 | estimator = tf.estimator.DNNClassifier( 32 | hidden_units=[10] * 150, 33 | feature_columns=[tf.feature_column.numeric_column("x", shape=INPUT_SIZE)], 34 | n_classes=NUM_CLASSES, 35 | ) 36 | tracing_server = TracingServer() 37 | estimator.train(input_fn, hooks=[tracing_server.hook]) 38 | estimator.evaluate(input_fn, hooks=[tracing_server.hook]) 39 | 40 | # Save the tracing session 41 | tracing_server.save_session("session.pickle") 42 | 43 | # Keep the tracing server running beyond training. Remove otherwise. 44 | tracing_server.join() 45 | 46 | 47 | if __name__ == '__main__': 48 | tf.logging.set_verbosity(tf.logging.INFO) 49 | main() 50 | -------------------------------------------------------------------------------- /examples/horovod-all-example.py: -------------------------------------------------------------------------------- 1 | # ! /usr/bin/env python -u 2 | # coding=utf-8 3 | 4 | __author__ = 'Sayed Hadi Hashemi' 5 | 6 | import tensorflow as tf 7 | import horovod.tensorflow as hvd 8 | from tensorflow.contrib.slim.nets import inception 9 | from tftracer import TracingServer 10 | 11 | INPUT_SIZE = [299, 299, 3] 12 | MINIBATCH_SIZE = 4 13 | NUM_CLASSES = 1000 14 | NUM_STEPS = 200 15 | 16 | 17 | def get_model(): 18 | input_data = tf.random_uniform([MINIBATCH_SIZE] + INPUT_SIZE) 19 | labels = tf.random_uniform([MINIBATCH_SIZE, NUM_CLASSES]) 20 | logit, _ = inception.inception_v3(input_data, num_classes=NUM_CLASSES) 21 | loss = tf.losses.softmax_cross_entropy(labels, logit) 22 | train_op = hvd.DistributedOptimizer(tf.train.MomentumOptimizer(0.01, 0.01)).minimize(loss) 23 | return train_op 24 | 25 | 26 | def get_config(): 27 | config = tf.ConfigProto() 28 | config.gpu_options.allow_growth = True 29 | config.gpu_options.visible_device_list = str(hvd.local_rank()) 30 | 31 | if hvd.rank() == 0: 32 | tf.logging.set_verbosity(tf.logging.INFO) 33 | else: 34 | tf.logging.set_verbosity(tf.logging.WARN) 35 | 36 | return dict(config=config) 37 | 38 | 39 | def main(_): 40 | hvd.init() 41 | train_op = get_model() 42 | 43 | hooks = [ 44 | hvd.BroadcastGlobalVariablesHook(0), 45 | ] 46 | 47 | # Assign a different TCP port to processes colocated on a same node 48 | server_port = 9999 + hvd.local_rank() 49 | tracing_server = TracingServer(server_port=server_port, is_horovod=True) 50 | hooks.append(tracing_server.hook) 51 | 52 | with tf.train.MonitoredTrainingSession(hooks=hooks, **get_config()) as sess: 53 | for _ in range(NUM_STEPS): 54 | sess.run(train_op) 55 | 56 | # Save the tracing session 57 | tracing_server.save_session("session-{}.pickle".format(hvd.rank())) 58 | 59 | # Keep the tracing server running beyond training. Remove otherwise. 60 | tracing_server.join() 61 | 62 | 63 | if __name__ == "__main__": 64 | tf.app.run() 65 | -------------------------------------------------------------------------------- /examples/horovod-example.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python -u 2 | # coding=utf-8 3 | 4 | __author__ = 'Sayed Hadi Hashemi' 5 | 6 | import tensorflow as tf 7 | import horovod.tensorflow as hvd 8 | from tensorflow.contrib.slim.nets import inception 9 | from tftracer import TracingServer 10 | 11 | INPUT_SIZE = [299, 299, 3] 12 | MINIBATCH_SIZE = 4 13 | NUM_CLASSES = 1000 14 | NUM_STEPS = 200 15 | 16 | 17 | def get_model(): 18 | input_data = tf.random_uniform([MINIBATCH_SIZE] + INPUT_SIZE) 19 | labels = tf.random_uniform([MINIBATCH_SIZE, NUM_CLASSES]) 20 | logit, _ = inception.inception_v3(input_data, num_classes=NUM_CLASSES) 21 | loss = tf.losses.softmax_cross_entropy(labels, logit) 22 | train_op = hvd.DistributedOptimizer(tf.train.MomentumOptimizer(0.01, 0.01)).minimize(loss) 23 | return train_op 24 | 25 | 26 | def get_config(): 27 | config = tf.ConfigProto() 28 | config.gpu_options.allow_growth = True 29 | config.gpu_options.visible_device_list = str(hvd.local_rank()) 30 | 31 | if hvd.rank() == 0: 32 | tf.logging.set_verbosity(tf.logging.INFO) 33 | else: 34 | tf.logging.set_verbosity(tf.logging.WARN) 35 | 36 | return dict(config=config) 37 | 38 | 39 | def main(_): 40 | print(_) 41 | hvd.init() 42 | train_op = get_model() 43 | 44 | hooks = [ 45 | hvd.BroadcastGlobalVariablesHook(0), 46 | ] 47 | 48 | if hvd.rank() == 0: 49 | tracing_server = TracingServer() 50 | hooks.append(tracing_server.hook) 51 | 52 | with tf.train.MonitoredTrainingSession(hooks=hooks, **get_config()) as sess: 53 | for _ in range(NUM_STEPS): 54 | sess.run(train_op) 55 | 56 | if hvd.rank() == 0: 57 | # Save the tracing session 58 | tracing_server.save_session("session.pickle") 59 | 60 | # Keep the tracing server running beyond training. Remove otherwise. 61 | tracing_server.join() 62 | 63 | 64 | if __name__ == "__main__": 65 | tf.app.run() 66 | -------------------------------------------------------------------------------- /examples/load_session-example.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python -u 2 | # coding=utf-8 3 | import time 4 | 5 | from tftracer import TracingServer 6 | 7 | __author__ = 'Sayed Hadi Hashemi' 8 | if __name__ == '__main__': 9 | server = TracingServer() 10 | server.load_session("session.pickle") 11 | 12 | # TODO(xldrx): Slow Server Workaround 13 | time.sleep(5) 14 | 15 | server.join() 16 | -------------------------------------------------------------------------------- /examples/monitoredtrainingsession-example.py: -------------------------------------------------------------------------------- 1 | # #! /usr/bin/env python -u 2 | # # coding=utf-8 3 | 4 | # Using tracing server with TesorFlow low level API 5 | 6 | __author__ = 'Sayed Hadi Hashemi' 7 | import tensorflow as tf 8 | from tensorflow.contrib.slim.nets import inception 9 | from tftracer import TracingServer 10 | 11 | INPUT_SIZE = [299, 299, 3] 12 | MINIBATCH_SIZE = 4 13 | NUM_CLASSES = 1000 14 | NUM_STEPS = 200 15 | 16 | 17 | def get_model(): 18 | input_data = tf.random_uniform([MINIBATCH_SIZE] + INPUT_SIZE) 19 | labels = tf.random_uniform([MINIBATCH_SIZE, NUM_CLASSES]) 20 | logit, _ = inception.inception_v3(input_data, num_classes=NUM_CLASSES) 21 | loss = tf.losses.softmax_cross_entropy(labels, logit) 22 | train_op = tf.train.MomentumOptimizer(0.01, 0.01).minimize(loss) 23 | return train_op 24 | 25 | 26 | def main(): 27 | train_op = get_model() 28 | tracing_server = TracingServer() 29 | with tf.train.MonitoredTrainingSession(hooks=[tracing_server.hook]) as sess: 30 | for _ in range(NUM_STEPS): 31 | sess.run(train_op) 32 | 33 | # Save the tracing session 34 | tracing_server.save_session("session.pickle") 35 | 36 | # Keep the tracing server running beyond training. Remove otherwise. 37 | tracing_server.join() 38 | 39 | 40 | if __name__ == '__main__': 41 | tf.logging.set_verbosity(tf.logging.INFO) 42 | main() 43 | -------------------------------------------------------------------------------- /examples/monkey_patching-example.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python -u 2 | # coding=utf-8 3 | 4 | # Using tracing server with TesorFlow Estimator API 5 | 6 | __author__ = 'Sayed Hadi Hashemi' 7 | 8 | import tensorflow as tf 9 | 10 | import tftracer 11 | tftracer.hook_inject() 12 | 13 | import numpy as np 14 | 15 | INPUT_SIZE = (299, 299, 3) 16 | MINIBATCH_SIZE = 128 17 | NUM_CLASSES = 1000 18 | NUM_STEPS = 500 19 | 20 | 21 | def input_fn(): 22 | dataset = tf.data.Dataset.from_tensor_slices([0]).repeat(MINIBATCH_SIZE) 23 | dataset = dataset.map( 24 | lambda _: 25 | ( 26 | {"x": np.random.uniform(size=INPUT_SIZE)}, 27 | [np.random.random_integers(0, NUM_CLASSES)] 28 | ) 29 | ) 30 | dataset = dataset.repeat(NUM_STEPS).batch(MINIBATCH_SIZE) 31 | return dataset 32 | 33 | 34 | def main(): 35 | estimator = tf.estimator.DNNClassifier( 36 | hidden_units=[10] * 150, 37 | feature_columns=[tf.feature_column.numeric_column("x", shape=INPUT_SIZE)], 38 | n_classes=NUM_CLASSES, 39 | ) 40 | estimator.train(input_fn) 41 | estimator.evaluate(input_fn) 42 | 43 | 44 | if __name__ == '__main__': 45 | tf.logging.set_verbosity(tf.logging.INFO) 46 | main() 47 | -------------------------------------------------------------------------------- /examples/options-example.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python -u 2 | # coding=utf-8 3 | 4 | # Using tracing server with TesorFlow Estimator API 5 | 6 | __author__ = 'Sayed Hadi Hashemi' 7 | import tensorflow as tf 8 | import numpy as np 9 | from tftracer import TracingServer 10 | 11 | # Options 12 | SERVER_IP = "0.0.0.0" 13 | SERVER_PORT = 9999 14 | KEEP_TRACES = 5 15 | START_WEB_SERVER_ON_START = True 16 | 17 | INPUT_SIZE = (299, 299, 3) 18 | MINIBATCH_SIZE = 128 19 | NUM_CLASSES = 1000 20 | NUM_STEPS = 500 21 | 22 | def input_fn(): 23 | dataset = tf.data.Dataset.from_tensor_slices([0]).repeat(MINIBATCH_SIZE) 24 | dataset = dataset.map( 25 | lambda _: 26 | ( 27 | {"x": np.random.uniform(size=INPUT_SIZE)}, 28 | [np.random.random_integers(0, NUM_CLASSES)] 29 | ) 30 | ) 31 | dataset = dataset.repeat(NUM_STEPS).batch(MINIBATCH_SIZE) 32 | return dataset 33 | 34 | 35 | def main(): 36 | estimator = tf.estimator.DNNClassifier( 37 | hidden_units=[10] * 150, 38 | feature_columns=[tf.feature_column.numeric_column("x", shape=INPUT_SIZE)], 39 | n_classes=NUM_CLASSES, 40 | ) 41 | tracing_server = TracingServer( 42 | server_ip=SERVER_IP, 43 | server_port=SERVER_PORT, 44 | keep_traces=KEEP_TRACES, 45 | start_web_server_on_start = START_WEB_SERVER_ON_START 46 | ) 47 | estimator.train(input_fn, hooks=[tracing_server.hook]) 48 | estimator.evaluate(input_fn, hooks=[tracing_server.hook]) 49 | 50 | # Save the tracing session 51 | tracing_server.save_session("session.pickle") 52 | 53 | # Keep the tracing server running beyond training. Remove otherwise. 54 | tracing_server.join() 55 | 56 | 57 | if __name__ == '__main__': 58 | tf.logging.set_verbosity(tf.logging.INFO) 59 | main() 60 | -------------------------------------------------------------------------------- /examples/timeline-example.py: -------------------------------------------------------------------------------- 1 | # #! /usr/bin/env python -u 2 | # # coding=utf-8 3 | 4 | # Using tracing server with TesorFlow low level API 5 | import webbrowser 6 | 7 | __author__ = 'Sayed Hadi Hashemi' 8 | import tensorflow as tf 9 | from tensorflow.contrib.slim.nets import inception 10 | from tftracer import Timeline 11 | 12 | INPUT_SIZE = [299, 299, 3] 13 | MINIBATCH_SIZE = 4 14 | NUM_CLASSES = 1000 15 | NUM_STEPS = 200 16 | 17 | 18 | def get_model(): 19 | input_data = tf.random_uniform([MINIBATCH_SIZE] + INPUT_SIZE) 20 | labels = tf.random_uniform([MINIBATCH_SIZE, NUM_CLASSES]) 21 | logit, _ = inception.inception_v3(input_data, num_classes=NUM_CLASSES) 22 | loss = tf.losses.softmax_cross_entropy(labels, logit) 23 | train_op = tf.train.MomentumOptimizer(0.01, 0.01).minimize(loss) 24 | return train_op 25 | 26 | 27 | def main(): 28 | train_op = get_model() 29 | with tf.train.MonitoredTrainingSession() as sess: 30 | with Timeline() as timeline: 31 | sess.run(train_op, **timeline.kwargs) 32 | 33 | # Save 34 | timeline.to_pickle("step.pickle") 35 | 36 | # Load 37 | timeline = Timeline.from_pickle("step.pickle") 38 | 39 | # Visualize 40 | timeline.visualize("step.html") 41 | webbrowser.open_new("step.html") 42 | 43 | 44 | if __name__ == '__main__': 45 | tf.logging.set_verbosity(tf.logging.INFO) 46 | main() 47 | -------------------------------------------------------------------------------- /gallery/README.md: -------------------------------------------------------------------------------- 1 | # Screenshots 2 | 3 | ### Main Interface 4 | ![Main Interface](https://github.com/xldrx/tensorflow-tracer/blob/master/gallery/main.png?raw=true "Main Interface") 5 | 6 | ### Global Tracing 7 | ![Global Tracing](https://github.com/xldrx/tensorflow-tracer/blob/master/gallery/global-tracing.png?raw=true "Global Tracing") 8 | 9 | ### Tracing 10 | ![Tracing](https://github.com/xldrx/tensorflow-tracer/blob/master/gallery/tracing.png?raw=true "Tracing") 11 | 12 | ### Timeline 13 | InceptionV3+Horvod 14 | ![Timeline](https://github.com/xldrx/tensorflow-tracer/blob/master/gallery/horovod-timeline.png?raw=true "Timeline") 15 | 16 | ### Timeline Details 17 | ![Timeline Details](https://github.com/xldrx/tensorflow-tracer/blob/master/gallery/horovod-timeline-details.png?raw=true "Timeline Details") 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /gallery/global-tracing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xldrx/tensorflow-tracer/0db65fe55c2b6acd55d37112effc17aa90886bcf/gallery/global-tracing.png -------------------------------------------------------------------------------- /gallery/horovod-timeline-details.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xldrx/tensorflow-tracer/0db65fe55c2b6acd55d37112effc17aa90886bcf/gallery/horovod-timeline-details.png -------------------------------------------------------------------------------- /gallery/horovod-timeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xldrx/tensorflow-tracer/0db65fe55c2b6acd55d37112effc17aa90886bcf/gallery/horovod-timeline.png -------------------------------------------------------------------------------- /gallery/main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xldrx/tensorflow-tracer/0db65fe55c2b6acd55d37112effc17aa90886bcf/gallery/main.png -------------------------------------------------------------------------------- /gallery/tracing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xldrx/tensorflow-tracer/0db65fe55c2b6acd55d37112effc17aa90886bcf/gallery/tracing.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | bokeh>=1.0 2 | flask 3 | jinja2 4 | tensorflow>=1.8 5 | six 6 | gevent 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | import os 3 | from tftracer.version import __version__ 4 | with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), "README.md"), "r") as fp: 5 | long_description = fp.read() 6 | 7 | with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements.txt"), "r") as fp: 8 | requirements = fp.read().split("\n") 9 | 10 | 11 | setup( 12 | name='tensorflow-tracer', 13 | version=__version__, 14 | packages=['tftracer'], 15 | url='https://github.com/xldrx/tensorflow-tracer', 16 | license='Apache-2.0', 17 | author='Sayed Hadi Hashemi', 18 | author_email='SayedHadiHashemi@gmail.com', 19 | description='Runtime Tracing Library for TensorFlow', 20 | long_description=long_description, 21 | long_description_content_type="text/markdown", 22 | classifiers=[ 23 | "Programming Language :: Python :: 3 :: Only", 24 | "Development Status :: 4 - Beta", 25 | "License :: OSI Approved :: Apache Software License", 26 | "Operating System :: OS Independent", 27 | ], 28 | entry_points={ 29 | 'console_scripts': [ 30 | 'tftracer=tftracer.__main__:main', 31 | ], 32 | }, 33 | install_requires=requirements, 34 | package_data={'tftracer': ['resources/*/*']}, 35 | ) 36 | -------------------------------------------------------------------------------- /tftracer/__init__.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python -u 2 | # coding=utf-8 3 | 4 | __author__ = 'Sayed Hadi Hashemi' 5 | 6 | from .timeline import Timeline 7 | from .tracing_server import TracingServer 8 | from .monkey_patching import hook_inject 9 | from .version import __version__ -------------------------------------------------------------------------------- /tftracer/__main__.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python -u 2 | # coding=utf-8 3 | __author__ = 'Sayed Hadi Hashemi' 4 | 5 | import argparse 6 | import errno 7 | import os 8 | import time 9 | import traceback 10 | from . import TracingServer 11 | 12 | FLAGS = None 13 | 14 | 15 | def arg_parser(): 16 | global FLAGS 17 | parser = argparse.ArgumentParser("tftracer") 18 | parser.add_argument( 19 | "--port", 20 | type=int, 21 | help="To what TCP port web server to listen", 22 | default="9999" 23 | ) 24 | parser.add_argument( 25 | "--ip", 26 | type=str, 27 | help="To what IP address web server to listen", 28 | default="0.0.0.0" 29 | ) 30 | parser.add_argument( 31 | "session_file", 32 | type=str, 33 | help="Path to the trace session file" 34 | ) 35 | FLAGS, _ = parser.parse_known_args() 36 | 37 | def main(): 38 | 39 | arg_parser() 40 | 41 | filename = FLAGS.session_file 42 | if not os.path.exists(filename): 43 | print("File not found: {}".format(filename)) 44 | exit(errno.ENOENT) 45 | else: 46 | server = TracingServer(server_port=FLAGS.port, server_ip=FLAGS.ip) 47 | try: 48 | server.load_session(filename) 49 | # todo(xldrx): workaround 50 | time.sleep(5) 51 | server.join() 52 | except Exception as ex: 53 | traceback.print_exc() 54 | print(ex) 55 | server.stop_web_server() 56 | exit() 57 | 58 | 59 | if __name__ == '__main__': 60 | main() 61 | -------------------------------------------------------------------------------- /tftracer/monkey_patching.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python -u 2 | # coding=utf-8 3 | 4 | __author__ = 'Sayed Hadi Hashemi' 5 | 6 | 7 | def __add_tracing_server_hook(hooks): 8 | if hooks is None: 9 | return [hook_inject.__tracing_server.hook] 10 | else: 11 | hooks = list(hooks) 12 | hooks.append(hook_inject.__tracing_server.hook) 13 | return hooks 14 | 15 | 16 | def __new_init(*args, **kwargs): 17 | if hook_inject.__original_init is None: 18 | return 19 | 20 | if "hooks" in hook_inject.__original_init.__code__.co_varnames: 21 | hooks_index = hook_inject.__original_init.__code__.co_varnames.index("hooks") 22 | if len(args) > hooks_index: 23 | args = list(args) 24 | args[hooks_index] = __add_tracing_server_hook(args[hooks_index]) 25 | else: 26 | kwargs["hooks"] = __add_tracing_server_hook(kwargs.get("hooks", None)) 27 | else: 28 | print("'hooks' not in '_MonitoredSession'") 29 | 30 | hook_inject.__original_init(*args, **kwargs) 31 | 32 | 33 | def hook_inject(*args, **kwargs): 34 | """ 35 | (Experimental) Injects a tracing server hook to all instances of ``MonitoredSession`` by by monkey patching 36 | the initializer. This function is an alternative to adding `hooks` to estimator or sessions. 37 | Be aware, monkey patching could cause unexpected errors and is not recommended. 38 | 39 | This function should be called once in the main script preferably before importing anything else. 40 | 41 | Example: 42 | 43 | .. code-block:: python 44 | 45 | import tftracer 46 | tftracer.hook_inject() 47 | ... 48 | 49 | estimator.train(input_fn) 50 | 51 | 52 | Args: 53 | **kwargs: same as :class:`tftracer.TracingServer`. 54 | 55 | Note: 56 | Monkey Patching (as :class:`tftracer.TracingServer`) works only with subclasses of ``MonitoredSession``. 57 | For other ``Session`` types, use :class:`tftracer.Timeline`. 58 | 59 | 60 | 61 | """ 62 | from . import TracingServer 63 | from tensorflow.python.training.monitored_session import _MonitoredSession 64 | 65 | if hook_inject.__original_init is None: 66 | hook_inject.__original_init = _MonitoredSession.__init__ 67 | hook_inject.__tracing_server = TracingServer(*args, **kwargs) 68 | _MonitoredSession.__init__ = __new_init 69 | 70 | 71 | hook_inject.__tracing_server = None 72 | hook_inject.__original_init = None 73 | -------------------------------------------------------------------------------- /tftracer/resources/templates/on_change_callback.js: -------------------------------------------------------------------------------- 1 | console.log(button); 2 | button.disabled = false; 3 | button.css_classes = ['center-block', 'xl-sync']; -------------------------------------------------------------------------------- /tftracer/resources/templates/on_click_callback.js: -------------------------------------------------------------------------------- 1 | content = ""; 2 | let elems = $(".bk-tooltip"); 3 | let len = elems.length - 1; 4 | elems.each(function (i) { 5 | content += $(this).html(); 6 | console.log(content); 7 | if (len === i) { 8 | $("#xl-toolbox").html(content); 9 | } 10 | }); 11 | UIkit.modal('#modal-details').show(); -------------------------------------------------------------------------------- /tftracer/resources/templates/on_hover_callback.js: -------------------------------------------------------------------------------- 1 | let tooltips = document.getElementsByClassName("bk-tooltip"); 2 | 3 | for (var i = 0, len = tooltips.length; i < len; i ++) { 4 | let el = tooltips[i]; 5 | el.style.zIndex = "1002"; 6 | let viewportOffset = el.getBoundingClientRect(); 7 | let left = viewportOffset.left; 8 | let right = viewportOffset.right; 9 | let width = (window.innerWidth || document.documentElement.clientWidth); 10 | width -= 20; 11 | let current_left = parseFloat(tooltips[i].style.left); 12 | if (left < 0) { 13 | tooltips[i].style.left = (current_left - left) + "px"; 14 | } else if (right >= width) { 15 | tooltips[i].style.left = (current_left - right + width ) + "px"; 16 | } 17 | } -------------------------------------------------------------------------------- /tftracer/resources/templates/timeline.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 10 | 11 | 12 | 13 | 14 | 16 | 17 | 20 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | {{ title }} 33 | 34 | {{ js_resources|indent(4)|safe }} 35 | 36 | {{ css_resources|indent(4)|safe }} 37 | 38 | {{ plot_script|indent(4)|safe }} 39 | 40 | 105 | 106 | {{ custom_header|indent(4)|safe }} 107 | 108 | 109 | 110 | 111 |
112 | 119 |
120 | 121 |
122 |
123 |
124 | {{ plot_div|indent(4)|safe }} 125 |
126 |
127 |
128 |
129 | 130 | 147 | 148 | 151 | 152 | -------------------------------------------------------------------------------- /tftracer/resources/templates/tooltips.html: -------------------------------------------------------------------------------- 1 |
2 |
3 |
Name
4 |
@name
5 | 6 |
op
7 |
@op
8 | 9 |
Inputs
10 |
11 |
    @inputs{safe}
12 |
13 | 14 |
Description
15 |
@description
16 |
17 |
18 |
19 |
Elapsed
20 |
@duration ms
21 |
22 |
23 |
Start
24 |
@start ms
25 |
26 |
27 |
End
28 |
@end ms
29 |
30 |
31 |
32 | -------------------------------------------------------------------------------- /tftracer/resources/templates/update_ranges.js: -------------------------------------------------------------------------------- 1 | console.log(me); 2 | let start = me.x_range.start; 3 | let end = me.x_range.end; 4 | console.log(start, end); 5 | 6 | for (var model_id in me.document._all_models) { 7 | model = me.document._all_models[model_id]; 8 | if (model.type === "Plot") { 9 | console.log(model); 10 | console.log(model.x_range); 11 | //model.x_range.set('start', start); 12 | //model.x_range.set('end', end); 13 | model.x_range.start = start; 14 | model.x_range.end = end; 15 | } 16 | if (model.type === "Button") { 17 | model.disabled = true; 18 | if (model.label === " Sync") { 19 | model.css_classes = ['xl-hidden']; 20 | } 21 | } 22 | 23 | } -------------------------------------------------------------------------------- /tftracer/resources/web/app.js: -------------------------------------------------------------------------------- 1 | Vue.component('run-card', { 2 | props: ['run'], 3 | template: "#run-card-template", 4 | methods: { 5 | trace: function () { 6 | var card = this; 7 | card["tracing"] = true; 8 | fetch(card.run.trace_url) 9 | .then(function (response) { 10 | app.connection_error = false; 11 | app.update_data(); 12 | }) 13 | .catch(function () { 14 | app.connection_error = true; 15 | card["tracing"] = false; 16 | }); 17 | }, 18 | trace_id: function (trace_id) { 19 | 20 | }, 21 | hide_trace_spinner: function (name) { 22 | console.log(name) 23 | var x = document.getElementsByClassName(name); 24 | for (var i = 0; i < x.length; i++) { 25 | x[i].classList.add("uk-hidden"); 26 | } 27 | } 28 | }, 29 | }) 30 | 31 | var app = new Vue({ 32 | el: '#app', 33 | data: { 34 | running: true, 35 | updating: false, 36 | global_tracing: false, 37 | runs: [], 38 | connection_error: false 39 | }, 40 | methods: { 41 | update_data: function () { 42 | this.updating = true; 43 | fetch("/update") 44 | .then(function (response) { 45 | return response.json(); 46 | }) 47 | .then(function (data) { 48 | app.connection_error = false; 49 | app.running = data.running; 50 | app.global_tracing = data.global_tracing; 51 | app.runs = data.runs; 52 | setTimeout(function () { 53 | app.updating = false; 54 | }, 1000); 55 | }) 56 | .catch(function () { 57 | app.connection_error = true; 58 | app.updating = false; 59 | }) 60 | }, 61 | cancelAutoUpdate: function () { 62 | clearInterval(this.timer) 63 | }, 64 | enable_global_tracing: function () { 65 | UIkit.modal.confirm('Global tracing imposes a significant runtime overhead. Continue?', 66 | {labels: {ok: "Enable Global Tracing", cancel: "Cancel"}}).then( 67 | function () { 68 | app.global_tracing = true; 69 | fetch("/enable_global_tracing") 70 | .then(function (response) { 71 | app.connection_error = false; 72 | app.update_data(); 73 | }) 74 | .catch(function () { 75 | app.connection_error = true; 76 | app.global_tracing = false; 77 | }); 78 | }); 79 | }, 80 | disable_global_tracing: function () { 81 | app.global_tracing = false; 82 | fetch("/disable_global_tracing") 83 | .then(function (response) { 84 | app.connection_error = false; 85 | app.update_data(); 86 | }) 87 | .catch(function () { 88 | app.connection_error = true; 89 | app.global_tracing = true; 90 | }); 91 | }, 92 | kill_tracing_server: function () { 93 | UIkit.modal.confirm('Are you sure?', {labels: {ok: "Kill Tracing Server", cancel: "Cancel"}}) 94 | .then(function () { 95 | fetch("/kill_tracing_server") 96 | .then(function (response) { 97 | app.connection_error = false; 98 | }) 99 | .catch(function () { 100 | app.connection_error = true 101 | }); 102 | }, function () { 103 | }); 104 | }, 105 | }, 106 | created: function () { 107 | this.update_data(); 108 | this.timer = setInterval(this.update_data, 5000) 109 | }, 110 | beforeDestroy() { 111 | clearInterval(this.timer); 112 | }, 113 | }); -------------------------------------------------------------------------------- /tftracer/resources/web/main.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 48 | 49 | 50 | 51 | Runtime Visualization 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 |
65 | 74 |
75 |
76 | 77 |
78 |
79 | 80 |
81 | 82 |
Connection Error.
83 |
84 |
85 |
86 | 87 |
88 |
89 | 90 |
91 | 92 |
Global tracing imposes a significant runtime overhead.
93 |
94 |
95 |
96 | 97 | 98 |
99 |
102 | 103 | 104 | 105 |
106 | 110 |
111 |
112 | 113 |
114 | 115 | 116 |
117 |
118 | 119 | There is an active session. 120 |
121 | 122 |
123 | 124 | There is no active session. 125 |
126 |
127 | 128 | 129 |
130 | 131 | 132 | 136 | 142 |
143 |
144 | 145 | 146 | 147 |
148 |
149 | 150 |
151 |
152 | 153 | 154 |
155 |
156 |
157 |
158 | 162 | 163 | 168 |
169 |
170 |
171 |
172 | 173 | 174 | 267 | 268 | 274 | 275 | 276 | 277 | 278 | -------------------------------------------------------------------------------- /tftracer/resources/web/tensorflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xldrx/tensorflow-tracer/0db65fe55c2b6acd55d37112effc17aa90886bcf/tftracer/resources/web/tensorflow.png -------------------------------------------------------------------------------- /tftracer/timeline.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python -u 2 | # coding=utf-8 3 | from __future__ import absolute_import 4 | from __future__ import with_statement 5 | 6 | import math 7 | import pickle 8 | import time 9 | from io import open 10 | from .timeline_visualizer import DataLoader, TimelineVisualizer 11 | import tensorflow 12 | __author__ = 'Sayed Hadi Hashemi' 13 | 14 | 15 | class Timeline(object): 16 | """ 17 | This class traces a session run and visualizes the execution timeline. 18 | 19 | Example: 20 | 21 | .. code-block:: python 22 | 23 | with Timeline() as tl: 24 | sess.run(fetches, **tl.kwargs) 25 | 26 | Args: 27 | run_metadata (tensorflow.RunMetadata): If set a web server starts on object initialization. (default: true) 28 | """ 29 | def __init__(self, run_metadata=None, **kwargs): 30 | self._elapsed = 0 31 | self._run_metadata = run_metadata 32 | self._options = None 33 | comm_op_name = kwargs.get("comm_op_name", None) 34 | self._comm_op_name = comm_op_name if comm_op_name is not None else "RecvTensor" 35 | 36 | def __is_communication_op(self, op): 37 | if " = HorovodAllreduce(" in op.timeline_label: 38 | return True 39 | else: 40 | return op.node_name == self._comm_op_name 41 | 42 | def __enter__(self): 43 | from tensorflow import RunMetadata, RunOptions 44 | 45 | self.__start = time.time() 46 | self._run_metadata = RunMetadata() 47 | self._options = RunOptions(trace_level=RunOptions.FULL_TRACE, output_partition_graphs=True) 48 | 49 | return self 50 | 51 | def __exit__(self, *args): 52 | self._elapsed = time.time() - self.__start 53 | 54 | @property 55 | def kwargs(self): 56 | """ 57 | Returns a dict of config_pb2.RunOptions. This object should be unpacked and passed to session.run. 58 | 59 | Example: :: 60 | 61 | session.run(fetches, **timeline.kwargs) 62 | """ 63 | if self._run_metadata is None: 64 | raise Exception("TensorFlow is not found") 65 | return dict(run_metadata=self._run_metadata, options=self._options) 66 | 67 | def visualize(self, output_file=None, device_pattern=None): 68 | """ 69 | Visualizes the runtime_metadata and saves it as a HTML file. 70 | Args: 71 | output_file (str): the output file path. If is None, returns the HTML content instead. 72 | device_pattern (str): a regex pattern used to choose which device to be included. 73 | If None, all devices are used. 74 | 75 | Returns: 76 | str: If output_file is None returns the HTML content, otherwise returns None. 77 | 78 | """ 79 | data_loader = DataLoader(self._run_metadata, device_pattern) 80 | visualizer = TimelineVisualizer(data_loader) 81 | return visualizer.visualize(output_file) 82 | 83 | def step_time(self, device_search_pattern=None): 84 | """ 85 | Calculate the step time. 86 | Args: 87 | device_search_pattern (str): a regex pattern used to choose which device to be included. 88 | If None, all devices are used. 89 | 90 | Returns: 91 | float: the time in seconds. 92 | 93 | """ 94 | max_time = 0 95 | min_time = math.inf 96 | device_search = "" if device_search_pattern is None else device_search_pattern 97 | all_ops = [] 98 | for device in [d for d in self._run_metadata.step_stats.dev_stats if device_search in d.device]: 99 | all_ops += device.node_stats 100 | for op in sorted(all_ops, key=lambda a: a.all_start_micros): 101 | min_time = min(min_time, op.all_start_micros) 102 | max_time = max(max_time, op.all_start_micros + op.all_end_rel_micros) 103 | return max_time - min_time if min_time != math.inf else 0 104 | 105 | def communication_elapsed_time(self, device_search_pattern=None, exclude_pattern=None): 106 | max_time = 0 107 | min_time = math.inf 108 | device_search = "" if device_search_pattern is None else device_search_pattern 109 | all_ops = [] 110 | for device in [d for d in self._run_metadata.step_stats.dev_stats if device_search in d.device]: 111 | all_ops += device.node_stats 112 | for op in sorted(all_ops, key=lambda a: a.all_start_micros): 113 | if not self.__is_communication_op(op): 114 | continue 115 | if exclude_pattern is not None and exclude_pattern in op.timeline_label: 116 | continue 117 | min_time = min(min_time, op.all_start_micros) 118 | max_time = max(max_time, op.all_start_micros + op.all_end_rel_micros) 119 | return max_time - min_time if min_time != math.inf else 0 120 | 121 | def communication_time(self, device_search_pattern=None, exclude_pattern=None): 122 | device_search = "" if device_search_pattern is None else device_search_pattern 123 | all_ops = [] 124 | for device in [d for d in self._run_metadata.step_stats.dev_stats if device_search in d.device]: 125 | all_ops += device.node_stats 126 | 127 | last_ = -math.inf 128 | total = 0 129 | for op in sorted(all_ops, key=lambda a: a.all_start_micros): 130 | if not self.__is_communication_op(op): 131 | continue 132 | if exclude_pattern is not None and exclude_pattern in op.timeline_label: 133 | continue 134 | if op.all_start_micros > last_: 135 | total += op.all_end_rel_micros 136 | elif op.all_start_micros + op.all_end_rel_micros > last_: 137 | total += op.all_start_micros + op.all_end_rel_micros - last_ 138 | last_ = max(last_, op.all_start_micros + op.all_end_rel_micros) 139 | 140 | return total 141 | 142 | def computation_time(self, device_search_pattern=None, exclude_pattern=None): 143 | device_search = "" if device_search_pattern is None else device_search_pattern 144 | all_ops = [] 145 | for device in [d for d in self._run_metadata.step_stats.dev_stats if device_search in d.device]: 146 | all_ops += device.node_stats 147 | 148 | last_ = -math.inf 149 | total = 0 150 | for op in sorted(all_ops, key=lambda a: a.all_start_micros): 151 | if exclude_pattern is not None and exclude_pattern in op.timeline_label: 152 | continue 153 | if op.all_start_micros > last_: 154 | total += op.all_end_rel_micros 155 | elif op.all_start_micros + op.all_end_rel_micros > last_: 156 | total += op.all_start_micros + op.all_end_rel_micros - last_ 157 | last_ = max(last_, op.all_start_micros + op.all_end_rel_micros) 158 | 159 | return total 160 | 161 | @property 162 | def wall_clock_elapsed(self): 163 | """ 164 | Time elapsed in ``with Timeline()`` statement. 165 | 166 | Returns: 167 | float: the time in seconds. 168 | """ 169 | return self._elapsed 170 | 171 | @classmethod 172 | def from_pickle(cls, pickle_file_name, **kwargs): 173 | """ 174 | Load a timeline form a pickle file. 175 | 176 | Args: 177 | pickle_file_name (str): pickle file path. 178 | **kwargs: same as Timeline class. 179 | 180 | Returns: 181 | a Timeline object with the content of pickle_file_name. 182 | 183 | """ 184 | with open(pickle_file_name, "rb") as fp: 185 | run_metadata = pickle.load(fp) 186 | return cls(run_metadata=run_metadata, **kwargs) 187 | 188 | def to_pickle(self, pickle_file_name): 189 | """ 190 | Save the timeline in a pickle file. 191 | Returns: 192 | None. 193 | 194 | Raises: 195 | Exception: if the timeline trace is empty. 196 | """ 197 | if self._run_metadata is None: 198 | raise Exception("No data has been collected yet") 199 | 200 | with open(pickle_file_name, "wb") as fp: 201 | pickle.dump(self._run_metadata, fp) 202 | -------------------------------------------------------------------------------- /tftracer/timeline_visualizer.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # coding=utf-8 3 | from __future__ import division 4 | from __future__ import unicode_literals 5 | from __future__ import with_statement 6 | from __future__ import absolute_import 7 | from io import open 8 | import os 9 | import random 10 | import re 11 | from datetime import datetime 12 | 13 | import six 14 | from bokeh.embed import components 15 | from bokeh.layouts import gridplot 16 | from bokeh.models import ColumnDataSource, Range1d, SingleIntervalTicker, WidgetBox, \ 17 | HoverTool, CustomJS, Button, TapTool 18 | from bokeh.plotting import figure 19 | from bokeh.resources import INLINE 20 | from bokeh.util.string import encode_utf8 21 | from jinja2 import Environment, FileSystemLoader 22 | 23 | __author__ = 'Sayed Hadi Hashemi' 24 | 25 | 26 | class TimelineVisualizer: 27 | def __init__(self, data_loader): 28 | self._load_templates() 29 | self._tools = self._get_tools() 30 | self._data_loader = data_loader 31 | self._iteration_time = 0 32 | 33 | def visualize(self, output_file=None): 34 | data = self._data_loader.get_data() 35 | self._iteration_time = max([max([event['end'] for event in device['events']]) for device in data]) 36 | 37 | device_plots = [] 38 | for device in data: 39 | plot, widget_box = self._generate_device_plot(device) 40 | device_plots += [[plot], [widget_box]] 41 | 42 | final_plot = gridplot( 43 | device_plots, 44 | toolbar_options={ 45 | 'logo': None, 46 | }, 47 | sizing_mode='scale_width' 48 | ) 49 | 50 | result = self._export_to_html(final_plot) 51 | 52 | if output_file: 53 | if six.PY2: 54 | with open(output_file, "wb") as fp: 55 | fp.write(result.decode("utf-8").encode('utf-8')) 56 | elif six.PY3: 57 | with open(output_file, "w") as fp: 58 | fp.write(result) 59 | else: 60 | raise Exception("Unsupported Python Version") 61 | else: 62 | return result 63 | 64 | def _load_templates(self): 65 | _template_env = Environment( 66 | loader=FileSystemLoader(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'resources/templates/')), 67 | ) 68 | 69 | self._js_on_change_callback = _template_env.get_template("on_change_callback.js").render() 70 | self._js_on_click_callback = _template_env.get_template("on_click_callback.js").render() 71 | self._js_update_ranges = _template_env.get_template("update_ranges.js").render() 72 | self._js_on_hover_callback = _template_env.get_template("on_hover_callback.js").render() 73 | self._main_template = _template_env.get_template("timeline.html") 74 | self._tooltips_template = _template_env.get_template("tooltips.html").render() 75 | 76 | def _get_tools(self): 77 | def boxed(content, tag='div'): 78 | return "<{tag} class='xl-box'>{content}".format( 79 | content=content, 80 | tag=tag) 81 | 82 | callback = CustomJS(code=self._js_on_hover_callback) 83 | 84 | hover = HoverTool(tooltips=self._tooltips_template, mode='mouse', 85 | point_policy='follow_mouse', attachment='below', show_arrow=False, 86 | anchor="bottom_center", 87 | callback=callback 88 | ) 89 | 90 | tap = TapTool(callback=CustomJS(code=self._js_on_click_callback)) 91 | 92 | tools = "xzoom_in,xzoom_out,xpan,xbox_zoom,xwheel_zoom,xwheel_pan,reset,undo,redo,crosshair".split(',') 93 | tools += [hover, tap] 94 | 95 | return tools 96 | 97 | def _generate_device_plot(self, device_events): 98 | data_source = self._convert_events_to_datasource(device_events['events']) 99 | n_rows = device_events['n_rows'] 100 | if n_rows == 0: 101 | n_rows = 1 102 | elif n_rows == 1: 103 | n_rows = 2 104 | name = device_events['name'] 105 | 106 | plot = figure( 107 | title="{}".format(name), 108 | plot_height=20 * n_rows + 60, 109 | plot_width=1200, 110 | tools=self._tools, 111 | sizing_mode='scale_width', 112 | active_scroll='xwheel_zoom' 113 | ) 114 | plot.hbar( 115 | left='start', 116 | right='end', 117 | y='height', 118 | color='color', 119 | height=0.85, 120 | source=data_source, 121 | hover_fill_alpha=0.5, 122 | line_join='round', 123 | line_cap='round', 124 | hover_line_color='red' 125 | ) 126 | 127 | plot.x_range = Range1d(0, self._iteration_time, bounds="auto") 128 | plot.y_range = Range1d(0, n_rows) 129 | 130 | plot.yaxis.visible = False 131 | plot.ygrid.ticker = SingleIntervalTicker(interval=1) 132 | plot.ygrid.grid_line_color = None 133 | plot.ygrid.band_fill_alpha = 0.1 134 | plot.ygrid.band_fill_color = "gray" 135 | 136 | button = Button(label=" Sync", width=20, button_type='primary', disabled=True) 137 | button.css_classes = ['xl-hidden'] 138 | button.js_on_click( 139 | CustomJS( 140 | args={ 141 | 'me': plot, 142 | }, 143 | code=self._js_update_ranges 144 | ) 145 | ) 146 | 147 | plot.x_range.js_on_change( 148 | 'start', 149 | CustomJS( 150 | args={ 151 | 'button': button, 152 | }, 153 | code=self._js_on_change_callback) 154 | ) 155 | 156 | return plot, WidgetBox(button) 157 | 158 | @staticmethod 159 | def _convert_events_to_datasource(device_data, base_row=0): 160 | def get_col(column_name, offset=None): 161 | if offset: 162 | return [item[column_name] + offset for item in device_data] 163 | else: 164 | return [item[column_name] for item in device_data] 165 | 166 | return ColumnDataSource(data=dict( 167 | duration=get_col("duration"), 168 | start=get_col("start"), 169 | end=get_col('end'), 170 | height=get_col('row', base_row + 0.5), 171 | color=get_col('color'), 172 | row=get_col('row', base_row), 173 | name=get_col('name'), 174 | description=get_col('description'), 175 | details=get_col('details'), 176 | op=get_col('op'), 177 | inputs=["".join(["
  • {}
  • ".format(i) for i in item.split()]) for item in get_col('inputs')] 178 | )) 179 | 180 | def _export_to_html(self, plot): 181 | js_resources = INLINE.render_js() 182 | css_resources = INLINE.render_css() 183 | 184 | script, div = components(plot) 185 | html = self._main_template.render( 186 | plot_script=script, 187 | plot_div=div, 188 | js_resources=js_resources, 189 | css_resources=css_resources, 190 | title="TensorFlow Timeline", 191 | header=str(datetime.now()), 192 | custom_css='', 193 | custom_header='', 194 | custom_js='' 195 | ) 196 | 197 | return encode_utf8(html) 198 | 199 | 200 | class DataLoader: 201 | def __init__(self, run_metadata, device_pattern=None): 202 | self._device_pattern_re = re.compile(device_pattern if device_pattern else "^.*$") 203 | self._step_stats = run_metadata.step_stats 204 | self.comm_op_name = "RecvTensor" 205 | 206 | @staticmethod 207 | def _assign_row(events): 208 | rows = [] 209 | for event in sorted(events, key=lambda x: x['start']): 210 | assigned = False 211 | for i, row in enumerate(rows): 212 | if row <= event['start']: 213 | event['row'] = i 214 | rows[i] = event['end'] 215 | assigned = True 216 | break 217 | if not assigned: 218 | event['row'] = len(rows) 219 | rows.append(event['end']) 220 | return len(rows) 221 | 222 | @staticmethod 223 | def _assign_color(events): 224 | for event in events: 225 | rand = random.Random(event['op']) 226 | event['color'] = "#%02x%02x%02x" % (rand.randint(0, 256), rand.randint(0, 256), rand.randint(0, 256)) 227 | 228 | @staticmethod 229 | def _parse_event_description(label): 230 | """Parses the fields in a node timeline label.""" 231 | # Expects labels of the form: name = op(arg, arg, ...). 232 | match = re.match(r'(.*) = (.*)\((.*)\)', label) 233 | if match is None: 234 | return 'unknown', 'unknown', [] 235 | nn, op, inputs = match.groups() 236 | if not inputs: 237 | inputs = [] 238 | else: 239 | inputs = inputs.split(', ') 240 | return nn, op, inputs 241 | 242 | def _fix_op_names(self, events): 243 | for event in events: 244 | _, op, inputs = self._parse_event_description(event['description']) 245 | if op == "unknown": 246 | op = event['name'] 247 | inputs = "" 248 | event['op'] = op 249 | event['inputs'] = "\n\n".join(inputs) 250 | 251 | def _process_device(self, device_name, node_stats, base_timestamp): 252 | device_events = [] 253 | 254 | for node in node_stats: 255 | device_events.append(dict( 256 | start=(node.all_start_micros - base_timestamp) / 1000, 257 | end=(max(node.all_end_rel_micros, 1) + node.all_start_micros - base_timestamp) / 1000, 258 | duration=node.all_end_rel_micros / 1000, 259 | name=node.node_name, 260 | description=node.timeline_label, 261 | details=str(node).replace("\n", "\n\n") 262 | )) 263 | 264 | self._fix_op_names(device_events) 265 | self._assign_color(device_events) 266 | n_rows = self._assign_row(device_events) 267 | 268 | return dict( 269 | name=device_name, 270 | n_rows=n_rows, 271 | events=device_events 272 | ) 273 | 274 | def _find_minimum_timestamp(self): 275 | return min([ 276 | min([node.all_start_micros for node in device.node_stats]) 277 | for device in self._step_stats.dev_stats if len(self._device_pattern_re.findall(device.device)) > 0 278 | ]) 279 | 280 | def __is_communication_op(self, op): 281 | if " = HorovodAllreduce(" in op.timeline_label: 282 | return True 283 | else: 284 | return op.node_name == self.comm_op_name 285 | 286 | def get_data(self): 287 | stats = self._step_stats 288 | events = [] 289 | 290 | base_timestamp = self._find_minimum_timestamp() 291 | 292 | for device in stats.dev_stats: 293 | device_name = device.device 294 | if len(self._device_pattern_re.findall(device_name)) == 0: 295 | print(("ignoring device: {}".format(device_name))) 296 | continue 297 | device_events = self._process_device( 298 | device_name, 299 | [node for node in device.node_stats if 300 | not self.__is_communication_op(node)], 301 | base_timestamp 302 | ) 303 | if len(device_events["events"]) > 0: 304 | events.append(device_events) 305 | 306 | device_events = self._process_device( 307 | device_name + " (Communication)", 308 | [node for node in device.node_stats if 309 | self.__is_communication_op(node)], 310 | base_timestamp 311 | ) 312 | if len(device_events["events"]) > 0: 313 | events.append(device_events) 314 | 315 | events.sort(key=lambda x: x['name']) 316 | return events 317 | -------------------------------------------------------------------------------- /tftracer/tracing_server.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python -u 2 | # coding=utf-8 3 | import gzip 4 | import pickle 5 | 6 | from io import BytesIO 7 | 8 | __author__ = 'Sayed Hadi Hashemi' 9 | 10 | import datetime 11 | import json 12 | import os 13 | from collections import OrderedDict 14 | from copy import deepcopy 15 | from gevent.pywsgi import WSGIServer 16 | import flask 17 | import threading 18 | from .timeline import Timeline 19 | import tensorflow as tf 20 | from .version import __version__ 21 | 22 | class VisualizationServerBase: 23 | def __init__(self, server_port=9999, server_ip="0.0.0.0", **kwargs): 24 | self._server_port = server_port 25 | self._server_ip = server_ip 26 | self._wsgi_server = None 27 | self._flask_app = None 28 | self._server_thread = None 29 | 30 | def _start_server(self): 31 | self._flask_app = self._get_flask_app() 32 | self._wsgi_server = WSGIServer((self._server_ip, self._server_port), self._flask_app) 33 | self._wsgi_server.serve_forever() 34 | 35 | def start_web_server(self): 36 | """ 37 | Start a web server in a separate thead. 38 | 39 | Note: 40 | The tracing server keeps track of session runs even without a running web server. 41 | """ 42 | if not self._server_thread: 43 | self._server_thread = threading.Thread(target=self._start_server) 44 | self._server_thread.start() 45 | tf.logging.warn("Tracing Server: http://{}:{}/".format(self._server_ip, 46 | self._server_port)) 47 | 48 | def stop_web_server(self): 49 | """ 50 | Stop the web server. 51 | 52 | Note: 53 | The tracing server keeps track of session runs even after the web server is stopped. 54 | """ 55 | if self._wsgi_server.started: 56 | self._wsgi_server.stop() 57 | if threading.current_thread() != self._server_thread: 58 | self._server_thread.join() 59 | self._server_thread = None 60 | 61 | def join(self): 62 | """ 63 | Wait until the web server is stopped. 64 | """ 65 | if self._wsgi_server.started: 66 | self._server_thread.join() 67 | 68 | def _get_flask_app(self): 69 | raise NotImplemented 70 | 71 | 72 | class VisualizationServer(VisualizationServerBase): 73 | _static_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'resources/web/') 74 | 75 | def __init__(self, name, source, **kwargs): 76 | super().__init__(**kwargs) 77 | self._name = name 78 | self._source = source 79 | self._keep_traces = kwargs.get("keep_traces", 5) 80 | 81 | def _handle_update(self): 82 | runs = self._source.get_runs() 83 | response = { 84 | "running": self._source.running, 85 | "global_tracing": self._source.global_tracing, 86 | "runs": deepcopy(runs), 87 | } 88 | for run in response["runs"]: 89 | run_id = run["run_id"] 90 | run["trace_url"] = "/trace/{}".format(run_id) 91 | 92 | run["stats"]["runtime_avg"] = str(run["stats"]["runtimes"]) 93 | del run["stats"]["runtimes"] 94 | 95 | run["stats"]["first_run"] = str(run["stats"]["first_run"]) 96 | run["stats"]["last_run"] = str(run["stats"]["last_run"]) 97 | 98 | run["traces"] = run["traces"][-self._keep_traces:] 99 | 100 | for trace in run["traces"]: 101 | trace_id = trace["trace_id"] 102 | trace["title"] = str(trace["date"]) 103 | del trace["date"] 104 | trace["url"] = "/{}/{}".format(run_id, trace_id) 105 | trace["download_url"] = "/download/{}/{}".format(run_id, trace_id) 106 | 107 | return json.dumps(response) 108 | 109 | def _handle_main(self): 110 | with open(os.path.join(self._static_folder, "main.html")) as fp: 111 | return fp.read() 112 | 113 | def _handle_timelime(self, run_id, trace_id=0): 114 | run_metadata = self._source.get_trace(run_id, trace_id) 115 | if run_metadata is None: 116 | return flask.redirect("/") 117 | else: 118 | result = Timeline(run_metadata=run_metadata).visualize() 119 | return result 120 | 121 | def _handle_download(self, run_id, trace_id=0): 122 | run_metadata = self._source.get_trace(run_id, trace_id) 123 | if run_metadata is None: 124 | return flask.redirect("/") 125 | else: 126 | fp = BytesIO() 127 | pickle.dump(run_metadata, fp) 128 | fp.seek(0, 0) 129 | return flask.send_file(fp, 130 | as_attachment=True, 131 | attachment_filename="run_metadata-{}-{}.pickle".format(run_id, trace_id)) 132 | 133 | def _handle_save_session(self): 134 | data = pickle.dumps(self._source) 135 | gdata = gzip.compress(data) 136 | fp = BytesIO(gdata) 137 | fp.seek(0, 0) 138 | return flask.send_file(fp, 139 | as_attachment=True, 140 | attachment_filename="tracing-session.pickle.gz") 141 | 142 | def _handle_enable_tracing(self, run_id): 143 | self._source.enable_tracing(run_id) 144 | return flask.redirect("/") 145 | 146 | def _handle_enable_global_tracing(self): 147 | self._source.enable_global_tracing() 148 | return flask.redirect("/") 149 | 150 | def _handle_disable_global_tracing(self): 151 | self._source.disable_global_tracing() 152 | return flask.redirect("/") 153 | 154 | def _handle_kill_server(self): 155 | self.stop_web_server() 156 | return flask.redirect("/") 157 | 158 | def _get_flask_app(self): 159 | app = flask.Flask(self._name, static_folder=self._static_folder, static_url_path="/static") 160 | app.route("/")(self._handle_main) 161 | app.route("//")(self._handle_timelime) 162 | app.route("/download//")(self._handle_download) 163 | app.route("/trace/")(self._handle_enable_tracing) 164 | app.route("/update")(self._handle_update) 165 | app.route("/enable_global_tracing")(self._handle_enable_global_tracing) 166 | app.route("/disable_global_tracing")(self._handle_disable_global_tracing) 167 | app.route("/kill_tracing_server")(self._handle_kill_server) 168 | app.route("/save_session")(self._handle_save_session) 169 | return app 170 | 171 | 172 | class TracingSource: 173 | tftracer_version = __version__ 174 | 175 | def __init__(self, **kwargs): 176 | self._run_profile = OrderedDict() 177 | self._traces = {} 178 | self.global_tracing = False 179 | self.running = False 180 | self._keep_traces = kwargs.get("keep_traces", 5) 181 | 182 | @staticmethod 183 | def get_run_context_key(run_context): 184 | return repr(run_context.original_args) 185 | 186 | def get_trace(self, run_id, trace_id): 187 | if run_id >= len(self._run_profile): 188 | return None 189 | if trace_id >= len(self._traces[run_id]): 190 | return None 191 | return self._traces[run_id][trace_id] 192 | 193 | def get_runs(self): 194 | return list(self._run_profile.values()) 195 | 196 | def enable_tracing(self, run_id): 197 | profile = list(self._run_profile.values())[run_id] 198 | profile["tracing"] = True 199 | 200 | def enable_global_tracing(self): 201 | self.global_tracing = True 202 | 203 | def disable_global_tracing(self): 204 | self.global_tracing = False 205 | 206 | def is_tracing_on(self, run_context): 207 | if self.global_tracing: 208 | return True 209 | key = self.get_run_context_key(run_context) 210 | if key in self._run_profile: 211 | return self._run_profile[key]["tracing"] 212 | else: 213 | return False 214 | 215 | def before_run(self, run_context): 216 | key = self.get_run_context_key(run_context) 217 | 218 | if key not in self._run_profile: 219 | run_id = len(self._run_profile) 220 | profile = { 221 | "info": { 222 | "fetches": repr(run_context.original_args.fetches), 223 | "feeds": repr(run_context.original_args.feed_dict), 224 | "options": repr(run_context.original_args.options) 225 | }, 226 | "stats": { 227 | "runs": 0, 228 | "traces": 0, 229 | "runtimes": datetime.timedelta(microseconds=0), 230 | "first_run": datetime.datetime.now(), 231 | "last_run": datetime.datetime.now(), 232 | }, 233 | "traces": [ 234 | ], 235 | "key": key, 236 | "run_id": run_id, 237 | "tracing": False, 238 | } 239 | self._run_profile[key] = profile 240 | self._traces[run_id] = [] 241 | else: 242 | profile = self._run_profile[key] 243 | profile["stats"]["last_run"] = datetime.datetime.now() 244 | 245 | def add_run(self, run_context, run_values): 246 | key = self.get_run_context_key(run_context) 247 | profile = self._run_profile[key] 248 | 249 | # stats 250 | num_runs = profile["stats"]["runs"] 251 | old_runtime = profile["stats"]["runtimes"] 252 | profile["stats"]["runs"] += 1 253 | runtime = datetime.datetime.now() - profile["stats"]["last_run"] 254 | profile["stats"]["runtimes"] = (runtime + old_runtime * num_runs) / (num_runs + 1) 255 | 256 | if run_values.run_metadata.ByteSize() > 0: 257 | run_id = profile["run_id"] 258 | trace_id = len(self._traces[run_id]) 259 | self._traces[run_id].append(run_values.run_metadata) 260 | profile["tracing"] = False 261 | profile["traces"].append( 262 | { 263 | "trace_id": trace_id, 264 | "date": datetime.datetime.now() 265 | } 266 | ) 267 | profile["stats"]["traces"] = len(profile["traces"]) 268 | if len(self._traces[run_id]) > self._keep_traces: 269 | self._traces[run_id][0] = None 270 | 271 | 272 | class TracingServerHook(tf.train.SessionRunHook): 273 | def __init__(self, source): 274 | self._source = source 275 | 276 | def begin(self): 277 | super().begin() 278 | self._source.running = True 279 | 280 | def after_create_session(self, session, coord): 281 | super().after_create_session(session, coord) 282 | 283 | def before_run(self, run_context): 284 | super().before_run(run_context) 285 | self._source.before_run(run_context) 286 | if self._source.is_tracing_on(run_context): 287 | opts = (tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)) 288 | return tf.train.SessionRunArgs(None, None, options=opts) 289 | else: 290 | return None 291 | 292 | def after_run(self, run_context, run_values): 293 | super().after_run(run_context, run_values) 294 | self._source.add_run(run_context, run_values) 295 | 296 | def end(self, session): 297 | super().end(session) 298 | self._source.running = False 299 | 300 | 301 | class TracingServer(VisualizationServer): 302 | """ 303 | This class provides a ``tf.train.SessionRunHook`` to track session runs as well as a web interface to interact with 304 | users. By default, the web interface is accessible on `http://0.0.0.0:9999`. 305 | The web server stops at the end of the script. Use :func:`tftracer.TracingServer.join` to keep the server alive. 306 | 307 | Example: 308 | Estimator API: 309 | 310 | .. code-block:: python 311 | 312 | tracing_server = TracingServer() 313 | estimator.train(input_fn, hooks=[tracing_server.hook]) 314 | 315 | Low-Level API: 316 | 317 | .. code-block:: python 318 | 319 | tracing_server = TracingServer() 320 | with tf.train.MonitoredTrainingSession(hooks=[tracing_server.hook]): 321 | ... 322 | 323 | Args: 324 | start_web_server_on_start (bool): If true a web server starts on object initialization. (default: true) 325 | server_port (int): TCP port to which web server listens (default: 9999) 326 | server_ip (str): IP Address to which web server listens (default: "0.0.0.0") 327 | keep_traces (int): Number of traces per run which the tracing server should keep. \ 328 | the server discards the oldest traces when exeeced the limit. (default: 5) 329 | """ 330 | 331 | def __init__(self, **kwargs): 332 | self._source = TracingSource(**kwargs) 333 | super().__init__("tftracer", self._source, **kwargs) 334 | start_web_server_on_start = kwargs.get("start_web_server_on_start", True) 335 | if start_web_server_on_start: 336 | self.start_web_server() 337 | self._hook = TracingServerHook(self._source) 338 | 339 | def save_session(self, filename): 340 | """ 341 | Stores the tracing session to a pickle file. 342 | 343 | Args: 344 | filename: path to the trace session file. 345 | """ 346 | with open(filename, "wb") as fp: 347 | pickle.dump(self._source, fp) 348 | 349 | def load_session(self, filename, gziped=None): 350 | """ 351 | Loads a tracing session into the current tracing server. 352 | 353 | Caution: 354 | This action discards the current data in the session. 355 | 356 | Args: 357 | filename: path to the trace session file. 358 | gziped (bool): when set, determines if the trace file is gziped. when None, use gzip if the filename ends with ".gz"; 359 | """ 360 | running = self._source.running 361 | global_tracing = self._source.global_tracing 362 | gziped = filename.endswith(".gz") if gziped is None else gziped 363 | 364 | with open(filename, "rb") as fp: 365 | if gziped: 366 | data = gzip.decompress(fp.read()) 367 | self._source = pickle.loads(data) 368 | else: 369 | self._source = pickle.load(fp) 370 | 371 | self._source.running = running 372 | self._source.global_tracing = global_tracing 373 | 374 | @property 375 | def hook(self): 376 | """ 377 | Returns a ``tensorflow.train.SessionRunHook`` object. 378 | This object is meant to pass to tensorflow ``estimator`` API or ``MonitoredSession``. 379 | 380 | """ 381 | 382 | return self._hook 383 | -------------------------------------------------------------------------------- /tftracer/version.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python -u 2 | # coding=utf-8 3 | 4 | __author__ = 'Sayed Hadi Hashemi' 5 | 6 | __version__ = '1.1.0' 7 | --------------------------------------------------------------------------------