├── .gitignore ├── .pre-commit-config.yaml ├── .travis.yml ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── altair_pandas ├── __init__.py ├── _core.py ├── _misc.py ├── conftest.py └── test_plotting.py ├── images └── example.png ├── requirements.txt ├── requirements_dev.txt ├── setup.cfg └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | 107 | 108 | # vscode 109 | .vscode/ 110 | 111 | #examples 112 | examples/*.ipynb 113 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: setup.py 2 | repos: 3 | - repo: https://github.com/psf/black 4 | rev: stable 5 | hooks: 6 | - id: black 7 | language_version: python3 -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | 3 | matrix: 4 | include: 5 | - python: 3.6 6 | - python: 3.7 7 | - python: 3.8 8 | 9 | before_install: 10 | - pip install pip --upgrade; 11 | - pip install -r requirements_dev.txt 12 | 13 | install: 14 | - pip install -e .; 15 | 16 | script: 17 | - black --check . 18 | - flake8 altair_pandas; 19 | - python -m pytest --backend_name=altair altair_pandas; 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, Altair 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include *.md 2 | include LICENSE 3 | include requirements.txt 4 | recursive-include altair_pandas *.py 5 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | all: install 2 | 3 | install: 4 | python setup.py install 5 | 6 | test : 7 | black . 8 | flake8 altair_pandas 9 | python -m pytest --pyargs --doctest-modules altair_pandas 10 | 11 | test-coverage: 12 | python -m pytest --pyargs --doctest-modules --cov=altair_pandas --cov-report term altair_pandas 13 | 14 | test-coverage-html: 15 | python -m pytest --pyargs --doctest-modules --cov=altair_pandas --cov-report html altair_pandas 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # altair_pandas 2 | 3 | [![build status](http://img.shields.io/travis/altair-viz/altair_pandas/master.svg?style=flat)](https://travis-ci.org/altair-viz/altair_pandas) 4 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 5 | 6 | Altair backend for pandas plotting functions. 7 | 8 | **Note: this package is a work in progress** 9 | 10 | ## Installation 11 | Altair pandas backend works with pandas version 0.25.1 or newer. 12 | ``` 13 | $ pip install git+https://github.com/altair-viz/altair_pandas 14 | $ pip install -U pandas 15 | ``` 16 | 17 | ## Usage 18 | In a Jupyter notebook with [Altair](http://altair-viz.github.io) properly configured: 19 | ```python 20 | import pandas as pd 21 | import numpy as np 22 | pd.set_option('plotting.backend', 'altair') # Installing altair_pandas registers this. 23 | 24 | data = pd.Series(np.random.randn(100).cumsum()) 25 | data.plot() 26 | ``` 27 | ![Altair-Pandas Visualization](https://raw.githubusercontent.com/altair-viz/altair_pandas/master/images/example.png) 28 | 29 | The goal of this package is to implement all of [Pandas' Plotting API](https://pandas.pydata.org/pandas-docs/stable/user_guide/visualization.html) -------------------------------------------------------------------------------- /altair_pandas/__init__.py: -------------------------------------------------------------------------------- 1 | """Altair plotting extension for pandas.""" 2 | __version__ = "0.1.0dev0" 3 | __all__ = ["boxplot_frame", "plot", "hist_frame", "hist_series", "scatter_matrix"] 4 | 5 | from ._core import boxplot_frame, plot, hist_frame, hist_series 6 | from ._misc import scatter_matrix 7 | -------------------------------------------------------------------------------- /altair_pandas/_core.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import altair as alt 4 | import pandas as pd 5 | import numpy as np 6 | 7 | 8 | def _valid_column(column_name): 9 | """Return a valid column name.""" 10 | return str(column_name) 11 | 12 | 13 | def _get_fontsize(size_name): 14 | """Return a fontsize based on matplotlib labels.""" 15 | font_sizes = { 16 | "xx-small": 5.79, 17 | "x-small": 6.94, 18 | "small": 8.33, 19 | "medium": 10.0, 20 | "large": 12.0, 21 | "x-large": 14.4, 22 | "xx-large": 17.28, 23 | "larger": 12.0, 24 | "smaller": 8.33, 25 | } 26 | return font_sizes[size_name] 27 | 28 | 29 | def _get_layout(panels, layout=None): 30 | """Compute the layout for a gridded chart. 31 | 32 | Parameters 33 | ---------- 34 | panels : int 35 | Number of panels in the chart. 36 | layout : tuple of ints 37 | Control the layout. Negative entries will be inferred 38 | from the number of panels. 39 | 40 | Returns 41 | ------- 42 | nrows, ncols : int, int 43 | number of rows and columns in the resulting layout. 44 | 45 | Examples 46 | -------- 47 | >>> _get_layout(6, (2, 3)) 48 | (2, 3) 49 | >>> _get_layout(6, (1, -1)) 50 | (1, 6) 51 | >>> _get_layout(6, (-1, 2)) 52 | (3, 2) 53 | """ 54 | if layout is None: 55 | layout = (-1, 2) 56 | if len(layout) != 2: 57 | raise ValueError("layout should have two elements") 58 | if layout[0] < 0 and layout[1] < 0: 59 | raise ValueError("At least one dimension of layout must be positive") 60 | if layout[0] < 0: 61 | layout = (int(np.ceil(panels / layout[1])), layout[1]) 62 | if layout[1] < 0: 63 | layout = (layout[0], int(np.ceil(panels / layout[0]))) 64 | if panels > layout[0] * layout[1]: 65 | raise ValueError(f"layout {layout[0]}x{layout[1]} must be larger than {panels}") 66 | return layout 67 | 68 | 69 | class _PandasPlotter: 70 | """Base class for pandas plotting.""" 71 | 72 | @classmethod 73 | def create(cls, data): 74 | if isinstance(data, pd.Series): 75 | return _SeriesPlotter(data) 76 | elif isinstance(data, pd.DataFrame): 77 | return _DataFramePlotter(data) 78 | else: 79 | raise NotImplementedError(f"data of type {type(data)}") 80 | 81 | def _get_mark_def(self, mark, kwargs): 82 | if isinstance(mark, str): 83 | mark = {"type": mark} 84 | if isinstance(kwargs.get("alpha"), float): 85 | mark["opacity"] = kwargs.pop("alpha") 86 | if isinstance(kwargs.get("color"), str): 87 | mark["color"] = kwargs.pop("color") 88 | return mark 89 | 90 | def _kde(self, data, bw_method=None, ind=None, **kwargs): 91 | if bw_method == "scott" or bw_method is None: 92 | bandwidth = 0 93 | elif bw_method == "silverman": 94 | # Implementation taken from 95 | # https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.gaussian_kde.html 96 | n = data.shape[0] 97 | d = 1 98 | bandwidth = (n * (d + 2) / 4.0) ** (-1.0 / (d + 4)) 99 | elif callable(bw_method): 100 | if 1 < data.shape[1]: 101 | warnings.warn( 102 | "Using a callable argument for ind using the Altair" 103 | " plotting backend sets the bandwidth for all" 104 | " columns", 105 | category=UserWarning, 106 | ) 107 | bandwidth = bw_method(data) 108 | else: 109 | bandwidth = bw_method 110 | 111 | if ind is None: 112 | steps = 1_000 113 | elif isinstance(ind, np.ndarray): 114 | warnings.warn( 115 | "The Altair plotting backend does not support sequences for ind", 116 | category=UserWarning, 117 | ) 118 | steps = 1_000 119 | else: 120 | steps = ind 121 | 122 | chart = ( 123 | alt.Chart(data, mark=self._get_mark_def("area", kwargs)) 124 | .transform_fold( 125 | data.columns.to_numpy(), 126 | as_=["Column", "value"], 127 | ) 128 | .transform_density( 129 | density="value", 130 | bandwidth=bandwidth, 131 | groupby=["Column"], 132 | # Manually setting domain to min and max makes kde look 133 | # more uniform 134 | extent=[data.min().min(), data.max().max()], 135 | steps=steps, 136 | ) 137 | .encode( 138 | x=alt.X("value", type="quantitative"), 139 | y=alt.Y("density", type="quantitative", stack="zero"), 140 | tooltip=[ 141 | alt.Tooltip("value", type="quantitative"), 142 | alt.Tooltip("density", type="quantitative"), 143 | alt.Tooltip("Column", type="nominal"), 144 | ], 145 | ) 146 | .interactive() 147 | ) 148 | # If there is only one column, do not encode color so that user 149 | # can pass optional color kwarg into mark 150 | if 1 < data.shape[1]: 151 | chart = chart.encode(color=alt.Color("Column", type="nominal")) 152 | return chart 153 | 154 | 155 | class _SeriesPlotter(_PandasPlotter): 156 | """Functionality for plotting of pandas Series.""" 157 | 158 | def __init__(self, data): 159 | if not isinstance(data, pd.Series): 160 | raise ValueError(f"data: expected pd.Series; got {type(data)}") 161 | self._data = data 162 | 163 | def _preprocess_data(self, with_index=True): 164 | # TODO: do this without copy? 165 | data = self._data 166 | if with_index: 167 | if isinstance(data.index, pd.MultiIndex): 168 | data = data.copy() 169 | data.index = pd.Index( 170 | [str(i) for i in data.index], name=data.index.name 171 | ) 172 | data = data.reset_index() 173 | else: 174 | data = data.to_frame() 175 | # Column names must all be strings. 176 | return data.rename(columns=_valid_column) 177 | 178 | def _xy(self, mark, **kwargs): 179 | data = self._preprocess_data(with_index=True) 180 | return ( 181 | alt.Chart(data, mark=self._get_mark_def(mark, kwargs)) 182 | .encode( 183 | x=alt.X(data.columns[0], title=None), 184 | y=alt.Y(data.columns[1], title=None), 185 | tooltip=list(data.columns), 186 | ) 187 | .interactive() 188 | ) 189 | 190 | def line(self, **kwargs): 191 | return self._xy("line", **kwargs) 192 | 193 | def bar(self, **kwargs): 194 | return self._xy({"type": "bar", "orient": "vertical"}, **kwargs) 195 | 196 | def barh(self, **kwargs): 197 | chart = self._xy({"type": "bar", "orient": "horizontal"}, **kwargs) 198 | chart.encoding.x, chart.encoding.y = chart.encoding.y, chart.encoding.x 199 | return chart 200 | 201 | def area(self, **kwargs): 202 | return self._xy(mark="area", **kwargs) 203 | 204 | def scatter(self, **kwargs): 205 | raise ValueError("kind='scatter' can only be used for DataFrames.") 206 | 207 | def hist(self, bins=None, orientation="vertical", **kwargs): 208 | data = self._preprocess_data(with_index=False) 209 | column = data.columns[0] 210 | if isinstance(bins, int): 211 | bins = alt.Bin(maxbins=bins) 212 | elif bins is None: 213 | bins = True 214 | if orientation == "vertical": 215 | Indep, Dep = alt.X, alt.Y 216 | elif orientation == "horizontal": 217 | Indep, Dep = alt.Y, alt.X 218 | else: 219 | raise ValueError("orientation must be 'horizontal' or 'vertical'.") 220 | 221 | mark = self._get_mark_def({"type": "bar", "orient": orientation}, kwargs) 222 | return alt.Chart(data, mark=mark).encode( 223 | Indep(column, title=None, bin=bins), Dep("count()", title="Frequency") 224 | ) 225 | 226 | def hist_series(self, **kwargs): 227 | return self.hist(**kwargs) 228 | 229 | def box(self, vert=True, **kwargs): 230 | data = self._preprocess_data(with_index=False) 231 | chart = ( 232 | alt.Chart(data) 233 | .transform_fold(list(data.columns), as_=["column", "value"]) 234 | .mark_boxplot() 235 | .encode(x=alt.X("column:N", title=None), y="value:Q") 236 | ) 237 | if not vert: 238 | chart.encoding.x, chart.encoding.y = chart.encoding.y, chart.encoding.x 239 | return chart 240 | 241 | def kde(self, bw_method=None, ind=None, **kwargs): 242 | data = self._preprocess_data(with_index=False) 243 | return self._kde(data, bw_method=bw_method, ind=ind, **kwargs) 244 | 245 | 246 | class _DataFramePlotter(_PandasPlotter): 247 | """Functionality for plotting of pandas DataFrames.""" 248 | 249 | def __init__(self, data): 250 | if not isinstance(data, pd.DataFrame): 251 | raise ValueError(f"data: expected pd.DataFrame; got {type(data)}") 252 | self._data = data 253 | 254 | def _preprocess_data(self, with_index=True, usecols=None): 255 | data = self._data.rename(columns=_valid_column) 256 | if usecols is not None: 257 | data = data[usecols] 258 | if with_index: 259 | if isinstance(data.index, pd.MultiIndex): 260 | data.index = pd.Index( 261 | [str(i) for i in data.index], name=data.index.name 262 | ) 263 | return data.reset_index() 264 | return data 265 | 266 | def _xy(self, mark, x=None, y=None, stacked=False, subplots=False, **kwargs): 267 | data = self._preprocess_data(with_index=True) 268 | 269 | if x is None: 270 | x = data.columns[0] 271 | else: 272 | x = _valid_column(x) 273 | assert x in data.columns 274 | 275 | if y is None: 276 | y_values = list(data.columns[1:]) 277 | else: 278 | y = _valid_column(y) 279 | assert y in data.columns 280 | y_values = [y] 281 | 282 | chart = ( 283 | alt.Chart(data, mark=self._get_mark_def(mark, kwargs)) 284 | .transform_fold(y_values, as_=["column", "value"]) 285 | .encode( 286 | x=x, 287 | y=alt.Y("value:Q", title=None, stack=stacked), 288 | color=alt.Color("column:N", title=None), 289 | tooltip=[x] + y_values, 290 | ) 291 | .interactive() 292 | ) 293 | 294 | if subplots: 295 | nrows, ncols = _get_layout(len(y_values), kwargs.get("layout", (-1, 1))) 296 | chart = chart.encode(facet=alt.Facet("column:N", title=None)).properties( 297 | columns=ncols 298 | ) 299 | 300 | return chart 301 | 302 | def line(self, x=None, y=None, **kwargs): 303 | return self._xy("line", x, y, **kwargs) 304 | 305 | def area(self, x=None, y=None, stacked=True, **kwargs): 306 | mark = "area" if stacked else {"type": "area", "line": True, "opacity": 0.5} 307 | return self._xy(mark, x, y, stacked, **kwargs) 308 | 309 | # TODO: bars should be grouped, not stacked. 310 | def bar(self, x=None, y=None, **kwargs): 311 | return self._xy({"type": "bar", "orient": "vertical"}, x, y, **kwargs) 312 | 313 | def barh(self, x=None, y=None, **kwargs): 314 | chart = self._xy({"type": "bar", "orient": "horizontal"}, x, y, **kwargs) 315 | chart.encoding.x, chart.encoding.y = chart.encoding.y, chart.encoding.x 316 | return chart 317 | 318 | def scatter(self, x, y, c=None, s=None, **kwargs): 319 | if x is None or y is None: 320 | raise ValueError("kind='scatter' requires 'x' and 'y' arguments.") 321 | encodings = {"x": _valid_column(x), "y": _valid_column(y)} 322 | if c is not None: 323 | encodings["color"] = _valid_column(c) 324 | if s is not None: 325 | encodings["size"] = _valid_column(s) 326 | columns = list(set(encodings.values())) 327 | data = self._preprocess_data(with_index=False, usecols=columns) 328 | encodings["tooltip"] = columns 329 | mark = self._get_mark_def("point", kwargs) 330 | return alt.Chart(data, mark=mark).encode(**encodings).interactive() 331 | 332 | def hist(self, bins=None, stacked=None, orientation="vertical", **kwargs): 333 | data = self._preprocess_data(with_index=False) 334 | if isinstance(bins, int): 335 | bins = alt.Bin(maxbins=bins) 336 | elif bins is None: 337 | bins = True 338 | if orientation == "vertical": 339 | Indep, Dep = alt.X, alt.Y 340 | elif orientation == "horizontal": 341 | Indep, Dep = alt.Y, alt.X 342 | else: 343 | raise ValueError("orientation must be 'horizontal' or 'vertical'.") 344 | 345 | mark = self._get_mark_def({"type": "bar", "orient": orientation}, kwargs) 346 | chart = ( 347 | alt.Chart(data, mark=mark) 348 | .transform_fold(list(data.columns), as_=["column", "value"]) 349 | .encode( 350 | Indep("value:Q", title=None, bin=bins), 351 | Dep("count()", title="Frequency", stack=stacked), 352 | color="column:N", 353 | ) 354 | ) 355 | 356 | if kwargs.get("subplots"): 357 | nrows, ncols = _get_layout(data.shape[1], kwargs.get("layout", (-1, 1))) 358 | chart = chart.encode(facet=alt.Facet("column:N", title=None)).properties( 359 | columns=ncols 360 | ) 361 | 362 | return chart 363 | 364 | def hist_frame(self, column=None, layout=(-1, 2), **kwargs): 365 | if column is not None: 366 | if isinstance(column, str): 367 | column = [column] 368 | data = self._preprocess_data(with_index=False, usecols=column) 369 | data = data._get_numeric_data() 370 | nrows, ncols = _get_layout(data.shape[1], layout) 371 | return ( 372 | alt.Chart(data, mark=self._get_mark_def("bar", kwargs)) 373 | .encode( 374 | x=alt.X(alt.repeat("repeat"), type="quantitative", bin=True), 375 | y=alt.Y("count()", title="Frequency"), 376 | ) 377 | .repeat(repeat=list(data.columns), columns=ncols) 378 | ) 379 | 380 | def box( 381 | self, 382 | vert=True, 383 | column=None, 384 | by=None, 385 | fontsize=None, 386 | rot=0, 387 | grid=True, 388 | figsize=None, 389 | layout=None, 390 | return_type=None, 391 | **kwargs, 392 | ): 393 | data = self._preprocess_data(with_index=False) 394 | 395 | if column is not None: 396 | columns = [column] if isinstance(column, str) else column 397 | else: 398 | columns = data.select_dtypes(np.number).columns 399 | if by is not None: 400 | columns = columns.difference(pd.Index(list(by))) 401 | 402 | if by is not None: 403 | if np.iterable(by) and not isinstance(by, str) and 1 < len(by): 404 | by_identifier = ", ".join(by) 405 | by_title = f"[{by_identifier}]" 406 | # Check that name doesn't overlap with existing 407 | # columns 408 | # If it does, assign a unique name 409 | by_column = ( 410 | by_identifier if by_identifier not in columns else "".join(columns) 411 | ) 412 | data[by_column] = data[by].apply( 413 | lambda row: f"({', '.join(row)})", axis=1 414 | ) 415 | panels = data[by_column].nunique() 416 | else: 417 | by = by.pop() if not isinstance(by, str) else by 418 | panels = data[by].nunique() 419 | by_title = by 420 | by_column = by 421 | x_column = by_column 422 | x_title = by_title 423 | else: 424 | panels = 1 425 | x_column = "Column" 426 | x_title = None 427 | 428 | mark_args = { 429 | kwarg: value 430 | for kwarg, value in kwargs.items() 431 | if kwarg in {"alpha", "color"} 432 | } 433 | 434 | # Matplotlib measures counterclockwise, while Vega-Lite measures 435 | # clockwise 436 | # Convert counterclockwise to clockwise 437 | label_angle = 360 - rot 438 | 439 | if return_type is not None: 440 | warnings.warn( 441 | "Different return types are not implimented for the Altair backend.", 442 | category=UserWarning, 443 | ) 444 | 445 | _, n_columns = _get_layout(panels, layout=layout) 446 | 447 | chart = ( 448 | alt.Chart(data) 449 | .transform_fold(list(columns), as_=["Column", "Value"]) 450 | .mark_boxplot(**mark_args) 451 | .encode( 452 | x=alt.X( 453 | x_column, 454 | title=x_title, 455 | type="nominal", 456 | axis=alt.Axis(labelAngle=label_angle, grid=grid), 457 | ), 458 | y=alt.Y("Value", type="quantitative", axis=alt.Axis(grid=grid)), 459 | tooltip=[ 460 | alt.Tooltip(x_column, title=x_title, type="nominal"), 461 | alt.Tooltip("Value", type="quantitative"), 462 | ], 463 | ) 464 | .interactive() 465 | ) 466 | 467 | if not vert: 468 | chart.encoding.x, chart.encoding.y = chart.encoding.y, chart.encoding.x 469 | 470 | if by is not None: 471 | chart = chart.facet( 472 | facet=alt.Facet("Column", title=None, type="nominal"), 473 | columns=n_columns, 474 | ).properties(title=f"Boxplot grouped by {by}") 475 | 476 | if fontsize is not None: 477 | size = _get_fontsize(fontsize) if isinstance(fontsize, str) else fontsize 478 | chart = chart.configure_axis( 479 | labelFontSize=size, 480 | titleFontSize=size, 481 | ) 482 | 483 | if figsize is not None: 484 | width, height = figsize 485 | chart = chart.configure_view( 486 | continuousHeight=height, 487 | discreteHeight=height, 488 | continuousWidth=width, 489 | discreteWidth=width, 490 | ) 491 | 492 | return chart 493 | 494 | def kde(self, bw_method=None, ind=None, **kwargs): 495 | data = self._preprocess_data(with_index=False) 496 | return self._kde(data, bw_method=bw_method, ind=ind, **kwargs) 497 | 498 | def hexbin(self, x, y, C=None, reduce_C_function=None, gridsize=None, **kwargs): 499 | data = self._preprocess_data(with_index=False) 500 | 501 | if np.iterable(gridsize): 502 | x_bins, y_bins = gridsize 503 | else: 504 | x_bins = 100 if gridsize is None else gridsize 505 | # Since rectangles are being used here, 506 | # instead of hexagons like in Matplotlib, 507 | # set default y_bins equal to x_bins 508 | y_bins = x_bins 509 | 510 | x_step = (data[x].max() - data[x].min()) / x_bins 511 | y_step = (data[y].max() - data[y].min()) / y_bins 512 | 513 | # Default set to bluegreen to match Matplotlib's default 514 | color_scheme = kwargs.pop("cmap", "bluegreen") 515 | 516 | if C is not None: 517 | reduce_C_function = ( 518 | np.mean if reduce_C_function is None else reduce_C_function 519 | ) 520 | # Make sure column is not overwritten if C is one 521 | # of the coordinate columns 522 | color_shorthand = C if C not in (x, y) else f"reduced_{C}" 523 | data[color_shorthand] = data.groupby( 524 | [ 525 | pd.cut(data[x], bins=x_bins), 526 | pd.cut(data[y], bins=y_bins), 527 | ] 528 | )[C].transform(reduce_C_function) 529 | # All reduced values will be identical across rows that 530 | # belong to the same bin 531 | # Since the median of a collection of identical values is 532 | # the value itself, the median is used here as a way to pass 533 | # the reduced value per bin to Altair 534 | color_aggregate = "median" 535 | color_title = C 536 | else: 537 | color_shorthand = x 538 | color_aggregate = "count" 539 | color_title = "Count" 540 | 541 | chart = ( 542 | alt.Chart(data) 543 | .mark_rect(**kwargs) 544 | .encode( 545 | x=alt.X(x, bin=alt.Bin(step=x_step)), 546 | y=alt.Y(y, bin=alt.Bin(step=y_step)), 547 | color=alt.Color( 548 | color_shorthand, 549 | aggregate=color_aggregate, 550 | scale=alt.Scale(scheme=color_scheme), 551 | title=color_title, 552 | type="quantitative", 553 | ), 554 | tooltip=[ 555 | alt.Tooltip(x, bin=alt.Bin(step=x_step), type="quantitative"), 556 | alt.Tooltip(y, bin=alt.Bin(step=y_step), type="quantitative"), 557 | alt.Tooltip( 558 | color_shorthand, 559 | aggregate=color_aggregate, 560 | title=color_title, 561 | type="quantitative", 562 | ), 563 | ], 564 | ) 565 | .interactive() 566 | ) 567 | 568 | return chart 569 | 570 | 571 | def plot(data, kind="line", **kwargs): 572 | """Pandas plotting interface for Altair.""" 573 | plotter = _PandasPlotter.create(data) 574 | 575 | if hasattr(plotter, kind): 576 | plotfunc = getattr(plotter, kind) 577 | else: 578 | raise NotImplementedError(f"kind='{kind}' for data of type {type(data)}") 579 | 580 | return plotfunc(**kwargs) 581 | 582 | 583 | def hist_frame(data, **kwargs): 584 | return _PandasPlotter.create(data).hist_frame(**kwargs) 585 | 586 | 587 | def hist_series(data, **kwargs): 588 | return _PandasPlotter.create(data).hist_series(**kwargs) 589 | 590 | 591 | def boxplot_frame( 592 | data, 593 | column=None, 594 | by=None, 595 | fontsize=None, 596 | rot=0, 597 | grid=True, 598 | figsize=None, 599 | layout=None, 600 | return_type=None, 601 | **kwargs, 602 | ): 603 | return _PandasPlotter.create(data).box( 604 | column=column, 605 | by=by, 606 | fontsize=fontsize, 607 | rot=rot, 608 | grid=grid, 609 | figsize=figsize, 610 | layout=layout, 611 | return_type=return_type, 612 | **kwargs, 613 | ) 614 | -------------------------------------------------------------------------------- /altair_pandas/_misc.py: -------------------------------------------------------------------------------- 1 | import altair as alt 2 | from typing import Union, List 3 | import pandas as pd 4 | 5 | tooltipList = List[alt.Tooltip] 6 | 7 | 8 | def _preprocess_data(data): 9 | for indx in ("index", "columns"): 10 | if isinstance(getattr(data, indx), pd.MultiIndex): 11 | setattr( 12 | data, 13 | indx, 14 | pd.Index( 15 | [str(i) for i in getattr(data, indx)], name=getattr(data, indx).name 16 | ), 17 | ) 18 | # Column names must all be strings. 19 | return data.rename(columns=str).copy() 20 | 21 | 22 | def _process_tooltip(tooltip): 23 | """converts tooltip els to string if needed""" 24 | if isinstance(tooltip, list) and not isinstance(tooltip[0], alt.Tooltip): 25 | tooltip = [str(el) for el in tooltip] 26 | 27 | return tooltip 28 | 29 | 30 | def scatter_matrix( 31 | df, 32 | color: Union[str, None] = None, 33 | alpha: float = 1.0, 34 | tooltip: Union[List[str], tooltipList, None] = None, 35 | **kwargs 36 | ) -> alt.Chart: 37 | """ plots a scatter matrix 38 | 39 | At the moment does not support neither histogram nor kde; 40 | Uses f-f scatterplots instead. Interactive and with a cusotmizable 41 | tooltip 42 | 43 | Parameters 44 | ---------- 45 | df : DataFame 46 | DataFame to be used for scatterplot. Only numeric columns will be included. 47 | color : string [optional] 48 | Can be a column name or specific color value (hex, webcolors). 49 | alpha : float 50 | Opacity of the markers, within [0,1] 51 | tooltip: list [optional] 52 | List of specific column names or alt.Tooltip objects. If none (default), 53 | will show all columns. 54 | """ 55 | dfc = _preprocess_data(df) 56 | tooltip = _process_tooltip(tooltip) or dfc.columns.tolist() 57 | cols = dfc._get_numeric_data().columns.tolist() 58 | 59 | chart = ( 60 | alt.Chart(dfc) 61 | .mark_circle() 62 | .encode( 63 | x=alt.X(alt.repeat("column"), type="quantitative"), 64 | y=alt.X(alt.repeat("row"), type="quantitative"), 65 | opacity=alt.value(alpha), 66 | tooltip=tooltip, 67 | ) 68 | .properties(width=150, height=150) 69 | ) 70 | 71 | if color: 72 | color = str(color) 73 | 74 | if color in dfc: 75 | color = alt.Color(color) 76 | if "colormap" in kwargs: 77 | color.scale = alt.Scale(scheme=kwargs.get("colormap")) 78 | else: 79 | color = alt.value(color) 80 | chart = chart.encode(color=color) 81 | 82 | return chart.repeat(row=cols, column=cols).interactive() 83 | -------------------------------------------------------------------------------- /altair_pandas/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pandas as pd 3 | 4 | 5 | def pytest_addoption(parser): 6 | parser.addoption( 7 | "--backend_name", 8 | action="store", 9 | default="altair_pandas", 10 | help="Plotting backend to use.", 11 | ) 12 | 13 | 14 | @pytest.fixture(scope="session") 15 | def with_plotting_backend(request): 16 | default = pd.get_option("plotting.backend") 17 | pd.set_option("plotting.backend", request.config.getoption("backend_name")) 18 | yield 19 | try: 20 | pd.set_option("plotting.backend", default) 21 | except ImportError: 22 | pass # matplotlib is not installed. 23 | -------------------------------------------------------------------------------- /altair_pandas/test_plotting.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | import pandas as pd 4 | import altair as alt 5 | 6 | 7 | @pytest.fixture 8 | def series(): 9 | return pd.Series(range(5), name="data_name") 10 | 11 | 12 | @pytest.fixture 13 | def dataframe(): 14 | return pd.DataFrame({"x": range(5), "y": range(5)}) 15 | 16 | 17 | def _expected_mark(kind): 18 | marks = {"barh": "bar", "hist": "bar", "box": "boxplot"} 19 | return marks.get(kind, kind) 20 | 21 | 22 | @pytest.mark.parametrize( 23 | "data", 24 | [ 25 | pd.Series( 26 | range(6), index=pd.MultiIndex.from_product([["a", "b", "c"], [1, 2]]) 27 | ), 28 | pd.DataFrame( 29 | {"x": range(6)}, index=pd.MultiIndex.from_product([["a", "b", "c"], [1, 2]]) 30 | ), 31 | ], 32 | ) 33 | def test_multiindex(data, with_plotting_backend): 34 | chart = data.plot.bar() 35 | spec = chart.to_dict() 36 | assert list(chart.data.iloc[:, 0]) == [str(i) for i in data.index] 37 | assert spec["encoding"]["x"]["field"] == "index" 38 | assert spec["encoding"]["x"]["type"] == "nominal" 39 | 40 | 41 | def test_nonstring_column_names(with_plotting_backend): 42 | data = pd.DataFrame(np.ones((3, 4))) 43 | chart = data.plot.scatter(x=0, y=1, c=2, s=3) 44 | 45 | # Ensure data is not modified 46 | assert list(data.columns) == list(range(4)) 47 | # Ensure chart data has string columns 48 | assert set(chart.data.columns) == {str(i) for i in range(4)} 49 | 50 | spec = chart.to_dict() 51 | assert spec["encoding"]["x"]["field"] == "0" 52 | assert spec["encoding"]["y"]["field"] == "1" 53 | assert spec["encoding"]["color"]["field"] == "2" 54 | assert spec["encoding"]["size"]["field"] == "3" 55 | 56 | 57 | @pytest.mark.parametrize("kind", ["line", "area", "bar", "barh"]) 58 | def test_series_basic_plot(series, kind, with_plotting_backend): 59 | chart = series.plot(kind=kind) 60 | spec = chart.to_dict() 61 | 62 | x, y = "x", "y" 63 | if kind == "bar": 64 | assert spec["mark"]["orient"] == "vertical" 65 | if kind == "barh": 66 | assert spec["mark"]["orient"] == "horizontal" 67 | x, y = y, x 68 | 69 | assert spec["mark"]["type"] == _expected_mark(kind) 70 | assert spec["encoding"][x]["field"] == "index" 71 | assert spec["encoding"][y]["field"] == "data_name" 72 | 73 | 74 | @pytest.mark.parametrize("stacked", [True, False]) 75 | @pytest.mark.parametrize("subplots", [False, True]) 76 | @pytest.mark.parametrize("kind", ["line", "area", "bar", "barh"]) 77 | def test_dataframe_basic_plot( 78 | dataframe, kind, stacked, subplots, with_plotting_backend 79 | ): 80 | chart = dataframe.plot(kind=kind, stacked=stacked, subplots=subplots) 81 | spec = chart.to_dict() 82 | 83 | x, y = "x", "y" 84 | if kind == "bar": 85 | assert spec["mark"]["orient"] == "vertical" 86 | if kind == "barh": 87 | assert spec["mark"]["orient"] == "horizontal" 88 | x, y = y, x 89 | 90 | assert spec["mark"]["type"] == _expected_mark(kind) 91 | assert spec["encoding"][x]["field"] == "index" 92 | assert spec["encoding"][y]["field"] == "value" 93 | assert spec["encoding"][y]["stack"] == stacked 94 | assert spec["encoding"]["color"]["field"] == "column" 95 | assert spec["transform"][0]["fold"] == ["x", "y"] 96 | if subplots: 97 | assert spec["encoding"]["facet"]["field"] == "column" 98 | assert spec["columns"] == 1 99 | else: 100 | assert "facet" not in spec["encoding"] 101 | 102 | 103 | def test_series_barh(series, with_plotting_backend): 104 | chart = series.plot.barh() 105 | spec = chart.to_dict() 106 | assert spec["mark"] == {"type": "bar", "orient": "horizontal"} 107 | assert spec["encoding"]["y"]["field"] == "index" 108 | assert spec["encoding"]["x"]["field"] == "data_name" 109 | 110 | 111 | def test_dataframe_barh(dataframe, with_plotting_backend): 112 | chart = dataframe.plot.barh() 113 | spec = chart.to_dict() 114 | assert spec["mark"] == {"type": "bar", "orient": "horizontal"} 115 | assert spec["encoding"]["y"]["field"] == "index" 116 | assert spec["encoding"]["x"]["field"] == "value" 117 | assert spec["encoding"]["color"]["field"] == "column" 118 | assert spec["transform"][0]["fold"] == ["x", "y"] 119 | 120 | 121 | def test_series_scatter_plot(series, with_plotting_backend): 122 | with pytest.raises(ValueError): 123 | series.plot.scatter("x", "y") 124 | 125 | 126 | def test_dataframe_scatter_plot(dataframe, with_plotting_backend): 127 | dataframe["c"] = range(len(dataframe)) 128 | chart = dataframe.plot.scatter("x", "y", c="y", s="x") 129 | spec = chart.to_dict() 130 | assert spec["mark"] == {"type": "point"} 131 | assert spec["encoding"]["x"]["field"] == "x" 132 | assert spec["encoding"]["y"]["field"] == "y" 133 | assert spec["encoding"]["color"]["field"] == "y" 134 | assert spec["encoding"]["size"]["field"] == "x" 135 | 136 | 137 | @pytest.mark.parametrize("bins", [None, 10]) 138 | @pytest.mark.parametrize("orientation", ["vertical", "horizontal"]) 139 | def test_series_hist(series, bins, orientation, with_plotting_backend): 140 | chart = series.plot.hist(bins=bins, orientation=orientation) 141 | spec = chart.to_dict() 142 | x, y = ("x", "y") if orientation == "vertical" else ("y", "x") 143 | 144 | assert spec["mark"]["type"] == "bar" 145 | assert spec["mark"]["orient"] == orientation 146 | assert spec["encoding"][x]["field"] == "data_name" 147 | assert "field" not in spec["encoding"][y] 148 | exp_bin = True if bins is None else {"maxbins": bins} 149 | assert spec["encoding"][x]["bin"] == exp_bin 150 | 151 | 152 | @pytest.mark.parametrize("bins", [None, 10]) 153 | @pytest.mark.parametrize("stacked", [None, True, False]) 154 | @pytest.mark.parametrize("orientation", ["vertical", "horizontal"]) 155 | def test_dataframe_hist(dataframe, bins, stacked, orientation, with_plotting_backend): 156 | chart = dataframe.plot.hist(bins=bins, stacked=stacked, orientation=orientation) 157 | spec = chart.to_dict() 158 | x, y = ("x", "y") if orientation == "vertical" else ("y", "x") 159 | assert spec["mark"]["type"] == "bar" 160 | assert spec["mark"]["orient"] == orientation 161 | assert spec["encoding"][x]["field"] == "value" 162 | assert "field" not in spec["encoding"][y] 163 | assert spec["encoding"]["color"]["field"] == "column" 164 | assert spec["transform"][0]["fold"] == ["x", "y"] 165 | exp_bin = True if bins is None else {"maxbins": bins} 166 | assert spec["encoding"][x]["bin"] == exp_bin 167 | assert spec["encoding"][y]["stack"] == (True if stacked else stacked) 168 | 169 | 170 | @pytest.mark.parametrize("vert", [True, False]) 171 | def test_series_boxplot(series, vert, with_plotting_backend): 172 | chart = series.plot.box(vert=vert) 173 | spec = chart.to_dict() 174 | assert spec["mark"] == "boxplot" 175 | assert spec["transform"][0]["fold"] == ["data_name"] 176 | fields = ["column", "value"] if vert else ["value", "column"] 177 | assert spec["encoding"]["x"]["field"] == fields[0] 178 | assert spec["encoding"]["y"]["field"] == fields[1] 179 | 180 | 181 | @pytest.mark.parametrize("vert", [True, False]) 182 | def test_dataframe_boxplot(dataframe, vert, with_plotting_backend): 183 | chart = dataframe.plot.box(vert=vert) 184 | spec = chart.to_dict() 185 | assert spec["mark"] == "boxplot" 186 | assert spec["transform"][0]["fold"] == ["x", "y"] 187 | fields = ["Column", "Value"] if vert else ["Value", "Column"] 188 | assert spec["encoding"]["x"]["field"] == fields[0] 189 | assert spec["encoding"]["y"]["field"] == fields[1] 190 | 191 | 192 | def test_hist_series(series, with_plotting_backend): 193 | chart = series.hist() 194 | spec = chart.to_dict() 195 | assert spec["mark"]["type"] == "bar" 196 | assert spec["encoding"]["x"]["field"] == "data_name" 197 | assert "field" not in spec["encoding"]["y"] 198 | assert spec["encoding"]["x"]["bin"] == {"maxbins": 10} 199 | 200 | 201 | def test_hist_frame(dataframe, with_plotting_backend): 202 | chart = dataframe.hist(layout=(-1, 1)) 203 | spec = chart.to_dict() 204 | assert spec["repeat"] == ["x", "y"] 205 | assert spec["columns"] == 1 206 | assert spec["spec"]["mark"] == {"type": "bar"} 207 | assert spec["spec"]["encoding"]["x"]["field"] == {"repeat": "repeat"} 208 | assert spec["spec"]["encoding"]["x"]["bin"] is True 209 | assert "field" not in spec["spec"]["encoding"]["y"] 210 | 211 | 212 | @pytest.mark.parametrize("kind", ["hist", "line", "bar", "barh"]) 213 | def test_dataframe_mark_properties(dataframe, kind, with_plotting_backend): 214 | chart = dataframe.plot(kind=kind, alpha=0.5, color="red") 215 | spec = chart.to_dict() 216 | assert spec["mark"]["type"] == _expected_mark(kind) 217 | assert spec["mark"]["opacity"] == 0.5 218 | assert spec["mark"]["color"] == "red" 219 | 220 | 221 | @pytest.mark.parametrize("kind", ["hist", "line", "bar", "barh"]) 222 | def test_series_mark_properties(series, kind, with_plotting_backend): 223 | chart = series.plot(kind=kind, alpha=0.5, color="red") 224 | spec = chart.to_dict() 225 | assert spec["mark"]["type"] == _expected_mark(kind) 226 | assert spec["mark"]["opacity"] == 0.5 227 | assert spec["mark"]["color"] == "red" 228 | 229 | 230 | @pytest.mark.parametrize("stacked", [True, False]) 231 | def test_dataframe_area(dataframe, stacked, with_plotting_backend): 232 | chart = dataframe.plot.area(stacked=stacked) 233 | spec = chart.to_dict() 234 | mark = ( 235 | {"type": "area"} if stacked else {"type": "area", "line": True, "opacity": 0.5} 236 | ) 237 | assert spec["mark"] == mark 238 | for k, v in {"x": "index", "y": "value", "color": "column"}.items(): 239 | assert spec["encoding"][k]["field"] == v 240 | assert spec["transform"][0]["fold"] == ["x", "y"] 241 | 242 | 243 | @pytest.mark.parametrize("alpha", [1.0, 0.2]) 244 | @pytest.mark.parametrize("color", [None, "x", "z"]) 245 | @pytest.mark.parametrize( 246 | "tooltip", 247 | [ 248 | None, 249 | ["x", "y"], 250 | [alt.Tooltip("x", format="$.2f"), alt.Tooltip("z", format=".0%")], 251 | ], 252 | ) 253 | def test_scatter_matrix(dataframe, alpha, color, tooltip, with_plotting_backend): 254 | from altair_pandas import scatter_matrix 255 | 256 | dataframe["z"] = ["A", "B", "C", "D", "E"] 257 | 258 | chart = scatter_matrix(dataframe, alpha=alpha, color=color, tooltip=tooltip) 259 | spec = chart.to_dict() 260 | 261 | cols = dataframe._get_numeric_data().columns.astype(str).tolist() 262 | for k, v in spec["repeat"].items(): 263 | assert set(v) == set(cols) 264 | 265 | if color is None: 266 | assert "color" not in spec["spec"]["encoding"] 267 | elif color == "x": 268 | assert spec["spec"]["encoding"]["color"] == { 269 | "type": "quantitative", 270 | "field": "x", 271 | } 272 | 273 | assert spec["spec"]["encoding"]["opacity"] == {"value": alpha} 274 | 275 | if tooltip is None: 276 | assert set(el["field"] for el in spec["spec"]["encoding"]["tooltip"]) == { 277 | "x", 278 | "y", 279 | "z", 280 | } 281 | elif tooltip == ["x", "y"]: 282 | assert len(spec["spec"]["encoding"]["tooltip"]) == 2 283 | assert set(el["field"] for el in spec["spec"]["encoding"]["tooltip"]) == { 284 | "x", 285 | "y", 286 | } 287 | else: 288 | assert len(spec["spec"]["encoding"]["tooltip"]) == 2 289 | assert set(el["field"] for el in spec["spec"]["encoding"]["tooltip"]) == { 290 | "x", 291 | "z", 292 | } 293 | assert spec["spec"]["encoding"]["tooltip"][0]["format"] == "$.2f" 294 | 295 | 296 | @pytest.mark.parametrize("colormap", ["viridis", "goldgreen"]) 297 | @pytest.mark.parametrize("color", ["x", "z"]) 298 | def test_scatter_colormap(dataframe, colormap, color, with_plotting_backend): 299 | from altair_pandas import scatter_matrix 300 | 301 | if color == "z": 302 | dataframe["z"] = ["A", "B", "C", "D", "E"] 303 | 304 | chart = scatter_matrix(dataframe, color=color, colormap=colormap) 305 | spec = chart.to_dict() 306 | 307 | assert spec["spec"]["encoding"]["color"]["scale"]["scheme"] == colormap 308 | 309 | 310 | @pytest.mark.parametrize( 311 | "indx, data", 312 | { 313 | "index": pd.DataFrame( 314 | {"x": range(6)}, index=pd.MultiIndex.from_product([["a", "b", "c"], [1, 2]]) 315 | ), 316 | "columns": pd.DataFrame( 317 | {"x": range(6)}, index=pd.MultiIndex.from_product([["a", "b", "c"], [1, 2]]) 318 | ).T, 319 | }.items(), 320 | ) 321 | def test_scatter_multiindex(indx, data, with_plotting_backend): 322 | from altair_pandas import scatter_matrix 323 | 324 | chart = scatter_matrix(data) 325 | spec = chart.to_dict() 326 | 327 | cols = ( 328 | {"x"} 329 | if indx == "index" 330 | else ({"('b', 2)", "('b', 1)", "('c', 2)", "('a', 2)", "('c', 1)", "('a', 1)"}) 331 | ) 332 | 333 | for k, v in spec["repeat"].items(): 334 | assert set(v) == cols 335 | 336 | 337 | @pytest.mark.parametrize( 338 | "data", 339 | [ 340 | pd.DataFrame({"a": np.arange(12), "b": np.arange(12, 24)}), 341 | pd.DataFrame({"a": np.arange(12)}), 342 | pd.Series(np.arange(12)), 343 | ], 344 | ) 345 | @pytest.mark.parametrize( 346 | "bw_method, bandwidth", 347 | [ 348 | (None, 0), 349 | ("scott", 0), 350 | ("silverman", 0.6443940149772542), 351 | (lambda data: 0.3, 0.3), 352 | ], 353 | ) 354 | @pytest.mark.parametrize("ind, steps", [(None, 1_000), (500, 500)]) 355 | def test_kde(data, bw_method, bandwidth, ind, steps, with_plotting_backend): 356 | chart = data.plot(kind="kde", bw_method=bw_method, ind=ind) 357 | spec = chart.to_dict() 358 | 359 | density_attributes = spec["transform"][1] 360 | assert density_attributes["bandwidth"] == pytest.approx(bandwidth) 361 | assert density_attributes["extent"] == [ 362 | data.to_numpy().min(), 363 | data.to_numpy().max(), 364 | ] 365 | assert density_attributes["groupby"] == ["Column"] 366 | assert density_attributes["steps"] == steps 367 | if 1 < len(data.shape) and 1 < data.shape[1]: 368 | assert spec["encoding"]["color"]["field"] == "Column" 369 | 370 | 371 | def test_kde_warns_callable_bw_method(dataframe, with_plotting_backend): 372 | with pytest.warns(UserWarning): 373 | dataframe.plot(kind="kde", bw_method=lambda data: 0) 374 | 375 | 376 | def test_kde_warns_array_ind(series): 377 | with pytest.warns(UserWarning): 378 | series.plot(kind="kde", ind=np.arange(5)) 379 | 380 | 381 | def test_set_color_kde(series, with_plotting_backend): 382 | mark_color = "#6300EE" 383 | chart = series.plot(kind="kde", color=mark_color) 384 | spec = chart.to_dict() 385 | assert spec["mark"]["color"] == mark_color 386 | 387 | 388 | def test_set_alpha_kde(dataframe, with_plotting_backend): 389 | alpha = 0.2 390 | chart = dataframe.plot(kind="kde", alpha=alpha) 391 | spec = chart.to_dict() 392 | assert spec["mark"]["opacity"] == alpha 393 | 394 | 395 | @pytest.mark.parametrize("gridsize", [None, 10, (5, 15)]) 396 | def test_hexbin(dataframe, gridsize, with_plotting_backend): 397 | chart = dataframe.plot(kind="hexbin", x="x", y="y", gridsize=gridsize) 398 | spec = chart.to_dict() 399 | 400 | if np.iterable(gridsize): 401 | x_bins, y_bins = gridsize 402 | else: 403 | x_bins = 100 if gridsize is None else gridsize 404 | y_bins = x_bins 405 | 406 | x_step = (dataframe["x"].max() - dataframe["y"].min()) / x_bins 407 | y_step = (dataframe["y"].max() - dataframe["y"].min()) / y_bins 408 | 409 | encoding = spec["encoding"] 410 | assert encoding["x"]["bin"]["step"] == pytest.approx(x_step) 411 | assert encoding["y"]["bin"]["step"] == pytest.approx(y_step) 412 | assert encoding["color"]["field"] == "x" 413 | assert encoding["color"]["aggregate"] == "count" 414 | 415 | 416 | @pytest.mark.parametrize( 417 | "reduce_C_function, first_color_value", 418 | [ 419 | (None, 1.5), 420 | (np.max, 3), 421 | (np.sum, 6), 422 | ], 423 | ) 424 | def test_hexbin_C(reduce_C_function, first_color_value, with_plotting_backend): 425 | dataframe = pd.DataFrame( 426 | {"x": np.arange(20), "y": np.arange(20, 40), "C": np.arange(20)} 427 | ) 428 | chart = dataframe.plot( 429 | kind="hexbin", 430 | x="x", 431 | y="y", 432 | C="C", 433 | reduce_C_function=reduce_C_function, 434 | gridsize=5, 435 | ) 436 | spec = chart.to_dict() 437 | 438 | dataset = spec["datasets"] 439 | assert dataset[list(dataset.keys())[0]][0]["C"] == first_color_value 440 | assert spec["encoding"]["color"]["aggregate"] == "median" 441 | assert spec["encoding"]["color"]["title"] == "C" 442 | 443 | 444 | def test_hexbin_C_equals_x(dataframe, with_plotting_backend): 445 | chart = dataframe.plot( 446 | kind="hexbin", x="x", y="y", C="x", reduce_C_function=lambda df: 1 447 | ) 448 | spec = chart.to_dict() 449 | 450 | dataset = spec["datasets"] 451 | assert dataset[list(dataset.keys())[0]][0]["reduced_x"] == 1 452 | 453 | 454 | def test_hexbin_cmap(dataframe, with_plotting_backend): 455 | chart = dataframe.plot(kind="hexbin", x="x", y="y", cmap="blue") 456 | spec = chart.to_dict() 457 | 458 | assert spec["encoding"]["color"]["scale"]["scheme"] == "blue" 459 | 460 | 461 | def test_boxplot(dataframe, with_plotting_backend): 462 | chart = dataframe.boxplot() 463 | spec = chart.to_dict() 464 | encoding = spec["encoding"] 465 | 466 | assert spec["mark"] == "boxplot" 467 | assert encoding["x"]["field"] == "Column" 468 | assert encoding["x"]["type"] == "nominal" 469 | assert encoding["y"]["field"] == "Value" 470 | assert encoding["y"]["type"] == "quantitative" 471 | 472 | 473 | @pytest.mark.parametrize( 474 | "column, fold", 475 | [ 476 | (None, ["Col1", "Col2", "Col3"]), 477 | ("Col1", ["Col1"]), 478 | ("Col2", ["Col2"]), 479 | (["Col1", "Col3"], ["Col1", "Col3"]), 480 | ], 481 | ) 482 | def test_boxplot_column(column, fold, with_plotting_backend): 483 | df = pd.DataFrame(np.random.randn(10, 3), columns=["Col1", "Col2", "Col3"]) 484 | df["X"] = pd.Series(["A", "A", "A", "A", "A", "B", "B", "B", "B", "B"]) 485 | df["Y"] = pd.Series(["A", "B", "A", "B", "A", "B", "A", "B", "A", "B"]) 486 | chart = df.boxplot(column=column) 487 | spec = chart.to_dict() 488 | 489 | assert spec["transform"][0]["fold"] == fold 490 | 491 | 492 | @pytest.mark.parametrize("by, field", [("X", "X"), ("Y", "Y"), (["X", "Y"], "X, Y")]) 493 | def test_boxplot_by(by, field, with_plotting_backend): 494 | df = pd.DataFrame(np.random.randn(10, 3), columns=["Col1", "Col2", "Col3"]) 495 | df["X"] = pd.Series(["A", "A", "A", "A", "A", "B", "B", "B", "B", "B"]) 496 | df["Y"] = pd.Series(["A", "B", "A", "B", "A", "B", "A", "B", "A", "B"]) 497 | chart = df.boxplot(by=by) 498 | spec = chart.to_dict() 499 | 500 | assert spec["facet"]["field"] == "Column" 501 | assert spec["spec"]["encoding"]["x"]["field"] == field 502 | 503 | 504 | def test_boxplot_fontsize(dataframe, with_plotting_backend): 505 | fontsize = 100 506 | chart = dataframe.boxplot(fontsize=fontsize) 507 | axis = chart.to_dict()["config"]["axis"] 508 | 509 | assert axis["labelFontSize"] == 100 510 | assert axis["titleFontSize"] == 100 511 | 512 | 513 | def test_boxplot_rot(dataframe, with_plotting_backend): 514 | rot = 45 515 | chart = dataframe.boxplot(rot=rot) 516 | x_encoding = chart["encoding"]["x"] 517 | 518 | assert x_encoding["axis"]["labelAngle"] == 360 - rot 519 | 520 | 521 | def test_boxplot_grid(dataframe, with_plotting_backend): 522 | chart = dataframe.boxplot(grid=False) 523 | encoding = chart.to_dict()["encoding"] 524 | 525 | assert encoding["x"]["axis"]["grid"] is False 526 | assert encoding["y"]["axis"]["grid"] is False 527 | 528 | 529 | def test_boxplot_figsize(dataframe, with_plotting_backend): 530 | width = 500 531 | height = 300 532 | chart = dataframe.boxplot(figsize=(width, height)) 533 | view = chart.to_dict()["config"]["view"] 534 | 535 | assert view["continuousHeight"] == 300 536 | assert view["continuousWidth"] == 500 537 | assert view["discreteHeight"] == 300 538 | assert view["discreteWidth"] == 500 539 | 540 | 541 | def test_boxplot_layout(with_plotting_backend): 542 | df = pd.DataFrame(np.random.randn(10, 3), columns=["Col1", "Col2", "Col3"]) 543 | df["X"] = pd.Series(["A", "A", "A", "A", "A", "B", "B", "B", "B", "B"]) 544 | chart = df.boxplot(by="X", layout=(3, 1)) 545 | 546 | assert chart.to_dict()["columns"] == 1 547 | 548 | 549 | def test_boxplot_warn_return_type(dataframe, with_plotting_backend): 550 | with pytest.warns(UserWarning): 551 | dataframe.boxplot(return_type="dict") 552 | -------------------------------------------------------------------------------- /images/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/altair-viz/altair_pandas/506d0b937f1ac047168c88f814c1e3266490a9cc/images/example.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | altair>=3.0 2 | pandas>=0.25.1 3 | -------------------------------------------------------------------------------- /requirements_dev.txt: -------------------------------------------------------------------------------- 1 | flake8 2 | black 3 | pytest 4 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md 3 | license_file = LICENSE 4 | 5 | [bdist_wheel] 6 | universal = 1 7 | 8 | 9 | [flake8] 10 | max-line-length = 88 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import re 4 | 5 | try: 6 | from setuptools import setup 7 | except ImportError: 8 | from distutils.core import setup 9 | 10 | # ============================================================================== 11 | # Utilities 12 | # ============================================================================== 13 | 14 | 15 | def read(path, encoding="utf-8"): 16 | path = os.path.join(os.path.dirname(__file__), path) 17 | with io.open(path, encoding=encoding) as fp: 18 | return fp.read() 19 | 20 | 21 | def get_install_requirements(path): 22 | content = read(path) 23 | return [req for req in content.split("\n") if req != "" and not req.startswith("#")] 24 | 25 | 26 | def version(path): 27 | """Obtain the packge version from a python file e.g. pkg/__init__.py 28 | 29 | See . 30 | """ 31 | version_file = read(path) 32 | version_match = re.search( 33 | r"""^__version__ = ['"]([^'"]*)['"]""", version_file, re.M 34 | ) 35 | if version_match: 36 | return version_match.group(1) 37 | raise RuntimeError("Unable to find version string.") 38 | 39 | 40 | HERE = os.path.abspath(os.path.dirname(__file__)) 41 | 42 | # From https://github.com/jupyterlab/jupyterlab/blob/master/setupbase.py, BSD licensed 43 | def find_packages(top=HERE): 44 | """ 45 | Find all of the packages. 46 | """ 47 | packages = [] 48 | for d, dirs, _ in os.walk(top, followlinks=True): 49 | if os.path.exists(os.path.join(d, "__init__.py")): 50 | packages.append(os.path.relpath(d, top).replace(os.path.sep, ".")) 51 | elif d != top: 52 | # Do not look for packages in subfolders if current is not a package 53 | dirs[:] = [] 54 | return packages 55 | 56 | 57 | # ============================================================================== 58 | # Variables 59 | # ============================================================================== 60 | 61 | DESCRIPTION = "Altair backend for pandas plotting." 62 | LONG_DESCRIPTION = read("README.md") 63 | LONG_DESCRIPTION_CONTENT_TYPE = "text/markdown" 64 | NAME = "altair_pandas" 65 | PACKAGES = find_packages() 66 | AUTHOR = "Jake VanderPlas" 67 | AUTHOR_EMAIL = "jakevdp@google.com" 68 | URL = "http://github.com/altair-viz/altair_pandas/" 69 | DOWNLOAD_URL = "http://github.com/altair-viz/altair_pandas/" 70 | LICENSE = "BSD 3-clause" 71 | INSTALL_REQUIRES = get_install_requirements("requirements.txt") 72 | VERSION = version("altair_pandas/__init__.py") 73 | ENTRYPOINTS = {"pandas_plotting_backends": ["altair = altair_pandas"]} 74 | 75 | 76 | setup( 77 | name=NAME, 78 | version=VERSION, 79 | description=DESCRIPTION, 80 | long_description=LONG_DESCRIPTION, 81 | long_description_content_type=LONG_DESCRIPTION_CONTENT_TYPE, 82 | author=AUTHOR, 83 | author_email=AUTHOR_EMAIL, 84 | url=URL, 85 | download_url=DOWNLOAD_URL, 86 | license=LICENSE, 87 | packages=PACKAGES, 88 | entry_points=ENTRYPOINTS, 89 | include_package_data=True, 90 | install_requires=INSTALL_REQUIRES, 91 | python_requires=">=3.6", 92 | classifiers=[ 93 | "Development Status :: 5 - Production/Stable", 94 | "Environment :: Console", 95 | "Intended Audience :: Science/Research", 96 | "License :: OSI Approved :: BSD License", 97 | "Natural Language :: English", 98 | "Programming Language :: Python :: 3.6", 99 | "Programming Language :: Python :: 3.7", 100 | "Programming Language :: Python :: 3.8", 101 | ], 102 | ) 103 | --------------------------------------------------------------------------------