├── .github └── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── .gitignore ├── LICENSE ├── README.md ├── github └── parallel_matplotlib_grid.svg ├── parallelplot ├── __init__.py └── plot.py └── setup.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: "[BUG] " 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **Minimal Code To Reproduce** 14 | Minimal code example to reproduce the bug. 15 | 16 | **Expected behavior** 17 | A clear and concise description of what you expected to happen. 18 | 19 | **Screenshots** 20 | If applicable, add screenshots to help explain your problem. 21 | 22 | **Environment (please complete the following information):** 23 | - OS: [e.g. Ubuntu 18.04] 24 | - Python [e.g. Python 3.8] 25 | - Module Version or git tag / hash etc. [e.g. 0.1] 26 | 27 | **Additional context** 28 | Add any other context about the problem here. 29 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: "[REQUEST]" 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Paul Gavrikov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Parallel generation of grid-like plots using matplotlib 2 | 3 | This Python 3 module helps you speedup generation of subplots in pseudo-parallel mode using `matplotlib` and `multiprocessing`. This can be useful if you are dealing with expensive preprocessing or plotting tasks such as violin plots per subplot. 4 | 5 | ![Operation overview](github/parallel_matplotlib_grid.svg) 6 | 7 | ## How does it work? 8 | 9 | This library uses pythons `multiprocessing` module to plot each cell individually. If provided, each process will first evaluate a user-defined preprocessing function. Afterwards, every process will call a second user-defined plotting function providing matplotlib axes to plot on. All created plots then stored as images and then retrieved and assembled by the main thread into a subplot without any decoration. 10 | 11 | ## How do I install this module? 12 | 13 | This module is in a very early stage, so no `pypi` releases are currently provided. However, you can simply install this module from git: 14 | ```bash 15 | pip install git+https://github.com/paulgavrikov/parallel-matplotlib-grid/ 16 | ``` 17 | 18 | ## How do I use it? 19 | 20 | Aside from the data all you need to provide is the grid layout `grid_shape` and a plotting function `plot_fn`. 21 | Here is an example: 22 | 23 | ```python 24 | from parallelplot import parallel_plot 25 | 26 | import matplotlib.pyplot as plt 27 | import numpy as np 28 | 29 | 30 | def violin(data, fig, axes): 31 | axes.violinplot(data) 32 | 33 | 34 | # Gen some fake data 35 | X = np.random.uniform(low=-1, high=1, size=(30, 512, 512)) 36 | 37 | parallel_plot(plot_fn=violin, data=X, grid_shape=(3, 10)) 38 | plt.show() 39 | ``` 40 | 41 | Want to preprocess your data before plotting? No problem! just provide `preprocess_fn`. 42 | Here is an example where we apply a PCA transformation: 43 | 44 | ```python 45 | from parallelplot import parallel_plot 46 | 47 | import matplotlib.pyplot as plt 48 | import numpy as np 49 | from sklearn.decomposition import PCA 50 | 51 | 52 | def preprocess(data): 53 | return PCA().fit_transform(data) 54 | 55 | 56 | def violin(data, fig, axes): 57 | axes.violinplot(data) 58 | 59 | 60 | # Gen some fake data 61 | X = np.random.uniform(low=-1, high=1, size=(30, 512, 512)) 62 | 63 | parallel_plot(plot_fn=violin, data=X, grid_shape=(3, 10), preprocessing_fn=preprocess) 64 | plt.show() 65 | 66 | ``` 67 | 68 | ## When should I *not* use this library? 69 | 70 | There are some cases where this module is either useless or adds overhead. Here are a few of those: 71 | - Your plot function and preprocessing functions execute fast, but your data is big. `multiprocessing` uses `pickle` as 72 | input / output format of process tasks which requires data to be serialized. This can introduce a significant 73 | overhead. 74 | - Your data is over 4 GiB big. For some reason `multiprocessing` is using some ancient `pickle` format that only supports 75 | data up to 4 GiB of size. There are ways to bypass that, but it's probably not worth it, as pickling is slow, and the 76 | computational overhead may not be worth it. 77 | - You only have one core available. Sorry 'bout that. 78 | 79 | 80 | ## How do I contribute? 81 | 82 | Just create a PR or feel free to raise an issue for questions, feature-requests etc. 83 | -------------------------------------------------------------------------------- /github/parallel_matplotlib_grid.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 |
DATA
DATA
PREPROCESSING
FUNC (OPTIONAL)
PREPROCESSING...
PLOT FUNC
PLOT FUNC
PROCESS POOL
PROCESS POOL
Viewer does not support full SVG 1.1
-------------------------------------------------------------------------------- /parallelplot/__init__.py: -------------------------------------------------------------------------------- 1 | from parallelplot.plot import parallel_plot -------------------------------------------------------------------------------- /parallelplot/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import matplotlib.pyplot as plt 3 | import os 4 | import shutil 5 | import functools 6 | from multiprocessing import get_context 7 | import io 8 | import numpy as np 9 | from PIL import Image 10 | 11 | 12 | CACHE_DIR = ".figcache" 13 | DPI = 96 14 | 15 | 16 | # noinspection PyUnresolvedReferences 17 | def _parallel_plot_worker(args, plot_fn, fig_size, in_memory, preprocessing_fn=None, dpi=96, pad_inches=0.1): 18 | index, data = args 19 | 20 | if preprocessing_fn is not None: 21 | data = preprocessing_fn(data) 22 | 23 | fig = plt.figure(figsize=fig_size) 24 | matplotlib.font_manager._get_font.cache_clear() # necessary to reduce text corruption artifacts 25 | axes = plt.axes() 26 | 27 | plot_fn(data, fig, axes) 28 | 29 | if not in_memory: 30 | path = f"{CACHE_DIR}/{index}.temp.png" 31 | plt.savefig(path, format="png", bbox_inches="tight", dpi=dpi, pad_inches=pad_inches) 32 | plt.close() 33 | return index, path 34 | else: 35 | buf = io.BytesIO() 36 | fig.savefig(buf, format="png", bbox_inches="tight", dpi=dpi, pad_inches=pad_inches) 37 | buf.seek(0) 38 | img = np.array(Image.open(buf)) 39 | buf.close() 40 | plt.close() 41 | return index, img 42 | 43 | 44 | def _make_subplots(data, plot_fn, n_rows, n_cols, grid_cell_size, total=None, preprocessing_fn=None, switch_axis=False, 45 | show_progress=True, in_memory=True, mp_context=None, dpi=96, pad_inches=0.1, max_workers=-1): 46 | 47 | if total is None: 48 | total = len(data) 49 | 50 | worker_func = functools.partial(_parallel_plot_worker, plot_fn=plot_fn, fig_size=grid_cell_size, 51 | in_memory=in_memory, preprocessing_fn=preprocessing_fn, dpi=dpi, 52 | pad_inches=pad_inches) 53 | 54 | subplot_results = np.empty((n_rows, n_cols), dtype=object) 55 | 56 | workers = min(total, os.cpu_count()) 57 | if max_workers != -1: 58 | workers = min(workers, max_workers) 59 | 60 | with get_context(mp_context).Pool(workers) as pool: 61 | 62 | iterator = pool.imap_unordered(worker_func, zip(range(total), data)) 63 | if show_progress: 64 | from tqdm.auto import tqdm 65 | iterator = tqdm(iterator, total=total) 66 | 67 | for index, img in iterator: 68 | 69 | if not in_memory: 70 | img = plt.imread(img) 71 | 72 | if switch_axis: 73 | c = int(index / n_rows) 74 | r = index % n_rows 75 | else: 76 | c = index % n_cols 77 | r = int(index / n_cols) 78 | 79 | subplot_results[r, c] = img 80 | 81 | return subplot_results 82 | 83 | 84 | def parallel_plot(plot_fn, data, grid_shape, total=None, preprocessing_fn=None, col_labels=None, row_labels=None, 85 | grid_cell_size=(6, 12), switch_axis=False, cleanup=True, show_progress=True, in_memory=True, 86 | mp_context=None, max_workers=-1): 87 | """ 88 | Generate a grid of plots, where each plot inside the grid is generated by another process, 89 | effectively allowing parallel plot generation. 90 | 91 | :param plot_fn: Plot function that will be called from the process context. Lambda expressions are not supported. 92 | :param data: Iterable data with length of rows * cols. 93 | :param grid_shape: Shape of the grid as (rows, cols) tuple. For a horizontal list provide (N, 1) and for a vertical 94 | list provide (1, N). 95 | :param total: Length of the data. Must be provided if the passed data length cannot be accessed by calling len() 96 | e.g. on generators 97 | :param preprocessing_fn: Optional preprocessing function that is called from the process context on the data chunk 98 | before plotting. Lambda expressions are not supported. 99 | :param col_labels: Optional list of column labels. 100 | :param row_labels: Optional list of row labels. 101 | :param grid_cell_size: Size of each cell (subplot) as (width, height) tuple. This has a direct impact on the parent 102 | plot size. 103 | :param switch_axis: If false the grid will be populated from left to right, top to bottom. Otherwise, it will be 104 | populated top to bottom, left to right. 105 | :param cleanup: If true, the generated cache directory will be deleted before finishing. Can be useful for 106 | debugging. Only active when in_memory is False. 107 | :param show_progress: If true, shows a progressbar of the plotting. Requires the tqdm module. 108 | :param in_memory: If true (Default) will pass images directly to main instead of writing to drive. 109 | :param mp_context: str that identifies which spawn-method multiprocessing should use. OS- dependant and typically 110 | one of "fork", "forksever", or "spawn". Leave None for system default. 111 | :param max_workers: Maximum number of worker processes to spawn. Leave at -1 to use as many as possible. 112 | :return: fig, axes of the parent plot. 113 | """ 114 | 115 | if not in_memory and not os.path.exists(CACHE_DIR): 116 | os.mkdir(CACHE_DIR) 117 | 118 | n_rows = grid_shape[0] 119 | n_cols = grid_shape[1] 120 | 121 | full_fig_size = (n_cols * grid_cell_size[0], n_rows * grid_cell_size[1]) 122 | 123 | fig, axes = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=full_fig_size) 124 | 125 | if col_labels is not None: 126 | for label, ax in zip(col_labels, axes[0]): 127 | ax.set_title(label, loc="center", wrap=True) 128 | 129 | if row_labels is not None: 130 | for label, ax in zip(row_labels, axes[:, 0]): 131 | ax.set_ylabel(label, loc="center", wrap=True) 132 | 133 | for ax in axes.ravel(): 134 | ax.get_xaxis().set_ticks([]) 135 | ax.get_yaxis().set_ticks([]) 136 | for spine in ax.spines.values(): 137 | spine.set_visible(False) 138 | 139 | subplots = _make_subplots(data, plot_fn, n_rows, n_cols, grid_cell_size, total, preprocessing_fn, switch_axis, 140 | show_progress, in_memory, mp_context, dpi=DPI, max_workers=max_workers) 141 | 142 | for ax, img in zip(axes.ravel(), subplots.ravel()): 143 | ax.imshow(img) 144 | 145 | plt.subplots_adjust(hspace=0, wspace=0) 146 | 147 | if not in_memory and cleanup: 148 | shutil.rmtree(CACHE_DIR) 149 | 150 | return fig, axes 151 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="parallelplot", 8 | version="0.0.4", 9 | author="Paul Gavrikov", 10 | author_email="paul.gavrikov@hs-offenburg.de", 11 | description="Parallel plotting of matplotlib subplots.", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/paulgavrikov/parallel-matplolib-grid", 15 | packages=setuptools.find_packages(), 16 | classifiers=[ 17 | "Programming Language :: Python :: 3", 18 | "License :: OSI Approved :: MIT License", 19 | "Operating System :: OS Independent", 20 | ], 21 | install_requires=[ 22 | "matplotlib", 23 | "pillow", 24 | "numpy" 25 | ], 26 | python_requires=">=3.6" 27 | ) 28 | --------------------------------------------------------------------------------