├── .binder ├── postBuild └── environment.yml ├── examples ├── im1.jpg ├── im2.jpg ├── im3.jpg ├── basic_labeller.py ├── multiclass_labeller.py ├── create_example.py ├── callbacks.py └── lazy_loading.py ├── setup.py ├── docs ├── _static │ ├── multi_class.gif │ ├── single_class.gif │ └── custom.css ├── api │ └── mpl_image_labeller.rst ├── Makefile ├── make.bat ├── index.md ├── examples │ ├── lazy-loading.ipynb │ ├── multi-class.ipynb │ ├── single-class.ipynb │ └── callbacks.ipynb ├── contributing.md └── conf.py ├── MANIFEST.in ├── readthedocs.yml ├── mpl_image_labeller ├── __init__.py ├── _util.py ├── _widgets.py └── _labeller.py ├── .github ├── ISSUE_TEMPLATE.md └── workflows │ ├── publish.yml │ └── test.yml ├── pyproject.toml ├── tox.ini ├── .pre-commit-config.yaml ├── LICENSE ├── .gitignore ├── setup.cfg ├── tests └── test_mpl_image_labeller.py └── README.md /.binder/postBuild: -------------------------------------------------------------------------------- 1 | python -m pip install -vvv -e . 2 | -------------------------------------------------------------------------------- /examples/im1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mpl-extensions/mpl-image-labeller/HEAD/examples/im1.jpg -------------------------------------------------------------------------------- /examples/im2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mpl-extensions/mpl-image-labeller/HEAD/examples/im2.jpg -------------------------------------------------------------------------------- /examples/im3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mpl-extensions/mpl-image-labeller/HEAD/examples/im3.jpg -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup(use_scm_version={"write_to": "mpl_image_labeller/_version.py"}) 4 | -------------------------------------------------------------------------------- /docs/_static/multi_class.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mpl-extensions/mpl-image-labeller/HEAD/docs/_static/multi_class.gif -------------------------------------------------------------------------------- /docs/_static/single_class.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mpl-extensions/mpl-image-labeller/HEAD/docs/_static/single_class.gif -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include README.md 3 | 4 | recursive-include tests * 5 | recursive-exclude * __pycache__ 6 | recursive-exclude * *.py[co] 7 | 8 | recursive-include docs *.md conf.py Makefile make.bat *.jpg *.png *.gif 9 | -------------------------------------------------------------------------------- /docs/api/mpl_image_labeller.rst: -------------------------------------------------------------------------------- 1 | mpl\_image\_labeller package 2 | ============================ 3 | 4 | Module contents 5 | --------------- 6 | 7 | .. automodule:: mpl_image_labeller 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | -------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | python: 3 | version: 3.8 4 | install: 5 | - method: pip 6 | path: . 7 | extra_requirements: 8 | - doc 9 | 10 | # Build documentation in the docs/ directory with Sphinx 11 | sphinx: 12 | configuration: docs/conf.py 13 | -------------------------------------------------------------------------------- /.binder/environment.yml: -------------------------------------------------------------------------------- 1 | # based on https://github.com/jupyterlab-contrib/jupyterlab-vim/blob/master/binder/environment.yml 2 | name: mpl-image-labller-demo 3 | 4 | channels: 5 | - conda-forge 6 | 7 | dependencies: 8 | # runtime dependencies 9 | - python >=3.8,<3.10.0a0 10 | - jupyterlab >=3,<4.0.0a0 11 | - pip 12 | - ipympl 13 | -------------------------------------------------------------------------------- /examples/basic_labeller.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | from mpl_image_labeller import image_labeller 5 | 6 | images = np.random.randn(5, 10, 10) 7 | labeller = image_labeller( 8 | images, classes=["good", "bad", "meh"], label_keymap=["a", "s", "d"] 9 | ) 10 | plt.show() 11 | print(labeller.labels) 12 | -------------------------------------------------------------------------------- /mpl_image_labeller/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from ._version import version as __version__ 3 | except ImportError: 4 | __version__ = "unknown" 5 | from ._labeller import image_labeller 6 | 7 | __author__ = "Ian Hunt-Isaak" 8 | __email__ = "ianhuntisaak@gmail.com" 9 | 10 | __all__ = [ 11 | "__version__", 12 | "__author__", 13 | "__email__", 14 | "image_labeller", 15 | ] 16 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | * mpl-image-labeller version: 2 | * Python version: 3 | * Operating System: 4 | 5 | ### Description 6 | 7 | Describe what you were trying to get done. 8 | Tell us what happened, what went wrong, and what you expected to happen. 9 | 10 | ### What I Did 11 | 12 | ``` 13 | Paste the command(s) you ran and the output. 14 | If there was a crash, please include the traceback here. 15 | ``` 16 | -------------------------------------------------------------------------------- /examples/multiclass_labeller.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | from mpl_image_labeller import image_labeller 5 | 6 | images = np.random.randn(5, 10, 10) 7 | labeller = image_labeller( 8 | images, 9 | classes=["good", "bad", "meh"], 10 | label_keymap=["a", "s", "d"], 11 | multiclass=True, 12 | ) 13 | plt.show() 14 | print(labeller.labels) 15 | print(labeller.labels_onehot) 16 | -------------------------------------------------------------------------------- /examples/create_example.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | from mpl_image_labeller import image_labeller 4 | 5 | # from PIL import imread 6 | im1 = plt.imread("im1.jpg") 7 | im2 = plt.imread("im2.jpg") 8 | im3 = plt.imread("im3.jpg") 9 | ims = [im1, im2, im3] 10 | labeller = image_labeller( 11 | ims, 12 | classes=["doggo", "cat", "car", "sofa"], 13 | # classes=["doggo", "cat", "other"], 14 | label_keymap=["a", "s", "d", "f"], 15 | multiclass=True, 16 | ) 17 | plt.show() 18 | print(labeller.labels) 19 | -------------------------------------------------------------------------------- /docs/_static/custom.css: -------------------------------------------------------------------------------- 1 | /* Fix numpydoc format delimiters */ 2 | .classifier:before { 3 | font-style: normal; 4 | margin: 0.5em; 5 | content: ":"; 6 | } 7 | 8 | /* override table no-wrap */ 9 | .wy-table-responsive table td, 10 | .wy-table-responsive table th { 11 | white-space: normal; 12 | } 13 | 14 | .text-align\:left, text-align\:left > p { 15 | text-align: left 16 | } 17 | .text-align\:center, text-align\:center > p { 18 | text-align: center 19 | } 20 | .text-align\:center, text-align\:right > p { 21 | text-align: right 22 | } 23 | -------------------------------------------------------------------------------- /examples/callbacks.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | from mpl_image_labeller import image_labeller 5 | 6 | images = np.random.randn(5, 10, 10) 7 | labeller = image_labeller(images, classes=["good", "bad", "blarg"]) 8 | 9 | 10 | def image_changed_callback(index, image): 11 | print(index) 12 | print(image.sum()) 13 | 14 | 15 | def label_assigned(index, label): 16 | print(f"label {label} assigned to image {index}") 17 | 18 | 19 | labeller.on_image_changed(image_changed_callback) 20 | labeller.on_label_assigned(image_changed_callback) 21 | plt.show() 22 | -------------------------------------------------------------------------------- /examples/lazy_loading.py: -------------------------------------------------------------------------------- 1 | # You can lazy load images by providing a function instead of a list for *images* 2 | # if you do this then you must also provide *N_images* in the labeller constructor 3 | 4 | 5 | import matplotlib.pyplot as plt 6 | from numpy.random import default_rng 7 | 8 | from mpl_image_labeller import image_labeller 9 | 10 | 11 | def lazy_image_generator(idx): 12 | rng = default_rng(idx) 13 | return rng.random((rng.integers(5, 15), rng.integers(5, 15))) 14 | 15 | 16 | labeller = image_labeller( 17 | lazy_image_generator, classes=["cool", "rad", "lame"], N_images=57 18 | ) 19 | plt.show() 20 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=45", "wheel", "setuptools_scm>=6.2"] 3 | 4 | [tool.isort] 5 | profile = "black" 6 | src_paths = "mpl_image_labeller" 7 | multi_line_output = 3 8 | 9 | [tool.pydocstyle] 10 | match_dir = "mpl_image_labeller" 11 | convention = "numpy" 12 | add_select = ["D402","D415","D417"] 13 | 14 | [tool.pytest.ini_options] 15 | addopts = "-W error" 16 | 17 | [tool.mypy] 18 | files = "mpl_image_labeller" 19 | warn_unused_configs = true 20 | warn_unused_ignores = true 21 | check_untyped_defs = true 22 | implicit_reexport = false 23 | # this is strict! 24 | # disallow_untyped_defs = true 25 | show_column_numbers = true 26 | show_error_codes = true 27 | ignore_missing_imports = true 28 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish Package 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | 8 | jobs: 9 | deploy: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v2 13 | - name: Install Python 14 | uses: actions/setup-python@v2 15 | with: 16 | python-version: '3.x' 17 | - name: Install dependencies 18 | run: | 19 | python -m pip install --upgrade pip 20 | pip install packaging setuptools twine wheel build 21 | - name: Publish the Python package 22 | env: 23 | TWINE_USERNAME: __token__ 24 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 25 | run: | 26 | python -m build -s -w 27 | twine upload dist/* 28 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= -T --color 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | 22 | watch: 23 | sphinx-autobuild . _build/html --open-browser --watch examples 24 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | schedule: 9 | - cron: "0 16 * * 1" # monday at noon est 10 | 11 | jobs: 12 | test: 13 | name: ${{ matrix.mpl-version}} 14 | runs-on: ubuntu-latest 15 | strategy: 16 | fail-fast: false 17 | matrix: 18 | mpl-version: ['3.4', 'latest'] 19 | python-version: [3.9] 20 | steps: 21 | - uses: actions/checkout@v2 22 | 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v2 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | 28 | 29 | - if: matrix.mpl-version=='latest' 30 | name: Install dev Matplotlib 31 | run: pip install git+https://github.com/matplotlib/matplotlib.git 32 | 33 | - if: matrix.mpl-version!='latest' 34 | name: Install matplotlib pinned 35 | run: pip install matplotlib~=${{matrix.mpl-version}} 36 | 37 | - name: Install 38 | run: | 39 | pip install -e ".[testing]" 40 | pip install pytest-cov 41 | 42 | - name: Run Tests 43 | run: pytest -v --color=yes --cov=mpl_image_labeller --cov-report=xml 44 | 45 | - name: Coverage 46 | uses: codecov/codecov-action@v2 47 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | # https://github.com/ComPWA/ampform/blob/2ad1f7d14dc9bb58045a0fa83af4a353232dc5a6/tox.ini 2 | [tox] 3 | envlist = 4 | py, 5 | doc, 6 | sty, 7 | passenv = PYTHONPATH 8 | skip_install = True 9 | skip_missing_interpreters = True 10 | skipsdist = True 11 | 12 | [testenv] 13 | description = 14 | Run all unit tests 15 | allowlist_externals = 16 | pytest 17 | commands = 18 | pytest {posargs} 19 | 20 | [testenv:doc] 21 | description = 22 | Build documentation and API through Sphinx 23 | changedir = docs 24 | allowlist_externals = 25 | make 26 | commands = 27 | make html 28 | 29 | [testenv:doclive] 30 | description = 31 | Set up a server to directly preview changes to the HTML pages 32 | allowlist_externals = 33 | sphinx-autobuild 34 | passenv = 35 | EXECUTE_NB 36 | TERM 37 | commands = 38 | sphinx-autobuild \ 39 | --watch docs \ 40 | --watch mpl_image_labeller \ 41 | --re-ignore .*/.ipynb_checkpoints/.* \ 42 | --re-ignore .*/__pycache__/.* \ 43 | --re-ignore docs/_build/.* \ 44 | --re-ignore docs/api/.* \ 45 | --re-ignore docs/examples/.*.gif \ 46 | --re-ignore docs/gallery/.* \ 47 | --open-browser \ 48 | docs/ docs/_build/html 49 | 50 | [testenv:docnb] 51 | description = 52 | Build documentation through Sphinx WITH output of Jupyter notebooks 53 | setenv = 54 | EXECUTE_NB = "yes" 55 | changedir = docs 56 | allowlist_externals = 57 | make 58 | commands = 59 | make html 60 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | ci: 2 | autoupdate_schedule: 'quarterly' 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v4.0.1 6 | hooks: 7 | - id: check-docstring-first 8 | - id: end-of-file-fixer 9 | - id: trailing-whitespace 10 | - repo: https://github.com/asottile/setup-cfg-fmt 11 | rev: v1.17.0 12 | hooks: 13 | - id: setup-cfg-fmt 14 | - repo: https://github.com/PyCQA/flake8 15 | rev: 3.9.2 16 | hooks: 17 | - id: flake8 18 | additional_dependencies: [flake8-typing-imports==1.7.0] 19 | - repo: https://github.com/myint/autoflake 20 | rev: v1.4 21 | hooks: 22 | - id: autoflake 23 | args: ["--in-place", "--remove-all-unused-imports", "--ignore-init-module-imports", "--remove-unused-variables"] 24 | - repo: https://github.com/PyCQA/isort 25 | rev: 5.8.0 26 | hooks: 27 | - id: isort 28 | - repo: https://github.com/psf/black 29 | rev: 21.5b2 30 | hooks: 31 | - id: black 32 | - repo: https://github.com/asottile/pyupgrade 33 | rev: v2.19.0 34 | hooks: 35 | - id: pyupgrade 36 | args: [--py37-plus] 37 | - repo: https://github.com/pre-commit/mirrors-mypy 38 | rev: v0.812 39 | hooks: 40 | - id: mypy 41 | - repo: https://github.com/nbQA-dev/nbQA 42 | rev: 1.1.1 43 | hooks: 44 | - id: nbqa-black 45 | - id: nbqa-isort 46 | - repo: https://github.com/kynan/nbstripout 47 | rev: 0.5.0 48 | hooks: 49 | - id: nbstripout 50 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021 Ian Hunt-Isaak 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # SageMath parsed files 81 | *.sage.py 82 | 83 | # dotenv 84 | .env 85 | 86 | # virtualenv 87 | .venv 88 | venv/ 89 | ENV/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | .spyproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | # mkdocs documentation 99 | /site 100 | 101 | # mypy 102 | .mypy_cache/ 103 | 104 | # IDE settings 105 | .vscode/ 106 | _version.py 107 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = mpl_image_labeller 3 | description = Use interactive matplotlib to label images for classification 4 | long_description = file: README.md 5 | long_description_content_type = text/markdown 6 | url = https://mpl-image-labeller.rtfd.io 7 | author = Ian Hunt-Isaak 8 | author_email = ianhuntisaak@gmail.com 9 | license = BSD-3-Clause 10 | license_file = LICENSE 11 | classifiers = 12 | Development Status :: 5 - Production/Stable 13 | Framework :: Matplotlib 14 | License :: OSI Approved :: BSD License 15 | Natural Language :: English 16 | Programming Language :: Python :: 3 17 | Programming Language :: Python :: 3 :: Only 18 | Programming Language :: Python :: 3.7 19 | Programming Language :: Python :: 3.8 20 | Programming Language :: Python :: 3.9 21 | Programming Language :: Python :: Implementation :: CPython 22 | project_urls = 23 | Tracker = https://github.com/ianhi/mpl-image-labeller/issues 24 | Changelog = https://github.com/ianhi/mpl-image-labeller/releases 25 | Documentation = https://mpl-image-labeller.rtfd.io 26 | Source = https://github.com/ianhi/mpl-interactions 27 | 28 | [options] 29 | packages = find: 30 | install_requires = 31 | matplotlib 32 | python_requires = >=3.7 33 | zip_safe = False 34 | 35 | [options.extras_require] 36 | dev = 37 | black 38 | flake8 39 | flake8-docstrings 40 | ipython 41 | isort 42 | jedi<0.18.0 43 | mypy 44 | pre-commit 45 | pydocstyle 46 | pytest 47 | doc = 48 | Sphinx>=1.5 49 | jupyter-sphinx 50 | myst-nb 51 | numpydoc 52 | sphinx-book-theme 53 | sphinx-copybutton 54 | sphinx-panels 55 | sphinx-thebe 56 | sphinx-togglebutton 57 | testing = 58 | pytest 59 | 60 | [bdist_wheel] 61 | universal = 1 62 | 63 | [flake8] 64 | exclude = docs, _version.py, .eggs, example 65 | max-line-length = 88 66 | docstring-convention = "numpy" 67 | -------------------------------------------------------------------------------- /tests/test_mpl_image_labeller.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from mpl_image_labeller import image_labeller 4 | 5 | N = 5 6 | M = 3 7 | im1 = np.ones([N, M]) 8 | im1[0, 0] = 0 9 | im2 = np.ones([M, N]) * 5 10 | im2[0, 0] = 0.5 11 | ims = [im1, im2] 12 | 13 | 14 | # norm and extent 15 | def test_norm_and_extent_updates(): 16 | labeller = image_labeller(ims, ["good", "bad"]) 17 | assert labeller._im.norm.vmin == 0 18 | assert labeller._im.norm.vmax == 1 19 | assert labeller._im.get_extent() == (-0.5, M - 0.5, N - 0.5, -0.5) 20 | assert labeller._image_ax.get_xlim() == (-0.5, M - 0.5) 21 | assert labeller._image_ax.get_ylim() == (N - 0.5, -0.5) 22 | 23 | labeller.image_index += 1 24 | assert labeller._im.norm.vmin == 0.5 25 | assert labeller._im.norm.vmax == 5 26 | assert labeller._im.get_extent() == (-0.5, N - 0.5, M - 0.5, -0.5) 27 | assert labeller._image_ax.get_xlim() == (-0.5, N - 0.5) 28 | assert labeller._image_ax.get_ylim() == (M - 0.5, -0.5) 29 | 30 | 31 | def test_norm_with_explict_vmin_vmax(): 32 | labeller = image_labeller(ims, ["good", "bad"], vmin=0.3, vmax=4) 33 | assert labeller._im.norm.vmin == 0.3 34 | assert labeller._im.norm.vmax == 4 35 | labeller.image_index += 1 36 | assert labeller._im.norm.vmin == 0.3 37 | assert labeller._im.norm.vmax == 4 38 | 39 | labeller = image_labeller(ims, ["good", "bad"], vmin=0.3) 40 | assert labeller._im.norm.vmin == 0.3 41 | assert labeller._im.norm.vmax == im1.max() 42 | labeller.image_index += 1 43 | assert labeller._im.norm.vmin == 0.3 44 | assert labeller._im.norm.vmax == im2.max() 45 | 46 | labeller = image_labeller(ims, ["good", "bad"], vmax=4) 47 | assert labeller._im.norm.vmin == im1.min() 48 | assert labeller._im.norm.vmax == 4 49 | labeller.image_index += 1 50 | assert labeller._im.norm.vmin == im2.min() 51 | assert labeller._im.norm.vmax == 4 52 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | 2 | # mpl-image-labeller's Documentation 3 | 4 | Use Matplotlib to label images for classification. Works anywhere Matplotlib does - from the notebook to a standalone gui! 5 | 6 | 7 | ## Key features 8 | - Single or Multiclass interfaces! 9 | - {doc}`examples/single-class` 10 | - {doc}`examples/multi-class` 11 | - Supports lists of classes or onehot encodings 12 | - Uses keys instead of mouse 13 | - Only depends on Matplotlib 14 | - Works anywhere - from inside Jupyter to any supported GUI framework 15 | - Displays images with correct aspect ratio 16 | - Easily configurable keymap 17 | - Smart interactions with default Matplotlib keymap 18 | - Callback System (see {doc}`examples/callbacks`) 19 | - Allows Lazy Loading of Images ({doc}`examples/lazy-loading`) 20 | 21 | ## Install 22 | ```bash 23 | pip install mpl-image-labeller 24 | ``` 25 | 26 | ## Example GIFs 27 | ```{table} 28 | :align: center 29 | 30 | | Single Class Interface | Multiclass Interface | 31 | | ----------------------- | --------------------| 32 | |![A gif of the single class interface. Showing keybindings to assign classes to images.](_static/single_class.gif) | ![A gif of the multi-class interface. Showing using both keybindings and mouse to assign classes to images.](_static/multi_class.gif)| 33 | ``` 34 | 35 | ## Getting help 36 | Please ask usage questions on the [Matplotlib 3rd Party Package Discourse](https://discourse.matplotlib.org/c/3rdparty/18). When you do so feel free 37 | to use `@ianhi` to ping me. 38 | 39 | ## Reporting Issues 40 | Please report any issues on github at https://github.com/ianhi/mpl-image-labeller/issues/new/choose 41 | 42 | 43 | 44 | 45 | 46 | ```{toctree} 47 | :maxdepth: 2 48 | 49 | API 50 | contributing 51 | ``` 52 | 53 | ```{toctree} 54 | :caption: Examples 55 | :maxdepth: 1 56 | 57 | examples/single-class.ipynb 58 | examples/multi-class.ipynb 59 | examples/lazy-loading.ipynb 60 | examples/callbacks.ipynb 61 | ``` 62 | -------------------------------------------------------------------------------- /mpl_image_labeller/_util.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | from collections.abc import Iterable 3 | 4 | import numpy as np 5 | from matplotlib.cbook import CallbackRegistry 6 | 7 | __all__ = [ 8 | "deactivatable_CallbackRegistry", 9 | "add_text_to_rect", 10 | "list_to_onehot", 11 | "onehot_to_list", 12 | "ConflictingArgumentsError", 13 | ] 14 | 15 | 16 | class deactivatable_CallbackRegistry(CallbackRegistry): 17 | def __init__(self, exception_handler=None): 18 | if exception_handler is not None: 19 | super().__init__(exception_handler) 20 | else: 21 | super().__init__() 22 | self._active = True 23 | 24 | def process(self, s, *args, **kwargs): 25 | """ 26 | Process signal *s*. 27 | 28 | All of the functions registered to receive callbacks on *s* will be 29 | called with ``*args`` and ``**kwargs``. 30 | """ 31 | if self._active: 32 | super().process(s, *args, **kwargs) 33 | 34 | @contextlib.contextmanager 35 | def deactivate(self): 36 | self._active = False 37 | yield 38 | self._active = True 39 | 40 | 41 | def add_text_to_rect(text, rect, **text_kwargs): 42 | rx, ry = rect.get_xy() 43 | cx = rx + rect.get_width() / 2.0 44 | cy = ry + rect.get_height() / 2.0 45 | ha = text_kwargs.pop("ha", "center") 46 | va = text_kwargs.pop("va", "center") 47 | rect.axes.annotate(text, (cx, cy), ha=ha, va=va, **text_kwargs) 48 | 49 | 50 | def list_to_onehot(labels, classes): 51 | lookup = {c: i for i, c in enumerate(classes)} 52 | arr = np.zeros((len(labels), len(classes)), dtype=bool) 53 | for i, l in enumerate(labels): 54 | 55 | if l is None: 56 | continue 57 | elif isinstance(l, str) or not isinstance(l, Iterable): 58 | # str, or number, or something like that 59 | arr[i, lookup[l]] = True 60 | else: 61 | for j in l: 62 | arr[i, lookup[j]] = True 63 | return arr 64 | 65 | 66 | def onehot_to_list(onehot, classes): 67 | c_arr = np.asarray(classes) 68 | labels = [] 69 | for row in onehot: 70 | labels.append(list(c_arr[row])) 71 | return labels 72 | 73 | 74 | class ConflictingArgumentsError(ValueError): 75 | pass 76 | -------------------------------------------------------------------------------- /docs/examples/lazy-loading.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "7b2b704c-a2b1-4497-bd93-b6619804d848", 6 | "metadata": {}, 7 | "source": [ 8 | "# Images from a function (Lazy loading)\n", 9 | "\n", 10 | "You do not need to provide an array of images. Instead you can provide a function that returns an image given an index. This enables the follow:\n", 11 | "\n", 12 | "1. Lazy loading of images\n", 13 | "2. Easily have images of different sizes (the image will update limits automatically)" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "id": "483bffb1-f291-40ad-80ba-e9e08b8f6e88", 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "# if in a notebook\n", 24 | "%matplotlib ipympl" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "id": "b62a762f-26be-4dba-ab07-ec26648b970b", 31 | "metadata": { 32 | "tags": [] 33 | }, 34 | "outputs": [], 35 | "source": [ 36 | "# You can lazy load images by providing a function instead of a list for *images*\n", 37 | "# if you do this then you must also provide *N_images* in the labeller constructor\n", 38 | "\n", 39 | "\n", 40 | "import matplotlib.pyplot as plt\n", 41 | "from numpy.random import default_rng\n", 42 | "\n", 43 | "from mpl_image_labeller import image_labeller\n", 44 | "\n", 45 | "\n", 46 | "def lazy_image_generator(idx):\n", 47 | " rng = default_rng(idx)\n", 48 | " return rng.random((rng.integers(5, 15), rng.integers(5, 15)))\n", 49 | "\n", 50 | "\n", 51 | "labeller = image_labeller(\n", 52 | " lazy_image_generator, classes=[\"cool\", \"rad\", \"lame\"], N_images=57\n", 53 | ")\n", 54 | "plt.show()" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "id": "8000249c-babc-430d-bd29-7a55145ce0bc", 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [] 64 | } 65 | ], 66 | "metadata": { 67 | "kernelspec": { 68 | "display_name": "Python 3 (ipykernel)", 69 | "language": "python", 70 | "name": "python3" 71 | }, 72 | "language_info": { 73 | "codemirror_mode": { 74 | "name": "ipython", 75 | "version": 3 76 | }, 77 | "file_extension": ".py", 78 | "mimetype": "text/x-python", 79 | "name": "python", 80 | "nbconvert_exporter": "python", 81 | "pygments_lexer": "ipython3", 82 | "version": "3.9.7" 83 | } 84 | }, 85 | "nbformat": 4, 86 | "nbformat_minor": 5 87 | } 88 | -------------------------------------------------------------------------------- /docs/examples/multi-class.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "10d8e4fe-11be-478a-bc33-a81bb0dce987", 6 | "metadata": {}, 7 | "source": [ 8 | "# Multi Class\n", 9 | "\n", 10 | "You can also allow each image to belong to multiple categories with the `multiclass` argument." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "d7aa1379-5cbd-4284-8d99-373bfd02c807", 17 | "metadata": { 18 | "tags": [] 19 | }, 20 | "outputs": [], 21 | "source": [ 22 | "# If in a notebook:\n", 23 | "%matplotlib ipympl" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "id": "d7db1d50-2fe0-4fa5-9af9-1f0753acd34d", 30 | "metadata": { 31 | "tags": [] 32 | }, 33 | "outputs": [], 34 | "source": [ 35 | "import matplotlib.pyplot as plt\n", 36 | "import numpy as np\n", 37 | "\n", 38 | "from mpl_image_labeller import image_labeller\n", 39 | "\n", 40 | "images = np.random.randn(5, 10, 10)\n", 41 | "labeller = image_labeller(\n", 42 | " images,\n", 43 | " classes=[\"good\", \"bad\", \"meh\"],\n", 44 | " label_keymap=[\"a\", \"s\", \"d\"],\n", 45 | " multiclass=True,\n", 46 | ")\n", 47 | "plt.show()" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "id": "4471612b-e5f7-4a55-8232-536da875fb29", 53 | "metadata": {}, 54 | "source": [ 55 | "The natural representation of this multiclass is a onehot encoding accessible (and settable!) via the `labels_onehot` property." 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "id": "b724f277-2eb9-4d5c-b0a7-0820ce4b20d2", 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "print(labeller.labels_onehot)" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "id": "2cf1e17e-8a18-4a56-b9d8-4be61fe4bd47", 71 | "metadata": {}, 72 | "source": [ 73 | "If you can you can also get the labels as a ragged list of lists via the `labels` property" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "id": "700dcf86-7337-4b50-aab0-ddd7346a24d9", 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "print(labeller.labels)" 84 | ] 85 | } 86 | ], 87 | "metadata": { 88 | "kernelspec": { 89 | "display_name": "Python 3 (ipykernel)", 90 | "language": "python", 91 | "name": "python3" 92 | }, 93 | "language_info": { 94 | "codemirror_mode": { 95 | "name": "ipython", 96 | "version": 3 97 | }, 98 | "file_extension": ".py", 99 | "mimetype": "text/x-python", 100 | "name": "python", 101 | "nbconvert_exporter": "python", 102 | "pygments_lexer": "ipython3", 103 | "version": "3.9.7" 104 | } 105 | }, 106 | "nbformat": 4, 107 | "nbformat_minor": 5 108 | } 109 | -------------------------------------------------------------------------------- /docs/examples/single-class.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "cdf469e7-0110-4b75-97d5-69c25aa290f5", 6 | "metadata": {}, 7 | "source": [ 8 | "# Single Class\n" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "4a694f4d-ba0d-4ea4-b15e-db17a2fa8e61", 15 | "metadata": { 16 | "tags": [] 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "# If running in a notebook make matplotlib interactive\n", 21 | "%matplotlib ipympl" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "id": "8b0938f1-fab9-4127-b836-9c06a2fbc333", 27 | "metadata": {}, 28 | "source": [ 29 | "```{note}\n", 30 | "In a notebook you need to make sure to click on the figure in order to give it keyboard focus.\n", 31 | "```" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "id": "65afcfa0-1511-4721-a79a-8df969a1002f", 38 | "metadata": { 39 | "tags": [] 40 | }, 41 | "outputs": [], 42 | "source": [ 43 | "import matplotlib.pyplot as plt\n", 44 | "import numpy as np\n", 45 | "\n", 46 | "from mpl_image_labeller import image_labeller\n", 47 | "\n", 48 | "images = np.random.randn(5, 10, 10)\n", 49 | "labeller = image_labeller(\n", 50 | " images, classes=[\"good\", \"bad\", \"meh\"], label_keymap=[\"a\", \"s\", \"d\"]\n", 51 | ")\n", 52 | "plt.show()" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "id": "f517f339-a7e1-4bd5-bff4-fc56d9948a1d", 58 | "metadata": {}, 59 | "source": [ 60 | "After you label the images then the labels will be available as a list:" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "id": "c79afd96-4f6d-46f8-ab7d-46b5dc6de38e", 67 | "metadata": { 68 | "tags": [] 69 | }, 70 | "outputs": [], 71 | "source": [ 72 | "print(labeller.labels)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "id": "7416b621-d9be-4e5d-9600-14f341ee419b", 78 | "metadata": {}, 79 | "source": [ 80 | "Or as a onehot encoding" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "id": "cca10064-0f61-47c3-b526-5d1330170c61", 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "print(labeller.labels_onehot)" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "id": "253335a9-17ea-4dfd-96a4-580f8593ac1d", 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [] 100 | } 101 | ], 102 | "metadata": { 103 | "kernelspec": { 104 | "display_name": "Python 3 (ipykernel)", 105 | "language": "python", 106 | "name": "python3" 107 | }, 108 | "language_info": { 109 | "codemirror_mode": { 110 | "name": "ipython", 111 | "version": 3 112 | }, 113 | "file_extension": ".py", 114 | "mimetype": "text/x-python", 115 | "name": "python", 116 | "nbconvert_exporter": "python", 117 | "pygments_lexer": "ipython3", 118 | "version": "3.9.7" 119 | } 120 | }, 121 | "nbformat": 4, 122 | "nbformat_minor": 5 123 | } 124 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mpl-image-labeller 2 | 3 | [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/ianhi/mpl-image-labeller/main?urlpath=lab/tree/docs/examples) 4 | [![Documentation Status](https://readthedocs.org/projects/mpl-image-labeller/badge/?version=stable)](https://mpl-image-labeller.readthedocs.io/en/stable/?badge=stable) 5 | 6 | 7 | [![License](https://img.shields.io/pypi/l/mpl-image-labeller.svg?color=green)](https://github.com/ianhi/mpl-image-labeller/raw/master/LICENSE) 8 | [![PyPI](https://img.shields.io/pypi/v/mpl-image-labeller.svg?color=green)](https://pypi.org/project/mpl-image-labeller) 9 | [![Python Version](https://img.shields.io/pypi/pyversions/mpl-image-labeller.svg?color=green)](https://python.org) 10 | 11 | Use Matplotlib to label images for classification. Works anywhere Matplotlib does - from the notebook to a standalone gui! 12 | 13 | For more see the [documentation](https://mpl-image-labeller.readthedocs.io/en/stable/?badge=stable). 14 | 15 | ## Install 16 | 17 | ```bash 18 | pip install mpl-image-labeller 19 | ``` 20 | ## Key features 21 | - Simple interface 22 | - Uses keys instead of mouse 23 | - Only depends on Matplotlib 24 | - Works anywhere - from inside Jupyter to any supported GUI framework 25 | - Displays images with correct aspect ratio 26 | - Easily configurable keymap 27 | - Smart interactions with default Matplotlib keymap 28 | - Callback System (see `examples/callbacks.py`) 29 | 30 | **single class per image** 31 | 32 | ![gif of usage for labelling images of cats and dogs](docs/_static/single_class.gif) 33 | 34 | **multiple classes per image** 35 | 36 | ![gif of usage for labelling images of cats and dogs](docs/_static/multi_class.gif) 37 | 38 | ## Usage 39 | 40 | ```python 41 | import matplotlib.pyplot as plt 42 | import numpy as np 43 | 44 | from mpl_image_labeller import image_labeller 45 | 46 | images = np.random.randn(5, 10, 10) 47 | labeller = image_labeller( 48 | images, classes=["good", "bad", "meh"], label_keymap=["a", "s", "d"] 49 | ) 50 | plt.show() 51 | ``` 52 | 53 | **accessing the axis** 54 | You can further modify the image (e.g. add masks over them) by using the plotting methods on 55 | axis object accessible by `labeller.ax`. 56 | 57 | **Lazy Loading Images** 58 | If you want to lazy load your images you can provide a function to give the images. This function should take 59 | the integer `idx` as an argument and return the image that corresponds to that index. If you do this then you 60 | must also provide `N_images` in the constructor to let the object know how many images it should expect. See `examples/lazy_loading.py` for an example. 61 | 62 | ### Controls 63 | 64 | - `<-` move one image back 65 | - `->` move one image forward 66 | 67 | To label images use the keys defined in the `label_keymap` argument - default 0, 1, 2... 68 | 69 | 70 | Get the labels by accessing the `labels` property. 71 | 72 | ### Overwriting default keymap 73 | Matplotlib has default keybindings that it applied to all figures via `rcparams.keymap` that allow for actions such as `s` to save or `q` to quit. If you inlcude one of these keys as a shortcut for labelling as a class then that default keymap will be disabled for that figure. 74 | 75 | 76 | ## Related Projects 77 | 78 | This is not the first project to implement easy image labelling but seems to be the first to do so entirely in Matplotlib. The below 79 | projects implement varying degrees of complexity and/or additional features in different frameworks. 80 | 81 | - https://github.com/wbwvos/pidgey 82 | - https://github.com/agermanidis/pigeon 83 | - https://github.com/Serhiy-Shekhovtsov/tkteach 84 | - https://github.com/robertbrada/PyQt-image-annotation-tool 85 | - https://github.com/Cartucho/OpenLabeling 86 | -------------------------------------------------------------------------------- /docs/examples/callbacks.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "3982f572-4151-47df-923d-3cd6794e7070", 6 | "metadata": {}, 7 | "source": [ 8 | "# Callbacks\n", 9 | "\n", 10 | "The image labeller implements a callback system that allows you to run arbitrary code whenever the displayed image changes or when an image has a label assigned." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "eebbcc16-dd8a-4e9b-9d79-b7cb2d3dbee0", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "# If in a notebook\n", 21 | "%matplotlib ipympl" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "id": "f576ccc0-d8e8-4d6c-b09c-5c3634c94bf6", 28 | "metadata": { 29 | "tags": [] 30 | }, 31 | "outputs": [], 32 | "source": [ 33 | "import matplotlib.pyplot as plt\n", 34 | "import numpy as np\n", 35 | "\n", 36 | "from mpl_image_labeller import image_labeller\n", 37 | "\n", 38 | "images = np.random.randn(5, 10, 10)\n", 39 | "labeller = image_labeller(images, classes=[\"good\", \"bad\", \"blarg\"])\n", 40 | "\n", 41 | "\n", 42 | "def image_changed_callback(index, image):\n", 43 | " print(index)\n", 44 | " print(image.sum())\n", 45 | "\n", 46 | "\n", 47 | "def label_assigned(index, label):\n", 48 | " print(f\"label {label} assigned to image {index}\")\n", 49 | "\n", 50 | "\n", 51 | "labeller.on_image_changed(image_changed_callback)\n", 52 | "labeller.on_label_assigned(image_changed_callback)\n", 53 | "plt.show()" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "id": "1e4c6ae8-c124-4e05-a9d2-6c4d987d9cf7", 59 | "metadata": {}, 60 | "source": [ 61 | "## Overlaying a mask\n", 62 | "\n", 63 | "One potential usage of this is to overlay a mask over the images which changes for each image. If the shape of image is changing then you will also need to adjust the `extent` of the overlay. If this is the case then uncomment the line in the `update_mask` function." 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "id": "352f6e5e-7b0b-44aa-bb19-da2bc6bca654", 70 | "metadata": { 71 | "tags": [] 72 | }, 73 | "outputs": [], 74 | "source": [ 75 | "labeller = image_labeller(images, classes=[\"good\", \"bad\", \"blarg\"])\n", 76 | "\n", 77 | "mask_threshold = 0.6\n", 78 | "\n", 79 | "from numpy.random import default_rng\n", 80 | "\n", 81 | "\n", 82 | "def gen_mask(idx, image):\n", 83 | " # here we return a random mask - but you could base this on your data\n", 84 | " rng = default_rng(idx)\n", 85 | " mask = rng.random(image.shape)\n", 86 | " return mask > mask_threshold\n", 87 | "\n", 88 | "\n", 89 | "overlay = labeller.ax.imshow(\n", 90 | " gen_mask(0, images[0]), cmap=\"gray\", vmin=0, vmax=mask_threshold, alpha=0.75\n", 91 | ")\n", 92 | "cmap = overlay.cmap.copy()\n", 93 | "cmap.set_over(alpha=0)\n", 94 | "overlay.set_cmap(cmap)\n", 95 | "\n", 96 | "\n", 97 | "def update_mask(idx, image):\n", 98 | " new_mask = gen_mask(idx, image)\n", 99 | " overlay.set_data(new_mask)\n", 100 | "\n", 101 | " # if your image is changing shape uncomment the next line\n", 102 | " # overlay.set_extent((-0.5, new_mask.shape[1] - 0.5, new_mask.shape[0] - 0.5, -0.5))\n", 103 | "\n", 104 | "\n", 105 | "labeller.on_image_changed(update_mask)" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "id": "9f1e35b8-5b0e-442e-ae91-910d4ff04017", 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [] 115 | } 116 | ], 117 | "metadata": { 118 | "kernelspec": { 119 | "display_name": "Python 3 (ipykernel)", 120 | "language": "python", 121 | "name": "python3" 122 | }, 123 | "language_info": { 124 | "codemirror_mode": { 125 | "name": "ipython", 126 | "version": 3 127 | }, 128 | "file_extension": ".py", 129 | "mimetype": "text/x-python", 130 | "name": "python", 131 | "nbconvert_exporter": "python", 132 | "pygments_lexer": "ipython3", 133 | "version": "3.9.7" 134 | } 135 | }, 136 | "nbformat": 4, 137 | "nbformat_minor": 5 138 | } 139 | -------------------------------------------------------------------------------- /docs/contributing.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Thanks for thinking of a way to help improve this library! Remember that contributions come in all shapes and sizes beyond writing bug fixes. Contributing to [documentation](#documentation), opening new [issues](https://github.com/ianhi/mpl-image-labeller/issues) for bugs, asking for clarification on things you find unclear, and requesting new features, are all super valuable contributions. 4 | 5 | ## Code Improvements 6 | 7 | All development for this library happens on GitHub [here](https://github.com/ianhi/mpl-image-labeller). We recommend you work with a [Conda](https://www.anaconda.com/products/individual) environment (or an alternative virtual environment like [`venv`](https://docs.python.org/3/library/venv.html)). 8 | 9 | The below instructions also use [Mamba](https://github.com/mamba-org/mamba#the-fast-cross-platform-package-manager) which is a very fast implementation of `conda`. 10 | 11 | ```bash 12 | git clone 13 | cd mpl-interactions 14 | mamba env create 15 | conda activate mpl-interactions 16 | pre-commit install 17 | ``` 18 | 19 | The `mamba env create` command installs all Python packages that are useful when working on the source code of `mpl_image_labeller` and its documentation. You can also install these packages separately: 20 | 21 | ```bash 22 | pip install -e ".[dev, doc]" 23 | ``` 24 | 25 | The {command}`-e .` flag installs the `mpl_image_labeller` folder in ["editable" mode](https://pip.pypa.io/en/stable/cli/pip_install/#editable-installs) and {command}`[dev]` installs the [optional dependencies](https://setuptools.readthedocs.io/en/latest/userguide/dependency_management.html#optional-dependencies) you need for developing `mpl_image_labeller`. 26 | 27 | ### Seeing your changes 28 | 29 | If you are working in a Jupyter Notebook, then in order to see your code changes you will need to either: 30 | 31 | - Restart the Kernel every time you make a change to the code. 32 | - Make the function reload from the source file every time you run it by using [autoreload](https://ipython.readthedocs.io/en/stable/config/extensions/autoreload.html), e.g.: 33 | 34 | ```python 35 | %load_ext autoreload 36 | %autoreload 2 37 | 38 | from mpl_image_labeller import .... 39 | ``` 40 | 41 | ### Working with Git 42 | 43 | Using Git/GitHub can confusing (), so if you're new to Git, you may find it helpful to use a program like [GitHub Desktop](https://desktop.github.com) and to follow a [guide](https://github.com/firstcontributions/first-contributions#first-contributions). 44 | 45 | Also feel free to ask for help/advice on the relevant GitHub [issue](https://github.com/ianhi/mpl-interactions/issues). 46 | 47 | ## Documentation 48 | 49 | Our documentation on Read the Docs ([mpl-interactions.rtfd.io](https://mpl-interactions.readthedocs.io)) is built with [Sphinx](https://www.sphinx-doc.org) from the notebooks in the `docs` folder. It contains both Markdown files and Jupyter notebooks. 50 | 51 | Examples are best written as Jupyter notebooks. To write a new example, create in a notebook in the `docs/examples` directory and list its path under one of the [`toctree`s](https://www.sphinx-doc.org/en/master/usage/restructuredtext/directives.html#directive-toctree) in the `index.md` file. When the docs are generated, they will be rendered as static html pages by [myst-nb](https://myst-nb.readthedocs.io). 52 | 53 | If you have installed all developer dependencies (see [above](#contributing)), you can view recent modifications to the source files the following simple tox command: 54 | 55 | ```bash 56 | tox -e doc 57 | ``` 58 | 59 | If you open the `index.html` file in your browser you should now be able to see the rendered documentation. 60 | 61 | Alternatively, you can use [sphinx-autobuild](https://github.com/executablebooks/sphinx-autobuild) to continuously watch source files for changes and rebuild the documentation for you. Sphinx-autobuild will be installed automatically by the above `pip` command, so all you need to do is run: 62 | 63 | ```bash 64 | tox -e doclive 65 | ``` 66 | 67 | In a few seconds your web browser should open up the documentation. Now whenever you save a file the documentation will automatically regenerate and the webpage will refresh for you! 68 | 69 | 70 | ### Making frontpage gifs 71 | The frontpage gifs are generated from the `examples/create_example.py` script. I used peek with a resolution of 638x653 and recorded the keystrokes using `screenkey -g screenkey -g 640x537+308+543`. 72 | 73 | Those numbers came from using `slop` which can be used with screenkey like so: `screenkey -g $(slop -n -f '%g')` 74 | -------------------------------------------------------------------------------- /mpl_image_labeller/_widgets.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | 3 | import numpy as np 4 | from matplotlib.patches import Rectangle 5 | 6 | from ._util import add_text_to_rect, deactivatable_CallbackRegistry 7 | 8 | 9 | class _array_button(Rectangle): 10 | def __init__( 11 | self, 12 | x, 13 | y, 14 | width, 15 | height, 16 | active_color="green", 17 | inactive_color="tab:red", 18 | **rect_kwargs 19 | ): 20 | self._state = False 21 | self.active_color = active_color 22 | self.inactive_color = inactive_color 23 | 24 | super().__init__( 25 | (x, y), width, height, facecolor=inactive_color, picker=True, **rect_kwargs 26 | ) 27 | 28 | @property 29 | def state(self): 30 | return self._state 31 | 32 | @state.setter 33 | def state(self, value): 34 | if not isinstance(value, (bool, np.bool_)): 35 | raise TypeError("Button state must be a bool") 36 | self._state = value 37 | self._update_color() 38 | 39 | def _update_color(self): 40 | col = self.active_color if self.state else self.inactive_color 41 | self.set_facecolor(col) 42 | self.set_edgecolor(col) 43 | self.stale = True 44 | 45 | 46 | class button_array: 47 | def __init__(self, options, ax, active_color="green", inactive_color="tab:red"): 48 | self._ax = ax 49 | self._fig = ax.figure 50 | ax.axis("off") 51 | # if len(options) <= 3: 52 | # nrow = 1 53 | ncol = 4 54 | len(options) 55 | 56 | gap = 0.05 57 | width = (1 - (ncol - 1) * gap) / ncol 58 | height = width 59 | self._buttons = [] 60 | self._active_color = active_color 61 | self._inactive_color = inactive_color 62 | nrow = np.ceil(len(options) / ncol) 63 | total_height = nrow * height 64 | top = 0.5 + (total_height / 2) 65 | for i, o in enumerate(options): 66 | vert_pos = top - ((i // ncol) * (height + gap)) - height 67 | horiz_pos = (i % ncol) * (width + gap) 68 | button = _array_button( 69 | horiz_pos, vert_pos, width, height, active_color, inactive_color 70 | ) 71 | self._ax.add_artist(button) 72 | add_text_to_rect(str(o), button) 73 | self._buttons.append(button) 74 | self._ax.figure.canvas.mpl_connect("pick_event", self._on_pick) 75 | # self._ax.figure.canvas.mpl_connect("key_press_event", self._on_pick) 76 | self.draw_on = True 77 | self._observers = deactivatable_CallbackRegistry() 78 | 79 | def on_state_change(self, func): 80 | """ 81 | Connect *func* to be called the state of the checked buttons changes. 82 | *func* will receive the updated state and the old state 83 | 84 | Maybe todo: also send the diff of the state. 85 | """ 86 | self._observers.connect( 87 | "state-changed", lambda new_state, old_state: func(new_state, old_state) 88 | ) 89 | 90 | def _on_pick(self, event): 91 | if event.artist in self._buttons: 92 | old_state = self.get_states() 93 | event.artist.state = not event.artist.state 94 | self._observers.process("state-changed", self.get_states(), old_state) 95 | # TODO: consider whether to draw here? 96 | # maybe make it toggleable a la draw_on 97 | self._fig.canvas.draw() 98 | 99 | def activate_all(self): 100 | for b in self._buttons: 101 | b.state = True 102 | if self.draw_on: 103 | self._fig.canvas.draw() 104 | 105 | def set_states(self, states): 106 | """ 107 | Update the "buttons" to 108 | Parameters 109 | ---------- 110 | toggled : dict or list 111 | mapping i -> True/False. Does need to include states for every button. 112 | """ 113 | if isinstance(states, dict): 114 | enum = states.items() 115 | else: 116 | enum = enumerate(states) 117 | for i, s in enum: 118 | self._buttons[i].state = s 119 | 120 | def get_states(self): 121 | """ 122 | Get the states of all buttons as a list 123 | """ 124 | states = [] 125 | for button in self._buttons: 126 | states.append(button.state) 127 | return states 128 | 129 | @contextlib.contextmanager 130 | def no_callbacks(self): 131 | with self._observers.deactivate(): 132 | yield 133 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | import inspect 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | import os 16 | import shutil 17 | import subprocess 18 | import sys 19 | 20 | try: 21 | from mpl_image_labeller import __version__ as release 22 | except ImportError: 23 | release = "unknown" 24 | 25 | 26 | # -- Project information ----------------------------------------------------- 27 | 28 | project = "mpl-image-labeller" 29 | copyright = "2021, Ian Hunt-Isaak" 30 | author = "Ian Hunt-Isaak" 31 | 32 | 33 | # -- Generate API ------------------------------------------------------------ 34 | api_folder_name = "api" 35 | shutil.rmtree(api_folder_name, ignore_errors=True) # in case of new or renamed modules 36 | subprocess.call( 37 | " ".join( 38 | [ 39 | "sphinx-apidoc", 40 | f"-o {api_folder_name}/", 41 | "--force", 42 | "--no-toc", 43 | "--templatedir _templates", 44 | "--separate", 45 | "../mpl_image_labeller/", 46 | # excluded modules 47 | # nothing here for cookiecutter 48 | ] 49 | ), 50 | shell=True, 51 | ) 52 | 53 | 54 | # -- General configuration --------------------------------------------------- 55 | 56 | # Add any Sphinx extension module names here, as strings. They can be 57 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 58 | # ones. 59 | extensions = [ 60 | "jupyter_sphinx", 61 | "myst_nb", 62 | "numpydoc", 63 | "sphinx.ext.autodoc", 64 | "sphinx.ext.intersphinx", 65 | "sphinx.ext.linkcode", 66 | "sphinx.ext.mathjax", 67 | "sphinx.ext.napoleon", 68 | "sphinx_copybutton", 69 | "sphinx_panels", 70 | "sphinx_thebe", 71 | "sphinx_togglebutton", 72 | ] 73 | 74 | 75 | # API settings 76 | autodoc_default_options = { 77 | "members": True, 78 | "show-inheritance": True, 79 | "undoc-members": True, 80 | } 81 | add_module_names = False 82 | napoleon_google_docstring = False 83 | napoleon_include_private_with_doc = False 84 | napoleon_include_special_with_doc = False 85 | napoleon_numpy_docstring = True 86 | napoleon_use_admonition_for_examples = False 87 | napoleon_use_admonition_for_notes = False 88 | napoleon_use_admonition_for_references = False 89 | napoleon_use_ivar = False 90 | napoleon_use_param = False 91 | napoleon_use_rtype = False 92 | numpydoc_show_class_members = False 93 | 94 | # Cross-referencing configuration 95 | default_role = "py:obj" 96 | primary_domain = "py" 97 | nitpicky = True # warn if cross-references are missing 98 | 99 | # Intersphinx settings 100 | intersphinx_mapping = { 101 | "ipywidgets": ("https://ipywidgets.readthedocs.io/en/stable", None), 102 | "matplotlib": ("https://matplotlib.org/stable", None), 103 | "numpy": ("https://numpy.org/doc/stable", None), 104 | "python": ("https://docs.python.org/3", None), 105 | } 106 | 107 | # remove panels css to get wider main content 108 | panels_add_bootstrap_css = False 109 | 110 | # Settings for copybutton 111 | copybutton_prompt_is_regexp = True 112 | copybutton_prompt_text = r">>> |\.\.\. " # doctest 113 | 114 | # Settings for linkcheck 115 | linkcheck_anchors = False 116 | linkcheck_ignore = [] # type: ignore 117 | 118 | execution_timeout = -1 119 | jupyter_execute_notebooks = "off" 120 | if "EXECUTE_NB" in os.environ: 121 | print("\033[93;1mWill run Jupyter notebooks!\033[0m") 122 | jupyter_execute_notebooks = "force" 123 | 124 | # Settings for myst-parser 125 | myst_enable_extensions = [ 126 | "amsmath", 127 | "colon_fence", 128 | "dollarmath", 129 | "smartquotes", 130 | "substitution", 131 | ] 132 | suppress_warnings = [ 133 | "myst.header", 134 | ] 135 | 136 | # Add any paths that contain templates here, relative to this directory. 137 | templates_path = ["_templates"] 138 | 139 | # List of patterns, relative to source directory, that match files and 140 | # directories to ignore when looking for source files. 141 | # This pattern also affects html_static_path and html_extra_path. 142 | exclude_patterns = [ 143 | "**ipynb_checkpoints", 144 | ".DS_Store", 145 | "Thumbs.db", 146 | "_build", 147 | ] 148 | 149 | 150 | # -- Options for HTML output ------------------------------------------------- 151 | 152 | # The theme to use for HTML and HTML Help pages. See the documentation for 153 | # a list of builtin themes. 154 | # 155 | html_copy_source = True # needed for download notebook button 156 | html_css_files = [ 157 | "custom.css", 158 | ] 159 | html_sourcelink_suffix = "" 160 | html_static_path = ["_static"] 161 | html_theme = "sphinx_book_theme" 162 | html_theme_options = { 163 | "launch_buttons": { 164 | "binderhub_url": "https://mybinder.org", 165 | "colab_url": "https://colab.research.google.com", 166 | "notebook_interface": "jupyterlab", 167 | "thebe": True, 168 | "thebelab": True, 169 | }, 170 | "path_to_docs": "docs", 171 | "repository_branch": "main", 172 | "repository_url": "https://github.com/ianhi/mpl-image-labeller", 173 | "use_download_button": True, 174 | "use_edit_page_button": True, 175 | "use_issues_button": True, 176 | "use_repository_button": True, 177 | } 178 | html_title = "mpl-image-labeller" 179 | 180 | master_doc = "index" 181 | thebe_config = { 182 | "repository_url": html_theme_options["repository_url"], 183 | "repository_branch": html_theme_options["repository_branch"], 184 | } 185 | 186 | 187 | # based on pandas/doc/source/conf.py 188 | def linkcode_resolve(domain, info): 189 | """ 190 | Determine the URL corresponding to Python object 191 | """ 192 | if domain != "py": 193 | return None 194 | 195 | modname = info["module"] 196 | fullname = info["fullname"] 197 | 198 | submod = sys.modules.get(modname) 199 | if submod is None: 200 | return None 201 | 202 | obj = submod 203 | for part in fullname.split("."): 204 | try: 205 | obj = getattr(obj, part) 206 | except AttributeError: 207 | return None 208 | 209 | try: 210 | fn = inspect.getsourcefile(inspect.unwrap(obj)) 211 | except TypeError: 212 | fn = None 213 | if not fn: 214 | return None 215 | 216 | try: 217 | source, lineno = inspect.getsourcelines(obj) 218 | except OSError: 219 | lineno = None 220 | 221 | if lineno: 222 | linespec = f"#L{lineno}-L{lineno + len(source) - 1}" 223 | else: 224 | linespec = "" 225 | 226 | fn = os.path.relpath(fn, start=os.path.dirname("../mpl_image_labeller")) 227 | 228 | return f"https://github.com/ianhi/mpl-image-labeller/blob/main/mpl_image_labeller/{fn}{linespec}" # noqa 229 | -------------------------------------------------------------------------------- /mpl_image_labeller/_labeller.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | import numpy as np 4 | from matplotlib.backend_bases import key_press_handler 5 | from matplotlib.cbook import CallbackRegistry 6 | from matplotlib.figure import Figure 7 | 8 | from ._util import ConflictingArgumentsError, list_to_onehot, onehot_to_list 9 | from ._widgets import button_array 10 | 11 | 12 | def gen_key_press_handler(skip_keys): 13 | def handler(event, canvas=None, toolbar=None): 14 | if event.key in skip_keys: 15 | return 16 | key_press_handler(event, canvas, toolbar) 17 | 18 | return handler 19 | 20 | 21 | class image_labeller: 22 | def __init__( 23 | self, 24 | images, 25 | classes, 26 | multiclass=False, 27 | label_keymap: Union[List[str], str] = "1234", 28 | title=None, 29 | init_labels=None, 30 | init_labels_onehot=None, 31 | labelling_advances_image: bool = True, 32 | N_images=None, 33 | fig: Figure = None, 34 | **imshow_kwargs, 35 | ): 36 | """ 37 | Parameters 38 | ---------- 39 | images : (N, Y, X) ArrayLike 40 | classes : (N,) ArrayLike 41 | The available classes for the images. 42 | multiclass : bool, default: False 43 | Whether to allow for an image to have multiple classes or just one. 44 | label_keymap : list of str, or str 45 | If a str must be one of the predefined values *1234* (1, 2, 3,..), 46 | *qwerty* (q, w, e, r, t, y). If an iterable then the items will be assigned 47 | in order to the classes. WARNING: These keys will be removed from the 48 | default keymap for that figure. So if *s* is included then *s* will no 49 | longer perform savefig. 50 | title : str or Callable, optional 51 | A {} style format string for the title of images or a function that 52 | returns a title string given the image index. If a format string 53 | then it will be used with ``title.format(image_index=...)``. If *None* 54 | then the default str of 'Image Index: {image_index}' will be used. 55 | init_labels: list of list or list of str, optional 56 | The initial labels for the images. If given it must be the same length as 57 | *images* and each entry should be either a single class or an iterable of 58 | classes. Incompatible with *init_labels_onehot*. 59 | init_labels_onehot: 2D ArrayLike, optional 60 | The initial labels for the images as a onehot encoding. If given it must 61 | have shape (N_images, N_classes) and be castable to a boolean array. 62 | Incompatible with *init_labels*. 63 | labelling_advances_image : bool, default: True 64 | Whether labelling an image should advance to the next image. 65 | Ignored if *multiclass* is True. 66 | N_images : int or None 67 | The number of images. Required if passing a Callable for images, otherwise 68 | ignored. 69 | fig : Figure 70 | An empty figure to build the UI in. Use this to embed image_labeller into 71 | a gui framework. 72 | **imshow_kwargs : 73 | kwargs to be passed to the imshow function that displays the images. 74 | """ 75 | self._multi = multiclass 76 | self._images = images 77 | if title is None: 78 | title = "Image Index: {image_index}" 79 | if isinstance(title, str): 80 | 81 | def _title(image_index): 82 | return title.format(image_index=image_index) 83 | 84 | self._title = _title 85 | elif callable(title): 86 | self._title = title 87 | else: 88 | raise TypeError("Title must be a str or a Callable") 89 | 90 | if callable(images): 91 | if not isinstance(N_images, int): 92 | raise TypeError( 93 | "If images is a callable then N_images must be provided" 94 | ) 95 | self._N_images = N_images 96 | 97 | def _get_image(i): 98 | return self._images(i) 99 | 100 | else: 101 | self._N_images = len(images) 102 | 103 | def _get_image(i): 104 | return self._images[i] 105 | 106 | self._get_image = _get_image 107 | 108 | self._label_advances = labelling_advances_image 109 | 110 | # TODO: sync this up with labels 111 | # TODO: make sure init_labels does something here 112 | self._onehot = np.zeros((self._N_images, len(classes)), dtype=bool) 113 | 114 | if label_keymap == "1234": 115 | if len(classes) > 10: 116 | raise ValueError( 117 | "More classes than numbers on the keyboard," 118 | "please provide a custom keymap" 119 | ) 120 | self._label_keymap = {f"{(i+1)%10}": i for i in range(len(classes))} 121 | elif label_keymap == "qwerty": 122 | if len(classes) > len("qwertyuiop"): 123 | raise ValueError( 124 | "More classes than length of qwertyuiop," 125 | "please provide a custom keymap" 126 | ) 127 | self._label_keymap = {"qwertyuiop"[c]: c for c in range(len(classes))} 128 | elif len(label_keymap) != len(classes): 129 | raise ValueError("label_keymap must have the same length as classes") 130 | else: 131 | self._label_keymap = {label_keymap[i]: i for i in range(len(label_keymap))} 132 | 133 | # make array for easy indexing 134 | self._classes = np.asarray(classes) 135 | 136 | if init_labels is not None and init_labels_onehot is not None: 137 | raise ConflictingArgumentsError( 138 | "init_labels and init_labels_onehot cannot both be *None*" 139 | ) 140 | 141 | if init_labels is not None: 142 | # length errors are handled in the setter 143 | self.labels = init_labels 144 | elif init_labels_onehot is not None: 145 | self.labels_onehot = init_labels_onehot 146 | else: 147 | if self._multi: 148 | self.labels = [[]] * self._N_images 149 | else: 150 | self.labels = [None] * self._N_images 151 | 152 | if fig is None: 153 | import matplotlib.pyplot as plt 154 | 155 | self._fig = plt.figure(constrained_layout=True) 156 | else: 157 | self._fig = fig 158 | 159 | # "remove" keys from the default keymap by overwriting the key handler method 160 | # see https://gitter.im/matplotlib/matplotlib?at=617988daee6c260cf743e9cb 161 | self._fig.canvas.mpl_disconnect(self._fig.canvas.manager.key_press_handler_id) 162 | 163 | self._fig.canvas.manager.key_press_handler_id = self._fig.canvas.mpl_connect( 164 | "key_press_event", gen_key_press_handler(list(self._label_keymap.keys())) 165 | ) 166 | 167 | self._image_index = 0 168 | if self._multi: 169 | self._image_ax, self._info_ax = self._fig.subplots(1, 2) 170 | else: 171 | gs = self._fig.add_gridspec(1, 5) 172 | self._image_ax = self._fig.add_subplot(gs[:, :-1]) 173 | self._info_ax = self._fig.add_subplot(gs[:, -1]) 174 | 175 | self._info_ax.axis("off") 176 | aspect = imshow_kwargs.pop("aspect", "equal") 177 | self._vmin = imshow_kwargs.get("vmin", None) 178 | self._vmax = imshow_kwargs.get("vmax", None) 179 | self._im = self._image_ax.imshow( 180 | self._get_image(0), aspect=aspect, **imshow_kwargs 181 | ) 182 | 183 | if self._multi: 184 | 185 | def on_state_change(new_state, old_state): 186 | self._onehot[self._image_index] = new_state 187 | # self.labels[self._image_index] = self._classes[new_state] 188 | 189 | texts = [] 190 | for key, klass in zip(self._label_keymap.keys(), classes): 191 | texts.append(f"[{key}]\n{str(klass)}") 192 | self._buttons = button_array(texts, self._info_ax) 193 | self._buttons.on_state_change(on_state_change) 194 | else: 195 | # shift axis to make room for list of keybindings 196 | # box = self._image_ax.get_position() 197 | # box.x0 = box.x0 - 0.20 198 | # box.x1 = box.x1 - 0.20 199 | # self._image_ax.set_position(box) 200 | 201 | # these are matplotlib.patch.Patch properties 202 | props = dict(boxstyle="round", facecolor="wheat", alpha=0.5) 203 | 204 | textstr = """Keybindings 205 | <- : Previous Image 206 | -> : Next Image""" 207 | horiz_pos = 0.575 208 | 209 | self._info_ax.text( 210 | horiz_pos, 211 | 0.75, 212 | textstr, 213 | # transform=self._fig.transFigure, 214 | fontsize=14, 215 | verticalalignment="top", 216 | bbox=props, 217 | horizontalalignment="left", 218 | ) 219 | 220 | textstr = """Class Keybindings:\n""" 221 | for k, v in self._label_keymap.items(): 222 | textstr += f"{k} : {self._classes[v]}\n" 223 | 224 | self._info_ax.text( 225 | horiz_pos, 226 | 0.55, 227 | textstr, 228 | # transform=self._fig.transFigure, 229 | fontsize=14, 230 | verticalalignment="top", 231 | bbox=props, 232 | ) 233 | 234 | textstr = f"Current Class:\n{str(self._labels[0])}" 235 | self._class_display = self._info_ax.text( 236 | horiz_pos, 237 | 0.25, 238 | textstr, 239 | # transform=self._fig.transFigure, 240 | fontsize=14, 241 | verticalalignment="top", 242 | bbox=props, 243 | ) 244 | 245 | self._update_title() 246 | 247 | self._fig.canvas.mpl_connect("key_press_event", self._key_press) 248 | self._observers = CallbackRegistry() 249 | 250 | @property 251 | def ax(self): 252 | """ 253 | **readonly** - The `~matplotlib.axes.Axes` object the image's are displayed on. 254 | """ 255 | return self._image_ax 256 | 257 | @property 258 | def labels(self): 259 | """ 260 | The current labels as a list of lists or a list of strings. 261 | """ 262 | if self._multi: 263 | return onehot_to_list(self._onehot, self._classes) 264 | else: 265 | return self._labels 266 | 267 | @labels.setter 268 | def labels(self, value): 269 | if len(value) != self._N_images: 270 | raise ValueError( 271 | "Length of labels must be the same as the number of images" 272 | ) 273 | if self._multi: 274 | self._onehot = list_to_onehot(value, self._classes) 275 | else: 276 | self._labels = value 277 | 278 | @property 279 | def labels_onehot(self): 280 | """ 281 | The current labels as a one hot encoding. 282 | """ 283 | if self._multi: 284 | return self._onehot 285 | else: 286 | return list_to_onehot(self._labels, self._classes) 287 | 288 | @labels_onehot.setter 289 | def labels_onehot(self, value): 290 | value = np.asanyarray(value) 291 | expected_shape = (self._N_images, len(self._classes)) 292 | if value.shape != expected_shape: 293 | raise ValueError( 294 | "One hot labels must have shape (N images, num classes." 295 | f"Expected shape {expected_shape} but got {value.shape}" 296 | ) 297 | if self._multi: 298 | self._onehot = value 299 | else: 300 | self._labels = onehot_to_list(value) 301 | 302 | @property 303 | def image_index(self): 304 | """ 305 | **int** the index of the currently displayed image. 306 | """ 307 | return self._image_index 308 | 309 | @image_index.setter 310 | def image_index(self, value): 311 | if value == self._image_index: 312 | # quick return to avoid unnecessary draw 313 | return 314 | elif value >= self._N_images: 315 | if self._image_index == self._N_images - 1: 316 | # quick return to avoid unnecessary draw 317 | return 318 | self._image_index = self._N_images - 1 319 | elif value < 0: 320 | if self._image_index == 0: 321 | # quick return to avoid unnecessary draw 322 | return 323 | self._image_index = 0 324 | else: 325 | self._image_index = value 326 | self._update_displayed() 327 | 328 | def _update_title(self): 329 | text = self._title(self._image_index) 330 | self._image_ax.set_title(text) 331 | 332 | def _update_displayed(self): 333 | image = np.asarray(self._get_image(self._image_index)) 334 | # for some reason this keeps getting turned off by something 335 | self._image_ax.set_autoscale_on(True) 336 | self._im.set_data(image) 337 | 338 | # autoscaling of colormaps if necessary 339 | if image.ndim != 3: 340 | if self._vmin is None: 341 | self._im.norm.vmin = image.min() 342 | if self._vmax is None: 343 | self._im.norm.vmax = image.max() 344 | 345 | self._im.set_extent((-0.5, image.shape[1] - 0.5, image.shape[0] - 0.5, -0.5)) 346 | self._update_title() 347 | self._observers.process("image-changed", self._image_index, image) 348 | if self._multi: 349 | with self._buttons.no_callbacks(): 350 | # TODO: check that this no_callbacks actually works.... 351 | new_state = self._onehot[self._image_index] 352 | self._buttons.set_states(new_state) 353 | else: 354 | textstr = f"Current Class:\n{str(self._labels[self._image_index])}" 355 | self._class_display.set_text(textstr) 356 | self._fig.canvas.draw_idle() 357 | 358 | def _key_press(self, event): 359 | if event.key == "left": 360 | self.image_index -= 1 361 | elif event.key == "right": 362 | self.image_index += 1 363 | elif event.key in self._label_keymap: 364 | which_label = self._label_keymap[event.key] 365 | klass = self._classes[which_label] 366 | if self._multi: 367 | img_labels = self._onehot[self._image_index] 368 | img_labels[which_label] = not img_labels[which_label] 369 | self._buttons.set_states(img_labels) 370 | else: 371 | self._labels[self._image_index] = klass 372 | self._observers.process("label-assigned", self._image_index, klass) 373 | if self._label_advances and not self._multi: 374 | if self.image_index == self._N_images - 1: 375 | # make sure we update the title we are on the last image 376 | self._update_title() 377 | textstr = f"Current Class:\n{str(self._labels[self._image_index])}" 378 | self._class_display.set_text(textstr) 379 | self._fig.canvas.draw_idle() 380 | else: 381 | self.image_index += 1 382 | else: 383 | # only updating the text 384 | self._update_title() 385 | # TODO: blit just the text here 386 | self._fig.canvas.draw_idle() 387 | 388 | def on_label_assigned(self, func): 389 | """ 390 | Connect *func* as a callback function for when a label is assigned 391 | to an image. *func* will receive the index of the image and the 392 | new class. 393 | 394 | Parameters 395 | ---------- 396 | func : callable 397 | Function to call when a point is added. 398 | 399 | Returns 400 | ------- 401 | int 402 | Connection id (which can be used to disconnect *func*). 403 | """ 404 | return self._observers.connect("label-assigned", lambda *args: func(*args)) 405 | 406 | def on_image_changed(self, func): 407 | """ 408 | Connect *func* as a callback function for when the displayed image 409 | is changed. *func* will receive the index of the new image and the 410 | image. `fig.canvas.draw_idle` will be called after the callback is 411 | executed so if you are modifying the figure then you do not need to 412 | explicitly call *draw* yourself. 413 | 414 | Parameters 415 | ---------- 416 | func : callable 417 | Function to call when a point is added. 418 | 419 | Returns 420 | ------- 421 | int 422 | Connection id (which can be used to disconnect *func*). 423 | """ 424 | return self._observers.connect("image-changed", lambda *args: func(*args)) 425 | --------------------------------------------------------------------------------