├── .gitignore
├── LICENCE
├── README.md
├── doc
├── READMEExample.svg
└── Summary.svg
├── examples
├── README_example.ipynb
├── contour_colorbar.ipynb
├── grid_axes.ipynb
├── one_axes.ipynb
├── profile_solver.py
├── ten_simple_rules_demo.ipynb
├── tutorial.ipynb
└── two_axes.ipynb
├── logo.svg
├── pyproject.toml
├── src
└── mpllayout
│ ├── __init__.py
│ ├── constraints.py
│ ├── constructions.py
│ ├── containers.py
│ ├── geometry.py
│ ├── layout.py
│ ├── matplotlibutils.py
│ ├── primitives.py
│ ├── solver.py
│ └── ui.py
└── tests
├── __init__.py
├── fixture_primitives.py
├── test_constraints.py
├── test_constructions.py
├── test_containers.py
├── test_layout.py
├── test_primitives.py
└── test_solve.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # My rule
2 | prototype.py
3 |
4 | # Benchmarking files
5 | *.prof
6 | *.profile
7 | *.lprof
8 |
9 | # Figures / images / video
10 | *.png
11 | *.svg
12 |
13 | # Byte-compiled / optimized / DLL files
14 | __pycache__/
15 | *.py[cod]
16 | *$py.class
17 |
18 | # C extensions
19 | *.so
20 |
21 | # Distribution / packaging
22 | .Python
23 | build/
24 | develop-eggs/
25 | dist/
26 | downloads/
27 | eggs/
28 | .eggs/
29 | lib/
30 | lib64/
31 | parts/
32 | sdist/
33 | var/
34 | wheels/
35 | share/python-wheels/
36 | *.egg-info/
37 | .installed.cfg
38 | *.egg
39 | MANIFEST
40 |
41 | # PyInstaller
42 | # Usually these files are written by a python script from a template
43 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
44 | *.manifest
45 | *.spec
46 |
47 | # Installer logs
48 | pip-log.txt
49 | pip-delete-this-directory.txt
50 |
51 | # Unit test / coverage reports
52 | htmlcov/
53 | .tox/
54 | .nox/
55 | .coverage
56 | .coverage.*
57 | .cache
58 | nosetests.xml
59 | coverage.xml
60 | *.cover
61 | *.py,cover
62 | .hypothesis/
63 | .pytest_cache/
64 | cover/
65 |
66 | # Translations
67 | *.mo
68 | *.pot
69 |
70 | # Django stuff:
71 | *.log
72 | local_settings.py
73 | db.sqlite3
74 | db.sqlite3-journal
75 |
76 | # Flask stuff:
77 | instance/
78 | .webassets-cache
79 |
80 | # Scrapy stuff:
81 | .scrapy
82 |
83 | # Sphinx documentation
84 | docs/_build/
85 |
86 | # PyBuilder
87 | .pybuilder/
88 | target/
89 |
90 | # Jupyter Notebook
91 | .ipynb_checkpoints
92 |
93 | # IPython
94 | profile_default/
95 | ipython_config.py
96 |
97 | # pyenv
98 | # For a library or package, you might want to ignore these files since the code is
99 | # intended to run in multiple environments; otherwise, check them in:
100 | # .python-version
101 |
102 | # pipenv
103 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
104 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
105 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
106 | # install all needed dependencies.
107 | #Pipfile.lock
108 |
109 | # poetry
110 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
111 | # This is especially recommended for binary packages to ensure reproducibility, and is more
112 | # commonly ignored for libraries.
113 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
114 | #poetry.lock
115 |
116 | # pdm
117 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
118 | #pdm.lock
119 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
120 | # in version control.
121 | # https://pdm.fming.dev/#use-with-ide
122 | .pdm.toml
123 |
124 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
125 | __pypackages__/
126 |
127 | # Celery stuff
128 | celerybeat-schedule
129 | celerybeat.pid
130 |
131 | # SageMath parsed files
132 | *.sage.py
133 |
134 | # Environments
135 | .env
136 | .venv
137 | env/
138 | venv/
139 | ENV/
140 | env.bak/
141 | venv.bak/
142 |
143 | # Spyder project settings
144 | .spyderproject
145 | .spyproject
146 |
147 | # Rope project settings
148 | .ropeproject
149 |
150 | # mkdocs documentation
151 | /site
152 |
153 | # mypy
154 | .mypy_cache/
155 | .dmypy.json
156 | dmypy.json
157 |
158 | # Pyre type checker
159 | .pyre/
160 |
161 | # pytype static type analyzer
162 | .pytype/
163 |
164 | # Cython debug symbols
165 | cython_debug/
166 |
167 | # PyCharm
168 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
169 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
170 | # and can be added to the global gitignore or merged into this file. For a more nuclear
171 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
172 | #.idea/
173 |
--------------------------------------------------------------------------------
/LICENCE:
--------------------------------------------------------------------------------
1 | Copyright 2023 Jonathan Deng
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4 |
5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6 |
7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
8 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MPLLayout
2 |
3 | 
4 |
5 | ## Summary
6 |
7 | MPLLayout is a package to create precise figure layouts in [matplotlib](https://matplotlib.org/).
8 | It works by modelling figure elements using geometric primitives (for example, text anchors are points, the figure is a quadrilateral, etc.), then constraining the sizes and positions of these elements using geometric constraints (for example, fixing the width of a figure, constraining axes sides to be collinear, constraining axes to lie on a grid, etc.).
9 |
10 | Using this approach, MPLLayout can:
11 |
12 | * align figure elements (axes, text label location, x and y axis, etc.),
13 | * specify margins around axes,
14 | * create templates for figures across different mediums (posters, manuscripts, slides, etc.),
15 | * and more!
16 |
17 | ## Basic usage
18 |
19 | The basic workflow to create a constrained layout is demonstrated through this example with a single axes figure.
20 | To create a figure 5 in wide and 4 in tall use the code:
21 |
22 | ```python
23 | import numpy as np
24 | from matplotlib import pyplot as plt
25 |
26 | from mpllayout import layout as ly
27 | from mpllayout import primitives as pr
28 | from mpllayout import constraints as cr
29 | from mpllayout import matplotlibutils as mputils
30 | from mpllayout import solver
31 |
32 | # The layout object stores the geometry and constraints defining the layout
33 | layout = ly.Layout()
34 |
35 | # To add geometry, pass the: geometry primitive and a string key.
36 | # Naming the `Quadrilateral`, "Figure" will cause mpllayout to identify it as a figure
37 | layout.add_prim(pr.Quadrilateral(), "Figure")
38 |
39 | # To create a constraint, pass the: constraint, geometry to constrain, and
40 | # constraint arguments.
41 | # Constraint documentation describes what kind of geometry can be constrained and
42 | # any constraint arguments.
43 | layout.add_constraint(cr.Box(), ("Figure",), ())
44 | layout.add_constraint(cr.Width(), ("Figure",), (5.0,))
45 | layout.add_constraint(cr.Height(), ("Figure",), (4.0,))
46 | ```
47 |
48 | To create an axes inside the figure with desired margins around the edges, use the code:
49 |
50 | ```python
51 | # To add an axes, pass the `Axes` primitive
52 | # The `Axes` is container of Quadrilaterals representing the drawing area (frame),
53 | # as well as, optionally, the x-axis and y-axis
54 | layout.add_prim(pr.Axes(), "MyAxes")
55 |
56 | # Constrain the axes drawing area to a box
57 | layout.add_constraint(cr.Box(), ("MyAxes/Frame",), ())
58 | # Set "inner" margins around the outside of the axes frame to the figure
59 | # The inner margin is the distance from a `Quadrilateral` inside another
60 | # `Quadrilateral`
61 | layout.add_constraint(cr.InnerMargin(side="bottom"), ("MyAxes/Frame", "Figure"), (.5,))
62 | layout.add_constraint(cr.InnerMargin(side="top"), ("MyAxes/Frame", "Figure"), (.5,))
63 | layout.add_constraint(cr.InnerMargin(side="left"), ("MyAxes/Frame", "Figure"), (2.0,))
64 | layout.add_constraint(cr.InnerMargin(side="right"), ("MyAxes/Frame", "Figure"), (0.5,))
65 | ```
66 |
67 | The desired layout can then be solved and used to create a plot.
68 |
69 | ```python
70 | # Create the figure and any axes from the solved geometry
71 | fig, axs = mputils.subplots(solved_prims)
72 |
73 | x = np.linspace(0, 2*np.pi)
74 | axs["MyAxes"].plot(np.sin(x))
75 | ```
76 |
77 | The code above results in the figure:
78 | 
79 |
80 | While this approach requires more code to specify the layout of figure elements, it allows precise specification of the layout.
81 | This can be useful, for example, in publication documents where precise figure, axes, and font sizes are desired.
82 | In addition, layouts can be adjusted (for example, by adjusting margin arguments) and serve as a template to generate multiple figures.
83 |
84 | More examples can be found in the `examples` folder which demonstrate other constraints and geometric primitives used to achieve more complicated layouts.
85 | The example given above can be found at `examples/README_example.ipynb`.
86 | The tutorial notebook in `examples/tutorial.ipynb` demonstrates the basic usage of the package and explains some of the commonly used geometric constraints.
87 | Other examples are also given in the `examples` folder.
88 | The notebook at `examples/ten_simple_rules_demo.ipynb` contains an interactive demo to recreate a figure from ["Ten Simple Rules For Better Figures"](https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1003833) (Rougier, Droettboom and Bourne 2014).
89 |
90 | ## Motivation
91 |
92 | Matplotlib contains several strategies for creating figure layouts (for example, `GridSpec` and `subplots` for grid-based layouts).
93 | While these approaches work well, greater control over figure element positions is sometimes desirable;
94 | for example, when preparing figures for published documents, research papers, or slides.
95 |
96 | ## Installation
97 |
98 | You can install the package from PyPI using
99 |
100 | ```bash
101 | pip install matplotlib-layout
102 | ```
103 |
104 | Alternateively, clone the repository into a local drive.
105 | Navigate to the project directory and run
106 |
107 | ```bash
108 | pip install .
109 | ```
110 |
111 | The package requires `numpy`, `matplotlib`, and `jax`.
112 |
113 | ## Contributing
114 |
115 | This project is a work in progress so there are likely bugs and missing features.
116 | If you would like to contribute a bug fix, a feature, refactor etc. thank you!
117 | All contributions are welcome.
118 |
119 | ## Motivation and Similar Projects
120 |
121 | A similar project with a geometric constraint solver is [`pygeosolve`](https://github.com/SeanDS/pygeosolve).
122 | There is also another project prototype for a constraint-based layout engine for `matplotlib` [`MplLayouter`](https://github.com/Tillsten/MplLayouter), although it doesn't seem active as of 2023.
123 |
--------------------------------------------------------------------------------
/doc/READMEExample.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
589 |
--------------------------------------------------------------------------------
/examples/README_example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "302445fb",
6 | "metadata": {},
7 | "source": [
8 | "# Basic Example\n",
9 | "\n",
10 | "This notebook demonstrates the basic example shown in the readme"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "id": "ffeacd16",
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "# Ensure that the notebook shows whitespace in the plot\n",
21 | "%config InlineBackend.print_figure_kwargs = {'bbox_inches': None}"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": null,
27 | "id": "e8888d26",
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "import numpy as np\n",
32 | "from matplotlib import pyplot as plt\n",
33 | "\n",
34 | "from mpllayout import layout as ly\n",
35 | "from mpllayout import primitives as pr\n",
36 | "from mpllayout import constraints as cr\n",
37 | "from mpllayout import matplotlibutils as mputils\n",
38 | "from mpllayout import solver\n",
39 | "\n",
40 | "# The layout object stores the geometry and constraints defining the layout\n",
41 | "layout = ly.Layout()\n",
42 | "\n",
43 | "# To add geometry, pass the: geometry primitive and a string key.\n",
44 | "# Naming the `Quadrilateral`, \"Figure\" will cause mpllayout to identify it as a figure\n",
45 | "layout.add_prim(pr.Quadrilateral(), \"Figure\")\n",
46 | "\n",
47 | "# To create a constraint, pass the: constraint, geometry to constrain, and\n",
48 | "# constraint arguments.\n",
49 | "# Constraint documentation describes what kind of geometry can be constrained and\n",
50 | "# any constraint arguments.\n",
51 | "layout.add_constraint(cr.Box(), (\"Figure\",), ())\n",
52 | "layout.add_constraint(cr.Width(), (\"Figure\",), (5.0,))\n",
53 | "layout.add_constraint(cr.Height(), (\"Figure\",), (4.0,))\n",
54 | "\n",
55 | "# To add an axes, pass the `Axes` primitive\n",
56 | "# The `Axes` is container of Quadrilaterals representing the drawing area (frame),\n",
57 | "# as well as, optionally, the x-axis and y-axis\n",
58 | "layout.add_prim(pr.Axes(), \"MyAxes\")\n",
59 | "\n",
60 | "# Constrain the axes drawing area to a box\n",
61 | "layout.add_constraint(cr.Box(), (\"MyAxes/Frame\",), ())\n",
62 | "# Set \"inner\" margins around the outside of the axes frame to the figure\n",
63 | "# The inner margin is the distance from a `Quadrilateral` inside another\n",
64 | "# `Quadrilateral`\n",
65 | "layout.add_constraint(cr.InnerMargin(side=\"bottom\"), (\"MyAxes/Frame\", \"Figure\"), (.5,))\n",
66 | "layout.add_constraint(cr.InnerMargin(side=\"top\"), (\"MyAxes/Frame\", \"Figure\"), (.5,))\n",
67 | "layout.add_constraint(cr.InnerMargin(side=\"left\"), (\"MyAxes/Frame\", \"Figure\"), (2.0,))\n",
68 | "layout.add_constraint(cr.InnerMargin(side=\"right\"), (\"MyAxes/Frame\", \"Figure\"), (0.5,))\n",
69 | "\n",
70 | "# Solve the constrained layout for geometry that satisfies the constraints\n",
71 | "solved_prims, *_ = solver.solve(layout)"
72 | ]
73 | },
74 | {
75 | "cell_type": "code",
76 | "execution_count": null,
77 | "id": "7166f72e",
78 | "metadata": {},
79 | "outputs": [],
80 | "source": [
81 | "# Create the figure and any axes from the solved geometry\n",
82 | "fig, axs = mputils.subplots(solved_prims)\n",
83 | "\n",
84 | "x = np.linspace(0, 2*np.pi)\n",
85 | "axs[\"MyAxes\"].plot(np.sin(x))\n",
86 | "axs[\"MyAxes\"].set_xlabel(\"x\")\n",
87 | "axs[\"MyAxes\"].set_ylabel(\"y\")\n",
88 | "\n",
89 | "fig.savefig(\"READMEExample.svg\")\n"
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": null,
95 | "id": "db55b8dd",
96 | "metadata": {},
97 | "outputs": [],
98 | "source": []
99 | }
100 | ],
101 | "metadata": {
102 | "kernelspec": {
103 | "display_name": "numerics",
104 | "language": "python",
105 | "name": "python3"
106 | },
107 | "language_info": {
108 | "codemirror_mode": {
109 | "name": "ipython",
110 | "version": 3
111 | },
112 | "file_extension": ".py",
113 | "mimetype": "text/x-python",
114 | "name": "python",
115 | "nbconvert_exporter": "python",
116 | "pygments_lexer": "ipython3",
117 | "version": "3.12.7"
118 | }
119 | },
120 | "nbformat": 4,
121 | "nbformat_minor": 5
122 | }
123 |
--------------------------------------------------------------------------------
/examples/contour_colorbar.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Contour plot with colorbar example\n"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import numpy as np\n",
17 | "\n",
18 | "from mpllayout import primitives as pr\n",
19 | "from mpllayout import constraints as cr\n",
20 | "from mpllayout import ui\n",
21 | "from mpllayout.solver import solve\n",
22 | "from mpllayout.layout import Layout\n",
23 | "\n",
24 | "from mpllayout.matplotlibutils import subplots"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": null,
30 | "metadata": {},
31 | "outputs": [],
32 | "source": [
33 | "# Use this so matplotlib figures are showed with whitespace\n",
34 | "%matplotlib inline\n",
35 | "%config InlineBackend.print_figure_kwargs = {'bbox_inches': None}"
36 | ]
37 | },
38 | {
39 | "cell_type": "markdown",
40 | "metadata": {},
41 | "source": [
42 | "## Specify the layout"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": null,
48 | "metadata": {},
49 | "outputs": [],
50 | "source": [
51 | "## Create the figure and two axes (one for the contour plot and one for the colorbar)\n",
52 | "\n",
53 | "### Primitives\n",
54 | "# Create a `Layout` to track all geometric primitives and constraints\n",
55 | "layout = Layout()\n",
56 | "\n",
57 | "# Create the main geometric primitives\n",
58 | "layout.add_prim(pr.Quadrilateral(), \"Figure\")\n",
59 | "layout.add_prim(pr.Axes(), \"AxesContour\")\n",
60 | "layout.add_prim(pr.Axes(), \"AxesColorbar\")\n",
61 | "\n",
62 | "### Constraints\n",
63 | "\n",
64 | "# Make all the quadrilaterals rectangular\n",
65 | "layout.add_constraint(cr.Box(), (\"Figure\",), ())\n",
66 | "layout.add_constraint(cr.Box(), (\"AxesContour/Frame\",), ())\n",
67 | "layout.add_constraint(cr.Box(), (\"AxesColorbar/Frame\",), ())\n",
68 | "\n",
69 | "## Figure size\n",
70 | "# Set the figure width\n",
71 | "layout.add_constraint(cr.Length(), (\"Figure/Line0\",), (5,))\n",
72 | "\n",
73 | "# Fix the figure's lower left corner to the origin\n",
74 | "layout.add_constraint(cr.Fix(), (\"Figure/Line0/Point0\",), (np.array([0, 0]),))\n",
75 | "\n",
76 | "## Axes sizes\n",
77 | "\n",
78 | "# Set the aspect ratio to match physical dimensions\n",
79 | "layout.add_constraint(\n",
80 | " cr.RelativeLength(), (\"AxesContour/Frame/Line0\", \"AxesContour/Frame/Line1\"), (2,)\n",
81 | ")\n",
82 | "\n",
83 | "# Set the color bar height to 1/4 inch\n",
84 | "layout.add_constraint(cr.Length(), (\"AxesColorbar/Frame/Line1\",), (1 / 8,))\n",
85 | "\n",
86 | "## Align the color bar with the contour plot\n",
87 | "# Make the left/right sides collinear\n",
88 | "layout.add_constraint(\n",
89 | " cr.Collinear(), (\"AxesContour/Frame/Line3\", \"AxesColorbar/Frame/Line3\"), ()\n",
90 | ")\n",
91 | "layout.add_constraint(\n",
92 | " cr.Collinear(), (\"AxesContour/Frame/Line1\", \"AxesColorbar/Frame/Line1\"), ()\n",
93 | ")\n",
94 | "\n",
95 | "## Margins\n",
96 | "# Place the color bar 1/16 inch above the contour plot\n",
97 | "layout.add_constraint(\n",
98 | " cr.MidpointYDistance(),\n",
99 | " (\"AxesContour/Frame/Line2\", \"AxesColorbar/Frame/Line0\"),\n",
100 | " (1 / 16,)\n",
101 | ")\n",
102 | "\n",
103 | "# Set the left margin to 6/8 inch from the contour plot\n",
104 | "layout.add_constraint(\n",
105 | " cr.MidpointXDistance(), (\"Figure/Line3\", \"AxesContour/Frame/Line3\"), (6 / 8,)\n",
106 | ")\n",
107 | "\n",
108 | "# Set the right margin 1/8 inch from the contour plot\n",
109 | "layout.add_constraint(\n",
110 | " cr.MidpointXDistance(), (\"AxesContour/Frame/Line1\", \"Figure/Line1\"), (1 / 8,)\n",
111 | ")\n",
112 | "\n",
113 | "# Set the top margin to 1/2 inch\n",
114 | "layout.add_constraint(\n",
115 | " cr.MidpointYDistance(), (\"AxesColorbar/Frame/Line2\", \"Figure/Line2\"), (1 / 2,)\n",
116 | ")\n",
117 | "\n",
118 | "# Set the bottom margin to 6/8 inch\n",
119 | "layout.add_constraint(\n",
120 | " cr.MidpointYDistance(), (\"Figure/Line0\", \"AxesContour/Frame/Line0\"), (6 / 8,)\n",
121 | ")"
122 | ]
123 | },
124 | {
125 | "cell_type": "markdown",
126 | "metadata": {},
127 | "source": [
128 | "## Solve the layout"
129 | ]
130 | },
131 | {
132 | "cell_type": "code",
133 | "execution_count": null,
134 | "metadata": {},
135 | "outputs": [],
136 | "source": [
137 | "## Solve the constrained layout\n",
138 | "\n",
139 | "root_prim_n, solve_info = solve(layout)\n",
140 | "\n",
141 | "# Visualize the layout\n",
142 | "fig_layout, ax_layout = ui.figure_prims(root_prim_n)\n",
143 | "\n",
144 | "fig_layout"
145 | ]
146 | },
147 | {
148 | "cell_type": "markdown",
149 | "metadata": {},
150 | "source": [
151 | "## Plot a contour plot and colorbar using the layout"
152 | ]
153 | },
154 | {
155 | "cell_type": "code",
156 | "execution_count": null,
157 | "metadata": {},
158 | "outputs": [],
159 | "source": [
160 | "# This creates the figure and axes based on the layout\n",
161 | "fig, axs = subplots(root_prim_n)\n",
162 | "# The sizes of `fig` and axes in `axs` will reflect the constraints\n",
163 | "# `axs['AxesColorbar']` is the colorbar and `axs['AxesContour']` is the contour\n",
164 | "\n",
165 | "\n",
166 | "print(fig.get_size_inches())\n",
167 | "\n",
168 | "x = np.linspace(0, 10, 51)\n",
169 | "y = np.linspace(0, 5, 26)\n",
170 | "xx, yy = np.meshgrid(x, y)\n",
171 | "z = (xx - 5) ** 2 + (yy - 2.5) ** 2\n",
172 | "\n",
173 | "cset = axs[\"AxesContour\"].contourf(x, y, z)\n",
174 | "\n",
175 | "axs[\"AxesContour\"].set_xlabel(\"x [cm]\")\n",
176 | "axs[\"AxesContour\"].set_ylabel(\"y [cm]\")\n",
177 | "\n",
178 | "fig.colorbar(cset, cax=axs[\"AxesColorbar\"], orientation=\"horizontal\")\n",
179 | "axs[\"AxesColorbar\"].xaxis.set_label_text(\"z [cm]\")\n",
180 | "axs[\"AxesColorbar\"].xaxis.set_tick_params(\n",
181 | " top=True, labeltop=True, bottom=False, labelbottom=False\n",
182 | ")\n",
183 | "axs[\"AxesColorbar\"].xaxis.set_label_position(position=\"top\")\n",
184 | "\n",
185 | "fig.savefig(\"contour_colorbar.png\")\n",
186 | "\n",
187 | "fig"
188 | ]
189 | },
190 | {
191 | "cell_type": "code",
192 | "execution_count": null,
193 | "metadata": {},
194 | "outputs": [],
195 | "source": []
196 | }
197 | ],
198 | "metadata": {
199 | "kernelspec": {
200 | "display_name": "numerics",
201 | "language": "python",
202 | "name": "python3"
203 | },
204 | "language_info": {
205 | "codemirror_mode": {
206 | "name": "ipython",
207 | "version": 3
208 | },
209 | "file_extension": ".py",
210 | "mimetype": "text/x-python",
211 | "name": "python",
212 | "nbconvert_exporter": "python",
213 | "pygments_lexer": "ipython3",
214 | "version": "3.12.7"
215 | }
216 | },
217 | "nbformat": 4,
218 | "nbformat_minor": 2
219 | }
220 |
--------------------------------------------------------------------------------
/examples/grid_axes.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Grid of fixed aspect axes example\n",
8 | "\n",
9 | "This example illustrates how to create a grid of axes with fixed-aspect ratios and margins around the figure.\n",
10 | "The width of the figure is fixed (for example, the width could be constrained by the column width in a journal) while the height automatically adjusts from the given constraints and axes width.\n",
11 | "This is difficult to accomplish directly in `matplotlib` without trial-and-error because the axes grid height isn't known until the figure is plotted."
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": null,
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "import itertools\n",
21 | "\n",
22 | "import numpy as np\n",
23 | "\n",
24 | "from mpllayout import solver\n",
25 | "from mpllayout import primitives as pr\n",
26 | "from mpllayout import constraints as co\n",
27 | "from mpllayout import layout as lay\n",
28 | "from mpllayout import matplotlibutils as lplt\n",
29 | "from mpllayout import ui"
30 | ]
31 | },
32 | {
33 | "cell_type": "markdown",
34 | "metadata": {},
35 | "source": [
36 | "## Specify the layout"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": null,
42 | "metadata": {},
43 | "outputs": [],
44 | "source": [
45 | "layout = lay.Layout()\n",
46 | "\n",
47 | "## Create the figure box\n",
48 | "layout.add_prim(pr.Quadrilateral(), \"Figure\")\n",
49 | "layout.add_constraint(co.Box(), (\"Figure\",), ())\n",
50 | "\n",
51 | "## Constrain the figure width\n",
52 | "# Note that the figure height isn't directly constrainted because other\n",
53 | "# constraints (margins, the axes aspect ratio, etc.) implicity determine what\n",
54 | "# the figure height should be.\n",
55 | "fig_width = 6\n",
56 | "layout.add_constraint(co.Width(), (\"Figure\",), (fig_width,))\n",
57 | "\n",
58 | "# Fix the figure bottom left corner to (0, 0)\n",
59 | "\n",
60 | "layout.add_constraint(co.Fix(), (\"Figure/Line0/Point0\",), (np.array([0, 0]),))\n",
61 | "\n",
62 | "## Create the axes\n",
63 | "\n",
64 | "# You can change the number of rows and columns in the axes here.\n",
65 | "num_row, num_col = (3, 4)\n",
66 | "axes_shape = (num_row, num_col)\n",
67 | "num_axes = int(np.prod(axes_shape))\n",
68 | "\n",
69 | "axes_keys = [\n",
70 | " f\"Axes[{row}, {col}]\"\n",
71 | " for row, col in itertools.product(range(num_row), range(num_col))\n",
72 | "]\n",
73 | "for axes_key in axes_keys:\n",
74 | " layout.add_prim(pr.Axes(), axes_key)\n",
75 | " layout.add_constraint(co.Box(), (f\"{axes_key}/Frame\",), ())\n",
76 | "\n",
77 | "# First constrain the top-left corner axes aspect ratio to be square.\n",
78 | "# Note this assumes the [0, 0] element is the top-left corner but other conventions\n",
79 | "# are possible.\n",
80 | "layout.add_constraint(co.AspectRatio(), (\"Axes[0, 0]/Frame\",), (1,))\n",
81 | "\n",
82 | "## Constrain the axes in a grid\n",
83 | "# You can either use a `Grid` constraint or manually apply the constraint grid\n",
84 | "\n",
85 | "margin_inner = 0.1\n",
86 | "\n",
87 | "# Approach using `co.Grid`:\n",
88 | "col_widths = np.ones(num_col-1)\n",
89 | "row_heights = np.ones(num_row-1)\n",
90 | "col_margins = margin_inner * np.ones(num_col-1)\n",
91 | "row_margins = margin_inner * np.ones(num_row-1)\n",
92 | "\n",
93 | "layout.add_constraint(\n",
94 | " co.Grid(shape=axes_shape),\n",
95 | " [f\"{axes_key}/Frame\" for axes_key in axes_keys],\n",
96 | " (col_widths, row_heights, col_margins, row_margins)\n",
97 | ")\n",
98 | "\n",
99 | "# Manual approach:\n",
100 | "\n",
101 | "# Constrain them on a rectilinear grid\n",
102 | "# (the x/y grid lines are still free to move left-right and up-down)\n",
103 | "# layout.add_constraint(\n",
104 | "# co.RectilinearGrid(shape=axes_shape),\n",
105 | "# [f\"{axes_key}/Frame\" for axes_key in axes_keys],\n",
106 | "# ()\n",
107 | "# )\n",
108 | "\n",
109 | "# # Because the axes lie on a grid, you only need to set widths/horizontal\n",
110 | "# # margins for a single row or axes, then heights/vertical margins for a single\n",
111 | "# # column of axes.\n",
112 | "\n",
113 | "# for col, width, margin in zip(range(1, num_col), col_widths, col_margins):\n",
114 | "# # Set equal widths for row 0\n",
115 | "# layout.add_constraint(\n",
116 | "# co.RelativeLength(),\n",
117 | "# (f\"Axes[0, {col}]/Frame/Line0\", \"Axes[0, 0]/Frame/Line0\"),\n",
118 | "# (width,)\n",
119 | "# )\n",
120 | "# # Set interior horizontal margin\n",
121 | "# layout.add_constraint(\n",
122 | "# co.OuterMargin(side='right'),\n",
123 | "# (f\"Axes[0, {col-1}]/Frame\", f\"Axes[0, {col}]/Frame\"),\n",
124 | "# (margin,)\n",
125 | "# )\n",
126 | "# for row, height, margin in zip(range(1, num_row), row_heights, row_margins):\n",
127 | "# # Set equal heights for column 0\n",
128 | "# layout.add_constraint(\n",
129 | "# co.RelativeLength(),\n",
130 | "# (f\"Axes[{row}, 0]/Frame/Line1\", \"Axes[0, 0]/Frame/Line1\"),\n",
131 | "# (height,)\n",
132 | "# )\n",
133 | "# # Set interior vertical margin\n",
134 | "# layout.add_constraint(\n",
135 | "# co.OuterMargin(side='bottom'),\n",
136 | "# (f\"Axes[{row-1}, 0]/Frame\", f\"Axes[{row}, 0]/Frame\"),\n",
137 | "# (margin,)\n",
138 | "# )\n",
139 | "\n",
140 | "## Constrain margins around the figure\n",
141 | "\n",
142 | "# Constrain top/bottom margins\n",
143 | "margin_top = 0.2\n",
144 | "margin_bottom = 0.2\n",
145 | "layout.add_constraint(\n",
146 | " co.InnerMargin(side='top'), (\"Axes[0, 0]/Frame\", \"Figure\"), (margin_top,)\n",
147 | ")\n",
148 | "layout.add_constraint(\n",
149 | " co.InnerMargin(side='bottom'), (f\"Axes[{num_row-1}, 0]/Frame\", \"Figure\"), (margin_bottom, )\n",
150 | ")\n",
151 | "\n",
152 | "# Constrain left/right margins\n",
153 | "margin_left = 0.2\n",
154 | "margin_right = 0.2\n",
155 | "layout.add_constraint(\n",
156 | " co.InnerMargin(side='left'),\n",
157 | " (\"Axes[0, 0]/Frame\", \"Figure\"), (margin_left,)\n",
158 | ")\n",
159 | "layout.add_constraint(\n",
160 | " co.InnerMargin(side='right'),\n",
161 | " (f\"Axes[0, {num_col-1}]/Frame\", \"Figure\"), (margin_right,)\n",
162 | ")"
163 | ]
164 | },
165 | {
166 | "cell_type": "markdown",
167 | "metadata": {},
168 | "source": [
169 | "## Solve the layout"
170 | ]
171 | },
172 | {
173 | "cell_type": "code",
174 | "execution_count": null,
175 | "metadata": {},
176 | "outputs": [],
177 | "source": [
178 | "## Solve the constraints and form the figure/axes layout\n",
179 | "prim_tree_n, info = solver.solve(layout)\n",
180 | "print(info)\n",
181 | "\n",
182 | "# This is the solved layout:\n",
183 | "_fig, _ = ui.figure_prims(prim_tree_n)\n",
184 | "_fig.savefig(\"grid_axes_layout.png\", dpi=300)"
185 | ]
186 | },
187 | {
188 | "cell_type": "markdown",
189 | "metadata": {},
190 | "source": [
191 | "## Plot a grid of images using the layout"
192 | ]
193 | },
194 | {
195 | "cell_type": "code",
196 | "execution_count": null,
197 | "metadata": {},
198 | "outputs": [],
199 | "source": [
200 | "# Use the layout to plot randomly generated 10x10 pixel images\n",
201 | "fig, axs = lplt.subplots(prim_tree_n)\n",
202 | "\n",
203 | "for key, ax in axs.items():\n",
204 | " ax.set_axis_off()\n",
205 | "\n",
206 | " ax.imshow(np.random.rand(10, 10))\n",
207 | "\n",
208 | "# x = np.linspace(0, 1)\n",
209 | "# axs['Axes1'].plot(x, x**2)\n",
210 | "\n",
211 | "fig.savefig(\"grid_axes.png\")\n",
212 | "# fig"
213 | ]
214 | }
215 | ],
216 | "metadata": {
217 | "kernelspec": {
218 | "display_name": "numerics",
219 | "language": "python",
220 | "name": "python3"
221 | },
222 | "language_info": {
223 | "codemirror_mode": {
224 | "name": "ipython",
225 | "version": 3
226 | },
227 | "file_extension": ".py",
228 | "mimetype": "text/x-python",
229 | "name": "python",
230 | "nbconvert_exporter": "python",
231 | "pygments_lexer": "ipython3",
232 | "version": "3.12.7"
233 | }
234 | },
235 | "nbformat": 4,
236 | "nbformat_minor": 2
237 | }
238 |
--------------------------------------------------------------------------------
/examples/one_axes.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# One axes figure example\n",
8 | "\n",
9 | "This example demonstrates how to create a one axes figure with fixed margins around the axes using both `mpllayout` and pure `matplotlib`."
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import numpy as np\n",
19 | "from matplotlib import pyplot as plt\n",
20 | "\n",
21 | "from mpllayout import primitives as pr\n",
22 | "from mpllayout import constraints as co\n",
23 | "from mpllayout import layout as lay\n",
24 | "from mpllayout import solver\n",
25 | "from mpllayout import matplotlibutils as lplt\n"
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "metadata": {},
31 | "source": [
32 | "## Specify and solve a layout using `mpllayout`"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": null,
38 | "metadata": {},
39 | "outputs": [],
40 | "source": [
41 | "layout = lay.Layout()\n",
42 | "\n",
43 | "## Create the figure and axes\n",
44 | "layout.add_prim(pr.Quadrilateral(), \"Figure\")\n",
45 | "layout.add_prim(pr.Axes(), \"Axes\")\n",
46 | "\n",
47 | "## Constrain figure and axes quadirlateral to be rectangular\n",
48 | "layout.add_constraint(co.Box(), (\"Figure\",), ())\n",
49 | "layout.add_constraint(co.Box(), (\"Axes/Frame\",), ())\n",
50 | "\n",
51 | "## Constrain the figure size\n",
52 | "fig_width, fig_height = 6, 3\n",
53 | "layout.add_constraint(co.XLength(), (\"Figure/Line0\",), (fig_width,))\n",
54 | "layout.add_constraint(co.YLength(), (\"Figure/Line1\",), (fig_height,))\n",
55 | "\n",
56 | "## Constrain 'Axes' margins\n",
57 | "# Constrain left/right margins\n",
58 | "margin_left = 1.1\n",
59 | "margin_right = 1.1\n",
60 | "layout.add_constraint(\n",
61 | " co.InnerMargin(side='left'), (\"Axes/Frame\", \"Figure\"), (margin_left,)\n",
62 | ")\n",
63 | "layout.add_constraint(\n",
64 | " co.InnerMargin(side='right'), (\"Axes/Frame\", \"Figure\"), (margin_right,)\n",
65 | ")\n",
66 | "\n",
67 | "# Constrain top/bottom margins\n",
68 | "margin_top = 1.1\n",
69 | "margin_bottom = 0.5\n",
70 | "layout.add_constraint(\n",
71 | " co.InnerMargin(side='bottom'), (\"Axes/Frame\", \"Figure\"), (margin_bottom,)\n",
72 | ")\n",
73 | "layout.add_constraint(\n",
74 | " co.InnerMargin(side='top'), (\"Axes/Frame\", \"Figure\"), (margin_top,)\n",
75 | ")\n",
76 | "\n",
77 | "## Solve the constraints\n",
78 | "prim_tree_n, info = solver.solve(layout)"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": null,
84 | "metadata": {},
85 | "outputs": [],
86 | "source": [
87 | "# Used the solved layout to form the figure and axes objects, then plot\n",
88 | "fig, axs = lplt.subplots(prim_tree_n)\n",
89 | "\n",
90 | "ax = axs[\"Axes\"]\n",
91 | "\n",
92 | "x = np.linspace(0, 1)\n",
93 | "ax.plot(x, x**2)\n",
94 | "\n",
95 | "ax.set_xlabel(\"x\")\n",
96 | "ax.set_ylabel(\"y\")"
97 | ]
98 | },
99 | {
100 | "cell_type": "markdown",
101 | "metadata": {},
102 | "source": [
103 | "## Specify and 'solve' a layout using pure `matplotlib`\n",
104 | "\n",
105 | "There are multiple approaches to create the above layout using pure `matplotlib`; however, some may not have the precise margins specified.\n",
106 | "To create the layout using pure `matplotlib`, the axes position that satisfies the desired margins and figure size must be determined.\n",
107 | "This essentially involves manually \"solving\" the system of constraints.\n",
108 | "For the simple case of fixed margins around a single axes, the solution is given by the code below.\n",
109 | "This is the basic process that `mpllayout` performs but for general constraints, which allows for more complicated layouts."
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": null,
115 | "metadata": {},
116 | "outputs": [],
117 | "source": [
118 | "## This code determines the axes sizes needed to achieve the desired margins\n",
119 | "\n",
120 | "# Specify the desired figure size and margins\n",
121 | "fig_width, fig_height = 6, 3\n",
122 | "\n",
123 | "margin_left = 1.1\n",
124 | "margin_right = 1.1\n",
125 | "\n",
126 | "margin_top = 1.1\n",
127 | "margin_bottom = 0.5\n",
128 | "\n",
129 | "# The code below \"solves\" the axes position needed to achieve the given margins.\n",
130 | "# Using this approach calculates maintains the same margins even if the figure\n",
131 | "# dimensions are changed.\n",
132 | "# This is essentially what `mpllayout` does.\n",
133 | "coord_botleft = np.array((margin_left, margin_bottom))\n",
134 | "coord_topright = np.array((fig_width - margin_right, fig_height-margin_top))\n",
135 | "\n",
136 | "# Scale the coordinates relative to the figure size since this is how `matplotlib`\n",
137 | "# interprets the axes position\n",
138 | "coord_botleft = coord_botleft / (fig_width, fig_height)\n",
139 | "coord_topright = coord_topright / (fig_width, fig_height)\n",
140 | "\n",
141 | "# Determine the axes width and height\n",
142 | "axes_width, axes_height = (coord_topright - coord_botleft)"
143 | ]
144 | },
145 | {
146 | "cell_type": "code",
147 | "execution_count": null,
148 | "metadata": {},
149 | "outputs": [],
150 | "source": [
151 | "# Use the \"solved\" axes position and figure size to create the figure\n",
152 | "fig = plt.figure(figsize=(fig_width, fig_height))\n",
153 | "ax = fig.add_axes((*coord_botleft, axes_width, axes_height))\n",
154 | "\n",
155 | "x = np.linspace(0, 1)\n",
156 | "ax.plot(x, x**2)\n",
157 | "\n",
158 | "ax.set_xlabel(\"x\")\n",
159 | "ax.set_ylabel(\"y\")\n"
160 | ]
161 | }
162 | ],
163 | "metadata": {
164 | "kernelspec": {
165 | "display_name": "numerics",
166 | "language": "python",
167 | "name": "python3"
168 | },
169 | "language_info": {
170 | "codemirror_mode": {
171 | "name": "ipython",
172 | "version": 3
173 | },
174 | "file_extension": ".py",
175 | "mimetype": "text/x-python",
176 | "name": "python",
177 | "nbconvert_exporter": "python",
178 | "pygments_lexer": "ipython3",
179 | "version": "3.12.7"
180 | }
181 | },
182 | "nbformat": 4,
183 | "nbformat_minor": 2
184 | }
185 |
--------------------------------------------------------------------------------
/examples/profile_solver.py:
--------------------------------------------------------------------------------
1 | """
2 | Profile runtimes of key functions
3 | """
4 |
5 | import typing as tp
6 | from typing import Optional
7 |
8 | import cProfile
9 | import pstats
10 |
11 | import numpy as np
12 |
13 | from mpllayout import layout as lay
14 | from mpllayout import primitives as pr
15 | from mpllayout import constraints as co
16 | from mpllayout import containers as cn
17 | from mpllayout import solver
18 |
19 |
20 | def gen_layout(axes_shape: Optional[tuple[int, ...]] = (3, 3)) -> lay.Layout:
21 | layout = lay.Layout()
22 |
23 | ## Create an origin point
24 | layout.add_prim(pr.Point([0, 0]), "Origin")
25 | layout.add_constraint(co.Fix(), ("Origin",), (np.array([0, 0]),))
26 |
27 | ## Create the figure box
28 | layout.add_prim(pr.Quadrilateral(), "Figure")
29 | layout.add_constraint(co.Box(), ("Figure",), ())
30 |
31 | ## Constrain the figure size and position
32 | fig_width, fig_height = 6, 3
33 | layout.add_constraint(co.Length(), ("Figure/Line0",), (fig_width,))
34 | layout.add_constraint(co.Coincident(), ("Figure/Line0/Point0", "Origin"), ())
35 |
36 | ## Create the axes boxes
37 | num_row, num_col = axes_shape
38 | num_axes = int(np.prod(axes_shape))
39 | verts = [[0, 0], [5, 0], [5, 5], [0, 5]]
40 | for n in range(num_axes):
41 | layout.add_prim(pr.Axes(), f"Axes{n}")
42 | layout.add_constraint(co.Box(), (f"Axes{n}/Frame",), ())
43 |
44 | ## Constrain the axes in a grid
45 | num_row, num_col = axes_shape
46 | grid_kwargs = (
47 | (num_col - 1) * [1 / 16],
48 | (num_row - 1) * [1 / 16],
49 | (num_col - 1) * [1],
50 | (num_row - 1) * [1],
51 | )
52 | layout.add_constraint(
53 | co.Grid(axes_shape),
54 | tuple(f"Axes{n}/Frame" for n in range(num_axes)),
55 | grid_kwargs
56 | )
57 |
58 | # Constrain the first axis aspect ratio
59 | layout.add_constraint(
60 | co.RelativeLength(), ("Axes0/Frame/Line0", "Axes0/Frame/Line1"), (2,)
61 | )
62 |
63 | # Constrain top/bottom margins
64 | margin_top = 1.1
65 | margin_bottom = 0.5
66 | layout.add_constraint(
67 | co.DirectedDistance(),
68 | ("Axes0/Frame/Line1/Point1", "Figure/Line1/Point1"),
69 | (np.array([0, 1]), margin_top)
70 | )
71 | layout.add_constraint(
72 | co.DirectedDistance(),
73 | (f"Axes{num_axes-1}/Frame/Line1/Point0", "Figure/Line1/Point0"),
74 | (np.array([0, -1]), margin_bottom)
75 | )
76 |
77 | # Constrain left/right margins
78 | margin_left = 0.2
79 | margin_right = 0.3
80 | layout.add_constraint(
81 | co.DirectedDistance(),
82 | ("Axes0/Frame/Line0/Point0", "Figure/Line0/Point0"),
83 | (np.array([-1, 0]), margin_left)
84 | )
85 | layout.add_constraint(
86 | co.DirectedDistance(),
87 | (f"Axes{num_col-1}/Frame/Line1/Point1", "Figure/Line1/Point1"),
88 | (np.array([1, 0]), margin_right)
89 | )
90 | return layout
91 |
92 |
93 | if __name__ == "__main__":
94 | layout = gen_layout((12, 12))
95 |
96 | constraints, constraint_graph, constraint_params = layout.flat_constraints()
97 |
98 | solver.assem_constraint_residual(
99 | layout.root_prim, constraints, constraint_graph, constraint_params
100 | )
101 | stmt = "solver.assem_constraint_residual(layout.root_prim, constraints, constraint_graph, constraint_params)"
102 | cProfile.run(stmt, "profile_wo_jax.prof")
103 | stats = pstats.Stats("profile_wo_jax.prof")
104 | stats.sort_stats("cumtime").print_stats(20)
105 |
106 | import jax
107 |
108 | constraints_jit = [jax.jit(c) for c in constraints]
109 | solver.assem_constraint_residual(
110 | layout.root_prim,
111 | constraints_jit,
112 | constraint_graph,
113 | constraint_params
114 | )
115 | stmt = "solver.assem_constraint_residual(layout.root_prim, constraints_jit, constraint_graph, constraint_params)"
116 | cProfile.run(stmt, "profile_w_jax.prof")
117 | stats = pstats.Stats("profile_w_jax.prof")
118 | stats.sort_stats("cumtime").print_stats(20)
119 |
--------------------------------------------------------------------------------
/examples/ten_simple_rules_demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "869c5e8b-6a10-48dd-855c-9aa540ec082a",
6 | "metadata": {},
7 | "source": [
8 | "# \"Ten Simple Rules for Better Figures\" example\n",
9 | "\n",
10 | "`mpllayout` models axes and other elements in figures as geometric primitives which can be constrained relative to each other. \n",
11 | "This gives a flexible way to precisely position figure elements.\n",
12 | "\n",
13 | "The following demo produces Figure 1 from the paper \"Ten Simple Rules for Better Figures\" (Rougier NP, Droettboom M, Bourne PE (2014) Ten Simple Rules for Better Figures. PLOS Computational Biology 10(9): e1003833. https://doi.org/10.1371/journal.pcbi.1003833).\n",
14 | "This figure is itself, a remake of one originally published in the [New York Times](https://archive.nytimes.com/www.nytimes.com/imagepages/2007/07/29/health/29cancer.graph.web.html?action=click&module=RelatedCoverage&pgtype=Article®ion=Footer).\n",
15 | "\n",
16 | "The below two sections illustrate how to create the above figure with a `Grid` type constraint as well as more basic constraints."
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": null,
22 | "id": "97be0bae-5121-4a0f-bbcd-aab3a5bea0bd",
23 | "metadata": {},
24 | "outputs": [],
25 | "source": [
26 | "# Import the relevant packages\n",
27 | "\n",
28 | "import matplotlib.pyplot as plt\n",
29 | "import numpy as np\n",
30 | "\n",
31 | "from mpllayout.solver import solve\n",
32 | "from mpllayout.layout import Layout, update_layout_constraints\n",
33 | "from mpllayout import primitives as pr\n",
34 | "from mpllayout import constraints as co\n",
35 | "from mpllayout.matplotlibutils import subplots, update_subplots, find_axis_position\n",
36 | "from mpllayout import ui"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": null,
42 | "id": "5c3a8af3",
43 | "metadata": {},
44 | "outputs": [],
45 | "source": [
46 | "def figure_layout(layout):\n",
47 | " \"\"\"\n",
48 | " Return a figure of the layout\n",
49 | " \"\"\"\n",
50 | " prims_n, solve_info = solve(layout)\n",
51 | " return ui.figure_prims(prims_n)"
52 | ]
53 | },
54 | {
55 | "cell_type": "markdown",
56 | "id": "a1fc5432",
57 | "metadata": {},
58 | "source": [
59 | "## Specify the `layout`\n",
60 | "\n",
61 | "The `layout` is a collection of:\n",
62 | "- geometric primitives to represent figure elements\n",
63 | "- and constraints which position the figure elements\n"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": null,
69 | "id": "a171b105",
70 | "metadata": {},
71 | "outputs": [],
72 | "source": [
73 | "layout = Layout()"
74 | ]
75 | },
76 | {
77 | "cell_type": "markdown",
78 | "id": "7c231206",
79 | "metadata": {},
80 | "source": [
81 | "### Add geometric primitives\n",
82 | "\n",
83 | "Represent the `mpl.Figure` by a `Quadrilateral` and `mpl.Axes` by an `Axes` primitive.\n",
84 | "The `Axes` primitive contains a `Quadrilateral` to represent the frame, and optionally, a `Quadrilateral` and `Point` to represent the x axis and y axis."
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": null,
90 | "id": "08134366",
91 | "metadata": {},
92 | "outputs": [],
93 | "source": [
94 | "# First a box called \"Figure\" to the layout\n",
95 | "layout.add_prim(pr.Quadrilateral(), \"Figure\")\n",
96 | "\n",
97 | "# Then add \"axes\" primitives to represent a left, middle, and right axes\n",
98 | "# Axes can contain `Quadrilaterals` and `Point` primitives to represent the\n",
99 | "# axes frame, x/y axis, and axis labels\n",
100 | "layout.add_prim(pr.Axes(xaxis=True, yaxis=True), \"AxesLeft\")\n",
101 | "layout.add_prim(pr.Axes(), \"AxesMid\")\n",
102 | "layout.add_prim(pr.Axes(xaxis=True, yaxis=True), \"AxesRight\")"
103 | ]
104 | },
105 | {
106 | "cell_type": "markdown",
107 | "id": "cc84cd74",
108 | "metadata": {},
109 | "source": [
110 | "### Add geometric constraints"
111 | ]
112 | },
113 | {
114 | "cell_type": "markdown",
115 | "id": "3de3df32",
116 | "metadata": {},
117 | "source": [
118 | "#### Make all `Quadrilateral`s rectangular\n",
119 | "\n",
120 | "MPLlayout doesn't constrain quadrilaterals to be rectangular like the figure or axes frame in matplotlib so they must be constrained."
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "execution_count": null,
126 | "id": "53d68f61",
127 | "metadata": {},
128 | "outputs": [],
129 | "source": [
130 | "# `co.Box` forces quadrilateral sides to be vertical and tops/bottoms to be horizontal\n",
131 | "# It has no parameters so that last argument is any empty tuple\n",
132 | "layout.add_constraint(co.Box(), (\"Figure\",), ())\n",
133 | "\n",
134 | "# \"AxesMid\" only has a frame (no x/y axis)\n",
135 | "layout.add_constraint(co.Box(), (\"AxesMid/Frame\",), ())\n",
136 | "\n",
137 | "# Here we constrain all child quads of the left and right axes to be boxes\n",
138 | "for axes_key in [\"AxesLeft\", \"AxesRight\"]:\n",
139 | " for quad_key in [\"Frame\", \"XAxis\", \"YAxis\"]:\n",
140 | " layout.add_constraint(co.Box(), (f\"{axes_key}/{quad_key}\",), ())"
141 | ]
142 | },
143 | {
144 | "cell_type": "code",
145 | "execution_count": null,
146 | "id": "5262250b",
147 | "metadata": {},
148 | "outputs": [],
149 | "source": [
150 | "# This plots the created geometry\n",
151 | "# Note that by default all the quads are unit squares\n",
152 | "figure_layout(layout)"
153 | ]
154 | },
155 | {
156 | "cell_type": "markdown",
157 | "id": "6e2224b7",
158 | "metadata": {},
159 | "source": [
160 | "#### Fix the Figure dimensions and position\n",
161 | "\n",
162 | "Set the figure width/height and fix the bottom left point to the origin"
163 | ]
164 | },
165 | {
166 | "cell_type": "code",
167 | "execution_count": null,
168 | "id": "3e9b9643",
169 | "metadata": {},
170 | "outputs": [],
171 | "source": [
172 | "## Set the figure dimensions\n",
173 | "\n",
174 | "# Fix the bottom left point of 'Figure' to the origin\n",
175 | "layout.add_constraint(co.Fix(), (\"Figure/Line0/Point0\",), (np.array([0, 0]),))\n",
176 | "\n",
177 | "# Set the 'Figure' width and height\n",
178 | "fig_width, fig_height = (12, 7)\n",
179 | "layout.add_constraint(co.Length(), (\"Figure/Line1\",), (fig_height,))\n",
180 | "layout.add_constraint(co.Length(), (\"Figure/Line0\",), (fig_width,))"
181 | ]
182 | },
183 | {
184 | "cell_type": "code",
185 | "execution_count": null,
186 | "id": "9d25658e",
187 | "metadata": {},
188 | "outputs": [],
189 | "source": [
190 | "# Note the figure quadrilateral is 12\" by 7\"\n",
191 | "# The remaining axes are unit squares since they haven't been constrained yet\n",
192 | "fig, ax = figure_layout(layout)"
193 | ]
194 | },
195 | {
196 | "cell_type": "markdown",
197 | "id": "56df289b",
198 | "metadata": {},
199 | "source": [
200 | "#### Constrain the `Axes` to a 1 by 3 rectilinear grid\n",
201 | "\n",
202 | "We can force the left, middle, and right axes to align on 1 by 3 grid and set their relative widths.\n",
203 | "\n"
204 | ]
205 | },
206 | {
207 | "cell_type": "code",
208 | "execution_count": null,
209 | "id": "5e1eab5c",
210 | "metadata": {},
211 | "outputs": [],
212 | "source": [
213 | "# Align the axes on a 1x3 rectilinear grid\n",
214 | "shape = (1, 3)\n",
215 | "layout.add_constraint(\n",
216 | " co.RectilinearGrid(shape),\n",
217 | " (\"AxesLeft/Frame\", \"AxesMid/Frame\", \"AxesRight/Frame\"),\n",
218 | " ()\n",
219 | ")"
220 | ]
221 | },
222 | {
223 | "cell_type": "code",
224 | "execution_count": null,
225 | "id": "6d1b513b",
226 | "metadata": {},
227 | "outputs": [],
228 | "source": [
229 | "# Set zeros margins between left/right axes and the middle axes\n",
230 | "layout.add_constraint(co.OuterMargin(side='left'), (\"AxesMid/Frame\", \"AxesLeft/Frame\"), (0,))\n",
231 | "layout.add_constraint(co.OuterMargin(side='right'), (\"AxesMid/Frame\", \"AxesRight/Frame\"), (0,))\n",
232 | "\n",
233 | "# Make the left/right axes the same width and the central axes 0.5 that width\n",
234 | "layout.add_constraint(\n",
235 | " co.RelativeLength(), (\"AxesRight/Frame/Line0\", \"AxesLeft/Frame/Line0\"), (1.0,)\n",
236 | ")\n",
237 | "layout.add_constraint(\n",
238 | " co.RelativeLength(), (\"AxesMid/Frame/Line0\", \"AxesLeft/Frame/Line0\"), (0.5,)\n",
239 | ")"
240 | ]
241 | },
242 | {
243 | "cell_type": "code",
244 | "execution_count": null,
245 | "id": "b007f45b",
246 | "metadata": {},
247 | "outputs": [],
248 | "source": [
249 | "# Note the 3 axes are now aligned\n",
250 | "# It's difficult to see because the left and right axes also have axises that are shown\n",
251 | "fig, ax = figure_layout(layout)"
252 | ]
253 | },
254 | {
255 | "cell_type": "markdown",
256 | "id": "129f4a98",
257 | "metadata": {},
258 | "source": [
259 | "#### Position the x-axis and y-axis for left and right axes"
260 | ]
261 | },
262 | {
263 | "cell_type": "code",
264 | "execution_count": null,
265 | "id": "b4730103",
266 | "metadata": {},
267 | "outputs": [],
268 | "source": [
269 | "# These constraints fix the x/y axis to one side of the axes\n",
270 | "# When creating the figure from a layout, these axis positions will be inherited\n",
271 | "layout.add_constraint(co.PositionXAxis(side='top'), (\"AxesLeft\", ), ())\n",
272 | "layout.add_constraint(co.PositionYAxis(side='right'), (\"AxesLeft\", ), ())\n",
273 | "\n",
274 | "layout.add_constraint(co.PositionXAxis(side='top'), (\"AxesRight\", ), ())\n",
275 | "layout.add_constraint(co.PositionYAxis(side='left'), (\"AxesRight\", ), ())\n"
276 | ]
277 | },
278 | {
279 | "cell_type": "code",
280 | "execution_count": null,
281 | "id": "87f32652",
282 | "metadata": {},
283 | "outputs": [],
284 | "source": [
285 | "# These constraints set the variable width of the y-axis and variable height of the x-axis\n",
286 | "# The axis dimensions are variable since they depend on the size of any tick labels\n",
287 | "# Axis dimensions can be updated using `update_layout_constraints` after\n",
288 | "# axis text has been generated\n",
289 | "# Note that axis labels aren't included in the size of the axis!\n",
290 | "layout.add_constraint(\n",
291 | " co.XAxisThickness(), (f\"AxesLeft/XAxis\",), (None,)\n",
292 | ")\n",
293 | "layout.add_constraint(\n",
294 | " co.YAxisThickness(), (f\"AxesLeft/YAxis\",), (None,)\n",
295 | ")\n",
296 | "layout.add_constraint(\n",
297 | " co.XAxisThickness(), (f\"AxesRight/XAxis\",), (None,)\n",
298 | ")\n",
299 | "layout.add_constraint(\n",
300 | " co.YAxisThickness(), (f\"AxesRight/YAxis\",), (None,)\n",
301 | ")"
302 | ]
303 | },
304 | {
305 | "cell_type": "code",
306 | "execution_count": null,
307 | "id": "cfd1ded6",
308 | "metadata": {},
309 | "outputs": [],
310 | "source": [
311 | "# Note the x axis is now stuck to the top of each axes\n",
312 | "figure_layout(layout)"
313 | ]
314 | },
315 | {
316 | "cell_type": "markdown",
317 | "id": "32a9a09b",
318 | "metadata": {},
319 | "source": [
320 | "#### Set margins\n",
321 | "\n",
322 | "Note that earlier we never set the absolute width of the axes; to ensure nice whitespace we can specify margins to indirectly set the axes dimensions."
323 | ]
324 | },
325 | {
326 | "cell_type": "code",
327 | "execution_count": null,
328 | "id": "9721f9a5",
329 | "metadata": {},
330 | "outputs": [],
331 | "source": [
332 | "# Set the top/bottom margins\n",
333 | "# The top margin will be set above the x-axis bounding box which ensure the text won't cut out of the figure\n",
334 | "margin_top, margin_bottom = (0.5, 0.5)\n",
335 | "\n",
336 | "# The `InnerMargin` constraint sets the gap between an\n",
337 | "# inner quad (\"AxesRight/XAxis\") and an outer quad (\"Figure\")\n",
338 | "layout.add_constraint(\n",
339 | " co.InnerMargin(side=\"top\"), (\"AxesRight/XAxis\", \"Figure\"), (margin_top,)\n",
340 | ")\n",
341 | "\n",
342 | "layout.add_constraint(\n",
343 | " co.InnerMargin(side=\"bottom\"), (\"AxesRight/Frame\", \"Figure\"), (margin_bottom,)\n",
344 | ")\n",
345 | "\n",
346 | "# Set the left/right margins\n",
347 | "margin_left, margin_right = (0.5, 0.5)\n",
348 | "\n",
349 | "layout.add_constraint(\n",
350 | " co.InnerMargin(side='left'), (\"AxesLeft/Frame\", \"Figure\"), (margin_left,)\n",
351 | ")\n",
352 | "\n",
353 | "layout.add_constraint(\n",
354 | " co.InnerMargin(side='right'), (\"AxesRight/Frame\", \"Figure\"), (margin_right,)\n",
355 | ")"
356 | ]
357 | },
358 | {
359 | "cell_type": "code",
360 | "execution_count": null,
361 | "id": "453aa970",
362 | "metadata": {},
363 | "outputs": [],
364 | "source": [
365 | "# Now the margins are all constrained!\n",
366 | "# The grid arrangement of the left/middle/right axes is clearer since the axes have been moved apart\n",
367 | "fig, ax = figure_layout(layout)"
368 | ]
369 | },
370 | {
371 | "cell_type": "markdown",
372 | "id": "0866e7b2-c69c-43d1-9bc6-65de4f26ff9a",
373 | "metadata": {},
374 | "source": [
375 | "## Solve the layout\n",
376 | "\n",
377 | "We can solve the `layout` to determine a set of primitives that satisfy the constraints.\n",
378 | "The solved primitives are then used to generate matplotlib figure and axes objects that reflect the layout.\n",
379 | "\n",
380 | "This is nice because the figure design and arrangement is separated from the plotting of data."
381 | ]
382 | },
383 | {
384 | "cell_type": "code",
385 | "execution_count": null,
386 | "id": "6a879951-73f3-4d43-998a-564647a94f44",
387 | "metadata": {},
388 | "outputs": [],
389 | "source": [
390 | "prim_tree_n, solve_info = solve(layout)"
391 | ]
392 | },
393 | {
394 | "cell_type": "code",
395 | "execution_count": null,
396 | "id": "35f187c0",
397 | "metadata": {},
398 | "outputs": [],
399 | "source": [
400 | "# This is the layout before the plot has been generated.\n",
401 | "# Note that the x axis thicknesses are 0 since no x tick labels exist.\n",
402 | "# After plotting the data, the axis thicknesses can be updated to account for this.\n",
403 | "fig, ax = figure_layout(layout)"
404 | ]
405 | },
406 | {
407 | "cell_type": "markdown",
408 | "id": "b678d9f3",
409 | "metadata": {},
410 | "source": [
411 | "## Plot the \"Ten Simple Rules for Better Figures\" dataset using the layout\n",
412 | "\n",
413 | "We can use the generated figure and axes to plot data now."
414 | ]
415 | },
416 | {
417 | "cell_type": "code",
418 | "execution_count": null,
419 | "id": "5a3fc270-4aac-4abe-943d-6c90e9c7f88e",
420 | "metadata": {},
421 | "outputs": [],
422 | "source": [
423 | "# The data below is approximated from a New York Times article ()\n",
424 | "# and is adapted from the figure-1.py file available at (https://github.com/rougier/ten-rules)\n",
425 | "\n",
426 | "diseases = [\n",
427 | " \"Kidney Cancer\",\n",
428 | " \"Bladder Cancer\",\n",
429 | " \"Esophageal Cancer\",\n",
430 | " \"Ovarian Cancer\",\n",
431 | " \"Liver Cancer\",\n",
432 | " \"Non-Hodgkin's\\nlymphoma\",\n",
433 | " \"Leukemia\",\n",
434 | " \"Prostate Cancer\",\n",
435 | " \"Pancreatic Cancer\",\n",
436 | " \"Breast Cancer\",\n",
437 | " \"Colorectal Cancer\",\n",
438 | " \"Lung Cancer\",\n",
439 | "]\n",
440 | "men_deaths = [\n",
441 | " 10000,\n",
442 | " 12000,\n",
443 | " 13000,\n",
444 | " 0,\n",
445 | " 14000,\n",
446 | " 12000,\n",
447 | " 16000,\n",
448 | " 25000,\n",
449 | " 20000,\n",
450 | " 500,\n",
451 | " 25000,\n",
452 | " 80000,\n",
453 | "]\n",
454 | "men_cases = [\n",
455 | " 30000,\n",
456 | " 50000,\n",
457 | " 13000,\n",
458 | " 0,\n",
459 | " 16000,\n",
460 | " 30000,\n",
461 | " 25000,\n",
462 | " 220000,\n",
463 | " 22000,\n",
464 | " 600,\n",
465 | " 55000,\n",
466 | " 115000,\n",
467 | "]\n",
468 | "women_deaths = [\n",
469 | " 6000,\n",
470 | " 5500,\n",
471 | " 5000,\n",
472 | " 20000,\n",
473 | " 9000,\n",
474 | " 12000,\n",
475 | " 13000,\n",
476 | " 0,\n",
477 | " 19000,\n",
478 | " 40000,\n",
479 | " 30000,\n",
480 | " 70000,\n",
481 | "]\n",
482 | "women_cases = [\n",
483 | " 20000,\n",
484 | " 18000,\n",
485 | " 5000,\n",
486 | " 25000,\n",
487 | " 9000,\n",
488 | " 29000,\n",
489 | " 24000,\n",
490 | " 0,\n",
491 | " 21000,\n",
492 | " 160000,\n",
493 | " 55000,\n",
494 | " 97000,\n",
495 | "]\n",
496 | "\n",
497 | "y_diseases = np.arange(len(diseases))"
498 | ]
499 | },
500 | {
501 | "cell_type": "code",
502 | "execution_count": null,
503 | "id": "e1dd6749-4b65-4617-9152-e56c28f16b04",
504 | "metadata": {},
505 | "outputs": [],
506 | "source": [
507 | "def format_axes(ax):\n",
508 | " \"\"\"\n",
509 | " Apply the Axes formatting used in \"Ten Simple Rules\"\n",
510 | " \"\"\"\n",
511 | " if not ax.xaxis.get_inverted():\n",
512 | " origin_side = \"left\"\n",
513 | " far_side = \"right\"\n",
514 | " else:\n",
515 | " origin_side = \"right\"\n",
516 | " far_side = \"left\"\n",
517 | "\n",
518 | " ax.spines[far_side].set_color(\"none\")\n",
519 | " ax.spines[origin_side].set_zorder(10)\n",
520 | " ax.spines[\"bottom\"].set_color(\"none\")\n",
521 | "\n",
522 | " # ax.xaxis.set_ticks_position(\"top\")\n",
523 | "\n",
524 | " # ax.yaxis.set_ticks_position(origin_side)\n",
525 | " ax.yaxis.set_ticks(y_diseases, labels=[\"\"] * len(y_diseases))\n",
526 | "\n",
527 | " ax.spines[\"top\"].set_position((\"data\", len(diseases) + 0.25))\n",
528 | " ax.spines[\"top\"].set_color(\"w\")"
529 | ]
530 | },
531 | {
532 | "cell_type": "markdown",
533 | "id": "d7611997",
534 | "metadata": {},
535 | "source": [
536 | "#### Plot with the initial layout\n",
537 | "\n",
538 | "The x and y axis thicknesses are not known since the ticks and label sizes cannot be known apriori.\n",
539 | "To account for this, simply plot the figure with the current layout (this may result in cut-off labels). \n",
540 | "You can then update the layout with the plotted axes to account for the axis thicknesses."
541 | ]
542 | },
543 | {
544 | "cell_type": "code",
545 | "execution_count": null,
546 | "id": "4da120e8-d200-4ad1-9714-49adf76afa1e",
547 | "metadata": {},
548 | "outputs": [],
549 | "source": [
550 | "## Here we plot the actual NYT figure from the article\n",
551 | "\n",
552 | "# The `subplots` function uses the solved primitives to create figure and axes objects with the determined sizes\n",
553 | "# `axs` is a dictionary with keys matching the axes names\n",
554 | "fig, axs = subplots(prim_tree_n)\n",
555 | "\n",
556 | "for ax in axs.values():\n",
557 | " ax.set_xlim(0, 200000)\n",
558 | "\n",
559 | "# Plot the men/womens data\n",
560 | "axs[\"AxesLeft\"].barh(y_diseases, women_cases, height=0.8, fc=\"red\", alpha=0.1)\n",
561 | "axs[\"AxesLeft\"].barh(y_diseases, women_deaths, height=0.55, fc=\"red\", alpha=0.5)\n",
562 | "axs[\"AxesLeft\"].xaxis.set_inverted(True)\n",
563 | "\n",
564 | "axs[\"AxesRight\"].barh(y_diseases, men_cases, height=0.8, fc=\"blue\", alpha=0.1)\n",
565 | "axs[\"AxesRight\"].barh(y_diseases, men_deaths, height=0.55, fc=\"blue\", alpha=0.5)\n",
566 | "\n",
567 | "axs_labels = [\"AxesLeft\", \"AxesRight\"]\n",
568 | "axs_categories = [\"women\", \"men\"]\n",
569 | "category_to_ax = {\n",
570 | " category: axs[key]\n",
571 | " for category, key in zip(axs_categories, axs_labels)\n",
572 | "}\n",
573 | "for category, ax in category_to_ax.items():\n",
574 | " format_axes(ax)\n",
575 | " ax.set_xticks(\n",
576 | " [0, 50000, 100000, 150000, 200000],\n",
577 | " [category.upper(), \"50,000\", \"100,000\", \"150,000\", \"200,000\"],\n",
578 | " )\n",
579 | " ax.grid(which=\"major\", axis=\"x\", color=\"white\")\n",
580 | " ax.get_xticklabels()[0].set_weight(\"bold\")\n",
581 | "\n",
582 | "# Add ylabels to 'AxesMid'\n",
583 | "axs[\"AxesMid\"].set_axis_off()\n",
584 | "axs[\"AxesMid\"].set_ylim(axs[\"AxesLeft\"].get_ylim())\n",
585 | "axs[\"AxesMid\"].set_xlim(-1, 1)\n",
586 | "\n",
587 | "for y, disease_name in zip(y_diseases, diseases):\n",
588 | " axs[\"AxesMid\"].text(0, y, disease_name, ha=\"center\", va=\"center\")\n",
589 | "\n",
590 | "# Add the \"NEW CASES\" and \"DEATHS\" annotations\n",
591 | "# Devil hides in the details...\n",
592 | "arrowprops = {\"arrowstyle\": \"-\", \"connectionstyle\": \"angle,angleA=0,angleB=90,rad=0\"}\n",
593 | "\n",
594 | "x = women_cases[-1]\n",
595 | "y = y_diseases[-1]\n",
596 | "axs[\"AxesLeft\"].annotate(\n",
597 | " \"NEW CASES\",\n",
598 | " xy=(0.9 * x, y),\n",
599 | " xycoords=\"data\",\n",
600 | " ha=\"right\",\n",
601 | " fontsize=10,\n",
602 | " xytext=(-40, -3),\n",
603 | " textcoords=\"offset points\",\n",
604 | " arrowprops=arrowprops,\n",
605 | ")\n",
606 | "\n",
607 | "x = women_deaths[-1]\n",
608 | "axs[\"AxesLeft\"].annotate(\n",
609 | " \"DEATHS\",\n",
610 | " xy=(0.85 * x, y),\n",
611 | " xycoords=\"data\",\n",
612 | " ha=\"right\",\n",
613 | " fontsize=10,\n",
614 | " xytext=(-50, -25),\n",
615 | " textcoords=\"offset points\",\n",
616 | " arrowprops=arrowprops,\n",
617 | ")\n",
618 | "\n",
619 | "x = men_cases[-1]\n",
620 | "axs[\"AxesRight\"].annotate(\n",
621 | " \"NEW CASES\",\n",
622 | " xy=(0.9 * x, y),\n",
623 | " xycoords=\"data\",\n",
624 | " ha=\"left\",\n",
625 | " fontsize=10,\n",
626 | " xytext=(+40, -3),\n",
627 | " textcoords=\"offset points\",\n",
628 | " arrowprops=arrowprops,\n",
629 | ")\n",
630 | "\n",
631 | "x = men_deaths[-1]\n",
632 | "axs[\"AxesRight\"].annotate(\n",
633 | " \"DEATHS\",\n",
634 | " xy=(0.9 * x, y),\n",
635 | " xycoords=\"data\",\n",
636 | " ha=\"left\",\n",
637 | " fontsize=10,\n",
638 | " xytext=(+50, -25),\n",
639 | " textcoords=\"offset points\",\n",
640 | " arrowprops=arrowprops,\n",
641 | ")\n",
642 | "\n",
643 | "# Add the caption text\n",
644 | "axs[\"AxesLeft\"].text(\n",
645 | " 165000, 8.2, \"Leading Causes\\nOf Cancer Deaths\", fontsize=18, va=\"top\"\n",
646 | ")\n",
647 | "axs[\"AxesLeft\"].text(\n",
648 | " 165000,\n",
649 | " 7,\n",
650 | " \"In 2007, there were more\\n\"\n",
651 | " \"than 1.4 million new cases\\n\"\n",
652 | " \"of cancer in the United States.\",\n",
653 | " va=\"top\",\n",
654 | " fontsize=10,\n",
655 | ")\n",
656 | "\n",
657 | "fig.savefig(\"ten_simple_rules_demo_no_axis_thickness.svg\")"
658 | ]
659 | },
660 | {
661 | "cell_type": "markdown",
662 | "id": "9e5c12c7",
663 | "metadata": {},
664 | "source": [
665 | "#### Update x and y axis thicknesses in the plot\n",
666 | "\n",
667 | "Now that data is plotted, the x/y have tick labels that modify their thickenss.\n",
668 | "You can use `update_layout_constraints` and the `XAxisThickness` and `YAxisThickness` constraints to update the axis sizes and adjust the figure layout."
669 | ]
670 | },
671 | {
672 | "cell_type": "code",
673 | "execution_count": null,
674 | "id": "e660bb3a",
675 | "metadata": {},
676 | "outputs": [],
677 | "source": [
678 | "# Update boundings boxes for the x/y axis now that text has been inserted\n",
679 | "# This will update the layout of axes\n",
680 | "layout = update_layout_constraints(layout, axs)\n",
681 | "\n",
682 | "prim_tree_n, info = solve(layout)\n",
683 | "# This updates the original `fig` and `axs` plot with the new layout\n",
684 | "fig, axs = update_subplots(prim_tree_n, \"Figure\", fig, axs)\n",
685 | "fig.savefig(\"ten_simple_rules_demo.svg\")\n",
686 | "\n",
687 | "fig"
688 | ]
689 | },
690 | {
691 | "cell_type": "code",
692 | "execution_count": null,
693 | "id": "f12e2ddf",
694 | "metadata": {},
695 | "outputs": [],
696 | "source": [
697 | "# If you plot the layout after axis sizes are updated, you can see the altered dimensions!\n",
698 | "fig, ax = figure_layout(layout)"
699 | ]
700 | },
701 | {
702 | "cell_type": "code",
703 | "execution_count": null,
704 | "id": "f983b76f",
705 | "metadata": {},
706 | "outputs": [],
707 | "source": []
708 | }
709 | ],
710 | "metadata": {
711 | "kernelspec": {
712 | "display_name": "numerics",
713 | "language": "python",
714 | "name": "python3"
715 | },
716 | "language_info": {
717 | "codemirror_mode": {
718 | "name": "ipython",
719 | "version": 3
720 | },
721 | "file_extension": ".py",
722 | "mimetype": "text/x-python",
723 | "name": "python",
724 | "nbconvert_exporter": "python",
725 | "pygments_lexer": "ipython3",
726 | "version": "3.12.7"
727 | }
728 | },
729 | "nbformat": 4,
730 | "nbformat_minor": 5
731 | }
732 |
--------------------------------------------------------------------------------
/examples/tutorial.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Tutorial\n",
8 | "\n",
9 | "This tutorial demonstrates the main workflow for using `mpllayout`.\n",
10 | "\n",
11 | "The workflow involves just a few high-level steps:\n",
12 | "1. Create a layout object to store the layout, `layout = Layout()`\n",
13 | "2. Add geometric primitives to `layout` using `layout.add_prim`. These primitives represent figure elements.\n",
14 | "3. Add geometric constraints to `layout` using `layout.add_constraint` to constrain the primitives.\n",
15 | "4. Solve the constrained layout of primitives using `constrained_prims, solve_info = solve(layout)`\n",
16 | "5. Generate a figure and axes to plot in using `fig, axs = subplots(constrained_prims)`\n",
17 | "\n",
18 | "The generated `fig` and `axs` will reflect the constrained layout."
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": null,
24 | "metadata": {},
25 | "outputs": [],
26 | "source": [
27 | "import numpy as np\n",
28 | "\n",
29 | "import matplotlib as mpl\n",
30 | "from matplotlib import pyplot as plt\n",
31 | "\n",
32 | "# `layout` contains the `Layout` class and related functions\n",
33 | "from mpllayout import layout as lay\n",
34 | "# `primitives` contains primitives and `constraints` constraints\n",
35 | "from mpllayout import primitives as pr\n",
36 | "from mpllayout import constraints as co\n",
37 | "# `solve` is used to solve the constrained layout\n",
38 | "from mpllayout.solver import solve\n",
39 | "\n",
40 | "# `subplots` and `update_subplots` are used to create matplotlib figure and\n",
41 | "# axes objects from geometric primitives\n",
42 | "from mpllayout.matplotlibutils import subplots, update_subplots\n",
43 | "\n",
44 | "# `ui` contains functions to visualize primitives\n",
45 | "from mpllayout import ui"
46 | ]
47 | },
48 | {
49 | "cell_type": "markdown",
50 | "metadata": {},
51 | "source": [
52 | "## Step 1: Create the layout"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": null,
58 | "metadata": {},
59 | "outputs": [],
60 | "source": [
61 | "# Create the layout to store constraints and primitives\n",
62 | "layout = lay.Layout()"
63 | ]
64 | },
65 | {
66 | "cell_type": "markdown",
67 | "metadata": {},
68 | "source": [
69 | "## Step 2: Add geometric primitives\n",
70 | "\n",
71 | "Geometric primitives represent geometry and are defined in `mpllayout.primitives`.\n",
72 | "Each primitive consists of a parameter vector (`primitive.value`) with child primitives (`primitive[\"child_key\"]`).\n",
73 | "For example:\n",
74 | "\n",
75 | "* `Point` represents a point and has a parameter vector containing its coordinates with no child primitives\n",
76 | "* `Line` represents a straight line segment, has no parameter vector, and contains two points representing the start point (`line[\"Point0\"]`) and end point (`line[\"Point1\"]`)\n",
77 | "* Other primitives are documented in the module\n",
78 | "\n",
79 | "Geometric primitives are added using the call\n",
80 | "`layout.add_prim(primitive, key)`\n",
81 | "where `primitive` is a geometric primitive object and `key` is a string used to identify it."
82 | ]
83 | },
84 | {
85 | "cell_type": "code",
86 | "execution_count": null,
87 | "metadata": {},
88 | "outputs": [],
89 | "source": [
90 | "# A `Quadrilateral` is a 4 sided polygon which can be used to represent the figure box.\n",
91 | "# Naming the quad \"Figure\" will cause the `subplots` command to create a figure of the same size.\n",
92 | "layout.add_prim(pr.Quadrilateral(), \"Figure\")\n",
93 | "\n",
94 | "# The `Axes` primitive is a collection of quadrilaterals and points used to represent an axes.\n",
95 | "# The child primitives of `Axes` are\n",
96 | "# - \"Frame\": a `Quadrilateral` representing the plotting area of the axes\n",
97 | "# - \"XAxis\": a `Quadrilateral` bounding x-axis ticks and tick labels\n",
98 | "# - \"XAxisLabel\": a `Point` for the x-axis label text anchor\n",
99 | "# - \"YAxis\": a `Quadrilateral` bounding y-axis ticks and tick labels\n",
100 | "# - \"YAxisLabel\": a `Point` for the y-axis label text anchor\n",
101 | "# The x/y axis can be optionally included by kwargs as seen below\n",
102 | "layout.add_prim(pr.Axes(xaxis=True, yaxis=True), \"Axes\")"
103 | ]
104 | },
105 | {
106 | "cell_type": "markdown",
107 | "metadata": {},
108 | "source": [
109 | "## Step 3: Add geometric constraints\n",
110 | "\n",
111 | "Geometric constraints represent constraints or conditions on primitives and are defined in `mpllayout.constraints`.\n",
112 | "Every constraint has a method representing the condition\n",
113 | "`Constraint.assem_res(prims, **kwargs)`, \n",
114 | "where `prims` is a tuple of primitives the constraint applies to and `**kwargs` are constraint specific parameters.\n",
115 | "The constraint is satisfied when `assem_res` is 0.\n",
116 | "\n",
117 | "For example:\n",
118 | "\n",
119 | "* `Coincident().assem_res((pointa, pointb))` represents the coincidence error between two `pointa` and `pointb` (no parameters are needed).\n",
120 | "* `Length().assem_res((linea,), length=7)` represents the length error for `linea` compared to the desired length of 7.\n",
121 | "* Other constraints are documented in the module\n",
122 | "\n",
123 | "Geometric constraints are added to a layout using the call\n",
124 | "`layout.add_constraint(constraint, prim_keys, constraint_params)`,\n",
125 | "where \n",
126 | "\n",
127 | "* `constraint` is the geometric constraint object\n",
128 | "* `prim_keys` is a tuple of primitive keys representing the primitives to constrain\n",
129 | "* `constraint_params` is a dictionary or tuple of constraint specific parameters.\n",
130 | "\n",
131 | "`prim_keys` can recursively indicate primitives uses slash separated keys.\n",
132 | "For example the tuple `(\"Figure/Line0/Point0\", \"Axes/Frame/Line0\")` represents the point 0 of the figure quadrilateral (the bottom left) and line 0 of the axes frame quadrilateral (the bottom line).\n",
133 | "\n",
134 | "`constraint_params` represents the `**kwargs` of `assem_res`.\n",
135 | "This can be either a dictionary or tuple representing the kwargs.\n",
136 | "\n",
137 | "The next few sections add sets of constraints and plot the resulting constrained layout to illustrate their effect.\n"
138 | ]
139 | },
140 | {
141 | "cell_type": "markdown",
142 | "metadata": {},
143 | "source": [
144 | "### Make all quadrilaterals rectangular\n",
145 | "\n",
146 | "`Quadrilateral`s have 4 unknown coordinates for each corner. \n",
147 | "To make them rectangular boxes like axes and figures, apply the `Box` constraint."
148 | ]
149 | },
150 | {
151 | "cell_type": "code",
152 | "execution_count": null,
153 | "metadata": {},
154 | "outputs": [],
155 | "source": [
156 | "# NOTE: This step is needed because `Quadrilateral`s don't have to be rectangular\n",
157 | "layout.add_constraint(co.Box(), (\"Figure\",), ())\n",
158 | "layout.add_constraint(co.Box(), (\"Axes/Frame\",), ())\n",
159 | "layout.add_constraint(co.Box(), (\"Axes/XAxis\",), ())\n",
160 | "layout.add_constraint(co.Box(), (\"Axes/YAxis\",), ())"
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "execution_count": null,
166 | "metadata": {},
167 | "outputs": [],
168 | "source": [
169 | "# The layout currently looks like:\n",
170 | "_fig, _ = ui.figure_prims(solve(layout)[0], fig_size=(5, 5))"
171 | ]
172 | },
173 | {
174 | "cell_type": "markdown",
175 | "metadata": {},
176 | "source": [
177 | "### Fix the figure position and size"
178 | ]
179 | },
180 | {
181 | "cell_type": "code",
182 | "execution_count": null,
183 | "metadata": {},
184 | "outputs": [],
185 | "source": [
186 | "# Fix the figure bottom left to the origin\n",
187 | "layout.add_constraint(co.Fix(), (\"Figure/Line0/Point0\",), (np.array([0, 0]),))\n",
188 | "\n",
189 | "# Figure the figure width and height\n",
190 | "fig_width, fig_height = 6, 3\n",
191 | "layout.add_constraint(co.XLength(), (\"Figure/Line0\",), (fig_width,))\n",
192 | "layout.add_constraint(co.YLength(), (\"Figure/Line1\",), (fig_height,))"
193 | ]
194 | },
195 | {
196 | "cell_type": "code",
197 | "execution_count": null,
198 | "metadata": {},
199 | "outputs": [],
200 | "source": [
201 | "# The layout currently looks like:\n",
202 | "_fig, _ = ui.figure_prims(solve(layout)[0], fig_size=(5, 5))"
203 | ]
204 | },
205 | {
206 | "cell_type": "markdown",
207 | "metadata": {},
208 | "source": [
209 | "### Position the x and y axis and axis labels"
210 | ]
211 | },
212 | {
213 | "cell_type": "code",
214 | "execution_count": null,
215 | "metadata": {},
216 | "outputs": [],
217 | "source": [
218 | "## Position the x axis on top and the y axis on the bottom\n",
219 | "# When creating axes from the primitives, `subplots` will detect axis\n",
220 | "# positions and set axis properties to reflect them.\n",
221 | "layout.add_constraint(co.PositionXAxis(side='top'), (\"Axes\", ), ())\n",
222 | "layout.add_constraint(co.PositionYAxis(side='right'), (\"Axes\", ), ())\n",
223 | "\n",
224 | "# Link x/y axis width/height to axis sizes in matplotlib.\n",
225 | "# Axis sizes change depending on the size of their tick labels so the\n",
226 | "# axis width/height must be linked to matplotlib and updated from plot\n",
227 | "# elements.\n",
228 | "layout.add_constraint(\n",
229 | " co.XAxisThickness(), (\"Axes/XAxis\",), (None,),\n",
230 | ")\n",
231 | "layout.add_constraint(\n",
232 | " co.YAxisThickness(), (\"Axes/YAxis\",), (None,),\n",
233 | ")\n",
234 | "\n",
235 | "## Position the x/y axis label text anchors\n",
236 | "# When creating axes from the primitives, `subplots` will detect these and set\n",
237 | "# their locations\n",
238 | "on_line = co.RelativePointOnLineDistance()\n",
239 | "to_line = co.PointToLineDistance()\n",
240 | "\n",
241 | "## Pad the x/y axis label from the axis bbox\n",
242 | "pad = 1/16\n",
243 | "layout.add_constraint(to_line, (\"Axes/XAxisLabel\", \"Axes/XAxis/Line2\"), (True, pad))\n",
244 | "layout.add_constraint(to_line, (\"Axes/YAxisLabel\", \"Axes/YAxis/Line1\"), (True, pad))\n",
245 | "\n",
246 | "## Center the axis labels halfway along the axes width/height\n",
247 | "layout.add_constraint(co.PositionXAxisLabel(), (\"Axes\",), (0.5,))\n",
248 | "layout.add_constraint(co.PositionYAxisLabel(), (\"Axes\",), (0.5,))"
249 | ]
250 | },
251 | {
252 | "cell_type": "code",
253 | "execution_count": null,
254 | "metadata": {},
255 | "outputs": [],
256 | "source": [
257 | "# The layout currently looks like:\n",
258 | "_fig, _ = ui.figure_prims(solve(layout)[0], fig_size=(5, 5))"
259 | ]
260 | },
261 | {
262 | "cell_type": "markdown",
263 | "metadata": {},
264 | "source": [
265 | "### Set margins between the axes and figure"
266 | ]
267 | },
268 | {
269 | "cell_type": "code",
270 | "execution_count": null,
271 | "metadata": {},
272 | "outputs": [],
273 | "source": [
274 | "## Constrain margins around the axes to the figure\n",
275 | "# Constrain left/right margins\n",
276 | "margin_left = 0.1\n",
277 | "margin_right = 1/4\n",
278 | "\n",
279 | "layout.add_constraint(\n",
280 | " co.InnerMargin(side='left'), (\"Axes/Frame\", \"Figure\"), (margin_left,)\n",
281 | ")\n",
282 | "layout.add_constraint(\n",
283 | " co.InnerMargin(side='right'), (\"Axes/YAxis\", \"Figure\"), (margin_right,)\n",
284 | ")\n",
285 | "\n",
286 | "# Constrain top/bottom margins\n",
287 | "margin_top = 1/4\n",
288 | "margin_bottom = 0.1\n",
289 | "layout.add_constraint(\n",
290 | " co.InnerMargin(side='bottom'), (\"Axes/Frame\", \"Figure\"), (margin_bottom,)\n",
291 | ")\n",
292 | "layout.add_constraint(\n",
293 | " co.InnerMargin(side='top'), (\"Axes/XAxis\", \"Figure\"), (margin_top,)\n",
294 | ")\n"
295 | ]
296 | },
297 | {
298 | "cell_type": "code",
299 | "execution_count": null,
300 | "metadata": {},
301 | "outputs": [],
302 | "source": [
303 | "# The layout currently looks like:\n",
304 | "_fig, _ = ui.figure_prims(solve(layout)[0], fig_size=(5, 5))"
305 | ]
306 | },
307 | {
308 | "cell_type": "markdown",
309 | "metadata": {},
310 | "source": [
311 | "## Step 4: Solve the layout "
312 | ]
313 | },
314 | {
315 | "cell_type": "code",
316 | "execution_count": null,
317 | "metadata": {},
318 | "outputs": [],
319 | "source": [
320 | "## Solve the constraints and form the figure/axes layout\n",
321 | "prim_tree_n, solve_info = solve(layout)\n",
322 | "\n",
323 | "print(f\"Absolute errors: {solve_info['abs_errs']}\")\n",
324 | "print(f\"Relative errors: {solve_info['rel_errs']}\")"
325 | ]
326 | },
327 | {
328 | "cell_type": "markdown",
329 | "metadata": {},
330 | "source": [
331 | "## Step 5: Plot a figure using the layout "
332 | ]
333 | },
334 | {
335 | "cell_type": "code",
336 | "execution_count": null,
337 | "metadata": {},
338 | "outputs": [],
339 | "source": [
340 | "## Plot into the generated figure and axes\n",
341 | "fig, axs = subplots(prim_tree_n)\n",
342 | "\n",
343 | "x = np.linspace(0, 1)\n",
344 | "axs[\"Axes\"].plot(x, x**2)\n",
345 | "\n",
346 | "axs[\"Axes\"].xaxis.set_label_text(\"My x label\", ha=\"center\", va=\"bottom\")\n",
347 | "axs[\"Axes\"].yaxis.set_label_text(\"My y label\", ha=\"center\", va=\"bottom\", rotation=-90)\n",
348 | "\n",
349 | "ax = axs[\"Axes\"]\n",
350 | "\n",
351 | "# Using the generated axes and x/y axis contents, the layout constraints\n",
352 | "# can be updated with those matplotlib elements\n",
353 | "layout = lay.update_layout_constraints(layout, axs)\n",
354 | "prim_tree_n, solve_info = solve(layout)\n",
355 | "\n",
356 | "# This updates the figure and axes using the updated layout\n",
357 | "update_subplots(prim_tree_n, \"Figure\", fig, axs)"
358 | ]
359 | },
360 | {
361 | "cell_type": "code",
362 | "execution_count": null,
363 | "metadata": {},
364 | "outputs": [],
365 | "source": [
366 | "# The layout currently looks like:\n",
367 | "_fig, _ = ui.figure_prims(solve(layout)[0], fig_size=(5, 5))\n",
368 | "\n",
369 | "# Note that x and y axis dimensions have adjusted"
370 | ]
371 | },
372 | {
373 | "cell_type": "code",
374 | "execution_count": null,
375 | "metadata": {},
376 | "outputs": [],
377 | "source": []
378 | }
379 | ],
380 | "metadata": {
381 | "kernelspec": {
382 | "display_name": "numerics",
383 | "language": "python",
384 | "name": "python3"
385 | },
386 | "language_info": {
387 | "codemirror_mode": {
388 | "name": "ipython",
389 | "version": 3
390 | },
391 | "file_extension": ".py",
392 | "mimetype": "text/x-python",
393 | "name": "python",
394 | "nbconvert_exporter": "python",
395 | "pygments_lexer": "ipython3",
396 | "version": "3.12.7"
397 | }
398 | },
399 | "nbformat": 4,
400 | "nbformat_minor": 2
401 | }
402 |
--------------------------------------------------------------------------------
/examples/two_axes.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Two axes figure example\n",
8 | "\n",
9 | "This example illustrates creating a two-axes layout.\n"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import numpy as np\n",
19 | "\n",
20 | "from mpllayout import layout as lay\n",
21 | "from mpllayout import primitives as pr\n",
22 | "from mpllayout import constraints as co\n",
23 | "from mpllayout import solver\n",
24 | "from mpllayout import ui\n",
25 | "from mpllayout import matplotlibutils as lplt"
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "metadata": {},
31 | "source": [
32 | "## Specify the layout"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": null,
38 | "metadata": {},
39 | "outputs": [],
40 | "source": [
41 | "# Create the layout object\n",
42 | "layout = lay.Layout()\n",
43 | "\n",
44 | "## Create an origin point fixed at (0, 0)\n",
45 | "layout.add_prim(pr.Point(), \"Origin\")\n",
46 | "layout.add_constraint(co.Fix(), (\"Origin\",), (np.array([0, 0]),))\n",
47 | "\n",
48 | "# Create a box to represent the figure\n",
49 | "layout.add_prim(pr.Quadrilateral(), \"Figure\")\n",
50 | "layout.add_constraint(co.Box(), (\"Figure\",), ())\n",
51 | "\n",
52 | "# Create `Axes` objects to represent the left and right axes\n",
53 | "layout.add_prim(pr.Axes(), \"AxesLeft\")\n",
54 | "layout.add_prim(pr.Axes(), \"AxesRight\")\n",
55 | "layout.add_constraint(co.Box(), (\"AxesLeft/Frame\",), ())\n",
56 | "layout.add_constraint(co.Box(), (\"AxesRight/Frame\",), ())\n",
57 | "\n",
58 | "# Constrain the width and height of the figure\n",
59 | "fig_width, fig_height = 6, 3\n",
60 | "layout.add_constraint(co.Width(), (\"Figure\",), (fig_width,))\n",
61 | "layout.add_constraint(co.Height(), (\"Figure\",), (fig_height,))\n",
62 | "\n",
63 | "# Constrain the bottom-left corner of the figure to the origin\n",
64 | "layout.add_constraint(co.Coincident(), (\"Figure/Line0/Point0\", \"Origin\"), ())\n",
65 | "\n",
66 | "# Constrain the left/right margins to `AxesLeft` and `AxesRight`, respectively\n",
67 | "margin_left = 0.5\n",
68 | "margin_right = 0.5\n",
69 | "\n",
70 | "layout.add_constraint(\n",
71 | " co.InnerMargin(side='left'), (\"AxesLeft/Frame\", \"Figure\",), (margin_left,)\n",
72 | ")\n",
73 | "layout.add_constraint(\n",
74 | " co.InnerMargin(side='right'), (\"AxesRight/Frame\", \"Figure\"), (margin_right,)\n",
75 | ")\n",
76 | "\n",
77 | "# Constrain the gap between the left and right axes\n",
78 | "margin_inter = 0.5\n",
79 | "layout.add_constraint(\n",
80 | " co.OuterMargin(side='right'), (\"AxesLeft/Frame\", \"AxesRight/Frame\"), (margin_inter,)\n",
81 | ")\n",
82 | "\n",
83 | "# Constrain the top/bottom margins on the left axes ('AxesLeft')\n",
84 | "# We can align the left/right axes to implicity set those margins\n",
85 | "margin_top = 1.0\n",
86 | "margin_bottom = 0.5\n",
87 | "layout.add_constraint(\n",
88 | " co.InnerMargin(side=\"bottom\"), (\"AxesLeft/Frame\", \"Figure\"), (margin_bottom,)\n",
89 | ")\n",
90 | "layout.add_constraint(\n",
91 | " co.InnerMargin(side=\"top\"), (\"AxesLeft/Frame\", \"Figure\"), (margin_top,)\n",
92 | ")\n",
93 | "\n",
94 | "# Align the left/right axes in a row\n",
95 | "layout.add_constraint(co.AlignRow(), (\"AxesLeft/Frame\", \"AxesRight/Frame\"), ())\n",
96 | "\n",
97 | "# Constrain the width of 'AxesLeft'\n",
98 | "# Note that the right axes width is already constrained through the margins and\n",
99 | "# known left axes width\n",
100 | "width = 2\n",
101 | "layout.add_constraint(co.Width(), (\"AxesLeft/Frame\",), (width,))"
102 | ]
103 | },
104 | {
105 | "cell_type": "markdown",
106 | "metadata": {},
107 | "source": [
108 | "## Solve the layout"
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": null,
114 | "metadata": {},
115 | "outputs": [],
116 | "source": [
117 | "# Solve and plot the constrained layout\n",
118 | "prims, info = solver.solve(layout, max_iter=40, rel_tol=1e-9)\n",
119 | "\n",
120 | "fig_layout, ax_layout = ui.figure_prims(prims)"
121 | ]
122 | },
123 | {
124 | "cell_type": "markdown",
125 | "metadata": {},
126 | "source": [
127 | "## Plot a figure using the layout"
128 | ]
129 | },
130 | {
131 | "cell_type": "code",
132 | "execution_count": null,
133 | "metadata": {},
134 | "outputs": [],
135 | "source": [
136 | "## Create a figure and axes from the constrained primitives\n",
137 | "fig, axs = lplt.subplots(prims)\n",
138 | "\n",
139 | "x = np.linspace(0, 1)\n",
140 | "axs[\"AxesLeft\"].plot(x, 4 * x)\n",
141 | "axs[\"AxesRight\"].plot(x, x**2)\n",
142 | "\n",
143 | "fig.savefig(\"two_axes.png\")\n"
144 | ]
145 | }
146 | ],
147 | "metadata": {
148 | "kernelspec": {
149 | "display_name": "numerics",
150 | "language": "python",
151 | "name": "python3"
152 | },
153 | "language_info": {
154 | "codemirror_mode": {
155 | "name": "ipython",
156 | "version": 3
157 | },
158 | "file_extension": ".py",
159 | "mimetype": "text/x-python",
160 | "name": "python",
161 | "nbconvert_exporter": "python",
162 | "pygments_lexer": "ipython3",
163 | "version": "3.12.7"
164 | }
165 | },
166 | "nbformat": 4,
167 | "nbformat_minor": 2
168 | }
169 |
--------------------------------------------------------------------------------
/logo.svg:
--------------------------------------------------------------------------------
1 |
2 |
255 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0", "jax", "numpy", "matplotlib"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "matplotlib-layout"
7 | version = "0.1.0"
8 | authors = [
9 | { name="Jonathan Deng", email="jonathan.j.deng@gmail.com" },
10 | ]
11 | description = "A package for laying out figures with geometric constraints"
12 | readme = "README.md"
13 | requires-python = ">=3.7"
14 | classifiers = [
15 | "Programming Language :: Python :: 3",
16 | "License :: OSI Approved :: MIT License",
17 | "Operating System :: OS Independent",
18 | ]
19 |
20 | [project.urls]
21 | "Homepage" = "https://github.com/jon-deng/mpl-layout"
22 | "Bug Tracker" = "https://github.com/jon-deng/mpl-layout/issues"
23 |
--------------------------------------------------------------------------------
/src/mpllayout/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jon-deng/mpl-layout/6a10aeff3dfa6e989cc8d9beefa55b78527ebe6a/src/mpllayout/__init__.py
--------------------------------------------------------------------------------
/src/mpllayout/containers.py:
--------------------------------------------------------------------------------
1 | """
2 | Tree class (`Node`) definition and related utilities for trees
3 |
4 | This module defines a tree class `Node` and related utilities.
5 | This class is used by itself as well as to define geometric primitives and constraints.
6 | """
7 |
8 | from typing import TypeVar, Generic, Any, Iterable, Callable
9 |
10 | import itertools
11 | import functools
12 |
13 | import jax
14 |
15 | TValue = TypeVar("TValue")
16 | TNode = TypeVar("TNode", bound="Node")
17 |
18 | class Node(Generic[TValue]):
19 | """
20 | Tree structure with labelled child nodes
21 |
22 | Parameters
23 | ----------
24 | value: TValue
25 | A value associated with the node
26 | children: dict[str, Node]
27 | A dictionary of child nodes
28 |
29 | Attributes
30 | ----------
31 | value: TValue
32 | The value stored in the node
33 | children: dict[str, Node]
34 | A dictionary of child nodes
35 | """
36 |
37 | def __init__(self: TNode, value: TValue, children: dict[str, TNode]):
38 | assert isinstance(children, dict)
39 | self._value = value
40 | self._children = children
41 |
42 | @classmethod
43 | def from_tree(cls, value: TValue, children: dict[str, TNode]):
44 | """
45 | Return any `Node` subclass from its value and children
46 |
47 | This method is needed because some `Node` subclasses have different `__init__`
48 | signatures.
49 | `from_tree` can be used to recreate any `Node` subclass using just a known value
50 | and children.
51 | This is particularly important for flattening and unflattening a tree.
52 | """
53 | node = super().__new__(cls)
54 | Node.__init__(node, value, children)
55 | return node
56 |
57 |
58 | ## Tree methods
59 |
60 | @property
61 | def value(self) -> TValue:
62 | """
63 | Return the node value
64 | """
65 | return self._value
66 |
67 | @property
68 | def children(self: TNode) -> dict[str, TNode]:
69 | """
70 | Return all child nodes
71 | """
72 | return self._children
73 |
74 | def node_height(self) -> int:
75 | """
76 | Return the height of a node
77 |
78 | Returns
79 | -------
80 | int
81 | The node height
82 |
83 | The node height is the number of edges from the current node to the
84 | "furthest" child node.
85 | """
86 | if len(self) == 0:
87 | return 0
88 | else:
89 | return 1 + max(child.node_height() for _, child in self.items())
90 |
91 | def get_child(self: TNode, key: str) -> TNode:
92 | split_key = key.split("/", 1)
93 | parent_key, child_keys = split_key[0], split_key[1:]
94 |
95 | try:
96 | if len(child_keys) == 0:
97 | return self._get_child_nonrecursive(parent_key)
98 | else:
99 | return self.children[parent_key].get_child(child_keys[0])
100 | except KeyError as err:
101 | raise KeyError(f"{key}") from err
102 |
103 | def _get_child_nonrecursive(self: TNode, key: str) -> TNode:
104 | return self.children[key]
105 |
106 | def add_child(self: TNode, key: str, child: TNode):
107 | """
108 | Add a child node at the given key
109 |
110 | Raises an error if the key already exists.
111 |
112 | Parameters
113 | ----------
114 | key: str
115 | A child node key
116 |
117 | see `__getitem__`
118 | node: Node
119 | The node to set
120 | """
121 | split_key = key.split("/", 1)
122 | parent_key, child_keys = split_key[0], split_key[1:]
123 |
124 | try:
125 | if len(child_keys) == 0:
126 | self._add_child_nonrecursive(parent_key, child)
127 | else:
128 | self.children[parent_key].add_child(child_keys[0], child)
129 |
130 | except KeyError as err:
131 | raise KeyError(f"{key}") from err
132 |
133 | def _add_child_nonrecursive(self: TNode, key: str, child: TNode):
134 | """
135 | Add a primitive indexed by a key
136 |
137 | Base case of recursive `add_child`
138 | """
139 | if key in self.children:
140 | raise KeyError(f"{key}")
141 | else:
142 | self.children[key] = child
143 |
144 | def copy(self):
145 | def identity(value: TValue) -> TValue:
146 | return value
147 | return map(identity, self)
148 |
149 | ## String methods
150 |
151 | def __repr__(self) -> str:
152 | keys_repr = ", ".join(self.keys())
153 | children_repr = ", ".join([node.__repr__() for _, node in self.children.items()])
154 | return f"{type(self).__name__}({self.value}, ({children_repr}), ({keys_repr}))"
155 |
156 | def __str__(self) -> str:
157 | return self.__repr__()
158 |
159 | ## Dict methods
160 |
161 | def __iter__(self):
162 | return self.children.__iter__()
163 |
164 | def __contains__(self, key: str) -> bool:
165 | split_keys = key.split("/", 1)
166 | parent_key = split_keys[0]
167 | child_key = "/".join(split_keys[1:])
168 |
169 | if child_key == "":
170 | return parent_key in self.children
171 | else:
172 | return child_key in self[parent_key]
173 |
174 | def __len__(self) -> int:
175 | return len(self.children)
176 |
177 | def keys(self) -> list[str]:
178 | """
179 | Return child keys
180 | """
181 | return list(self.children.keys())
182 |
183 | def values(self: TNode) -> list[TNode]:
184 | """
185 | Return child nodes
186 | """
187 | return list(self.children.values())
188 |
189 | def items(self):
190 | """
191 | Return an iterator of (child key, child node) pairs
192 | """
193 | return self.children.items()
194 |
195 | def __setitem__(self: TNode, key: str, node: TNode):
196 | """
197 | Set the child node at the given key
198 |
199 | Raises an error if the key doesn't exist.
200 |
201 | Parameters
202 | ----------
203 | key: str
204 | A child node key
205 |
206 | see `__getitem__`
207 | node: Node
208 | The node to set
209 | """
210 | # This splits `key = 'a/b/c/d'`
211 | # into `parent_key = 'a/b/c'` and `child_key = 'd'`
212 | split_keys = key.split("/")
213 | parent_key = "/".join(split_keys[:-1])
214 | child_key = split_keys[-1]
215 |
216 | if key not in self:
217 | raise KeyError(key)
218 | else:
219 | if parent_key == "":
220 | parent = self
221 | else:
222 | parent = self[parent_key]
223 | parent.children[child_key] = node
224 |
225 | def __getitem__(self: TNode, key: str | int | slice) -> TNode | list[TNode]:
226 | """
227 | Return the value indexed by a slash-separated key
228 |
229 | Parameters
230 | ----------
231 | key: str | int | slice
232 | A child node key
233 |
234 | The interpretation depends on the key type:
235 | - `str` keys indicate a child key and can be slash separated to denote
236 | child keys of child keys,
237 | for example, 'childa/grandchildb/greatgrandchildc'.
238 | - `int` keys indicate a child by integer index.
239 | - `slice` keys indicate a range of children.
240 | """
241 | if isinstance(key, str):
242 | return self.get_child(key)
243 | elif isinstance(key, (int, slice)):
244 | return list(self.children.values())[key]
245 | else:
246 | raise TypeError("")
247 |
248 |
249 | TItem = TypeVar("TItem")
250 |
251 | class ItemCounter(Generic[TItem]):
252 | """
253 | Count the number of added items by category (a string)
254 |
255 | This is used to generate unique string keys for objects
256 | (see `Layout.add_constraint`).
257 |
258 | Parameters
259 | ----------
260 | categorize: Callable[[TItem], str]
261 | A function that returns the category (string) of an item
262 |
263 | Attributes
264 | ----------
265 | category_counts: dict[str, int]
266 | The number of items added to each category
267 | """
268 |
269 | @staticmethod
270 | def categorize_by_classname(item: TItem) -> str:
271 | return type(item).__name__
272 |
273 | def __init__(self, categorize: Callable[[TItem], str] = categorize_by_classname):
274 | self._category_counts = {}
275 | self._categorize = categorize
276 |
277 | @property
278 | def category_counts(self) -> dict[str, int]:
279 | return self._category_counts
280 |
281 | def __contains__(self, key):
282 | return key in self._p
283 |
284 | def categorize(self, item: TItem) -> str:
285 | """
286 | Return the category string of an item
287 | """
288 | return self._categorize(item)
289 |
290 | def add_item(self, item: TItem) -> str:
291 | """
292 | Add an item
293 |
294 | Parameters
295 | ----------
296 | item: TItem
297 | The item to add
298 |
299 | Returns
300 | -------
301 | str
302 | A string identifying the added item's category and count
303 | """
304 | category = self.categorize(item)
305 | if category in self.category_counts:
306 | self.category_counts[category] += 1
307 | else:
308 | self.category_counts[category] = 1
309 |
310 | count = self.category_counts[category] - 1
311 | return f"{category}{count}"
312 |
313 | def add_item_until_valid(self, item: TItem, valid: Callable[[str], bool]) -> str:
314 | """
315 | Add an item until the return item key is valid
316 |
317 | This can be used to keep adding items until a unique key is generated for some
318 | existing collection.
319 | For example, if a dictionary of items already exists, this function can be used
320 | to generate new item keys until one that doesn't already exist in the dictionary
321 | is found.
322 |
323 | Parameters
324 | ----------
325 | item: TItem
326 | The item to add
327 | valid: Callable[[str], bool]
328 | The condition the generated item key must satisfy
329 |
330 | Returns
331 | -------
332 | str
333 | A string identifying the added item's category and count
334 | """
335 | key = self.add_item(item)
336 | while not valid(key):
337 | key = self.add_item(item)
338 |
339 | return key
340 |
341 | def add_item_to_nodes(self, item: TItem, *nodes: tuple[Node, ...]) -> str:
342 | """
343 | Add an item until the item key is unique within a set of trees
344 |
345 | This is used generate a unique key for an item for a set of existing trees
346 | (`Node`).
347 |
348 | Parameters
349 | ----------
350 | item: TItem
351 | The item to add
352 | *nodes: tuple[Node, ...]
353 | The set of trees
354 |
355 | The returned item key should exist in these trees.
356 |
357 | Returns
358 | -------
359 | str
360 | A string identifying the added item's category and count
361 | """
362 | def valid(key):
363 | key_notin_nodes = (key not in node for node in nodes)
364 | return all(key_notin_nodes)
365 | return self.add_item_until_valid(item, valid)
366 |
367 |
368 | ## Node functions
369 |
370 | U = TypeVar("U")
371 |
372 | def map(
373 | function: Callable[[TValue], U],
374 | node: Node[TValue]
375 | ) -> Node[U]:
376 | """
377 | Return a node by applying a function to every value in an input node
378 | """
379 |
380 | flat_node_structs = flatten('', node)
381 |
382 | flat_map_node_structs = [
383 | (key, Node, function(value), child_keys)
384 | for (key, _NodeType, value, child_keys) in flat_node_structs
385 | ]
386 | return unflatten(flat_map_node_structs)[0]
387 |
388 | def accumulate(
389 | function: Callable[[TValue, TValue], TValue],
390 | node: Node[TValue],
391 | initial: TValue
392 | ) -> Node[TValue]:
393 | """
394 | Return a node by accumulating all leaf node values into the root
395 | """
396 | # Recursively create all accumulated child nodes
397 | cnodes = {
398 | ckey: accumulate(function, cnode, initial)
399 | for ckey, cnode in node.items()
400 | }
401 |
402 | if len(cnodes) == 0:
403 | value = function(node.value, initial)
404 | else:
405 | value = functools.reduce(
406 | function, [node.value]+[cnode.value for cnode in cnodes.values()]
407 | )
408 | return Node(value, cnodes)
409 |
410 | ## Manual flattening/unflattening implementation
411 |
412 | def iter_flat(root_key: str, root_node: TNode) -> Iterable[tuple[str, TNode]]:
413 | """
414 | Return an iterable over all nodes in the root node (recursively depth-first)
415 |
416 | Parameters
417 | ----------
418 | root_key: str
419 | A key for the node
420 |
421 | All child node keys will be appended to this key with a '/' separator.
422 | root_node: TNode
423 | The root node
424 |
425 | Returns
426 | -------
427 | Iterable[tuple[str, TNode]]
428 | An iterable over all nodes in the root node
429 | """
430 | # TODO: Fix mypy typing errors here
431 |
432 | # The flattened node consists of the root node tuple...
433 | flat_root_node = [(root_key, root_node)]
434 |
435 | # then recursively appends all flattened child nodes
436 | flat_child_nodes = [
437 | iter_flat(f"{root_key}/{ckey}", cnode)
438 | for ckey, cnode in root_node.items()
439 | ]
440 | return itertools.chain(flat_root_node, *flat_child_nodes)
441 |
442 |
443 | FlatNodeStructure = tuple[str, type[Node], TValue, list[str]]
444 |
445 | def flatten(root_key: str, root_node: TNode) -> list[FlatNodeStructure]:
446 | """
447 | Return a flattened list of node structures for a root node (recursively depth-first)
448 |
449 | Parameters
450 | ----------
451 | root_key: str
452 | A key for the node
453 | root_node: TNode
454 | The root node
455 |
456 | Returns
457 | -------
458 | list[FlatNodeStructure]
459 | A list of node structures
460 |
461 | Each node structure is a tuple representing the node.
462 | """
463 | node_structs = [
464 | (key, type(node), node.value, node.keys())
465 | for key, node in iter_flat(root_key, root_node)
466 | ]
467 | return node_structs
468 |
469 | def unflatten(
470 | node_structs: list[FlatNodeStructure],
471 | ) -> tuple[TNode, list[FlatNodeStructure]]:
472 | """
473 | Return the root node from a flat representation
474 |
475 | Parameters
476 | ----------
477 | node_structs: list[FlatNodeStructure]
478 | The flat representation
479 |
480 | Returns
481 | -------
482 | TNode
483 | The root node
484 | list[FlatNodeStructure]
485 | A "leftover" flat node representation
486 |
487 | This list should be empty if the flat node representation only contains
488 | nodes that belong to the root node.
489 | """
490 | node_key, node_type, value, child_keys = node_structs[0]
491 |
492 | children = []
493 | node_structs = node_structs[1:]
494 | for _key in child_keys:
495 | child, node_structs = unflatten(node_structs)
496 | children.append(child)
497 |
498 | node = node_type.from_tree(
499 | value, {key: child for key, child in zip(child_keys, children)}
500 | )
501 | return node, node_structs
502 |
503 |
504 | ## pytree flattening/unflattening implementation
505 | # These functions register `Node` classes as a `jax.pytree` so jax can flatten/unflatten
506 | # them
507 | FlatNode = tuple[TValue, dict[str, TNode]]
508 | AuxData = Any
509 |
510 | def _make_flatten_unflatten(node_type: type[TNode]):
511 |
512 | def _flatten_node(node: TNode) -> tuple[FlatNode, AuxData]:
513 | flat_node = (node.value, node.children)
514 | aux_data = None
515 | return (flat_node, aux_data)
516 |
517 | def _unflatten_node(aux_data: AuxData, flat_node: FlatNode) -> TNode:
518 | value, children = flat_node
519 | return node_type.from_tree(value, children)
520 |
521 | return _flatten_node, _unflatten_node
522 |
523 |
524 | ## Register `Node` as `jax.pytree`
525 | for _NodeType in [Node]:
526 | _flatten, _unflatten = _make_flatten_unflatten(_NodeType)
527 | jax.tree_util.register_pytree_node(_NodeType, _flatten, _unflatten)
528 |
--------------------------------------------------------------------------------
/src/mpllayout/geometry.py:
--------------------------------------------------------------------------------
1 | """
2 | Geometric primitives and constraints
3 | """
4 |
5 | from .primitives import *
6 | from .constraints import *
7 |
--------------------------------------------------------------------------------
/src/mpllayout/layout.py:
--------------------------------------------------------------------------------
1 | """
2 | Layout class and associated utilities
3 |
4 | A `Layout` is the class used to represent an arrangement (layout) of figure
5 | elements.
6 | """
7 |
8 | from typing import Optional, Any
9 | from numpy.typing import NDArray
10 |
11 | from matplotlib.axes import Axes
12 | import numpy as np
13 |
14 | from . import primitives as pr
15 | from . import constraints as cr
16 | from .containers import ItemCounter, iter_flat
17 |
18 | IntGraph = list[tuple[int, ...]]
19 | StrGraph = list[tuple[str, ...]]
20 |
21 |
22 | class Layout:
23 | """
24 | A collection of geometric primitives and constraints
25 |
26 | A `Layout` stores a collection of geometric primitives and the constraints
27 | on those primitives to represent an arrangement of figure elements.
28 |
29 | Parameters
30 | ----------
31 | root_prim: Optional[pr.PrimitiveNode]
32 | A root primitive to store all geometric primitives
33 | root_constraint: Optional[cr.ConstraintNode]
34 | A root constraint to store all geometric constraints
35 | root_prim_keys: Optional[cr.PrimKeysNode]
36 | A root primitive arguments tree for `root_constraint`
37 |
38 | Each value in `root_prim_keys` is a tuple of primitive keys
39 | for the corresponding constraint.
40 | The tuple of primitive keys indicate primitives from `root_prim`.
41 | root_param: Optional[cr.ParamsNode]
42 | A root parameter tree for `root_constraint`
43 |
44 | Each value in `root_param` is a constraint residual parameter
45 | for the corresponding constraint in `root_constraint`.
46 | constraint_counter: Optional[ItemCounter]
47 | An item counter used to create automatic keys for constraints
48 |
49 | This does not have to be supplied.
50 | Supplying it will change any automatically generated constraint keys.
51 |
52 | Attributes
53 | ----------
54 | root_prim: pr.PrimitiveNode
55 | See `Parameters`
56 | root_constraint: cr.ConstraintNode
57 | See `Parameters`
58 | root_prim_keys: cr.PrimKeysNode
59 | See `Parameters`
60 | root_param: cr.ParamsNode
61 | See `Parameters`
62 | """
63 |
64 | def __init__(
65 | self,
66 | root_prim: Optional[pr.PrimitiveNode] = None,
67 | root_constraint: Optional[cr.ConstraintNode] = None,
68 | root_prim_keys: Optional[cr.PrimKeysNode] = None,
69 | root_param: Optional[cr.ParamsNode] = None,
70 | constraint_counter: Optional[ItemCounter] = None,
71 | ):
72 |
73 | if root_prim is None:
74 | root_prim = pr.PrimitiveNode.from_tree(np.array([]), {})
75 | if root_constraint is None:
76 | root_constraint = cr.ConstraintNode.from_tree(None, {})
77 | if root_prim_keys is None:
78 | root_prim_keys = cr.PrimKeysNode(None, {})
79 | if root_param is None:
80 | root_param = cr.ParamsNode(None, {})
81 | if constraint_counter is None:
82 | constraint_counter = ItemCounter()
83 |
84 | self._root_prim = root_prim
85 | self._root_constraint = root_constraint
86 | self._root_constraint_prim_keys = root_prim_keys
87 | self._root_constraint_param = root_param
88 |
89 | self._constraint_counter = constraint_counter
90 |
91 | @property
92 | def root_prim(self) -> pr.PrimitiveNode:
93 | return self._root_prim
94 |
95 | @property
96 | def root_constraint(self) -> cr.ConstraintNode:
97 | return self._root_constraint
98 |
99 | @property
100 | def root_prim_keys(self) -> cr.PrimKeysNode:
101 | return self._root_constraint_prim_keys
102 |
103 | @property
104 | def root_param(self) -> cr.ParamsNode:
105 | return self._root_constraint_param
106 |
107 | def flat_constraints(
108 | self
109 | ) -> tuple[list[cr.Constraint], list[cr.PrimKeys], list[cr.Params]]:
110 | """
111 | Return flat constraints, primitive argument keys and parameters
112 |
113 | Returns
114 | -------
115 | constraints: list[cr.Constraint]
116 | A flat list of constraints from the root constraint
117 | prim_keys: list[cr.PrimKeys]
118 | A list of primitive keys for each constraint
119 | params: list[cr.ResParams]
120 | A list of residual parameters for each constraint
121 | """
122 | # The `[1:]` removes the 'root' constraint which is just a container
123 | constraints = [
124 | node for _, node in iter_flat('', self.root_constraint)
125 | ][1:]
126 | prim_keys = [
127 | node.value for _, node in iter_flat('', self.root_prim_keys)
128 | ][1:]
129 | params = [
130 | node.value for _, node in iter_flat('', self.root_param)
131 | ][1:]
132 |
133 | constraints = [c.assem_atleast_1d for c in constraints]
134 |
135 | return constraints, prim_keys, params
136 |
137 | def add_prim(self, prim: pr.Primitive, key: str):
138 | """
139 | Add a primitive to the layout
140 |
141 | The primitive will be added to `self.root_prim` using the given `key`.
142 |
143 | Parameters
144 | ----------
145 | prim: pr.Primitive
146 | The primitive to add
147 | key: str
148 | The key for the primitive
149 | """
150 | self.root_prim.add_child(key, prim)
151 |
152 | def add_prims(self, prims: dict[str, pr.Primitive]):
153 | """
154 | Add multiple primitives to the layout
155 |
156 | Parameters
157 | ----------
158 | prims: dict[str, pr.Primitive]
159 | A dictionary of primitives to add
160 | """
161 | for key, prim in prims.items():
162 | self.add_prim(prim, key)
163 |
164 | def add_constraint(
165 | self,
166 | constraint: cr.Constraint,
167 | prim_keys: cr.PrimKeys,
168 | param: cr.Params,
169 | key: str = ""
170 | ):
171 | """
172 | Add a constraint between primitives
173 |
174 | Parameters
175 | ----------
176 | constraint: cr.Constraint
177 | The constraint to add
178 | prim_keys: cr.PrimKeys
179 | A tuple of primitive keys the constraint applies to
180 |
181 | The primitive keys refer to primitives in `self.root_prim`.
182 | param: cr.ResParams
183 | Parameters for the constraint
184 | key: str
185 | An optional key to identify the constraint
186 |
187 | If not supplied, a key will be automatically generated using
188 | `_constraint_counter`.
189 | """
190 | nodes = (
191 | self.root_constraint, self.root_prim_keys, self.root_param
192 | )
193 | if key == "":
194 | key = self._constraint_counter.add_item_to_nodes(constraint, *nodes)
195 |
196 | # Notify the user if the input primitives or parameters to the constraint
197 | # are incorrect
198 | try:
199 | constraint.validate_prims(tuple(self.root_prim[key] for key in prim_keys))
200 | except TypeError as exc:
201 | raise RuntimeError(f"Wrong primitives {prim_keys} for constraint {type(constraint)}\n{exc}") from exc
202 |
203 | try:
204 | constraint.validate_params(param)
205 | except TypeError as exc:
206 | raise RuntimeError(f"Wrong parameters {param} for constraint {type(constraint)}\n{exc}") from exc
207 |
208 | self.root_constraint.add_child(key, constraint)
209 | self.root_prim_keys.add_child(key, constraint.root_prim_keys(prim_keys))
210 | self.root_param.add_child(key, constraint.root_params(param))
211 |
212 | def update_layout_constraints(
213 | layout: Layout,
214 | axs: dict[str, Axes]
215 | ) -> Layout:
216 | """
217 | Update layout constraints that depend on `matplotlib` elements
218 |
219 | Some constraints have parameters that depend on `matplotlib`.
220 | This function shoud identify these constraints in a `Layout` and replace
221 | their parameters with the correct `matplotlib` element.
222 | Currently this is only implemented for `XAxisThickness` and `YAxisThickness`.
223 |
224 | Parameters
225 | ----------
226 | layout: Layout
227 | The layout
228 | axs: dict[str, Axes]
229 | The `matplotlib` axes objects
230 |
231 | The key for every `matplotlib.Axes` should match a corresponding
232 | `pr.Axes` in the layout.
233 |
234 | Returns
235 | -------
236 | Layout
237 | The layout with updated constraint parameters
238 | """
239 | constraintkey_to_param = {}
240 | for key, constraint in iter_flat('', layout.root_constraint):
241 | # `key[1:]` removes the initial "/" from the key
242 | key = key[1:]
243 | if key != "":
244 | prim_keys = layout.root_prim_keys[key]
245 | constraint_param = layout.root_param[key]
246 |
247 | if isinstance(constraint, cr.XAxisThickness):
248 | axis_key, = prim_keys.value
249 | axes_key = axis_key.split("/", 1)[0]
250 | constraintkey_to_param[key] = (axs[axes_key].xaxis,)
251 |
252 | if isinstance(constraint, cr.YAxisThickness):
253 | axis_key, = prim_keys.value
254 | axes_key = axis_key.split("/", 1)[0]
255 | constraintkey_to_param[key] = (axs[axes_key].yaxis,)
256 |
257 | new_root_param = update_root_param(
258 | layout.root_constraint,
259 | layout.root_param,
260 | constraintkey_to_param
261 | )
262 |
263 | return Layout(
264 | layout.root_prim,
265 | layout.root_constraint,
266 | layout.root_prim_keys,
267 | new_root_param
268 | )
269 |
270 | def update_root_param(
271 | root_constraint: cr.ConstraintNode,
272 | root_param: cr.ParamsNode,
273 | constraintkey_to_param: dict[str, cr.Params]
274 | ) -> cr.ParamsNode:
275 | """
276 | Update the root constraint parameters node
277 |
278 | Parameters
279 | ----------
280 | root_constraint: cr.ConstraintNode
281 | The root constraint
282 | root_param: cr.ParamsNode
283 | The corresponding root parameters node
284 | constraintkey_to_param: dict[str, cr.ResParams]
285 | A mapping of constraint keys to replacement constraint parameters
286 |
287 | Each constraint key should indicate a node in `root_constraint` and
288 | `root_param`.
289 | """
290 | new_root_param = root_param.copy()
291 | for key, param in constraintkey_to_param.items():
292 | constraint = root_constraint[key]
293 | new_root_param[key] = constraint.root_params(param)
294 |
295 | return new_root_param
296 |
--------------------------------------------------------------------------------
/src/mpllayout/matplotlibutils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for creating `matplotlib` elements from geometric primitives
3 | """
4 |
5 | from typing import Optional
6 | from numpy.typing import NDArray
7 | import warnings
8 |
9 | import numpy as np
10 |
11 | from matplotlib import pyplot as plt
12 | from matplotlib.figure import Figure
13 | from matplotlib.axes import Axes
14 |
15 | from . import primitives as pr
16 | from . import constraints as cr
17 |
18 | # TODO: (not critical) Should special primitive classes indicate `matplotlib` figures and axes?
19 | # I'm not certain if that would be that beneficial here.
20 | # If so, this should be done for both `subplots` and `update_subplots`.
21 | def subplots(
22 | root_prim: pr.Primitive,
23 | fig_key: str = "Figure",
24 | axs_keys: Optional[list[str]] = None,
25 | ) -> tuple[Figure, dict[str, Axes]]:
26 | """
27 | Create matplotlib `Figure` and `Axes` objects from geometric primitives
28 |
29 | The `Figure` and `Axes` objects are extracted based on labels in the primitive tree
30 | and have sizes and positions from their corresponding primitives.
31 |
32 | Parameters
33 | ----------
34 | root_prim: pr.Primitive
35 | The root primitive
36 | fig_key: str
37 | The quadrilateral key corresponding to the figure
38 |
39 | The key is "Figure" by default.
40 | axs_keys: Optional[list[str]]
41 | Axes keys
42 |
43 | If supplied, only these axes keys will be used to generate `Axes` instances.
44 |
45 | Returns
46 | -------
47 | fig: Figure
48 | The matplotlib `Figure`
49 | axs: dict[str, Axes]
50 | The matplotlib `Axes` instances
51 | """
52 | # Create the `Figure` instance
53 | fig = plt.figure(figsize=(1, 1))
54 |
55 | # Assume all axes are prefixed by "Axes" if there are no keys provided
56 | if axs_keys is None:
57 | axs_keys = [key for key in root_prim.keys() if "Axes" in key]
58 |
59 | # Create all `Axes` instances
60 | unit_rect = (0, 0, 1, 1)
61 | key_to_ax = {key: fig.add_axes(unit_rect) for key in axs_keys}
62 |
63 | # Update positions figures and axes
64 | fig, key_to_ax = update_subplots(root_prim, fig_key, fig, key_to_ax)
65 |
66 | return fig, key_to_ax
67 |
68 |
69 | def update_subplots(
70 | root_prim: pr.Primitive, fig_key: str, fig: Figure, axs: dict[str, Axes],
71 | ):
72 | """
73 | Update matplotlib `Figure` and `Axes` object positions from primitives
74 |
75 | The `Figure` and `Axes` objects are extracted based on labels in the primitive tree
76 | and have sizes and positions updated from their corresponding primitives.
77 |
78 | Parameters
79 | ----------
80 | root_prim: pr.Primitive
81 | The root primitive
82 | fig_key: str
83 | The quadrilateral key in `root_prim` corresponding to the figure
84 | fig: Figure
85 | The `Figure` to update
86 | axs: dict[str, Axes]
87 | The `Axes` objects to update
88 |
89 | Returns
90 | -------
91 | fig: Figure
92 | The updated matplotlib `Figure`
93 | axs: dict[str, Axes]
94 | The updated matplotlib `Axes` instances
95 | """
96 | # Set Figure position
97 | quad = root_prim[fig_key]
98 | fig_origin = quad['Line0/Point0'].value
99 | fig_size = np.array(width_and_height_from_quad(quad))
100 | fig.set_size_inches(fig_size)
101 |
102 | # Set Axes properties/position
103 | for key, ax in axs.items():
104 | # Set Axes dimensions
105 | quad = root_prim[f"{key}/Frame"]
106 | ax.set_position(rect_from_box(quad, fig_origin, fig_size))
107 |
108 | # Set x/y axis properties
109 | axis_prefixes = ("X", "Y")
110 | axis_tuple = (ax.xaxis, ax.yaxis)
111 |
112 | for axis_prefix, axis in zip(axis_prefixes, axis_tuple):
113 | # Set the axis label position
114 | axis_label = f"{axis_prefix}AxisLabel"
115 | if axis_label in root_prim[key]:
116 | axis_label_point: pr.Point = root_prim[f"{key}/{axis_label}"]
117 | label_coords = axis_label_point.value
118 | axis.set_label_coords(
119 | *(label_coords / fig_size), transform=fig.transFigure
120 | )
121 |
122 | # Set the axis tick position
123 | axis_bbox = f"{axis_prefix}Axis"
124 | if axis_bbox in root_prim[key]:
125 | axis_quad = root_prim[f"{key}/{axis_bbox}"]
126 | axis_tick_position = find_axis_position(quad, axis_quad)
127 | axis.set_ticks_position(axis_tick_position)
128 |
129 | return fig, axs
130 |
131 | def find_axis_position(axes_frame: pr.Quadrilateral, axis: pr.Quadrilateral) -> str:
132 | """
133 | Return the axis position relative to a frame
134 |
135 | Parameters
136 | ----------
137 | axes_frame: pr.Quadrilateral
138 | The axes frame
139 | axis: pr.Quadrilateral
140 | The axes axis
141 |
142 | This can be any x or y axis
143 |
144 | Returns
145 | -------
146 | position: str
147 | One of ('bottom', 'top', 'left', 'right') indicating the axis position
148 | """
149 | coincident_line = cr.CoincidentLines()
150 | params = {"reverse": True}
151 | bottom_res = coincident_line((axes_frame["Line0"], axis["Line2"]), params)
152 | top_res = coincident_line((axes_frame["Line2"], axis["Line0"]), params)
153 | left_res = coincident_line((axes_frame["Line3"], axis["Line1"]), params)
154 | right_res = coincident_line((axes_frame["Line1"], axis["Line3"]), params)
155 |
156 | residuals = tuple(
157 | np.linalg.norm(res) for res in (bottom_res, top_res, left_res, right_res)
158 | )
159 | residual_positions = ("bottom", "top", "left", "right")
160 |
161 | if not np.isclose(np.min(residuals), 0):
162 | warnings.warn("The axis isn't closely aligned with any of the axes sides")
163 | position = residual_positions[np.argmin(residuals)]
164 | return position
165 |
166 |
167 | def width_and_height_from_quad(quad: pr.Quadrilateral) -> tuple[float, float]:
168 | """
169 | Return the width and height of a quadrilateral
170 |
171 | Parameters
172 | ----------
173 | quad: pr.Quadrilateral
174 |
175 | Returns
176 | -------
177 | tuple[float, float]
178 | The width and height
179 | """
180 |
181 | coord_botleft = quad["Line0/Point0"].value
182 | xmin, ymin = coord_botleft
183 |
184 | coord_topright = quad["Line1/Point1"].value
185 | xmax, ymax = coord_topright
186 |
187 | return (xmax - xmin), (ymax - ymin)
188 |
189 |
190 | def rect_from_box(
191 | quad: pr.Quadrilateral,
192 | fig_origin: NDArray,
193 | fig_size: NDArray = np.array((1, 1))
194 | ) -> tuple[float, float, float, float]:
195 | """
196 | Return a `rect' tuple, `(left, bottom, width, height)`, from a quadrilateral
197 |
198 | This tuple of quadrilateral information can be used to create a `Bbox` or `Axes`
199 | object in `matplotlib`.
200 |
201 | Parameters
202 | ----------
203 | quad: pr.Quadrilateral
204 | The quadrilateral
205 | fig_origin: NDArray
206 | Coordinates for the figure bottom left corner
207 | fig_size: NDArray
208 | The width and height of the figure
209 |
210 | This should be supplied so that the rect tuple has units relative to the figure.
211 | Some matplotlib `Axes` constructors accept the rect tuple in figure units by default.
212 |
213 | Returns
214 | -------
215 | xmin, ymin, width, heigth: tuple[float, float, float, float]
216 | """
217 |
218 | coord_botleft = quad["Line0/Point0"].value
219 | xmin, ymin = (coord_botleft-fig_origin) / fig_size
220 |
221 | coord_topright = quad["Line1/Point1"].value
222 | xmax, ymax = (coord_topright-fig_origin) / fig_size
223 | width = xmax - xmin
224 | height = ymax - ymin
225 |
226 | return (xmin, ymin, width, height)
227 |
--------------------------------------------------------------------------------
/src/mpllayout/primitives.py:
--------------------------------------------------------------------------------
1 | """
2 | Geometric primitive definitions
3 | """
4 |
5 | from typing import Optional, TypeVar, Any
6 | from numpy.typing import NDArray
7 |
8 | import numpy as np
9 | import jax
10 |
11 | # from .containers import Node, _make_flatten_unflatten, iter_flat, unflatten, FlatNodeStructure
12 | import mpllayout.containers as cn
13 |
14 |
15 | ## Generic primitive class/interface
16 | # You can create specific primitive definitions by inheriting from these and
17 | # defining appropriate class attributes
18 |
19 | TPrim = TypeVar("TPrim", bound="PrimitiveNode")
20 | PrimValue = NDArray[np.float64]
21 | PrimTypes = tuple[type[TPrim] | None, ...]
22 | PrimTypesSignature = PrimTypes | set[PrimTypes]
23 | PrimNodeSignature = tuple[int, PrimTypesSignature]
24 |
25 | def validate_prims(
26 | prims: list[TPrim], prim_types: PrimTypes
27 | ) -> tuple[bool, str]:
28 | """
29 | Return whether input primitives are valid
30 |
31 | Parameters
32 | ----------
33 | prims: list[TPrim]
34 | Primitives to validate
35 | prim_types: PrimTypes
36 | A tuple of primitive types that `prims` must match
37 |
38 | Raises
39 | -------
40 | ValueError
41 | """
42 | ## Pre-process `prim_types` into a simple tuple of primitive types
43 | # Expand any `ellipsis` in `prim_types` by treating types before the ...
44 | # as repeating units
45 | if len(prim_types) > 0:
46 | if prim_types[-1] == Ellipsis:
47 | repeat_prim_types = prim_types[:-1]
48 | n_repeat = len(prims) // len(repeat_prim_types)
49 | prim_types = n_repeat * repeat_prim_types
50 |
51 | # Replace any `None` primitive types with `PrimitiveNode`
52 | prim_types = [
53 | PrimitiveNode if prim_type is None else prim_type
54 | for prim_type in prim_types
55 | ]
56 |
57 | ## Check `prims` has the right length
58 | if len(prims) != len(prim_types):
59 | raise ValueError(
60 | f"Invalid `prims` must be {len(prim_types)} not {len(prims)}"
61 | )
62 |
63 | ## Check `prims` are the right types
64 | _prim_types = [type(prim) for prim in prims]
65 | if not all(
66 | issubclass(_prim_type, prim_type)
67 | for _prim_type, prim_type in zip(_prim_types, prim_types)
68 | ):
69 | raise ValueError(
70 | f"Invalid `prims` types must be {prim_types} not {_prim_types}"
71 | )
72 |
73 |
74 | # TODO: Improve primitive array data structure?
75 | # Each primitive node stores it's own parameter vector but this doesn't
76 | # handle connectivity information. For example, the same point coordinate is
77 | # owned by two lines in a polygon
78 |
79 | class PrimitiveNode(cn.Node[NDArray[np.float64]]):
80 | """
81 | Node representation of a geometric primitive
82 |
83 | A geometric primitive (prim for short) is represented by a parameter vector
84 | and child primitives. For example, in 2D, a point has a 2 element parameter
85 | vector representing (x, y) coordinates and no child prims. A straight
86 | line segment has an empty parameter vector with two point child prims
87 | representing the start and end points.
88 |
89 | Parameters
90 | ----------
91 | value: PrimValue
92 | Parameter vector for the prim
93 | children: dict[str, TPrim]
94 | Child prims
95 |
96 | Attributes
97 | ----------
98 | signature: PrimNodeSignature
99 | A tuple specifying valid parameter vector and child types
100 |
101 | A signature has two components
102 | ``(param_size, prim_types) = signature``,
103 | where `param_size` is the size of the parameter vector and `prim_types`
104 | indicates valid child primitives.
105 | """
106 |
107 | # NOTE: I used `None` to indicate `PrimitiveNode` because the name isn't
108 | # available within the class itself
109 | signature: PrimNodeSignature = (0, (None, ...))
110 |
111 | def __init__(self, value: PrimValue, children: dict[str, TPrim]):
112 |
113 | # Type checks that value is an array and children are prims
114 | assert isinstance(value, np.ndarray)
115 | assert all(
116 | isinstance(cprim, PrimitiveNode) for _, cprim in children.items()
117 | )
118 | param_size, prim_type_sig = self.signature
119 |
120 | assert len(value) == param_size
121 |
122 | prims = [cprim for _, cprim in children.items()]
123 | if isinstance(prim_type_sig, tuple):
124 | prim_types = prim_type_sig
125 | try:
126 | validate_prims(prims, prim_types)
127 | except ValueError as err:
128 | raise err
129 | elif isinstance(prim_type_sig, set):
130 | def match(prims, prim_types):
131 | try:
132 | validate_prims(prims, prim_types)
133 | except ValueError as err:
134 | return False
135 | else:
136 | return True
137 |
138 | if not any(
139 | match(prims, prim_types) for prim_types in prim_type_sig
140 | ):
141 | raise ValueError(
142 | f"No matching signatures for `prims` in {prim_type_sig}"
143 | )
144 |
145 | super().__init__(value, children)
146 |
147 |
148 | class Primitive(PrimitiveNode):
149 | """
150 | Primitive parameterized by a value, primitives and keyword arguments
151 |
152 | To define a primitive, subclass `Primitive` and define the class methods
153 | `init_children`, `default_value` and `default_prims`.
154 |
155 | Parameters
156 | ----------
157 | value: Optional[PrimValue]
158 | Parameter vector for the primitive
159 | prims: Optional[list[TPrim]]
160 | Parameterizing primitives
161 |
162 | Note that these are often the same as child primitives but this isn't
163 | always the case.
164 | **kwargs: dict[str, Any]
165 | Additional keyword arguments
166 |
167 | Attributes
168 | ----------
169 | see `PrimitiveNode`
170 | """
171 |
172 | def __init__(
173 | self,
174 | value: Optional[NDArray] = None,
175 | prims: Optional[list[TPrim]] = None,
176 | **kwargs: dict[str, Any]
177 | ):
178 | if value is None:
179 | value = self.default_value(**kwargs)
180 | elif isinstance(value, (list, tuple)):
181 | value = np.array(value)
182 | elif isinstance(value, (np.ndarray, jax.numpy.ndarray)):
183 | value = value
184 | else:
185 | raise TypeError()
186 |
187 | if prims is None:
188 | prims = self.default_prims(**kwargs)
189 |
190 | child_keys, child_prims = self.init_children(prims, **kwargs)
191 |
192 | super().__init__(
193 | value, {key: prim for key, prim in zip(child_keys, child_prims)}
194 | )
195 |
196 | @classmethod
197 | def init_children(
198 | cls, prims: list[TPrim], **kwargs: dict[str, Any]
199 | ) -> tuple[list[str], list[TPrim]]:
200 | """
201 | Return child primitives from parameterizing primitives
202 |
203 | Parameters
204 | ----------
205 | prims: list[TPrim]
206 | Parameterizing primitives
207 |
208 | Returns
209 | -------
210 | List[str]
211 | Child primitive keys
212 | List[TPrim]
213 | Child primitives
214 | """
215 | raise NotImplementedError()
216 |
217 | @classmethod
218 | def default_value(cls, **kwargs: dict[str, Any]) -> NDArray:
219 | """
220 | Return a default parameter vector
221 | """
222 | raise NotImplementedError()
223 |
224 | @classmethod
225 | def default_prims(cls, **kwargs: dict[str, Any]) -> list[TPrim]:
226 | """
227 | Return default parameterizing primitives
228 | """
229 | raise NotImplementedError()
230 |
231 |
232 | class StaticPrimitive(Primitive):
233 | """
234 | A `Primitive` with no additional keyword arguments
235 |
236 | Static primitives have a static parameter vector shape and child primitive
237 | types.
238 |
239 | Parameters
240 | ----------
241 | See `Primitive`
242 | """
243 |
244 | def __init__(
245 | self,
246 | value: Optional[NDArray] = None,
247 | prims: Optional[list[TPrim]] = None,
248 | ):
249 | super().__init__(value=value, prims=prims)
250 |
251 | @classmethod
252 | def init_children(
253 | cls, prims: list[TPrim]
254 | ) -> tuple[list[str], list[TPrim]]:
255 | """
256 | See `Primitive`
257 | """
258 | raise NotImplementedError()
259 |
260 | @classmethod
261 | def default_value(cls) -> NDArray:
262 | """
263 | See `Primitive`
264 | """
265 | raise NotImplementedError()
266 |
267 | @classmethod
268 | def default_prims(cls) -> list[TPrim]:
269 | """
270 | See `Primitive`
271 | """
272 | raise NotImplementedError()
273 |
274 | ## Primitive definitions
275 |
276 | class Point(StaticPrimitive):
277 | """
278 | A point
279 |
280 | Parameters
281 | ----------
282 | value: Optional[NDArray] with shape (2,)
283 | The point coordinate
284 | prims: Optional[tuple[]]
285 | An empty tuple
286 | """
287 |
288 | signature = (2, ())
289 |
290 | @classmethod
291 | def init_children(cls, prims):
292 | return (), ()
293 |
294 | @classmethod
295 | def default_value(cls):
296 | return np.array([0, 0])
297 |
298 | @classmethod
299 | def default_prims(cls):
300 | return ()
301 |
302 |
303 | class Line(StaticPrimitive):
304 | """
305 | A straight line segment between two points
306 |
307 | Parameters
308 | ----------
309 | value: Optional[NDArray] with shape ()
310 | An empty vector
311 | prims: Optional[tuple[Point, Point]]
312 | A tuple containing the line start and end point
313 | """
314 |
315 | signature = (0, (Point, Point))
316 |
317 | @classmethod
318 | def init_children(cls, prims: tuple[Point, Point]):
319 | return ("Point0", "Point1"), prims
320 |
321 | @classmethod
322 | def default_value(cls):
323 | return np.array(())
324 |
325 | @classmethod
326 | def default_prims(cls):
327 | return (Point([0, 0]), Point([0, 1]))
328 |
329 |
330 | class Polygon(Primitive):
331 | """
332 | A polygon with straight-line edges through a set of points
333 |
334 | The polygon tree structure contains a sequence of lines where ehe end point
335 | of each line joins the start point of the next line forming a closed loop.
336 |
337 | Parameters
338 | ----------
339 | value: Optional[NDArray] with shape ()
340 | An empty vector
341 | prims: Optional[tuple[Point, ...]]
342 | A list of points representing polygon vertices
343 |
344 | The number of resulting polygon edges (`Line` instances) is the same as
345 | the number of points.
346 |
347 | Child primitives are lines joining the given vertices.
348 | size: int
349 | The number of polygon vertices
350 | """
351 |
352 | signature = (0, (Line, ...))
353 |
354 | def __init__(
355 | self,
356 | value: Optional[NDArray] = None,
357 | prims: Optional[list[TPrim]] = None,
358 | size: int = 3
359 | ):
360 | return super().__init__(value=value, prims=prims, size=size)
361 |
362 | @classmethod
363 | def init_children(
364 | cls, prims: list[Point], size: int=3
365 | ):
366 | points = prims
367 | child_prims = [
368 | Line(np.array([]), (pointa, pointb))
369 | for pointa, pointb in zip(points[:], points[1:] + points[:1])
370 | ]
371 | child_keys = [f"Line{n}" for n, _ in enumerate(child_prims)]
372 | return child_keys, child_prims
373 |
374 | @classmethod
375 | def default_value(cls, size: int=3):
376 | return np.array(())
377 |
378 | @classmethod
379 | def default_prims(cls, size: int=3):
380 | # Generate points around circle
381 | ii = np.arange(size)
382 | xs = np.cos(2*np.pi/size * ii)
383 | ys = np.sin(2*np.pi/size * ii)
384 | return [Point((x, y)) for x, y in zip(xs, ys)]
385 |
386 |
387 | class Quadrilateral(Polygon):
388 | """
389 | A quadrilateral (4 sided polygon)
390 |
391 | For modelling rectangles in matplotlib (`axes`, `bbox`, etc.) the lines
392 | treated as the bottom, right, top, and left of a box in a clockwise fasion.
393 | Specifically, the lines correspond to:
394 | - 'Line0' : bottom
395 | - 'Line1' : right
396 | - 'Line2' : top
397 | - 'Line3' : left
398 |
399 | Parameters
400 | ----------
401 | value: Optional[NDArray] with shape ()
402 | An empty vector
403 | prims: Optional[tuple[Point, Point, Point, Point]]
404 | A tuple of 4 vertices for the quadrilateral
405 | """
406 |
407 | signature = (0, (Line, Line, Line, Line))
408 |
409 | def __init__(
410 | self,
411 | value: Optional[NDArray] = None,
412 | children: Optional[list[Point]] = None,
413 | ):
414 | super().__init__(value, children, size=4)
415 |
416 | @classmethod
417 | def default_value(cls, size: int=4):
418 | return np.array(())
419 |
420 | @classmethod
421 | def default_prims(cls, size: int=4):
422 | # Generate a unit square
423 | xs = [0, 1, 1, 0]
424 | ys = [0, 0, 1, 1]
425 | return [Point((x, y)) for x, y in zip(xs, ys)]
426 |
427 | AxisPrims = tuple[Quadrilateral, Point]
428 | AxesChildPrims = (
429 | tuple[Quadrilateral]
430 | | tuple[Quadrilateral, *AxisPrims]
431 | | tuple[Quadrilateral, *AxisPrims, *AxisPrims]
432 | | tuple[Quadrilateral, *AxisPrims, *AxisPrims, *AxisPrims]
433 | | tuple[Quadrilateral, *AxisPrims, *AxisPrims, *AxisPrims, *AxisPrims]
434 | )
435 |
436 | class Axes(Primitive):
437 | """
438 | A collection of quadrilaterals and points representing an axes
439 |
440 | Child primitives are:
441 | - `quad['Frame']` : A `Quadrilateral` representing the plotting area
442 | - `quad['XAxis']` : A `Quadrilateral` representing the x-axis
443 | - `quad['XAxisLabel']` : A `Point` representing the x-axis label anchor
444 | - `quad['YAxis']` : A `Quadrilateral` representing the y-axis
445 | - `quad['YAxisLabel']` : A `Point` representing the y-axis label anchor
446 | - `quad['TwinXAxis']` : A `Quadrilateral` representing the twin x-axis
447 | - `quad['TwinXAxisLabel']` : A `Point` representing the twin x-axis label anchor
448 | - `quad['TwinYAxis']` : A `Quadrilateral` representing the twin y-axis
449 | - `quad['TwinYAxisLabel']` : A `Point` representing the twin y-axis label anchor
450 |
451 | Parameters
452 | ----------
453 | value: Optional[NDArray] with shape ()
454 | An empty vector
455 | prims: Optional[AxesChildPrims]
456 | A tuple of quadrilateral and points
457 |
458 | The number of quadrilaterals and points in `prims` depends on whether an
459 | x/y axis is included.
460 | xaxis, yaxis: bool
461 | Whether to include an x/y axis and the corresponding label
462 |
463 | If false for a given axis, the corresponding child primitives will not
464 | be present.
465 | xtwin, ytwin: bool
466 | Whether to include a twin x/y axis
467 | """
468 |
469 | _AxisPrimClasses = (Quadrilateral, Point)
470 | signature = (
471 | 0,
472 | {
473 | (Quadrilateral,),
474 | (Quadrilateral, *(1 * _AxisPrimClasses)),
475 | (Quadrilateral, *(2 * _AxisPrimClasses)),
476 | (Quadrilateral, *(3 * _AxisPrimClasses)),
477 | (Quadrilateral, *(4 * _AxisPrimClasses)),
478 | }
479 | )
480 |
481 | def __init__(
482 | self,
483 | value: Optional[NDArray]=None,
484 | prims: Optional[AxesChildPrims]=None,
485 | xaxis: bool=False,
486 | yaxis: bool=False,
487 | twinx: bool=False,
488 | twiny: bool=False
489 | ):
490 | super().__init__(value, prims, xaxis=xaxis, yaxis=yaxis, twinx=twinx, twiny=twiny)
491 |
492 | @classmethod
493 | def init_children(
494 | cls,
495 | prims: AxesChildPrims,
496 | xaxis: bool=False,
497 | yaxis: bool=False,
498 | twinx: bool=False,
499 | twiny: bool=False
500 | ) -> tuple[list[str], AxesChildPrims]:
501 |
502 | xaxis_keys = ()
503 | twin_xaxis_keys = ()
504 | if xaxis:
505 | xaxis_keys = ("XAxis", "XAxisLabel")
506 | if twinx:
507 | twin_xaxis_keys = ("TwinXAxis", "TwinXAxisLabel")
508 |
509 | yaxis_keys = ()
510 | twin_yaxis_keys = ()
511 | if yaxis:
512 | yaxis_keys = ("YAxis", "YAxisLabel")
513 | if twiny:
514 | twin_yaxis_keys = ("TwinYAxis", "TwinYAxisLabel")
515 |
516 | keys = (
517 | ("Frame",)
518 | + (xaxis_keys + yaxis_keys)
519 | + (twin_xaxis_keys + twin_yaxis_keys)
520 | )
521 | return (keys, prims)
522 |
523 | @classmethod
524 | def default_value(
525 | cls,
526 | xaxis: bool=False,
527 | yaxis: bool=False,
528 | twinx: bool=False,
529 | twiny: bool=False
530 | ):
531 | return np.array(())
532 |
533 | @classmethod
534 | def default_prims(
535 | cls,
536 | xaxis: bool=False,
537 | yaxis: bool=False,
538 | twinx: bool=False,
539 | twiny: bool=False
540 | ):
541 | xaxis_prims = ()
542 | twin_xaxis_prims = ()
543 | if xaxis:
544 | xaxis_prims = (Quadrilateral(), Point())
545 | if twinx:
546 | twin_xaxis_prims = (Quadrilateral(), Point())
547 |
548 | yaxis_prims = ()
549 | twin_yaxis_prims = ()
550 | if yaxis:
551 | yaxis_prims = (Quadrilateral(), Point())
552 | if twiny:
553 | twin_yaxis_prims = (Quadrilateral(), Point())
554 |
555 | return (
556 | (Quadrilateral(),)
557 | + (xaxis_prims + yaxis_prims)
558 | + (twin_xaxis_prims + twin_yaxis_prims)
559 | )
560 |
561 |
562 | ## Register `Primitive` classes as `jax.pytree`
563 | _PrimitiveClasses = [
564 | Primitive,
565 | PrimitiveNode,
566 | Point,
567 | Line,
568 | Polygon,
569 | Quadrilateral,
570 | Axes,
571 | ]
572 | for _PrimitiveClass in _PrimitiveClasses:
573 | _flatten_primitive, _unflatten_primitive = cn._make_flatten_unflatten(_PrimitiveClass)
574 | jax.tree_util.register_pytree_node(
575 | _PrimitiveClass, _flatten_primitive, _unflatten_primitive
576 | )
577 |
578 |
579 | ## Primitive value vector methods
580 | # These are used to get the primitive parameter vector from a primitive
581 | # and to update primitives with new parameters
582 |
583 | def filter_unique_values_from_prim(
584 | root_prim: Primitive,
585 | ) -> tuple[dict[str, int], list[Primitive]]:
586 | """
587 | Return unique primitives from a root primitive and indicate their indices
588 |
589 | Note that primitives in a primitive node are not necessarily unique
590 | ; for example `Point`s are shared between lines in a polygon.
591 |
592 | When solving a set of geometric constraints, the geometric constraint
593 | residual should be linked to a function of unique primitives only.
594 |
595 | Returns
596 | -------
597 | prim_to_idx: dict[str, int]
598 | A mapping from each primitive key to its unique primitive index
599 | prims: list[Primitive]
600 | A list of unique primitives
601 | """
602 | value_id_to_idx = {}
603 | values = []
604 | prim_to_idx = {}
605 |
606 | for key, prim in cn.iter_flat("", root_prim):
607 | value_id = id(prim.value)
608 |
609 | if value_id not in value_id_to_idx:
610 | values.append(prim.value)
611 | value_idx = len(values) - 1
612 | value_id_to_idx[value_id] = value_idx
613 | else:
614 | value_idx = value_id_to_idx[value_id]
615 |
616 | prim_to_idx[key] = value_idx
617 |
618 | return prim_to_idx, values
619 |
620 | def build_prim_from_unique_values(
621 | flat_prim: list[cn.FlatNodeStructure], prim_to_idx: dict[str, int], values: list[NDArray]
622 | ) -> Primitive:
623 | """
624 | Return a new primitive with values updated from unique values
625 |
626 | Parameters
627 | ----------
628 | flat_prim: list[FlatNodeStructure]
629 | The flat primitive tree (see `flatten`)
630 | prim_to_idx: dict[str, int]
631 | A mapping from each primitive key to a unique primitive value in `values`
632 | values: list[NDArray]
633 | A list of primitive values for unique primitives in `root_prim`
634 |
635 | Returns
636 | -------
637 | Primitive
638 | The new primitive with updated values
639 | """
640 | prim_keys = (flat_struct[0] for flat_struct in flat_prim)
641 | new_prim_values = (values[prim_to_idx[key]] for key in prim_keys)
642 |
643 | new_prim_structs = [
644 | (prim_key, PrimType, new_value, child_keys)
645 | for (prim_key, PrimType, _old_value, child_keys), new_value
646 | in zip(flat_prim, new_prim_values)
647 | ]
648 | return cn.unflatten(new_prim_structs)[0]
649 |
--------------------------------------------------------------------------------
/src/mpllayout/solver.py:
--------------------------------------------------------------------------------
1 | """
2 | Solvers for constrained geometric primitives
3 | """
4 |
5 | from typing import Any
6 | from numpy.typing import NDArray
7 |
8 | import warnings
9 |
10 | import jax
11 | from jax import numpy as jnp
12 | import numpy as np
13 | from scipy.optimize import minimize, OptimizeResult
14 |
15 | from . import primitives as pr
16 | from . import constraints as cr
17 | from . import containers as cn
18 | from . import layout as lay
19 |
20 | IntGraph = list[tuple[int, ...]]
21 | StrGraph = list[tuple[str, ...]]
22 |
23 | SolverInfo = dict[str, Any]
24 |
25 | def solve(
26 | layout: lay.Layout,
27 | abs_tol: float = 1e-10,
28 | rel_tol: float = 1e-7,
29 | max_iter: int = 10,
30 | method: str='newton'
31 | ) -> tuple[pr.PrimitiveNode, SolverInfo]:
32 | """
33 | Return geometric primitives that satisfy constraints
34 |
35 | This solves a set of, potentially non-linear, geometric constraints with an
36 | iterative method.
37 |
38 | Parameters
39 | ----------
40 | layout: lay.Layout
41 | The layout of geometric primitives and constraints to solve
42 | abs_tol, rel_tol: float
43 | The absolute and relative tolerance for the iterative solution
44 | max_iter: int
45 | The maximum number of iterations for the iterative solution
46 | method: Optional[str]
47 | A solver method (one of 'newton', 'minimize')
48 |
49 | Returns
50 | -------
51 | pr.PrimitiveNode
52 | A primitive tree satisfying the constraints
53 | SolverInfo
54 | Information about the iterative solution
55 |
56 | Keys are:
57 | 'abs_errs':
58 | A list of absolute errors for each solver iteration.
59 | This is the 2-norm of the constraint residual vector.
60 | 'rel_errs':
61 | A list of relative errors for each solver iteration.
62 | This is the absolute error at each iteration, relative to the
63 | initial absolute error.
64 | """
65 | if method == 'newton':
66 | return solve_newton(layout, abs_tol, rel_tol, max_iter)
67 | elif method == 'minimize':
68 | return solve_minimize(layout, abs_tol, rel_tol, max_iter)
69 | else:
70 | raise ValueError(f"Invalid `method` {method}")
71 |
72 | # TODO: Add sparse jacobian assembly and solving via sparse least squares
73 |
74 | def solve_newton(
75 | layout: lay.Layout,
76 | abs_tol: float = 1e-10,
77 | rel_tol: float = 1e-7,
78 | max_iter: int = 10,
79 | ) -> tuple[pr.PrimitiveNode, SolverInfo]:
80 | """
81 | Return geometric primitives that satisfy constraints using a newton method
82 |
83 | Parameters
84 | ----------
85 | Parameters match those for `solve` except for `method`
86 |
87 | See `solve` for more details.
88 |
89 | Returns
90 | -------
91 | Returns match those for `solve`
92 | """
93 |
94 | ## Set-up assembly function for the global residual as a function of a global
95 | ## parameter list
96 |
97 | # `prim_idx_bounds` stores the right/left indices for each primitive's
98 | # parameter vector in the global parameter vector array
99 | # For primitive with index `n`, for example,
100 | # `prim_idx_bounds[n], prim_idx_bounds[n+1]` are the indices between which
101 | # the parameter vectors are stored.
102 | flat_prim = cn.flatten('', layout.root_prim)
103 | prim_graph, prim_values = pr.filter_unique_values_from_prim(layout.root_prim)
104 | prim_idx_bounds = np.cumsum([0] + [value.size for value in prim_values])
105 | global_param_n = np.concatenate(prim_values)
106 |
107 | constraints, constraint_graph, constraint_params = layout.flat_constraints()
108 |
109 | @jax.jit
110 | def assem_global_res(global_param):
111 | new_prim_params = [
112 | global_param[idx_start:idx_end]
113 | for idx_start, idx_end in zip(prim_idx_bounds[:-1], prim_idx_bounds[1:])
114 | ]
115 | root_prim = pr.build_prim_from_unique_values(flat_prim, prim_graph, new_prim_params)
116 | residuals = assem_constraint_residual(
117 | root_prim, constraints, constraint_graph, constraint_params
118 | )
119 | return jnp.concatenate(residuals)
120 |
121 | assem_global_jac = jax.jacfwd(assem_global_res)
122 |
123 | ## Iteratively minimize the global residual as function of the global parameter vector
124 | abs_errs = []
125 | rel_errs = []
126 |
127 | n = 0
128 | abs_err = np.inf
129 | rel_err = np.inf
130 | while (abs_err > abs_tol) and (rel_err > rel_tol) and (n < max_iter):
131 |
132 | global_res = assem_global_res(global_param_n)
133 | global_jac = assem_global_jac(global_param_n)
134 |
135 | dglobal_param, err, rank, s = np.linalg.lstsq(
136 | global_jac, -global_res, rcond=None
137 | )
138 | global_param_n = global_param_n + dglobal_param
139 |
140 | n += 1
141 | abs_err = np.linalg.norm(global_res)
142 | abs_errs.append(abs_err)
143 | with warnings.catch_warnings():
144 | warnings.simplefilter("ignore", RuntimeWarning)
145 | rel_err = abs_errs[-1] / abs_errs[0]
146 | rel_errs.append(rel_err)
147 |
148 | nonlinear_solve_info = {"abs_errs": abs_errs, "rel_errs": rel_errs}
149 |
150 | ## Build a new primitive tree from the global parameter vector
151 | prim_params_n = [
152 | np.array(global_param_n[idx_start:idx_end])
153 | for idx_start, idx_end in zip(prim_idx_bounds[:-1], prim_idx_bounds[1:])
154 | ]
155 | root_prim_n = pr.build_prim_from_unique_values(flat_prim, prim_graph, prim_params_n)
156 |
157 | return root_prim_n, nonlinear_solve_info
158 |
159 |
160 | def solve_minimize(
161 | layout: lay.Layout,
162 | abs_tol: float = 1e-10,
163 | rel_tol: float = 1e-7,
164 | max_iter: int = 10,
165 | ) -> tuple[pr.PrimitiveNode, SolverInfo]:
166 | """
167 | Return geometric primitives that satisfy constraints using minimization (L-BFGS-B)
168 |
169 | The minimization strategies are from `scipy`.
170 |
171 | Parameters
172 | ----------
173 | Parameters match those for `solve` except for `method`
174 |
175 | See `solve` for more details.
176 |
177 | Returns
178 | -------
179 | Returns match those for `solve`
180 | """
181 |
182 | ## Set-up assembly function for the global residual as a function of a global
183 | ## parameter list
184 |
185 | # `prim_idx_bounds` stores the right/left indices for each primitive's
186 | # parameter vector in the global parameter vector array
187 | # For primitive with index `n`, for example,
188 | # `prim_idx_bounds[n], prim_idx_bounds[n+1]` are the indices between which
189 | # the parameter vectors are stored.
190 | flat_prim = cn.flatten('', layout.root_prim)
191 | prim_graph, prim_values = pr.filter_unique_values_from_prim(layout.root_prim)
192 | prim_idx_bounds = np.cumsum([0] + [value.size for value in prim_values])
193 | global_param_n = np.concatenate(prim_values)
194 |
195 | constraints, constraint_graph, constraint_params = layout.flat_constraints()
196 |
197 | @jax.jit
198 | def assem_objective(global_param):
199 | new_prim_params = [
200 | global_param[idx_start:idx_end]
201 | for idx_start, idx_end in zip(prim_idx_bounds[:-1], prim_idx_bounds[1:])
202 | ]
203 | root_prim = pr.build_prim_from_unique_values(flat_prim, prim_graph, new_prim_params)
204 | residuals = assem_constraint_residual(
205 | root_prim, constraints, constraint_graph, constraint_params
206 | )
207 | return jnp.sum(jnp.concatenate(residuals)**2)
208 |
209 | class MinHistory:
210 |
211 | def __init__(self):
212 | self.abs_errs = []
213 | self.rel_errs = []
214 |
215 | def callback(self, intermediate_result: OptimizeResult):
216 | abs_err = intermediate_result['fun']
217 | self.abs_errs.append(abs_err)
218 |
219 | rel_err = abs_err / self.abs_errs[0]
220 | self.rel_errs.append(rel_err)
221 |
222 | if rel_err < rel_tol or abs_err < abs_tol:
223 | raise StopIteration()
224 |
225 | min_hist = MinHistory()
226 |
227 | ## Iteratively minimize the global residual as function of the global parameter vector
228 |
229 | # TODO: (not critical) Implement other optimization solvers besides 'L-BFGS-B'
230 | res = minimize(
231 | jax.value_and_grad(assem_objective),
232 | global_param_n,
233 | method='L-BFGS-B',
234 | jac=True,
235 | callback=min_hist.callback,
236 | options={'maxiter': max_iter}
237 | )
238 | global_param_n = res['x']
239 |
240 | prim_params_n = [
241 | np.array(global_param_n[idx_start:idx_end])
242 | for idx_start, idx_end in zip(prim_idx_bounds[:-1], prim_idx_bounds[1:])
243 | ]
244 | root_prim_n = pr.build_prim_from_unique_values(flat_prim, prim_graph, prim_params_n)
245 |
246 | nonlinear_solve_info = {
247 | "abs_errs": min_hist.abs_errs, "rel_errs": min_hist.rel_errs
248 | }
249 |
250 | return root_prim_n, nonlinear_solve_info
251 |
252 |
253 | def assem_constraint_residual(
254 | root_prim: pr.Primitive,
255 | constraints: list[cr.Constraint],
256 | constraint_graph: list[cr.PrimKeys],
257 | constraint_params: list[cr.Params]
258 | ) -> list[NDArray]:
259 | """
260 | Return a list of constraint residual vectors
261 |
262 | Parameters
263 | ----------
264 | root_prim: pr.Primitive
265 | The primitive which constraints act on
266 | constraints: list[cr.Constraint]
267 | A list of constraints
268 | constraint_graph: list[cr.PrimKeys]
269 | A list of keys indicating primitives in `root_prim` for each constraint
270 | constraint_params: list[cr.ResParams]
271 | A list of parameters for each constraint
272 |
273 | Returns
274 | -------
275 | residuals: list[NDArray]
276 | A list of constraint residual vectors
277 |
278 | Each residual vector corresponds to a constraint in `constraints`
279 | """
280 | residuals = [
281 | constraint(tuple(root_prim[key] for key in prim_keys), *param)
282 | for constraint, prim_keys, param in zip(constraints, constraint_graph, constraint_params)
283 | ]
284 | return residuals
285 |
--------------------------------------------------------------------------------
/src/mpllayout/ui.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for visualizing primitives and constraints
3 | """
4 |
5 | from typing import Callable, Optional
6 | from matplotlib.figure import Figure
7 | from matplotlib.axes import Axes
8 |
9 | import matplotlib as mpl
10 | from matplotlib.patches import Polygon
11 | from matplotlib import pyplot as plt
12 | from matplotlib.colors import Colormap
13 | import numpy as np
14 |
15 | from . import primitives as pr
16 | from . import constraints as cr
17 | from . import constructions as cn
18 |
19 | ## Functions for plotting geometric primitives
20 |
21 |
22 | def plot_point(ax: Axes, point: pr.Point, label: Optional[str]=None, **kwargs):
23 | """
24 | Plot a point
25 |
26 | Parameters
27 | ----------
28 | ax: Axes
29 | The axes to plot in
30 | point: pr.Point
31 | The point to plot
32 | label: Optional[str]
33 | A label
34 | **kwargs
35 | Additional keyword arguments for plotting
36 | """
37 | x, y = point.value
38 | line, = ax.plot([x], [y], marker=".", **kwargs)
39 | ax.annotate(label, (x, y), ha='center', **kwargs)
40 |
41 | def rotation_from_line(line: pr.Line) -> float:
42 | """
43 | Return the rotation of a line vector
44 | """
45 | line_vec = cn.LineVector.assem((line,))
46 | unit_vec = line_vec / np.linalg.norm(line_vec)
47 |
48 | # Since `unit_vec` has unit length, the x-component is the cosine
49 | theta = 180/np.pi * np.arccos(unit_vec[0])
50 | if unit_vec[1] < 0:
51 | theta = theta + 180
52 |
53 | return theta
54 |
55 | def plot_line(ax: Axes, line: pr.Line, label: Optional[str]=None, **kwargs):
56 | """
57 | Plot a line
58 |
59 | Parameters
60 | ----------
61 | ax: Axes
62 | The axes to plot in
63 | line: pr.Line
64 | The line to plot
65 | label: Optional[str]
66 | A label
67 | **kwargs
68 | Additional keyword arguments for plotting
69 | """
70 | # Don't plot zero length lines
71 | if cn.Length.assem((line,)) != 0:
72 | xs = np.array([point.value[0] for point in line.values()])
73 | ys = np.array([point.value[1] for point in line.values()])
74 | ax.plot(xs, ys, **kwargs)
75 |
76 | xmid = 1/2*xs.sum()
77 | ymid = 1/2*ys.sum()
78 | theta = rotation_from_line(line)
79 | ax.annotate(label, (xmid, ymid), ha='center', va='baseline', rotation=theta, **kwargs)
80 |
81 | def plot_polygon(ax: Axes, polygon: pr.Polygon, label: Optional[str]=None, **kwargs):
82 | """
83 | Plot a `Polygon`
84 |
85 | Parameters
86 | ----------
87 | ax: Axes
88 | The axes to plot in
89 | polygon: pr.Polygon
90 | The polygon to plot
91 | label: Optional[str]
92 | A label
93 | **kwargs
94 | Additional keyword arguments for plotting
95 | """
96 | origin = polygon[f"Line0"]["Point0"].value
97 | points = [
98 | polygon[f"Line{ii}"]["Point0"] for ii in range(len(polygon))
99 | ]
100 | verts = np.array([point.value for point in points])
101 | patch_kwargs = kwargs.copy()
102 | patch_kwargs['alpha'] = 0.1*kwargs['alpha']
103 | poly_patch = Polygon(verts, closed=True, **patch_kwargs)
104 | ax.add_patch(poly_patch)
105 |
106 | # (line,) = ax.plot(xs, ys, **kwargs)
107 | if label is not None:
108 | # Place the label at the first point
109 | ax.annotate(
110 | label,
111 | origin,
112 | xycoords="data",
113 | xytext=(2.0, 2.0),
114 | textcoords="offset points",
115 | ha="left",
116 | va="bottom",
117 | **kwargs
118 | )
119 |
120 | def plot_generic_prim(ax: Axes, prim: pr.Primitive, label: Optional[str]=None, **kwargs):
121 | pass
122 |
123 | def plot_prim(
124 | ax: Axes,
125 | prim: pr.Primitive,
126 | prim_key: str='',
127 | max_label_depth: int = 99,
128 | **kwargs
129 | ):
130 | """
131 | Recursively plot all child primitives of a generic primitive
132 |
133 | Parameters
134 | ----------
135 | ax: Axes
136 | The axes to plot in
137 | prim: pr.Primitive
138 | The primitive to plot
139 | label: Optional[str]
140 | A label
141 | **kwargs
142 | Additional keyword arguments for plotting
143 | """
144 | split_key = prim_key.split("/")
145 |
146 | # The primitive heigth is the maximum primitive depth for all child
147 | # primitives.
148 | prim_height = prim.node_height() + len(split_key) - 1
149 |
150 | # The primitive depth is how far away from the root node the primitive is.
151 | depth = len(split_key) - 1
152 |
153 | # Add labels for primitives before a maximum depth
154 | if depth > max_label_depth or prim_key == '':
155 | label = None
156 | else:
157 | parent_key = depth*"."
158 | label = f"{parent_key}/{split_key[-1]}"
159 |
160 | # Use the prim height to control opacity.
161 | # The root primitive has 100% opacity while deeper child primitives
162 | # have lower opacity.
163 | if prim_height == 0:
164 | s = 1
165 | else:
166 | s = (prim_height - depth)/prim_height
167 | alpha = 1*s + 0.2*(1-s)
168 |
169 | plot = make_plot(prim)
170 | plot(ax, prim, label=label, alpha=alpha, **kwargs)
171 |
172 | if isinstance(prim, pr.Line):
173 | pass
174 | # NOTE: I skipped plotting line start and end points because the labels
175 | # get crowded
176 | # TODO: Implement nice plots of points in lines
177 | # Should try to avoid label overlap
178 | else:
179 | for child_key, child_prim in prim.items():
180 | plot_prim(
181 | ax,
182 | child_prim,
183 | prim_key=f'{prim_key}/{child_key}',
184 | max_label_depth=max_label_depth,
185 | **kwargs
186 | )
187 |
188 |
189 | ## Functions for plotting arbitrary geometric primitives
190 | def make_plot(
191 | prim: pr.Primitive,
192 | ) -> Callable[[Axes, tuple[pr.Primitive, ...]], None]:
193 | """
194 | Return a function that can plot a `pr.Primitive` object
195 |
196 | Parameters
197 | ----------
198 | prim: pr.Primitive
199 | The primitive to plot
200 |
201 | Returns
202 | -------
203 | Callable[[Axes, tuple[pr.Primitive, ...]], None]
204 | A function that can plot the primitive
205 |
206 | This function is one of the above `plot_...` function
207 | (see `plot_point`, `plot_line`, etc.).
208 | """
209 | if isinstance(prim, pr.Point):
210 | return plot_point
211 | elif isinstance(prim, pr.Line):
212 | return plot_line
213 | elif isinstance(prim, pr.Polygon):
214 | return plot_polygon
215 | else:
216 | return plot_generic_prim
217 |
218 |
219 | def plot_prims(ax: Axes, root_prim: pr.Primitive, cmap: Colormap=mpl.colormaps['viridis']):
220 | """
221 | Plot all child primitives in a root primitive
222 |
223 | Parameters
224 | ----------
225 | ax: Axes
226 | The axes to plot in
227 | root_prim: pr.Primitive
228 | The primitive to plot
229 | """
230 | num_prims = len(root_prim)
231 | for ii, (key, prim) in enumerate(root_prim.items()):
232 | color = cmap(ii / num_prims)
233 | plot_prim(ax, prim, prim_key=key, color=color)
234 |
235 |
236 | def figure_prims(
237 | root_prim: pr.Primitive,
238 | fig_size: tuple[float, float] = (8, 8),
239 | major_tick_interval: float = 1.0,
240 | minor_tick_interval: float = 1/8
241 | ) -> tuple[Figure, Axes]:
242 | """
243 | Return a figure of a primitive
244 |
245 | Parameters
246 | ----------
247 | root_prim: pr.Primitive
248 | The primitive to plot
249 | fig_size: tuple[float, float]
250 | The figure size
251 | major_tick_interval, minor_tick_interval: float, float
252 | Major and minor tick intervals for grid lines
253 |
254 | By default these are 1 and 1/8 which is nice for inch dimensions.
255 |
256 | Returns
257 | -------
258 | fig: Figure
259 | The figure
260 | ax: Axes
261 | The axes
262 | """
263 |
264 | fig, ax = plt.subplots(1, 1, figsize=fig_size)
265 |
266 | for axis in (ax.xaxis, ax.yaxis):
267 | axis.set_minor_locator(mpl.ticker.MultipleLocator(minor_tick_interval))
268 | axis.set_major_locator(mpl.ticker.MultipleLocator(major_tick_interval))
269 |
270 | ax.set_aspect(1)
271 | ax.grid()
272 |
273 | ax.set_xlabel("x [in]")
274 | ax.set_ylabel("y [in]")
275 |
276 | plot_prims(ax, root_prim)
277 |
278 | return (fig, ax)
279 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jon-deng/mpl-layout/6a10aeff3dfa6e989cc8d9beefa55b78527ebe6a/tests/__init__.py
--------------------------------------------------------------------------------
/tests/fixture_primitives.py:
--------------------------------------------------------------------------------
1 | """
2 | Fixtures to create primitives
3 | """
4 |
5 | from numpy.typing import NDArray
6 |
7 | import pytest
8 | import itertools
9 |
10 | import numpy as np
11 |
12 | from mpllayout import primitives as pr
13 |
14 | class GeometryFixtures:
15 | """
16 | Utilities to help create primitives
17 | """
18 |
19 | ## Coordinate axes
20 |
21 | @pytest.fixture(
22 | params=[
23 | ('x', np.array([1, 0])),
24 | ('y', np.array([0, 1]))
25 | ]
26 | )
27 | def axis_name_dir(self, request):
28 | return request.param
29 |
30 | @pytest.fixture()
31 | def axis_name(self, axis_name_dir):
32 | return axis_name_dir[0]
33 |
34 | @pytest.fixture()
35 | def axis_dir(self, axis_name_dir):
36 | return axis_name_dir[1]
37 |
38 | ## Point creation
39 | def make_point(self, coord):
40 | """
41 | Return a `pr.Point` at the given coordinates
42 | """
43 | return pr.Point(value=coord)
44 |
45 | def make_relative_point(self, point: pr.Point, displacement: NDArray):
46 | """
47 | Return a `pr.Point` displaced from a given point
48 | """
49 | return pr.Point(value=point.value + displacement)
50 |
51 | ## Line creation
52 | def make_rotation(self, theta: float):
53 | rot_mat = np.array(
54 | [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
55 | )
56 | return rot_mat
57 |
58 | def make_line(self, origin: NDArray, line_vec: NDArray):
59 | """
60 | Return a `pr.Line` with given origin and line vector
61 | """
62 | coords = (origin, origin + line_vec)
63 | return pr.Line(value=[], prims=tuple(pr.Point(x) for x in coords))
64 |
65 | def make_relative_line(
66 | self, line: pr.Line, translation: NDArray, deformation: NDArray
67 | ):
68 | """
69 | Return a `pr.Line` deformed about it's start point then translated
70 | """
71 | lineb_vec = line[1].value - line[0].value
72 | lineb_vec = deformation @ lineb_vec
73 |
74 | lineb_start = line[0].value + translation
75 | return self.make_line(lineb_start, lineb_vec)
76 |
77 | def make_relline_about_mid(
78 | self, line: pr.Line, translation: NDArray, deformation: NDArray
79 | ):
80 | """
81 | Return a `pr.Line` deformed about it's midpoint then translated
82 | """
83 | lineb_vec = line[1].value - line[0].value
84 | lineb_vec = deformation @ lineb_vec
85 |
86 | lineb_mid = 1/2*(line[0].value + line[1].value) + translation
87 | lineb_start = lineb_mid - lineb_vec/2
88 | return self.make_line(lineb_start, lineb_vec)
89 |
90 | ## Quadrilateral creation
91 | def make_quad(self, displacement, deformation):
92 | """
93 | Return a `pr.Quadrilateral` translated and deformed from a unit quadrilateral
94 | """
95 | # Specify vertices of a unit square, then deform it and translate it
96 | verts = np.array([[0, 0], [1, 0], [1, 1], [0, 1]])
97 | verts = np.tensordot(verts, deformation, axes=(-1, -1))
98 | verts = verts + displacement
99 |
100 | return pr.Quadrilateral(
101 | value=[], children=tuple(pr.Point(vert) for vert in verts)
102 | )
103 |
104 | def make_quad_grid(
105 | self,
106 | translation: NDArray,
107 | col_margins: NDArray,
108 | row_margins: NDArray,
109 | col_widths: NDArray,
110 | row_heights: NDArray,
111 | ):
112 | """
113 | Return a grid of `Quadrilateral`s with shape (M, N)
114 |
115 | Rows and columns are numbered from top to bottom and left to right, respectively.
116 |
117 | Parameters
118 | ----------
119 | col_margins, row_margins: NDArray (N-1,), (M-1,)
120 | Column and row margins
121 | col_widths, row_heights: NDArray (N,), (M,)
122 | Column and row dimensions
123 | """
124 | # Determine translations/transformations needed for each quad
125 | col_defs = col_widths[:, None, None] * np.outer([1, 0], [1, 0]) + np.outer(
126 | [0, 1], [0, 1]
127 | )
128 | row_defs = np.outer([1, 0], [1, 0]) + row_heights[:, None, None] * np.outer(
129 | [0, 1], [0, 1]
130 | )
131 |
132 | cum_col_widths = np.cumsum(
133 | np.concatenate((translation[[0]], col_widths[:-1] + col_margins))
134 | )
135 | col_trans = np.stack([np.array([x, 0]) for x in cum_col_widths], axis=0)
136 | cum_row_heights = np.cumsum(
137 | np.concatenate((translation[[1]], -row_heights[1:] - row_margins))
138 | )
139 | row_trans = np.stack([np.array([0, y]) for y in cum_row_heights], axis=0)
140 |
141 | row_args = zip(row_trans, row_defs)
142 | col_args = zip(col_trans, col_defs)
143 |
144 | quads = tuple(
145 | self.make_quad(drow + dcol, row_def @ col_def)
146 | for (drow, row_def), (dcol, col_def) in itertools.product(
147 | row_args, col_args
148 | )
149 | )
150 | return quads
--------------------------------------------------------------------------------
/tests/test_constructions.py:
--------------------------------------------------------------------------------
1 | """
2 | Test geometric onstraints
3 | """
4 |
5 | import pytest
6 |
7 | from numpy.typing import NDArray
8 |
9 | import numpy as np
10 |
11 | from mpllayout import primitives as pr
12 | from mpllayout import constructions as con
13 | from mpllayout.containers import Node, accumulate
14 |
15 | from tests.fixture_primitives import GeometryFixtures
16 |
17 | class TestConstructionValidation(GeometryFixtures):
18 |
19 | def test_validate_prims(self):
20 | midpoint = con.Midpoint()
21 |
22 | line = self.make_line(np.random.rand(2), np.random.rand(2))
23 | point_a = self.make_point(np.random.rand(2))
24 | point_b = self.make_point(np.random.rand(2))
25 |
26 | # midpoint((point_a, point_b))
27 |
28 | with pytest.raises(TypeError, match=r"Expected tuple of primitives.*") as exc:
29 | midpoint(line)
30 | print(exc.value)
31 |
32 | with pytest.raises(TypeError, match=r"Expected [0-9]* primitives.*") as exc:
33 | midpoint((point_a, point_b))
34 | print(exc.value)
35 |
36 | with pytest.raises(TypeError, match=r"Expected primitive types.*") as exc:
37 | midpoint((point_a,))
38 | print(exc.value)
39 |
40 | class TestConstructionFunctions(GeometryFixtures):
41 |
42 | def test_transform_map(self):
43 | num_point = 2
44 | coords = np.random.rand(num_point, 2)
45 | points = tuple(pr.Point(value=coord) for coord in coords)
46 |
47 | coords_ref = np.concatenate([con.Coordinate()((point,)) for point in points])
48 | map_coordinate = con.transform_map(con.Coordinate(), num_point*[pr.Point])
49 | coords_map = map_coordinate(points)
50 |
51 | assert np.all(np.isclose(coords_ref, coords_map))
52 |
53 | def test_transform_constraint(self):
54 |
55 | construction = con.OuterMargin(side='right')
56 | quada = self.make_quad(np.zeros(2), np.diag(np.ones(2)))
57 | quadb = self.make_quad(np.array([1.5, 0]), np.diag(np.ones(2)))
58 | prims = (quada, quadb)
59 |
60 | value = construction(prims)
61 |
62 | constraint = con.transform_constraint(construction)
63 | res = constraint(prims, value)
64 | assert np.all(np.isclose(res, 0))
65 |
66 | construction = con.Coordinate()
67 | point = self.make_point(np.random.rand(2))
68 | prims = (point,)
69 |
70 | value = construction(prims)
71 |
72 | constraint = con.transform_constraint(construction)
73 | res = constraint(prims, value)
74 | assert np.all(np.isclose(res, 0))
75 |
76 | def test_transform_sum(self):
77 | construction = con.Coordinate()
78 | sum_construction = con.transform_sum(construction, construction)
79 |
80 | prims_a = (self.make_point(np.random.rand(2)),)
81 | prims_b = (self.make_point(np.random.rand(2)),)
82 | params_a = ()
83 | params_b = ()
84 |
85 | res_a = construction(prims_a, *params_a) + construction(prims_b, *params_b)
86 | res_b = sum_construction(prims_a + prims_b, *(params_a + params_b))
87 |
88 | assert np.all(np.isclose(res_a, res_b))
89 |
90 | res_b = (construction + construction)(prims_a+prims_b, *(params_a + params_b))
91 |
92 | assert np.all(np.isclose(res_a, res_b))
93 |
94 | def test_transform_scalar_mul(self):
95 | construction = con.Coordinate()
96 |
97 | prims = (self.make_point(np.random.rand(2)),)
98 | cons_params = ()
99 | scalar = np.random.rand()
100 | res_a = scalar * construction(prims, *cons_params)
101 |
102 | # Test constant version
103 | params = cons_params + ()
104 |
105 | mul_construction = con.transform_scalar_mul(construction, scalar)
106 | res_b = mul_construction(prims, *params)
107 | assert np.all(np.isclose(res_a, res_b))
108 |
109 | mul_construction = scalar * construction
110 | res_b = mul_construction(prims, *params)
111 | assert np.all(np.isclose(res_a, res_b))
112 |
113 | # Test non constant version
114 | params = cons_params + (scalar,)
115 |
116 | mul_construction = con.transform_scalar_mul(construction, con.Scalar())
117 | res_b = mul_construction(prims, *params)
118 | assert np.all(np.isclose(res_a, res_b))
119 |
120 | mul_construction = con.Scalar() * construction
121 | res_b = mul_construction(prims, *params)
122 | assert np.all(np.isclose(res_a, res_b))
123 |
124 | def test_transform_scalar_mul(self):
125 | construction = con.Coordinate()
126 |
127 | prims = (self.make_point(np.random.rand(2)),)
128 | cons_params = ()
129 | scalar = np.random.rand()
130 | res_a = construction(prims, *cons_params) / scalar
131 |
132 | # Test constant version
133 | params = cons_params + ()
134 |
135 | div_construction = con.transform_scalar_div(construction, scalar)
136 | res_b = div_construction(prims, *params)
137 | assert np.all(np.isclose(res_a, res_b))
138 |
139 | div_construction = construction / scalar
140 | res_b = div_construction(prims, *params)
141 | assert np.all(np.isclose(res_a, res_b))
142 |
143 | # Test non constant version
144 | params = cons_params + (scalar,)
145 |
146 | div_construction = con.transform_scalar_div(construction, con.Scalar())
147 | res_b = div_construction(prims, *params)
148 | assert np.all(np.isclose(res_a, res_b))
149 |
150 | div_construction = construction / con.Scalar()
151 | res_b = div_construction(prims, *params)
152 | assert np.all(np.isclose(res_a, res_b))
153 |
154 | def test_transform_scalar_pow(self):
155 | construction = con.Coordinate()
156 |
157 | prims = (self.make_point(np.random.rand(2)),)
158 | cons_params = ()
159 | scalar = np.random.rand()
160 | res_a = construction(prims, *cons_params) ** scalar
161 |
162 | # Test constant version
163 | params = cons_params + ()
164 |
165 | pow_construction = con.transform_scalar_pow(construction, scalar)
166 | res_b = pow_construction(prims, *params)
167 | assert np.all(np.isclose(res_a, res_b))
168 |
169 | pow_construction = construction ** scalar
170 | res_b = pow_construction(prims, *params)
171 | assert np.all(np.isclose(res_a, res_b))
172 |
173 | # Test non constant version
174 | params = cons_params + (scalar,)
175 |
176 | pow_construction = con.transform_scalar_pow(construction, con.Scalar())
177 | res_b = pow_construction(prims, *params)
178 | assert np.all(np.isclose(res_a, res_b))
179 |
180 | pow_construction = construction ** con.Scalar()
181 | res_b = pow_construction(prims, *params)
182 | assert np.all(np.isclose(res_a, res_b))
183 |
184 | def test_transform_partial(self):
185 | cons = con.DirectedLength()
186 |
187 | line = self.make_line(np.zeros(2), np.random.rand(2))
188 |
189 | direction = np.random.rand(2)
190 | direction = direction / np.linalg.norm(direction)
191 | length_ref = cons((line,), direction)
192 |
193 | length_test = con.transform_partial(cons, direction)((line,))
194 |
195 | assert np.all(np.isclose(length_test, length_ref))
196 |
197 | cons = con.Width()
198 |
199 | quad = self.make_quad(
200 | np.zeros(2), np.diag(np.random.rand(2))
201 | )
202 |
203 | res_ref = cons((quad,))
204 |
205 | res_test = con.transform_partial(cons)((quad,))
206 |
207 | assert np.all(np.isclose(res_test, res_ref))
208 |
209 |
210 | class TestNull:
211 |
212 | @pytest.fixture()
213 | def size_node(self):
214 | node = Node(
215 | 0,
216 | {
217 | 'Vec1': Node(1, {}),
218 | 'Vec2': Node(2, {}),
219 | 'Vec3': Node(
220 | 0,
221 | {'Vec4': Node(1, {})}
222 | )
223 | }
224 | )
225 | return node
226 |
227 | @pytest.fixture()
228 | def vector(self, size_node: Node[int]):
229 | cumsize_node = accumulate(lambda x, y: x + y, size_node, 0)
230 | return np.random.rand(cumsize_node.value)
231 |
232 | def test_Vector(self, size_node: Node[int], vector: NDArray):
233 | vec = con.Vector(size_node)
234 | vector_test = vec.assem((), vector)
235 | print(vector_test, vector)
236 |
237 | assert np.all(np.isclose(vector_test, vector))
238 |
239 |
240 | class TestPoint(GeometryFixtures):
241 | """
242 | Test constructions with signature `[Point]`
243 | """
244 |
245 | @pytest.fixture()
246 | def coordinate(self):
247 | return np.random.rand(2)
248 |
249 | def test_Coordinate(self, coordinate):
250 | point = self.make_point(coordinate)
251 | res = con.Coordinate()((point,)) - coordinate
252 | assert np.all(np.isclose(res, 0))
253 |
254 |
255 | class TestLine(GeometryFixtures):
256 | """
257 | Test constraints with signature `[Line]`
258 | """
259 |
260 | @pytest.fixture()
261 | def length(self):
262 | return np.random.rand()
263 |
264 | @pytest.fixture()
265 | def direction(self):
266 | unit_vec = np.random.rand(2)
267 | unit_vec = unit_vec / np.linalg.norm(unit_vec)
268 | return unit_vec
269 |
270 | @pytest.fixture()
271 | def linea(self, length, direction):
272 | origin = np.random.rand(2)
273 | return self.make_line(origin, direction * length)
274 |
275 | def test_Length(self, linea, length):
276 | res = con.Length()((linea,)) - length
277 | assert np.all(np.isclose(res, 0))
278 |
279 | def test_DirectedLength(self, length, direction):
280 | line_dir = np.random.rand(2)
281 | line_vec = length*line_dir
282 | line = self.make_line((0, 0), line_vec)
283 |
284 | dlength = np.dot(line_vec, direction)
285 | res = con.DirectedLength()((line,), direction) - dlength
286 | assert np.all(np.isclose(res, 0))
287 |
288 | @pytest.fixture()
289 | def XYLength(self, axis_name):
290 | if axis_name == 'x':
291 | return con.XLength
292 | else:
293 | return con.YLength
294 |
295 | def test_XYLength(self, XYLength, axis_dir, length):
296 | line_dir = np.random.rand(2)
297 | line_vec = length*line_dir
298 | line = self.make_line((0, 0), line_vec)
299 |
300 | dlength = np.dot(line_vec, axis_dir)
301 |
302 | res = XYLength()((line,)) - dlength
303 |
304 | assert np.all(np.isclose(res, 0))
305 |
306 |
307 | class TestQuadrilateral(GeometryFixtures):
308 | """
309 | Test constructions with signature `[Quadrilateral]`
310 | """
311 |
312 | @pytest.fixture()
313 | def quada(self):
314 | return self.make_quad(np.random.rand(2), np.random.rand(2, 2))
315 |
316 | @pytest.fixture()
317 | def aspect_ratio(self, quada):
318 | width = np.linalg.norm(con.LineVector.assem((quada['Line0'],)))
319 | height = np.linalg.norm(con.LineVector.assem((quada['Line1'],)))
320 | return width/height
321 |
322 | def test_AspectRatio(self, quada: pr.Quadrilateral, aspect_ratio: float):
323 | res = con.AspectRatio()((quada,)) - aspect_ratio
324 | assert np.all(np.isclose(res, 0))
325 |
326 |
327 | class TestQuadrilateralQuadrilateral(GeometryFixtures):
328 | """
329 | Test constraints with signature `[Quadrilateral, Quadrilateral]`
330 | """
331 |
332 | @pytest.fixture()
333 | def margin(self):
334 | return np.random.rand()
335 |
336 | @pytest.fixture()
337 | def boxa(self):
338 | size = np.random.rand(2)
339 | origin = np.random.rand(2)
340 | return self.make_quad(origin, np.diag(size))
341 |
342 | @pytest.fixture()
343 | def boxb(self):
344 | size = np.random.rand(2)
345 | origin = np.random.rand(2)
346 | return self.make_quad(origin, np.diag(size))
347 |
348 | @pytest.fixture(params=('bottom', 'top', 'left', 'right'))
349 | def margin_side(self, request):
350 | return request.param
351 |
352 | @pytest.fixture()
353 | def outer_margin(self, boxa, boxb, margin_side):
354 | a_topright = boxa['Line1/Point1'].value
355 | b_topright = boxb['Line1/Point1'].value
356 | a_botleft = boxa['Line0/Point0'].value
357 | b_botleft = boxb['Line0/Point0'].value
358 |
359 | if margin_side == 'left':
360 | margin = (a_botleft - b_topright)[0]
361 | elif margin_side == 'right':
362 | margin = (b_botleft - a_topright)[0]
363 | elif margin_side == 'bottom':
364 | margin = (a_botleft - b_topright)[1]
365 | elif margin_side == 'top':
366 | margin = (b_botleft - a_topright)[1]
367 | return margin
368 |
369 | def test_OuterMargin(self, boxa, boxb, outer_margin, margin_side):
370 | res = con.OuterMargin(side=margin_side)((boxa, boxb)) - outer_margin
371 | assert np.all(np.isclose(res, 0))
372 |
373 | @pytest.fixture()
374 | def inner_margin(self, boxa, boxb, margin_side):
375 | a_topright = boxa['Line1/Point1'].value
376 | b_topright = boxb['Line1/Point1'].value
377 | a_botleft = boxa['Line0/Point0'].value
378 | b_botleft = boxb['Line0/Point0'].value
379 |
380 | if margin_side == 'left':
381 | margin = (a_botleft - b_botleft)[0]
382 | elif margin_side == 'right':
383 | margin = (b_topright - a_topright)[0]
384 | elif margin_side == 'bottom':
385 | margin = (a_botleft - b_botleft)[1]
386 | elif margin_side == 'top':
387 | margin = (b_topright-a_topright)[1]
388 | return margin
389 |
390 | def test_InnerMargin(self, boxa, boxb, inner_margin, margin_side):
391 | res = con.InnerMargin(side=margin_side)((boxa, boxb)) - inner_margin
392 | assert np.all(np.isclose(res, 0))
393 |
394 |
--------------------------------------------------------------------------------
/tests/test_containers.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests for `mpllayout.containers`
3 | """
4 |
5 | import pytest
6 |
7 | from timeit import timeit
8 |
9 | import numpy as np
10 |
11 | from mpllayout import containers as cn
12 |
13 |
14 | class NodeFixtures:
15 |
16 | def random_node(
17 | self,
18 | value: float,
19 | depth: int=0,
20 | min_children: int=1,
21 | max_children: int=1,
22 | min_depth: int=0,
23 | max_depth: int=0
24 | ):
25 | num_child = np.random.randint(min_children, max_children)
26 | child_values = np.random.rand(num_child)
27 |
28 | # Give a linear distribution for the probability of stopping
29 | def probability_stop(depth: int):
30 | _p = (depth - min_depth) / (max_depth - min_depth)
31 | return np.clip(_p, 0, 1)
32 |
33 | if np.random.rand() < probability_stop(depth):
34 | children = {}
35 | else:
36 | child_kwargs = {
37 | 'depth': depth+1,
38 | 'min_children': min_children,
39 | 'max_children': max_children,
40 | 'min_depth': min_depth,
41 | 'max_depth': max_depth
42 | }
43 | children = {
44 | f"Child{n:d}": self.random_node(child_value, **child_kwargs)
45 | for n, child_value in enumerate(child_values)
46 | }
47 | return cn.Node(value, children)
48 |
49 | @pytest.fixture(params=(0, 1, 2))
50 | def num_children(self, request):
51 | return request.param
52 |
53 | @pytest.fixture()
54 | def root_value(self):
55 | return np.random.rand()
56 |
57 | @pytest.fixture()
58 | def child_values(self, num_children: int):
59 | return np.random.rand(num_children)
60 |
61 | @pytest.fixture()
62 | def node(self, root_value: float, child_values: list[float]):
63 | """
64 | Return nodes with varying numbers of children
65 |
66 | This covers the case of a leaf node, a node with one child, and a node
67 | with two children.
68 | """
69 | children = {
70 | f'a{n}': cn.Node(child_value, {})
71 | for n, child_value in enumerate(child_values)
72 | }
73 | return cn.Node(root_value, children)
74 |
75 |
76 | class TestNode(NodeFixtures):
77 |
78 | def test_node_height(self):
79 |
80 | # Check that a root node has height 0
81 | node = cn.Node(0, {})
82 | assert node.node_height() == 0
83 |
84 | # Check nodes where one child has a grand child and the other child does not
85 | node = cn.Node(0, {'a1': cn.Node(0, {'b1': cn.Node(0, {})}), 'a2': cn.Node(0, {})})
86 | assert node.node_height() == 2
87 |
88 | node = cn.Node(0, {'a1': cn.Node(0, {}), 'a2': cn.Node(0, {'b1': cn.Node(0, {})})})
89 | assert node.node_height() == 2
90 |
91 | def test_repr(self, node: cn.Node):
92 | print(node)
93 |
94 | def test_iter_flat(self, node: cn.Node):
95 | print([{key: _node} for key, _node in cn.iter_flat("", node)])
96 |
97 | def test_flatten_unflatten_python(self, node: cn.Node):
98 | fnode_structs = cn.flatten("root", node)
99 | reconstructed_node, _ = cn.unflatten(fnode_structs)
100 |
101 | print(fnode_structs)
102 |
103 | assert str(node) == str(reconstructed_node)
104 | print(node)
105 | print(reconstructed_node)
106 |
107 | N = int(1e5)
108 | timeit_kwargs = {"globals": {**globals(), **locals()}, "number": N}
109 |
110 | duration = timeit("cn.flatten('root', node)", **timeit_kwargs)
111 | print(f"Flattening duration: {duration/N: .2e} s")
112 |
113 | duration = timeit("cn.unflatten(fnode_structs)", **timeit_kwargs)
114 | print(f"Unflattening duration: {duration/N: .2e} s")
115 |
116 | def test_flatten_unflatten_jax(self, node: cn.Node):
117 | import jax
118 |
119 | flat_tree, flat_tree_def = jax.tree_util.tree_flatten(node)
120 | reconstructed_node = jax.tree_util.tree_unflatten(flat_tree_def, flat_tree)
121 |
122 | assert str(node) == str(reconstructed_node)
123 | print(node)
124 | print(reconstructed_node)
125 |
126 | N = int(1e5)
127 | timeit_kwargs = {"globals": {**globals(), **locals()}, "number": N}
128 |
129 | duration = timeit("jax.tree_util.tree_flatten(node)", **timeit_kwargs)
130 | print(f"Flattening duration: {duration/N: .2e} s")
131 |
132 | duration = timeit(
133 | "jax.tree_util.tree_unflatten(flat_tree_def, flat_tree)", **timeit_kwargs
134 | )
135 | print(f"Unflattening duration: {duration/N: .2e} s")
136 |
137 |
138 | class TestFunctions(NodeFixtures):
139 |
140 | def test_map(self, node):
141 | def fun(x):
142 | return x+1
143 |
144 | node_test = cn.map(fun, node)
145 |
146 | values_ref = [fun(fnode.value) for _, fnode in cn.iter_flat('', node)]
147 | values_test = [fnode.value for _, fnode in cn.iter_flat('', node_test)]
148 | assert np.all(np.isclose(values_test, values_ref))
149 |
150 | def test_accumulate(self, node, root_value, child_values):
151 | def fun(x, y):
152 | return x + y
153 |
154 | node_test = cn.accumulate(fun, node, initial=0)
155 |
156 | assert np.isclose(node_test.value, root_value + np.sum(child_values))
157 |
--------------------------------------------------------------------------------
/tests/test_layout.py:
--------------------------------------------------------------------------------
1 | """
2 | Test `layout`
3 | """
4 |
5 | import pytest
6 |
7 | from pprint import pprint
8 |
9 | import numpy as np
10 |
11 | from mpllayout import primitives as pr
12 | from mpllayout import constraints as co
13 | from mpllayout import layout as lat
14 | from mpllayout import containers as cn
15 |
16 |
17 | class TestPrimitiveTree:
18 |
19 | @pytest.fixture()
20 | def prim_node(self):
21 | return cn.Node(np.array([]), {})
22 |
23 | def test_set_prim(self, prim_node):
24 | prim_node.add_child("MyBox", pr.Quadrilateral())
25 |
26 | pprint(f"Keys:")
27 | pprint(prim_node["MyBox"].keys())
28 |
29 | def test_build_primtree(self, prim_node):
30 | point_a = pr.Point([0, 0])
31 | point_b = pr.Point([1, 1])
32 | prim_node.add_child("PointA", point_a)
33 | prim_node.add_child("LineA", pr.Line([], (point_a, point_b)))
34 | prim_node.add_child("MySpecialBox", pr.Quadrilateral())
35 |
36 | prim_graph, prim_values = pr.filter_unique_values_from_prim(prim_node)
37 |
38 | params = prim_values
39 |
40 | new_params = [np.random.rand(*param.shape) for param in params]
41 | new_prim_node = pr.build_prim_from_unique_values(cn.flatten('', prim_node), prim_graph, new_params)
42 | # breakpoint()
43 |
44 | # rng = np.random.default_rng()
45 |
46 | # new_tree = lat.build_tree(prim_tree, prim_graph, new_params, {})
47 |
48 | # print("Old primitive graph:")
49 | # pprint(prim_tree.prim_graph())
50 |
51 | # print("Old primitive list")
52 | # pprint(prim_tree.prims())
53 |
54 | # print("Old primitive keys")
55 | # pprint(prim_tree.keys(flat=True))
56 |
57 | # print("New parameters")
58 | # pprint(new_params)
59 |
60 | # print("New primitive graph:")
61 | # pprint(new_tree.prim_graph())
62 |
63 | # print("New primitive list")
64 | # pprint(new_tree.prims())
65 |
66 | # print("New primitive keys")
67 | # pprint(new_tree.keys(flat=True))
68 |
69 |
70 | class TestLayout:
71 |
72 | def test_layout(self):
73 | layout = lat.Layout()
74 |
75 | layout.add_prim(pr.Quadrilateral(), "MyBox")
76 | layout.add_constraint(co.Box(), ("MyBox",), ())
77 |
78 | layout.add_constraint(co.Fix(), ("MyBox/Line0/Point0",), ([0, 0],))
79 |
80 | pprint(layout.root_prim)
81 | constraints, constraints_argkeys, constraints_param = layout.flat_constraints()
82 |
83 | print("Flat constraints: ")
84 | print("Constraints:")
85 | pprint(constraints)
86 | print("Constraints argument keys:")
87 | pprint(constraints_argkeys)
88 | print("Constraints parameter vector:")
89 | pprint(constraints_param)
90 |
91 |
--------------------------------------------------------------------------------
/tests/test_primitives.py:
--------------------------------------------------------------------------------
1 | """
2 | Test geometric primitive and constraints
3 | """
4 |
5 | import pytest
6 |
7 | import typing as tp
8 |
9 | import itertools
10 | from pprint import pprint
11 |
12 | import numpy as np
13 |
14 | from mpllayout import primitives as pr
15 |
16 |
17 | class TestPrimitives:
18 |
19 | def test_Quadrilateral(self):
20 | quad = pr.Quadrilateral()
21 |
22 | def test_Line(self):
23 | line = pr.Line()
24 |
25 | def test_Point(self):
26 | point = pr.Point()
27 |
28 | def test_Polygon(self):
29 | poly = pr.Polygon()
30 |
31 | def test_Primitive_jax_pytree(self):
32 | # breakpoint()
33 | from jax import tree_util
34 |
35 | point = pr.Point()
36 | line = pr.Line()
37 | quad = pr.Quadrilateral()
38 |
39 | for prim in (point, line, quad):
40 | print(f"\nTesting primitive type {type(prim).__name__}")
41 | leaves = tree_util.tree_leaves(prim)
42 | print("Leaves:", leaves)
43 |
44 | value_flat, value_tree = tree_util.tree_flatten(prim)
45 | reconstructed_prim = tree_util.tree_unflatten(value_tree, value_flat)
46 | print("tree_util.tree_flatten:", value_flat, value_tree)
47 | print("tree_util.tree_unflatten:", reconstructed_prim)
48 |
49 | leaves = tree_util.tree_leaves([0, 1, 2, 3, [4, 5, [6, 7, [8]]]])
50 | print(leaves)
51 | print([type(leaf) for leaf in leaves])
52 |
--------------------------------------------------------------------------------
/tests/test_solve.py:
--------------------------------------------------------------------------------
1 | """
2 | Test `solver`
3 | """
4 |
5 | import pytest
6 |
7 | import time
8 | from pprint import pprint
9 |
10 | import numpy as np
11 |
12 | from mpllayout import primitives as pr
13 | from mpllayout import constraints as co
14 | from mpllayout import layout as lay
15 | from mpllayout import containers as cn
16 | from mpllayout import solver
17 |
18 |
19 | class TestPrimitiveTree:
20 |
21 | @pytest.fixture()
22 | def layout(self):
23 | layout = lay.Layout()
24 |
25 | verts = np.array([[0.1, 0.2], [1.0, 2.0], [2.0, 2.0], [3.0, 3.0]])
26 |
27 | layout.add_prim(
28 | pr.Quadrilateral(children=[pr.Point(vert) for vert in verts]),
29 | "MyFavouriteBox",
30 | )
31 | layout.add_constraint(co.Box(), ("MyFavouriteBox",), ())
32 | layout.add_constraint(
33 | co.Fix(), ("MyFavouriteBox/Line0/Point0",), (np.array([0, 0]),)
34 | )
35 |
36 | layout.add_constraint(co.Length(), ("MyFavouriteBox/Line0",), (5.0,))
37 | layout.add_constraint(co.Length(), ("MyFavouriteBox/Line1",), (5.1,))
38 | return layout
39 |
40 | @pytest.fixture(params=[(5, 6)])
41 | def axes_shape(self, request):
42 | return request.param
43 |
44 | @pytest.fixture()
45 | def layout_grid(self, axes_shape):
46 | layout = lay.Layout()
47 | ## Create an origin point
48 | layout.add_prim(pr.Point(), "Origin")
49 | layout.add_constraint(co.Fix(), ("Origin",), (np.array([0, 0]),))
50 |
51 | ## Create the figure box
52 | layout.add_prim(pr.Quadrilateral(), "Figure",)
53 | layout.add_constraint(co.Box(), ("Figure",), ())
54 |
55 | ## Constrain the figure size and position
56 | fig_width, fig_height = 6, 3
57 | layout.add_constraint(
58 | co.Length(), ("Figure/Line0",), (fig_width,)
59 | )
60 | layout.add_constraint(
61 | co.Coincident(), ("Figure/Line0/Point0", "Origin"), ()
62 | )
63 |
64 | ## Create the axes boxes
65 | # axes_shape = (3, 4)
66 | num_row, num_col = axes_shape
67 | num_axes = int(np.prod(axes_shape))
68 | for n in range(num_axes):
69 | layout.add_prim(pr.Quadrilateral(), f"Axes{n}")
70 | layout.add_constraint(co.Box(), (f"Axes{n}",), ())
71 |
72 | ## Constrain the axes in a grid
73 | num_row, num_col = axes_shape
74 | grid_param = (
75 | (num_col - 1) * [1],
76 | (num_row - 1) * [1],
77 | (num_col - 1) * [1 / 16],
78 | (num_row - 1) * [1 / 16],
79 | )
80 | layout.add_constraint(
81 | co.Grid(axes_shape),
82 | tuple(f"Axes{n}" for n in range(num_axes)),
83 | grid_param
84 | )
85 |
86 | # Constrain the first axis aspect ratio
87 | layout.add_constraint(
88 | co.RelativeLength(), ("Axes0/Line0", "Axes0/Line1"), (2,)
89 | )
90 |
91 | # Constrain top/bottom margins
92 | margin_top = 1.1
93 | margin_bottom = 0.5
94 | layout.add_constraint(
95 | co.DirectedDistance(),
96 | ("Axes0/Line1/Point1", "Figure/Line1/Point1"),
97 | (np.array([0, 1]), margin_top)
98 | )
99 | layout.add_constraint(
100 | co.DirectedDistance(),
101 | (f"Axes{num_axes-1}/Line1/Point0", "Figure/Line1/Point0"),
102 | (np.array([0, -1]), margin_bottom)
103 | )
104 |
105 | # Constrain left/right margins
106 | margin_left = 0.2
107 | margin_right = 0.3
108 | layout.add_constraint(
109 | co.DirectedDistance(),
110 | ("Axes0/Line0/Point0", "Figure/Line0/Point0"),
111 | (np.array([-1, 0]), margin_left)
112 | )
113 | layout.add_constraint(
114 | co.DirectedDistance(),
115 | (f"Axes{num_col-1}/Line1/Point1", "Figure/Line1/Point1"),
116 | (np.array([1, 0]), margin_right)
117 | )
118 | return layout
119 |
120 | def test_assem_constraint_residual(self, layout_grid: lay.Layout):
121 | layout = layout_grid
122 |
123 | root_prim = layout.root_prim
124 | flat_constraints = layout.flat_constraints()
125 |
126 | # Plain call
127 | t0 = time.time()
128 | for i in range(50):
129 | solver.assem_constraint_residual(
130 | root_prim, *flat_constraints
131 | )
132 | t1 = time.time()
133 | print(f"Duration {t1-t0:.2e} s")
134 |
135 | # `jax.jit` individual constraint functions
136 | import jax
137 |
138 | constraints = flat_constraints[0]
139 | constraints_jit = [jax.jit(constraint) for constraint in constraints]
140 | flat_constraints_jit = (constraints_jit,) + flat_constraints[1:]
141 | solver.assem_constraint_residual(
142 | root_prim, *flat_constraints_jit
143 | )
144 |
145 | t0 = time.time()
146 | for i in range(50):
147 | solver.assem_constraint_residual(
148 | root_prim, *flat_constraints_jit
149 | )
150 | t1 = time.time()
151 | print(f"Duration {t1-t0:.2e} s")
152 |
153 | # `jax.jit` the overall function
154 |
155 | @jax.jit
156 | def assem_constraint_residual(root_prim):
157 | return solver.assem_constraint_residual(
158 | root_prim, *flat_constraints
159 | )
160 |
161 | assem_constraint_residual(root_prim)
162 | t0 = time.time()
163 | for i in range(50):
164 | assem_constraint_residual(root_prim)
165 | t1 = time.time()
166 | print(f"Duration {t1-t0:.2e} s")
167 |
168 | @pytest.fixture(
169 | params=('newton', 'minimize')
170 | )
171 | def method(self, request):
172 | return request.param
173 |
174 | def test_solve(self, layout: lay.Layout, method: str):
175 | t0 = time.time()
176 | prim_tree_n, solve_info = solver.solve(
177 | layout, method=method, max_iter=100
178 | )
179 | t1 = time.time()
180 | print(f"Solve took {t1-t0:.2e} s")
181 |
182 | prim_keys_to_value = {
183 | key: prim.value for key, prim in cn.iter_flat("", prim_tree_n)
184 | }
185 | pprint(prim_keys_to_value)
186 | pprint(solve_info)
187 |
--------------------------------------------------------------------------------