├── .gitignore ├── .pylintrc ├── .vscode └── launch.json ├── CHANGELOG.md ├── CONTRIBUTING.md ├── ISSUE_TEMPLATE.md ├── LICENSE.txt ├── MANIFEST.in ├── NOTICE.md ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── TODO.md ├── data └── test_images │ ├── cat.jpg │ ├── dogs.png │ ├── elephant.png │ ├── imagenet_class_index.json │ └── imagenet_synsets.txt ├── docs ├── images │ ├── draw_model.png │ ├── fruits.gif │ ├── lazy_log_array_sum.png │ ├── model_stats.png │ ├── quick_start.gif │ ├── saliency.png │ ├── simple_logging │ │ ├── full-page2.png │ │ ├── line_cell.png │ │ ├── line_cell2.png │ │ ├── plotly_line.png │ │ ├── text_cell.png │ │ └── text_summary.png │ ├── teaser.gif │ ├── teaser_small.gif │ └── tsne.gif ├── lazy_logging.md ├── paper │ ├── .gitignore │ ├── ACM-Reference-Format.bbx │ ├── ACM-Reference-Format.bst │ ├── ACM-Reference-Format.cbx │ ├── ACM-Reference-Format.dbx │ ├── TensorWatch_Collaboration.png │ ├── acmart.cls │ ├── main.pdf │ ├── main.tex │ ├── sample-base.bib │ ├── tensorwatch-screenshot.png │ └── tensorwatch-screenshot2.png └── simple_logging.md ├── install_jupyterlab.bat ├── notebooks ├── cnn_pred_explain.ipynb ├── data_exploration.ipynb ├── fruits_analysis.ipynb ├── lazy_logging.ipynb ├── mnist.ipynb ├── network_arch.ipynb ├── receptive_field.ipynb └── simple_logging.ipynb ├── setup.py ├── tensorwatch.pyproj ├── tensorwatch.sln ├── tensorwatch ├── __init__.py ├── array_stream.py ├── data_utils.py ├── embeddings │ ├── __init__.py │ └── tsne_utils.py ├── evaler.py ├── evaler_utils.py ├── file_stream.py ├── filtered_stream.py ├── image_utils.py ├── imagenet_class_index.json ├── imagenet_synsets.txt ├── imagenet_utils.py ├── lv_types.py ├── model_graph │ ├── __init__.py │ ├── hiddenlayer │ │ ├── README.md │ │ ├── __init__.py │ │ ├── distiller.py │ │ ├── distiller_utils.py │ │ ├── ge.py │ │ ├── graph.py │ │ ├── pytorch_builder.py │ │ ├── pytorch_builder_grad.py │ │ ├── pytorch_builder_trace.py │ │ ├── pytorch_draw_model.py │ │ ├── summary_graph.py │ │ ├── tf_builder.py │ │ └── transforms.py │ ├── torchstat │ │ ├── README.md │ │ ├── __init__.py │ │ ├── analyzer.py │ │ ├── compute_flops.py │ │ ├── compute_madd.py │ │ ├── compute_memory.py │ │ ├── reporter.py │ │ └── stat_tree.py │ └── torchstat_utils.py ├── mpl │ ├── __init__.py │ ├── bar_plot.py │ ├── base_mpl_plot.py │ ├── histogram.py │ ├── image_plot.py │ ├── line_plot.py │ └── pie_chart.py ├── notebook_maker.py ├── plotly │ ├── __init__.py │ ├── base_plotly_plot.py │ ├── embeddings_plot.py │ └── line_plot.py ├── pytorch_utils.py ├── receptive_field │ ├── __init__.py │ └── rf_utils.py ├── repeated_timer.py ├── saliency │ ├── README.md │ ├── __init__.py │ ├── backprop.py │ ├── deeplift.py │ ├── epsilon_lrp.py │ ├── gradcam.py │ ├── inverter_util.py │ ├── lime │ │ ├── __init__.py │ │ ├── lime_base.py │ │ ├── lime_image.py │ │ └── wrappers │ │ │ ├── __init__.py │ │ │ ├── generic_utils.py │ │ │ └── scikit_image.py │ ├── lime_image_explainer.py │ ├── occlusion.py │ └── saliency.py ├── stream.py ├── stream_factory.py ├── stream_union.py ├── tensor_utils.py ├── text_vis.py ├── utils.py ├── vis_base.py ├── visualizer.py ├── watcher.py ├── watcher_base.py ├── watcher_client.py ├── zmq_mgmt_stream.py ├── zmq_stream.py └── zmq_wrapper.py ├── test ├── components │ ├── circ_ref.py │ ├── evaler.py │ ├── file_only_test.py │ ├── notebook_maker.py │ ├── stream.py │ └── watcher.py ├── deps │ ├── ipython_widget.py │ ├── live_graph.py │ ├── panda.py │ └── thread.py ├── files │ └── file_stream.py ├── mnist │ └── cli_mnist.py ├── post_train │ └── saliency.py ├── pre_train │ ├── dot_manual.bat │ ├── draw_cust_model.py │ ├── draw_model.py │ ├── model_stats.py │ ├── model_stats_perf.py │ ├── sample.dot │ ├── test_pydot.py │ └── tsny.py ├── simple_log │ ├── cli_file_expr.py │ ├── cli_ij.py │ ├── cli_sum_log.py │ ├── file_expr.py │ ├── quick_start.py │ ├── srv_ij.py │ ├── sum_lazy.py │ └── sum_log.py ├── test.pyproj ├── visualizations │ ├── arr_img_plot.py │ ├── arr_mpl_line.py │ ├── bar_plot.py │ ├── confidence_int.py │ ├── histogram.py │ ├── line3d_plot.py │ ├── mpl_line.py │ ├── pie_chart.py │ └── plotly_line.py └── zmq │ ├── zmq_pub.py │ ├── zmq_stream.py │ ├── zmq_sub.py │ ├── zmq_watcher_client.py │ └── zmq_watcher_server.py └── update_package.bat /.pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | ignore=model_graph, 3 | receptive_field, 4 | saliency 5 | 6 | [TYPECHECK] 7 | ignored-modules=numpy,torch,matplotlib,pyplot,zmq 8 | 9 | [MESSAGES CONTROL] 10 | disable=protected-access, 11 | broad-except, 12 | global-statement, 13 | fixme, 14 | C, 15 | R -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: Current File", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "${file}", 12 | "console": "integratedTerminal" 13 | } 14 | ] 15 | } -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # What's new 2 | 3 | Below is summarized list of important changes. This does not include minor/less important changes or bug fixes or documentation update. This list updated every few months. For complete detailed changes, please review [commit history](https://github.com/Microsoft/tensorwatch/commits/master). 4 | 5 | ### May, 2019 6 | * First release! -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | This project welcomes contributions and suggestions. Most contributions require you to 4 | agree to a Contributor License Agreement (CLA) declaring that you have the right to, 5 | and actually do, grant us the rights to use your contribution. For details, visit 6 | https://cla.microsoft.com. 7 | 8 | When you submit a pull request, a CLA-bot will automatically determine whether you need 9 | to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the 10 | instructions provided by the bot. You will only need to do this once across all repositories using our CLA. 11 | 12 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 13 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 14 | or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 15 | -------------------------------------------------------------------------------- /ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | # Read This First 2 | 3 | * Make sure to describe **all the steps** to reproduce the issue 4 | * Include full error message in the description 5 | * Add OS version, Python version, Pytorch version if applicable 6 | 7 | Remember: if we cannot reproduce your problem, we cannot find solution! 8 | 9 | 10 | **What's better than filing issue? Filing a pull request :).** 11 | 12 | ------------------------------------ (Remove above before filing the issue) ------------------------------------ 13 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) Microsoft Corporation. All rights reserved. 2 | 3 | MIT License 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include tensorwatch/*.txt 2 | include tensorwatch/*.json -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # Support 2 | 3 | We highly recommend to take a look at source code and contribute to the project. Also please consider [contributing](CONTRIBUTING.md) new features and fixes :). 4 | 5 | * [Join TensorWatch Facebook Group](https://www.facebook.com/groups/378075159472803/) 6 | * [File GitHub Issue](https://github.com/Microsoft/tensorwatch/issues) -------------------------------------------------------------------------------- /TODO.md: -------------------------------------------------------------------------------- 1 | * Fix cell size issue 2 | * Refactor _plot* interface to accept all values, for ImagePlot only use last value 3 | * Refactor ImagePlot for arbitrary number of images with alpha, cmap 4 | * Change tw.open -> tw.create_viz 5 | * Make sure streams have names as key, each data point has index 6 | * Add tw.open_viz(stream_name, from_index)_ 7 | * Add persist=device_name option for streams 8 | * Ability to use streams in standalone mode 9 | * tw.create_viz on server side 10 | * tw.log for server side 11 | * experiment with IPC channel 12 | * confusion matrix as in https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html 13 | * Speed up import 14 | * Do linting 15 | * live perf data 16 | * NaN tracing 17 | * PCA 18 | * Remove error if MNIST notebook is on and we run fruits 19 | * Remove 2nd image from fruits 20 | * clear exisitng streams when starting client 21 | * ImageData should accept numpy array or pillow or torch tensor 22 | * image plot getting refreshed at 12hz instead of 2 hz in MNIST 23 | * image plot doesn't title 24 | * Animated mesh/surface graph demo 25 | * Move to h5 storage? 26 | * Error envelop 27 | * histogram 28 | * new graph on end 29 | * TF support 30 | * generic visualizer -> Given obj and Box, paint in box 31 | * visualize image and text with attention 32 | * add confidence interval for plotly: https://plot.ly/python/continuous-error-bars/ -------------------------------------------------------------------------------- /data/test_images/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/data/test_images/cat.jpg -------------------------------------------------------------------------------- /data/test_images/dogs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/data/test_images/dogs.png -------------------------------------------------------------------------------- /data/test_images/elephant.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/data/test_images/elephant.png -------------------------------------------------------------------------------- /docs/images/draw_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/images/draw_model.png -------------------------------------------------------------------------------- /docs/images/fruits.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/images/fruits.gif -------------------------------------------------------------------------------- /docs/images/lazy_log_array_sum.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/images/lazy_log_array_sum.png -------------------------------------------------------------------------------- /docs/images/model_stats.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/images/model_stats.png -------------------------------------------------------------------------------- /docs/images/quick_start.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/images/quick_start.gif -------------------------------------------------------------------------------- /docs/images/saliency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/images/saliency.png -------------------------------------------------------------------------------- /docs/images/simple_logging/full-page2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/images/simple_logging/full-page2.png -------------------------------------------------------------------------------- /docs/images/simple_logging/line_cell.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/images/simple_logging/line_cell.png -------------------------------------------------------------------------------- /docs/images/simple_logging/line_cell2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/images/simple_logging/line_cell2.png -------------------------------------------------------------------------------- /docs/images/simple_logging/plotly_line.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/images/simple_logging/plotly_line.png -------------------------------------------------------------------------------- /docs/images/simple_logging/text_cell.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/images/simple_logging/text_cell.png -------------------------------------------------------------------------------- /docs/images/simple_logging/text_summary.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/images/simple_logging/text_summary.png -------------------------------------------------------------------------------- /docs/images/teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/images/teaser.gif -------------------------------------------------------------------------------- /docs/images/teaser_small.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/images/teaser_small.gif -------------------------------------------------------------------------------- /docs/images/tsne.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/images/tsne.gif -------------------------------------------------------------------------------- /docs/lazy_logging.md: -------------------------------------------------------------------------------- 1 | 2 | # Lazy Logging Tutorial 3 | 4 | In this tutorial, we will use a straightforward script that creates an array of random numbers. Our goal is to examine this array from the Jupyter Notebook while below script is running separately in a console. 5 | 6 | ``` 7 | # available in test/simple_log/sum_lazy.py 8 | 9 | import time, random 10 | import tensorwatch as tw 11 | 12 | # create watcher object as usual 13 | w = tw.Watcher() 14 | 15 | weights = None 16 | for i in range(10000): 17 | weights = [random.random() for _ in range(5)] 18 | 19 | # let watcher observe variables we have 20 | # this has almost no performance cost 21 | w.observe(weights=weights) 22 | 23 | time.sleep(1) 24 | ``` 25 | 26 | ## Basic Concepts 27 | 28 | Notice that we give `Watcher` object an opportunity to *observe* variables. Observing variables is very cheap so we can observe anything that we might be interested in, for example, an entire model or a batch of data. From the Jupyter Notebook, you can specify arbitrary *lambda expression* that we execute in the context of these observed variables. The result of this lambda expression is sent as stream back to Jupyter Notebook so that you can render into visualizer of your choice. That's pretty much it. 29 | 30 | ## How Does This Work? 31 | 32 | TensorWatch has two core classes: `Watcher` and `WatcherClient`. The `Watcher` object will open up TCP/IP sockets by default and listens to any incoming requests. The `WatcherClient` allows you to connect to `Watcher` and have it execute any Python [lambda expression](http://book.pythontips.com/en/latest/lambdas.html) you want. The Python lambda expressions consume a stream of values as input and outputs another stream of values. The lambda expressions may typically contain [map, filter, and reduce](http://book.pythontips.com/en/latest/map_filter.html) so you can transform values in the input stream, filter them, or aggregate them. Let's see all these with a simple example. 33 | 34 | ## Create the Client 35 | You can either create a Jupyter Notebook or get the existing one in the repo at `notebooks/lazy_logging.ipynb`. 36 | 37 | Let's first do imports: 38 | 39 | 40 | ```python 41 | %matplotlib notebook 42 | import tensorwatch as tw 43 | ``` 44 | 45 | Next, create a stream using a simple lambda expression that sums values in our `weights` array that we are observing in the above script. The input to the lambda expression is an object that has all the variables we are observing. 46 | 47 | ```python 48 | client = tw.WatcherClient() 49 | stream = client.create_stream(expr='lambda d: np.sum(d.weights)') 50 | ``` 51 | 52 | Next, send this stream to line chart visualization: 53 | 54 | ```python 55 | line_plot = tw.Visualizer(stream, vis_type='line') 56 | line_plot.show() 57 | ``` 58 | 59 | Now when you run our script `sum_lazy.py` in the console and then above Jupyter Notebook, you will see the sum of `weights` array getting plotted in real-time: 60 | 61 | 62 | 63 | ## What Just Happened? 64 | 65 | Notice that your script is running in a different process than Jupyter Notebook. You could have been running your Jupyter Notebook on your laptop while your script could have been in some VM in a cloud. You can do queries on observed variables at the run-time! As `Watcher.observe()` is a cheap call, you could observe your entire model, batch inputs and outputs, and so on. You can then slice and dice all these observed variables from the comfort of your Jupyter Notebook while your script is progressing! 66 | 67 | ## Next Steps 68 | 69 | We have just scratched the surface in this new land of lazy logging! You can do much more with this idea. For example, you can create *events* so your lambda expression gets executed only on those events. This is useful to create streams that generate data per batch or per epoch. You can also use many predefined lambda expressions that allows you to view top k by some criteria. 70 | 71 | ## Questions? 72 | 73 | File a [Github issue](https://github.com/microsoft/tensorwatch/issues/new) and let us know if we can improve this tutorial. 74 | -------------------------------------------------------------------------------- /docs/paper/.gitignore: -------------------------------------------------------------------------------- 1 | ## Core latex/pdflatex auxiliary files: 2 | *.aux 3 | *.lof 4 | *.log 5 | *.lot 6 | *.fls 7 | *.out 8 | *.toc 9 | *.fmt 10 | *.fot 11 | *.cb 12 | *.cb2 13 | .*.lb 14 | 15 | ## Intermediate documents: 16 | *.dvi 17 | *.xdv 18 | *-converted-to.* 19 | # these rules might exclude image files for figures etc. 20 | # *.ps 21 | # *.eps 22 | # *.pdf 23 | 24 | ## Bibliography auxiliary files (bibtex/biblatex/biber): 25 | *.bbl 26 | *.bcf 27 | *.blg 28 | *-blx.aux 29 | *-blx.bib 30 | *.run.xml 31 | 32 | ## Build tool auxiliary files: 33 | *.fdb_latexmk 34 | *.synctex 35 | *.synctex(busy) 36 | *.synctex.gz 37 | *.synctex.gz(busy) 38 | *.pdfsync 39 | 40 | ## Build tool directories for auxiliary files 41 | # latexrun 42 | latex.out/ 43 | 44 | ## Auxiliary and intermediate files from other packages: 45 | # algorithms 46 | *.alg 47 | *.loa 48 | 49 | # achemso 50 | acs-*.bib 51 | 52 | # amsthm 53 | *.thm 54 | 55 | # beamer 56 | *.nav 57 | *.pre 58 | *.snm 59 | *.vrb 60 | 61 | # changes 62 | *.soc 63 | 64 | # comment 65 | *.cut 66 | 67 | # cprotect 68 | *.cpt 69 | 70 | # elsarticle (documentclass of Elsevier journals) 71 | *.spl 72 | 73 | # endnotes 74 | *.ent 75 | 76 | # fixme 77 | *.lox 78 | 79 | # feynmf/feynmp 80 | *.mf 81 | *.mp 82 | *.t[1-9] 83 | *.t[1-9][0-9] 84 | *.tfm 85 | 86 | #(r)(e)ledmac/(r)(e)ledpar 87 | *.end 88 | *.?end 89 | *.[1-9] 90 | *.[1-9][0-9] 91 | *.[1-9][0-9][0-9] 92 | *.[1-9]R 93 | *.[1-9][0-9]R 94 | *.[1-9][0-9][0-9]R 95 | *.eledsec[1-9] 96 | *.eledsec[1-9]R 97 | *.eledsec[1-9][0-9] 98 | *.eledsec[1-9][0-9]R 99 | *.eledsec[1-9][0-9][0-9] 100 | *.eledsec[1-9][0-9][0-9]R 101 | 102 | # glossaries 103 | *.acn 104 | *.acr 105 | *.glg 106 | *.glo 107 | *.gls 108 | *.glsdefs 109 | *.lzo 110 | *.lzs 111 | 112 | # uncomment this for glossaries-extra (will ignore makeindex's style files!) 113 | # *.ist 114 | 115 | # gnuplottex 116 | *-gnuplottex-* 117 | 118 | # gregoriotex 119 | *.gaux 120 | *.gtex 121 | 122 | # htlatex 123 | *.4ct 124 | *.4tc 125 | *.idv 126 | *.lg 127 | *.trc 128 | *.xref 129 | 130 | # hyperref 131 | *.brf 132 | 133 | # knitr 134 | *-concordance.tex 135 | # TODO Comment the next line if you want to keep your tikz graphics files 136 | *.tikz 137 | *-tikzDictionary 138 | 139 | # listings 140 | *.lol 141 | 142 | # luatexja-ruby 143 | *.ltjruby 144 | 145 | # makeidx 146 | *.idx 147 | *.ilg 148 | *.ind 149 | 150 | # minitoc 151 | *.maf 152 | *.mlf 153 | *.mlt 154 | *.mtc[0-9]* 155 | *.slf[0-9]* 156 | *.slt[0-9]* 157 | *.stc[0-9]* 158 | 159 | # minted 160 | _minted* 161 | *.pyg 162 | 163 | # morewrites 164 | *.mw 165 | 166 | # nomencl 167 | *.nlg 168 | *.nlo 169 | *.nls 170 | 171 | # pax 172 | *.pax 173 | 174 | # pdfpcnotes 175 | *.pdfpc 176 | 177 | # sagetex 178 | *.sagetex.sage 179 | *.sagetex.py 180 | *.sagetex.scmd 181 | 182 | # scrwfile 183 | *.wrt 184 | 185 | # sympy 186 | *.sout 187 | *.sympy 188 | sympy-plots-for-*.tex/ 189 | 190 | # pdfcomment 191 | *.upa 192 | *.upb 193 | 194 | # pythontex 195 | *.pytxcode 196 | pythontex-files-*/ 197 | 198 | # tcolorbox 199 | *.listing 200 | 201 | # thmtools 202 | *.loe 203 | 204 | # TikZ & PGF 205 | *.dpth 206 | *.md5 207 | *.auxlock 208 | 209 | # todonotes 210 | *.tdo 211 | 212 | # vhistory 213 | *.hst 214 | *.ver 215 | 216 | # easy-todo 217 | *.lod 218 | 219 | # xcolor 220 | *.xcp 221 | 222 | # xmpincl 223 | *.xmpi 224 | 225 | # xindy 226 | *.xdy 227 | 228 | # xypic precompiled matrices and outlines 229 | *.xyc 230 | *.xyd 231 | 232 | # endfloat 233 | *.ttt 234 | *.fff 235 | 236 | # Latexian 237 | TSWLatexianTemp* 238 | 239 | ## Editors: 240 | # WinEdt 241 | *.bak 242 | *.sav 243 | 244 | # Texpad 245 | .texpadtmp 246 | 247 | # LyX 248 | *.lyx~ 249 | 250 | # Kile 251 | *.backup 252 | 253 | # gummi 254 | .*.swp 255 | 256 | # KBibTeX 257 | *~[0-9]* 258 | 259 | # TeXnicCenter 260 | *.tps 261 | 262 | # auto folder when using emacs and auctex 263 | ./auto/* 264 | *.el 265 | 266 | # expex forward references with \gathertags 267 | *-tags.tex 268 | 269 | # standalone packages 270 | *.sta 271 | 272 | # Makeindex log files 273 | *.lpz -------------------------------------------------------------------------------- /docs/paper/ACM-Reference-Format.cbx: -------------------------------------------------------------------------------- 1 | \ProvidesFile{ACM-Reference-Format.cbx}[2017-09-27 v0.1] 2 | 3 | \RequireCitationStyle{numeric} 4 | 5 | \endinput 6 | -------------------------------------------------------------------------------- /docs/paper/ACM-Reference-Format.dbx: -------------------------------------------------------------------------------- 1 | % Teach biblatex about numpages field 2 | \DeclareDatamodelFields[type=field, datatype=literal]{numpages} 3 | \DeclareDatamodelEntryfields{numpages} 4 | 5 | % Teach biblatex about articleno field 6 | \DeclareDatamodelFields[type=field, datatype=literal]{articleno} 7 | \DeclareDatamodelEntryfields{articleno} 8 | 9 | % Teach biblatex about urls field 10 | \DeclareDatamodelFields[type=list, datatype=uri]{urls} 11 | \DeclareDatamodelEntryfields{urls} 12 | 13 | % Teach biblatex about school field 14 | \DeclareDatamodelFields[type=list, datatype=literal]{school} 15 | \DeclareDatamodelEntryfields[thesis]{school} 16 | 17 | \DeclareDatamodelFields[type=field, datatype=literal]{key} 18 | \DeclareDatamodelEntryfields{key} -------------------------------------------------------------------------------- /docs/paper/TensorWatch_Collaboration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/paper/TensorWatch_Collaboration.png -------------------------------------------------------------------------------- /docs/paper/main.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/paper/main.pdf -------------------------------------------------------------------------------- /docs/paper/tensorwatch-screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/paper/tensorwatch-screenshot.png -------------------------------------------------------------------------------- /docs/paper/tensorwatch-screenshot2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/paper/tensorwatch-screenshot2.png -------------------------------------------------------------------------------- /docs/simple_logging.md: -------------------------------------------------------------------------------- 1 | 2 | # Simple Logging Tutorial 3 | 4 | In this tutorial, we will use a simple script that maintains two variables. Our goal is to log the values of these variables using TensorWatch and view it in real-time in variety of ways from Jupyter Notebook demonstrating various features offered by TensorWatch. We would execute this script from a console and its code looks like this: 5 | 6 | ``` 7 | # available in test/simple_log/sum_log.py 8 | 9 | import time, random 10 | import tensorwatch as tw 11 | 12 | # we will create two streams, one for 13 | # sums of integers, other for random numbers 14 | w = tw.Watcher() 15 | st_isum = w.create_stream('isums') 16 | st_rsum = w.create_stream('rsums') 17 | 18 | isum, rsum = 0, 0 19 | for i in range(10000): 20 | isum += i 21 | rsum += random.randint(i) 22 | 23 | # write to streams 24 | st_isum.write((i, isum)) 25 | st_rsum.write((i, rsum)) 26 | 27 | time.sleep(1) 28 | ``` 29 | 30 | ## Basic Concepts 31 | 32 | The `Watcher` object allows you to create streams. You can then write any values in to these streams. By default, the `Watcher` object opens up TCP/IP sockets so a client can request these streams. As values are being written on one end, they get serialized and sent to any client(s) on the other end. We will use Jupyter Notebook to create a client, open these streams and feed them to various visualizers. 33 | 34 | ## Create the Client 35 | You can either create a new Jupyter Notebook or get the existing one in the repo at `notebooks/simple_logging.ipynb`. 36 | 37 | Let's first do imports: 38 | 39 | 40 | ```python 41 | %matplotlib notebook 42 | import tensorwatch as tw 43 | ``` 44 | 45 | Next, open the streams that we had created in above script: 46 | 47 | ```python 48 | cli = tw.WatcherClient() 49 | st_isum = cli.open_stream('isums') 50 | st_rsum = cli.open_stream('rsums') 51 | ``` 52 | 53 | Now lets visualize isum stream in textual format: 54 | 55 | 56 | ```python 57 | text_vis = tw.Visualizer(st_isum, vis_type='text') 58 | text_vis.show() 59 | ``` 60 | 61 | You should see a growing table in real-time like this, each row is tuple we wrote to the stream: 62 | 63 | 64 | 65 | 66 | That worked out good! How about feeding the same stream simultaneously to plot a line graph? 67 | 68 | 69 | ```python 70 | line_plot = tw.Visualizer(st_isum, vis_type='line', xtitle='i', ytitle='isum') 71 | line_plot.show() 72 | ``` 73 | 74 | 75 | 76 | Let's take step further. Say, we plot rsum as well but, just for fun, in the *same plot* as isum above. That's easy with the optional `host` parameter: 77 | 78 | ```python 79 | line_plot = tw.Visualizer(st_rsum, vis_type='line', host=line_plot, ytitle='rsum') 80 | ``` 81 | This instantly changes our line plot and a new Y-Axis is automatically added for our second stream. TensorWatch allows you to add multiple streams in same visualizer. 82 | 83 | Are you curious about statistics of these two streams? Let's display the live statistics of both streams side by side! To do this, we will use the `cell` parameter for the second visualization to place right next to the first one: 84 | 85 | ```python 86 | istats = tw.Visualizer(st_isum, vis_type='summary') 87 | istats.show() 88 | 89 | rstats = tw.Visualizer(st_rsum, vis_type='summary', cell=istats) 90 | ``` 91 | 92 | The output looks like this (each panel showing statistics for the tuple that was logged in the stream): 93 | 94 | 95 | 96 | ## Next Steps 97 | 98 | We have just scratched the surface of what you can do with TensorWatch. You should check out [Lazy Logging Tutorial](lazy_logging.md) and other [notebooks](https://github.com/microsoft/tensorwatch/tree/master/notebooks) as well. 99 | 100 | ## Questions? 101 | 102 | File a [Github issue](https://github.com/microsoft/tensorwatch/issues/new) and let us know if we can improve this tutorial. -------------------------------------------------------------------------------- /install_jupyterlab.bat: -------------------------------------------------------------------------------- 1 | conda install -c conda-forge jupyterlab nodejs 2 | conda install ipywidgets 3 | conda install -c plotly plotly-orca psutil 4 | 5 | set NODE_OPTIONS=--max-old-space-size=4096 6 | jupyter labextension install @jupyter-widgets/jupyterlab-manager --no-build 7 | jupyter labextension install plotlywidget --no-build 8 | jupyter labextension install @jupyterlab/plotly-extension --no-build 9 | jupyter labextension install jupyterlab-chart-editor --no-build 10 | jupyter lab build 11 | set NODE_OPTIONS= -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import setuptools 5 | 6 | with open("README.md", "r") as fh: 7 | long_description = fh.read() 8 | 9 | setuptools.setup( 10 | name="tensorwatch", 11 | version="0.9.1", 12 | author="Shital Shah", 13 | author_email="shitals@microsoft.com", 14 | description="Interactive Realtime Debugging and Visualization for AI", 15 | long_description=long_description, 16 | long_description_content_type="text/markdown", 17 | url="https://github.com/microsoft/tensorwatch", 18 | packages=setuptools.find_packages(), 19 | license='MIT', 20 | classifiers=( 21 | "Programming Language :: Python :: 3", 22 | "License :: OSI Approved :: MIT License", 23 | "Operating System :: OS Independent", 24 | ), 25 | include_package_data=True, 26 | install_requires=[ 27 | 'matplotlib', 'numpy', 'pyzmq', 'plotly', 'ipywidgets', 28 | 'pydot>=1.4.2', 29 | 'nbformat', 'scikit-image', 'nbformat', 'pyyaml', 'scikit-image', 'graphviz' # , 'receptivefield' 30 | ] 31 | ) 32 | -------------------------------------------------------------------------------- /tensorwatch.sln: -------------------------------------------------------------------------------- 1 |  2 | Microsoft Visual Studio Solution File, Format Version 12.00 3 | # Visual Studio 15 4 | VisualStudioVersion = 15.0.27428.2043 5 | MinimumVisualStudioVersion = 10.0.40219.1 6 | Project("{888888A0-9F3D-457C-B088-3A5042F75D52}") = "tensorwatch", "tensorwatch.pyproj", "{CC8ABC7F-EDE1-4E13-B6B7-0041A5EC66A7}" 7 | EndProject 8 | Global 9 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 10 | Debug|Any CPU = Debug|Any CPU 11 | Release|Any CPU = Release|Any CPU 12 | EndGlobalSection 13 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 14 | {CC8ABC7F-EDE1-4E13-B6B7-0041A5EC66A7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 15 | {CC8ABC7F-EDE1-4E13-B6B7-0041A5EC66A7}.Release|Any CPU.ActiveCfg = Release|Any CPU 16 | EndGlobalSection 17 | GlobalSection(SolutionProperties) = preSolution 18 | HideSolutionNode = FALSE 19 | EndGlobalSection 20 | GlobalSection(ExtensibilityGlobals) = postSolution 21 | SolutionGuid = {99E7AEC7-2CDE-48C8-B98B-4E28E4F840B6} 22 | EndGlobalSection 23 | EndGlobal 24 | -------------------------------------------------------------------------------- /tensorwatch/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from typing import Iterable, Sequence, Union 5 | 6 | from .watcher_client import WatcherClient 7 | from .watcher import Watcher 8 | from .watcher_base import WatcherBase 9 | 10 | from .text_vis import TextVis 11 | from .plotly.embeddings_plot import EmbeddingsPlot 12 | from .mpl.line_plot import LinePlot 13 | from .mpl.image_plot import ImagePlot 14 | from .mpl.histogram import Histogram 15 | from .mpl.bar_plot import BarPlot 16 | from .mpl.pie_chart import PieChart 17 | from .visualizer import Visualizer 18 | 19 | from .stream import Stream 20 | from .array_stream import ArrayStream 21 | from .lv_types import PointData, ImageData, VisArgs, StreamItem, PredictionResult 22 | from . import utils 23 | 24 | ###### Import methods for tw namespace ######### 25 | #from .receptive_field.rf_utils import plot_receptive_field, plot_grads_at 26 | from .embeddings.tsne_utils import get_tsne_components 27 | from .model_graph.torchstat_utils import model_stats, ModelStats 28 | from .image_utils import show_image, open_image, img2pyt, linear_to_2d, plt_loop 29 | from .data_utils import pyt_ds2list, sample_by_class, col2array, search_similar 30 | 31 | 32 | 33 | def draw_model(model, input_shape=None, orientation='TB', png_filename=None): #orientation = 'LR' for landscpe 34 | from .model_graph.hiddenlayer import pytorch_draw_model 35 | g = pytorch_draw_model.draw_graph(model, input_shape) 36 | return g 37 | 38 | 39 | -------------------------------------------------------------------------------- /tensorwatch/array_stream.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from .stream import Stream 5 | from .lv_types import StreamItem 6 | import uuid 7 | 8 | class ArrayStream(Stream): 9 | def __init__(self, array, stream_name:str=None, console_debug:bool=False): 10 | super(ArrayStream, self).__init__(stream_name=stream_name, console_debug=console_debug) 11 | 12 | self.stream_name = stream_name 13 | self.array = array 14 | 15 | def load(self, from_stream:'Stream'=None): 16 | if self.array is not None: 17 | self.write(self.array) 18 | super(ArrayStream, self).load() 19 | -------------------------------------------------------------------------------- /tensorwatch/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import random 5 | import scipy.spatial.distance 6 | import heapq 7 | import numpy as np 8 | import torch 9 | 10 | def pyt_tensor2np(pyt_tensor): 11 | if pyt_tensor is None: 12 | return None 13 | if isinstance(pyt_tensor, torch.Tensor): 14 | n = pyt_tensor.data.cpu().numpy() 15 | if len(n.shape) == 1: 16 | return n[0] 17 | else: 18 | return n 19 | elif isinstance(pyt_tensor, np.ndarray): 20 | return pyt_tensor 21 | else: 22 | return np.array(pyt_tensor) 23 | 24 | def pyt_tuple2np(pyt_tuple): 25 | return tuple((pyt_tensor2np(t) for t in pyt_tuple)) 26 | 27 | def pyt_ds2list(pyt_ds, count=None): 28 | count = count or len(pyt_ds) 29 | return [pyt_tuple2np(t) for t, c in zip(pyt_ds, range(count))] 30 | 31 | def sample_by_class(data, n_samples, class_col=1, shuffle=True): 32 | if shuffle: 33 | random.shuffle(data) 34 | samples = {} 35 | for i, t in enumerate(data): 36 | cls = t[class_col] 37 | if cls not in samples: 38 | samples[cls] = [] 39 | if len(samples[cls]) < n_samples: 40 | samples[cls].append(data[i]) 41 | samples = sum(samples.values(), []) 42 | return samples 43 | 44 | def col2array(dataset, col): 45 | return [row[col] for row in dataset] 46 | 47 | def search_similar(inputs, compare_to, algorithm='euclidean', topk=5, invert_score=True): 48 | all_scores = scipy.spatial.distance.cdist(inputs, compare_to, algorithm) 49 | all_results = [] 50 | for input_val, scores in zip(inputs, all_scores): 51 | result = [] 52 | for i, (score, data) in enumerate(zip(scores, compare_to)): 53 | if invert_score: 54 | score = 1/(score + 1.0E-6) 55 | if len(result) < topk: 56 | heapq.heappush(result, (score, (i, input_val, data))) 57 | else: 58 | heapq.heappushpop(result, (score, (i, input_val, data))) 59 | all_results.append(result) 60 | return all_results -------------------------------------------------------------------------------- /tensorwatch/embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | 5 | -------------------------------------------------------------------------------- /tensorwatch/embeddings/tsne_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from sklearn.manifold import TSNE 5 | from sklearn.preprocessing import StandardScaler 6 | import numpy as np 7 | 8 | def _standardize_data(data, col, whitten, flatten): 9 | if col is not None: 10 | data = data[col] 11 | 12 | #TODO: enable auto flattening 13 | #if data is tensor then flatten it first 14 | #if flatten and len(data) > 0 and hasattr(data[0], 'shape') and \ 15 | # utils.has_method(data[0], 'reshape'): 16 | 17 | # data = [d.reshape((-1,)) for d in data] 18 | 19 | if whitten: 20 | data = StandardScaler().fit_transform(data) 21 | return data 22 | 23 | def get_tsne_components(data, features_col=0, labels_col=1, whitten=True, n_components=3, perplexity=20, flatten=True, for_plot=True): 24 | features = _standardize_data(data, features_col, whitten, flatten) 25 | tsne = TSNE(n_components=n_components, perplexity=perplexity) 26 | tsne_results = tsne.fit_transform(features) 27 | 28 | if for_plot: 29 | comps = tsne_results.tolist() 30 | labels = data[labels_col] 31 | for i, item in enumerate(comps): 32 | # low, high, annotation, text, color 33 | label = labels[i] 34 | if isinstance(labels, np.ndarray): 35 | label = label.item() 36 | item.extend((None, None, None, str(int(label)), label)) 37 | return comps 38 | return tsne_results 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /tensorwatch/evaler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import threading, sys, logging 5 | from collections.abc import Iterator 6 | from .lv_types import EventData 7 | 8 | # pylint: disable=unused-wildcard-import 9 | # pylint: disable=wildcard-import 10 | # pylint: disable=unused-import 11 | from functools import * 12 | from itertools import * 13 | from statistics import * 14 | import numpy as np 15 | from .evaler_utils import * 16 | 17 | class Evaler: 18 | class EvalReturn: 19 | def __init__(self, result=None, is_valid=False, exception=None): 20 | self.result, self.exception, self.is_valid = \ 21 | result, exception, is_valid 22 | def reset(self): 23 | self.result, self.exception, self.is_valid = \ 24 | None, None, False 25 | 26 | class PostableIterator: 27 | def __init__(self, eval_wait): 28 | self.eval_wait = eval_wait 29 | self.post_wait = threading.Event() 30 | self.event_data, self.ended = None, None # define attributes in init 31 | self.reset() 32 | 33 | def reset(self): 34 | self.event_data, self.ended = None, False 35 | self.post_wait.clear() 36 | 37 | def abort(self): 38 | self.ended = True 39 | self.post_wait.set() 40 | 41 | def post(self, event_data:EventData=None, ended=False): 42 | self.event_data, self.ended = event_data, ended 43 | self.post_wait.set() 44 | 45 | def get_vals(self): 46 | while True: 47 | self.post_wait.wait() 48 | self.post_wait.clear() 49 | if self.ended: 50 | break 51 | else: 52 | yield self.event_data 53 | # below will cause result=None, is_valid=False when 54 | # expression has reduce 55 | self.eval_wait.set() 56 | 57 | def __init__(self, expr): 58 | self.eval_wait = threading.Event() 59 | self.reset_wait = threading.Event() 60 | self.g = Evaler.PostableIterator(self.eval_wait) 61 | self.expr = expr 62 | self.eval_return, self.continue_thread = None, None # define in __init__ 63 | self.reset() 64 | 65 | self.th = threading.Thread(target=self._runner, daemon=True, name='evaler') 66 | self.th.start() 67 | self.running = True 68 | 69 | def reset(self): 70 | self.g.reset() 71 | self.eval_wait.clear() 72 | self.reset_wait.clear() 73 | self.eval_return = Evaler.EvalReturn() 74 | self.continue_thread = True 75 | 76 | def _runner(self): 77 | while True: 78 | # this var will be used by eval 79 | l = self.g.get_vals() # pylint: disable=unused-variable 80 | try: 81 | result = eval(self.expr) # pylint: disable=eval-used 82 | if isinstance(result, Iterator): 83 | for item in result: 84 | self.eval_return = Evaler.EvalReturn(item, True) 85 | else: 86 | self.eval_return = Evaler.EvalReturn(result, True) 87 | except Exception as ex: # pylint: disable=broad-except 88 | logging.exception('Exception occured while evaluating expression: ' + self.expr) 89 | self.eval_return = Evaler.EvalReturn(None, True, ex) 90 | self.eval_wait.set() 91 | self.reset_wait.wait() 92 | if not self.continue_thread: 93 | break 94 | self.reset() 95 | self.running = False 96 | utils.debug_log('eval runner ended!') 97 | 98 | def abort(self): 99 | utils.debug_log('Evaler Aborted') 100 | self.continue_thread = False 101 | self.g.abort() 102 | self.eval_wait.set() 103 | self.reset_wait.set() 104 | 105 | def post(self, event_data:EventData=None, ended=False, continue_thread=True): 106 | if not self.running: 107 | utils.debug_log('post was called when Evaler is not running') 108 | return None, False 109 | self.eval_return.reset() 110 | self.g.post(event_data, ended) 111 | self.eval_wait.wait() 112 | self.eval_wait.clear() 113 | # save result before it would get reset 114 | eval_return = self.eval_return 115 | self.reset_wait.set() 116 | self.continue_thread = continue_thread 117 | if isinstance(eval_return.result, Iterator): 118 | eval_return.result = list(eval_return.result) 119 | return eval_return 120 | 121 | def join(self): 122 | self.th.join() 123 | -------------------------------------------------------------------------------- /tensorwatch/file_stream.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from .stream import Stream 5 | import pickle, os 6 | from typing import Any 7 | from . import utils 8 | import time 9 | 10 | class FileStream(Stream): 11 | def __init__(self, for_write:bool, file_name:str, stream_name:str=None, console_debug:bool=False): 12 | super(FileStream, self).__init__(stream_name=stream_name or file_name, console_debug=console_debug) 13 | 14 | self._file = open(file_name, 'wb' if for_write else 'rb') 15 | self.file_name = file_name 16 | self.for_write = for_write 17 | utils.debug_log('FileStream started', os.path.realpath(self._file.name), verbosity=0) 18 | 19 | def close(self): 20 | if not self._file.closed: 21 | self._file.close() 22 | self._file = None 23 | utils.debug_log('FileStream is closed', os.path.realpath(self._file.name), verbosity=0) 24 | super(FileStream, self).close() 25 | 26 | def write(self, val:Any, from_stream:'Stream'=None): 27 | stream_item = self.to_stream_item(val) 28 | 29 | if self.for_write: 30 | pickle.dump(stream_item, self._file) 31 | self._file.flush() 32 | #os.fsync() 33 | super(FileStream, self).write(stream_item) 34 | 35 | def read_all(self, from_stream:'Stream'=None): 36 | if self.for_write: 37 | raise IOError('Cannot use read() call because FileSteam is opened with for_write=True') 38 | if self._file is not None: 39 | self._file.seek(0, 0) # we may filter this stream multiple times 40 | while not utils.is_eof(self._file): 41 | yield pickle.load(self._file) 42 | for item in super(FileStream, self).read_all(): 43 | yield item 44 | 45 | def load(self, from_stream:'Stream'=None): 46 | if self.for_write: 47 | raise IOError('Cannot use load() call because FileSteam is opened with for_write=True') 48 | if self._file is not None: 49 | self._file.seek(0, 0) # we may filter this stream multiple times 50 | while not utils.is_eof(self._file): 51 | stream_item = pickle.load(self._file) 52 | self.write(stream_item) 53 | super(FileStream, self).load() 54 | 55 | def save(self, from_stream:'Stream'=None): 56 | if not self._file.closed: 57 | self._file.flush() 58 | super(FileStream, self).save(val) 59 | 60 | -------------------------------------------------------------------------------- /tensorwatch/filtered_stream.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from .stream import Stream 5 | from typing import Callable, Any 6 | 7 | class FilteredStream(Stream): 8 | def __init__(self, source_stream:Stream, filter_expr:Callable, stream_name:str=None, 9 | console_debug:bool=False)->None: 10 | 11 | stream_name = stream_name or '{}|{}'.format(source_stream.stream_name, str(filter_expr)) 12 | super(FilteredStream, self).__init__(stream_name=stream_name, console_debug=console_debug) 13 | self.subscribe(source_stream) 14 | self.filter_expr = filter_expr 15 | 16 | def _filter(self, stream_item): 17 | return self.filter_expr(stream_item) \ 18 | if self.filter_expr is not None \ 19 | else (stream_item, True) 20 | 21 | def write(self, val:Any, from_stream:'Stream'=None): 22 | stream_item = self.to_stream_item(val) 23 | 24 | result, is_valid = self._filter(stream_item) 25 | if is_valid: 26 | return super(FilteredStream, self).write(result) 27 | # else ignore this call 28 | 29 | def read_all(self, from_stream:'Stream'=None): #override->replacement 30 | for subscribed_to in self._subscribed_to: 31 | for stream_item in subscribed_to.read_all(from_stream=self): 32 | result, is_valid = self._filter(stream_item) 33 | if is_valid: 34 | yield stream_item 35 | 36 | -------------------------------------------------------------------------------- /tensorwatch/image_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import numpy as np 5 | import math 6 | import time 7 | 8 | def guess_image_dims(img): 9 | if len(img.shape) == 1: 10 | # assume 2D monochrome (MNIST) 11 | width = height = round(math.sqrt(img.shape[0])) 12 | if width*height != img.shape[0]: 13 | # assume 3 channels (CFAR, ImageNet) 14 | width = height = round(math.sqrt(img.shape[0] / 3)) 15 | if width*height*3 != img.shape[0]: 16 | raise ValueError("Cannot guess image dimensions for linearized pixels") 17 | return (3, height, width) 18 | return (1, height, width) 19 | return img.shape 20 | 21 | def to_imshow_array(img, width=None, height=None): 22 | # array from Pytorch has shape: [[channels,] height, width] 23 | # image needed for imshow needs: [height, width, channels] 24 | from PIL import Image 25 | 26 | if img is not None: 27 | if isinstance(img, Image.Image): 28 | img = np.array(img) 29 | if len(img.shape) >= 2: 30 | return img # img is already compatible to imshow 31 | 32 | # force max 3 dimensions 33 | if len(img.shape) > 3: 34 | # TODO allow config 35 | # select first one in batch 36 | img = img[0:1,:,:] 37 | 38 | if len(img.shape) == 1: # linearized pixels typically used for MLPs 39 | if not(width and height): 40 | # pylint: disable=unused-variable 41 | channels, height, width = guess_image_dims(img) 42 | img = img.reshape((-1, height, width)) 43 | 44 | if len(img.shape) == 3: 45 | if img.shape[0] == 1: # single channel images 46 | img = img.squeeze(0) 47 | else: 48 | img = np.swapaxes(img, 0, 2) # transpose H,W for imshow 49 | img = np.swapaxes(img, 0, 1) 50 | elif len(img.shape) == 2: 51 | img = np.swapaxes(img, 0, 1) # transpose H,W for imshow 52 | else: #zero dimensions 53 | img = None 54 | 55 | return img 56 | 57 | #width_dim=1 for imshow, 2 for pytorch arrays 58 | def stitch_horizontal(images, width_dim=1): 59 | return np.concatenate(images, axis=width_dim) 60 | 61 | def _resize_image(img, size=None): 62 | if size is not None or (hasattr(img, 'shape') and len(img.shape) == 1): 63 | if size is None: 64 | # make guess for 1-dim tensors 65 | h = int(math.sqrt(img.shape[0])) 66 | w = int(img.shape[0] / h) 67 | size = h,w 68 | img = np.reshape(img, size) 69 | return img 70 | 71 | def show_image(img, size=None, alpha=None, cmap=None, 72 | img2=None, size2=None, alpha2=None, cmap2=None, ax=None): 73 | import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue 74 | 75 | img =_resize_image(img, size) 76 | img2 =_resize_image(img2, size2) 77 | 78 | (ax or plt).imshow(img, alpha=alpha, cmap=cmap) 79 | 80 | if img2 is not None: 81 | (ax or plt).imshow(img2, alpha=alpha2, cmap=cmap2) 82 | 83 | return ax or plt.show() 84 | 85 | # convert_mode param is mode: https://pillow.readthedocs.io/en/5.1.x/handbook/concepts.html#modes 86 | # use convert_mode='RGB' to force 3 channels 87 | def open_image(path, resize=None, resample=1, convert_mode=None): # Image.ANTIALIAS==1 88 | from PIL import Image 89 | 90 | img = Image.open(path) 91 | if resize is not None: 92 | img = img.resize(resize, resample) 93 | if convert_mode is not None: 94 | img = img.convert(convert_mode) 95 | return img 96 | 97 | def img2pyt(img, add_batch_dim=True, resize=None): 98 | from torchvision import transforms # expensive function-level import 99 | 100 | ts = [] 101 | if resize is not None: 102 | ts.append(transforms.RandomResizedCrop(resize)) 103 | ts.append(transforms.ToTensor()) 104 | img_pyt = transforms.Compose(ts)(img) 105 | if add_batch_dim: 106 | img_pyt.unsqueeze_(0) 107 | return img_pyt 108 | 109 | def linear_to_2d(img, size=None): 110 | if size is not None or (hasattr(img, 'shape') and len(img.shape) == 1): 111 | if size is None: 112 | # make guess for 1-dim tensors 113 | h = int(math.sqrt(img.shape[0])) 114 | w = int(img.shape[0] / h) 115 | size = h,w 116 | img = np.reshape(img, size) 117 | return img 118 | 119 | def stack_images(imgs): 120 | return np.hstack(imgs) 121 | 122 | def plt_loop(count=None, sleep_time=1, plt_pause=0.01): 123 | import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue 124 | 125 | #plt.ion() 126 | #plt.show(block=False) 127 | while((count is None or count > 0) and not plt.waitforbuttonpress(plt_pause)): 128 | #plt.draw() 129 | plt.pause(plt_pause) 130 | time.sleep(sleep_time) 131 | if count is not None: 132 | count = count - 1 133 | 134 | def get_cmap(name:str): 135 | import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue 136 | return plt.cm.get_cmap(name=name) 137 | 138 | -------------------------------------------------------------------------------- /tensorwatch/imagenet_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from torchvision import transforms 5 | from . import pytorch_utils 6 | import json, os 7 | 8 | def get_image_transform(): 9 | transf = transforms.Compose([ #TODO: cache these transforms? 10 | get_resize_transform(), 11 | transforms.ToTensor(), 12 | get_normalize_transform() 13 | ]) 14 | 15 | return transf 16 | 17 | def get_resize_transform(): 18 | return transforms.Resize((224, 224)) 19 | 20 | def get_normalize_transform(): 21 | return transforms.Normalize(mean=[0.485, 0.456, 0.406], 22 | std=[0.229, 0.224, 0.225]) 23 | 24 | def image2batch(image): 25 | return pytorch_utils.image2batch(image, image_transform=get_image_transform()) 26 | 27 | def predict(model, images, image_transform=None, device=None): 28 | logits = pytorch_utils.batch_predict(model, images, 29 | input_transform=image_transform or get_image_transform(), device=device) 30 | probs = pytorch_utils.logits2probabilities(logits) #2-dim array, one column per class, one row per input 31 | return probs 32 | 33 | _imagenet_labels = None 34 | def get_imagenet_labels(): 35 | # pylint: disable=global-statement 36 | global _imagenet_labels 37 | _imagenet_labels = _imagenet_labels or ImagenetLabels() 38 | return _imagenet_labels 39 | 40 | def probabilities2classes(probs, topk=5): 41 | labels = get_imagenet_labels() 42 | top_probs = probs.topk(topk) 43 | # return (probability, class_id, class_label, class_code) 44 | return tuple((p,c, labels.index2label_text(c), labels.index2label_code(c)) \ 45 | for p, c in zip(top_probs[0][0].data.cpu().numpy(), top_probs[1][0].data.cpu().numpy())) 46 | 47 | class ImagenetLabels: 48 | def __init__(self, json_path=None): 49 | self._idx2label = [] 50 | self._idx2cls = [] 51 | self._cls2label = {} 52 | self._cls2idx = {} 53 | 54 | json_path = json_path or os.path.join(os.path.dirname(__file__), 'imagenet_class_index.json') 55 | 56 | with open(os.path.abspath(json_path), "r") as read_file: 57 | class_json = json.load(read_file) 58 | self._idx2label = [class_json[str(k)][1] for k in range(len(class_json))] 59 | self._idx2cls = [class_json[str(k)][0] for k in range(len(class_json))] 60 | self._cls2label = {class_json[str(k)][0]: class_json[str(k)][1] for k in range(len(class_json))} 61 | self._cls2idx = {class_json[str(k)][0]: k for k in range(len(class_json))} 62 | 63 | def index2label_text(self, index): 64 | return self._idx2label[index] 65 | def index2label_code(self, index): 66 | return self._idx2cls[index] 67 | def label_code2label_text(self, label_code): 68 | return self._cls2label[label_code] 69 | def label_code2index(self, label_code): 70 | return self._cls2idx[label_code] -------------------------------------------------------------------------------- /tensorwatch/model_graph/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | -------------------------------------------------------------------------------- /tensorwatch/model_graph/hiddenlayer/README.md: -------------------------------------------------------------------------------- 1 | # Credits 2 | 3 | Code in this folder is adopted from, 4 | 5 | * https://github.com/waleedka/hiddenlayer 6 | -------------------------------------------------------------------------------- /tensorwatch/model_graph/hiddenlayer/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tensorwatch/model_graph/hiddenlayer/distiller.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2018 Intel Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import torch 18 | from .distiller_utils import * 19 | 20 | import logging 21 | logging.captureWarnings(True) 22 | 23 | def model_find_param_name(model, param_to_find): 24 | """Look up the name of a model parameter. 25 | 26 | Arguments: 27 | model: the model to search 28 | param_to_find: the parameter whose name we want to look up 29 | 30 | Returns: 31 | The parameter name (string) or None, if the parameter was not found. 32 | """ 33 | for name, param in model.named_parameters(): 34 | if param is param_to_find: 35 | return name 36 | return None 37 | 38 | 39 | def model_find_module_name(model, module_to_find): 40 | """Look up the name of a module in a model. 41 | 42 | Arguments: 43 | model: the model to search 44 | module_to_find: the module whose name we want to look up 45 | 46 | Returns: 47 | The module name (string) or None, if the module was not found. 48 | """ 49 | for name, m in model.named_modules(): 50 | if m == module_to_find: 51 | return name 52 | return None 53 | 54 | 55 | def model_find_param(model, param_to_find_name): 56 | """Look a model parameter by its name 57 | 58 | Arguments: 59 | model: the model to search 60 | param_to_find_name: the name of the parameter that we are searching for 61 | 62 | Returns: 63 | The parameter or None, if the paramter name was not found. 64 | """ 65 | for name, param in model.named_parameters(): 66 | if name == param_to_find_name: 67 | return param 68 | return None 69 | 70 | 71 | def model_find_module(model, module_to_find): 72 | """Given a module name, find the module in the provided model. 73 | 74 | Arguments: 75 | model: the model to search 76 | module_to_find: the module whose name we want to look up 77 | 78 | Returns: 79 | The module or None, if the module was not found. 80 | """ 81 | for name, m in model.named_modules(): 82 | if name == module_to_find: 83 | return m 84 | return None 85 | 86 | -------------------------------------------------------------------------------- /tensorwatch/model_graph/hiddenlayer/ge.py: -------------------------------------------------------------------------------- 1 | """ 2 | HiddenLayer 3 | 4 | Implementation graph expressions to find nodes in a graph based on a pattern. 5 | 6 | Written by Waleed Abdulla 7 | Licensed under the MIT License 8 | """ 9 | 10 | import re 11 | 12 | 13 | 14 | class GEParser(): 15 | def __init__(self, text): 16 | self.index = 0 17 | self.text = text 18 | 19 | def parse(self): 20 | return self.serial() or self.parallel() or self.expression() 21 | 22 | def parallel(self): 23 | index = self.index 24 | expressions = [] 25 | while len(expressions) == 0 or self.token("|"): 26 | e = self.expression() 27 | if not e: 28 | break 29 | expressions.append(e) 30 | if len(expressions) >= 2: 31 | return ParallelPattern(expressions) 32 | # No match. Reset index 33 | self.index = index 34 | 35 | def serial(self): 36 | index = self.index 37 | expressions = [] 38 | while len(expressions) == 0 or self.token(">"): 39 | e = self.expression() 40 | if not e: 41 | break 42 | expressions.append(e) 43 | 44 | if len(expressions) >= 2: 45 | return SerialPattern(expressions) 46 | self.index = index 47 | 48 | def expression(self): 49 | index = self.index 50 | 51 | if self.token("("): 52 | e = self.serial() or self.parallel() or self.op() 53 | if e and self.token(")"): 54 | return e 55 | self.index = index 56 | e = self.op() 57 | return e 58 | 59 | def op(self): 60 | t = self.re(r"\w+") 61 | if t: 62 | c = self.condition() 63 | return NodePattern(t, c) 64 | 65 | def condition(self): 66 | # TODO: not implemented yet. This function is a placeholder 67 | index = self.index 68 | if self.token("["): 69 | c = self.token("1x1") or self.token("3x3") 70 | if c: 71 | if self.token("]"): 72 | return c 73 | self.index = index 74 | 75 | def token(self, s): 76 | return self.re(r"\s*(" + re.escape(s) + r")\s*", 1) 77 | 78 | def string(self, s): 79 | if s == self.text[self.index:self.index+len(s)]: 80 | self.index += len(s) 81 | return s 82 | 83 | def re(self, regex, group=0): 84 | m = re.match(regex, self.text[self.index:]) 85 | if m: 86 | self.index += len(m.group(0)) 87 | return m.group(group) 88 | 89 | 90 | class NodePattern(): 91 | def __init__(self, op, condition=None): 92 | self.op = op 93 | self.condition = condition # TODO: not implemented yet 94 | 95 | def match(self, graph, node): 96 | if isinstance(node, list): 97 | return [], None 98 | if self.op == node.op: 99 | following = graph.outgoing(node) 100 | if len(following) == 1: 101 | following = following[0] 102 | return [node], following 103 | else: 104 | return [], None 105 | 106 | 107 | class SerialPattern(): 108 | def __init__(self, patterns): 109 | self.patterns = patterns 110 | 111 | def match(self, graph, node): 112 | all_matches = [] 113 | for i, p in enumerate(self.patterns): 114 | matches, following = p.match(graph, node) 115 | if not matches: 116 | return [], None 117 | all_matches.extend(matches) 118 | if i < len(self.patterns) - 1: 119 | node = following # Might be more than one node 120 | return all_matches, following 121 | 122 | 123 | class ParallelPattern(): 124 | def __init__(self, patterns): 125 | self.patterns = patterns 126 | 127 | def match(self, graph, nodes): 128 | if not nodes: 129 | return [], None 130 | nodes = nodes if isinstance(nodes, list) else [nodes] 131 | # If a single node, assume we need to match with its siblings 132 | if len(nodes) == 1: 133 | nodes = graph.siblings(nodes[0]) 134 | else: 135 | # Verify all nodes have the same parent or all have no parent 136 | parents = [graph.incoming(n) for n in nodes] 137 | matches = [set(p) == set(parents[0]) for p in parents[1:]] 138 | if not all(matches): 139 | return [], None 140 | 141 | # TODO: If more nodes than patterns, we should consider 142 | # all permutations of the nodes 143 | if len(self.patterns) != len(nodes): 144 | return [], None 145 | 146 | patterns = self.patterns.copy() 147 | nodes = nodes.copy() 148 | all_matches = [] 149 | end_node = None 150 | for p in patterns: 151 | found = False 152 | for n in nodes: 153 | matches, following = p.match(graph, n) 154 | if matches: 155 | found = True 156 | nodes.remove(n) 157 | all_matches.extend(matches) 158 | # Verify all branches end in the same node 159 | if end_node: 160 | if end_node != following: 161 | return [], None 162 | else: 163 | end_node = following 164 | break 165 | if not found: 166 | return [], None 167 | return all_matches, end_node 168 | 169 | 170 | -------------------------------------------------------------------------------- /tensorwatch/model_graph/hiddenlayer/pytorch_builder.py: -------------------------------------------------------------------------------- 1 | """ 2 | HiddenLayer 3 | 4 | PyTorch graph importer. 5 | 6 | Written by Waleed Abdulla 7 | Licensed under the MIT License 8 | """ 9 | 10 | from __future__ import absolute_import, division, print_function 11 | import re 12 | from .graph import Graph, Node 13 | from . import transforms as ht 14 | import torch 15 | from collections import abc 16 | import numpy as np 17 | 18 | # PyTorch Graph Transforms 19 | FRAMEWORK_TRANSFORMS = [ 20 | # Hide onnx: prefix 21 | ht.Rename(op=r"onnx::(.*)", to=r"\1"), 22 | # ONNX uses Gemm for linear layers (stands for General Matrix Multiplication). 23 | # It's an odd name that noone recognizes. Rename it. 24 | ht.Rename(op=r"Gemm", to=r"Linear"), 25 | # PyTorch layers that don't have an ONNX counterpart 26 | ht.Rename(op=r"aten::max\_pool2d\_with\_indices", to="MaxPool"), 27 | # Shorten op name 28 | ht.Rename(op=r"BatchNormalization", to="BatchNorm"), 29 | ] 30 | 31 | 32 | def dump_pytorch_graph(graph): 33 | """List all the nodes in a PyTorch graph.""" 34 | f = "{:25} {:40} {} -> {}" 35 | print(f.format("kind", "scopeName", "inputs", "outputs")) 36 | for node in graph.nodes(): 37 | print(f.format(node.kind(), node.scopeName(), 38 | [i.unique() for i in node.inputs()], 39 | [i.unique() for i in node.outputs()] 40 | )) 41 | 42 | 43 | def pytorch_id(node): 44 | """Returns a unique ID for a node.""" 45 | # After ONNX simplification, the scopeName is not unique anymore 46 | # so append node outputs to guarantee uniqueness 47 | return node.scopeName() + "/outputs/" + "/".join([o.debugName() for o in node.outputs()]) 48 | 49 | 50 | 51 | def get_shape(torch_node): 52 | """Return the output shape of the given Pytorch node.""" 53 | # Extract node output shape from the node string representation 54 | # This is a hack because there doesn't seem to be an official way to do it. 55 | # See my quesiton in the PyTorch forum: 56 | # https://discuss.pytorch.org/t/node-output-shape-from-trace-graph/24351/2 57 | # TODO: find a better way to extract output shape 58 | # TODO: Assuming the node has one output. Update if we encounter a multi-output node. 59 | m = re.match(r".*Float\(([\d\s\,]+)\).*", str(next(torch_node.outputs()))) 60 | if m: 61 | shape = m.group(1) 62 | shape = shape.split(",") 63 | shape = tuple(map(int, shape)) 64 | else: 65 | shape = None 66 | return shape 67 | 68 | def calc_rf(model, input_shape): 69 | for n, p in model.named_parameters(): 70 | if not p.requires_grad: 71 | continue; 72 | if 'bias' in n: 73 | p.data.fill_(0) 74 | elif 'weight' in n: 75 | p.data.fill_(1) 76 | 77 | input = torch.ones(input_shape, requires_grad=True) 78 | output = model(input) 79 | out_shape = output.size() 80 | ndims = len(out_shape) 81 | grad = torch.zeros(out_shape) 82 | l_tmp=[] 83 | for i in xrange(ndims): 84 | if i==0 or i==1:#batch or channel 85 | l_tmp.append(0) 86 | else: 87 | l_tmp.append(out_shape[i]/2) 88 | 89 | grad[tuple(l_tmp)] = 1 90 | output.backward(gradient=grad) 91 | grad_np = img_.grad[0,0].data.numpy() 92 | idx_nonzeros = np.where(grad_np!=0) 93 | RF=[np.max(idx)-np.min(idx)+1 for idx in idx_nonzeros] 94 | 95 | return RF 96 | 97 | def import_graph(hl_graph, model, args, input_names=None, verbose=False): 98 | # TODO: add input names to graph 99 | 100 | if args is None: 101 | args = [1, 3, 224, 224] # assume ImageNet default 102 | 103 | # if args is not Tensor but is array like then convert it to torch tensor 104 | if not isinstance(args, torch.Tensor) and \ 105 | hasattr(args, "__len__") and hasattr(args, '__getitem__') and \ 106 | not isinstance(args, (str, abc.ByteString)): 107 | args = torch.ones(args) 108 | 109 | # Run the Pytorch graph to get a trace and generate a graph from it 110 | with torch.onnx.set_training(model, False): 111 | try: 112 | trace = torch.jit.trace(model, args) 113 | torch.onnx._optimize_trace(trace) 114 | torch_graph = trace.graph 115 | except RuntimeError as e: 116 | print(e) 117 | print('Error occured when creating jit trace for model.') 118 | raise e 119 | 120 | # Dump list of nodes (DEBUG only) 121 | if verbose: 122 | dump_pytorch_graph(torch_graph) 123 | 124 | # Loop through nodes and build HL graph 125 | nodes = list(torch_graph.nodes()) 126 | inps = [(n, [i.unique() for i in n.inputs()]) for n in nodes] 127 | for i, torch_node in enumerate(nodes): 128 | # Op 129 | op = torch_node.kind() 130 | # Parameters 131 | params = {k: torch_node[k] for k in torch_node.attributeNames()} 132 | # Inputs/outputs 133 | # TODO: inputs = [i.unique() for i in node.inputs()] 134 | outputs = [o.unique() for o in torch_node.outputs()] 135 | # Get output shape 136 | shape = get_shape(torch_node) 137 | # Add HL node 138 | hl_node = Node(uid=pytorch_id(torch_node), name=None, op=op, 139 | output_shape=shape, params=params) 140 | hl_graph.add_node(hl_node) 141 | # Add edges 142 | for target_torch_node,target_inputs in inps: 143 | if set(outputs) & set(target_inputs): 144 | hl_graph.add_edge_by_id(pytorch_id(torch_node), pytorch_id(target_torch_node), shape) 145 | return hl_graph 146 | -------------------------------------------------------------------------------- /tensorwatch/model_graph/hiddenlayer/pytorch_builder_grad.py: -------------------------------------------------------------------------------- 1 | ''' 2 | File name: plot-pytorch-autograd-graph.py 3 | Author: Ludovic Trottier 4 | Date created: November 8, 2017. 5 | Date last modified: November 8, 2017 6 | Credits: moskomule (https://discuss.pytorch.org/t/print-autograd-graph/692/15) 7 | ''' 8 | from graphviz import Digraph 9 | import torch 10 | from torch.autograd import Variable 11 | 12 | from . import transforms as ht 13 | from collections import abc 14 | import numpy as np 15 | from .graph import Graph, Node 16 | 17 | # PyTorch Graph Transforms 18 | FRAMEWORK_TRANSFORMS = [ 19 | # Hide onnx: prefix 20 | ht.Rename(op=r"onnx::(.*)", to=r"\1"), 21 | # ONNX uses Gemm for linear layers (stands for General Matrix Multiplication). 22 | # It's an odd name that noone recognizes. Rename it. 23 | ht.Rename(op=r"Gemm", to=r"Linear"), 24 | # PyTorch layers that don't have an ONNX counterpart 25 | ht.Rename(op=r"aten::max\_pool2d\_with\_indices", to="MaxPool"), 26 | # Shorten op name 27 | ht.Rename(op=r"BatchNormalization", to="BatchNorm"), 28 | ] 29 | 30 | def add_node2dot(dot, var, id, label, op=None, output_shape=None, params=None): 31 | hl_node = Node(uid=id, name=op, op=label, 32 | output_shape=output_shape, params=params) 33 | dot.add_node(hl_node) 34 | 35 | def make_dot(var, params, dot): 36 | """ Produces Graphviz representation of PyTorch autograd graph. 37 | 38 | Blue nodes are trainable Variables (weights, bias). 39 | Orange node are saved tensors for the backward pass. 40 | 41 | Args: 42 | var: output Variable 43 | params: list of (name, Parameters) 44 | """ 45 | param_map2 = {k:v for k, v in params} 46 | print(param_map2) 47 | param_map = {id(v): k for k, v in params} 48 | 49 | 50 | 51 | node_attr = dict(style='filled', 52 | shape='box', 53 | align='left', 54 | fontsize='12', 55 | ranksep='0.1', 56 | height='0.2') 57 | 58 | # dot = Digraph( 59 | # filename='network', 60 | # format='pdf', 61 | # node_attr=node_attr, 62 | # graph_attr=dict(size="12,12")) 63 | seen = set() 64 | 65 | def add_nodes(dot, var): 66 | if var not in seen: 67 | 68 | node_id = str(id(var)) 69 | 70 | if torch.is_tensor(var): 71 | node_label = "saved tensor\n{}".format(tuple(var.size())) 72 | add_node2dot(dot, var, node_id, node_label, op=None) 73 | 74 | elif hasattr(var, 'variable'): 75 | variable_name = param_map.get(id(var.variable)) 76 | variable_size = tuple(var.variable.size()) 77 | node_name = "{}\n{}".format(variable_name, variable_size) 78 | add_node2dot(dot, var, node_id, node_name, op=None) 79 | 80 | else: 81 | node_label = type(var).__name__.replace('Backward', '') 82 | add_node2dot(dot, var, node_id, node_label, op=None) 83 | 84 | seen.add(var) 85 | 86 | if hasattr(var, 'next_functions'): 87 | for u in var.next_functions: 88 | if u[0] is not None: 89 | dot.add_edge_by_id(str(id(u[0])), str(id(var)), None) 90 | add_nodes(dot, u[0]) 91 | 92 | if hasattr(var, 'saved_tensors'): 93 | for t in var.saved_tensors: 94 | dot.add_edge_by_id(str(id(t)), str(id(var)), None) 95 | add_nodes(dot, t) 96 | 97 | add_nodes(dot, var.grad_fn) 98 | 99 | return dot 100 | 101 | def import_graph(hl_graph, model, args, input_names=None, verbose=False): 102 | if args is None: 103 | args = [1, 3, 224, 224] # assume ImageNet default 104 | 105 | # if args is not Tensor but is array like then convert it to torch tensor 106 | if not isinstance(args, torch.Tensor) and \ 107 | hasattr(args, "__len__") and hasattr(args, '__getitem__') and \ 108 | not isinstance(args, (str, abc.ByteString)): 109 | args = torch.ones(args) 110 | 111 | y = model(args) 112 | g = make_dot(y, model.named_parameters(), hl_graph) 113 | return hl_graph -------------------------------------------------------------------------------- /tensorwatch/model_graph/hiddenlayer/pytorch_builder_trace.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard._pytorch_graph import GraphPy, NodePyIO, NodePyOP 2 | import torch 3 | 4 | from . import transforms as ht 5 | from collections import abc 6 | import numpy as np 7 | 8 | # PyTorch Graph Transforms 9 | FRAMEWORK_TRANSFORMS = [ 10 | # Hide onnx: prefix 11 | ht.Rename(op=r"onnx::(.*)", to=r"\1"), 12 | # ONNX uses Gemm for linear layers (stands for General Matrix Multiplication). 13 | # It's an odd name that noone recognizes. Rename it. 14 | ht.Rename(op=r"Gemm", to=r"Linear"), 15 | # PyTorch layers that don't have an ONNX counterpart 16 | ht.Rename(op=r"aten::max\_pool2d\_with\_indices", to="MaxPool"), 17 | # Shorten op name 18 | ht.Rename(op=r"BatchNormalization", to="BatchNorm"), 19 | ] 20 | 21 | def parse(graph, args=None, omit_useless_nodes=True): 22 | """This method parses an optimized PyTorch model graph and produces 23 | a list of nodes and node stats for eventual conversion to TensorBoard 24 | protobuf format. 25 | Args: 26 | graph (PyTorch module): The model to be parsed. 27 | args (tuple): input tensor[s] for the model. 28 | omit_useless_nodes (boolean): Whether to remove nodes from the graph. 29 | """ 30 | n_inputs = len(args) 31 | 32 | scope = {} 33 | nodes_py = GraphPy() 34 | for i, node in enumerate(graph.inputs()): 35 | if omit_useless_nodes: 36 | if len(node.uses()) == 0: # number of user of the node (= number of outputs/ fanout) 37 | continue 38 | 39 | if i < n_inputs: 40 | nodes_py.append(NodePyIO(node, 'input')) 41 | else: 42 | nodes_py.append(NodePyIO(node)) # parameter 43 | 44 | for node in graph.nodes(): 45 | nodes_py.append(NodePyOP(node)) 46 | 47 | for node in graph.outputs(): # must place last. 48 | NodePyIO(node, 'output') 49 | nodes_py.find_common_root() 50 | nodes_py.populate_namespace_from_OP_to_IO() 51 | return nodes_py 52 | 53 | 54 | def graph(model, args, verbose=False): 55 | """ 56 | This method processes a PyTorch model and produces a `GraphDef` proto 57 | that can be logged to TensorBoard. 58 | Args: 59 | model (PyTorch module): The model to be parsed. 60 | args (tuple): input tensor[s] for the model. 61 | verbose (bool): Whether to print out verbose information while 62 | processing. 63 | """ 64 | with torch.onnx.set_training(model, False): # TODO: move outside of torch.onnx? 65 | try: 66 | trace = torch.jit.trace(model, args) 67 | graph = trace.graph 68 | except RuntimeError as e: 69 | print(e) 70 | print('Error occurs, No graph saved') 71 | raise e 72 | 73 | if verbose: 74 | print(graph) 75 | return parse(graph, args) 76 | 77 | 78 | def import_graph(hl_graph, model, args, input_names=None, verbose=False): 79 | # TODO: add input names to graph 80 | 81 | if args is None: 82 | args = [1, 3, 224, 224] # assume ImageNet default 83 | 84 | # if args is not Tensor but is array like then convert it to torch tensor 85 | if not isinstance(args, torch.Tensor) and \ 86 | hasattr(args, "__len__") and hasattr(args, '__getitem__') and \ 87 | not isinstance(args, (str, abc.ByteString)): 88 | args = torch.ones(args) 89 | 90 | graph_py = graph(model, args, verbose) 91 | 92 | # # Loop through nodes and build HL graph 93 | # nodes = list(torch_graph.nodes()) 94 | # inps = [(n, [i.unique() for i in n.inputs()]) for n in nodes] 95 | # for i, torch_node in enumerate(nodes): 96 | # # Op 97 | # op = torch_node.kind() 98 | # # Parameters 99 | # params = {k: torch_node[k] for k in torch_node.attributeNames()} 100 | # # Inputs/outputs 101 | # # TODO: inputs = [i.unique() for i in node.inputs()] 102 | # outputs = [o.unique() for o in torch_node.outputs()] 103 | # # Get output shape 104 | # shape = get_shape(torch_node) 105 | # # Add HL node 106 | # hl_node = Node(uid=pytorch_id(torch_node), name=None, op=op, 107 | # output_shape=shape, params=params) 108 | # hl_graph.add_node(hl_node) 109 | # # Add edges 110 | # for target_torch_node,target_inputs in inps: 111 | # if set(outputs) & set(target_inputs): 112 | # hl_graph.add_edge_by_id(pytorch_id(torch_node), pytorch_id(target_torch_node), shape) 113 | return hl_graph 114 | -------------------------------------------------------------------------------- /tensorwatch/model_graph/torchstat/README.md: -------------------------------------------------------------------------------- 1 | # Credits 2 | 3 | Code in this folder is almost as-is from torchstat repository located at https://github.com/Swall0w/torchstat. 4 | 5 | Additional merges are from: 6 | - https://github.com/kenshohara/torchstat 7 | - https://github.com/lyakaap/torchstat -------------------------------------------------------------------------------- /tensorwatch/model_graph/torchstat/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/tensorwatch/model_graph/torchstat/__init__.py -------------------------------------------------------------------------------- /tensorwatch/model_graph/torchstat/compute_flops.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import numpy as np 4 | import math 5 | 6 | def compute_flops(module, inp, out): 7 | if isinstance(module, nn.Conv2d): 8 | return compute_Conv2d_flops(module, inp[0], out[0]) 9 | elif isinstance(module, nn.BatchNorm2d): 10 | return compute_BatchNorm2d_flops(module, inp[0], out[0]) 11 | elif isinstance(module, (nn.AvgPool2d, nn.MaxPool2d)): 12 | return compute_Pool2d_flops(module, inp[0], out[0]) 13 | elif isinstance(module, (nn.AdaptiveAvgPool2d, nn.AdaptiveMaxPool2d)): 14 | return compute_adaptivepool_flops(module, inp[0], out[0]) 15 | elif isinstance(module, (nn.ReLU, nn.ReLU6, nn.PReLU, nn.ELU, nn.LeakyReLU)): 16 | return compute_ReLU_flops(module, inp[0], out[0]) 17 | elif isinstance(module, nn.Upsample): 18 | return compute_Upsample_flops(module, inp[0], out[0]) 19 | elif isinstance(module, nn.Linear): 20 | return compute_Linear_flops(module, inp[0], out[0]) 21 | else: 22 | #print(f"[Flops]: {type(module).__name__} is not supported!") 23 | return 0 24 | pass 25 | 26 | 27 | def compute_Conv2d_flops(module, inp, out): 28 | # Can have multiple inputs, getting the first one 29 | assert isinstance(module, nn.Conv2d) 30 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 31 | 32 | batch_size = inp.size()[0] 33 | in_c = inp.size()[1] 34 | k_h, k_w = module.kernel_size 35 | out_c, out_h, out_w = out.size()[1:] 36 | groups = module.groups 37 | 38 | filters_per_channel = out_c // groups 39 | conv_per_position_flops = k_h * k_w * in_c * filters_per_channel 40 | active_elements_count = batch_size * out_h * out_w 41 | 42 | total_conv_flops = conv_per_position_flops * active_elements_count 43 | 44 | bias_flops = 0 45 | if module.bias is not None: 46 | bias_flops = out_c * active_elements_count 47 | 48 | total_flops = total_conv_flops + bias_flops 49 | return total_flops 50 | 51 | def compute_adaptivepool_flops(module, input, output): 52 | # credits: https://github.com/xternalz/SDPoint/blob/master/utils/flops.py 53 | batch_size = input.size(0) 54 | input_planes = input.size(1) 55 | input_height = input.size(2) 56 | input_width = input.size(3) 57 | 58 | flops = 0 59 | for i in range(output.size(2)): 60 | y_start = int(math.floor(float(i * input_height) / output.size(2))) 61 | y_end = int(math.ceil(float((i + 1) * input_height) / output.size(2))) 62 | for j in range(output.size(3)): 63 | x_start = int(math.floor(float(j * input_width) / output.size(3))) 64 | x_end = int(math.ceil(float((j + 1) * input_width) / output.size(3))) 65 | 66 | flops += batch_size * input_planes * (y_end-y_start+1) * (x_end-x_start+1) 67 | return flops 68 | 69 | def compute_BatchNorm2d_flops(module, inp, out): 70 | assert isinstance(module, nn.BatchNorm2d) 71 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 72 | in_c, in_h, in_w = inp.size()[1:] 73 | batch_flops = np.prod(inp.shape) 74 | if module.affine: 75 | batch_flops *= 2 76 | return batch_flops 77 | 78 | 79 | def compute_ReLU_flops(module, inp, out): 80 | assert isinstance(module, (nn.ReLU, nn.ReLU6, nn.PReLU, nn.ELU, nn.LeakyReLU)) 81 | batch_size = inp.size()[0] 82 | active_elements_count = batch_size 83 | 84 | for s in inp.size()[1:]: 85 | active_elements_count *= s 86 | 87 | return active_elements_count 88 | 89 | 90 | def compute_Pool2d_flops(module, input, out): 91 | batch_size = input.size(0) 92 | input_planes = input.size(1) 93 | input_height = input.size(2) 94 | input_width = input.size(3) 95 | kernel_size = ('int' in str(type(module.kernel_size))) and [module.kernel_size, module.kernel_size] or module.kernel_size 96 | kernel_ops = kernel_size[0] * kernel_size[1] 97 | stride = ('int' in str(type(module.stride))) and [module.stride, module.stride] or module.stride 98 | padding = ('int' in str(type(module.padding))) and [module.padding, module.padding] or module.padding 99 | 100 | output_width = math.floor((input_width + 2 * padding[0] - kernel_size[0]) / float(stride[0]) + 1) 101 | output_height = math.floor((input_height + 2 * padding[1] - kernel_size[1]) / float(stride[0]) + 1) 102 | return batch_size * input_planes * output_width * output_height * kernel_ops 103 | 104 | 105 | def compute_Linear_flops(module, inp, out): 106 | assert isinstance(module, nn.Linear) 107 | assert len(inp.size()) == 2 and len(out.size()) == 2 108 | batch_size = inp.size()[0] 109 | return batch_size * inp.size()[1] * out.size()[1] 110 | 111 | def compute_Upsample_flops(module, inp, out): 112 | assert isinstance(module, nn.Upsample) 113 | output_size = out[0] 114 | batch_size = inp.size()[0] 115 | output_elements_count = batch_size 116 | for s in output_size.shape[1:]: 117 | output_elements_count *= s 118 | 119 | return output_elements_count 120 | -------------------------------------------------------------------------------- /tensorwatch/model_graph/torchstat/compute_memory.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import numpy as np 4 | 5 | 6 | def compute_memory(module, inp, out): 7 | if isinstance(module, (nn.ReLU, nn.ReLU6, nn.ELU, nn.LeakyReLU)): 8 | return compute_ReLU_memory(module, inp[0], out[0]) 9 | elif isinstance(module, nn.PReLU): 10 | return compute_PReLU_memory(module, inp[0], out[0]) 11 | elif isinstance(module, nn.Conv2d): 12 | return compute_Conv2d_memory(module, inp[0], out[0]) 13 | elif isinstance(module, nn.BatchNorm2d): 14 | return compute_BatchNorm2d_memory(module, inp[0], out[0]) 15 | elif isinstance(module, nn.Linear): 16 | return compute_Linear_memory(module, inp[0], out[0]) 17 | elif isinstance(module, (nn.AvgPool2d, nn.MaxPool2d)): 18 | return compute_Pool2d_memory(module, inp[0], out[0]) 19 | else: 20 | #print(f"[Memory]: {type(module).__name__} is not supported!") 21 | return 0, 0 22 | pass 23 | 24 | 25 | def num_params(module): 26 | return sum(p.numel() for p in module.parameters() if p.requires_grad) 27 | 28 | 29 | def compute_ReLU_memory(module, inp, out): 30 | assert isinstance(module, (nn.ReLU, nn.ReLU6, nn.ELU, nn.LeakyReLU)) 31 | 32 | mread = inp.numel() 33 | mwrite = out.numel() 34 | 35 | return mread*inp.element_size(), mwrite*out.element_size() 36 | 37 | 38 | def compute_PReLU_memory(module, inp, out): 39 | assert isinstance(module, nn.PReLU) 40 | 41 | batch_size = inp.size()[0] 42 | mread = batch_size * (inp[0].numel() + num_params(module)) 43 | mwrite = out.numel() 44 | 45 | return mread*inp.element_size(), mwrite*out.element_size() 46 | 47 | 48 | def compute_Conv2d_memory(module, inp, out): 49 | # Can have multiple inputs, getting the first one 50 | assert isinstance(module, nn.Conv2d) 51 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 52 | 53 | batch_size = inp.size()[0] 54 | 55 | # This includes weights with bias if the module contains it. 56 | mread = batch_size * (inp[0].numel() + num_params(module)) 57 | mwrite = out.numel() 58 | 59 | return mread*inp.element_size(), mwrite*out.element_size() 60 | 61 | 62 | def compute_BatchNorm2d_memory(module, inp, out): 63 | assert isinstance(module, nn.BatchNorm2d) 64 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 65 | 66 | batch_size, in_c, in_h, in_w = inp.size() 67 | mread = batch_size * (inp[0].numel() + 2 * in_c) 68 | mwrite = out.numel() 69 | 70 | return mread*inp.element_size(), mwrite*out.element_size() 71 | 72 | 73 | def compute_Linear_memory(module, inp, out): 74 | assert isinstance(module, nn.Linear) 75 | assert len(inp.size()) == 2 and len(out.size()) == 2 76 | 77 | batch_size = inp.size()[0] 78 | 79 | # This includes weights with bias if the module contains it. 80 | mread = batch_size * (inp[0].numel() + num_params(module)) 81 | mwrite = out.numel() 82 | 83 | return mread*inp.element_size(), mwrite*out.element_size() 84 | 85 | def compute_Pool2d_memory(module, inp, out): 86 | assert isinstance(module, (nn.MaxPool2d, nn.AvgPool2d)) 87 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 88 | 89 | mread = inp.numel() 90 | mwrite = out.numel() 91 | 92 | return mread*inp.element_size(), mwrite*out.element_size() 93 | -------------------------------------------------------------------------------- /tensorwatch/model_graph/torchstat/reporter.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | 4 | pd.set_option('display.width', 1000) 5 | pd.set_option('display.max_rows', 10000) 6 | pd.set_option('display.max_columns', 10000) 7 | 8 | 9 | def round_value(value, binary=False): 10 | divisor = 1024. if binary else 1000. 11 | 12 | if value // divisor**4 > 0: 13 | return str(round(value / divisor**4, 2)) + 'T' 14 | elif value // divisor**3 > 0: 15 | return str(round(value / divisor**3, 2)) + 'G' 16 | elif value // divisor**2 > 0: 17 | return str(round(value / divisor**2, 2)) + 'M' 18 | elif value // divisor > 0: 19 | return str(round(value / divisor, 2)) + 'K' 20 | return str(value) 21 | 22 | 23 | def report_format(collected_nodes): 24 | data = list() 25 | for node in collected_nodes: 26 | name = node.name 27 | input_shape = ' '.join(['{:>3d}'] * len(node.input_shape)).format( 28 | *[e for e in node.input_shape]) 29 | output_shape = ' '.join(['{:>3d}'] * len(node.output_shape)).format( 30 | *[e for e in node.output_shape]) 31 | parameter_quantity = node.parameter_quantity 32 | inference_memory = node.inference_memory 33 | MAdd = node.MAdd 34 | Flops = node.Flops 35 | mread, mwrite = [i for i in node.Memory] 36 | duration = node.duration 37 | data.append([name, input_shape, output_shape, parameter_quantity, 38 | inference_memory, MAdd, duration, Flops, mread, 39 | mwrite]) 40 | df = pd.DataFrame(data) 41 | df.columns = ['module name', 'input shape', 'output shape', 42 | 'params', 'memory(MB)', 43 | 'MAdd', 'duration', 'Flops', 'MemRead(B)', 'MemWrite(B)'] 44 | df['duration[%]'] = df['duration'] / (df['duration'].sum() + 1e-7) 45 | df['MemR+W(B)'] = df['MemRead(B)'] + df['MemWrite(B)'] 46 | total_parameters_quantity = df['params'].sum() 47 | total_memory = df['memory(MB)'].sum() 48 | total_operation_quantity = df['MAdd'].sum() 49 | total_flops = df['Flops'].sum() 50 | total_duration = df['duration[%]'].sum() 51 | total_mread = df['MemRead(B)'].sum() 52 | total_mwrite = df['MemWrite(B)'].sum() 53 | total_memrw = df['MemR+W(B)'].sum() 54 | del df['duration'] 55 | 56 | # Add Total row 57 | total_df = pd.Series([total_parameters_quantity, total_memory, 58 | total_operation_quantity, total_flops, 59 | total_duration, mread, mwrite, total_memrw], 60 | index=['params', 'memory(MB)', 'MAdd', 'Flops', 'duration[%]', 61 | 'MemRead(B)', 'MemWrite(B)', 'MemR+W(B)'], 62 | name='total') 63 | df = df.append(total_df) 64 | 65 | df = df.fillna(' ') 66 | df['memory(MB)'] = df['memory(MB)'].apply( 67 | lambda x: '{:.2f}'.format(x)) 68 | df['duration[%]'] = df['duration[%]'].apply(lambda x: '{:.2%}'.format(x)) 69 | df['MAdd'] = df['MAdd'].apply(lambda x: '{:,}'.format(x)) 70 | df['Flops'] = df['Flops'].apply(lambda x: '{:,}'.format(x)) 71 | 72 | summary = str(df) + '\n' 73 | summary += "=" * len(str(df).split('\n')[0]) 74 | summary += '\n' 75 | summary += "Total params: {:,}\n".format(total_parameters_quantity) 76 | 77 | summary += "-" * len(str(df).split('\n')[0]) 78 | summary += '\n' 79 | summary += "Total memory: {:.2f}MB\n".format(total_memory) 80 | summary += "Total MAdd: {}MAdd\n".format(round_value(total_operation_quantity)) 81 | summary += "Total Flops: {}Flops\n".format(round_value(total_flops)) 82 | summary += "Total MemR+W: {}B\n".format(round_value(total_memrw, True)) 83 | return summary 84 | -------------------------------------------------------------------------------- /tensorwatch/model_graph/torchstat_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from .torchstat import analyzer 5 | import pandas as pd 6 | import copy 7 | 8 | class LayerStats: 9 | def __init__(self, node) -> None: 10 | self.name = node.name 11 | self.input_shape = node.input_shape 12 | self.output_shape = node.output_shape 13 | self.parameters = node.parameter_quantity 14 | self.inference_memory = node.inference_memory 15 | self.MAdd = node.MAdd 16 | self.Flops = node.Flops 17 | self.mread, self.mwrite = node.Memory[0], node.Memory[1] 18 | self.duration = node.duration 19 | 20 | class ModelStats(LayerStats): 21 | def __init__(self, model, input_shape, clone_model=False) -> None: 22 | if clone_model: 23 | model = copy.deepcopy(model) 24 | collected_nodes = analyzer.analyze(model, input_shape, 1) 25 | self.layer_stats = [] 26 | for node in collected_nodes: 27 | self.layer_stats.append(LayerStats(node)) 28 | 29 | self.name = 'Model' 30 | self.input_shape = input_shape 31 | self.output_shape = self.layer_stats[-1].output_shape 32 | self.parameters = sum((l.parameters for l in self.layer_stats)) 33 | self.inference_memory = sum((l.inference_memory for l in self.layer_stats)) 34 | self.MAdd = sum((l.MAdd for l in self.layer_stats)) 35 | self.Flops = sum((l.Flops for l in self.layer_stats)) 36 | self.mread = sum((l.mread for l in self.layer_stats)) 37 | self.mwrite = sum((l.mwrite for l in self.layer_stats)) 38 | self.duration = sum((l.duration for l in self.layer_stats)) 39 | 40 | def model_stats(model, input_shape): 41 | ms = ModelStats(model, input_shape) 42 | return model_stats2df(ms) 43 | 44 | def _round_value(value, binary=False): 45 | divisor = 1024. if binary else 1000. 46 | 47 | if value // divisor**4 > 0: 48 | return str(round(value / divisor**4, 2)) + 'T' 49 | elif value // divisor**3 > 0: 50 | return str(round(value / divisor**3, 2)) + 'G' 51 | elif value // divisor**2 > 0: 52 | return str(round(value / divisor**2, 2)) + 'M' 53 | elif value // divisor > 0: 54 | return str(round(value / divisor, 2)) + 'K' 55 | return str(value) 56 | 57 | 58 | def model_stats2df(model_stats:ModelStats): 59 | pd.set_option('display.width', 1000) 60 | pd.set_option('display.max_rows', 10000) 61 | pd.set_option('display.max_columns', 10000) 62 | 63 | df = pd.DataFrame([l.__dict__ for l in model_stats.layer_stats]) 64 | total_df = pd.Series(model_stats.__dict__, name='Total') 65 | df = df.append(total_df[df.columns], ignore_index=True) 66 | 67 | df = df.fillna(' ') 68 | # df['memory(MB)'] = df['memory(MB)'].apply( 69 | # lambda x: '{:.2f}'.format(x)) 70 | # df['duration[%]'] = df['duration[%]'].apply(lambda x: '{:.2%}'.format(x)) 71 | for c in ['MAdd', 'Flops', 'parameters', 'inference_memory', 'mread', 'mwrite']: 72 | df[c] = df[c].apply(lambda x: '{:,}'.format(x)) 73 | 74 | df.rename(columns={'name': 'module name', 75 | 'input_shape': 'input shape', 76 | 'output_shape': 'output shape', 77 | 'inference_memory': 'infer memory(MB)', 78 | 'mread': 'MemRead(B)', 79 | 'mwrite': 'MemWrite(B)' 80 | }, inplace=True) 81 | 82 | #summary = "Total params: {:,}\n".format(total_parameters_quantity) 83 | 84 | #summary += "-" * len(str(df).split('\n')[0]) 85 | #summary += '\n' 86 | #summary += "Total memory: {:.2f}MB\n".format(total_memory) 87 | #summary += "Total MAdd: {}MAdd\n".format(_round_value(total_operation_quantity)) 88 | #summary += "Total Flops: {}Flops\n".format(_round_value(total_flops)) 89 | #summary += "Total MemR+W: {}B\n".format(_round_value(total_memrw, True)) 90 | return df 91 | -------------------------------------------------------------------------------- /tensorwatch/mpl/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | 5 | -------------------------------------------------------------------------------- /tensorwatch/mpl/histogram.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from .base_mpl_plot import BaseMplPlot 5 | from .. import image_utils 6 | from .. import utils 7 | import numpy as np 8 | 9 | class Histogram(BaseMplPlot): 10 | def init_stream_plot(self, stream_vis, 11 | xtitle='', ytitle='', ztitle='', colormap=None, color=None, 12 | bins=None, normed=None, histtype='bar', edge_color=None, linewidth=None, 13 | opacity=None, **stream_vis_args): 14 | 15 | # add main subplot 16 | stream_vis.bins, stream_vis.normed, stream_vis.linewidth = bins, normed, (linewidth or 2) 17 | stream_vis.ax = self.get_main_axis() 18 | stream_vis.series = [] 19 | stream_vis.bars_artists = [] # stores previously drawn bars 20 | 21 | stream_vis.cmap = image_utils.get_cmap(name=colormap or 'Dark2') 22 | if color is None: 23 | if not self.is_3d: 24 | color = stream_vis.cmap((len(self._stream_vises)%stream_vis.cmap.N)/stream_vis.cmap.N) # pylint: disable=no-member 25 | stream_vis.color = color 26 | stream_vis.edge_color = 'black' 27 | stream_vis.histtype = histtype 28 | stream_vis.opacity = opacity 29 | stream_vis.ax.set_xlabel(xtitle) 30 | stream_vis.ax.xaxis.label.set_style('italic') 31 | stream_vis.ax.set_ylabel(ytitle) 32 | stream_vis.ax.yaxis.label.set_color(color) 33 | stream_vis.ax.yaxis.label.set_style('italic') 34 | if self.is_3d: 35 | stream_vis.ax.set_zlabel(ztitle) 36 | stream_vis.ax.zaxis.label.set_style('italic') 37 | 38 | def is_show_grid(self): #override 39 | return False 40 | 41 | def clear_artists(self, stream_vis): 42 | for bar in stream_vis.bars_artists: 43 | bar.remove() 44 | stream_vis.bars_artists.clear() 45 | 46 | def clear_plot(self, stream_vis, clear_history): 47 | stream_vis.series.clear() 48 | self.clear_artists(stream_vis) 49 | 50 | def _show_stream_items(self, stream_vis, stream_items): 51 | """Paint the given stream_items in to visualizer. If visualizer is dirty then return False else True. 52 | """ 53 | 54 | vals = self._extract_vals(stream_items) 55 | if not len(vals): 56 | return True 57 | 58 | stream_vis.series += vals 59 | self.clear_artists(stream_vis) 60 | n, bins, stream_vis.bars_artists = stream_vis.ax.hist(stream_vis.series, bins=stream_vis.bins, 61 | normed=stream_vis.normed, color=stream_vis.color, edgecolor=stream_vis.edge_color, 62 | histtype=stream_vis.histtype, alpha=stream_vis.opacity, 63 | linewidth=stream_vis.linewidth) 64 | 65 | stream_vis.ax.set_xticks(bins) 66 | 67 | #stream_vis.ax.relim() 68 | #stream_vis.ax.autoscale_view() 69 | 70 | return False 71 | -------------------------------------------------------------------------------- /tensorwatch/mpl/image_plot.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from .base_mpl_plot import BaseMplPlot 5 | from .. import utils, image_utils 6 | import numpy as np 7 | 8 | #from IPython import get_ipython 9 | 10 | class ImagePlot(BaseMplPlot): 11 | def init_stream_plot(self, stream_vis, 12 | rows=2, cols=5, img_width=None, img_height=None, img_channels=None, 13 | colormap=None, viz_img_scale=None, **stream_vis_args): 14 | stream_vis.rows, stream_vis.cols = rows, cols 15 | stream_vis.img_channels, stream_vis.colormap = img_channels, colormap 16 | stream_vis.img_width, stream_vis.img_height = img_width, img_height 17 | stream_vis.viz_img_scale = viz_img_scale 18 | # subplots holding each image 19 | stream_vis.axs = [[None for _ in range(cols)] for _ in range(rows)] 20 | # axis image 21 | stream_vis.ax_imgs = [[None for _ in range(cols)] for _ in range(rows)] 22 | 23 | def clear_plot(self, stream_vis, clear_history): 24 | for row in range(stream_vis.rows): 25 | for col in range(stream_vis.cols): 26 | img = stream_vis.ax_imgs[row][col] 27 | if img: 28 | x, y = img.get_size() 29 | img.set_data(np.zeros((x, y))) 30 | 31 | def _show_stream_items(self, stream_vis, stream_items): 32 | """Paint the given stream_items in to visualizer. If visualizer is dirty then return False else True. 33 | """ 34 | 35 | # as we repaint each image plot, select last if multiple events were pending 36 | stream_item = None 37 | for er in reversed(stream_items): 38 | if not(er.ended or er.value is None): 39 | stream_item = er 40 | break 41 | if stream_item is None: 42 | return True 43 | 44 | row, col, i = 0, 0, 0 45 | dirty = False 46 | # stream_item.value is expected to be ImagePlotItems 47 | for image_list in stream_item.value: 48 | # convert to imshow compatible, stitch images 49 | images = [image_utils.to_imshow_array(img, stream_vis.img_width, stream_vis.img_height) \ 50 | for img in image_list.images if img is not None] 51 | img_viz = image_utils.stitch_horizontal(images, width_dim=1) 52 | 53 | # resize if requested 54 | if stream_vis.viz_img_scale is not None: 55 | import skimage.transform # expensive import done on demand 56 | 57 | if isinstance(img_viz, np.ndarray) and np.issubdtype(img_viz.dtype, np.floating): 58 | img_viz = img_viz.clip(-1, 1) # some MNIST images have out of range values causing exception in sklearn 59 | img_viz = skimage.transform.rescale(img_viz, 60 | (stream_vis.viz_img_scale, stream_vis.viz_img_scale), mode='reflect', preserve_range=False) 61 | 62 | # create subplot if it doesn't exist 63 | ax = stream_vis.axs[row][col] 64 | if ax is None: 65 | ax = stream_vis.axs[row][col] = \ 66 | self.figure.add_subplot(stream_vis.rows, stream_vis.cols, i+1) 67 | ax.set_xticks([]) 68 | ax.set_yticks([]) 69 | 70 | cmap = image_list.cmap or ('Greys' if stream_vis.colormap is None and \ 71 | len(img_viz.shape) == 2 else stream_vis.colormap) 72 | 73 | stream_vis.ax_imgs[row][col] = ax.imshow(img_viz, interpolation="none", cmap=cmap, alpha=image_list.alpha) 74 | dirty = True 75 | 76 | # set title 77 | title = image_list.title 78 | if len(title) > 12: #wordwrap if too long 79 | title = utils.wrap_string(title) if len(title) > 24 else title 80 | fontsize = 8 81 | else: 82 | fontsize = 12 83 | ax.set_title(title, fontsize=fontsize) #'fontweight': 'light' 84 | 85 | #ax.autoscale_view() # not needed 86 | col = col + 1 87 | if col >= stream_vis.cols: 88 | col = 0 89 | row = row + 1 90 | if row >= stream_vis.rows: 91 | break 92 | i += 1 93 | 94 | return not dirty 95 | 96 | 97 | def has_legend(self): 98 | return self.show_legend or False -------------------------------------------------------------------------------- /tensorwatch/mpl/pie_chart.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from .base_mpl_plot import BaseMplPlot 5 | from .. import image_utils 6 | from .. import utils 7 | import numpy as np 8 | 9 | class PieChart(BaseMplPlot): 10 | def init_stream_plot(self, stream_vis, autopct=None, colormap=None, color=None, 11 | shadow=None, **stream_vis_args): 12 | 13 | # add main subplot 14 | stream_vis.autopct, stream_vis.shadow = autopct, True if shadow is None else shadow 15 | stream_vis.ax = self.get_main_axis() 16 | stream_vis.series = [] 17 | stream_vis.wedge_artists = [] # stores previously drawn bars 18 | 19 | stream_vis.cmap = image_utils.get_cmap(name=colormap or 'Dark2') 20 | if color is None: 21 | if not self.is_3d: 22 | color = stream_vis.cmap((len(self._stream_vises)%stream_vis.cmap.N)/stream_vis.cmap.N) # pylint: disable=no-member 23 | stream_vis.color = color 24 | 25 | def clear_artists(self, stream_vis): 26 | for artist in stream_vis.wedge_artists: 27 | artist.remove() 28 | stream_vis.wedge_artists.clear() 29 | 30 | def clear_plot(self, stream_vis, clear_history): 31 | stream_vis.series.clear() 32 | self.clear_artists(stream_vis) 33 | 34 | def _show_stream_items(self, stream_vis, stream_items): 35 | """Paint the given stream_items in to visualizer. If visualizer is dirty then return False else True. 36 | """ 37 | 38 | vals = self._extract_vals(stream_items) 39 | if not len(vals): 40 | return True 41 | 42 | # make sure tuple has 4 elements 43 | unpacker = lambda a0=None,a1=None,a2=None,a3=None,*_:(a0,a1,a2,a3) 44 | stream_vis.series.extend([unpacker(*val) for val in vals]) 45 | self.clear_artists(stream_vis) 46 | 47 | labels, sizes, colors, explode = \ 48 | [t[0] for t in stream_vis.series], \ 49 | [t[1] for t in stream_vis.series], \ 50 | [(t[2] or stream_vis.cmap.colors[i % len(stream_vis.cmap.colors)]) \ 51 | for i, t in enumerate(stream_vis.series)], \ 52 | [t[3] or 0 for t in stream_vis.series], 53 | 54 | 55 | stream_vis.wedge_artists, *_ = stream_vis.ax.pie( \ 56 | sizes, explode=explode, labels=labels, colors=colors, 57 | autopct=stream_vis.autopct, shadow=stream_vis.shadow) 58 | 59 | return False 60 | 61 | -------------------------------------------------------------------------------- /tensorwatch/notebook_maker.py: -------------------------------------------------------------------------------- 1 | import nbformat 2 | from nbformat.v4 import new_code_cell, new_markdown_cell, new_notebook 3 | from nbformat import v3, v4 4 | import codecs 5 | from os import linesep, path 6 | import uuid 7 | from . import utils 8 | import re 9 | from typing import List 10 | from .lv_types import VisArgs 11 | 12 | class NotebookMaker: 13 | def __init__(self, watcher, filename:str=None)->None: 14 | self.filename = filename or \ 15 | (path.splitext(watcher.filename)[0] + '.ipynb' if watcher.filename else \ 16 | 'tensorwatch.ipynb') 17 | 18 | self.cells = [] 19 | self._default_vis_args = VisArgs() 20 | 21 | watcher_args_str = NotebookMaker._get_vis_args(watcher) 22 | 23 | # create initial cell 24 | self.cells.append(new_code_cell(source=linesep.join( 25 | ['%matplotlib notebook', 26 | 'import tensorwatch as tw', 27 | 'client = tw.WatcherClient({})'.format(NotebookMaker._get_vis_args(watcher))]))) 28 | 29 | def _get_vis_args(watcher)->str: 30 | args_strs = [] 31 | for param, default_v in [('port', 0), ('filename', None)]: 32 | if hasattr(watcher, param): 33 | v = getattr(watcher, param) 34 | if v==default_v or (v is None and default_v is None): 35 | continue 36 | args_strs.append("{}={}".format(param, NotebookMaker._val2str(v))) 37 | return ', '.join(args_strs) 38 | 39 | def _get_stream_identifier(prefix, event_name, stream_name, stream_index)->str: 40 | if not stream_name or utils.is_uuid4(stream_name): 41 | if event_name is not None and event_name != '': 42 | return '{}_{}_{}'.format(prefix, event_name, stream_index) 43 | else: 44 | return prefix + str(stream_index) 45 | else: 46 | return '{}{}_{}'.format(prefix, stream_index, utils.str2identifier(stream_name)[:8]) 47 | 48 | def _val2str(v)->str: 49 | # TODO: shall we raise error if non str, bool, number (or its container) parameters? 50 | return str(v) if not isinstance(v, str) else "'{}'".format(v) 51 | 52 | def _add_vis_args_str(self, stream_info, param_strs:List[str])->None: 53 | default_args = self._default_vis_args.__dict__ 54 | if not stream_info.req.vis_args: 55 | return 56 | for k, v in stream_info.req.vis_args.__dict__.items(): 57 | if k in default_args: 58 | default_v = default_args[k] 59 | if (v is None and default_v is None) or (v==default_v): 60 | continue # skip param if its value is not changed from default 61 | param_strs.append("{}={}".format(k, NotebookMaker._val2str(v))) 62 | 63 | def _get_stream_code(self, event_name, stream_name, stream_index, stream_info)->List[str]: 64 | lines = [] 65 | 66 | stream_identifier = 's'+str(stream_index) 67 | lines.append("{} = client.open_stream(name='{}')".format(stream_identifier, stream_name)) 68 | 69 | vis_identifier = 'v'+str(stream_index) 70 | vis_args_strs = ['stream={}'.format(stream_identifier)] 71 | self._add_vis_args_str(stream_info, vis_args_strs) 72 | lines.append("{} = tw.Visualizer({})".format(vis_identifier, ', '.join(vis_args_strs))) 73 | lines.append("{}.show()".format(vis_identifier)) 74 | return lines 75 | 76 | def add_streams(self, event_stream_infos)->None: 77 | stream_index = 0 78 | for event_name, stream_infos in event_stream_infos.items(): # per event 79 | for stream_name, stream_info in stream_infos.items(): 80 | lines = self._get_stream_code(event_name, stream_name, stream_index, stream_info) 81 | self.cells.append(new_code_cell(source=linesep.join(lines))) 82 | stream_index += 1 83 | 84 | def write(self): 85 | nb = new_notebook(cells=self.cells, metadata={'language': 'python',}) 86 | with codecs.open(self.filename, encoding='utf-8', mode='w') as f: 87 | utils.debug_log('Notebook created', path.realpath(f.name), verbosity=0) 88 | nbformat.write(nb, f, 4) 89 | 90 | 91 | 92 | 93 | -------------------------------------------------------------------------------- /tensorwatch/plotly/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | -------------------------------------------------------------------------------- /tensorwatch/plotly/base_plotly_plot.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from ..vis_base import VisBase 5 | import time 6 | from abc import abstractmethod 7 | from .. import utils 8 | 9 | 10 | class BasePlotlyPlot(VisBase): 11 | def __init__(self, cell:VisBase.widgets.Box=None, title=None, show_legend:bool=None, is_3d:bool=False, 12 | stream_name:str=None, console_debug:bool=False, **vis_args): 13 | import plotly.graph_objs as go # function-level import as this takes long time 14 | super(BasePlotlyPlot, self).__init__(go.FigureWidget(), cell, title, show_legend, 15 | stream_name=stream_name, console_debug=console_debug, **vis_args) 16 | 17 | self.is_3d = is_3d 18 | self.widget.layout.title = title 19 | self.widget.layout.showlegend = show_legend if show_legend is not None else True 20 | 21 | def _add_trace(self, stream_vis): 22 | stream_vis.trace_index = len(self.widget.data) 23 | trace = self._create_trace(stream_vis) 24 | if stream_vis.opacity is not None: 25 | trace.opacity = stream_vis.opacity 26 | self.widget.add_trace(trace) 27 | 28 | def _add_trace_with_history(self, stream_vis): 29 | # if history buffer isn't full 30 | if stream_vis.history_len > len(stream_vis.trace_history): 31 | self._add_trace(stream_vis) 32 | stream_vis.trace_history.append(len(self.widget.data)-1) 33 | stream_vis.cur_history_index = len(stream_vis.trace_history)-1 34 | #if stream_vis.cur_history_index: 35 | # self.widget.data[trace_index].showlegend = False 36 | else: 37 | # rotate trace 38 | stream_vis.cur_history_index = (stream_vis.cur_history_index + 1) % stream_vis.history_len 39 | stream_vis.trace_index = stream_vis.trace_history[stream_vis.cur_history_index] 40 | self.clear_plot(stream_vis, False) 41 | self.widget.data[stream_vis.trace_index].opacity = stream_vis.opacity or 1 42 | 43 | cur_history_len = len(stream_vis.trace_history) 44 | if stream_vis.dim_history and cur_history_len > 1: 45 | max_opacity = stream_vis.opacity or 1 46 | min_alpha, max_alpha, dimmed_len = max_opacity*0.05, max_opacity*0.8, cur_history_len-1 47 | alphas = list(utils.frange(max_alpha, min_alpha, steps=dimmed_len)) 48 | for i, thi in enumerate(range(stream_vis.cur_history_index+1, 49 | stream_vis.cur_history_index+cur_history_len)): 50 | trace_index = stream_vis.trace_history[thi % cur_history_len] 51 | self.widget.data[trace_index].opacity = alphas[i] 52 | 53 | @staticmethod 54 | def get_pallet_color(i:int): 55 | import plotly # function-level import as this takes long time 56 | return plotly.colors.DEFAULT_PLOTLY_COLORS[i % len(plotly.colors.DEFAULT_PLOTLY_COLORS)] 57 | 58 | @staticmethod 59 | def _get_axis_common_props(title:str, axis_range:tuple): 60 | props = {'showline':True, 'showgrid': True, 61 | 'showticklabels': True, 'ticks':'inside'} 62 | if title: 63 | props['title'] = title 64 | if axis_range: 65 | props['range'] = list(axis_range) 66 | return props 67 | 68 | def _can_update_stream_plots(self): 69 | return time.time() - self.q_last_processed > 0.5 # make configurable 70 | 71 | def _post_add_subscription(self, stream_vis, **stream_vis_args): 72 | stream_vis.trace_history, stream_vis.cur_history_index = [], None 73 | self._add_trace_with_history(stream_vis) 74 | self._setup_layout(stream_vis) 75 | 76 | if not self.widget.layout.title: 77 | self.widget.layout.title = stream_vis.title 78 | # TODO: better way for below? 79 | if stream_vis.history_len > 1: 80 | self.widget.layout.showlegend = False 81 | 82 | def _show_widget_native(self, blocking:bool): 83 | pass 84 | #TODO: save image, spawn browser? 85 | 86 | def _show_widget_notebook(self): 87 | #plotly.offline.iplot(self.widget) 88 | return None 89 | 90 | def _post_update_stream_plot(self, stream_vis): 91 | # not needed for plotly as FigureWidget stays upto date 92 | pass 93 | 94 | @abstractmethod 95 | def clear_plot(self, stream_vis, clear_history): 96 | """(for derived class) Clears the data in specified plot before new data is redrawn""" 97 | pass 98 | @abstractmethod 99 | def _show_stream_items(self, stream_vis, stream_items): 100 | """Paint the given stream_items in to visualizer. If visualizer is dirty then return False else True. 101 | """ 102 | 103 | pass 104 | @abstractmethod 105 | def _setup_layout(self, stream_vis): 106 | pass 107 | @abstractmethod 108 | def _create_trace(self, stream_vis): 109 | pass 110 | -------------------------------------------------------------------------------- /tensorwatch/plotly/embeddings_plot.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from .. import image_utils 5 | import numpy as np 6 | from .line_plot import LinePlot 7 | import time 8 | from .. import utils 9 | 10 | class EmbeddingsPlot(LinePlot): 11 | def __init__(self, cell:LinePlot.widgets.Box=None, title=None, show_legend:bool=False, stream_name:str=None, console_debug:bool=False, 12 | is_3d:bool=True, hover_images=None, hover_image_reshape=None, **vis_args): 13 | utils.set_default(vis_args, 'height', '8in') 14 | super(EmbeddingsPlot, self).__init__(cell, title, show_legend, 15 | stream_name=stream_name, console_debug=console_debug, is_3d=is_3d, **vis_args) 16 | 17 | import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue 18 | if hover_images is not None: 19 | plt.ioff() 20 | self.image_output = LinePlot.widgets.Output() 21 | self.image_figure = plt.figure(figsize=(2,2)) 22 | self.image_ax = self.image_figure.add_subplot(111) 23 | self.cell.children += (self.image_output,) 24 | plt.ion() 25 | self.hover_images, self.hover_image_reshape = hover_images, hover_image_reshape 26 | self.last_ind, self.last_ind_time = -1, 0 27 | 28 | def hover_fn(self, trace, points, state): # pylint: disable=unused-argument 29 | if not points: 30 | return 31 | ind = points.point_inds[0] 32 | if ind == self.last_ind or ind > len(self.hover_images) or ind < 0: 33 | return 34 | 35 | if self.last_ind == -1: 36 | self.last_ind, self.last_ind_time = ind, time.time() 37 | else: 38 | elapsed = time.time() - self.last_ind_time 39 | if elapsed < 0.3: 40 | self.last_ind, self.last_ind_time = ind, time.time() 41 | if elapsed < 1: 42 | return 43 | # else too much time since update 44 | # else we have stable ind 45 | 46 | import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue 47 | with self.image_output: 48 | plt.ioff() 49 | 50 | if self.hover_image_reshape: 51 | img = np.reshape(self.hover_images[ind], self.hover_image_reshape) 52 | else: 53 | img = self.hover_images[ind] 54 | if img is not None: 55 | LinePlot.display.clear_output(wait=True) 56 | self.image_ax.imshow(img) 57 | LinePlot.display.display(self.image_figure) 58 | plt.ion() 59 | 60 | return None 61 | 62 | def _create_trace(self, stream_vis): 63 | stream_vis.stream_vis_args.clear() #TODO remove this 64 | utils.set_default(stream_vis.stream_vis_args, 'draw_line', False) 65 | utils.set_default(stream_vis.stream_vis_args, 'draw_marker', True) 66 | utils.set_default(stream_vis.stream_vis_args, 'draw_marker_text', True) 67 | utils.set_default(stream_vis.stream_vis_args, 'hoverinfo', 'text') 68 | utils.set_default(stream_vis.stream_vis_args, 'marker', {}) 69 | 70 | marker = stream_vis.stream_vis_args['marker'] 71 | utils.set_default(marker, 'size', 6) 72 | utils.set_default(marker, 'colorscale', 'Jet') 73 | utils.set_default(marker, 'showscale', False) 74 | utils.set_default(marker, 'opacity', 0.8) 75 | 76 | return super(EmbeddingsPlot, self)._create_trace(stream_vis) 77 | 78 | def subscribe(self, stream, **stream_vis_args): 79 | super(EmbeddingsPlot, self).subscribe(stream) 80 | stream_vis = self._stream_vises[stream.stream_name] 81 | if stream_vis.index == 0 and self.hover_images is not None: 82 | self.widget.data[stream_vis.trace_index].on_hover(self.hover_fn) -------------------------------------------------------------------------------- /tensorwatch/pytorch_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from torchvision import models, transforms 5 | import torch 6 | import torch.nn.functional as F 7 | from . import utils, image_utils 8 | import os 9 | 10 | def get_model(model_name): 11 | model = models.__dict__[model_name](pretrained=True) 12 | return model 13 | 14 | def tensors2batch(tensors, preprocess_transform=None): 15 | if preprocess_transform: 16 | tensors = tuple(preprocess_transform(i) for i in tensors) 17 | if not utils.is_array_like(tensors): 18 | tensors = tuple(tensors) 19 | return torch.stack(tensors, dim=0) 20 | 21 | def int2tensor(val): 22 | return torch.LongTensor([val]) 23 | 24 | def image2batch(image, image_transform=None): 25 | if image_transform: 26 | input_x = image_transform(image) 27 | else: # if no transforms supplied then just convert PIL image to tensor 28 | input_x = transforms.ToTensor()(image) 29 | input_x = input_x.unsqueeze(0) #convert to batch of 1 30 | return input_x 31 | 32 | def image_class2tensor(image_path, class_index=None, image_convert_mode=None, 33 | image_transform=None): 34 | image_pil = image_utils.open_image(os.path.abspath(image_path), convert_mode=image_convert_mode) 35 | input_x = image2batch(image_pil, image_transform) 36 | target_class = int2tensor(class_index) if class_index is not None else None 37 | return image_pil, input_x, target_class 38 | 39 | def batch_predict(model, inputs, input_transform=None, device=None): 40 | if input_transform: 41 | batch = torch.stack(tuple(input_transform(i) for i in inputs), dim=0) 42 | else: 43 | batch = torch.stack(inputs, dim=0) 44 | 45 | device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") 46 | model.eval() 47 | model.to(device) 48 | batch = batch.to(device) 49 | 50 | outputs = model(batch) 51 | 52 | return outputs 53 | 54 | def logits2probabilities(logits): 55 | return F.softmax(logits, dim=1) 56 | 57 | def tensor2numpy(t): 58 | return t.data.cpu().numpy() 59 | -------------------------------------------------------------------------------- /tensorwatch/receptive_field/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | -------------------------------------------------------------------------------- /tensorwatch/receptive_field/rf_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from receptivefield.pytorch import PytorchReceptiveField 5 | #from receptivefield.image import get_default_image 6 | import numpy as np 7 | 8 | def _get_rf(model, sample_pil_img): 9 | # define model functions 10 | def model_fn(): 11 | model.eval() 12 | return model 13 | 14 | input_shape = np.array(sample_pil_img).shape 15 | 16 | rf = PytorchReceptiveField(model_fn) 17 | rf_params = rf.compute(input_shape=input_shape) 18 | return rf, rf_params 19 | 20 | def plot_receptive_field(model, sample_pil_img, layout=(2, 2), figsize=(6, 6)): 21 | rf, rf_params = _get_rf(model, sample_pil_img) # pylint: disable=unused-variable 22 | return rf.plot_rf_grids( 23 | custom_image=sample_pil_img, 24 | figsize=figsize, 25 | layout=layout) 26 | 27 | def plot_grads_at(model, sample_pil_img, feature_map_index=0, point=(8,8), figsize=(6, 6)): 28 | rf, rf_params = _get_rf(model, sample_pil_img) # pylint: disable=unused-variable 29 | return rf.plot_gradient_at(fm_id=feature_map_index, point=point, image=None, figsize=figsize) 30 | -------------------------------------------------------------------------------- /tensorwatch/repeated_timer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import threading 5 | import time 6 | import weakref 7 | 8 | class RepeatedTimer: 9 | class State: 10 | Stopped=0 11 | Paused=1 12 | Running=2 13 | 14 | def __init__(self, secs, callback, count=None): 15 | self.secs = secs 16 | self.callback = weakref.WeakMethod(callback) if callback else None 17 | self._thread = None 18 | self._state = RepeatedTimer.State.Stopped 19 | self.pause_wait = threading.Event() 20 | self.pause_wait.set() 21 | self._continue_thread = False 22 | self.count = count 23 | 24 | def start(self): 25 | self._continue_thread = True 26 | self.pause_wait.set() 27 | if self._thread is None or not self._thread.isAlive(): 28 | self._thread = threading.Thread(target=self._runner, name='RepeatedTimer', daemon=True) 29 | self._thread.start() 30 | self._state = RepeatedTimer.State.Running 31 | 32 | def stop(self, block=False): 33 | self.pause_wait.set() 34 | self._continue_thread = False 35 | if block and not (self._thread is None or not self._thread.isAlive()): 36 | self._thread.join() 37 | self._state = RepeatedTimer.State.Stopped 38 | 39 | def get_state(self): 40 | return self._state 41 | 42 | 43 | def pause(self): 44 | if self._state == RepeatedTimer.State.Running: 45 | self.pause_wait.clear() 46 | self._state = RepeatedTimer.State.Paused 47 | # else nothing to do 48 | def unpause(self): 49 | if self._state == RepeatedTimer.State.Paused: 50 | self.pause_wait.set() 51 | if self._state == RepeatedTimer.State.Paused: 52 | self._state = RepeatedTimer.State.Running 53 | # else nothing to do 54 | 55 | def _runner(self): 56 | while (self._continue_thread): 57 | if self.count: 58 | self.count -= 0 59 | if not self.count: 60 | self._continue_thread = False 61 | 62 | if self._continue_thread: 63 | self.pause_wait.wait() 64 | if self.callback and self.callback(): 65 | self.callback()() 66 | 67 | if self._continue_thread: 68 | time.sleep(self.secs) 69 | 70 | self._thread = None 71 | self._state = RepeatedTimer.State.Stopped -------------------------------------------------------------------------------- /tensorwatch/saliency/README.md: -------------------------------------------------------------------------------- 1 | # Credits 2 | Code in this folder is adopted from 3 | 4 | * https://github.com/yulongwang12/visual-attribution 5 | * https://github.com/marcotcr/lime 6 | -------------------------------------------------------------------------------- /tensorwatch/saliency/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | -------------------------------------------------------------------------------- /tensorwatch/saliency/backprop.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.autograd import Variable, Function 3 | import torch 4 | import types 5 | 6 | 7 | class VanillaGradExplainer(object): 8 | def __init__(self, model): 9 | self.model = model 10 | 11 | def _backprop(self, inp, ind): 12 | inp.requires_grad = True 13 | if inp.grad is not None: 14 | inp.grad.zero_() 15 | if ind.grad is not None: 16 | ind.grad.zero_() 17 | self.model.eval() 18 | self.model.zero_grad() 19 | 20 | output = self.model(inp) 21 | if ind is None: 22 | ind = output.max(1)[1] 23 | grad_out = output.clone() 24 | grad_out.fill_(0.0) 25 | grad_out.scatter_(1, ind.unsqueeze(0).t(), 1.0) 26 | output.backward(grad_out) 27 | return inp.grad 28 | 29 | def explain(self, inp, ind=None, raw_inp=None): 30 | return self._backprop(inp, ind) 31 | 32 | 33 | class GradxInputExplainer(VanillaGradExplainer): 34 | def __init__(self, model): 35 | super(GradxInputExplainer, self).__init__(model) 36 | 37 | def explain(self, inp, ind=None, raw_inp=None): 38 | grad = self._backprop(inp, ind) 39 | return inp * grad 40 | 41 | 42 | class SaliencyExplainer(VanillaGradExplainer): 43 | def __init__(self, model): 44 | super(SaliencyExplainer, self).__init__(model) 45 | 46 | def explain(self, inp, ind=None, raw_inp=None): 47 | grad = self._backprop(inp, ind) 48 | return grad.abs() 49 | 50 | 51 | class IntegrateGradExplainer(VanillaGradExplainer): 52 | def __init__(self, model, steps=100): 53 | super(IntegrateGradExplainer, self).__init__(model) 54 | self.steps = steps 55 | 56 | def explain(self, inp, ind=None, raw_inp=None): 57 | grad = 0 58 | inp_data = inp.clone() 59 | 60 | for alpha in np.arange(1 / self.steps, 1.0, 1 / self.steps): 61 | new_inp = Variable(inp_data * alpha, requires_grad=True) 62 | g = self._backprop(new_inp, ind) 63 | grad += g 64 | 65 | return grad * inp_data / self.steps 66 | 67 | 68 | class DeconvExplainer(VanillaGradExplainer): 69 | def __init__(self, model): 70 | super(DeconvExplainer, self).__init__(model) 71 | self._override_backward() 72 | 73 | def _override_backward(self): 74 | class _ReLU(Function): 75 | @staticmethod 76 | def forward(ctx, input): 77 | output = torch.clamp(input, min=0) 78 | return output 79 | 80 | @staticmethod 81 | def backward(ctx, grad_output): 82 | grad_inp = torch.clamp(grad_output, min=0) 83 | return grad_inp 84 | 85 | def new_forward(self, x): 86 | return _ReLU.apply(x) 87 | 88 | def replace(m): 89 | if m.__class__.__name__ == 'ReLU': 90 | m.forward = types.MethodType(new_forward, m) 91 | 92 | self.model.apply(replace) 93 | 94 | 95 | class GuidedBackpropExplainer(VanillaGradExplainer): 96 | def __init__(self, model): 97 | super(GuidedBackpropExplainer, self).__init__(model) 98 | self._override_backward() 99 | 100 | def _override_backward(self): 101 | class _ReLU(Function): 102 | @staticmethod 103 | def forward(ctx, input): 104 | output = torch.clamp(input, min=0) 105 | ctx.save_for_backward(output) 106 | return output 107 | 108 | @staticmethod 109 | def backward(ctx, grad_output): 110 | output, = ctx.saved_tensors 111 | mask1 = (output > 0).float() 112 | mask2 = (grad_output > 0).float() 113 | grad_inp = mask1 * mask2 * grad_output 114 | grad_output.copy_(grad_inp) 115 | return grad_output 116 | 117 | def new_forward(self, x): 118 | return _ReLU.apply(x) 119 | 120 | def replace(m): 121 | if m.__class__.__name__ == 'ReLU': 122 | m.forward = types.MethodType(new_forward, m) 123 | 124 | self.model.apply(replace) 125 | 126 | 127 | # modified from https://github.com/PAIR-code/saliency/blob/master/saliency/base.py#L80 128 | class SmoothGradExplainer(object): 129 | def __init__(self, model, base_explainer=None, stdev_spread=0.15, 130 | nsamples=25, magnitude=True): 131 | self.base_explainer = base_explainer or VanillaGradExplainer(model) 132 | self.stdev_spread = stdev_spread 133 | self.nsamples = nsamples 134 | self.magnitude = magnitude 135 | 136 | def explain(self, inp, ind=None, raw_inp=None): 137 | stdev = self.stdev_spread * (inp.max() - inp.min()) 138 | 139 | total_gradients = 0 140 | 141 | for i in range(self.nsamples): 142 | noise = torch.randn_like(inp) * stdev 143 | 144 | noisy_inp = inp + noise 145 | noisy_inp.retain_grad() 146 | grad = self.base_explainer.explain(noisy_inp, ind) 147 | 148 | if self.magnitude: 149 | total_gradients += grad ** 2 150 | else: 151 | total_gradients += grad 152 | 153 | return total_gradients / self.nsamples -------------------------------------------------------------------------------- /tensorwatch/saliency/deeplift.py: -------------------------------------------------------------------------------- 1 | from .backprop import GradxInputExplainer 2 | import types 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | # Based on formulation in DeepExplain, https://arxiv.org/abs/1711.06104 7 | # https://github.com/marcoancona/DeepExplain/blob/master/deepexplain/tensorflow/methods.py#L221-L272 8 | class DeepLIFTRescaleExplainer(GradxInputExplainer): 9 | def __init__(self, model): 10 | super(DeepLIFTRescaleExplainer, self).__init__(model) 11 | self._prepare_reference() 12 | self.baseline_inp = None 13 | self._override_backward() 14 | 15 | def _prepare_reference(self): 16 | def init_refs(m): 17 | name = m.__class__.__name__ 18 | if name.find('ReLU') != -1: 19 | m.ref_inp_list = [] 20 | m.ref_out_list = [] 21 | 22 | def ref_forward(self, x): 23 | self.ref_inp_list.append(x.data.clone()) 24 | out = F.relu(x) 25 | self.ref_out_list.append(out.data.clone()) 26 | return out 27 | 28 | def ref_replace(m): 29 | name = m.__class__.__name__ 30 | if name.find('ReLU') != -1: 31 | m.forward = types.MethodType(ref_forward, m) 32 | 33 | self.model.apply(init_refs) 34 | self.model.apply(ref_replace) 35 | 36 | def _reset_preference(self): 37 | def reset_refs(m): 38 | name = m.__class__.__name__ 39 | if name.find('ReLU') != -1: 40 | m.ref_inp_list = [] 41 | m.ref_out_list = [] 42 | 43 | self.model.apply(reset_refs) 44 | 45 | def _baseline_forward(self, inp): 46 | if self.baseline_inp is None: 47 | self.baseline_inp = inp.data.clone() 48 | self.baseline_inp.fill_(0.0) 49 | self.baseline_inp = Variable(self.baseline_inp) 50 | else: 51 | self.baseline_inp.fill_(0.0) 52 | # get ref 53 | _ = self.model(self.baseline_inp) 54 | 55 | def _override_backward(self): 56 | def new_backward(self, grad_out): 57 | ref_inp, inp = self.ref_inp_list 58 | ref_out, out = self.ref_out_list 59 | delta_out = out - ref_out 60 | delta_in = inp - ref_inp 61 | g1 = (delta_in.abs() > 1e-5).float() * grad_out * \ 62 | delta_out / delta_in 63 | mask = ((ref_inp + inp) > 0).float() 64 | g2 = (delta_in.abs() <= 1e-5).float() * 0.5 * mask * grad_out 65 | 66 | return g1 + g2 67 | 68 | def backward_replace(m): 69 | name = m.__class__.__name__ 70 | if name.find('ReLU') != -1: 71 | m.backward = types.MethodType(new_backward, m) 72 | 73 | self.model.apply(backward_replace) 74 | 75 | def explain(self, inp, ind=None, raw_inp=None): 76 | self._reset_preference() 77 | self._baseline_forward(inp) 78 | g = super(DeepLIFTRescaleExplainer, self).explain(inp, ind) 79 | 80 | return g 81 | -------------------------------------------------------------------------------- /tensorwatch/saliency/gradcam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .backprop import VanillaGradExplainer 3 | 4 | def _get_layer(model, key_list): 5 | if key_list is None: 6 | return None 7 | 8 | a = model 9 | for key in key_list: 10 | a = a._modules[key] 11 | 12 | return a 13 | 14 | class GradCAMExplainer(VanillaGradExplainer): 15 | def __init__(self, model, target_layer_name_keys=None, use_inp=False): 16 | super(GradCAMExplainer, self).__init__(model) 17 | self.target_layer = _get_layer(model, target_layer_name_keys) 18 | self.use_inp = use_inp 19 | self.intermediate_act = [] 20 | self.intermediate_grad = [] 21 | self.handle_forward_hook = None 22 | self.handle_backward_hook = None 23 | self._register_forward_backward_hook() 24 | 25 | def _register_forward_backward_hook(self): 26 | def forward_hook_input(m, i, o): 27 | self.intermediate_act.append(i[0].data.clone()) 28 | 29 | def forward_hook_output(m, i, o): 30 | self.intermediate_act.append(o.data.clone()) 31 | 32 | def backward_hook(m, grad_i, grad_o): 33 | self.intermediate_grad.append(grad_o[0].data.clone()) 34 | 35 | if self.target_layer is not None: 36 | if self.use_inp: 37 | self.handle_forward_hook = self.target_layer.register_forward_hook(forward_hook_input) 38 | else: 39 | self.handle_forward_hook = self.target_layer.register_forward_hook(forward_hook_output) 40 | 41 | self.handle_backward_hook = self.target_layer.register_backward_hook(backward_hook) 42 | 43 | def _reset_intermediate_lists(self): 44 | self.intermediate_act = [] 45 | self.intermediate_grad = [] 46 | 47 | def explain(self, inp, ind=None, raw_inp=None): 48 | self._reset_intermediate_lists() 49 | _ = super(GradCAMExplainer, self)._backprop(inp, ind) 50 | self.handle_forward_hook.remove() 51 | self.handle_backward_hook.remove() 52 | if len(self.intermediate_grad): 53 | grad = self.intermediate_grad[0] 54 | act = self.intermediate_act[0] 55 | 56 | weights = grad.sum(-1).sum(-1).unsqueeze(-1).unsqueeze(-1) 57 | cam = weights * act 58 | cam = cam.sum(1).unsqueeze(1) 59 | 60 | cam = torch.clamp(cam, min=0) 61 | return cam 62 | else: 63 | return None 64 | 65 | -------------------------------------------------------------------------------- /tensorwatch/saliency/lime/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/tensorwatch/saliency/lime/__init__.py -------------------------------------------------------------------------------- /tensorwatch/saliency/lime/wrappers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/tensorwatch/saliency/lime/wrappers/__init__.py -------------------------------------------------------------------------------- /tensorwatch/saliency/lime/wrappers/generic_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import inspect 3 | import types 4 | 5 | 6 | def has_arg(fn, arg_name): 7 | """Checks if a callable accepts a given keyword argument. 8 | 9 | Args: 10 | fn: callable to inspect 11 | arg_name: string, keyword argument name to check 12 | 13 | Returns: 14 | bool, whether `fn` accepts a `arg_name` keyword argument. 15 | """ 16 | if sys.version_info < (3,): 17 | if isinstance(fn, types.FunctionType) or isinstance(fn, types.MethodType): 18 | arg_spec = inspect.getargspec(fn) 19 | else: 20 | try: 21 | arg_spec = inspect.getargspec(fn.__call__) 22 | except AttributeError: 23 | return False 24 | return (arg_name in arg_spec.args) 25 | elif sys.version_info < (3, 6): 26 | arg_spec = inspect.getfullargspec(fn) 27 | return (arg_name in arg_spec.args or 28 | arg_name in arg_spec.kwonlyargs) 29 | else: 30 | try: 31 | signature = inspect.signature(fn) 32 | except ValueError: 33 | # handling Cython 34 | signature = inspect.signature(fn.__call__) 35 | parameter = signature.parameters.get(arg_name) 36 | if parameter is None: 37 | return False 38 | return (parameter.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, 39 | inspect.Parameter.KEYWORD_ONLY)) 40 | -------------------------------------------------------------------------------- /tensorwatch/saliency/lime/wrappers/scikit_image.py: -------------------------------------------------------------------------------- 1 | import types 2 | from .generic_utils import has_arg 3 | from skimage.segmentation import felzenszwalb, slic, quickshift 4 | 5 | 6 | class BaseWrapper(object): 7 | """Base class for LIME Scikit-Image wrapper 8 | 9 | 10 | Args: 11 | target_fn: callable function or class instance 12 | target_params: dict, parameters to pass to the target_fn 13 | 14 | 15 | 'target_params' takes parameters required to instanciate the 16 | desired Scikit-Image class/model 17 | """ 18 | 19 | def __init__(self, target_fn=None, **target_params): 20 | self.target_fn = target_fn 21 | self.target_params = target_params 22 | 23 | self.target_fn = target_fn 24 | self.target_params = target_params 25 | 26 | def _check_params(self, parameters): 27 | """Checks for mistakes in 'parameters' 28 | 29 | Args : 30 | parameters: dict, parameters to be checked 31 | 32 | Raises : 33 | ValueError: if any parameter is not a valid argument for the target function 34 | or the target function is not defined 35 | TypeError: if argument parameters is not iterable 36 | """ 37 | a_valid_fn = [] 38 | if self.target_fn is None: 39 | if callable(self): 40 | a_valid_fn.append(self.__call__) 41 | else: 42 | raise TypeError('invalid argument: tested object is not callable,\ 43 | please provide a valid target_fn') 44 | elif isinstance(self.target_fn, types.FunctionType) \ 45 | or isinstance(self.target_fn, types.MethodType): 46 | a_valid_fn.append(self.target_fn) 47 | else: 48 | a_valid_fn.append(self.target_fn.__call__) 49 | 50 | if not isinstance(parameters, str): 51 | for p in parameters: 52 | for fn in a_valid_fn: 53 | if has_arg(fn, p): 54 | pass 55 | else: 56 | raise ValueError('{} is not a valid parameter'.format(p)) 57 | else: 58 | raise TypeError('invalid argument: list or dictionnary expected') 59 | 60 | def set_params(self, **params): 61 | """Sets the parameters of this estimator. 62 | Args: 63 | **params: Dictionary of parameter names mapped to their values. 64 | 65 | Raises : 66 | ValueError: if any parameter is not a valid argument 67 | for the target function 68 | """ 69 | self._check_params(params) 70 | self.target_params = params 71 | 72 | def filter_params(self, fn, override=None): 73 | """Filters `target_params` and return those in `fn`'s arguments. 74 | Args: 75 | fn : arbitrary function 76 | override: dict, values to override target_params 77 | Returns: 78 | result : dict, dictionary containing variables 79 | in both target_params and fn's arguments. 80 | """ 81 | override = override or {} 82 | result = {} 83 | for name, value in self.target_params.items(): 84 | if has_arg(fn, name): 85 | result.update({name: value}) 86 | result.update(override) 87 | return result 88 | 89 | 90 | class SegmentationAlgorithm(BaseWrapper): 91 | """ Define the image segmentation function based on Scikit-Image 92 | implementation and a set of provided parameters 93 | 94 | Args: 95 | algo_type: string, segmentation algorithm among the following: 96 | 'quickshift', 'slic', 'felzenszwalb' 97 | target_params: dict, algorithm parameters (valid model paramters 98 | as define in Scikit-Image documentation) 99 | """ 100 | 101 | def __init__(self, algo_type, **target_params): 102 | self.algo_type = algo_type 103 | if (self.algo_type == 'quickshift'): 104 | BaseWrapper.__init__(self, quickshift, **target_params) 105 | kwargs = self.filter_params(quickshift) 106 | self.set_params(**kwargs) 107 | elif (self.algo_type == 'felzenszwalb'): 108 | BaseWrapper.__init__(self, felzenszwalb, **target_params) 109 | kwargs = self.filter_params(felzenszwalb) 110 | self.set_params(**kwargs) 111 | elif (self.algo_type == 'slic'): 112 | BaseWrapper.__init__(self, slic, **target_params) 113 | kwargs = self.filter_params(slic) 114 | self.set_params(**kwargs) 115 | 116 | def __call__(self, *args): 117 | return self.target_fn(args[0], **self.target_params) 118 | -------------------------------------------------------------------------------- /tensorwatch/saliency/lime_image_explainer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from skimage.segmentation import mark_boundaries 5 | from torchvision import transforms 6 | import torch 7 | from .lime import lime_image 8 | import numpy as np 9 | from .. import imagenet_utils, pytorch_utils, utils 10 | 11 | class LimeImageExplainer: 12 | def __init__(self, model, predict_fn): 13 | self.model = model 14 | self.predict_fn = predict_fn 15 | 16 | def preprocess_input(self, inp): 17 | return inp 18 | def preprocess_label(self, label): 19 | return label 20 | 21 | def explain(self, inp, ind=None, raw_inp=None, top_labels=5, hide_color=0, num_samples=1000, 22 | positive_only=True, num_features=5, hide_rest=True, pixel_val_max=255.0): 23 | explainer = lime_image.LimeImageExplainer() 24 | explanation = explainer.explain_instance(self.preprocess_input(raw_inp), self.predict_fn, 25 | top_labels=5, hide_color=0, num_samples=1000) 26 | 27 | temp, mask = explanation.get_image_and_mask(self.preprocess_label(ind) or explanation.top_labels[0], 28 | positive_only=True, num_features=5, hide_rest=True) 29 | 30 | img = mark_boundaries(temp/pixel_val_max, mask) 31 | img = torch.from_numpy(img) 32 | img = torch.transpose(img, 0, 2) 33 | img = torch.transpose(img, 1, 2) 34 | return img.unsqueeze(0) 35 | 36 | class LimeImagenetExplainer(LimeImageExplainer): 37 | def __init__(self, model, predict_fn=None): 38 | super(LimeImagenetExplainer, self).__init__(model, predict_fn or self._imagenet_predict) 39 | 40 | def _preprocess_transform(self): 41 | transf = transforms.Compose([ 42 | transforms.ToTensor(), 43 | imagenet_utils.get_normalize_transform() 44 | ]) 45 | 46 | return transf 47 | 48 | def preprocess_input(self, inp): 49 | return np.array(imagenet_utils.get_resize_transform()(inp)) 50 | def preprocess_label(self, label): 51 | return label.item() if label is not None and utils.has_method(label, 'item') else label 52 | 53 | def _imagenet_predict(self, images): 54 | probs = imagenet_utils.predict(self.model, images, image_transform=self._preprocess_transform()) 55 | return pytorch_utils.tensor2numpy(probs) 56 | -------------------------------------------------------------------------------- /tensorwatch/saliency/occlusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.autograd import Variable 4 | from skimage.util import view_as_windows 5 | 6 | # modified from https://github.com/marcoancona/DeepExplain/blob/master/deepexplain/tensorflow/methods.py#L291-L342 7 | # note the different dim order in pytorch (NCHW) and tensorflow (NHWC) 8 | 9 | class OcclusionExplainer: 10 | def __init__(self, model, window_shape=10, step=1): 11 | self.model = model 12 | self.window_shape = window_shape 13 | self.step = step 14 | 15 | def explain(self, inp, ind=None, raw_inp=None): 16 | self.model.eval() 17 | with torch.no_grad(): 18 | return OcclusionExplainer._occlusion(inp, self.model, self.window_shape, self.step) 19 | 20 | @staticmethod 21 | def _occlusion(inp, model, window_shape, step=None): 22 | if type(window_shape) == int: 23 | window_shape = (window_shape, window_shape, 3) 24 | 25 | if step is None: 26 | step = 1 27 | n, c, h, w = inp.data.size() 28 | total_dim = c * h * w 29 | index_matrix = np.arange(total_dim).reshape(h, w, c) 30 | idx_patches = view_as_windows(index_matrix, window_shape, step).reshape( 31 | (-1,) + window_shape) 32 | heatmap = np.zeros((n, h, w, c), dtype=np.float32).reshape((-1), total_dim) 33 | weights = np.zeros_like(heatmap) 34 | 35 | inp_data = inp.data.clone() 36 | new_inp = Variable(inp_data) 37 | eval0 = model(new_inp) 38 | pred_id = eval0.max(1)[1].data[0] 39 | 40 | for i, p in enumerate(idx_patches): 41 | mask = np.ones((h, w, c)).flatten() 42 | mask[p.flatten()] = 0 43 | th_mask = torch.from_numpy(mask.reshape(1, h, w, c).transpose(0, 3, 1, 2)).float().cuda() 44 | masked_xs = Variable(th_mask * inp_data) 45 | delta = (eval0[0, pred_id] - model(masked_xs)[0, pred_id]).data.cpu().numpy() 46 | delta_aggregated = np.sum(delta.reshape(n, -1), -1, keepdims=True) 47 | heatmap[:, p.flatten()] += delta_aggregated 48 | weights[:, p.flatten()] += p.size 49 | 50 | attribution = np.reshape(heatmap / (weights + 1e-10), (n, h, w, c)).transpose(0, 3, 1, 2) 51 | return torch.from_numpy(attribution) -------------------------------------------------------------------------------- /tensorwatch/saliency/saliency.py: -------------------------------------------------------------------------------- 1 | from .gradcam import GradCAMExplainer 2 | from .backprop import VanillaGradExplainer, GradxInputExplainer, SaliencyExplainer, \ 3 | IntegrateGradExplainer, DeconvExplainer, GuidedBackpropExplainer, SmoothGradExplainer 4 | from .deeplift import DeepLIFTRescaleExplainer 5 | from .occlusion import OcclusionExplainer 6 | from .epsilon_lrp import EpsilonLrp 7 | from .lime_image_explainer import LimeImageExplainer, LimeImagenetExplainer 8 | import skimage.transform 9 | import torch 10 | import math 11 | from .. import image_utils 12 | 13 | class ImageSaliencyResult: 14 | def __init__(self, raw_image, saliency, title, saliency_alpha=0.4, saliency_cmap='jet'): 15 | self.raw_image, self.saliency, self.title = raw_image, saliency, title 16 | self.saliency_alpha, self.saliency_cmap = saliency_alpha, saliency_cmap 17 | 18 | def _get_explainer(explainer_name, model, layer_path=None): 19 | if explainer_name == 'gradcam': 20 | return GradCAMExplainer(model, target_layer_name_keys=layer_path, use_inp=True) 21 | if explainer_name == 'vanilla_grad': 22 | return VanillaGradExplainer(model) 23 | if explainer_name == 'grad_x_input': 24 | return GradxInputExplainer(model) 25 | if explainer_name == 'saliency': 26 | return SaliencyExplainer(model) 27 | if explainer_name == 'integrate_grad': 28 | return IntegrateGradExplainer(model) 29 | if explainer_name == 'deconv': 30 | return DeconvExplainer(model) 31 | if explainer_name == 'guided_backprop': 32 | return GuidedBackpropExplainer(model) 33 | if explainer_name == 'smooth_grad': 34 | return SmoothGradExplainer(model) 35 | if explainer_name == 'deeplift': 36 | return DeepLIFTRescaleExplainer(model) 37 | if explainer_name == 'occlusion': 38 | return OcclusionExplainer(model) 39 | if explainer_name == 'lrp': 40 | return EpsilonLrp(model) 41 | if explainer_name == 'lime_imagenet': 42 | return LimeImagenetExplainer(model) 43 | 44 | raise ValueError('Explainer {} is not recognized'.format(explainer_name)) 45 | 46 | def _get_layer_path(model): 47 | if model.__class__.__name__ == 'VGG': 48 | return ['features', '30'] # pool5 49 | elif model.__class__.__name__ == 'GoogleNet': 50 | return ['pool5'] 51 | elif model.__class__.__name__ == 'ResNet': 52 | return ['avgpool'] #layer4 53 | elif model.__class__.__name__ == 'Inception3': 54 | return ['Mixed_7c', 'branch_pool'] # ['conv2d_94'], 'mixed10' 55 | else: #TODO: guess layer for other networks? 56 | return None 57 | 58 | def get_saliency(model, raw_input, input, label, device=None, method='integrate_grad', layer_path=None): 59 | if device == None or type(device) != torch.device: 60 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 61 | 62 | model.to(device) 63 | input = input.to(device) 64 | if label is not None: 65 | label = label.to(device) 66 | 67 | if input.grad is not None: 68 | input.grad.zero_() 69 | if label is not None and label.grad is not None: 70 | label.grad.zero_() 71 | model.eval() 72 | model.zero_grad() 73 | 74 | layer_path = layer_path or _get_layer_path(model) 75 | exp = _get_explainer(method, model, layer_path) 76 | saliency = exp.explain(input, label, raw_input) 77 | 78 | if saliency is not None: 79 | saliency = saliency.abs().sum(dim=1)[0].squeeze() 80 | saliency -= saliency.min() 81 | saliency /= (saliency.max() + 1e-20) 82 | 83 | return saliency.detach().cpu().numpy() 84 | else: 85 | return None 86 | 87 | def get_image_saliency_results(model, raw_image, input, label, 88 | device=None, 89 | methods=['lime_imagenet', 'gradcam', 'smooth_grad', 90 | 'guided_backprop', 'deeplift', 'grad_x_input'], 91 | layer_path=None): 92 | results = [] 93 | for method in methods: 94 | sal = get_saliency(model, raw_image, input, label, device=device, method=method) 95 | 96 | if sal is not None: 97 | results.append(ImageSaliencyResult(raw_image, sal, method)) 98 | return results 99 | 100 | def get_image_saliency_plot(image_saliency_results, cols:int=2, figsize:tuple=None): 101 | import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue 102 | 103 | rows = math.ceil(len(image_saliency_results) / cols) 104 | figsize=figsize or (8, 3 * rows) 105 | figure = plt.figure(figsize=figsize) 106 | 107 | for i, r in enumerate(image_saliency_results): 108 | ax = figure.add_subplot(rows, cols, i+1) 109 | ax.set_xticks([]) 110 | ax.set_yticks([]) 111 | ax.set_title(r.title, fontdict={'fontsize': 24}) #'fontweight': 'light' 112 | 113 | #upsampler = nn.Upsample(size=(raw_image.height, raw_image.width), mode='bilinear') 114 | saliency_upsampled = skimage.transform.resize(r.saliency, 115 | (r.raw_image.height, r.raw_image.width), 116 | mode='reflect') 117 | 118 | image_utils.show_image(r.raw_image, img2=saliency_upsampled, 119 | alpha2=r.saliency_alpha, cmap2=r.saliency_cmap, ax=ax) 120 | return figure -------------------------------------------------------------------------------- /tensorwatch/stream.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import weakref, uuid 5 | from typing import Any 6 | from . import utils 7 | from .lv_types import StreamItem 8 | 9 | class Stream: 10 | """ 11 | Stream allows you to write values into it. One stream can subscribe to many streams. When a value is 12 | written in the stream, all subscribers also gets that value written into them (the from_steam parameter 13 | is set to the source stream). 14 | 15 | You can read values from a stream by calling read_all method. This will yield values as someone or subscribed streams are 16 | writing into the stream. You can also load all values from the stream. The default stream won't load anything but 17 | derived streams like file may load all values from the file and write into themselves. 18 | 19 | You can think of stream as a pipe that can be chained to other pipees. As value is put in pipe, it travels through 20 | connected pipes. The read_all method is like a tap that you can use to read values from the pipe. The load method is 21 | for specialized streams that may generate values in that pipe. 22 | 23 | Stream class supports full multi-threading. 24 | """ 25 | def __init__(self, stream_name:str=None, console_debug:bool=False): 26 | self._subscribers = weakref.WeakSet() 27 | self._subscribed_to = weakref.WeakSet() 28 | self.held_refs = set() # on some rare occasion we might want stream to hold references of other streams 29 | self.closed = False 30 | self.console_debug = console_debug 31 | self.stream_name = stream_name or str(uuid.uuid4()) # useful to use as key and avoid circular references 32 | self.items_written = 0 33 | 34 | def subscribe(self, stream:'Stream'): # notify other stream 35 | utils.debug_log('{} added {} as subscription'.format(self.stream_name, stream.stream_name)) 36 | stream._subscribers.add(self) 37 | self._subscribed_to.add(stream) 38 | 39 | def unsubscribe(self, stream:'Stream'): 40 | utils.debug_log('{} removed {} as subscription'.format(self.stream_name, stream.stream_name)) 41 | stream._subscribers.discard(self) 42 | self._subscribed_to.discard(stream) 43 | self.held_refs.discard(stream) 44 | #stream.held_refs.discard(self) # not needed as only subscriber should hold ref 45 | 46 | def to_stream_item(self, val:Any): 47 | stream_item = val if isinstance(val, StreamItem) else \ 48 | StreamItem(value=val, stream_name=self.stream_name) 49 | if stream_item.stream_name is None: 50 | stream_item.stream_name = self.stream_name 51 | if stream_item.item_index is None: 52 | stream_item.item_index = self.items_written 53 | return stream_item 54 | 55 | def write(self, val:Any, from_stream:'Stream'=None): 56 | # if you override write method, first you must call self.to_stream_item 57 | # so it can stamp the stamp the stream name 58 | stream_item = self.to_stream_item(val) 59 | 60 | if self.console_debug: 61 | print(self.stream_name, stream_item) 62 | 63 | for subscriber in self._subscribers: 64 | subscriber.write(stream_item, from_stream=self) 65 | self.items_written += 1 66 | 67 | def read_all(self, from_stream:'Stream'=None): 68 | for subscribed_to in self._subscribed_to: 69 | for stream_item in subscribed_to.read_all(from_stream=self): 70 | yield stream_item 71 | 72 | def load(self, from_stream:'Stream'=None): 73 | for subscribed_to in self._subscribed_to: 74 | subscribed_to.load(from_stream=self) 75 | 76 | def save(self, from_stream:'Stream'=None): 77 | for subscriber in self._subscribers: 78 | subscriber.save(from_stream=self) 79 | 80 | def close(self): 81 | if not self.closed: 82 | for subscribed_to in self._subscribed_to: 83 | subscribed_to._subscribers.discard(self) 84 | self._subscribed_to.clear() 85 | self.closed = True 86 | 87 | def __enter__(self): 88 | return self 89 | 90 | def __exit__(self, exception_type, exception_value, traceback): 91 | self.close() 92 | 93 | -------------------------------------------------------------------------------- /tensorwatch/stream_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from typing import Dict, Sequence, List, Any 5 | from .zmq_stream import ZmqStream 6 | from .file_stream import FileStream 7 | from .stream import Stream 8 | from .stream_union import StreamUnion 9 | import uuid 10 | from . import utils 11 | 12 | class StreamFactory: 13 | r"""Allows to create shared stream such as file and ZMQ streams 14 | """ 15 | 16 | def __init__(self)->None: 17 | self.closed = None 18 | self._streams:Dict[str, Stream] = None 19 | self._reset() 20 | 21 | def _reset(self): 22 | self._streams:Dict[str, Stream] = {} 23 | self.closed = False 24 | 25 | def close(self): 26 | if not self.closed: 27 | for stream in self._streams.values(): 28 | stream.close() 29 | self._reset() 30 | self.closed = True 31 | 32 | def __enter__(self): 33 | return self 34 | def __exit__(self, exception_type, exception_value, traceback): 35 | self.close() 36 | 37 | def get_streams(self, stream_types:Sequence[str], for_write:bool=None)->List[Stream]: 38 | streams = [self._create_stream_by_string(stream_type, for_write) for stream_type in stream_types] 39 | return streams 40 | 41 | def get_combined_stream(self, stream_types:Sequence[str], for_write:bool=None)->Stream: 42 | streams = [self._create_stream_by_string(stream_type, for_write) for stream_type in stream_types] 43 | if len(streams) == 1: 44 | return self._streams[0] 45 | else: 46 | # we create new union of child but this is not necessory 47 | return StreamUnion(streams, for_write=for_write) 48 | 49 | def _get_stream_name(stream_type:str, stream_args:Any, for_write:bool)->str: 50 | return '{}:{}:{}'.format(stream_type, stream_args, for_write) 51 | 52 | def _create_stream_by_string(self, stream_spec:str, for_write:bool)->Stream: 53 | parts = stream_spec.split(':', 1) if stream_spec is not None else [''] 54 | stream_type = parts[0] 55 | stream_args = parts[1] if len(parts) > 1 else None 56 | 57 | utils.debug_log("Creating stream", (stream_spec, for_write)) 58 | 59 | if stream_type == 'tcp': 60 | port = int(stream_args or 0) 61 | stream_name = StreamFactory._get_stream_name(stream_type, port, for_write) 62 | if stream_name not in self._streams: 63 | self._streams[stream_name] = ZmqStream(for_write=for_write, 64 | port=port, stream_name=stream_name, block_until_connected=False) 65 | # else we already have this stream 66 | return self._streams[stream_name] 67 | 68 | 69 | if stream_args is None: # file name specified without 'file:' prefix 70 | stream_args = stream_type 71 | stream_type = 'file' 72 | if len(stream_type) == 1: # windows drive letter 73 | stream_type = 'file' 74 | stream_args = stream_spec 75 | 76 | if stream_type == 'file': 77 | if stream_args is None: 78 | raise ValueError('File name must be specified for stream type "file"') 79 | stream_name = StreamFactory._get_stream_name(stream_type, stream_args, for_write) 80 | # each read only file stream should be separate stream or otheriwse sharing will 81 | # change seek positions 82 | if not for_write: 83 | stream_name += ':' + str(uuid.uuid4()) 84 | 85 | # if write file exist then flush it before read stream would read it 86 | write_stream_name = StreamFactory._get_stream_name(stream_type, stream_args, True) 87 | write_file_stream = self._streams.get(write_stream_name, None) 88 | if write_file_stream: 89 | write_file_stream.save() 90 | if stream_name not in self._streams: 91 | self._streams[stream_name] = FileStream(for_write=for_write, 92 | file_name=stream_args, stream_name=stream_name) 93 | # else we already have this stream 94 | return self._streams[stream_name] 95 | 96 | if stream_type == '': 97 | return Stream() 98 | 99 | raise ValueError('stream_type "{}" has unknown type'.format(stream_type)) 100 | 101 | 102 | -------------------------------------------------------------------------------- /tensorwatch/stream_union.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from .stream import Stream 5 | from typing import Iterator 6 | 7 | class StreamUnion(Stream): 8 | def __init__(self, child_streams:Iterator[Stream], for_write:bool, stream_name:str=None, console_debug:bool=False) -> None: 9 | super(StreamUnion, self).__init__(stream_name=stream_name, console_debug=console_debug) 10 | 11 | # save references, child streams does away only if parent goes away 12 | self.child_streams = child_streams 13 | 14 | # when someone does write to us, we write to all our listeners 15 | if for_write: 16 | for child_stream in child_streams: 17 | child_stream.subscribe(self) 18 | else: 19 | # union of all child streams 20 | for child_stream in child_streams: 21 | self.subscribe(child_stream) -------------------------------------------------------------------------------- /tensorwatch/tensor_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Sized, Any, List 2 | import numbers 3 | import numpy as np 4 | 5 | class TensorType: 6 | Torch = 'torch' 7 | TF = 'tf' 8 | Numeric = 'numeric' 9 | Numpy = 'numpy' 10 | Other = 'other' 11 | 12 | def tensor_type(item:Any)->str: 13 | module_name = type(item).__module__ 14 | class_name = type(item).__name__ 15 | 16 | if module_name=='torch' and class_name=='Tensor': 17 | return TensorType.Torch 18 | elif module_name.startswith('tensorflow') and class_name=='EagerTensor': 19 | return TensorType.TF 20 | elif isinstance(item, numbers.Number): 21 | return TensorType.Numeric 22 | elif isinstance(item, np.ndarray): 23 | return TensorType.Numpy 24 | else: 25 | return TensorType.Other 26 | 27 | def tensor2scaler(item:Any)->numbers.Number: 28 | tt = tensor_type(item) 29 | if item is None or tt == TensorType.Numeric: 30 | return item 31 | if tt == TensorType.TF: 32 | item = item.numpy() 33 | return item.item() # works for torch and numpy 34 | 35 | def tensor2np(item:Any)->np.ndarray: 36 | tt = tensor_type(item) 37 | if item is None or tt == TensorType.Numpy: 38 | return item 39 | elif tt == TensorType.TF: 40 | return item.numpy() 41 | elif tt == TensorType.Torch: 42 | return item.data.cpu().numpy() 43 | else: # numeric and everything else, let np take care 44 | return np.array(item) 45 | 46 | def to_scaler_list(l:Sized)->List[numbers.Number]: 47 | """Create list of scalers for given list of tensors where each element is 0-dim tensor 48 | """ 49 | if l is not None and len(l): 50 | tt = tensor_type(l[0]) 51 | if tt == TensorType.Torch or tt == TensorType.Numpy: 52 | return [i.item() for i in l] 53 | elif tt == TensorType.TF: 54 | return [i.numpy().item() for i in l] 55 | elif tt == TensorType.Numeric: 56 | # convert to list in case l is not list type 57 | return [i for i in l] 58 | else: 59 | raise ValueError('Cannot convert tensor list to scaler list \ 60 | because list element are of unsupported type ' + tt) 61 | else: 62 | return None if l is None else [] # make sure we always return list type 63 | 64 | def to_mean_list(l:Sized)->List[float]: 65 | """Create list of scalers for given list of tensors where each element is 0-dim tensor 66 | """ 67 | if l is not None and len(l): 68 | tt = tensor_type(l[0]) 69 | if tt == TensorType.Torch or tt == TensorType.Numpy: 70 | return [i.mean() for i in l] 71 | elif tt == TensorType.TF: 72 | return [i.numpy().mean() for i in l] 73 | elif tt == TensorType.Numeric: 74 | # convert to list in case l is not list type 75 | return [float(i) for i in l] 76 | else: 77 | raise ValueError('Cannot convert tensor list to scaler list \ 78 | because list element are of unsupported type ' + tt) 79 | else: 80 | return None if l is None else [] 81 | 82 | def to_np_list(l:Sized)->List[np.ndarray]: 83 | if l is not None and len(l): 84 | tt = tensor_type(l[0]) 85 | if tt == TensorType.Numeric: 86 | return [np.array(i) for i in l] 87 | if tt == TensorType.TF: 88 | return [i.numpy() for i in l] 89 | if tt == TensorType.Torch: 90 | return [i.data.cpu().numpy() for i in l] 91 | if tt == TensorType.Numpy: 92 | return [i for i in l] 93 | raise ValueError('Cannot convert tensor list to scaler list \ 94 | because list element are of unsupported type ' + tt) 95 | else: 96 | return None if l is None else [] -------------------------------------------------------------------------------- /tensorwatch/text_vis.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from . import utils 5 | from .vis_base import VisBase 6 | 7 | class TextVis(VisBase): 8 | def __init__(self, cell:VisBase.widgets.Box=None, title:str=None, show_legend:bool=None, 9 | stream_name:str=None, console_debug:bool=False, **vis_args): 10 | import pandas as pd # expensive import 11 | 12 | super(TextVis, self).__init__(VisBase.widgets.HTML(), cell, title, show_legend, 13 | stream_name=stream_name, console_debug=console_debug, **vis_args) 14 | self.df = pd.DataFrame([]) 15 | self.SeriesClass = pd.Series 16 | 17 | def _get_column_prefix(self, stream_vis, i): 18 | return '[S.{}]:{}'.format(stream_vis.index, i) 19 | 20 | def _get_title(self, stream_vis): 21 | title = stream_vis.title or 'Stream ' + str(len(self._stream_vises)) 22 | return title 23 | 24 | # this will be called from _show_stream_items 25 | def _append(self, stream_vis, vals): 26 | if vals is None: 27 | self.df = self.df.append(self.SeriesClass({self._get_column_prefix(stream_vis, 0) : None}), 28 | sort=False, ignore_index=True) 29 | return 30 | for val in vals: 31 | if val is None or utils.is_scalar(val): 32 | self.df = self.df.append(self.SeriesClass({self._get_column_prefix(stream_vis, 0) : val}), 33 | sort=False, ignore_index=True) 34 | elif utils.is_array_like(val): 35 | val_dict = {} 36 | for i,val_i in enumerate(val): 37 | val_dict[self._get_column_prefix(stream_vis, i)] = val_i 38 | self.df = self.df.append(self.SeriesClass(val_dict), sort=False, ignore_index=True) 39 | else: 40 | self.df = self.df.append(self.SeriesClass(val.__dict__), sort=False, ignore_index=True) 41 | 42 | def _post_add_subscription(self, stream_vis, **stream_vis_args): 43 | only_summary = stream_vis_args.get('only_summary', False) 44 | stream_vis.text = self._get_title(stream_vis) 45 | stream_vis.only_summary = only_summary 46 | 47 | def clear_plot(self, stream_vis, clear_history): 48 | self.df = self.df.iloc[0:0] 49 | 50 | def _show_stream_items(self, stream_vis, stream_items): 51 | """Paint the given stream_items in to visualizer. If visualizer is dirty then return False else True. 52 | """ 53 | for stream_item in stream_items: 54 | if stream_item.ended: 55 | self.df = self.df.append(self.SeriesClass({'Ended':True}), 56 | sort=False, ignore_index=True) 57 | else: 58 | vals = self._extract_vals((stream_item,)) 59 | self._append(stream_vis, vals) 60 | return False # dirty 61 | 62 | def _post_update_stream_plot(self, stream_vis): 63 | if VisBase.get_ipython(): 64 | if not stream_vis.only_summary: 65 | self.widget.value = self.df.to_html(classes=['output_html', 'rendered_html']) 66 | else: 67 | self.widget.value = self.df.describe().to_html(classes=['output_html', 'rendered_html']) 68 | # below doesn't work because of threading issue 69 | #self.widget.clear_output(wait=True) 70 | #with self.widget: 71 | # display.display(self.df) 72 | else: 73 | last_recs = self.df.iloc[[-1]].to_dict('records') 74 | if len(last_recs) == 1: 75 | print(last_recs[0]) 76 | else: 77 | print(last_recs) 78 | 79 | def _show_widget_native(self, blocking:bool): 80 | return None # we will be using console 81 | 82 | def _show_widget_notebook(self): 83 | return self.widget -------------------------------------------------------------------------------- /tensorwatch/visualizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from .stream import Stream 5 | from .vis_base import VisBase 6 | from . import mpl 7 | from . import plotly 8 | 9 | class Visualizer: 10 | """Constructs visualizer for specified vis_type. 11 | 12 | NOTE: If you modify arguments here then also sync VisArgs contructor. 13 | """ 14 | def __init__(self, stream:Stream, vis_type:str=None, host:'Visualizer'=None, 15 | cell:'Visualizer'=None, title:str=None, 16 | clear_after_end=False, clear_after_each=False, history_len=1, dim_history=True, opacity=None, 17 | 18 | rows=2, cols=5, img_width=None, img_height=None, img_channels=None, 19 | colormap=None, viz_img_scale=None, 20 | 21 | # these image params are for hover on point for t-sne 22 | hover_images=None, hover_image_reshape=None, cell_width:str=None, cell_height:str=None, 23 | 24 | only_summary=False, separate_yaxis=True, xtitle=None, ytitle=None, ztitle=None, color=None, 25 | xrange=None, yrange=None, zrange=None, draw_line=True, draw_marker=False, 26 | 27 | # histogram 28 | bins=None, normed=None, histtype='bar', edge_color=None, linewidth=None, bar_width=None, 29 | 30 | # pie chart 31 | autopct=None, shadow=None, 32 | 33 | vis_args={}, stream_vis_args={})->None: 34 | 35 | cell = cell._host_base.cell if cell is not None else None 36 | 37 | if host: 38 | self._host_base = host._host_base 39 | else: 40 | self._host_base = self._get_vis_base(vis_type, cell, title, hover_images=hover_images, hover_image_reshape=hover_image_reshape, 41 | cell_width=cell_width, cell_height=cell_height, 42 | **vis_args) 43 | 44 | self._host_base.subscribe(stream, show=False, clear_after_end=clear_after_end, clear_after_each=clear_after_each, 45 | history_len=history_len, dim_history=dim_history, opacity=opacity, 46 | only_summary=only_summary if vis_type is None or 'summary' != vis_type else True, 47 | separate_yaxis=separate_yaxis, xtitle=xtitle, ytitle=ytitle, ztitle=ztitle, color=color, 48 | xrange=xrange, yrange=yrange, zrange=zrange, 49 | draw_line=draw_line if vis_type is not None and 'scatter' in vis_type else True, 50 | draw_marker=draw_marker, 51 | rows=rows, cols=cols, img_width=img_width, img_height=img_height, img_channels=img_channels, 52 | colormap=colormap, viz_img_scale=viz_img_scale, 53 | bins=bins, normed=normed, histtype=histtype, edge_color=edge_color, linewidth=linewidth, bar_width = bar_width, 54 | autopct=autopct, shadow=shadow, 55 | **stream_vis_args) 56 | 57 | stream.load() 58 | 59 | def show(self): 60 | return self._host_base.show() 61 | 62 | def save(self, filepath:str)->None: 63 | self._host_base.save(filepath) 64 | 65 | def _get_vis_base(self, vis_type, cell:VisBase.widgets.Box, title, hover_images=None, hover_image_reshape=None, cell_width=None, cell_height=None, **vis_args)->VisBase: 66 | if vis_type is None or vis_type in ['line', 67 | 'mpl-line', 'mpl-line3d', 'mpl-scatter3d', 'mpl-scatter']: 68 | return mpl.line_plot.LinePlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, 69 | is_3d=vis_type is not None and vis_type.endswith('3d'), **vis_args) 70 | if vis_type in ['image', 'mpl-image']: 71 | return mpl.image_plot.ImagePlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, **vis_args) 72 | if vis_type in ['bar', 'bar3d']: 73 | return mpl.bar_plot.BarPlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, 74 | is_3d=vis_type.endswith('3d'), **vis_args) 75 | if vis_type in ['histogram']: 76 | return mpl.histogram.Histogram(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, **vis_args) 77 | if vis_type in ['pie']: 78 | return mpl.pie_chart.PieChart(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, **vis_args) 79 | if vis_type in ['text', 'summary']: 80 | from .text_vis import TextVis 81 | return TextVis(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, **vis_args) 82 | if vis_type in ['line3d', 'scatter', 'scatter3d', 83 | 'plotly-line', 'plotly-line3d', 'plotly-scatter', 'plotly-scatter3d', 'mesh3d']: 84 | return plotly.line_plot.LinePlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, 85 | is_3d=vis_type.endswith('3d'), **vis_args) 86 | if vis_type in ['tsne', 'embeddings', 'tsne2d', 'embeddings2d']: 87 | return plotly.embeddings_plot.EmbeddingsPlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, 88 | is_3d='2d' not in vis_type, 89 | hover_images=hover_images, hover_image_reshape=hover_image_reshape, **vis_args) 90 | else: 91 | raise ValueError('Render vis_type parameter has invalid value: "{}"'.format(vis_type)) 92 | -------------------------------------------------------------------------------- /tensorwatch/watcher.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import uuid 5 | from typing import Sequence 6 | from .zmq_wrapper import ZmqWrapper 7 | from .watcher_base import WatcherBase 8 | from .lv_types import CliSrvReqTypes 9 | from .lv_types import DefaultPorts, PublisherTopics, ServerMgmtMsg 10 | from . import utils 11 | import threading, time 12 | 13 | class Watcher(WatcherBase): 14 | def __init__(self, filename:str=None, port:int=0, srv_name:str=None): 15 | super(Watcher, self).__init__() 16 | 17 | self.port = port 18 | self.filename = filename 19 | 20 | # used to detect server restarts 21 | self.srv_name = srv_name or str(uuid.uuid4()) 22 | 23 | # define vars in __init__ 24 | self._clisrv = None 25 | self._zmq_stream_pub = None 26 | self._file = None 27 | self._th = None 28 | 29 | self._open_devices() 30 | 31 | def _open_devices(self): 32 | if self.port is not None: 33 | self._clisrv = ZmqWrapper.ClientServer(port=DefaultPorts.CliSrv+self.port, 34 | is_server=True, callback=self._clisrv_callback) 35 | 36 | # notify existing listeners of our ID 37 | self._zmq_stream_pub = self._stream_factory.get_streams(stream_types=['tcp:'+str(self.port)], for_write=True)[0] 38 | 39 | # ZMQ quirk: we must wait a bit after opening port and before sending message 40 | # TODO: can we do better? 41 | self._th = threading.Thread(target=self._send_server_start) 42 | self._th.start() 43 | 44 | def _send_server_start(self): 45 | time.sleep(2) 46 | self._zmq_stream_pub.write(ServerMgmtMsg(event_name=ServerMgmtMsg.EventServerStart, 47 | event_args=self.srv_name), topic=PublisherTopics.ServerMgmt) 48 | 49 | def devices_or_default(self, devices:Sequence[str])->Sequence[str]: # overriden 50 | # TODO: this method is duplicated in Watcher and WatcherClient 51 | 52 | # make sure TCP port is attached to tcp device 53 | if devices is not None: 54 | return ['tcp:' + str(self.port) if device=='tcp' else device for device in devices] 55 | 56 | # if no devices specified then use our filename and tcp:port as default devices 57 | devices = [] 58 | # first open file device because it may have older data 59 | if self.filename is not None: 60 | devices.append('file:' + self.filename) 61 | if self.port is not None: 62 | devices.append('tcp:' + str(self.port)) 63 | return devices 64 | 65 | def close(self): 66 | if not self.closed: 67 | if self._clisrv is not None: 68 | self._clisrv.close() 69 | if self._zmq_stream_pub is not None: 70 | self._zmq_stream_pub.close() 71 | if self._file is not None: 72 | self._file.close() 73 | utils.debug_log("Watcher is closed", verbosity=1) 74 | super(Watcher, self).close() 75 | 76 | def _reset(self): 77 | self._clisrv = None 78 | self._zmq_stream_pub = None 79 | self._file = None 80 | self._th = None 81 | utils.debug_log("Watcher reset", verbosity=1) 82 | super(Watcher, self)._reset() 83 | 84 | def _clisrv_callback(self, clisrv, clisrv_req): # pylint: disable=unused-argument 85 | utils.debug_log("Received client request", clisrv_req.req_type) 86 | 87 | # request = create stream 88 | if clisrv_req.req_type == CliSrvReqTypes.create_stream: 89 | stream_req = clisrv_req.req_data 90 | self.create_stream(name=stream_req.stream_name, devices=stream_req.devices, 91 | event_name=stream_req.event_name, expr=stream_req.expr, throttle=stream_req.throttle, 92 | vis_args=stream_req.vis_args) 93 | return None # ignore return as we can't send back stream obj 94 | elif clisrv_req.req_type == CliSrvReqTypes.del_stream: 95 | stream_name = clisrv_req.req_data 96 | return self.del_stream(stream_name) 97 | else: 98 | raise ValueError('ClientServer Request Type {} is not recognized'.format(clisrv_req)) -------------------------------------------------------------------------------- /tensorwatch/watcher_client.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from typing import Any, Dict, Sequence, List 5 | from .zmq_wrapper import ZmqWrapper 6 | from .lv_types import CliSrvReqTypes, ClientServerRequest, DefaultPorts 7 | from .lv_types import VisArgs, PublisherTopics, ServerMgmtMsg, StreamCreateRequest 8 | from .stream import Stream 9 | from .zmq_mgmt_stream import ZmqMgmtStream 10 | from . import utils 11 | from .watcher_base import WatcherBase 12 | 13 | class WatcherClient(WatcherBase): 14 | r"""Extends watcher to add methods so calls for create and delete stream can be sent to server. 15 | """ 16 | def __init__(self, filename:str=None, port:int=0): 17 | super(WatcherClient, self).__init__() 18 | self.port = port 19 | self.filename = filename 20 | 21 | # define vars in __init__ 22 | self._clisrv = None # client-server sockets allows to send create/del stream requests 23 | self._zmq_srvmgmt_sub = None 24 | self._file = None 25 | 26 | self._open() 27 | 28 | def _reset(self): 29 | self._clisrv = None 30 | self._zmq_srvmgmt_sub = None 31 | self._file = None 32 | utils.debug_log("WatcherClient reset", verbosity=1) 33 | super(WatcherClient, self)._reset() 34 | 35 | def _open(self): 36 | if self.port is not None: 37 | self._clisrv = ZmqWrapper.ClientServer(port=DefaultPorts.CliSrv+self.port, 38 | is_server=False) 39 | # create subscription where we will receive server management events 40 | self._zmq_srvmgmt_sub = ZmqMgmtStream(clisrv=self._clisrv, for_write=False, port=self.port, 41 | stream_name='zmq_srvmgmt_sub:'+str(self.port)+':False') 42 | 43 | def close(self): 44 | if not self.closed: 45 | self._zmq_srvmgmt_sub.close() 46 | self._clisrv.close() 47 | utils.debug_log("WatcherClient is closed", verbosity=1) 48 | super(WatcherClient, self).close() 49 | 50 | def devices_or_default(self, devices:Sequence[str])->Sequence[str]: # overridden 51 | # TODO: this method is duplicated in Watcher and WatcherClient 52 | 53 | # make sure TCP port is attached to tcp device 54 | if devices is not None: 55 | return ['tcp:' + str(self.port) if device=='tcp' else device for device in devices] 56 | 57 | # if no devices specified then use our filename and tcp:port as default devices 58 | devices = [] 59 | # first open file device because it may have older data 60 | if self.filename is not None: 61 | devices.append('file:' + self.filename) 62 | if self.port is not None: 63 | devices.append('tcp:' + str(self.port)) 64 | return devices 65 | 66 | # override to send request to server, instead of underlying WatcherBase base class 67 | def create_stream(self, name:str=None, devices:Sequence[str]=None, event_name:str='', 68 | expr=None, throttle:float=1, vis_args:VisArgs=None)->Stream: # overriden 69 | 70 | stream_req = StreamCreateRequest(stream_name=name, devices=self.devices_or_default(devices), 71 | event_name=event_name, expr=expr, throttle=throttle, vis_args=vis_args) 72 | 73 | self._zmq_srvmgmt_sub.add_stream_req(stream_req) 74 | 75 | if stream_req.devices is not None: 76 | stream = self.open_stream(name=stream_req.stream_name, devices=stream_req.devices) 77 | else: # we cannot return remote streams that are not backed by a device 78 | stream = None 79 | return stream 80 | 81 | # override to set devices default to tcp 82 | def open_stream(self, name:str=None, devices:Sequence[str]=None)->Stream: # overriden 83 | return super(WatcherClient, self).open_stream(name=name, devices=devices) 84 | 85 | 86 | # override to send request to server 87 | def del_stream(self, name:str) -> None: 88 | self._zmq_srvmgmt_sub.del_stream(name) 89 | 90 | -------------------------------------------------------------------------------- /tensorwatch/zmq_mgmt_stream.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from typing import Any 5 | from .zmq_stream import ZmqStream 6 | from .lv_types import PublisherTopics, ServerMgmtMsg, StreamCreateRequest 7 | from .zmq_wrapper import ZmqWrapper 8 | from .lv_types import CliSrvReqTypes, ClientServerRequest 9 | from . import utils 10 | 11 | class ZmqMgmtStream(ZmqStream): 12 | # default topic is mgmt 13 | def __init__(self, clisrv:ZmqWrapper.ClientServer, for_write:bool, port:int=0, topic=PublisherTopics.ServerMgmt, block_until_connected=True, 14 | stream_name:str=None, console_debug:bool=False): 15 | super(ZmqMgmtStream, self).__init__(for_write=for_write, port=port, topic=topic, 16 | block_until_connected=block_until_connected, stream_name=stream_name, console_debug=console_debug) 17 | 18 | self._clisrv = clisrv 19 | self._stream_reqs:Dict[str,StreamCreateRequest] = {} 20 | 21 | def write(self, val:Any, from_stream:'Stream'=None): 22 | r"""Handles server management events. 23 | """ 24 | stream_item = self.to_stream_item(val) 25 | mgmt_msg = stream_item.value 26 | 27 | utils.debug_log("Received - SeverMgmtevent", mgmt_msg) 28 | # if server was restarted then send create stream requests again 29 | if mgmt_msg.event_name == ServerMgmtMsg.EventServerStart: 30 | for stream_req in self._stream_reqs.values(): 31 | self._send_create_stream(stream_req) 32 | 33 | super(ZmqMgmtStream, self).write(stream_item) 34 | 35 | def add_stream_req(self, stream_req:StreamCreateRequest)->None: 36 | self._send_create_stream(stream_req) 37 | 38 | # save this for later for resend if server restarts 39 | self._stream_reqs[stream_req.stream_name] = stream_req 40 | 41 | # override to send request to server 42 | def del_stream(self, name:str) -> None: 43 | clisrv_req = ClientServerRequest(CliSrvReqTypes.del_stream, name) 44 | self._clisrv.send_obj(clisrv_req) 45 | self._stream_reqs.pop(name, None) 46 | 47 | def _send_create_stream(self, stream_req): 48 | utils.debug_log("sending create streamreq...") 49 | clisrv_req = ClientServerRequest(CliSrvReqTypes.create_stream, stream_req) 50 | self._clisrv.send_obj(clisrv_req) 51 | utils.debug_log("sent create streamreq") 52 | 53 | def close(self): 54 | if not self.closed: 55 | self._stream_reqs = {} 56 | self._clisrv = None 57 | utils.debug_log('ZmqMgmtStream is closed', verbosity=1) 58 | super(ZmqMgmtStream, self).close() 59 | -------------------------------------------------------------------------------- /tensorwatch/zmq_stream.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from typing import Any 5 | from .zmq_wrapper import ZmqWrapper 6 | from .stream import Stream 7 | from .lv_types import DefaultPorts, PublisherTopics 8 | from . import utils 9 | 10 | # on writes send data on ZMQ transport 11 | class ZmqStream(Stream): 12 | def __init__(self, for_write:bool, port:int=0, topic=PublisherTopics.StreamItem, block_until_connected=True, 13 | stream_name:str=None, console_debug:bool=False): 14 | super(ZmqStream, self).__init__(stream_name=stream_name, console_debug=console_debug) 15 | 16 | self.for_write = for_write 17 | self._zmq = None 18 | 19 | self.topic = topic 20 | self._open(for_write, port, block_until_connected) 21 | utils.debug_log('ZmqStream started', verbosity=1) 22 | 23 | def _open(self, for_write:bool, port:int, block_until_connected:bool): 24 | if for_write: 25 | self._zmq = ZmqWrapper.Publication(port=DefaultPorts.PubSub+port, 26 | block_until_connected=block_until_connected) 27 | else: 28 | self._zmq = ZmqWrapper.Subscription(port=DefaultPorts.PubSub+port, 29 | topic=self.topic, callback=self._on_subscription_item) 30 | 31 | def close(self): 32 | if not self.closed: 33 | self._zmq.close() 34 | self._zmq = None 35 | utils.debug_log('ZmqStream is closed', verbosity=1) 36 | super(ZmqStream, self).close() 37 | 38 | def _on_subscription_item(self, val:Any): 39 | utils.debug_log('Received subscription item', verbosity=5) 40 | self.write(val) 41 | 42 | def write(self, val:Any, from_stream:'Stream'=None, topic=None): 43 | stream_item = self.to_stream_item(val) 44 | 45 | if self.for_write: 46 | topic = topic or self.topic 47 | utils.debug_log('Sent subscription item', verbosity=5) 48 | self._zmq.send_obj(stream_item, topic) 49 | # else if this was opened for read then we have subscription and 50 | # we shouldn't be calling send_obj 51 | super(ZmqStream, self).write(stream_item) 52 | -------------------------------------------------------------------------------- /test/components/circ_ref.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | import objgraph, time #pip install objgraph 3 | 4 | cli = tw.WatcherClient() 5 | time.sleep(10) 6 | del cli 7 | 8 | import gc 9 | gc.collect() 10 | 11 | import time 12 | time.sleep(2) 13 | 14 | objgraph.show_backrefs(objgraph.by_type('WatcherClient'), refcounts=True, filename='b.png') 15 | 16 | -------------------------------------------------------------------------------- /test/components/evaler.py: -------------------------------------------------------------------------------- 1 | #import tensorwatch as tw 2 | from tensorwatch import evaler 3 | 4 | 5 | e = evaler.Evaler(expr='reduce(lambda x,y: (x+y), map(lambda x:(x**2), filter(lambda x: x%2==0, l)))') 6 | for i in range(5): 7 | eval_return = e.post(i) 8 | print(i, eval_return) 9 | eval_return = e.post(ended=True) 10 | print(i, eval_return) 11 | -------------------------------------------------------------------------------- /test/components/file_only_test.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | 3 | def writer(): 4 | watcher = tw.Watcher(filename=r'c:\temp\test.log', port=None) 5 | 6 | with watcher.create_stream('metric1') as stream1: 7 | for i in range(3): 8 | stream1.write((i, i*i)) 9 | 10 | with watcher.create_stream('metric2') as stream2: 11 | for i in range(3): 12 | stream2.write((i, i*i*i)) 13 | 14 | def reader1(): 15 | print('---------------------------reader1---------------------------') 16 | watcher = tw.Watcher(filename=r'c:\temp\test.log', port=None) 17 | 18 | stream1 = watcher.open_stream('metric1') 19 | stream1.console_debug = True 20 | stream1.load() 21 | 22 | stream2 = watcher.open_stream('metric2') 23 | stream2.console_debug = True 24 | stream2.load() 25 | 26 | def reader2(): 27 | print('---------------------------reader2---------------------------') 28 | 29 | watcher = tw.Watcher(filename=r'c:\temp\test.log', port=None) 30 | 31 | stream1 = watcher.open_stream('metric1') 32 | for item in stream1.read_all(): 33 | print(item) 34 | 35 | stream2 = watcher.open_stream('metric2') 36 | for item in stream2.read_all(): 37 | print(item) 38 | 39 | def reader3(): 40 | print('---------------------------reader3---------------------------') 41 | 42 | watcher = tw.Watcher(filename=r'c:\temp\test.log', port=None) 43 | stream1 = watcher.open_stream('metric1') 44 | stream2 = watcher.open_stream('metric2') 45 | 46 | vis1 = tw.Visualizer(stream1, vis_type='line') 47 | vis2 = tw.Visualizer(stream2, vis_type='line', host=vis1) 48 | 49 | vis1.show() 50 | 51 | tw.plt_loop() 52 | 53 | 54 | writer() 55 | reader1() 56 | reader2() 57 | reader3() 58 | 59 | -------------------------------------------------------------------------------- /test/components/notebook_maker.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | 3 | 4 | def main(): 5 | w = tw.Watcher() 6 | s1 = w.create_stream() 7 | s2 = w.create_stream(name='accuracy', vis_args=tw.VisArgs(vis_type='line', xtitle='X-Axis', clear_after_each=False, history_len=2)) 8 | s3 = w.create_stream(name='loss', expr='lambda d:d.loss') 9 | w.make_notebook() 10 | 11 | main() 12 | 13 | -------------------------------------------------------------------------------- /test/components/stream.py: -------------------------------------------------------------------------------- 1 | from tensorwatch.stream import Stream 2 | 3 | 4 | s1 = Stream(stream_name='s1', console_debug=True) 5 | s2 = Stream(stream_name='s2', console_debug=True) 6 | s3 = Stream(stream_name='s3', console_debug=True) 7 | 8 | s1.subscribe(s2) 9 | s2.subscribe(s3) 10 | 11 | s3.write('S3 wrote this') 12 | s2.write('S2 wrote this') 13 | s1.write('S1 wrote this') 14 | 15 | -------------------------------------------------------------------------------- /test/components/watcher.py: -------------------------------------------------------------------------------- 1 | from tensorwatch.watcher_base import WatcherBase 2 | from tensorwatch.stream import Stream 3 | 4 | def main(): 5 | watcher = WatcherBase() 6 | console_pub = Stream(stream_name = 'S1', console_debug=True) 7 | stream = watcher.create_stream(expr='lambda vars:vars.x**2') 8 | console_pub.subscribe(stream) 9 | 10 | for i in range(5): 11 | watcher.observe(x=i) 12 | 13 | main() 14 | 15 | 16 | -------------------------------------------------------------------------------- /test/deps/ipython_widget.py: -------------------------------------------------------------------------------- 1 | import ipywidgets as widgets 2 | from IPython import get_ipython 3 | 4 | class PrinterX: 5 | def __init__(self): 6 | self.w = w=widgets.HTML() 7 | def show(self): 8 | return self.w 9 | def write(self,s): 10 | self.w.value = s 11 | 12 | print("Running from within ipython?", get_ipython() is not None) 13 | p=PrinterX() 14 | p.show() 15 | p.write('ffffffffff') 16 | -------------------------------------------------------------------------------- /test/deps/live_graph.py: -------------------------------------------------------------------------------- 1 | from matplotlib import pyplot as plt 2 | from matplotlib.animation import FuncAnimation 3 | from random import randrange 4 | from threading import Thread 5 | import time 6 | 7 | class LiveGraph: 8 | def __init__(self): 9 | self.x_data, self.y_data = [], [] 10 | self.figure = plt.figure() 11 | self.line, = plt.plot(self.x_data, self.y_data) 12 | self.animation = FuncAnimation(self.figure, self.update, interval=1000) 13 | self.th = Thread(target=self.thread_f, name='LiveGraph', daemon=True) 14 | self.th.start() 15 | 16 | def update(self, frame): 17 | self.line.set_data(self.x_data, self.y_data) 18 | self.figure.gca().relim() 19 | self.figure.gca().autoscale_view() 20 | return self.line, 21 | 22 | def show(self): 23 | plt.show() 24 | 25 | def thread_f(self): 26 | x = 0 27 | while True: 28 | self.x_data.append(x) 29 | x += 1 30 | self.y_data.append(randrange(0, 100)) 31 | time.sleep(1) -------------------------------------------------------------------------------- /test/deps/panda.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | class SomeThing: 4 | def __init__(self, x, y): 5 | self.x, self.y = x, y 6 | 7 | things = [SomeThing(1,2), SomeThing(3,4), SomeThing(4,5)] 8 | 9 | df = pd.DataFrame([t.__dict__ for t in things ]) 10 | 11 | print(df.iloc[[-1]].to_dict('records')[0]) 12 | #print(df.to_html()) 13 | print(df.style.render()) 14 | -------------------------------------------------------------------------------- /test/deps/thread.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import time 3 | import sys 4 | 5 | def handler(a,b=None): 6 | sys.exit(1) 7 | def install_handler(): 8 | if sys.platform == "win32": 9 | if sys.stdin is not None and sys.stdin.isatty(): 10 | #this is Console based application 11 | import win32api 12 | win32api.SetConsoleCtrlHandler(handler, True) 13 | 14 | 15 | def work(): 16 | time.sleep(10000) 17 | t = threading.Thread(target=work, name='ThreadTest') 18 | t.daemon = True 19 | t.start() 20 | while(True): 21 | t.join(0.1) #100ms ~ typical human response 22 | # you will get KeyboardIntrupt exception 23 | 24 | -------------------------------------------------------------------------------- /test/files/file_stream.py: -------------------------------------------------------------------------------- 1 | from tensorwatch.watcher_base import WatcherBase 2 | from tensorwatch.stream import Stream 3 | from tensorwatch.file_stream import FileStream 4 | from tensorwatch import LinePlot 5 | from tensorwatch.image_utils import plt_loop 6 | import tensorwatch as tw 7 | 8 | def file_write(): 9 | watcher = WatcherBase() 10 | stream = watcher.create_stream(expr='lambda vars:(vars.x, vars.x**2)', 11 | devices=[r'c:\temp\obs.txt']) 12 | 13 | for i in range(5): 14 | watcher.observe(x=i) 15 | 16 | def file_read(): 17 | watcher = WatcherBase() 18 | stream = watcher.open_stream(devices=[r'c:\temp\obs.txt']) 19 | vis = tw.Visualizer(stream, vis_type='mpl-line') 20 | vis.show() 21 | plt_loop() 22 | 23 | def main(): 24 | file_write() 25 | file_read() 26 | 27 | main() 28 | 29 | -------------------------------------------------------------------------------- /test/mnist/cli_mnist.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | import time 3 | import math 4 | from tensorwatch import utils 5 | 6 | utils.set_debug_verbosity(4) 7 | 8 | def img_in_class(): 9 | cli_train = tw.WatcherClient() 10 | 11 | imgs = cli_train.create_stream(event_name='batch', 12 | expr="topk_all(l, batch_vals=lambda b: (b.batch.loss_all, (b.batch.input, b.batch.output), b.batch.target), \ 13 | out_f=image_class_outf, order='dsc')", throttle=1) 14 | img_plot = tw.ImagePlot() 15 | img_plot.subscribe(imgs, viz_img_scale=3) 16 | img_plot.show() 17 | 18 | tw.image_utils.plt_loop() 19 | 20 | def show_find_lr(): 21 | cli_train = tw.WatcherClient() 22 | plot = tw.LinePlot() 23 | 24 | train_batch_loss = cli_train.create_stream(event_name='batch', 25 | expr='lambda d:(d.tt.scheduler.get_lr()[0], d.metrics.batch_loss)') 26 | plot.subscribe(train_batch_loss, xtitle='Epoch', ytitle='Loss') 27 | 28 | utils.wait_key() 29 | 30 | def plot_grads_plotly(): 31 | train_cli = tw.WatcherClient() 32 | grads = train_cli.create_stream(event_name='batch', 33 | expr='lambda d:grads_abs_mean(d.model)', throttle=1) 34 | p = tw.plotly.line_plot.LinePlot('Demo') 35 | p.subscribe(grads, xtitle='Layer', ytitle='Gradients', history_len=30, new_on_eval=True) 36 | utils.wait_key() 37 | 38 | 39 | def plot_grads(): 40 | train_cli = tw.WatcherClient() 41 | 42 | grads = train_cli.create_stream(event_name='batch', 43 | expr='lambda d:grads_abs_mean(d.model)', throttle=1) 44 | grad_plot = tw.LinePlot() 45 | grad_plot.subscribe(grads, xtitle='Layer', ytitle='Gradients', clear_after_each=1, history_len=40, dim_history=True) 46 | grad_plot.show() 47 | 48 | tw.plt_loop() 49 | 50 | def plot_weight(): 51 | train_cli = tw.WatcherClient() 52 | 53 | params = train_cli.create_stream(event_name='batch', 54 | expr='lambda d:weights_abs_mean(d.model)', throttle=1) 55 | params_plot = tw.LinePlot() 56 | params_plot.subscribe(params, xtitle='Layer', ytitle='avg |params|', clear_after_each=1, history_len=40, dim_history=True) 57 | params_plot.show() 58 | 59 | tw.plt_loop() 60 | 61 | def epoch_stats(): 62 | train_cli = tw.WatcherClient(port=0) 63 | test_cli = tw.WatcherClient(port=1) 64 | 65 | plot = tw.LinePlot() 66 | 67 | train_loss = train_cli.create_stream(event_name="epoch", 68 | expr='lambda v:(v.metrics.epoch_index, v.metrics.epoch_loss)') 69 | plot.subscribe(train_loss, xtitle='Epoch', ytitle='Train Loss') 70 | 71 | test_acc = test_cli.create_stream(event_name="epoch", 72 | expr='lambda v:(v.metrics.epoch_index, v.metrics.epoch_accuracy)') 73 | plot.subscribe(test_acc, xtitle='Epoch', ytitle='Test Accuracy', ylim=(0,1)) 74 | 75 | plot.show() 76 | tw.plt_loop() 77 | 78 | 79 | def batch_stats(): 80 | train_cli = tw.WatcherClient() 81 | stream = train_cli.create_stream(event_name="batch", 82 | expr='lambda v:(v.metrics.epochf, v.metrics.batch_loss)', throttle=0.75) 83 | 84 | train_loss = tw.Visualizer(stream, clear_after_end=False, vis_type='mpl-line', 85 | xtitle='Epoch', ytitle='Train Loss') 86 | 87 | #train_acc = tw.Visualizer('lambda v:(v.metrics.epochf, v.metrics.epoch_loss)', event_name="batch", 88 | # xtitle='Epoch', ytitle='Train Accuracy', clear_after_end=False, yrange=(0,1), 89 | # vis=train_loss, vis_type='mpl-line') 90 | 91 | train_loss.show() 92 | tw.plt_loop() 93 | 94 | def text_stats(): 95 | train_cli = tw.WatcherClient() 96 | stream = train_cli.create_stream(event_name="batch", 97 | expr='lambda d:(d.metrics.epoch_index, d.metrics.batch_loss)') 98 | 99 | trl = tw.Visualizer(stream, vis_type='text') 100 | trl.show() 101 | input('Paused...') 102 | 103 | 104 | 105 | #epoch_stats() 106 | #plot_weight() 107 | #plot_grads() 108 | #img_in_class() 109 | text_stats() 110 | #batch_stats() -------------------------------------------------------------------------------- /test/post_train/saliency.py: -------------------------------------------------------------------------------- 1 | from tensorwatch.saliency import saliency 2 | from tensorwatch import image_utils, imagenet_utils, pytorch_utils 3 | 4 | model = pytorch_utils.get_model('resnet50') 5 | raw_input, input, target_class = pytorch_utils.image_class2tensor('../data/test_images/dogs.png', 240, #'../data/elephant.png', 101, 6 | image_transform=imagenet_utils.get_image_transform(), image_convert_mode='RGB') 7 | 8 | results = saliency.get_image_saliency_results(model, raw_input, input, target_class) 9 | figure = saliency.get_image_saliency_plot(results) 10 | 11 | image_utils.plt_loop() 12 | 13 | 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /test/pre_train/dot_manual.bat: -------------------------------------------------------------------------------- 1 | dot -Tpng -O test\pre_train\sample.dot -------------------------------------------------------------------------------- /test/pre_train/draw_cust_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.models 3 | import tensorwatch as tw 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torchvision 9 | from torch.utils.tensorboard import SummaryWriter 10 | from torchvision import datasets, transforms 11 | writer = SummaryWriter() 12 | class Net(nn.Module): 13 | 14 | def __init__(self): 15 | super(Net, self).__init__() 16 | self.layer = nn.Sequential( 17 | nn.Conv2d(3, 64, 7, padding=1, stride=1), 18 | nn.ReLU(), 19 | nn.MaxPool2d(1, 2), 20 | nn.BatchNorm2d(64), 21 | nn.AvgPool2d(1, 1), 22 | nn.Dropout(0.5), 23 | nn.Linear(110,30) 24 | ) 25 | 26 | def forward(self, x): 27 | return self.layer(x) 28 | 29 | 30 | net = Net() 31 | args = torch.ones([1, 3, 224, 224]) 32 | # writer.add_graph(net, args) 33 | # writer.close() 34 | 35 | #vgg16_model = torchvision.models.vgg16() 36 | 37 | drawing = tw.draw_model(net, [1, 3, 224, 224]) 38 | drawing.save('abc2.png') 39 | 40 | input("Press any key") -------------------------------------------------------------------------------- /test/pre_train/draw_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.models 3 | import tensorwatch as tw 4 | 5 | vgg16_model = torchvision.models.vgg16() 6 | 7 | drawing = tw.draw_model(vgg16_model, [1, 3, 224, 224]) 8 | drawing.save('abc.png') 9 | 10 | input("Press any key") -------------------------------------------------------------------------------- /test/pre_train/model_stats.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | import torchvision.models 3 | 4 | model_names = ['alexnet', 'resnet18', 'resnet34', 'resnet101', 'densenet121'] 5 | 6 | for model_name in model_names: 7 | model = getattr(torchvision.models, model_name)() 8 | model_stats = tw.ModelStats(model, [1, 3, 224, 224], clone_model=False) 9 | print(f'{model_name}: flops={model_stats.Flops}, parameters={model_stats.parameters}, memory={model_stats.inference_memory}') -------------------------------------------------------------------------------- /test/pre_train/model_stats_perf.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import tensorwatch as tw 3 | import torchvision.models 4 | import torch 5 | import time 6 | 7 | model = getattr(torchvision.models, 'densenet201')() 8 | 9 | def model_timing(model): 10 | st = time.time() 11 | for _ in range(20): 12 | batch = torch.rand([64, 3, 224, 224]) 13 | y = model(batch) 14 | return time.time()-st 15 | 16 | print(model_timing(model)) 17 | model_stats = tw.ModelStats(model, [1, 3, 224, 224], clone_model=False) 18 | print(f'flops={model_stats.Flops}, parameters={model_stats.parameters}, memory={model_stats.inference_memory}') 19 | print(model_timing(model)) 20 | -------------------------------------------------------------------------------- /test/pre_train/test_pydot.py: -------------------------------------------------------------------------------- 1 | import pydot 2 | 3 | g = pydot.Dot() 4 | g.set_type('digraph') 5 | node = pydot.Node('legend') 6 | node.set("shape", 'box') 7 | g.add_node(node) 8 | node.set('label', 'mine') 9 | s = g.to_string() 10 | expected = 'digraph G {\nlegend [label=mine, shape=box];\n}\n' 11 | assert s == expected 12 | print(s) 13 | png = g.create_png() 14 | print(png) -------------------------------------------------------------------------------- /test/pre_train/tsny.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | 3 | from regim import * 4 | ds = DataUtils.mnist_datasets(linearize=True, train_test=False) 5 | ds = DataUtils.sample_by_class(ds, k=5, shuffle=True, as_np=True, no_test=True) 6 | 7 | comps = tw.get_tsne_components(ds) 8 | print(comps) 9 | plot = tw.Visualizer(comps, hover_images=ds[0], hover_image_reshape=(28,28), vis_type='tsne') 10 | plot.show() -------------------------------------------------------------------------------- /test/simple_log/cli_file_expr.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | from tensorwatch import utils 3 | utils.set_debug_verbosity(4) 4 | 5 | 6 | #r = tw.Visualizer(vis_type='mpl-line') 7 | #r.show() 8 | #r2=tw.Visualizer('map(lambda x:math.sqrt(x.sum), l)', cell=r.cell) 9 | #r3=tw.Visualizer('map(lambda x:math.sqrt(x.sum), l)', renderer=r2) 10 | 11 | def show_mpl(): 12 | cli = tw.WatcherClient(r'c:\temp\sum.log') 13 | s1 = cli.open_stream('sum') 14 | p = tw.LinePlot(title='Demo') 15 | p.subscribe(s1, xtitle='Index', ytitle='sqrt(ev_i)') 16 | s1.load() 17 | p.show() 18 | 19 | tw.plt_loop() 20 | 21 | def show_text(): 22 | cli = tw.WatcherClient(r'c:\temp\sum.log') 23 | s1 = cli.open_stream('sum_2') 24 | text = tw.Visualizer(s1) 25 | text.show() 26 | input('Waiting') 27 | 28 | #show_text() 29 | show_mpl() 30 | -------------------------------------------------------------------------------- /test/simple_log/cli_ij.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | import time 3 | import math 4 | from tensorwatch import utils 5 | utils.set_debug_verbosity(4) 6 | 7 | def mpl_line_plot(): 8 | cli = tw.WatcherClient() 9 | p = tw.LinePlot(title='Demo') 10 | s1 = cli.create_stream(event_name='ev_i', expr='map(lambda v:math.sqrt(v.val)*2, l)') 11 | p.subscribe(s1, xtitle='Index', ytitle='sqrt(ev_i)') 12 | p.show() 13 | tw.plt_loop() 14 | 15 | def mpl_history_plot(): 16 | cli = tw.WatcherClient() 17 | p2 = tw.LinePlot(title='History Demo') 18 | p2s1 = cli.create_stream(event_name='ev_j', expr='map(lambda v:(v.val, math.sqrt(v.val)*2), l)') 19 | p2.subscribe(p2s1, xtitle='Index', ytitle='sqrt(ev_j)', clear_after_end=True, history_len=15) 20 | p2.show() 21 | tw.plt_loop() 22 | 23 | def show_stream(): 24 | cli = tw.WatcherClient() 25 | 26 | print("Subscribing to event ev_i...") 27 | s1 = cli.create_stream(event_name="ev_i", expr='map(lambda v:math.sqrt(v.val), l)') 28 | r1 = tw.TextVis(title='L1') 29 | r1.subscribe(s1) 30 | r1.show() 31 | 32 | print("Subscribing to event ev_j...") 33 | s2 = cli.create_stream(event_name="ev_j", expr='map(lambda v:v.val*v.val, l)') 34 | r2 = tw.TextVis(title='L2') 35 | r2.subscribe(s2) 36 | 37 | r2.show() 38 | 39 | print("Waiting for key...") 40 | 41 | utils.wait_key() 42 | 43 | # this no longer directly supported 44 | # TODO: create stream that allows enumeration from buffered values 45 | #def read_stream(): 46 | # cli = tw.WatcherClient() 47 | 48 | # with cli.create_stream(event_name="ev_i", expr='map(lambda v:(v.x, math.sqrt(v.val)), l)') as s1: 49 | # for stream_item in s1: 50 | # print(stream_item.value) 51 | # print('done') 52 | # utils.wait_key() 53 | 54 | def plotly_line_graph(): 55 | cli = tw.WatcherClient() 56 | s1 = cli.create_stream(event_name="ev_i", expr='map(lambda v:(v.x, math.sqrt(v.val)), l)') 57 | 58 | p = tw.plotly.line_plot.LinePlot() 59 | p.subscribe(s1) 60 | p.show() 61 | 62 | utils.wait_key() 63 | 64 | def plotly_history_graph(): 65 | cli = tw.WatcherClient() 66 | p = tw.plotly.line_plot.LinePlot(title='Demo') 67 | s2 = cli.create_stream(event_name='ev_j', expr='map(lambda v:(v.x, v.val), l)') 68 | p.subscribe(s2, ytitle='ev_j', history_len=15) 69 | p.show() 70 | utils.wait_key() 71 | 72 | 73 | mpl_line_plot() 74 | #mpl_history_plot() 75 | #show_stream() 76 | #plotly_line_graph() 77 | #plotly_history_graph() 78 | -------------------------------------------------------------------------------- /test/simple_log/cli_sum_log.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | from tensorwatch import utils 3 | utils.set_debug_verbosity(4) 4 | 5 | 6 | #r = tw.Visualizer(vis_type='mpl-line') 7 | #r.show() 8 | #r2=tw.Visualizer('map(lambda x:math.sqrt(x.sum), l)', cell=r.cell) 9 | #r3=tw.Visualizer('map(lambda x:math.sqrt(x.sum), l)', renderer=r2) 10 | 11 | def show_mpl(): 12 | cli = tw.WatcherClient() 13 | st_isum = cli.open_stream('isums') 14 | st_rsum = cli.open_stream('rsums') 15 | 16 | line_plot = tw.Visualizer(st_isum, vis_type='line', xtitle='i', ytitle='isum') 17 | line_plot.show() 18 | 19 | line_plot2 = tw.Visualizer(st_rsum, vis_type='line', host=line_plot, ytitle='rsum') 20 | 21 | tw.plt_loop() 22 | 23 | def show_text(): 24 | cli = tw.WatcherClient() 25 | text_vis = tw.Visualizer(st_isum, vis_type='text') 26 | text_vis.show() 27 | input('Waiting') 28 | 29 | #show_text() 30 | show_mpl() -------------------------------------------------------------------------------- /test/simple_log/file_expr.py: -------------------------------------------------------------------------------- 1 | import time 2 | import tensorwatch as tw 3 | from tensorwatch import utils 4 | utils.set_debug_verbosity(4) 5 | 6 | srv = tw.Watcher(filename=r'c:\temp\sum.log') 7 | s1 = srv.create_stream('sum', expr='lambda v:(v.i, v.sum)') 8 | s2 = srv.create_stream('sum_2', expr='lambda v:(v.i, v.sum/2)') 9 | 10 | sum = 0 11 | for i in range(10000): 12 | sum += i 13 | srv.observe(i=i, sum=sum) 14 | #print(i, sum) 15 | time.sleep(1) 16 | 17 | -------------------------------------------------------------------------------- /test/simple_log/quick_start.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | import time 3 | 4 | w = tw.Watcher(filename='test.log') 5 | s = w.create_stream(name='my_metric') 6 | #w.make_notebook() 7 | 8 | for i in range(1000): 9 | s.write((i, i*i)) 10 | time.sleep(1) 11 | -------------------------------------------------------------------------------- /test/simple_log/srv_ij.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | import time 3 | import random 4 | from tensorwatch import utils 5 | 6 | utils.set_debug_verbosity(4) 7 | 8 | srv = tw.Watcher() 9 | 10 | while(True): 11 | for i in range(1000): 12 | srv.observe("ev_i", val=i*random.random(), x=i) 13 | print('sent ev_i ', i) 14 | time.sleep(1) 15 | for j in range(5): 16 | srv.observe("ev_j", x=j, val=j*random.random()) 17 | print('sent ev_j ', j) 18 | time.sleep(0.5) 19 | srv.end_event("ev_j") 20 | srv.end_event("ev_i") -------------------------------------------------------------------------------- /test/simple_log/sum_lazy.py: -------------------------------------------------------------------------------- 1 | import time, random 2 | import tensorwatch as tw 3 | 4 | # create watcher object as usual 5 | w = tw.Watcher() 6 | 7 | weights = None 8 | for i in range(10000): 9 | weights = [random.random() for _ in range(5)] 10 | 11 | # let watcher observe variables we have 12 | # this has almost no performance cost 13 | w.observe(weights=weights) 14 | 15 | time.sleep(1) 16 | 17 | -------------------------------------------------------------------------------- /test/simple_log/sum_log.py: -------------------------------------------------------------------------------- 1 | import time, random 2 | import tensorwatch as tw 3 | 4 | # we will create two streams, one for 5 | # sums of integers, other for random numbers 6 | w = tw.Watcher() 7 | st_isum = w.create_stream('isums') 8 | st_rsum = w.create_stream('rsums') 9 | 10 | isum, rsum = 0, 0 11 | for i in range(10000): 12 | isum += i 13 | rsum += random.randint(0,i) 14 | 15 | # write to streams 16 | st_isum.write((i, isum)) 17 | st_rsum.write((i, rsum)) 18 | 19 | time.sleep(1) 20 | -------------------------------------------------------------------------------- /test/visualizations/arr_img_plot.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | import numpy as np 3 | import time 4 | import torchvision.datasets as datasets 5 | 6 | fruits_ds = datasets.ImageFolder(r'D:\datasets\fruits-360\Training') 7 | mnist_ds = datasets.MNIST('../data', train=True, download=True) 8 | 9 | images = [tw.ImageData(fruits_ds[i][0], title=str(i)) for i in range(5)] + \ 10 | [tw.ImageData(mnist_ds[i][0], title=str(i)) for i in range(5)] 11 | 12 | stream = tw.ArrayStream(images) 13 | 14 | img_plot = tw.Visualizer(stream, vis_type='image', viz_img_scale=3) 15 | img_plot.show() 16 | 17 | tw.image_utils.plt_loop() -------------------------------------------------------------------------------- /test/visualizations/arr_mpl_line.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | 3 | stream = tw.ArrayStream([(i, i*i) for i in range(50)]) 4 | img_plot = tw.Visualizer(stream, vis_type='mpl-line', viz_img_scale=3, xtitle='Epochs', ytitle='Gain') 5 | # img_plot.show() 6 | # tw.plt_loop() 7 | img_plot.save(r'c:\temp\fig1.png') 8 | -------------------------------------------------------------------------------- /test/visualizations/bar_plot.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | import random, time 3 | 4 | def static_bar(): 5 | w = tw.Watcher() 6 | s = w.create_stream() 7 | 8 | v = tw.Visualizer(s, vis_type='bar') 9 | v.show() 10 | 11 | for i in range(10): 12 | s.write(int(random.random()*10)) 13 | 14 | tw.plt_loop() 15 | 16 | 17 | def dynamic_bar(): 18 | w = tw.Watcher() 19 | s = w.create_stream() 20 | 21 | v = tw.Visualizer(s, vis_type='bar', clear_after_each=True) 22 | v.show() 23 | 24 | for i in range(100): 25 | s.write([('a'+str(i), random.random()*10) for i in range(10)]) 26 | tw.plt_loop(count=3) 27 | 28 | def dynamic_bar3d(): 29 | w = tw.Watcher() 30 | s = w.create_stream() 31 | 32 | v = tw.Visualizer(s, vis_type='bar3d', clear_after_each=True) 33 | v.show() 34 | 35 | for i in range(100): 36 | s.write([(i, random.random()*10, z) for i in range(10) for z in range(10)]) 37 | tw.plt_loop(count=3) 38 | 39 | static_bar() 40 | #dynamic_bar() 41 | #dynamic_bar3d() 42 | -------------------------------------------------------------------------------- /test/visualizations/confidence_int.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | 3 | w = tw.Watcher() 4 | s = w.create_stream() 5 | 6 | v = tw.Visualizer(s, vis_type='line') 7 | v.show() 8 | 9 | for i in range(10): 10 | i = float(i) 11 | s.write(tw.PointData(i, i*i, low=i*i-i, high=i*i+i)) 12 | 13 | tw.plt_loop() -------------------------------------------------------------------------------- /test/visualizations/histogram.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | import random, time 3 | 4 | def static_hist(): 5 | w = tw.Watcher() 6 | s = w.create_stream() 7 | 8 | v = tw.Visualizer(s, vis_type='histogram', bins=6) 9 | v.show() 10 | 11 | for _ in range(100): 12 | s.write(random.random()*10) 13 | 14 | tw.plt_loop() 15 | 16 | 17 | def dynamic_hist(): 18 | w = tw.Watcher() 19 | s = w.create_stream() 20 | 21 | v = tw.Visualizer(s, vis_type='histogram', bins=6, clear_after_each=True) 22 | v.show() 23 | 24 | for _ in range(100): 25 | s.write([random.random()*10 for _ in range(100)]) 26 | tw.plt_loop(count=3) 27 | 28 | dynamic_hist() -------------------------------------------------------------------------------- /test/visualizations/line3d_plot.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | import random, time 3 | 4 | # TODO: resolve problem with Axis3D? 5 | 6 | def static_line3d(): 7 | w = tw.Watcher() 8 | s = w.create_stream() 9 | 10 | v = tw.Visualizer(s, vis_type='line3d') 11 | v.show() 12 | 13 | for i in range(10): 14 | s.write((i, i*i, int(random.random()*10))) 15 | 16 | tw.plt_loop() 17 | 18 | def dynamic_line3d(): 19 | w = tw.Watcher() 20 | s = w.create_stream() 21 | 22 | v = tw.Visualizer(s, vis_type='line3d', clear_after_each=True) 23 | v.show() 24 | 25 | for i in range(100): 26 | s.write([(i, random.random()*10, z) for i in range(10) for z in range(10)]) 27 | tw.plt_loop(count=3) 28 | 29 | static_line3d() 30 | #dynamic_line3d() 31 | 32 | -------------------------------------------------------------------------------- /test/visualizations/mpl_line.py: -------------------------------------------------------------------------------- 1 | from tensorwatch.watcher_base import WatcherBase 2 | from tensorwatch.mpl.line_plot import LinePlot 3 | from tensorwatch.image_utils import plt_loop 4 | from tensorwatch.stream import Stream 5 | from tensorwatch.lv_types import StreamItem 6 | 7 | 8 | def main(): 9 | watcher = WatcherBase() 10 | line_plot = LinePlot() 11 | stream = watcher.create_stream(expr='lambda vars:vars.x') 12 | line_plot.subscribe(stream) 13 | line_plot.show() 14 | 15 | for i in range(5): 16 | watcher.observe(x=(i, i*i)) 17 | plt_loop() 18 | 19 | main() 20 | 21 | -------------------------------------------------------------------------------- /test/visualizations/pie_chart.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | import random, time 3 | 4 | def static_pie(): 5 | w = tw.Watcher() 6 | s = w.create_stream() 7 | 8 | v = tw.Visualizer(s, vis_type='pie', bins=6) 9 | v.show() 10 | 11 | for i in range(6): 12 | s.write(('label'+str(i), random.random()*10, None, 0.5 if i==3 else 0)) 13 | 14 | tw.plt_loop() 15 | 16 | 17 | def dynamic_pie(): 18 | w = tw.Watcher() 19 | s = w.create_stream() 20 | 21 | v = tw.Visualizer(s, vis_type='pie', bins=6, clear_after_each=True) 22 | v.show() 23 | 24 | for _ in range(100): 25 | s.write([('label'+str(i), random.random()*10, None, i*0.01) for i in range(12)]) 26 | tw.plt_loop(count=3) 27 | 28 | #static_pie() 29 | dynamic_pie() 30 | -------------------------------------------------------------------------------- /test/visualizations/plotly_line.py: -------------------------------------------------------------------------------- 1 | from tensorwatch.watcher_base import WatcherBase 2 | from tensorwatch.plotly.line_plot import LinePlot 3 | from tensorwatch.image_utils import plt_loop 4 | from tensorwatch.stream import Stream 5 | from tensorwatch.lv_types import StreamItem 6 | 7 | 8 | def main(): 9 | watcher = WatcherBase() 10 | line_plot = LinePlot() 11 | stream = watcher.create_stream(expr='lambda vars:vars.x') 12 | line_plot.subscribe(stream) 13 | line_plot.show() 14 | 15 | for i in range(5): 16 | watcher.observe(x=(i, i*i)) 17 | 18 | main() 19 | 20 | -------------------------------------------------------------------------------- /test/zmq/zmq_pub.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | import time 3 | from tensorwatch.zmq_wrapper import ZmqWrapper 4 | from tensorwatch import utils 5 | 6 | utils.set_debug_verbosity(10) 7 | 8 | def clisrv_callback(clisrv, msg): 9 | print('from clisrv', msg) 10 | 11 | stream = ZmqWrapper.Publication(port = 40859) 12 | clisrv = ZmqWrapper.ClientServer(40860, True, clisrv_callback) 13 | 14 | for i in range(10000): 15 | stream.send_obj({'a': i}, "Topic1") 16 | print("sent ", i) 17 | time.sleep(1) 18 | -------------------------------------------------------------------------------- /test/zmq/zmq_stream.py: -------------------------------------------------------------------------------- 1 | from tensorwatch.watcher_base import WatcherBase 2 | from tensorwatch.zmq_stream import ZmqStream 3 | 4 | def main(): 5 | watcher = WatcherBase() 6 | stream = watcher.create_stream(expr='lambda vars:vars.x**2') 7 | 8 | zmq_pub = ZmqStream(for_write=True, stream_name = 'ZmqPub', console_debug=True) 9 | zmq_pub.subscribe(stream) 10 | 11 | for i in range(5): 12 | watcher.observe(x=i) 13 | input('paused') 14 | 15 | main() 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /test/zmq/zmq_sub.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | import time 3 | from tensorwatch.zmq_wrapper import ZmqWrapper 4 | from tensorwatch import utils 5 | 6 | class A: 7 | def on_event(self, obj): 8 | print(obj) 9 | 10 | a = A() 11 | 12 | 13 | utils.set_debug_verbosity(10) 14 | sub = ZmqWrapper.Subscription(40859, "Topic1", a.on_event) 15 | print("subscriber is waiting") 16 | 17 | clisrv = ZmqWrapper.ClientServer(40860, False) 18 | clisrv.send_obj("hello 1") 19 | print('sleeping..') 20 | time.sleep(10) 21 | clisrv.send_obj("hello 2") 22 | 23 | print('waiting for key..') 24 | utils.wait_key() -------------------------------------------------------------------------------- /test/zmq/zmq_watcher_client.py: -------------------------------------------------------------------------------- 1 | from tensorwatch.watcher_client import WatcherClient 2 | import time 3 | 4 | from tensorwatch import utils 5 | utils.set_debug_verbosity(10) 6 | 7 | 8 | def main(): 9 | watcher = WatcherClient() 10 | stream = watcher.create_stream(expr='lambda vars:vars.x**2') 11 | stream.console_debug = True 12 | input('pause') 13 | 14 | main() 15 | 16 | -------------------------------------------------------------------------------- /test/zmq/zmq_watcher_server.py: -------------------------------------------------------------------------------- 1 | from tensorwatch.watcher import Watcher 2 | import time 3 | 4 | from tensorwatch import utils 5 | utils.set_debug_verbosity(10) 6 | 7 | 8 | def main(): 9 | watcher = Watcher() 10 | 11 | for i in range(5000): 12 | watcher.observe(x=i) 13 | # print(i) 14 | time.sleep(1) 15 | 16 | main() 17 | -------------------------------------------------------------------------------- /update_package.bat: -------------------------------------------------------------------------------- 1 | PAUSE Make sure to increment version in setup.py. Continue? 2 | python setup.py sdist 3 | twine upload --repository-url https://upload.pypi.org/legacy/ dist/* 4 | REM pip install tensorwatch --upgrade 5 | REM pip show tensorwatch 6 | 7 | REM pip install yolk3k 8 | REM yolk -V tensorwatch --------------------------------------------------------------------------------