├── plotannot ├── _version.py ├── __init__.py ├── functions.py └── code.py ├── readthedocs.yml ├── docs ├── source │ ├── examples │ │ ├── examples.nblink │ │ ├── customization.nblink │ │ └── index.rst │ ├── API │ │ └── index.rst │ ├── install.rst │ ├── index.rst │ └── conf.py ├── rtd-environment.yml ├── Makefile └── make.bat ├── examples ├── before_after.png └── simple_example.png ├── pyproject.toml ├── CHANGELOG.md ├── setup.cfg ├── LICENSE ├── .gitignore └── README.md /plotannot/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.0" 2 | -------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- 1 | conda: 2 | file: docs/rtd-environment.yml 3 | -------------------------------------------------------------------------------- /docs/source/examples/examples.nblink: -------------------------------------------------------------------------------- 1 | {"path": "../../../examples/examples.ipynb"} -------------------------------------------------------------------------------- /docs/source/examples/customization.nblink: -------------------------------------------------------------------------------- 1 | {"path": "../../../examples/customization.ipynb"} -------------------------------------------------------------------------------- /examples/before_after.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/msbentsen/plotannot/HEAD/examples/before_after.png -------------------------------------------------------------------------------- /examples/simple_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/msbentsen/plotannot/HEAD/examples/simple_example.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | -------------------------------------------------------------------------------- /docs/source/examples/index.rst: -------------------------------------------------------------------------------- 1 | Example notebooks 2 | ------------------- 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | examples 8 | customization -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | 2 | # Version changes 3 | 4 | _plotannot_ follows semantic versioning guidelines (reference: [semver.org](https://semver.org/)). 5 | 6 | ## 0.1 7 | - Initial version 8 | -------------------------------------------------------------------------------- /docs/source/API/index.rst: -------------------------------------------------------------------------------- 1 | Module API 2 | -------------------------------- 3 | 4 | .. automodule:: plotannot.functions 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/rtd-environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - conda-forge 3 | dependencies: 4 | - python 5 | - sphinx>=1.4 6 | - pandoc 7 | - nbconvert 8 | - ipykernel 9 | - pip: 10 | - nbsphinx 11 | - nbsphinx-link -------------------------------------------------------------------------------- /docs/source/install.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | --------------- 3 | 4 | To install from PyPI, simply run: 5 | 6 | .. code-block:: 7 | 8 | pip install plotannot 9 | 10 | To install from the GitHub repository, run: 11 | 12 | .. code-block:: 13 | 14 | git clone https://github.com/msbentsen/plotannot.git 15 | cd plotannot 16 | pip install . 17 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to plotannot's documentation! 2 | ===================================== 3 | 4 | plotannot is a python package for highlighting and adjusting axes ticklabels, and remove overlapping labels. 5 | The project code lives on GitHub (`msbentsen/plotannot `_). 6 | 7 | Please submit any issues to the `github issues page `_. 8 | 9 | Content 10 | -------------- 11 | 12 | .. toctree:: 13 | :maxdepth: 2 14 | 15 | install 16 | API/index 17 | examples/index 18 | -------------------------------------------------------------------------------- /plotannot/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | 3 | from ._version import __version__ 4 | 5 | #Set functions to be available directly, i.e. "from plotannot import annotate" 6 | module = import_module("plotannot.functions") 7 | function_names = [func for func in dir(module) if callable(getattr(module, func)) and not func.startswith("__")] 8 | global_functions = ["plotannot.functions." + s for s in function_names] 9 | 10 | for f in global_functions: 11 | 12 | module_name = ".".join(f.split(".")[:-1]) 13 | attribute_name = f.split(".")[-1] 14 | 15 | module = import_module(module_name) 16 | attribute = getattr(module, attribute_name) 17 | 18 | globals()[attribute_name] = attribute -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = plotannot 3 | version = attr: plotannot._version.__version__ 4 | long_description = file: README.md 5 | long_description_content_type = text/markdown 6 | url = https://github.com/msbentsen/plotannot 7 | author = Mette Bentsen 8 | author_email = mette.bentsen@mpi-bn.mpg.de 9 | license = MIT 10 | license_file = LICENSE 11 | platforms = Linux, Mac OS X, Windows 12 | classifiers = 13 | Intended Audience :: Developers 14 | Intended Audience :: Science/Research 15 | License :: OSI Approved :: MIT License 16 | Topic :: Scientific/Engineering :: Visualization 17 | Programming Language :: Python :: 3 18 | Framework :: Matplotlib 19 | 20 | [options] 21 | packages = find: 22 | install_requires = 23 | matplotlib 24 | numpy 25 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Mette Bentsen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # plotannot 2 | [![PyPI Version](https://img.shields.io/pypi/v/plotannot.svg?style=plastic)](https://pypi.org/project/plotannot/) 3 | 4 | # Introduction 5 | _plotannot_ is a a python package to automatically highlight and adjust overlapping ticklabels in matplotlib/seaborn plots. It is written with great inspiration and appreciation for the _statannot_ package ([webermarcolivier/statannot](https://github.com/webermarcolivier/statannot) - now maintained at [trevismd/statannotations](https://github.com/trevismd/statannotations)), as well as the _adjustText_ package ([Phlya/adjustText](https://github.com/Phlya/adjustText)). 6 | 7 | I originally created this package for myself, as I wanted to create ComplexHeatmap (R package) style annotations for Python plots - but maybe it is of use to you too? 8 | 9 | 10 | 11 | ## Features 12 | 13 | - Add annotation lines for certain row/column labels 14 | - Shift labels to not overlap 15 | - Add additional highlights such as color, fontsize, etc. to certain row/column labels 16 | 17 | 18 | ## Getting started 19 | 20 | Install from PyPI: 21 | 22 | ```pip install plotannot``` 23 | 24 | Or directly from github: 25 | 26 | ``` pip install git+git://github.com/msbentsen/plotannot ``` 27 | 28 | Requirements for package: 29 | - Python >= 3.6 30 | - matplotlib 31 | - numpy 32 | 33 | 34 | ## Simple example 35 | 36 | ``` 37 | import pandas as pd 38 | import seaborn as sns 39 | import plotannot 40 | 41 | #Plot heatmap 42 | table = pd.DataFrame(np.random.random((100,50))) 43 | ax = sns.heatmap(table, xticklabels=True, yticklabels=False) 44 | 45 | #Rotate all labels 46 | plotannot.format_ticklabels(ax, axis="xaxis", rotation=45) 47 | 48 | #Annotate labels 49 | to_label = range(20,35) 50 | plotannot.annotate_ticks(ax, axis="xaxis", labels=to_label) 51 | 52 | #Color individual labels 53 | plotannot.format_ticklabels(ax, axis="xaxis", labels=[25], color="red") 54 | ``` 55 | 56 | 57 | Additional examples are found in the [examples notebook](examples/examples.ipynb). 58 | 59 | ## Documentation and help 60 | 61 | Documentation of the main functions are found at: [plotannot.readthedocs.io](https://plotannot.readthedocs.io/en/latest/). Examples of how to use these are in the examples notebook here: [examples/examples.ipynb](examples/examples.ipynb). 62 | 63 | Issues and PRs are very welcome - please use the [repository issues](https://github.com/msbentsen/plotannot/issues) to raise an issue/contribute. 64 | 65 | 66 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | sys.path.insert(0, os.path.abspath('../..')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = 'plotannot' 21 | copyright = '2022, Mette Bentsen' 22 | author = 'Mette Bentsen' 23 | 24 | 25 | # -- General configuration --------------------------------------------------- 26 | 27 | # Add any Sphinx extension module names here, as strings. They can be 28 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 29 | # ones. 30 | extensions = ['sphinx.ext.viewcode', 'sphinx.ext.autodoc', 'sphinx.ext.napoleon', 31 | 'nbsphinx', 'nbsphinx_link'] 32 | 33 | napoleon_numpy_docstring = True 34 | 35 | # Add any paths that contain templates here, relative to this directory. 36 | templates_path = ['_templates'] 37 | 38 | # List of patterns, relative to source directory, that match files and 39 | # directories to ignore when looking for source files. 40 | # This pattern also affects html_static_path and html_extra_path. 41 | exclude_patterns = [] 42 | 43 | autodoc_mock_imports = ['numpy', 'matplotlib'] 44 | 45 | 46 | # -- Create nblink files ------------------------------------------------- 47 | 48 | import glob 49 | import json 50 | 51 | #Remove all previous .nblink files 52 | links = glob.glob("examples/*.nblink") 53 | for l in links: 54 | os.remove(l) 55 | 56 | #Create nblinks for current notebooks 57 | notebooks = glob.glob("../../examples/*.ipynb") 58 | for f in notebooks: 59 | f_name = os.path.basename(f).replace(".ipynb", "") 60 | 61 | d = {"path": "../" + f} 62 | with open("examples/" + f_name + ".nblink", 'w') as fp: 63 | json.dump(d, fp) 64 | 65 | nbsphinx_execute = 'never' 66 | 67 | # -- Options for HTML output ------------------------------------------------- 68 | 69 | # The theme to use for HTML and HTML Help pages. See the documentation for 70 | # a list of builtin themes. 71 | # 72 | html_theme = 'sphinx_rtd_theme' 73 | 74 | # Add any paths that contain custom static files (such as style sheets) here, 75 | # relative to this directory. They are copied after the builtin static files, 76 | # so a file named "default.css" will overwrite the builtin "default.css". 77 | html_static_path = ['_static'] -------------------------------------------------------------------------------- /plotannot/functions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | #Functions for annotating and formatting ticklabels (using the class PlotInfo) 4 | #@author: Mette Bentsen 5 | #@contact: mette.bentsen (at) mpi-bn.mpg.de 6 | #@license: MIT 7 | 8 | 9 | import plotannot.code 10 | 11 | def annotate_ticks(ax, axis, labels, 12 | expand_axis=0, 13 | rel_label_size=1.1, 14 | perp_shift=5, 15 | rel_tick_size=0.25, 16 | resolution=1000, 17 | speed=0.1, 18 | verbosity=1 19 | ): 20 | """ 21 | Annotate ticks with a subset of labels, and shift to overlapping labels. 22 | 23 | Parameters 24 | -------------- 25 | ax : matplotlib.axes.Axes 26 | Axes object holding the plot and labels to annotate. 27 | axis : str 28 | Name of axis to annotate. Must be one of ["xaxis", "yaxis", "left", "right", "bottom", "top"]. 29 | labels : list of str 30 | A list of labels to annotate. Must be a list of strings (or values convertible to strings) corresponding to the labels to show in plot. 31 | expand_axis : float or tuple, optional 32 | Expand the annotation axis by this amount of the total axis width. Can be either float or tuple of floats. 33 | Corresponds to the relative size of axes to expand with, e.g. 0.1 extends with 5% of the axis size in both directions (total 10%). 34 | A tuple of (0.1,0.2) extends the axis with 10% in the beginning (left or bottom) and 20% (right or top). Default: 0. 35 | rel_label_size : float, optional 36 | Relative size of labels to use for measuring overlaps. Default: 1.1. 37 | perp_shift : float, optional 38 | Perpendicular shift of labels. Represents the relative length of ticks of the axis. Default: 5 (5 times the length of ticks). 39 | rel_tick_size : float, optional 40 | Relative size of the horizontal part of the annotation lines as a fraction of perp_shift. Default: 0.25. 41 | resolution : int, optional 42 | Resolution for finding overlapping labels. Default: 1000. 43 | speed : float, optional 44 | The speed with which the labels are moving when removing overlaps. A float value between 0-1. Default: 0.1. 45 | verbosity : int, optional 46 | The level of logging from the function. An integer between 0 and 3, corresponding to: 0: only errors, 1: minimal, 2: debug, 3: spam debug. Default: 1. 47 | """ 48 | 49 | p = plotannot.code.PlotInfo(ax, verbosity=verbosity) 50 | 51 | p.check_axis(axis) 52 | p.subset_ticklabels(axis, labels) 53 | p.extend_axis(axis, expand_axis=expand_axis) 54 | p.shift_integer_labels(axis, resolution=resolution, rel_label_size=rel_label_size, speed=speed) 55 | p.apply_shift(axis, perp_shift=perp_shift) 56 | p.plot_annotation_lines(axis, rel_tick_size=rel_tick_size) 57 | 58 | 59 | def format_ticklabels(ax, axis, labels=None, format_ticks=False, verbosity=1, **kwargs): 60 | """ 61 | Format ticklabels of a given axis using attributes such as color, fontsize, fontweight, etc. 62 | 63 | Parameters 64 | -------------- 65 | ax : matplotlib.axes.Axes 66 | Axes object holding the plot and labels to annotate. 67 | axis : str 68 | Name of axis to annotate. Must be one of ["xaxis", "yaxis", "left", "right", "bottom", "top"]. 69 | labels : list of str, optional 70 | A list of labels to annotate. Must be a list of strings corresponding to the labels to show in plot. If None, all labels are used. Default: None. 71 | format_ticks : bool, optional 72 | If True, also format the ticklines of the axis. Default: False. 73 | verbosity : int, optional 74 | The level of logging from the function. An integer between 0 and 3, corresponding to: 0: only errors, 1: minimal, 2: debug, 3: spam debug. Default: 1. 75 | kwargs : args, optional 76 | Additional keyword arguments containing the attributes to set for labels. Each attribute is used as a function "set\_" + attribute for the label, 77 | e.g. "color='red'" will set the color of the label to red using the label-function 'set_color'. 78 | """ 79 | 80 | p = plotannot.code.PlotInfo(ax, verbosity=verbosity) 81 | axis = p.format_axis(axis) 82 | 83 | #Check if kwargs were given 84 | if len(kwargs) == 0: 85 | raise ValueError("No attributes given to format labels.") 86 | 87 | #Apply to axis (can be more than one if axis is "xaxis" or "yaxis") 88 | for a in axis: 89 | 90 | #Get labels for applying functions 91 | label_objects = [d["object"] for d in p.label_info[a]] 92 | tick_objects = [d["object"] for d in p.tick_info[a]] 93 | 94 | #Subset to labels if chosen 95 | if labels is not None: 96 | labels = [str(l) for l in labels] 97 | 98 | p.check_labels(a, labels) 99 | indices = [i for i, o in enumerate(label_objects) if o._text in labels] 100 | label_objects = [label_objects[i] for i in indices] 101 | tick_objects = [tick_objects[i] for i in indices] 102 | 103 | #Apply attributes to ticklabels 104 | for attribute, value in kwargs.items(): 105 | 106 | func_name = "set_" + attribute 107 | 108 | #Format labels 109 | for label in label_objects: 110 | try: 111 | f = getattr(label, func_name) 112 | except: 113 | raise ValueError("{func_name}") 114 | f(value) 115 | 116 | #Format ticks 117 | if format_ticks == True: 118 | for tick in tick_objects: 119 | if hasattr(tick, func_name): 120 | f = getattr(tick, func_name) 121 | f(value) 122 | -------------------------------------------------------------------------------- /plotannot/code.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | #PlotInfo class for plotannot 4 | #@author: Mette Bentsen 5 | #@contact: mette.bentsen (at) mpi-bn.mpg.de 6 | #@license: MIT 7 | 8 | import sys 9 | import numpy as np 10 | import math 11 | import matplotlib 12 | import matplotlib.transforms 13 | import matplotlib.pyplot as plt 14 | 15 | #Functions for logging 16 | import logging 17 | from logging import ERROR, INFO, DEBUG 18 | 19 | #Create additional level below DEBUG 20 | SPAM = DEBUG - 1 21 | logging.addLevelName(SPAM, 'SPAM') 22 | 23 | class PlotInfo(): 24 | """ A class collecting information on plot elements to annotate """ 25 | 26 | def __init__(self, o, verbosity=1): 27 | 28 | self.set_logger(verbosity=verbosity) 29 | self.get_figure() 30 | self.get_axis(o) 31 | self.get_transform() 32 | self.get_axis_info() 33 | self.get_tick_info() 34 | 35 | def set_logger(self, verbosity=1): 36 | """ Set logger for class. 37 | 38 | Parameters 39 | ------------- 40 | verbosity : int 41 | Verbosity level. An integer between 0 and 3, corresponding to: 0: only errors, 1: minimal, 2: debug, 3: spam debug. Default: 1. 42 | """ 43 | 44 | verbosity_to_level = {0: ERROR, #silent 45 | 1: INFO, 46 | 2: DEBUG, 47 | 3: SPAM #extreme spam debugging 48 | } 49 | level = verbosity_to_level[verbosity] 50 | 51 | self.verbosity = verbosity 52 | self.logger = logging.getLogger(__name__) 53 | self.logger.setLevel(level) 54 | 55 | #Remove all existing handlers 56 | for handler in self.logger.handlers: 57 | self.logger.removeHandler(handler) 58 | self.logger.addHandler(logging.StreamHandler(sys.stdout)) 59 | 60 | #Create custom spam level 61 | def spam(self, message, *args, **kwargs): 62 | if self.isEnabledFor(SPAM): 63 | self._log(SPAM, message, args, **kwargs) 64 | self.logger.spam = lambda message, *args, **kwargs: spam(self.logger, message, *args, **kwargs) 65 | 66 | #Format 67 | formatter = logging.Formatter('[%(levelname)s] %(message)s') 68 | handler = self.logger.handlers[0] 69 | handler.setFormatter(formatter) 70 | 71 | def get_figure(self): 72 | """ Get the current figure """ 73 | 74 | fig = plt.gcf() 75 | fig.canvas.draw() 76 | 77 | self.fig = fig 78 | self.renderer = fig.canvas.get_renderer() 79 | 80 | 81 | def get_axis(self, o): 82 | """ Get xaxis and yaxis objects. 83 | 84 | Parameters 85 | ------------ 86 | o : matplotlib.axes.Axes, sns.ClusterGrid 87 | An object to find xaxis/yaxis objects from. 88 | """ 89 | 90 | #For sns.heatmap; get axis directly from object 91 | if hasattr(o, "xaxis") and hasattr(o, "yaxis"): 92 | self.ax = o 93 | self.xaxis = o.xaxis 94 | self.yaxis = o.yaxis 95 | 96 | #For plt plots; get from _axes 97 | elif hasattr(o, "_axes"): 98 | self.ax = o._axes 99 | self.xaxis = self.ax.xaxis 100 | self.yaxis = self.ax.yaxis 101 | 102 | #for sns.clustermap; get from heatmap axes 103 | elif hasattr(o, "ax_heatmap"): 104 | self.ax = o.ax_heatmap 105 | self.xaxis = self.ax.xaxis 106 | self.yaxis = self.ax.yaxis 107 | 108 | #Todo: search for xaxis/yaxis in dict 109 | else: 110 | self.logger.error("Could not find xaxis/yaxis in object") 111 | sys.exit() 112 | 113 | def get_transform(self): 114 | """ 115 | Get transformation objects for figure and data axis. 116 | """ 117 | 118 | #Transformation between display and inches 119 | self.trans_fig = self.fig.dpi_scale_trans 120 | self.trans_fig_inv = self.trans_fig.inverted() 121 | 122 | self.trans_data = self.ax.transData 123 | self.trans_data_inv = self.trans_data.inverted() 124 | 125 | @staticmethod 126 | def format_axis(axis): 127 | """ Convert "xaxis" and "yaxis" names into "bottom", "top", "left", "right" """ 128 | 129 | #Establish axis 130 | if axis == "xaxis": 131 | axis = ["top", "bottom"] 132 | elif axis == "yaxis": 133 | axis = ["left", "right"] 134 | 135 | if isinstance(axis, str): 136 | axis = [axis] 137 | 138 | return axis 139 | 140 | @staticmethod 141 | def check_axis(axis): 142 | """ Check that axis is valid """ 143 | 144 | valid = ["xaxis", "yaxis", "top", "bottom", "left", "right"] 145 | if axis not in valid: 146 | raise ValueError(f"Given axis '{axis}' is not valid. Possible axis are: {valid}") 147 | 148 | @staticmethod 149 | def check_value(value, vmin=-math.inf, vmax=math.inf, integer=False, name=None): 150 | """ Check that value is valid based on vmin/vmax""" 151 | 152 | if vmin > vmax: 153 | raise ValueError("vmin must be smaller than vmax") 154 | 155 | error_msg = None 156 | if integer == True: 157 | if not isinstance(value, int): 158 | error_msg = "The {0} given ({1}) is not an integer, but integer is set to True.".format(name, value) 159 | else: 160 | #check if value is any value 161 | try: 162 | _ = int(value) 163 | except: 164 | error_msg = "The {0} given ({1}) is not a valid number".format(name, value) 165 | 166 | #If value is a number, check if it is within bounds 167 | if error_msg is None: 168 | if not ((value >= vmin) & (value <= vmax)): 169 | error_msg = "The {0} given ({1}) is not within the bounds of [{2};{3}]".format(name, value, vmin, vmax) 170 | 171 | #Finally, raise error if necessary: 172 | if error_msg is not None: 173 | raise ValueError(error_msg) 174 | 175 | 176 | def check_labels(self, axis, labels): 177 | """ Check whether given labels are valid for axis """ 178 | 179 | self.check_axis(axis) 180 | axis = self.format_axis(axis) 181 | 182 | for a in axis: 183 | all_label_texts = [d["object"]._text for d in self.label_info[a]] #all label texts found 184 | 185 | not_found = set(labels) - set(all_label_texts) 186 | 187 | #If there are no visible labels at all; pass 188 | if len(all_label_texts) == 0: 189 | pass 190 | 191 | #If no matches were found 192 | elif len(not_found) == len(labels): 193 | self.logger.warning(f"No match could be found between given 'labels' and the {a}-axis ticklabels.") 194 | self.logger.warning(f"Axis ticklabels are: {all_label_texts[:5]} (...). Given labels are: {labels[:5]}.") 195 | self.logger.warning("Please check input labels and axis.") 196 | raise ValueError("No match between given labels and axis ticklabels") 197 | 198 | #If only some matches were found 199 | elif len(not_found) > 0: 200 | self.logger.warning(f"{len(not_found)} string(s) from 'labels' were not found in axis ticklabels. These labels were: {list(not_found)}") 201 | 202 | #---------------- Get information of the axes and labels in the plot ----------------# 203 | 204 | def get_axis_info(self): 205 | """ Get the extent of all axis in inch coordinates """ 206 | 207 | self.axis_info = {} 208 | 209 | #Get extent of axis 210 | ax_bbox = self.ax.get_window_extent(renderer=self.renderer) 211 | ax_bbox = ax_bbox.transformed(self.trans_fig_inv) #transform from display to inches 212 | 213 | for axis in ["xaxis", "yaxis"]: 214 | ax_0, ax_1 = (ax_bbox.x0, ax_bbox.x1) if axis == "xaxis" else (ax_bbox.y0, ax_bbox.y1) # in inches 215 | self.logger.debug(f"Read info for axis: {axis}. Extent is: {ax_0}, {ax_1}") 216 | 217 | self.axis_info[axis] = {"bbox_inch": ax_bbox, "from_inch": ax_0, "to_inch": ax_1, "extent_inch": ax_1-ax_0} #positions in inches 218 | 219 | #Propagate to individual axis 220 | self.axis_info["bottom"] = self.axis_info["xaxis"] 221 | self.axis_info["top"] = self.axis_info["xaxis"] 222 | self.axis_info["left"] = self.axis_info["yaxis"] 223 | self.axis_info["right"] = self.axis_info["yaxis"] 224 | 225 | def get_tick_info(self): 226 | """ 227 | Save information on tick and labels 228 | """ 229 | 230 | axis_names = ["bottom", "top", "left", "right"] 231 | 232 | self.label_info = {a:[] for a in axis_names} 233 | self.tick_info = {a:[] for a in axis_names} 234 | 235 | #Collect labels and ticks from axis 236 | for axis in ["xaxis", "yaxis"]: 237 | 238 | ticks = getattr(self, axis).get_major_ticks() #list of tick objects 239 | for tick in ticks: 240 | if axis == "xaxis": 241 | self.label_info["bottom"].append({"object": tick.label1}) 242 | self.tick_info["bottom"].append({"object": tick.tick1line}) 243 | 244 | self.label_info["top"].append({"object": tick.label2}) 245 | self.tick_info["top"].append({"object": tick.tick2line}) 246 | 247 | elif axis == "yaxis": 248 | self.label_info["left"].append({"object": tick.label1}) 249 | self.tick_info["left"].append({"object": tick.tick1line}) 250 | 251 | self.label_info["right"].append({"object": tick.label2}) 252 | self.tick_info["right"].append({"object": tick.tick2line}) 253 | 254 | self.remove_invisible_labels() 255 | 256 | #Add additional information about sizes and positions 257 | for axis in axis_names: 258 | for l in [self.label_info[axis], self.tick_info[axis]]: #for each list 259 | for d in l: #for each dict in list 260 | 261 | bbox = d["object"].get_window_extent(self.renderer) 262 | bbox_inch = bbox.transformed(self.trans_fig_inv) #to inches 263 | bbox_data = bbox.transformed(self.trans_data_inv) #to data 264 | 265 | para_0, para_1 = (bbox_inch.x0, bbox_inch.x1) if axis in ["top", "bottom"] else (bbox_inch.y0, bbox_inch.y1) # in inches 266 | perp_0, perp_1 = (bbox_inch.y0, bbox_inch.y1) if axis in ["top", "bottom"] else (bbox_inch.x0, bbox_inch.x1) #perpendicular size in inches 267 | 268 | para_data_0, para_data_1 = (bbox_data.x0, bbox_data.x1) if axis in ["top", "bottom"] else (bbox_data.y0, bbox_data.y1) #parallel size in data coordinates 269 | perp_data_0, perp_data_1 = (bbox_data.y0, bbox_data.y1) if axis in ["top", "bottom"] else (bbox_data.x0, bbox_data.x1) #in data coordinates 270 | 271 | additional_info = {"bbox": bbox, "from_inch": para_0, "to_inch": para_1, 272 | "pos_inch_para": np.mean([para_0, para_1]), "extent_inch_para": para_1-para_0, 273 | "pos_data_para": np.mean([para_data_0, para_data_1]), "extent_data_para": para_data_1-para_data_0, 274 | 275 | "pos_inch_perp": np.mean([perp_0, perp_1]), "extent_inch_perp": perp_1-perp_0, 276 | "pos_data_perp": np.mean([perp_data_0, perp_data_1]), "extent_data_perp": perp_data_1-perp_data_0 277 | } 278 | 279 | d.update(additional_info) #update dict in place 280 | 281 | #Sort from lowest to highest inches positions on axis 282 | ind_to_sort = np.argsort([d["pos_inch_para"] for d in self.label_info[axis]]) 283 | self.label_info[axis] = [self.label_info[axis][i] for i in ind_to_sort] 284 | self.tick_info[axis] = [self.tick_info[axis][i] for i in ind_to_sort] 285 | 286 | 287 | def remove_invisible_labels(self): 288 | """ 289 | Remove invisible labels (and corresponding ticks) from internal info dicts 290 | """ 291 | 292 | #Remove any invisible labels from the lists 293 | for axis in self.label_info: 294 | visible_indices = [i for i, d in enumerate(self.label_info[axis]) if d["object"]._visible == True] 295 | self.label_info[axis] = [self.label_info[axis][i] for i in visible_indices] 296 | self.tick_info[axis] = [self.tick_info[axis][i] for i in visible_indices] 297 | 298 | 299 | #------------------- Subset and format labels -----------------# 300 | 301 | def subset_ticklabels(self, axis, labels): 302 | """ Hide any ticklabels not in 'labels'. 303 | 304 | Parameters 305 | ------------ 306 | axis : str 307 | Name of axis. One of: "xaxis", "yaxis", "top", "bottom", "left", "right". 308 | labels : list 309 | List of labels to keep. 310 | """ 311 | 312 | labels = [str(label) for label in labels] #convert labels to strings 313 | 314 | self.check_axis(axis) 315 | self.check_labels(axis, labels) 316 | 317 | axis = self.format_axis(axis) 318 | 319 | #Subset for each axis separately 320 | found = [] 321 | for a in axis: 322 | for i, d in enumerate(self.label_info[a]): 323 | label_text = d["object"]._text 324 | if label_text not in labels: 325 | self.label_info[a][i]["object"].set_visible(False) 326 | self.tick_info[a][i]["object"].set_visible(False) 327 | else: 328 | found.append(label_text) 329 | 330 | self.remove_invisible_labels() 331 | 332 | 333 | #-----------------------------------------------------------------------------------# 334 | #----------------- Functionality for shifting and annotating labels ----------------# 335 | #-----------------------------------------------------------------------------------# 336 | 337 | def extend_axis(self, axis, expand_axis=0): 338 | """ 339 | Extend the size of axis to make room for labels. 340 | 341 | Parameters 342 | ------------- 343 | axis : str 344 | Name of axis. One of: "xaxis", "yaxis", "top", "bottom", "left", "right". 345 | expand_axis : float or tuple 346 | Expand the axis by this amount. Corresponds to the relative size of axes to expand with, e.g. 0.1 extends with 5% of the axis size in 347 | both directions. Default: 0. 348 | """ 349 | 350 | self.check_value(expand_axis, name="expand_axis") 351 | 352 | if isinstance(expand_axis, (int, float)): 353 | expand_axis = (expand_axis/2, expand_axis/2) 354 | 355 | axis = self.format_axis(axis) 356 | 357 | #Adjust axis 358 | for a in axis: 359 | self.axis_info[a]["from_inch"] -= self.axis_info[a]["extent_inch"] * expand_axis[0] #lower 360 | self.axis_info[a]["to_inch"] += self.axis_info[a]["extent_inch"] * expand_axis[1] #higher 361 | self.axis_info[a]["extent_inch"] = self.axis_info[a]["to_inch"] - self.axis_info[a]["from_inch"] 362 | 363 | 364 | def get_integer_positions(self, resolution=1000): 365 | """ Integer positions of ticks and ticklabels """ 366 | 367 | #Save integer positions for axis 368 | self.integer_positions = {} 369 | for axis in self.label_info: 370 | self.logger.debug(f"Getting integer positions for labels on {axis} axis") 371 | 372 | #Initialize tick positions in integer space 373 | n = len(self.label_info[axis]) 374 | text_positions_int = [] 375 | tick_positions_int = [] 376 | 377 | #Fill in information from each label 378 | for i in range(n): 379 | 380 | #Size and position of label 381 | label_position = self.label_info[axis][i]["pos_inch_para"] 382 | 383 | #Extent of axes 384 | ax_start = self.axis_info[axis]["from_inch"] 385 | ax_extent = self.axis_info[axis]["extent_inch"] 386 | 387 | #Calculate integer values for labels 388 | label_position_int = int(((label_position - ax_start) / ax_extent) * resolution) 389 | 390 | #make sure that the positions are not out of bounds 391 | label_position_int = min(label_position_int, resolution) 392 | label_position_int = max(label_position_int, 0) 393 | 394 | #Fill lists 395 | tick_positions_int.append(label_position_int) 396 | text_positions_int.append(label_position_int) 397 | 398 | 399 | #Save information for axis 400 | self.integer_positions[axis] = {"text_pos_int_arr": text_positions_int, 401 | "tick_pos_int_arr": tick_positions_int} 402 | 403 | 404 | def get_extent_matrix(self, rel_label_size=1, resolution=1000): 405 | """ Set extent of labels based on widths of labels """ 406 | 407 | for axis in self.label_info: 408 | 409 | n = len(self.label_info[axis]) 410 | extent_matrix_int = np.zeros((n, resolution)) 411 | 412 | for i in range(n): 413 | 414 | label_position_int = self.integer_positions[axis]["text_pos_int_arr"][i] #current position of label 415 | label_extent = self.label_info[axis][i]["extent_inch_para"] * rel_label_size 416 | ax_extent = self.axis_info[axis]["extent_inch"] 417 | 418 | #Get extent of label 419 | extent_half_int = int((label_extent / ax_extent / 2) * resolution) 420 | extent_start = label_position_int - extent_half_int 421 | extent_end = label_position_int + extent_half_int 422 | 423 | #Make sure that the positions are not out of bounds 424 | label_position_int = min(label_position_int, resolution) 425 | label_position_int = max(label_position_int, 0) 426 | extent_start = max(0, extent_start) 427 | extent_end = min(extent_end, resolution) 428 | 429 | #Fill matrix 430 | extent_matrix_int[i, extent_start:extent_end+1] = 1 431 | 432 | self.integer_positions[axis]["extent_matrix_int"] = extent_matrix_int 433 | 434 | @staticmethod 435 | def roll_matrix(matrix, i, shift): 436 | """ 437 | Roll a matrix along ith rows by shift positions. 438 | 439 | Parameters 440 | ------------- 441 | matrix : numpy.ndarray 442 | 2D array 443 | i : int 444 | Index of axis to roll along. 445 | shift : int 446 | Number of positions to shift. 447 | """ 448 | 449 | arr = matrix[i] 450 | rolled = np.roll(arr, shift) 451 | 452 | if shift < 0: 453 | rolled[shift:] = 0 #last values are 0 454 | else: 455 | rolled[:shift] = 0 #first values are 0 456 | 457 | matrix[i,:] = rolled 458 | 459 | return matrix 460 | 461 | def move_elements(self, current_pos_arr, target_pos_arr, extent_matrix, speed=0.1): 462 | """ Move elements from their current position closer to their target position, given that they must not overlap. 463 | 464 | Parameters 465 | ------------- 466 | current_pos_arr : array 467 | Current positions of elements. 468 | target_pos_arr : array 469 | Target positions of elements. 470 | extent_matrix : matrix 471 | Matrix containing the extent of each element. 472 | speed : integer 473 | The speed with which the elements move to their target position. Default: 0.1. 474 | """ 475 | 476 | #Initialize counts 477 | failed_count = 0 #count of labels which failed to move 478 | iteration_count = 0 479 | n = len(current_pos_arr) #number of elements 480 | resolution = extent_matrix.shape[1] 481 | 482 | speed = max(int(speed * resolution), 1) #make sure that speed is at least 1 483 | self.logger.debug(f"speed is {speed} (resolution: {resolution})") 484 | 485 | #Shift elements until all elements fail to move 486 | while failed_count < n: 487 | 488 | failed_count = 0 #intialize failed_count for this iteration 489 | iteration_count += 1 #increment iteration count 490 | 491 | #calculate differences 492 | diff = current_pos_arr - target_pos_arr 493 | diff_argsort = np.argsort(diff) 494 | 495 | #Shift starting with closest 496 | for i in diff_argsort: 497 | 498 | this_label_pos = current_pos_arr[i] 499 | this_target_pos = target_pos_arr[i] 500 | this_diff = diff[i] 501 | self.logger.spam(f"Label with index {i} (pos: {this_label_pos}) is closest to target (pos: {this_target_pos}) with a diff of {this_diff}") 502 | 503 | #What direction is the difference? 504 | shift = 0 505 | if this_diff > 0: #Difference is positive; shift should be negative (left) 506 | 507 | if i == 0: #if i is 0, there are no labels to the left, and 508 | shift = -1 * this_diff #label can move directly to the left 509 | 510 | else: 511 | #Get next label is to the left 512 | left_label_pos = current_pos_arr[i-1] #position of the label to the left 513 | arr = extent_matrix[[i,i-1], :].sum(axis=0) 514 | possible_shift = sum(arr[left_label_pos:this_label_pos] == 0) #space between left and current label 515 | self.logger.spam(f"Possible shift without overlap is {possible_shift}") 516 | 517 | shift = min(possible_shift, np.abs(this_diff)) 518 | shift = min(shift, int(np.ceil(shift * speed))) #cap shift at speed 519 | shift = -1 * shift #shift to the left (-) 520 | 521 | else: #difference is negative or 0; shift should be positive (right) 522 | 523 | if (i+1) == n: #if i+1 is n; there are no labels to the right 524 | shift = this_diff #label can move directly to the right 525 | 526 | else: 527 | #Get next label is to the right 528 | right_label_pos = current_pos_arr[i+1] 529 | arr = extent_matrix[[i,i+1]].sum(axis=0) 530 | 531 | possible_shift = sum(arr[this_label_pos:right_label_pos] == 0) 532 | shift = min(possible_shift, np.abs(this_diff)) #shift to the right (+) 533 | shift = min(shift, int(np.ceil(shift * speed))) #cap shift at speed 534 | 535 | self.logger.spam(f"Trying to shift label by {shift}") 536 | 537 | #If shift is 0, label was not moved (=failed to move) 538 | if shift == 0: 539 | failed_count += 1 540 | 541 | else: 542 | 543 | #Try to perform roll of index 544 | overlap_before = sum(extent_matrix.sum(axis=0) > 1) #positions with more than two labels 545 | extent_rolled = PlotInfo.roll_matrix(extent_matrix, i, shift) 546 | overlap_after = sum(extent_rolled.sum(axis=0) > 1) 547 | 548 | if overlap_before - overlap_after == 0: #no changes in overlaps; success 549 | extent_matrix = extent_rolled #update 550 | current_pos_arr[i] += shift 551 | 552 | else: 553 | failed_count += 1 554 | #print("Failed to move label. Current failed count is: {0}/{1}".format(failed_count, n)) 555 | 556 | self.logger.debug(f"Finished iteration {iteration_count} of moves: Failed count is {failed_count}/{n}") 557 | 558 | #Make sure positions are within bounds of resolution 559 | current_pos_arr = np.clip(current_pos_arr, 0, resolution-1) 560 | 561 | return current_pos_arr 562 | 563 | 564 | def shift_integer_labels(self, axis, resolution=1000, rel_label_size=1.1, speed=0.1): 565 | """ Shift labels to not be overlapping . 566 | 567 | Parameters 568 | ------------- 569 | axis : str or list 570 | Name of axis. 571 | resolution : int, optional 572 | Number of bins in axis. Default: 1000. 573 | rel_label_size : float, optional 574 | Relative size of labels. Default: 1.1. 575 | speed : float, optional 576 | The speed with which labels move. Default: 0.1. 577 | """ 578 | 579 | self.check_axis(axis) 580 | self.check_value(resolution, vmin=1, integer=True, name="resolution") 581 | self.check_value(rel_label_size, vmin=0, name="rel_label_size") 582 | self.check_value(speed, vmin=0, vmax=1, name="speed") 583 | 584 | self.get_integer_positions() #get integer arrays for labels 585 | 586 | axis = self.format_axis(axis) 587 | 588 | #Shift labels across axis 589 | for a in axis: 590 | 591 | if len(self.label_info[a]) == 0: 592 | continue #no labels to shift for axis 593 | 594 | self.logger.debug("Shifting integer labels on axis: {0}".format(a)) 595 | self.logger.spam("Initial text positions: {0} (...)".format(self.integer_positions[a]["text_pos_int_arr"][:10])) 596 | 597 | #Start by distributing labels across whole axis 598 | n = len(self.integer_positions[a]["text_pos_int_arr"]) 599 | new_text_positions = np.linspace(0, resolution, n).astype(int) 600 | self.integer_positions[a]["text_pos_int_arr"] = new_text_positions #update positions array 601 | self.logger.spam(f"Initial distribution of labels across axis. New positions are: {new_text_positions[:10]} (...)") 602 | self.get_extent_matrix(resolution=resolution, rel_label_size=rel_label_size) #update label extent after shifting 603 | 604 | 605 | #--------- Shift labels closer to ticks without overlapping -------# 606 | extent_matrix = self.integer_positions[a]["extent_matrix_int"] 607 | text_positions = self.integer_positions[a]["text_pos_int_arr"] 608 | tick_positions = self.integer_positions[a]["tick_pos_int_arr"] 609 | 610 | #Check initial overlaps 611 | overlap = extent_matrix.sum(axis=0) 612 | if max(overlap) > 1: #if any position has more than one label 613 | self.logger.warning("The labels cannot be fit into the range without overlap.") 614 | 615 | #How much space would be needed? 616 | space_needed = sum(extent_matrix.sum(axis=1)) 617 | resolution = extent_matrix.shape[1] 618 | 619 | needed_extend = space_needed / resolution - 1 620 | self.logger.warning(f"Set 'expand_axis' to at least {needed_extend:.2f} in order to fit labels into the range.") 621 | 622 | 623 | ##### shift until no longer possible 624 | new_text_positions = self.move_elements(current_pos_arr=text_positions, 625 | target_pos_arr=tick_positions, 626 | extent_matrix=extent_matrix, 627 | speed=speed) 628 | 629 | self.logger.spam(f"Done shifting labels on axis {a}. New positions are: {new_text_positions[:10]}") 630 | 631 | #Update positions 632 | self.integer_positions[a]["text_pos_int_arr"] = new_text_positions 633 | 634 | 635 | #-------------------------------------------------------------# 636 | #------------------ Apply changes to plot --------------------# 637 | #-------------------------------------------------------------# 638 | 639 | def apply_shift(self, axis, perp_shift=5): 640 | """ Apply the integer shifted labels to plot. 641 | 642 | Parameters 643 | ----------- 644 | axis : str or list 645 | Axis to apply shift to. 646 | perp_shift : int, optional 647 | The amount label shift perpendicular to axis. Default: 5. 648 | """ 649 | 650 | #Check input 651 | self.check_value(perp_shift, vmin=0, name="perp_shift") 652 | axis = self.format_axis(axis) 653 | 654 | for a in axis: 655 | 656 | if len(self.label_info[a]) == 0: 657 | continue #no labels to shift for axis 658 | 659 | self.logger.debug(f"Applying shift for axis: {a}") 660 | 661 | #Convert integers back to inches 662 | resolution = self.integer_positions[a]["extent_matrix_int"].shape[1] #width of matrix 663 | inches_arr = np.linspace(self.axis_info[a]["from_inch"], self.axis_info[a]["to_inch"], resolution) 664 | 665 | text_positions_int = self.integer_positions[a]["text_pos_int_arr"] #np.clip(text_positions_int, 0, resolution-1) 666 | text_positions_inch = inches_arr[text_positions_int] 667 | 668 | #Move labels to new positions in inches space 669 | for i, d in enumerate(self.label_info[a]): #loop over list of dicts 670 | 671 | self.logger.spam(f"Moving tick {i} ({d['object']._text})") 672 | 673 | #Find the perpendicular shift in relation to ticks 674 | tick_len_inch = self.tick_info[a][i]["extent_inch_perp"] #tick_bbox_inch.height if axis in ["top", "bottom"] else tick_bbox_inch.width # in inches 675 | this_perp_shift = tick_len_inch * perp_shift 676 | 677 | #Decide whether to shift left/right/up/down 678 | if a in ["bottom", "left"]: 679 | this_perp_shift = -1 * this_perp_shift 680 | 681 | self.logger.spam(f"Tick is {tick_len_inch:0.3f} inches wide; perpendicular shift will be {this_perp_shift:0.3f}.") 682 | 683 | #Get position of shifted label 684 | old_para = self.label_info[a][i]["pos_inch_para"] 685 | new_para = text_positions_inch[i] 686 | d_para = new_para - old_para 687 | d_perp = this_perp_shift 688 | self.logger.spam(f"Parallel location change: {d_para:.3f} ({old_para:.3f} -> {new_para:.3f})") 689 | 690 | #dx/dy in inches 691 | if a in ["left", "right"]: 692 | dx, dy = d_perp, d_para 693 | else: 694 | dx, dy = d_para, d_perp 695 | 696 | #Apply transformation 697 | offset = matplotlib.transforms.ScaledTranslation(dx, dy, self.trans_fig) #from inches into display 698 | label = self.label_info[a][i]["object"] 699 | label.set_transform(label.get_transform() + offset) 700 | 701 | #Save shifted box 702 | d["bbox_shifted"] = label.get_window_extent() #now shifted 703 | 704 | def plot_annotation_lines(self, axis, rel_tick_size=0.25): 705 | """ Plot lines from the original ticks to the newly shifted labels on the plot. """ 706 | 707 | self.check_value(rel_tick_size, vmin=0, vmax=1, name="rel_tick_size") 708 | axis = self.format_axis(axis) 709 | 710 | for a in axis: 711 | for i in range(len(self.label_info[a])): 712 | 713 | self.logger.spam("Plotting annotation line for tick {}".format(i)) 714 | 715 | #Find out how much labels were shifted in data space 716 | original_label_bbox = self.label_info[a][i]["bbox"].transformed(self.trans_data_inv) #from display to data 717 | shifted_label_bbox = self.label_info[a][i]["bbox_shifted"].transformed(self.trans_data_inv) #from display to data 718 | 719 | #perpendicular shift of labels; either positive or negative 720 | if a in ["top", "bottom"]: 721 | perp_shift_data = shifted_label_bbox.y0 - original_label_bbox.y0 722 | else: 723 | perp_shift_data = shifted_label_bbox.x0 - original_label_bbox.x0 724 | 725 | self.logger.spam("The perpendicular shift of label is: {}".format(perp_shift_data)) 726 | 727 | #Start position of ticks on the perpendicular axis 728 | perp_shift_start = self.tick_info[a][i]["pos_data_perp"] 729 | self.logger.spam(f"The start of tick in data coordinates is: {perp_shift_start:3f}") 730 | 731 | #Shift locations of each line segment in data coordinates (parallel) 732 | t1 = perp_shift_start + perp_shift_data #perp_shift is already negative if needed 733 | t2 = t1 - perp_shift_data * rel_tick_size / 2 734 | t3 = perp_shift_start + perp_shift_data * rel_tick_size / 2 735 | t4 = perp_shift_start #location of axis 736 | 737 | #New positions in data coordinates 738 | old_para = self.tick_info[a][i]["pos_data_para"] 739 | new_para = (shifted_label_bbox.x0 + shifted_label_bbox.x1)/2 if a in ["top", "bottom"] else (shifted_label_bbox.y0 + shifted_label_bbox.y1)/2 740 | 741 | #Plot annotation line 742 | perp_coord = [t1, t2, t3, t4] #perpendicular to axis; from label to axis 743 | para_coord = [new_para, new_para, old_para, old_para] #parallel to axis 744 | tick_lw = self.tick_info[a][i]["object"].get_markeredgewidth() 745 | color = self.tick_info[a][i]["object"].get_color() #carry over existing color of tick 746 | 747 | #Get axes limits before plotting 748 | orig_xlim = self.ax.get_xlim() 749 | orig_ylim = self.ax.get_ylim() 750 | 751 | if a in ["top", "bottom"]: 752 | self.ax.plot(para_coord, perp_coord, clip_on=False, lw=tick_lw, color=color) #plot from inches coordinates 753 | else: 754 | self.ax.plot(perp_coord, para_coord, clip_on=False, lw=tick_lw, color=color) #plot from inches coordinates 755 | 756 | #Set axes limit in case they were changed by plotting 757 | self.ax.set_xlim(orig_xlim) 758 | self.ax.set_ylim(orig_ylim) 759 | 760 | #Hide original tick 761 | self.tick_info[a][i]["object"].set_visible(False) --------------------------------------------------------------------------------