├── .dockerignore ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md └── workflows │ └── build-and-push-app.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── MANIFEST.in ├── README.md ├── configs ├── advanced_det_seg.yaml ├── schemas │ ├── config_schema_export.json │ ├── config_schema_full.json │ ├── config_schema_train.json │ ├── config_schema_tune.json │ ├── schema_assumptions.txt │ └── test_schema.py ├── simple_class.yaml ├── simple_det.yaml ├── simple_keypoint.yaml ├── simple_seg.yaml └── simple_tuner.yaml ├── docker ├── Dockerfile ├── docker-compose.yaml └── entrypoint.sh ├── requirements-dev.txt ├── requirements.txt ├── setup.py ├── src ├── luxonis_train │ ├── __init__.py │ ├── config_db │ │ ├── config_all.yaml │ │ ├── yolov6-n.yaml │ │ ├── yolov6-s.yaml │ │ └── yolov6-t.yaml │ ├── core │ │ ├── __init__.py │ │ ├── exporter.py │ │ ├── inferer.py │ │ ├── trainer.py │ │ └── tuner.py │ ├── datasets │ │ └── __init__.py │ ├── models │ │ ├── README.md │ │ ├── __init__.py │ │ ├── backbones │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── context_spatial.py │ │ │ ├── efficient_rep.py │ │ │ ├── efficientnet.py │ │ │ ├── micronet.py │ │ │ ├── mobilenetv2.py │ │ │ ├── mobileone.py │ │ │ ├── rep_vgg.py │ │ │ ├── resnet18.py │ │ │ └── rexnetv1.py │ │ ├── heads │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── base_heads.py │ │ │ ├── bisenet_head.py │ │ │ ├── classification_head.py │ │ │ ├── effide_head.py │ │ │ ├── ikeypoint_head.py │ │ │ ├── multilabel_classification_head.py │ │ │ ├── segmentation_head.py │ │ │ └── yolov6_head.py │ │ ├── model.py │ │ ├── model_lightning_module.py │ │ ├── modules │ │ │ ├── __init__.py │ │ │ └── common.py │ │ └── necks │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ └── reppan_neck.py │ └── utils │ │ ├── __init__.py │ │ ├── assigners │ │ ├── __init__.py │ │ ├── anchor_generator.py │ │ ├── assigner_utils.py │ │ ├── atts_assigner.py │ │ ├── iou2d_calculator.py │ │ └── tal_assigner.py │ │ ├── boxutils.py │ │ ├── callbacks.py │ │ ├── config.py │ │ ├── constants.py │ │ ├── general.py │ │ ├── losses │ │ ├── README.md │ │ ├── __init__.py │ │ ├── common.py │ │ ├── utils.py │ │ ├── yolov6_loss.py │ │ └── yolov7_pose_loss.py │ │ ├── metrics │ │ ├── __init__.py │ │ ├── custom.py │ │ └── utils.py │ │ ├── optimizers.py │ │ ├── schedulers.py │ │ └── visualization.py └── tests │ ├── config_tests.py │ ├── test_config.yaml │ └── test_config_fail.yaml └── tools ├── export.py ├── infer.py ├── store_config.py ├── test_dataset.py ├── train.py └── tune.py /.dockerignore: -------------------------------------------------------------------------------- 1 | # Git 2 | .git 3 | .gitignore 4 | 5 | # CI 6 | .codeclimate.yml 7 | .travis.yml 8 | .taskcluster.yml 9 | 10 | # Docker 11 | docker-compose.yml 12 | .docker 13 | 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | */__pycache__/ 17 | */*/__pycache__/ 18 | */*/*/__pycache__/ 19 | *.py[cod] 20 | */*.py[cod] 21 | */*/*.py[cod] 22 | */*/*/*.py[cod] 23 | 24 | # C extensions 25 | *.so 26 | 27 | # Distribution / packaging 28 | .Python 29 | env/ 30 | build/ 31 | develop-eggs/ 32 | dist/ 33 | downloads/ 34 | eggs/ 35 | lib/ 36 | lib64/ 37 | parts/ 38 | sdist/ 39 | var/ 40 | *.egg-info/ 41 | .installed.cfg 42 | *.egg 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | 54 | # Unit test / coverage reports 55 | htmlcov/ 56 | .tox/ 57 | .coverage 58 | .cache 59 | nosetests.xml 60 | coverage.xml 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Virtual environment 76 | .env/ 77 | .venv/ 78 | venv/ 79 | 80 | # PyCharm 81 | .idea 82 | 83 | # Python mode for VIM 84 | .ropeproject 85 | */.ropeproject 86 | */*/.ropeproject 87 | */*/*/.ropeproject 88 | 89 | # Vim swap files 90 | *.swp 91 | */*.swp 92 | */*/*.swp 93 | */*/*/*.swp -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: "[BUG_TITLE]" 5 | labels: bug 6 | assignees: kozlov721, tersekmatija, conorsim 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **Expected behavior** 14 | A clear and concise description of what you expected to happen. 15 | 16 | **Screenshots** 17 | If applicable, add screenshots to help explain your problem. 18 | 19 | **Environment (please complete the following information):** 20 | - OS: [e.g. iOS] 21 | - Package version: [e.g. 0.1.0] 22 | 23 | **To Reproduce** 24 | Please provide a minimal reproducible example. This would include: 25 | - Config file 26 | - Script to generate the dataset in case the issue cannot be reproduced using one of our [example](../../examples) datasets 27 | - Sript to run the pipeline 28 | 29 | **Additional context** 30 | Add any other context about the problem here. 31 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: "[FEATURE REQUEST]" 5 | labels: enhancement 6 | assignees: conorsim, kozlov721, tersekmatija 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Additional context** 17 | Add any other context or screenshots about the feature request here. 18 | 19 | **Proposed solution** 20 | If you have a possible solution in mind, you can include more details about it here. 21 | -------------------------------------------------------------------------------- /.github/workflows/build-and-push-app.yaml: -------------------------------------------------------------------------------- 1 | name: Build and deploy app image 2 | 3 | on: 4 | push: 5 | branches: 6 | - docker 7 | workflow_dispatch: 8 | 9 | jobs: 10 | push-store: 11 | name: Build and push image 12 | runs-on: ubuntu-latest 13 | steps: 14 | - name: 'Checkout GitHub Action' 15 | uses: actions/checkout@v2 16 | with: 17 | ref: ${{ github.ref }} 18 | 19 | - name: 'Login to GitHub Container Registry' 20 | uses: docker/login-action@v1 21 | with: 22 | registry: ghcr.io 23 | username: luxonis-ml 24 | password: ${{secrets.GHCR_PAT}} 25 | 26 | - name: Get commit name 27 | id: commit 28 | run: echo "sha_short=$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT 29 | 30 | - name: Extract branch name 31 | shell: bash 32 | run: echo "branch=${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}" >> $GITHUB_OUTPUT 33 | id: extract_branch 34 | 35 | - name: 'Build Inventory Image' 36 | run: | 37 | docker build --build-arg GITHUB_TOKEN=${{secrets.GHCR_PAT}} -f docker/Dockerfile . --tag ghcr.io/luxonis/models:${{ steps.extract_branch.outputs.branch }}.${{ steps.commit.outputs.sha_short }} 38 | docker push ghcr.io/luxonis/models --all-tags -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | demo/* 2 | output/* 3 | output_export/* 4 | 5 | # database 6 | *.db 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | pip-wheel-metadata/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ 137 | 138 | # Datasets 139 | cifar_ldf/* 140 | cifar_small_ldf/* 141 | 142 | # Venv 143 | models_venv/* -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/ambv/black 3 | rev: 23.3.0 4 | hooks: 5 | - id: black 6 | language_version: python3.8 7 | - repo: https://github.com/pre-commit/pre-commit-hooks 8 | rev: v4.4.0 9 | hooks: 10 | - id: no-commit-to-branch 11 | args: ['--branch', 'main', '--branch', 'dev'] 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Luxonis 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include src/luxonis_train/config_db/*.yaml -------------------------------------------------------------------------------- /configs/advanced_det_seg.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | name: AdvancedDetectionSegmentation 3 | type: yolov6-n 4 | pretrained: 5 | params: 6 | n_classes: null 7 | 8 | additional_heads: 9 | - name: SegmentationHead 10 | params: 11 | n_classes: null 12 | loss: 13 | name: FocalLoss 14 | params: 15 | 16 | dataset: 17 | team_id: # TODO 18 | dataset_id: # TODO 19 | 20 | train: 21 | preprocessing: 22 | train_image_size: [512,512] 23 | 24 | train_metrics_interval: 100 25 | validation_interval: 10 26 | 27 | optimizers: 28 | optimizer: 29 | name: SGD 30 | params: 31 | lr: 0.02 32 | momentum: 0.937 33 | nesterov: True 34 | weight_decay: 0.0005 35 | scheduler: 36 | name: CosineAnnealingLR 37 | params: 38 | T_max: 100 39 | eta_min: 0 40 | 41 | losses: 42 | log_sub_losses: True 43 | weights: [0.1, 1] 44 | 45 | -------------------------------------------------------------------------------- /configs/schemas/schema_assumptions.txt: -------------------------------------------------------------------------------- 1 | num_sanity_val_steps: [0,2], default 2 2 | batch_size: [1,64], default 32 3 | accumulate_grad_batches: [1, 4], default: 1 4 | epochs: [1,100], default: 100 5 | num_workers: [0,16], default: 2 6 | train_metrics_interval: [-1, ], default: -1 7 | validation_interval: [1, ], default: 1 8 | num_log_images: [0, 4], default: 4 9 | model_checkpoint.save_top_k: [1,3], default: 3 10 | 11 | in exporter: 12 | export_weights required but only if you are exporting -> new schema? 13 | shaves: [1,10], default 6 (https://docs.luxonis.com/en/latest/pages/faq/#what-are-the-shaves) 14 | 15 | tuner: 16 | n_trials: [1, ], default: 3 -------------------------------------------------------------------------------- /configs/schemas/test_schema.py: -------------------------------------------------------------------------------- 1 | import jsonschema 2 | import json 3 | from luxonis_train.utils.config import Config 4 | import yaml 5 | 6 | cfg = Config("../simple_det.yaml") 7 | data = cfg.get_data() 8 | 9 | # with open("../simple_det.yaml") as f: 10 | # data = yaml.load(f, Loader=yaml.SafeLoader) 11 | 12 | with open("config_schema_full.json") as f: 13 | schema = json.load(f) 14 | 15 | try: 16 | # Validate the data against the schema 17 | jsonschema.validate(data, schema) 18 | print("Validation successful!") 19 | except jsonschema.exceptions.ValidationError as e: 20 | print(f"Validation failed: {e}") -------------------------------------------------------------------------------- /configs/simple_class.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | name: SimpleClassification 3 | type: 4 | pretrained: 5 | 6 | backbone: 7 | name: MicroNet 8 | pretrained: 9 | 10 | heads: 11 | - name: ClassificationHead 12 | params: 13 | n_classes: null 14 | loss: 15 | name: CrossEntropyLoss 16 | params: 17 | 18 | dataset: 19 | team_id: # TODO 20 | dataset_id: # TODO 21 | -------------------------------------------------------------------------------- /configs/simple_det.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | name: SimpleDetection 3 | type: yolov6-n 4 | pretrained: 5 | params: 6 | n_classes: null 7 | is_4head: False 8 | 9 | dataset: 10 | team_id: # TODO 11 | dataset_id: # TODO -------------------------------------------------------------------------------- /configs/simple_keypoint.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | name: SimpleKeypoint 3 | type: 4 | pretrained: 5 | 6 | backbone: 7 | name: EfficientRep 8 | pretrained: 9 | params: 10 | channels_list: [64, 128, 256, 512, 1024] 11 | num_repeats: [1, 6, 12, 18, 6] 12 | depth_mul: 0.33 13 | width_mul: 0.25 14 | 15 | neck: 16 | name: RepPANNeck 17 | params: 18 | channels_list: [256, 128, 128, 256, 256, 512] 19 | num_repeats: [12, 12, 12, 12] 20 | depth_mul: 0.33 21 | width_mul: 0.25 22 | 23 | heads: 24 | - name: IKeypoint 25 | params: 26 | n_classes: 1 27 | n_keypoints: 17 28 | anchors: 29 | - [12,16, 19,36, 40,28] # P3/8 30 | - [36,75, 76,55, 72,146] # P4/16 31 | - [142,110, 192,243, 459,401] # P5/32 32 | loss: 33 | name: YoloV7PoseLoss 34 | params: 35 | 36 | dataset: 37 | team_id: # TODO 38 | dataset_id: # TODO -------------------------------------------------------------------------------- /configs/simple_seg.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | name: SimpleSegmentation 3 | type: 4 | pretrained: 5 | 6 | backbone: 7 | name: ContextSpatial #ResNet18 8 | pretrained: 9 | 10 | heads: 11 | - name: BiSeNetHead #SegmentationHead 12 | params: 13 | n_classes: null 14 | loss: 15 | name: FocalLoss 16 | params: 17 | 18 | dataset: 19 | team_id: # TODO 20 | dataset_id: # TODO -------------------------------------------------------------------------------- /configs/simple_tuner.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | name: SimpleDetection-Tuner 3 | type: yolov6-n 4 | pretrained: 5 | params: 6 | n_classes: null 7 | is_4head: False 8 | 9 | dataset: 10 | team_id: # TODO 11 | dataset_id: # TODO 12 | 13 | tuner: 14 | params: # (key, value) pairs for tunning 15 | train.optimizers.optimizer.name_categorical: ["Adam", "SGD"] 16 | train.optimizers.optimizer.params.lr_float: [0.0001, 0.001] 17 | train.batch_size_int: [4, 4, 16] -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:23.04-py3 2 | 3 | ENV DEBIAN_FRONTEND=noninteractive 4 | 5 | RUN apt-get update && apt-get install -y git wget python3-dev gcc s3fs ffmpeg libsm6 libxext6 curl libfreetype6-dev libssl-dev libpng-dev 6 | 7 | RUN pip3 uninstall opencv PIL pillow -y 8 | 9 | RUN useradd -ms /bin/bash luxonis 10 | 11 | WORKDIR /home/luxonis 12 | 13 | # Copy the whole library and tool directory 14 | COPY ./src ./src 15 | COPY ./tools ./tools 16 | 17 | # Copy package related files 18 | COPY ["./requirements.txt", "./setup.py", "./README.md", "./MANIFEST.in", "docker/entrypoint.sh", "./"] 19 | 20 | # needed because of _imagingft C module error 21 | RUN pip install --no-cache-dir pillow 22 | 23 | # Install the library 24 | RUN pip install --upgrade pip 25 | RUN pip install --no-cache-dir . 26 | 27 | RUN pip install awscli 28 | 29 | # # Set the entrypoint command for the container 30 | RUN chown -R luxonis:luxonis /home/luxonis 31 | RUN chmod +x ./entrypoint.sh 32 | 33 | USER luxonis 34 | ENTRYPOINT ["/home/luxonis/entrypoint.sh"] -------------------------------------------------------------------------------- /docker/docker-compose.yaml: -------------------------------------------------------------------------------- 1 | version: '3.8' 2 | 3 | x-common: &common 4 | container_name: luxonis-train 5 | env_file: 6 | - ${DOTENV_PATH} 7 | 8 | x-gpu: &gpu 9 | <<: *common 10 | deploy: 11 | resources: 12 | reservations: 13 | devices: 14 | - driver: nvidia 15 | count: 1 16 | capabilities: [ gpu ] 17 | 18 | services: 19 | luxonis-train: 20 | <<: *common 21 | image: luxonis-train 22 | 23 | luxonis-train-gpu: 24 | <<: *gpu 25 | image: luxonis-train-gpu # TODO: create image that supports nvidia gpu -------------------------------------------------------------------------------- /docker/entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir /home/luxonis/.luxonis_mount 4 | echo $AWS_ACCESS_KEY_ID:$AWS_SECRET_ACCESS_KEY > /home/luxonis/.passwd-s3fs 5 | chmod 600 /home/luxonis/.passwd-s3fs 6 | s3fs $AWS_BUCKET /home/luxonis/.luxonis_mount \ 7 | -o passwd_file=/home/luxonis/.passwd-s3fs \ 8 | -o curldbg \ 9 | -o url=$S3_ENDPOINT \ 10 | -o use_path_request_style \ 11 | -o umask=0000 12 | 13 | 14 | 15 | # Check if no arguments are passed 16 | if [ "$#" -eq 0 ]; then 17 | echo "Choose an action that you want to perform from supported actions: ['train', 'tune']" 18 | exit 1 19 | fi 20 | 21 | # Check the first argument 22 | if [ "$1" = "train" ]; then 23 | echo "Starting training..." 24 | shift # Remove the first argument from the list 25 | python3 /home/luxonis/tools/train.py "$@" # Pass remaining arguments to start.py 26 | elif [ "$1" = "tune" ]; then 27 | echo "Starting tunning..." 28 | shift 29 | python3 /home/luxonis/tools/tune.py "$@" 30 | else 31 | echo "Argument $1 doesn't match any action. Supported actions are ['train', 'tune']." 32 | exit 1 33 | fi -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | pre-commit==3.2.1 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations>=1.3.0 2 | numpy>=1.22 3 | opencv-python>=4.7.0.68 4 | python-dotenv>=0.21.1 5 | pytorch-lightning>=2.0.0 6 | PyYAML>=6.0 7 | torch>=1.12.1 8 | torchmetrics>=0.11.0 9 | torchvision>=0.13.1 10 | tensorboard>=2.10.1 11 | rich 12 | luxonis-ml[all] @ git+https://github.com/luxonis/luxonis-ml.git@fdaab1c 13 | # tunning related packages 14 | optuna>=3.2.0 15 | psycopg2-binary 16 | # export related packages (only base ones) 17 | onnx>=1.12.0 18 | onnxsim>=0.4.10 19 | onnxruntime>=1.13.1 20 | # pydantic==1.10.10 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import io 2 | from setuptools import setup, find_packages 3 | 4 | with open('requirements.txt') as f: 5 | required = f.readlines() 6 | 7 | export_packages = [ 8 | "blobconverter>=1.3.0", 9 | # "openvino-dev==2022.1.0" # problematic because of numpy version 10 | ] 11 | 12 | setup( 13 | name="luxonis-train", 14 | version="0.0.1", 15 | description="Luxonis training library for training lightweight models that run fast on OAK products.", 16 | long_description=io.open("README.md", encoding="utf-8").read(), 17 | long_description_content_type="text/markdown", 18 | url="https://github.com/luxonis/models", 19 | keywords="ml trainig luxonis", 20 | author="Luxonis", 21 | author_email="support@luxonis.com", 22 | license="MIT", 23 | packages=find_packages(where="src"), 24 | package_dir={'': 'src'}, # https://stackoverflow.com/a/67238346/5494277 25 | include_package_data=True, 26 | install_requires=required, 27 | extras_require={ 28 | "export": export_packages, 29 | }, 30 | project_urls={ 31 | "Bug Tracker": "https://github.com/luxonis/models/issues", 32 | "Source Code": "https://github.com/luxonis/models/tree/dev", 33 | }, 34 | classifiers=[ 35 | "License :: MIT License", 36 | "Development Status :: 2 - Pre-Alpha", 37 | "Programming Language :: Python :: 3.8" 38 | ] 39 | ) -------------------------------------------------------------------------------- /src/luxonis_train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luxonis/models/bc1cf4b1ea3bdaf918d0f0af13e21ec64bae9e41/src/luxonis_train/__init__.py -------------------------------------------------------------------------------- /src/luxonis_train/config_db/config_all.yaml: -------------------------------------------------------------------------------- 1 | trainer: 2 | accelerator: auto # either "cpu" or "gpu, if "auto" then we cauto-selects best available (string) 3 | devices: auto # either specify how many devices (int) or which specific devices (list) to use. If auto then automatic selection (int|[int]|string) 4 | strategy: auto # either one of PL stategies or auto (string) 5 | num_sanity_val_steps: 2 # number of sanity validation steps performed before training (int) 6 | profiler: null # use PL profiler for GPU/CPU/RAM utilization analysis (string|null) 7 | verbose: True # print all intermidiate results in console (bool) 8 | 9 | logger: 10 | project_name: null # name of the project used for logging (string|null) 11 | project_id: null # id of the project used for logging (relevant if using MLFlow) (string|null) 12 | run_name: null # name of the run, if empty then auto-generate (string|null) 13 | run_id: null # id of already create run (relevant if using MLFlow) (string|null) 14 | save_directory: output # path to the save directory (string) 15 | is_tensorboard: True # bool if use tensorboard (bool) 16 | is_wandb: False # bool if use WanDB (bool) 17 | wandb_entity: null # name of WanDB entity (string|null) 18 | is_mlflow: False # bool if use MLFlow (bool) 19 | logged_hyperparams: ["train.epochs", "train.batch_size"] # list of hyperparameters to log (list) 20 | 21 | dataset: 22 | team_id: null # team under which you can find all datasets (string) 23 | dataset_id: null # id of the dataset (string) 24 | bucket_type: local # underlying storage for images, which can be local or an AWS bucket (local|aws) 25 | override_bucket_type: False # option to change underlying storage from saved setting in DB (bool) 26 | train_view: train # view to use for training (string) 27 | val_view: val # view to use for validation (string) 28 | test_view: test # view to use for testing (string) 29 | 30 | train: 31 | preprocessing: 32 | train_image_size: [256, 256] # image size used for training [height, width] (list) 33 | keep_aspect_ratio: True # bool if keep aspect ration while resizing (bool) 34 | train_rgb: True # bool if train on rgb or bgr (bool) 35 | normalize: 36 | active: True # bool if use normalization (bool) 37 | params: # params for normalization (dict|null) 38 | augmentations: # list of Albumentations augmentations 39 | # - name: Rotate 40 | # params: 41 | # limit: 15 42 | 43 | batch_size: 32 # batch size used for trainig (int) 44 | accumulate_grad_batches: 1 # number of batches for gradient accumulation (int) 45 | use_weighted_sampler: False # bool if use WeightedRandomSampler for training, only works with classification tasks (bool) 46 | epochs: 100 # number of training epochs (int) 47 | num_workers: 2 # number of workers for data loading (int) 48 | train_metrics_interval: -1 # frequency of computing metrics on train data, -1 if don't perform (int) 49 | validation_interval: 1 # frequency of computing metrics on validation data (int) 50 | num_log_images: 4 # maximum number of images to visualize and log (int) 51 | skip_last_batch: True # bool if skip last batch while training (bool) 52 | main_head_index: 0 # index of the head which is used for checkpointing based on best metric (int) 53 | use_rich_text: True # bool if use rich text for console printing 54 | 55 | callbacks: # callback specific parameters (check PL docs) 56 | test_on_finish: False # bool if should run test when train loop finishes (bool) 57 | export_on_finish: False # bool if should run export when train loop finishes - should specify config block (bool) 58 | use_device_stats_monitor: False # bool if should use device stats monitor during training (bool) 59 | model_checkpoint: 60 | save_top_k: 3 61 | early_stopping: 62 | active: True 63 | monitor: val_loss/loss 64 | mode: min 65 | patience: 5 66 | verbose: True 67 | 68 | optimizers: # optimizers specific parameters (check Pytorch docs) 69 | optimizer: 70 | name: Adam 71 | params: 72 | scheduler: 73 | name: ConstantLR 74 | params: 75 | 76 | freeze_modules: # defines which modules you want to freeze (not train) 77 | backbone: False # bool if freeze backbone (bool) 78 | neck: False # bool if freeze neck (bool) 79 | heads: [False] # list of bools for specific head freeeze (list[bool]) 80 | 81 | losses: # defines weights for losses in multi-head architecture 82 | log_sub_losses: False # bool if should also log sub-losses (bool) 83 | weights: [1,1] # list of ints for specific loss weight (list[int]) 84 | # learn_weights: False # bool if weights should be learned (not implemented yet) (bool) 85 | 86 | inferer: 87 | dataset_view: val # view to use for inference (string) 88 | display: True # bool if should display inference resutls (bool) 89 | infer_save_directory: null # if this is not null then use this as save directory (string|null) 90 | 91 | exporter: 92 | export_weights: null # path to local weights used for export (string) 93 | export_save_directory: output_export # path to save directory of exported models (string) 94 | export_image_size: [256, 256] # image size used for export [height, width] (list) 95 | export_model_name: model # name of the exported model (string) 96 | data_type: FP16 # data type used for openVino conversion (string) 97 | reverse_input_channels: True # bool if reverse input shapes (bool) 98 | scale_values: [58.395, 57.120, 57.375] # list of scale values (list[int|float]) 99 | mean_values: [123.675, 116.28, 103.53] # list of mean values (list[int|float]) 100 | onnx: 101 | opset_version: 12 # opset version of onnx used (int) 102 | dynamic_axes: null # define if dynamic input shapes are used (dict) 103 | openvino: 104 | active: False # bool if export to openvino (bool) 105 | blobconverter: 106 | active: False # bool if export to blob (bool) 107 | shaves: 6 # number of shaves used (int) 108 | s3_upload: 109 | active: False # bool if upload .ckpt, .onnx and config file to s3 bucket (bool) 110 | bucket: null # name of the s3 bucket (string) 111 | upload_directory: null # location of directory for upload (string) 112 | 113 | tuner: 114 | study_name: "test-study" # name of the study (string) 115 | use_pruner: True # if should use MedianPruner (bool) 116 | n_trials: 3 # number of trials for each process (int) 117 | timeout: null # stop study after the given number of seconds (null|int) 118 | storage: 119 | active: True # if should use storage to make study persistant (bool) 120 | type: local # type of storage, "local" or "remote" (string) 121 | params: # (key, value) pairs for tunning -------------------------------------------------------------------------------- /src/luxonis_train/config_db/yolov6-n.yaml: -------------------------------------------------------------------------------- 1 | backbone: 2 | name: EfficientRep 3 | pretrained: 4 | params: 5 | channels_list: [64, 128, 256, 512, 1024] 6 | num_repeats: [1, 6, 12, 18, 6] 7 | depth_mul: 0.33 8 | width_mul: 0.25 9 | is_4head: False 10 | 11 | neck: 12 | name: RepPANNeck 13 | params: 14 | channels_list: [256, 128, 128, 256, 256, 512] 15 | num_repeats: [12, 12, 12, 12] 16 | depth_mul: 0.33 17 | width_mul: 0.25 18 | is_4head: False 19 | 20 | heads: 21 | - name: YoloV6Head 22 | params: 23 | n_classes: null 24 | is_4head: False 25 | reg_max: 0 26 | loss: 27 | name: YoloV6Loss 28 | params: 29 | iou_type: siou -------------------------------------------------------------------------------- /src/luxonis_train/config_db/yolov6-s.yaml: -------------------------------------------------------------------------------- 1 | backbone: 2 | name: EfficientRep 3 | pretrained: 4 | params: 5 | channels_list: [64, 128, 256, 512, 1024] 6 | num_repeats: [1, 6, 12, 18, 6] 7 | depth_mul: 0.33 8 | width_mul: 0.50 9 | is_4head: False 10 | 11 | neck: 12 | name: RepPANNeck 13 | params: 14 | channels_list: [256, 128, 128, 256, 256, 512] 15 | num_repeats: [12, 12, 12, 12] 16 | depth_mul: 0.33 17 | width_mul: 0.50 18 | is_4head: False 19 | 20 | heads: 21 | - name: YoloV6Head 22 | params: 23 | n_classes: null 24 | is_4head: False 25 | reg_max: 0 26 | loss: 27 | name: YoloV6Loss 28 | params: 29 | iou_type: giou -------------------------------------------------------------------------------- /src/luxonis_train/config_db/yolov6-t.yaml: -------------------------------------------------------------------------------- 1 | backbone: 2 | name: EfficientRep 3 | pretrained: 4 | params: 5 | channels_list: [64, 128, 256, 512, 1024] 6 | num_repeats: [1, 6, 12, 18, 6] 7 | depth_mul: 0.33 8 | width_mul: 0.375 9 | is_4head: False 10 | 11 | neck: 12 | name: RepPANNeck 13 | params: 14 | channels_list: [256, 128, 128, 256, 256, 512] 15 | num_repeats: [12, 12, 12, 12] 16 | depth_mul: 0.33 17 | width_mul: 0.375 18 | is_4head: False 19 | 20 | heads: 21 | - name: YoloV6Head 22 | params: 23 | n_classes: null 24 | is_4head: False 25 | reg_max: 0 26 | loss: 27 | name: YoloV6Loss 28 | params: 29 | iou_type: siou -------------------------------------------------------------------------------- /src/luxonis_train/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import Trainer 2 | from .exporter import Exporter 3 | from .inferer import Inferer 4 | from .tuner import Tuner 5 | 6 | __all__ = [ 7 | "Trainer", 8 | "Exporter", 9 | "Inferer", 10 | "Tuner" 11 | ] -------------------------------------------------------------------------------- /src/luxonis_train/core/exporter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import onnx 4 | import onnxsim 5 | import pytorch_lightning as pl 6 | import warnings 7 | from typing import Union, Optional 8 | from pathlib import Path 9 | from dotenv import load_dotenv 10 | 11 | from luxonis_train.utils.config import Config 12 | from luxonis_train.models import Model 13 | from luxonis_train.models.heads import * 14 | 15 | class Exporter(pl.LightningModule): 16 | def __init__(self, cfg: Union[str, dict], args: Optional[dict] = None): 17 | """ Main API which is used for exporting models trained with this library to .onnx, openVINO and .blob format. 18 | 19 | Args: 20 | cfg (Union[str, dict]): path to config file or config dict used to setup training 21 | args (Optional[dict]): argument dict provided through command line, used for config overriding 22 | """ 23 | super().__init__() 24 | 25 | load_dotenv() 26 | 27 | self.cfg = Config(cfg) 28 | if args and args["override"]: 29 | self.cfg.override_config(args["override"]) 30 | self.cfg.validate_config_exporter() 31 | 32 | # ensure save directory 33 | Path(self.cfg.get("exporter.export_save_directory")).mkdir(parents=True, exist_ok=True) 34 | 35 | self.model = Model() 36 | self.model.build_model() 37 | 38 | self.load_checkpoint(self.cfg.get("exporter.export_weights")) 39 | self.model.eval() 40 | self.to_deploy() 41 | 42 | def load_checkpoint(self, path: str): 43 | """ Loads checkpoint weights from provided path """ 44 | print(f"Loading weights from: {path}") 45 | state_dict = torch.load(path)["state_dict"] 46 | # remove weights that are not part of the model 47 | removed = [] 48 | for key in list(state_dict.keys()): 49 | if not key.startswith("model"): 50 | removed.append(key) 51 | state_dict.pop(key) 52 | if len(removed): 53 | print(f"Following weights weren't loaded: {removed}") 54 | 55 | self.load_state_dict(state_dict) 56 | 57 | def to_deploy(self): 58 | """ Switch modules of the model to deploy""" 59 | for module in self.model.modules(): 60 | if hasattr(module, "to_deploy"): 61 | module.to_deploy() 62 | 63 | def forward(self, inputs: torch.Tensor): 64 | """ Forward function used in export """ 65 | outputs = self.model(inputs) 66 | return outputs 67 | 68 | def export(self): 69 | """ Exports model to onnx and optionally to openVINO and .blob format """ 70 | dummy_input = torch.rand(1,3,*self.cfg.get("exporter.export_image_size")) 71 | base_path = self.cfg.get("exporter.export_save_directory") 72 | output_names = self._get_output_names() 73 | 74 | print("Converting PyTorch model to ONNX") 75 | onnx_path = os.path.join(base_path, f"{self.cfg.get('exporter.export_model_name')}.onnx") 76 | self.to_onnx( 77 | onnx_path, 78 | dummy_input, 79 | opset_version=self.cfg.get("exporter.onnx.opset_version"), 80 | input_names=["input"], 81 | output_names=output_names, 82 | dynamic_axes=self.cfg.get("exporter.onnx.dynamic_axes") 83 | ) 84 | model_onnx = onnx.load(onnx_path) 85 | onnx_model, check = onnxsim.simplify(model_onnx) 86 | if not check: 87 | raise RuntimeError("Onnx simplify failed.") 88 | onnx.save(onnx_model, onnx_path) 89 | 90 | if self.cfg.get("exporter.openvino.active"): 91 | import subprocess 92 | print("Converting ONNX to openVINO") 93 | output_list = ",".join(output_names) 94 | 95 | cmd = f"mo --input_model {onnx_path} " \ 96 | f"--output_dir {base_path} " \ 97 | f"--model_name {self.cfg.get('exporter.export_model_name')} " \ 98 | f"--data_type {self.cfg.get('exporter.data_type')} " \ 99 | f"--scale_values '{self.cfg.get('exporter.scale_values')}' " \ 100 | f"--mean_values '{self.cfg.get('exporter.mean_values')}' " \ 101 | f"--output {output_list}" 102 | 103 | if self.cfg.get("exporter.reverse_input_channels"): 104 | cmd += " --reverse_input_channels " 105 | 106 | subprocess.check_output(cmd, shell=True) 107 | 108 | if self.cfg.get("exporter.blobconverter.active"): 109 | import blobconverter 110 | print("Converting ONNX to .blob") 111 | 112 | optimizer_params=[ 113 | f"--scale_values={self.cfg.get('exporter.scale_values')}", 114 | f"--mean_values={self.cfg.get('exporter.mean_values')}", 115 | ] 116 | if self.cfg.get("exporter.reverse_input_channels"): 117 | optimizer_params.append("--reverse_input_channels") 118 | 119 | blob_path = blobconverter.from_onnx( 120 | model=onnx_path, 121 | optimizer_params=optimizer_params, 122 | data_type=self.cfg.get("exporter.data_type"), 123 | shaves=self.cfg.get("exporter.blobconverter.shaves"), 124 | use_cache=False, 125 | output_dir=base_path 126 | ) 127 | 128 | print(f"Finished exporting. Files saved in: {base_path}") 129 | 130 | if self.cfg.get("exporter.s3_upload.active"): 131 | if None not in [self.cfg.get("logger.project_id"), self.cfg.get("logger.run_id")]: 132 | warnings.warn("Using current MLFlow run for upload instead of specified bucket.") 133 | bucket = os.getenv("MLFLOW_S3_BUCKET") 134 | base_key = f'{self.cfg.get("logger.project_id")}/{self.cfg.get("logger.run_id")}/artifacts' 135 | else: 136 | bucket = self.cfg.get("exporter.s3_upload.bucket") 137 | base_key = f"{self.cfg.get('exporter.s3_upload.upload_directory')}/{self.cfg.get('exporter.export_model_name')}" 138 | 139 | self._upload_to_s3(onnx_path, bucket, base_key) 140 | 141 | def _get_output_names(self): 142 | """ Gets output names for each head """ 143 | output_names = [] 144 | for i, head in enumerate(self.model.heads): 145 | curr_output = head.get_output_names(i) 146 | if isinstance(curr_output, str): 147 | output_names.append(curr_output) 148 | else: 149 | output_names.extend(curr_output) 150 | return output_names 151 | 152 | def _upload_to_s3(self, onnx_path, bucket, base_key): 153 | """ Uploads .pt, .onnx and current config.yaml to specified s3 bucket """ 154 | if None in [bucket, base_key]: 155 | raise KeyError("Bucket or base_key not specified. Check 's3_upload' in exporter.") 156 | 157 | import boto3 158 | import yaml 159 | 160 | print("Started upload to S3...") 161 | 162 | s3_client = boto3.client("s3", 163 | aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"), 164 | aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"), 165 | endpoint_url=os.getenv("AWS_S3_ENDPOINT_URL") 166 | ) 167 | 168 | # upload .ckpt file 169 | s3_client.upload_file( 170 | Filename=self.cfg.get("exporter.export_weights"), 171 | Bucket=bucket, 172 | Key=f"{base_key}/{self.cfg.get('exporter.export_model_name')}.ckpt") 173 | 174 | # upload .onnx file 175 | s3_client.upload_file( 176 | Filename=onnx_path, 177 | Bucket=bucket, 178 | Key=f"{base_key}/{self.cfg.get('exporter.export_model_name')}.onnx") 179 | 180 | # upload config.yaml 181 | self.cfg.save_data("config.yaml") # create temporary file 182 | s3_client.upload_file( 183 | Filename="config.yaml", 184 | Bucket=bucket, 185 | Key=f"{base_key}/config.yaml") 186 | os.remove("config.yaml") # delete temporary file 187 | 188 | # generate and upload export_config.yaml compatible with modelconverter 189 | onnx_path = f"s3://{bucket}/" + \ 190 | f"{base_key}/{self.cfg.get('exporter.export_model_name')}.onnx" 191 | modelconverter_config = self._get_modelconverter_config(onnx_path) 192 | 193 | with open("config_export.yaml", "w+") as f: 194 | yaml.dump(modelconverter_config, f, default_flow_style=False) 195 | 196 | s3_client.upload_file( 197 | Filename="config_export.yaml", 198 | Bucket=bucket, 199 | Key=f"{base_key}/config_export.yaml") 200 | os.remove("config_export.yaml") # delete temporary file 201 | 202 | print(f"Files uploaded to: s3://{bucket}/{base_key}") 203 | 204 | def _get_modelconverter_config(self, onnx_path: str): 205 | """ Generates export config from input config that is 206 | compatible with Luxonis modelconverter tool 207 | 208 | Args: 209 | onnx_path (str): Path to .onnx model 210 | """ 211 | out_config = { 212 | "input_model": onnx_path, 213 | "scale_values": self.cfg.get("exporter.scale_values"), 214 | "mean_values": self.cfg.get("exporter.mean_values"), 215 | "reverse_input_channels": self.cfg.get("exporter.reverse_input_channels"), 216 | "use_bgr": not self.cfg.get("train.preprocessing.train_rgb"), 217 | "input_shape": [1,3] + self.cfg.get("exporter.export_image_size"), 218 | "data_type": "f16", #self.cfg.get("exporter.data_type"), # NOTE: change this when modelconverter is updated 219 | "output": [{"name":name} for name in self._get_output_names()], 220 | "meta":{ 221 | "description": self.cfg.get("exporter.export_model_name") 222 | } 223 | } 224 | return out_config -------------------------------------------------------------------------------- /src/luxonis_train/core/inferer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import os 4 | import numpy as np 5 | import pytorch_lightning as pl 6 | import matplotlib.pyplot as plt 7 | from typing import Union, Optional 8 | from tqdm import tqdm 9 | from luxonis_ml.data import LuxonisDataset 10 | from luxonis_ml.loader import LuxonisLoader 11 | from luxonis_ml.loader import TrainAugmentations, ValAugmentations, Augmentations 12 | 13 | from luxonis_train.utils.config import Config 14 | from luxonis_train.models import Model 15 | from luxonis_train.models.heads import * 16 | from luxonis_train.utils.visualization import draw_outputs, draw_labels 17 | 18 | 19 | class Inferer(pl.LightningModule): 20 | def __init__(self, cfg: Union[str, dict], args: Optional[dict] = None): 21 | """Main API which is used for inference/visualization on the dataset 22 | 23 | Args: 24 | cfg (Union[str, dict]): path to config file or config dict used to setup training 25 | args (Optional[dict]): argument dict provided through command line, used for config overriding 26 | """ 27 | super().__init__() 28 | 29 | self.cfg = Config(cfg) 30 | if args and args["override"]: 31 | self.cfg.override_config(args["override"]) 32 | 33 | self.model = Model() 34 | self.model.build_model() 35 | 36 | self.load_checkpoint(self.cfg.get("model.pretrained")) 37 | self.model.eval() 38 | 39 | self.augmentations = None 40 | 41 | def load_checkpoint(self, path: str): 42 | """ Loads checkpoint weights from provided path """ 43 | print(f"Loading weights from: {path}") 44 | state_dict = torch.load(path)["state_dict"] 45 | # remove weights that are not part of the model 46 | removed = [] 47 | for key in state_dict.keys(): 48 | if not key.startswith("model"): 49 | removed.append(key) 50 | state_dict.pop(key) 51 | if len(removed): 52 | print(f"Following weights weren't loaded: {removed}") 53 | 54 | self.load_state_dict(state_dict) 55 | 56 | def override_augmentations(self, aug: object): 57 | """ Overrides augmentations used for validation dataset """ 58 | self.augmentations = aug 59 | 60 | def forward(self, inputs: torch.Tensor): 61 | """ Forward function used in inference """ 62 | outputs = self.model(inputs) 63 | return outputs 64 | 65 | def infer(self): 66 | """ Runs inference on all images in the dataset """ 67 | 68 | with LuxonisDataset( 69 | team_id=self.cfg.get("dataset.team_id"), 70 | dataset_id=self.cfg.get("dataset.dataset_id"), 71 | bucket_type=self.cfg.get("dataset.bucket_type"), 72 | override_bucket_type=self.cfg.get("dataset.override_bucket_type") 73 | ) as dataset: 74 | 75 | view = self.cfg.get("inferer.dataset_view") 76 | 77 | if self.augmentations == None: 78 | if view == "train": 79 | self.augmentations = TrainAugmentations( 80 | image_size=self.cfg.get("train.preprocessing.train_image_size"), 81 | augmentations=self.cfg.get("train.preprocessing.augmentations"), 82 | train_rgb=self.cfg.get("train.preprocessing.train_rgb"), 83 | keep_aspect_ratio=self.cfg.get("train.preprocessing.keep_aspect_ratio") 84 | ) 85 | else: 86 | self.augmentations = ValAugmentations( 87 | image_size=self.cfg.get("train.preprocessing.train_image_size"), 88 | augmentations=self.cfg.get("train.preprocessing.augmentations"), 89 | train_rgb=self.cfg.get("train.preprocessing.train_rgb"), 90 | keep_aspect_ratio=self.cfg.get("train.preprocessing.keep_aspect_ratio") 91 | ) 92 | 93 | loader_val = LuxonisLoader( 94 | dataset, 95 | view=view, 96 | augmentations=self.augmentations 97 | ) 98 | 99 | pytorch_loader_val = torch.utils.data.DataLoader( 100 | loader_val, 101 | batch_size=self.cfg.get("train.batch_size"), 102 | num_workers=self.cfg.get("train.num_workers"), 103 | collate_fn=loader_val.collate_fn 104 | ) 105 | 106 | display = self.cfg.get("inferer.display") 107 | save_dir = self.cfg.get("inferer.infer_save_directory") 108 | 109 | if save_dir is not None: 110 | os.makedirs(save_dir, exist_ok=True) 111 | 112 | unnormalize_img = self.cfg.get("train.preprocessing.normalize.active") 113 | cvt_color = not self.cfg.get("train.preprocessing.train_rgb") 114 | counter = 0 115 | with torch.no_grad(): 116 | for data in tqdm(pytorch_loader_val): 117 | inputs = data[0] 118 | label_dict = data[1] 119 | outputs = self.forward(inputs) 120 | 121 | for i, output in enumerate(outputs): 122 | curr_head = self.model.heads[i] 123 | curr_head_name = curr_head.get_name(i) 124 | 125 | label_imgs = draw_labels(imgs=inputs, label_dict=label_dict, label_keys=curr_head.label_types, 126 | unnormalize_img=unnormalize_img, cvt_color=cvt_color) 127 | output_imgs = draw_outputs(imgs=inputs, output=output, head=curr_head, 128 | unnormalize_img=unnormalize_img, cvt_color=cvt_color) 129 | merged_imgs = [cv2.hconcat([l_img, o_img]) for l_img, o_img in zip(label_imgs, output_imgs)] 130 | 131 | for img in merged_imgs: 132 | counter += 1 133 | plt.imshow(img) 134 | plt.title(curr_head_name+f"\n(labels left, outputs right)") 135 | if save_dir is not None: 136 | plt.savefig(os.path.join(save_dir, f"{counter}.png")) 137 | if display: 138 | plt.show() 139 | 140 | 141 | def infer_image(self, img: np.ndarray, augmentations: Optional[Augmentations] = None, 142 | display: bool = True, save_path: Optional[str] = None): 143 | """ Runs inference on single image 144 | 145 | Args: 146 | img (np.ndarray): Input image of shape (H x W x C) and dtype uint8. 147 | augmentations (Optional[Augmentations], optional): Instance of augmentation class. If None use ValAugmentations(). Defaults to None. 148 | display (bool, optional): Control if want to display output. Defaults to True. 149 | save_path (Optional[str], optional): Path for saving the output, will generate separate image for each model head. If None then don't save. Defaults to None. 150 | """ 151 | 152 | if augmentations == None: 153 | augmentations = ValAugmentations( 154 | image_size=self.cfg.get("train.preprocessing.train_image_size"), 155 | augmentations=self.cfg.get("train.preprocessing.augmentations"), 156 | train_rgb=self.cfg.get("train.preprocessing.train_rgb"), 157 | keep_aspect_ratio=self.cfg.get("train.preprocessing.keep_aspect_ratio") 158 | ) 159 | 160 | # IMG IN RGB HWC 161 | transformed = augmentations.transform( 162 | image = img, 163 | bboxes = [], 164 | bbox_classes = [], 165 | keypoints = [], 166 | keypoints_classes = [] 167 | ) 168 | inputs = torch.unsqueeze(transformed["image"], dim=0) 169 | outputs = self.forward(inputs) 170 | 171 | unnormalize_img = self.cfg.get("train.preprocessing.normalize.active") 172 | cvt_color = not self.cfg.get("train.preprocessing.train_rgb") 173 | 174 | for i, output in enumerate(outputs): 175 | curr_head = self.model.heads[i] 176 | curr_head_name = curr_head.get_name(i) 177 | 178 | output_img = draw_outputs(imgs=inputs, output=output, head=curr_head, 179 | unnormalize_img=unnormalize_img, cvt_color=cvt_color)[0] 180 | 181 | if save_path is not None: 182 | path, save_type = save_path.rsplit(".", 1) # get desired save type (e.g. .png, .jpg, ...) 183 | # use cv2 for saving to avoid paddings 184 | save_img = cv2.cvtColor(output_img, cv2.COLOR_RGB2BGR) 185 | cv2.imwrite(path+f"_{curr_head_name}.{save_type}", save_img) 186 | 187 | if display: 188 | plt.imshow(output_img) 189 | plt.title(curr_head_name) 190 | plt.show() 191 | -------------------------------------------------------------------------------- /src/luxonis_train/core/tuner.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import warnings 3 | import torch 4 | import os 5 | import optuna 6 | from typing import Union, Optional 7 | from dotenv import load_dotenv 8 | from copy import deepcopy 9 | from pytorch_lightning.utilities import rank_zero_only 10 | from optuna.integration import PyTorchLightningPruningCallback 11 | from luxonis_ml.tracker import LuxonisTrackerPL 12 | from luxonis_ml.data import LuxonisDataset 13 | from luxonis_ml.loader import LuxonisLoader, TrainAugmentations, ValAugmentations 14 | 15 | from luxonis_train.utils.config import Config 16 | from luxonis_train.utils.callbacks import LuxonisProgressBar 17 | from luxonis_train.models import ModelLightningModule 18 | 19 | 20 | class Tuner: 21 | def __init__(self, cfg: Union[str, dict], args: Optional[dict] = None): 22 | """Main API which is used to perform hyperparameter tunning 23 | 24 | Args: 25 | cfg (Union[str, dict]): path to config file or config dict used to setup training 26 | args (Optional[dict]): argument dict provided through command line, used for config overriding 27 | """ 28 | self.cfg_data = cfg 29 | self.args = args 30 | load_dotenv() 31 | 32 | def tune(self): 33 | """Runs Optuna tunning of hyperparameters""" 34 | self.cfg = Config(self.cfg_data) 35 | if self.args and self.args["override"]: 36 | self.cfg.override_config(self.args["override"]) 37 | self.cfg.validate_config_tuner() 38 | 39 | pruner = ( 40 | optuna.pruners.MedianPruner() 41 | if self.cfg.get("tuner.use_pruner") 42 | else optuna.pruners.NopPruner() 43 | ) 44 | 45 | storage = None 46 | if self.cfg.get("tuner.storage.active"): 47 | if self.cfg.get("tuner.storage.type") == "local": 48 | storage = "sqlite:///study_local.db" 49 | elif self.cfg.get("tuner.storage.type") == "remote": 50 | storage = "postgresql://{}:{}@{}:{}/{}".format( 51 | os.environ["POSTGRES_USER"], 52 | os.environ["POSTGRES_PASSWORD"], 53 | os.environ["POSTGRES_HOST"], 54 | os.environ["POSTGRES_PORT"], 55 | os.environ["POSTGRES_DB"], 56 | ) 57 | else: 58 | raise KeyError( 59 | f"Storage type '{self.cfg.get('tuner.storage.type')}'" 60 | + "not supported. Choose one of ['local', 'remote']" 61 | ) 62 | 63 | study = optuna.create_study( 64 | study_name=self.cfg.get("tuner.study_name"), 65 | storage=storage, 66 | direction="minimize", 67 | pruner=pruner, 68 | load_if_exists=True, 69 | ) 70 | 71 | study.optimize( 72 | self._objective, 73 | n_trials=self.cfg.get("tuner.n_trials"), 74 | timeout=self.cfg.get("tuner.timeout"), 75 | ) 76 | 77 | def _objective(self, trial: optuna.trial.Trial): 78 | """Objective function used to optimize Optuna study""" 79 | # TODO: check if this is even needed needed because config is singleton 80 | # Config.clear_instance() 81 | self.cfg = Config(self.cfg_data) 82 | if self.args and self.args["override"]: 83 | self.cfg.override_config(self.args["override"]) 84 | self.cfg.validate_config_tuner() 85 | 86 | rank = rank_zero_only.rank 87 | cfg_logger = self.cfg.get("logger") 88 | logger_params = deepcopy(cfg_logger.copy()) 89 | logger_params.pop("logged_hyperparams") 90 | logger = LuxonisTrackerPL( 91 | rank=rank, 92 | mlflow_tracking_uri=os.getenv( 93 | "MLFLOW_TRACKING_URI" 94 | ), # read seperately from env vars 95 | is_sweep=True, 96 | **logger_params, 97 | ) 98 | run_save_dir = os.path.join(cfg_logger["save_directory"], logger.run_name) 99 | 100 | # get curr trial params and update config 101 | curr_params = self._get_trial_params(trial) 102 | for key, value in curr_params.items(): 103 | self.cfg.override_config(f"{key} {value}") 104 | 105 | logger.log_hyperparams(curr_params) # log curr trial params 106 | 107 | # save current config to logger directory 108 | self.cfg.save_data(os.path.join(run_save_dir, "config.yaml")) 109 | 110 | lightning_module = ModelLightningModule(run_save_dir) 111 | pruner_callback = PyTorchLightningPruningCallback( 112 | trial, monitor="val_loss/loss" 113 | ) 114 | pl_trainer = pl.Trainer( 115 | accelerator=self.cfg.get("trainer.accelerator"), 116 | devices=self.cfg.get("trainer.devices"), 117 | strategy=self.cfg.get("trainer.strategy"), 118 | logger=logger, 119 | max_epochs=self.cfg.get("train.epochs"), 120 | accumulate_grad_batches=self.cfg.get("train.accumulate_grad_batches"), 121 | check_val_every_n_epoch=self.cfg.get("train.validation_interval"), 122 | num_sanity_val_steps=self.cfg.get("trainer.num_sanity_val_steps"), 123 | profiler=self.cfg.get("trainer.profiler"), # for debugging purposes, 124 | callbacks=[ 125 | LuxonisProgressBar() 126 | if self.cfg.get("train.use_rich_text") 127 | else None, # NOTE: this is likely PL bug, should be configurable inside configure_callbacks(), 128 | pruner_callback, 129 | ], 130 | ) 131 | 132 | with LuxonisDataset( 133 | team_id=self.cfg.get("dataset.team_id"), 134 | dataset_id=self.cfg.get("dataset.dataset_id"), 135 | bucket_type=self.cfg.get("dataset.bucket_type"), 136 | override_bucket_type=self.cfg.get("dataset.override_bucket_type"), 137 | ) as dataset: 138 | loader_train = LuxonisLoader( 139 | dataset, 140 | view=self.cfg.get("dataset.train_view"), 141 | augmentations=TrainAugmentations( 142 | image_size=self.cfg.get("train.preprocessing.train_image_size"), 143 | augmentations=self.cfg.get("train.preprocessing.augmentations"), 144 | train_rgb=self.cfg.get("train.preprocessing.train_rgb"), 145 | keep_aspect_ratio=self.cfg.get( 146 | "train.preprocessing.keep_aspect_ratio" 147 | ), 148 | ), 149 | ) 150 | 151 | sampler = None 152 | if self.cfg.get("train.use_weighted_sampler"): 153 | classes_count = dataset.get_classes_count() 154 | if len(classes_count) == 0: 155 | warnings.warn( 156 | "WeightedRandomSampler only available for classification tasks. Using default sampler instead." 157 | ) 158 | else: 159 | weights = [1 / i for i in classes_count.values()] 160 | num_samples = sum(classes_count.values()) 161 | sampler = torch.utils.data.WeightedRandomSampler( 162 | weights, num_samples 163 | ) 164 | 165 | pytorch_loader_train = torch.utils.data.DataLoader( 166 | loader_train, 167 | batch_size=self.cfg.get("train.batch_size"), 168 | num_workers=self.cfg.get("train.num_workers"), 169 | collate_fn=loader_train.collate_fn, 170 | drop_last=self.cfg.get("train.skip_last_batch"), 171 | sampler=sampler, 172 | ) 173 | 174 | loader_val = LuxonisLoader( 175 | dataset, 176 | view=self.cfg.get("dataset.val_view"), 177 | augmentations=ValAugmentations( 178 | image_size=self.cfg.get("train.preprocessing.train_image_size"), 179 | augmentations=self.cfg.get("train.preprocessing.augmentations"), 180 | train_rgb=self.cfg.get("train.preprocessing.train_rgb"), 181 | keep_aspect_ratio=self.cfg.get( 182 | "train.preprocessing.keep_aspect_ratio" 183 | ), 184 | ), 185 | ) 186 | pytorch_loader_val = torch.utils.data.DataLoader( 187 | loader_val, 188 | batch_size=self.cfg.get("train.batch_size"), 189 | num_workers=self.cfg.get("train.num_workers"), 190 | collate_fn=loader_val.collate_fn, 191 | ) 192 | 193 | pl_trainer.fit(lightning_module, pytorch_loader_train, pytorch_loader_val) 194 | pruner_callback.check_pruned() 195 | 196 | return pl_trainer.callback_metrics["val_loss/loss"].item() 197 | 198 | def _get_trial_params(self, trial: optuna.trial.Trial): 199 | """Get trial params based on specified config""" 200 | cfg_tuner = self.cfg.get("tuner.params") 201 | new_params = {} 202 | for key, value in cfg_tuner.items(): 203 | key_info = key.split("_") 204 | key_name = "_".join(key_info[:-1]) 205 | key_type = key_info[-1] 206 | 207 | if key_type == "categorical": 208 | # NOTE: might need to do some preprocessing if list doesn't only have strings 209 | new_value = trial.suggest_categorical(key_name, value) 210 | elif key_type == "float": 211 | new_value = trial.suggest_float(key_name, *value) 212 | elif key_type == "int": 213 | new_value = trial.suggest_int(key_name, *value) 214 | elif key_type == "loguniform": 215 | new_value = trial.suggest_loguniform(key_name, *value) 216 | elif key_type == "uniform": 217 | new_value = trial.suggest_uniform(key_name, *value) 218 | else: 219 | raise KeyError(f"Tunning type '{key_type}' not supported.") 220 | 221 | new_params[key_name] = new_value 222 | return new_params 223 | -------------------------------------------------------------------------------- /src/luxonis_train/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luxonis/models/bc1cf4b1ea3bdaf918d0f0af13e21ec64bae9e41/src/luxonis_train/datasets/__init__.py -------------------------------------------------------------------------------- /src/luxonis_train/models/README.md: -------------------------------------------------------------------------------- 1 | ## List of supported predefined models 2 | - YoloV6 ([source](https://github.com/meituan/YOLOv6/tree/725913050e15a31cd091dfd7795a1891b0524d35)) (YoloV6-n, YoloV6-t and YoloV6-s) 3 | - Params: 4 | - n_classes: int # used in YoloV6 head -------------------------------------------------------------------------------- /src/luxonis_train/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Model 2 | from .model_lightning_module import ModelLightningModule 3 | 4 | __all__ = [ 5 | "Model", 6 | "ModelLightningModule" 7 | ] 8 | -------------------------------------------------------------------------------- /src/luxonis_train/models/backbones/README.md: -------------------------------------------------------------------------------- 1 | ## List of supported backbones 2 | - ResNet18 ([source](https://pytorch.org/vision/main/models/generated/torchvision.models.resnet18.html)) 3 | - Params: 4 | - download_weights: bool # If True download weights from imagenet. Defaults to False. 5 | - MicroNet ([source](https://github.com/liyunsheng13/micronet)) 6 | - Params: 7 | - variant: str # Variant from ['M1', 'M2', 'M3']. Defaults to 'M1'. 8 | - RepVGG ([source](https://github.com/DingXiaoH/RepVGG)) 9 | - Params: 10 | - variant: str # Variant from ['A0', 'A1']. Defaults to "A0". 11 | - EfficientRep (adapted from [here](https://github.com/meituan/YOLOv6/blob/725913050e15a31cd091dfd7795a1891b0524d35/yolov6/models/efficientrep.py)) 12 | - Params: 13 | - channels_list: List[int] # List of number of channels for each block 14 | - num_repeats: List[int] # List of number of repeats of RepBlock 15 | - in_channels: int # Number of input channels, should be 3 in most cases . Defaults to 3. 16 | - depth_mul: int # Depth multiplier. Defaults to 0.33. 17 | - width_mul: int # Width multiplier. Defaults to 0.25. 18 | - is_4head: bool # Either build 4 headed architecture or 3 headed one (**Important: Should be same also on neck and head**). Defaults to False. 19 | - RexNetV1_lite ([source](https://github.com/clovaai/rexnet)) 20 | - Params: 21 | - fix_head_stem: bool # Weather to multiply head stem. Defaults to False. 22 | - divisible_value: int # Divisor used. Defaults to 8. 23 | - input_ch: int # tarting channel dimension. Defaults to 16. 24 | - final_ch: int # Final channel dimension. Defaults to 164. 25 | - multiplier: float # Channel dimension multiplier. Defaults to 1.0. 26 | - kernel_conf: str # Kernel sizes encoded as string. Defaults to '333333'. 27 | - MobileOne ([source](https://github.com/apple/ml-mobileone)) 28 | - Params: 29 | - variant: str # Variant from ['s0', 's1', 's2', 's3', 's4']. Defaults to "s0". 30 | - MobileNetV2 ([source](https://pytorch.org/vision/main/models/generated/torchvision.models.mobilenet_v2.html)) 31 | - Params: 32 | - download_weights: bool # If True download weights from imagenet. Defaults to False. 33 | - EfficientNet ([source](https://github.com/rwightman/gen-efficientnet-pytorch)) 34 | - Params: 35 | - download_weights: bool # If True download weights from imagenet. Defualts to False. 36 | - ContextSpatial (adapted from [here](https://github.com/taveraantonio/BiseNetv1)) 37 | - Params: 38 | - context_backbone: str # Backbone used. Defaults to 'MobileNetV2'. 39 | - in_channels: int # Number of input channels, should be 3 in most cases. Defaults to 3. 40 | 41 | 42 | - TODO: add DeepLabV3+ ([source](https://github.com/VainF/DeepLabV3Plus-Pytorch)) -------------------------------------------------------------------------------- /src/luxonis_train/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .micronet import MicroNet 2 | from .resnet18 import ResNet18 3 | from .efficient_rep import EfficientRep 4 | from .rep_vgg import RepVGG 5 | from .rexnetv1 import ReXNetV1_lite 6 | from .mobileone import MobileOne 7 | from .mobilenetv2 import MobileNetV2 8 | from .efficientnet import EfficientNet 9 | from .context_spatial import ContextSpatial 10 | 11 | __all__ = [ 12 | "MicroNet", 13 | "ResNet18", 14 | "EfficientRep", 15 | "RepVGG", 16 | "ReXNetV1_lite", 17 | "MobileOne", 18 | "MobileNetV2", 19 | "EfficientNet", 20 | "ContextSpatial" 21 | ] 22 | -------------------------------------------------------------------------------- /src/luxonis_train/models/backbones/context_spatial.py: -------------------------------------------------------------------------------- 1 | # 2 | # Source: https://github.com/taveraantonio/BiseNetv1 3 | # 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn import functional as F 8 | 9 | from luxonis_train.models.backbones import * 10 | from luxonis_train.models.modules import ConvModule 11 | 12 | class SpatialPath(nn.Module): 13 | def __init__(self, c1, c2) -> None: 14 | super().__init__() 15 | ch = 64 16 | self.conv_7x7 = ConvModule(c1, ch, 7, 2, 3) 17 | self.conv_3x3_1 = ConvModule(ch, ch, 3, 2, 1) 18 | self.conv_3x3_2 = ConvModule(ch, ch, 3, 2, 1) 19 | self.conv_1x1 = ConvModule(ch, c2, 1, 1, 0) 20 | 21 | def forward(self, x): 22 | x = self.conv_7x7(x) 23 | x = self.conv_3x3_1(x) 24 | x = self.conv_3x3_2(x) 25 | return self.conv_1x1(x) 26 | 27 | 28 | class ContextPath(nn.Module): 29 | def __init__(self, backbone: nn.Module) -> None: 30 | super().__init__() 31 | self.backbone = backbone 32 | c3, c4 = self.backbone.channels[-2:] 33 | 34 | self.arm16 = AttentionRefinmentModule(c3, 128) 35 | self.arm32 = AttentionRefinmentModule(c4, 128) 36 | 37 | self.global_context = nn.Sequential( 38 | nn.AdaptiveAvgPool2d(1), 39 | ConvModule(c4, 128, 1, 1, 0) 40 | ) 41 | 42 | self.up16 = nn.Upsample(scale_factor=2.0, mode="bilinear", align_corners=True) 43 | self.up32 = nn.Upsample(scale_factor=2.0, mode="bilinear", align_corners=True) 44 | 45 | self.refine16 = ConvModule(128, 128, 3, 1, 1) 46 | self.refine32 = ConvModule(128, 128, 3, 1, 1) 47 | 48 | 49 | def forward(self, x): 50 | _, _, down16, down32 = self.backbone(x) # 4x256x64x128, 4x512x32x64 51 | 52 | arm_down16 = self.arm16(down16) # 4x128x64x128 53 | arm_down32 = self.arm32(down32) # 4x128x32x64 54 | 55 | global_down32 = self.global_context(down32) # 4x128x1x1 56 | global_down32 = F.interpolate(global_down32, size=down32.size()[2:], mode="bilinear", align_corners=True) # 4x128x32x64 57 | 58 | arm_down32 = arm_down32 + global_down32 # 4x128x32x64 59 | arm_down32 = self.up32(arm_down32) # 4x128x64x128 60 | arm_down32 = self.refine32(arm_down32) # 4x128x64x128 61 | 62 | arm_down16 = arm_down16 + arm_down32 # 4x128x64x128 63 | arm_down16 = self.up16(arm_down16) # 4x128x128x256 64 | arm_down16 = self.refine16(arm_down16) # 4x128x128x256 65 | 66 | return arm_down16, arm_down32 67 | 68 | class AttentionRefinmentModule(nn.Module): 69 | def __init__(self, c1, c2) -> None: 70 | super().__init__() 71 | self.conv_3x3 = ConvModule(c1, c2, 3, 1, 1) 72 | 73 | self.attention = nn.Sequential( 74 | nn.AdaptiveAvgPool2d(1), 75 | nn.Conv2d(c2, c2, 1, bias=False), 76 | nn.BatchNorm2d(c2), 77 | nn.Sigmoid() 78 | ) 79 | 80 | def forward(self, x): 81 | fm = self.conv_3x3(x) 82 | fm_se = self.attention(fm) 83 | return fm * fm_se 84 | 85 | class FeatureFusionModule(nn.Module): 86 | def __init__(self, c1, c2, reduction=1) -> None: 87 | super().__init__() 88 | self.conv_1x1 = ConvModule(c1, c2, 1, 1, 0) 89 | 90 | self.attention = nn.Sequential( 91 | nn.AdaptiveAvgPool2d(1), 92 | nn.Conv2d(c2, c2 // reduction, 1, bias=False), 93 | nn.ReLU(True), 94 | nn.Conv2d(c2 // reduction, c2, 1, bias=False), 95 | nn.Sigmoid() 96 | ) 97 | 98 | def forward(self, x1, x2): 99 | fm = torch.cat([x1, x2], dim=1) 100 | fm = self.conv_1x1(fm) 101 | fm_se = self.attention(fm) 102 | return fm + fm * fm_se 103 | 104 | class ContextSpatial(nn.Module): 105 | def __init__(self, context_backbone: str = 'MobileNetV2', in_channels: int = 3): 106 | """Context spatial backbone 107 | 108 | Args: 109 | context_backbone (str, optional): Backbone used. Defaults to 'MobileNetV2'. 110 | in_channels (int, optional): Number of input channels, should be 3 in most cases. Defaults to 3. 111 | """ 112 | super().__init__() 113 | self.context_path = ContextPath(eval(context_backbone)()) 114 | self.spatial_path = SpatialPath(3, 128) 115 | self.ffm = FeatureFusionModule(256, 256) 116 | 117 | def forward(self, x): 118 | spatial_out = self.spatial_path(x) 119 | context16, context32 = self.context_path(x) 120 | fm_fuse = self.ffm(spatial_out, context16) 121 | outs = [fm_fuse] 122 | return outs 123 | 124 | if __name__ == "__main__": 125 | 126 | model = ContextSpatial() 127 | model.eval() 128 | 129 | shapes = [224, 256, 384, 512] 130 | for shape in shapes: 131 | print("\nShape", shape) 132 | x = torch.zeros(1, 3, shape, shape) 133 | outs = model(x) 134 | for out in outs: 135 | print(out.shape) 136 | -------------------------------------------------------------------------------- /src/luxonis_train/models/backbones/efficient_rep.py: -------------------------------------------------------------------------------- 1 | # 2 | # Adapted from: https://github.com/meituan/YOLOv6/blob/725913050e15a31cd091dfd7795a1891b0524d35/yolov6/models/efficientrep.py 3 | # License: https://github.com/meituan/YOLOv6/blob/main/LICENSE 4 | # 5 | 6 | import torch.nn as nn 7 | import torch 8 | 9 | from luxonis_train.models.modules import RepVGGBlock, RepBlock, SimplifiedSPPF 10 | from luxonis_train.utils.general import make_divisible 11 | 12 | class EfficientRep(nn.Module): 13 | def __init__(self, channels_list: list, num_repeats: list, in_channels: int = 3, depth_mul: float = 0.33, 14 | width_mul: float = 0.25, is_4head: bool = False): 15 | """EfficientRep backbone, normally used with YoloV6 model. 16 | 17 | Args: 18 | channels_list (list): List of number of channels for each block 19 | num_repeats (list): List of number of repeats of RepBlock 20 | in_channels (int, optional): Number of input channels, should be 3 in most cases . Defaults to 3. 21 | depth_mul (float, optional): Depth multiplier. Defaults to 0.33. 22 | width_mul (float, optional): Width multiplier. Defaults to 0.25. 23 | is_4head (bool, optional): Either build 4 headed architecture or 3 headed one \ 24 | (**Important: Should be same also on neck and head**). Defaults to False. 25 | """ 26 | super().__init__() 27 | 28 | channels_list = [make_divisible(i * width_mul, 8) for i in channels_list] 29 | num_repeats = [(max(round(i * depth_mul), 1) if i > 1 else i) for i in num_repeats] 30 | 31 | self.is_4head = is_4head 32 | 33 | self.stem = RepVGGBlock( 34 | in_channels=in_channels, 35 | out_channels=channels_list[0], 36 | kernel_size=3, 37 | stride=2 38 | ) 39 | 40 | self.blocks = nn.ModuleList() 41 | for i in range(4): 42 | curr_block = nn.Sequential( 43 | RepVGGBlock( 44 | in_channels=channels_list[i], 45 | out_channels=channels_list[i+1], 46 | kernel_size=3, 47 | stride=2 48 | ), 49 | RepBlock( 50 | in_channels=channels_list[i+1], 51 | out_channels=channels_list[i+1], 52 | n=num_repeats[i+1], 53 | ) 54 | ) 55 | if i == 3: 56 | curr_block.append( 57 | SimplifiedSPPF( 58 | in_channels=channels_list[i+1], 59 | out_channels=channels_list[i+1], 60 | kernel_size=5 61 | ) 62 | ) 63 | 64 | self.blocks.append(curr_block) 65 | 66 | def forward(self, x): 67 | outputs = [] 68 | x = self.stem(x) 69 | start_idx = 0 if self.is_4head else 1 # idx at which we start saving outputs 70 | for i, block in enumerate(self.blocks): 71 | x = block(x) 72 | if i >= start_idx: 73 | outputs.append(x) 74 | 75 | return outputs 76 | 77 | 78 | if __name__ == "__main__": 79 | num_repeats = [1, 6, 12, 18, 6] 80 | depth_mul = 0.33 81 | 82 | channels_list =[64, 128, 256, 512, 1024] 83 | width_mul = 0.25 84 | 85 | model = EfficientRep(in_channels=3, channels_list=channels_list, num_repeats=num_repeats, 86 | depth_mul=depth_mul, width_mul=width_mul, is_4head=False) 87 | model.eval() 88 | 89 | shapes = [224, 256, 384, 512] 90 | for shape in shapes: 91 | print("\n\nShape", shape) 92 | x = torch.zeros(1, 3, shape, shape) 93 | outs = model(x) 94 | for out in outs: 95 | print(out.shape) 96 | -------------------------------------------------------------------------------- /src/luxonis_train/models/backbones/efficientnet.py: -------------------------------------------------------------------------------- 1 | # 2 | # Soure: https://github.com/rwightman/gen-efficientnet-pytorch 3 | # License: https://github.com/rwightman/gen-efficientnet-pytorch/blob/master/LICENSE 4 | # 5 | 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class EfficientNet(nn.Module): 12 | def __init__(self, download_weights: bool = False): 13 | """EfficientNet backbone 14 | 15 | Args: 16 | download_weights (bool, optional): If True download weights from imagenet. Defaults to False. 17 | """ 18 | super().__init__() 19 | efficientnet_lite0_model = torch.hub.load('rwightman/gen-efficientnet-pytorch', 'efficientnet_lite0', pretrained=download_weights) 20 | self.out_indices = [1, 2, 4, 6] 21 | self.backbone = efficientnet_lite0_model 22 | 23 | def forward(self, X): 24 | outs = [] 25 | X = self.backbone.conv_stem(X) 26 | X = self.backbone.bn1(X) 27 | X = self.backbone.act1(X) 28 | for i, m in enumerate(self.backbone.blocks): 29 | X = m(X) 30 | if i in self.out_indices: 31 | outs.append(X) 32 | return outs 33 | 34 | 35 | if __name__ == '__main__': 36 | 37 | model = EfficientNet() 38 | model.eval() 39 | 40 | shapes = [224, 256, 384, 512] 41 | 42 | for shape in shapes: 43 | print("\n\nShape", shape) 44 | x = torch.zeros(1, 3, shape, shape) 45 | outs = model(x) 46 | if isinstance(outs, list): 47 | for out in outs: 48 | print(out.shape) 49 | else: 50 | print(outs.shape) -------------------------------------------------------------------------------- /src/luxonis_train/models/backbones/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | # 2 | # Soure: https://pytorch.org/vision/main/models/generated/torchvision.models.mobilenet_v2.html 3 | # License: https://github.com/pytorch/pytorch/blob/master/LICENSE 4 | # 5 | 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torchvision 10 | 11 | 12 | class MobileNetV2(nn.Module): 13 | 14 | def __init__(self, download_weights: bool = False): 15 | """MobileNetV2 backbone 16 | 17 | Args: 18 | download_weights (bool, optional): If True download weights from imagenet. Defaults to False. 19 | """ 20 | super().__init__() 21 | mobilenet_v2 = torchvision.models.mobilenet_v2(weights="DEFAULT" if download_weights else None) 22 | self.out_indices = [3, 6, 13, 17] 23 | self.channels = [24, 32, 96, 320] 24 | self.backbone = mobilenet_v2 25 | 26 | def forward(self, X): 27 | outs = [] 28 | for i, m in enumerate(self.backbone.features): 29 | X = m(X) 30 | if i in self.out_indices: 31 | outs.append(X) 32 | return outs 33 | 34 | 35 | if __name__ == '__main__': 36 | 37 | model = MobileNetV2(download_weights=True) 38 | model.eval() 39 | shapes = [224, 256, 384, 512] 40 | for shape in shapes: 41 | print("\nShape", shape) 42 | x = torch.zeros(1, 3, shape, shape) 43 | outs = model(x) 44 | for out in outs: 45 | print(out.shape) 46 | -------------------------------------------------------------------------------- /src/luxonis_train/models/backbones/rep_vgg.py: -------------------------------------------------------------------------------- 1 | # 2 | # Soure: https://github.com/DingXiaoH/RepVGG 3 | # License: https://github.com/DingXiaoH/RepVGG/blob/main/LICENSE 4 | # 5 | 6 | 7 | import torch.nn as nn 8 | import torch 9 | import torch.utils.checkpoint as checkpoint 10 | 11 | from luxonis_train.models.modules import RepVGGBlock 12 | 13 | class RepVGG_(nn.Module): 14 | 15 | def __init__(self, num_blocks, num_classes=1000, width_multiplier=None, override_groups_map=None, deploy=False, use_se=False, use_checkpoint=False): 16 | super(RepVGG_, self).__init__() 17 | assert len(width_multiplier) == 4 18 | self.deploy = deploy 19 | self.override_groups_map = override_groups_map or dict() 20 | assert 0 not in self.override_groups_map 21 | self.use_se = use_se 22 | self.use_checkpoint = use_checkpoint 23 | 24 | self.in_planes = min(64, int(64 * width_multiplier[0])) 25 | self.stage0 = RepVGGBlock(in_channels=3, out_channels=self.in_planes, kernel_size=3, stride=2, padding=1, deploy=self.deploy, use_se=self.use_se) 26 | self.cur_layer_idx = 1 27 | self.stage1 = self._make_stage(int(64 * width_multiplier[0]), num_blocks[0], stride=2) 28 | self.stage2 = self._make_stage(int(128 * width_multiplier[1]), num_blocks[1], stride=2) 29 | self.stage3 = self._make_stage(int(256 * width_multiplier[2]), num_blocks[2], stride=2) 30 | self.stage4 = self._make_stage(int(512 * width_multiplier[3]), num_blocks[3], stride=2) 31 | self.gap = nn.AdaptiveAvgPool2d(output_size=1) 32 | self.linear = nn.Linear(int(512 * width_multiplier[3]), num_classes) 33 | 34 | def _make_stage(self, planes, num_blocks, stride): 35 | strides = [stride] + [1]*(num_blocks-1) 36 | blocks = [] 37 | for stride in strides: 38 | cur_groups = self.override_groups_map.get(self.cur_layer_idx, 1) 39 | blocks.append(RepVGGBlock(in_channels=self.in_planes, out_channels=planes, kernel_size=3, 40 | stride=stride, padding=1, groups=cur_groups, deploy=self.deploy, use_se=self.use_se)) 41 | self.in_planes = planes 42 | self.cur_layer_idx += 1 43 | return nn.ModuleList(blocks) 44 | 45 | def forward(self, x): 46 | outputs = [] 47 | out = self.stage0(x) 48 | for stage in (self.stage1, self.stage2, self.stage3, self.stage4): 49 | for block in stage: 50 | if self.use_checkpoint: 51 | out = checkpoint.checkpoint(block, out) 52 | else: 53 | out = block(out) 54 | outputs.append(out) 55 | return outputs 56 | 57 | 58 | optional_groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26] 59 | g2_map = {l: 2 for l in optional_groupwise_layers} 60 | g4_map = {l: 4 for l in optional_groupwise_layers} 61 | 62 | def create_RepVGG_A0(deploy=False, use_checkpoint=False): 63 | return RepVGG_(num_blocks=[2, 4, 14, 1], num_classes=1000, 64 | width_multiplier=[0.75, 0.75, 0.75, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint) 65 | 66 | def create_RepVGG_A1(deploy=False, use_checkpoint=False): 67 | return RepVGG_(num_blocks=[2, 4, 14, 1], num_classes=1000, 68 | width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint) 69 | 70 | def create_RepVGG_A2(deploy=False, use_checkpoint=False): 71 | return RepVGG_(num_blocks=[2, 4, 14, 1], num_classes=1000, 72 | width_multiplier=[1.5, 1.5, 1.5, 2.75], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint) 73 | 74 | def create_RepVGG_B0(deploy=False, use_checkpoint=False): 75 | return RepVGG_(num_blocks=[4, 6, 16, 1], num_classes=1000, 76 | width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint) 77 | 78 | def create_RepVGG_B1(deploy=False, use_checkpoint=False): 79 | return RepVGG_(num_blocks=[4, 6, 16, 1], num_classes=1000, 80 | width_multiplier=[2, 2, 2, 4], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint) 81 | 82 | def create_RepVGG_B1g2(deploy=False, use_checkpoint=False): 83 | return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000, 84 | width_multiplier=[2, 2, 2, 4], override_groups_map=g2_map, deploy=deploy, use_checkpoint=use_checkpoint) 85 | 86 | def create_RepVGG_B1g4(deploy=False, use_checkpoint=False): 87 | return RepVGG_(num_blocks=[4, 6, 16, 1], num_classes=1000, 88 | width_multiplier=[2, 2, 2, 4], override_groups_map=g4_map, deploy=deploy, use_checkpoint=use_checkpoint) 89 | 90 | 91 | def create_RepVGG_B2(deploy=False, use_checkpoint=False): 92 | return RepVGG_(num_blocks=[4, 6, 16, 1], num_classes=1000, 93 | width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint) 94 | 95 | def create_RepVGG_B2g2(deploy=False, use_checkpoint=False): 96 | return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000, 97 | width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g2_map, deploy=deploy, use_checkpoint=use_checkpoint) 98 | 99 | def create_RepVGG_B2g4(deploy=False, use_checkpoint=False): 100 | return RepVGG_(num_blocks=[4, 6, 16, 1], num_classes=1000, 101 | width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g4_map, deploy=deploy, use_checkpoint=use_checkpoint) 102 | 103 | 104 | def create_RepVGG_B3(deploy=False, use_checkpoint=False): 105 | return RepVGG_(num_blocks=[4, 6, 16, 1], num_classes=1000, 106 | width_multiplier=[3, 3, 3, 5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint) 107 | 108 | def create_RepVGG_B3g2(deploy=False, use_checkpoint=False): 109 | return RepVGG_(num_blocks=[4, 6, 16, 1], num_classes=1000, 110 | width_multiplier=[3, 3, 3, 5], override_groups_map=g2_map, deploy=deploy, use_checkpoint=use_checkpoint) 111 | 112 | def create_RepVGG_B3g4(deploy=False, use_checkpoint=False): 113 | return RepVGG_(num_blocks=[4, 6, 16, 1], num_classes=1000, 114 | width_multiplier=[3, 3, 3, 5], override_groups_map=g4_map, deploy=deploy, use_checkpoint=use_checkpoint) 115 | 116 | def create_RepVGG_D2se(deploy=False, use_checkpoint=False): 117 | return RepVGG_(num_blocks=[8, 14, 24, 1], num_classes=1000, 118 | width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, deploy=deploy, use_se=True, use_checkpoint=use_checkpoint) 119 | 120 | 121 | func_dict = { 122 | 'RepVGG-A0': create_RepVGG_A0, 123 | 'RepVGG-A1': create_RepVGG_A1, 124 | 'RepVGG-A2': create_RepVGG_A2, 125 | 'RepVGG-B0': create_RepVGG_B0, 126 | 'RepVGG-B1': create_RepVGG_B1, 127 | 'RepVGG-B1g2': create_RepVGG_B1g2, 128 | 'RepVGG-B1g4': create_RepVGG_B1g4, 129 | 'RepVGG-B2': create_RepVGG_B2, 130 | 'RepVGG-B2g2': create_RepVGG_B2g2, 131 | 'RepVGG-B2g4': create_RepVGG_B2g4, 132 | 'RepVGG-B3': create_RepVGG_B3, 133 | 'RepVGG-B3g2': create_RepVGG_B3g2, 134 | 'RepVGG-B3g4': create_RepVGG_B3g4, 135 | 'RepVGG-D2se': create_RepVGG_D2se, # Updated at April 25, 2021. This is not reported in the CVPR paper. 136 | } 137 | def get_RepVGG_func_by_name(name): 138 | return func_dict[name] 139 | 140 | 141 | 142 | # Use this for converting a RepVGG model or a bigger model with RepVGG as its component 143 | # Use like this 144 | # model = create_RepVGG_A0(deploy=False) 145 | # train model or load weights 146 | # repvgg_model_convert(model, save_path='repvgg_deploy.pth') 147 | # If you want to preserve the original model, call with do_copy=True 148 | 149 | # ====================== for using RepVGG as the backbone of a bigger model, e.g., PSPNet, the pseudo code will be like 150 | # train_backbone = create_RepVGG_B2(deploy=False) 151 | # train_backbone.load_state_dict(torch.load('RepVGG-B2-train.pth')) 152 | # train_pspnet = build_pspnet(backbone=train_backbone) 153 | # segmentation_train(train_pspnet) 154 | # deploy_pspnet = repvgg_model_convert(train_pspnet) 155 | # segmentation_test(deploy_pspnet) 156 | # ===================== example_pspnet.py shows an example 157 | 158 | 159 | class RepVGG(nn.Module): 160 | def __init__(self, variant: str = "A0"): 161 | """RepVGG baackbone 162 | 163 | Args: 164 | variant (str, optional): Variant from ['A0', 'A1']. Defaults to "A0". 165 | """ 166 | super().__init__() 167 | assert variant in ["A0", "A1"] 168 | 169 | model_create = create_RepVGG_A0 if variant == "A0" else create_RepVGG_A1 170 | self.model = model_create(deploy=False) 171 | 172 | def forward(self, X): 173 | features = self.model(X) 174 | return features 175 | 176 | 177 | if __name__ == "__main__": 178 | 179 | for variant in ["A0", "A1"]: 180 | model = RepVGG(variant=variant) 181 | model.eval() 182 | print("Variant:", variant) 183 | shapes = [224, 256, 384, 512] 184 | for shape in shapes: 185 | print("\n\nShape", shape) 186 | x = torch.zeros(1, 3, shape, shape) 187 | outs = model(x) 188 | for out in outs: 189 | print(out.shape) -------------------------------------------------------------------------------- /src/luxonis_train/models/backbones/resnet18.py: -------------------------------------------------------------------------------- 1 | # 2 | # Soure: https://pytorch.org/vision/main/models/generated/torchvision.models.resnet18.html 3 | # License: https://github.com/pytorch/pytorch/blob/master/LICENSE 4 | # 5 | 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torchvision 10 | 11 | 12 | class ResNet18(nn.Module): 13 | 14 | def __init__(self, download_weights: bool = False): 15 | """ResNet18 backbone 16 | 17 | Args: 18 | download_weights (bool, optional): If True download weights from imagenet. Defaults to False. 19 | """ 20 | super().__init__() 21 | 22 | resnet18 = torchvision.models.resnet18(weights="DEFAULT" if download_weights else None) 23 | self.channels = [64, 128, 256, 512] 24 | self.backbone = resnet18 25 | 26 | def forward(self, X): 27 | outs = [] 28 | X = self.backbone.conv1(X) 29 | X = self.backbone.bn1(X) 30 | X = self.backbone.relu(X) 31 | X = self.backbone.maxpool(X) 32 | 33 | X = self.backbone.layer1(X) 34 | outs.append(X) 35 | X = self.backbone.layer2(X) 36 | outs.append(X) 37 | X = self.backbone.layer3(X) 38 | outs.append(X) 39 | X = self.backbone.layer4(X) 40 | outs.append(X) 41 | 42 | return outs 43 | 44 | if __name__ == '__main__': 45 | 46 | model = ResNet18() 47 | model.eval() 48 | 49 | shapes = [224, 256, 384, 512] 50 | 51 | for shape in shapes: 52 | print("\nShape", shape) 53 | x = torch.zeros(1, 3, shape, shape) 54 | outs = model(x) 55 | for out in outs: 56 | print(out.shape) -------------------------------------------------------------------------------- /src/luxonis_train/models/backbones/rexnetv1.py: -------------------------------------------------------------------------------- 1 | # 2 | # Soure: https://github.com/clovaai/rexnet 3 | # License: https://github.com/clovaai/rexnet/blob/master/LICENSE 4 | # 5 | 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from luxonis_train.utils.general import make_divisible 11 | from luxonis_train.models.modules import ConvModule 12 | 13 | 14 | class LinearBottleneck(nn.Module): 15 | def __init__(self, in_channels, channels, t, kernel_size=3, stride=1, **kwargs): 16 | super(LinearBottleneck, self).__init__(**kwargs) 17 | self.conv_shortcut = None 18 | self.use_shortcut = stride == 1 and in_channels <= channels 19 | self.in_channels = in_channels 20 | self.out_channels = channels 21 | out = [] 22 | if t != 1: 23 | dw_channels = in_channels * t 24 | out.append( 25 | ConvModule(in_channels=in_channels, out_channels=dw_channels, kernel_size=1, activation=nn.ReLU6(inplace=True)) 26 | ) 27 | else: 28 | dw_channels = in_channels 29 | out.append( 30 | ConvModule(in_channels=dw_channels, out_channels=dw_channels*1, kernel_size=kernel_size, stride=stride, 31 | padding=(kernel_size//2), groups=dw_channels, activation=nn.ReLU6(inplace=True) 32 | ) 33 | ) 34 | out.append( 35 | ConvModule(in_channels=dw_channels, out_channels=channels, kernel_size=1, activation=nn.Identity()) 36 | ) 37 | 38 | self.out = nn.Sequential(*out) 39 | 40 | def forward(self, x): 41 | out = self.out(x) 42 | 43 | if self.use_shortcut: 44 | # this results in a ScatterND node which isn't supported yet in myriad 45 | # out[:, 0:self.in_channels] += x 46 | 47 | a = out[:, :self.in_channels] 48 | b = x 49 | a = a + b 50 | c = out[:, self.in_channels:] 51 | d = torch.concat([a, c], dim=1) 52 | return d 53 | return out 54 | 55 | 56 | class ReXNetV1_lite(nn.Module): 57 | def __init__(self, fix_head_stem: bool = False, divisible_value: int = 8, input_ch: int = 16, 58 | final_ch: int = 164, multiplier: float = 1.0, kernel_conf: str = '333333'): 59 | """ReXNetV1_lite backbone 60 | 61 | Args: 62 | fix_head_stem (bool, optional): Weather to multiply head stem. Defaults to False. 63 | divisible_value (int, optional): Divisor used. Defaults to 8. 64 | input_ch (int, optional): Starting channel dimension. Defaults to 16. 65 | final_ch (int, optional): Final channel dimension. Defaults to 164. 66 | multiplier (float, optional): Channel dimension multiplier. Defaults to 1.0. 67 | kernel_conf (str, optional): Kernel sizes encoded as string. Defaults to '333333'. 68 | """ 69 | super().__init__() 70 | 71 | self.out_indices = [1,4, 10, 16] 72 | self.channels = [16, 48, 112, 184] 73 | layers = [1, 2, 2, 3, 3, 5] 74 | strides = [1, 2, 2, 2, 1, 2] 75 | kernel_sizes = [int(element) for element in kernel_conf] 76 | 77 | strides = sum([[element] + [1] * (layers[idx] - 1) for idx, element in enumerate(strides)], []) 78 | ts = [1] * layers[0] + [6] * sum(layers[1:]) 79 | kernel_sizes = sum([[element] * layers[idx] for idx, element in enumerate(kernel_sizes)], []) 80 | self.num_convblocks = sum(layers[:]) 81 | 82 | features = [] 83 | inplanes = input_ch / multiplier if multiplier < 1.0 else input_ch 84 | first_channel = 32 / multiplier if multiplier < 1.0 or fix_head_stem else 32 85 | first_channel = make_divisible(int(round(first_channel * multiplier)), divisible_value) 86 | 87 | in_channels_group = [] 88 | channels_group = [] 89 | 90 | features.append( 91 | ConvModule(3, first_channel, kernel_size=3, stride=2, padding=1, activation=nn.ReLU6(inplace=True)) 92 | ) 93 | 94 | for i in range(self.num_convblocks): 95 | inplanes_divisible = make_divisible(int(round(inplanes * multiplier)), divisible_value) 96 | if i == 0: 97 | in_channels_group.append(first_channel) 98 | channels_group.append(inplanes_divisible) 99 | else: 100 | in_channels_group.append(inplanes_divisible) 101 | inplanes += final_ch / (self.num_convblocks - 1 * 1.0) 102 | inplanes_divisible = make_divisible(int(round(inplanes * multiplier)), divisible_value) 103 | channels_group.append(inplanes_divisible) 104 | 105 | for block_idx, (in_c, c, t, k, s) in enumerate( 106 | zip(in_channels_group, channels_group, ts, kernel_sizes, strides)): 107 | features.append(LinearBottleneck(in_channels=in_c, 108 | channels=c, 109 | t=t, 110 | kernel_size=k, 111 | stride=s )) 112 | 113 | pen_channels = int(1280 * multiplier) if multiplier > 1 and not fix_head_stem else 1280 114 | features.append( 115 | ConvModule(in_channels=c, out_channels=pen_channels, kernel_size=1, activation=nn.ReLU6(inplace=True)) 116 | ) 117 | self.features = nn.Sequential(*features) 118 | 119 | def forward(self, x): 120 | outs = [] 121 | for i, m in enumerate(self.features): 122 | x = m(x) 123 | if i in self.out_indices: 124 | outs.append(x) 125 | return outs 126 | 127 | 128 | if __name__ == '__main__': 129 | model = ReXNetV1_lite(multiplier=1.0) 130 | model.eval() 131 | 132 | shapes = [224, 256, 384, 512] 133 | 134 | for shape in shapes: 135 | print("\nShape", shape) 136 | x = torch.zeros(1, 3, shape, shape) 137 | outs = model(x) 138 | if isinstance(outs, list): 139 | for out in outs: 140 | print(out.shape) 141 | else: 142 | print(outs.shape) -------------------------------------------------------------------------------- /src/luxonis_train/models/heads/README.md: -------------------------------------------------------------------------------- 1 | ## List of supported heads 2 | Every head takes this parameters: 3 | - n_classes: int # number of classes to predict 4 | - attach_index: int # on which backbone/neck layer should the head attach to. Defaults to -1. 5 | 6 | Here is a list of all supported heads and any additional parameters they take: 7 | - ClassificationHead 8 | - Params: 9 | - fc_dropout: float # Dropout rate before last layer, range [0,1]. Defaults to 0.2. 10 | - MultiLabelClassificationHead 11 | - Params: 12 | - fc_dropout: float # Dropout rate before last layer, range [0,1]. Defaults to 0.2. 13 | - SegmentationHead (adapted from [here](https://github.com/pytorch/vision/blob/main/torchvision/models/segmentation/fcn.py)) 14 | - BiSeNetHead (adapted from [here](https://github.com/taveraantonio/BiseNetv1)) 15 | - Params: 16 | - c1: int # Number of input channels. Defaults to 256. 17 | - upscale_factor: int # Factor used for upscaling input. Defaults to 8. 18 | - is_aux: bool # Either use 256 for intermediate channels or 64. Defaults to False 19 | - EffiDeHead (adapted from [here](https://github.com/meituan/YOLOv6/blob/725913050e15a31cd091dfd7795a1891b0524d35/yolov6/models/effidehead.py)) 20 | - Params: 21 | - n_anchors: int # Should stay default. Defaults to 1. 22 | - YoloV6Head (adapted from [here](https://github.com/meituan/YOLOv6/blob/725913050e15a31cd091dfd7795a1891b0524d35/yolov6/models/effidehead.py)) 23 | - Params: 24 | - is_4head: bool # Either build 4 headed architecture or 3 headed one (**Important: Should be same also on backbone and neck**). Defaults to False. 25 | - IKeypoint (adapted from [here](https://github.com/WongKinYiu/yolov7)) 26 | - Params: 27 | - n_keypoints: int # Number of keypoints 28 | - anchors: list # Anchors used for object detection 29 | - connectivity: list # Connectivity mapping used in visualization. Defaults to None. -------------------------------------------------------------------------------- /src/luxonis_train/models/heads/__init__.py: -------------------------------------------------------------------------------- 1 | from .classification_head import ClassificationHead 2 | from .multilabel_classification_head import MultiLabelClassificationHead 3 | from .segmentation_head import SegmentationHead 4 | from .bisenet_head import BiSeNetHead 5 | from .yolov6_head import YoloV6Head 6 | from .effide_head import EffiDeHead 7 | from .ikeypoint_head import IKeypoint 8 | 9 | __all__ = [ 10 | "ClassificationHead", 11 | "MultiLabelClassificationHead", 12 | "SegmentationHead", 13 | "BiSeNetHead", 14 | "YoloV6Head", 15 | "EffiDeHead", 16 | "IKeypoint" 17 | ] 18 | -------------------------------------------------------------------------------- /src/luxonis_train/models/heads/bisenet_head.py: -------------------------------------------------------------------------------- 1 | # 2 | # Source: https://github.com/taveraantonio/BiseNetv1 3 | # 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from luxonis_train.models.modules import ConvModule 9 | from luxonis_train.models.heads.base_heads import BaseSegmentationHead 10 | 11 | class BiSeNetHead(BaseSegmentationHead): 12 | def __init__(self, n_classes: int, prev_out_shapes: list, original_in_shape: list, attach_index: int = -1, 13 | c1: int = 256, upscale_factor: int = 8, is_aux: bool = False, **kwargs): 14 | """ BiSeNet segmentation head 15 | 16 | Args: 17 | n_classes (int): NUmber of classes 18 | prev_out_shapes (list): List of shapes of previous outputs 19 | original_in_shape (list): Original inpuut shape to the model 20 | attach_index (int, optional): Index of previous output that the head attaches to. Defaults to -1. 21 | c1 (int, optional): Number of input channels. Defaults to 256. 22 | upscale_factor (int, optional): Factor used for upscaling input. Defaults to 8. 23 | is_aux (bool, optional): Either use 256 for intermediate channels or 64. Defaults to False. 24 | """ 25 | super().__init__(n_classes=n_classes, prev_out_shapes=prev_out_shapes, original_in_shape=original_in_shape, 26 | attach_index=attach_index) 27 | 28 | ch = 256 if is_aux else 64 29 | c2 = n_classes * upscale_factor * upscale_factor 30 | self.conv_3x3 = ConvModule(c1, ch, 3, 1, 1) 31 | self.conv_1x1 = nn.Conv2d(ch, c2, 1, 1, 0) 32 | self.upscale = nn.PixelShuffle(upscale_factor) 33 | 34 | def forward(self, x): 35 | x = self.conv_1x1(self.conv_3x3(x[self.attach_index])) 36 | return self.upscale(x) 37 | 38 | 39 | if __name__ == "__main__": 40 | from luxonis_train.models.backbones import ContextSpatial 41 | 42 | backbone = ContextSpatial() 43 | backbone.eval() 44 | 45 | head = BiSeNetHead(n_classes=2) 46 | head.eval() 47 | 48 | shapes = [224, 256, 384, 512] 49 | for shape in shapes: 50 | print("\nShape", shape) 51 | x = torch.zeros(1, 3, shape, shape) 52 | outs = backbone(x) 53 | outs = head(outs) 54 | for out in outs: 55 | print(out.shape) -------------------------------------------------------------------------------- /src/luxonis_train/models/heads/classification_head.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from luxonis_train.models.heads.base_heads import BaseClassificationHead 4 | 5 | class ClassificationHead(BaseClassificationHead): 6 | def __init__(self, n_classes: int, prev_out_shapes: list, original_in_shape: list, attach_index: int = -1, 7 | fc_dropout: float = 0.2, **kwargs): 8 | """Simple classification head 9 | 10 | Args: 11 | n_classes (int): Number of classes 12 | prev_out_shapes (list): List of shapes of previous outputs 13 | original_in_shape (list): Original input shape to the model 14 | attach_index (int, optional): Index of previous output that the head attaches to. Defaults to -1. 15 | fc_dropout (float, optional): Dropout rate before last layer, range [0,1]. Defaults to 0.2. 16 | """ 17 | super().__init__(n_classes=n_classes, prev_out_shapes=prev_out_shapes, 18 | original_in_shape=original_in_shape, attach_index=attach_index) 19 | 20 | in_channels = self.prev_out_shapes[self.attach_index][1] 21 | self.head = nn.Sequential( 22 | nn.AdaptiveAvgPool2d(1), 23 | nn.Flatten(), 24 | nn.Dropout(p=fc_dropout), 25 | nn.Linear(in_channels, n_classes) 26 | ) 27 | 28 | def forward(self, x): 29 | out = self.head(x[self.attach_index]) 30 | return out -------------------------------------------------------------------------------- /src/luxonis_train/models/heads/effide_head.py: -------------------------------------------------------------------------------- 1 | # 2 | # Adapted from: https://github.com/meituan/YOLOv6/blob/725913050e15a31cd091dfd7795a1891b0524d35/yolov6/models/effidehead.py 3 | # License: https://github.com/meituan/YOLOv6/blob/main/LICENSE 4 | # 5 | 6 | import torch 7 | import torch.nn as nn 8 | import math 9 | 10 | from luxonis_train.models.modules import ConvModule 11 | from luxonis_train.models.heads.base_heads import BaseObjectDetection 12 | 13 | class EffiDeHead(BaseObjectDetection): 14 | def __init__(self, n_classes: int, prev_out_shapes: list, original_in_shape: list, attach_index: int = -1, 15 | n_anchors:int = 1): 16 | """EffieDeHead object detection head which is part of YoloV6 head 17 | 18 | Args: 19 | n_classes (int): Number of classes 20 | prev_out_shapes (list): List of shapes of previous outputs 21 | original_in_shape (list): Original input shape to the model 22 | attach_index (int, optional): Index of previous output that the head attaches to. Defaults to -1. 23 | n_anchors (int, optional): Should stay default. Defaults to 1. 24 | """ 25 | super().__init__(n_classes=n_classes, prev_out_shapes=prev_out_shapes, 26 | original_in_shape=original_in_shape, attach_index=attach_index) 27 | 28 | self.n_anchors = n_anchors 29 | self.prior_prob = 1e-2 30 | 31 | in_channels = self.prev_out_shapes[self.attach_index][1] 32 | self.head = nn.Sequential(*[ 33 | # stem 34 | ConvModule( 35 | in_channels=in_channels, 36 | out_channels=in_channels, 37 | kernel_size=1, 38 | stride=1, 39 | activation=nn.SiLU() 40 | ), 41 | # cls_conv 42 | ConvModule( 43 | in_channels=in_channels, 44 | out_channels=in_channels, 45 | kernel_size=3, 46 | stride=1, 47 | padding= 3//2, 48 | activation=nn.SiLU() 49 | ), 50 | # reg_conv 51 | ConvModule( 52 | in_channels=in_channels, 53 | out_channels=in_channels, 54 | kernel_size=3, 55 | stride=1, 56 | padding=3//2, 57 | activation=nn.SiLU() 58 | ), 59 | # cls_pred 60 | nn.Conv2d( 61 | in_channels=in_channels, 62 | out_channels=self.n_classes * self.n_anchors, 63 | kernel_size=1 64 | ), 65 | # reg_pred 66 | nn.Conv2d( 67 | in_channels=in_channels, 68 | out_channels=4 * (self.n_anchors), 69 | kernel_size=1 70 | ) 71 | ]) 72 | self.initialize_biases() 73 | 74 | def initialize_biases(self): 75 | # cls_pred 76 | conv = self.head[3] 77 | b = conv.bias.view(-1, ) 78 | b.data.fill_(-math.log((1 - self.prior_prob) / self.prior_prob)) 79 | conv.bias = nn.Parameter(b.view(-1), requires_grad=True) 80 | w = conv.weight 81 | w.data.fill_(0.) 82 | conv.weight = nn.Parameter(w, requires_grad=True) 83 | 84 | # reg_pred 85 | conv = self.head[4] 86 | b = conv.bias.view(-1, ) 87 | b.data.fill_(1.0) 88 | conv.bias = nn.Parameter(b.view(-1), requires_grad=True) 89 | w = conv.weight 90 | w.data.fill_(0.) 91 | conv.weight = nn.Parameter(w, requires_grad=True) 92 | 93 | def forward(self, x): 94 | out = self.head[0](x[self.attach_index]) 95 | out_cls = self.head[1](out) 96 | out_cls = self.head[3](out_cls) 97 | out_reg = self.head[2](out) 98 | out_reg = self.head[4](out_reg) 99 | 100 | return [x[-1], out_cls, out_reg] 101 | 102 | if __name__ == "__main__": 103 | from luxonis_train.models.backbones import * 104 | from luxonis_train.utils.general import dummy_input_run 105 | 106 | backbone = ResNet18() 107 | backbone_out_shapes = dummy_input_run(backbone, [1,3,224,224]) 108 | backbone.eval() 109 | 110 | shapes = [224, 256, 384, 512] 111 | shapes = [512] 112 | for shape in shapes: 113 | print("\nShape", shape) 114 | x = torch.zeros(1, 3, shape, shape) 115 | outs = backbone(x) 116 | head = EffiDeHead(prev_out_shape=backbone_out_shapes, n_classes=10, original_in_shape=x.shape) 117 | head.eval() 118 | outs = head(outs) 119 | for i in range(len(outs)): 120 | print(f"Output {i}:") 121 | if isinstance(outs[i], list): 122 | for o in outs[i]: 123 | print(len(o) if isinstance(o, list) else o.shape) 124 | else: 125 | print(outs[i].shape) -------------------------------------------------------------------------------- /src/luxonis_train/models/heads/ikeypoint_head.py: -------------------------------------------------------------------------------- 1 | # 2 | # Adapted from: https://github.com/WongKinYiu/yolov7 3 | # License: https://github.com/WongKinYiu/yolov7/blob/main/LICENSE.md 4 | # 5 | 6 | import torch 7 | import torch.nn as nn 8 | import math 9 | from typing import List 10 | from torchvision.ops import box_convert 11 | from torchvision.utils import draw_bounding_boxes, draw_keypoints 12 | 13 | from luxonis_train.models.heads.base_heads import BaseHead 14 | from luxonis_train.models.modules import ConvModule, autopad 15 | from luxonis_train.utils.constants import HeadType, LabelType 16 | from luxonis_train.utils.boxutils import non_max_suppression_kpts 17 | 18 | class IKeypoint(BaseHead): 19 | head_types: List[HeadType] = [HeadType.OBJECT_DETECTION, HeadType.KEYPOINT_DETECTION] 20 | label_types: List[LabelType] = [LabelType.BOUNDINGBOX, LabelType.KEYPOINT] 21 | 22 | def __init__(self, n_classes: int, prev_out_shapes: list, original_in_shape: list, n_keypoints: int, anchors: list, 23 | attach_index: int = -1, main_metric: str = "map", connectivity: list = None, **kwargs): 24 | """IKeypoint head which is used for object and keypoint detection 25 | 26 | Args: 27 | n_classes (int): Number of classes 28 | prev_out_shapes (list): List of shapes of previous outputs 29 | original_in_shape (list): Original input shape to the model 30 | n_keypoints (int): Number of keypoints 31 | anchors (list): Anchors used for object detection 32 | attach_index (int, optional): Index of previous output that the head attaches to. Defaults to -1. 33 | main_metric (str, optional): Name of the main metric which is used for tracking training process. Defaults to "map". 34 | connectivity (list, optional): Connectivity mapping used in visualization. Defaults to None. 35 | """ 36 | super().__init__(n_classes=n_classes, prev_out_shapes=prev_out_shapes, 37 | original_in_shape=original_in_shape, attach_index=attach_index) 38 | 39 | self.main_metric: str = main_metric 40 | 41 | self.n_keypoints = n_keypoints 42 | self.connectivity = connectivity 43 | 44 | ch = [prev[1] for prev in self.prev_out_shapes] 45 | self.gr = 1.0 # TODO: find out what this is 46 | self.no_det = n_classes + 5 # number of outputs per anchor for box and class 47 | self.no_kpt = 3 * self.n_keypoints # number of outputs per anchor for keypoints 48 | self.no = self.no_det + self.no_kpt 49 | self.nl = len(anchors) # number of detection layers 50 | self.na = len(anchors[0]) // 2 # number of anchors 51 | self.grid = [torch.zeros(1)] * self.nl # init grid 52 | self.flip_test = False 53 | 54 | a = torch.tensor(anchors).float().view(self.nl, -1, 2) 55 | self.anchors = a # shape(nl,na,2) 56 | self.anchor_grid = a.clone().view(self.nl, 1, -1, 1, 1, 2) 57 | self.m = nn.ModuleList(nn.Conv2d(x, self.no_det * self.na, 1) for x in ch) 58 | 59 | self.ia = nn.ModuleList(ImplicitA(x) for x in ch) 60 | self.im = nn.ModuleList(ImplicitM(self.no_det * self.na) for _ in ch) 61 | 62 | self.m_kpt = nn.ModuleList( 63 | nn.Sequential( 64 | ConvModule(x, x, kernel_size=3, padding=autopad(3), groups=math.gcd(x,x), 65 | activation=nn.SiLU()), 66 | ConvModule(x,x,kernel_size=1, padding=autopad(1), activation=nn.SiLU()), 67 | 68 | ConvModule(x, x, kernel_size=3, padding=autopad(3), groups=math.gcd(x,x), 69 | activation=nn.SiLU()), 70 | ConvModule(x,x,kernel_size=1, padding=autopad(1), activation=nn.SiLU()), 71 | 72 | ConvModule(x, x, kernel_size=3, padding=autopad(3), groups=math.gcd(x,x), 73 | activation=nn.SiLU()), 74 | ConvModule(x,x,kernel_size=1, padding=autopad(1), activation=nn.SiLU()), 75 | 76 | ConvModule(x, x, kernel_size=3, padding=autopad(3), groups=math.gcd(x,x), 77 | activation=nn.SiLU()), 78 | ConvModule(x,x,kernel_size=1, padding=autopad(1), activation=nn.SiLU()), 79 | 80 | ConvModule(x, x, kernel_size=3, padding=autopad(3), groups=math.gcd(x,x), 81 | activation=nn.SiLU()), 82 | ConvModule(x,x,kernel_size=1, padding=autopad(1), activation=nn.SiLU()), 83 | 84 | ConvModule(x, x, kernel_size=3, padding=autopad(3), groups=math.gcd(x,x), 85 | activation=nn.SiLU()), 86 | nn.Conv2d(x, self.no_kpt * self.na, 1) 87 | ) for x in ch 88 | ) 89 | 90 | self.stride = torch.tensor([self.original_in_shape[2] / x[2] 91 | for x in self.prev_out_shapes] 92 | ) 93 | self.anchors /= self.stride.view(-1, 1, 1) 94 | self._check_anchor_order() 95 | 96 | def forward(self, inputs): 97 | z = [] # inference output 98 | x = [] # layer outputs 99 | 100 | if self.anchor_grid.device != inputs[0].device: 101 | self.anchor_grid = self.anchor_grid.to(inputs[0].device) 102 | 103 | for i in range(self.nl): 104 | x.append(torch.cat( 105 | (self.im[i](self.m[i](self.ia[i](inputs[i]))), 106 | self.m_kpt[i](inputs[i])), axis=1)) # type: ignore 107 | 108 | bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85) 109 | x[i] = x[i].view(bs, self.na, self.no, ny, nx 110 | ).permute(0, 1, 3, 4, 2).contiguous() 111 | x_det = x[i][..., :5 + self.n_classes] 112 | x_kpt = x[i][..., 5 + self.n_classes:] 113 | 114 | # from this point down only needed for inference 115 | if self.grid[i].shape[2:4] != x[i].shape[2:4]: 116 | self.grid[i] = self._make_grid(nx, ny).to(x[i].device) 117 | kpt_grid_x = self.grid[i][..., 0:1] 118 | kpt_grid_y = self.grid[i][..., 1:2] 119 | 120 | if self.n_keypoints == 0: 121 | y = x[i].sigmoid() 122 | else: 123 | y = x_det.sigmoid() 124 | 125 | xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy 126 | wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].view( 127 | 1, self.na, 1, 1, 2) # wh 128 | x_kpt[..., 0::3] = (x_kpt[..., ::3] * 2. - 0.5 + kpt_grid_x.repeat( 129 | 1,1,1,1,self.n_keypoints)) * self.stride[i] # xy 130 | x_kpt[..., 1::3] = (x_kpt[..., 1::3] * 2. - 0.5 + kpt_grid_y.repeat( 131 | 1,1,1,1,self.n_keypoints)) * self.stride[i] # xy 132 | x_kpt[..., 2::3] = x_kpt[..., 2::3].sigmoid() 133 | 134 | y = torch.cat((xy, wh, y[..., 4:], x_kpt), dim = -1) 135 | z.append(y.view(bs, -1, self.no)) 136 | 137 | # returns Tuple[kpt, features] 138 | return torch.cat(z, 1), x 139 | 140 | def postprocess_for_loss(self, output: tuple, label_dict: dict): 141 | kpts = label_dict[LabelType.KEYPOINT] 142 | boxes = label_dict[LabelType.BOUNDINGBOX] 143 | nkpts = (kpts.shape[1] - 2) // 3 144 | label = torch.zeros((len(boxes), nkpts * 2 + 6)) 145 | label[:, :2] = boxes[:, :2] 146 | label[:, 2:6] = box_convert(boxes[:, 2:], "xywh", "cxcywh") 147 | label[:,6::2] = kpts[:,2::3] # insert kp x coordinates 148 | label[:,7::2] = kpts[:,3::3] # insert kp y coordinates 149 | return output, label 150 | 151 | def postprocess_for_metric(self, output: tuple, label_dict: dict): 152 | kpts = label_dict[LabelType.KEYPOINT] 153 | boxes = label_dict[LabelType.BOUNDINGBOX] 154 | nkpts = (kpts.shape[1] - 2) // 3 155 | label = torch.zeros((len(boxes), nkpts * 2 + 6)) 156 | label[:, :2] = boxes[:, :2] 157 | label[:, 2:6] = box_convert(boxes[:, 2:], "xywh", "cxcywh") 158 | label[:,6::2] = kpts[:,2::3] # insert kp x coordinates 159 | label[:,7::2] = kpts[:,3::3] # insert kp y coordinates 160 | 161 | nms = non_max_suppression_kpts(output[0]) 162 | output_list_map = [] 163 | label_list_map = [] 164 | image_size = self.original_in_shape[2:] 165 | for i in range(len(nms)): 166 | output_list_map.append({ 167 | "boxes": nms[i][:, :4], 168 | "scores": nms[i][:, 4], 169 | "labels": nms[i][:, 5].int(), 170 | }) 171 | 172 | curr_label = label[label[:, 0] == i].to(nms[i].device) 173 | curr_bboxs = box_convert(curr_label[:, 2: 6], "cxcywh", "xyxy") 174 | curr_bboxs[:, 0::2] *= image_size[1] 175 | curr_bboxs[:, 1::2] *= image_size[0] 176 | label_list_map.append({ 177 | "boxes": curr_bboxs, 178 | "labels": curr_label[:, 1].int(), 179 | }) 180 | 181 | output_list_oks, label_list_oks = [], [] # TODO: implement oks and add correct output and labels 182 | 183 | # metric mapping is needed here because each metrics requires different output/label format 184 | metric_mapping = {"map": 0, "oks": 1} 185 | return (output_list_map, output_list_oks), (label_list_map, label_list_oks), metric_mapping 186 | 187 | def draw_output_to_img(self, img: torch.Tensor, output: tuple, idx: int): 188 | curr_output = output[0][idx] 189 | nms = non_max_suppression_kpts(curr_output.unsqueeze(0), conf_thresh=0.25, iou_thresh=0.45)[0] 190 | bboxes = nms[:, :4] 191 | img = draw_bounding_boxes(img, bboxes) 192 | kpts = nms[:, 6:].reshape(-1, self.n_keypoints, 3) 193 | img = draw_keypoints(img, kpts, colors='red', connectivity=self.connectivity) 194 | return img 195 | 196 | def get_output_names(self, idx: int): 197 | # TODO: check if this is correct output name 198 | return f"output{idx}" 199 | 200 | def _make_grid(self, nx=20, ny=20): 201 | yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)], indexing="ij") 202 | return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float() 203 | 204 | def _check_anchor_order(self): 205 | a = self.anchor_grid.prod(-1).view(-1) # anchor area 206 | da = a[-1] - a[0] # delta a 207 | ds = self.stride[-1] - self.stride[0] # delta s 208 | if da.sign() != ds.sign(): # same order 209 | print('Reversing anchor order') 210 | self.anchors[:] = self.anchors.flip(0) 211 | self.anchor_grid[:] = self.anchor_grid.flip(0) 212 | 213 | 214 | class ImplicitA(nn.Module): 215 | def __init__(self, channel): 216 | super(ImplicitA, self).__init__() 217 | self.channel = channel 218 | self.implicit = nn.Parameter(torch.zeros(1, channel, 1, 1)) 219 | nn.init.normal_(self.implicit, std=.02) 220 | 221 | def forward(self, x): 222 | return self.implicit.expand_as(x) + x 223 | 224 | class ImplicitM(nn.Module): 225 | def __init__(self, channel): 226 | super(ImplicitM, self).__init__() 227 | self.channel = channel 228 | self.implicit = nn.Parameter(torch.ones(1, channel, 1, 1)) 229 | nn.init.normal_(self.implicit, mean=1., std=.02) 230 | 231 | def forward(self, x): 232 | return self.implicit.expand_as(x) * x -------------------------------------------------------------------------------- /src/luxonis_train/models/heads/multilabel_classification_head.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from luxonis_train.models.heads.base_heads import BaseMultiLabelClassificationHead 4 | 5 | class MultiLabelClassificationHead(BaseMultiLabelClassificationHead): 6 | def __init__(self, n_classes: int, prev_out_shapes: list, original_in_shape: list, attach_index: int = -1, 7 | fc_dropout: float = 0.2, **kwargs): 8 | """Simple multi-label classification head 9 | 10 | Args: 11 | n_classes (int): Number of classes 12 | prev_out_shapes (list): List of shapes of previous outputs 13 | original_in_shape (list): Original input shape to the model 14 | attach_index (int, optional): Index of previous output that the head attaches to. Defaults to -1. 15 | fc_dropout (float, optional): Dropout rate before last layer, range [0,1]. Defaults to 0.2. 16 | """ 17 | super().__init__(n_classes=n_classes, prev_out_shapes=prev_out_shapes, 18 | original_in_shape=original_in_shape, attach_index=attach_index) 19 | 20 | in_channels = self.prev_out_shapes[self.attach_index][1] 21 | self.head = nn.Sequential( 22 | nn.AdaptiveAvgPool2d(1), 23 | nn.Flatten(), 24 | nn.Dropout(p=fc_dropout), 25 | nn.Linear(in_channels, n_classes) 26 | ) 27 | 28 | def forward(self, x): 29 | out = self.head(x[self.attach_index]) 30 | return out -------------------------------------------------------------------------------- /src/luxonis_train/models/heads/segmentation_head.py: -------------------------------------------------------------------------------- 1 | # 2 | # Adapted from: https://github.com/pytorch/vision/blob/main/torchvision/models/segmentation/fcn.py 3 | # License: https://github.com/pytorch/vision/blob/main/LICENSE 4 | # 5 | 6 | 7 | import math 8 | import warnings 9 | import torch.nn as nn 10 | 11 | from luxonis_train.models.modules import Up 12 | from luxonis_train.models.heads.base_heads import BaseSegmentationHead 13 | 14 | class SegmentationHead(BaseSegmentationHead): 15 | def __init__(self, n_classes: int, prev_out_shapes: list, original_in_shape: list, attach_index: int = -1, **kwargs): 16 | """ Basic segmentation FCN head. Note that it doesn't ensure that ouptut is same size as input. 17 | 18 | Args: 19 | n_classes (int): NUmber of classes 20 | prev_out_shapes (list): List of shapes of previous outputs 21 | original_in_shape (list): Original inpuut shape to the model 22 | attach_index (int, optional): Index of previous output that the head attaches to. Defaults to -1. 23 | """ 24 | 25 | super().__init__(n_classes=n_classes, prev_out_shapes=prev_out_shapes, original_in_shape=original_in_shape, 26 | attach_index=attach_index) 27 | 28 | in_height = self.prev_out_shapes[self.attach_index][2] 29 | original_height = self.original_in_shape[2] 30 | num_up = math.log2(original_height) - math.log2(in_height) 31 | 32 | if not num_up.is_integer(): 33 | warnings.warn("Segmentation head's output shape not same as original input shape.") 34 | num_up = round(num_up) 35 | 36 | modules = [] 37 | in_channels = self.prev_out_shapes[self.attach_index][1] 38 | for _ in range(int(num_up)): 39 | modules.append(Up(in_channels=in_channels, out_channels=in_channels//2)) 40 | in_channels //= 2 41 | 42 | self.head = nn.Sequential( 43 | *modules, 44 | nn.Conv2d(in_channels, n_classes, kernel_size=1) 45 | ) 46 | 47 | def forward(self, x): 48 | out = self.head(x[self.attach_index]) 49 | return out -------------------------------------------------------------------------------- /src/luxonis_train/models/heads/yolov6_head.py: -------------------------------------------------------------------------------- 1 | # 2 | # Adapted from: https://github.com/meituan/YOLOv6/blob/725913050e15a31cd091dfd7795a1891b0524d35/yolov6/models/effidehead.py 3 | # License: https://github.com/meituan/YOLOv6/blob/main/LICENSE 4 | # 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torchvision.ops import box_convert 9 | from torchvision.utils import draw_bounding_boxes 10 | 11 | from .effide_head import EffiDeHead 12 | # from luxonis_train.models.heads.effide_head import EffiDeHead #import for unit testing 13 | from luxonis_train.models.heads.base_heads import BaseObjectDetection 14 | from luxonis_train.utils.assigners.anchor_generator import generate_anchors 15 | from luxonis_train.utils.boxutils import dist2bbox, non_max_suppression_bbox 16 | 17 | class YoloV6Head(BaseObjectDetection): 18 | def __init__(self, n_classes: int, prev_out_shapes: list, original_in_shape: list, attach_index: int = -1, 19 | is_4head: bool = False, **kwargs): 20 | """YoloV6 object detection head. With hardware-aware degisn, the decoupled head is optimized with 21 | hybridchannels methods. 22 | 23 | Args: 24 | n_classes (int): Number of classes 25 | prev_out_shapes (list): List of shapes of previous outputs 26 | original_in_shape (list): Original input shape to the model 27 | attach_index (int, optional): Index of previous output that the head attaches to. Defaults to -1. 28 | is_4head (bool, optional): Either build 4 headed architecture or 3 headed one 29 | (**Important: Should be same also on backbone and neck**). Defaults to False. 30 | """ 31 | super().__init__(n_classes=n_classes, prev_out_shapes=prev_out_shapes, 32 | original_in_shape=original_in_shape, attach_index=attach_index) 33 | 34 | self.no = n_classes + 5 # number of outputs per anchor 35 | self.is_4head = is_4head 36 | self.nl = 4 if self.is_4head else 3 # number of detection layers (support 3 and 4 heads) 37 | 38 | self.prior_prob = 1e-2 39 | 40 | self.n_anchors = 1 41 | stride = [4, 8, 16, 32] if self.is_4head else [8,16,32] # strides computed during build 42 | self.stride = torch.tensor(stride) 43 | self.grid = [torch.zeros(1)] * self.nl 44 | self.grid_cell_offset = 0.5 45 | self.grid_cell_size = 5.0 46 | 47 | self.head = nn.ModuleList() 48 | for i in range(self.nl): 49 | curr_head = EffiDeHead( 50 | prev_out_shapes=[self.prev_out_shapes[i]], 51 | original_in_shape=self.original_in_shape, 52 | n_classes=self.n_classes, 53 | n_anchors=self.n_anchors 54 | ) 55 | self.head.append(curr_head) 56 | 57 | def forward(self, x): 58 | cls_score_list = [] 59 | reg_distri_list = [] 60 | 61 | for i, module in enumerate(self.head): 62 | out_x, out_cls, out_reg = module([x[i]]) 63 | x[i] = out_x 64 | out_cls = torch.sigmoid(out_cls) 65 | cls_score_list.append(out_cls.flatten(2).permute((0, 2, 1))) 66 | reg_distri_list.append(out_reg.flatten(2).permute((0, 2, 1))) 67 | 68 | cls_score_list = torch.cat(cls_score_list, axis=1) 69 | reg_distri_list = torch.cat(reg_distri_list, axis=1) 70 | 71 | return [x, cls_score_list, reg_distri_list] 72 | 73 | def postprocess_for_loss(self, output: tuple, label_dict: dict): 74 | label = label_dict[self.label_types[0]] 75 | return output, label 76 | 77 | def postprocess_for_metric(self, output: tuple, label_dict: dict): 78 | label = label_dict[self.label_types[0]] 79 | 80 | output_nms = self._out2box(output) 81 | image_size = self.original_in_shape[2:] 82 | 83 | output_list = [] 84 | label_list = [] 85 | for i in range(len(output_nms)): 86 | output_list.append({ 87 | "boxes": output_nms[i][:,:4], 88 | "scores": output_nms[i][:,4], 89 | "labels": output_nms[i][:,5].int() 90 | }) 91 | 92 | curr_label = label[label[:,0]==i] 93 | curr_bboxs = box_convert(curr_label[:, 2:], "xywh", "xyxy") 94 | curr_bboxs[:, 0::2] *= image_size[1] 95 | curr_bboxs[:, 1::2] *= image_size[0] 96 | label_list.append({ 97 | "boxes": curr_bboxs, 98 | "labels": curr_label[:,1].int() 99 | }) 100 | 101 | return output_list, label_list, None 102 | 103 | def draw_output_to_img(self, img: torch.Tensor, output: tuple, idx: int): 104 | curr_output = self._out2box(output, conf_thres=0.3, iou_thres=0.6) 105 | curr_output = curr_output[idx] 106 | bboxs = curr_output[:,:4] 107 | img = draw_bounding_boxes(img, bboxs) 108 | return img 109 | 110 | def get_output_names(self, idx: int): 111 | output_names = ["output1_yolov6r2", "output2_yolov6r2", "output3_yolov6r2"] 112 | if self.is_4head: 113 | output_names.append("output4_yolov6r2") 114 | return output_names 115 | 116 | def to_deploy(self): 117 | # change definition of forward() 118 | def deploy_forward(x): 119 | outputs = [] 120 | for i, module in enumerate(self.head): 121 | out_x, out_cls, out_reg = module([x[i]]) 122 | out_cls = torch.sigmoid(out_cls) 123 | conf, _ = out_cls.max(1, keepdim=True) 124 | output = torch.cat([out_reg, conf, out_cls], axis=1) 125 | outputs.append(output) 126 | return outputs 127 | 128 | self.forward = deploy_forward 129 | 130 | def _out2box(self, output: tuple, **kwargs): 131 | """ Performs post-processing of the YoloV6 output and returns bboxs after NMS""" 132 | x, cls_score_list, reg_dist_list = output 133 | anchor_points, stride_tensor = generate_anchors(x, self.stride, 134 | self.grid_cell_size, self.grid_cell_offset, is_eval=True) 135 | pred_bboxes = dist2bbox(reg_dist_list, anchor_points, box_format="xywh") 136 | 137 | pred_bboxes *= stride_tensor 138 | output_merged = torch.cat([ 139 | pred_bboxes, 140 | torch.ones((x[-1].shape[0], pred_bboxes.shape[1], 1), dtype=pred_bboxes.dtype, device=pred_bboxes.device), 141 | cls_score_list 142 | ], axis=-1) 143 | 144 | conf_thres = kwargs.get("conf_thres", 0.001) 145 | iou_thres = kwargs.get("iou_thres", 0.6) 146 | output_nms = non_max_suppression_bbox(output_merged, conf_thres=conf_thres, iou_thres=iou_thres) 147 | 148 | return output_nms 149 | 150 | 151 | if __name__ == "__main__": 152 | # test yolov6-n config 153 | from luxonis_train.models.backbones import EfficientRep 154 | from luxonis_train.models.necks import RepPANNeck 155 | from luxonis_train.utils.general import dummy_input_run 156 | 157 | num_repeats_backbone = [1, 6, 12, 18, 6] 158 | num_repeats_neck = [12, 12, 12, 12] 159 | depth_mul = 0.33 160 | 161 | channels_list_backbone = [64, 128, 256, 512, 1024] 162 | channels_list_neck = [256, 128, 128, 256, 256, 512] 163 | width_mul = 0.25 164 | 165 | backbone = EfficientRep(in_channels=3, channels_list=channels_list_backbone, num_repeats=num_repeats_backbone, 166 | depth_mul=depth_mul, width_mul=width_mul, is_4head=True) 167 | for module in backbone.modules(): 168 | if hasattr(module, 'switch_to_deploy'): 169 | module.switch_to_deploy() 170 | backbone_out_shapes = dummy_input_run(backbone, [1,3,224,224]) 171 | backbone.eval() 172 | 173 | neck = RepPANNeck(prev_out_shape=backbone_out_shapes, channels_list=channels_list_neck, num_repeats=num_repeats_neck, 174 | depth_mul=depth_mul, width_mul=width_mul, is_4head=True) 175 | neck_out_shapes = dummy_input_run(neck, backbone_out_shapes, multi_input=True) 176 | neck.eval() 177 | 178 | shapes = [224, 256, 384, 512] 179 | for shape in shapes: 180 | print("\n\nShape", shape) 181 | x = torch.zeros(1, 3, shape, shape) 182 | head = YoloV6Head(prev_out_shape=neck_out_shapes, n_classes=10, 183 | original_in_shape=x.shape, is_4head=True) 184 | head.eval() 185 | outs = backbone(x) 186 | outs = neck(outs) 187 | outs = head(outs) 188 | for i in range(len(outs)): 189 | print(f"Output {i}:") 190 | if isinstance(outs[i], list): 191 | for o in outs[i]: 192 | print(o.shape) 193 | else: 194 | print(outs[i].shape) -------------------------------------------------------------------------------- /src/luxonis_train/models/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from .backbones import * 4 | from .necks import * 5 | from .heads import * 6 | 7 | from luxonis_train.utils.config import Config 8 | from luxonis_train.utils.general import dummy_input_run 9 | 10 | 11 | class Model(nn.Module): 12 | def __init__(self): 13 | """ Model class for [backbone, Optional(neck), heads] architectures """ 14 | super(Model, self).__init__() 15 | self.backbone = None 16 | self.neck = None 17 | self.heads = nn.ModuleList() 18 | 19 | def build_model(self): 20 | """ Builds the model from defined config """ 21 | cfg = Config() 22 | modules_cfg = cfg.get("model") 23 | dummy_input_shape = [1,3,]+cfg.get("train.preprocessing.train_image_size") # NOTE: we assume 3 dimensional input shape 24 | 25 | self.backbone = eval(modules_cfg["backbone"]["name"]) \ 26 | (**modules_cfg["backbone"].get("params", {})) 27 | # load local backbone weights if avaliable 28 | if modules_cfg["backbone"]["pretrained"]: 29 | self.backbone.load_state_dict(torch.load(modules_cfg["backbone"]["pretrained"])["state_dict"]) 30 | 31 | self.backbone_out_shapes = dummy_input_run(self.backbone, dummy_input_shape) 32 | 33 | if "neck" in modules_cfg and modules_cfg["neck"]: 34 | self.neck = eval(modules_cfg["neck"]["name"])( 35 | prev_out_shapes = self.backbone_out_shapes, 36 | **modules_cfg["neck"].get("params", {}) 37 | ) 38 | self.neck_out_shapes = dummy_input_run(self.neck, self.backbone_out_shapes, multi_input=True) 39 | 40 | for head in modules_cfg["heads"]: 41 | curr_head = eval(head["name"])( 42 | prev_out_shapes = self.neck_out_shapes if self.neck else self.backbone_out_shapes, 43 | original_in_shape = dummy_input_shape, 44 | **head["params"], 45 | ) 46 | self.heads.append(curr_head) 47 | 48 | def forward(self, x: torch.Tensor): 49 | """ Models forward method 50 | 51 | Args: 52 | x (torch.Tensor): Input batch 53 | 54 | Returns: 55 | outs (list): List of outputs for each models head 56 | """ 57 | out = self.backbone(x) 58 | if self.neck != None: 59 | out = self.neck(out) 60 | outs = [] 61 | for head in self.heads: 62 | curr_out = head(out) 63 | outs.append(curr_out) 64 | 65 | return outs -------------------------------------------------------------------------------- /src/luxonis_train/models/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .common import * -------------------------------------------------------------------------------- /src/luxonis_train/models/modules/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from torch.nn.parameter import Parameter 6 | import warnings 7 | 8 | class ConvModule(nn.Sequential): 9 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, 10 | dilation=1, groups=1, bias=False, activation=nn.ReLU()): 11 | """Conv2d + BN + ReLu""" 12 | super().__init__( 13 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias), 14 | nn.BatchNorm2d(out_channels), 15 | activation 16 | ) 17 | 18 | def autopad(k, p=None): 19 | """ Compute padding based on kernel size 20 | Source: https://github.com/WongKinYiu/yolov7/blob/pose/models/common.py 21 | """ 22 | if p is None: 23 | p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad 24 | return p 25 | 26 | 27 | class Up(nn.Sequential): 28 | def __init__(self, in_channels, out_channels, kernel_size=2, stride=2): 29 | """ Upsampling with ConvTranpose2D (similar to U-Net Up block) """ 30 | super().__init__( 31 | nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride), 32 | ConvModule(out_channels, out_channels, kernel_size=3, padding=1) 33 | ) 34 | 35 | class SEBlock(nn.Module): 36 | def __init__(self, in_channels, internal_channels): 37 | super(SEBlock, self).__init__() 38 | """ Squeeze and Excite module. 39 | Pytorch implementation of `Squeeze-and-Excitation Networks` - 40 | https://arxiv.org/pdf/1709.01507.pdf 41 | Source: https://github.com/apple/ml-mobileone/blob/main/mobileone.py 42 | """ 43 | self.down = nn.Conv2d(in_channels=in_channels, out_channels=internal_channels, kernel_size=1, stride=1, bias=True) 44 | self.up = nn.Conv2d(in_channels=internal_channels, out_channels=in_channels, kernel_size=1, stride=1, bias=True) 45 | self.in_channels = in_channels 46 | 47 | def forward(self, inputs): 48 | x = F.avg_pool2d(inputs, kernel_size=inputs.size(3)) 49 | x = self.down(x) 50 | x = F.relu(x) 51 | x = self.up(x) 52 | x = torch.sigmoid(x) 53 | x = x.view(-1, self.in_channels, 1, 1) 54 | return inputs * x 55 | 56 | def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1): 57 | """ Conv2d + BN 58 | Source: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py 59 | """ 60 | result = nn.Sequential() 61 | result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 62 | kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False)) 63 | result.add_module('bn', nn.BatchNorm2d(num_features=out_channels)) 64 | return result 65 | 66 | class RepVGGBlock(nn.Module): 67 | """Source:https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py""" 68 | 69 | def __init__(self, in_channels, out_channels, kernel_size=3, 70 | stride=1, padding=1, dilation=1, groups=1, padding_mode='zeros', deploy=False, use_se=False): 71 | super(RepVGGBlock, self).__init__() 72 | """ RepVGGBlock is a basic rep-style block, including training and deploy status 73 | This code is based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py 74 | """ 75 | self.deploy = deploy 76 | self.groups = groups 77 | self.in_channels = in_channels 78 | self.out_channels = out_channels 79 | 80 | assert kernel_size == 3 81 | assert padding == 1 82 | 83 | padding_11 = padding - kernel_size // 2 84 | 85 | self.nonlinearity = nn.ReLU() 86 | 87 | if use_se: 88 | # Note that RepVGG-D2se uses SE before nonlinearity. But RepVGGplus models uses SE after nonlinearity. 89 | self.se = SEBlock(out_channels, internal_channels=out_channels // 16) 90 | else: 91 | self.se = nn.Identity() 92 | 93 | if deploy: 94 | self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, 95 | padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode) 96 | else: 97 | self.rbr_identity = nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else None 98 | self.rbr_dense = conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=groups) 99 | self.rbr_1x1 = conv_bn(in_channels, out_channels, 1, stride, padding_11, groups=groups) 100 | 101 | def forward(self, inputs): 102 | if hasattr(self, 'rbr_reparam'): 103 | return self.nonlinearity(self.se(self.rbr_reparam(inputs))) 104 | 105 | if self.rbr_identity is None: 106 | id_out = 0 107 | else: 108 | id_out = self.rbr_identity(inputs) 109 | 110 | return self.nonlinearity(self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)) 111 | 112 | 113 | # Optional. This may improve the accuracy and facilitates quantization in some cases. 114 | # 1. Cancel the original weight decay on rbr_dense.conv.weight and rbr_1x1.conv.weight. 115 | # 2. Use like this. 116 | # loss = criterion(....) 117 | # for every RepVGGBlock blk: 118 | # loss += weight_decay_coefficient * 0.5 * blk.get_cust_L2() 119 | # optimizer.zero_grad() 120 | # loss.backward() 121 | def get_custom_L2(self): 122 | K3 = self.rbr_dense.conv.weight 123 | K1 = self.rbr_1x1.conv.weight 124 | t3 = (self.rbr_dense.bn.weight / ((self.rbr_dense.bn.running_var + self.rbr_dense.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach() 125 | t1 = (self.rbr_1x1.bn.weight / ((self.rbr_1x1.bn.running_var + self.rbr_1x1.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach() 126 | 127 | l2_loss_circle = (K3 ** 2).sum() - (K3[:, :, 1:2, 1:2] ** 2).sum() # The L2 loss of the "circle" of weights in 3x3 kernel. Use regular L2 on them. 128 | eq_kernel = K3[:, :, 1:2, 1:2] * t3 + K1 * t1 # The equivalent resultant central point of 3x3 kernel. 129 | l2_loss_eq_kernel = (eq_kernel ** 2 / (t3 ** 2 + t1 ** 2)).sum() # Normalize for an L2 coefficient comparable to regular L2. 130 | return l2_loss_eq_kernel + l2_loss_circle 131 | 132 | 133 | 134 | # This func derives the equivalent kernel and bias in a DIFFERENTIABLE way. 135 | # You can get the equivalent kernel and bias at any time and do whatever you want, 136 | # for example, apply some penalties or constraints during training, just like you do to the other models. 137 | # May be useful for quantization or pruning. 138 | def get_equivalent_kernel_bias(self): 139 | kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense) 140 | kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1) 141 | kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity) 142 | return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid 143 | 144 | def _pad_1x1_to_3x3_tensor(self, kernel1x1): 145 | if kernel1x1 is None: 146 | return 0 147 | else: 148 | return torch.nn.functional.pad(kernel1x1, [1,1,1,1]) 149 | 150 | def _fuse_bn_tensor(self, branch): 151 | if branch is None: 152 | return 0, 0 153 | if isinstance(branch, nn.Sequential): 154 | kernel = branch.conv.weight 155 | running_mean = branch.bn.running_mean 156 | running_var = branch.bn.running_var 157 | gamma = branch.bn.weight 158 | beta = branch.bn.bias 159 | eps = branch.bn.eps 160 | else: 161 | assert isinstance(branch, nn.BatchNorm2d) 162 | if not hasattr(self, 'id_tensor'): 163 | input_dim = self.in_channels // self.groups 164 | kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32) 165 | for i in range(self.in_channels): 166 | kernel_value[i, i % input_dim, 1, 1] = 1 167 | self.id_tensor = torch.from_numpy(kernel_value) 168 | kernel = self.id_tensor 169 | running_mean = branch.running_mean 170 | running_var = branch.running_var 171 | gamma = branch.weight 172 | beta = branch.bias 173 | eps = branch.eps 174 | std = (running_var + eps).sqrt() 175 | t = (gamma / std).reshape(-1, 1, 1, 1) 176 | return kernel * t, beta - running_mean * gamma / std 177 | 178 | def to_deploy(self): 179 | if hasattr(self, 'rbr_reparam'): 180 | return 181 | kernel, bias = self.get_equivalent_kernel_bias() 182 | self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.conv.in_channels, out_channels=self.rbr_dense.conv.out_channels, 183 | kernel_size=self.rbr_dense.conv.kernel_size, stride=self.rbr_dense.conv.stride, 184 | padding=self.rbr_dense.conv.padding, dilation=self.rbr_dense.conv.dilation, groups=self.rbr_dense.conv.groups, bias=True) 185 | self.rbr_reparam.weight.data = kernel 186 | self.rbr_reparam.bias.data = bias 187 | self.__delattr__('rbr_dense') 188 | self.__delattr__('rbr_1x1') 189 | if hasattr(self, 'rbr_identity'): 190 | self.__delattr__('rbr_identity') 191 | if hasattr(self, 'id_tensor'): 192 | self.__delattr__('id_tensor') 193 | self.deploy = True 194 | 195 | class RepBlock(nn.Module): 196 | def __init__(self, in_channels, out_channels, n=1): 197 | super(RepBlock, self).__init__() 198 | """ 199 | RepBlock is a stage block with rep-style basic block 200 | Adapted from: https://github.com/meituan/YOLOv6/blob/725913050e15a31cd091dfd7795a1891b0524d35/yolov6/layers/common.py 201 | """ 202 | 203 | self.conv1 = RepVGGBlock(in_channels, out_channels) 204 | self.block = nn.Sequential(*(RepVGGBlock(out_channels, out_channels) for _ in range(n - 1))) if n > 1 else None 205 | 206 | def forward(self, x): 207 | x = self.conv1(x) 208 | if self.block is not None: 209 | x = self.block(x) 210 | return x 211 | 212 | class SimplifiedSPPF(nn.Module): 213 | def __init__(self, in_channels, out_channels, kernel_size=5): 214 | super(SimplifiedSPPF, self).__init__() 215 | """ Simplified Spatial Pyramid Pooling with ReLU activation 216 | Adapted from: https://github.com/meituan/YOLOv6/blob/725913050e15a31cd091dfd7795a1891b0524d35/yolov6/layers/common.py 217 | """ 218 | c_ = in_channels // 2 # hidden channels 219 | self.cv1 = ConvModule(in_channels, c_, 1, 1) 220 | self.cv2 = ConvModule(c_*4, out_channels, 1, 1) 221 | self.m = nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=kernel_size // 2) 222 | 223 | def forward(self, x): 224 | # Pass the input feature map through the first convolutional layer 225 | x = self.cv1(x) 226 | 227 | # apply max-pooling at three different scales 228 | y1 = self.m(x) 229 | y2 = self.m(y1) 230 | y3 = self.m(y2) 231 | 232 | # Concatenate the original feature map and the three max-pooled versions 233 | # along the channel dimension and pass through the second convolutional layer 234 | out = self.cv2(torch.cat([x, y1, y2, y3], dim=1)) 235 | return out -------------------------------------------------------------------------------- /src/luxonis_train/models/necks/README.md: -------------------------------------------------------------------------------- 1 | ## List of supported necks 2 | - RepPANNeck (adapted from [here](https://github.com/meituan/YOLOv6/blob/725913050e15a31cd091dfd7795a1891b0524d35/yolov6/models/reppan.py)) 3 | - Params: 4 | - channels_list: List[int] # List of number of channels for each block 5 | - num_repeats: List[int] # List of number of repeats of RepBlock 6 | - depth_mul: int # Depth multiplier. Defaults to 0.33. 7 | - width_mul: int # Width multiplier. Defaults to 0.25. 8 | - is_4head: bool # Either build 4 headed architecture or 3 headed one (**Important: Should be same also on backbone and head**). Defaults to False. -------------------------------------------------------------------------------- /src/luxonis_train/models/necks/__init__.py: -------------------------------------------------------------------------------- 1 | from .reppan_neck import RepPANNeck 2 | 3 | __all__ = [ 4 | "RepPANNeck" 5 | ] -------------------------------------------------------------------------------- /src/luxonis_train/models/necks/reppan_neck.py: -------------------------------------------------------------------------------- 1 | # 2 | # Adapted from: https://github.com/meituan/YOLOv6/blob/725913050e15a31cd091dfd7795a1891b0524d35/yolov6/models/reppan.py 3 | # License: https://github.com/meituan/YOLOv6/blob/main/LICENSE 4 | # 5 | 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from luxonis_train.models.modules import RepBlock, ConvModule 11 | from luxonis_train.utils.general import make_divisible 12 | 13 | class RepPANNeck(nn.Module): 14 | def __init__(self, prev_out_shapes: list, channels_list: list, num_repeats: list, 15 | depth_mul: float = 0.33, width_mul: float = 0.25, is_4head: bool = False, **kwargs): 16 | """RepPANNeck normally used with YoloV6 model. It has the balance of feature fusion ability and hardware efficiency. 17 | 18 | Args: 19 | prev_out_shapes (list): List of shapes of previous outputs 20 | channels_list (list): List of number of channels for each block. 21 | num_repeats (list): List of number of repeats of RepBlock 22 | depth_mul (float, optional): Depth multiplier. Defaults to 0.33. 23 | width_mul (float, optional): Width multiplier. Defaults to 0.25. 24 | is_4head (bool, optional): Either build 4 headed architecture or 3 headed one \ 25 | (**Important: Should be same also on backbone and head**). Defaults to False. 26 | """ 27 | super().__init__() 28 | 29 | channels_list = [make_divisible(i * width_mul, 8) for i in channels_list] 30 | num_repeats = [(max(round(i * depth_mul), 1) if i > 1 else i) for i in num_repeats] 31 | 32 | self.is_4head = is_4head 33 | prev_out_start_idx = 1 if self.is_4head else 0 34 | 35 | self.Rep_p4 = RepBlock( 36 | in_channels=prev_out_shapes[prev_out_start_idx+1][1] + channels_list[0], 37 | out_channels=channels_list[0], 38 | n=num_repeats[0], 39 | ) 40 | 41 | self.Rep_p3 = RepBlock( 42 | in_channels=prev_out_shapes[prev_out_start_idx][1] + channels_list[1], 43 | out_channels=channels_list[1], 44 | n=num_repeats[1], 45 | ) 46 | 47 | self.Rep_n3 = RepBlock( 48 | in_channels=channels_list[1] + channels_list[2], 49 | out_channels=channels_list[3], 50 | n=num_repeats[2], 51 | ) 52 | 53 | self.Rep_n4 = RepBlock( 54 | in_channels=channels_list[0] + channels_list[4], 55 | out_channels=channels_list[5], 56 | n=num_repeats[3], 57 | ) 58 | 59 | self.reduce_layer0 = ConvModule( 60 | in_channels=prev_out_shapes[prev_out_start_idx+2][1], 61 | out_channels=channels_list[0], 62 | kernel_size=1, 63 | stride=1 64 | ) 65 | 66 | self.upsample0 = torch.nn.ConvTranspose2d( 67 | in_channels=channels_list[0], 68 | out_channels=channels_list[0], 69 | kernel_size=2, 70 | stride=2, 71 | bias=True 72 | ) 73 | 74 | self.reduce_layer1 = ConvModule( 75 | in_channels=channels_list[0], 76 | out_channels=channels_list[1], 77 | kernel_size=1, 78 | stride=1 79 | ) 80 | 81 | self.upsample1 = torch.nn.ConvTranspose2d( 82 | in_channels=channels_list[1], 83 | out_channels=channels_list[1], 84 | kernel_size=2, 85 | stride=2, 86 | bias=True 87 | ) 88 | 89 | self.downsample2 = ConvModule( 90 | in_channels=channels_list[1], 91 | out_channels=channels_list[2], 92 | kernel_size=3, 93 | stride=2, 94 | padding=3 // 2 95 | ) 96 | 97 | self.downsample1 = ConvModule( 98 | in_channels=channels_list[3], 99 | out_channels=channels_list[4], 100 | kernel_size=3, 101 | stride=2, 102 | padding=3 // 2 103 | ) 104 | 105 | if self.is_4head: 106 | self.reduce_layer2 = ConvModule( 107 | in_channels=channels_list[1], 108 | out_channels=channels_list[1]//2, 109 | kernel_size=1, 110 | stride=1 111 | ) 112 | self.upsample2 = torch.nn.ConvTranspose2d( 113 | in_channels=channels_list[1]//2, 114 | out_channels=channels_list[1]//2, 115 | kernel_size=2, 116 | stride=2, 117 | bias=True 118 | ) 119 | self.Rep_p2 = RepBlock( 120 | in_channels=prev_out_shapes[prev_out_start_idx-1][1] + 121 | channels_list[1]//2, 122 | out_channels=channels_list[1]//2, 123 | n=num_repeats[1], 124 | ) 125 | self.downsample3 = ConvModule( 126 | in_channels=channels_list[1]//2, 127 | out_channels=channels_list[1], 128 | kernel_size=3, 129 | stride=2, 130 | padding=3 // 2 131 | ) 132 | self.Rep_n2 = RepBlock( 133 | in_channels=channels_list[1] + channels_list[1]//2, 134 | out_channels=channels_list[1], 135 | n=num_repeats[1], 136 | ) 137 | 138 | def forward(self, x): 139 | if self.is_4head: 140 | x3, x2, x1, x0 = x 141 | else: 142 | x2, x1, x0 = x 143 | 144 | fpn_out0 = self.reduce_layer0(x0) 145 | upsample_feat0 = self.upsample0(fpn_out0) 146 | f_concat_layer0 = torch.cat([upsample_feat0, x1], dim=1) 147 | f_out0 = self.Rep_p4(f_concat_layer0) 148 | 149 | fpn_out1 = self.reduce_layer1(f_out0) 150 | upsample_feat1 = self.upsample1(fpn_out1) 151 | f_concat_layer1 = torch.cat([upsample_feat1, x2], dim=1) 152 | pan_out2 = self.Rep_p3(f_concat_layer1) 153 | 154 | if self.is_4head: 155 | fpn_out2 = self.reduce_layer2(pan_out2) 156 | upsample_feat2 = self.upsample2(fpn_out2) 157 | f_concat_layer2 = torch.cat([upsample_feat2, x3], dim=1) 158 | pan_out3 = self.Rep_p2(f_concat_layer2) 159 | 160 | down_feat2 = self.downsample3(pan_out3) 161 | p_concat_layer0 = torch.cat([down_feat2, fpn_out2], dim=1) 162 | pan_out2 = self.Rep_n2(p_concat_layer0) 163 | 164 | down_feat1 = self.downsample2(pan_out2) 165 | p_concat_layer1 = torch.cat([down_feat1, fpn_out1], dim=1) 166 | pan_out1 = self.Rep_n3(p_concat_layer1) 167 | 168 | down_feat0 = self.downsample1(pan_out1) 169 | p_concat_layer2 = torch.cat([down_feat0, fpn_out0], dim=1) 170 | pan_out0 = self.Rep_n4(p_concat_layer2) 171 | 172 | if self.is_4head: 173 | outputs = [pan_out3, pan_out2, pan_out1, pan_out0] 174 | else: 175 | outputs = [pan_out2, pan_out1, pan_out0] 176 | 177 | return outputs 178 | 179 | 180 | if __name__ == "__main__": 181 | # test together with EfficientRep backbone 182 | from luxonis_train.models.backbones import EfficientRep 183 | from luxonis_train.utils.general import dummy_input_run 184 | 185 | num_repeats_backbone = [1, 6, 12, 18, 6] 186 | num_repeats_neck = [12, 12, 12, 12] 187 | depth_mul = 0.33 188 | 189 | channels_list_backbone = [64, 128, 256, 512, 1024] 190 | channels_list_neck = [256, 128, 128, 256, 256, 512] 191 | width_mul = 0.25 192 | 193 | backbone = EfficientRep(in_channels=3, channels_list=channels_list_backbone, num_repeats=num_repeats_backbone, 194 | depth_mul=depth_mul, width_mul=width_mul, is_4head=False) 195 | backbone_out_shapes = dummy_input_run(backbone, [1,3,224,224]) 196 | backbone.eval() 197 | 198 | neck = RepPANNeck(prev_out_shape=backbone_out_shapes, channels_list=channels_list_neck, num_repeats=num_repeats_neck, 199 | depth_mul=depth_mul, width_mul=width_mul, is_4head=False) 200 | neck.eval() 201 | 202 | shapes = [224, 256, 384, 512] 203 | for shape in shapes: 204 | print("\n\nShape", shape) 205 | x = torch.zeros(1, 3, shape, shape) 206 | outs = backbone(x) 207 | outs = neck(outs) 208 | for out in outs: 209 | print(out.shape) -------------------------------------------------------------------------------- /src/luxonis_train/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luxonis/models/bc1cf4b1ea3bdaf918d0f0af13e21ec64bae9e41/src/luxonis_train/utils/__init__.py -------------------------------------------------------------------------------- /src/luxonis_train/utils/assigners/__init__.py: -------------------------------------------------------------------------------- 1 | from .atts_assigner import ATSSAssigner 2 | from .tal_assigner import TaskAlignedAssigner 3 | from .anchor_generator import generate_anchors -------------------------------------------------------------------------------- /src/luxonis_train/utils/assigners/anchor_generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def generate_anchors(feats, fpn_strides, grid_cell_size=5.0, grid_cell_offset=0.5, is_eval=False): 5 | '''Generate anchors from features.''' 6 | device = feats[0].device 7 | anchors = [] 8 | anchor_points = [] 9 | stride_tensor = [] 10 | num_anchors_list = [] 11 | assert feats is not None 12 | if is_eval: 13 | for i, stride in enumerate(fpn_strides): 14 | _, _, h, w = feats[i].shape 15 | shift_x = torch.arange(end=w) + grid_cell_offset 16 | shift_y = torch.arange(end=h) + grid_cell_offset 17 | shift_y, shift_x = torch.meshgrid(shift_y, shift_x, indexing="ij") 18 | anchor_point = torch.stack( 19 | [shift_x, shift_y], axis=-1).to(torch.float) 20 | anchor_points.append(anchor_point.reshape([-1, 2])) 21 | stride_tensor.append( 22 | torch.full( 23 | (h * w, 1), stride, dtype=torch.float)) 24 | anchor_points = torch.cat(anchor_points).to(device) 25 | stride_tensor = torch.cat(stride_tensor).to(device) 26 | return anchor_points, stride_tensor 27 | else: 28 | for i, stride in enumerate(fpn_strides): 29 | _, _, h, w = feats[i].shape 30 | cell_half_size = grid_cell_size * stride * 0.5 31 | shift_x = (torch.arange(end=w) + grid_cell_offset) * stride 32 | shift_y = (torch.arange(end=h) + grid_cell_offset) * stride 33 | shift_y, shift_x = torch.meshgrid(shift_y, shift_x, indexing="ij") 34 | anchor = torch.stack( 35 | [ 36 | shift_x - cell_half_size, shift_y - cell_half_size, 37 | shift_x + cell_half_size, shift_y + cell_half_size 38 | ], 39 | axis=-1).clone().to(feats[0].dtype) 40 | anchor_point = torch.stack( 41 | [shift_x, shift_y], axis=-1).clone().to(feats[0].dtype) 42 | 43 | anchors.append(anchor.reshape([-1, 4])) 44 | anchor_points.append(anchor_point.reshape([-1, 2])) 45 | num_anchors_list.append(len(anchors[-1])) 46 | stride_tensor.append( 47 | torch.full( 48 | [num_anchors_list[-1], 1], stride, dtype=feats[0].dtype)) 49 | anchors = torch.cat(anchors).to(device) 50 | anchor_points = torch.cat(anchor_points).to(device) 51 | stride_tensor = torch.cat(stride_tensor).to(device) 52 | return anchors, anchor_points, num_anchors_list, stride_tensor 53 | -------------------------------------------------------------------------------- /src/luxonis_train/utils/assigners/assigner_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def dist_calculator(gt_bboxes, anchor_bboxes): 5 | """compute center distance between all bbox and gt 6 | 7 | Args: 8 | gt_bboxes (Tensor): shape(bs*n_max_boxes, 4) 9 | anchor_bboxes (Tensor): shape(num_total_anchors, 4) 10 | Return: 11 | distances (Tensor): shape(bs*n_max_boxes, num_total_anchors) 12 | ac_points (Tensor): shape(num_total_anchors, 2) 13 | """ 14 | gt_cx = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0 15 | gt_cy = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0 16 | gt_points = torch.stack([gt_cx, gt_cy], dim=1) 17 | ac_cx = (anchor_bboxes[:, 0] + anchor_bboxes[:, 2]) / 2.0 18 | ac_cy = (anchor_bboxes[:, 1] + anchor_bboxes[:, 3]) / 2.0 19 | ac_points = torch.stack([ac_cx, ac_cy], dim=1) 20 | 21 | distances = (gt_points[:, None, :] - ac_points[None, :, :]).pow(2).sum(-1).sqrt() 22 | 23 | return distances, ac_points 24 | 25 | def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9): 26 | """select the positive anchors's center in gt 27 | 28 | Args: 29 | xy_centers (Tensor): shape(bs*n_max_boxes, num_total_anchors, 4) 30 | gt_bboxes (Tensor): shape(bs, n_max_boxes, 4) 31 | Return: 32 | (Tensor): shape(bs, n_max_boxes, num_total_anchors) 33 | """ 34 | n_anchors = xy_centers.size(0) 35 | bs, n_max_boxes, _ = gt_bboxes.size() 36 | _gt_bboxes = gt_bboxes.reshape([-1, 4]) 37 | xy_centers = xy_centers.unsqueeze(0).repeat(bs * n_max_boxes, 1, 1) 38 | gt_bboxes_lt = _gt_bboxes[:, 0:2].unsqueeze(1).repeat(1, n_anchors, 1) 39 | gt_bboxes_rb = _gt_bboxes[:, 2:4].unsqueeze(1).repeat(1, n_anchors, 1) 40 | b_lt = xy_centers - gt_bboxes_lt 41 | b_rb = gt_bboxes_rb - xy_centers 42 | bbox_deltas = torch.cat([b_lt, b_rb], dim=-1) 43 | bbox_deltas = bbox_deltas.reshape([bs, n_max_boxes, n_anchors, -1]) 44 | return (bbox_deltas.min(axis=-1)[0] > eps).to(gt_bboxes.dtype) 45 | 46 | def select_highest_overlaps(mask_pos, overlaps, n_max_boxes): 47 | """if an anchor box is assigned to multiple gts, 48 | the one with the highest iou will be selected. 49 | 50 | Args: 51 | mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors) 52 | overlaps (Tensor): shape(bs, n_max_boxes, num_total_anchors) 53 | Return: 54 | target_gt_idx (Tensor): shape(bs, num_total_anchors) 55 | fg_mask (Tensor): shape(bs, num_total_anchors) 56 | mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors) 57 | """ 58 | fg_mask = mask_pos.sum(axis=-2) 59 | if fg_mask.max() > 1: 60 | mask_multi_gts = (fg_mask.unsqueeze(1) > 1).repeat([1, n_max_boxes, 1]) 61 | max_overlaps_idx = overlaps.argmax(axis=1) 62 | is_max_overlaps = F.one_hot(max_overlaps_idx, n_max_boxes) 63 | is_max_overlaps = is_max_overlaps.permute(0, 2, 1).to(overlaps.dtype) 64 | mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos) 65 | fg_mask = mask_pos.sum(axis=-2) 66 | target_gt_idx = mask_pos.argmax(axis=-2) 67 | return target_gt_idx, fg_mask , mask_pos 68 | 69 | def iou_calculator(box1, box2, eps=1e-9): 70 | """Calculate iou for batch 71 | 72 | Args: 73 | box1 (Tensor): shape(bs, n_max_boxes, 1, 4) 74 | box2 (Tensor): shape(bs, 1, num_total_anchors, 4) 75 | Return: 76 | (Tensor): shape(bs, n_max_boxes, num_total_anchors) 77 | """ 78 | box1 = box1.unsqueeze(2) # [N, M1, 4] -> [N, M1, 1, 4] 79 | box2 = box2.unsqueeze(1) # [N, M2, 4] -> [N, 1, M2, 4] 80 | px1y1, px2y2 = box1[:, :, :, 0:2], box1[:, :, :, 2:4] 81 | gx1y1, gx2y2 = box2[:, :, :, 0:2], box2[:, :, :, 2:4] 82 | x1y1 = torch.maximum(px1y1, gx1y1) 83 | x2y2 = torch.minimum(px2y2, gx2y2) 84 | overlap = (x2y2 - x1y1).clip(0).prod(-1) 85 | area1 = (px2y2 - px1y1).clip(0).prod(-1) 86 | area2 = (gx2y2 - gx1y1).clip(0).prod(-1) 87 | union = area1 + area2 - overlap + eps 88 | 89 | return overlap / union 90 | -------------------------------------------------------------------------------- /src/luxonis_train/utils/assigners/atts_assigner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .iou2d_calculator import iou2d_calculator 5 | from .assigner_utils import dist_calculator, select_candidates_in_gts, select_highest_overlaps, iou_calculator 6 | 7 | class ATSSAssigner(nn.Module): 8 | '''Adaptive Training Sample Selection Assigner''' 9 | def __init__(self, 10 | topk=9, 11 | num_classes=80): 12 | super(ATSSAssigner, self).__init__() 13 | self.topk = topk 14 | self.num_classes = num_classes 15 | self.bg_idx = num_classes 16 | 17 | @torch.no_grad() 18 | def forward(self, 19 | anc_bboxes, 20 | n_level_bboxes, 21 | gt_labels, 22 | gt_bboxes, 23 | mask_gt, 24 | pd_bboxes): 25 | r"""This code is based on 26 | https://github.com/fcjian/TOOD/blob/master/mmdet/core/bbox/assigners/atss_assigner.py 27 | 28 | Args: 29 | anc_bboxes (Tensor): shape(num_total_anchors, 4) 30 | n_level_bboxes (List):len(3) 31 | gt_labels (Tensor): shape(bs, n_max_boxes, 1) 32 | gt_bboxes (Tensor): shape(bs, n_max_boxes, 4) 33 | mask_gt (Tensor): shape(bs, n_max_boxes, 1) 34 | pd_bboxes (Tensor): shape(bs, n_max_boxes, 4) 35 | Returns: 36 | target_labels (Tensor): shape(bs, num_total_anchors) 37 | target_bboxes (Tensor): shape(bs, num_total_anchors, 4) 38 | target_scores (Tensor): shape(bs, num_total_anchors, num_classes) 39 | fg_mask (Tensor): shape(bs, num_total_anchors) 40 | """ 41 | self.n_anchors = anc_bboxes.size(0) 42 | self.bs = gt_bboxes.size(0) 43 | self.n_max_boxes = gt_bboxes.size(1) 44 | 45 | if self.n_max_boxes == 0: 46 | device = gt_bboxes.device 47 | return torch.full( [self.bs, self.n_anchors], self.bg_idx).to(device), \ 48 | torch.zeros([self.bs, self.n_anchors, 4]).to(device), \ 49 | torch.zeros([self.bs, self.n_anchors, self.num_classes]).to(device), \ 50 | torch.zeros([self.bs, self.n_anchors]).to(device) 51 | 52 | overlaps = iou2d_calculator(gt_bboxes.reshape([-1, 4]), anc_bboxes) 53 | overlaps = overlaps.reshape([self.bs, -1, self.n_anchors]) 54 | 55 | distances, ac_points = dist_calculator(gt_bboxes.reshape([-1, 4]), anc_bboxes) 56 | distances = distances.reshape([self.bs, -1, self.n_anchors]) 57 | 58 | is_in_candidate, candidate_idxs = self.select_topk_candidates( 59 | distances, n_level_bboxes, mask_gt) 60 | 61 | overlaps_thr_per_gt, iou_candidates = self.thres_calculator( 62 | is_in_candidate, candidate_idxs, overlaps) 63 | 64 | # select candidates iou >= threshold as positive 65 | is_pos = torch.where( 66 | iou_candidates > overlaps_thr_per_gt.repeat([1, 1, self.n_anchors]), 67 | is_in_candidate, torch.zeros_like(is_in_candidate)) 68 | 69 | is_in_gts = select_candidates_in_gts(ac_points, gt_bboxes) 70 | mask_pos = is_pos * is_in_gts * mask_gt 71 | 72 | target_gt_idx, fg_mask, mask_pos = select_highest_overlaps( 73 | mask_pos, overlaps, self.n_max_boxes) 74 | 75 | # assigned target 76 | target_labels, target_bboxes, target_scores = self.get_targets( 77 | gt_labels, gt_bboxes, target_gt_idx, fg_mask) 78 | 79 | # soft label with iou 80 | if pd_bboxes is not None: 81 | ious = iou_calculator(gt_bboxes, pd_bboxes) * mask_pos 82 | ious = ious.max(axis=-2)[0].unsqueeze(-1) 83 | target_scores *= ious 84 | 85 | return target_labels.long(), target_bboxes, target_scores, fg_mask.bool() 86 | 87 | def select_topk_candidates(self, 88 | distances, 89 | n_level_bboxes, 90 | mask_gt): 91 | 92 | mask_gt = mask_gt.repeat(1, 1, self.topk).bool() 93 | level_distances = torch.split(distances, n_level_bboxes, dim=-1) 94 | is_in_candidate_list = [] 95 | candidate_idxs = [] 96 | start_idx = 0 97 | for per_level_distances, per_level_boxes in zip(level_distances, n_level_bboxes): 98 | 99 | end_idx = start_idx + per_level_boxes 100 | selected_k = min(self.topk, per_level_boxes) 101 | _, per_level_topk_idxs = per_level_distances.topk(selected_k, dim=-1, largest=False) 102 | candidate_idxs.append(per_level_topk_idxs + start_idx) 103 | per_level_topk_idxs = torch.where(mask_gt, 104 | per_level_topk_idxs, torch.zeros_like(per_level_topk_idxs)) 105 | is_in_candidate = F.one_hot(per_level_topk_idxs, per_level_boxes).sum(dim=-2) 106 | is_in_candidate = torch.where(is_in_candidate > 1, 107 | torch.zeros_like(is_in_candidate), is_in_candidate) 108 | is_in_candidate_list.append(is_in_candidate.to(distances.dtype)) 109 | start_idx = end_idx 110 | 111 | is_in_candidate_list = torch.cat(is_in_candidate_list, dim=-1) 112 | candidate_idxs = torch.cat(candidate_idxs, dim=-1) 113 | 114 | return is_in_candidate_list, candidate_idxs 115 | 116 | def thres_calculator(self, 117 | is_in_candidate, 118 | candidate_idxs, 119 | overlaps): 120 | 121 | n_bs_max_boxes = self.bs * self.n_max_boxes 122 | _candidate_overlaps = torch.where(is_in_candidate > 0, 123 | overlaps, torch.zeros_like(overlaps)) 124 | candidate_idxs = candidate_idxs.reshape([n_bs_max_boxes, -1]) 125 | assist_idxs = self.n_anchors * torch.arange(n_bs_max_boxes, device=candidate_idxs.device) 126 | assist_idxs = assist_idxs[:,None] 127 | faltten_idxs = candidate_idxs + assist_idxs 128 | candidate_overlaps = _candidate_overlaps.reshape(-1)[faltten_idxs] 129 | candidate_overlaps = candidate_overlaps.reshape([self.bs, self.n_max_boxes, -1]) 130 | 131 | overlaps_mean_per_gt = candidate_overlaps.mean(axis=-1, keepdim=True) 132 | overlaps_std_per_gt = candidate_overlaps.std(axis=-1, keepdim=True) 133 | overlaps_thr_per_gt = overlaps_mean_per_gt + overlaps_std_per_gt 134 | 135 | return overlaps_thr_per_gt, _candidate_overlaps 136 | 137 | def get_targets(self, 138 | gt_labels, 139 | gt_bboxes, 140 | target_gt_idx, 141 | fg_mask): 142 | 143 | # assigned target labels 144 | batch_idx = torch.arange(self.bs, dtype=gt_labels.dtype, device=gt_labels.device) 145 | batch_idx = batch_idx[...,None] 146 | target_gt_idx = (target_gt_idx + batch_idx * self.n_max_boxes).long() 147 | target_labels = gt_labels.flatten()[target_gt_idx.flatten()] 148 | target_labels = target_labels.reshape([self.bs, self.n_anchors]) 149 | target_labels = torch.where(fg_mask > 0, 150 | target_labels, torch.full_like(target_labels, self.bg_idx)) 151 | 152 | # assigned target boxes 153 | target_bboxes = gt_bboxes.reshape([-1, 4])[target_gt_idx.flatten()] 154 | target_bboxes = target_bboxes.reshape([self.bs, self.n_anchors, 4]) 155 | 156 | # assigned target scores 157 | target_scores = F.one_hot(target_labels.long(), self.num_classes + 1).float() 158 | target_scores = target_scores[:, :, :self.num_classes] 159 | 160 | return target_labels, target_bboxes, target_scores 161 | -------------------------------------------------------------------------------- /src/luxonis_train/utils/assigners/iou2d_calculator.py: -------------------------------------------------------------------------------- 1 | #This code is based on 2 | #https://github.com/fcjian/TOOD/blob/master/mmdet/core/bbox/iou_calculators/iou2d_calculator.py 3 | 4 | import torch 5 | 6 | 7 | def cast_tensor_type(x, scale=1., dtype=None): 8 | if dtype == 'fp16': 9 | # scale is for preventing overflows 10 | x = (x / scale).half() 11 | return x 12 | 13 | 14 | def fp16_clamp(x, min=None, max=None): 15 | if not x.is_cuda and x.dtype == torch.float16: 16 | # clamp for cpu float16, tensor fp16 has no clamp implementation 17 | return x.float().clamp(min, max).half() 18 | 19 | return x.clamp(min, max) 20 | 21 | 22 | def iou2d_calculator(bboxes1, bboxes2, mode='iou', is_aligned=False, scale=1., dtype=None): 23 | """2D Overlaps (e.g. IoUs, GIoUs) Calculator.""" 24 | 25 | """Calculate IoU between 2D bboxes. 26 | 27 | Args: 28 | bboxes1 (Tensor): bboxes have shape (m, 4) in 29 | format, or shape (m, 5) in format. 30 | bboxes2 (Tensor): bboxes have shape (m, 4) in 31 | format, shape (m, 5) in format, or be 32 | empty. If ``is_aligned `` is ``True``, then m and n must be 33 | equal. 34 | mode (str): "iou" (intersection over union), "iof" (intersection 35 | over foreground), or "giou" (generalized intersection over 36 | union). 37 | is_aligned (bool, optional): If True, then m and n must be equal. 38 | Default False. 39 | 40 | Returns: 41 | Tensor: shape (m, n) if ``is_aligned `` is False else shape (m,) 42 | """ 43 | assert bboxes1.size(-1) in [0, 4, 5] 44 | assert bboxes2.size(-1) in [0, 4, 5] 45 | if bboxes2.size(-1) == 5: 46 | bboxes2 = bboxes2[..., :4] 47 | if bboxes1.size(-1) == 5: 48 | bboxes1 = bboxes1[..., :4] 49 | 50 | if dtype == 'fp16': 51 | # change tensor type to save cpu and cuda memory and keep speed 52 | bboxes1 = cast_tensor_type(bboxes1, scale, dtype) 53 | bboxes2 = cast_tensor_type(bboxes2, scale, dtype) 54 | overlaps = bbox_overlaps(bboxes1, bboxes2, mode, is_aligned) 55 | if not overlaps.is_cuda and overlaps.dtype == torch.float16: 56 | # resume cpu float32 57 | overlaps = overlaps.float() 58 | return overlaps 59 | 60 | return bbox_overlaps(bboxes1, bboxes2, mode, is_aligned) 61 | 62 | 63 | def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False, eps=1e-6): 64 | """Calculate overlap between two set of bboxes. 65 | 66 | FP16 Contributed by https://github.com/open-mmlab/mmdetection/pull/4889 67 | Note: 68 | Assume bboxes1 is M x 4, bboxes2 is N x 4, when mode is 'iou', 69 | there are some new generated variable when calculating IOU 70 | using bbox_overlaps function: 71 | 72 | 1) is_aligned is False 73 | area1: M x 1 74 | area2: N x 1 75 | lt: M x N x 2 76 | rb: M x N x 2 77 | wh: M x N x 2 78 | overlap: M x N x 1 79 | union: M x N x 1 80 | ious: M x N x 1 81 | 82 | Total memory: 83 | S = (9 x N x M + N + M) * 4 Byte, 84 | 85 | When using FP16, we can reduce: 86 | R = (9 x N x M + N + M) * 4 / 2 Byte 87 | R large than (N + M) * 4 * 2 is always true when N and M >= 1. 88 | Obviously, N + M <= N * M < 3 * N * M, when N >=2 and M >=2, 89 | N + 1 < 3 * N, when N or M is 1. 90 | 91 | Given M = 40 (ground truth), N = 400000 (three anchor boxes 92 | in per grid, FPN, R-CNNs), 93 | R = 275 MB (one times) 94 | 95 | A special case (dense detection), M = 512 (ground truth), 96 | R = 3516 MB = 3.43 GB 97 | 98 | When the batch size is B, reduce: 99 | B x R 100 | 101 | Therefore, CUDA memory runs out frequently. 102 | 103 | Experiments on GeForce RTX 2080Ti (11019 MiB): 104 | 105 | | dtype | M | N | Use | Real | Ideal | 106 | |:----:|:----:|:----:|:----:|:----:|:----:| 107 | | FP32 | 512 | 400000 | 8020 MiB | -- | -- | 108 | | FP16 | 512 | 400000 | 4504 MiB | 3516 MiB | 3516 MiB | 109 | | FP32 | 40 | 400000 | 1540 MiB | -- | -- | 110 | | FP16 | 40 | 400000 | 1264 MiB | 276MiB | 275 MiB | 111 | 112 | 2) is_aligned is True 113 | area1: N x 1 114 | area2: N x 1 115 | lt: N x 2 116 | rb: N x 2 117 | wh: N x 2 118 | overlap: N x 1 119 | union: N x 1 120 | ious: N x 1 121 | 122 | Total memory: 123 | S = 11 x N * 4 Byte 124 | 125 | When using FP16, we can reduce: 126 | R = 11 x N * 4 / 2 Byte 127 | 128 | So do the 'giou' (large than 'iou'). 129 | 130 | Time-wise, FP16 is generally faster than FP32. 131 | 132 | When gpu_assign_thr is not -1, it takes more time on cpu 133 | but not reduce memory. 134 | There, we can reduce half the memory and keep the speed. 135 | 136 | If ``is_aligned`` is ``False``, then calculate the overlaps between each 137 | bbox of bboxes1 and bboxes2, otherwise the overlaps between each aligned 138 | pair of bboxes1 and bboxes2. 139 | 140 | Args: 141 | bboxes1 (Tensor): shape (B, m, 4) in format or empty. 142 | bboxes2 (Tensor): shape (B, n, 4) in format or empty. 143 | B indicates the batch dim, in shape (B1, B2, ..., Bn). 144 | If ``is_aligned`` is ``True``, then m and n must be equal. 145 | mode (str): "iou" (intersection over union), "iof" (intersection over 146 | foreground) or "giou" (generalized intersection over union). 147 | Default "iou". 148 | is_aligned (bool, optional): If True, then m and n must be equal. 149 | Default False. 150 | eps (float, optional): A value added to the denominator for numerical 151 | stability. Default 1e-6. 152 | 153 | Returns: 154 | Tensor: shape (m, n) if ``is_aligned`` is False else shape (m,) 155 | 156 | Example: 157 | >>> bboxes1 = torch.FloatTensor([ 158 | >>> [0, 0, 10, 10], 159 | >>> [10, 10, 20, 20], 160 | >>> [32, 32, 38, 42], 161 | >>> ]) 162 | >>> bboxes2 = torch.FloatTensor([ 163 | >>> [0, 0, 10, 20], 164 | >>> [0, 10, 10, 19], 165 | >>> [10, 10, 20, 20], 166 | >>> ]) 167 | >>> overlaps = bbox_overlaps(bboxes1, bboxes2) 168 | >>> assert overlaps.shape == (3, 3) 169 | >>> overlaps = bbox_overlaps(bboxes1, bboxes2, is_aligned=True) 170 | >>> assert overlaps.shape == (3, ) 171 | 172 | Example: 173 | >>> empty = torch.empty(0, 4) 174 | >>> nonempty = torch.FloatTensor([[0, 0, 10, 9]]) 175 | >>> assert tuple(bbox_overlaps(empty, nonempty).shape) == (0, 1) 176 | >>> assert tuple(bbox_overlaps(nonempty, empty).shape) == (1, 0) 177 | >>> assert tuple(bbox_overlaps(empty, empty).shape) == (0, 0) 178 | """ 179 | 180 | assert mode in ['iou', 'iof', 'giou'], f'Unsupported mode {mode}' 181 | # Either the boxes are empty or the length of boxes' last dimension is 4 182 | assert (bboxes1.size(-1) == 4 or bboxes1.size(0) == 0) 183 | assert (bboxes2.size(-1) == 4 or bboxes2.size(0) == 0) 184 | 185 | # Batch dim must be the same 186 | # Batch dim: (B1, B2, ... Bn) 187 | assert bboxes1.shape[:-2] == bboxes2.shape[:-2] 188 | batch_shape = bboxes1.shape[:-2] 189 | 190 | rows = bboxes1.size(-2) 191 | cols = bboxes2.size(-2) 192 | if is_aligned: 193 | assert rows == cols 194 | 195 | if rows * cols == 0: 196 | if is_aligned: 197 | return bboxes1.new(batch_shape + (rows, )) 198 | else: 199 | return bboxes1.new(batch_shape + (rows, cols)) 200 | 201 | area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * ( 202 | bboxes1[..., 3] - bboxes1[..., 1]) 203 | area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * ( 204 | bboxes2[..., 3] - bboxes2[..., 1]) 205 | 206 | if is_aligned: 207 | lt = torch.max(bboxes1[..., :2], bboxes2[..., :2]) # [B, rows, 2] 208 | rb = torch.min(bboxes1[..., 2:], bboxes2[..., 2:]) # [B, rows, 2] 209 | 210 | wh = fp16_clamp(rb - lt, min=0) 211 | overlap = wh[..., 0] * wh[..., 1] 212 | 213 | if mode in ['iou', 'giou']: 214 | union = area1 + area2 - overlap 215 | else: 216 | union = area1 217 | if mode == 'giou': 218 | enclosed_lt = torch.min(bboxes1[..., :2], bboxes2[..., :2]) 219 | enclosed_rb = torch.max(bboxes1[..., 2:], bboxes2[..., 2:]) 220 | else: 221 | lt = torch.max(bboxes1[..., :, None, :2], 222 | bboxes2[..., None, :, :2]) # [B, rows, cols, 2] 223 | rb = torch.min(bboxes1[..., :, None, 2:], 224 | bboxes2[..., None, :, 2:]) # [B, rows, cols, 2] 225 | 226 | wh = fp16_clamp(rb - lt, min=0) 227 | overlap = wh[..., 0] * wh[..., 1] 228 | 229 | if mode in ['iou', 'giou']: 230 | union = area1[..., None] + area2[..., None, :] - overlap 231 | else: 232 | union = area1[..., None] 233 | if mode == 'giou': 234 | enclosed_lt = torch.min(bboxes1[..., :, None, :2], 235 | bboxes2[..., None, :, :2]) 236 | enclosed_rb = torch.max(bboxes1[..., :, None, 2:], 237 | bboxes2[..., None, :, 2:]) 238 | 239 | eps = union.new_tensor([eps]) 240 | union = torch.max(union, eps) 241 | ious = overlap / union 242 | if mode in ['iou', 'iof']: 243 | return ious 244 | # calculate gious 245 | enclose_wh = fp16_clamp(enclosed_rb - enclosed_lt, min=0) 246 | enclose_area = enclose_wh[..., 0] * enclose_wh[..., 1] 247 | enclose_area = torch.max(enclose_area, eps) 248 | gious = ious - (enclose_area - union) / enclose_area 249 | return gious 250 | -------------------------------------------------------------------------------- /src/luxonis_train/utils/assigners/tal_assigner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .assigner_utils import select_candidates_in_gts, select_highest_overlaps, iou_calculator 5 | 6 | class TaskAlignedAssigner(nn.Module): 7 | def __init__(self, 8 | topk=13, 9 | num_classes=80, 10 | alpha=1.0, 11 | beta=6.0, 12 | eps=1e-9): 13 | super(TaskAlignedAssigner, self).__init__() 14 | self.topk = topk 15 | self.num_classes = num_classes 16 | self.bg_idx = num_classes 17 | self.alpha = alpha 18 | self.beta = beta 19 | self.eps = eps 20 | 21 | @torch.no_grad() 22 | def forward(self, 23 | pd_scores, 24 | pd_bboxes, 25 | anc_points, 26 | gt_labels, 27 | gt_bboxes, 28 | mask_gt): 29 | """This code referenced to 30 | https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py 31 | 32 | Args: 33 | pd_scores (Tensor): shape(bs, num_total_anchors, num_classes) 34 | pd_bboxes (Tensor): shape(bs, num_total_anchors, 4) 35 | anc_points (Tensor): shape(num_total_anchors, 2) 36 | gt_labels (Tensor): shape(bs, n_max_boxes, 1) 37 | gt_bboxes (Tensor): shape(bs, n_max_boxes, 4) 38 | mask_gt (Tensor): shape(bs, n_max_boxes, 1) 39 | Returns: 40 | target_labels (Tensor): shape(bs, num_total_anchors) 41 | target_bboxes (Tensor): shape(bs, num_total_anchors, 4) 42 | target_scores (Tensor): shape(bs, num_total_anchors, num_classes) 43 | fg_mask (Tensor): shape(bs, num_total_anchors) 44 | """ 45 | self.bs = pd_scores.size(0) 46 | self.n_max_boxes = gt_bboxes.size(1) 47 | 48 | if self.n_max_boxes == 0: 49 | device = gt_bboxes.device 50 | return torch.full_like(pd_scores[..., 0], self.bg_idx).to(device), \ 51 | torch.zeros_like(pd_bboxes).to(device), \ 52 | torch.zeros_like(pd_scores).to(device), \ 53 | torch.zeros_like(pd_scores[..., 0]).to(device) 54 | 55 | 56 | mask_pos, align_metric, overlaps = self.get_pos_mask( 57 | pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt) 58 | 59 | target_gt_idx, fg_mask, mask_pos = select_highest_overlaps( 60 | mask_pos, overlaps, self.n_max_boxes) 61 | 62 | # assigned target 63 | target_labels, target_bboxes, target_scores = self.get_targets( 64 | gt_labels, gt_bboxes, target_gt_idx, fg_mask) 65 | 66 | # normalize 67 | align_metric *= mask_pos 68 | pos_align_metrics = align_metric.max(axis=-1, keepdim=True)[0] 69 | pos_overlaps = (overlaps * mask_pos).max(axis=-1, keepdim=True)[0] 70 | norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).max(-2)[0].unsqueeze(-1) 71 | target_scores = target_scores * norm_align_metric 72 | 73 | return target_labels, target_bboxes, target_scores, fg_mask.bool() 74 | 75 | def get_pos_mask(self, 76 | pd_scores, 77 | pd_bboxes, 78 | gt_labels, 79 | gt_bboxes, 80 | anc_points, 81 | mask_gt): 82 | 83 | # get anchor_align metric 84 | align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes) 85 | # get in_gts mask 86 | mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes) 87 | # get topk_metric mask 88 | mask_topk = self.select_topk_candidates( 89 | align_metric * mask_in_gts, topk_mask=mask_gt.repeat([1, 1, self.topk]).bool()) 90 | # merge all mask to a final mask 91 | mask_pos = mask_topk * mask_in_gts * mask_gt 92 | 93 | return mask_pos, align_metric, overlaps 94 | 95 | def get_box_metrics(self, 96 | pd_scores, 97 | pd_bboxes, 98 | gt_labels, 99 | gt_bboxes): 100 | 101 | pd_scores = pd_scores.permute(0, 2, 1) 102 | gt_labels = gt_labels.to(torch.long) 103 | ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long) 104 | ind[0] = torch.arange(end=self.bs).view(-1, 1).repeat(1, self.n_max_boxes) 105 | ind[1] = gt_labels.squeeze(-1) 106 | bbox_scores = pd_scores[ind[0], ind[1]] 107 | 108 | overlaps = iou_calculator(gt_bboxes, pd_bboxes) 109 | align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta) 110 | 111 | return align_metric, overlaps 112 | 113 | def select_topk_candidates(self, 114 | metrics, 115 | largest=True, 116 | topk_mask=None): 117 | 118 | num_anchors = metrics.shape[-1] 119 | topk_metrics, topk_idxs = torch.topk( 120 | metrics, self.topk, axis=-1, largest=largest) 121 | if topk_mask is None: 122 | topk_mask = (topk_metrics.max(axis=-1, keepdim=True) > self.eps).tile( 123 | [1, 1, self.topk]) 124 | topk_idxs = torch.where(topk_mask, topk_idxs, torch.zeros_like(topk_idxs)) 125 | is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(axis=-2) 126 | is_in_topk = torch.where(is_in_topk > 1, 127 | torch.zeros_like(is_in_topk), is_in_topk) 128 | return is_in_topk.to(metrics.dtype) 129 | 130 | def get_targets(self, 131 | gt_labels, 132 | gt_bboxes, 133 | target_gt_idx, 134 | fg_mask): 135 | 136 | # assigned target labels 137 | batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[...,None] 138 | target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes 139 | target_labels = gt_labels.long().flatten()[target_gt_idx] 140 | 141 | # assigned target boxes 142 | target_bboxes = gt_bboxes.reshape([-1, 4])[target_gt_idx] 143 | 144 | # assigned target scores 145 | target_labels[target_labels<0] = 0 146 | target_scores = F.one_hot(target_labels, self.num_classes) 147 | fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) 148 | target_scores = torch.where(fg_scores_mask > 0, target_scores, 149 | torch.full_like(target_scores, 0)) 150 | 151 | return target_labels, target_bboxes, target_scores 152 | -------------------------------------------------------------------------------- /src/luxonis_train/utils/callbacks.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from pytorch_lightning.callbacks import RichProgressBar 3 | from rich.table import Table 4 | 5 | class LuxonisProgressBar(RichProgressBar): 6 | """ Custom rich text progress bar based on RichProgressBar from Pytorch Lightning""" 7 | def __init__(self): 8 | # TODO: play with values to create custom output 9 | # from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme 10 | # progress_bar = RichProgressBar( 11 | # theme = RichProgressBarTheme( 12 | # description="green_yellow", 13 | # progress_bar="green1", 14 | # progress_bar_finished="green1", 15 | # batch_progress="green_yellow", 16 | # time="gray82", 17 | # processing_speed="grey82", 18 | # metrics="yellow1" 19 | # ) 20 | # ) 21 | 22 | super().__init__(leave=True) 23 | 24 | def print_single_line(self, text:str): 25 | self._console.print(f"[magenta]{text}[/magenta]") 26 | 27 | def get_metrics(self, trainer, pl_module): 28 | # NOTE: there might be a cleaner way of doing this 29 | items = super().get_metrics(trainer, pl_module) 30 | if trainer.training: 31 | items["Loss"] = pl_module.training_step_outputs[-1]["loss"].item() 32 | return items 33 | 34 | def print_results(self, stage: str, loss: float, metrics: dict): 35 | """ Prints results to the console using rich text""" 36 | 37 | self._console.rule(stage, style="bold magenta") 38 | self._console.print(f"[bold magenta]Loss:[/bold magenta] [white]{loss}[/white]") 39 | self._console.print(f"[bold magenta]Metrics:[/bold magenta]") 40 | for head in metrics: 41 | table = Table(show_header=True, header_style="bold magenta") 42 | table.add_column("Metric name", style="magenta") 43 | table.add_column(head) 44 | for metric_name in metrics[head]: 45 | value = "{:.5f}".format(metrics[head][metric_name].cpu().item()) 46 | table.add_row(metric_name, value) 47 | self._console.print(table) 48 | self._console.rule(style="bold magenta") 49 | 50 | 51 | class TestOnTrainEnd(pl.Callback): 52 | """ Callback that performs test on pl_module when train ends """ 53 | def on_train_end(self, trainer, pl_module): 54 | from torch.utils.data import DataLoader 55 | from luxonis_ml.data import LuxonisDataset 56 | from luxonis_ml.loader import LuxonisLoader, ValAugmentations 57 | from luxonis_train.utils.config import Config 58 | 59 | cfg = Config() 60 | with LuxonisDataset( 61 | team_id=cfg.get("dataset.team_id"), 62 | dataset_id=cfg.get("dataset.dataset_id"), 63 | bucket_type=cfg.get("dataset.bucket_type"), 64 | override_bucket_type=cfg.get("dataset.override_bucket_type") 65 | ) as dataset: 66 | loader_test = LuxonisLoader( 67 | dataset, 68 | view=cfg.get("dataset.test_view"), 69 | augmentations=ValAugmentations( 70 | image_size=self.cfg.get("train.preprocessing.train_image_size"), 71 | augmentations=self.cfg.get("train.preprocessing.augmentations"), 72 | train_rgb=self.cfg.get("train.preprocessing.train_rgb"), 73 | keep_aspect_ratio=self.cfg.get("train.preprocessing.keep_aspect_ratio") 74 | ) 75 | ) 76 | pytorch_loader_test = DataLoader( 77 | loader_test, 78 | batch_size=cfg.get("train.batch_size"), 79 | num_workers=cfg.get("train.num_workers"), 80 | collate_fn=loader_test.collate_fn 81 | ) 82 | trainer.test(pl_module, pytorch_loader_test) 83 | 84 | 85 | class ExportOnTrainEnd(pl.Callback): 86 | """ Callback that performs export on train end with best weights according to the validation loss """ 87 | def on_train_end(self, trainer, pl_module): 88 | from luxonis_train.core import Exporter 89 | 90 | model_checkpoint_callbacks = [ 91 | c for c in trainer.callbacks if isinstance(c, pl.callbacks.ModelCheckpoint) 92 | ] 93 | # NOTE: assume that first checkpoint callback is based on val loss 94 | best_model_path = model_checkpoint_callbacks[0].best_model_path 95 | 96 | # override export_weights path with path to currently best weights 97 | override = f"exporter.export_weights {best_model_path}" 98 | exporter = Exporter( 99 | cfg="", # singleton instance already present 100 | args={"override": override} 101 | ) 102 | exporter.export() -------------------------------------------------------------------------------- /src/luxonis_train/utils/constants.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | class LabelType(str, Enum): 4 | CLASSIFICATION = "class" 5 | SEGMENTATION = "segmentation" 6 | BOUNDINGBOX = "bbox" 7 | KEYPOINT = "keypoints" 8 | 9 | class HeadType(Enum): 10 | CLASSIFICATION = 1 11 | MULTI_LABEL_CLASSIFICATION = 2 12 | SEMANTIC_SEGMENTATION = 3 13 | INSTANCE_SEGMENTATION = 4 14 | OBJECT_DETECTION = 5 15 | KEYPOINT_DETECTION = 6 -------------------------------------------------------------------------------- /src/luxonis_train/utils/general.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | def make_divisible(x: int, divisor: int): 5 | """ Upward revision the value x to make it evenly divisible by the divisor. """ 6 | return math.ceil(x / divisor) * divisor 7 | 8 | def dummy_input_run(module: torch.nn.Module, input_shape: list, multi_input: bool = False): 9 | """ Runs dummy input through the module""" 10 | module.eval() 11 | if multi_input: 12 | input = [torch.zeros(i) for i in input_shape] 13 | else: 14 | input = torch.zeros(input_shape) 15 | 16 | out = module(input) 17 | module.train() 18 | if isinstance(out,list): 19 | shapes = [] 20 | for o in out: 21 | shapes.append(list(o.shape)) 22 | return shapes 23 | else: 24 | return [list(out.shape)] -------------------------------------------------------------------------------- /src/luxonis_train/utils/losses/README.md: -------------------------------------------------------------------------------- 1 | ## List of supported loss functions 2 | - CrossEntropyLoss ([source](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html)) 3 | - Params: Can be seen in the source 4 | - BCEWithLogitsLoss ([source](https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html)) 5 | - Params: Can be seen in the source 6 | - FocalLoss ([source](https://www.kaggle.com/code/bigironsphere/loss-function-library-keras-pytorch/notebook)) 7 | - Params: 8 | - alpha: float 9 | - gamma: float 10 | - YoloV6Loss (adapted from [here](https://github.com/meituan/YOLOv6/blob/725913050e15a31cd091dfd7795a1891b0524d35/yolov6/models/loss.py)) 11 | - Params: 12 | - n_classes: int # should be same as head 13 | - iou_type: str # giou, diou, ciou or siou 14 | - loss_weight: dict 15 | - others should stay default -------------------------------------------------------------------------------- /src/luxonis_train/utils/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import init_loss 2 | 3 | __all__ = [ 4 | "init_loss", 5 | ] -------------------------------------------------------------------------------- /src/luxonis_train/utils/losses/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class CrossEntropyLoss(nn.Module): 6 | def __init__(self, **kwargs): 7 | super(CrossEntropyLoss, self).__init__() 8 | self.n_classes = kwargs.get("n_classes") 9 | loss_dict = kwargs 10 | loss_dict.pop("n_classes", None) 11 | loss_dict.pop("head_attributes", None) 12 | self.criterion = nn.CrossEntropyLoss( 13 | **loss_dict 14 | ) 15 | 16 | def forward(self, preds, labels, **kwargs): 17 | if labels.ndim == 4: 18 | # target should be of size (N,...) 19 | labels = labels.argmax(dim=1) 20 | return self.criterion(preds, labels) 21 | 22 | class BCEWithLogitsLoss(nn.Module): 23 | def __init__(self, **kwargs): 24 | super(BCEWithLogitsLoss, self).__init__() 25 | self.n_classes = kwargs.get("n_classes") 26 | loss_dict = kwargs 27 | loss_dict.pop("n_classes", None) 28 | loss_dict.pop("head_attributes", None) 29 | self.criterion = nn.BCEWithLogitsLoss( 30 | **loss_dict 31 | ) 32 | 33 | def forward(self, preds, labels, **kwargs): 34 | return self.criterion(preds, labels) 35 | 36 | class FocalLoss(nn.Module): 37 | # Source: https://www.kaggle.com/code/bigironsphere/loss-function-library-keras-pytorch/notebook 38 | def __init__(self, alpha=0.8, gamma=2, **kwargs): 39 | super(FocalLoss, self).__init__() 40 | self.alpha = alpha 41 | self.gamma = gamma 42 | self.use_sigmoid = kwargs.get("use_sigmoid", True) 43 | 44 | def forward(self, inputs, targets, **kwargs): 45 | 46 | if self.use_sigmoid: 47 | inputs = torch.sigmoid(inputs) 48 | 49 | #flatten label and prediction tensors 50 | inputs = inputs.view(-1).to(torch.float32) 51 | targets = targets.view(-1).to(torch.float32) 52 | 53 | #first compute binary cross-entropy 54 | BCE = F.binary_cross_entropy(inputs, targets, reduction='mean') 55 | BCE_EXP = torch.exp(-BCE) 56 | focal_loss = self.alpha * (1-BCE_EXP)**self.gamma * BCE 57 | 58 | return focal_loss 59 | 60 | class SegmentationLoss(nn.Module): 61 | 62 | def __init__(self, n_classes, alpha=4.0, gamma=2.0, **kwargs): 63 | super(SegmentationLoss, self).__init__() 64 | 65 | self.bce = nn.BCELoss(reduction="none") 66 | self.nc = n_classes 67 | self.alpha = alpha # currently not used 68 | self.gamma = gamma 69 | 70 | def focal_loss(self, logits, labels): 71 | 72 | epsilon = 1.e-9 73 | 74 | # Focal loss 75 | fl = - (labels * torch.log(logits + epsilon)) * (1. - logits) ** self.gamma 76 | fl = fl.sum(1) # Sum focal loss along channel dimension 77 | 78 | # Return mean of the focal loss along spatial 79 | return fl.mean([1,2]) 80 | 81 | def forward(self, predictions, targets, **kwargs): 82 | 83 | predictions = torch.nn.functional.softmax(predictions, dim=1) 84 | 85 | bs = predictions.shape[0] 86 | ps = predictions.view(bs, -1) 87 | ts = targets.view(bs, -1) 88 | 89 | lseg = self.bce(ps, ts.float()).mean(1) 90 | 91 | # focal 92 | fcl = self.focal_loss(predictions.clone(), targets.clone()) 93 | 94 | # iou 95 | preds = torch.argmax(predictions, dim = 1) 96 | preds = torch.unsqueeze(preds, 1) 97 | 98 | targets = torch.argmax(targets, dim = 1) 99 | masks = torch.unsqueeze(targets, 1) 100 | 101 | ious = torch.zeros(preds.shape[0], device=predictions.device) 102 | present_classes = torch.zeros(preds.shape[0], device=predictions.device) 103 | 104 | for cls in range(0,self.nc+1): 105 | masks_c = masks == cls 106 | outputs_c = preds == cls 107 | TP = torch.sum(torch.logical_and(masks_c, outputs_c), dim = [1, 2, 3])#.cpu() 108 | FP = torch.sum(torch.logical_and(torch.logical_not(masks_c), outputs_c), dim = [1, 2, 3])#.cpu() 109 | FN = torch.sum(torch.logical_and(masks_c, torch.logical_not(outputs_c)), dim = [1, 2, 3])#.cpu() 110 | ious += torch.nan_to_num(TP / (TP + FP + FN)) 111 | present_classes += (masks.view(preds.shape[0], -1) == cls).any(dim = 1)#.cpu() 112 | 113 | iou = ious / present_classes 114 | 115 | liou = 1 - iou 116 | 117 | return (lseg + liou + fcl).mean() 118 | -------------------------------------------------------------------------------- /src/luxonis_train/utils/losses/utils.py: -------------------------------------------------------------------------------- 1 | from .common import * 2 | from .yolov6_loss import YoloV6Loss 3 | from .yolov7_pose_loss import YoloV7PoseLoss 4 | 5 | def init_loss(name, **kwargs): 6 | """ Initializes and returns loss based on provided name and config""" 7 | return eval(name)(**kwargs) -------------------------------------------------------------------------------- /src/luxonis_train/utils/losses/yolov7_pose_loss.py: -------------------------------------------------------------------------------- 1 | # 2 | # Adapted from: https://github.com/WongKinYiu/yolov7 3 | # License: https://github.com/WongKinYiu/yolov7/blob/main/LICENSE.md 4 | # 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from luxonis_train.utils.boxutils import bbox_iou 10 | from luxonis_train.utils.losses.common import FocalLoss, BCEWithLogitsLoss 11 | 12 | 13 | def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441 14 | # return positive, negative label smoothing BCE targets 15 | return 1.0 - 0.5 * eps, 0.5 * eps 16 | 17 | class YoloV7PoseLoss(nn.Module): 18 | # Compute losses 19 | def __init__(self, n_classes, cls_pw=1.0, obj_pw=1.0, gamma=2, alpha=0.25, label_smoothing=0.0, 20 | box_weight=0.05, kpt_weight=0.10, kptv_weight=0.6, cls_weight=0.6, obj_weight=0.7, 21 | anchor_t=4.0, **kwargs): 22 | super().__init__() 23 | 24 | self.box_weight = box_weight 25 | self.kpt_weight = kpt_weight 26 | self.cls_weight = cls_weight 27 | self.obj_weight = obj_weight 28 | self.kptv_weight = kptv_weight 29 | self.anchor_t = anchor_t 30 | 31 | # Define criteria 32 | self.BCEcls = BCEWithLogitsLoss(pos_weight=torch.tensor([cls_pw])) 33 | self.BCEobj = BCEWithLogitsLoss(pos_weight=torch.tensor([obj_pw])) 34 | self.focal_loss = FocalLoss(alpha=alpha, gamma=gamma, use_sigmoid=False) 35 | 36 | # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3 37 | # positive, negative BCE targets 38 | self.cp, self.cn = smooth_BCE(eps=label_smoothing) 39 | 40 | head_attributes = kwargs.get("head_attributes") 41 | 42 | self.balance = { 43 | 3: [4.0, 1.0, 0.4] 44 | }.get(head_attributes.get("nl"), [4.0, 1.0, 0.25, 0.06, .02]) # P3-P7 45 | 46 | self.ssi = 0 # stride 16 index 47 | self.gr = head_attributes.get("gr") 48 | self.nkpt = head_attributes.get("n_keypoints") 49 | self.nc = n_classes 50 | self.na = head_attributes.get("na") 51 | self.nl = head_attributes.get("nl") 52 | self.anchors = head_attributes.get("anchors") 53 | 54 | def forward(self, kpt_pred, kpt, **kwargs): 55 | # model output is (kpt, features). The loss only needs features. 56 | kpt_pred = kpt_pred[1] 57 | kpt_pred[0].shape[0] # batch size 58 | device = kpt_pred[0].device 59 | lcls, lbox, lobj, lkpt, lkptv = [ 60 | torch.zeros(1, device=device) for _ in range(5) 61 | ] 62 | tcls, tbox, tkpt, indices, anchors = self.build_targets(kpt_pred, kpt) 63 | kpt = kpt.to(device) 64 | 65 | # Losses 66 | for i, pi in enumerate(kpt_pred): # layer index, layer predictions 67 | b, a, gj, gi = indices[i] # image, anchor, gridy, gridx 68 | tobj = torch.zeros_like(pi[..., 0], device=device) # target obj 69 | 70 | n = b.shape[0] # number of targets 71 | if n: 72 | ps = pi[b, a, gj, gi] # prediction subset corresponding to targets 73 | 74 | # Regression 75 | pxy = ps[:, :2].sigmoid() * 2. - 0.5 76 | pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i].to(device) 77 | pbox = torch.cat((pxy, pwh), 1) 78 | iou = bbox_iou(pbox.T, tbox[i].to(device), box_format="xywh", iou_type="ciou") 79 | lbox += (1.0 - iou).mean() # iou loss 80 | #Direct kpt prediction 81 | pkpt_x = ps[:, 5 + self.nc::3] * 2. - 0.5 82 | pkpt_y = ps[:, 6 + self.nc::3] * 2. - 0.5 83 | pkpt_score = ps[:, 7 + self.nc::3] 84 | #mask 85 | tkpt[i] = tkpt[i].to(device) 86 | kpt_mask = (tkpt[i][:, 0::2] != 0) 87 | lkptv += self.BCEcls(pkpt_score, kpt_mask.float()) 88 | d = (pkpt_x-tkpt[i][:,0::2]) ** 2 + (pkpt_y - tkpt[i][:, 1::2]) ** 2 89 | kpt_loss_factor = (torch.sum(kpt_mask != 0) + torch.sum(kpt_mask == 0) 90 | ) / torch.sum(kpt_mask != 0) 91 | lkpt += kpt_loss_factor * (torch.log(d + 1 + 1e-9) * kpt_mask).mean() 92 | # Objectness 93 | tobj[b, a, gj, gi] = ( 94 | (1.0 - self.gr) 95 | + self.gr * iou.detach().clamp(0).type(tobj.dtype) 96 | ) 97 | 98 | # Classification 99 | if self.nc > 1: # cls loss (only if multiple classes) 100 | t = torch.full_like(ps[:, 5:5+self.nc], self.cn, device=device) 101 | t[range(n), tcls[i]] = self.cp 102 | lcls += self.BCEcls(ps[:, 5:5+self.nc], t) # BCE 103 | 104 | 105 | obji = self.BCEobj(pi[..., 4], tobj) 106 | lobj += obji * self.balance[i] # obj loss 107 | lbox *= self.box_weight 108 | lobj *= self.obj_weight 109 | lcls *= self.cls_weight 110 | lkptv *= self.kptv_weight 111 | lkpt *= self.kpt_weight 112 | 113 | loss = (lbox + lobj + lcls + lkpt + lkptv).reshape([]) 114 | 115 | sub_losses = { 116 | "lbox": lbox.detach(), 117 | "lobj": lobj.detach(), 118 | "lcls": lcls.detach(), 119 | "lkptv": lkptv.detach(), 120 | "lkpt": lkpt.detach() 121 | } 122 | 123 | return loss, sub_losses 124 | 125 | def build_targets(self, p, targets): 126 | na, nt = self.na, targets.shape[0] # number of anchors, targets 127 | tcls, tbox, tkpt, indices, anch = [], [], [], [], [] 128 | gain_length = 7 + 2 * self.nkpt 129 | gain = torch.ones(gain_length, device=targets.device) 130 | ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) 131 | targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) 132 | 133 | g = 0.5 # bias 134 | off = torch.tensor([[0, 0], 135 | [1, 0], [0, 1], [-1, 0], [0, -1], # j,k,l,m 136 | # [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm 137 | ], device=targets.device).float() * g # offsets 138 | 139 | for i in range(self.nl): 140 | anchors = self.anchors[i] 141 | gain[2:gain_length-1] = torch.tensor(p[i].shape)[(2 + self.nkpt)*[3, 2]] 142 | # Match targets to anchors 143 | t = targets * gain 144 | if nt: 145 | # Matches 146 | r = t[:, :, 4:6] / anchors[:, None] # wh ratio 147 | j = torch.max(r, 1. / r).max(2)[0] < self.anchor_t # compare 148 | t = t[j] # filter 149 | 150 | # Offsets 151 | gxy = t[:, 2:4] # grid xy 152 | gxi = gain[[2, 3]] - gxy # inverse 153 | j, k = ((gxy % 1. < g) & (gxy > 1.)).T 154 | l, m = ((gxi % 1. < g) & (gxi > 1.)).T 155 | j = torch.stack((torch.ones_like(j), j, k, l, m)) 156 | t = t.repeat((5, 1, 1))[j] 157 | offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j] 158 | else: 159 | t = targets[0] 160 | offsets = 0 161 | 162 | # Define 163 | b, c = t[:, :2].long().T # image, class 164 | gxy = t[:, 2:4] # grid xy 165 | gwh = t[:, 4:6] # grid wh 166 | gij = (gxy - offsets).long() 167 | gi, gj = gij.T # grid xy indices 168 | 169 | # Append 170 | a = t[:, -1].long() # anchor indices 171 | indices.append( 172 | ( 173 | b, a, gj.clamp_(0, gain[3].long() - 1), # type: ignore 174 | gi.clamp_(0, gain[2].long() - 1) # type: ignore 175 | ) 176 | ) 177 | tbox.append(torch.cat((gxy - gij, gwh), 1)) # box 178 | for kpt in range(self.nkpt): 179 | low = 6 + 2 * kpt 180 | high = 6 + 2 * (kpt + 1) 181 | t[:, low: high][t[:, low: high] != 0] -= gij[t[:, low: high] !=0] 182 | tkpt.append(t[:, 6:-1]) 183 | anch.append(anchors[a]) # anchors 184 | tcls.append(c) # class 185 | 186 | return tcls, tbox, tkpt, indices, anch 187 | -------------------------------------------------------------------------------- /src/luxonis_train/utils/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import init_metrics 2 | 3 | __all__ = [ 4 | "init_metrics", 5 | ] -------------------------------------------------------------------------------- /src/luxonis_train/utils/metrics/custom.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchmetrics import Metric 3 | import warnings 4 | 5 | 6 | class ObjectKeypointSimilarity(Metric): 7 | def __init__(self): 8 | super().__init__() 9 | # self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") 10 | # self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") 11 | 12 | def update(self, preds: torch.Tensor, target: torch.Tensor): 13 | # preds, target = self._input_format(preds, target) 14 | # assert preds.shape == target.shape 15 | 16 | # self.correct += torch.sum(preds == target) 17 | # self.total += target.numel() 18 | pass 19 | 20 | def compute(self): 21 | # return self.correct.float() / self.total 22 | warnings.warn( 23 | "ObjectKeypointSimilarity metric not yet implemented. Returning default value 1." 24 | ) 25 | return torch.ones(1) 26 | -------------------------------------------------------------------------------- /src/luxonis_train/utils/metrics/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchmetrics 3 | from torchmetrics.detection.mean_ap import MeanAveragePrecision 4 | 5 | from .custom import ObjectKeypointSimilarity 6 | from luxonis_train.utils.constants import HeadType 7 | 8 | """ 9 | Default average method for different metrics: 10 | Accuract: micro 11 | Precision: micro 12 | Recall: micro 13 | F1Score: micro 14 | JaccardIndex: macro 15 | """ 16 | 17 | def init_metrics(head: nn.Module): 18 | """ Initializes specific metrics depending on the head type and returns nn.ModuleDict """ 19 | 20 | is_binary = head.n_classes == 1 21 | 22 | metrics = {} 23 | for head_type in head.head_types: 24 | if head_type == HeadType.CLASSIFICATION: 25 | metrics["accuracy"] = torchmetrics.Accuracy(task="binary" if is_binary else "multiclass", 26 | num_classes=head.n_classes) 27 | metrics["precision"] = torchmetrics.Precision(task="binary" if is_binary else "multiclass", 28 | num_classes=head.n_classes) 29 | metrics["recall"] = torchmetrics.Recall(task="binary" if is_binary else "multiclass", 30 | num_classes=head.n_classes) 31 | metrics["f1"] = torchmetrics.F1Score(task="binary" if is_binary else "multiclass", 32 | num_classes=head.n_classes) 33 | elif head_type == HeadType.MULTI_LABEL_CLASSIFICATION: 34 | metrics["accuracy"] = torchmetrics.Accuracy(task="multilabel", num_labels=head.n_classes) 35 | metrics["precision"] = torchmetrics.Precision(task="multilabel", num_labels=head.n_classes) 36 | metrics["recall"] = torchmetrics.Recall(task="multilabel", num_labels=head.n_classes) 37 | metrics["f1"] = torchmetrics.F1Score(task="multilabel", num_labels=head.n_classes) 38 | elif head_type == HeadType.SEMANTIC_SEGMENTATION: 39 | metrics["accuracy"] = torchmetrics.Accuracy(task="binary" if is_binary else "multiclass", 40 | num_classes=head.n_classes, ignore_index=0 if is_binary else None) 41 | metrics["mIoU"] = torchmetrics.JaccardIndex(task="binary" if is_binary else "multiclass", 42 | num_classes=head.n_classes, ignore_index=0 if is_binary else None) 43 | elif head_type == HeadType.OBJECT_DETECTION: 44 | metrics["map"] = MeanAveragePrecision(box_format="xyxy") 45 | elif head_type == HeadType.KEYPOINT_DETECTION: 46 | metrics["oks"] = ObjectKeypointSimilarity() 47 | else: 48 | raise KeyError(f"No metrics for head type = {head_type} are currently supported.") 49 | 50 | collection = torchmetrics.MetricCollection(metrics) 51 | 52 | return nn.ModuleDict({ 53 | "train_metrics": collection, 54 | "val_metrics": collection.clone(), 55 | "test_metrics": collection.clone(), 56 | }) 57 | -------------------------------------------------------------------------------- /src/luxonis_train/utils/optimizers.py: -------------------------------------------------------------------------------- 1 | from torch.optim import * 2 | 3 | def init_optimizer(model_params, name, **kwargs): 4 | """ Initializes and returns optimizer based on provided name and config""" 5 | return eval(name)(params=model_params, **kwargs) -------------------------------------------------------------------------------- /src/luxonis_train/utils/schedulers.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import * 2 | 3 | def init_scheduler(optimizer, name, **kwargs): 4 | """ Initializes and returns scheduler based on provided name and config""" 5 | return eval(name)(optimizer=optimizer, **kwargs) -------------------------------------------------------------------------------- /src/luxonis_train/utils/visualization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import numpy as np 4 | import torchvision.transforms.functional as F 5 | from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks, draw_keypoints 6 | from torchvision.ops import box_convert 7 | 8 | from luxonis_train.utils.constants import LabelType 9 | 10 | 11 | def draw_outputs(imgs: torch.Tensor, output: torch.Tensor, head: torch.nn.Module, return_numpy: bool = True, 12 | unnormalize_img: bool = True, cvt_color: bool = False): 13 | """Draw model outputs on a batch of images 14 | 15 | Args: 16 | imgs (torch.Tensor): Batch of images (NCHW format) 17 | output (torch.Tensor): Model output 18 | head (torch.nn.Module): Model head used for drawing 19 | return_numpy (bool, optional): Flag if should return images in numpy format (HWC). Defaults to True. 20 | unnormalize_img (bool, optional): Unormalize image before drawing to it. Defaults to True. 21 | cvt_color (bool, optional): Convert from BGR to RGB. Defaults to False. 22 | 23 | Returns: 24 | list[Union[torch.Tensor, np.ndarray]]: list of images with visualizations 25 | (either torch tensors in CHW or numpy arrays in HWC format) 26 | """ 27 | 28 | out_imgs = [] 29 | for i in range(imgs.shape[0]): 30 | curr_img = imgs[i] 31 | if unnormalize_img: 32 | curr_img = unnormalize(curr_img, to_uint8=True) 33 | 34 | curr_img = head.draw_output_to_img(curr_img, output, i) 35 | out_imgs.append(curr_img) 36 | 37 | if return_numpy: 38 | out_imgs = [torch_img_to_numpy(i, cvt_color=cvt_color) for i in out_imgs] 39 | 40 | return out_imgs 41 | 42 | def draw_labels(imgs: torch.tensor, label_dict: dict, label_keys: list = None, return_numpy: bool = True, 43 | unnormalize_img: bool = True, cvt_color: bool = False, overlay: bool = False): 44 | """Draw all present labels on a batch of images 45 | 46 | Args: 47 | imgs (torch.tensor): Batch of images (NCHW format) 48 | label_dict (dict): Dictionary of present labels 49 | label_keys (list, optional): List of keys for labels to draw, if None use all. Defaults to None 50 | return_numpy (bool, optional): Flag if should return images in numpy format (HWC). Defaults to True. 51 | unnormalize_img (bool, optional): Unormalize image before drawing to it. Defaults to True. 52 | cvt_color (bool, optional): Convert from BGR to RGB. Defaults to False. 53 | overlay (bool, optional): Draw all labels on the same image. Defaults to False. 54 | 55 | Returns: 56 | list[Union[torch.Tensor, np.ndarray]]: list of images with visualizations 57 | (either torch tensors in CHW or numpy arrays in HWC format) 58 | """ 59 | 60 | _, _, ih, iw = imgs.shape 61 | out_imgs = [] 62 | 63 | if label_keys is None: 64 | label_keys = list(label_dict.keys()) 65 | 66 | for i in range(imgs.shape[0]): 67 | curr_img = imgs[i] 68 | curr_out_imgs = [] 69 | if unnormalize_img: 70 | curr_img = unnormalize(curr_img, to_uint8=True) 71 | 72 | for label_key in label_keys: 73 | if label_key == LabelType.CLASSIFICATION: 74 | curr_img_class = torch_img_to_numpy(curr_img) 75 | indices = torch.nonzero(label_dict[label_key][i]).flatten().tolist() 76 | curr_label_str = ",".join(str(i) for i in indices) 77 | curr_img_class = cv2.putText(curr_img_class, curr_label_str, (10,30), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,255), 2) 78 | curr_img_class = numpy_to_torch_img(curr_img_class) 79 | if overlay: 80 | curr_img = curr_img_class 81 | else: 82 | curr_out_imgs.append(curr_img_class) 83 | 84 | if label_key == LabelType.SEGMENTATION: 85 | curr_label = label_dict[label_key][i] 86 | masks = curr_label.bool() 87 | # NOTE: we have to push everything to cpu manually before draw_segmentation_masks (torchvision bug?) 88 | masks = masks.cpu() 89 | curr_img = curr_img.cpu() 90 | curr_img_seg = draw_segmentation_masks(curr_img, masks, alpha=0.4) 91 | if overlay: 92 | curr_img = curr_img_seg 93 | else: 94 | curr_out_imgs.append(curr_img_seg) 95 | 96 | if label_key == LabelType.BOUNDINGBOX: 97 | curr_label = label_dict[label_key] 98 | curr_label = curr_label[curr_label[:,0]==i] 99 | bboxs = box_convert(curr_label[:,2:], "xywh", "xyxy") 100 | bboxs[:, 0::2] *= iw 101 | bboxs[:, 1::2] *= ih 102 | curr_img_bbox = draw_bounding_boxes(curr_img, bboxs) 103 | if overlay: 104 | curr_img = curr_img_bbox 105 | else: 106 | curr_out_imgs.append(curr_img_bbox) 107 | 108 | if label_key == LabelType.KEYPOINT: 109 | curr_label = label_dict[label_key] 110 | curr_label = curr_label[curr_label[:,0]==i][:,1:] 111 | 112 | keypoints_flat = torch.reshape(curr_label[:,1:], (-1,3)) 113 | keypoints_points = keypoints_flat[:,:2] 114 | keypoints_points[:,0] *= iw 115 | keypoints_points[:,1] *= ih 116 | keypoints_visibility = keypoints_flat[:,2] 117 | 118 | # torchvision expects format [n_instances, K, 2] 119 | n_instances = curr_label.shape[0] 120 | out_keypoints = torch.reshape(keypoints_points, (n_instances, -1, 2)).int() 121 | curr_img_keypoints = draw_keypoints(curr_img, out_keypoints, colors="red") 122 | if overlay: 123 | curr_img = curr_img_keypoints 124 | else: 125 | curr_out_imgs.append(curr_img_keypoints) 126 | 127 | if overlay: 128 | curr_out_imgs = [curr_img] 129 | 130 | if return_numpy: 131 | curr_out_merged = cv2.hconcat( 132 | [torch_img_to_numpy(i, cvt_color=cvt_color) for i in curr_out_imgs] 133 | ) 134 | else: 135 | curr_out_merged = torch.cat(curr_out_imgs, dim=-1) # horizontal concat 136 | 137 | out_imgs.append(curr_out_merged) 138 | return out_imgs 139 | 140 | 141 | def seg_output_to_bool(data: torch.Tensor, binary_threshold: float = 0.5): 142 | """ Converts seg head output to 2D boolean mask for visualization""" 143 | masks = torch.empty_like(data, dtype=torch.bool, device=data.device) 144 | if data.shape[0] == 1: 145 | classes = torch.sigmoid(data) 146 | masks[0] = classes >= binary_threshold 147 | else: 148 | classes = torch.argmax(data, dim=0) 149 | for i in range(masks.shape[0]): 150 | masks[i] = classes == i 151 | return masks 152 | 153 | def unnormalize(img: torch.Tensor, original_mean: tuple = (0.485, 0.456, 0.406), 154 | original_std: tuple = (0.229, 0.224, 0.225), to_uint8: bool = False): 155 | """ Unnormalizes image back to original values, optionally converts it to uin8""" 156 | mean = np.array(original_mean) 157 | std = np.array(original_std) 158 | new_mean = -mean/std 159 | new_std = 1/std 160 | out_img = F.normalize(img, mean=new_mean,std=new_std) 161 | if to_uint8: 162 | out_img = torch.clamp(out_img.mul(255), 0, 255).to(torch.uint8) 163 | return out_img 164 | 165 | def torch_img_to_numpy(img: torch.Tensor, cvt_color: bool = False): 166 | """ Converts torch image (CHW) to numpy array (HWC). Optionally also converts colors. """ 167 | if img.is_floating_point(): 168 | img = img.mul(255).int() 169 | img = torch.clamp(img, 0, 255) 170 | img = np.transpose(img.cpu().numpy().astype(np.uint8), (1, 2, 0)) 171 | img = np.ascontiguousarray(img) 172 | if cvt_color: 173 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 174 | return img 175 | 176 | def numpy_to_torch_img(img: np.array): 177 | """ Converts numpy image (HWC) to torch image (CHW) """ 178 | img = torch.from_numpy(img).permute(2,0,1) 179 | return img 180 | -------------------------------------------------------------------------------- /src/tests/test_config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | name: SimpleClassification 3 | type: 4 | pretrained: 5 | 6 | backbone: 7 | name: MicroNet 8 | pretrained: 9 | 10 | heads: 11 | - name: ClassificationHead 12 | params: 13 | n_classes: null 14 | loss: 15 | name: CrossEntropyLoss 16 | params: 17 | 18 | dataset: 19 | team_id: 2af31474-a342-49c9-8fa4-786ac83a43a3 20 | dataset_id: 64a079d8028d6439d136495d 21 | bucket_type: aws 22 | -------------------------------------------------------------------------------- /src/tests/test_config_fail.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | team_id: 2af31474-a342-49c9-8fa4-786ac83a43a3 3 | dataset_id: 64a079d8028d6439d136495d 4 | bucket_type: aws -------------------------------------------------------------------------------- /tools/export.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from luxonis_train.core import Exporter 3 | 4 | if __name__ == "__main__": 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("-cfg", "--config", type=str, required=True, help="Configuration file to use") 7 | parser.add_argument("--override", default=None, type=str, help="Manually override config parameter") 8 | args = parser.parse_args() 9 | args_dict = vars(args) 10 | 11 | exporter = Exporter(args.config, args_dict) 12 | exporter.export() -------------------------------------------------------------------------------- /tools/infer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from luxonis_train.core import Inferer 3 | 4 | if __name__ == "__main__": 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("-cfg", "--config", type=str, required=True, help="Configuration file to use") 7 | parser.add_argument("--override", default=None, type=str, help="Manually override config parameter") 8 | args = parser.parse_args() 9 | args_dict = vars(args) 10 | 11 | inferer = Inferer(args.config, args_dict) 12 | inferer.infer() -------------------------------------------------------------------------------- /tools/store_config.py: -------------------------------------------------------------------------------- 1 | from luxonis_train.utils.config import Config 2 | import mlflow 3 | import argparse 4 | from dotenv import load_dotenv 5 | import boto3 6 | import os 7 | import json 8 | 9 | if __name__ == "__main__": 10 | """ 11 | Stores config as MLFlow artifact and prints run_id 12 | Prerequisites: 13 | - MLFlow parameters configured in configs under `logger` 14 | - .env file with AWS and MLFlow variables 15 | """ 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("-cfg", "--config", type=str, required=True, help="Configuration file to use") 19 | parser.add_argument("--override", default=None, type=str, help="Manually override config parameter") 20 | parser.add_argument("--bucket", type=str, default="luxonis-mlflow", help="S3 bucket name for config upload") 21 | args = parser.parse_args() 22 | args_dict = vars(args) 23 | 24 | load_dotenv() 25 | 26 | cfg = Config(args.config) 27 | if args.override: 28 | cfg.override_config(args.override) 29 | 30 | mlflow.set_tracking_uri(os.getenv("MLFLOW_TRACKING_URI")) 31 | if cfg.get("logger.project_id") is not None: 32 | cfg.override_config("logger.project_name null") 33 | project_id = cfg.get("logger.project_id") 34 | mlflow.set_experiment( 35 | experiment_name=cfg.get("logger.project_name"), 36 | experiment_id=str(project_id) if project_id is not None else None 37 | ) 38 | 39 | with mlflow.start_run() as run: 40 | s3_client = boto3.client("s3", 41 | aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"), 42 | aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"), 43 | endpoint_url=os.getenv("AWS_S3_ENDPOINT_URL") 44 | ) 45 | tmp_path = "config.json" 46 | with open(tmp_path, "w+") as f: 47 | json.dump(cfg.get_data(), f, indent=4) 48 | 49 | key = run.info.artifact_uri.split(args.bucket+"/")[-1]+"/config.json" 50 | s3_client.upload_file( 51 | Filename="config.json", 52 | Bucket=args.bucket, 53 | Key=key 54 | ) 55 | os.remove(tmp_path) # delete temporary file 56 | 57 | run_id = run.info.run_id 58 | print(f"Config saved as MLFlow artifact. Run id: {run_id}") -------------------------------------------------------------------------------- /tools/test_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import matplotlib.pyplot as plt 5 | from dotenv import load_dotenv 6 | from luxonis_ml.data import LuxonisDataset 7 | from luxonis_ml.loader import LuxonisLoader, TrainAugmentations, ValAugmentations 8 | 9 | from luxonis_train.utils.config import Config 10 | from luxonis_train.utils.visualization import draw_labels 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("-cfg", "--config", type=str, required=True, help="Configuration file to use") 15 | parser.add_argument("--override", default=None, type=str, help="Manually override config parameter") 16 | parser.add_argument("--view", type=str, default="val", help="Dataset view to use") 17 | parser.add_argument("--no-display", action="store_true", help="Don't display images") 18 | parser.add_argument("--save-dir", type=str, default=None, help="Path to save directory, by default don't save") 19 | args = parser.parse_args() 20 | 21 | cfg = Config(args.config) 22 | if args.override: 23 | cfg.override_config(args.override) 24 | 25 | load_dotenv() 26 | 27 | image_size = cfg.get("train.preprocessing.train_image_size") 28 | 29 | with LuxonisDataset( 30 | team_id=cfg.get("dataset.team_id"), 31 | dataset_id=cfg.get("dataset.dataset_id"), 32 | bucket_type=cfg.get("dataset.bucket_type"), 33 | override_bucket_type=cfg.get("dataset.override_bucket_type") 34 | ) as dataset: 35 | 36 | augmentations = TrainAugmentations( 37 | image_size=image_size, 38 | augmentations=cfg.get("train.preprocessing.augmentations"), 39 | train_rgb=cfg.get("train.preprocessing.train_rgb"), 40 | keep_aspect_ratio=cfg.get("train.preprocessing.keep_aspect_ratio") 41 | ) if args.view == "train" else ValAugmentations( 42 | image_size=image_size, 43 | augmentations=cfg.get("train.preprocessing.augmentations"), 44 | train_rgb=cfg.get("train.preprocessing.train_rgb"), 45 | keep_aspect_ratio=cfg.get("train.preprocessing.keep_aspect_ratio") 46 | ) 47 | 48 | loader_train = LuxonisLoader( 49 | dataset, 50 | view=args.view, 51 | augmentations=augmentations 52 | ) 53 | pytorch_loader_train = torch.utils.data.DataLoader( 54 | loader_train, 55 | batch_size=4, 56 | num_workers=1, 57 | collate_fn=loader_train.collate_fn 58 | ) 59 | 60 | save_dir = args.save_dir 61 | if save_dir is not None: 62 | os.makedirs(save_dir, exist_ok=True) 63 | 64 | counter = 0 65 | for data in pytorch_loader_train: 66 | imgs, label_dict = data 67 | out_imgs = draw_labels( 68 | imgs = imgs, label_dict = label_dict, 69 | unnormalize_img = cfg.get("train.preprocessing.normalize.active"), 70 | cvt_color = not cfg.get("train.preprocessing.train_rgb") 71 | ) 72 | 73 | for i in out_imgs: 74 | plt.imshow(i) 75 | if save_dir is not None: 76 | counter += 1 77 | plt.savefig(os.path.join(save_dir, f"{counter}.png")) 78 | if not args.no_display: 79 | plt.show() -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from luxonis_train.core import Trainer 3 | 4 | if __name__ == "__main__": 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("-cfg", "--config", type=str, required=True, help="Configuration file to use") 7 | parser.add_argument("--override", default=None, type=str, help="Manually override config parameter") 8 | args = parser.parse_args() 9 | args_dict = vars(args) 10 | 11 | trainer = Trainer(args.config, args_dict) 12 | trainer.train() 13 | 14 | # Example: train in new thread 15 | # import time 16 | # trainer.train(new_thread=True) 17 | # while True: 18 | # time.sleep(5) 19 | # print(trainer.get_status()) 20 | # print(trainer.get_status_percentage(), trainer.get_save_dir()) 21 | -------------------------------------------------------------------------------- /tools/tune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from luxonis_train.core import Tuner 3 | 4 | if __name__ == "__main__": 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("-cfg", "--config", type=str, required=True, help="Configuration file to use") 7 | parser.add_argument("--override", default=None, type=str, help="Manually override config parameter") 8 | args = parser.parse_args() 9 | args_dict = vars(args) 10 | 11 | trainer = Tuner(args.config, args_dict) 12 | trainer.tune() --------------------------------------------------------------------------------