├── tensorwatch ├── saliency │ ├── lime │ │ ├── __init__.py │ │ └── wrappers │ │ │ ├── __init__.py │ │ │ ├── generic_utils.py │ │ │ └── scikit_image.py │ ├── __init__.py │ ├── README.md │ ├── gradcam.py │ ├── occlusion.py │ ├── lime_image_explainer.py │ ├── deeplift.py │ ├── backprop.py │ └── saliency.py ├── model_graph │ ├── hiddenlayer │ │ ├── __init__.py │ │ ├── README.md │ │ ├── pytorch_builder.py │ │ ├── ge.py │ │ └── tf_builder.py │ ├── __init__.py │ └── torchstat_utils.py ├── mpl │ ├── __init__.py │ ├── pie_chart.py │ ├── histogram.py │ ├── image_plot.py │ ├── bar_plot.py │ └── base_mpl_plot.py ├── plotly │ ├── __init__.py │ ├── embeddings_plot.py │ └── base_plotly_plot.py ├── embeddings │ ├── __init__.py │ └── tsne_utils.py ├── receptive_field │ ├── __init__.py │ └── rf_utils.py ├── array_stream.py ├── stream_union.py ├── __init__.py ├── filtered_stream.py ├── data_utils.py ├── pytorch_utils.py ├── zmq_stream.py ├── file_stream.py ├── repeated_timer.py ├── zmq_mgmt_stream.py ├── imagenet_utils.py ├── stream.py ├── tensor_utils.py ├── text_vis.py ├── watcher_client.py ├── notebook_maker.py ├── watcher.py ├── stream_factory.py ├── evaler.py ├── image_utils.py ├── visualizer.py └── lv_types.py ├── MANIFEST.in ├── docs ├── images │ ├── tsne.gif │ ├── fruits.gif │ ├── teaser.gif │ ├── saliency.png │ ├── draw_model.png │ ├── model_stats.png │ ├── quick_start.gif │ ├── teaser_small.gif │ ├── lazy_log_array_sum.png │ └── simple_logging │ │ ├── line_cell.png │ │ ├── text_cell.png │ │ ├── line_cell2.png │ │ ├── plotly_line.png │ │ └── text_summary.png ├── simple_logging.md └── lazy_logging.md ├── data └── test_images │ ├── cat.jpg │ ├── dogs.png │ └── elephant.png ├── test ├── simple_log │ ├── quick_start.py │ ├── sum_lazy.py │ ├── file_expr.py │ ├── sum_log.py │ ├── srv_ij.py │ ├── cli_file_expr.py │ ├── cli_sum_log.py │ └── cli_ij.py ├── pre_train │ ├── draw_model.py │ └── tsny.py ├── visualizations │ ├── confidence_int.py │ ├── plotly_line.py │ ├── mpl_line.py │ ├── arr_img_plot.py │ ├── histogram.py │ ├── pie_chart.py │ ├── line3d_plot.py │ └── bar_plot.py ├── components │ ├── circ_ref.py │ ├── evaler.py │ ├── stream.py │ ├── notebook_maker.py │ ├── watcher.py │ └── file_only_test.py ├── zmq │ ├── zmq_watcher_server.py │ ├── zmq_watcher_client.py │ ├── zmq_stream.py │ ├── zmq_pub.py │ └── zmq_sub.py ├── deps │ ├── panda.py │ ├── ipython_widget.py │ ├── thread.py │ └── live_graph.py ├── post_train │ └── saliency.py ├── files │ └── file_stream.py ├── mnist │ └── cli_mnist.py └── test.pyproj ├── .pylintrc ├── update_package.bat ├── SUPPORT.md ├── CHANGELOG.md ├── install_jupyterlab.bat ├── CONTRIBUTING.md ├── tensorwatch.sln ├── setup.py ├── LICENSE.TXT ├── TODO.md ├── NOTICE.md └── notebooks └── data_exploration.ipynb /tensorwatch/saliency/lime/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tensorwatch/saliency/lime/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tensorwatch/model_graph/hiddenlayer/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include tensorwatch/*.txt 2 | include tensorwatch/*.json -------------------------------------------------------------------------------- /docs/images/tsne.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0101011/tensorwatch/master/docs/images/tsne.gif -------------------------------------------------------------------------------- /docs/images/fruits.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0101011/tensorwatch/master/docs/images/fruits.gif -------------------------------------------------------------------------------- /docs/images/teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0101011/tensorwatch/master/docs/images/teaser.gif -------------------------------------------------------------------------------- /data/test_images/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0101011/tensorwatch/master/data/test_images/cat.jpg -------------------------------------------------------------------------------- /data/test_images/dogs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0101011/tensorwatch/master/data/test_images/dogs.png -------------------------------------------------------------------------------- /docs/images/saliency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0101011/tensorwatch/master/docs/images/saliency.png -------------------------------------------------------------------------------- /docs/images/draw_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0101011/tensorwatch/master/docs/images/draw_model.png -------------------------------------------------------------------------------- /docs/images/model_stats.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0101011/tensorwatch/master/docs/images/model_stats.png -------------------------------------------------------------------------------- /docs/images/quick_start.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0101011/tensorwatch/master/docs/images/quick_start.gif -------------------------------------------------------------------------------- /docs/images/teaser_small.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0101011/tensorwatch/master/docs/images/teaser_small.gif -------------------------------------------------------------------------------- /data/test_images/elephant.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0101011/tensorwatch/master/data/test_images/elephant.png -------------------------------------------------------------------------------- /tensorwatch/mpl/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | 5 | -------------------------------------------------------------------------------- /tensorwatch/plotly/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | -------------------------------------------------------------------------------- /tensorwatch/saliency/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | -------------------------------------------------------------------------------- /tensorwatch/model_graph/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | -------------------------------------------------------------------------------- /docs/images/lazy_log_array_sum.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0101011/tensorwatch/master/docs/images/lazy_log_array_sum.png -------------------------------------------------------------------------------- /tensorwatch/embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | 5 | -------------------------------------------------------------------------------- /tensorwatch/receptive_field/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | -------------------------------------------------------------------------------- /docs/images/simple_logging/line_cell.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0101011/tensorwatch/master/docs/images/simple_logging/line_cell.png -------------------------------------------------------------------------------- /docs/images/simple_logging/text_cell.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0101011/tensorwatch/master/docs/images/simple_logging/text_cell.png -------------------------------------------------------------------------------- /docs/images/simple_logging/line_cell2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0101011/tensorwatch/master/docs/images/simple_logging/line_cell2.png -------------------------------------------------------------------------------- /docs/images/simple_logging/plotly_line.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0101011/tensorwatch/master/docs/images/simple_logging/plotly_line.png -------------------------------------------------------------------------------- /docs/images/simple_logging/text_summary.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0101011/tensorwatch/master/docs/images/simple_logging/text_summary.png -------------------------------------------------------------------------------- /tensorwatch/model_graph/hiddenlayer/README.md: -------------------------------------------------------------------------------- 1 | # Credits 2 | 3 | Code in this folder is adopted from, 4 | 5 | * https://github.com/waleedka/hiddenlayer 6 | -------------------------------------------------------------------------------- /tensorwatch/saliency/README.md: -------------------------------------------------------------------------------- 1 | # Credits 2 | Code in this folder is adopted from 3 | 4 | * https://github.com/yulongwang12/visual-attribution 5 | * https://github.com/marcotcr/lime 6 | -------------------------------------------------------------------------------- /test/simple_log/quick_start.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | import time 3 | 4 | w = tw.Watcher(filename='test.log') 5 | s = w.create_stream(name='my_metric') 6 | #w.make_notebook() 7 | 8 | for i in range(1000): 9 | s.write((i, i*i)) 10 | time.sleep(1) 11 | -------------------------------------------------------------------------------- /test/pre_train/draw_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.models 3 | import tensorwatch as tw 4 | 5 | vgg16_model = torchvision.models.vgg16() 6 | 7 | drawing = tw.draw_model(vgg16_model, [1, 3, 224, 224]) 8 | drawing.save('abc') 9 | 10 | input("Press any key") -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | ignore=model_graph, 3 | receptive_field, 4 | saliency 5 | 6 | [TYPECHECK] 7 | ignored-modules=numpy,torch,matplotlib,pyplot,zmq 8 | 9 | [MESSAGES CONTROL] 10 | disable=protected-access, 11 | broad-except, 12 | global-statement, 13 | fixme, 14 | C, 15 | R -------------------------------------------------------------------------------- /update_package.bat: -------------------------------------------------------------------------------- 1 | PAUSE Make sure to increment version in setup.py. Continue? 2 | python setup.py sdist 3 | twine upload --repository-url https://upload.pypi.org/legacy/ dist/* 4 | REM pip install tensorwatch --upgrade 5 | REM pip show tensorwatch 6 | 7 | REM pip install yolk3k 8 | REM yolk -V tensorwatch -------------------------------------------------------------------------------- /test/visualizations/confidence_int.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | 3 | w = tw.Watcher() 4 | s = w.create_stream() 5 | 6 | v = tw.Visualizer(s, vis_type='line') 7 | v.show() 8 | 9 | for i in range(10): 10 | i = float(i) 11 | s.write(tw.PointData(i, i*i, low=i*i-i, high=i*i+i)) 12 | 13 | tw.plt_loop() -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # Support 2 | 3 | We highly recommend to take a look at source code and contribute to the project. Also please consider [contributing](CONTRIBUTING.md) new features and fixes :). 4 | 5 | * [Join TensorWatch Facebook Group](https://www.facebook.com/groups/378075159472803/) 6 | * [File GitHub Issue](https://github.com/Microsoft/tensorwatch/issues) -------------------------------------------------------------------------------- /test/components/circ_ref.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | import objgraph, time #pip install objgraph 3 | 4 | cli = tw.WatcherClient() 5 | time.sleep(10) 6 | del cli 7 | 8 | import gc 9 | gc.collect() 10 | 11 | import time 12 | time.sleep(2) 13 | 14 | objgraph.show_backrefs(objgraph.by_type('WatcherClient'), refcounts=True, filename='b.png') 15 | 16 | -------------------------------------------------------------------------------- /test/components/evaler.py: -------------------------------------------------------------------------------- 1 | #import tensorwatch as tw 2 | from tensorwatch import evaler 3 | 4 | 5 | e = evaler.Evaler(expr='reduce(lambda x,y: (x+y), map(lambda x:(x**2), filter(lambda x: x%2==0, l)))') 6 | for i in range(5): 7 | eval_return = e.post(i) 8 | print(i, eval_return) 9 | eval_return = e.post(ended=True) 10 | print(i, eval_return) 11 | -------------------------------------------------------------------------------- /test/zmq/zmq_watcher_server.py: -------------------------------------------------------------------------------- 1 | from tensorwatch.watcher import Watcher 2 | import time 3 | 4 | from tensorwatch import utils 5 | utils.set_debug_verbosity(10) 6 | 7 | 8 | def main(): 9 | watcher = Watcher() 10 | 11 | for i in range(5000): 12 | watcher.observe(x=i) 13 | # print(i) 14 | time.sleep(1) 15 | 16 | main() 17 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # What's new 2 | 3 | Below is summarized list of important changes. This does not include minor/less important changes or bug fixes or documentation update. This list updated every few months. For complete detailed changes, please review [commit history](https://github.com/Microsoft/tensorwatch/commits/master). 4 | 5 | ### May, 2019 6 | * First release! -------------------------------------------------------------------------------- /test/deps/panda.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | class SomeThing: 4 | def __init__(self, x, y): 5 | self.x, self.y = x, y 6 | 7 | things = [SomeThing(1,2), SomeThing(3,4), SomeThing(4,5)] 8 | 9 | df = pd.DataFrame([t.__dict__ for t in things ]) 10 | 11 | print(df.iloc[[-1]].to_dict('records')[0]) 12 | #print(df.to_html()) 13 | print(df.style.render()) 14 | -------------------------------------------------------------------------------- /test/components/stream.py: -------------------------------------------------------------------------------- 1 | from tensorwatch.stream import Stream 2 | 3 | 4 | s1 = Stream(stream_name='s1', console_debug=True) 5 | s2 = Stream(stream_name='s2', console_debug=True) 6 | s3 = Stream(stream_name='s3', console_debug=True) 7 | 8 | s1.subscribe(s2) 9 | s2.subscribe(s3) 10 | 11 | s3.write('S3 wrote this') 12 | s2.write('S2 wrote this') 13 | s1.write('S1 wrote this') 14 | 15 | -------------------------------------------------------------------------------- /test/zmq/zmq_watcher_client.py: -------------------------------------------------------------------------------- 1 | from tensorwatch.watcher_client import WatcherClient 2 | import time 3 | 4 | from tensorwatch import utils 5 | utils.set_debug_verbosity(10) 6 | 7 | 8 | def main(): 9 | watcher = WatcherClient() 10 | stream = watcher.create_stream(expr='lambda vars:vars.x**2') 11 | stream.console_debug = True 12 | input('pause') 13 | 14 | main() 15 | 16 | -------------------------------------------------------------------------------- /test/components/notebook_maker.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | 3 | 4 | def main(): 5 | w = tw.Watcher() 6 | s1 = w.create_stream() 7 | s2 = w.create_stream(name='accuracy', vis_args=tw.VisArgs(vis_type='line', xtitle='X-Axis', clear_after_each=False, history_len=2)) 8 | s3 = w.create_stream(name='loss', expr='lambda d:d.loss') 9 | w.make_notebook() 10 | 11 | main() 12 | 13 | -------------------------------------------------------------------------------- /test/pre_train/tsny.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | 3 | from regim import * 4 | ds = DataUtils.mnist_datasets(linearize=True, train_test=False) 5 | ds = DataUtils.sample_by_class(ds, k=5, shuffle=True, as_np=True, no_test=True) 6 | 7 | comps = tw.get_tsne_components(ds) 8 | print(comps) 9 | plot = tw.Visualizer(comps, hover_images=ds[0], hover_image_reshape=(28,28), vis_type='tsne') 10 | plot.show() -------------------------------------------------------------------------------- /test/simple_log/sum_lazy.py: -------------------------------------------------------------------------------- 1 | import time, random 2 | import tensorwatch as tw 3 | 4 | # create watcher object as usual 5 | w = tw.Watcher() 6 | 7 | weights = None 8 | for i in range(10000): 9 | weights = [random.random() for _ in range(5)] 10 | 11 | # let watcher observe variables we have 12 | # this has almost no performance cost 13 | w.observe(weights=weights) 14 | 15 | time.sleep(1) 16 | 17 | -------------------------------------------------------------------------------- /test/deps/ipython_widget.py: -------------------------------------------------------------------------------- 1 | import ipywidgets as widgets 2 | from IPython import get_ipython 3 | 4 | class PrinterX: 5 | def __init__(self): 6 | self.w = w=widgets.HTML() 7 | def show(self): 8 | return self.w 9 | def write(self,s): 10 | self.w.value = s 11 | 12 | print("Running from within ipython?", get_ipython() is not None) 13 | p=PrinterX() 14 | p.show() 15 | p.write('ffffffffff') 16 | -------------------------------------------------------------------------------- /test/components/watcher.py: -------------------------------------------------------------------------------- 1 | from tensorwatch.watcher_base import WatcherBase 2 | from tensorwatch.stream import Stream 3 | 4 | def main(): 5 | watcher = WatcherBase() 6 | console_pub = Stream(stream_name = 'S1', console_debug=True) 7 | stream = watcher.create_stream(expr='lambda vars:vars.x**2') 8 | console_pub.subscribe(stream) 9 | 10 | for i in range(5): 11 | watcher.observe(x=i) 12 | 13 | main() 14 | 15 | 16 | -------------------------------------------------------------------------------- /test/simple_log/file_expr.py: -------------------------------------------------------------------------------- 1 | import time 2 | import tensorwatch as tw 3 | from tensorwatch import utils 4 | utils.set_debug_verbosity(4) 5 | 6 | srv = tw.Watcher(filename=r'c:\temp\sum.log') 7 | s1 = srv.create_stream('sum', expr='lambda v:(v.i, v.sum)') 8 | s2 = srv.create_stream('sum_2', expr='lambda v:(v.i, v.sum/2)') 9 | 10 | sum = 0 11 | for i in range(10000): 12 | sum += i 13 | srv.observe(i=i, sum=sum) 14 | #print(i, sum) 15 | time.sleep(1) 16 | 17 | -------------------------------------------------------------------------------- /test/zmq/zmq_stream.py: -------------------------------------------------------------------------------- 1 | from tensorwatch.watcher_base import WatcherBase 2 | from tensorwatch.zmq_stream import ZmqStream 3 | 4 | def main(): 5 | watcher = WatcherBase() 6 | stream = watcher.create_stream(expr='lambda vars:vars.x**2') 7 | 8 | zmq_pub = ZmqStream(for_write=True, stream_name = 'ZmqPub', console_debug=True) 9 | zmq_pub.subscribe(stream) 10 | 11 | for i in range(5): 12 | watcher.observe(x=i) 13 | input('paused') 14 | 15 | main() 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /test/simple_log/sum_log.py: -------------------------------------------------------------------------------- 1 | import time, random 2 | import tensorwatch as tw 3 | 4 | # we will create two streams, one for 5 | # sums of integers, other for random numbers 6 | w = tw.Watcher() 7 | st_isum = w.create_stream('isums') 8 | st_rsum = w.create_stream('rsums') 9 | 10 | isum, rsum = 0, 0 11 | for i in range(10000): 12 | isum += i 13 | rsum += random.randint(0,i) 14 | 15 | # write to streams 16 | st_isum.write((i, isum)) 17 | st_rsum.write((i, rsum)) 18 | 19 | time.sleep(1) 20 | -------------------------------------------------------------------------------- /install_jupyterlab.bat: -------------------------------------------------------------------------------- 1 | conda install -c conda-forge jupyterlab nodejs 2 | conda install ipywidgets 3 | conda install -c plotly plotly-orca psutil 4 | 5 | set NODE_OPTIONS=--max-old-space-size=4096 6 | jupyter labextension install @jupyter-widgets/jupyterlab-manager --no-build 7 | jupyter labextension install plotlywidget --no-build 8 | jupyter labextension install @jupyterlab/plotly-extension --no-build 9 | jupyter labextension install jupyterlab-chart-editor --no-build 10 | jupyter lab build 11 | set NODE_OPTIONS= -------------------------------------------------------------------------------- /test/zmq/zmq_pub.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | import time 3 | from tensorwatch.zmq_wrapper import ZmqWrapper 4 | from tensorwatch import utils 5 | 6 | utils.set_debug_verbosity(10) 7 | 8 | def clisrv_callback(clisrv, msg): 9 | print('from clisrv', msg) 10 | 11 | stream = ZmqWrapper.Publication(port = 40859) 12 | clisrv = ZmqWrapper.ClientServer(40860, True, clisrv_callback) 13 | 14 | for i in range(10000): 15 | stream.send_obj({'a': i}, "Topic1") 16 | print("sent ", i) 17 | time.sleep(1) 18 | -------------------------------------------------------------------------------- /test/simple_log/srv_ij.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | import time 3 | import random 4 | from tensorwatch import utils 5 | 6 | utils.set_debug_verbosity(4) 7 | 8 | srv = tw.Watcher() 9 | 10 | while(True): 11 | for i in range(1000): 12 | srv.observe("ev_i", val=i*random.random(), x=i) 13 | print('sent ev_i ', i) 14 | time.sleep(1) 15 | for j in range(5): 16 | srv.observe("ev_j", x=j, val=j*random.random()) 17 | print('sent ev_j ', j) 18 | time.sleep(0.5) 19 | srv.end_event("ev_j") 20 | srv.end_event("ev_i") -------------------------------------------------------------------------------- /test/visualizations/plotly_line.py: -------------------------------------------------------------------------------- 1 | from tensorwatch.watcher_base import WatcherBase 2 | from tensorwatch.plotly.line_plot import LinePlot 3 | from tensorwatch.image_utils import plt_loop 4 | from tensorwatch.stream import Stream 5 | from tensorwatch.lv_types import StreamItem 6 | 7 | 8 | def main(): 9 | watcher = WatcherBase() 10 | line_plot = LinePlot() 11 | stream = watcher.create_stream(expr='lambda vars:vars.x') 12 | line_plot.subscribe(stream) 13 | line_plot.show() 14 | 15 | for i in range(5): 16 | watcher.observe(x=(i, i*i)) 17 | 18 | main() 19 | 20 | -------------------------------------------------------------------------------- /test/visualizations/mpl_line.py: -------------------------------------------------------------------------------- 1 | from tensorwatch.watcher_base import WatcherBase 2 | from tensorwatch.mpl.line_plot import LinePlot 3 | from tensorwatch.image_utils import plt_loop 4 | from tensorwatch.stream import Stream 5 | from tensorwatch.lv_types import StreamItem 6 | 7 | 8 | def main(): 9 | watcher = WatcherBase() 10 | line_plot = LinePlot() 11 | stream = watcher.create_stream(expr='lambda vars:vars.x') 12 | line_plot.subscribe(stream) 13 | line_plot.show() 14 | 15 | for i in range(5): 16 | watcher.observe(x=(i, i*i)) 17 | plt_loop() 18 | 19 | main() 20 | 21 | -------------------------------------------------------------------------------- /test/zmq/zmq_sub.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | import time 3 | from tensorwatch.zmq_wrapper import ZmqWrapper 4 | from tensorwatch import utils 5 | 6 | class A: 7 | def on_event(self, obj): 8 | print(obj) 9 | 10 | a = A() 11 | 12 | 13 | utils.set_debug_verbosity(10) 14 | sub = ZmqWrapper.Subscription(40859, "Topic1", a.on_event) 15 | print("subscriber is waiting") 16 | 17 | clisrv = ZmqWrapper.ClientServer(40860, False) 18 | clisrv.send_obj("hello 1") 19 | print('sleeping..') 20 | time.sleep(10) 21 | clisrv.send_obj("hello 2") 22 | 23 | print('waiting for key..') 24 | utils.wait_key() -------------------------------------------------------------------------------- /test/visualizations/arr_img_plot.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | import numpy as np 3 | import time 4 | import torchvision.datasets as datasets 5 | 6 | fruits_ds = datasets.ImageFolder(r'D:\datasets\fruits-360\Training') 7 | mnist_ds = datasets.MNIST('../data', train=True, download=True) 8 | 9 | images = [tw.ImageData(fruits_ds[i][0], title=str(i)) for i in range(5)] + \ 10 | [tw.ImageData(mnist_ds[i][0], title=str(i)) for i in range(5)] 11 | 12 | stream = tw.ArrayStream(images) 13 | 14 | img_plot = tw.Visualizer(stream, vis_type='image', viz_img_scale=3) 15 | img_plot.show() 16 | 17 | tw.image_utils.plt_loop() -------------------------------------------------------------------------------- /test/post_train/saliency.py: -------------------------------------------------------------------------------- 1 | from tensorwatch.saliency import saliency 2 | from tensorwatch import image_utils, imagenet_utils, pytorch_utils 3 | 4 | model = pytorch_utils.get_model('resnet50') 5 | raw_input, input, target_class = pytorch_utils.image_class2tensor('../data/test_images/dogs.png', 240, #'../data/elephant.png', 101, 6 | image_transform=imagenet_utils.get_image_transform(), image_convert_mode='RGB') 7 | 8 | results = saliency.get_image_saliency_results(model, raw_input, input, target_class) 9 | figure = saliency.get_image_saliency_plot(results) 10 | 11 | image_utils.plt_loop() 12 | 13 | 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /test/deps/thread.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import time 3 | import sys 4 | 5 | def handler(a,b=None): 6 | sys.exit(1) 7 | def install_handler(): 8 | if sys.platform == "win32": 9 | if sys.stdin is not None and sys.stdin.isatty(): 10 | #this is Console based application 11 | import win32api 12 | win32api.SetConsoleCtrlHandler(handler, True) 13 | 14 | 15 | def work(): 16 | time.sleep(10000) 17 | t = threading.Thread(target=work, name='ThreadTest') 18 | t.daemon = True 19 | t.start() 20 | while(True): 21 | t.join(0.1) #100ms ~ typical human response 22 | # you will get KeyboardIntrupt exception 23 | 24 | -------------------------------------------------------------------------------- /tensorwatch/array_stream.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from .stream import Stream 5 | from .lv_types import StreamItem 6 | import uuid 7 | 8 | class ArrayStream(Stream): 9 | def __init__(self, array, stream_name:str=None, console_debug:bool=False): 10 | super(ArrayStream, self).__init__(stream_name=stream_name, console_debug=console_debug) 11 | 12 | self.stream_name = stream_name 13 | self.array = array 14 | 15 | def load(self, from_stream:'Stream'=None): 16 | if self.array is not None: 17 | self.write(self.array) 18 | super(ArrayStream, self).load() 19 | -------------------------------------------------------------------------------- /test/visualizations/histogram.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | import random, time 3 | 4 | def static_hist(): 5 | w = tw.Watcher() 6 | s = w.create_stream() 7 | 8 | v = tw.Visualizer(s, vis_type='histogram', bins=6) 9 | v.show() 10 | 11 | for _ in range(100): 12 | s.write(random.random()*10) 13 | 14 | tw.plt_loop() 15 | 16 | 17 | def dynamic_hist(): 18 | w = tw.Watcher() 19 | s = w.create_stream() 20 | 21 | v = tw.Visualizer(s, vis_type='histogram', bins=6, clear_after_each=True) 22 | v.show() 23 | 24 | for _ in range(100): 25 | s.write([random.random()*10 for _ in range(100)]) 26 | tw.plt_loop(count=3) 27 | 28 | dynamic_hist() -------------------------------------------------------------------------------- /test/visualizations/pie_chart.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | import random, time 3 | 4 | def static_pie(): 5 | w = tw.Watcher() 6 | s = w.create_stream() 7 | 8 | v = tw.Visualizer(s, vis_type='pie', bins=6) 9 | v.show() 10 | 11 | for i in range(6): 12 | s.write(('label'+str(i), random.random()*10, None, 0.5 if i==3 else 0)) 13 | 14 | tw.plt_loop() 15 | 16 | 17 | def dynamic_pie(): 18 | w = tw.Watcher() 19 | s = w.create_stream() 20 | 21 | v = tw.Visualizer(s, vis_type='pie', bins=6, clear_after_each=True) 22 | v.show() 23 | 24 | for _ in range(100): 25 | s.write([('label'+str(i), random.random()*10, None, i*0.01) for i in range(12)]) 26 | tw.plt_loop(count=3) 27 | 28 | #static_pie() 29 | dynamic_pie() 30 | -------------------------------------------------------------------------------- /test/visualizations/line3d_plot.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | import random, time 3 | 4 | # TODO: resolve problem with Axis3D? 5 | 6 | def static_line3d(): 7 | w = tw.Watcher() 8 | s = w.create_stream() 9 | 10 | v = tw.Visualizer(s, vis_type='line3d') 11 | v.show() 12 | 13 | for i in range(10): 14 | s.write((i, i*i, int(random.random()*10))) 15 | 16 | tw.plt_loop() 17 | 18 | def dynamic_line3d(): 19 | w = tw.Watcher() 20 | s = w.create_stream() 21 | 22 | v = tw.Visualizer(s, vis_type='line3d', clear_after_each=True) 23 | v.show() 24 | 25 | for i in range(100): 26 | s.write([(i, random.random()*10, z) for i in range(10) for z in range(10)]) 27 | tw.plt_loop(count=3) 28 | 29 | static_line3d() 30 | #dynamic_line3d() 31 | 32 | -------------------------------------------------------------------------------- /test/files/file_stream.py: -------------------------------------------------------------------------------- 1 | from tensorwatch.watcher_base import WatcherBase 2 | from tensorwatch.stream import Stream 3 | from tensorwatch.file_stream import FileStream 4 | from tensorwatch import LinePlot 5 | from tensorwatch.image_utils import plt_loop 6 | import tensorwatch as tw 7 | 8 | def file_write(): 9 | watcher = WatcherBase() 10 | stream = watcher.create_stream(expr='lambda vars:(vars.x, vars.x**2)', 11 | devices=[r'c:\temp\obs.txt']) 12 | 13 | for i in range(5): 14 | watcher.observe(x=i) 15 | 16 | def file_read(): 17 | watcher = WatcherBase() 18 | stream = watcher.open_stream(devices=[r'c:\temp\obs.txt']) 19 | vis = tw.Visualizer(stream, vis_type='mpl-line') 20 | vis.show() 21 | plt_loop() 22 | 23 | def main(): 24 | file_write() 25 | file_read() 26 | 27 | main() 28 | 29 | -------------------------------------------------------------------------------- /test/simple_log/cli_file_expr.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | from tensorwatch import utils 3 | utils.set_debug_verbosity(4) 4 | 5 | 6 | #r = tw.Visualizer(vis_type='mpl-line') 7 | #r.show() 8 | #r2=tw.Visualizer('map(lambda x:math.sqrt(x.sum), l)', cell=r.cell) 9 | #r3=tw.Visualizer('map(lambda x:math.sqrt(x.sum), l)', renderer=r2) 10 | 11 | def show_mpl(): 12 | cli = tw.WatcherClient(r'c:\temp\sum.log') 13 | s1 = cli.open_stream('sum') 14 | p = tw.LinePlot(title='Demo') 15 | p.subscribe(s1, xtitle='Index', ytitle='sqrt(ev_i)') 16 | s1.load() 17 | p.show() 18 | 19 | tw.plt_loop() 20 | 21 | def show_text(): 22 | cli = tw.WatcherClient(r'c:\temp\sum.log') 23 | s1 = cli.open_stream('sum_2') 24 | text = tw.Visualizer(s1) 25 | text.show() 26 | input('Waiting') 27 | 28 | #show_text() 29 | show_mpl() 30 | -------------------------------------------------------------------------------- /test/simple_log/cli_sum_log.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | from tensorwatch import utils 3 | utils.set_debug_verbosity(4) 4 | 5 | 6 | #r = tw.Visualizer(vis_type='mpl-line') 7 | #r.show() 8 | #r2=tw.Visualizer('map(lambda x:math.sqrt(x.sum), l)', cell=r.cell) 9 | #r3=tw.Visualizer('map(lambda x:math.sqrt(x.sum), l)', renderer=r2) 10 | 11 | def show_mpl(): 12 | cli = tw.WatcherClient() 13 | st_isum = cli.open_stream('isums') 14 | st_rsum = cli.open_stream('rsums') 15 | 16 | line_plot = tw.Visualizer(st_isum, vis_type='line', xtitle='i', ytitle='isum') 17 | line_plot.show() 18 | 19 | line_plot2 = tw.Visualizer(st_rsum, vis_type='line', host=line_plot, ytitle='rsum') 20 | 21 | tw.plt_loop() 22 | 23 | def show_text(): 24 | cli = tw.WatcherClient() 25 | text_vis = tw.Visualizer(st_isum, vis_type='text') 26 | text_vis.show() 27 | input('Waiting') 28 | 29 | #show_text() 30 | show_mpl() -------------------------------------------------------------------------------- /tensorwatch/stream_union.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from .stream import Stream 5 | from typing import Iterator 6 | 7 | class StreamUnion(Stream): 8 | def __init__(self, child_streams:Iterator[Stream], for_write:bool, stream_name:str=None, console_debug:bool=False) -> None: 9 | super(StreamUnion, self).__init__(stream_name=stream_name, console_debug=console_debug) 10 | 11 | # save references, child streams does away only if parent goes away 12 | self.child_streams = child_streams 13 | 14 | # when someone does write to us, we write to all our listeners 15 | if for_write: 16 | for child_stream in child_streams: 17 | child_stream.subscribe(self) 18 | else: 19 | # union of all child streams 20 | for child_stream in child_streams: 21 | self.subscribe(child_stream) -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | This project welcomes contributions and suggestions. Most contributions require you to 4 | agree to a Contributor License Agreement (CLA) declaring that you have the right to, 5 | and actually do, grant us the rights to use your contribution. For details, visit 6 | https://cla.microsoft.com. 7 | 8 | When you submit a pull request, a CLA-bot will automatically determine whether you need 9 | to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the 10 | instructions provided by the bot. You will only need to do this once across all repositories using our CLA. 11 | 12 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 13 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 14 | or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 15 | -------------------------------------------------------------------------------- /tensorwatch.sln: -------------------------------------------------------------------------------- 1 | 2 | Microsoft Visual Studio Solution File, Format Version 12.00 3 | # Visual Studio 15 4 | VisualStudioVersion = 15.0.27428.2043 5 | MinimumVisualStudioVersion = 10.0.40219.1 6 | Project("{888888A0-9F3D-457C-B088-3A5042F75D52}") = "tensorwatch", "tensorwatch.pyproj", "{CC8ABC7F-EDE1-4E13-B6B7-0041A5EC66A7}" 7 | EndProject 8 | Global 9 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 10 | Debug|Any CPU = Debug|Any CPU 11 | Release|Any CPU = Release|Any CPU 12 | EndGlobalSection 13 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 14 | {CC8ABC7F-EDE1-4E13-B6B7-0041A5EC66A7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 15 | {CC8ABC7F-EDE1-4E13-B6B7-0041A5EC66A7}.Release|Any CPU.ActiveCfg = Release|Any CPU 16 | EndGlobalSection 17 | GlobalSection(SolutionProperties) = preSolution 18 | HideSolutionNode = FALSE 19 | EndGlobalSection 20 | GlobalSection(ExtensibilityGlobals) = postSolution 21 | SolutionGuid = {99E7AEC7-2CDE-48C8-B98B-4E28E4F840B6} 22 | EndGlobalSection 23 | EndGlobal 24 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import setuptools 5 | 6 | with open("README.md", "r") as fh: 7 | long_description = fh.read() 8 | 9 | setuptools.setup( 10 | name="tensorwatch", 11 | version="0.8.5", 12 | author="Shital Shah", 13 | author_email="shitals@microsoft.com", 14 | description="Interactive Realtime Debugging and Visualization for AI", 15 | long_description=long_description, 16 | long_description_content_type="text/markdown", 17 | url="https://github.com/microsoft/tensorwatch", 18 | packages=setuptools.find_packages(), 19 | license='MIT', 20 | classifiers=( 21 | "Programming Language :: Python :: 3", 22 | "License :: OSI Approved :: MIT License", 23 | "Operating System :: OS Independent", 24 | ), 25 | include_package_data=True, 26 | install_requires=[ 27 | 'matplotlib', 'numpy', 'pyzmq', 'plotly', 'torchstat', 'ipywidgets', 'sklearn', 'nbformat', 'scikit-image' # , 'receptivefield' 28 | ] 29 | ) -------------------------------------------------------------------------------- /test/visualizations/bar_plot.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | import random, time 3 | 4 | def static_bar(): 5 | w = tw.Watcher() 6 | s = w.create_stream() 7 | 8 | v = tw.Visualizer(s, vis_type='bar') 9 | v.show() 10 | 11 | for i in range(10): 12 | s.write(int(random.random()*10)) 13 | 14 | tw.plt_loop() 15 | 16 | 17 | def dynamic_bar(): 18 | w = tw.Watcher() 19 | s = w.create_stream() 20 | 21 | v = tw.Visualizer(s, vis_type='bar', clear_after_each=True) 22 | v.show() 23 | 24 | for i in range(100): 25 | s.write([('a'+str(i), random.random()*10) for i in range(10)]) 26 | tw.plt_loop(count=3) 27 | 28 | def dynamic_bar3d(): 29 | w = tw.Watcher() 30 | s = w.create_stream() 31 | 32 | v = tw.Visualizer(s, vis_type='bar3d', clear_after_each=True) 33 | v.show() 34 | 35 | for i in range(100): 36 | s.write([(i, random.random()*10, z) for i in range(10) for z in range(10)]) 37 | tw.plt_loop(count=3) 38 | 39 | static_bar() 40 | #dynamic_bar() 41 | #dynamic_bar3d() 42 | -------------------------------------------------------------------------------- /test/deps/live_graph.py: -------------------------------------------------------------------------------- 1 | from matplotlib import pyplot as plt 2 | from matplotlib.animation import FuncAnimation 3 | from random import randrange 4 | from threading import Thread 5 | import time 6 | 7 | class LiveGraph: 8 | def __init__(self): 9 | self.x_data, self.y_data = [], [] 10 | self.figure = plt.figure() 11 | self.line, = plt.plot(self.x_data, self.y_data) 12 | self.animation = FuncAnimation(self.figure, self.update, interval=1000) 13 | self.th = Thread(target=self.thread_f, name='LiveGraph', daemon=True) 14 | self.th.start() 15 | 16 | def update(self, frame): 17 | self.line.set_data(self.x_data, self.y_data) 18 | self.figure.gca().relim() 19 | self.figure.gca().autoscale_view() 20 | return self.line, 21 | 22 | def show(self): 23 | plt.show() 24 | 25 | def thread_f(self): 26 | x = 0 27 | while True: 28 | self.x_data.append(x) 29 | x += 1 30 | self.y_data.append(randrange(0, 100)) 31 | time.sleep(1) -------------------------------------------------------------------------------- /LICENSE.TXT: -------------------------------------------------------------------------------- 1 | Copyright (c) Microsoft Corporation. All rights reserved. 2 | 3 | MIT License 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /tensorwatch/receptive_field/rf_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from receptivefield.pytorch import PytorchReceptiveField 5 | #from receptivefield.image import get_default_image 6 | import numpy as np 7 | 8 | def _get_rf(model, sample_pil_img): 9 | # define model functions 10 | def model_fn(): 11 | model.eval() 12 | return model 13 | 14 | input_shape = np.array(sample_pil_img).shape 15 | 16 | rf = PytorchReceptiveField(model_fn) 17 | rf_params = rf.compute(input_shape=input_shape) 18 | return rf, rf_params 19 | 20 | def plot_receptive_field(model, sample_pil_img, layout=(2, 2), figsize=(6, 6)): 21 | rf, rf_params = _get_rf(model, sample_pil_img) # pylint: disable=unused-variable 22 | return rf.plot_rf_grids( 23 | custom_image=sample_pil_img, 24 | figsize=figsize, 25 | layout=layout) 26 | 27 | def plot_grads_at(model, sample_pil_img, feature_map_index=0, point=(8,8), figsize=(6, 6)): 28 | rf, rf_params = _get_rf(model, sample_pil_img) # pylint: disable=unused-variable 29 | return rf.plot_gradient_at(fm_id=feature_map_index, point=point, image=None, figsize=figsize) 30 | -------------------------------------------------------------------------------- /TODO.md: -------------------------------------------------------------------------------- 1 | * Fix cell size issue 2 | * Refactor _plot* interface to accept all values, for ImagePlot only use last value 3 | * Refactor ImagePlot for arbitrary number of images with alpha, cmap 4 | * Change tw.open -> tw.create_viz 5 | * Make sure streams have names as key, each data point has index 6 | * Add tw.open_viz(stream_name, from_index)_ 7 | * Add persist=device_name option for streams 8 | * Ability to use streams in standalone mode 9 | * tw.create_viz on server side 10 | * tw.log for server side 11 | * experiment with IPC channel 12 | * confusion matrix as in https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html 13 | * Speed up import 14 | * Do linting 15 | * live perf data 16 | * NaN tracing 17 | * PCA 18 | * Remove error if MNIST notebook is on and we run fruits 19 | * Remove 2nd image from fruits 20 | * clear exisitng streams when starting client 21 | * ImageData should accept numpy array or pillow or torch tensor 22 | * image plot getting refreshed at 12hz instead of 2 hz in MNIST 23 | * image plot doesn't title 24 | * Animated mesh/surface graph demo 25 | * Move to h5 storage? 26 | * Error envelop 27 | * histogram 28 | * new graph on end 29 | * TF support 30 | * generic visualizer -> Given obj and Box, paint in box 31 | * visualize image and text with attention 32 | * add confidence interval for plotly: https://plot.ly/python/continuous-error-bars/ -------------------------------------------------------------------------------- /tensorwatch/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from typing import Iterable, Sequence, Union 5 | 6 | from .watcher_client import WatcherClient 7 | from .watcher import Watcher 8 | from .watcher_base import WatcherBase 9 | 10 | from .text_vis import TextVis 11 | from .plotly.embeddings_plot import EmbeddingsPlot 12 | from .mpl.line_plot import LinePlot 13 | from .mpl.image_plot import ImagePlot 14 | from .mpl.histogram import Histogram 15 | from .mpl.bar_plot import BarPlot 16 | from .mpl.pie_chart import PieChart 17 | from .visualizer import Visualizer 18 | 19 | from .stream import Stream 20 | from .array_stream import ArrayStream 21 | from .lv_types import PointData, ImageData, VisArgs, StreamItem, PredictionResult 22 | from . import utils 23 | 24 | ###### Import methods for tw namespace ######### 25 | #from .receptive_field.rf_utils import plot_receptive_field, plot_grads_at 26 | from .embeddings.tsne_utils import get_tsne_components 27 | from .model_graph.torchstat_utils import model_stats 28 | from .image_utils import show_image, open_image, img2pyt, linear_to_2d, plt_loop 29 | from .data_utils import pyt_ds2list, sample_by_class, col2array, search_similar 30 | 31 | 32 | 33 | def draw_model(model, input_shape=None, orientation='TB'): #orientation = 'LR' for landscpe 34 | from .model_graph.hiddenlayer import graph 35 | g = graph.build_graph(model, input_shape, orientation=orientation) 36 | return g 37 | 38 | 39 | -------------------------------------------------------------------------------- /tensorwatch/saliency/lime/wrappers/generic_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import inspect 3 | import types 4 | 5 | 6 | def has_arg(fn, arg_name): 7 | """Checks if a callable accepts a given keyword argument. 8 | 9 | Args: 10 | fn: callable to inspect 11 | arg_name: string, keyword argument name to check 12 | 13 | Returns: 14 | bool, whether `fn` accepts a `arg_name` keyword argument. 15 | """ 16 | if sys.version_info < (3,): 17 | if isinstance(fn, types.FunctionType) or isinstance(fn, types.MethodType): 18 | arg_spec = inspect.getargspec(fn) 19 | else: 20 | try: 21 | arg_spec = inspect.getargspec(fn.__call__) 22 | except AttributeError: 23 | return False 24 | return (arg_name in arg_spec.args) 25 | elif sys.version_info < (3, 6): 26 | arg_spec = inspect.getfullargspec(fn) 27 | return (arg_name in arg_spec.args or 28 | arg_name in arg_spec.kwonlyargs) 29 | else: 30 | try: 31 | signature = inspect.signature(fn) 32 | except ValueError: 33 | # handling Cython 34 | signature = inspect.signature(fn.__call__) 35 | parameter = signature.parameters.get(arg_name) 36 | if parameter is None: 37 | return False 38 | return (parameter.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, 39 | inspect.Parameter.KEYWORD_ONLY)) 40 | -------------------------------------------------------------------------------- /tensorwatch/filtered_stream.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from .stream import Stream 5 | from typing import Callable, Any 6 | 7 | class FilteredStream(Stream): 8 | def __init__(self, source_stream:Stream, filter_expr:Callable, stream_name:str=None, 9 | console_debug:bool=False)->None: 10 | 11 | stream_name = stream_name or '{}|{}'.format(source_stream.stream_name, str(filter_expr)) 12 | super(FilteredStream, self).__init__(stream_name=stream_name, console_debug=console_debug) 13 | self.subscribe(source_stream) 14 | self.filter_expr = filter_expr 15 | 16 | def _filter(self, stream_item): 17 | return self.filter_expr(stream_item) \ 18 | if self.filter_expr is not None \ 19 | else (stream_item, True) 20 | 21 | def write(self, val:Any, from_stream:'Stream'=None): 22 | stream_item = self.to_stream_item(val) 23 | 24 | result, is_valid = self._filter(stream_item) 25 | if is_valid: 26 | return super(FilteredStream, self).write(result) 27 | # else ignore this call 28 | 29 | def read_all(self, from_stream:'Stream'=None): #override->replacement 30 | for subscribed_to in self._subscribed_to: 31 | for stream_item in subscribed_to.read_all(from_stream=self): 32 | result, is_valid = self._filter(stream_item) 33 | if is_valid: 34 | yield stream_item 35 | 36 | -------------------------------------------------------------------------------- /tensorwatch/embeddings/tsne_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from sklearn.manifold import TSNE 5 | from sklearn.preprocessing import StandardScaler 6 | import numpy as np 7 | 8 | def _standardize_data(data, col, whitten, flatten): 9 | if col is not None: 10 | data = data[col] 11 | 12 | #TODO: enable auto flattening 13 | #if data is tensor then flatten it first 14 | #if flatten and len(data) > 0 and hasattr(data[0], 'shape') and \ 15 | # utils.has_method(data[0], 'reshape'): 16 | 17 | # data = [d.reshape((-1,)) for d in data] 18 | 19 | if whitten: 20 | data = StandardScaler().fit_transform(data) 21 | return data 22 | 23 | def get_tsne_components(data, features_col=0, labels_col=1, whitten=True, n_components=3, perplexity=20, flatten=True, for_plot=True): 24 | features = _standardize_data(data, features_col, whitten, flatten) 25 | tsne = TSNE(n_components=n_components, perplexity=perplexity) 26 | tsne_results = tsne.fit_transform(features) 27 | 28 | if for_plot: 29 | comps = tsne_results.tolist() 30 | labels = data[labels_col] 31 | for i, item in enumerate(comps): 32 | # low, high, annotation, text, color 33 | label = labels[i] 34 | if isinstance(labels, np.ndarray): 35 | label = label.item() 36 | item.extend((None, None, None, str(int(label)), label)) 37 | return comps 38 | return tsne_results 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /test/components/file_only_test.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | 3 | def writer(): 4 | watcher = tw.Watcher(filename=r'c:\temp\test.log', port=None) 5 | 6 | with watcher.create_stream('metric1') as stream1: 7 | for i in range(3): 8 | stream1.write((i, i*i)) 9 | 10 | with watcher.create_stream('metric2') as stream2: 11 | for i in range(3): 12 | stream2.write((i, i*i*i)) 13 | 14 | def reader1(): 15 | print('---------------------------reader1---------------------------') 16 | watcher = tw.Watcher(filename=r'c:\temp\test.log', port=None) 17 | 18 | stream1 = watcher.open_stream('metric1') 19 | stream1.console_debug = True 20 | stream1.load() 21 | 22 | stream2 = watcher.open_stream('metric2') 23 | stream2.console_debug = True 24 | stream2.load() 25 | 26 | def reader2(): 27 | print('---------------------------reader2---------------------------') 28 | 29 | watcher = tw.Watcher(filename=r'c:\temp\test.log', port=None) 30 | 31 | stream1 = watcher.open_stream('metric1') 32 | for item in stream1.read_all(): 33 | print(item) 34 | 35 | stream2 = watcher.open_stream('metric2') 36 | for item in stream2.read_all(): 37 | print(item) 38 | 39 | def reader3(): 40 | print('---------------------------reader3---------------------------') 41 | 42 | watcher = tw.Watcher(filename=r'c:\temp\test.log', port=None) 43 | stream1 = watcher.open_stream('metric1') 44 | stream2 = watcher.open_stream('metric2') 45 | 46 | vis1 = tw.Visualizer(stream1, vis_type='line') 47 | vis2 = tw.Visualizer(stream2, vis_type='line', host=vis1) 48 | 49 | vis1.show() 50 | 51 | tw.plt_loop() 52 | 53 | 54 | writer() 55 | reader1() 56 | reader2() 57 | reader3() 58 | 59 | -------------------------------------------------------------------------------- /tensorwatch/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import random 5 | import scipy.spatial.distance 6 | import heapq 7 | import numpy as np 8 | import torch 9 | 10 | def pyt_tensor2np(pyt_tensor): 11 | if pyt_tensor is None: 12 | return None 13 | if isinstance(pyt_tensor, torch.Tensor): 14 | n = pyt_tensor.data.cpu().numpy() 15 | if len(n.shape) == 1: 16 | return n[0] 17 | else: 18 | return n 19 | elif isinstance(pyt_tensor, np.ndarray): 20 | return pyt_tensor 21 | else: 22 | return np.array(pyt_tensor) 23 | 24 | def pyt_tuple2np(pyt_tuple): 25 | return tuple((pyt_tensor2np(t) for t in pyt_tuple)) 26 | 27 | def pyt_ds2list(pyt_ds, count=None): 28 | count = count or len(pyt_ds) 29 | return [pyt_tuple2np(t) for t, c in zip(pyt_ds, range(count))] 30 | 31 | def sample_by_class(data, n_samples, class_col=1, shuffle=True): 32 | if shuffle: 33 | random.shuffle(data) 34 | samples = {} 35 | for i, t in enumerate(data): 36 | cls = t[class_col] 37 | if cls not in samples: 38 | samples[cls] = [] 39 | if len(samples[cls]) < n_samples: 40 | samples[cls].append(data[i]) 41 | samples = sum(samples.values(), []) 42 | return samples 43 | 44 | def col2array(dataset, col): 45 | return [row[col] for row in dataset] 46 | 47 | def search_similar(inputs, compare_to, algorithm='euclidean', topk=5, invert_score=True): 48 | all_scores = scipy.spatial.distance.cdist(inputs, compare_to, algorithm) 49 | all_results = [] 50 | for input_val, scores in zip(inputs, all_scores): 51 | result = [] 52 | for i, (score, data) in enumerate(zip(scores, compare_to)): 53 | if invert_score: 54 | score = 1/(score + 1.0E-6) 55 | if len(result) < topk: 56 | heapq.heappush(result, (score, (i, input_val, data))) 57 | else: 58 | heapq.heappushpop(result, (score, (i, input_val, data))) 59 | all_results.append(result) 60 | return all_results -------------------------------------------------------------------------------- /tensorwatch/pytorch_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from torchvision import models, transforms 5 | import torch 6 | import torch.nn.functional as F 7 | from . import utils, image_utils 8 | import os 9 | 10 | def get_model(model_name): 11 | model = models.__dict__[model_name](pretrained=True) 12 | return model 13 | 14 | def tensors2batch(tensors, preprocess_transform=None): 15 | if preprocess_transform: 16 | tensors = tuple(preprocess_transform(i) for i in tensors) 17 | if not utils.is_array_like(tensors): 18 | tensors = tuple(tensors) 19 | return torch.stack(tensors, dim=0) 20 | 21 | def int2tensor(val): 22 | return torch.LongTensor([val]) 23 | 24 | def image2batch(image, image_transform=None): 25 | if image_transform: 26 | input_x = image_transform(image) 27 | else: # if no transforms supplied then just convert PIL image to tensor 28 | input_x = transforms.ToTensor()(image) 29 | input_x = input_x.unsqueeze(0) #convert to batch of 1 30 | return input_x 31 | 32 | def image_class2tensor(image_path, class_index=None, image_convert_mode=None, 33 | image_transform=None): 34 | image_pil = image_utils.open_image(os.path.abspath(image_path), convert_mode=image_convert_mode) 35 | input_x = image2batch(image_pil, image_transform) 36 | target_class = int2tensor(class_index) if class_index is not None else None 37 | return image_pil, input_x, target_class 38 | 39 | def batch_predict(model, inputs, input_transform=None, device=None): 40 | if input_transform: 41 | batch = torch.stack(tuple(input_transform(i) for i in inputs), dim=0) 42 | else: 43 | batch = torch.stack(inputs, dim=0) 44 | 45 | device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") 46 | model.eval() 47 | model.to(device) 48 | batch = batch.to(device) 49 | 50 | outputs = model(batch) 51 | 52 | return outputs 53 | 54 | def logits2probabilities(logits): 55 | return F.softmax(logits, dim=1) 56 | 57 | def tensor2numpy(t): 58 | return t.data.cpu().numpy() 59 | -------------------------------------------------------------------------------- /tensorwatch/saliency/gradcam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .backprop import VanillaGradExplainer 3 | 4 | 5 | def _get_layer(model, key_list): 6 | if key_list is None: 7 | return None 8 | 9 | a = model 10 | for key in key_list: 11 | a = a._modules[key] 12 | return a 13 | 14 | class GradCAMExplainer(VanillaGradExplainer): 15 | def __init__(self, model, target_layer_name_keys=None, use_inp=False): 16 | super(GradCAMExplainer, self).__init__(model) 17 | self.target_layer = _get_layer(model, target_layer_name_keys) 18 | self.use_inp = use_inp 19 | self.intermediate_act = [] 20 | self.intermediate_grad = [] 21 | self._register_forward_backward_hook() 22 | 23 | def _register_forward_backward_hook(self): 24 | def forward_hook_input(m, i, o): 25 | self.intermediate_act.append(i[0].data.clone()) 26 | 27 | def forward_hook_output(m, i, o): 28 | self.intermediate_act.append(o.data.clone()) 29 | 30 | def backward_hook(m, grad_i, grad_o): 31 | self.intermediate_grad.append(grad_o[0].data.clone()) 32 | 33 | if self.target_layer is not None: 34 | if self.use_inp: 35 | self.target_layer.register_forward_hook(forward_hook_input) 36 | else: 37 | self.target_layer.register_forward_hook(forward_hook_output) 38 | 39 | self.target_layer.register_backward_hook(backward_hook) 40 | 41 | def _reset_intermediate_lists(self): 42 | self.intermediate_act = [] 43 | self.intermediate_grad = [] 44 | 45 | def explain(self, inp, ind=None, raw_inp=None): 46 | self._reset_intermediate_lists() 47 | 48 | _ = super(GradCAMExplainer, self)._backprop(inp, ind) 49 | 50 | if len(self.intermediate_grad): 51 | grad = self.intermediate_grad[0] 52 | act = self.intermediate_act[0] 53 | 54 | weights = grad.sum(-1).sum(-1).unsqueeze(-1).unsqueeze(-1) 55 | cam = weights * act 56 | cam = cam.sum(1).unsqueeze(1) 57 | 58 | cam = torch.clamp(cam, min=0) 59 | 60 | return cam 61 | else: 62 | return None 63 | -------------------------------------------------------------------------------- /tensorwatch/zmq_stream.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from typing import Any 5 | from .zmq_wrapper import ZmqWrapper 6 | from .stream import Stream 7 | from .lv_types import DefaultPorts, PublisherTopics 8 | from . import utils 9 | 10 | # on writes send data on ZMQ transport 11 | class ZmqStream(Stream): 12 | def __init__(self, for_write:bool, port:int=0, topic=PublisherTopics.StreamItem, block_until_connected=True, 13 | stream_name:str=None, console_debug:bool=False): 14 | super(ZmqStream, self).__init__(stream_name=stream_name, console_debug=console_debug) 15 | 16 | self.for_write = for_write 17 | self._zmq = None 18 | 19 | self.topic = topic 20 | self._open(for_write, port, block_until_connected) 21 | utils.debug_log('ZmqStream started', verbosity=1) 22 | 23 | def _open(self, for_write:bool, port:int, block_until_connected:bool): 24 | if for_write: 25 | self._zmq = ZmqWrapper.Publication(port=DefaultPorts.PubSub+port, 26 | block_until_connected=block_until_connected) 27 | else: 28 | self._zmq = ZmqWrapper.Subscription(port=DefaultPorts.PubSub+port, 29 | topic=self.topic, callback=self._on_subscription_item) 30 | 31 | def close(self): 32 | if not self.closed: 33 | self._zmq.close() 34 | self._zmq = None 35 | utils.debug_log('ZmqStream is closed', verbosity=1) 36 | super(ZmqStream, self).close() 37 | 38 | def _on_subscription_item(self, val:Any): 39 | utils.debug_log('Received subscription item', verbosity=5) 40 | self.write(val) 41 | 42 | def write(self, val:Any, from_stream:'Stream'=None, topic=None): 43 | stream_item = self.to_stream_item(val) 44 | 45 | if self.for_write: 46 | topic = topic or self.topic 47 | utils.debug_log('Sent subscription item', verbosity=5) 48 | self._zmq.send_obj(stream_item, topic) 49 | # else if this was opened for read then we have subscription and 50 | # we shouldn't be calling send_obj 51 | super(ZmqStream, self).write(stream_item) 52 | -------------------------------------------------------------------------------- /tensorwatch/saliency/occlusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.autograd import Variable 4 | from skimage.util import view_as_windows 5 | 6 | # modified from https://github.com/marcoancona/DeepExplain/blob/master/deepexplain/tensorflow/methods.py#L291-L342 7 | # note the different dim order in pytorch (NCHW) and tensorflow (NHWC) 8 | 9 | class OcclusionExplainer: 10 | def __init__(self, model, window_shape=10, step=1): 11 | self.model = model 12 | self.window_shape = window_shape 13 | self.step = step 14 | 15 | def explain(self, inp, ind=None, raw_inp=None): 16 | self.model.eval() 17 | with torch.no_grad(): 18 | return OcclusionExplainer._occlusion(inp, self.model, self.window_shape, self.step) 19 | 20 | @staticmethod 21 | def _occlusion(inp, model, window_shape, step=None): 22 | if type(window_shape) == int: 23 | window_shape = (window_shape, window_shape, 3) 24 | 25 | if step is None: 26 | step = 1 27 | n, c, h, w = inp.data.size() 28 | total_dim = c * h * w 29 | index_matrix = np.arange(total_dim).reshape(h, w, c) 30 | idx_patches = view_as_windows(index_matrix, window_shape, step).reshape( 31 | (-1,) + window_shape) 32 | heatmap = np.zeros((n, h, w, c), dtype=np.float32).reshape((-1), total_dim) 33 | weights = np.zeros_like(heatmap) 34 | 35 | inp_data = inp.data.clone() 36 | new_inp = Variable(inp_data) 37 | eval0 = model(new_inp) 38 | pred_id = eval0.max(1)[1].data[0] 39 | 40 | for i, p in enumerate(idx_patches): 41 | mask = np.ones((h, w, c)).flatten() 42 | mask[p.flatten()] = 0 43 | th_mask = torch.from_numpy(mask.reshape(1, h, w, c).transpose(0, 3, 1, 2)).float().cuda() 44 | masked_xs = Variable(th_mask * inp_data) 45 | delta = (eval0[0, pred_id] - model(masked_xs)[0, pred_id]).data.cpu().numpy() 46 | delta_aggregated = np.sum(delta.reshape(n, -1), -1, keepdims=True) 47 | heatmap[:, p.flatten()] += delta_aggregated 48 | weights[:, p.flatten()] += p.size 49 | 50 | attribution = np.reshape(heatmap / (weights + 1e-10), (n, h, w, c)).transpose(0, 3, 1, 2) 51 | return torch.from_numpy(attribution) -------------------------------------------------------------------------------- /tensorwatch/file_stream.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from .stream import Stream 5 | import pickle, os 6 | from typing import Any 7 | from . import utils 8 | import time 9 | 10 | class FileStream(Stream): 11 | def __init__(self, for_write:bool, file_name:str, stream_name:str=None, console_debug:bool=False): 12 | super(FileStream, self).__init__(stream_name=stream_name or file_name, console_debug=console_debug) 13 | 14 | self._file = open(file_name, 'wb' if for_write else 'rb') 15 | self.file_name = file_name 16 | self.for_write = for_write 17 | utils.debug_log('FileStream started', self.file_name, verbosity=1) 18 | 19 | def close(self): 20 | if not self._file.closed: 21 | self._file.close() 22 | self._file = None 23 | utils.debug_log('FileStream is closed', self.file_name, verbosity=1) 24 | super(FileStream, self).close() 25 | 26 | def write(self, val:Any, from_stream:'Stream'=None): 27 | stream_item = self.to_stream_item(val) 28 | 29 | if self.for_write: 30 | pickle.dump(stream_item, self._file) 31 | super(FileStream, self).write(stream_item) 32 | 33 | def read_all(self, from_stream:'Stream'=None): 34 | if self.for_write: 35 | raise IOError('Cannot use read() call because FileSteam is opened with for_write=True') 36 | if self._file is not None: 37 | self._file.seek(0, 0) # we may filter this stream multiple times 38 | while not utils.is_eof(self._file): 39 | yield pickle.load(self._file) 40 | for item in super(FileStream, self).read_all(): 41 | yield item 42 | 43 | def load(self, from_stream:'Stream'=None): 44 | if self.for_write: 45 | raise IOError('Cannot use load() call because FileSteam is opened with for_write=True') 46 | if self._file is not None: 47 | self._file.seek(0, 0) # we may filter this stream multiple times 48 | while not utils.is_eof(self._file): 49 | stream_item = pickle.load(self._file) 50 | self.write(stream_item) 51 | super(FileStream, self).load() 52 | 53 | def save(self, from_stream:'Stream'=None): 54 | if not self._file.closed: 55 | self._file.flush() 56 | super(FileStream, self).save(val) 57 | 58 | -------------------------------------------------------------------------------- /tensorwatch/saliency/lime_image_explainer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from skimage.segmentation import mark_boundaries 5 | from torchvision import transforms 6 | import torch 7 | from .lime import lime_image 8 | import numpy as np 9 | from .. import imagenet_utils, pytorch_utils, utils 10 | 11 | class LimeImageExplainer: 12 | def __init__(self, model, predict_fn): 13 | self.model = model 14 | self.predict_fn = predict_fn 15 | 16 | def preprocess_input(self, inp): 17 | return inp 18 | def preprocess_label(self, label): 19 | return label 20 | 21 | def explain(self, inp, ind=None, raw_inp=None, top_labels=5, hide_color=0, num_samples=1000, 22 | positive_only=True, num_features=5, hide_rest=True, pixel_val_max=255.0): 23 | explainer = lime_image.LimeImageExplainer() 24 | explanation = explainer.explain_instance(self.preprocess_input(raw_inp), self.predict_fn, 25 | top_labels=5, hide_color=0, num_samples=1000) 26 | 27 | temp, mask = explanation.get_image_and_mask(self.preprocess_label(ind) or explanation.top_labels[0], 28 | positive_only=True, num_features=5, hide_rest=True) 29 | 30 | img = mark_boundaries(temp/pixel_val_max, mask) 31 | img = torch.from_numpy(img) 32 | img = torch.transpose(img, 0, 2) 33 | img = torch.transpose(img, 1, 2) 34 | return img.unsqueeze(0) 35 | 36 | class LimeImagenetExplainer(LimeImageExplainer): 37 | def __init__(self, model, predict_fn=None): 38 | super(LimeImagenetExplainer, self).__init__(model, predict_fn or self._imagenet_predict) 39 | 40 | def _preprocess_transform(self): 41 | transf = transforms.Compose([ 42 | transforms.ToTensor(), 43 | imagenet_utils.get_normalize_transform() 44 | ]) 45 | 46 | return transf 47 | 48 | def preprocess_input(self, inp): 49 | return np.array(imagenet_utils.get_resize_transform()(inp)) 50 | def preprocess_label(self, label): 51 | return label.item() if label is not None and utils.has_method(label, 'item') else label 52 | 53 | def _imagenet_predict(self, images): 54 | probs = imagenet_utils.predict(self.model, images, image_transform=self._preprocess_transform()) 55 | return pytorch_utils.tensor2numpy(probs) 56 | -------------------------------------------------------------------------------- /tensorwatch/repeated_timer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import threading 5 | import time 6 | import weakref 7 | 8 | class RepeatedTimer: 9 | class State: 10 | Stopped=0 11 | Paused=1 12 | Running=2 13 | 14 | def __init__(self, secs, callback, count=None): 15 | self.secs = secs 16 | self.callback = weakref.WeakMethod(callback) if callback else None 17 | self._thread = None 18 | self._state = RepeatedTimer.State.Stopped 19 | self.pause_wait = threading.Event() 20 | self.pause_wait.set() 21 | self._continue_thread = False 22 | self.count = count 23 | 24 | def start(self): 25 | self._continue_thread = True 26 | self.pause_wait.set() 27 | if self._thread is None or not self._thread.isAlive(): 28 | self._thread = threading.Thread(target=self._runner, name='RepeatedTimer', daemon=True) 29 | self._thread.start() 30 | self._state = RepeatedTimer.State.Running 31 | 32 | def stop(self, block=False): 33 | self.pause_wait.set() 34 | self._continue_thread = False 35 | if block and not (self._thread is None or not self._thread.isAlive()): 36 | self._thread.join() 37 | self._state = RepeatedTimer.State.Stopped 38 | 39 | def get_state(self): 40 | return self._state 41 | 42 | 43 | def pause(self): 44 | if self._state == RepeatedTimer.State.Running: 45 | self.pause_wait.clear() 46 | self._state = RepeatedTimer.State.Paused 47 | # else nothing to do 48 | def unpause(self): 49 | if self._state == RepeatedTimer.State.Paused: 50 | self.pause_wait.set() 51 | if self._state == RepeatedTimer.State.Paused: 52 | self._state = RepeatedTimer.State.Running 53 | # else nothing to do 54 | 55 | def _runner(self): 56 | while (self._continue_thread): 57 | if self.count: 58 | self.count -= 0 59 | if not self.count: 60 | self._continue_thread = False 61 | 62 | if self._continue_thread: 63 | self.pause_wait.wait() 64 | if self.callback and self.callback(): 65 | self.callback()() 66 | 67 | if self._continue_thread: 68 | time.sleep(self.secs) 69 | 70 | self._thread = None 71 | self._state = RepeatedTimer.State.Stopped -------------------------------------------------------------------------------- /test/simple_log/cli_ij.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | import time 3 | import math 4 | from tensorwatch import utils 5 | utils.set_debug_verbosity(4) 6 | 7 | def mpl_line_plot(): 8 | cli = tw.WatcherClient() 9 | p = tw.LinePlot(title='Demo') 10 | s1 = cli.create_stream(event_name='ev_i', expr='map(lambda v:math.sqrt(v.val)*2, l)') 11 | p.subscribe(s1, xtitle='Index', ytitle='sqrt(ev_i)') 12 | p.show() 13 | tw.plt_loop() 14 | 15 | def mpl_history_plot(): 16 | cli = tw.WatcherClient() 17 | p2 = tw.LinePlot(title='History Demo') 18 | p2s1 = cli.create_stream(event_name='ev_j', expr='map(lambda v:(v.val, math.sqrt(v.val)*2), l)') 19 | p2.subscribe(p2s1, xtitle='Index', ytitle='sqrt(ev_j)', clear_after_end=True, history_len=15) 20 | p2.show() 21 | tw.plt_loop() 22 | 23 | def show_stream(): 24 | cli = tw.WatcherClient() 25 | 26 | print("Subscribing to event ev_i...") 27 | s1 = cli.create_stream(event_name="ev_i", expr='map(lambda v:math.sqrt(v.val), l)') 28 | r1 = tw.TextVis(title='L1') 29 | r1.subscribe(s1) 30 | r1.show() 31 | 32 | print("Subscribing to event ev_j...") 33 | s2 = cli.create_stream(event_name="ev_j", expr='map(lambda v:v.val*v.val, l)') 34 | r2 = tw.TextVis(title='L2') 35 | r2.subscribe(s2) 36 | 37 | r2.show() 38 | 39 | print("Waiting for key...") 40 | 41 | utils.wait_key() 42 | 43 | # this no longer directly supported 44 | # TODO: create stream that allows enumeration from buffered values 45 | #def read_stream(): 46 | # cli = tw.WatcherClient() 47 | 48 | # with cli.create_stream(event_name="ev_i", expr='map(lambda v:(v.x, math.sqrt(v.val)), l)') as s1: 49 | # for stream_item in s1: 50 | # print(stream_item.value) 51 | # print('done') 52 | # utils.wait_key() 53 | 54 | def plotly_line_graph(): 55 | cli = tw.WatcherClient() 56 | s1 = cli.create_stream(event_name="ev_i", expr='map(lambda v:(v.x, math.sqrt(v.val)), l)') 57 | 58 | p = tw.plotly.line_plot.LinePlot() 59 | p.subscribe(s1) 60 | p.show() 61 | 62 | utils.wait_key() 63 | 64 | def plotly_history_graph(): 65 | cli = tw.WatcherClient() 66 | p = tw.plotly.line_plot.LinePlot(title='Demo') 67 | s2 = cli.create_stream(event_name='ev_j', expr='map(lambda v:(v.x, v.val), l)') 68 | p.subscribe(s2, ytitle='ev_j', history_len=15) 69 | p.show() 70 | utils.wait_key() 71 | 72 | 73 | mpl_line_plot() 74 | #mpl_history_plot() 75 | #show_stream() 76 | #plotly_line_graph() 77 | #plotly_history_graph() 78 | -------------------------------------------------------------------------------- /tensorwatch/mpl/pie_chart.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from .base_mpl_plot import BaseMplPlot 5 | from .. import image_utils 6 | from .. import utils 7 | import numpy as np 8 | 9 | class PieChart(BaseMplPlot): 10 | def init_stream_plot(self, stream_vis, autopct=None, colormap=None, color=None, 11 | shadow=None, **stream_vis_args): 12 | 13 | # add main subplot 14 | stream_vis.autopct, stream_vis.shadow = autopct, True if shadow is None else shadow 15 | stream_vis.ax = self.get_main_axis() 16 | stream_vis.series = [] 17 | stream_vis.wedge_artists = [] # stores previously drawn bars 18 | 19 | stream_vis.cmap = image_utils.get_cmap(name=colormap or 'Dark2') 20 | if color is None: 21 | if not self.is_3d: 22 | color = stream_vis.cmap((len(self._stream_vises)%stream_vis.cmap.N)/stream_vis.cmap.N) # pylint: disable=no-member 23 | stream_vis.color = color 24 | 25 | def clear_artists(self, stream_vis): 26 | for artist in stream_vis.wedge_artists: 27 | artist.remove() 28 | stream_vis.wedge_artists.clear() 29 | 30 | def clear_plot(self, stream_vis, clear_history): 31 | stream_vis.series.clear() 32 | self.clear_artists(stream_vis) 33 | 34 | def _show_stream_items(self, stream_vis, stream_items): 35 | """Paint the given stream_items in to visualizer. If visualizer is dirty then return False else True. 36 | """ 37 | 38 | vals = self._extract_vals(stream_items) 39 | if not len(vals): 40 | return True 41 | 42 | # make sure tuple has 4 elements 43 | unpacker = lambda a0=None,a1=None,a2=None,a3=None,*_:(a0,a1,a2,a3) 44 | stream_vis.series.extend([unpacker(*val) for val in vals]) 45 | self.clear_artists(stream_vis) 46 | 47 | labels, sizes, colors, explode = \ 48 | [t[0] for t in stream_vis.series], \ 49 | [t[1] for t in stream_vis.series], \ 50 | [(t[2] or stream_vis.cmap.colors[i % len(stream_vis.cmap.colors)]) \ 51 | for i, t in enumerate(stream_vis.series)], \ 52 | [t[3] or 0 for t in stream_vis.series], 53 | 54 | 55 | stream_vis.wedge_artists, *_ = stream_vis.ax.pie( \ 56 | sizes, explode=explode, labels=labels, colors=colors, 57 | autopct=stream_vis.autopct, shadow=stream_vis.shadow) 58 | 59 | return False 60 | 61 | -------------------------------------------------------------------------------- /tensorwatch/zmq_mgmt_stream.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from typing import Any 5 | from .zmq_stream import ZmqStream 6 | from .lv_types import PublisherTopics, ServerMgmtMsg, StreamCreateRequest 7 | from .zmq_wrapper import ZmqWrapper 8 | from .lv_types import CliSrvReqTypes, ClientServerRequest 9 | from . import utils 10 | 11 | class ZmqMgmtStream(ZmqStream): 12 | # default topic is mgmt 13 | def __init__(self, clisrv:ZmqWrapper.ClientServer, for_write:bool, port:int=0, topic=PublisherTopics.ServerMgmt, block_until_connected=True, 14 | stream_name:str=None, console_debug:bool=False): 15 | super(ZmqMgmtStream, self).__init__(for_write=for_write, port=port, topic=topic, 16 | block_until_connected=block_until_connected, stream_name=stream_name, console_debug=console_debug) 17 | 18 | self._clisrv = clisrv 19 | self._stream_reqs:Dict[str,StreamCreateRequest] = {} 20 | 21 | def write(self, val:Any, from_stream:'Stream'=None): 22 | r"""Handles server management events. 23 | """ 24 | stream_item = self.to_stream_item(val) 25 | mgmt_msg = stream_item.value 26 | 27 | utils.debug_log("Received - SeverMgmtevent", mgmt_msg) 28 | # if server was restarted then send create stream requests again 29 | if mgmt_msg.event_name == ServerMgmtMsg.EventServerStart: 30 | for stream_req in self._stream_reqs.values(): 31 | self._send_create_stream(stream_req) 32 | 33 | super(ZmqMgmtStream, self).write(stream_item) 34 | 35 | def add_stream_req(self, stream_req:StreamCreateRequest)->None: 36 | self._send_create_stream(stream_req) 37 | 38 | # save this for later for resend if server restarts 39 | self._stream_reqs[stream_req.stream_name] = stream_req 40 | 41 | # override to send request to server 42 | def del_stream(self, name:str) -> None: 43 | clisrv_req = ClientServerRequest(CliSrvReqTypes.del_stream, name) 44 | self._clisrv.send_obj(clisrv_req) 45 | self._stream_reqs.pop(name, None) 46 | 47 | def _send_create_stream(self, stream_req): 48 | utils.debug_log("sending create streamreq...") 49 | clisrv_req = ClientServerRequest(CliSrvReqTypes.create_stream, stream_req) 50 | self._clisrv.send_obj(clisrv_req) 51 | utils.debug_log("sent create streamreq") 52 | 53 | def close(self): 54 | if not self.closed: 55 | self._stream_reqs = {} 56 | self._clisrv = None 57 | utils.debug_log('ZmqMgmtStream is closed', verbosity=1) 58 | super(ZmqMgmtStream, self).close() 59 | -------------------------------------------------------------------------------- /tensorwatch/mpl/histogram.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from .base_mpl_plot import BaseMplPlot 5 | from .. import image_utils 6 | from .. import utils 7 | import numpy as np 8 | 9 | class Histogram(BaseMplPlot): 10 | def init_stream_plot(self, stream_vis, 11 | xtitle='', ytitle='', ztitle='', colormap=None, color=None, 12 | bins=None, normed=None, histtype='bar', edge_color=None, linewidth=None, 13 | opacity=None, **stream_vis_args): 14 | 15 | # add main subplot 16 | stream_vis.bins, stream_vis.normed, stream_vis.linewidth = bins, normed, (linewidth or 2) 17 | stream_vis.ax = self.get_main_axis() 18 | stream_vis.series = [] 19 | stream_vis.bars_artists = [] # stores previously drawn bars 20 | 21 | stream_vis.cmap = image_utils.get_cmap(name=colormap or 'Dark2') 22 | if color is None: 23 | if not self.is_3d: 24 | color = stream_vis.cmap((len(self._stream_vises)%stream_vis.cmap.N)/stream_vis.cmap.N) # pylint: disable=no-member 25 | stream_vis.color = color 26 | stream_vis.edge_color = 'black' 27 | stream_vis.histtype = histtype 28 | stream_vis.opacity = opacity 29 | stream_vis.ax.set_xlabel(xtitle) 30 | stream_vis.ax.xaxis.label.set_style('italic') 31 | stream_vis.ax.set_ylabel(ytitle) 32 | stream_vis.ax.yaxis.label.set_color(color) 33 | stream_vis.ax.yaxis.label.set_style('italic') 34 | if self.is_3d: 35 | stream_vis.ax.set_zlabel(ztitle) 36 | stream_vis.ax.zaxis.label.set_style('italic') 37 | 38 | def is_show_grid(self): #override 39 | return False 40 | 41 | def clear_artists(self, stream_vis): 42 | for bar in stream_vis.bars_artists: 43 | bar.remove() 44 | stream_vis.bars_artists.clear() 45 | 46 | def clear_plot(self, stream_vis, clear_history): 47 | stream_vis.series.clear() 48 | self.clear_artists(stream_vis) 49 | 50 | def _show_stream_items(self, stream_vis, stream_items): 51 | """Paint the given stream_items in to visualizer. If visualizer is dirty then return False else True. 52 | """ 53 | 54 | vals = self._extract_vals(stream_items) 55 | if not len(vals): 56 | return True 57 | 58 | stream_vis.series += vals 59 | self.clear_artists(stream_vis) 60 | n, bins, stream_vis.bars_artists = stream_vis.ax.hist(stream_vis.series, bins=stream_vis.bins, 61 | normed=stream_vis.normed, color=stream_vis.color, edgecolor=stream_vis.edge_color, 62 | histtype=stream_vis.histtype, alpha=stream_vis.opacity, 63 | linewidth=stream_vis.linewidth) 64 | 65 | stream_vis.ax.set_xticks(bins) 66 | 67 | #stream_vis.ax.relim() 68 | #stream_vis.ax.autoscale_view() 69 | 70 | return False 71 | -------------------------------------------------------------------------------- /tensorwatch/imagenet_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from torchvision import transforms 5 | from . import pytorch_utils 6 | import json, os 7 | 8 | def get_image_transform(): 9 | transf = transforms.Compose([ #TODO: cache these transforms? 10 | get_resize_transform(), 11 | transforms.ToTensor(), 12 | get_normalize_transform() 13 | ]) 14 | 15 | return transf 16 | 17 | def get_resize_transform(): 18 | return transforms.Resize((224, 224)) 19 | 20 | def get_normalize_transform(): 21 | return transforms.Normalize(mean=[0.485, 0.456, 0.406], 22 | std=[0.229, 0.224, 0.225]) 23 | 24 | def image2batch(image): 25 | return pytorch_utils.image2batch(image, image_transform=get_image_transform()) 26 | 27 | def predict(model, images, image_transform=None, device=None): 28 | logits = pytorch_utils.batch_predict(model, images, 29 | input_transform=image_transform or get_image_transform(), device=device) 30 | probs = pytorch_utils.logits2probabilities(logits) #2-dim array, one column per class, one row per input 31 | return probs 32 | 33 | _imagenet_labels = None 34 | def get_imagenet_labels(): 35 | # pylint: disable=global-statement 36 | global _imagenet_labels 37 | _imagenet_labels = _imagenet_labels or ImagenetLabels() 38 | return _imagenet_labels 39 | 40 | def probabilities2classes(probs, topk=5): 41 | labels = get_imagenet_labels() 42 | top_probs = probs.topk(topk) 43 | # return (probability, class_id, class_label, class_code) 44 | return tuple((p,c, labels.index2label_text(c), labels.index2label_code(c)) \ 45 | for p, c in zip(top_probs[0][0].data.cpu().numpy(), top_probs[1][0].data.cpu().numpy())) 46 | 47 | class ImagenetLabels: 48 | def __init__(self, json_path=None): 49 | self._idx2label = [] 50 | self._idx2cls = [] 51 | self._cls2label = {} 52 | self._cls2idx = {} 53 | 54 | json_path = json_path or os.path.join(os.path.dirname(__file__), 'imagenet_class_index.json') 55 | 56 | with open(os.path.abspath(json_path), "r") as read_file: 57 | class_json = json.load(read_file) 58 | self._idx2label = [class_json[str(k)][1] for k in range(len(class_json))] 59 | self._idx2cls = [class_json[str(k)][0] for k in range(len(class_json))] 60 | self._cls2label = {class_json[str(k)][0]: class_json[str(k)][1] for k in range(len(class_json))} 61 | self._cls2idx = {class_json[str(k)][0]: k for k in range(len(class_json))} 62 | 63 | def index2label_text(self, index): 64 | return self._idx2label[index] 65 | def index2label_code(self, index): 66 | return self._idx2cls[index] 67 | def label_code2label_text(self, label_code): 68 | return self._cls2label[label_code] 69 | def label_code2index(self, label_code): 70 | return self._cls2idx[label_code] -------------------------------------------------------------------------------- /tensorwatch/saliency/deeplift.py: -------------------------------------------------------------------------------- 1 | from .backprop import GradxInputExplainer 2 | import types 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | # Based on formulation in DeepExplain, https://arxiv.org/abs/1711.06104 7 | # https://github.com/marcoancona/DeepExplain/blob/master/deepexplain/tensorflow/methods.py#L221-L272 8 | class DeepLIFTRescaleExplainer(GradxInputExplainer): 9 | def __init__(self, model): 10 | super(DeepLIFTRescaleExplainer, self).__init__(model) 11 | self._prepare_reference() 12 | self.baseline_inp = None 13 | self._override_backward() 14 | 15 | def _prepare_reference(self): 16 | def init_refs(m): 17 | name = m.__class__.__name__ 18 | if name.find('ReLU') != -1: 19 | m.ref_inp_list = [] 20 | m.ref_out_list = [] 21 | 22 | def ref_forward(self, x): 23 | self.ref_inp_list.append(x.data.clone()) 24 | out = F.relu(x) 25 | self.ref_out_list.append(out.data.clone()) 26 | return out 27 | 28 | def ref_replace(m): 29 | name = m.__class__.__name__ 30 | if name.find('ReLU') != -1: 31 | m.forward = types.MethodType(ref_forward, m) 32 | 33 | self.model.apply(init_refs) 34 | self.model.apply(ref_replace) 35 | 36 | def _reset_preference(self): 37 | def reset_refs(m): 38 | name = m.__class__.__name__ 39 | if name.find('ReLU') != -1: 40 | m.ref_inp_list = [] 41 | m.ref_out_list = [] 42 | 43 | self.model.apply(reset_refs) 44 | 45 | def _baseline_forward(self, inp): 46 | if self.baseline_inp is None: 47 | self.baseline_inp = inp.data.clone() 48 | self.baseline_inp.fill_(0.0) 49 | self.baseline_inp = Variable(self.baseline_inp) 50 | else: 51 | self.baseline_inp.fill_(0.0) 52 | # get ref 53 | _ = self.model(self.baseline_inp) 54 | 55 | def _override_backward(self): 56 | def new_backward(self, grad_out): 57 | ref_inp, inp = self.ref_inp_list 58 | ref_out, out = self.ref_out_list 59 | delta_out = out - ref_out 60 | delta_in = inp - ref_inp 61 | g1 = (delta_in.abs() > 1e-5).float() * grad_out * \ 62 | delta_out / delta_in 63 | mask = ((ref_inp + inp) > 0).float() 64 | g2 = (delta_in.abs() <= 1e-5).float() * 0.5 * mask * grad_out 65 | 66 | return g1 + g2 67 | 68 | def backward_replace(m): 69 | name = m.__class__.__name__ 70 | if name.find('ReLU') != -1: 71 | m.backward = types.MethodType(new_backward, m) 72 | 73 | self.model.apply(backward_replace) 74 | 75 | def explain(self, inp, ind=None, raw_inp=None): 76 | self._reset_preference() 77 | self._baseline_forward(inp) 78 | g = super(DeepLIFTRescaleExplainer, self).explain(inp, ind) 79 | 80 | return g 81 | -------------------------------------------------------------------------------- /tensorwatch/stream.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import weakref, uuid 5 | from typing import Any 6 | from . import utils 7 | from .lv_types import StreamItem 8 | 9 | class Stream: 10 | def __init__(self, stream_name:str=None, console_debug:bool=False): 11 | self._subscribers = weakref.WeakSet() 12 | self._subscribed_to = weakref.WeakSet() 13 | self.held_refs = set() # on some rare occasion we might want stream to hold references of other streams 14 | self.closed = False 15 | self.console_debug = console_debug 16 | self.stream_name = stream_name or str(uuid.uuid4()) # useful to use as key and avoid circular references 17 | self.items_written = 0 18 | 19 | def subscribe(self, stream:'Stream'): # notify other stream 20 | utils.debug_log('{} added {} as subscription'.format(self.stream_name, stream.stream_name)) 21 | stream._subscribers.add(self) 22 | self._subscribed_to.add(stream) 23 | 24 | def unsubscribe(self, stream:'Stream'): 25 | utils.debug_log('{} removed {} as subscription'.format(self.stream_name, stream.stream_name)) 26 | stream._subscribers.discard(self) 27 | self._subscribed_to.discard(stream) 28 | self.held_refs.discard(stream) 29 | #stream.held_refs.discard(self) # not needed as only subscriber should hold ref 30 | 31 | def to_stream_item(self, val:Any): 32 | stream_item = val if isinstance(val, StreamItem) else \ 33 | StreamItem(value=val, stream_name=self.stream_name) 34 | if stream_item.stream_name is None: 35 | stream_item.stream_name = self.stream_name 36 | if stream_item.item_index is None: 37 | stream_item.item_index = self.items_written 38 | return stream_item 39 | 40 | def write(self, val:Any, from_stream:'Stream'=None): 41 | # if you override write method, first you must call self.to_stream_item 42 | # so it can stamp the stamp the stream name 43 | stream_item = self.to_stream_item(val) 44 | 45 | if self.console_debug: 46 | print(self.stream_name, stream_item) 47 | 48 | for subscriber in self._subscribers: 49 | subscriber.write(stream_item, from_stream=self) 50 | self.items_written += 1 51 | 52 | def read_all(self, from_stream:'Stream'=None): 53 | for subscribed_to in self._subscribed_to: 54 | for stream_item in subscribed_to.read_all(from_stream=self): 55 | yield stream_item 56 | 57 | def load(self, from_stream:'Stream'=None): 58 | for subscribed_to in self._subscribed_to: 59 | subscribed_to.load(from_stream=self) 60 | 61 | def save(self, from_stream:'Stream'=None): 62 | for subscriber in self._subscribers: 63 | subscriber.save(from_stream=self) 64 | 65 | def close(self): 66 | if not self.closed: 67 | for subscribed_to in self._subscribed_to: 68 | subscribed_to._subscribers.discard(self) 69 | self._subscribed_to.clear() 70 | self.closed = True 71 | 72 | def __enter__(self): 73 | return self 74 | 75 | def __exit__(self, exception_type, exception_value, traceback): 76 | self.close() 77 | 78 | -------------------------------------------------------------------------------- /tensorwatch/tensor_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Sized, Any, List 2 | import numbers 3 | import numpy as np 4 | 5 | class TensorType: 6 | Torch = 'torch' 7 | TF = 'tf' 8 | Numeric = 'numeric' 9 | Numpy = 'numpy' 10 | Other = 'other' 11 | 12 | def tensor_type(item:Any)->str: 13 | module_name = type(item).__module__ 14 | class_name = type(item).__name__ 15 | 16 | if module_name=='torch' and class_name=='Tensor': 17 | return TensorType.Torch 18 | elif module_name.startswith('tensorflow') and class_name=='EagerTensor': 19 | return TensorType.TF 20 | elif isinstance(item, numbers.Number): 21 | return TensorType.Numeric 22 | elif isinstance(item, np.ndarray): 23 | return TensorType.Numpy 24 | else: 25 | return TensorType.Other 26 | 27 | def tensor2scaler(item:Any)->numbers.Number: 28 | tt = tensor_type(item) 29 | if item is None or tt == TensorType.Numeric: 30 | return item 31 | if tt == TensorType.TF: 32 | item = item.numpy() 33 | return item.item() # works for torch and numpy 34 | 35 | def tensor2np(item:Any)->np.ndarray: 36 | tt = tensor_type(item) 37 | if item is None or tt == TensorType.Numpy: 38 | return item 39 | elif tt == TensorType.TF: 40 | return item.numpy() 41 | elif tt == TensorType.Torch: 42 | return item.data.cpu().numpy() 43 | else: # numeric and everything else, let np take care 44 | return np.array(item) 45 | 46 | def to_scaler_list(l:Sized)->List[numbers.Number]: 47 | """Create list of scalers for given list of tensors where each element is 0-dim tensor 48 | """ 49 | if l is not None and len(l): 50 | tt = tensor_type(l[0]) 51 | if tt == TensorType.Torch or tt == TensorType.Numpy: 52 | return [i.item() for i in l] 53 | elif tt == TensorType.TF: 54 | return [i.numpy().item() for i in l] 55 | elif tt == TensorType.Numeric: 56 | # convert to list in case l is not list type 57 | return [i for i in l] 58 | else: 59 | raise ValueError('Cannot convert tensor list to scaler list \ 60 | because list element are of unsupported type ' + tt) 61 | else: 62 | return None if l is None else [] # make sure we always return list type 63 | 64 | def to_mean_list(l:Sized)->List[float]: 65 | """Create list of scalers for given list of tensors where each element is 0-dim tensor 66 | """ 67 | if l is not None and len(l): 68 | tt = tensor_type(l[0]) 69 | if tt == TensorType.Torch or tt == TensorType.Numpy: 70 | return [i.mean() for i in l] 71 | elif tt == TensorType.TF: 72 | return [i.numpy().mean() for i in l] 73 | elif tt == TensorType.Numeric: 74 | # convert to list in case l is not list type 75 | return [float(i) for i in l] 76 | else: 77 | raise ValueError('Cannot convert tensor list to scaler list \ 78 | because list element are of unsupported type ' + tt) 79 | else: 80 | return None if l is None else [] 81 | 82 | def to_np_list(l:Sized)->List[np.ndarray]: 83 | if l is not None and len(l): 84 | tt = tensor_type(l[0]) 85 | if tt == TensorType.Numeric: 86 | return [np.array(i) for i in l] 87 | if tt == TensorType.TF: 88 | return [i.numpy() for i in l] 89 | if tt == TensorType.Torch: 90 | return [i.data.cpu().numpy() for i in l] 91 | if tt == TensorType.Numpy: 92 | return [i for i in l] 93 | raise ValueError('Cannot convert tensor list to scaler list \ 94 | because list element are of unsupported type ' + tt) 95 | else: 96 | return None if l is None else [] -------------------------------------------------------------------------------- /tensorwatch/plotly/embeddings_plot.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from .. import image_utils 5 | import numpy as np 6 | from .line_plot import LinePlot 7 | import time 8 | from .. import utils 9 | 10 | class EmbeddingsPlot(LinePlot): 11 | def __init__(self, cell:LinePlot.widgets.Box=None, title=None, show_legend:bool=False, stream_name:str=None, console_debug:bool=False, 12 | is_3d:bool=True, hover_images=None, hover_image_reshape=None, **vis_args): 13 | utils.set_default(vis_args, 'height', '8in') 14 | super(EmbeddingsPlot, self).__init__(cell, title, show_legend, 15 | stream_name=stream_name, console_debug=console_debug, is_3d=is_3d, **vis_args) 16 | 17 | import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue 18 | if hover_images is not None: 19 | plt.ioff() 20 | self.image_output = LinePlot.widgets.Output() 21 | self.image_figure = plt.figure(figsize=(2,2)) 22 | self.image_ax = self.image_figure.add_subplot(111) 23 | self.cell.children += (self.image_output,) 24 | plt.ion() 25 | self.hover_images, self.hover_image_reshape = hover_images, hover_image_reshape 26 | self.last_ind, self.last_ind_time = -1, 0 27 | 28 | def hover_fn(self, trace, points, state): # pylint: disable=unused-argument 29 | if not points: 30 | return 31 | ind = points.point_inds[0] 32 | if ind == self.last_ind or ind > len(self.hover_images) or ind < 0: 33 | return 34 | 35 | if self.last_ind == -1: 36 | self.last_ind, self.last_ind_time = ind, time.time() 37 | else: 38 | elapsed = time.time() - self.last_ind_time 39 | if elapsed < 0.3: 40 | self.last_ind, self.last_ind_time = ind, time.time() 41 | if elapsed < 1: 42 | return 43 | # else too much time since update 44 | # else we have stable ind 45 | 46 | import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue 47 | with self.image_output: 48 | plt.ioff() 49 | 50 | if self.hover_image_reshape: 51 | img = np.reshape(self.hover_images[ind], self.hover_image_reshape) 52 | else: 53 | img = self.hover_images[ind] 54 | if img is not None: 55 | LinePlot.display.clear_output(wait=True) 56 | self.image_ax.imshow(img) 57 | LinePlot.display.display(self.image_figure) 58 | plt.ion() 59 | 60 | return None 61 | 62 | def _create_trace(self, stream_vis): 63 | stream_vis.stream_vis_args.clear() #TODO remove this 64 | utils.set_default(stream_vis.stream_vis_args, 'draw_line', False) 65 | utils.set_default(stream_vis.stream_vis_args, 'draw_marker', True) 66 | utils.set_default(stream_vis.stream_vis_args, 'draw_marker_text', True) 67 | utils.set_default(stream_vis.stream_vis_args, 'hoverinfo', 'text') 68 | utils.set_default(stream_vis.stream_vis_args, 'marker', {}) 69 | 70 | marker = stream_vis.stream_vis_args['marker'] 71 | utils.set_default(marker, 'size', 6) 72 | utils.set_default(marker, 'colorscale', 'Jet') 73 | utils.set_default(marker, 'showscale', False) 74 | utils.set_default(marker, 'opacity', 0.8) 75 | 76 | return super(EmbeddingsPlot, self)._create_trace(stream_vis) 77 | 78 | def subscribe(self, stream, **stream_vis_args): 79 | super(EmbeddingsPlot, self).subscribe(stream) 80 | stream_vis = self._stream_vises[stream.stream_name] 81 | if stream_vis.index == 0 and self.hover_images is not None: 82 | self.widget.data[stream_vis.trace_index].on_hover(self.hover_fn) -------------------------------------------------------------------------------- /tensorwatch/text_vis.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from . import utils 5 | from .vis_base import VisBase 6 | 7 | class TextVis(VisBase): 8 | def __init__(self, cell:VisBase.widgets.Box=None, title:str=None, show_legend:bool=None, 9 | stream_name:str=None, console_debug:bool=False, **vis_args): 10 | import pandas as pd # expensive import 11 | 12 | super(TextVis, self).__init__(VisBase.widgets.HTML(), cell, title, show_legend, 13 | stream_name=stream_name, console_debug=console_debug, **vis_args) 14 | self.df = pd.DataFrame([]) 15 | self.SeriesClass = pd.Series 16 | 17 | def _get_column_prefix(self, stream_vis, i): 18 | return '[S.{}]:{}'.format(stream_vis.index, i) 19 | 20 | def _get_title(self, stream_vis): 21 | title = stream_vis.title or 'Stream ' + str(len(self._stream_vises)) 22 | return title 23 | 24 | # this will be called from _show_stream_items 25 | def _append(self, stream_vis, vals): 26 | if vals is None: 27 | self.df = self.df.append(self.SeriesClass({self._get_column_prefix(stream_vis, 0) : None}), 28 | sort=False, ignore_index=True) 29 | return 30 | for val in vals: 31 | if val is None or utils.is_scalar(val): 32 | self.df = self.df.append(self.SeriesClass({self._get_column_prefix(stream_vis, 0) : val}), 33 | sort=False, ignore_index=True) 34 | elif utils.is_array_like(val): 35 | val_dict = {} 36 | for i,val_i in enumerate(val): 37 | val_dict[self._get_column_prefix(stream_vis, i)] = val_i 38 | self.df = self.df.append(self.SeriesClass(val_dict), sort=False, ignore_index=True) 39 | else: 40 | self.df = self.df.append(self.SeriesClass(val.__dict__), sort=False, ignore_index=True) 41 | 42 | def _post_add_subscription(self, stream_vis, **stream_vis_args): 43 | only_summary = stream_vis_args.get('only_summary', False) 44 | stream_vis.text = self._get_title(stream_vis) 45 | stream_vis.only_summary = only_summary 46 | 47 | def clear_plot(self, stream_vis, clear_history): 48 | self.df = self.df.iloc[0:0] 49 | 50 | def _show_stream_items(self, stream_vis, stream_items): 51 | """Paint the given stream_items in to visualizer. If visualizer is dirty then return False else True. 52 | """ 53 | for stream_item in stream_items: 54 | if stream_item.ended: 55 | self.df = self.df.append(self.SeriesClass({'Ended':True}), 56 | sort=False, ignore_index=True) 57 | else: 58 | vals = self._extract_vals((stream_item,)) 59 | self._append(stream_vis, vals) 60 | return False # dirty 61 | 62 | def _post_update_stream_plot(self, stream_vis): 63 | if VisBase.get_ipython(): 64 | if not stream_vis.only_summary: 65 | self.widget.value = self.df.to_html(classes=['output_html', 'rendered_html']) 66 | else: 67 | self.widget.value = self.df.describe().to_html(classes=['output_html', 'rendered_html']) 68 | # below doesn't work because of threading issue 69 | #self.widget.clear_output(wait=True) 70 | #with self.widget: 71 | # display.display(self.df) 72 | else: 73 | last_recs = self.df.iloc[[-1]].to_dict('records') 74 | if len(last_recs) == 1: 75 | print(last_recs[0]) 76 | else: 77 | print(last_recs) 78 | 79 | def _show_widget_native(self, blocking:bool): 80 | return None # we will be using console 81 | 82 | def _show_widget_notebook(self): 83 | return self.widget -------------------------------------------------------------------------------- /tensorwatch/model_graph/torchstat_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import torchstat 5 | import pandas as pd 6 | 7 | def model_stats(model, input_shape): 8 | if len(input_shape) > 3: 9 | input_shape = input_shape[1:4] 10 | ms = torchstat.statistics.ModelStat(model, input_shape, 1) 11 | collected_nodes = ms._analyze_model() 12 | return _report_format(collected_nodes) 13 | 14 | def _round_value(value, binary=False): 15 | divisor = 1024. if binary else 1000. 16 | 17 | if value // divisor**4 > 0: 18 | return str(round(value / divisor**4, 2)) + 'T' 19 | elif value // divisor**3 > 0: 20 | return str(round(value / divisor**3, 2)) + 'G' 21 | elif value // divisor**2 > 0: 22 | return str(round(value / divisor**2, 2)) + 'M' 23 | elif value // divisor > 0: 24 | return str(round(value / divisor, 2)) + 'K' 25 | return str(value) 26 | 27 | 28 | def _report_format(collected_nodes): 29 | pd.set_option('display.width', 1000) 30 | pd.set_option('display.max_rows', 10000) 31 | pd.set_option('display.max_columns', 10000) 32 | 33 | data = list() 34 | for node in collected_nodes: 35 | name = node.name 36 | input_shape = ' '.join(['{:>3d}'] * len(node.input_shape)).format( 37 | *[e for e in node.input_shape]) 38 | output_shape = ' '.join(['{:>3d}'] * len(node.output_shape)).format( 39 | *[e for e in node.output_shape]) 40 | parameter_quantity = node.parameter_quantity 41 | inference_memory = node.inference_memory 42 | MAdd = node.MAdd 43 | Flops = node.Flops 44 | mread, mwrite = [i for i in node.Memory] 45 | duration = node.duration 46 | data.append([name, input_shape, output_shape, parameter_quantity, 47 | inference_memory, MAdd, duration, Flops, mread, 48 | mwrite]) 49 | df = pd.DataFrame(data) 50 | df.columns = ['module name', 'input shape', 'output shape', 51 | 'params', 'memory(MB)', 52 | 'MAdd', 'duration', 'Flops', 'MemRead(B)', 'MemWrite(B)'] 53 | df['duration[%]'] = df['duration'] / (df['duration'].sum() + 1e-7) 54 | df['MemR+W(B)'] = df['MemRead(B)'] + df['MemWrite(B)'] 55 | total_parameters_quantity = df['params'].sum() 56 | total_memory = df['memory(MB)'].sum() 57 | total_operation_quantity = df['MAdd'].sum() 58 | total_flops = df['Flops'].sum() 59 | total_duration = df['duration[%]'].sum() 60 | total_mread = df['MemRead(B)'].sum() 61 | total_mwrite = df['MemWrite(B)'].sum() 62 | total_memrw = df['MemR+W(B)'].sum() 63 | del df['duration'] 64 | 65 | # Add Total row 66 | total_df = pd.Series([total_parameters_quantity, total_memory, 67 | total_operation_quantity, total_flops, 68 | total_duration, mread, mwrite, total_memrw], 69 | index=['params', 'memory(MB)', 'MAdd', 'Flops', 'duration[%]', 70 | 'MemRead(B)', 'MemWrite(B)', 'MemR+W(B)'], 71 | name='total') 72 | df = df.append(total_df) 73 | 74 | df = df.fillna(' ') 75 | df['memory(MB)'] = df['memory(MB)'].apply( 76 | lambda x: '{:.2f}'.format(x)) 77 | df['duration[%]'] = df['duration[%]'].apply(lambda x: '{:.2%}'.format(x)) 78 | df['MAdd'] = df['MAdd'].apply(lambda x: '{:,}'.format(x)) 79 | df['Flops'] = df['Flops'].apply(lambda x: '{:,}'.format(x)) 80 | 81 | #summary = "Total params: {:,}\n".format(total_parameters_quantity) 82 | 83 | #summary += "-" * len(str(df).split('\n')[0]) 84 | #summary += '\n' 85 | #summary += "Total memory: {:.2f}MB\n".format(total_memory) 86 | #summary += "Total MAdd: {}MAdd\n".format(_round_value(total_operation_quantity)) 87 | #summary += "Total Flops: {}Flops\n".format(_round_value(total_flops)) 88 | #summary += "Total MemR+W: {}B\n".format(_round_value(total_memrw, True)) 89 | return df -------------------------------------------------------------------------------- /tensorwatch/watcher_client.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from typing import Any, Dict, Sequence, List 5 | from .zmq_wrapper import ZmqWrapper 6 | from .lv_types import CliSrvReqTypes, ClientServerRequest, DefaultPorts 7 | from .lv_types import VisArgs, PublisherTopics, ServerMgmtMsg, StreamCreateRequest 8 | from .stream import Stream 9 | from .zmq_mgmt_stream import ZmqMgmtStream 10 | from . import utils 11 | from .watcher_base import WatcherBase 12 | 13 | class WatcherClient(WatcherBase): 14 | r"""Extends watcher to add methods so calls for create and delete stream can be sent to server. 15 | """ 16 | def __init__(self, filename:str=None, port:int=0): 17 | super(WatcherClient, self).__init__() 18 | self.port = port 19 | self.filename = filename 20 | 21 | # define vars in __init__ 22 | self._clisrv = None # client-server sockets allows to send create/del stream requests 23 | self._zmq_srvmgmt_sub = None 24 | self._file = None 25 | 26 | self._open() 27 | 28 | def _reset(self): 29 | self._clisrv = None 30 | self._zmq_srvmgmt_sub = None 31 | self._file = None 32 | utils.debug_log("WatcherClient reset", verbosity=1) 33 | super(WatcherClient, self)._reset() 34 | 35 | def _open(self): 36 | if self.port is not None: 37 | self._clisrv = ZmqWrapper.ClientServer(port=DefaultPorts.CliSrv+self.port, 38 | is_server=False) 39 | # create subscription where we will receive server management events 40 | self._zmq_srvmgmt_sub = ZmqMgmtStream(clisrv=self._clisrv, for_write=False, port=self.port, 41 | stream_name='zmq_srvmgmt_sub:'+str(self.port)+':False') 42 | 43 | def close(self): 44 | if not self.closed: 45 | self._zmq_srvmgmt_sub.close() 46 | self._clisrv.close() 47 | utils.debug_log("WatcherClient is closed", verbosity=1) 48 | super(WatcherClient, self).close() 49 | 50 | def devices_or_default(self, devices:Sequence[str])->Sequence[str]: # overridden 51 | # TODO: this method is duplicated in Watcher and WatcherClient 52 | 53 | # make sure TCP port is attached to tcp device 54 | if devices is not None: 55 | return ['tcp:' + str(self.port) if device=='tcp' else device for device in devices] 56 | 57 | # if no devices specified then use our filename and tcp:port as default devices 58 | devices = [] 59 | # first open file device because it may have older data 60 | if self.filename is not None: 61 | devices.append('file:' + self.filename) 62 | if self.port is not None: 63 | devices.append('tcp:' + str(self.port)) 64 | return devices 65 | 66 | # override to send request to server, instead of underlying WatcherBase base class 67 | def create_stream(self, name:str=None, devices:Sequence[str]=None, event_name:str='', 68 | expr=None, throttle:float=1, vis_args:VisArgs=None)->Stream: # overriden 69 | 70 | stream_req = StreamCreateRequest(stream_name=name, devices=self.devices_or_default(devices), 71 | event_name=event_name, expr=expr, throttle=throttle, vis_args=vis_args) 72 | 73 | self._zmq_srvmgmt_sub.add_stream_req(stream_req) 74 | 75 | if stream_req.devices is not None: 76 | stream = self.open_stream(name=stream_req.stream_name, devices=stream_req.devices) 77 | else: # we cannot return remote streams that are not backed by a device 78 | stream = None 79 | return stream 80 | 81 | # override to set devices default to tcp 82 | def open_stream(self, name:str=None, devices:Sequence[str]=None)->Stream: # overriden 83 | return super(WatcherClient, self).open_stream(name=name, devices=devices) 84 | 85 | 86 | # override to send request to server 87 | def del_stream(self, name:str) -> None: 88 | self._zmq_srvmgmt_sub.del_stream(name) 89 | 90 | -------------------------------------------------------------------------------- /docs/simple_logging.md: -------------------------------------------------------------------------------- 1 | 2 | # Simple Logging Tutorial 3 | 4 | In this tutorial, we will use a simple script that maintains two variables. Our goal is to log the values of these variables using TensorWatch and view it in real-time in variety of ways from Jupyter Notebook demonstrating various features offered by TensorWatch. We would execute this script from a console and its code looks like this: 5 | 6 | ``` 7 | # available in test/simple_log/sum_log.py 8 | 9 | import time, random 10 | import tensorwatch as tw 11 | 12 | # we will create two streams, one for 13 | # sums of integers, other for random numbers 14 | w = tw.Watcher() 15 | st_isum = w.create_stream('isums') 16 | st_rsum = w.create_stream('rsums') 17 | 18 | isum, rsum = 0, 0 19 | for i in range(10000): 20 | isum += i 21 | rsum += random.randint(i) 22 | 23 | # write to streams 24 | st_isum.write((i, isum)) 25 | st_rsum.write((i, rsum)) 26 | 27 | time.sleep(1) 28 | ``` 29 | 30 | ## Basic Concepts 31 | 32 | The `Watcher` object allows you to create streams. You can then write any values in to these streams. By default, the `Watcher` object opens up TCP/IP sockets so a client can request these streams. As values are being written on one end, they get serialized and sent to any client(s) on the other end. We will use Jupyter Notebook to create a client, open these streams and feed them to various visualizers. 33 | 34 | ## Create the Client 35 | You can either create a new Jupyter Notebook or get the existing one in the repo at `notebooks/simple_logging.ipynb`. 36 | 37 | Let's first do imports: 38 | 39 | 40 | ```python 41 | %matplotlib notebook 42 | import tensorwatch as tw 43 | ``` 44 | 45 | Next, open the streams that we had created in above script: 46 | 47 | ```python 48 | cli = tw.WatcherClient() 49 | st_isum = cli.open_stream('isums') 50 | st_rsum = cli.open_stream('rsums') 51 | ``` 52 | 53 | Now lets visualize isum stream in textual format: 54 | 55 | 56 | ```python 57 | text_vis = tw.Visualizer(st_isum, vis_type='text') 58 | text_vis.show() 59 | ``` 60 | 61 | You should see a growing table in real-time like this, each row is tuple we wrote to the stream: 62 | 63 | 64 | 65 | 66 | That worked out good! How about feeding the same stream simultaneously to plot a line graph? 67 | 68 | 69 | ```python 70 | line_plot = tw.Visualizer(st_isum, vis_type='line', xtitle='i', ytitle='isum') 71 | line_plot.show() 72 | ``` 73 | 74 | 75 | 76 | Let's take step further. Say, we plot rsum as well but, just for fun, in the *same plot* as isum above. That's easy with the optional `host` parameter: 77 | 78 | ```python 79 | line_plot = tw.Visualizer(st_rsum, vis_type='line', host=line_plot, ytitle='rsum') 80 | ``` 81 | This instantly changes our line plot and a new Y-Axis is automatically added for our second stream. TensorWatch allows you to add multiple streams in same visualizer. 82 | 83 | Are you curious about statistics of these two streams? Let's display the live statistics of both streams side by side! To do this, we will use the `cell` parameter for the second visualization to place right next to the first one: 84 | 85 | ```python 86 | istats = tw.Visualizer(st_isum, vis_type='summary') 87 | istats.show() 88 | 89 | rstats = tw.Visualizer(st_rsum, vis_type='summary', cell=istats) 90 | ``` 91 | 92 | The output looks like this (each panel showing statistics for the tuple that was logged in the stream): 93 | 94 | 95 | 96 | ## Next Steps 97 | 98 | We have just scratched the surface of what you can do with TensorWatch. You should check out [Lazy Logging Tutorial](lazy_logging.md) and other [notebooks](https://github.com/microsoft/tensorwatch/tree/master/notebooks) as well. 99 | 100 | ## Questions? 101 | 102 | File a [Github issue](https://github.com/microsoft/tensorwatch/issues/new) and let us know if we can improve this tutorial. -------------------------------------------------------------------------------- /test/mnist/cli_mnist.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | import time 3 | import math 4 | from tensorwatch import utils 5 | 6 | utils.set_debug_verbosity(4) 7 | 8 | def img_in_class(): 9 | cli_train = tw.WatcherClient() 10 | 11 | imgs = cli_train.create_stream(event_name='batch', 12 | expr="topk_all(l, batch_vals=lambda b: (b.batch.loss_all, (b.batch.input, b.batch.output), b.batch.target), \ 13 | out_f=image_class_outf, order='dsc')", throttle=1) 14 | img_plot = tw.ImagePlot() 15 | img_plot.subscribe(imgs, viz_img_scale=3) 16 | img_plot.show() 17 | 18 | tw.image_utils.plt_loop() 19 | 20 | def show_find_lr(): 21 | cli_train = tw.WatcherClient() 22 | plot = tw.LinePlot() 23 | 24 | train_batch_loss = cli_train.create_stream(event_name='batch', 25 | expr='lambda d:(d.tt.scheduler.get_lr()[0], d.metrics.batch_loss)') 26 | plot.subscribe(train_batch_loss, xtitle='Epoch', ytitle='Loss') 27 | 28 | utils.wait_key() 29 | 30 | def plot_grads_plotly(): 31 | train_cli = tw.WatcherClient() 32 | grads = train_cli.create_stream(event_name='batch', 33 | expr='lambda d:grads_abs_mean(d.model)', throttle=1) 34 | p = tw.plotly.line_plot.LinePlot('Demo') 35 | p.subscribe(grads, xtitle='Layer', ytitle='Gradients', history_len=30, new_on_eval=True) 36 | utils.wait_key() 37 | 38 | 39 | def plot_grads(): 40 | train_cli = tw.WatcherClient() 41 | 42 | grads = train_cli.create_stream(event_name='batch', 43 | expr='lambda d:grads_abs_mean(d.model)', throttle=1) 44 | grad_plot = tw.LinePlot() 45 | grad_plot.subscribe(grads, xtitle='Layer', ytitle='Gradients', clear_after_each=1, history_len=40, dim_history=True) 46 | grad_plot.show() 47 | 48 | tw.plt_loop() 49 | 50 | def plot_weight(): 51 | train_cli = tw.WatcherClient() 52 | 53 | params = train_cli.create_stream(event_name='batch', 54 | expr='lambda d:weights_abs_mean(d.model)', throttle=1) 55 | params_plot = tw.LinePlot() 56 | params_plot.subscribe(params, xtitle='Layer', ytitle='avg |params|', clear_after_each=1, history_len=40, dim_history=True) 57 | params_plot.show() 58 | 59 | tw.plt_loop() 60 | 61 | def epoch_stats(): 62 | train_cli = tw.WatcherClient(port=0) 63 | test_cli = tw.WatcherClient(port=1) 64 | 65 | plot = tw.LinePlot() 66 | 67 | train_loss = train_cli.create_stream(event_name="epoch", 68 | expr='lambda v:(v.metrics.epoch_index, v.metrics.epoch_loss)') 69 | plot.subscribe(train_loss, xtitle='Epoch', ytitle='Train Loss') 70 | 71 | test_acc = test_cli.create_stream(event_name="epoch", 72 | expr='lambda v:(v.metrics.epoch_index, v.metrics.epoch_accuracy)') 73 | plot.subscribe(test_acc, xtitle='Epoch', ytitle='Test Accuracy', ylim=(0,1)) 74 | 75 | plot.show() 76 | tw.plt_loop() 77 | 78 | 79 | def batch_stats(): 80 | train_cli = tw.WatcherClient() 81 | stream = train_cli.create_stream(event_name="batch", 82 | expr='lambda v:(v.metrics.epochf, v.metrics.batch_loss)', throttle=0.75) 83 | 84 | train_loss = tw.Visualizer(stream, clear_after_end=False, vis_type='mpl-line', 85 | xtitle='Epoch', ytitle='Train Loss') 86 | 87 | #train_acc = tw.Visualizer('lambda v:(v.metrics.epochf, v.metrics.epoch_loss)', event_name="batch", 88 | # xtitle='Epoch', ytitle='Train Accuracy', clear_after_end=False, yrange=(0,1), 89 | # vis=train_loss, vis_type='mpl-line') 90 | 91 | train_loss.show() 92 | tw.plt_loop() 93 | 94 | def text_stats(): 95 | train_cli = tw.WatcherClient() 96 | stream = train_cli.create_stream(event_name="batch", 97 | expr='lambda d:(d.metrics.epoch_index, d.metrics.batch_loss)') 98 | 99 | trl = tw.Visualizer(stream, vis_type='text') 100 | trl.show() 101 | input('Paused...') 102 | 103 | 104 | 105 | #epoch_stats() 106 | #plot_weight() 107 | #plot_grads() 108 | #img_in_class() 109 | text_stats() 110 | #batch_stats() -------------------------------------------------------------------------------- /docs/lazy_logging.md: -------------------------------------------------------------------------------- 1 | 2 | # Lazy Logging Tutorial 3 | 4 | In this tutorial, we will use a straightforward script that creates an array of random numbers. Our goal is to examine this array from the Jupyter Notebook while below script is running separately in a console. 5 | 6 | ``` 7 | # available in test/simple_log/sum_lazy.py 8 | 9 | import time, random 10 | import tensorwatch as tw 11 | 12 | # create watcher object as usual 13 | w = tw.Watcher() 14 | 15 | weights = None 16 | for i in range(10000): 17 | weights = [random.random() for _ in range(5)] 18 | 19 | # let watcher observe variables we have 20 | # this has almost no performance cost 21 | w.observe(weights=weights) 22 | 23 | time.sleep(1) 24 | ``` 25 | 26 | ## Basic Concepts 27 | 28 | Notice that we give `Watcher` object an opportunity to *observe* variables. Observing variables is very cheap so we can observe anything that we might be interested in, for example, an entire model or a batch of data. From the Jupyter Notebook, you can specify arbitrary *lambda expression* that we execute in the context of these observed variables. The result of this lambda expression is sent as stream back to Jupyter Notebook so that you can render into visualizer of your choice. That's pretty much it. 29 | 30 | ## How Does This Work? 31 | 32 | TensorWatch has two core classes: `Watcher` and `WatcherClient`. The `Watcher` object will open up TCP/IP sockets by default and listens to any incoming requests. The `WatcherClient` allows you to connect to `Watcher` and have it execute any Python [lambda expression](http://book.pythontips.com/en/latest/lambdas.html) you want. The Python lambda expressions consume a stream of values as input and outputs another stream of values. The lambda expressions may typically contain [map, filter, and reduce](http://book.pythontips.com/en/latest/map_filter.html) so you can transform values in the input stream, filter them, or aggregate them. Let's see all these with a simple example. 33 | 34 | ## Create the Client 35 | You can either create a Jupyter Notebook or get the existing one in the repo at `notebooks/lazy_logging.ipynb`. 36 | 37 | Let's first do imports: 38 | 39 | 40 | ```python 41 | %matplotlib notebook 42 | import tensorwatch as tw 43 | ``` 44 | 45 | Next, create a stream using a simple lambda expression that sums values in our `weights` array that we are observing in the above script. The input to the lambda expression is an object that has all the variables we are observing. 46 | 47 | ```python 48 | client = tw.WatcherClient() 49 | stream = client.create_stream(expr='lambda d: np.sum(d.weights)') 50 | ``` 51 | 52 | Next, send this stream to line chart visualization: 53 | 54 | ```python 55 | line_plot = tw.Visualizer(stream, vis_type='line') 56 | line_plot.show() 57 | ``` 58 | 59 | Now when you run our script `sum_lazzy.py` in the console and then above Jupyter Notebook, you will see the sum of `weights` array getting plotted in real-time: 60 | 61 | 62 | 63 | ## What Just Happened? 64 | 65 | Notice that your script is running in a different process than Jupyter Notebook. You could have been running your Jupyter Notebook on your laptop while your script could have been in some VM in a cloud. You can do queries on observed variables at the run-time! As `Watcher.observe()` is a cheap call, you could observe your entire model, batch inputs and outputs, and so on. You can then slice and dice all these observed variables from the comfort of your Jupyter Notebook while your script is progressing! 66 | 67 | ## Next Steps 68 | 69 | We have just scratched the surface in this new land of lazy logging! You can do much more with this idea. For example, you can create *events* so your lambda expression gets executed only on those events. This is useful to create streams that generate data per batch or per epoch. You can also use many predefined lambda expressions that allows you to view top k by some criteria. 70 | 71 | ## Questions? 72 | 73 | File a [Github issue](https://github.com/microsoft/tensorwatch/issues/new) and let us know if we can improve this tutorial. -------------------------------------------------------------------------------- /tensorwatch/notebook_maker.py: -------------------------------------------------------------------------------- 1 | import nbformat 2 | from nbformat.v4 import new_code_cell, new_markdown_cell, new_notebook 3 | from nbformat import v3, v4 4 | import codecs 5 | from os import linesep, path 6 | import uuid 7 | from . import utils 8 | import re 9 | from typing import List 10 | from .lv_types import VisArgs 11 | 12 | class NotebookMaker: 13 | def __init__(self, watcher, filename:str=None)->None: 14 | self.filename = filename or \ 15 | (path.splitext(watcher.filename)[0] + '.ipynb' if watcher.filename else \ 16 | 'tensorwatch.ipynb') 17 | 18 | self.cells = [] 19 | self._default_vis_args = VisArgs() 20 | 21 | watcher_args_str = NotebookMaker._get_vis_args(watcher) 22 | 23 | # create initial cell 24 | self.cells.append(new_code_cell(source=linesep.join( 25 | ['%matplotlib notebook', 26 | 'import tensorwatch as tw', 27 | 'client = tw.WatcherClient({})'.format(NotebookMaker._get_vis_args(watcher))]))) 28 | 29 | def _get_vis_args(watcher)->str: 30 | args_strs = [] 31 | for param, default_v in [('port', 0), ('filename', None)]: 32 | if hasattr(watcher, param): 33 | v = getattr(watcher, param) 34 | if v==default_v or (v is None and default_v is None): 35 | continue 36 | args_strs.append("{}={}".format(param, NotebookMaker._val2str(v))) 37 | return ', '.join(args_strs) 38 | 39 | def _get_stream_identifier(prefix, event_name, stream_name, stream_index)->str: 40 | if not stream_name or utils.is_uuid4(stream_name): 41 | if event_name is not None and event_name != '': 42 | return '{}_{}_{}'.format(prefix, event_name, stream_index) 43 | else: 44 | return prefix + str(stream_index) 45 | else: 46 | return '{}{}_{}'.format(prefix, stream_index, utils.str2identifier(stream_name)[:8]) 47 | 48 | def _val2str(v)->str: 49 | # TODO: shall we raise error if non str, bool, number (or its container) parameters? 50 | return str(v) if not isinstance(v, str) else "'{}'".format(v) 51 | 52 | def _add_vis_args_str(self, stream_info, param_strs:List[str])->None: 53 | default_args = self._default_vis_args.__dict__ 54 | if not stream_info.req.vis_args: 55 | return 56 | for k, v in stream_info.req.vis_args.__dict__.items(): 57 | if k in default_args: 58 | default_v = default_args[k] 59 | if (v is None and default_v is None) or (v==default_v): 60 | continue # skip param if its value is not changed from default 61 | param_strs.append("{}={}".format(k, NotebookMaker._val2str(v))) 62 | 63 | def _get_stream_code(self, event_name, stream_name, stream_index, stream_info)->List[str]: 64 | lines = [] 65 | 66 | stream_identifier = 's'+str(stream_index) 67 | lines.append("{} = client.open_stream(name='{}')".format(stream_identifier, stream_name)) 68 | 69 | vis_identifier = 'v'+str(stream_index) 70 | vis_args_strs = ['stream={}'.format(stream_identifier)] 71 | self._add_vis_args_str(stream_info, vis_args_strs) 72 | lines.append("{} = tw.Visualizer({})".format(vis_identifier, ', '.join(vis_args_strs))) 73 | lines.append("{}.show()".format(vis_identifier)) 74 | return lines 75 | 76 | def add_streams(self, event_stream_infos)->None: 77 | stream_index = 0 78 | for event_name, stream_infos in event_stream_infos.items(): # per event 79 | for stream_name, stream_info in stream_infos.items(): 80 | lines = self._get_stream_code(event_name, stream_name, stream_index, stream_info) 81 | self.cells.append(new_code_cell(source=linesep.join(lines))) 82 | stream_index += 1 83 | 84 | def write(self): 85 | nb = new_notebook(cells=self.cells, metadata={'language': 'python',}) 86 | with codecs.open(self.filename, encoding='utf-8', mode='w') as f: 87 | nbformat.write(nb, f, 4) 88 | 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /tensorwatch/watcher.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import uuid 5 | from typing import Sequence 6 | from .zmq_wrapper import ZmqWrapper 7 | from .watcher_base import WatcherBase 8 | from .lv_types import CliSrvReqTypes 9 | from .lv_types import DefaultPorts, PublisherTopics, ServerMgmtMsg 10 | from . import utils 11 | import threading, time 12 | 13 | class Watcher(WatcherBase): 14 | def __init__(self, filename:str=None, port:int=0, srv_name:str=None): 15 | super(Watcher, self).__init__() 16 | 17 | self.port = port 18 | self.filename = filename 19 | 20 | # used to detect server restarts 21 | self.srv_name = srv_name or str(uuid.uuid4()) 22 | 23 | # define vars in __init__ 24 | self._clisrv = None 25 | self._zmq_stream_pub = None 26 | self._file = None 27 | self._th = None 28 | 29 | self._open_devices() 30 | 31 | def _open_devices(self): 32 | if self.port is not None: 33 | self._clisrv = ZmqWrapper.ClientServer(port=DefaultPorts.CliSrv+self.port, 34 | is_server=True, callback=self._clisrv_callback) 35 | 36 | # notify existing listeners of our ID 37 | self._zmq_stream_pub = self._stream_factory.get_streams(stream_types=['tcp:'+str(self.port)], for_write=True)[0] 38 | 39 | # ZMQ quirk: we must wait a bit after opening port and before sending message 40 | # TODO: can we do better? 41 | self._th = threading.Thread(target=self._send_server_start) 42 | self._th.start() 43 | 44 | def _send_server_start(self): 45 | time.sleep(2) 46 | self._zmq_stream_pub.write(ServerMgmtMsg(event_name=ServerMgmtMsg.EventServerStart, 47 | event_args=self.srv_name), topic=PublisherTopics.ServerMgmt) 48 | 49 | def devices_or_default(self, devices:Sequence[str])->Sequence[str]: # overriden 50 | # TODO: this method is duplicated in Watcher and WatcherClient 51 | 52 | # make sure TCP port is attached to tcp device 53 | if devices is not None: 54 | return ['tcp:' + str(self.port) if device=='tcp' else device for device in devices] 55 | 56 | # if no devices specified then use our filename and tcp:port as default devices 57 | devices = [] 58 | # first open file device because it may have older data 59 | if self.filename is not None: 60 | devices.append('file:' + self.filename) 61 | if self.port is not None: 62 | devices.append('tcp:' + str(self.port)) 63 | return devices 64 | 65 | def close(self): 66 | if not self.closed: 67 | if self._clisrv is not None: 68 | self._clisrv.close() 69 | if self._zmq_stream_pub is not None: 70 | self._zmq_stream_pub.close() 71 | if self._file is not None: 72 | self._file.close() 73 | utils.debug_log("Watcher is closed", verbosity=1) 74 | super(Watcher, self).close() 75 | 76 | def _reset(self): 77 | self._clisrv = None 78 | self._zmq_stream_pub = None 79 | self._file = None 80 | self._th = None 81 | utils.debug_log("Watcher reset", verbosity=1) 82 | super(Watcher, self)._reset() 83 | 84 | def _clisrv_callback(self, clisrv, clisrv_req): # pylint: disable=unused-argument 85 | utils.debug_log("Received client request", clisrv_req.req_type) 86 | 87 | # request = create stream 88 | if clisrv_req.req_type == CliSrvReqTypes.create_stream: 89 | stream_req = clisrv_req.req_data 90 | self.create_stream(name=stream_req.stream_name, devices=stream_req.devices, 91 | event_name=stream_req.event_name, expr=stream_req.expr, throttle=stream_req.throttle, 92 | vis_args=stream_req.vis_args) 93 | return None # ignore return as we can't send back stream obj 94 | elif clisrv_req.req_type == CliSrvReqTypes.del_stream: 95 | stream_name = clisrv_req.req_data 96 | return self.del_stream(stream_name) 97 | else: 98 | raise ValueError('ClientServer Request Type {} is not recognized'.format(clisrv_req)) -------------------------------------------------------------------------------- /tensorwatch/mpl/image_plot.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from .base_mpl_plot import BaseMplPlot 5 | from .. import utils, image_utils 6 | import numpy as np 7 | 8 | #from IPython import get_ipython 9 | 10 | class ImagePlot(BaseMplPlot): 11 | def init_stream_plot(self, stream_vis, 12 | rows=2, cols=5, img_width=None, img_height=None, img_channels=None, 13 | colormap=None, viz_img_scale=None, **stream_vis_args): 14 | stream_vis.rows, stream_vis.cols = rows, cols 15 | stream_vis.img_channels, stream_vis.colormap = img_channels, colormap 16 | stream_vis.img_width, stream_vis.img_height = img_width, img_height 17 | stream_vis.viz_img_scale = viz_img_scale 18 | # subplots holding each image 19 | stream_vis.axs = [[None for _ in range(cols)] for _ in range(rows)] 20 | # axis image 21 | stream_vis.ax_imgs = [[None for _ in range(cols)] for _ in range(rows)] 22 | 23 | def clear_plot(self, stream_vis, clear_history): 24 | for row in range(stream_vis.rows): 25 | for col in range(stream_vis.cols): 26 | img = stream_vis.ax_imgs[row][col] 27 | if img: 28 | x, y = img.get_size() 29 | img.set_data(np.zeros((x, y))) 30 | 31 | def _show_stream_items(self, stream_vis, stream_items): 32 | """Paint the given stream_items in to visualizer. If visualizer is dirty then return False else True. 33 | """ 34 | 35 | # as we repaint each image plot, select last if multiple events were pending 36 | stream_item = None 37 | for er in reversed(stream_items): 38 | if not(er.ended or er.value is None): 39 | stream_item = er 40 | break 41 | if stream_item is None: 42 | return True 43 | 44 | row, col, i = 0, 0, 0 45 | dirty = False 46 | # stream_item.value is expected to be ImagePlotItems 47 | for image_list in stream_item.value: 48 | # convert to imshow compatible, stitch images 49 | images = [image_utils.to_imshow_array(img, stream_vis.img_width, stream_vis.img_height) \ 50 | for img in image_list.images if img is not None] 51 | img_viz = image_utils.stitch_horizontal(images, width_dim=1) 52 | 53 | # resize if requested 54 | if stream_vis.viz_img_scale is not None: 55 | import skimage.transform # expensive import done on demand 56 | 57 | if isinstance(img_viz, np.ndarray) and np.issubdtype(img_viz.dtype, np.floating): 58 | img_viz = img_viz.clip(-1, 1) # some MNIST images have out of range values causing exception in sklearn 59 | img_viz = skimage.transform.rescale(img_viz, 60 | (stream_vis.viz_img_scale, stream_vis.viz_img_scale), mode='reflect', preserve_range=False) 61 | 62 | # create subplot if it doesn't exist 63 | ax = stream_vis.axs[row][col] 64 | if ax is None: 65 | ax = stream_vis.axs[row][col] = \ 66 | self.figure.add_subplot(stream_vis.rows, stream_vis.cols, i+1) 67 | ax.set_xticks([]) 68 | ax.set_yticks([]) 69 | 70 | cmap = image_list.cmap or ('Greys' if stream_vis.colormap is None and \ 71 | len(img_viz.shape) == 2 else stream_vis.colormap) 72 | 73 | stream_vis.ax_imgs[row][col] = ax.imshow(img_viz, interpolation="none", cmap=cmap, alpha=image_list.alpha) 74 | dirty = True 75 | 76 | # set title 77 | title = image_list.title 78 | if len(title) > 12: #wordwrap if too long 79 | title = utils.wrap_string(title) if len(title) > 24 else title 80 | fontsize = 8 81 | else: 82 | fontsize = 12 83 | ax.set_title(title, fontsize=fontsize) #'fontweight': 'light' 84 | 85 | #ax.autoscale_view() # not needed 86 | col = col + 1 87 | if col >= stream_vis.cols: 88 | col = 0 89 | row = row + 1 90 | if row >= stream_vis.rows: 91 | break 92 | i += 1 93 | 94 | return not dirty 95 | 96 | 97 | def has_legend(self): 98 | return self.show_legend or False -------------------------------------------------------------------------------- /tensorwatch/stream_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from typing import Dict, Sequence, List, Any 5 | from .zmq_stream import ZmqStream 6 | from .file_stream import FileStream 7 | from .stream import Stream 8 | from .stream_union import StreamUnion 9 | import uuid 10 | from . import utils 11 | 12 | class StreamFactory: 13 | r"""Allows to create shared stream such as file and ZMQ streams 14 | """ 15 | 16 | def __init__(self)->None: 17 | self.closed = None 18 | self._streams:Dict[str, Stream] = None 19 | self._reset() 20 | 21 | def _reset(self): 22 | self._streams:Dict[str, Stream] = {} 23 | self.closed = False 24 | 25 | def close(self): 26 | if not self.closed: 27 | for stream in self._streams.values(): 28 | stream.close() 29 | self._reset() 30 | self.closed = True 31 | 32 | def __enter__(self): 33 | return self 34 | def __exit__(self, exception_type, exception_value, traceback): 35 | self.close() 36 | 37 | def get_streams(self, stream_types:Sequence[str], for_write:bool=None)->List[Stream]: 38 | streams = [self._create_stream_by_string(stream_type, for_write) for stream_type in stream_types] 39 | return streams 40 | 41 | def get_combined_stream(self, stream_types:Sequence[str], for_write:bool=None)->Stream: 42 | streams = [self._create_stream_by_string(stream_type, for_write) for stream_type in stream_types] 43 | if len(streams) == 1: 44 | return self._streams[0] 45 | else: 46 | # we create new union of child but this is not necessory 47 | return StreamUnion(streams, for_write=for_write) 48 | 49 | def _get_stream_name(stream_type:str, stream_args:Any, for_write:bool)->str: 50 | return '{}:{}:{}'.format(stream_type, stream_args, for_write) 51 | 52 | def _create_stream_by_string(self, stream_spec:str, for_write:bool)->Stream: 53 | parts = stream_spec.split(':', 1) if stream_spec is not None else [''] 54 | stream_type = parts[0] 55 | stream_args = parts[1] if len(parts) > 1 else None 56 | 57 | utils.debug_log("Creating stream", (stream_spec, for_write)) 58 | 59 | if stream_type == 'tcp': 60 | port = int(stream_args or 0) 61 | stream_name = StreamFactory._get_stream_name(stream_type, port, for_write) 62 | if stream_name not in self._streams: 63 | self._streams[stream_name] = ZmqStream(for_write=for_write, 64 | port=port, stream_name=stream_name, block_until_connected=False) 65 | # else we already have this stream 66 | return self._streams[stream_name] 67 | 68 | 69 | if stream_args is None: # file name specified without 'file:' prefix 70 | stream_args = stream_type 71 | stream_type = 'file' 72 | if len(stream_type) == 1: # windows drive letter 73 | stream_type = 'file' 74 | stream_args = stream_spec 75 | 76 | if stream_type == 'file': 77 | if stream_args is None: 78 | raise ValueError('File name must be specified for stream type "file"') 79 | stream_name = StreamFactory._get_stream_name(stream_type, stream_args, for_write) 80 | # each read only file stream should be separate stream or otheriwse sharing will 81 | # change seek positions 82 | if not for_write: 83 | stream_name += ':' + str(uuid.uuid4()) 84 | 85 | # if write file exist then flush it before read stream would read it 86 | write_stream_name = StreamFactory._get_stream_name(stream_type, stream_args, True) 87 | write_file_stream = self._streams.get(write_stream_name, None) 88 | if write_file_stream: 89 | write_file_stream.save() 90 | if stream_name not in self._streams: 91 | self._streams[stream_name] = FileStream(for_write=for_write, 92 | file_name=stream_args, stream_name=stream_name) 93 | # else we already have this stream 94 | return self._streams[stream_name] 95 | 96 | if stream_type == '': 97 | return Stream() 98 | 99 | raise ValueError('stream_type "{}" has unknown type'.format(stream_type)) 100 | 101 | 102 | -------------------------------------------------------------------------------- /tensorwatch/saliency/lime/wrappers/scikit_image.py: -------------------------------------------------------------------------------- 1 | import types 2 | from .generic_utils import has_arg 3 | from skimage.segmentation import felzenszwalb, slic, quickshift 4 | 5 | 6 | class BaseWrapper(object): 7 | """Base class for LIME Scikit-Image wrapper 8 | 9 | 10 | Args: 11 | target_fn: callable function or class instance 12 | target_params: dict, parameters to pass to the target_fn 13 | 14 | 15 | 'target_params' takes parameters required to instanciate the 16 | desired Scikit-Image class/model 17 | """ 18 | 19 | def __init__(self, target_fn=None, **target_params): 20 | self.target_fn = target_fn 21 | self.target_params = target_params 22 | 23 | self.target_fn = target_fn 24 | self.target_params = target_params 25 | 26 | def _check_params(self, parameters): 27 | """Checks for mistakes in 'parameters' 28 | 29 | Args : 30 | parameters: dict, parameters to be checked 31 | 32 | Raises : 33 | ValueError: if any parameter is not a valid argument for the target function 34 | or the target function is not defined 35 | TypeError: if argument parameters is not iterable 36 | """ 37 | a_valid_fn = [] 38 | if self.target_fn is None: 39 | if callable(self): 40 | a_valid_fn.append(self.__call__) 41 | else: 42 | raise TypeError('invalid argument: tested object is not callable,\ 43 | please provide a valid target_fn') 44 | elif isinstance(self.target_fn, types.FunctionType) \ 45 | or isinstance(self.target_fn, types.MethodType): 46 | a_valid_fn.append(self.target_fn) 47 | else: 48 | a_valid_fn.append(self.target_fn.__call__) 49 | 50 | if not isinstance(parameters, str): 51 | for p in parameters: 52 | for fn in a_valid_fn: 53 | if has_arg(fn, p): 54 | pass 55 | else: 56 | raise ValueError('{} is not a valid parameter'.format(p)) 57 | else: 58 | raise TypeError('invalid argument: list or dictionnary expected') 59 | 60 | def set_params(self, **params): 61 | """Sets the parameters of this estimator. 62 | Args: 63 | **params: Dictionary of parameter names mapped to their values. 64 | 65 | Raises : 66 | ValueError: if any parameter is not a valid argument 67 | for the target function 68 | """ 69 | self._check_params(params) 70 | self.target_params = params 71 | 72 | def filter_params(self, fn, override=None): 73 | """Filters `target_params` and return those in `fn`'s arguments. 74 | Args: 75 | fn : arbitrary function 76 | override: dict, values to override target_params 77 | Returns: 78 | result : dict, dictionary containing variables 79 | in both target_params and fn's arguments. 80 | """ 81 | override = override or {} 82 | result = {} 83 | for name, value in self.target_params.items(): 84 | if has_arg(fn, name): 85 | result.update({name: value}) 86 | result.update(override) 87 | return result 88 | 89 | 90 | class SegmentationAlgorithm(BaseWrapper): 91 | """ Define the image segmentation function based on Scikit-Image 92 | implementation and a set of provided parameters 93 | 94 | Args: 95 | algo_type: string, segmentation algorithm among the following: 96 | 'quickshift', 'slic', 'felzenszwalb' 97 | target_params: dict, algorithm parameters (valid model paramters 98 | as define in Scikit-Image documentation) 99 | """ 100 | 101 | def __init__(self, algo_type, **target_params): 102 | self.algo_type = algo_type 103 | if (self.algo_type == 'quickshift'): 104 | BaseWrapper.__init__(self, quickshift, **target_params) 105 | kwargs = self.filter_params(quickshift) 106 | self.set_params(**kwargs) 107 | elif (self.algo_type == 'felzenszwalb'): 108 | BaseWrapper.__init__(self, felzenszwalb, **target_params) 109 | kwargs = self.filter_params(felzenszwalb) 110 | self.set_params(**kwargs) 111 | elif (self.algo_type == 'slic'): 112 | BaseWrapper.__init__(self, slic, **target_params) 113 | kwargs = self.filter_params(slic) 114 | self.set_params(**kwargs) 115 | 116 | def __call__(self, *args): 117 | return self.target_fn(args[0], **self.target_params) 118 | -------------------------------------------------------------------------------- /tensorwatch/evaler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import threading, sys, logging 5 | from collections.abc import Iterator 6 | from .lv_types import EventData 7 | 8 | # pylint: disable=unused-wildcard-import 9 | # pylint: disable=wildcard-import 10 | # pylint: disable=unused-import 11 | from functools import * 12 | from itertools import * 13 | from statistics import * 14 | import numpy as np 15 | from .evaler_utils import * 16 | 17 | class Evaler: 18 | class EvalReturn: 19 | def __init__(self, result=None, is_valid=False, exception=None): 20 | self.result, self.exception, self.is_valid = \ 21 | result, exception, is_valid 22 | def reset(self): 23 | self.result, self.exception, self.is_valid = \ 24 | None, None, False 25 | 26 | class PostableIterator: 27 | def __init__(self, eval_wait): 28 | self.eval_wait = eval_wait 29 | self.post_wait = threading.Event() 30 | self.event_data, self.ended = None, None # define attributes in init 31 | self.reset() 32 | 33 | def reset(self): 34 | self.event_data, self.ended = None, False 35 | self.post_wait.clear() 36 | 37 | def abort(self): 38 | self.ended = True 39 | self.post_wait.set() 40 | 41 | def post(self, event_data:EventData=None, ended=False): 42 | self.event_data, self.ended = event_data, ended 43 | self.post_wait.set() 44 | 45 | def get_vals(self): 46 | while True: 47 | self.post_wait.wait() 48 | self.post_wait.clear() 49 | if self.ended: 50 | break 51 | else: 52 | yield self.event_data 53 | # below will cause result=None, is_valid=False when 54 | # expression has reduce 55 | self.eval_wait.set() 56 | 57 | def __init__(self, expr): 58 | self.eval_wait = threading.Event() 59 | self.reset_wait = threading.Event() 60 | self.g = Evaler.PostableIterator(self.eval_wait) 61 | self.expr = expr 62 | self.eval_return, self.continue_thread = None, None # define in __init__ 63 | self.reset() 64 | 65 | self.th = threading.Thread(target=self._runner, daemon=True, name='evaler') 66 | self.th.start() 67 | self.running = True 68 | 69 | def reset(self): 70 | self.g.reset() 71 | self.eval_wait.clear() 72 | self.reset_wait.clear() 73 | self.eval_return = Evaler.EvalReturn() 74 | self.continue_thread = True 75 | 76 | def _runner(self): 77 | while True: 78 | # this var will be used by eval 79 | l = self.g.get_vals() # pylint: disable=unused-variable 80 | try: 81 | result = eval(self.expr) # pylint: disable=eval-used 82 | if isinstance(result, Iterator): 83 | for item in result: 84 | self.eval_return = Evaler.EvalReturn(item, True) 85 | else: 86 | self.eval_return = Evaler.EvalReturn(result, True) 87 | except Exception as ex: # pylint: disable=broad-except 88 | logging.exception('Exception occured while evaluating expression: ' + self.expr) 89 | self.eval_return = Evaler.EvalReturn(None, True, ex) 90 | self.eval_wait.set() 91 | self.reset_wait.wait() 92 | if not self.continue_thread: 93 | break 94 | self.reset() 95 | self.running = False 96 | utils.debug_log('eval runner ended!') 97 | 98 | def abort(self): 99 | utils.debug_log('Evaler Aborted') 100 | self.continue_thread = False 101 | self.g.abort() 102 | self.eval_wait.set() 103 | self.reset_wait.set() 104 | 105 | def post(self, event_data:EventData=None, ended=False, continue_thread=True): 106 | if not self.running: 107 | utils.debug_log('post was called when Evaler is not running') 108 | return None, False 109 | self.eval_return.reset() 110 | self.g.post(event_data, ended) 111 | self.eval_wait.wait() 112 | self.eval_wait.clear() 113 | # save result before it would get reset 114 | eval_return = self.eval_return 115 | self.reset_wait.set() 116 | self.continue_thread = continue_thread 117 | if isinstance(eval_return.result, Iterator): 118 | eval_return.result = list(eval_return.result) 119 | return eval_return 120 | 121 | def join(self): 122 | self.th.join() 123 | -------------------------------------------------------------------------------- /tensorwatch/plotly/base_plotly_plot.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from ..vis_base import VisBase 5 | import time 6 | from abc import abstractmethod 7 | from .. import utils 8 | 9 | 10 | class BasePlotlyPlot(VisBase): 11 | def __init__(self, cell:VisBase.widgets.Box=None, title=None, show_legend:bool=None, is_3d:bool=False, 12 | stream_name:str=None, console_debug:bool=False, **vis_args): 13 | import plotly.graph_objs as go # function-level import as this takes long time 14 | super(BasePlotlyPlot, self).__init__(go.FigureWidget(), cell, title, show_legend, 15 | stream_name=stream_name, console_debug=console_debug, **vis_args) 16 | 17 | self.is_3d = is_3d 18 | self.widget.layout.title = title 19 | self.widget.layout.showlegend = show_legend if show_legend is not None else True 20 | 21 | def _add_trace(self, stream_vis): 22 | stream_vis.trace_index = len(self.widget.data) 23 | trace = self._create_trace(stream_vis) 24 | if stream_vis.opacity is not None: 25 | trace.opacity = stream_vis.opacity 26 | self.widget.add_trace(trace) 27 | 28 | def _add_trace_with_history(self, stream_vis): 29 | # if history buffer isn't full 30 | if stream_vis.history_len > len(stream_vis.trace_history): 31 | self._add_trace(stream_vis) 32 | stream_vis.trace_history.append(len(self.widget.data)-1) 33 | stream_vis.cur_history_index = len(stream_vis.trace_history)-1 34 | #if stream_vis.cur_history_index: 35 | # self.widget.data[trace_index].showlegend = False 36 | else: 37 | # rotate trace 38 | stream_vis.cur_history_index = (stream_vis.cur_history_index + 1) % stream_vis.history_len 39 | stream_vis.trace_index = stream_vis.trace_history[stream_vis.cur_history_index] 40 | self.clear_plot(stream_vis, False) 41 | self.widget.data[stream_vis.trace_index].opacity = stream_vis.opacity or 1 42 | 43 | cur_history_len = len(stream_vis.trace_history) 44 | if stream_vis.dim_history and cur_history_len > 1: 45 | max_opacity = stream_vis.opacity or 1 46 | min_alpha, max_alpha, dimmed_len = max_opacity*0.05, max_opacity*0.8, cur_history_len-1 47 | alphas = list(utils.frange(max_alpha, min_alpha, steps=dimmed_len)) 48 | for i, thi in enumerate(range(stream_vis.cur_history_index+1, 49 | stream_vis.cur_history_index+cur_history_len)): 50 | trace_index = stream_vis.trace_history[thi % cur_history_len] 51 | self.widget.data[trace_index].opacity = alphas[i] 52 | 53 | @staticmethod 54 | def get_pallet_color(i:int): 55 | import plotly # function-level import as this takes long time 56 | return plotly.colors.DEFAULT_PLOTLY_COLORS[i % len(plotly.colors.DEFAULT_PLOTLY_COLORS)] 57 | 58 | @staticmethod 59 | def _get_axis_common_props(title:str, axis_range:tuple): 60 | props = {'showline':True, 'showgrid': True, 61 | 'showticklabels': True, 'ticks':'inside'} 62 | if title: 63 | props['title'] = title 64 | if axis_range: 65 | props['range'] = list(axis_range) 66 | return props 67 | 68 | def _can_update_stream_plots(self): 69 | return time.time() - self.q_last_processed > 0.5 # make configurable 70 | 71 | def _post_add_subscription(self, stream_vis, **stream_vis_args): 72 | stream_vis.trace_history, stream_vis.cur_history_index = [], None 73 | self._add_trace_with_history(stream_vis) 74 | self._setup_layout(stream_vis) 75 | 76 | if not self.widget.layout.title: 77 | self.widget.layout.title = stream_vis.title 78 | # TODO: better way for below? 79 | if stream_vis.history_len > 1: 80 | self.widget.layout.showlegend = False 81 | 82 | def _show_widget_native(self, blocking:bool): 83 | pass 84 | #TODO: save image, spawn browser? 85 | 86 | def _show_widget_notebook(self): 87 | #plotly.offline.iplot(self.widget) 88 | return None 89 | 90 | def _post_update_stream_plot(self, stream_vis): 91 | # not needed for plotly as FigureWidget stays upto date 92 | pass 93 | 94 | @abstractmethod 95 | def clear_plot(self, stream_vis, clear_history): 96 | """(for derived class) Clears the data in specified plot before new data is redrawn""" 97 | pass 98 | @abstractmethod 99 | def _show_stream_items(self, stream_vis, stream_items): 100 | """Paint the given stream_items in to visualizer. If visualizer is dirty then return False else True. 101 | """ 102 | 103 | pass 104 | @abstractmethod 105 | def _setup_layout(self, stream_vis): 106 | pass 107 | @abstractmethod 108 | def _create_trace(self, stream_vis): 109 | pass 110 | -------------------------------------------------------------------------------- /tensorwatch/image_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import numpy as np 5 | import math 6 | import time 7 | 8 | def guess_image_dims(img): 9 | if len(img.shape) == 1: 10 | # assume 2D monochrome (MNIST) 11 | width = height = round(math.sqrt(img.shape[0])) 12 | if width*height != img.shape[0]: 13 | # assume 3 channels (CFAR, ImageNet) 14 | width = height = round(math.sqrt(img.shape[0] / 3)) 15 | if width*height*3 != img.shape[0]: 16 | raise ValueError("Cannot guess image dimensions for linearized pixels") 17 | return (3, height, width) 18 | return (1, height, width) 19 | return img.shape 20 | 21 | def to_imshow_array(img, width=None, height=None): 22 | # array from Pytorch has shape: [[channels,] height, width] 23 | # image needed for imshow needs: [height, width, channels] 24 | from PIL import Image 25 | 26 | if img is not None: 27 | if isinstance(img, Image.Image): 28 | img = np.array(img) 29 | if len(img.shape) >= 2: 30 | return img # img is already compatible to imshow 31 | 32 | # force max 3 dimensions 33 | if len(img.shape) > 3: 34 | # TODO allow config 35 | # select first one in batch 36 | img = img[0:1,:,:] 37 | 38 | if len(img.shape) == 1: # linearized pixels typically used for MLPs 39 | if not(width and height): 40 | # pylint: disable=unused-variable 41 | channels, height, width = guess_image_dims(img) 42 | img = img.reshape((-1, height, width)) 43 | 44 | if len(img.shape) == 3: 45 | if img.shape[0] == 1: # single channel images 46 | img = img.squeeze(0) 47 | else: 48 | img = np.swapaxes(img, 0, 2) # transpose H,W for imshow 49 | img = np.swapaxes(img, 0, 1) 50 | elif len(img.shape) == 2: 51 | img = np.swapaxes(img, 0, 1) # transpose H,W for imshow 52 | else: #zero dimensions 53 | img = None 54 | 55 | return img 56 | 57 | #width_dim=1 for imshow, 2 for pytorch arrays 58 | def stitch_horizontal(images, width_dim=1): 59 | return np.concatenate(images, axis=width_dim) 60 | 61 | def _resize_image(img, size=None): 62 | if size is not None or (hasattr(img, 'shape') and len(img.shape) == 1): 63 | if size is None: 64 | # make guess for 1-dim tensors 65 | h = int(math.sqrt(img.shape[0])) 66 | w = int(img.shape[0] / h) 67 | size = h,w 68 | img = np.reshape(img, size) 69 | return img 70 | 71 | def show_image(img, size=None, alpha=None, cmap=None, 72 | img2=None, size2=None, alpha2=None, cmap2=None, ax=None): 73 | import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue 74 | 75 | img =_resize_image(img, size) 76 | img2 =_resize_image(img2, size2) 77 | 78 | (ax or plt).imshow(img, alpha=alpha, cmap=cmap) 79 | 80 | if img2 is not None: 81 | (ax or plt).imshow(img2, alpha=alpha2, cmap=cmap2) 82 | 83 | return ax or plt.show() 84 | 85 | # convert_mode param is mode: https://pillow.readthedocs.io/en/5.1.x/handbook/concepts.html#modes 86 | # use convert_mode='RGB' to force 3 channels 87 | def open_image(path, resize=None, resample=1, convert_mode=None): # Image.ANTIALIAS==1 88 | from PIL import Image 89 | 90 | img = Image.open(path) 91 | if resize is not None: 92 | img = img.resize(resize, resample) 93 | if convert_mode is not None: 94 | img = img.convert(convert_mode) 95 | return img 96 | 97 | def img2pyt(img, add_batch_dim=True, resize=None): 98 | from torchvision import transforms # expensive function-level import 99 | 100 | ts = [] 101 | if resize is not None: 102 | ts.append(transforms.RandomResizedCrop(resize)) 103 | ts.append(transforms.ToTensor()) 104 | img_pyt = transforms.Compose(ts)(img) 105 | if add_batch_dim: 106 | img_pyt.unsqueeze_(0) 107 | return img_pyt 108 | 109 | def linear_to_2d(img, size=None): 110 | if size is not None or (hasattr(img, 'shape') and len(img.shape) == 1): 111 | if size is None: 112 | # make guess for 1-dim tensors 113 | h = int(math.sqrt(img.shape[0])) 114 | w = int(img.shape[0] / h) 115 | size = h,w 116 | img = np.reshape(img, size) 117 | return img 118 | 119 | def stack_images(imgs): 120 | return np.hstack(imgs) 121 | 122 | def plt_loop(count=None, sleep_time=1, plt_pause=0.01): 123 | import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue 124 | 125 | #plt.ion() 126 | #plt.show(block=False) 127 | while((count is None or count > 0) and not plt.waitforbuttonpress(plt_pause)): 128 | #plt.draw() 129 | plt.pause(plt_pause) 130 | time.sleep(sleep_time) 131 | if count is not None: 132 | count = count - 1 133 | 134 | def get_cmap(name:str): 135 | import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue 136 | return plt.cm.get_cmap(name=name) 137 | 138 | -------------------------------------------------------------------------------- /tensorwatch/visualizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from .stream import Stream 5 | from .vis_base import VisBase 6 | from . import mpl 7 | from . import plotly 8 | 9 | class Visualizer: 10 | """Constructs visualizer for specified vis_type. 11 | 12 | NOTE: If you modify arguments here then also sync VisArgs contructor. 13 | """ 14 | def __init__(self, stream:Stream, vis_type:str=None, host:'Visualizer'=None, 15 | cell:'Visualizer'=None, title:str=None, 16 | clear_after_end=False, clear_after_each=False, history_len=1, dim_history=True, opacity=None, 17 | 18 | rows=2, cols=5, img_width=None, img_height=None, img_channels=None, 19 | colormap=None, viz_img_scale=None, 20 | 21 | # these image params are for hover on point for t-sne 22 | hover_images=None, hover_image_reshape=None, cell_width:str=None, cell_height:str=None, 23 | 24 | only_summary=False, separate_yaxis=True, xtitle=None, ytitle=None, ztitle=None, color=None, 25 | xrange=None, yrange=None, zrange=None, draw_line=True, draw_marker=False, 26 | 27 | # histogram 28 | bins=None, normed=None, histtype='bar', edge_color=None, linewidth=None, bar_width=None, 29 | 30 | # pie chart 31 | autopct=None, shadow=None, 32 | 33 | vis_args={}, stream_vis_args={})->None: 34 | 35 | cell = cell._host_base.cell if cell is not None else None 36 | 37 | if host: 38 | self._host_base = host._host_base 39 | else: 40 | self._host_base = self._get_vis_base(vis_type, cell, title, hover_images=hover_images, hover_image_reshape=hover_image_reshape, 41 | cell_width=cell_width, cell_height=cell_height, 42 | **vis_args) 43 | 44 | self._host_base.subscribe(stream, show=False, clear_after_end=clear_after_end, clear_after_each=clear_after_each, 45 | history_len=history_len, dim_history=dim_history, opacity=opacity, 46 | only_summary=only_summary if vis_type is None or 'summary' != vis_type else True, 47 | separate_yaxis=separate_yaxis, xtitle=xtitle, ytitle=ytitle, ztitle=ztitle, color=color, 48 | xrange=xrange, yrange=yrange, zrange=zrange, 49 | draw_line=draw_line if vis_type is not None and 'scatter' in vis_type else True, 50 | draw_marker=draw_marker, 51 | rows=rows, cols=cols, img_width=img_width, img_height=img_height, img_channels=img_channels, 52 | colormap=colormap, viz_img_scale=viz_img_scale, 53 | bins=bins, normed=normed, histtype=histtype, edge_color=edge_color, linewidth=linewidth, bar_width = bar_width, 54 | autopct=autopct, shadow=shadow, 55 | **stream_vis_args) 56 | 57 | stream.load() 58 | 59 | def show(self): 60 | return self._host_base.show() 61 | 62 | def _get_vis_base(self, vis_type, cell:VisBase.widgets.Box, title, hover_images=None, hover_image_reshape=None, cell_width=None, cell_height=None, **vis_args)->VisBase: 63 | if vis_type is None or vis_type in ['line', 64 | 'mpl-line', 'mpl-line3d', 'mpl-scatter3d', 'mpl-scatter']: 65 | return mpl.line_plot.LinePlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, 66 | is_3d=vis_type is not None and vis_type.endswith('3d'), **vis_args) 67 | if vis_type in ['image', 'mpl-image']: 68 | return mpl.image_plot.ImagePlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, **vis_args) 69 | if vis_type in ['bar', 'bar3d']: 70 | return mpl.bar_plot.BarPlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, 71 | is_3d=vis_type.endswith('3d'), **vis_args) 72 | if vis_type in ['histogram']: 73 | return mpl.histogram.Histogram(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, **vis_args) 74 | if vis_type in ['pie']: 75 | return mpl.pie_chart.PieChart(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, **vis_args) 76 | if vis_type in ['text', 'summary']: 77 | from .text_vis import TextVis 78 | return TextVis(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, **vis_args) 79 | if vis_type in ['line3d', 'scatter', 'scatter3d', 80 | 'plotly-line', 'plotly-line3d', 'plotly-scatter', 'plotly-scatter3d', 'mesh3d']: 81 | return plotly.line_plot.LinePlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, 82 | is_3d=vis_type.endswith('3d'), **vis_args) 83 | if vis_type in ['tsne', 'embeddings', 'tsne2d', 'embeddings2d']: 84 | return plotly.embeddings_plot.EmbeddingsPlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, 85 | is_3d='2d' not in vis_type, 86 | hover_images=hover_images, hover_image_reshape=hover_image_reshape, **vis_args) 87 | else: 88 | raise ValueError('Render vis_type parameter has invalid value: "{}"'.format(vis_type)) 89 | -------------------------------------------------------------------------------- /tensorwatch/saliency/backprop.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.autograd import Variable, Function 3 | import torch 4 | import types 5 | 6 | 7 | class VanillaGradExplainer(object): 8 | def __init__(self, model): 9 | self.model = model 10 | 11 | def _backprop(self, inp, ind): 12 | inp.requires_grad = True 13 | if inp.grad is not None: 14 | inp.grad.zero_() 15 | if ind.grad is not None: 16 | ind.grad.zero_() 17 | self.model.eval() 18 | self.model.zero_grad() 19 | 20 | output = self.model(inp) 21 | if ind is None: 22 | ind = output.max(1)[1] 23 | grad_out = output.clone() 24 | grad_out.fill_(0.0) 25 | grad_out.scatter_(1, ind.unsqueeze(0).t(), 1.0) 26 | output.backward(grad_out) 27 | return inp.grad 28 | 29 | def explain(self, inp, ind=None, raw_inp=None): 30 | return self._backprop(inp, ind) 31 | 32 | 33 | class GradxInputExplainer(VanillaGradExplainer): 34 | def __init__(self, model): 35 | super(GradxInputExplainer, self).__init__(model) 36 | 37 | def explain(self, inp, ind=None, raw_inp=None): 38 | grad = self._backprop(inp, ind) 39 | return inp * grad 40 | 41 | 42 | class SaliencyExplainer(VanillaGradExplainer): 43 | def __init__(self, model): 44 | super(SaliencyExplainer, self).__init__(model) 45 | 46 | def explain(self, inp, ind=None, raw_inp=None): 47 | grad = self._backprop(inp, ind) 48 | return grad.abs() 49 | 50 | 51 | class IntegrateGradExplainer(VanillaGradExplainer): 52 | def __init__(self, model, steps=100): 53 | super(IntegrateGradExplainer, self).__init__(model) 54 | self.steps = steps 55 | 56 | def explain(self, inp, ind=None, raw_inp=None): 57 | grad = 0 58 | inp_data = inp.clone() 59 | 60 | for alpha in np.arange(1 / self.steps, 1.0, 1 / self.steps): 61 | new_inp = Variable(inp_data * alpha, requires_grad=True) 62 | g = self._backprop(new_inp, ind) 63 | grad += g 64 | 65 | return grad * inp_data / self.steps 66 | 67 | 68 | class DeconvExplainer(VanillaGradExplainer): 69 | def __init__(self, model): 70 | super(DeconvExplainer, self).__init__(model) 71 | self._override_backward() 72 | 73 | def _override_backward(self): 74 | class _ReLU(Function): 75 | @staticmethod 76 | def forward(ctx, input): 77 | output = torch.clamp(input, min=0) 78 | return output 79 | 80 | @staticmethod 81 | def backward(ctx, grad_output): 82 | grad_inp = torch.clamp(grad_output, min=0) 83 | return grad_inp 84 | 85 | def new_forward(self, x): 86 | return _ReLU.apply(x) 87 | 88 | def replace(m): 89 | if m.__class__.__name__ == 'ReLU': 90 | m.forward = types.MethodType(new_forward, m) 91 | 92 | self.model.apply(replace) 93 | 94 | 95 | class GuidedBackpropExplainer(VanillaGradExplainer): 96 | def __init__(self, model): 97 | super(GuidedBackpropExplainer, self).__init__(model) 98 | self._override_backward() 99 | 100 | def _override_backward(self): 101 | class _ReLU(Function): 102 | @staticmethod 103 | def forward(ctx, input): 104 | output = torch.clamp(input, min=0) 105 | ctx.save_for_backward(output) 106 | return output 107 | 108 | @staticmethod 109 | def backward(ctx, grad_output): 110 | output, = ctx.saved_tensors 111 | mask1 = (output > 0).float() 112 | mask2 = (grad_output > 0).float() 113 | grad_inp = mask1 * mask2 * grad_output 114 | grad_output.copy_(grad_inp) 115 | return grad_output 116 | 117 | def new_forward(self, x): 118 | return _ReLU.apply(x) 119 | 120 | def replace(m): 121 | if m.__class__.__name__ == 'ReLU': 122 | m.forward = types.MethodType(new_forward, m) 123 | 124 | self.model.apply(replace) 125 | 126 | 127 | # modified from https://github.com/PAIR-code/saliency/blob/master/saliency/base.py#L80 128 | class SmoothGradExplainer(object): 129 | def __init__(self, model, base_explainer=None, stdev_spread=0.15, 130 | nsamples=25, magnitude=True): 131 | self.base_explainer = base_explainer or VanillaGradExplainer(model) 132 | self.stdev_spread = stdev_spread 133 | self.nsamples = nsamples 134 | self.magnitude = magnitude 135 | 136 | def explain(self, inp, ind=None, raw_inp=None): 137 | stdev = self.stdev_spread * (inp.max() - inp.min()) 138 | 139 | total_gradients = 0 140 | 141 | for i in range(self.nsamples): 142 | noise = torch.randn_like(inp) * stdev 143 | 144 | noisy_inp = inp + noise 145 | noisy_inp.retain_grad() 146 | grad = self.base_explainer.explain(noisy_inp, ind) 147 | 148 | if self.magnitude: 149 | total_gradients += grad ** 2 150 | else: 151 | total_gradients += grad 152 | 153 | return total_gradients / self.nsamples -------------------------------------------------------------------------------- /tensorwatch/saliency/saliency.py: -------------------------------------------------------------------------------- 1 | from .gradcam import GradCAMExplainer 2 | from .backprop import VanillaGradExplainer, GradxInputExplainer, SaliencyExplainer, \ 3 | IntegrateGradExplainer, DeconvExplainer, GuidedBackpropExplainer, SmoothGradExplainer 4 | from .deeplift import DeepLIFTRescaleExplainer 5 | from .occlusion import OcclusionExplainer 6 | from .epsilon_lrp import EpsilonLrp 7 | from .lime_image_explainer import LimeImageExplainer, LimeImagenetExplainer 8 | import skimage.transform 9 | import torch 10 | import math 11 | from .. import image_utils 12 | 13 | class ImageSaliencyResult: 14 | def __init__(self, raw_image, saliency, title, saliency_alpha=0.4, saliency_cmap='jet'): 15 | self.raw_image, self.saliency, self.title = raw_image, saliency, title 16 | self.saliency_alpha, self.saliency_cmap = saliency_alpha, saliency_cmap 17 | 18 | def _get_explainer(explainer_name, model, layer_path=None): 19 | if explainer_name == 'gradcam': 20 | return GradCAMExplainer(model, target_layer_name_keys=layer_path, use_inp=True) 21 | if explainer_name == 'vanilla_grad': 22 | return VanillaGradExplainer(model) 23 | if explainer_name == 'grad_x_input': 24 | return GradxInputExplainer(model) 25 | if explainer_name == 'saliency': 26 | return SaliencyExplainer(model) 27 | if explainer_name == 'integrate_grad': 28 | return IntegrateGradExplainer(model) 29 | if explainer_name == 'deconv': 30 | return DeconvExplainer(model) 31 | if explainer_name == 'guided_backprop': 32 | return GuidedBackpropExplainer(model) 33 | if explainer_name == 'smooth_grad': 34 | return SmoothGradExplainer(model) 35 | if explainer_name == 'deeplift': 36 | return DeepLIFTRescaleExplainer(model) 37 | if explainer_name == 'occlusion': 38 | return OcclusionExplainer(model) 39 | if explainer_name == 'lrp': 40 | return EpsilonLrp(model) 41 | if explainer_name == 'lime_imagenet': 42 | return LimeImagenetExplainer(model) 43 | 44 | raise ValueError('Explainer {} is not recognized'.format(explainer_name)) 45 | 46 | def _get_layer_path(model): 47 | if model.__class__.__name__ == 'VGG': 48 | return ['features', '30'] # pool5 49 | elif model.__class__.__name__ == 'GoogleNet': 50 | return ['pool5'] 51 | elif model.__class__.__name__ == 'ResNet': 52 | return ['avgpool'] #layer4 53 | elif model.__class__.__name__ == 'Inception3': 54 | return ['Mixed_7c', 'branch_pool'] # ['conv2d_94'], 'mixed10' 55 | else: #TODO: guess layer for other networks? 56 | return None 57 | 58 | def get_saliency(model, raw_input, input, label, method='integrate_grad', layer_path=None): 59 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 60 | 61 | model.to(device) 62 | input = input.to(device) 63 | if label is not None: 64 | label = label.to(device) 65 | 66 | if input.grad is not None: 67 | input.grad.zero_() 68 | if label is not None and label.grad is not None: 69 | label.grad.zero_() 70 | model.eval() 71 | model.zero_grad() 72 | 73 | layer_path = layer_path or _get_layer_path(model) 74 | 75 | exp = _get_explainer(method, model, layer_path) 76 | saliency = exp.explain(input, label, raw_input) 77 | 78 | if saliency is not None: 79 | saliency = saliency.abs().sum(dim=1)[0].squeeze() 80 | saliency -= saliency.min() 81 | saliency /= (saliency.max() + 1e-20) 82 | 83 | return saliency.detach().cpu().numpy() 84 | else: 85 | return None 86 | 87 | def get_image_saliency_results(model, raw_image, input, label, 88 | methods=['lime_imagenet', 'gradcam', 'smooth_grad', 89 | 'guided_backprop', 'deeplift', 'grad_x_input'], 90 | layer_path=None): 91 | results = [] 92 | for method in methods: 93 | sal = get_saliency(model, raw_image, input, label, method=method) 94 | 95 | if sal is not None: 96 | results.append(ImageSaliencyResult(raw_image, sal, method)) 97 | return results 98 | 99 | def get_image_saliency_plot(image_saliency_results, cols:int=2, figsize:tuple=None): 100 | import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue 101 | 102 | rows = math.ceil(len(image_saliency_results) / cols) 103 | figsize=figsize or (8, 3 * rows) 104 | figure = plt.figure(figsize=figsize) 105 | 106 | for i, r in enumerate(image_saliency_results): 107 | ax = figure.add_subplot(rows, cols, i+1) 108 | ax.set_xticks([]) 109 | ax.set_yticks([]) 110 | ax.set_title(r.title, fontdict={'fontsize': 24}) #'fontweight': 'light' 111 | 112 | #upsampler = nn.Upsample(size=(raw_image.height, raw_image.width), mode='bilinear') 113 | saliency_upsampled = skimage.transform.resize(r.saliency, 114 | (r.raw_image.height, r.raw_image.width), 115 | mode='reflect') 116 | 117 | image_utils.show_image(r.raw_image, img2=saliency_upsampled, 118 | alpha2=r.saliency_alpha, cmap2=r.saliency_cmap, ax=ax) 119 | return figure 120 | -------------------------------------------------------------------------------- /tensorwatch/model_graph/hiddenlayer/pytorch_builder.py: -------------------------------------------------------------------------------- 1 | """ 2 | HiddenLayer 3 | 4 | PyTorch graph importer. 5 | 6 | Written by Waleed Abdulla 7 | Licensed under the MIT License 8 | """ 9 | 10 | from __future__ import absolute_import, division, print_function 11 | import re 12 | from .graph import Graph, Node 13 | from . import transforms as ht 14 | import torch 15 | from collections import abc 16 | import numpy as np 17 | 18 | # PyTorch Graph Transforms 19 | FRAMEWORK_TRANSFORMS = [ 20 | # Hide onnx: prefix 21 | ht.Rename(op=r"onnx::(.*)", to=r"\1"), 22 | # ONNX uses Gemm for linear layers (stands for General Matrix Multiplication). 23 | # It's an odd name that noone recognizes. Rename it. 24 | ht.Rename(op=r"Gemm", to=r"Linear"), 25 | # PyTorch layers that don't have an ONNX counterpart 26 | ht.Rename(op=r"aten::max\_pool2d\_with\_indices", to="MaxPool"), 27 | # Shorten op name 28 | ht.Rename(op=r"BatchNormalization", to="BatchNorm"), 29 | ] 30 | 31 | 32 | def dump_pytorch_graph(graph): 33 | """List all the nodes in a PyTorch graph.""" 34 | f = "{:25} {:40} {} -> {}" 35 | print(f.format("kind", "scopeName", "inputs", "outputs")) 36 | for node in graph.nodes(): 37 | print(f.format(node.kind(), node.scopeName(), 38 | [i.unique() for i in node.inputs()], 39 | [i.unique() for i in node.outputs()] 40 | )) 41 | 42 | 43 | def pytorch_id(node): 44 | """Returns a unique ID for a node.""" 45 | # After ONNX simplification, the scopeName is not unique anymore 46 | # so append node outputs to guarantee uniqueness 47 | return node.scopeName() + "/outputs/" + "/".join([o.uniqueName() for o in node.outputs()]) 48 | 49 | 50 | def get_shape(torch_node): 51 | """Return the output shape of the given Pytorch node.""" 52 | # Extract node output shape from the node string representation 53 | # This is a hack because there doesn't seem to be an official way to do it. 54 | # See my quesiton in the PyTorch forum: 55 | # https://discuss.pytorch.org/t/node-output-shape-from-trace-graph/24351/2 56 | # TODO: find a better way to extract output shape 57 | # TODO: Assuming the node has one output. Update if we encounter a multi-output node. 58 | m = re.match(r".*Float\(([\d\s\,]+)\).*", str(next(torch_node.outputs()))) 59 | if m: 60 | shape = m.group(1) 61 | shape = shape.split(",") 62 | shape = tuple(map(int, shape)) 63 | else: 64 | shape = None 65 | return shape 66 | 67 | def calc_rf(model, input_shape): 68 | for n, p in model.named_parameters(): 69 | if not p.requires_grad: 70 | continue; 71 | if 'bias' in n: 72 | p.data.fill_(0) 73 | elif 'weight' in n: 74 | p.data.fill_(1) 75 | 76 | input = torch.ones(input_shape, requires_grad=True) 77 | output = model(input) 78 | out_shape = output.size() 79 | ndims = len(out_shape) 80 | grad = torch.zeros(out_shape) 81 | l_tmp=[] 82 | for i in xrange(ndims): 83 | if i==0 or i==1:#batch or channel 84 | l_tmp.append(0) 85 | else: 86 | l_tmp.append(out_shape[i]/2) 87 | 88 | grad[tuple(l_tmp)] = 1 89 | output.backward(gradient=grad) 90 | grad_np = img_.grad[0,0].data.numpy() 91 | idx_nonzeros = np.where(grad_np!=0) 92 | RF=[np.max(idx)-np.min(idx)+1 for idx in idx_nonzeros] 93 | 94 | return RF 95 | 96 | def import_graph(hl_graph, model, args, input_names=None, verbose=False): 97 | # TODO: add input names to graph 98 | 99 | if args is None: 100 | args = [1, 3, 224, 224] # assume ImageNet default 101 | 102 | # if args is not Tensor but is array like then convert it to torch tensor 103 | if not isinstance(args, torch.Tensor) and \ 104 | hasattr(args, "__len__") and hasattr(args, '__getitem__') and \ 105 | not isinstance(args, (str, abc.ByteString)): 106 | args = torch.ones(args) 107 | 108 | # Run the Pytorch graph to get a trace and generate a graph from it 109 | trace, out = torch.jit.get_trace_graph(model, args) 110 | torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX) 111 | torch_graph = trace.graph() 112 | 113 | # Dump list of nodes (DEBUG only) 114 | if verbose: 115 | dump_pytorch_graph(torch_graph) 116 | 117 | # Loop through nodes and build HL graph 118 | for torch_node in torch_graph.nodes(): 119 | # Op 120 | op = torch_node.kind() 121 | # Parameters 122 | params = {k: torch_node[k] for k in torch_node.attributeNames()} 123 | # Inputs/outputs 124 | # TODO: inputs = [i.unique() for i in node.inputs()] 125 | outputs = [o.unique() for o in torch_node.outputs()] 126 | # Get output shape 127 | shape = get_shape(torch_node) 128 | # Add HL node 129 | hl_node = Node(uid=pytorch_id(torch_node), name=None, op=op, 130 | output_shape=shape, params=params) 131 | hl_graph.add_node(hl_node) 132 | # Add edges 133 | for target_torch_node in torch_graph.nodes(): 134 | target_inputs = [i.unique() for i in target_torch_node.inputs()] 135 | if set(outputs) & set(target_inputs): 136 | hl_graph.add_edge_by_id(pytorch_id(torch_node), pytorch_id(target_torch_node), shape) 137 | return hl_graph 138 | -------------------------------------------------------------------------------- /tensorwatch/model_graph/hiddenlayer/ge.py: -------------------------------------------------------------------------------- 1 | """ 2 | HiddenLayer 3 | 4 | Implementation graph expressions to find nodes in a graph based on a pattern. 5 | 6 | Written by Waleed Abdulla 7 | Licensed under the MIT License 8 | """ 9 | 10 | import re 11 | 12 | 13 | 14 | class GEParser(): 15 | def __init__(self, text): 16 | self.index = 0 17 | self.text = text 18 | 19 | def parse(self): 20 | return self.serial() or self.parallel() or self.expression() 21 | 22 | def parallel(self): 23 | index = self.index 24 | expressions = [] 25 | while len(expressions) == 0 or self.token("|"): 26 | e = self.expression() 27 | if not e: 28 | break 29 | expressions.append(e) 30 | if len(expressions) >= 2: 31 | return ParallelPattern(expressions) 32 | # No match. Reset index 33 | self.index = index 34 | 35 | def serial(self): 36 | index = self.index 37 | expressions = [] 38 | while len(expressions) == 0 or self.token(">"): 39 | e = self.expression() 40 | if not e: 41 | break 42 | expressions.append(e) 43 | 44 | if len(expressions) >= 2: 45 | return SerialPattern(expressions) 46 | self.index = index 47 | 48 | def expression(self): 49 | index = self.index 50 | 51 | if self.token("("): 52 | e = self.serial() or self.parallel() or self.op() 53 | if e and self.token(")"): 54 | return e 55 | self.index = index 56 | e = self.op() 57 | return e 58 | 59 | def op(self): 60 | t = self.re(r"\w+") 61 | if t: 62 | c = self.condition() 63 | return NodePattern(t, c) 64 | 65 | def condition(self): 66 | # TODO: not implemented yet. This function is a placeholder 67 | index = self.index 68 | if self.token("["): 69 | c = self.token("1x1") or self.token("3x3") 70 | if c: 71 | if self.token("]"): 72 | return c 73 | self.index = index 74 | 75 | def token(self, s): 76 | return self.re(r"\s*(" + re.escape(s) + r")\s*", 1) 77 | 78 | def string(self, s): 79 | if s == self.text[self.index:self.index+len(s)]: 80 | self.index += len(s) 81 | return s 82 | 83 | def re(self, regex, group=0): 84 | m = re.match(regex, self.text[self.index:]) 85 | if m: 86 | self.index += len(m.group(0)) 87 | return m.group(group) 88 | 89 | 90 | class NodePattern(): 91 | def __init__(self, op, condition=None): 92 | self.op = op 93 | self.condition = condition # TODO: not implemented yet 94 | 95 | def match(self, graph, node): 96 | if isinstance(node, list): 97 | return [], None 98 | if self.op == node.op: 99 | following = graph.outgoing(node) 100 | if len(following) == 1: 101 | following = following[0] 102 | return [node], following 103 | else: 104 | return [], None 105 | 106 | 107 | class SerialPattern(): 108 | def __init__(self, patterns): 109 | self.patterns = patterns 110 | 111 | def match(self, graph, node): 112 | all_matches = [] 113 | for i, p in enumerate(self.patterns): 114 | matches, following = p.match(graph, node) 115 | if not matches: 116 | return [], None 117 | all_matches.extend(matches) 118 | if i < len(self.patterns) - 1: 119 | node = following # Might be more than one node 120 | return all_matches, following 121 | 122 | 123 | class ParallelPattern(): 124 | def __init__(self, patterns): 125 | self.patterns = patterns 126 | 127 | def match(self, graph, nodes): 128 | if not nodes: 129 | return [], None 130 | nodes = nodes if isinstance(nodes, list) else [nodes] 131 | # If a single node, assume we need to match with its siblings 132 | if len(nodes) == 1: 133 | nodes = graph.siblings(nodes[0]) 134 | else: 135 | # Verify all nodes have the same parent or all have no parent 136 | parents = [graph.incoming(n) for n in nodes] 137 | matches = [set(p) == set(parents[0]) for p in parents[1:]] 138 | if not all(matches): 139 | return [], None 140 | 141 | # TODO: If more nodes than patterns, we should consider 142 | # all permutations of the nodes 143 | if len(self.patterns) != len(nodes): 144 | return [], None 145 | 146 | patterns = self.patterns.copy() 147 | nodes = nodes.copy() 148 | all_matches = [] 149 | end_node = None 150 | for p in patterns: 151 | found = False 152 | for n in nodes: 153 | matches, following = p.match(graph, n) 154 | if matches: 155 | found = True 156 | nodes.remove(n) 157 | all_matches.extend(matches) 158 | # Verify all branches end in the same node 159 | if end_node: 160 | if end_node != following: 161 | return [], None 162 | else: 163 | end_node = following 164 | break 165 | if not found: 166 | return [], None 167 | return all_matches, end_node 168 | 169 | 170 | -------------------------------------------------------------------------------- /tensorwatch/model_graph/hiddenlayer/tf_builder.py: -------------------------------------------------------------------------------- 1 | """ 2 | HiddenLayer 3 | 4 | TensorFlow graph importer. 5 | 6 | Written by Phil Ferriere. Edits by Waleed Abdulla. 7 | Licensed under the MIT License 8 | """ 9 | 10 | from __future__ import absolute_import, division, print_function, unicode_literals 11 | import logging 12 | import tensorflow as tf 13 | from .graph import Graph, Node 14 | from . import transforms as ht 15 | 16 | 17 | FRAMEWORK_TRANSFORMS = [ 18 | # Rename VariableV2 op to Variable. Same for anything V2, V3, ...etc. 19 | ht.Rename(op=r"(\w+)V\d", to=r"\1"), 20 | ht.Prune("Const"), 21 | ht.Prune("PlaceholderWithDefault"), 22 | ht.Prune("Variable"), 23 | ht.Prune("VarIsInitializedOp"), 24 | ht.Prune("VarHandleOp"), 25 | ht.Prune("ReadVariableOp"), 26 | ht.PruneBranch("Assign"), 27 | ht.PruneBranch("AssignSub"), 28 | ht.PruneBranch("AssignAdd"), 29 | ht.PruneBranch("AssignVariableOp"), 30 | ht.Prune("ApplyMomentum"), 31 | ht.Prune("ApplyAdam"), 32 | ht.FoldId(r"^(gradients)/.*", "NoOp"), # Fold to NoOp then delete in the next step 33 | ht.Prune("NoOp"), 34 | ht.Rename(op=r"DepthwiseConv2dNative", to="SeparableConv"), 35 | ht.Rename(op=r"Conv2D", to="Conv"), 36 | ht.Rename(op=r"FusedBatchNorm", to="BatchNorm"), 37 | ht.Rename(op=r"MatMul", to="Linear"), 38 | ht.Fold("Conv > BiasAdd", "__first__"), 39 | ht.Fold("Linear > BiasAdd", "__first__"), 40 | ht.Fold("Shape > StridedSlice > Pack > Reshape", "__last__"), 41 | ht.FoldId(r"(.+)/dropout/.*", "Dropout"), 42 | ht.FoldId(r"(softmax_cross\_entropy)\_with\_logits.*", "SoftmaxCrossEntropy"), 43 | ] 44 | 45 | 46 | def dump_tf_graph(tfgraph, tfgraphdef): 47 | """List all the nodes in a TF graph. 48 | tfgraph: A TF Graph object. 49 | tfgraphdef: A TF GraphDef object. 50 | """ 51 | print("Nodes ({})".format(len(tfgraphdef.node))) 52 | f = "{:15} {:59} {:20} {}" 53 | print(f.format("kind", "scopeName", "shape", "inputs")) 54 | for node in tfgraphdef.node: 55 | scopename = node.name 56 | kind = node.op 57 | inputs = node.input 58 | shape = tf.graph_util.tensor_shape_from_node_def_name(tfgraph, scopename) 59 | print(f.format(kind, scopename, str(shape), inputs)) 60 | 61 | 62 | def import_graph(hl_graph, tf_graph, output=None, verbose=False): 63 | """Convert TF graph to directed graph 64 | tfgraph: A TF Graph object. 65 | output: Name of the output node (string). 66 | verbose: Set to True for debug print output 67 | """ 68 | # Get clean(er) list of nodes 69 | graph_def = tf_graph.as_graph_def(add_shapes=True) 70 | graph_def = tf.graph_util.remove_training_nodes(graph_def) 71 | 72 | # Dump list of TF nodes (DEBUG only) 73 | if verbose: 74 | dump_tf_graph(tf_graph, graph_def) 75 | 76 | # Loop through nodes and build the matching directed graph 77 | for tf_node in graph_def.node: 78 | # Read node details 79 | try: 80 | op, uid, name, shape, params = import_node(tf_node, tf_graph, verbose) 81 | except: 82 | if verbose: 83 | logging.exception("Failed to read node {}".format(tf_node)) 84 | continue 85 | 86 | # Add node 87 | hl_node = Node(uid=uid, name=name, op=op, output_shape=shape, params=params) 88 | hl_graph.add_node(hl_node) 89 | 90 | # Add edges 91 | for target_node in graph_def.node: 92 | target_inputs = target_node.input 93 | if uid in target_node.input: 94 | hl_graph.add_edge_by_id(uid, target_node.name, shape) 95 | return hl_graph 96 | 97 | 98 | def import_node(tf_node, tf_graph, verbose=False): 99 | # Operation type and name 100 | op = tf_node.op 101 | uid = tf_node.name 102 | name = None 103 | 104 | # Shape 105 | shape = None 106 | if tf_node.op != "NoOp": 107 | try: 108 | shape = tf.graph_util.tensor_shape_from_node_def_name(tf_graph, tf_node.name) 109 | # Is the shape is known, convert to a list 110 | if shape.ndims is not None: 111 | shape = shape.as_list() 112 | except: 113 | if verbose: 114 | logging.exception("Error reading shape of {}".format(tf_node.name)) 115 | 116 | # Parameters 117 | # At this stage, we really only care about two parameters: 118 | # 1/ the kernel size used by convolution layers 119 | # 2/ the stride used by convolutional and pooling layers (TODO: not fully working yet) 120 | 121 | # 1/ The kernel size is actually not stored in the convolution tensor but in its weight input. 122 | # The weights input has the shape [shape=[kernel, kernel, in_channels, filters]] 123 | # So we must fish for it 124 | params = {} 125 | if op == "Conv2D" or op == "DepthwiseConv2dNative": 126 | kernel_shape = tf.graph_util.tensor_shape_from_node_def_name(tf_graph, tf_node.input[1]) 127 | kernel_shape = [int(a) for a in kernel_shape] 128 | params["kernel_shape"] = kernel_shape[0:2] 129 | if 'strides' in tf_node.attr.keys(): 130 | strides = [int(a) for a in tf_node.attr['strides'].list.i] 131 | params["stride"] = strides[1:3] 132 | elif op == "MaxPool" or op == "AvgPool": 133 | # 2/ the stride used by pooling layers 134 | # See https://stackoverflow.com/questions/44124942/how-to-access-values-in-protos-in-tensorflow 135 | if 'ksize' in tf_node.attr.keys(): 136 | kernel_shape = [int(a) for a in tf_node.attr['ksize'].list.i] 137 | params["kernel_shape"] = kernel_shape[1:3] 138 | if 'strides' in tf_node.attr.keys(): 139 | strides = [int(a) for a in tf_node.attr['strides'].list.i] 140 | params["stride"] = strides[1:3] 141 | 142 | return op, uid, name, shape, params 143 | -------------------------------------------------------------------------------- /tensorwatch/mpl/bar_plot.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from .base_mpl_plot import BaseMplPlot 5 | from .. import utils 6 | from .. import image_utils 7 | import numpy as np 8 | from itertools import groupby 9 | import operator 10 | 11 | class BarPlot(BaseMplPlot): 12 | def init_stream_plot(self, stream_vis, 13 | xtitle='', ytitle='', ztitle='', colormap=None, color=None, 14 | edge_color=None, linewidth=None, align=None, bar_width=None, 15 | opacity=None, **stream_vis_args): 16 | 17 | # add main subplot 18 | stream_vis.align, stream_vis.linewidth = align, (linewidth or 2) 19 | stream_vis.bar_width = bar_width or 1 20 | stream_vis.ax = self.get_main_axis() 21 | stream_vis.series = {} 22 | stream_vis.bars_artists = [] # stores previously drawn bars 23 | 24 | stream_vis.cmap = image_utils.get_cmap(colormap or 'Set3') 25 | if color is None: 26 | if not self.is_3d: 27 | stream_vis.cmap((len(self._stream_vises)%stream_vis.cmap.N)/stream_vis.cmap.N) # pylint: disable=no-member 28 | stream_vis.color = color 29 | stream_vis.edge_color = 'black' 30 | stream_vis.opacity = opacity 31 | stream_vis.ax.set_xlabel(xtitle) 32 | stream_vis.ax.xaxis.label.set_style('italic') 33 | stream_vis.ax.set_ylabel(ytitle) 34 | stream_vis.ax.yaxis.label.set_color(color) 35 | stream_vis.ax.yaxis.label.set_style('italic') 36 | if self.is_3d: 37 | stream_vis.ax.set_zlabel(ztitle) 38 | stream_vis.ax.zaxis.label.set_style('italic') 39 | 40 | def is_show_grid(self): #override 41 | return False 42 | 43 | def clear_artists(self, stream_vis): 44 | for bar in stream_vis.bars_artists: 45 | bar.remove() 46 | stream_vis.bars_artists.clear() 47 | 48 | def clear_plot(self, stream_vis, clear_history): 49 | stream_vis.series.clear() 50 | self.clear_artists(stream_vis) 51 | 52 | def _val2tuple(val, x)->tuple: 53 | """Accept scaler val, (y,), (x, y), (label, y), (x, y, label), return (x, y, z, color, label) 54 | """ 55 | if utils.is_array_like(val): 56 | unpacker = lambda a0=None,a1=None,a2=None,*_:(a0,a1,a2) 57 | t = unpacker(*val) 58 | if len(t) == 1: # (y,) 59 | t = (x, t[0], 0, None, None) 60 | elif len(t) == 2: 61 | t = (x, t[1], 0, None, t[0]) if isinstance(t[0], str) else t + (None,) 62 | elif len(t) == 3: # either (x, y, z) or (x, y, label) 63 | t = (t[0], t[1], None, t[2]) if isinstance(t[2], str) else t + (None, None) 64 | elif len(t) == 4: # we assume (x, y, z, color) 65 | t += (None,) 66 | # else leave it alone 67 | else: # scaler 68 | t = (x, val, 0, None, None) 69 | 70 | return t 71 | 72 | 73 | def _show_stream_items(self, stream_vis, stream_items): 74 | """Paint the given stream_items in to visualizer. If visualizer is dirty then return False else True. 75 | """ 76 | 77 | vals = self._extract_vals(stream_items) 78 | if not len(vals): 79 | return True 80 | 81 | if not self.is_3d: 82 | existing_len = len(stream_vis.series) 83 | for i,val in enumerate(vals): 84 | t = BarPlot._val2tuple(val, i+existing_len) 85 | stream_vis.series[t[0]] = t # merge x with previous items 86 | x, y, labels = [t[0] for t in stream_vis.series.values()], \ 87 | [t[1] for t in stream_vis.series.values()], \ 88 | [t[4] for t in stream_vis.series.values()] 89 | 90 | self.clear_artists(stream_vis) # remove previous bars 91 | bar_container = stream_vis.ax.bar(x, y, 92 | width=stream_vis.bar_width, 93 | tick_label = labels if any(l is not None for l in labels) else None, 94 | color=stream_vis.color, edgecolor=stream_vis.edge_color, 95 | alpha=stream_vis.opacity, linewidth=stream_vis.linewidth) 96 | else: 97 | for i,val in enumerate(vals): 98 | t = BarPlot._val2tuple(val, None) # we should not use i paameter as 3d expects x,y,z 99 | z = t[2] 100 | if z not in stream_vis.series: 101 | stream_vis.series[z] = [] 102 | stream_vis.series[t[2]] += [t] # merge z with previous items 103 | 104 | # sort by z so we have consistent colors 105 | ts = sorted(stream_vis.series.items(), key=lambda g: g[0]) 106 | 107 | for zi, (z, tg) in enumerate(ts): 108 | x, y, labels = [t[0] for t in tg], \ 109 | [t[1] for t in tg], \ 110 | [t[4] for t in tg] 111 | colors = stream_vis.color or stream_vis.cmap.colors 112 | color = colors[zi % len(colors)] 113 | 114 | self.clear_artists(stream_vis) # remove previous bars 115 | bar_container = stream_vis.ax.bar(x, y, zs=z, zdir='y', 116 | width=stream_vis.bar_width, 117 | tick_label = labels if any(l is not None for l in labels) else None, 118 | color=color, 119 | edgecolor=stream_vis.edge_color, 120 | alpha=stream_vis.opacity or 0.8, linewidth=stream_vis.linewidth) 121 | 122 | stream_vis.bars_artists = bar_container.patches 123 | 124 | #stream_vis.ax.relim() 125 | #stream_vis.ax.autoscale_view() 126 | 127 | return False 128 | 129 | -------------------------------------------------------------------------------- /test/test.pyproj: -------------------------------------------------------------------------------- 1 | 2 | 3 | Debug 4 | 2.0 5 | 9a7fe67e-93f0-42b5-b58f-77320fc639e4 6 | 7 | 8 | post_train\saliency.py 9 | 10 | 11 | . 12 | . 13 | test 14 | test 15 | Global|ContinuumAnalytics|Anaconda36-64 16 | False 17 | 18 | 19 | true 20 | false 21 | 22 | 23 | true 24 | false 25 | 26 | 27 | 28 | 29 | 30 | Code 31 | 32 | 33 | Code 34 | 35 | 36 | Code 37 | 38 | 39 | 40 | Code 41 | 42 | 43 | Code 44 | 45 | 46 | Code 47 | 48 | 49 | Code 50 | 51 | 52 | Code 53 | 54 | 55 | Code 56 | 57 | 58 | Code 59 | 60 | 61 | Code 62 | 63 | 64 | Code 65 | 66 | 67 | Code 68 | 69 | 70 | Code 71 | 72 | 73 | Code 74 | 75 | 76 | Code 77 | 78 | 79 | Code 80 | 81 | 82 | Code 83 | 84 | 85 | Code 86 | 87 | 88 | Code 89 | 90 | 91 | Code 92 | 93 | 94 | Code 95 | 96 | 97 | Code 98 | 99 | 100 | Code 101 | 102 | 103 | Code 104 | 105 | 106 | Code 107 | 108 | 109 | Code 110 | 111 | 112 | Code 113 | 114 | 115 | Code 116 | 117 | 118 | Code 119 | 120 | 121 | 122 | Code 123 | 124 | 125 | Code 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | tensorwatch 134 | {cc8abc7f-ede1-4e13-b6b7-0041a5ec66a7} 135 | True 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 153 | 154 | 155 | 156 | 157 | 158 | -------------------------------------------------------------------------------- /NOTICE.md: -------------------------------------------------------------------------------- 1 | NOTICES AND INFORMATION 2 | Do Not Translate or Localize 3 | 4 | This software incorporates material from third parties. Microsoft makes certain 5 | open source code available at http://3rdpartysource.microsoft.com, or you may 6 | send a check or money order for US $5.00, including the product name, the open 7 | source component name, and version number, to: 8 | 9 | Source Code Compliance Team 10 | Microsoft Corporation 11 | One Microsoft Way 12 | Redmond, WA 98052 13 | USA 14 | 15 | Notwithstanding any other terms, you may reverse engineer this software to the 16 | extent required to debug changes to any libraries licensed under the GNU Lesser 17 | General Public License. 18 | 19 | 20 | **Component.** https://github.com/Swall0w/torchstat 21 | 22 | **Open Source License/Copyright Notice.** 23 | MIT License 24 | 25 | Copyright (c) 2018 Swall0w - Alan 26 | 27 | Permission is hereby granted, free of charge, to any person obtaining a copy 28 | of this software and associated documentation files (the "Software"), to deal 29 | in the Software without restriction, including without limitation the rights 30 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 31 | copies of the Software, and to permit persons to whom the Software is 32 | furnished to do so, subject to the following conditions: 33 | 34 | The above copyright notice and this permission notice shall be included in all 35 | copies or substantial portions of the Software. 36 | 37 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 38 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 39 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 40 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 41 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 42 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 43 | SOFTWARE. 44 | 45 | 46 | **Component.** https://github.com/waleedka/hiddenlayer 47 | 48 | **Open Source License/Copyright Notice.** 49 | MIT License 50 | 51 | Copyright (c) 2018 Waleed Abdulla, Phil Ferriere 52 | 53 | Permission is hereby granted, free of charge, to any person obtaining a copy 54 | of this software and associated documentation files (the "Software"), to deal 55 | in the Software without restriction, including without limitation the rights 56 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 57 | copies of the Software, and to permit persons to whom the Software is 58 | furnished to do so, subject to the following conditions: 59 | 60 | The above copyright notice and this permission notice shall be included in all 61 | copies or substantial portions of the Software. 62 | 63 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 64 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 65 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 66 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 67 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 68 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 69 | SOFTWARE. 70 | 71 | 72 | **Component.** https://github.com/yulongwang12/visual-attribution 73 | 74 | **Open Source License/Copyright Notice.** 75 | BSD 2-Clause License 76 | 77 | Copyright (c) 2019, Yulong Wang 78 | All rights reserved. 79 | 80 | Redistribution and use in source and binary forms, with or without 81 | modification, are permitted provided that the following conditions are met: 82 | 83 | * Redistributions of source code must retain the above copyright notice, this 84 | list of conditions and the following disclaimer. 85 | 86 | * Redistributions in binary form must reproduce the above copyright notice, 87 | this list of conditions and the following disclaimer in the documentation 88 | and/or other materials provided with the distribution. 89 | 90 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 91 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 92 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 93 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 94 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 95 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 96 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 97 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 98 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 99 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 100 | 101 | 102 | **Component.** https://github.com/marcotcr/lime 103 | 104 | **Open Source License/Copyright Notice.** 105 | Copyright (c) 2016, Marco Tulio Correia Ribeiro 106 | All rights reserved. 107 | 108 | Redistribution and use in source and binary forms, with or without 109 | modification, are permitted provided that the following conditions are met: 110 | 111 | * Redistributions of source code must retain the above copyright notice, this 112 | list of conditions and the following disclaimer. 113 | 114 | * Redistributions in binary form must reproduce the above copyright notice, 115 | this list of conditions and the following disclaimer in the documentation 116 | and/or other materials provided with the distribution. 117 | 118 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 119 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 120 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 121 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 122 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 123 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 124 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 125 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 126 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 127 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /notebooks/data_exploration.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Using TensorWatch for Data Exploration\n", 8 | "\n", 9 | "In this tutorial, we will show how to quickly use TensorWatch for exploring data using dimention reduction technique called [t-sne](https://en.wikipedia.org/wiki/T-distributed_stochastic_neighbor_embedding). We plan to implement many other techniques in future (please feel to [contribute](https://github.com/microsoft/tensorwatch/blob/master/CONTRIBUTING.md)!)." 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "## Installing regim\n", 17 | "This tutorial will use small Python package called `regim`. It has few utility classes to quickly work with PyTorch datasets. \n", 18 | "\n", 19 | "To install regim, clone repo from Github and do local package install:\n", 20 | "\n", 21 | "```\n", 22 | "git clone https://github.com/sytelus/regim.git\n", 23 | "cd regim\n", 24 | "pip install -e .\n", 25 | "```\n", 26 | "\n", 27 | "Now we are done with that, let's do our imports:" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 1, 33 | "metadata": { 34 | "scrolled": true 35 | }, 36 | "outputs": [], 37 | "source": [ 38 | "import tensorwatch as tw\n", 39 | "from regim import DataUtils" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "First we will get MNIST dataset. The `regim` package has DataUtils class that allows to get entire MNIST dataset without train/test split and reshaping each image as vector of 784 integers instead of 28x28 matrix." 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 2, 52 | "metadata": { 53 | "scrolled": true 54 | }, 55 | "outputs": [], 56 | "source": [ 57 | "ds = DataUtils.mnist_datasets(linearize=True, train_test=False)" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "MNIST dataset is big enough for t-sne to take long time on slow computers. So we will apply t-sne on sample of this dataset. The `regim` package has utility method that allows us to take `k` random samples for each class. We also set `as_np=True` to convert images to numpy array from PyTorch tensor. The `no_test=True` parameter instructs that we don't want to split our data as train and test. The return value is a tuple of two numpy arrays, one containing input images and other labels." 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 3, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "inputs, labels = DataUtils.sample_by_class(ds, k=50, shuffle=True, as_np=True, no_test=True)" 74 | ] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "metadata": {}, 79 | "source": [ 80 | "We are now ready to supply this dataset to TensorWatch and in just one line we can get lower dimensional components. The `get_tsne_components` method takes a tuple of input and labels. The optional parameters `features_col=0` and `labels_col=1` tells which member of tuple is input features and truth labels. Another optional parameter `n_components=3` says that we should generate 3 components for each data point." 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 4, 86 | "metadata": { 87 | "scrolled": true 88 | }, 89 | "outputs": [], 90 | "source": [ 91 | "components = tw.get_tsne_components((inputs, labels))" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": {}, 97 | "source": [ 98 | "Now that we have 3D component for each data point in our dataset, let's plot it! For this purpose, we use `ArrayStream` class from TensorWatch that allows you to convert any iterables in to TensorWatch stream. This stream then we supply to Visualizer class asking it to use `tsne` visualization type which is just fency 3D scatter plot." 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 5, 104 | "metadata": { 105 | "scrolled": true 106 | }, 107 | "outputs": [ 108 | { 109 | "data": { 110 | "application/vnd.jupyter.widget-view+json": { 111 | "model_id": "f86416164bb04e729ea16cb71a666a70", 112 | "version_major": 2, 113 | "version_minor": 0 114 | }, 115 | "text/plain": [ 116 | "HBox(children=(FigureWidget({\n", 117 | " 'data': [{'hoverinfo': 'text',\n", 118 | " 'line': {'color': 'rgb(31, 119,…" 119 | ] 120 | }, 121 | "metadata": {}, 122 | "output_type": "display_data" 123 | } 124 | ], 125 | "source": [ 126 | "comp_stream = tw.ArrayStream(components)\n", 127 | "vis = tw.Visualizer(comp_stream, vis_type='tsne', \n", 128 | " hover_images=inputs, hover_image_reshape=(28,28))\n", 129 | "vis.show()" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": {}, 135 | "source": [ 136 | "Notice that as you hover over mouse on each point on graph, you will see MNIST image associated with that point. How does that work? We are supplying input images to `hover_image` parameter. As our images are 1 dimentional array of 784, we also set `hover_image_reshape` parameter to reshape it to 28x28. You can customize hover behavior by attaching new fuction to vis.[hover_fn](https://github.com/microsoft/tensorwatch/blob/master/tensorwatch/plotly/embeddings_plot.py#L28) member." 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "metadata": {}, 142 | "source": [ 143 | "## Questions?\n", 144 | "\n", 145 | "Please file a [Github issue](https://github.com/microsoft/tensorwatch/issues/new) and let us know if we can improve this tutorial." 146 | ] 147 | } 148 | ], 149 | "metadata": { 150 | "kernelspec": { 151 | "display_name": "Python 3", 152 | "language": "python", 153 | "name": "python3" 154 | }, 155 | "language_info": { 156 | "codemirror_mode": { 157 | "name": "ipython", 158 | "version": 3 159 | }, 160 | "file_extension": ".py", 161 | "mimetype": "text/x-python", 162 | "name": "python", 163 | "nbconvert_exporter": "python", 164 | "pygments_lexer": "ipython3", 165 | "version": "3.6.5" 166 | } 167 | }, 168 | "nbformat": 4, 169 | "nbformat_minor": 2 170 | } 171 | -------------------------------------------------------------------------------- /tensorwatch/mpl/base_mpl_plot.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | #from IPython import get_ipython, display 5 | #if get_ipython(): 6 | # get_ipython().magic('matplotlib notebook') 7 | 8 | #import matplotlib 9 | #if os.name == 'posix' and "DISPLAY" not in os.environ: 10 | # matplotlib.use('Agg') # Must be before importing matplotlib.pyplot or pylab! 11 | 12 | #from ipywidgets.widgets.interaction import show_inline_matplotlib_plots 13 | #from ipykernel.pylab.backend_inline import flush_figures 14 | 15 | from ..vis_base import VisBase 16 | 17 | import sys, logging 18 | from abc import abstractmethod 19 | from .. import utils 20 | 21 | 22 | class BaseMplPlot(VisBase): 23 | def __init__(self, cell:VisBase.widgets.Box=None, title:str=None, show_legend:bool=None, is_3d:bool=False, 24 | stream_name:str=None, console_debug:bool=False, **vis_args): 25 | super(BaseMplPlot, self).__init__(VisBase.widgets.Output(), cell, title, show_legend, 26 | stream_name=stream_name, console_debug=console_debug, **vis_args) 27 | 28 | self._fig_init_done = False 29 | self.show_legend = show_legend 30 | self.is_3d = is_3d 31 | if is_3d: 32 | # this is needed for some reason 33 | from mpl_toolkits.mplot3d import Axes3D 34 | # graph objects 35 | self.figure = None 36 | self._ax_main = None 37 | # matplotlib animation 38 | self.animation = None 39 | self.anim_interval = None 40 | #print(matplotlib.get_backend()) 41 | #display.display(self.cell) 42 | 43 | # anim_interval in seconds 44 | def init_fig(self, anim_interval:float=1.0): 45 | import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue 46 | 47 | """(for derived class) Initializes matplotlib figure""" 48 | if self._fig_init_done: 49 | return False 50 | 51 | # create figure and animation 52 | self.figure = plt.figure(figsize=(8, 3)) 53 | self.anim_interval = anim_interval 54 | 55 | # default color pallet 56 | import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue 57 | 58 | plt.set_cmap('Dark2') 59 | plt.rcParams['image.cmap']='Dark2' 60 | 61 | self._fig_init_done = True 62 | return True 63 | 64 | def get_main_axis(self): 65 | # if we don't yet have main axis, create one 66 | if not self._ax_main: 67 | # by default assign one subplot to whole graph 68 | self._ax_main = self.figure.add_subplot(111, 69 | projection=None if not self.is_3d else '3d') 70 | self._ax_main.grid(self.is_show_grid()) 71 | # change the color of the top and right spines to opaque gray 72 | self._ax_main.spines['right'].set_color((.8,.8,.8)) 73 | self._ax_main.spines['top'].set_color((.8,.8,.8)) 74 | if self.title is not None: 75 | title = self._ax_main.set_title(self.title) 76 | title.set_weight('bold') 77 | return self._ax_main 78 | 79 | # overridable 80 | def is_show_grid(self): 81 | return True 82 | 83 | def _on_update(self, frame): # pylint: disable=unused-argument 84 | try: 85 | self._update_stream_plots() 86 | except Exception as ex: 87 | # when exception occurs here, animation will stop and there 88 | # will be no further plot updates 89 | # TODO: may be we don't need all of below but none of them 90 | # are popping up exception in Jupyter Notebook because these 91 | # exceptions occur in background? 92 | self.last_ex = ex 93 | logging.exception('Exception in matplotlib update loop') 94 | 95 | 96 | def show(self, blocking=False): 97 | if not self.is_shown and self.anim_interval: 98 | from matplotlib.animation import FuncAnimation # function-level import as this one is expensive 99 | self.animation = FuncAnimation(self.figure, self._on_update, interval=self.anim_interval*1000.0) 100 | super(BaseMplPlot, self).show(blocking) 101 | 102 | def _post_update_stream_plot(self, stream_vis): 103 | import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue 104 | 105 | utils.debug_log("Plot updated", stream_vis.stream.stream_name, verbosity=5) 106 | 107 | if self.layout_dirty: 108 | # do not do tight_layout() call on every update 109 | # that would jumble up the graphs! it should only called 110 | # once each time there is change in layout 111 | self.figure.tight_layout() 112 | self.layout_dirty = False 113 | 114 | # below forces redraw and it was helpful to 115 | # repaint even if there was error in interval loop 116 | # but it does work in native UX and not in Jupyter Notebook 117 | #self.figure.canvas.draw() 118 | #self.figure.canvas.flush_events() 119 | 120 | if self._use_hbox and VisBase.get_ipython(): 121 | self.widget.clear_output(wait=True) 122 | with self.widget: 123 | plt.show(self.figure) 124 | 125 | # everything else that doesn't work 126 | #self.figure.show() 127 | #display.clear_output(wait=True) 128 | #display.display(self.figure) 129 | #flush_figures() 130 | #plt.show() 131 | #show_inline_matplotlib_plots() 132 | #elif not get_ipython(): 133 | # self.figure.canvas.draw() 134 | 135 | def _post_add_subscription(self, stream_vis, **stream_vis_args): 136 | import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue 137 | 138 | # make sure figure is initialized 139 | self.init_fig() 140 | self.init_stream_plot(stream_vis, **stream_vis_args) 141 | 142 | # redo the legend 143 | #self.figure.legend(loc='center right', bbox_to_anchor=(1.5, 0.5)) 144 | if self.show_legend: 145 | self.figure.legend(loc='lower right') 146 | plt.subplots_adjust(hspace=0.6) 147 | 148 | def _show_widget_native(self, blocking:bool): 149 | import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue 150 | 151 | #plt.ion() 152 | #plt.show() 153 | return plt.show(block=blocking) 154 | 155 | def _show_widget_notebook(self): 156 | # no need to return anything because %matplotlib notebook will 157 | # detect spawning of figure and paint it 158 | # if self.figure is returned then you will see two of them 159 | return None 160 | #plt.show() 161 | #return self.figure 162 | 163 | def _can_update_stream_plots(self): 164 | return False # we run interval timer which will flush the key 165 | 166 | @abstractmethod 167 | def init_stream_plot(self, stream_vis, **stream_vis_args): 168 | """(for derived class) Create new plot info for this stream""" 169 | pass 170 | -------------------------------------------------------------------------------- /tensorwatch/lv_types.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from typing import List, Callable, Any, Sequence, Hashable 5 | from . import utils 6 | import uuid 7 | 8 | 9 | class EventData: 10 | def __init__(self, globals_val, **vars_val): 11 | if globals_val is not None: 12 | for key in globals_val: 13 | setattr(self, key, globals_val[key]) 14 | for key in vars_val: 15 | setattr(self, key, vars_val[key]) 16 | 17 | def __str__(self): 18 | sb = [] 19 | for key in self.__dict__: 20 | val = self.__dict__[key] 21 | if utils.is_scalar(val): 22 | sb.append('{key}={value}'.format(key=key, value=val)) 23 | else: 24 | sb.append('{key}="{value}"'.format(key=key, value=val)) 25 | 26 | return ', '.join(sb) 27 | 28 | EventsVars = List[EventData] 29 | 30 | class StreamItem: 31 | def __init__(self, value:Any, stream_name:str=None, item_index:int=None, 32 | ended:bool=False, exception:Exception=None, stream_reset:bool=False): 33 | self.value = value 34 | self.exception = exception 35 | self.stream_name = stream_name 36 | self.item_index = item_index 37 | self.ended = ended 38 | self.stream_reset = stream_reset 39 | 40 | def __repr__(self): 41 | return str(self.__dict__) 42 | 43 | EventEvalFunc = Callable[[EventsVars], StreamItem] 44 | 45 | 46 | class VisArgs: 47 | """Provides container for visualizer parameters 48 | 49 | These are same parameters as Visualizer constructor. 50 | NOTE: If you modify arguments here then also sync Visualizer contructor. 51 | """ 52 | def __init__(self, vis_type:str=None, host:'Visualizer'=None, 53 | cell:'Visualizer'=None, title:str=None, 54 | clear_after_end=False, clear_after_each=False, history_len=1, dim_history=True, opacity=None, 55 | 56 | rows=2, cols=5, img_width=None, img_height=None, img_channels=None, 57 | colormap=None, viz_img_scale=None, 58 | 59 | # these image params are for hover on point for t-sne 60 | hover_images=None, hover_image_reshape=None, cell_width:str=None, cell_height:str=None, 61 | 62 | only_summary=False, separate_yaxis=True, xtitle=None, ytitle=None, ztitle=None, color=None, 63 | xrange=None, yrange=None, zrange=None, draw_line=True, draw_marker=False, 64 | 65 | # histogram 66 | bins=None, normed=None, histtype='bar', edge_color=None, linewidth=None, bar_width=None, 67 | 68 | # pie chart 69 | autopct=None, shadow=None, 70 | 71 | vis_args:dict=None, stream_vis_args:dict=None)->None: 72 | 73 | self.vis_type, self.host = vis_type, host 74 | self.cell, self.title = cell, title 75 | self.clear_after_end, self.clear_after_each, self.history_len, self.dim_history, self.opacity = \ 76 | clear_after_end, clear_after_each, history_len, dim_history, opacity 77 | self.rows, self.cols, self.img_width, self.img_height, self.img_channels = \ 78 | rows, cols, img_width, img_height, img_channels 79 | self.colormap, self.viz_img_scale = colormap, viz_img_scale 80 | 81 | # these image params are for hover on point for t-sne 82 | self.hover_images, self.hover_image_reshape, self.cell_width, self.cell_height = \ 83 | hover_images, hover_image_reshape, cell_width, cell_height 84 | 85 | self.only_summary, self.separate_yaxis, self.xtitle, self.ytitle, self.ztitle, self.color = \ 86 | only_summary, separate_yaxis, xtitle, ytitle, ztitle, color 87 | self.xrange, self.yrange, self.zrange, self.draw_line, self.draw_marker = \ 88 | xrange, yrange, zrange, draw_line, draw_marker 89 | 90 | # histogram 91 | self.bins, self.normed, self.histtype, self.edge_color, self.linewidth, self.bar_width = \ 92 | bins, normed, histtype, edge_color, linewidth, bar_width 93 | # pie chart 94 | self.autopct, self.shadow = autopct, shadow 95 | 96 | self.vis_args, self.stream_vis_args = vis_args, stream_vis_args 97 | 98 | 99 | class StreamCreateRequest: 100 | def __init__(self, stream_name:str, devices:Sequence[str]=None, event_name:str='', 101 | expr:str=None, throttle:float=None, vis_args:VisArgs=None): 102 | self.event_name = event_name 103 | self.expr = expr 104 | self.stream_name = stream_name or str(uuid.uuid4()) 105 | self.devices = devices 106 | self.vis_args = vis_args 107 | 108 | # max throughput n Lenovo P50 laptop for MNIST 109 | # text console -> 0.1s 110 | # matplotlib line graph -> 0.5s 111 | self.throttle = throttle 112 | 113 | class ClientServerRequest: 114 | def __init__(self, req_type:str, req_data:Any): 115 | self.req_type = req_type 116 | self.req_data = req_data 117 | 118 | class CliSrvReqTypes: 119 | create_stream = 'CreateStream' 120 | del_stream = 'DeleteStream' 121 | 122 | class StreamVisInfo: 123 | def __init__(self, stream, title, clear_after_end, 124 | clear_after_each, history_len, dim_history, opacity, 125 | index, stream_vis_args, last_update): 126 | self.stream = stream 127 | self.title, self.opacity = title, opacity 128 | self.clear_after_end, self.clear_after_each = clear_after_end, clear_after_each 129 | self.history_len, self.dim_history = history_len, dim_history 130 | self.index, self.stream_vis_args, self.last_update = index, stream_vis_args, last_update 131 | 132 | class ImageData: 133 | # images are numpy array of shape [[channels,] height, width] 134 | def __init__(self, images=None, title=None, alpha=None, cmap=None): 135 | if not isinstance(images, tuple): 136 | images = (images,) 137 | self.images, self.alpha, self.cmap, self.title = images, alpha, cmap, title 138 | 139 | class PointData: 140 | def __init__(self, x:float=None, y:float=None, z:float=None, low:float=None, high:float=None, 141 | annotation:Any=None, text:Any=None, color:Any=None)->None: 142 | self.x = x 143 | self.y = y 144 | self.z = z 145 | self.low = low # confidence interval 146 | self.high = high 147 | self.annotation = annotation 148 | self.text = text 149 | self.color = color # typically string like '#d62728' 150 | 151 | class PredictionResult: 152 | def __init__(self, loss:float=None, class_id:Hashable=None, probability:float=None, 153 | inputs:Any=None, outputs:Any=None, targets:Any=None, others:Any=None): 154 | self.loss = loss 155 | self.class_id = class_id 156 | self.probability = probability 157 | self.inputs = inputs 158 | self.outputs = outputs 159 | self.targets = targets 160 | self.others = others 161 | 162 | class DefaultPorts: 163 | PubSub = 40859 164 | CliSrv = 41459 165 | 166 | class PublisherTopics: 167 | StreamItem = 'StreamItem' 168 | ServerMgmt = 'ServerMgmt' 169 | 170 | class ServerMgmtMsg: 171 | EventServerStart = 'ServerStart' 172 | def __init__(self, event_name:str, event_args:Any=None): 173 | self.event_name = event_name 174 | self.event_args = event_args 175 | --------------------------------------------------------------------------------