├── data └── demo │ ├── 0.jpg │ ├── 1.jpg │ ├── 2.jpg │ ├── 3.jpg │ ├── 4.jpg │ └── 5.jpg ├── docs ├── _static │ ├── inv │ │ └── tqdm.inv │ ├── img │ │ ├── banner.jpg │ │ ├── logo-dark.png │ │ ├── logo-light.png │ │ ├── detection-eyes-0.jpg │ │ ├── detection-eyes-1.jpg │ │ ├── detection-eyes-2.jpg │ │ ├── detection-solo-0.jpg │ │ ├── detection-solo-1.jpg │ │ ├── detection-solo-2.jpg │ │ ├── detection-worn-0.jpg │ │ ├── detection-worn-1.jpg │ │ ├── detection-worn-2.jpg │ │ ├── segmentation-full-0.jpg │ │ ├── segmentation-full-1.jpg │ │ ├── segmentation-full-2.jpg │ │ ├── segmentation-legs-0.jpg │ │ ├── segmentation-legs-1.jpg │ │ ├── segmentation-legs-2.jpg │ │ ├── segmentation-smart-0.jpg │ │ ├── segmentation-smart-1.jpg │ │ ├── segmentation-smart-2.jpg │ │ ├── segmentation-frames-0.jpg │ │ ├── segmentation-frames-1.jpg │ │ ├── segmentation-frames-2.jpg │ │ ├── segmentation-lenses-0.jpg │ │ ├── segmentation-lenses-1.jpg │ │ ├── segmentation-lenses-2.jpg │ │ ├── segmentation-shadows-0.jpg │ │ ├── segmentation-shadows-1.jpg │ │ ├── segmentation-shadows-2.jpg │ │ ├── classification-shadows-neg.jpg │ │ ├── classification-shadows-pos.jpg │ │ ├── classification-eyeglasses-neg.jpg │ │ ├── classification-eyeglasses-pos.jpg │ │ ├── classification-no-glasses-neg.jpg │ │ ├── classification-sunglasses-neg.jpg │ │ └── classification-sunglasses-pos.jpg │ ├── css │ │ ├── signatures.css │ │ ├── custom.css │ │ └── highlights.css │ ├── js │ │ ├── colab-icon.js │ │ ├── pypi-icon.js │ │ └── zenodo-icon.js │ ├── bib │ │ └── references.bib │ └── svg │ │ └── colab.svg ├── helpers │ ├── __init__.py │ ├── custom_invs.py │ ├── generate_examples.py │ └── build_finished.py ├── requirements.txt ├── docs │ ├── api │ │ ├── glasses_detector.utils.rst │ │ ├── glasses_detector.components.base_model.rst │ │ ├── glasses_detector.components.pred_type.rst │ │ ├── glasses_detector.architectures.tiny_binary_detector.rst │ │ ├── glasses_detector.architectures.tiny_binary_segmenter.rst │ │ ├── glasses_detector.components.pred_interface.rst │ │ ├── glasses_detector.architectures.tiny_binary_classifier.rst │ │ ├── glasses_detector.detector.rst │ │ ├── glasses_detector.segmenter.rst │ │ └── glasses_detector.classifier.rst │ ├── api.rst │ ├── credits.rst │ ├── install.rst │ ├── cli.rst │ ├── examples.rst │ └── features.rst ├── Makefile ├── make.bat ├── index.rst ├── conf.yaml └── conf.py ├── src └── glasses_detector │ ├── components │ ├── __init__.py │ └── pred_type.py │ ├── __init__.py │ ├── _wrappers │ ├── __init__.py │ ├── binary_detector.py │ ├── binary_segmenter.py │ ├── binary_classifier.py │ └── metrics.py │ ├── architectures │ ├── __init__.py │ ├── tiny_binary_classifier.py │ ├── tiny_binary_segmenter.py │ └── tiny_binary_detector.py │ ├── _data │ ├── __init__.py │ ├── binary_segmentation_dataset.py │ ├── binary_classification_dataset.py │ ├── binary_detection_dataset.py │ ├── base_categorized_dataset.py │ └── augmenter_mixin.py │ ├── utils.py │ └── __main__.py ├── .gitignore ├── requirements.txt ├── CITATION.cff ├── .github └── workflows │ ├── python-publish.yaml │ └── sphinx.yaml ├── LICENSE ├── scripts ├── analyse.py └── run.py └── CODE_OF_CONDUCT.md /data/demo/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/data/demo/0.jpg -------------------------------------------------------------------------------- /data/demo/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/data/demo/1.jpg -------------------------------------------------------------------------------- /data/demo/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/data/demo/2.jpg -------------------------------------------------------------------------------- /data/demo/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/data/demo/3.jpg -------------------------------------------------------------------------------- /data/demo/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/data/demo/4.jpg -------------------------------------------------------------------------------- /data/demo/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/data/demo/5.jpg -------------------------------------------------------------------------------- /docs/_static/inv/tqdm.inv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/inv/tqdm.inv -------------------------------------------------------------------------------- /docs/helpers/__init__.py: -------------------------------------------------------------------------------- 1 | from .build_finished import BuildFinished 2 | from .custom_invs import CustomInvs 3 | -------------------------------------------------------------------------------- /docs/_static/img/banner.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/banner.jpg -------------------------------------------------------------------------------- /docs/_static/img/logo-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/logo-dark.png -------------------------------------------------------------------------------- /docs/_static/img/logo-light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/logo-light.png -------------------------------------------------------------------------------- /src/glasses_detector/components/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_model import BaseGlassesModel 2 | from .pred_type import PredType 3 | -------------------------------------------------------------------------------- /docs/_static/img/detection-eyes-0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/detection-eyes-0.jpg -------------------------------------------------------------------------------- /docs/_static/img/detection-eyes-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/detection-eyes-1.jpg -------------------------------------------------------------------------------- /docs/_static/img/detection-eyes-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/detection-eyes-2.jpg -------------------------------------------------------------------------------- /docs/_static/img/detection-solo-0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/detection-solo-0.jpg -------------------------------------------------------------------------------- /docs/_static/img/detection-solo-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/detection-solo-1.jpg -------------------------------------------------------------------------------- /docs/_static/img/detection-solo-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/detection-solo-2.jpg -------------------------------------------------------------------------------- /docs/_static/img/detection-worn-0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/detection-worn-0.jpg -------------------------------------------------------------------------------- /docs/_static/img/detection-worn-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/detection-worn-1.jpg -------------------------------------------------------------------------------- /docs/_static/img/detection-worn-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/detection-worn-2.jpg -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx 2 | sphinx-design 3 | sphinx-copybutton 4 | sphinxcontrib-bibtex 5 | pydata-sphinx-theme 6 | sphobjinv 7 | -------------------------------------------------------------------------------- /docs/_static/img/segmentation-full-0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/segmentation-full-0.jpg -------------------------------------------------------------------------------- /docs/_static/img/segmentation-full-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/segmentation-full-1.jpg -------------------------------------------------------------------------------- /docs/_static/img/segmentation-full-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/segmentation-full-2.jpg -------------------------------------------------------------------------------- /docs/_static/img/segmentation-legs-0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/segmentation-legs-0.jpg -------------------------------------------------------------------------------- /docs/_static/img/segmentation-legs-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/segmentation-legs-1.jpg -------------------------------------------------------------------------------- /docs/_static/img/segmentation-legs-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/segmentation-legs-2.jpg -------------------------------------------------------------------------------- /docs/_static/img/segmentation-smart-0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/segmentation-smart-0.jpg -------------------------------------------------------------------------------- /docs/_static/img/segmentation-smart-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/segmentation-smart-1.jpg -------------------------------------------------------------------------------- /docs/_static/img/segmentation-smart-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/segmentation-smart-2.jpg -------------------------------------------------------------------------------- /docs/_static/img/segmentation-frames-0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/segmentation-frames-0.jpg -------------------------------------------------------------------------------- /docs/_static/img/segmentation-frames-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/segmentation-frames-1.jpg -------------------------------------------------------------------------------- /docs/_static/img/segmentation-frames-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/segmentation-frames-2.jpg -------------------------------------------------------------------------------- /docs/_static/img/segmentation-lenses-0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/segmentation-lenses-0.jpg -------------------------------------------------------------------------------- /docs/_static/img/segmentation-lenses-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/segmentation-lenses-1.jpg -------------------------------------------------------------------------------- /docs/_static/img/segmentation-lenses-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/segmentation-lenses-2.jpg -------------------------------------------------------------------------------- /docs/_static/img/segmentation-shadows-0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/segmentation-shadows-0.jpg -------------------------------------------------------------------------------- /docs/_static/img/segmentation-shadows-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/segmentation-shadows-1.jpg -------------------------------------------------------------------------------- /docs/_static/img/segmentation-shadows-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/segmentation-shadows-2.jpg -------------------------------------------------------------------------------- /docs/_static/img/classification-shadows-neg.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/classification-shadows-neg.jpg -------------------------------------------------------------------------------- /docs/_static/img/classification-shadows-pos.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/classification-shadows-pos.jpg -------------------------------------------------------------------------------- /docs/_static/img/classification-eyeglasses-neg.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/classification-eyeglasses-neg.jpg -------------------------------------------------------------------------------- /docs/_static/img/classification-eyeglasses-pos.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/classification-eyeglasses-pos.jpg -------------------------------------------------------------------------------- /docs/_static/img/classification-no-glasses-neg.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/classification-no-glasses-neg.jpg -------------------------------------------------------------------------------- /docs/_static/img/classification-sunglasses-neg.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/classification-sunglasses-neg.jpg -------------------------------------------------------------------------------- /docs/_static/img/classification-sunglasses-pos.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mantasu/glasses-detector/HEAD/docs/_static/img/classification-sunglasses-pos.jpg -------------------------------------------------------------------------------- /src/glasses_detector/__init__.py: -------------------------------------------------------------------------------- 1 | from .classifier import GlassesClassifier 2 | from .detector import GlassesDetector 3 | from .segmenter import GlassesSegmenter 4 | -------------------------------------------------------------------------------- /docs/docs/api/glasses_detector.utils.rst: -------------------------------------------------------------------------------- 1 | Utilities 2 | ========= 3 | 4 | .. automodule:: glasses_detector.utils 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/docs/api/glasses_detector.components.base_model.rst: -------------------------------------------------------------------------------- 1 | Base Model 2 | ========== 3 | 4 | .. automodule:: glasses_detector.components.base_model 5 | :members: 6 | :show-inheritance: 7 | -------------------------------------------------------------------------------- /src/glasses_detector/_wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | from .binary_classifier import BinaryClassifier 2 | from .binary_detector import BinaryDetector 3 | from .binary_segmenter import BinarySegmenter 4 | -------------------------------------------------------------------------------- /src/glasses_detector/architectures/__init__.py: -------------------------------------------------------------------------------- 1 | from .tiny_binary_classifier import TinyBinaryClassifier 2 | from .tiny_binary_detector import TinyBinaryDetector 3 | from .tiny_binary_segmenter import TinyBinarySegmenter 4 | -------------------------------------------------------------------------------- /docs/docs/api/glasses_detector.components.pred_type.rst: -------------------------------------------------------------------------------- 1 | Prediction Type 2 | =============== 3 | 4 | .. automodule:: glasses_detector.components.pred_type 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Data 2 | data/* 3 | !data/demo 4 | 5 | # Training 6 | checkpoints 7 | lightning_logs 8 | 9 | # IDE 10 | .vscode 11 | 12 | # Build 13 | dist 14 | build 15 | _build 16 | _templates 17 | __pycache__ 18 | *.egg-info -------------------------------------------------------------------------------- /docs/docs/api/glasses_detector.architectures.tiny_binary_detector.rst: -------------------------------------------------------------------------------- 1 | Tiny Binary Detector 2 | ==================== 3 | 4 | .. automodule:: glasses_detector.architectures.tiny_binary_detector 5 | :members: 6 | :show-inheritance: 7 | -------------------------------------------------------------------------------- /docs/docs/api/glasses_detector.architectures.tiny_binary_segmenter.rst: -------------------------------------------------------------------------------- 1 | Tiny Binary Segmenter 2 | ===================== 3 | 4 | .. automodule:: glasses_detector.architectures.tiny_binary_segmenter 5 | :members: 6 | :show-inheritance: 7 | -------------------------------------------------------------------------------- /docs/docs/api/glasses_detector.components.pred_interface.rst: -------------------------------------------------------------------------------- 1 | Prediction Interface 2 | ==================== 3 | 4 | .. automodule:: glasses_detector.components.pred_interface 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/docs/api/glasses_detector.architectures.tiny_binary_classifier.rst: -------------------------------------------------------------------------------- 1 | Tiny Binary Classifier 2 | ====================== 3 | 4 | .. automodule:: glasses_detector.architectures.tiny_binary_classifier 5 | :members: 6 | :show-inheritance: 7 | 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | torch 3 | scipy 4 | pyyaml 5 | fvcore 6 | rarfile 7 | ipykernel 8 | pycocotools 9 | torchvision 10 | prettytable 11 | albumentations 12 | pytorch_lightning 13 | tensorboard 14 | torch-tb-profiler 15 | jsonargparse[signatures] -------------------------------------------------------------------------------- /docs/docs/api/glasses_detector.detector.rst: -------------------------------------------------------------------------------- 1 | Detector 2 | ======== 3 | 4 | .. automodule:: glasses_detector.detector 5 | :members: 6 | :exclude-members: create_model, save, forward, model_info, BASE_WEIGHTS_URL, ALLOWED_SIZE_ALIASES, DEFAULT_SIZE_MAP, DEFAULT_KIND_MAP 7 | :inherited-members: 8 | :show-inheritance: 9 | -------------------------------------------------------------------------------- /docs/docs/api/glasses_detector.segmenter.rst: -------------------------------------------------------------------------------- 1 | Segmenter 2 | ========= 3 | 4 | .. automodule:: glasses_detector.segmenter 5 | :members: 6 | :exclude-members: create_model, save, forward, model_info, BASE_WEIGHTS_URL, ALLOWED_SIZE_ALIASES, DEFAULT_SIZE_MAP, DEFAULT_KIND_MAP 7 | :inherited-members: 8 | :show-inheritance: 9 | -------------------------------------------------------------------------------- /docs/docs/api/glasses_detector.classifier.rst: -------------------------------------------------------------------------------- 1 | Classifier 2 | ========== 3 | 4 | .. automodule:: glasses_detector.classifier 5 | :members: 6 | :exclude-members: create_model, save, forward, model_info, BASE_WEIGHTS_URL, ALLOWED_SIZE_ALIASES, DEFAULT_SIZE_MAP, DEFAULT_KIND_MAP 7 | :inherited-members: 8 | :show-inheritance: 9 | -------------------------------------------------------------------------------- /src/glasses_detector/_data/__init__.py: -------------------------------------------------------------------------------- 1 | from .augmenter_mixin import AugmenterMixin 2 | from .base_categorized_dataset import BaseCategorizedDataset 3 | from .binary_classification_dataset import BinaryClassificationDataset 4 | from .binary_detection_dataset import BinaryDetectionDataset 5 | from .binary_segmentation_dataset import BinarySegmentationDataset 6 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: >- 3 | If you find Glasses Detector useful in your work, please 4 | consider citing the following BibTeX entry. 5 | authors: 6 | - family-names: "Birškus" 7 | given-names: "Mantas" 8 | title: "Glasses Detector" 9 | journal: "GitHub repository" 10 | publisher: "GitHub" 11 | url: "https://github.com/mantasu/glasses-detector" 12 | license: "MIT" 13 | type: software 14 | date-released: 2024-03-06 15 | doi: 10.5281/zenodo.8126101 -------------------------------------------------------------------------------- /src/glasses_detector/_data/binary_segmentation_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import override 2 | 3 | import torch 4 | 5 | from .base_categorized_dataset import BaseCategorizedDataset 6 | 7 | 8 | class BinarySegmentationDataset(BaseCategorizedDataset): 9 | @override 10 | def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]: 11 | sample = self.data[index] 12 | image, masks = self.load_transform( 13 | image=sample[self.img_folder], 14 | masks=[sample["masks"]], 15 | transform=self.transform, 16 | ) 17 | 18 | return image, torch.stack(masks) 19 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_static/css/signatures.css: -------------------------------------------------------------------------------- 1 | /* Newlines (\a) and spaces (\20) before each parameter */ 2 | .long-sig.sig-param::before { 3 | content: "\a\20\20\20\20\20\20\20\20"; 4 | white-space: pre; 5 | } 6 | 7 | /* Newlines (\a) and spaces (\20) before each parameter separator */ 8 | span.long-sig::before { 9 | content: "\a\20\20\20\20\20\20\20\20"; 10 | white-space: pre; 11 | } 12 | 13 | /* Newline after the last parameter (so the closing bracket is on a new line) */ 14 | dt em.long-sig.sig-param:last-of-type::after { 15 | content: "\a"; 16 | white-space: pre; 17 | } 18 | 19 | /* To have blue background of width of the block (instead of width of content) */ 20 | dl.class > dt:first-of-type { 21 | display: block !important; 22 | } 23 | -------------------------------------------------------------------------------- /src/glasses_detector/_data/binary_classification_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import override 2 | 3 | import torch 4 | 5 | from .base_categorized_dataset import BaseCategorizedDataset 6 | 7 | 8 | class BinaryClassificationDataset(BaseCategorizedDataset): 9 | def __post_init__(self): 10 | # Flatten (some image names may have been the same across cats) 11 | self.data = [dict([cat]) for d in self.data for cat in d.items()] 12 | 13 | @override 14 | def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]: 15 | cat, pth = next(iter(self.data[index].items())) 16 | label = self.cat2tensor(cat) 17 | image = self.load_transform(image=pth, transform=self.transform) 18 | 19 | return image, label 20 | -------------------------------------------------------------------------------- /docs/docs/api.rst: -------------------------------------------------------------------------------- 1 | :fas:`book` API 2 | =============== 3 | 4 | Models 5 | ------ 6 | 7 | .. toctree:: 8 | :maxdepth: 1 9 | 10 | api/glasses_detector.classifier 11 | api/glasses_detector.detector 12 | api/glasses_detector.segmenter 13 | 14 | 15 | Components 16 | ---------- 17 | 18 | .. toctree:: 19 | :maxdepth: 1 20 | 21 | api/glasses_detector.components.base_model 22 | api/glasses_detector.components.pred_interface 23 | api/glasses_detector.components.pred_type 24 | api/glasses_detector.utils 25 | 26 | 27 | Architectures 28 | ------------- 29 | 30 | .. toctree:: 31 | :maxdepth: 1 32 | 33 | api/glasses_detector.architectures.tiny_binary_classifier 34 | api/glasses_detector.architectures.tiny_binary_detector 35 | api/glasses_detector.architectures.tiny_binary_segmenter -------------------------------------------------------------------------------- /.github/workflows/python-publish.yaml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | workflow_dispatch: 8 | 9 | permissions: 10 | contents: read 11 | id-token: write 12 | 13 | jobs: 14 | deploy: 15 | runs-on: ubuntu-latest 16 | environment: 17 | name: pypi 18 | url: https://pypi.org/p/glasses-detector 19 | steps: 20 | - name: Set up Python 21 | uses: actions/setup-python@v5 22 | with: 23 | python-version: '3.12' 24 | - name: Checkout 25 | uses: actions/checkout@v4 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@release/v1 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/docs/credits.rst: -------------------------------------------------------------------------------- 1 | :fas:`heart` Credits 2 | ==================== 3 | 4 | Support 5 | ------- 6 | 7 | The easiest way to get support is to open an issue on |issues_link|. 8 | 9 | .. |issues_link| raw:: html 10 | 11 | 12 | 13 | GitHub 14 | 15 | 16 | 17 | Citation 18 | -------- 19 | 20 | .. code-block:: bibtex 21 | 22 | @software{Birskus_Glasses_Detector_2024, 23 | author = {Birškus, Mantas}, 24 | title = {{Glasses Detector}}, 25 | license = {MIT}, 26 | url = {https://github.com/mantasu/glasses-detector}, 27 | month = {3}, 28 | year = {2024}, 29 | doi = {10.5281/zenodo.8126101} 30 | } 31 | 32 | 33 | References 34 | ---------- 35 | 36 | .. bibliography:: 37 | :style: unsrt -------------------------------------------------------------------------------- /docs/_static/js/colab-icon.js: -------------------------------------------------------------------------------- 1 | FontAwesome.library.add( 2 | (faListOldStyle = { 3 | prefix: "fa-custom", 4 | iconName: "colab", 5 | icon: [ 6 | 24, // viewBox width 7 | 24, // viewBox height 8 | [], // ligature 9 | "e001", // unicode codepoint - private use area 10 | "M16.9414 4.9757a7.033 7.033 0 0 0-4.9308 2.0646 7.033 7.033 0 0 0-.1232 9.8068l2.395-2.395a3.6455 3.6455 0 0 1 5.1497-5.1478l2.397-2.3989a7.033 7.033 0 0 0-4.8877-1.9297zM7.07 4.9855a7.033 7.033 0 0 0-4.8878 1.9316l2.3911 2.3911a3.6434 3.6434 0 0 1 5.0227.1271l1.7341-2.9737-.0997-.0802A7.033 7.033 0 0 0 7.07 4.9855zm15.0093 2.1721l-2.3892 2.3911a3.6455 3.6455 0 0 1-5.1497 5.1497l-2.4067 2.4068a7.0362 7.0362 0 0 0 9.9456-9.9476zM1.932 7.1674a7.033 7.033 0 0 0-.002 9.6816l2.397-2.397a3.6434 3.6434 0 0 1-.004-4.8916zm7.664 7.4235c-1.38 1.3816-3.5863 1.411-5.0168.1134l-2.397 2.395c2.4693 2.3328 6.263 2.5753 9.0072.5455l.1368-.1115z", // svg path (https://simpleicons.org/?q=colab) 11 | ], 12 | }) 13 | ); -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Mantas Birškus 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docs/docs/install.rst: -------------------------------------------------------------------------------- 1 | :fas:`download` Installation 2 | ============================ 3 | 4 | The packages requires at least `Python 3.12 `_. To install the package via `pip `_, simply run: 5 | 6 | .. code-block:: bash 7 | 8 | pip install glasses-detector 9 | 10 | Or, to install it from source, run: 11 | 12 | .. code-block:: bash 13 | 14 | git clone https://github.com/mantasu/glasses-detector 15 | cd glasses-detector && pip install . 16 | 17 | .. tip:: 18 | 19 | You may want to set up `PyTorch `_ in advance to enable **GPU** support for your device. Note that *CUDA* is backwards compatible, thus even if you have the newest version of `CUDA Toolkit `_, *PyTorch* **GPU** acceleration should work just fine. 20 | 21 | .. note:: 22 | 23 | By default, the required models will be automatically downloaded and saved under *Torch Hub* directory, which by default is ``~/.cache/torch/hub/checkpoints``. For more information and how to change it, see `Torch Hub documentation `_. -------------------------------------------------------------------------------- /.github/workflows/sphinx.yaml: -------------------------------------------------------------------------------- 1 | name: docs 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | workflow_dispatch: 8 | 9 | permissions: 10 | contents: read 11 | pages: write 12 | id-token: write 13 | 14 | concurrency: 15 | group: "pages" 16 | cancel-in-progress: false 17 | 18 | jobs: 19 | deploy: 20 | environment: 21 | name: github-pages 22 | url: ${{ steps.deployment.outputs.page_url }} 23 | runs-on: ubuntu-latest 24 | steps: 25 | - name: Set up Python 26 | uses: actions/setup-python@v5 27 | with: 28 | python-version: "3.12" 29 | - name: Checkout 30 | uses: actions/checkout@v4 31 | - name: Install dependencies 32 | run: | 33 | python -m pip install --upgrade pip 34 | pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu 35 | pip install . 36 | pip install --upgrade setuptools 37 | pip install -r docs/requirements.txt 38 | - name: Build HTML 39 | run: | 40 | cd docs/ 41 | make html 42 | - name: Setup Pages 43 | uses: actions/configure-pages@v4 44 | - name: Upload artifact 45 | uses: actions/upload-pages-artifact@v3 46 | with: 47 | path: 'docs/_build/html' 48 | - name: Deploy to GitHub Pages 49 | id: deployment 50 | uses: actions/deploy-pages@v4 51 | -------------------------------------------------------------------------------- /src/glasses_detector/_data/binary_detection_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import override 2 | 3 | import albumentations as A 4 | import torch 5 | from torch.utils.data import DataLoader 6 | 7 | from .base_categorized_dataset import BaseCategorizedDataset 8 | 9 | 10 | class BinaryDetectionDataset(BaseCategorizedDataset): 11 | @staticmethod 12 | def collate_fn( 13 | batch: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] 14 | ) -> tuple[list[torch.Tensor], list[dict[str, torch.Tensor]]]: 15 | images = [item[0] for item in batch] 16 | annots = [{"boxes": item[1], "labels": item[2]} for item in batch] 17 | 18 | return images, annots 19 | 20 | @override 21 | @classmethod 22 | def create_loader(cls, **kwargs) -> DataLoader: 23 | kwargs.setdefault("collate_fn", cls.collate_fn) 24 | return super().create_loader(**kwargs) 25 | 26 | @override 27 | def create_transform(self, is_train: bool, **kwargs) -> A.Compose: 28 | kwargs.setdefault("has_bbox", True) 29 | return super().create_transform(is_train, **kwargs) 30 | 31 | @override 32 | def __getitem__( 33 | self, 34 | index: int, 35 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 36 | sample = self.data[index] 37 | image, bboxes, bbcats = self.load_transform( 38 | image=sample[self.img_folder], 39 | boxes=[sample["annotations"]], 40 | bcats=[1], # 0 - background 41 | transform=self.transform, 42 | ) 43 | 44 | if len(bboxes) == 0: 45 | bboxes = torch.tensor([[0, 0, 1, 1]], dtype=torch.float32) 46 | bbcats = torch.tensor([0], dtype=torch.int64) 47 | 48 | return image, bboxes, bbcats 49 | -------------------------------------------------------------------------------- /docs/_static/js/pypi-icon.js: -------------------------------------------------------------------------------- 1 | FontAwesome.library.add( 2 | (faListOldStyle = { 3 | prefix: "fa-custom", 4 | iconName: "pypi", 5 | icon: [ 6 | 17.313, // viewBox width 7 | 19.807, // viewBox height 8 | [], // ligature 9 | "e001", // unicode codepoint - private use area 10 | "m10.383 0.2-3.239 1.1769 3.1883 1.1614 3.239-1.1798zm-3.4152 1.2411-3.2362 1.1769 3.1855 1.1614 3.2369-1.1769zm6.7177 0.00281-3.2947 1.2009v3.8254l3.2947-1.1988zm-3.4145 1.2439-3.2926 1.1981v3.8254l0.17548-0.064132 3.1171-1.1347zm-6.6564 0.018325v3.8247l3.244 1.1805v-3.8254zm10.191 0.20931v2.3137l3.1777-1.1558zm3.2947 1.2425-3.2947 1.1988v3.8254l3.2947-1.1988zm-8.7058 0.45739c0.00929-1.931e-4 0.018327-2.977e-4 0.027485 0 0.25633 0.00851 0.4263 0.20713 0.42638 0.49826 1.953e-4 0.38532-0.29327 0.80469-0.65542 0.93662-0.36226 0.13215-0.65608-0.073306-0.65613-0.4588-6.28e-5 -0.38556 0.2938-0.80504 0.65613-0.93662 0.068422-0.024919 0.13655-0.038114 0.20156-0.039466zm5.2913 0.78369-3.2947 1.1988v3.8247l3.2947-1.1981zm-10.132 1.239-3.2362 1.1769 3.1883 1.1614 3.2362-1.1769zm6.7177 0.00213-3.2926 1.2016v3.8247l3.2926-1.2009zm-3.4124 1.2439-3.2947 1.1988v3.8254l3.2947-1.1988zm-6.6585 0.016195v3.8275l3.244 1.1805v-3.8254zm16.9 0.21143-3.2947 1.1988v3.8247l3.2947-1.1981zm-3.4145 1.2411-3.2926 1.2016v3.8247l3.2926-1.2009zm-3.4145 1.2411-3.2926 1.2016v3.8247l3.2926-1.2009zm-3.4124 1.2432-3.2947 1.1988v3.8254l3.2947-1.1988zm-6.6585 0.019027v3.8247l3.244 1.1805v-3.8254zm13.485 1.4497-3.2947 1.1988v3.8247l3.2947-1.1981zm-3.4145 1.2411-3.2926 1.2016v3.8247l3.2926-1.2009zm2.4018 0.38127c0.0093-1.83e-4 0.01833-3.16e-4 0.02749 0 0.25633 0.0085 0.4263 0.20713 0.42638 0.49826 1.97e-4 0.38532-0.29327 0.80469-0.65542 0.93662-0.36188 0.1316-0.65525-0.07375-0.65542-0.4588-1.95e-4 -0.38532 0.29328-0.80469 0.65542-0.93662 0.06842-0.02494 0.13655-0.03819 0.20156-0.03947zm-5.8142 0.86403-3.244 1.1805v1.4201l3.244 1.1805z", // svg path (https://simpleicons.org/icons/pypi.svg) 11 | ], 12 | }) 13 | ); -------------------------------------------------------------------------------- /docs/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | /* Adjust icon size */ 2 | .fa-zenodo { 3 | font-size: 1em; 4 | } 5 | 6 | /* Header setup for index.html */ 7 | #glasses-detector h1 { 8 | display: flex; 9 | align-items: center; 10 | } 11 | 12 | /* Logo setup for index.html page header */ 13 | #glasses-detector h1::before { 14 | content: ""; 15 | display: inline-block; 16 | height: 1.5em; 17 | width: 1.5em; 18 | margin-right: 10px; 19 | } 20 | 21 | /* Logo setup for index.html page header for dark mode */ 22 | html[data-theme="dark"] #glasses-detector h1::before { 23 | background: url("../img/logo-dark.png") no-repeat center/contain; 24 | } 25 | 26 | /* Logo setup for index.html page header for light mode */ 27 | html[data-theme="light"] #glasses-detector h1::before { 28 | background: url("../img/logo-light.png") no-repeat center/contain; 29 | } 30 | 31 | /* Adjust navbar logo size */ 32 | img.logo__image { 33 | width: 40px; 34 | height: 40px; 35 | } 36 | 37 | /* Align the table content in the center */ 38 | #classification-kinds td, #detection-kinds td, #segmentation-kinds td { 39 | vertical-align: middle; 40 | } 41 | 42 | /* Adjust banner */ 43 | #banner { 44 | margin-top: 2em; 45 | border-radius: 0.3em; 46 | } 47 | 48 | .toctree-wrapper { 49 | /* Padding + Width */ 50 | padding: 1em 2em; 51 | width: 100%; 52 | 53 | /* Border Setup */ 54 | border-radius: 1em; 55 | border-width: .25em; 56 | border-style: solid; 57 | 58 | /* Color setup (automatic for dark/light modes) */ 59 | background-color: var(--pst-color-on-background); 60 | border: .25em solid var(--pst-color-border) 61 | } 62 | 63 | #index-toctree { 64 | /* Multi-column Layout */ 65 | column-count: 2; 66 | } 67 | 68 | #index-toctree .toctree-l1 { 69 | /* List Style */ 70 | padding: 0.5em 0; 71 | } 72 | 73 | #index-toctree .toctree-l2 { 74 | /* List Style */ 75 | list-style-type: circle !important; 76 | margin-left: 1em !important; 77 | } -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | Glasses Detector 2 | ================ 3 | 4 | .. image:: _static/img/banner.jpg 5 | :alt: Glasses Detector Banner 6 | :align: center 7 | :name: banner 8 | 9 | About 10 | ----- 11 | 12 | Package for processing images with different types of glasses and their parts. It provides a quick way to use the pre-trained models for **3** kinds of tasks, each being divided to multiple categories, for instance, *classification of sunglasses* or *segmentation of glasses frames* (see :doc:`docs/features` for more details): 13 | 14 | .. list-table:: 15 | :align: center 16 | 17 | * - 18 | - 19 | * - **Classification** 20 | - 👓 *transparent* 🕶️ *opaque* 🥽 *any* ➿ *shadows* 21 | * - **Detection** 22 | - 🤓 *worn* 👓 *standalone* 👀 *eye-area* 23 | * - **Segmentation** 24 | - 😎 *full* 🖼️ *frames* 🦿 *legs* 🔍 *lenses* 👥 *shadows* 25 | 26 | .. raw:: html 27 | 28 |
29 | 30 | The processing can be launched via the command line or written in a *Python* script. Based on the selected task, an image or a directory of images will be processed and corresponding predictions, e.g., labels or masks, will be generated. 31 | 32 | .. seealso:: 33 | 34 | Refer to |repo_link| for information about the datasets used and how to train or test your own models. Model fitting and evaluation can be simply launched through terminal with commands integrated from `PyTorch Lightning CLI `_. 35 | 36 | .. |repo_link| raw:: html 37 | 38 | 39 | 40 | GitHub repository 41 | 42 | 43 | 44 | Contents 45 | -------- 46 | 47 | .. toctree:: 48 | :maxdepth: 2 49 | :name: index-toctree 50 | 51 | docs/install 52 | docs/features 53 | docs/examples 54 | docs/cli 55 | docs/api 56 | docs/credits 57 | 58 | 59 | Indices and Tables 60 | ------------------ 61 | 62 | * :ref:`genindex` 63 | * :ref:`modindex` 64 | * :ref:`search` 65 | -------------------------------------------------------------------------------- /docs/_static/bib/references.bib: -------------------------------------------------------------------------------- 1 | @inproceedings{ma2018shufflenet, 2 | title={Shufflenet v2: Practical guidelines for efficient cnn architecture design}, 3 | author={Ma, Ningning and Zhang, Xiangyu and Zheng, Hai-Tao and Sun, Jian}, 4 | booktitle={Proceedings of the European Conference on Computer Vision (ECCV)}, 5 | pages={116--131}, 6 | year={2018} 7 | } 8 | 9 | @inproceedings{liu2016ssd, 10 | title={Ssd: Single shot multibox detector}, 11 | author={Liu, Wei and Anguelov, Dragomir and Erhan, Dumitru and Szegedy, Christian and Reed, Scott and Fu, Cheng-Yang and Berg, Alexander C}, 12 | booktitle={Computer Vision--ECCV 2016: 14th European Conference, Amsterdam, The Netherlands, October 11--14, 2016, Proceedings, Part I 14}, 13 | pages={21--37}, 14 | year={2016}, 15 | organization={Springer} 16 | } 17 | 18 | @inproceedings{howard2019searching, 19 | title={Searching for mobilenetv3}, 20 | author={Howard, Andrew and Sandler, Mark and Chu, Grace and Chen, Liang-Chieh and Chen, Bo and Tan, Mingxing and Wang, Weijun and Zhu, Yukun and Pang, Ruoming and Vasudevan, Vijay and others}, 21 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 22 | pages={1314--1324}, 23 | year={2019} 24 | } 25 | 26 | @article{ren2015faster, 27 | title={Faster r-cnn: Towards real-time object detection with region proposal networks}, 28 | author={Ren, Shaoqing and He, Kaiming and Girshick, Ross and Sun, Jian}, 29 | journal={Advances in Neural Information Processing Systems}, 30 | volume={28}, 31 | year={2015} 32 | } 33 | 34 | @inproceedings{long2015fully, 35 | title={Fully convolutional networks for semantic segmentation}, 36 | author={Long, Jonathan and Shelhamer, Evan and Darrell, Trevor}, 37 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 38 | pages={3431--3440}, 39 | year={2015} 40 | } 41 | 42 | @inproceedings{he2016deep, 43 | title={Deep residual learning for image recognition}, 44 | author={He, Kaiming and Zhang, Xiangyu and Ren, Shaoqing and Sun, Jian}, 45 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 46 | pages={770--778}, 47 | year={2016} 48 | } -------------------------------------------------------------------------------- /src/glasses_detector/architectures/tiny_binary_classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class TinyBinaryClassifier(nn.Module): 6 | """Tiny binary classifier. 7 | 8 | This is a custom classifier created with the aim to contain very few 9 | parameters while maintaining a reasonable accuracy. It only has 10 | several sequential convolutional and pooling blocks (with 11 | batch-norm in between). 12 | """ 13 | 14 | def __init__(self): 15 | super().__init__() 16 | 17 | # Several convolutional blocks 18 | self.features = nn.Sequential( 19 | self._create_block(3, 5, 3), 20 | self._create_block(5, 10, 3), 21 | self._create_block(10, 15, 3), 22 | self._create_block(15, 20, 3), 23 | self._create_block(20, 25, 3), 24 | self._create_block(25, 80, 3), 25 | nn.AdaptiveAvgPool2d(1), 26 | nn.Flatten(), 27 | ) 28 | 29 | # Fully connected layer 30 | self.fc = nn.Linear(80, 1) 31 | 32 | def _create_block(self, num_in, num_out, filter_size): 33 | return nn.Sequential( 34 | nn.Conv2d(num_in, num_out, filter_size, 1, "valid", bias=False), 35 | nn.ReLU(), 36 | nn.BatchNorm2d(num_out), 37 | nn.MaxPool2d(2, 2), 38 | ) 39 | 40 | def forward(self, x: torch.Tensor) -> torch.Tensor: 41 | """Performs forward pass. 42 | 43 | Predicts raw scores for the given batch of inputs. Scores are 44 | unbounded, anything that's less than 0, means positive class is 45 | unlikely and anything that's above 0 indicates that the positive 46 | class is likely 47 | 48 | Args: 49 | x (torch.Tensor): Image batch of shape (N, C, H, W). Note 50 | that pixel values are normalized and squeezed between 51 | 0 and 1. 52 | 53 | Returns: 54 | torch.Tensor: An output tensor of shape (N,) indicating 55 | whether each nth image falls under the positive class or 56 | not. The scores are unbounded, thus, to convert to a 57 | probability, sigmoid function must be used. 58 | """ 59 | return self.fc(self.features(x)) 60 | -------------------------------------------------------------------------------- /docs/_static/svg/colab.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | colab 22 | colab 23 | open 24 | open 25 | 26 | -------------------------------------------------------------------------------- /docs/helpers/custom_invs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from importlib.metadata import version 4 | 5 | import sphobjinv as soi 6 | 7 | 8 | class CustomInvs: 9 | def __init__(self, static_path: str = "_static"): 10 | # Init inv directory and create it if not exists 11 | self.inv_dir = os.path.join(static_path, "inv") 12 | os.makedirs(self.inv_dir, exist_ok=True) 13 | 14 | def create_tqdm_inv(self) -> dict[str, tuple[str, str]]: 15 | # Init inv and module 16 | inv = soi.Inventory() 17 | 18 | # Define the Sphinx header information 19 | inv.project = "tqdm" 20 | inv.version = version("tqdm") 21 | 22 | # Define the tqdm class 23 | inv_obj = soi.DataObjStr( 24 | name="tqdm.tqdm", 25 | domain="py", 26 | role="class", 27 | priority="1", 28 | uri="tqdm/#tqdm-objects", 29 | dispname="tqdm", 30 | ) 31 | inv.objects.append(inv_obj) 32 | 33 | # Write the inventory to a file 34 | path = os.path.join(self.inv_dir, "tqdm.inv") 35 | text = soi.compress(inv.data_file(contract=True)) 36 | soi.writebytes(path, text) 37 | 38 | return {"tqdm": ("https://tqdm.github.io/docs/", path)} 39 | 40 | def create_builtin_constants_inv(self) -> dict[str, tuple[str, str]]: 41 | # Init inv and module 42 | inv = soi.Inventory() 43 | 44 | # Define the Sphinx header information 45 | inv.project = "builtin_constants" 46 | major, minor, micro = sys.version_info[:3] 47 | inv.version = f"{major}.{minor}.{micro}" 48 | 49 | for constant in ["None", "True", "False"]: 50 | # Define the constant as class 51 | inv_obj = soi.DataObjStr( 52 | name=f"{constant}", 53 | domain="py", 54 | role="class", # dummy class for linking 55 | priority="1", 56 | uri=f"library/constants.html#{constant}", 57 | dispname=constant, 58 | ) 59 | inv.objects.append(inv_obj) 60 | 61 | # Write the inventory to a file 62 | path = os.path.join(self.inv_dir, "builtin_constants.inv") 63 | text = soi.compress(inv.data_file(contract=True)) 64 | soi.writebytes(path, text) 65 | 66 | return {"builtin_constants": ("https://docs.python.org/3", path)} 67 | 68 | def __call__(self) -> dict[str, tuple[str, str]]: 69 | # Init custom invs 70 | custom_invs = {} 71 | 72 | for method_name in dir(self): 73 | if method_name == "create_builtin_constants_inv": 74 | continue 75 | 76 | if method_name.startswith("create_"): 77 | # Update custom invs dictionary with the new one 78 | custom_invs.update(getattr(self, method_name)()) 79 | 80 | return custom_invs 81 | -------------------------------------------------------------------------------- /docs/_static/css/highlights.css: -------------------------------------------------------------------------------- 1 | div.highlight-python div.highlight pre .linenos { 2 | margin-right: 1em; 3 | } 4 | 5 | /* Default python code syntax highlights for dark mode*/ 6 | html[data-theme="dark"] div.highlight-python div.highlight pre { 7 | .n { color: #9CDCFE; } /* Variable */ 8 | .k, .ow, .kn { color: #C586C0; } /* Keyword */ 9 | .gp { color: #af92ff; } /* Generic Prompt */ 10 | .kc { color: #569CD6; } /* Built-in constant */ 11 | .s2 { color: #CE9178; } /* String */ 12 | .mi, .m { color: #B5CEA8; } /* Number */ 13 | .o { color: #CCCCCC; } /* Operator */ 14 | .go { color: #808080; } /* Generic Output */ 15 | .p { color: #D4D4D4; } /* Punctuation */ 16 | .c1 { color: #6A9955; } /* Comment */ 17 | .nn, .nc { color: #4EC9B0; } /* Class/module name */ 18 | .nf, .nb { color: #DCDCAA; } /* Function name */ 19 | .err { color: #F44747; } /* Error */ 20 | } 21 | 22 | /* Default python code syntax highlights for light mode*/ 23 | html[data-theme="light"] div.highlight-python div.highlight pre { 24 | .n { color: #3B3B3B; } /* Variable */ 25 | .k, .ow, .kn { color: #AF00DB; } /* Keyword */ 26 | .gp { color: #ff9ef7; } /* Generic Prompt */ 27 | .kc { color: #0000FF; } /* Built-in constant */ 28 | .s2 { color: #A31515; } /* String */ 29 | .mi, .m { color: #098658; } /* Number */ 30 | .o { color: #3B3B3B; } /* Operator */ 31 | .go { color: #8c8c8c; } /* Generic Output */ 32 | .p { color: #3B3B3B; } /* Punctuation */ 33 | .c1 { color: #008000; } /* Comment */ 34 | .nn, .nc { color: #267F99; } /* Class/module name */ 35 | .nf, .nb { color: #001080; } /* Function name */ 36 | .err { color: #fe7272; } /* Error */ 37 | } 38 | 39 | /* Default bash code syntax highlights for dark mode*/ 40 | html[data-theme="dark"] div.highlight-bash div.highlight pre { 41 | .c1 { color: #6A9955; } /* Comment */ 42 | } 43 | 44 | /* Default bash code syntax highlights for light mode*/ 45 | html[data-theme="light"] div.highlight-bash div.highlight pre { 46 | .c1 { color: #008000; } /* Comment */ 47 | } 48 | 49 | /* Custom python code syntax highlights for dark mode*/ 50 | html[data-theme="dark"] div.highlight-python div.highlight pre { 51 | .custom-highlight-class { color: #4EC9B0; } 52 | .custom-highlight-function { color: #DCDCAA; } 53 | .custom-highlight-variable { color: #9CDCFE; } 54 | .custom-highlight-constant { color: #4FC1FF; } 55 | } 56 | 57 | /* Custom python code syntax highlights for light mode*/ 58 | html[data-theme="light"] div.highlight-python div.highlight pre { 59 | .custom-highlight-class { color: #267F99; } 60 | .custom-highlight-function { color: #001080; } 61 | .custom-highlight-variable { color: #3B3B3B; } 62 | .custom-highlight-constant { color: #0070C1; } 63 | } 64 | 65 | /* Custom bash code syntax highlights for dark mode*/ 66 | html[data-theme="dark"] div.highlight-bash div.highlight pre, html[data-theme="dark"] code.highlight-bash { 67 | .custom-highlight-default { color: #CE9178; } 68 | .custom-highlight-start { color: #DCDCAA; } 69 | .custom-highlight-flag { color: #569CD6; } 70 | .custom-highlight-op { color: #CCCCCC; } 71 | } 72 | 73 | /* Custom bash code syntax highlights for light mode*/ 74 | html[data-theme="light"] div.highlight-bash div.highlight pre, html[data-theme="light"] code.highlight-bash { 75 | .custom-highlight-default { color: #A31515; } 76 | .custom-highlight-start { color: #001080; } 77 | .custom-highlight-flag { color: #0000FF; } 78 | .custom-highlight-op { color: #3B3B3B; } 79 | } -------------------------------------------------------------------------------- /src/glasses_detector/_wrappers/binary_detector.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | import torchmetrics 4 | from torch.optim import AdamW 5 | from torch.optim.lr_scheduler import ReduceLROnPlateau 6 | 7 | from .metrics import BoxClippedR2, BoxIoU, BoxMSLE 8 | 9 | 10 | class BinaryDetector(pl.LightningModule): 11 | def __init__(self, model, train_loader=None, val_loader=None, test_loader=None): 12 | super().__init__() 13 | 14 | # Assign attributes 15 | self.model = model 16 | self.train_loader = train_loader 17 | self.val_loader = val_loader 18 | self.test_loader = test_loader 19 | 20 | # Initialize val_loss metric (just the mean) 21 | self.val_loss = torchmetrics.MeanMetric() 22 | 23 | # Create F1 score to monitor average label performance 24 | self.label_metrics = torchmetrics.MetricCollection( 25 | [torchmetrics.F1Score(task="binary")] 26 | ) 27 | 28 | # Initialize some metrics to monitor bbox performance 29 | self.boxes_metrics = torchmetrics.MetricCollection( 30 | [BoxMSLE(), BoxIoU(), BoxClippedR2()] 31 | ) 32 | 33 | def forward(self, *args): 34 | return self.model(*args) 35 | 36 | def training_step(self, batch, batch_idx): 37 | # Forward propagate and compute loss 38 | loss_dict = self(batch[0], batch[1]) 39 | loss = sum(loss for loss in loss_dict.values()) / len(batch[0]) 40 | self.log("train_loss", loss, prog_bar=True) 41 | return loss 42 | 43 | def eval_step(self, batch): 44 | # Forward pass and compute loss 45 | with torch.inference_mode(): 46 | self.train() 47 | loss = sum(loss for loss in self(batch[0], batch[1]).values()) 48 | self.val_loss.update(loss / len(batch[0])) 49 | self.eval() 50 | 51 | # Update all the metrics 52 | self.boxes_metrics.update(self(batch[0]), batch[1], self.label_metrics) 53 | 54 | def on_eval_epoch_end(self, prefix=""): 55 | # Compute total loss and metrics 56 | loss = self.val_loss.compute() 57 | label_metrics = self.label_metrics.compute() 58 | boxes_metrics = self.boxes_metrics.compute() 59 | 60 | # Reset the metrics 61 | self.val_loss.reset() 62 | self.label_metrics.reset() 63 | self.boxes_metrics.reset() 64 | 65 | # Log the metrics and the learning rate 66 | self.log(f"{prefix}_loss", loss, prog_bar=True) 67 | self.log(f"{prefix}_f1", label_metrics["BinaryF1Score"], prog_bar=True) 68 | self.log(f"{prefix}_msle", boxes_metrics["BoxMSLE"], prog_bar=True) 69 | self.log(f"{prefix}_r2", boxes_metrics["BoxClippedR2"], prog_bar=True) 70 | self.log(f"{prefix}_iou", boxes_metrics["BoxIoU"], prog_bar=True) 71 | 72 | if not isinstance(opt := self.optimizers(), list): 73 | # Log the learning rate of a single optimizer 74 | self.log("lr", opt.param_groups[0]["lr"], prog_bar=True) 75 | 76 | def validation_step(self, batch, batch_idx): 77 | self.eval_step(batch) 78 | 79 | def on_validation_epoch_end(self): 80 | self.on_eval_epoch_end(prefix="val") 81 | 82 | def test_step(self, batch, batch_idx): 83 | self.eval_step(batch) 84 | 85 | def on_test_epoch_end(self): 86 | self.on_eval_epoch_end(prefix="test") 87 | 88 | def train_dataloader(self): 89 | return self.train_loader 90 | 91 | def val_dataloader(self): 92 | return self.val_loader 93 | 94 | def test_dataloader(self): 95 | return self.test_loader 96 | 97 | def configure_optimizers(self): 98 | # Initialize AdamW optimizer and Reduce On Plateau scheduler 99 | optimizer = AdamW(self.parameters(), lr=1e-3, weight_decay=1e-3) 100 | scheduler = ReduceLROnPlateau( 101 | optimizer=optimizer, 102 | factor=0.3, 103 | patience=15, 104 | threshold=0.01, 105 | min_lr=1e-6, 106 | ) 107 | 108 | return { 109 | "optimizer": optimizer, 110 | "lr_scheduler": scheduler, 111 | "monitor": "val_loss", 112 | } 113 | -------------------------------------------------------------------------------- /src/glasses_detector/architectures/tiny_binary_segmenter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.ops import Conv2dNormActivation 4 | 5 | 6 | class TinyBinarySegmenter(nn.Module): 7 | """Tiny binary segmenter. 8 | 9 | This is a custom segmenter created with the aim to contain very few 10 | parameters while maintaining a reasonable accuracy. It only has 11 | several sequential up-convolution and down-convolution layers with 12 | residual connections and is very similar to U-Net. 13 | 14 | Note: 15 | You can read more about U-Net architecture in the following 16 | paper by O. Ronneberger et al.: 17 | `U-Net: Convolutional Networks for Biomedical Image Segmentation `_ 18 | """ 19 | 20 | class _Down(nn.Module): 21 | def __init__(self, in_channels, out_channels): 22 | super().__init__() 23 | 24 | self.pool0 = nn.MaxPool2d(2) 25 | self.conv1 = Conv2dNormActivation(in_channels, out_channels) 26 | self.conv2 = Conv2dNormActivation(out_channels, out_channels) 27 | 28 | def forward(self, x): 29 | return self.conv2(self.conv1(self.pool0(x))) 30 | 31 | class _Up(nn.Module): 32 | def __init__(self, in_channels, out_channels): 33 | super().__init__() 34 | 35 | half_channels = in_channels // 2 36 | self.conv0 = nn.ConvTranspose2d(half_channels, half_channels, 2, 2) 37 | self.conv1 = Conv2dNormActivation(in_channels, out_channels) 38 | self.conv2 = Conv2dNormActivation(out_channels, out_channels) 39 | 40 | def forward(self, x1, x2): 41 | x1 = self.conv0(x1) 42 | 43 | diffY = x2.size()[2] - x1.size()[2] 44 | diffX = x2.size()[3] - x1.size()[3] 45 | 46 | x1 = nn.functional.pad( 47 | x1, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2) 48 | ) 49 | 50 | x = torch.cat([x2, x1], dim=1) 51 | 52 | return self.conv2(self.conv1(x)) 53 | 54 | def __init__(self): 55 | super().__init__() 56 | 57 | # Feature extraction layer 58 | self.first = nn.Sequential( 59 | Conv2dNormActivation(3, 16), 60 | Conv2dNormActivation(16, 16), 61 | ) 62 | 63 | # Down-sampling layers 64 | self.down1 = self._Down(16, 32) 65 | self.down2 = self._Down(32, 64) 66 | self.down3 = self._Down(64, 64) 67 | 68 | # Up-sampling layers 69 | self.up1 = self._Up(128, 32) 70 | self.up2 = self._Up(64, 16) 71 | self.up3 = self._Up(32, 16) 72 | 73 | # Pixel-wise classification layer 74 | self.last = nn.Conv2d(16, 1, 1) 75 | 76 | def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: 77 | """Performs forward pass. 78 | 79 | Predicts raw pixel scores for the given batch of inputs. Scores 80 | are unbounded - anything that's less than 0 means positive class 81 | belonging to the pixel is unlikely and anything that's above 0 82 | indicates that positive class for a particular pixel is likely. 83 | 84 | Args: 85 | x (torch.Tensor): Image batch of shape (N, C, H, W). Note 86 | that pixel values are normalized and squeezed between 87 | 0 and 1. 88 | 89 | Returns: 90 | dict[str, torch.Tensor]: A dictionary with a single "out" 91 | entry (for compatibility). The value is an output tensor of 92 | shape (N, 1, H, W) indicating which pixels in the image fall 93 | under positive category. The scores are unbounded, thus, to 94 | convert to probabilities, sigmoid function must be used. 95 | """ 96 | # Extract primary features 97 | x1 = self.first(x) 98 | 99 | # Downsample features 100 | x2 = self.down1(x1) 101 | x3 = self.down2(x2) 102 | x4 = self.down3(x3) 103 | 104 | # Updample features 105 | x = self.up1(x4, x3) 106 | x = self.up2(x, x2) 107 | x = self.up3(x, x1) 108 | 109 | # Predict one channel 110 | out = self.last(x) 111 | 112 | return {"out": out} 113 | -------------------------------------------------------------------------------- /src/glasses_detector/_wrappers/binary_segmenter.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch.nn as nn 3 | import torchmetrics 4 | from torch.optim import AdamW 5 | from torch.optim.lr_scheduler import ReduceLROnPlateau 6 | 7 | 8 | class BinarySegmenter(pl.LightningModule): 9 | def __init__(self, model, train_loader=None, val_loader=None, test_loader=None): 10 | super().__init__() 11 | 12 | # Assign attributes 13 | self.model = model 14 | self.train_loader = train_loader 15 | self.val_loader = val_loader 16 | self.test_loader = test_loader 17 | 18 | # Create loss function and account for imbalance of classes 19 | self.criterion = nn.BCEWithLogitsLoss(pos_weight=self.pos_weight) 20 | self.val_loss = torchmetrics.MeanMetric() 21 | 22 | # Initialize some metrics to monitor the performance 23 | self.metrics = torchmetrics.MetricCollection( 24 | [ 25 | torchmetrics.MatthewsCorrCoef(task="binary"), 26 | torchmetrics.F1Score(task="binary"), # same as Dice 27 | torchmetrics.JaccardIndex(task="binary"), # IoU 28 | ] 29 | ) 30 | 31 | @property 32 | def pos_weight(self): 33 | if self.train_loader is None: 34 | # Not known 35 | return None 36 | 37 | # Init counts 38 | pos, neg = 0, 0 39 | 40 | for _, mask in self.train_loader: 41 | # Update pos and neg sums 42 | pos += mask.sum() 43 | neg += (1 - mask).sum() 44 | 45 | return neg / pos 46 | 47 | def forward(self, x): 48 | return self.model(x)["out"] 49 | 50 | def training_step(self, batch, batch_idx): 51 | # Forward propagate and compute loss 52 | loss = self.criterion(self(batch[0]), batch[1]) 53 | self.log("train_loss", loss, prog_bar=True) 54 | return loss 55 | 56 | def eval_step(self, batch): 57 | # Forward pass 58 | x, y = batch 59 | y_hat = self(x) 60 | 61 | # Compute the loss and the metrics 62 | self.val_loss.update(self.criterion(y_hat, y)) 63 | self.metrics.update(y_hat.sigmoid(), y.long()) 64 | 65 | def on_eval_epoch_end(self, prefix=""): 66 | # Compute total loss and metrics 67 | loss = self.val_loss.compute() 68 | metrics = self.metrics.compute() 69 | 70 | # Reset the loss and the metrics 71 | self.val_loss.reset() 72 | self.metrics.reset() 73 | 74 | # Log the loss and the metrics 75 | self.log(f"{prefix}_loss", loss, prog_bar=True) 76 | self.log(f"{prefix}_mcc", metrics["BinaryMatthewsCorrCoef"], prog_bar=True) 77 | self.log(f"{prefix}_f1", metrics["BinaryF1Score"], prog_bar=True) 78 | self.log(f"{prefix}_iou", metrics["BinaryJaccardIndex"], prog_bar=True) 79 | 80 | if not isinstance(opt := self.optimizers(), list): 81 | # Log the learning rate of a single optimizer 82 | self.log("lr", opt.param_groups[0]["lr"], prog_bar=True) 83 | 84 | def validation_step(self, batch, batch_idx): 85 | self.eval_step(batch) 86 | 87 | def on_validation_epoch_end(self): 88 | self.on_eval_epoch_end(prefix="val") 89 | 90 | def test_step(self, batch, batch_idx): 91 | self.eval_step(batch) 92 | 93 | def on_test_epoch_end(self): 94 | self.on_eval_epoch_end(prefix="test") 95 | 96 | def train_dataloader(self): 97 | return self.train_loader 98 | 99 | def val_dataloader(self): 100 | return self.val_loader 101 | 102 | def test_dataloader(self): 103 | return self.test_loader 104 | 105 | def configure_optimizers(self): 106 | # Initialize AdamW optimizer and Reduce On Plateau scheduler 107 | optimizer = AdamW(self.parameters(), lr=1e-3, weight_decay=1e-4) 108 | scheduler = ReduceLROnPlateau( 109 | optimizer=optimizer, 110 | factor=0.3, 111 | patience=15, 112 | threshold=0.01, 113 | min_lr=1e-6, 114 | ) 115 | 116 | return { 117 | "optimizer": optimizer, 118 | "lr_scheduler": scheduler, 119 | "monitor": "val_loss", 120 | } 121 | -------------------------------------------------------------------------------- /docs/conf.yaml: -------------------------------------------------------------------------------- 1 | build-finished: 2 | TYPE_ALIASES: 3 | FilePath: "glasses_detector.utils." 4 | Scalar: "glasses_detector.components.pred_type." 5 | Tensor: "glasses_detector.components.pred_type." 6 | Default: "glasses_detector.components.pred_type." 7 | StandardScalar: "glasses_detector.components.pred_type." 8 | StandardTensor: "glasses_detector.components.pred_type." 9 | StandardDefault: "glasses_detector.components.pred_type." 10 | NonDefault: "glasses_detector.components.pred_type." 11 | Either: "glasses_detector.components.pred_type." 12 | 13 | LONG_SIGNATURE_IDS: 14 | - "glasses_detector.components.pred_type.PredType" 15 | - "glasses_detector.components.pred_interface.PredInterface.process_file" 16 | - "glasses_detector.components.pred_interface.PredInterface.process_dir" 17 | - "glasses_detector.components.base_model.BaseGlassesModel" 18 | - "glasses_detector.components.base_model.BaseGlassesModel.predict" 19 | - "glasses_detector.components.base_model.BaseGlassesModel.process_dir" 20 | - "glasses_detector.components.base_model.BaseGlassesModel.process_file" 21 | - "glasses_detector.classifier.GlassesClassifier" 22 | - "glasses_detector.classifier.GlassesClassifier.draw_label" 23 | - "glasses_detector.classifier.GlassesClassifier.predict" 24 | - "glasses_detector.classifier.GlassesClassifier.process_dir" 25 | - "glasses_detector.classifier.GlassesClassifier.process_file" 26 | - "glasses_detector.detector.GlassesDetector" 27 | - "glasses_detector.detector.GlassesDetector.draw_boxes" 28 | - "glasses_detector.detector.GlassesDetector.predict" 29 | - "glasses_detector.detector.GlassesDetector.process_dir" 30 | - "glasses_detector.detector.GlassesDetector.process_file" 31 | - "glasses_detector.segmenter.GlassesSegmenter" 32 | - "glasses_detector.segmenter.GlassesSegmenter.draw_masks" 33 | - "glasses_detector.segmenter.GlassesSegmenter.predict" 34 | - "glasses_detector.segmenter.GlassesSegmenter.process_dir" 35 | - "glasses_detector.segmenter.GlassesSegmenter.process_file" 36 | - "glasses_detector.architectures.tiny_binary_detector.TinyBinaryDetector.forward" 37 | - "glasses_detector.architectures.tiny_binary_detector.TinyBinaryDetector.compute_loss" 38 | 39 | LONG_PARAMETER_IDS: 40 | glasses_detector.components.base_model.BaseGlassesModel.predict: 41 | - "format" 42 | glasses_detector.classifier.GlassesClassifier.predict: 43 | - "format" 44 | glasses_detector.detector.GlassesDetector.predict: 45 | - "format" 46 | glasses_detector.segmenter.GlassesSegmenter.predict: 47 | - "format" 48 | glasses_detector.detector.GlassesDetector.draw_boxes: 49 | - "colors" 50 | glasses_detector.segmenter.GlassesSegmenter.draw_masks: 51 | - "colors" 52 | 53 | CUSTOM_SYNTAX_COLORS_PYTHON: 54 | custom-highlight-class: 55 | - "GlassesClassifier" 56 | - "GlassesDetector" 57 | - "GlassesSegmenter" 58 | - "PredType" 59 | - "Image" 60 | - "type" 61 | - "np" 62 | - "str" 63 | - "int" 64 | - "bool" 65 | - "subprocess" 66 | - "random" 67 | - "@copy_signature" 68 | - "@eval_infer_mode" 69 | - "eval_infer_mode" 70 | custom-highlight-function: 71 | - "process_file" 72 | - "process_dir" 73 | - "run" 74 | - "load" 75 | - "array" 76 | - "zeros" 77 | - "fromarray" 78 | - "values" 79 | - "standardize" 80 | - "is_standard_scalar" 81 | - "is_default" 82 | - "reveal_type" 83 | - "full_signature" 84 | - "test_signature" 85 | custom-highlight-constant: 86 | - "StandardScalar" 87 | - "DEFAULT_SIZE_MAP" 88 | - "DEFAULT_KIND_MAP" 89 | custom-highlight-variable: 90 | - "format" 91 | - "self" 92 | 93 | CUSTOM_SYNTAX_COLORS_BASH: 94 | custom-highlight-start: 95 | - "cd" 96 | - "pip install" 97 | - "git clone" 98 | - "glasses-detector" 99 | custom-highlight-flag: 100 | - "-" 101 | custom-highlight-op: 102 | - "&&" 103 | 104 | SECTION_ICONS: 105 | Installation: "fas fa-download" 106 | Features: "fas fa-cogs" 107 | Examples: "fas fa-code" 108 | CLI: "fas fa-terminal" 109 | API: "fas fa-book" 110 | Credits: "fas fa-heart" 111 | -------------------------------------------------------------------------------- /src/glasses_detector/_wrappers/binary_classifier.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | import torch.nn as nn 4 | import torchmetrics 5 | import tqdm 6 | from torch.optim import AdamW 7 | from torch.optim.lr_scheduler import ReduceLROnPlateau 8 | 9 | 10 | class BinaryClassifier(pl.LightningModule): 11 | def __init__(self, model, train_loader=None, val_loader=None, test_loader=None): 12 | super().__init__() 13 | 14 | # Assign attributes 15 | self.model = model 16 | self.train_loader = train_loader 17 | self.val_loader = val_loader 18 | self.test_loader = test_loader 19 | 20 | # Create loss function and account for imbalance of classes 21 | self.criterion = nn.BCEWithLogitsLoss(pos_weight=self.pos_weight) 22 | self.val_loss = torchmetrics.MeanMetric() 23 | 24 | # Create F1 score and ROC-AUC metrics to monitor 25 | self.metrics = torchmetrics.MetricCollection( 26 | [ 27 | torchmetrics.F1Score(task="binary"), 28 | torchmetrics.AUROC(task="binary"), # ROC-AUC 29 | torchmetrics.AveragePrecision(task="binary"), # PR-AUC 30 | ] 31 | ) 32 | 33 | @property 34 | def pos_weight(self): 35 | if self.train_loader is None: 36 | # Not known 37 | return None 38 | 39 | # Calculate the positive weight to account for class imbalance 40 | iterator = tqdm.tqdm(self.train_loader, desc="Computing pos_weight") 41 | pos_count = sum(y.sum().item() for _, y in iterator) 42 | neg_count = len(self.train_loader.dataset) - pos_count 43 | 44 | return torch.tensor(neg_count / pos_count) 45 | 46 | def forward(self, x): 47 | return self.model(x) 48 | 49 | def training_step(self, batch, batch_idx): 50 | # Forward propagate and compute loss 51 | loss = self.criterion(self(batch[0]), batch[1].to(torch.float32)) 52 | self.log("train_loss", loss, prog_bar=True) 53 | return loss 54 | 55 | def eval_step(self, batch): 56 | # Forward pass 57 | x, y = batch 58 | y_hat = self(x) 59 | 60 | # Compute the loss and the metrics 61 | self.val_loss.update(self.criterion(y_hat, y.to(torch.float32))) 62 | self.metrics.update(y_hat.sigmoid(), y) 63 | 64 | def on_eval_epoch_end(self, prefix=""): 65 | # Compute total loss and metrics 66 | loss = self.val_loss.compute() 67 | metrics = self.metrics.compute() 68 | 69 | # Reset the metrics 70 | self.val_loss.reset() 71 | self.metrics.reset() 72 | 73 | # Log the loss and the metrics 74 | self.log(f"{prefix}_loss", loss, prog_bar=True) 75 | self.log(f"{prefix}_f1", metrics["BinaryF1Score"], prog_bar=True) 76 | self.log(f"{prefix}_roc_auc", metrics["BinaryAUROC"], prog_bar=True) 77 | self.log(f"{prefix}_pr_auc", metrics["BinaryAveragePrecision"], prog_bar=True) 78 | 79 | if not isinstance(opt := self.optimizers(), list): 80 | # Log the learning rate of a single optimizer 81 | self.log("lr", opt.param_groups[0]["lr"], prog_bar=True) 82 | 83 | def validation_step(self, batch, batch_idx): 84 | self.eval_step(batch) 85 | 86 | def on_validation_epoch_end(self): 87 | self.on_eval_epoch_end(prefix="val") 88 | 89 | def test_step(self, batch, batch_idx): 90 | self.eval_step(batch) 91 | 92 | def on_test_epoch_end(self): 93 | self.on_eval_epoch_end(prefix="test") 94 | 95 | def train_dataloader(self): 96 | return self.train_loader 97 | 98 | def val_dataloader(self): 99 | return self.val_loader 100 | 101 | def test_dataloader(self): 102 | return self.test_loader 103 | 104 | def configure_optimizers(self): 105 | # Initialize AdamW optimizer and Reduce On Plateau scheduler 106 | optimizer = AdamW(self.parameters(), lr=1e-3, weight_decay=1e-3) 107 | scheduler = ReduceLROnPlateau( 108 | optimizer=optimizer, 109 | factor=0.3, 110 | patience=15, 111 | threshold=0.01, 112 | min_lr=1e-6, 113 | ) 114 | 115 | return { 116 | "optimizer": optimizer, 117 | "lr_scheduler": scheduler, 118 | "monitor": "val_loss", 119 | } 120 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | import os 7 | import sys 8 | from pathlib import Path 9 | 10 | from sphinx.application import Sphinx 11 | 12 | sys.path += [ 13 | str(Path(__file__).parent.parent / "src"), 14 | str(Path(__file__).parent), 15 | ] 16 | 17 | from helpers import BuildFinished, CustomInvs 18 | 19 | DOCS_DIR = Path(os.path.dirname(os.path.abspath(__file__))) 20 | 21 | # -- Project information ----------------------------------------------------- 22 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 23 | 24 | project = "Glasses Detector" 25 | copyright = "2024, Mantas Birškus" 26 | author = "Mantas Birškus" 27 | release = "v1.0.4" 28 | 29 | # -- General configuration --------------------------------------------------- 30 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 31 | 32 | extensions = [ 33 | "sphinx.ext.intersphinx", 34 | "sphinx.ext.todo", 35 | "sphinx.ext.napoleon", 36 | "sphinx.ext.viewcode", 37 | "sphinx.ext.autodoc", 38 | "sphinx.ext.autosectionlabel", 39 | "sphinx_copybutton", 40 | "sphinxcontrib.bibtex", 41 | "sphinx_design", 42 | ] 43 | 44 | intersphinx_mapping = { 45 | "python": ("https://docs.python.org/3/", None), 46 | "PIL": ("https://pillow.readthedocs.io/en/stable/", None), 47 | "numpy": ("https://numpy.org/doc/stable/", None), 48 | "torch": ("https://pytorch.org/docs/stable/", None), 49 | "torchvision": ("https://pytorch.org/vision/stable/", None), 50 | "matplotlib": ("https://matplotlib.org/stable/", None), 51 | "tqdm": ("https://tqdm.github.io/docs/", "_static/inv/tqdm.inv"), 52 | } 53 | 54 | # -- Options for napaleon/autosummary/autodoc output ------------------------- 55 | napoleon_use_param = True 56 | autosummary_generate = True 57 | autodoc_typehints = "both" 58 | autodoc_member_order = "bysource" 59 | 60 | templates_path = ["_templates"] 61 | bibtex_bibfiles = ["_static/bib/references.bib"] 62 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 63 | 64 | # -- Options for HTML output ------------------------------------------------- 65 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 66 | 67 | html_theme = "pydata_sphinx_theme" 68 | html_theme_options = { 69 | "logo": { 70 | "alt_text": "Glasses Detector - Home", 71 | "text": f"Glasses Detector {release}", 72 | "image_light": "_static/img/logo-light.png", 73 | "image_dark": "_static/img/logo-dark.png", 74 | }, 75 | "icon_links": [ 76 | { 77 | "name": "GitHub", 78 | "url": "https://github.com/mantasu/glasses-detector", 79 | "icon": "fa-brands fa-github", 80 | }, 81 | { 82 | "name": "Colab", 83 | "url": ( 84 | "https://colab.research.google.com/github/mantasu/glasses-detector/blob/main/notebooks/demo.ipynb" 85 | ), 86 | "icon": "fa-custom fa-colab", 87 | }, 88 | { 89 | "name": "PyPI", 90 | "url": "https://pypi.org/project/glasses-detector", 91 | "icon": "fa-custom fa-pypi", 92 | }, 93 | { 94 | "name": "Zenodo", 95 | "url": "https://zenodo.org/doi/10.5281/zenodo.8126101", 96 | "icon": "fa-custom fa-zenodo", 97 | }, 98 | ], 99 | "show_toc_level": 2, 100 | "navigation_with_keys": False, 101 | "header_links_before_dropdown": 7, 102 | } 103 | html_context = { 104 | "github_user": "mantasu", 105 | "github_repo": "glasses-detector", 106 | "github_version": "main", 107 | "doc_path": "docs", 108 | } 109 | html_static_path = ["_static"] 110 | html_js_files = ["js/colab-icon.js", "js/pypi-icon.js", "js/zenodo-icon.js"] 111 | html_css_files = ["css/highlights.css", "css/signatures.css", "css/custom.css"] 112 | html_title = f"Glasses Detector {release}" 113 | html_favicon = "_static/img/logo-light.png" 114 | 115 | # -- Custom Template Functions ----------------------------------------------- 116 | # https://www.sphinx-doc.org/en/master/development/theming.html#defining-custom-template-functions 117 | 118 | 119 | def setup(app: Sphinx): 120 | # Add local inventories to intersphinx_mapping 121 | custom_invs = CustomInvs(static_path=DOCS_DIR / "_static") 122 | app.config.intersphinx_mapping.update(custom_invs()) 123 | 124 | # Add custom build-finished event 125 | build_finished = BuildFinished(DOCS_DIR / "_static", DOCS_DIR / "conf.yaml") 126 | app.connect("build-finished", build_finished) 127 | -------------------------------------------------------------------------------- /docs/docs/cli.rst: -------------------------------------------------------------------------------- 1 | :fas:`terminal` CLI 2 | =================== 3 | 4 | .. role:: bash(code) 5 | :language: bash 6 | :class: highlight 7 | 8 | These flags allow you to define the kind of task and the model to process your image or a directory with images. Check out how to use them in :ref:`command-line`. 9 | 10 | .. option:: -i path/to/dir/or/file, --input path/to/dir/or/file 11 | 12 | Path to the input image or the directory with images. 13 | 14 | .. option:: -o path/to/dir/or/file, --output path/to/dir/or/file 15 | 16 | Path to the output file or the directory. If not provided, then, if input is a file, the prediction will be printed (or shown if it is an image), otherwise, if input is a directory, the predictions will be written to a directory with the same name with an added suffix ``_preds``. If provided as a file, then the prediction(-s) will be saved to this file (supported extensions include: ``.txt``, ``.csv``, ``.json``, ``.npy``, ``.pkl``, ``.jpg``, ``.png``). If provided as a directory, then the predictions will be saved to this directory use :bash:`--extension` flag to specify the file extensions in that directory. 17 | 18 | **Default:** :py:data:`None` 19 | 20 | .. option:: -e , --extension 21 | 22 | Only used if :bash:`--output` is a directory. The extension to use to save the predictions as files. Common extensions include: ``.txt``, ``.csv``, ``.json``, ``.npy``, ``.pkl``, ``.jpg``, ``.png``. If not specified, it will be set automatically to ``.jpg`` for image predictions and to ``.txt`` for all other formats. 23 | 24 | **Default:** :py:data:`None` 25 | 26 | .. option:: -f , --format 27 | 28 | The format to use to map the raw prediction to. 29 | 30 | * For *classification*, common formats are ``bool``, ``proba``, ``str`` - check :meth:`GlassesClassifier.predict` for more details 31 | * For *detection*, common formats are ``bool``, ``int``, ``img`` - check :meth:`GlassesDetector.predict` for more details 32 | * For *segmentation*, common formats are ``proba``, ``img``, ``mask`` - check :meth:`GlassesSegmenter.predict` for more details 33 | 34 | If not specified, it will be set automatically to ``str``, ``img``, ``mask`` for *classification*, *detection*, *segmentation* respectively. 35 | 36 | **Default:** :py:data:`None` 37 | 38 | .. option:: -t , --task 39 | 40 | The kind of task the model should perform. One of 41 | 42 | * ``classification`` 43 | * ``classification:anyglasses`` 44 | * ``classification:sunglasses`` 45 | * ``classification:eyeglasses`` 46 | * ``classification:shadows`` 47 | * ``detection`` 48 | * ``detection:eyes`` 49 | * ``detection:solo`` 50 | * ``detection:worn`` 51 | * ``segmentation`` 52 | * ``segmentation:frames`` 53 | * ``segmentation:full`` 54 | * ``segmentation:legs`` 55 | * ``segmentation:lenses`` 56 | * ``segmentation:shadows`` 57 | * ``segmentation:smart`` 58 | 59 | If specified only as ``classification``, ``detection``, or ``segmentation``, the subcategories ``anyglasses``, ``worn``, and ``smart`` will be chosen, respectively. 60 | 61 | **Default:** ``classification:anyglasses`` 62 | 63 | .. option:: -s , --size 64 | 65 | The model size which determines architecture type. One of ``small``, ``medium``, ``large`` (or ``s``, ``m``, ``l``). 66 | 67 | **Default:** ``medium`` 68 | 69 | .. option:: -b , --batch-size 70 | 71 | Only used if :bash:`--input` is a directory. The batch size to use when processing the images. This groups the files in the input directory to batches of size ``batch_size`` before processing them. In some cases, larger batch sizes can speed up the processing at the cost of more memory usage. 72 | 73 | **Default:** ``1`` 74 | 75 | .. option:: -p , --pbar 76 | 77 | Only used if :bash:`--input` is a directory. It is the description that is used for the progress bar. If specified as ``""`` (empty string), no progress bar is shown. 78 | 79 | **Default:** ``"Processing"`` 80 | 81 | .. option:: -w path/to/weights.pth, --weights path/to/weights.pth 82 | 83 | Path to custom weights to load into the model. If not specified, weights will be loaded from the default location (and automatically downloaded there if needed). 84 | 85 | **Default:** :py:data:`None` 86 | 87 | .. option:: -d , --device 88 | 89 | The device on which to perform inference. If not specified, it will be automatically checked if `CUDA `_ or `MPS `_ is supported. 90 | 91 | **Default:** :py:data:`None` 92 | -------------------------------------------------------------------------------- /scripts/analyse.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | import sys 4 | import tempfile 5 | from typing import Any 6 | 7 | import torch 8 | from fvcore.nn import FlopCountAnalysis 9 | from prettytable import PrettyTable 10 | from torch.profiler import ProfilerActivity, profile 11 | 12 | PROJECT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 13 | sys.path.append(os.path.join(PROJECT_DIR, "src")) 14 | torch.set_float32_matmul_precision("medium") 15 | 16 | from glasses_detector import GlassesClassifier, GlassesDetector, GlassesSegmenter 17 | 18 | 19 | def check_filesize(model: torch.nn.Module): 20 | with tempfile.NamedTemporaryFile() as temp: 21 | torch.save(model.state_dict(), temp.name) 22 | return os.path.getsize(temp.name) / (1024**2) 23 | 24 | 25 | def check_num_params(model: torch.nn.Module): 26 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 27 | 28 | 29 | def check_flops(model: torch.nn.Module, input: Any): 30 | flops = FlopCountAnalysis(model, (input,)) 31 | return flops.total() 32 | 33 | 34 | def check_ram(model: torch.nn.Module, input: Any): 35 | # Clean up 36 | gc.collect() 37 | 38 | with profile( 39 | activities=[ProfilerActivity.CPU], 40 | profile_memory=True, 41 | record_shapes=True, 42 | ) as prof: 43 | # Run model 44 | model(input) 45 | 46 | # Sum the memory usage of all the operations on the CPU 47 | memory_usage = sum(item.cpu_memory_usage for item in prof.key_averages()) 48 | 49 | return memory_usage / (1024**2) 50 | 51 | 52 | def check_vram(model: torch.nn.Module, input: Any): 53 | # Clean up 54 | torch.cuda.empty_cache() 55 | torch.cuda.synchronize() 56 | 57 | # Check the allocated memory before 58 | mem_before = torch.cuda.max_memory_allocated() 59 | 60 | if isinstance(input, list): 61 | # Detection models require a list 62 | input = [torch.tensor(i, device="cuda") for i in input] 63 | else: 64 | # Otherwise a single input is enough 65 | input = torch.tensor(input, device="cuda") 66 | 67 | # Run the model on CUDA 68 | model = model.to("cuda") 69 | model(input) 70 | torch.cuda.synchronize() 71 | 72 | # Check the allocated memory after 73 | mem_after = torch.cuda.max_memory_allocated() 74 | 75 | # Clean up 76 | model.to("cpu") 77 | torch.cuda.empty_cache() 78 | 79 | return (mem_after - mem_before) / (1024**2) 80 | 81 | 82 | def analyse(model_cls, task): 83 | # Create a table 84 | table = PrettyTable() 85 | field_names = [ 86 | "Model size", 87 | "Filesize (MB)", 88 | "Num params", 89 | "FLOPS", 90 | "RAM (MB)", 91 | ] 92 | 93 | if torch.cuda.is_available(): 94 | # If CUDA is available, add VRAM 95 | field_names.append("VRAM (MB)") 96 | 97 | # Set the field names 98 | table.field_names = field_names 99 | 100 | if task == "detection": 101 | # Detection models require a list 102 | input = [*torch.randn(1, 3, 256, 256)] 103 | else: 104 | # Otherwise a single input is enough 105 | input = torch.randn(1, 3, 256, 256) 106 | 107 | for size in ["small", "medium", "large"]: 108 | # Clean up 109 | gc.collect() 110 | 111 | # Load the model without pre-trained weights on the CPU 112 | model = model_cls(size=size, weights=False, device="cpu").model 113 | model.eval() 114 | 115 | # Check the basic stats 116 | filesize = check_filesize(model) 117 | num_params = check_num_params(model) 118 | flops = check_flops(model, input) 119 | ram = check_ram(model, input) 120 | 121 | # Stats 122 | row = [ 123 | size, 124 | f"{filesize:.2f}", 125 | f"{num_params:,}", 126 | f"{flops:,}", 127 | f"{ram:.2f}", 128 | ] 129 | 130 | if torch.cuda.is_available(): 131 | # If CUDA is available, check VRAM 132 | vram = check_vram(model, input) 133 | row.append(f"{vram:.2f}") 134 | 135 | # Add the stats row 136 | table.add_row(row) 137 | 138 | # Print table 139 | print(table) 140 | 141 | 142 | def parse_args(): 143 | import argparse 144 | 145 | parser = argparse.ArgumentParser() 146 | parser.add_argument( 147 | "--task", 148 | "-t", 149 | type=str, 150 | required=True, 151 | choices=["classification", "detection", "segmentation"], 152 | ) 153 | return parser.parse_args() 154 | 155 | 156 | def main(): 157 | args = parse_args() 158 | 159 | match args.task: 160 | case "classification": 161 | model_cls = GlassesClassifier 162 | case "detection": 163 | model_cls = GlassesDetector 164 | case "segmentation": 165 | model_cls = GlassesSegmenter 166 | 167 | analyse(model_cls, args.task) 168 | 169 | 170 | if __name__ == "__main__": 171 | main() 172 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | mantix7@gmail.com. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. 129 | -------------------------------------------------------------------------------- /src/glasses_detector/_wrappers/metrics.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import override 3 | 4 | import torch 5 | import torchmetrics 6 | from scipy.optimize import linear_sum_assignment 7 | from torchmetrics.functional import mean_squared_log_error, r2_score 8 | from torchvision.ops import box_iou 9 | 10 | 11 | class BoxMetric(torchmetrics.Metric): 12 | def __init__(self, name="sum", is_min=False, dist_sync_on_step=False): 13 | super().__init__(dist_sync_on_step=dist_sync_on_step) 14 | self.add_state(name, default=torch.tensor(0.0), dist_reduce_fx="sum") 15 | self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") 16 | 17 | # Select min/max mode 18 | self.is_min = is_min 19 | self.name = name 20 | 21 | @abstractmethod 22 | def compute_matrix( 23 | self, 24 | preds: torch.Tensor, 25 | targets: torch.Tensor, 26 | ) -> torch.Tensor: ... 27 | 28 | def update( 29 | self, 30 | preds: list[dict[str, torch.Tensor]], 31 | targets: list[dict[str, torch.Tensor]], 32 | *classification_metrics, 33 | ): 34 | # Initialize flat target labels and predicted labels 35 | pred_l = torch.empty(0, dtype=torch.long, device=preds[0]["labels"].device) 36 | target_l = torch.empty(0, dtype=torch.long, device=preds[0]["labels"].device) 37 | 38 | for pred, target in zip(preds, targets): 39 | if len(pred["boxes"]) == 0: 40 | pred["boxes"] = torch.tensor( 41 | [[0, 0, 0, 0]], dtype=torch.float32, device=pred["boxes"].device 42 | ) 43 | pred["labels"] = torch.tensor( 44 | [0], dtype=torch.long, device=pred["labels"].device 45 | ) 46 | 47 | if len(target["boxes"]) == 0: 48 | target["boxes"] = torch.tensor( 49 | [[0, 0, 0, 0]], dtype=torch.float32, device=target["boxes"].device 50 | ) 51 | target["labels"] = torch.tensor( 52 | [0], dtype=torch.long, device=target["labels"].device 53 | ) 54 | 55 | # Compute the matrix of similarities and select best 56 | similarities = self.compute_matrix(pred["boxes"], target["boxes"]) 57 | cost_matrix = similarities if self.is_min else 1 - similarities 58 | pred_idx, target_idx = linear_sum_assignment(cost_matrix.cpu()) 59 | best_sims = similarities[pred_idx, target_idx] 60 | 61 | # Add the labels of matched predictions 62 | target_l = torch.cat([target_l, target["labels"][target_idx]]) 63 | pred_l = torch.cat([pred_l, pred["labels"][pred_idx]]) 64 | 65 | if (remain := list(set(range(len(pred["boxes"]))) - set(pred_idx))) != []: 66 | # Add the labels of unmatched predictions, set as bg 67 | pred_l = torch.cat([pred_l, pred["labels"][remain]]) 68 | padded = torch.tensor(len(remain) * [0], device=target_l.device) 69 | target_l = torch.cat([target_l, padded]) 70 | 71 | if ( 72 | remain := list(set(range(len(target["boxes"]))) - set(target_idx)) 73 | ) != []: 74 | # Add the labels of unmatched predictions, set as bg 75 | target_l = torch.cat([target_l, target["labels"][remain]]) 76 | padded = torch.tensor(len(remain) * [0], device=pred_l.device) 77 | pred_l = torch.cat([pred_l, torch.tensor(len(remain) * [0])]) 78 | 79 | # Update the bbox metric 80 | setattr(self, self.name, getattr(self, self.name) + best_sims.sum()) 81 | self.total += max(len(pred["boxes"]), len(target["boxes"])) 82 | 83 | for metric in classification_metrics: 84 | # Update classification metrics 85 | metric.update(pred_l, target_l) 86 | 87 | def compute(self): 88 | return getattr(self, self.name) / self.total 89 | 90 | 91 | class BoxIoU(BoxMetric): 92 | def __init__(self, **kwargs): 93 | kwargs["name"] = "iou_sum" 94 | super().__init__(**kwargs) 95 | 96 | @override 97 | def compute_matrix( 98 | self, 99 | preds: torch.Tensor, 100 | targets: torch.Tensor, 101 | ) -> torch.Tensor: 102 | return box_iou(preds, targets) 103 | 104 | 105 | class BoxClippedR2(BoxMetric): 106 | def __init__(self, **kwargs): 107 | # Get r2 kwargs, remove them from kwargs 108 | r2_args = r2_score.__code__.co_varnames 109 | self.r2_kwargs = {k: v for k, v in kwargs.items() if k in r2_args} 110 | kwargs = {k: v for k, v in kwargs.items() if k not in r2_args} 111 | kwargs["name"] = "r2_sum" 112 | super().__init__(**kwargs) 113 | 114 | @override 115 | def compute_matrix( 116 | self, 117 | preds: torch.Tensor, 118 | targets: torch.Tensor, 119 | ) -> torch.Tensor: 120 | return torch.tensor( 121 | [ 122 | [ 123 | max(0, r2_score(p.view(-1), t.view(-1), **self.r2_kwargs)) 124 | for t in targets 125 | ] 126 | for p in preds 127 | ] 128 | ) 129 | 130 | 131 | class BoxMSLE(BoxMetric): 132 | def __init__(self, **kwargs): 133 | kwargs.setdefault("is_min", True) 134 | kwargs["name"] = "msle_sum" 135 | super().__init__(**kwargs) 136 | 137 | @override 138 | def compute_matrix( 139 | self, 140 | preds: torch.Tensor, 141 | targets: torch.Tensor, 142 | ) -> torch.Tensor: 143 | return torch.tensor( 144 | [ 145 | [mean_squared_log_error(p.view(-1), t.view(-1)) for t in targets] 146 | for p in preds 147 | ] 148 | ) 149 | -------------------------------------------------------------------------------- /src/glasses_detector/_data/base_categorized_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from abc import ABC, abstractmethod 4 | from collections import defaultdict 5 | from functools import cached_property 6 | from os.path import basename, splitext 7 | from typing import Any, Callable 8 | 9 | import torch 10 | from torch.utils.data import DataLoader, Dataset 11 | 12 | from .augmenter_mixin import AugmenterMixin 13 | 14 | 15 | class BaseCategorizedDataset(ABC, Dataset, AugmenterMixin): 16 | def __init__( 17 | self, 18 | root: str = ".", 19 | split_type: str = "train", 20 | img_folder: str = "images", 21 | label_type: str = "enum", 22 | cat2idx_fn: Callable[[str], int] | dict[str | int] | None = None, 23 | pth2idx_fn: Callable[[str], str] = lambda x: splitext(basename(x))[0], 24 | seed: int = 0, 25 | ): 26 | super().__init__() 27 | 28 | self.label_type = label_type 29 | self.img_folder = img_folder 30 | self.data = defaultdict(lambda: {}) 31 | self.cats = [] 32 | 33 | for dataset in os.listdir(root): 34 | if not os.path.isdir(p := os.path.join(root, dataset, split_type)): 35 | # No split 36 | continue 37 | 38 | for cat in os.scandir(p): 39 | if cat.name != img_folder and cat.name not in self.cats: 40 | # Expand category list 41 | self.cats.append(cat.name) 42 | 43 | for file in os.scandir(cat.path): 44 | # Add image/annotation path under file name as key 45 | self.data[pth2idx_fn(file.path)][cat.name] = file.path 46 | 47 | # Shuffle only values (sort first for reproducibility) 48 | self.data = [v for _, v in sorted(self.data.items())] 49 | random.seed(seed) 50 | random.shuffle(self.data) 51 | 52 | # Sort cats as well 53 | self.cats = sorted( 54 | self.cats, 55 | key=( 56 | None 57 | if cat2idx_fn is None 58 | else cat2idx_fn.get if isinstance(cat2idx_fn, dict) else cat2idx_fn 59 | ), 60 | ) 61 | 62 | # Create a default transformation 63 | self.transform = self.create_transform(split_type == "train") 64 | self.__post_init__() 65 | 66 | @cached_property 67 | def cat2idx(self) -> dict[str, int]: 68 | return dict(zip(self.cats, range(len(self.cats)))) 69 | 70 | @cached_property 71 | def idx2cat(self) -> dict[int, str]: 72 | return dict(zip(range(len(self.cats)), self.cats)) 73 | 74 | @classmethod 75 | def create_loader(cls, **kwargs) -> DataLoader: 76 | # Get argument names from DataLoader 77 | fn_code = DataLoader.__init__.__code__ 78 | init_arg_names = fn_code.co_varnames[: fn_code.co_argcount] 79 | 80 | # Split all the given kwargs to dataset (cls) and loader kwargs 81 | set_kwargs = {k: v for k, v in kwargs.items() if k not in init_arg_names} 82 | ldr_kwargs = {k: v for k, v in kwargs.items() if k in init_arg_names} 83 | 84 | # Define default loader kwargs 85 | default_loader_kwargs = { 86 | "dataset": cls(**set_kwargs), 87 | "batch_size": 64, 88 | "num_workers": 12, 89 | "pin_memory": True, 90 | "drop_last": True, 91 | "shuffle": set_kwargs.get("split_type", "train") == "train", 92 | } 93 | 94 | # Update default loader kwargs with custom 95 | default_loader_kwargs.update(ldr_kwargs) 96 | 97 | return DataLoader(**default_loader_kwargs) 98 | 99 | @classmethod 100 | def create_loaders(cls, **kwargs) -> tuple[DataLoader, DataLoader, DataLoader]: 101 | # Create train, validationa and test loaders 102 | train_loader = cls.create_loader(split_type="train", **kwargs) 103 | val_loader = cls.create_loader(split_type="val", **kwargs) 104 | test_loader = cls.create_loader(split_type="test", **kwargs) 105 | 106 | return train_loader, val_loader, test_loader 107 | 108 | def cat2tensor(self, cat: str | list[str]) -> torch.Tensor: 109 | # Convert category name(-s) to the index list 110 | cat = cat if isinstance(cat, list) else [cat] 111 | indices = list(map(self.cat2idx.get, cat)) 112 | 113 | match self.label_type: 114 | case "enum": 115 | label = torch.tensor(indices) 116 | case "onehot": 117 | label = torch.eye(len(self.cats))[indices] 118 | case "multihot": 119 | label = torch.any(torch.eye(len(self.cats))[indices], 0, True) 120 | case _: 121 | raise ValueError(f"Unknown label type: {self.label_type}") 122 | 123 | return label.to(torch.long) 124 | 125 | def tensor2cat(self, tensor: torch.Tensor) -> str | list[str] | list[list[str]]: 126 | match self.label_type: 127 | case "enum": 128 | # Add a batch dimension if tensor is a scalar, get cats 129 | ts = tensor.unsqueeze(0) if tensor.ndim == 0 else tensor 130 | cat = [self.idx2cat[i.item()] for i in ts] 131 | return cat[0] if tensor.ndim == 0 or len(tensor) == 0 else cat 132 | case "onehot": 133 | # Get cats directly (works for both 1D and 2D tensors) 134 | cat = [self.idx2cat[i.item()] for i in torch.where(tensor)[0]] 135 | return cat[0] if tensor.ndim == 1 else cat 136 | case "multihot": 137 | # Add a batch dimension if tensor is a 1D list, get cats 138 | ts = tensor if tensor.ndim > 1 else tensor.unsqueeze(0) 139 | cat = [[self.idx2cat[i.item()] for i in torch.where[t][0]] for t in ts] 140 | return cat[0] if tensor.ndim == 1 else cat 141 | case _: 142 | raise ValueError(f"Unknown label type: {self.label_type}") 143 | 144 | def __len__(self) -> int: 145 | return len(self.data) 146 | 147 | @abstractmethod 148 | def __getitem__(self, index: int) -> Any: ... 149 | 150 | def __post_init__(self): 151 | pass 152 | -------------------------------------------------------------------------------- /docs/_static/js/zenodo-icon.js: -------------------------------------------------------------------------------- 1 | FontAwesome.library.add( 2 | (faListOldStyle = { 3 | prefix: "fa-custom", 4 | iconName: "zenodo", 5 | icon: [ 6 | 146.355, // viewBox width 7 | 47.955, // viewBox height 8 | [], // ligature 9 | "e001", // unicode codepoint - private use area 10 | "M145.301,18.875c-0.705-1.602-1.656-2.997-2.846-4.19c-1.189-1.187-2.584-2.125-4.188-2.805 c-1.604-0.678-3.307-1.02-5.102-1.02c-1.848,0-3.564,0.342-5.139,1.02c-0.787,0.339-1.529,0.74-2.225,1.205 c-0.701,0.469-1.357,1.003-1.967,1.6c-0.377,0.37-0.727,0.761-1.051,1.17c-0.363,0.457-0.764,1.068-0.992,1.439 c-0.281,0.456-0.957,1.861-1.254,2.828c0.041-1.644,0.281-4.096,1.254-5.472V2.768c0-0.776-0.279-1.431-0.84-1.965 C120.396,0.268,119.75,0,119.021,0c-0.777,0-1.43,0.268-1.969,0.803c-0.531,0.534-0.801,1.189-0.801,1.965v10.569 c-1.117-0.778-2.322-1.386-3.605-1.824c-1.285-0.436-2.637-0.654-4.045-0.654c-1.799,0-3.496,0.342-5.1,1.02 c-1.605,0.679-3,1.618-4.195,2.805c-1.186,1.194-2.139,2.588-2.836,4.19c-0.053,0.12-0.1,0.242-0.15,0.364 c-0.047-0.122-0.094-0.244-0.146-0.364c-0.705-1.602-1.656-2.997-2.846-4.19c-1.189-1.187-2.586-2.125-4.188-2.805 c-1.604-0.678-3.307-1.02-5.102-1.02c-1.848,0-3.564,0.342-5.139,1.02c-1.584,0.679-2.979,1.618-4.191,2.805 c-1.213,1.194-2.164,2.588-2.842,4.19c-0.049,0.115-0.092,0.23-0.137,0.344c-0.047-0.114-0.092-0.229-0.141-0.344 c-0.701-1.602-1.65-2.997-2.84-4.19c-1.191-1.187-2.588-2.125-4.193-2.805c-1.604-0.678-3.301-1.02-5.104-1.02 c-1.842,0-3.557,0.342-5.137,1.02c-1.578,0.679-2.977,1.618-4.186,2.805c-1.221,1.194-2.166,2.588-2.848,4.19 c-0.043,0.106-0.082,0.214-0.125,0.32c-0.043-0.106-0.084-0.214-0.131-0.32c-0.707-1.602-1.656-2.997-2.848-4.19 c-1.188-1.187-2.582-2.125-4.184-2.805c-1.605-0.678-3.309-1.02-5.104-1.02c-1.85,0-3.564,0.342-5.137,1.02 c-1.467,0.628-2.764,1.488-3.91,2.552V13.99c0-1.557-1.262-2.822-2.82-2.822H3.246c-1.557,0-2.82,1.265-2.82,2.822 c0,1.559,1.264,2.82,2.82,2.82h15.541L0.557,41.356C0.195,41.843,0,42.433,0,43.038v1.841c0,1.558,1.264,2.822,2.822,2.822 h21.047c1.488,0,2.705-1.153,2.812-2.614c0.932,0.743,1.967,1.364,3.109,1.848c1.605,0.684,3.299,1.021,5.102,1.021 c2.723,0,5.15-0.726,7.287-2.187c1.727-1.176,3.092-2.639,4.084-4.389v3.805c0,0.778,0.264,1.436,0.805,1.968 c0.531,0.537,1.189,0.803,1.967,0.803c0.73,0,1.369-0.266,1.93-0.803c0.561-0.532,0.838-1.189,0.838-1.968v-9.879h-0.01 c0-0.002,0.01-0.013,0.01-0.013s-6.137,0-6.912,0c-0.58,0-1.109,0.154-1.566,0.472c-0.463,0.316-0.793,0.744-0.982,1.275 l-0.453,0.93c-0.631,1.365-1.566,2.443-2.809,3.244c-1.238,0.803-2.633,1.201-4.188,1.201c-1.023,0-2.004-0.191-2.955-0.579 c-0.941-0.39-1.758-0.935-2.439-1.64c-0.682-0.703-1.227-1.52-1.641-2.443c-0.41-0.924-0.617-1.893-0.617-2.916v-2.476h17.715 h1.309h5.539v-8.385c0-1.015,0.191-1.99,0.582-2.912c0.389-0.922,0.936-1.74,1.645-2.444c0.699-0.703,1.514-1.249,2.441-1.641 c0.918-0.388,1.92-0.581,2.982-0.581c1.023,0,2.01,0.193,2.955,0.581c0.945,0.393,1.762,0.938,2.439,1.641 c0.682,0.704,1.225,1.521,1.641,2.444c0.412,0.922,0.621,1.896,0.621,2.912v21.208c0,0.778,0.266,1.436,0.799,1.968 c0.535,0.537,1.191,0.803,1.971,0.803c0.729,0,1.371-0.266,1.934-0.803c0.553-0.532,0.834-1.189,0.834-1.968v-3.803 c0.588,1.01,1.283,1.932,2.1,2.749c1.189,1.189,2.586,2.124,4.191,2.804c1.602,0.684,3.303,1.021,5.102,1.021 c1.795,0,3.498-0.337,5.102-1.021c1.602-0.68,3.01-1.614,4.227-2.804c1.211-1.19,2.162-2.589,2.842-4.189 c0.037-0.095,0.074-0.19,0.109-0.286c0.039,0.096,0.074,0.191,0.113,0.286c0.678,1.601,1.625,2.999,2.842,4.189 c1.213,1.189,2.607,2.124,4.189,2.804c1.574,0.684,3.293,1.021,5.139,1.021c1.795,0,3.5-0.337,5.105-1.021 c1.6-0.68,2.994-1.614,4.184-2.804c1.191-1.19,2.141-2.589,2.848-4.189c0.051-0.12,0.098-0.239,0.146-0.36 c0.049,0.121,0.094,0.24,0.146,0.36c0.703,1.601,1.652,2.999,2.842,4.189c1.189,1.189,2.586,2.124,4.191,2.804 c1.604,0.684,3.303,1.021,5.102,1.021c1.795,0,3.498-0.337,5.102-1.021c1.604-0.68,3.01-1.614,4.227-2.804 c1.211-1.19,2.16-2.589,2.842-4.189c0.678-1.606,1.02-3.306,1.02-5.104v-10.86C146.355,22.182,146.002,20.479,145.301,18.875z M7.064,42.06l14.758-19.874c-0.078,0.587-0.121,1.184-0.121,1.791v10.86c0,1.799,0.35,3.498,1.059,5.104 c0.328,0.752,0.719,1.458,1.156,2.119c-0.016,0-0.031-0.001-0.047-0.001H7.064z M42.541,26.817H27.24v-2.841 c0-1.015,0.189-1.99,0.58-2.912c0.391-0.922,0.936-1.74,1.645-2.444c0.697-0.703,1.516-1.249,2.438-1.641 c0.922-0.388,1.92-0.581,2.99-0.581c1.02,0,2.002,0.193,2.949,0.581c0.949,0.393,1.764,0.938,2.441,1.641 c0.682,0.704,1.225,1.521,1.641,2.444c0.414,0.922,0.617,1.896,0.617,2.912V26.817z M91.688,34.837 c0,1.023-0.189,1.992-0.582,2.916c-0.389,0.924-0.936,1.74-1.637,2.443c-0.705,0.705-1.523,1.25-2.445,1.64 c-0.92,0.388-1.92,0.579-2.984,0.579c-1.023,0-2.004-0.191-2.955-0.579c-0.945-0.39-1.758-0.935-2.439-1.64 c-0.682-0.703-1.229-1.52-1.641-2.443s-0.617-1.893-0.617-2.916v-10.86c0-1.015,0.191-1.99,0.582-2.912 c0.387-0.922,0.934-1.74,1.639-2.444c0.701-0.703,1.52-1.249,2.441-1.641c0.922-0.388,1.92-0.581,2.99-0.581 c1.018,0,2.004,0.193,2.947,0.581c0.951,0.393,1.764,0.938,2.443,1.641c0.68,0.704,1.223,1.521,1.641,2.444 c0.412,0.922,0.617,1.896,0.617,2.912V34.837z M116.252,34.837c0,1.023-0.203,1.992-0.617,2.916 c-0.412,0.924-0.961,1.74-1.641,2.443c-0.68,0.705-1.492,1.25-2.443,1.64c-0.943,0.388-1.93,0.579-2.949,0.579 c-1.07,0-2.066-0.191-2.988-0.579c-0.924-0.39-1.74-0.935-2.439-1.64c-0.707-0.703-1.252-1.52-1.643-2.443 s-0.584-1.893-0.584-2.916v-10.86c0-1.015,0.211-1.99,0.619-2.912c0.416-0.922,0.961-1.74,1.641-2.444 c0.682-0.703,1.496-1.249,2.439-1.641c0.951-0.388,1.934-0.581,2.955-0.581c1.068,0,2.062,0.193,2.986,0.581 c0.926,0.393,1.738,0.938,2.443,1.641c0.703,0.704,1.252,1.521,1.641,2.444c0.389,0.922,0.58,1.896,0.58,2.912V34.837z M140.816,34.837c0,1.023-0.193,1.992-0.58,2.916c-0.393,0.924-0.939,1.74-1.641,2.443c-0.705,0.705-1.523,1.25-2.443,1.64 c-0.922,0.388-1.92,0.579-2.986,0.579c-1.021,0-2.004-0.191-2.955-0.579c-0.943-0.39-1.758-0.935-2.438-1.64 c-0.682-0.703-1.23-1.52-1.643-2.443s-0.619-1.893-0.619-2.916v-10.86c0-1.015,0.193-1.99,0.584-2.912 c0.387-0.922,0.934-1.74,1.639-2.444c0.703-0.703,1.518-1.249,2.441-1.641c0.924-0.388,1.92-0.581,2.99-0.581 c1.02,0,2.004,0.193,2.949,0.581c0.949,0.393,1.764,0.938,2.441,1.641c0.682,0.704,1.225,1.521,1.643,2.444 c0.412,0.922,0.617,1.896,0.617,2.912V34.837z", // svg path (https://github.com/zenodo/zenodo/blob/482ee72ad501cbbd7f8ce8df9b393c130d1970f7/zenodo/modules/theme/static/img/zenodo.svg) 11 | ], 12 | }) 13 | ); -------------------------------------------------------------------------------- /src/glasses_detector/architectures/tiny_binary_detector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class TinyBinaryDetector(nn.Module): 6 | """Tiny binary detector. 7 | 8 | This is a custom detector created with the aim to contain very few 9 | parameters while maintaining a reasonable accuracy. It only has 10 | several sequential convolutional and pooling blocks (with 11 | batch-norm in between). 12 | 13 | Note: 14 | I tried varying the architecture, including activations, 15 | convolution behavior (groups and stride), pooling, and layer 16 | structure. This also includes residual and dense connections, 17 | as well as combinations. Turns out, they do not perform as well 18 | as the current architecture which is just a bunch of 19 | CONV-RELU-BN-MAXPOOL blocks with no paddings. 20 | """ 21 | 22 | def __init__(self): 23 | super().__init__() 24 | 25 | # Several convolutional blocks 26 | self.features = nn.Sequential( 27 | self._create_block(3, 6, 15), 28 | self._create_block(6, 12, 7), 29 | self._create_block(12, 24, 5), 30 | self._create_block(24, 48, 3), 31 | self._create_block(48, 96, 3), 32 | self._create_block(96, 192, 3), 33 | nn.AdaptiveAvgPool2d((1, 1)), 34 | nn.Flatten(), 35 | ) 36 | 37 | # Fully connected layer 38 | self.fc = nn.Linear(192, 4) 39 | 40 | def _create_block(self, num_in, num_out, filter_size): 41 | return nn.Sequential( 42 | nn.Conv2d(num_in, num_out, filter_size, 1, "valid", bias=False), 43 | nn.ReLU(), 44 | nn.BatchNorm2d(num_out), 45 | nn.MaxPool2d(2, 2), 46 | ) 47 | 48 | def forward( 49 | self, 50 | imgs: list[torch.Tensor], 51 | targets: list[dict[str, torch.Tensor]] | None = None, 52 | ) -> dict[str, torch.Tensor] | list[dict[str, torch.Tensor]]: 53 | """Forward pass through the network. 54 | 55 | This takes a list of images and returns a list of predictions 56 | for each image or a loss dictionary if the targets are provided. 57 | This is to match the API of the PyTorch *torchvision* models, 58 | which specify that: 59 | 60 | "During training, returns a dictionary containing the 61 | classification and regression losses for each image in the 62 | batch. During inference, returns a list of dictionaries, one 63 | for each input image. Each dictionary contains the predicted 64 | boxes, labels, and scores for all detections in the image." 65 | 66 | Args: 67 | imgs (list[torch.Tensor]): A list of images. 68 | annotations (list[dict[str, torch.Tensor]], optional): A 69 | list of annotations for each image. Each annotation is a 70 | dictionary that contains: 71 | 72 | 1. ``"boxes"``: the bounding boxes for each object 73 | 2. ``"labels"``: labels 74 | for all objects in the image. If ``None``, the 75 | network is in inference mode. 76 | 77 | 78 | Returns: 79 | dict[str, torch.Tensor] | list[dict[str, torch.Tensor]]: 80 | A dictionary with only a single "regression" loss entry if 81 | ``targets`` were specified. Otherwise, a list of 82 | dictionaries with the predicted bounding boxes, labels, and 83 | scores for all detections in each image. 84 | """ 85 | # Forward pass; insert a new dimension to indicate a single bbox 86 | preds = self.fc(self.features(torch.stack(imgs))) 87 | 88 | # Get width and height 89 | h, w = imgs[0].shape[-2:] 90 | 91 | # Convert to (x_min, y_min, x_max, y_max) 92 | preds[:, 0] = preds[:, 0] * w 93 | preds[:, 1] = preds[:, 1] * h 94 | preds[:, 2] = preds[:, 0] + preds[:, 2] * w 95 | preds[:, 3] = preds[:, 1] + preds[:, 3] * h 96 | 97 | if targets is None: 98 | # Clamp the coordinates to the image size 99 | preds[:, 0] = torch.clamp(preds[:, 0], 0, w) 100 | preds[:, 1] = torch.clamp(preds[:, 1], 0, h) 101 | preds[:, 2] = torch.clamp(preds[:, 2], 0, w) 102 | preds[:, 3] = torch.clamp(preds[:, 3], 0, h) 103 | 104 | # Convert to shape (N, 1, 4) 105 | preds = [*preds[:, None, :]] 106 | 107 | if targets is not None: 108 | return self.compute_loss(preds, targets, imgs[0].size()[-2:]) 109 | else: 110 | return [ 111 | { 112 | "boxes": pred, 113 | "labels": torch.ones(1, dtype=torch.int64, device=pred.device), 114 | "scores": torch.ones(1, device=pred.device), 115 | } 116 | for pred in preds 117 | ] 118 | 119 | def compute_loss( 120 | self, 121 | preds: list[torch.Tensor], 122 | targets: list[dict[str, torch.Tensor]], 123 | size: tuple[int, int], 124 | ) -> dict[str, torch.Tensor]: 125 | """Compute the loss for the predicted bounding boxes. 126 | 127 | This computes the MSE loss between the predicted bounding boxes 128 | and the target bounding boxes. The returned dictionary contains 129 | only one key: "regression". 130 | 131 | Args: 132 | preds (list[torch.Tensor]): A list of predicted bounding 133 | boxes for each image. 134 | targets (list[dict[str, torch.Tensor]]): A list of targets 135 | for each image. 136 | 137 | Returns: 138 | dict[str, torch.Tensor]: A dictionary with only one key: 139 | "regression" which contains the regression MSE loss. 140 | """ 141 | # Initialize criterion, loss dictionary, and device 142 | criterion, loss_dict, device = nn.MSELoss(), {}, preds[0].device 143 | 144 | # Use to divide (x_min, y_min, x_max, y_max) by (w, h, w, h) 145 | size = torch.tensor([[*size[::-1], *size[::-1]]], device=device) 146 | 147 | for i, pred in enumerate(preds): 148 | # Compute the loss (normalize the coordinates before that) 149 | loss = criterion(pred / size, targets[i]["boxes"][:1] / size) 150 | loss_dict[i] = loss 151 | 152 | return loss_dict 153 | -------------------------------------------------------------------------------- /docs/docs/examples.rst: -------------------------------------------------------------------------------- 1 | :fas:`code` Examples 2 | ==================== 3 | 4 | .. role:: bash(code) 5 | :language: bash 6 | :class: highlight 7 | 8 | .. _command-line: 9 | 10 | Command Line 11 | ------------ 12 | 13 | You can run predictions via the command line. For example, classification of a single or multiple images, can be performed via: 14 | 15 | .. code-block:: bash 16 | 17 | glasses-detector -i path/to/img.jpg --task classification # Prints "present" or "absent" 18 | glasses-detector -i path/to/dir --output path/to/output.csv # Creates CSV (default --task is classification) 19 | 20 | 21 | It is possible to specify the **kind** of :bash:`--task` in the following format :bash:`task:kind`, for example, we may want to classify only *sunglasses* (only glasses with opaque lenses). Further, more options can be specified, like :bash:`--format`, :bash:`--size`, :bash:`--batch-size`, :bash:`--device`, etc: 22 | 23 | .. code-block:: bash 24 | 25 | glasses-detector -i path/to/img.jpg -t classification:sunglasses -f proba # Prints probability of sunglasses 26 | glasses-detector -i path/to/dir -o preds.pkl -s large -b 64 -d cuda # Fast and accurate processing 27 | 28 | Running *detection* and *segmentation* is similar, though we may want to generate a folder of predictions when processing a directory (but we can also squeeze all the predictions into a single file, such as ``.npy``): 29 | 30 | .. code-block:: bash 31 | 32 | glasses-detector -i path/to/img.jpg -t detection # Shows image with bounding boxes 33 | glasses-detector -i path/to/dir -t segmentation -f mask -e .jpg # Generates dir with masks 34 | 35 | .. tip:: 36 | 37 | For a more exhaustive explanation of the available options use :bash:`glasses-detector --help` or check :doc:`cli`. 38 | 39 | 40 | Python Script 41 | ------------- 42 | 43 | The most straightforward way to perform a prediction on a single file (or a list of files) is to use :meth:`~glasses_detector.components.pred_interface.PredInterface.process_file`. Although the prediction(-s) can be saved to a file or a directory, in most cases, this is useful to immediately show the prediction result(-s). 44 | 45 | .. code-block:: python 46 | :linenos: 47 | 48 | from glasses_detector import GlassesClassifier, GlassesDetector 49 | 50 | # Prints either '1' or '0' 51 | classifier = GlassesClassifier() 52 | classifier.process_file( 53 | input_path="path/to/img.jpg", # can be a list of paths 54 | format={True: "1", False: "0"}, # similar to format="int" 55 | show=True, # to print the prediction 56 | ) 57 | 58 | # Opens a plot in a new window 59 | detector = GlassesDetector() 60 | detector.process_file( 61 | image="path/to/img.jpg", # can be a list of paths 62 | format="img", # to return the image with drawn bboxes 63 | show=True, # to show the image using matplotlib 64 | ) 65 | 66 | A more useful method is :meth:`~glasses_detector.components.pred_interface.PredInterface.process_dir` which goes through all the images in the directory and generates the predictions into a single file or a directory of files. Also note how we can specify task ``kind`` and model ``size``: 67 | 68 | .. code-block:: python 69 | :linenos: 70 | 71 | from glasses_detector import GlassesClassifier, GlassesSegmenter 72 | 73 | # Generates a CSV file with image paths and labels 74 | classifier = GlassesClassifier(kind="sunglasses") 75 | classifier.process_dir( 76 | input_path="path/to/dir", # failed files will raise a warning 77 | output_path="path/to/output.csv", # img_name1.jpg,... 78 | format="proba", # is a probability of sunglasses 79 | pbar="Processing", # set to None to disable 80 | ) 81 | 82 | # Generates a directory with masks 83 | segmenter = GlassesSegmenter(size="large", device="cuda") 84 | segmenter.process_dir( 85 | input_path="path/to/dir", # output dir defaults to path/to/dir_preds 86 | ext=".jpg", # saves each mask in JPG format 87 | format="mask", # output type will be a grayscale PIL image 88 | batch_size=32, # to speed up the processing 89 | output_size=(512, 512), # set to None to keep the same size as image 90 | ) 91 | 92 | 93 | It is also possible to directly use :meth:`~glasses_detector.components.pred_interface.PredInterface.predict` which allows to process already loaded images. This is useful when you want to incorporate the prediction into a custom pipeline. 94 | 95 | .. code-block:: python 96 | :linenos: 97 | 98 | import numpy as np 99 | from glasses_detector import GlassesDetector 100 | 101 | # Predicts normalized bounding boxes 102 | detector = GlassesDetector() 103 | predictions = detector( 104 | image=np.random.randint(0, 256, size=(224, 224, 3), dtype=np.uint8), 105 | format="float", 106 | ) 107 | print(type(prediction), len(prediction)) # 10 108 | 109 | 110 | .. admonition:: Refer to API documentation for model-specific examples 111 | 112 | * :class:`~glasses_detector.classifier.GlassesClassifier` and its :meth:`~glasses_detector.classifier.GlassesClassifier.predict` 113 | * :class:`~glasses_detector.detector.GlassesDetector` and its :meth:`~glasses_detector.detector.GlassesDetector.predict` 114 | * :class:`~glasses_detector.segmenter.GlassesSegmenter` and its :meth:`~glasses_detector.segmenter.GlassesSegmenter.predict` 115 | 116 | Demo 117 | ---- 118 | 119 | Feel free to play around with some `demo image files `_. For example, after installing through `pip `_, you can run: 120 | 121 | .. code-block:: bash 122 | 123 | git clone https://github.com/mantasu/glasses-detector && cd glasses-detector/data 124 | glasses-detector -i demo -o demo_labels.csv --task classification:sunglasses -f proba 125 | glasses-detector -i demo -o demo_masks -t segmentation:full -f img -e .jpg 126 | 127 | Alternatively, you can check out the `demo notebook `_ which can be also accessed on `Google Colab `_. 128 | -------------------------------------------------------------------------------- /src/glasses_detector/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. class:: FilePath 3 | 4 | .. data:: FilePath 5 | :noindex: 6 | :type: typing.TypeAliasType 7 | :value: str | bytes | os.PathLike 8 | 9 | Type alias for a file path. 10 | 11 | :class:`str` | :class:`bytes` | :class:`os.PathLike` 12 | """ 13 | 14 | import functools 15 | import os 16 | import typing 17 | from typing import Any, Callable, Iterable, TypeGuard, overload 18 | from urllib.parse import urlparse 19 | 20 | import torch 21 | from PIL import Image 22 | 23 | type FilePath = str | bytes | os.PathLike 24 | 25 | 26 | class copy_signature[**P, T]: 27 | """Decorator to copy a function's or a method's signature. 28 | 29 | This decorator takes a callable and copies its signature to the 30 | decorated function or method. 31 | 32 | Example 33 | ------- 34 | 35 | .. code-block:: python 36 | 37 | def full_signature(x: bool, *extra: int) -> str: ... 38 | 39 | @copy_signature(full_signature) 40 | def test_signature(*args, **kwargs): 41 | return full_signature(*args, **kwargs) 42 | 43 | reveal_type(test_signature) # 'def (x: bool, *extra: int) -> str' 44 | 45 | .. seealso:: 46 | 47 | https://github.com/python/typing/issues/270#issuecomment-1344537820 48 | 49 | Args: 50 | source (typing.Callable[P, T]): The callable whose signature to 51 | copy. 52 | """ 53 | 54 | def __init__(self, source: Callable[P, T]): 55 | # The source callable 56 | self.source = source 57 | 58 | def __call__(self, target: Callable[..., T]) -> Callable[P, T]: 59 | @functools.wraps(self.source) 60 | def wrapped(*args: P.args, **kwargs: P.kwargs) -> T: 61 | return target(*args, **kwargs) 62 | 63 | return wrapped 64 | 65 | 66 | class eval_infer_mode: 67 | """Context manager and decorator for evaluation and inference. 68 | 69 | This class can be used as a context manager or a decorator to set a 70 | PyTorch :class:`~torch.nn.Module` to evaluation mode via 71 | :meth:`~torch.nn.Module.eval` and enable 72 | :class:`~torch.no_grad` for the duration of a function 73 | or a ``with`` statement. After the function or the ``with`` 74 | statement, the model's mode, i.e., :attr:`~torch.nn.Module.training` 75 | property, and :class:`~torch.no_grad` are restored to their 76 | original states. 77 | 78 | Example 79 | ------- 80 | 81 | .. code-block:: python 82 | 83 | model = ... # Your PyTorch model 84 | 85 | @eval_infer_mode(model) 86 | def your_function(): 87 | # E.g., forward pass 88 | pass 89 | 90 | # or 91 | 92 | with eval_infer_mode(model): 93 | # E.g., forward pass 94 | pass 95 | 96 | Args: 97 | model (torch.nn.Module): The PyTorch model to be set to 98 | evaluation mode. 99 | """ 100 | 101 | def __init__(self, model: torch.nn.Module): 102 | self.model = model 103 | self.was_training = model.training 104 | 105 | def __call__[F](self, func: F) -> F: 106 | @functools.wraps(func) 107 | def wrapper(*args, **kwargs): 108 | with self: 109 | return func(*args, **kwargs) 110 | 111 | return wrapper 112 | 113 | def __enter__(self): 114 | self.model.eval() 115 | self._no_grad = torch.no_grad() 116 | self._no_grad.__enter__() 117 | 118 | def __exit__(self, type: Any, value: Any, traceback: Any): 119 | self._no_grad.__exit__(type, value, traceback) 120 | self.model.train(self.was_training) 121 | 122 | 123 | def is_url(x: str) -> bool: 124 | """Check if a string is a valid URL. 125 | 126 | Takes any string and checks if it is a valid URL. 127 | 128 | .. seealso:: 129 | 130 | https://stackoverflow.com/a/38020041 131 | 132 | Args: 133 | x: The string to check. 134 | 135 | Returns: 136 | :data:`True` if the string is a valid URL, :data:`False` 137 | otherwise. 138 | """ 139 | try: 140 | result = urlparse(x) 141 | return all([result.scheme, result.netloc]) 142 | except: 143 | return False 144 | 145 | 146 | @overload 147 | def flatten[T](items: T) -> T: ... 148 | 149 | 150 | @overload 151 | def flatten[T](items: typing.Iterable[T | typing.Iterable]) -> list[T]: ... 152 | 153 | 154 | def flatten[T](items: T | Iterable[T | Iterable]) -> T | list[T]: 155 | """Flatten a nested list. 156 | 157 | This function takes any nested iterable and returns a flat list. 158 | 159 | Args: 160 | items (T | typing.Iterable[T | typing.Iterable]): The nested 161 | iterable to flatten. 162 | 163 | Returns: 164 | T | list[T]: The flattened list or the original ``items`` value 165 | if it is not an iterable or is of type :class:`str`. 166 | """ 167 | if not isinstance(items, Iterable) or isinstance(items, str): 168 | # Not iterable 169 | return items 170 | 171 | # Init flat list 172 | flattened = [] 173 | 174 | for item in items: 175 | if isinstance(item, Iterable) and not isinstance(item, str): 176 | flattened.extend(flatten(item)) 177 | else: 178 | flattened.append(item) 179 | 180 | return flattened 181 | 182 | 183 | def is_path_type(path: Any) -> TypeGuard[FilePath]: 184 | """Check if an object is a valid path type. 185 | 186 | This function takes any object and checks if it is a valid path 187 | type. A valid path type is either a :class:`str`, :class:`bytes` or 188 | :class:`os.PathLike` object. 189 | 190 | Args: 191 | path: The object to check. 192 | 193 | Returns: 194 | :data:`True` if the object is a valid path type, :data:`False` 195 | otherwise. 196 | """ 197 | return isinstance(path, (str, bytes, os.PathLike)) 198 | 199 | 200 | def is_image_file(path: FilePath) -> bool: 201 | """Check if a file is an image. 202 | 203 | This function takes a file path and checks if it is an image file. 204 | This is done by checking if the file exists and if it can be 205 | identified as an image 206 | 207 | Args: 208 | path: The path to the file. 209 | 210 | Returns: 211 | :data:`True` if the file is an image, :data:`False` otherwise. 212 | """ 213 | if not os.path.isfile(path): 214 | return False 215 | 216 | try: 217 | with Image.open(path) as img: 218 | img.verify() 219 | return True 220 | except (IOError, SyntaxError): 221 | return False 222 | -------------------------------------------------------------------------------- /scripts/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import pytorch_lightning as pl 5 | import torch 6 | from pytorch_lightning.callbacks import ModelCheckpoint 7 | from pytorch_lightning.cli import LightningCLI 8 | from pytorch_lightning.tuner import Tuner 9 | 10 | PROJECT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 11 | DEFAULT_KINDS = { 12 | "classification": "anyglasses", 13 | "detection": "worn", 14 | "segmentation": "smart", 15 | } 16 | 17 | sys.path.append(os.path.join(PROJECT_DIR, "src")) 18 | torch.set_float32_matmul_precision("medium") 19 | 20 | from glasses_detector import GlassesClassifier, GlassesDetector, GlassesSegmenter 21 | from glasses_detector._data import ( 22 | BinaryClassificationDataset, 23 | BinaryDetectionDataset, 24 | BinarySegmentationDataset, 25 | ) 26 | from glasses_detector._wrappers import BinaryClassifier, BinaryDetector, BinarySegmenter 27 | 28 | 29 | class RunCLI(LightningCLI): 30 | def add_arguments_to_parser(self, parser): 31 | # Add args for wrapper creation 32 | parser.add_argument( 33 | "-r", 34 | "--root", 35 | metavar="path/to/data/root", 36 | type=str, 37 | default=os.path.join(PROJECT_DIR, "data"), 38 | help="Path to the data directory with classification and segmentation subdirectories that contain datasets for different kinds of tasks. Defaults to 'data' under project root.", 39 | ) 40 | parser.add_argument( 41 | "-t", 42 | "--task", 43 | metavar="", 44 | type=str, 45 | default="classification:anyglasses", 46 | choices=(ch := [ 47 | "classification", 48 | "classification:anyglasses", 49 | "classification:sunglasses", 50 | "classification:eyeglasses", 51 | "classification:shadows", 52 | "detection", 53 | "detection:eyes", 54 | "detection:solo", 55 | "detection:worn", 56 | "segmentation", 57 | "segmentation:frames", 58 | "segmentation:full", 59 | "segmentation:legs", 60 | "segmentation:lenses", 61 | "segmentation:shadows", 62 | "segmentation:smart", 63 | ]), 64 | help=f"The kind of task to train/test the model for. One of {", ".join([f"'{c}'" for c in ch])}. If specified only as 'classification', 'detection', or 'segmentation', the subcategories 'anyglasses', 'worn', and 'smart' will be chosen, respectively. Defaults to 'classification:anyglasses'.", 65 | ) 66 | parser.add_argument( 67 | "-s", 68 | "--size", 69 | metavar="", 70 | type=str, 71 | default="medium", 72 | choices=["small", "medium", "large"], 73 | help="The model size which determines architecture type. One of 'small', 'medium', 'large'. Defaults to 'medium'.", 74 | ) 75 | parser.add_argument( 76 | "-b", 77 | "--batch-size", 78 | metavar="", 79 | type=int, 80 | default=64, 81 | help="The batch size used for training. Defaults to 64.", 82 | ) 83 | parser.add_argument( 84 | "-n", 85 | "--num-workers", 86 | metavar="", 87 | type=int, 88 | default=8, 89 | help="The number of workers for the data loader. Defaults to 8.", 90 | ) 91 | parser.add_argument( 92 | "-w", 93 | "--weights", 94 | metavar="path/to/weights", 95 | type=str | None, 96 | default=None, 97 | help="Path to weights to load into the model. Defaults to None.", 98 | ) 99 | parser.add_argument( 100 | "-f", 101 | "--find-lr", 102 | action="store_true", 103 | help="Whether to run the learning rate finder before training. Defaults to False.", 104 | ) 105 | parser.add_lightning_class_args(ModelCheckpoint, "checkpoint") 106 | 107 | # Checkpoint and trainer defaults 108 | parser.set_defaults( 109 | { 110 | "checkpoint.dirpath": "checkpoints", 111 | "checkpoint.save_last": False, 112 | "checkpoint.monitor": "val_loss", 113 | "checkpoint.mode": "min", 114 | "trainer.precision": "bf16-mixed", 115 | "trainer.max_epochs": 300, 116 | } 117 | ) 118 | 119 | # Link argument with wrapper creation callback arguments 120 | parser.link_arguments("root", "model.root", apply_on="parse") 121 | parser.link_arguments("task", "model.task", apply_on="parse") 122 | parser.link_arguments("size", "model.size", apply_on="parse") 123 | parser.link_arguments("batch_size", "model.batch_size", apply_on="parse") 124 | parser.link_arguments("num_workers", "model.num_workers", apply_on="parse") 125 | parser.link_arguments("weights", "model.weights", apply_on="parse") 126 | 127 | def before_fit(self): 128 | if self.config.fit.checkpoint.filename is None: 129 | # Update default filename for checkpoint saver callback 130 | self.model_name = ( 131 | self.config.fit.model.task.replace(":", "-") + "-" + self.config.fit.model.size 132 | ) 133 | self.trainer.callbacks[-1].filename = ( 134 | self.model_name + "-{epoch:02d}-{val_loss:.3f}" 135 | ) 136 | 137 | if self.config.fit.find_lr: 138 | # Run learning rate finder 139 | tuner = Tuner(self.trainer) 140 | lr_finder = tuner.lr_find(self.model, min_lr=1e-5, max_lr=1e-1, num_training=500) 141 | self.model.lr = lr_finder.suggestion() 142 | 143 | def after_fit(self): 144 | # Get the best checkpoint path and load it 145 | ckpt_path = self.trainer.callbacks[-1].best_model_path 146 | ckpt_dir = os.path.dirname(ckpt_path) 147 | ckpt = torch.load(ckpt_path) 148 | 149 | # Load wights and save the inner model as pth 150 | self.model.load_state_dict(ckpt["state_dict"]) 151 | torch.save( 152 | self.model.model.state_dict(), 153 | os.path.join(ckpt_dir, self.model_name + ".pth"), 154 | ) 155 | 156 | def create_wrapper_callback( 157 | root: str = "data", 158 | task: str = "classification", 159 | size: str = "medium", 160 | batch_size: int = 64, 161 | num_workers: int = 8, 162 | weights: str | None = None, 163 | ) -> pl.LightningModule: 164 | 165 | # Get task and kind 166 | task_and_kind = task.split(":") 167 | task = task_and_kind[0] 168 | kind = DEFAULT_KINDS[task] if len(task_and_kind) == 1 else task_and_kind[1] 169 | 170 | # Get model and dataset classes 171 | model_cls, data_cls = { 172 | "classification": (GlassesClassifier, BinaryClassificationDataset), 173 | "detection": (GlassesDetector, BinaryDetectionDataset), 174 | "segmentation": (GlassesSegmenter, BinarySegmentationDataset), 175 | }[task] 176 | 177 | # Set-up wrapper initialization kwargs 178 | kwargs = { 179 | "root": os.path.join(root, task, kind), 180 | "batch_size": batch_size, 181 | "num_workers": num_workers, 182 | } 183 | 184 | # Update wrapper initialization kwargs and set the initializer class 185 | if task == "classification": 186 | kwargs["cat2idx_fn"] = {"no_" + kind: 0, kind: 1} 187 | wrapper_cls = BinaryClassifier 188 | elif task == "detection": 189 | wrapper_cls = BinaryDetector 190 | elif task == "segmentation": 191 | wrapper_cls = BinarySegmenter 192 | 193 | # Initialize model architecture and load weights if needed 194 | model = model_cls(kind=kind, size=size, weights=weights).model 195 | 196 | return wrapper_cls(model, *data_cls.create_loaders(**kwargs)) 197 | 198 | 199 | def cli_main(): 200 | cli = RunCLI(create_wrapper_callback, seed_everything_default=0) 201 | 202 | 203 | if __name__ == "__main__": 204 | cli_main() 205 | -------------------------------------------------------------------------------- /src/glasses_detector/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import inspect 3 | import os 4 | 5 | import torch 6 | 7 | from . import GlassesClassifier, GlassesDetector, GlassesSegmenter 8 | from .components import BaseGlassesModel 9 | 10 | 11 | def parse_kwargs(): 12 | parser = argparse.ArgumentParser( 13 | prog="Glasses Detector", 14 | description=f"Classification, detection, and segmentation of glasses " 15 | f"in images.", 16 | ) 17 | 18 | parser.add_argument( 19 | "-i", 20 | "--input", 21 | metavar="path/to/dir/or/file", 22 | type=str, 23 | required=True, 24 | help="Path to the input image or the directory with images.", 25 | ) 26 | parser.add_argument( 27 | "-o", 28 | "--output", 29 | metavar="path/to/dir/or/file", 30 | type=str, 31 | default=None, 32 | help=f"Path to the output file or the directory. If not provided, " 33 | f"then, if input is a file, the prediction will be printed (or shown " 34 | f"if it is an image), otherwise, if input is a directory, the " 35 | f"predictions will be written to a directory with the same name with " 36 | f"an added suffix '_preds'. If provided as a file, then the " 37 | f"prediction(-s) will be saved to this file (supported extensions " 38 | f"include: .txt, .csv, .json, .npy, .pkl, .jpg, .png). If provided as " 39 | f"a directory, then the predictions will be saved to this directory " 40 | f"use `--extension` flag to specify the file extensions in that " 41 | f"directory. Defaults to None.", 42 | ) 43 | parser.add_argument( 44 | "-e", 45 | "--extension", 46 | metavar="", 47 | type=str, 48 | default=None, 49 | choices=(ch := [".txt", ".csv", ".json", ".npy", ".pkl", ".jpg", ".png"]), 50 | help=f"Only used if `--output` is a directory. The extension to " 51 | f"use to save the predictions as files. Common extensions include: " 52 | f"{", ".join([f"{c}" for c in ch])}. If not specified, it will be set " 53 | f"automatically to .jpg for image predictions and to .txt for all " 54 | f"other formats. Defaults to None.", 55 | ) 56 | parser.add_argument( 57 | "-f", 58 | "--format", 59 | metavar="", 60 | type=str, 61 | default=None, 62 | help=f"The format to use to map the raw prediction to. For " 63 | f"classification, common formats are bool, proba, str, for detection, " 64 | f"common formats are bool, int, img, for segmentation, common formats " 65 | f"are proba, img, mask. If not specified, it will be set " 66 | f"automatically to str, img, mask for classification, detection, " 67 | f"segmentation respectively. Check API documentation for more " 68 | f"details. Defaults to None.", 69 | ) 70 | parser.add_argument( 71 | "-t", 72 | "--task", 73 | metavar="", 74 | type=str, 75 | default="classification:anyglasses", 76 | choices=(ch := [ 77 | "classification", 78 | "classification:anyglasses", 79 | "classification:sunglasses", 80 | "classification:eyeglasses", 81 | "classification:shadows", 82 | "detection", 83 | "detection:eyes", 84 | "detection:solo", 85 | "detection:worn", 86 | "segmentation", 87 | "segmentation:frames", 88 | "segmentation:full", 89 | "segmentation:legs", 90 | "segmentation:lenses", 91 | "segmentation:shadows", 92 | "segmentation:smart", 93 | ]), 94 | help=f"The kind of task the model should perform. One of " 95 | f"{", ".join([f"{c}" for c in ch])}. If specified only as " 96 | f"classification, detection, or segmentation, the subcategories " 97 | f"anyglasses, worn, and smart will be chosen, respectively. Defaults " 98 | f"to classification:anyglasses.", 99 | ) 100 | parser.add_argument( 101 | "-s", 102 | "--size", 103 | metavar="", 104 | type=str, 105 | default="medium", 106 | choices=["small", "medium", "large", "s", "m", "l"], 107 | help=f"The model size which determines architecture type. One of " 108 | f"'small', 'medium', 'large' (or 's', 'm', 'l'). Defaults to 'medium'.", 109 | ) 110 | parser.add_argument( 111 | "-b", 112 | "--batch-size", 113 | metavar="", 114 | type=int, 115 | default=1, 116 | help=f"Only used if `--input` is a directory. The batch size to " 117 | f"use when processing the images. This groups the files in the input " 118 | f"directory to batches of size `batch_size` before processing them. " 119 | f"In some cases, larger batch sizes can speed up the processing at " 120 | f"the cost of more memory usage. Defaults to 1." 121 | ) 122 | parser.add_argument( 123 | "-p", 124 | "--pbar", 125 | type=str, 126 | metavar="", 127 | default="Processing", 128 | help=f"Only used if `--input` is a directory. It is the " 129 | f"description that is used for the progress bar. If specified " 130 | f"as '' (empty string), no progress bar is shown. Defaults to " 131 | f"'Processing'.", 132 | ) 133 | parser.add_argument( 134 | "-w", 135 | "--weights", 136 | metavar="path/to/weights.pth", 137 | type=str, 138 | default=None, 139 | help=f"Path to custom weights to load into the model. If not " 140 | f"specified, weights will be loaded from the default location (and " 141 | f"automatically downloaded there if needed). Defaults to None.", 142 | ) 143 | parser.add_argument( 144 | "-d", 145 | "--device", 146 | type=str, 147 | metavar="", 148 | default=None, 149 | help=f"The device on which to perform inference. If not specified, it " 150 | f"will be automatically checked if CUDA or MPS is supported. " 151 | f"Defaults to None.", 152 | ) 153 | 154 | return vars(parser.parse_args()) 155 | 156 | def prepare_kwargs(kwargs: dict[str, str | int | None]): 157 | # Define the keys to use when calling process and init methods 158 | model_keys = inspect.getfullargspec(BaseGlassesModel.__init__).args 159 | process_keys = [key for key in kwargs.keys() if key not in model_keys] 160 | process_keys += ["ext", "show", "is_file", "input_path", "output_path"] 161 | 162 | # Add "is_file" key to check which process method to call 163 | kwargs["is_file"] = os.path.splitext(kwargs["input"])[-1] != "" 164 | kwargs["ext"] = kwargs.pop("extension") 165 | kwargs["input_path"] = kwargs.pop("input") 166 | kwargs["output_path"] = kwargs.pop("output") 167 | 168 | if not kwargs["is_file"] and kwargs["output_path"] is None: 169 | # Input is a directory but no output path is specified 170 | kwargs["output_path"] = os.path.splitext(kwargs["input_path"])[0] + "_preds" 171 | 172 | if kwargs["is_file"] and kwargs["output_path"] is None: 173 | # Input is a file and no output path is specified 174 | kwargs["show"] = True 175 | 176 | if kwargs["pbar"] == "": 177 | # No progress bar 178 | kwargs["pbar"] = None 179 | 180 | if kwargs["weights"] is None: 181 | # Use default weights 182 | kwargs["weights"] = True 183 | 184 | if len(splits := kwargs["task"].split(":")) == 2: 185 | # Task is specified as "task:kind" 186 | kwargs["task"] = splits[0] 187 | kwargs["kind"] = splits[1] 188 | 189 | if kwargs["format"] is None and kwargs["task"] == "classification": 190 | # Default format for classification 191 | kwargs["format"] = "str" 192 | elif kwargs["format"] is None and kwargs["task"] == "detection": 193 | # Default format for detection 194 | kwargs["format"] = "img" 195 | elif kwargs["format"] is None and kwargs["task"] == "segmentation": 196 | # Default format for segmentation 197 | kwargs["format"] = "mask" 198 | 199 | # Get the kwargs for the process and init methods 200 | process_kwargs = {k: kwargs[k] for k in process_keys if k in kwargs} 201 | model_kwargs = {k: kwargs[k] for k in model_keys if k in kwargs} 202 | 203 | return process_kwargs, model_kwargs 204 | 205 | 206 | def main(): 207 | # Parse CLI args; prepare to create model and process images 208 | process_kwargs, model_kwargs = prepare_kwargs(parse_kwargs()) 209 | is_file = process_kwargs.pop("is_file") 210 | task = model_kwargs.pop("task") 211 | 212 | # Create model 213 | match task: 214 | case "classification": 215 | model = GlassesClassifier(**model_kwargs) 216 | case "detection": 217 | model = GlassesDetector(**model_kwargs) 218 | case "segmentation": 219 | model = GlassesSegmenter(**model_kwargs) 220 | case _: 221 | raise ValueError(f"Unknown task '{task}'.") 222 | 223 | if is_file: 224 | # Process a single image file 225 | process_kwargs.pop("batch_size") 226 | process_kwargs.pop("pbar") 227 | model.process_file(**process_kwargs) 228 | else: 229 | # Process a directory of images 230 | model.process_dir(**process_kwargs) 231 | 232 | if __name__ == "__main__": 233 | main() 234 | -------------------------------------------------------------------------------- /docs/helpers/generate_examples.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from PIL import Image 4 | 5 | sys.path.append("../src") 6 | 7 | from glasses_detector import GlassesDetector, GlassesSegmenter 8 | 9 | SAMPLES = { 10 | "classification": { 11 | "eyeglasses": "data/classification/eyeglasses/sunglasses-glasses-detect/test/eyeglasses/face-149_png.rf.cc420d484a00dd7158550510785f0d51.jpg", 12 | "sunglasses": "data/classification/sunglasses/face-attributes-grouped/test/sunglasses/2014-Colored-Mirror-Sunglasses-2014-Renkli-ve-Aynal-Caml-Gunes-Gozlugu-Modelleri-19.jpg", 13 | "no-glasses": "data/classification/eyeglasses/face-attributes-extra/test/no_eyeglasses/800px_COLOURBOX4590613.jpg", 14 | }, 15 | "segmentation-frames": [ 16 | { 17 | "img": "data/segmentation/frames/eyeglass/val/images/Woman-wearing-white-rimmed-thin-spectacles_jpg.rf.097bc1a8d872b6acddc3312e95a2d48e.jpg", 18 | "msk": "data/segmentation/frames/eyeglass/val/masks/Woman-wearing-white-rimmed-thin-spectacles_jpg.rf.097bc1a8d872b6acddc3312e95a2d48e.jpg", 19 | }, 20 | { 21 | "img": "data/segmentation/frames/glasses-segmentation-synthetic/test/images/img-Glass001-386-2_mouth_open-3-missile_launch_facility_01-037.jpg", 22 | "msk": "data/segmentation/frames/glasses-segmentation-synthetic/test/masks/img-Glass001-386-2_mouth_open-3-missile_launch_facility_01-037.jpg", 23 | }, 24 | { 25 | "img": "data/segmentation/frames/glasses-segmentation-synthetic/test/images/img-Glass021-450-4_brow_lower-2-entrance_hall-204.jpg", 26 | "msk": "data/segmentation/frames/glasses-segmentation-synthetic/test/masks/img-Glass021-450-4_brow_lower-2-entrance_hall-204.jpg", 27 | }, 28 | ], 29 | "segmentation-full": [ 30 | { 31 | "img": "data/segmentation/full/celeba-mask-hq/test/images/821.jpg", 32 | "msk": "data/segmentation/full/celeba-mask-hq/test/masks/821.jpg", 33 | }, 34 | { 35 | "img": "data/segmentation/full/celeba-mask-hq/test/images/1034.jpg", 36 | "msk": "data/segmentation/full/celeba-mask-hq/test/masks/1034.jpg", 37 | }, 38 | { 39 | "img": "data/segmentation/full/celeba-mask-hq/test/images/2442.jpg", 40 | "msk": "data/segmentation/full/celeba-mask-hq/test/masks/2442.jpg", 41 | }, 42 | ], 43 | "segmentation-legs": [ 44 | { 45 | "img": "data/segmentation/legs/capstone-mini-2/test/images/IMG20230325193452_0_jpg.rf.ea87e7fe943f39216cacc84b32848e28.jpg", 46 | "msk": "data/segmentation/legs/capstone-mini-2/test/masks/IMG20230325193452_0_jpg.rf.ea87e7fe943f39216cacc84b32848e28.jpg", 47 | }, 48 | { 49 | "img": "data/segmentation/legs/sunglasses-color-detection/test/images/2004970PJPXT_P00_JPG_jpg.rf.e1cd193efd84dac31c027e3d3649ec7a.jpg", 50 | "msk": "data/segmentation/legs/sunglasses-color-detection/test/masks/2004970PJPXT_P00_JPG_jpg.rf.e1cd193efd84dac31c027e3d3649ec7a.jpg", 51 | }, 52 | { 53 | "img": "data/segmentation/legs/sunglasses-color-detection/test/images/aug_57_203675009QQT_P00_JPG_jpg.rf.54bdfef21f854be18d9dcf13fa5a7ae7.jpg", 54 | "msk": "data/segmentation/legs/sunglasses-color-detection/test/masks/aug_57_203675009QQT_P00_JPG_jpg.rf.54bdfef21f854be18d9dcf13fa5a7ae7.jpg", 55 | }, 56 | ], 57 | "segmentation-lenses": [ 58 | { 59 | "img": "data/segmentation/lenses/glasses-lens/test/images/face-35_jpg.rf.f0a9a1d3b4f9e756488294d2db1720d5.jpg", 60 | "msk": "data/segmentation/lenses/glasses-lens/test/masks/face-35_jpg.rf.f0a9a1d3b4f9e756488294d2db1720d5.jpg", 61 | }, 62 | { 63 | "img": "data/segmentation/lenses/glass-color/test/images/2025260PJPVP_P00_JPG_jpg.rf.aaa9e83edbfd8a3c107650b62ddf52ed.jpg", 64 | "msk": "data/segmentation/lenses/glass-color/test/masks/2025260PJPVP_P00_JPG_jpg.rf.aaa9e83edbfd8a3c107650b62ddf52ed.jpg", 65 | }, 66 | { 67 | "img": "data/segmentation/lenses/glasses-segmentation-cropped-faces/test/images/face-1306_scaled_cropping_jpg.rf.b5f5b788fb75aa05a15e69b938704c12.jpg", 68 | "msk": "data/segmentation/lenses/glasses-segmentation-cropped-faces/test/masks/face-1306_scaled_cropping_jpg.rf.b5f5b788fb75aa05a15e69b938704c12.jpg", 69 | }, 70 | ], 71 | "segmentation-shadows": [ 72 | { 73 | "img": "data/segmentation/shadows/glasses-segmentation-synthetic/test/images/img-Glass021-435-16_sadness-2-versveldpas-105.jpg", 74 | "msk": "data/segmentation/shadows/glasses-segmentation-synthetic/test/masks/img-Glass021-435-16_sadness-2-versveldpas-105.jpg", 75 | }, 76 | { 77 | "img": "data/segmentation/shadows/glasses-segmentation-synthetic/test/images/img-Glass001-379-4_brow_lower-1-simons_town_rocks-315.jpg", 78 | "msk": "data/segmentation/shadows/glasses-segmentation-synthetic/test/masks/img-Glass001-379-4_brow_lower-1-simons_town_rocks-315.jpg", 79 | }, 80 | { 81 | "img": "data/segmentation/shadows/glasses-segmentation-synthetic/test/images/img-Glass018-422-7_jaw_left-1-urban_street_03-039.jpg", 82 | "msk": "data/segmentation/shadows/glasses-segmentation-synthetic/test/masks/img-Glass018-422-7_jaw_left-1-urban_street_03-039.jpg", 83 | }, 84 | ], 85 | "segmentation-smart": [ 86 | { 87 | "img": "data/segmentation/smart/face-synthetics-glasses/test/images/000410.jpg", 88 | "msk": "data/segmentation/smart/face-synthetics-glasses/test/masks/000410.jpg", 89 | }, 90 | { 91 | "img": "data/segmentation/smart/face-synthetics-glasses/test/images/001229.jpg", 92 | "msk": "data/segmentation/smart/face-synthetics-glasses/test/masks/001229.jpg", 93 | }, 94 | { 95 | "img": "data/segmentation/smart/face-synthetics-glasses/test/images/002315.jpg", 96 | "msk": "data/segmentation/smart/face-synthetics-glasses/test/masks/002315.jpg", 97 | }, 98 | ], 99 | "detection-eyes": [ 100 | { 101 | "img": "data/detection/eyes/ex07/test/images/face-16_jpg.rf.9554ce9ff29cca368918cb849806902f.jpg", 102 | "ann": "data/detection/eyes/ex07/test/annotations/face-16_jpg.rf.9554ce9ff29cca368918cb849806902f.txt", 103 | }, 104 | { 105 | "img": "data/detection/eyes/glasses-detection/test/images/41d3e9440d1678109133_jpeg.rf.564dc61348a3986faf801d352a7ebe41.jpg", 106 | "ann": "data/detection/eyes/glasses-detection/test/annotations/41d3e9440d1678109133_jpeg.rf.564dc61348a3986faf801d352a7ebe41.txt", 107 | }, 108 | { 109 | "img": "data/detection/eyes/glasses-detection/test/images/woman-face-eyes-feeling_jpg.rf.8c1547d76fe23936984db74a5507f188.jpg", 110 | "ann": "data/detection/eyes/glasses-detection/test/annotations/woman-face-eyes-feeling_jpg.rf.8c1547d76fe23936984db74a5507f188.txt", 111 | }, 112 | ], 113 | "detection-solo": [ 114 | { 115 | "img": "data/detection/solo/onlyglasses/test/images/8--52-_jpg.rf.cfc2d6dec8f46cd5b91c9c112fbb8bf3.jpg", 116 | "ann": "data/detection/solo/onlyglasses/test/annotations/8--52-_jpg.rf.cfc2d6dec8f46cd5b91c9c112fbb8bf3.txt", 117 | }, 118 | { 119 | "img": "data/detection/solo/kacamata-membaca/test/images/85_jpg.rf.4c164fa95a20bebc7c888d34ed160e16.jpg", 120 | "ann": "data/detection/solo/kacamata-membaca/test/annotations/85_jpg.rf.4c164fa95a20bebc7c888d34ed160e16.txt", 121 | }, 122 | { 123 | "img": "data/detection/worn/ai-pass/test/images/28f3d11c3465ce2e74d8a4d65861de51_jpg.rf.e119438402f655cec8032304c7603606.jpg", 124 | "ann": "data/detection/worn/ai-pass/test/annotations/28f3d11c3465ce2e74d8a4d65861de51_jpg.rf.e119438402f655cec8032304c7603606.txt", 125 | }, 126 | ], 127 | "detection-worn": [ 128 | { 129 | "img": "data/detection/worn/glasses-detection/test/images/425px-robert_downey_jr_avp_iron_man_3_paris_jpg.rf.998f29000b52081eb6ea4d25df75512c.jpg", 130 | "ann": "data/detection/worn/glasses-detection/test/annotations/425px-robert_downey_jr_avp_iron_man_3_paris_jpg.rf.998f29000b52081eb6ea4d25df75512c.txt", 131 | }, 132 | { 133 | "img": "data/detection/worn/ai-pass/test/images/glasses120_png_jpg.rf.847610bd1230c85c8f81cbced18c38ea.jpg", 134 | "ann": "data/detection/worn/ai-pass/test/annotations/glasses120_png_jpg.rf.847610bd1230c85c8f81cbced18c38ea.txt", 135 | }, 136 | { 137 | "img": "data/detection/worn/ai-pass/test/images/women-with-glass_81_jpg.rf.bef8096d89dd0805ff3bbf0f8d08b0c8.jpg", 138 | "ann": "data/detection/worn/ai-pass/test/annotations/women-with-glass_81_jpg.rf.bef8096d89dd0805ff3bbf0f8d08b0c8.txt", 139 | }, 140 | ], 141 | } 142 | 143 | 144 | def generate_examples(data_dir: str = "..", out_dir: str = "_static/img"): 145 | for task, samples in SAMPLES.items(): 146 | if task == "classification": 147 | for label, path in samples.items(): 148 | # Load the image and save it 149 | img = Image.open(f"{data_dir}/{path}") 150 | img.save(f"{out_dir}/{task}-{label}.jpg") 151 | elif task.startswith("detection"): 152 | for i, sample in enumerate(samples): 153 | # Load the image 154 | img = Image.open(f"{data_dir}/{sample["img"]}") 155 | 156 | with open(f"{data_dir}/{sample["ann"]}", "r") as f: 157 | # Load annotations (single bbox per image) 158 | ann = [list(map(float, f.read().split()))] 159 | 160 | # Draw the bounding box and save the image 161 | out = GlassesDetector.draw_boxes(img, ann, colors="red", width=3) 162 | out.save(f"{out_dir}/{task}-{i}.jpg") 163 | elif task.startswith("segmentation"): 164 | for i, sample in enumerate(samples): 165 | # Load image and mask and overlay them 166 | img = Image.open(f"{data_dir}/{sample["img"]}") 167 | msk = Image.open(f"{data_dir}/{sample["msk"]}") 168 | out = GlassesSegmenter.draw_masks(img, msk, colors="red", alpha=0.5) 169 | out.save(f"{out_dir}/{task}-{i}.jpg") 170 | 171 | 172 | if __name__ == "__main__": 173 | # cd docs/ 174 | # python helpers/generate_examples.py 175 | generate_examples() 176 | -------------------------------------------------------------------------------- /docs/helpers/build_finished.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import requests 5 | import yaml 6 | from bs4 import BeautifulSoup, NavigableString, Tag 7 | from sphinx.application import Sphinx 8 | 9 | 10 | class BuildFinished: 11 | def __init__(self, static_path: str = "_static", conf_path: str = "conf.yaml"): 12 | # Init inv directory and create it if not exists 13 | self.inv_dir = os.path.join(static_path, "inv") 14 | os.makedirs(self.inv_dir, exist_ok=True) 15 | 16 | with open(conf_path) as f: 17 | # Load conf.yaml and get build_finished section 18 | self.conf = yaml.safe_load(f)["build-finished"] 19 | 20 | def align_rowspans(self, soup: BeautifulSoup): 21 | if tds := soup.find_all("td", rowspan=True): 22 | for td in tds: 23 | td["valign"] = "middle" 24 | 25 | def add_collapse_ids(self, soup: BeautifulSoup): 26 | if details := soup.find_all("details"): 27 | for detail in details: 28 | if detail.has_attr("name"): 29 | detail["id"] = "-".join(detail["name"].split()) 30 | 31 | def keep_only_data(self, soup: BeautifulSoup): 32 | def has_children(tag: Tag, txt1: str, txt2: str): 33 | if tag.name != "dt": 34 | return False 35 | 36 | # Get the prename and name elements of the signature 37 | ch1 = tag.select_one("span.sig-prename.descclassname span.pre") 38 | ch2 = tag.select_one("span.sig-name.descname span.pre") 39 | 40 | return ch1 and ch2 and ch1.string == txt1 and ch2.string == txt2 41 | 42 | for alias, module in self.conf["TYPE_ALIASES"].items(): 43 | if dt := soup.find("dt", id=f"{module}{alias}"): 44 | # Copy class directive's a 45 | a = dt.find("a").__copy__() 46 | dt.parent.decompose() 47 | else: 48 | continue 49 | 50 | if dt := soup.find(lambda tag: has_children(tag, module, alias)): 51 | # ID and a for data directive 52 | dt["id"] = f"{module}{alias}" 53 | dt.append(a) 54 | dt.find("span", class_="sig-prename descclassname").decompose() 55 | 56 | def process_in_page_toc(self, soup: BeautifulSoup): 57 | for li in soup.find_all("li", class_="toc-h3 nav-item toc-entry"): 58 | if span := li.find("span"): 59 | # Modify the toc-nav span element here 60 | span.string = span.string.split(".")[-1] 61 | 62 | def break_long_signatures(self, soup: BeautifulSoup): 63 | def break_long_params(id, sig_param): 64 | if (params := self.conf["LONG_PARAMETER_IDS"].get(id)) is None: 65 | return 66 | 67 | is_opened = False 68 | 69 | for span in sig_param.find_all("span", class_="pre"): 70 | if span.string == "[": 71 | is_opened = True 72 | elif span.string == "]": 73 | is_opened = False 74 | 75 | if ( 76 | span.string == "|" 77 | and not is_opened 78 | and span.parent.parent.parent.find("span", class_="pre").string 79 | in params 80 | ): 81 | # Add long-sig to spans with | 82 | span["class"].append("long-sig") 83 | 84 | for id in self.conf["LONG_SIGNATURE_IDS"]: 85 | if not (dt := soup.find("dt", id=id)): 86 | continue 87 | 88 | for sig_param in dt.find_all("em", class_="sig-param"): 89 | # Add long-sig to the identified sig-param ems 90 | sig_param["class"].append("long-sig") 91 | break_long_params(id, sig_param) 92 | 93 | for dt_sibling in dt.find_next_siblings("dt"): 94 | for sig_param in dt_sibling.find_all("em", class_="sig-param"): 95 | # Add long-sig for overrides, i.e., sibling dts, too 96 | sig_param["class"].append("long-sig") 97 | break_long_params(id, sig_param) 98 | 99 | def customize_code_block_colors_python(self, soup: BeautifulSoup): 100 | for span in soup.select("div.highlight-python div.highlight pre span"): 101 | for name, keyword in self.conf["CUSTOM_SYNTAX_COLORS_PYTHON"].items(): 102 | if span.get_text().strip() in keyword: 103 | # Add class of the syntax keyword 104 | span["class"].append(name) 105 | 106 | def customize_code_block_colors_bash(self, soup: BeautifulSoup): 107 | # Select content groups 108 | pres = soup.select("div.highlight-bash div.highlight pre") 109 | pres.extend(soup.select("code.highlight-bash")) 110 | 111 | # Define the constants 112 | KEEP_CLS = {"c1", "w"} 113 | OP_CLS = "custom-highlight-op" 114 | START_CLS = "custom-highlight-start" 115 | DEFAULT_CLS = "custom-highlight-default" 116 | 117 | # Get the starts and flatten the keywords 118 | starts = self.conf["CUSTOM_SYNTAX_COLORS_BASH"][START_CLS] 119 | ops = self.conf["CUSTOM_SYNTAX_COLORS_BASH"][OP_CLS] + ["\n"] 120 | flat_kwds = [ 121 | (cls, kwd) 122 | for cls, kwds in self.conf["CUSTOM_SYNTAX_COLORS_BASH"].items() 123 | for kwd in kwds 124 | if cls not in [START_CLS, DEFAULT_CLS, OP_CLS] 125 | ] 126 | 127 | for pre in pres: 128 | for content in pre.contents: 129 | if ( 130 | isinstance(content, Tag) 131 | and "class" in content.attrs.keys() 132 | and not any(cls in content["class"] for cls in KEEP_CLS) 133 | ): 134 | # Only keep the text part, i.e., remove 135 | content.replace_with(NavigableString(content.get_text())) 136 | elif isinstance(content, NavigableString) and "\n" in content: 137 | # Init the splits 138 | sub_contents = [] 139 | 140 | for sub_content in content.split("\n"): 141 | if sub_content != "": 142 | # No need to add borderline empty strings 143 | sub_contents.append(NavigableString(sub_content)) 144 | 145 | # Also add the newline character as NS 146 | sub_contents.append(NavigableString("\n")) 147 | 148 | # Replace the original content with splits 149 | content.replace_with(*sub_contents[:-1]) 150 | 151 | for pre in pres: 152 | # Init the starts 153 | start_idx = 0 154 | 155 | for content in pre.contents: 156 | if not isinstance(content, NavigableString): 157 | # Skip non-navigable strings 158 | continue 159 | 160 | if content in ops: 161 | # Reset start 162 | start_idx = 0 163 | 164 | # If keyword is an operator, wrap with OP_CLS 165 | new_content = f'{content}' 166 | content.replace_with(BeautifulSoup(new_content, "html.parser")) 167 | continue 168 | 169 | # Get the start keyword if it exists 170 | start = [ 171 | sub_start 172 | for start in starts 173 | for sub_start_idx, sub_start in enumerate(start.split()) 174 | if start_idx == sub_start_idx and content == sub_start 175 | ] 176 | 177 | # Increment start idx 178 | start_idx += 1 179 | 180 | if len(start) > 0: 181 | # If keyword is a start 182 | new_content = f'{start[0]}' 183 | content.replace_with(BeautifulSoup(new_content, "html.parser")) 184 | continue 185 | 186 | # Check if any of the keywords from config matches 187 | is_kwd = [content.startswith(kwd) for _, kwd in flat_kwds] 188 | 189 | if any(is_kwd): 190 | # Add the corresponding keyword class 191 | cls, _ = flat_kwds[is_kwd.index(True)] 192 | new_content = f'{content}' 193 | else: 194 | # Add the default class if no keyword is found 195 | new_content = f'{content}' 196 | 197 | # Replace the original content with the new one 198 | content.replace_with(BeautifulSoup(new_content, "html.parser")) 199 | 200 | # Prettify soup 201 | soup.prettify() 202 | 203 | def configure_section_icons(self, soup: BeautifulSoup): 204 | # 205 | as_ = soup.select("nav.navbar-nav li.nav-item a") 206 | 207 | for a in as_: 208 | for name, icon_cls in self.conf["SECTION_ICONS"].items(): 209 | if a.get_text().strip() != name: 210 | continue 211 | 212 | # Add an icon to the navbar section item 213 | a.parent["style"] = "margin-right: 0.5em;" 214 | icon = f"" 215 | a_new = f"{icon}{name}" 216 | a.string.replace_with(BeautifulSoup(a_new, "html.parser")) 217 | 218 | as_ = soup.select("div#index-toctree li.toctree-l1 a") 219 | 220 | for a in as_: 221 | for content in a.contents: 222 | if isinstance(content, Tag) and content.name == "span": 223 | content["style"] = "margin-right: 0.5em;" 224 | continue 225 | elif ( 226 | not isinstance(content, NavigableString) 227 | or content.strip() not in self.conf["SECTION_ICONS"].keys() 228 | ): 229 | continue 230 | 231 | # Modify the underline of the toctree section item 232 | content.parent["style"] = "text-decoration: none;" 233 | content.replace_with( 234 | BeautifulSoup(f"{content.get_text().strip()}", "html.parser") 235 | ) 236 | 237 | soup.prettify() 238 | 239 | def edit_html(self, app: Sphinx): 240 | if app.builder.format != "html": 241 | return 242 | 243 | for pagename in app.env.found_docs: 244 | if not isinstance(pagename, str): 245 | continue 246 | 247 | with (Path(app.outdir) / f"{pagename}.html").open("r") as f: 248 | # Parse HTML using BeautifulSoup html parser 249 | soup = BeautifulSoup(f.read(), "html.parser") 250 | 251 | self.align_rowspans(soup) 252 | self.keep_only_data(soup) 253 | self.add_collapse_ids(soup) 254 | self.process_in_page_toc(soup) 255 | self.break_long_signatures(soup) 256 | self.customize_code_block_colors_python(soup) 257 | self.customize_code_block_colors_bash(soup) 258 | self.configure_section_icons(soup) 259 | 260 | with (Path(app.outdir) / f"{pagename}.html").open("w") as f: 261 | # Write back HTML 262 | f.write(str(soup)) 263 | 264 | def __call__(self, app, exception): 265 | self.edit_html(app) 266 | -------------------------------------------------------------------------------- /docs/docs/features.rst: -------------------------------------------------------------------------------- 1 | :fas:`gears` Features 2 | ===================== 3 | 4 | The following *tasks* are supported: 5 | 6 | * **Classification** - binary classification of the presence of glasses and their types. 7 | * **Detection** - binary detection of worn/standalone glasses and eye area. 8 | * **Segmentation** - binary segmentation of glasses and their parts. 9 | 10 | Each :attr:`task` has multiple :attr:`kinds` (task categories) and model :attr:`sizes` (architectures with pre-trained weights). 11 | 12 | Classification 13 | -------------- 14 | 15 | .. table:: Classification Kinds 16 | :widths: 15 31 18 18 18 17 | :name: classification-kinds 18 | 19 | +----------------+-------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ 20 | | **Kind** | **Description** | **Examples** | 21 | +================+=====================================+=============================================================+=============================================================+=============================================================+ 22 | | ``anyglasses`` | Identifies any kind of glasses, | .. image:: ../_static/img/classification-eyeglasses-pos.jpg | .. image:: ../_static/img/classification-sunglasses-pos.jpg | .. image:: ../_static/img/classification-no-glasses-neg.jpg | 23 | | | googles, or spectacles. +-------------------------------------------------------------+-------------------------------------------------------------+-------------------------------------------------------------+ 24 | | | | .. centered:: Positive | .. centered:: Positive | .. centered:: Negative | 25 | +----------------+-------------------------------------+-------------------------------------------------------------+-------------------------------------------------------------+-------------------------------------------------------------+ 26 | | ``eyeglasses`` | Identifies only transparent glasses | .. image:: ../_static/img/classification-eyeglasses-pos.jpg | .. image:: ../_static/img/classification-sunglasses-neg.jpg | .. image:: ../_static/img/classification-no-glasses-neg.jpg | 27 | | | (here referred as *eyeglasses*) +-------------------------------------------------------------+-------------------------------------------------------------+-------------------------------------------------------------+ 28 | | | | .. centered:: Positive | .. centered:: Negative | .. centered:: Negative | 29 | +----------------+-------------------------------------+-------------------------------------------------------------+-------------------------------------------------------------+-------------------------------------------------------------+ 30 | | ``sunglasses`` | Identifies only opaque and | .. image:: ../_static/img/classification-eyeglasses-neg.jpg | .. image:: ../_static/img/classification-sunglasses-pos.jpg | .. image:: ../_static/img/classification-no-glasses-neg.jpg | 31 | | | semi-transparent glasses (here +-------------------------------------------------------------+-------------------------------------------------------------+-------------------------------------------------------------+ 32 | | | referred as *sunglasses*) | .. centered:: Negative | .. centered:: Positive | .. centered:: Negative | 33 | +----------------+-------------------------------------+-------------------------------------------------------------+-------------------------------------------------------------+-------------------------------------------------------------+ 34 | | ``shadows`` | Identifies cast shadows (only | .. image:: ../_static/img/classification-shadows-pos.jpg | .. image:: ../_static/img/classification-shadows-neg.jpg | .. image:: ../_static/img/classification-no-glasses-neg.jpg | 35 | | | shadows of (any) glasses frames) +-------------------------------------------------------------+-------------------------------------------------------------+-------------------------------------------------------------+ 36 | | | | .. centered:: Positive | .. centered:: Negative | .. centered:: Negative | 37 | +----------------+-------------------------------------+-------------------------------------------------------------+-------------------------------------------------------------+-------------------------------------------------------------+ 38 | 39 | .. admonition:: Check classifier performances 40 | :class: tip 41 | 42 | * `Performance Information of the Pre-trained Classifiers `_: performance of each :attr:`kind`. 43 | * `Size Information of the Pre-trained Classifiers `_: efficiency of each :attr:`size`. 44 | 45 | Detection 46 | --------- 47 | 48 | .. table:: Detection Kinds 49 | :widths: 15 31 18 18 18 50 | :name: detection-kinds 51 | 52 | +----------+--------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------+ 53 | | **Kind** | **Description** | **Examples** | 54 | +==========+======================================================================================+================================================+================================================+================================================+ 55 | | ``eyes`` | Detects only the eye region, no glasses. | .. image:: ../_static/img/detection-eyes-0.jpg | .. image:: ../_static/img/detection-eyes-1.jpg | .. image:: ../_static/img/detection-eyes-2.jpg | 56 | +----------+--------------------------------------------------------------------------------------+------------------------------------------------+------------------------------------------------+------------------------------------------------+ 57 | | ``solo`` | Detects any glasses in the wild, i.e., standalone glasses that are placed somewhere. | .. image:: ../_static/img/detection-solo-0.jpg | .. image:: ../_static/img/detection-solo-1.jpg | .. image:: ../_static/img/detection-solo-2.jpg | 58 | +----------+--------------------------------------------------------------------------------------+------------------------------------------------+------------------------------------------------+------------------------------------------------+ 59 | | ``worn`` | Detects any glasses worn by people but can also detect non-worn glasses. | .. image:: ../_static/img/detection-worn-0.jpg | .. image:: ../_static/img/detection-worn-1.jpg | .. image:: ../_static/img/detection-worn-2.jpg | 60 | +----------+--------------------------------------------------------------------------------------+------------------------------------------------+------------------------------------------------+------------------------------------------------+ 61 | 62 | .. admonition:: Check detector performances 63 | :class: tip 64 | 65 | * `Performance Information of the Pre-trained Detectors `_: performance of each :attr:`kind`. 66 | * `Size Information of the Pre-trained Detectors `_: efficiency of each :attr:`size`. 67 | 68 | Segmentation 69 | ------------ 70 | 71 | .. table:: Segmentation Kinds 72 | :widths: 15 31 18 18 18 73 | :name: segmentation-kinds 74 | 75 | +-------------+----------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------+ 76 | | **Kind** | **Description** | **Examples** | 77 | +=============+====================================================================================================+======================================================+======================================================+======================================================+ 78 | | ``frames`` | Segments frames (including legs) of any glasses | .. image:: ../_static/img/segmentation-frames-0.jpg | .. image:: ../_static/img/segmentation-frames-1.jpg | .. image:: ../_static/img/segmentation-frames-2.jpg | 79 | +-------------+----------------------------------------------------------------------------------------------------+------------------------------------------------------+------------------------------------------------------+------------------------------------------------------+ 80 | | ``full`` | Segments full glasses, i.e., lenses and the whole frame | .. image:: ../_static/img/segmentation-full-0.jpg | .. image:: ../_static/img/segmentation-full-1.jpg | .. image:: ../_static/img/segmentation-full-2.jpg | 81 | +-------------+----------------------------------------------------------------------------------------------------+------------------------------------------------------+------------------------------------------------------+------------------------------------------------------+ 82 | | ``legs`` | Segments only frame legs of standalone glasses | .. image:: ../_static/img/segmentation-legs-0.jpg | .. image:: ../_static/img/segmentation-legs-1.jpg | .. image:: ../_static/img/segmentation-legs-2.jpg | 83 | +-------------+----------------------------------------------------------------------------------------------------+------------------------------------------------------+------------------------------------------------------+------------------------------------------------------+ 84 | | ``lenses`` | Segments lenses of any glasses (both transparent and opaque). | .. image:: ../_static/img/segmentation-lenses-0.jpg | .. image:: ../_static/img/segmentation-lenses-1.jpg | .. image:: ../_static/img/segmentation-lenses-2.jpg | 85 | +-------------+----------------------------------------------------------------------------------------------------+------------------------------------------------------+------------------------------------------------------+------------------------------------------------------+ 86 | | ``shadows`` | Segments cast shadows on the skin by the glasses frames only (does not consider opaque lenses). | .. image:: ../_static/img/segmentation-shadows-0.jpg | .. image:: ../_static/img/segmentation-shadows-1.jpg | .. image:: ../_static/img/segmentation-shadows-2.jpg | 87 | +-------------+----------------------------------------------------------------------------------------------------+------------------------------------------------------+------------------------------------------------------+------------------------------------------------------+ 88 | | ``smart`` | Segments visible glasses parts: like ``full`` but does not segment lenses if they are transparent. | .. image:: ../_static/img/segmentation-smart-0.jpg | .. image:: ../_static/img/segmentation-smart-1.jpg | .. image:: ../_static/img/segmentation-smart-2.jpg | 89 | +-------------+----------------------------------------------------------------------------------------------------+------------------------------------------------------+------------------------------------------------------+------------------------------------------------------+ 90 | 91 | .. admonition:: Check segmenter performances 92 | :class: tip 93 | 94 | * `Performance Information of the Pre-trained Segmenters `_: performance of each :attr:`kind`. 95 | * `Size Information of the Pre-trained Segmenters `_: efficiency of each :attr:`size`. -------------------------------------------------------------------------------- /src/glasses_detector/_data/augmenter_mixin.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import albumentations as A 4 | import numpy as np 5 | import skimage.transform as st 6 | import torch 7 | from albumentations.pytorch import ToTensorV2 8 | from PIL import Image 9 | 10 | 11 | class ToTensor(ToTensorV2): 12 | def apply_to_mask(self, mask, **params): 13 | return torch.from_numpy((mask > 127).astype(np.float32)) 14 | 15 | 16 | class AugmenterMixin: 17 | @staticmethod 18 | def default_augmentations() -> list[A.BasicTransform]: 19 | return [ 20 | A.OneOf( 21 | [ 22 | A.VerticalFlip(), 23 | A.HorizontalFlip(), 24 | A.RandomRotate90(), 25 | A.Transpose(), 26 | ], 27 | p=0.75, 28 | ), 29 | A.OneOf( 30 | [ 31 | A.PiecewiseAffine(), 32 | A.ShiftScaleRotate(), 33 | A.ElasticTransform(), 34 | A.OpticalDistortion(distort_limit=0.1, shift_limit=0.1), 35 | A.GridDistortion(distort_limit=0.5), 36 | ] 37 | ), 38 | A.OneOf( 39 | [ 40 | A.RandomBrightnessContrast(), 41 | A.ColorJitter(), 42 | A.HueSaturationValue(), 43 | A.RandomGamma(), 44 | A.CLAHE(), 45 | A.RGBShift(), 46 | ] 47 | ), 48 | A.OneOf( 49 | [ 50 | A.Blur(), 51 | A.GaussianBlur(), 52 | A.MedianBlur(), 53 | A.GaussNoise(), 54 | ] 55 | ), 56 | A.OneOf( 57 | [ 58 | A.RandomResizedCrop(256, 256, p=0.4), 59 | A.RandomSizedCrop((10, 131), 256, 256, p=0.4), 60 | A.RandomCrop(height=200, width=200, p=0.2), 61 | ], 62 | p=0.25, 63 | ), 64 | A.PadIfNeeded(min_height=256, min_width=256, always_apply=True), 65 | A.CoarseDropout(max_holes=10, max_height=8, max_width=8, p=0.2), 66 | A.Normalize(), 67 | ToTensor(), 68 | ] 69 | 70 | @staticmethod 71 | def minimal_augmentations() -> list[A.BasicTransform]: 72 | return [ 73 | A.OneOf( 74 | [ 75 | A.VerticalFlip(), 76 | A.HorizontalFlip(), 77 | A.RandomRotate90(), 78 | A.Transpose(), 79 | ], 80 | p=0.1, 81 | ), 82 | A.OneOf( 83 | [ 84 | A.PiecewiseAffine((0.02, 0.03)), 85 | A.ShiftScaleRotate((-0.02, 0.02), 0.05), 86 | A.ElasticTransform(sigma=20, alpha_affine=20), 87 | A.OpticalDistortion(distort_limit=0.02, shift_limit=0.02), 88 | A.GridDistortion(num_steps=3, distort_limit=0.1), 89 | ], 90 | p=0.1, 91 | ), 92 | A.OneOf( 93 | [ 94 | A.RandomBrightnessContrast(0.05, 0.05), 95 | A.ColorJitter(0.05, 0.05, 0.05), 96 | A.HueSaturationValue(5, 10, 5), 97 | A.RandomGamma((80, 100)), 98 | A.CLAHE(2, (3, 3)), 99 | A.RGBShift(5, 5, 5), 100 | ], 101 | p=0.1, 102 | ), 103 | A.OneOf( 104 | [ 105 | A.Blur((3, 3)), 106 | A.GaussianBlur((3, 3)), 107 | A.MedianBlur((3, 3)), 108 | A.GaussNoise((5, 10)), 109 | ], 110 | p=0.1, 111 | ), 112 | A.OneOf( 113 | [ 114 | A.RandomResizedCrop(256, 256), 115 | A.RandomSizedCrop((10, 131), 256, 256), 116 | ], 117 | p=0.1, 118 | ), 119 | A.Normalize(), 120 | ToTensor(), 121 | ] 122 | 123 | @classmethod 124 | def create_transform( 125 | cls, 126 | is_train: bool = False, 127 | **kwargs, 128 | ) -> A.Compose: 129 | # Get the list of default augmentations 130 | transform = cls.default_augmentations() 131 | kwargs = deepcopy(kwargs) 132 | 133 | if kwargs.pop("has_bbox", False): 134 | # Add bbox params 135 | kwargs.setdefault( 136 | "bbox_params", 137 | A.BboxParams( 138 | format="pascal_voc", 139 | label_fields=["bbcats"], 140 | min_visibility=0.1, 141 | **kwargs.pop("bbox_kwargs", {}), 142 | ), 143 | ) 144 | if isinstance(transform[-3], A.CoarseDropout): 145 | # CoarseDropout not supported with bbox_params 146 | transform.pop(-3) 147 | 148 | if kwargs.pop("has_keys", False): 149 | # Add keypoint params 150 | kwargs.setdefault( 151 | "keypoint_params", 152 | A.KeypointParams( 153 | format="xy", 154 | label_fields=["kpcats"], 155 | remove_invisible=False, 156 | **kwargs.pop("keys_kwargs", {}), 157 | ), 158 | ) 159 | 160 | if not is_train: 161 | # Only keep the last two 162 | transform = transform[-2:] 163 | 164 | return A.Compose(transform, **kwargs) 165 | 166 | @staticmethod 167 | def load_image( 168 | image: str | Image.Image | np.ndarray, 169 | resize: tuple[int, int] | None = None, 170 | is_mask: bool = False, 171 | return_orig_size: bool = False, 172 | ) -> np.ndarray: 173 | if isinstance(image, str): 174 | # Image is given as path 175 | image = Image.open(image) 176 | 177 | if isinstance(image, Image.Image): 178 | # Image is not a numpy array 179 | image = np.array(image) 180 | 181 | if is_mask: 182 | # Convert image to black & white and ensure only 1 channel 183 | image = ((image > 127).any(2) if image.ndim > 2 else (image > 127)) * 255 184 | elif image.ndim == 2: 185 | # Image isn't a mask, convert it to RGB 186 | image = np.stack([image] * 3, axis=-1) 187 | 188 | if resize is not None: 189 | # Resize image to new (w, h), preserv range from 0 to 255 190 | image = st.resize(image, resize[::-1], preserve_range=True) 191 | 192 | # Convert image to UINT8 type 193 | image = image.astype(np.uint8) 194 | 195 | if return_orig_size: 196 | # Original size as well 197 | return image, image.shape[:2][::-1] 198 | 199 | return image 200 | 201 | @staticmethod 202 | def load_boxes( 203 | boxes: str | list[list[int | float | str]], 204 | resize: tuple[int, int] | None = None, 205 | img_size: tuple[int, int] | None = None, 206 | ) -> list[list[float]]: 207 | if isinstance(boxes, str): 208 | with open(boxes, "r") as f: 209 | # Each line is bounding box: "x_min y_min x_max y_max" 210 | boxes = [xyxy.strip().split() for xyxy in f.readlines()] 211 | 212 | # Convert each coordinate in each bbox to float 213 | boxes = [list(map(float, xyxy)) for xyxy in boxes] 214 | 215 | if img_size is None: 216 | if resize is not None: 217 | raise ValueError("img_size must be provided if resize is not None") 218 | 219 | return boxes 220 | 221 | for i, box in enumerate(boxes): 222 | if box[2] <= box[0]: 223 | # Ensure x_min < x_max <= img_size[0] 224 | boxes[i][0] = min(box[0], img_size[0] - 1) 225 | boxes[i][2] = boxes[i][0] + 1 226 | 227 | if box[3] <= box[1]: 228 | # Ensure y_min < y_max <= img_size[1] 229 | boxes[i][1] = min(box[1], img_size[1] - 1) 230 | boxes[i][3] = boxes[i][1] + 1 231 | 232 | if resize is not None: 233 | # Convert boxes to new (w, h) 234 | boxes = [ 235 | [ 236 | box[0] * resize[0] / img_size[0], 237 | box[1] * resize[1] / img_size[1], 238 | box[2] * resize[0] / img_size[0], 239 | box[3] * resize[1] / img_size[1], 240 | ] 241 | for box in boxes 242 | ] 243 | 244 | return boxes 245 | 246 | @staticmethod 247 | def load_keypoints( 248 | keypoints: str | list[list[int | float | str]], 249 | resize: tuple[int, int] | None = None, 250 | img_size: tuple[int, int] | None = None, 251 | ) -> list[list[float]]: 252 | if isinstance(keypoints, str): 253 | with open(keypoints, "r") as f: 254 | # Each line is keypoint: "x y" 255 | keypoints = [xy.strip().split() for xy in f.readlines()] 256 | 257 | # Convert each coordinate in each keypoint to float 258 | keypoints = [list(map(float, xy)) for xy in keypoints] 259 | 260 | if img_size is None: 261 | if resize is not None: 262 | raise ValueError("img_size must be provided if resize is not None") 263 | 264 | return keypoints 265 | 266 | if resize is not None: 267 | # Convert keypoints to new (w, h) 268 | keypoints = [ 269 | [ 270 | keypoint[0] * resize[0] / img_size[0], 271 | keypoint[1] * resize[1] / img_size[1], 272 | ] 273 | for keypoint in keypoints 274 | ] 275 | 276 | return keypoints 277 | 278 | @classmethod 279 | def load_transform( 280 | cls, 281 | image: str | Image.Image | np.ndarray, 282 | masks: list[str | Image.Image | np.ndarray] = [], 283 | boxes: list[str | list[list[int | float | str]]] = [], 284 | bcats: list[str] = [], 285 | keys: list[str | list[list[int | float | str]]] = [], 286 | kcats: list[str] = [], 287 | resize: tuple[int, int] | None = None, 288 | transform: A.Compose | bool = False, # False means test/val 289 | ) -> torch.Tensor | tuple[torch.Tensor]: 290 | # Load the image and resize if needed (also return original size) 291 | image, orig_size = cls.load_image(image, resize, return_orig_size=True) 292 | transform_kwargs = {"image": image} 293 | 294 | if isinstance(transform, bool): 295 | # Load transform (train or test is based on bool val) 296 | transform = cls.create_transform(is_train=transform) 297 | 298 | if masks != []: 299 | # Load masks and add to transform kwargs 300 | masks = [cls.load_image(m, resize, is_mask=True) for m in masks] 301 | transform_kwargs.update({"masks": masks}) 302 | 303 | if boxes != []: 304 | # Initialize flat boxes and cats 305 | flat_boxes, flat_bcats = [], [] 306 | 307 | for i, b in enumerate(boxes): 308 | # Load boxes and add to transform kwargs 309 | b = cls.load_boxes(b, resize, orig_size) 310 | flat_boxes.extend(b) 311 | 312 | if bcats != []: 313 | # For each box, add corresponding cat 314 | flat_bcats.extend([bcats[i]] * len(b)) 315 | 316 | # Add boxes to transform kwargs 317 | transform_kwargs.update({"bboxes": flat_boxes}) 318 | 319 | if bcats != []: 320 | # Also add cats to transform kwargs 321 | transform_kwargs.update({"bbcats": flat_bcats}) 322 | 323 | if keys != []: 324 | # Initialize flat keypoints and cats 325 | flat_keys, flat_kcats = [], [] 326 | 327 | for i, k in enumerate(keys): 328 | # Load keypoints and add to transform kwargs 329 | k = cls.load_keypoints(k, resize, orig_size) 330 | flat_keys.extend(k) 331 | 332 | if kcats != []: 333 | # For each keypoint, add corresponding cat 334 | flat_kcats.extend([kcats[i]] * len(k)) 335 | 336 | # Add keypoints to transform kwargs, update cats 337 | transform_kwargs.update({"keypoints": flat_keys}) 338 | 339 | if kcats != []: 340 | # Also add cats to transform kwargs 341 | transform_kwargs.update({"kpcats": flat_kcats}) 342 | 343 | # Transform everything, generate return list 344 | transformed = transform(**transform_kwargs) 345 | return_list = [transformed["image"]] 346 | 347 | for key in ["masks", "bboxes", "bbcats", "keypoints", "kpcats"]: 348 | if key not in transformed: 349 | continue 350 | 351 | if key in {"bboxes", "keypoints", "bbcats", "kpcats"}: 352 | # Convert to torch tensor if key is category or bbox/keypoint 353 | dtype = torch.long if key in ["bbcats", "kpcats"] else torch.float32 354 | transformed[key] = torch.tensor(transformed[key], dtype=dtype) 355 | 356 | # Add to the return list 357 | return_list.append(transformed[key]) 358 | 359 | if len(return_list) == 1: 360 | return return_list[0] 361 | 362 | return tuple(return_list) 363 | -------------------------------------------------------------------------------- /src/glasses_detector/components/pred_type.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. class:: Scalar 3 | 4 | .. data:: Scalar 5 | :noindex: 6 | :type: typing.TypeAliasType 7 | :value: StandardScalar | np.generic | np.ndarray | torch.Tensor 8 | 9 | Type alias for a scalar prediction. For more information, see 10 | :attr:`~PredType.SCALAR`. 11 | 12 | :class:`bool` | :class:`int` | :class:`float` | :class:`str` 13 | | :class:`~numpy.generic` | :class:`~numpy.ndarray` 14 | | :class:`~torch.Tensor` 15 | 16 | .. class:: Tensor 17 | 18 | .. data:: Tensor 19 | :noindex: 20 | :type: typing.TypeAliasType 21 | :value: Iterable[Scalar | Tensor] | PIL.Image.Image 22 | 23 | Type alias for a tensor prediction. For more information, see 24 | :attr:`~PredType.TENSOR`. 25 | 26 | :class:`~typing.Iterable` | :class:`~PIL.Image.Image` 27 | 28 | .. class:: Default 29 | 30 | .. data:: Default 31 | :noindex: 32 | :type: typing.TypeAliasType 33 | :value: Scalar | Tensor 34 | 35 | Type alias for a default prediction. For more information, see 36 | :attr:`~PredType.DEFAULT`. 37 | 38 | :class:`bool` | :class:`int` | :class:`float` | :class:`str` 39 | | :class:`~numpy.generic` | :class:`~numpy.ndarray` 40 | | :class:`~torch.Tensor` | :class:`~typing.Iterable` 41 | 42 | .. class:: StandardScalar 43 | 44 | .. data:: StandardScalar 45 | :noindex: 46 | :type: typing.TypeAliasType 47 | :value: bool | int | float | str 48 | 49 | Type alias for a standard scalar prediction. For more information, 50 | see :attr:`~PredType.STANDARD_SCALAR`. 51 | 52 | :class:`bool` | :class:`int` | :class:`float` | :class:`str` 53 | 54 | .. class:: StandardTensor 55 | 56 | .. data:: StandardTensor 57 | :noindex: 58 | :type: typing.TypeAliasType 59 | :value: list[StandardScalar | StandardTensor] 60 | 61 | Type alias for a standard tensor prediction. For more information, 62 | see :attr:`~PredType.STANDARD_TENSOR`. 63 | 64 | :class:`list` 65 | 66 | .. class:: StandardDefault 67 | 68 | .. data:: StandardDefault 69 | :noindex: 70 | :type: typing.TypeAliasType 71 | :value: StandardScalar | StandardTensor 72 | 73 | Type alias for a standard default prediction. For more information, 74 | see :attr:`~PredType.STANDARD_DEFAULT`. 75 | 76 | :class:`bool` | :class:`int` | :class:`float` | :class:`str` 77 | | :class:`list` 78 | 79 | .. class:: NonDefault[T] 80 | 81 | .. data:: NonDefault[T] 82 | :noindex: 83 | :type: typing.TypeAliasType 84 | :value: T 85 | 86 | Type alias for a non-default prediction. For more information, see 87 | :attr:`~PredType.NON_DEFAULT`. 88 | 89 | :data:`~typing.Any` 90 | 91 | .. class:: Either 92 | 93 | .. data:: Either 94 | :noindex: 95 | :type: typing.TypeAliasType 96 | :value: Default | NonDefault 97 | 98 | Type alias for either default or non-default prediction, i.e., any 99 | prediction. 100 | 101 | :data:`~typing.Any` 102 | """ 103 | from enum import Enum, auto 104 | from typing import Any, Iterable, Self, TypeGuard 105 | 106 | import numpy as np 107 | import torch 108 | from PIL import Image 109 | 110 | type Scalar = bool | int | float | str | np.generic | np.ndarray | torch.Tensor 111 | type Tensor = Iterable[Scalar | Tensor] | Image.Image 112 | type Default = Scalar | Tensor 113 | type StandardScalar = bool | int | float | str 114 | type StandardTensor = list[StandardScalar | StandardTensor] 115 | type StandardDefault = StandardScalar | StandardTensor 116 | type NonDefault[T] = T 117 | type Anything = Default | NonDefault 118 | 119 | 120 | class PredType(Enum): 121 | """Enum class for expected prediction types. 122 | 123 | This class specifies the expected prediction types mainly for 124 | classification, detection and segmentation models that work with 125 | image data. The expected types are called **Default** and there are 126 | two categories of them: 127 | 128 | 1. **Standard**: these are the basic *Python* types, i.e., 129 | :class:`bool`, :class:`int`, :class:`float`, :class:`str`, and 130 | :class:`list`. Standard types are easy to work with, e.g., they 131 | can be parsed by JSON and YAML formats. 132 | 2. **Non-standard**: these additionally contain types like 133 | :class:`numpy.ndarray`, :class:`torch.Tensor`, and 134 | :class:`PIL.Image.Image`. They are convenient due to more 135 | flexibility for model prediction outputs. In most cases, they can 136 | be converted to standard types via :meth:`standardize`. 137 | 138 | Note: 139 | The constants defined in this class are only enums, not 140 | actual classes or :class:`type` objects. For type hints, 141 | corresponding type aliases are defined in the same file as this 142 | class. 143 | 144 | Warning: 145 | The enum types are not exclusive, for example, :attr:`SCALAR` 146 | is also a :attr:`DEFAULT`. 147 | 148 | Examples 149 | -------- 150 | 151 | Type aliases (:class:`~typing.TypeAliasType` objects defined 152 | using :class:`type` keyword) corresponding to the enums of 153 | this class are defined in the same file as :class:`PredType`. They 154 | can be used to specify the expected types when defining the methods: 155 | 156 | .. code-block:: python 157 | 158 | >>> from glasses_detector.components.pred_type import StandardScalar 159 | >>> def predict_class( 160 | ... self, 161 | ... image: Image.Image, 162 | ... output_format: str = "score", 163 | ... ) -> StandardScalar: 164 | ... ... 165 | 166 | :class:`PredType` static and class methods can be used to check 167 | the type of the prediction: 168 | 169 | .. code-block:: python 170 | 171 | >>> PredType.is_standard_scalar(1) 172 | True 173 | >>> PredType.is_standard_scalar(np.array([1])[0]) 174 | False 175 | >>> PredType.is_default(Image.fromarray(np.zeros((1, 1)))) 176 | True 177 | 178 | Finally, :meth:`standardize` can be used to convert the 179 | prediction to a standard type: 180 | 181 | .. code-block:: python 182 | 183 | >>> PredType.standardize(np.array([1, 2, 3])) 184 | [1, 2, 3] 185 | >>> PredType.standardize(Image.fromarray(np.zeros((1, 1)))) 186 | [[0.0]] 187 | """ 188 | 189 | """Enum: Scalar type. A prediction is considered to be a scalar if""" 190 | 191 | SCALAR = auto() 192 | """ 193 | PredType: Scalar type. A prediction is considered to be a scalar if 194 | it is one of the following types: 195 | 196 | * :class:`bool` 197 | * :class:`int` 198 | * :class:`float` 199 | * :class:`str` 200 | * :class:`numpy.generic` 201 | * :class:`numpy.ndarray` with ``ndim == 0`` 202 | * :class:`torch.Tensor` with ``ndim == 0`` 203 | 204 | :meta hide-value: 205 | """ 206 | 207 | TENSOR = auto() 208 | """ 209 | PredType: Tensor type. A prediction is considered to be a tensor if 210 | it is one of the following types: 211 | 212 | * :class:`PIL.Image.Image` 213 | * :class:`~typing.Iterable` of scalars or tensors of any iterable 214 | type, including :class:`list`, :class:`tuple`, 215 | :class:`~typing.Collection`, :class:`numpy.ndarray` and 216 | :class:`torch.Tensor` objects, and any other iterables. 217 | 218 | :meta hide-value: 219 | """ 220 | 221 | DEFAULT = auto() 222 | """ 223 | PredType: Default type. A prediction is considered to be a default 224 | type if it is one of the following types: 225 | 226 | * Any of the types defined in :attr:`SCALAR`. 227 | * Any of the types defined in :attr:`TENSOR`. 228 | 229 | :meta hide-value: 230 | """ 231 | 232 | STANDARD_SCALAR = auto() 233 | """ 234 | PredType: Standard scalar type. A prediction is considered to be a 235 | standard scalar if it is one of the following types: 236 | 237 | * :class:`bool` 238 | * :class:`int` 239 | * :class:`float` 240 | * :class:`str` 241 | 242 | :meta hide-value: 243 | """ 244 | 245 | STANDARD_TENSOR = auto() 246 | """ 247 | PredType: Standard tensor type. A prediction is considered to be a 248 | standard tensor if it is one of the following types: 249 | 250 | * :class:`list` of standard scalars or standard tensors. No other 251 | iterables than lists are allowed. 252 | 253 | :meta hide-value: 254 | """ 255 | 256 | STANDARD_DEFAULT = auto() 257 | """ 258 | PredType: Standard default type. A prediction is considered to be a 259 | standard default type if it is one of the following types: 260 | 261 | * Any of the types defined in :attr:`STANDARD_SCALAR`. 262 | * Any of the types defined in :attr:`STANDARD_TENSOR`. 263 | 264 | :meta hide-value: 265 | """ 266 | 267 | NON_DEFAULT = auto() 268 | """ 269 | PredType: Non-default type. A prediction is considered to be a 270 | non-default type if it is not a default type, i.e., it is not any of 271 | the types defined in :attr:`DEFAULT`. 272 | 273 | :meta hide-value: 274 | """ 275 | 276 | @staticmethod 277 | def is_scalar(pred: Any) -> TypeGuard[Scalar]: 278 | """Checks if the prediction is a **scalar**. 279 | 280 | .. seealso:: 281 | 282 | :attr:`SCALAR` 283 | 284 | Args: 285 | pred: The value to check. 286 | 287 | Returns: 288 | :data:`True` if the value is a **scalar**, :data:`False` 289 | otherwise. 290 | """ 291 | return isinstance(pred, (bool, int, float, str, np.generic)) or ( 292 | isinstance(pred, (torch.Tensor, np.ndarray)) and pred.ndim == 0 293 | ) 294 | 295 | @staticmethod 296 | def is_standard_scalar(pred: Any) -> TypeGuard[StandardScalar]: 297 | """Checks if the prediction is a **standard scalar**. 298 | 299 | .. seealso:: 300 | 301 | :attr:`STANDARD_SCALAR` 302 | 303 | Args: 304 | pred: The value to check. 305 | 306 | Returns: 307 | :data:`True` if the value is a **standard scalar**, 308 | :data:`False` otherwise. 309 | """ 310 | return isinstance(pred, (bool, int, float, str)) 311 | 312 | @classmethod 313 | def is_tensor(cls, pred: Any) -> TypeGuard[Tensor]: 314 | """Checks if the prediction is a **tensor**. 315 | 316 | .. seealso:: 317 | 318 | :attr:`TENSOR` 319 | 320 | Args: 321 | pred: The value to check. 322 | 323 | Returns: 324 | :data:`True` if the value is a **tensor**, 325 | :data:`False` otherwise. 326 | """ 327 | return isinstance(pred, Image.Image) or ( 328 | isinstance(pred, Iterable) 329 | and not cls.is_scalar(pred) 330 | and all([cls.is_scalar(p) or cls.is_tensor(p) for p in pred]) 331 | ) 332 | 333 | @classmethod 334 | def is_standard_tensor(cls, pred: Any) -> TypeGuard[StandardTensor]: 335 | """Checks if the prediction is a **standard tensor**. 336 | 337 | .. seealso:: 338 | 339 | :attr:`STANDARD_TENSOR` 340 | 341 | Args: 342 | pred: The value to check. 343 | 344 | Returns: 345 | :data:`True` if the value is a **standard tensor**, 346 | :data:`False` otherwise. 347 | """ 348 | return isinstance(pred, list) and all( 349 | [cls.is_standard_scalar(p) or cls.is_standard_tensor(p) for p in pred] 350 | ) 351 | 352 | @classmethod 353 | def is_default(cls, pred: Any) -> TypeGuard[Default]: 354 | """Checks if the prediction is a **default** type. 355 | 356 | .. seealso:: 357 | 358 | :attr:`DEFAULT` 359 | 360 | Args: 361 | pred: The value to check. 362 | 363 | Returns: 364 | :data:`True` if the type of the value is **default**, 365 | :data:`False` otherwise. 366 | """ 367 | return cls.is_scalar(pred) or cls.is_tensor(pred) 368 | 369 | @classmethod 370 | def is_standard_default(cls, pred: Any) -> TypeGuard[StandardDefault]: 371 | """Checks if the prediction is a **standard default** type. 372 | 373 | .. seealso:: 374 | 375 | :attr:`STANDARD_DEFAULT` 376 | 377 | Args: 378 | pred: The value to check. 379 | 380 | Returns: 381 | :data:`True` if the type of the value is **standard 382 | default**, :data:`False` otherwise. 383 | """ 384 | return cls.is_standard_scalar(pred) or cls.is_standard_tensor(pred) 385 | 386 | @classmethod 387 | def check(cls, pred: Any) -> Self: 388 | """Checks the type of the prediction and returns its enum. 389 | 390 | Checks the type of the prediction and returns the corresponding 391 | enum of the lowest type category. First, it checks if the 392 | prediction is a **standard scalar** or a regular **scalar** (in 393 | that order). If not, it checks if the prediction is a **standard 394 | tensor** or a regular **tensor** (in that order). Finally, if 395 | none of the previous checks are successful, it returns 396 | :attr:`NON_DEFAULT`. 397 | 398 | Note: 399 | All four types, i.e., :attr:`STANDARD_SCALAR`, 400 | :attr:`SCALAR`, :attr:`STANDARD_TENSOR`, and :attr:`TENSOR` 401 | are subclasses of :attr:`DEFAULT`. 402 | 403 | Args: 404 | pred: The value to check. 405 | 406 | Returns: 407 | The corresponding enum of the lowest type category or 408 | :attr:`NON_DEFAULT` if no **default** category is 409 | applicable. 410 | """ 411 | if cls.is_standard_scalar(pred): 412 | return cls.STANDARD_SCALAR 413 | elif cls.is_scalar(pred): 414 | return cls.SCALAR 415 | elif cls.is_standard_tensor(pred): 416 | return cls.STANDARD_TENSOR 417 | elif cls.is_tensor(pred): 418 | return cls.TENSOR 419 | else: 420 | return cls.NON_DEFAULT 421 | 422 | @classmethod 423 | def standardize(cls, pred: Default) -> StandardDefault: 424 | """Standardize the prediction. 425 | 426 | Standardize the prediction to a **standard default** type. If 427 | the prediction is already a **standard default** type, it is 428 | returned as-is. Otherwise, it is converted to a **standard 429 | default** type using the following rules: 430 | 431 | * :class:`bool`, :class:`int`, :class:`float`, and :class:`str` 432 | are returned as-is. 433 | * :class:`numpy.generic` and :class:`numpy.ndarray` with 434 | :attr:`~numpy.ndarray.ndim` = ``0`` are converted to 435 | **standard scalars**. 436 | * :class:`torch.Tensor` with :attr:`~numpy.ndarray.ndim` = ``0`` 437 | is converted to **standard scalars**. 438 | * :class:`numpy.ndarray` and :class:`torch.Tensor` with 439 | :attr:`~numpy.ndarray.ndim` > ``0`` are converted to 440 | **standard tensors**. 441 | * :class:`PIL.Image.Image` is converted to **standard tensors** 442 | by converting it to a :class:`numpy.ndarray` and then 443 | applying the previous rule. 444 | * All other iterables are converted to **standard tensors** by 445 | applying the previous rule to each element. 446 | 447 | Args: 448 | pred: The **default** prediction to standardize. 449 | 450 | Raises: 451 | ValueError: If the prediction cannot be standardized. This 452 | can happen if a prediction is not **default** or if a 453 | class, such as :class:`torch.Tensor` or 454 | :class:`numpy.ndarray` returns a scalar that is not of 455 | type defined in :attr:`SCALAR`. 456 | 457 | Returns: 458 | The standardized prediction. 459 | """ 460 | if isinstance(pred, (bool, int, float, str)): 461 | return pred 462 | elif isinstance(pred, np.generic): 463 | return cls.standardize(pred.item()) 464 | elif isinstance(pred, (np.ndarray, torch.Tensor)) and pred.ndim == 0: 465 | return cls.standardize(pred.item()) 466 | elif isinstance(pred, Image.Image): 467 | return np.asarray(pred).tolist() 468 | elif isinstance(pred, Iterable): 469 | return [cls.standardize(item) for item in pred] 470 | else: 471 | raise ValueError(f"Cannot standardize {type(pred)}") 472 | --------------------------------------------------------------------------------