├── .gitignore ├── .travis.yml ├── LICENSE ├── MANIFEST.in ├── Readme.md ├── appveyor.yml ├── etc ├── cm.gif ├── make_gif.txt ├── scatter.png └── sgd.gif ├── requirements.txt ├── setup.py └── tfmpl ├── __init__.py ├── create.py ├── figure.py ├── meta.py ├── plots ├── __init__.py └── confusion_matrix.py ├── samples ├── mnist.py ├── scatter.py └── sgd.py └── tests ├── __init__.py ├── test_confusion.py └── test_figure.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | log/ 104 | MNIST_data/ -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | env: 3 | - PYTHON=3.5 4 | - PYTHON=3.6 5 | install: 6 | # Install conda 7 | - if [[ "$PYTHON" == "2.7" ]]; then 8 | wget https://repo.continuum.io/miniconda/Miniconda2-latest-Linux-x86_64.sh -O miniconda.sh; 9 | else 10 | wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh; 11 | fi 12 | - bash miniconda.sh -b -p $HOME/miniconda 13 | - export PATH="$HOME/miniconda/bin:$PATH" 14 | - hash -r 15 | - conda config --set always_yes yes --set changeps1 no 16 | - conda config --add channels pandas 17 | - conda update -q conda 18 | - conda info -a 19 | 20 | # Install deps 21 | - deps='pip' 22 | - conda create -q -n pyenv python=$PYTHON $deps 23 | - source activate pyenv 24 | - python -m pip install -U pip 25 | - pip install tensorflow 26 | - pip install matplotlib 27 | - pip install pytest 28 | - pip install . 29 | 30 | script: pytest -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Christoph Heindl 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md LICENSE requirements.txt -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | ### **tf-matplotlib** - seamless integration of matplotlib figures into TensorFlow summaries 2 | 3 | **tf-matplotlib** renders your everyday matplotlib figures tinside TensorFlow's Tensorboard visualization interface. The library 4 | - takes care of evaluating input tensors prior to plotting, 5 | - avoids matplotlib threading issues, 6 | - support multiple figures and, 7 | - provides blitting for runtime critical plotting. 8 | 9 | The following TensorFlow summary is generated by [sgd.py](tfmpl/samples/sgd.py). It plots the progress of gradient descent optimizers on a test surface. To avoid redrawing the test surface, it makes use of blitting. See [usage](#usage) below for a more introductory example. 10 | 11 | ![](etc/sgd.gif) 12 | 13 | ### Installation 14 | 15 | ``` 16 | pip install tfmpl 17 | ``` 18 | 19 | Requirements 20 | - Python 3.5/3.6 21 | - TensorFlow 1.x 22 | - matplotlib 2.2.0 23 | 24 | ### Build status 25 | 26 | |Branch|Linux|Windows| 27 | |------|------|------| 28 | |master|![](https://travis-ci.org/cheind/tf-matplotlib.svg?branch=master)| ![](https://ci.appveyor.com/api/projects/status/reo8nucumqhb93q5/branch/master?svg=true) | 29 | |develop|![](https://travis-ci.org/cheind/tf-matplotlib.svg?branch=master)|![](https://ci.appveyor.com/api/projects/status/reo8nucumqhb93q5/branch/develop?svg=true)| 30 | 31 | ### Usage 32 | 33 | 34 | Below are the relevant snippets to render a simple scatter plot. See [scatter.py](tfmpl/samples/scatter.py) for the complete self-contained example. 35 | 36 | ```python 37 | import tensorflow as tf 38 | import numpy as np 39 | 40 | import tfmpl 41 | 42 | @tfmpl.figure_tensor 43 | def draw_scatter(scaled, colors): 44 | '''Draw scatter plots. One for each color.''' 45 | figs = tfmpl.create_figures(len(colors), figsize=(4,4)) 46 | for idx, f in enumerate(figs): 47 | ax = f.add_subplot(111) 48 | ax.axis('off') 49 | ax.scatter(scaled[:, 0], scaled[:, 1], c=colors[idx]) 50 | f.tight_layout() 51 | 52 | return figs 53 | 54 | with tf.Session(graph=tf.Graph()) as sess: 55 | 56 | # A point cloud that can be scaled by the user 57 | points = tf.constant( 58 | np.random.normal(loc=0.0, scale=1.0, size=(100, 2)).astype(np.float32) 59 | ) 60 | scale = tf.placeholder(tf.float32) 61 | scaled = points*scale 62 | 63 | # Note, `scaled` above is a tensor. Its being passed `draw_scatter` below. 64 | # However, when `draw_scatter` is invoked, the tensor will be evaluated and a 65 | # numpy array representing its content is provided. 66 | image_tensor = draw_scatter(scaled, ['r', 'g']) 67 | image_summary = tf.summary.image('scatter', image_tensor) 68 | all_summaries = tf.summary.merge_all() 69 | 70 | writer = tf.summary.FileWriter('log', sess.graph) 71 | summary = sess.run(all_summaries, feed_dict={scale: 2.}) 72 | writer.add_summary(summary, global_step=0) 73 | ``` 74 | 75 | ![](etc/scatter.png) 76 | 77 | ### Draw utilities 78 | 79 | When doing classification, a common task is to generate a confusion matrix. **tf-matplotlib** provides `tfmpl.draw.confusion_matrix` to quickly generate such a plot from labels and predictions. The following plot shows classification training progress on the MNIST classification task. Full sample code is provided in [mnist.py](tfmpl/samples/mnist.py). 80 | 81 | ![](etc/cm.gif) 82 | 83 | ### License 84 | 85 | ``` 86 | MIT License 87 | 88 | Copyright (c) 2018 Christoph Heindl 89 | 90 | Permission is hereby granted, free of charge, to any person obtaining a copy 91 | of this software and associated documentation files (the "Software"), to deal 92 | in the Software without restriction, including without limitation the rights 93 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 94 | copies of the Software, and to permit persons to whom the Software is 95 | furnished to do so, subject to the following conditions: 96 | 97 | The above copyright notice and this permission notice shall be included in all 98 | copies or substantial portions of the Software. 99 | 100 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 101 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 102 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 103 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 104 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 105 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 106 | SOFTWARE. 107 | ``` 108 | 109 | 110 | -------------------------------------------------------------------------------- /appveyor.yml: -------------------------------------------------------------------------------- 1 | environment: 2 | TWINE_USERNAME: cheind 3 | TWINE_PASSWORD: 4 | secure: hnxMBvmJAGM1rQVOUbkGvQ== 5 | 6 | # http://www.appveyor.com/docs/installed-software#python 7 | matrix: 8 | - PYTHON: "C:\\Miniconda36-x64" 9 | PYTHON_VERSION: "3.6" 10 | PYTHON_ARCH: "64" 11 | - PYTHON: "C:\\Miniconda35-x64" 12 | PYTHON_VERSION: "3.5" 13 | PYTHON_ARCH: "64" 14 | 15 | install: 16 | - set "CONDA_ROOT=%PYTHON%" 17 | - set "PATH=%CONDA_ROOT%;%CONDA_ROOT%\Scripts;%CONDA_ROOT%\Library\bin;%PATH%" 18 | - conda config --set always_yes yes 19 | - conda update -q conda 20 | - conda config --set auto_update_conda no 21 | - conda install -q pip pytest numpy 22 | - python -m pip install -U pip 23 | - pip install wheel 24 | - pip install --upgrade --ignore-installed setuptools 25 | - pip install tensorflow 26 | - pip install matplotlib 27 | 28 | build_script: 29 | - python setup.py sdist 30 | 31 | test_script: 32 | # Try building source wheel and install 33 | - ps: >- 34 | $wheel = cmd /r dir .\dist\*.tar.gz /b/s; 35 | pip install --verbose $wheel 36 | - pytest --pyargs tfmpl 37 | 38 | on_success: 39 | ps: >- 40 | if ($env:APPVEYOR_REPO_BRANCH -eq "master") { 41 | Write-Output ("Deploying to PyPI") 42 | pip install --upgrade twine 43 | # If powershell ever sees anything on stderr it thinks it's a fail. 44 | # So we use cmd to redirect stderr to stdout before PS can see it. 45 | cmd /c 'twine upload --skip-existing dist\* 2>&1' 46 | } else { 47 | Write-Output "Not deploying as this is not a tagged commit or commit on master" 48 | } 49 | 50 | artifacts: 51 | - path: "dist\\*.tar.gz" 52 | name: Wheels 53 | 54 | notifications: 55 | - provider: Email 56 | to: 57 | - christoph.heindl@email.com 58 | on_build_success: true 59 | on_build_failure: true 60 | 61 | branches: 62 | only: 63 | - master 64 | - develop -------------------------------------------------------------------------------- /etc/cm.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheind/tf-matplotlib/c6904d3d2d306d9a479c24fbcb1f674a57dafd0e/etc/cm.gif -------------------------------------------------------------------------------- /etc/make_gif.txt: -------------------------------------------------------------------------------- 1 | ffmpeg -y -i sgd.mp4 -vf fps=10,scale=320:-1:flags=lanczos,palettegen palette.png 2 | ffmpeg -i sgd.mp4 -i palette.png -filter_complex "fps=10,scale=960:-1:flags=lanczos[x];[x][1:v]paletteuse" output.gif -------------------------------------------------------------------------------- /etc/scatter.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheind/tf-matplotlib/c6904d3d2d306d9a479c24fbcb1f674a57dafd0e/etc/scatter.png -------------------------------------------------------------------------------- /etc/sgd.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheind/tf-matplotlib/c6904d3d2d306d9a479c24fbcb1f674a57dafd0e/etc/sgd.gif -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib>=2.0.2 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Christoph Heindl. 2 | # 3 | # Licensed under MIT License 4 | # ============================================================ 5 | 6 | import os 7 | try: 8 | from setuptools import setup 9 | except ImportError: 10 | from distutils.core import setup 11 | 12 | with open('requirements.txt') as f: 13 | required = f.read().splitlines() 14 | 15 | setup( 16 | name='tfmpl', 17 | version=open('tfmpl/__init__.py').readlines()[-1].split()[-1].strip('\''), 18 | description='Seamlessly integrate matplotlib figures tensorflow summaries.', 19 | author='Christoph Heindl', 20 | url='https://github.com/cheind/tf-matplotlib', 21 | license='MIT', 22 | install_requires=required, 23 | packages=['tfmpl', 'tfmpl.plots', 'tfmpl.samples', 'tfmpl.tests'], 24 | include_package_data=True, 25 | keywords='tensorflow matplotlib tensorboard' 26 | ) -------------------------------------------------------------------------------- /tfmpl/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Christoph Heindl. 2 | # 3 | # Licensed under MIT License 4 | # ============================================================ 5 | 6 | from tfmpl.figure import figure_tensor, blittable_figure_tensor 7 | from tfmpl.create import create_figure, create_figures 8 | import tfmpl.plots 9 | 10 | # Needs to be last line 11 | __version__ = '1.0.2' -------------------------------------------------------------------------------- /tfmpl/create.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Christoph Heindl. 2 | # 3 | # Licensed under MIT License 4 | # ============================================================ 5 | 6 | from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas 7 | from matplotlib.figure import Figure 8 | 9 | def create_figure(*fig_args, **fig_kwargs): 10 | '''Create a single figure. 11 | 12 | Args and Kwargs are passed to `matplotlib.figure.Figure`. 13 | 14 | This routine is provided in order to avoid usage of pyplot which 15 | is stateful and not thread safe. As drawing routines in tf-matplotlib 16 | are called from py-funcs in their respective thread, avoid usage 17 | of pyplot where possible. 18 | ''' 19 | 20 | fig = Figure(*fig_args, **fig_kwargs) 21 | # Attach canvas 22 | FigureCanvas(fig) 23 | return fig 24 | 25 | def create_figures(n, *fig_args, **fig_kwargs): 26 | '''Create multiple figures. 27 | 28 | Args and Kwargs are passed to `matplotlib.figure.Figure`. 29 | 30 | This routine is provided in order to avoid usage of pyplot which 31 | is stateful and not thread safe. As drawing routines in tf-matplotlib 32 | are called from py-funcs in their respective thread, avoid usage 33 | of pyplot where possible. 34 | ''' 35 | return [create_figure(*fig_args, **fig_kwargs) for _ in range(n)] -------------------------------------------------------------------------------- /tfmpl/figure.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Christoph Heindl. 2 | # 3 | # Licensed under MIT License 4 | # ============================================================ 5 | 6 | import tensorflow as tf 7 | import traceback 8 | import numpy as np 9 | from functools import wraps 10 | 11 | from tfmpl.meta import vararg_decorator, as_list 12 | from tfmpl.meta import PositionalTensorArgs 13 | 14 | def figure_buffer(figs): 15 | '''Extract raw image buffer from matplotlib figure shaped as 1xHxWx3.''' 16 | assert len(figs) > 0, 'No figure buffers given. Forgot to return from draw call?' 17 | buffers = [] 18 | w, h = figs[0].canvas.get_width_height() 19 | for f in figs: 20 | wf, hf = f.canvas.get_width_height() 21 | assert wf == w and hf == h, 'All canvas objects need to have same size' 22 | buffers.append(np.fromstring(f.canvas.tostring_rgb(), dtype=np.uint8).reshape(h, w, 3)) 23 | 24 | return np.stack(buffers) # NxHxWx3 25 | 26 | @vararg_decorator 27 | def figure_tensor(func, **tf_pyfunc_kwargs): 28 | '''Decorate matplotlib drawing routines. 29 | 30 | This dectorator is meant to decorate functions that return matplotlib 31 | figures. The decorated function has to have the following signature 32 | 33 | def decorated(*args, **kwargs) -> figure or iterable of figures 34 | 35 | where `*args` can be any positional argument and `**kwargs` are any 36 | keyword arguments. The decorated function returns a tensor of shape 37 | `[NumFigures, Height, Width, 3]` of type `tf.uint8`. 38 | 39 | The drawing code is invoked during running of TensorFlow sessions, 40 | at a time when all positional tensor arguments have been evaluated 41 | by the session. The decorated function is then passed the tensor values. 42 | All non tensor arguments remain unchanged. 43 | ''' 44 | 45 | name = tf_pyfunc_kwargs.pop('name', func.__name__) 46 | 47 | @wraps(func) 48 | def wrapper(*func_args, **func_kwargs): 49 | tf_args = PositionalTensorArgs(func_args) 50 | 51 | def pyfnc_callee(*tensor_values, **unused): 52 | try: 53 | figs = as_list(func(*tf_args.mix_args(tensor_values), **func_kwargs)) 54 | for f in figs: 55 | f.canvas.draw() 56 | return figure_buffer(figs) 57 | except Exception: 58 | print('-'*5 + 'tfmpl catched exception' + '-'*5) 59 | print(traceback.format_exc()) 60 | print('-'*20) 61 | raise 62 | 63 | return tf.py_func(pyfnc_callee, tf_args.tensor_args, tf.uint8, name=name, **tf_pyfunc_kwargs) 64 | return wrapper 65 | 66 | @vararg_decorator 67 | def blittable_figure_tensor(func, init_func, **tf_pyfunc_kwargs): 68 | '''Decorate matplotlib drawing routines with blitting support. 69 | 70 | This dectorator is meant to decorate functions that return matplotlib 71 | figures. The decorated function has to have the following signature 72 | 73 | def decorated(*args, **kwargs) -> iterable of artists 74 | 75 | where `*args` can be any positional argument and `**kwargs` are any 76 | keyword arguments. The decorated function returns a tensor of shape 77 | `[NumFigures, Height, Width, 3]` of type `tf.uint8`. 78 | 79 | Besides the actual drawing function, `blittable_figure_tensor` requires 80 | a `init_func` argument with the following signature 81 | 82 | def init(*args, **kwargs) -> iterable of figures, iterable of artists 83 | 84 | The init function is meant to create and initialize figures, as well as to 85 | perform drawing that is meant to be done only once. Any set of artits to be 86 | updated in later drawing calls should also be allocated in init. The 87 | initialize function must have the same positional and keyword arguments 88 | as the decorated function. It is called once before the decorated function 89 | is called. 90 | 91 | The drawing code / init function is invoked during running of TensorFlow 92 | sessions, at a time when all positional tensor arguments have been 93 | evaluated by the session. The decorated / init function is then passed the 94 | tensor values. All non tensor arguments remain unchanged. 95 | ''' 96 | name = tf_pyfunc_kwargs.pop('name', func.__name__) 97 | assert callable(init_func), 'Init function not callable' 98 | 99 | @wraps(func) 100 | def wrapper(*func_args, **func_kwargs): 101 | figs = None 102 | bgs = None 103 | 104 | tf_args = PositionalTensorArgs(func_args) 105 | 106 | def pyfnc_callee(*tensor_values, **unused): 107 | 108 | try: 109 | nonlocal figs, bgs 110 | pos_args = tf_args.mix_args(tensor_values) 111 | 112 | if figs is None: 113 | figs, artists = init_func(*pos_args, **func_kwargs) 114 | figs = as_list(figs) 115 | artists = as_list(artists) 116 | for f in figs: 117 | f.canvas.draw() 118 | for a in artists: 119 | a.set_animated(True) 120 | bgs = [f.canvas.copy_from_bbox(f.bbox) for f in figs] 121 | 122 | artists = as_list(func(*pos_args, **func_kwargs)) 123 | 124 | for f, bg in zip(figs, bgs): 125 | f.canvas.restore_region(bg) 126 | for a in artists: 127 | a.axes.draw_artist(a) 128 | for f in figs: 129 | f.canvas.blit(f.bbox) 130 | 131 | return figure_buffer(figs) 132 | except Exception: 133 | print('-'*5 + 'tfmpl catched exception' + '-'*5) 134 | print(traceback.format_exc()) 135 | print('-'*20) 136 | raise 137 | 138 | return tf.py_func(pyfnc_callee, tf_args.tensor_args, tf.uint8, name=name, **tf_pyfunc_kwargs) 139 | return wrapper -------------------------------------------------------------------------------- /tfmpl/meta.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Christoph Heindl. 2 | # 3 | # Licensed under MIT License 4 | # ============================================================ 5 | 6 | from functools import wraps 7 | import tensorflow as tf 8 | from tensorflow.contrib.framework import is_tensor 9 | from collections import Sequence 10 | 11 | def vararg_decorator(f): 12 | '''Decorator to handle variable argument decorators.''' 13 | 14 | @wraps(f) 15 | def decorator(*args, **kwargs): 16 | if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): 17 | return f(args[0]) 18 | else: 19 | return lambda realf: f(realf, *args, **kwargs) 20 | 21 | return decorator 22 | 23 | class PositionalTensorArgs: 24 | '''Handle tensor arguments.''' 25 | 26 | def __init__(self, args): 27 | self.args = args 28 | self.tf_args = [(i,a) for i,a in enumerate(args) if is_tensor(a)] 29 | 30 | @property 31 | def tensor_args(self): 32 | return [a for i,a in self.tf_args] 33 | 34 | def mix_args(self, tensor_values): 35 | args = list(self.args) 36 | for i, (j, _) in enumerate(self.tf_args): 37 | args[j] = tensor_values[i] 38 | return args 39 | 40 | def as_list(x): 41 | '''Ensure `x` is of list type.''' 42 | 43 | if x is None: 44 | x = [] 45 | elif not isinstance(x, Sequence): 46 | x = [x] 47 | return list(x) -------------------------------------------------------------------------------- /tfmpl/plots/__init__.py: -------------------------------------------------------------------------------- 1 | import tfmpl.plots.confusion_matrix -------------------------------------------------------------------------------- /tfmpl/plots/confusion_matrix.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Christoph Heindl. 2 | # 3 | # Licensed under MIT License 4 | # ============================================================ 5 | 6 | import numpy as np 7 | import re 8 | from textwrap import wrap 9 | from itertools import product 10 | 11 | def from_labels_and_predictions(labels, predictions, num_classes): 12 | '''Compute a confusion matrix from labels and predictions. 13 | 14 | A drop-in replacement for tf.confusion_matrix that works on CPU data 15 | and not tensors. 16 | 17 | Params 18 | ------ 19 | labels : array-like 20 | 1-D array of real labels for classification 21 | predicitions: array-like 22 | 1-D array of predicted label classes 23 | num_classes: scalar 24 | Total number of classes 25 | 26 | Returns 27 | ------- 28 | matrix : NxN array 29 | Array of shape [num_classes, num_classes] containing the confusion values. 30 | ''' 31 | assert len(labels) == len(predictions) 32 | cm = np.zeros((num_classes, num_classes), dtype=np.int32) 33 | for i in range(len(labels)): 34 | cm[labels[i], predictions[i]] += 1 35 | return cm 36 | 37 | def draw(ax, cm, axis_labels=None, normalize=False): 38 | '''Plot a confusion matrix. 39 | 40 | Inspired by 41 | https://stackoverflow.com/questions/41617463/tensorflow-confusion-matrix-in-tensorboard 42 | 43 | Params 44 | ------ 45 | ax : axis 46 | Axis to plot on 47 | cm : NxN array 48 | Confusion matrix 49 | 50 | Kwargs 51 | ------ 52 | axis_labels : array-like 53 | Array of size N containing axis labels 54 | normalize : bool 55 | Whether to plot counts or ratios. 56 | ''' 57 | 58 | cm = np.asarray(cm) 59 | num_classes = cm.shape[0] 60 | 61 | if normalize: 62 | with np.errstate(invalid='ignore', divide='ignore'): 63 | cm = cm / cm.sum(1, keepdims=True) 64 | cm = np.nan_to_num(cm, copy=True) 65 | 66 | po = np.get_printoptions() 67 | np.set_printoptions(precision=2) 68 | 69 | ax.imshow(cm, cmap='Oranges') 70 | 71 | ticks = np.arange(num_classes) 72 | 73 | ax.set_xlabel('Predicted') 74 | ax.set_xticks(ticks) 75 | ax.xaxis.set_label_position('bottom') 76 | ax.xaxis.tick_bottom() 77 | 78 | ax.set_ylabel('Actual') 79 | ax.set_yticks(ticks) 80 | ax.yaxis.set_label_position('left') 81 | ax.yaxis.tick_left() 82 | 83 | if axis_labels is not None: 84 | ticklabels = [re.sub(r'([a-z](?=[A-Z])|[A-Z](?=[A-Z][a-z]))', r'\1 ', x) for x in axis_labels] 85 | ticklabels = ['\n'.join(wrap(l, 20)) for l in ticklabels] 86 | ax.set_xticklabels(ticklabels, rotation=-90, ha='center') 87 | ax.set_yticklabels(ticklabels, va ='center') 88 | 89 | for i, j in product(range(num_classes), range(num_classes)): 90 | if cm[i,j] == 0: 91 | txt = '.' 92 | elif normalize: 93 | txt = '{:.2f}'.format(cm[i,j]) 94 | else: 95 | txt = '{}'.format(cm[i,j]) 96 | ax.text(j, i, txt, horizontalalignment="center", verticalalignment='center', color= "black", fontsize=7) 97 | 98 | np.set_printoptions(**po) -------------------------------------------------------------------------------- /tfmpl/samples/mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Christoph Heindl. 2 | # 3 | # Licensed under MIT License 4 | # ============================================================ 5 | """Show usage of confusion matrix visualization. 6 | 7 | Using a simple MNIST classifier taken from 8 | https://github.com/tensorflow/tensorflow/blob/r1.1/tensorflow/examples/tutorials/mnist/mnist_softmax.py 9 | 10 | Code is modified to slow down convergence so that 11 | time-stepping confusion matrix in Tensorboard has a 12 | better visual effect. 13 | """ 14 | 15 | from tensorflow.examples.tutorials.mnist import input_data 16 | from datetime import datetime 17 | import tensorflow as tf 18 | import numpy as np 19 | import os 20 | 21 | import tfmpl 22 | 23 | @tfmpl.figure_tensor 24 | def draw_confusion_matrix(matrix): 25 | '''Draw confusion matrix for MNIST.''' 26 | fig = tfmpl.create_figure(figsize=(7,7)) 27 | ax = fig.add_subplot(111) 28 | ax.set_title('Confusion matrix for MNIST classification') 29 | 30 | tfmpl.plots.confusion_matrix.draw( 31 | ax, matrix, 32 | axis_labels=['Digit ' + str(x) for x in range(10)], 33 | normalize=True 34 | ) 35 | 36 | return fig 37 | 38 | if __name__ == '__main__': 39 | mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) 40 | 41 | with tf.Session(graph=tf.Graph()) as sess: 42 | 43 | # Create the model 44 | x = tf.placeholder(tf.float32, [None, 784]) 45 | W = tf.Variable(tf.zeros([784, 10])) 46 | b = tf.Variable(tf.zeros([10])) 47 | y = tf.matmul(x, W) + b 48 | 49 | y_ = tf.placeholder(tf.float32, [None, 10]) 50 | cross_entropy = tf.reduce_mean( 51 | tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_, logits=y) 52 | ) 53 | train = tf.train.GradientDescentOptimizer(1e-3).minimize(cross_entropy) 54 | 55 | preds = tf.argmax(y, 1) 56 | labels = tf.argmax(y_, 1) 57 | 58 | # Compute confusion matrix 59 | matrix = tf.confusion_matrix(labels, preds, num_classes=10) 60 | 61 | # Get a image tensor for summary usage 62 | image_tensor = draw_confusion_matrix(matrix) 63 | 64 | image_summary = tf.summary.image('confusion_matrix', image_tensor) 65 | all_summaries = tf.summary.merge_all() 66 | 67 | os.makedirs('log', exist_ok=True) 68 | now = datetime.now() 69 | logdir = "log/" + now.strftime("%Y%m%d-%H%M%S") + "/" 70 | writer = tf.summary.FileWriter(logdir, sess.graph) 71 | 72 | # Train 73 | sess.run(tf.global_variables_initializer()) 74 | for i in range(1000): 75 | batch_xs, batch_ys = mnist.train.next_batch(10) 76 | sess.run(train, feed_dict={x: batch_xs, y_: batch_ys}) 77 | 78 | if i % 10 == 0: 79 | print(f'Iteration {i}') 80 | summary = sess.run(all_summaries, feed_dict={x: mnist.test.images, y_: mnist.test.labels}) 81 | writer.add_summary(summary, global_step=i) 82 | writer.flush() 83 | 84 | correct_prediction = tf.equal(preds, labels) 85 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 86 | print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})) 87 | -------------------------------------------------------------------------------- /tfmpl/samples/scatter.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Christoph Heindl. 2 | # 3 | # Licensed under MIT License 4 | # ============================================================ 5 | 6 | from datetime import datetime 7 | import tensorflow as tf 8 | import numpy as np 9 | import os 10 | 11 | import tfmpl 12 | 13 | if __name__ == '__main__': 14 | 15 | with tf.Session(graph=tf.Graph()) as sess: 16 | 17 | @tfmpl.figure_tensor 18 | def draw_scatter(scaled, colors): 19 | '''Draw scatter plots. One for each color.''' 20 | figs = tfmpl.create_figures(len(colors), figsize=(4,4)) 21 | for idx, f in enumerate(figs): 22 | ax = f.add_subplot(111) 23 | ax.axis('off') 24 | ax.scatter(scaled[:, 0], scaled[:, 1], c=colors[idx]) 25 | f.tight_layout() 26 | 27 | return figs 28 | 29 | points = tf.constant(np.random.normal(loc=0.0, scale=1.0, size=(100, 2)).astype(np.float32)) 30 | scale = tf.placeholder(tf.float32) 31 | scaled = points*scale 32 | 33 | image_tensor = draw_scatter(scaled, ['r', 'g']) 34 | image_summary = tf.summary.image('scatter', image_tensor) 35 | all_summaries = tf.summary.merge_all() 36 | 37 | os.makedirs('log', exist_ok=True) 38 | now = datetime.now() 39 | logdir = "log/" + now.strftime("%Y%m%d-%H%M%S") + "/" 40 | writer = tf.summary.FileWriter(logdir, sess.graph) 41 | 42 | summary = sess.run(all_summaries, feed_dict={scale: 2.}) 43 | writer.add_summary(summary, global_step=0) 44 | writer.flush() 45 | -------------------------------------------------------------------------------- /tfmpl/samples/sgd.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Christoph Heindl. 2 | # 3 | # Licensed under MIT License 4 | # ============================================================ 5 | 6 | from mpl_toolkits.mplot3d.art3d import Line3DCollection 7 | from mpl_toolkits.mplot3d import Axes3D 8 | from matplotlib.colors import LogNorm 9 | from matplotlib import cm 10 | from datetime import datetime 11 | import tensorflow as tf 12 | import numpy as np 13 | import os 14 | 15 | import tfmpl 16 | 17 | if __name__ == '__main__': 18 | 19 | with tf.Session(graph=tf.Graph()) as sess: 20 | 21 | def beale(x, y): 22 | '''Beale surface for optimization tests.''' 23 | with tf.name_scope('beale', [x, y]): 24 | return (1.5 - x + x*y)**2 + (2.25 - x + x*y**2)**2 + (2.625 - x + x*y**3)**2 25 | 26 | # List of optimizers to compare 27 | optimizers = [ 28 | (tf.train.GradientDescentOptimizer(1e-3), 'SGD'), 29 | (tf.train.AdagradOptimizer(1e-1), 'Adagrad'), 30 | (tf.train.AdadeltaOptimizer(1e2), 'Adadelta'), 31 | (tf.train.AdamOptimizer(1e-1), 'Adam'), 32 | ] 33 | 34 | paths = [] 35 | history = [] 36 | 37 | def init_fig(*args, **kwargs): 38 | '''Initialize figures.''' 39 | fig = tfmpl.create_figure(figsize=(8,6)) 40 | ax = fig.add_subplot(111, projection='3d', elev=50, azim=-30) 41 | ax.w_xaxis.set_pane_color((1.0,1.0,1.0,1.0)) 42 | ax.w_yaxis.set_pane_color((1.0,1.0,1.0,1.0)) 43 | ax.w_zaxis.set_pane_color((1.0,1.0,1.0,1.0)) 44 | ax.set_title('Gradient descent on Beale surface') 45 | ax.set_xlabel('$x$') 46 | ax.set_ylabel('$y$') 47 | ax.set_zlabel('beale($x$,$y$)') 48 | 49 | xx, yy = np.meshgrid(np.linspace(-4.5, 4.5, 40), np.linspace(-4.5, 4.5, 40)) 50 | zz = beale(xx, yy) 51 | ax.plot_surface(xx, yy, zz, norm=LogNorm(), rstride=1, cstride=1, edgecolor='none', alpha=.8, cmap=cm.jet) 52 | ax.plot([3], [.5], [beale(3, .5)], 'k*', markersize=5) 53 | 54 | for o in optimizers: 55 | path, = ax.plot([],[],[], label=o[1]) 56 | paths.append(path) 57 | 58 | ax.legend(loc='upper left') 59 | fig.tight_layout() 60 | 61 | return fig, paths 62 | 63 | @tfmpl.blittable_figure_tensor(init_func=init_fig) 64 | def draw(xy, z): 65 | '''Updates paths for each optimizer.''' 66 | history.append(np.c_[xy, z]) 67 | xyz = np.stack(history) #NxMx3 68 | for idx, path in enumerate(paths): 69 | path.set_data(xyz[:, idx, 0], xyz[:, idx, 1]) 70 | path.set_3d_properties(xyz[:, idx, 2]) 71 | 72 | return paths 73 | 74 | # Create variables for each optimizer 75 | start = tf.constant_initializer([3., 4.], dtype=tf.float32) 76 | xys = [tf.get_variable(f'xy_{o[1]}', 2, tf.float32, initializer=start) for o in optimizers] 77 | zs = [beale(xy[0], xy[1]) for xy in xys] 78 | 79 | # Define optimization target 80 | train = [] 81 | for idx, (opt, name) in enumerate(optimizers): 82 | grads_and_vars = opt.compute_gradients(zs[idx], xys[idx]) 83 | clipped = [(tf.clip_by_value(g, -10, 10), v) for g, v in grads_and_vars] 84 | train.append(opt.apply_gradients(clipped)) 85 | 86 | # Generate summary 87 | image_tensor = draw(tf.stack(xys), tf.stack(zs)) 88 | image_summary = tf.summary.image('optimization', image_tensor) 89 | all_summaries = tf.summary.merge_all() 90 | 91 | # Alloc summary writer 92 | os.makedirs('log', exist_ok=True) 93 | now = datetime.now() 94 | logdir = "log/" + now.strftime("%Y%m%d-%H%M%S") + "/" 95 | writer = tf.summary.FileWriter(logdir, sess.graph) 96 | 97 | # Run optimization, write summary every now and then. 98 | init = tf.global_variables_initializer() 99 | sess.run(init) 100 | for i in range(200): 101 | if i % 10 == 0: 102 | summary = sess.run(all_summaries) 103 | writer.add_summary(summary, global_step=i) 104 | writer.flush() 105 | sess.run(train) -------------------------------------------------------------------------------- /tfmpl/tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Christoph Heindl. 2 | # 3 | # Licensed under MIT License 4 | # ============================================================ -------------------------------------------------------------------------------- /tfmpl/tests/test_confusion.py: -------------------------------------------------------------------------------- 1 | import tfmpl 2 | import numpy as np 3 | 4 | def test_confusion_matrix(): 5 | cm = tfmpl.plots.confusion_matrix.from_labels_and_predictions([1, 2, 4], [2, 2, 4], num_classes=5) 6 | exp = np.zeros((5,5), dtype=int) 7 | exp[1,2] = exp[2,2] = exp[4,4] = 1 8 | np.testing.assert_allclose(cm, exp) 9 | -------------------------------------------------------------------------------- /tfmpl/tests/test_figure.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Christoph Heindl. 2 | # 3 | # Licensed under MIT License 4 | # ============================================================ 5 | 6 | import tensorflow as tf 7 | import tfmpl 8 | import numpy as np 9 | 10 | def test_arguments(): 11 | 12 | debug = {} 13 | 14 | @tfmpl.figure_tensor 15 | def draw(a, b, c, d=None, e=None): 16 | debug['a'] = a 17 | debug['b'] = b 18 | debug['c'] = c 19 | debug['d'] = d 20 | debug['e'] = e 21 | 22 | return tfmpl.create_figure() 23 | 24 | with tf.Session(graph=tf.Graph()) as sess: 25 | a = tf.constant(0) 26 | c = tf.placeholder(tf.float32) 27 | 28 | tensor = draw(a, [0,1], c, d='d', e='e') 29 | sess.run(tensor, feed_dict={c: np.zeros((2,2))}) 30 | 31 | assert debug['a'] == 0 32 | assert debug['b'] == [0,1] 33 | np.testing.assert_allclose(debug['c'], np.zeros((2,2))) 34 | debug['d'] = 'd' 35 | debug['e'] = 'e' 36 | 37 | 38 | def test_arguments_blittable(): 39 | 40 | debug = {} 41 | 42 | def init(a, b, c, d=None, e=None): 43 | debug['init_args'] = [a, b, c, d, e] 44 | return tfmpl.create_figure(), None 45 | 46 | @tfmpl.blittable_figure_tensor(init_func=init) 47 | def draw(a, b, c, d=None, e=None): 48 | debug['args'] = [a, b, c, d, e] 49 | 50 | with tf.Session(graph=tf.Graph()) as sess: 51 | a = tf.constant(0) 52 | c = tf.placeholder(tf.float32) 53 | 54 | tensor = draw(a, [0,1], c, d='d', e='e') 55 | sess.run(tensor, feed_dict={c: np.zeros((2,2))}) 56 | 57 | assert debug['init_args'][0] == 0 58 | assert debug['init_args'][1] == [0,1] 59 | np.testing.assert_allclose(debug['init_args'][2], np.zeros((2,2))) 60 | assert debug['init_args'][3] == 'd' 61 | assert debug['init_args'][4] == 'e' 62 | 63 | assert debug['args'][0] == 0 64 | assert debug['args'][1] == [0,1] 65 | np.testing.assert_allclose(debug['args'][2], np.zeros((2,2))) 66 | assert debug['args'][3] == 'd' 67 | assert debug['args'][4] == 'e' 68 | 69 | def test_callcount(): 70 | 71 | debug = {} 72 | debug['called'] = 0 73 | debug['a'] = [] 74 | 75 | @tfmpl.figure_tensor 76 | def draw(a): 77 | debug['called'] += 1 78 | debug['a'].append(a) 79 | return tfmpl.create_figure() 80 | 81 | with tf.Session(graph=tf.Graph()) as sess: 82 | a = tf.placeholder(tf.float32) 83 | 84 | tensor = draw(a) 85 | 86 | for i in range(5): 87 | sess.run(tensor, feed_dict={a: i}) 88 | 89 | assert debug['called'] == 5 90 | np.testing.assert_allclose(debug['a'], [0,1,2,3,4]) 91 | 92 | def test_callcount_blittable(): 93 | 94 | debug = {} 95 | debug['init_called'] = 0 96 | debug['draw_called'] = 0 97 | debug['a'] = [] 98 | debug['a_init'] = [] 99 | 100 | def init(a): 101 | debug['init_called'] += 1 102 | debug['a_init'] = a 103 | return tfmpl.create_figure(), None 104 | 105 | @tfmpl.blittable_figure_tensor(init_func=init) 106 | def draw(a): 107 | debug['draw_called'] += 1 108 | debug['a'].append(a) 109 | 110 | with tf.Session(graph=tf.Graph()) as sess: 111 | a = tf.placeholder(tf.float32) 112 | 113 | tensor = draw(a) 114 | 115 | for i in range(5): 116 | sess.run(tensor, feed_dict={a: i}) 117 | 118 | assert debug['init_called'] == 1 119 | assert debug['draw_called'] == 5 120 | assert debug['a_init'] == 0 121 | np.testing.assert_allclose(debug['a'], [0,1,2,3,4]) 122 | 123 | def test_callcount_blittable(): 124 | 125 | debug = {} 126 | debug['init_called'] = 0 127 | debug['draw_called'] = 0 128 | debug['a'] = [] 129 | debug['a_init'] = [] 130 | 131 | def init(a): 132 | debug['init_called'] += 1 133 | debug['a_init'] = a 134 | return tfmpl.create_figure(), None 135 | 136 | @tfmpl.blittable_figure_tensor(init_func=init) 137 | def draw(a): 138 | debug['draw_called'] += 1 139 | debug['a'].append(a) 140 | 141 | with tf.Session(graph=tf.Graph()) as sess: 142 | a = tf.placeholder(tf.float32) 143 | 144 | tensor = draw(a) 145 | 146 | for i in range(5): 147 | sess.run(tensor, feed_dict={a: i}) 148 | 149 | assert debug['init_called'] == 1 150 | assert debug['draw_called'] == 5 151 | assert debug['a_init'] == 0 152 | np.testing.assert_allclose(debug['a'], [0,1,2,3,4]) 153 | 154 | def test_draw(): 155 | 156 | @tfmpl.figure_tensor 157 | def draw(): 158 | figs = tfmpl.create_figures(2, figsize=(4,3), dpi=100) 159 | 160 | figs[0].patch.set_facecolor('red') 161 | figs[1].patch.set_facecolor((0, 1, 0)) 162 | 163 | return figs 164 | 165 | with tf.Session(graph=tf.Graph()) as sess: 166 | a = tf.placeholder(tf.float32) 167 | 168 | tensor = draw() 169 | 170 | imgs = sess.run(tensor) 171 | assert imgs.shape == (2, 300, 400, 3) 172 | np.testing.assert_allclose(imgs[0], np.tile([255, 0, 0], (300, 400, 1))) 173 | np.testing.assert_allclose(imgs[1], np.tile([0, 255, 0], (300, 400, 1))) 174 | 175 | def test_draw_blittable(): 176 | import matplotlib.patches as patches 177 | 178 | rect = None 179 | 180 | def init(t): 181 | nonlocal rect 182 | fig = tfmpl.create_figure(figsize=(4,4), dpi=100) 183 | ax = fig.add_axes([0,0,1,1]) 184 | ax.invert_yaxis() 185 | rect = ax.add_patch(patches.Rectangle((0,0), 0.1, 0.1, facecolor=(0,1,0))) 186 | return fig, rect 187 | 188 | @tfmpl.blittable_figure_tensor(init_func=init) 189 | def draw(t): 190 | rect.set_xy((t,t)) 191 | return rect 192 | 193 | with tf.Session(graph=tf.Graph()) as sess: 194 | t = tf.placeholder(tf.float32) 195 | tensor = draw(t) 196 | 197 | imgs = sess.run(tensor, feed_dict={t:0}) 198 | assert imgs.shape == (1, 400, 400, 3) 199 | 200 | np.testing.assert_allclose(imgs[0, :40, :40], np.tile([0, 255, 0], (40, 40, 1))) 201 | 202 | imgs = sess.run(tensor, feed_dict={t:0.5}) 203 | assert imgs.shape == (1, 400, 400, 3) 204 | np.testing.assert_allclose(imgs[0, 200:240, 200:240], np.tile([0, 255, 0], (40, 40, 1))) 205 | --------------------------------------------------------------------------------