├── .gitignore ├── LICENCE ├── README.md ├── doc ├── READMEExample.svg └── Summary.svg ├── examples ├── README_example.ipynb ├── contour_colorbar.ipynb ├── grid_axes.ipynb ├── one_axes.ipynb ├── profile_solver.py ├── ten_simple_rules_demo.ipynb ├── tutorial.ipynb └── two_axes.ipynb ├── logo.svg ├── pyproject.toml ├── src └── mpllayout │ ├── __init__.py │ ├── constraints.py │ ├── constructions.py │ ├── containers.py │ ├── geometry.py │ ├── layout.py │ ├── matplotlibutils.py │ ├── primitives.py │ ├── solver.py │ └── ui.py └── tests ├── __init__.py ├── fixture_primitives.py ├── test_constraints.py ├── test_constructions.py ├── test_containers.py ├── test_layout.py ├── test_primitives.py └── test_solve.py /.gitignore: -------------------------------------------------------------------------------- 1 | # My rule 2 | prototype.py 3 | 4 | # Benchmarking files 5 | *.prof 6 | *.profile 7 | *.lprof 8 | 9 | # Figures / images / video 10 | *.png 11 | *.svg 12 | 13 | # Byte-compiled / optimized / DLL files 14 | __pycache__/ 15 | *.py[cod] 16 | *$py.class 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | build/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | wheels/ 35 | share/python-wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | MANIFEST 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .nox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | *.py,cover 62 | .hypothesis/ 63 | .pytest_cache/ 64 | cover/ 65 | 66 | # Translations 67 | *.mo 68 | *.pot 69 | 70 | # Django stuff: 71 | *.log 72 | local_settings.py 73 | db.sqlite3 74 | db.sqlite3-journal 75 | 76 | # Flask stuff: 77 | instance/ 78 | .webassets-cache 79 | 80 | # Scrapy stuff: 81 | .scrapy 82 | 83 | # Sphinx documentation 84 | docs/_build/ 85 | 86 | # PyBuilder 87 | .pybuilder/ 88 | target/ 89 | 90 | # Jupyter Notebook 91 | .ipynb_checkpoints 92 | 93 | # IPython 94 | profile_default/ 95 | ipython_config.py 96 | 97 | # pyenv 98 | # For a library or package, you might want to ignore these files since the code is 99 | # intended to run in multiple environments; otherwise, check them in: 100 | # .python-version 101 | 102 | # pipenv 103 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 104 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 105 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 106 | # install all needed dependencies. 107 | #Pipfile.lock 108 | 109 | # poetry 110 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 111 | # This is especially recommended for binary packages to ensure reproducibility, and is more 112 | # commonly ignored for libraries. 113 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 114 | #poetry.lock 115 | 116 | # pdm 117 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 118 | #pdm.lock 119 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 120 | # in version control. 121 | # https://pdm.fming.dev/#use-with-ide 122 | .pdm.toml 123 | 124 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 125 | __pypackages__/ 126 | 127 | # Celery stuff 128 | celerybeat-schedule 129 | celerybeat.pid 130 | 131 | # SageMath parsed files 132 | *.sage.py 133 | 134 | # Environments 135 | .env 136 | .venv 137 | env/ 138 | venv/ 139 | ENV/ 140 | env.bak/ 141 | venv.bak/ 142 | 143 | # Spyder project settings 144 | .spyderproject 145 | .spyproject 146 | 147 | # Rope project settings 148 | .ropeproject 149 | 150 | # mkdocs documentation 151 | /site 152 | 153 | # mypy 154 | .mypy_cache/ 155 | .dmypy.json 156 | dmypy.json 157 | 158 | # Pyre type checker 159 | .pyre/ 160 | 161 | # pytype static type analyzer 162 | .pytype/ 163 | 164 | # Cython debug symbols 165 | cython_debug/ 166 | 167 | # PyCharm 168 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 169 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 170 | # and can be added to the global gitignore or merged into this file. For a more nuclear 171 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 172 | #.idea/ 173 | -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | Copyright 2023 Jonathan Deng 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MPLLayout 2 | 3 | ![Project logo](logo.svg) 4 | 5 | ## Summary 6 | 7 | MPLLayout is a package to create precise figure layouts in [matplotlib](https://matplotlib.org/). 8 | It works by modelling figure elements using geometric primitives (for example, text anchors are points, the figure is a quadrilateral, etc.), then constraining the sizes and positions of these elements using geometric constraints (for example, fixing the width of a figure, constraining axes sides to be collinear, constraining axes to lie on a grid, etc.). 9 | 10 | Using this approach, MPLLayout can: 11 | 12 | * align figure elements (axes, text label location, x and y axis, etc.), 13 | * specify margins around axes, 14 | * create templates for figures across different mediums (posters, manuscripts, slides, etc.), 15 | * and more! 16 | 17 | ## Basic usage 18 | 19 | The basic workflow to create a constrained layout is demonstrated through this example with a single axes figure. 20 | To create a figure 5 in wide and 4 in tall use the code: 21 | 22 | ```python 23 | import numpy as np 24 | from matplotlib import pyplot as plt 25 | 26 | from mpllayout import layout as ly 27 | from mpllayout import primitives as pr 28 | from mpllayout import constraints as cr 29 | from mpllayout import matplotlibutils as mputils 30 | from mpllayout import solver 31 | 32 | # The layout object stores the geometry and constraints defining the layout 33 | layout = ly.Layout() 34 | 35 | # To add geometry, pass the: geometry primitive and a string key. 36 | # Naming the `Quadrilateral`, "Figure" will cause mpllayout to identify it as a figure 37 | layout.add_prim(pr.Quadrilateral(), "Figure") 38 | 39 | # To create a constraint, pass the: constraint, geometry to constrain, and 40 | # constraint arguments. 41 | # Constraint documentation describes what kind of geometry can be constrained and 42 | # any constraint arguments. 43 | layout.add_constraint(cr.Box(), ("Figure",), ()) 44 | layout.add_constraint(cr.Width(), ("Figure",), (5.0,)) 45 | layout.add_constraint(cr.Height(), ("Figure",), (4.0,)) 46 | ``` 47 | 48 | To create an axes inside the figure with desired margins around the edges, use the code: 49 | 50 | ```python 51 | # To add an axes, pass the `Axes` primitive 52 | # The `Axes` is container of Quadrilaterals representing the drawing area (frame), 53 | # as well as, optionally, the x-axis and y-axis 54 | layout.add_prim(pr.Axes(), "MyAxes") 55 | 56 | # Constrain the axes drawing area to a box 57 | layout.add_constraint(cr.Box(), ("MyAxes/Frame",), ()) 58 | # Set "inner" margins around the outside of the axes frame to the figure 59 | # The inner margin is the distance from a `Quadrilateral` inside another 60 | # `Quadrilateral` 61 | layout.add_constraint(cr.InnerMargin(side="bottom"), ("MyAxes/Frame", "Figure"), (.5,)) 62 | layout.add_constraint(cr.InnerMargin(side="top"), ("MyAxes/Frame", "Figure"), (.5,)) 63 | layout.add_constraint(cr.InnerMargin(side="left"), ("MyAxes/Frame", "Figure"), (2.0,)) 64 | layout.add_constraint(cr.InnerMargin(side="right"), ("MyAxes/Frame", "Figure"), (0.5,)) 65 | ``` 66 | 67 | The desired layout can then be solved and used to create a plot. 68 | 69 | ```python 70 | # Create the figure and any axes from the solved geometry 71 | fig, axs = mputils.subplots(solved_prims) 72 | 73 | x = np.linspace(0, 2*np.pi) 74 | axs["MyAxes"].plot(np.sin(x)) 75 | ``` 76 | 77 | The code above results in the figure: 78 | ![README example figure](doc/READMEExample.svg) 79 | 80 | While this approach requires more code to specify the layout of figure elements, it allows precise specification of the layout. 81 | This can be useful, for example, in publication documents where precise figure, axes, and font sizes are desired. 82 | In addition, layouts can be adjusted (for example, by adjusting margin arguments) and serve as a template to generate multiple figures. 83 | 84 | More examples can be found in the `examples` folder which demonstrate other constraints and geometric primitives used to achieve more complicated layouts. 85 | The example given above can be found at `examples/README_example.ipynb`. 86 | The tutorial notebook in `examples/tutorial.ipynb` demonstrates the basic usage of the package and explains some of the commonly used geometric constraints. 87 | Other examples are also given in the `examples` folder. 88 | The notebook at `examples/ten_simple_rules_demo.ipynb` contains an interactive demo to recreate a figure from ["Ten Simple Rules For Better Figures"](https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1003833) (Rougier, Droettboom and Bourne 2014). 89 | 90 | ## Motivation 91 | 92 | Matplotlib contains several strategies for creating figure layouts (for example, `GridSpec` and `subplots` for grid-based layouts). 93 | While these approaches work well, greater control over figure element positions is sometimes desirable; 94 | for example, when preparing figures for published documents, research papers, or slides. 95 | 96 | ## Installation 97 | 98 | You can install the package from PyPI using 99 | 100 | ```bash 101 | pip install matplotlib-layout 102 | ``` 103 | 104 | Alternateively, clone the repository into a local drive. 105 | Navigate to the project directory and run 106 | 107 | ```bash 108 | pip install . 109 | ``` 110 | 111 | The package requires `numpy`, `matplotlib`, and `jax`. 112 | 113 | ## Contributing 114 | 115 | This project is a work in progress so there are likely bugs and missing features. 116 | If you would like to contribute a bug fix, a feature, refactor etc. thank you! 117 | All contributions are welcome. 118 | 119 | ## Motivation and Similar Projects 120 | 121 | A similar project with a geometric constraint solver is [`pygeosolve`](https://github.com/SeanDS/pygeosolve). 122 | There is also another project prototype for a constraint-based layout engine for `matplotlib` [`MplLayouter`](https://github.com/Tillsten/MplLayouter), although it doesn't seem active as of 2023. 123 | -------------------------------------------------------------------------------- /doc/READMEExample.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | 7 | 8 | 9 | 2025-05-04T14:33:25.943089 10 | image/svg+xml 11 | 12 | 13 | Matplotlib v3.9.2, https://matplotlib.org/ 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 30 | 31 | 32 | 33 | 39 | 40 | 41 | 42 | 43 | 44 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 323 | 330 | 331 | 332 | 333 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | 346 | 347 | 348 | 349 | 359 | 360 | 361 | 362 | 363 | 364 | 365 | 366 | 367 | 368 | 369 | 370 | 371 | 372 | 373 | 374 | 375 | 376 | 377 | 378 | 379 | 380 | 381 | 382 | 383 | 384 | 385 | 386 | 387 | 388 | 389 | 390 | 391 | 392 | 393 | 394 | 395 | 396 | 397 | 398 | 399 | 400 | 401 | 402 | 403 | 404 | 405 | 406 | 407 | 408 | 409 | 410 | 411 | 412 | 413 | 414 | 415 | 416 | 417 | 418 | 419 | 420 | 421 | 422 | 423 | 424 | 425 | 426 | 427 | 428 | 429 | 430 | 431 | 432 | 433 | 434 | 435 | 436 | 437 | 438 | 439 | 440 | 441 | 442 | 443 | 444 | 445 | 446 | 447 | 448 | 449 | 450 | 451 | 452 | 453 | 454 | 455 | 456 | 457 | 458 | 459 | 460 | 461 | 462 | 463 | 464 | 465 | 466 | 467 | 468 | 469 | 470 | 471 | 472 | 473 | 474 | 475 | 476 | 477 | 478 | 479 | 480 | 481 | 482 | 483 | 484 | 485 | 486 | 503 | 504 | 505 | 506 | 507 | 508 | 509 | 560 | 561 | 562 | 565 | 566 | 567 | 570 | 571 | 572 | 575 | 576 | 577 | 580 | 581 | 582 | 583 | 584 | 585 | 586 | 587 | 588 | 589 | -------------------------------------------------------------------------------- /examples/README_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "302445fb", 6 | "metadata": {}, 7 | "source": [ 8 | "# Basic Example\n", 9 | "\n", 10 | "This notebook demonstrates the basic example shown in the readme" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "ffeacd16", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "# Ensure that the notebook shows whitespace in the plot\n", 21 | "%config InlineBackend.print_figure_kwargs = {'bbox_inches': None}" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "id": "e8888d26", 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "import numpy as np\n", 32 | "from matplotlib import pyplot as plt\n", 33 | "\n", 34 | "from mpllayout import layout as ly\n", 35 | "from mpllayout import primitives as pr\n", 36 | "from mpllayout import constraints as cr\n", 37 | "from mpllayout import matplotlibutils as mputils\n", 38 | "from mpllayout import solver\n", 39 | "\n", 40 | "# The layout object stores the geometry and constraints defining the layout\n", 41 | "layout = ly.Layout()\n", 42 | "\n", 43 | "# To add geometry, pass the: geometry primitive and a string key.\n", 44 | "# Naming the `Quadrilateral`, \"Figure\" will cause mpllayout to identify it as a figure\n", 45 | "layout.add_prim(pr.Quadrilateral(), \"Figure\")\n", 46 | "\n", 47 | "# To create a constraint, pass the: constraint, geometry to constrain, and\n", 48 | "# constraint arguments.\n", 49 | "# Constraint documentation describes what kind of geometry can be constrained and\n", 50 | "# any constraint arguments.\n", 51 | "layout.add_constraint(cr.Box(), (\"Figure\",), ())\n", 52 | "layout.add_constraint(cr.Width(), (\"Figure\",), (5.0,))\n", 53 | "layout.add_constraint(cr.Height(), (\"Figure\",), (4.0,))\n", 54 | "\n", 55 | "# To add an axes, pass the `Axes` primitive\n", 56 | "# The `Axes` is container of Quadrilaterals representing the drawing area (frame),\n", 57 | "# as well as, optionally, the x-axis and y-axis\n", 58 | "layout.add_prim(pr.Axes(), \"MyAxes\")\n", 59 | "\n", 60 | "# Constrain the axes drawing area to a box\n", 61 | "layout.add_constraint(cr.Box(), (\"MyAxes/Frame\",), ())\n", 62 | "# Set \"inner\" margins around the outside of the axes frame to the figure\n", 63 | "# The inner margin is the distance from a `Quadrilateral` inside another\n", 64 | "# `Quadrilateral`\n", 65 | "layout.add_constraint(cr.InnerMargin(side=\"bottom\"), (\"MyAxes/Frame\", \"Figure\"), (.5,))\n", 66 | "layout.add_constraint(cr.InnerMargin(side=\"top\"), (\"MyAxes/Frame\", \"Figure\"), (.5,))\n", 67 | "layout.add_constraint(cr.InnerMargin(side=\"left\"), (\"MyAxes/Frame\", \"Figure\"), (2.0,))\n", 68 | "layout.add_constraint(cr.InnerMargin(side=\"right\"), (\"MyAxes/Frame\", \"Figure\"), (0.5,))\n", 69 | "\n", 70 | "# Solve the constrained layout for geometry that satisfies the constraints\n", 71 | "solved_prims, *_ = solver.solve(layout)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "id": "7166f72e", 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "# Create the figure and any axes from the solved geometry\n", 82 | "fig, axs = mputils.subplots(solved_prims)\n", 83 | "\n", 84 | "x = np.linspace(0, 2*np.pi)\n", 85 | "axs[\"MyAxes\"].plot(np.sin(x))\n", 86 | "axs[\"MyAxes\"].set_xlabel(\"x\")\n", 87 | "axs[\"MyAxes\"].set_ylabel(\"y\")\n", 88 | "\n", 89 | "fig.savefig(\"READMEExample.svg\")\n" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "id": "db55b8dd", 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [] 99 | } 100 | ], 101 | "metadata": { 102 | "kernelspec": { 103 | "display_name": "numerics", 104 | "language": "python", 105 | "name": "python3" 106 | }, 107 | "language_info": { 108 | "codemirror_mode": { 109 | "name": "ipython", 110 | "version": 3 111 | }, 112 | "file_extension": ".py", 113 | "mimetype": "text/x-python", 114 | "name": "python", 115 | "nbconvert_exporter": "python", 116 | "pygments_lexer": "ipython3", 117 | "version": "3.12.7" 118 | } 119 | }, 120 | "nbformat": 4, 121 | "nbformat_minor": 5 122 | } 123 | -------------------------------------------------------------------------------- /examples/contour_colorbar.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Contour plot with colorbar example\n" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import numpy as np\n", 17 | "\n", 18 | "from mpllayout import primitives as pr\n", 19 | "from mpllayout import constraints as cr\n", 20 | "from mpllayout import ui\n", 21 | "from mpllayout.solver import solve\n", 22 | "from mpllayout.layout import Layout\n", 23 | "\n", 24 | "from mpllayout.matplotlibutils import subplots" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "# Use this so matplotlib figures are showed with whitespace\n", 34 | "%matplotlib inline\n", 35 | "%config InlineBackend.print_figure_kwargs = {'bbox_inches': None}" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "metadata": {}, 41 | "source": [ 42 | "## Specify the layout" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "## Create the figure and two axes (one for the contour plot and one for the colorbar)\n", 52 | "\n", 53 | "### Primitives\n", 54 | "# Create a `Layout` to track all geometric primitives and constraints\n", 55 | "layout = Layout()\n", 56 | "\n", 57 | "# Create the main geometric primitives\n", 58 | "layout.add_prim(pr.Quadrilateral(), \"Figure\")\n", 59 | "layout.add_prim(pr.Axes(), \"AxesContour\")\n", 60 | "layout.add_prim(pr.Axes(), \"AxesColorbar\")\n", 61 | "\n", 62 | "### Constraints\n", 63 | "\n", 64 | "# Make all the quadrilaterals rectangular\n", 65 | "layout.add_constraint(cr.Box(), (\"Figure\",), ())\n", 66 | "layout.add_constraint(cr.Box(), (\"AxesContour/Frame\",), ())\n", 67 | "layout.add_constraint(cr.Box(), (\"AxesColorbar/Frame\",), ())\n", 68 | "\n", 69 | "## Figure size\n", 70 | "# Set the figure width\n", 71 | "layout.add_constraint(cr.Length(), (\"Figure/Line0\",), (5,))\n", 72 | "\n", 73 | "# Fix the figure's lower left corner to the origin\n", 74 | "layout.add_constraint(cr.Fix(), (\"Figure/Line0/Point0\",), (np.array([0, 0]),))\n", 75 | "\n", 76 | "## Axes sizes\n", 77 | "\n", 78 | "# Set the aspect ratio to match physical dimensions\n", 79 | "layout.add_constraint(\n", 80 | " cr.RelativeLength(), (\"AxesContour/Frame/Line0\", \"AxesContour/Frame/Line1\"), (2,)\n", 81 | ")\n", 82 | "\n", 83 | "# Set the color bar height to 1/4 inch\n", 84 | "layout.add_constraint(cr.Length(), (\"AxesColorbar/Frame/Line1\",), (1 / 8,))\n", 85 | "\n", 86 | "## Align the color bar with the contour plot\n", 87 | "# Make the left/right sides collinear\n", 88 | "layout.add_constraint(\n", 89 | " cr.Collinear(), (\"AxesContour/Frame/Line3\", \"AxesColorbar/Frame/Line3\"), ()\n", 90 | ")\n", 91 | "layout.add_constraint(\n", 92 | " cr.Collinear(), (\"AxesContour/Frame/Line1\", \"AxesColorbar/Frame/Line1\"), ()\n", 93 | ")\n", 94 | "\n", 95 | "## Margins\n", 96 | "# Place the color bar 1/16 inch above the contour plot\n", 97 | "layout.add_constraint(\n", 98 | " cr.MidpointYDistance(),\n", 99 | " (\"AxesContour/Frame/Line2\", \"AxesColorbar/Frame/Line0\"),\n", 100 | " (1 / 16,)\n", 101 | ")\n", 102 | "\n", 103 | "# Set the left margin to 6/8 inch from the contour plot\n", 104 | "layout.add_constraint(\n", 105 | " cr.MidpointXDistance(), (\"Figure/Line3\", \"AxesContour/Frame/Line3\"), (6 / 8,)\n", 106 | ")\n", 107 | "\n", 108 | "# Set the right margin 1/8 inch from the contour plot\n", 109 | "layout.add_constraint(\n", 110 | " cr.MidpointXDistance(), (\"AxesContour/Frame/Line1\", \"Figure/Line1\"), (1 / 8,)\n", 111 | ")\n", 112 | "\n", 113 | "# Set the top margin to 1/2 inch\n", 114 | "layout.add_constraint(\n", 115 | " cr.MidpointYDistance(), (\"AxesColorbar/Frame/Line2\", \"Figure/Line2\"), (1 / 2,)\n", 116 | ")\n", 117 | "\n", 118 | "# Set the bottom margin to 6/8 inch\n", 119 | "layout.add_constraint(\n", 120 | " cr.MidpointYDistance(), (\"Figure/Line0\", \"AxesContour/Frame/Line0\"), (6 / 8,)\n", 121 | ")" 122 | ] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "metadata": {}, 127 | "source": [ 128 | "## Solve the layout" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "## Solve the constrained layout\n", 138 | "\n", 139 | "root_prim_n, solve_info = solve(layout)\n", 140 | "\n", 141 | "# Visualize the layout\n", 142 | "fig_layout, ax_layout = ui.figure_prims(root_prim_n)\n", 143 | "\n", 144 | "fig_layout" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "metadata": {}, 150 | "source": [ 151 | "## Plot a contour plot and colorbar using the layout" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "# This creates the figure and axes based on the layout\n", 161 | "fig, axs = subplots(root_prim_n)\n", 162 | "# The sizes of `fig` and axes in `axs` will reflect the constraints\n", 163 | "# `axs['AxesColorbar']` is the colorbar and `axs['AxesContour']` is the contour\n", 164 | "\n", 165 | "\n", 166 | "print(fig.get_size_inches())\n", 167 | "\n", 168 | "x = np.linspace(0, 10, 51)\n", 169 | "y = np.linspace(0, 5, 26)\n", 170 | "xx, yy = np.meshgrid(x, y)\n", 171 | "z = (xx - 5) ** 2 + (yy - 2.5) ** 2\n", 172 | "\n", 173 | "cset = axs[\"AxesContour\"].contourf(x, y, z)\n", 174 | "\n", 175 | "axs[\"AxesContour\"].set_xlabel(\"x [cm]\")\n", 176 | "axs[\"AxesContour\"].set_ylabel(\"y [cm]\")\n", 177 | "\n", 178 | "fig.colorbar(cset, cax=axs[\"AxesColorbar\"], orientation=\"horizontal\")\n", 179 | "axs[\"AxesColorbar\"].xaxis.set_label_text(\"z [cm]\")\n", 180 | "axs[\"AxesColorbar\"].xaxis.set_tick_params(\n", 181 | " top=True, labeltop=True, bottom=False, labelbottom=False\n", 182 | ")\n", 183 | "axs[\"AxesColorbar\"].xaxis.set_label_position(position=\"top\")\n", 184 | "\n", 185 | "fig.savefig(\"contour_colorbar.png\")\n", 186 | "\n", 187 | "fig" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": null, 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [] 196 | } 197 | ], 198 | "metadata": { 199 | "kernelspec": { 200 | "display_name": "numerics", 201 | "language": "python", 202 | "name": "python3" 203 | }, 204 | "language_info": { 205 | "codemirror_mode": { 206 | "name": "ipython", 207 | "version": 3 208 | }, 209 | "file_extension": ".py", 210 | "mimetype": "text/x-python", 211 | "name": "python", 212 | "nbconvert_exporter": "python", 213 | "pygments_lexer": "ipython3", 214 | "version": "3.12.7" 215 | } 216 | }, 217 | "nbformat": 4, 218 | "nbformat_minor": 2 219 | } 220 | -------------------------------------------------------------------------------- /examples/grid_axes.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Grid of fixed aspect axes example\n", 8 | "\n", 9 | "This example illustrates how to create a grid of axes with fixed-aspect ratios and margins around the figure.\n", 10 | "The width of the figure is fixed (for example, the width could be constrained by the column width in a journal) while the height automatically adjusts from the given constraints and axes width.\n", 11 | "This is difficult to accomplish directly in `matplotlib` without trial-and-error because the axes grid height isn't known until the figure is plotted." 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import itertools\n", 21 | "\n", 22 | "import numpy as np\n", 23 | "\n", 24 | "from mpllayout import solver\n", 25 | "from mpllayout import primitives as pr\n", 26 | "from mpllayout import constraints as co\n", 27 | "from mpllayout import layout as lay\n", 28 | "from mpllayout import matplotlibutils as lplt\n", 29 | "from mpllayout import ui" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": {}, 35 | "source": [ 36 | "## Specify the layout" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "layout = lay.Layout()\n", 46 | "\n", 47 | "## Create the figure box\n", 48 | "layout.add_prim(pr.Quadrilateral(), \"Figure\")\n", 49 | "layout.add_constraint(co.Box(), (\"Figure\",), ())\n", 50 | "\n", 51 | "## Constrain the figure width\n", 52 | "# Note that the figure height isn't directly constrainted because other\n", 53 | "# constraints (margins, the axes aspect ratio, etc.) implicity determine what\n", 54 | "# the figure height should be.\n", 55 | "fig_width = 6\n", 56 | "layout.add_constraint(co.Width(), (\"Figure\",), (fig_width,))\n", 57 | "\n", 58 | "# Fix the figure bottom left corner to (0, 0)\n", 59 | "\n", 60 | "layout.add_constraint(co.Fix(), (\"Figure/Line0/Point0\",), (np.array([0, 0]),))\n", 61 | "\n", 62 | "## Create the axes\n", 63 | "\n", 64 | "# You can change the number of rows and columns in the axes here.\n", 65 | "num_row, num_col = (3, 4)\n", 66 | "axes_shape = (num_row, num_col)\n", 67 | "num_axes = int(np.prod(axes_shape))\n", 68 | "\n", 69 | "axes_keys = [\n", 70 | " f\"Axes[{row}, {col}]\"\n", 71 | " for row, col in itertools.product(range(num_row), range(num_col))\n", 72 | "]\n", 73 | "for axes_key in axes_keys:\n", 74 | " layout.add_prim(pr.Axes(), axes_key)\n", 75 | " layout.add_constraint(co.Box(), (f\"{axes_key}/Frame\",), ())\n", 76 | "\n", 77 | "# First constrain the top-left corner axes aspect ratio to be square.\n", 78 | "# Note this assumes the [0, 0] element is the top-left corner but other conventions\n", 79 | "# are possible.\n", 80 | "layout.add_constraint(co.AspectRatio(), (\"Axes[0, 0]/Frame\",), (1,))\n", 81 | "\n", 82 | "## Constrain the axes in a grid\n", 83 | "# You can either use a `Grid` constraint or manually apply the constraint grid\n", 84 | "\n", 85 | "margin_inner = 0.1\n", 86 | "\n", 87 | "# Approach using `co.Grid`:\n", 88 | "col_widths = np.ones(num_col-1)\n", 89 | "row_heights = np.ones(num_row-1)\n", 90 | "col_margins = margin_inner * np.ones(num_col-1)\n", 91 | "row_margins = margin_inner * np.ones(num_row-1)\n", 92 | "\n", 93 | "layout.add_constraint(\n", 94 | " co.Grid(shape=axes_shape),\n", 95 | " [f\"{axes_key}/Frame\" for axes_key in axes_keys],\n", 96 | " (col_widths, row_heights, col_margins, row_margins)\n", 97 | ")\n", 98 | "\n", 99 | "# Manual approach:\n", 100 | "\n", 101 | "# Constrain them on a rectilinear grid\n", 102 | "# (the x/y grid lines are still free to move left-right and up-down)\n", 103 | "# layout.add_constraint(\n", 104 | "# co.RectilinearGrid(shape=axes_shape),\n", 105 | "# [f\"{axes_key}/Frame\" for axes_key in axes_keys],\n", 106 | "# ()\n", 107 | "# )\n", 108 | "\n", 109 | "# # Because the axes lie on a grid, you only need to set widths/horizontal\n", 110 | "# # margins for a single row or axes, then heights/vertical margins for a single\n", 111 | "# # column of axes.\n", 112 | "\n", 113 | "# for col, width, margin in zip(range(1, num_col), col_widths, col_margins):\n", 114 | "# # Set equal widths for row 0\n", 115 | "# layout.add_constraint(\n", 116 | "# co.RelativeLength(),\n", 117 | "# (f\"Axes[0, {col}]/Frame/Line0\", \"Axes[0, 0]/Frame/Line0\"),\n", 118 | "# (width,)\n", 119 | "# )\n", 120 | "# # Set interior horizontal margin\n", 121 | "# layout.add_constraint(\n", 122 | "# co.OuterMargin(side='right'),\n", 123 | "# (f\"Axes[0, {col-1}]/Frame\", f\"Axes[0, {col}]/Frame\"),\n", 124 | "# (margin,)\n", 125 | "# )\n", 126 | "# for row, height, margin in zip(range(1, num_row), row_heights, row_margins):\n", 127 | "# # Set equal heights for column 0\n", 128 | "# layout.add_constraint(\n", 129 | "# co.RelativeLength(),\n", 130 | "# (f\"Axes[{row}, 0]/Frame/Line1\", \"Axes[0, 0]/Frame/Line1\"),\n", 131 | "# (height,)\n", 132 | "# )\n", 133 | "# # Set interior vertical margin\n", 134 | "# layout.add_constraint(\n", 135 | "# co.OuterMargin(side='bottom'),\n", 136 | "# (f\"Axes[{row-1}, 0]/Frame\", f\"Axes[{row}, 0]/Frame\"),\n", 137 | "# (margin,)\n", 138 | "# )\n", 139 | "\n", 140 | "## Constrain margins around the figure\n", 141 | "\n", 142 | "# Constrain top/bottom margins\n", 143 | "margin_top = 0.2\n", 144 | "margin_bottom = 0.2\n", 145 | "layout.add_constraint(\n", 146 | " co.InnerMargin(side='top'), (\"Axes[0, 0]/Frame\", \"Figure\"), (margin_top,)\n", 147 | ")\n", 148 | "layout.add_constraint(\n", 149 | " co.InnerMargin(side='bottom'), (f\"Axes[{num_row-1}, 0]/Frame\", \"Figure\"), (margin_bottom, )\n", 150 | ")\n", 151 | "\n", 152 | "# Constrain left/right margins\n", 153 | "margin_left = 0.2\n", 154 | "margin_right = 0.2\n", 155 | "layout.add_constraint(\n", 156 | " co.InnerMargin(side='left'),\n", 157 | " (\"Axes[0, 0]/Frame\", \"Figure\"), (margin_left,)\n", 158 | ")\n", 159 | "layout.add_constraint(\n", 160 | " co.InnerMargin(side='right'),\n", 161 | " (f\"Axes[0, {num_col-1}]/Frame\", \"Figure\"), (margin_right,)\n", 162 | ")" 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "metadata": {}, 168 | "source": [ 169 | "## Solve the layout" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "metadata": {}, 176 | "outputs": [], 177 | "source": [ 178 | "## Solve the constraints and form the figure/axes layout\n", 179 | "prim_tree_n, info = solver.solve(layout)\n", 180 | "print(info)\n", 181 | "\n", 182 | "# This is the solved layout:\n", 183 | "_fig, _ = ui.figure_prims(prim_tree_n)\n", 184 | "_fig.savefig(\"grid_axes_layout.png\", dpi=300)" 185 | ] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "metadata": {}, 190 | "source": [ 191 | "## Plot a grid of images using the layout" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": null, 197 | "metadata": {}, 198 | "outputs": [], 199 | "source": [ 200 | "# Use the layout to plot randomly generated 10x10 pixel images\n", 201 | "fig, axs = lplt.subplots(prim_tree_n)\n", 202 | "\n", 203 | "for key, ax in axs.items():\n", 204 | " ax.set_axis_off()\n", 205 | "\n", 206 | " ax.imshow(np.random.rand(10, 10))\n", 207 | "\n", 208 | "# x = np.linspace(0, 1)\n", 209 | "# axs['Axes1'].plot(x, x**2)\n", 210 | "\n", 211 | "fig.savefig(\"grid_axes.png\")\n", 212 | "# fig" 213 | ] 214 | } 215 | ], 216 | "metadata": { 217 | "kernelspec": { 218 | "display_name": "numerics", 219 | "language": "python", 220 | "name": "python3" 221 | }, 222 | "language_info": { 223 | "codemirror_mode": { 224 | "name": "ipython", 225 | "version": 3 226 | }, 227 | "file_extension": ".py", 228 | "mimetype": "text/x-python", 229 | "name": "python", 230 | "nbconvert_exporter": "python", 231 | "pygments_lexer": "ipython3", 232 | "version": "3.12.7" 233 | } 234 | }, 235 | "nbformat": 4, 236 | "nbformat_minor": 2 237 | } 238 | -------------------------------------------------------------------------------- /examples/one_axes.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# One axes figure example\n", 8 | "\n", 9 | "This example demonstrates how to create a one axes figure with fixed margins around the axes using both `mpllayout` and pure `matplotlib`." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import numpy as np\n", 19 | "from matplotlib import pyplot as plt\n", 20 | "\n", 21 | "from mpllayout import primitives as pr\n", 22 | "from mpllayout import constraints as co\n", 23 | "from mpllayout import layout as lay\n", 24 | "from mpllayout import solver\n", 25 | "from mpllayout import matplotlibutils as lplt\n" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "## Specify and solve a layout using `mpllayout`" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "layout = lay.Layout()\n", 42 | "\n", 43 | "## Create the figure and axes\n", 44 | "layout.add_prim(pr.Quadrilateral(), \"Figure\")\n", 45 | "layout.add_prim(pr.Axes(), \"Axes\")\n", 46 | "\n", 47 | "## Constrain figure and axes quadirlateral to be rectangular\n", 48 | "layout.add_constraint(co.Box(), (\"Figure\",), ())\n", 49 | "layout.add_constraint(co.Box(), (\"Axes/Frame\",), ())\n", 50 | "\n", 51 | "## Constrain the figure size\n", 52 | "fig_width, fig_height = 6, 3\n", 53 | "layout.add_constraint(co.XLength(), (\"Figure/Line0\",), (fig_width,))\n", 54 | "layout.add_constraint(co.YLength(), (\"Figure/Line1\",), (fig_height,))\n", 55 | "\n", 56 | "## Constrain 'Axes' margins\n", 57 | "# Constrain left/right margins\n", 58 | "margin_left = 1.1\n", 59 | "margin_right = 1.1\n", 60 | "layout.add_constraint(\n", 61 | " co.InnerMargin(side='left'), (\"Axes/Frame\", \"Figure\"), (margin_left,)\n", 62 | ")\n", 63 | "layout.add_constraint(\n", 64 | " co.InnerMargin(side='right'), (\"Axes/Frame\", \"Figure\"), (margin_right,)\n", 65 | ")\n", 66 | "\n", 67 | "# Constrain top/bottom margins\n", 68 | "margin_top = 1.1\n", 69 | "margin_bottom = 0.5\n", 70 | "layout.add_constraint(\n", 71 | " co.InnerMargin(side='bottom'), (\"Axes/Frame\", \"Figure\"), (margin_bottom,)\n", 72 | ")\n", 73 | "layout.add_constraint(\n", 74 | " co.InnerMargin(side='top'), (\"Axes/Frame\", \"Figure\"), (margin_top,)\n", 75 | ")\n", 76 | "\n", 77 | "## Solve the constraints\n", 78 | "prim_tree_n, info = solver.solve(layout)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "# Used the solved layout to form the figure and axes objects, then plot\n", 88 | "fig, axs = lplt.subplots(prim_tree_n)\n", 89 | "\n", 90 | "ax = axs[\"Axes\"]\n", 91 | "\n", 92 | "x = np.linspace(0, 1)\n", 93 | "ax.plot(x, x**2)\n", 94 | "\n", 95 | "ax.set_xlabel(\"x\")\n", 96 | "ax.set_ylabel(\"y\")" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": {}, 102 | "source": [ 103 | "## Specify and 'solve' a layout using pure `matplotlib`\n", 104 | "\n", 105 | "There are multiple approaches to create the above layout using pure `matplotlib`; however, some may not have the precise margins specified.\n", 106 | "To create the layout using pure `matplotlib`, the axes position that satisfies the desired margins and figure size must be determined.\n", 107 | "This essentially involves manually \"solving\" the system of constraints.\n", 108 | "For the simple case of fixed margins around a single axes, the solution is given by the code below.\n", 109 | "This is the basic process that `mpllayout` performs but for general constraints, which allows for more complicated layouts." 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "## This code determines the axes sizes needed to achieve the desired margins\n", 119 | "\n", 120 | "# Specify the desired figure size and margins\n", 121 | "fig_width, fig_height = 6, 3\n", 122 | "\n", 123 | "margin_left = 1.1\n", 124 | "margin_right = 1.1\n", 125 | "\n", 126 | "margin_top = 1.1\n", 127 | "margin_bottom = 0.5\n", 128 | "\n", 129 | "# The code below \"solves\" the axes position needed to achieve the given margins.\n", 130 | "# Using this approach calculates maintains the same margins even if the figure\n", 131 | "# dimensions are changed.\n", 132 | "# This is essentially what `mpllayout` does.\n", 133 | "coord_botleft = np.array((margin_left, margin_bottom))\n", 134 | "coord_topright = np.array((fig_width - margin_right, fig_height-margin_top))\n", 135 | "\n", 136 | "# Scale the coordinates relative to the figure size since this is how `matplotlib`\n", 137 | "# interprets the axes position\n", 138 | "coord_botleft = coord_botleft / (fig_width, fig_height)\n", 139 | "coord_topright = coord_topright / (fig_width, fig_height)\n", 140 | "\n", 141 | "# Determine the axes width and height\n", 142 | "axes_width, axes_height = (coord_topright - coord_botleft)" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "metadata": {}, 149 | "outputs": [], 150 | "source": [ 151 | "# Use the \"solved\" axes position and figure size to create the figure\n", 152 | "fig = plt.figure(figsize=(fig_width, fig_height))\n", 153 | "ax = fig.add_axes((*coord_botleft, axes_width, axes_height))\n", 154 | "\n", 155 | "x = np.linspace(0, 1)\n", 156 | "ax.plot(x, x**2)\n", 157 | "\n", 158 | "ax.set_xlabel(\"x\")\n", 159 | "ax.set_ylabel(\"y\")\n" 160 | ] 161 | } 162 | ], 163 | "metadata": { 164 | "kernelspec": { 165 | "display_name": "numerics", 166 | "language": "python", 167 | "name": "python3" 168 | }, 169 | "language_info": { 170 | "codemirror_mode": { 171 | "name": "ipython", 172 | "version": 3 173 | }, 174 | "file_extension": ".py", 175 | "mimetype": "text/x-python", 176 | "name": "python", 177 | "nbconvert_exporter": "python", 178 | "pygments_lexer": "ipython3", 179 | "version": "3.12.7" 180 | } 181 | }, 182 | "nbformat": 4, 183 | "nbformat_minor": 2 184 | } 185 | -------------------------------------------------------------------------------- /examples/profile_solver.py: -------------------------------------------------------------------------------- 1 | """ 2 | Profile runtimes of key functions 3 | """ 4 | 5 | import typing as tp 6 | from typing import Optional 7 | 8 | import cProfile 9 | import pstats 10 | 11 | import numpy as np 12 | 13 | from mpllayout import layout as lay 14 | from mpllayout import primitives as pr 15 | from mpllayout import constraints as co 16 | from mpllayout import containers as cn 17 | from mpllayout import solver 18 | 19 | 20 | def gen_layout(axes_shape: Optional[tuple[int, ...]] = (3, 3)) -> lay.Layout: 21 | layout = lay.Layout() 22 | 23 | ## Create an origin point 24 | layout.add_prim(pr.Point([0, 0]), "Origin") 25 | layout.add_constraint(co.Fix(), ("Origin",), (np.array([0, 0]),)) 26 | 27 | ## Create the figure box 28 | layout.add_prim(pr.Quadrilateral(), "Figure") 29 | layout.add_constraint(co.Box(), ("Figure",), ()) 30 | 31 | ## Constrain the figure size and position 32 | fig_width, fig_height = 6, 3 33 | layout.add_constraint(co.Length(), ("Figure/Line0",), (fig_width,)) 34 | layout.add_constraint(co.Coincident(), ("Figure/Line0/Point0", "Origin"), ()) 35 | 36 | ## Create the axes boxes 37 | num_row, num_col = axes_shape 38 | num_axes = int(np.prod(axes_shape)) 39 | verts = [[0, 0], [5, 0], [5, 5], [0, 5]] 40 | for n in range(num_axes): 41 | layout.add_prim(pr.Axes(), f"Axes{n}") 42 | layout.add_constraint(co.Box(), (f"Axes{n}/Frame",), ()) 43 | 44 | ## Constrain the axes in a grid 45 | num_row, num_col = axes_shape 46 | grid_kwargs = ( 47 | (num_col - 1) * [1 / 16], 48 | (num_row - 1) * [1 / 16], 49 | (num_col - 1) * [1], 50 | (num_row - 1) * [1], 51 | ) 52 | layout.add_constraint( 53 | co.Grid(axes_shape), 54 | tuple(f"Axes{n}/Frame" for n in range(num_axes)), 55 | grid_kwargs 56 | ) 57 | 58 | # Constrain the first axis aspect ratio 59 | layout.add_constraint( 60 | co.RelativeLength(), ("Axes0/Frame/Line0", "Axes0/Frame/Line1"), (2,) 61 | ) 62 | 63 | # Constrain top/bottom margins 64 | margin_top = 1.1 65 | margin_bottom = 0.5 66 | layout.add_constraint( 67 | co.DirectedDistance(), 68 | ("Axes0/Frame/Line1/Point1", "Figure/Line1/Point1"), 69 | (np.array([0, 1]), margin_top) 70 | ) 71 | layout.add_constraint( 72 | co.DirectedDistance(), 73 | (f"Axes{num_axes-1}/Frame/Line1/Point0", "Figure/Line1/Point0"), 74 | (np.array([0, -1]), margin_bottom) 75 | ) 76 | 77 | # Constrain left/right margins 78 | margin_left = 0.2 79 | margin_right = 0.3 80 | layout.add_constraint( 81 | co.DirectedDistance(), 82 | ("Axes0/Frame/Line0/Point0", "Figure/Line0/Point0"), 83 | (np.array([-1, 0]), margin_left) 84 | ) 85 | layout.add_constraint( 86 | co.DirectedDistance(), 87 | (f"Axes{num_col-1}/Frame/Line1/Point1", "Figure/Line1/Point1"), 88 | (np.array([1, 0]), margin_right) 89 | ) 90 | return layout 91 | 92 | 93 | if __name__ == "__main__": 94 | layout = gen_layout((12, 12)) 95 | 96 | constraints, constraint_graph, constraint_params = layout.flat_constraints() 97 | 98 | solver.assem_constraint_residual( 99 | layout.root_prim, constraints, constraint_graph, constraint_params 100 | ) 101 | stmt = "solver.assem_constraint_residual(layout.root_prim, constraints, constraint_graph, constraint_params)" 102 | cProfile.run(stmt, "profile_wo_jax.prof") 103 | stats = pstats.Stats("profile_wo_jax.prof") 104 | stats.sort_stats("cumtime").print_stats(20) 105 | 106 | import jax 107 | 108 | constraints_jit = [jax.jit(c) for c in constraints] 109 | solver.assem_constraint_residual( 110 | layout.root_prim, 111 | constraints_jit, 112 | constraint_graph, 113 | constraint_params 114 | ) 115 | stmt = "solver.assem_constraint_residual(layout.root_prim, constraints_jit, constraint_graph, constraint_params)" 116 | cProfile.run(stmt, "profile_w_jax.prof") 117 | stats = pstats.Stats("profile_w_jax.prof") 118 | stats.sort_stats("cumtime").print_stats(20) 119 | -------------------------------------------------------------------------------- /examples/ten_simple_rules_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "869c5e8b-6a10-48dd-855c-9aa540ec082a", 6 | "metadata": {}, 7 | "source": [ 8 | "# \"Ten Simple Rules for Better Figures\" example\n", 9 | "\n", 10 | "`mpllayout` models axes and other elements in figures as geometric primitives which can be constrained relative to each other. \n", 11 | "This gives a flexible way to precisely position figure elements.\n", 12 | "\n", 13 | "The following demo produces Figure 1 from the paper \"Ten Simple Rules for Better Figures\" (Rougier NP, Droettboom M, Bourne PE (2014) Ten Simple Rules for Better Figures. PLOS Computational Biology 10(9): e1003833. https://doi.org/10.1371/journal.pcbi.1003833).\n", 14 | "This figure is itself, a remake of one originally published in the [New York Times](https://archive.nytimes.com/www.nytimes.com/imagepages/2007/07/29/health/29cancer.graph.web.html?action=click&module=RelatedCoverage&pgtype=Article®ion=Footer).\n", 15 | "\n", 16 | "The below two sections illustrate how to create the above figure with a `Grid` type constraint as well as more basic constraints." 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "id": "97be0bae-5121-4a0f-bbcd-aab3a5bea0bd", 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "# Import the relevant packages\n", 27 | "\n", 28 | "import matplotlib.pyplot as plt\n", 29 | "import numpy as np\n", 30 | "\n", 31 | "from mpllayout.solver import solve\n", 32 | "from mpllayout.layout import Layout, update_layout_constraints\n", 33 | "from mpllayout import primitives as pr\n", 34 | "from mpllayout import constraints as co\n", 35 | "from mpllayout.matplotlibutils import subplots, update_subplots, find_axis_position\n", 36 | "from mpllayout import ui" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "id": "5c3a8af3", 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "def figure_layout(layout):\n", 47 | " \"\"\"\n", 48 | " Return a figure of the layout\n", 49 | " \"\"\"\n", 50 | " prims_n, solve_info = solve(layout)\n", 51 | " return ui.figure_prims(prims_n)" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "id": "a1fc5432", 57 | "metadata": {}, 58 | "source": [ 59 | "## Specify the `layout`\n", 60 | "\n", 61 | "The `layout` is a collection of:\n", 62 | "- geometric primitives to represent figure elements\n", 63 | "- and constraints which position the figure elements\n" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "id": "a171b105", 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "layout = Layout()" 74 | ] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "id": "7c231206", 79 | "metadata": {}, 80 | "source": [ 81 | "### Add geometric primitives\n", 82 | "\n", 83 | "Represent the `mpl.Figure` by a `Quadrilateral` and `mpl.Axes` by an `Axes` primitive.\n", 84 | "The `Axes` primitive contains a `Quadrilateral` to represent the frame, and optionally, a `Quadrilateral` and `Point` to represent the x axis and y axis." 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "id": "08134366", 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "# First a box called \"Figure\" to the layout\n", 95 | "layout.add_prim(pr.Quadrilateral(), \"Figure\")\n", 96 | "\n", 97 | "# Then add \"axes\" primitives to represent a left, middle, and right axes\n", 98 | "# Axes can contain `Quadrilaterals` and `Point` primitives to represent the\n", 99 | "# axes frame, x/y axis, and axis labels\n", 100 | "layout.add_prim(pr.Axes(xaxis=True, yaxis=True), \"AxesLeft\")\n", 101 | "layout.add_prim(pr.Axes(), \"AxesMid\")\n", 102 | "layout.add_prim(pr.Axes(xaxis=True, yaxis=True), \"AxesRight\")" 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "id": "cc84cd74", 108 | "metadata": {}, 109 | "source": [ 110 | "### Add geometric constraints" 111 | ] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "id": "3de3df32", 116 | "metadata": {}, 117 | "source": [ 118 | "#### Make all `Quadrilateral`s rectangular\n", 119 | "\n", 120 | "MPLlayout doesn't constrain quadrilaterals to be rectangular like the figure or axes frame in matplotlib so they must be constrained." 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "id": "53d68f61", 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "# `co.Box` forces quadrilateral sides to be vertical and tops/bottoms to be horizontal\n", 131 | "# It has no parameters so that last argument is any empty tuple\n", 132 | "layout.add_constraint(co.Box(), (\"Figure\",), ())\n", 133 | "\n", 134 | "# \"AxesMid\" only has a frame (no x/y axis)\n", 135 | "layout.add_constraint(co.Box(), (\"AxesMid/Frame\",), ())\n", 136 | "\n", 137 | "# Here we constrain all child quads of the left and right axes to be boxes\n", 138 | "for axes_key in [\"AxesLeft\", \"AxesRight\"]:\n", 139 | " for quad_key in [\"Frame\", \"XAxis\", \"YAxis\"]:\n", 140 | " layout.add_constraint(co.Box(), (f\"{axes_key}/{quad_key}\",), ())" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": null, 146 | "id": "5262250b", 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "# This plots the created geometry\n", 151 | "# Note that by default all the quads are unit squares\n", 152 | "figure_layout(layout)" 153 | ] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "id": "6e2224b7", 158 | "metadata": {}, 159 | "source": [ 160 | "#### Fix the Figure dimensions and position\n", 161 | "\n", 162 | "Set the figure width/height and fix the bottom left point to the origin" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "id": "3e9b9643", 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "## Set the figure dimensions\n", 173 | "\n", 174 | "# Fix the bottom left point of 'Figure' to the origin\n", 175 | "layout.add_constraint(co.Fix(), (\"Figure/Line0/Point0\",), (np.array([0, 0]),))\n", 176 | "\n", 177 | "# Set the 'Figure' width and height\n", 178 | "fig_width, fig_height = (12, 7)\n", 179 | "layout.add_constraint(co.Length(), (\"Figure/Line1\",), (fig_height,))\n", 180 | "layout.add_constraint(co.Length(), (\"Figure/Line0\",), (fig_width,))" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": null, 186 | "id": "9d25658e", 187 | "metadata": {}, 188 | "outputs": [], 189 | "source": [ 190 | "# Note the figure quadrilateral is 12\" by 7\"\n", 191 | "# The remaining axes are unit squares since they haven't been constrained yet\n", 192 | "fig, ax = figure_layout(layout)" 193 | ] 194 | }, 195 | { 196 | "cell_type": "markdown", 197 | "id": "56df289b", 198 | "metadata": {}, 199 | "source": [ 200 | "#### Constrain the `Axes` to a 1 by 3 rectilinear grid\n", 201 | "\n", 202 | "We can force the left, middle, and right axes to align on 1 by 3 grid and set their relative widths.\n", 203 | "\n" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": null, 209 | "id": "5e1eab5c", 210 | "metadata": {}, 211 | "outputs": [], 212 | "source": [ 213 | "# Align the axes on a 1x3 rectilinear grid\n", 214 | "shape = (1, 3)\n", 215 | "layout.add_constraint(\n", 216 | " co.RectilinearGrid(shape),\n", 217 | " (\"AxesLeft/Frame\", \"AxesMid/Frame\", \"AxesRight/Frame\"),\n", 218 | " ()\n", 219 | ")" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "id": "6d1b513b", 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "# Set zeros margins between left/right axes and the middle axes\n", 230 | "layout.add_constraint(co.OuterMargin(side='left'), (\"AxesMid/Frame\", \"AxesLeft/Frame\"), (0,))\n", 231 | "layout.add_constraint(co.OuterMargin(side='right'), (\"AxesMid/Frame\", \"AxesRight/Frame\"), (0,))\n", 232 | "\n", 233 | "# Make the left/right axes the same width and the central axes 0.5 that width\n", 234 | "layout.add_constraint(\n", 235 | " co.RelativeLength(), (\"AxesRight/Frame/Line0\", \"AxesLeft/Frame/Line0\"), (1.0,)\n", 236 | ")\n", 237 | "layout.add_constraint(\n", 238 | " co.RelativeLength(), (\"AxesMid/Frame/Line0\", \"AxesLeft/Frame/Line0\"), (0.5,)\n", 239 | ")" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": null, 245 | "id": "b007f45b", 246 | "metadata": {}, 247 | "outputs": [], 248 | "source": [ 249 | "# Note the 3 axes are now aligned\n", 250 | "# It's difficult to see because the left and right axes also have axises that are shown\n", 251 | "fig, ax = figure_layout(layout)" 252 | ] 253 | }, 254 | { 255 | "cell_type": "markdown", 256 | "id": "129f4a98", 257 | "metadata": {}, 258 | "source": [ 259 | "#### Position the x-axis and y-axis for left and right axes" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": null, 265 | "id": "b4730103", 266 | "metadata": {}, 267 | "outputs": [], 268 | "source": [ 269 | "# These constraints fix the x/y axis to one side of the axes\n", 270 | "# When creating the figure from a layout, these axis positions will be inherited\n", 271 | "layout.add_constraint(co.PositionXAxis(side='top'), (\"AxesLeft\", ), ())\n", 272 | "layout.add_constraint(co.PositionYAxis(side='right'), (\"AxesLeft\", ), ())\n", 273 | "\n", 274 | "layout.add_constraint(co.PositionXAxis(side='top'), (\"AxesRight\", ), ())\n", 275 | "layout.add_constraint(co.PositionYAxis(side='left'), (\"AxesRight\", ), ())\n" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": null, 281 | "id": "87f32652", 282 | "metadata": {}, 283 | "outputs": [], 284 | "source": [ 285 | "# These constraints set the variable width of the y-axis and variable height of the x-axis\n", 286 | "# The axis dimensions are variable since they depend on the size of any tick labels\n", 287 | "# Axis dimensions can be updated using `update_layout_constraints` after\n", 288 | "# axis text has been generated\n", 289 | "# Note that axis labels aren't included in the size of the axis!\n", 290 | "layout.add_constraint(\n", 291 | " co.XAxisThickness(), (f\"AxesLeft/XAxis\",), (None,)\n", 292 | ")\n", 293 | "layout.add_constraint(\n", 294 | " co.YAxisThickness(), (f\"AxesLeft/YAxis\",), (None,)\n", 295 | ")\n", 296 | "layout.add_constraint(\n", 297 | " co.XAxisThickness(), (f\"AxesRight/XAxis\",), (None,)\n", 298 | ")\n", 299 | "layout.add_constraint(\n", 300 | " co.YAxisThickness(), (f\"AxesRight/YAxis\",), (None,)\n", 301 | ")" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": null, 307 | "id": "cfd1ded6", 308 | "metadata": {}, 309 | "outputs": [], 310 | "source": [ 311 | "# Note the x axis is now stuck to the top of each axes\n", 312 | "figure_layout(layout)" 313 | ] 314 | }, 315 | { 316 | "cell_type": "markdown", 317 | "id": "32a9a09b", 318 | "metadata": {}, 319 | "source": [ 320 | "#### Set margins\n", 321 | "\n", 322 | "Note that earlier we never set the absolute width of the axes; to ensure nice whitespace we can specify margins to indirectly set the axes dimensions." 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": null, 328 | "id": "9721f9a5", 329 | "metadata": {}, 330 | "outputs": [], 331 | "source": [ 332 | "# Set the top/bottom margins\n", 333 | "# The top margin will be set above the x-axis bounding box which ensure the text won't cut out of the figure\n", 334 | "margin_top, margin_bottom = (0.5, 0.5)\n", 335 | "\n", 336 | "# The `InnerMargin` constraint sets the gap between an\n", 337 | "# inner quad (\"AxesRight/XAxis\") and an outer quad (\"Figure\")\n", 338 | "layout.add_constraint(\n", 339 | " co.InnerMargin(side=\"top\"), (\"AxesRight/XAxis\", \"Figure\"), (margin_top,)\n", 340 | ")\n", 341 | "\n", 342 | "layout.add_constraint(\n", 343 | " co.InnerMargin(side=\"bottom\"), (\"AxesRight/Frame\", \"Figure\"), (margin_bottom,)\n", 344 | ")\n", 345 | "\n", 346 | "# Set the left/right margins\n", 347 | "margin_left, margin_right = (0.5, 0.5)\n", 348 | "\n", 349 | "layout.add_constraint(\n", 350 | " co.InnerMargin(side='left'), (\"AxesLeft/Frame\", \"Figure\"), (margin_left,)\n", 351 | ")\n", 352 | "\n", 353 | "layout.add_constraint(\n", 354 | " co.InnerMargin(side='right'), (\"AxesRight/Frame\", \"Figure\"), (margin_right,)\n", 355 | ")" 356 | ] 357 | }, 358 | { 359 | "cell_type": "code", 360 | "execution_count": null, 361 | "id": "453aa970", 362 | "metadata": {}, 363 | "outputs": [], 364 | "source": [ 365 | "# Now the margins are all constrained!\n", 366 | "# The grid arrangement of the left/middle/right axes is clearer since the axes have been moved apart\n", 367 | "fig, ax = figure_layout(layout)" 368 | ] 369 | }, 370 | { 371 | "cell_type": "markdown", 372 | "id": "0866e7b2-c69c-43d1-9bc6-65de4f26ff9a", 373 | "metadata": {}, 374 | "source": [ 375 | "## Solve the layout\n", 376 | "\n", 377 | "We can solve the `layout` to determine a set of primitives that satisfy the constraints.\n", 378 | "The solved primitives are then used to generate matplotlib figure and axes objects that reflect the layout.\n", 379 | "\n", 380 | "This is nice because the figure design and arrangement is separated from the plotting of data." 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": null, 386 | "id": "6a879951-73f3-4d43-998a-564647a94f44", 387 | "metadata": {}, 388 | "outputs": [], 389 | "source": [ 390 | "prim_tree_n, solve_info = solve(layout)" 391 | ] 392 | }, 393 | { 394 | "cell_type": "code", 395 | "execution_count": null, 396 | "id": "35f187c0", 397 | "metadata": {}, 398 | "outputs": [], 399 | "source": [ 400 | "# This is the layout before the plot has been generated.\n", 401 | "# Note that the x axis thicknesses are 0 since no x tick labels exist.\n", 402 | "# After plotting the data, the axis thicknesses can be updated to account for this.\n", 403 | "fig, ax = figure_layout(layout)" 404 | ] 405 | }, 406 | { 407 | "cell_type": "markdown", 408 | "id": "b678d9f3", 409 | "metadata": {}, 410 | "source": [ 411 | "## Plot the \"Ten Simple Rules for Better Figures\" dataset using the layout\n", 412 | "\n", 413 | "We can use the generated figure and axes to plot data now." 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "execution_count": null, 419 | "id": "5a3fc270-4aac-4abe-943d-6c90e9c7f88e", 420 | "metadata": {}, 421 | "outputs": [], 422 | "source": [ 423 | "# The data below is approximated from a New York Times article ()\n", 424 | "# and is adapted from the figure-1.py file available at (https://github.com/rougier/ten-rules)\n", 425 | "\n", 426 | "diseases = [\n", 427 | " \"Kidney Cancer\",\n", 428 | " \"Bladder Cancer\",\n", 429 | " \"Esophageal Cancer\",\n", 430 | " \"Ovarian Cancer\",\n", 431 | " \"Liver Cancer\",\n", 432 | " \"Non-Hodgkin's\\nlymphoma\",\n", 433 | " \"Leukemia\",\n", 434 | " \"Prostate Cancer\",\n", 435 | " \"Pancreatic Cancer\",\n", 436 | " \"Breast Cancer\",\n", 437 | " \"Colorectal Cancer\",\n", 438 | " \"Lung Cancer\",\n", 439 | "]\n", 440 | "men_deaths = [\n", 441 | " 10000,\n", 442 | " 12000,\n", 443 | " 13000,\n", 444 | " 0,\n", 445 | " 14000,\n", 446 | " 12000,\n", 447 | " 16000,\n", 448 | " 25000,\n", 449 | " 20000,\n", 450 | " 500,\n", 451 | " 25000,\n", 452 | " 80000,\n", 453 | "]\n", 454 | "men_cases = [\n", 455 | " 30000,\n", 456 | " 50000,\n", 457 | " 13000,\n", 458 | " 0,\n", 459 | " 16000,\n", 460 | " 30000,\n", 461 | " 25000,\n", 462 | " 220000,\n", 463 | " 22000,\n", 464 | " 600,\n", 465 | " 55000,\n", 466 | " 115000,\n", 467 | "]\n", 468 | "women_deaths = [\n", 469 | " 6000,\n", 470 | " 5500,\n", 471 | " 5000,\n", 472 | " 20000,\n", 473 | " 9000,\n", 474 | " 12000,\n", 475 | " 13000,\n", 476 | " 0,\n", 477 | " 19000,\n", 478 | " 40000,\n", 479 | " 30000,\n", 480 | " 70000,\n", 481 | "]\n", 482 | "women_cases = [\n", 483 | " 20000,\n", 484 | " 18000,\n", 485 | " 5000,\n", 486 | " 25000,\n", 487 | " 9000,\n", 488 | " 29000,\n", 489 | " 24000,\n", 490 | " 0,\n", 491 | " 21000,\n", 492 | " 160000,\n", 493 | " 55000,\n", 494 | " 97000,\n", 495 | "]\n", 496 | "\n", 497 | "y_diseases = np.arange(len(diseases))" 498 | ] 499 | }, 500 | { 501 | "cell_type": "code", 502 | "execution_count": null, 503 | "id": "e1dd6749-4b65-4617-9152-e56c28f16b04", 504 | "metadata": {}, 505 | "outputs": [], 506 | "source": [ 507 | "def format_axes(ax):\n", 508 | " \"\"\"\n", 509 | " Apply the Axes formatting used in \"Ten Simple Rules\"\n", 510 | " \"\"\"\n", 511 | " if not ax.xaxis.get_inverted():\n", 512 | " origin_side = \"left\"\n", 513 | " far_side = \"right\"\n", 514 | " else:\n", 515 | " origin_side = \"right\"\n", 516 | " far_side = \"left\"\n", 517 | "\n", 518 | " ax.spines[far_side].set_color(\"none\")\n", 519 | " ax.spines[origin_side].set_zorder(10)\n", 520 | " ax.spines[\"bottom\"].set_color(\"none\")\n", 521 | "\n", 522 | " # ax.xaxis.set_ticks_position(\"top\")\n", 523 | "\n", 524 | " # ax.yaxis.set_ticks_position(origin_side)\n", 525 | " ax.yaxis.set_ticks(y_diseases, labels=[\"\"] * len(y_diseases))\n", 526 | "\n", 527 | " ax.spines[\"top\"].set_position((\"data\", len(diseases) + 0.25))\n", 528 | " ax.spines[\"top\"].set_color(\"w\")" 529 | ] 530 | }, 531 | { 532 | "cell_type": "markdown", 533 | "id": "d7611997", 534 | "metadata": {}, 535 | "source": [ 536 | "#### Plot with the initial layout\n", 537 | "\n", 538 | "The x and y axis thicknesses are not known since the ticks and label sizes cannot be known apriori.\n", 539 | "To account for this, simply plot the figure with the current layout (this may result in cut-off labels). \n", 540 | "You can then update the layout with the plotted axes to account for the axis thicknesses." 541 | ] 542 | }, 543 | { 544 | "cell_type": "code", 545 | "execution_count": null, 546 | "id": "4da120e8-d200-4ad1-9714-49adf76afa1e", 547 | "metadata": {}, 548 | "outputs": [], 549 | "source": [ 550 | "## Here we plot the actual NYT figure from the article\n", 551 | "\n", 552 | "# The `subplots` function uses the solved primitives to create figure and axes objects with the determined sizes\n", 553 | "# `axs` is a dictionary with keys matching the axes names\n", 554 | "fig, axs = subplots(prim_tree_n)\n", 555 | "\n", 556 | "for ax in axs.values():\n", 557 | " ax.set_xlim(0, 200000)\n", 558 | "\n", 559 | "# Plot the men/womens data\n", 560 | "axs[\"AxesLeft\"].barh(y_diseases, women_cases, height=0.8, fc=\"red\", alpha=0.1)\n", 561 | "axs[\"AxesLeft\"].barh(y_diseases, women_deaths, height=0.55, fc=\"red\", alpha=0.5)\n", 562 | "axs[\"AxesLeft\"].xaxis.set_inverted(True)\n", 563 | "\n", 564 | "axs[\"AxesRight\"].barh(y_diseases, men_cases, height=0.8, fc=\"blue\", alpha=0.1)\n", 565 | "axs[\"AxesRight\"].barh(y_diseases, men_deaths, height=0.55, fc=\"blue\", alpha=0.5)\n", 566 | "\n", 567 | "axs_labels = [\"AxesLeft\", \"AxesRight\"]\n", 568 | "axs_categories = [\"women\", \"men\"]\n", 569 | "category_to_ax = {\n", 570 | " category: axs[key]\n", 571 | " for category, key in zip(axs_categories, axs_labels)\n", 572 | "}\n", 573 | "for category, ax in category_to_ax.items():\n", 574 | " format_axes(ax)\n", 575 | " ax.set_xticks(\n", 576 | " [0, 50000, 100000, 150000, 200000],\n", 577 | " [category.upper(), \"50,000\", \"100,000\", \"150,000\", \"200,000\"],\n", 578 | " )\n", 579 | " ax.grid(which=\"major\", axis=\"x\", color=\"white\")\n", 580 | " ax.get_xticklabels()[0].set_weight(\"bold\")\n", 581 | "\n", 582 | "# Add ylabels to 'AxesMid'\n", 583 | "axs[\"AxesMid\"].set_axis_off()\n", 584 | "axs[\"AxesMid\"].set_ylim(axs[\"AxesLeft\"].get_ylim())\n", 585 | "axs[\"AxesMid\"].set_xlim(-1, 1)\n", 586 | "\n", 587 | "for y, disease_name in zip(y_diseases, diseases):\n", 588 | " axs[\"AxesMid\"].text(0, y, disease_name, ha=\"center\", va=\"center\")\n", 589 | "\n", 590 | "# Add the \"NEW CASES\" and \"DEATHS\" annotations\n", 591 | "# Devil hides in the details...\n", 592 | "arrowprops = {\"arrowstyle\": \"-\", \"connectionstyle\": \"angle,angleA=0,angleB=90,rad=0\"}\n", 593 | "\n", 594 | "x = women_cases[-1]\n", 595 | "y = y_diseases[-1]\n", 596 | "axs[\"AxesLeft\"].annotate(\n", 597 | " \"NEW CASES\",\n", 598 | " xy=(0.9 * x, y),\n", 599 | " xycoords=\"data\",\n", 600 | " ha=\"right\",\n", 601 | " fontsize=10,\n", 602 | " xytext=(-40, -3),\n", 603 | " textcoords=\"offset points\",\n", 604 | " arrowprops=arrowprops,\n", 605 | ")\n", 606 | "\n", 607 | "x = women_deaths[-1]\n", 608 | "axs[\"AxesLeft\"].annotate(\n", 609 | " \"DEATHS\",\n", 610 | " xy=(0.85 * x, y),\n", 611 | " xycoords=\"data\",\n", 612 | " ha=\"right\",\n", 613 | " fontsize=10,\n", 614 | " xytext=(-50, -25),\n", 615 | " textcoords=\"offset points\",\n", 616 | " arrowprops=arrowprops,\n", 617 | ")\n", 618 | "\n", 619 | "x = men_cases[-1]\n", 620 | "axs[\"AxesRight\"].annotate(\n", 621 | " \"NEW CASES\",\n", 622 | " xy=(0.9 * x, y),\n", 623 | " xycoords=\"data\",\n", 624 | " ha=\"left\",\n", 625 | " fontsize=10,\n", 626 | " xytext=(+40, -3),\n", 627 | " textcoords=\"offset points\",\n", 628 | " arrowprops=arrowprops,\n", 629 | ")\n", 630 | "\n", 631 | "x = men_deaths[-1]\n", 632 | "axs[\"AxesRight\"].annotate(\n", 633 | " \"DEATHS\",\n", 634 | " xy=(0.9 * x, y),\n", 635 | " xycoords=\"data\",\n", 636 | " ha=\"left\",\n", 637 | " fontsize=10,\n", 638 | " xytext=(+50, -25),\n", 639 | " textcoords=\"offset points\",\n", 640 | " arrowprops=arrowprops,\n", 641 | ")\n", 642 | "\n", 643 | "# Add the caption text\n", 644 | "axs[\"AxesLeft\"].text(\n", 645 | " 165000, 8.2, \"Leading Causes\\nOf Cancer Deaths\", fontsize=18, va=\"top\"\n", 646 | ")\n", 647 | "axs[\"AxesLeft\"].text(\n", 648 | " 165000,\n", 649 | " 7,\n", 650 | " \"In 2007, there were more\\n\"\n", 651 | " \"than 1.4 million new cases\\n\"\n", 652 | " \"of cancer in the United States.\",\n", 653 | " va=\"top\",\n", 654 | " fontsize=10,\n", 655 | ")\n", 656 | "\n", 657 | "fig.savefig(\"ten_simple_rules_demo_no_axis_thickness.svg\")" 658 | ] 659 | }, 660 | { 661 | "cell_type": "markdown", 662 | "id": "9e5c12c7", 663 | "metadata": {}, 664 | "source": [ 665 | "#### Update x and y axis thicknesses in the plot\n", 666 | "\n", 667 | "Now that data is plotted, the x/y have tick labels that modify their thickenss.\n", 668 | "You can use `update_layout_constraints` and the `XAxisThickness` and `YAxisThickness` constraints to update the axis sizes and adjust the figure layout." 669 | ] 670 | }, 671 | { 672 | "cell_type": "code", 673 | "execution_count": null, 674 | "id": "e660bb3a", 675 | "metadata": {}, 676 | "outputs": [], 677 | "source": [ 678 | "# Update boundings boxes for the x/y axis now that text has been inserted\n", 679 | "# This will update the layout of axes\n", 680 | "layout = update_layout_constraints(layout, axs)\n", 681 | "\n", 682 | "prim_tree_n, info = solve(layout)\n", 683 | "# This updates the original `fig` and `axs` plot with the new layout\n", 684 | "fig, axs = update_subplots(prim_tree_n, \"Figure\", fig, axs)\n", 685 | "fig.savefig(\"ten_simple_rules_demo.svg\")\n", 686 | "\n", 687 | "fig" 688 | ] 689 | }, 690 | { 691 | "cell_type": "code", 692 | "execution_count": null, 693 | "id": "f12e2ddf", 694 | "metadata": {}, 695 | "outputs": [], 696 | "source": [ 697 | "# If you plot the layout after axis sizes are updated, you can see the altered dimensions!\n", 698 | "fig, ax = figure_layout(layout)" 699 | ] 700 | }, 701 | { 702 | "cell_type": "code", 703 | "execution_count": null, 704 | "id": "f983b76f", 705 | "metadata": {}, 706 | "outputs": [], 707 | "source": [] 708 | } 709 | ], 710 | "metadata": { 711 | "kernelspec": { 712 | "display_name": "numerics", 713 | "language": "python", 714 | "name": "python3" 715 | }, 716 | "language_info": { 717 | "codemirror_mode": { 718 | "name": "ipython", 719 | "version": 3 720 | }, 721 | "file_extension": ".py", 722 | "mimetype": "text/x-python", 723 | "name": "python", 724 | "nbconvert_exporter": "python", 725 | "pygments_lexer": "ipython3", 726 | "version": "3.12.7" 727 | } 728 | }, 729 | "nbformat": 4, 730 | "nbformat_minor": 5 731 | } 732 | -------------------------------------------------------------------------------- /examples/tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Tutorial\n", 8 | "\n", 9 | "This tutorial demonstrates the main workflow for using `mpllayout`.\n", 10 | "\n", 11 | "The workflow involves just a few high-level steps:\n", 12 | "1. Create a layout object to store the layout, `layout = Layout()`\n", 13 | "2. Add geometric primitives to `layout` using `layout.add_prim`. These primitives represent figure elements.\n", 14 | "3. Add geometric constraints to `layout` using `layout.add_constraint` to constrain the primitives.\n", 15 | "4. Solve the constrained layout of primitives using `constrained_prims, solve_info = solve(layout)`\n", 16 | "5. Generate a figure and axes to plot in using `fig, axs = subplots(constrained_prims)`\n", 17 | "\n", 18 | "The generated `fig` and `axs` will reflect the constrained layout." 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "import numpy as np\n", 28 | "\n", 29 | "import matplotlib as mpl\n", 30 | "from matplotlib import pyplot as plt\n", 31 | "\n", 32 | "# `layout` contains the `Layout` class and related functions\n", 33 | "from mpllayout import layout as lay\n", 34 | "# `primitives` contains primitives and `constraints` constraints\n", 35 | "from mpllayout import primitives as pr\n", 36 | "from mpllayout import constraints as co\n", 37 | "# `solve` is used to solve the constrained layout\n", 38 | "from mpllayout.solver import solve\n", 39 | "\n", 40 | "# `subplots` and `update_subplots` are used to create matplotlib figure and\n", 41 | "# axes objects from geometric primitives\n", 42 | "from mpllayout.matplotlibutils import subplots, update_subplots\n", 43 | "\n", 44 | "# `ui` contains functions to visualize primitives\n", 45 | "from mpllayout import ui" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": {}, 51 | "source": [ 52 | "## Step 1: Create the layout" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "# Create the layout to store constraints and primitives\n", 62 | "layout = lay.Layout()" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "metadata": {}, 68 | "source": [ 69 | "## Step 2: Add geometric primitives\n", 70 | "\n", 71 | "Geometric primitives represent geometry and are defined in `mpllayout.primitives`.\n", 72 | "Each primitive consists of a parameter vector (`primitive.value`) with child primitives (`primitive[\"child_key\"]`).\n", 73 | "For example:\n", 74 | "\n", 75 | "* `Point` represents a point and has a parameter vector containing its coordinates with no child primitives\n", 76 | "* `Line` represents a straight line segment, has no parameter vector, and contains two points representing the start point (`line[\"Point0\"]`) and end point (`line[\"Point1\"]`)\n", 77 | "* Other primitives are documented in the module\n", 78 | "\n", 79 | "Geometric primitives are added using the call\n", 80 | "`layout.add_prim(primitive, key)`\n", 81 | "where `primitive` is a geometric primitive object and `key` is a string used to identify it." 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "# A `Quadrilateral` is a 4 sided polygon which can be used to represent the figure box.\n", 91 | "# Naming the quad \"Figure\" will cause the `subplots` command to create a figure of the same size.\n", 92 | "layout.add_prim(pr.Quadrilateral(), \"Figure\")\n", 93 | "\n", 94 | "# The `Axes` primitive is a collection of quadrilaterals and points used to represent an axes.\n", 95 | "# The child primitives of `Axes` are\n", 96 | "# - \"Frame\": a `Quadrilateral` representing the plotting area of the axes\n", 97 | "# - \"XAxis\": a `Quadrilateral` bounding x-axis ticks and tick labels\n", 98 | "# - \"XAxisLabel\": a `Point` for the x-axis label text anchor\n", 99 | "# - \"YAxis\": a `Quadrilateral` bounding y-axis ticks and tick labels\n", 100 | "# - \"YAxisLabel\": a `Point` for the y-axis label text anchor\n", 101 | "# The x/y axis can be optionally included by kwargs as seen below\n", 102 | "layout.add_prim(pr.Axes(xaxis=True, yaxis=True), \"Axes\")" 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "metadata": {}, 108 | "source": [ 109 | "## Step 3: Add geometric constraints\n", 110 | "\n", 111 | "Geometric constraints represent constraints or conditions on primitives and are defined in `mpllayout.constraints`.\n", 112 | "Every constraint has a method representing the condition\n", 113 | "`Constraint.assem_res(prims, **kwargs)`, \n", 114 | "where `prims` is a tuple of primitives the constraint applies to and `**kwargs` are constraint specific parameters.\n", 115 | "The constraint is satisfied when `assem_res` is 0.\n", 116 | "\n", 117 | "For example:\n", 118 | "\n", 119 | "* `Coincident().assem_res((pointa, pointb))` represents the coincidence error between two `pointa` and `pointb` (no parameters are needed).\n", 120 | "* `Length().assem_res((linea,), length=7)` represents the length error for `linea` compared to the desired length of 7.\n", 121 | "* Other constraints are documented in the module\n", 122 | "\n", 123 | "Geometric constraints are added to a layout using the call\n", 124 | "`layout.add_constraint(constraint, prim_keys, constraint_params)`,\n", 125 | "where \n", 126 | "\n", 127 | "* `constraint` is the geometric constraint object\n", 128 | "* `prim_keys` is a tuple of primitive keys representing the primitives to constrain\n", 129 | "* `constraint_params` is a dictionary or tuple of constraint specific parameters.\n", 130 | "\n", 131 | "`prim_keys` can recursively indicate primitives uses slash separated keys.\n", 132 | "For example the tuple `(\"Figure/Line0/Point0\", \"Axes/Frame/Line0\")` represents the point 0 of the figure quadrilateral (the bottom left) and line 0 of the axes frame quadrilateral (the bottom line).\n", 133 | "\n", 134 | "`constraint_params` represents the `**kwargs` of `assem_res`.\n", 135 | "This can be either a dictionary or tuple representing the kwargs.\n", 136 | "\n", 137 | "The next few sections add sets of constraints and plot the resulting constrained layout to illustrate their effect.\n" 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": {}, 143 | "source": [ 144 | "### Make all quadrilaterals rectangular\n", 145 | "\n", 146 | "`Quadrilateral`s have 4 unknown coordinates for each corner. \n", 147 | "To make them rectangular boxes like axes and figures, apply the `Box` constraint." 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "# NOTE: This step is needed because `Quadrilateral`s don't have to be rectangular\n", 157 | "layout.add_constraint(co.Box(), (\"Figure\",), ())\n", 158 | "layout.add_constraint(co.Box(), (\"Axes/Frame\",), ())\n", 159 | "layout.add_constraint(co.Box(), (\"Axes/XAxis\",), ())\n", 160 | "layout.add_constraint(co.Box(), (\"Axes/YAxis\",), ())" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "# The layout currently looks like:\n", 170 | "_fig, _ = ui.figure_prims(solve(layout)[0], fig_size=(5, 5))" 171 | ] 172 | }, 173 | { 174 | "cell_type": "markdown", 175 | "metadata": {}, 176 | "source": [ 177 | "### Fix the figure position and size" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": null, 183 | "metadata": {}, 184 | "outputs": [], 185 | "source": [ 186 | "# Fix the figure bottom left to the origin\n", 187 | "layout.add_constraint(co.Fix(), (\"Figure/Line0/Point0\",), (np.array([0, 0]),))\n", 188 | "\n", 189 | "# Figure the figure width and height\n", 190 | "fig_width, fig_height = 6, 3\n", 191 | "layout.add_constraint(co.XLength(), (\"Figure/Line0\",), (fig_width,))\n", 192 | "layout.add_constraint(co.YLength(), (\"Figure/Line1\",), (fig_height,))" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": null, 198 | "metadata": {}, 199 | "outputs": [], 200 | "source": [ 201 | "# The layout currently looks like:\n", 202 | "_fig, _ = ui.figure_prims(solve(layout)[0], fig_size=(5, 5))" 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "metadata": {}, 208 | "source": [ 209 | "### Position the x and y axis and axis labels" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": null, 215 | "metadata": {}, 216 | "outputs": [], 217 | "source": [ 218 | "## Position the x axis on top and the y axis on the bottom\n", 219 | "# When creating axes from the primitives, `subplots` will detect axis\n", 220 | "# positions and set axis properties to reflect them.\n", 221 | "layout.add_constraint(co.PositionXAxis(side='top'), (\"Axes\", ), ())\n", 222 | "layout.add_constraint(co.PositionYAxis(side='right'), (\"Axes\", ), ())\n", 223 | "\n", 224 | "# Link x/y axis width/height to axis sizes in matplotlib.\n", 225 | "# Axis sizes change depending on the size of their tick labels so the\n", 226 | "# axis width/height must be linked to matplotlib and updated from plot\n", 227 | "# elements.\n", 228 | "layout.add_constraint(\n", 229 | " co.XAxisThickness(), (\"Axes/XAxis\",), (None,),\n", 230 | ")\n", 231 | "layout.add_constraint(\n", 232 | " co.YAxisThickness(), (\"Axes/YAxis\",), (None,),\n", 233 | ")\n", 234 | "\n", 235 | "## Position the x/y axis label text anchors\n", 236 | "# When creating axes from the primitives, `subplots` will detect these and set\n", 237 | "# their locations\n", 238 | "on_line = co.RelativePointOnLineDistance()\n", 239 | "to_line = co.PointToLineDistance()\n", 240 | "\n", 241 | "## Pad the x/y axis label from the axis bbox\n", 242 | "pad = 1/16\n", 243 | "layout.add_constraint(to_line, (\"Axes/XAxisLabel\", \"Axes/XAxis/Line2\"), (True, pad))\n", 244 | "layout.add_constraint(to_line, (\"Axes/YAxisLabel\", \"Axes/YAxis/Line1\"), (True, pad))\n", 245 | "\n", 246 | "## Center the axis labels halfway along the axes width/height\n", 247 | "layout.add_constraint(co.PositionXAxisLabel(), (\"Axes\",), (0.5,))\n", 248 | "layout.add_constraint(co.PositionYAxisLabel(), (\"Axes\",), (0.5,))" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": null, 254 | "metadata": {}, 255 | "outputs": [], 256 | "source": [ 257 | "# The layout currently looks like:\n", 258 | "_fig, _ = ui.figure_prims(solve(layout)[0], fig_size=(5, 5))" 259 | ] 260 | }, 261 | { 262 | "cell_type": "markdown", 263 | "metadata": {}, 264 | "source": [ 265 | "### Set margins between the axes and figure" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": null, 271 | "metadata": {}, 272 | "outputs": [], 273 | "source": [ 274 | "## Constrain margins around the axes to the figure\n", 275 | "# Constrain left/right margins\n", 276 | "margin_left = 0.1\n", 277 | "margin_right = 1/4\n", 278 | "\n", 279 | "layout.add_constraint(\n", 280 | " co.InnerMargin(side='left'), (\"Axes/Frame\", \"Figure\"), (margin_left,)\n", 281 | ")\n", 282 | "layout.add_constraint(\n", 283 | " co.InnerMargin(side='right'), (\"Axes/YAxis\", \"Figure\"), (margin_right,)\n", 284 | ")\n", 285 | "\n", 286 | "# Constrain top/bottom margins\n", 287 | "margin_top = 1/4\n", 288 | "margin_bottom = 0.1\n", 289 | "layout.add_constraint(\n", 290 | " co.InnerMargin(side='bottom'), (\"Axes/Frame\", \"Figure\"), (margin_bottom,)\n", 291 | ")\n", 292 | "layout.add_constraint(\n", 293 | " co.InnerMargin(side='top'), (\"Axes/XAxis\", \"Figure\"), (margin_top,)\n", 294 | ")\n" 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": null, 300 | "metadata": {}, 301 | "outputs": [], 302 | "source": [ 303 | "# The layout currently looks like:\n", 304 | "_fig, _ = ui.figure_prims(solve(layout)[0], fig_size=(5, 5))" 305 | ] 306 | }, 307 | { 308 | "cell_type": "markdown", 309 | "metadata": {}, 310 | "source": [ 311 | "## Step 4: Solve the layout " 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": null, 317 | "metadata": {}, 318 | "outputs": [], 319 | "source": [ 320 | "## Solve the constraints and form the figure/axes layout\n", 321 | "prim_tree_n, solve_info = solve(layout)\n", 322 | "\n", 323 | "print(f\"Absolute errors: {solve_info['abs_errs']}\")\n", 324 | "print(f\"Relative errors: {solve_info['rel_errs']}\")" 325 | ] 326 | }, 327 | { 328 | "cell_type": "markdown", 329 | "metadata": {}, 330 | "source": [ 331 | "## Step 5: Plot a figure using the layout " 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": null, 337 | "metadata": {}, 338 | "outputs": [], 339 | "source": [ 340 | "## Plot into the generated figure and axes\n", 341 | "fig, axs = subplots(prim_tree_n)\n", 342 | "\n", 343 | "x = np.linspace(0, 1)\n", 344 | "axs[\"Axes\"].plot(x, x**2)\n", 345 | "\n", 346 | "axs[\"Axes\"].xaxis.set_label_text(\"My x label\", ha=\"center\", va=\"bottom\")\n", 347 | "axs[\"Axes\"].yaxis.set_label_text(\"My y label\", ha=\"center\", va=\"bottom\", rotation=-90)\n", 348 | "\n", 349 | "ax = axs[\"Axes\"]\n", 350 | "\n", 351 | "# Using the generated axes and x/y axis contents, the layout constraints\n", 352 | "# can be updated with those matplotlib elements\n", 353 | "layout = lay.update_layout_constraints(layout, axs)\n", 354 | "prim_tree_n, solve_info = solve(layout)\n", 355 | "\n", 356 | "# This updates the figure and axes using the updated layout\n", 357 | "update_subplots(prim_tree_n, \"Figure\", fig, axs)" 358 | ] 359 | }, 360 | { 361 | "cell_type": "code", 362 | "execution_count": null, 363 | "metadata": {}, 364 | "outputs": [], 365 | "source": [ 366 | "# The layout currently looks like:\n", 367 | "_fig, _ = ui.figure_prims(solve(layout)[0], fig_size=(5, 5))\n", 368 | "\n", 369 | "# Note that x and y axis dimensions have adjusted" 370 | ] 371 | }, 372 | { 373 | "cell_type": "code", 374 | "execution_count": null, 375 | "metadata": {}, 376 | "outputs": [], 377 | "source": [] 378 | } 379 | ], 380 | "metadata": { 381 | "kernelspec": { 382 | "display_name": "numerics", 383 | "language": "python", 384 | "name": "python3" 385 | }, 386 | "language_info": { 387 | "codemirror_mode": { 388 | "name": "ipython", 389 | "version": 3 390 | }, 391 | "file_extension": ".py", 392 | "mimetype": "text/x-python", 393 | "name": "python", 394 | "nbconvert_exporter": "python", 395 | "pygments_lexer": "ipython3", 396 | "version": "3.12.7" 397 | } 398 | }, 399 | "nbformat": 4, 400 | "nbformat_minor": 2 401 | } 402 | -------------------------------------------------------------------------------- /examples/two_axes.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Two axes figure example\n", 8 | "\n", 9 | "This example illustrates creating a two-axes layout.\n" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import numpy as np\n", 19 | "\n", 20 | "from mpllayout import layout as lay\n", 21 | "from mpllayout import primitives as pr\n", 22 | "from mpllayout import constraints as co\n", 23 | "from mpllayout import solver\n", 24 | "from mpllayout import ui\n", 25 | "from mpllayout import matplotlibutils as lplt" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "## Specify the layout" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "# Create the layout object\n", 42 | "layout = lay.Layout()\n", 43 | "\n", 44 | "## Create an origin point fixed at (0, 0)\n", 45 | "layout.add_prim(pr.Point(), \"Origin\")\n", 46 | "layout.add_constraint(co.Fix(), (\"Origin\",), (np.array([0, 0]),))\n", 47 | "\n", 48 | "# Create a box to represent the figure\n", 49 | "layout.add_prim(pr.Quadrilateral(), \"Figure\")\n", 50 | "layout.add_constraint(co.Box(), (\"Figure\",), ())\n", 51 | "\n", 52 | "# Create `Axes` objects to represent the left and right axes\n", 53 | "layout.add_prim(pr.Axes(), \"AxesLeft\")\n", 54 | "layout.add_prim(pr.Axes(), \"AxesRight\")\n", 55 | "layout.add_constraint(co.Box(), (\"AxesLeft/Frame\",), ())\n", 56 | "layout.add_constraint(co.Box(), (\"AxesRight/Frame\",), ())\n", 57 | "\n", 58 | "# Constrain the width and height of the figure\n", 59 | "fig_width, fig_height = 6, 3\n", 60 | "layout.add_constraint(co.Width(), (\"Figure\",), (fig_width,))\n", 61 | "layout.add_constraint(co.Height(), (\"Figure\",), (fig_height,))\n", 62 | "\n", 63 | "# Constrain the bottom-left corner of the figure to the origin\n", 64 | "layout.add_constraint(co.Coincident(), (\"Figure/Line0/Point0\", \"Origin\"), ())\n", 65 | "\n", 66 | "# Constrain the left/right margins to `AxesLeft` and `AxesRight`, respectively\n", 67 | "margin_left = 0.5\n", 68 | "margin_right = 0.5\n", 69 | "\n", 70 | "layout.add_constraint(\n", 71 | " co.InnerMargin(side='left'), (\"AxesLeft/Frame\", \"Figure\",), (margin_left,)\n", 72 | ")\n", 73 | "layout.add_constraint(\n", 74 | " co.InnerMargin(side='right'), (\"AxesRight/Frame\", \"Figure\"), (margin_right,)\n", 75 | ")\n", 76 | "\n", 77 | "# Constrain the gap between the left and right axes\n", 78 | "margin_inter = 0.5\n", 79 | "layout.add_constraint(\n", 80 | " co.OuterMargin(side='right'), (\"AxesLeft/Frame\", \"AxesRight/Frame\"), (margin_inter,)\n", 81 | ")\n", 82 | "\n", 83 | "# Constrain the top/bottom margins on the left axes ('AxesLeft')\n", 84 | "# We can align the left/right axes to implicity set those margins\n", 85 | "margin_top = 1.0\n", 86 | "margin_bottom = 0.5\n", 87 | "layout.add_constraint(\n", 88 | " co.InnerMargin(side=\"bottom\"), (\"AxesLeft/Frame\", \"Figure\"), (margin_bottom,)\n", 89 | ")\n", 90 | "layout.add_constraint(\n", 91 | " co.InnerMargin(side=\"top\"), (\"AxesLeft/Frame\", \"Figure\"), (margin_top,)\n", 92 | ")\n", 93 | "\n", 94 | "# Align the left/right axes in a row\n", 95 | "layout.add_constraint(co.AlignRow(), (\"AxesLeft/Frame\", \"AxesRight/Frame\"), ())\n", 96 | "\n", 97 | "# Constrain the width of 'AxesLeft'\n", 98 | "# Note that the right axes width is already constrained through the margins and\n", 99 | "# known left axes width\n", 100 | "width = 2\n", 101 | "layout.add_constraint(co.Width(), (\"AxesLeft/Frame\",), (width,))" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "metadata": {}, 107 | "source": [ 108 | "## Solve the layout" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "# Solve and plot the constrained layout\n", 118 | "prims, info = solver.solve(layout, max_iter=40, rel_tol=1e-9)\n", 119 | "\n", 120 | "fig_layout, ax_layout = ui.figure_prims(prims)" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "metadata": {}, 126 | "source": [ 127 | "## Plot a figure using the layout" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "## Create a figure and axes from the constrained primitives\n", 137 | "fig, axs = lplt.subplots(prims)\n", 138 | "\n", 139 | "x = np.linspace(0, 1)\n", 140 | "axs[\"AxesLeft\"].plot(x, 4 * x)\n", 141 | "axs[\"AxesRight\"].plot(x, x**2)\n", 142 | "\n", 143 | "fig.savefig(\"two_axes.png\")\n" 144 | ] 145 | } 146 | ], 147 | "metadata": { 148 | "kernelspec": { 149 | "display_name": "numerics", 150 | "language": "python", 151 | "name": "python3" 152 | }, 153 | "language_info": { 154 | "codemirror_mode": { 155 | "name": "ipython", 156 | "version": 3 157 | }, 158 | "file_extension": ".py", 159 | "mimetype": "text/x-python", 160 | "name": "python", 161 | "nbconvert_exporter": "python", 162 | "pygments_lexer": "ipython3", 163 | "version": "3.12.7" 164 | } 165 | }, 166 | "nbformat": 4, 167 | "nbformat_minor": 2 168 | } 169 | -------------------------------------------------------------------------------- /logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 17 | 19 | 43 | 52 | 53 | 55 | 56 | 58 | image/svg+xml 59 | 61 | 62 | 63 | 64 | 65 | 69 | 77 | 85 | 93 | 101 | 109 | 117 | 125 | 133 | 141 | 149 | 157 | 165 | 173 | 181 | 189 | 197 | 205 | 213 | 221 | 229 | 237 | 245 | 253 | 254 | 255 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0", "jax", "numpy", "matplotlib"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "matplotlib-layout" 7 | version = "0.1.0" 8 | authors = [ 9 | { name="Jonathan Deng", email="jonathan.j.deng@gmail.com" }, 10 | ] 11 | description = "A package for laying out figures with geometric constraints" 12 | readme = "README.md" 13 | requires-python = ">=3.7" 14 | classifiers = [ 15 | "Programming Language :: Python :: 3", 16 | "License :: OSI Approved :: MIT License", 17 | "Operating System :: OS Independent", 18 | ] 19 | 20 | [project.urls] 21 | "Homepage" = "https://github.com/jon-deng/mpl-layout" 22 | "Bug Tracker" = "https://github.com/jon-deng/mpl-layout/issues" 23 | -------------------------------------------------------------------------------- /src/mpllayout/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jon-deng/mpl-layout/6a10aeff3dfa6e989cc8d9beefa55b78527ebe6a/src/mpllayout/__init__.py -------------------------------------------------------------------------------- /src/mpllayout/containers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tree class (`Node`) definition and related utilities for trees 3 | 4 | This module defines a tree class `Node` and related utilities. 5 | This class is used by itself as well as to define geometric primitives and constraints. 6 | """ 7 | 8 | from typing import TypeVar, Generic, Any, Iterable, Callable 9 | 10 | import itertools 11 | import functools 12 | 13 | import jax 14 | 15 | TValue = TypeVar("TValue") 16 | TNode = TypeVar("TNode", bound="Node") 17 | 18 | class Node(Generic[TValue]): 19 | """ 20 | Tree structure with labelled child nodes 21 | 22 | Parameters 23 | ---------- 24 | value: TValue 25 | A value associated with the node 26 | children: dict[str, Node] 27 | A dictionary of child nodes 28 | 29 | Attributes 30 | ---------- 31 | value: TValue 32 | The value stored in the node 33 | children: dict[str, Node] 34 | A dictionary of child nodes 35 | """ 36 | 37 | def __init__(self: TNode, value: TValue, children: dict[str, TNode]): 38 | assert isinstance(children, dict) 39 | self._value = value 40 | self._children = children 41 | 42 | @classmethod 43 | def from_tree(cls, value: TValue, children: dict[str, TNode]): 44 | """ 45 | Return any `Node` subclass from its value and children 46 | 47 | This method is needed because some `Node` subclasses have different `__init__` 48 | signatures. 49 | `from_tree` can be used to recreate any `Node` subclass using just a known value 50 | and children. 51 | This is particularly important for flattening and unflattening a tree. 52 | """ 53 | node = super().__new__(cls) 54 | Node.__init__(node, value, children) 55 | return node 56 | 57 | 58 | ## Tree methods 59 | 60 | @property 61 | def value(self) -> TValue: 62 | """ 63 | Return the node value 64 | """ 65 | return self._value 66 | 67 | @property 68 | def children(self: TNode) -> dict[str, TNode]: 69 | """ 70 | Return all child nodes 71 | """ 72 | return self._children 73 | 74 | def node_height(self) -> int: 75 | """ 76 | Return the height of a node 77 | 78 | Returns 79 | ------- 80 | int 81 | The node height 82 | 83 | The node height is the number of edges from the current node to the 84 | "furthest" child node. 85 | """ 86 | if len(self) == 0: 87 | return 0 88 | else: 89 | return 1 + max(child.node_height() for _, child in self.items()) 90 | 91 | def get_child(self: TNode, key: str) -> TNode: 92 | split_key = key.split("/", 1) 93 | parent_key, child_keys = split_key[0], split_key[1:] 94 | 95 | try: 96 | if len(child_keys) == 0: 97 | return self._get_child_nonrecursive(parent_key) 98 | else: 99 | return self.children[parent_key].get_child(child_keys[0]) 100 | except KeyError as err: 101 | raise KeyError(f"{key}") from err 102 | 103 | def _get_child_nonrecursive(self: TNode, key: str) -> TNode: 104 | return self.children[key] 105 | 106 | def add_child(self: TNode, key: str, child: TNode): 107 | """ 108 | Add a child node at the given key 109 | 110 | Raises an error if the key already exists. 111 | 112 | Parameters 113 | ---------- 114 | key: str 115 | A child node key 116 | 117 | see `__getitem__` 118 | node: Node 119 | The node to set 120 | """ 121 | split_key = key.split("/", 1) 122 | parent_key, child_keys = split_key[0], split_key[1:] 123 | 124 | try: 125 | if len(child_keys) == 0: 126 | self._add_child_nonrecursive(parent_key, child) 127 | else: 128 | self.children[parent_key].add_child(child_keys[0], child) 129 | 130 | except KeyError as err: 131 | raise KeyError(f"{key}") from err 132 | 133 | def _add_child_nonrecursive(self: TNode, key: str, child: TNode): 134 | """ 135 | Add a primitive indexed by a key 136 | 137 | Base case of recursive `add_child` 138 | """ 139 | if key in self.children: 140 | raise KeyError(f"{key}") 141 | else: 142 | self.children[key] = child 143 | 144 | def copy(self): 145 | def identity(value: TValue) -> TValue: 146 | return value 147 | return map(identity, self) 148 | 149 | ## String methods 150 | 151 | def __repr__(self) -> str: 152 | keys_repr = ", ".join(self.keys()) 153 | children_repr = ", ".join([node.__repr__() for _, node in self.children.items()]) 154 | return f"{type(self).__name__}({self.value}, ({children_repr}), ({keys_repr}))" 155 | 156 | def __str__(self) -> str: 157 | return self.__repr__() 158 | 159 | ## Dict methods 160 | 161 | def __iter__(self): 162 | return self.children.__iter__() 163 | 164 | def __contains__(self, key: str) -> bool: 165 | split_keys = key.split("/", 1) 166 | parent_key = split_keys[0] 167 | child_key = "/".join(split_keys[1:]) 168 | 169 | if child_key == "": 170 | return parent_key in self.children 171 | else: 172 | return child_key in self[parent_key] 173 | 174 | def __len__(self) -> int: 175 | return len(self.children) 176 | 177 | def keys(self) -> list[str]: 178 | """ 179 | Return child keys 180 | """ 181 | return list(self.children.keys()) 182 | 183 | def values(self: TNode) -> list[TNode]: 184 | """ 185 | Return child nodes 186 | """ 187 | return list(self.children.values()) 188 | 189 | def items(self): 190 | """ 191 | Return an iterator of (child key, child node) pairs 192 | """ 193 | return self.children.items() 194 | 195 | def __setitem__(self: TNode, key: str, node: TNode): 196 | """ 197 | Set the child node at the given key 198 | 199 | Raises an error if the key doesn't exist. 200 | 201 | Parameters 202 | ---------- 203 | key: str 204 | A child node key 205 | 206 | see `__getitem__` 207 | node: Node 208 | The node to set 209 | """ 210 | # This splits `key = 'a/b/c/d'` 211 | # into `parent_key = 'a/b/c'` and `child_key = 'd'` 212 | split_keys = key.split("/") 213 | parent_key = "/".join(split_keys[:-1]) 214 | child_key = split_keys[-1] 215 | 216 | if key not in self: 217 | raise KeyError(key) 218 | else: 219 | if parent_key == "": 220 | parent = self 221 | else: 222 | parent = self[parent_key] 223 | parent.children[child_key] = node 224 | 225 | def __getitem__(self: TNode, key: str | int | slice) -> TNode | list[TNode]: 226 | """ 227 | Return the value indexed by a slash-separated key 228 | 229 | Parameters 230 | ---------- 231 | key: str | int | slice 232 | A child node key 233 | 234 | The interpretation depends on the key type: 235 | - `str` keys indicate a child key and can be slash separated to denote 236 | child keys of child keys, 237 | for example, 'childa/grandchildb/greatgrandchildc'. 238 | - `int` keys indicate a child by integer index. 239 | - `slice` keys indicate a range of children. 240 | """ 241 | if isinstance(key, str): 242 | return self.get_child(key) 243 | elif isinstance(key, (int, slice)): 244 | return list(self.children.values())[key] 245 | else: 246 | raise TypeError("") 247 | 248 | 249 | TItem = TypeVar("TItem") 250 | 251 | class ItemCounter(Generic[TItem]): 252 | """ 253 | Count the number of added items by category (a string) 254 | 255 | This is used to generate unique string keys for objects 256 | (see `Layout.add_constraint`). 257 | 258 | Parameters 259 | ---------- 260 | categorize: Callable[[TItem], str] 261 | A function that returns the category (string) of an item 262 | 263 | Attributes 264 | ---------- 265 | category_counts: dict[str, int] 266 | The number of items added to each category 267 | """ 268 | 269 | @staticmethod 270 | def categorize_by_classname(item: TItem) -> str: 271 | return type(item).__name__ 272 | 273 | def __init__(self, categorize: Callable[[TItem], str] = categorize_by_classname): 274 | self._category_counts = {} 275 | self._categorize = categorize 276 | 277 | @property 278 | def category_counts(self) -> dict[str, int]: 279 | return self._category_counts 280 | 281 | def __contains__(self, key): 282 | return key in self._p 283 | 284 | def categorize(self, item: TItem) -> str: 285 | """ 286 | Return the category string of an item 287 | """ 288 | return self._categorize(item) 289 | 290 | def add_item(self, item: TItem) -> str: 291 | """ 292 | Add an item 293 | 294 | Parameters 295 | ---------- 296 | item: TItem 297 | The item to add 298 | 299 | Returns 300 | ------- 301 | str 302 | A string identifying the added item's category and count 303 | """ 304 | category = self.categorize(item) 305 | if category in self.category_counts: 306 | self.category_counts[category] += 1 307 | else: 308 | self.category_counts[category] = 1 309 | 310 | count = self.category_counts[category] - 1 311 | return f"{category}{count}" 312 | 313 | def add_item_until_valid(self, item: TItem, valid: Callable[[str], bool]) -> str: 314 | """ 315 | Add an item until the return item key is valid 316 | 317 | This can be used to keep adding items until a unique key is generated for some 318 | existing collection. 319 | For example, if a dictionary of items already exists, this function can be used 320 | to generate new item keys until one that doesn't already exist in the dictionary 321 | is found. 322 | 323 | Parameters 324 | ---------- 325 | item: TItem 326 | The item to add 327 | valid: Callable[[str], bool] 328 | The condition the generated item key must satisfy 329 | 330 | Returns 331 | ------- 332 | str 333 | A string identifying the added item's category and count 334 | """ 335 | key = self.add_item(item) 336 | while not valid(key): 337 | key = self.add_item(item) 338 | 339 | return key 340 | 341 | def add_item_to_nodes(self, item: TItem, *nodes: tuple[Node, ...]) -> str: 342 | """ 343 | Add an item until the item key is unique within a set of trees 344 | 345 | This is used generate a unique key for an item for a set of existing trees 346 | (`Node`). 347 | 348 | Parameters 349 | ---------- 350 | item: TItem 351 | The item to add 352 | *nodes: tuple[Node, ...] 353 | The set of trees 354 | 355 | The returned item key should exist in these trees. 356 | 357 | Returns 358 | ------- 359 | str 360 | A string identifying the added item's category and count 361 | """ 362 | def valid(key): 363 | key_notin_nodes = (key not in node for node in nodes) 364 | return all(key_notin_nodes) 365 | return self.add_item_until_valid(item, valid) 366 | 367 | 368 | ## Node functions 369 | 370 | U = TypeVar("U") 371 | 372 | def map( 373 | function: Callable[[TValue], U], 374 | node: Node[TValue] 375 | ) -> Node[U]: 376 | """ 377 | Return a node by applying a function to every value in an input node 378 | """ 379 | 380 | flat_node_structs = flatten('', node) 381 | 382 | flat_map_node_structs = [ 383 | (key, Node, function(value), child_keys) 384 | for (key, _NodeType, value, child_keys) in flat_node_structs 385 | ] 386 | return unflatten(flat_map_node_structs)[0] 387 | 388 | def accumulate( 389 | function: Callable[[TValue, TValue], TValue], 390 | node: Node[TValue], 391 | initial: TValue 392 | ) -> Node[TValue]: 393 | """ 394 | Return a node by accumulating all leaf node values into the root 395 | """ 396 | # Recursively create all accumulated child nodes 397 | cnodes = { 398 | ckey: accumulate(function, cnode, initial) 399 | for ckey, cnode in node.items() 400 | } 401 | 402 | if len(cnodes) == 0: 403 | value = function(node.value, initial) 404 | else: 405 | value = functools.reduce( 406 | function, [node.value]+[cnode.value for cnode in cnodes.values()] 407 | ) 408 | return Node(value, cnodes) 409 | 410 | ## Manual flattening/unflattening implementation 411 | 412 | def iter_flat(root_key: str, root_node: TNode) -> Iterable[tuple[str, TNode]]: 413 | """ 414 | Return an iterable over all nodes in the root node (recursively depth-first) 415 | 416 | Parameters 417 | ---------- 418 | root_key: str 419 | A key for the node 420 | 421 | All child node keys will be appended to this key with a '/' separator. 422 | root_node: TNode 423 | The root node 424 | 425 | Returns 426 | ------- 427 | Iterable[tuple[str, TNode]] 428 | An iterable over all nodes in the root node 429 | """ 430 | # TODO: Fix mypy typing errors here 431 | 432 | # The flattened node consists of the root node tuple... 433 | flat_root_node = [(root_key, root_node)] 434 | 435 | # then recursively appends all flattened child nodes 436 | flat_child_nodes = [ 437 | iter_flat(f"{root_key}/{ckey}", cnode) 438 | for ckey, cnode in root_node.items() 439 | ] 440 | return itertools.chain(flat_root_node, *flat_child_nodes) 441 | 442 | 443 | FlatNodeStructure = tuple[str, type[Node], TValue, list[str]] 444 | 445 | def flatten(root_key: str, root_node: TNode) -> list[FlatNodeStructure]: 446 | """ 447 | Return a flattened list of node structures for a root node (recursively depth-first) 448 | 449 | Parameters 450 | ---------- 451 | root_key: str 452 | A key for the node 453 | root_node: TNode 454 | The root node 455 | 456 | Returns 457 | ------- 458 | list[FlatNodeStructure] 459 | A list of node structures 460 | 461 | Each node structure is a tuple representing the node. 462 | """ 463 | node_structs = [ 464 | (key, type(node), node.value, node.keys()) 465 | for key, node in iter_flat(root_key, root_node) 466 | ] 467 | return node_structs 468 | 469 | def unflatten( 470 | node_structs: list[FlatNodeStructure], 471 | ) -> tuple[TNode, list[FlatNodeStructure]]: 472 | """ 473 | Return the root node from a flat representation 474 | 475 | Parameters 476 | ---------- 477 | node_structs: list[FlatNodeStructure] 478 | The flat representation 479 | 480 | Returns 481 | ------- 482 | TNode 483 | The root node 484 | list[FlatNodeStructure] 485 | A "leftover" flat node representation 486 | 487 | This list should be empty if the flat node representation only contains 488 | nodes that belong to the root node. 489 | """ 490 | node_key, node_type, value, child_keys = node_structs[0] 491 | 492 | children = [] 493 | node_structs = node_structs[1:] 494 | for _key in child_keys: 495 | child, node_structs = unflatten(node_structs) 496 | children.append(child) 497 | 498 | node = node_type.from_tree( 499 | value, {key: child for key, child in zip(child_keys, children)} 500 | ) 501 | return node, node_structs 502 | 503 | 504 | ## pytree flattening/unflattening implementation 505 | # These functions register `Node` classes as a `jax.pytree` so jax can flatten/unflatten 506 | # them 507 | FlatNode = tuple[TValue, dict[str, TNode]] 508 | AuxData = Any 509 | 510 | def _make_flatten_unflatten(node_type: type[TNode]): 511 | 512 | def _flatten_node(node: TNode) -> tuple[FlatNode, AuxData]: 513 | flat_node = (node.value, node.children) 514 | aux_data = None 515 | return (flat_node, aux_data) 516 | 517 | def _unflatten_node(aux_data: AuxData, flat_node: FlatNode) -> TNode: 518 | value, children = flat_node 519 | return node_type.from_tree(value, children) 520 | 521 | return _flatten_node, _unflatten_node 522 | 523 | 524 | ## Register `Node` as `jax.pytree` 525 | for _NodeType in [Node]: 526 | _flatten, _unflatten = _make_flatten_unflatten(_NodeType) 527 | jax.tree_util.register_pytree_node(_NodeType, _flatten, _unflatten) 528 | -------------------------------------------------------------------------------- /src/mpllayout/geometry.py: -------------------------------------------------------------------------------- 1 | """ 2 | Geometric primitives and constraints 3 | """ 4 | 5 | from .primitives import * 6 | from .constraints import * 7 | -------------------------------------------------------------------------------- /src/mpllayout/layout.py: -------------------------------------------------------------------------------- 1 | """ 2 | Layout class and associated utilities 3 | 4 | A `Layout` is the class used to represent an arrangement (layout) of figure 5 | elements. 6 | """ 7 | 8 | from typing import Optional, Any 9 | from numpy.typing import NDArray 10 | 11 | from matplotlib.axes import Axes 12 | import numpy as np 13 | 14 | from . import primitives as pr 15 | from . import constraints as cr 16 | from .containers import ItemCounter, iter_flat 17 | 18 | IntGraph = list[tuple[int, ...]] 19 | StrGraph = list[tuple[str, ...]] 20 | 21 | 22 | class Layout: 23 | """ 24 | A collection of geometric primitives and constraints 25 | 26 | A `Layout` stores a collection of geometric primitives and the constraints 27 | on those primitives to represent an arrangement of figure elements. 28 | 29 | Parameters 30 | ---------- 31 | root_prim: Optional[pr.PrimitiveNode] 32 | A root primitive to store all geometric primitives 33 | root_constraint: Optional[cr.ConstraintNode] 34 | A root constraint to store all geometric constraints 35 | root_prim_keys: Optional[cr.PrimKeysNode] 36 | A root primitive arguments tree for `root_constraint` 37 | 38 | Each value in `root_prim_keys` is a tuple of primitive keys 39 | for the corresponding constraint. 40 | The tuple of primitive keys indicate primitives from `root_prim`. 41 | root_param: Optional[cr.ParamsNode] 42 | A root parameter tree for `root_constraint` 43 | 44 | Each value in `root_param` is a constraint residual parameter 45 | for the corresponding constraint in `root_constraint`. 46 | constraint_counter: Optional[ItemCounter] 47 | An item counter used to create automatic keys for constraints 48 | 49 | This does not have to be supplied. 50 | Supplying it will change any automatically generated constraint keys. 51 | 52 | Attributes 53 | ---------- 54 | root_prim: pr.PrimitiveNode 55 | See `Parameters` 56 | root_constraint: cr.ConstraintNode 57 | See `Parameters` 58 | root_prim_keys: cr.PrimKeysNode 59 | See `Parameters` 60 | root_param: cr.ParamsNode 61 | See `Parameters` 62 | """ 63 | 64 | def __init__( 65 | self, 66 | root_prim: Optional[pr.PrimitiveNode] = None, 67 | root_constraint: Optional[cr.ConstraintNode] = None, 68 | root_prim_keys: Optional[cr.PrimKeysNode] = None, 69 | root_param: Optional[cr.ParamsNode] = None, 70 | constraint_counter: Optional[ItemCounter] = None, 71 | ): 72 | 73 | if root_prim is None: 74 | root_prim = pr.PrimitiveNode.from_tree(np.array([]), {}) 75 | if root_constraint is None: 76 | root_constraint = cr.ConstraintNode.from_tree(None, {}) 77 | if root_prim_keys is None: 78 | root_prim_keys = cr.PrimKeysNode(None, {}) 79 | if root_param is None: 80 | root_param = cr.ParamsNode(None, {}) 81 | if constraint_counter is None: 82 | constraint_counter = ItemCounter() 83 | 84 | self._root_prim = root_prim 85 | self._root_constraint = root_constraint 86 | self._root_constraint_prim_keys = root_prim_keys 87 | self._root_constraint_param = root_param 88 | 89 | self._constraint_counter = constraint_counter 90 | 91 | @property 92 | def root_prim(self) -> pr.PrimitiveNode: 93 | return self._root_prim 94 | 95 | @property 96 | def root_constraint(self) -> cr.ConstraintNode: 97 | return self._root_constraint 98 | 99 | @property 100 | def root_prim_keys(self) -> cr.PrimKeysNode: 101 | return self._root_constraint_prim_keys 102 | 103 | @property 104 | def root_param(self) -> cr.ParamsNode: 105 | return self._root_constraint_param 106 | 107 | def flat_constraints( 108 | self 109 | ) -> tuple[list[cr.Constraint], list[cr.PrimKeys], list[cr.Params]]: 110 | """ 111 | Return flat constraints, primitive argument keys and parameters 112 | 113 | Returns 114 | ------- 115 | constraints: list[cr.Constraint] 116 | A flat list of constraints from the root constraint 117 | prim_keys: list[cr.PrimKeys] 118 | A list of primitive keys for each constraint 119 | params: list[cr.ResParams] 120 | A list of residual parameters for each constraint 121 | """ 122 | # The `[1:]` removes the 'root' constraint which is just a container 123 | constraints = [ 124 | node for _, node in iter_flat('', self.root_constraint) 125 | ][1:] 126 | prim_keys = [ 127 | node.value for _, node in iter_flat('', self.root_prim_keys) 128 | ][1:] 129 | params = [ 130 | node.value for _, node in iter_flat('', self.root_param) 131 | ][1:] 132 | 133 | constraints = [c.assem_atleast_1d for c in constraints] 134 | 135 | return constraints, prim_keys, params 136 | 137 | def add_prim(self, prim: pr.Primitive, key: str): 138 | """ 139 | Add a primitive to the layout 140 | 141 | The primitive will be added to `self.root_prim` using the given `key`. 142 | 143 | Parameters 144 | ---------- 145 | prim: pr.Primitive 146 | The primitive to add 147 | key: str 148 | The key for the primitive 149 | """ 150 | self.root_prim.add_child(key, prim) 151 | 152 | def add_prims(self, prims: dict[str, pr.Primitive]): 153 | """ 154 | Add multiple primitives to the layout 155 | 156 | Parameters 157 | ---------- 158 | prims: dict[str, pr.Primitive] 159 | A dictionary of primitives to add 160 | """ 161 | for key, prim in prims.items(): 162 | self.add_prim(prim, key) 163 | 164 | def add_constraint( 165 | self, 166 | constraint: cr.Constraint, 167 | prim_keys: cr.PrimKeys, 168 | param: cr.Params, 169 | key: str = "" 170 | ): 171 | """ 172 | Add a constraint between primitives 173 | 174 | Parameters 175 | ---------- 176 | constraint: cr.Constraint 177 | The constraint to add 178 | prim_keys: cr.PrimKeys 179 | A tuple of primitive keys the constraint applies to 180 | 181 | The primitive keys refer to primitives in `self.root_prim`. 182 | param: cr.ResParams 183 | Parameters for the constraint 184 | key: str 185 | An optional key to identify the constraint 186 | 187 | If not supplied, a key will be automatically generated using 188 | `_constraint_counter`. 189 | """ 190 | nodes = ( 191 | self.root_constraint, self.root_prim_keys, self.root_param 192 | ) 193 | if key == "": 194 | key = self._constraint_counter.add_item_to_nodes(constraint, *nodes) 195 | 196 | # Notify the user if the input primitives or parameters to the constraint 197 | # are incorrect 198 | try: 199 | constraint.validate_prims(tuple(self.root_prim[key] for key in prim_keys)) 200 | except TypeError as exc: 201 | raise RuntimeError(f"Wrong primitives {prim_keys} for constraint {type(constraint)}\n{exc}") from exc 202 | 203 | try: 204 | constraint.validate_params(param) 205 | except TypeError as exc: 206 | raise RuntimeError(f"Wrong parameters {param} for constraint {type(constraint)}\n{exc}") from exc 207 | 208 | self.root_constraint.add_child(key, constraint) 209 | self.root_prim_keys.add_child(key, constraint.root_prim_keys(prim_keys)) 210 | self.root_param.add_child(key, constraint.root_params(param)) 211 | 212 | def update_layout_constraints( 213 | layout: Layout, 214 | axs: dict[str, Axes] 215 | ) -> Layout: 216 | """ 217 | Update layout constraints that depend on `matplotlib` elements 218 | 219 | Some constraints have parameters that depend on `matplotlib`. 220 | This function shoud identify these constraints in a `Layout` and replace 221 | their parameters with the correct `matplotlib` element. 222 | Currently this is only implemented for `XAxisThickness` and `YAxisThickness`. 223 | 224 | Parameters 225 | ---------- 226 | layout: Layout 227 | The layout 228 | axs: dict[str, Axes] 229 | The `matplotlib` axes objects 230 | 231 | The key for every `matplotlib.Axes` should match a corresponding 232 | `pr.Axes` in the layout. 233 | 234 | Returns 235 | ------- 236 | Layout 237 | The layout with updated constraint parameters 238 | """ 239 | constraintkey_to_param = {} 240 | for key, constraint in iter_flat('', layout.root_constraint): 241 | # `key[1:]` removes the initial "/" from the key 242 | key = key[1:] 243 | if key != "": 244 | prim_keys = layout.root_prim_keys[key] 245 | constraint_param = layout.root_param[key] 246 | 247 | if isinstance(constraint, cr.XAxisThickness): 248 | axis_key, = prim_keys.value 249 | axes_key = axis_key.split("/", 1)[0] 250 | constraintkey_to_param[key] = (axs[axes_key].xaxis,) 251 | 252 | if isinstance(constraint, cr.YAxisThickness): 253 | axis_key, = prim_keys.value 254 | axes_key = axis_key.split("/", 1)[0] 255 | constraintkey_to_param[key] = (axs[axes_key].yaxis,) 256 | 257 | new_root_param = update_root_param( 258 | layout.root_constraint, 259 | layout.root_param, 260 | constraintkey_to_param 261 | ) 262 | 263 | return Layout( 264 | layout.root_prim, 265 | layout.root_constraint, 266 | layout.root_prim_keys, 267 | new_root_param 268 | ) 269 | 270 | def update_root_param( 271 | root_constraint: cr.ConstraintNode, 272 | root_param: cr.ParamsNode, 273 | constraintkey_to_param: dict[str, cr.Params] 274 | ) -> cr.ParamsNode: 275 | """ 276 | Update the root constraint parameters node 277 | 278 | Parameters 279 | ---------- 280 | root_constraint: cr.ConstraintNode 281 | The root constraint 282 | root_param: cr.ParamsNode 283 | The corresponding root parameters node 284 | constraintkey_to_param: dict[str, cr.ResParams] 285 | A mapping of constraint keys to replacement constraint parameters 286 | 287 | Each constraint key should indicate a node in `root_constraint` and 288 | `root_param`. 289 | """ 290 | new_root_param = root_param.copy() 291 | for key, param in constraintkey_to_param.items(): 292 | constraint = root_constraint[key] 293 | new_root_param[key] = constraint.root_params(param) 294 | 295 | return new_root_param 296 | -------------------------------------------------------------------------------- /src/mpllayout/matplotlibutils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for creating `matplotlib` elements from geometric primitives 3 | """ 4 | 5 | from typing import Optional 6 | from numpy.typing import NDArray 7 | import warnings 8 | 9 | import numpy as np 10 | 11 | from matplotlib import pyplot as plt 12 | from matplotlib.figure import Figure 13 | from matplotlib.axes import Axes 14 | 15 | from . import primitives as pr 16 | from . import constraints as cr 17 | 18 | # TODO: (not critical) Should special primitive classes indicate `matplotlib` figures and axes? 19 | # I'm not certain if that would be that beneficial here. 20 | # If so, this should be done for both `subplots` and `update_subplots`. 21 | def subplots( 22 | root_prim: pr.Primitive, 23 | fig_key: str = "Figure", 24 | axs_keys: Optional[list[str]] = None, 25 | ) -> tuple[Figure, dict[str, Axes]]: 26 | """ 27 | Create matplotlib `Figure` and `Axes` objects from geometric primitives 28 | 29 | The `Figure` and `Axes` objects are extracted based on labels in the primitive tree 30 | and have sizes and positions from their corresponding primitives. 31 | 32 | Parameters 33 | ---------- 34 | root_prim: pr.Primitive 35 | The root primitive 36 | fig_key: str 37 | The quadrilateral key corresponding to the figure 38 | 39 | The key is "Figure" by default. 40 | axs_keys: Optional[list[str]] 41 | Axes keys 42 | 43 | If supplied, only these axes keys will be used to generate `Axes` instances. 44 | 45 | Returns 46 | ------- 47 | fig: Figure 48 | The matplotlib `Figure` 49 | axs: dict[str, Axes] 50 | The matplotlib `Axes` instances 51 | """ 52 | # Create the `Figure` instance 53 | fig = plt.figure(figsize=(1, 1)) 54 | 55 | # Assume all axes are prefixed by "Axes" if there are no keys provided 56 | if axs_keys is None: 57 | axs_keys = [key for key in root_prim.keys() if "Axes" in key] 58 | 59 | # Create all `Axes` instances 60 | unit_rect = (0, 0, 1, 1) 61 | key_to_ax = {key: fig.add_axes(unit_rect) for key in axs_keys} 62 | 63 | # Update positions figures and axes 64 | fig, key_to_ax = update_subplots(root_prim, fig_key, fig, key_to_ax) 65 | 66 | return fig, key_to_ax 67 | 68 | 69 | def update_subplots( 70 | root_prim: pr.Primitive, fig_key: str, fig: Figure, axs: dict[str, Axes], 71 | ): 72 | """ 73 | Update matplotlib `Figure` and `Axes` object positions from primitives 74 | 75 | The `Figure` and `Axes` objects are extracted based on labels in the primitive tree 76 | and have sizes and positions updated from their corresponding primitives. 77 | 78 | Parameters 79 | ---------- 80 | root_prim: pr.Primitive 81 | The root primitive 82 | fig_key: str 83 | The quadrilateral key in `root_prim` corresponding to the figure 84 | fig: Figure 85 | The `Figure` to update 86 | axs: dict[str, Axes] 87 | The `Axes` objects to update 88 | 89 | Returns 90 | ------- 91 | fig: Figure 92 | The updated matplotlib `Figure` 93 | axs: dict[str, Axes] 94 | The updated matplotlib `Axes` instances 95 | """ 96 | # Set Figure position 97 | quad = root_prim[fig_key] 98 | fig_origin = quad['Line0/Point0'].value 99 | fig_size = np.array(width_and_height_from_quad(quad)) 100 | fig.set_size_inches(fig_size) 101 | 102 | # Set Axes properties/position 103 | for key, ax in axs.items(): 104 | # Set Axes dimensions 105 | quad = root_prim[f"{key}/Frame"] 106 | ax.set_position(rect_from_box(quad, fig_origin, fig_size)) 107 | 108 | # Set x/y axis properties 109 | axis_prefixes = ("X", "Y") 110 | axis_tuple = (ax.xaxis, ax.yaxis) 111 | 112 | for axis_prefix, axis in zip(axis_prefixes, axis_tuple): 113 | # Set the axis label position 114 | axis_label = f"{axis_prefix}AxisLabel" 115 | if axis_label in root_prim[key]: 116 | axis_label_point: pr.Point = root_prim[f"{key}/{axis_label}"] 117 | label_coords = axis_label_point.value 118 | axis.set_label_coords( 119 | *(label_coords / fig_size), transform=fig.transFigure 120 | ) 121 | 122 | # Set the axis tick position 123 | axis_bbox = f"{axis_prefix}Axis" 124 | if axis_bbox in root_prim[key]: 125 | axis_quad = root_prim[f"{key}/{axis_bbox}"] 126 | axis_tick_position = find_axis_position(quad, axis_quad) 127 | axis.set_ticks_position(axis_tick_position) 128 | 129 | return fig, axs 130 | 131 | def find_axis_position(axes_frame: pr.Quadrilateral, axis: pr.Quadrilateral) -> str: 132 | """ 133 | Return the axis position relative to a frame 134 | 135 | Parameters 136 | ---------- 137 | axes_frame: pr.Quadrilateral 138 | The axes frame 139 | axis: pr.Quadrilateral 140 | The axes axis 141 | 142 | This can be any x or y axis 143 | 144 | Returns 145 | ------- 146 | position: str 147 | One of ('bottom', 'top', 'left', 'right') indicating the axis position 148 | """ 149 | coincident_line = cr.CoincidentLines() 150 | params = {"reverse": True} 151 | bottom_res = coincident_line((axes_frame["Line0"], axis["Line2"]), params) 152 | top_res = coincident_line((axes_frame["Line2"], axis["Line0"]), params) 153 | left_res = coincident_line((axes_frame["Line3"], axis["Line1"]), params) 154 | right_res = coincident_line((axes_frame["Line1"], axis["Line3"]), params) 155 | 156 | residuals = tuple( 157 | np.linalg.norm(res) for res in (bottom_res, top_res, left_res, right_res) 158 | ) 159 | residual_positions = ("bottom", "top", "left", "right") 160 | 161 | if not np.isclose(np.min(residuals), 0): 162 | warnings.warn("The axis isn't closely aligned with any of the axes sides") 163 | position = residual_positions[np.argmin(residuals)] 164 | return position 165 | 166 | 167 | def width_and_height_from_quad(quad: pr.Quadrilateral) -> tuple[float, float]: 168 | """ 169 | Return the width and height of a quadrilateral 170 | 171 | Parameters 172 | ---------- 173 | quad: pr.Quadrilateral 174 | 175 | Returns 176 | ------- 177 | tuple[float, float] 178 | The width and height 179 | """ 180 | 181 | coord_botleft = quad["Line0/Point0"].value 182 | xmin, ymin = coord_botleft 183 | 184 | coord_topright = quad["Line1/Point1"].value 185 | xmax, ymax = coord_topright 186 | 187 | return (xmax - xmin), (ymax - ymin) 188 | 189 | 190 | def rect_from_box( 191 | quad: pr.Quadrilateral, 192 | fig_origin: NDArray, 193 | fig_size: NDArray = np.array((1, 1)) 194 | ) -> tuple[float, float, float, float]: 195 | """ 196 | Return a `rect' tuple, `(left, bottom, width, height)`, from a quadrilateral 197 | 198 | This tuple of quadrilateral information can be used to create a `Bbox` or `Axes` 199 | object in `matplotlib`. 200 | 201 | Parameters 202 | ---------- 203 | quad: pr.Quadrilateral 204 | The quadrilateral 205 | fig_origin: NDArray 206 | Coordinates for the figure bottom left corner 207 | fig_size: NDArray 208 | The width and height of the figure 209 | 210 | This should be supplied so that the rect tuple has units relative to the figure. 211 | Some matplotlib `Axes` constructors accept the rect tuple in figure units by default. 212 | 213 | Returns 214 | ------- 215 | xmin, ymin, width, heigth: tuple[float, float, float, float] 216 | """ 217 | 218 | coord_botleft = quad["Line0/Point0"].value 219 | xmin, ymin = (coord_botleft-fig_origin) / fig_size 220 | 221 | coord_topright = quad["Line1/Point1"].value 222 | xmax, ymax = (coord_topright-fig_origin) / fig_size 223 | width = xmax - xmin 224 | height = ymax - ymin 225 | 226 | return (xmin, ymin, width, height) 227 | -------------------------------------------------------------------------------- /src/mpllayout/primitives.py: -------------------------------------------------------------------------------- 1 | """ 2 | Geometric primitive definitions 3 | """ 4 | 5 | from typing import Optional, TypeVar, Any 6 | from numpy.typing import NDArray 7 | 8 | import numpy as np 9 | import jax 10 | 11 | # from .containers import Node, _make_flatten_unflatten, iter_flat, unflatten, FlatNodeStructure 12 | import mpllayout.containers as cn 13 | 14 | 15 | ## Generic primitive class/interface 16 | # You can create specific primitive definitions by inheriting from these and 17 | # defining appropriate class attributes 18 | 19 | TPrim = TypeVar("TPrim", bound="PrimitiveNode") 20 | PrimValue = NDArray[np.float64] 21 | PrimTypes = tuple[type[TPrim] | None, ...] 22 | PrimTypesSignature = PrimTypes | set[PrimTypes] 23 | PrimNodeSignature = tuple[int, PrimTypesSignature] 24 | 25 | def validate_prims( 26 | prims: list[TPrim], prim_types: PrimTypes 27 | ) -> tuple[bool, str]: 28 | """ 29 | Return whether input primitives are valid 30 | 31 | Parameters 32 | ---------- 33 | prims: list[TPrim] 34 | Primitives to validate 35 | prim_types: PrimTypes 36 | A tuple of primitive types that `prims` must match 37 | 38 | Raises 39 | ------- 40 | ValueError 41 | """ 42 | ## Pre-process `prim_types` into a simple tuple of primitive types 43 | # Expand any `ellipsis` in `prim_types` by treating types before the ... 44 | # as repeating units 45 | if len(prim_types) > 0: 46 | if prim_types[-1] == Ellipsis: 47 | repeat_prim_types = prim_types[:-1] 48 | n_repeat = len(prims) // len(repeat_prim_types) 49 | prim_types = n_repeat * repeat_prim_types 50 | 51 | # Replace any `None` primitive types with `PrimitiveNode` 52 | prim_types = [ 53 | PrimitiveNode if prim_type is None else prim_type 54 | for prim_type in prim_types 55 | ] 56 | 57 | ## Check `prims` has the right length 58 | if len(prims) != len(prim_types): 59 | raise ValueError( 60 | f"Invalid `prims` must be {len(prim_types)} not {len(prims)}" 61 | ) 62 | 63 | ## Check `prims` are the right types 64 | _prim_types = [type(prim) for prim in prims] 65 | if not all( 66 | issubclass(_prim_type, prim_type) 67 | for _prim_type, prim_type in zip(_prim_types, prim_types) 68 | ): 69 | raise ValueError( 70 | f"Invalid `prims` types must be {prim_types} not {_prim_types}" 71 | ) 72 | 73 | 74 | # TODO: Improve primitive array data structure? 75 | # Each primitive node stores it's own parameter vector but this doesn't 76 | # handle connectivity information. For example, the same point coordinate is 77 | # owned by two lines in a polygon 78 | 79 | class PrimitiveNode(cn.Node[NDArray[np.float64]]): 80 | """ 81 | Node representation of a geometric primitive 82 | 83 | A geometric primitive (prim for short) is represented by a parameter vector 84 | and child primitives. For example, in 2D, a point has a 2 element parameter 85 | vector representing (x, y) coordinates and no child prims. A straight 86 | line segment has an empty parameter vector with two point child prims 87 | representing the start and end points. 88 | 89 | Parameters 90 | ---------- 91 | value: PrimValue 92 | Parameter vector for the prim 93 | children: dict[str, TPrim] 94 | Child prims 95 | 96 | Attributes 97 | ---------- 98 | signature: PrimNodeSignature 99 | A tuple specifying valid parameter vector and child types 100 | 101 | A signature has two components 102 | ``(param_size, prim_types) = signature``, 103 | where `param_size` is the size of the parameter vector and `prim_types` 104 | indicates valid child primitives. 105 | """ 106 | 107 | # NOTE: I used `None` to indicate `PrimitiveNode` because the name isn't 108 | # available within the class itself 109 | signature: PrimNodeSignature = (0, (None, ...)) 110 | 111 | def __init__(self, value: PrimValue, children: dict[str, TPrim]): 112 | 113 | # Type checks that value is an array and children are prims 114 | assert isinstance(value, np.ndarray) 115 | assert all( 116 | isinstance(cprim, PrimitiveNode) for _, cprim in children.items() 117 | ) 118 | param_size, prim_type_sig = self.signature 119 | 120 | assert len(value) == param_size 121 | 122 | prims = [cprim for _, cprim in children.items()] 123 | if isinstance(prim_type_sig, tuple): 124 | prim_types = prim_type_sig 125 | try: 126 | validate_prims(prims, prim_types) 127 | except ValueError as err: 128 | raise err 129 | elif isinstance(prim_type_sig, set): 130 | def match(prims, prim_types): 131 | try: 132 | validate_prims(prims, prim_types) 133 | except ValueError as err: 134 | return False 135 | else: 136 | return True 137 | 138 | if not any( 139 | match(prims, prim_types) for prim_types in prim_type_sig 140 | ): 141 | raise ValueError( 142 | f"No matching signatures for `prims` in {prim_type_sig}" 143 | ) 144 | 145 | super().__init__(value, children) 146 | 147 | 148 | class Primitive(PrimitiveNode): 149 | """ 150 | Primitive parameterized by a value, primitives and keyword arguments 151 | 152 | To define a primitive, subclass `Primitive` and define the class methods 153 | `init_children`, `default_value` and `default_prims`. 154 | 155 | Parameters 156 | ---------- 157 | value: Optional[PrimValue] 158 | Parameter vector for the primitive 159 | prims: Optional[list[TPrim]] 160 | Parameterizing primitives 161 | 162 | Note that these are often the same as child primitives but this isn't 163 | always the case. 164 | **kwargs: dict[str, Any] 165 | Additional keyword arguments 166 | 167 | Attributes 168 | ---------- 169 | see `PrimitiveNode` 170 | """ 171 | 172 | def __init__( 173 | self, 174 | value: Optional[NDArray] = None, 175 | prims: Optional[list[TPrim]] = None, 176 | **kwargs: dict[str, Any] 177 | ): 178 | if value is None: 179 | value = self.default_value(**kwargs) 180 | elif isinstance(value, (list, tuple)): 181 | value = np.array(value) 182 | elif isinstance(value, (np.ndarray, jax.numpy.ndarray)): 183 | value = value 184 | else: 185 | raise TypeError() 186 | 187 | if prims is None: 188 | prims = self.default_prims(**kwargs) 189 | 190 | child_keys, child_prims = self.init_children(prims, **kwargs) 191 | 192 | super().__init__( 193 | value, {key: prim for key, prim in zip(child_keys, child_prims)} 194 | ) 195 | 196 | @classmethod 197 | def init_children( 198 | cls, prims: list[TPrim], **kwargs: dict[str, Any] 199 | ) -> tuple[list[str], list[TPrim]]: 200 | """ 201 | Return child primitives from parameterizing primitives 202 | 203 | Parameters 204 | ---------- 205 | prims: list[TPrim] 206 | Parameterizing primitives 207 | 208 | Returns 209 | ------- 210 | List[str] 211 | Child primitive keys 212 | List[TPrim] 213 | Child primitives 214 | """ 215 | raise NotImplementedError() 216 | 217 | @classmethod 218 | def default_value(cls, **kwargs: dict[str, Any]) -> NDArray: 219 | """ 220 | Return a default parameter vector 221 | """ 222 | raise NotImplementedError() 223 | 224 | @classmethod 225 | def default_prims(cls, **kwargs: dict[str, Any]) -> list[TPrim]: 226 | """ 227 | Return default parameterizing primitives 228 | """ 229 | raise NotImplementedError() 230 | 231 | 232 | class StaticPrimitive(Primitive): 233 | """ 234 | A `Primitive` with no additional keyword arguments 235 | 236 | Static primitives have a static parameter vector shape and child primitive 237 | types. 238 | 239 | Parameters 240 | ---------- 241 | See `Primitive` 242 | """ 243 | 244 | def __init__( 245 | self, 246 | value: Optional[NDArray] = None, 247 | prims: Optional[list[TPrim]] = None, 248 | ): 249 | super().__init__(value=value, prims=prims) 250 | 251 | @classmethod 252 | def init_children( 253 | cls, prims: list[TPrim] 254 | ) -> tuple[list[str], list[TPrim]]: 255 | """ 256 | See `Primitive` 257 | """ 258 | raise NotImplementedError() 259 | 260 | @classmethod 261 | def default_value(cls) -> NDArray: 262 | """ 263 | See `Primitive` 264 | """ 265 | raise NotImplementedError() 266 | 267 | @classmethod 268 | def default_prims(cls) -> list[TPrim]: 269 | """ 270 | See `Primitive` 271 | """ 272 | raise NotImplementedError() 273 | 274 | ## Primitive definitions 275 | 276 | class Point(StaticPrimitive): 277 | """ 278 | A point 279 | 280 | Parameters 281 | ---------- 282 | value: Optional[NDArray] with shape (2,) 283 | The point coordinate 284 | prims: Optional[tuple[]] 285 | An empty tuple 286 | """ 287 | 288 | signature = (2, ()) 289 | 290 | @classmethod 291 | def init_children(cls, prims): 292 | return (), () 293 | 294 | @classmethod 295 | def default_value(cls): 296 | return np.array([0, 0]) 297 | 298 | @classmethod 299 | def default_prims(cls): 300 | return () 301 | 302 | 303 | class Line(StaticPrimitive): 304 | """ 305 | A straight line segment between two points 306 | 307 | Parameters 308 | ---------- 309 | value: Optional[NDArray] with shape () 310 | An empty vector 311 | prims: Optional[tuple[Point, Point]] 312 | A tuple containing the line start and end point 313 | """ 314 | 315 | signature = (0, (Point, Point)) 316 | 317 | @classmethod 318 | def init_children(cls, prims: tuple[Point, Point]): 319 | return ("Point0", "Point1"), prims 320 | 321 | @classmethod 322 | def default_value(cls): 323 | return np.array(()) 324 | 325 | @classmethod 326 | def default_prims(cls): 327 | return (Point([0, 0]), Point([0, 1])) 328 | 329 | 330 | class Polygon(Primitive): 331 | """ 332 | A polygon with straight-line edges through a set of points 333 | 334 | The polygon tree structure contains a sequence of lines where ehe end point 335 | of each line joins the start point of the next line forming a closed loop. 336 | 337 | Parameters 338 | ---------- 339 | value: Optional[NDArray] with shape () 340 | An empty vector 341 | prims: Optional[tuple[Point, ...]] 342 | A list of points representing polygon vertices 343 | 344 | The number of resulting polygon edges (`Line` instances) is the same as 345 | the number of points. 346 | 347 | Child primitives are lines joining the given vertices. 348 | size: int 349 | The number of polygon vertices 350 | """ 351 | 352 | signature = (0, (Line, ...)) 353 | 354 | def __init__( 355 | self, 356 | value: Optional[NDArray] = None, 357 | prims: Optional[list[TPrim]] = None, 358 | size: int = 3 359 | ): 360 | return super().__init__(value=value, prims=prims, size=size) 361 | 362 | @classmethod 363 | def init_children( 364 | cls, prims: list[Point], size: int=3 365 | ): 366 | points = prims 367 | child_prims = [ 368 | Line(np.array([]), (pointa, pointb)) 369 | for pointa, pointb in zip(points[:], points[1:] + points[:1]) 370 | ] 371 | child_keys = [f"Line{n}" for n, _ in enumerate(child_prims)] 372 | return child_keys, child_prims 373 | 374 | @classmethod 375 | def default_value(cls, size: int=3): 376 | return np.array(()) 377 | 378 | @classmethod 379 | def default_prims(cls, size: int=3): 380 | # Generate points around circle 381 | ii = np.arange(size) 382 | xs = np.cos(2*np.pi/size * ii) 383 | ys = np.sin(2*np.pi/size * ii) 384 | return [Point((x, y)) for x, y in zip(xs, ys)] 385 | 386 | 387 | class Quadrilateral(Polygon): 388 | """ 389 | A quadrilateral (4 sided polygon) 390 | 391 | For modelling rectangles in matplotlib (`axes`, `bbox`, etc.) the lines 392 | treated as the bottom, right, top, and left of a box in a clockwise fasion. 393 | Specifically, the lines correspond to: 394 | - 'Line0' : bottom 395 | - 'Line1' : right 396 | - 'Line2' : top 397 | - 'Line3' : left 398 | 399 | Parameters 400 | ---------- 401 | value: Optional[NDArray] with shape () 402 | An empty vector 403 | prims: Optional[tuple[Point, Point, Point, Point]] 404 | A tuple of 4 vertices for the quadrilateral 405 | """ 406 | 407 | signature = (0, (Line, Line, Line, Line)) 408 | 409 | def __init__( 410 | self, 411 | value: Optional[NDArray] = None, 412 | children: Optional[list[Point]] = None, 413 | ): 414 | super().__init__(value, children, size=4) 415 | 416 | @classmethod 417 | def default_value(cls, size: int=4): 418 | return np.array(()) 419 | 420 | @classmethod 421 | def default_prims(cls, size: int=4): 422 | # Generate a unit square 423 | xs = [0, 1, 1, 0] 424 | ys = [0, 0, 1, 1] 425 | return [Point((x, y)) for x, y in zip(xs, ys)] 426 | 427 | AxisPrims = tuple[Quadrilateral, Point] 428 | AxesChildPrims = ( 429 | tuple[Quadrilateral] 430 | | tuple[Quadrilateral, *AxisPrims] 431 | | tuple[Quadrilateral, *AxisPrims, *AxisPrims] 432 | | tuple[Quadrilateral, *AxisPrims, *AxisPrims, *AxisPrims] 433 | | tuple[Quadrilateral, *AxisPrims, *AxisPrims, *AxisPrims, *AxisPrims] 434 | ) 435 | 436 | class Axes(Primitive): 437 | """ 438 | A collection of quadrilaterals and points representing an axes 439 | 440 | Child primitives are: 441 | - `quad['Frame']` : A `Quadrilateral` representing the plotting area 442 | - `quad['XAxis']` : A `Quadrilateral` representing the x-axis 443 | - `quad['XAxisLabel']` : A `Point` representing the x-axis label anchor 444 | - `quad['YAxis']` : A `Quadrilateral` representing the y-axis 445 | - `quad['YAxisLabel']` : A `Point` representing the y-axis label anchor 446 | - `quad['TwinXAxis']` : A `Quadrilateral` representing the twin x-axis 447 | - `quad['TwinXAxisLabel']` : A `Point` representing the twin x-axis label anchor 448 | - `quad['TwinYAxis']` : A `Quadrilateral` representing the twin y-axis 449 | - `quad['TwinYAxisLabel']` : A `Point` representing the twin y-axis label anchor 450 | 451 | Parameters 452 | ---------- 453 | value: Optional[NDArray] with shape () 454 | An empty vector 455 | prims: Optional[AxesChildPrims] 456 | A tuple of quadrilateral and points 457 | 458 | The number of quadrilaterals and points in `prims` depends on whether an 459 | x/y axis is included. 460 | xaxis, yaxis: bool 461 | Whether to include an x/y axis and the corresponding label 462 | 463 | If false for a given axis, the corresponding child primitives will not 464 | be present. 465 | xtwin, ytwin: bool 466 | Whether to include a twin x/y axis 467 | """ 468 | 469 | _AxisPrimClasses = (Quadrilateral, Point) 470 | signature = ( 471 | 0, 472 | { 473 | (Quadrilateral,), 474 | (Quadrilateral, *(1 * _AxisPrimClasses)), 475 | (Quadrilateral, *(2 * _AxisPrimClasses)), 476 | (Quadrilateral, *(3 * _AxisPrimClasses)), 477 | (Quadrilateral, *(4 * _AxisPrimClasses)), 478 | } 479 | ) 480 | 481 | def __init__( 482 | self, 483 | value: Optional[NDArray]=None, 484 | prims: Optional[AxesChildPrims]=None, 485 | xaxis: bool=False, 486 | yaxis: bool=False, 487 | twinx: bool=False, 488 | twiny: bool=False 489 | ): 490 | super().__init__(value, prims, xaxis=xaxis, yaxis=yaxis, twinx=twinx, twiny=twiny) 491 | 492 | @classmethod 493 | def init_children( 494 | cls, 495 | prims: AxesChildPrims, 496 | xaxis: bool=False, 497 | yaxis: bool=False, 498 | twinx: bool=False, 499 | twiny: bool=False 500 | ) -> tuple[list[str], AxesChildPrims]: 501 | 502 | xaxis_keys = () 503 | twin_xaxis_keys = () 504 | if xaxis: 505 | xaxis_keys = ("XAxis", "XAxisLabel") 506 | if twinx: 507 | twin_xaxis_keys = ("TwinXAxis", "TwinXAxisLabel") 508 | 509 | yaxis_keys = () 510 | twin_yaxis_keys = () 511 | if yaxis: 512 | yaxis_keys = ("YAxis", "YAxisLabel") 513 | if twiny: 514 | twin_yaxis_keys = ("TwinYAxis", "TwinYAxisLabel") 515 | 516 | keys = ( 517 | ("Frame",) 518 | + (xaxis_keys + yaxis_keys) 519 | + (twin_xaxis_keys + twin_yaxis_keys) 520 | ) 521 | return (keys, prims) 522 | 523 | @classmethod 524 | def default_value( 525 | cls, 526 | xaxis: bool=False, 527 | yaxis: bool=False, 528 | twinx: bool=False, 529 | twiny: bool=False 530 | ): 531 | return np.array(()) 532 | 533 | @classmethod 534 | def default_prims( 535 | cls, 536 | xaxis: bool=False, 537 | yaxis: bool=False, 538 | twinx: bool=False, 539 | twiny: bool=False 540 | ): 541 | xaxis_prims = () 542 | twin_xaxis_prims = () 543 | if xaxis: 544 | xaxis_prims = (Quadrilateral(), Point()) 545 | if twinx: 546 | twin_xaxis_prims = (Quadrilateral(), Point()) 547 | 548 | yaxis_prims = () 549 | twin_yaxis_prims = () 550 | if yaxis: 551 | yaxis_prims = (Quadrilateral(), Point()) 552 | if twiny: 553 | twin_yaxis_prims = (Quadrilateral(), Point()) 554 | 555 | return ( 556 | (Quadrilateral(),) 557 | + (xaxis_prims + yaxis_prims) 558 | + (twin_xaxis_prims + twin_yaxis_prims) 559 | ) 560 | 561 | 562 | ## Register `Primitive` classes as `jax.pytree` 563 | _PrimitiveClasses = [ 564 | Primitive, 565 | PrimitiveNode, 566 | Point, 567 | Line, 568 | Polygon, 569 | Quadrilateral, 570 | Axes, 571 | ] 572 | for _PrimitiveClass in _PrimitiveClasses: 573 | _flatten_primitive, _unflatten_primitive = cn._make_flatten_unflatten(_PrimitiveClass) 574 | jax.tree_util.register_pytree_node( 575 | _PrimitiveClass, _flatten_primitive, _unflatten_primitive 576 | ) 577 | 578 | 579 | ## Primitive value vector methods 580 | # These are used to get the primitive parameter vector from a primitive 581 | # and to update primitives with new parameters 582 | 583 | def filter_unique_values_from_prim( 584 | root_prim: Primitive, 585 | ) -> tuple[dict[str, int], list[Primitive]]: 586 | """ 587 | Return unique primitives from a root primitive and indicate their indices 588 | 589 | Note that primitives in a primitive node are not necessarily unique 590 | ; for example `Point`s are shared between lines in a polygon. 591 | 592 | When solving a set of geometric constraints, the geometric constraint 593 | residual should be linked to a function of unique primitives only. 594 | 595 | Returns 596 | ------- 597 | prim_to_idx: dict[str, int] 598 | A mapping from each primitive key to its unique primitive index 599 | prims: list[Primitive] 600 | A list of unique primitives 601 | """ 602 | value_id_to_idx = {} 603 | values = [] 604 | prim_to_idx = {} 605 | 606 | for key, prim in cn.iter_flat("", root_prim): 607 | value_id = id(prim.value) 608 | 609 | if value_id not in value_id_to_idx: 610 | values.append(prim.value) 611 | value_idx = len(values) - 1 612 | value_id_to_idx[value_id] = value_idx 613 | else: 614 | value_idx = value_id_to_idx[value_id] 615 | 616 | prim_to_idx[key] = value_idx 617 | 618 | return prim_to_idx, values 619 | 620 | def build_prim_from_unique_values( 621 | flat_prim: list[cn.FlatNodeStructure], prim_to_idx: dict[str, int], values: list[NDArray] 622 | ) -> Primitive: 623 | """ 624 | Return a new primitive with values updated from unique values 625 | 626 | Parameters 627 | ---------- 628 | flat_prim: list[FlatNodeStructure] 629 | The flat primitive tree (see `flatten`) 630 | prim_to_idx: dict[str, int] 631 | A mapping from each primitive key to a unique primitive value in `values` 632 | values: list[NDArray] 633 | A list of primitive values for unique primitives in `root_prim` 634 | 635 | Returns 636 | ------- 637 | Primitive 638 | The new primitive with updated values 639 | """ 640 | prim_keys = (flat_struct[0] for flat_struct in flat_prim) 641 | new_prim_values = (values[prim_to_idx[key]] for key in prim_keys) 642 | 643 | new_prim_structs = [ 644 | (prim_key, PrimType, new_value, child_keys) 645 | for (prim_key, PrimType, _old_value, child_keys), new_value 646 | in zip(flat_prim, new_prim_values) 647 | ] 648 | return cn.unflatten(new_prim_structs)[0] 649 | -------------------------------------------------------------------------------- /src/mpllayout/solver.py: -------------------------------------------------------------------------------- 1 | """ 2 | Solvers for constrained geometric primitives 3 | """ 4 | 5 | from typing import Any 6 | from numpy.typing import NDArray 7 | 8 | import warnings 9 | 10 | import jax 11 | from jax import numpy as jnp 12 | import numpy as np 13 | from scipy.optimize import minimize, OptimizeResult 14 | 15 | from . import primitives as pr 16 | from . import constraints as cr 17 | from . import containers as cn 18 | from . import layout as lay 19 | 20 | IntGraph = list[tuple[int, ...]] 21 | StrGraph = list[tuple[str, ...]] 22 | 23 | SolverInfo = dict[str, Any] 24 | 25 | def solve( 26 | layout: lay.Layout, 27 | abs_tol: float = 1e-10, 28 | rel_tol: float = 1e-7, 29 | max_iter: int = 10, 30 | method: str='newton' 31 | ) -> tuple[pr.PrimitiveNode, SolverInfo]: 32 | """ 33 | Return geometric primitives that satisfy constraints 34 | 35 | This solves a set of, potentially non-linear, geometric constraints with an 36 | iterative method. 37 | 38 | Parameters 39 | ---------- 40 | layout: lay.Layout 41 | The layout of geometric primitives and constraints to solve 42 | abs_tol, rel_tol: float 43 | The absolute and relative tolerance for the iterative solution 44 | max_iter: int 45 | The maximum number of iterations for the iterative solution 46 | method: Optional[str] 47 | A solver method (one of 'newton', 'minimize') 48 | 49 | Returns 50 | ------- 51 | pr.PrimitiveNode 52 | A primitive tree satisfying the constraints 53 | SolverInfo 54 | Information about the iterative solution 55 | 56 | Keys are: 57 | 'abs_errs': 58 | A list of absolute errors for each solver iteration. 59 | This is the 2-norm of the constraint residual vector. 60 | 'rel_errs': 61 | A list of relative errors for each solver iteration. 62 | This is the absolute error at each iteration, relative to the 63 | initial absolute error. 64 | """ 65 | if method == 'newton': 66 | return solve_newton(layout, abs_tol, rel_tol, max_iter) 67 | elif method == 'minimize': 68 | return solve_minimize(layout, abs_tol, rel_tol, max_iter) 69 | else: 70 | raise ValueError(f"Invalid `method` {method}") 71 | 72 | # TODO: Add sparse jacobian assembly and solving via sparse least squares 73 | 74 | def solve_newton( 75 | layout: lay.Layout, 76 | abs_tol: float = 1e-10, 77 | rel_tol: float = 1e-7, 78 | max_iter: int = 10, 79 | ) -> tuple[pr.PrimitiveNode, SolverInfo]: 80 | """ 81 | Return geometric primitives that satisfy constraints using a newton method 82 | 83 | Parameters 84 | ---------- 85 | Parameters match those for `solve` except for `method` 86 | 87 | See `solve` for more details. 88 | 89 | Returns 90 | ------- 91 | Returns match those for `solve` 92 | """ 93 | 94 | ## Set-up assembly function for the global residual as a function of a global 95 | ## parameter list 96 | 97 | # `prim_idx_bounds` stores the right/left indices for each primitive's 98 | # parameter vector in the global parameter vector array 99 | # For primitive with index `n`, for example, 100 | # `prim_idx_bounds[n], prim_idx_bounds[n+1]` are the indices between which 101 | # the parameter vectors are stored. 102 | flat_prim = cn.flatten('', layout.root_prim) 103 | prim_graph, prim_values = pr.filter_unique_values_from_prim(layout.root_prim) 104 | prim_idx_bounds = np.cumsum([0] + [value.size for value in prim_values]) 105 | global_param_n = np.concatenate(prim_values) 106 | 107 | constraints, constraint_graph, constraint_params = layout.flat_constraints() 108 | 109 | @jax.jit 110 | def assem_global_res(global_param): 111 | new_prim_params = [ 112 | global_param[idx_start:idx_end] 113 | for idx_start, idx_end in zip(prim_idx_bounds[:-1], prim_idx_bounds[1:]) 114 | ] 115 | root_prim = pr.build_prim_from_unique_values(flat_prim, prim_graph, new_prim_params) 116 | residuals = assem_constraint_residual( 117 | root_prim, constraints, constraint_graph, constraint_params 118 | ) 119 | return jnp.concatenate(residuals) 120 | 121 | assem_global_jac = jax.jacfwd(assem_global_res) 122 | 123 | ## Iteratively minimize the global residual as function of the global parameter vector 124 | abs_errs = [] 125 | rel_errs = [] 126 | 127 | n = 0 128 | abs_err = np.inf 129 | rel_err = np.inf 130 | while (abs_err > abs_tol) and (rel_err > rel_tol) and (n < max_iter): 131 | 132 | global_res = assem_global_res(global_param_n) 133 | global_jac = assem_global_jac(global_param_n) 134 | 135 | dglobal_param, err, rank, s = np.linalg.lstsq( 136 | global_jac, -global_res, rcond=None 137 | ) 138 | global_param_n = global_param_n + dglobal_param 139 | 140 | n += 1 141 | abs_err = np.linalg.norm(global_res) 142 | abs_errs.append(abs_err) 143 | with warnings.catch_warnings(): 144 | warnings.simplefilter("ignore", RuntimeWarning) 145 | rel_err = abs_errs[-1] / abs_errs[0] 146 | rel_errs.append(rel_err) 147 | 148 | nonlinear_solve_info = {"abs_errs": abs_errs, "rel_errs": rel_errs} 149 | 150 | ## Build a new primitive tree from the global parameter vector 151 | prim_params_n = [ 152 | np.array(global_param_n[idx_start:idx_end]) 153 | for idx_start, idx_end in zip(prim_idx_bounds[:-1], prim_idx_bounds[1:]) 154 | ] 155 | root_prim_n = pr.build_prim_from_unique_values(flat_prim, prim_graph, prim_params_n) 156 | 157 | return root_prim_n, nonlinear_solve_info 158 | 159 | 160 | def solve_minimize( 161 | layout: lay.Layout, 162 | abs_tol: float = 1e-10, 163 | rel_tol: float = 1e-7, 164 | max_iter: int = 10, 165 | ) -> tuple[pr.PrimitiveNode, SolverInfo]: 166 | """ 167 | Return geometric primitives that satisfy constraints using minimization (L-BFGS-B) 168 | 169 | The minimization strategies are from `scipy`. 170 | 171 | Parameters 172 | ---------- 173 | Parameters match those for `solve` except for `method` 174 | 175 | See `solve` for more details. 176 | 177 | Returns 178 | ------- 179 | Returns match those for `solve` 180 | """ 181 | 182 | ## Set-up assembly function for the global residual as a function of a global 183 | ## parameter list 184 | 185 | # `prim_idx_bounds` stores the right/left indices for each primitive's 186 | # parameter vector in the global parameter vector array 187 | # For primitive with index `n`, for example, 188 | # `prim_idx_bounds[n], prim_idx_bounds[n+1]` are the indices between which 189 | # the parameter vectors are stored. 190 | flat_prim = cn.flatten('', layout.root_prim) 191 | prim_graph, prim_values = pr.filter_unique_values_from_prim(layout.root_prim) 192 | prim_idx_bounds = np.cumsum([0] + [value.size for value in prim_values]) 193 | global_param_n = np.concatenate(prim_values) 194 | 195 | constraints, constraint_graph, constraint_params = layout.flat_constraints() 196 | 197 | @jax.jit 198 | def assem_objective(global_param): 199 | new_prim_params = [ 200 | global_param[idx_start:idx_end] 201 | for idx_start, idx_end in zip(prim_idx_bounds[:-1], prim_idx_bounds[1:]) 202 | ] 203 | root_prim = pr.build_prim_from_unique_values(flat_prim, prim_graph, new_prim_params) 204 | residuals = assem_constraint_residual( 205 | root_prim, constraints, constraint_graph, constraint_params 206 | ) 207 | return jnp.sum(jnp.concatenate(residuals)**2) 208 | 209 | class MinHistory: 210 | 211 | def __init__(self): 212 | self.abs_errs = [] 213 | self.rel_errs = [] 214 | 215 | def callback(self, intermediate_result: OptimizeResult): 216 | abs_err = intermediate_result['fun'] 217 | self.abs_errs.append(abs_err) 218 | 219 | rel_err = abs_err / self.abs_errs[0] 220 | self.rel_errs.append(rel_err) 221 | 222 | if rel_err < rel_tol or abs_err < abs_tol: 223 | raise StopIteration() 224 | 225 | min_hist = MinHistory() 226 | 227 | ## Iteratively minimize the global residual as function of the global parameter vector 228 | 229 | # TODO: (not critical) Implement other optimization solvers besides 'L-BFGS-B' 230 | res = minimize( 231 | jax.value_and_grad(assem_objective), 232 | global_param_n, 233 | method='L-BFGS-B', 234 | jac=True, 235 | callback=min_hist.callback, 236 | options={'maxiter': max_iter} 237 | ) 238 | global_param_n = res['x'] 239 | 240 | prim_params_n = [ 241 | np.array(global_param_n[idx_start:idx_end]) 242 | for idx_start, idx_end in zip(prim_idx_bounds[:-1], prim_idx_bounds[1:]) 243 | ] 244 | root_prim_n = pr.build_prim_from_unique_values(flat_prim, prim_graph, prim_params_n) 245 | 246 | nonlinear_solve_info = { 247 | "abs_errs": min_hist.abs_errs, "rel_errs": min_hist.rel_errs 248 | } 249 | 250 | return root_prim_n, nonlinear_solve_info 251 | 252 | 253 | def assem_constraint_residual( 254 | root_prim: pr.Primitive, 255 | constraints: list[cr.Constraint], 256 | constraint_graph: list[cr.PrimKeys], 257 | constraint_params: list[cr.Params] 258 | ) -> list[NDArray]: 259 | """ 260 | Return a list of constraint residual vectors 261 | 262 | Parameters 263 | ---------- 264 | root_prim: pr.Primitive 265 | The primitive which constraints act on 266 | constraints: list[cr.Constraint] 267 | A list of constraints 268 | constraint_graph: list[cr.PrimKeys] 269 | A list of keys indicating primitives in `root_prim` for each constraint 270 | constraint_params: list[cr.ResParams] 271 | A list of parameters for each constraint 272 | 273 | Returns 274 | ------- 275 | residuals: list[NDArray] 276 | A list of constraint residual vectors 277 | 278 | Each residual vector corresponds to a constraint in `constraints` 279 | """ 280 | residuals = [ 281 | constraint(tuple(root_prim[key] for key in prim_keys), *param) 282 | for constraint, prim_keys, param in zip(constraints, constraint_graph, constraint_params) 283 | ] 284 | return residuals 285 | -------------------------------------------------------------------------------- /src/mpllayout/ui.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for visualizing primitives and constraints 3 | """ 4 | 5 | from typing import Callable, Optional 6 | from matplotlib.figure import Figure 7 | from matplotlib.axes import Axes 8 | 9 | import matplotlib as mpl 10 | from matplotlib.patches import Polygon 11 | from matplotlib import pyplot as plt 12 | from matplotlib.colors import Colormap 13 | import numpy as np 14 | 15 | from . import primitives as pr 16 | from . import constraints as cr 17 | from . import constructions as cn 18 | 19 | ## Functions for plotting geometric primitives 20 | 21 | 22 | def plot_point(ax: Axes, point: pr.Point, label: Optional[str]=None, **kwargs): 23 | """ 24 | Plot a point 25 | 26 | Parameters 27 | ---------- 28 | ax: Axes 29 | The axes to plot in 30 | point: pr.Point 31 | The point to plot 32 | label: Optional[str] 33 | A label 34 | **kwargs 35 | Additional keyword arguments for plotting 36 | """ 37 | x, y = point.value 38 | line, = ax.plot([x], [y], marker=".", **kwargs) 39 | ax.annotate(label, (x, y), ha='center', **kwargs) 40 | 41 | def rotation_from_line(line: pr.Line) -> float: 42 | """ 43 | Return the rotation of a line vector 44 | """ 45 | line_vec = cn.LineVector.assem((line,)) 46 | unit_vec = line_vec / np.linalg.norm(line_vec) 47 | 48 | # Since `unit_vec` has unit length, the x-component is the cosine 49 | theta = 180/np.pi * np.arccos(unit_vec[0]) 50 | if unit_vec[1] < 0: 51 | theta = theta + 180 52 | 53 | return theta 54 | 55 | def plot_line(ax: Axes, line: pr.Line, label: Optional[str]=None, **kwargs): 56 | """ 57 | Plot a line 58 | 59 | Parameters 60 | ---------- 61 | ax: Axes 62 | The axes to plot in 63 | line: pr.Line 64 | The line to plot 65 | label: Optional[str] 66 | A label 67 | **kwargs 68 | Additional keyword arguments for plotting 69 | """ 70 | # Don't plot zero length lines 71 | if cn.Length.assem((line,)) != 0: 72 | xs = np.array([point.value[0] for point in line.values()]) 73 | ys = np.array([point.value[1] for point in line.values()]) 74 | ax.plot(xs, ys, **kwargs) 75 | 76 | xmid = 1/2*xs.sum() 77 | ymid = 1/2*ys.sum() 78 | theta = rotation_from_line(line) 79 | ax.annotate(label, (xmid, ymid), ha='center', va='baseline', rotation=theta, **kwargs) 80 | 81 | def plot_polygon(ax: Axes, polygon: pr.Polygon, label: Optional[str]=None, **kwargs): 82 | """ 83 | Plot a `Polygon` 84 | 85 | Parameters 86 | ---------- 87 | ax: Axes 88 | The axes to plot in 89 | polygon: pr.Polygon 90 | The polygon to plot 91 | label: Optional[str] 92 | A label 93 | **kwargs 94 | Additional keyword arguments for plotting 95 | """ 96 | origin = polygon[f"Line0"]["Point0"].value 97 | points = [ 98 | polygon[f"Line{ii}"]["Point0"] for ii in range(len(polygon)) 99 | ] 100 | verts = np.array([point.value for point in points]) 101 | patch_kwargs = kwargs.copy() 102 | patch_kwargs['alpha'] = 0.1*kwargs['alpha'] 103 | poly_patch = Polygon(verts, closed=True, **patch_kwargs) 104 | ax.add_patch(poly_patch) 105 | 106 | # (line,) = ax.plot(xs, ys, **kwargs) 107 | if label is not None: 108 | # Place the label at the first point 109 | ax.annotate( 110 | label, 111 | origin, 112 | xycoords="data", 113 | xytext=(2.0, 2.0), 114 | textcoords="offset points", 115 | ha="left", 116 | va="bottom", 117 | **kwargs 118 | ) 119 | 120 | def plot_generic_prim(ax: Axes, prim: pr.Primitive, label: Optional[str]=None, **kwargs): 121 | pass 122 | 123 | def plot_prim( 124 | ax: Axes, 125 | prim: pr.Primitive, 126 | prim_key: str='', 127 | max_label_depth: int = 99, 128 | **kwargs 129 | ): 130 | """ 131 | Recursively plot all child primitives of a generic primitive 132 | 133 | Parameters 134 | ---------- 135 | ax: Axes 136 | The axes to plot in 137 | prim: pr.Primitive 138 | The primitive to plot 139 | label: Optional[str] 140 | A label 141 | **kwargs 142 | Additional keyword arguments for plotting 143 | """ 144 | split_key = prim_key.split("/") 145 | 146 | # The primitive heigth is the maximum primitive depth for all child 147 | # primitives. 148 | prim_height = prim.node_height() + len(split_key) - 1 149 | 150 | # The primitive depth is how far away from the root node the primitive is. 151 | depth = len(split_key) - 1 152 | 153 | # Add labels for primitives before a maximum depth 154 | if depth > max_label_depth or prim_key == '': 155 | label = None 156 | else: 157 | parent_key = depth*"." 158 | label = f"{parent_key}/{split_key[-1]}" 159 | 160 | # Use the prim height to control opacity. 161 | # The root primitive has 100% opacity while deeper child primitives 162 | # have lower opacity. 163 | if prim_height == 0: 164 | s = 1 165 | else: 166 | s = (prim_height - depth)/prim_height 167 | alpha = 1*s + 0.2*(1-s) 168 | 169 | plot = make_plot(prim) 170 | plot(ax, prim, label=label, alpha=alpha, **kwargs) 171 | 172 | if isinstance(prim, pr.Line): 173 | pass 174 | # NOTE: I skipped plotting line start and end points because the labels 175 | # get crowded 176 | # TODO: Implement nice plots of points in lines 177 | # Should try to avoid label overlap 178 | else: 179 | for child_key, child_prim in prim.items(): 180 | plot_prim( 181 | ax, 182 | child_prim, 183 | prim_key=f'{prim_key}/{child_key}', 184 | max_label_depth=max_label_depth, 185 | **kwargs 186 | ) 187 | 188 | 189 | ## Functions for plotting arbitrary geometric primitives 190 | def make_plot( 191 | prim: pr.Primitive, 192 | ) -> Callable[[Axes, tuple[pr.Primitive, ...]], None]: 193 | """ 194 | Return a function that can plot a `pr.Primitive` object 195 | 196 | Parameters 197 | ---------- 198 | prim: pr.Primitive 199 | The primitive to plot 200 | 201 | Returns 202 | ------- 203 | Callable[[Axes, tuple[pr.Primitive, ...]], None] 204 | A function that can plot the primitive 205 | 206 | This function is one of the above `plot_...` function 207 | (see `plot_point`, `plot_line`, etc.). 208 | """ 209 | if isinstance(prim, pr.Point): 210 | return plot_point 211 | elif isinstance(prim, pr.Line): 212 | return plot_line 213 | elif isinstance(prim, pr.Polygon): 214 | return plot_polygon 215 | else: 216 | return plot_generic_prim 217 | 218 | 219 | def plot_prims(ax: Axes, root_prim: pr.Primitive, cmap: Colormap=mpl.colormaps['viridis']): 220 | """ 221 | Plot all child primitives in a root primitive 222 | 223 | Parameters 224 | ---------- 225 | ax: Axes 226 | The axes to plot in 227 | root_prim: pr.Primitive 228 | The primitive to plot 229 | """ 230 | num_prims = len(root_prim) 231 | for ii, (key, prim) in enumerate(root_prim.items()): 232 | color = cmap(ii / num_prims) 233 | plot_prim(ax, prim, prim_key=key, color=color) 234 | 235 | 236 | def figure_prims( 237 | root_prim: pr.Primitive, 238 | fig_size: tuple[float, float] = (8, 8), 239 | major_tick_interval: float = 1.0, 240 | minor_tick_interval: float = 1/8 241 | ) -> tuple[Figure, Axes]: 242 | """ 243 | Return a figure of a primitive 244 | 245 | Parameters 246 | ---------- 247 | root_prim: pr.Primitive 248 | The primitive to plot 249 | fig_size: tuple[float, float] 250 | The figure size 251 | major_tick_interval, minor_tick_interval: float, float 252 | Major and minor tick intervals for grid lines 253 | 254 | By default these are 1 and 1/8 which is nice for inch dimensions. 255 | 256 | Returns 257 | ------- 258 | fig: Figure 259 | The figure 260 | ax: Axes 261 | The axes 262 | """ 263 | 264 | fig, ax = plt.subplots(1, 1, figsize=fig_size) 265 | 266 | for axis in (ax.xaxis, ax.yaxis): 267 | axis.set_minor_locator(mpl.ticker.MultipleLocator(minor_tick_interval)) 268 | axis.set_major_locator(mpl.ticker.MultipleLocator(major_tick_interval)) 269 | 270 | ax.set_aspect(1) 271 | ax.grid() 272 | 273 | ax.set_xlabel("x [in]") 274 | ax.set_ylabel("y [in]") 275 | 276 | plot_prims(ax, root_prim) 277 | 278 | return (fig, ax) 279 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jon-deng/mpl-layout/6a10aeff3dfa6e989cc8d9beefa55b78527ebe6a/tests/__init__.py -------------------------------------------------------------------------------- /tests/fixture_primitives.py: -------------------------------------------------------------------------------- 1 | """ 2 | Fixtures to create primitives 3 | """ 4 | 5 | from numpy.typing import NDArray 6 | 7 | import pytest 8 | import itertools 9 | 10 | import numpy as np 11 | 12 | from mpllayout import primitives as pr 13 | 14 | class GeometryFixtures: 15 | """ 16 | Utilities to help create primitives 17 | """ 18 | 19 | ## Coordinate axes 20 | 21 | @pytest.fixture( 22 | params=[ 23 | ('x', np.array([1, 0])), 24 | ('y', np.array([0, 1])) 25 | ] 26 | ) 27 | def axis_name_dir(self, request): 28 | return request.param 29 | 30 | @pytest.fixture() 31 | def axis_name(self, axis_name_dir): 32 | return axis_name_dir[0] 33 | 34 | @pytest.fixture() 35 | def axis_dir(self, axis_name_dir): 36 | return axis_name_dir[1] 37 | 38 | ## Point creation 39 | def make_point(self, coord): 40 | """ 41 | Return a `pr.Point` at the given coordinates 42 | """ 43 | return pr.Point(value=coord) 44 | 45 | def make_relative_point(self, point: pr.Point, displacement: NDArray): 46 | """ 47 | Return a `pr.Point` displaced from a given point 48 | """ 49 | return pr.Point(value=point.value + displacement) 50 | 51 | ## Line creation 52 | def make_rotation(self, theta: float): 53 | rot_mat = np.array( 54 | [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]] 55 | ) 56 | return rot_mat 57 | 58 | def make_line(self, origin: NDArray, line_vec: NDArray): 59 | """ 60 | Return a `pr.Line` with given origin and line vector 61 | """ 62 | coords = (origin, origin + line_vec) 63 | return pr.Line(value=[], prims=tuple(pr.Point(x) for x in coords)) 64 | 65 | def make_relative_line( 66 | self, line: pr.Line, translation: NDArray, deformation: NDArray 67 | ): 68 | """ 69 | Return a `pr.Line` deformed about it's start point then translated 70 | """ 71 | lineb_vec = line[1].value - line[0].value 72 | lineb_vec = deformation @ lineb_vec 73 | 74 | lineb_start = line[0].value + translation 75 | return self.make_line(lineb_start, lineb_vec) 76 | 77 | def make_relline_about_mid( 78 | self, line: pr.Line, translation: NDArray, deformation: NDArray 79 | ): 80 | """ 81 | Return a `pr.Line` deformed about it's midpoint then translated 82 | """ 83 | lineb_vec = line[1].value - line[0].value 84 | lineb_vec = deformation @ lineb_vec 85 | 86 | lineb_mid = 1/2*(line[0].value + line[1].value) + translation 87 | lineb_start = lineb_mid - lineb_vec/2 88 | return self.make_line(lineb_start, lineb_vec) 89 | 90 | ## Quadrilateral creation 91 | def make_quad(self, displacement, deformation): 92 | """ 93 | Return a `pr.Quadrilateral` translated and deformed from a unit quadrilateral 94 | """ 95 | # Specify vertices of a unit square, then deform it and translate it 96 | verts = np.array([[0, 0], [1, 0], [1, 1], [0, 1]]) 97 | verts = np.tensordot(verts, deformation, axes=(-1, -1)) 98 | verts = verts + displacement 99 | 100 | return pr.Quadrilateral( 101 | value=[], children=tuple(pr.Point(vert) for vert in verts) 102 | ) 103 | 104 | def make_quad_grid( 105 | self, 106 | translation: NDArray, 107 | col_margins: NDArray, 108 | row_margins: NDArray, 109 | col_widths: NDArray, 110 | row_heights: NDArray, 111 | ): 112 | """ 113 | Return a grid of `Quadrilateral`s with shape (M, N) 114 | 115 | Rows and columns are numbered from top to bottom and left to right, respectively. 116 | 117 | Parameters 118 | ---------- 119 | col_margins, row_margins: NDArray (N-1,), (M-1,) 120 | Column and row margins 121 | col_widths, row_heights: NDArray (N,), (M,) 122 | Column and row dimensions 123 | """ 124 | # Determine translations/transformations needed for each quad 125 | col_defs = col_widths[:, None, None] * np.outer([1, 0], [1, 0]) + np.outer( 126 | [0, 1], [0, 1] 127 | ) 128 | row_defs = np.outer([1, 0], [1, 0]) + row_heights[:, None, None] * np.outer( 129 | [0, 1], [0, 1] 130 | ) 131 | 132 | cum_col_widths = np.cumsum( 133 | np.concatenate((translation[[0]], col_widths[:-1] + col_margins)) 134 | ) 135 | col_trans = np.stack([np.array([x, 0]) for x in cum_col_widths], axis=0) 136 | cum_row_heights = np.cumsum( 137 | np.concatenate((translation[[1]], -row_heights[1:] - row_margins)) 138 | ) 139 | row_trans = np.stack([np.array([0, y]) for y in cum_row_heights], axis=0) 140 | 141 | row_args = zip(row_trans, row_defs) 142 | col_args = zip(col_trans, col_defs) 143 | 144 | quads = tuple( 145 | self.make_quad(drow + dcol, row_def @ col_def) 146 | for (drow, row_def), (dcol, col_def) in itertools.product( 147 | row_args, col_args 148 | ) 149 | ) 150 | return quads -------------------------------------------------------------------------------- /tests/test_constructions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test geometric onstraints 3 | """ 4 | 5 | import pytest 6 | 7 | from numpy.typing import NDArray 8 | 9 | import numpy as np 10 | 11 | from mpllayout import primitives as pr 12 | from mpllayout import constructions as con 13 | from mpllayout.containers import Node, accumulate 14 | 15 | from tests.fixture_primitives import GeometryFixtures 16 | 17 | class TestConstructionValidation(GeometryFixtures): 18 | 19 | def test_validate_prims(self): 20 | midpoint = con.Midpoint() 21 | 22 | line = self.make_line(np.random.rand(2), np.random.rand(2)) 23 | point_a = self.make_point(np.random.rand(2)) 24 | point_b = self.make_point(np.random.rand(2)) 25 | 26 | # midpoint((point_a, point_b)) 27 | 28 | with pytest.raises(TypeError, match=r"Expected tuple of primitives.*") as exc: 29 | midpoint(line) 30 | print(exc.value) 31 | 32 | with pytest.raises(TypeError, match=r"Expected [0-9]* primitives.*") as exc: 33 | midpoint((point_a, point_b)) 34 | print(exc.value) 35 | 36 | with pytest.raises(TypeError, match=r"Expected primitive types.*") as exc: 37 | midpoint((point_a,)) 38 | print(exc.value) 39 | 40 | class TestConstructionFunctions(GeometryFixtures): 41 | 42 | def test_transform_map(self): 43 | num_point = 2 44 | coords = np.random.rand(num_point, 2) 45 | points = tuple(pr.Point(value=coord) for coord in coords) 46 | 47 | coords_ref = np.concatenate([con.Coordinate()((point,)) for point in points]) 48 | map_coordinate = con.transform_map(con.Coordinate(), num_point*[pr.Point]) 49 | coords_map = map_coordinate(points) 50 | 51 | assert np.all(np.isclose(coords_ref, coords_map)) 52 | 53 | def test_transform_constraint(self): 54 | 55 | construction = con.OuterMargin(side='right') 56 | quada = self.make_quad(np.zeros(2), np.diag(np.ones(2))) 57 | quadb = self.make_quad(np.array([1.5, 0]), np.diag(np.ones(2))) 58 | prims = (quada, quadb) 59 | 60 | value = construction(prims) 61 | 62 | constraint = con.transform_constraint(construction) 63 | res = constraint(prims, value) 64 | assert np.all(np.isclose(res, 0)) 65 | 66 | construction = con.Coordinate() 67 | point = self.make_point(np.random.rand(2)) 68 | prims = (point,) 69 | 70 | value = construction(prims) 71 | 72 | constraint = con.transform_constraint(construction) 73 | res = constraint(prims, value) 74 | assert np.all(np.isclose(res, 0)) 75 | 76 | def test_transform_sum(self): 77 | construction = con.Coordinate() 78 | sum_construction = con.transform_sum(construction, construction) 79 | 80 | prims_a = (self.make_point(np.random.rand(2)),) 81 | prims_b = (self.make_point(np.random.rand(2)),) 82 | params_a = () 83 | params_b = () 84 | 85 | res_a = construction(prims_a, *params_a) + construction(prims_b, *params_b) 86 | res_b = sum_construction(prims_a + prims_b, *(params_a + params_b)) 87 | 88 | assert np.all(np.isclose(res_a, res_b)) 89 | 90 | res_b = (construction + construction)(prims_a+prims_b, *(params_a + params_b)) 91 | 92 | assert np.all(np.isclose(res_a, res_b)) 93 | 94 | def test_transform_scalar_mul(self): 95 | construction = con.Coordinate() 96 | 97 | prims = (self.make_point(np.random.rand(2)),) 98 | cons_params = () 99 | scalar = np.random.rand() 100 | res_a = scalar * construction(prims, *cons_params) 101 | 102 | # Test constant version 103 | params = cons_params + () 104 | 105 | mul_construction = con.transform_scalar_mul(construction, scalar) 106 | res_b = mul_construction(prims, *params) 107 | assert np.all(np.isclose(res_a, res_b)) 108 | 109 | mul_construction = scalar * construction 110 | res_b = mul_construction(prims, *params) 111 | assert np.all(np.isclose(res_a, res_b)) 112 | 113 | # Test non constant version 114 | params = cons_params + (scalar,) 115 | 116 | mul_construction = con.transform_scalar_mul(construction, con.Scalar()) 117 | res_b = mul_construction(prims, *params) 118 | assert np.all(np.isclose(res_a, res_b)) 119 | 120 | mul_construction = con.Scalar() * construction 121 | res_b = mul_construction(prims, *params) 122 | assert np.all(np.isclose(res_a, res_b)) 123 | 124 | def test_transform_scalar_mul(self): 125 | construction = con.Coordinate() 126 | 127 | prims = (self.make_point(np.random.rand(2)),) 128 | cons_params = () 129 | scalar = np.random.rand() 130 | res_a = construction(prims, *cons_params) / scalar 131 | 132 | # Test constant version 133 | params = cons_params + () 134 | 135 | div_construction = con.transform_scalar_div(construction, scalar) 136 | res_b = div_construction(prims, *params) 137 | assert np.all(np.isclose(res_a, res_b)) 138 | 139 | div_construction = construction / scalar 140 | res_b = div_construction(prims, *params) 141 | assert np.all(np.isclose(res_a, res_b)) 142 | 143 | # Test non constant version 144 | params = cons_params + (scalar,) 145 | 146 | div_construction = con.transform_scalar_div(construction, con.Scalar()) 147 | res_b = div_construction(prims, *params) 148 | assert np.all(np.isclose(res_a, res_b)) 149 | 150 | div_construction = construction / con.Scalar() 151 | res_b = div_construction(prims, *params) 152 | assert np.all(np.isclose(res_a, res_b)) 153 | 154 | def test_transform_scalar_pow(self): 155 | construction = con.Coordinate() 156 | 157 | prims = (self.make_point(np.random.rand(2)),) 158 | cons_params = () 159 | scalar = np.random.rand() 160 | res_a = construction(prims, *cons_params) ** scalar 161 | 162 | # Test constant version 163 | params = cons_params + () 164 | 165 | pow_construction = con.transform_scalar_pow(construction, scalar) 166 | res_b = pow_construction(prims, *params) 167 | assert np.all(np.isclose(res_a, res_b)) 168 | 169 | pow_construction = construction ** scalar 170 | res_b = pow_construction(prims, *params) 171 | assert np.all(np.isclose(res_a, res_b)) 172 | 173 | # Test non constant version 174 | params = cons_params + (scalar,) 175 | 176 | pow_construction = con.transform_scalar_pow(construction, con.Scalar()) 177 | res_b = pow_construction(prims, *params) 178 | assert np.all(np.isclose(res_a, res_b)) 179 | 180 | pow_construction = construction ** con.Scalar() 181 | res_b = pow_construction(prims, *params) 182 | assert np.all(np.isclose(res_a, res_b)) 183 | 184 | def test_transform_partial(self): 185 | cons = con.DirectedLength() 186 | 187 | line = self.make_line(np.zeros(2), np.random.rand(2)) 188 | 189 | direction = np.random.rand(2) 190 | direction = direction / np.linalg.norm(direction) 191 | length_ref = cons((line,), direction) 192 | 193 | length_test = con.transform_partial(cons, direction)((line,)) 194 | 195 | assert np.all(np.isclose(length_test, length_ref)) 196 | 197 | cons = con.Width() 198 | 199 | quad = self.make_quad( 200 | np.zeros(2), np.diag(np.random.rand(2)) 201 | ) 202 | 203 | res_ref = cons((quad,)) 204 | 205 | res_test = con.transform_partial(cons)((quad,)) 206 | 207 | assert np.all(np.isclose(res_test, res_ref)) 208 | 209 | 210 | class TestNull: 211 | 212 | @pytest.fixture() 213 | def size_node(self): 214 | node = Node( 215 | 0, 216 | { 217 | 'Vec1': Node(1, {}), 218 | 'Vec2': Node(2, {}), 219 | 'Vec3': Node( 220 | 0, 221 | {'Vec4': Node(1, {})} 222 | ) 223 | } 224 | ) 225 | return node 226 | 227 | @pytest.fixture() 228 | def vector(self, size_node: Node[int]): 229 | cumsize_node = accumulate(lambda x, y: x + y, size_node, 0) 230 | return np.random.rand(cumsize_node.value) 231 | 232 | def test_Vector(self, size_node: Node[int], vector: NDArray): 233 | vec = con.Vector(size_node) 234 | vector_test = vec.assem((), vector) 235 | print(vector_test, vector) 236 | 237 | assert np.all(np.isclose(vector_test, vector)) 238 | 239 | 240 | class TestPoint(GeometryFixtures): 241 | """ 242 | Test constructions with signature `[Point]` 243 | """ 244 | 245 | @pytest.fixture() 246 | def coordinate(self): 247 | return np.random.rand(2) 248 | 249 | def test_Coordinate(self, coordinate): 250 | point = self.make_point(coordinate) 251 | res = con.Coordinate()((point,)) - coordinate 252 | assert np.all(np.isclose(res, 0)) 253 | 254 | 255 | class TestLine(GeometryFixtures): 256 | """ 257 | Test constraints with signature `[Line]` 258 | """ 259 | 260 | @pytest.fixture() 261 | def length(self): 262 | return np.random.rand() 263 | 264 | @pytest.fixture() 265 | def direction(self): 266 | unit_vec = np.random.rand(2) 267 | unit_vec = unit_vec / np.linalg.norm(unit_vec) 268 | return unit_vec 269 | 270 | @pytest.fixture() 271 | def linea(self, length, direction): 272 | origin = np.random.rand(2) 273 | return self.make_line(origin, direction * length) 274 | 275 | def test_Length(self, linea, length): 276 | res = con.Length()((linea,)) - length 277 | assert np.all(np.isclose(res, 0)) 278 | 279 | def test_DirectedLength(self, length, direction): 280 | line_dir = np.random.rand(2) 281 | line_vec = length*line_dir 282 | line = self.make_line((0, 0), line_vec) 283 | 284 | dlength = np.dot(line_vec, direction) 285 | res = con.DirectedLength()((line,), direction) - dlength 286 | assert np.all(np.isclose(res, 0)) 287 | 288 | @pytest.fixture() 289 | def XYLength(self, axis_name): 290 | if axis_name == 'x': 291 | return con.XLength 292 | else: 293 | return con.YLength 294 | 295 | def test_XYLength(self, XYLength, axis_dir, length): 296 | line_dir = np.random.rand(2) 297 | line_vec = length*line_dir 298 | line = self.make_line((0, 0), line_vec) 299 | 300 | dlength = np.dot(line_vec, axis_dir) 301 | 302 | res = XYLength()((line,)) - dlength 303 | 304 | assert np.all(np.isclose(res, 0)) 305 | 306 | 307 | class TestQuadrilateral(GeometryFixtures): 308 | """ 309 | Test constructions with signature `[Quadrilateral]` 310 | """ 311 | 312 | @pytest.fixture() 313 | def quada(self): 314 | return self.make_quad(np.random.rand(2), np.random.rand(2, 2)) 315 | 316 | @pytest.fixture() 317 | def aspect_ratio(self, quada): 318 | width = np.linalg.norm(con.LineVector.assem((quada['Line0'],))) 319 | height = np.linalg.norm(con.LineVector.assem((quada['Line1'],))) 320 | return width/height 321 | 322 | def test_AspectRatio(self, quada: pr.Quadrilateral, aspect_ratio: float): 323 | res = con.AspectRatio()((quada,)) - aspect_ratio 324 | assert np.all(np.isclose(res, 0)) 325 | 326 | 327 | class TestQuadrilateralQuadrilateral(GeometryFixtures): 328 | """ 329 | Test constraints with signature `[Quadrilateral, Quadrilateral]` 330 | """ 331 | 332 | @pytest.fixture() 333 | def margin(self): 334 | return np.random.rand() 335 | 336 | @pytest.fixture() 337 | def boxa(self): 338 | size = np.random.rand(2) 339 | origin = np.random.rand(2) 340 | return self.make_quad(origin, np.diag(size)) 341 | 342 | @pytest.fixture() 343 | def boxb(self): 344 | size = np.random.rand(2) 345 | origin = np.random.rand(2) 346 | return self.make_quad(origin, np.diag(size)) 347 | 348 | @pytest.fixture(params=('bottom', 'top', 'left', 'right')) 349 | def margin_side(self, request): 350 | return request.param 351 | 352 | @pytest.fixture() 353 | def outer_margin(self, boxa, boxb, margin_side): 354 | a_topright = boxa['Line1/Point1'].value 355 | b_topright = boxb['Line1/Point1'].value 356 | a_botleft = boxa['Line0/Point0'].value 357 | b_botleft = boxb['Line0/Point0'].value 358 | 359 | if margin_side == 'left': 360 | margin = (a_botleft - b_topright)[0] 361 | elif margin_side == 'right': 362 | margin = (b_botleft - a_topright)[0] 363 | elif margin_side == 'bottom': 364 | margin = (a_botleft - b_topright)[1] 365 | elif margin_side == 'top': 366 | margin = (b_botleft - a_topright)[1] 367 | return margin 368 | 369 | def test_OuterMargin(self, boxa, boxb, outer_margin, margin_side): 370 | res = con.OuterMargin(side=margin_side)((boxa, boxb)) - outer_margin 371 | assert np.all(np.isclose(res, 0)) 372 | 373 | @pytest.fixture() 374 | def inner_margin(self, boxa, boxb, margin_side): 375 | a_topright = boxa['Line1/Point1'].value 376 | b_topright = boxb['Line1/Point1'].value 377 | a_botleft = boxa['Line0/Point0'].value 378 | b_botleft = boxb['Line0/Point0'].value 379 | 380 | if margin_side == 'left': 381 | margin = (a_botleft - b_botleft)[0] 382 | elif margin_side == 'right': 383 | margin = (b_topright - a_topright)[0] 384 | elif margin_side == 'bottom': 385 | margin = (a_botleft - b_botleft)[1] 386 | elif margin_side == 'top': 387 | margin = (b_topright-a_topright)[1] 388 | return margin 389 | 390 | def test_InnerMargin(self, boxa, boxb, inner_margin, margin_side): 391 | res = con.InnerMargin(side=margin_side)((boxa, boxb)) - inner_margin 392 | assert np.all(np.isclose(res, 0)) 393 | 394 | -------------------------------------------------------------------------------- /tests/test_containers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for `mpllayout.containers` 3 | """ 4 | 5 | import pytest 6 | 7 | from timeit import timeit 8 | 9 | import numpy as np 10 | 11 | from mpllayout import containers as cn 12 | 13 | 14 | class NodeFixtures: 15 | 16 | def random_node( 17 | self, 18 | value: float, 19 | depth: int=0, 20 | min_children: int=1, 21 | max_children: int=1, 22 | min_depth: int=0, 23 | max_depth: int=0 24 | ): 25 | num_child = np.random.randint(min_children, max_children) 26 | child_values = np.random.rand(num_child) 27 | 28 | # Give a linear distribution for the probability of stopping 29 | def probability_stop(depth: int): 30 | _p = (depth - min_depth) / (max_depth - min_depth) 31 | return np.clip(_p, 0, 1) 32 | 33 | if np.random.rand() < probability_stop(depth): 34 | children = {} 35 | else: 36 | child_kwargs = { 37 | 'depth': depth+1, 38 | 'min_children': min_children, 39 | 'max_children': max_children, 40 | 'min_depth': min_depth, 41 | 'max_depth': max_depth 42 | } 43 | children = { 44 | f"Child{n:d}": self.random_node(child_value, **child_kwargs) 45 | for n, child_value in enumerate(child_values) 46 | } 47 | return cn.Node(value, children) 48 | 49 | @pytest.fixture(params=(0, 1, 2)) 50 | def num_children(self, request): 51 | return request.param 52 | 53 | @pytest.fixture() 54 | def root_value(self): 55 | return np.random.rand() 56 | 57 | @pytest.fixture() 58 | def child_values(self, num_children: int): 59 | return np.random.rand(num_children) 60 | 61 | @pytest.fixture() 62 | def node(self, root_value: float, child_values: list[float]): 63 | """ 64 | Return nodes with varying numbers of children 65 | 66 | This covers the case of a leaf node, a node with one child, and a node 67 | with two children. 68 | """ 69 | children = { 70 | f'a{n}': cn.Node(child_value, {}) 71 | for n, child_value in enumerate(child_values) 72 | } 73 | return cn.Node(root_value, children) 74 | 75 | 76 | class TestNode(NodeFixtures): 77 | 78 | def test_node_height(self): 79 | 80 | # Check that a root node has height 0 81 | node = cn.Node(0, {}) 82 | assert node.node_height() == 0 83 | 84 | # Check nodes where one child has a grand child and the other child does not 85 | node = cn.Node(0, {'a1': cn.Node(0, {'b1': cn.Node(0, {})}), 'a2': cn.Node(0, {})}) 86 | assert node.node_height() == 2 87 | 88 | node = cn.Node(0, {'a1': cn.Node(0, {}), 'a2': cn.Node(0, {'b1': cn.Node(0, {})})}) 89 | assert node.node_height() == 2 90 | 91 | def test_repr(self, node: cn.Node): 92 | print(node) 93 | 94 | def test_iter_flat(self, node: cn.Node): 95 | print([{key: _node} for key, _node in cn.iter_flat("", node)]) 96 | 97 | def test_flatten_unflatten_python(self, node: cn.Node): 98 | fnode_structs = cn.flatten("root", node) 99 | reconstructed_node, _ = cn.unflatten(fnode_structs) 100 | 101 | print(fnode_structs) 102 | 103 | assert str(node) == str(reconstructed_node) 104 | print(node) 105 | print(reconstructed_node) 106 | 107 | N = int(1e5) 108 | timeit_kwargs = {"globals": {**globals(), **locals()}, "number": N} 109 | 110 | duration = timeit("cn.flatten('root', node)", **timeit_kwargs) 111 | print(f"Flattening duration: {duration/N: .2e} s") 112 | 113 | duration = timeit("cn.unflatten(fnode_structs)", **timeit_kwargs) 114 | print(f"Unflattening duration: {duration/N: .2e} s") 115 | 116 | def test_flatten_unflatten_jax(self, node: cn.Node): 117 | import jax 118 | 119 | flat_tree, flat_tree_def = jax.tree_util.tree_flatten(node) 120 | reconstructed_node = jax.tree_util.tree_unflatten(flat_tree_def, flat_tree) 121 | 122 | assert str(node) == str(reconstructed_node) 123 | print(node) 124 | print(reconstructed_node) 125 | 126 | N = int(1e5) 127 | timeit_kwargs = {"globals": {**globals(), **locals()}, "number": N} 128 | 129 | duration = timeit("jax.tree_util.tree_flatten(node)", **timeit_kwargs) 130 | print(f"Flattening duration: {duration/N: .2e} s") 131 | 132 | duration = timeit( 133 | "jax.tree_util.tree_unflatten(flat_tree_def, flat_tree)", **timeit_kwargs 134 | ) 135 | print(f"Unflattening duration: {duration/N: .2e} s") 136 | 137 | 138 | class TestFunctions(NodeFixtures): 139 | 140 | def test_map(self, node): 141 | def fun(x): 142 | return x+1 143 | 144 | node_test = cn.map(fun, node) 145 | 146 | values_ref = [fun(fnode.value) for _, fnode in cn.iter_flat('', node)] 147 | values_test = [fnode.value for _, fnode in cn.iter_flat('', node_test)] 148 | assert np.all(np.isclose(values_test, values_ref)) 149 | 150 | def test_accumulate(self, node, root_value, child_values): 151 | def fun(x, y): 152 | return x + y 153 | 154 | node_test = cn.accumulate(fun, node, initial=0) 155 | 156 | assert np.isclose(node_test.value, root_value + np.sum(child_values)) 157 | -------------------------------------------------------------------------------- /tests/test_layout.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test `layout` 3 | """ 4 | 5 | import pytest 6 | 7 | from pprint import pprint 8 | 9 | import numpy as np 10 | 11 | from mpllayout import primitives as pr 12 | from mpllayout import constraints as co 13 | from mpllayout import layout as lat 14 | from mpllayout import containers as cn 15 | 16 | 17 | class TestPrimitiveTree: 18 | 19 | @pytest.fixture() 20 | def prim_node(self): 21 | return cn.Node(np.array([]), {}) 22 | 23 | def test_set_prim(self, prim_node): 24 | prim_node.add_child("MyBox", pr.Quadrilateral()) 25 | 26 | pprint(f"Keys:") 27 | pprint(prim_node["MyBox"].keys()) 28 | 29 | def test_build_primtree(self, prim_node): 30 | point_a = pr.Point([0, 0]) 31 | point_b = pr.Point([1, 1]) 32 | prim_node.add_child("PointA", point_a) 33 | prim_node.add_child("LineA", pr.Line([], (point_a, point_b))) 34 | prim_node.add_child("MySpecialBox", pr.Quadrilateral()) 35 | 36 | prim_graph, prim_values = pr.filter_unique_values_from_prim(prim_node) 37 | 38 | params = prim_values 39 | 40 | new_params = [np.random.rand(*param.shape) for param in params] 41 | new_prim_node = pr.build_prim_from_unique_values(cn.flatten('', prim_node), prim_graph, new_params) 42 | # breakpoint() 43 | 44 | # rng = np.random.default_rng() 45 | 46 | # new_tree = lat.build_tree(prim_tree, prim_graph, new_params, {}) 47 | 48 | # print("Old primitive graph:") 49 | # pprint(prim_tree.prim_graph()) 50 | 51 | # print("Old primitive list") 52 | # pprint(prim_tree.prims()) 53 | 54 | # print("Old primitive keys") 55 | # pprint(prim_tree.keys(flat=True)) 56 | 57 | # print("New parameters") 58 | # pprint(new_params) 59 | 60 | # print("New primitive graph:") 61 | # pprint(new_tree.prim_graph()) 62 | 63 | # print("New primitive list") 64 | # pprint(new_tree.prims()) 65 | 66 | # print("New primitive keys") 67 | # pprint(new_tree.keys(flat=True)) 68 | 69 | 70 | class TestLayout: 71 | 72 | def test_layout(self): 73 | layout = lat.Layout() 74 | 75 | layout.add_prim(pr.Quadrilateral(), "MyBox") 76 | layout.add_constraint(co.Box(), ("MyBox",), ()) 77 | 78 | layout.add_constraint(co.Fix(), ("MyBox/Line0/Point0",), ([0, 0],)) 79 | 80 | pprint(layout.root_prim) 81 | constraints, constraints_argkeys, constraints_param = layout.flat_constraints() 82 | 83 | print("Flat constraints: ") 84 | print("Constraints:") 85 | pprint(constraints) 86 | print("Constraints argument keys:") 87 | pprint(constraints_argkeys) 88 | print("Constraints parameter vector:") 89 | pprint(constraints_param) 90 | 91 | -------------------------------------------------------------------------------- /tests/test_primitives.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test geometric primitive and constraints 3 | """ 4 | 5 | import pytest 6 | 7 | import typing as tp 8 | 9 | import itertools 10 | from pprint import pprint 11 | 12 | import numpy as np 13 | 14 | from mpllayout import primitives as pr 15 | 16 | 17 | class TestPrimitives: 18 | 19 | def test_Quadrilateral(self): 20 | quad = pr.Quadrilateral() 21 | 22 | def test_Line(self): 23 | line = pr.Line() 24 | 25 | def test_Point(self): 26 | point = pr.Point() 27 | 28 | def test_Polygon(self): 29 | poly = pr.Polygon() 30 | 31 | def test_Primitive_jax_pytree(self): 32 | # breakpoint() 33 | from jax import tree_util 34 | 35 | point = pr.Point() 36 | line = pr.Line() 37 | quad = pr.Quadrilateral() 38 | 39 | for prim in (point, line, quad): 40 | print(f"\nTesting primitive type {type(prim).__name__}") 41 | leaves = tree_util.tree_leaves(prim) 42 | print("Leaves:", leaves) 43 | 44 | value_flat, value_tree = tree_util.tree_flatten(prim) 45 | reconstructed_prim = tree_util.tree_unflatten(value_tree, value_flat) 46 | print("tree_util.tree_flatten:", value_flat, value_tree) 47 | print("tree_util.tree_unflatten:", reconstructed_prim) 48 | 49 | leaves = tree_util.tree_leaves([0, 1, 2, 3, [4, 5, [6, 7, [8]]]]) 50 | print(leaves) 51 | print([type(leaf) for leaf in leaves]) 52 | -------------------------------------------------------------------------------- /tests/test_solve.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test `solver` 3 | """ 4 | 5 | import pytest 6 | 7 | import time 8 | from pprint import pprint 9 | 10 | import numpy as np 11 | 12 | from mpllayout import primitives as pr 13 | from mpllayout import constraints as co 14 | from mpllayout import layout as lay 15 | from mpllayout import containers as cn 16 | from mpllayout import solver 17 | 18 | 19 | class TestPrimitiveTree: 20 | 21 | @pytest.fixture() 22 | def layout(self): 23 | layout = lay.Layout() 24 | 25 | verts = np.array([[0.1, 0.2], [1.0, 2.0], [2.0, 2.0], [3.0, 3.0]]) 26 | 27 | layout.add_prim( 28 | pr.Quadrilateral(children=[pr.Point(vert) for vert in verts]), 29 | "MyFavouriteBox", 30 | ) 31 | layout.add_constraint(co.Box(), ("MyFavouriteBox",), ()) 32 | layout.add_constraint( 33 | co.Fix(), ("MyFavouriteBox/Line0/Point0",), (np.array([0, 0]),) 34 | ) 35 | 36 | layout.add_constraint(co.Length(), ("MyFavouriteBox/Line0",), (5.0,)) 37 | layout.add_constraint(co.Length(), ("MyFavouriteBox/Line1",), (5.1,)) 38 | return layout 39 | 40 | @pytest.fixture(params=[(5, 6)]) 41 | def axes_shape(self, request): 42 | return request.param 43 | 44 | @pytest.fixture() 45 | def layout_grid(self, axes_shape): 46 | layout = lay.Layout() 47 | ## Create an origin point 48 | layout.add_prim(pr.Point(), "Origin") 49 | layout.add_constraint(co.Fix(), ("Origin",), (np.array([0, 0]),)) 50 | 51 | ## Create the figure box 52 | layout.add_prim(pr.Quadrilateral(), "Figure",) 53 | layout.add_constraint(co.Box(), ("Figure",), ()) 54 | 55 | ## Constrain the figure size and position 56 | fig_width, fig_height = 6, 3 57 | layout.add_constraint( 58 | co.Length(), ("Figure/Line0",), (fig_width,) 59 | ) 60 | layout.add_constraint( 61 | co.Coincident(), ("Figure/Line0/Point0", "Origin"), () 62 | ) 63 | 64 | ## Create the axes boxes 65 | # axes_shape = (3, 4) 66 | num_row, num_col = axes_shape 67 | num_axes = int(np.prod(axes_shape)) 68 | for n in range(num_axes): 69 | layout.add_prim(pr.Quadrilateral(), f"Axes{n}") 70 | layout.add_constraint(co.Box(), (f"Axes{n}",), ()) 71 | 72 | ## Constrain the axes in a grid 73 | num_row, num_col = axes_shape 74 | grid_param = ( 75 | (num_col - 1) * [1], 76 | (num_row - 1) * [1], 77 | (num_col - 1) * [1 / 16], 78 | (num_row - 1) * [1 / 16], 79 | ) 80 | layout.add_constraint( 81 | co.Grid(axes_shape), 82 | tuple(f"Axes{n}" for n in range(num_axes)), 83 | grid_param 84 | ) 85 | 86 | # Constrain the first axis aspect ratio 87 | layout.add_constraint( 88 | co.RelativeLength(), ("Axes0/Line0", "Axes0/Line1"), (2,) 89 | ) 90 | 91 | # Constrain top/bottom margins 92 | margin_top = 1.1 93 | margin_bottom = 0.5 94 | layout.add_constraint( 95 | co.DirectedDistance(), 96 | ("Axes0/Line1/Point1", "Figure/Line1/Point1"), 97 | (np.array([0, 1]), margin_top) 98 | ) 99 | layout.add_constraint( 100 | co.DirectedDistance(), 101 | (f"Axes{num_axes-1}/Line1/Point0", "Figure/Line1/Point0"), 102 | (np.array([0, -1]), margin_bottom) 103 | ) 104 | 105 | # Constrain left/right margins 106 | margin_left = 0.2 107 | margin_right = 0.3 108 | layout.add_constraint( 109 | co.DirectedDistance(), 110 | ("Axes0/Line0/Point0", "Figure/Line0/Point0"), 111 | (np.array([-1, 0]), margin_left) 112 | ) 113 | layout.add_constraint( 114 | co.DirectedDistance(), 115 | (f"Axes{num_col-1}/Line1/Point1", "Figure/Line1/Point1"), 116 | (np.array([1, 0]), margin_right) 117 | ) 118 | return layout 119 | 120 | def test_assem_constraint_residual(self, layout_grid: lay.Layout): 121 | layout = layout_grid 122 | 123 | root_prim = layout.root_prim 124 | flat_constraints = layout.flat_constraints() 125 | 126 | # Plain call 127 | t0 = time.time() 128 | for i in range(50): 129 | solver.assem_constraint_residual( 130 | root_prim, *flat_constraints 131 | ) 132 | t1 = time.time() 133 | print(f"Duration {t1-t0:.2e} s") 134 | 135 | # `jax.jit` individual constraint functions 136 | import jax 137 | 138 | constraints = flat_constraints[0] 139 | constraints_jit = [jax.jit(constraint) for constraint in constraints] 140 | flat_constraints_jit = (constraints_jit,) + flat_constraints[1:] 141 | solver.assem_constraint_residual( 142 | root_prim, *flat_constraints_jit 143 | ) 144 | 145 | t0 = time.time() 146 | for i in range(50): 147 | solver.assem_constraint_residual( 148 | root_prim, *flat_constraints_jit 149 | ) 150 | t1 = time.time() 151 | print(f"Duration {t1-t0:.2e} s") 152 | 153 | # `jax.jit` the overall function 154 | 155 | @jax.jit 156 | def assem_constraint_residual(root_prim): 157 | return solver.assem_constraint_residual( 158 | root_prim, *flat_constraints 159 | ) 160 | 161 | assem_constraint_residual(root_prim) 162 | t0 = time.time() 163 | for i in range(50): 164 | assem_constraint_residual(root_prim) 165 | t1 = time.time() 166 | print(f"Duration {t1-t0:.2e} s") 167 | 168 | @pytest.fixture( 169 | params=('newton', 'minimize') 170 | ) 171 | def method(self, request): 172 | return request.param 173 | 174 | def test_solve(self, layout: lay.Layout, method: str): 175 | t0 = time.time() 176 | prim_tree_n, solve_info = solver.solve( 177 | layout, method=method, max_iter=100 178 | ) 179 | t1 = time.time() 180 | print(f"Solve took {t1-t0:.2e} s") 181 | 182 | prim_keys_to_value = { 183 | key: prim.value for key, prim in cn.iter_flat("", prim_tree_n) 184 | } 185 | pprint(prim_keys_to_value) 186 | pprint(solve_info) 187 | --------------------------------------------------------------------------------