├── data
└── demo
│ ├── 0.jpg
│ ├── 1.jpg
│ ├── 2.jpg
│ ├── 3.jpg
│ ├── 4.jpg
│ └── 5.jpg
├── docs
├── _static
│ ├── inv
│ │ └── tqdm.inv
│ ├── img
│ │ ├── banner.jpg
│ │ ├── logo-dark.png
│ │ ├── logo-light.png
│ │ ├── detection-eyes-0.jpg
│ │ ├── detection-eyes-1.jpg
│ │ ├── detection-eyes-2.jpg
│ │ ├── detection-solo-0.jpg
│ │ ├── detection-solo-1.jpg
│ │ ├── detection-solo-2.jpg
│ │ ├── detection-worn-0.jpg
│ │ ├── detection-worn-1.jpg
│ │ ├── detection-worn-2.jpg
│ │ ├── segmentation-full-0.jpg
│ │ ├── segmentation-full-1.jpg
│ │ ├── segmentation-full-2.jpg
│ │ ├── segmentation-legs-0.jpg
│ │ ├── segmentation-legs-1.jpg
│ │ ├── segmentation-legs-2.jpg
│ │ ├── segmentation-smart-0.jpg
│ │ ├── segmentation-smart-1.jpg
│ │ ├── segmentation-smart-2.jpg
│ │ ├── segmentation-frames-0.jpg
│ │ ├── segmentation-frames-1.jpg
│ │ ├── segmentation-frames-2.jpg
│ │ ├── segmentation-lenses-0.jpg
│ │ ├── segmentation-lenses-1.jpg
│ │ ├── segmentation-lenses-2.jpg
│ │ ├── segmentation-shadows-0.jpg
│ │ ├── segmentation-shadows-1.jpg
│ │ ├── segmentation-shadows-2.jpg
│ │ ├── classification-shadows-neg.jpg
│ │ ├── classification-shadows-pos.jpg
│ │ ├── classification-eyeglasses-neg.jpg
│ │ ├── classification-eyeglasses-pos.jpg
│ │ ├── classification-no-glasses-neg.jpg
│ │ ├── classification-sunglasses-neg.jpg
│ │ └── classification-sunglasses-pos.jpg
│ ├── css
│ │ ├── signatures.css
│ │ ├── custom.css
│ │ └── highlights.css
│ ├── js
│ │ ├── colab-icon.js
│ │ ├── pypi-icon.js
│ │ └── zenodo-icon.js
│ ├── bib
│ │ └── references.bib
│ └── svg
│ │ └── colab.svg
├── helpers
│ ├── __init__.py
│ ├── custom_invs.py
│ ├── generate_examples.py
│ └── build_finished.py
├── requirements.txt
├── docs
│ ├── api
│ │ ├── glasses_detector.utils.rst
│ │ ├── glasses_detector.components.base_model.rst
│ │ ├── glasses_detector.components.pred_type.rst
│ │ ├── glasses_detector.architectures.tiny_binary_detector.rst
│ │ ├── glasses_detector.architectures.tiny_binary_segmenter.rst
│ │ ├── glasses_detector.components.pred_interface.rst
│ │ ├── glasses_detector.architectures.tiny_binary_classifier.rst
│ │ ├── glasses_detector.detector.rst
│ │ ├── glasses_detector.segmenter.rst
│ │ └── glasses_detector.classifier.rst
│ ├── api.rst
│ ├── credits.rst
│ ├── install.rst
│ ├── cli.rst
│ ├── examples.rst
│ └── features.rst
├── Makefile
├── make.bat
├── index.rst
├── conf.yaml
└── conf.py
├── src
└── glasses_detector
│ ├── components
│ ├── __init__.py
│ └── pred_type.py
│ ├── __init__.py
│ ├── _wrappers
│ ├── __init__.py
│ ├── binary_detector.py
│ ├── binary_segmenter.py
│ ├── binary_classifier.py
│ └── metrics.py
│ ├── architectures
│ ├── __init__.py
│ ├── tiny_binary_classifier.py
│ ├── tiny_binary_segmenter.py
│ └── tiny_binary_detector.py
│ ├── _data
│ ├── __init__.py
│ ├── binary_segmentation_dataset.py
│ ├── binary_classification_dataset.py
│ ├── binary_detection_dataset.py
│ ├── base_categorized_dataset.py
│ └── augmenter_mixin.py
│ ├── utils.py
│ └── __main__.py
├── .gitignore
├── requirements.txt
├── CITATION.cff
├── .github
└── workflows
│ ├── python-publish.yaml
│ └── sphinx.yaml
├── LICENSE
├── scripts
├── analyse.py
└── run.py
└── CODE_OF_CONDUCT.md
/data/demo/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/data/demo/0.jpg
--------------------------------------------------------------------------------
/data/demo/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/data/demo/1.jpg
--------------------------------------------------------------------------------
/data/demo/2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/data/demo/2.jpg
--------------------------------------------------------------------------------
/data/demo/3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/data/demo/3.jpg
--------------------------------------------------------------------------------
/data/demo/4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/data/demo/4.jpg
--------------------------------------------------------------------------------
/data/demo/5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/data/demo/5.jpg
--------------------------------------------------------------------------------
/docs/_static/inv/tqdm.inv:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/inv/tqdm.inv
--------------------------------------------------------------------------------
/docs/helpers/__init__.py:
--------------------------------------------------------------------------------
1 | from .build_finished import BuildFinished
2 | from .custom_invs import CustomInvs
3 |
--------------------------------------------------------------------------------
/docs/_static/img/banner.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/banner.jpg
--------------------------------------------------------------------------------
/docs/_static/img/logo-dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/logo-dark.png
--------------------------------------------------------------------------------
/docs/_static/img/logo-light.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/logo-light.png
--------------------------------------------------------------------------------
/src/glasses_detector/components/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_model import BaseGlassesModel
2 | from .pred_type import PredType
3 |
--------------------------------------------------------------------------------
/docs/_static/img/detection-eyes-0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/detection-eyes-0.jpg
--------------------------------------------------------------------------------
/docs/_static/img/detection-eyes-1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/detection-eyes-1.jpg
--------------------------------------------------------------------------------
/docs/_static/img/detection-eyes-2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/detection-eyes-2.jpg
--------------------------------------------------------------------------------
/docs/_static/img/detection-solo-0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/detection-solo-0.jpg
--------------------------------------------------------------------------------
/docs/_static/img/detection-solo-1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/detection-solo-1.jpg
--------------------------------------------------------------------------------
/docs/_static/img/detection-solo-2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/detection-solo-2.jpg
--------------------------------------------------------------------------------
/docs/_static/img/detection-worn-0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/detection-worn-0.jpg
--------------------------------------------------------------------------------
/docs/_static/img/detection-worn-1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/detection-worn-1.jpg
--------------------------------------------------------------------------------
/docs/_static/img/detection-worn-2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/detection-worn-2.jpg
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx
2 | sphinx-design
3 | sphinx-copybutton
4 | sphinxcontrib-bibtex
5 | pydata-sphinx-theme
6 | sphobjinv
7 |
--------------------------------------------------------------------------------
/docs/_static/img/segmentation-full-0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/segmentation-full-0.jpg
--------------------------------------------------------------------------------
/docs/_static/img/segmentation-full-1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/segmentation-full-1.jpg
--------------------------------------------------------------------------------
/docs/_static/img/segmentation-full-2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/segmentation-full-2.jpg
--------------------------------------------------------------------------------
/docs/_static/img/segmentation-legs-0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/segmentation-legs-0.jpg
--------------------------------------------------------------------------------
/docs/_static/img/segmentation-legs-1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/segmentation-legs-1.jpg
--------------------------------------------------------------------------------
/docs/_static/img/segmentation-legs-2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/segmentation-legs-2.jpg
--------------------------------------------------------------------------------
/docs/_static/img/segmentation-smart-0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/segmentation-smart-0.jpg
--------------------------------------------------------------------------------
/docs/_static/img/segmentation-smart-1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/segmentation-smart-1.jpg
--------------------------------------------------------------------------------
/docs/_static/img/segmentation-smart-2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/segmentation-smart-2.jpg
--------------------------------------------------------------------------------
/docs/_static/img/segmentation-frames-0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/segmentation-frames-0.jpg
--------------------------------------------------------------------------------
/docs/_static/img/segmentation-frames-1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/segmentation-frames-1.jpg
--------------------------------------------------------------------------------
/docs/_static/img/segmentation-frames-2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/segmentation-frames-2.jpg
--------------------------------------------------------------------------------
/docs/_static/img/segmentation-lenses-0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/segmentation-lenses-0.jpg
--------------------------------------------------------------------------------
/docs/_static/img/segmentation-lenses-1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/segmentation-lenses-1.jpg
--------------------------------------------------------------------------------
/docs/_static/img/segmentation-lenses-2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/segmentation-lenses-2.jpg
--------------------------------------------------------------------------------
/docs/_static/img/segmentation-shadows-0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/segmentation-shadows-0.jpg
--------------------------------------------------------------------------------
/docs/_static/img/segmentation-shadows-1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/segmentation-shadows-1.jpg
--------------------------------------------------------------------------------
/docs/_static/img/segmentation-shadows-2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/segmentation-shadows-2.jpg
--------------------------------------------------------------------------------
/docs/_static/img/classification-shadows-neg.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/classification-shadows-neg.jpg
--------------------------------------------------------------------------------
/docs/_static/img/classification-shadows-pos.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/classification-shadows-pos.jpg
--------------------------------------------------------------------------------
/docs/_static/img/classification-eyeglasses-neg.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/classification-eyeglasses-neg.jpg
--------------------------------------------------------------------------------
/docs/_static/img/classification-eyeglasses-pos.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/classification-eyeglasses-pos.jpg
--------------------------------------------------------------------------------
/docs/_static/img/classification-no-glasses-neg.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/classification-no-glasses-neg.jpg
--------------------------------------------------------------------------------
/docs/_static/img/classification-sunglasses-neg.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/classification-sunglasses-neg.jpg
--------------------------------------------------------------------------------
/docs/_static/img/classification-sunglasses-pos.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/classification-sunglasses-pos.jpg
--------------------------------------------------------------------------------
/src/glasses_detector/__init__.py:
--------------------------------------------------------------------------------
1 | from .classifier import GlassesClassifier
2 | from .detector import GlassesDetector
3 | from .segmenter import GlassesSegmenter
4 |
--------------------------------------------------------------------------------
/docs/docs/api/glasses_detector.utils.rst:
--------------------------------------------------------------------------------
1 | Utilities
2 | =========
3 |
4 | .. automodule:: glasses_detector.utils
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/docs/api/glasses_detector.components.base_model.rst:
--------------------------------------------------------------------------------
1 | Base Model
2 | ==========
3 |
4 | .. automodule:: glasses_detector.components.base_model
5 | :members:
6 | :show-inheritance:
7 |
--------------------------------------------------------------------------------
/src/glasses_detector/_wrappers/__init__.py:
--------------------------------------------------------------------------------
1 | from .binary_classifier import BinaryClassifier
2 | from .binary_detector import BinaryDetector
3 | from .binary_segmenter import BinarySegmenter
4 |
--------------------------------------------------------------------------------
/src/glasses_detector/architectures/__init__.py:
--------------------------------------------------------------------------------
1 | from .tiny_binary_classifier import TinyBinaryClassifier
2 | from .tiny_binary_detector import TinyBinaryDetector
3 | from .tiny_binary_segmenter import TinyBinarySegmenter
4 |
--------------------------------------------------------------------------------
/docs/docs/api/glasses_detector.components.pred_type.rst:
--------------------------------------------------------------------------------
1 | Prediction Type
2 | ===============
3 |
4 | .. automodule:: glasses_detector.components.pred_type
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Data
2 | data/*
3 | !data/demo
4 |
5 | # Training
6 | checkpoints
7 | lightning_logs
8 |
9 | # IDE
10 | .vscode
11 |
12 | # Build
13 | dist
14 | build
15 | _build
16 | _templates
17 | __pycache__
18 | *.egg-info
--------------------------------------------------------------------------------
/docs/docs/api/glasses_detector.architectures.tiny_binary_detector.rst:
--------------------------------------------------------------------------------
1 | Tiny Binary Detector
2 | ====================
3 |
4 | .. automodule:: glasses_detector.architectures.tiny_binary_detector
5 | :members:
6 | :show-inheritance:
7 |
--------------------------------------------------------------------------------
/docs/docs/api/glasses_detector.architectures.tiny_binary_segmenter.rst:
--------------------------------------------------------------------------------
1 | Tiny Binary Segmenter
2 | =====================
3 |
4 | .. automodule:: glasses_detector.architectures.tiny_binary_segmenter
5 | :members:
6 | :show-inheritance:
7 |
--------------------------------------------------------------------------------
/docs/docs/api/glasses_detector.components.pred_interface.rst:
--------------------------------------------------------------------------------
1 | Prediction Interface
2 | ====================
3 |
4 | .. automodule:: glasses_detector.components.pred_interface
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/docs/api/glasses_detector.architectures.tiny_binary_classifier.rst:
--------------------------------------------------------------------------------
1 | Tiny Binary Classifier
2 | ======================
3 |
4 | .. automodule:: glasses_detector.architectures.tiny_binary_classifier
5 | :members:
6 | :show-inheritance:
7 |
8 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tqdm
2 | torch
3 | scipy
4 | pyyaml
5 | fvcore
6 | rarfile
7 | ipykernel
8 | pycocotools
9 | torchvision
10 | prettytable
11 | albumentations
12 | pytorch_lightning
13 | tensorboard
14 | torch-tb-profiler
15 | jsonargparse[signatures]
--------------------------------------------------------------------------------
/docs/docs/api/glasses_detector.detector.rst:
--------------------------------------------------------------------------------
1 | Detector
2 | ========
3 |
4 | .. automodule:: glasses_detector.detector
5 | :members:
6 | :exclude-members: create_model, save, forward, model_info, BASE_WEIGHTS_URL, ALLOWED_SIZE_ALIASES, DEFAULT_SIZE_MAP, DEFAULT_KIND_MAP
7 | :inherited-members:
8 | :show-inheritance:
9 |
--------------------------------------------------------------------------------
/docs/docs/api/glasses_detector.segmenter.rst:
--------------------------------------------------------------------------------
1 | Segmenter
2 | =========
3 |
4 | .. automodule:: glasses_detector.segmenter
5 | :members:
6 | :exclude-members: create_model, save, forward, model_info, BASE_WEIGHTS_URL, ALLOWED_SIZE_ALIASES, DEFAULT_SIZE_MAP, DEFAULT_KIND_MAP
7 | :inherited-members:
8 | :show-inheritance:
9 |
--------------------------------------------------------------------------------
/docs/docs/api/glasses_detector.classifier.rst:
--------------------------------------------------------------------------------
1 | Classifier
2 | ==========
3 |
4 | .. automodule:: glasses_detector.classifier
5 | :members:
6 | :exclude-members: create_model, save, forward, model_info, BASE_WEIGHTS_URL, ALLOWED_SIZE_ALIASES, DEFAULT_SIZE_MAP, DEFAULT_KIND_MAP
7 | :inherited-members:
8 | :show-inheritance:
9 |
--------------------------------------------------------------------------------
/src/glasses_detector/_data/__init__.py:
--------------------------------------------------------------------------------
1 | from .augmenter_mixin import AugmenterMixin
2 | from .base_categorized_dataset import BaseCategorizedDataset
3 | from .binary_classification_dataset import BinaryClassificationDataset
4 | from .binary_detection_dataset import BinaryDetectionDataset
5 | from .binary_segmentation_dataset import BinarySegmentationDataset
6 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | message: >-
3 | If you find Glasses Detector useful in your work, please
4 | consider citing the following BibTeX entry.
5 | authors:
6 | - family-names: "Birškus"
7 | given-names: "Mantas"
8 | title: "Glasses Detector"
9 | journal: "GitHub repository"
10 | publisher: "GitHub"
11 | url: "https://github.com/mantasu/glasses-detector"
12 | license: "MIT"
13 | type: software
14 | date-released: 2024-03-06
15 | doi: 10.5281/zenodo.8126101
--------------------------------------------------------------------------------
/src/glasses_detector/_data/binary_segmentation_dataset.py:
--------------------------------------------------------------------------------
1 | from typing import override
2 |
3 | import torch
4 |
5 | from .base_categorized_dataset import BaseCategorizedDataset
6 |
7 |
8 | class BinarySegmentationDataset(BaseCategorizedDataset):
9 | @override
10 | def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]:
11 | sample = self.data[index]
12 | image, masks = self.load_transform(
13 | image=sample[self.img_folder],
14 | masks=[sample["masks"]],
15 | transform=self.transform,
16 | )
17 |
18 | return image, torch.stack(masks)
19 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/_static/css/signatures.css:
--------------------------------------------------------------------------------
1 | /* Newlines (\a) and spaces (\20) before each parameter */
2 | .long-sig.sig-param::before {
3 | content: "\a\20\20\20\20\20\20\20\20";
4 | white-space: pre;
5 | }
6 |
7 | /* Newlines (\a) and spaces (\20) before each parameter separator */
8 | span.long-sig::before {
9 | content: "\a\20\20\20\20\20\20\20\20";
10 | white-space: pre;
11 | }
12 |
13 | /* Newline after the last parameter (so the closing bracket is on a new line) */
14 | dt em.long-sig.sig-param:last-of-type::after {
15 | content: "\a";
16 | white-space: pre;
17 | }
18 |
19 | /* To have blue background of width of the block (instead of width of content) */
20 | dl.class > dt:first-of-type {
21 | display: block !important;
22 | }
23 |
--------------------------------------------------------------------------------
/src/glasses_detector/_data/binary_classification_dataset.py:
--------------------------------------------------------------------------------
1 | from typing import override
2 |
3 | import torch
4 |
5 | from .base_categorized_dataset import BaseCategorizedDataset
6 |
7 |
8 | class BinaryClassificationDataset(BaseCategorizedDataset):
9 | def __post_init__(self):
10 | # Flatten (some image names may have been the same across cats)
11 | self.data = [dict([cat]) for d in self.data for cat in d.items()]
12 |
13 | @override
14 | def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]:
15 | cat, pth = next(iter(self.data[index].items()))
16 | label = self.cat2tensor(cat)
17 | image = self.load_transform(image=pth, transform=self.transform)
18 |
19 | return image, label
20 |
--------------------------------------------------------------------------------
/docs/docs/api.rst:
--------------------------------------------------------------------------------
1 | :fas:`book` API
2 | ===============
3 |
4 | Models
5 | ------
6 |
7 | .. toctree::
8 | :maxdepth: 1
9 |
10 | api/glasses_detector.classifier
11 | api/glasses_detector.detector
12 | api/glasses_detector.segmenter
13 |
14 |
15 | Components
16 | ----------
17 |
18 | .. toctree::
19 | :maxdepth: 1
20 |
21 | api/glasses_detector.components.base_model
22 | api/glasses_detector.components.pred_interface
23 | api/glasses_detector.components.pred_type
24 | api/glasses_detector.utils
25 |
26 |
27 | Architectures
28 | -------------
29 |
30 | .. toctree::
31 | :maxdepth: 1
32 |
33 | api/glasses_detector.architectures.tiny_binary_classifier
34 | api/glasses_detector.architectures.tiny_binary_detector
35 | api/glasses_detector.architectures.tiny_binary_segmenter
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yaml:
--------------------------------------------------------------------------------
1 | name: build
2 |
3 | on:
4 | release:
5 | types: [published]
6 |
7 | workflow_dispatch:
8 |
9 | permissions:
10 | contents: read
11 | id-token: write
12 |
13 | jobs:
14 | deploy:
15 | runs-on: ubuntu-latest
16 | environment:
17 | name: pypi
18 | url: https://pypi.org/p/glasses-detector
19 | steps:
20 | - name: Set up Python
21 | uses: actions/setup-python@v5
22 | with:
23 | python-version: '3.12'
24 | - name: Checkout
25 | uses: actions/checkout@v4
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | pip install build
30 | - name: Build package
31 | run: python -m build
32 | - name: Publish package
33 | uses: pypa/gh-action-pypi-publish@release/v1
34 | with:
35 | user: __token__
36 | password: ${{ secrets.PYPI_API_TOKEN }}
37 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | %SPHINXBUILD% >NUL 2>NUL
14 | if errorlevel 9009 (
15 | echo.
16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17 | echo.installed, then set the SPHINXBUILD environment variable to point
18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
19 | echo.may add the Sphinx directory to PATH.
20 | echo.
21 | echo.If you don't have Sphinx installed, grab it from
22 | echo.https://www.sphinx-doc.org/
23 | exit /b 1
24 | )
25 |
26 | if "%1" == "" goto help
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/docs/credits.rst:
--------------------------------------------------------------------------------
1 | :fas:`heart` Credits
2 | ====================
3 |
4 | Support
5 | -------
6 |
7 | The easiest way to get support is to open an issue on |issues_link|.
8 |
9 | .. |issues_link| raw:: html
10 |
11 |
12 |
13 | GitHub
14 |
15 |
16 |
17 | Citation
18 | --------
19 |
20 | .. code-block:: bibtex
21 |
22 | @software{Birskus_Glasses_Detector_2024,
23 | author = {Birškus, Mantas},
24 | title = {{Glasses Detector}},
25 | license = {MIT},
26 | url = {https://github.com/mantasu/glasses-detector},
27 | month = {3},
28 | year = {2024},
29 | doi = {10.5281/zenodo.8126101}
30 | }
31 |
32 |
33 | References
34 | ----------
35 |
36 | .. bibliography::
37 | :style: unsrt
--------------------------------------------------------------------------------
/docs/_static/js/colab-icon.js:
--------------------------------------------------------------------------------
1 | FontAwesome.library.add(
2 | (faListOldStyle = {
3 | prefix: "fa-custom",
4 | iconName: "colab",
5 | icon: [
6 | 24, // viewBox width
7 | 24, // viewBox height
8 | [], // ligature
9 | "e001", // unicode codepoint - private use area
10 | "M16.9414 4.9757a7.033 7.033 0 0 0-4.9308 2.0646 7.033 7.033 0 0 0-.1232 9.8068l2.395-2.395a3.6455 3.6455 0 0 1 5.1497-5.1478l2.397-2.3989a7.033 7.033 0 0 0-4.8877-1.9297zM7.07 4.9855a7.033 7.033 0 0 0-4.8878 1.9316l2.3911 2.3911a3.6434 3.6434 0 0 1 5.0227.1271l1.7341-2.9737-.0997-.0802A7.033 7.033 0 0 0 7.07 4.9855zm15.0093 2.1721l-2.3892 2.3911a3.6455 3.6455 0 0 1-5.1497 5.1497l-2.4067 2.4068a7.0362 7.0362 0 0 0 9.9456-9.9476zM1.932 7.1674a7.033 7.033 0 0 0-.002 9.6816l2.397-2.397a3.6434 3.6434 0 0 1-.004-4.8916zm7.664 7.4235c-1.38 1.3816-3.5863 1.411-5.0168.1134l-2.397 2.395c2.4693 2.3328 6.263 2.5753 9.0072.5455l.1368-.1115z", // svg path (https://simpleicons.org/?q=colab)
11 | ],
12 | })
13 | );
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Mantas Birškus
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/docs/docs/install.rst:
--------------------------------------------------------------------------------
1 | :fas:`download` Installation
2 | ============================
3 |
4 | The packages requires at least `Python 3.12 `_. To install the package via `pip `_, simply run:
5 |
6 | .. code-block:: bash
7 |
8 | pip install glasses-detector
9 |
10 | Or, to install it from source, run:
11 |
12 | .. code-block:: bash
13 |
14 | git clone https://github.com/mantasu/glasses-detector
15 | cd glasses-detector && pip install .
16 |
17 | .. tip::
18 |
19 | You may want to set up `PyTorch `_ in advance to enable **GPU** support for your device. Note that *CUDA* is backwards compatible, thus even if you have the newest version of `CUDA Toolkit `_, *PyTorch* **GPU** acceleration should work just fine.
20 |
21 | .. note::
22 |
23 | By default, the required models will be automatically downloaded and saved under *Torch Hub* directory, which by default is ``~/.cache/torch/hub/checkpoints``. For more information and how to change it, see `Torch Hub documentation `_.
--------------------------------------------------------------------------------
/.github/workflows/sphinx.yaml:
--------------------------------------------------------------------------------
1 | name: docs
2 |
3 | on:
4 | release:
5 | types: [published]
6 |
7 | workflow_dispatch:
8 |
9 | permissions:
10 | contents: read
11 | pages: write
12 | id-token: write
13 |
14 | concurrency:
15 | group: "pages"
16 | cancel-in-progress: false
17 |
18 | jobs:
19 | deploy:
20 | environment:
21 | name: github-pages
22 | url: ${{ steps.deployment.outputs.page_url }}
23 | runs-on: ubuntu-latest
24 | steps:
25 | - name: Set up Python
26 | uses: actions/setup-python@v5
27 | with:
28 | python-version: "3.12"
29 | - name: Checkout
30 | uses: actions/checkout@v4
31 | - name: Install dependencies
32 | run: |
33 | python -m pip install --upgrade pip
34 | pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
35 | pip install .
36 | pip install --upgrade setuptools
37 | pip install -r docs/requirements.txt
38 | - name: Build HTML
39 | run: |
40 | cd docs/
41 | make html
42 | - name: Setup Pages
43 | uses: actions/configure-pages@v4
44 | - name: Upload artifact
45 | uses: actions/upload-pages-artifact@v3
46 | with:
47 | path: 'docs/_build/html'
48 | - name: Deploy to GitHub Pages
49 | id: deployment
50 | uses: actions/deploy-pages@v4
51 |
--------------------------------------------------------------------------------
/src/glasses_detector/_data/binary_detection_dataset.py:
--------------------------------------------------------------------------------
1 | from typing import override
2 |
3 | import albumentations as A
4 | import torch
5 | from torch.utils.data import DataLoader
6 |
7 | from .base_categorized_dataset import BaseCategorizedDataset
8 |
9 |
10 | class BinaryDetectionDataset(BaseCategorizedDataset):
11 | @staticmethod
12 | def collate_fn(
13 | batch: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
14 | ) -> tuple[list[torch.Tensor], list[dict[str, torch.Tensor]]]:
15 | images = [item[0] for item in batch]
16 | annots = [{"boxes": item[1], "labels": item[2]} for item in batch]
17 |
18 | return images, annots
19 |
20 | @override
21 | @classmethod
22 | def create_loader(cls, **kwargs) -> DataLoader:
23 | kwargs.setdefault("collate_fn", cls.collate_fn)
24 | return super().create_loader(**kwargs)
25 |
26 | @override
27 | def create_transform(self, is_train: bool, **kwargs) -> A.Compose:
28 | kwargs.setdefault("has_bbox", True)
29 | return super().create_transform(is_train, **kwargs)
30 |
31 | @override
32 | def __getitem__(
33 | self,
34 | index: int,
35 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
36 | sample = self.data[index]
37 | image, bboxes, bbcats = self.load_transform(
38 | image=sample[self.img_folder],
39 | boxes=[sample["annotations"]],
40 | bcats=[1], # 0 - background
41 | transform=self.transform,
42 | )
43 |
44 | if len(bboxes) == 0:
45 | bboxes = torch.tensor([[0, 0, 1, 1]], dtype=torch.float32)
46 | bbcats = torch.tensor([0], dtype=torch.int64)
47 |
48 | return image, bboxes, bbcats
49 |
--------------------------------------------------------------------------------
/docs/_static/js/pypi-icon.js:
--------------------------------------------------------------------------------
1 | FontAwesome.library.add(
2 | (faListOldStyle = {
3 | prefix: "fa-custom",
4 | iconName: "pypi",
5 | icon: [
6 | 17.313, // viewBox width
7 | 19.807, // viewBox height
8 | [], // ligature
9 | "e001", // unicode codepoint - private use area
10 | "m10.383 0.2-3.239 1.1769 3.1883 1.1614 3.239-1.1798zm-3.4152 1.2411-3.2362 1.1769 3.1855 1.1614 3.2369-1.1769zm6.7177 0.00281-3.2947 1.2009v3.8254l3.2947-1.1988zm-3.4145 1.2439-3.2926 1.1981v3.8254l0.17548-0.064132 3.1171-1.1347zm-6.6564 0.018325v3.8247l3.244 1.1805v-3.8254zm10.191 0.20931v2.3137l3.1777-1.1558zm3.2947 1.2425-3.2947 1.1988v3.8254l3.2947-1.1988zm-8.7058 0.45739c0.00929-1.931e-4 0.018327-2.977e-4 0.027485 0 0.25633 0.00851 0.4263 0.20713 0.42638 0.49826 1.953e-4 0.38532-0.29327 0.80469-0.65542 0.93662-0.36226 0.13215-0.65608-0.073306-0.65613-0.4588-6.28e-5 -0.38556 0.2938-0.80504 0.65613-0.93662 0.068422-0.024919 0.13655-0.038114 0.20156-0.039466zm5.2913 0.78369-3.2947 1.1988v3.8247l3.2947-1.1981zm-10.132 1.239-3.2362 1.1769 3.1883 1.1614 3.2362-1.1769zm6.7177 0.00213-3.2926 1.2016v3.8247l3.2926-1.2009zm-3.4124 1.2439-3.2947 1.1988v3.8254l3.2947-1.1988zm-6.6585 0.016195v3.8275l3.244 1.1805v-3.8254zm16.9 0.21143-3.2947 1.1988v3.8247l3.2947-1.1981zm-3.4145 1.2411-3.2926 1.2016v3.8247l3.2926-1.2009zm-3.4145 1.2411-3.2926 1.2016v3.8247l3.2926-1.2009zm-3.4124 1.2432-3.2947 1.1988v3.8254l3.2947-1.1988zm-6.6585 0.019027v3.8247l3.244 1.1805v-3.8254zm13.485 1.4497-3.2947 1.1988v3.8247l3.2947-1.1981zm-3.4145 1.2411-3.2926 1.2016v3.8247l3.2926-1.2009zm2.4018 0.38127c0.0093-1.83e-4 0.01833-3.16e-4 0.02749 0 0.25633 0.0085 0.4263 0.20713 0.42638 0.49826 1.97e-4 0.38532-0.29327 0.80469-0.65542 0.93662-0.36188 0.1316-0.65525-0.07375-0.65542-0.4588-1.95e-4 -0.38532 0.29328-0.80469 0.65542-0.93662 0.06842-0.02494 0.13655-0.03819 0.20156-0.03947zm-5.8142 0.86403-3.244 1.1805v1.4201l3.244 1.1805z", // svg path (https://simpleicons.org/icons/pypi.svg)
11 | ],
12 | })
13 | );
--------------------------------------------------------------------------------
/docs/_static/css/custom.css:
--------------------------------------------------------------------------------
1 | /* Adjust icon size */
2 | .fa-zenodo {
3 | font-size: 1em;
4 | }
5 |
6 | /* Header setup for index.html */
7 | #glasses-detector h1 {
8 | display: flex;
9 | align-items: center;
10 | }
11 |
12 | /* Logo setup for index.html page header */
13 | #glasses-detector h1::before {
14 | content: "";
15 | display: inline-block;
16 | height: 1.5em;
17 | width: 1.5em;
18 | margin-right: 10px;
19 | }
20 |
21 | /* Logo setup for index.html page header for dark mode */
22 | html[data-theme="dark"] #glasses-detector h1::before {
23 | background: url("../img/logo-dark.png") no-repeat center/contain;
24 | }
25 |
26 | /* Logo setup for index.html page header for light mode */
27 | html[data-theme="light"] #glasses-detector h1::before {
28 | background: url("../img/logo-light.png") no-repeat center/contain;
29 | }
30 |
31 | /* Adjust navbar logo size */
32 | img.logo__image {
33 | width: 40px;
34 | height: 40px;
35 | }
36 |
37 | /* Align the table content in the center */
38 | #classification-kinds td, #detection-kinds td, #segmentation-kinds td {
39 | vertical-align: middle;
40 | }
41 |
42 | /* Adjust banner */
43 | #banner {
44 | margin-top: 2em;
45 | border-radius: 0.3em;
46 | }
47 |
48 | .toctree-wrapper {
49 | /* Padding + Width */
50 | padding: 1em 2em;
51 | width: 100%;
52 |
53 | /* Border Setup */
54 | border-radius: 1em;
55 | border-width: .25em;
56 | border-style: solid;
57 |
58 | /* Color setup (automatic for dark/light modes) */
59 | background-color: var(--pst-color-on-background);
60 | border: .25em solid var(--pst-color-border)
61 | }
62 |
63 | #index-toctree {
64 | /* Multi-column Layout */
65 | column-count: 2;
66 | }
67 |
68 | #index-toctree .toctree-l1 {
69 | /* List Style */
70 | padding: 0.5em 0;
71 | }
72 |
73 | #index-toctree .toctree-l2 {
74 | /* List Style */
75 | list-style-type: circle !important;
76 | margin-left: 1em !important;
77 | }
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | Glasses Detector
2 | ================
3 |
4 | .. image:: _static/img/banner.jpg
5 | :alt: Glasses Detector Banner
6 | :align: center
7 | :name: banner
8 |
9 | About
10 | -----
11 |
12 | Package for processing images with different types of glasses and their parts. It provides a quick way to use the pre-trained models for **3** kinds of tasks, each being divided to multiple categories, for instance, *classification of sunglasses* or *segmentation of glasses frames* (see :doc:`docs/features` for more details):
13 |
14 | .. list-table::
15 | :align: center
16 |
17 | * -
18 | -
19 | * - **Classification**
20 | - 👓 *transparent* 🕶️ *opaque* 🥽 *any* ➿ *shadows*
21 | * - **Detection**
22 | - 🤓 *worn* 👓 *standalone* 👀 *eye-area*
23 | * - **Segmentation**
24 | - 😎 *full* 🖼️ *frames* 🦿 *legs* 🔍 *lenses* 👥 *shadows*
25 |
26 | .. raw:: html
27 |
28 |
29 |
30 | The processing can be launched via the command line or written in a *Python* script. Based on the selected task, an image or a directory of images will be processed and corresponding predictions, e.g., labels or masks, will be generated.
31 |
32 | .. seealso::
33 |
34 | Refer to |repo_link| for information about the datasets used and how to train or test your own models. Model fitting and evaluation can be simply launched through terminal with commands integrated from `PyTorch Lightning CLI `_.
35 |
36 | .. |repo_link| raw:: html
37 |
38 |
39 |
40 | GitHub repository
41 |
42 |
43 |
44 | Contents
45 | --------
46 |
47 | .. toctree::
48 | :maxdepth: 2
49 | :name: index-toctree
50 |
51 | docs/install
52 | docs/features
53 | docs/examples
54 | docs/cli
55 | docs/api
56 | docs/credits
57 |
58 |
59 | Indices and Tables
60 | ------------------
61 |
62 | * :ref:`genindex`
63 | * :ref:`modindex`
64 | * :ref:`search`
65 |
--------------------------------------------------------------------------------
/docs/_static/bib/references.bib:
--------------------------------------------------------------------------------
1 | @inproceedings{ma2018shufflenet,
2 | title={Shufflenet v2: Practical guidelines for efficient cnn architecture design},
3 | author={Ma, Ningning and Zhang, Xiangyu and Zheng, Hai-Tao and Sun, Jian},
4 | booktitle={Proceedings of the European Conference on Computer Vision (ECCV)},
5 | pages={116--131},
6 | year={2018}
7 | }
8 |
9 | @inproceedings{liu2016ssd,
10 | title={Ssd: Single shot multibox detector},
11 | author={Liu, Wei and Anguelov, Dragomir and Erhan, Dumitru and Szegedy, Christian and Reed, Scott and Fu, Cheng-Yang and Berg, Alexander C},
12 | booktitle={Computer Vision--ECCV 2016: 14th European Conference, Amsterdam, The Netherlands, October 11--14, 2016, Proceedings, Part I 14},
13 | pages={21--37},
14 | year={2016},
15 | organization={Springer}
16 | }
17 |
18 | @inproceedings{howard2019searching,
19 | title={Searching for mobilenetv3},
20 | author={Howard, Andrew and Sandler, Mark and Chu, Grace and Chen, Liang-Chieh and Chen, Bo and Tan, Mingxing and Wang, Weijun and Zhu, Yukun and Pang, Ruoming and Vasudevan, Vijay and others},
21 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
22 | pages={1314--1324},
23 | year={2019}
24 | }
25 |
26 | @article{ren2015faster,
27 | title={Faster r-cnn: Towards real-time object detection with region proposal networks},
28 | author={Ren, Shaoqing and He, Kaiming and Girshick, Ross and Sun, Jian},
29 | journal={Advances in Neural Information Processing Systems},
30 | volume={28},
31 | year={2015}
32 | }
33 |
34 | @inproceedings{long2015fully,
35 | title={Fully convolutional networks for semantic segmentation},
36 | author={Long, Jonathan and Shelhamer, Evan and Darrell, Trevor},
37 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
38 | pages={3431--3440},
39 | year={2015}
40 | }
41 |
42 | @inproceedings{he2016deep,
43 | title={Deep residual learning for image recognition},
44 | author={He, Kaiming and Zhang, Xiangyu and Ren, Shaoqing and Sun, Jian},
45 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
46 | pages={770--778},
47 | year={2016}
48 | }
--------------------------------------------------------------------------------
/src/glasses_detector/architectures/tiny_binary_classifier.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class TinyBinaryClassifier(nn.Module):
6 | """Tiny binary classifier.
7 |
8 | This is a custom classifier created with the aim to contain very few
9 | parameters while maintaining a reasonable accuracy. It only has
10 | several sequential convolutional and pooling blocks (with
11 | batch-norm in between).
12 | """
13 |
14 | def __init__(self):
15 | super().__init__()
16 |
17 | # Several convolutional blocks
18 | self.features = nn.Sequential(
19 | self._create_block(3, 5, 3),
20 | self._create_block(5, 10, 3),
21 | self._create_block(10, 15, 3),
22 | self._create_block(15, 20, 3),
23 | self._create_block(20, 25, 3),
24 | self._create_block(25, 80, 3),
25 | nn.AdaptiveAvgPool2d(1),
26 | nn.Flatten(),
27 | )
28 |
29 | # Fully connected layer
30 | self.fc = nn.Linear(80, 1)
31 |
32 | def _create_block(self, num_in, num_out, filter_size):
33 | return nn.Sequential(
34 | nn.Conv2d(num_in, num_out, filter_size, 1, "valid", bias=False),
35 | nn.ReLU(),
36 | nn.BatchNorm2d(num_out),
37 | nn.MaxPool2d(2, 2),
38 | )
39 |
40 | def forward(self, x: torch.Tensor) -> torch.Tensor:
41 | """Performs forward pass.
42 |
43 | Predicts raw scores for the given batch of inputs. Scores are
44 | unbounded, anything that's less than 0, means positive class is
45 | unlikely and anything that's above 0 indicates that the positive
46 | class is likely
47 |
48 | Args:
49 | x (torch.Tensor): Image batch of shape (N, C, H, W). Note
50 | that pixel values are normalized and squeezed between
51 | 0 and 1.
52 |
53 | Returns:
54 | torch.Tensor: An output tensor of shape (N,) indicating
55 | whether each nth image falls under the positive class or
56 | not. The scores are unbounded, thus, to convert to a
57 | probability, sigmoid function must be used.
58 | """
59 | return self.fc(self.features(x))
60 |
--------------------------------------------------------------------------------
/docs/_static/svg/colab.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/helpers/custom_invs.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from importlib.metadata import version
4 |
5 | import sphobjinv as soi
6 |
7 |
8 | class CustomInvs:
9 | def __init__(self, static_path: str = "_static"):
10 | # Init inv directory and create it if not exists
11 | self.inv_dir = os.path.join(static_path, "inv")
12 | os.makedirs(self.inv_dir, exist_ok=True)
13 |
14 | def create_tqdm_inv(self) -> dict[str, tuple[str, str]]:
15 | # Init inv and module
16 | inv = soi.Inventory()
17 |
18 | # Define the Sphinx header information
19 | inv.project = "tqdm"
20 | inv.version = version("tqdm")
21 |
22 | # Define the tqdm class
23 | inv_obj = soi.DataObjStr(
24 | name="tqdm.tqdm",
25 | domain="py",
26 | role="class",
27 | priority="1",
28 | uri="tqdm/#tqdm-objects",
29 | dispname="tqdm",
30 | )
31 | inv.objects.append(inv_obj)
32 |
33 | # Write the inventory to a file
34 | path = os.path.join(self.inv_dir, "tqdm.inv")
35 | text = soi.compress(inv.data_file(contract=True))
36 | soi.writebytes(path, text)
37 |
38 | return {"tqdm": ("https://tqdm.github.io/docs/", path)}
39 |
40 | def create_builtin_constants_inv(self) -> dict[str, tuple[str, str]]:
41 | # Init inv and module
42 | inv = soi.Inventory()
43 |
44 | # Define the Sphinx header information
45 | inv.project = "builtin_constants"
46 | major, minor, micro = sys.version_info[:3]
47 | inv.version = f"{major}.{minor}.{micro}"
48 |
49 | for constant in ["None", "True", "False"]:
50 | # Define the constant as class
51 | inv_obj = soi.DataObjStr(
52 | name=f"{constant}",
53 | domain="py",
54 | role="class", # dummy class for linking
55 | priority="1",
56 | uri=f"library/constants.html#{constant}",
57 | dispname=constant,
58 | )
59 | inv.objects.append(inv_obj)
60 |
61 | # Write the inventory to a file
62 | path = os.path.join(self.inv_dir, "builtin_constants.inv")
63 | text = soi.compress(inv.data_file(contract=True))
64 | soi.writebytes(path, text)
65 |
66 | return {"builtin_constants": ("https://docs.python.org/3", path)}
67 |
68 | def __call__(self) -> dict[str, tuple[str, str]]:
69 | # Init custom invs
70 | custom_invs = {}
71 |
72 | for method_name in dir(self):
73 | if method_name == "create_builtin_constants_inv":
74 | continue
75 |
76 | if method_name.startswith("create_"):
77 | # Update custom invs dictionary with the new one
78 | custom_invs.update(getattr(self, method_name)())
79 |
80 | return custom_invs
81 |
--------------------------------------------------------------------------------
/docs/_static/css/highlights.css:
--------------------------------------------------------------------------------
1 | div.highlight-python div.highlight pre .linenos {
2 | margin-right: 1em;
3 | }
4 |
5 | /* Default python code syntax highlights for dark mode*/
6 | html[data-theme="dark"] div.highlight-python div.highlight pre {
7 | .n { color: #9CDCFE; } /* Variable */
8 | .k, .ow, .kn { color: #C586C0; } /* Keyword */
9 | .gp { color: #af92ff; } /* Generic Prompt */
10 | .kc { color: #569CD6; } /* Built-in constant */
11 | .s2 { color: #CE9178; } /* String */
12 | .mi, .m { color: #B5CEA8; } /* Number */
13 | .o { color: #CCCCCC; } /* Operator */
14 | .go { color: #808080; } /* Generic Output */
15 | .p { color: #D4D4D4; } /* Punctuation */
16 | .c1 { color: #6A9955; } /* Comment */
17 | .nn, .nc { color: #4EC9B0; } /* Class/module name */
18 | .nf, .nb { color: #DCDCAA; } /* Function name */
19 | .err { color: #F44747; } /* Error */
20 | }
21 |
22 | /* Default python code syntax highlights for light mode*/
23 | html[data-theme="light"] div.highlight-python div.highlight pre {
24 | .n { color: #3B3B3B; } /* Variable */
25 | .k, .ow, .kn { color: #AF00DB; } /* Keyword */
26 | .gp { color: #ff9ef7; } /* Generic Prompt */
27 | .kc { color: #0000FF; } /* Built-in constant */
28 | .s2 { color: #A31515; } /* String */
29 | .mi, .m { color: #098658; } /* Number */
30 | .o { color: #3B3B3B; } /* Operator */
31 | .go { color: #8c8c8c; } /* Generic Output */
32 | .p { color: #3B3B3B; } /* Punctuation */
33 | .c1 { color: #008000; } /* Comment */
34 | .nn, .nc { color: #267F99; } /* Class/module name */
35 | .nf, .nb { color: #001080; } /* Function name */
36 | .err { color: #fe7272; } /* Error */
37 | }
38 |
39 | /* Default bash code syntax highlights for dark mode*/
40 | html[data-theme="dark"] div.highlight-bash div.highlight pre {
41 | .c1 { color: #6A9955; } /* Comment */
42 | }
43 |
44 | /* Default bash code syntax highlights for light mode*/
45 | html[data-theme="light"] div.highlight-bash div.highlight pre {
46 | .c1 { color: #008000; } /* Comment */
47 | }
48 |
49 | /* Custom python code syntax highlights for dark mode*/
50 | html[data-theme="dark"] div.highlight-python div.highlight pre {
51 | .custom-highlight-class { color: #4EC9B0; }
52 | .custom-highlight-function { color: #DCDCAA; }
53 | .custom-highlight-variable { color: #9CDCFE; }
54 | .custom-highlight-constant { color: #4FC1FF; }
55 | }
56 |
57 | /* Custom python code syntax highlights for light mode*/
58 | html[data-theme="light"] div.highlight-python div.highlight pre {
59 | .custom-highlight-class { color: #267F99; }
60 | .custom-highlight-function { color: #001080; }
61 | .custom-highlight-variable { color: #3B3B3B; }
62 | .custom-highlight-constant { color: #0070C1; }
63 | }
64 |
65 | /* Custom bash code syntax highlights for dark mode*/
66 | html[data-theme="dark"] div.highlight-bash div.highlight pre, html[data-theme="dark"] code.highlight-bash {
67 | .custom-highlight-default { color: #CE9178; }
68 | .custom-highlight-start { color: #DCDCAA; }
69 | .custom-highlight-flag { color: #569CD6; }
70 | .custom-highlight-op { color: #CCCCCC; }
71 | }
72 |
73 | /* Custom bash code syntax highlights for light mode*/
74 | html[data-theme="light"] div.highlight-bash div.highlight pre, html[data-theme="light"] code.highlight-bash {
75 | .custom-highlight-default { color: #A31515; }
76 | .custom-highlight-start { color: #001080; }
77 | .custom-highlight-flag { color: #0000FF; }
78 | .custom-highlight-op { color: #3B3B3B; }
79 | }
--------------------------------------------------------------------------------
/src/glasses_detector/_wrappers/binary_detector.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 | import torch
3 | import torchmetrics
4 | from torch.optim import AdamW
5 | from torch.optim.lr_scheduler import ReduceLROnPlateau
6 |
7 | from .metrics import BoxClippedR2, BoxIoU, BoxMSLE
8 |
9 |
10 | class BinaryDetector(pl.LightningModule):
11 | def __init__(self, model, train_loader=None, val_loader=None, test_loader=None):
12 | super().__init__()
13 |
14 | # Assign attributes
15 | self.model = model
16 | self.train_loader = train_loader
17 | self.val_loader = val_loader
18 | self.test_loader = test_loader
19 |
20 | # Initialize val_loss metric (just the mean)
21 | self.val_loss = torchmetrics.MeanMetric()
22 |
23 | # Create F1 score to monitor average label performance
24 | self.label_metrics = torchmetrics.MetricCollection(
25 | [torchmetrics.F1Score(task="binary")]
26 | )
27 |
28 | # Initialize some metrics to monitor bbox performance
29 | self.boxes_metrics = torchmetrics.MetricCollection(
30 | [BoxMSLE(), BoxIoU(), BoxClippedR2()]
31 | )
32 |
33 | def forward(self, *args):
34 | return self.model(*args)
35 |
36 | def training_step(self, batch, batch_idx):
37 | # Forward propagate and compute loss
38 | loss_dict = self(batch[0], batch[1])
39 | loss = sum(loss for loss in loss_dict.values()) / len(batch[0])
40 | self.log("train_loss", loss, prog_bar=True)
41 | return loss
42 |
43 | def eval_step(self, batch):
44 | # Forward pass and compute loss
45 | with torch.inference_mode():
46 | self.train()
47 | loss = sum(loss for loss in self(batch[0], batch[1]).values())
48 | self.val_loss.update(loss / len(batch[0]))
49 | self.eval()
50 |
51 | # Update all the metrics
52 | self.boxes_metrics.update(self(batch[0]), batch[1], self.label_metrics)
53 |
54 | def on_eval_epoch_end(self, prefix=""):
55 | # Compute total loss and metrics
56 | loss = self.val_loss.compute()
57 | label_metrics = self.label_metrics.compute()
58 | boxes_metrics = self.boxes_metrics.compute()
59 |
60 | # Reset the metrics
61 | self.val_loss.reset()
62 | self.label_metrics.reset()
63 | self.boxes_metrics.reset()
64 |
65 | # Log the metrics and the learning rate
66 | self.log(f"{prefix}_loss", loss, prog_bar=True)
67 | self.log(f"{prefix}_f1", label_metrics["BinaryF1Score"], prog_bar=True)
68 | self.log(f"{prefix}_msle", boxes_metrics["BoxMSLE"], prog_bar=True)
69 | self.log(f"{prefix}_r2", boxes_metrics["BoxClippedR2"], prog_bar=True)
70 | self.log(f"{prefix}_iou", boxes_metrics["BoxIoU"], prog_bar=True)
71 |
72 | if not isinstance(opt := self.optimizers(), list):
73 | # Log the learning rate of a single optimizer
74 | self.log("lr", opt.param_groups[0]["lr"], prog_bar=True)
75 |
76 | def validation_step(self, batch, batch_idx):
77 | self.eval_step(batch)
78 |
79 | def on_validation_epoch_end(self):
80 | self.on_eval_epoch_end(prefix="val")
81 |
82 | def test_step(self, batch, batch_idx):
83 | self.eval_step(batch)
84 |
85 | def on_test_epoch_end(self):
86 | self.on_eval_epoch_end(prefix="test")
87 |
88 | def train_dataloader(self):
89 | return self.train_loader
90 |
91 | def val_dataloader(self):
92 | return self.val_loader
93 |
94 | def test_dataloader(self):
95 | return self.test_loader
96 |
97 | def configure_optimizers(self):
98 | # Initialize AdamW optimizer and Reduce On Plateau scheduler
99 | optimizer = AdamW(self.parameters(), lr=1e-3, weight_decay=1e-3)
100 | scheduler = ReduceLROnPlateau(
101 | optimizer=optimizer,
102 | factor=0.3,
103 | patience=15,
104 | threshold=0.01,
105 | min_lr=1e-6,
106 | )
107 |
108 | return {
109 | "optimizer": optimizer,
110 | "lr_scheduler": scheduler,
111 | "monitor": "val_loss",
112 | }
113 |
--------------------------------------------------------------------------------
/src/glasses_detector/architectures/tiny_binary_segmenter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchvision.ops import Conv2dNormActivation
4 |
5 |
6 | class TinyBinarySegmenter(nn.Module):
7 | """Tiny binary segmenter.
8 |
9 | This is a custom segmenter created with the aim to contain very few
10 | parameters while maintaining a reasonable accuracy. It only has
11 | several sequential up-convolution and down-convolution layers with
12 | residual connections and is very similar to U-Net.
13 |
14 | Note:
15 | You can read more about U-Net architecture in the following
16 | paper by O. Ronneberger et al.:
17 | `U-Net: Convolutional Networks for Biomedical Image Segmentation `_
18 | """
19 |
20 | class _Down(nn.Module):
21 | def __init__(self, in_channels, out_channels):
22 | super().__init__()
23 |
24 | self.pool0 = nn.MaxPool2d(2)
25 | self.conv1 = Conv2dNormActivation(in_channels, out_channels)
26 | self.conv2 = Conv2dNormActivation(out_channels, out_channels)
27 |
28 | def forward(self, x):
29 | return self.conv2(self.conv1(self.pool0(x)))
30 |
31 | class _Up(nn.Module):
32 | def __init__(self, in_channels, out_channels):
33 | super().__init__()
34 |
35 | half_channels = in_channels // 2
36 | self.conv0 = nn.ConvTranspose2d(half_channels, half_channels, 2, 2)
37 | self.conv1 = Conv2dNormActivation(in_channels, out_channels)
38 | self.conv2 = Conv2dNormActivation(out_channels, out_channels)
39 |
40 | def forward(self, x1, x2):
41 | x1 = self.conv0(x1)
42 |
43 | diffY = x2.size()[2] - x1.size()[2]
44 | diffX = x2.size()[3] - x1.size()[3]
45 |
46 | x1 = nn.functional.pad(
47 | x1, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2)
48 | )
49 |
50 | x = torch.cat([x2, x1], dim=1)
51 |
52 | return self.conv2(self.conv1(x))
53 |
54 | def __init__(self):
55 | super().__init__()
56 |
57 | # Feature extraction layer
58 | self.first = nn.Sequential(
59 | Conv2dNormActivation(3, 16),
60 | Conv2dNormActivation(16, 16),
61 | )
62 |
63 | # Down-sampling layers
64 | self.down1 = self._Down(16, 32)
65 | self.down2 = self._Down(32, 64)
66 | self.down3 = self._Down(64, 64)
67 |
68 | # Up-sampling layers
69 | self.up1 = self._Up(128, 32)
70 | self.up2 = self._Up(64, 16)
71 | self.up3 = self._Up(32, 16)
72 |
73 | # Pixel-wise classification layer
74 | self.last = nn.Conv2d(16, 1, 1)
75 |
76 | def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
77 | """Performs forward pass.
78 |
79 | Predicts raw pixel scores for the given batch of inputs. Scores
80 | are unbounded - anything that's less than 0 means positive class
81 | belonging to the pixel is unlikely and anything that's above 0
82 | indicates that positive class for a particular pixel is likely.
83 |
84 | Args:
85 | x (torch.Tensor): Image batch of shape (N, C, H, W). Note
86 | that pixel values are normalized and squeezed between
87 | 0 and 1.
88 |
89 | Returns:
90 | dict[str, torch.Tensor]: A dictionary with a single "out"
91 | entry (for compatibility). The value is an output tensor of
92 | shape (N, 1, H, W) indicating which pixels in the image fall
93 | under positive category. The scores are unbounded, thus, to
94 | convert to probabilities, sigmoid function must be used.
95 | """
96 | # Extract primary features
97 | x1 = self.first(x)
98 |
99 | # Downsample features
100 | x2 = self.down1(x1)
101 | x3 = self.down2(x2)
102 | x4 = self.down3(x3)
103 |
104 | # Updample features
105 | x = self.up1(x4, x3)
106 | x = self.up2(x, x2)
107 | x = self.up3(x, x1)
108 |
109 | # Predict one channel
110 | out = self.last(x)
111 |
112 | return {"out": out}
113 |
--------------------------------------------------------------------------------
/src/glasses_detector/_wrappers/binary_segmenter.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 | import torch.nn as nn
3 | import torchmetrics
4 | from torch.optim import AdamW
5 | from torch.optim.lr_scheduler import ReduceLROnPlateau
6 |
7 |
8 | class BinarySegmenter(pl.LightningModule):
9 | def __init__(self, model, train_loader=None, val_loader=None, test_loader=None):
10 | super().__init__()
11 |
12 | # Assign attributes
13 | self.model = model
14 | self.train_loader = train_loader
15 | self.val_loader = val_loader
16 | self.test_loader = test_loader
17 |
18 | # Create loss function and account for imbalance of classes
19 | self.criterion = nn.BCEWithLogitsLoss(pos_weight=self.pos_weight)
20 | self.val_loss = torchmetrics.MeanMetric()
21 |
22 | # Initialize some metrics to monitor the performance
23 | self.metrics = torchmetrics.MetricCollection(
24 | [
25 | torchmetrics.MatthewsCorrCoef(task="binary"),
26 | torchmetrics.F1Score(task="binary"), # same as Dice
27 | torchmetrics.JaccardIndex(task="binary"), # IoU
28 | ]
29 | )
30 |
31 | @property
32 | def pos_weight(self):
33 | if self.train_loader is None:
34 | # Not known
35 | return None
36 |
37 | # Init counts
38 | pos, neg = 0, 0
39 |
40 | for _, mask in self.train_loader:
41 | # Update pos and neg sums
42 | pos += mask.sum()
43 | neg += (1 - mask).sum()
44 |
45 | return neg / pos
46 |
47 | def forward(self, x):
48 | return self.model(x)["out"]
49 |
50 | def training_step(self, batch, batch_idx):
51 | # Forward propagate and compute loss
52 | loss = self.criterion(self(batch[0]), batch[1])
53 | self.log("train_loss", loss, prog_bar=True)
54 | return loss
55 |
56 | def eval_step(self, batch):
57 | # Forward pass
58 | x, y = batch
59 | y_hat = self(x)
60 |
61 | # Compute the loss and the metrics
62 | self.val_loss.update(self.criterion(y_hat, y))
63 | self.metrics.update(y_hat.sigmoid(), y.long())
64 |
65 | def on_eval_epoch_end(self, prefix=""):
66 | # Compute total loss and metrics
67 | loss = self.val_loss.compute()
68 | metrics = self.metrics.compute()
69 |
70 | # Reset the loss and the metrics
71 | self.val_loss.reset()
72 | self.metrics.reset()
73 |
74 | # Log the loss and the metrics
75 | self.log(f"{prefix}_loss", loss, prog_bar=True)
76 | self.log(f"{prefix}_mcc", metrics["BinaryMatthewsCorrCoef"], prog_bar=True)
77 | self.log(f"{prefix}_f1", metrics["BinaryF1Score"], prog_bar=True)
78 | self.log(f"{prefix}_iou", metrics["BinaryJaccardIndex"], prog_bar=True)
79 |
80 | if not isinstance(opt := self.optimizers(), list):
81 | # Log the learning rate of a single optimizer
82 | self.log("lr", opt.param_groups[0]["lr"], prog_bar=True)
83 |
84 | def validation_step(self, batch, batch_idx):
85 | self.eval_step(batch)
86 |
87 | def on_validation_epoch_end(self):
88 | self.on_eval_epoch_end(prefix="val")
89 |
90 | def test_step(self, batch, batch_idx):
91 | self.eval_step(batch)
92 |
93 | def on_test_epoch_end(self):
94 | self.on_eval_epoch_end(prefix="test")
95 |
96 | def train_dataloader(self):
97 | return self.train_loader
98 |
99 | def val_dataloader(self):
100 | return self.val_loader
101 |
102 | def test_dataloader(self):
103 | return self.test_loader
104 |
105 | def configure_optimizers(self):
106 | # Initialize AdamW optimizer and Reduce On Plateau scheduler
107 | optimizer = AdamW(self.parameters(), lr=1e-3, weight_decay=1e-4)
108 | scheduler = ReduceLROnPlateau(
109 | optimizer=optimizer,
110 | factor=0.3,
111 | patience=15,
112 | threshold=0.01,
113 | min_lr=1e-6,
114 | )
115 |
116 | return {
117 | "optimizer": optimizer,
118 | "lr_scheduler": scheduler,
119 | "monitor": "val_loss",
120 | }
121 |
--------------------------------------------------------------------------------
/docs/conf.yaml:
--------------------------------------------------------------------------------
1 | build-finished:
2 | TYPE_ALIASES:
3 | FilePath: "glasses_detector.utils."
4 | Scalar: "glasses_detector.components.pred_type."
5 | Tensor: "glasses_detector.components.pred_type."
6 | Default: "glasses_detector.components.pred_type."
7 | StandardScalar: "glasses_detector.components.pred_type."
8 | StandardTensor: "glasses_detector.components.pred_type."
9 | StandardDefault: "glasses_detector.components.pred_type."
10 | NonDefault: "glasses_detector.components.pred_type."
11 | Either: "glasses_detector.components.pred_type."
12 |
13 | LONG_SIGNATURE_IDS:
14 | - "glasses_detector.components.pred_type.PredType"
15 | - "glasses_detector.components.pred_interface.PredInterface.process_file"
16 | - "glasses_detector.components.pred_interface.PredInterface.process_dir"
17 | - "glasses_detector.components.base_model.BaseGlassesModel"
18 | - "glasses_detector.components.base_model.BaseGlassesModel.predict"
19 | - "glasses_detector.components.base_model.BaseGlassesModel.process_dir"
20 | - "glasses_detector.components.base_model.BaseGlassesModel.process_file"
21 | - "glasses_detector.classifier.GlassesClassifier"
22 | - "glasses_detector.classifier.GlassesClassifier.draw_label"
23 | - "glasses_detector.classifier.GlassesClassifier.predict"
24 | - "glasses_detector.classifier.GlassesClassifier.process_dir"
25 | - "glasses_detector.classifier.GlassesClassifier.process_file"
26 | - "glasses_detector.detector.GlassesDetector"
27 | - "glasses_detector.detector.GlassesDetector.draw_boxes"
28 | - "glasses_detector.detector.GlassesDetector.predict"
29 | - "glasses_detector.detector.GlassesDetector.process_dir"
30 | - "glasses_detector.detector.GlassesDetector.process_file"
31 | - "glasses_detector.segmenter.GlassesSegmenter"
32 | - "glasses_detector.segmenter.GlassesSegmenter.draw_masks"
33 | - "glasses_detector.segmenter.GlassesSegmenter.predict"
34 | - "glasses_detector.segmenter.GlassesSegmenter.process_dir"
35 | - "glasses_detector.segmenter.GlassesSegmenter.process_file"
36 | - "glasses_detector.architectures.tiny_binary_detector.TinyBinaryDetector.forward"
37 | - "glasses_detector.architectures.tiny_binary_detector.TinyBinaryDetector.compute_loss"
38 |
39 | LONG_PARAMETER_IDS:
40 | glasses_detector.components.base_model.BaseGlassesModel.predict:
41 | - "format"
42 | glasses_detector.classifier.GlassesClassifier.predict:
43 | - "format"
44 | glasses_detector.detector.GlassesDetector.predict:
45 | - "format"
46 | glasses_detector.segmenter.GlassesSegmenter.predict:
47 | - "format"
48 | glasses_detector.detector.GlassesDetector.draw_boxes:
49 | - "colors"
50 | glasses_detector.segmenter.GlassesSegmenter.draw_masks:
51 | - "colors"
52 |
53 | CUSTOM_SYNTAX_COLORS_PYTHON:
54 | custom-highlight-class:
55 | - "GlassesClassifier"
56 | - "GlassesDetector"
57 | - "GlassesSegmenter"
58 | - "PredType"
59 | - "Image"
60 | - "type"
61 | - "np"
62 | - "str"
63 | - "int"
64 | - "bool"
65 | - "subprocess"
66 | - "random"
67 | - "@copy_signature"
68 | - "@eval_infer_mode"
69 | - "eval_infer_mode"
70 | custom-highlight-function:
71 | - "process_file"
72 | - "process_dir"
73 | - "run"
74 | - "load"
75 | - "array"
76 | - "zeros"
77 | - "fromarray"
78 | - "values"
79 | - "standardize"
80 | - "is_standard_scalar"
81 | - "is_default"
82 | - "reveal_type"
83 | - "full_signature"
84 | - "test_signature"
85 | custom-highlight-constant:
86 | - "StandardScalar"
87 | - "DEFAULT_SIZE_MAP"
88 | - "DEFAULT_KIND_MAP"
89 | custom-highlight-variable:
90 | - "format"
91 | - "self"
92 |
93 | CUSTOM_SYNTAX_COLORS_BASH:
94 | custom-highlight-start:
95 | - "cd"
96 | - "pip install"
97 | - "git clone"
98 | - "glasses-detector"
99 | custom-highlight-flag:
100 | - "-"
101 | custom-highlight-op:
102 | - "&&"
103 |
104 | SECTION_ICONS:
105 | Installation: "fas fa-download"
106 | Features: "fas fa-cogs"
107 | Examples: "fas fa-code"
108 | CLI: "fas fa-terminal"
109 | API: "fas fa-book"
110 | Credits: "fas fa-heart"
111 |
--------------------------------------------------------------------------------
/src/glasses_detector/_wrappers/binary_classifier.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 | import torch
3 | import torch.nn as nn
4 | import torchmetrics
5 | import tqdm
6 | from torch.optim import AdamW
7 | from torch.optim.lr_scheduler import ReduceLROnPlateau
8 |
9 |
10 | class BinaryClassifier(pl.LightningModule):
11 | def __init__(self, model, train_loader=None, val_loader=None, test_loader=None):
12 | super().__init__()
13 |
14 | # Assign attributes
15 | self.model = model
16 | self.train_loader = train_loader
17 | self.val_loader = val_loader
18 | self.test_loader = test_loader
19 |
20 | # Create loss function and account for imbalance of classes
21 | self.criterion = nn.BCEWithLogitsLoss(pos_weight=self.pos_weight)
22 | self.val_loss = torchmetrics.MeanMetric()
23 |
24 | # Create F1 score and ROC-AUC metrics to monitor
25 | self.metrics = torchmetrics.MetricCollection(
26 | [
27 | torchmetrics.F1Score(task="binary"),
28 | torchmetrics.AUROC(task="binary"), # ROC-AUC
29 | torchmetrics.AveragePrecision(task="binary"), # PR-AUC
30 | ]
31 | )
32 |
33 | @property
34 | def pos_weight(self):
35 | if self.train_loader is None:
36 | # Not known
37 | return None
38 |
39 | # Calculate the positive weight to account for class imbalance
40 | iterator = tqdm.tqdm(self.train_loader, desc="Computing pos_weight")
41 | pos_count = sum(y.sum().item() for _, y in iterator)
42 | neg_count = len(self.train_loader.dataset) - pos_count
43 |
44 | return torch.tensor(neg_count / pos_count)
45 |
46 | def forward(self, x):
47 | return self.model(x)
48 |
49 | def training_step(self, batch, batch_idx):
50 | # Forward propagate and compute loss
51 | loss = self.criterion(self(batch[0]), batch[1].to(torch.float32))
52 | self.log("train_loss", loss, prog_bar=True)
53 | return loss
54 |
55 | def eval_step(self, batch):
56 | # Forward pass
57 | x, y = batch
58 | y_hat = self(x)
59 |
60 | # Compute the loss and the metrics
61 | self.val_loss.update(self.criterion(y_hat, y.to(torch.float32)))
62 | self.metrics.update(y_hat.sigmoid(), y)
63 |
64 | def on_eval_epoch_end(self, prefix=""):
65 | # Compute total loss and metrics
66 | loss = self.val_loss.compute()
67 | metrics = self.metrics.compute()
68 |
69 | # Reset the metrics
70 | self.val_loss.reset()
71 | self.metrics.reset()
72 |
73 | # Log the loss and the metrics
74 | self.log(f"{prefix}_loss", loss, prog_bar=True)
75 | self.log(f"{prefix}_f1", metrics["BinaryF1Score"], prog_bar=True)
76 | self.log(f"{prefix}_roc_auc", metrics["BinaryAUROC"], prog_bar=True)
77 | self.log(f"{prefix}_pr_auc", metrics["BinaryAveragePrecision"], prog_bar=True)
78 |
79 | if not isinstance(opt := self.optimizers(), list):
80 | # Log the learning rate of a single optimizer
81 | self.log("lr", opt.param_groups[0]["lr"], prog_bar=True)
82 |
83 | def validation_step(self, batch, batch_idx):
84 | self.eval_step(batch)
85 |
86 | def on_validation_epoch_end(self):
87 | self.on_eval_epoch_end(prefix="val")
88 |
89 | def test_step(self, batch, batch_idx):
90 | self.eval_step(batch)
91 |
92 | def on_test_epoch_end(self):
93 | self.on_eval_epoch_end(prefix="test")
94 |
95 | def train_dataloader(self):
96 | return self.train_loader
97 |
98 | def val_dataloader(self):
99 | return self.val_loader
100 |
101 | def test_dataloader(self):
102 | return self.test_loader
103 |
104 | def configure_optimizers(self):
105 | # Initialize AdamW optimizer and Reduce On Plateau scheduler
106 | optimizer = AdamW(self.parameters(), lr=1e-3, weight_decay=1e-3)
107 | scheduler = ReduceLROnPlateau(
108 | optimizer=optimizer,
109 | factor=0.3,
110 | patience=15,
111 | threshold=0.01,
112 | min_lr=1e-6,
113 | )
114 |
115 | return {
116 | "optimizer": optimizer,
117 | "lr_scheduler": scheduler,
118 | "monitor": "val_loss",
119 | }
120 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # For the full list of built-in configuration values, see the documentation:
4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
5 |
6 | import os
7 | import sys
8 | from pathlib import Path
9 |
10 | from sphinx.application import Sphinx
11 |
12 | sys.path += [
13 | str(Path(__file__).parent.parent / "src"),
14 | str(Path(__file__).parent),
15 | ]
16 |
17 | from helpers import BuildFinished, CustomInvs
18 |
19 | DOCS_DIR = Path(os.path.dirname(os.path.abspath(__file__)))
20 |
21 | # -- Project information -----------------------------------------------------
22 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
23 |
24 | project = "Glasses Detector"
25 | copyright = "2024, Mantas Birškus"
26 | author = "Mantas Birškus"
27 | release = "v1.0.4"
28 |
29 | # -- General configuration ---------------------------------------------------
30 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
31 |
32 | extensions = [
33 | "sphinx.ext.intersphinx",
34 | "sphinx.ext.todo",
35 | "sphinx.ext.napoleon",
36 | "sphinx.ext.viewcode",
37 | "sphinx.ext.autodoc",
38 | "sphinx.ext.autosectionlabel",
39 | "sphinx_copybutton",
40 | "sphinxcontrib.bibtex",
41 | "sphinx_design",
42 | ]
43 |
44 | intersphinx_mapping = {
45 | "python": ("https://docs.python.org/3/", None),
46 | "PIL": ("https://pillow.readthedocs.io/en/stable/", None),
47 | "numpy": ("https://numpy.org/doc/stable/", None),
48 | "torch": ("https://pytorch.org/docs/stable/", None),
49 | "torchvision": ("https://pytorch.org/vision/stable/", None),
50 | "matplotlib": ("https://matplotlib.org/stable/", None),
51 | "tqdm": ("https://tqdm.github.io/docs/", "_static/inv/tqdm.inv"),
52 | }
53 |
54 | # -- Options for napaleon/autosummary/autodoc output -------------------------
55 | napoleon_use_param = True
56 | autosummary_generate = True
57 | autodoc_typehints = "both"
58 | autodoc_member_order = "bysource"
59 |
60 | templates_path = ["_templates"]
61 | bibtex_bibfiles = ["_static/bib/references.bib"]
62 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
63 |
64 | # -- Options for HTML output -------------------------------------------------
65 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
66 |
67 | html_theme = "pydata_sphinx_theme"
68 | html_theme_options = {
69 | "logo": {
70 | "alt_text": "Glasses Detector - Home",
71 | "text": f"Glasses Detector {release}",
72 | "image_light": "_static/img/logo-light.png",
73 | "image_dark": "_static/img/logo-dark.png",
74 | },
75 | "icon_links": [
76 | {
77 | "name": "GitHub",
78 | "url": "https://github.com/mantasu/glasses-detector",
79 | "icon": "fa-brands fa-github",
80 | },
81 | {
82 | "name": "Colab",
83 | "url": (
84 | "https://colab.research.google.com/github/mantasu/glasses-detector/blob/main/notebooks/demo.ipynb"
85 | ),
86 | "icon": "fa-custom fa-colab",
87 | },
88 | {
89 | "name": "PyPI",
90 | "url": "https://pypi.org/project/glasses-detector",
91 | "icon": "fa-custom fa-pypi",
92 | },
93 | {
94 | "name": "Zenodo",
95 | "url": "https://zenodo.org/doi/10.5281/zenodo.8126101",
96 | "icon": "fa-custom fa-zenodo",
97 | },
98 | ],
99 | "show_toc_level": 2,
100 | "navigation_with_keys": False,
101 | "header_links_before_dropdown": 7,
102 | }
103 | html_context = {
104 | "github_user": "mantasu",
105 | "github_repo": "glasses-detector",
106 | "github_version": "main",
107 | "doc_path": "docs",
108 | }
109 | html_static_path = ["_static"]
110 | html_js_files = ["js/colab-icon.js", "js/pypi-icon.js", "js/zenodo-icon.js"]
111 | html_css_files = ["css/highlights.css", "css/signatures.css", "css/custom.css"]
112 | html_title = f"Glasses Detector {release}"
113 | html_favicon = "_static/img/logo-light.png"
114 |
115 | # -- Custom Template Functions -----------------------------------------------
116 | # https://www.sphinx-doc.org/en/master/development/theming.html#defining-custom-template-functions
117 |
118 |
119 | def setup(app: Sphinx):
120 | # Add local inventories to intersphinx_mapping
121 | custom_invs = CustomInvs(static_path=DOCS_DIR / "_static")
122 | app.config.intersphinx_mapping.update(custom_invs())
123 |
124 | # Add custom build-finished event
125 | build_finished = BuildFinished(DOCS_DIR / "_static", DOCS_DIR / "conf.yaml")
126 | app.connect("build-finished", build_finished)
127 |
--------------------------------------------------------------------------------
/docs/docs/cli.rst:
--------------------------------------------------------------------------------
1 | :fas:`terminal` CLI
2 | ===================
3 |
4 | .. role:: bash(code)
5 | :language: bash
6 | :class: highlight
7 |
8 | These flags allow you to define the kind of task and the model to process your image or a directory with images. Check out how to use them in :ref:`command-line`.
9 |
10 | .. option:: -i path/to/dir/or/file, --input path/to/dir/or/file
11 |
12 | Path to the input image or the directory with images.
13 |
14 | .. option:: -o path/to/dir/or/file, --output path/to/dir/or/file
15 |
16 | Path to the output file or the directory. If not provided, then, if input is a file, the prediction will be printed (or shown if it is an image), otherwise, if input is a directory, the predictions will be written to a directory with the same name with an added suffix ``_preds``. If provided as a file, then the prediction(-s) will be saved to this file (supported extensions include: ``.txt``, ``.csv``, ``.json``, ``.npy``, ``.pkl``, ``.jpg``, ``.png``). If provided as a directory, then the predictions will be saved to this directory use :bash:`--extension` flag to specify the file extensions in that directory.
17 |
18 | **Default:** :py:data:`None`
19 |
20 | .. option:: -e , --extension
21 |
22 | Only used if :bash:`--output` is a directory. The extension to use to save the predictions as files. Common extensions include: ``.txt``, ``.csv``, ``.json``, ``.npy``, ``.pkl``, ``.jpg``, ``.png``. If not specified, it will be set automatically to ``.jpg`` for image predictions and to ``.txt`` for all other formats.
23 |
24 | **Default:** :py:data:`None`
25 |
26 | .. option:: -f , --format
27 |
28 | The format to use to map the raw prediction to.
29 |
30 | * For *classification*, common formats are ``bool``, ``proba``, ``str`` - check :meth:`GlassesClassifier.predict` for more details
31 | * For *detection*, common formats are ``bool``, ``int``, ``img`` - check :meth:`GlassesDetector.predict` for more details
32 | * For *segmentation*, common formats are ``proba``, ``img``, ``mask`` - check :meth:`GlassesSegmenter.predict` for more details
33 |
34 | If not specified, it will be set automatically to ``str``, ``img``, ``mask`` for *classification*, *detection*, *segmentation* respectively.
35 |
36 | **Default:** :py:data:`None`
37 |
38 | .. option:: -t , --task
39 |
40 | The kind of task the model should perform. One of
41 |
42 | * ``classification``
43 | * ``classification:anyglasses``
44 | * ``classification:sunglasses``
45 | * ``classification:eyeglasses``
46 | * ``classification:shadows``
47 | * ``detection``
48 | * ``detection:eyes``
49 | * ``detection:solo``
50 | * ``detection:worn``
51 | * ``segmentation``
52 | * ``segmentation:frames``
53 | * ``segmentation:full``
54 | * ``segmentation:legs``
55 | * ``segmentation:lenses``
56 | * ``segmentation:shadows``
57 | * ``segmentation:smart``
58 |
59 | If specified only as ``classification``, ``detection``, or ``segmentation``, the subcategories ``anyglasses``, ``worn``, and ``smart`` will be chosen, respectively.
60 |
61 | **Default:** ``classification:anyglasses``
62 |
63 | .. option:: -s , --size
64 |
65 | The model size which determines architecture type. One of ``small``, ``medium``, ``large`` (or ``s``, ``m``, ``l``).
66 |
67 | **Default:** ``medium``
68 |
69 | .. option:: -b , --batch-size
70 |
71 | Only used if :bash:`--input` is a directory. The batch size to use when processing the images. This groups the files in the input directory to batches of size ``batch_size`` before processing them. In some cases, larger batch sizes can speed up the processing at the cost of more memory usage.
72 |
73 | **Default:** ``1``
74 |
75 | .. option:: -p , --pbar
76 |
77 | Only used if :bash:`--input` is a directory. It is the description that is used for the progress bar. If specified as ``""`` (empty string), no progress bar is shown.
78 |
79 | **Default:** ``"Processing"``
80 |
81 | .. option:: -w path/to/weights.pth, --weights path/to/weights.pth
82 |
83 | Path to custom weights to load into the model. If not specified, weights will be loaded from the default location (and automatically downloaded there if needed).
84 |
85 | **Default:** :py:data:`None`
86 |
87 | .. option:: -d , --device
88 |
89 | The device on which to perform inference. If not specified, it will be automatically checked if `CUDA `_ or `MPS `_ is supported.
90 |
91 | **Default:** :py:data:`None`
92 |
--------------------------------------------------------------------------------
/scripts/analyse.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import os
3 | import sys
4 | import tempfile
5 | from typing import Any
6 |
7 | import torch
8 | from fvcore.nn import FlopCountAnalysis
9 | from prettytable import PrettyTable
10 | from torch.profiler import ProfilerActivity, profile
11 |
12 | PROJECT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
13 | sys.path.append(os.path.join(PROJECT_DIR, "src"))
14 | torch.set_float32_matmul_precision("medium")
15 |
16 | from glasses_detector import GlassesClassifier, GlassesDetector, GlassesSegmenter
17 |
18 |
19 | def check_filesize(model: torch.nn.Module):
20 | with tempfile.NamedTemporaryFile() as temp:
21 | torch.save(model.state_dict(), temp.name)
22 | return os.path.getsize(temp.name) / (1024**2)
23 |
24 |
25 | def check_num_params(model: torch.nn.Module):
26 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
27 |
28 |
29 | def check_flops(model: torch.nn.Module, input: Any):
30 | flops = FlopCountAnalysis(model, (input,))
31 | return flops.total()
32 |
33 |
34 | def check_ram(model: torch.nn.Module, input: Any):
35 | # Clean up
36 | gc.collect()
37 |
38 | with profile(
39 | activities=[ProfilerActivity.CPU],
40 | profile_memory=True,
41 | record_shapes=True,
42 | ) as prof:
43 | # Run model
44 | model(input)
45 |
46 | # Sum the memory usage of all the operations on the CPU
47 | memory_usage = sum(item.cpu_memory_usage for item in prof.key_averages())
48 |
49 | return memory_usage / (1024**2)
50 |
51 |
52 | def check_vram(model: torch.nn.Module, input: Any):
53 | # Clean up
54 | torch.cuda.empty_cache()
55 | torch.cuda.synchronize()
56 |
57 | # Check the allocated memory before
58 | mem_before = torch.cuda.max_memory_allocated()
59 |
60 | if isinstance(input, list):
61 | # Detection models require a list
62 | input = [torch.tensor(i, device="cuda") for i in input]
63 | else:
64 | # Otherwise a single input is enough
65 | input = torch.tensor(input, device="cuda")
66 |
67 | # Run the model on CUDA
68 | model = model.to("cuda")
69 | model(input)
70 | torch.cuda.synchronize()
71 |
72 | # Check the allocated memory after
73 | mem_after = torch.cuda.max_memory_allocated()
74 |
75 | # Clean up
76 | model.to("cpu")
77 | torch.cuda.empty_cache()
78 |
79 | return (mem_after - mem_before) / (1024**2)
80 |
81 |
82 | def analyse(model_cls, task):
83 | # Create a table
84 | table = PrettyTable()
85 | field_names = [
86 | "Model size",
87 | "Filesize (MB)",
88 | "Num params",
89 | "FLOPS",
90 | "RAM (MB)",
91 | ]
92 |
93 | if torch.cuda.is_available():
94 | # If CUDA is available, add VRAM
95 | field_names.append("VRAM (MB)")
96 |
97 | # Set the field names
98 | table.field_names = field_names
99 |
100 | if task == "detection":
101 | # Detection models require a list
102 | input = [*torch.randn(1, 3, 256, 256)]
103 | else:
104 | # Otherwise a single input is enough
105 | input = torch.randn(1, 3, 256, 256)
106 |
107 | for size in ["small", "medium", "large"]:
108 | # Clean up
109 | gc.collect()
110 |
111 | # Load the model without pre-trained weights on the CPU
112 | model = model_cls(size=size, weights=False, device="cpu").model
113 | model.eval()
114 |
115 | # Check the basic stats
116 | filesize = check_filesize(model)
117 | num_params = check_num_params(model)
118 | flops = check_flops(model, input)
119 | ram = check_ram(model, input)
120 |
121 | # Stats
122 | row = [
123 | size,
124 | f"{filesize:.2f}",
125 | f"{num_params:,}",
126 | f"{flops:,}",
127 | f"{ram:.2f}",
128 | ]
129 |
130 | if torch.cuda.is_available():
131 | # If CUDA is available, check VRAM
132 | vram = check_vram(model, input)
133 | row.append(f"{vram:.2f}")
134 |
135 | # Add the stats row
136 | table.add_row(row)
137 |
138 | # Print table
139 | print(table)
140 |
141 |
142 | def parse_args():
143 | import argparse
144 |
145 | parser = argparse.ArgumentParser()
146 | parser.add_argument(
147 | "--task",
148 | "-t",
149 | type=str,
150 | required=True,
151 | choices=["classification", "detection", "segmentation"],
152 | )
153 | return parser.parse_args()
154 |
155 |
156 | def main():
157 | args = parse_args()
158 |
159 | match args.task:
160 | case "classification":
161 | model_cls = GlassesClassifier
162 | case "detection":
163 | model_cls = GlassesDetector
164 | case "segmentation":
165 | model_cls = GlassesSegmenter
166 |
167 | analyse(model_cls, args.task)
168 |
169 |
170 | if __name__ == "__main__":
171 | main()
172 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Contributor Covenant Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | We as members, contributors, and leaders pledge to make participation in our
6 | community a harassment-free experience for everyone, regardless of age, body
7 | size, visible or invisible disability, ethnicity, sex characteristics, gender
8 | identity and expression, level of experience, education, socio-economic status,
9 | nationality, personal appearance, race, religion, or sexual identity
10 | and orientation.
11 |
12 | We pledge to act and interact in ways that contribute to an open, welcoming,
13 | diverse, inclusive, and healthy community.
14 |
15 | ## Our Standards
16 |
17 | Examples of behavior that contributes to a positive environment for our
18 | community include:
19 |
20 | * Demonstrating empathy and kindness toward other people
21 | * Being respectful of differing opinions, viewpoints, and experiences
22 | * Giving and gracefully accepting constructive feedback
23 | * Accepting responsibility and apologizing to those affected by our mistakes,
24 | and learning from the experience
25 | * Focusing on what is best not just for us as individuals, but for the
26 | overall community
27 |
28 | Examples of unacceptable behavior include:
29 |
30 | * The use of sexualized language or imagery, and sexual attention or
31 | advances of any kind
32 | * Trolling, insulting or derogatory comments, and personal or political attacks
33 | * Public or private harassment
34 | * Publishing others' private information, such as a physical or email
35 | address, without their explicit permission
36 | * Other conduct which could reasonably be considered inappropriate in a
37 | professional setting
38 |
39 | ## Enforcement Responsibilities
40 |
41 | Community leaders are responsible for clarifying and enforcing our standards of
42 | acceptable behavior and will take appropriate and fair corrective action in
43 | response to any behavior that they deem inappropriate, threatening, offensive,
44 | or harmful.
45 |
46 | Community leaders have the right and responsibility to remove, edit, or reject
47 | comments, commits, code, wiki edits, issues, and other contributions that are
48 | not aligned to this Code of Conduct, and will communicate reasons for moderation
49 | decisions when appropriate.
50 |
51 | ## Scope
52 |
53 | This Code of Conduct applies within all community spaces, and also applies when
54 | an individual is officially representing the community in public spaces.
55 | Examples of representing our community include using an official e-mail address,
56 | posting via an official social media account, or acting as an appointed
57 | representative at an online or offline event.
58 |
59 | ## Enforcement
60 |
61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
62 | reported to the community leaders responsible for enforcement at
63 | mantix7@gmail.com.
64 | All complaints will be reviewed and investigated promptly and fairly.
65 |
66 | All community leaders are obligated to respect the privacy and security of the
67 | reporter of any incident.
68 |
69 | ## Enforcement Guidelines
70 |
71 | Community leaders will follow these Community Impact Guidelines in determining
72 | the consequences for any action they deem in violation of this Code of Conduct:
73 |
74 | ### 1. Correction
75 |
76 | **Community Impact**: Use of inappropriate language or other behavior deemed
77 | unprofessional or unwelcome in the community.
78 |
79 | **Consequence**: A private, written warning from community leaders, providing
80 | clarity around the nature of the violation and an explanation of why the
81 | behavior was inappropriate. A public apology may be requested.
82 |
83 | ### 2. Warning
84 |
85 | **Community Impact**: A violation through a single incident or series
86 | of actions.
87 |
88 | **Consequence**: A warning with consequences for continued behavior. No
89 | interaction with the people involved, including unsolicited interaction with
90 | those enforcing the Code of Conduct, for a specified period of time. This
91 | includes avoiding interactions in community spaces as well as external channels
92 | like social media. Violating these terms may lead to a temporary or
93 | permanent ban.
94 |
95 | ### 3. Temporary Ban
96 |
97 | **Community Impact**: A serious violation of community standards, including
98 | sustained inappropriate behavior.
99 |
100 | **Consequence**: A temporary ban from any sort of interaction or public
101 | communication with the community for a specified period of time. No public or
102 | private interaction with the people involved, including unsolicited interaction
103 | with those enforcing the Code of Conduct, is allowed during this period.
104 | Violating these terms may lead to a permanent ban.
105 |
106 | ### 4. Permanent Ban
107 |
108 | **Community Impact**: Demonstrating a pattern of violation of community
109 | standards, including sustained inappropriate behavior, harassment of an
110 | individual, or aggression toward or disparagement of classes of individuals.
111 |
112 | **Consequence**: A permanent ban from any sort of public interaction within
113 | the community.
114 |
115 | ## Attribution
116 |
117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage],
118 | version 2.0, available at
119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
120 |
121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct
122 | enforcement ladder](https://github.com/mozilla/diversity).
123 |
124 | [homepage]: https://www.contributor-covenant.org
125 |
126 | For answers to common questions about this code of conduct, see the FAQ at
127 | https://www.contributor-covenant.org/faq. Translations are available at
128 | https://www.contributor-covenant.org/translations.
129 |
--------------------------------------------------------------------------------
/src/glasses_detector/_wrappers/metrics.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from typing import override
3 |
4 | import torch
5 | import torchmetrics
6 | from scipy.optimize import linear_sum_assignment
7 | from torchmetrics.functional import mean_squared_log_error, r2_score
8 | from torchvision.ops import box_iou
9 |
10 |
11 | class BoxMetric(torchmetrics.Metric):
12 | def __init__(self, name="sum", is_min=False, dist_sync_on_step=False):
13 | super().__init__(dist_sync_on_step=dist_sync_on_step)
14 | self.add_state(name, default=torch.tensor(0.0), dist_reduce_fx="sum")
15 | self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
16 |
17 | # Select min/max mode
18 | self.is_min = is_min
19 | self.name = name
20 |
21 | @abstractmethod
22 | def compute_matrix(
23 | self,
24 | preds: torch.Tensor,
25 | targets: torch.Tensor,
26 | ) -> torch.Tensor: ...
27 |
28 | def update(
29 | self,
30 | preds: list[dict[str, torch.Tensor]],
31 | targets: list[dict[str, torch.Tensor]],
32 | *classification_metrics,
33 | ):
34 | # Initialize flat target labels and predicted labels
35 | pred_l = torch.empty(0, dtype=torch.long, device=preds[0]["labels"].device)
36 | target_l = torch.empty(0, dtype=torch.long, device=preds[0]["labels"].device)
37 |
38 | for pred, target in zip(preds, targets):
39 | if len(pred["boxes"]) == 0:
40 | pred["boxes"] = torch.tensor(
41 | [[0, 0, 0, 0]], dtype=torch.float32, device=pred["boxes"].device
42 | )
43 | pred["labels"] = torch.tensor(
44 | [0], dtype=torch.long, device=pred["labels"].device
45 | )
46 |
47 | if len(target["boxes"]) == 0:
48 | target["boxes"] = torch.tensor(
49 | [[0, 0, 0, 0]], dtype=torch.float32, device=target["boxes"].device
50 | )
51 | target["labels"] = torch.tensor(
52 | [0], dtype=torch.long, device=target["labels"].device
53 | )
54 |
55 | # Compute the matrix of similarities and select best
56 | similarities = self.compute_matrix(pred["boxes"], target["boxes"])
57 | cost_matrix = similarities if self.is_min else 1 - similarities
58 | pred_idx, target_idx = linear_sum_assignment(cost_matrix.cpu())
59 | best_sims = similarities[pred_idx, target_idx]
60 |
61 | # Add the labels of matched predictions
62 | target_l = torch.cat([target_l, target["labels"][target_idx]])
63 | pred_l = torch.cat([pred_l, pred["labels"][pred_idx]])
64 |
65 | if (remain := list(set(range(len(pred["boxes"]))) - set(pred_idx))) != []:
66 | # Add the labels of unmatched predictions, set as bg
67 | pred_l = torch.cat([pred_l, pred["labels"][remain]])
68 | padded = torch.tensor(len(remain) * [0], device=target_l.device)
69 | target_l = torch.cat([target_l, padded])
70 |
71 | if (
72 | remain := list(set(range(len(target["boxes"]))) - set(target_idx))
73 | ) != []:
74 | # Add the labels of unmatched predictions, set as bg
75 | target_l = torch.cat([target_l, target["labels"][remain]])
76 | padded = torch.tensor(len(remain) * [0], device=pred_l.device)
77 | pred_l = torch.cat([pred_l, torch.tensor(len(remain) * [0])])
78 |
79 | # Update the bbox metric
80 | setattr(self, self.name, getattr(self, self.name) + best_sims.sum())
81 | self.total += max(len(pred["boxes"]), len(target["boxes"]))
82 |
83 | for metric in classification_metrics:
84 | # Update classification metrics
85 | metric.update(pred_l, target_l)
86 |
87 | def compute(self):
88 | return getattr(self, self.name) / self.total
89 |
90 |
91 | class BoxIoU(BoxMetric):
92 | def __init__(self, **kwargs):
93 | kwargs["name"] = "iou_sum"
94 | super().__init__(**kwargs)
95 |
96 | @override
97 | def compute_matrix(
98 | self,
99 | preds: torch.Tensor,
100 | targets: torch.Tensor,
101 | ) -> torch.Tensor:
102 | return box_iou(preds, targets)
103 |
104 |
105 | class BoxClippedR2(BoxMetric):
106 | def __init__(self, **kwargs):
107 | # Get r2 kwargs, remove them from kwargs
108 | r2_args = r2_score.__code__.co_varnames
109 | self.r2_kwargs = {k: v for k, v in kwargs.items() if k in r2_args}
110 | kwargs = {k: v for k, v in kwargs.items() if k not in r2_args}
111 | kwargs["name"] = "r2_sum"
112 | super().__init__(**kwargs)
113 |
114 | @override
115 | def compute_matrix(
116 | self,
117 | preds: torch.Tensor,
118 | targets: torch.Tensor,
119 | ) -> torch.Tensor:
120 | return torch.tensor(
121 | [
122 | [
123 | max(0, r2_score(p.view(-1), t.view(-1), **self.r2_kwargs))
124 | for t in targets
125 | ]
126 | for p in preds
127 | ]
128 | )
129 |
130 |
131 | class BoxMSLE(BoxMetric):
132 | def __init__(self, **kwargs):
133 | kwargs.setdefault("is_min", True)
134 | kwargs["name"] = "msle_sum"
135 | super().__init__(**kwargs)
136 |
137 | @override
138 | def compute_matrix(
139 | self,
140 | preds: torch.Tensor,
141 | targets: torch.Tensor,
142 | ) -> torch.Tensor:
143 | return torch.tensor(
144 | [
145 | [mean_squared_log_error(p.view(-1), t.view(-1)) for t in targets]
146 | for p in preds
147 | ]
148 | )
149 |
--------------------------------------------------------------------------------
/src/glasses_detector/_data/base_categorized_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | from abc import ABC, abstractmethod
4 | from collections import defaultdict
5 | from functools import cached_property
6 | from os.path import basename, splitext
7 | from typing import Any, Callable
8 |
9 | import torch
10 | from torch.utils.data import DataLoader, Dataset
11 |
12 | from .augmenter_mixin import AugmenterMixin
13 |
14 |
15 | class BaseCategorizedDataset(ABC, Dataset, AugmenterMixin):
16 | def __init__(
17 | self,
18 | root: str = ".",
19 | split_type: str = "train",
20 | img_folder: str = "images",
21 | label_type: str = "enum",
22 | cat2idx_fn: Callable[[str], int] | dict[str | int] | None = None,
23 | pth2idx_fn: Callable[[str], str] = lambda x: splitext(basename(x))[0],
24 | seed: int = 0,
25 | ):
26 | super().__init__()
27 |
28 | self.label_type = label_type
29 | self.img_folder = img_folder
30 | self.data = defaultdict(lambda: {})
31 | self.cats = []
32 |
33 | for dataset in os.listdir(root):
34 | if not os.path.isdir(p := os.path.join(root, dataset, split_type)):
35 | # No split
36 | continue
37 |
38 | for cat in os.scandir(p):
39 | if cat.name != img_folder and cat.name not in self.cats:
40 | # Expand category list
41 | self.cats.append(cat.name)
42 |
43 | for file in os.scandir(cat.path):
44 | # Add image/annotation path under file name as key
45 | self.data[pth2idx_fn(file.path)][cat.name] = file.path
46 |
47 | # Shuffle only values (sort first for reproducibility)
48 | self.data = [v for _, v in sorted(self.data.items())]
49 | random.seed(seed)
50 | random.shuffle(self.data)
51 |
52 | # Sort cats as well
53 | self.cats = sorted(
54 | self.cats,
55 | key=(
56 | None
57 | if cat2idx_fn is None
58 | else cat2idx_fn.get if isinstance(cat2idx_fn, dict) else cat2idx_fn
59 | ),
60 | )
61 |
62 | # Create a default transformation
63 | self.transform = self.create_transform(split_type == "train")
64 | self.__post_init__()
65 |
66 | @cached_property
67 | def cat2idx(self) -> dict[str, int]:
68 | return dict(zip(self.cats, range(len(self.cats))))
69 |
70 | @cached_property
71 | def idx2cat(self) -> dict[int, str]:
72 | return dict(zip(range(len(self.cats)), self.cats))
73 |
74 | @classmethod
75 | def create_loader(cls, **kwargs) -> DataLoader:
76 | # Get argument names from DataLoader
77 | fn_code = DataLoader.__init__.__code__
78 | init_arg_names = fn_code.co_varnames[: fn_code.co_argcount]
79 |
80 | # Split all the given kwargs to dataset (cls) and loader kwargs
81 | set_kwargs = {k: v for k, v in kwargs.items() if k not in init_arg_names}
82 | ldr_kwargs = {k: v for k, v in kwargs.items() if k in init_arg_names}
83 |
84 | # Define default loader kwargs
85 | default_loader_kwargs = {
86 | "dataset": cls(**set_kwargs),
87 | "batch_size": 64,
88 | "num_workers": 12,
89 | "pin_memory": True,
90 | "drop_last": True,
91 | "shuffle": set_kwargs.get("split_type", "train") == "train",
92 | }
93 |
94 | # Update default loader kwargs with custom
95 | default_loader_kwargs.update(ldr_kwargs)
96 |
97 | return DataLoader(**default_loader_kwargs)
98 |
99 | @classmethod
100 | def create_loaders(cls, **kwargs) -> tuple[DataLoader, DataLoader, DataLoader]:
101 | # Create train, validationa and test loaders
102 | train_loader = cls.create_loader(split_type="train", **kwargs)
103 | val_loader = cls.create_loader(split_type="val", **kwargs)
104 | test_loader = cls.create_loader(split_type="test", **kwargs)
105 |
106 | return train_loader, val_loader, test_loader
107 |
108 | def cat2tensor(self, cat: str | list[str]) -> torch.Tensor:
109 | # Convert category name(-s) to the index list
110 | cat = cat if isinstance(cat, list) else [cat]
111 | indices = list(map(self.cat2idx.get, cat))
112 |
113 | match self.label_type:
114 | case "enum":
115 | label = torch.tensor(indices)
116 | case "onehot":
117 | label = torch.eye(len(self.cats))[indices]
118 | case "multihot":
119 | label = torch.any(torch.eye(len(self.cats))[indices], 0, True)
120 | case _:
121 | raise ValueError(f"Unknown label type: {self.label_type}")
122 |
123 | return label.to(torch.long)
124 |
125 | def tensor2cat(self, tensor: torch.Tensor) -> str | list[str] | list[list[str]]:
126 | match self.label_type:
127 | case "enum":
128 | # Add a batch dimension if tensor is a scalar, get cats
129 | ts = tensor.unsqueeze(0) if tensor.ndim == 0 else tensor
130 | cat = [self.idx2cat[i.item()] for i in ts]
131 | return cat[0] if tensor.ndim == 0 or len(tensor) == 0 else cat
132 | case "onehot":
133 | # Get cats directly (works for both 1D and 2D tensors)
134 | cat = [self.idx2cat[i.item()] for i in torch.where(tensor)[0]]
135 | return cat[0] if tensor.ndim == 1 else cat
136 | case "multihot":
137 | # Add a batch dimension if tensor is a 1D list, get cats
138 | ts = tensor if tensor.ndim > 1 else tensor.unsqueeze(0)
139 | cat = [[self.idx2cat[i.item()] for i in torch.where[t][0]] for t in ts]
140 | return cat[0] if tensor.ndim == 1 else cat
141 | case _:
142 | raise ValueError(f"Unknown label type: {self.label_type}")
143 |
144 | def __len__(self) -> int:
145 | return len(self.data)
146 |
147 | @abstractmethod
148 | def __getitem__(self, index: int) -> Any: ...
149 |
150 | def __post_init__(self):
151 | pass
152 |
--------------------------------------------------------------------------------
/docs/_static/js/zenodo-icon.js:
--------------------------------------------------------------------------------
1 | FontAwesome.library.add(
2 | (faListOldStyle = {
3 | prefix: "fa-custom",
4 | iconName: "zenodo",
5 | icon: [
6 | 146.355, // viewBox width
7 | 47.955, // viewBox height
8 | [], // ligature
9 | "e001", // unicode codepoint - private use area
10 | "M145.301,18.875c-0.705-1.602-1.656-2.997-2.846-4.19c-1.189-1.187-2.584-2.125-4.188-2.805 c-1.604-0.678-3.307-1.02-5.102-1.02c-1.848,0-3.564,0.342-5.139,1.02c-0.787,0.339-1.529,0.74-2.225,1.205 c-0.701,0.469-1.357,1.003-1.967,1.6c-0.377,0.37-0.727,0.761-1.051,1.17c-0.363,0.457-0.764,1.068-0.992,1.439 c-0.281,0.456-0.957,1.861-1.254,2.828c0.041-1.644,0.281-4.096,1.254-5.472V2.768c0-0.776-0.279-1.431-0.84-1.965 C120.396,0.268,119.75,0,119.021,0c-0.777,0-1.43,0.268-1.969,0.803c-0.531,0.534-0.801,1.189-0.801,1.965v10.569 c-1.117-0.778-2.322-1.386-3.605-1.824c-1.285-0.436-2.637-0.654-4.045-0.654c-1.799,0-3.496,0.342-5.1,1.02 c-1.605,0.679-3,1.618-4.195,2.805c-1.186,1.194-2.139,2.588-2.836,4.19c-0.053,0.12-0.1,0.242-0.15,0.364 c-0.047-0.122-0.094-0.244-0.146-0.364c-0.705-1.602-1.656-2.997-2.846-4.19c-1.189-1.187-2.586-2.125-4.188-2.805 c-1.604-0.678-3.307-1.02-5.102-1.02c-1.848,0-3.564,0.342-5.139,1.02c-1.584,0.679-2.979,1.618-4.191,2.805 c-1.213,1.194-2.164,2.588-2.842,4.19c-0.049,0.115-0.092,0.23-0.137,0.344c-0.047-0.114-0.092-0.229-0.141-0.344 c-0.701-1.602-1.65-2.997-2.84-4.19c-1.191-1.187-2.588-2.125-4.193-2.805c-1.604-0.678-3.301-1.02-5.104-1.02 c-1.842,0-3.557,0.342-5.137,1.02c-1.578,0.679-2.977,1.618-4.186,2.805c-1.221,1.194-2.166,2.588-2.848,4.19 c-0.043,0.106-0.082,0.214-0.125,0.32c-0.043-0.106-0.084-0.214-0.131-0.32c-0.707-1.602-1.656-2.997-2.848-4.19 c-1.188-1.187-2.582-2.125-4.184-2.805c-1.605-0.678-3.309-1.02-5.104-1.02c-1.85,0-3.564,0.342-5.137,1.02 c-1.467,0.628-2.764,1.488-3.91,2.552V13.99c0-1.557-1.262-2.822-2.82-2.822H3.246c-1.557,0-2.82,1.265-2.82,2.822 c0,1.559,1.264,2.82,2.82,2.82h15.541L0.557,41.356C0.195,41.843,0,42.433,0,43.038v1.841c0,1.558,1.264,2.822,2.822,2.822 h21.047c1.488,0,2.705-1.153,2.812-2.614c0.932,0.743,1.967,1.364,3.109,1.848c1.605,0.684,3.299,1.021,5.102,1.021 c2.723,0,5.15-0.726,7.287-2.187c1.727-1.176,3.092-2.639,4.084-4.389v3.805c0,0.778,0.264,1.436,0.805,1.968 c0.531,0.537,1.189,0.803,1.967,0.803c0.73,0,1.369-0.266,1.93-0.803c0.561-0.532,0.838-1.189,0.838-1.968v-9.879h-0.01 c0-0.002,0.01-0.013,0.01-0.013s-6.137,0-6.912,0c-0.58,0-1.109,0.154-1.566,0.472c-0.463,0.316-0.793,0.744-0.982,1.275 l-0.453,0.93c-0.631,1.365-1.566,2.443-2.809,3.244c-1.238,0.803-2.633,1.201-4.188,1.201c-1.023,0-2.004-0.191-2.955-0.579 c-0.941-0.39-1.758-0.935-2.439-1.64c-0.682-0.703-1.227-1.52-1.641-2.443c-0.41-0.924-0.617-1.893-0.617-2.916v-2.476h17.715 h1.309h5.539v-8.385c0-1.015,0.191-1.99,0.582-2.912c0.389-0.922,0.936-1.74,1.645-2.444c0.699-0.703,1.514-1.249,2.441-1.641 c0.918-0.388,1.92-0.581,2.982-0.581c1.023,0,2.01,0.193,2.955,0.581c0.945,0.393,1.762,0.938,2.439,1.641 c0.682,0.704,1.225,1.521,1.641,2.444c0.412,0.922,0.621,1.896,0.621,2.912v21.208c0,0.778,0.266,1.436,0.799,1.968 c0.535,0.537,1.191,0.803,1.971,0.803c0.729,0,1.371-0.266,1.934-0.803c0.553-0.532,0.834-1.189,0.834-1.968v-3.803 c0.588,1.01,1.283,1.932,2.1,2.749c1.189,1.189,2.586,2.124,4.191,2.804c1.602,0.684,3.303,1.021,5.102,1.021 c1.795,0,3.498-0.337,5.102-1.021c1.602-0.68,3.01-1.614,4.227-2.804c1.211-1.19,2.162-2.589,2.842-4.189 c0.037-0.095,0.074-0.19,0.109-0.286c0.039,0.096,0.074,0.191,0.113,0.286c0.678,1.601,1.625,2.999,2.842,4.189 c1.213,1.189,2.607,2.124,4.189,2.804c1.574,0.684,3.293,1.021,5.139,1.021c1.795,0,3.5-0.337,5.105-1.021 c1.6-0.68,2.994-1.614,4.184-2.804c1.191-1.19,2.141-2.589,2.848-4.189c0.051-0.12,0.098-0.239,0.146-0.36 c0.049,0.121,0.094,0.24,0.146,0.36c0.703,1.601,1.652,2.999,2.842,4.189c1.189,1.189,2.586,2.124,4.191,2.804 c1.604,0.684,3.303,1.021,5.102,1.021c1.795,0,3.498-0.337,5.102-1.021c1.604-0.68,3.01-1.614,4.227-2.804 c1.211-1.19,2.16-2.589,2.842-4.189c0.678-1.606,1.02-3.306,1.02-5.104v-10.86C146.355,22.182,146.002,20.479,145.301,18.875z M7.064,42.06l14.758-19.874c-0.078,0.587-0.121,1.184-0.121,1.791v10.86c0,1.799,0.35,3.498,1.059,5.104 c0.328,0.752,0.719,1.458,1.156,2.119c-0.016,0-0.031-0.001-0.047-0.001H7.064z M42.541,26.817H27.24v-2.841 c0-1.015,0.189-1.99,0.58-2.912c0.391-0.922,0.936-1.74,1.645-2.444c0.697-0.703,1.516-1.249,2.438-1.641 c0.922-0.388,1.92-0.581,2.99-0.581c1.02,0,2.002,0.193,2.949,0.581c0.949,0.393,1.764,0.938,2.441,1.641 c0.682,0.704,1.225,1.521,1.641,2.444c0.414,0.922,0.617,1.896,0.617,2.912V26.817z M91.688,34.837 c0,1.023-0.189,1.992-0.582,2.916c-0.389,0.924-0.936,1.74-1.637,2.443c-0.705,0.705-1.523,1.25-2.445,1.64 c-0.92,0.388-1.92,0.579-2.984,0.579c-1.023,0-2.004-0.191-2.955-0.579c-0.945-0.39-1.758-0.935-2.439-1.64 c-0.682-0.703-1.229-1.52-1.641-2.443s-0.617-1.893-0.617-2.916v-10.86c0-1.015,0.191-1.99,0.582-2.912 c0.387-0.922,0.934-1.74,1.639-2.444c0.701-0.703,1.52-1.249,2.441-1.641c0.922-0.388,1.92-0.581,2.99-0.581 c1.018,0,2.004,0.193,2.947,0.581c0.951,0.393,1.764,0.938,2.443,1.641c0.68,0.704,1.223,1.521,1.641,2.444 c0.412,0.922,0.617,1.896,0.617,2.912V34.837z M116.252,34.837c0,1.023-0.203,1.992-0.617,2.916 c-0.412,0.924-0.961,1.74-1.641,2.443c-0.68,0.705-1.492,1.25-2.443,1.64c-0.943,0.388-1.93,0.579-2.949,0.579 c-1.07,0-2.066-0.191-2.988-0.579c-0.924-0.39-1.74-0.935-2.439-1.64c-0.707-0.703-1.252-1.52-1.643-2.443 s-0.584-1.893-0.584-2.916v-10.86c0-1.015,0.211-1.99,0.619-2.912c0.416-0.922,0.961-1.74,1.641-2.444 c0.682-0.703,1.496-1.249,2.439-1.641c0.951-0.388,1.934-0.581,2.955-0.581c1.068,0,2.062,0.193,2.986,0.581 c0.926,0.393,1.738,0.938,2.443,1.641c0.703,0.704,1.252,1.521,1.641,2.444c0.389,0.922,0.58,1.896,0.58,2.912V34.837z M140.816,34.837c0,1.023-0.193,1.992-0.58,2.916c-0.393,0.924-0.939,1.74-1.641,2.443c-0.705,0.705-1.523,1.25-2.443,1.64 c-0.922,0.388-1.92,0.579-2.986,0.579c-1.021,0-2.004-0.191-2.955-0.579c-0.943-0.39-1.758-0.935-2.438-1.64 c-0.682-0.703-1.23-1.52-1.643-2.443s-0.619-1.893-0.619-2.916v-10.86c0-1.015,0.193-1.99,0.584-2.912 c0.387-0.922,0.934-1.74,1.639-2.444c0.703-0.703,1.518-1.249,2.441-1.641c0.924-0.388,1.92-0.581,2.99-0.581 c1.02,0,2.004,0.193,2.949,0.581c0.949,0.393,1.764,0.938,2.441,1.641c0.682,0.704,1.225,1.521,1.643,2.444 c0.412,0.922,0.617,1.896,0.617,2.912V34.837z", // svg path (https://github.com/zenodo/zenodo/blob/482ee72ad501cbbd7f8ce8df9b393c130d1970f7/zenodo/modules/theme/static/img/zenodo.svg)
11 | ],
12 | })
13 | );
--------------------------------------------------------------------------------
/src/glasses_detector/architectures/tiny_binary_detector.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class TinyBinaryDetector(nn.Module):
6 | """Tiny binary detector.
7 |
8 | This is a custom detector created with the aim to contain very few
9 | parameters while maintaining a reasonable accuracy. It only has
10 | several sequential convolutional and pooling blocks (with
11 | batch-norm in between).
12 |
13 | Note:
14 | I tried varying the architecture, including activations,
15 | convolution behavior (groups and stride), pooling, and layer
16 | structure. This also includes residual and dense connections,
17 | as well as combinations. Turns out, they do not perform as well
18 | as the current architecture which is just a bunch of
19 | CONV-RELU-BN-MAXPOOL blocks with no paddings.
20 | """
21 |
22 | def __init__(self):
23 | super().__init__()
24 |
25 | # Several convolutional blocks
26 | self.features = nn.Sequential(
27 | self._create_block(3, 6, 15),
28 | self._create_block(6, 12, 7),
29 | self._create_block(12, 24, 5),
30 | self._create_block(24, 48, 3),
31 | self._create_block(48, 96, 3),
32 | self._create_block(96, 192, 3),
33 | nn.AdaptiveAvgPool2d((1, 1)),
34 | nn.Flatten(),
35 | )
36 |
37 | # Fully connected layer
38 | self.fc = nn.Linear(192, 4)
39 |
40 | def _create_block(self, num_in, num_out, filter_size):
41 | return nn.Sequential(
42 | nn.Conv2d(num_in, num_out, filter_size, 1, "valid", bias=False),
43 | nn.ReLU(),
44 | nn.BatchNorm2d(num_out),
45 | nn.MaxPool2d(2, 2),
46 | )
47 |
48 | def forward(
49 | self,
50 | imgs: list[torch.Tensor],
51 | targets: list[dict[str, torch.Tensor]] | None = None,
52 | ) -> dict[str, torch.Tensor] | list[dict[str, torch.Tensor]]:
53 | """Forward pass through the network.
54 |
55 | This takes a list of images and returns a list of predictions
56 | for each image or a loss dictionary if the targets are provided.
57 | This is to match the API of the PyTorch *torchvision* models,
58 | which specify that:
59 |
60 | "During training, returns a dictionary containing the
61 | classification and regression losses for each image in the
62 | batch. During inference, returns a list of dictionaries, one
63 | for each input image. Each dictionary contains the predicted
64 | boxes, labels, and scores for all detections in the image."
65 |
66 | Args:
67 | imgs (list[torch.Tensor]): A list of images.
68 | annotations (list[dict[str, torch.Tensor]], optional): A
69 | list of annotations for each image. Each annotation is a
70 | dictionary that contains:
71 |
72 | 1. ``"boxes"``: the bounding boxes for each object
73 | 2. ``"labels"``: labels
74 | for all objects in the image. If ``None``, the
75 | network is in inference mode.
76 |
77 |
78 | Returns:
79 | dict[str, torch.Tensor] | list[dict[str, torch.Tensor]]:
80 | A dictionary with only a single "regression" loss entry if
81 | ``targets`` were specified. Otherwise, a list of
82 | dictionaries with the predicted bounding boxes, labels, and
83 | scores for all detections in each image.
84 | """
85 | # Forward pass; insert a new dimension to indicate a single bbox
86 | preds = self.fc(self.features(torch.stack(imgs)))
87 |
88 | # Get width and height
89 | h, w = imgs[0].shape[-2:]
90 |
91 | # Convert to (x_min, y_min, x_max, y_max)
92 | preds[:, 0] = preds[:, 0] * w
93 | preds[:, 1] = preds[:, 1] * h
94 | preds[:, 2] = preds[:, 0] + preds[:, 2] * w
95 | preds[:, 3] = preds[:, 1] + preds[:, 3] * h
96 |
97 | if targets is None:
98 | # Clamp the coordinates to the image size
99 | preds[:, 0] = torch.clamp(preds[:, 0], 0, w)
100 | preds[:, 1] = torch.clamp(preds[:, 1], 0, h)
101 | preds[:, 2] = torch.clamp(preds[:, 2], 0, w)
102 | preds[:, 3] = torch.clamp(preds[:, 3], 0, h)
103 |
104 | # Convert to shape (N, 1, 4)
105 | preds = [*preds[:, None, :]]
106 |
107 | if targets is not None:
108 | return self.compute_loss(preds, targets, imgs[0].size()[-2:])
109 | else:
110 | return [
111 | {
112 | "boxes": pred,
113 | "labels": torch.ones(1, dtype=torch.int64, device=pred.device),
114 | "scores": torch.ones(1, device=pred.device),
115 | }
116 | for pred in preds
117 | ]
118 |
119 | def compute_loss(
120 | self,
121 | preds: list[torch.Tensor],
122 | targets: list[dict[str, torch.Tensor]],
123 | size: tuple[int, int],
124 | ) -> dict[str, torch.Tensor]:
125 | """Compute the loss for the predicted bounding boxes.
126 |
127 | This computes the MSE loss between the predicted bounding boxes
128 | and the target bounding boxes. The returned dictionary contains
129 | only one key: "regression".
130 |
131 | Args:
132 | preds (list[torch.Tensor]): A list of predicted bounding
133 | boxes for each image.
134 | targets (list[dict[str, torch.Tensor]]): A list of targets
135 | for each image.
136 |
137 | Returns:
138 | dict[str, torch.Tensor]: A dictionary with only one key:
139 | "regression" which contains the regression MSE loss.
140 | """
141 | # Initialize criterion, loss dictionary, and device
142 | criterion, loss_dict, device = nn.MSELoss(), {}, preds[0].device
143 |
144 | # Use to divide (x_min, y_min, x_max, y_max) by (w, h, w, h)
145 | size = torch.tensor([[*size[::-1], *size[::-1]]], device=device)
146 |
147 | for i, pred in enumerate(preds):
148 | # Compute the loss (normalize the coordinates before that)
149 | loss = criterion(pred / size, targets[i]["boxes"][:1] / size)
150 | loss_dict[i] = loss
151 |
152 | return loss_dict
153 |
--------------------------------------------------------------------------------
/docs/docs/examples.rst:
--------------------------------------------------------------------------------
1 | :fas:`code` Examples
2 | ====================
3 |
4 | .. role:: bash(code)
5 | :language: bash
6 | :class: highlight
7 |
8 | .. _command-line:
9 |
10 | Command Line
11 | ------------
12 |
13 | You can run predictions via the command line. For example, classification of a single or multiple images, can be performed via:
14 |
15 | .. code-block:: bash
16 |
17 | glasses-detector -i path/to/img.jpg --task classification # Prints "present" or "absent"
18 | glasses-detector -i path/to/dir --output path/to/output.csv # Creates CSV (default --task is classification)
19 |
20 |
21 | It is possible to specify the **kind** of :bash:`--task` in the following format :bash:`task:kind`, for example, we may want to classify only *sunglasses* (only glasses with opaque lenses). Further, more options can be specified, like :bash:`--format`, :bash:`--size`, :bash:`--batch-size`, :bash:`--device`, etc:
22 |
23 | .. code-block:: bash
24 |
25 | glasses-detector -i path/to/img.jpg -t classification:sunglasses -f proba # Prints probability of sunglasses
26 | glasses-detector -i path/to/dir -o preds.pkl -s large -b 64 -d cuda # Fast and accurate processing
27 |
28 | Running *detection* and *segmentation* is similar, though we may want to generate a folder of predictions when processing a directory (but we can also squeeze all the predictions into a single file, such as ``.npy``):
29 |
30 | .. code-block:: bash
31 |
32 | glasses-detector -i path/to/img.jpg -t detection # Shows image with bounding boxes
33 | glasses-detector -i path/to/dir -t segmentation -f mask -e .jpg # Generates dir with masks
34 |
35 | .. tip::
36 |
37 | For a more exhaustive explanation of the available options use :bash:`glasses-detector --help` or check :doc:`cli`.
38 |
39 |
40 | Python Script
41 | -------------
42 |
43 | The most straightforward way to perform a prediction on a single file (or a list of files) is to use :meth:`~glasses_detector.components.pred_interface.PredInterface.process_file`. Although the prediction(-s) can be saved to a file or a directory, in most cases, this is useful to immediately show the prediction result(-s).
44 |
45 | .. code-block:: python
46 | :linenos:
47 |
48 | from glasses_detector import GlassesClassifier, GlassesDetector
49 |
50 | # Prints either '1' or '0'
51 | classifier = GlassesClassifier()
52 | classifier.process_file(
53 | input_path="path/to/img.jpg", # can be a list of paths
54 | format={True: "1", False: "0"}, # similar to format="int"
55 | show=True, # to print the prediction
56 | )
57 |
58 | # Opens a plot in a new window
59 | detector = GlassesDetector()
60 | detector.process_file(
61 | image="path/to/img.jpg", # can be a list of paths
62 | format="img", # to return the image with drawn bboxes
63 | show=True, # to show the image using matplotlib
64 | )
65 |
66 | A more useful method is :meth:`~glasses_detector.components.pred_interface.PredInterface.process_dir` which goes through all the images in the directory and generates the predictions into a single file or a directory of files. Also note how we can specify task ``kind`` and model ``size``:
67 |
68 | .. code-block:: python
69 | :linenos:
70 |
71 | from glasses_detector import GlassesClassifier, GlassesSegmenter
72 |
73 | # Generates a CSV file with image paths and labels
74 | classifier = GlassesClassifier(kind="sunglasses")
75 | classifier.process_dir(
76 | input_path="path/to/dir", # failed files will raise a warning
77 | output_path="path/to/output.csv", # img_name1.jpg,...
78 | format="proba", # is a probability of sunglasses
79 | pbar="Processing", # set to None to disable
80 | )
81 |
82 | # Generates a directory with masks
83 | segmenter = GlassesSegmenter(size="large", device="cuda")
84 | segmenter.process_dir(
85 | input_path="path/to/dir", # output dir defaults to path/to/dir_preds
86 | ext=".jpg", # saves each mask in JPG format
87 | format="mask", # output type will be a grayscale PIL image
88 | batch_size=32, # to speed up the processing
89 | output_size=(512, 512), # set to None to keep the same size as image
90 | )
91 |
92 |
93 | It is also possible to directly use :meth:`~glasses_detector.components.pred_interface.PredInterface.predict` which allows to process already loaded images. This is useful when you want to incorporate the prediction into a custom pipeline.
94 |
95 | .. code-block:: python
96 | :linenos:
97 |
98 | import numpy as np
99 | from glasses_detector import GlassesDetector
100 |
101 | # Predicts normalized bounding boxes
102 | detector = GlassesDetector()
103 | predictions = detector(
104 | image=np.random.randint(0, 256, size=(224, 224, 3), dtype=np.uint8),
105 | format="float",
106 | )
107 | print(type(prediction), len(prediction)) # 10
108 |
109 |
110 | .. admonition:: Refer to API documentation for model-specific examples
111 |
112 | * :class:`~glasses_detector.classifier.GlassesClassifier` and its :meth:`~glasses_detector.classifier.GlassesClassifier.predict`
113 | * :class:`~glasses_detector.detector.GlassesDetector` and its :meth:`~glasses_detector.detector.GlassesDetector.predict`
114 | * :class:`~glasses_detector.segmenter.GlassesSegmenter` and its :meth:`~glasses_detector.segmenter.GlassesSegmenter.predict`
115 |
116 | Demo
117 | ----
118 |
119 | Feel free to play around with some `demo image files `_. For example, after installing through `pip `_, you can run:
120 |
121 | .. code-block:: bash
122 |
123 | git clone https://github.com/mantasu/glasses-detector && cd glasses-detector/data
124 | glasses-detector -i demo -o demo_labels.csv --task classification:sunglasses -f proba
125 | glasses-detector -i demo -o demo_masks -t segmentation:full -f img -e .jpg
126 |
127 | Alternatively, you can check out the `demo notebook `_ which can be also accessed on `Google Colab `_.
128 |
--------------------------------------------------------------------------------
/src/glasses_detector/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | .. class:: FilePath
3 |
4 | .. data:: FilePath
5 | :noindex:
6 | :type: typing.TypeAliasType
7 | :value: str | bytes | os.PathLike
8 |
9 | Type alias for a file path.
10 |
11 | :class:`str` | :class:`bytes` | :class:`os.PathLike`
12 | """
13 |
14 | import functools
15 | import os
16 | import typing
17 | from typing import Any, Callable, Iterable, TypeGuard, overload
18 | from urllib.parse import urlparse
19 |
20 | import torch
21 | from PIL import Image
22 |
23 | type FilePath = str | bytes | os.PathLike
24 |
25 |
26 | class copy_signature[**P, T]:
27 | """Decorator to copy a function's or a method's signature.
28 |
29 | This decorator takes a callable and copies its signature to the
30 | decorated function or method.
31 |
32 | Example
33 | -------
34 |
35 | .. code-block:: python
36 |
37 | def full_signature(x: bool, *extra: int) -> str: ...
38 |
39 | @copy_signature(full_signature)
40 | def test_signature(*args, **kwargs):
41 | return full_signature(*args, **kwargs)
42 |
43 | reveal_type(test_signature) # 'def (x: bool, *extra: int) -> str'
44 |
45 | .. seealso::
46 |
47 | https://github.com/python/typing/issues/270#issuecomment-1344537820
48 |
49 | Args:
50 | source (typing.Callable[P, T]): The callable whose signature to
51 | copy.
52 | """
53 |
54 | def __init__(self, source: Callable[P, T]):
55 | # The source callable
56 | self.source = source
57 |
58 | def __call__(self, target: Callable[..., T]) -> Callable[P, T]:
59 | @functools.wraps(self.source)
60 | def wrapped(*args: P.args, **kwargs: P.kwargs) -> T:
61 | return target(*args, **kwargs)
62 |
63 | return wrapped
64 |
65 |
66 | class eval_infer_mode:
67 | """Context manager and decorator for evaluation and inference.
68 |
69 | This class can be used as a context manager or a decorator to set a
70 | PyTorch :class:`~torch.nn.Module` to evaluation mode via
71 | :meth:`~torch.nn.Module.eval` and enable
72 | :class:`~torch.no_grad` for the duration of a function
73 | or a ``with`` statement. After the function or the ``with``
74 | statement, the model's mode, i.e., :attr:`~torch.nn.Module.training`
75 | property, and :class:`~torch.no_grad` are restored to their
76 | original states.
77 |
78 | Example
79 | -------
80 |
81 | .. code-block:: python
82 |
83 | model = ... # Your PyTorch model
84 |
85 | @eval_infer_mode(model)
86 | def your_function():
87 | # E.g., forward pass
88 | pass
89 |
90 | # or
91 |
92 | with eval_infer_mode(model):
93 | # E.g., forward pass
94 | pass
95 |
96 | Args:
97 | model (torch.nn.Module): The PyTorch model to be set to
98 | evaluation mode.
99 | """
100 |
101 | def __init__(self, model: torch.nn.Module):
102 | self.model = model
103 | self.was_training = model.training
104 |
105 | def __call__[F](self, func: F) -> F:
106 | @functools.wraps(func)
107 | def wrapper(*args, **kwargs):
108 | with self:
109 | return func(*args, **kwargs)
110 |
111 | return wrapper
112 |
113 | def __enter__(self):
114 | self.model.eval()
115 | self._no_grad = torch.no_grad()
116 | self._no_grad.__enter__()
117 |
118 | def __exit__(self, type: Any, value: Any, traceback: Any):
119 | self._no_grad.__exit__(type, value, traceback)
120 | self.model.train(self.was_training)
121 |
122 |
123 | def is_url(x: str) -> bool:
124 | """Check if a string is a valid URL.
125 |
126 | Takes any string and checks if it is a valid URL.
127 |
128 | .. seealso::
129 |
130 | https://stackoverflow.com/a/38020041
131 |
132 | Args:
133 | x: The string to check.
134 |
135 | Returns:
136 | :data:`True` if the string is a valid URL, :data:`False`
137 | otherwise.
138 | """
139 | try:
140 | result = urlparse(x)
141 | return all([result.scheme, result.netloc])
142 | except:
143 | return False
144 |
145 |
146 | @overload
147 | def flatten[T](items: T) -> T: ...
148 |
149 |
150 | @overload
151 | def flatten[T](items: typing.Iterable[T | typing.Iterable]) -> list[T]: ...
152 |
153 |
154 | def flatten[T](items: T | Iterable[T | Iterable]) -> T | list[T]:
155 | """Flatten a nested list.
156 |
157 | This function takes any nested iterable and returns a flat list.
158 |
159 | Args:
160 | items (T | typing.Iterable[T | typing.Iterable]): The nested
161 | iterable to flatten.
162 |
163 | Returns:
164 | T | list[T]: The flattened list or the original ``items`` value
165 | if it is not an iterable or is of type :class:`str`.
166 | """
167 | if not isinstance(items, Iterable) or isinstance(items, str):
168 | # Not iterable
169 | return items
170 |
171 | # Init flat list
172 | flattened = []
173 |
174 | for item in items:
175 | if isinstance(item, Iterable) and not isinstance(item, str):
176 | flattened.extend(flatten(item))
177 | else:
178 | flattened.append(item)
179 |
180 | return flattened
181 |
182 |
183 | def is_path_type(path: Any) -> TypeGuard[FilePath]:
184 | """Check if an object is a valid path type.
185 |
186 | This function takes any object and checks if it is a valid path
187 | type. A valid path type is either a :class:`str`, :class:`bytes` or
188 | :class:`os.PathLike` object.
189 |
190 | Args:
191 | path: The object to check.
192 |
193 | Returns:
194 | :data:`True` if the object is a valid path type, :data:`False`
195 | otherwise.
196 | """
197 | return isinstance(path, (str, bytes, os.PathLike))
198 |
199 |
200 | def is_image_file(path: FilePath) -> bool:
201 | """Check if a file is an image.
202 |
203 | This function takes a file path and checks if it is an image file.
204 | This is done by checking if the file exists and if it can be
205 | identified as an image
206 |
207 | Args:
208 | path: The path to the file.
209 |
210 | Returns:
211 | :data:`True` if the file is an image, :data:`False` otherwise.
212 | """
213 | if not os.path.isfile(path):
214 | return False
215 |
216 | try:
217 | with Image.open(path) as img:
218 | img.verify()
219 | return True
220 | except (IOError, SyntaxError):
221 | return False
222 |
--------------------------------------------------------------------------------
/scripts/run.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import pytorch_lightning as pl
5 | import torch
6 | from pytorch_lightning.callbacks import ModelCheckpoint
7 | from pytorch_lightning.cli import LightningCLI
8 | from pytorch_lightning.tuner import Tuner
9 |
10 | PROJECT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
11 | DEFAULT_KINDS = {
12 | "classification": "anyglasses",
13 | "detection": "worn",
14 | "segmentation": "smart",
15 | }
16 |
17 | sys.path.append(os.path.join(PROJECT_DIR, "src"))
18 | torch.set_float32_matmul_precision("medium")
19 |
20 | from glasses_detector import GlassesClassifier, GlassesDetector, GlassesSegmenter
21 | from glasses_detector._data import (
22 | BinaryClassificationDataset,
23 | BinaryDetectionDataset,
24 | BinarySegmentationDataset,
25 | )
26 | from glasses_detector._wrappers import BinaryClassifier, BinaryDetector, BinarySegmenter
27 |
28 |
29 | class RunCLI(LightningCLI):
30 | def add_arguments_to_parser(self, parser):
31 | # Add args for wrapper creation
32 | parser.add_argument(
33 | "-r",
34 | "--root",
35 | metavar="path/to/data/root",
36 | type=str,
37 | default=os.path.join(PROJECT_DIR, "data"),
38 | help="Path to the data directory with classification and segmentation subdirectories that contain datasets for different kinds of tasks. Defaults to 'data' under project root.",
39 | )
40 | parser.add_argument(
41 | "-t",
42 | "--task",
43 | metavar="",
44 | type=str,
45 | default="classification:anyglasses",
46 | choices=(ch := [
47 | "classification",
48 | "classification:anyglasses",
49 | "classification:sunglasses",
50 | "classification:eyeglasses",
51 | "classification:shadows",
52 | "detection",
53 | "detection:eyes",
54 | "detection:solo",
55 | "detection:worn",
56 | "segmentation",
57 | "segmentation:frames",
58 | "segmentation:full",
59 | "segmentation:legs",
60 | "segmentation:lenses",
61 | "segmentation:shadows",
62 | "segmentation:smart",
63 | ]),
64 | help=f"The kind of task to train/test the model for. One of {", ".join([f"'{c}'" for c in ch])}. If specified only as 'classification', 'detection', or 'segmentation', the subcategories 'anyglasses', 'worn', and 'smart' will be chosen, respectively. Defaults to 'classification:anyglasses'.",
65 | )
66 | parser.add_argument(
67 | "-s",
68 | "--size",
69 | metavar="",
70 | type=str,
71 | default="medium",
72 | choices=["small", "medium", "large"],
73 | help="The model size which determines architecture type. One of 'small', 'medium', 'large'. Defaults to 'medium'.",
74 | )
75 | parser.add_argument(
76 | "-b",
77 | "--batch-size",
78 | metavar="",
79 | type=int,
80 | default=64,
81 | help="The batch size used for training. Defaults to 64.",
82 | )
83 | parser.add_argument(
84 | "-n",
85 | "--num-workers",
86 | metavar="",
87 | type=int,
88 | default=8,
89 | help="The number of workers for the data loader. Defaults to 8.",
90 | )
91 | parser.add_argument(
92 | "-w",
93 | "--weights",
94 | metavar="path/to/weights",
95 | type=str | None,
96 | default=None,
97 | help="Path to weights to load into the model. Defaults to None.",
98 | )
99 | parser.add_argument(
100 | "-f",
101 | "--find-lr",
102 | action="store_true",
103 | help="Whether to run the learning rate finder before training. Defaults to False.",
104 | )
105 | parser.add_lightning_class_args(ModelCheckpoint, "checkpoint")
106 |
107 | # Checkpoint and trainer defaults
108 | parser.set_defaults(
109 | {
110 | "checkpoint.dirpath": "checkpoints",
111 | "checkpoint.save_last": False,
112 | "checkpoint.monitor": "val_loss",
113 | "checkpoint.mode": "min",
114 | "trainer.precision": "bf16-mixed",
115 | "trainer.max_epochs": 300,
116 | }
117 | )
118 |
119 | # Link argument with wrapper creation callback arguments
120 | parser.link_arguments("root", "model.root", apply_on="parse")
121 | parser.link_arguments("task", "model.task", apply_on="parse")
122 | parser.link_arguments("size", "model.size", apply_on="parse")
123 | parser.link_arguments("batch_size", "model.batch_size", apply_on="parse")
124 | parser.link_arguments("num_workers", "model.num_workers", apply_on="parse")
125 | parser.link_arguments("weights", "model.weights", apply_on="parse")
126 |
127 | def before_fit(self):
128 | if self.config.fit.checkpoint.filename is None:
129 | # Update default filename for checkpoint saver callback
130 | self.model_name = (
131 | self.config.fit.model.task.replace(":", "-") + "-" + self.config.fit.model.size
132 | )
133 | self.trainer.callbacks[-1].filename = (
134 | self.model_name + "-{epoch:02d}-{val_loss:.3f}"
135 | )
136 |
137 | if self.config.fit.find_lr:
138 | # Run learning rate finder
139 | tuner = Tuner(self.trainer)
140 | lr_finder = tuner.lr_find(self.model, min_lr=1e-5, max_lr=1e-1, num_training=500)
141 | self.model.lr = lr_finder.suggestion()
142 |
143 | def after_fit(self):
144 | # Get the best checkpoint path and load it
145 | ckpt_path = self.trainer.callbacks[-1].best_model_path
146 | ckpt_dir = os.path.dirname(ckpt_path)
147 | ckpt = torch.load(ckpt_path)
148 |
149 | # Load wights and save the inner model as pth
150 | self.model.load_state_dict(ckpt["state_dict"])
151 | torch.save(
152 | self.model.model.state_dict(),
153 | os.path.join(ckpt_dir, self.model_name + ".pth"),
154 | )
155 |
156 | def create_wrapper_callback(
157 | root: str = "data",
158 | task: str = "classification",
159 | size: str = "medium",
160 | batch_size: int = 64,
161 | num_workers: int = 8,
162 | weights: str | None = None,
163 | ) -> pl.LightningModule:
164 |
165 | # Get task and kind
166 | task_and_kind = task.split(":")
167 | task = task_and_kind[0]
168 | kind = DEFAULT_KINDS[task] if len(task_and_kind) == 1 else task_and_kind[1]
169 |
170 | # Get model and dataset classes
171 | model_cls, data_cls = {
172 | "classification": (GlassesClassifier, BinaryClassificationDataset),
173 | "detection": (GlassesDetector, BinaryDetectionDataset),
174 | "segmentation": (GlassesSegmenter, BinarySegmentationDataset),
175 | }[task]
176 |
177 | # Set-up wrapper initialization kwargs
178 | kwargs = {
179 | "root": os.path.join(root, task, kind),
180 | "batch_size": batch_size,
181 | "num_workers": num_workers,
182 | }
183 |
184 | # Update wrapper initialization kwargs and set the initializer class
185 | if task == "classification":
186 | kwargs["cat2idx_fn"] = {"no_" + kind: 0, kind: 1}
187 | wrapper_cls = BinaryClassifier
188 | elif task == "detection":
189 | wrapper_cls = BinaryDetector
190 | elif task == "segmentation":
191 | wrapper_cls = BinarySegmenter
192 |
193 | # Initialize model architecture and load weights if needed
194 | model = model_cls(kind=kind, size=size, weights=weights).model
195 |
196 | return wrapper_cls(model, *data_cls.create_loaders(**kwargs))
197 |
198 |
199 | def cli_main():
200 | cli = RunCLI(create_wrapper_callback, seed_everything_default=0)
201 |
202 |
203 | if __name__ == "__main__":
204 | cli_main()
205 |
--------------------------------------------------------------------------------
/src/glasses_detector/__main__.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import inspect
3 | import os
4 |
5 | import torch
6 |
7 | from . import GlassesClassifier, GlassesDetector, GlassesSegmenter
8 | from .components import BaseGlassesModel
9 |
10 |
11 | def parse_kwargs():
12 | parser = argparse.ArgumentParser(
13 | prog="Glasses Detector",
14 | description=f"Classification, detection, and segmentation of glasses "
15 | f"in images.",
16 | )
17 |
18 | parser.add_argument(
19 | "-i",
20 | "--input",
21 | metavar="path/to/dir/or/file",
22 | type=str,
23 | required=True,
24 | help="Path to the input image or the directory with images.",
25 | )
26 | parser.add_argument(
27 | "-o",
28 | "--output",
29 | metavar="path/to/dir/or/file",
30 | type=str,
31 | default=None,
32 | help=f"Path to the output file or the directory. If not provided, "
33 | f"then, if input is a file, the prediction will be printed (or shown "
34 | f"if it is an image), otherwise, if input is a directory, the "
35 | f"predictions will be written to a directory with the same name with "
36 | f"an added suffix '_preds'. If provided as a file, then the "
37 | f"prediction(-s) will be saved to this file (supported extensions "
38 | f"include: .txt, .csv, .json, .npy, .pkl, .jpg, .png). If provided as "
39 | f"a directory, then the predictions will be saved to this directory "
40 | f"use `--extension` flag to specify the file extensions in that "
41 | f"directory. Defaults to None.",
42 | )
43 | parser.add_argument(
44 | "-e",
45 | "--extension",
46 | metavar="",
47 | type=str,
48 | default=None,
49 | choices=(ch := [".txt", ".csv", ".json", ".npy", ".pkl", ".jpg", ".png"]),
50 | help=f"Only used if `--output` is a directory. The extension to "
51 | f"use to save the predictions as files. Common extensions include: "
52 | f"{", ".join([f"{c}" for c in ch])}. If not specified, it will be set "
53 | f"automatically to .jpg for image predictions and to .txt for all "
54 | f"other formats. Defaults to None.",
55 | )
56 | parser.add_argument(
57 | "-f",
58 | "--format",
59 | metavar="",
60 | type=str,
61 | default=None,
62 | help=f"The format to use to map the raw prediction to. For "
63 | f"classification, common formats are bool, proba, str, for detection, "
64 | f"common formats are bool, int, img, for segmentation, common formats "
65 | f"are proba, img, mask. If not specified, it will be set "
66 | f"automatically to str, img, mask for classification, detection, "
67 | f"segmentation respectively. Check API documentation for more "
68 | f"details. Defaults to None.",
69 | )
70 | parser.add_argument(
71 | "-t",
72 | "--task",
73 | metavar="",
74 | type=str,
75 | default="classification:anyglasses",
76 | choices=(ch := [
77 | "classification",
78 | "classification:anyglasses",
79 | "classification:sunglasses",
80 | "classification:eyeglasses",
81 | "classification:shadows",
82 | "detection",
83 | "detection:eyes",
84 | "detection:solo",
85 | "detection:worn",
86 | "segmentation",
87 | "segmentation:frames",
88 | "segmentation:full",
89 | "segmentation:legs",
90 | "segmentation:lenses",
91 | "segmentation:shadows",
92 | "segmentation:smart",
93 | ]),
94 | help=f"The kind of task the model should perform. One of "
95 | f"{", ".join([f"{c}" for c in ch])}. If specified only as "
96 | f"classification, detection, or segmentation, the subcategories "
97 | f"anyglasses, worn, and smart will be chosen, respectively. Defaults "
98 | f"to classification:anyglasses.",
99 | )
100 | parser.add_argument(
101 | "-s",
102 | "--size",
103 | metavar="",
104 | type=str,
105 | default="medium",
106 | choices=["small", "medium", "large", "s", "m", "l"],
107 | help=f"The model size which determines architecture type. One of "
108 | f"'small', 'medium', 'large' (or 's', 'm', 'l'). Defaults to 'medium'.",
109 | )
110 | parser.add_argument(
111 | "-b",
112 | "--batch-size",
113 | metavar="",
114 | type=int,
115 | default=1,
116 | help=f"Only used if `--input` is a directory. The batch size to "
117 | f"use when processing the images. This groups the files in the input "
118 | f"directory to batches of size `batch_size` before processing them. "
119 | f"In some cases, larger batch sizes can speed up the processing at "
120 | f"the cost of more memory usage. Defaults to 1."
121 | )
122 | parser.add_argument(
123 | "-p",
124 | "--pbar",
125 | type=str,
126 | metavar="",
127 | default="Processing",
128 | help=f"Only used if `--input` is a directory. It is the "
129 | f"description that is used for the progress bar. If specified "
130 | f"as '' (empty string), no progress bar is shown. Defaults to "
131 | f"'Processing'.",
132 | )
133 | parser.add_argument(
134 | "-w",
135 | "--weights",
136 | metavar="path/to/weights.pth",
137 | type=str,
138 | default=None,
139 | help=f"Path to custom weights to load into the model. If not "
140 | f"specified, weights will be loaded from the default location (and "
141 | f"automatically downloaded there if needed). Defaults to None.",
142 | )
143 | parser.add_argument(
144 | "-d",
145 | "--device",
146 | type=str,
147 | metavar="",
148 | default=None,
149 | help=f"The device on which to perform inference. If not specified, it "
150 | f"will be automatically checked if CUDA or MPS is supported. "
151 | f"Defaults to None.",
152 | )
153 |
154 | return vars(parser.parse_args())
155 |
156 | def prepare_kwargs(kwargs: dict[str, str | int | None]):
157 | # Define the keys to use when calling process and init methods
158 | model_keys = inspect.getfullargspec(BaseGlassesModel.__init__).args
159 | process_keys = [key for key in kwargs.keys() if key not in model_keys]
160 | process_keys += ["ext", "show", "is_file", "input_path", "output_path"]
161 |
162 | # Add "is_file" key to check which process method to call
163 | kwargs["is_file"] = os.path.splitext(kwargs["input"])[-1] != ""
164 | kwargs["ext"] = kwargs.pop("extension")
165 | kwargs["input_path"] = kwargs.pop("input")
166 | kwargs["output_path"] = kwargs.pop("output")
167 |
168 | if not kwargs["is_file"] and kwargs["output_path"] is None:
169 | # Input is a directory but no output path is specified
170 | kwargs["output_path"] = os.path.splitext(kwargs["input_path"])[0] + "_preds"
171 |
172 | if kwargs["is_file"] and kwargs["output_path"] is None:
173 | # Input is a file and no output path is specified
174 | kwargs["show"] = True
175 |
176 | if kwargs["pbar"] == "":
177 | # No progress bar
178 | kwargs["pbar"] = None
179 |
180 | if kwargs["weights"] is None:
181 | # Use default weights
182 | kwargs["weights"] = True
183 |
184 | if len(splits := kwargs["task"].split(":")) == 2:
185 | # Task is specified as "task:kind"
186 | kwargs["task"] = splits[0]
187 | kwargs["kind"] = splits[1]
188 |
189 | if kwargs["format"] is None and kwargs["task"] == "classification":
190 | # Default format for classification
191 | kwargs["format"] = "str"
192 | elif kwargs["format"] is None and kwargs["task"] == "detection":
193 | # Default format for detection
194 | kwargs["format"] = "img"
195 | elif kwargs["format"] is None and kwargs["task"] == "segmentation":
196 | # Default format for segmentation
197 | kwargs["format"] = "mask"
198 |
199 | # Get the kwargs for the process and init methods
200 | process_kwargs = {k: kwargs[k] for k in process_keys if k in kwargs}
201 | model_kwargs = {k: kwargs[k] for k in model_keys if k in kwargs}
202 |
203 | return process_kwargs, model_kwargs
204 |
205 |
206 | def main():
207 | # Parse CLI args; prepare to create model and process images
208 | process_kwargs, model_kwargs = prepare_kwargs(parse_kwargs())
209 | is_file = process_kwargs.pop("is_file")
210 | task = model_kwargs.pop("task")
211 |
212 | # Create model
213 | match task:
214 | case "classification":
215 | model = GlassesClassifier(**model_kwargs)
216 | case "detection":
217 | model = GlassesDetector(**model_kwargs)
218 | case "segmentation":
219 | model = GlassesSegmenter(**model_kwargs)
220 | case _:
221 | raise ValueError(f"Unknown task '{task}'.")
222 |
223 | if is_file:
224 | # Process a single image file
225 | process_kwargs.pop("batch_size")
226 | process_kwargs.pop("pbar")
227 | model.process_file(**process_kwargs)
228 | else:
229 | # Process a directory of images
230 | model.process_dir(**process_kwargs)
231 |
232 | if __name__ == "__main__":
233 | main()
234 |
--------------------------------------------------------------------------------
/docs/helpers/generate_examples.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | from PIL import Image
4 |
5 | sys.path.append("../src")
6 |
7 | from glasses_detector import GlassesDetector, GlassesSegmenter
8 |
9 | SAMPLES = {
10 | "classification": {
11 | "eyeglasses": "data/classification/eyeglasses/sunglasses-glasses-detect/test/eyeglasses/face-149_png.rf.cc420d484a00dd7158550510785f0d51.jpg",
12 | "sunglasses": "data/classification/sunglasses/face-attributes-grouped/test/sunglasses/2014-Colored-Mirror-Sunglasses-2014-Renkli-ve-Aynal-Caml-Gunes-Gozlugu-Modelleri-19.jpg",
13 | "no-glasses": "data/classification/eyeglasses/face-attributes-extra/test/no_eyeglasses/800px_COLOURBOX4590613.jpg",
14 | },
15 | "segmentation-frames": [
16 | {
17 | "img": "data/segmentation/frames/eyeglass/val/images/Woman-wearing-white-rimmed-thin-spectacles_jpg.rf.097bc1a8d872b6acddc3312e95a2d48e.jpg",
18 | "msk": "data/segmentation/frames/eyeglass/val/masks/Woman-wearing-white-rimmed-thin-spectacles_jpg.rf.097bc1a8d872b6acddc3312e95a2d48e.jpg",
19 | },
20 | {
21 | "img": "data/segmentation/frames/glasses-segmentation-synthetic/test/images/img-Glass001-386-2_mouth_open-3-missile_launch_facility_01-037.jpg",
22 | "msk": "data/segmentation/frames/glasses-segmentation-synthetic/test/masks/img-Glass001-386-2_mouth_open-3-missile_launch_facility_01-037.jpg",
23 | },
24 | {
25 | "img": "data/segmentation/frames/glasses-segmentation-synthetic/test/images/img-Glass021-450-4_brow_lower-2-entrance_hall-204.jpg",
26 | "msk": "data/segmentation/frames/glasses-segmentation-synthetic/test/masks/img-Glass021-450-4_brow_lower-2-entrance_hall-204.jpg",
27 | },
28 | ],
29 | "segmentation-full": [
30 | {
31 | "img": "data/segmentation/full/celeba-mask-hq/test/images/821.jpg",
32 | "msk": "data/segmentation/full/celeba-mask-hq/test/masks/821.jpg",
33 | },
34 | {
35 | "img": "data/segmentation/full/celeba-mask-hq/test/images/1034.jpg",
36 | "msk": "data/segmentation/full/celeba-mask-hq/test/masks/1034.jpg",
37 | },
38 | {
39 | "img": "data/segmentation/full/celeba-mask-hq/test/images/2442.jpg",
40 | "msk": "data/segmentation/full/celeba-mask-hq/test/masks/2442.jpg",
41 | },
42 | ],
43 | "segmentation-legs": [
44 | {
45 | "img": "data/segmentation/legs/capstone-mini-2/test/images/IMG20230325193452_0_jpg.rf.ea87e7fe943f39216cacc84b32848e28.jpg",
46 | "msk": "data/segmentation/legs/capstone-mini-2/test/masks/IMG20230325193452_0_jpg.rf.ea87e7fe943f39216cacc84b32848e28.jpg",
47 | },
48 | {
49 | "img": "data/segmentation/legs/sunglasses-color-detection/test/images/2004970PJPXT_P00_JPG_jpg.rf.e1cd193efd84dac31c027e3d3649ec7a.jpg",
50 | "msk": "data/segmentation/legs/sunglasses-color-detection/test/masks/2004970PJPXT_P00_JPG_jpg.rf.e1cd193efd84dac31c027e3d3649ec7a.jpg",
51 | },
52 | {
53 | "img": "data/segmentation/legs/sunglasses-color-detection/test/images/aug_57_203675009QQT_P00_JPG_jpg.rf.54bdfef21f854be18d9dcf13fa5a7ae7.jpg",
54 | "msk": "data/segmentation/legs/sunglasses-color-detection/test/masks/aug_57_203675009QQT_P00_JPG_jpg.rf.54bdfef21f854be18d9dcf13fa5a7ae7.jpg",
55 | },
56 | ],
57 | "segmentation-lenses": [
58 | {
59 | "img": "data/segmentation/lenses/glasses-lens/test/images/face-35_jpg.rf.f0a9a1d3b4f9e756488294d2db1720d5.jpg",
60 | "msk": "data/segmentation/lenses/glasses-lens/test/masks/face-35_jpg.rf.f0a9a1d3b4f9e756488294d2db1720d5.jpg",
61 | },
62 | {
63 | "img": "data/segmentation/lenses/glass-color/test/images/2025260PJPVP_P00_JPG_jpg.rf.aaa9e83edbfd8a3c107650b62ddf52ed.jpg",
64 | "msk": "data/segmentation/lenses/glass-color/test/masks/2025260PJPVP_P00_JPG_jpg.rf.aaa9e83edbfd8a3c107650b62ddf52ed.jpg",
65 | },
66 | {
67 | "img": "data/segmentation/lenses/glasses-segmentation-cropped-faces/test/images/face-1306_scaled_cropping_jpg.rf.b5f5b788fb75aa05a15e69b938704c12.jpg",
68 | "msk": "data/segmentation/lenses/glasses-segmentation-cropped-faces/test/masks/face-1306_scaled_cropping_jpg.rf.b5f5b788fb75aa05a15e69b938704c12.jpg",
69 | },
70 | ],
71 | "segmentation-shadows": [
72 | {
73 | "img": "data/segmentation/shadows/glasses-segmentation-synthetic/test/images/img-Glass021-435-16_sadness-2-versveldpas-105.jpg",
74 | "msk": "data/segmentation/shadows/glasses-segmentation-synthetic/test/masks/img-Glass021-435-16_sadness-2-versveldpas-105.jpg",
75 | },
76 | {
77 | "img": "data/segmentation/shadows/glasses-segmentation-synthetic/test/images/img-Glass001-379-4_brow_lower-1-simons_town_rocks-315.jpg",
78 | "msk": "data/segmentation/shadows/glasses-segmentation-synthetic/test/masks/img-Glass001-379-4_brow_lower-1-simons_town_rocks-315.jpg",
79 | },
80 | {
81 | "img": "data/segmentation/shadows/glasses-segmentation-synthetic/test/images/img-Glass018-422-7_jaw_left-1-urban_street_03-039.jpg",
82 | "msk": "data/segmentation/shadows/glasses-segmentation-synthetic/test/masks/img-Glass018-422-7_jaw_left-1-urban_street_03-039.jpg",
83 | },
84 | ],
85 | "segmentation-smart": [
86 | {
87 | "img": "data/segmentation/smart/face-synthetics-glasses/test/images/000410.jpg",
88 | "msk": "data/segmentation/smart/face-synthetics-glasses/test/masks/000410.jpg",
89 | },
90 | {
91 | "img": "data/segmentation/smart/face-synthetics-glasses/test/images/001229.jpg",
92 | "msk": "data/segmentation/smart/face-synthetics-glasses/test/masks/001229.jpg",
93 | },
94 | {
95 | "img": "data/segmentation/smart/face-synthetics-glasses/test/images/002315.jpg",
96 | "msk": "data/segmentation/smart/face-synthetics-glasses/test/masks/002315.jpg",
97 | },
98 | ],
99 | "detection-eyes": [
100 | {
101 | "img": "data/detection/eyes/ex07/test/images/face-16_jpg.rf.9554ce9ff29cca368918cb849806902f.jpg",
102 | "ann": "data/detection/eyes/ex07/test/annotations/face-16_jpg.rf.9554ce9ff29cca368918cb849806902f.txt",
103 | },
104 | {
105 | "img": "data/detection/eyes/glasses-detection/test/images/41d3e9440d1678109133_jpeg.rf.564dc61348a3986faf801d352a7ebe41.jpg",
106 | "ann": "data/detection/eyes/glasses-detection/test/annotations/41d3e9440d1678109133_jpeg.rf.564dc61348a3986faf801d352a7ebe41.txt",
107 | },
108 | {
109 | "img": "data/detection/eyes/glasses-detection/test/images/woman-face-eyes-feeling_jpg.rf.8c1547d76fe23936984db74a5507f188.jpg",
110 | "ann": "data/detection/eyes/glasses-detection/test/annotations/woman-face-eyes-feeling_jpg.rf.8c1547d76fe23936984db74a5507f188.txt",
111 | },
112 | ],
113 | "detection-solo": [
114 | {
115 | "img": "data/detection/solo/onlyglasses/test/images/8--52-_jpg.rf.cfc2d6dec8f46cd5b91c9c112fbb8bf3.jpg",
116 | "ann": "data/detection/solo/onlyglasses/test/annotations/8--52-_jpg.rf.cfc2d6dec8f46cd5b91c9c112fbb8bf3.txt",
117 | },
118 | {
119 | "img": "data/detection/solo/kacamata-membaca/test/images/85_jpg.rf.4c164fa95a20bebc7c888d34ed160e16.jpg",
120 | "ann": "data/detection/solo/kacamata-membaca/test/annotations/85_jpg.rf.4c164fa95a20bebc7c888d34ed160e16.txt",
121 | },
122 | {
123 | "img": "data/detection/worn/ai-pass/test/images/28f3d11c3465ce2e74d8a4d65861de51_jpg.rf.e119438402f655cec8032304c7603606.jpg",
124 | "ann": "data/detection/worn/ai-pass/test/annotations/28f3d11c3465ce2e74d8a4d65861de51_jpg.rf.e119438402f655cec8032304c7603606.txt",
125 | },
126 | ],
127 | "detection-worn": [
128 | {
129 | "img": "data/detection/worn/glasses-detection/test/images/425px-robert_downey_jr_avp_iron_man_3_paris_jpg.rf.998f29000b52081eb6ea4d25df75512c.jpg",
130 | "ann": "data/detection/worn/glasses-detection/test/annotations/425px-robert_downey_jr_avp_iron_man_3_paris_jpg.rf.998f29000b52081eb6ea4d25df75512c.txt",
131 | },
132 | {
133 | "img": "data/detection/worn/ai-pass/test/images/glasses120_png_jpg.rf.847610bd1230c85c8f81cbced18c38ea.jpg",
134 | "ann": "data/detection/worn/ai-pass/test/annotations/glasses120_png_jpg.rf.847610bd1230c85c8f81cbced18c38ea.txt",
135 | },
136 | {
137 | "img": "data/detection/worn/ai-pass/test/images/women-with-glass_81_jpg.rf.bef8096d89dd0805ff3bbf0f8d08b0c8.jpg",
138 | "ann": "data/detection/worn/ai-pass/test/annotations/women-with-glass_81_jpg.rf.bef8096d89dd0805ff3bbf0f8d08b0c8.txt",
139 | },
140 | ],
141 | }
142 |
143 |
144 | def generate_examples(data_dir: str = "..", out_dir: str = "_static/img"):
145 | for task, samples in SAMPLES.items():
146 | if task == "classification":
147 | for label, path in samples.items():
148 | # Load the image and save it
149 | img = Image.open(f"{data_dir}/{path}")
150 | img.save(f"{out_dir}/{task}-{label}.jpg")
151 | elif task.startswith("detection"):
152 | for i, sample in enumerate(samples):
153 | # Load the image
154 | img = Image.open(f"{data_dir}/{sample["img"]}")
155 |
156 | with open(f"{data_dir}/{sample["ann"]}", "r") as f:
157 | # Load annotations (single bbox per image)
158 | ann = [list(map(float, f.read().split()))]
159 |
160 | # Draw the bounding box and save the image
161 | out = GlassesDetector.draw_boxes(img, ann, colors="red", width=3)
162 | out.save(f"{out_dir}/{task}-{i}.jpg")
163 | elif task.startswith("segmentation"):
164 | for i, sample in enumerate(samples):
165 | # Load image and mask and overlay them
166 | img = Image.open(f"{data_dir}/{sample["img"]}")
167 | msk = Image.open(f"{data_dir}/{sample["msk"]}")
168 | out = GlassesSegmenter.draw_masks(img, msk, colors="red", alpha=0.5)
169 | out.save(f"{out_dir}/{task}-{i}.jpg")
170 |
171 |
172 | if __name__ == "__main__":
173 | # cd docs/
174 | # python helpers/generate_examples.py
175 | generate_examples()
176 |
--------------------------------------------------------------------------------
/docs/helpers/build_finished.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 |
4 | import requests
5 | import yaml
6 | from bs4 import BeautifulSoup, NavigableString, Tag
7 | from sphinx.application import Sphinx
8 |
9 |
10 | class BuildFinished:
11 | def __init__(self, static_path: str = "_static", conf_path: str = "conf.yaml"):
12 | # Init inv directory and create it if not exists
13 | self.inv_dir = os.path.join(static_path, "inv")
14 | os.makedirs(self.inv_dir, exist_ok=True)
15 |
16 | with open(conf_path) as f:
17 | # Load conf.yaml and get build_finished section
18 | self.conf = yaml.safe_load(f)["build-finished"]
19 |
20 | def align_rowspans(self, soup: BeautifulSoup):
21 | if tds := soup.find_all("td", rowspan=True):
22 | for td in tds:
23 | td["valign"] = "middle"
24 |
25 | def add_collapse_ids(self, soup: BeautifulSoup):
26 | if details := soup.find_all("details"):
27 | for detail in details:
28 | if detail.has_attr("name"):
29 | detail["id"] = "-".join(detail["name"].split())
30 |
31 | def keep_only_data(self, soup: BeautifulSoup):
32 | def has_children(tag: Tag, txt1: str, txt2: str):
33 | if tag.name != "dt":
34 | return False
35 |
36 | # Get the prename and name elements of the signature
37 | ch1 = tag.select_one("span.sig-prename.descclassname span.pre")
38 | ch2 = tag.select_one("span.sig-name.descname span.pre")
39 |
40 | return ch1 and ch2 and ch1.string == txt1 and ch2.string == txt2
41 |
42 | for alias, module in self.conf["TYPE_ALIASES"].items():
43 | if dt := soup.find("dt", id=f"{module}{alias}"):
44 | # Copy class directive's a
45 | a = dt.find("a").__copy__()
46 | dt.parent.decompose()
47 | else:
48 | continue
49 |
50 | if dt := soup.find(lambda tag: has_children(tag, module, alias)):
51 | # ID and a for data directive
52 | dt["id"] = f"{module}{alias}"
53 | dt.append(a)
54 | dt.find("span", class_="sig-prename descclassname").decompose()
55 |
56 | def process_in_page_toc(self, soup: BeautifulSoup):
57 | for li in soup.find_all("li", class_="toc-h3 nav-item toc-entry"):
58 | if span := li.find("span"):
59 | # Modify the toc-nav span element here
60 | span.string = span.string.split(".")[-1]
61 |
62 | def break_long_signatures(self, soup: BeautifulSoup):
63 | def break_long_params(id, sig_param):
64 | if (params := self.conf["LONG_PARAMETER_IDS"].get(id)) is None:
65 | return
66 |
67 | is_opened = False
68 |
69 | for span in sig_param.find_all("span", class_="pre"):
70 | if span.string == "[":
71 | is_opened = True
72 | elif span.string == "]":
73 | is_opened = False
74 |
75 | if (
76 | span.string == "|"
77 | and not is_opened
78 | and span.parent.parent.parent.find("span", class_="pre").string
79 | in params
80 | ):
81 | # Add long-sig to spans with |
82 | span["class"].append("long-sig")
83 |
84 | for id in self.conf["LONG_SIGNATURE_IDS"]:
85 | if not (dt := soup.find("dt", id=id)):
86 | continue
87 |
88 | for sig_param in dt.find_all("em", class_="sig-param"):
89 | # Add long-sig to the identified sig-param ems
90 | sig_param["class"].append("long-sig")
91 | break_long_params(id, sig_param)
92 |
93 | for dt_sibling in dt.find_next_siblings("dt"):
94 | for sig_param in dt_sibling.find_all("em", class_="sig-param"):
95 | # Add long-sig for overrides, i.e., sibling dts, too
96 | sig_param["class"].append("long-sig")
97 | break_long_params(id, sig_param)
98 |
99 | def customize_code_block_colors_python(self, soup: BeautifulSoup):
100 | for span in soup.select("div.highlight-python div.highlight pre span"):
101 | for name, keyword in self.conf["CUSTOM_SYNTAX_COLORS_PYTHON"].items():
102 | if span.get_text().strip() in keyword:
103 | # Add class of the syntax keyword
104 | span["class"].append(name)
105 |
106 | def customize_code_block_colors_bash(self, soup: BeautifulSoup):
107 | # Select content groups
108 | pres = soup.select("div.highlight-bash div.highlight pre")
109 | pres.extend(soup.select("code.highlight-bash"))
110 |
111 | # Define the constants
112 | KEEP_CLS = {"c1", "w"}
113 | OP_CLS = "custom-highlight-op"
114 | START_CLS = "custom-highlight-start"
115 | DEFAULT_CLS = "custom-highlight-default"
116 |
117 | # Get the starts and flatten the keywords
118 | starts = self.conf["CUSTOM_SYNTAX_COLORS_BASH"][START_CLS]
119 | ops = self.conf["CUSTOM_SYNTAX_COLORS_BASH"][OP_CLS] + ["\n"]
120 | flat_kwds = [
121 | (cls, kwd)
122 | for cls, kwds in self.conf["CUSTOM_SYNTAX_COLORS_BASH"].items()
123 | for kwd in kwds
124 | if cls not in [START_CLS, DEFAULT_CLS, OP_CLS]
125 | ]
126 |
127 | for pre in pres:
128 | for content in pre.contents:
129 | if (
130 | isinstance(content, Tag)
131 | and "class" in content.attrs.keys()
132 | and not any(cls in content["class"] for cls in KEEP_CLS)
133 | ):
134 | # Only keep the text part, i.e., remove
135 | content.replace_with(NavigableString(content.get_text()))
136 | elif isinstance(content, NavigableString) and "\n" in content:
137 | # Init the splits
138 | sub_contents = []
139 |
140 | for sub_content in content.split("\n"):
141 | if sub_content != "":
142 | # No need to add borderline empty strings
143 | sub_contents.append(NavigableString(sub_content))
144 |
145 | # Also add the newline character as NS
146 | sub_contents.append(NavigableString("\n"))
147 |
148 | # Replace the original content with splits
149 | content.replace_with(*sub_contents[:-1])
150 |
151 | for pre in pres:
152 | # Init the starts
153 | start_idx = 0
154 |
155 | for content in pre.contents:
156 | if not isinstance(content, NavigableString):
157 | # Skip non-navigable strings
158 | continue
159 |
160 | if content in ops:
161 | # Reset start
162 | start_idx = 0
163 |
164 | # If keyword is an operator, wrap with OP_CLS
165 | new_content = f'{content}'
166 | content.replace_with(BeautifulSoup(new_content, "html.parser"))
167 | continue
168 |
169 | # Get the start keyword if it exists
170 | start = [
171 | sub_start
172 | for start in starts
173 | for sub_start_idx, sub_start in enumerate(start.split())
174 | if start_idx == sub_start_idx and content == sub_start
175 | ]
176 |
177 | # Increment start idx
178 | start_idx += 1
179 |
180 | if len(start) > 0:
181 | # If keyword is a start
182 | new_content = f'{start[0]}'
183 | content.replace_with(BeautifulSoup(new_content, "html.parser"))
184 | continue
185 |
186 | # Check if any of the keywords from config matches
187 | is_kwd = [content.startswith(kwd) for _, kwd in flat_kwds]
188 |
189 | if any(is_kwd):
190 | # Add the corresponding keyword class
191 | cls, _ = flat_kwds[is_kwd.index(True)]
192 | new_content = f'{content}'
193 | else:
194 | # Add the default class if no keyword is found
195 | new_content = f'{content}'
196 |
197 | # Replace the original content with the new one
198 | content.replace_with(BeautifulSoup(new_content, "html.parser"))
199 |
200 | # Prettify soup
201 | soup.prettify()
202 |
203 | def configure_section_icons(self, soup: BeautifulSoup):
204 | #
205 | as_ = soup.select("nav.navbar-nav li.nav-item a")
206 |
207 | for a in as_:
208 | for name, icon_cls in self.conf["SECTION_ICONS"].items():
209 | if a.get_text().strip() != name:
210 | continue
211 |
212 | # Add an icon to the navbar section item
213 | a.parent["style"] = "margin-right: 0.5em;"
214 | icon = f""
215 | a_new = f"{icon}{name}"
216 | a.string.replace_with(BeautifulSoup(a_new, "html.parser"))
217 |
218 | as_ = soup.select("div#index-toctree li.toctree-l1 a")
219 |
220 | for a in as_:
221 | for content in a.contents:
222 | if isinstance(content, Tag) and content.name == "span":
223 | content["style"] = "margin-right: 0.5em;"
224 | continue
225 | elif (
226 | not isinstance(content, NavigableString)
227 | or content.strip() not in self.conf["SECTION_ICONS"].keys()
228 | ):
229 | continue
230 |
231 | # Modify the underline of the toctree section item
232 | content.parent["style"] = "text-decoration: none;"
233 | content.replace_with(
234 | BeautifulSoup(f"{content.get_text().strip()}", "html.parser")
235 | )
236 |
237 | soup.prettify()
238 |
239 | def edit_html(self, app: Sphinx):
240 | if app.builder.format != "html":
241 | return
242 |
243 | for pagename in app.env.found_docs:
244 | if not isinstance(pagename, str):
245 | continue
246 |
247 | with (Path(app.outdir) / f"{pagename}.html").open("r") as f:
248 | # Parse HTML using BeautifulSoup html parser
249 | soup = BeautifulSoup(f.read(), "html.parser")
250 |
251 | self.align_rowspans(soup)
252 | self.keep_only_data(soup)
253 | self.add_collapse_ids(soup)
254 | self.process_in_page_toc(soup)
255 | self.break_long_signatures(soup)
256 | self.customize_code_block_colors_python(soup)
257 | self.customize_code_block_colors_bash(soup)
258 | self.configure_section_icons(soup)
259 |
260 | with (Path(app.outdir) / f"{pagename}.html").open("w") as f:
261 | # Write back HTML
262 | f.write(str(soup))
263 |
264 | def __call__(self, app, exception):
265 | self.edit_html(app)
266 |
--------------------------------------------------------------------------------
/docs/docs/features.rst:
--------------------------------------------------------------------------------
1 | :fas:`gears` Features
2 | =====================
3 |
4 | The following *tasks* are supported:
5 |
6 | * **Classification** - binary classification of the presence of glasses and their types.
7 | * **Detection** - binary detection of worn/standalone glasses and eye area.
8 | * **Segmentation** - binary segmentation of glasses and their parts.
9 |
10 | Each :attr:`task` has multiple :attr:`kinds` (task categories) and model :attr:`sizes` (architectures with pre-trained weights).
11 |
12 | Classification
13 | --------------
14 |
15 | .. table:: Classification Kinds
16 | :widths: 15 31 18 18 18
17 | :name: classification-kinds
18 |
19 | +----------------+-------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
20 | | **Kind** | **Description** | **Examples** |
21 | +================+=====================================+=============================================================+=============================================================+=============================================================+
22 | | ``anyglasses`` | Identifies any kind of glasses, | .. image:: ../_static/img/classification-eyeglasses-pos.jpg | .. image:: ../_static/img/classification-sunglasses-pos.jpg | .. image:: ../_static/img/classification-no-glasses-neg.jpg |
23 | | | googles, or spectacles. +-------------------------------------------------------------+-------------------------------------------------------------+-------------------------------------------------------------+
24 | | | | .. centered:: Positive | .. centered:: Positive | .. centered:: Negative |
25 | +----------------+-------------------------------------+-------------------------------------------------------------+-------------------------------------------------------------+-------------------------------------------------------------+
26 | | ``eyeglasses`` | Identifies only transparent glasses | .. image:: ../_static/img/classification-eyeglasses-pos.jpg | .. image:: ../_static/img/classification-sunglasses-neg.jpg | .. image:: ../_static/img/classification-no-glasses-neg.jpg |
27 | | | (here referred as *eyeglasses*) +-------------------------------------------------------------+-------------------------------------------------------------+-------------------------------------------------------------+
28 | | | | .. centered:: Positive | .. centered:: Negative | .. centered:: Negative |
29 | +----------------+-------------------------------------+-------------------------------------------------------------+-------------------------------------------------------------+-------------------------------------------------------------+
30 | | ``sunglasses`` | Identifies only opaque and | .. image:: ../_static/img/classification-eyeglasses-neg.jpg | .. image:: ../_static/img/classification-sunglasses-pos.jpg | .. image:: ../_static/img/classification-no-glasses-neg.jpg |
31 | | | semi-transparent glasses (here +-------------------------------------------------------------+-------------------------------------------------------------+-------------------------------------------------------------+
32 | | | referred as *sunglasses*) | .. centered:: Negative | .. centered:: Positive | .. centered:: Negative |
33 | +----------------+-------------------------------------+-------------------------------------------------------------+-------------------------------------------------------------+-------------------------------------------------------------+
34 | | ``shadows`` | Identifies cast shadows (only | .. image:: ../_static/img/classification-shadows-pos.jpg | .. image:: ../_static/img/classification-shadows-neg.jpg | .. image:: ../_static/img/classification-no-glasses-neg.jpg |
35 | | | shadows of (any) glasses frames) +-------------------------------------------------------------+-------------------------------------------------------------+-------------------------------------------------------------+
36 | | | | .. centered:: Positive | .. centered:: Negative | .. centered:: Negative |
37 | +----------------+-------------------------------------+-------------------------------------------------------------+-------------------------------------------------------------+-------------------------------------------------------------+
38 |
39 | .. admonition:: Check classifier performances
40 | :class: tip
41 |
42 | * `Performance Information of the Pre-trained Classifiers `_: performance of each :attr:`kind`.
43 | * `Size Information of the Pre-trained Classifiers `_: efficiency of each :attr:`size`.
44 |
45 | Detection
46 | ---------
47 |
48 | .. table:: Detection Kinds
49 | :widths: 15 31 18 18 18
50 | :name: detection-kinds
51 |
52 | +----------+--------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------+
53 | | **Kind** | **Description** | **Examples** |
54 | +==========+======================================================================================+================================================+================================================+================================================+
55 | | ``eyes`` | Detects only the eye region, no glasses. | .. image:: ../_static/img/detection-eyes-0.jpg | .. image:: ../_static/img/detection-eyes-1.jpg | .. image:: ../_static/img/detection-eyes-2.jpg |
56 | +----------+--------------------------------------------------------------------------------------+------------------------------------------------+------------------------------------------------+------------------------------------------------+
57 | | ``solo`` | Detects any glasses in the wild, i.e., standalone glasses that are placed somewhere. | .. image:: ../_static/img/detection-solo-0.jpg | .. image:: ../_static/img/detection-solo-1.jpg | .. image:: ../_static/img/detection-solo-2.jpg |
58 | +----------+--------------------------------------------------------------------------------------+------------------------------------------------+------------------------------------------------+------------------------------------------------+
59 | | ``worn`` | Detects any glasses worn by people but can also detect non-worn glasses. | .. image:: ../_static/img/detection-worn-0.jpg | .. image:: ../_static/img/detection-worn-1.jpg | .. image:: ../_static/img/detection-worn-2.jpg |
60 | +----------+--------------------------------------------------------------------------------------+------------------------------------------------+------------------------------------------------+------------------------------------------------+
61 |
62 | .. admonition:: Check detector performances
63 | :class: tip
64 |
65 | * `Performance Information of the Pre-trained Detectors `_: performance of each :attr:`kind`.
66 | * `Size Information of the Pre-trained Detectors `_: efficiency of each :attr:`size`.
67 |
68 | Segmentation
69 | ------------
70 |
71 | .. table:: Segmentation Kinds
72 | :widths: 15 31 18 18 18
73 | :name: segmentation-kinds
74 |
75 | +-------------+----------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------+
76 | | **Kind** | **Description** | **Examples** |
77 | +=============+====================================================================================================+======================================================+======================================================+======================================================+
78 | | ``frames`` | Segments frames (including legs) of any glasses | .. image:: ../_static/img/segmentation-frames-0.jpg | .. image:: ../_static/img/segmentation-frames-1.jpg | .. image:: ../_static/img/segmentation-frames-2.jpg |
79 | +-------------+----------------------------------------------------------------------------------------------------+------------------------------------------------------+------------------------------------------------------+------------------------------------------------------+
80 | | ``full`` | Segments full glasses, i.e., lenses and the whole frame | .. image:: ../_static/img/segmentation-full-0.jpg | .. image:: ../_static/img/segmentation-full-1.jpg | .. image:: ../_static/img/segmentation-full-2.jpg |
81 | +-------------+----------------------------------------------------------------------------------------------------+------------------------------------------------------+------------------------------------------------------+------------------------------------------------------+
82 | | ``legs`` | Segments only frame legs of standalone glasses | .. image:: ../_static/img/segmentation-legs-0.jpg | .. image:: ../_static/img/segmentation-legs-1.jpg | .. image:: ../_static/img/segmentation-legs-2.jpg |
83 | +-------------+----------------------------------------------------------------------------------------------------+------------------------------------------------------+------------------------------------------------------+------------------------------------------------------+
84 | | ``lenses`` | Segments lenses of any glasses (both transparent and opaque). | .. image:: ../_static/img/segmentation-lenses-0.jpg | .. image:: ../_static/img/segmentation-lenses-1.jpg | .. image:: ../_static/img/segmentation-lenses-2.jpg |
85 | +-------------+----------------------------------------------------------------------------------------------------+------------------------------------------------------+------------------------------------------------------+------------------------------------------------------+
86 | | ``shadows`` | Segments cast shadows on the skin by the glasses frames only (does not consider opaque lenses). | .. image:: ../_static/img/segmentation-shadows-0.jpg | .. image:: ../_static/img/segmentation-shadows-1.jpg | .. image:: ../_static/img/segmentation-shadows-2.jpg |
87 | +-------------+----------------------------------------------------------------------------------------------------+------------------------------------------------------+------------------------------------------------------+------------------------------------------------------+
88 | | ``smart`` | Segments visible glasses parts: like ``full`` but does not segment lenses if they are transparent. | .. image:: ../_static/img/segmentation-smart-0.jpg | .. image:: ../_static/img/segmentation-smart-1.jpg | .. image:: ../_static/img/segmentation-smart-2.jpg |
89 | +-------------+----------------------------------------------------------------------------------------------------+------------------------------------------------------+------------------------------------------------------+------------------------------------------------------+
90 |
91 | .. admonition:: Check segmenter performances
92 | :class: tip
93 |
94 | * `Performance Information of the Pre-trained Segmenters `_: performance of each :attr:`kind`.
95 | * `Size Information of the Pre-trained Segmenters `_: efficiency of each :attr:`size`.
--------------------------------------------------------------------------------
/src/glasses_detector/_data/augmenter_mixin.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 |
3 | import albumentations as A
4 | import numpy as np
5 | import skimage.transform as st
6 | import torch
7 | from albumentations.pytorch import ToTensorV2
8 | from PIL import Image
9 |
10 |
11 | class ToTensor(ToTensorV2):
12 | def apply_to_mask(self, mask, **params):
13 | return torch.from_numpy((mask > 127).astype(np.float32))
14 |
15 |
16 | class AugmenterMixin:
17 | @staticmethod
18 | def default_augmentations() -> list[A.BasicTransform]:
19 | return [
20 | A.OneOf(
21 | [
22 | A.VerticalFlip(),
23 | A.HorizontalFlip(),
24 | A.RandomRotate90(),
25 | A.Transpose(),
26 | ],
27 | p=0.75,
28 | ),
29 | A.OneOf(
30 | [
31 | A.PiecewiseAffine(),
32 | A.ShiftScaleRotate(),
33 | A.ElasticTransform(),
34 | A.OpticalDistortion(distort_limit=0.1, shift_limit=0.1),
35 | A.GridDistortion(distort_limit=0.5),
36 | ]
37 | ),
38 | A.OneOf(
39 | [
40 | A.RandomBrightnessContrast(),
41 | A.ColorJitter(),
42 | A.HueSaturationValue(),
43 | A.RandomGamma(),
44 | A.CLAHE(),
45 | A.RGBShift(),
46 | ]
47 | ),
48 | A.OneOf(
49 | [
50 | A.Blur(),
51 | A.GaussianBlur(),
52 | A.MedianBlur(),
53 | A.GaussNoise(),
54 | ]
55 | ),
56 | A.OneOf(
57 | [
58 | A.RandomResizedCrop(256, 256, p=0.4),
59 | A.RandomSizedCrop((10, 131), 256, 256, p=0.4),
60 | A.RandomCrop(height=200, width=200, p=0.2),
61 | ],
62 | p=0.25,
63 | ),
64 | A.PadIfNeeded(min_height=256, min_width=256, always_apply=True),
65 | A.CoarseDropout(max_holes=10, max_height=8, max_width=8, p=0.2),
66 | A.Normalize(),
67 | ToTensor(),
68 | ]
69 |
70 | @staticmethod
71 | def minimal_augmentations() -> list[A.BasicTransform]:
72 | return [
73 | A.OneOf(
74 | [
75 | A.VerticalFlip(),
76 | A.HorizontalFlip(),
77 | A.RandomRotate90(),
78 | A.Transpose(),
79 | ],
80 | p=0.1,
81 | ),
82 | A.OneOf(
83 | [
84 | A.PiecewiseAffine((0.02, 0.03)),
85 | A.ShiftScaleRotate((-0.02, 0.02), 0.05),
86 | A.ElasticTransform(sigma=20, alpha_affine=20),
87 | A.OpticalDistortion(distort_limit=0.02, shift_limit=0.02),
88 | A.GridDistortion(num_steps=3, distort_limit=0.1),
89 | ],
90 | p=0.1,
91 | ),
92 | A.OneOf(
93 | [
94 | A.RandomBrightnessContrast(0.05, 0.05),
95 | A.ColorJitter(0.05, 0.05, 0.05),
96 | A.HueSaturationValue(5, 10, 5),
97 | A.RandomGamma((80, 100)),
98 | A.CLAHE(2, (3, 3)),
99 | A.RGBShift(5, 5, 5),
100 | ],
101 | p=0.1,
102 | ),
103 | A.OneOf(
104 | [
105 | A.Blur((3, 3)),
106 | A.GaussianBlur((3, 3)),
107 | A.MedianBlur((3, 3)),
108 | A.GaussNoise((5, 10)),
109 | ],
110 | p=0.1,
111 | ),
112 | A.OneOf(
113 | [
114 | A.RandomResizedCrop(256, 256),
115 | A.RandomSizedCrop((10, 131), 256, 256),
116 | ],
117 | p=0.1,
118 | ),
119 | A.Normalize(),
120 | ToTensor(),
121 | ]
122 |
123 | @classmethod
124 | def create_transform(
125 | cls,
126 | is_train: bool = False,
127 | **kwargs,
128 | ) -> A.Compose:
129 | # Get the list of default augmentations
130 | transform = cls.default_augmentations()
131 | kwargs = deepcopy(kwargs)
132 |
133 | if kwargs.pop("has_bbox", False):
134 | # Add bbox params
135 | kwargs.setdefault(
136 | "bbox_params",
137 | A.BboxParams(
138 | format="pascal_voc",
139 | label_fields=["bbcats"],
140 | min_visibility=0.1,
141 | **kwargs.pop("bbox_kwargs", {}),
142 | ),
143 | )
144 | if isinstance(transform[-3], A.CoarseDropout):
145 | # CoarseDropout not supported with bbox_params
146 | transform.pop(-3)
147 |
148 | if kwargs.pop("has_keys", False):
149 | # Add keypoint params
150 | kwargs.setdefault(
151 | "keypoint_params",
152 | A.KeypointParams(
153 | format="xy",
154 | label_fields=["kpcats"],
155 | remove_invisible=False,
156 | **kwargs.pop("keys_kwargs", {}),
157 | ),
158 | )
159 |
160 | if not is_train:
161 | # Only keep the last two
162 | transform = transform[-2:]
163 |
164 | return A.Compose(transform, **kwargs)
165 |
166 | @staticmethod
167 | def load_image(
168 | image: str | Image.Image | np.ndarray,
169 | resize: tuple[int, int] | None = None,
170 | is_mask: bool = False,
171 | return_orig_size: bool = False,
172 | ) -> np.ndarray:
173 | if isinstance(image, str):
174 | # Image is given as path
175 | image = Image.open(image)
176 |
177 | if isinstance(image, Image.Image):
178 | # Image is not a numpy array
179 | image = np.array(image)
180 |
181 | if is_mask:
182 | # Convert image to black & white and ensure only 1 channel
183 | image = ((image > 127).any(2) if image.ndim > 2 else (image > 127)) * 255
184 | elif image.ndim == 2:
185 | # Image isn't a mask, convert it to RGB
186 | image = np.stack([image] * 3, axis=-1)
187 |
188 | if resize is not None:
189 | # Resize image to new (w, h), preserv range from 0 to 255
190 | image = st.resize(image, resize[::-1], preserve_range=True)
191 |
192 | # Convert image to UINT8 type
193 | image = image.astype(np.uint8)
194 |
195 | if return_orig_size:
196 | # Original size as well
197 | return image, image.shape[:2][::-1]
198 |
199 | return image
200 |
201 | @staticmethod
202 | def load_boxes(
203 | boxes: str | list[list[int | float | str]],
204 | resize: tuple[int, int] | None = None,
205 | img_size: tuple[int, int] | None = None,
206 | ) -> list[list[float]]:
207 | if isinstance(boxes, str):
208 | with open(boxes, "r") as f:
209 | # Each line is bounding box: "x_min y_min x_max y_max"
210 | boxes = [xyxy.strip().split() for xyxy in f.readlines()]
211 |
212 | # Convert each coordinate in each bbox to float
213 | boxes = [list(map(float, xyxy)) for xyxy in boxes]
214 |
215 | if img_size is None:
216 | if resize is not None:
217 | raise ValueError("img_size must be provided if resize is not None")
218 |
219 | return boxes
220 |
221 | for i, box in enumerate(boxes):
222 | if box[2] <= box[0]:
223 | # Ensure x_min < x_max <= img_size[0]
224 | boxes[i][0] = min(box[0], img_size[0] - 1)
225 | boxes[i][2] = boxes[i][0] + 1
226 |
227 | if box[3] <= box[1]:
228 | # Ensure y_min < y_max <= img_size[1]
229 | boxes[i][1] = min(box[1], img_size[1] - 1)
230 | boxes[i][3] = boxes[i][1] + 1
231 |
232 | if resize is not None:
233 | # Convert boxes to new (w, h)
234 | boxes = [
235 | [
236 | box[0] * resize[0] / img_size[0],
237 | box[1] * resize[1] / img_size[1],
238 | box[2] * resize[0] / img_size[0],
239 | box[3] * resize[1] / img_size[1],
240 | ]
241 | for box in boxes
242 | ]
243 |
244 | return boxes
245 |
246 | @staticmethod
247 | def load_keypoints(
248 | keypoints: str | list[list[int | float | str]],
249 | resize: tuple[int, int] | None = None,
250 | img_size: tuple[int, int] | None = None,
251 | ) -> list[list[float]]:
252 | if isinstance(keypoints, str):
253 | with open(keypoints, "r") as f:
254 | # Each line is keypoint: "x y"
255 | keypoints = [xy.strip().split() for xy in f.readlines()]
256 |
257 | # Convert each coordinate in each keypoint to float
258 | keypoints = [list(map(float, xy)) for xy in keypoints]
259 |
260 | if img_size is None:
261 | if resize is not None:
262 | raise ValueError("img_size must be provided if resize is not None")
263 |
264 | return keypoints
265 |
266 | if resize is not None:
267 | # Convert keypoints to new (w, h)
268 | keypoints = [
269 | [
270 | keypoint[0] * resize[0] / img_size[0],
271 | keypoint[1] * resize[1] / img_size[1],
272 | ]
273 | for keypoint in keypoints
274 | ]
275 |
276 | return keypoints
277 |
278 | @classmethod
279 | def load_transform(
280 | cls,
281 | image: str | Image.Image | np.ndarray,
282 | masks: list[str | Image.Image | np.ndarray] = [],
283 | boxes: list[str | list[list[int | float | str]]] = [],
284 | bcats: list[str] = [],
285 | keys: list[str | list[list[int | float | str]]] = [],
286 | kcats: list[str] = [],
287 | resize: tuple[int, int] | None = None,
288 | transform: A.Compose | bool = False, # False means test/val
289 | ) -> torch.Tensor | tuple[torch.Tensor]:
290 | # Load the image and resize if needed (also return original size)
291 | image, orig_size = cls.load_image(image, resize, return_orig_size=True)
292 | transform_kwargs = {"image": image}
293 |
294 | if isinstance(transform, bool):
295 | # Load transform (train or test is based on bool val)
296 | transform = cls.create_transform(is_train=transform)
297 |
298 | if masks != []:
299 | # Load masks and add to transform kwargs
300 | masks = [cls.load_image(m, resize, is_mask=True) for m in masks]
301 | transform_kwargs.update({"masks": masks})
302 |
303 | if boxes != []:
304 | # Initialize flat boxes and cats
305 | flat_boxes, flat_bcats = [], []
306 |
307 | for i, b in enumerate(boxes):
308 | # Load boxes and add to transform kwargs
309 | b = cls.load_boxes(b, resize, orig_size)
310 | flat_boxes.extend(b)
311 |
312 | if bcats != []:
313 | # For each box, add corresponding cat
314 | flat_bcats.extend([bcats[i]] * len(b))
315 |
316 | # Add boxes to transform kwargs
317 | transform_kwargs.update({"bboxes": flat_boxes})
318 |
319 | if bcats != []:
320 | # Also add cats to transform kwargs
321 | transform_kwargs.update({"bbcats": flat_bcats})
322 |
323 | if keys != []:
324 | # Initialize flat keypoints and cats
325 | flat_keys, flat_kcats = [], []
326 |
327 | for i, k in enumerate(keys):
328 | # Load keypoints and add to transform kwargs
329 | k = cls.load_keypoints(k, resize, orig_size)
330 | flat_keys.extend(k)
331 |
332 | if kcats != []:
333 | # For each keypoint, add corresponding cat
334 | flat_kcats.extend([kcats[i]] * len(k))
335 |
336 | # Add keypoints to transform kwargs, update cats
337 | transform_kwargs.update({"keypoints": flat_keys})
338 |
339 | if kcats != []:
340 | # Also add cats to transform kwargs
341 | transform_kwargs.update({"kpcats": flat_kcats})
342 |
343 | # Transform everything, generate return list
344 | transformed = transform(**transform_kwargs)
345 | return_list = [transformed["image"]]
346 |
347 | for key in ["masks", "bboxes", "bbcats", "keypoints", "kpcats"]:
348 | if key not in transformed:
349 | continue
350 |
351 | if key in {"bboxes", "keypoints", "bbcats", "kpcats"}:
352 | # Convert to torch tensor if key is category or bbox/keypoint
353 | dtype = torch.long if key in ["bbcats", "kpcats"] else torch.float32
354 | transformed[key] = torch.tensor(transformed[key], dtype=dtype)
355 |
356 | # Add to the return list
357 | return_list.append(transformed[key])
358 |
359 | if len(return_list) == 1:
360 | return return_list[0]
361 |
362 | return tuple(return_list)
363 |
--------------------------------------------------------------------------------
/src/glasses_detector/components/pred_type.py:
--------------------------------------------------------------------------------
1 | """
2 | .. class:: Scalar
3 |
4 | .. data:: Scalar
5 | :noindex:
6 | :type: typing.TypeAliasType
7 | :value: StandardScalar | np.generic | np.ndarray | torch.Tensor
8 |
9 | Type alias for a scalar prediction. For more information, see
10 | :attr:`~PredType.SCALAR`.
11 |
12 | :class:`bool` | :class:`int` | :class:`float` | :class:`str`
13 | | :class:`~numpy.generic` | :class:`~numpy.ndarray`
14 | | :class:`~torch.Tensor`
15 |
16 | .. class:: Tensor
17 |
18 | .. data:: Tensor
19 | :noindex:
20 | :type: typing.TypeAliasType
21 | :value: Iterable[Scalar | Tensor] | PIL.Image.Image
22 |
23 | Type alias for a tensor prediction. For more information, see
24 | :attr:`~PredType.TENSOR`.
25 |
26 | :class:`~typing.Iterable` | :class:`~PIL.Image.Image`
27 |
28 | .. class:: Default
29 |
30 | .. data:: Default
31 | :noindex:
32 | :type: typing.TypeAliasType
33 | :value: Scalar | Tensor
34 |
35 | Type alias for a default prediction. For more information, see
36 | :attr:`~PredType.DEFAULT`.
37 |
38 | :class:`bool` | :class:`int` | :class:`float` | :class:`str`
39 | | :class:`~numpy.generic` | :class:`~numpy.ndarray`
40 | | :class:`~torch.Tensor` | :class:`~typing.Iterable`
41 |
42 | .. class:: StandardScalar
43 |
44 | .. data:: StandardScalar
45 | :noindex:
46 | :type: typing.TypeAliasType
47 | :value: bool | int | float | str
48 |
49 | Type alias for a standard scalar prediction. For more information,
50 | see :attr:`~PredType.STANDARD_SCALAR`.
51 |
52 | :class:`bool` | :class:`int` | :class:`float` | :class:`str`
53 |
54 | .. class:: StandardTensor
55 |
56 | .. data:: StandardTensor
57 | :noindex:
58 | :type: typing.TypeAliasType
59 | :value: list[StandardScalar | StandardTensor]
60 |
61 | Type alias for a standard tensor prediction. For more information,
62 | see :attr:`~PredType.STANDARD_TENSOR`.
63 |
64 | :class:`list`
65 |
66 | .. class:: StandardDefault
67 |
68 | .. data:: StandardDefault
69 | :noindex:
70 | :type: typing.TypeAliasType
71 | :value: StandardScalar | StandardTensor
72 |
73 | Type alias for a standard default prediction. For more information,
74 | see :attr:`~PredType.STANDARD_DEFAULT`.
75 |
76 | :class:`bool` | :class:`int` | :class:`float` | :class:`str`
77 | | :class:`list`
78 |
79 | .. class:: NonDefault[T]
80 |
81 | .. data:: NonDefault[T]
82 | :noindex:
83 | :type: typing.TypeAliasType
84 | :value: T
85 |
86 | Type alias for a non-default prediction. For more information, see
87 | :attr:`~PredType.NON_DEFAULT`.
88 |
89 | :data:`~typing.Any`
90 |
91 | .. class:: Either
92 |
93 | .. data:: Either
94 | :noindex:
95 | :type: typing.TypeAliasType
96 | :value: Default | NonDefault
97 |
98 | Type alias for either default or non-default prediction, i.e., any
99 | prediction.
100 |
101 | :data:`~typing.Any`
102 | """
103 | from enum import Enum, auto
104 | from typing import Any, Iterable, Self, TypeGuard
105 |
106 | import numpy as np
107 | import torch
108 | from PIL import Image
109 |
110 | type Scalar = bool | int | float | str | np.generic | np.ndarray | torch.Tensor
111 | type Tensor = Iterable[Scalar | Tensor] | Image.Image
112 | type Default = Scalar | Tensor
113 | type StandardScalar = bool | int | float | str
114 | type StandardTensor = list[StandardScalar | StandardTensor]
115 | type StandardDefault = StandardScalar | StandardTensor
116 | type NonDefault[T] = T
117 | type Anything = Default | NonDefault
118 |
119 |
120 | class PredType(Enum):
121 | """Enum class for expected prediction types.
122 |
123 | This class specifies the expected prediction types mainly for
124 | classification, detection and segmentation models that work with
125 | image data. The expected types are called **Default** and there are
126 | two categories of them:
127 |
128 | 1. **Standard**: these are the basic *Python* types, i.e.,
129 | :class:`bool`, :class:`int`, :class:`float`, :class:`str`, and
130 | :class:`list`. Standard types are easy to work with, e.g., they
131 | can be parsed by JSON and YAML formats.
132 | 2. **Non-standard**: these additionally contain types like
133 | :class:`numpy.ndarray`, :class:`torch.Tensor`, and
134 | :class:`PIL.Image.Image`. They are convenient due to more
135 | flexibility for model prediction outputs. In most cases, they can
136 | be converted to standard types via :meth:`standardize`.
137 |
138 | Note:
139 | The constants defined in this class are only enums, not
140 | actual classes or :class:`type` objects. For type hints,
141 | corresponding type aliases are defined in the same file as this
142 | class.
143 |
144 | Warning:
145 | The enum types are not exclusive, for example, :attr:`SCALAR`
146 | is also a :attr:`DEFAULT`.
147 |
148 | Examples
149 | --------
150 |
151 | Type aliases (:class:`~typing.TypeAliasType` objects defined
152 | using :class:`type` keyword) corresponding to the enums of
153 | this class are defined in the same file as :class:`PredType`. They
154 | can be used to specify the expected types when defining the methods:
155 |
156 | .. code-block:: python
157 |
158 | >>> from glasses_detector.components.pred_type import StandardScalar
159 | >>> def predict_class(
160 | ... self,
161 | ... image: Image.Image,
162 | ... output_format: str = "score",
163 | ... ) -> StandardScalar:
164 | ... ...
165 |
166 | :class:`PredType` static and class methods can be used to check
167 | the type of the prediction:
168 |
169 | .. code-block:: python
170 |
171 | >>> PredType.is_standard_scalar(1)
172 | True
173 | >>> PredType.is_standard_scalar(np.array([1])[0])
174 | False
175 | >>> PredType.is_default(Image.fromarray(np.zeros((1, 1))))
176 | True
177 |
178 | Finally, :meth:`standardize` can be used to convert the
179 | prediction to a standard type:
180 |
181 | .. code-block:: python
182 |
183 | >>> PredType.standardize(np.array([1, 2, 3]))
184 | [1, 2, 3]
185 | >>> PredType.standardize(Image.fromarray(np.zeros((1, 1))))
186 | [[0.0]]
187 | """
188 |
189 | """Enum: Scalar type. A prediction is considered to be a scalar if"""
190 |
191 | SCALAR = auto()
192 | """
193 | PredType: Scalar type. A prediction is considered to be a scalar if
194 | it is one of the following types:
195 |
196 | * :class:`bool`
197 | * :class:`int`
198 | * :class:`float`
199 | * :class:`str`
200 | * :class:`numpy.generic`
201 | * :class:`numpy.ndarray` with ``ndim == 0``
202 | * :class:`torch.Tensor` with ``ndim == 0``
203 |
204 | :meta hide-value:
205 | """
206 |
207 | TENSOR = auto()
208 | """
209 | PredType: Tensor type. A prediction is considered to be a tensor if
210 | it is one of the following types:
211 |
212 | * :class:`PIL.Image.Image`
213 | * :class:`~typing.Iterable` of scalars or tensors of any iterable
214 | type, including :class:`list`, :class:`tuple`,
215 | :class:`~typing.Collection`, :class:`numpy.ndarray` and
216 | :class:`torch.Tensor` objects, and any other iterables.
217 |
218 | :meta hide-value:
219 | """
220 |
221 | DEFAULT = auto()
222 | """
223 | PredType: Default type. A prediction is considered to be a default
224 | type if it is one of the following types:
225 |
226 | * Any of the types defined in :attr:`SCALAR`.
227 | * Any of the types defined in :attr:`TENSOR`.
228 |
229 | :meta hide-value:
230 | """
231 |
232 | STANDARD_SCALAR = auto()
233 | """
234 | PredType: Standard scalar type. A prediction is considered to be a
235 | standard scalar if it is one of the following types:
236 |
237 | * :class:`bool`
238 | * :class:`int`
239 | * :class:`float`
240 | * :class:`str`
241 |
242 | :meta hide-value:
243 | """
244 |
245 | STANDARD_TENSOR = auto()
246 | """
247 | PredType: Standard tensor type. A prediction is considered to be a
248 | standard tensor if it is one of the following types:
249 |
250 | * :class:`list` of standard scalars or standard tensors. No other
251 | iterables than lists are allowed.
252 |
253 | :meta hide-value:
254 | """
255 |
256 | STANDARD_DEFAULT = auto()
257 | """
258 | PredType: Standard default type. A prediction is considered to be a
259 | standard default type if it is one of the following types:
260 |
261 | * Any of the types defined in :attr:`STANDARD_SCALAR`.
262 | * Any of the types defined in :attr:`STANDARD_TENSOR`.
263 |
264 | :meta hide-value:
265 | """
266 |
267 | NON_DEFAULT = auto()
268 | """
269 | PredType: Non-default type. A prediction is considered to be a
270 | non-default type if it is not a default type, i.e., it is not any of
271 | the types defined in :attr:`DEFAULT`.
272 |
273 | :meta hide-value:
274 | """
275 |
276 | @staticmethod
277 | def is_scalar(pred: Any) -> TypeGuard[Scalar]:
278 | """Checks if the prediction is a **scalar**.
279 |
280 | .. seealso::
281 |
282 | :attr:`SCALAR`
283 |
284 | Args:
285 | pred: The value to check.
286 |
287 | Returns:
288 | :data:`True` if the value is a **scalar**, :data:`False`
289 | otherwise.
290 | """
291 | return isinstance(pred, (bool, int, float, str, np.generic)) or (
292 | isinstance(pred, (torch.Tensor, np.ndarray)) and pred.ndim == 0
293 | )
294 |
295 | @staticmethod
296 | def is_standard_scalar(pred: Any) -> TypeGuard[StandardScalar]:
297 | """Checks if the prediction is a **standard scalar**.
298 |
299 | .. seealso::
300 |
301 | :attr:`STANDARD_SCALAR`
302 |
303 | Args:
304 | pred: The value to check.
305 |
306 | Returns:
307 | :data:`True` if the value is a **standard scalar**,
308 | :data:`False` otherwise.
309 | """
310 | return isinstance(pred, (bool, int, float, str))
311 |
312 | @classmethod
313 | def is_tensor(cls, pred: Any) -> TypeGuard[Tensor]:
314 | """Checks if the prediction is a **tensor**.
315 |
316 | .. seealso::
317 |
318 | :attr:`TENSOR`
319 |
320 | Args:
321 | pred: The value to check.
322 |
323 | Returns:
324 | :data:`True` if the value is a **tensor**,
325 | :data:`False` otherwise.
326 | """
327 | return isinstance(pred, Image.Image) or (
328 | isinstance(pred, Iterable)
329 | and not cls.is_scalar(pred)
330 | and all([cls.is_scalar(p) or cls.is_tensor(p) for p in pred])
331 | )
332 |
333 | @classmethod
334 | def is_standard_tensor(cls, pred: Any) -> TypeGuard[StandardTensor]:
335 | """Checks if the prediction is a **standard tensor**.
336 |
337 | .. seealso::
338 |
339 | :attr:`STANDARD_TENSOR`
340 |
341 | Args:
342 | pred: The value to check.
343 |
344 | Returns:
345 | :data:`True` if the value is a **standard tensor**,
346 | :data:`False` otherwise.
347 | """
348 | return isinstance(pred, list) and all(
349 | [cls.is_standard_scalar(p) or cls.is_standard_tensor(p) for p in pred]
350 | )
351 |
352 | @classmethod
353 | def is_default(cls, pred: Any) -> TypeGuard[Default]:
354 | """Checks if the prediction is a **default** type.
355 |
356 | .. seealso::
357 |
358 | :attr:`DEFAULT`
359 |
360 | Args:
361 | pred: The value to check.
362 |
363 | Returns:
364 | :data:`True` if the type of the value is **default**,
365 | :data:`False` otherwise.
366 | """
367 | return cls.is_scalar(pred) or cls.is_tensor(pred)
368 |
369 | @classmethod
370 | def is_standard_default(cls, pred: Any) -> TypeGuard[StandardDefault]:
371 | """Checks if the prediction is a **standard default** type.
372 |
373 | .. seealso::
374 |
375 | :attr:`STANDARD_DEFAULT`
376 |
377 | Args:
378 | pred: The value to check.
379 |
380 | Returns:
381 | :data:`True` if the type of the value is **standard
382 | default**, :data:`False` otherwise.
383 | """
384 | return cls.is_standard_scalar(pred) or cls.is_standard_tensor(pred)
385 |
386 | @classmethod
387 | def check(cls, pred: Any) -> Self:
388 | """Checks the type of the prediction and returns its enum.
389 |
390 | Checks the type of the prediction and returns the corresponding
391 | enum of the lowest type category. First, it checks if the
392 | prediction is a **standard scalar** or a regular **scalar** (in
393 | that order). If not, it checks if the prediction is a **standard
394 | tensor** or a regular **tensor** (in that order). Finally, if
395 | none of the previous checks are successful, it returns
396 | :attr:`NON_DEFAULT`.
397 |
398 | Note:
399 | All four types, i.e., :attr:`STANDARD_SCALAR`,
400 | :attr:`SCALAR`, :attr:`STANDARD_TENSOR`, and :attr:`TENSOR`
401 | are subclasses of :attr:`DEFAULT`.
402 |
403 | Args:
404 | pred: The value to check.
405 |
406 | Returns:
407 | The corresponding enum of the lowest type category or
408 | :attr:`NON_DEFAULT` if no **default** category is
409 | applicable.
410 | """
411 | if cls.is_standard_scalar(pred):
412 | return cls.STANDARD_SCALAR
413 | elif cls.is_scalar(pred):
414 | return cls.SCALAR
415 | elif cls.is_standard_tensor(pred):
416 | return cls.STANDARD_TENSOR
417 | elif cls.is_tensor(pred):
418 | return cls.TENSOR
419 | else:
420 | return cls.NON_DEFAULT
421 |
422 | @classmethod
423 | def standardize(cls, pred: Default) -> StandardDefault:
424 | """Standardize the prediction.
425 |
426 | Standardize the prediction to a **standard default** type. If
427 | the prediction is already a **standard default** type, it is
428 | returned as-is. Otherwise, it is converted to a **standard
429 | default** type using the following rules:
430 |
431 | * :class:`bool`, :class:`int`, :class:`float`, and :class:`str`
432 | are returned as-is.
433 | * :class:`numpy.generic` and :class:`numpy.ndarray` with
434 | :attr:`~numpy.ndarray.ndim` = ``0`` are converted to
435 | **standard scalars**.
436 | * :class:`torch.Tensor` with :attr:`~numpy.ndarray.ndim` = ``0``
437 | is converted to **standard scalars**.
438 | * :class:`numpy.ndarray` and :class:`torch.Tensor` with
439 | :attr:`~numpy.ndarray.ndim` > ``0`` are converted to
440 | **standard tensors**.
441 | * :class:`PIL.Image.Image` is converted to **standard tensors**
442 | by converting it to a :class:`numpy.ndarray` and then
443 | applying the previous rule.
444 | * All other iterables are converted to **standard tensors** by
445 | applying the previous rule to each element.
446 |
447 | Args:
448 | pred: The **default** prediction to standardize.
449 |
450 | Raises:
451 | ValueError: If the prediction cannot be standardized. This
452 | can happen if a prediction is not **default** or if a
453 | class, such as :class:`torch.Tensor` or
454 | :class:`numpy.ndarray` returns a scalar that is not of
455 | type defined in :attr:`SCALAR`.
456 |
457 | Returns:
458 | The standardized prediction.
459 | """
460 | if isinstance(pred, (bool, int, float, str)):
461 | return pred
462 | elif isinstance(pred, np.generic):
463 | return cls.standardize(pred.item())
464 | elif isinstance(pred, (np.ndarray, torch.Tensor)) and pred.ndim == 0:
465 | return cls.standardize(pred.item())
466 | elif isinstance(pred, Image.Image):
467 | return np.asarray(pred).tolist()
468 | elif isinstance(pred, Iterable):
469 | return [cls.standardize(item) for item in pred]
470 | else:
471 | raise ValueError(f"Cannot standardize {type(pred)}")
472 |
--------------------------------------------------------------------------------