├── .github └── workflows │ └── test_and_deploy.yml ├── .gitignore ├── .napari └── DESCRIPTION.md ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── docs └── index.md ├── mkdocs.yml ├── pyproject.toml ├── setup.cfg ├── src └── napari_deeplabcut │ ├── __init__.py │ ├── _reader.py │ ├── _tests │ ├── __init__.py │ ├── conftest.py │ ├── test_keypoints.py │ ├── test_misc.py │ ├── test_reader.py │ └── test_widgets.py │ ├── _widgets.py │ ├── _writer.py │ ├── assets │ ├── black │ │ ├── Back.png │ │ ├── Customize.png │ │ ├── Forward.png │ │ ├── Home.png │ │ ├── Pan.png │ │ ├── Pan_checked.png │ │ ├── Save.png │ │ ├── Subplots.png │ │ ├── Zoom.png │ │ └── Zoom_checked.png │ ├── napari_shortcuts.svg │ ├── superanimal_quadruped.jpg │ ├── superanimal_quadruped.json │ ├── superanimal_topviewmouse.jpg │ ├── superanimal_topviewmouse.json │ └── white │ │ ├── Back.png │ │ ├── Customize.png │ │ ├── Forward.png │ │ ├── Home.png │ │ ├── Pan.png │ │ ├── Pan_checked.png │ │ ├── Save.png │ │ ├── Subplots.png │ │ ├── Zoom.png │ │ └── Zoom_checked.png │ ├── keypoints.py │ ├── misc.py │ ├── napari.yaml │ └── styles │ ├── dark.mplstyle │ └── light.mplstyle └── tox.ini /.github/workflows/test_and_deploy.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: tests 5 | 6 | on: 7 | push: 8 | branches: 9 | - main 10 | - npe2 11 | tags: 12 | - "v*" # Push events to matching v*, i.e. v1.0, v20.15.10 13 | pull_request: 14 | branches: 15 | - main 16 | - npe2 17 | workflow_dispatch: 18 | inputs: 19 | force_deploy: 20 | description: 'Force deployment even if tests fail' 21 | required: true 22 | type: boolean 23 | 24 | jobs: 25 | test: 26 | name: ${{ matrix.platform }} py${{ matrix.python-version }} 27 | runs-on: ${{ matrix.platform }} 28 | strategy: 29 | matrix: 30 | platform: [ubuntu-latest, windows-latest, macos-latest] 31 | python-version: [3.9, "3.10"] 32 | 33 | steps: 34 | - uses: actions/checkout@v4 35 | 36 | - name: Set up Python ${{ matrix.python-version }} 37 | uses: actions/setup-python@v5 38 | with: 39 | python-version: ${{ matrix.python-version }} 40 | 41 | # these libraries enable testing on Qt on linux 42 | - uses: tlambert03/setup-qt-libs@v1 43 | 44 | # strategy borrowed from vispy for installing opengl libs on windows 45 | - name: Install Windows OpenGL 46 | if: runner.os == 'Windows' 47 | run: | 48 | git clone --depth 1 https://github.com/pyvista/gl-ci-helpers.git 49 | powershell gl-ci-helpers/appveyor/install_opengl.ps1 50 | if (Test-Path -Path "C:\Windows\system32\opengl32.dll" -PathType Leaf) {Exit 0} else {Exit 1} 51 | 52 | # note: if you need dependencies from conda, considering using 53 | # setup-miniconda: https://github.com/conda-incubator/setup-miniconda 54 | # and 55 | # tox-conda: https://github.com/tox-dev/tox-conda 56 | - name: Install dependencies 57 | run: | 58 | python -m pip install --upgrade pip 59 | python -m pip install setuptools tox tox-gh-actions 60 | 61 | # this runs the platform-specific tests declared in tox.ini 62 | - name: Test with tox 63 | uses: aganders3/headless-gui@v2 64 | with: 65 | run: python -m tox 66 | env: 67 | PLATFORM: ${{ matrix.platform }} 68 | 69 | - name: Coverage 70 | uses: codecov/codecov-action@v4 71 | 72 | deploy: 73 | # this will run when you have tagged a commit, starting with "v*" 74 | # and requires that you have put your twine API key in your 75 | # github secrets (see readme for details) 76 | needs: [test] 77 | runs-on: ubuntu-latest 78 | if: | 79 | always() && 80 | inputs.force_deploy || 81 | (contains(github.ref, 'tags') && contains(needs.test.result, 'success')) 82 | steps: 83 | - uses: actions/checkout@v4 84 | - name: Set up Python 85 | uses: actions/setup-python@v4 86 | with: 87 | python-version: "3.x" 88 | - name: Install dependencies 89 | run: | 90 | python -m pip install --upgrade pip 91 | pip install -U setuptools setuptools_scm wheel twine build 92 | - name: Build and publish 93 | env: 94 | TWINE_USERNAME: __token__ 95 | TWINE_PASSWORD: ${{ secrets.TWINE_API_KEY }} 96 | run: | 97 | git tag 98 | python -m build . 99 | twine upload dist/* 100 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | .napari_cache 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask instance folder 58 | instance/ 59 | 60 | # Sphinx documentation 61 | docs/_build/ 62 | 63 | # MkDocs documentation 64 | /site/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # Pycharm and VSCode 70 | .idea/ 71 | venv/ 72 | .vscode/ 73 | 74 | # IPython Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # OS 81 | .DS_Store 82 | 83 | # written by setuptools_scm 84 | **/_version.py 85 | 86 | -------------------------------------------------------------------------------- /.napari/DESCRIPTION.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 91 | 92 | The developer has not yet provided a napari-hub specific description. 93 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.1.0 4 | hooks: 5 | - id: check-docstring-first 6 | - id: end-of-file-fixer 7 | - id: trailing-whitespace 8 | - repo: https://github.com/asottile/setup-cfg-fmt 9 | rev: v1.20.0 10 | hooks: 11 | - id: setup-cfg-fmt 12 | - repo: https://github.com/PyCQA/flake8 13 | rev: 4.0.1 14 | hooks: 15 | - id: flake8 16 | additional_dependencies: [flake8-typing-imports==1.7.0] 17 | - repo: https://github.com/myint/autoflake 18 | rev: v1.4 19 | hooks: 20 | - id: autoflake 21 | args: ["--in-place", "--remove-all-unused-imports"] 22 | - repo: https://github.com/PyCQA/isort 23 | rev: 5.10.1 24 | hooks: 25 | - id: isort 26 | - repo: https://github.com/psf/black 27 | rev: 22.1.0 28 | hooks: 29 | - id: black 30 | - repo: https://github.com/asottile/pyupgrade 31 | rev: v2.31.0 32 | hooks: 33 | - id: pyupgrade 34 | args: [--py37-plus, --keep-runtime-typing] 35 | - repo: https://github.com/tlambert03/napari-plugin-checks 36 | rev: v0.2.0 37 | hooks: 38 | - id: napari-plugin-checks 39 | # https://mypy.readthedocs.io/en/stable/introduction.html 40 | # you may wish to add this as well! 41 | # - repo: https://github.com/pre-commit/mirrors-mypy 42 | # rev: v0.910-1 43 | # hooks: 44 | # - id: mypy 45 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at mackenzie.mathis@epfl.ch. All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU LESSER GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | 9 | This version of the GNU Lesser General Public License incorporates 10 | the terms and conditions of version 3 of the GNU General Public 11 | License, supplemented by the additional permissions listed below. 12 | 13 | 0. Additional Definitions. 14 | 15 | As used herein, "this License" refers to version 3 of the GNU Lesser 16 | General Public License, and the "GNU GPL" refers to version 3 of the GNU 17 | General Public License. 18 | 19 | "The Library" refers to a covered work governed by this License, 20 | other than an Application or a Combined Work as defined below. 21 | 22 | An "Application" is any work that makes use of an interface provided 23 | by the Library, but which is not otherwise based on the Library. 24 | Defining a subclass of a class defined by the Library is deemed a mode 25 | of using an interface provided by the Library. 26 | 27 | A "Combined Work" is a work produced by combining or linking an 28 | Application with the Library. The particular version of the Library 29 | with which the Combined Work was made is also called the "Linked 30 | Version". 31 | 32 | The "Minimal Corresponding Source" for a Combined Work means the 33 | Corresponding Source for the Combined Work, excluding any source code 34 | for portions of the Combined Work that, considered in isolation, are 35 | based on the Application, and not on the Linked Version. 36 | 37 | The "Corresponding Application Code" for a Combined Work means the 38 | object code and/or source code for the Application, including any data 39 | and utility programs needed for reproducing the Combined Work from the 40 | Application, but excluding the System Libraries of the Combined Work. 41 | 42 | 1. Exception to Section 3 of the GNU GPL. 43 | 44 | You may convey a covered work under sections 3 and 4 of this License 45 | without being bound by section 3 of the GNU GPL. 46 | 47 | 2. Conveying Modified Versions. 48 | 49 | If you modify a copy of the Library, and, in your modifications, a 50 | facility refers to a function or data to be supplied by an Application 51 | that uses the facility (other than as an argument passed when the 52 | facility is invoked), then you may convey a copy of the modified 53 | version: 54 | 55 | a) under this License, provided that you make a good faith effort to 56 | ensure that, in the event an Application does not supply the 57 | function or data, the facility still operates, and performs 58 | whatever part of its purpose remains meaningful, or 59 | 60 | b) under the GNU GPL, with none of the additional permissions of 61 | this License applicable to that copy. 62 | 63 | 3. Object Code Incorporating Material from Library Header Files. 64 | 65 | The object code form of an Application may incorporate material from 66 | a header file that is part of the Library. You may convey such object 67 | code under terms of your choice, provided that, if the incorporated 68 | material is not limited to numerical parameters, data structure 69 | layouts and accessors, or small macros, inline functions and templates 70 | (ten or fewer lines in length), you do both of the following: 71 | 72 | a) Give prominent notice with each copy of the object code that the 73 | Library is used in it and that the Library and its use are 74 | covered by this License. 75 | 76 | b) Accompany the object code with a copy of the GNU GPL and this license 77 | document. 78 | 79 | 4. Combined Works. 80 | 81 | You may convey a Combined Work under terms of your choice that, 82 | taken together, effectively do not restrict modification of the 83 | portions of the Library contained in the Combined Work and reverse 84 | engineering for debugging such modifications, if you also do each of 85 | the following: 86 | 87 | a) Give prominent notice with each copy of the Combined Work that 88 | the Library is used in it and that the Library and its use are 89 | covered by this License. 90 | 91 | b) Accompany the Combined Work with a copy of the GNU GPL and this license 92 | document. 93 | 94 | c) For a Combined Work that displays copyright notices during 95 | execution, include the copyright notice for the Library among 96 | these notices, as well as a reference directing the user to the 97 | copies of the GNU GPL and this license document. 98 | 99 | d) Do one of the following: 100 | 101 | 0) Convey the Minimal Corresponding Source under the terms of this 102 | License, and the Corresponding Application Code in a form 103 | suitable for, and under terms that permit, the user to 104 | recombine or relink the Application with a modified version of 105 | the Linked Version to produce a modified Combined Work, in the 106 | manner specified by section 6 of the GNU GPL for conveying 107 | Corresponding Source. 108 | 109 | 1) Use a suitable shared library mechanism for linking with the 110 | Library. A suitable mechanism is one that (a) uses at run time 111 | a copy of the Library already present on the user's computer 112 | system, and (b) will operate properly with a modified version 113 | of the Library that is interface-compatible with the Linked 114 | Version. 115 | 116 | e) Provide Installation Information, but only if you would otherwise 117 | be required to provide such information under section 6 of the 118 | GNU GPL, and only to the extent that such information is 119 | necessary to install and execute a modified version of the 120 | Combined Work produced by recombining or relinking the 121 | Application with a modified version of the Linked Version. (If 122 | you use option 4d0, the Installation Information must accompany 123 | the Minimal Corresponding Source and Corresponding Application 124 | Code. If you use option 4d1, you must provide the Installation 125 | Information in the manner specified by section 6 of the GNU GPL 126 | for conveying Corresponding Source.) 127 | 128 | 5. Combined Libraries. 129 | 130 | You may place library facilities that are a work based on the 131 | Library side by side in a single library together with other library 132 | facilities that are not Applications and are not covered by this 133 | License, and convey such a combined library under terms of your 134 | choice, if you do both of the following: 135 | 136 | a) Accompany the combined library with a copy of the same work based 137 | on the Library, uncombined with any other library facilities, 138 | conveyed under the terms of this License. 139 | 140 | b) Give prominent notice with the combined library that part of it 141 | is a work based on the Library, and explaining where to find the 142 | accompanying uncombined form of the same work. 143 | 144 | 6. Revised Versions of the GNU Lesser General Public License. 145 | 146 | The Free Software Foundation may publish revised and/or new versions 147 | of the GNU Lesser General Public License from time to time. Such new 148 | versions will be similar in spirit to the present version, but may 149 | differ in detail to address new problems or concerns. 150 | 151 | Each version is given a distinguishing version number. If the 152 | Library as you received it specifies that a certain numbered version 153 | of the GNU Lesser General Public License "or any later version" 154 | applies to it, you have the option of following the terms and 155 | conditions either of that published version or of any later version 156 | published by the Free Software Foundation. If the Library as you 157 | received it does not specify a version number of the GNU Lesser 158 | General Public License, you may choose any version of the GNU Lesser 159 | General Public License ever published by the Free Software Foundation. 160 | 161 | If the Library as you received it specifies that a proxy can decide 162 | whether future versions of the GNU Lesser General Public License shall 163 | apply, that proxy's public statement of acceptance of any version is 164 | permanent authorization for you to choose that version for the 165 | Library. 166 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include README.md 3 | include src/napari_deeplabcut/styles/*.mplstyle 4 | include src/napari_deeplabcut/assets/napari_shortcuts.svg 5 | include src/napari_deeplabcut/assets/superanimal_topviewmouse.jpg 6 | include src/napari_deeplabcut/assets/superanimal_topviewmouse.json 7 | include src/napari_deeplabcut/assets/superanimal_quadruped.jpg 8 | include src/napari_deeplabcut/assets/superanimal_quadruped.json 9 | 10 | recursive-exclude * __pycache__ 11 | recursive-exclude * *.py[co] 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # napari-deeplabcut: keypoint annotation for pose estimation 2 | 3 | 4 | 5 | napari+deeplabcut 6 | 7 | [📚Documentation](https://deeplabcut.github.io/DeepLabCut/README.html) | 8 | [🛠️ DeepLabCut Installation](https://deeplabcut.github.io/DeepLabCut/docs/installation.html) | 9 | [🌎 Home Page](https://www.deeplabcut.org) | 10 | 11 | [![License: BSD-3](https://img.shields.io/badge/License-BSD3-blue.svg)](https://www.gnu.org/licenses/bsd3) 12 | [![PyPI](https://img.shields.io/pypi/v/napari-deeplabcut.svg?color=green)](https://pypi.org/project/napari-deeplabcut) 13 | [![Python Version](https://img.shields.io/pypi/pyversions/napari-deeplabcut.svg?color=green)](https://python.org) 14 | [![tests](https://github.com/DeepLabCut/napari-deeplabcut/workflows/tests/badge.svg)](https://github.com/DeepLabCut/napari-deeplabcut/actions) 15 | [![codecov](https://codecov.io/gh/DeepLabCut/napari-deeplabcut/branch/main/graph/badge.svg)](https://codecov.io/gh/DeepLabCut/napari-deeplabcut) 16 | [![napari hub](https://img.shields.io/endpoint?url=https://api.napari-hub.org/shields/napari-deeplabcut)](https://napari-hub.org/plugins/napari-deeplabcut) 17 | 18 | A napari plugin for keypoint annotation, also used within DeepLabCut! 19 | 20 | 21 | ## Installation 22 | 23 | If you installed DeepLabCut[gui], this plugin is already installed. However, you can also use this as a stand-alone keypoint annotator without using DeepLabCut. Instructions below! 24 | 25 | Start by installing PySide6 with `pip install "pyside6==6.4.2"`; this is the library we now use to build GUIs. 26 | 27 | You can then install `napari-deeplabcut` via [pip]: 28 | 29 | pip install napari-deeplabcut 30 | 31 | 32 | 33 | Alternatively, to install the latest development version, run: 34 | 35 | pip install git+https://github.com/DeepLabCut/napari-deeplabcut.git 36 | 37 | 38 | ## Usage 39 | 40 | To use the plugin, please run: 41 | 42 | napari 43 | 44 | Then, activate the plugin in Plugins > napari-deeplabcut: Keypoint controls. 45 | 46 | All accepted files (config.yaml, images, h5 data files) can be loaded 47 | either by dropping them directly onto the canvas or via the File menu. 48 | 49 | The easiest way to get started is to drop a folder (typically a folder from within a DeepLabCut's `labeled-data` directory), and, if labeling from scratch, drop the corresponding `config.yaml` to automatically add a `Points layer` and populate the dropdown menus. 50 | 51 | [🎥 DEMO 52 | ](https://youtu.be/hsA9IB5r73E) 53 | 54 | **Tools & shortcuts are:** 55 | 56 | - `2` and `3`, to easily switch between labeling and selection mode 57 | - `4`, to enable pan & zoom (which is achieved using the mouse wheel or finger scrolling on the Trackpad) 58 | - `M`, to cycle through regular (sequential), quick, and cycle annotation mode (see the description [here](https://github.com/DeepLabCut/napari-deeplabcut/blob/5a5709dd38868341568d66eab548ae8abf37cd63/src/napari_deeplabcut/keypoints.py#L25-L34)) 59 | - `E`, to enable edge coloring (by default, if using this in refinement GUI mode, points with a confidence lower than 0.6 are marked 60 | in red) 61 | - `F`, to toggle between animal and body part color scheme. 62 | - `V`, to toggle visibility of the selected layer. 63 | - `backspace` to delete a point. 64 | - Check the box "display text" to show the label names on the canvas. 65 | - To move to another folder, be sure to save (`Ctrl+S`), then delete the layers, and re-drag/drop the next folder. 66 | - One can jump to a specific image by double-clicking and editing the current frame number (located to the right of the slider). 67 | - Selected points can be copied with `Ctrl+C`, and pasted onto other images with `Ctrl+V`. 68 | 69 | 70 | ### Save Layers 71 | 72 | Annotations and segmentations are saved with `File > Save Selected Layer(s)...` (or its shortcut `Ctrl+S`). 73 | Only when saving segmentation masks does a save file dialog pop up to name the destination folder; 74 | keypoint annotations are otherwise automatically saved in the corresponding folder as `CollectedData_.h5`. 75 | - As a reminder, DLC will only use the H5 file; so be sure if you open already labeled images you save/overwrite the H5. 76 | - Note, before saving a layer, make sure the points layer is selected. If the user clicked on the image(s) layer first, does `Save As`, then closes the window, any labeling work during that session will be lost! 77 | - Modifying and then saving points in a `machinelabels...` layer will add to or overwrite the existing `CollectedData` layer and will **not** save to the `machinelabels` file. 78 | 79 | 80 | ### Video frame extraction and prediction refinement 81 | 82 | Since v0.0.4, videos can be viewed in the GUI. 83 | 84 | Since v0.0.5, trailing points can be visualized; e.g., helping in the identification 85 | of swaps or outlier, jittery predictions. 86 | 87 | Loading a video (and its corresponding output h5 file) will enable the video actions 88 | at the top of the dock widget: they offer the option to manually extract video 89 | frames from the GUI, or to define cropping coordinates. 90 | Note that keypoints can be displaced and saved, as when annotating individual frames. 91 | 92 | 93 | ## Workflow 94 | 95 | Suggested workflows, depending on the image folder contents: 96 | 97 | 1. **Labeling from scratch** – the image folder does not contain `CollectedData_.h5` file. 98 | 99 | Open *napari* as described in [Usage](#usage) and open an image folder together with the DeepLabCut project's `config.yaml`. 100 | The image folder creates an *image layer* with the images to label. 101 | Supported image formats are: `jpg`, `jpeg`, `png`. 102 | The `config.yaml` file creates a *Points layer*, which holds metadata (such as keypoints read from the config file) necessary for labeling. 103 | Select the *Points layer* in the layer list (lower left pane on the GUI) and click on the *+*-symbol in the layer controls menu (upper left pane) to start labeling. 104 | The current keypoint can be viewed/selected in the keypoints dropdown menu (right pane). 105 | The slider below the displayed image (or the left/right arrow keys) allows selecting the image to label. 106 | 107 | To save the labeling progress refer to [Save Layers](#save-layers). 108 | `Data successfully saved` should be shown in the status bar, and the image folder should now contain a `CollectedData_.h5` file. 109 | (Note: For convenience, a CSV file with the same name is also saved.) 110 | 111 | 2. **Resuming labeling** – the image folder contains a `CollectedData_.h5` file. 112 | 113 | Open *napari* and open an image folder (which needs to contain a `CollectedData_.h5` file). 114 | In this case, it is not necessary to open the DLC project's `config.yaml` file, as all necessary metadata is read from the `h5` data file. 115 | 116 | Saving works as described in *1*. 117 | 118 | ***Note that if a new body part has been added to the `config.yaml` file after having started to label, loading the config in the GUI is necessary to update the dropdown menus and other metadata.*** 119 | 120 | ***As `viridis` is `napari-deeplabcut` default colormap, loading the config in the GUI is also needed to update the color scheme.*** 121 | 122 | 3. **Refining labels** – the image folder contains a `machinelabels-iter<#>.h5` file. 123 | 124 | The process is analog to *2*. 125 | Open *napari* and open an image folder. 126 | If the video was originally labeled, *and* had outliers extracted it will contain a `CollectedData_.h5` file and a `machinelabels-iter<#>.h5` file. In this case, select the `machinelabels` layer in the GUI, and type `e` to show edges. Red indicates likelihood < 0.6. As you navigate through frames, images with labels with edges will need to be refined (moved, deleted, etc). Images with labels without edges will be on the `CollectedData` (previous manual annotations) layer and shouldn't need refining. However, you can switch to that layer and fix errors. You can also right-click on the `CollectedData` layer and select `toggle visibility` to hide that layer. Select the `machinelabels` layer before saving which will append your refined annotations to `CollectedData`. 127 | 128 | If the folder only had outliers extracted and wasn't originally labeled, it will not have a `CollectedData` layer. Work with the `machinelabels` layer selected to refine annotation positions, then save. 129 | 130 | In this case, it is not necessary to open the DLC project's `config.yaml` file, as all necessary metadata is read from the `h5` data file. 131 | 132 | Saving works as described in *1*. 133 | 134 | 4. **Drawing segmentation masks** 135 | 136 | Drop an image folder as in *1*, manually add a *shapes layer*. Then select the *rectangle* in the layer controls (top left pane), 137 | and start drawing rectangles over the images. Masks and rectangle vertices are saved as described in [Save Layers](#save-layers). 138 | Note that masks can be reloaded and edited at a later stage by dropping the `vertices.csv` file onto the canvas. 139 | 140 | ### Workflow flowchart 141 | 142 | ```mermaid 143 | %%{init: {"flowchart": {"htmlLabels": false}} }%% 144 | graph TD 145 | id1[What stage of labeling?] 146 | id2[deeplabcut.label_frames] 147 | id3[deeplabcut.refine_labels] 148 | id4[Add labels to, or modify in, \n `CollectedData...` layer and save that layer] 149 | id5[Modify labels in `machinelabels` layer and save \n which will create a `CollectedData...` file] 150 | id6[Have you refined some labels from the most recent iteration and saved already?] 151 | id7["All extracted frames are already saved in `CollectedData...`. 152 | 1. Hide or trash all `machinelabels` layers. 153 | 2. Then modify in and save `CollectedData`"] 154 | id8[" 155 | 1. hide or trash all `machinelabels` layers except for the most recent. 156 | 2. Select most recent `machinelabels` and hit `e` to show edges. 157 | 3. Modify only in `machinelabels` and skip frames with labels without edges shown. 158 | 4. Save `machinelabels` layer, which will add data to `CollectedData`. 159 | - If you need to revisit this video later, ignore `machinelabels` and work only in `CollectedData`"] 160 | 161 | id1 -->|I need to manually label new frames \n or fix my labels|id2 162 | id1 ---->|I need to refine outlier frames \nfrom analyzed videos|id3 163 | id2 -->id4 164 | id3 -->|I only have a `machinelabels...` file|id5 165 | id3 ---->|I have both `machinelabels` and `CollectedData` files|id6 166 | id6 -->|yes|id7 167 | id6 ---->|no, I just extracted outliers|id8 168 | ``` 169 | 170 | ### Labeling multiple image folders 171 | 172 | Labeling multiple image folders has to be done in sequence; i.e., only one image folder can be opened at a time. 173 | After labeling the images of a particular folder is done and the associated *Points layer* has been saved, *all* layers should be removed from the layers list (lower left pane on the GUI) by selecting them and clicking on the trashcan icon. 174 | Now, another image folder can be labeled, following the process described in *1*, *2*, or *3*, depending on the particular image folder. 175 | 176 | 177 | ### Defining cropping coordinates 178 | 179 | Prior to defining cropping coordinates, two elements should be loaded in the GUI: 180 | a video and the DLC project's `config.yaml` file (into which the crop dimensions will be stored). 181 | Then it suffices to add a `Shapes layer`, draw a `rectangle` in it with the desired area, 182 | and hit the button `Store crop coordinates`; coordinates are automatically written to the configuration file. 183 | 184 | 185 | ## Contributing 186 | 187 | Contributions are very welcome. Tests can be run with [tox], please ensure 188 | the coverage at least stays the same before you submit a pull request. 189 | 190 | To locally install the code, please git clone the repo and then run `pip install -e .` 191 | 192 | ## License 193 | 194 | Distributed under the terms of the [BSD-3] license, 195 | "napari-deeplabcut" is free and open source software. 196 | 197 | ## Issues 198 | 199 | If you encounter any problems, please [file an issue] along with a detailed description. 200 | 201 | [file an issue]: https://github.com/DeepLabCut/napari-deeplabcut/issues 202 | 203 | 204 | ## Acknowledgements 205 | 206 | 207 | This [napari] plugin was generated with [Cookiecutter] using [@napari]'s [cookiecutter-napari-plugin] template. We thank the Chan Zuckerberg Initiative (CZI) for funding the initial development of this work! 208 | 209 | 216 | 217 | 218 | [napari]: https://github.com/napari/napari 219 | [Cookiecutter]: https://github.com/audreyr/cookiecutter 220 | [@napari]: https://github.com/napari 221 | [cookiecutter-napari-plugin]: https://github.com/napari/cookiecutter-napari-plugin 222 | [BSD-3]: http://opensource.org/licenses/BSD-3-Clause 223 | [tox]: https://tox.readthedocs.io/en/latest/ 224 | [pip]: https://pypi.org/project/pip/ 225 | [PyPI]: https://pypi.org/ 226 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Welcome to napari-deeplabcut 2 | 3 | For documentation, please see: https://deeplabcut.github.io/DeepLabCut/docs/napari_GUI.html 4 | 5 | ## Maintenance 6 | 7 | ### Testing 8 | 9 | The package can be locally tested running `python -m pytest`. Note that these tests are automatically triggered anyway by pull requests targeting the main branch or push events. 10 | 11 | ### Deployment 12 | 13 | Versioning and deployment builds upon napari's plugin cookiecutter setup 14 | https://github.com/napari/cookiecutter-napari-plugin/blob/main/README.md#automatic-deployment-and-version-management. 15 | In short, it suffices to create a tagged commit and push it to GitHub; once the `deploy` step of the GitHub's CI workflow has completed, the package has been automatically published to PyPi. 16 | 17 | ```bash 18 | # the tag will be used as the version string for the package 19 | git tag -a v0.1.0 -m "v0.1.0" 20 | 21 | # make sure to use follow-tags so that the tag also gets pushed to github 22 | git push --follow-tags 23 | # alternatively, you can set follow-tags as default with git config --global push.followTags true 24 | ``` 25 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: napari-deeplabcut 2 | site_description: napari + DeepLabCut annotation tool 3 | site_author: Jessy Lauer for the DeeplabCut organization 4 | 5 | theme: readthedocs 6 | 7 | 8 | repo_url: https://github.com/DeepLabCut/napari-deeplabcut 9 | 10 | pages: 11 | - Home: index.md 12 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel", "setuptools_scm"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | 6 | [tool.setuptools_scm] 7 | write_to = "src/napari_deeplabcut/_version.py" 8 | 9 | [tool.pytest.ini_options] 10 | qt_api = "pyside6" 11 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = napari-deeplabcut 3 | description =napari + DeepLabCut annotation tool 4 | long_description = file: README.md 5 | long_description_content_type = text/markdown 6 | url = https://github.com/DeepLabCut/napari-deeplabcut 7 | author = Team DeepLabCut, Lead by Jessy Lauer 8 | author_email = admin@deeplabcut.org 9 | license = BSD-3-Clause 10 | license_file = LICENSE 11 | classifiers = 12 | Development Status :: 2 - Pre-Alpha 13 | Framework :: napari 14 | Intended Audience :: Developers 15 | License :: OSI Approved :: BSD License 16 | Operating System :: OS Independent 17 | Programming Language :: Python :: 3 18 | Programming Language :: Python :: 3 :: Only 19 | Programming Language :: Python :: 3.9 20 | Programming Language :: Python :: 3.10 21 | Topic :: Scientific/Engineering :: Artificial Intelligence 22 | Topic :: Scientific/Engineering :: Image Processing 23 | Topic :: Scientific/Engineering :: Visualization 24 | project_urls = 25 | Bug Tracker = https://github.com/DeepLabCut/napari-deeplabcut/issues 26 | Documentation = https://github.com/DeepLabCut/napari-deeplabcut#README.md 27 | Source Code = https://github.com/DeepLabCut/napari-deeplabcut 28 | User Support = https://github.com/DeepLabCut/napari-deeplabcut/issues 29 | 30 | [options] 31 | packages = find: 32 | install_requires = 33 | dask-image 34 | matplotlib>=3.3 35 | napari==0.4.18 36 | natsort 37 | numpy 38 | opencv-python-headless 39 | pandas 40 | pyyaml 41 | qtpy>=2.4 42 | scikit-image 43 | scipy 44 | tables 45 | python_requires = >=3.9 46 | include_package_data = True 47 | package_dir = 48 | =src 49 | setup_requires = 50 | setuptools-scm 51 | 52 | [options.packages.find] 53 | where = src 54 | 55 | [options.entry_points] 56 | napari.manifest = 57 | napari-deeplabcut = napari_deeplabcut:napari.yaml 58 | 59 | [options.extras_require] 60 | testing = 61 | pyside6==6.4.2 62 | pytest 63 | pytest-cov 64 | pytest-qt 65 | tox 66 | 67 | [options.package_data] 68 | napari_deeplabcut = 69 | napari.yaml 70 | -------------------------------------------------------------------------------- /src/napari_deeplabcut/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import warnings 3 | 4 | # FIXME: Circumvent the need to access window.qt_viewer 5 | warnings.filterwarnings("ignore", category=FutureWarning) 6 | 7 | 8 | class VispyWarningFilter(logging.Filter): 9 | def filter(self, record): 10 | ignore_messages = [ 11 | "delivering touch release to same window QWindow(0x0) not QWidgetWindow", 12 | "skipping QEventPoint", 13 | ] 14 | return not any(msg in record.getMessage() for msg in ignore_messages) 15 | 16 | 17 | vispy_logger = logging.getLogger("vispy") 18 | vispy_logger.addFilter(VispyWarningFilter()) 19 | 20 | try: 21 | from ._version import version as __version__ 22 | except ImportError: # pragma: no cover 23 | __version__ = "unknown" 24 | 25 | 26 | from ._reader import ( 27 | get_hdf_reader, 28 | get_image_reader, 29 | get_video_reader, 30 | get_folder_parser, 31 | get_config_reader, 32 | ) 33 | from ._writer import write_hdf, write_masks 34 | -------------------------------------------------------------------------------- /src/napari_deeplabcut/_reader.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | from pathlib import Path 5 | from typing import Dict, List, Optional, Sequence 6 | 7 | import cv2 8 | import dask.array as da 9 | import numpy as np 10 | import pandas as pd 11 | import yaml 12 | from dask import delayed 13 | from dask_image.imread import imread 14 | from napari.types import LayerData 15 | from natsort import natsorted 16 | 17 | from napari_deeplabcut import misc 18 | 19 | SUPPORTED_IMAGES = ".jpg", ".jpeg", ".png" 20 | SUPPORTED_VIDEOS = ".mp4", ".mov", ".avi" 21 | 22 | 23 | def is_video(filename: str): 24 | return any(filename.lower().endswith(ext) for ext in SUPPORTED_VIDEOS) 25 | 26 | 27 | def get_hdf_reader(path): 28 | if isinstance(path, list): 29 | path = path[0] 30 | 31 | if not path.endswith(".h5"): 32 | return None 33 | 34 | return read_hdf 35 | 36 | 37 | def get_image_reader(path): 38 | if isinstance(path, list): 39 | path = path[0] 40 | 41 | if not any(path.lower().endswith(ext) for ext in SUPPORTED_IMAGES): 42 | return None 43 | 44 | return read_images 45 | 46 | 47 | def get_video_reader(path): 48 | if isinstance(path, str) and any( 49 | path.lower().endswith(ext) for ext in SUPPORTED_VIDEOS 50 | ): 51 | return read_video 52 | return None 53 | 54 | 55 | def get_config_reader(path): 56 | if isinstance(path, list): 57 | path = path[0] 58 | 59 | if not path.endswith(".yaml"): 60 | return None 61 | 62 | return read_config 63 | 64 | 65 | def get_folder_parser(path): 66 | if not os.path.isdir(path): 67 | return None 68 | 69 | layers = [] 70 | files = os.listdir(path) 71 | images = "" 72 | for file in files: 73 | if any(file.lower().endswith(ext) for ext in SUPPORTED_IMAGES): 74 | images = os.path.join(path, f"*{os.path.splitext(file)[1]}") 75 | break 76 | if not images: 77 | raise OSError(f"No supported images were found in {path}.") 78 | 79 | layers.extend(read_images(images)) 80 | datafile = "" 81 | for file in os.listdir(path): 82 | if file.endswith(".h5"): 83 | datafile = os.path.join(path, "*.h5") 84 | break 85 | if datafile: 86 | layers.extend(read_hdf(datafile)) 87 | 88 | return lambda _: layers 89 | 90 | 91 | def read_images(path): 92 | if isinstance(path, list): 93 | root, ext = os.path.splitext(path[0]) 94 | path = os.path.join(os.path.dirname(root), f"*{ext}") 95 | # Retrieve filepaths exactly as parsed by pims 96 | filepaths = [] 97 | for filepath in glob.iglob(path): 98 | relpath = Path(filepath).parts[-3:] 99 | filepaths.append(os.path.join(*relpath)) 100 | params = { 101 | "name": "images", 102 | "metadata": { 103 | "paths": natsorted(filepaths), 104 | "root": os.path.split(path)[0], 105 | }, 106 | } 107 | 108 | # https://github.com/soft-matter/pims/issues/452 109 | if len(filepaths) == 1: 110 | path = glob.glob(path)[0] 111 | 112 | return [(imread(path), params, "image")] 113 | 114 | 115 | def _populate_metadata( 116 | header: misc.DLCHeader, 117 | *, 118 | labels: Optional[Sequence[str]] = None, 119 | ids: Optional[Sequence[str]] = None, 120 | likelihood: Optional[Sequence[float]] = None, 121 | paths: Optional[List[str]] = None, 122 | size: Optional[int] = 8, 123 | pcutoff: Optional[float] = 0.6, 124 | colormap: Optional[str] = "viridis", 125 | ) -> Dict: 126 | if labels is None: 127 | labels = header.bodyparts 128 | if ids is None: 129 | ids = header.individuals 130 | if likelihood is None: 131 | likelihood = np.ones(len(labels)) 132 | face_color_cycle_maps = misc.build_color_cycles(header, colormap) 133 | face_color_prop = "id" if ids[0] else "label" 134 | return { 135 | "name": "keypoints", 136 | "text": "{id}–{label}" if ids[0] else "label", 137 | "properties": { 138 | "label": list(labels), 139 | "id": list(ids), 140 | "likelihood": likelihood, 141 | "valid": likelihood > pcutoff, 142 | }, 143 | "face_color_cycle": face_color_cycle_maps[face_color_prop], 144 | "face_color": face_color_prop, 145 | "face_colormap": colormap, 146 | "edge_color": "valid", 147 | "edge_color_cycle": ["black", "red"], 148 | "edge_width": 0, 149 | "edge_width_is_relative": False, 150 | "size": size, 151 | "metadata": { 152 | "header": header, 153 | "face_color_cycles": face_color_cycle_maps, 154 | "colormap_name": colormap, 155 | "paths": paths or [], 156 | }, 157 | } 158 | 159 | 160 | def _load_superkeypoints_diagram(super_animal: str): 161 | path = str(Path(__file__).parent / "assets" / f"{super_animal}.jpg") 162 | return imread(path), {"root": ""}, "images" 163 | 164 | 165 | def _load_superkeypoints(super_animal: str): 166 | path = str(Path(__file__).parent / "assets" / f"{super_animal}.json") 167 | with open(path) as f: 168 | return json.load(f) 169 | 170 | 171 | def _load_config(config_path: str): 172 | with open(config_path) as file: 173 | return yaml.safe_load(file) 174 | 175 | 176 | def read_config(configname: str) -> List[LayerData]: 177 | config = _load_config(configname) 178 | header = misc.DLCHeader.from_config(config) 179 | metadata = _populate_metadata( 180 | header, 181 | size=config["dotsize"], 182 | pcutoff=config["pcutoff"], 183 | colormap=config["colormap"], 184 | likelihood=np.array([1]), 185 | ) 186 | metadata["name"] = f"CollectedData_{config['scorer']}" 187 | metadata["ndim"] = 3 188 | metadata["property_choices"] = metadata.pop("properties") 189 | metadata["metadata"]["project"] = os.path.dirname(configname) 190 | conversion_tables = config.get("SuperAnimalConversionTables") 191 | if conversion_tables is not None: 192 | super_animal, table = conversion_tables.popitem() 193 | metadata["metadata"]["tables"] = {super_animal: table} 194 | return [(None, metadata, "points")] 195 | 196 | 197 | def read_hdf(filename: str) -> List[LayerData]: 198 | config_path = misc.find_project_config_path(filename) 199 | layers = [] 200 | for filename in glob.iglob(filename): 201 | temp = pd.read_hdf(filename) 202 | temp = misc.merge_multiple_scorers(temp) 203 | header = misc.DLCHeader(temp.columns) 204 | temp = temp.droplevel("scorer", axis=1) 205 | if "individuals" not in temp.columns.names: 206 | # Append a fake level to the MultiIndex 207 | # to make it look like a multi-animal DataFrame 208 | old_idx = temp.columns.to_frame() 209 | old_idx.insert(0, "individuals", "") 210 | temp.columns = pd.MultiIndex.from_frame(old_idx) 211 | try: 212 | cfg = _load_config(config_path) 213 | colormap = cfg["colormap"] 214 | except FileNotFoundError: 215 | colormap = "rainbow" 216 | else: 217 | colormap = "Set3" 218 | if isinstance(temp.index, pd.MultiIndex): 219 | temp.index = [os.path.join(*row) for row in temp.index] 220 | df = ( 221 | temp.stack(["individuals", "bodyparts"]) 222 | .reindex(header.individuals, level="individuals") 223 | .reindex(header.bodyparts, level="bodyparts") 224 | .reset_index() 225 | ) 226 | nrows = df.shape[0] 227 | data = np.empty((nrows, 3)) 228 | image_paths = df["level_0"] 229 | if np.issubdtype(image_paths.dtype, np.number): 230 | image_inds = image_paths.values 231 | paths2inds = [] 232 | else: 233 | image_inds, paths2inds = misc.encode_categories( 234 | image_paths, 235 | return_map=True, 236 | ) 237 | data[:, 0] = image_inds 238 | data[:, 1:] = df[["y", "x"]].to_numpy() 239 | metadata = _populate_metadata( 240 | header, 241 | labels=df["bodyparts"], 242 | ids=df["individuals"], 243 | likelihood=df.get("likelihood"), 244 | paths=list(paths2inds), 245 | colormap=colormap, 246 | ) 247 | metadata["name"] = os.path.split(filename)[1].split(".")[0] 248 | metadata["metadata"]["root"] = os.path.split(filename)[0] 249 | # Store file name in case the layer's name is edited by the user 250 | metadata["metadata"]["name"] = metadata["name"] 251 | layers.append((data, metadata, "points")) 252 | return layers 253 | 254 | 255 | class Video: 256 | def __init__(self, video_path): 257 | if not os.path.isfile(video_path): 258 | raise ValueError(f'Video path "{video_path}" does not point to a file.') 259 | 260 | self.path = video_path 261 | self.stream = cv2.VideoCapture(video_path) 262 | if not self.stream.isOpened(): 263 | raise OSError("Video could not be opened.") 264 | 265 | self._n_frames = int(self.stream.get(cv2.CAP_PROP_FRAME_COUNT)) 266 | self._width = int(self.stream.get(cv2.CAP_PROP_FRAME_WIDTH)) 267 | self._height = int(self.stream.get(cv2.CAP_PROP_FRAME_HEIGHT)) 268 | self._frame = cv2.UMat(self._height, self._width, cv2.CV_8UC3) 269 | 270 | def __len__(self): 271 | return self._n_frames 272 | 273 | @property 274 | def width(self): 275 | return self._width 276 | 277 | @property 278 | def height(self): 279 | return self._height 280 | 281 | def set_to_frame(self, ind): 282 | ind = min(ind, len(self) - 1) 283 | ind += 1 # Unclear why this is needed at all 284 | self.stream.set(cv2.CAP_PROP_POS_FRAMES, ind) 285 | 286 | def read_frame(self): 287 | self.stream.retrieve(self._frame) 288 | cv2.cvtColor(self._frame, cv2.COLOR_BGR2RGB, self._frame, 3) 289 | return self._frame.get() 290 | 291 | def close(self): 292 | self.stream.release() 293 | 294 | 295 | def read_video(filename: str, opencv: bool = True): 296 | if opencv: 297 | stream = Video(filename) 298 | shape = stream.width, stream.height, 3 299 | 300 | def _read_frame(ind): 301 | stream.set_to_frame(ind) 302 | return stream.read_frame() 303 | 304 | lazy_imread = delayed(_read_frame) 305 | else: # pragma: no cover 306 | from pims import PyAVReaderIndexed 307 | 308 | try: 309 | stream = PyAVReaderIndexed(filename) 310 | except ImportError: 311 | raise ImportError("`pip install av` to use the PyAV video reader.") 312 | 313 | shape = stream.frame_shape 314 | lazy_imread = delayed(stream.get_frame) 315 | 316 | movie = da.stack( 317 | [ 318 | da.from_delayed(lazy_imread(i), shape=shape, dtype=np.uint8) 319 | for i in range(len(stream)) 320 | ] 321 | ) 322 | elems = list(Path(filename).parts) 323 | elems[-2] = "labeled-data" 324 | elems[-1] = elems[-1].split(".")[0] 325 | root = os.path.join(*elems) 326 | params = { 327 | "name": filename, 328 | "metadata": { 329 | "root": root, 330 | }, 331 | } 332 | return [(movie, params)] 333 | -------------------------------------------------------------------------------- /src/napari_deeplabcut/_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepLabCut/napari-deeplabcut/0ceba4f203135aeeeaf8edfb8a8689d0517a5e1e/src/napari_deeplabcut/_tests/__init__.py -------------------------------------------------------------------------------- /src/napari_deeplabcut/_tests/conftest.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os 4 | 5 | os.environ["hide_tutorial"] = "True" 6 | import pandas as pd 7 | import pytest 8 | from napari_deeplabcut import keypoints, _writer 9 | from skimage.io import imsave 10 | 11 | 12 | @pytest.fixture 13 | def viewer(make_napari_viewer): 14 | viewer = make_napari_viewer() 15 | for action in viewer.window.plugins_menu.actions(): 16 | if "deeplabcut" in action.text(): 17 | action.trigger() 18 | break 19 | return viewer 20 | 21 | 22 | @pytest.fixture 23 | def fake_keypoints(): 24 | n_rows = 10 25 | n_animals = 2 26 | n_kpts = 3 27 | data = np.random.rand(n_rows, n_animals * n_kpts * 2) 28 | cols = pd.MultiIndex.from_product( 29 | [ 30 | ["me"], 31 | [f"animal_{i}" for i in range(n_animals)], 32 | [f"kpt_{i}" for i in range(n_kpts)], 33 | ["x", "y"], 34 | ], 35 | names=["scorer", "individuals", "bodyparts", "coords"], 36 | ) 37 | df = pd.DataFrame(data, columns=cols, index=range(n_rows)) 38 | return df 39 | 40 | 41 | @pytest.fixture 42 | def points(tmp_path_factory, viewer, fake_keypoints): 43 | output_path = str(tmp_path_factory.mktemp("folder") / "fake_data.h5") 44 | fake_keypoints.to_hdf(output_path, key="data") 45 | layer = viewer.open(output_path, plugin="napari-deeplabcut")[0] 46 | return layer 47 | 48 | 49 | @pytest.fixture 50 | def fake_image(): 51 | return (np.random.rand(10, 10) * 255).astype(np.uint8) 52 | 53 | 54 | @pytest.fixture 55 | def images(tmp_path_factory, viewer, fake_image): 56 | output_path = str(tmp_path_factory.mktemp("folder") / "img.png") 57 | imsave(output_path, fake_image) 58 | layer = viewer.open(output_path, plugin="napari-deeplabcut")[0] 59 | return layer 60 | 61 | 62 | @pytest.fixture 63 | def store(viewer, points): 64 | return keypoints.KeypointStore(viewer, points) 65 | 66 | 67 | @pytest.fixture(scope="session") 68 | def config_path(tmp_path_factory): 69 | cfg = { 70 | "scorer": "me", 71 | "bodyparts": list("abc"), 72 | "dotsize": 0, 73 | "pcutoff": 0, 74 | "colormap": "viridis", 75 | "video_sets": { 76 | "fake_video": [], 77 | }, 78 | } 79 | path = str(tmp_path_factory.mktemp("configs") / "config.yaml") 80 | _writer._write_config( 81 | path, 82 | params=cfg, 83 | ) 84 | return path 85 | 86 | 87 | @pytest.fixture(scope="session") 88 | def video_path(tmp_path_factory): 89 | output_path = str(tmp_path_factory.mktemp("data") / "fake_video.avi") 90 | h = w = 50 91 | writer = cv2.VideoWriter( 92 | output_path, 93 | cv2.VideoWriter_fourcc(*"MJPG"), 94 | 2, 95 | (w, h), 96 | ) 97 | for _ in range(5): 98 | frame = np.random.randint(0, 255, (h, w, 3)).astype(np.uint8) 99 | writer.write(frame) 100 | writer.release() 101 | return output_path 102 | -------------------------------------------------------------------------------- /src/napari_deeplabcut/_tests/test_keypoints.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from napari_deeplabcut import keypoints 3 | 4 | 5 | def test_store_advance_step(store): 6 | assert store.current_step == 0 7 | store._advance_step(event=None) 8 | assert store.current_step == 1 9 | 10 | 11 | def test_store_labels(store, fake_keypoints): 12 | assert store.n_steps == fake_keypoints.shape[0] 13 | assert store.labels == list( 14 | fake_keypoints.columns.get_level_values("bodyparts").unique() 15 | ) 16 | 17 | 18 | def test_store_find_first_unlabeled_frame(store, fake_keypoints): 19 | store._find_first_unlabeled_frame(event=None) 20 | assert store.current_step == store.n_steps - 1 21 | # Remove a frame to test whether it is correctly found 22 | ind_to_remove = 2 23 | data = store.layer.data 24 | store.layer.data = data[data[:, 0] != ind_to_remove] 25 | store._find_first_unlabeled_frame(event=None) 26 | assert store.current_step == ind_to_remove 27 | 28 | 29 | def test_store_keypoints(store, fake_keypoints): 30 | annotated_keypoints = store.annotated_keypoints 31 | assert len(annotated_keypoints) == fake_keypoints.shape[1] // 2 32 | assert annotated_keypoints[0].id == "animal_0" 33 | assert annotated_keypoints[-1].id == "animal_1" 34 | kpt = keypoints.Keypoint(label="kpt_0", id="animal_0") 35 | next_kpt = keypoints.Keypoint(label="kpt_1", id="animal_0") 36 | store.current_keypoint = kpt 37 | assert store.current_keypoint == kpt 38 | store.next_keypoint() 39 | assert store.current_keypoint == next_kpt 40 | store.prev_keypoint() 41 | assert store.current_keypoint == kpt 42 | store.next_keypoint() 43 | 44 | 45 | def test_point_resize(viewer, points): 46 | viewer.layers.selection.add(points) 47 | layer = viewer.layers[0] 48 | controls = keypoints.QtPointsControls(layer) 49 | new_size = 10 50 | controls.changeCurrentSize(new_size) 51 | np.testing.assert_array_equal(points.size, new_size) 52 | 53 | 54 | def test_add_unnanotated(store): 55 | store.layer.metadata["controls"].label_mode = "loop" 56 | ind_to_remove = 0 57 | data = store.layer.data 58 | store.layer.data = data[data[:, 0] != ind_to_remove] 59 | store.viewer.dims.set_current_step(0, ind_to_remove) 60 | assert not store.annotated_keypoints 61 | n_points = store.layer.data.shape[0] 62 | keypoints._add(store, coord=(0, 1, 1)) 63 | assert store.layer.data.shape[0] == n_points + 1 64 | assert store.current_step == ind_to_remove + 1 65 | 66 | 67 | def test_add_quick(store): 68 | store.layer.metadata["controls"].label_mode = "quick" 69 | store.current_keypoint = store._keypoints[0] 70 | coord = store.current_step, -1, -1 71 | keypoints._add(store, coord=coord) 72 | np.testing.assert_array_equal( 73 | store.layer.data[store.current_step], 74 | coord, 75 | ) 76 | -------------------------------------------------------------------------------- /src/napari_deeplabcut/_tests/test_misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import pandas as pd 4 | import pytest 5 | from napari_deeplabcut import misc, _reader 6 | 7 | 8 | def test_unsorted_unique_numeric(): 9 | seq = [4, 3, 2, 1, 0, 1, 2, 3, 4, 5] 10 | out = misc.unsorted_unique(seq) 11 | assert list(out) == [4, 3, 2, 1, 0, 5] 12 | 13 | 14 | def test_unsorted_unique_string(): 15 | seq = ["c", "b", "d", "b", "a", "b"] 16 | out = misc.unsorted_unique(seq) 17 | assert list(out) == ["c", "b", "d", "a"] 18 | 19 | 20 | def test_encode_categories(): 21 | categories = list("abcdabcd") 22 | inds, map_ = misc.encode_categories(categories, return_map=True) 23 | assert list(inds) == [0, 1, 2, 3, 0, 1, 2, 3] 24 | assert map_ == dict(zip(list("abcd"), range(4))) 25 | inds = misc.encode_categories(categories, return_map=False) 26 | 27 | 28 | def test_merge_multiple_scorers_no_likelihood(fake_keypoints): 29 | temp = fake_keypoints.copy(deep=True) 30 | temp.columns = temp.columns.set_levels(["you"], level="scorer") 31 | df = fake_keypoints.merge(temp, left_index=True, right_index=True) 32 | df = misc.merge_multiple_scorers(df) 33 | pd.testing.assert_frame_equal(df, fake_keypoints) 34 | 35 | 36 | def test_merge_multiple_scorers(fake_keypoints): 37 | new_columns = pd.MultiIndex.from_product( 38 | fake_keypoints.columns.levels[:-1] + [["x", "y", "likelihood"]], 39 | names=fake_keypoints.columns.names, 40 | ) 41 | fake_keypoints = fake_keypoints.reindex(new_columns, axis=1) 42 | fake_keypoints.loc(axis=1)[:, :, :, "likelihood"] = 1 43 | temp = fake_keypoints.copy(deep=True) 44 | temp.columns = temp.columns.set_levels(["you"], level="scorer") 45 | fake_keypoints.iloc[:5] = np.nan 46 | temp.iloc[5:] = np.nan 47 | df = fake_keypoints.merge(temp, left_index=True, right_index=True) 48 | df = misc.merge_multiple_scorers(df) 49 | pd.testing.assert_index_equal(df.columns, fake_keypoints.columns) 50 | assert not df.isna().any(axis=None) 51 | 52 | 53 | @pytest.mark.parametrize( 54 | "path", 55 | ["/home/to/fake/path", "C:\\Users\\with\\fake\\name"], 56 | ) 57 | def test_to_os_dir_sep(path): 58 | sep_wrong = "\\" if os.path.sep == "/" else "/" 59 | assert sep_wrong not in misc.to_os_dir_sep(path) 60 | 61 | 62 | def test_to_os_dir_sep_invalid(): 63 | with pytest.raises(ValueError): 64 | misc.to_os_dir_sep("/home\\home") 65 | 66 | 67 | def test_guarantee_multiindex_rows(): 68 | fake_index = [ 69 | f"labeled-data/subfolder_{i}/image_{j}" for i in range(3) for j in range(10) 70 | ] 71 | df = pd.DataFrame(index=fake_index) 72 | misc.guarantee_multiindex_rows(df) 73 | assert isinstance(df.index, pd.MultiIndex) 74 | 75 | # Substitute index with frame numbers 76 | frame_numbers = list(range(df.shape[0])) 77 | df.index = frame_numbers 78 | misc.guarantee_multiindex_rows(df) 79 | assert df.index.to_list() == frame_numbers 80 | 81 | 82 | @pytest.mark.parametrize("n_colors", range(1, 11)) 83 | def test_build_color_cycle(n_colors): 84 | color_cycle = misc.build_color_cycle(n_colors) 85 | assert color_cycle.shape[0] == n_colors 86 | # Test whether all colors are different 87 | assert len(set(map(tuple, color_cycle))) == n_colors 88 | 89 | 90 | def test_dlc_header(): 91 | n_animals = 2 92 | n_keypoints = 3 93 | scorer = "me" 94 | animals = [f"animal_{n}" for n in range(n_animals)] 95 | keypoints = [f"kpt_{n}" for n in range(n_keypoints)] 96 | fake_columns = pd.MultiIndex.from_product( 97 | [ 98 | [scorer], 99 | animals, 100 | keypoints, 101 | ["x", "y", "likelihood"], 102 | ], 103 | names=["scorer", "individuals", "bodyparts", "coords"], 104 | ) 105 | header = misc.DLCHeader(fake_columns) 106 | assert header.scorer == scorer 107 | header.scorer = "you" 108 | assert header.scorer == "you" 109 | assert header.individuals == animals 110 | assert header.bodyparts == keypoints 111 | assert header.coords == ["x", "y", "likelihood"] 112 | 113 | 114 | def test_dlc_header_from_config_multi(config_path): 115 | config = _reader._load_config(config_path) 116 | config["multianimalproject"] = True 117 | config["individuals"] = ["animal"] 118 | config["multianimalbodyparts"] = list("abc") 119 | config["uniquebodyparts"] = list("de") 120 | header = misc.DLCHeader.from_config(config) 121 | assert header.individuals != [""] 122 | 123 | 124 | def test_cycle_enum(): 125 | enum = misc.CycleEnum("Test", list("AB")) 126 | assert next(enum).value == "a" 127 | assert next(enum).value == "b" 128 | assert next(enum).value == "a" 129 | assert next(enum).value == "b" 130 | assert enum["a"] == enum.A 131 | -------------------------------------------------------------------------------- /src/napari_deeplabcut/_tests/test_reader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pytest 4 | from napari_deeplabcut import _reader 5 | from skimage.io import imsave 6 | 7 | 8 | @pytest.mark.parametrize("ext", _reader.SUPPORTED_IMAGES) 9 | def test_get_image_reader(ext): 10 | path = f"fake_path/img{ext}" 11 | assert _reader.get_image_reader(path) is not None 12 | assert _reader.get_image_reader([path]) is not None 13 | 14 | 15 | def test_get_config_reader(config_path): 16 | assert _reader.get_config_reader(str(config_path)) is not None 17 | 18 | 19 | def test_get_config_reader_invalid_path(): 20 | assert _reader.get_config_reader("fake_data.h5") is None 21 | 22 | 23 | def test_get_folder_parser(tmp_path_factory, fake_keypoints): 24 | folder = tmp_path_factory.mktemp("folder") 25 | frame = (np.random.rand(10, 10) * 255).astype(np.uint8) 26 | imsave(folder / "img1.png", frame) 27 | imsave(folder / "img2.png", frame) 28 | layers = _reader.get_folder_parser(folder)(None) 29 | # There should be only an Image layer 30 | assert len(layers) == 1 31 | assert layers[0][-1] == "image" 32 | 33 | # Add an annotation data file 34 | fake_keypoints.to_hdf(folder / "data.h5", key="data") 35 | layers = _reader.get_folder_parser(folder)(None) 36 | # There should now be an additional Points layer 37 | assert len(layers) == 2 38 | assert layers[-1][-1] == "points" 39 | 40 | 41 | def test_get_folder_parser_wrong_input(): 42 | assert _reader.get_folder_parser("") is None 43 | 44 | 45 | def test_get_folder_parser_no_images(tmp_path_factory): 46 | folder = str(tmp_path_factory.mktemp("images")) 47 | with pytest.raises(OSError): 48 | _reader.get_folder_parser(folder) 49 | 50 | 51 | def test_read_images(tmp_path_factory, fake_image): 52 | folder = tmp_path_factory.mktemp("folder") 53 | path = str(folder / "img.png") 54 | imsave(path, fake_image) 55 | _ = _reader.read_images(path)[0] 56 | 57 | 58 | def test_read_config(config_path): 59 | dict_ = _reader.read_config(config_path)[0][1] 60 | assert dict_["name"].startswith("CollectedData_") 61 | assert config_path.startswith(dict_["metadata"]["project"]) 62 | 63 | 64 | def test_read_hdf_old_index(tmp_path_factory, fake_keypoints): 65 | path = str(tmp_path_factory.mktemp("folder") / "data.h5") 66 | old_index = [ 67 | f"labeled-data/video/img{i}.png" for i in range(fake_keypoints.shape[0]) 68 | ] 69 | fake_keypoints.index = old_index 70 | fake_keypoints.to_hdf(path, key="data") 71 | layers = _reader.read_hdf(path) 72 | assert len(layers) == 1 73 | image_paths = layers[0][1]["metadata"]["paths"] 74 | assert len(image_paths) == len(fake_keypoints) 75 | assert isinstance(image_paths[0], str) 76 | assert "labeled-data" in image_paths[0] 77 | 78 | 79 | def test_read_hdf_new_index(tmp_path_factory, fake_keypoints): 80 | path = str(tmp_path_factory.mktemp("folder") / "data.h5") 81 | new_index = pd.MultiIndex.from_product( 82 | [ 83 | ["labeled-data"], 84 | ["video"], 85 | [f"img{i}.png" for i in range(fake_keypoints.shape[0])], 86 | ] 87 | ) 88 | fake_keypoints.index = new_index 89 | fake_keypoints.to_hdf(path, key="data") 90 | layers = _reader.read_hdf(path) 91 | assert len(layers) == 1 92 | image_paths = layers[0][1]["metadata"]["paths"] 93 | assert len(image_paths) == len(fake_keypoints) 94 | assert isinstance(image_paths[0], str) 95 | assert "labeled-data" in image_paths[0] 96 | 97 | 98 | def test_video(video_path): 99 | video = _reader.Video(video_path) 100 | video.close() 101 | assert not video.stream.isOpened() 102 | 103 | 104 | def test_video_wrong_path(): 105 | with pytest.raises(ValueError): 106 | _ = _reader.Video("") 107 | 108 | 109 | def test_read_video(video_path): 110 | array, dict_ = _reader.read_video(video_path)[0] 111 | assert dict_["metadata"].get("root") 112 | assert array.shape[0] == 5 113 | assert array[0].compute().shape == (50, 50, 3) 114 | -------------------------------------------------------------------------------- /src/napari_deeplabcut/_tests/test_widgets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from napari_deeplabcut import _widgets 4 | from vispy import keys 5 | 6 | 7 | def test_guess_continuous(): 8 | assert _widgets.guess_continuous(np.array([0.0])) 9 | assert not _widgets.guess_continuous(np.array(list("abc"))) 10 | 11 | 12 | def test_keypoint_controls(viewer): 13 | controls = _widgets.KeypointControls(viewer) 14 | controls.label_mode = "loop" 15 | assert controls._radio_group.checkedButton().text() == "loop" 16 | controls.cycle_through_label_modes() 17 | assert controls._radio_group.checkedButton().text() == "sequential" 18 | 19 | 20 | def test_save_layers(viewer, points): 21 | controls = _widgets.KeypointControls(viewer) 22 | viewer.layers.selection.add(points) 23 | _widgets._save_layers_dialog(controls) 24 | 25 | 26 | def test_show_trails(viewer, store): 27 | controls = _widgets.KeypointControls(viewer) 28 | controls._stores["temp"] = store 29 | controls._is_saved = True 30 | controls._show_trails(state=2) 31 | 32 | 33 | def test_extract_single_frame(viewer, images): 34 | viewer.layers.selection.add(images) 35 | controls = _widgets.KeypointControls(viewer) 36 | controls._extract_single_frame() 37 | 38 | 39 | def test_store_crop_coordinates(viewer, images, config_path): 40 | viewer.layers.selection.add(images) 41 | _ = viewer.add_shapes( 42 | np.random.random((4, 3)), 43 | shape_type="rectangle", 44 | ) 45 | controls = _widgets.KeypointControls(viewer) 46 | controls._images_meta = { 47 | "name": "fake_video", 48 | "project": os.path.dirname(config_path), 49 | } 50 | controls._store_crop_coordinates() 51 | 52 | 53 | def test_toggle_face_color(viewer, points): 54 | viewer.layers.selection.add(points) 55 | view = viewer.window._qt_viewer 56 | # By default, points are colored by individual with multi-animal data 57 | assert points._face.color_properties.name == "id" 58 | view.canvas.events.key_press(key=keys.Key("F")) 59 | assert points._face.color_properties.name == "label" 60 | view.canvas.events.key_press(key=keys.Key("F")) 61 | assert points._face.color_properties.name == "id" 62 | 63 | 64 | def test_toggle_edge_color(viewer, points): 65 | viewer.layers.selection.add(points) 66 | view = viewer.window._qt_viewer 67 | np.testing.assert_array_equal(points.edge_width, 0) 68 | view.canvas.events.key_press(key=keys.Key("E")) 69 | np.testing.assert_array_equal(points.edge_width, 2) 70 | 71 | 72 | def test_dropdown_menu(qtbot): 73 | widget = _widgets.DropdownMenu(list("abc")) 74 | widget.update_to("c") 75 | assert widget.currentText() == "c" 76 | widget.reset() 77 | assert widget.currentText() == "a" 78 | qtbot.add_widget(widget) 79 | 80 | 81 | def test_keypoints_dropdown_menu(store): 82 | widget = _widgets.KeypointsDropdownMenu(store) 83 | assert "id" in widget.menus 84 | assert "label" in widget.menus 85 | label_menu = widget.menus["label"] 86 | label_menu.currentText() == "kpt_0" 87 | widget.update_menus(event=None) 88 | label_menu.currentText() == "kpt_2" 89 | widget.refresh_label_menu("id_0") 90 | assert label_menu.count() == 0 91 | 92 | 93 | def test_keypoints_dropdown_menu_smart_reset(store): 94 | widget = _widgets.KeypointsDropdownMenu(store) 95 | label_menu = widget.menus["label"] 96 | label_menu.update_to("kpt_2") 97 | widget._locked = True 98 | widget.smart_reset(event=None) 99 | assert label_menu.currentText() == "kpt_2" 100 | widget._locked = False 101 | widget.smart_reset(event=None) 102 | assert label_menu.currentText() == "kpt_0" 103 | 104 | 105 | def test_color_pair(): 106 | pair = _widgets.LabelPair(color="pink", name="kpt", parent=None) 107 | assert pair.part_name == "kpt" 108 | assert pair.color == "pink" 109 | pair.color = "orange" 110 | pair.part_name = "kpt2" 111 | assert pair.color_label.toolTip() == "kpt2" 112 | 113 | 114 | def test_color_scheme_display(qtbot): 115 | widget = _widgets.ColorSchemeDisplay(None) 116 | widget._build() 117 | assert not widget.scheme_dict 118 | assert not widget._container.layout().count() 119 | widget.add_entry("keypoint", "red") 120 | assert widget.scheme_dict["keypoint"] == "red" 121 | assert widget._container.layout().count() == 1 122 | qtbot.add_widget(widget) 123 | -------------------------------------------------------------------------------- /src/napari_deeplabcut/_widgets.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from collections import defaultdict, namedtuple 4 | from copy import deepcopy 5 | from datetime import datetime 6 | from functools import partial, cached_property 7 | from math import ceil, log10 8 | import matplotlib.pyplot as plt 9 | import matplotlib.style as mplstyle 10 | import napari 11 | import pandas as pd 12 | from pathlib import Path 13 | from types import MethodType 14 | from typing import Optional, Sequence, Union 15 | 16 | from matplotlib.backends.backend_qtagg import FigureCanvas, NavigationToolbar2QT 17 | 18 | import numpy as np 19 | from napari._qt.widgets.qt_welcome import QtWelcomeLabel 20 | from napari.layers import Image, Points, Shapes, Tracks 21 | from napari.layers.points._points_key_bindings import register_points_action 22 | from napari.layers.utils import color_manager 23 | from napari.layers.utils.layer_utils import _features_to_properties 24 | from napari.utils.events import Event 25 | from napari.utils.history import get_save_history, update_save_history 26 | from qtpy.QtCore import Qt, QTimer, Signal, QPoint, QSettings, QSize 27 | from qtpy.QtGui import QPainter, QAction, QCursor, QIcon 28 | from qtpy.QtSvgWidgets import QSvgWidget 29 | from qtpy.QtWidgets import ( 30 | QButtonGroup, 31 | QCheckBox, 32 | QComboBox, 33 | QDialog, 34 | QFileDialog, 35 | QGroupBox, 36 | QGridLayout, 37 | QHBoxLayout, 38 | QLabel, 39 | QMessageBox, 40 | QPushButton, 41 | QRadioButton, 42 | QScrollArea, 43 | QSizePolicy, 44 | QSlider, 45 | QStyle, 46 | QStyleOption, 47 | QVBoxLayout, 48 | QWidget, 49 | ) 50 | 51 | from napari_deeplabcut import keypoints 52 | from napari_deeplabcut._reader import ( 53 | _load_config, 54 | _load_superkeypoints_diagram, 55 | _load_superkeypoints, 56 | is_video, 57 | ) 58 | from napari_deeplabcut._writer import _write_config, _write_image, _form_df 59 | from napari_deeplabcut.misc import ( 60 | encode_categories, 61 | to_os_dir_sep, 62 | guarantee_multiindex_rows, 63 | build_color_cycles, 64 | ) 65 | 66 | Tip = namedtuple("Tip", ["msg", "pos"]) 67 | 68 | 69 | class Shortcuts(QDialog): 70 | """Opens a window displaying available napari-deeplabcut shortcuts""" 71 | 72 | def __init__(self, parent): 73 | super().__init__(parent=parent) 74 | self.setParent(parent) 75 | self.setWindowTitle("Shortcuts") 76 | 77 | image_path = str(Path(__file__).parent / "assets" / "napari_shortcuts.svg") 78 | 79 | vlayout = QVBoxLayout() 80 | svg_widget = QSvgWidget(image_path) 81 | svg_widget.setStyleSheet("background-color: white;") 82 | vlayout.addWidget(svg_widget) 83 | self.setLayout(vlayout) 84 | 85 | 86 | class Tutorial(QDialog): 87 | def __init__(self, parent): 88 | super().__init__(parent=parent) 89 | self.setParent(parent) 90 | self.setWindowTitle("Tutorial") 91 | self.setModal(True) 92 | self.setStyleSheet("background:#361AE5") 93 | self.setAttribute(Qt.WA_DeleteOnClose) 94 | self.setWindowOpacity(0.95) 95 | self.setWindowFlags(self.windowFlags() | Qt.WindowCloseButtonHint) 96 | 97 | self._current_tip = -1 98 | self._tips = [ 99 | Tip( 100 | "Load a folder of annotated data\n(and optionally a config file if labeling from scratch)\nfrom the menu File > Open File or Open Folder.\nAlternatively, files and folders of images can be dragged\nand dropped onto the main window.", 101 | (0.35, 0.15), 102 | ), 103 | Tip( 104 | "Data layers will be listed at the bottom left;\ntheir visibility can be toggled by clicking on the small eye icon.", 105 | (0.1, 0.65), 106 | ), 107 | Tip( 108 | "Corresponding layer controls can be found at the top left.\nSwitch between labeling and selection mode using the numeric keys 2 and 3,\nor clicking on the + or -> icons.", 109 | (0.1, 0.2), 110 | ), 111 | Tip( 112 | "There are three keypoint labeling modes:\nthe key M can be used to cycle between them.", 113 | (0.65, 0.05), 114 | ), 115 | Tip( 116 | "When done labeling, save your data by selecting the Points layer\nand hitting Ctrl+S (or File > Save Selected Layer(s)...).", 117 | (0.1, 0.65), 118 | ), 119 | Tip( 120 | "Read more at napari-deeplabcut", 121 | (0.4, 0.4), 122 | ), 123 | ] 124 | 125 | vlayout = QVBoxLayout() 126 | self.message = QLabel("💡\n\nLet's get started with a quick walkthrough!") 127 | self.message.setTextInteractionFlags(Qt.LinksAccessibleByMouse) 128 | self.message.setOpenExternalLinks(True) 129 | vlayout.addWidget(self.message) 130 | 131 | nav_layout = QHBoxLayout() 132 | self.prev_button = QPushButton("<") 133 | self.prev_button.clicked.connect(self.prev_tip) 134 | nav_layout.addWidget(self.prev_button) 135 | self.next_button = QPushButton(">") 136 | self.next_button.clicked.connect(self.next_tip) 137 | nav_layout.addWidget(self.next_button) 138 | 139 | self.update_nav_buttons() 140 | 141 | hlayout = QHBoxLayout() 142 | self.count = QLabel("") 143 | hlayout.addWidget(self.count) 144 | hlayout.addLayout(nav_layout) 145 | vlayout.addLayout(hlayout) 146 | self.setLayout(vlayout) 147 | 148 | def prev_tip(self): 149 | self._current_tip = (self._current_tip - 1) % len(self._tips) 150 | self.update_tip() 151 | self.update_nav_buttons() 152 | 153 | def next_tip(self): 154 | self._current_tip = (self._current_tip + 1) % len(self._tips) 155 | self.update_tip() 156 | self.update_nav_buttons() 157 | 158 | def update_tip(self): 159 | tip = self._tips[self._current_tip] 160 | msg = tip.msg 161 | if ( 162 | self._current_tip < len(self._tips) - 1 163 | ): # No emoji in the last tip otherwise the hyperlink breaks 164 | msg = "💡\n\n" + msg 165 | self.message.setText(msg) 166 | self.count.setText(f"Tip {self._current_tip + 1}|{len(self._tips)}") 167 | self.adjustSize() 168 | xrel, yrel = tip.pos 169 | geom = self.parent().geometry() 170 | p = QPoint( 171 | int(geom.left() + geom.width() * xrel), 172 | int(geom.top() + geom.height() * yrel), 173 | ) 174 | self.move(p) 175 | 176 | def update_nav_buttons(self): 177 | self.prev_button.setEnabled(self._current_tip > 0) 178 | self.next_button.setEnabled(self._current_tip < len(self._tips) - 1) 179 | 180 | 181 | def _get_and_try_preferred_reader( 182 | self, 183 | dialog, 184 | *args, 185 | ): 186 | try: 187 | self.viewer.open( 188 | dialog._current_file, 189 | plugin="napari-deeplabcut", 190 | ) 191 | except ValueError: 192 | self.viewer.open( 193 | dialog._current_file, 194 | plugin="builtins", 195 | ) 196 | 197 | 198 | # Hack to avoid napari's silly variable type guess, 199 | # where property is understood as continuous if 200 | # there are more than 16 unique categories... 201 | def guess_continuous(property): 202 | if issubclass(property.dtype.type, np.floating): 203 | return True 204 | else: 205 | return False 206 | 207 | 208 | color_manager.guess_continuous = guess_continuous 209 | 210 | 211 | def _paste_data(self, store): 212 | """Paste only currently unannotated data.""" 213 | features = self._clipboard.pop("features", None) 214 | if features is None: 215 | return 216 | 217 | unannotated = [ 218 | keypoints.Keypoint(label, id_) not in store.annotated_keypoints 219 | for label, id_ in zip(features["label"], features["id"]) 220 | ] 221 | if not any(unannotated): 222 | return 223 | 224 | new_features = features.iloc[unannotated] 225 | indices_ = self._clipboard.pop("indices") 226 | text_ = self._clipboard.pop("text") 227 | self._clipboard = {k: v[unannotated] for k, v in self._clipboard.items()} 228 | self._clipboard["features"] = new_features 229 | self._clipboard["indices"] = indices_ 230 | if text_ is not None: 231 | new_text = { 232 | "string": text_["string"][unannotated], 233 | "color": text_["color"], 234 | } 235 | self._clipboard["text"] = new_text 236 | 237 | npoints = len(self._view_data) 238 | totpoints = len(self.data) 239 | 240 | if len(self._clipboard.keys()) > 0: 241 | not_disp = self._slice_input.not_displayed 242 | data = deepcopy(self._clipboard["data"]) 243 | offset = [ 244 | self._slice_indices[i] - self._clipboard["indices"][i] for i in not_disp 245 | ] 246 | data[:, not_disp] = data[:, not_disp] + np.array(offset) 247 | self._data = np.append(self.data, data, axis=0) 248 | self._shown = np.append(self.shown, deepcopy(self._clipboard["shown"]), axis=0) 249 | self._size = np.append(self.size, deepcopy(self._clipboard["size"]), axis=0) 250 | self._symbol = np.append( 251 | self.symbol, deepcopy(self._clipboard["symbol"]), axis=0 252 | ) 253 | 254 | self._feature_table.append(self._clipboard["features"]) 255 | 256 | self.text._paste(**self._clipboard["text"]) 257 | 258 | self._edge_width = np.append( 259 | self.edge_width, 260 | deepcopy(self._clipboard["edge_width"]), 261 | axis=0, 262 | ) 263 | self._edge._paste( 264 | colors=self._clipboard["edge_color"], 265 | properties=_features_to_properties(self._clipboard["features"]), 266 | ) 267 | self._face._paste( 268 | colors=self._clipboard["face_color"], 269 | properties=_features_to_properties(self._clipboard["features"]), 270 | ) 271 | 272 | self._selected_view = list( 273 | range(npoints, npoints + len(self._clipboard["data"])) 274 | ) 275 | self._selected_data = set( 276 | range(totpoints, totpoints + len(self._clipboard["data"])) 277 | ) 278 | self.refresh() 279 | 280 | 281 | # Hack to save a KeyPoints layer without showing the Save dialog 282 | def _save_layers_dialog(self, selected=False): 283 | """Save layers (all or selected) to disk, using ``LayerList.save()``. 284 | Parameters 285 | ---------- 286 | selected : bool 287 | If True, only layers that are selected in the viewer will be saved. 288 | By default, all layers are saved. 289 | """ 290 | selected_layers = list(self.viewer.layers.selection) 291 | msg = "" 292 | if not len(self.viewer.layers): 293 | msg = "There are no layers in the viewer to save." 294 | elif selected and not len(selected_layers): 295 | msg = "Please select a Points layer to save." 296 | if msg: 297 | QMessageBox.warning(self, "Nothing to save", msg, QMessageBox.Ok) 298 | return 299 | if len(selected_layers) == 1 and isinstance(selected_layers[0], Points): 300 | self.viewer.layers.save("", selected=True, plugin="napari-deeplabcut") 301 | self.viewer.status = "Data successfully saved" 302 | else: 303 | dlg = QFileDialog() 304 | hist = get_save_history() 305 | dlg.setHistory(hist) 306 | filename, _ = dlg.getSaveFileName( 307 | caption=f'Save {"selected" if selected else "all"} layers', 308 | dir=hist[0], # home dir by default 309 | ) 310 | if filename: 311 | self.viewer.layers.save(filename, selected=selected) 312 | else: 313 | return 314 | self._is_saved = True 315 | self.last_saved_label.setText( 316 | f'Last saved at {str(datetime.now().time()).split(".")[0]}' 317 | ) 318 | self.last_saved_label.show() 319 | 320 | 321 | def on_close(self, event, widget): 322 | if widget._stores and not widget._is_saved: 323 | choice = QMessageBox.warning( 324 | widget, 325 | "Warning", 326 | "Data were not saved. Are you certain you want to leave?", 327 | QMessageBox.Yes | QMessageBox.No, 328 | ) 329 | if choice == QMessageBox.Yes: 330 | event.accept() 331 | else: 332 | event.ignore() 333 | else: 334 | event.accept() 335 | 336 | 337 | # Class taken from https://github.com/matplotlib/napari-matplotlib/blob/53aa5ec95c1f3901e21dedce8347d3f95efe1f79/src/napari_matplotlib/base.py#L309 338 | class NapariNavigationToolbar(NavigationToolbar2QT): 339 | """Custom Toolbar style for Napari.""" 340 | 341 | def __init__(self, *args, **kwargs): # type: ignore[no-untyped-def] 342 | super().__init__(*args, **kwargs) 343 | self.setIconSize(QSize(28, 28)) 344 | 345 | def _update_buttons_checked(self) -> None: 346 | """Update toggle tool icons when selected/unselected.""" 347 | super()._update_buttons_checked() 348 | icon_dir = self.parentWidget()._get_path_to_icon() 349 | 350 | # changes pan/zoom icons depending on state (checked or not) 351 | if "pan" in self._actions: 352 | if self._actions["pan"].isChecked(): 353 | self._actions["pan"].setIcon( 354 | QIcon(os.path.join(icon_dir, "Pan_checked.png")) 355 | ) 356 | else: 357 | self._actions["pan"].setIcon(QIcon(os.path.join(icon_dir, "Pan.png"))) 358 | if "zoom" in self._actions: 359 | if self._actions["zoom"].isChecked(): 360 | self._actions["zoom"].setIcon( 361 | QIcon(os.path.join(icon_dir, "Zoom_checked.png")) 362 | ) 363 | else: 364 | self._actions["zoom"].setIcon(QIcon(os.path.join(icon_dir, "Zoom.png"))) 365 | 366 | 367 | class KeypointMatplotlibCanvas(QWidget): 368 | """ 369 | Class about matplotlib canvas in which I will draw the keypoints over a range of frames 370 | It will be at the bottom of the screen and will use the keypoints from the range of frames to plot them on a x-y time series. 371 | """ 372 | 373 | def __init__(self, napari_viewer, parent=None): 374 | super().__init__(parent=parent) 375 | 376 | self.viewer = napari_viewer 377 | with mplstyle.context(self.mpl_style_sheet_path): 378 | self.canvas = FigureCanvas() 379 | self.canvas.figure.set_layout_engine("constrained") 380 | self.ax = self.canvas.figure.subplots() 381 | self.toolbar = NapariNavigationToolbar(self.canvas, parent=self) 382 | self._replace_toolbar_icons() 383 | self.canvas.mpl_connect("button_press_event", self.on_doubleclick) 384 | self.vline = self.ax.axvline(0, 0, 1, color="k", linestyle="--") 385 | self.ax.set_xlabel("Frame") 386 | self.ax.set_ylabel("Y position") 387 | # Add a slot to specify the range of frames to plot 388 | self.slider = QSlider(Qt.Horizontal) 389 | self.slider.setMinimum(50) 390 | self.slider.setMaximum(10000) 391 | self.slider.setValue(50) 392 | self.slider.setTickPosition(QSlider.TicksBelow) 393 | self.slider.setTickInterval(50) 394 | self.slider_value = QLabel(str(self.slider.value())) 395 | self._window = self.slider.value() 396 | # Connect slider to window setter 397 | self.slider.valueChanged.connect(self.set_window) 398 | 399 | layout = QVBoxLayout() 400 | layout.addWidget(self.canvas) 401 | layout.addWidget(self.toolbar) 402 | layout2 = QHBoxLayout() 403 | layout2.addWidget(self.slider) 404 | layout2.addWidget(self.slider_value) 405 | 406 | layout.addLayout(layout2) 407 | self.setLayout(layout) 408 | 409 | self.frames = [] 410 | self.keypoints = [] 411 | self.df = None 412 | # Make widget larger 413 | self.setMinimumHeight(300) 414 | # connect sliders to update plot 415 | self.viewer.dims.events.current_step.connect(self.update_plot_range) 416 | 417 | # Run update plot range once to initialize the plot 418 | self._n = 0 419 | self.update_plot_range( 420 | Event(type_name="", value=[self.viewer.dims.current_step[0]]) 421 | ) 422 | 423 | self.viewer.layers.events.inserted.connect(self._load_dataframe) 424 | self._lines = {} 425 | 426 | def on_doubleclick(self, event): 427 | if event.dblclick: 428 | show = list(self._lines.values())[0][0].get_visible() 429 | for lines in self._lines.values(): 430 | for l in lines: 431 | l.set_visible(not show) 432 | self._refresh_canvas(value=self._n) 433 | 434 | def _napari_theme_has_light_bg(self) -> bool: 435 | """ 436 | Does this theme have a light background? 437 | 438 | Returns 439 | ------- 440 | bool 441 | True if theme's background colour has hsl lighter than 50%, False if darker. 442 | """ 443 | theme = napari.utils.theme.get_theme(self.viewer.theme, as_dict=False) 444 | _, _, bg_lightness = theme.background.as_hsl_tuple() 445 | return bg_lightness > 0.5 446 | 447 | @property 448 | def mpl_style_sheet_path(self) -> Path: 449 | """ 450 | Path to the set Matplotlib style sheet. 451 | """ 452 | if self._napari_theme_has_light_bg(): 453 | return Path(__file__).parent / "styles" / "light.mplstyle" 454 | else: 455 | return Path(__file__).parent / "styles" / "dark.mplstyle" 456 | 457 | def _get_path_to_icon(self) -> Path: 458 | """ 459 | Get the icons directory (which is theme-dependent). 460 | 461 | Icons modified from 462 | https://github.com/matplotlib/matplotlib/tree/main/lib/matplotlib/mpl-data/images 463 | """ 464 | icon_root = Path(__file__).parent / "assets" 465 | if self._napari_theme_has_light_bg(): 466 | return icon_root / "black" 467 | else: 468 | return icon_root / "white" 469 | 470 | def _replace_toolbar_icons(self) -> None: 471 | """ 472 | Modifies toolbar icons to match the napari theme, and add some tooltips. 473 | """ 474 | icon_dir = self._get_path_to_icon() 475 | for action in self.toolbar.actions(): 476 | text = action.text() 477 | if text == "Pan": 478 | action.setToolTip( 479 | "Pan/Zoom: Left button pans; Right button zooms; " 480 | "Click once to activate; Click again to deactivate" 481 | ) 482 | if text == "Zoom": 483 | action.setToolTip( 484 | "Zoom to rectangle; Click once to activate; " 485 | "Click again to deactivate" 486 | ) 487 | if len(text) > 0: # i.e. not a separator item 488 | icon_path = os.path.join(icon_dir, text + ".png") 489 | action.setIcon(QIcon(icon_path)) 490 | 491 | def _load_dataframe(self): 492 | points_layer = None 493 | for layer in self.viewer.layers: 494 | if isinstance(layer, Points): 495 | points_layer = layer 496 | break 497 | 498 | if points_layer is None or ~np.any(points_layer.data): 499 | return 500 | 501 | self.show() # Silly hack so the window does not hang the first time it is shown 502 | self.hide() 503 | 504 | self.df = _form_df( 505 | points_layer.data, 506 | { 507 | "metadata": points_layer.metadata, 508 | "properties": points_layer.properties, 509 | }, 510 | ) 511 | for keypoint in self.df.columns.get_level_values("bodyparts").unique(): 512 | y = self.df.xs((keypoint, "y"), axis=1, level=["bodyparts", "coords"]) 513 | x = np.arange(len(y)) 514 | color = points_layer.metadata["face_color_cycles"]["label"][keypoint] 515 | lines = self.ax.plot(x, y, color=color, label=keypoint) 516 | self._lines[keypoint] = lines 517 | 518 | self._refresh_canvas(value=self._n) 519 | 520 | def _toggle_line_visibility(self, keypoint): 521 | for artist in self._lines[keypoint]: 522 | artist.set_visible(not artist.get_visible()) 523 | self._refresh_canvas(value=self._n) 524 | 525 | def _refresh_canvas(self, value): 526 | start = max(0, value - self._window // 2) 527 | end = min(value + self._window // 2, len(self.df)) 528 | 529 | self.ax.set_xlim(start, end) 530 | self.vline.set_xdata([value]) 531 | self.canvas.draw() 532 | 533 | def set_window(self, value): 534 | self._window = value 535 | self.slider_value.setText(str(value)) 536 | self.update_plot_range(Event(type_name="", value=[self._n])) 537 | 538 | def update_plot_range(self, event): 539 | value = event.value[0] 540 | self._n = value 541 | 542 | if self.df is None: 543 | return 544 | 545 | self._refresh_canvas(value) 546 | 547 | 548 | class KeypointControls(QWidget): 549 | def __init__(self, napari_viewer): 550 | super().__init__() 551 | self._is_saved = False 552 | 553 | self.viewer = napari_viewer 554 | self.viewer.layers.events.inserted.connect(self.on_insert) 555 | self.viewer.layers.events.removed.connect(self.on_remove) 556 | 557 | self.viewer.window.qt_viewer._get_and_try_preferred_reader = MethodType( 558 | _get_and_try_preferred_reader, 559 | self.viewer.window.qt_viewer, 560 | ) 561 | 562 | status_bar = self.viewer.window._qt_window.statusBar() 563 | self.last_saved_label = QLabel("") 564 | self.last_saved_label.hide() 565 | status_bar.addPermanentWidget(self.last_saved_label) 566 | 567 | # Hack napari's Welcome overlay to show more relevant instructions 568 | overlay = self.viewer.window._qt_viewer._welcome_widget 569 | welcome_widget = overlay.layout().itemAt(1).widget() 570 | welcome_widget.deleteLater() 571 | w = QtWelcomeWidget(None) 572 | overlay._overlay = w 573 | overlay.addWidget(w) 574 | overlay._overlay.sig_dropped.connect(overlay.sig_dropped) 575 | 576 | self._color_mode = keypoints.ColorMode.default() 577 | self._label_mode = keypoints.LabelMode.default() 578 | 579 | # Hold references to the KeypointStores 580 | self._stores = {} 581 | # Intercept close event if data were not saved 582 | self.viewer.window._qt_window.closeEvent = partial( 583 | on_close, 584 | self.viewer.window._qt_window, 585 | widget=self, 586 | ) 587 | 588 | # Storage for extra image metadata that are relevant to other layers. 589 | # These are updated anytime images are added to the Viewer 590 | # and passed on to the other layers upon creation. 591 | self._images_meta = dict() 592 | 593 | # Add some more controls 594 | self._layout = QVBoxLayout(self) 595 | self._menus = [] 596 | self._layer_to_menu = {} 597 | self.viewer.layers.selection.events.active.connect(self.on_active_layer_change) 598 | 599 | self._video_group = self._form_video_action_menu() 600 | self.video_widget = self.viewer.window.add_dock_widget( 601 | self._video_group, name="video", area="right" 602 | ) 603 | self.video_widget.setVisible(False) 604 | 605 | # form helper display 606 | self._keypoint_mapping_button = None 607 | self._func_id = None 608 | help_buttons = self._form_help_buttons() 609 | self._layout.addLayout(help_buttons) 610 | 611 | grid = QGridLayout() 612 | self._trail_cb = QCheckBox("Show trails", parent=self) 613 | self._trail_cb.setToolTip("toggle trails visibility") 614 | self._trail_cb.setChecked(False) 615 | self._trail_cb.setEnabled(False) 616 | self._trail_cb.stateChanged.connect(self._show_trails) 617 | self._trails = None 618 | 619 | self._matplotlib_canvas = KeypointMatplotlibCanvas(self.viewer) 620 | self._matplotlib_cb = QCheckBox("Show matplotlib canvas", parent=self) 621 | self._matplotlib_cb.setToolTip("toggle matplotlib canvas visibility") 622 | self._matplotlib_cb.stateChanged.connect(self._show_matplotlib_canvas) 623 | self._matplotlib_cb.setChecked(False) 624 | self._matplotlib_cb.setEnabled(False) 625 | self._view_scheme_cb = QCheckBox("Show color scheme", parent=self) 626 | 627 | grid.addWidget(self._matplotlib_cb, 0, 0) 628 | grid.addWidget(self._trail_cb, 1, 0) 629 | grid.addWidget(self._view_scheme_cb, 2, 0) 630 | 631 | self._layout.addLayout(grid) 632 | 633 | # form buttons for selection of annotation mode 634 | self._radio_box, self._radio_group = self._form_mode_radio_buttons() 635 | self._radio_box.setEnabled(False) 636 | 637 | # form color scheme display + color mode selector 638 | self._color_grp, self._color_mode_selector = self._form_color_mode_selector() 639 | self._color_grp.setEnabled(False) 640 | self._display = ColorSchemeDisplay(parent=self) 641 | self._color_scheme_display = self._form_color_scheme_display(self.viewer) 642 | self._view_scheme_cb.toggled.connect(self._show_color_scheme) 643 | self._view_scheme_cb.toggle() 644 | self._display.added.connect( 645 | lambda w: w.part_label.clicked.connect( 646 | self._matplotlib_canvas._toggle_line_visibility 647 | ), 648 | ) 649 | 650 | # Substitute default menu action with custom one 651 | for action in self.viewer.window.file_menu.actions()[::-1]: 652 | action_name = action.text().lower() 653 | if "save selected layer" in action_name: 654 | action.triggered.disconnect() 655 | action.triggered.connect( 656 | lambda: _save_layers_dialog( 657 | self, 658 | selected=True, 659 | ) 660 | ) 661 | elif "save all layers" in action_name: 662 | self.viewer.window.file_menu.removeAction(action) 663 | 664 | # Add action to show the walkthrough again 665 | launch_tutorial = QAction("&Launch Tutorial", self) 666 | launch_tutorial.triggered.connect(self.start_tutorial) 667 | self.viewer.window.view_menu.addAction(launch_tutorial) 668 | 669 | # Add action to view keyboard shortcuts 670 | display_shortcuts_action = QAction("&Shortcuts", self) 671 | display_shortcuts_action.triggered.connect(self.display_shortcuts) 672 | self.viewer.window.help_menu.addAction(display_shortcuts_action) 673 | 674 | # Hide some unused viewer buttons 675 | self.viewer.window._qt_viewer.viewerButtons.gridViewButton.hide() 676 | self.viewer.window._qt_viewer.viewerButtons.rollDimsButton.hide() 677 | self.viewer.window._qt_viewer.viewerButtons.transposeDimsButton.hide() 678 | self.viewer.window._qt_viewer.layerButtons.newPointsButton.setDisabled(True) 679 | self.viewer.window._qt_viewer.layerButtons.newLabelsButton.setDisabled(True) 680 | 681 | if self.settings.value("first_launch", True) and not os.environ.get( 682 | "hide_tutorial", False 683 | ): 684 | QTimer.singleShot(10, self.start_tutorial) 685 | self.settings.setValue("first_launch", False) 686 | 687 | # Slightly delay docking so it is shown underneath the KeypointsControls widget 688 | QTimer.singleShot(10, self.silently_dock_matplotlib_canvas) 689 | 690 | def silently_dock_matplotlib_canvas(self): 691 | self.viewer.window.add_dock_widget(self._matplotlib_canvas, name="Trajectory plot", area="right") 692 | self._matplotlib_canvas.hide() 693 | 694 | @cached_property 695 | def settings(self): 696 | return QSettings() 697 | 698 | def load_superkeypoints_diagram(self): 699 | points_layer = None 700 | for layer in self.viewer.layers: 701 | if isinstance(layer, Points): 702 | points_layer = layer 703 | break 704 | 705 | if points_layer is None: 706 | return 707 | 708 | tables = deepcopy(points_layer.metadata.get("tables", {})) 709 | if not tables: 710 | return 711 | 712 | super_animal, table = tables.popitem() 713 | layer_data = _load_superkeypoints_diagram(super_animal) 714 | self.viewer.add_image(layer_data[0], metadata=layer_data[1]) 715 | superkpts_dict = _load_superkeypoints(super_animal) 716 | xy = [] 717 | labels = [] 718 | for kpt_ref, kpt_super in table.items(): 719 | xy.append([0.0, *superkpts_dict[kpt_super]]) 720 | labels.append(kpt_ref) 721 | points_layer.data = np.array(xy) 722 | properties = deepcopy(points_layer.properties) 723 | properties["label"] = np.array(labels) 724 | points_layer.properties = properties 725 | self._keypoint_mapping_button.setText("Map keypoints") 726 | self._keypoint_mapping_button.clicked.disconnect(self._func_id) 727 | self._keypoint_mapping_button.clicked.connect( 728 | lambda: self._map_keypoints(super_animal) 729 | ) 730 | 731 | def _map_keypoints(self, super_animal: str): 732 | points_layer = None 733 | for layer in self.viewer.layers: 734 | if isinstance(layer, Points) and layer.metadata.get("tables"): 735 | points_layer = layer 736 | break 737 | 738 | if points_layer is None or ~np.any(points_layer.data): 739 | return 740 | 741 | xy = points_layer.data[:, 1:3] 742 | superkpts_dict = _load_superkeypoints(super_animal) 743 | xy_ref = np.c_[[val for val in superkpts_dict.values()]] 744 | neighbors = keypoints._find_nearest_neighbors(xy, xy_ref) 745 | found = neighbors != -1 746 | if ~np.any(found): 747 | return 748 | 749 | project_path = points_layer.metadata["project"] 750 | config_path = str(Path(project_path) / "config.yaml") 751 | cfg = _load_config(config_path) 752 | conversion_tables = cfg.get("SuperAnimalConversionTables", {}) 753 | conversion_tables[super_animal] = dict( 754 | zip( 755 | map( 756 | str, points_layer.metadata["header"].bodyparts 757 | ), # Needed to fix an ugly yaml RepresenterError 758 | map(str, list(np.array(list(superkpts_dict))[neighbors[found]])), 759 | ) 760 | ) 761 | _write_config(config_path, cfg) 762 | self.viewer.status = "Mapping to superkeypoint set successfully saved" 763 | 764 | def start_tutorial(self): 765 | Tutorial(self.viewer.window._qt_window.current()).show() 766 | 767 | def display_shortcuts(self): 768 | Shortcuts(self.viewer.window._qt_window.current()).show() 769 | 770 | def _move_image_layer_to_bottom(self, index): 771 | if (ind := index) != 0: 772 | self.viewer.layers.move_selected(ind, 0) 773 | self.viewer.layers.select_next() # Auto-select the Points layer 774 | 775 | def _show_color_scheme(self): 776 | show = self._view_scheme_cb.isChecked() 777 | self._color_scheme_display.setVisible(show) 778 | 779 | def _show_trails(self, state): 780 | if Qt.CheckState(state) == Qt.CheckState.Checked: 781 | if self._trails is None: 782 | store = list(self._stores.values())[0] 783 | categories = store.layer.properties["id"] 784 | if not categories[0]: # Single animal data 785 | categories = store.layer.properties["label"] 786 | inds = encode_categories(categories) 787 | temp = np.c_[inds, store.layer.data] 788 | cmap = "viridis" 789 | for layer in self.viewer.layers: 790 | if isinstance(layer, Points) and layer.metadata: 791 | cmap = layer.metadata["colormap_name"] 792 | self._trails = self.viewer.add_tracks( 793 | temp, 794 | tail_length=50, 795 | head_length=50, 796 | tail_width=6, 797 | name="trails", 798 | colormap=cmap, 799 | ) 800 | self._trails.visible = True 801 | elif self._trails is not None: 802 | self._trails.visible = False 803 | 804 | def _show_matplotlib_canvas(self, state): 805 | if Qt.CheckState(state) == Qt.CheckState.Checked: 806 | self._matplotlib_canvas.show() 807 | self.viewer.window._qt_window.update() 808 | else: 809 | self._matplotlib_canvas.hide() 810 | 811 | def _form_video_action_menu(self): 812 | group_box = QGroupBox("Video") 813 | layout = QVBoxLayout() 814 | extract_button = QPushButton("Extract frame") 815 | extract_button.clicked.connect(self._extract_single_frame) 816 | layout.addWidget(extract_button) 817 | crop_button = QPushButton("Store crop coordinates") 818 | crop_button.clicked.connect(self._store_crop_coordinates) 819 | layout.addWidget(crop_button) 820 | group_box.setLayout(layout) 821 | return group_box 822 | 823 | def _form_help_buttons(self): 824 | layout = QVBoxLayout() 825 | show_shortcuts = QPushButton("View shortcuts") 826 | show_shortcuts.clicked.connect(self.display_shortcuts) 827 | layout.addWidget(show_shortcuts) 828 | tutorial = QPushButton("Start tutorial") 829 | tutorial.clicked.connect(self.start_tutorial) 830 | layout.addWidget(tutorial) 831 | self._keypoint_mapping_button = QPushButton("Load superkeypoints diagram") 832 | self._func_id = self._keypoint_mapping_button.clicked.connect( 833 | self.load_superkeypoints_diagram 834 | ) 835 | self._keypoint_mapping_button.hide() 836 | layout.addWidget(self._keypoint_mapping_button) 837 | return layout 838 | 839 | def _extract_single_frame(self, *args): 840 | image_layer = None 841 | points_layer = None 842 | for layer in self.viewer.layers: 843 | if isinstance(layer, Image): 844 | image_layer = layer 845 | elif isinstance(layer, Points): 846 | points_layer = layer 847 | if image_layer is not None: 848 | ind = self.viewer.dims.current_step[0] 849 | frame = image_layer.data[ind] 850 | n_frames = image_layer.data.shape[0] 851 | name = f"img{str(ind).zfill(int(ceil(log10(n_frames))))}.png" 852 | output_path = os.path.join(image_layer.metadata["root"], name) 853 | _write_image(frame, str(output_path)) 854 | 855 | # If annotations were loaded, they should be written to a machinefile.h5 file 856 | if points_layer is not None: 857 | df = _form_df( 858 | points_layer.data, 859 | { 860 | "metadata": points_layer.metadata, 861 | "properties": points_layer.properties, 862 | }, 863 | ) 864 | df = df.iloc[ind : ind + 1] 865 | df.index = pd.MultiIndex.from_tuples([Path(output_path).parts[-3:]]) 866 | filepath = os.path.join( 867 | image_layer.metadata["root"], "machinelabels-iter0.h5" 868 | ) 869 | if Path(filepath).is_file(): 870 | df_prev = pd.read_hdf(filepath) 871 | guarantee_multiindex_rows(df_prev) 872 | df = pd.concat([df_prev, df]) 873 | df = df[~df.index.duplicated(keep="first")] 874 | df.to_hdf(filepath, key="machinelabels") 875 | 876 | def _store_crop_coordinates(self, *args): 877 | if not (project_path := self._images_meta.get("project")): 878 | return 879 | for layer in self.viewer.layers: 880 | if isinstance(layer, Shapes): 881 | try: 882 | ind = layer.shape_type.index("rectangle") 883 | except ValueError: 884 | return 885 | bbox = layer.data[ind][:, 1:] 886 | h = self.viewer.dims.range[2][1] 887 | bbox[:, 0] = h - bbox[:, 0] 888 | bbox = np.clip(bbox, 0, a_max=None).astype(int) 889 | y1, x1 = bbox.min(axis=0) 890 | y2, x2 = bbox.max(axis=0) 891 | temp = {"crop": ", ".join(map(str, [x1, x2, y1, y2]))} 892 | config_path = os.path.join(project_path, "config.yaml") 893 | cfg = _load_config(config_path) 894 | cfg["video_sets"][ 895 | os.path.join(project_path, "videos", self._images_meta["name"]) 896 | ] = temp 897 | _write_config(config_path, cfg) 898 | break 899 | 900 | def _form_dropdown_menus(self, store): 901 | menu = KeypointsDropdownMenu(store) 902 | self.viewer.dims.events.current_step.connect( 903 | menu.smart_reset, 904 | position="last", 905 | ) 906 | menu.smart_reset(event=None) 907 | self._menus.append(menu) 908 | self._layer_to_menu[store.layer] = len(self._menus) - 1 909 | layout = QVBoxLayout() 910 | layout.addWidget(menu) 911 | self._layout.addLayout(layout) 912 | 913 | def _form_mode_radio_buttons(self): 914 | group_box = QGroupBox("Labeling mode") 915 | layout = QHBoxLayout() 916 | group = QButtonGroup(self) 917 | for i, mode in enumerate(keypoints.LabelMode.__members__, start=1): 918 | btn = QRadioButton(mode.lower()) 919 | btn.setToolTip(keypoints.TOOLTIPS[mode]) 920 | group.addButton(btn, i) 921 | layout.addWidget(btn) 922 | group.button(1).setChecked(True) 923 | group_box.setLayout(layout) 924 | self._layout.addWidget(group_box) 925 | 926 | def _func(): 927 | self.label_mode = group.checkedButton().text() 928 | 929 | group.buttonClicked.connect(_func) 930 | return group_box, group 931 | 932 | def _form_color_mode_selector(self): 933 | group_box = QGroupBox("Keypoint coloring mode") 934 | layout = QHBoxLayout() 935 | group = QButtonGroup(self) 936 | for i, mode in enumerate(keypoints.ColorMode.__members__, start=1): 937 | btn = QRadioButton(mode.lower()) 938 | group.addButton(btn, i) 939 | layout.addWidget(btn) 940 | group.button(1).setChecked(True) 941 | group_box.setLayout(layout) 942 | self._layout.addWidget(group_box) 943 | 944 | def _func(): 945 | self.color_mode = group.checkedButton().text() 946 | 947 | group.buttonClicked.connect(_func) 948 | return group_box, group 949 | 950 | def _form_color_scheme_display(self, viewer): 951 | self.viewer.layers.events.inserted.connect(self._update_color_scheme) 952 | return viewer.window.add_dock_widget( 953 | self._display, name="Color scheme reference", area="left" 954 | ) 955 | 956 | def _update_color_scheme(self): 957 | def to_hex(nparray): 958 | a = np.array(nparray * 255, dtype=int) 959 | rgb2hex = lambda r, g, b, _: f"#{r:02x}{g:02x}{b:02x}" 960 | res = rgb2hex(*a) 961 | return res 962 | 963 | self._display.reset() 964 | mode = "label" 965 | if self.color_mode == str(keypoints.ColorMode.INDIVIDUAL): 966 | mode = "id" 967 | 968 | for layer in self.viewer.layers: 969 | if isinstance(layer, Points) and layer.metadata: 970 | self._display.update_color_scheme( 971 | { 972 | name: to_hex(color) 973 | for name, color in layer.metadata["face_color_cycles"][ 974 | mode 975 | ].items() 976 | } 977 | ) 978 | 979 | def _remap_frame_indices(self, layer): 980 | if not self._images_meta.get("paths"): 981 | return 982 | 983 | new_paths = [to_os_dir_sep(p) for p in self._images_meta["paths"]] 984 | paths = layer.metadata.get("paths") 985 | if paths is not None and np.any(layer.data): 986 | paths_map = dict(zip(range(len(paths)), map(to_os_dir_sep, paths))) 987 | # Discard data if there are missing frames 988 | missing = [i for i, path in paths_map.items() if path not in new_paths] 989 | if missing: 990 | if isinstance(layer.data, list): 991 | inds_to_remove = [ 992 | i 993 | for i, verts in enumerate(layer.data) 994 | if verts[0, 0] in missing 995 | ] 996 | else: 997 | inds_to_remove = np.flatnonzero(np.isin(layer.data[:, 0], missing)) 998 | layer.selected_data = inds_to_remove 999 | layer.remove_selected() 1000 | for i in missing: 1001 | paths_map.pop(i) 1002 | 1003 | # Check now whether there are new frames 1004 | temp = {k: new_paths.index(v) for k, v in paths_map.items()} 1005 | data = layer.data 1006 | if isinstance(data, list): 1007 | for verts in data: 1008 | verts[:, 0] = np.vectorize(temp.get)(verts[:, 0]) 1009 | else: 1010 | data[:, 0] = np.vectorize(temp.get)(data[:, 0]) 1011 | layer.data = data 1012 | layer.metadata.update(self._images_meta) 1013 | 1014 | def on_insert(self, event): 1015 | layer = event.source[-1] 1016 | logging.debug(f"Inserting Layer {layer}") 1017 | if isinstance(layer, Image): 1018 | paths = layer.metadata.get("paths") 1019 | if paths is None and is_video(layer.name): 1020 | self.video_widget.setVisible(True) 1021 | # Store the metadata and pass them on to the other layers 1022 | self._images_meta.update( 1023 | { 1024 | "paths": paths, 1025 | "shape": layer.level_shapes[0], 1026 | "root": layer.metadata["root"], 1027 | "name": layer.name, 1028 | } 1029 | ) 1030 | # Delay layer sorting 1031 | QTimer.singleShot( 1032 | 10, partial(self._move_image_layer_to_bottom, event.index) 1033 | ) 1034 | elif isinstance(layer, Points): 1035 | # If the current Points layer comes from a config file, some have already 1036 | # been added and the body part names are different from the existing ones, 1037 | # then we update store's metadata and menus. 1038 | if layer.metadata.get("project", "") and self._stores: 1039 | new_metadata = layer.metadata.copy() 1040 | 1041 | keypoints_menu = self._menus[0].menus["label"] 1042 | current_keypoint_set = set( 1043 | keypoints_menu.itemText(i) for i in range(keypoints_menu.count()) 1044 | ) 1045 | new_keypoint_set = set(new_metadata["header"].bodyparts) 1046 | diff = new_keypoint_set.difference(current_keypoint_set) 1047 | if diff: 1048 | answer = QMessageBox.question( 1049 | self, "", "Do you want to display the new keypoints only?" 1050 | ) 1051 | if answer == QMessageBox.Yes: 1052 | self.viewer.layers[-2].shown = False 1053 | 1054 | self.viewer.status = f"New keypoint{'s' if len(diff) > 1 else ''} {', '.join(diff)} found." 1055 | for _layer, store in self._stores.items(): 1056 | _layer.metadata["header"] = new_metadata["header"] 1057 | store.layer = _layer 1058 | 1059 | for menu in self._menus: 1060 | menu._map_individuals_to_bodyparts() 1061 | menu._update_items() 1062 | 1063 | # Remove the unnecessary layer newly added 1064 | QTimer.singleShot(10, self.viewer.layers.pop) 1065 | 1066 | # Always update the colormap to reflect the one in the config.yaml file 1067 | for _layer, store in self._stores.items(): 1068 | _layer.metadata["face_color_cycles"] = new_metadata[ 1069 | "face_color_cycles" 1070 | ] 1071 | _layer.face_color = "label" 1072 | _layer.face_color_cycle = new_metadata["face_color_cycles"]["label"] 1073 | _layer.events.face_color() 1074 | store.layer = _layer 1075 | self._update_color_scheme() 1076 | 1077 | return 1078 | 1079 | if layer.metadata.get("tables", ""): 1080 | self._keypoint_mapping_button.show() 1081 | 1082 | store = keypoints.KeypointStore(self.viewer, layer) 1083 | self._stores[layer] = store 1084 | # TODO Set default dir of the save file dialog 1085 | if root := layer.metadata.get("root"): 1086 | update_save_history(root) 1087 | layer.metadata["controls"] = self 1088 | layer.text.visible = False 1089 | layer.bind_key("M", self.cycle_through_label_modes) 1090 | layer.bind_key("F", self.cycle_through_color_modes) 1091 | func = partial(_paste_data, store=store) 1092 | layer._paste_data = MethodType(func, layer) 1093 | layer.add = MethodType(keypoints._add, store) 1094 | layer.events.add(query_next_frame=Event) 1095 | layer.events.query_next_frame.connect(store._advance_step) 1096 | layer.bind_key("Shift-Right", store._find_first_unlabeled_frame) 1097 | layer.bind_key("Shift-Left", store._find_first_unlabeled_frame) 1098 | 1099 | layer.bind_key("Down", store.next_keypoint, overwrite=True) 1100 | layer.bind_key("Up", store.prev_keypoint, overwrite=True) 1101 | layer.face_color_mode = "cycle" 1102 | self._form_dropdown_menus(store) 1103 | 1104 | self._images_meta.update( 1105 | { 1106 | "project": layer.metadata.get("project"), 1107 | } 1108 | ) 1109 | self._radio_box.setEnabled(True) 1110 | self._color_grp.setEnabled(True) 1111 | self._trail_cb.setEnabled(True) 1112 | self._matplotlib_cb.setEnabled(True) 1113 | 1114 | # Hide the color pickers, as colormaps are strictly defined by users 1115 | controls = self.viewer.window.qt_viewer.dockLayerControls 1116 | point_controls = controls.widget().widgets[layer] 1117 | point_controls.faceColorEdit.hide() 1118 | point_controls.edgeColorEdit.hide() 1119 | point_controls.layout().itemAt(9).widget().hide() 1120 | point_controls.layout().itemAt(11).widget().hide() 1121 | # Hide out of slice checkbox 1122 | point_controls.outOfSliceCheckBox.hide() 1123 | point_controls.layout().itemAt(15).widget().hide() 1124 | # Add dropdown menu for colormap picking 1125 | colormap_selector = DropdownMenu(plt.colormaps, self) 1126 | colormap_selector.update_to(layer.metadata["colormap_name"]) 1127 | colormap_selector.currentTextChanged.connect(self._update_colormap) 1128 | point_controls.layout().addRow("colormap", colormap_selector) 1129 | 1130 | for layer_ in self.viewer.layers: 1131 | if not isinstance(layer_, Image): 1132 | self._remap_frame_indices(layer_) 1133 | 1134 | def on_remove(self, event): 1135 | layer = event.value 1136 | n_points_layer = sum(isinstance(l, Points) for l in self.viewer.layers) 1137 | if isinstance(layer, Points) and n_points_layer == 0: 1138 | if self._color_scheme_display is not None: 1139 | self._display.reset() 1140 | self._stores.pop(layer, None) 1141 | while self._menus: 1142 | menu = self._menus.pop() 1143 | self._layout.removeWidget(menu) 1144 | menu.deleteLater() 1145 | menu.destroy() 1146 | self._layer_to_menu = {} 1147 | self._trail_cb.setEnabled(False) 1148 | self._matplotlib_cb.setEnabled(False) 1149 | self.last_saved_label.hide() 1150 | elif isinstance(layer, Image): 1151 | self._images_meta = dict() 1152 | paths = layer.metadata.get("paths") 1153 | if paths is None: 1154 | self.video_widget.setVisible(False) 1155 | elif isinstance(layer, Tracks): 1156 | self._trail_cb.setChecked(False) 1157 | self._matplotlib_cb.setChecked(False) 1158 | self._trails = None 1159 | 1160 | def on_active_layer_change(self, event) -> None: 1161 | """Updates the GUI when the active layer changes 1162 | * Hides all KeypointsDropdownMenu that aren't for the selected layer 1163 | * Sets the visibility of the "Color mode" box to True if the selected layer 1164 | is a multi-animal one, or False otherwise 1165 | """ 1166 | self._color_grp.setVisible(self._is_multianimal(event.value)) 1167 | menu_idx = -1 1168 | if event.value is not None and isinstance(event.value, Points): 1169 | menu_idx = self._layer_to_menu.get(event.value, -1) 1170 | 1171 | for idx, menu in enumerate(self._menus): 1172 | if idx == menu_idx: 1173 | menu.setHidden(False) 1174 | else: 1175 | menu.setHidden(True) 1176 | 1177 | def _update_colormap(self, colormap_name): 1178 | for layer in self.viewer.layers.selection: 1179 | if isinstance(layer, Points) and layer.metadata: 1180 | face_color_cycle_maps = build_color_cycles( 1181 | layer.metadata["header"], 1182 | colormap_name, 1183 | ) 1184 | layer.metadata["face_color_cycles"] = face_color_cycle_maps 1185 | face_color_prop = "label" 1186 | if self.color_mode == str(keypoints.ColorMode.INDIVIDUAL): 1187 | face_color_prop = "id" 1188 | 1189 | layer.face_color = face_color_prop 1190 | layer.face_color_cycle = face_color_cycle_maps[face_color_prop] 1191 | layer.events.face_color() 1192 | self._update_color_scheme() 1193 | 1194 | @register_points_action("Change labeling mode") 1195 | def cycle_through_label_modes(self, *args): 1196 | self.label_mode = next(keypoints.LabelMode) 1197 | 1198 | @register_points_action("Change color mode") 1199 | def cycle_through_color_modes(self, *args): 1200 | if self._active_layer_is_multianimal() or self.color_mode != str( 1201 | keypoints.ColorMode.BODYPART 1202 | ): 1203 | self.color_mode = next(keypoints.ColorMode) 1204 | 1205 | @property 1206 | def label_mode(self): 1207 | return str(self._label_mode) 1208 | 1209 | @label_mode.setter 1210 | def label_mode(self, mode: Union[str, keypoints.LabelMode]): 1211 | self._label_mode = keypoints.LabelMode(mode) 1212 | self.viewer.status = self.label_mode 1213 | mode_ = str(mode) 1214 | if mode_ == "loop": 1215 | for menu in self._menus: 1216 | menu._locked = True 1217 | else: 1218 | for menu in self._menus: 1219 | menu._locked = False 1220 | for btn in self._radio_group.buttons(): 1221 | if btn.text() == mode_: 1222 | btn.setChecked(True) 1223 | break 1224 | 1225 | @property 1226 | def color_mode(self): 1227 | return str(self._color_mode) 1228 | 1229 | @color_mode.setter 1230 | def color_mode(self, mode: Union[str, keypoints.ColorMode]): 1231 | self._color_mode = keypoints.ColorMode(mode) 1232 | if self._color_mode == keypoints.ColorMode.BODYPART: 1233 | face_color_mode = "label" 1234 | else: 1235 | face_color_mode = "id" 1236 | 1237 | for layer in self.viewer.layers: 1238 | if isinstance(layer, Points) and layer.metadata: 1239 | layer.face_color = face_color_mode 1240 | layer.face_color_cycle = layer.metadata["face_color_cycles"][ 1241 | face_color_mode 1242 | ] 1243 | layer.events.face_color() 1244 | 1245 | for btn in self._color_mode_selector.buttons(): 1246 | if btn.text() == str(mode): 1247 | btn.setChecked(True) 1248 | break 1249 | 1250 | self._update_color_scheme() 1251 | 1252 | def _is_multianimal(self, layer) -> bool: 1253 | is_multi = False 1254 | if layer is not None and isinstance(layer, Points): 1255 | try: 1256 | header = layer.metadata.get("header") 1257 | if header is not None: 1258 | ids = header.individuals 1259 | is_multi = len(ids) > 0 and ids[0] != "" 1260 | except AttributeError: 1261 | pass 1262 | 1263 | return is_multi 1264 | 1265 | def _active_layer_is_multianimal(self) -> bool: 1266 | """Returns: whether the active layer is a multi-animal points layer""" 1267 | for layer in self.viewer.layers.selection: 1268 | if self._is_multianimal(layer): 1269 | return True 1270 | 1271 | return False 1272 | 1273 | 1274 | @Points.bind_key("E") 1275 | def toggle_edge_color(layer): 1276 | # Trick to toggle between 0 and 2 1277 | layer.edge_width = np.bitwise_xor(layer.edge_width, 2) 1278 | 1279 | 1280 | class DropdownMenu(QComboBox): 1281 | def __init__(self, labels: Sequence[str], parent: Optional[QWidget] = None): 1282 | super().__init__(parent) 1283 | self.update_items(labels) 1284 | 1285 | def update_to(self, text: str): 1286 | index = self.findText(text) 1287 | if index >= 0: 1288 | self.setCurrentIndex(index) 1289 | 1290 | def reset(self): 1291 | self.setCurrentIndex(0) 1292 | 1293 | def update_items(self, items): 1294 | self.clear() 1295 | self.addItems(items) 1296 | 1297 | 1298 | class KeypointsDropdownMenu(QWidget): 1299 | def __init__( 1300 | self, 1301 | store: keypoints.KeypointStore, 1302 | parent: Optional[QWidget] = None, 1303 | ): 1304 | super().__init__(parent) 1305 | self.store = store 1306 | self.store.layer.events.current_properties.connect(self.update_menus) 1307 | self._locked = False 1308 | 1309 | self.id2label = defaultdict(list) 1310 | self.menus = dict() 1311 | self._map_individuals_to_bodyparts() 1312 | self._populate_menus() 1313 | 1314 | layout1 = QVBoxLayout() 1315 | layout1.addStretch(1) 1316 | group_box = QGroupBox("Keypoint selection") 1317 | layout2 = QVBoxLayout() 1318 | for menu in self.menus.values(): 1319 | layout2.addWidget(menu) 1320 | group_box.setLayout(layout2) 1321 | layout1.addWidget(group_box) 1322 | self.setLayout(layout1) 1323 | 1324 | def _map_individuals_to_bodyparts(self): 1325 | self.id2label.clear() # Empty dict so entries are ordered as in the config 1326 | for keypoint in self.store._keypoints: 1327 | label = keypoint.label 1328 | id_ = keypoint.id 1329 | if label not in self.id2label[id_]: 1330 | self.id2label[id_].append(label) 1331 | 1332 | def _populate_menus(self): 1333 | id_ = self.store.ids[0] 1334 | if id_: 1335 | menu = create_dropdown_menu(self.store, list(self.id2label), "id") 1336 | menu.currentTextChanged.connect(self.refresh_label_menu) 1337 | self.menus["id"] = menu 1338 | self.menus["label"] = create_dropdown_menu( 1339 | self.store, 1340 | self.id2label[id_], 1341 | "label", 1342 | ) 1343 | 1344 | def _update_items(self): 1345 | id_ = self.store.ids[0] 1346 | if id_: 1347 | self.menus["id"].update_items(list(self.id2label)) 1348 | self.menus["label"].update_items(self.id2label[id_]) 1349 | 1350 | def update_menus(self, event): 1351 | keypoint = self.store.current_keypoint 1352 | for attr, menu in self.menus.items(): 1353 | val = getattr(keypoint, attr) 1354 | if menu.currentText() != val: 1355 | menu.update_to(val) 1356 | 1357 | def refresh_label_menu(self, text: str): 1358 | menu = self.menus["label"] 1359 | menu.blockSignals(True) 1360 | menu.clear() 1361 | menu.blockSignals(False) 1362 | menu.addItems(self.id2label[text]) 1363 | 1364 | def smart_reset(self, event): 1365 | """Set current keypoint to the first unlabeled one.""" 1366 | if self._locked: # The currently selected point is not updated 1367 | return 1368 | unannotated = "" 1369 | already_annotated = self.store.annotated_keypoints 1370 | for keypoint in self.store._keypoints: 1371 | if keypoint not in already_annotated: 1372 | unannotated = keypoint 1373 | break 1374 | self.store.current_keypoint = ( 1375 | unannotated if unannotated else self.store._keypoints[0] 1376 | ) 1377 | 1378 | 1379 | def create_dropdown_menu(store, items, attr): 1380 | menu = DropdownMenu(items) 1381 | 1382 | def item_changed(ind): 1383 | current_item = menu.itemText(ind) 1384 | if current_item is not None: 1385 | setattr(store, f"current_{attr}", current_item) 1386 | 1387 | menu.currentIndexChanged.connect(item_changed) 1388 | return menu 1389 | 1390 | 1391 | # WelcomeWidget modified from: 1392 | # https://github.com/napari/napari/blob/a72d512972a274380645dae16b9aa93de38c3ba2/napari/_qt/widgets/qt_welcome.py#L28 1393 | class QtWelcomeWidget(QWidget): 1394 | """Welcome widget to display initial information and shortcuts to user.""" 1395 | 1396 | sig_dropped = Signal("QEvent") 1397 | 1398 | def __init__(self, parent): 1399 | super().__init__(parent) 1400 | 1401 | # Create colored icon using theme 1402 | self._image = QLabel() 1403 | self._image.setObjectName("logo_silhouette") 1404 | self._image.setMinimumSize(300, 300) 1405 | self._label = QtWelcomeLabel( 1406 | """ 1407 | Drop a folder from within a DeepLabCut's labeled-data directory, 1408 | and, if labeling from scratch, 1409 | the corresponding project's config.yaml file. 1410 | """ 1411 | ) 1412 | 1413 | # Widget setup 1414 | self.setAutoFillBackground(True) 1415 | self.setAcceptDrops(True) 1416 | self._image.setAlignment(Qt.AlignCenter) 1417 | self._label.setAlignment(Qt.AlignCenter) 1418 | 1419 | # Layout 1420 | text_layout = QVBoxLayout() 1421 | text_layout.addWidget(self._label) 1422 | 1423 | layout = QVBoxLayout() 1424 | layout.addStretch() 1425 | layout.setSpacing(30) 1426 | layout.addWidget(self._image) 1427 | layout.addLayout(text_layout) 1428 | layout.addStretch() 1429 | 1430 | self.setLayout(layout) 1431 | 1432 | def paintEvent(self, event): 1433 | """Override Qt method. 1434 | 1435 | Parameters 1436 | ---------- 1437 | event : qtpy.QtCore.QEvent 1438 | Event from the Qt context. 1439 | """ 1440 | option = QStyleOption() 1441 | option.initFrom(self) 1442 | p = QPainter(self) 1443 | self.style().drawPrimitive(QStyle.PE_Widget, option, p, self) 1444 | 1445 | def _update_property(self, prop, value): 1446 | """Update properties of widget to update style. 1447 | 1448 | Parameters 1449 | ---------- 1450 | prop : str 1451 | Property name to update. 1452 | value : bool 1453 | Property value to update. 1454 | """ 1455 | self.setProperty(prop, value) 1456 | self.style().unpolish(self) 1457 | self.style().polish(self) 1458 | 1459 | def dragEnterEvent(self, event): 1460 | """Override Qt method. 1461 | 1462 | Provide style updates on event. 1463 | 1464 | Parameters 1465 | ---------- 1466 | event : qtpy.QtCore.QEvent 1467 | Event from the Qt context. 1468 | """ 1469 | self._update_property("drag", True) 1470 | if event.mimeData().hasUrls(): 1471 | event.accept() 1472 | else: 1473 | event.ignore() 1474 | 1475 | def dragLeaveEvent(self, event): 1476 | """Override Qt method. 1477 | 1478 | Provide style updates on event. 1479 | 1480 | Parameters 1481 | ---------- 1482 | event : qtpy.QtCore.QEvent 1483 | Event from the Qt context. 1484 | """ 1485 | self._update_property("drag", False) 1486 | 1487 | def dropEvent(self, event): 1488 | """Override Qt method. 1489 | 1490 | Provide style updates on event and emit the drop event. 1491 | 1492 | Parameters 1493 | ---------- 1494 | event : qtpy.QtCore.QEvent 1495 | Event from the Qt context. 1496 | """ 1497 | self._update_property("drag", False) 1498 | self.sig_dropped.emit(event) 1499 | 1500 | 1501 | class ClickableLabel(QLabel): 1502 | clicked = Signal(str) 1503 | 1504 | def __init__(self, text="", color="turquoise", parent=None): 1505 | super().__init__(text, parent) 1506 | self._default_style = self.styleSheet() 1507 | self.color = color 1508 | 1509 | def mousePressEvent(self, event): 1510 | self.clicked.emit(self.text()) 1511 | 1512 | def enterEvent(self, event): 1513 | self.setCursor(QCursor(Qt.PointingHandCursor)) 1514 | self.setStyleSheet(f"color: {self.color}") 1515 | 1516 | def leaveEvent(self, event): 1517 | self.unsetCursor() 1518 | self.setStyleSheet(self._default_style) 1519 | 1520 | 1521 | class LabelPair(QWidget): 1522 | def __init__(self, color: str, name: str, parent: QWidget): 1523 | super().__init__(parent) 1524 | 1525 | self._color = color 1526 | self._part_name = name 1527 | 1528 | self.color_label = QLabel("", parent=self) 1529 | self.part_label = ClickableLabel(name, color=color, parent=self) 1530 | 1531 | self.color_label.setToolTip(name) 1532 | self.part_label.setToolTip(name) 1533 | 1534 | self._format_label(self.color_label, 10, 10) 1535 | self._format_label(self.part_label) 1536 | 1537 | self.color_label.setStyleSheet(f"background-color: {color};") 1538 | 1539 | self._build() 1540 | 1541 | @staticmethod 1542 | def _format_label(label: QLabel, height: int = None, width: int = None): 1543 | label.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) 1544 | if height is not None: 1545 | label.setMaximumHeight(height) 1546 | if width is not None: 1547 | label.setMaximumWidth(width) 1548 | 1549 | def _build(self): 1550 | layout = QHBoxLayout() 1551 | layout.addWidget(self.color_label, alignment=Qt.AlignmentFlag.AlignLeft) 1552 | layout.addWidget(self.part_label, alignment=Qt.AlignmentFlag.AlignLeft) 1553 | self.setLayout(layout) 1554 | 1555 | @property 1556 | def color(self): 1557 | return self._color 1558 | 1559 | @color.setter 1560 | def color(self, color: str): 1561 | self._color = color 1562 | self.color_label.setStyleSheet(f"background-color: {color};") 1563 | 1564 | @property 1565 | def part_name(self): 1566 | return self._part_name 1567 | 1568 | @part_name.setter 1569 | def part_name(self, part_name: str): 1570 | self._part_name = part_name 1571 | self.part_label.setText(part_name) 1572 | self.part_label.setToolTip(part_name) 1573 | self.color_label.setToolTip(part_name) 1574 | 1575 | 1576 | class ColorSchemeDisplay(QScrollArea): 1577 | added = Signal(object) 1578 | 1579 | def __init__(self, parent): 1580 | super().__init__(parent) 1581 | 1582 | self.scheme_dict = {} # {name: color} mapping 1583 | self._layout = QVBoxLayout() 1584 | self._layout.setSpacing(0) 1585 | self._container = QWidget( 1586 | parent=self 1587 | ) # workaround to use setWidget, let me know if there's a better option 1588 | 1589 | self._build() 1590 | 1591 | @property 1592 | def labels(self): 1593 | labels = [] 1594 | for i in range(self._layout.count()): 1595 | item = self._layout.itemAt(i) 1596 | if w := item.widget(): 1597 | labels.append(w) 1598 | return labels 1599 | 1600 | def _build(self): 1601 | self._container.setSizePolicy( 1602 | QSizePolicy.Fixed, QSizePolicy.Maximum 1603 | ) # feel free to change those 1604 | self._container.setLayout(self._layout) 1605 | self._container.adjustSize() 1606 | 1607 | self.setWidget(self._container) 1608 | 1609 | self.setWidgetResizable(True) 1610 | self.setSizePolicy( 1611 | QSizePolicy.MinimumExpanding, QSizePolicy.MinimumExpanding 1612 | ) # feel free to change those 1613 | # self.setMaximumHeight(150) 1614 | self.setBaseSize(100, 200) 1615 | 1616 | self.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOn) 1617 | self.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff) 1618 | 1619 | def add_entry(self, name, color): 1620 | self.scheme_dict.update({name: color}) 1621 | 1622 | widget = LabelPair(color, name, self) 1623 | self._layout.addWidget(widget, alignment=Qt.AlignmentFlag.AlignLeft) 1624 | self.added.emit(widget) 1625 | 1626 | def update_color_scheme(self, new_color_scheme) -> None: 1627 | logging.debug(f"Updating color scheme: {self._layout.count()} widgets") 1628 | self.scheme_dict = {name: color for name, color in new_color_scheme.items()} 1629 | names = list(new_color_scheme.keys()) 1630 | existing_widgets = self._layout.count() 1631 | required_widgets = len(self.scheme_dict) 1632 | 1633 | # update existing widgets 1634 | for idx in range(min(existing_widgets, required_widgets)): 1635 | logging.debug(f" updating {idx}") 1636 | w = self._layout.itemAt(idx).widget() 1637 | w.setVisible(True) 1638 | w.part_name = names[idx] 1639 | w.color = self.scheme_dict[names[idx]] 1640 | 1641 | # remove extra widgets 1642 | for i in range(max(existing_widgets - required_widgets, 0)): 1643 | logging.debug(f" hiding {required_widgets + i}") 1644 | if w := self._layout.itemAt(required_widgets + i).widget(): 1645 | logging.debug(f" done!") 1646 | w.setVisible(False) 1647 | 1648 | # add missing widgets 1649 | for i in range(max(required_widgets - existing_widgets, 0)): 1650 | logging.debug(f" adding {existing_widgets + i}") 1651 | name = names[existing_widgets + i] 1652 | self.add_entry(name, self.scheme_dict[name]) 1653 | logging.debug(f" done!") 1654 | 1655 | def reset(self): 1656 | self.scheme_dict = {} 1657 | for i in range(self._layout.count()): 1658 | w = self._layout.itemAt(i).widget() 1659 | logging.debug(f"making {w} invisible") 1660 | w.setVisible(False) 1661 | -------------------------------------------------------------------------------- /src/napari_deeplabcut/_writer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from itertools import groupby 3 | from pathlib import Path 4 | 5 | import pandas as pd 6 | import yaml 7 | from napari.layers import Shapes 8 | from napari_builtins.io import napari_write_shapes 9 | from skimage.io import imsave 10 | from skimage.util import img_as_ubyte 11 | 12 | from napari_deeplabcut import misc 13 | from napari_deeplabcut._reader import _load_config 14 | 15 | 16 | def _write_config(config_path: str, params: dict): 17 | with open(config_path, "w") as file: 18 | yaml.safe_dump(params, file) 19 | 20 | 21 | def _form_df(points_data, metadata): 22 | temp = pd.DataFrame(points_data[:, -1:0:-1], columns=["x", "y"]) 23 | properties = metadata["properties"] 24 | meta = metadata["metadata"] 25 | temp["bodyparts"] = properties["label"] 26 | temp["individuals"] = properties["id"] 27 | temp["inds"] = points_data[:, 0].astype(int) 28 | temp["likelihood"] = properties["likelihood"] 29 | temp["scorer"] = meta["header"].scorer 30 | df = temp.set_index(["scorer", "individuals", "bodyparts", "inds"]).stack() 31 | df.index.set_names("coords", level=-1, inplace=True) 32 | df = df.unstack(["scorer", "individuals", "bodyparts", "coords"]) 33 | df.index.name = None 34 | if not properties["id"][0]: 35 | df = df.droplevel("individuals", axis=1) 36 | df = df.reindex(meta["header"].columns, axis=1) 37 | # Fill unannotated rows with NaNs 38 | # df = df.reindex(range(len(meta['paths']))) 39 | # df.index = meta['paths'] 40 | if meta["paths"]: 41 | df.index = [meta["paths"][i] for i in df.index] 42 | misc.guarantee_multiindex_rows(df) 43 | return df 44 | 45 | 46 | def write_hdf(filename, data, metadata): 47 | file, _ = os.path.splitext(filename) # FIXME Unused currently 48 | df = _form_df(data, metadata) 49 | meta = metadata["metadata"] 50 | name = metadata["name"] 51 | root = meta["root"] 52 | if "machine" in name: # We are attempting to save refined model predictions 53 | df.drop("likelihood", axis=1, level="coords", inplace=True, errors="ignore") 54 | header = misc.DLCHeader(df.columns) 55 | gt_file = "" 56 | for file in os.listdir(root): 57 | if file.startswith("CollectedData") and file.endswith("h5"): 58 | gt_file = file 59 | break 60 | if gt_file: # Refined predictions must be merged into the existing data 61 | df_gt = pd.read_hdf(os.path.join(root, gt_file)) 62 | new_scorer = df_gt.columns.get_level_values("scorer")[0] 63 | header.scorer = new_scorer 64 | df.columns = header.columns 65 | df = pd.concat((df, df_gt)) 66 | df = df[~df.index.duplicated(keep="first")] 67 | name = os.path.splitext(gt_file)[0] 68 | else: 69 | # Let us fetch the config.yaml file to get the scorer name... 70 | project_folder = Path(root).parents[1] 71 | config = _load_config(str(project_folder / "config.yaml")) 72 | new_scorer = config["scorer"] 73 | header.scorer = new_scorer 74 | df.columns = header.columns 75 | name = f"CollectedData_{new_scorer}" 76 | df.sort_index(inplace=True) 77 | filename = name + ".h5" 78 | path = os.path.join(root, filename) 79 | df.to_hdf(path, key="keypoints", mode="w") 80 | df.to_csv(path.replace(".h5", ".csv")) 81 | return filename 82 | 83 | 84 | def _write_image(data, output_path, plugin=None): 85 | Path(output_path).parent.mkdir(parents=True, exist_ok=True) 86 | imsave( 87 | output_path, 88 | img_as_ubyte(data).squeeze(), 89 | plugin=plugin, 90 | check_contrast=False, 91 | ) 92 | 93 | 94 | def write_masks(foldername, data, metadata): 95 | folder, _ = os.path.splitext(foldername) 96 | os.makedirs(folder, exist_ok=True) 97 | filename = os.path.join(folder, "{}_obj_{}.png") 98 | shapes = Shapes(data, shape_type="polygon") 99 | meta = metadata["metadata"] 100 | frame_inds = [int(array[0, 0]) for array in data] 101 | shape_inds = [] 102 | for _, group in groupby(frame_inds): 103 | shape_inds += range(sum(1 for _ in group)) 104 | masks = shapes.to_masks(mask_shape=meta["shape"][1:]) 105 | for n, mask in enumerate(masks): 106 | image_name = os.path.basename(meta["paths"][frame_inds[n]]) 107 | output_path = filename.format(os.path.splitext(image_name)[0], shape_inds[n]) 108 | _write_image(mask, output_path) 109 | napari_write_shapes(os.path.join(folder, "vertices.csv"), data, metadata) 110 | return folder 111 | -------------------------------------------------------------------------------- /src/napari_deeplabcut/assets/black/Back.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepLabCut/napari-deeplabcut/0ceba4f203135aeeeaf8edfb8a8689d0517a5e1e/src/napari_deeplabcut/assets/black/Back.png -------------------------------------------------------------------------------- /src/napari_deeplabcut/assets/black/Customize.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepLabCut/napari-deeplabcut/0ceba4f203135aeeeaf8edfb8a8689d0517a5e1e/src/napari_deeplabcut/assets/black/Customize.png -------------------------------------------------------------------------------- /src/napari_deeplabcut/assets/black/Forward.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepLabCut/napari-deeplabcut/0ceba4f203135aeeeaf8edfb8a8689d0517a5e1e/src/napari_deeplabcut/assets/black/Forward.png -------------------------------------------------------------------------------- /src/napari_deeplabcut/assets/black/Home.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepLabCut/napari-deeplabcut/0ceba4f203135aeeeaf8edfb8a8689d0517a5e1e/src/napari_deeplabcut/assets/black/Home.png -------------------------------------------------------------------------------- /src/napari_deeplabcut/assets/black/Pan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepLabCut/napari-deeplabcut/0ceba4f203135aeeeaf8edfb8a8689d0517a5e1e/src/napari_deeplabcut/assets/black/Pan.png -------------------------------------------------------------------------------- /src/napari_deeplabcut/assets/black/Pan_checked.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepLabCut/napari-deeplabcut/0ceba4f203135aeeeaf8edfb8a8689d0517a5e1e/src/napari_deeplabcut/assets/black/Pan_checked.png -------------------------------------------------------------------------------- /src/napari_deeplabcut/assets/black/Save.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepLabCut/napari-deeplabcut/0ceba4f203135aeeeaf8edfb8a8689d0517a5e1e/src/napari_deeplabcut/assets/black/Save.png -------------------------------------------------------------------------------- /src/napari_deeplabcut/assets/black/Subplots.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepLabCut/napari-deeplabcut/0ceba4f203135aeeeaf8edfb8a8689d0517a5e1e/src/napari_deeplabcut/assets/black/Subplots.png -------------------------------------------------------------------------------- /src/napari_deeplabcut/assets/black/Zoom.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepLabCut/napari-deeplabcut/0ceba4f203135aeeeaf8edfb8a8689d0517a5e1e/src/napari_deeplabcut/assets/black/Zoom.png -------------------------------------------------------------------------------- /src/napari_deeplabcut/assets/black/Zoom_checked.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepLabCut/napari-deeplabcut/0ceba4f203135aeeeaf8edfb8a8689d0517a5e1e/src/napari_deeplabcut/assets/black/Zoom_checked.png -------------------------------------------------------------------------------- /src/napari_deeplabcut/assets/superanimal_quadruped.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepLabCut/napari-deeplabcut/0ceba4f203135aeeeaf8edfb8a8689d0517a5e1e/src/napari_deeplabcut/assets/superanimal_quadruped.jpg -------------------------------------------------------------------------------- /src/napari_deeplabcut/assets/superanimal_quadruped.json: -------------------------------------------------------------------------------- 1 | { 2 | "nose": [ 3 | 155.19781918649664, 4 | 458.4043683559347 5 | ], 6 | "upper_jaw": [ 7 | 97.40751997750615, 8 | 458.4043683559347 9 | ], 10 | "lower_jaw": [ 11 | 35.585339428353535, 12 | 457.06040790921395 13 | ], 14 | "mouth_end_right": [ 15 | 79.93603417013694, 16 | 551.1376391796636 17 | ], 18 | "mouth_end_left": [ 19 | 81.27999461685764, 20 | 376.42278110597147 21 | ], 22 | "right_eye": [ 23 | 223.7398019692528, 24 | 518.8825884583666 25 | ], 26 | "right_earbase": [ 27 | 289.5938638585676, 28 | 555.1695205198257 29 | ], 30 | "right_earend": [ 31 | 332.60059815363024, 32 | 630.4313055361854 33 | ], 34 | "right_antler_base": [ 35 | 321.8489145798645, 36 | 489.315458630511 37 | ], 38 | "right_antler_end": [ 39 | 364.8556488749273, 40 | 513.5067466714837 41 | ], 42 | "left_eye": [ 43 | 225.08376241597355, 44 | 407.33387138054775 45 | ], 46 | "left_earbase": [ 47 | 290.93782430528825, 48 | 373.73486021252995 49 | ], 50 | "left_earend": [ 51 | 332.60059815363024, 52 | 302.50495653633243 53 | ], 54 | "left_antler_base": [ 55 | 320.5049541331439, 56 | 430.1811989747998 57 | ], 58 | "left_antler_end": [ 59 | 362.1677279814859, 60 | 405.989910933827 61 | ], 62 | "neck_base": [ 63 | 413.23822495687284, 64 | 458.4043683559347 65 | ], 66 | "neck_end": [ 67 | 505.9714957806017, 68 | 458.4043683559347 69 | ], 70 | "throat_base": [ 71 | 472.372484612584, 72 | 275.62574760191825 73 | ], 74 | "throat_end": [ 75 | 503.2835748871603, 76 | 350.887532618278 77 | ], 78 | "back_base": [ 79 | 583.9212016904029, 80 | 459.74832880265535 81 | ], 82 | "back_end": [ 83 | 776.1075455714642, 84 | 458.4043683559347 85 | ], 86 | "back_middle": [ 87 | 718.3172463624737, 88 | 458.4043683559347 89 | ], 90 | "tail_base": [ 91 | 831.2099238870134, 92 | 457.06040790921395 93 | ], 94 | "tail_end": [ 95 | 933.3509178377872, 96 | 459.74832880265535 97 | ], 98 | "front_left_thai": [ 99 | 582.5772412436821, 100 | 324.00832368386375 101 | ], 102 | "front_left_knee": [ 103 | 579.8893203502407, 104 | 244.7146573273419 105 | ], 106 | "front_left_paw": [ 107 | 581.2332807969614, 108 | 160.04514918393727 109 | ], 110 | "front_right_thai": [ 111 | 581.2332807969614, 112 | 590.1124921345641 113 | ], 114 | "front_right_knee": [ 115 | 581.2332807969614, 116 | 672.0940793845274 117 | ], 118 | "front_right_paw": [ 119 | 582.5772412436821, 120 | 755.4196270812114 121 | ], 122 | "back_left_paw": [ 123 | 829.8659634402926, 124 | 161.38910963065797 125 | ], 126 | "back_left_thai": [ 127 | 831.2099238870134, 128 | 334.76000725762947 129 | ], 130 | "back_right_thai": [ 131 | 831.2099238870134, 132 | 582.04872945424 133 | ], 134 | "back_left_knee": [ 135 | 829.8659634402926, 136 | 247.40257822078337 137 | ], 138 | "back_right_knee": [ 139 | 829.8659634402926, 140 | 668.0621980443653 141 | ], 142 | "back_right_paw": [ 143 | 831.2099238870134, 144 | 756.7635875279321 145 | ], 146 | "belly_bottom": [ 147 | 782.8273478050678, 148 | 537.6980347124565 149 | ], 150 | "body_middle_right": [ 151 | 710.2534836821495, 152 | 561.8893227534293 153 | ], 154 | "body_middle_left": [ 155 | 707.5655627887081, 156 | 375.0788206592507 157 | ] 158 | } -------------------------------------------------------------------------------- /src/napari_deeplabcut/assets/superanimal_topviewmouse.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepLabCut/napari-deeplabcut/0ceba4f203135aeeeaf8edfb8a8689d0517a5e1e/src/napari_deeplabcut/assets/superanimal_topviewmouse.jpg -------------------------------------------------------------------------------- /src/napari_deeplabcut/assets/superanimal_topviewmouse.json: -------------------------------------------------------------------------------- 1 | { 2 | "nose": [ 3 | 21.659348770669567, 4 | 201.9571852019767 5 | ], 6 | "left_ear": [ 7 | 140.92800399864308, 8 | 213.7160667033262 9 | ], 10 | "right_ear": [ 11 | 103.97151928011608, 12 | 301.0677578562082 13 | ], 14 | "left_ear_tip": [ 15 | 177.8844887171701, 16 | 185.15878305719173 17 | ], 18 | "right_ear_tip": [ 19 | 119.09008121042258, 20 | 356.5024849339987 21 | ], 22 | "left_eye": [ 23 | 92.21263777876659, 24 | 196.91766455854122 25 | ], 26 | "right_eye": [ 27 | 65.33519434711059, 28 | 274.19031442455224 29 | ], 30 | "neck": [ 31 | 194.68289086195512, 32 | 279.2298350679877 33 | ], 34 | "mid_back": [ 35 | 251.7974581542241, 36 | 265.7911133521597 37 | ], 38 | "mouse_center": [ 39 | 298.83298415962213, 40 | 252.35239163633167 41 | ], 42 | "mid_backend": [ 43 | 344.1886699505416, 44 | 235.5539894915467 45 | ], 46 | "mid_backend2": [ 47 | 404.6629176717676, 48 | 215.39590691780472 49 | ], 50 | "mid_backend3": [ 51 | 460.0976447495581, 52 | 195.2378243440627 53 | ], 54 | "tail_base": [ 55 | 518.8920522563056, 56 | 166.68054069792822 57 | ], 58 | "tail1": [ 59 | 592.8050216933597, 60 | 144.8426179097077 61 | ], 62 | "tail2": [ 63 | 665.0381509159353, 64 | 129.72405597940121 65 | ], 66 | "tail3": [ 67 | 732.2317594950753, 68 | 148.20229833866472 69 | ], 70 | "tail4": [ 71 | 777.5874452859947, 72 | 173.39990155584218 73 | ], 74 | "tail5": [ 75 | 836.3818527927423, 76 | 185.15878305719173 77 | ], 78 | "left_shoulder": [ 79 | 214.8409734356971, 80 | 193.5579841295842 81 | ], 82 | "left_midside": [ 83 | 288.7539428727511, 84 | 166.68054069792822 85 | ], 86 | "left_hip": [ 87 | 374.4257938111546, 88 | 87.72805061743867 89 | ], 90 | "right_shoulder": [ 91 | 213.16113322121862, 92 | 351.4629642905632 93 | ], 94 | "right_midside": [ 95 | 335.78946887814914, 96 | 332.9847219312997 97 | ], 98 | "right_hip": [ 99 | 510.49285118391316, 100 | 264.11127313768117 101 | ], 102 | "tail_end": [ 103 | 883.4173787981402, 104 | 173.39990155584218 105 | ], 106 | "head_midpoint": [ 107 | 117.41024099594408, 108 | 252.35239163633167 109 | ] 110 | } -------------------------------------------------------------------------------- /src/napari_deeplabcut/assets/white/Back.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepLabCut/napari-deeplabcut/0ceba4f203135aeeeaf8edfb8a8689d0517a5e1e/src/napari_deeplabcut/assets/white/Back.png -------------------------------------------------------------------------------- /src/napari_deeplabcut/assets/white/Customize.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepLabCut/napari-deeplabcut/0ceba4f203135aeeeaf8edfb8a8689d0517a5e1e/src/napari_deeplabcut/assets/white/Customize.png -------------------------------------------------------------------------------- /src/napari_deeplabcut/assets/white/Forward.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepLabCut/napari-deeplabcut/0ceba4f203135aeeeaf8edfb8a8689d0517a5e1e/src/napari_deeplabcut/assets/white/Forward.png -------------------------------------------------------------------------------- /src/napari_deeplabcut/assets/white/Home.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepLabCut/napari-deeplabcut/0ceba4f203135aeeeaf8edfb8a8689d0517a5e1e/src/napari_deeplabcut/assets/white/Home.png -------------------------------------------------------------------------------- /src/napari_deeplabcut/assets/white/Pan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepLabCut/napari-deeplabcut/0ceba4f203135aeeeaf8edfb8a8689d0517a5e1e/src/napari_deeplabcut/assets/white/Pan.png -------------------------------------------------------------------------------- /src/napari_deeplabcut/assets/white/Pan_checked.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepLabCut/napari-deeplabcut/0ceba4f203135aeeeaf8edfb8a8689d0517a5e1e/src/napari_deeplabcut/assets/white/Pan_checked.png -------------------------------------------------------------------------------- /src/napari_deeplabcut/assets/white/Save.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepLabCut/napari-deeplabcut/0ceba4f203135aeeeaf8edfb8a8689d0517a5e1e/src/napari_deeplabcut/assets/white/Save.png -------------------------------------------------------------------------------- /src/napari_deeplabcut/assets/white/Subplots.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepLabCut/napari-deeplabcut/0ceba4f203135aeeeaf8edfb8a8689d0517a5e1e/src/napari_deeplabcut/assets/white/Subplots.png -------------------------------------------------------------------------------- /src/napari_deeplabcut/assets/white/Zoom.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepLabCut/napari-deeplabcut/0ceba4f203135aeeeaf8edfb8a8689d0517a5e1e/src/napari_deeplabcut/assets/white/Zoom.png -------------------------------------------------------------------------------- /src/napari_deeplabcut/assets/white/Zoom_checked.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepLabCut/napari-deeplabcut/0ceba4f203135aeeeaf8edfb8a8689d0517a5e1e/src/napari_deeplabcut/assets/white/Zoom_checked.png -------------------------------------------------------------------------------- /src/napari_deeplabcut/keypoints.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from enum import auto 3 | from typing import List, Sequence 4 | 5 | import numpy as np 6 | from napari._qt.layer_controls.qt_points_controls import QtPointsControls 7 | from napari.layers import Points 8 | from napari.layers.points._points_constants import SYMBOL_TRANSLATION_INVERTED 9 | from napari.layers.points._points_utils import coerce_symbols 10 | from scipy.spatial import cKDTree 11 | 12 | from napari_deeplabcut.misc import CycleEnum 13 | 14 | 15 | # Monkeypatch the point size slider 16 | def _change_size(self, value): 17 | """Resize all points at once regardless of the current selection.""" 18 | self.layer._current_size = value 19 | if self.layer._update_properties: 20 | self.layer.size = (self.layer.size > 0) * value 21 | self.layer.refresh() 22 | self.layer.events.size() 23 | 24 | 25 | def _change_symbol(self, text): 26 | symbol = coerce_symbols(np.array([SYMBOL_TRANSLATION_INVERTED[text]]))[0] 27 | self.layer._current_symbol = symbol 28 | if self.layer._update_properties: 29 | self.layer.symbol = symbol 30 | self.layer.events.symbol() 31 | self.layer.events.current_symbol() 32 | 33 | 34 | QtPointsControls.changeCurrentSize = _change_size 35 | QtPointsControls.changeCurrentSymbol = _change_symbol 36 | 37 | 38 | class ColorMode(CycleEnum): 39 | """Modes in which keypoints can be colored 40 | 41 | BODYPART: the keypoints are grouped by bodypart (all bodyparts have the same color) 42 | INDIVIDUAL: the keypoints are grouped by individual (all keypoints for the same 43 | individual have the same color) 44 | """ 45 | 46 | BODYPART = auto() 47 | INDIVIDUAL = auto() 48 | 49 | @classmethod 50 | def default(cls): 51 | return cls.BODYPART 52 | 53 | 54 | class LabelMode(CycleEnum): 55 | """ 56 | Labeling modes. 57 | SEQUENTIAL: points are placed in sequence, then frame after frame; 58 | clicking to add an already annotated point has no effect. 59 | QUICK: similar to SEQUENTIAL, but trying to add an already 60 | annotated point actually moves it to the cursor location. 61 | LOOP: the currently selected point is placed frame after frame, 62 | before wrapping at the end to frame 1, etc. 63 | """ 64 | 65 | SEQUENTIAL = auto() 66 | QUICK = auto() 67 | LOOP = auto() 68 | 69 | @classmethod 70 | def default(cls): 71 | return cls.SEQUENTIAL 72 | 73 | 74 | # Description tooltips for the labeling modes radio buttons. 75 | TOOLTIPS = { 76 | "SEQUENTIAL": "Points are placed in sequence, then frame after frame;\n" 77 | "clicking to add an already annotated point has no effect.", 78 | "QUICK": "Similar to SEQUENTIAL, but trying to add an already\n" 79 | "annotated point actually moves it to the cursor location.", 80 | "LOOP": "The currently selected point is placed frame after frame,\n" 81 | "before wrapping at the end to frame 1, etc.", 82 | } 83 | 84 | 85 | Keypoint = namedtuple("Keypoint", ["label", "id"]) 86 | 87 | 88 | class KeypointStore: 89 | def __init__(self, viewer, layer: Points): 90 | self.viewer = viewer 91 | self._keypoints = [] 92 | self.layer = layer 93 | self.viewer.dims.set_current_step(0, 0) 94 | 95 | @property 96 | def layer(self): 97 | return self._layer 98 | 99 | @layer.setter 100 | def layer(self, layer): 101 | self._layer = layer 102 | all_pairs = self.layer.metadata["header"].form_individual_bodypart_pairs() 103 | self._keypoints = [ 104 | Keypoint(label, id_) for id_, label in all_pairs 105 | ] # Ordered references to all possible keypoints 106 | 107 | @property 108 | def current_step(self): 109 | return self.viewer.dims.current_step[0] 110 | 111 | @property 112 | def n_steps(self): 113 | return self.viewer.dims.nsteps[0] 114 | 115 | @property 116 | def annotated_keypoints(self) -> List[Keypoint]: 117 | mask = self.current_mask 118 | labels = self.layer.properties["label"][mask] 119 | ids = self.layer.properties["id"][mask] 120 | return [Keypoint(label, id_) for label, id_ in zip(labels, ids)] 121 | 122 | @property 123 | def current_mask(self) -> Sequence[bool]: 124 | return np.asarray(self.layer.data[:, 0] == self.current_step) 125 | 126 | @property 127 | def current_keypoint(self) -> Keypoint: 128 | props = self.layer.current_properties 129 | return Keypoint(label=props["label"][0], id=props["id"][0]) 130 | 131 | @current_keypoint.setter 132 | def current_keypoint(self, keypoint: Keypoint): 133 | # Avoid changing the properties of a selected point 134 | if not len(self.layer.selected_data): 135 | current_properties = self.layer.current_properties 136 | current_properties["label"] = np.asarray([keypoint.label]) 137 | current_properties["id"] = np.asarray([keypoint.id]) 138 | self.layer.current_properties = current_properties 139 | 140 | def next_keypoint(self, *args): 141 | ind = self._keypoints.index(self.current_keypoint) + 1 142 | if ind <= len(self._keypoints) - 1: 143 | self.current_keypoint = self._keypoints[ind] 144 | 145 | def prev_keypoint(self, *args): 146 | ind = self._keypoints.index(self.current_keypoint) - 1 147 | if ind >= 0: 148 | self.current_keypoint = self._keypoints[ind] 149 | 150 | @property 151 | def labels(self) -> List[str]: 152 | return self.layer.metadata["header"].bodyparts 153 | 154 | @property 155 | def current_label(self) -> str: 156 | return self.layer.current_properties["label"][0] 157 | 158 | @current_label.setter 159 | def current_label(self, label: str): 160 | if not len(self.layer.selected_data): 161 | current_properties = self.layer.current_properties 162 | current_properties["label"] = np.asarray([label]) 163 | self.layer.current_properties = current_properties 164 | 165 | @property 166 | def ids(self) -> List[str]: 167 | return self.layer.metadata["header"].individuals 168 | 169 | @property 170 | def current_id(self) -> str: 171 | return self.layer.current_properties["id"][0] 172 | 173 | @current_id.setter 174 | def current_id(self, id_: str): 175 | if not len(self.layer.selected_data): 176 | current_properties = self.layer.current_properties 177 | current_properties["id"] = np.asarray([id_]) 178 | self.layer.current_properties = current_properties 179 | 180 | def _advance_step(self, event): 181 | ind = (self.current_step + 1) % self.n_steps 182 | self.viewer.dims.set_current_step(0, ind) 183 | 184 | def _find_first_unlabeled_frame(self, event): 185 | inds = set(range(self.n_steps)) 186 | unlabeled_inds = inds.difference(self.layer.data[:, 0].astype(int)) 187 | if not unlabeled_inds: 188 | self.viewer.dims.set_current_step(0, self.n_steps - 1) 189 | else: 190 | self.viewer.dims.set_current_step(0, min(unlabeled_inds)) 191 | 192 | 193 | def _add(store, coord): 194 | if store.current_keypoint not in store.annotated_keypoints: 195 | store.layer.data = np.append( 196 | store.layer.data, 197 | np.atleast_2d(coord), 198 | axis=0, 199 | ) 200 | elif store.layer.metadata["controls"]._label_mode is LabelMode.QUICK: 201 | ind = store.annotated_keypoints.index(store.current_keypoint) 202 | data = store.layer.data 203 | data[np.flatnonzero(store.current_mask)[ind]] = coord 204 | store.layer.data = data 205 | store.layer.selected_data = set() 206 | if store.layer.metadata["controls"]._label_mode is LabelMode.LOOP: 207 | store.layer.events.query_next_frame() 208 | else: 209 | store.next_keypoint() 210 | 211 | 212 | def _find_nearest_neighbors(xy_true, xy_pred, k=5): 213 | n_preds = xy_pred.shape[0] 214 | tree = cKDTree(xy_pred) 215 | dist, inds = tree.query(xy_true, k=k) 216 | idx = np.argsort(dist[:, 0]) 217 | neighbors = np.full(len(xy_true), -1, dtype=int) 218 | picked = set() 219 | for i, ind in enumerate(inds[idx]): 220 | for j in ind: 221 | if j not in picked: 222 | picked.add(j) 223 | neighbors[idx[i]] = j 224 | break 225 | if len(picked) == n_preds: 226 | break 227 | return neighbors 228 | -------------------------------------------------------------------------------- /src/napari_deeplabcut/misc.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | from enum import Enum, EnumMeta 5 | from itertools import cycle 6 | from pathlib import Path 7 | from typing import Dict, List, Optional, Sequence, Tuple, Union 8 | 9 | import numpy as np 10 | import pandas as pd 11 | from napari.utils import colormaps 12 | 13 | 14 | def find_project_config_path(labeled_data_path: str) -> str: 15 | return str(Path(labeled_data_path).parents[2] / "config.yaml") 16 | 17 | 18 | def is_latest_version(): 19 | import json 20 | import urllib.request 21 | from napari_deeplabcut import __version__ 22 | 23 | url = "https://pypi.org/pypi/napari-deeplabcut/json" 24 | contents = urllib.request.urlopen(url).read() 25 | latest_version = json.loads(contents)["info"]["version"] 26 | return __version__ == latest_version, latest_version 27 | 28 | 29 | 30 | def unsorted_unique(array: Sequence) -> np.ndarray: 31 | """Return the unsorted unique elements of an array.""" 32 | _, inds = np.unique(array, return_index=True) 33 | return np.asarray(array)[np.sort(inds)] 34 | 35 | 36 | def encode_categories( 37 | categories: List[str], return_map: bool = False 38 | ) -> Union[List[int], Tuple[List[int], Dict]]: 39 | unique_cat = unsorted_unique(categories) 40 | map_ = dict(zip(unique_cat, range(len(unique_cat)))) 41 | inds = np.vectorize(map_.get)(categories) 42 | if return_map: 43 | return inds, map_ 44 | return inds 45 | 46 | 47 | def merge_multiple_scorers( 48 | df: pd.DataFrame, 49 | ) -> pd.DataFrame: 50 | n_frames = df.shape[0] 51 | header = DLCHeader(df.columns) 52 | n_scorers = len(header._get_unique("scorer")) 53 | if n_scorers == 1: 54 | return df 55 | 56 | if "likelihood" in header.coords: 57 | # Merge annotations from multiple scorers to keep 58 | # detections with highest confidence 59 | data = df.to_numpy().reshape((n_frames, n_scorers, -1, 3)) 60 | try: 61 | idx = np.nanargmax(data[..., 2], axis=1) 62 | except ValueError: # All-NaN slice encountered 63 | mask = np.isnan(data[..., 2]).all(axis=1, keepdims=True) 64 | mask = np.broadcast_to(mask[..., None], data.shape) 65 | data[mask] = -1 66 | idx = np.nanargmax(data[..., 2], axis=1) 67 | data[mask] = np.nan 68 | data_best = data[ 69 | np.arange(n_frames)[:, None], idx, np.arange(data.shape[2]) 70 | ].reshape((n_frames, -1)) 71 | df = pd.DataFrame( 72 | data_best, 73 | index=df.index, 74 | columns=header.columns[: data_best.shape[1]], 75 | ) 76 | else: # Arbitrarily pick data from the first scorer 77 | df = df.loc(axis=1)[: header.scorer] 78 | return df 79 | 80 | 81 | def to_os_dir_sep(path: str) -> str: 82 | """ 83 | Replace all directory separators in `path` with `os.path.sep`. 84 | Function originally written by @pyzun: 85 | https://github.com/DeepLabCut/napari-DeepLabCut/pull/13 86 | 87 | Raises 88 | ------ 89 | ValueError: if `path` contains both UNIX and Windows directory separators. 90 | 91 | """ 92 | win_sep, unix_sep = "\\", "/" 93 | 94 | # On UNIX systems, `win_sep` is a valid character in directory and file 95 | # names. This function fails if both are present. 96 | if win_sep in path and unix_sep in path: 97 | raise ValueError(f'"{path}" may not contain both "{win_sep}" and "{unix_sep}"!') 98 | 99 | sep = win_sep if win_sep in path else unix_sep 100 | 101 | return os.path.sep.join(path.split(sep)) 102 | 103 | 104 | def guarantee_multiindex_rows(df): 105 | # Make paths platform-agnostic if they are not already 106 | if not isinstance(df.index, pd.MultiIndex): # Backwards compatibility 107 | path = df.index[0] 108 | try: 109 | sep = "/" if "/" in path else "\\" 110 | splits = tuple(df.index.str.split(sep)) 111 | df.index = pd.MultiIndex.from_tuples(splits) 112 | except TypeError: # Ignore numerical index of frame indices 113 | pass 114 | 115 | 116 | def build_color_cycle(n_colors: int, colormap: Optional[str] = "viridis") -> np.ndarray: 117 | cmap = colormaps.ensure_colormap(colormap) 118 | return cmap.map(np.linspace(0, 1, n_colors)) 119 | 120 | 121 | def build_color_cycles(header: DLCHeader, colormap: Optional[str] = "viridis"): 122 | label_colors = build_color_cycle(len(header.bodyparts), colormap) 123 | id_colors = build_color_cycle(len(header.individuals), colormap) 124 | return { 125 | "label": dict(zip(header.bodyparts, label_colors)), 126 | "id": dict(zip(header.individuals, id_colors)), 127 | } 128 | 129 | 130 | class DLCHeader: 131 | def __init__(self, columns: pd.MultiIndex): 132 | self.columns = columns 133 | 134 | @classmethod 135 | def from_config(cls, config: Dict) -> DLCHeader: 136 | multi = config.get("multianimalproject", False) 137 | scorer = [config["scorer"]] 138 | if multi: 139 | columns = pd.MultiIndex.from_product( 140 | [ 141 | scorer, 142 | config["individuals"], 143 | config["multianimalbodyparts"], 144 | ["x", "y"], 145 | ] 146 | ) 147 | if len(config["uniquebodyparts"]): 148 | temp = pd.MultiIndex.from_product( 149 | [scorer, ["single"], config["uniquebodyparts"], ["x", "y"]] 150 | ) 151 | columns = columns.append(temp) 152 | columns.set_names( 153 | ["scorer", "individuals", "bodyparts", "coords"], inplace=True 154 | ) 155 | else: 156 | columns = pd.MultiIndex.from_product( 157 | [scorer, config["bodyparts"], ["x", "y"]], 158 | names=["scorer", "bodyparts", "coords"], 159 | ) 160 | return cls(columns) 161 | 162 | def form_individual_bodypart_pairs(self) -> List[Tuple[str]]: 163 | to_drop = [ 164 | name 165 | for name in self.columns.names 166 | if name not in ("individuals", "bodyparts") 167 | ] 168 | temp = self.columns.droplevel(to_drop).unique() 169 | if "individuals" not in temp.names: 170 | temp = pd.MultiIndex.from_product([self.individuals, temp]) 171 | return temp.to_list() 172 | 173 | @property 174 | def scorer(self) -> str: 175 | return self._get_unique("scorer")[0] 176 | 177 | @scorer.setter 178 | def scorer(self, scorer: str): 179 | self.columns = self.columns.set_levels([scorer], level="scorer") 180 | 181 | @property 182 | def individuals(self) -> List[str]: 183 | individuals = self._get_unique("individuals") 184 | if individuals is None: 185 | return [""] 186 | return individuals 187 | 188 | @property 189 | def bodyparts(self) -> List[str]: 190 | return self._get_unique("bodyparts") 191 | 192 | @property 193 | def coords(self) -> List[str]: 194 | return self._get_unique("coords") 195 | 196 | def _get_unique(self, name: str) -> Optional[List]: 197 | if name in self.columns.names: 198 | return list(unsorted_unique(self.columns.get_level_values(name))) 199 | return None 200 | 201 | 202 | class CycleEnumMeta(EnumMeta): 203 | def __new__(metacls, cls, bases, classdict): 204 | enum_ = super().__new__(metacls, cls, bases, classdict) 205 | enum_._cycle = cycle(enum_._member_map_[name] for name in enum_._member_names_) 206 | return enum_ 207 | 208 | def __iter__(cls): 209 | return cls._cycle 210 | 211 | def __next__(cls): 212 | return next(cls.__iter__()) 213 | 214 | def __getitem__(self, item): 215 | if isinstance(item, str): 216 | item = item.upper() 217 | return super().__getitem__(item) 218 | 219 | 220 | class CycleEnum(Enum, metaclass=CycleEnumMeta): 221 | def _generate_next_value_(name, start, count, last_values): 222 | return name.lower() 223 | 224 | def __str__(self): 225 | return self.value 226 | -------------------------------------------------------------------------------- /src/napari_deeplabcut/napari.yaml: -------------------------------------------------------------------------------- 1 | name: napari-deeplabcut 2 | display_name: napari DeepLabCut 3 | contributions: 4 | commands: 5 | - id: napari-deeplabcut.get_hdf_reader 6 | python_name: napari_deeplabcut._reader:get_hdf_reader 7 | title: Open data with napari DeepLabCut 8 | - id: napari-deeplabcut.get_image_reader 9 | python_name: napari_deeplabcut._reader:get_image_reader 10 | title: Open images with napari DeepLabCut 11 | - id: napari-deeplabcut.get_video_reader 12 | python_name: napari_deeplabcut._reader:get_video_reader 13 | title: Open videos with napari DeepLabCut 14 | - id: napari-deeplabcut.get_folder_parser 15 | python_name: napari_deeplabcut._reader:get_folder_parser 16 | title: Open folder with napari DeepLabCut 17 | - id: napari-deeplabcut.get_config_reader 18 | python_name: napari_deeplabcut._reader:get_config_reader 19 | title: Open config with napari DeepLabCut 20 | - id: napari-deeplabcut.write_hdf 21 | python_name: napari_deeplabcut._writer:write_hdf 22 | title: Save keypoint annotations with napari DeepLabCut 23 | - id: napari-deeplabcut.write_masks 24 | python_name: napari_deeplabcut._writer:write_masks 25 | title: Save segmentation masks with napari DeepLabCut 26 | - id: napari-deeplabcut.make_keypoint_controls 27 | python_name: napari_deeplabcut._widgets:KeypointControls 28 | title: Make keypoint controls 29 | readers: 30 | - command: napari-deeplabcut.get_hdf_reader 31 | accepts_directories: false 32 | filename_patterns: ['*.h5'] 33 | - command: napari-deeplabcut.get_image_reader 34 | accepts_directories: true 35 | filename_patterns: ['*.png', '*.jpg', '*.jpeg'] 36 | - command: napari-deeplabcut.get_video_reader 37 | accepts_directories: true 38 | filename_patterns: ['*.mp4', '*.mov', '*.avi'] 39 | - command: napari-deeplabcut.get_config_reader 40 | accepts_directories: false 41 | filename_patterns: ['*.yaml'] 42 | - command: napari-deeplabcut.get_folder_parser 43 | accepts_directories: true 44 | filename_patterns: ['*'] 45 | writers: 46 | - command: napari-deeplabcut.write_hdf 47 | layer_types: ["points{1}"] 48 | filename_extensions: [".h5"] 49 | - command: napari-deeplabcut.write_masks 50 | layer_types: ["shapes{1}"] 51 | filename_extensions: [".csv"] 52 | widgets: 53 | - command: napari-deeplabcut.make_keypoint_controls 54 | display_name: Keypoint controls 55 | -------------------------------------------------------------------------------- /src/napari_deeplabcut/styles/dark.mplstyle: -------------------------------------------------------------------------------- 1 | # Dark-theme napari colour scheme for matplotlib plots 2 | 3 | # text (very light grey - almost white): #f0f1f2 4 | # foreground (mid grey): #414851 5 | # background (dark blue-gray): #262930 6 | 7 | figure.facecolor : none 8 | axes.labelcolor : f0f1f2 9 | axes.facecolor : none 10 | axes.edgecolor : 414851 11 | xtick.color : f0f1f2 12 | ytick.color : f0f1f2 -------------------------------------------------------------------------------- /src/napari_deeplabcut/styles/light.mplstyle: -------------------------------------------------------------------------------- 1 | # Light-theme napari colour scheme for matplotlib plots 2 | 3 | # text (very dark grey - almost black): #3b3a39 4 | # foreground (mid grey): #d6d0ce 5 | # background (brownish beige): #efebe9 6 | 7 | figure.facecolor : none 8 | axes.labelcolor : 3b3a39 9 | axes.facecolor : none 10 | axes.edgecolor : d6d0ce 11 | xtick.color : 3b3a39 12 | ytick.color : 3b3a39 -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | # For more information about tox, see https://tox.readthedocs.io/en/latest/ 2 | [tox] 3 | envlist = py{39,310}-{linux,macos,windows} 4 | isolated_build=true 5 | 6 | [gh-actions] 7 | python = 8 | 3.9: py39 9 | 3.10: py310 10 | 11 | [gh-actions:env] 12 | PLATFORM = 13 | ubuntu-latest: linux 14 | macos-latest: macos 15 | windows-latest: windows 16 | 17 | [testenv] 18 | platform = 19 | macos: darwin 20 | linux: linux 21 | windows: win32 22 | passenv = 23 | CI 24 | GITHUB_ACTIONS 25 | DISPLAY 26 | XAUTHORITY 27 | NUMPY_EXPERIMENTAL_ARRAY_FUNCTION 28 | PYTEST_QT_API 29 | PYVISTA_OFF_SCREEN 30 | extras = 31 | testing 32 | commands = pytest -v --color=yes --cov=napari_deeplabcut --cov-report=xml 33 | --------------------------------------------------------------------------------