├── .conda └── meta.yaml ├── .github ├── CODEOWNERS ├── FUNDING.yml ├── ISSUE_TEMPLATE │ ├── bug_report.yml │ ├── config.yml │ └── feature_request.yml ├── PULL_REQUEST_TEMPLATE.md ├── SECURITY.md ├── collect_env.py ├── dependabot.yml ├── labeler.yml ├── pull-labels.yml ├── release.yml ├── verify_deps_sync.py ├── verify_labels.py └── workflows │ ├── build.yml │ ├── page-build.yml │ ├── pr-merged.yml │ ├── push.yml │ ├── release.yml │ ├── style.yml │ ├── tests.yml │ └── triage.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CITATION.cff ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── README.md ├── api ├── Dockerfile ├── Makefile ├── README.md ├── app │ ├── config.py │ ├── main.py │ ├── routes │ │ └── classification.py │ ├── schemas.py │ └── vision.py ├── docker-compose.yml ├── pyproject.toml ├── tests │ ├── conftest.py │ └── routes │ │ └── test_classification.py └── uv.lock ├── demo ├── app.py └── requirements.txt ├── docs ├── Makefile ├── README.md ├── build.sh ├── make.bat └── source │ ├── _static │ ├── css │ │ └── custom_theme.css │ ├── images │ │ ├── favicon.ico │ │ ├── logo.png │ │ └── logo_text.png │ └── js │ │ └── custom.js │ ├── _templates │ └── function.rst │ ├── changelog.rst │ ├── conf.py │ ├── docutils.conf │ ├── index.rst │ ├── installing.rst │ ├── models.rst │ ├── models │ ├── convnext.rst │ ├── darknet.rst │ ├── darknetv2.rst │ ├── darknetv3.rst │ ├── darknetv4.rst │ ├── mobileone.rst │ ├── pyconv_resnet.rst │ ├── repvgg.rst │ ├── res2net.rst │ ├── resnet.rst │ ├── resnext.rst │ ├── rexnet.rst │ ├── sknet.rst │ └── tridentnet.rst │ ├── nn.functional.rst │ ├── nn.rst │ ├── notebooks.md │ ├── ops.rst │ ├── optim.rst │ ├── trainer.rst │ ├── transforms.rst │ ├── utils.data.rst │ └── utils.rst ├── holocron ├── __init__.py ├── models │ ├── __init__.py │ ├── checkpoints.py │ ├── classification │ │ ├── __init__.py │ │ ├── convnext.py │ │ ├── darknet.py │ │ ├── darknetv2.py │ │ ├── darknetv3.py │ │ ├── darknetv4.py │ │ ├── mobileone.py │ │ ├── pyconv_resnet.py │ │ ├── repvgg.py │ │ ├── res2net.py │ │ ├── resnet.py │ │ ├── rexnet.py │ │ ├── sknet.py │ │ └── tridentnet.py │ ├── detection │ │ ├── __init__.py │ │ ├── yolo.py │ │ ├── yolov2.py │ │ └── yolov4.py │ ├── presets.py │ ├── segmentation │ │ ├── __init__.py │ │ ├── unet.py │ │ ├── unet3p.py │ │ └── unetpp.py │ └── utils.py ├── nn │ ├── __init__.py │ ├── functional.py │ ├── init.py │ └── modules │ │ ├── __init__.py │ │ ├── activation.py │ │ ├── attention.py │ │ ├── conv.py │ │ ├── downsample.py │ │ ├── dropblock.py │ │ ├── lambda_layer.py │ │ └── loss.py ├── ops │ ├── __init__.py │ └── boxes.py ├── optim │ ├── __init__.py │ ├── adabelief.py │ ├── adamp.py │ ├── adan.py │ ├── ademamix.py │ ├── lamb.py │ ├── lars.py │ ├── ralars.py │ ├── tadam.py │ └── wrapper.py ├── trainer │ ├── __init__.py │ ├── classification.py │ ├── core.py │ ├── detection.py │ ├── segmentation.py │ └── utils.py ├── transforms │ ├── __init__.py │ └── interpolation.py └── utils │ ├── __init__.py │ ├── data │ ├── __init__.py │ └── collate.py │ └── misc.py ├── notebooks └── README.md ├── pyproject.toml ├── references ├── README.md ├── classification │ ├── README.md │ └── train.py ├── clean_checkpoint.py ├── detection │ ├── README.md │ ├── train.py │ └── transforms.py └── segmentation │ ├── README.md │ ├── train.py │ └── transforms.py ├── scripts ├── eval_latency.py └── export_to_onnx.py ├── setup.py └── tests ├── test_models.py ├── test_models_classification.py ├── test_models_detection.py ├── test_models_segmentation.py ├── test_nn.py ├── test_nn_activation.py ├── test_nn_attention.py ├── test_nn_conv.py ├── test_nn_downsample.py ├── test_nn_init.py ├── test_nn_loss.py ├── test_ops.py ├── test_optim.py ├── test_optim_wrapper.py ├── test_trainer.py ├── test_trainer_utils.py ├── test_transforms.py └── test_utils.py /.conda/meta.yaml: -------------------------------------------------------------------------------- 1 | # https://docs.conda.io/projects/conda-build/en/latest/resources/define-metadata.html#loading-data-from-other-files 2 | # https://github.com/conda/conda-build/pull/4480 3 | # for conda-build > 3.21.9 4 | # {% set pyproject = load_file_data('../pyproject.toml', from_recipe_dir=True) %} 5 | # {% set project = pyproject.get('project') %} 6 | # {% set urls = pyproject.get('project', {}).get('urls') %} 7 | package: 8 | name: pylocron 9 | version: "{{ environ.get('BUILD_VERSION', '0.2.2.dev0') }}" 10 | 11 | source: 12 | fn: pylocron-{{ environ.get('BUILD_VERSION', '0.2.2.dev0') }}.tar.gz 13 | url: ../dist/pylocron-{{ environ.get('BUILD_VERSION', '0.2.2.dev0') }}.tar.gz 14 | 15 | build: 16 | number: 0 17 | noarch: python 18 | script: python setup.py install --single-version-externally-managed --record=record.txt 19 | 20 | requirements: 21 | host: 22 | - python>=3.8, <4.0 23 | - setuptools 24 | 25 | run: 26 | - python>=3.8, <4.0 27 | - pytorch >=2.0.0, <3.0.0 28 | - torchvision >=0.15.0, <1.0.0 29 | - tqdm >=4.1.0 30 | - numpy >=1.17.2, <2.0.0 31 | - fastprogress >=1.0.0, <2.0.0 32 | - matplotlib >=3.0.0, <4.0.0 33 | - pillow >=8.4.0, !=9.2.0 34 | - huggingface_hub >=0.4.0 35 | 36 | test: 37 | # Python imports 38 | imports: 39 | - holocron 40 | - holocron.models 41 | - holocron.nn 42 | - holocron.ops 43 | - holocron.optim 44 | - holocron.trainer 45 | - holocron.utils 46 | requires: 47 | - python 48 | 49 | about: 50 | home: https://github.com/frgfm/Holocron 51 | license: Apache 2.0 52 | license_file: LICENSE 53 | summary: 'Modules, operations and models for computer vision in PyTorch' 54 | # description: | 55 | # {{ data['long_description'] | replace("\n", "\n ") | replace("#", '\#')}} 56 | doc_url: https://frgfm.github.io/Holocron/ 57 | dev_url: https://github.com/frgfm/Holocron 58 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # This is a comment. 2 | # Each line is a file pattern followed by one or more owners. 3 | 4 | # These owners will be the default owners for everything in 5 | # the repo. Unless a later match takes precedence, 6 | # @global-owner1 and @global-owner2 will be requested for 7 | # review when someone opens a pull request. 8 | * @frgfm 9 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: frgfm 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with an OpenCollective account 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 13 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | name: 🐛 Bug report 2 | description: Create a report to help us improve the library 3 | labels: bug 4 | assignees: frgfm 5 | 6 | body: 7 | - type: markdown 8 | attributes: 9 | value: > 10 | #### Before reporting a bug, please check that the issue hasn't already been addressed in [the existing and past issues](https://github.com/frgfm/Holocron/issues?q=is%3Aissue). 11 | - type: textarea 12 | attributes: 13 | label: Bug description 14 | description: | 15 | A clear and concise description of what the bug is. 16 | 17 | Please explain the result you observed and the behavior you were expecting. 18 | placeholder: | 19 | A clear and concise description of what the bug is. 20 | validations: 21 | required: true 22 | 23 | - type: textarea 24 | attributes: 25 | label: Code snippet to reproduce the bug 26 | description: | 27 | Sample code to reproduce the problem. 28 | 29 | Please wrap your code snippet with ```` ```triple quotes blocks``` ```` for readability. 30 | placeholder: | 31 | ```python 32 | Sample code to reproduce the problem 33 | ``` 34 | validations: 35 | required: true 36 | - type: textarea 37 | attributes: 38 | label: Error traceback 39 | description: | 40 | The error message you received running the code snippet, with the full traceback. 41 | 42 | Please wrap your error message with ```` ```triple quotes blocks``` ```` for readability. 43 | placeholder: | 44 | ``` 45 | The error message you got, with the full traceback. 46 | ``` 47 | validations: 48 | required: true 49 | - type: textarea 50 | attributes: 51 | label: Environment 52 | description: | 53 | Please run the following command and paste the output below. 54 | ```sh 55 | wget https://raw.githubusercontent.com/frgfm/Holocron/main/.github/collect_env.py 56 | # For security purposes, please check the contents of collect_env.py before running it. 57 | python collect_env.py 58 | ``` 59 | validations: 60 | required: true 61 | - type: markdown 62 | attributes: 63 | value: > 64 | Thanks for helping us improve the library! 65 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: true 2 | contact_links: 3 | - name: Usage questions 4 | url: https://github.com/frgfm/Holocron/discussions 5 | about: Ask questions and discuss with other Holocron community members 6 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | name: 🚀 Feature request 2 | description: Submit a proposal/request for a new feature for Holocron 3 | labels: enhancement 4 | assignees: frgfm 5 | 6 | body: 7 | - type: textarea 8 | attributes: 9 | label: 🚀 Feature 10 | description: > 11 | A clear and concise description of the feature proposal 12 | validations: 13 | required: true 14 | - type: textarea 15 | attributes: 16 | label: Motivation & pitch 17 | description: > 18 | Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link here too. 19 | validations: 20 | required: true 21 | - type: textarea 22 | attributes: 23 | label: Alternatives 24 | description: > 25 | A description of any alternative solutions or features you've considered, if any. 26 | - type: textarea 27 | attributes: 28 | label: Additional context 29 | description: > 30 | Add any other context or screenshots about the feature request. 31 | - type: markdown 32 | attributes: 33 | value: > 34 | Thanks for contributing 🎉 35 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | # What does this PR do? 2 | 3 | 9 | 10 | 11 | 12 | Closes # (issue) 13 | 14 | 15 | ## Before submitting 16 | - [ ] Was this discussed/approved in a Github [issue](https://github.com/frgfm/Holocron/issues?q=is%3Aissue) or a [discussion](https://github.com/frgfm/Holocron/discussions)? Please add a link to it if that's the case. 17 | - [ ] You have read the [contribution guidelines](https://github.com/frgfm/Holocron/blob/main/CONTRIBUTING.md#submitting-a-pull-request) and followed them in this PR. 18 | - [ ] Did you make sure to update the documentation with your changes? Here are the 19 | [documentation guidelines](https://github.com/frgm/Holocron/tree/main/docs). 20 | - [ ] Did you write any new necessary tests? 21 | -------------------------------------------------------------------------------- /.github/SECURITY.md: -------------------------------------------------------------------------------- 1 | # Reporting security issues 2 | 3 | If you believe you have found a security vulnerability in Holocron, we encourage you to let us know right away. We will investigate all legitimate reports and do our best to quickly fix the problem. 4 | 5 | Please report security issues using https://github.com/frgfm/Holocron/security/advisories/new 6 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "github-actions" 9 | directory: "/" 10 | schedule: 11 | interval: "weekly" 12 | - package-ecosystem: "pip" 13 | directory: "/" 14 | schedule: 15 | interval: "daily" 16 | allow: 17 | - dependency-name: "ruff" 18 | - dependency-name: "mypy" 19 | - dependency-name: "pre-commit" 20 | - dependency-name: "torch" 21 | - dependency-name: "torchvision" 22 | - package-ecosystem: "pip" 23 | directory: "api/" 24 | schedule: 25 | interval: "daily" 26 | allow: 27 | - dependency-name: "fastapi" 28 | - dependency-name: "onnxruntime" 29 | - package-ecosystem: "docker" 30 | directory: "api/" 31 | schedule: 32 | interval: "daily" 33 | allow: 34 | - dependency-name: "ghcr.io/astral-sh/uv" 35 | -------------------------------------------------------------------------------- /.github/labeler.yml: -------------------------------------------------------------------------------- 1 | 'module: models': 2 | - changed-files: 3 | - any-glob-to-any-file: holocron/models/* 4 | 5 | 'module: nn': 6 | - changed-files: 7 | - any-glob-to-any-file: holocron/nn/* 8 | 9 | 'module: ops': 10 | - changed-files: 11 | - any-glob-to-any-file: holocron/ops/* 12 | 13 | 'module: optim': 14 | - changed-files: 15 | - any-glob-to-any-file: holocron/optim/* 16 | 17 | 'module: trainer': 18 | - changed-files: 19 | - any-glob-to-any-file: holocron/trainer/* 20 | 21 | 'module: transforms': 22 | - changed-files: 23 | - any-glob-to-any-file: holocron/transforms/* 24 | 25 | 'module: utils': 26 | - changed-files: 27 | - any-glob-to-any-file: holocron/utils/* 28 | 29 | 'ext: api': 30 | - changed-files: 31 | - any-glob-to-any-file: api/* 32 | 33 | 'ext: demo': 34 | - changed-files: 35 | - any-glob-to-any-file: demo/* 36 | 37 | 'ext: docs': 38 | - changed-files: 39 | - any-glob-to-any-file: docs/* 40 | 41 | 'ext: references': 42 | - changed-files: 43 | - any-glob-to-any-file: references/* 44 | 45 | 'ext: scripts': 46 | - changed-files: 47 | - any-glob-to-any-file: scripts/* 48 | 49 | 'ext: tests': 50 | - changed-files: 51 | - any-glob-to-any-file: tests/* 52 | 53 | 'ext: ci': 54 | - changed-files: 55 | - any-glob-to-any-file: .github/* 56 | 57 | 'topic: docs': 58 | - changed-files: 59 | - any-glob-to-any-file: 60 | - README.md 61 | - CONTRIBUTING.md 62 | - CODE_OF_CONDUCT.md 63 | - Makefile 64 | - .env.example 65 | - SECURITY.md 66 | - notebooks/* 67 | 68 | 'func: build': 69 | - changed-files: 70 | - any-glob-to-any-file: 71 | - setup.py 72 | - pyproject.toml 73 | - api/Dockerfile 74 | - api/pyproject.toml 75 | - api/*.lock 76 | - api/docker-compose.* 77 | 78 | 'topic: style': 79 | - changed-files: 80 | - any-glob-to-any-file: .pre-commit-config.yaml 81 | -------------------------------------------------------------------------------- /.github/pull-labels.yml: -------------------------------------------------------------------------------- 1 | primary: 2 | - "type: feat" 3 | - "type: fix" 4 | - "type: improvement" 5 | - "type: misc" 6 | 7 | secondary: 8 | - "func: build" 9 | - "ext: api" 10 | - "ext: ci" 11 | - "ext: demo" 12 | - "ext: docs" 13 | - "ext: references" 14 | - "ext: scripts" 15 | - "ext: tests" 16 | - "topic: docs" 17 | - "topic: style" 18 | - "module: models" 19 | - "module: nn" 20 | - "module: ops" 21 | - "module: optim" 22 | - "module: trainer" 23 | - "module: transforms" 24 | - "module: utils" 25 | -------------------------------------------------------------------------------- /.github/release.yml: -------------------------------------------------------------------------------- 1 | changelog: 2 | exclude: 3 | labels: 4 | - ignore-for-release 5 | categories: 6 | - title: Breaking Changes 🛠 7 | labels: 8 | - "type: breaking change" 9 | # NEW FEATURES 10 | - title: New Features 🚀 11 | labels: 12 | - "type: feat" 13 | # BUG FIXES 14 | - title: Bug Fixes 🐛 15 | labels: 16 | - "type: fix" 17 | # IMPROVEMENTS 18 | - title: Improvements 19 | labels: 20 | - "type: improvement" 21 | # MISC 22 | - title: Miscellaneous 23 | labels: 24 | - "type: misc" 25 | -------------------------------------------------------------------------------- /.github/verify_deps_sync.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024, François-Guillaume Fernandez. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | import re 7 | from pathlib import Path 8 | 9 | import tomllib 10 | import yaml 11 | 12 | DOCKERFILE_PATH = "api/Dockerfile" 13 | PRECOMMIT_PATH = ".pre-commit-config.yaml" 14 | PYPROJECT_PATH = "pyproject.toml" 15 | 16 | 17 | def main(): 18 | # Retrieve & parse all deps files 19 | deps_dict = {} 20 | # UV: Dockerfile, precommit, .github 21 | # Parse Dockerfile 22 | with Path(DOCKERFILE_PATH).open("r") as f: 23 | dockerfile = f.read() 24 | uv_version = re.search(r"ghcr\.io/astral-sh/uv:(\d+\.\d+\.\d+)", dockerfile) 25 | if uv_version: 26 | deps_dict["uv"] = [{"file": DOCKERFILE_PATH, "version": uv_version.group(1)}] 27 | 28 | # Parse precommit 29 | with Path(PRECOMMIT_PATH).open("r") as f: 30 | precommit = yaml.safe_load(f) 31 | 32 | for repo in precommit["repos"]: 33 | if repo["repo"] == "https://github.com/astral-sh/uv-pre-commit": 34 | if "uv" not in deps_dict: 35 | deps_dict["uv"] = [] 36 | deps_dict["uv"].append({"file": PRECOMMIT_PATH, "version": repo["rev"].lstrip("v")}) 37 | elif repo["repo"] == "https://github.com/charliermarsh/ruff-pre-commit": 38 | if "ruff" not in deps_dict: 39 | deps_dict["ruff"] = [] 40 | deps_dict["ruff"].append({"file": PRECOMMIT_PATH, "version": repo["rev"].lstrip("v")}) 41 | 42 | # Parse pyproject.toml 43 | with Path(PYPROJECT_PATH).open("rb") as f: 44 | pyproject = tomllib.load(f) 45 | 46 | for group_deps in pyproject["project"]["optional-dependencies"]: 47 | for dep in group_deps: 48 | if dep.startswith("ruff=="): 49 | if "ruff" not in deps_dict: 50 | deps_dict["ruff"] = [] 51 | deps_dict["ruff"].append({"file": PYPROJECT_PATH, "version": dep.split("==")[1]}) 52 | elif dep.startswith("mypy=="): 53 | if "mypy" not in deps_dict: 54 | deps_dict["mypy"] = [] 55 | deps_dict["mypy"].append({"file": PYPROJECT_PATH, "version": dep.split("==")[1]}) 56 | 57 | # Parse github/workflows/... 58 | for workflow_file in Path(".github/workflows").glob("*.yml"): 59 | with workflow_file.open("r") as f: 60 | workflow = yaml.safe_load(f) 61 | if "env" in workflow and "UV_VERSION" in workflow["env"]: 62 | if "uv" not in deps_dict: 63 | deps_dict["uv"] = [] 64 | deps_dict["uv"].append({ 65 | "file": str(workflow_file), 66 | "version": workflow["env"]["UV_VERSION"].lstrip("v"), 67 | }) 68 | 69 | # Assert all deps are in sync 70 | troubles = [] 71 | for dep, versions in deps_dict.items(): 72 | versions_ = {v["version"] for v in versions} 73 | if len(versions_) != 1: 74 | inv_dict = {v: set() for v in versions_} 75 | for version in versions: 76 | inv_dict[version["version"]].add(version["file"]) 77 | troubles.extend([ 78 | f"{dep}:", 79 | "\n".join(f"- '{v}': {', '.join(files)}" for v, files in inv_dict.items()), 80 | ]) 81 | 82 | if len(troubles) > 0: 83 | raise AssertionError("Some dependencies are out of sync:\n\n" + "\n".join(troubles)) 84 | 85 | 86 | if __name__ == "__main__": 87 | main() 88 | -------------------------------------------------------------------------------- /.github/verify_labels.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024, François-Guillaume Fernandez. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | """ 7 | Borrowed & adapted from https://github.com/pytorch/vision/blob/main/.github/process_commit.py 8 | This script finds the merger responsible for labeling a PR by a commit SHA. It is used by the workflow in 9 | '.github/workflows/pull-labels.yml'. If there exists no PR associated with the commit or the PR is properly labeled, 10 | this script is a no-op. 11 | Note: we ping the merger only, not the reviewers, as the reviewers can sometimes be external to torchvision 12 | with no labeling responsibility, so we don't want to bother them. 13 | """ 14 | 15 | import os 16 | from pathlib import Path 17 | from typing import Any, Set, Tuple 18 | 19 | import requests 20 | import yaml 21 | 22 | 23 | def query_repo(cmd: str, *, accept) -> Any: 24 | auth = {"Authorization": f"Bearer {os.environ['GITHUB_TOKEN']}"} if os.environ.get("GITHUB_TOKEN") else {} 25 | response = requests.get( 26 | f"https://api.github.com/repos/{cmd}", 27 | headers={"Accept": accept, **auth}, 28 | timeout=5, 29 | ) 30 | return response.json() 31 | 32 | 33 | def get_pr_merger_and_labels(repo: str, pr_number: int) -> Tuple[str, Set[str]]: 34 | # See https://docs.github.com/en/rest/reference/pulls#get-a-pull-request 35 | data = query_repo(f"{repo}/pulls/{pr_number}", accept="application/vnd.github.v3+json") 36 | merger = data.get("merged_by", {}).get("login") 37 | labels = {label["name"] for label in data["labels"]} 38 | return merger, labels 39 | 40 | 41 | def main(args): 42 | # Load the labels 43 | with Path(__file__).parent.joinpath(args.file).open("r") as f: 44 | labels = yaml.safe_load(f) 45 | primary = set(labels["primary"]) 46 | secondary = set(labels["secondary"]) 47 | # Retrieve the PR info 48 | merger, labels = get_pr_merger_and_labels(args.repo, args.pr) 49 | # Check if the PR is properly labeled 50 | # For a PR to be properly labeled it should have one primary label and one secondary label 51 | is_properly_labeled = bool(primary.intersection(labels) and secondary.intersection(labels)) 52 | # If the PR is not properly labeled, ping the merger 53 | if isinstance(merger, str) and not is_properly_labeled: 54 | print(f"@{merger}") 55 | 56 | 57 | def parse_args(): 58 | import argparse 59 | 60 | parser = argparse.ArgumentParser( 61 | description="PR label checker", formatter_class=argparse.ArgumentDefaultsHelpFormatter 62 | ) 63 | 64 | parser.add_argument("repo", type=str, help="Repo full name") 65 | parser.add_argument("pr", type=int, help="PR number") 66 | parser.add_argument("--file", type=str, help="Path to the labels file", default="pull-labels.yml") 67 | return parser.parse_args() 68 | 69 | 70 | if __name__ == "__main__": 71 | args = parse_args() 72 | main(args) 73 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: 4 | push: 5 | branches: main 6 | pull_request: 7 | branches: main 8 | 9 | env: 10 | PYTHON_VERSION: "3.11" 11 | UV_VERSION: "0.5.13" 12 | 13 | jobs: 14 | package: 15 | runs-on: ${{ matrix.os }} 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | os: [ubuntu-latest, macos-latest, windows-latest] 20 | python: [3.9, '3.10', 3.11, 3.12] 21 | exclude: 22 | - os: macos-latest 23 | python: '3.10' 24 | steps: 25 | - uses: actions/checkout@v4 26 | - uses: actions/setup-python@v5 27 | with: 28 | python-version: ${{ matrix.python }} 29 | architecture: x64 30 | - uses: astral-sh/setup-uv@v5 31 | with: 32 | version: ${{ env.UV_VERSION }} 33 | - name: Install package 34 | run: | 35 | make install 36 | python -c "import holocron; print(holocron.__version__)" 37 | 38 | pypi: 39 | runs-on: ubuntu-latest 40 | steps: 41 | - uses: actions/checkout@v4 42 | - uses: actions/setup-python@v5 43 | with: 44 | python-version: ${{ env.PYTHON_VERSION }} 45 | architecture: x64 46 | - uses: astral-sh/setup-uv@v5 47 | with: 48 | version: ${{ env.UV_VERSION }} 49 | - name: Build package 50 | run: | 51 | uv pip install --system setuptools wheel twine --upgrade 52 | python setup.py sdist bdist_wheel 53 | twine check dist/* 54 | 55 | conda: 56 | runs-on: ubuntu-latest 57 | steps: 58 | - uses: actions/checkout@v4 59 | - uses: conda-incubator/setup-miniconda@v3 60 | with: 61 | auto-update-conda: true 62 | python-version: ${{ env.PYTHON_VERSION }} 63 | - name: Install dependencies 64 | shell: bash -el {0} 65 | run: conda install -y conda-build conda-verify 66 | - name: Build conda 67 | shell: bash -el {0} 68 | run: | 69 | python setup.py sdist 70 | mkdir conda-dist 71 | conda env list 72 | conda-build .conda/ -c pytorch -c fastai -c conda-forge --output-folder conda-dist 73 | ls -l conda-dist/noarch/*tar.bz2 74 | 75 | api: 76 | runs-on: ubuntu-latest 77 | steps: 78 | - uses: actions/checkout@v4 79 | - uses: astral-sh/setup-uv@v5 80 | with: 81 | version: ${{ env.UV_VERSION }} 82 | - name: Build, run & check docker 83 | run: make start-api 84 | 85 | demo: 86 | runs-on: ubuntu-latest 87 | steps: 88 | - uses: actions/checkout@v4 89 | - uses: actions/setup-python@v5 90 | with: 91 | python-version: ${{ env.PYTHON_VERSION }} 92 | architecture: x64 93 | - uses: astral-sh/setup-uv@v5 94 | with: 95 | version: ${{ env.UV_VERSION }} 96 | - name: Install & run demo app 97 | run: | 98 | make install-demo 99 | screen -dm make run-demo 100 | sleep 20 && nc -vz localhost 8080 101 | 102 | docs: 103 | runs-on: ubuntu-latest 104 | steps: 105 | - uses: actions/checkout@v4 106 | - uses: actions/setup-python@v5 107 | with: 108 | python-version: "3.9" 109 | architecture: x64 110 | - uses: astral-sh/setup-uv@v5 111 | with: 112 | version: ${{ env.UV_VERSION }} 113 | - name: Build documentation 114 | run: | 115 | make install-docs 116 | make docs-full 117 | - name: Documentation sanity check 118 | run: test -e docs/build/index.html || exit 119 | -------------------------------------------------------------------------------- /.github/workflows/page-build.yml: -------------------------------------------------------------------------------- 1 | name: page-build 2 | on: 3 | page_build 4 | 5 | env: 6 | PYTHON_VERSION: "3.11" 7 | 8 | jobs: 9 | check-gh-pages: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/setup-python@v5 13 | with: 14 | python-version: ${{ env.PYTHON_VERSION }} 15 | architecture: x64 16 | - name: check status 17 | run: | 18 | import os 19 | status, errormsg = os.getenv('STATUS'), os.getenv('ERROR') 20 | if status != 'built': raise AssertionError(f"There was an error building the page on GitHub pages.\n\nStatus: {status}\n\nError messsage: {errormsg}") 21 | shell: python 22 | env: 23 | STATUS: ${{ github.event.build.status }} 24 | ERROR: ${{ github.event.build.error.message }} 25 | -------------------------------------------------------------------------------- /.github/workflows/pr-merged.yml: -------------------------------------------------------------------------------- 1 | name: pr-merged 2 | 3 | on: 4 | pull_request: 5 | branches: main 6 | types: closed 7 | 8 | env: 9 | PYTHON_VERSION: "3.11" 10 | UV_VERSION: "0.5.13" 11 | 12 | jobs: 13 | is-properly-labeled: 14 | if: github.event.pull_request.merged == true 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v4 18 | - uses: actions/setup-python@v5 19 | with: 20 | python-version: ${{ env.PYTHON_VERSION }} 21 | architecture: x64 22 | - uses: astral-sh/setup-uv@v5 23 | with: 24 | version: ${{ env.UV_VERSION }} 25 | - run: uv pip install requests 26 | - name: Process commit and find merger responsible for labeling 27 | id: commit 28 | run: echo "::set-output name=merger::$(python .github/verify_labels.py ${{ github.event.pull_request.number }})" 29 | - name: Comment PR 30 | uses: actions/github-script@v7.0.1 31 | if: ${{ steps.commit.outputs.merger != '' }} 32 | with: 33 | github-token: ${{ secrets.GITHUB_TOKEN }} 34 | script: | 35 | const { issue: { number: issue_number }, repo: { owner, repo } } = context; 36 | github.issues.createComment({ issue_number, owner, repo, body: 'Hey ${{ steps.commit.outputs.merger }} 👋\nYou merged this PR, but it is not correctly labeled. The list of valid labels is available at https://github.com/frgfm/Holocron/blob/main/.github/pull-labels.yml' }); 37 | -------------------------------------------------------------------------------- /.github/workflows/push.yml: -------------------------------------------------------------------------------- 1 | name: push 2 | on: 3 | push: 4 | branches: main 5 | 6 | env: 7 | PYTHON_VERSION: "3.11" 8 | UV_VERSION: "0.5.13" 9 | 10 | jobs: 11 | docs-deploy: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v4 15 | with: 16 | persist-credentials: false 17 | - uses: actions/setup-python@v5 18 | with: 19 | python-version: "3.9" 20 | architecture: x64 21 | - uses: astral-sh/setup-uv@v5 22 | with: 23 | version: ${{ env.UV_VERSION }} 24 | - name: Build documentation 25 | run: | 26 | make install-docs 27 | make docs-full 28 | - name: Documentation sanity check 29 | run: test -e docs/build/index.html || exit 30 | - name: Install SSH Client 🔑 31 | uses: webfactory/ssh-agent@v0.9.0 32 | with: 33 | ssh-private-key: ${{ secrets.SSH_DEPLOY_KEY }} 34 | - name: Deploy to Github Pages 35 | uses: JamesIves/github-pages-deploy-action@v4.7.2 36 | with: 37 | branch: gh-pages 38 | folder: docs/build 39 | commit-message: '[skip ci] Documentation updates' 40 | clean: true 41 | 42 | dockerhub: 43 | runs-on: ubuntu-latest 44 | steps: 45 | - uses: actions/checkout@v4 46 | - uses: astral-sh/setup-uv@v5 47 | with: 48 | version: ${{ env.UV_VERSION }} 49 | - name: Build docker image 50 | run: make build-api 51 | - name: Login to GHCR 52 | uses: docker/login-action@v3 53 | with: 54 | registry: ghcr.io 55 | username: ${{ github.repository_owner }} 56 | password: ${{ secrets.GITHUB_TOKEN }} 57 | - name: Push to GitHub container registry 58 | run: | 59 | docker tag holocron/backend:latest ghcr.io/${{ github.repository_owner }}/holocron:latest 60 | docker push ghcr.io/${{ github.repository_owner }}/holocron:latest 61 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: release 2 | on: 3 | release: 4 | types: [published] 5 | 6 | env: 7 | PYTHON_VERSION: "3.11" 8 | UV_VERSION: "0.5.13" 9 | 10 | jobs: 11 | pypi: 12 | if: "!github.event.release.prerelease" 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v4 16 | - uses: actions/setup-python@v5 17 | with: 18 | python-version: ${{ env.PYTHON_VERSION }} 19 | architecture: x64 20 | - uses: astral-sh/setup-uv@v5 21 | with: 22 | version: ${{ env.UV_VERSION }} 23 | - name: Install dependencies 24 | run: uv pip install --system setuptools wheel twine --upgrade 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | echo "BUILD_VERSION=${GITHUB_REF#refs/*/}" | cut -c 2- >> $GITHUB_ENV 31 | python setup.py sdist bdist_wheel 32 | twine check dist/* 33 | twine upload dist/* 34 | 35 | pypi-check: 36 | if: "!github.event.release.prerelease" 37 | runs-on: ubuntu-latest 38 | needs: pypi 39 | steps: 40 | - uses: actions/checkout@v4 41 | - uses: actions/setup-python@v5 42 | with: 43 | python-version: ${{ env.PYTHON_VERSION }} 44 | architecture: x64 45 | - uses: astral-sh/setup-uv@v5 46 | with: 47 | version: ${{ env.UV_VERSION }} 48 | - name: Install package 49 | run: | 50 | uv pip install --system pylocron 51 | python -c "import holocron; print(holocron.__version__)" 52 | 53 | conda: 54 | if: "!github.event.release.prerelease" 55 | runs-on: ubuntu-latest 56 | steps: 57 | - uses: actions/checkout@v4 58 | - uses: conda-incubator/setup-miniconda@v3 59 | with: 60 | auto-update-conda: true 61 | python-version: ${{ env.PYTHON_VERSION }} 62 | - name: Install dependencies 63 | shell: bash -el {0} 64 | run: conda install -y conda-build conda-verify anaconda-client 65 | - name: Build and publish 66 | shell: bash -el {0} 67 | env: 68 | ANACONDA_API_TOKEN: ${{ secrets.ANACONDA_TOKEN }} 69 | run: | 70 | echo "BUILD_VERSION=${GITHUB_REF#refs/*/}" | cut -c 2- >> $GITHUB_ENV 71 | python setup.py sdist 72 | mkdir conda-dist 73 | conda-build .conda/ -c pytorch -c fastai -c conda-forge --output-folder conda-dist 74 | ls -l conda-dist/noarch/*tar.bz2 75 | anaconda upload conda-dist/noarch/*tar.bz2 76 | 77 | conda-check: 78 | if: "!github.event.release.prerelease" 79 | runs-on: ubuntu-latest 80 | needs: conda 81 | steps: 82 | - uses: conda-incubator/setup-miniconda@v3 83 | with: 84 | auto-update-conda: true 85 | python-version: ${{ env.PYTHON_VERSION }} 86 | - name: Install package 87 | run: | 88 | conda install -c frgfm pylocron 89 | python -c "import holocron; print(holocron.__version__)" 90 | -------------------------------------------------------------------------------- /.github/workflows/style.yml: -------------------------------------------------------------------------------- 1 | name: style 2 | 3 | on: 4 | push: 5 | branches: main 6 | pull_request: 7 | branches: main 8 | 9 | env: 10 | PYTHON_VERSION: "3.11" 11 | UV_VERSION: "0.5.13" 12 | 13 | jobs: 14 | ruff: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v4 18 | - uses: actions/setup-python@v5 19 | with: 20 | python-version: ${{ env.PYTHON_VERSION }} 21 | architecture: x64 22 | - uses: astral-sh/setup-uv@v5 23 | with: 24 | version: ${{ env.UV_VERSION }} 25 | - name: Run ruff 26 | run: | 27 | make install-quality 28 | ruff --version 29 | make lint-check 30 | 31 | mypy: 32 | runs-on: ubuntu-latest 33 | steps: 34 | - uses: actions/checkout@v4 35 | - uses: actions/setup-python@v5 36 | with: 37 | python-version: ${{ env.PYTHON_VERSION }} 38 | architecture: x64 39 | - uses: astral-sh/setup-uv@v5 40 | with: 41 | version: ${{ env.UV_VERSION }} 42 | - name: Run mypy 43 | run: | 44 | uv export --no-hashes --locked -o api/requirements.txt --project api 45 | uv pip install --system -r api/requirements.txt 46 | uv pip install --system -e .[quality] 47 | mypy --version 48 | make typing-check 49 | 50 | precommit-hooks: 51 | runs-on: ubuntu-latest 52 | steps: 53 | - uses: actions/checkout@v4 54 | - uses: actions/setup-python@v5 55 | with: 56 | python-version: ${{ env.PYTHON_VERSION }} 57 | architecture: x64 58 | - uses: astral-sh/setup-uv@v5 59 | with: 60 | version: ${{ env.UV_VERSION }} 61 | - name: Run pre-commit hooks 62 | run: | 63 | make install-quality 64 | git checkout -b temp 65 | pre-commit install 66 | pre-commit --version 67 | make precommit 68 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: 4 | push: 5 | branches: main 6 | pull_request: 7 | branches: main 8 | 9 | env: 10 | PYTHON_VERSION: "3.11" 11 | UV_VERSION: "0.5.13" 12 | 13 | jobs: 14 | deps-sync: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v4 18 | - uses: actions/setup-python@v5 19 | with: 20 | python-version: ${{ env.PYTHON_VERSION }} 21 | architecture: x64 22 | - uses: astral-sh/setup-uv@v5 23 | with: 24 | version: ${{ env.UV_VERSION }} 25 | - name: Run dependency sync checker 26 | run: | 27 | uv pip install --system PyYAML 28 | make deps-check 29 | 30 | pytest: 31 | runs-on: ubuntu-latest 32 | steps: 33 | - uses: actions/checkout@v4 34 | with: 35 | persist-credentials: false 36 | - uses: actions/setup-python@v5 37 | with: 38 | python-version: ${{ env.PYTHON_VERSION }} 39 | architecture: x64 40 | - uses: astral-sh/setup-uv@v5 41 | with: 42 | version: ${{ env.UV_VERSION }} 43 | - name: Run the tests 44 | run: | 45 | make install-test 46 | pytest --cov=holocron --cov-report xml tests/ 47 | - uses: actions/upload-artifact@v4 48 | with: 49 | name: coverage-reports 50 | path: ./coverage.xml 51 | 52 | codecov-upload: 53 | runs-on: ubuntu-latest 54 | needs: pytest 55 | steps: 56 | - uses: actions/checkout@v4 57 | - uses: actions/download-artifact@v4 58 | - uses: codecov/codecov-action@v5 59 | with: 60 | token: ${{ secrets.CODECOV_TOKEN }} 61 | flags: unittests 62 | directory: ./coverage-reports 63 | fail_ci_if_error: true 64 | 65 | api: 66 | runs-on: ubuntu-latest 67 | steps: 68 | - uses: actions/checkout@v4 69 | - uses: actions/setup-python@v5 70 | with: 71 | python-version: ${{ env.PYTHON_VERSION }} 72 | architecture: x64 73 | - uses: astral-sh/setup-uv@v5 74 | with: 75 | version: ${{ env.UV_VERSION }} 76 | - name: Run docker test 77 | run: make test-api 78 | 79 | headers: 80 | runs-on: ubuntu-latest 81 | steps: 82 | - uses: actions/checkout@v4 83 | with: 84 | persist-credentials: false 85 | - name: Check the headers 86 | uses: frgfm/validate-python-headers@main 87 | with: 88 | license: 'Apache-2.0' 89 | owner: 'François-Guillaume Fernandez' 90 | starting-year: 2019 91 | folders: 'holocron,scripts,references,api/app,demo,docs,.github' 92 | ignore-files: 'version.py,__init__.py' 93 | 94 | eval-latency: 95 | runs-on: ubuntu-latest 96 | steps: 97 | - uses: actions/checkout@v4 98 | - uses: actions/setup-python@v5 99 | with: 100 | python-version: ${{ env.PYTHON_VERSION }} 101 | architecture: x64 102 | - uses: astral-sh/setup-uv@v5 103 | with: 104 | version: ${{ env.UV_VERSION }} 105 | - name: Run script 106 | run: | 107 | make install 108 | uv pip install --system onnx onnxruntime 109 | python scripts/eval_latency.py rexnet1_0x 110 | -------------------------------------------------------------------------------- /.github/workflows/triage.yml: -------------------------------------------------------------------------------- 1 | name: triage 2 | 3 | on: 4 | pull_request: 5 | branches: main 6 | 7 | jobs: 8 | autolabel: 9 | permissions: 10 | contents: read 11 | pull-requests: write 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/labeler@v5 15 | with: 16 | repo-token: "${{ secrets.GITHUB_TOKEN }}" 17 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | holocron/version.py 106 | conda-dist/ 107 | wandb/ 108 | **/*.pt 109 | **/*.pth 110 | **/*.onnx 111 | .codecarbon.config 112 | 113 | 114 | # API uses poetry 115 | api/requirements.txt 116 | api/requirements-dev.txt 117 | 118 | # Doc 119 | docs/source/generated/ 120 | docs/source/models/generated/ 121 | 122 | # Codecarbon 123 | emissions.csv 124 | 125 | # Gradio 126 | flagged/ 127 | 128 | # Training 129 | checkpoints/ 130 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3.11 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v4.5.0 6 | hooks: 7 | - id: check-added-large-files 8 | - id: check-ast 9 | - id: check-case-conflict 10 | - id: check-json 11 | - id: check-merge-conflict 12 | - id: check-symlinks 13 | - id: check-toml 14 | - id: check-xml 15 | - id: check-yaml 16 | exclude: .conda 17 | - id: debug-statements 18 | language_version: python3 19 | - id: end-of-file-fixer 20 | - id: no-commit-to-branch 21 | args: ['--branch', 'main'] 22 | - id: trailing-whitespace 23 | - repo: https://github.com/compilerla/conventional-pre-commit 24 | rev: 'v3.6.0' 25 | hooks: 26 | - id: conventional-pre-commit 27 | stages: [commit-msg] 28 | - repo: https://github.com/charliermarsh/ruff-pre-commit 29 | rev: 'v0.8.4' 30 | hooks: 31 | - id: ruff 32 | args: ["--fix", "--config", "pyproject.toml"] 33 | - id: ruff-format 34 | args: ["--config", "pyproject.toml"] 35 | - repo: https://github.com/astral-sh/uv-pre-commit 36 | rev: '0.5.13' 37 | hooks: 38 | - id: uv-lock 39 | args: ["--locked", "--project", "api"] 40 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | type: software 3 | message: "If you use this software in your work, please cite it as below." 4 | authors: 5 | - family-names: "Fernandez" 6 | given-names: "François-Guillaume" 7 | title: "Holocron" 8 | date-released: 2020-05-11 9 | url: "https://github.com/frgfm/Holocron" 10 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | fg-feedback@protonmail.com. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. 129 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to holocron 2 | 3 | Everything you need to know to contribute efficiently to the project! 4 | 5 | Whatever the way you wish to contribute to the project, please respect the [code of conduct](CODE_OF_CONDUCT.md). 6 | 7 | 8 | 9 | ## Codebase structure 10 | 11 | - [`./holocron`](holocron) - The actual holocron library 12 | - [`./tests`](tests) - Python unit tests 13 | - [`./docs`](docs) - Sphinx documentation building 14 | - [`./scripts`](scripts) - Example and utilities scripts 15 | - [`./references`](references) - Reference training scripts 16 | - [`./api`](api) - A minimal FastAPI backend to run Holocron models 17 | - [`./demo`](demo) - A minimal Gradio demo 18 | 19 | 20 | 21 | ## Continuous Integration 22 | 23 | This project uses the following integrations to ensure proper codebase maintenance: 24 | 25 | - [Github Worklow](https://help.github.com/en/actions/configuring-and-managing-workflows/configuring-a-workflow) - run jobs for package build and coverage 26 | - [Codacy](https://www.codacy.com/) - analyzes commits for code quality 27 | - [Codecov](https://codecov.io/) - reports back coverage results 28 | 29 | As a contributor, you will only have to ensure coverage of your code by adding appropriate unit testing of your code. 30 | 31 | 32 | 33 | ## Feedback 34 | 35 | ### Feature requests & bug report 36 | 37 | Whether you encountered a problem, or you have a feature suggestion, your input has value and can be used by contributors to reference it in their developments. For this purpose, we advise you to use Github [issues](https://github.com/frgfm/Holocron/issues). 38 | 39 | First, check whether the topic wasn't already covered in an open / closed issue. If not, feel free to open a new one! When doing so, use issue templates whenever possible and provide enough information for other contributors to jump in. 40 | 41 | ### Questions 42 | 43 | If you are wondering how to do something with Holocron, or a more general question, you should consider checking out Github [discussions](https://github.com/frgfm/Holocron/discussions). See it as a Q&A forum, or the Holocron-specific StackOverflow! 44 | 45 | 46 | 47 | ## Submitting a Pull Request 48 | 49 | ### Preparing your local branch 50 | 51 | 1 - Fork this [repository](https://github.com/frgfm/Holocron) by clicking on the "Fork" button at the top right of the page. This will create a copy of the project under your GitHub account (cf. [Fork a repo](https://docs.github.com/en/get-started/quickstart/fork-a-repo)). 52 | 53 | 2 - [Clone your fork](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository) to your local disk and set the upstream to this repo 54 | ```shell 55 | git clone git@github.com:/Holocron.git 56 | cd Holocron 57 | git remote add upstream https://github.com/frgfm/Holocron.git 58 | ``` 59 | 60 | 3 - You should not work on the `main` branch, so let's create a new one 61 | ```shell 62 | git checkout -b a-short-description 63 | ``` 64 | 65 | 4 - You only have to set your development environment now. First uninstall any existing installation of the library with `pip uninstall pylocron`, then: 66 | ```shell 67 | pip install -e ".[dev]" 68 | pre-commit install 69 | ``` 70 | 71 | ### Developing your feature 72 | 73 | #### Commits 74 | 75 | - **Code**: ensure to provide docstrings to your Python code. In doing so, please follow [Google-style](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) so it can ease the process of documentation later. 76 | - **Commit message**: please follow [Udacity guide](http://udacity.github.io/git-styleguide/) 77 | 78 | #### Unit tests 79 | 80 | In order to run the same unit tests as the CI workflows, you can run unittests locally: 81 | 82 | ```shell 83 | make test 84 | ``` 85 | 86 | #### Sanity checks 87 | 88 | The CI will also run some sanity checks (header format, dependency consistency, etc.), which you can run as follows: 89 | 90 | ```shell 91 | make style 92 | ``` 93 | 94 | #### Code quality 95 | 96 | To run all quality checks together 97 | 98 | ```shell 99 | make quality 100 | ``` 101 | 102 | ### Submit your modifications 103 | 104 | Push your last modifications to your remote branch 105 | ```shell 106 | git push -u origin a-short-description 107 | ``` 108 | 109 | Then [open a Pull Request](https://docs.github.com/en/github/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request) from your fork's branch. Follow the instructions of the Pull Request template and then click on "Create a pull request". 110 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | DOCKERFILE_PATH = ./api/Dockerfile 2 | API_DIR = ./api 3 | PKG_DIR = . 4 | DEMO_DIR = ./demo 5 | DOCS_DIR = ./docs 6 | PKG_CONFIG_FILE = ${PKG_DIR}/pyproject.toml 7 | PKG_TEST_DIR = ${PKG_DIR}/tests 8 | API_CONFIG_FILE = ${API_DIR}/pyproject.toml 9 | API_LOCK_FILE = ${API_DIR}/uv.lock 10 | API_REQ_FILE = ${API_DIR}/requirements.txt 11 | DEMO_REQ_FILE = ${DEMO_DIR}/requirements.txt 12 | DEMO_SCRIPT = ${DEMO_DIR}/app.py 13 | DOCKER_NAMESPACE ?= holocron 14 | DOCKER_REPO ?= backend 15 | DOCKER_TAG ?= latest 16 | 17 | ######################################################## 18 | # Code checks 19 | ######################################################## 20 | 21 | 22 | install-quality: ${PKG_CONFIG_FILE} 23 | uv pip install --system -e ".[quality]" 24 | pre-commit install 25 | 26 | lint-check: ${PKG_CONFIG_FILE} 27 | ruff format --check . --config ${PKG_CONFIG_FILE} 28 | ruff check . --config ${PKG_CONFIG_FILE} 29 | 30 | lint-format: ${PKG_CONFIG_FILE} 31 | ruff format . --config ${PKG_CONFIG_FILE} 32 | ruff check --fix . --config ${PKG_CONFIG_FILE} 33 | 34 | precommit: ${PKG_CONFIG_FILE} .pre-commit-config.yaml 35 | pre-commit run --all-files 36 | 37 | typing-check: ${PKG_CONFIG_FILE} 38 | mypy --config-file ${PKG_CONFIG_FILE} 39 | 40 | deps-check: .github/verify_deps_sync.py 41 | python .github/verify_deps_sync.py 42 | 43 | # this target runs checks on all files 44 | quality: lint-check typing-check deps-check 45 | 46 | style: lint-format precommit 47 | 48 | ######################################################## 49 | # Build 50 | ######################################################## 51 | 52 | # PACKAGE 53 | install: ${PKG_CONFIG_FILE} 54 | uv pip install --system -e . 55 | 56 | # TESTS 57 | install-test: ${PKG_CONFIG_FILE} 58 | uv pip install --system -e ".[test]" 59 | 60 | test: install-test ${PKG_TEST_DIR} 61 | pytest --cov=holocron tests/ 62 | 63 | # DEMO 64 | install-demo: ${DEMO_REQ_FILE} 65 | uv pip install --system -r ${DEMO_REQ_FILE} 66 | 67 | run-demo: install-demo ${DEMO_SCRIPT} 68 | python ${DEMO_SCRIPT} --port 8080 69 | 70 | # DOCS 71 | install-docs: ${PKG_CONFIG_FILE} 72 | uv pip install --system -e ".[docs]" 73 | 74 | docs-latest: install-docs ${DOCS_DIR} 75 | sphinx-build ${DOCS_DIR}/source ${DOCS_DIR}/_build -a 76 | 77 | docs-full: install-docs ${DOCS_DIR} 78 | cd ${DOCS_DIR} && bash build.sh 79 | 80 | # API 81 | lock: ${API_CONFIG_FILE} 82 | uv lock --project ${API_DIR} 83 | 84 | req: ${API_CONFIG_FILE} ${PYTHON_LOCK_FILE} 85 | uv export --no-hashes --locked --no-dev -q -o ${API_REQ_FILE} --project ${API_DIR} 86 | 87 | build-api: req ${DOCKERFILE_PATH} 88 | docker build --platform linux/amd64 ${API_DIR} -t ${DOCKER_NAMESPACE}/${DOCKER_REPO}:${DOCKER_TAG} 89 | 90 | push-api: build-api 91 | docker push ${DOCKER_NAMESPACE}/${DOCKER_REPO}:${DOCKER_TAG} 92 | 93 | start-api: build-api ${API_DIR}/docker-compose.yml 94 | docker compose -f ${API_DIR}/docker-compose.yml up -d --wait 95 | 96 | stop-api: ${API_DIR}/docker-compose.yml 97 | docker compose -f ${API_DIR}/docker-compose.yml down 98 | 99 | test-api: ${API_CONFIG_FILE} ${PYTHON_LOCK_FILE} ${DOCKERFILE_PATH} ${API_DIR}/tests 100 | uv export --no-hashes --locked --extra test -q -o ${API_REQ_FILE} --project ${API_DIR} 101 | docker compose -f ${API_DIR}/docker-compose.yml up -d --wait --build 102 | - docker compose -f ${API_DIR}/docker-compose.yml exec -T backend pytest tests/ 103 | docker compose -f ${API_DIR}/docker-compose.yml down 104 | -------------------------------------------------------------------------------- /api/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11-slim 2 | 3 | WORKDIR /app 4 | 5 | # set environment variables 6 | ENV PYTHONDONTWRITEBYTECODE=1 7 | ENV PYTHONUNBUFFERED=1 8 | ENV PYTHONPATH="/app" 9 | 10 | # Install curl 11 | RUN apt-get -y update \ 12 | && apt-get -y install curl \ 13 | && apt-get clean \ 14 | && rm -rf /var/lib/apt/lists/* 15 | 16 | # Install uv 17 | # Ref: https://docs.astral.sh/uv/guides/integration/docker/#installing-uv 18 | COPY --from=ghcr.io/astral-sh/uv:0.5.13 /uv /bin/uv 19 | 20 | # copy requirements file 21 | COPY requirements.txt /tmp/requirements.txt 22 | 23 | # install dependencies 24 | RUN uv pip install --no-cache --system -r /tmp/requirements.txt 25 | 26 | # copy project 27 | COPY app /app/app 28 | -------------------------------------------------------------------------------- /api/Makefile: -------------------------------------------------------------------------------- 1 | # Pin the dependencies 2 | lock: 3 | poetry lock 4 | 5 | build: 6 | poetry export -f requirements.txt --without-hashes --output requirements.txt 7 | docker build . -t frgfm/holocron:python3.9-slim 8 | 9 | # Run the docker 10 | run: 11 | poetry export -f requirements.txt --without-hashes --output requirements.txt 12 | docker compose up -d --build 13 | 14 | # Run the docker 15 | stop: 16 | docker compose down 17 | 18 | # Run tests for the library 19 | test: 20 | poetry export -f requirements.txt --without-hashes --with dev --output requirements.txt 21 | docker compose up -d --build 22 | docker compose exec -T backend pytest tests/ --cov=app 23 | docker compose down 24 | -------------------------------------------------------------------------------- /api/README.md: -------------------------------------------------------------------------------- 1 | # Template for your Vision API using Holocron 2 | 3 | ## Installation 4 | 5 | You will only need to install [Git](https://git-scm.com/book/en/v2/Getting-Started-Installing-Git), [Docker](https://docs.docker.com/get-docker/) and [poetry](https://python-poetry.org/docs/#installation). The container environment will be self-sufficient and install the remaining dependencies on its own. 6 | 7 | ## Usage 8 | 9 | ### Starting your web server 10 | 11 | You will need to clone the repository first: 12 | ```shell 13 | git clone https://github.com/frgfm/Holocron.git 14 | ``` 15 | then from the repo root folder, you can start your container: 16 | 17 | ```shell 18 | make lock 19 | make run 20 | ``` 21 | Once completed, your [FastAPI](https://fastapi.tiangolo.com/) server should be running on port 8080. 22 | 23 | ### Documentation and swagger 24 | 25 | FastAPI comes with many advantages including speed and OpenAPI features. For instance, once your server is running, you can access the automatically built documentation and swagger in your browser at: http://api.localhost:8050/docs 26 | 27 | 28 | ### Using the routes 29 | 30 | You will find detailed instructions in the live documentation when your server is up, but here are some examples to use your available API routes: 31 | 32 | #### Image classification 33 | 34 | Using the following image: 35 | 36 | 37 | with this snippet: 38 | 39 | ```python 40 | import requests 41 | with open('/path/to/your/img.jpg', 'rb') as f: 42 | data = f.read() 43 | print(requests.post("http://api.localhost:8050/classification", files={'file': data}).json()) 44 | ``` 45 | 46 | should yield 47 | ``` 48 | {'value': 'French horn', 'confidence': 0.9685316681861877} 49 | ``` 50 | -------------------------------------------------------------------------------- /api/app/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-2024, François-Guillaume Fernandez. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | import os 7 | 8 | from pydantic import Field 9 | from pydantic_settings import BaseSettings 10 | 11 | __all__ = ["settings"] 12 | 13 | 14 | class Settings(BaseSettings): 15 | # State 16 | PROJECT_NAME: str = "Holocron API template" 17 | PROJECT_DESCRIPTION: str = "Template API for Computer Vision" 18 | VERSION: str = "0.2.2.dev0" 19 | DEBUG: bool = os.environ.get("DEBUG", "") != "False" 20 | CLF_HUB_REPO: str = Field( 21 | os.environ.get("CLF_HUB_REPO", "frgfm/rexnet1_5x"), 22 | json_schema_extra=[{"min_length": 2, "example": "frgfm/rexnet1_5x"}], 23 | ) 24 | 25 | 26 | settings = Settings() 27 | -------------------------------------------------------------------------------- /api/app/main.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-2024, François-Guillaume Fernandez. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | import time 7 | 8 | from fastapi import FastAPI, Request, status 9 | from fastapi.openapi.utils import get_openapi 10 | from pydantic import BaseModel 11 | 12 | from app.config import settings 13 | from app.routes import classification 14 | 15 | app = FastAPI( 16 | title=settings.PROJECT_NAME, 17 | description=settings.PROJECT_DESCRIPTION, 18 | debug=settings.DEBUG, 19 | version=settings.VERSION, 20 | ) 21 | 22 | 23 | # Routing 24 | app.include_router(classification.router, prefix="/classification", tags=["classification"]) 25 | 26 | 27 | class Status(BaseModel): 28 | status: str 29 | 30 | 31 | # Healthcheck 32 | @app.get( 33 | "/status", 34 | status_code=status.HTTP_200_OK, 35 | summary="Healthcheck for the API", 36 | include_in_schema=False, 37 | ) 38 | def get_status() -> Status: 39 | return Status(status="ok") 40 | 41 | 42 | # Middleware 43 | @app.middleware("http") 44 | async def add_process_time_header(request: Request, call_next): 45 | start_time = time.time() 46 | response = await call_next(request) 47 | process_time = time.time() - start_time 48 | response.headers["X-Process-Time"] = str(process_time) 49 | return response 50 | 51 | 52 | # Docs 53 | def custom_openapi(): 54 | if app.openapi_schema: 55 | return app.openapi_schema 56 | openapi_schema = get_openapi( 57 | title=settings.PROJECT_NAME, 58 | version=settings.VERSION, 59 | description=settings.PROJECT_DESCRIPTION, 60 | routes=app.routes, 61 | ) 62 | app.openapi_schema = openapi_schema 63 | return app.openapi_schema 64 | 65 | 66 | app.openapi = custom_openapi 67 | -------------------------------------------------------------------------------- /api/app/routes/classification.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-2024, François-Guillaume Fernandez. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | 7 | from fastapi import APIRouter, File, UploadFile, status 8 | 9 | from app.schemas import ClsCandidate 10 | from app.vision import CLF_CFG, classify_image, decode_image 11 | 12 | router = APIRouter() 13 | 14 | 15 | @router.post("/", status_code=status.HTTP_200_OK, summary="Perform image classification") 16 | def classify(file: UploadFile = File(...)) -> ClsCandidate: 17 | """Runs holocron vision model to analyze the input image""" 18 | probs = classify_image(decode_image(file.file.read())) 19 | 20 | return ClsCandidate( 21 | value=CLF_CFG["classes"][probs.argmax()], 22 | confidence=float(probs.max()), 23 | ) 24 | -------------------------------------------------------------------------------- /api/app/schemas.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-2024, François-Guillaume Fernandez. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from pydantic import BaseModel, Field 7 | 8 | 9 | class ClsCandidate(BaseModel): 10 | """Classification result""" 11 | 12 | value: str = Field(..., json_schema_extra=[{"example": "Wookie"}]) 13 | confidence: float = Field(..., json_schema_extra=[{"gte": 0, "lte": 1}]) 14 | -------------------------------------------------------------------------------- /api/app/vision.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-2024, François-Guillaume Fernandez. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | import io 7 | import json 8 | import logging 9 | from pathlib import Path 10 | 11 | import numpy as np 12 | import onnxruntime 13 | from huggingface_hub import hf_hub_download 14 | from PIL import Image 15 | 16 | from app.config import settings 17 | 18 | __all__ = ["classify_image", "decode_image"] 19 | 20 | logger = logging.getLogger("uvicorn.warning") 21 | 22 | # Download model config & checkpoint 23 | with Path(hf_hub_download(settings.CLF_HUB_REPO, filename="config.json")).open("rb") as f: 24 | CLF_CFG = json.load(f) 25 | 26 | CLF_ORT = onnxruntime.InferenceSession(hf_hub_download(settings.CLF_HUB_REPO, filename="model.onnx")) 27 | 28 | logger.info(f"Model loading completed: {settings.CLF_HUB_REPO}") 29 | 30 | 31 | def decode_image(img_data: bytes) -> Image.Image: 32 | return Image.open(io.BytesIO(img_data)) 33 | 34 | 35 | def preprocess_image(pil_img: Image.Image) -> np.ndarray: 36 | """Preprocess an image for inference 37 | 38 | Args: 39 | pil_img: a valid pillow image 40 | 41 | Returns: 42 | the resized and normalized image of shape (1, C, H, W) 43 | """ 44 | # Resizing (PIL takes (W, H) order for resizing) 45 | img = pil_img.resize(CLF_CFG["input_shape"][-2:][::-1], Image.BILINEAR) 46 | # (H, W, C) --> (C, H, W) 47 | img = np.asarray(img).transpose((2, 0, 1)).astype(np.float32) / 255 48 | # Normalization 49 | img -= np.array(CLF_CFG["mean"])[:, None, None] 50 | img /= np.array(CLF_CFG["std"])[:, None, None] 51 | 52 | return img[None, ...] 53 | 54 | 55 | def classify_image(pil_img: Image.Image) -> np.ndarray: 56 | np_img = preprocess_image(pil_img) 57 | ort_input = {CLF_ORT.get_inputs()[0].name: np_img} 58 | 59 | # Inference 60 | ort_out = CLF_ORT.run(None, ort_input) 61 | # sigmoid 62 | return 1 / (1 + np.exp(-ort_out[0][0])) 63 | -------------------------------------------------------------------------------- /api/docker-compose.yml: -------------------------------------------------------------------------------- 1 | name: holocron 2 | 3 | services: 4 | backend: 5 | image: holocron/backend:latest 6 | build: 7 | context: . 8 | ports: 9 | - "5050:5050" 10 | volumes: 11 | - ./:/app/ 12 | command: uvicorn app.main:app --reload --host 0.0.0.0 --port 5050 --proxy-headers --use-colors --log-level info 13 | healthcheck: 14 | test: ["CMD-SHELL", "curl http://localhost:5050/status"] 15 | interval: 10s 16 | timeout: 3s 17 | retries: 5 18 | -------------------------------------------------------------------------------- /api/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "holocron" 3 | version = "0.2.2.dev0" 4 | description = "Backend template for your Vision API with Holocron" 5 | requires-python = ">=3.11,<4.0" 6 | license = { text = "Apache-2.0" } 7 | authors = [{ name = "François-Guillaume Fernandez", email = "fg-feedback@protonmail.com" }] 8 | maintainers = [{ name = "François-Guillaume Fernandez", email = "fg-feedback@protonmail.com" }] 9 | readme = "README.md" 10 | keywords = ["backend", "api", "computer vision", "fastapi", "onnx"] 11 | dependencies = [ 12 | "uvicorn>=0.23.0,<1.0.0", 13 | "fastapi>=0.109.1,<1.0.0", 14 | "python-multipart>=0.0.9", 15 | "Pillow>=8.4.0,!=9.2.0", 16 | "onnxruntime>=1.16.3,<2.0.0", 17 | "huggingface-hub>=0.4.0,<1.0.0", 18 | "numpy>=1.19.5,<3.0.0", 19 | "pydantic-settings>=2.0.0,<3.0.0" 20 | 21 | ] 22 | 23 | [project.optional-dependencies] 24 | test = [ 25 | "pytest>=8.3.3,<9.0.0", 26 | "pytest-asyncio>=0.17.0,<1.0.0", 27 | "httpx>=0.23.0,<1.0.0", 28 | "pytest-cov>=4.0.0,<5.0.0", 29 | "pytest-pretty>=1.0.0,<2.0.0", 30 | "httpx>=0.23.0,<1.0.0", 31 | "requests>=2.32.0,<3.0.0", 32 | "asyncpg>=0.29.0,<1.0.0" 33 | ] 34 | -------------------------------------------------------------------------------- /api/tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pytest_asyncio 3 | import requests 4 | from httpx import AsyncClient 5 | 6 | from app.main import app 7 | 8 | 9 | @pytest.fixture(scope="session") 10 | def mock_classification_image(tmpdir_factory): 11 | url = "https://m.media-amazon.com/images/I/517Nh08xqkL._AC_SX425_.jpg" 12 | return requests.get(url, timeout=5).content 13 | 14 | 15 | @pytest_asyncio.fixture(scope="function") 16 | async def test_app_asyncio(): 17 | # for httpx>=20, follow_redirects=True (cf. https://github.com/encode/httpx/releases/tag/0.20.0) 18 | async with AsyncClient(app=app, base_url="http://test", follow_redirects=True) as ac: 19 | yield ac # testing happens here 20 | -------------------------------------------------------------------------------- /api/tests/routes/test_classification.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | @pytest.mark.asyncio 5 | async def test_classification(test_app_asyncio, mock_classification_image): 6 | response = await test_app_asyncio.post("/classification", files={"file": mock_classification_image}) 7 | assert response.status_code == 200 8 | json_response = response.json() 9 | 10 | # Check that IoU with GT if reasonable 11 | assert isinstance(json_response, dict) 12 | assert json_response["value"] == "French horn" 13 | assert json_response["confidence"] >= 0.8 14 | -------------------------------------------------------------------------------- /demo/app.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-2024, François-Guillaume Fernandez. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | import argparse 7 | import json 8 | from pathlib import Path 9 | 10 | import gradio as gr 11 | import numpy as np 12 | import onnxruntime 13 | from huggingface_hub import hf_hub_download 14 | from PIL import Image 15 | 16 | 17 | def main(args): 18 | # Download model config & checkpoint 19 | with Path(hf_hub_download(args.repo, filename="config.json")).open("rb") as f: 20 | cfg = json.load(f) 21 | 22 | ort_session = onnxruntime.InferenceSession(hf_hub_download(args.repo, filename="model.onnx")) 23 | 24 | def preprocess_image(pil_img: Image.Image) -> np.ndarray: 25 | """Preprocess an image for inference 26 | 27 | Args: 28 | pil_img: a valid pillow image 29 | 30 | Returns: 31 | the resized and normalized image of shape (1, C, H, W) 32 | """ 33 | # Resizing (PIL takes (W, H) order for resizing) 34 | img = pil_img.resize(cfg["input_shape"][-2:][::-1], Image.BILINEAR) 35 | # (H, W, C) --> (C, H, W) 36 | img = np.asarray(img).transpose((2, 0, 1)).astype(np.float32) / 255 37 | # Normalization 38 | img -= np.array(cfg["mean"])[:, None, None] 39 | img /= np.array(cfg["std"])[:, None, None] 40 | 41 | return img[None, ...] 42 | 43 | def predict(image): 44 | # Preprocessing 45 | np_img = preprocess_image(image) 46 | ort_input = {ort_session.get_inputs()[0].name: np_img} 47 | 48 | # Inference 49 | ort_out = ort_session.run(None, ort_input) 50 | # Post-processing 51 | out_exp = np.exp(ort_out[0][0]) 52 | probs = out_exp / out_exp.sum() 53 | 54 | return {class_name: float(conf) for class_name, conf in zip(cfg["classes"], probs)} 55 | 56 | interface = gr.Interface( 57 | fn=predict, 58 | inputs=gr.Image(type="pil"), 59 | outputs=gr.Label(num_top_classes=3), 60 | title="Holocron: image classification demo", 61 | article=( 62 | "

" 63 | "Github Repo | " 64 | "Documentation

" 65 | ), 66 | live=True, 67 | ) 68 | 69 | interface.launch(server_port=args.port, show_error=True) 70 | 71 | 72 | if __name__ == "__main__": 73 | parser = argparse.ArgumentParser( 74 | description="Holocron image classification demo", formatter_class=argparse.ArgumentDefaultsHelpFormatter 75 | ) 76 | parser.add_argument("--repo", type=str, default="frgfm/rexnet1_0x", help="HF Hub repo to use") 77 | parser.add_argument("--port", type=int, default=8001, help="Port on which the webserver will be run") 78 | args = parser.parse_args() 79 | 80 | main(args) 81 | -------------------------------------------------------------------------------- /demo/requirements.txt: -------------------------------------------------------------------------------- 1 | gradio>=5.0.0,<6.0.0 2 | Pillow>=8.4.0,!=9.2.0 3 | onnxruntime>=1.16.3,<2.0.0 4 | huggingface-hub>=0.4.0,<1.0.0 5 | numpy>=1.19.5,<3.0.0 6 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Changing the documentation 2 | 3 | The documentation of this project is built using `sphinx`. In order to install all the build dependencies, run the following command from the root folder of the repository: 4 | ```shell 5 | pip install -e ".[docs]" 6 | ``` 7 | 8 | --- 9 | **NOTE** 10 | 11 | You are only generating the documentation to inspect it locally. Only the source files are pushed to the remote repository, the documentation will be built automatically by the CI. 12 | 13 | --- 14 | 15 | ## Build the documentation 16 | 17 | ### Latest version 18 | 19 | In most cases, you will only be changing the documentation of the latest version (dev version). In this case, you can build the documentation (the HTML files) with the following command: 20 | 21 | ```shell 22 | sphinx-build docs/source docs/_build -a 23 | ``` 24 | 25 | Then open `docs/_build/index.html` in your web browser to navigate in it. 26 | 27 | 28 | ### Multi-version documentation 29 | 30 | In rare cases, you might want to modify the documentation for other versions. You will then have to build the documentation for the multiple versions of the package, which you can do by running this command from the `docs` folder: 31 | ```shell 32 | bash build.sh 33 | ``` 34 | -------------------------------------------------------------------------------- /docs/build.sh: -------------------------------------------------------------------------------- 1 | function deploy_doc(){ 2 | if [ ! -z "$1" ] 3 | then 4 | git checkout $1 5 | fi 6 | COMMIT=$(git rev-parse --short HEAD) 7 | echo "Creating doc at commit" $COMMIT "and pushing to folder $2" 8 | # Hotfix 9 | if [ -d ../requirements.txt ]; then 10 | sed -i "s/^torchvision.*/&,<0.11.0/" ../requirements.txt 11 | fi 12 | sed -i "s/torchvision>=.*',/&,<0.11.0',/" ../setup.py 13 | sed -i "s/',,/,/" ../setup.py 14 | uv pip install --system --upgrade .. 15 | git checkout ../setup.py 16 | if [ -d ../requirements.txt ]; then 17 | git checkout ../requirements.txt 18 | fi 19 | if [ ! -z "$2" ] 20 | then 21 | if [ "$2" == "latest" ]; then 22 | echo "Pushing main" 23 | sphinx-build source build/$2 -a 24 | elif [ -d build/$2 ]; then 25 | echo "Directory" $2 "already exists" 26 | else 27 | echo "Pushing version" $2 28 | cp -r _static source/ && cp _conf.py source/conf.py 29 | sphinx-build source build/$2 -a 30 | fi 31 | else 32 | echo "Pushing stable" 33 | cp -r _static source/ && cp _conf.py source/conf.py 34 | sphinx-build source build -a 35 | fi 36 | git checkout source/ && git clean -f source/ 37 | } 38 | 39 | # exit when any command fails 40 | set -e 41 | # You can find the commit for each tag on https://github.com/frgfm/holocron/tags 42 | if [ -d build ]; then rm -Rf build; fi 43 | mkdir build 44 | cp -r source/_static . 45 | cp source/conf.py _conf.py 46 | git fetch --all --tags --unshallow 47 | deploy_doc "" latest 48 | deploy_doc "e9ca768" v0.1.0 49 | deploy_doc "9b3f927" v0.1.1 50 | deploy_doc "59c3124" v0.1.2 51 | deploy_doc "d41610b" v0.1.3 52 | deploy_doc "67a50c7" v0.2.0 53 | deploy_doc "bc0d972" # v0.2.1 Latest stable release 54 | rm -rf _build _static _conf.py 55 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/_static/css/custom_theme.css: -------------------------------------------------------------------------------- 1 | h1 { 2 | font-size: 180%; 3 | } 4 | 5 | /* Github button */ 6 | 7 | .github-repo { 8 | display: flex; 9 | justify-content: center; 10 | } 11 | 12 | /* Version control */ 13 | 14 | .version-button { 15 | color: gray; 16 | border: none; 17 | padding: 5px; 18 | font-size: 15px; 19 | cursor: pointer; 20 | } 21 | 22 | .version-button:hover, .version-button:focus { 23 | color: white; 24 | background-color: gray; 25 | } 26 | 27 | .version-dropdown { 28 | display: none; 29 | min-width: 160px; 30 | overflow: auto; 31 | font-size: 15px; 32 | } 33 | 34 | .version-dropdown a { 35 | color: gray; 36 | padding: 3px 4px; 37 | text-decoration: none; 38 | display: block; 39 | } 40 | 41 | .version-dropdown a:hover { 42 | color: white; 43 | background-color: gray; 44 | } 45 | 46 | .version-show { 47 | display: block; 48 | } 49 | 50 | /* These 2 rules below are for the weight tables (generated in conf.py) to look 51 | * better. In particular we make their row height shorter */ 52 | .table-checkpoints td, .table-checkpoints th { 53 | margin-bottom: 0.2rem; 54 | padding: 0 !important; 55 | line-height: 1 !important; 56 | } 57 | .table-checkpoints p { 58 | margin-bottom: 0.2rem !important; 59 | } 60 | -------------------------------------------------------------------------------- /docs/source/_static/images/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frgfm/Holocron/cbe2ac60cf9824c116f6a0b282fa47af095179d6/docs/source/_static/images/favicon.ico -------------------------------------------------------------------------------- /docs/source/_static/images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frgfm/Holocron/cbe2ac60cf9824c116f6a0b282fa47af095179d6/docs/source/_static/images/logo.png -------------------------------------------------------------------------------- /docs/source/_static/images/logo_text.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frgfm/Holocron/cbe2ac60cf9824c116f6a0b282fa47af095179d6/docs/source/_static/images/logo_text.png -------------------------------------------------------------------------------- /docs/source/_templates/function.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: {{ module }} 4 | 5 | 6 | {{ name | underline}} 7 | 8 | .. autofunction:: {{ name }} 9 | -------------------------------------------------------------------------------- /docs/source/changelog.rst: -------------------------------------------------------------------------------- 1 | Changelog 2 | ========= 3 | 4 | v0.2.1 (2022-07-16) 5 | ------------------- 6 | Release note: `v0.2.1 `_ 7 | 8 | v0.2.0 (2022-02-05) 9 | ------------------- 10 | Release note: `v0.2.0 `_ 11 | 12 | v0.1.3 (2020-10-27) 13 | ------------------- 14 | Release note: `v0.1.3 `_ 15 | 16 | v0.1.2 (2020-06-21) 17 | ------------------- 18 | Release note: `v0.1.2 `_ 19 | 20 | v0.1.1 (2020-05-12) 21 | ------------------- 22 | Release note: `v0.1.1 `_ 23 | 24 | v0.1.0 (2020-05-11) 25 | ------------------- 26 | Release note: `v0.1.0 `_ 27 | -------------------------------------------------------------------------------- /docs/source/docutils.conf: -------------------------------------------------------------------------------- 1 | # Necessary for the table generated by autosummary to look decent 2 | [html writers] 3 | table_style: colwidths-auto 4 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | ********************************************* 2 | Holocron: a Deep Learning toolbox for PyTorch 3 | ********************************************* 4 | 5 | .. image:: https://github.com/frgfm/Holocron/releases/download/v0.1.3/holocron_logo_text.png 6 | :align: center 7 | 8 | Holocron is meant to bridge the gap between PyTorch and latest research papers. It brings training components that are not available yet in PyTorch with a similar interface. 9 | 10 | This project is meant for: 11 | 12 | * |:zap:| **speed**: architectures in this repo are picked for both pure performances and minimal latency 13 | * |:woman_scientist:| **research**: train your models easily to SOTA standards 14 | 15 | 16 | .. toctree:: 17 | :maxdepth: 2 18 | :caption: Getting Started 19 | :hidden: 20 | 21 | installing 22 | notebooks 23 | 24 | 25 | 26 | Model zoo 27 | ^^^^^^^^^ 28 | 29 | 30 | Image classification 31 | """""""""""""""""""" 32 | * TridentNet from `"Scale-Aware Trident Networks for Object Detection" `_ 33 | * SKNet from `"Selective Kernel Networks" `_ 34 | * PyConvResNet from `"Pyramidal Convolution: Rethinking Convolutional Neural Networks for Visual Recognition" `_ 35 | * ReXNet from `"ReXNet: Diminishing Representational Bottleneck on Convolutional Neural Network" `_ 36 | * RepVGG from `"RepVGG: Making VGG-style ConvNets Great Again" `_ 37 | 38 | Semantic segmentation 39 | """"""""""""""""""""" 40 | * U-Net from `"U-Net: Convolutional Networks for Biomedical Image Segmentation" `_ 41 | * U-Net++ from `"UNet++: Redesigning Skip Connections to Exploit Multiscale Features in Image Segmentation" `_ 42 | * UNet3+ from `"UNet 3+: A Full-Scale Connected UNet For Medical Image Segmentation" `_ 43 | 44 | Object detection 45 | """""""""""""""" 46 | * YOLO from `"ou Only Look Once: Unified, Real-Time Object Detection" `_ 47 | * YOLOv2 from `"YOLO9000: Better, Faster, Stronger" `_ 48 | * YOLOv4 from `"YOLOv4: Optimal Speed and Accuracy of Object Detection" `_ 49 | 50 | 51 | .. toctree:: 52 | :maxdepth: 2 53 | :caption: Package Reference 54 | :hidden: 55 | 56 | models 57 | nn 58 | nn.functional 59 | ops 60 | optim 61 | trainer 62 | transforms 63 | utils 64 | utils.data 65 | 66 | 67 | .. toctree:: 68 | :maxdepth: 2 69 | :caption: Notes 70 | :hidden: 71 | 72 | changelog 73 | -------------------------------------------------------------------------------- /docs/source/installing.rst: -------------------------------------------------------------------------------- 1 | 2 | ************ 3 | Installation 4 | ************ 5 | 6 | This library requires `Python `_ 3.9 or higher. 7 | 8 | Via Python Package 9 | ================== 10 | 11 | Install the last stable release of the package using `pip `_: 12 | 13 | .. code:: bash 14 | 15 | pip install pylocron 16 | 17 | 18 | Via Conda 19 | ========= 20 | 21 | Install the last stable release of the package using `conda `_: 22 | 23 | .. code:: bash 24 | 25 | conda install -c frgfm pylocron 26 | 27 | 28 | Via Git 29 | ======= 30 | 31 | Install the library in developer mode: 32 | 33 | .. code:: bash 34 | 35 | git clone https://github.com/frgfm/Holocron.git 36 | pip install -e Holocron/. 37 | -------------------------------------------------------------------------------- /docs/source/models.rst: -------------------------------------------------------------------------------- 1 | holocron.models 2 | ############### 3 | 4 | The models subpackage contains definitions of models for addressing 5 | different tasks, including: image classification, pixelwise semantic 6 | segmentation, object detection, instance segmentation, person 7 | keypoint detection and video classification. 8 | 9 | 10 | .. currentmodule:: holocron.models 11 | 12 | Classification 13 | ============== 14 | 15 | Classification models expect a 4D image tensor as an input (N x C x H x W) and returns a 2D output (N x K). 16 | The output represents the classification scores for each output classes. 17 | 18 | .. code:: python 19 | 20 | import holocron.models as models 21 | darknet19 = models.darknet19(num_classes=10) 22 | 23 | 24 | Supported architectures 25 | ----------------------- 26 | 27 | .. toctree:: 28 | :caption: Supported architectures 29 | :maxdepth: 1 30 | 31 | models/resnet 32 | models/resnext 33 | models/res2net 34 | models/tridentnet 35 | models/convnext 36 | models/pyconv_resnet 37 | models/rexnet 38 | models/sknet 39 | models/darknet 40 | models/darknetv2 41 | models/darknetv3 42 | models/darknetv4 43 | models/repvgg 44 | models/mobileone 45 | 46 | Available checkpoints 47 | --------------------- 48 | 49 | Here is the list of available checkpoints: 50 | 51 | .. include:: generated/classification_table.rst 52 | 53 | 54 | 55 | Object Detection 56 | ================ 57 | 58 | Object detection models expect a 4D image tensor as an input (N x C x H x W) and returns a list of dictionaries. 59 | Each dictionary has 3 keys: box coordinates, classification probability, classification label. 60 | 61 | .. code:: python 62 | 63 | import holocron.models as models 64 | yolov2 = models.yolov2(num_classes=10) 65 | 66 | 67 | .. currentmodule:: holocron.models.detection 68 | 69 | YOLO 70 | ---- 71 | 72 | .. autofunction:: yolov1 73 | 74 | .. autofunction:: yolov2 75 | 76 | .. autofunction:: yolov4 77 | 78 | 79 | Semantic Segmentation 80 | ===================== 81 | 82 | Semantic segmentation models expect a 4D image tensor as an input (N x C x H x W) and returns a classification score 83 | tensor of size (N x K x Ho x Wo). 84 | 85 | .. code:: python 86 | 87 | import holocron.models as models 88 | unet = models.unet(num_classes=10) 89 | 90 | 91 | .. currentmodule:: holocron.models.segmentation 92 | 93 | 94 | U-Net 95 | ----- 96 | 97 | .. autofunction:: unet 98 | 99 | .. autofunction:: unetp 100 | 101 | .. autofunction:: unetpp 102 | 103 | .. autofunction:: unet3p 104 | 105 | .. autofunction:: unet2 106 | 107 | .. autofunction:: unet_tvvgg11 108 | 109 | .. autofunction:: unet_tvresnet34 110 | 111 | .. autofunction:: unet_rexnet13 112 | -------------------------------------------------------------------------------- /docs/source/models/convnext.rst: -------------------------------------------------------------------------------- 1 | ConvNeXt 2 | ======== 3 | 4 | .. currentmodule:: holocron.models 5 | 6 | The ConvNeXt model is based on the `"A ConvNet for the 2020s" `_ paper. 7 | 8 | Architecture overview 9 | --------------------- 10 | 11 | This architecture compiles tricks from transformer-based vision models to improve a pure convolutional model. 12 | 13 | .. image:: https://github.com/frgfm/Holocron/releases/download/v0.2.1/convnext.png 14 | :align: center 15 | 16 | The key takeaways from the paper are the following: 17 | 18 | * update the stem convolution to act like a patchify layer of transformers 19 | * increase block kernel size to 7 20 | * switch to depth-wise convolutions 21 | * reduce the amount of activations and normalization layers 22 | 23 | 24 | Model builders 25 | -------------- 26 | 27 | The following model builders can be used to instantiate a ConvNeXt model, with or 28 | without pre-trained weights. All the model builders internally rely on the 29 | ``holocron.models.classification.convnext.ConvNeXt`` base class. Please refer to the `source 30 | code 31 | `_ for 32 | more details about this class. 33 | 34 | .. autosummary:: 35 | :toctree: generated/ 36 | :template: function.rst 37 | 38 | convnext_atto 39 | convnext_femto 40 | convnext_pico 41 | convnext_nano 42 | convnext_tiny 43 | convnext_small 44 | convnext_base 45 | convnext_large 46 | convnext_xl 47 | -------------------------------------------------------------------------------- /docs/source/models/darknet.rst: -------------------------------------------------------------------------------- 1 | DarkNet 2 | ======= 3 | 4 | .. currentmodule:: holocron.models 5 | 6 | The DarkNet model is based on the `"You Only Look Once: Unified, Real-Time Object Detection" `_ paper. 7 | 8 | Architecture overview 9 | --------------------- 10 | 11 | This paper introduces a highway network with powerful feature representation abilities. 12 | 13 | The key takeaways from the paper are the following: 14 | 15 | * improves the Inception architecture by using conv1x1 16 | * replaces ReLU by LeakyReLU 17 | 18 | 19 | Model builders 20 | -------------- 21 | 22 | The following model builders can be used to instantiate a DarknetV1 model, with or 23 | without pre-trained weights. All the model builders internally rely on the 24 | ``holocron.models.classification.darknet.DarknetV1`` base class. Please refer to the `source 25 | code 26 | `_ for 27 | more details about this class. 28 | 29 | .. autosummary:: 30 | :toctree: generated/ 31 | :template: function.rst 32 | 33 | darknet24 34 | -------------------------------------------------------------------------------- /docs/source/models/darknetv2.rst: -------------------------------------------------------------------------------- 1 | DarkNetV2 2 | ========= 3 | 4 | .. currentmodule:: holocron.models 5 | 6 | The DarkNetV2 model is based on the `"YOLO9000: Better, Faster, Stronger" `_ paper. 7 | 8 | Architecture overview 9 | --------------------- 10 | 11 | This paper improves its version version by adding more recent gradient flow facilitators. 12 | 13 | The key takeaways from the paper are the following: 14 | 15 | * adds batch normalization layers compared to DarkNetV1 16 | 17 | 18 | Model builders 19 | -------------- 20 | 21 | The following model builders can be used to instantiate a DarknetV2 model, with or 22 | without pre-trained weights. All the model builders internally rely on the 23 | ``holocron.models.classification.darknetv2.DarknetV2`` base class. Please refer to the `source 24 | code 25 | `_ for 26 | more details about this class. 27 | 28 | .. autosummary:: 29 | :toctree: generated/ 30 | :template: function.rst 31 | 32 | darknet19 33 | -------------------------------------------------------------------------------- /docs/source/models/darknetv3.rst: -------------------------------------------------------------------------------- 1 | DarkNetV3 2 | ========= 3 | 4 | .. currentmodule:: holocron.models 5 | 6 | The DarkNetV3 model is based on the `"YOLOv3: An Incremental Improvement" `_ paper. 7 | 8 | Architecture overview 9 | --------------------- 10 | 11 | This paper makes a more powerful version than its predecedors by increasing depth and using ResNet tricks. 12 | 13 | The key takeaways from the paper are the following: 14 | 15 | * adds residual connection compared to DarkNetV2 16 | 17 | 18 | Model builders 19 | -------------- 20 | 21 | The following model builders can be used to instantiate a DarknetV3 model, with or 22 | without pre-trained weights. All the model builders internally rely on the 23 | ``holocron.models.classification.darknetv3.DarknetV3`` base class. Please refer to the `source 24 | code 25 | `_ for 26 | more details about this class. 27 | 28 | .. autosummary:: 29 | :toctree: generated/ 30 | :template: function.rst 31 | 32 | darknet53 33 | -------------------------------------------------------------------------------- /docs/source/models/darknetv4.rst: -------------------------------------------------------------------------------- 1 | DarkNetV4 2 | ========= 3 | 4 | .. currentmodule:: holocron.models 5 | 6 | The DarkNetV4 model is based on the `"CSPNet: A New Backbone that can Enhance Learning Capability of CNN" `_ paper. 7 | 8 | Architecture overview 9 | --------------------- 10 | 11 | This paper makes a more powerful version than its predecedors by increasing depth and using ResNet tricks. 12 | 13 | The key takeaways from the paper are the following: 14 | 15 | * add cross-path connections to its predecessors 16 | * explores newer non-linearities 17 | 18 | 19 | Model builders 20 | -------------- 21 | 22 | The following model builders can be used to instantiate a DarknetV3 model, with or 23 | without pre-trained weights. All the model builders internally rely on the 24 | ``holocron.models.classification.darknetv4.DarknetV4`` base class. Please refer to the `source 25 | code 26 | `_ for 27 | more details about this class. 28 | 29 | .. autosummary:: 30 | :toctree: generated/ 31 | :template: function.rst 32 | 33 | cspdarknet53 34 | cspdarknet53_mish 35 | -------------------------------------------------------------------------------- /docs/source/models/mobileone.rst: -------------------------------------------------------------------------------- 1 | MobileOne 2 | ========= 3 | 4 | .. currentmodule:: holocron.models 5 | 6 | The ResNet model is based on the `"An Improved One millisecond Mobile Backbone" `_ paper. 7 | 8 | Architecture overview 9 | --------------------- 10 | 11 | This architecture optimizes the model for inference speed at inference time on mobile device. 12 | 13 | .. image:: https://github.com/frgfm/Holocron/releases/download/v0.2.1/mobileone.png 14 | :align: center 15 | 16 | The key takeaways from the paper are the following: 17 | 18 | * reuse the reparametrization concept of RepVGG while adding overparametrization in the block branches. 19 | * each block is composed of two consecutive reparametrizeable blocks (in a similar fashion than RepVGG): a depth-wise convolutional block, a point-wise convolutional block. 20 | 21 | 22 | Model builders 23 | -------------- 24 | 25 | The following model builders can be used to instantiate a MobileOne model, with or 26 | without pre-trained weights. All the model builders internally rely on the 27 | ``holocron.models.classification.mobileone.MobileOne`` base class. Please refer to the `source 28 | code 29 | `_ for 30 | more details about this class. 31 | 32 | .. autosummary:: 33 | :toctree: generated/ 34 | :template: function.rst 35 | 36 | mobileone_s0 37 | mobileone_s1 38 | mobileone_s2 39 | mobileone_s3 40 | -------------------------------------------------------------------------------- /docs/source/models/pyconv_resnet.rst: -------------------------------------------------------------------------------- 1 | PyConvResNet 2 | ============ 3 | 4 | .. currentmodule:: holocron.models 5 | 6 | The PyConvResNet model is based on the `"Pyramidal Convolution: Rethinking Convolutional Neural Networks for Visual Recognition" `_ paper. 7 | 8 | Architecture overview 9 | --------------------- 10 | 11 | This paper explores an alternative approach for convolutional block in a pyramidal fashion. 12 | 13 | .. image:: https://github.com/frgfm/Holocron/releases/download/v0.2.1/pyconv_resnet.png 14 | :align: center 15 | 16 | The key takeaways from the paper are the following: 17 | 18 | * replaces standard convolutions with pyramidal convolutions 19 | * extends kernel size while increasing group size to balance the number of operations 20 | 21 | 22 | Model builders 23 | -------------- 24 | 25 | The following model builders can be used to instantiate a PyConvResNet model, with or 26 | without pre-trained weights. All the model builders internally rely on the 27 | ``holocron.models.classification.resnet.ResNet`` base class. Please refer to the `source 28 | code 29 | `_ for 30 | more details about this class. 31 | 32 | .. autosummary:: 33 | :toctree: generated/ 34 | :template: function.rst 35 | 36 | pyconv_resnet50 37 | pyconvhg_resnet50 38 | -------------------------------------------------------------------------------- /docs/source/models/repvgg.rst: -------------------------------------------------------------------------------- 1 | RepVGG 2 | ====== 3 | 4 | .. currentmodule:: holocron.models 5 | 6 | The ResNet model is based on the `"RepVGG: Making VGG-style ConvNets Great Again" `_ paper. 7 | 8 | Architecture overview 9 | --------------------- 10 | 11 | This paper revisits the VGG architecture by adapting its parameter setting in training and inference mode to combine the original VGG speed and the block design of ResNet. 12 | 13 | .. image:: https://github.com/frgfm/Holocron/releases/download/v0.2.1/repvgg.png 14 | :align: center 15 | 16 | The key takeaways from the paper are the following: 17 | 18 | * have different block architectures between training and inference modes 19 | * the block is designed in a similar fashion as a ResNet bottleneck but in a way that all branches can be fused into a single one 20 | * The more complex training architecture improves gradient flow and overall optimization, while its inference counterpart is optimized for minimum latency and memory usage 21 | 22 | 23 | Model builders 24 | -------------- 25 | 26 | The following model builders can be used to instantiate a RepVGG model, with or 27 | without pre-trained weights. All the model builders internally rely on the 28 | ``holocron.models.classification.revpgg.RepVGG`` base class. Please refer to the `source 29 | code 30 | `_ for 31 | more details about this class. 32 | 33 | .. autosummary:: 34 | :toctree: generated/ 35 | :template: function.rst 36 | 37 | repvgg_a0 38 | repvgg_a1 39 | repvgg_a2 40 | repvgg_b0 41 | repvgg_b1 42 | repvgg_b2 43 | repvgg_b3 44 | -------------------------------------------------------------------------------- /docs/source/models/res2net.rst: -------------------------------------------------------------------------------- 1 | Res2Net 2 | ======= 3 | 4 | .. currentmodule:: holocron.models 5 | 6 | The Res2Net model is based on the `"Res2Net: A New Multi-scale Backbone Architecture" `_ paper. 7 | 8 | Architecture overview 9 | --------------------- 10 | 11 | This paper replaces the bottleneck block of ResNet architectures by a multi-scale version. 12 | 13 | .. image:: https://github.com/frgfm/Holocron/releases/download/v0.2.1/res2net.png 14 | :align: center 15 | 16 | The key takeaways from the paper are the following: 17 | 18 | * switch to efficient multi-scale convolutions using a cascade of conv 3x3 19 | * adapt the block for cardinality & SE blocks 20 | 21 | 22 | Model builders 23 | -------------- 24 | 25 | The following model builders can be used to instantiate a Res2Net model, with or 26 | without pre-trained weights. All the model builders internally rely on the 27 | ``holocron.models.classification.resnet.ResNet`` base class. Please refer to the `source 28 | code 29 | `_ for 30 | more details about this class. 31 | 32 | .. autosummary:: 33 | :toctree: generated/ 34 | :template: function.rst 35 | 36 | res2net50_26w_4s 37 | -------------------------------------------------------------------------------- /docs/source/models/resnet.rst: -------------------------------------------------------------------------------- 1 | ResNet 2 | ====== 3 | 4 | .. currentmodule:: holocron.models 5 | 6 | The ResNet model is based on the `"Deep Residual Learning for Image Recognition" `_ paper. 7 | 8 | Architecture overview 9 | --------------------- 10 | 11 | This paper introduces a few tricks to maximize the depth of convolutional architectures that can be trained. 12 | 13 | The key takeaways from the paper are the following: 14 | 15 | * add a shortcut connection in bottleneck blocks to ease the gradient flow 16 | * extensive use of batch normalization layers 17 | 18 | 19 | Model builders 20 | -------------- 21 | 22 | The following model builders can be used to instantiate a ResNeXt model, with or 23 | without pre-trained weights. All the model builders internally rely on the 24 | ``holocron.models.classification.resnet.ResNet`` base class. Please refer to the `source 25 | code 26 | `_ for 27 | more details about this class. 28 | 29 | .. autosummary:: 30 | :toctree: generated/ 31 | :template: function.rst 32 | 33 | resnet18 34 | resnet34 35 | resnet50 36 | resnet50d 37 | resnet101 38 | resnet152 39 | -------------------------------------------------------------------------------- /docs/source/models/resnext.rst: -------------------------------------------------------------------------------- 1 | ResNeXt 2 | ======= 3 | 4 | .. currentmodule:: holocron.models 5 | 6 | The ResNeXt model is based on the `"Aggregated Residual Transformations for Deep Neural Networks" `_ paper. 7 | 8 | Architecture overview 9 | --------------------- 10 | 11 | This paper improves the ResNet architecture by increasing the width of bottleneck blocks 12 | 13 | The key takeaways from the paper are the following: 14 | 15 | * increases the number of channels in bottlenecks 16 | * switches to group convolutions to balance the number of operations 17 | 18 | 19 | Model builders 20 | -------------- 21 | 22 | The following model builders can be used to instantiate a ResNet model, with or 23 | without pre-trained weights. All the model builders internally rely on the 24 | ``holocron.models.classification.resnet.ResNet`` base class. Please refer to the `source 25 | code 26 | `_ for 27 | more details about this class. 28 | 29 | .. autosummary:: 30 | :toctree: generated/ 31 | :template: function.rst 32 | 33 | resnext50_32x4d 34 | resnext101_32x8d 35 | -------------------------------------------------------------------------------- /docs/source/models/rexnet.rst: -------------------------------------------------------------------------------- 1 | ReXNet 2 | ====== 3 | 4 | .. currentmodule:: holocron.models 5 | 6 | The ResNet model is based on the `"ReXNet: Diminishing Representational Bottleneck on Convolutional Neural Network" `_ paper. 7 | 8 | Architecture overview 9 | --------------------- 10 | 11 | This paper investigates the effect of channel configuration in convolutional bottlenecks. 12 | 13 | The key takeaways from the paper are the following: 14 | 15 | * increasing the depth ratio of conv 1x1 and inverted bottlenecks 16 | * replace ReLU6 with SiLU 17 | 18 | 19 | Model builders 20 | -------------- 21 | 22 | The following model builders can be used to instantiate a ReXNet model, with or 23 | without pre-trained weights. All the model builders internally rely on the 24 | ``holocron.models.classification.rexnet.ReXNet`` base class. Please refer to the `source 25 | code 26 | `_ for 27 | more details about this class. 28 | 29 | .. autosummary:: 30 | :toctree: generated/ 31 | :template: function.rst 32 | 33 | rexnet1_0x 34 | rexnet1_3x 35 | rexnet1_5x 36 | rexnet2_0x 37 | rexnet2_2x 38 | -------------------------------------------------------------------------------- /docs/source/models/sknet.rst: -------------------------------------------------------------------------------- 1 | SKNet 2 | ===== 3 | 4 | .. currentmodule:: holocron.models 5 | 6 | The ResNet model is based on the `"Selective Kernel Networks" `_ paper. 7 | 8 | Architecture overview 9 | --------------------- 10 | 11 | This paper revisits the concept of having a dynamic receptive field selection in convolutional blocks. 12 | 13 | .. image:: https://github.com/frgfm/Holocron/releases/download/v0.2.1/skconv.png 14 | :align: center 15 | 16 | The key takeaways from the paper are the following: 17 | 18 | * performs convolutions with multiple kernel sizes 19 | * implements a cross-channel attention mechanism 20 | 21 | 22 | Model builders 23 | -------------- 24 | 25 | The following model builders can be used to instantiate a SKNet model, with or 26 | without pre-trained weights. All the model builders internally rely on the 27 | ``holocron.models.classification.resnet.ResNet`` base class. Please refer to the `source 28 | code 29 | `_ for 30 | more details about this class. 31 | 32 | .. autosummary:: 33 | :toctree: generated/ 34 | :template: function.rst 35 | 36 | sknet50 37 | sknet101 38 | sknet152 39 | -------------------------------------------------------------------------------- /docs/source/models/tridentnet.rst: -------------------------------------------------------------------------------- 1 | TridentNet 2 | ========== 3 | 4 | .. currentmodule:: holocron.models 5 | 6 | The ResNeXt model is based on the `"Scale-Aware Trident Networks for Object Detection" `_ paper. 7 | 8 | Architecture overview 9 | --------------------- 10 | 11 | This paper replaces the bottleneck block of ResNet architectures by a multi-scale version. 12 | 13 | .. image:: https://github.com/frgfm/Holocron/releases/download/v0.2.1/tridentnet.png 14 | :align: center 15 | 16 | The key takeaways from the paper are the following: 17 | 18 | * switch bottleneck to a 3 branch system 19 | * all parallel branches share the same parameters but using different dilation values 20 | 21 | 22 | Model builders 23 | -------------- 24 | 25 | The following model builders can be used to instantiate a TridentNet model, with or 26 | without pre-trained weights. All the model builders internally rely on the 27 | ``holocron.models.classification.resnet.ResNet`` base class. Please refer to the `source 28 | code 29 | `_ for 30 | more details about this class. 31 | 32 | .. autosummary:: 33 | :toctree: generated/ 34 | :template: function.rst 35 | 36 | tridentnet50 37 | -------------------------------------------------------------------------------- /docs/source/nn.functional.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | 5 | holocron.nn.functional 6 | ====================== 7 | 8 | .. currentmodule:: holocron.nn.functional 9 | 10 | 11 | Non-linear activations 12 | ---------------------- 13 | 14 | .. autofunction:: hard_mish 15 | 16 | .. autofunction:: nl_relu 17 | 18 | 19 | 20 | Loss functions 21 | -------------- 22 | 23 | .. autofunction:: focal_loss 24 | 25 | .. autofunction:: multilabel_cross_entropy 26 | 27 | .. autofunction:: complement_cross_entropy 28 | 29 | .. autofunction:: dice_loss 30 | 31 | ..autofunction:: poly_loss 32 | 33 | Convolutions 34 | ------------ 35 | 36 | .. autofunction:: norm_conv2d 37 | 38 | .. autofunction:: add2d 39 | 40 | Regularization layers 41 | --------------------- 42 | 43 | .. autofunction:: dropblock2d 44 | 45 | 46 | Downsampling 47 | ------------ 48 | 49 | .. autofunction:: concat_downsample2d 50 | .. autofunction:: z_pool 51 | -------------------------------------------------------------------------------- /docs/source/nn.rst: -------------------------------------------------------------------------------- 1 | holocron.nn 2 | ============ 3 | 4 | An addition to the :mod:`torch.nn` module of Pytorch to extend the range of neural networks building blocks. 5 | 6 | 7 | .. currentmodule:: holocron.nn 8 | 9 | Non-linear activations 10 | ---------------------- 11 | 12 | .. autoclass:: HardMish 13 | 14 | .. autoclass:: NLReLU 15 | 16 | .. autoclass:: FReLU 17 | 18 | Loss functions 19 | -------------- 20 | 21 | .. autoclass:: FocalLoss 22 | 23 | .. autoclass:: MultiLabelCrossEntropy 24 | 25 | .. autoclass:: ComplementCrossEntropy 26 | 27 | .. autoclass:: MutualChannelLoss 28 | 29 | .. autoclass:: DiceLoss 30 | 31 | .. autoclass:: PolyLoss 32 | 33 | 34 | Loss wrappers 35 | -------------- 36 | 37 | .. autoclass:: ClassBalancedWrapper 38 | 39 | Convolution layers 40 | ------------------ 41 | 42 | .. autoclass:: NormConv2d 43 | 44 | .. autoclass:: Add2d 45 | 46 | .. autoclass:: SlimConv2d 47 | 48 | .. autoclass:: PyConv2d 49 | 50 | .. autoclass:: Involution2d 51 | 52 | Regularization layers 53 | --------------------- 54 | 55 | .. autoclass:: DropBlock2d 56 | 57 | 58 | Downsampling 59 | ------------ 60 | 61 | .. autoclass:: ConcatDownsample2d 62 | 63 | .. autoclass:: GlobalAvgPool2d 64 | 65 | .. autoclass:: GlobalMaxPool2d 66 | 67 | .. autoclass:: BlurPool2d 68 | 69 | .. autoclass:: SPP 70 | 71 | .. autoclass:: ZPool 72 | 73 | 74 | Attention 75 | --------- 76 | 77 | .. autoclass:: SAM 78 | 79 | .. autoclass:: LambdaLayer 80 | 81 | .. autoclass:: TripletAttention 82 | -------------------------------------------------------------------------------- /docs/source/notebooks.md: -------------------------------------------------------------------------------- 1 | ../../notebooks/README.md -------------------------------------------------------------------------------- /docs/source/ops.rst: -------------------------------------------------------------------------------- 1 | holocron.ops 2 | ============ 3 | 4 | .. currentmodule:: holocron.ops 5 | 6 | :mod:`holocron.ops` implements operators that are specific for Computer Vision. 7 | 8 | .. note:: 9 | Those operators currently do not support TorchScript. 10 | 11 | Boxes 12 | ----- 13 | 14 | .. autofunction:: box_giou 15 | .. autofunction:: diou_loss 16 | .. autofunction:: ciou_loss 17 | -------------------------------------------------------------------------------- /docs/source/optim.rst: -------------------------------------------------------------------------------- 1 | holocron.optim 2 | =============== 3 | 4 | .. automodule:: holocron.optim 5 | 6 | .. currentmodule:: holocron.optim 7 | 8 | To use :mod:`holocron.optim` you have to construct an optimizer object, that will hold 9 | the current state and will update the parameters based on the computed gradients. 10 | 11 | Optimizers 12 | ---------- 13 | 14 | Implementations of recent parameter optimizer for Pytorch modules. 15 | 16 | .. autoclass:: LARS 17 | 18 | .. autoclass:: LAMB 19 | 20 | .. autoclass:: RaLars 21 | 22 | .. autoclass:: TAdam 23 | 24 | .. autoclass:: AdaBelief 25 | 26 | .. autoclass:: AdamP 27 | 28 | .. autoclass:: Adan 29 | 30 | .. autoclass:: AdEMAMix 31 | 32 | 33 | Optimizer wrappers 34 | ------------------ 35 | 36 | :mod:`holocron.optim` also implements optimizer wrappers. 37 | 38 | A base optimizer should always be passed to the wrapper; e.g., you 39 | should write your code this way: 40 | 41 | >>> optimizer = ... 42 | >>> optimizer = wrapper(optimizer) 43 | 44 | .. autoclass:: holocron.optim.wrapper.Lookahead 45 | 46 | .. autoclass:: holocron.optim.wrapper.Scout 47 | -------------------------------------------------------------------------------- /docs/source/trainer.rst: -------------------------------------------------------------------------------- 1 | holocron.trainer 2 | ================ 3 | 4 | .. automodule:: holocron.trainer 5 | 6 | .. currentmodule:: holocron.trainer 7 | 8 | :mod:`holocron.trainer` provides some basic objects for training purposes. 9 | 10 | 11 | .. autoclass:: Trainer 12 | :members: 13 | 14 | 15 | Image classification 16 | -------------------- 17 | 18 | .. autoclass:: ClassificationTrainer 19 | :members: 20 | 21 | .. autoclass:: BinaryClassificationTrainer 22 | :members: 23 | 24 | 25 | Semantic segmentation 26 | --------------------- 27 | 28 | .. autoclass:: SegmentationTrainer 29 | :members: 30 | 31 | Object detection 32 | ---------------- 33 | 34 | .. autoclass:: DetectionTrainer 35 | :members: 36 | 37 | 38 | Miscellaneous 39 | ------------- 40 | 41 | .. autofunction:: freeze_bn 42 | 43 | .. autofunction:: freeze_model 44 | -------------------------------------------------------------------------------- /docs/source/transforms.rst: -------------------------------------------------------------------------------- 1 | holocron.transforms 2 | =================== 3 | 4 | .. automodule:: holocron.transforms 5 | 6 | .. currentmodule:: holocron.transforms 7 | 8 | :mod:`holocron.transforms` provides PIL and PyTorch tensor transformations. 9 | 10 | 11 | .. autoclass:: Resize 12 | :members: 13 | 14 | 15 | .. autoclass:: RandomZoomOut 16 | :members: 17 | -------------------------------------------------------------------------------- /docs/source/utils.data.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | holocron.utils.data 5 | =================== 6 | 7 | .. currentmodule:: holocron.utils.data 8 | 9 | 10 | Batch collate 11 | ------------- 12 | 13 | .. autofunction:: Mixup 14 | -------------------------------------------------------------------------------- /docs/source/utils.rst: -------------------------------------------------------------------------------- 1 | holocron.utils 2 | =============== 3 | 4 | .. automodule:: holocron.utils 5 | 6 | .. currentmodule:: holocron.utils 7 | 8 | :mod:`holocron.utils` provides some utilities for general usage. 9 | 10 | 11 | Miscellaneous 12 | ------------- 13 | 14 | .. autofunction:: find_image_size 15 | -------------------------------------------------------------------------------- /holocron/__init__.py: -------------------------------------------------------------------------------- 1 | from holocron import models, nn, ops, optim, trainer, transforms, utils 2 | 3 | from .version import __version__ 4 | -------------------------------------------------------------------------------- /holocron/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import detection, segmentation 2 | from .classification import * 3 | -------------------------------------------------------------------------------- /holocron/models/checkpoints.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023-2024, François-Guillaume Fernandez. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | import logging 7 | from dataclasses import dataclass 8 | from enum import Enum 9 | from typing import Dict, List, Tuple, Union 10 | 11 | from torchvision.transforms.functional import InterpolationMode 12 | 13 | __all__ = [ 14 | "Checkpoint", 15 | "Dataset", 16 | "Evaluation", 17 | "LoadingMeta", 18 | "Metric", 19 | "PreProcessing", 20 | "TrainingRecipe", 21 | ] 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | @dataclass 27 | class TrainingRecipe: 28 | """Implements a training recipe. 29 | 30 | Args: 31 | commit_hash: the commit that was used to train the model. 32 | args: the argument values that were passed to the reference script to train this. 33 | """ 34 | 35 | commit: Union[str, None] 36 | script: Union[str, None] 37 | args: Union[str, None] 38 | 39 | 40 | class Metric(str, Enum): 41 | """Evaluation metric""" 42 | 43 | TOP1_ACC = "top1-accuracy" 44 | TOP5_ACC = "top5-accuracy" 45 | 46 | 47 | class Dataset(str, Enum): 48 | """Training/evaluation dataset""" 49 | 50 | IMAGENET1K = "imagenet-1k" 51 | IMAGENETTE = "imagenette" 52 | CIFAR10 = "cifar10" 53 | 54 | 55 | @dataclass 56 | class Evaluation: 57 | """Results of model evaluation""" 58 | 59 | dataset: Dataset 60 | results: Dict[Metric, float] 61 | 62 | 63 | @dataclass 64 | class LoadingMeta: 65 | """Metadata to load the model""" 66 | 67 | url: str 68 | sha256: str 69 | size: int 70 | arch: str 71 | num_params: int 72 | categories: List[str] 73 | 74 | 75 | @dataclass 76 | class PreProcessing: 77 | """Preprocessing metadata for the model""" 78 | 79 | input_shape: Tuple[int, ...] 80 | mean: Tuple[float, ...] 81 | std: Tuple[float, ...] 82 | interpolation: InterpolationMode = InterpolationMode.BILINEAR 83 | 84 | 85 | @dataclass 86 | class Checkpoint: 87 | """Data required to run a model in the exact same condition than the checkpoint""" 88 | 89 | # What to expect 90 | evaluation: Evaluation 91 | # How to load it 92 | meta: LoadingMeta 93 | # How to use it 94 | pre_processing: PreProcessing 95 | # How to reproduce 96 | recipe: TrainingRecipe 97 | 98 | 99 | def _handle_legacy_pretrained( 100 | pretrained: bool = False, 101 | checkpoint: Union[Checkpoint, None] = None, 102 | default_checkpoint: Union[Checkpoint, None] = None, 103 | ) -> Union[Checkpoint, None]: 104 | checkpoint = checkpoint or (default_checkpoint if pretrained else None) 105 | 106 | if pretrained and checkpoint is None: 107 | logger.warning("Invalid model URL, using default initialization.") 108 | 109 | return checkpoint 110 | -------------------------------------------------------------------------------- /holocron/models/classification/__init__.py: -------------------------------------------------------------------------------- 1 | from .convnext import * 2 | from .darknet import * 3 | from .darknetv2 import * 4 | from .darknetv3 import * 5 | from .darknetv4 import * 6 | from .mobileone import * 7 | from .pyconv_resnet import * 8 | from .repvgg import * 9 | from .res2net import * 10 | from .resnet import * 11 | from .rexnet import * 12 | from .sknet import * 13 | from .tridentnet import * 14 | -------------------------------------------------------------------------------- /holocron/models/classification/darknet.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2020-2024, François-Guillaume Fernandez. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from collections import OrderedDict 7 | from typing import Any, Callable, Dict, List, Optional 8 | 9 | import torch.nn as nn 10 | 11 | from holocron.nn import GlobalAvgPool2d 12 | from holocron.nn.init import init_module 13 | 14 | from ..presets import IMAGENETTE 15 | from ..utils import conv_sequence, load_pretrained_params 16 | 17 | __all__ = ["DarknetV1", "darknet24"] 18 | 19 | 20 | default_cfgs: Dict[str, Dict[str, Any]] = { 21 | "darknet24": { 22 | **IMAGENETTE.__dict__, 23 | "input_shape": (3, 224, 224), 24 | "url": "https://github.com/frgfm/Holocron/releases/download/v0.1.3/darknet24_224-816d72cb.pt", 25 | }, 26 | } 27 | 28 | 29 | class DarknetBodyV1(nn.Sequential): 30 | def __init__( 31 | self, 32 | layout: List[List[int]], 33 | in_channels: int = 3, 34 | stem_channels: int = 64, 35 | act_layer: Optional[nn.Module] = None, 36 | norm_layer: Optional[Callable[[int], nn.Module]] = None, 37 | drop_layer: Optional[Callable[..., nn.Module]] = None, 38 | conv_layer: Optional[Callable[..., nn.Module]] = None, 39 | ) -> None: 40 | if act_layer is None: 41 | act_layer = nn.LeakyReLU(0.1, inplace=True) 42 | 43 | in_chans = [stem_channels] + [_layout[-1] for _layout in layout[:-1]] 44 | 45 | super().__init__( 46 | OrderedDict([ 47 | ( 48 | "stem", 49 | nn.Sequential( 50 | *conv_sequence( 51 | in_channels, 52 | stem_channels, 53 | act_layer, 54 | norm_layer, 55 | drop_layer, 56 | conv_layer, 57 | kernel_size=7, 58 | padding=3, 59 | stride=2, 60 | bias=(norm_layer is None), 61 | ) 62 | ), 63 | ), 64 | ( 65 | "layers", 66 | nn.Sequential(*[ 67 | self._make_layer([_in_chans, *planes], act_layer, norm_layer, drop_layer, conv_layer) 68 | for _in_chans, planes in zip(in_chans, layout) 69 | ]), 70 | ), 71 | ]) 72 | ) 73 | init_module(self, "leaky_relu") 74 | 75 | @staticmethod 76 | def _make_layer( 77 | planes: List[int], 78 | act_layer: Optional[nn.Module] = None, 79 | norm_layer: Optional[Callable[[int], nn.Module]] = None, 80 | drop_layer: Optional[Callable[..., nn.Module]] = None, 81 | conv_layer: Optional[Callable[..., nn.Module]] = None, 82 | ) -> nn.Sequential: 83 | layers: List[nn.Module] = [nn.MaxPool2d(2)] 84 | k1 = True 85 | for in_planes, out_planes in zip(planes[:-1], planes[1:]): 86 | layers.extend( 87 | conv_sequence( 88 | in_planes, 89 | out_planes, 90 | act_layer, 91 | norm_layer, 92 | drop_layer, 93 | conv_layer, 94 | kernel_size=3 if out_planes > in_planes else 1, 95 | padding=1 if out_planes > in_planes else 0, 96 | bias=(norm_layer is None), 97 | ) 98 | ) 99 | k1 = not k1 100 | 101 | return nn.Sequential(*layers) 102 | 103 | 104 | class DarknetV1(nn.Sequential): 105 | def __init__( 106 | self, 107 | layout: List[List[int]], 108 | num_classes: int = 10, 109 | in_channels: int = 3, 110 | stem_channels: int = 64, 111 | act_layer: Optional[nn.Module] = None, 112 | norm_layer: Optional[Callable[[int], nn.Module]] = None, 113 | drop_layer: Optional[Callable[..., nn.Module]] = None, 114 | conv_layer: Optional[Callable[..., nn.Module]] = None, 115 | ) -> None: 116 | super().__init__( 117 | OrderedDict([ 118 | ( 119 | "features", 120 | DarknetBodyV1(layout, in_channels, stem_channels, act_layer, norm_layer, drop_layer, conv_layer), 121 | ), 122 | ("pool", GlobalAvgPool2d(flatten=True)), 123 | ("classifier", nn.Linear(layout[2][-1], num_classes)), 124 | ]) 125 | ) 126 | 127 | init_module(self, "leaky_relu") 128 | 129 | 130 | def _darknet(arch: str, pretrained: bool, progress: bool, layout: List[List[int]], **kwargs: Any) -> DarknetV1: 131 | # Build the model 132 | model = DarknetV1(layout, **kwargs) 133 | model.default_cfg = default_cfgs[arch] # type: ignore[assignment] 134 | # Load pretrained parameters 135 | if pretrained: 136 | load_pretrained_params(model, default_cfgs[arch]["url"], progress) 137 | 138 | return model 139 | 140 | 141 | def darknet24(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DarknetV1: 142 | """Darknet-24 from 143 | `"You Only Look Once: Unified, Real-Time Object Detection" `_ 144 | 145 | Args: 146 | pretrained (bool): If True, returns a model pre-trained on ImageNet 147 | progress (bool): If True, displays a progress bar of the download to stderr 148 | kwargs: keyword args of _darknet 149 | 150 | Returns: 151 | torch.nn.Module: classification model 152 | """ 153 | return _darknet( 154 | "darknet24", 155 | pretrained, 156 | progress, 157 | [[192], [128, 256, 256, 512], [*([256, 512] * 4), 512, 1024], [512, 1024] * 2], 158 | **kwargs, 159 | ) 160 | -------------------------------------------------------------------------------- /holocron/models/detection/__init__.py: -------------------------------------------------------------------------------- 1 | from .yolo import * 2 | from .yolov2 import * 3 | from .yolov4 import * 4 | -------------------------------------------------------------------------------- /holocron/models/segmentation/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet import * 2 | from .unet3p import * 3 | from .unetpp import * 4 | -------------------------------------------------------------------------------- /holocron/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from . import init 2 | from .modules import * 3 | -------------------------------------------------------------------------------- /holocron/nn/init.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2019-2024, François-Guillaume Fernandez. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | import torch.nn as nn 7 | from torch.nn.modules.conv import _ConvNd 8 | 9 | 10 | def init_module(module: nn.Module, nonlinearity: str = "relu") -> None: 11 | """Initializes pytorch modules. 12 | 13 | Args: 14 | module: module to initialize 15 | nonlinearity: linearity to initialize convolutions for 16 | """ 17 | for m in module.modules(): 18 | if isinstance(m, _ConvNd): 19 | nn.init.kaiming_normal_(m.weight.data, mode="fan_out", nonlinearity=nonlinearity) 20 | if m.bias is not None: 21 | m.bias.data.zero_() 22 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 23 | m.weight.data.fill_(1.0) 24 | m.bias.data.zero_() 25 | -------------------------------------------------------------------------------- /holocron/nn/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .activation import * 2 | from .attention import * 3 | from .conv import * 4 | from .downsample import * 5 | from .dropblock import * 6 | from .lambda_layer import * 7 | from .loss import * 8 | -------------------------------------------------------------------------------- /holocron/nn/modules/activation.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2019-2024, François-Guillaume Fernandez. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import ClassVar, List 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch import Tensor 11 | 12 | from .. import functional as F 13 | 14 | __all__ = ["FReLU", "HardMish", "NLReLU"] 15 | 16 | 17 | class _Activation(nn.Module): 18 | __constants__: ClassVar[List[str]] = ["inplace"] 19 | 20 | def __init__(self, inplace: bool = False) -> None: 21 | super().__init__() 22 | self.inplace = inplace 23 | 24 | def extra_repr(self) -> str: 25 | return "inplace=True" if self.inplace else "" 26 | 27 | 28 | class HardMish(_Activation): 29 | r"""Implements the Had Mish activation module from `"H-Mish" `_. 30 | 31 | This activation is computed as follows: 32 | 33 | .. math:: 34 | f(x) = \frac{x}{2} \cdot \min(2, \max(0, x + 2)) 35 | """ 36 | 37 | def forward(self, x: Tensor) -> Tensor: 38 | return F.hard_mish(x, inplace=self.inplace) 39 | 40 | 41 | class NLReLU(_Activation): 42 | r"""Implements the Natural-Logarithm ReLU activation module from `"Natural-Logarithm-Rectified Activation 43 | Function in Convolutional Neural Networks" `_. 44 | 45 | This activation is computed as follows: 46 | 47 | .. math:: 48 | f(x) = ln(1 + \beta \cdot max(0, x)) 49 | 50 | Args: 51 | inplace (bool): should the operation be performed inplace 52 | """ 53 | 54 | def forward(self, x: Tensor) -> Tensor: 55 | return F.nl_relu(x, inplace=self.inplace) 56 | 57 | 58 | class FReLU(nn.Module): 59 | r"""Implements the Funnel activation module from `"Funnel Activation for Visual Recognition" 60 | `_. 61 | 62 | This activation is computed as follows: 63 | 64 | .. math:: 65 | f(x) = max(\mathbb{T}(x), x) 66 | 67 | where the :math:`\mathbb{T}` is the spatial contextual feature extraction. It is a convolution filter of size 68 | `kernel_size`, same padding and groups equal to the number of input channels, followed by a batch normalization. 69 | 70 | Args: 71 | inplace (bool): should the operation be performed inplace 72 | """ 73 | 74 | def __init__(self, in_channels: int, kernel_size: int = 3) -> None: 75 | super().__init__() 76 | self.conv = nn.Conv2d(in_channels, in_channels, kernel_size, padding=kernel_size // 2, groups=in_channels) 77 | self.bn = nn.BatchNorm2d(in_channels) 78 | 79 | def forward(self, x: Tensor) -> Tensor: 80 | out = self.conv(x) 81 | out = self.bn(out) 82 | return torch.max(x, out) 83 | -------------------------------------------------------------------------------- /holocron/nn/modules/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2019-2024, François-Guillaume Fernandez. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import cast 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch import Tensor 11 | 12 | from .downsample import ZPool 13 | 14 | __all__ = ["SAM", "TripletAttention"] 15 | 16 | 17 | class SAM(nn.Module): 18 | """SAM layer from `"CBAM: Convolutional Block Attention Module" `_ 19 | modified in `"YOLOv4: Optimal Speed and Accuracy of Object Detection" `_. 20 | 21 | Args: 22 | in_channels (int): input channels 23 | """ 24 | 25 | def __init__(self, in_channels: int) -> None: 26 | super().__init__() 27 | self.conv = nn.Conv2d(in_channels, 1, 1) 28 | 29 | def forward(self, x: Tensor) -> Tensor: 30 | return x * torch.sigmoid(self.conv(x)) 31 | 32 | 33 | class DimAttention(nn.Module): 34 | """Attention layer across a specific dimension 35 | 36 | Args: 37 | dim: dimension to compute attention on 38 | """ 39 | 40 | def __init__(self, dim: int) -> None: 41 | super().__init__() 42 | self.compress = nn.Sequential( 43 | ZPool(dim=1), 44 | nn.Conv2d(2, 1, kernel_size=7, stride=1, padding=3, bias=False), 45 | nn.BatchNorm2d(1, eps=1e-5, momentum=0.01), 46 | nn.Sigmoid(), 47 | ) 48 | self.dim = dim 49 | 50 | def forward(self, x: Tensor) -> Tensor: 51 | if self.dim != 1: 52 | x = x.transpose(self.dim, 1).contiguous() 53 | out = cast(Tensor, x * self.compress(x)) 54 | if self.dim != 1: 55 | out = out.transpose(self.dim, 1).contiguous() 56 | return out 57 | 58 | 59 | class TripletAttention(nn.Module): 60 | """Triplet attention layer from `"Rotate to Attend: Convolutional Triplet Attention Module" 61 | `_. This implementation is based on the 62 | `one `_ 63 | from the paper's authors. 64 | """ 65 | 66 | def __init__(self) -> None: 67 | super().__init__() 68 | self.c_branch = DimAttention(dim=1) 69 | self.h_branch = DimAttention(dim=2) 70 | self.w_branch = DimAttention(dim=3) 71 | 72 | def forward(self, x: Tensor) -> Tensor: 73 | x_c = cast(Tensor, self.c_branch(x)) 74 | x_h = cast(Tensor, self.h_branch(x)) 75 | x_w = cast(Tensor, self.w_branch(x)) 76 | 77 | return (x_c + x_h + x_w) / 3 78 | -------------------------------------------------------------------------------- /holocron/nn/modules/dropblock.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2019-2024, François-Guillaume Fernandez. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | import torch.nn as nn 7 | from torch import Tensor 8 | 9 | from .. import functional as F 10 | 11 | __all__ = ["DropBlock2d"] 12 | 13 | 14 | class DropBlock2d(nn.Module): 15 | """Implements the DropBlock module from `"DropBlock: A regularization method for convolutional networks" 16 | `_ 17 | 18 | .. image:: https://github.com/frgfm/Holocron/releases/download/v0.1.3/dropblock.png 19 | :align: center 20 | 21 | Args: 22 | p (float, optional): probability of dropping activation value 23 | block_size (int, optional): size of each block that is expended from the sampled mask 24 | inplace (bool, optional): whether the operation should be done inplace 25 | """ 26 | 27 | def __init__(self, p: float = 0.1, block_size: int = 7, inplace: bool = False) -> None: 28 | super().__init__() 29 | self.p = p 30 | self.block_size = block_size 31 | self.inplace = inplace 32 | 33 | @property 34 | def drop_prob(self) -> float: 35 | return self.p / self.block_size**2 36 | 37 | def forward(self, x: Tensor) -> Tensor: 38 | return F.dropblock2d(x, self.drop_prob, self.block_size, self.inplace, self.training) 39 | 40 | def extra_repr(self) -> str: 41 | return f"p={self.p}, block_size={self.block_size}, inplace={self.inplace}" 42 | -------------------------------------------------------------------------------- /holocron/nn/modules/lambda_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2019-2024, François-Guillaume Fernandez. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import Optional 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from torch import einsum, nn 11 | 12 | __all__ = ["LambdaLayer"] 13 | 14 | 15 | class LambdaLayer(nn.Module): 16 | """Lambda layer from `"LambdaNetworks: Modeling long-range interactions without attention" 17 | `_. The implementation was adapted from `lucidrains' 18 | `_. 19 | 20 | .. image:: https://github.com/frgfm/Holocron/releases/download/v0.1.3/lambdalayer.png 21 | :align: center 22 | 23 | Args: 24 | in_channels (int): input channels 25 | out_channels (int, optional): output channels 26 | dim_k (int): key dimension 27 | n (int, optional): number of input pixels 28 | r (int, optional): receptive field for relative positional encoding 29 | num_heads (int, optional): number of attention heads 30 | dim_u (int, optional): intra-depth dimension 31 | """ 32 | 33 | def __init__( 34 | self, 35 | in_channels: int, 36 | out_channels: int, 37 | dim_k: int, 38 | n: Optional[int] = None, 39 | r: Optional[int] = None, 40 | num_heads: int = 4, 41 | dim_u: int = 1, 42 | ) -> None: 43 | super().__init__() 44 | self.u = dim_u 45 | self.num_heads = num_heads 46 | 47 | if out_channels % num_heads != 0: 48 | raise AssertionError("values dimension must be divisible by number of heads for multi-head query") 49 | dim_v = out_channels // num_heads 50 | 51 | # Project input and context to get queries, keys & values 52 | self.to_q = nn.Conv2d(in_channels, dim_k * num_heads, 1, bias=False) 53 | self.to_k = nn.Conv2d(in_channels, dim_k * dim_u, 1, bias=False) 54 | self.to_v = nn.Conv2d(in_channels, dim_v * dim_u, 1, bias=False) 55 | 56 | self.norm_q = nn.BatchNorm2d(dim_k * num_heads) 57 | self.norm_v = nn.BatchNorm2d(dim_v * dim_u) 58 | 59 | self.local_contexts = r is not None 60 | if r is not None: 61 | if r % 2 != 1: 62 | raise AssertionError("Receptive kernel size should be odd") 63 | self.padding = r // 2 64 | self.R = nn.Parameter(torch.randn(dim_k, dim_u, 1, r, r)) 65 | else: 66 | if n is None: 67 | raise AssertionError("You must specify the total sequence length (h x w)") 68 | self.pos_emb = nn.Parameter(torch.randn(n, n, dim_k, dim_u)) 69 | 70 | def forward(self, x: torch.Tensor) -> torch.Tensor: 71 | b, _, h, w = x.shape 72 | 73 | # Project inputs & context to retrieve queries, keys and values 74 | q = self.to_q(x) 75 | k = self.to_k(x) 76 | v = self.to_v(x) 77 | 78 | # Normalize queries & values 79 | q = self.norm_q(q) 80 | v = self.norm_v(v) 81 | 82 | # B x (num_heads * dim_k) * H * W -> B x num_heads x dim_k x (H * W) 83 | q = q.reshape(b, self.num_heads, -1, h * w) 84 | # B x (dim_k * dim_u) * H * W -> B x dim_u x dim_k x (H * W) 85 | k = k.reshape(b, -1, self.u, h * w).permute(0, 2, 1, 3) 86 | # B x (dim_v * dim_u) * H * W -> B x dim_u x dim_v x (H * W) 87 | v = v.reshape(b, -1, self.u, h * w).permute(0, 2, 1, 3) 88 | 89 | # Normalized keys 90 | k = k.softmax(dim=-1) 91 | 92 | # Content function 93 | λc = einsum("b u k m, b u v m -> b k v", k, v) 94 | Yc = einsum("b h k n, b k v -> b n h v", q, λc) 95 | 96 | # Position function 97 | if self.local_contexts: 98 | # B x dim_u x dim_v x (H * W) -> B x dim_u x dim_v x H x W 99 | v = v.reshape(b, self.u, v.shape[2], h, w) 100 | λp = F.conv3d(v, self.R, padding=(0, self.padding, self.padding)) 101 | Yp = einsum("b h k n, b k v n -> b n h v", q, λp.flatten(3)) 102 | else: 103 | λp = einsum("n m k u, b u v m -> b n k v", self.pos_emb, v) 104 | Yp = einsum("b h k n, b n k v -> b n h v", q, λp) 105 | 106 | Y = Yc + Yp 107 | # B x (H * W) x num_heads x dim_v -> B x (num_heads * dim_v) x H x W 108 | return Y.permute(0, 2, 3, 1).reshape(b, self.num_heads * v.shape[2], h, w) 109 | -------------------------------------------------------------------------------- /holocron/ops/__init__.py: -------------------------------------------------------------------------------- 1 | from .boxes import * 2 | -------------------------------------------------------------------------------- /holocron/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from . import wrapper 2 | from .adabelief import AdaBelief 3 | from .adamp import AdamP 4 | from .adan import Adan 5 | from .ademamix import AdEMAMix 6 | from .lamb import LAMB 7 | from .lars import LARS 8 | from .ralars import RaLars 9 | from .tadam import TAdam 10 | -------------------------------------------------------------------------------- /holocron/optim/lars.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2019-2024, François-Guillaume Fernandez. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import Callable, Dict, Iterable, Optional, Tuple 7 | 8 | import torch 9 | from torch.optim.optimizer import Optimizer 10 | 11 | __all__ = ["LARS"] 12 | 13 | 14 | class LARS(Optimizer): 15 | r"""Implements the LARS optimizer from `"Large batch training of convolutional networks" 16 | `_. 17 | 18 | The estimation of global and local learning rates is described as follows, :math:`\forall t \geq 1`: 19 | 20 | .. math:: 21 | \alpha_t \leftarrow \alpha (1 - t / T)^2 \\ 22 | \gamma_t \leftarrow \frac{\lVert \theta_t \rVert}{\lVert g_t \rVert + \lambda \lVert \theta_t \rVert} 23 | 24 | where :math:`\theta_t` is the parameter value at step :math:`t` (:math:`\theta_0` being the initialization value), 25 | :math:`g_t` is the gradient of :math:`\theta_t`, 26 | :math:`T` is the total number of steps, 27 | :math:`\alpha` is the learning rate 28 | :math:`\lambda \geq 0` is the weight decay. 29 | 30 | Then we estimate the momentum using: 31 | 32 | .. math:: 33 | v_t \leftarrow m v_{t-1} + \alpha_t \gamma_t (g_t + \lambda \theta_t) 34 | 35 | where :math:`m` is the momentum and :math:`v_0 = 0`. 36 | 37 | And finally the update step is performed using the following rule: 38 | 39 | .. math:: 40 | \theta_t \leftarrow \theta_{t-1} - v_t 41 | 42 | Args: 43 | params (iterable): iterable of parameters to optimize or dicts defining 44 | parameter groups 45 | lr (float, optional): learning rate 46 | momentum (float, optional): momentum factor (default: 0) 47 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 48 | dampening (float, optional): dampening for momentum (default: 0) 49 | nesterov (bool, optional): enables Nesterov momentum (default: False) 50 | scale_clip (tuple, optional): the lower and upper bounds for the weight norm in local LR of LARS 51 | """ 52 | 53 | def __init__( 54 | self, 55 | params: Iterable[torch.nn.Parameter], 56 | lr: float = 1e-3, 57 | momentum: float = 0.0, 58 | dampening: float = 0.0, 59 | weight_decay: float = 0.0, 60 | nesterov: bool = False, 61 | scale_clip: Optional[Tuple[float, float]] = None, 62 | ) -> None: 63 | if not isinstance(lr, float) or lr < 0.0: 64 | raise ValueError(f"Invalid learning rate: {lr}") 65 | if momentum < 0.0: 66 | raise ValueError(f"Invalid momentum value: {momentum}") 67 | if weight_decay < 0.0: 68 | raise ValueError(f"Invalid weight_decay value: {weight_decay}") 69 | 70 | defaults = { 71 | "lr": lr, 72 | "momentum": momentum, 73 | "dampening": dampening, 74 | "weight_decay": weight_decay, 75 | "nesterov": nesterov, 76 | } 77 | if nesterov and (momentum <= 0 or dampening != 0): 78 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 79 | super().__init__(params, defaults) 80 | # LARS arguments 81 | self.scale_clip = scale_clip 82 | if self.scale_clip is None: 83 | self.scale_clip = (0.0, 10.0) 84 | 85 | def __setstate__(self, state: Dict[str, torch.Tensor]) -> None: 86 | super().__setstate__(state) 87 | for group in self.param_groups: 88 | group.setdefault("nesterov", False) 89 | 90 | @torch.no_grad() 91 | def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: # type: ignore[override] 92 | """Performs a single optimization step. 93 | 94 | Arguments: 95 | closure (callable, optional): A closure that reevaluates the model and returns the loss. 96 | """ 97 | loss = None 98 | if closure is not None: 99 | with torch.enable_grad(): 100 | loss = closure() 101 | 102 | for group in self.param_groups: 103 | weight_decay = group["weight_decay"] 104 | momentum = group["momentum"] 105 | dampening = group["dampening"] 106 | nesterov = group["nesterov"] 107 | 108 | for p in group["params"]: 109 | if p.grad is None: 110 | continue 111 | d_p = p.grad.data 112 | 113 | # LARS 114 | p_norm = torch.norm(p.data) 115 | denom = torch.norm(d_p) 116 | if weight_decay != 0: 117 | d_p.add_(p.data, alpha=weight_decay) 118 | denom.add_(p_norm, alpha=weight_decay) 119 | # Compute the local LR 120 | local_lr = 1 if p_norm == 0 or denom == 0 else p_norm / denom 121 | 122 | if momentum == 0: 123 | p.data.add_(d_p, alpha=-group["lr"] * local_lr) 124 | else: 125 | param_state = self.state[p] 126 | if "momentum_buffer" not in param_state: 127 | momentum_buffer = param_state["momentum_buffer"] = torch.clone(d_p).detach() 128 | else: 129 | momentum_buffer = param_state["momentum_buffer"] 130 | momentum_buffer.mul_(momentum).add_(d_p, alpha=1 - dampening) 131 | d_p = d_p.add(momentum_buffer, alpha=momentum) if nesterov else momentum_buffer 132 | p.data.add_(d_p, alpha=-group["lr"] * local_lr) 133 | self.state[p]["momentum_buffer"] = momentum_buffer 134 | 135 | return loss 136 | -------------------------------------------------------------------------------- /holocron/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import * 2 | from .classification import * 3 | from .detection import * 4 | from .segmentation import * 5 | from .utils import * 6 | -------------------------------------------------------------------------------- /holocron/trainer/detection.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2019-2024, François-Guillaume Fernandez. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import Dict, List, Optional, Tuple 7 | 8 | import torch 9 | from torch import Tensor 10 | from torchvision.ops.boxes import box_iou 11 | 12 | from .core import Trainer 13 | 14 | __all__ = ["DetectionTrainer"] 15 | 16 | 17 | def assign_iou(gt_boxes: Tensor, pred_boxes: Tensor, iou_threshold: float = 0.5) -> Tuple[List[int], List[int]]: 18 | """Assigns boxes by IoU""" 19 | iou = box_iou(gt_boxes, pred_boxes) 20 | iou = iou.max(dim=1) 21 | gt_kept = iou.values >= iou_threshold 22 | assign_unique = torch.unique(iou.indices[gt_kept]) 23 | # Filter 24 | if iou.indices[gt_kept].shape[0] == assign_unique.shape[0]: 25 | return torch.arange(gt_boxes.shape[0])[gt_kept], iou.indices[gt_kept] # type: ignore[return-value] 26 | 27 | gt_indices, pred_indices = [], [] 28 | for pred_idx in assign_unique: 29 | selection = iou.values[gt_kept][iou.indices[gt_kept] == pred_idx].argmax() 30 | gt_indices.append(torch.arange(gt_boxes.shape[0])[gt_kept][selection].item()) 31 | pred_indices.append(iou.indices[gt_kept][selection].item()) 32 | return gt_indices, pred_indices # type: ignore[return-value] 33 | 34 | 35 | class DetectionTrainer(Trainer): 36 | """Object detection trainer class. 37 | 38 | Args: 39 | model: model to train 40 | train_loader: training loader 41 | val_loader: validation loader 42 | criterion: loss criterion 43 | optimizer: parameter optimizer 44 | gpu: index of the GPU to use 45 | output_file: path where checkpoints will be saved 46 | amp: whether to use automatic mixed precision 47 | skip_nan_loss: whether the optimizer step should be skipped when the loss is NaN 48 | nan_tolerance: number of consecutive batches with NaN loss before stopping the training 49 | gradient_acc: number of batches to accumulate the gradient of before performing the update step 50 | gradient_clip: the gradient clip value 51 | on_epoch_end: callback triggered at the end of an epoch 52 | """ 53 | 54 | @staticmethod 55 | def _to_cuda( # type: ignore[override] 56 | x: List[Tensor], target: List[Dict[str, Tensor]] 57 | ) -> Tuple[List[Tensor], List[Dict[str, Tensor]]]: 58 | """Move input and target to GPU""" 59 | x = [_x.cuda(non_blocking=True) for _x in x] 60 | target = [{k: v.cuda(non_blocking=True) for k, v in t.items()} for t in target] 61 | return x, target 62 | 63 | def _get_loss(self, x: List[Tensor], target: List[Dict[str, Tensor]]) -> Tensor: # type: ignore[override] 64 | # AMP 65 | if self.amp: 66 | with torch.cuda.amp.autocast(): 67 | # Forward & loss computation 68 | loss_dict = self.model(x, target) 69 | return sum(loss_dict.values()) 70 | # Forward & loss computation 71 | loss_dict = self.model(x, target) 72 | return sum(loss_dict.values()) 73 | 74 | @staticmethod 75 | def _eval_metrics_str(eval_metrics: Dict[str, Optional[float]]) -> str: 76 | loc_str = f"{eval_metrics['loc_err']:.2%}" if isinstance(eval_metrics["loc_err"], float) else "N/A" 77 | clf_str = f"{eval_metrics['clf_err']:.2%}" if isinstance(eval_metrics["clf_err"], float) else "N/A" 78 | det_str = f"{eval_metrics['det_err']:.2%}" if isinstance(eval_metrics["det_err"], float) else "N/A" 79 | return f"Loc error: {loc_str} | Clf error: {clf_str} | Det error: {det_str}" 80 | 81 | @torch.inference_mode() 82 | def evaluate(self, iou_threshold: float = 0.5) -> Dict[str, Optional[float]]: 83 | """Evaluate the model on the validation set. 84 | 85 | Args: 86 | iou_threshold (float, optional): IoU threshold for pair assignment 87 | 88 | Returns: 89 | dict: evaluation metrics 90 | """ 91 | self.model.eval() 92 | 93 | loc_assigns = 0 94 | correct, clf_error, loc_fn, loc_fp, num_samples = 0, 0, 0, 0, 0 95 | 96 | for x, target in self.val_loader: 97 | x, target = self.to_cuda(x, target) 98 | 99 | if self.amp: 100 | with torch.cuda.amp.autocast(): 101 | detections = self.model(x) 102 | else: 103 | detections = self.model(x) 104 | 105 | for dets, t in zip(detections, target): 106 | if t["boxes"].shape[0] > 0 and dets["boxes"].shape[0] > 0: 107 | gt_indices, pred_indices = assign_iou(t["boxes"], dets["boxes"], iou_threshold) 108 | loc_assigns += len(gt_indices) 109 | correct_ = (t["labels"][gt_indices] == dets["labels"][pred_indices]).sum().item() 110 | else: 111 | gt_indices, pred_indices = [], [] 112 | correct_ = 0 113 | correct += correct_ 114 | clf_error += len(gt_indices) - correct_ 115 | loc_fn += t["boxes"].shape[0] - len(gt_indices) 116 | loc_fp += dets["boxes"].shape[0] - len(pred_indices) 117 | num_samples += sum(t["boxes"].shape[0] for t in target) 118 | 119 | nb_preds = num_samples - loc_fn + loc_fp 120 | # Localization 121 | loc_err = 1 - 2 * loc_assigns / (nb_preds + num_samples) if nb_preds + num_samples > 0 else None 122 | # Classification 123 | clf_err = 1 - correct / loc_assigns if loc_assigns > 0 else None 124 | # End-to-end 125 | det_err = 1 - 2 * correct / (nb_preds + num_samples) if nb_preds + num_samples > 0 else None 126 | return {"loc_err": loc_err, "clf_err": clf_err, "det_err": det_err, "val_loss": loc_err} 127 | -------------------------------------------------------------------------------- /holocron/trainer/segmentation.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2019-2024, François-Guillaume Fernandez. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import Any, Dict 7 | 8 | import torch 9 | 10 | from .core import Trainer 11 | 12 | __all__ = ["SegmentationTrainer"] 13 | 14 | 15 | class SegmentationTrainer(Trainer): 16 | """Semantic segmentation trainer class. 17 | 18 | Args: 19 | model: model to train 20 | train_loader: training loader 21 | val_loader: validation loader 22 | criterion: loss criterion 23 | optimizer: parameter optimizer 24 | gpu: index of the GPU to use 25 | output_file: path where checkpoints will be saved 26 | amp: whether to use automatic mixed precision 27 | skip_nan_loss: whether the optimizer step should be skipped when the loss is NaN 28 | nan_tolerance: number of consecutive batches with NaN loss before stopping the training 29 | gradient_acc: number of batches to accumulate the gradient of before performing the update step 30 | gradient_clip: the gradient clip value 31 | on_epoch_end: callback triggered at the end of an epoch 32 | """ 33 | 34 | def __init__(self, *args: Any, num_classes: int = 10, **kwargs: Any) -> None: 35 | super().__init__(*args, **kwargs) 36 | self.num_classes = num_classes 37 | 38 | @torch.inference_mode() 39 | def evaluate(self, ignore_index: int = 255) -> Dict[str, float]: 40 | """Evaluate the model on the validation set 41 | 42 | Args: 43 | ignore_index (int, optional): index of the class to ignore in evaluation 44 | 45 | Returns: 46 | dict: evaluation metrics 47 | """ 48 | self.model.eval() 49 | 50 | val_loss, mean_iou, num_valid_batches = 0.0, 0.0, 0 51 | conf_mat = torch.zeros( 52 | (self.num_classes, self.num_classes), dtype=torch.int64, device=next(self.model.parameters()).device 53 | ) 54 | for x, target in self.val_loader: 55 | x, target = self.to_cuda(x, target) 56 | 57 | loss, out = self._get_loss(x, target, return_logits=True) 58 | 59 | # Safeguard for NaN loss 60 | if not torch.isnan(loss) and not torch.isinf(loss): 61 | val_loss += loss.item() 62 | num_valid_batches += 1 63 | 64 | # borrowed from https://github.com/pytorch/vision/blob/master/references/segmentation/train.py 65 | pred = out.argmax(dim=1).flatten() 66 | target = target.flatten() 67 | k = (target >= 0) & (target < self.num_classes) 68 | inds = self.num_classes * target[k].to(torch.int64) + pred[k] 69 | nc = self.num_classes 70 | conf_mat += torch.bincount(inds, minlength=nc**2).reshape(nc, nc) 71 | 72 | val_loss /= num_valid_batches 73 | acc_global = (torch.diag(conf_mat).sum() / conf_mat.sum()).item() 74 | mean_iou = (torch.diag(conf_mat) / (conf_mat.sum(1) + conf_mat.sum(0) - torch.diag(conf_mat))).mean().item() 75 | 76 | return {"val_loss": val_loss, "acc_global": acc_global, "mean_iou": mean_iou} 77 | 78 | @staticmethod 79 | def _eval_metrics_str(eval_metrics: Dict[str, float]) -> str: 80 | return ( 81 | f"Validation loss: {eval_metrics['val_loss']:.4} " 82 | f"(Acc: {eval_metrics['acc_global']:.2%} | Mean IoU: {eval_metrics['mean_iou']:.2%})" 83 | ) 84 | -------------------------------------------------------------------------------- /holocron/trainer/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2019-2024, François-Guillaume Fernandez. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import List, Optional, Tuple 7 | 8 | from torch import nn 9 | from torch.nn.modules.batchnorm import _BatchNorm 10 | 11 | __all__ = ["freeze_bn", "freeze_model", "split_normalization_params"] 12 | 13 | 14 | def freeze_bn(mod: nn.Module) -> None: 15 | """Prevents parameter and stats from updating in Batchnorm layers that are frozen 16 | 17 | >>> from holocron.models import rexnet1_0x 18 | >>> from holocron.trainer.utils import freeze_bn 19 | >>> model = rexnet1_0x() 20 | >>> freeze_bn(model) 21 | 22 | Args: 23 | mod (torch.nn.Module): model to train 24 | """ 25 | # Loop on modules 26 | for m in mod.modules(): 27 | if isinstance(m, _BatchNorm) and m.affine and all(not p.requires_grad for p in m.parameters()): 28 | # Switch back to commented code when https://github.com/pytorch/pytorch/issues/37823 is resolved 29 | m.track_running_stats = False 30 | m.eval() 31 | 32 | 33 | def freeze_model( 34 | model: nn.Module, 35 | last_frozen_layer: Optional[str] = None, 36 | frozen_bn_stat_update: bool = False, 37 | ) -> None: 38 | """Freeze a specific range of model layers. 39 | 40 | >>> from holocron.models import rexnet1_0x 41 | >>> from holocron.trainer.utils import freeze_model 42 | >>> model = rexnet1_0x() 43 | >>> freeze_model(model) 44 | 45 | Args: 46 | model (torch.nn.Module): model to train 47 | last_frozen_layer (str, optional): last layer to freeze. Assumes layers have been registered in forward order 48 | frozen_bn_stat_update (bool, optional): force stats update in BN layers that are frozen 49 | """ 50 | # Unfreeze everything 51 | for p in model.parameters(): 52 | p.requires_grad_(True) 53 | 54 | # Loop on parameters 55 | if isinstance(last_frozen_layer, str): 56 | layer_reached = False 57 | for n, p in model.named_parameters(): 58 | if not layer_reached or n.startswith(last_frozen_layer): 59 | p.requires_grad_(False) 60 | if n.startswith(last_frozen_layer): 61 | layer_reached = True 62 | # Once the last param of the layer is frozen, we break 63 | elif layer_reached: 64 | break 65 | if not layer_reached: 66 | raise ValueError(f"Unable to locate child module {last_frozen_layer}") 67 | 68 | # Loop on modules 69 | if not frozen_bn_stat_update: 70 | freeze_bn(model) 71 | 72 | 73 | def split_normalization_params( 74 | model: nn.Module, 75 | norm_classes: Optional[List[type]] = None, 76 | ) -> Tuple[List[nn.Parameter], List[nn.Parameter]]: 77 | """Split the param groups by normalization schemes""" 78 | # Borrowed from https://github.com/pytorch/vision/blob/main/torchvision/ops/_utils.py 79 | # Adapted from https://github.com/facebookresearch/ClassyVision/blob/659d7f78/classy_vision/generic/util.py#L501 80 | if not norm_classes: 81 | norm_classes = [nn.modules.batchnorm._BatchNorm, nn.LayerNorm, nn.GroupNorm] 82 | 83 | for t in norm_classes: 84 | if not issubclass(t, nn.Module): 85 | raise ValueError(f"Class {t} is not a subclass of nn.Module.") 86 | 87 | classes = tuple(norm_classes) 88 | 89 | norm_params: List[nn.Parameter] = [] 90 | other_params: List[nn.Parameter] = [] 91 | for module in model.modules(): 92 | if next(module.children(), None): 93 | other_params.extend(p for p in module.parameters(recurse=False) if p.requires_grad) 94 | elif isinstance(module, classes): 95 | norm_params.extend(p for p in module.parameters() if p.requires_grad) 96 | else: 97 | other_params.extend(p for p in module.parameters() if p.requires_grad) 98 | return norm_params, other_params 99 | -------------------------------------------------------------------------------- /holocron/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .interpolation import * 2 | -------------------------------------------------------------------------------- /holocron/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data 2 | from .misc import * 3 | -------------------------------------------------------------------------------- /holocron/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .collate import * 2 | -------------------------------------------------------------------------------- /holocron/utils/data/collate.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2019-2024, François-Guillaume Fernandez. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import Tuple 7 | 8 | import torch 9 | from torch import Tensor 10 | from torch.distributions.beta import Beta 11 | from torch.nn.functional import one_hot 12 | 13 | __all__ = ["Mixup"] 14 | 15 | 16 | class Mixup(torch.nn.Module): 17 | """Implements a batch collate function with MixUp strategy from 18 | `"mixup: Beyond Empirical Risk Minimization" `_. 19 | 20 | >>> import torch 21 | >>> from torch.utils.data._utils.collate import default_collate 22 | >>> from holocron.utils.data import Mixup 23 | >>> mix = Mixup(num_classes=10, alpha=0.4) 24 | >>> loader = torch.utils.data.DataLoader(dataset, batch_size, collate_fn=lambda b: mix(*default_collate(b))) 25 | 26 | Args: 27 | num_classes: number of expected classes 28 | alpha: mixup factor 29 | """ 30 | 31 | def __init__(self, num_classes: int, alpha: float = 0.2) -> None: 32 | super().__init__() 33 | self.num_classes = num_classes 34 | if alpha < 0: 35 | raise ValueError("`alpha` only takes positive values") 36 | self.alpha = alpha 37 | 38 | def forward(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]: 39 | # Convert target to one-hot 40 | if targets.ndim == 1: 41 | # (N,) --> (N, C) 42 | if self.num_classes > 1: 43 | targets = one_hot(targets, num_classes=self.num_classes) 44 | elif self.num_classes == 1: 45 | targets = targets.unsqueeze(1) 46 | targets = targets.to(dtype=inputs.dtype) 47 | 48 | # Sample lambda 49 | if self.alpha == 0: 50 | return inputs, targets 51 | lam = Beta(self.alpha, self.alpha).sample() 52 | 53 | # Mix batch indices 54 | batch_size = inputs.size()[0] 55 | index = torch.randperm(batch_size) 56 | 57 | # Create the new input and targets 58 | mixed_input, mixed_target = inputs[index, :], targets[index] 59 | mixed_input.mul_(1 - lam) 60 | inputs.mul_(lam).add_(mixed_input) 61 | mixed_target.mul_(1 - lam) 62 | targets.mul_(lam).add_(mixed_target) 63 | 64 | return inputs, targets 65 | -------------------------------------------------------------------------------- /holocron/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2019-2024, François-Guillaume Fernandez. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | import multiprocessing as mp 7 | from math import sqrt 8 | from multiprocessing.pool import ThreadPool 9 | from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, TypeVar 10 | 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | from PIL import Image 14 | from tqdm.auto import tqdm 15 | 16 | Inp = TypeVar("Inp") 17 | Out = TypeVar("Out") 18 | 19 | 20 | __all__ = ["find_image_size", "parallel"] 21 | 22 | 23 | def parallel( 24 | func: Callable[[Inp], Out], 25 | arr: Sequence[Inp], 26 | num_threads: Optional[int] = None, 27 | progress: bool = False, 28 | **kwargs: Any, 29 | ) -> Iterable[Out]: 30 | """Performs parallel tasks by leveraging multi-threading. 31 | 32 | >>> from holocron.utils.misc import parallel 33 | >>> parallel(lambda x: x ** 2, list(range(10))) 34 | 35 | Args: 36 | func: function to be executed on multiple workers 37 | arr: function argument's values 38 | num_threads: number of workers to be used for multiprocessing 39 | progress: whether the progress bar should be displayed 40 | kwargs: keyword arguments of tqdm 41 | 42 | Returns: 43 | list: list of function's results 44 | """ 45 | num_threads = num_threads if isinstance(num_threads, int) else min(16, mp.cpu_count()) 46 | if num_threads < 2: 47 | results = list(map(func, tqdm(arr, total=len(arr), **kwargs))) if progress else map(func, arr) 48 | else: 49 | with ThreadPool(num_threads) as tp: 50 | results = list(tqdm(tp.imap(func, arr), total=len(arr), **kwargs)) if progress else tp.map(func, arr) 51 | 52 | return results 53 | 54 | 55 | def find_image_size(dataset: Sequence[Tuple[Image.Image, Any]], **kwargs: Any) -> None: 56 | """Computes the best image size target for a given set of images 57 | 58 | Args: 59 | dataset: an iterator yielding a PIL Image and a target object 60 | kwargs: keyword args of matplotlib.pyplot.show 61 | 62 | Returns: 63 | the suggested height and width to be used 64 | """ 65 | # Record height & width 66 | shapes_ = parallel(lambda x: x[0].size, dataset, progress=True) 67 | 68 | shapes = np.asarray(shapes_)[:, ::-1] 69 | ratios = shapes[:, 0] / shapes[:, 1] 70 | sides = np.sqrt(shapes[:, 0] * shapes[:, 1]) 71 | 72 | # Compute median aspect ratio & side 73 | median_ratio = np.median(ratios) 74 | median_side = np.median(sides) 75 | 76 | height = round(median_side * sqrt(median_ratio)) 77 | width = round(median_side / sqrt(median_ratio)) 78 | 79 | # Double histogram 80 | fig, axes = plt.subplots(1, 2) 81 | axes[0].hist(ratios, bins=30, alpha=0.7) 82 | axes[0].title.set_text(f"Aspect ratio (median: {median_ratio:.2})") 83 | axes[0].grid(True, linestyle="--", axis="x") 84 | axes[0].axvline(median_ratio, color="r") 85 | axes[1].hist(sides, bins=30, alpha=0.7) 86 | axes[1].title.set_text(f"Side (median: {int(median_side)})") 87 | axes[1].grid(True, linestyle="--", axis="x") 88 | axes[1].axvline(median_side, color="r") 89 | fig.suptitle(f"Median image size: ({height}, {width})") 90 | plt.show(**kwargs) 91 | -------------------------------------------------------------------------------- /notebooks/README.md: -------------------------------------------------------------------------------- 1 | # Holocron Notebooks 2 | 3 | Here are some notebooks compiled for users to better leverage the library capabilities: 4 | 5 | | Notebook | Description | | 6 | |:----------|:-------------|------:| 7 | | [Quicktour](https://github.com/frgfm/notebooks/blob/main/holocron/quicktour.ipynb) | A presentation of the main features of Holocron | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/frgfm/notebooks/blob/main/holocron/quicktour.ipynb) | 8 | | [HuggingFace Hub integration](https://github.com/frgfm/notebooks/blob/main/holocron/hf_hub.ipynb) | Use HuggingFace model hub with Holocron | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/frgfm/notebooks/blob/main/holocron/hf_hub.ipynb) | 9 | | [Image classification](https://github.com/frgfm/notebooks/blob/main/holocron/classification_training.ipynb) | How to train your own image classifier | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/frgfm/notebooks/blob/main/holocron/classification_training.ipynb) | 10 | -------------------------------------------------------------------------------- /references/README.md: -------------------------------------------------------------------------------- 1 | # Holocron training scripts 2 | 3 | This section is specific to train computer vision models. 4 | 5 | 6 | ## Installation 7 | 8 | ### Prerequisites 9 | 10 | Python 3.8 (or higher) and [pip](https://pip.pypa.io/en/stable/) & [Git](https://git-scm.com/book/en/v2/Getting-Started-Installing-Git) are required to install Holocron. 11 | 12 | 13 | ### Developer mode 14 | 15 | In order to install the specific dependencies for training, you will have to install the package from source *(install [Git](https://git-scm.com/book/en/v2/Getting-Started-Installing-Git) first)*: 16 | 17 | ```shell 18 | git clone https://github.com/frgfm/Holocron.git 19 | pip install -e "Holocron/.[training]" 20 | ``` 21 | 22 | ## Available tasks 23 | 24 | ### Image classification 25 | 26 | Refer to the [`./classification`](classification) folder 27 | 28 | ### Semantic segmentation 29 | 30 | Refer to the [`./segmentation`](segmentation) folder 31 | 32 | ### Object detection 33 | 34 | Refer to the [`./detection`](detection) folder 35 | -------------------------------------------------------------------------------- /references/classification/README.md: -------------------------------------------------------------------------------- 1 | # Image classification 2 | 3 | Since I do not own enough computing power to iterate over ImageNet full training, this section involves training on a subset of ImageNet, called [Imagenette](https://github.com/fastai/imagenette). 4 | 5 | ## Getting started 6 | 7 | Ensure that you have holocron installed 8 | 9 | ```bash 10 | git clone https://github.com/frgfm/Holocron.git 11 | pip install -e "Holocron/.[training]" 12 | ``` 13 | 14 | Download [Imagenette](https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz) and extract it where you want 15 | 16 | ```bash 17 | wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz 18 | tar -xvzf imagenette2-320.tgz 19 | ``` 20 | 21 | From there, you can run your training with the following command 22 | 23 | ``` 24 | python train.py imagenette2-320/ --arch darknet53 --lr 5e-3 -b 32 -j 16 --epochs 40 --opt adamp --sched onecycle 25 | ``` 26 | 27 | 28 | 29 | ## Personal leaderboard 30 | 31 | The updated list of available checkpoints can be found in the [documentation](https://frgfm.github.io/Holocron/latest/models.html#classification). 32 | 33 | 34 | ## Imagenette 35 | 36 | | Model | Accuracy@1 (Err) | Param # | MACs | Interpolation | Image size | 37 | | ---------------- | ---------------- | ------- | ----- | ------------- | ---------- | 38 | | cspdarknet53 | 92.54 (7.46) | 26.63M | 5.03G | bilinear | 224 | 39 | | cspdarknet53_mish| 94.14 (5.86) | 26.63M | 5.03G | bilinear | 256 | 40 | | rexnet2_2x | 91.75 (8.25) | 19.49M | 1.88G | bilinear | 224 | 41 | | rexnet50d | 92.18 (7.82) | 23.55M | 4.35G | bilinear | 224 | 42 | | darknet53 | 91.46 (8.54) | 40.60M | 9.31G | bilinear | 256 | 43 | | repvgg_a2 | 91.26 (8.74) | 48.63M | | bilinear | 224 | 44 | | darknet19 | 91.87 (8.13) | 19.83M | 2.75G | bilinear | 224 | 45 | | tridentresnet50 | 91.01 (8.99) | 45.83M | 35.9G | bilinear | 224 | 46 | | sknet50 | 90.42 (9.58) | 35.22M | 5.96G | bilinear | 224 | 47 | | rexnet1_3x | 94.06 (5.94) | 7.56M | 0.68G | bilinear | 224 | 48 | | repvgg_a1 | 90.97 (9.03) | 30.12M | | bilinear | 224 | 49 | | rexnet1_0x | 92.99 (7.01) | 4.80M | 0.42G | bilinear | 224 | 50 | | repvgg_a0 | 91.18 (8.82) | 24.74M | | bilinear | 224 | 51 | | repvgg_b0 | 89.61 (9.39) | 31.85M | | bilinear | 224 | 52 | | res2net50_26w_4s | 89.58 (99.26) | 23.67M | 4.28G | bilinear | 224 | 53 | | darnet24 | 91.57 (8.43) | 22.40M | 4.21G | bilinear | 224 | 54 | | resnet50 | 84.36 (15.64) | 23.53M | 4.11G | bilinear | 224 | 55 | -------------------------------------------------------------------------------- /references/clean_checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2019-2024, François-Guillaume Fernandez. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | import hashlib 7 | from pathlib import Path 8 | 9 | import torch 10 | 11 | 12 | def main(args): 13 | checkpoint = torch.load(args.checkpoint, map_location="cpu")["model"] 14 | torch.save(checkpoint, args.outfile, _use_new_zipfile_serialization=False) 15 | 16 | with Path(args.outfile).open("rb") as f: 17 | sha_hash = hashlib.sha256(f.read()).hexdigest() 18 | print(f"Checkpoint saved to {args.outfile} with hash: {sha_hash[:8]}") 19 | 20 | 21 | def parse_args(): 22 | import argparse 23 | 24 | parser = argparse.ArgumentParser( 25 | description="Training checkpoint cleanup", formatter_class=argparse.ArgumentDefaultsHelpFormatter 26 | ) 27 | 28 | parser.add_argument("checkpoint", type=str, help="path to the training checkpoint") 29 | parser.add_argument("outfile", type=str, help="model") 30 | return parser.parse_args() 31 | 32 | 33 | if __name__ == "__main__": 34 | args = parse_args() 35 | main(args) 36 | -------------------------------------------------------------------------------- /references/detection/README.md: -------------------------------------------------------------------------------- 1 | # Object detection 2 | 3 | The sample training script was made to train object detection models on [PASCAL VOC 2012](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/). 4 | 5 | ## Getting started 6 | 7 | Ensure that you have holocron installed 8 | 9 | ```bash 10 | git clone https://github.com/frgfm/Holocron.git 11 | pip install -e "Holocron/.[training]" 12 | ``` 13 | 14 | No need to download the dataset, torchvision will handle [this](https://pytorch.org/docs/stable/torchvision/datasets.html#torchvision.datasets.VOCDetection) for you! From there, you can run your training with the following command 15 | 16 | ```bash 17 | python train.py VOC2012 --arch yolov2 --lr 1e-5 -b 32 -j 16 --epochs 20 --opt radam --sched onecycle 18 | ``` 19 | 20 | 21 | 22 | ## Personal leaderboard 23 | 24 | ### PASCAL VOC 2012 25 | 26 | Performances are evaluated on the validation set of the dataset. Since the mAP does not allow easy interpretation by humans, the performance metrics have been changed here. 27 | 28 | A prediction is considered as correct if it checks two criteria: 29 | 30 | - Localization: it is the best acceptable localization candidate (highest IoU among predictions with the GT, and IoU >= 0.5) 31 | - Classification: the top predicted probabilities is for the class label of the matched ground truth object. 32 | 33 | Then we define: 34 | 35 | - **Localization error rate**: with loc_recall being the matching rate of ground truth boxes, and loc_precision being the matching rate of predicted boxes, we define the localization error as 1 - (harmonic mean of localization loc_recall & loc_precision) 36 | - **Classification error rate**: classification error rate of matched predictions. 37 | - **Detection error rate**: with det_recall being the correctness rate of ground truth boxes, and det_precision being the correctness rate of predicted boxes, we define the localization error as 1 - (harmonic mean of localization det_recall & det_precision) 38 | 39 | Here, the recall being the ratio of correctly predicted ground truth predictions by the total number of ground truth objects, and the precision being the ratio of correctly predicted ground truth predictions by the total number of predicted boxes. 40 | 41 | | Size (px) | Epochs | args | Loc@.5 | Clf@.5 | Det@.5 | # Runs | 42 | | --------- | ------ | ------------------------------------------------------------ | ------ | ------ | ------ | ------ | 43 | | 416 | 40 | VOC2012 --arch yolov2 --img-size 416 --lr 5e-4 -b 64 -j 16 --epochs 40 --opt tadam --freeze-backbone --sched onecycle | 83.09 | 52.82 | 92.02 | 1 | 44 | 45 | 46 | 47 | ## Model zoo 48 | 49 | | Model | Loc@.5 | Clf@.5 | Det@.5 | Param # | MACs | Interpolation | Image size | 50 | | ------ | ------ | ------ | ------ | ------- | ---- | ------------- | ---------- | 51 | | yolov2 | 83.09 | 52.82 | 92.02 | 50.65M | | bilinear | 416 | 52 | -------------------------------------------------------------------------------- /references/detection/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2019-2024, François-Guillaume Fernandez. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | """ 7 | Transformation for object detection 8 | """ 9 | 10 | import torch 11 | from torchvision.transforms import functional as F 12 | from torchvision.transforms import transforms 13 | 14 | 15 | class VOCTargetTransform: 16 | def __init__(self, classes): 17 | self.class_map = {label: idx for idx, label in enumerate(classes)} 18 | 19 | def __call__(self, image, target): 20 | # Format boxes properly 21 | boxes = torch.tensor( 22 | [ 23 | [ 24 | int(obj["bndbox"]["xmin"]), 25 | int(obj["bndbox"]["ymin"]), 26 | int(obj["bndbox"]["xmax"]), 27 | int(obj["bndbox"]["ymax"]), 28 | ] 29 | for obj in target["annotation"]["object"] 30 | ], 31 | dtype=torch.float32, 32 | ) 33 | # Encode class labels 34 | labels = torch.tensor([self.class_map[obj["name"]] for obj in target["annotation"]["object"]], dtype=torch.long) 35 | 36 | return image, {"boxes": boxes, "labels": labels} 37 | 38 | 39 | class Compose(transforms.Compose): 40 | def __call__(self, image, target): 41 | for t in self.transforms: 42 | image, target = t(image, target) 43 | return image, target 44 | 45 | 46 | class ImageTransform(object): 47 | def __init__(self, transform): 48 | self.transform = transform 49 | 50 | def __call__(self, image, target): 51 | image = self.transform.__call__(image) 52 | return image, target 53 | 54 | def __repr__(self): 55 | return self.transform.__repr__() 56 | 57 | 58 | class CenterCrop(transforms.CenterCrop): 59 | def __call__(self, image, target): 60 | image = F.center_crop(image, self.size) 61 | x = int(image.size[0] / 2 - self.size[0] / 2) 62 | y = int(image.size[1] / 2 - self.size[1] / 2) 63 | # Crop 64 | target["boxes"][:, [0, 2]] = target["boxes"][:, [0, 2]].clamp_(x, x + self.size[0]) 65 | target["boxes"][:, [1, 3]] = target["boxes"][:, [1, 3]].clamp_(y, y + self.size[1]) 66 | target["boxes"][:, [0, 2]] -= x 67 | target["boxes"][:, [1, 3]] -= y 68 | 69 | return image, target 70 | 71 | 72 | class Resize(transforms.Resize): 73 | def __call__(self, image, target): 74 | if isinstance(self.size, int): 75 | if image.size[1] < image.size[0]: 76 | target["boxes"] *= self.size / image.size[1] 77 | else: 78 | target["boxes"] *= self.size / image.size[0] 79 | elif isinstance(self.size, tuple): 80 | target["boxes"][:, [0, 2]] *= self.size[0] / image.size[0] 81 | target["boxes"][:, [1, 3]] *= self.size[1] / image.size[1] 82 | return F.resize(image, self.size, self.interpolation), target 83 | 84 | 85 | class RandomResizedCrop(transforms.RandomResizedCrop): 86 | def __call__(self, image, target): 87 | i, j, h, w = self.get_params(image, self.scale, self.ratio) 88 | image = F.resized_crop(image, i, j, h, w, self.size, self.interpolation) 89 | # Crop 90 | target["boxes"][:, [0, 2]] = target["boxes"][:, [0, 2]].clamp_(j, j + w) 91 | target["boxes"][:, [1, 3]] = target["boxes"][:, [1, 3]].clamp_(i, i + h) 92 | # Reset origin 93 | target["boxes"][:, [0, 2]] -= j 94 | target["boxes"][:, [1, 3]] -= i 95 | # Remove targets that are out of crop 96 | target_filter = (target["boxes"][:, 0] != target["boxes"][:, 2]) & ( 97 | target["boxes"][:, 1] != target["boxes"][:, 3] 98 | ) 99 | target["boxes"] = target["boxes"][target_filter] 100 | target["labels"] = target["labels"][target_filter] 101 | # Resize 102 | target["boxes"][:, [0, 2]] *= self.size[0] / w 103 | target["boxes"][:, [1, 3]] *= self.size[1] / h 104 | 105 | return image, target 106 | 107 | 108 | def convert_to_relative(image, target): 109 | target["boxes"][:, [0, 2]] /= image.size[0] 110 | target["boxes"][:, [1, 3]] /= image.size[1] 111 | 112 | # Clip 113 | target["boxes"][:, [0, 2]] = target["boxes"][:, [0, 2]].clamp_(0, 1) 114 | target["boxes"][:, [1, 3]] = target["boxes"][:, [1, 3]].clamp_(0, 1) 115 | 116 | return image, target 117 | 118 | 119 | class RandomHorizontalFlip(transforms.RandomHorizontalFlip): 120 | def __call__(self, image, target): 121 | if torch.rand(1).item() < self.p: 122 | _, width = image.size 123 | image = F.hflip(image) 124 | target["boxes"][:, [0, 2]] = width - target["boxes"][:, [0, 2]] 125 | # Reorder them correctly 126 | target["boxes"] = target["boxes"][:, [2, 1, 0, 3]] 127 | return image, target 128 | -------------------------------------------------------------------------------- /references/segmentation/README.md: -------------------------------------------------------------------------------- 1 | # Semantic segmentation 2 | 3 | The sample training script was made to train object detection models on [PASCAL VOC 2012](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/). 4 | 5 | ## Getting started 6 | 7 | Ensure that you have holocron installed 8 | 9 | ```bash 10 | git clone https://github.com/frgfm/Holocron.git 11 | pip install -e "Holocron/.[training]" 12 | ``` 13 | 14 | No need to download the dataset, torchvision will handle [this](https://pytorch.org/docs/stable/torchvision/datasets.html#torchvision.datasets.VOCSegmentation) for you! From there, you can run your training with the following command 15 | 16 | ```bash 17 | python train.py VOC2012 --arch unet3p -b 4 -j 16 --opt radam --lr 1e-5 --sched onecycle --epochs 20 18 | ``` 19 | 20 | 21 | 22 | ## Personal leaderboard 23 | 24 | ### PASCAL VOC 2012 25 | 26 | Performances are evaluated on the validation set of the dataset using the mean IoU metric. 27 | 28 | | Size (px) | Epochs | args | mean IoU | # Runs | 29 | | --------- | ------ | ------------------------------------------------------------ | -------- | ------ | 30 | | 256 | 200 | VOC2012 --arch unet_rexnet13 -b 16 --loss label_smoothing --opt adamp --device 0 --lr 2e-3 --epochs 200 | 32.14 | 1 | 31 | | 256 | 20 | VOC2012 --arch unet3p -b 4 -j 16 --opt radam --lr 1e-5 --sched onecycle --epochs 20 | 14.17 | 1 | 32 | 33 | 34 | 35 | ## Model zoo 36 | 37 | | Model | mean IoU | Param # | MACs | Interpolation | Image size | 38 | | ------------- | -------- | ------- | ---- | ------------- | ---------- | 39 | | unet | | 18.11M | | bilinear | 256 | 40 | | unetp | | 28.28M | | bilinear | 256 | 41 | | unetpp | | 29.54M | | bilinear | 256 | 42 | | unet3p | | 26.93M | | bilinear | 256 | 43 | | unet_tvvgg11 | | 32.17M | | bilinear | 256 | 44 | | unet_tvresnet34 | | 36.25M | | bilinear | 256 | 45 | | unet_rexnet13 | 32.14 | 9.34M | | bilinear | 256 | 46 | -------------------------------------------------------------------------------- /references/segmentation/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2019-2024, François-Guillaume Fernandez. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | """ 7 | Transformation for semantic segmentation 8 | """ 9 | 10 | import numpy as np 11 | import torch 12 | from torchvision.transforms import InterpolationMode, transforms 13 | from torchvision.transforms import functional as F 14 | 15 | 16 | def pad_if_smaller(img, size, fill=0): 17 | min_size = min(img.size) 18 | if min_size < size: 19 | ow, oh = img.size 20 | padh = size - oh if oh < size else 0 21 | padw = size - ow if ow < size else 0 22 | img = F.pad(img, (0, 0, padw, padh), fill=fill) 23 | return img 24 | 25 | 26 | class Compose(transforms.Compose): 27 | def __init__(self, transforms): 28 | super(Compose, self).__init__(transforms) 29 | 30 | def __call__(self, image, target): 31 | for t in self.transforms: 32 | image, target = t(image, target) 33 | return image, target 34 | 35 | 36 | class Resize(object): 37 | def __init__(self, output_size, interpolation=InterpolationMode.BILINEAR): 38 | self.output_size = output_size 39 | self.interpolation = interpolation 40 | 41 | def __call__(self, image, target): 42 | image = F.resize(image, self.output_size, interpolation=self.interpolation) 43 | target = F.resize(target, self.output_size, interpolation=InterpolationMode.NEAREST) 44 | return image, target 45 | 46 | def __repr__(self): 47 | return f"{self.__class__.__name__}(output_size={self.output_size})" 48 | 49 | 50 | class RandomResize(object): 51 | def __init__(self, min_size, max_size=None, interpolation=InterpolationMode.BILINEAR): 52 | self.min_size = min_size 53 | if max_size is None: 54 | max_size = min_size 55 | self.max_size = max_size 56 | self.interpolation = interpolation 57 | 58 | def __call__(self, image, target): 59 | if self.min_size == self.max_size: 60 | size = self.min_size 61 | else: 62 | size = torch.randint(self.min_size, self.max_size, (1,)).item() 63 | image = F.resize(image, size, interpolation=self.interpolation) 64 | target = F.resize(target, size, interpolation=InterpolationMode.NEAREST) 65 | return image, target 66 | 67 | def __repr__(self): 68 | return f"{self.__class__.__name__}(min_size={self.min_size}, max_size={self.max_size})" 69 | 70 | 71 | class RandomHorizontalFlip(object): 72 | def __init__(self, prob): 73 | self.prob = prob 74 | 75 | def __call__(self, image, target): 76 | if torch.rand(1).item() < self.prob: 77 | image = F.hflip(image) 78 | # Flip the segmentation 79 | target = F.hflip(target) 80 | 81 | return image, target 82 | 83 | def __repr__(self): 84 | return f"{self.__class__.__name__}(p={self.prob})" 85 | 86 | 87 | class RandomCrop(object): 88 | def __init__(self, size): 89 | self.size = size 90 | 91 | def __call__(self, image, target): 92 | image = pad_if_smaller(image, self.size) 93 | target = pad_if_smaller(target, self.size, fill=255) 94 | crop_params = transforms.RandomCrop.get_params(image, (self.size, self.size)) 95 | image = F.crop(image, *crop_params) 96 | target = F.crop(target, *crop_params) 97 | return image, target 98 | 99 | def __repr__(self): 100 | return f"{self.__class__.__name__}(size={self.size})" 101 | 102 | 103 | class ToTensor(transforms.ToTensor): 104 | def __call__(self, img, target): 105 | img = super(ToTensor, self).__call__(img) 106 | target = torch.as_tensor(np.array(target), dtype=torch.int64) 107 | 108 | return img, target 109 | 110 | 111 | class ImageTransform(object): 112 | def __init__(self, transform): 113 | self.transform = transform 114 | 115 | def __call__(self, image, target): 116 | image = self.transform.__call__(image) 117 | return image, target 118 | 119 | def __repr__(self): 120 | return self.transform.__repr__() 121 | -------------------------------------------------------------------------------- /scripts/eval_latency.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2019-2024, François-Guillaume Fernandez. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | """ 7 | Holocron model latency benchmark 8 | """ 9 | 10 | import argparse 11 | import time 12 | 13 | import numpy as np 14 | import onnxruntime 15 | import torch 16 | 17 | from holocron import models 18 | 19 | 20 | @torch.inference_mode() 21 | def run_evaluation( 22 | model: torch.nn.Module, img_tensor: torch.Tensor, num_it: int = 100, warmup_it: int = 10 23 | ) -> np.array: 24 | # Warmup 25 | for _ in range(warmup_it): 26 | _ = model(img_tensor) 27 | 28 | timings = [] 29 | 30 | # Evaluation runs 31 | for _ in range(num_it): 32 | start_ts = time.perf_counter() 33 | _ = model(img_tensor) 34 | timings.append(time.perf_counter() - start_ts) 35 | 36 | return np.array(timings) 37 | 38 | 39 | def run_onnx_evaluation( 40 | model: onnxruntime.InferenceSession, img_tensor: np.array, num_it: int = 100, warmup_it: int = 10 41 | ) -> np.array: 42 | # Set input 43 | ort_input = {model.get_inputs()[0].name: img_tensor} 44 | # Warmup 45 | for _ in range(warmup_it): 46 | _ = model.run(None, ort_input) 47 | 48 | timings = [] 49 | 50 | # Evaluation runs 51 | for _ in range(num_it): 52 | start_ts = time.perf_counter() 53 | _ = model.run(None, ort_input) 54 | timings.append(time.perf_counter() - start_ts) 55 | 56 | return np.array(timings) 57 | 58 | 59 | @torch.inference_mode() 60 | def main(args): 61 | # Pretrained imagenet model 62 | model = models.__dict__[args.arch](pretrained=args.pretrained).eval() 63 | # Reparametrizable models 64 | if args.arch.startswith("repvgg") or args.arch.startswith("mobileone"): 65 | model.reparametrize() 66 | 67 | # Input 68 | img_tensor = torch.rand((1, 3, args.size, args.size)) 69 | 70 | timings = run_evaluation(model, img_tensor, args.it) 71 | cpu_str = f"mean {1000 * timings.mean():.2f}ms, std {1000 * timings.std():.2f}ms" 72 | 73 | # ONNX 74 | torch.onnx.export( 75 | model, 76 | img_tensor, 77 | "tmp.onnx", 78 | export_params=True, 79 | opset_version=14, 80 | ) 81 | onnx_session = onnxruntime.InferenceSession("tmp.onnx") 82 | npy_tensor = img_tensor.numpy() 83 | timings = run_onnx_evaluation(onnx_session, npy_tensor, args.it) 84 | onnx_str = f"mean {1000 * timings.mean():.2f}ms, std {1000 * timings.std():.2f}ms" 85 | 86 | # GPU 87 | if args.device is None: 88 | args.device = "cuda:0" if torch.cuda.is_available() else "cpu" 89 | if args.device == "cpu": 90 | gpu_str = "N/A" 91 | else: 92 | device = torch.device(args.device) 93 | model = model.to(device=device) 94 | 95 | # Input 96 | img_tensor = img_tensor.to(device=device) 97 | timings = run_evaluation(model, img_tensor, args.it) 98 | gpu_str = f"mean {1000 * timings.mean():.2f}ms, std {1000 * timings.std():.2f}ms" 99 | 100 | print(f"{args.arch} ({args.it} runs on ({args.size}, {args.size}) inputs)") 101 | print(f"CPU - {cpu_str}\nONNX - {onnx_str}\nGPU - {gpu_str}") 102 | 103 | 104 | if __name__ == "__main__": 105 | parser = argparse.ArgumentParser( 106 | description="Holocron model latency benchmark", formatter_class=argparse.ArgumentDefaultsHelpFormatter 107 | ) 108 | parser.add_argument("arch", type=str, help="Architecture to use") 109 | parser.add_argument("--size", type=int, default=224, help="The image input size") 110 | parser.add_argument("--device", type=str, default=None, help="Default device to perform computation on") 111 | parser.add_argument("--it", type=int, default=100, help="Number of iterations to run") 112 | parser.add_argument("--warmup", type=int, default=10, help="Number of iterations for warmup") 113 | parser.add_argument( 114 | "--pretrained", dest="pretrained", help="Use pre-trained models from the modelzoo", action="store_true" 115 | ) 116 | args = parser.parse_args() 117 | 118 | main(args) 119 | -------------------------------------------------------------------------------- /scripts/export_to_onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-2024, François-Guillaume Fernandez. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | """ 7 | Holocron model ONNX export 8 | """ 9 | 10 | import argparse 11 | 12 | import torch 13 | 14 | from holocron import models 15 | 16 | 17 | @torch.inference_mode() 18 | def main(args): 19 | is_pretrained = args.pretrained and not isinstance(args.checkpoint, str) 20 | # Pretrained imagenet model 21 | model = models.__dict__[args.arch](pretrained=is_pretrained).eval() 22 | 23 | # Load the checkpoint 24 | if isinstance(args.checkpoint, str): 25 | state_dict = torch.load(args.checkpoint, map_location="cpu") 26 | model.load_state_dict(state_dict, strict=True) 27 | 28 | # RepVGG 29 | if args.arch.startswith("repvgg") or args.arch.startswith("mobileone"): 30 | model.reparametrize() 31 | 32 | # Input 33 | img_tensor = torch.rand((args.batch_size, args.in_channels, args.height, args.width)) 34 | 35 | # ONNX export 36 | torch.onnx.export( 37 | model, 38 | img_tensor, 39 | args.path, 40 | export_params=True, 41 | opset_version=14, 42 | ) 43 | 44 | 45 | if __name__ == "__main__": 46 | parser = argparse.ArgumentParser( 47 | description="Holocron model ONNX export", formatter_class=argparse.ArgumentDefaultsHelpFormatter 48 | ) 49 | parser.add_argument("arch", type=str, help="Architecture to use") 50 | parser.add_argument("--height", type=int, default=224, help="The height of the input image") 51 | parser.add_argument("--width", type=int, default=224, help="The width of the input image") 52 | parser.add_argument("--in-channels", type=int, default=3, help="The number of channels of the input image") 53 | parser.add_argument("--batch-size", type=int, default=1, help="The batch size used for the model") 54 | parser.add_argument("--path", type=str, default="./model.onnx", help="The path of the output file") 55 | parser.add_argument("--checkpoint", type=str, default=None, help="The checkpoint to restore") 56 | parser.add_argument( 57 | "--pretrained", dest="pretrained", help="Use pre-trained models from the modelzoo", action="store_true" 58 | ) 59 | args = parser.parse_args() 60 | 61 | main(args) 62 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2019-2024, François-Guillaume Fernandez. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | 7 | import os 8 | from pathlib import Path 9 | 10 | from setuptools import setup 11 | 12 | PKG_NAME = "pylocron" 13 | VERSION = os.getenv("BUILD_VERSION", "0.2.2.dev0") 14 | 15 | 16 | if __name__ == "__main__": 17 | print(f"Building wheel {PKG_NAME}-{VERSION}") 18 | 19 | # Dynamically set the __version__ attribute 20 | cwd = Path(__file__).parent.absolute() 21 | with cwd.joinpath("holocron", "version.py").open("w", encoding="utf-8") as f: 22 | f.write(f"__version__ = '{VERSION}'\n") 23 | 24 | setup(name=PKG_NAME, version=VERSION) 25 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch import nn 4 | 5 | from holocron.models import utils 6 | from holocron.models.classification.repvgg import RepVGG 7 | from holocron.nn import SAM, BlurPool2d, DropBlock2d 8 | 9 | 10 | def _test_conv_seq(conv_seq, expected_classes, expected_channels): 11 | assert len(conv_seq) == len(expected_classes) 12 | for _layer, mod_class in zip(conv_seq, expected_classes, strict=False): 13 | assert isinstance(_layer, mod_class) 14 | 15 | input_t = torch.rand(1, conv_seq[0].in_channels, 224, 224) 16 | out = torch.nn.Sequential(*conv_seq)(input_t) 17 | assert out.shape[:2] == (1, expected_channels) 18 | out.sum().backward() 19 | 20 | 21 | def test_conv_sequence(): 22 | mod = utils.conv_sequence( 23 | 3, 24 | 32, 25 | kernel_size=3, 26 | act_layer=nn.ReLU(inplace=True), 27 | norm_layer=nn.BatchNorm2d, 28 | drop_layer=DropBlock2d, 29 | attention_layer=SAM, 30 | ) 31 | 32 | _test_conv_seq(mod, [nn.Conv2d, nn.BatchNorm2d, nn.ReLU, SAM, DropBlock2d], 32) 33 | assert mod[0].kernel_size == (3, 3) 34 | 35 | mod = utils.conv_sequence( 36 | 3, 37 | 32, 38 | kernel_size=3, 39 | stride=2, 40 | act_layer=nn.ReLU(inplace=True), 41 | norm_layer=nn.BatchNorm2d, 42 | drop_layer=DropBlock2d, 43 | blurpool=True, 44 | ) 45 | _test_conv_seq(mod, [nn.Conv2d, nn.BatchNorm2d, nn.ReLU, BlurPool2d, DropBlock2d], 32) 46 | assert mod[0].kernel_size == (3, 3) 47 | assert mod[0].stride == (1, 1) 48 | assert mod[3].stride == 2 49 | assert mod[0].bias is None 50 | # Ensures that bias is added when there is no BN 51 | mod = utils.conv_sequence(3, 32, kernel_size=3, stride=2, act_layer=nn.ReLU(inplace=True)) 52 | assert isinstance(mod[0].bias, nn.Parameter) 53 | 54 | 55 | def test_fuse_conv_bn(): 56 | # Check the channel verification 57 | with pytest.raises(AssertionError): 58 | utils.fuse_conv_bn(nn.Conv2d(3, 5, 3), nn.BatchNorm2d(3)) 59 | 60 | # Prepare candidate modules 61 | conv = nn.Conv2d(3, 8, 3, padding=1, bias=False).eval() 62 | bn = nn.BatchNorm2d(8).eval() 63 | bn.weight.data = torch.rand(8) 64 | 65 | # Create the fused version 66 | fused_conv = nn.Conv2d(3, 8, 3, padding=1, bias=True).eval() 67 | k, b = utils.fuse_conv_bn(conv, bn) 68 | fused_conv.weight.data = k 69 | fused_conv.bias.data = b 70 | 71 | # Check values 72 | batch_size = 2 73 | x = torch.rand((batch_size, 3, 32, 32)) 74 | with torch.no_grad(): 75 | assert torch.allclose(bn(conv(x)), fused_conv(x), atol=1e-6) 76 | 77 | # Check the warning when there is already a bias 78 | conv = nn.Conv2d(3, 8, 3, padding=1, bias=True).eval() 79 | k, b = utils.fuse_conv_bn(conv, bn) 80 | fused_conv.weight.data = k 81 | fused_conv.bias.data = b 82 | with torch.no_grad(): 83 | assert torch.allclose(bn(conv(x)), fused_conv(x), atol=1e-6) 84 | 85 | 86 | def test_model_from_hf_hub(): 87 | model = utils.model_from_hf_hub("frgfm/repvgg_a0") 88 | # Check model type 89 | assert isinstance(model, RepVGG) 90 | 91 | # Check num of params 92 | assert sum(p.data.numel() for p in model.parameters()) == 24741642 93 | -------------------------------------------------------------------------------- /tests/test_models_classification.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | import torch 5 | from torch import nn 6 | 7 | from holocron.models import classification 8 | 9 | 10 | def _test_classification_model(name, num_classes, pretrained): 11 | batch_size = 2 12 | x = torch.rand((batch_size, 3, 224, 224)) 13 | model = classification.__dict__[name](pretrained=pretrained, num_classes=num_classes).eval() 14 | with torch.no_grad(): 15 | out = model(x) 16 | 17 | assert out.shape[0] == x.shape[0] 18 | assert out.shape[-1] == num_classes 19 | 20 | # Check backprop is OK 21 | target = torch.zeros(batch_size, dtype=torch.long) 22 | model.train() 23 | out = model(x) 24 | loss = torch.nn.functional.cross_entropy(out, target) 25 | loss.backward() 26 | 27 | 28 | def test_repvgg_reparametrize(): 29 | num_classes = 10 30 | batch_size = 2 31 | x = torch.rand((batch_size, 3, 224, 224)) 32 | model = classification.repvgg_a0(pretrained=False, num_classes=num_classes).eval() 33 | with torch.no_grad(): 34 | out = model(x) 35 | 36 | # Reparametrize 37 | model.reparametrize() 38 | # Check that there is no longer any Conv1x1 or BatchNorm 39 | for mod in model.modules(): 40 | assert not isinstance(mod, nn.BatchNorm2d) 41 | if isinstance(mod, nn.Conv2d): 42 | assert mod.weight.data.shape[2:] == (3, 3) 43 | # Check that values are still matching 44 | with torch.no_grad(): 45 | assert torch.allclose(out, model(x), atol=1e-3) 46 | 47 | 48 | def test_mobileone_reparametrize(): 49 | num_classes = 10 50 | batch_size = 2 51 | x = torch.rand((batch_size, 3, 224, 224)) 52 | model = classification.mobileone_s0(pretrained=False, num_classes=num_classes).eval() 53 | with torch.no_grad(): 54 | out = model(x) 55 | 56 | # Reparametrize 57 | model.reparametrize() 58 | # Check that there is no longer any Conv1x1 or BatchNorm 59 | for mod in model.modules(): 60 | assert not isinstance(mod, nn.BatchNorm2d) 61 | # Check that values are still matching 62 | with torch.no_grad(): 63 | assert torch.allclose(out, model(x), atol=1e-3) 64 | 65 | 66 | @pytest.mark.parametrize( 67 | ("arch", "pretrained"), 68 | [ 69 | ("darknet24", True), 70 | ("darknet19", True), 71 | ("darknet53", True), 72 | ("cspdarknet53", True), 73 | ("cspdarknet53_mish", True), 74 | ("resnet18", True), 75 | ("resnet34", True), 76 | ("resnet50", True), 77 | ("resnet101", True), 78 | ("resnet152", True), 79 | ("resnext50_32x4d", True), 80 | ("resnext101_32x8d", True), 81 | ("resnet50d", True), 82 | ("res2net50_26w_4s", True), 83 | ("tridentnet50", True), 84 | ("pyconv_resnet50", True), 85 | ("pyconvhg_resnet50", True), 86 | ("rexnet1_0x", True), 87 | ("rexnet1_3x", False), 88 | ("rexnet1_5x", False), 89 | ("rexnet2_0x", False), 90 | ("rexnet2_2x", False), 91 | ("sknet50", True), 92 | ("sknet101", True), 93 | ("sknet152", True), 94 | ("repvgg_a0", True), 95 | ("repvgg_b0", False), 96 | ("convnext_atto", True), 97 | ("convnext_femto", False), 98 | ("convnext_pico", False), 99 | ("convnext_nano", False), 100 | ("convnext_tiny", False), 101 | ("convnext_small", False), 102 | ("convnext_base", False), 103 | ("convnext_large", False), 104 | ("convnext_xl", False), 105 | ("mobileone_s0", True), 106 | ("mobileone_s1", False), 107 | ("mobileone_s2", False), 108 | ("mobileone_s3", False), 109 | ], 110 | ) 111 | def test_classification_model(arch, pretrained): 112 | num_classes = 1000 if arch.startswith("rexnet") else 10 113 | _test_classification_model(arch, num_classes, pretrained) 114 | 115 | 116 | @pytest.mark.parametrize( 117 | "arch", 118 | [ 119 | "darknet24", 120 | "darknet19", 121 | "darknet53", 122 | "cspdarknet53", 123 | "resnet18", 124 | "res2net50_26w_4s", 125 | "tridentnet50", 126 | "pyconv_resnet50", 127 | "rexnet1_0x", 128 | "sknet50", 129 | "repvgg_a0", 130 | "convnext_atto", 131 | "mobileone_s0", 132 | ], 133 | ) 134 | def test_classification_onnx_export(arch, tmpdir_factory): 135 | model = classification.__dict__[arch](pretrained=False, num_classes=10).eval() 136 | tmp_path = Path(str(tmpdir_factory.mktemp("onnx"))).joinpath(f"{arch}.onnx") 137 | img_tensor = torch.rand((1, 3, 224, 224)) 138 | with torch.no_grad(): 139 | torch.onnx.export(model, img_tensor, tmp_path, export_params=True, opset_version=14) 140 | -------------------------------------------------------------------------------- /tests/test_models_segmentation.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | import torch 5 | 6 | from holocron.models import segmentation 7 | 8 | 9 | def _test_segmentation_model(name, input_shape): 10 | num_classes = 10 11 | batch_size = 2 12 | num_channels = 3 13 | x = torch.rand((batch_size, num_channels, *input_shape)) 14 | # Check pretrained version 15 | model = segmentation.__dict__[name](pretrained=True).eval() 16 | # Check custom number of output classes 17 | model = segmentation.__dict__[name](pretrained=False, num_classes=num_classes).eval() 18 | with torch.no_grad(): 19 | out = model(x) 20 | 21 | assert isinstance(out, torch.Tensor) 22 | assert out.shape == (batch_size, num_classes, *input_shape) 23 | 24 | 25 | @pytest.mark.parametrize( 26 | ("arch", "input_shape"), 27 | [ 28 | ("unet", (256, 256)), 29 | ("unet2", (256, 256)), 30 | ("unet_rexnet13", (256, 256)), 31 | ("unet_tvvgg11", (256, 256)), 32 | ("unet_tvresnet34", (256, 256)), 33 | ("unetp", (256, 256)), 34 | ("unetpp", (256, 256)), 35 | ("unet3p", (320, 320)), 36 | ], 37 | ) 38 | def test_segmentation_model(arch, input_shape): 39 | _test_segmentation_model(arch, input_shape) 40 | 41 | 42 | @pytest.mark.parametrize( 43 | ("arch", "input_shape"), 44 | [ 45 | ("unet", (256, 256)), 46 | ("unet2", (256, 256)), 47 | ("unetp", (256, 256)), 48 | ("unetpp", (256, 256)), 49 | ("unet3p", (320, 320)), 50 | ], 51 | ) 52 | def test_segmentation_onnx_export(arch, input_shape, tmpdir_factory): 53 | model = segmentation.__dict__[arch](pretrained=False, num_classes=10).eval() 54 | tmp_path = Path(str(tmpdir_factory.mktemp("onnx"))).joinpath(f"{arch}.onnx") 55 | img_tensor = torch.rand((1, 3, *input_shape)) 56 | with torch.no_grad(): 57 | torch.onnx.export(model, img_tensor, tmp_path, export_params=True, opset_version=14) 58 | -------------------------------------------------------------------------------- /tests/test_nn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from holocron.nn.modules import dropblock 4 | 5 | 6 | def test_dropblock2d(): 7 | x = torch.rand(2, 4, 16, 16) 8 | 9 | # Drop probability of 1 10 | mod = dropblock.DropBlock2d(1.0, 1, inplace=False) 11 | 12 | with torch.no_grad(): 13 | out = mod(x) 14 | assert torch.equal(out, torch.zeros_like(x)) 15 | 16 | # Drop probability of 0 17 | mod = dropblock.DropBlock2d(0.0, 3, inplace=False) 18 | 19 | with torch.no_grad(): 20 | out = mod(x) 21 | assert torch.equal(out, x) 22 | assert out.data_ptr == x.data_ptr 23 | 24 | # Check inference mode 25 | mod = dropblock.DropBlock2d(1.0, 3, inplace=False).eval() 26 | 27 | with torch.no_grad(): 28 | out = mod(x) 29 | assert torch.equal(out, x) 30 | 31 | # Check inplace 32 | mod = dropblock.DropBlock2d(1.0, 3, inplace=True) 33 | 34 | with torch.no_grad(): 35 | out = mod(x) 36 | assert out.data_ptr == x.data_ptr 37 | -------------------------------------------------------------------------------- /tests/test_nn_activation.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | import torch 4 | 5 | from holocron.nn import functional as F 6 | from holocron.nn.modules import activation 7 | 8 | 9 | def _test_activation_function(fn, input_shape): 10 | # Optional testing 11 | fn_args = inspect.signature(fn).parameters.keys() 12 | cfg = {} 13 | if "inplace" in fn_args: 14 | cfg["inplace"] = [False, True] 15 | 16 | # Generate inputs 17 | x = torch.rand(input_shape) 18 | 19 | # Optional argument testing 20 | kwargs = {} 21 | for inplace in cfg.get("inplace", [None]): 22 | if isinstance(inplace, bool): 23 | kwargs["inplace"] = inplace 24 | out = fn(x, **kwargs) 25 | assert out.shape == x.shape 26 | if kwargs.get("inplace", False): 27 | assert x.data_ptr() == out.data_ptr() 28 | 29 | 30 | def test_hard_mish(): 31 | _test_activation_function(F.hard_mish, (4, 3, 32, 32)) 32 | assert repr(activation.HardMish()) == "HardMish()" 33 | 34 | 35 | def test_nl_relu(): 36 | _test_activation_function(F.nl_relu, (4, 3, 32, 32)) 37 | assert repr(activation.NLReLU()) == "NLReLU()" 38 | 39 | 40 | def test_frelu(): 41 | mod = activation.FReLU(8).eval() 42 | with torch.no_grad(): 43 | _test_activation_function(mod.forward, (4, 8, 32, 32)) 44 | assert len(repr(mod).split("\n")) == 4 45 | -------------------------------------------------------------------------------- /tests/test_nn_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from holocron import nn 4 | 5 | 6 | def _test_attention_mod(mod): 7 | x = torch.rand(2, 4, 8, 8) 8 | # Check that attention preserves shape 9 | mod = mod.eval() 10 | with torch.no_grad(): 11 | out = mod(x) 12 | assert x.shape == out.shape 13 | # Check that it doesn't break backprop 14 | mod = mod.train() 15 | out = mod(x) 16 | out.sum().backward() 17 | assert isinstance(next(mod.parameters()).grad, torch.Tensor) 18 | 19 | 20 | def test_sam(): 21 | _test_attention_mod(nn.SAM(4)) 22 | 23 | 24 | def test_triplet_attention(): 25 | _test_attention_mod(nn.TripletAttention()) 26 | -------------------------------------------------------------------------------- /tests/test_nn_conv.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from holocron import nn 5 | 6 | 7 | def _test_conv2d(mod, input_shape, output_shape): 8 | x = torch.rand(*input_shape) 9 | 10 | out = mod(x) 11 | assert out.shape == output_shape 12 | # Check that backprop works 13 | out.sum().backward() 14 | 15 | 16 | def test_normconv2d(): 17 | _test_conv2d(nn.NormConv2d(8, 16, 3, padding=1), (2, 8, 16, 16), (2, 16, 16, 16)) 18 | _test_conv2d(nn.NormConv2d(8, 16, 3, padding=1, padding_mode="reflect"), (2, 8, 16, 16), (2, 16, 16, 16)) 19 | 20 | 21 | def test_add2d(): 22 | _test_conv2d(nn.Add2d(8, 16, 3, padding=1), (2, 8, 16, 16), (2, 16, 16, 16)) 23 | _test_conv2d(nn.Add2d(8, 16, 3, padding=1, padding_mode="reflect"), (2, 8, 16, 16), (2, 16, 16, 16)) 24 | 25 | 26 | def test_slimconv2d(): 27 | _test_conv2d(nn.SlimConv2d(8, 3, padding=1, r=32, L=2), (2, 8, 16, 16), (2, 6, 16, 16)) 28 | 29 | 30 | def test_pyconv2d(): 31 | for num_levels in range(1, 5): 32 | _test_conv2d(nn.PyConv2d(8, 16, 3, num_levels, padding=1), (2, 8, 16, 16), (2, 16, 16, 16)) 33 | 34 | 35 | def test_lambdalayer(): 36 | with pytest.raises(AssertionError): 37 | nn.LambdaLayer(3, 31, 16) 38 | with pytest.raises(AssertionError): 39 | nn.LambdaLayer(3, 32, 16, r=2) 40 | with pytest.raises(AssertionError): 41 | nn.LambdaLayer(3, 32, 16, r=None, n=None) 42 | 43 | _test_conv2d(nn.LambdaLayer(8, 32, 16, r=13), (2, 8, 32, 32), (2, 32, 32, 32)) 44 | 45 | 46 | def test_involution2d(): 47 | _test_conv2d(nn.Involution2d(8, 3, 1, reduction_ratio=2), (2, 8, 16, 16), (2, 8, 16, 16)) 48 | _test_conv2d(nn.Involution2d(8, 3, 1, 2, reduction_ratio=2), (2, 8, 16, 16), (2, 8, 8, 8)) 49 | -------------------------------------------------------------------------------- /tests/test_nn_downsample.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch import nn 4 | 5 | from holocron.nn import functional as F 6 | from holocron.nn.modules import downsample 7 | 8 | 9 | def test_concatdownsample2d(): 10 | num_batches = 2 11 | num_chan = 4 12 | scale_factor = 2 13 | x = torch.arange(num_batches * num_chan * 4**2).view(num_batches, num_chan, 4, 4) 14 | 15 | # Test functional API 16 | with pytest.raises(AssertionError): 17 | F.concat_downsample2d(x, 3) 18 | out = F.concat_downsample2d(x, scale_factor) 19 | assert out.shape == ( 20 | num_batches, 21 | num_chan * scale_factor**2, 22 | x.shape[2] // scale_factor, 23 | x.shape[3] // scale_factor, 24 | ) 25 | 26 | # Check first and last values 27 | assert torch.equal(out[0][0], torch.tensor([[0, 2], [8, 10]])) 28 | assert torch.equal(out[0][-num_chan], torch.tensor([[5, 7], [13, 15]])) 29 | # Test module 30 | mod = downsample.ConcatDownsample2d(scale_factor) 31 | assert torch.equal(mod(x), out) 32 | # Test JIT module 33 | mod = downsample.ConcatDownsample2dJit(scale_factor) 34 | assert torch.equal(mod(x), out) 35 | 36 | 37 | def test_globalavgpool2d(): 38 | x = torch.rand(2, 4, 16, 16) 39 | 40 | # Check that ops are doing the same thing 41 | ref = nn.AdaptiveAvgPool2d(1) 42 | mod = downsample.GlobalAvgPool2d(flatten=False) 43 | out = mod(x) 44 | assert torch.equal(out, ref(x)) 45 | assert out.data_ptr != x.data_ptr 46 | 47 | # Check that flatten works 48 | x = torch.rand(2, 4, 16, 16) 49 | mod = downsample.GlobalAvgPool2d(flatten=True) 50 | assert torch.equal(mod(x), ref(x).view(*x.shape[:2])) 51 | 52 | 53 | def test_globalmaxpool2d(): 54 | x = torch.rand(2, 4, 16, 16) 55 | 56 | # Check that ops are doing the same thing 57 | ref = nn.AdaptiveMaxPool2d(1) 58 | mod = downsample.GlobalMaxPool2d(flatten=False) 59 | out = mod(x) 60 | assert torch.equal(out, ref(x)) 61 | assert out.data_ptr != x.data_ptr 62 | 63 | # Check that flatten works 64 | x = torch.rand(2, 4, 16, 16) 65 | mod = downsample.GlobalMaxPool2d(flatten=True) 66 | assert torch.equal(mod(x), ref(x).view(*x.shape[:2])) 67 | 68 | 69 | def test_blurpool2d(): 70 | with pytest.raises(AssertionError): 71 | downsample.BlurPool2d(1, 0) 72 | 73 | # Generate inputs 74 | num_batches = 2 75 | num_chan = 8 76 | x = torch.rand((num_batches, num_chan, 5, 5)) 77 | mod = downsample.BlurPool2d(num_chan, stride=2) 78 | 79 | # Optional argument testing 80 | with torch.no_grad(): 81 | out = mod(x) 82 | assert out.shape == (num_batches, num_chan, 3, 3) 83 | 84 | k = torch.tensor([[0.0625, 0.125, 0.0625], [0.125, 0.25, 0.125], [0.0625, 0.125, 0.0625]]) 85 | assert torch.allclose(out[..., 1, 1], (x[..., 1:-1, 1:-1] * k[None, None, ...]).sum(dim=(2, 3)), atol=1e-7) 86 | 87 | 88 | def test_zpool(): 89 | num_batches = 2 90 | num_chan = 4 91 | x = torch.rand((num_batches, num_chan, 32, 32)) 92 | 93 | # Test functional API 94 | out = F.z_pool(x, 1) 95 | assert out.shape == (num_batches, 2, 32, 32) 96 | assert out[0, 0, 0, 0].item() == x[0, :, 0, 0].max().item() 97 | assert out[0, 1, 0, 0].item() == x[0, :, 0, 0].mean().item() 98 | 99 | # Test module 100 | mod = downsample.ZPool(1) 101 | assert torch.equal(mod(x), out) 102 | -------------------------------------------------------------------------------- /tests/test_nn_init.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from holocron.nn import init 5 | 6 | 7 | def test_init(): 8 | module = nn.Sequential(nn.Conv2d(3, 32, 3), nn.BatchNorm2d(32), nn.LeakyReLU(inplace=True)) 9 | 10 | # Check that each layer was initialized correctly 11 | init.init_module(module, "leaky_relu") 12 | assert torch.all(module[0].bias.data == 0) 13 | assert torch.all(module[1].weight.data == 1) 14 | assert torch.all(module[1].bias.data == 0) 15 | -------------------------------------------------------------------------------- /tests/test_ops.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import pytest 4 | import torch 5 | 6 | from holocron import ops 7 | 8 | 9 | @pytest.fixture 10 | def boxes(): 11 | return torch.tensor( 12 | [[0, 0, 100, 100], [50, 50, 100, 100], [50, 50, 150, 150], [100, 100, 200, 200]], dtype=torch.float32 13 | ) 14 | 15 | 16 | def test_iou_penalty(boxes): 17 | penalty = ops.boxes.iou_penalty(boxes, boxes) 18 | 19 | # Check shape 20 | assert penalty.shape == (boxes.shape[0], boxes.shape[0]) 21 | # Unit tests 22 | for idx in range(boxes.shape[0]): 23 | assert penalty[idx, idx].item() == 0 24 | 25 | assert penalty[0, 1].item() == 25**2 / 100**2 26 | assert penalty[0, 3].item() == 100**2 / 200**2 27 | assert penalty[0, 2].item() == penalty[2, 3].item() 28 | 29 | 30 | def test_diou_loss(boxes): 31 | diou = ops.boxes.diou_loss(boxes, boxes) 32 | 33 | # Check shape 34 | assert diou.shape == (boxes.shape[0], boxes.shape[0]) 35 | # Unit tests 36 | for idx in range(boxes.shape[0]): 37 | assert diou[idx, idx].item() == 0.0 38 | 39 | assert diou[0, 1].item() == 1 - 0.25 + 25**2 / 100**2 40 | assert diou[0, 3].item() == 1 + 100**2 / 200**2 41 | assert diou[0, 2].item() == diou[2, 3].item() 42 | 43 | 44 | def test_box_giou(boxes): 45 | giou = ops.boxes.box_giou(boxes, boxes) 46 | 47 | # Check shape 48 | assert giou.shape == (boxes.shape[0], boxes.shape[0]) 49 | # Unit tests 50 | for idx in range(boxes.shape[0]): 51 | assert giou[idx, idx].item() == 1.0 52 | 53 | assert giou[0, 1].item() == 0.25 54 | assert giou[0, 3].item() == -(200**2 - 2 * 100**2) / 200**2 55 | assert giou[0, 2].item() == giou[2, 3].item() 56 | 57 | 58 | def test_aspect_ratio(boxes): 59 | # All boxes are squares so arctan should yield Pi / 4 60 | assert torch.equal(ops.boxes.aspect_ratio(boxes), math.pi / 4 * torch.ones(boxes.shape[0])) 61 | 62 | 63 | def test_aspect_ratio_consistency(boxes): 64 | # All boxes have the same aspect ratio 65 | assert torch.equal(ops.boxes.aspect_ratio_consistency(boxes, boxes), torch.zeros(boxes.shape[0], boxes.shape[0])) 66 | 67 | 68 | def test_ciou_loss(boxes): 69 | ciou = ops.boxes.ciou_loss(boxes, boxes) 70 | 71 | # Check shape 72 | assert ciou.shape == (boxes.shape[0], boxes.shape[0]) 73 | # Unit tests 74 | for idx in range(boxes.shape[0]): 75 | assert ciou[idx, idx].item() == 0.0 76 | assert ciou[0, 2].item() == ciou[2, 3].item() 77 | -------------------------------------------------------------------------------- /tests/test_optim.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | from torchvision.models import mobilenet_v3_small 6 | 7 | from holocron import optim 8 | 9 | 10 | def _test_optimizer(name: str, **kwargs: Any) -> None: 11 | lr = 1e-4 12 | input_shape = (3, 224, 224) 13 | num_batches = 4 14 | # Get model and optimizer 15 | model = mobilenet_v3_small(num_classes=10) 16 | for p in model.parameters(): 17 | p.requires_grad_(False) 18 | for p in model.classifier[3].parameters(): 19 | p.requires_grad_(True) 20 | optimizer = optim.__dict__[name](model.classifier[3].parameters(), lr=lr, **kwargs) 21 | 22 | # Save param value 23 | p_ = model.classifier[3].weight 24 | p_val = p_.data.clone() 25 | 26 | # Random inputs 27 | input_t = torch.rand((num_batches, *input_shape), dtype=torch.float32) 28 | target = torch.zeros(num_batches, dtype=torch.long) 29 | 30 | # Update 31 | optimizer.zero_grad() 32 | output = model(input_t) 33 | loss = F.cross_entropy(output, target) 34 | loss.backward() 35 | optimizer.step() 36 | 37 | # Test 38 | assert p_.grad is not None 39 | assert not torch.equal(p_.data, p_val) 40 | 41 | 42 | def test_lars(): 43 | _test_optimizer("LARS", momentum=0.9, weight_decay=2e-5) 44 | 45 | 46 | def test_lamb(): 47 | _test_optimizer("LAMB", weight_decay=2e-5) 48 | 49 | 50 | def test_ralars(): 51 | _test_optimizer("RaLars", weight_decay=2e-5) 52 | 53 | 54 | def test_tadam(): 55 | _test_optimizer("TAdam") 56 | 57 | 58 | def test_adabelief(): 59 | _test_optimizer("AdaBelief") 60 | 61 | 62 | def test_adamp(): 63 | _test_optimizer("AdamP") 64 | 65 | 66 | def test_adan(): 67 | _test_optimizer("Adan") 68 | 69 | 70 | def test_ademamix(): 71 | _test_optimizer("AdEMAMix") 72 | -------------------------------------------------------------------------------- /tests/test_optim_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | from torch.optim import SGD 4 | from torchvision.models import mobilenet_v3_small 5 | 6 | from holocron.optim import wrapper 7 | 8 | 9 | def _test_wrapper(name: str) -> None: 10 | lr = 1e-4 11 | input_shape = (3, 224, 224) 12 | num_batches = 4 13 | # Get model, optimizer and criterion 14 | model = mobilenet_v3_small(num_classes=10) 15 | for p in model.parameters(): 16 | p.requires_grad_(False) 17 | for p in model.classifier[3].parameters(): 18 | p.requires_grad_(True) 19 | # Pick an optimizer whose update is easy to verify 20 | optimizer = SGD(model.classifier[3].parameters(), lr=lr) 21 | 22 | # Wrap the optimizer 23 | opt_wrapper = wrapper.__dict__[name](optimizer) 24 | 25 | # Check gradient reset 26 | opt_wrapper.zero_grad() 27 | for group in optimizer.param_groups: 28 | for p in group["params"]: 29 | if p.grad is not None: 30 | assert torch.all(p.grad == 0.0) 31 | 32 | # Check update step 33 | p_ = model.classifier[3].weight 34 | p_val = p_.data.clone() 35 | 36 | # Random inputs 37 | input_t = torch.rand((num_batches, *input_shape), dtype=torch.float32) 38 | target = torch.zeros(num_batches, dtype=torch.long) 39 | 40 | # Update 41 | for _ in range(10): 42 | output = model(input_t) 43 | loss = F.cross_entropy(output, target) 44 | loss.backward() 45 | opt_wrapper.step() 46 | # Check update rule 47 | assert not torch.equal(p_.data, p_val) 48 | assert not torch.equal(p_.data, p_val - lr * p_.grad) 49 | 50 | # Repr 51 | assert len(repr(opt_wrapper).split("\n")) == len(repr(optimizer).split("\n")) + 4 52 | 53 | 54 | def test_lookahead(): 55 | _test_wrapper("Lookahead") 56 | 57 | 58 | def test_scout(): 59 | _test_wrapper("Scout") 60 | -------------------------------------------------------------------------------- /tests/test_trainer_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch import nn 4 | 5 | from holocron import trainer 6 | 7 | 8 | def test_freeze_bn(): 9 | # Simple module with BN 10 | mod = nn.Sequential(nn.Conv2d(3, 32, 3), nn.BatchNorm2d(32), nn.ReLU(inplace=True)) 11 | nb = mod[1].num_batches_tracked.clone() 12 | rm = mod[1].running_mean.clone() 13 | rv = mod[1].running_var.clone() 14 | # Freeze & forward 15 | for p in mod.parameters(): 16 | p.requires_grad_(False) 17 | trainer.freeze_bn(mod) 18 | for _ in range(10): 19 | _ = mod(torch.rand((1, 3, 32, 32))) 20 | # Check that stats were not updated 21 | assert torch.equal(mod[1].num_batches_tracked, nb) 22 | assert torch.equal(mod[1].running_mean, rm) 23 | assert torch.equal(mod[1].running_var, rv) 24 | 25 | 26 | def test_freeze_model(): 27 | # Simple model 28 | mod = nn.Sequential(nn.Conv2d(3, 32, 3), nn.ReLU(inplace=True), nn.Conv2d(32, 64, 3), nn.ReLU(inplace=True)) 29 | trainer.freeze_model(mod, "0") 30 | # Check that the correct layers were frozen 31 | assert not any(p.requires_grad for p in mod[0].parameters()) 32 | assert all(p.requires_grad for p in mod[2].parameters()) 33 | with pytest.raises(ValueError): 34 | trainer.freeze_model(mod, "wrong_layer") 35 | 36 | # Freeze last layer 37 | for p in mod[-1].parameters(): 38 | p.requires_grad_(False) 39 | trainer.freeze_model(mod, "0") 40 | # Ensure the last layer is now unfrozen 41 | assert all(p.requires_grad for p in mod[-1].parameters()) 42 | -------------------------------------------------------------------------------- /tests/test_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | from PIL import Image 5 | from torch import nn 6 | 7 | from holocron import transforms as T 8 | from holocron.transforms.interpolation import ResizeMethod 9 | 10 | 11 | def test_resize(): 12 | # Arg check 13 | with pytest.raises(ValueError): 14 | T.Resize(16) 15 | 16 | with pytest.raises(ValueError): 17 | T.Resize((16, 16), mode="stretch") 18 | 19 | with pytest.raises(ValueError): 20 | T.Resize((16, 16), mode="pad") 21 | 22 | img1 = np.full((16, 32, 3), 255, dtype=np.uint8) 23 | img2 = np.full((32, 16, 3), 255, dtype=np.uint8) 24 | tf = T.Resize((32, 32), mode=ResizeMethod.PAD) 25 | assert isinstance(tf, nn.Module) 26 | 27 | # PIL Image 28 | out = tf(Image.fromarray(img1)) 29 | assert isinstance(out, Image.Image) 30 | assert out.size == (32, 32) 31 | np_out = np.asarray(out) 32 | assert np.all(np_out[8:-8] == 255) 33 | assert np.all(np_out[:8] == 0) 34 | assert np.all(np_out[-8:]) == 0 35 | out = tf(Image.fromarray(img2)) 36 | assert isinstance(out, Image.Image) 37 | assert out.size == (32, 32) 38 | np_out = np.asarray(out) 39 | assert np.all(np_out[:, 8:-8] == 255) 40 | assert np.all(np_out[:, :8] == 0) 41 | assert np.all(np_out[:, -8:]) == 0 42 | # Squish 43 | out = T.Resize((32, 32), mode=ResizeMethod.SQUISH)(Image.fromarray(img1)) 44 | assert np.all(np.asarray(out) == 255) 45 | 46 | # Tensor 47 | out = tf(torch.from_numpy(img1).to(dtype=torch.float32).permute(2, 0, 1) / 255) 48 | assert isinstance(out, torch.Tensor) 49 | assert out.shape == (3, 32, 32) 50 | np_out = out.numpy() 51 | assert np.all(np_out[:, 8:-8] == 1) 52 | assert np.all(np_out[:, :8] == 0) 53 | assert np.all(np_out[:, -8:]) == 0 54 | out = tf(torch.from_numpy(img2).to(dtype=torch.float32).permute(2, 0, 1) / 255) 55 | assert isinstance(out, torch.Tensor) 56 | assert out.shape == (3, 32, 32) 57 | np_out = out.numpy() 58 | assert np.all(np_out[:, :, 8:-8] == 1) 59 | assert np.all(np_out[:, :, :8] == 0) 60 | assert np.all(np_out[:, :, -8:]) == 0 61 | 62 | 63 | def test_randomzoomout(): 64 | # Arg check 65 | with pytest.raises(ValueError): 66 | T.RandomZoomOut(224) 67 | 68 | with pytest.raises(ValueError): 69 | T.Resize((16, 16), (1, 0.5)) 70 | 71 | pil_img = Image.fromarray(np.full((64, 64, 3), 255, dtype=np.uint8)) 72 | torch_img = torch.ones((3, 64, 64), dtype=torch.float32) 73 | tf = T.RandomZoomOut((32, 32), scale=(0.5, 0.99)) 74 | assert isinstance(tf, nn.Module) 75 | 76 | # PIL Image 77 | out = tf(pil_img) 78 | assert isinstance(out, Image.Image) 79 | assert out.size == (32, 32) 80 | np_out = np.asarray(out) 81 | assert np.all(np_out[16, 16] == 255) 82 | assert np_out.mean() < 255 83 | 84 | # Tensor 85 | out = tf(torch_img) 86 | assert isinstance(out, torch.Tensor) 87 | assert out.shape == (3, 32, 32) 88 | np_out = np.asarray(out) 89 | assert np.all(np_out[:, 16, 16] == 1) 90 | assert np_out.mean() < 1 91 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | from PIL import Image 5 | 6 | from holocron import utils 7 | 8 | 9 | def test_mixup(): 10 | batch_size = 8 11 | num_classes = 10 12 | shape = (3, 32, 32) 13 | with pytest.raises(ValueError): 14 | utils.data.Mixup(num_classes, alpha=-1.0) 15 | # Generate all dependencies 16 | mix = utils.data.Mixup(num_classes, alpha=0.2) 17 | img, target = torch.rand((batch_size, *shape)), torch.arange(num_classes)[:batch_size] 18 | mix_img, mix_target = mix(img.clone(), target.clone()) 19 | assert img.shape == (batch_size, *shape) 20 | assert not torch.equal(img, mix_img) 21 | assert mix_target.dtype == torch.float32 22 | assert mix_target.shape == (batch_size, num_classes) 23 | assert torch.all(mix_target.sum(dim=1) == 1.0) 24 | count = (mix_target > 0).sum(dim=1) 25 | assert torch.all((count == 2.0) | (count == 1.0)) 26 | 27 | # Alpha = 0 case 28 | mix = utils.data.Mixup(num_classes, alpha=0.0) 29 | mix_img, mix_target = mix(img.clone(), target.clone()) 30 | assert torch.equal(img, mix_img) 31 | assert mix_target.dtype == torch.float32 32 | assert mix_target.shape == (batch_size, num_classes) 33 | assert torch.all(mix_target.sum(dim=1) == 1.0) 34 | assert torch.all((mix_target > 0).sum(dim=1) == 1.0) 35 | 36 | # Binary target 37 | mix = utils.data.Mixup(1, alpha=0.5) 38 | img = torch.rand((batch_size, *shape)) 39 | target = torch.concat((torch.zeros(batch_size // 2), torch.ones(batch_size - batch_size // 2))) 40 | mix_img, mix_target = mix(img.clone(), target.clone()) 41 | assert img.shape == (batch_size, *shape) 42 | assert not torch.equal(img, mix_img) 43 | assert mix_target.dtype == torch.float32 44 | assert mix_target.shape == (batch_size, 1) 45 | 46 | # Already in one-hot 47 | mix = utils.data.Mixup(num_classes, alpha=0.2) 48 | img, target = torch.rand((batch_size, *shape)), torch.rand((batch_size, num_classes)) 49 | mix_img, mix_target = mix(img.clone(), target.clone()) 50 | assert img.shape == (batch_size, *shape) 51 | assert not torch.equal(img, mix_img) 52 | assert mix_target.dtype == torch.float32 53 | assert mix_target.shape == (batch_size, num_classes) 54 | 55 | 56 | @pytest.mark.parametrize( 57 | ("arr", "fn", "expected", "progress", "num_threads"), 58 | [ 59 | ([1, 2, 3], lambda x: x**2, [1, 4, 9], False, 3), 60 | ([1, 2, 3], lambda x: x**2, [1, 4, 9], True, 1), 61 | ("hello", lambda x: x.upper(), list("HELLO"), True, None), 62 | ("hello", lambda x: x.upper(), list("HELLO"), False, None), 63 | ], 64 | ) 65 | def test_parallel(arr, fn, expected, progress, num_threads): 66 | assert utils.parallel(fn, arr, progress=progress, num_threads=num_threads) == expected 67 | 68 | 69 | def test_find_image_size(): 70 | ds = [(Image.fromarray(np.full((16, 16, 3), 255, dtype=np.uint8)), 0) for _ in range(100)] 71 | utils.find_image_size(ds, block=False) 72 | --------------------------------------------------------------------------------