├── tensorwatch
├── saliency
│ ├── lime
│ │ ├── __init__.py
│ │ └── wrappers
│ │ │ ├── __init__.py
│ │ │ ├── generic_utils.py
│ │ │ └── scikit_image.py
│ ├── __init__.py
│ ├── README.md
│ ├── gradcam.py
│ ├── occlusion.py
│ ├── lime_image_explainer.py
│ ├── deeplift.py
│ ├── backprop.py
│ └── saliency.py
├── model_graph
│ ├── hiddenlayer
│ │ ├── __init__.py
│ │ ├── README.md
│ │ ├── pytorch_builder.py
│ │ ├── ge.py
│ │ └── tf_builder.py
│ ├── __init__.py
│ └── torchstat_utils.py
├── mpl
│ ├── __init__.py
│ ├── pie_chart.py
│ ├── histogram.py
│ ├── image_plot.py
│ ├── bar_plot.py
│ └── base_mpl_plot.py
├── plotly
│ ├── __init__.py
│ ├── embeddings_plot.py
│ └── base_plotly_plot.py
├── embeddings
│ ├── __init__.py
│ └── tsne_utils.py
├── receptive_field
│ ├── __init__.py
│ └── rf_utils.py
├── array_stream.py
├── stream_union.py
├── __init__.py
├── filtered_stream.py
├── data_utils.py
├── pytorch_utils.py
├── zmq_stream.py
├── file_stream.py
├── repeated_timer.py
├── zmq_mgmt_stream.py
├── imagenet_utils.py
├── stream.py
├── tensor_utils.py
├── text_vis.py
├── watcher_client.py
├── notebook_maker.py
├── watcher.py
├── stream_factory.py
├── evaler.py
├── image_utils.py
├── visualizer.py
└── lv_types.py
├── MANIFEST.in
├── docs
├── images
│ ├── tsne.gif
│ ├── fruits.gif
│ ├── teaser.gif
│ ├── saliency.png
│ ├── draw_model.png
│ ├── model_stats.png
│ ├── quick_start.gif
│ ├── teaser_small.gif
│ ├── lazy_log_array_sum.png
│ └── simple_logging
│ │ ├── line_cell.png
│ │ ├── text_cell.png
│ │ ├── line_cell2.png
│ │ ├── plotly_line.png
│ │ └── text_summary.png
├── simple_logging.md
└── lazy_logging.md
├── data
└── test_images
│ ├── cat.jpg
│ ├── dogs.png
│ └── elephant.png
├── test
├── simple_log
│ ├── quick_start.py
│ ├── sum_lazy.py
│ ├── file_expr.py
│ ├── sum_log.py
│ ├── srv_ij.py
│ ├── cli_file_expr.py
│ ├── cli_sum_log.py
│ └── cli_ij.py
├── pre_train
│ ├── draw_model.py
│ └── tsny.py
├── visualizations
│ ├── confidence_int.py
│ ├── plotly_line.py
│ ├── mpl_line.py
│ ├── arr_img_plot.py
│ ├── histogram.py
│ ├── pie_chart.py
│ ├── line3d_plot.py
│ └── bar_plot.py
├── components
│ ├── circ_ref.py
│ ├── evaler.py
│ ├── stream.py
│ ├── notebook_maker.py
│ ├── watcher.py
│ └── file_only_test.py
├── zmq
│ ├── zmq_watcher_server.py
│ ├── zmq_watcher_client.py
│ ├── zmq_stream.py
│ ├── zmq_pub.py
│ └── zmq_sub.py
├── deps
│ ├── panda.py
│ ├── ipython_widget.py
│ ├── thread.py
│ └── live_graph.py
├── post_train
│ └── saliency.py
├── files
│ └── file_stream.py
├── mnist
│ └── cli_mnist.py
└── test.pyproj
├── .pylintrc
├── update_package.bat
├── SUPPORT.md
├── CHANGELOG.md
├── install_jupyterlab.bat
├── CONTRIBUTING.md
├── tensorwatch.sln
├── setup.py
├── LICENSE.TXT
├── TODO.md
├── NOTICE.md
└── notebooks
└── data_exploration.ipynb
/tensorwatch/saliency/lime/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tensorwatch/saliency/lime/wrappers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tensorwatch/model_graph/hiddenlayer/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include tensorwatch/*.txt
2 | include tensorwatch/*.json
--------------------------------------------------------------------------------
/docs/images/tsne.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0101011/tensorwatch/master/docs/images/tsne.gif
--------------------------------------------------------------------------------
/docs/images/fruits.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0101011/tensorwatch/master/docs/images/fruits.gif
--------------------------------------------------------------------------------
/docs/images/teaser.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0101011/tensorwatch/master/docs/images/teaser.gif
--------------------------------------------------------------------------------
/data/test_images/cat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0101011/tensorwatch/master/data/test_images/cat.jpg
--------------------------------------------------------------------------------
/data/test_images/dogs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0101011/tensorwatch/master/data/test_images/dogs.png
--------------------------------------------------------------------------------
/docs/images/saliency.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0101011/tensorwatch/master/docs/images/saliency.png
--------------------------------------------------------------------------------
/docs/images/draw_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0101011/tensorwatch/master/docs/images/draw_model.png
--------------------------------------------------------------------------------
/docs/images/model_stats.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0101011/tensorwatch/master/docs/images/model_stats.png
--------------------------------------------------------------------------------
/docs/images/quick_start.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0101011/tensorwatch/master/docs/images/quick_start.gif
--------------------------------------------------------------------------------
/docs/images/teaser_small.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0101011/tensorwatch/master/docs/images/teaser_small.gif
--------------------------------------------------------------------------------
/data/test_images/elephant.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0101011/tensorwatch/master/data/test_images/elephant.png
--------------------------------------------------------------------------------
/tensorwatch/mpl/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 |
5 |
--------------------------------------------------------------------------------
/tensorwatch/plotly/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 |
--------------------------------------------------------------------------------
/tensorwatch/saliency/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 |
--------------------------------------------------------------------------------
/tensorwatch/model_graph/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 |
--------------------------------------------------------------------------------
/docs/images/lazy_log_array_sum.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0101011/tensorwatch/master/docs/images/lazy_log_array_sum.png
--------------------------------------------------------------------------------
/tensorwatch/embeddings/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 |
5 |
--------------------------------------------------------------------------------
/tensorwatch/receptive_field/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 |
--------------------------------------------------------------------------------
/docs/images/simple_logging/line_cell.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0101011/tensorwatch/master/docs/images/simple_logging/line_cell.png
--------------------------------------------------------------------------------
/docs/images/simple_logging/text_cell.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0101011/tensorwatch/master/docs/images/simple_logging/text_cell.png
--------------------------------------------------------------------------------
/docs/images/simple_logging/line_cell2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0101011/tensorwatch/master/docs/images/simple_logging/line_cell2.png
--------------------------------------------------------------------------------
/docs/images/simple_logging/plotly_line.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0101011/tensorwatch/master/docs/images/simple_logging/plotly_line.png
--------------------------------------------------------------------------------
/docs/images/simple_logging/text_summary.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0101011/tensorwatch/master/docs/images/simple_logging/text_summary.png
--------------------------------------------------------------------------------
/tensorwatch/model_graph/hiddenlayer/README.md:
--------------------------------------------------------------------------------
1 | # Credits
2 |
3 | Code in this folder is adopted from,
4 |
5 | * https://github.com/waleedka/hiddenlayer
6 |
--------------------------------------------------------------------------------
/tensorwatch/saliency/README.md:
--------------------------------------------------------------------------------
1 | # Credits
2 | Code in this folder is adopted from
3 |
4 | * https://github.com/yulongwang12/visual-attribution
5 | * https://github.com/marcotcr/lime
6 |
--------------------------------------------------------------------------------
/test/simple_log/quick_start.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 | import time
3 |
4 | w = tw.Watcher(filename='test.log')
5 | s = w.create_stream(name='my_metric')
6 | #w.make_notebook()
7 |
8 | for i in range(1000):
9 | s.write((i, i*i))
10 | time.sleep(1)
11 |
--------------------------------------------------------------------------------
/test/pre_train/draw_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.models
3 | import tensorwatch as tw
4 |
5 | vgg16_model = torchvision.models.vgg16()
6 |
7 | drawing = tw.draw_model(vgg16_model, [1, 3, 224, 224])
8 | drawing.save('abc')
9 |
10 | input("Press any key")
--------------------------------------------------------------------------------
/.pylintrc:
--------------------------------------------------------------------------------
1 | [MASTER]
2 | ignore=model_graph,
3 | receptive_field,
4 | saliency
5 |
6 | [TYPECHECK]
7 | ignored-modules=numpy,torch,matplotlib,pyplot,zmq
8 |
9 | [MESSAGES CONTROL]
10 | disable=protected-access,
11 | broad-except,
12 | global-statement,
13 | fixme,
14 | C,
15 | R
--------------------------------------------------------------------------------
/update_package.bat:
--------------------------------------------------------------------------------
1 | PAUSE Make sure to increment version in setup.py. Continue?
2 | python setup.py sdist
3 | twine upload --repository-url https://upload.pypi.org/legacy/ dist/*
4 | REM pip install tensorwatch --upgrade
5 | REM pip show tensorwatch
6 |
7 | REM pip install yolk3k
8 | REM yolk -V tensorwatch
--------------------------------------------------------------------------------
/test/visualizations/confidence_int.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 |
3 | w = tw.Watcher()
4 | s = w.create_stream()
5 |
6 | v = tw.Visualizer(s, vis_type='line')
7 | v.show()
8 |
9 | for i in range(10):
10 | i = float(i)
11 | s.write(tw.PointData(i, i*i, low=i*i-i, high=i*i+i))
12 |
13 | tw.plt_loop()
--------------------------------------------------------------------------------
/SUPPORT.md:
--------------------------------------------------------------------------------
1 | # Support
2 |
3 | We highly recommend to take a look at source code and contribute to the project. Also please consider [contributing](CONTRIBUTING.md) new features and fixes :).
4 |
5 | * [Join TensorWatch Facebook Group](https://www.facebook.com/groups/378075159472803/)
6 | * [File GitHub Issue](https://github.com/Microsoft/tensorwatch/issues)
--------------------------------------------------------------------------------
/test/components/circ_ref.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 | import objgraph, time #pip install objgraph
3 |
4 | cli = tw.WatcherClient()
5 | time.sleep(10)
6 | del cli
7 |
8 | import gc
9 | gc.collect()
10 |
11 | import time
12 | time.sleep(2)
13 |
14 | objgraph.show_backrefs(objgraph.by_type('WatcherClient'), refcounts=True, filename='b.png')
15 |
16 |
--------------------------------------------------------------------------------
/test/components/evaler.py:
--------------------------------------------------------------------------------
1 | #import tensorwatch as tw
2 | from tensorwatch import evaler
3 |
4 |
5 | e = evaler.Evaler(expr='reduce(lambda x,y: (x+y), map(lambda x:(x**2), filter(lambda x: x%2==0, l)))')
6 | for i in range(5):
7 | eval_return = e.post(i)
8 | print(i, eval_return)
9 | eval_return = e.post(ended=True)
10 | print(i, eval_return)
11 |
--------------------------------------------------------------------------------
/test/zmq/zmq_watcher_server.py:
--------------------------------------------------------------------------------
1 | from tensorwatch.watcher import Watcher
2 | import time
3 |
4 | from tensorwatch import utils
5 | utils.set_debug_verbosity(10)
6 |
7 |
8 | def main():
9 | watcher = Watcher()
10 |
11 | for i in range(5000):
12 | watcher.observe(x=i)
13 | # print(i)
14 | time.sleep(1)
15 |
16 | main()
17 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # What's new
2 |
3 | Below is summarized list of important changes. This does not include minor/less important changes or bug fixes or documentation update. This list updated every few months. For complete detailed changes, please review [commit history](https://github.com/Microsoft/tensorwatch/commits/master).
4 |
5 | ### May, 2019
6 | * First release!
--------------------------------------------------------------------------------
/test/deps/panda.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 | class SomeThing:
4 | def __init__(self, x, y):
5 | self.x, self.y = x, y
6 |
7 | things = [SomeThing(1,2), SomeThing(3,4), SomeThing(4,5)]
8 |
9 | df = pd.DataFrame([t.__dict__ for t in things ])
10 |
11 | print(df.iloc[[-1]].to_dict('records')[0])
12 | #print(df.to_html())
13 | print(df.style.render())
14 |
--------------------------------------------------------------------------------
/test/components/stream.py:
--------------------------------------------------------------------------------
1 | from tensorwatch.stream import Stream
2 |
3 |
4 | s1 = Stream(stream_name='s1', console_debug=True)
5 | s2 = Stream(stream_name='s2', console_debug=True)
6 | s3 = Stream(stream_name='s3', console_debug=True)
7 |
8 | s1.subscribe(s2)
9 | s2.subscribe(s3)
10 |
11 | s3.write('S3 wrote this')
12 | s2.write('S2 wrote this')
13 | s1.write('S1 wrote this')
14 |
15 |
--------------------------------------------------------------------------------
/test/zmq/zmq_watcher_client.py:
--------------------------------------------------------------------------------
1 | from tensorwatch.watcher_client import WatcherClient
2 | import time
3 |
4 | from tensorwatch import utils
5 | utils.set_debug_verbosity(10)
6 |
7 |
8 | def main():
9 | watcher = WatcherClient()
10 | stream = watcher.create_stream(expr='lambda vars:vars.x**2')
11 | stream.console_debug = True
12 | input('pause')
13 |
14 | main()
15 |
16 |
--------------------------------------------------------------------------------
/test/components/notebook_maker.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 |
3 |
4 | def main():
5 | w = tw.Watcher()
6 | s1 = w.create_stream()
7 | s2 = w.create_stream(name='accuracy', vis_args=tw.VisArgs(vis_type='line', xtitle='X-Axis', clear_after_each=False, history_len=2))
8 | s3 = w.create_stream(name='loss', expr='lambda d:d.loss')
9 | w.make_notebook()
10 |
11 | main()
12 |
13 |
--------------------------------------------------------------------------------
/test/pre_train/tsny.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 |
3 | from regim import *
4 | ds = DataUtils.mnist_datasets(linearize=True, train_test=False)
5 | ds = DataUtils.sample_by_class(ds, k=5, shuffle=True, as_np=True, no_test=True)
6 |
7 | comps = tw.get_tsne_components(ds)
8 | print(comps)
9 | plot = tw.Visualizer(comps, hover_images=ds[0], hover_image_reshape=(28,28), vis_type='tsne')
10 | plot.show()
--------------------------------------------------------------------------------
/test/simple_log/sum_lazy.py:
--------------------------------------------------------------------------------
1 | import time, random
2 | import tensorwatch as tw
3 |
4 | # create watcher object as usual
5 | w = tw.Watcher()
6 |
7 | weights = None
8 | for i in range(10000):
9 | weights = [random.random() for _ in range(5)]
10 |
11 | # let watcher observe variables we have
12 | # this has almost no performance cost
13 | w.observe(weights=weights)
14 |
15 | time.sleep(1)
16 |
17 |
--------------------------------------------------------------------------------
/test/deps/ipython_widget.py:
--------------------------------------------------------------------------------
1 | import ipywidgets as widgets
2 | from IPython import get_ipython
3 |
4 | class PrinterX:
5 | def __init__(self):
6 | self.w = w=widgets.HTML()
7 | def show(self):
8 | return self.w
9 | def write(self,s):
10 | self.w.value = s
11 |
12 | print("Running from within ipython?", get_ipython() is not None)
13 | p=PrinterX()
14 | p.show()
15 | p.write('ffffffffff')
16 |
--------------------------------------------------------------------------------
/test/components/watcher.py:
--------------------------------------------------------------------------------
1 | from tensorwatch.watcher_base import WatcherBase
2 | from tensorwatch.stream import Stream
3 |
4 | def main():
5 | watcher = WatcherBase()
6 | console_pub = Stream(stream_name = 'S1', console_debug=True)
7 | stream = watcher.create_stream(expr='lambda vars:vars.x**2')
8 | console_pub.subscribe(stream)
9 |
10 | for i in range(5):
11 | watcher.observe(x=i)
12 |
13 | main()
14 |
15 |
16 |
--------------------------------------------------------------------------------
/test/simple_log/file_expr.py:
--------------------------------------------------------------------------------
1 | import time
2 | import tensorwatch as tw
3 | from tensorwatch import utils
4 | utils.set_debug_verbosity(4)
5 |
6 | srv = tw.Watcher(filename=r'c:\temp\sum.log')
7 | s1 = srv.create_stream('sum', expr='lambda v:(v.i, v.sum)')
8 | s2 = srv.create_stream('sum_2', expr='lambda v:(v.i, v.sum/2)')
9 |
10 | sum = 0
11 | for i in range(10000):
12 | sum += i
13 | srv.observe(i=i, sum=sum)
14 | #print(i, sum)
15 | time.sleep(1)
16 |
17 |
--------------------------------------------------------------------------------
/test/zmq/zmq_stream.py:
--------------------------------------------------------------------------------
1 | from tensorwatch.watcher_base import WatcherBase
2 | from tensorwatch.zmq_stream import ZmqStream
3 |
4 | def main():
5 | watcher = WatcherBase()
6 | stream = watcher.create_stream(expr='lambda vars:vars.x**2')
7 |
8 | zmq_pub = ZmqStream(for_write=True, stream_name = 'ZmqPub', console_debug=True)
9 | zmq_pub.subscribe(stream)
10 |
11 | for i in range(5):
12 | watcher.observe(x=i)
13 | input('paused')
14 |
15 | main()
16 |
17 |
18 |
19 |
--------------------------------------------------------------------------------
/test/simple_log/sum_log.py:
--------------------------------------------------------------------------------
1 | import time, random
2 | import tensorwatch as tw
3 |
4 | # we will create two streams, one for
5 | # sums of integers, other for random numbers
6 | w = tw.Watcher()
7 | st_isum = w.create_stream('isums')
8 | st_rsum = w.create_stream('rsums')
9 |
10 | isum, rsum = 0, 0
11 | for i in range(10000):
12 | isum += i
13 | rsum += random.randint(0,i)
14 |
15 | # write to streams
16 | st_isum.write((i, isum))
17 | st_rsum.write((i, rsum))
18 |
19 | time.sleep(1)
20 |
--------------------------------------------------------------------------------
/install_jupyterlab.bat:
--------------------------------------------------------------------------------
1 | conda install -c conda-forge jupyterlab nodejs
2 | conda install ipywidgets
3 | conda install -c plotly plotly-orca psutil
4 |
5 | set NODE_OPTIONS=--max-old-space-size=4096
6 | jupyter labextension install @jupyter-widgets/jupyterlab-manager --no-build
7 | jupyter labextension install plotlywidget --no-build
8 | jupyter labextension install @jupyterlab/plotly-extension --no-build
9 | jupyter labextension install jupyterlab-chart-editor --no-build
10 | jupyter lab build
11 | set NODE_OPTIONS=
--------------------------------------------------------------------------------
/test/zmq/zmq_pub.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 | import time
3 | from tensorwatch.zmq_wrapper import ZmqWrapper
4 | from tensorwatch import utils
5 |
6 | utils.set_debug_verbosity(10)
7 |
8 | def clisrv_callback(clisrv, msg):
9 | print('from clisrv', msg)
10 |
11 | stream = ZmqWrapper.Publication(port = 40859)
12 | clisrv = ZmqWrapper.ClientServer(40860, True, clisrv_callback)
13 |
14 | for i in range(10000):
15 | stream.send_obj({'a': i}, "Topic1")
16 | print("sent ", i)
17 | time.sleep(1)
18 |
--------------------------------------------------------------------------------
/test/simple_log/srv_ij.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 | import time
3 | import random
4 | from tensorwatch import utils
5 |
6 | utils.set_debug_verbosity(4)
7 |
8 | srv = tw.Watcher()
9 |
10 | while(True):
11 | for i in range(1000):
12 | srv.observe("ev_i", val=i*random.random(), x=i)
13 | print('sent ev_i ', i)
14 | time.sleep(1)
15 | for j in range(5):
16 | srv.observe("ev_j", x=j, val=j*random.random())
17 | print('sent ev_j ', j)
18 | time.sleep(0.5)
19 | srv.end_event("ev_j")
20 | srv.end_event("ev_i")
--------------------------------------------------------------------------------
/test/visualizations/plotly_line.py:
--------------------------------------------------------------------------------
1 | from tensorwatch.watcher_base import WatcherBase
2 | from tensorwatch.plotly.line_plot import LinePlot
3 | from tensorwatch.image_utils import plt_loop
4 | from tensorwatch.stream import Stream
5 | from tensorwatch.lv_types import StreamItem
6 |
7 |
8 | def main():
9 | watcher = WatcherBase()
10 | line_plot = LinePlot()
11 | stream = watcher.create_stream(expr='lambda vars:vars.x')
12 | line_plot.subscribe(stream)
13 | line_plot.show()
14 |
15 | for i in range(5):
16 | watcher.observe(x=(i, i*i))
17 |
18 | main()
19 |
20 |
--------------------------------------------------------------------------------
/test/visualizations/mpl_line.py:
--------------------------------------------------------------------------------
1 | from tensorwatch.watcher_base import WatcherBase
2 | from tensorwatch.mpl.line_plot import LinePlot
3 | from tensorwatch.image_utils import plt_loop
4 | from tensorwatch.stream import Stream
5 | from tensorwatch.lv_types import StreamItem
6 |
7 |
8 | def main():
9 | watcher = WatcherBase()
10 | line_plot = LinePlot()
11 | stream = watcher.create_stream(expr='lambda vars:vars.x')
12 | line_plot.subscribe(stream)
13 | line_plot.show()
14 |
15 | for i in range(5):
16 | watcher.observe(x=(i, i*i))
17 | plt_loop()
18 |
19 | main()
20 |
21 |
--------------------------------------------------------------------------------
/test/zmq/zmq_sub.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 | import time
3 | from tensorwatch.zmq_wrapper import ZmqWrapper
4 | from tensorwatch import utils
5 |
6 | class A:
7 | def on_event(self, obj):
8 | print(obj)
9 |
10 | a = A()
11 |
12 |
13 | utils.set_debug_verbosity(10)
14 | sub = ZmqWrapper.Subscription(40859, "Topic1", a.on_event)
15 | print("subscriber is waiting")
16 |
17 | clisrv = ZmqWrapper.ClientServer(40860, False)
18 | clisrv.send_obj("hello 1")
19 | print('sleeping..')
20 | time.sleep(10)
21 | clisrv.send_obj("hello 2")
22 |
23 | print('waiting for key..')
24 | utils.wait_key()
--------------------------------------------------------------------------------
/test/visualizations/arr_img_plot.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 | import numpy as np
3 | import time
4 | import torchvision.datasets as datasets
5 |
6 | fruits_ds = datasets.ImageFolder(r'D:\datasets\fruits-360\Training')
7 | mnist_ds = datasets.MNIST('../data', train=True, download=True)
8 |
9 | images = [tw.ImageData(fruits_ds[i][0], title=str(i)) for i in range(5)] + \
10 | [tw.ImageData(mnist_ds[i][0], title=str(i)) for i in range(5)]
11 |
12 | stream = tw.ArrayStream(images)
13 |
14 | img_plot = tw.Visualizer(stream, vis_type='image', viz_img_scale=3)
15 | img_plot.show()
16 |
17 | tw.image_utils.plt_loop()
--------------------------------------------------------------------------------
/test/post_train/saliency.py:
--------------------------------------------------------------------------------
1 | from tensorwatch.saliency import saliency
2 | from tensorwatch import image_utils, imagenet_utils, pytorch_utils
3 |
4 | model = pytorch_utils.get_model('resnet50')
5 | raw_input, input, target_class = pytorch_utils.image_class2tensor('../data/test_images/dogs.png', 240, #'../data/elephant.png', 101,
6 | image_transform=imagenet_utils.get_image_transform(), image_convert_mode='RGB')
7 |
8 | results = saliency.get_image_saliency_results(model, raw_input, input, target_class)
9 | figure = saliency.get_image_saliency_plot(results)
10 |
11 | image_utils.plt_loop()
12 |
13 |
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/test/deps/thread.py:
--------------------------------------------------------------------------------
1 | import threading
2 | import time
3 | import sys
4 |
5 | def handler(a,b=None):
6 | sys.exit(1)
7 | def install_handler():
8 | if sys.platform == "win32":
9 | if sys.stdin is not None and sys.stdin.isatty():
10 | #this is Console based application
11 | import win32api
12 | win32api.SetConsoleCtrlHandler(handler, True)
13 |
14 |
15 | def work():
16 | time.sleep(10000)
17 | t = threading.Thread(target=work, name='ThreadTest')
18 | t.daemon = True
19 | t.start()
20 | while(True):
21 | t.join(0.1) #100ms ~ typical human response
22 | # you will get KeyboardIntrupt exception
23 |
24 |
--------------------------------------------------------------------------------
/tensorwatch/array_stream.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from .stream import Stream
5 | from .lv_types import StreamItem
6 | import uuid
7 |
8 | class ArrayStream(Stream):
9 | def __init__(self, array, stream_name:str=None, console_debug:bool=False):
10 | super(ArrayStream, self).__init__(stream_name=stream_name, console_debug=console_debug)
11 |
12 | self.stream_name = stream_name
13 | self.array = array
14 |
15 | def load(self, from_stream:'Stream'=None):
16 | if self.array is not None:
17 | self.write(self.array)
18 | super(ArrayStream, self).load()
19 |
--------------------------------------------------------------------------------
/test/visualizations/histogram.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 | import random, time
3 |
4 | def static_hist():
5 | w = tw.Watcher()
6 | s = w.create_stream()
7 |
8 | v = tw.Visualizer(s, vis_type='histogram', bins=6)
9 | v.show()
10 |
11 | for _ in range(100):
12 | s.write(random.random()*10)
13 |
14 | tw.plt_loop()
15 |
16 |
17 | def dynamic_hist():
18 | w = tw.Watcher()
19 | s = w.create_stream()
20 |
21 | v = tw.Visualizer(s, vis_type='histogram', bins=6, clear_after_each=True)
22 | v.show()
23 |
24 | for _ in range(100):
25 | s.write([random.random()*10 for _ in range(100)])
26 | tw.plt_loop(count=3)
27 |
28 | dynamic_hist()
--------------------------------------------------------------------------------
/test/visualizations/pie_chart.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 | import random, time
3 |
4 | def static_pie():
5 | w = tw.Watcher()
6 | s = w.create_stream()
7 |
8 | v = tw.Visualizer(s, vis_type='pie', bins=6)
9 | v.show()
10 |
11 | for i in range(6):
12 | s.write(('label'+str(i), random.random()*10, None, 0.5 if i==3 else 0))
13 |
14 | tw.plt_loop()
15 |
16 |
17 | def dynamic_pie():
18 | w = tw.Watcher()
19 | s = w.create_stream()
20 |
21 | v = tw.Visualizer(s, vis_type='pie', bins=6, clear_after_each=True)
22 | v.show()
23 |
24 | for _ in range(100):
25 | s.write([('label'+str(i), random.random()*10, None, i*0.01) for i in range(12)])
26 | tw.plt_loop(count=3)
27 |
28 | #static_pie()
29 | dynamic_pie()
30 |
--------------------------------------------------------------------------------
/test/visualizations/line3d_plot.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 | import random, time
3 |
4 | # TODO: resolve problem with Axis3D?
5 |
6 | def static_line3d():
7 | w = tw.Watcher()
8 | s = w.create_stream()
9 |
10 | v = tw.Visualizer(s, vis_type='line3d')
11 | v.show()
12 |
13 | for i in range(10):
14 | s.write((i, i*i, int(random.random()*10)))
15 |
16 | tw.plt_loop()
17 |
18 | def dynamic_line3d():
19 | w = tw.Watcher()
20 | s = w.create_stream()
21 |
22 | v = tw.Visualizer(s, vis_type='line3d', clear_after_each=True)
23 | v.show()
24 |
25 | for i in range(100):
26 | s.write([(i, random.random()*10, z) for i in range(10) for z in range(10)])
27 | tw.plt_loop(count=3)
28 |
29 | static_line3d()
30 | #dynamic_line3d()
31 |
32 |
--------------------------------------------------------------------------------
/test/files/file_stream.py:
--------------------------------------------------------------------------------
1 | from tensorwatch.watcher_base import WatcherBase
2 | from tensorwatch.stream import Stream
3 | from tensorwatch.file_stream import FileStream
4 | from tensorwatch import LinePlot
5 | from tensorwatch.image_utils import plt_loop
6 | import tensorwatch as tw
7 |
8 | def file_write():
9 | watcher = WatcherBase()
10 | stream = watcher.create_stream(expr='lambda vars:(vars.x, vars.x**2)',
11 | devices=[r'c:\temp\obs.txt'])
12 |
13 | for i in range(5):
14 | watcher.observe(x=i)
15 |
16 | def file_read():
17 | watcher = WatcherBase()
18 | stream = watcher.open_stream(devices=[r'c:\temp\obs.txt'])
19 | vis = tw.Visualizer(stream, vis_type='mpl-line')
20 | vis.show()
21 | plt_loop()
22 |
23 | def main():
24 | file_write()
25 | file_read()
26 |
27 | main()
28 |
29 |
--------------------------------------------------------------------------------
/test/simple_log/cli_file_expr.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 | from tensorwatch import utils
3 | utils.set_debug_verbosity(4)
4 |
5 |
6 | #r = tw.Visualizer(vis_type='mpl-line')
7 | #r.show()
8 | #r2=tw.Visualizer('map(lambda x:math.sqrt(x.sum), l)', cell=r.cell)
9 | #r3=tw.Visualizer('map(lambda x:math.sqrt(x.sum), l)', renderer=r2)
10 |
11 | def show_mpl():
12 | cli = tw.WatcherClient(r'c:\temp\sum.log')
13 | s1 = cli.open_stream('sum')
14 | p = tw.LinePlot(title='Demo')
15 | p.subscribe(s1, xtitle='Index', ytitle='sqrt(ev_i)')
16 | s1.load()
17 | p.show()
18 |
19 | tw.plt_loop()
20 |
21 | def show_text():
22 | cli = tw.WatcherClient(r'c:\temp\sum.log')
23 | s1 = cli.open_stream('sum_2')
24 | text = tw.Visualizer(s1)
25 | text.show()
26 | input('Waiting')
27 |
28 | #show_text()
29 | show_mpl()
30 |
--------------------------------------------------------------------------------
/test/simple_log/cli_sum_log.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 | from tensorwatch import utils
3 | utils.set_debug_verbosity(4)
4 |
5 |
6 | #r = tw.Visualizer(vis_type='mpl-line')
7 | #r.show()
8 | #r2=tw.Visualizer('map(lambda x:math.sqrt(x.sum), l)', cell=r.cell)
9 | #r3=tw.Visualizer('map(lambda x:math.sqrt(x.sum), l)', renderer=r2)
10 |
11 | def show_mpl():
12 | cli = tw.WatcherClient()
13 | st_isum = cli.open_stream('isums')
14 | st_rsum = cli.open_stream('rsums')
15 |
16 | line_plot = tw.Visualizer(st_isum, vis_type='line', xtitle='i', ytitle='isum')
17 | line_plot.show()
18 |
19 | line_plot2 = tw.Visualizer(st_rsum, vis_type='line', host=line_plot, ytitle='rsum')
20 |
21 | tw.plt_loop()
22 |
23 | def show_text():
24 | cli = tw.WatcherClient()
25 | text_vis = tw.Visualizer(st_isum, vis_type='text')
26 | text_vis.show()
27 | input('Waiting')
28 |
29 | #show_text()
30 | show_mpl()
--------------------------------------------------------------------------------
/tensorwatch/stream_union.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from .stream import Stream
5 | from typing import Iterator
6 |
7 | class StreamUnion(Stream):
8 | def __init__(self, child_streams:Iterator[Stream], for_write:bool, stream_name:str=None, console_debug:bool=False) -> None:
9 | super(StreamUnion, self).__init__(stream_name=stream_name, console_debug=console_debug)
10 |
11 | # save references, child streams does away only if parent goes away
12 | self.child_streams = child_streams
13 |
14 | # when someone does write to us, we write to all our listeners
15 | if for_write:
16 | for child_stream in child_streams:
17 | child_stream.subscribe(self)
18 | else:
19 | # union of all child streams
20 | for child_stream in child_streams:
21 | self.subscribe(child_stream)
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing
2 |
3 | This project welcomes contributions and suggestions. Most contributions require you to
4 | agree to a Contributor License Agreement (CLA) declaring that you have the right to,
5 | and actually do, grant us the rights to use your contribution. For details, visit
6 | https://cla.microsoft.com.
7 |
8 | When you submit a pull request, a CLA-bot will automatically determine whether you need
9 | to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the
10 | instructions provided by the bot. You will only need to do this once across all repositories using our CLA.
11 |
12 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
13 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
14 | or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
15 |
--------------------------------------------------------------------------------
/tensorwatch.sln:
--------------------------------------------------------------------------------
1 |
2 | Microsoft Visual Studio Solution File, Format Version 12.00
3 | # Visual Studio 15
4 | VisualStudioVersion = 15.0.27428.2043
5 | MinimumVisualStudioVersion = 10.0.40219.1
6 | Project("{888888A0-9F3D-457C-B088-3A5042F75D52}") = "tensorwatch", "tensorwatch.pyproj", "{CC8ABC7F-EDE1-4E13-B6B7-0041A5EC66A7}"
7 | EndProject
8 | Global
9 | GlobalSection(SolutionConfigurationPlatforms) = preSolution
10 | Debug|Any CPU = Debug|Any CPU
11 | Release|Any CPU = Release|Any CPU
12 | EndGlobalSection
13 | GlobalSection(ProjectConfigurationPlatforms) = postSolution
14 | {CC8ABC7F-EDE1-4E13-B6B7-0041A5EC66A7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
15 | {CC8ABC7F-EDE1-4E13-B6B7-0041A5EC66A7}.Release|Any CPU.ActiveCfg = Release|Any CPU
16 | EndGlobalSection
17 | GlobalSection(SolutionProperties) = preSolution
18 | HideSolutionNode = FALSE
19 | EndGlobalSection
20 | GlobalSection(ExtensibilityGlobals) = postSolution
21 | SolutionGuid = {99E7AEC7-2CDE-48C8-B98B-4E28E4F840B6}
22 | EndGlobalSection
23 | EndGlobal
24 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import setuptools
5 |
6 | with open("README.md", "r") as fh:
7 | long_description = fh.read()
8 |
9 | setuptools.setup(
10 | name="tensorwatch",
11 | version="0.8.5",
12 | author="Shital Shah",
13 | author_email="shitals@microsoft.com",
14 | description="Interactive Realtime Debugging and Visualization for AI",
15 | long_description=long_description,
16 | long_description_content_type="text/markdown",
17 | url="https://github.com/microsoft/tensorwatch",
18 | packages=setuptools.find_packages(),
19 | license='MIT',
20 | classifiers=(
21 | "Programming Language :: Python :: 3",
22 | "License :: OSI Approved :: MIT License",
23 | "Operating System :: OS Independent",
24 | ),
25 | include_package_data=True,
26 | install_requires=[
27 | 'matplotlib', 'numpy', 'pyzmq', 'plotly', 'torchstat', 'ipywidgets', 'sklearn', 'nbformat', 'scikit-image' # , 'receptivefield'
28 | ]
29 | )
--------------------------------------------------------------------------------
/test/visualizations/bar_plot.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 | import random, time
3 |
4 | def static_bar():
5 | w = tw.Watcher()
6 | s = w.create_stream()
7 |
8 | v = tw.Visualizer(s, vis_type='bar')
9 | v.show()
10 |
11 | for i in range(10):
12 | s.write(int(random.random()*10))
13 |
14 | tw.plt_loop()
15 |
16 |
17 | def dynamic_bar():
18 | w = tw.Watcher()
19 | s = w.create_stream()
20 |
21 | v = tw.Visualizer(s, vis_type='bar', clear_after_each=True)
22 | v.show()
23 |
24 | for i in range(100):
25 | s.write([('a'+str(i), random.random()*10) for i in range(10)])
26 | tw.plt_loop(count=3)
27 |
28 | def dynamic_bar3d():
29 | w = tw.Watcher()
30 | s = w.create_stream()
31 |
32 | v = tw.Visualizer(s, vis_type='bar3d', clear_after_each=True)
33 | v.show()
34 |
35 | for i in range(100):
36 | s.write([(i, random.random()*10, z) for i in range(10) for z in range(10)])
37 | tw.plt_loop(count=3)
38 |
39 | static_bar()
40 | #dynamic_bar()
41 | #dynamic_bar3d()
42 |
--------------------------------------------------------------------------------
/test/deps/live_graph.py:
--------------------------------------------------------------------------------
1 | from matplotlib import pyplot as plt
2 | from matplotlib.animation import FuncAnimation
3 | from random import randrange
4 | from threading import Thread
5 | import time
6 |
7 | class LiveGraph:
8 | def __init__(self):
9 | self.x_data, self.y_data = [], []
10 | self.figure = plt.figure()
11 | self.line, = plt.plot(self.x_data, self.y_data)
12 | self.animation = FuncAnimation(self.figure, self.update, interval=1000)
13 | self.th = Thread(target=self.thread_f, name='LiveGraph', daemon=True)
14 | self.th.start()
15 |
16 | def update(self, frame):
17 | self.line.set_data(self.x_data, self.y_data)
18 | self.figure.gca().relim()
19 | self.figure.gca().autoscale_view()
20 | return self.line,
21 |
22 | def show(self):
23 | plt.show()
24 |
25 | def thread_f(self):
26 | x = 0
27 | while True:
28 | self.x_data.append(x)
29 | x += 1
30 | self.y_data.append(randrange(0, 100))
31 | time.sleep(1)
--------------------------------------------------------------------------------
/LICENSE.TXT:
--------------------------------------------------------------------------------
1 | Copyright (c) Microsoft Corporation. All rights reserved.
2 |
3 | MIT License
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
23 |
--------------------------------------------------------------------------------
/tensorwatch/receptive_field/rf_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from receptivefield.pytorch import PytorchReceptiveField
5 | #from receptivefield.image import get_default_image
6 | import numpy as np
7 |
8 | def _get_rf(model, sample_pil_img):
9 | # define model functions
10 | def model_fn():
11 | model.eval()
12 | return model
13 |
14 | input_shape = np.array(sample_pil_img).shape
15 |
16 | rf = PytorchReceptiveField(model_fn)
17 | rf_params = rf.compute(input_shape=input_shape)
18 | return rf, rf_params
19 |
20 | def plot_receptive_field(model, sample_pil_img, layout=(2, 2), figsize=(6, 6)):
21 | rf, rf_params = _get_rf(model, sample_pil_img) # pylint: disable=unused-variable
22 | return rf.plot_rf_grids(
23 | custom_image=sample_pil_img,
24 | figsize=figsize,
25 | layout=layout)
26 |
27 | def plot_grads_at(model, sample_pil_img, feature_map_index=0, point=(8,8), figsize=(6, 6)):
28 | rf, rf_params = _get_rf(model, sample_pil_img) # pylint: disable=unused-variable
29 | return rf.plot_gradient_at(fm_id=feature_map_index, point=point, image=None, figsize=figsize)
30 |
--------------------------------------------------------------------------------
/TODO.md:
--------------------------------------------------------------------------------
1 | * Fix cell size issue
2 | * Refactor _plot* interface to accept all values, for ImagePlot only use last value
3 | * Refactor ImagePlot for arbitrary number of images with alpha, cmap
4 | * Change tw.open -> tw.create_viz
5 | * Make sure streams have names as key, each data point has index
6 | * Add tw.open_viz(stream_name, from_index)_
7 | * Add persist=device_name option for streams
8 | * Ability to use streams in standalone mode
9 | * tw.create_viz on server side
10 | * tw.log for server side
11 | * experiment with IPC channel
12 | * confusion matrix as in https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html
13 | * Speed up import
14 | * Do linting
15 | * live perf data
16 | * NaN tracing
17 | * PCA
18 | * Remove error if MNIST notebook is on and we run fruits
19 | * Remove 2nd image from fruits
20 | * clear exisitng streams when starting client
21 | * ImageData should accept numpy array or pillow or torch tensor
22 | * image plot getting refreshed at 12hz instead of 2 hz in MNIST
23 | * image plot doesn't title
24 | * Animated mesh/surface graph demo
25 | * Move to h5 storage?
26 | * Error envelop
27 | * histogram
28 | * new graph on end
29 | * TF support
30 | * generic visualizer -> Given obj and Box, paint in box
31 | * visualize image and text with attention
32 | * add confidence interval for plotly: https://plot.ly/python/continuous-error-bars/
--------------------------------------------------------------------------------
/tensorwatch/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from typing import Iterable, Sequence, Union
5 |
6 | from .watcher_client import WatcherClient
7 | from .watcher import Watcher
8 | from .watcher_base import WatcherBase
9 |
10 | from .text_vis import TextVis
11 | from .plotly.embeddings_plot import EmbeddingsPlot
12 | from .mpl.line_plot import LinePlot
13 | from .mpl.image_plot import ImagePlot
14 | from .mpl.histogram import Histogram
15 | from .mpl.bar_plot import BarPlot
16 | from .mpl.pie_chart import PieChart
17 | from .visualizer import Visualizer
18 |
19 | from .stream import Stream
20 | from .array_stream import ArrayStream
21 | from .lv_types import PointData, ImageData, VisArgs, StreamItem, PredictionResult
22 | from . import utils
23 |
24 | ###### Import methods for tw namespace #########
25 | #from .receptive_field.rf_utils import plot_receptive_field, plot_grads_at
26 | from .embeddings.tsne_utils import get_tsne_components
27 | from .model_graph.torchstat_utils import model_stats
28 | from .image_utils import show_image, open_image, img2pyt, linear_to_2d, plt_loop
29 | from .data_utils import pyt_ds2list, sample_by_class, col2array, search_similar
30 |
31 |
32 |
33 | def draw_model(model, input_shape=None, orientation='TB'): #orientation = 'LR' for landscpe
34 | from .model_graph.hiddenlayer import graph
35 | g = graph.build_graph(model, input_shape, orientation=orientation)
36 | return g
37 |
38 |
39 |
--------------------------------------------------------------------------------
/tensorwatch/saliency/lime/wrappers/generic_utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import inspect
3 | import types
4 |
5 |
6 | def has_arg(fn, arg_name):
7 | """Checks if a callable accepts a given keyword argument.
8 |
9 | Args:
10 | fn: callable to inspect
11 | arg_name: string, keyword argument name to check
12 |
13 | Returns:
14 | bool, whether `fn` accepts a `arg_name` keyword argument.
15 | """
16 | if sys.version_info < (3,):
17 | if isinstance(fn, types.FunctionType) or isinstance(fn, types.MethodType):
18 | arg_spec = inspect.getargspec(fn)
19 | else:
20 | try:
21 | arg_spec = inspect.getargspec(fn.__call__)
22 | except AttributeError:
23 | return False
24 | return (arg_name in arg_spec.args)
25 | elif sys.version_info < (3, 6):
26 | arg_spec = inspect.getfullargspec(fn)
27 | return (arg_name in arg_spec.args or
28 | arg_name in arg_spec.kwonlyargs)
29 | else:
30 | try:
31 | signature = inspect.signature(fn)
32 | except ValueError:
33 | # handling Cython
34 | signature = inspect.signature(fn.__call__)
35 | parameter = signature.parameters.get(arg_name)
36 | if parameter is None:
37 | return False
38 | return (parameter.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD,
39 | inspect.Parameter.KEYWORD_ONLY))
40 |
--------------------------------------------------------------------------------
/tensorwatch/filtered_stream.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from .stream import Stream
5 | from typing import Callable, Any
6 |
7 | class FilteredStream(Stream):
8 | def __init__(self, source_stream:Stream, filter_expr:Callable, stream_name:str=None,
9 | console_debug:bool=False)->None:
10 |
11 | stream_name = stream_name or '{}|{}'.format(source_stream.stream_name, str(filter_expr))
12 | super(FilteredStream, self).__init__(stream_name=stream_name, console_debug=console_debug)
13 | self.subscribe(source_stream)
14 | self.filter_expr = filter_expr
15 |
16 | def _filter(self, stream_item):
17 | return self.filter_expr(stream_item) \
18 | if self.filter_expr is not None \
19 | else (stream_item, True)
20 |
21 | def write(self, val:Any, from_stream:'Stream'=None):
22 | stream_item = self.to_stream_item(val)
23 |
24 | result, is_valid = self._filter(stream_item)
25 | if is_valid:
26 | return super(FilteredStream, self).write(result)
27 | # else ignore this call
28 |
29 | def read_all(self, from_stream:'Stream'=None): #override->replacement
30 | for subscribed_to in self._subscribed_to:
31 | for stream_item in subscribed_to.read_all(from_stream=self):
32 | result, is_valid = self._filter(stream_item)
33 | if is_valid:
34 | yield stream_item
35 |
36 |
--------------------------------------------------------------------------------
/tensorwatch/embeddings/tsne_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from sklearn.manifold import TSNE
5 | from sklearn.preprocessing import StandardScaler
6 | import numpy as np
7 |
8 | def _standardize_data(data, col, whitten, flatten):
9 | if col is not None:
10 | data = data[col]
11 |
12 | #TODO: enable auto flattening
13 | #if data is tensor then flatten it first
14 | #if flatten and len(data) > 0 and hasattr(data[0], 'shape') and \
15 | # utils.has_method(data[0], 'reshape'):
16 |
17 | # data = [d.reshape((-1,)) for d in data]
18 |
19 | if whitten:
20 | data = StandardScaler().fit_transform(data)
21 | return data
22 |
23 | def get_tsne_components(data, features_col=0, labels_col=1, whitten=True, n_components=3, perplexity=20, flatten=True, for_plot=True):
24 | features = _standardize_data(data, features_col, whitten, flatten)
25 | tsne = TSNE(n_components=n_components, perplexity=perplexity)
26 | tsne_results = tsne.fit_transform(features)
27 |
28 | if for_plot:
29 | comps = tsne_results.tolist()
30 | labels = data[labels_col]
31 | for i, item in enumerate(comps):
32 | # low, high, annotation, text, color
33 | label = labels[i]
34 | if isinstance(labels, np.ndarray):
35 | label = label.item()
36 | item.extend((None, None, None, str(int(label)), label))
37 | return comps
38 | return tsne_results
39 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/test/components/file_only_test.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 |
3 | def writer():
4 | watcher = tw.Watcher(filename=r'c:\temp\test.log', port=None)
5 |
6 | with watcher.create_stream('metric1') as stream1:
7 | for i in range(3):
8 | stream1.write((i, i*i))
9 |
10 | with watcher.create_stream('metric2') as stream2:
11 | for i in range(3):
12 | stream2.write((i, i*i*i))
13 |
14 | def reader1():
15 | print('---------------------------reader1---------------------------')
16 | watcher = tw.Watcher(filename=r'c:\temp\test.log', port=None)
17 |
18 | stream1 = watcher.open_stream('metric1')
19 | stream1.console_debug = True
20 | stream1.load()
21 |
22 | stream2 = watcher.open_stream('metric2')
23 | stream2.console_debug = True
24 | stream2.load()
25 |
26 | def reader2():
27 | print('---------------------------reader2---------------------------')
28 |
29 | watcher = tw.Watcher(filename=r'c:\temp\test.log', port=None)
30 |
31 | stream1 = watcher.open_stream('metric1')
32 | for item in stream1.read_all():
33 | print(item)
34 |
35 | stream2 = watcher.open_stream('metric2')
36 | for item in stream2.read_all():
37 | print(item)
38 |
39 | def reader3():
40 | print('---------------------------reader3---------------------------')
41 |
42 | watcher = tw.Watcher(filename=r'c:\temp\test.log', port=None)
43 | stream1 = watcher.open_stream('metric1')
44 | stream2 = watcher.open_stream('metric2')
45 |
46 | vis1 = tw.Visualizer(stream1, vis_type='line')
47 | vis2 = tw.Visualizer(stream2, vis_type='line', host=vis1)
48 |
49 | vis1.show()
50 |
51 | tw.plt_loop()
52 |
53 |
54 | writer()
55 | reader1()
56 | reader2()
57 | reader3()
58 |
59 |
--------------------------------------------------------------------------------
/tensorwatch/data_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import random
5 | import scipy.spatial.distance
6 | import heapq
7 | import numpy as np
8 | import torch
9 |
10 | def pyt_tensor2np(pyt_tensor):
11 | if pyt_tensor is None:
12 | return None
13 | if isinstance(pyt_tensor, torch.Tensor):
14 | n = pyt_tensor.data.cpu().numpy()
15 | if len(n.shape) == 1:
16 | return n[0]
17 | else:
18 | return n
19 | elif isinstance(pyt_tensor, np.ndarray):
20 | return pyt_tensor
21 | else:
22 | return np.array(pyt_tensor)
23 |
24 | def pyt_tuple2np(pyt_tuple):
25 | return tuple((pyt_tensor2np(t) for t in pyt_tuple))
26 |
27 | def pyt_ds2list(pyt_ds, count=None):
28 | count = count or len(pyt_ds)
29 | return [pyt_tuple2np(t) for t, c in zip(pyt_ds, range(count))]
30 |
31 | def sample_by_class(data, n_samples, class_col=1, shuffle=True):
32 | if shuffle:
33 | random.shuffle(data)
34 | samples = {}
35 | for i, t in enumerate(data):
36 | cls = t[class_col]
37 | if cls not in samples:
38 | samples[cls] = []
39 | if len(samples[cls]) < n_samples:
40 | samples[cls].append(data[i])
41 | samples = sum(samples.values(), [])
42 | return samples
43 |
44 | def col2array(dataset, col):
45 | return [row[col] for row in dataset]
46 |
47 | def search_similar(inputs, compare_to, algorithm='euclidean', topk=5, invert_score=True):
48 | all_scores = scipy.spatial.distance.cdist(inputs, compare_to, algorithm)
49 | all_results = []
50 | for input_val, scores in zip(inputs, all_scores):
51 | result = []
52 | for i, (score, data) in enumerate(zip(scores, compare_to)):
53 | if invert_score:
54 | score = 1/(score + 1.0E-6)
55 | if len(result) < topk:
56 | heapq.heappush(result, (score, (i, input_val, data)))
57 | else:
58 | heapq.heappushpop(result, (score, (i, input_val, data)))
59 | all_results.append(result)
60 | return all_results
--------------------------------------------------------------------------------
/tensorwatch/pytorch_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from torchvision import models, transforms
5 | import torch
6 | import torch.nn.functional as F
7 | from . import utils, image_utils
8 | import os
9 |
10 | def get_model(model_name):
11 | model = models.__dict__[model_name](pretrained=True)
12 | return model
13 |
14 | def tensors2batch(tensors, preprocess_transform=None):
15 | if preprocess_transform:
16 | tensors = tuple(preprocess_transform(i) for i in tensors)
17 | if not utils.is_array_like(tensors):
18 | tensors = tuple(tensors)
19 | return torch.stack(tensors, dim=0)
20 |
21 | def int2tensor(val):
22 | return torch.LongTensor([val])
23 |
24 | def image2batch(image, image_transform=None):
25 | if image_transform:
26 | input_x = image_transform(image)
27 | else: # if no transforms supplied then just convert PIL image to tensor
28 | input_x = transforms.ToTensor()(image)
29 | input_x = input_x.unsqueeze(0) #convert to batch of 1
30 | return input_x
31 |
32 | def image_class2tensor(image_path, class_index=None, image_convert_mode=None,
33 | image_transform=None):
34 | image_pil = image_utils.open_image(os.path.abspath(image_path), convert_mode=image_convert_mode)
35 | input_x = image2batch(image_pil, image_transform)
36 | target_class = int2tensor(class_index) if class_index is not None else None
37 | return image_pil, input_x, target_class
38 |
39 | def batch_predict(model, inputs, input_transform=None, device=None):
40 | if input_transform:
41 | batch = torch.stack(tuple(input_transform(i) for i in inputs), dim=0)
42 | else:
43 | batch = torch.stack(inputs, dim=0)
44 |
45 | device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
46 | model.eval()
47 | model.to(device)
48 | batch = batch.to(device)
49 |
50 | outputs = model(batch)
51 |
52 | return outputs
53 |
54 | def logits2probabilities(logits):
55 | return F.softmax(logits, dim=1)
56 |
57 | def tensor2numpy(t):
58 | return t.data.cpu().numpy()
59 |
--------------------------------------------------------------------------------
/tensorwatch/saliency/gradcam.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .backprop import VanillaGradExplainer
3 |
4 |
5 | def _get_layer(model, key_list):
6 | if key_list is None:
7 | return None
8 |
9 | a = model
10 | for key in key_list:
11 | a = a._modules[key]
12 | return a
13 |
14 | class GradCAMExplainer(VanillaGradExplainer):
15 | def __init__(self, model, target_layer_name_keys=None, use_inp=False):
16 | super(GradCAMExplainer, self).__init__(model)
17 | self.target_layer = _get_layer(model, target_layer_name_keys)
18 | self.use_inp = use_inp
19 | self.intermediate_act = []
20 | self.intermediate_grad = []
21 | self._register_forward_backward_hook()
22 |
23 | def _register_forward_backward_hook(self):
24 | def forward_hook_input(m, i, o):
25 | self.intermediate_act.append(i[0].data.clone())
26 |
27 | def forward_hook_output(m, i, o):
28 | self.intermediate_act.append(o.data.clone())
29 |
30 | def backward_hook(m, grad_i, grad_o):
31 | self.intermediate_grad.append(grad_o[0].data.clone())
32 |
33 | if self.target_layer is not None:
34 | if self.use_inp:
35 | self.target_layer.register_forward_hook(forward_hook_input)
36 | else:
37 | self.target_layer.register_forward_hook(forward_hook_output)
38 |
39 | self.target_layer.register_backward_hook(backward_hook)
40 |
41 | def _reset_intermediate_lists(self):
42 | self.intermediate_act = []
43 | self.intermediate_grad = []
44 |
45 | def explain(self, inp, ind=None, raw_inp=None):
46 | self._reset_intermediate_lists()
47 |
48 | _ = super(GradCAMExplainer, self)._backprop(inp, ind)
49 |
50 | if len(self.intermediate_grad):
51 | grad = self.intermediate_grad[0]
52 | act = self.intermediate_act[0]
53 |
54 | weights = grad.sum(-1).sum(-1).unsqueeze(-1).unsqueeze(-1)
55 | cam = weights * act
56 | cam = cam.sum(1).unsqueeze(1)
57 |
58 | cam = torch.clamp(cam, min=0)
59 |
60 | return cam
61 | else:
62 | return None
63 |
--------------------------------------------------------------------------------
/tensorwatch/zmq_stream.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from typing import Any
5 | from .zmq_wrapper import ZmqWrapper
6 | from .stream import Stream
7 | from .lv_types import DefaultPorts, PublisherTopics
8 | from . import utils
9 |
10 | # on writes send data on ZMQ transport
11 | class ZmqStream(Stream):
12 | def __init__(self, for_write:bool, port:int=0, topic=PublisherTopics.StreamItem, block_until_connected=True,
13 | stream_name:str=None, console_debug:bool=False):
14 | super(ZmqStream, self).__init__(stream_name=stream_name, console_debug=console_debug)
15 |
16 | self.for_write = for_write
17 | self._zmq = None
18 |
19 | self.topic = topic
20 | self._open(for_write, port, block_until_connected)
21 | utils.debug_log('ZmqStream started', verbosity=1)
22 |
23 | def _open(self, for_write:bool, port:int, block_until_connected:bool):
24 | if for_write:
25 | self._zmq = ZmqWrapper.Publication(port=DefaultPorts.PubSub+port,
26 | block_until_connected=block_until_connected)
27 | else:
28 | self._zmq = ZmqWrapper.Subscription(port=DefaultPorts.PubSub+port,
29 | topic=self.topic, callback=self._on_subscription_item)
30 |
31 | def close(self):
32 | if not self.closed:
33 | self._zmq.close()
34 | self._zmq = None
35 | utils.debug_log('ZmqStream is closed', verbosity=1)
36 | super(ZmqStream, self).close()
37 |
38 | def _on_subscription_item(self, val:Any):
39 | utils.debug_log('Received subscription item', verbosity=5)
40 | self.write(val)
41 |
42 | def write(self, val:Any, from_stream:'Stream'=None, topic=None):
43 | stream_item = self.to_stream_item(val)
44 |
45 | if self.for_write:
46 | topic = topic or self.topic
47 | utils.debug_log('Sent subscription item', verbosity=5)
48 | self._zmq.send_obj(stream_item, topic)
49 | # else if this was opened for read then we have subscription and
50 | # we shouldn't be calling send_obj
51 | super(ZmqStream, self).write(stream_item)
52 |
--------------------------------------------------------------------------------
/tensorwatch/saliency/occlusion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch.autograd import Variable
4 | from skimage.util import view_as_windows
5 |
6 | # modified from https://github.com/marcoancona/DeepExplain/blob/master/deepexplain/tensorflow/methods.py#L291-L342
7 | # note the different dim order in pytorch (NCHW) and tensorflow (NHWC)
8 |
9 | class OcclusionExplainer:
10 | def __init__(self, model, window_shape=10, step=1):
11 | self.model = model
12 | self.window_shape = window_shape
13 | self.step = step
14 |
15 | def explain(self, inp, ind=None, raw_inp=None):
16 | self.model.eval()
17 | with torch.no_grad():
18 | return OcclusionExplainer._occlusion(inp, self.model, self.window_shape, self.step)
19 |
20 | @staticmethod
21 | def _occlusion(inp, model, window_shape, step=None):
22 | if type(window_shape) == int:
23 | window_shape = (window_shape, window_shape, 3)
24 |
25 | if step is None:
26 | step = 1
27 | n, c, h, w = inp.data.size()
28 | total_dim = c * h * w
29 | index_matrix = np.arange(total_dim).reshape(h, w, c)
30 | idx_patches = view_as_windows(index_matrix, window_shape, step).reshape(
31 | (-1,) + window_shape)
32 | heatmap = np.zeros((n, h, w, c), dtype=np.float32).reshape((-1), total_dim)
33 | weights = np.zeros_like(heatmap)
34 |
35 | inp_data = inp.data.clone()
36 | new_inp = Variable(inp_data)
37 | eval0 = model(new_inp)
38 | pred_id = eval0.max(1)[1].data[0]
39 |
40 | for i, p in enumerate(idx_patches):
41 | mask = np.ones((h, w, c)).flatten()
42 | mask[p.flatten()] = 0
43 | th_mask = torch.from_numpy(mask.reshape(1, h, w, c).transpose(0, 3, 1, 2)).float().cuda()
44 | masked_xs = Variable(th_mask * inp_data)
45 | delta = (eval0[0, pred_id] - model(masked_xs)[0, pred_id]).data.cpu().numpy()
46 | delta_aggregated = np.sum(delta.reshape(n, -1), -1, keepdims=True)
47 | heatmap[:, p.flatten()] += delta_aggregated
48 | weights[:, p.flatten()] += p.size
49 |
50 | attribution = np.reshape(heatmap / (weights + 1e-10), (n, h, w, c)).transpose(0, 3, 1, 2)
51 | return torch.from_numpy(attribution)
--------------------------------------------------------------------------------
/tensorwatch/file_stream.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from .stream import Stream
5 | import pickle, os
6 | from typing import Any
7 | from . import utils
8 | import time
9 |
10 | class FileStream(Stream):
11 | def __init__(self, for_write:bool, file_name:str, stream_name:str=None, console_debug:bool=False):
12 | super(FileStream, self).__init__(stream_name=stream_name or file_name, console_debug=console_debug)
13 |
14 | self._file = open(file_name, 'wb' if for_write else 'rb')
15 | self.file_name = file_name
16 | self.for_write = for_write
17 | utils.debug_log('FileStream started', self.file_name, verbosity=1)
18 |
19 | def close(self):
20 | if not self._file.closed:
21 | self._file.close()
22 | self._file = None
23 | utils.debug_log('FileStream is closed', self.file_name, verbosity=1)
24 | super(FileStream, self).close()
25 |
26 | def write(self, val:Any, from_stream:'Stream'=None):
27 | stream_item = self.to_stream_item(val)
28 |
29 | if self.for_write:
30 | pickle.dump(stream_item, self._file)
31 | super(FileStream, self).write(stream_item)
32 |
33 | def read_all(self, from_stream:'Stream'=None):
34 | if self.for_write:
35 | raise IOError('Cannot use read() call because FileSteam is opened with for_write=True')
36 | if self._file is not None:
37 | self._file.seek(0, 0) # we may filter this stream multiple times
38 | while not utils.is_eof(self._file):
39 | yield pickle.load(self._file)
40 | for item in super(FileStream, self).read_all():
41 | yield item
42 |
43 | def load(self, from_stream:'Stream'=None):
44 | if self.for_write:
45 | raise IOError('Cannot use load() call because FileSteam is opened with for_write=True')
46 | if self._file is not None:
47 | self._file.seek(0, 0) # we may filter this stream multiple times
48 | while not utils.is_eof(self._file):
49 | stream_item = pickle.load(self._file)
50 | self.write(stream_item)
51 | super(FileStream, self).load()
52 |
53 | def save(self, from_stream:'Stream'=None):
54 | if not self._file.closed:
55 | self._file.flush()
56 | super(FileStream, self).save(val)
57 |
58 |
--------------------------------------------------------------------------------
/tensorwatch/saliency/lime_image_explainer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from skimage.segmentation import mark_boundaries
5 | from torchvision import transforms
6 | import torch
7 | from .lime import lime_image
8 | import numpy as np
9 | from .. import imagenet_utils, pytorch_utils, utils
10 |
11 | class LimeImageExplainer:
12 | def __init__(self, model, predict_fn):
13 | self.model = model
14 | self.predict_fn = predict_fn
15 |
16 | def preprocess_input(self, inp):
17 | return inp
18 | def preprocess_label(self, label):
19 | return label
20 |
21 | def explain(self, inp, ind=None, raw_inp=None, top_labels=5, hide_color=0, num_samples=1000,
22 | positive_only=True, num_features=5, hide_rest=True, pixel_val_max=255.0):
23 | explainer = lime_image.LimeImageExplainer()
24 | explanation = explainer.explain_instance(self.preprocess_input(raw_inp), self.predict_fn,
25 | top_labels=5, hide_color=0, num_samples=1000)
26 |
27 | temp, mask = explanation.get_image_and_mask(self.preprocess_label(ind) or explanation.top_labels[0],
28 | positive_only=True, num_features=5, hide_rest=True)
29 |
30 | img = mark_boundaries(temp/pixel_val_max, mask)
31 | img = torch.from_numpy(img)
32 | img = torch.transpose(img, 0, 2)
33 | img = torch.transpose(img, 1, 2)
34 | return img.unsqueeze(0)
35 |
36 | class LimeImagenetExplainer(LimeImageExplainer):
37 | def __init__(self, model, predict_fn=None):
38 | super(LimeImagenetExplainer, self).__init__(model, predict_fn or self._imagenet_predict)
39 |
40 | def _preprocess_transform(self):
41 | transf = transforms.Compose([
42 | transforms.ToTensor(),
43 | imagenet_utils.get_normalize_transform()
44 | ])
45 |
46 | return transf
47 |
48 | def preprocess_input(self, inp):
49 | return np.array(imagenet_utils.get_resize_transform()(inp))
50 | def preprocess_label(self, label):
51 | return label.item() if label is not None and utils.has_method(label, 'item') else label
52 |
53 | def _imagenet_predict(self, images):
54 | probs = imagenet_utils.predict(self.model, images, image_transform=self._preprocess_transform())
55 | return pytorch_utils.tensor2numpy(probs)
56 |
--------------------------------------------------------------------------------
/tensorwatch/repeated_timer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import threading
5 | import time
6 | import weakref
7 |
8 | class RepeatedTimer:
9 | class State:
10 | Stopped=0
11 | Paused=1
12 | Running=2
13 |
14 | def __init__(self, secs, callback, count=None):
15 | self.secs = secs
16 | self.callback = weakref.WeakMethod(callback) if callback else None
17 | self._thread = None
18 | self._state = RepeatedTimer.State.Stopped
19 | self.pause_wait = threading.Event()
20 | self.pause_wait.set()
21 | self._continue_thread = False
22 | self.count = count
23 |
24 | def start(self):
25 | self._continue_thread = True
26 | self.pause_wait.set()
27 | if self._thread is None or not self._thread.isAlive():
28 | self._thread = threading.Thread(target=self._runner, name='RepeatedTimer', daemon=True)
29 | self._thread.start()
30 | self._state = RepeatedTimer.State.Running
31 |
32 | def stop(self, block=False):
33 | self.pause_wait.set()
34 | self._continue_thread = False
35 | if block and not (self._thread is None or not self._thread.isAlive()):
36 | self._thread.join()
37 | self._state = RepeatedTimer.State.Stopped
38 |
39 | def get_state(self):
40 | return self._state
41 |
42 |
43 | def pause(self):
44 | if self._state == RepeatedTimer.State.Running:
45 | self.pause_wait.clear()
46 | self._state = RepeatedTimer.State.Paused
47 | # else nothing to do
48 | def unpause(self):
49 | if self._state == RepeatedTimer.State.Paused:
50 | self.pause_wait.set()
51 | if self._state == RepeatedTimer.State.Paused:
52 | self._state = RepeatedTimer.State.Running
53 | # else nothing to do
54 |
55 | def _runner(self):
56 | while (self._continue_thread):
57 | if self.count:
58 | self.count -= 0
59 | if not self.count:
60 | self._continue_thread = False
61 |
62 | if self._continue_thread:
63 | self.pause_wait.wait()
64 | if self.callback and self.callback():
65 | self.callback()()
66 |
67 | if self._continue_thread:
68 | time.sleep(self.secs)
69 |
70 | self._thread = None
71 | self._state = RepeatedTimer.State.Stopped
--------------------------------------------------------------------------------
/test/simple_log/cli_ij.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 | import time
3 | import math
4 | from tensorwatch import utils
5 | utils.set_debug_verbosity(4)
6 |
7 | def mpl_line_plot():
8 | cli = tw.WatcherClient()
9 | p = tw.LinePlot(title='Demo')
10 | s1 = cli.create_stream(event_name='ev_i', expr='map(lambda v:math.sqrt(v.val)*2, l)')
11 | p.subscribe(s1, xtitle='Index', ytitle='sqrt(ev_i)')
12 | p.show()
13 | tw.plt_loop()
14 |
15 | def mpl_history_plot():
16 | cli = tw.WatcherClient()
17 | p2 = tw.LinePlot(title='History Demo')
18 | p2s1 = cli.create_stream(event_name='ev_j', expr='map(lambda v:(v.val, math.sqrt(v.val)*2), l)')
19 | p2.subscribe(p2s1, xtitle='Index', ytitle='sqrt(ev_j)', clear_after_end=True, history_len=15)
20 | p2.show()
21 | tw.plt_loop()
22 |
23 | def show_stream():
24 | cli = tw.WatcherClient()
25 |
26 | print("Subscribing to event ev_i...")
27 | s1 = cli.create_stream(event_name="ev_i", expr='map(lambda v:math.sqrt(v.val), l)')
28 | r1 = tw.TextVis(title='L1')
29 | r1.subscribe(s1)
30 | r1.show()
31 |
32 | print("Subscribing to event ev_j...")
33 | s2 = cli.create_stream(event_name="ev_j", expr='map(lambda v:v.val*v.val, l)')
34 | r2 = tw.TextVis(title='L2')
35 | r2.subscribe(s2)
36 |
37 | r2.show()
38 |
39 | print("Waiting for key...")
40 |
41 | utils.wait_key()
42 |
43 | # this no longer directly supported
44 | # TODO: create stream that allows enumeration from buffered values
45 | #def read_stream():
46 | # cli = tw.WatcherClient()
47 |
48 | # with cli.create_stream(event_name="ev_i", expr='map(lambda v:(v.x, math.sqrt(v.val)), l)') as s1:
49 | # for stream_item in s1:
50 | # print(stream_item.value)
51 | # print('done')
52 | # utils.wait_key()
53 |
54 | def plotly_line_graph():
55 | cli = tw.WatcherClient()
56 | s1 = cli.create_stream(event_name="ev_i", expr='map(lambda v:(v.x, math.sqrt(v.val)), l)')
57 |
58 | p = tw.plotly.line_plot.LinePlot()
59 | p.subscribe(s1)
60 | p.show()
61 |
62 | utils.wait_key()
63 |
64 | def plotly_history_graph():
65 | cli = tw.WatcherClient()
66 | p = tw.plotly.line_plot.LinePlot(title='Demo')
67 | s2 = cli.create_stream(event_name='ev_j', expr='map(lambda v:(v.x, v.val), l)')
68 | p.subscribe(s2, ytitle='ev_j', history_len=15)
69 | p.show()
70 | utils.wait_key()
71 |
72 |
73 | mpl_line_plot()
74 | #mpl_history_plot()
75 | #show_stream()
76 | #plotly_line_graph()
77 | #plotly_history_graph()
78 |
--------------------------------------------------------------------------------
/tensorwatch/mpl/pie_chart.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from .base_mpl_plot import BaseMplPlot
5 | from .. import image_utils
6 | from .. import utils
7 | import numpy as np
8 |
9 | class PieChart(BaseMplPlot):
10 | def init_stream_plot(self, stream_vis, autopct=None, colormap=None, color=None,
11 | shadow=None, **stream_vis_args):
12 |
13 | # add main subplot
14 | stream_vis.autopct, stream_vis.shadow = autopct, True if shadow is None else shadow
15 | stream_vis.ax = self.get_main_axis()
16 | stream_vis.series = []
17 | stream_vis.wedge_artists = [] # stores previously drawn bars
18 |
19 | stream_vis.cmap = image_utils.get_cmap(name=colormap or 'Dark2')
20 | if color is None:
21 | if not self.is_3d:
22 | color = stream_vis.cmap((len(self._stream_vises)%stream_vis.cmap.N)/stream_vis.cmap.N) # pylint: disable=no-member
23 | stream_vis.color = color
24 |
25 | def clear_artists(self, stream_vis):
26 | for artist in stream_vis.wedge_artists:
27 | artist.remove()
28 | stream_vis.wedge_artists.clear()
29 |
30 | def clear_plot(self, stream_vis, clear_history):
31 | stream_vis.series.clear()
32 | self.clear_artists(stream_vis)
33 |
34 | def _show_stream_items(self, stream_vis, stream_items):
35 | """Paint the given stream_items in to visualizer. If visualizer is dirty then return False else True.
36 | """
37 |
38 | vals = self._extract_vals(stream_items)
39 | if not len(vals):
40 | return True
41 |
42 | # make sure tuple has 4 elements
43 | unpacker = lambda a0=None,a1=None,a2=None,a3=None,*_:(a0,a1,a2,a3)
44 | stream_vis.series.extend([unpacker(*val) for val in vals])
45 | self.clear_artists(stream_vis)
46 |
47 | labels, sizes, colors, explode = \
48 | [t[0] for t in stream_vis.series], \
49 | [t[1] for t in stream_vis.series], \
50 | [(t[2] or stream_vis.cmap.colors[i % len(stream_vis.cmap.colors)]) \
51 | for i, t in enumerate(stream_vis.series)], \
52 | [t[3] or 0 for t in stream_vis.series],
53 |
54 |
55 | stream_vis.wedge_artists, *_ = stream_vis.ax.pie( \
56 | sizes, explode=explode, labels=labels, colors=colors,
57 | autopct=stream_vis.autopct, shadow=stream_vis.shadow)
58 |
59 | return False
60 |
61 |
--------------------------------------------------------------------------------
/tensorwatch/zmq_mgmt_stream.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from typing import Any
5 | from .zmq_stream import ZmqStream
6 | from .lv_types import PublisherTopics, ServerMgmtMsg, StreamCreateRequest
7 | from .zmq_wrapper import ZmqWrapper
8 | from .lv_types import CliSrvReqTypes, ClientServerRequest
9 | from . import utils
10 |
11 | class ZmqMgmtStream(ZmqStream):
12 | # default topic is mgmt
13 | def __init__(self, clisrv:ZmqWrapper.ClientServer, for_write:bool, port:int=0, topic=PublisherTopics.ServerMgmt, block_until_connected=True,
14 | stream_name:str=None, console_debug:bool=False):
15 | super(ZmqMgmtStream, self).__init__(for_write=for_write, port=port, topic=topic,
16 | block_until_connected=block_until_connected, stream_name=stream_name, console_debug=console_debug)
17 |
18 | self._clisrv = clisrv
19 | self._stream_reqs:Dict[str,StreamCreateRequest] = {}
20 |
21 | def write(self, val:Any, from_stream:'Stream'=None):
22 | r"""Handles server management events.
23 | """
24 | stream_item = self.to_stream_item(val)
25 | mgmt_msg = stream_item.value
26 |
27 | utils.debug_log("Received - SeverMgmtevent", mgmt_msg)
28 | # if server was restarted then send create stream requests again
29 | if mgmt_msg.event_name == ServerMgmtMsg.EventServerStart:
30 | for stream_req in self._stream_reqs.values():
31 | self._send_create_stream(stream_req)
32 |
33 | super(ZmqMgmtStream, self).write(stream_item)
34 |
35 | def add_stream_req(self, stream_req:StreamCreateRequest)->None:
36 | self._send_create_stream(stream_req)
37 |
38 | # save this for later for resend if server restarts
39 | self._stream_reqs[stream_req.stream_name] = stream_req
40 |
41 | # override to send request to server
42 | def del_stream(self, name:str) -> None:
43 | clisrv_req = ClientServerRequest(CliSrvReqTypes.del_stream, name)
44 | self._clisrv.send_obj(clisrv_req)
45 | self._stream_reqs.pop(name, None)
46 |
47 | def _send_create_stream(self, stream_req):
48 | utils.debug_log("sending create streamreq...")
49 | clisrv_req = ClientServerRequest(CliSrvReqTypes.create_stream, stream_req)
50 | self._clisrv.send_obj(clisrv_req)
51 | utils.debug_log("sent create streamreq")
52 |
53 | def close(self):
54 | if not self.closed:
55 | self._stream_reqs = {}
56 | self._clisrv = None
57 | utils.debug_log('ZmqMgmtStream is closed', verbosity=1)
58 | super(ZmqMgmtStream, self).close()
59 |
--------------------------------------------------------------------------------
/tensorwatch/mpl/histogram.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from .base_mpl_plot import BaseMplPlot
5 | from .. import image_utils
6 | from .. import utils
7 | import numpy as np
8 |
9 | class Histogram(BaseMplPlot):
10 | def init_stream_plot(self, stream_vis,
11 | xtitle='', ytitle='', ztitle='', colormap=None, color=None,
12 | bins=None, normed=None, histtype='bar', edge_color=None, linewidth=None,
13 | opacity=None, **stream_vis_args):
14 |
15 | # add main subplot
16 | stream_vis.bins, stream_vis.normed, stream_vis.linewidth = bins, normed, (linewidth or 2)
17 | stream_vis.ax = self.get_main_axis()
18 | stream_vis.series = []
19 | stream_vis.bars_artists = [] # stores previously drawn bars
20 |
21 | stream_vis.cmap = image_utils.get_cmap(name=colormap or 'Dark2')
22 | if color is None:
23 | if not self.is_3d:
24 | color = stream_vis.cmap((len(self._stream_vises)%stream_vis.cmap.N)/stream_vis.cmap.N) # pylint: disable=no-member
25 | stream_vis.color = color
26 | stream_vis.edge_color = 'black'
27 | stream_vis.histtype = histtype
28 | stream_vis.opacity = opacity
29 | stream_vis.ax.set_xlabel(xtitle)
30 | stream_vis.ax.xaxis.label.set_style('italic')
31 | stream_vis.ax.set_ylabel(ytitle)
32 | stream_vis.ax.yaxis.label.set_color(color)
33 | stream_vis.ax.yaxis.label.set_style('italic')
34 | if self.is_3d:
35 | stream_vis.ax.set_zlabel(ztitle)
36 | stream_vis.ax.zaxis.label.set_style('italic')
37 |
38 | def is_show_grid(self): #override
39 | return False
40 |
41 | def clear_artists(self, stream_vis):
42 | for bar in stream_vis.bars_artists:
43 | bar.remove()
44 | stream_vis.bars_artists.clear()
45 |
46 | def clear_plot(self, stream_vis, clear_history):
47 | stream_vis.series.clear()
48 | self.clear_artists(stream_vis)
49 |
50 | def _show_stream_items(self, stream_vis, stream_items):
51 | """Paint the given stream_items in to visualizer. If visualizer is dirty then return False else True.
52 | """
53 |
54 | vals = self._extract_vals(stream_items)
55 | if not len(vals):
56 | return True
57 |
58 | stream_vis.series += vals
59 | self.clear_artists(stream_vis)
60 | n, bins, stream_vis.bars_artists = stream_vis.ax.hist(stream_vis.series, bins=stream_vis.bins,
61 | normed=stream_vis.normed, color=stream_vis.color, edgecolor=stream_vis.edge_color,
62 | histtype=stream_vis.histtype, alpha=stream_vis.opacity,
63 | linewidth=stream_vis.linewidth)
64 |
65 | stream_vis.ax.set_xticks(bins)
66 |
67 | #stream_vis.ax.relim()
68 | #stream_vis.ax.autoscale_view()
69 |
70 | return False
71 |
--------------------------------------------------------------------------------
/tensorwatch/imagenet_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from torchvision import transforms
5 | from . import pytorch_utils
6 | import json, os
7 |
8 | def get_image_transform():
9 | transf = transforms.Compose([ #TODO: cache these transforms?
10 | get_resize_transform(),
11 | transforms.ToTensor(),
12 | get_normalize_transform()
13 | ])
14 |
15 | return transf
16 |
17 | def get_resize_transform():
18 | return transforms.Resize((224, 224))
19 |
20 | def get_normalize_transform():
21 | return transforms.Normalize(mean=[0.485, 0.456, 0.406],
22 | std=[0.229, 0.224, 0.225])
23 |
24 | def image2batch(image):
25 | return pytorch_utils.image2batch(image, image_transform=get_image_transform())
26 |
27 | def predict(model, images, image_transform=None, device=None):
28 | logits = pytorch_utils.batch_predict(model, images,
29 | input_transform=image_transform or get_image_transform(), device=device)
30 | probs = pytorch_utils.logits2probabilities(logits) #2-dim array, one column per class, one row per input
31 | return probs
32 |
33 | _imagenet_labels = None
34 | def get_imagenet_labels():
35 | # pylint: disable=global-statement
36 | global _imagenet_labels
37 | _imagenet_labels = _imagenet_labels or ImagenetLabels()
38 | return _imagenet_labels
39 |
40 | def probabilities2classes(probs, topk=5):
41 | labels = get_imagenet_labels()
42 | top_probs = probs.topk(topk)
43 | # return (probability, class_id, class_label, class_code)
44 | return tuple((p,c, labels.index2label_text(c), labels.index2label_code(c)) \
45 | for p, c in zip(top_probs[0][0].data.cpu().numpy(), top_probs[1][0].data.cpu().numpy()))
46 |
47 | class ImagenetLabels:
48 | def __init__(self, json_path=None):
49 | self._idx2label = []
50 | self._idx2cls = []
51 | self._cls2label = {}
52 | self._cls2idx = {}
53 |
54 | json_path = json_path or os.path.join(os.path.dirname(__file__), 'imagenet_class_index.json')
55 |
56 | with open(os.path.abspath(json_path), "r") as read_file:
57 | class_json = json.load(read_file)
58 | self._idx2label = [class_json[str(k)][1] for k in range(len(class_json))]
59 | self._idx2cls = [class_json[str(k)][0] for k in range(len(class_json))]
60 | self._cls2label = {class_json[str(k)][0]: class_json[str(k)][1] for k in range(len(class_json))}
61 | self._cls2idx = {class_json[str(k)][0]: k for k in range(len(class_json))}
62 |
63 | def index2label_text(self, index):
64 | return self._idx2label[index]
65 | def index2label_code(self, index):
66 | return self._idx2cls[index]
67 | def label_code2label_text(self, label_code):
68 | return self._cls2label[label_code]
69 | def label_code2index(self, label_code):
70 | return self._cls2idx[label_code]
--------------------------------------------------------------------------------
/tensorwatch/saliency/deeplift.py:
--------------------------------------------------------------------------------
1 | from .backprop import GradxInputExplainer
2 | import types
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 |
6 | # Based on formulation in DeepExplain, https://arxiv.org/abs/1711.06104
7 | # https://github.com/marcoancona/DeepExplain/blob/master/deepexplain/tensorflow/methods.py#L221-L272
8 | class DeepLIFTRescaleExplainer(GradxInputExplainer):
9 | def __init__(self, model):
10 | super(DeepLIFTRescaleExplainer, self).__init__(model)
11 | self._prepare_reference()
12 | self.baseline_inp = None
13 | self._override_backward()
14 |
15 | def _prepare_reference(self):
16 | def init_refs(m):
17 | name = m.__class__.__name__
18 | if name.find('ReLU') != -1:
19 | m.ref_inp_list = []
20 | m.ref_out_list = []
21 |
22 | def ref_forward(self, x):
23 | self.ref_inp_list.append(x.data.clone())
24 | out = F.relu(x)
25 | self.ref_out_list.append(out.data.clone())
26 | return out
27 |
28 | def ref_replace(m):
29 | name = m.__class__.__name__
30 | if name.find('ReLU') != -1:
31 | m.forward = types.MethodType(ref_forward, m)
32 |
33 | self.model.apply(init_refs)
34 | self.model.apply(ref_replace)
35 |
36 | def _reset_preference(self):
37 | def reset_refs(m):
38 | name = m.__class__.__name__
39 | if name.find('ReLU') != -1:
40 | m.ref_inp_list = []
41 | m.ref_out_list = []
42 |
43 | self.model.apply(reset_refs)
44 |
45 | def _baseline_forward(self, inp):
46 | if self.baseline_inp is None:
47 | self.baseline_inp = inp.data.clone()
48 | self.baseline_inp.fill_(0.0)
49 | self.baseline_inp = Variable(self.baseline_inp)
50 | else:
51 | self.baseline_inp.fill_(0.0)
52 | # get ref
53 | _ = self.model(self.baseline_inp)
54 |
55 | def _override_backward(self):
56 | def new_backward(self, grad_out):
57 | ref_inp, inp = self.ref_inp_list
58 | ref_out, out = self.ref_out_list
59 | delta_out = out - ref_out
60 | delta_in = inp - ref_inp
61 | g1 = (delta_in.abs() > 1e-5).float() * grad_out * \
62 | delta_out / delta_in
63 | mask = ((ref_inp + inp) > 0).float()
64 | g2 = (delta_in.abs() <= 1e-5).float() * 0.5 * mask * grad_out
65 |
66 | return g1 + g2
67 |
68 | def backward_replace(m):
69 | name = m.__class__.__name__
70 | if name.find('ReLU') != -1:
71 | m.backward = types.MethodType(new_backward, m)
72 |
73 | self.model.apply(backward_replace)
74 |
75 | def explain(self, inp, ind=None, raw_inp=None):
76 | self._reset_preference()
77 | self._baseline_forward(inp)
78 | g = super(DeepLIFTRescaleExplainer, self).explain(inp, ind)
79 |
80 | return g
81 |
--------------------------------------------------------------------------------
/tensorwatch/stream.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import weakref, uuid
5 | from typing import Any
6 | from . import utils
7 | from .lv_types import StreamItem
8 |
9 | class Stream:
10 | def __init__(self, stream_name:str=None, console_debug:bool=False):
11 | self._subscribers = weakref.WeakSet()
12 | self._subscribed_to = weakref.WeakSet()
13 | self.held_refs = set() # on some rare occasion we might want stream to hold references of other streams
14 | self.closed = False
15 | self.console_debug = console_debug
16 | self.stream_name = stream_name or str(uuid.uuid4()) # useful to use as key and avoid circular references
17 | self.items_written = 0
18 |
19 | def subscribe(self, stream:'Stream'): # notify other stream
20 | utils.debug_log('{} added {} as subscription'.format(self.stream_name, stream.stream_name))
21 | stream._subscribers.add(self)
22 | self._subscribed_to.add(stream)
23 |
24 | def unsubscribe(self, stream:'Stream'):
25 | utils.debug_log('{} removed {} as subscription'.format(self.stream_name, stream.stream_name))
26 | stream._subscribers.discard(self)
27 | self._subscribed_to.discard(stream)
28 | self.held_refs.discard(stream)
29 | #stream.held_refs.discard(self) # not needed as only subscriber should hold ref
30 |
31 | def to_stream_item(self, val:Any):
32 | stream_item = val if isinstance(val, StreamItem) else \
33 | StreamItem(value=val, stream_name=self.stream_name)
34 | if stream_item.stream_name is None:
35 | stream_item.stream_name = self.stream_name
36 | if stream_item.item_index is None:
37 | stream_item.item_index = self.items_written
38 | return stream_item
39 |
40 | def write(self, val:Any, from_stream:'Stream'=None):
41 | # if you override write method, first you must call self.to_stream_item
42 | # so it can stamp the stamp the stream name
43 | stream_item = self.to_stream_item(val)
44 |
45 | if self.console_debug:
46 | print(self.stream_name, stream_item)
47 |
48 | for subscriber in self._subscribers:
49 | subscriber.write(stream_item, from_stream=self)
50 | self.items_written += 1
51 |
52 | def read_all(self, from_stream:'Stream'=None):
53 | for subscribed_to in self._subscribed_to:
54 | for stream_item in subscribed_to.read_all(from_stream=self):
55 | yield stream_item
56 |
57 | def load(self, from_stream:'Stream'=None):
58 | for subscribed_to in self._subscribed_to:
59 | subscribed_to.load(from_stream=self)
60 |
61 | def save(self, from_stream:'Stream'=None):
62 | for subscriber in self._subscribers:
63 | subscriber.save(from_stream=self)
64 |
65 | def close(self):
66 | if not self.closed:
67 | for subscribed_to in self._subscribed_to:
68 | subscribed_to._subscribers.discard(self)
69 | self._subscribed_to.clear()
70 | self.closed = True
71 |
72 | def __enter__(self):
73 | return self
74 |
75 | def __exit__(self, exception_type, exception_value, traceback):
76 | self.close()
77 |
78 |
--------------------------------------------------------------------------------
/tensorwatch/tensor_utils.py:
--------------------------------------------------------------------------------
1 | from typing import Sized, Any, List
2 | import numbers
3 | import numpy as np
4 |
5 | class TensorType:
6 | Torch = 'torch'
7 | TF = 'tf'
8 | Numeric = 'numeric'
9 | Numpy = 'numpy'
10 | Other = 'other'
11 |
12 | def tensor_type(item:Any)->str:
13 | module_name = type(item).__module__
14 | class_name = type(item).__name__
15 |
16 | if module_name=='torch' and class_name=='Tensor':
17 | return TensorType.Torch
18 | elif module_name.startswith('tensorflow') and class_name=='EagerTensor':
19 | return TensorType.TF
20 | elif isinstance(item, numbers.Number):
21 | return TensorType.Numeric
22 | elif isinstance(item, np.ndarray):
23 | return TensorType.Numpy
24 | else:
25 | return TensorType.Other
26 |
27 | def tensor2scaler(item:Any)->numbers.Number:
28 | tt = tensor_type(item)
29 | if item is None or tt == TensorType.Numeric:
30 | return item
31 | if tt == TensorType.TF:
32 | item = item.numpy()
33 | return item.item() # works for torch and numpy
34 |
35 | def tensor2np(item:Any)->np.ndarray:
36 | tt = tensor_type(item)
37 | if item is None or tt == TensorType.Numpy:
38 | return item
39 | elif tt == TensorType.TF:
40 | return item.numpy()
41 | elif tt == TensorType.Torch:
42 | return item.data.cpu().numpy()
43 | else: # numeric and everything else, let np take care
44 | return np.array(item)
45 |
46 | def to_scaler_list(l:Sized)->List[numbers.Number]:
47 | """Create list of scalers for given list of tensors where each element is 0-dim tensor
48 | """
49 | if l is not None and len(l):
50 | tt = tensor_type(l[0])
51 | if tt == TensorType.Torch or tt == TensorType.Numpy:
52 | return [i.item() for i in l]
53 | elif tt == TensorType.TF:
54 | return [i.numpy().item() for i in l]
55 | elif tt == TensorType.Numeric:
56 | # convert to list in case l is not list type
57 | return [i for i in l]
58 | else:
59 | raise ValueError('Cannot convert tensor list to scaler list \
60 | because list element are of unsupported type ' + tt)
61 | else:
62 | return None if l is None else [] # make sure we always return list type
63 |
64 | def to_mean_list(l:Sized)->List[float]:
65 | """Create list of scalers for given list of tensors where each element is 0-dim tensor
66 | """
67 | if l is not None and len(l):
68 | tt = tensor_type(l[0])
69 | if tt == TensorType.Torch or tt == TensorType.Numpy:
70 | return [i.mean() for i in l]
71 | elif tt == TensorType.TF:
72 | return [i.numpy().mean() for i in l]
73 | elif tt == TensorType.Numeric:
74 | # convert to list in case l is not list type
75 | return [float(i) for i in l]
76 | else:
77 | raise ValueError('Cannot convert tensor list to scaler list \
78 | because list element are of unsupported type ' + tt)
79 | else:
80 | return None if l is None else []
81 |
82 | def to_np_list(l:Sized)->List[np.ndarray]:
83 | if l is not None and len(l):
84 | tt = tensor_type(l[0])
85 | if tt == TensorType.Numeric:
86 | return [np.array(i) for i in l]
87 | if tt == TensorType.TF:
88 | return [i.numpy() for i in l]
89 | if tt == TensorType.Torch:
90 | return [i.data.cpu().numpy() for i in l]
91 | if tt == TensorType.Numpy:
92 | return [i for i in l]
93 | raise ValueError('Cannot convert tensor list to scaler list \
94 | because list element are of unsupported type ' + tt)
95 | else:
96 | return None if l is None else []
--------------------------------------------------------------------------------
/tensorwatch/plotly/embeddings_plot.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from .. import image_utils
5 | import numpy as np
6 | from .line_plot import LinePlot
7 | import time
8 | from .. import utils
9 |
10 | class EmbeddingsPlot(LinePlot):
11 | def __init__(self, cell:LinePlot.widgets.Box=None, title=None, show_legend:bool=False, stream_name:str=None, console_debug:bool=False,
12 | is_3d:bool=True, hover_images=None, hover_image_reshape=None, **vis_args):
13 | utils.set_default(vis_args, 'height', '8in')
14 | super(EmbeddingsPlot, self).__init__(cell, title, show_legend,
15 | stream_name=stream_name, console_debug=console_debug, is_3d=is_3d, **vis_args)
16 |
17 | import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue
18 | if hover_images is not None:
19 | plt.ioff()
20 | self.image_output = LinePlot.widgets.Output()
21 | self.image_figure = plt.figure(figsize=(2,2))
22 | self.image_ax = self.image_figure.add_subplot(111)
23 | self.cell.children += (self.image_output,)
24 | plt.ion()
25 | self.hover_images, self.hover_image_reshape = hover_images, hover_image_reshape
26 | self.last_ind, self.last_ind_time = -1, 0
27 |
28 | def hover_fn(self, trace, points, state): # pylint: disable=unused-argument
29 | if not points:
30 | return
31 | ind = points.point_inds[0]
32 | if ind == self.last_ind or ind > len(self.hover_images) or ind < 0:
33 | return
34 |
35 | if self.last_ind == -1:
36 | self.last_ind, self.last_ind_time = ind, time.time()
37 | else:
38 | elapsed = time.time() - self.last_ind_time
39 | if elapsed < 0.3:
40 | self.last_ind, self.last_ind_time = ind, time.time()
41 | if elapsed < 1:
42 | return
43 | # else too much time since update
44 | # else we have stable ind
45 |
46 | import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue
47 | with self.image_output:
48 | plt.ioff()
49 |
50 | if self.hover_image_reshape:
51 | img = np.reshape(self.hover_images[ind], self.hover_image_reshape)
52 | else:
53 | img = self.hover_images[ind]
54 | if img is not None:
55 | LinePlot.display.clear_output(wait=True)
56 | self.image_ax.imshow(img)
57 | LinePlot.display.display(self.image_figure)
58 | plt.ion()
59 |
60 | return None
61 |
62 | def _create_trace(self, stream_vis):
63 | stream_vis.stream_vis_args.clear() #TODO remove this
64 | utils.set_default(stream_vis.stream_vis_args, 'draw_line', False)
65 | utils.set_default(stream_vis.stream_vis_args, 'draw_marker', True)
66 | utils.set_default(stream_vis.stream_vis_args, 'draw_marker_text', True)
67 | utils.set_default(stream_vis.stream_vis_args, 'hoverinfo', 'text')
68 | utils.set_default(stream_vis.stream_vis_args, 'marker', {})
69 |
70 | marker = stream_vis.stream_vis_args['marker']
71 | utils.set_default(marker, 'size', 6)
72 | utils.set_default(marker, 'colorscale', 'Jet')
73 | utils.set_default(marker, 'showscale', False)
74 | utils.set_default(marker, 'opacity', 0.8)
75 |
76 | return super(EmbeddingsPlot, self)._create_trace(stream_vis)
77 |
78 | def subscribe(self, stream, **stream_vis_args):
79 | super(EmbeddingsPlot, self).subscribe(stream)
80 | stream_vis = self._stream_vises[stream.stream_name]
81 | if stream_vis.index == 0 and self.hover_images is not None:
82 | self.widget.data[stream_vis.trace_index].on_hover(self.hover_fn)
--------------------------------------------------------------------------------
/tensorwatch/text_vis.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from . import utils
5 | from .vis_base import VisBase
6 |
7 | class TextVis(VisBase):
8 | def __init__(self, cell:VisBase.widgets.Box=None, title:str=None, show_legend:bool=None,
9 | stream_name:str=None, console_debug:bool=False, **vis_args):
10 | import pandas as pd # expensive import
11 |
12 | super(TextVis, self).__init__(VisBase.widgets.HTML(), cell, title, show_legend,
13 | stream_name=stream_name, console_debug=console_debug, **vis_args)
14 | self.df = pd.DataFrame([])
15 | self.SeriesClass = pd.Series
16 |
17 | def _get_column_prefix(self, stream_vis, i):
18 | return '[S.{}]:{}'.format(stream_vis.index, i)
19 |
20 | def _get_title(self, stream_vis):
21 | title = stream_vis.title or 'Stream ' + str(len(self._stream_vises))
22 | return title
23 |
24 | # this will be called from _show_stream_items
25 | def _append(self, stream_vis, vals):
26 | if vals is None:
27 | self.df = self.df.append(self.SeriesClass({self._get_column_prefix(stream_vis, 0) : None}),
28 | sort=False, ignore_index=True)
29 | return
30 | for val in vals:
31 | if val is None or utils.is_scalar(val):
32 | self.df = self.df.append(self.SeriesClass({self._get_column_prefix(stream_vis, 0) : val}),
33 | sort=False, ignore_index=True)
34 | elif utils.is_array_like(val):
35 | val_dict = {}
36 | for i,val_i in enumerate(val):
37 | val_dict[self._get_column_prefix(stream_vis, i)] = val_i
38 | self.df = self.df.append(self.SeriesClass(val_dict), sort=False, ignore_index=True)
39 | else:
40 | self.df = self.df.append(self.SeriesClass(val.__dict__), sort=False, ignore_index=True)
41 |
42 | def _post_add_subscription(self, stream_vis, **stream_vis_args):
43 | only_summary = stream_vis_args.get('only_summary', False)
44 | stream_vis.text = self._get_title(stream_vis)
45 | stream_vis.only_summary = only_summary
46 |
47 | def clear_plot(self, stream_vis, clear_history):
48 | self.df = self.df.iloc[0:0]
49 |
50 | def _show_stream_items(self, stream_vis, stream_items):
51 | """Paint the given stream_items in to visualizer. If visualizer is dirty then return False else True.
52 | """
53 | for stream_item in stream_items:
54 | if stream_item.ended:
55 | self.df = self.df.append(self.SeriesClass({'Ended':True}),
56 | sort=False, ignore_index=True)
57 | else:
58 | vals = self._extract_vals((stream_item,))
59 | self._append(stream_vis, vals)
60 | return False # dirty
61 |
62 | def _post_update_stream_plot(self, stream_vis):
63 | if VisBase.get_ipython():
64 | if not stream_vis.only_summary:
65 | self.widget.value = self.df.to_html(classes=['output_html', 'rendered_html'])
66 | else:
67 | self.widget.value = self.df.describe().to_html(classes=['output_html', 'rendered_html'])
68 | # below doesn't work because of threading issue
69 | #self.widget.clear_output(wait=True)
70 | #with self.widget:
71 | # display.display(self.df)
72 | else:
73 | last_recs = self.df.iloc[[-1]].to_dict('records')
74 | if len(last_recs) == 1:
75 | print(last_recs[0])
76 | else:
77 | print(last_recs)
78 |
79 | def _show_widget_native(self, blocking:bool):
80 | return None # we will be using console
81 |
82 | def _show_widget_notebook(self):
83 | return self.widget
--------------------------------------------------------------------------------
/tensorwatch/model_graph/torchstat_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import torchstat
5 | import pandas as pd
6 |
7 | def model_stats(model, input_shape):
8 | if len(input_shape) > 3:
9 | input_shape = input_shape[1:4]
10 | ms = torchstat.statistics.ModelStat(model, input_shape, 1)
11 | collected_nodes = ms._analyze_model()
12 | return _report_format(collected_nodes)
13 |
14 | def _round_value(value, binary=False):
15 | divisor = 1024. if binary else 1000.
16 |
17 | if value // divisor**4 > 0:
18 | return str(round(value / divisor**4, 2)) + 'T'
19 | elif value // divisor**3 > 0:
20 | return str(round(value / divisor**3, 2)) + 'G'
21 | elif value // divisor**2 > 0:
22 | return str(round(value / divisor**2, 2)) + 'M'
23 | elif value // divisor > 0:
24 | return str(round(value / divisor, 2)) + 'K'
25 | return str(value)
26 |
27 |
28 | def _report_format(collected_nodes):
29 | pd.set_option('display.width', 1000)
30 | pd.set_option('display.max_rows', 10000)
31 | pd.set_option('display.max_columns', 10000)
32 |
33 | data = list()
34 | for node in collected_nodes:
35 | name = node.name
36 | input_shape = ' '.join(['{:>3d}'] * len(node.input_shape)).format(
37 | *[e for e in node.input_shape])
38 | output_shape = ' '.join(['{:>3d}'] * len(node.output_shape)).format(
39 | *[e for e in node.output_shape])
40 | parameter_quantity = node.parameter_quantity
41 | inference_memory = node.inference_memory
42 | MAdd = node.MAdd
43 | Flops = node.Flops
44 | mread, mwrite = [i for i in node.Memory]
45 | duration = node.duration
46 | data.append([name, input_shape, output_shape, parameter_quantity,
47 | inference_memory, MAdd, duration, Flops, mread,
48 | mwrite])
49 | df = pd.DataFrame(data)
50 | df.columns = ['module name', 'input shape', 'output shape',
51 | 'params', 'memory(MB)',
52 | 'MAdd', 'duration', 'Flops', 'MemRead(B)', 'MemWrite(B)']
53 | df['duration[%]'] = df['duration'] / (df['duration'].sum() + 1e-7)
54 | df['MemR+W(B)'] = df['MemRead(B)'] + df['MemWrite(B)']
55 | total_parameters_quantity = df['params'].sum()
56 | total_memory = df['memory(MB)'].sum()
57 | total_operation_quantity = df['MAdd'].sum()
58 | total_flops = df['Flops'].sum()
59 | total_duration = df['duration[%]'].sum()
60 | total_mread = df['MemRead(B)'].sum()
61 | total_mwrite = df['MemWrite(B)'].sum()
62 | total_memrw = df['MemR+W(B)'].sum()
63 | del df['duration']
64 |
65 | # Add Total row
66 | total_df = pd.Series([total_parameters_quantity, total_memory,
67 | total_operation_quantity, total_flops,
68 | total_duration, mread, mwrite, total_memrw],
69 | index=['params', 'memory(MB)', 'MAdd', 'Flops', 'duration[%]',
70 | 'MemRead(B)', 'MemWrite(B)', 'MemR+W(B)'],
71 | name='total')
72 | df = df.append(total_df)
73 |
74 | df = df.fillna(' ')
75 | df['memory(MB)'] = df['memory(MB)'].apply(
76 | lambda x: '{:.2f}'.format(x))
77 | df['duration[%]'] = df['duration[%]'].apply(lambda x: '{:.2%}'.format(x))
78 | df['MAdd'] = df['MAdd'].apply(lambda x: '{:,}'.format(x))
79 | df['Flops'] = df['Flops'].apply(lambda x: '{:,}'.format(x))
80 |
81 | #summary = "Total params: {:,}\n".format(total_parameters_quantity)
82 |
83 | #summary += "-" * len(str(df).split('\n')[0])
84 | #summary += '\n'
85 | #summary += "Total memory: {:.2f}MB\n".format(total_memory)
86 | #summary += "Total MAdd: {}MAdd\n".format(_round_value(total_operation_quantity))
87 | #summary += "Total Flops: {}Flops\n".format(_round_value(total_flops))
88 | #summary += "Total MemR+W: {}B\n".format(_round_value(total_memrw, True))
89 | return df
--------------------------------------------------------------------------------
/tensorwatch/watcher_client.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from typing import Any, Dict, Sequence, List
5 | from .zmq_wrapper import ZmqWrapper
6 | from .lv_types import CliSrvReqTypes, ClientServerRequest, DefaultPorts
7 | from .lv_types import VisArgs, PublisherTopics, ServerMgmtMsg, StreamCreateRequest
8 | from .stream import Stream
9 | from .zmq_mgmt_stream import ZmqMgmtStream
10 | from . import utils
11 | from .watcher_base import WatcherBase
12 |
13 | class WatcherClient(WatcherBase):
14 | r"""Extends watcher to add methods so calls for create and delete stream can be sent to server.
15 | """
16 | def __init__(self, filename:str=None, port:int=0):
17 | super(WatcherClient, self).__init__()
18 | self.port = port
19 | self.filename = filename
20 |
21 | # define vars in __init__
22 | self._clisrv = None # client-server sockets allows to send create/del stream requests
23 | self._zmq_srvmgmt_sub = None
24 | self._file = None
25 |
26 | self._open()
27 |
28 | def _reset(self):
29 | self._clisrv = None
30 | self._zmq_srvmgmt_sub = None
31 | self._file = None
32 | utils.debug_log("WatcherClient reset", verbosity=1)
33 | super(WatcherClient, self)._reset()
34 |
35 | def _open(self):
36 | if self.port is not None:
37 | self._clisrv = ZmqWrapper.ClientServer(port=DefaultPorts.CliSrv+self.port,
38 | is_server=False)
39 | # create subscription where we will receive server management events
40 | self._zmq_srvmgmt_sub = ZmqMgmtStream(clisrv=self._clisrv, for_write=False, port=self.port,
41 | stream_name='zmq_srvmgmt_sub:'+str(self.port)+':False')
42 |
43 | def close(self):
44 | if not self.closed:
45 | self._zmq_srvmgmt_sub.close()
46 | self._clisrv.close()
47 | utils.debug_log("WatcherClient is closed", verbosity=1)
48 | super(WatcherClient, self).close()
49 |
50 | def devices_or_default(self, devices:Sequence[str])->Sequence[str]: # overridden
51 | # TODO: this method is duplicated in Watcher and WatcherClient
52 |
53 | # make sure TCP port is attached to tcp device
54 | if devices is not None:
55 | return ['tcp:' + str(self.port) if device=='tcp' else device for device in devices]
56 |
57 | # if no devices specified then use our filename and tcp:port as default devices
58 | devices = []
59 | # first open file device because it may have older data
60 | if self.filename is not None:
61 | devices.append('file:' + self.filename)
62 | if self.port is not None:
63 | devices.append('tcp:' + str(self.port))
64 | return devices
65 |
66 | # override to send request to server, instead of underlying WatcherBase base class
67 | def create_stream(self, name:str=None, devices:Sequence[str]=None, event_name:str='',
68 | expr=None, throttle:float=1, vis_args:VisArgs=None)->Stream: # overriden
69 |
70 | stream_req = StreamCreateRequest(stream_name=name, devices=self.devices_or_default(devices),
71 | event_name=event_name, expr=expr, throttle=throttle, vis_args=vis_args)
72 |
73 | self._zmq_srvmgmt_sub.add_stream_req(stream_req)
74 |
75 | if stream_req.devices is not None:
76 | stream = self.open_stream(name=stream_req.stream_name, devices=stream_req.devices)
77 | else: # we cannot return remote streams that are not backed by a device
78 | stream = None
79 | return stream
80 |
81 | # override to set devices default to tcp
82 | def open_stream(self, name:str=None, devices:Sequence[str]=None)->Stream: # overriden
83 | return super(WatcherClient, self).open_stream(name=name, devices=devices)
84 |
85 |
86 | # override to send request to server
87 | def del_stream(self, name:str) -> None:
88 | self._zmq_srvmgmt_sub.del_stream(name)
89 |
90 |
--------------------------------------------------------------------------------
/docs/simple_logging.md:
--------------------------------------------------------------------------------
1 |
2 | # Simple Logging Tutorial
3 |
4 | In this tutorial, we will use a simple script that maintains two variables. Our goal is to log the values of these variables using TensorWatch and view it in real-time in variety of ways from Jupyter Notebook demonstrating various features offered by TensorWatch. We would execute this script from a console and its code looks like this:
5 |
6 | ```
7 | # available in test/simple_log/sum_log.py
8 |
9 | import time, random
10 | import tensorwatch as tw
11 |
12 | # we will create two streams, one for
13 | # sums of integers, other for random numbers
14 | w = tw.Watcher()
15 | st_isum = w.create_stream('isums')
16 | st_rsum = w.create_stream('rsums')
17 |
18 | isum, rsum = 0, 0
19 | for i in range(10000):
20 | isum += i
21 | rsum += random.randint(i)
22 |
23 | # write to streams
24 | st_isum.write((i, isum))
25 | st_rsum.write((i, rsum))
26 |
27 | time.sleep(1)
28 | ```
29 |
30 | ## Basic Concepts
31 |
32 | The `Watcher` object allows you to create streams. You can then write any values in to these streams. By default, the `Watcher` object opens up TCP/IP sockets so a client can request these streams. As values are being written on one end, they get serialized and sent to any client(s) on the other end. We will use Jupyter Notebook to create a client, open these streams and feed them to various visualizers.
33 |
34 | ## Create the Client
35 | You can either create a new Jupyter Notebook or get the existing one in the repo at `notebooks/simple_logging.ipynb`.
36 |
37 | Let's first do imports:
38 |
39 |
40 | ```python
41 | %matplotlib notebook
42 | import tensorwatch as tw
43 | ```
44 |
45 | Next, open the streams that we had created in above script:
46 |
47 | ```python
48 | cli = tw.WatcherClient()
49 | st_isum = cli.open_stream('isums')
50 | st_rsum = cli.open_stream('rsums')
51 | ```
52 |
53 | Now lets visualize isum stream in textual format:
54 |
55 |
56 | ```python
57 | text_vis = tw.Visualizer(st_isum, vis_type='text')
58 | text_vis.show()
59 | ```
60 |
61 | You should see a growing table in real-time like this, each row is tuple we wrote to the stream:
62 |
63 |
64 |
65 |
66 | That worked out good! How about feeding the same stream simultaneously to plot a line graph?
67 |
68 |
69 | ```python
70 | line_plot = tw.Visualizer(st_isum, vis_type='line', xtitle='i', ytitle='isum')
71 | line_plot.show()
72 | ```
73 |
74 |
75 |
76 | Let's take step further. Say, we plot rsum as well but, just for fun, in the *same plot* as isum above. That's easy with the optional `host` parameter:
77 |
78 | ```python
79 | line_plot = tw.Visualizer(st_rsum, vis_type='line', host=line_plot, ytitle='rsum')
80 | ```
81 | This instantly changes our line plot and a new Y-Axis is automatically added for our second stream. TensorWatch allows you to add multiple streams in same visualizer.
82 |
83 | Are you curious about statistics of these two streams? Let's display the live statistics of both streams side by side! To do this, we will use the `cell` parameter for the second visualization to place right next to the first one:
84 |
85 | ```python
86 | istats = tw.Visualizer(st_isum, vis_type='summary')
87 | istats.show()
88 |
89 | rstats = tw.Visualizer(st_rsum, vis_type='summary', cell=istats)
90 | ```
91 |
92 | The output looks like this (each panel showing statistics for the tuple that was logged in the stream):
93 |
94 |
95 |
96 | ## Next Steps
97 |
98 | We have just scratched the surface of what you can do with TensorWatch. You should check out [Lazy Logging Tutorial](lazy_logging.md) and other [notebooks](https://github.com/microsoft/tensorwatch/tree/master/notebooks) as well.
99 |
100 | ## Questions?
101 |
102 | File a [Github issue](https://github.com/microsoft/tensorwatch/issues/new) and let us know if we can improve this tutorial.
--------------------------------------------------------------------------------
/test/mnist/cli_mnist.py:
--------------------------------------------------------------------------------
1 | import tensorwatch as tw
2 | import time
3 | import math
4 | from tensorwatch import utils
5 |
6 | utils.set_debug_verbosity(4)
7 |
8 | def img_in_class():
9 | cli_train = tw.WatcherClient()
10 |
11 | imgs = cli_train.create_stream(event_name='batch',
12 | expr="topk_all(l, batch_vals=lambda b: (b.batch.loss_all, (b.batch.input, b.batch.output), b.batch.target), \
13 | out_f=image_class_outf, order='dsc')", throttle=1)
14 | img_plot = tw.ImagePlot()
15 | img_plot.subscribe(imgs, viz_img_scale=3)
16 | img_plot.show()
17 |
18 | tw.image_utils.plt_loop()
19 |
20 | def show_find_lr():
21 | cli_train = tw.WatcherClient()
22 | plot = tw.LinePlot()
23 |
24 | train_batch_loss = cli_train.create_stream(event_name='batch',
25 | expr='lambda d:(d.tt.scheduler.get_lr()[0], d.metrics.batch_loss)')
26 | plot.subscribe(train_batch_loss, xtitle='Epoch', ytitle='Loss')
27 |
28 | utils.wait_key()
29 |
30 | def plot_grads_plotly():
31 | train_cli = tw.WatcherClient()
32 | grads = train_cli.create_stream(event_name='batch',
33 | expr='lambda d:grads_abs_mean(d.model)', throttle=1)
34 | p = tw.plotly.line_plot.LinePlot('Demo')
35 | p.subscribe(grads, xtitle='Layer', ytitle='Gradients', history_len=30, new_on_eval=True)
36 | utils.wait_key()
37 |
38 |
39 | def plot_grads():
40 | train_cli = tw.WatcherClient()
41 |
42 | grads = train_cli.create_stream(event_name='batch',
43 | expr='lambda d:grads_abs_mean(d.model)', throttle=1)
44 | grad_plot = tw.LinePlot()
45 | grad_plot.subscribe(grads, xtitle='Layer', ytitle='Gradients', clear_after_each=1, history_len=40, dim_history=True)
46 | grad_plot.show()
47 |
48 | tw.plt_loop()
49 |
50 | def plot_weight():
51 | train_cli = tw.WatcherClient()
52 |
53 | params = train_cli.create_stream(event_name='batch',
54 | expr='lambda d:weights_abs_mean(d.model)', throttle=1)
55 | params_plot = tw.LinePlot()
56 | params_plot.subscribe(params, xtitle='Layer', ytitle='avg |params|', clear_after_each=1, history_len=40, dim_history=True)
57 | params_plot.show()
58 |
59 | tw.plt_loop()
60 |
61 | def epoch_stats():
62 | train_cli = tw.WatcherClient(port=0)
63 | test_cli = tw.WatcherClient(port=1)
64 |
65 | plot = tw.LinePlot()
66 |
67 | train_loss = train_cli.create_stream(event_name="epoch",
68 | expr='lambda v:(v.metrics.epoch_index, v.metrics.epoch_loss)')
69 | plot.subscribe(train_loss, xtitle='Epoch', ytitle='Train Loss')
70 |
71 | test_acc = test_cli.create_stream(event_name="epoch",
72 | expr='lambda v:(v.metrics.epoch_index, v.metrics.epoch_accuracy)')
73 | plot.subscribe(test_acc, xtitle='Epoch', ytitle='Test Accuracy', ylim=(0,1))
74 |
75 | plot.show()
76 | tw.plt_loop()
77 |
78 |
79 | def batch_stats():
80 | train_cli = tw.WatcherClient()
81 | stream = train_cli.create_stream(event_name="batch",
82 | expr='lambda v:(v.metrics.epochf, v.metrics.batch_loss)', throttle=0.75)
83 |
84 | train_loss = tw.Visualizer(stream, clear_after_end=False, vis_type='mpl-line',
85 | xtitle='Epoch', ytitle='Train Loss')
86 |
87 | #train_acc = tw.Visualizer('lambda v:(v.metrics.epochf, v.metrics.epoch_loss)', event_name="batch",
88 | # xtitle='Epoch', ytitle='Train Accuracy', clear_after_end=False, yrange=(0,1),
89 | # vis=train_loss, vis_type='mpl-line')
90 |
91 | train_loss.show()
92 | tw.plt_loop()
93 |
94 | def text_stats():
95 | train_cli = tw.WatcherClient()
96 | stream = train_cli.create_stream(event_name="batch",
97 | expr='lambda d:(d.metrics.epoch_index, d.metrics.batch_loss)')
98 |
99 | trl = tw.Visualizer(stream, vis_type='text')
100 | trl.show()
101 | input('Paused...')
102 |
103 |
104 |
105 | #epoch_stats()
106 | #plot_weight()
107 | #plot_grads()
108 | #img_in_class()
109 | text_stats()
110 | #batch_stats()
--------------------------------------------------------------------------------
/docs/lazy_logging.md:
--------------------------------------------------------------------------------
1 |
2 | # Lazy Logging Tutorial
3 |
4 | In this tutorial, we will use a straightforward script that creates an array of random numbers. Our goal is to examine this array from the Jupyter Notebook while below script is running separately in a console.
5 |
6 | ```
7 | # available in test/simple_log/sum_lazy.py
8 |
9 | import time, random
10 | import tensorwatch as tw
11 |
12 | # create watcher object as usual
13 | w = tw.Watcher()
14 |
15 | weights = None
16 | for i in range(10000):
17 | weights = [random.random() for _ in range(5)]
18 |
19 | # let watcher observe variables we have
20 | # this has almost no performance cost
21 | w.observe(weights=weights)
22 |
23 | time.sleep(1)
24 | ```
25 |
26 | ## Basic Concepts
27 |
28 | Notice that we give `Watcher` object an opportunity to *observe* variables. Observing variables is very cheap so we can observe anything that we might be interested in, for example, an entire model or a batch of data. From the Jupyter Notebook, you can specify arbitrary *lambda expression* that we execute in the context of these observed variables. The result of this lambda expression is sent as stream back to Jupyter Notebook so that you can render into visualizer of your choice. That's pretty much it.
29 |
30 | ## How Does This Work?
31 |
32 | TensorWatch has two core classes: `Watcher` and `WatcherClient`. The `Watcher` object will open up TCP/IP sockets by default and listens to any incoming requests. The `WatcherClient` allows you to connect to `Watcher` and have it execute any Python [lambda expression](http://book.pythontips.com/en/latest/lambdas.html) you want. The Python lambda expressions consume a stream of values as input and outputs another stream of values. The lambda expressions may typically contain [map, filter, and reduce](http://book.pythontips.com/en/latest/map_filter.html) so you can transform values in the input stream, filter them, or aggregate them. Let's see all these with a simple example.
33 |
34 | ## Create the Client
35 | You can either create a Jupyter Notebook or get the existing one in the repo at `notebooks/lazy_logging.ipynb`.
36 |
37 | Let's first do imports:
38 |
39 |
40 | ```python
41 | %matplotlib notebook
42 | import tensorwatch as tw
43 | ```
44 |
45 | Next, create a stream using a simple lambda expression that sums values in our `weights` array that we are observing in the above script. The input to the lambda expression is an object that has all the variables we are observing.
46 |
47 | ```python
48 | client = tw.WatcherClient()
49 | stream = client.create_stream(expr='lambda d: np.sum(d.weights)')
50 | ```
51 |
52 | Next, send this stream to line chart visualization:
53 |
54 | ```python
55 | line_plot = tw.Visualizer(stream, vis_type='line')
56 | line_plot.show()
57 | ```
58 |
59 | Now when you run our script `sum_lazzy.py` in the console and then above Jupyter Notebook, you will see the sum of `weights` array getting plotted in real-time:
60 |
61 |
62 |
63 | ## What Just Happened?
64 |
65 | Notice that your script is running in a different process than Jupyter Notebook. You could have been running your Jupyter Notebook on your laptop while your script could have been in some VM in a cloud. You can do queries on observed variables at the run-time! As `Watcher.observe()` is a cheap call, you could observe your entire model, batch inputs and outputs, and so on. You can then slice and dice all these observed variables from the comfort of your Jupyter Notebook while your script is progressing!
66 |
67 | ## Next Steps
68 |
69 | We have just scratched the surface in this new land of lazy logging! You can do much more with this idea. For example, you can create *events* so your lambda expression gets executed only on those events. This is useful to create streams that generate data per batch or per epoch. You can also use many predefined lambda expressions that allows you to view top k by some criteria.
70 |
71 | ## Questions?
72 |
73 | File a [Github issue](https://github.com/microsoft/tensorwatch/issues/new) and let us know if we can improve this tutorial.
--------------------------------------------------------------------------------
/tensorwatch/notebook_maker.py:
--------------------------------------------------------------------------------
1 | import nbformat
2 | from nbformat.v4 import new_code_cell, new_markdown_cell, new_notebook
3 | from nbformat import v3, v4
4 | import codecs
5 | from os import linesep, path
6 | import uuid
7 | from . import utils
8 | import re
9 | from typing import List
10 | from .lv_types import VisArgs
11 |
12 | class NotebookMaker:
13 | def __init__(self, watcher, filename:str=None)->None:
14 | self.filename = filename or \
15 | (path.splitext(watcher.filename)[0] + '.ipynb' if watcher.filename else \
16 | 'tensorwatch.ipynb')
17 |
18 | self.cells = []
19 | self._default_vis_args = VisArgs()
20 |
21 | watcher_args_str = NotebookMaker._get_vis_args(watcher)
22 |
23 | # create initial cell
24 | self.cells.append(new_code_cell(source=linesep.join(
25 | ['%matplotlib notebook',
26 | 'import tensorwatch as tw',
27 | 'client = tw.WatcherClient({})'.format(NotebookMaker._get_vis_args(watcher))])))
28 |
29 | def _get_vis_args(watcher)->str:
30 | args_strs = []
31 | for param, default_v in [('port', 0), ('filename', None)]:
32 | if hasattr(watcher, param):
33 | v = getattr(watcher, param)
34 | if v==default_v or (v is None and default_v is None):
35 | continue
36 | args_strs.append("{}={}".format(param, NotebookMaker._val2str(v)))
37 | return ', '.join(args_strs)
38 |
39 | def _get_stream_identifier(prefix, event_name, stream_name, stream_index)->str:
40 | if not stream_name or utils.is_uuid4(stream_name):
41 | if event_name is not None and event_name != '':
42 | return '{}_{}_{}'.format(prefix, event_name, stream_index)
43 | else:
44 | return prefix + str(stream_index)
45 | else:
46 | return '{}{}_{}'.format(prefix, stream_index, utils.str2identifier(stream_name)[:8])
47 |
48 | def _val2str(v)->str:
49 | # TODO: shall we raise error if non str, bool, number (or its container) parameters?
50 | return str(v) if not isinstance(v, str) else "'{}'".format(v)
51 |
52 | def _add_vis_args_str(self, stream_info, param_strs:List[str])->None:
53 | default_args = self._default_vis_args.__dict__
54 | if not stream_info.req.vis_args:
55 | return
56 | for k, v in stream_info.req.vis_args.__dict__.items():
57 | if k in default_args:
58 | default_v = default_args[k]
59 | if (v is None and default_v is None) or (v==default_v):
60 | continue # skip param if its value is not changed from default
61 | param_strs.append("{}={}".format(k, NotebookMaker._val2str(v)))
62 |
63 | def _get_stream_code(self, event_name, stream_name, stream_index, stream_info)->List[str]:
64 | lines = []
65 |
66 | stream_identifier = 's'+str(stream_index)
67 | lines.append("{} = client.open_stream(name='{}')".format(stream_identifier, stream_name))
68 |
69 | vis_identifier = 'v'+str(stream_index)
70 | vis_args_strs = ['stream={}'.format(stream_identifier)]
71 | self._add_vis_args_str(stream_info, vis_args_strs)
72 | lines.append("{} = tw.Visualizer({})".format(vis_identifier, ', '.join(vis_args_strs)))
73 | lines.append("{}.show()".format(vis_identifier))
74 | return lines
75 |
76 | def add_streams(self, event_stream_infos)->None:
77 | stream_index = 0
78 | for event_name, stream_infos in event_stream_infos.items(): # per event
79 | for stream_name, stream_info in stream_infos.items():
80 | lines = self._get_stream_code(event_name, stream_name, stream_index, stream_info)
81 | self.cells.append(new_code_cell(source=linesep.join(lines)))
82 | stream_index += 1
83 |
84 | def write(self):
85 | nb = new_notebook(cells=self.cells, metadata={'language': 'python',})
86 | with codecs.open(self.filename, encoding='utf-8', mode='w') as f:
87 | nbformat.write(nb, f, 4)
88 |
89 |
90 |
91 |
92 |
--------------------------------------------------------------------------------
/tensorwatch/watcher.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import uuid
5 | from typing import Sequence
6 | from .zmq_wrapper import ZmqWrapper
7 | from .watcher_base import WatcherBase
8 | from .lv_types import CliSrvReqTypes
9 | from .lv_types import DefaultPorts, PublisherTopics, ServerMgmtMsg
10 | from . import utils
11 | import threading, time
12 |
13 | class Watcher(WatcherBase):
14 | def __init__(self, filename:str=None, port:int=0, srv_name:str=None):
15 | super(Watcher, self).__init__()
16 |
17 | self.port = port
18 | self.filename = filename
19 |
20 | # used to detect server restarts
21 | self.srv_name = srv_name or str(uuid.uuid4())
22 |
23 | # define vars in __init__
24 | self._clisrv = None
25 | self._zmq_stream_pub = None
26 | self._file = None
27 | self._th = None
28 |
29 | self._open_devices()
30 |
31 | def _open_devices(self):
32 | if self.port is not None:
33 | self._clisrv = ZmqWrapper.ClientServer(port=DefaultPorts.CliSrv+self.port,
34 | is_server=True, callback=self._clisrv_callback)
35 |
36 | # notify existing listeners of our ID
37 | self._zmq_stream_pub = self._stream_factory.get_streams(stream_types=['tcp:'+str(self.port)], for_write=True)[0]
38 |
39 | # ZMQ quirk: we must wait a bit after opening port and before sending message
40 | # TODO: can we do better?
41 | self._th = threading.Thread(target=self._send_server_start)
42 | self._th.start()
43 |
44 | def _send_server_start(self):
45 | time.sleep(2)
46 | self._zmq_stream_pub.write(ServerMgmtMsg(event_name=ServerMgmtMsg.EventServerStart,
47 | event_args=self.srv_name), topic=PublisherTopics.ServerMgmt)
48 |
49 | def devices_or_default(self, devices:Sequence[str])->Sequence[str]: # overriden
50 | # TODO: this method is duplicated in Watcher and WatcherClient
51 |
52 | # make sure TCP port is attached to tcp device
53 | if devices is not None:
54 | return ['tcp:' + str(self.port) if device=='tcp' else device for device in devices]
55 |
56 | # if no devices specified then use our filename and tcp:port as default devices
57 | devices = []
58 | # first open file device because it may have older data
59 | if self.filename is not None:
60 | devices.append('file:' + self.filename)
61 | if self.port is not None:
62 | devices.append('tcp:' + str(self.port))
63 | return devices
64 |
65 | def close(self):
66 | if not self.closed:
67 | if self._clisrv is not None:
68 | self._clisrv.close()
69 | if self._zmq_stream_pub is not None:
70 | self._zmq_stream_pub.close()
71 | if self._file is not None:
72 | self._file.close()
73 | utils.debug_log("Watcher is closed", verbosity=1)
74 | super(Watcher, self).close()
75 |
76 | def _reset(self):
77 | self._clisrv = None
78 | self._zmq_stream_pub = None
79 | self._file = None
80 | self._th = None
81 | utils.debug_log("Watcher reset", verbosity=1)
82 | super(Watcher, self)._reset()
83 |
84 | def _clisrv_callback(self, clisrv, clisrv_req): # pylint: disable=unused-argument
85 | utils.debug_log("Received client request", clisrv_req.req_type)
86 |
87 | # request = create stream
88 | if clisrv_req.req_type == CliSrvReqTypes.create_stream:
89 | stream_req = clisrv_req.req_data
90 | self.create_stream(name=stream_req.stream_name, devices=stream_req.devices,
91 | event_name=stream_req.event_name, expr=stream_req.expr, throttle=stream_req.throttle,
92 | vis_args=stream_req.vis_args)
93 | return None # ignore return as we can't send back stream obj
94 | elif clisrv_req.req_type == CliSrvReqTypes.del_stream:
95 | stream_name = clisrv_req.req_data
96 | return self.del_stream(stream_name)
97 | else:
98 | raise ValueError('ClientServer Request Type {} is not recognized'.format(clisrv_req))
--------------------------------------------------------------------------------
/tensorwatch/mpl/image_plot.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from .base_mpl_plot import BaseMplPlot
5 | from .. import utils, image_utils
6 | import numpy as np
7 |
8 | #from IPython import get_ipython
9 |
10 | class ImagePlot(BaseMplPlot):
11 | def init_stream_plot(self, stream_vis,
12 | rows=2, cols=5, img_width=None, img_height=None, img_channels=None,
13 | colormap=None, viz_img_scale=None, **stream_vis_args):
14 | stream_vis.rows, stream_vis.cols = rows, cols
15 | stream_vis.img_channels, stream_vis.colormap = img_channels, colormap
16 | stream_vis.img_width, stream_vis.img_height = img_width, img_height
17 | stream_vis.viz_img_scale = viz_img_scale
18 | # subplots holding each image
19 | stream_vis.axs = [[None for _ in range(cols)] for _ in range(rows)]
20 | # axis image
21 | stream_vis.ax_imgs = [[None for _ in range(cols)] for _ in range(rows)]
22 |
23 | def clear_plot(self, stream_vis, clear_history):
24 | for row in range(stream_vis.rows):
25 | for col in range(stream_vis.cols):
26 | img = stream_vis.ax_imgs[row][col]
27 | if img:
28 | x, y = img.get_size()
29 | img.set_data(np.zeros((x, y)))
30 |
31 | def _show_stream_items(self, stream_vis, stream_items):
32 | """Paint the given stream_items in to visualizer. If visualizer is dirty then return False else True.
33 | """
34 |
35 | # as we repaint each image plot, select last if multiple events were pending
36 | stream_item = None
37 | for er in reversed(stream_items):
38 | if not(er.ended or er.value is None):
39 | stream_item = er
40 | break
41 | if stream_item is None:
42 | return True
43 |
44 | row, col, i = 0, 0, 0
45 | dirty = False
46 | # stream_item.value is expected to be ImagePlotItems
47 | for image_list in stream_item.value:
48 | # convert to imshow compatible, stitch images
49 | images = [image_utils.to_imshow_array(img, stream_vis.img_width, stream_vis.img_height) \
50 | for img in image_list.images if img is not None]
51 | img_viz = image_utils.stitch_horizontal(images, width_dim=1)
52 |
53 | # resize if requested
54 | if stream_vis.viz_img_scale is not None:
55 | import skimage.transform # expensive import done on demand
56 |
57 | if isinstance(img_viz, np.ndarray) and np.issubdtype(img_viz.dtype, np.floating):
58 | img_viz = img_viz.clip(-1, 1) # some MNIST images have out of range values causing exception in sklearn
59 | img_viz = skimage.transform.rescale(img_viz,
60 | (stream_vis.viz_img_scale, stream_vis.viz_img_scale), mode='reflect', preserve_range=False)
61 |
62 | # create subplot if it doesn't exist
63 | ax = stream_vis.axs[row][col]
64 | if ax is None:
65 | ax = stream_vis.axs[row][col] = \
66 | self.figure.add_subplot(stream_vis.rows, stream_vis.cols, i+1)
67 | ax.set_xticks([])
68 | ax.set_yticks([])
69 |
70 | cmap = image_list.cmap or ('Greys' if stream_vis.colormap is None and \
71 | len(img_viz.shape) == 2 else stream_vis.colormap)
72 |
73 | stream_vis.ax_imgs[row][col] = ax.imshow(img_viz, interpolation="none", cmap=cmap, alpha=image_list.alpha)
74 | dirty = True
75 |
76 | # set title
77 | title = image_list.title
78 | if len(title) > 12: #wordwrap if too long
79 | title = utils.wrap_string(title) if len(title) > 24 else title
80 | fontsize = 8
81 | else:
82 | fontsize = 12
83 | ax.set_title(title, fontsize=fontsize) #'fontweight': 'light'
84 |
85 | #ax.autoscale_view() # not needed
86 | col = col + 1
87 | if col >= stream_vis.cols:
88 | col = 0
89 | row = row + 1
90 | if row >= stream_vis.rows:
91 | break
92 | i += 1
93 |
94 | return not dirty
95 |
96 |
97 | def has_legend(self):
98 | return self.show_legend or False
--------------------------------------------------------------------------------
/tensorwatch/stream_factory.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from typing import Dict, Sequence, List, Any
5 | from .zmq_stream import ZmqStream
6 | from .file_stream import FileStream
7 | from .stream import Stream
8 | from .stream_union import StreamUnion
9 | import uuid
10 | from . import utils
11 |
12 | class StreamFactory:
13 | r"""Allows to create shared stream such as file and ZMQ streams
14 | """
15 |
16 | def __init__(self)->None:
17 | self.closed = None
18 | self._streams:Dict[str, Stream] = None
19 | self._reset()
20 |
21 | def _reset(self):
22 | self._streams:Dict[str, Stream] = {}
23 | self.closed = False
24 |
25 | def close(self):
26 | if not self.closed:
27 | for stream in self._streams.values():
28 | stream.close()
29 | self._reset()
30 | self.closed = True
31 |
32 | def __enter__(self):
33 | return self
34 | def __exit__(self, exception_type, exception_value, traceback):
35 | self.close()
36 |
37 | def get_streams(self, stream_types:Sequence[str], for_write:bool=None)->List[Stream]:
38 | streams = [self._create_stream_by_string(stream_type, for_write) for stream_type in stream_types]
39 | return streams
40 |
41 | def get_combined_stream(self, stream_types:Sequence[str], for_write:bool=None)->Stream:
42 | streams = [self._create_stream_by_string(stream_type, for_write) for stream_type in stream_types]
43 | if len(streams) == 1:
44 | return self._streams[0]
45 | else:
46 | # we create new union of child but this is not necessory
47 | return StreamUnion(streams, for_write=for_write)
48 |
49 | def _get_stream_name(stream_type:str, stream_args:Any, for_write:bool)->str:
50 | return '{}:{}:{}'.format(stream_type, stream_args, for_write)
51 |
52 | def _create_stream_by_string(self, stream_spec:str, for_write:bool)->Stream:
53 | parts = stream_spec.split(':', 1) if stream_spec is not None else ['']
54 | stream_type = parts[0]
55 | stream_args = parts[1] if len(parts) > 1 else None
56 |
57 | utils.debug_log("Creating stream", (stream_spec, for_write))
58 |
59 | if stream_type == 'tcp':
60 | port = int(stream_args or 0)
61 | stream_name = StreamFactory._get_stream_name(stream_type, port, for_write)
62 | if stream_name not in self._streams:
63 | self._streams[stream_name] = ZmqStream(for_write=for_write,
64 | port=port, stream_name=stream_name, block_until_connected=False)
65 | # else we already have this stream
66 | return self._streams[stream_name]
67 |
68 |
69 | if stream_args is None: # file name specified without 'file:' prefix
70 | stream_args = stream_type
71 | stream_type = 'file'
72 | if len(stream_type) == 1: # windows drive letter
73 | stream_type = 'file'
74 | stream_args = stream_spec
75 |
76 | if stream_type == 'file':
77 | if stream_args is None:
78 | raise ValueError('File name must be specified for stream type "file"')
79 | stream_name = StreamFactory._get_stream_name(stream_type, stream_args, for_write)
80 | # each read only file stream should be separate stream or otheriwse sharing will
81 | # change seek positions
82 | if not for_write:
83 | stream_name += ':' + str(uuid.uuid4())
84 |
85 | # if write file exist then flush it before read stream would read it
86 | write_stream_name = StreamFactory._get_stream_name(stream_type, stream_args, True)
87 | write_file_stream = self._streams.get(write_stream_name, None)
88 | if write_file_stream:
89 | write_file_stream.save()
90 | if stream_name not in self._streams:
91 | self._streams[stream_name] = FileStream(for_write=for_write,
92 | file_name=stream_args, stream_name=stream_name)
93 | # else we already have this stream
94 | return self._streams[stream_name]
95 |
96 | if stream_type == '':
97 | return Stream()
98 |
99 | raise ValueError('stream_type "{}" has unknown type'.format(stream_type))
100 |
101 |
102 |
--------------------------------------------------------------------------------
/tensorwatch/saliency/lime/wrappers/scikit_image.py:
--------------------------------------------------------------------------------
1 | import types
2 | from .generic_utils import has_arg
3 | from skimage.segmentation import felzenszwalb, slic, quickshift
4 |
5 |
6 | class BaseWrapper(object):
7 | """Base class for LIME Scikit-Image wrapper
8 |
9 |
10 | Args:
11 | target_fn: callable function or class instance
12 | target_params: dict, parameters to pass to the target_fn
13 |
14 |
15 | 'target_params' takes parameters required to instanciate the
16 | desired Scikit-Image class/model
17 | """
18 |
19 | def __init__(self, target_fn=None, **target_params):
20 | self.target_fn = target_fn
21 | self.target_params = target_params
22 |
23 | self.target_fn = target_fn
24 | self.target_params = target_params
25 |
26 | def _check_params(self, parameters):
27 | """Checks for mistakes in 'parameters'
28 |
29 | Args :
30 | parameters: dict, parameters to be checked
31 |
32 | Raises :
33 | ValueError: if any parameter is not a valid argument for the target function
34 | or the target function is not defined
35 | TypeError: if argument parameters is not iterable
36 | """
37 | a_valid_fn = []
38 | if self.target_fn is None:
39 | if callable(self):
40 | a_valid_fn.append(self.__call__)
41 | else:
42 | raise TypeError('invalid argument: tested object is not callable,\
43 | please provide a valid target_fn')
44 | elif isinstance(self.target_fn, types.FunctionType) \
45 | or isinstance(self.target_fn, types.MethodType):
46 | a_valid_fn.append(self.target_fn)
47 | else:
48 | a_valid_fn.append(self.target_fn.__call__)
49 |
50 | if not isinstance(parameters, str):
51 | for p in parameters:
52 | for fn in a_valid_fn:
53 | if has_arg(fn, p):
54 | pass
55 | else:
56 | raise ValueError('{} is not a valid parameter'.format(p))
57 | else:
58 | raise TypeError('invalid argument: list or dictionnary expected')
59 |
60 | def set_params(self, **params):
61 | """Sets the parameters of this estimator.
62 | Args:
63 | **params: Dictionary of parameter names mapped to their values.
64 |
65 | Raises :
66 | ValueError: if any parameter is not a valid argument
67 | for the target function
68 | """
69 | self._check_params(params)
70 | self.target_params = params
71 |
72 | def filter_params(self, fn, override=None):
73 | """Filters `target_params` and return those in `fn`'s arguments.
74 | Args:
75 | fn : arbitrary function
76 | override: dict, values to override target_params
77 | Returns:
78 | result : dict, dictionary containing variables
79 | in both target_params and fn's arguments.
80 | """
81 | override = override or {}
82 | result = {}
83 | for name, value in self.target_params.items():
84 | if has_arg(fn, name):
85 | result.update({name: value})
86 | result.update(override)
87 | return result
88 |
89 |
90 | class SegmentationAlgorithm(BaseWrapper):
91 | """ Define the image segmentation function based on Scikit-Image
92 | implementation and a set of provided parameters
93 |
94 | Args:
95 | algo_type: string, segmentation algorithm among the following:
96 | 'quickshift', 'slic', 'felzenszwalb'
97 | target_params: dict, algorithm parameters (valid model paramters
98 | as define in Scikit-Image documentation)
99 | """
100 |
101 | def __init__(self, algo_type, **target_params):
102 | self.algo_type = algo_type
103 | if (self.algo_type == 'quickshift'):
104 | BaseWrapper.__init__(self, quickshift, **target_params)
105 | kwargs = self.filter_params(quickshift)
106 | self.set_params(**kwargs)
107 | elif (self.algo_type == 'felzenszwalb'):
108 | BaseWrapper.__init__(self, felzenszwalb, **target_params)
109 | kwargs = self.filter_params(felzenszwalb)
110 | self.set_params(**kwargs)
111 | elif (self.algo_type == 'slic'):
112 | BaseWrapper.__init__(self, slic, **target_params)
113 | kwargs = self.filter_params(slic)
114 | self.set_params(**kwargs)
115 |
116 | def __call__(self, *args):
117 | return self.target_fn(args[0], **self.target_params)
118 |
--------------------------------------------------------------------------------
/tensorwatch/evaler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import threading, sys, logging
5 | from collections.abc import Iterator
6 | from .lv_types import EventData
7 |
8 | # pylint: disable=unused-wildcard-import
9 | # pylint: disable=wildcard-import
10 | # pylint: disable=unused-import
11 | from functools import *
12 | from itertools import *
13 | from statistics import *
14 | import numpy as np
15 | from .evaler_utils import *
16 |
17 | class Evaler:
18 | class EvalReturn:
19 | def __init__(self, result=None, is_valid=False, exception=None):
20 | self.result, self.exception, self.is_valid = \
21 | result, exception, is_valid
22 | def reset(self):
23 | self.result, self.exception, self.is_valid = \
24 | None, None, False
25 |
26 | class PostableIterator:
27 | def __init__(self, eval_wait):
28 | self.eval_wait = eval_wait
29 | self.post_wait = threading.Event()
30 | self.event_data, self.ended = None, None # define attributes in init
31 | self.reset()
32 |
33 | def reset(self):
34 | self.event_data, self.ended = None, False
35 | self.post_wait.clear()
36 |
37 | def abort(self):
38 | self.ended = True
39 | self.post_wait.set()
40 |
41 | def post(self, event_data:EventData=None, ended=False):
42 | self.event_data, self.ended = event_data, ended
43 | self.post_wait.set()
44 |
45 | def get_vals(self):
46 | while True:
47 | self.post_wait.wait()
48 | self.post_wait.clear()
49 | if self.ended:
50 | break
51 | else:
52 | yield self.event_data
53 | # below will cause result=None, is_valid=False when
54 | # expression has reduce
55 | self.eval_wait.set()
56 |
57 | def __init__(self, expr):
58 | self.eval_wait = threading.Event()
59 | self.reset_wait = threading.Event()
60 | self.g = Evaler.PostableIterator(self.eval_wait)
61 | self.expr = expr
62 | self.eval_return, self.continue_thread = None, None # define in __init__
63 | self.reset()
64 |
65 | self.th = threading.Thread(target=self._runner, daemon=True, name='evaler')
66 | self.th.start()
67 | self.running = True
68 |
69 | def reset(self):
70 | self.g.reset()
71 | self.eval_wait.clear()
72 | self.reset_wait.clear()
73 | self.eval_return = Evaler.EvalReturn()
74 | self.continue_thread = True
75 |
76 | def _runner(self):
77 | while True:
78 | # this var will be used by eval
79 | l = self.g.get_vals() # pylint: disable=unused-variable
80 | try:
81 | result = eval(self.expr) # pylint: disable=eval-used
82 | if isinstance(result, Iterator):
83 | for item in result:
84 | self.eval_return = Evaler.EvalReturn(item, True)
85 | else:
86 | self.eval_return = Evaler.EvalReturn(result, True)
87 | except Exception as ex: # pylint: disable=broad-except
88 | logging.exception('Exception occured while evaluating expression: ' + self.expr)
89 | self.eval_return = Evaler.EvalReturn(None, True, ex)
90 | self.eval_wait.set()
91 | self.reset_wait.wait()
92 | if not self.continue_thread:
93 | break
94 | self.reset()
95 | self.running = False
96 | utils.debug_log('eval runner ended!')
97 |
98 | def abort(self):
99 | utils.debug_log('Evaler Aborted')
100 | self.continue_thread = False
101 | self.g.abort()
102 | self.eval_wait.set()
103 | self.reset_wait.set()
104 |
105 | def post(self, event_data:EventData=None, ended=False, continue_thread=True):
106 | if not self.running:
107 | utils.debug_log('post was called when Evaler is not running')
108 | return None, False
109 | self.eval_return.reset()
110 | self.g.post(event_data, ended)
111 | self.eval_wait.wait()
112 | self.eval_wait.clear()
113 | # save result before it would get reset
114 | eval_return = self.eval_return
115 | self.reset_wait.set()
116 | self.continue_thread = continue_thread
117 | if isinstance(eval_return.result, Iterator):
118 | eval_return.result = list(eval_return.result)
119 | return eval_return
120 |
121 | def join(self):
122 | self.th.join()
123 |
--------------------------------------------------------------------------------
/tensorwatch/plotly/base_plotly_plot.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from ..vis_base import VisBase
5 | import time
6 | from abc import abstractmethod
7 | from .. import utils
8 |
9 |
10 | class BasePlotlyPlot(VisBase):
11 | def __init__(self, cell:VisBase.widgets.Box=None, title=None, show_legend:bool=None, is_3d:bool=False,
12 | stream_name:str=None, console_debug:bool=False, **vis_args):
13 | import plotly.graph_objs as go # function-level import as this takes long time
14 | super(BasePlotlyPlot, self).__init__(go.FigureWidget(), cell, title, show_legend,
15 | stream_name=stream_name, console_debug=console_debug, **vis_args)
16 |
17 | self.is_3d = is_3d
18 | self.widget.layout.title = title
19 | self.widget.layout.showlegend = show_legend if show_legend is not None else True
20 |
21 | def _add_trace(self, stream_vis):
22 | stream_vis.trace_index = len(self.widget.data)
23 | trace = self._create_trace(stream_vis)
24 | if stream_vis.opacity is not None:
25 | trace.opacity = stream_vis.opacity
26 | self.widget.add_trace(trace)
27 |
28 | def _add_trace_with_history(self, stream_vis):
29 | # if history buffer isn't full
30 | if stream_vis.history_len > len(stream_vis.trace_history):
31 | self._add_trace(stream_vis)
32 | stream_vis.trace_history.append(len(self.widget.data)-1)
33 | stream_vis.cur_history_index = len(stream_vis.trace_history)-1
34 | #if stream_vis.cur_history_index:
35 | # self.widget.data[trace_index].showlegend = False
36 | else:
37 | # rotate trace
38 | stream_vis.cur_history_index = (stream_vis.cur_history_index + 1) % stream_vis.history_len
39 | stream_vis.trace_index = stream_vis.trace_history[stream_vis.cur_history_index]
40 | self.clear_plot(stream_vis, False)
41 | self.widget.data[stream_vis.trace_index].opacity = stream_vis.opacity or 1
42 |
43 | cur_history_len = len(stream_vis.trace_history)
44 | if stream_vis.dim_history and cur_history_len > 1:
45 | max_opacity = stream_vis.opacity or 1
46 | min_alpha, max_alpha, dimmed_len = max_opacity*0.05, max_opacity*0.8, cur_history_len-1
47 | alphas = list(utils.frange(max_alpha, min_alpha, steps=dimmed_len))
48 | for i, thi in enumerate(range(stream_vis.cur_history_index+1,
49 | stream_vis.cur_history_index+cur_history_len)):
50 | trace_index = stream_vis.trace_history[thi % cur_history_len]
51 | self.widget.data[trace_index].opacity = alphas[i]
52 |
53 | @staticmethod
54 | def get_pallet_color(i:int):
55 | import plotly # function-level import as this takes long time
56 | return plotly.colors.DEFAULT_PLOTLY_COLORS[i % len(plotly.colors.DEFAULT_PLOTLY_COLORS)]
57 |
58 | @staticmethod
59 | def _get_axis_common_props(title:str, axis_range:tuple):
60 | props = {'showline':True, 'showgrid': True,
61 | 'showticklabels': True, 'ticks':'inside'}
62 | if title:
63 | props['title'] = title
64 | if axis_range:
65 | props['range'] = list(axis_range)
66 | return props
67 |
68 | def _can_update_stream_plots(self):
69 | return time.time() - self.q_last_processed > 0.5 # make configurable
70 |
71 | def _post_add_subscription(self, stream_vis, **stream_vis_args):
72 | stream_vis.trace_history, stream_vis.cur_history_index = [], None
73 | self._add_trace_with_history(stream_vis)
74 | self._setup_layout(stream_vis)
75 |
76 | if not self.widget.layout.title:
77 | self.widget.layout.title = stream_vis.title
78 | # TODO: better way for below?
79 | if stream_vis.history_len > 1:
80 | self.widget.layout.showlegend = False
81 |
82 | def _show_widget_native(self, blocking:bool):
83 | pass
84 | #TODO: save image, spawn browser?
85 |
86 | def _show_widget_notebook(self):
87 | #plotly.offline.iplot(self.widget)
88 | return None
89 |
90 | def _post_update_stream_plot(self, stream_vis):
91 | # not needed for plotly as FigureWidget stays upto date
92 | pass
93 |
94 | @abstractmethod
95 | def clear_plot(self, stream_vis, clear_history):
96 | """(for derived class) Clears the data in specified plot before new data is redrawn"""
97 | pass
98 | @abstractmethod
99 | def _show_stream_items(self, stream_vis, stream_items):
100 | """Paint the given stream_items in to visualizer. If visualizer is dirty then return False else True.
101 | """
102 |
103 | pass
104 | @abstractmethod
105 | def _setup_layout(self, stream_vis):
106 | pass
107 | @abstractmethod
108 | def _create_trace(self, stream_vis):
109 | pass
110 |
--------------------------------------------------------------------------------
/tensorwatch/image_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import numpy as np
5 | import math
6 | import time
7 |
8 | def guess_image_dims(img):
9 | if len(img.shape) == 1:
10 | # assume 2D monochrome (MNIST)
11 | width = height = round(math.sqrt(img.shape[0]))
12 | if width*height != img.shape[0]:
13 | # assume 3 channels (CFAR, ImageNet)
14 | width = height = round(math.sqrt(img.shape[0] / 3))
15 | if width*height*3 != img.shape[0]:
16 | raise ValueError("Cannot guess image dimensions for linearized pixels")
17 | return (3, height, width)
18 | return (1, height, width)
19 | return img.shape
20 |
21 | def to_imshow_array(img, width=None, height=None):
22 | # array from Pytorch has shape: [[channels,] height, width]
23 | # image needed for imshow needs: [height, width, channels]
24 | from PIL import Image
25 |
26 | if img is not None:
27 | if isinstance(img, Image.Image):
28 | img = np.array(img)
29 | if len(img.shape) >= 2:
30 | return img # img is already compatible to imshow
31 |
32 | # force max 3 dimensions
33 | if len(img.shape) > 3:
34 | # TODO allow config
35 | # select first one in batch
36 | img = img[0:1,:,:]
37 |
38 | if len(img.shape) == 1: # linearized pixels typically used for MLPs
39 | if not(width and height):
40 | # pylint: disable=unused-variable
41 | channels, height, width = guess_image_dims(img)
42 | img = img.reshape((-1, height, width))
43 |
44 | if len(img.shape) == 3:
45 | if img.shape[0] == 1: # single channel images
46 | img = img.squeeze(0)
47 | else:
48 | img = np.swapaxes(img, 0, 2) # transpose H,W for imshow
49 | img = np.swapaxes(img, 0, 1)
50 | elif len(img.shape) == 2:
51 | img = np.swapaxes(img, 0, 1) # transpose H,W for imshow
52 | else: #zero dimensions
53 | img = None
54 |
55 | return img
56 |
57 | #width_dim=1 for imshow, 2 for pytorch arrays
58 | def stitch_horizontal(images, width_dim=1):
59 | return np.concatenate(images, axis=width_dim)
60 |
61 | def _resize_image(img, size=None):
62 | if size is not None or (hasattr(img, 'shape') and len(img.shape) == 1):
63 | if size is None:
64 | # make guess for 1-dim tensors
65 | h = int(math.sqrt(img.shape[0]))
66 | w = int(img.shape[0] / h)
67 | size = h,w
68 | img = np.reshape(img, size)
69 | return img
70 |
71 | def show_image(img, size=None, alpha=None, cmap=None,
72 | img2=None, size2=None, alpha2=None, cmap2=None, ax=None):
73 | import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue
74 |
75 | img =_resize_image(img, size)
76 | img2 =_resize_image(img2, size2)
77 |
78 | (ax or plt).imshow(img, alpha=alpha, cmap=cmap)
79 |
80 | if img2 is not None:
81 | (ax or plt).imshow(img2, alpha=alpha2, cmap=cmap2)
82 |
83 | return ax or plt.show()
84 |
85 | # convert_mode param is mode: https://pillow.readthedocs.io/en/5.1.x/handbook/concepts.html#modes
86 | # use convert_mode='RGB' to force 3 channels
87 | def open_image(path, resize=None, resample=1, convert_mode=None): # Image.ANTIALIAS==1
88 | from PIL import Image
89 |
90 | img = Image.open(path)
91 | if resize is not None:
92 | img = img.resize(resize, resample)
93 | if convert_mode is not None:
94 | img = img.convert(convert_mode)
95 | return img
96 |
97 | def img2pyt(img, add_batch_dim=True, resize=None):
98 | from torchvision import transforms # expensive function-level import
99 |
100 | ts = []
101 | if resize is not None:
102 | ts.append(transforms.RandomResizedCrop(resize))
103 | ts.append(transforms.ToTensor())
104 | img_pyt = transforms.Compose(ts)(img)
105 | if add_batch_dim:
106 | img_pyt.unsqueeze_(0)
107 | return img_pyt
108 |
109 | def linear_to_2d(img, size=None):
110 | if size is not None or (hasattr(img, 'shape') and len(img.shape) == 1):
111 | if size is None:
112 | # make guess for 1-dim tensors
113 | h = int(math.sqrt(img.shape[0]))
114 | w = int(img.shape[0] / h)
115 | size = h,w
116 | img = np.reshape(img, size)
117 | return img
118 |
119 | def stack_images(imgs):
120 | return np.hstack(imgs)
121 |
122 | def plt_loop(count=None, sleep_time=1, plt_pause=0.01):
123 | import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue
124 |
125 | #plt.ion()
126 | #plt.show(block=False)
127 | while((count is None or count > 0) and not plt.waitforbuttonpress(plt_pause)):
128 | #plt.draw()
129 | plt.pause(plt_pause)
130 | time.sleep(sleep_time)
131 | if count is not None:
132 | count = count - 1
133 |
134 | def get_cmap(name:str):
135 | import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue
136 | return plt.cm.get_cmap(name=name)
137 |
138 |
--------------------------------------------------------------------------------
/tensorwatch/visualizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from .stream import Stream
5 | from .vis_base import VisBase
6 | from . import mpl
7 | from . import plotly
8 |
9 | class Visualizer:
10 | """Constructs visualizer for specified vis_type.
11 |
12 | NOTE: If you modify arguments here then also sync VisArgs contructor.
13 | """
14 | def __init__(self, stream:Stream, vis_type:str=None, host:'Visualizer'=None,
15 | cell:'Visualizer'=None, title:str=None,
16 | clear_after_end=False, clear_after_each=False, history_len=1, dim_history=True, opacity=None,
17 |
18 | rows=2, cols=5, img_width=None, img_height=None, img_channels=None,
19 | colormap=None, viz_img_scale=None,
20 |
21 | # these image params are for hover on point for t-sne
22 | hover_images=None, hover_image_reshape=None, cell_width:str=None, cell_height:str=None,
23 |
24 | only_summary=False, separate_yaxis=True, xtitle=None, ytitle=None, ztitle=None, color=None,
25 | xrange=None, yrange=None, zrange=None, draw_line=True, draw_marker=False,
26 |
27 | # histogram
28 | bins=None, normed=None, histtype='bar', edge_color=None, linewidth=None, bar_width=None,
29 |
30 | # pie chart
31 | autopct=None, shadow=None,
32 |
33 | vis_args={}, stream_vis_args={})->None:
34 |
35 | cell = cell._host_base.cell if cell is not None else None
36 |
37 | if host:
38 | self._host_base = host._host_base
39 | else:
40 | self._host_base = self._get_vis_base(vis_type, cell, title, hover_images=hover_images, hover_image_reshape=hover_image_reshape,
41 | cell_width=cell_width, cell_height=cell_height,
42 | **vis_args)
43 |
44 | self._host_base.subscribe(stream, show=False, clear_after_end=clear_after_end, clear_after_each=clear_after_each,
45 | history_len=history_len, dim_history=dim_history, opacity=opacity,
46 | only_summary=only_summary if vis_type is None or 'summary' != vis_type else True,
47 | separate_yaxis=separate_yaxis, xtitle=xtitle, ytitle=ytitle, ztitle=ztitle, color=color,
48 | xrange=xrange, yrange=yrange, zrange=zrange,
49 | draw_line=draw_line if vis_type is not None and 'scatter' in vis_type else True,
50 | draw_marker=draw_marker,
51 | rows=rows, cols=cols, img_width=img_width, img_height=img_height, img_channels=img_channels,
52 | colormap=colormap, viz_img_scale=viz_img_scale,
53 | bins=bins, normed=normed, histtype=histtype, edge_color=edge_color, linewidth=linewidth, bar_width = bar_width,
54 | autopct=autopct, shadow=shadow,
55 | **stream_vis_args)
56 |
57 | stream.load()
58 |
59 | def show(self):
60 | return self._host_base.show()
61 |
62 | def _get_vis_base(self, vis_type, cell:VisBase.widgets.Box, title, hover_images=None, hover_image_reshape=None, cell_width=None, cell_height=None, **vis_args)->VisBase:
63 | if vis_type is None or vis_type in ['line',
64 | 'mpl-line', 'mpl-line3d', 'mpl-scatter3d', 'mpl-scatter']:
65 | return mpl.line_plot.LinePlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height,
66 | is_3d=vis_type is not None and vis_type.endswith('3d'), **vis_args)
67 | if vis_type in ['image', 'mpl-image']:
68 | return mpl.image_plot.ImagePlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, **vis_args)
69 | if vis_type in ['bar', 'bar3d']:
70 | return mpl.bar_plot.BarPlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height,
71 | is_3d=vis_type.endswith('3d'), **vis_args)
72 | if vis_type in ['histogram']:
73 | return mpl.histogram.Histogram(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, **vis_args)
74 | if vis_type in ['pie']:
75 | return mpl.pie_chart.PieChart(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, **vis_args)
76 | if vis_type in ['text', 'summary']:
77 | from .text_vis import TextVis
78 | return TextVis(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, **vis_args)
79 | if vis_type in ['line3d', 'scatter', 'scatter3d',
80 | 'plotly-line', 'plotly-line3d', 'plotly-scatter', 'plotly-scatter3d', 'mesh3d']:
81 | return plotly.line_plot.LinePlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height,
82 | is_3d=vis_type.endswith('3d'), **vis_args)
83 | if vis_type in ['tsne', 'embeddings', 'tsne2d', 'embeddings2d']:
84 | return plotly.embeddings_plot.EmbeddingsPlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height,
85 | is_3d='2d' not in vis_type,
86 | hover_images=hover_images, hover_image_reshape=hover_image_reshape, **vis_args)
87 | else:
88 | raise ValueError('Render vis_type parameter has invalid value: "{}"'.format(vis_type))
89 |
--------------------------------------------------------------------------------
/tensorwatch/saliency/backprop.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch.autograd import Variable, Function
3 | import torch
4 | import types
5 |
6 |
7 | class VanillaGradExplainer(object):
8 | def __init__(self, model):
9 | self.model = model
10 |
11 | def _backprop(self, inp, ind):
12 | inp.requires_grad = True
13 | if inp.grad is not None:
14 | inp.grad.zero_()
15 | if ind.grad is not None:
16 | ind.grad.zero_()
17 | self.model.eval()
18 | self.model.zero_grad()
19 |
20 | output = self.model(inp)
21 | if ind is None:
22 | ind = output.max(1)[1]
23 | grad_out = output.clone()
24 | grad_out.fill_(0.0)
25 | grad_out.scatter_(1, ind.unsqueeze(0).t(), 1.0)
26 | output.backward(grad_out)
27 | return inp.grad
28 |
29 | def explain(self, inp, ind=None, raw_inp=None):
30 | return self._backprop(inp, ind)
31 |
32 |
33 | class GradxInputExplainer(VanillaGradExplainer):
34 | def __init__(self, model):
35 | super(GradxInputExplainer, self).__init__(model)
36 |
37 | def explain(self, inp, ind=None, raw_inp=None):
38 | grad = self._backprop(inp, ind)
39 | return inp * grad
40 |
41 |
42 | class SaliencyExplainer(VanillaGradExplainer):
43 | def __init__(self, model):
44 | super(SaliencyExplainer, self).__init__(model)
45 |
46 | def explain(self, inp, ind=None, raw_inp=None):
47 | grad = self._backprop(inp, ind)
48 | return grad.abs()
49 |
50 |
51 | class IntegrateGradExplainer(VanillaGradExplainer):
52 | def __init__(self, model, steps=100):
53 | super(IntegrateGradExplainer, self).__init__(model)
54 | self.steps = steps
55 |
56 | def explain(self, inp, ind=None, raw_inp=None):
57 | grad = 0
58 | inp_data = inp.clone()
59 |
60 | for alpha in np.arange(1 / self.steps, 1.0, 1 / self.steps):
61 | new_inp = Variable(inp_data * alpha, requires_grad=True)
62 | g = self._backprop(new_inp, ind)
63 | grad += g
64 |
65 | return grad * inp_data / self.steps
66 |
67 |
68 | class DeconvExplainer(VanillaGradExplainer):
69 | def __init__(self, model):
70 | super(DeconvExplainer, self).__init__(model)
71 | self._override_backward()
72 |
73 | def _override_backward(self):
74 | class _ReLU(Function):
75 | @staticmethod
76 | def forward(ctx, input):
77 | output = torch.clamp(input, min=0)
78 | return output
79 |
80 | @staticmethod
81 | def backward(ctx, grad_output):
82 | grad_inp = torch.clamp(grad_output, min=0)
83 | return grad_inp
84 |
85 | def new_forward(self, x):
86 | return _ReLU.apply(x)
87 |
88 | def replace(m):
89 | if m.__class__.__name__ == 'ReLU':
90 | m.forward = types.MethodType(new_forward, m)
91 |
92 | self.model.apply(replace)
93 |
94 |
95 | class GuidedBackpropExplainer(VanillaGradExplainer):
96 | def __init__(self, model):
97 | super(GuidedBackpropExplainer, self).__init__(model)
98 | self._override_backward()
99 |
100 | def _override_backward(self):
101 | class _ReLU(Function):
102 | @staticmethod
103 | def forward(ctx, input):
104 | output = torch.clamp(input, min=0)
105 | ctx.save_for_backward(output)
106 | return output
107 |
108 | @staticmethod
109 | def backward(ctx, grad_output):
110 | output, = ctx.saved_tensors
111 | mask1 = (output > 0).float()
112 | mask2 = (grad_output > 0).float()
113 | grad_inp = mask1 * mask2 * grad_output
114 | grad_output.copy_(grad_inp)
115 | return grad_output
116 |
117 | def new_forward(self, x):
118 | return _ReLU.apply(x)
119 |
120 | def replace(m):
121 | if m.__class__.__name__ == 'ReLU':
122 | m.forward = types.MethodType(new_forward, m)
123 |
124 | self.model.apply(replace)
125 |
126 |
127 | # modified from https://github.com/PAIR-code/saliency/blob/master/saliency/base.py#L80
128 | class SmoothGradExplainer(object):
129 | def __init__(self, model, base_explainer=None, stdev_spread=0.15,
130 | nsamples=25, magnitude=True):
131 | self.base_explainer = base_explainer or VanillaGradExplainer(model)
132 | self.stdev_spread = stdev_spread
133 | self.nsamples = nsamples
134 | self.magnitude = magnitude
135 |
136 | def explain(self, inp, ind=None, raw_inp=None):
137 | stdev = self.stdev_spread * (inp.max() - inp.min())
138 |
139 | total_gradients = 0
140 |
141 | for i in range(self.nsamples):
142 | noise = torch.randn_like(inp) * stdev
143 |
144 | noisy_inp = inp + noise
145 | noisy_inp.retain_grad()
146 | grad = self.base_explainer.explain(noisy_inp, ind)
147 |
148 | if self.magnitude:
149 | total_gradients += grad ** 2
150 | else:
151 | total_gradients += grad
152 |
153 | return total_gradients / self.nsamples
--------------------------------------------------------------------------------
/tensorwatch/saliency/saliency.py:
--------------------------------------------------------------------------------
1 | from .gradcam import GradCAMExplainer
2 | from .backprop import VanillaGradExplainer, GradxInputExplainer, SaliencyExplainer, \
3 | IntegrateGradExplainer, DeconvExplainer, GuidedBackpropExplainer, SmoothGradExplainer
4 | from .deeplift import DeepLIFTRescaleExplainer
5 | from .occlusion import OcclusionExplainer
6 | from .epsilon_lrp import EpsilonLrp
7 | from .lime_image_explainer import LimeImageExplainer, LimeImagenetExplainer
8 | import skimage.transform
9 | import torch
10 | import math
11 | from .. import image_utils
12 |
13 | class ImageSaliencyResult:
14 | def __init__(self, raw_image, saliency, title, saliency_alpha=0.4, saliency_cmap='jet'):
15 | self.raw_image, self.saliency, self.title = raw_image, saliency, title
16 | self.saliency_alpha, self.saliency_cmap = saliency_alpha, saliency_cmap
17 |
18 | def _get_explainer(explainer_name, model, layer_path=None):
19 | if explainer_name == 'gradcam':
20 | return GradCAMExplainer(model, target_layer_name_keys=layer_path, use_inp=True)
21 | if explainer_name == 'vanilla_grad':
22 | return VanillaGradExplainer(model)
23 | if explainer_name == 'grad_x_input':
24 | return GradxInputExplainer(model)
25 | if explainer_name == 'saliency':
26 | return SaliencyExplainer(model)
27 | if explainer_name == 'integrate_grad':
28 | return IntegrateGradExplainer(model)
29 | if explainer_name == 'deconv':
30 | return DeconvExplainer(model)
31 | if explainer_name == 'guided_backprop':
32 | return GuidedBackpropExplainer(model)
33 | if explainer_name == 'smooth_grad':
34 | return SmoothGradExplainer(model)
35 | if explainer_name == 'deeplift':
36 | return DeepLIFTRescaleExplainer(model)
37 | if explainer_name == 'occlusion':
38 | return OcclusionExplainer(model)
39 | if explainer_name == 'lrp':
40 | return EpsilonLrp(model)
41 | if explainer_name == 'lime_imagenet':
42 | return LimeImagenetExplainer(model)
43 |
44 | raise ValueError('Explainer {} is not recognized'.format(explainer_name))
45 |
46 | def _get_layer_path(model):
47 | if model.__class__.__name__ == 'VGG':
48 | return ['features', '30'] # pool5
49 | elif model.__class__.__name__ == 'GoogleNet':
50 | return ['pool5']
51 | elif model.__class__.__name__ == 'ResNet':
52 | return ['avgpool'] #layer4
53 | elif model.__class__.__name__ == 'Inception3':
54 | return ['Mixed_7c', 'branch_pool'] # ['conv2d_94'], 'mixed10'
55 | else: #TODO: guess layer for other networks?
56 | return None
57 |
58 | def get_saliency(model, raw_input, input, label, method='integrate_grad', layer_path=None):
59 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60 |
61 | model.to(device)
62 | input = input.to(device)
63 | if label is not None:
64 | label = label.to(device)
65 |
66 | if input.grad is not None:
67 | input.grad.zero_()
68 | if label is not None and label.grad is not None:
69 | label.grad.zero_()
70 | model.eval()
71 | model.zero_grad()
72 |
73 | layer_path = layer_path or _get_layer_path(model)
74 |
75 | exp = _get_explainer(method, model, layer_path)
76 | saliency = exp.explain(input, label, raw_input)
77 |
78 | if saliency is not None:
79 | saliency = saliency.abs().sum(dim=1)[0].squeeze()
80 | saliency -= saliency.min()
81 | saliency /= (saliency.max() + 1e-20)
82 |
83 | return saliency.detach().cpu().numpy()
84 | else:
85 | return None
86 |
87 | def get_image_saliency_results(model, raw_image, input, label,
88 | methods=['lime_imagenet', 'gradcam', 'smooth_grad',
89 | 'guided_backprop', 'deeplift', 'grad_x_input'],
90 | layer_path=None):
91 | results = []
92 | for method in methods:
93 | sal = get_saliency(model, raw_image, input, label, method=method)
94 |
95 | if sal is not None:
96 | results.append(ImageSaliencyResult(raw_image, sal, method))
97 | return results
98 |
99 | def get_image_saliency_plot(image_saliency_results, cols:int=2, figsize:tuple=None):
100 | import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue
101 |
102 | rows = math.ceil(len(image_saliency_results) / cols)
103 | figsize=figsize or (8, 3 * rows)
104 | figure = plt.figure(figsize=figsize)
105 |
106 | for i, r in enumerate(image_saliency_results):
107 | ax = figure.add_subplot(rows, cols, i+1)
108 | ax.set_xticks([])
109 | ax.set_yticks([])
110 | ax.set_title(r.title, fontdict={'fontsize': 24}) #'fontweight': 'light'
111 |
112 | #upsampler = nn.Upsample(size=(raw_image.height, raw_image.width), mode='bilinear')
113 | saliency_upsampled = skimage.transform.resize(r.saliency,
114 | (r.raw_image.height, r.raw_image.width),
115 | mode='reflect')
116 |
117 | image_utils.show_image(r.raw_image, img2=saliency_upsampled,
118 | alpha2=r.saliency_alpha, cmap2=r.saliency_cmap, ax=ax)
119 | return figure
120 |
--------------------------------------------------------------------------------
/tensorwatch/model_graph/hiddenlayer/pytorch_builder.py:
--------------------------------------------------------------------------------
1 | """
2 | HiddenLayer
3 |
4 | PyTorch graph importer.
5 |
6 | Written by Waleed Abdulla
7 | Licensed under the MIT License
8 | """
9 |
10 | from __future__ import absolute_import, division, print_function
11 | import re
12 | from .graph import Graph, Node
13 | from . import transforms as ht
14 | import torch
15 | from collections import abc
16 | import numpy as np
17 |
18 | # PyTorch Graph Transforms
19 | FRAMEWORK_TRANSFORMS = [
20 | # Hide onnx: prefix
21 | ht.Rename(op=r"onnx::(.*)", to=r"\1"),
22 | # ONNX uses Gemm for linear layers (stands for General Matrix Multiplication).
23 | # It's an odd name that noone recognizes. Rename it.
24 | ht.Rename(op=r"Gemm", to=r"Linear"),
25 | # PyTorch layers that don't have an ONNX counterpart
26 | ht.Rename(op=r"aten::max\_pool2d\_with\_indices", to="MaxPool"),
27 | # Shorten op name
28 | ht.Rename(op=r"BatchNormalization", to="BatchNorm"),
29 | ]
30 |
31 |
32 | def dump_pytorch_graph(graph):
33 | """List all the nodes in a PyTorch graph."""
34 | f = "{:25} {:40} {} -> {}"
35 | print(f.format("kind", "scopeName", "inputs", "outputs"))
36 | for node in graph.nodes():
37 | print(f.format(node.kind(), node.scopeName(),
38 | [i.unique() for i in node.inputs()],
39 | [i.unique() for i in node.outputs()]
40 | ))
41 |
42 |
43 | def pytorch_id(node):
44 | """Returns a unique ID for a node."""
45 | # After ONNX simplification, the scopeName is not unique anymore
46 | # so append node outputs to guarantee uniqueness
47 | return node.scopeName() + "/outputs/" + "/".join([o.uniqueName() for o in node.outputs()])
48 |
49 |
50 | def get_shape(torch_node):
51 | """Return the output shape of the given Pytorch node."""
52 | # Extract node output shape from the node string representation
53 | # This is a hack because there doesn't seem to be an official way to do it.
54 | # See my quesiton in the PyTorch forum:
55 | # https://discuss.pytorch.org/t/node-output-shape-from-trace-graph/24351/2
56 | # TODO: find a better way to extract output shape
57 | # TODO: Assuming the node has one output. Update if we encounter a multi-output node.
58 | m = re.match(r".*Float\(([\d\s\,]+)\).*", str(next(torch_node.outputs())))
59 | if m:
60 | shape = m.group(1)
61 | shape = shape.split(",")
62 | shape = tuple(map(int, shape))
63 | else:
64 | shape = None
65 | return shape
66 |
67 | def calc_rf(model, input_shape):
68 | for n, p in model.named_parameters():
69 | if not p.requires_grad:
70 | continue;
71 | if 'bias' in n:
72 | p.data.fill_(0)
73 | elif 'weight' in n:
74 | p.data.fill_(1)
75 |
76 | input = torch.ones(input_shape, requires_grad=True)
77 | output = model(input)
78 | out_shape = output.size()
79 | ndims = len(out_shape)
80 | grad = torch.zeros(out_shape)
81 | l_tmp=[]
82 | for i in xrange(ndims):
83 | if i==0 or i==1:#batch or channel
84 | l_tmp.append(0)
85 | else:
86 | l_tmp.append(out_shape[i]/2)
87 |
88 | grad[tuple(l_tmp)] = 1
89 | output.backward(gradient=grad)
90 | grad_np = img_.grad[0,0].data.numpy()
91 | idx_nonzeros = np.where(grad_np!=0)
92 | RF=[np.max(idx)-np.min(idx)+1 for idx in idx_nonzeros]
93 |
94 | return RF
95 |
96 | def import_graph(hl_graph, model, args, input_names=None, verbose=False):
97 | # TODO: add input names to graph
98 |
99 | if args is None:
100 | args = [1, 3, 224, 224] # assume ImageNet default
101 |
102 | # if args is not Tensor but is array like then convert it to torch tensor
103 | if not isinstance(args, torch.Tensor) and \
104 | hasattr(args, "__len__") and hasattr(args, '__getitem__') and \
105 | not isinstance(args, (str, abc.ByteString)):
106 | args = torch.ones(args)
107 |
108 | # Run the Pytorch graph to get a trace and generate a graph from it
109 | trace, out = torch.jit.get_trace_graph(model, args)
110 | torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX)
111 | torch_graph = trace.graph()
112 |
113 | # Dump list of nodes (DEBUG only)
114 | if verbose:
115 | dump_pytorch_graph(torch_graph)
116 |
117 | # Loop through nodes and build HL graph
118 | for torch_node in torch_graph.nodes():
119 | # Op
120 | op = torch_node.kind()
121 | # Parameters
122 | params = {k: torch_node[k] for k in torch_node.attributeNames()}
123 | # Inputs/outputs
124 | # TODO: inputs = [i.unique() for i in node.inputs()]
125 | outputs = [o.unique() for o in torch_node.outputs()]
126 | # Get output shape
127 | shape = get_shape(torch_node)
128 | # Add HL node
129 | hl_node = Node(uid=pytorch_id(torch_node), name=None, op=op,
130 | output_shape=shape, params=params)
131 | hl_graph.add_node(hl_node)
132 | # Add edges
133 | for target_torch_node in torch_graph.nodes():
134 | target_inputs = [i.unique() for i in target_torch_node.inputs()]
135 | if set(outputs) & set(target_inputs):
136 | hl_graph.add_edge_by_id(pytorch_id(torch_node), pytorch_id(target_torch_node), shape)
137 | return hl_graph
138 |
--------------------------------------------------------------------------------
/tensorwatch/model_graph/hiddenlayer/ge.py:
--------------------------------------------------------------------------------
1 | """
2 | HiddenLayer
3 |
4 | Implementation graph expressions to find nodes in a graph based on a pattern.
5 |
6 | Written by Waleed Abdulla
7 | Licensed under the MIT License
8 | """
9 |
10 | import re
11 |
12 |
13 |
14 | class GEParser():
15 | def __init__(self, text):
16 | self.index = 0
17 | self.text = text
18 |
19 | def parse(self):
20 | return self.serial() or self.parallel() or self.expression()
21 |
22 | def parallel(self):
23 | index = self.index
24 | expressions = []
25 | while len(expressions) == 0 or self.token("|"):
26 | e = self.expression()
27 | if not e:
28 | break
29 | expressions.append(e)
30 | if len(expressions) >= 2:
31 | return ParallelPattern(expressions)
32 | # No match. Reset index
33 | self.index = index
34 |
35 | def serial(self):
36 | index = self.index
37 | expressions = []
38 | while len(expressions) == 0 or self.token(">"):
39 | e = self.expression()
40 | if not e:
41 | break
42 | expressions.append(e)
43 |
44 | if len(expressions) >= 2:
45 | return SerialPattern(expressions)
46 | self.index = index
47 |
48 | def expression(self):
49 | index = self.index
50 |
51 | if self.token("("):
52 | e = self.serial() or self.parallel() or self.op()
53 | if e and self.token(")"):
54 | return e
55 | self.index = index
56 | e = self.op()
57 | return e
58 |
59 | def op(self):
60 | t = self.re(r"\w+")
61 | if t:
62 | c = self.condition()
63 | return NodePattern(t, c)
64 |
65 | def condition(self):
66 | # TODO: not implemented yet. This function is a placeholder
67 | index = self.index
68 | if self.token("["):
69 | c = self.token("1x1") or self.token("3x3")
70 | if c:
71 | if self.token("]"):
72 | return c
73 | self.index = index
74 |
75 | def token(self, s):
76 | return self.re(r"\s*(" + re.escape(s) + r")\s*", 1)
77 |
78 | def string(self, s):
79 | if s == self.text[self.index:self.index+len(s)]:
80 | self.index += len(s)
81 | return s
82 |
83 | def re(self, regex, group=0):
84 | m = re.match(regex, self.text[self.index:])
85 | if m:
86 | self.index += len(m.group(0))
87 | return m.group(group)
88 |
89 |
90 | class NodePattern():
91 | def __init__(self, op, condition=None):
92 | self.op = op
93 | self.condition = condition # TODO: not implemented yet
94 |
95 | def match(self, graph, node):
96 | if isinstance(node, list):
97 | return [], None
98 | if self.op == node.op:
99 | following = graph.outgoing(node)
100 | if len(following) == 1:
101 | following = following[0]
102 | return [node], following
103 | else:
104 | return [], None
105 |
106 |
107 | class SerialPattern():
108 | def __init__(self, patterns):
109 | self.patterns = patterns
110 |
111 | def match(self, graph, node):
112 | all_matches = []
113 | for i, p in enumerate(self.patterns):
114 | matches, following = p.match(graph, node)
115 | if not matches:
116 | return [], None
117 | all_matches.extend(matches)
118 | if i < len(self.patterns) - 1:
119 | node = following # Might be more than one node
120 | return all_matches, following
121 |
122 |
123 | class ParallelPattern():
124 | def __init__(self, patterns):
125 | self.patterns = patterns
126 |
127 | def match(self, graph, nodes):
128 | if not nodes:
129 | return [], None
130 | nodes = nodes if isinstance(nodes, list) else [nodes]
131 | # If a single node, assume we need to match with its siblings
132 | if len(nodes) == 1:
133 | nodes = graph.siblings(nodes[0])
134 | else:
135 | # Verify all nodes have the same parent or all have no parent
136 | parents = [graph.incoming(n) for n in nodes]
137 | matches = [set(p) == set(parents[0]) for p in parents[1:]]
138 | if not all(matches):
139 | return [], None
140 |
141 | # TODO: If more nodes than patterns, we should consider
142 | # all permutations of the nodes
143 | if len(self.patterns) != len(nodes):
144 | return [], None
145 |
146 | patterns = self.patterns.copy()
147 | nodes = nodes.copy()
148 | all_matches = []
149 | end_node = None
150 | for p in patterns:
151 | found = False
152 | for n in nodes:
153 | matches, following = p.match(graph, n)
154 | if matches:
155 | found = True
156 | nodes.remove(n)
157 | all_matches.extend(matches)
158 | # Verify all branches end in the same node
159 | if end_node:
160 | if end_node != following:
161 | return [], None
162 | else:
163 | end_node = following
164 | break
165 | if not found:
166 | return [], None
167 | return all_matches, end_node
168 |
169 |
170 |
--------------------------------------------------------------------------------
/tensorwatch/model_graph/hiddenlayer/tf_builder.py:
--------------------------------------------------------------------------------
1 | """
2 | HiddenLayer
3 |
4 | TensorFlow graph importer.
5 |
6 | Written by Phil Ferriere. Edits by Waleed Abdulla.
7 | Licensed under the MIT License
8 | """
9 |
10 | from __future__ import absolute_import, division, print_function, unicode_literals
11 | import logging
12 | import tensorflow as tf
13 | from .graph import Graph, Node
14 | from . import transforms as ht
15 |
16 |
17 | FRAMEWORK_TRANSFORMS = [
18 | # Rename VariableV2 op to Variable. Same for anything V2, V3, ...etc.
19 | ht.Rename(op=r"(\w+)V\d", to=r"\1"),
20 | ht.Prune("Const"),
21 | ht.Prune("PlaceholderWithDefault"),
22 | ht.Prune("Variable"),
23 | ht.Prune("VarIsInitializedOp"),
24 | ht.Prune("VarHandleOp"),
25 | ht.Prune("ReadVariableOp"),
26 | ht.PruneBranch("Assign"),
27 | ht.PruneBranch("AssignSub"),
28 | ht.PruneBranch("AssignAdd"),
29 | ht.PruneBranch("AssignVariableOp"),
30 | ht.Prune("ApplyMomentum"),
31 | ht.Prune("ApplyAdam"),
32 | ht.FoldId(r"^(gradients)/.*", "NoOp"), # Fold to NoOp then delete in the next step
33 | ht.Prune("NoOp"),
34 | ht.Rename(op=r"DepthwiseConv2dNative", to="SeparableConv"),
35 | ht.Rename(op=r"Conv2D", to="Conv"),
36 | ht.Rename(op=r"FusedBatchNorm", to="BatchNorm"),
37 | ht.Rename(op=r"MatMul", to="Linear"),
38 | ht.Fold("Conv > BiasAdd", "__first__"),
39 | ht.Fold("Linear > BiasAdd", "__first__"),
40 | ht.Fold("Shape > StridedSlice > Pack > Reshape", "__last__"),
41 | ht.FoldId(r"(.+)/dropout/.*", "Dropout"),
42 | ht.FoldId(r"(softmax_cross\_entropy)\_with\_logits.*", "SoftmaxCrossEntropy"),
43 | ]
44 |
45 |
46 | def dump_tf_graph(tfgraph, tfgraphdef):
47 | """List all the nodes in a TF graph.
48 | tfgraph: A TF Graph object.
49 | tfgraphdef: A TF GraphDef object.
50 | """
51 | print("Nodes ({})".format(len(tfgraphdef.node)))
52 | f = "{:15} {:59} {:20} {}"
53 | print(f.format("kind", "scopeName", "shape", "inputs"))
54 | for node in tfgraphdef.node:
55 | scopename = node.name
56 | kind = node.op
57 | inputs = node.input
58 | shape = tf.graph_util.tensor_shape_from_node_def_name(tfgraph, scopename)
59 | print(f.format(kind, scopename, str(shape), inputs))
60 |
61 |
62 | def import_graph(hl_graph, tf_graph, output=None, verbose=False):
63 | """Convert TF graph to directed graph
64 | tfgraph: A TF Graph object.
65 | output: Name of the output node (string).
66 | verbose: Set to True for debug print output
67 | """
68 | # Get clean(er) list of nodes
69 | graph_def = tf_graph.as_graph_def(add_shapes=True)
70 | graph_def = tf.graph_util.remove_training_nodes(graph_def)
71 |
72 | # Dump list of TF nodes (DEBUG only)
73 | if verbose:
74 | dump_tf_graph(tf_graph, graph_def)
75 |
76 | # Loop through nodes and build the matching directed graph
77 | for tf_node in graph_def.node:
78 | # Read node details
79 | try:
80 | op, uid, name, shape, params = import_node(tf_node, tf_graph, verbose)
81 | except:
82 | if verbose:
83 | logging.exception("Failed to read node {}".format(tf_node))
84 | continue
85 |
86 | # Add node
87 | hl_node = Node(uid=uid, name=name, op=op, output_shape=shape, params=params)
88 | hl_graph.add_node(hl_node)
89 |
90 | # Add edges
91 | for target_node in graph_def.node:
92 | target_inputs = target_node.input
93 | if uid in target_node.input:
94 | hl_graph.add_edge_by_id(uid, target_node.name, shape)
95 | return hl_graph
96 |
97 |
98 | def import_node(tf_node, tf_graph, verbose=False):
99 | # Operation type and name
100 | op = tf_node.op
101 | uid = tf_node.name
102 | name = None
103 |
104 | # Shape
105 | shape = None
106 | if tf_node.op != "NoOp":
107 | try:
108 | shape = tf.graph_util.tensor_shape_from_node_def_name(tf_graph, tf_node.name)
109 | # Is the shape is known, convert to a list
110 | if shape.ndims is not None:
111 | shape = shape.as_list()
112 | except:
113 | if verbose:
114 | logging.exception("Error reading shape of {}".format(tf_node.name))
115 |
116 | # Parameters
117 | # At this stage, we really only care about two parameters:
118 | # 1/ the kernel size used by convolution layers
119 | # 2/ the stride used by convolutional and pooling layers (TODO: not fully working yet)
120 |
121 | # 1/ The kernel size is actually not stored in the convolution tensor but in its weight input.
122 | # The weights input has the shape [shape=[kernel, kernel, in_channels, filters]]
123 | # So we must fish for it
124 | params = {}
125 | if op == "Conv2D" or op == "DepthwiseConv2dNative":
126 | kernel_shape = tf.graph_util.tensor_shape_from_node_def_name(tf_graph, tf_node.input[1])
127 | kernel_shape = [int(a) for a in kernel_shape]
128 | params["kernel_shape"] = kernel_shape[0:2]
129 | if 'strides' in tf_node.attr.keys():
130 | strides = [int(a) for a in tf_node.attr['strides'].list.i]
131 | params["stride"] = strides[1:3]
132 | elif op == "MaxPool" or op == "AvgPool":
133 | # 2/ the stride used by pooling layers
134 | # See https://stackoverflow.com/questions/44124942/how-to-access-values-in-protos-in-tensorflow
135 | if 'ksize' in tf_node.attr.keys():
136 | kernel_shape = [int(a) for a in tf_node.attr['ksize'].list.i]
137 | params["kernel_shape"] = kernel_shape[1:3]
138 | if 'strides' in tf_node.attr.keys():
139 | strides = [int(a) for a in tf_node.attr['strides'].list.i]
140 | params["stride"] = strides[1:3]
141 |
142 | return op, uid, name, shape, params
143 |
--------------------------------------------------------------------------------
/tensorwatch/mpl/bar_plot.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from .base_mpl_plot import BaseMplPlot
5 | from .. import utils
6 | from .. import image_utils
7 | import numpy as np
8 | from itertools import groupby
9 | import operator
10 |
11 | class BarPlot(BaseMplPlot):
12 | def init_stream_plot(self, stream_vis,
13 | xtitle='', ytitle='', ztitle='', colormap=None, color=None,
14 | edge_color=None, linewidth=None, align=None, bar_width=None,
15 | opacity=None, **stream_vis_args):
16 |
17 | # add main subplot
18 | stream_vis.align, stream_vis.linewidth = align, (linewidth or 2)
19 | stream_vis.bar_width = bar_width or 1
20 | stream_vis.ax = self.get_main_axis()
21 | stream_vis.series = {}
22 | stream_vis.bars_artists = [] # stores previously drawn bars
23 |
24 | stream_vis.cmap = image_utils.get_cmap(colormap or 'Set3')
25 | if color is None:
26 | if not self.is_3d:
27 | stream_vis.cmap((len(self._stream_vises)%stream_vis.cmap.N)/stream_vis.cmap.N) # pylint: disable=no-member
28 | stream_vis.color = color
29 | stream_vis.edge_color = 'black'
30 | stream_vis.opacity = opacity
31 | stream_vis.ax.set_xlabel(xtitle)
32 | stream_vis.ax.xaxis.label.set_style('italic')
33 | stream_vis.ax.set_ylabel(ytitle)
34 | stream_vis.ax.yaxis.label.set_color(color)
35 | stream_vis.ax.yaxis.label.set_style('italic')
36 | if self.is_3d:
37 | stream_vis.ax.set_zlabel(ztitle)
38 | stream_vis.ax.zaxis.label.set_style('italic')
39 |
40 | def is_show_grid(self): #override
41 | return False
42 |
43 | def clear_artists(self, stream_vis):
44 | for bar in stream_vis.bars_artists:
45 | bar.remove()
46 | stream_vis.bars_artists.clear()
47 |
48 | def clear_plot(self, stream_vis, clear_history):
49 | stream_vis.series.clear()
50 | self.clear_artists(stream_vis)
51 |
52 | def _val2tuple(val, x)->tuple:
53 | """Accept scaler val, (y,), (x, y), (label, y), (x, y, label), return (x, y, z, color, label)
54 | """
55 | if utils.is_array_like(val):
56 | unpacker = lambda a0=None,a1=None,a2=None,*_:(a0,a1,a2)
57 | t = unpacker(*val)
58 | if len(t) == 1: # (y,)
59 | t = (x, t[0], 0, None, None)
60 | elif len(t) == 2:
61 | t = (x, t[1], 0, None, t[0]) if isinstance(t[0], str) else t + (None,)
62 | elif len(t) == 3: # either (x, y, z) or (x, y, label)
63 | t = (t[0], t[1], None, t[2]) if isinstance(t[2], str) else t + (None, None)
64 | elif len(t) == 4: # we assume (x, y, z, color)
65 | t += (None,)
66 | # else leave it alone
67 | else: # scaler
68 | t = (x, val, 0, None, None)
69 |
70 | return t
71 |
72 |
73 | def _show_stream_items(self, stream_vis, stream_items):
74 | """Paint the given stream_items in to visualizer. If visualizer is dirty then return False else True.
75 | """
76 |
77 | vals = self._extract_vals(stream_items)
78 | if not len(vals):
79 | return True
80 |
81 | if not self.is_3d:
82 | existing_len = len(stream_vis.series)
83 | for i,val in enumerate(vals):
84 | t = BarPlot._val2tuple(val, i+existing_len)
85 | stream_vis.series[t[0]] = t # merge x with previous items
86 | x, y, labels = [t[0] for t in stream_vis.series.values()], \
87 | [t[1] for t in stream_vis.series.values()], \
88 | [t[4] for t in stream_vis.series.values()]
89 |
90 | self.clear_artists(stream_vis) # remove previous bars
91 | bar_container = stream_vis.ax.bar(x, y,
92 | width=stream_vis.bar_width,
93 | tick_label = labels if any(l is not None for l in labels) else None,
94 | color=stream_vis.color, edgecolor=stream_vis.edge_color,
95 | alpha=stream_vis.opacity, linewidth=stream_vis.linewidth)
96 | else:
97 | for i,val in enumerate(vals):
98 | t = BarPlot._val2tuple(val, None) # we should not use i paameter as 3d expects x,y,z
99 | z = t[2]
100 | if z not in stream_vis.series:
101 | stream_vis.series[z] = []
102 | stream_vis.series[t[2]] += [t] # merge z with previous items
103 |
104 | # sort by z so we have consistent colors
105 | ts = sorted(stream_vis.series.items(), key=lambda g: g[0])
106 |
107 | for zi, (z, tg) in enumerate(ts):
108 | x, y, labels = [t[0] for t in tg], \
109 | [t[1] for t in tg], \
110 | [t[4] for t in tg]
111 | colors = stream_vis.color or stream_vis.cmap.colors
112 | color = colors[zi % len(colors)]
113 |
114 | self.clear_artists(stream_vis) # remove previous bars
115 | bar_container = stream_vis.ax.bar(x, y, zs=z, zdir='y',
116 | width=stream_vis.bar_width,
117 | tick_label = labels if any(l is not None for l in labels) else None,
118 | color=color,
119 | edgecolor=stream_vis.edge_color,
120 | alpha=stream_vis.opacity or 0.8, linewidth=stream_vis.linewidth)
121 |
122 | stream_vis.bars_artists = bar_container.patches
123 |
124 | #stream_vis.ax.relim()
125 | #stream_vis.ax.autoscale_view()
126 |
127 | return False
128 |
129 |
--------------------------------------------------------------------------------
/test/test.pyproj:
--------------------------------------------------------------------------------
1 |
2 |
3 | Debug
4 | 2.0
5 | 9a7fe67e-93f0-42b5-b58f-77320fc639e4
6 |
7 |
8 | post_train\saliency.py
9 |
10 |
11 | .
12 | .
13 | test
14 | test
15 | Global|ContinuumAnalytics|Anaconda36-64
16 | False
17 |
18 |
19 | true
20 | false
21 |
22 |
23 | true
24 | false
25 |
26 |
27 |
28 |
29 |
30 | Code
31 |
32 |
33 | Code
34 |
35 |
36 | Code
37 |
38 |
39 |
40 | Code
41 |
42 |
43 | Code
44 |
45 |
46 | Code
47 |
48 |
49 | Code
50 |
51 |
52 | Code
53 |
54 |
55 | Code
56 |
57 |
58 | Code
59 |
60 |
61 | Code
62 |
63 |
64 | Code
65 |
66 |
67 | Code
68 |
69 |
70 | Code
71 |
72 |
73 | Code
74 |
75 |
76 | Code
77 |
78 |
79 | Code
80 |
81 |
82 | Code
83 |
84 |
85 | Code
86 |
87 |
88 | Code
89 |
90 |
91 | Code
92 |
93 |
94 | Code
95 |
96 |
97 | Code
98 |
99 |
100 | Code
101 |
102 |
103 | Code
104 |
105 |
106 | Code
107 |
108 |
109 | Code
110 |
111 |
112 | Code
113 |
114 |
115 | Code
116 |
117 |
118 | Code
119 |
120 |
121 |
122 | Code
123 |
124 |
125 | Code
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 | tensorwatch
134 | {cc8abc7f-ede1-4e13-b6b7-0041a5ec66a7}
135 | True
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
153 |
154 |
155 |
156 |
157 |
158 |
--------------------------------------------------------------------------------
/NOTICE.md:
--------------------------------------------------------------------------------
1 | NOTICES AND INFORMATION
2 | Do Not Translate or Localize
3 |
4 | This software incorporates material from third parties. Microsoft makes certain
5 | open source code available at http://3rdpartysource.microsoft.com, or you may
6 | send a check or money order for US $5.00, including the product name, the open
7 | source component name, and version number, to:
8 |
9 | Source Code Compliance Team
10 | Microsoft Corporation
11 | One Microsoft Way
12 | Redmond, WA 98052
13 | USA
14 |
15 | Notwithstanding any other terms, you may reverse engineer this software to the
16 | extent required to debug changes to any libraries licensed under the GNU Lesser
17 | General Public License.
18 |
19 |
20 | **Component.** https://github.com/Swall0w/torchstat
21 |
22 | **Open Source License/Copyright Notice.**
23 | MIT License
24 |
25 | Copyright (c) 2018 Swall0w - Alan
26 |
27 | Permission is hereby granted, free of charge, to any person obtaining a copy
28 | of this software and associated documentation files (the "Software"), to deal
29 | in the Software without restriction, including without limitation the rights
30 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
31 | copies of the Software, and to permit persons to whom the Software is
32 | furnished to do so, subject to the following conditions:
33 |
34 | The above copyright notice and this permission notice shall be included in all
35 | copies or substantial portions of the Software.
36 |
37 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
38 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
39 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
40 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
41 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
42 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
43 | SOFTWARE.
44 |
45 |
46 | **Component.** https://github.com/waleedka/hiddenlayer
47 |
48 | **Open Source License/Copyright Notice.**
49 | MIT License
50 |
51 | Copyright (c) 2018 Waleed Abdulla, Phil Ferriere
52 |
53 | Permission is hereby granted, free of charge, to any person obtaining a copy
54 | of this software and associated documentation files (the "Software"), to deal
55 | in the Software without restriction, including without limitation the rights
56 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
57 | copies of the Software, and to permit persons to whom the Software is
58 | furnished to do so, subject to the following conditions:
59 |
60 | The above copyright notice and this permission notice shall be included in all
61 | copies or substantial portions of the Software.
62 |
63 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
64 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
65 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
66 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
67 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
68 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
69 | SOFTWARE.
70 |
71 |
72 | **Component.** https://github.com/yulongwang12/visual-attribution
73 |
74 | **Open Source License/Copyright Notice.**
75 | BSD 2-Clause License
76 |
77 | Copyright (c) 2019, Yulong Wang
78 | All rights reserved.
79 |
80 | Redistribution and use in source and binary forms, with or without
81 | modification, are permitted provided that the following conditions are met:
82 |
83 | * Redistributions of source code must retain the above copyright notice, this
84 | list of conditions and the following disclaimer.
85 |
86 | * Redistributions in binary form must reproduce the above copyright notice,
87 | this list of conditions and the following disclaimer in the documentation
88 | and/or other materials provided with the distribution.
89 |
90 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
91 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
92 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
93 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
94 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
95 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
96 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
97 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
98 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
99 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
100 |
101 |
102 | **Component.** https://github.com/marcotcr/lime
103 |
104 | **Open Source License/Copyright Notice.**
105 | Copyright (c) 2016, Marco Tulio Correia Ribeiro
106 | All rights reserved.
107 |
108 | Redistribution and use in source and binary forms, with or without
109 | modification, are permitted provided that the following conditions are met:
110 |
111 | * Redistributions of source code must retain the above copyright notice, this
112 | list of conditions and the following disclaimer.
113 |
114 | * Redistributions in binary form must reproduce the above copyright notice,
115 | this list of conditions and the following disclaimer in the documentation
116 | and/or other materials provided with the distribution.
117 |
118 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
119 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
120 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
121 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
122 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
123 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
124 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
125 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
126 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
127 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/notebooks/data_exploration.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Using TensorWatch for Data Exploration\n",
8 | "\n",
9 | "In this tutorial, we will show how to quickly use TensorWatch for exploring data using dimention reduction technique called [t-sne](https://en.wikipedia.org/wiki/T-distributed_stochastic_neighbor_embedding). We plan to implement many other techniques in future (please feel to [contribute](https://github.com/microsoft/tensorwatch/blob/master/CONTRIBUTING.md)!)."
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "## Installing regim\n",
17 | "This tutorial will use small Python package called `regim`. It has few utility classes to quickly work with PyTorch datasets. \n",
18 | "\n",
19 | "To install regim, clone repo from Github and do local package install:\n",
20 | "\n",
21 | "```\n",
22 | "git clone https://github.com/sytelus/regim.git\n",
23 | "cd regim\n",
24 | "pip install -e .\n",
25 | "```\n",
26 | "\n",
27 | "Now we are done with that, let's do our imports:"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": 1,
33 | "metadata": {
34 | "scrolled": true
35 | },
36 | "outputs": [],
37 | "source": [
38 | "import tensorwatch as tw\n",
39 | "from regim import DataUtils"
40 | ]
41 | },
42 | {
43 | "cell_type": "markdown",
44 | "metadata": {},
45 | "source": [
46 | "First we will get MNIST dataset. The `regim` package has DataUtils class that allows to get entire MNIST dataset without train/test split and reshaping each image as vector of 784 integers instead of 28x28 matrix."
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": 2,
52 | "metadata": {
53 | "scrolled": true
54 | },
55 | "outputs": [],
56 | "source": [
57 | "ds = DataUtils.mnist_datasets(linearize=True, train_test=False)"
58 | ]
59 | },
60 | {
61 | "cell_type": "markdown",
62 | "metadata": {},
63 | "source": [
64 | "MNIST dataset is big enough for t-sne to take long time on slow computers. So we will apply t-sne on sample of this dataset. The `regim` package has utility method that allows us to take `k` random samples for each class. We also set `as_np=True` to convert images to numpy array from PyTorch tensor. The `no_test=True` parameter instructs that we don't want to split our data as train and test. The return value is a tuple of two numpy arrays, one containing input images and other labels."
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": 3,
70 | "metadata": {},
71 | "outputs": [],
72 | "source": [
73 | "inputs, labels = DataUtils.sample_by_class(ds, k=50, shuffle=True, as_np=True, no_test=True)"
74 | ]
75 | },
76 | {
77 | "cell_type": "markdown",
78 | "metadata": {},
79 | "source": [
80 | "We are now ready to supply this dataset to TensorWatch and in just one line we can get lower dimensional components. The `get_tsne_components` method takes a tuple of input and labels. The optional parameters `features_col=0` and `labels_col=1` tells which member of tuple is input features and truth labels. Another optional parameter `n_components=3` says that we should generate 3 components for each data point."
81 | ]
82 | },
83 | {
84 | "cell_type": "code",
85 | "execution_count": 4,
86 | "metadata": {
87 | "scrolled": true
88 | },
89 | "outputs": [],
90 | "source": [
91 | "components = tw.get_tsne_components((inputs, labels))"
92 | ]
93 | },
94 | {
95 | "cell_type": "markdown",
96 | "metadata": {},
97 | "source": [
98 | "Now that we have 3D component for each data point in our dataset, let's plot it! For this purpose, we use `ArrayStream` class from TensorWatch that allows you to convert any iterables in to TensorWatch stream. This stream then we supply to Visualizer class asking it to use `tsne` visualization type which is just fency 3D scatter plot."
99 | ]
100 | },
101 | {
102 | "cell_type": "code",
103 | "execution_count": 5,
104 | "metadata": {
105 | "scrolled": true
106 | },
107 | "outputs": [
108 | {
109 | "data": {
110 | "application/vnd.jupyter.widget-view+json": {
111 | "model_id": "f86416164bb04e729ea16cb71a666a70",
112 | "version_major": 2,
113 | "version_minor": 0
114 | },
115 | "text/plain": [
116 | "HBox(children=(FigureWidget({\n",
117 | " 'data': [{'hoverinfo': 'text',\n",
118 | " 'line': {'color': 'rgb(31, 119,…"
119 | ]
120 | },
121 | "metadata": {},
122 | "output_type": "display_data"
123 | }
124 | ],
125 | "source": [
126 | "comp_stream = tw.ArrayStream(components)\n",
127 | "vis = tw.Visualizer(comp_stream, vis_type='tsne', \n",
128 | " hover_images=inputs, hover_image_reshape=(28,28))\n",
129 | "vis.show()"
130 | ]
131 | },
132 | {
133 | "cell_type": "markdown",
134 | "metadata": {},
135 | "source": [
136 | "Notice that as you hover over mouse on each point on graph, you will see MNIST image associated with that point. How does that work? We are supplying input images to `hover_image` parameter. As our images are 1 dimentional array of 784, we also set `hover_image_reshape` parameter to reshape it to 28x28. You can customize hover behavior by attaching new fuction to vis.[hover_fn](https://github.com/microsoft/tensorwatch/blob/master/tensorwatch/plotly/embeddings_plot.py#L28) member."
137 | ]
138 | },
139 | {
140 | "cell_type": "markdown",
141 | "metadata": {},
142 | "source": [
143 | "## Questions?\n",
144 | "\n",
145 | "Please file a [Github issue](https://github.com/microsoft/tensorwatch/issues/new) and let us know if we can improve this tutorial."
146 | ]
147 | }
148 | ],
149 | "metadata": {
150 | "kernelspec": {
151 | "display_name": "Python 3",
152 | "language": "python",
153 | "name": "python3"
154 | },
155 | "language_info": {
156 | "codemirror_mode": {
157 | "name": "ipython",
158 | "version": 3
159 | },
160 | "file_extension": ".py",
161 | "mimetype": "text/x-python",
162 | "name": "python",
163 | "nbconvert_exporter": "python",
164 | "pygments_lexer": "ipython3",
165 | "version": "3.6.5"
166 | }
167 | },
168 | "nbformat": 4,
169 | "nbformat_minor": 2
170 | }
171 |
--------------------------------------------------------------------------------
/tensorwatch/mpl/base_mpl_plot.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | #from IPython import get_ipython, display
5 | #if get_ipython():
6 | # get_ipython().magic('matplotlib notebook')
7 |
8 | #import matplotlib
9 | #if os.name == 'posix' and "DISPLAY" not in os.environ:
10 | # matplotlib.use('Agg') # Must be before importing matplotlib.pyplot or pylab!
11 |
12 | #from ipywidgets.widgets.interaction import show_inline_matplotlib_plots
13 | #from ipykernel.pylab.backend_inline import flush_figures
14 |
15 | from ..vis_base import VisBase
16 |
17 | import sys, logging
18 | from abc import abstractmethod
19 | from .. import utils
20 |
21 |
22 | class BaseMplPlot(VisBase):
23 | def __init__(self, cell:VisBase.widgets.Box=None, title:str=None, show_legend:bool=None, is_3d:bool=False,
24 | stream_name:str=None, console_debug:bool=False, **vis_args):
25 | super(BaseMplPlot, self).__init__(VisBase.widgets.Output(), cell, title, show_legend,
26 | stream_name=stream_name, console_debug=console_debug, **vis_args)
27 |
28 | self._fig_init_done = False
29 | self.show_legend = show_legend
30 | self.is_3d = is_3d
31 | if is_3d:
32 | # this is needed for some reason
33 | from mpl_toolkits.mplot3d import Axes3D
34 | # graph objects
35 | self.figure = None
36 | self._ax_main = None
37 | # matplotlib animation
38 | self.animation = None
39 | self.anim_interval = None
40 | #print(matplotlib.get_backend())
41 | #display.display(self.cell)
42 |
43 | # anim_interval in seconds
44 | def init_fig(self, anim_interval:float=1.0):
45 | import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue
46 |
47 | """(for derived class) Initializes matplotlib figure"""
48 | if self._fig_init_done:
49 | return False
50 |
51 | # create figure and animation
52 | self.figure = plt.figure(figsize=(8, 3))
53 | self.anim_interval = anim_interval
54 |
55 | # default color pallet
56 | import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue
57 |
58 | plt.set_cmap('Dark2')
59 | plt.rcParams['image.cmap']='Dark2'
60 |
61 | self._fig_init_done = True
62 | return True
63 |
64 | def get_main_axis(self):
65 | # if we don't yet have main axis, create one
66 | if not self._ax_main:
67 | # by default assign one subplot to whole graph
68 | self._ax_main = self.figure.add_subplot(111,
69 | projection=None if not self.is_3d else '3d')
70 | self._ax_main.grid(self.is_show_grid())
71 | # change the color of the top and right spines to opaque gray
72 | self._ax_main.spines['right'].set_color((.8,.8,.8))
73 | self._ax_main.spines['top'].set_color((.8,.8,.8))
74 | if self.title is not None:
75 | title = self._ax_main.set_title(self.title)
76 | title.set_weight('bold')
77 | return self._ax_main
78 |
79 | # overridable
80 | def is_show_grid(self):
81 | return True
82 |
83 | def _on_update(self, frame): # pylint: disable=unused-argument
84 | try:
85 | self._update_stream_plots()
86 | except Exception as ex:
87 | # when exception occurs here, animation will stop and there
88 | # will be no further plot updates
89 | # TODO: may be we don't need all of below but none of them
90 | # are popping up exception in Jupyter Notebook because these
91 | # exceptions occur in background?
92 | self.last_ex = ex
93 | logging.exception('Exception in matplotlib update loop')
94 |
95 |
96 | def show(self, blocking=False):
97 | if not self.is_shown and self.anim_interval:
98 | from matplotlib.animation import FuncAnimation # function-level import as this one is expensive
99 | self.animation = FuncAnimation(self.figure, self._on_update, interval=self.anim_interval*1000.0)
100 | super(BaseMplPlot, self).show(blocking)
101 |
102 | def _post_update_stream_plot(self, stream_vis):
103 | import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue
104 |
105 | utils.debug_log("Plot updated", stream_vis.stream.stream_name, verbosity=5)
106 |
107 | if self.layout_dirty:
108 | # do not do tight_layout() call on every update
109 | # that would jumble up the graphs! it should only called
110 | # once each time there is change in layout
111 | self.figure.tight_layout()
112 | self.layout_dirty = False
113 |
114 | # below forces redraw and it was helpful to
115 | # repaint even if there was error in interval loop
116 | # but it does work in native UX and not in Jupyter Notebook
117 | #self.figure.canvas.draw()
118 | #self.figure.canvas.flush_events()
119 |
120 | if self._use_hbox and VisBase.get_ipython():
121 | self.widget.clear_output(wait=True)
122 | with self.widget:
123 | plt.show(self.figure)
124 |
125 | # everything else that doesn't work
126 | #self.figure.show()
127 | #display.clear_output(wait=True)
128 | #display.display(self.figure)
129 | #flush_figures()
130 | #plt.show()
131 | #show_inline_matplotlib_plots()
132 | #elif not get_ipython():
133 | # self.figure.canvas.draw()
134 |
135 | def _post_add_subscription(self, stream_vis, **stream_vis_args):
136 | import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue
137 |
138 | # make sure figure is initialized
139 | self.init_fig()
140 | self.init_stream_plot(stream_vis, **stream_vis_args)
141 |
142 | # redo the legend
143 | #self.figure.legend(loc='center right', bbox_to_anchor=(1.5, 0.5))
144 | if self.show_legend:
145 | self.figure.legend(loc='lower right')
146 | plt.subplots_adjust(hspace=0.6)
147 |
148 | def _show_widget_native(self, blocking:bool):
149 | import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue
150 |
151 | #plt.ion()
152 | #plt.show()
153 | return plt.show(block=blocking)
154 |
155 | def _show_widget_notebook(self):
156 | # no need to return anything because %matplotlib notebook will
157 | # detect spawning of figure and paint it
158 | # if self.figure is returned then you will see two of them
159 | return None
160 | #plt.show()
161 | #return self.figure
162 |
163 | def _can_update_stream_plots(self):
164 | return False # we run interval timer which will flush the key
165 |
166 | @abstractmethod
167 | def init_stream_plot(self, stream_vis, **stream_vis_args):
168 | """(for derived class) Create new plot info for this stream"""
169 | pass
170 |
--------------------------------------------------------------------------------
/tensorwatch/lv_types.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from typing import List, Callable, Any, Sequence, Hashable
5 | from . import utils
6 | import uuid
7 |
8 |
9 | class EventData:
10 | def __init__(self, globals_val, **vars_val):
11 | if globals_val is not None:
12 | for key in globals_val:
13 | setattr(self, key, globals_val[key])
14 | for key in vars_val:
15 | setattr(self, key, vars_val[key])
16 |
17 | def __str__(self):
18 | sb = []
19 | for key in self.__dict__:
20 | val = self.__dict__[key]
21 | if utils.is_scalar(val):
22 | sb.append('{key}={value}'.format(key=key, value=val))
23 | else:
24 | sb.append('{key}="{value}"'.format(key=key, value=val))
25 |
26 | return ', '.join(sb)
27 |
28 | EventsVars = List[EventData]
29 |
30 | class StreamItem:
31 | def __init__(self, value:Any, stream_name:str=None, item_index:int=None,
32 | ended:bool=False, exception:Exception=None, stream_reset:bool=False):
33 | self.value = value
34 | self.exception = exception
35 | self.stream_name = stream_name
36 | self.item_index = item_index
37 | self.ended = ended
38 | self.stream_reset = stream_reset
39 |
40 | def __repr__(self):
41 | return str(self.__dict__)
42 |
43 | EventEvalFunc = Callable[[EventsVars], StreamItem]
44 |
45 |
46 | class VisArgs:
47 | """Provides container for visualizer parameters
48 |
49 | These are same parameters as Visualizer constructor.
50 | NOTE: If you modify arguments here then also sync Visualizer contructor.
51 | """
52 | def __init__(self, vis_type:str=None, host:'Visualizer'=None,
53 | cell:'Visualizer'=None, title:str=None,
54 | clear_after_end=False, clear_after_each=False, history_len=1, dim_history=True, opacity=None,
55 |
56 | rows=2, cols=5, img_width=None, img_height=None, img_channels=None,
57 | colormap=None, viz_img_scale=None,
58 |
59 | # these image params are for hover on point for t-sne
60 | hover_images=None, hover_image_reshape=None, cell_width:str=None, cell_height:str=None,
61 |
62 | only_summary=False, separate_yaxis=True, xtitle=None, ytitle=None, ztitle=None, color=None,
63 | xrange=None, yrange=None, zrange=None, draw_line=True, draw_marker=False,
64 |
65 | # histogram
66 | bins=None, normed=None, histtype='bar', edge_color=None, linewidth=None, bar_width=None,
67 |
68 | # pie chart
69 | autopct=None, shadow=None,
70 |
71 | vis_args:dict=None, stream_vis_args:dict=None)->None:
72 |
73 | self.vis_type, self.host = vis_type, host
74 | self.cell, self.title = cell, title
75 | self.clear_after_end, self.clear_after_each, self.history_len, self.dim_history, self.opacity = \
76 | clear_after_end, clear_after_each, history_len, dim_history, opacity
77 | self.rows, self.cols, self.img_width, self.img_height, self.img_channels = \
78 | rows, cols, img_width, img_height, img_channels
79 | self.colormap, self.viz_img_scale = colormap, viz_img_scale
80 |
81 | # these image params are for hover on point for t-sne
82 | self.hover_images, self.hover_image_reshape, self.cell_width, self.cell_height = \
83 | hover_images, hover_image_reshape, cell_width, cell_height
84 |
85 | self.only_summary, self.separate_yaxis, self.xtitle, self.ytitle, self.ztitle, self.color = \
86 | only_summary, separate_yaxis, xtitle, ytitle, ztitle, color
87 | self.xrange, self.yrange, self.zrange, self.draw_line, self.draw_marker = \
88 | xrange, yrange, zrange, draw_line, draw_marker
89 |
90 | # histogram
91 | self.bins, self.normed, self.histtype, self.edge_color, self.linewidth, self.bar_width = \
92 | bins, normed, histtype, edge_color, linewidth, bar_width
93 | # pie chart
94 | self.autopct, self.shadow = autopct, shadow
95 |
96 | self.vis_args, self.stream_vis_args = vis_args, stream_vis_args
97 |
98 |
99 | class StreamCreateRequest:
100 | def __init__(self, stream_name:str, devices:Sequence[str]=None, event_name:str='',
101 | expr:str=None, throttle:float=None, vis_args:VisArgs=None):
102 | self.event_name = event_name
103 | self.expr = expr
104 | self.stream_name = stream_name or str(uuid.uuid4())
105 | self.devices = devices
106 | self.vis_args = vis_args
107 |
108 | # max throughput n Lenovo P50 laptop for MNIST
109 | # text console -> 0.1s
110 | # matplotlib line graph -> 0.5s
111 | self.throttle = throttle
112 |
113 | class ClientServerRequest:
114 | def __init__(self, req_type:str, req_data:Any):
115 | self.req_type = req_type
116 | self.req_data = req_data
117 |
118 | class CliSrvReqTypes:
119 | create_stream = 'CreateStream'
120 | del_stream = 'DeleteStream'
121 |
122 | class StreamVisInfo:
123 | def __init__(self, stream, title, clear_after_end,
124 | clear_after_each, history_len, dim_history, opacity,
125 | index, stream_vis_args, last_update):
126 | self.stream = stream
127 | self.title, self.opacity = title, opacity
128 | self.clear_after_end, self.clear_after_each = clear_after_end, clear_after_each
129 | self.history_len, self.dim_history = history_len, dim_history
130 | self.index, self.stream_vis_args, self.last_update = index, stream_vis_args, last_update
131 |
132 | class ImageData:
133 | # images are numpy array of shape [[channels,] height, width]
134 | def __init__(self, images=None, title=None, alpha=None, cmap=None):
135 | if not isinstance(images, tuple):
136 | images = (images,)
137 | self.images, self.alpha, self.cmap, self.title = images, alpha, cmap, title
138 |
139 | class PointData:
140 | def __init__(self, x:float=None, y:float=None, z:float=None, low:float=None, high:float=None,
141 | annotation:Any=None, text:Any=None, color:Any=None)->None:
142 | self.x = x
143 | self.y = y
144 | self.z = z
145 | self.low = low # confidence interval
146 | self.high = high
147 | self.annotation = annotation
148 | self.text = text
149 | self.color = color # typically string like '#d62728'
150 |
151 | class PredictionResult:
152 | def __init__(self, loss:float=None, class_id:Hashable=None, probability:float=None,
153 | inputs:Any=None, outputs:Any=None, targets:Any=None, others:Any=None):
154 | self.loss = loss
155 | self.class_id = class_id
156 | self.probability = probability
157 | self.inputs = inputs
158 | self.outputs = outputs
159 | self.targets = targets
160 | self.others = others
161 |
162 | class DefaultPorts:
163 | PubSub = 40859
164 | CliSrv = 41459
165 |
166 | class PublisherTopics:
167 | StreamItem = 'StreamItem'
168 | ServerMgmt = 'ServerMgmt'
169 |
170 | class ServerMgmtMsg:
171 | EventServerStart = 'ServerStart'
172 | def __init__(self, event_name:str, event_args:Any=None):
173 | self.event_name = event_name
174 | self.event_args = event_args
175 |
--------------------------------------------------------------------------------