├── .flake8 ├── .github └── workflows │ └── tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── MANIFEST.in ├── README.md ├── examples ├── sample_figure.png └── sample_figure_layout.png ├── figrid ├── __init__.py └── example_figures │ └── __init__.py ├── figrid_example_notebook.ipynb ├── setup.cfg ├── setup.py └── tests ├── __init__.py └── test_figrid.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | # Match black's default line length 3 | max-line-length = 88 4 | # See https://github.com/PyCQA/pycodestyle/issues/373 5 | extend-ignore = E203 6 | exclude = .git,__pycache__,build,dist 7 | per-file-ignores = 8 | # Allow unused imports in __init__.py 9 | __init__.py: F401 -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | lint: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v3 14 | - name: Set up Python 15 | uses: actions/setup-python@v4 16 | with: 17 | python-version: "3.13" 18 | - name: Install dependencies 19 | run: | 20 | python -m pip install --upgrade pip 21 | pip install -e ".[dev]" 22 | - name: Check formatting with black 23 | run: | 24 | black --check . 25 | - name: Lint with flake8 26 | run: | 27 | flake8 . 28 | 29 | test: 30 | needs: lint 31 | runs-on: ubuntu-latest 32 | strategy: 33 | matrix: 34 | python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] 35 | steps: 36 | - uses: actions/checkout@v3 37 | - name: Set up Python ${{ matrix.python-version }} 38 | uses: actions/setup-python@v4 39 | with: 40 | python-version: ${{ matrix.python-version }} 41 | - name: Install dependencies 42 | run: | 43 | python -m pip install --upgrade pip 44 | pip install -e ".[dev]" 45 | - name: Run tests with coverage 46 | run: | 47 | pytest --cov=figrid --cov-report=term-missing -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | build/ 8 | develop-eggs/ 9 | dist/ 10 | downloads/ 11 | eggs/ 12 | .eggs/ 13 | lib/ 14 | lib64/ 15 | parts/ 16 | sdist/ 17 | var/ 18 | wheels/ 19 | *.egg-info/ 20 | .installed.cfg 21 | *.egg 22 | 23 | # Testing 24 | .coverage 25 | .coverage.* 26 | .pytest_cache/ 27 | htmlcov/ 28 | coverage.xml 29 | 30 | # IDE 31 | .idea/ 32 | .vscode/ 33 | *.swp 34 | .DS_Store 35 | 36 | # Byte-compiled / optimized / DLL files 37 | __pycache__/ 38 | *.py[cod] 39 | *$py.class 40 | 41 | # visual studio code 42 | .vscode 43 | 44 | # dist 45 | /dist 46 | 47 | # notebook checkpoints 48 | *checkpoint* -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 23.12.1 4 | hooks: 5 | - id: black 6 | language_version: python3 7 | - repo: https://github.com/pycqa/flake8 8 | rev: 6.1.0 9 | hooks: 10 | - id: flake8 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Doug Ollerenshaw 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include figrid *.py 2 | include README.md LICENSE -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # figrid 2 | A wrapper for the matplotlib gridspec function. Designed to make it easy to place axes on a pre-defined grid on a figure canvas. For example, maybe you want to lay out axes like this: 3 | 4 | Example Layout 5 | 6 | ## try it out in colab 7 | Open In Colab 8 | 9 | ## how it works 10 | The fundamental function to use is `place_axes_on_grid`. This will generate an evenly spaced 100x100 grid on the desired figure canvas. You can then specify how much of the figure canvas a given axis (or set of axes) will span. 11 | 12 | ## what it's good for 13 | Maybe it's just me, but I've always found matplotlib's `gridspec` function to be confusing. And simple NxM grids with the `subplots` function can be too limiting. This makes it easy to place any number of axes at arbitrary locations on a figure. It's handy for making figures for publication. 14 | 15 | ## a sample workflow 16 | 1) Make some functions to generate the various subplots you want to display on a figure. Those functions should take an axis handle as an input. 17 | 2) Define a figure canvas of the desired size. 18 | 3) Define your axes, specifying their locations using `figrid.place_axes_on_grid()` (a dictionary is a useful data structure for storing your axis handles). 19 | 4) Call your plotting functions with the axes as inputs. 20 | 5) Add some axis labels using `figrid.add_labels()` that you can refer to from your figure legend. 21 | 22 | ## installation: 23 | 24 | pip install figrid 25 | 26 | For development installation with testing dependencies: 27 | 28 | pip install -e ".[dev]" 29 | 30 | ## syntax 31 | `figrid.place_axes_on_grid` takes the following inputs: 32 | * fig - the figure handle on which the axis will be placed 33 | * xspan - a two-element list or tuple defining the left and right edges of the axis, respectively. Numbers should be floats ranging from 0 to 1 and will be rounded to 2 decimal places. 34 | * yspan - a two-element list or tuple defining the top and bottom edges of the axis, respectively. Numbers should be floats ranging from 0 to 1 and will be rounded to 2 decimal places. 35 | * dim - a two-element tuple defining the number of rows/columns of the axis. Default = [1, 1], giving a single axis. 36 | * hspace = a float defining the horizontal space between subplots (if dim is specified) 37 | * vspace = a float defining the vertical space between subplots (if dim is specified) 38 | 39 | ## sample use: 40 | 41 | some imports: 42 | 43 | # import the package as fg 44 | import figrid as fg 45 | 46 | # import example figure code 47 | import example_figures 48 | 49 | # import maptlotlib 50 | import matplotlib.pyplot as plt 51 | 52 | define a function to lay out the axes on a figure 53 | 54 | # define function to set up figure and axes 55 | def make_fig_ax(): 56 | fig = plt.figure(figsize=(11,8.5)) 57 | ax = { 58 | 'panel_A': fg.place_axes_on_grid(fig, xspan=[0.05, 0.3], yspan=[0.05, 0.45]), 59 | 'panel_B': fg.place_axes_on_grid(fig, xspan=[0.4, 1], yspan=[0.05, 0.45], dim=[3, 1], hspace=0.4), 60 | 'panel_C': fg.place_axes_on_grid(fig, xspan=[0.05, 0.4], yspan=[0.57, 1]), 61 | 'panel_D': fg.place_axes_on_grid(fig, xspan=[0.5, 1], yspan=[0.57, 1]) 62 | } 63 | 64 | return fig, ax 65 | 66 | make the figure 67 | 68 | # call function to make figure and axes 69 | fig, ax = make_fig_ax() 70 | 71 | # call individual plotting functions, with axes as inputs 72 | example_figures.heatmap(ax['panel_A']) 73 | example_figures.sinusoids(ax['panel_B']) 74 | example_figures.violins(ax['panel_C']) 75 | example_figures.scatterplot(ax['panel_D']) 76 | 77 | add some labels 78 | 79 | labels = [ 80 | {'label_text':'A', 'xpos':0, 'ypos':0.05, 'fontsize':20, 'weight': 'bold', 'ha': 'right', 'va': 'bottom'}, 81 | {'label_text':'B', 'xpos':0.37, 'ypos':0.05, 'fontsize':20, 'weight': 'bold', 'ha': 'right', 'va': 'bottom'}, 82 | {'label_text':'C', 'xpos':0, 'ypos':0.55, 'fontsize':20, 'weight': 'bold', 'ha': 'right', 'va': 'bottom'}, 83 | {'label_text':'D', 'xpos':0.45, 'ypos':0.55, 'fontsize':20, 'weight': 'bold', 'ha': 'right', 'va': 'bottom'}, 84 | ] 85 | fg.add_labels(fig, labels) 86 | 87 | Then we have this: 88 | 89 | Example Figure 90 | 91 | ## development and testing 92 | 93 | [![Tests](https://github.com/dougollerenshaw/figrid/actions/workflows/tests.yml/badge.svg)](https://github.com/dougollerenshaw/figrid/actions/workflows/tests.yml) 94 | 95 | For development, install with testing dependencies: 96 | 97 | pip install -e ".[dev]" 98 | 99 | To set up pre-commit hooks for automatic code formatting: 100 | 101 | pre-commit install 102 | 103 | This will automatically run black (code formatter) and flake8 (linter) on your commits. 104 | 105 | To run the tests: 106 | 107 | pytest 108 | 109 | To run tests with coverage reporting: 110 | 111 | pytest --cov=figrid --cov-report=term-missing 112 | 113 | To manually format code: 114 | 115 | black . 116 | 117 | To manually check code style: 118 | 119 | flake8 . 120 | 121 | Tests and code quality checks are automatically run on push and pull request to the main branch using GitHub Actions. 122 | 123 | -------------------------------------------------------------------------------- /examples/sample_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dougollerenshaw/figrid/d781c7ff14feb18ae0ca0c322ac03aead9aab266/examples/sample_figure.png -------------------------------------------------------------------------------- /examples/sample_figure_layout.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dougollerenshaw/figrid/d781c7ff14feb18ae0ca0c322ac03aead9aab266/examples/sample_figure_layout.png -------------------------------------------------------------------------------- /figrid/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import matplotlib.gridspec as gridspec 4 | 5 | __version__ = "0.1.7" 6 | 7 | 8 | def place_axes_on_grid( 9 | fig, 10 | dim=[1, 1], 11 | xspan=[0, 1], 12 | yspan=[0, 1], 13 | wspace=None, 14 | hspace=None, 15 | sharex=False, 16 | sharey=False, 17 | frameon=True, 18 | ): 19 | """ 20 | Takes a figure with a gridspec defined and places an array of sub-axes on a portion 21 | of the gridspec. 22 | 23 | Takes as arguments: 24 | fig: figure handle - required 25 | dim: number of rows and columns in the subaxes - defaults to 1x1 26 | xspan: fraction of figure that the subaxes subtends in the x-direction 27 | (0 = left edge, 1 = right edge) 28 | yspan: fraction of figure that the subaxes subtends in the y-direction 29 | (0 = top edge, 1 = bottom edge) 30 | wspace and hspace: white space between subaxes in vertical and horizontal 31 | directions, respectively 32 | returns: 33 | subaxes handles 34 | """ 35 | 36 | outer_grid = gridspec.GridSpec(100, 100) 37 | inner_grid = gridspec.GridSpecFromSubplotSpec( 38 | dim[0], 39 | dim[1], 40 | subplot_spec=outer_grid[ 41 | int(100 * yspan[0]) : int(100 * yspan[1]), 42 | int(100 * xspan[0]) : int(100 * xspan[1]), 43 | ], 44 | wspace=wspace, 45 | hspace=hspace, 46 | ) 47 | 48 | # NOTE: A cleaner way to do this is with list comprehension: 49 | # inner_ax = [[0 for ii in range(dim[1])] for ii in range(dim[0])] 50 | inner_ax = dim[0] * [ 51 | dim[1] * [fig] 52 | ] # fill with figure objects to prevent later errors 53 | inner_ax = np.array(inner_ax) 54 | idx = 0 55 | for row in range(dim[0]): 56 | for col in range(dim[1]): 57 | if row > 0 and sharex: 58 | share_x_with = inner_ax[0][col] 59 | else: 60 | share_x_with = None 61 | 62 | if col > 0 and sharey: 63 | share_y_with = inner_ax[row][0] 64 | else: 65 | share_y_with = None 66 | 67 | inner_ax[row][col] = plt.Subplot( 68 | fig, 69 | inner_grid[idx], 70 | sharex=share_x_with, 71 | sharey=share_y_with, 72 | frameon=frameon, 73 | ) 74 | 75 | if row == dim[0] - 1 and sharex: 76 | # For shared x-axes, only show tick labels on the bottom subplot 77 | inner_ax[row][col].xaxis.set_ticks_position("bottom") 78 | elif row < dim[0] and sharex: 79 | # Hide tick labels (but keep ticks) on all but the bottom subplot 80 | plt.setp(inner_ax[row][col].get_xticklabels(), visible=False) 81 | 82 | if col == 0 and sharey: 83 | # For shared y-axes, only show tick labels on the leftmost subplot 84 | inner_ax[row][col].yaxis.set_ticks_position("left") 85 | elif col > 0 and sharey: 86 | # Hide tick labels (but keep ticks) on all but the leftmost subplot 87 | plt.setp(inner_ax[row][col].get_yticklabels(), visible=False) 88 | 89 | fig.add_subplot(inner_ax[row, col]) 90 | idx += 1 91 | 92 | inner_ax = np.array(inner_ax).squeeze().tolist() # remove redundant dimension 93 | return inner_ax 94 | 95 | 96 | def add_label(fig, label_text, xpos, ypos, **kwargs): 97 | """ 98 | add a single label to a figure canvas using the place_axes_on_grid infrastructure 99 | inputs: 100 | fig: figure handle 101 | label_text : text of label, 102 | xpos, ypos: floats from 0 to 1 defining where on the canvas the label should be 103 | kwargs: additional keyword arguments for matplotlib text() 104 | """ 105 | label_axis = place_axes_on_grid( 106 | fig, 107 | xspan=[xpos, xpos + 0.01], 108 | yspan=[ypos, ypos + 0.01], 109 | ) 110 | label_axis.text(0, 0, label_text, **kwargs) 111 | label_axis.axis("off") 112 | 113 | 114 | def add_labels(fig, labels): 115 | """ 116 | Add multiple labels to a figure canvas using the place_axes_on_grid infrastructure. 117 | 118 | inputs: 119 | fig: figure handle 120 | labels: a list of dictionaries with the following key/value pairs: 121 | * label_text (required): text of label 122 | * xpos (required): float from 0 to 1 defining horizontal position 123 | * ypos (required): float from 0 to 1 defining vertical position 124 | * any additional keyword arguments that can be passed to the matplotlib 125 | text function (e.g., fontsize, weight, etc) 126 | """ 127 | for label in labels: 128 | add_label(fig, **label) 129 | 130 | 131 | def scalebar( 132 | axis, 133 | x_pos, 134 | y_pos, 135 | x_length=None, 136 | y_length=None, 137 | x_text=None, 138 | y_text=None, 139 | x_buffer=0.25, 140 | y_buffer=0.25, 141 | scalebar_color="black", 142 | text_color="black", 143 | fontsize=10, 144 | linewidth=3, 145 | ): 146 | """ 147 | add a scalebar 148 | input params: 149 | axis: axis on which to add scalebar 150 | x_pos: x position, in pixels 151 | y_pos: y position, in pixels 152 | """ 153 | if x_length is not None: 154 | axis.plot( 155 | [x_pos, x_pos + x_length], 156 | [y_pos, y_pos], 157 | color=scalebar_color, 158 | linewidth=linewidth, 159 | ) 160 | axis.text( 161 | x_pos + x_length / 2, 162 | y_pos - y_buffer, 163 | x_text, 164 | color=text_color, 165 | fontsize=fontsize, 166 | ha="center", 167 | va="top", 168 | ) 169 | 170 | if y_length is not None: 171 | axis.plot( 172 | [x_pos, x_pos], 173 | [y_pos, y_pos + y_length], 174 | color=scalebar_color, 175 | linewidth=linewidth, 176 | ) 177 | 178 | axis.text( 179 | x_pos - x_buffer, 180 | y_pos + y_length / 2, 181 | y_text, 182 | color=text_color, 183 | fontsize=fontsize, 184 | ha="right", 185 | va="center", 186 | ) 187 | -------------------------------------------------------------------------------- /figrid/example_figures/__init__.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import seaborn as sns 4 | 5 | import figrid as fg 6 | 7 | 8 | def heatmap(axis): 9 | """ 10 | plot a random pixel image 11 | input: a single axis handle 12 | """ 13 | axis.imshow(np.random.randn(100, 100), cmap="gray") 14 | axis.set_title("An image") 15 | axis.axis("off") 16 | 17 | # add a scalebar 18 | fg.scalebar( 19 | axis=axis, 20 | x_pos=10, 21 | y_pos=85, 22 | x_length=30, 23 | x_text="60 um", 24 | scalebar_color="white", 25 | text_color="white", 26 | fontsize=12, 27 | y_buffer=-2, 28 | ) 29 | 30 | 31 | def sinusoids(axis): 32 | """ 33 | plot 3 sinusoids plus a scalebar 34 | input: a 3 row by 1 column array of axis handles 35 | """ 36 | t = np.arange(0, 10, 0.01) 37 | for row in range(3): 38 | f = 0.8 * (row + 1) 39 | axis[row].plot(t, np.sin(2 * np.pi * f * t), color="black", linewidth=2) 40 | axis[row].axis("off") 41 | axis[row].set_xlim(-0.5, 10.5) 42 | axis[row].set_ylim(-1.35, 1.05) 43 | axis[row].set_title("frequency = {:0.1f} Hz".format(f)) 44 | 45 | # add a scalebar 46 | fg.scalebar( 47 | axis=axis[2], 48 | x_pos=-0.25, 49 | y_pos=-1.25, 50 | x_length=1, 51 | y_length=1, 52 | x_text="1 s", 53 | y_text="1 u", 54 | ) 55 | 56 | 57 | def violins(axis): 58 | """ 59 | violinplot example from: 60 | https://seaborn.pydata.org/examples/simple_violinplots.html 61 | """ 62 | # Create a random dataset across several variables 63 | rs = np.random.default_rng(0) 64 | n, p = 40, 8 65 | d = rs.normal(0, 2, (n, p)) 66 | d += np.log(np.arange(1, p + 1)) * -5 + 10 67 | 68 | # Show each distribution with both violins and points 69 | sns.violinplot(data=d, palette="light:g", inner="points", orient="h", ax=axis) 70 | sns.despine() 71 | 72 | 73 | def scatterplot(axis): 74 | """ 75 | scatterplot example from: 76 | https://seaborn.pydata.org/examples/layered_bivariate_plot.html 77 | """ 78 | # Simulate data from a bivariate Gaussian 79 | n = 10000 80 | mean = [0, 0] 81 | cov = [(2, 0.4), (0.4, 0.2)] 82 | rng = np.random.RandomState(0) 83 | x, y = rng.multivariate_normal(mean, cov, n).T 84 | 85 | # Draw a combo histogram and scatterplot with density contours 86 | sns.scatterplot(x=x, y=y, s=5, color=".15", ax=axis) 87 | sns.kdeplot(x=x, y=y, levels=5, color="w", linewidths=1, ax=axis) 88 | -------------------------------------------------------------------------------- /figrid_example_notebook.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "figrid_example_notebook.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "authorship_tag": "ABX9TyO01ULmqHO0gKMYxabl74rd", 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | }, 16 | "language_info": { 17 | "name": "python" 18 | } 19 | }, 20 | "cells": [ 21 | { 22 | "cell_type": "markdown", 23 | "metadata": { 24 | "id": "view-in-github", 25 | "colab_type": "text" 26 | }, 27 | "source": [ 28 | "\"Open" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": { 34 | "id": "iU5XTqLcBktY" 35 | }, 36 | "source": [ 37 | "# A simple example use case\n", 38 | "This example shows how to place four figures on a figure canvas using `figrid.place_axes_on_grid`. \n", 39 | "It uses plots that are generated in the `example_figures` module of the package. \n", 40 | "Plots are arranged in desired locations, then figure labels are added." 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "metadata": { 46 | "id": "yEPnvyXhBngr" 47 | }, 48 | "source": [ 49 | "## Install figrid\n", 50 | "First install figrid in the current environment using pip" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "metadata": { 56 | "id": "OVpb05qhBM7N" 57 | }, 58 | "source": [ 59 | "!pip install figrid" 60 | ], 61 | "execution_count": null, 62 | "outputs": [] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": { 67 | "id": "3qHXBOH8Bz9S" 68 | }, 69 | "source": [ 70 | "## Imports\n", 71 | "Now import figrid, example_figure definitions, and matplotlib" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "metadata": { 77 | "id": "k0GEpkUvBO0Q" 78 | }, 79 | "source": [ 80 | "# import the package as fg\n", 81 | "import figrid as fg\n", 82 | "\n", 83 | "# import code for example figures\n", 84 | "import figrid.example_figures as example_figures\n", 85 | "\n", 86 | "# import maptlotlib\n", 87 | "import matplotlib.pyplot as plt" 88 | ], 89 | "execution_count": null, 90 | "outputs": [] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "metadata": { 95 | "id": "_kPNwzuKByKU" 96 | }, 97 | "source": [ 98 | "## Define a function to make a figure and place axes\n", 99 | "Use `figrid.place_axes_on_grid` to define four axes on an 11 x 8.5 inch figure canvas" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "metadata": { 105 | "id": "wxBkoalxBTwU" 106 | }, 107 | "source": [ 108 | "def make_fig_ax():\n", 109 | " fig = plt.figure(figsize=(11,8.5))\n", 110 | " ax = {\n", 111 | " 'panel_A': fg.place_axes_on_grid(fig, xspan=[0.05, 0.3], yspan=[0.05, 0.45]),\n", 112 | " 'panel_B': fg.place_axes_on_grid(fig, xspan=[0.4, 1], yspan=[0.05, 0.45], dim=[3, 1], hspace=0.4),\n", 113 | " 'panel_C': fg.place_axes_on_grid(fig, xspan=[0.05, 0.4], yspan=[0.57, 1]),\n", 114 | " 'panel_D': fg.place_axes_on_grid(fig, xspan=[0.5, 1], yspan=[0.57, 1])\n", 115 | " }\n", 116 | " \n", 117 | " return fig, ax" 118 | ], 119 | "execution_count": null, 120 | "outputs": [] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "metadata": { 125 | "id": "kyV1sguvB_lX" 126 | }, 127 | "source": [ 128 | "## Make the figure\n", 129 | "Add the plots to the axes, then add some labels" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "metadata": { 135 | "colab": { 136 | "base_uri": "https://localhost:8080/", 137 | "height": 503 138 | }, 139 | "id": "CjixSobGBa_I", 140 | "outputId": "a77e739d-5bca-4eb4-e463-ee1abe9326fa" 141 | }, 142 | "source": [ 143 | "fig, ax = make_fig_ax()\n", 144 | "\n", 145 | "example_figures.heatmap(ax['panel_A'])\n", 146 | "example_figures.sinusoids(ax['panel_B'])\n", 147 | "example_figures.violins(ax['panel_C'])\n", 148 | "example_figures.scatterplot(ax['panel_D'])\n", 149 | "\n", 150 | "# add labels\n", 151 | "labels = [\n", 152 | " {'label_text':'A', 'xpos':0, 'ypos':0.05, 'fontsize':20, 'weight': 'bold', 'ha': 'right', 'va': 'bottom'},\n", 153 | " {'label_text':'B', 'xpos':0.37, 'ypos':0.05, 'fontsize':20, 'weight': 'bold', 'ha': 'right', 'va': 'bottom'},\n", 154 | " {'label_text':'C', 'xpos':0, 'ypos':0.55, 'fontsize':20, 'weight': 'bold', 'ha': 'right', 'va': 'bottom'},\n", 155 | " {'label_text':'D', 'xpos':0.45, 'ypos':0.55, 'fontsize':20, 'weight': 'bold', 'ha': 'right', 'va': 'bottom'},\n", 156 | "]\n", 157 | "fg.add_labels(fig, labels)" 158 | ], 159 | "execution_count": null, 160 | "outputs": [ 161 | { 162 | "output_type": "display_data", 163 | "data": { 164 | "image/png": "\n", 165 | "text/plain": [ 166 | "
" 167 | ] 168 | }, 169 | "metadata": { 170 | "tags": [], 171 | "needs_background": "light" 172 | } 173 | } 174 | ] 175 | } 176 | ] 177 | } -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | long_description = file: README.md 3 | long_description_content_type = text/markdown -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | """ 4 | figrid provides a set of convenience functions to make it easier to 5 | place axes on a grid using matplotlib 6 | """ 7 | setup( 8 | name="figrid", 9 | version="0.1.7", 10 | packages=["figrid"], 11 | include_package_data=True, 12 | description="Formats multipanel figures", 13 | url="https://github.com/dougollerenshaw/figrid", 14 | author="Doug Ollerenshaw", 15 | author_email="d.ollerenshaw@gmail.com", 16 | license="MIT", 17 | install_requires=["matplotlib", "numpy", "seaborn"], 18 | extras_require={ 19 | "dev": [ 20 | "pytest>=7.0.0", 21 | "pytest-cov>=4.0.0", 22 | "black>=23.0.0", 23 | "flake8>=6.0.0", 24 | "pre-commit>=3.0.0", 25 | ], 26 | }, 27 | classifiers=[ 28 | "Development Status :: 3 - Alpha", 29 | "Intended Audience :: Science/Research", 30 | "License :: OSI Approved :: MIT License", 31 | "Natural Language :: English", 32 | "Programming Language :: Python :: 3.8", 33 | "Programming Language :: Python :: 3.9", 34 | "Programming Language :: Python :: 3.10", 35 | "Programming Language :: Python :: 3.11", 36 | "Programming Language :: Python :: 3.12", 37 | "Programming Language :: Python :: 3.13", 38 | ], 39 | ) 40 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dougollerenshaw/figrid/d781c7ff14feb18ae0ca0c322ac03aead9aab266/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_figrid.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import matplotlib.pyplot as plt 3 | from figrid import place_axes_on_grid, add_label, add_labels, scalebar 4 | from figrid.example_figures import heatmap, sinusoids, violins, scatterplot 5 | 6 | 7 | @pytest.fixture 8 | def figure(): 9 | """Create a new figure for each test.""" 10 | fig = plt.figure(figsize=(10, 10)) 11 | yield fig 12 | plt.close(fig) 13 | 14 | 15 | def test_example_layout(figure): 16 | """Test the complete layout example from README.""" 17 | ax = { 18 | "panel_A": place_axes_on_grid(figure, xspan=[0.05, 0.3], yspan=[0.05, 0.45]), 19 | "panel_B": place_axes_on_grid( 20 | figure, xspan=[0.4, 1], yspan=[0.05, 0.45], dim=[3, 1], hspace=0.4 21 | ), 22 | "panel_C": place_axes_on_grid(figure, xspan=[0.05, 0.4], yspan=[0.57, 1]), 23 | "panel_D": place_axes_on_grid(figure, xspan=[0.5, 1], yspan=[0.57, 1]), 24 | } 25 | 26 | # Verify all panels exist 27 | assert all(key in ax for key in ["panel_A", "panel_B", "panel_C", "panel_D"]) 28 | 29 | # Verify panel_B is a 3x1 grid 30 | assert len(ax["panel_B"]) == 3 31 | 32 | # Add example labels 33 | labels = [ 34 | { 35 | "label_text": "A", 36 | "xpos": 0, 37 | "ypos": 0.05, 38 | "fontsize": 20, 39 | "weight": "bold", 40 | "ha": "right", 41 | "va": "bottom", 42 | }, 43 | { 44 | "label_text": "B", 45 | "xpos": 0.37, 46 | "ypos": 0.05, 47 | "fontsize": 20, 48 | "weight": "bold", 49 | "ha": "right", 50 | "va": "bottom", 51 | }, 52 | { 53 | "label_text": "C", 54 | "xpos": 0, 55 | "ypos": 0.55, 56 | "fontsize": 20, 57 | "weight": "bold", 58 | "ha": "right", 59 | "va": "bottom", 60 | }, 61 | { 62 | "label_text": "D", 63 | "xpos": 0.45, 64 | "ypos": 0.55, 65 | "fontsize": 20, 66 | "weight": "bold", 67 | "ha": "right", 68 | "va": "bottom", 69 | }, 70 | ] 71 | add_labels(figure, labels) 72 | 73 | 74 | def test_single_axis(figure): 75 | """Test creating a single axis.""" 76 | ax = place_axes_on_grid(figure) 77 | assert isinstance(ax, plt.Axes) 78 | 79 | 80 | def test_multi_axis_grid(figure): 81 | """Test creating a multi-axis grid.""" 82 | axes = place_axes_on_grid(figure, dim=[2, 2]) 83 | assert len(axes) == 2 84 | assert len(axes[0]) == 2 85 | 86 | 87 | def test_shared_axes(figure): 88 | """Test axes with shared x and y axes.""" 89 | axes = place_axes_on_grid(figure, dim=[2, 1], sharex=True) 90 | 91 | # Axes should exist and be properly configured 92 | assert len(axes) == 2 93 | 94 | # Bottom axis should show ticks, top axis should share x scale 95 | assert axes[0].get_shared_x_axes().joined(axes[0], axes[1]) 96 | assert axes[1].xaxis.get_ticks_position() == "bottom" 97 | 98 | # Check that top axis has hidden tick labels 99 | assert not any(label.get_visible() for label in axes[0].get_xticklabels()) 100 | assert all(label.get_visible() for label in axes[1].get_xticklabels()) 101 | 102 | 103 | def test_axis_positioning(figure): 104 | """Test custom axis positioning.""" 105 | ax = place_axes_on_grid(figure, dim=[1, 1], xspan=[0.2, 0.8], yspan=[0.2, 0.8]) 106 | bbox = ax.get_position() 107 | # Allow for some margin adjustment while ensuring the axis is roughly where expected 108 | assert 0.2 < bbox.x0 < 0.3 # Axis starts in the first third 109 | assert 0.7 < bbox.x1 < 0.8 # Axis ends in the last third 110 | assert 0.2 < bbox.y0 < 0.3 # Similar bounds for y-axis 111 | 112 | 113 | def test_single_label(figure): 114 | """Test adding a single label.""" 115 | label_text = "Test Label" 116 | add_label(figure, label_text, 0.5, 0.5, fontsize=12) 117 | all_texts = [] 118 | for ax in figure.axes: 119 | all_texts.extend(ax.texts) 120 | assert any(text.get_text() == label_text for text in all_texts) 121 | 122 | 123 | def test_multiple_labels(figure): 124 | """Test adding multiple labels.""" 125 | labels = [ 126 | {"label_text": "A", "xpos": 0.1, "ypos": 0.1, "fontsize": 12}, 127 | {"label_text": "B", "xpos": 0.9, "ypos": 0.1, "fontsize": 12}, 128 | ] 129 | add_labels(figure, labels) 130 | all_texts = [] 131 | for ax in figure.axes: 132 | all_texts.extend(ax.texts) 133 | assert any(text.get_text() == "A" for text in all_texts) 134 | assert any(text.get_text() == "B" for text in all_texts) 135 | 136 | 137 | def test_scalebar_creation(figure): 138 | """Test adding a scalebar to an axis.""" 139 | ax = place_axes_on_grid(figure) 140 | scalebar( 141 | ax, 142 | x_pos=0.5, 143 | y_pos=0.5, 144 | x_length=1.0, 145 | y_length=1.0, 146 | x_text="1 unit", 147 | y_text="1 unit", 148 | ) 149 | assert len(ax.lines) > 0 # Has at least one line 150 | assert len(ax.texts) > 0 # Has at least one text element 151 | 152 | 153 | def test_example_heatmap(figure): 154 | """Test the heatmap example figure.""" 155 | ax = place_axes_on_grid(figure) 156 | heatmap(ax) 157 | # Check that image was added 158 | assert len(ax.images) == 1 159 | # Check that scalebar was added 160 | assert len(ax.lines) > 0 161 | assert len(ax.texts) > 0 162 | 163 | 164 | def test_example_sinusoids(figure): 165 | """Test the sinusoids example figure.""" 166 | axes = place_axes_on_grid(figure, dim=[3, 1]) 167 | sinusoids(axes) 168 | # Check that each subplot has the sinusoid 169 | for ax in axes: 170 | assert len(ax.lines) > 0 # At least one line exists 171 | # Check that scalebar was added to bottom plot 172 | assert len(axes[2].texts) > 0 173 | 174 | 175 | def test_example_violins(figure): 176 | """Test the violin plot example figure.""" 177 | ax = place_axes_on_grid(figure) 178 | violins(ax) 179 | # Check that violin plot elements exist 180 | assert len(ax.collections) > 0 # Violin shapes and points 181 | 182 | 183 | def test_example_scatterplot(figure): 184 | """Test the scatter plot example figure.""" 185 | ax = place_axes_on_grid(figure) 186 | scatterplot(ax) 187 | # Check that plot elements exist 188 | assert len(ax.collections) > 0 # Scatter points and contours 189 | --------------------------------------------------------------------------------