├── .gitignore ├── .travis.yml ├── LICENSE ├── MANIFEST.in ├── README.md ├── mplexporter ├── __init__.py ├── _py3k_compat.py ├── convertors.py ├── exporter.py ├── renderers │ ├── __init__.py │ ├── base.py │ ├── fake_renderer.py │ ├── vega_renderer.py │ └── vincent_renderer.py ├── tests │ ├── __init__.py │ ├── test_basic.py │ ├── test_convertors.py │ └── test_utils.py ├── tools.py └── utils.py ├── notebooks ├── VegaTest.ipynb └── VincentTest.ipynb └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.py[cod] 2 | 3 | # C extensions 4 | *.so 5 | 6 | # Packages 7 | *.egg 8 | *.egg-info 9 | dist 10 | build 11 | eggs 12 | parts 13 | bin 14 | var 15 | sdist 16 | develop-eggs 17 | .installed.cfg 18 | lib 19 | lib64 20 | __pycache__ 21 | 22 | # Installer logs 23 | pip-log.txt 24 | 25 | # Unit test / coverage reports 26 | .coverage 27 | .tox 28 | nosetests.xml 29 | 30 | # Translations 31 | *.mo 32 | 33 | # Mr Developer 34 | .mr.developer.cfg 35 | .project 36 | .pydevproject 37 | 38 | 39 | # emacs backup files 40 | *~ 41 | 42 | # ipython backups 43 | .ipynb_checkpoints 44 | 45 | # os x files 46 | .DS_Store 47 | 48 | # VIM 49 | *.sw* 50 | 51 | # virtualenv 52 | virtualenv/ 53 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: false 2 | language: python 3 | 4 | python: 5 | - 2.7 6 | - 3.5 7 | 8 | env: 9 | - DEPS="numpy=1.11 matplotlib=2.2.3 jinja2=2.8 pandas=0.18 nose2" 10 | 11 | install: 12 | - conda create -n testenv --yes python=$TRAVIS_PYTHON_VERSION 13 | - source activate testenv 14 | - conda install --yes $DEPS 15 | - python setup.py install 16 | 17 | before_install: 18 | # setup virtual x 19 | - "export DISPLAY=:99.0" 20 | - "sh -e /etc/init.d/xvfb start" 21 | # then install python version to test 22 | - wget https://repo.continuum.io/miniconda/Miniconda-latest-Linux-x86_64.sh -O miniconda.sh 23 | - chmod +x miniconda.sh 24 | - bash miniconda.sh -b -p $HOME/miniconda 25 | - export PATH="$HOME/miniconda/bin:$PATH" 26 | # Learned the hard way: miniconda is not always up-to-date with conda. 27 | - conda update --yes conda 28 | 29 | script: 30 | - MPLBE=Agg nose2 mplexporter 31 | 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2014, mpld3 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, 5 | are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, this 11 | list of conditions and the following disclaimer in the documentation and/or 12 | other materials provided with the distribution. 13 | 14 | * Neither the name of the {organization} nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 19 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 20 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 22 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 23 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 24 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 25 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 27 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include *.md 2 | include LICENSE 3 | 4 | recursive-include mplexporter *.py 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | mplexporter 2 | =========== 3 | 4 | A proof of concept general matplotlib exporter 5 | -------------------------------------------------------------------------------- /mplexporter/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | MPLBE = os.environ.get('MPLBE', False) 4 | 5 | if MPLBE: 6 | import matplotlib 7 | matplotlib.use(MPLBE) 8 | 9 | import matplotlib.pyplot as plt 10 | from .renderers import Renderer 11 | from .exporter import Exporter 12 | from .convertors import StrMethodTickFormatterConvertor 13 | -------------------------------------------------------------------------------- /mplexporter/_py3k_compat.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple fixes for Python 2/3 compatibility 3 | """ 4 | import sys 5 | PY3K = sys.version_info[0] >= 3 6 | 7 | 8 | if PY3K: 9 | import builtins 10 | import functools 11 | reduce = functools.reduce 12 | zip = builtins.zip 13 | xrange = builtins.range 14 | map = builtins.map 15 | else: 16 | import __builtin__ 17 | import itertools 18 | builtins = __builtin__ 19 | reduce = __builtin__.reduce 20 | zip = itertools.izip 21 | xrange = __builtin__.xrange 22 | map = itertools.imap 23 | -------------------------------------------------------------------------------- /mplexporter/convertors.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | from string import Formatter 3 | 4 | class StrMethodTickFormatterConvertor(object): 5 | 6 | STRING_FORMAT_D3 = "d3-format" 7 | 8 | SUPPORTED_OUTPUT_FORMATS = ( 9 | STRING_FORMAT_D3, 10 | ) 11 | 12 | def __init__(self, formatter, output_format=STRING_FORMAT_D3): 13 | assert output_format in self.SUPPORTED_OUTPUT_FORMATS, "Unknown output_format" 14 | if not isinstance(formatter, matplotlib.ticker.StrMethodFormatter): 15 | raise ValueError("Formatter must be of type `matplotlib.ticker.StrMethodFormatter`") 16 | self.formatter = formatter 17 | self.output_format = output_format 18 | 19 | @property 20 | def is_output_d3(self): 21 | return self.output_format == self.STRING_FORMAT_D3 22 | 23 | def export_mpl_format_str_d3(self, mpl_format_str): 24 | prefixes = [] 25 | suffixes = [] 26 | before_x = True 27 | format_spec_for_d3 = "" 28 | for literal_text, field_name, format_spec, conversion in Formatter().parse(mpl_format_str): 29 | if before_x: 30 | prefixes.append(literal_text) 31 | else: 32 | suffixes.append(literal_text) 33 | 34 | if field_name == "x" and format_spec and format_spec_for_d3 and self.is_output_d3: 35 | raise ValueError("D3 doesn't support multiple conversions") 36 | 37 | if field_name == "x": 38 | before_x = False 39 | format_spec_for_d3 = format_spec 40 | 41 | prefix = "".join(prefixes) 42 | suffix = "".join(suffixes) 43 | return { 44 | "format_string": format_spec_for_d3, 45 | "prefix": prefix, 46 | "suffix": suffix 47 | } 48 | 49 | @property 50 | def output(self): 51 | # just incase we want to support something other than d3 52 | if self.is_output_d3: 53 | return self.export_mpl_format_str_d3(self.formatter.fmt) 54 | -------------------------------------------------------------------------------- /mplexporter/exporter.py: -------------------------------------------------------------------------------- 1 | """ 2 | Matplotlib Exporter 3 | =================== 4 | This submodule contains tools for crawling a matplotlib figure and exporting 5 | relevant pieces to a renderer. 6 | """ 7 | import warnings 8 | import io 9 | from . import utils 10 | 11 | import matplotlib 12 | from matplotlib import transforms, collections 13 | from matplotlib.backends.backend_agg import FigureCanvasAgg 14 | 15 | class Exporter(object): 16 | """Matplotlib Exporter 17 | 18 | Parameters 19 | ---------- 20 | renderer : Renderer object 21 | The renderer object called by the exporter to create a figure 22 | visualization. See mplexporter.Renderer for information on the 23 | methods which should be defined within the renderer. 24 | close_mpl : bool 25 | If True (default), close the matplotlib figure as it is rendered. This 26 | is useful for when the exporter is used within the notebook, or with 27 | an interactive matplotlib backend. 28 | """ 29 | 30 | def __init__(self, renderer, close_mpl=True): 31 | self.close_mpl = close_mpl 32 | self.renderer = renderer 33 | 34 | def run(self, fig): 35 | """ 36 | Run the exporter on the given figure 37 | 38 | Parmeters 39 | --------- 40 | fig : matplotlib.Figure instance 41 | The figure to export 42 | """ 43 | # Calling savefig executes the draw() command, putting elements 44 | # in the correct place. 45 | if fig.canvas is None: 46 | canvas = FigureCanvasAgg(fig) 47 | fig.savefig(io.BytesIO(), format='png', dpi=fig.dpi) 48 | if self.close_mpl: 49 | import matplotlib.pyplot as plt 50 | plt.close(fig) 51 | self.crawl_fig(fig) 52 | 53 | @staticmethod 54 | def process_transform(transform, ax=None, data=None, return_trans=False, 55 | force_trans=None): 56 | """Process the transform and convert data to figure or data coordinates 57 | 58 | Parameters 59 | ---------- 60 | transform : matplotlib Transform object 61 | The transform applied to the data 62 | ax : matplotlib Axes object (optional) 63 | The axes the data is associated with 64 | data : ndarray (optional) 65 | The array of data to be transformed. 66 | return_trans : bool (optional) 67 | If true, return the final transform of the data 68 | force_trans : matplotlib.transform instance (optional) 69 | If supplied, first force the data to this transform 70 | 71 | Returns 72 | ------- 73 | code : string 74 | Code is either "data", "axes", "figure", or "display", indicating 75 | the type of coordinates output. 76 | transform : matplotlib transform 77 | the transform used to map input data to output data. 78 | Returned only if return_trans is True 79 | new_data : ndarray 80 | Data transformed to match the given coordinate code. 81 | Returned only if data is specified 82 | """ 83 | if isinstance(transform, transforms.BlendedGenericTransform): 84 | warnings.warn("Blended transforms not yet supported. " 85 | "Zoom behavior may not work as expected.") 86 | 87 | if force_trans is not None: 88 | if data is not None: 89 | data = (transform - force_trans).transform(data) 90 | transform = force_trans 91 | 92 | code = "display" 93 | if ax is not None: 94 | for (c, trans) in [("data", ax.transData), 95 | ("axes", ax.transAxes), 96 | ("figure", ax.figure.transFigure), 97 | ("display", transforms.IdentityTransform())]: 98 | if transform.contains_branch(trans): 99 | code, transform = (c, transform - trans) 100 | break 101 | 102 | if data is not None: 103 | if return_trans: 104 | return code, transform.transform(data), transform 105 | else: 106 | return code, transform.transform(data) 107 | else: 108 | if return_trans: 109 | return code, transform 110 | else: 111 | return code 112 | 113 | def crawl_fig(self, fig): 114 | """Crawl the figure and process all axes""" 115 | with self.renderer.draw_figure(fig=fig, 116 | props=utils.get_figure_properties(fig)): 117 | for ax in fig.axes: 118 | self.crawl_ax(ax) 119 | 120 | def crawl_ax(self, ax): 121 | """Crawl the axes and process all elements within""" 122 | with self.renderer.draw_axes(ax=ax, 123 | props=utils.get_axes_properties(ax)): 124 | for line in ax.lines: 125 | self.draw_line(ax, line) 126 | for text in ax.texts: 127 | self.draw_text(ax, text) 128 | for (text, ttp) in zip([ax.xaxis.label, ax.yaxis.label, ax.title], 129 | ["xlabel", "ylabel", "title"]): 130 | if(hasattr(text, 'get_text') and text.get_text()): 131 | self.draw_text(ax, text, force_trans=ax.transAxes, 132 | text_type=ttp) 133 | for artist in ax.artists: 134 | # TODO: process other artists 135 | if isinstance(artist, matplotlib.text.Text): 136 | self.draw_text(ax, artist) 137 | for patch in ax.patches: 138 | self.draw_patch(ax, patch) 139 | for collection in ax.collections: 140 | self.draw_collection(ax, collection) 141 | for image in ax.images: 142 | self.draw_image(ax, image) 143 | 144 | legend = ax.get_legend() 145 | if legend is not None: 146 | props = utils.get_legend_properties(ax, legend) 147 | with self.renderer.draw_legend(legend=legend, props=props): 148 | if props['visible']: 149 | self.crawl_legend(ax, legend) 150 | 151 | def crawl_legend(self, ax, legend): 152 | """ 153 | Recursively look through objects in legend children 154 | """ 155 | legendElements = list(utils.iter_all_children(legend._legend_box, 156 | skipContainers=True)) 157 | legendElements.append(legend.legendPatch) 158 | for child in legendElements: 159 | # force a large zorder so it appears on top 160 | child.set_zorder(1E6 + child.get_zorder()) 161 | 162 | # reorder border box to make sure marks are visible 163 | if isinstance(child, matplotlib.patches.FancyBboxPatch): 164 | child.set_zorder(child.get_zorder()-1) 165 | 166 | try: 167 | # What kind of object... 168 | if isinstance(child, matplotlib.patches.Patch): 169 | self.draw_patch(ax, child, force_trans=ax.transAxes) 170 | elif isinstance(child, matplotlib.text.Text): 171 | if child.get_text() != 'None': 172 | self.draw_text(ax, child, force_trans=ax.transAxes) 173 | elif isinstance(child, matplotlib.lines.Line2D): 174 | self.draw_line(ax, child, force_trans=ax.transAxes) 175 | elif isinstance(child, matplotlib.collections.Collection): 176 | self.draw_collection(ax, child, 177 | force_offsettrans=ax.transAxes) 178 | else: 179 | warnings.warn("Legend element %s not impemented" % child) 180 | except NotImplementedError: 181 | warnings.warn("Legend element %s not impemented" % child) 182 | 183 | def draw_line(self, ax, line, force_trans=None): 184 | """Process a matplotlib line and call renderer.draw_line""" 185 | coordinates, data = self.process_transform(line.get_transform(), 186 | ax, line.get_xydata(), 187 | force_trans=force_trans) 188 | linestyle = utils.get_line_style(line) 189 | if (linestyle['dasharray'] is None 190 | and linestyle['drawstyle'] == 'default'): 191 | linestyle = None 192 | markerstyle = utils.get_marker_style(line) 193 | if (markerstyle['marker'] in ['None', 'none', None] 194 | or markerstyle['markerpath'][0].size == 0): 195 | markerstyle = None 196 | label = line.get_label() 197 | if markerstyle or linestyle: 198 | self.renderer.draw_marked_line(data=data, coordinates=coordinates, 199 | linestyle=linestyle, 200 | markerstyle=markerstyle, 201 | label=label, 202 | mplobj=line) 203 | 204 | def draw_text(self, ax, text, force_trans=None, text_type=None): 205 | """Process a matplotlib text object and call renderer.draw_text""" 206 | content = text.get_text() 207 | if content: 208 | transform = text.get_transform() 209 | position = text.get_position() 210 | coords, position = self.process_transform(transform, ax, 211 | position, 212 | force_trans=force_trans) 213 | style = utils.get_text_style(text) 214 | self.renderer.draw_text(text=content, position=position, 215 | coordinates=coords, 216 | text_type=text_type, 217 | style=style, mplobj=text) 218 | 219 | def draw_patch(self, ax, patch, force_trans=None): 220 | """Process a matplotlib patch object and call renderer.draw_path""" 221 | vertices, pathcodes = utils.SVG_path(patch.get_path()) 222 | transform = patch.get_transform() 223 | coordinates, vertices = self.process_transform(transform, 224 | ax, vertices, 225 | force_trans=force_trans) 226 | linestyle = utils.get_path_style(patch, fill=patch.get_fill()) 227 | self.renderer.draw_path(data=vertices, 228 | coordinates=coordinates, 229 | pathcodes=pathcodes, 230 | style=linestyle, 231 | mplobj=patch) 232 | 233 | def draw_collection(self, ax, collection, 234 | force_pathtrans=None, 235 | force_offsettrans=None): 236 | """Process a matplotlib collection and call renderer.draw_collection""" 237 | (transform, transOffset, 238 | offsets, paths) = collection._prepare_points() 239 | 240 | offset_coords, offsets = self.process_transform( 241 | transOffset, ax, offsets, force_trans=force_offsettrans) 242 | path_coords = self.process_transform( 243 | transform, ax, force_trans=force_pathtrans) 244 | 245 | processed_paths = [utils.SVG_path(path) for path in paths] 246 | processed_paths = [(self.process_transform( 247 | transform, ax, path[0], force_trans=force_pathtrans)[1], path[1]) 248 | for path in processed_paths] 249 | 250 | path_transforms = collection.get_transforms() 251 | try: 252 | # matplotlib 1.3: path_transforms are transform objects. 253 | # Convert them to numpy arrays. 254 | path_transforms = [t.get_matrix() for t in path_transforms] 255 | except AttributeError: 256 | # matplotlib 1.4: path transforms are already numpy arrays. 257 | pass 258 | 259 | styles = {'linewidth': collection.get_linewidths(), 260 | 'facecolor': collection.get_facecolors(), 261 | 'edgecolor': collection.get_edgecolors(), 262 | 'alpha': collection._alpha, 263 | 'zorder': collection.get_zorder()} 264 | 265 | self.renderer.draw_path_collection(paths=processed_paths, 266 | path_coordinates=path_coords, 267 | path_transforms=path_transforms, 268 | offsets=offsets, 269 | offset_coordinates=offset_coords, 270 | offset_order="after", 271 | styles=styles, 272 | mplobj=collection) 273 | 274 | def draw_image(self, ax, image): 275 | """Process a matplotlib image object and call renderer.draw_image""" 276 | self.renderer.draw_image(imdata=utils.image_to_base64(image), 277 | extent=image.get_extent(), 278 | coordinates="data", 279 | style={"alpha": image.get_alpha(), 280 | "zorder": image.get_zorder()}, 281 | mplobj=image) 282 | -------------------------------------------------------------------------------- /mplexporter/renderers/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Matplotlib Renderers 3 | ==================== 4 | This submodule contains renderer objects which define renderer behavior used 5 | within the Exporter class. The base renderer class is :class:`Renderer`, an 6 | abstract base class 7 | """ 8 | 9 | from .base import Renderer 10 | from .vega_renderer import VegaRenderer, fig_to_vega 11 | from .vincent_renderer import VincentRenderer, fig_to_vincent 12 | from .fake_renderer import FakeRenderer, FullFakeRenderer 13 | -------------------------------------------------------------------------------- /mplexporter/renderers/base.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import itertools 3 | from contextlib import contextmanager 4 | from packaging.version import Version 5 | 6 | import numpy as np 7 | import matplotlib as mpl 8 | from matplotlib import transforms 9 | 10 | from .. import utils 11 | from .. import _py3k_compat as py3k 12 | 13 | 14 | class Renderer(object): 15 | @staticmethod 16 | def ax_zoomable(ax): 17 | return bool(ax and ax.get_navigate()) 18 | 19 | @staticmethod 20 | def ax_has_xgrid(ax): 21 | return bool(ax and ax.xaxis._major_tick_kw['gridOn'] and ax.yaxis.get_gridlines()) 22 | 23 | @staticmethod 24 | def ax_has_ygrid(ax): 25 | return bool(ax and ax.yaxis._major_tick_kw['gridOn'] and ax.yaxis.get_gridlines()) 26 | 27 | @property 28 | def current_ax_zoomable(self): 29 | return self.ax_zoomable(self._current_ax) 30 | 31 | @property 32 | def current_ax_has_xgrid(self): 33 | return self.ax_has_xgrid(self._current_ax) 34 | 35 | @property 36 | def current_ax_has_ygrid(self): 37 | return self.ax_has_ygrid(self._current_ax) 38 | 39 | @contextmanager 40 | def draw_figure(self, fig, props): 41 | if hasattr(self, "_current_fig") and self._current_fig is not None: 42 | warnings.warn("figure embedded in figure: something is wrong") 43 | self._current_fig = fig 44 | self._fig_props = props 45 | self.open_figure(fig=fig, props=props) 46 | yield 47 | self.close_figure(fig=fig) 48 | self._current_fig = None 49 | self._fig_props = {} 50 | 51 | @contextmanager 52 | def draw_axes(self, ax, props): 53 | if hasattr(self, "_current_ax") and self._current_ax is not None: 54 | warnings.warn("axes embedded in axes: something is wrong") 55 | self._current_ax = ax 56 | self._ax_props = props 57 | self.open_axes(ax=ax, props=props) 58 | yield 59 | self.close_axes(ax=ax) 60 | self._current_ax = None 61 | self._ax_props = {} 62 | 63 | @contextmanager 64 | def draw_legend(self, legend, props): 65 | self._current_legend = legend 66 | self._legend_props = props 67 | self.open_legend(legend=legend, props=props) 68 | yield 69 | self.close_legend(legend=legend) 70 | self._current_legend = None 71 | self._legend_props = {} 72 | 73 | # Following are the functions which should be overloaded in subclasses 74 | 75 | def open_figure(self, fig, props): 76 | """ 77 | Begin commands for a particular figure. 78 | 79 | Parameters 80 | ---------- 81 | fig : matplotlib.Figure 82 | The Figure which will contain the ensuing axes and elements 83 | props : dictionary 84 | The dictionary of figure properties 85 | """ 86 | pass 87 | 88 | def close_figure(self, fig): 89 | """ 90 | Finish commands for a particular figure. 91 | 92 | Parameters 93 | ---------- 94 | fig : matplotlib.Figure 95 | The figure which is finished being drawn. 96 | """ 97 | pass 98 | 99 | def open_axes(self, ax, props): 100 | """ 101 | Begin commands for a particular axes. 102 | 103 | Parameters 104 | ---------- 105 | ax : matplotlib.Axes 106 | The Axes which will contain the ensuing axes and elements 107 | props : dictionary 108 | The dictionary of axes properties 109 | """ 110 | pass 111 | 112 | def close_axes(self, ax): 113 | """ 114 | Finish commands for a particular axes. 115 | 116 | Parameters 117 | ---------- 118 | ax : matplotlib.Axes 119 | The Axes which is finished being drawn. 120 | """ 121 | pass 122 | 123 | def open_legend(self, legend, props): 124 | """ 125 | Beging commands for a particular legend. 126 | 127 | Parameters 128 | ---------- 129 | legend : matplotlib.legend.Legend 130 | The Legend that will contain the ensuing elements 131 | props : dictionary 132 | The dictionary of legend properties 133 | """ 134 | pass 135 | 136 | def close_legend(self, legend): 137 | """ 138 | Finish commands for a particular legend. 139 | 140 | Parameters 141 | ---------- 142 | legend : matplotlib.legend.Legend 143 | The Legend which is finished being drawn 144 | """ 145 | pass 146 | 147 | def draw_marked_line(self, data, coordinates, linestyle, markerstyle, 148 | label, mplobj=None): 149 | """Draw a line that also has markers. 150 | 151 | If this isn't reimplemented by a renderer object, by default, it will 152 | make a call to BOTH draw_line and draw_markers when both markerstyle 153 | and linestyle are not None in the same Line2D object. 154 | 155 | """ 156 | if linestyle is not None: 157 | self.draw_line(data, coordinates, linestyle, label, mplobj) 158 | if markerstyle is not None: 159 | self.draw_markers(data, coordinates, markerstyle, label, mplobj) 160 | 161 | def draw_line(self, data, coordinates, style, label, mplobj=None): 162 | """ 163 | Draw a line. By default, draw the line via the draw_path() command. 164 | Some renderers might wish to override this and provide more 165 | fine-grained behavior. 166 | 167 | In matplotlib, lines are generally created via the plt.plot() command, 168 | though this command also can create marker collections. 169 | 170 | Parameters 171 | ---------- 172 | data : array_like 173 | A shape (N, 2) array of datapoints. 174 | coordinates : string 175 | A string code, which should be either 'data' for data coordinates, 176 | or 'figure' for figure (pixel) coordinates. 177 | style : dictionary 178 | a dictionary specifying the appearance of the line. 179 | mplobj : matplotlib object 180 | the matplotlib plot element which generated this line 181 | """ 182 | pathcodes = ['M'] + (data.shape[0] - 1) * ['L'] 183 | pathstyle = dict(facecolor='none', **style) 184 | pathstyle['edgecolor'] = pathstyle.pop('color') 185 | pathstyle['edgewidth'] = pathstyle.pop('linewidth') 186 | self.draw_path(data=data, coordinates=coordinates, 187 | pathcodes=pathcodes, style=pathstyle, mplobj=mplobj) 188 | 189 | @staticmethod 190 | def _iter_path_collection(paths, path_transforms, offsets, styles): 191 | """Build an iterator over the elements of the path collection""" 192 | N = max(len(paths), len(offsets)) 193 | 194 | # Before mpl 1.4.0, path_transform can be a false-y value, not a valid 195 | # transformation matrix. 196 | if Version(mpl.__version__) < Version('1.4.0'): 197 | if path_transforms is None: 198 | path_transforms = [np.eye(3)] 199 | 200 | edgecolor = styles['edgecolor'] 201 | if np.size(edgecolor) == 0: 202 | edgecolor = ['none'] 203 | facecolor = styles['facecolor'] 204 | if np.size(facecolor) == 0: 205 | facecolor = ['none'] 206 | 207 | elements = [paths, path_transforms, offsets, 208 | edgecolor, styles['linewidth'], facecolor] 209 | 210 | it = itertools 211 | return it.islice(py3k.zip(*py3k.map(it.cycle, elements)), N) 212 | 213 | def draw_path_collection(self, paths, path_coordinates, path_transforms, 214 | offsets, offset_coordinates, offset_order, 215 | styles, mplobj=None): 216 | """ 217 | Draw a collection of paths. The paths, offsets, and styles are all 218 | iterables, and the number of paths is max(len(paths), len(offsets)). 219 | 220 | By default, this is implemented via multiple calls to the draw_path() 221 | function. For efficiency, Renderers may choose to customize this 222 | implementation. 223 | 224 | Examples of path collections created by matplotlib are scatter plots, 225 | histograms, contour plots, and many others. 226 | 227 | Parameters 228 | ---------- 229 | paths : list 230 | list of tuples, where each tuple has two elements: 231 | (data, pathcodes). See draw_path() for a description of these. 232 | path_coordinates: string 233 | the coordinates code for the paths, which should be either 234 | 'data' for data coordinates, or 'figure' for figure (pixel) 235 | coordinates. 236 | path_transforms: array_like 237 | an array of shape (*, 3, 3), giving a series of 2D Affine 238 | transforms for the paths. These encode translations, rotations, 239 | and scalings in the standard way. 240 | offsets: array_like 241 | An array of offsets of shape (N, 2) 242 | offset_coordinates : string 243 | the coordinates code for the offsets, which should be either 244 | 'data' for data coordinates, or 'figure' for figure (pixel) 245 | coordinates. 246 | offset_order : string 247 | either "before" or "after". This specifies whether the offset 248 | is applied before the path transform, or after. The matplotlib 249 | backend equivalent is "before"->"data", "after"->"screen". 250 | styles: dictionary 251 | A dictionary in which each value is a list of length N, containing 252 | the style(s) for the paths. 253 | mplobj : matplotlib object 254 | the matplotlib plot element which generated this collection 255 | """ 256 | if offset_order == "before": 257 | raise NotImplementedError("offset before transform") 258 | 259 | for tup in self._iter_path_collection(paths, path_transforms, 260 | offsets, styles): 261 | (path, path_transform, offset, ec, lw, fc) = tup 262 | vertices, pathcodes = path 263 | path_transform = transforms.Affine2D(path_transform) 264 | vertices = path_transform.transform(vertices) 265 | # This is a hack: 266 | if path_coordinates == "figure": 267 | path_coordinates = "points" 268 | style = {"edgecolor": utils.export_color(ec), 269 | "facecolor": utils.export_color(fc), 270 | "edgewidth": lw, 271 | "dasharray": "10,0", 272 | "alpha": styles['alpha'], 273 | "zorder": styles['zorder']} 274 | self.draw_path(data=vertices, coordinates=path_coordinates, 275 | pathcodes=pathcodes, style=style, offset=offset, 276 | offset_coordinates=offset_coordinates, 277 | mplobj=mplobj) 278 | 279 | def draw_markers(self, data, coordinates, style, label, mplobj=None): 280 | """ 281 | Draw a set of markers. By default, this is done by repeatedly 282 | calling draw_path(), but renderers should generally overload 283 | this method to provide a more efficient implementation. 284 | 285 | In matplotlib, markers are created using the plt.plot() command. 286 | 287 | Parameters 288 | ---------- 289 | data : array_like 290 | A shape (N, 2) array of datapoints. 291 | coordinates : string 292 | A string code, which should be either 'data' for data coordinates, 293 | or 'figure' for figure (pixel) coordinates. 294 | style : dictionary 295 | a dictionary specifying the appearance of the markers. 296 | mplobj : matplotlib object 297 | the matplotlib plot element which generated this marker collection 298 | """ 299 | vertices, pathcodes = style['markerpath'] 300 | pathstyle = dict((key, style[key]) for key in ['alpha', 'edgecolor', 301 | 'facecolor', 'zorder', 302 | 'edgewidth']) 303 | pathstyle['dasharray'] = "10,0" 304 | for vertex in data: 305 | self.draw_path(data=vertices, coordinates="points", 306 | pathcodes=pathcodes, style=pathstyle, 307 | offset=vertex, offset_coordinates=coordinates, 308 | mplobj=mplobj) 309 | 310 | def draw_text(self, text, position, coordinates, style, 311 | text_type=None, mplobj=None): 312 | """ 313 | Draw text on the image. 314 | 315 | Parameters 316 | ---------- 317 | text : string 318 | The text to draw 319 | position : tuple 320 | The (x, y) position of the text 321 | coordinates : string 322 | A string code, which should be either 'data' for data coordinates, 323 | or 'figure' for figure (pixel) coordinates. 324 | style : dictionary 325 | a dictionary specifying the appearance of the text. 326 | text_type : string or None 327 | if specified, a type of text such as "xlabel", "ylabel", "title" 328 | mplobj : matplotlib object 329 | the matplotlib plot element which generated this text 330 | """ 331 | raise NotImplementedError() 332 | 333 | def draw_path(self, data, coordinates, pathcodes, style, 334 | offset=None, offset_coordinates="data", mplobj=None): 335 | """ 336 | Draw a path. 337 | 338 | In matplotlib, paths are created by filled regions, histograms, 339 | contour plots, patches, etc. 340 | 341 | Parameters 342 | ---------- 343 | data : array_like 344 | A shape (N, 2) array of datapoints. 345 | coordinates : string 346 | A string code, which should be either 'data' for data coordinates, 347 | 'figure' for figure (pixel) coordinates, or "points" for raw 348 | point coordinates (useful in conjunction with offsets, below). 349 | pathcodes : list 350 | A list of single-character SVG pathcodes associated with the data. 351 | Path codes are one of ['M', 'm', 'L', 'l', 'Q', 'q', 'T', 't', 352 | 'S', 's', 'C', 'c', 'Z', 'z'] 353 | See the SVG specification for details. Note that some path codes 354 | consume more than one datapoint (while 'Z' consumes none), so 355 | in general, the length of the pathcodes list will not be the same 356 | as that of the data array. 357 | style : dictionary 358 | a dictionary specifying the appearance of the line. 359 | offset : list (optional) 360 | the (x, y) offset of the path. If not given, no offset will 361 | be used. 362 | offset_coordinates : string (optional) 363 | A string code, which should be either 'data' for data coordinates, 364 | or 'figure' for figure (pixel) coordinates. 365 | mplobj : matplotlib object 366 | the matplotlib plot element which generated this path 367 | """ 368 | raise NotImplementedError() 369 | 370 | def draw_image(self, imdata, extent, coordinates, style, mplobj=None): 371 | """ 372 | Draw an image. 373 | 374 | Parameters 375 | ---------- 376 | imdata : string 377 | base64 encoded png representation of the image 378 | extent : list 379 | the axes extent of the image: [xmin, xmax, ymin, ymax] 380 | coordinates: string 381 | A string code, which should be either 'data' for data coordinates, 382 | or 'figure' for figure (pixel) coordinates. 383 | style : dictionary 384 | a dictionary specifying the appearance of the image 385 | mplobj : matplotlib object 386 | the matplotlib plot object which generated this image 387 | """ 388 | raise NotImplementedError() 389 | -------------------------------------------------------------------------------- /mplexporter/renderers/fake_renderer.py: -------------------------------------------------------------------------------- 1 | from .base import Renderer 2 | 3 | 4 | class FakeRenderer(Renderer): 5 | """ 6 | Fake Renderer 7 | 8 | This is a fake renderer which simply outputs a text tree representing the 9 | elements found in the plot(s). This is used in the unit tests for the 10 | package. 11 | 12 | Below are the methods your renderer must implement. You are free to do 13 | anything you wish within the renderer (i.e. build an XML or JSON 14 | representation, call an external API, etc.) Here the renderer just 15 | builds a simple string representation for testing purposes. 16 | """ 17 | def __init__(self): 18 | self.output = "" 19 | 20 | def open_figure(self, fig, props): 21 | self.output += "opening figure\n" 22 | 23 | def close_figure(self, fig): 24 | self.output += "closing figure\n" 25 | 26 | def open_axes(self, ax, props): 27 | self.output += " opening axes\n" 28 | 29 | def close_axes(self, ax): 30 | self.output += " closing axes\n" 31 | 32 | def open_legend(self, legend, props): 33 | self.output += " opening legend\n" 34 | 35 | def close_legend(self, legend): 36 | self.output += " closing legend\n" 37 | 38 | def draw_text(self, text, position, coordinates, style, 39 | text_type=None, mplobj=None): 40 | self.output += " draw text '{0}' {1}\n".format(text, text_type) 41 | 42 | def draw_path(self, data, coordinates, pathcodes, style, 43 | offset=None, offset_coordinates="data", mplobj=None): 44 | self.output += " draw path with {0} vertices\n".format(data.shape[0]) 45 | 46 | def draw_image(self, imdata, extent, coordinates, style, mplobj=None): 47 | self.output += " draw image of size {0}\n".format(len(imdata)) 48 | 49 | 50 | class FullFakeRenderer(FakeRenderer): 51 | """ 52 | Renderer with the full complement of methods. 53 | 54 | When the following are left undefined, they will be implemented via 55 | other methods in the class. They can be defined explicitly for 56 | more efficient or specialized use within the renderer implementation. 57 | """ 58 | def draw_line(self, data, coordinates, style, label, mplobj=None): 59 | self.output += " draw line with {0} points\n".format(data.shape[0]) 60 | 61 | def draw_markers(self, data, coordinates, style, label, mplobj=None): 62 | self.output += " draw {0} markers\n".format(data.shape[0]) 63 | 64 | def draw_path_collection(self, paths, path_coordinates, path_transforms, 65 | offsets, offset_coordinates, offset_order, 66 | styles, mplobj=None): 67 | self.output += (" draw path collection " 68 | "with {0} offsets\n".format(offsets.shape[0])) 69 | -------------------------------------------------------------------------------- /mplexporter/renderers/vega_renderer.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import json 3 | import random 4 | from .base import Renderer 5 | from ..exporter import Exporter 6 | 7 | 8 | class VegaRenderer(Renderer): 9 | def open_figure(self, fig, props): 10 | self.props = props 11 | self.figwidth = int(props['figwidth'] * props['dpi']) 12 | self.figheight = int(props['figheight'] * props['dpi']) 13 | self.data = [] 14 | self.scales = [] 15 | self.axes = [] 16 | self.marks = [] 17 | 18 | def open_axes(self, ax, props): 19 | if len(self.axes) > 0: 20 | warnings.warn("multiple axes not yet supported") 21 | self.axes = [dict(type="x", scale="x", ticks=10), 22 | dict(type="y", scale="y", ticks=10)] 23 | self.scales = [dict(name="x", 24 | domain=props['xlim'], 25 | type="linear", 26 | range="width", 27 | ), 28 | dict(name="y", 29 | domain=props['ylim'], 30 | type="linear", 31 | range="height", 32 | ),] 33 | 34 | def draw_line(self, data, coordinates, style, label, mplobj=None): 35 | if coordinates != 'data': 36 | warnings.warn("Only data coordinates supported. Skipping this") 37 | dataname = "table{0:03d}".format(len(self.data) + 1) 38 | 39 | # TODO: respect the other style settings 40 | self.data.append({'name': dataname, 41 | 'values': [dict(x=d[0], y=d[1]) for d in data]}) 42 | self.marks.append({'type': 'line', 43 | 'from': {'data': dataname}, 44 | 'properties': { 45 | "enter": { 46 | "interpolate": {"value": "monotone"}, 47 | "x": {"scale": "x", "field": "data.x"}, 48 | "y": {"scale": "y", "field": "data.y"}, 49 | "stroke": {"value": style['color']}, 50 | "strokeOpacity": {"value": style['alpha']}, 51 | "strokeWidth": {"value": style['linewidth']}, 52 | } 53 | } 54 | }) 55 | 56 | def draw_markers(self, data, coordinates, style, label, mplobj=None): 57 | if coordinates != 'data': 58 | warnings.warn("Only data coordinates supported. Skipping this") 59 | dataname = "table{0:03d}".format(len(self.data) + 1) 60 | 61 | # TODO: respect the other style settings 62 | self.data.append({'name': dataname, 63 | 'values': [dict(x=d[0], y=d[1]) for d in data]}) 64 | self.marks.append({'type': 'symbol', 65 | 'from': {'data': dataname}, 66 | 'properties': { 67 | "enter": { 68 | "interpolate": {"value": "monotone"}, 69 | "x": {"scale": "x", "field": "data.x"}, 70 | "y": {"scale": "y", "field": "data.y"}, 71 | "fill": {"value": style['facecolor']}, 72 | "fillOpacity": {"value": style['alpha']}, 73 | "stroke": {"value": style['edgecolor']}, 74 | "strokeOpacity": {"value": style['alpha']}, 75 | "strokeWidth": {"value": style['edgewidth']}, 76 | } 77 | } 78 | }) 79 | 80 | def draw_text(self, text, position, coordinates, style, 81 | text_type=None, mplobj=None): 82 | if text_type == 'xlabel': 83 | self.axes[0]['title'] = text 84 | elif text_type == 'ylabel': 85 | self.axes[1]['title'] = text 86 | 87 | 88 | class VegaHTML(object): 89 | def __init__(self, renderer): 90 | self.specification = dict(width=renderer.figwidth, 91 | height=renderer.figheight, 92 | data=renderer.data, 93 | scales=renderer.scales, 94 | axes=renderer.axes, 95 | marks=renderer.marks) 96 | 97 | def html(self): 98 | """Build the HTML representation for IPython.""" 99 | id = random.randint(0, 2 ** 16) 100 | html = '
' % id 101 | html += '\n' 104 | return html 105 | 106 | def _repr_html_(self): 107 | return self.html() 108 | 109 | 110 | def fig_to_vega(fig, notebook=False): 111 | """Convert a matplotlib figure to vega dictionary 112 | 113 | if notebook=True, then return an object which will display in a notebook 114 | otherwise, return an HTML string. 115 | """ 116 | renderer = VegaRenderer() 117 | Exporter(renderer).run(fig) 118 | vega_html = VegaHTML(renderer) 119 | if notebook: 120 | return vega_html 121 | else: 122 | return vega_html.html() 123 | 124 | 125 | VEGA_TEMPLATE = """ 126 | ( function() { 127 | var _do_plot = function() { 128 | if ( (typeof vg == 'undefined') && (typeof IPython != 'undefined')) { 129 | $([IPython.events]).on("vega_loaded.vincent", _do_plot); 130 | return; 131 | } 132 | vg.parse.spec(%s, function(chart) { 133 | chart({el: "#vis%d"}).update(); 134 | }); 135 | }; 136 | _do_plot(); 137 | })(); 138 | """ 139 | -------------------------------------------------------------------------------- /mplexporter/renderers/vincent_renderer.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from .base import Renderer 3 | from ..exporter import Exporter 4 | 5 | 6 | class VincentRenderer(Renderer): 7 | def open_figure(self, fig, props): 8 | self.chart = None 9 | self.figwidth = int(props['figwidth'] * props['dpi']) 10 | self.figheight = int(props['figheight'] * props['dpi']) 11 | 12 | def draw_line(self, data, coordinates, style, label, mplobj=None): 13 | import vincent # only import if VincentRenderer is used 14 | if coordinates != 'data': 15 | warnings.warn("Only data coordinates supported. Skipping this") 16 | linedata = {'x': data[:, 0], 17 | 'y': data[:, 1]} 18 | line = vincent.Line(linedata, iter_idx='x', 19 | width=self.figwidth, height=self.figheight) 20 | 21 | # TODO: respect the other style settings 22 | line.scales['color'].range = [style['color']] 23 | 24 | if self.chart is None: 25 | self.chart = line 26 | else: 27 | warnings.warn("Multiple plot elements not yet supported") 28 | 29 | def draw_markers(self, data, coordinates, style, label, mplobj=None): 30 | import vincent # only import if VincentRenderer is used 31 | if coordinates != 'data': 32 | warnings.warn("Only data coordinates supported. Skipping this") 33 | markerdata = {'x': data[:, 0], 34 | 'y': data[:, 1]} 35 | markers = vincent.Scatter(markerdata, iter_idx='x', 36 | width=self.figwidth, height=self.figheight) 37 | 38 | # TODO: respect the other style settings 39 | markers.scales['color'].range = [style['facecolor']] 40 | 41 | if self.chart is None: 42 | self.chart = markers 43 | else: 44 | warnings.warn("Multiple plot elements not yet supported") 45 | 46 | 47 | def fig_to_vincent(fig): 48 | """Convert a matplotlib figure to a vincent object""" 49 | renderer = VincentRenderer() 50 | exporter = Exporter(renderer) 51 | exporter.run(fig) 52 | return renderer.chart 53 | -------------------------------------------------------------------------------- /mplexporter/tests/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | MPLBE = os.environ.get('MPLBE', 'Agg') 4 | 5 | if MPLBE: 6 | import matplotlib 7 | matplotlib.use(MPLBE) 8 | 9 | import matplotlib.pyplot as plt 10 | -------------------------------------------------------------------------------- /mplexporter/tests/test_basic.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import numpy as np 3 | from packaging.version import Version 4 | from unittest import SkipTest 5 | from numpy.testing import assert_warns 6 | 7 | from ..exporter import Exporter 8 | from ..renderers import FakeRenderer, FullFakeRenderer 9 | from . import plt 10 | 11 | 12 | def fake_renderer_output(fig, Renderer): 13 | renderer = Renderer() 14 | exporter = Exporter(renderer) 15 | exporter.run(fig) 16 | return renderer.output 17 | 18 | 19 | def _assert_output_equal(text1, text2): 20 | for line1, line2 in zip(text1.strip().split(), text2.strip().split()): 21 | assert line1 == line2 22 | 23 | 24 | def test_lines(): 25 | fig, ax = plt.subplots() 26 | ax.plot(range(20), '-k') 27 | 28 | _assert_output_equal(fake_renderer_output(fig, FakeRenderer), 29 | """ 30 | opening figure 31 | opening axes 32 | draw path with 20 vertices 33 | closing axes 34 | closing figure 35 | """) 36 | 37 | _assert_output_equal(fake_renderer_output(fig, FullFakeRenderer), 38 | """ 39 | opening figure 40 | opening axes 41 | draw line with 20 points 42 | closing axes 43 | closing figure 44 | """) 45 | 46 | 47 | def test_markers(): 48 | fig, ax = plt.subplots() 49 | ax.plot(range(2), 'ok') 50 | 51 | _assert_output_equal(fake_renderer_output(fig, FakeRenderer), 52 | """ 53 | opening figure 54 | opening axes 55 | draw path with 25 vertices 56 | draw path with 25 vertices 57 | closing axes 58 | closing figure 59 | """) 60 | 61 | _assert_output_equal(fake_renderer_output(fig, FullFakeRenderer), 62 | """ 63 | opening figure 64 | opening axes 65 | draw 2 markers 66 | closing axes 67 | closing figure 68 | """) 69 | 70 | 71 | def test_path_collection(): 72 | fig, ax = plt.subplots() 73 | ax.scatter(range(3), range(3)) 74 | 75 | _assert_output_equal(fake_renderer_output(fig, FakeRenderer), 76 | """ 77 | opening figure 78 | opening axes 79 | draw path with 25 vertices 80 | draw path with 25 vertices 81 | draw path with 25 vertices 82 | closing axes 83 | closing figure 84 | """) 85 | 86 | _assert_output_equal(fake_renderer_output(fig, FullFakeRenderer), 87 | """ 88 | opening figure 89 | opening axes 90 | draw path collection with 3 offsets 91 | closing axes 92 | closing figure 93 | """) 94 | 95 | 96 | def test_text(): 97 | fig, ax = plt.subplots() 98 | ax.set_xlabel("my x label") 99 | ax.set_ylabel("my y label") 100 | ax.set_title("my title") 101 | ax.text(0.5, 0.5, "my text") 102 | 103 | _assert_output_equal(fake_renderer_output(fig, FakeRenderer), 104 | """ 105 | opening figure 106 | opening axes 107 | draw text 'my text' None 108 | draw text 'my x label' xlabel 109 | draw text 'my y label' ylabel 110 | draw text 'my title' title 111 | closing axes 112 | closing figure 113 | """) 114 | 115 | 116 | def test_path(): 117 | fig, ax = plt.subplots() 118 | ax.add_patch(plt.Circle((0, 0), 1)) 119 | ax.add_patch(plt.Rectangle((0, 0), 1, 2)) 120 | 121 | _assert_output_equal(fake_renderer_output(fig, FakeRenderer), 122 | """ 123 | opening figure 124 | opening axes 125 | draw path with 25 vertices 126 | draw path with 4 vertices 127 | closing axes 128 | closing figure 129 | """) 130 | 131 | def test_Figure(): 132 | """ if the fig is not associated with a canvas, FakeRenderer shall 133 | not fail. """ 134 | fig = plt.Figure() 135 | ax = fig.add_subplot(111) 136 | ax.add_patch(plt.Circle((0, 0), 1)) 137 | ax.add_patch(plt.Rectangle((0, 0), 1, 2)) 138 | 139 | _assert_output_equal(fake_renderer_output(fig, FakeRenderer), 140 | """ 141 | opening figure 142 | opening axes 143 | draw path with 25 vertices 144 | draw path with 4 vertices 145 | closing axes 146 | closing figure 147 | """) 148 | 149 | def test_multiaxes(): 150 | fig, ax = plt.subplots(2) 151 | ax[0].plot(range(4)) 152 | ax[1].plot(range(10)) 153 | 154 | _assert_output_equal(fake_renderer_output(fig, FakeRenderer), 155 | """ 156 | opening figure 157 | opening axes 158 | draw path with 4 vertices 159 | closing axes 160 | opening axes 161 | draw path with 10 vertices 162 | closing axes 163 | closing figure 164 | """) 165 | 166 | 167 | def test_image(): 168 | # Test fails for matplotlib 1.5+ because the size of the image 169 | # generated by matplotlib has changed. 170 | if Version(matplotlib.__version__) >= Version('1.5.0'): 171 | raise SkipTest("Test fails for matplotlib version > 1.5.0"); 172 | np.random.seed(0) # image size depends on the seed 173 | fig, ax = plt.subplots(figsize=(2, 2)) 174 | ax.imshow(np.random.random((10, 10)), 175 | cmap=plt.cm.jet, interpolation='nearest') 176 | _assert_output_equal(fake_renderer_output(fig, FakeRenderer), 177 | """ 178 | opening figure 179 | opening axes 180 | draw image of size 1240 181 | closing axes 182 | closing figure 183 | """) 184 | 185 | 186 | def test_legend(): 187 | fig, ax = plt.subplots() 188 | ax.plot([1, 2, 3], label='label') 189 | ax.legend().set_visible(False) 190 | _assert_output_equal(fake_renderer_output(fig, FakeRenderer), 191 | """ 192 | opening figure 193 | opening axes 194 | draw path with 3 vertices 195 | opening legend 196 | closing legend 197 | closing axes 198 | closing figure 199 | """) 200 | 201 | 202 | def test_legend_dots(): 203 | raise SkipTest("This works visually so skipping for now") 204 | fig, ax = plt.subplots() 205 | ax.plot([1, 2, 3], label='label') 206 | ax.plot([2, 2, 2], 'o', label='dots') 207 | ax.legend().set_visible(True) 208 | _assert_output_equal(fake_renderer_output(fig, FullFakeRenderer), 209 | """ 210 | opening figure 211 | opening axes 212 | draw line with 3 points 213 | draw 3 markers 214 | opening legend 215 | draw line with 2 points 216 | draw text 'label' None 217 | draw 2 markers 218 | draw text 'dots' None 219 | draw path with 4 vertices 220 | closing legend 221 | closing axes 222 | closing figure 223 | """) 224 | 225 | 226 | def test_blended(): 227 | fig, ax = plt.subplots() 228 | ax.axvline(0) 229 | #assert_warns(UserWarning, fake_renderer_output, fig, FakeRenderer) 230 | -------------------------------------------------------------------------------- /mplexporter/tests/test_convertors.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from matplotlib import ticker 3 | from .. import StrMethodTickFormatterConvertor 4 | 5 | class TickFormatConvertorTestCase(unittest.TestCase): 6 | 7 | def test_001_format_strings(self): 8 | conversions = [ 9 | ("{x}", {'format_string': '', 'prefix': '', 'suffix': ''}), 10 | ("{x:#x}", {'format_string': '#x', 'prefix': '', 'suffix': ''}), 11 | ("{x:.2f}", {'format_string': '.2f', 'prefix': '', 'suffix': ''}), 12 | ("{x:.2%}", {'format_string': '.2%', 'prefix': '', 'suffix': ''}), 13 | ("P{x:.2%}", {'format_string': '.2%', 'prefix': 'P', 'suffix': ''}), 14 | ("P{x:.2%} 100", {'format_string': '.2%', 'prefix': 'P', 'suffix': ' 100'}), 15 | ] 16 | for mpl_fmt, d3_fmt in conversions: 17 | formatter = ticker.StrMethodFormatter(mpl_fmt) 18 | cnvrt = StrMethodTickFormatterConvertor(formatter) 19 | self.assertEqual(cnvrt.output, d3_fmt) 20 | -------------------------------------------------------------------------------- /mplexporter/tests/test_utils.py: -------------------------------------------------------------------------------- 1 | from numpy.testing import assert_allclose, assert_equal 2 | from . import plt 3 | from .. import utils 4 | 5 | 6 | def test_path_data(): 7 | circle = plt.Circle((0, 0), 1) 8 | vertices, codes = utils.SVG_path(circle.get_path()) 9 | 10 | assert_allclose(vertices.shape, (25, 2)) 11 | assert_equal(codes, ['M', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'Z']) 12 | 13 | 14 | def test_linestyle(): 15 | linestyles = {'solid': 'none', '-': 'none', 16 | #'dashed': '6,6', '--': '6,6', 17 | #'dotted': '2,2', ':': '2,2', 18 | #'dashdot': '4,4,2,4', '-.': '4,4,2,4', 19 | '': None, 'None': None} 20 | 21 | for ls, result in linestyles.items(): 22 | line, = plt.plot([1, 2, 3], linestyle=ls) 23 | assert_equal(utils.get_dasharray(line), result) 24 | 25 | 26 | def test_axis_w_fixed_formatter(): 27 | positions, labels = [0, 1, 10], ['A','B','C'] 28 | 29 | plt.xticks(positions, labels) 30 | props = utils.get_axis_properties(plt.gca().xaxis) 31 | 32 | assert_equal(props['tickvalues'], positions) 33 | # NOTE: Issue #471 34 | # assert_equal(props['tickformat'], labels) 35 | 36 | -------------------------------------------------------------------------------- /mplexporter/tools.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tools for matplotlib plot exporting 3 | """ 4 | 5 | 6 | def ipynb_vega_init(): 7 | """Initialize the IPython notebook display elements 8 | 9 | This function borrows heavily from the excellent vincent package: 10 | http://github.com/wrobstory/vincent 11 | """ 12 | try: 13 | from IPython.core.display import display, HTML 14 | except ImportError: 15 | print('IPython Notebook could not be loaded.') 16 | 17 | require_js = ''' 18 | if (window['d3'] === undefined) {{ 19 | require.config({{ paths: {{d3: "http://d3js.org/d3.v3.min"}} }}); 20 | require(["d3"], function(d3) {{ 21 | window.d3 = d3; 22 | {0} 23 | }}); 24 | }}; 25 | if (window['topojson'] === undefined) {{ 26 | require.config( 27 | {{ paths: {{topojson: "http://d3js.org/topojson.v1.min"}} }} 28 | ); 29 | require(["topojson"], function(topojson) {{ 30 | window.topojson = topojson; 31 | }}); 32 | }}; 33 | ''' 34 | d3_geo_projection_js_url = "http://d3js.org/d3.geo.projection.v0.min.js" 35 | d3_layout_cloud_js_url = ("http://wrobstory.github.io/d3-cloud/" 36 | "d3.layout.cloud.js") 37 | topojson_js_url = "http://d3js.org/topojson.v1.min.js" 38 | vega_js_url = 'http://trifacta.github.com/vega/vega.js' 39 | 40 | dep_libs = '''$.getScript("%s", function() { 41 | $.getScript("%s", function() { 42 | $.getScript("%s", function() { 43 | $.getScript("%s", function() { 44 | $([IPython.events]).trigger("vega_loaded.vincent"); 45 | }) 46 | }) 47 | }) 48 | });''' % (d3_geo_projection_js_url, d3_layout_cloud_js_url, 49 | topojson_js_url, vega_js_url) 50 | load_js = require_js.format(dep_libs) 51 | html = '' 52 | display(HTML(html)) 53 | -------------------------------------------------------------------------------- /mplexporter/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility Routines for Working with Matplotlib Objects 3 | ==================================================== 4 | """ 5 | import itertools 6 | import io 7 | import base64 8 | 9 | import numpy as np 10 | 11 | import warnings 12 | 13 | import matplotlib 14 | from matplotlib.colors import colorConverter 15 | from matplotlib.path import Path 16 | from matplotlib.markers import MarkerStyle 17 | from matplotlib.transforms import Affine2D 18 | from matplotlib import ticker 19 | from .convertors import StrMethodTickFormatterConvertor 20 | 21 | 22 | def export_color(color): 23 | """Convert matplotlib color code to hex color or RGBA color""" 24 | if color is None or colorConverter.to_rgba(color)[3] == 0: 25 | return 'none' 26 | elif colorConverter.to_rgba(color)[3] == 1: 27 | rgb = colorConverter.to_rgb(color) 28 | return '#{0:02X}{1:02X}{2:02X}'.format(*(int(255 * c) for c in rgb)) 29 | else: 30 | c = colorConverter.to_rgba(color) 31 | return "rgba(" + ", ".join(str(int(np.round(val * 255))) 32 | for val in c[:3])+', '+str(c[3])+")" 33 | 34 | 35 | def _many_to_one(input_dict): 36 | """Convert a many-to-one mapping to a one-to-one mapping""" 37 | return dict((key, val) 38 | for keys, val in input_dict.items() 39 | for key in keys) 40 | 41 | LINESTYLES = _many_to_one({('solid', '-', (None, None)): 'none', 42 | ('dashed', '--'): "6,6", 43 | ('dotted', ':'): "2,2", 44 | ('dashdot', '-.'): "4,4,2,4", 45 | ('', ' ', 'None', 'none'): None}) 46 | 47 | 48 | def get_dasharray(obj): 49 | """Get an SVG dash array for the given matplotlib linestyle 50 | 51 | Parameters 52 | ---------- 53 | obj : matplotlib object 54 | The matplotlib line or path object, which must have a get_linestyle() 55 | method which returns a valid matplotlib line code 56 | 57 | Returns 58 | ------- 59 | dasharray : string 60 | The HTML/SVG dasharray code associated with the object. 61 | """ 62 | if obj.__dict__.get('_dashSeq', None) is not None: 63 | return ','.join(map(str, obj._dashSeq)) 64 | else: 65 | ls = obj.get_linestyle() 66 | dasharray = LINESTYLES.get(ls, 'not found') 67 | if dasharray == 'not found': 68 | warnings.warn("line style '{0}' not understood: " 69 | "defaulting to solid line.".format(ls)) 70 | dasharray = LINESTYLES['solid'] 71 | return dasharray 72 | 73 | 74 | PATH_DICT = {Path.LINETO: 'L', 75 | Path.MOVETO: 'M', 76 | Path.CURVE3: 'S', 77 | Path.CURVE4: 'C', 78 | Path.CLOSEPOLY: 'Z'} 79 | 80 | 81 | def SVG_path(path, transform=None, simplify=False): 82 | """Construct the vertices and SVG codes for the path 83 | 84 | Parameters 85 | ---------- 86 | path : matplotlib.Path object 87 | 88 | transform : matplotlib transform (optional) 89 | if specified, the path will be transformed before computing the output. 90 | 91 | Returns 92 | ------- 93 | vertices : array 94 | The shape (M, 2) array of vertices of the Path. Note that some Path 95 | codes require multiple vertices, so the length of these vertices may 96 | be longer than the list of path codes. 97 | path_codes : list 98 | A length N list of single-character path codes, N <= M. Each code is 99 | a single character, in ['L','M','S','C','Z']. See the standard SVG 100 | path specification for a description of these. 101 | """ 102 | if transform is not None: 103 | path = path.transformed(transform) 104 | 105 | vc_tuples = [(vertices if path_code != Path.CLOSEPOLY else [], 106 | PATH_DICT[path_code]) 107 | for (vertices, path_code) 108 | in path.iter_segments(simplify=simplify)] 109 | 110 | if not vc_tuples: 111 | # empty path is a special case 112 | return np.zeros((0, 2)), [] 113 | else: 114 | vertices, codes = zip(*vc_tuples) 115 | vertices = np.array(list(itertools.chain(*vertices))).reshape(-1, 2) 116 | return vertices, list(codes) 117 | 118 | 119 | def get_path_style(path, fill=True): 120 | """Get the style dictionary for matplotlib path objects""" 121 | style = {} 122 | style['alpha'] = path.get_alpha() 123 | if style['alpha'] is None: 124 | style['alpha'] = 1 125 | style['edgecolor'] = export_color(path.get_edgecolor()) 126 | if fill: 127 | style['facecolor'] = export_color(path.get_facecolor()) 128 | else: 129 | style['facecolor'] = 'none' 130 | style['edgewidth'] = path.get_linewidth() 131 | style['dasharray'] = get_dasharray(path) 132 | style['zorder'] = path.get_zorder() 133 | return style 134 | 135 | 136 | def get_line_style(line): 137 | """Get the style dictionary for matplotlib line objects""" 138 | style = {} 139 | style['alpha'] = line.get_alpha() 140 | if style['alpha'] is None: 141 | style['alpha'] = 1 142 | style['color'] = export_color(line.get_color()) 143 | style['linewidth'] = line.get_linewidth() 144 | style['dasharray'] = get_dasharray(line) 145 | style['zorder'] = line.get_zorder() 146 | style['drawstyle'] = line.get_drawstyle() 147 | return style 148 | 149 | 150 | def get_marker_style(line): 151 | """Get the style dictionary for matplotlib marker objects""" 152 | style = {} 153 | style['alpha'] = line.get_alpha() 154 | if style['alpha'] is None: 155 | style['alpha'] = 1 156 | 157 | style['facecolor'] = export_color(line.get_markerfacecolor()) 158 | style['edgecolor'] = export_color(line.get_markeredgecolor()) 159 | style['edgewidth'] = line.get_markeredgewidth() 160 | 161 | style['marker'] = line.get_marker() 162 | markerstyle = MarkerStyle(line.get_marker()) 163 | markersize = line.get_markersize() 164 | markertransform = (markerstyle.get_transform() 165 | + Affine2D().scale(markersize, -markersize)) 166 | style['markerpath'] = SVG_path(markerstyle.get_path(), 167 | markertransform) 168 | style['markersize'] = markersize 169 | style['zorder'] = line.get_zorder() 170 | return style 171 | 172 | 173 | def get_text_style(text): 174 | """Return the text style dict for a text instance""" 175 | style = {} 176 | style['alpha'] = text.get_alpha() 177 | if style['alpha'] is None: 178 | style['alpha'] = 1 179 | style['fontsize'] = text.get_size() 180 | style['color'] = export_color(text.get_color()) 181 | style['halign'] = text.get_horizontalalignment() # left, center, right 182 | style['valign'] = text.get_verticalalignment() # baseline, center, top 183 | style['malign'] = text._multialignment # text alignment when '\n' in text 184 | style['rotation'] = text.get_rotation() 185 | style['zorder'] = text.get_zorder() 186 | return style 187 | 188 | 189 | def get_axis_properties(axis): 190 | """Return the property dictionary for a matplotlib.Axis instance""" 191 | props = {} 192 | label1On = axis._major_tick_kw.get('label1On', True) 193 | 194 | if isinstance(axis, matplotlib.axis.XAxis): 195 | if label1On: 196 | props['position'] = "bottom" 197 | else: 198 | props['position'] = "top" 199 | elif isinstance(axis, matplotlib.axis.YAxis): 200 | if label1On: 201 | props['position'] = "left" 202 | else: 203 | props['position'] = "right" 204 | else: 205 | raise ValueError("{0} should be an Axis instance".format(axis)) 206 | 207 | # Use tick values if appropriate 208 | locator = axis.get_major_locator() 209 | props['nticks'] = len(locator()) 210 | if isinstance(locator, ticker.FixedLocator): 211 | props['tickvalues'] = list(locator()) 212 | else: 213 | props['tickvalues'] = None 214 | 215 | # Find tick formats 216 | props['tickformat_formatter'] = "" 217 | formatter = axis.get_major_formatter() 218 | if isinstance(formatter, ticker.NullFormatter): 219 | props['tickformat'] = "" 220 | elif isinstance(formatter, ticker.StrMethodFormatter): 221 | convertor = StrMethodTickFormatterConvertor(formatter) 222 | props['tickformat'] = convertor.output 223 | props['tickformat_formatter'] = "str_method" 224 | elif isinstance(formatter, ticker.PercentFormatter): 225 | props['tickformat'] = { 226 | "xmax": formatter.xmax, 227 | "decimals": formatter.decimals, 228 | "symbol": formatter.symbol, 229 | } 230 | props['tickformat_formatter'] = "percent" 231 | elif hasattr(ticker, 'IndexFormatter') and isinstance(formatter, ticker.IndexFormatter): 232 | # IndexFormatter was dropped in matplotlib 3.5 233 | props['tickformat'] = [text.get_text() for text in axis.get_ticklabels()] 234 | props['tickformat_formatter'] = "index" 235 | elif isinstance(formatter, ticker.FixedFormatter): 236 | props['tickformat'] = list(formatter.seq) 237 | props['tickformat_formatter'] = "fixed" 238 | elif not any(label.get_visible() for label in axis.get_ticklabels()): 239 | props['tickformat'] = "" 240 | else: 241 | props['tickformat'] = None 242 | 243 | # Get axis scale 244 | props['scale'] = axis.get_scale() 245 | 246 | # Get major tick label size (assumes that's all we really care about!) 247 | labels = axis.get_ticklabels() 248 | if labels: 249 | props['fontsize'] = labels[0].get_fontsize() 250 | else: 251 | props['fontsize'] = None 252 | 253 | # Get associated grid 254 | props['grid'] = get_grid_style(axis) 255 | 256 | # get axis visibility 257 | props['visible'] = axis.get_visible() 258 | 259 | return props 260 | 261 | 262 | def get_grid_style(axis): 263 | gridlines = axis.get_gridlines() 264 | if axis._major_tick_kw['gridOn'] and len(gridlines) > 0: 265 | color = export_color(gridlines[0].get_color()) 266 | alpha = gridlines[0].get_alpha() 267 | dasharray = get_dasharray(gridlines[0]) 268 | return dict(gridOn=True, 269 | color=color, 270 | dasharray=dasharray, 271 | alpha=alpha) 272 | else: 273 | return {"gridOn": False} 274 | 275 | 276 | def get_figure_properties(fig): 277 | return {'figwidth': fig.get_figwidth(), 278 | 'figheight': fig.get_figheight(), 279 | 'dpi': fig.dpi} 280 | 281 | 282 | def get_axes_properties(ax): 283 | props = {'axesbg': export_color(ax.patch.get_facecolor()), 284 | 'axesbgalpha': ax.patch.get_alpha(), 285 | 'bounds': ax.get_position().bounds, 286 | 'dynamic': ax.get_navigate(), 287 | 'axison': ax.axison, 288 | 'frame_on': ax.get_frame_on(), 289 | 'patch_visible':ax.patch.get_visible(), 290 | 'axes': [get_axis_properties(ax.xaxis), 291 | get_axis_properties(ax.yaxis)]} 292 | 293 | for axname in ['x', 'y']: 294 | axis = getattr(ax, axname + 'axis') 295 | domain = getattr(ax, 'get_{0}lim'.format(axname))() 296 | lim = domain 297 | if ( 298 | ( 299 | hasattr(matplotlib.dates, '_SwitchableDateConverter') and 300 | isinstance(axis.get_converter(), matplotlib.dates._SwitchableDateConverter) 301 | ) or ( 302 | hasattr(matplotlib.dates, 'DateConverter') and 303 | isinstance(axis.get_converter(), matplotlib.dates.DateConverter) 304 | ) or ( 305 | hasattr(matplotlib.dates, 'ConciseDateConverter') and 306 | isinstance(axis.get_converter(), matplotlib.dates.ConciseDateConverter) 307 | ) 308 | ): 309 | scale = 'date' 310 | try: 311 | import pandas as pd 312 | from pandas.tseries.converter import PeriodConverter 313 | except ImportError: 314 | pd = None 315 | 316 | if (pd is not None and isinstance(axis.get_converter(), 317 | PeriodConverter)): 318 | _dates = [pd.Period(ordinal=int(d), freq=axis.freq) 319 | for d in domain] 320 | domain = [(d.year, d.month - 1, d.day, 321 | d.hour, d.minute, d.second, 0) 322 | for d in _dates] 323 | else: 324 | domain = [(d.year, d.month - 1, d.day, 325 | d.hour, d.minute, d.second, 326 | d.microsecond * 1E-3) 327 | for d in matplotlib.dates.num2date(domain)] 328 | else: 329 | scale = axis.get_scale() 330 | 331 | if scale not in ['date', 'linear', 'log']: 332 | raise ValueError("Unknown axis scale: " 333 | "{0}".format(axis.get_scale())) 334 | 335 | props[axname + 'scale'] = scale 336 | props[axname + 'lim'] = lim 337 | props[axname + 'domain'] = domain 338 | 339 | return props 340 | 341 | 342 | def iter_all_children(obj, skipContainers=False): 343 | """ 344 | Returns an iterator over all childen and nested children using 345 | obj's get_children() method 346 | 347 | if skipContainers is true, only childless objects are returned. 348 | """ 349 | if hasattr(obj, 'get_children') and len(obj.get_children()) > 0: 350 | for child in obj.get_children(): 351 | if not skipContainers: 352 | yield child 353 | # could use `yield from` in python 3... 354 | for grandchild in iter_all_children(child, skipContainers): 355 | yield grandchild 356 | else: 357 | yield obj 358 | 359 | 360 | def get_legend_properties(ax, legend): 361 | handles, labels = ax.get_legend_handles_labels() 362 | visible = legend.get_visible() 363 | return {'handles': handles, 'labels': labels, 'visible': visible} 364 | 365 | 366 | def image_to_base64(image): 367 | """ 368 | Convert a matplotlib image to a base64 png representation 369 | 370 | Parameters 371 | ---------- 372 | image : matplotlib image object 373 | The image to be converted. 374 | 375 | Returns 376 | ------- 377 | image_base64 : string 378 | The UTF8-encoded base64 string representation of the png image. 379 | """ 380 | ax = image.axes 381 | binary_buffer = io.BytesIO() 382 | 383 | # image is saved in axes coordinates: we need to temporarily 384 | # set the correct limits to get the correct image 385 | lim = ax.axis() 386 | ax.axis(image.get_extent()) 387 | image.write_png(binary_buffer) 388 | ax.axis(lim) 389 | 390 | binary_buffer.seek(0) 391 | return base64.b64encode(binary_buffer.read()).decode('utf-8') 392 | -------------------------------------------------------------------------------- /notebooks/VegaTest.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "metadata": { 3 | "name": "" 4 | }, 5 | "nbformat": 3, 6 | "nbformat_minor": 0, 7 | "worksheets": [ 8 | { 9 | "cells": [ 10 | { 11 | "cell_type": "heading", 12 | "level": 1, 13 | "metadata": {}, 14 | "source": [ 15 | "Matplotlib to Vega Example" 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "metadata": {}, 21 | "source": [ 22 | "This notebook contains some examples of converting matplotlib plots to vega." 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "collapsed": false, 28 | "input": [ 29 | "%matplotlib inline\n", 30 | "import matplotlib.pyplot as plt\n", 31 | "import numpy as np\n", 32 | "from mplexporter.renderers import fig_to_vega" 33 | ], 34 | "language": "python", 35 | "metadata": {}, 36 | "outputs": [], 37 | "prompt_number": 1 38 | }, 39 | { 40 | "cell_type": "code", 41 | "collapsed": false, 42 | "input": [ 43 | "from mplexporter.tools import ipynb_vega_init\n", 44 | "ipynb_vega_init()" 45 | ], 46 | "language": "python", 47 | "metadata": {}, 48 | "outputs": [ 49 | { 50 | "html": [ 51 | "" 76 | ], 77 | "metadata": {}, 78 | "output_type": "display_data", 79 | "text": [ 80 | "