├── .gitignore
├── .pylintrc
├── .vscode
└── launch.json
├── CHANGELOG.md
├── CONTRIBUTING.md
├── ISSUE_TEMPLATE.md
├── LICENSE.txt
├── MANIFEST.in
├── NOTICE.md
├── README.md
├── SECURITY.md
├── SUPPORT.md
├── TODO.md
├── data
└── test_images
│ ├── cat.jpg
│ ├── dogs.png
│ ├── elephant.png
│ ├── imagenet_class_index.json
│ └── imagenet_synsets.txt
├── docs
├── images
│ ├── draw_model.png
│ ├── fruits.gif
│ ├── lazy_log_array_sum.png
│ ├── model_stats.png
│ ├── quick_start.gif
│ ├── saliency.png
│ ├── simple_logging
│ │ ├── full-page2.png
│ │ ├── line_cell.png
│ │ ├── line_cell2.png
│ │ ├── plotly_line.png
│ │ ├── text_cell.png
│ │ └── text_summary.png
│ ├── teaser.gif
│ ├── teaser_small.gif
│ └── tsne.gif
├── lazy_logging.md
├── paper
│ ├── .gitignore
│ ├── ACM-Reference-Format.bbx
│ ├── ACM-Reference-Format.bst
│ ├── ACM-Reference-Format.cbx
│ ├── ACM-Reference-Format.dbx
│ ├── TensorWatch_Collaboration.png
│ ├── acmart.cls
│ ├── main.pdf
│ ├── main.tex
│ ├── sample-base.bib
│ ├── tensorwatch-screenshot.png
│ └── tensorwatch-screenshot2.png
└── simple_logging.md
├── install_jupyterlab.bat
├── notebooks
├── cnn_pred_explain.ipynb
├── data_exploration.ipynb
├── fruits_analysis.ipynb
├── lazy_logging.ipynb
├── mnist.ipynb
├── network_arch.ipynb
├── receptive_field.ipynb
└── simple_logging.ipynb
├── setup.py
├── tensorwatch.pyproj
├── tensorwatch.sln
├── tensorwatch
├── __init__.py
├── array_stream.py
├── data_utils.py
├── embeddings
│ ├── __init__.py
│ └── tsne_utils.py
├── evaler.py
├── evaler_utils.py
├── file_stream.py
├── filtered_stream.py
├── image_utils.py
├── imagenet_class_index.json
├── imagenet_synsets.txt
├── imagenet_utils.py
├── lv_types.py
├── model_graph
│ ├── __init__.py
│ ├── hiddenlayer
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── distiller.py
│ │ ├── distiller_utils.py
│ │ ├── ge.py
│ │ ├── graph.py
│ │ ├── pytorch_builder.py
│ │ ├── pytorch_builder_grad.py
│ │ ├── pytorch_builder_trace.py
│ │ ├── pytorch_draw_model.py
│ │ ├── summary_graph.py
│ │ ├── tf_builder.py
│ │ └── transforms.py
│ ├── torchstat
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── analyzer.py
│ │ ├── compute_flops.py
│ │ ├── compute_madd.py
│ │ ├── compute_memory.py
│ │ ├── reporter.py
│ │ └── stat_tree.py
│ └── torchstat_utils.py
├── mpl
│ ├── __init__.py
│ ├── bar_plot.py
│ ├── base_mpl_plot.py
│ ├── histogram.py
│ ├── image_plot.py
│ ├── line_plot.py
│ └── pie_chart.py
├── notebook_maker.py
├── plotly
│ ├── __init__.py
│ ├── base_plotly_plot.py
│ ├── embeddings_plot.py
│ └── line_plot.py
├── pytorch_utils.py
├── receptive_field
│ ├── __init__.py
│ └── rf_utils.py
├── repeated_timer.py
├── saliency
│ ├── README.md
│ ├── __init__.py
│ ├── backprop.py
│ ├── deeplift.py
│ ├── epsilon_lrp.py
│ ├── gradcam.py
│ ├── inverter_util.py
│ ├── lime
│ │ ├── __init__.py
│ │ ├── lime_base.py
│ │ ├── lime_image.py
│ │ └── wrappers
│ │ │ ├── __init__.py
│ │ │ ├── generic_utils.py
│ │ │ └── scikit_image.py
│ ├── lime_image_explainer.py
│ ├── occlusion.py
│ └── saliency.py
├── stream.py
├── stream_factory.py
├── stream_union.py
├── tensor_utils.py
├── text_vis.py
├── utils.py
├── vis_base.py
├── visualizer.py
├── watcher.py
├── watcher_base.py
├── watcher_client.py
├── zmq_mgmt_stream.py
├── zmq_stream.py
└── zmq_wrapper.py
├── test
├── components
│ ├── circ_ref.py
│ ├── evaler.py
│ ├── file_only_test.py
│ ├── notebook_maker.py
│ ├── stream.py
│ └── watcher.py
├── deps
│ ├── ipython_widget.py
│ ├── live_graph.py
│ ├── panda.py
│ └── thread.py
├── files
│ └── file_stream.py
├── mnist
│ └── cli_mnist.py
├── post_train
│ └── saliency.py
├── pre_train
│ ├── dot_manual.bat
│ ├── draw_cust_model.py
│ ├── draw_model.py
│ ├── model_stats.py
│ ├── model_stats_perf.py
│ ├── sample.dot
│ ├── test_pydot.py
│ └── tsny.py
├── simple_log
│ ├── cli_file_expr.py
│ ├── cli_ij.py
│ ├── cli_sum_log.py
│ ├── file_expr.py
│ ├── quick_start.py
│ ├── srv_ij.py
│ ├── sum_lazy.py
│ └── sum_log.py
├── test.pyproj
├── visualizations
│ ├── arr_img_plot.py
│ ├── arr_mpl_line.py
│ ├── bar_plot.py
│ ├── confidence_int.py
│ ├── histogram.py
│ ├── line3d_plot.py
│ ├── mpl_line.py
│ ├── pie_chart.py
│ └── plotly_line.py
└── zmq
│ ├── zmq_pub.py
│ ├── zmq_stream.py
│ ├── zmq_sub.py
│ ├── zmq_watcher_client.py
│ └── zmq_watcher_server.py
└── update_package.bat
/.pylintrc:
--------------------------------------------------------------------------------
1 | [MASTER]
2 | ignore=model_graph,
3 | receptive_field,
4 | saliency
5 |
6 | [TYPECHECK]
7 | ignored-modules=numpy,torch,matplotlib,pyplot,zmq
8 |
9 | [MESSAGES CONTROL]
10 | disable=protected-access,
11 | broad-except,
12 | global-statement,
13 | fixme,
14 | C,
15 | R
--------------------------------------------------------------------------------
/.vscode/launch.json:
--------------------------------------------------------------------------------
1 | {
2 | // Use IntelliSense to learn about possible attributes.
3 | // Hover to view descriptions of existing attributes.
4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5 | "version": "0.2.0",
6 | "configurations": [
7 | {
8 | "name": "Python: Current File",
9 | "type": "python",
10 | "request": "launch",
11 | "program": "${file}",
12 | "console": "integratedTerminal"
13 | }
14 | ]
15 | }
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # What's new
2 |
3 | Below is summarized list of important changes. This does not include minor/less important changes or bug fixes or documentation update. This list updated every few months. For complete detailed changes, please review [commit history](https://github.com/Microsoft/tensorwatch/commits/master).
4 |
5 | ### May, 2019
6 | * First release!
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing
2 |
3 | This project welcomes contributions and suggestions. Most contributions require you to
4 | agree to a Contributor License Agreement (CLA) declaring that you have the right to,
5 | and actually do, grant us the rights to use your contribution. For details, visit
6 | https://cla.microsoft.com.
7 |
8 | When you submit a pull request, a CLA-bot will automatically determine whether you need
9 | to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the
10 | instructions provided by the bot. You will only need to do this once across all repositories using our CLA.
11 |
12 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
13 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
14 | or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
15 |
--------------------------------------------------------------------------------
/ISSUE_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | # Read This First
2 |
3 | * Make sure to describe **all the steps** to reproduce the issue
4 | * Include full error message in the description
5 | * Add OS version, Python version, Pytorch version if applicable
6 |
7 | Remember: if we cannot reproduce your problem, we cannot find solution!
8 |
9 |
10 | **What's better than filing issue? Filing a pull request :).**
11 |
12 | ------------------------------------ (Remove above before filing the issue) ------------------------------------
13 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) Microsoft Corporation. All rights reserved.
2 |
3 | MIT License
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
23 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include tensorwatch/*.txt
2 | include tensorwatch/*.json
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd).
40 |
41 |
42 |
--------------------------------------------------------------------------------
/SUPPORT.md:
--------------------------------------------------------------------------------
1 | # Support
2 |
3 | We highly recommend to take a look at source code and contribute to the project. Also please consider [contributing](CONTRIBUTING.md) new features and fixes :).
4 |
5 | * [Join TensorWatch Facebook Group](https://www.facebook.com/groups/378075159472803/)
6 | * [File GitHub Issue](https://github.com/Microsoft/tensorwatch/issues)
--------------------------------------------------------------------------------
/TODO.md:
--------------------------------------------------------------------------------
1 | * Fix cell size issue
2 | * Refactor _plot* interface to accept all values, for ImagePlot only use last value
3 | * Refactor ImagePlot for arbitrary number of images with alpha, cmap
4 | * Change tw.open -> tw.create_viz
5 | * Make sure streams have names as key, each data point has index
6 | * Add tw.open_viz(stream_name, from_index)_
7 | * Add persist=device_name option for streams
8 | * Ability to use streams in standalone mode
9 | * tw.create_viz on server side
10 | * tw.log for server side
11 | * experiment with IPC channel
12 | * confusion matrix as in https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html
13 | * Speed up import
14 | * Do linting
15 | * live perf data
16 | * NaN tracing
17 | * PCA
18 | * Remove error if MNIST notebook is on and we run fruits
19 | * Remove 2nd image from fruits
20 | * clear exisitng streams when starting client
21 | * ImageData should accept numpy array or pillow or torch tensor
22 | * image plot getting refreshed at 12hz instead of 2 hz in MNIST
23 | * image plot doesn't title
24 | * Animated mesh/surface graph demo
25 | * Move to h5 storage?
26 | * Error envelop
27 | * histogram
28 | * new graph on end
29 | * TF support
30 | * generic visualizer -> Given obj and Box, paint in box
31 | * visualize image and text with attention
32 | * add confidence interval for plotly: https://plot.ly/python/continuous-error-bars/
--------------------------------------------------------------------------------
/data/test_images/cat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/data/test_images/cat.jpg
--------------------------------------------------------------------------------
/data/test_images/dogs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/data/test_images/dogs.png
--------------------------------------------------------------------------------
/data/test_images/elephant.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/data/test_images/elephant.png
--------------------------------------------------------------------------------
/docs/images/draw_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/images/draw_model.png
--------------------------------------------------------------------------------
/docs/images/fruits.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/images/fruits.gif
--------------------------------------------------------------------------------
/docs/images/lazy_log_array_sum.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/images/lazy_log_array_sum.png
--------------------------------------------------------------------------------
/docs/images/model_stats.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/images/model_stats.png
--------------------------------------------------------------------------------
/docs/images/quick_start.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/images/quick_start.gif
--------------------------------------------------------------------------------
/docs/images/saliency.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/images/saliency.png
--------------------------------------------------------------------------------
/docs/images/simple_logging/full-page2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/images/simple_logging/full-page2.png
--------------------------------------------------------------------------------
/docs/images/simple_logging/line_cell.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/images/simple_logging/line_cell.png
--------------------------------------------------------------------------------
/docs/images/simple_logging/line_cell2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/images/simple_logging/line_cell2.png
--------------------------------------------------------------------------------
/docs/images/simple_logging/plotly_line.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/images/simple_logging/plotly_line.png
--------------------------------------------------------------------------------
/docs/images/simple_logging/text_cell.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/images/simple_logging/text_cell.png
--------------------------------------------------------------------------------
/docs/images/simple_logging/text_summary.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/images/simple_logging/text_summary.png
--------------------------------------------------------------------------------
/docs/images/teaser.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/images/teaser.gif
--------------------------------------------------------------------------------
/docs/images/teaser_small.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/images/teaser_small.gif
--------------------------------------------------------------------------------
/docs/images/tsne.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/images/tsne.gif
--------------------------------------------------------------------------------
/docs/lazy_logging.md:
--------------------------------------------------------------------------------
1 |
2 | # Lazy Logging Tutorial
3 |
4 | In this tutorial, we will use a straightforward script that creates an array of random numbers. Our goal is to examine this array from the Jupyter Notebook while below script is running separately in a console.
5 |
6 | ```
7 | # available in test/simple_log/sum_lazy.py
8 |
9 | import time, random
10 | import tensorwatch as tw
11 |
12 | # create watcher object as usual
13 | w = tw.Watcher()
14 |
15 | weights = None
16 | for i in range(10000):
17 | weights = [random.random() for _ in range(5)]
18 |
19 | # let watcher observe variables we have
20 | # this has almost no performance cost
21 | w.observe(weights=weights)
22 |
23 | time.sleep(1)
24 | ```
25 |
26 | ## Basic Concepts
27 |
28 | Notice that we give `Watcher` object an opportunity to *observe* variables. Observing variables is very cheap so we can observe anything that we might be interested in, for example, an entire model or a batch of data. From the Jupyter Notebook, you can specify arbitrary *lambda expression* that we execute in the context of these observed variables. The result of this lambda expression is sent as stream back to Jupyter Notebook so that you can render into visualizer of your choice. That's pretty much it.
29 |
30 | ## How Does This Work?
31 |
32 | TensorWatch has two core classes: `Watcher` and `WatcherClient`. The `Watcher` object will open up TCP/IP sockets by default and listens to any incoming requests. The `WatcherClient` allows you to connect to `Watcher` and have it execute any Python [lambda expression](http://book.pythontips.com/en/latest/lambdas.html) you want. The Python lambda expressions consume a stream of values as input and outputs another stream of values. The lambda expressions may typically contain [map, filter, and reduce](http://book.pythontips.com/en/latest/map_filter.html) so you can transform values in the input stream, filter them, or aggregate them. Let's see all these with a simple example.
33 |
34 | ## Create the Client
35 | You can either create a Jupyter Notebook or get the existing one in the repo at `notebooks/lazy_logging.ipynb`.
36 |
37 | Let's first do imports:
38 |
39 |
40 | ```python
41 | %matplotlib notebook
42 | import tensorwatch as tw
43 | ```
44 |
45 | Next, create a stream using a simple lambda expression that sums values in our `weights` array that we are observing in the above script. The input to the lambda expression is an object that has all the variables we are observing.
46 |
47 | ```python
48 | client = tw.WatcherClient()
49 | stream = client.create_stream(expr='lambda d: np.sum(d.weights)')
50 | ```
51 |
52 | Next, send this stream to line chart visualization:
53 |
54 | ```python
55 | line_plot = tw.Visualizer(stream, vis_type='line')
56 | line_plot.show()
57 | ```
58 |
59 | Now when you run our script `sum_lazy.py` in the console and then above Jupyter Notebook, you will see the sum of `weights` array getting plotted in real-time:
60 |
61 |
62 |
63 | ## What Just Happened?
64 |
65 | Notice that your script is running in a different process than Jupyter Notebook. You could have been running your Jupyter Notebook on your laptop while your script could have been in some VM in a cloud. You can do queries on observed variables at the run-time! As `Watcher.observe()` is a cheap call, you could observe your entire model, batch inputs and outputs, and so on. You can then slice and dice all these observed variables from the comfort of your Jupyter Notebook while your script is progressing!
66 |
67 | ## Next Steps
68 |
69 | We have just scratched the surface in this new land of lazy logging! You can do much more with this idea. For example, you can create *events* so your lambda expression gets executed only on those events. This is useful to create streams that generate data per batch or per epoch. You can also use many predefined lambda expressions that allows you to view top k by some criteria.
70 |
71 | ## Questions?
72 |
73 | File a [Github issue](https://github.com/microsoft/tensorwatch/issues/new) and let us know if we can improve this tutorial.
74 |
--------------------------------------------------------------------------------
/docs/paper/.gitignore:
--------------------------------------------------------------------------------
1 | ## Core latex/pdflatex auxiliary files:
2 | *.aux
3 | *.lof
4 | *.log
5 | *.lot
6 | *.fls
7 | *.out
8 | *.toc
9 | *.fmt
10 | *.fot
11 | *.cb
12 | *.cb2
13 | .*.lb
14 |
15 | ## Intermediate documents:
16 | *.dvi
17 | *.xdv
18 | *-converted-to.*
19 | # these rules might exclude image files for figures etc.
20 | # *.ps
21 | # *.eps
22 | # *.pdf
23 |
24 | ## Bibliography auxiliary files (bibtex/biblatex/biber):
25 | *.bbl
26 | *.bcf
27 | *.blg
28 | *-blx.aux
29 | *-blx.bib
30 | *.run.xml
31 |
32 | ## Build tool auxiliary files:
33 | *.fdb_latexmk
34 | *.synctex
35 | *.synctex(busy)
36 | *.synctex.gz
37 | *.synctex.gz(busy)
38 | *.pdfsync
39 |
40 | ## Build tool directories for auxiliary files
41 | # latexrun
42 | latex.out/
43 |
44 | ## Auxiliary and intermediate files from other packages:
45 | # algorithms
46 | *.alg
47 | *.loa
48 |
49 | # achemso
50 | acs-*.bib
51 |
52 | # amsthm
53 | *.thm
54 |
55 | # beamer
56 | *.nav
57 | *.pre
58 | *.snm
59 | *.vrb
60 |
61 | # changes
62 | *.soc
63 |
64 | # comment
65 | *.cut
66 |
67 | # cprotect
68 | *.cpt
69 |
70 | # elsarticle (documentclass of Elsevier journals)
71 | *.spl
72 |
73 | # endnotes
74 | *.ent
75 |
76 | # fixme
77 | *.lox
78 |
79 | # feynmf/feynmp
80 | *.mf
81 | *.mp
82 | *.t[1-9]
83 | *.t[1-9][0-9]
84 | *.tfm
85 |
86 | #(r)(e)ledmac/(r)(e)ledpar
87 | *.end
88 | *.?end
89 | *.[1-9]
90 | *.[1-9][0-9]
91 | *.[1-9][0-9][0-9]
92 | *.[1-9]R
93 | *.[1-9][0-9]R
94 | *.[1-9][0-9][0-9]R
95 | *.eledsec[1-9]
96 | *.eledsec[1-9]R
97 | *.eledsec[1-9][0-9]
98 | *.eledsec[1-9][0-9]R
99 | *.eledsec[1-9][0-9][0-9]
100 | *.eledsec[1-9][0-9][0-9]R
101 |
102 | # glossaries
103 | *.acn
104 | *.acr
105 | *.glg
106 | *.glo
107 | *.gls
108 | *.glsdefs
109 | *.lzo
110 | *.lzs
111 |
112 | # uncomment this for glossaries-extra (will ignore makeindex's style files!)
113 | # *.ist
114 |
115 | # gnuplottex
116 | *-gnuplottex-*
117 |
118 | # gregoriotex
119 | *.gaux
120 | *.gtex
121 |
122 | # htlatex
123 | *.4ct
124 | *.4tc
125 | *.idv
126 | *.lg
127 | *.trc
128 | *.xref
129 |
130 | # hyperref
131 | *.brf
132 |
133 | # knitr
134 | *-concordance.tex
135 | # TODO Comment the next line if you want to keep your tikz graphics files
136 | *.tikz
137 | *-tikzDictionary
138 |
139 | # listings
140 | *.lol
141 |
142 | # luatexja-ruby
143 | *.ltjruby
144 |
145 | # makeidx
146 | *.idx
147 | *.ilg
148 | *.ind
149 |
150 | # minitoc
151 | *.maf
152 | *.mlf
153 | *.mlt
154 | *.mtc[0-9]*
155 | *.slf[0-9]*
156 | *.slt[0-9]*
157 | *.stc[0-9]*
158 |
159 | # minted
160 | _minted*
161 | *.pyg
162 |
163 | # morewrites
164 | *.mw
165 |
166 | # nomencl
167 | *.nlg
168 | *.nlo
169 | *.nls
170 |
171 | # pax
172 | *.pax
173 |
174 | # pdfpcnotes
175 | *.pdfpc
176 |
177 | # sagetex
178 | *.sagetex.sage
179 | *.sagetex.py
180 | *.sagetex.scmd
181 |
182 | # scrwfile
183 | *.wrt
184 |
185 | # sympy
186 | *.sout
187 | *.sympy
188 | sympy-plots-for-*.tex/
189 |
190 | # pdfcomment
191 | *.upa
192 | *.upb
193 |
194 | # pythontex
195 | *.pytxcode
196 | pythontex-files-*/
197 |
198 | # tcolorbox
199 | *.listing
200 |
201 | # thmtools
202 | *.loe
203 |
204 | # TikZ & PGF
205 | *.dpth
206 | *.md5
207 | *.auxlock
208 |
209 | # todonotes
210 | *.tdo
211 |
212 | # vhistory
213 | *.hst
214 | *.ver
215 |
216 | # easy-todo
217 | *.lod
218 |
219 | # xcolor
220 | *.xcp
221 |
222 | # xmpincl
223 | *.xmpi
224 |
225 | # xindy
226 | *.xdy
227 |
228 | # xypic precompiled matrices and outlines
229 | *.xyc
230 | *.xyd
231 |
232 | # endfloat
233 | *.ttt
234 | *.fff
235 |
236 | # Latexian
237 | TSWLatexianTemp*
238 |
239 | ## Editors:
240 | # WinEdt
241 | *.bak
242 | *.sav
243 |
244 | # Texpad
245 | .texpadtmp
246 |
247 | # LyX
248 | *.lyx~
249 |
250 | # Kile
251 | *.backup
252 |
253 | # gummi
254 | .*.swp
255 |
256 | # KBibTeX
257 | *~[0-9]*
258 |
259 | # TeXnicCenter
260 | *.tps
261 |
262 | # auto folder when using emacs and auctex
263 | ./auto/*
264 | *.el
265 |
266 | # expex forward references with \gathertags
267 | *-tags.tex
268 |
269 | # standalone packages
270 | *.sta
271 |
272 | # Makeindex log files
273 | *.lpz
--------------------------------------------------------------------------------
/docs/paper/ACM-Reference-Format.cbx:
--------------------------------------------------------------------------------
1 | \ProvidesFile{ACM-Reference-Format.cbx}[2017-09-27 v0.1]
2 |
3 | \RequireCitationStyle{numeric}
4 |
5 | \endinput
6 |
--------------------------------------------------------------------------------
/docs/paper/ACM-Reference-Format.dbx:
--------------------------------------------------------------------------------
1 | % Teach biblatex about numpages field
2 | \DeclareDatamodelFields[type=field, datatype=literal]{numpages}
3 | \DeclareDatamodelEntryfields{numpages}
4 |
5 | % Teach biblatex about articleno field
6 | \DeclareDatamodelFields[type=field, datatype=literal]{articleno}
7 | \DeclareDatamodelEntryfields{articleno}
8 |
9 | % Teach biblatex about urls field
10 | \DeclareDatamodelFields[type=list, datatype=uri]{urls}
11 | \DeclareDatamodelEntryfields{urls}
12 |
13 | % Teach biblatex about school field
14 | \DeclareDatamodelFields[type=list, datatype=literal]{school}
15 | \DeclareDatamodelEntryfields[thesis]{school}
16 |
17 | \DeclareDatamodelFields[type=field, datatype=literal]{key}
18 | \DeclareDatamodelEntryfields{key}
--------------------------------------------------------------------------------
/docs/paper/TensorWatch_Collaboration.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/paper/TensorWatch_Collaboration.png
--------------------------------------------------------------------------------
/docs/paper/main.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/paper/main.pdf
--------------------------------------------------------------------------------
/docs/paper/tensorwatch-screenshot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/paper/tensorwatch-screenshot.png
--------------------------------------------------------------------------------
/docs/paper/tensorwatch-screenshot2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/docs/paper/tensorwatch-screenshot2.png
--------------------------------------------------------------------------------
/docs/simple_logging.md:
--------------------------------------------------------------------------------
1 |
2 | # Simple Logging Tutorial
3 |
4 | In this tutorial, we will use a simple script that maintains two variables. Our goal is to log the values of these variables using TensorWatch and view it in real-time in variety of ways from Jupyter Notebook demonstrating various features offered by TensorWatch. We would execute this script from a console and its code looks like this:
5 |
6 | ```
7 | # available in test/simple_log/sum_log.py
8 |
9 | import time, random
10 | import tensorwatch as tw
11 |
12 | # we will create two streams, one for
13 | # sums of integers, other for random numbers
14 | w = tw.Watcher()
15 | st_isum = w.create_stream('isums')
16 | st_rsum = w.create_stream('rsums')
17 |
18 | isum, rsum = 0, 0
19 | for i in range(10000):
20 | isum += i
21 | rsum += random.randint(i)
22 |
23 | # write to streams
24 | st_isum.write((i, isum))
25 | st_rsum.write((i, rsum))
26 |
27 | time.sleep(1)
28 | ```
29 |
30 | ## Basic Concepts
31 |
32 | The `Watcher` object allows you to create streams. You can then write any values in to these streams. By default, the `Watcher` object opens up TCP/IP sockets so a client can request these streams. As values are being written on one end, they get serialized and sent to any client(s) on the other end. We will use Jupyter Notebook to create a client, open these streams and feed them to various visualizers.
33 |
34 | ## Create the Client
35 | You can either create a new Jupyter Notebook or get the existing one in the repo at `notebooks/simple_logging.ipynb`.
36 |
37 | Let's first do imports:
38 |
39 |
40 | ```python
41 | %matplotlib notebook
42 | import tensorwatch as tw
43 | ```
44 |
45 | Next, open the streams that we had created in above script:
46 |
47 | ```python
48 | cli = tw.WatcherClient()
49 | st_isum = cli.open_stream('isums')
50 | st_rsum = cli.open_stream('rsums')
51 | ```
52 |
53 | Now lets visualize isum stream in textual format:
54 |
55 |
56 | ```python
57 | text_vis = tw.Visualizer(st_isum, vis_type='text')
58 | text_vis.show()
59 | ```
60 |
61 | You should see a growing table in real-time like this, each row is tuple we wrote to the stream:
62 |
63 |
64 |
65 |
66 | That worked out good! How about feeding the same stream simultaneously to plot a line graph?
67 |
68 |
69 | ```python
70 | line_plot = tw.Visualizer(st_isum, vis_type='line', xtitle='i', ytitle='isum')
71 | line_plot.show()
72 | ```
73 |
74 |
75 |
76 | Let's take step further. Say, we plot rsum as well but, just for fun, in the *same plot* as isum above. That's easy with the optional `host` parameter:
77 |
78 | ```python
79 | line_plot = tw.Visualizer(st_rsum, vis_type='line', host=line_plot, ytitle='rsum')
80 | ```
81 | This instantly changes our line plot and a new Y-Axis is automatically added for our second stream. TensorWatch allows you to add multiple streams in same visualizer.
82 |
83 | Are you curious about statistics of these two streams? Let's display the live statistics of both streams side by side! To do this, we will use the `cell` parameter for the second visualization to place right next to the first one:
84 |
85 | ```python
86 | istats = tw.Visualizer(st_isum, vis_type='summary')
87 | istats.show()
88 |
89 | rstats = tw.Visualizer(st_rsum, vis_type='summary', cell=istats)
90 | ```
91 |
92 | The output looks like this (each panel showing statistics for the tuple that was logged in the stream):
93 |
94 |
95 |
96 | ## Next Steps
97 |
98 | We have just scratched the surface of what you can do with TensorWatch. You should check out [Lazy Logging Tutorial](lazy_logging.md) and other [notebooks](https://github.com/microsoft/tensorwatch/tree/master/notebooks) as well.
99 |
100 | ## Questions?
101 |
102 | File a [Github issue](https://github.com/microsoft/tensorwatch/issues/new) and let us know if we can improve this tutorial.
--------------------------------------------------------------------------------
/install_jupyterlab.bat:
--------------------------------------------------------------------------------
1 | conda install -c conda-forge jupyterlab nodejs
2 | conda install ipywidgets
3 | conda install -c plotly plotly-orca psutil
4 |
5 | set NODE_OPTIONS=--max-old-space-size=4096
6 | jupyter labextension install @jupyter-widgets/jupyterlab-manager --no-build
7 | jupyter labextension install plotlywidget --no-build
8 | jupyter labextension install @jupyterlab/plotly-extension --no-build
9 | jupyter labextension install jupyterlab-chart-editor --no-build
10 | jupyter lab build
11 | set NODE_OPTIONS=
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import setuptools
5 |
6 | with open("README.md", "r") as fh:
7 | long_description = fh.read()
8 |
9 | setuptools.setup(
10 | name="tensorwatch",
11 | version="0.9.1",
12 | author="Shital Shah",
13 | author_email="shitals@microsoft.com",
14 | description="Interactive Realtime Debugging and Visualization for AI",
15 | long_description=long_description,
16 | long_description_content_type="text/markdown",
17 | url="https://github.com/microsoft/tensorwatch",
18 | packages=setuptools.find_packages(),
19 | license='MIT',
20 | classifiers=(
21 | "Programming Language :: Python :: 3",
22 | "License :: OSI Approved :: MIT License",
23 | "Operating System :: OS Independent",
24 | ),
25 | include_package_data=True,
26 | install_requires=[
27 | 'matplotlib', 'numpy', 'pyzmq', 'plotly', 'ipywidgets',
28 | 'pydot>=1.4.2',
29 | 'nbformat', 'scikit-image', 'nbformat', 'pyyaml', 'scikit-image', 'graphviz' # , 'receptivefield'
30 | ]
31 | )
32 |
--------------------------------------------------------------------------------
/tensorwatch.sln:
--------------------------------------------------------------------------------
1 |
2 | Microsoft Visual Studio Solution File, Format Version 12.00
3 | # Visual Studio 15
4 | VisualStudioVersion = 15.0.27428.2043
5 | MinimumVisualStudioVersion = 10.0.40219.1
6 | Project("{888888A0-9F3D-457C-B088-3A5042F75D52}") = "tensorwatch", "tensorwatch.pyproj", "{CC8ABC7F-EDE1-4E13-B6B7-0041A5EC66A7}"
7 | EndProject
8 | Global
9 | GlobalSection(SolutionConfigurationPlatforms) = preSolution
10 | Debug|Any CPU = Debug|Any CPU
11 | Release|Any CPU = Release|Any CPU
12 | EndGlobalSection
13 | GlobalSection(ProjectConfigurationPlatforms) = postSolution
14 | {CC8ABC7F-EDE1-4E13-B6B7-0041A5EC66A7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
15 | {CC8ABC7F-EDE1-4E13-B6B7-0041A5EC66A7}.Release|Any CPU.ActiveCfg = Release|Any CPU
16 | EndGlobalSection
17 | GlobalSection(SolutionProperties) = preSolution
18 | HideSolutionNode = FALSE
19 | EndGlobalSection
20 | GlobalSection(ExtensibilityGlobals) = postSolution
21 | SolutionGuid = {99E7AEC7-2CDE-48C8-B98B-4E28E4F840B6}
22 | EndGlobalSection
23 | EndGlobal
24 |
--------------------------------------------------------------------------------
/tensorwatch/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from typing import Iterable, Sequence, Union
5 |
6 | from .watcher_client import WatcherClient
7 | from .watcher import Watcher
8 | from .watcher_base import WatcherBase
9 |
10 | from .text_vis import TextVis
11 | from .plotly.embeddings_plot import EmbeddingsPlot
12 | from .mpl.line_plot import LinePlot
13 | from .mpl.image_plot import ImagePlot
14 | from .mpl.histogram import Histogram
15 | from .mpl.bar_plot import BarPlot
16 | from .mpl.pie_chart import PieChart
17 | from .visualizer import Visualizer
18 |
19 | from .stream import Stream
20 | from .array_stream import ArrayStream
21 | from .lv_types import PointData, ImageData, VisArgs, StreamItem, PredictionResult
22 | from . import utils
23 |
24 | ###### Import methods for tw namespace #########
25 | #from .receptive_field.rf_utils import plot_receptive_field, plot_grads_at
26 | from .embeddings.tsne_utils import get_tsne_components
27 | from .model_graph.torchstat_utils import model_stats, ModelStats
28 | from .image_utils import show_image, open_image, img2pyt, linear_to_2d, plt_loop
29 | from .data_utils import pyt_ds2list, sample_by_class, col2array, search_similar
30 |
31 |
32 |
33 | def draw_model(model, input_shape=None, orientation='TB', png_filename=None): #orientation = 'LR' for landscpe
34 | from .model_graph.hiddenlayer import pytorch_draw_model
35 | g = pytorch_draw_model.draw_graph(model, input_shape)
36 | return g
37 |
38 |
39 |
--------------------------------------------------------------------------------
/tensorwatch/array_stream.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from .stream import Stream
5 | from .lv_types import StreamItem
6 | import uuid
7 |
8 | class ArrayStream(Stream):
9 | def __init__(self, array, stream_name:str=None, console_debug:bool=False):
10 | super(ArrayStream, self).__init__(stream_name=stream_name, console_debug=console_debug)
11 |
12 | self.stream_name = stream_name
13 | self.array = array
14 |
15 | def load(self, from_stream:'Stream'=None):
16 | if self.array is not None:
17 | self.write(self.array)
18 | super(ArrayStream, self).load()
19 |
--------------------------------------------------------------------------------
/tensorwatch/data_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import random
5 | import scipy.spatial.distance
6 | import heapq
7 | import numpy as np
8 | import torch
9 |
10 | def pyt_tensor2np(pyt_tensor):
11 | if pyt_tensor is None:
12 | return None
13 | if isinstance(pyt_tensor, torch.Tensor):
14 | n = pyt_tensor.data.cpu().numpy()
15 | if len(n.shape) == 1:
16 | return n[0]
17 | else:
18 | return n
19 | elif isinstance(pyt_tensor, np.ndarray):
20 | return pyt_tensor
21 | else:
22 | return np.array(pyt_tensor)
23 |
24 | def pyt_tuple2np(pyt_tuple):
25 | return tuple((pyt_tensor2np(t) for t in pyt_tuple))
26 |
27 | def pyt_ds2list(pyt_ds, count=None):
28 | count = count or len(pyt_ds)
29 | return [pyt_tuple2np(t) for t, c in zip(pyt_ds, range(count))]
30 |
31 | def sample_by_class(data, n_samples, class_col=1, shuffle=True):
32 | if shuffle:
33 | random.shuffle(data)
34 | samples = {}
35 | for i, t in enumerate(data):
36 | cls = t[class_col]
37 | if cls not in samples:
38 | samples[cls] = []
39 | if len(samples[cls]) < n_samples:
40 | samples[cls].append(data[i])
41 | samples = sum(samples.values(), [])
42 | return samples
43 |
44 | def col2array(dataset, col):
45 | return [row[col] for row in dataset]
46 |
47 | def search_similar(inputs, compare_to, algorithm='euclidean', topk=5, invert_score=True):
48 | all_scores = scipy.spatial.distance.cdist(inputs, compare_to, algorithm)
49 | all_results = []
50 | for input_val, scores in zip(inputs, all_scores):
51 | result = []
52 | for i, (score, data) in enumerate(zip(scores, compare_to)):
53 | if invert_score:
54 | score = 1/(score + 1.0E-6)
55 | if len(result) < topk:
56 | heapq.heappush(result, (score, (i, input_val, data)))
57 | else:
58 | heapq.heappushpop(result, (score, (i, input_val, data)))
59 | all_results.append(result)
60 | return all_results
--------------------------------------------------------------------------------
/tensorwatch/embeddings/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 |
5 |
--------------------------------------------------------------------------------
/tensorwatch/embeddings/tsne_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from sklearn.manifold import TSNE
5 | from sklearn.preprocessing import StandardScaler
6 | import numpy as np
7 |
8 | def _standardize_data(data, col, whitten, flatten):
9 | if col is not None:
10 | data = data[col]
11 |
12 | #TODO: enable auto flattening
13 | #if data is tensor then flatten it first
14 | #if flatten and len(data) > 0 and hasattr(data[0], 'shape') and \
15 | # utils.has_method(data[0], 'reshape'):
16 |
17 | # data = [d.reshape((-1,)) for d in data]
18 |
19 | if whitten:
20 | data = StandardScaler().fit_transform(data)
21 | return data
22 |
23 | def get_tsne_components(data, features_col=0, labels_col=1, whitten=True, n_components=3, perplexity=20, flatten=True, for_plot=True):
24 | features = _standardize_data(data, features_col, whitten, flatten)
25 | tsne = TSNE(n_components=n_components, perplexity=perplexity)
26 | tsne_results = tsne.fit_transform(features)
27 |
28 | if for_plot:
29 | comps = tsne_results.tolist()
30 | labels = data[labels_col]
31 | for i, item in enumerate(comps):
32 | # low, high, annotation, text, color
33 | label = labels[i]
34 | if isinstance(labels, np.ndarray):
35 | label = label.item()
36 | item.extend((None, None, None, str(int(label)), label))
37 | return comps
38 | return tsne_results
39 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/tensorwatch/evaler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import threading, sys, logging
5 | from collections.abc import Iterator
6 | from .lv_types import EventData
7 |
8 | # pylint: disable=unused-wildcard-import
9 | # pylint: disable=wildcard-import
10 | # pylint: disable=unused-import
11 | from functools import *
12 | from itertools import *
13 | from statistics import *
14 | import numpy as np
15 | from .evaler_utils import *
16 |
17 | class Evaler:
18 | class EvalReturn:
19 | def __init__(self, result=None, is_valid=False, exception=None):
20 | self.result, self.exception, self.is_valid = \
21 | result, exception, is_valid
22 | def reset(self):
23 | self.result, self.exception, self.is_valid = \
24 | None, None, False
25 |
26 | class PostableIterator:
27 | def __init__(self, eval_wait):
28 | self.eval_wait = eval_wait
29 | self.post_wait = threading.Event()
30 | self.event_data, self.ended = None, None # define attributes in init
31 | self.reset()
32 |
33 | def reset(self):
34 | self.event_data, self.ended = None, False
35 | self.post_wait.clear()
36 |
37 | def abort(self):
38 | self.ended = True
39 | self.post_wait.set()
40 |
41 | def post(self, event_data:EventData=None, ended=False):
42 | self.event_data, self.ended = event_data, ended
43 | self.post_wait.set()
44 |
45 | def get_vals(self):
46 | while True:
47 | self.post_wait.wait()
48 | self.post_wait.clear()
49 | if self.ended:
50 | break
51 | else:
52 | yield self.event_data
53 | # below will cause result=None, is_valid=False when
54 | # expression has reduce
55 | self.eval_wait.set()
56 |
57 | def __init__(self, expr):
58 | self.eval_wait = threading.Event()
59 | self.reset_wait = threading.Event()
60 | self.g = Evaler.PostableIterator(self.eval_wait)
61 | self.expr = expr
62 | self.eval_return, self.continue_thread = None, None # define in __init__
63 | self.reset()
64 |
65 | self.th = threading.Thread(target=self._runner, daemon=True, name='evaler')
66 | self.th.start()
67 | self.running = True
68 |
69 | def reset(self):
70 | self.g.reset()
71 | self.eval_wait.clear()
72 | self.reset_wait.clear()
73 | self.eval_return = Evaler.EvalReturn()
74 | self.continue_thread = True
75 |
76 | def _runner(self):
77 | while True:
78 | # this var will be used by eval
79 | l = self.g.get_vals() # pylint: disable=unused-variable
80 | try:
81 | result = eval(self.expr) # pylint: disable=eval-used
82 | if isinstance(result, Iterator):
83 | for item in result:
84 | self.eval_return = Evaler.EvalReturn(item, True)
85 | else:
86 | self.eval_return = Evaler.EvalReturn(result, True)
87 | except Exception as ex: # pylint: disable=broad-except
88 | logging.exception('Exception occured while evaluating expression: ' + self.expr)
89 | self.eval_return = Evaler.EvalReturn(None, True, ex)
90 | self.eval_wait.set()
91 | self.reset_wait.wait()
92 | if not self.continue_thread:
93 | break
94 | self.reset()
95 | self.running = False
96 | utils.debug_log('eval runner ended!')
97 |
98 | def abort(self):
99 | utils.debug_log('Evaler Aborted')
100 | self.continue_thread = False
101 | self.g.abort()
102 | self.eval_wait.set()
103 | self.reset_wait.set()
104 |
105 | def post(self, event_data:EventData=None, ended=False, continue_thread=True):
106 | if not self.running:
107 | utils.debug_log('post was called when Evaler is not running')
108 | return None, False
109 | self.eval_return.reset()
110 | self.g.post(event_data, ended)
111 | self.eval_wait.wait()
112 | self.eval_wait.clear()
113 | # save result before it would get reset
114 | eval_return = self.eval_return
115 | self.reset_wait.set()
116 | self.continue_thread = continue_thread
117 | if isinstance(eval_return.result, Iterator):
118 | eval_return.result = list(eval_return.result)
119 | return eval_return
120 |
121 | def join(self):
122 | self.th.join()
123 |
--------------------------------------------------------------------------------
/tensorwatch/file_stream.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from .stream import Stream
5 | import pickle, os
6 | from typing import Any
7 | from . import utils
8 | import time
9 |
10 | class FileStream(Stream):
11 | def __init__(self, for_write:bool, file_name:str, stream_name:str=None, console_debug:bool=False):
12 | super(FileStream, self).__init__(stream_name=stream_name or file_name, console_debug=console_debug)
13 |
14 | self._file = open(file_name, 'wb' if for_write else 'rb')
15 | self.file_name = file_name
16 | self.for_write = for_write
17 | utils.debug_log('FileStream started', os.path.realpath(self._file.name), verbosity=0)
18 |
19 | def close(self):
20 | if not self._file.closed:
21 | self._file.close()
22 | self._file = None
23 | utils.debug_log('FileStream is closed', os.path.realpath(self._file.name), verbosity=0)
24 | super(FileStream, self).close()
25 |
26 | def write(self, val:Any, from_stream:'Stream'=None):
27 | stream_item = self.to_stream_item(val)
28 |
29 | if self.for_write:
30 | pickle.dump(stream_item, self._file)
31 | self._file.flush()
32 | #os.fsync()
33 | super(FileStream, self).write(stream_item)
34 |
35 | def read_all(self, from_stream:'Stream'=None):
36 | if self.for_write:
37 | raise IOError('Cannot use read() call because FileSteam is opened with for_write=True')
38 | if self._file is not None:
39 | self._file.seek(0, 0) # we may filter this stream multiple times
40 | while not utils.is_eof(self._file):
41 | yield pickle.load(self._file)
42 | for item in super(FileStream, self).read_all():
43 | yield item
44 |
45 | def load(self, from_stream:'Stream'=None):
46 | if self.for_write:
47 | raise IOError('Cannot use load() call because FileSteam is opened with for_write=True')
48 | if self._file is not None:
49 | self._file.seek(0, 0) # we may filter this stream multiple times
50 | while not utils.is_eof(self._file):
51 | stream_item = pickle.load(self._file)
52 | self.write(stream_item)
53 | super(FileStream, self).load()
54 |
55 | def save(self, from_stream:'Stream'=None):
56 | if not self._file.closed:
57 | self._file.flush()
58 | super(FileStream, self).save(val)
59 |
60 |
--------------------------------------------------------------------------------
/tensorwatch/filtered_stream.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from .stream import Stream
5 | from typing import Callable, Any
6 |
7 | class FilteredStream(Stream):
8 | def __init__(self, source_stream:Stream, filter_expr:Callable, stream_name:str=None,
9 | console_debug:bool=False)->None:
10 |
11 | stream_name = stream_name or '{}|{}'.format(source_stream.stream_name, str(filter_expr))
12 | super(FilteredStream, self).__init__(stream_name=stream_name, console_debug=console_debug)
13 | self.subscribe(source_stream)
14 | self.filter_expr = filter_expr
15 |
16 | def _filter(self, stream_item):
17 | return self.filter_expr(stream_item) \
18 | if self.filter_expr is not None \
19 | else (stream_item, True)
20 |
21 | def write(self, val:Any, from_stream:'Stream'=None):
22 | stream_item = self.to_stream_item(val)
23 |
24 | result, is_valid = self._filter(stream_item)
25 | if is_valid:
26 | return super(FilteredStream, self).write(result)
27 | # else ignore this call
28 |
29 | def read_all(self, from_stream:'Stream'=None): #override->replacement
30 | for subscribed_to in self._subscribed_to:
31 | for stream_item in subscribed_to.read_all(from_stream=self):
32 | result, is_valid = self._filter(stream_item)
33 | if is_valid:
34 | yield stream_item
35 |
36 |
--------------------------------------------------------------------------------
/tensorwatch/image_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import numpy as np
5 | import math
6 | import time
7 |
8 | def guess_image_dims(img):
9 | if len(img.shape) == 1:
10 | # assume 2D monochrome (MNIST)
11 | width = height = round(math.sqrt(img.shape[0]))
12 | if width*height != img.shape[0]:
13 | # assume 3 channels (CFAR, ImageNet)
14 | width = height = round(math.sqrt(img.shape[0] / 3))
15 | if width*height*3 != img.shape[0]:
16 | raise ValueError("Cannot guess image dimensions for linearized pixels")
17 | return (3, height, width)
18 | return (1, height, width)
19 | return img.shape
20 |
21 | def to_imshow_array(img, width=None, height=None):
22 | # array from Pytorch has shape: [[channels,] height, width]
23 | # image needed for imshow needs: [height, width, channels]
24 | from PIL import Image
25 |
26 | if img is not None:
27 | if isinstance(img, Image.Image):
28 | img = np.array(img)
29 | if len(img.shape) >= 2:
30 | return img # img is already compatible to imshow
31 |
32 | # force max 3 dimensions
33 | if len(img.shape) > 3:
34 | # TODO allow config
35 | # select first one in batch
36 | img = img[0:1,:,:]
37 |
38 | if len(img.shape) == 1: # linearized pixels typically used for MLPs
39 | if not(width and height):
40 | # pylint: disable=unused-variable
41 | channels, height, width = guess_image_dims(img)
42 | img = img.reshape((-1, height, width))
43 |
44 | if len(img.shape) == 3:
45 | if img.shape[0] == 1: # single channel images
46 | img = img.squeeze(0)
47 | else:
48 | img = np.swapaxes(img, 0, 2) # transpose H,W for imshow
49 | img = np.swapaxes(img, 0, 1)
50 | elif len(img.shape) == 2:
51 | img = np.swapaxes(img, 0, 1) # transpose H,W for imshow
52 | else: #zero dimensions
53 | img = None
54 |
55 | return img
56 |
57 | #width_dim=1 for imshow, 2 for pytorch arrays
58 | def stitch_horizontal(images, width_dim=1):
59 | return np.concatenate(images, axis=width_dim)
60 |
61 | def _resize_image(img, size=None):
62 | if size is not None or (hasattr(img, 'shape') and len(img.shape) == 1):
63 | if size is None:
64 | # make guess for 1-dim tensors
65 | h = int(math.sqrt(img.shape[0]))
66 | w = int(img.shape[0] / h)
67 | size = h,w
68 | img = np.reshape(img, size)
69 | return img
70 |
71 | def show_image(img, size=None, alpha=None, cmap=None,
72 | img2=None, size2=None, alpha2=None, cmap2=None, ax=None):
73 | import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue
74 |
75 | img =_resize_image(img, size)
76 | img2 =_resize_image(img2, size2)
77 |
78 | (ax or plt).imshow(img, alpha=alpha, cmap=cmap)
79 |
80 | if img2 is not None:
81 | (ax or plt).imshow(img2, alpha=alpha2, cmap=cmap2)
82 |
83 | return ax or plt.show()
84 |
85 | # convert_mode param is mode: https://pillow.readthedocs.io/en/5.1.x/handbook/concepts.html#modes
86 | # use convert_mode='RGB' to force 3 channels
87 | def open_image(path, resize=None, resample=1, convert_mode=None): # Image.ANTIALIAS==1
88 | from PIL import Image
89 |
90 | img = Image.open(path)
91 | if resize is not None:
92 | img = img.resize(resize, resample)
93 | if convert_mode is not None:
94 | img = img.convert(convert_mode)
95 | return img
96 |
97 | def img2pyt(img, add_batch_dim=True, resize=None):
98 | from torchvision import transforms # expensive function-level import
99 |
100 | ts = []
101 | if resize is not None:
102 | ts.append(transforms.RandomResizedCrop(resize))
103 | ts.append(transforms.ToTensor())
104 | img_pyt = transforms.Compose(ts)(img)
105 | if add_batch_dim:
106 | img_pyt.unsqueeze_(0)
107 | return img_pyt
108 |
109 | def linear_to_2d(img, size=None):
110 | if size is not None or (hasattr(img, 'shape') and len(img.shape) == 1):
111 | if size is None:
112 | # make guess for 1-dim tensors
113 | h = int(math.sqrt(img.shape[0]))
114 | w = int(img.shape[0] / h)
115 | size = h,w
116 | img = np.reshape(img, size)
117 | return img
118 |
119 | def stack_images(imgs):
120 | return np.hstack(imgs)
121 |
122 | def plt_loop(count=None, sleep_time=1, plt_pause=0.01):
123 | import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue
124 |
125 | #plt.ion()
126 | #plt.show(block=False)
127 | while((count is None or count > 0) and not plt.waitforbuttonpress(plt_pause)):
128 | #plt.draw()
129 | plt.pause(plt_pause)
130 | time.sleep(sleep_time)
131 | if count is not None:
132 | count = count - 1
133 |
134 | def get_cmap(name:str):
135 | import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue
136 | return plt.cm.get_cmap(name=name)
137 |
138 |
--------------------------------------------------------------------------------
/tensorwatch/imagenet_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from torchvision import transforms
5 | from . import pytorch_utils
6 | import json, os
7 |
8 | def get_image_transform():
9 | transf = transforms.Compose([ #TODO: cache these transforms?
10 | get_resize_transform(),
11 | transforms.ToTensor(),
12 | get_normalize_transform()
13 | ])
14 |
15 | return transf
16 |
17 | def get_resize_transform():
18 | return transforms.Resize((224, 224))
19 |
20 | def get_normalize_transform():
21 | return transforms.Normalize(mean=[0.485, 0.456, 0.406],
22 | std=[0.229, 0.224, 0.225])
23 |
24 | def image2batch(image):
25 | return pytorch_utils.image2batch(image, image_transform=get_image_transform())
26 |
27 | def predict(model, images, image_transform=None, device=None):
28 | logits = pytorch_utils.batch_predict(model, images,
29 | input_transform=image_transform or get_image_transform(), device=device)
30 | probs = pytorch_utils.logits2probabilities(logits) #2-dim array, one column per class, one row per input
31 | return probs
32 |
33 | _imagenet_labels = None
34 | def get_imagenet_labels():
35 | # pylint: disable=global-statement
36 | global _imagenet_labels
37 | _imagenet_labels = _imagenet_labels or ImagenetLabels()
38 | return _imagenet_labels
39 |
40 | def probabilities2classes(probs, topk=5):
41 | labels = get_imagenet_labels()
42 | top_probs = probs.topk(topk)
43 | # return (probability, class_id, class_label, class_code)
44 | return tuple((p,c, labels.index2label_text(c), labels.index2label_code(c)) \
45 | for p, c in zip(top_probs[0][0].data.cpu().numpy(), top_probs[1][0].data.cpu().numpy()))
46 |
47 | class ImagenetLabels:
48 | def __init__(self, json_path=None):
49 | self._idx2label = []
50 | self._idx2cls = []
51 | self._cls2label = {}
52 | self._cls2idx = {}
53 |
54 | json_path = json_path or os.path.join(os.path.dirname(__file__), 'imagenet_class_index.json')
55 |
56 | with open(os.path.abspath(json_path), "r") as read_file:
57 | class_json = json.load(read_file)
58 | self._idx2label = [class_json[str(k)][1] for k in range(len(class_json))]
59 | self._idx2cls = [class_json[str(k)][0] for k in range(len(class_json))]
60 | self._cls2label = {class_json[str(k)][0]: class_json[str(k)][1] for k in range(len(class_json))}
61 | self._cls2idx = {class_json[str(k)][0]: k for k in range(len(class_json))}
62 |
63 | def index2label_text(self, index):
64 | return self._idx2label[index]
65 | def index2label_code(self, index):
66 | return self._idx2cls[index]
67 | def label_code2label_text(self, label_code):
68 | return self._cls2label[label_code]
69 | def label_code2index(self, label_code):
70 | return self._cls2idx[label_code]
--------------------------------------------------------------------------------
/tensorwatch/model_graph/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 |
--------------------------------------------------------------------------------
/tensorwatch/model_graph/hiddenlayer/README.md:
--------------------------------------------------------------------------------
1 | # Credits
2 |
3 | Code in this folder is adopted from,
4 |
5 | * https://github.com/waleedka/hiddenlayer
6 |
--------------------------------------------------------------------------------
/tensorwatch/model_graph/hiddenlayer/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/tensorwatch/model_graph/hiddenlayer/distiller.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2018 Intel Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 |
17 | import torch
18 | from .distiller_utils import *
19 |
20 | import logging
21 | logging.captureWarnings(True)
22 |
23 | def model_find_param_name(model, param_to_find):
24 | """Look up the name of a model parameter.
25 |
26 | Arguments:
27 | model: the model to search
28 | param_to_find: the parameter whose name we want to look up
29 |
30 | Returns:
31 | The parameter name (string) or None, if the parameter was not found.
32 | """
33 | for name, param in model.named_parameters():
34 | if param is param_to_find:
35 | return name
36 | return None
37 |
38 |
39 | def model_find_module_name(model, module_to_find):
40 | """Look up the name of a module in a model.
41 |
42 | Arguments:
43 | model: the model to search
44 | module_to_find: the module whose name we want to look up
45 |
46 | Returns:
47 | The module name (string) or None, if the module was not found.
48 | """
49 | for name, m in model.named_modules():
50 | if m == module_to_find:
51 | return name
52 | return None
53 |
54 |
55 | def model_find_param(model, param_to_find_name):
56 | """Look a model parameter by its name
57 |
58 | Arguments:
59 | model: the model to search
60 | param_to_find_name: the name of the parameter that we are searching for
61 |
62 | Returns:
63 | The parameter or None, if the paramter name was not found.
64 | """
65 | for name, param in model.named_parameters():
66 | if name == param_to_find_name:
67 | return param
68 | return None
69 |
70 |
71 | def model_find_module(model, module_to_find):
72 | """Given a module name, find the module in the provided model.
73 |
74 | Arguments:
75 | model: the model to search
76 | module_to_find: the module whose name we want to look up
77 |
78 | Returns:
79 | The module or None, if the module was not found.
80 | """
81 | for name, m in model.named_modules():
82 | if name == module_to_find:
83 | return m
84 | return None
85 |
86 |
--------------------------------------------------------------------------------
/tensorwatch/model_graph/hiddenlayer/ge.py:
--------------------------------------------------------------------------------
1 | """
2 | HiddenLayer
3 |
4 | Implementation graph expressions to find nodes in a graph based on a pattern.
5 |
6 | Written by Waleed Abdulla
7 | Licensed under the MIT License
8 | """
9 |
10 | import re
11 |
12 |
13 |
14 | class GEParser():
15 | def __init__(self, text):
16 | self.index = 0
17 | self.text = text
18 |
19 | def parse(self):
20 | return self.serial() or self.parallel() or self.expression()
21 |
22 | def parallel(self):
23 | index = self.index
24 | expressions = []
25 | while len(expressions) == 0 or self.token("|"):
26 | e = self.expression()
27 | if not e:
28 | break
29 | expressions.append(e)
30 | if len(expressions) >= 2:
31 | return ParallelPattern(expressions)
32 | # No match. Reset index
33 | self.index = index
34 |
35 | def serial(self):
36 | index = self.index
37 | expressions = []
38 | while len(expressions) == 0 or self.token(">"):
39 | e = self.expression()
40 | if not e:
41 | break
42 | expressions.append(e)
43 |
44 | if len(expressions) >= 2:
45 | return SerialPattern(expressions)
46 | self.index = index
47 |
48 | def expression(self):
49 | index = self.index
50 |
51 | if self.token("("):
52 | e = self.serial() or self.parallel() or self.op()
53 | if e and self.token(")"):
54 | return e
55 | self.index = index
56 | e = self.op()
57 | return e
58 |
59 | def op(self):
60 | t = self.re(r"\w+")
61 | if t:
62 | c = self.condition()
63 | return NodePattern(t, c)
64 |
65 | def condition(self):
66 | # TODO: not implemented yet. This function is a placeholder
67 | index = self.index
68 | if self.token("["):
69 | c = self.token("1x1") or self.token("3x3")
70 | if c:
71 | if self.token("]"):
72 | return c
73 | self.index = index
74 |
75 | def token(self, s):
76 | return self.re(r"\s*(" + re.escape(s) + r")\s*", 1)
77 |
78 | def string(self, s):
79 | if s == self.text[self.index:self.index+len(s)]:
80 | self.index += len(s)
81 | return s
82 |
83 | def re(self, regex, group=0):
84 | m = re.match(regex, self.text[self.index:])
85 | if m:
86 | self.index += len(m.group(0))
87 | return m.group(group)
88 |
89 |
90 | class NodePattern():
91 | def __init__(self, op, condition=None):
92 | self.op = op
93 | self.condition = condition # TODO: not implemented yet
94 |
95 | def match(self, graph, node):
96 | if isinstance(node, list):
97 | return [], None
98 | if self.op == node.op:
99 | following = graph.outgoing(node)
100 | if len(following) == 1:
101 | following = following[0]
102 | return [node], following
103 | else:
104 | return [], None
105 |
106 |
107 | class SerialPattern():
108 | def __init__(self, patterns):
109 | self.patterns = patterns
110 |
111 | def match(self, graph, node):
112 | all_matches = []
113 | for i, p in enumerate(self.patterns):
114 | matches, following = p.match(graph, node)
115 | if not matches:
116 | return [], None
117 | all_matches.extend(matches)
118 | if i < len(self.patterns) - 1:
119 | node = following # Might be more than one node
120 | return all_matches, following
121 |
122 |
123 | class ParallelPattern():
124 | def __init__(self, patterns):
125 | self.patterns = patterns
126 |
127 | def match(self, graph, nodes):
128 | if not nodes:
129 | return [], None
130 | nodes = nodes if isinstance(nodes, list) else [nodes]
131 | # If a single node, assume we need to match with its siblings
132 | if len(nodes) == 1:
133 | nodes = graph.siblings(nodes[0])
134 | else:
135 | # Verify all nodes have the same parent or all have no parent
136 | parents = [graph.incoming(n) for n in nodes]
137 | matches = [set(p) == set(parents[0]) for p in parents[1:]]
138 | if not all(matches):
139 | return [], None
140 |
141 | # TODO: If more nodes than patterns, we should consider
142 | # all permutations of the nodes
143 | if len(self.patterns) != len(nodes):
144 | return [], None
145 |
146 | patterns = self.patterns.copy()
147 | nodes = nodes.copy()
148 | all_matches = []
149 | end_node = None
150 | for p in patterns:
151 | found = False
152 | for n in nodes:
153 | matches, following = p.match(graph, n)
154 | if matches:
155 | found = True
156 | nodes.remove(n)
157 | all_matches.extend(matches)
158 | # Verify all branches end in the same node
159 | if end_node:
160 | if end_node != following:
161 | return [], None
162 | else:
163 | end_node = following
164 | break
165 | if not found:
166 | return [], None
167 | return all_matches, end_node
168 |
169 |
170 |
--------------------------------------------------------------------------------
/tensorwatch/model_graph/hiddenlayer/pytorch_builder.py:
--------------------------------------------------------------------------------
1 | """
2 | HiddenLayer
3 |
4 | PyTorch graph importer.
5 |
6 | Written by Waleed Abdulla
7 | Licensed under the MIT License
8 | """
9 |
10 | from __future__ import absolute_import, division, print_function
11 | import re
12 | from .graph import Graph, Node
13 | from . import transforms as ht
14 | import torch
15 | from collections import abc
16 | import numpy as np
17 |
18 | # PyTorch Graph Transforms
19 | FRAMEWORK_TRANSFORMS = [
20 | # Hide onnx: prefix
21 | ht.Rename(op=r"onnx::(.*)", to=r"\1"),
22 | # ONNX uses Gemm for linear layers (stands for General Matrix Multiplication).
23 | # It's an odd name that noone recognizes. Rename it.
24 | ht.Rename(op=r"Gemm", to=r"Linear"),
25 | # PyTorch layers that don't have an ONNX counterpart
26 | ht.Rename(op=r"aten::max\_pool2d\_with\_indices", to="MaxPool"),
27 | # Shorten op name
28 | ht.Rename(op=r"BatchNormalization", to="BatchNorm"),
29 | ]
30 |
31 |
32 | def dump_pytorch_graph(graph):
33 | """List all the nodes in a PyTorch graph."""
34 | f = "{:25} {:40} {} -> {}"
35 | print(f.format("kind", "scopeName", "inputs", "outputs"))
36 | for node in graph.nodes():
37 | print(f.format(node.kind(), node.scopeName(),
38 | [i.unique() for i in node.inputs()],
39 | [i.unique() for i in node.outputs()]
40 | ))
41 |
42 |
43 | def pytorch_id(node):
44 | """Returns a unique ID for a node."""
45 | # After ONNX simplification, the scopeName is not unique anymore
46 | # so append node outputs to guarantee uniqueness
47 | return node.scopeName() + "/outputs/" + "/".join([o.debugName() for o in node.outputs()])
48 |
49 |
50 |
51 | def get_shape(torch_node):
52 | """Return the output shape of the given Pytorch node."""
53 | # Extract node output shape from the node string representation
54 | # This is a hack because there doesn't seem to be an official way to do it.
55 | # See my quesiton in the PyTorch forum:
56 | # https://discuss.pytorch.org/t/node-output-shape-from-trace-graph/24351/2
57 | # TODO: find a better way to extract output shape
58 | # TODO: Assuming the node has one output. Update if we encounter a multi-output node.
59 | m = re.match(r".*Float\(([\d\s\,]+)\).*", str(next(torch_node.outputs())))
60 | if m:
61 | shape = m.group(1)
62 | shape = shape.split(",")
63 | shape = tuple(map(int, shape))
64 | else:
65 | shape = None
66 | return shape
67 |
68 | def calc_rf(model, input_shape):
69 | for n, p in model.named_parameters():
70 | if not p.requires_grad:
71 | continue;
72 | if 'bias' in n:
73 | p.data.fill_(0)
74 | elif 'weight' in n:
75 | p.data.fill_(1)
76 |
77 | input = torch.ones(input_shape, requires_grad=True)
78 | output = model(input)
79 | out_shape = output.size()
80 | ndims = len(out_shape)
81 | grad = torch.zeros(out_shape)
82 | l_tmp=[]
83 | for i in xrange(ndims):
84 | if i==0 or i==1:#batch or channel
85 | l_tmp.append(0)
86 | else:
87 | l_tmp.append(out_shape[i]/2)
88 |
89 | grad[tuple(l_tmp)] = 1
90 | output.backward(gradient=grad)
91 | grad_np = img_.grad[0,0].data.numpy()
92 | idx_nonzeros = np.where(grad_np!=0)
93 | RF=[np.max(idx)-np.min(idx)+1 for idx in idx_nonzeros]
94 |
95 | return RF
96 |
97 | def import_graph(hl_graph, model, args, input_names=None, verbose=False):
98 | # TODO: add input names to graph
99 |
100 | if args is None:
101 | args = [1, 3, 224, 224] # assume ImageNet default
102 |
103 | # if args is not Tensor but is array like then convert it to torch tensor
104 | if not isinstance(args, torch.Tensor) and \
105 | hasattr(args, "__len__") and hasattr(args, '__getitem__') and \
106 | not isinstance(args, (str, abc.ByteString)):
107 | args = torch.ones(args)
108 |
109 | # Run the Pytorch graph to get a trace and generate a graph from it
110 | with torch.onnx.set_training(model, False):
111 | try:
112 | trace = torch.jit.trace(model, args)
113 | torch.onnx._optimize_trace(trace)
114 | torch_graph = trace.graph
115 | except RuntimeError as e:
116 | print(e)
117 | print('Error occured when creating jit trace for model.')
118 | raise e
119 |
120 | # Dump list of nodes (DEBUG only)
121 | if verbose:
122 | dump_pytorch_graph(torch_graph)
123 |
124 | # Loop through nodes and build HL graph
125 | nodes = list(torch_graph.nodes())
126 | inps = [(n, [i.unique() for i in n.inputs()]) for n in nodes]
127 | for i, torch_node in enumerate(nodes):
128 | # Op
129 | op = torch_node.kind()
130 | # Parameters
131 | params = {k: torch_node[k] for k in torch_node.attributeNames()}
132 | # Inputs/outputs
133 | # TODO: inputs = [i.unique() for i in node.inputs()]
134 | outputs = [o.unique() for o in torch_node.outputs()]
135 | # Get output shape
136 | shape = get_shape(torch_node)
137 | # Add HL node
138 | hl_node = Node(uid=pytorch_id(torch_node), name=None, op=op,
139 | output_shape=shape, params=params)
140 | hl_graph.add_node(hl_node)
141 | # Add edges
142 | for target_torch_node,target_inputs in inps:
143 | if set(outputs) & set(target_inputs):
144 | hl_graph.add_edge_by_id(pytorch_id(torch_node), pytorch_id(target_torch_node), shape)
145 | return hl_graph
146 |
--------------------------------------------------------------------------------
/tensorwatch/model_graph/hiddenlayer/pytorch_builder_grad.py:
--------------------------------------------------------------------------------
1 | '''
2 | File name: plot-pytorch-autograd-graph.py
3 | Author: Ludovic Trottier
4 | Date created: November 8, 2017.
5 | Date last modified: November 8, 2017
6 | Credits: moskomule (https://discuss.pytorch.org/t/print-autograd-graph/692/15)
7 | '''
8 | from graphviz import Digraph
9 | import torch
10 | from torch.autograd import Variable
11 |
12 | from . import transforms as ht
13 | from collections import abc
14 | import numpy as np
15 | from .graph import Graph, Node
16 |
17 | # PyTorch Graph Transforms
18 | FRAMEWORK_TRANSFORMS = [
19 | # Hide onnx: prefix
20 | ht.Rename(op=r"onnx::(.*)", to=r"\1"),
21 | # ONNX uses Gemm for linear layers (stands for General Matrix Multiplication).
22 | # It's an odd name that noone recognizes. Rename it.
23 | ht.Rename(op=r"Gemm", to=r"Linear"),
24 | # PyTorch layers that don't have an ONNX counterpart
25 | ht.Rename(op=r"aten::max\_pool2d\_with\_indices", to="MaxPool"),
26 | # Shorten op name
27 | ht.Rename(op=r"BatchNormalization", to="BatchNorm"),
28 | ]
29 |
30 | def add_node2dot(dot, var, id, label, op=None, output_shape=None, params=None):
31 | hl_node = Node(uid=id, name=op, op=label,
32 | output_shape=output_shape, params=params)
33 | dot.add_node(hl_node)
34 |
35 | def make_dot(var, params, dot):
36 | """ Produces Graphviz representation of PyTorch autograd graph.
37 |
38 | Blue nodes are trainable Variables (weights, bias).
39 | Orange node are saved tensors for the backward pass.
40 |
41 | Args:
42 | var: output Variable
43 | params: list of (name, Parameters)
44 | """
45 | param_map2 = {k:v for k, v in params}
46 | print(param_map2)
47 | param_map = {id(v): k for k, v in params}
48 |
49 |
50 |
51 | node_attr = dict(style='filled',
52 | shape='box',
53 | align='left',
54 | fontsize='12',
55 | ranksep='0.1',
56 | height='0.2')
57 |
58 | # dot = Digraph(
59 | # filename='network',
60 | # format='pdf',
61 | # node_attr=node_attr,
62 | # graph_attr=dict(size="12,12"))
63 | seen = set()
64 |
65 | def add_nodes(dot, var):
66 | if var not in seen:
67 |
68 | node_id = str(id(var))
69 |
70 | if torch.is_tensor(var):
71 | node_label = "saved tensor\n{}".format(tuple(var.size()))
72 | add_node2dot(dot, var, node_id, node_label, op=None)
73 |
74 | elif hasattr(var, 'variable'):
75 | variable_name = param_map.get(id(var.variable))
76 | variable_size = tuple(var.variable.size())
77 | node_name = "{}\n{}".format(variable_name, variable_size)
78 | add_node2dot(dot, var, node_id, node_name, op=None)
79 |
80 | else:
81 | node_label = type(var).__name__.replace('Backward', '')
82 | add_node2dot(dot, var, node_id, node_label, op=None)
83 |
84 | seen.add(var)
85 |
86 | if hasattr(var, 'next_functions'):
87 | for u in var.next_functions:
88 | if u[0] is not None:
89 | dot.add_edge_by_id(str(id(u[0])), str(id(var)), None)
90 | add_nodes(dot, u[0])
91 |
92 | if hasattr(var, 'saved_tensors'):
93 | for t in var.saved_tensors:
94 | dot.add_edge_by_id(str(id(t)), str(id(var)), None)
95 | add_nodes(dot, t)
96 |
97 | add_nodes(dot, var.grad_fn)
98 |
99 | return dot
100 |
101 | def import_graph(hl_graph, model, args, input_names=None, verbose=False):
102 | if args is None:
103 | args = [1, 3, 224, 224] # assume ImageNet default
104 |
105 | # if args is not Tensor but is array like then convert it to torch tensor
106 | if not isinstance(args, torch.Tensor) and \
107 | hasattr(args, "__len__") and hasattr(args, '__getitem__') and \
108 | not isinstance(args, (str, abc.ByteString)):
109 | args = torch.ones(args)
110 |
111 | y = model(args)
112 | g = make_dot(y, model.named_parameters(), hl_graph)
113 | return hl_graph
--------------------------------------------------------------------------------
/tensorwatch/model_graph/hiddenlayer/pytorch_builder_trace.py:
--------------------------------------------------------------------------------
1 | from torch.utils.tensorboard._pytorch_graph import GraphPy, NodePyIO, NodePyOP
2 | import torch
3 |
4 | from . import transforms as ht
5 | from collections import abc
6 | import numpy as np
7 |
8 | # PyTorch Graph Transforms
9 | FRAMEWORK_TRANSFORMS = [
10 | # Hide onnx: prefix
11 | ht.Rename(op=r"onnx::(.*)", to=r"\1"),
12 | # ONNX uses Gemm for linear layers (stands for General Matrix Multiplication).
13 | # It's an odd name that noone recognizes. Rename it.
14 | ht.Rename(op=r"Gemm", to=r"Linear"),
15 | # PyTorch layers that don't have an ONNX counterpart
16 | ht.Rename(op=r"aten::max\_pool2d\_with\_indices", to="MaxPool"),
17 | # Shorten op name
18 | ht.Rename(op=r"BatchNormalization", to="BatchNorm"),
19 | ]
20 |
21 | def parse(graph, args=None, omit_useless_nodes=True):
22 | """This method parses an optimized PyTorch model graph and produces
23 | a list of nodes and node stats for eventual conversion to TensorBoard
24 | protobuf format.
25 | Args:
26 | graph (PyTorch module): The model to be parsed.
27 | args (tuple): input tensor[s] for the model.
28 | omit_useless_nodes (boolean): Whether to remove nodes from the graph.
29 | """
30 | n_inputs = len(args)
31 |
32 | scope = {}
33 | nodes_py = GraphPy()
34 | for i, node in enumerate(graph.inputs()):
35 | if omit_useless_nodes:
36 | if len(node.uses()) == 0: # number of user of the node (= number of outputs/ fanout)
37 | continue
38 |
39 | if i < n_inputs:
40 | nodes_py.append(NodePyIO(node, 'input'))
41 | else:
42 | nodes_py.append(NodePyIO(node)) # parameter
43 |
44 | for node in graph.nodes():
45 | nodes_py.append(NodePyOP(node))
46 |
47 | for node in graph.outputs(): # must place last.
48 | NodePyIO(node, 'output')
49 | nodes_py.find_common_root()
50 | nodes_py.populate_namespace_from_OP_to_IO()
51 | return nodes_py
52 |
53 |
54 | def graph(model, args, verbose=False):
55 | """
56 | This method processes a PyTorch model and produces a `GraphDef` proto
57 | that can be logged to TensorBoard.
58 | Args:
59 | model (PyTorch module): The model to be parsed.
60 | args (tuple): input tensor[s] for the model.
61 | verbose (bool): Whether to print out verbose information while
62 | processing.
63 | """
64 | with torch.onnx.set_training(model, False): # TODO: move outside of torch.onnx?
65 | try:
66 | trace = torch.jit.trace(model, args)
67 | graph = trace.graph
68 | except RuntimeError as e:
69 | print(e)
70 | print('Error occurs, No graph saved')
71 | raise e
72 |
73 | if verbose:
74 | print(graph)
75 | return parse(graph, args)
76 |
77 |
78 | def import_graph(hl_graph, model, args, input_names=None, verbose=False):
79 | # TODO: add input names to graph
80 |
81 | if args is None:
82 | args = [1, 3, 224, 224] # assume ImageNet default
83 |
84 | # if args is not Tensor but is array like then convert it to torch tensor
85 | if not isinstance(args, torch.Tensor) and \
86 | hasattr(args, "__len__") and hasattr(args, '__getitem__') and \
87 | not isinstance(args, (str, abc.ByteString)):
88 | args = torch.ones(args)
89 |
90 | graph_py = graph(model, args, verbose)
91 |
92 | # # Loop through nodes and build HL graph
93 | # nodes = list(torch_graph.nodes())
94 | # inps = [(n, [i.unique() for i in n.inputs()]) for n in nodes]
95 | # for i, torch_node in enumerate(nodes):
96 | # # Op
97 | # op = torch_node.kind()
98 | # # Parameters
99 | # params = {k: torch_node[k] for k in torch_node.attributeNames()}
100 | # # Inputs/outputs
101 | # # TODO: inputs = [i.unique() for i in node.inputs()]
102 | # outputs = [o.unique() for o in torch_node.outputs()]
103 | # # Get output shape
104 | # shape = get_shape(torch_node)
105 | # # Add HL node
106 | # hl_node = Node(uid=pytorch_id(torch_node), name=None, op=op,
107 | # output_shape=shape, params=params)
108 | # hl_graph.add_node(hl_node)
109 | # # Add edges
110 | # for target_torch_node,target_inputs in inps:
111 | # if set(outputs) & set(target_inputs):
112 | # hl_graph.add_edge_by_id(pytorch_id(torch_node), pytorch_id(target_torch_node), shape)
113 | return hl_graph
114 |
--------------------------------------------------------------------------------
/tensorwatch/model_graph/torchstat/README.md:
--------------------------------------------------------------------------------
1 | # Credits
2 |
3 | Code in this folder is almost as-is from torchstat repository located at https://github.com/Swall0w/torchstat.
4 |
5 | Additional merges are from:
6 | - https://github.com/kenshohara/torchstat
7 | - https://github.com/lyakaap/torchstat
--------------------------------------------------------------------------------
/tensorwatch/model_graph/torchstat/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/tensorwatch/model_graph/torchstat/__init__.py
--------------------------------------------------------------------------------
/tensorwatch/model_graph/torchstat/compute_flops.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import numpy as np
4 | import math
5 |
6 | def compute_flops(module, inp, out):
7 | if isinstance(module, nn.Conv2d):
8 | return compute_Conv2d_flops(module, inp[0], out[0])
9 | elif isinstance(module, nn.BatchNorm2d):
10 | return compute_BatchNorm2d_flops(module, inp[0], out[0])
11 | elif isinstance(module, (nn.AvgPool2d, nn.MaxPool2d)):
12 | return compute_Pool2d_flops(module, inp[0], out[0])
13 | elif isinstance(module, (nn.AdaptiveAvgPool2d, nn.AdaptiveMaxPool2d)):
14 | return compute_adaptivepool_flops(module, inp[0], out[0])
15 | elif isinstance(module, (nn.ReLU, nn.ReLU6, nn.PReLU, nn.ELU, nn.LeakyReLU)):
16 | return compute_ReLU_flops(module, inp[0], out[0])
17 | elif isinstance(module, nn.Upsample):
18 | return compute_Upsample_flops(module, inp[0], out[0])
19 | elif isinstance(module, nn.Linear):
20 | return compute_Linear_flops(module, inp[0], out[0])
21 | else:
22 | #print(f"[Flops]: {type(module).__name__} is not supported!")
23 | return 0
24 | pass
25 |
26 |
27 | def compute_Conv2d_flops(module, inp, out):
28 | # Can have multiple inputs, getting the first one
29 | assert isinstance(module, nn.Conv2d)
30 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
31 |
32 | batch_size = inp.size()[0]
33 | in_c = inp.size()[1]
34 | k_h, k_w = module.kernel_size
35 | out_c, out_h, out_w = out.size()[1:]
36 | groups = module.groups
37 |
38 | filters_per_channel = out_c // groups
39 | conv_per_position_flops = k_h * k_w * in_c * filters_per_channel
40 | active_elements_count = batch_size * out_h * out_w
41 |
42 | total_conv_flops = conv_per_position_flops * active_elements_count
43 |
44 | bias_flops = 0
45 | if module.bias is not None:
46 | bias_flops = out_c * active_elements_count
47 |
48 | total_flops = total_conv_flops + bias_flops
49 | return total_flops
50 |
51 | def compute_adaptivepool_flops(module, input, output):
52 | # credits: https://github.com/xternalz/SDPoint/blob/master/utils/flops.py
53 | batch_size = input.size(0)
54 | input_planes = input.size(1)
55 | input_height = input.size(2)
56 | input_width = input.size(3)
57 |
58 | flops = 0
59 | for i in range(output.size(2)):
60 | y_start = int(math.floor(float(i * input_height) / output.size(2)))
61 | y_end = int(math.ceil(float((i + 1) * input_height) / output.size(2)))
62 | for j in range(output.size(3)):
63 | x_start = int(math.floor(float(j * input_width) / output.size(3)))
64 | x_end = int(math.ceil(float((j + 1) * input_width) / output.size(3)))
65 |
66 | flops += batch_size * input_planes * (y_end-y_start+1) * (x_end-x_start+1)
67 | return flops
68 |
69 | def compute_BatchNorm2d_flops(module, inp, out):
70 | assert isinstance(module, nn.BatchNorm2d)
71 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
72 | in_c, in_h, in_w = inp.size()[1:]
73 | batch_flops = np.prod(inp.shape)
74 | if module.affine:
75 | batch_flops *= 2
76 | return batch_flops
77 |
78 |
79 | def compute_ReLU_flops(module, inp, out):
80 | assert isinstance(module, (nn.ReLU, nn.ReLU6, nn.PReLU, nn.ELU, nn.LeakyReLU))
81 | batch_size = inp.size()[0]
82 | active_elements_count = batch_size
83 |
84 | for s in inp.size()[1:]:
85 | active_elements_count *= s
86 |
87 | return active_elements_count
88 |
89 |
90 | def compute_Pool2d_flops(module, input, out):
91 | batch_size = input.size(0)
92 | input_planes = input.size(1)
93 | input_height = input.size(2)
94 | input_width = input.size(3)
95 | kernel_size = ('int' in str(type(module.kernel_size))) and [module.kernel_size, module.kernel_size] or module.kernel_size
96 | kernel_ops = kernel_size[0] * kernel_size[1]
97 | stride = ('int' in str(type(module.stride))) and [module.stride, module.stride] or module.stride
98 | padding = ('int' in str(type(module.padding))) and [module.padding, module.padding] or module.padding
99 |
100 | output_width = math.floor((input_width + 2 * padding[0] - kernel_size[0]) / float(stride[0]) + 1)
101 | output_height = math.floor((input_height + 2 * padding[1] - kernel_size[1]) / float(stride[0]) + 1)
102 | return batch_size * input_planes * output_width * output_height * kernel_ops
103 |
104 |
105 | def compute_Linear_flops(module, inp, out):
106 | assert isinstance(module, nn.Linear)
107 | assert len(inp.size()) == 2 and len(out.size()) == 2
108 | batch_size = inp.size()[0]
109 | return batch_size * inp.size()[1] * out.size()[1]
110 |
111 | def compute_Upsample_flops(module, inp, out):
112 | assert isinstance(module, nn.Upsample)
113 | output_size = out[0]
114 | batch_size = inp.size()[0]
115 | output_elements_count = batch_size
116 | for s in output_size.shape[1:]:
117 | output_elements_count *= s
118 |
119 | return output_elements_count
120 |
--------------------------------------------------------------------------------
/tensorwatch/model_graph/torchstat/compute_memory.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import numpy as np
4 |
5 |
6 | def compute_memory(module, inp, out):
7 | if isinstance(module, (nn.ReLU, nn.ReLU6, nn.ELU, nn.LeakyReLU)):
8 | return compute_ReLU_memory(module, inp[0], out[0])
9 | elif isinstance(module, nn.PReLU):
10 | return compute_PReLU_memory(module, inp[0], out[0])
11 | elif isinstance(module, nn.Conv2d):
12 | return compute_Conv2d_memory(module, inp[0], out[0])
13 | elif isinstance(module, nn.BatchNorm2d):
14 | return compute_BatchNorm2d_memory(module, inp[0], out[0])
15 | elif isinstance(module, nn.Linear):
16 | return compute_Linear_memory(module, inp[0], out[0])
17 | elif isinstance(module, (nn.AvgPool2d, nn.MaxPool2d)):
18 | return compute_Pool2d_memory(module, inp[0], out[0])
19 | else:
20 | #print(f"[Memory]: {type(module).__name__} is not supported!")
21 | return 0, 0
22 | pass
23 |
24 |
25 | def num_params(module):
26 | return sum(p.numel() for p in module.parameters() if p.requires_grad)
27 |
28 |
29 | def compute_ReLU_memory(module, inp, out):
30 | assert isinstance(module, (nn.ReLU, nn.ReLU6, nn.ELU, nn.LeakyReLU))
31 |
32 | mread = inp.numel()
33 | mwrite = out.numel()
34 |
35 | return mread*inp.element_size(), mwrite*out.element_size()
36 |
37 |
38 | def compute_PReLU_memory(module, inp, out):
39 | assert isinstance(module, nn.PReLU)
40 |
41 | batch_size = inp.size()[0]
42 | mread = batch_size * (inp[0].numel() + num_params(module))
43 | mwrite = out.numel()
44 |
45 | return mread*inp.element_size(), mwrite*out.element_size()
46 |
47 |
48 | def compute_Conv2d_memory(module, inp, out):
49 | # Can have multiple inputs, getting the first one
50 | assert isinstance(module, nn.Conv2d)
51 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
52 |
53 | batch_size = inp.size()[0]
54 |
55 | # This includes weights with bias if the module contains it.
56 | mread = batch_size * (inp[0].numel() + num_params(module))
57 | mwrite = out.numel()
58 |
59 | return mread*inp.element_size(), mwrite*out.element_size()
60 |
61 |
62 | def compute_BatchNorm2d_memory(module, inp, out):
63 | assert isinstance(module, nn.BatchNorm2d)
64 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
65 |
66 | batch_size, in_c, in_h, in_w = inp.size()
67 | mread = batch_size * (inp[0].numel() + 2 * in_c)
68 | mwrite = out.numel()
69 |
70 | return mread*inp.element_size(), mwrite*out.element_size()
71 |
72 |
73 | def compute_Linear_memory(module, inp, out):
74 | assert isinstance(module, nn.Linear)
75 | assert len(inp.size()) == 2 and len(out.size()) == 2
76 |
77 | batch_size = inp.size()[0]
78 |
79 | # This includes weights with bias if the module contains it.
80 | mread = batch_size * (inp[0].numel() + num_params(module))
81 | mwrite = out.numel()
82 |
83 | return mread*inp.element_size(), mwrite*out.element_size()
84 |
85 | def compute_Pool2d_memory(module, inp, out):
86 | assert isinstance(module, (nn.MaxPool2d, nn.AvgPool2d))
87 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
88 |
89 | mread = inp.numel()
90 | mwrite = out.numel()
91 |
92 | return mread*inp.element_size(), mwrite*out.element_size()
93 |
--------------------------------------------------------------------------------
/tensorwatch/model_graph/torchstat/reporter.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 |
4 | pd.set_option('display.width', 1000)
5 | pd.set_option('display.max_rows', 10000)
6 | pd.set_option('display.max_columns', 10000)
7 |
8 |
9 | def round_value(value, binary=False):
10 | divisor = 1024. if binary else 1000.
11 |
12 | if value // divisor**4 > 0:
13 | return str(round(value / divisor**4, 2)) + 'T'
14 | elif value // divisor**3 > 0:
15 | return str(round(value / divisor**3, 2)) + 'G'
16 | elif value // divisor**2 > 0:
17 | return str(round(value / divisor**2, 2)) + 'M'
18 | elif value // divisor > 0:
19 | return str(round(value / divisor, 2)) + 'K'
20 | return str(value)
21 |
22 |
23 | def report_format(collected_nodes):
24 | data = list()
25 | for node in collected_nodes:
26 | name = node.name
27 | input_shape = ' '.join(['{:>3d}'] * len(node.input_shape)).format(
28 | *[e for e in node.input_shape])
29 | output_shape = ' '.join(['{:>3d}'] * len(node.output_shape)).format(
30 | *[e for e in node.output_shape])
31 | parameter_quantity = node.parameter_quantity
32 | inference_memory = node.inference_memory
33 | MAdd = node.MAdd
34 | Flops = node.Flops
35 | mread, mwrite = [i for i in node.Memory]
36 | duration = node.duration
37 | data.append([name, input_shape, output_shape, parameter_quantity,
38 | inference_memory, MAdd, duration, Flops, mread,
39 | mwrite])
40 | df = pd.DataFrame(data)
41 | df.columns = ['module name', 'input shape', 'output shape',
42 | 'params', 'memory(MB)',
43 | 'MAdd', 'duration', 'Flops', 'MemRead(B)', 'MemWrite(B)']
44 | df['duration[%]'] = df['duration'] / (df['duration'].sum() + 1e-7)
45 | df['MemR+W(B)'] = df['MemRead(B)'] + df['MemWrite(B)']
46 | total_parameters_quantity = df['params'].sum()
47 | total_memory = df['memory(MB)'].sum()
48 | total_operation_quantity = df['MAdd'].sum()
49 | total_flops = df['Flops'].sum()
50 | total_duration = df['duration[%]'].sum()
51 | total_mread = df['MemRead(B)'].sum()
52 | total_mwrite = df['MemWrite(B)'].sum()
53 | total_memrw = df['MemR+W(B)'].sum()
54 | del df['duration']
55 |
56 | # Add Total row
57 | total_df = pd.Series([total_parameters_quantity, total_memory,
58 | total_operation_quantity, total_flops,
59 | total_duration, mread, mwrite, total_memrw],
60 | index=['params', 'memory(MB)', 'MAdd', 'Flops', 'duration[%]',
61 | 'MemRead(B)', 'MemWrite(B)', 'MemR+W(B)'],
62 | name='total')
63 | df = df.append(total_df)
64 |
65 | df = df.fillna(' ')
66 | df['memory(MB)'] = df['memory(MB)'].apply(
67 | lambda x: '{:.2f}'.format(x))
68 | df['duration[%]'] = df['duration[%]'].apply(lambda x: '{:.2%}'.format(x))
69 | df['MAdd'] = df['MAdd'].apply(lambda x: '{:,}'.format(x))
70 | df['Flops'] = df['Flops'].apply(lambda x: '{:,}'.format(x))
71 |
72 | summary = str(df) + '\n'
73 | summary += "=" * len(str(df).split('\n')[0])
74 | summary += '\n'
75 | summary += "Total params: {:,}\n".format(total_parameters_quantity)
76 |
77 | summary += "-" * len(str(df).split('\n')[0])
78 | summary += '\n'
79 | summary += "Total memory: {:.2f}MB\n".format(total_memory)
80 | summary += "Total MAdd: {}MAdd\n".format(round_value(total_operation_quantity))
81 | summary += "Total Flops: {}Flops\n".format(round_value(total_flops))
82 | summary += "Total MemR+W: {}B\n".format(round_value(total_memrw, True))
83 | return summary
84 |
--------------------------------------------------------------------------------
/tensorwatch/model_graph/torchstat_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from .torchstat import analyzer
5 | import pandas as pd
6 | import copy
7 |
8 | class LayerStats:
9 | def __init__(self, node) -> None:
10 | self.name = node.name
11 | self.input_shape = node.input_shape
12 | self.output_shape = node.output_shape
13 | self.parameters = node.parameter_quantity
14 | self.inference_memory = node.inference_memory
15 | self.MAdd = node.MAdd
16 | self.Flops = node.Flops
17 | self.mread, self.mwrite = node.Memory[0], node.Memory[1]
18 | self.duration = node.duration
19 |
20 | class ModelStats(LayerStats):
21 | def __init__(self, model, input_shape, clone_model=False) -> None:
22 | if clone_model:
23 | model = copy.deepcopy(model)
24 | collected_nodes = analyzer.analyze(model, input_shape, 1)
25 | self.layer_stats = []
26 | for node in collected_nodes:
27 | self.layer_stats.append(LayerStats(node))
28 |
29 | self.name = 'Model'
30 | self.input_shape = input_shape
31 | self.output_shape = self.layer_stats[-1].output_shape
32 | self.parameters = sum((l.parameters for l in self.layer_stats))
33 | self.inference_memory = sum((l.inference_memory for l in self.layer_stats))
34 | self.MAdd = sum((l.MAdd for l in self.layer_stats))
35 | self.Flops = sum((l.Flops for l in self.layer_stats))
36 | self.mread = sum((l.mread for l in self.layer_stats))
37 | self.mwrite = sum((l.mwrite for l in self.layer_stats))
38 | self.duration = sum((l.duration for l in self.layer_stats))
39 |
40 | def model_stats(model, input_shape):
41 | ms = ModelStats(model, input_shape)
42 | return model_stats2df(ms)
43 |
44 | def _round_value(value, binary=False):
45 | divisor = 1024. if binary else 1000.
46 |
47 | if value // divisor**4 > 0:
48 | return str(round(value / divisor**4, 2)) + 'T'
49 | elif value // divisor**3 > 0:
50 | return str(round(value / divisor**3, 2)) + 'G'
51 | elif value // divisor**2 > 0:
52 | return str(round(value / divisor**2, 2)) + 'M'
53 | elif value // divisor > 0:
54 | return str(round(value / divisor, 2)) + 'K'
55 | return str(value)
56 |
57 |
58 | def model_stats2df(model_stats:ModelStats):
59 | pd.set_option('display.width', 1000)
60 | pd.set_option('display.max_rows', 10000)
61 | pd.set_option('display.max_columns', 10000)
62 |
63 | df = pd.DataFrame([l.__dict__ for l in model_stats.layer_stats])
64 | total_df = pd.Series(model_stats.__dict__, name='Total')
65 | df = df.append(total_df[df.columns], ignore_index=True)
66 |
67 | df = df.fillna(' ')
68 | # df['memory(MB)'] = df['memory(MB)'].apply(
69 | # lambda x: '{:.2f}'.format(x))
70 | # df['duration[%]'] = df['duration[%]'].apply(lambda x: '{:.2%}'.format(x))
71 | for c in ['MAdd', 'Flops', 'parameters', 'inference_memory', 'mread', 'mwrite']:
72 | df[c] = df[c].apply(lambda x: '{:,}'.format(x))
73 |
74 | df.rename(columns={'name': 'module name',
75 | 'input_shape': 'input shape',
76 | 'output_shape': 'output shape',
77 | 'inference_memory': 'infer memory(MB)',
78 | 'mread': 'MemRead(B)',
79 | 'mwrite': 'MemWrite(B)'
80 | }, inplace=True)
81 |
82 | #summary = "Total params: {:,}\n".format(total_parameters_quantity)
83 |
84 | #summary += "-" * len(str(df).split('\n')[0])
85 | #summary += '\n'
86 | #summary += "Total memory: {:.2f}MB\n".format(total_memory)
87 | #summary += "Total MAdd: {}MAdd\n".format(_round_value(total_operation_quantity))
88 | #summary += "Total Flops: {}Flops\n".format(_round_value(total_flops))
89 | #summary += "Total MemR+W: {}B\n".format(_round_value(total_memrw, True))
90 | return df
91 |
--------------------------------------------------------------------------------
/tensorwatch/mpl/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 |
5 |
--------------------------------------------------------------------------------
/tensorwatch/mpl/histogram.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from .base_mpl_plot import BaseMplPlot
5 | from .. import image_utils
6 | from .. import utils
7 | import numpy as np
8 |
9 | class Histogram(BaseMplPlot):
10 | def init_stream_plot(self, stream_vis,
11 | xtitle='', ytitle='', ztitle='', colormap=None, color=None,
12 | bins=None, normed=None, histtype='bar', edge_color=None, linewidth=None,
13 | opacity=None, **stream_vis_args):
14 |
15 | # add main subplot
16 | stream_vis.bins, stream_vis.normed, stream_vis.linewidth = bins, normed, (linewidth or 2)
17 | stream_vis.ax = self.get_main_axis()
18 | stream_vis.series = []
19 | stream_vis.bars_artists = [] # stores previously drawn bars
20 |
21 | stream_vis.cmap = image_utils.get_cmap(name=colormap or 'Dark2')
22 | if color is None:
23 | if not self.is_3d:
24 | color = stream_vis.cmap((len(self._stream_vises)%stream_vis.cmap.N)/stream_vis.cmap.N) # pylint: disable=no-member
25 | stream_vis.color = color
26 | stream_vis.edge_color = 'black'
27 | stream_vis.histtype = histtype
28 | stream_vis.opacity = opacity
29 | stream_vis.ax.set_xlabel(xtitle)
30 | stream_vis.ax.xaxis.label.set_style('italic')
31 | stream_vis.ax.set_ylabel(ytitle)
32 | stream_vis.ax.yaxis.label.set_color(color)
33 | stream_vis.ax.yaxis.label.set_style('italic')
34 | if self.is_3d:
35 | stream_vis.ax.set_zlabel(ztitle)
36 | stream_vis.ax.zaxis.label.set_style('italic')
37 |
38 | def is_show_grid(self): #override
39 | return False
40 |
41 | def clear_artists(self, stream_vis):
42 | for bar in stream_vis.bars_artists:
43 | bar.remove()
44 | stream_vis.bars_artists.clear()
45 |
46 | def clear_plot(self, stream_vis, clear_history):
47 | stream_vis.series.clear()
48 | self.clear_artists(stream_vis)
49 |
50 | def _show_stream_items(self, stream_vis, stream_items):
51 | """Paint the given stream_items in to visualizer. If visualizer is dirty then return False else True.
52 | """
53 |
54 | vals = self._extract_vals(stream_items)
55 | if not len(vals):
56 | return True
57 |
58 | stream_vis.series += vals
59 | self.clear_artists(stream_vis)
60 | n, bins, stream_vis.bars_artists = stream_vis.ax.hist(stream_vis.series, bins=stream_vis.bins,
61 | normed=stream_vis.normed, color=stream_vis.color, edgecolor=stream_vis.edge_color,
62 | histtype=stream_vis.histtype, alpha=stream_vis.opacity,
63 | linewidth=stream_vis.linewidth)
64 |
65 | stream_vis.ax.set_xticks(bins)
66 |
67 | #stream_vis.ax.relim()
68 | #stream_vis.ax.autoscale_view()
69 |
70 | return False
71 |
--------------------------------------------------------------------------------
/tensorwatch/mpl/image_plot.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from .base_mpl_plot import BaseMplPlot
5 | from .. import utils, image_utils
6 | import numpy as np
7 |
8 | #from IPython import get_ipython
9 |
10 | class ImagePlot(BaseMplPlot):
11 | def init_stream_plot(self, stream_vis,
12 | rows=2, cols=5, img_width=None, img_height=None, img_channels=None,
13 | colormap=None, viz_img_scale=None, **stream_vis_args):
14 | stream_vis.rows, stream_vis.cols = rows, cols
15 | stream_vis.img_channels, stream_vis.colormap = img_channels, colormap
16 | stream_vis.img_width, stream_vis.img_height = img_width, img_height
17 | stream_vis.viz_img_scale = viz_img_scale
18 | # subplots holding each image
19 | stream_vis.axs = [[None for _ in range(cols)] for _ in range(rows)]
20 | # axis image
21 | stream_vis.ax_imgs = [[None for _ in range(cols)] for _ in range(rows)]
22 |
23 | def clear_plot(self, stream_vis, clear_history):
24 | for row in range(stream_vis.rows):
25 | for col in range(stream_vis.cols):
26 | img = stream_vis.ax_imgs[row][col]
27 | if img:
28 | x, y = img.get_size()
29 | img.set_data(np.zeros((x, y)))
30 |
31 | def _show_stream_items(self, stream_vis, stream_items):
32 | """Paint the given stream_items in to visualizer. If visualizer is dirty then return False else True.
33 | """
34 |
35 | # as we repaint each image plot, select last if multiple events were pending
36 | stream_item = None
37 | for er in reversed(stream_items):
38 | if not(er.ended or er.value is None):
39 | stream_item = er
40 | break
41 | if stream_item is None:
42 | return True
43 |
44 | row, col, i = 0, 0, 0
45 | dirty = False
46 | # stream_item.value is expected to be ImagePlotItems
47 | for image_list in stream_item.value:
48 | # convert to imshow compatible, stitch images
49 | images = [image_utils.to_imshow_array(img, stream_vis.img_width, stream_vis.img_height) \
50 | for img in image_list.images if img is not None]
51 | img_viz = image_utils.stitch_horizontal(images, width_dim=1)
52 |
53 | # resize if requested
54 | if stream_vis.viz_img_scale is not None:
55 | import skimage.transform # expensive import done on demand
56 |
57 | if isinstance(img_viz, np.ndarray) and np.issubdtype(img_viz.dtype, np.floating):
58 | img_viz = img_viz.clip(-1, 1) # some MNIST images have out of range values causing exception in sklearn
59 | img_viz = skimage.transform.rescale(img_viz,
60 | (stream_vis.viz_img_scale, stream_vis.viz_img_scale), mode='reflect', preserve_range=False)
61 |
62 | # create subplot if it doesn't exist
63 | ax = stream_vis.axs[row][col]
64 | if ax is None:
65 | ax = stream_vis.axs[row][col] = \
66 | self.figure.add_subplot(stream_vis.rows, stream_vis.cols, i+1)
67 | ax.set_xticks([])
68 | ax.set_yticks([])
69 |
70 | cmap = image_list.cmap or ('Greys' if stream_vis.colormap is None and \
71 | len(img_viz.shape) == 2 else stream_vis.colormap)
72 |
73 | stream_vis.ax_imgs[row][col] = ax.imshow(img_viz, interpolation="none", cmap=cmap, alpha=image_list.alpha)
74 | dirty = True
75 |
76 | # set title
77 | title = image_list.title
78 | if len(title) > 12: #wordwrap if too long
79 | title = utils.wrap_string(title) if len(title) > 24 else title
80 | fontsize = 8
81 | else:
82 | fontsize = 12
83 | ax.set_title(title, fontsize=fontsize) #'fontweight': 'light'
84 |
85 | #ax.autoscale_view() # not needed
86 | col = col + 1
87 | if col >= stream_vis.cols:
88 | col = 0
89 | row = row + 1
90 | if row >= stream_vis.rows:
91 | break
92 | i += 1
93 |
94 | return not dirty
95 |
96 |
97 | def has_legend(self):
98 | return self.show_legend or False
--------------------------------------------------------------------------------
/tensorwatch/mpl/pie_chart.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from .base_mpl_plot import BaseMplPlot
5 | from .. import image_utils
6 | from .. import utils
7 | import numpy as np
8 |
9 | class PieChart(BaseMplPlot):
10 | def init_stream_plot(self, stream_vis, autopct=None, colormap=None, color=None,
11 | shadow=None, **stream_vis_args):
12 |
13 | # add main subplot
14 | stream_vis.autopct, stream_vis.shadow = autopct, True if shadow is None else shadow
15 | stream_vis.ax = self.get_main_axis()
16 | stream_vis.series = []
17 | stream_vis.wedge_artists = [] # stores previously drawn bars
18 |
19 | stream_vis.cmap = image_utils.get_cmap(name=colormap or 'Dark2')
20 | if color is None:
21 | if not self.is_3d:
22 | color = stream_vis.cmap((len(self._stream_vises)%stream_vis.cmap.N)/stream_vis.cmap.N) # pylint: disable=no-member
23 | stream_vis.color = color
24 |
25 | def clear_artists(self, stream_vis):
26 | for artist in stream_vis.wedge_artists:
27 | artist.remove()
28 | stream_vis.wedge_artists.clear()
29 |
30 | def clear_plot(self, stream_vis, clear_history):
31 | stream_vis.series.clear()
32 | self.clear_artists(stream_vis)
33 |
34 | def _show_stream_items(self, stream_vis, stream_items):
35 | """Paint the given stream_items in to visualizer. If visualizer is dirty then return False else True.
36 | """
37 |
38 | vals = self._extract_vals(stream_items)
39 | if not len(vals):
40 | return True
41 |
42 | # make sure tuple has 4 elements
43 | unpacker = lambda a0=None,a1=None,a2=None,a3=None,*_:(a0,a1,a2,a3)
44 | stream_vis.series.extend([unpacker(*val) for val in vals])
45 | self.clear_artists(stream_vis)
46 |
47 | labels, sizes, colors, explode = \
48 | [t[0] for t in stream_vis.series], \
49 | [t[1] for t in stream_vis.series], \
50 | [(t[2] or stream_vis.cmap.colors[i % len(stream_vis.cmap.colors)]) \
51 | for i, t in enumerate(stream_vis.series)], \
52 | [t[3] or 0 for t in stream_vis.series],
53 |
54 |
55 | stream_vis.wedge_artists, *_ = stream_vis.ax.pie( \
56 | sizes, explode=explode, labels=labels, colors=colors,
57 | autopct=stream_vis.autopct, shadow=stream_vis.shadow)
58 |
59 | return False
60 |
61 |
--------------------------------------------------------------------------------
/tensorwatch/notebook_maker.py:
--------------------------------------------------------------------------------
1 | import nbformat
2 | from nbformat.v4 import new_code_cell, new_markdown_cell, new_notebook
3 | from nbformat import v3, v4
4 | import codecs
5 | from os import linesep, path
6 | import uuid
7 | from . import utils
8 | import re
9 | from typing import List
10 | from .lv_types import VisArgs
11 |
12 | class NotebookMaker:
13 | def __init__(self, watcher, filename:str=None)->None:
14 | self.filename = filename or \
15 | (path.splitext(watcher.filename)[0] + '.ipynb' if watcher.filename else \
16 | 'tensorwatch.ipynb')
17 |
18 | self.cells = []
19 | self._default_vis_args = VisArgs()
20 |
21 | watcher_args_str = NotebookMaker._get_vis_args(watcher)
22 |
23 | # create initial cell
24 | self.cells.append(new_code_cell(source=linesep.join(
25 | ['%matplotlib notebook',
26 | 'import tensorwatch as tw',
27 | 'client = tw.WatcherClient({})'.format(NotebookMaker._get_vis_args(watcher))])))
28 |
29 | def _get_vis_args(watcher)->str:
30 | args_strs = []
31 | for param, default_v in [('port', 0), ('filename', None)]:
32 | if hasattr(watcher, param):
33 | v = getattr(watcher, param)
34 | if v==default_v or (v is None and default_v is None):
35 | continue
36 | args_strs.append("{}={}".format(param, NotebookMaker._val2str(v)))
37 | return ', '.join(args_strs)
38 |
39 | def _get_stream_identifier(prefix, event_name, stream_name, stream_index)->str:
40 | if not stream_name or utils.is_uuid4(stream_name):
41 | if event_name is not None and event_name != '':
42 | return '{}_{}_{}'.format(prefix, event_name, stream_index)
43 | else:
44 | return prefix + str(stream_index)
45 | else:
46 | return '{}{}_{}'.format(prefix, stream_index, utils.str2identifier(stream_name)[:8])
47 |
48 | def _val2str(v)->str:
49 | # TODO: shall we raise error if non str, bool, number (or its container) parameters?
50 | return str(v) if not isinstance(v, str) else "'{}'".format(v)
51 |
52 | def _add_vis_args_str(self, stream_info, param_strs:List[str])->None:
53 | default_args = self._default_vis_args.__dict__
54 | if not stream_info.req.vis_args:
55 | return
56 | for k, v in stream_info.req.vis_args.__dict__.items():
57 | if k in default_args:
58 | default_v = default_args[k]
59 | if (v is None and default_v is None) or (v==default_v):
60 | continue # skip param if its value is not changed from default
61 | param_strs.append("{}={}".format(k, NotebookMaker._val2str(v)))
62 |
63 | def _get_stream_code(self, event_name, stream_name, stream_index, stream_info)->List[str]:
64 | lines = []
65 |
66 | stream_identifier = 's'+str(stream_index)
67 | lines.append("{} = client.open_stream(name='{}')".format(stream_identifier, stream_name))
68 |
69 | vis_identifier = 'v'+str(stream_index)
70 | vis_args_strs = ['stream={}'.format(stream_identifier)]
71 | self._add_vis_args_str(stream_info, vis_args_strs)
72 | lines.append("{} = tw.Visualizer({})".format(vis_identifier, ', '.join(vis_args_strs)))
73 | lines.append("{}.show()".format(vis_identifier))
74 | return lines
75 |
76 | def add_streams(self, event_stream_infos)->None:
77 | stream_index = 0
78 | for event_name, stream_infos in event_stream_infos.items(): # per event
79 | for stream_name, stream_info in stream_infos.items():
80 | lines = self._get_stream_code(event_name, stream_name, stream_index, stream_info)
81 | self.cells.append(new_code_cell(source=linesep.join(lines)))
82 | stream_index += 1
83 |
84 | def write(self):
85 | nb = new_notebook(cells=self.cells, metadata={'language': 'python',})
86 | with codecs.open(self.filename, encoding='utf-8', mode='w') as f:
87 | utils.debug_log('Notebook created', path.realpath(f.name), verbosity=0)
88 | nbformat.write(nb, f, 4)
89 |
90 |
91 |
92 |
93 |
--------------------------------------------------------------------------------
/tensorwatch/plotly/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 |
--------------------------------------------------------------------------------
/tensorwatch/plotly/base_plotly_plot.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from ..vis_base import VisBase
5 | import time
6 | from abc import abstractmethod
7 | from .. import utils
8 |
9 |
10 | class BasePlotlyPlot(VisBase):
11 | def __init__(self, cell:VisBase.widgets.Box=None, title=None, show_legend:bool=None, is_3d:bool=False,
12 | stream_name:str=None, console_debug:bool=False, **vis_args):
13 | import plotly.graph_objs as go # function-level import as this takes long time
14 | super(BasePlotlyPlot, self).__init__(go.FigureWidget(), cell, title, show_legend,
15 | stream_name=stream_name, console_debug=console_debug, **vis_args)
16 |
17 | self.is_3d = is_3d
18 | self.widget.layout.title = title
19 | self.widget.layout.showlegend = show_legend if show_legend is not None else True
20 |
21 | def _add_trace(self, stream_vis):
22 | stream_vis.trace_index = len(self.widget.data)
23 | trace = self._create_trace(stream_vis)
24 | if stream_vis.opacity is not None:
25 | trace.opacity = stream_vis.opacity
26 | self.widget.add_trace(trace)
27 |
28 | def _add_trace_with_history(self, stream_vis):
29 | # if history buffer isn't full
30 | if stream_vis.history_len > len(stream_vis.trace_history):
31 | self._add_trace(stream_vis)
32 | stream_vis.trace_history.append(len(self.widget.data)-1)
33 | stream_vis.cur_history_index = len(stream_vis.trace_history)-1
34 | #if stream_vis.cur_history_index:
35 | # self.widget.data[trace_index].showlegend = False
36 | else:
37 | # rotate trace
38 | stream_vis.cur_history_index = (stream_vis.cur_history_index + 1) % stream_vis.history_len
39 | stream_vis.trace_index = stream_vis.trace_history[stream_vis.cur_history_index]
40 | self.clear_plot(stream_vis, False)
41 | self.widget.data[stream_vis.trace_index].opacity = stream_vis.opacity or 1
42 |
43 | cur_history_len = len(stream_vis.trace_history)
44 | if stream_vis.dim_history and cur_history_len > 1:
45 | max_opacity = stream_vis.opacity or 1
46 | min_alpha, max_alpha, dimmed_len = max_opacity*0.05, max_opacity*0.8, cur_history_len-1
47 | alphas = list(utils.frange(max_alpha, min_alpha, steps=dimmed_len))
48 | for i, thi in enumerate(range(stream_vis.cur_history_index+1,
49 | stream_vis.cur_history_index+cur_history_len)):
50 | trace_index = stream_vis.trace_history[thi % cur_history_len]
51 | self.widget.data[trace_index].opacity = alphas[i]
52 |
53 | @staticmethod
54 | def get_pallet_color(i:int):
55 | import plotly # function-level import as this takes long time
56 | return plotly.colors.DEFAULT_PLOTLY_COLORS[i % len(plotly.colors.DEFAULT_PLOTLY_COLORS)]
57 |
58 | @staticmethod
59 | def _get_axis_common_props(title:str, axis_range:tuple):
60 | props = {'showline':True, 'showgrid': True,
61 | 'showticklabels': True, 'ticks':'inside'}
62 | if title:
63 | props['title'] = title
64 | if axis_range:
65 | props['range'] = list(axis_range)
66 | return props
67 |
68 | def _can_update_stream_plots(self):
69 | return time.time() - self.q_last_processed > 0.5 # make configurable
70 |
71 | def _post_add_subscription(self, stream_vis, **stream_vis_args):
72 | stream_vis.trace_history, stream_vis.cur_history_index = [], None
73 | self._add_trace_with_history(stream_vis)
74 | self._setup_layout(stream_vis)
75 |
76 | if not self.widget.layout.title:
77 | self.widget.layout.title = stream_vis.title
78 | # TODO: better way for below?
79 | if stream_vis.history_len > 1:
80 | self.widget.layout.showlegend = False
81 |
82 | def _show_widget_native(self, blocking:bool):
83 | pass
84 | #TODO: save image, spawn browser?
85 |
86 | def _show_widget_notebook(self):
87 | #plotly.offline.iplot(self.widget)
88 | return None
89 |
90 | def _post_update_stream_plot(self, stream_vis):
91 | # not needed for plotly as FigureWidget stays upto date
92 | pass
93 |
94 | @abstractmethod
95 | def clear_plot(self, stream_vis, clear_history):
96 | """(for derived class) Clears the data in specified plot before new data is redrawn"""
97 | pass
98 | @abstractmethod
99 | def _show_stream_items(self, stream_vis, stream_items):
100 | """Paint the given stream_items in to visualizer. If visualizer is dirty then return False else True.
101 | """
102 |
103 | pass
104 | @abstractmethod
105 | def _setup_layout(self, stream_vis):
106 | pass
107 | @abstractmethod
108 | def _create_trace(self, stream_vis):
109 | pass
110 |
--------------------------------------------------------------------------------
/tensorwatch/plotly/embeddings_plot.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from .. import image_utils
5 | import numpy as np
6 | from .line_plot import LinePlot
7 | import time
8 | from .. import utils
9 |
10 | class EmbeddingsPlot(LinePlot):
11 | def __init__(self, cell:LinePlot.widgets.Box=None, title=None, show_legend:bool=False, stream_name:str=None, console_debug:bool=False,
12 | is_3d:bool=True, hover_images=None, hover_image_reshape=None, **vis_args):
13 | utils.set_default(vis_args, 'height', '8in')
14 | super(EmbeddingsPlot, self).__init__(cell, title, show_legend,
15 | stream_name=stream_name, console_debug=console_debug, is_3d=is_3d, **vis_args)
16 |
17 | import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue
18 | if hover_images is not None:
19 | plt.ioff()
20 | self.image_output = LinePlot.widgets.Output()
21 | self.image_figure = plt.figure(figsize=(2,2))
22 | self.image_ax = self.image_figure.add_subplot(111)
23 | self.cell.children += (self.image_output,)
24 | plt.ion()
25 | self.hover_images, self.hover_image_reshape = hover_images, hover_image_reshape
26 | self.last_ind, self.last_ind_time = -1, 0
27 |
28 | def hover_fn(self, trace, points, state): # pylint: disable=unused-argument
29 | if not points:
30 | return
31 | ind = points.point_inds[0]
32 | if ind == self.last_ind or ind > len(self.hover_images) or ind < 0:
33 | return
34 |
35 | if self.last_ind == -1:
36 | self.last_ind, self.last_ind_time = ind, time.time()
37 | else:
38 | elapsed = time.time() - self.last_ind_time
39 | if elapsed < 0.3:
40 | self.last_ind, self.last_ind_time = ind, time.time()
41 | if elapsed < 1:
42 | return
43 | # else too much time since update
44 | # else we have stable ind
45 |
46 | import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue
47 | with self.image_output:
48 | plt.ioff()
49 |
50 | if self.hover_image_reshape:
51 | img = np.reshape(self.hover_images[ind], self.hover_image_reshape)
52 | else:
53 | img = self.hover_images[ind]
54 | if img is not None:
55 | LinePlot.display.clear_output(wait=True)
56 | self.image_ax.imshow(img)
57 | LinePlot.display.display(self.image_figure)
58 | plt.ion()
59 |
60 | return None
61 |
62 | def _create_trace(self, stream_vis):
63 | stream_vis.stream_vis_args.clear() #TODO remove this
64 | utils.set_default(stream_vis.stream_vis_args, 'draw_line', False)
65 | utils.set_default(stream_vis.stream_vis_args, 'draw_marker', True)
66 | utils.set_default(stream_vis.stream_vis_args, 'draw_marker_text', True)
67 | utils.set_default(stream_vis.stream_vis_args, 'hoverinfo', 'text')
68 | utils.set_default(stream_vis.stream_vis_args, 'marker', {})
69 |
70 | marker = stream_vis.stream_vis_args['marker']
71 | utils.set_default(marker, 'size', 6)
72 | utils.set_default(marker, 'colorscale', 'Jet')
73 | utils.set_default(marker, 'showscale', False)
74 | utils.set_default(marker, 'opacity', 0.8)
75 |
76 | return super(EmbeddingsPlot, self)._create_trace(stream_vis)
77 |
78 | def subscribe(self, stream, **stream_vis_args):
79 | super(EmbeddingsPlot, self).subscribe(stream)
80 | stream_vis = self._stream_vises[stream.stream_name]
81 | if stream_vis.index == 0 and self.hover_images is not None:
82 | self.widget.data[stream_vis.trace_index].on_hover(self.hover_fn)
--------------------------------------------------------------------------------
/tensorwatch/pytorch_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from torchvision import models, transforms
5 | import torch
6 | import torch.nn.functional as F
7 | from . import utils, image_utils
8 | import os
9 |
10 | def get_model(model_name):
11 | model = models.__dict__[model_name](pretrained=True)
12 | return model
13 |
14 | def tensors2batch(tensors, preprocess_transform=None):
15 | if preprocess_transform:
16 | tensors = tuple(preprocess_transform(i) for i in tensors)
17 | if not utils.is_array_like(tensors):
18 | tensors = tuple(tensors)
19 | return torch.stack(tensors, dim=0)
20 |
21 | def int2tensor(val):
22 | return torch.LongTensor([val])
23 |
24 | def image2batch(image, image_transform=None):
25 | if image_transform:
26 | input_x = image_transform(image)
27 | else: # if no transforms supplied then just convert PIL image to tensor
28 | input_x = transforms.ToTensor()(image)
29 | input_x = input_x.unsqueeze(0) #convert to batch of 1
30 | return input_x
31 |
32 | def image_class2tensor(image_path, class_index=None, image_convert_mode=None,
33 | image_transform=None):
34 | image_pil = image_utils.open_image(os.path.abspath(image_path), convert_mode=image_convert_mode)
35 | input_x = image2batch(image_pil, image_transform)
36 | target_class = int2tensor(class_index) if class_index is not None else None
37 | return image_pil, input_x, target_class
38 |
39 | def batch_predict(model, inputs, input_transform=None, device=None):
40 | if input_transform:
41 | batch = torch.stack(tuple(input_transform(i) for i in inputs), dim=0)
42 | else:
43 | batch = torch.stack(inputs, dim=0)
44 |
45 | device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
46 | model.eval()
47 | model.to(device)
48 | batch = batch.to(device)
49 |
50 | outputs = model(batch)
51 |
52 | return outputs
53 |
54 | def logits2probabilities(logits):
55 | return F.softmax(logits, dim=1)
56 |
57 | def tensor2numpy(t):
58 | return t.data.cpu().numpy()
59 |
--------------------------------------------------------------------------------
/tensorwatch/receptive_field/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 |
--------------------------------------------------------------------------------
/tensorwatch/receptive_field/rf_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from receptivefield.pytorch import PytorchReceptiveField
5 | #from receptivefield.image import get_default_image
6 | import numpy as np
7 |
8 | def _get_rf(model, sample_pil_img):
9 | # define model functions
10 | def model_fn():
11 | model.eval()
12 | return model
13 |
14 | input_shape = np.array(sample_pil_img).shape
15 |
16 | rf = PytorchReceptiveField(model_fn)
17 | rf_params = rf.compute(input_shape=input_shape)
18 | return rf, rf_params
19 |
20 | def plot_receptive_field(model, sample_pil_img, layout=(2, 2), figsize=(6, 6)):
21 | rf, rf_params = _get_rf(model, sample_pil_img) # pylint: disable=unused-variable
22 | return rf.plot_rf_grids(
23 | custom_image=sample_pil_img,
24 | figsize=figsize,
25 | layout=layout)
26 |
27 | def plot_grads_at(model, sample_pil_img, feature_map_index=0, point=(8,8), figsize=(6, 6)):
28 | rf, rf_params = _get_rf(model, sample_pil_img) # pylint: disable=unused-variable
29 | return rf.plot_gradient_at(fm_id=feature_map_index, point=point, image=None, figsize=figsize)
30 |
--------------------------------------------------------------------------------
/tensorwatch/repeated_timer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import threading
5 | import time
6 | import weakref
7 |
8 | class RepeatedTimer:
9 | class State:
10 | Stopped=0
11 | Paused=1
12 | Running=2
13 |
14 | def __init__(self, secs, callback, count=None):
15 | self.secs = secs
16 | self.callback = weakref.WeakMethod(callback) if callback else None
17 | self._thread = None
18 | self._state = RepeatedTimer.State.Stopped
19 | self.pause_wait = threading.Event()
20 | self.pause_wait.set()
21 | self._continue_thread = False
22 | self.count = count
23 |
24 | def start(self):
25 | self._continue_thread = True
26 | self.pause_wait.set()
27 | if self._thread is None or not self._thread.isAlive():
28 | self._thread = threading.Thread(target=self._runner, name='RepeatedTimer', daemon=True)
29 | self._thread.start()
30 | self._state = RepeatedTimer.State.Running
31 |
32 | def stop(self, block=False):
33 | self.pause_wait.set()
34 | self._continue_thread = False
35 | if block and not (self._thread is None or not self._thread.isAlive()):
36 | self._thread.join()
37 | self._state = RepeatedTimer.State.Stopped
38 |
39 | def get_state(self):
40 | return self._state
41 |
42 |
43 | def pause(self):
44 | if self._state == RepeatedTimer.State.Running:
45 | self.pause_wait.clear()
46 | self._state = RepeatedTimer.State.Paused
47 | # else nothing to do
48 | def unpause(self):
49 | if self._state == RepeatedTimer.State.Paused:
50 | self.pause_wait.set()
51 | if self._state == RepeatedTimer.State.Paused:
52 | self._state = RepeatedTimer.State.Running
53 | # else nothing to do
54 |
55 | def _runner(self):
56 | while (self._continue_thread):
57 | if self.count:
58 | self.count -= 0
59 | if not self.count:
60 | self._continue_thread = False
61 |
62 | if self._continue_thread:
63 | self.pause_wait.wait()
64 | if self.callback and self.callback():
65 | self.callback()()
66 |
67 | if self._continue_thread:
68 | time.sleep(self.secs)
69 |
70 | self._thread = None
71 | self._state = RepeatedTimer.State.Stopped
--------------------------------------------------------------------------------
/tensorwatch/saliency/README.md:
--------------------------------------------------------------------------------
1 | # Credits
2 | Code in this folder is adopted from
3 |
4 | * https://github.com/yulongwang12/visual-attribution
5 | * https://github.com/marcotcr/lime
6 |
--------------------------------------------------------------------------------
/tensorwatch/saliency/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 |
--------------------------------------------------------------------------------
/tensorwatch/saliency/backprop.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch.autograd import Variable, Function
3 | import torch
4 | import types
5 |
6 |
7 | class VanillaGradExplainer(object):
8 | def __init__(self, model):
9 | self.model = model
10 |
11 | def _backprop(self, inp, ind):
12 | inp.requires_grad = True
13 | if inp.grad is not None:
14 | inp.grad.zero_()
15 | if ind.grad is not None:
16 | ind.grad.zero_()
17 | self.model.eval()
18 | self.model.zero_grad()
19 |
20 | output = self.model(inp)
21 | if ind is None:
22 | ind = output.max(1)[1]
23 | grad_out = output.clone()
24 | grad_out.fill_(0.0)
25 | grad_out.scatter_(1, ind.unsqueeze(0).t(), 1.0)
26 | output.backward(grad_out)
27 | return inp.grad
28 |
29 | def explain(self, inp, ind=None, raw_inp=None):
30 | return self._backprop(inp, ind)
31 |
32 |
33 | class GradxInputExplainer(VanillaGradExplainer):
34 | def __init__(self, model):
35 | super(GradxInputExplainer, self).__init__(model)
36 |
37 | def explain(self, inp, ind=None, raw_inp=None):
38 | grad = self._backprop(inp, ind)
39 | return inp * grad
40 |
41 |
42 | class SaliencyExplainer(VanillaGradExplainer):
43 | def __init__(self, model):
44 | super(SaliencyExplainer, self).__init__(model)
45 |
46 | def explain(self, inp, ind=None, raw_inp=None):
47 | grad = self._backprop(inp, ind)
48 | return grad.abs()
49 |
50 |
51 | class IntegrateGradExplainer(VanillaGradExplainer):
52 | def __init__(self, model, steps=100):
53 | super(IntegrateGradExplainer, self).__init__(model)
54 | self.steps = steps
55 |
56 | def explain(self, inp, ind=None, raw_inp=None):
57 | grad = 0
58 | inp_data = inp.clone()
59 |
60 | for alpha in np.arange(1 / self.steps, 1.0, 1 / self.steps):
61 | new_inp = Variable(inp_data * alpha, requires_grad=True)
62 | g = self._backprop(new_inp, ind)
63 | grad += g
64 |
65 | return grad * inp_data / self.steps
66 |
67 |
68 | class DeconvExplainer(VanillaGradExplainer):
69 | def __init__(self, model):
70 | super(DeconvExplainer, self).__init__(model)
71 | self._override_backward()
72 |
73 | def _override_backward(self):
74 | class _ReLU(Function):
75 | @staticmethod
76 | def forward(ctx, input):
77 | output = torch.clamp(input, min=0)
78 | return output
79 |
80 | @staticmethod
81 | def backward(ctx, grad_output):
82 | grad_inp = torch.clamp(grad_output, min=0)
83 | return grad_inp
84 |
85 | def new_forward(self, x):
86 | return _ReLU.apply(x)
87 |
88 | def replace(m):
89 | if m.__class__.__name__ == 'ReLU':
90 | m.forward = types.MethodType(new_forward, m)
91 |
92 | self.model.apply(replace)
93 |
94 |
95 | class GuidedBackpropExplainer(VanillaGradExplainer):
96 | def __init__(self, model):
97 | super(GuidedBackpropExplainer, self).__init__(model)
98 | self._override_backward()
99 |
100 | def _override_backward(self):
101 | class _ReLU(Function):
102 | @staticmethod
103 | def forward(ctx, input):
104 | output = torch.clamp(input, min=0)
105 | ctx.save_for_backward(output)
106 | return output
107 |
108 | @staticmethod
109 | def backward(ctx, grad_output):
110 | output, = ctx.saved_tensors
111 | mask1 = (output > 0).float()
112 | mask2 = (grad_output > 0).float()
113 | grad_inp = mask1 * mask2 * grad_output
114 | grad_output.copy_(grad_inp)
115 | return grad_output
116 |
117 | def new_forward(self, x):
118 | return _ReLU.apply(x)
119 |
120 | def replace(m):
121 | if m.__class__.__name__ == 'ReLU':
122 | m.forward = types.MethodType(new_forward, m)
123 |
124 | self.model.apply(replace)
125 |
126 |
127 | # modified from https://github.com/PAIR-code/saliency/blob/master/saliency/base.py#L80
128 | class SmoothGradExplainer(object):
129 | def __init__(self, model, base_explainer=None, stdev_spread=0.15,
130 | nsamples=25, magnitude=True):
131 | self.base_explainer = base_explainer or VanillaGradExplainer(model)
132 | self.stdev_spread = stdev_spread
133 | self.nsamples = nsamples
134 | self.magnitude = magnitude
135 |
136 | def explain(self, inp, ind=None, raw_inp=None):
137 | stdev = self.stdev_spread * (inp.max() - inp.min())
138 |
139 | total_gradients = 0
140 |
141 | for i in range(self.nsamples):
142 | noise = torch.randn_like(inp) * stdev
143 |
144 | noisy_inp = inp + noise
145 | noisy_inp.retain_grad()
146 | grad = self.base_explainer.explain(noisy_inp, ind)
147 |
148 | if self.magnitude:
149 | total_gradients += grad ** 2
150 | else:
151 | total_gradients += grad
152 |
153 | return total_gradients / self.nsamples
--------------------------------------------------------------------------------
/tensorwatch/saliency/deeplift.py:
--------------------------------------------------------------------------------
1 | from .backprop import GradxInputExplainer
2 | import types
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 |
6 | # Based on formulation in DeepExplain, https://arxiv.org/abs/1711.06104
7 | # https://github.com/marcoancona/DeepExplain/blob/master/deepexplain/tensorflow/methods.py#L221-L272
8 | class DeepLIFTRescaleExplainer(GradxInputExplainer):
9 | def __init__(self, model):
10 | super(DeepLIFTRescaleExplainer, self).__init__(model)
11 | self._prepare_reference()
12 | self.baseline_inp = None
13 | self._override_backward()
14 |
15 | def _prepare_reference(self):
16 | def init_refs(m):
17 | name = m.__class__.__name__
18 | if name.find('ReLU') != -1:
19 | m.ref_inp_list = []
20 | m.ref_out_list = []
21 |
22 | def ref_forward(self, x):
23 | self.ref_inp_list.append(x.data.clone())
24 | out = F.relu(x)
25 | self.ref_out_list.append(out.data.clone())
26 | return out
27 |
28 | def ref_replace(m):
29 | name = m.__class__.__name__
30 | if name.find('ReLU') != -1:
31 | m.forward = types.MethodType(ref_forward, m)
32 |
33 | self.model.apply(init_refs)
34 | self.model.apply(ref_replace)
35 |
36 | def _reset_preference(self):
37 | def reset_refs(m):
38 | name = m.__class__.__name__
39 | if name.find('ReLU') != -1:
40 | m.ref_inp_list = []
41 | m.ref_out_list = []
42 |
43 | self.model.apply(reset_refs)
44 |
45 | def _baseline_forward(self, inp):
46 | if self.baseline_inp is None:
47 | self.baseline_inp = inp.data.clone()
48 | self.baseline_inp.fill_(0.0)
49 | self.baseline_inp = Variable(self.baseline_inp)
50 | else:
51 | self.baseline_inp.fill_(0.0)
52 | # get ref
53 | _ = self.model(self.baseline_inp)
54 |
55 | def _override_backward(self):
56 | def new_backward(self, grad_out):
57 | ref_inp, inp = self.ref_inp_list
58 | ref_out, out = self.ref_out_list
59 | delta_out = out - ref_out
60 | delta_in = inp - ref_inp
61 | g1 = (delta_in.abs() > 1e-5).float() * grad_out * \
62 | delta_out / delta_in
63 | mask = ((ref_inp + inp) > 0).float()
64 | g2 = (delta_in.abs() <= 1e-5).float() * 0.5 * mask * grad_out
65 |
66 | return g1 + g2
67 |
68 | def backward_replace(m):
69 | name = m.__class__.__name__
70 | if name.find('ReLU') != -1:
71 | m.backward = types.MethodType(new_backward, m)
72 |
73 | self.model.apply(backward_replace)
74 |
75 | def explain(self, inp, ind=None, raw_inp=None):
76 | self._reset_preference()
77 | self._baseline_forward(inp)
78 | g = super(DeepLIFTRescaleExplainer, self).explain(inp, ind)
79 |
80 | return g
81 |
--------------------------------------------------------------------------------
/tensorwatch/saliency/gradcam.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .backprop import VanillaGradExplainer
3 |
4 | def _get_layer(model, key_list):
5 | if key_list is None:
6 | return None
7 |
8 | a = model
9 | for key in key_list:
10 | a = a._modules[key]
11 |
12 | return a
13 |
14 | class GradCAMExplainer(VanillaGradExplainer):
15 | def __init__(self, model, target_layer_name_keys=None, use_inp=False):
16 | super(GradCAMExplainer, self).__init__(model)
17 | self.target_layer = _get_layer(model, target_layer_name_keys)
18 | self.use_inp = use_inp
19 | self.intermediate_act = []
20 | self.intermediate_grad = []
21 | self.handle_forward_hook = None
22 | self.handle_backward_hook = None
23 | self._register_forward_backward_hook()
24 |
25 | def _register_forward_backward_hook(self):
26 | def forward_hook_input(m, i, o):
27 | self.intermediate_act.append(i[0].data.clone())
28 |
29 | def forward_hook_output(m, i, o):
30 | self.intermediate_act.append(o.data.clone())
31 |
32 | def backward_hook(m, grad_i, grad_o):
33 | self.intermediate_grad.append(grad_o[0].data.clone())
34 |
35 | if self.target_layer is not None:
36 | if self.use_inp:
37 | self.handle_forward_hook = self.target_layer.register_forward_hook(forward_hook_input)
38 | else:
39 | self.handle_forward_hook = self.target_layer.register_forward_hook(forward_hook_output)
40 |
41 | self.handle_backward_hook = self.target_layer.register_backward_hook(backward_hook)
42 |
43 | def _reset_intermediate_lists(self):
44 | self.intermediate_act = []
45 | self.intermediate_grad = []
46 |
47 | def explain(self, inp, ind=None, raw_inp=None):
48 | self._reset_intermediate_lists()
49 | _ = super(GradCAMExplainer, self)._backprop(inp, ind)
50 | self.handle_forward_hook.remove()
51 | self.handle_backward_hook.remove()
52 | if len(self.intermediate_grad):
53 | grad = self.intermediate_grad[0]
54 | act = self.intermediate_act[0]
55 |
56 | weights = grad.sum(-1).sum(-1).unsqueeze(-1).unsqueeze(-1)
57 | cam = weights * act
58 | cam = cam.sum(1).unsqueeze(1)
59 |
60 | cam = torch.clamp(cam, min=0)
61 | return cam
62 | else:
63 | return None
64 |
65 |
--------------------------------------------------------------------------------
/tensorwatch/saliency/lime/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/tensorwatch/saliency/lime/__init__.py
--------------------------------------------------------------------------------
/tensorwatch/saliency/lime/wrappers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/tensorwatch/05121c60bf1d336634d1cc2f50a55046d4eec10b/tensorwatch/saliency/lime/wrappers/__init__.py
--------------------------------------------------------------------------------
/tensorwatch/saliency/lime/wrappers/generic_utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import inspect
3 | import types
4 |
5 |
6 | def has_arg(fn, arg_name):
7 | """Checks if a callable accepts a given keyword argument.
8 |
9 | Args:
10 | fn: callable to inspect
11 | arg_name: string, keyword argument name to check
12 |
13 | Returns:
14 | bool, whether `fn` accepts a `arg_name` keyword argument.
15 | """
16 | if sys.version_info < (3,):
17 | if isinstance(fn, types.FunctionType) or isinstance(fn, types.MethodType):
18 | arg_spec = inspect.getargspec(fn)
19 | else:
20 | try:
21 | arg_spec = inspect.getargspec(fn.__call__)
22 | except AttributeError:
23 | return False
24 | return (arg_name in arg_spec.args)
25 | elif sys.version_info < (3, 6):
26 | arg_spec = inspect.getfullargspec(fn)
27 | return (arg_name in arg_spec.args or
28 | arg_name in arg_spec.kwonlyargs)
29 | else:
30 | try:
31 | signature = inspect.signature(fn)
32 | except ValueError:
33 | # handling Cython
34 | signature = inspect.signature(fn.__call__)
35 | parameter = signature.parameters.get(arg_name)
36 | if parameter is None:
37 | return False
38 | return (parameter.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD,
39 | inspect.Parameter.KEYWORD_ONLY))
40 |
--------------------------------------------------------------------------------
/tensorwatch/saliency/lime/wrappers/scikit_image.py:
--------------------------------------------------------------------------------
1 | import types
2 | from .generic_utils import has_arg
3 | from skimage.segmentation import felzenszwalb, slic, quickshift
4 |
5 |
6 | class BaseWrapper(object):
7 | """Base class for LIME Scikit-Image wrapper
8 |
9 |
10 | Args:
11 | target_fn: callable function or class instance
12 | target_params: dict, parameters to pass to the target_fn
13 |
14 |
15 | 'target_params' takes parameters required to instanciate the
16 | desired Scikit-Image class/model
17 | """
18 |
19 | def __init__(self, target_fn=None, **target_params):
20 | self.target_fn = target_fn
21 | self.target_params = target_params
22 |
23 | self.target_fn = target_fn
24 | self.target_params = target_params
25 |
26 | def _check_params(self, parameters):
27 | """Checks for mistakes in 'parameters'
28 |
29 | Args :
30 | parameters: dict, parameters to be checked
31 |
32 | Raises :
33 | ValueError: if any parameter is not a valid argument for the target function
34 | or the target function is not defined
35 | TypeError: if argument parameters is not iterable
36 | """
37 | a_valid_fn = []
38 | if self.target_fn is None:
39 | if callable(self):
40 | a_valid_fn.append(self.__call__)
41 | else:
42 | raise TypeError('invalid argument: tested object is not callable,\
43 | please provide a valid target_fn')
44 | elif isinstance(self.target_fn, types.FunctionType) \
45 | or isinstance(self.target_fn, types.MethodType):
46 | a_valid_fn.append(self.target_fn)
47 | else:
48 | a_valid_fn.append(self.target_fn.__call__)
49 |
50 | if not isinstance(parameters, str):
51 | for p in parameters:
52 | for fn in a_valid_fn:
53 | if has_arg(fn, p):
54 | pass
55 | else:
56 | raise ValueError('{} is not a valid parameter'.format(p))
57 | else:
58 | raise TypeError('invalid argument: list or dictionnary expected')
59 |
60 | def set_params(self, **params):
61 | """Sets the parameters of this estimator.
62 | Args:
63 | **params: Dictionary of parameter names mapped to their values.
64 |
65 | Raises :
66 | ValueError: if any parameter is not a valid argument
67 | for the target function
68 | """
69 | self._check_params(params)
70 | self.target_params = params
71 |
72 | def filter_params(self, fn, override=None):
73 | """Filters `target_params` and return those in `fn`'s arguments.
74 | Args:
75 | fn : arbitrary function
76 | override: dict, values to override target_params
77 | Returns:
78 | result : dict, dictionary containing variables
79 | in both target_params and fn's arguments.
80 | """
81 | override = override or {}
82 | result = {}
83 | for name, value in self.target_params.items():
84 | if has_arg(fn, name):
85 | result.update({name: value})
86 | result.update(override)
87 | return result
88 |
89 |
90 | class SegmentationAlgorithm(BaseWrapper):
91 | """ Define the image segmentation function based on Scikit-Image
92 | implementation and a set of provided parameters
93 |
94 | Args:
95 | algo_type: string, segmentation algorithm among the following:
96 | 'quickshift', 'slic', 'felzenszwalb'
97 | target_params: dict, algorithm parameters (valid model paramters
98 | as define in Scikit-Image documentation)
99 | """
100 |
101 | def __init__(self, algo_type, **target_params):
102 | self.algo_type = algo_type
103 | if (self.algo_type == 'quickshift'):
104 | BaseWrapper.__init__(self, quickshift, **target_params)
105 | kwargs = self.filter_params(quickshift)
106 | self.set_params(**kwargs)
107 | elif (self.algo_type == 'felzenszwalb'):
108 | BaseWrapper.__init__(self, felzenszwalb, **target_params)
109 | kwargs = self.filter_params(felzenszwalb)
110 | self.set_params(**kwargs)
111 | elif (self.algo_type == 'slic'):
112 | BaseWrapper.__init__(self, slic, **target_params)
113 | kwargs = self.filter_params(slic)
114 | self.set_params(**kwargs)
115 |
116 | def __call__(self, *args):
117 | return self.target_fn(args[0], **self.target_params)
118 |
--------------------------------------------------------------------------------
/tensorwatch/saliency/lime_image_explainer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from skimage.segmentation import mark_boundaries
5 | from torchvision import transforms
6 | import torch
7 | from .lime import lime_image
8 | import numpy as np
9 | from .. import imagenet_utils, pytorch_utils, utils
10 |
11 | class LimeImageExplainer:
12 | def __init__(self, model, predict_fn):
13 | self.model = model
14 | self.predict_fn = predict_fn
15 |
16 | def preprocess_input(self, inp):
17 | return inp
18 | def preprocess_label(self, label):
19 | return label
20 |
21 | def explain(self, inp, ind=None, raw_inp=None, top_labels=5, hide_color=0, num_samples=1000,
22 | positive_only=True, num_features=5, hide_rest=True, pixel_val_max=255.0):
23 | explainer = lime_image.LimeImageExplainer()
24 | explanation = explainer.explain_instance(self.preprocess_input(raw_inp), self.predict_fn,
25 | top_labels=5, hide_color=0, num_samples=1000)
26 |
27 | temp, mask = explanation.get_image_and_mask(self.preprocess_label(ind) or explanation.top_labels[0],
28 | positive_only=True, num_features=5, hide_rest=True)
29 |
30 | img = mark_boundaries(temp/pixel_val_max, mask)
31 | img = torch.from_numpy(img)
32 | img = torch.transpose(img, 0, 2)
33 | img = torch.transpose(img, 1, 2)
34 | return img.unsqueeze(0)
35 |
36 | class LimeImagenetExplainer(LimeImageExplainer):
37 | def __init__(self, model, predict_fn=None):
38 | super(LimeImagenetExplainer, self).__init__(model, predict_fn or self._imagenet_predict)
39 |
40 | def _preprocess_transform(self):
41 | transf = transforms.Compose([
42 | transforms.ToTensor(),
43 | imagenet_utils.get_normalize_transform()
44 | ])
45 |
46 | return transf
47 |
48 | def preprocess_input(self, inp):
49 | return np.array(imagenet_utils.get_resize_transform()(inp))
50 | def preprocess_label(self, label):
51 | return label.item() if label is not None and utils.has_method(label, 'item') else label
52 |
53 | def _imagenet_predict(self, images):
54 | probs = imagenet_utils.predict(self.model, images, image_transform=self._preprocess_transform())
55 | return pytorch_utils.tensor2numpy(probs)
56 |
--------------------------------------------------------------------------------
/tensorwatch/saliency/occlusion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch.autograd import Variable
4 | from skimage.util import view_as_windows
5 |
6 | # modified from https://github.com/marcoancona/DeepExplain/blob/master/deepexplain/tensorflow/methods.py#L291-L342
7 | # note the different dim order in pytorch (NCHW) and tensorflow (NHWC)
8 |
9 | class OcclusionExplainer:
10 | def __init__(self, model, window_shape=10, step=1):
11 | self.model = model
12 | self.window_shape = window_shape
13 | self.step = step
14 |
15 | def explain(self, inp, ind=None, raw_inp=None):
16 | self.model.eval()
17 | with torch.no_grad():
18 | return OcclusionExplainer._occlusion(inp, self.model, self.window_shape, self.step)
19 |
20 | @staticmethod
21 | def _occlusion(inp, model, window_shape, step=None):
22 | if type(window_shape) == int:
23 | window_shape = (window_shape, window_shape, 3)
24 |
25 | if step is None:
26 | step = 1
27 | n, c, h, w = inp.data.size()
28 | total_dim = c * h * w
29 | index_matrix = np.arange(total_dim).reshape(h, w, c)
30 | idx_patches = view_as_windows(index_matrix, window_shape, step).reshape(
31 | (-1,) + window_shape)
32 | heatmap = np.zeros((n, h, w, c), dtype=np.float32).reshape((-1), total_dim)
33 | weights = np.zeros_like(heatmap)
34 |
35 | inp_data = inp.data.clone()
36 | new_inp = Variable(inp_data)
37 | eval0 = model(new_inp)
38 | pred_id = eval0.max(1)[1].data[0]
39 |
40 | for i, p in enumerate(idx_patches):
41 | mask = np.ones((h, w, c)).flatten()
42 | mask[p.flatten()] = 0
43 | th_mask = torch.from_numpy(mask.reshape(1, h, w, c).transpose(0, 3, 1, 2)).float().cuda()
44 | masked_xs = Variable(th_mask * inp_data)
45 | delta = (eval0[0, pred_id] - model(masked_xs)[0, pred_id]).data.cpu().numpy()
46 | delta_aggregated = np.sum(delta.reshape(n, -1), -1, keepdims=True)
47 | heatmap[:, p.flatten()] += delta_aggregated
48 | weights[:, p.flatten()] += p.size
49 |
50 | attribution = np.reshape(heatmap / (weights + 1e-10), (n, h, w, c)).transpose(0, 3, 1, 2)
51 | return torch.from_numpy(attribution)
--------------------------------------------------------------------------------
/tensorwatch/saliency/saliency.py:
--------------------------------------------------------------------------------
1 | from .gradcam import GradCAMExplainer
2 | from .backprop import VanillaGradExplainer, GradxInputExplainer, SaliencyExplainer, \
3 | IntegrateGradExplainer, DeconvExplainer, GuidedBackpropExplainer, SmoothGradExplainer
4 | from .deeplift import DeepLIFTRescaleExplainer
5 | from .occlusion import OcclusionExplainer
6 | from .epsilon_lrp import EpsilonLrp
7 | from .lime_image_explainer import LimeImageExplainer, LimeImagenetExplainer
8 | import skimage.transform
9 | import torch
10 | import math
11 | from .. import image_utils
12 |
13 | class ImageSaliencyResult:
14 | def __init__(self, raw_image, saliency, title, saliency_alpha=0.4, saliency_cmap='jet'):
15 | self.raw_image, self.saliency, self.title = raw_image, saliency, title
16 | self.saliency_alpha, self.saliency_cmap = saliency_alpha, saliency_cmap
17 |
18 | def _get_explainer(explainer_name, model, layer_path=None):
19 | if explainer_name == 'gradcam':
20 | return GradCAMExplainer(model, target_layer_name_keys=layer_path, use_inp=True)
21 | if explainer_name == 'vanilla_grad':
22 | return VanillaGradExplainer(model)
23 | if explainer_name == 'grad_x_input':
24 | return GradxInputExplainer(model)
25 | if explainer_name == 'saliency':
26 | return SaliencyExplainer(model)
27 | if explainer_name == 'integrate_grad':
28 | return IntegrateGradExplainer(model)
29 | if explainer_name == 'deconv':
30 | return DeconvExplainer(model)
31 | if explainer_name == 'guided_backprop':
32 | return GuidedBackpropExplainer(model)
33 | if explainer_name == 'smooth_grad':
34 | return SmoothGradExplainer(model)
35 | if explainer_name == 'deeplift':
36 | return DeepLIFTRescaleExplainer(model)
37 | if explainer_name == 'occlusion':
38 | return OcclusionExplainer(model)
39 | if explainer_name == 'lrp':
40 | return EpsilonLrp(model)
41 | if explainer_name == 'lime_imagenet':
42 | return LimeImagenetExplainer(model)
43 |
44 | raise ValueError('Explainer {} is not recognized'.format(explainer_name))
45 |
46 | def _get_layer_path(model):
47 | if model.__class__.__name__ == 'VGG':
48 | return ['features', '30'] # pool5
49 | elif model.__class__.__name__ == 'GoogleNet':
50 | return ['pool5']
51 | elif model.__class__.__name__ == 'ResNet':
52 | return ['avgpool'] #layer4
53 | elif model.__class__.__name__ == 'Inception3':
54 | return ['Mixed_7c', 'branch_pool'] # ['conv2d_94'], 'mixed10'
55 | else: #TODO: guess layer for other networks?
56 | return None
57 |
58 | def get_saliency(model, raw_input, input, label, device=None, method='integrate_grad', layer_path=None):
59 | if device == None or type(device) != torch.device:
60 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61 |
62 | model.to(device)
63 | input = input.to(device)
64 | if label is not None:
65 | label = label.to(device)
66 |
67 | if input.grad is not None:
68 | input.grad.zero_()
69 | if label is not None and label.grad is not None:
70 | label.grad.zero_()
71 | model.eval()
72 | model.zero_grad()
73 |
74 | layer_path = layer_path or _get_layer_path(model)
75 | exp = _get_explainer(method, model, layer_path)
76 | saliency = exp.explain(input, label, raw_input)
77 |
78 | if saliency is not None:
79 | saliency = saliency.abs().sum(dim=1)[0].squeeze()
80 | saliency -= saliency.min()
81 | saliency /= (saliency.max() + 1e-20)
82 |
83 | return saliency.detach().cpu().numpy()
84 | else:
85 | return None
86 |
87 | def get_image_saliency_results(model, raw_image, input, label,
88 | device=None,
89 | methods=['lime_imagenet', 'gradcam', 'smooth_grad',
90 | 'guided_backprop', 'deeplift', 'grad_x_input'],
91 | layer_path=None):
92 | results = []
93 | for method in methods:
94 | sal = get_saliency(model, raw_image, input, label, device=device, method=method)
95 |
96 | if sal is not None:
97 | results.append(ImageSaliencyResult(raw_image, sal, method))
98 | return results
99 |
100 | def get_image_saliency_plot(image_saliency_results, cols:int=2, figsize:tuple=None):
101 | import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue
102 |
103 | rows = math.ceil(len(image_saliency_results) / cols)
104 | figsize=figsize or (8, 3 * rows)
105 | figure = plt.figure(figsize=figsize)
106 |
107 | for i, r in enumerate(image_saliency_results):
108 | ax = figure.add_subplot(rows, cols, i+1)
109 | ax.set_xticks([])
110 | ax.set_yticks([])
111 | ax.set_title(r.title, fontdict={'fontsize': 24}) #'fontweight': 'light'
112 |
113 | #upsampler = nn.Upsample(size=(raw_image.height, raw_image.width), mode='bilinear')
114 | saliency_upsampled = skimage.transform.resize(r.saliency,
115 | (r.raw_image.height, r.raw_image.width),
116 | mode='reflect')
117 |
118 | image_utils.show_image(r.raw_image, img2=saliency_upsampled,
119 | alpha2=r.saliency_alpha, cmap2=r.saliency_cmap, ax=ax)
120 | return figure
--------------------------------------------------------------------------------
/tensorwatch/stream.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import weakref, uuid
5 | from typing import Any
6 | from . import utils
7 | from .lv_types import StreamItem
8 |
9 | class Stream:
10 | """
11 | Stream allows you to write values into it. One stream can subscribe to many streams. When a value is
12 | written in the stream, all subscribers also gets that value written into them (the from_steam parameter
13 | is set to the source stream).
14 |
15 | You can read values from a stream by calling read_all method. This will yield values as someone or subscribed streams are
16 | writing into the stream. You can also load all values from the stream. The default stream won't load anything but
17 | derived streams like file may load all values from the file and write into themselves.
18 |
19 | You can think of stream as a pipe that can be chained to other pipees. As value is put in pipe, it travels through
20 | connected pipes. The read_all method is like a tap that you can use to read values from the pipe. The load method is
21 | for specialized streams that may generate values in that pipe.
22 |
23 | Stream class supports full multi-threading.
24 | """
25 | def __init__(self, stream_name:str=None, console_debug:bool=False):
26 | self._subscribers = weakref.WeakSet()
27 | self._subscribed_to = weakref.WeakSet()
28 | self.held_refs = set() # on some rare occasion we might want stream to hold references of other streams
29 | self.closed = False
30 | self.console_debug = console_debug
31 | self.stream_name = stream_name or str(uuid.uuid4()) # useful to use as key and avoid circular references
32 | self.items_written = 0
33 |
34 | def subscribe(self, stream:'Stream'): # notify other stream
35 | utils.debug_log('{} added {} as subscription'.format(self.stream_name, stream.stream_name))
36 | stream._subscribers.add(self)
37 | self._subscribed_to.add(stream)
38 |
39 | def unsubscribe(self, stream:'Stream'):
40 | utils.debug_log('{} removed {} as subscription'.format(self.stream_name, stream.stream_name))
41 | stream._subscribers.discard(self)
42 | self._subscribed_to.discard(stream)
43 | self.held_refs.discard(stream)
44 | #stream.held_refs.discard(self) # not needed as only subscriber should hold ref
45 |
46 | def to_stream_item(self, val:Any):
47 | stream_item = val if isinstance(val, StreamItem) else \
48 | StreamItem(value=val, stream_name=self.stream_name)
49 | if stream_item.stream_name is None:
50 | stream_item.stream_name = self.stream_name
51 | if stream_item.item_index is None:
52 | stream_item.item_index = self.items_written
53 | return stream_item
54 |
55 | def write(self, val:Any, from_stream:'Stream'=None):
56 | # if you override write method, first you must call self.to_stream_item
57 | # so it can stamp the stamp the stream name
58 | stream_item = self.to_stream_item(val)
59 |
60 | if self.console_debug:
61 | print(self.stream_name, stream_item)
62 |
63 | for subscriber in self._subscribers:
64 | subscriber.write(stream_item, from_stream=self)
65 | self.items_written += 1
66 |
67 | def read_all(self, from_stream:'Stream'=None):
68 | for subscribed_to in self._subscribed_to:
69 | for stream_item in subscribed_to.read_all(from_stream=self):
70 | yield stream_item
71 |
72 | def load(self, from_stream:'Stream'=None):
73 | for subscribed_to in self._subscribed_to:
74 | subscribed_to.load(from_stream=self)
75 |
76 | def save(self, from_stream:'Stream'=None):
77 | for subscriber in self._subscribers:
78 | subscriber.save(from_stream=self)
79 |
80 | def close(self):
81 | if not self.closed:
82 | for subscribed_to in self._subscribed_to:
83 | subscribed_to._subscribers.discard(self)
84 | self._subscribed_to.clear()
85 | self.closed = True
86 |
87 | def __enter__(self):
88 | return self
89 |
90 | def __exit__(self, exception_type, exception_value, traceback):
91 | self.close()
92 |
93 |
--------------------------------------------------------------------------------
/tensorwatch/stream_factory.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from typing import Dict, Sequence, List, Any
5 | from .zmq_stream import ZmqStream
6 | from .file_stream import FileStream
7 | from .stream import Stream
8 | from .stream_union import StreamUnion
9 | import uuid
10 | from . import utils
11 |
12 | class StreamFactory:
13 | r"""Allows to create shared stream such as file and ZMQ streams
14 | """
15 |
16 | def __init__(self)->None:
17 | self.closed = None
18 | self._streams:Dict[str, Stream] = None
19 | self._reset()
20 |
21 | def _reset(self):
22 | self._streams:Dict[str, Stream] = {}
23 | self.closed = False
24 |
25 | def close(self):
26 | if not self.closed:
27 | for stream in self._streams.values():
28 | stream.close()
29 | self._reset()
30 | self.closed = True
31 |
32 | def __enter__(self):
33 | return self
34 | def __exit__(self, exception_type, exception_value, traceback):
35 | self.close()
36 |
37 | def get_streams(self, stream_types:Sequence[str], for_write:bool=None)->List[Stream]:
38 | streams = [self._create_stream_by_string(stream_type, for_write) for stream_type in stream_types]
39 | return streams
40 |
41 | def get_combined_stream(self, stream_types:Sequence[str], for_write:bool=None)->Stream:
42 | streams = [self._create_stream_by_string(stream_type, for_write) for stream_type in stream_types]
43 | if len(streams) == 1:
44 | return self._streams[0]
45 | else:
46 | # we create new union of child but this is not necessory
47 | return StreamUnion(streams, for_write=for_write)
48 |
49 | def _get_stream_name(stream_type:str, stream_args:Any, for_write:bool)->str:
50 | return '{}:{}:{}'.format(stream_type, stream_args, for_write)
51 |
52 | def _create_stream_by_string(self, stream_spec:str, for_write:bool)->Stream:
53 | parts = stream_spec.split(':', 1) if stream_spec is not None else ['']
54 | stream_type = parts[0]
55 | stream_args = parts[1] if len(parts) > 1 else None
56 |
57 | utils.debug_log("Creating stream", (stream_spec, for_write))
58 |
59 | if stream_type == 'tcp':
60 | port = int(stream_args or 0)
61 | stream_name = StreamFactory._get_stream_name(stream_type, port, for_write)
62 | if stream_name not in self._streams:
63 | self._streams[stream_name] = ZmqStream(for_write=for_write,
64 | port=port, stream_name=stream_name, block_until_connected=False)
65 | # else we already have this stream
66 | return self._streams[stream_name]
67 |
68 |
69 | if stream_args is None: # file name specified without 'file:' prefix
70 | stream_args = stream_type
71 | stream_type = 'file'
72 | if len(stream_type) == 1: # windows drive letter
73 | stream_type = 'file'
74 | stream_args = stream_spec
75 |
76 | if stream_type == 'file':
77 | if stream_args is None:
78 | raise ValueError('File name must be specified for stream type "file"')
79 | stream_name = StreamFactory._get_stream_name(stream_type, stream_args, for_write)
80 | # each read only file stream should be separate stream or otheriwse sharing will
81 | # change seek positions
82 | if not for_write:
83 | stream_name += ':' + str(uuid.uuid4())
84 |
85 | # if write file exist then flush it before read stream would read it
86 | write_stream_name = StreamFactory._get_stream_name(stream_type, stream_args, True)
87 | write_file_stream = self._streams.get(write_stream_name, None)
88 | if write_file_stream:
89 | write_file_stream.save()
90 | if stream_name not in self._streams:
91 | self._streams[stream_name] = FileStream(for_write=for_write,
92 | file_name=stream_args, stream_name=stream_name)
93 | # else we already have this stream
94 | return self._streams[stream_name]
95 |
96 | if stream_type == '':
97 | return Stream()
98 |
99 | raise ValueError('stream_type "{}" has unknown type'.format(stream_type))
100 |
101 |
102 |
--------------------------------------------------------------------------------
/tensorwatch/stream_union.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from .stream import Stream
5 | from typing import Iterator
6 |
7 | class StreamUnion(Stream):
8 | def __init__(self, child_streams:Iterator[Stream], for_write:bool, stream_name:str=None, console_debug:bool=False) -> None:
9 | super(StreamUnion, self).__init__(stream_name=stream_name, console_debug=console_debug)
10 |
11 | # save references, child streams does away only if parent goes away
12 | self.child_streams = child_streams
13 |
14 | # when someone does write to us, we write to all our listeners
15 | if for_write:
16 | for child_stream in child_streams:
17 | child_stream.subscribe(self)
18 | else:
19 | # union of all child streams
20 | for child_stream in child_streams:
21 | self.subscribe(child_stream)
--------------------------------------------------------------------------------
/tensorwatch/tensor_utils.py:
--------------------------------------------------------------------------------
1 | from typing import Sized, Any, List
2 | import numbers
3 | import numpy as np
4 |
5 | class TensorType:
6 | Torch = 'torch'
7 | TF = 'tf'
8 | Numeric = 'numeric'
9 | Numpy = 'numpy'
10 | Other = 'other'
11 |
12 | def tensor_type(item:Any)->str:
13 | module_name = type(item).__module__
14 | class_name = type(item).__name__
15 |
16 | if module_name=='torch' and class_name=='Tensor':
17 | return TensorType.Torch
18 | elif module_name.startswith('tensorflow') and class_name=='EagerTensor':
19 | return TensorType.TF
20 | elif isinstance(item, numbers.Number):
21 | return TensorType.Numeric
22 | elif isinstance(item, np.ndarray):
23 | return TensorType.Numpy
24 | else:
25 | return TensorType.Other
26 |
27 | def tensor2scaler(item:Any)->numbers.Number:
28 | tt = tensor_type(item)
29 | if item is None or tt == TensorType.Numeric:
30 | return item
31 | if tt == TensorType.TF:
32 | item = item.numpy()
33 | return item.item() # works for torch and numpy
34 |
35 | def tensor2np(item:Any)->np.ndarray:
36 | tt = tensor_type(item)
37 | if item is None or tt == TensorType.Numpy:
38 | return item
39 | elif tt == TensorType.TF:
40 | return item.numpy()
41 | elif tt == TensorType.Torch:
42 | return item.data.cpu().numpy()
43 | else: # numeric and everything else, let np take care
44 | return np.array(item)
45 |
46 | def to_scaler_list(l:Sized)->List[numbers.Number]:
47 | """Create list of scalers for given list of tensors where each element is 0-dim tensor
48 | """
49 | if l is not None and len(l):
50 | tt = tensor_type(l[0])
51 | if tt == TensorType.Torch or tt == TensorType.Numpy:
52 | return [i.item() for i in l]
53 | elif tt == TensorType.TF:
54 | return [i.numpy().item() for i in l]
55 | elif tt == TensorType.Numeric:
56 | # convert to list in case l is not list type
57 | return [i for i in l]
58 | else:
59 | raise ValueError('Cannot convert tensor list to scaler list \
60 | because list element are of unsupported type ' + tt)
61 | else:
62 | return None if l is None else [] # make sure we always return list type
63 |
64 | def to_mean_list(l:Sized)->List[float]:
65 | """Create list of scalers for given list of tensors where each element is 0-dim tensor
66 | """
67 | if l is not None and len(l):
68 | tt = tensor_type(l[0])
69 | if tt == TensorType.Torch or tt == TensorType.Numpy:
70 | return [i.mean() for i in l]
71 | elif tt == TensorType.TF:
72 | return [i.numpy().mean() for i in l]
73 | elif tt == TensorType.Numeric:
74 | # convert to list in case l is not list type
75 | return [float(i) for i in l]
76 | else:
77 | raise ValueError('Cannot convert tensor list to scaler list \
78 | because list element are of unsupported type ' + tt)
79 | else:
80 | return None if l is None else []
81 |
82 | def to_np_list(l:Sized)->List[np.ndarray]:
83 | if l is not None and len(l):
84 | tt = tensor_type(l[0])
85 | if tt == TensorType.Numeric:
86 | return [np.array(i) for i in l]
87 | if tt == TensorType.TF:
88 | return [i.numpy() for i in l]
89 | if tt == TensorType.Torch:
90 | return [i.data.cpu().numpy() for i in l]
91 | if tt == TensorType.Numpy:
92 | return [i for i in l]
93 | raise ValueError('Cannot convert tensor list to scaler list \
94 | because list element are of unsupported type ' + tt)
95 | else:
96 | return None if l is None else []
--------------------------------------------------------------------------------
/tensorwatch/text_vis.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from . import utils
5 | from .vis_base import VisBase
6 |
7 | class TextVis(VisBase):
8 | def __init__(self, cell:VisBase.widgets.Box=None, title:str=None, show_legend:bool=None,
9 | stream_name:str=None, console_debug:bool=False, **vis_args):
10 | import pandas as pd # expensive import
11 |
12 | super(TextVis, self).__init__(VisBase.widgets.HTML(), cell, title, show_legend,
13 | stream_name=stream_name, console_debug=console_debug, **vis_args)
14 | self.df = pd.DataFrame([])
15 | self.SeriesClass = pd.Series
16 |
17 | def _get_column_prefix(self, stream_vis, i):
18 | return '[S.{}]:{}'.format(stream_vis.index, i)
19 |
20 | def _get_title(self, stream_vis):
21 | title = stream_vis.title or 'Stream ' + str(len(self._stream_vises))
22 | return title
23 |
24 | # this will be called from _show_stream_items
25 | def _append(self, stream_vis, vals):
26 | if vals is None:
27 | self.df = self.df.append(self.SeriesClass({self._get_column_prefix(stream_vis, 0) : None}),
28 | sort=False, ignore_index=True)
29 | return
30 | for val in vals:
31 | if val is None or utils.is_scalar(val):
32 | self.df = self.df.append(self.SeriesClass({self._get_column_prefix(stream_vis, 0) : val}),
33 | sort=False, ignore_index=True)
34 | elif utils.is_array_like(val):
35 | val_dict = {}
36 | for i,val_i in enumerate(val):
37 | val_dict[self._get_column_prefix(stream_vis, i)] = val_i
38 | self.df = self.df.append(self.SeriesClass(val_dict), sort=False, ignore_index=True)
39 | else:
40 | self.df = self.df.append(self.SeriesClass(val.__dict__), sort=False, ignore_index=True)
41 |
42 | def _post_add_subscription(self, stream_vis, **stream_vis_args):
43 | only_summary = stream_vis_args.get('only_summary', False)
44 | stream_vis.text = self._get_title(stream_vis)
45 | stream_vis.only_summary = only_summary
46 |
47 | def clear_plot(self, stream_vis, clear_history):
48 | self.df = self.df.iloc[0:0]
49 |
50 | def _show_stream_items(self, stream_vis, stream_items):
51 | """Paint the given stream_items in to visualizer. If visualizer is dirty then return False else True.
52 | """
53 | for stream_item in stream_items:
54 | if stream_item.ended:
55 | self.df = self.df.append(self.SeriesClass({'Ended':True}),
56 | sort=False, ignore_index=True)
57 | else:
58 | vals = self._extract_vals((stream_item,))
59 | self._append(stream_vis, vals)
60 | return False # dirty
61 |
62 | def _post_update_stream_plot(self, stream_vis):
63 | if VisBase.get_ipython():
64 | if not stream_vis.only_summary:
65 | self.widget.value = self.df.to_html(classes=['output_html', 'rendered_html'])
66 | else:
67 | self.widget.value = self.df.describe().to_html(classes=['output_html', 'rendered_html'])
68 | # below doesn't work because of threading issue
69 | #self.widget.clear_output(wait=True)
70 | #with self.widget:
71 | # display.display(self.df)
72 | else:
73 | last_recs = self.df.iloc[[-1]].to_dict('records')
74 | if len(last_recs) == 1:
75 | print(last_recs[0])
76 | else:
77 | print(last_recs)
78 |
79 | def _show_widget_native(self, blocking:bool):
80 | return None # we will be using console
81 |
82 | def _show_widget_notebook(self):
83 | return self.widget
--------------------------------------------------------------------------------
/tensorwatch/visualizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from .stream import Stream
5 | from .vis_base import VisBase
6 | from . import mpl
7 | from . import plotly
8 |
9 | class Visualizer:
10 | """Constructs visualizer for specified vis_type.
11 |
12 | NOTE: If you modify arguments here then also sync VisArgs contructor.
13 | """
14 | def __init__(self, stream:Stream, vis_type:str=None, host:'Visualizer'=None,
15 | cell:'Visualizer'=None, title:str=None,
16 | clear_after_end=False, clear_after_each=False, history_len=1, dim_history=True, opacity=None,
17 |
18 | rows=2, cols=5, img_width=None, img_height=None, img_channels=None,
19 | colormap=None, viz_img_scale=None,
20 |
21 | # these image params are for hover on point for t-sne
22 | hover_images=None, hover_image_reshape=None, cell_width:str=None, cell_height:str=None,
23 |
24 | only_summary=False, separate_yaxis=True, xtitle=None, ytitle=None, ztitle=None, color=None,
25 | xrange=None, yrange=None, zrange=None, draw_line=True, draw_marker=False,
26 |
27 | # histogram
28 | bins=None, normed=None, histtype='bar', edge_color=None, linewidth=None, bar_width=None,
29 |
30 | # pie chart
31 | autopct=None, shadow=None,
32 |
33 | vis_args={}, stream_vis_args={})->None:
34 |
35 | cell = cell._host_base.cell if cell is not None else None
36 |
37 | if host:
38 | self._host_base = host._host_base
39 | else:
40 | self._host_base = self._get_vis_base(vis_type, cell, title, hover_images=hover_images, hover_image_reshape=hover_image_reshape,
41 | cell_width=cell_width, cell_height=cell_height,
42 | **vis_args)
43 |
44 | self._host_base.subscribe(stream, show=False, clear_after_end=clear_after_end, clear_after_each=clear_after_each,
45 | history_len=history_len, dim_history=dim_history, opacity=opacity,
46 | only_summary=only_summary if vis_type is None or 'summary' != vis_type else True,
47 | separate_yaxis=separate_yaxis, xtitle=xtitle, ytitle=ytitle, ztitle=ztitle, color=color,
48 | xrange=xrange, yrange=yrange, zrange=zrange,
49 | draw_line=draw_line if vis_type is not None and 'scatter' in vis_type else True,
50 | draw_marker=draw_marker,
51 | rows=rows, cols=cols, img_width=img_width, img_height=img_height, img_channels=img_channels,
52 | colormap=colormap, viz_img_scale=viz_img_scale,
53 | bins=bins, normed=normed, histtype=histtype, edge_color=edge_color, linewidth=linewidth, bar_width = bar_width,
54 | autopct=autopct, shadow=shadow,
55 | **stream_vis_args)
56 |
57 | stream.load()
58 |
59 | def show(self):
60 | return self._host_base.show()
61 |
62 | def save(self, filepath:str)->None:
63 | self._host_base.save(filepath)
64 |
65 | def _get_vis_base(self, vis_type, cell:VisBase.widgets.Box, title, hover_images=None, hover_image_reshape=None, cell_width=None, cell_height=None, **vis_args)->VisBase:
66 | if vis_type is None or vis_type in ['line',
67 | 'mpl-line', 'mpl-line3d', 'mpl-scatter3d', 'mpl-scatter']:
68 | return mpl.line_plot.LinePlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height,
69 | is_3d=vis_type is not None and vis_type.endswith('3d'), **vis_args)
70 | if vis_type in ['image', 'mpl-image']:
71 | return mpl.image_plot.ImagePlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, **vis_args)
72 | if vis_type in ['bar', 'bar3d']:
73 | return mpl.bar_plot.BarPlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height,
74 | is_3d=vis_type.endswith('3d'), **vis_args)
75 | if vis_type in ['histogram']:
76 | return mpl.histogram.Histogram(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, **vis_args)
77 | if vis_type in ['pie']:
78 | return mpl.pie_chart.PieChart(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, **vis_args)
79 | if vis_type in ['text', 'summary']:
80 | from .text_vis import TextVis
81 | return TextVis(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, **vis_args)
82 | if vis_type in ['line3d', 'scatter', 'scatter3d',
83 | 'plotly-line', 'plotly-line3d', 'plotly-scatter', 'plotly-scatter3d', 'mesh3d']:
84 | return plotly.line_plot.LinePlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height,
85 | is_3d=vis_type.endswith('3d'), **vis_args)
86 | if vis_type in ['tsne', 'embeddings', 'tsne2d', 'embeddings2d']:
87 | return plotly.embeddings_plot.EmbeddingsPlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height,
88 | is_3d='2d' not in vis_type,
89 | hover_images=hover_images, hover_image_reshape=hover_image_reshape, **vis_args)
90 | else:
91 | raise ValueError('Render vis_type parameter has invalid value: "{}"'.format(vis_type))
92 |
--------------------------------------------------------------------------------
/tensorwatch/watcher.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import uuid
5 | from typing import Sequence
6 | from .zmq_wrapper import ZmqWrapper
7 | from .watcher_base import WatcherBase
8 | from .lv_types import CliSrvReqTypes
9 | from .lv_types import DefaultPorts, PublisherTopics, ServerMgmtMsg
10 | from . import utils
11 | import threading, time
12 |
13 | class Watcher(WatcherBase):
14 | def __init__(self, filename:str=None, port:int=0, srv_name:str=None):
15 | super(Watcher, self).__init__()
16 |
17 | self.port = port
18 | self.filename = filename
19 |
20 | # used to detect server restarts
21 | self.srv_name = srv_name or str(uuid.uuid4())
22 |
23 | # define vars in __init__
24 | self._clisrv = None
25 | self._zmq_stream_pub = None
26 | self._file = None
27 | self._th = None
28 |
29 | self._open_devices()
30 |
31 | def _open_devices(self):
32 | if self.port is not None:
33 | self._clisrv = ZmqWrapper.ClientServer(port=DefaultPorts.CliSrv+self.port,
34 | is_server=True, callback=self._clisrv_callback)
35 |
36 | # notify existing listeners of our ID
37 | self._zmq_stream_pub = self._stream_factory.get_streams(stream_types=['tcp:'+str(self.port)], for_write=True)[0]
38 |
39 | # ZMQ quirk: we must wait a bit after opening port and before sending message
40 | # TODO: can we do better?
41 | self._th = threading.Thread(target=self._send_server_start)
42 | self._th.start()
43 |
44 | def _send_server_start(self):
45 | time.sleep(2)
46 | self._zmq_stream_pub.write(ServerMgmtMsg(event_name=ServerMgmtMsg.EventServerStart,
47 | event_args=self.srv_name), topic=PublisherTopics.ServerMgmt)
48 |
49 | def devices_or_default(self, devices:Sequence[str])->Sequence[str]: # overriden
50 | # TODO: this method is duplicated in Watcher and WatcherClient
51 |
52 | # make sure TCP port is attached to tcp device
53 | if devices is not None:
54 | return ['tcp:' + str(self.port) if device=='tcp' else device for device in devices]
55 |
56 | # if no devices specified then use our filename and tcp:port as default devices
57 | devices = []
58 | # first open file device because it may have older data
59 | if self.filename is not None:
60 | devices.append('file:' + self.filename)
61 | if self.port is not None:
62 | devices.append('tcp:' + str(self.port))
63 | return devices
64 |
65 | def close(self):
66 | if not self.closed:
67 | if self._clisrv is not None:
68 | self._clisrv.close()
69 | if self._zmq_stream_pub is not None:
70 | self._zmq_stream_pub.close()
71 | if self._file is not None:
72 | self._file.close()
73 | utils.debug_log("Watcher is closed", verbosity=1)
74 | super(Watcher, self).close()
75 |
76 | def _reset(self):
77 | self._clisrv = None
78 | self._zmq_stream_pub = None
79 | self._file = None
80 | self._th = None
81 | utils.debug_log("Watcher reset", verbosity=1)
82 | super(Watcher, self)._reset()
83 |
84 | def _clisrv_callback(self, clisrv, clisrv_req): # pylint: disable=unused-argument
85 | utils.debug_log("Received client request", clisrv_req.req_type)
86 |
87 | # request = create stream
88 | if clisrv_req.req_type == CliSrvReqTypes.create_stream:
89 | stream_req = clisrv_req.req_data
90 | self.create_stream(name=stream_req.stream_name, devices=stream_req.devices,
91 | event_name=stream_req.event_name, expr=stream_req.expr, throttle=stream_req.throttle,
92 | vis_args=stream_req.vis_args)
93 | return None # ignore return as we can't send back stream obj
94 | elif clisrv_req.req_type == CliSrvReqTypes.del_stream:
95 | stream_name = clisrv_req.req_data
96 | return self.del_stream(stream_name)
97 | else:
98 | raise ValueError('ClientServer Request Type {} is not recognized'.format(clisrv_req))
--------------------------------------------------------------------------------
/tensorwatch/watcher_client.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from typing import Any, Dict, Sequence, List
5 | from .zmq_wrapper import ZmqWrapper
6 | from .lv_types import CliSrvReqTypes, ClientServerRequest, DefaultPorts
7 | from .lv_types import VisArgs, PublisherTopics, ServerMgmtMsg, StreamCreateRequest
8 | from .stream import Stream
9 | from .zmq_mgmt_stream import ZmqMgmtStream
10 | from . import utils
11 | from .watcher_base import WatcherBase
12 |
13 | class WatcherClient(WatcherBase):
14 | r"""Extends watcher to add methods so calls for create and delete stream can be sent to server.
15 | """
16 | def __init__(self, filename:str=None, port:int=0):
17 | super(WatcherClient, self).__init__()
18 | self.port = port
19 | self.filename = filename
20 |
21 | # define vars in __init__
22 | self._clisrv = None # client-server sockets allows to send create/del stream requests
23 | self._zmq_srvmgmt_sub = None
24 | self._file = None
25 |
26 | self._open()
27 |
28 | def _reset(self):
29 | self._clisrv = None
30 | self._zmq_srvmgmt_sub = None
31 | self._file = None
32 | utils.debug_log("WatcherClient reset", verbosity=1)
33 | super(WatcherClient, self)._reset()
34 |
35 | def _open(self):
36 | if self.port is not None:
37 | self._clisrv = ZmqWrapper.ClientServer(port=DefaultPorts.CliSrv+self.port,
38 | is_server=False)
39 | # create subscription where we will receive server management events
40 | self._zmq_srvmgmt_sub = ZmqMgmtStream(clisrv=self._clisrv, for_write=False, port=self.port,
41 | stream_name='zmq_srvmgmt_sub:'+str(self.port)+':False')
42 |
43 | def close(self):
44 | if not self.closed:
45 | self._zmq_srvmgmt_sub.close()
46 | self._clisrv.close()
47 | utils.debug_log("WatcherClient is closed", verbosity=1)
48 | super(WatcherClient, self).close()
49 |
50 | def devices_or_default(self, devices:Sequence[str])->Sequence[str]: # overridden
51 | # TODO: this method is duplicated in Watcher and WatcherClient
52 |
53 | # make sure TCP port is attached to tcp device
54 | if devices is not None:
55 | return ['tcp:' + str(self.port) if device=='tcp' else device for device in devices]
56 |
57 | # if no devices specified then use our filename and tcp:port as default devices
58 | devices = []
59 | # first open file device because it may have older data
60 | if self.filename is not None:
61 | devices.append('file:' + self.filename)
62 | if self.port is not None:
63 | devices.append('tcp:' + str(self.port))
64 | return devices
65 |
66 | # override to send request to server, instead of underlying WatcherBase base class
67 | def create_stream(self, name:str=None, devices:Sequence[str]=None, event_name:str='',
68 | expr=None, throttle:float=1, vis_args:VisArgs=None)->Stream: # overriden
69 |
70 | stream_req = StreamCreateRequest(stream_name=name, devices=self.devices_or_default(devices),
71 | event_name=event_name, expr=expr, throttle=throttle, vis_args=vis_args)
72 |
73 | self._zmq_srvmgmt_sub.add_stream_req(stream_req)
74 |
75 | if stream_req.devices is not None:
76 | stream = self.open_stream(name=stream_req.stream_name, devices=stream_req.devices)
77 | else: # we cannot return remote streams that are not backed by a device
78 | stream = None
79 | return stream
80 |
81 | # override to set devices default to tcp
82 | def open_stream(self, name:str=None, devices:Sequence[str]=None)->Stream: # overriden
83 | return super(WatcherClient, self).open_stream(name=name, devices=devices)
84 |
85 |
86 | # override to send request to server
87 | def del_stream(self, name:str) -> None:
88 | self._zmq_srvmgmt_sub.del_stream(name)
89 |
90 |
--------------------------------------------------------------------------------
/tensorwatch/zmq_mgmt_stream.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from typing import Any
5 | from .zmq_stream import ZmqStream
6 | from .lv_types import PublisherTopics, ServerMgmtMsg, StreamCreateRequest
7 | from .zmq_wrapper import ZmqWrapper
8 | from .lv_types import CliSrvReqTypes, ClientServerRequest
9 | from . import utils
10 |
11 | class ZmqMgmtStream(ZmqStream):
12 | # default topic is mgmt
13 | def __init__(self, clisrv:ZmqWrapper.ClientServer, for_write:bool, port:int=0, topic=PublisherTopics.ServerMgmt, block_until_connected=True,
14 | stream_name:str=None, console_debug:bool=False):
15 | super(ZmqMgmtStream, self).__init__(for_write=for_write, port=port, topic=topic,
16 | block_until_connected=block_until_connected, stream_name=stream_name, console_debug=console_debug)
17 |
18 | self._clisrv = clisrv
19 | self._stream_reqs:Dict[str,StreamCreateRequest] = {}
20 |
21 | def write(self, val:Any, from_stream:'Stream'=None):
22 | r"""Handles server management events.
23 | """
24 | stream_item = self.to_stream_item(val)
25 | mgmt_msg = stream_item.value
26 |
27 | utils.debug_log("Received - SeverMgmtevent", mgmt_msg)
28 | # if server was restarted then send create stream requests again
29 | if mgmt_msg.event_name == ServerMgmtMsg.EventServerStart:
30 | for stream_req in self._stream_reqs.values():
31 | self._send_create_stream(stream_req)
32 |
33 | super(ZmqMgmtStream, self).write(stream_item)
34 |
35 | def add_stream_req(self, stream_req:StreamCreateRequest)->None:
36 | self._send_create_stream(stream_req)
37 |
38 | # save this for later for resend if server restarts
39 | self._stream_reqs[stream_req.stream_name] = stream_req
40 |
41 | # override to send request to server
42 | def del_stream(self, name:str) -> None:
43 | clisrv_req = ClientServerRequest(CliSrvReqTypes.del_stream, name)
44 | self._clisrv.send_obj(clisrv_req)
45 | self._stream_reqs.pop(name, None)
46 |
47 | def _send_create_stream(self, stream_req):
48 | utils.debug_log("sending create streamreq...")
49 | clisrv_req = ClientServerRequest(CliSrvReqTypes.create_stream, stream_req)
50 | self._clisrv.send_obj(clisrv_req)
51 | utils.debug_log("sent create streamreq")
52 |
53 | def close(self):
54 | if not self.closed:
55 | self._stream_reqs = {}
56 | self._clisrv = None
57 | utils.debug_log('ZmqMgmtStream is closed', verbosity=1)
58 | super(ZmqMgmtStream, self).close()
59 |
--------------------------------------------------------------------------------
/tensorwatch/zmq_stream.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from typing import Any
5 | from .zmq_wrapper import ZmqWrapper
6 | from .stream import Stream
7 | from .lv_types import DefaultPorts, PublisherTopics
8 | from . import utils
9 |
10 | # on writes send data on ZMQ transport
11 | class ZmqStream(Stream):
12 | def __init__(self, for_write:bool, port:int=0, topic=PublisherTopics.StreamItem, block_until_connected=True,
13 | stream_name:str=None, console_debug:bool=False):
14 | super(ZmqStream, self).__init__(stream_name=stream_name, console_debug=console_debug)
15 |
16 | self.for_write = for_write
17 | self._zmq = None
18 |
19 | self.topic = topic
20 | self._open(for_write, port, block_until_connected)
21 | utils.debug_log('ZmqStream started', verbosity=1)
22 |
23 | def _open(self, for_write:bool, port:int, block_until_connected:bool):
24 | if for_write:
25 | self._zmq = ZmqWrapper.Publication(port=DefaultPorts.PubSub+port,
26 | block_until_connected=block_until_connected)
27 | else:
28 | self._zmq = ZmqWrapper.Subscription(port=DefaultPorts.PubSub+port,
29 | topic=self.topic, callback=self._on_subscription_item)
30 |
31 | def close(self):
32 | if not self.closed:
33 | self._zmq.close()
34 | self._zmq = None
35 | utils.debug_log('ZmqStream is closed', verbosity=1)
36 | super(ZmqStream, self).close()
37 |
38 | def _on_subscription_item(self, val:Any):
39 | utils.debug_log('Received subscription item', verbosity=5)
40 | self.write(val)
41 |
42 | def write(self, val:Any, from_stream:'Stream'=None, topic=None):
43 | stream_item = self.to_stream_item(val)
44 |
45 | if self.for_write:
46 | topic = topic or self.topic
47 | utils.debug_log('Sent subscription item', verbosity=5)
48 | self._zmq.send_obj(stream_item, topic)
49 | # else if this was opened for read then we have subscription and
50 | # we shouldn't be calling send_obj
51 | super(ZmqStream, self).write(stream_item)
52 |
--------------------------------------------------------------------------------
/test/components/circ_ref.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 | import objgraph, time #pip install objgraph
3 |
4 | cli = tw.WatcherClient()
5 | time.sleep(10)
6 | del cli
7 |
8 | import gc
9 | gc.collect()
10 |
11 | import time
12 | time.sleep(2)
13 |
14 | objgraph.show_backrefs(objgraph.by_type('WatcherClient'), refcounts=True, filename='b.png')
15 |
16 |
--------------------------------------------------------------------------------
/test/components/evaler.py:
--------------------------------------------------------------------------------
1 | #import tensorwatch as tw
2 | from tensorwatch import evaler
3 |
4 |
5 | e = evaler.Evaler(expr='reduce(lambda x,y: (x+y), map(lambda x:(x**2), filter(lambda x: x%2==0, l)))')
6 | for i in range(5):
7 | eval_return = e.post(i)
8 | print(i, eval_return)
9 | eval_return = e.post(ended=True)
10 | print(i, eval_return)
11 |
--------------------------------------------------------------------------------
/test/components/file_only_test.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 |
3 | def writer():
4 | watcher = tw.Watcher(filename=r'c:\temp\test.log', port=None)
5 |
6 | with watcher.create_stream('metric1') as stream1:
7 | for i in range(3):
8 | stream1.write((i, i*i))
9 |
10 | with watcher.create_stream('metric2') as stream2:
11 | for i in range(3):
12 | stream2.write((i, i*i*i))
13 |
14 | def reader1():
15 | print('---------------------------reader1---------------------------')
16 | watcher = tw.Watcher(filename=r'c:\temp\test.log', port=None)
17 |
18 | stream1 = watcher.open_stream('metric1')
19 | stream1.console_debug = True
20 | stream1.load()
21 |
22 | stream2 = watcher.open_stream('metric2')
23 | stream2.console_debug = True
24 | stream2.load()
25 |
26 | def reader2():
27 | print('---------------------------reader2---------------------------')
28 |
29 | watcher = tw.Watcher(filename=r'c:\temp\test.log', port=None)
30 |
31 | stream1 = watcher.open_stream('metric1')
32 | for item in stream1.read_all():
33 | print(item)
34 |
35 | stream2 = watcher.open_stream('metric2')
36 | for item in stream2.read_all():
37 | print(item)
38 |
39 | def reader3():
40 | print('---------------------------reader3---------------------------')
41 |
42 | watcher = tw.Watcher(filename=r'c:\temp\test.log', port=None)
43 | stream1 = watcher.open_stream('metric1')
44 | stream2 = watcher.open_stream('metric2')
45 |
46 | vis1 = tw.Visualizer(stream1, vis_type='line')
47 | vis2 = tw.Visualizer(stream2, vis_type='line', host=vis1)
48 |
49 | vis1.show()
50 |
51 | tw.plt_loop()
52 |
53 |
54 | writer()
55 | reader1()
56 | reader2()
57 | reader3()
58 |
59 |
--------------------------------------------------------------------------------
/test/components/notebook_maker.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 |
3 |
4 | def main():
5 | w = tw.Watcher()
6 | s1 = w.create_stream()
7 | s2 = w.create_stream(name='accuracy', vis_args=tw.VisArgs(vis_type='line', xtitle='X-Axis', clear_after_each=False, history_len=2))
8 | s3 = w.create_stream(name='loss', expr='lambda d:d.loss')
9 | w.make_notebook()
10 |
11 | main()
12 |
13 |
--------------------------------------------------------------------------------
/test/components/stream.py:
--------------------------------------------------------------------------------
1 | from tensorwatch.stream import Stream
2 |
3 |
4 | s1 = Stream(stream_name='s1', console_debug=True)
5 | s2 = Stream(stream_name='s2', console_debug=True)
6 | s3 = Stream(stream_name='s3', console_debug=True)
7 |
8 | s1.subscribe(s2)
9 | s2.subscribe(s3)
10 |
11 | s3.write('S3 wrote this')
12 | s2.write('S2 wrote this')
13 | s1.write('S1 wrote this')
14 |
15 |
--------------------------------------------------------------------------------
/test/components/watcher.py:
--------------------------------------------------------------------------------
1 | from tensorwatch.watcher_base import WatcherBase
2 | from tensorwatch.stream import Stream
3 |
4 | def main():
5 | watcher = WatcherBase()
6 | console_pub = Stream(stream_name = 'S1', console_debug=True)
7 | stream = watcher.create_stream(expr='lambda vars:vars.x**2')
8 | console_pub.subscribe(stream)
9 |
10 | for i in range(5):
11 | watcher.observe(x=i)
12 |
13 | main()
14 |
15 |
16 |
--------------------------------------------------------------------------------
/test/deps/ipython_widget.py:
--------------------------------------------------------------------------------
1 | import ipywidgets as widgets
2 | from IPython import get_ipython
3 |
4 | class PrinterX:
5 | def __init__(self):
6 | self.w = w=widgets.HTML()
7 | def show(self):
8 | return self.w
9 | def write(self,s):
10 | self.w.value = s
11 |
12 | print("Running from within ipython?", get_ipython() is not None)
13 | p=PrinterX()
14 | p.show()
15 | p.write('ffffffffff')
16 |
--------------------------------------------------------------------------------
/test/deps/live_graph.py:
--------------------------------------------------------------------------------
1 | from matplotlib import pyplot as plt
2 | from matplotlib.animation import FuncAnimation
3 | from random import randrange
4 | from threading import Thread
5 | import time
6 |
7 | class LiveGraph:
8 | def __init__(self):
9 | self.x_data, self.y_data = [], []
10 | self.figure = plt.figure()
11 | self.line, = plt.plot(self.x_data, self.y_data)
12 | self.animation = FuncAnimation(self.figure, self.update, interval=1000)
13 | self.th = Thread(target=self.thread_f, name='LiveGraph', daemon=True)
14 | self.th.start()
15 |
16 | def update(self, frame):
17 | self.line.set_data(self.x_data, self.y_data)
18 | self.figure.gca().relim()
19 | self.figure.gca().autoscale_view()
20 | return self.line,
21 |
22 | def show(self):
23 | plt.show()
24 |
25 | def thread_f(self):
26 | x = 0
27 | while True:
28 | self.x_data.append(x)
29 | x += 1
30 | self.y_data.append(randrange(0, 100))
31 | time.sleep(1)
--------------------------------------------------------------------------------
/test/deps/panda.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 | class SomeThing:
4 | def __init__(self, x, y):
5 | self.x, self.y = x, y
6 |
7 | things = [SomeThing(1,2), SomeThing(3,4), SomeThing(4,5)]
8 |
9 | df = pd.DataFrame([t.__dict__ for t in things ])
10 |
11 | print(df.iloc[[-1]].to_dict('records')[0])
12 | #print(df.to_html())
13 | print(df.style.render())
14 |
--------------------------------------------------------------------------------
/test/deps/thread.py:
--------------------------------------------------------------------------------
1 | import threading
2 | import time
3 | import sys
4 |
5 | def handler(a,b=None):
6 | sys.exit(1)
7 | def install_handler():
8 | if sys.platform == "win32":
9 | if sys.stdin is not None and sys.stdin.isatty():
10 | #this is Console based application
11 | import win32api
12 | win32api.SetConsoleCtrlHandler(handler, True)
13 |
14 |
15 | def work():
16 | time.sleep(10000)
17 | t = threading.Thread(target=work, name='ThreadTest')
18 | t.daemon = True
19 | t.start()
20 | while(True):
21 | t.join(0.1) #100ms ~ typical human response
22 | # you will get KeyboardIntrupt exception
23 |
24 |
--------------------------------------------------------------------------------
/test/files/file_stream.py:
--------------------------------------------------------------------------------
1 | from tensorwatch.watcher_base import WatcherBase
2 | from tensorwatch.stream import Stream
3 | from tensorwatch.file_stream import FileStream
4 | from tensorwatch import LinePlot
5 | from tensorwatch.image_utils import plt_loop
6 | import tensorwatch as tw
7 |
8 | def file_write():
9 | watcher = WatcherBase()
10 | stream = watcher.create_stream(expr='lambda vars:(vars.x, vars.x**2)',
11 | devices=[r'c:\temp\obs.txt'])
12 |
13 | for i in range(5):
14 | watcher.observe(x=i)
15 |
16 | def file_read():
17 | watcher = WatcherBase()
18 | stream = watcher.open_stream(devices=[r'c:\temp\obs.txt'])
19 | vis = tw.Visualizer(stream, vis_type='mpl-line')
20 | vis.show()
21 | plt_loop()
22 |
23 | def main():
24 | file_write()
25 | file_read()
26 |
27 | main()
28 |
29 |
--------------------------------------------------------------------------------
/test/mnist/cli_mnist.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 | import time
3 | import math
4 | from tensorwatch import utils
5 |
6 | utils.set_debug_verbosity(4)
7 |
8 | def img_in_class():
9 | cli_train = tw.WatcherClient()
10 |
11 | imgs = cli_train.create_stream(event_name='batch',
12 | expr="topk_all(l, batch_vals=lambda b: (b.batch.loss_all, (b.batch.input, b.batch.output), b.batch.target), \
13 | out_f=image_class_outf, order='dsc')", throttle=1)
14 | img_plot = tw.ImagePlot()
15 | img_plot.subscribe(imgs, viz_img_scale=3)
16 | img_plot.show()
17 |
18 | tw.image_utils.plt_loop()
19 |
20 | def show_find_lr():
21 | cli_train = tw.WatcherClient()
22 | plot = tw.LinePlot()
23 |
24 | train_batch_loss = cli_train.create_stream(event_name='batch',
25 | expr='lambda d:(d.tt.scheduler.get_lr()[0], d.metrics.batch_loss)')
26 | plot.subscribe(train_batch_loss, xtitle='Epoch', ytitle='Loss')
27 |
28 | utils.wait_key()
29 |
30 | def plot_grads_plotly():
31 | train_cli = tw.WatcherClient()
32 | grads = train_cli.create_stream(event_name='batch',
33 | expr='lambda d:grads_abs_mean(d.model)', throttle=1)
34 | p = tw.plotly.line_plot.LinePlot('Demo')
35 | p.subscribe(grads, xtitle='Layer', ytitle='Gradients', history_len=30, new_on_eval=True)
36 | utils.wait_key()
37 |
38 |
39 | def plot_grads():
40 | train_cli = tw.WatcherClient()
41 |
42 | grads = train_cli.create_stream(event_name='batch',
43 | expr='lambda d:grads_abs_mean(d.model)', throttle=1)
44 | grad_plot = tw.LinePlot()
45 | grad_plot.subscribe(grads, xtitle='Layer', ytitle='Gradients', clear_after_each=1, history_len=40, dim_history=True)
46 | grad_plot.show()
47 |
48 | tw.plt_loop()
49 |
50 | def plot_weight():
51 | train_cli = tw.WatcherClient()
52 |
53 | params = train_cli.create_stream(event_name='batch',
54 | expr='lambda d:weights_abs_mean(d.model)', throttle=1)
55 | params_plot = tw.LinePlot()
56 | params_plot.subscribe(params, xtitle='Layer', ytitle='avg |params|', clear_after_each=1, history_len=40, dim_history=True)
57 | params_plot.show()
58 |
59 | tw.plt_loop()
60 |
61 | def epoch_stats():
62 | train_cli = tw.WatcherClient(port=0)
63 | test_cli = tw.WatcherClient(port=1)
64 |
65 | plot = tw.LinePlot()
66 |
67 | train_loss = train_cli.create_stream(event_name="epoch",
68 | expr='lambda v:(v.metrics.epoch_index, v.metrics.epoch_loss)')
69 | plot.subscribe(train_loss, xtitle='Epoch', ytitle='Train Loss')
70 |
71 | test_acc = test_cli.create_stream(event_name="epoch",
72 | expr='lambda v:(v.metrics.epoch_index, v.metrics.epoch_accuracy)')
73 | plot.subscribe(test_acc, xtitle='Epoch', ytitle='Test Accuracy', ylim=(0,1))
74 |
75 | plot.show()
76 | tw.plt_loop()
77 |
78 |
79 | def batch_stats():
80 | train_cli = tw.WatcherClient()
81 | stream = train_cli.create_stream(event_name="batch",
82 | expr='lambda v:(v.metrics.epochf, v.metrics.batch_loss)', throttle=0.75)
83 |
84 | train_loss = tw.Visualizer(stream, clear_after_end=False, vis_type='mpl-line',
85 | xtitle='Epoch', ytitle='Train Loss')
86 |
87 | #train_acc = tw.Visualizer('lambda v:(v.metrics.epochf, v.metrics.epoch_loss)', event_name="batch",
88 | # xtitle='Epoch', ytitle='Train Accuracy', clear_after_end=False, yrange=(0,1),
89 | # vis=train_loss, vis_type='mpl-line')
90 |
91 | train_loss.show()
92 | tw.plt_loop()
93 |
94 | def text_stats():
95 | train_cli = tw.WatcherClient()
96 | stream = train_cli.create_stream(event_name="batch",
97 | expr='lambda d:(d.metrics.epoch_index, d.metrics.batch_loss)')
98 |
99 | trl = tw.Visualizer(stream, vis_type='text')
100 | trl.show()
101 | input('Paused...')
102 |
103 |
104 |
105 | #epoch_stats()
106 | #plot_weight()
107 | #plot_grads()
108 | #img_in_class()
109 | text_stats()
110 | #batch_stats()
--------------------------------------------------------------------------------
/test/post_train/saliency.py:
--------------------------------------------------------------------------------
1 | from tensorwatch.saliency import saliency
2 | from tensorwatch import image_utils, imagenet_utils, pytorch_utils
3 |
4 | model = pytorch_utils.get_model('resnet50')
5 | raw_input, input, target_class = pytorch_utils.image_class2tensor('../data/test_images/dogs.png', 240, #'../data/elephant.png', 101,
6 | image_transform=imagenet_utils.get_image_transform(), image_convert_mode='RGB')
7 |
8 | results = saliency.get_image_saliency_results(model, raw_input, input, target_class)
9 | figure = saliency.get_image_saliency_plot(results)
10 |
11 | image_utils.plt_loop()
12 |
13 |
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/test/pre_train/dot_manual.bat:
--------------------------------------------------------------------------------
1 | dot -Tpng -O test\pre_train\sample.dot
--------------------------------------------------------------------------------
/test/pre_train/draw_cust_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.models
3 | import tensorwatch as tw
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import torchvision
9 | from torch.utils.tensorboard import SummaryWriter
10 | from torchvision import datasets, transforms
11 | writer = SummaryWriter()
12 | class Net(nn.Module):
13 |
14 | def __init__(self):
15 | super(Net, self).__init__()
16 | self.layer = nn.Sequential(
17 | nn.Conv2d(3, 64, 7, padding=1, stride=1),
18 | nn.ReLU(),
19 | nn.MaxPool2d(1, 2),
20 | nn.BatchNorm2d(64),
21 | nn.AvgPool2d(1, 1),
22 | nn.Dropout(0.5),
23 | nn.Linear(110,30)
24 | )
25 |
26 | def forward(self, x):
27 | return self.layer(x)
28 |
29 |
30 | net = Net()
31 | args = torch.ones([1, 3, 224, 224])
32 | # writer.add_graph(net, args)
33 | # writer.close()
34 |
35 | #vgg16_model = torchvision.models.vgg16()
36 |
37 | drawing = tw.draw_model(net, [1, 3, 224, 224])
38 | drawing.save('abc2.png')
39 |
40 | input("Press any key")
--------------------------------------------------------------------------------
/test/pre_train/draw_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.models
3 | import tensorwatch as tw
4 |
5 | vgg16_model = torchvision.models.vgg16()
6 |
7 | drawing = tw.draw_model(vgg16_model, [1, 3, 224, 224])
8 | drawing.save('abc.png')
9 |
10 | input("Press any key")
--------------------------------------------------------------------------------
/test/pre_train/model_stats.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 | import torchvision.models
3 |
4 | model_names = ['alexnet', 'resnet18', 'resnet34', 'resnet101', 'densenet121']
5 |
6 | for model_name in model_names:
7 | model = getattr(torchvision.models, model_name)()
8 | model_stats = tw.ModelStats(model, [1, 3, 224, 224], clone_model=False)
9 | print(f'{model_name}: flops={model_stats.Flops}, parameters={model_stats.parameters}, memory={model_stats.inference_memory}')
--------------------------------------------------------------------------------
/test/pre_train/model_stats_perf.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import tensorwatch as tw
3 | import torchvision.models
4 | import torch
5 | import time
6 |
7 | model = getattr(torchvision.models, 'densenet201')()
8 |
9 | def model_timing(model):
10 | st = time.time()
11 | for _ in range(20):
12 | batch = torch.rand([64, 3, 224, 224])
13 | y = model(batch)
14 | return time.time()-st
15 |
16 | print(model_timing(model))
17 | model_stats = tw.ModelStats(model, [1, 3, 224, 224], clone_model=False)
18 | print(f'flops={model_stats.Flops}, parameters={model_stats.parameters}, memory={model_stats.inference_memory}')
19 | print(model_timing(model))
20 |
--------------------------------------------------------------------------------
/test/pre_train/test_pydot.py:
--------------------------------------------------------------------------------
1 | import pydot
2 |
3 | g = pydot.Dot()
4 | g.set_type('digraph')
5 | node = pydot.Node('legend')
6 | node.set("shape", 'box')
7 | g.add_node(node)
8 | node.set('label', 'mine')
9 | s = g.to_string()
10 | expected = 'digraph G {\nlegend [label=mine, shape=box];\n}\n'
11 | assert s == expected
12 | print(s)
13 | png = g.create_png()
14 | print(png)
--------------------------------------------------------------------------------
/test/pre_train/tsny.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 |
3 | from regim import *
4 | ds = DataUtils.mnist_datasets(linearize=True, train_test=False)
5 | ds = DataUtils.sample_by_class(ds, k=5, shuffle=True, as_np=True, no_test=True)
6 |
7 | comps = tw.get_tsne_components(ds)
8 | print(comps)
9 | plot = tw.Visualizer(comps, hover_images=ds[0], hover_image_reshape=(28,28), vis_type='tsne')
10 | plot.show()
--------------------------------------------------------------------------------
/test/simple_log/cli_file_expr.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 | from tensorwatch import utils
3 | utils.set_debug_verbosity(4)
4 |
5 |
6 | #r = tw.Visualizer(vis_type='mpl-line')
7 | #r.show()
8 | #r2=tw.Visualizer('map(lambda x:math.sqrt(x.sum), l)', cell=r.cell)
9 | #r3=tw.Visualizer('map(lambda x:math.sqrt(x.sum), l)', renderer=r2)
10 |
11 | def show_mpl():
12 | cli = tw.WatcherClient(r'c:\temp\sum.log')
13 | s1 = cli.open_stream('sum')
14 | p = tw.LinePlot(title='Demo')
15 | p.subscribe(s1, xtitle='Index', ytitle='sqrt(ev_i)')
16 | s1.load()
17 | p.show()
18 |
19 | tw.plt_loop()
20 |
21 | def show_text():
22 | cli = tw.WatcherClient(r'c:\temp\sum.log')
23 | s1 = cli.open_stream('sum_2')
24 | text = tw.Visualizer(s1)
25 | text.show()
26 | input('Waiting')
27 |
28 | #show_text()
29 | show_mpl()
30 |
--------------------------------------------------------------------------------
/test/simple_log/cli_ij.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 | import time
3 | import math
4 | from tensorwatch import utils
5 | utils.set_debug_verbosity(4)
6 |
7 | def mpl_line_plot():
8 | cli = tw.WatcherClient()
9 | p = tw.LinePlot(title='Demo')
10 | s1 = cli.create_stream(event_name='ev_i', expr='map(lambda v:math.sqrt(v.val)*2, l)')
11 | p.subscribe(s1, xtitle='Index', ytitle='sqrt(ev_i)')
12 | p.show()
13 | tw.plt_loop()
14 |
15 | def mpl_history_plot():
16 | cli = tw.WatcherClient()
17 | p2 = tw.LinePlot(title='History Demo')
18 | p2s1 = cli.create_stream(event_name='ev_j', expr='map(lambda v:(v.val, math.sqrt(v.val)*2), l)')
19 | p2.subscribe(p2s1, xtitle='Index', ytitle='sqrt(ev_j)', clear_after_end=True, history_len=15)
20 | p2.show()
21 | tw.plt_loop()
22 |
23 | def show_stream():
24 | cli = tw.WatcherClient()
25 |
26 | print("Subscribing to event ev_i...")
27 | s1 = cli.create_stream(event_name="ev_i", expr='map(lambda v:math.sqrt(v.val), l)')
28 | r1 = tw.TextVis(title='L1')
29 | r1.subscribe(s1)
30 | r1.show()
31 |
32 | print("Subscribing to event ev_j...")
33 | s2 = cli.create_stream(event_name="ev_j", expr='map(lambda v:v.val*v.val, l)')
34 | r2 = tw.TextVis(title='L2')
35 | r2.subscribe(s2)
36 |
37 | r2.show()
38 |
39 | print("Waiting for key...")
40 |
41 | utils.wait_key()
42 |
43 | # this no longer directly supported
44 | # TODO: create stream that allows enumeration from buffered values
45 | #def read_stream():
46 | # cli = tw.WatcherClient()
47 |
48 | # with cli.create_stream(event_name="ev_i", expr='map(lambda v:(v.x, math.sqrt(v.val)), l)') as s1:
49 | # for stream_item in s1:
50 | # print(stream_item.value)
51 | # print('done')
52 | # utils.wait_key()
53 |
54 | def plotly_line_graph():
55 | cli = tw.WatcherClient()
56 | s1 = cli.create_stream(event_name="ev_i", expr='map(lambda v:(v.x, math.sqrt(v.val)), l)')
57 |
58 | p = tw.plotly.line_plot.LinePlot()
59 | p.subscribe(s1)
60 | p.show()
61 |
62 | utils.wait_key()
63 |
64 | def plotly_history_graph():
65 | cli = tw.WatcherClient()
66 | p = tw.plotly.line_plot.LinePlot(title='Demo')
67 | s2 = cli.create_stream(event_name='ev_j', expr='map(lambda v:(v.x, v.val), l)')
68 | p.subscribe(s2, ytitle='ev_j', history_len=15)
69 | p.show()
70 | utils.wait_key()
71 |
72 |
73 | mpl_line_plot()
74 | #mpl_history_plot()
75 | #show_stream()
76 | #plotly_line_graph()
77 | #plotly_history_graph()
78 |
--------------------------------------------------------------------------------
/test/simple_log/cli_sum_log.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 | from tensorwatch import utils
3 | utils.set_debug_verbosity(4)
4 |
5 |
6 | #r = tw.Visualizer(vis_type='mpl-line')
7 | #r.show()
8 | #r2=tw.Visualizer('map(lambda x:math.sqrt(x.sum), l)', cell=r.cell)
9 | #r3=tw.Visualizer('map(lambda x:math.sqrt(x.sum), l)', renderer=r2)
10 |
11 | def show_mpl():
12 | cli = tw.WatcherClient()
13 | st_isum = cli.open_stream('isums')
14 | st_rsum = cli.open_stream('rsums')
15 |
16 | line_plot = tw.Visualizer(st_isum, vis_type='line', xtitle='i', ytitle='isum')
17 | line_plot.show()
18 |
19 | line_plot2 = tw.Visualizer(st_rsum, vis_type='line', host=line_plot, ytitle='rsum')
20 |
21 | tw.plt_loop()
22 |
23 | def show_text():
24 | cli = tw.WatcherClient()
25 | text_vis = tw.Visualizer(st_isum, vis_type='text')
26 | text_vis.show()
27 | input('Waiting')
28 |
29 | #show_text()
30 | show_mpl()
--------------------------------------------------------------------------------
/test/simple_log/file_expr.py:
--------------------------------------------------------------------------------
1 | import time
2 | import tensorwatch as tw
3 | from tensorwatch import utils
4 | utils.set_debug_verbosity(4)
5 |
6 | srv = tw.Watcher(filename=r'c:\temp\sum.log')
7 | s1 = srv.create_stream('sum', expr='lambda v:(v.i, v.sum)')
8 | s2 = srv.create_stream('sum_2', expr='lambda v:(v.i, v.sum/2)')
9 |
10 | sum = 0
11 | for i in range(10000):
12 | sum += i
13 | srv.observe(i=i, sum=sum)
14 | #print(i, sum)
15 | time.sleep(1)
16 |
17 |
--------------------------------------------------------------------------------
/test/simple_log/quick_start.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 | import time
3 |
4 | w = tw.Watcher(filename='test.log')
5 | s = w.create_stream(name='my_metric')
6 | #w.make_notebook()
7 |
8 | for i in range(1000):
9 | s.write((i, i*i))
10 | time.sleep(1)
11 |
--------------------------------------------------------------------------------
/test/simple_log/srv_ij.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 | import time
3 | import random
4 | from tensorwatch import utils
5 |
6 | utils.set_debug_verbosity(4)
7 |
8 | srv = tw.Watcher()
9 |
10 | while(True):
11 | for i in range(1000):
12 | srv.observe("ev_i", val=i*random.random(), x=i)
13 | print('sent ev_i ', i)
14 | time.sleep(1)
15 | for j in range(5):
16 | srv.observe("ev_j", x=j, val=j*random.random())
17 | print('sent ev_j ', j)
18 | time.sleep(0.5)
19 | srv.end_event("ev_j")
20 | srv.end_event("ev_i")
--------------------------------------------------------------------------------
/test/simple_log/sum_lazy.py:
--------------------------------------------------------------------------------
1 | import time, random
2 | import tensorwatch as tw
3 |
4 | # create watcher object as usual
5 | w = tw.Watcher()
6 |
7 | weights = None
8 | for i in range(10000):
9 | weights = [random.random() for _ in range(5)]
10 |
11 | # let watcher observe variables we have
12 | # this has almost no performance cost
13 | w.observe(weights=weights)
14 |
15 | time.sleep(1)
16 |
17 |
--------------------------------------------------------------------------------
/test/simple_log/sum_log.py:
--------------------------------------------------------------------------------
1 | import time, random
2 | import tensorwatch as tw
3 |
4 | # we will create two streams, one for
5 | # sums of integers, other for random numbers
6 | w = tw.Watcher()
7 | st_isum = w.create_stream('isums')
8 | st_rsum = w.create_stream('rsums')
9 |
10 | isum, rsum = 0, 0
11 | for i in range(10000):
12 | isum += i
13 | rsum += random.randint(0,i)
14 |
15 | # write to streams
16 | st_isum.write((i, isum))
17 | st_rsum.write((i, rsum))
18 |
19 | time.sleep(1)
20 |
--------------------------------------------------------------------------------
/test/visualizations/arr_img_plot.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 | import numpy as np
3 | import time
4 | import torchvision.datasets as datasets
5 |
6 | fruits_ds = datasets.ImageFolder(r'D:\datasets\fruits-360\Training')
7 | mnist_ds = datasets.MNIST('../data', train=True, download=True)
8 |
9 | images = [tw.ImageData(fruits_ds[i][0], title=str(i)) for i in range(5)] + \
10 | [tw.ImageData(mnist_ds[i][0], title=str(i)) for i in range(5)]
11 |
12 | stream = tw.ArrayStream(images)
13 |
14 | img_plot = tw.Visualizer(stream, vis_type='image', viz_img_scale=3)
15 | img_plot.show()
16 |
17 | tw.image_utils.plt_loop()
--------------------------------------------------------------------------------
/test/visualizations/arr_mpl_line.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 |
3 | stream = tw.ArrayStream([(i, i*i) for i in range(50)])
4 | img_plot = tw.Visualizer(stream, vis_type='mpl-line', viz_img_scale=3, xtitle='Epochs', ytitle='Gain')
5 | # img_plot.show()
6 | # tw.plt_loop()
7 | img_plot.save(r'c:\temp\fig1.png')
8 |
--------------------------------------------------------------------------------
/test/visualizations/bar_plot.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 | import random, time
3 |
4 | def static_bar():
5 | w = tw.Watcher()
6 | s = w.create_stream()
7 |
8 | v = tw.Visualizer(s, vis_type='bar')
9 | v.show()
10 |
11 | for i in range(10):
12 | s.write(int(random.random()*10))
13 |
14 | tw.plt_loop()
15 |
16 |
17 | def dynamic_bar():
18 | w = tw.Watcher()
19 | s = w.create_stream()
20 |
21 | v = tw.Visualizer(s, vis_type='bar', clear_after_each=True)
22 | v.show()
23 |
24 | for i in range(100):
25 | s.write([('a'+str(i), random.random()*10) for i in range(10)])
26 | tw.plt_loop(count=3)
27 |
28 | def dynamic_bar3d():
29 | w = tw.Watcher()
30 | s = w.create_stream()
31 |
32 | v = tw.Visualizer(s, vis_type='bar3d', clear_after_each=True)
33 | v.show()
34 |
35 | for i in range(100):
36 | s.write([(i, random.random()*10, z) for i in range(10) for z in range(10)])
37 | tw.plt_loop(count=3)
38 |
39 | static_bar()
40 | #dynamic_bar()
41 | #dynamic_bar3d()
42 |
--------------------------------------------------------------------------------
/test/visualizations/confidence_int.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 |
3 | w = tw.Watcher()
4 | s = w.create_stream()
5 |
6 | v = tw.Visualizer(s, vis_type='line')
7 | v.show()
8 |
9 | for i in range(10):
10 | i = float(i)
11 | s.write(tw.PointData(i, i*i, low=i*i-i, high=i*i+i))
12 |
13 | tw.plt_loop()
--------------------------------------------------------------------------------
/test/visualizations/histogram.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 | import random, time
3 |
4 | def static_hist():
5 | w = tw.Watcher()
6 | s = w.create_stream()
7 |
8 | v = tw.Visualizer(s, vis_type='histogram', bins=6)
9 | v.show()
10 |
11 | for _ in range(100):
12 | s.write(random.random()*10)
13 |
14 | tw.plt_loop()
15 |
16 |
17 | def dynamic_hist():
18 | w = tw.Watcher()
19 | s = w.create_stream()
20 |
21 | v = tw.Visualizer(s, vis_type='histogram', bins=6, clear_after_each=True)
22 | v.show()
23 |
24 | for _ in range(100):
25 | s.write([random.random()*10 for _ in range(100)])
26 | tw.plt_loop(count=3)
27 |
28 | dynamic_hist()
--------------------------------------------------------------------------------
/test/visualizations/line3d_plot.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 | import random, time
3 |
4 | # TODO: resolve problem with Axis3D?
5 |
6 | def static_line3d():
7 | w = tw.Watcher()
8 | s = w.create_stream()
9 |
10 | v = tw.Visualizer(s, vis_type='line3d')
11 | v.show()
12 |
13 | for i in range(10):
14 | s.write((i, i*i, int(random.random()*10)))
15 |
16 | tw.plt_loop()
17 |
18 | def dynamic_line3d():
19 | w = tw.Watcher()
20 | s = w.create_stream()
21 |
22 | v = tw.Visualizer(s, vis_type='line3d', clear_after_each=True)
23 | v.show()
24 |
25 | for i in range(100):
26 | s.write([(i, random.random()*10, z) for i in range(10) for z in range(10)])
27 | tw.plt_loop(count=3)
28 |
29 | static_line3d()
30 | #dynamic_line3d()
31 |
32 |
--------------------------------------------------------------------------------
/test/visualizations/mpl_line.py:
--------------------------------------------------------------------------------
1 | from tensorwatch.watcher_base import WatcherBase
2 | from tensorwatch.mpl.line_plot import LinePlot
3 | from tensorwatch.image_utils import plt_loop
4 | from tensorwatch.stream import Stream
5 | from tensorwatch.lv_types import StreamItem
6 |
7 |
8 | def main():
9 | watcher = WatcherBase()
10 | line_plot = LinePlot()
11 | stream = watcher.create_stream(expr='lambda vars:vars.x')
12 | line_plot.subscribe(stream)
13 | line_plot.show()
14 |
15 | for i in range(5):
16 | watcher.observe(x=(i, i*i))
17 | plt_loop()
18 |
19 | main()
20 |
21 |
--------------------------------------------------------------------------------
/test/visualizations/pie_chart.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 | import random, time
3 |
4 | def static_pie():
5 | w = tw.Watcher()
6 | s = w.create_stream()
7 |
8 | v = tw.Visualizer(s, vis_type='pie', bins=6)
9 | v.show()
10 |
11 | for i in range(6):
12 | s.write(('label'+str(i), random.random()*10, None, 0.5 if i==3 else 0))
13 |
14 | tw.plt_loop()
15 |
16 |
17 | def dynamic_pie():
18 | w = tw.Watcher()
19 | s = w.create_stream()
20 |
21 | v = tw.Visualizer(s, vis_type='pie', bins=6, clear_after_each=True)
22 | v.show()
23 |
24 | for _ in range(100):
25 | s.write([('label'+str(i), random.random()*10, None, i*0.01) for i in range(12)])
26 | tw.plt_loop(count=3)
27 |
28 | #static_pie()
29 | dynamic_pie()
30 |
--------------------------------------------------------------------------------
/test/visualizations/plotly_line.py:
--------------------------------------------------------------------------------
1 | from tensorwatch.watcher_base import WatcherBase
2 | from tensorwatch.plotly.line_plot import LinePlot
3 | from tensorwatch.image_utils import plt_loop
4 | from tensorwatch.stream import Stream
5 | from tensorwatch.lv_types import StreamItem
6 |
7 |
8 | def main():
9 | watcher = WatcherBase()
10 | line_plot = LinePlot()
11 | stream = watcher.create_stream(expr='lambda vars:vars.x')
12 | line_plot.subscribe(stream)
13 | line_plot.show()
14 |
15 | for i in range(5):
16 | watcher.observe(x=(i, i*i))
17 |
18 | main()
19 |
20 |
--------------------------------------------------------------------------------
/test/zmq/zmq_pub.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 | import time
3 | from tensorwatch.zmq_wrapper import ZmqWrapper
4 | from tensorwatch import utils
5 |
6 | utils.set_debug_verbosity(10)
7 |
8 | def clisrv_callback(clisrv, msg):
9 | print('from clisrv', msg)
10 |
11 | stream = ZmqWrapper.Publication(port = 40859)
12 | clisrv = ZmqWrapper.ClientServer(40860, True, clisrv_callback)
13 |
14 | for i in range(10000):
15 | stream.send_obj({'a': i}, "Topic1")
16 | print("sent ", i)
17 | time.sleep(1)
18 |
--------------------------------------------------------------------------------
/test/zmq/zmq_stream.py:
--------------------------------------------------------------------------------
1 | from tensorwatch.watcher_base import WatcherBase
2 | from tensorwatch.zmq_stream import ZmqStream
3 |
4 | def main():
5 | watcher = WatcherBase()
6 | stream = watcher.create_stream(expr='lambda vars:vars.x**2')
7 |
8 | zmq_pub = ZmqStream(for_write=True, stream_name = 'ZmqPub', console_debug=True)
9 | zmq_pub.subscribe(stream)
10 |
11 | for i in range(5):
12 | watcher.observe(x=i)
13 | input('paused')
14 |
15 | main()
16 |
17 |
18 |
19 |
--------------------------------------------------------------------------------
/test/zmq/zmq_sub.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 | import time
3 | from tensorwatch.zmq_wrapper import ZmqWrapper
4 | from tensorwatch import utils
5 |
6 | class A:
7 | def on_event(self, obj):
8 | print(obj)
9 |
10 | a = A()
11 |
12 |
13 | utils.set_debug_verbosity(10)
14 | sub = ZmqWrapper.Subscription(40859, "Topic1", a.on_event)
15 | print("subscriber is waiting")
16 |
17 | clisrv = ZmqWrapper.ClientServer(40860, False)
18 | clisrv.send_obj("hello 1")
19 | print('sleeping..')
20 | time.sleep(10)
21 | clisrv.send_obj("hello 2")
22 |
23 | print('waiting for key..')
24 | utils.wait_key()
--------------------------------------------------------------------------------
/test/zmq/zmq_watcher_client.py:
--------------------------------------------------------------------------------
1 | from tensorwatch.watcher_client import WatcherClient
2 | import time
3 |
4 | from tensorwatch import utils
5 | utils.set_debug_verbosity(10)
6 |
7 |
8 | def main():
9 | watcher = WatcherClient()
10 | stream = watcher.create_stream(expr='lambda vars:vars.x**2')
11 | stream.console_debug = True
12 | input('pause')
13 |
14 | main()
15 |
16 |
--------------------------------------------------------------------------------
/test/zmq/zmq_watcher_server.py:
--------------------------------------------------------------------------------
1 | from tensorwatch.watcher import Watcher
2 | import time
3 |
4 | from tensorwatch import utils
5 | utils.set_debug_verbosity(10)
6 |
7 |
8 | def main():
9 | watcher = Watcher()
10 |
11 | for i in range(5000):
12 | watcher.observe(x=i)
13 | # print(i)
14 | time.sleep(1)
15 |
16 | main()
17 |
--------------------------------------------------------------------------------
/update_package.bat:
--------------------------------------------------------------------------------
1 | PAUSE Make sure to increment version in setup.py. Continue?
2 | python setup.py sdist
3 | twine upload --repository-url https://upload.pypi.org/legacy/ dist/*
4 | REM pip install tensorwatch --upgrade
5 | REM pip show tensorwatch
6 |
7 | REM pip install yolk3k
8 | REM yolk -V tensorwatch
--------------------------------------------------------------------------------