├── .gitattributes ├── .github └── workflows │ ├── build.yml │ ├── test.yml │ └── update-perf-stats.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── example.gif ├── setup.cfg ├── setup.py ├── src └── torchcontentarea │ ├── __init__.py │ ├── _version.py │ ├── csrc │ ├── common.hpp │ ├── cpu_functions.hpp │ ├── cuda_functions.cuh │ ├── implementation.cpp │ ├── implementation.hpp │ ├── python_bindings.cpp │ └── source │ │ ├── find_points_cpu.cpp │ │ ├── find_points_cuda.cu │ │ ├── find_points_from_scores_cpu.cpp │ │ ├── find_points_from_scores_cuda.cu │ │ ├── fit_circle_cpu.cpp │ │ ├── fit_circle_cuda.cu │ │ ├── make_strips_cpu.cpp │ │ └── make_strips_cuda.cu │ ├── extension_wrapper.py │ ├── models │ ├── kernel_1_8.pt │ ├── kernel_2_8.pt │ └── kernel_3_8.pt │ ├── pythonimplementation │ ├── __init__.py │ ├── estimate_area.py │ ├── fit_area.py │ └── get_points.py │ └── utils.py ├── tests ├── __init__.py ├── test_api.py ├── test_performance.py ├── test_utils.py └── utils │ ├── data.py │ ├── profiling.py │ └── scoring.py └── versioneer.py /.gitattributes: -------------------------------------------------------------------------------- 1 | src/torchcontentarea/_version.py export-subst 2 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | push: 8 | branches: 9 | - main 10 | tags: 11 | - 'v*' 12 | 13 | jobs: 14 | build-wheels: 15 | name: Build Wheels 16 | uses: charliebudd/torch-extension-builder/.github/workflows/build-pytorch-extension-wheels.yml@main 17 | with: 18 | build-command: "git config --global --add safe.directory ${GITHUB_WORKSPACE} && python setup.py bdist_wheel" 19 | 20 | test-wheels-locally: 21 | name: Test Wheels Locally 22 | needs: build-wheels 23 | uses: ./.github/workflows/test.yml 24 | with: 25 | local-wheels: true 26 | 27 | update-perfromance-stats: 28 | if: ${{ github.base_ref == 'main' }} 29 | name: Update Performance Stats 30 | needs: test-wheels-locally 31 | uses: ./.github/workflows/update-perf-stats.yml 32 | 33 | publish-wheels-to-testpypi: 34 | if: startsWith(github.ref, 'refs/tags/v') 35 | name: Publish Wheels To TestPyPI 36 | runs-on: ubuntu-latest 37 | needs: test-wheels-locally 38 | steps: 39 | - name: Download Cached Wheels 40 | uses: actions/download-artifact@v3 41 | with: 42 | name: final-wheels 43 | path: dist 44 | 45 | - name: Checkout repo 46 | uses: actions/checkout@v3 47 | 48 | - name: Make Source Distribution 49 | run: python setup.py sdist 50 | 51 | - name: Publish Package to TestPyPI 52 | uses: pypa/gh-action-pypi-publish@release/v1 53 | with: 54 | user: __token__ 55 | password: ${{ secrets.TEST_PYPI_API_TOKEN }} 56 | repository_url: https://test.pypi.org/legacy/ 57 | 58 | test-testpypi-release: 59 | if: startsWith(github.ref, 'refs/tags/v') 60 | name: Test TestPyPi Release 61 | needs: publish-wheels-to-testpypi 62 | uses: ./.github/workflows/test.yml 63 | with: 64 | local-wheels: false 65 | wheel-location: https://test.pypi.org/simple/ 66 | 67 | publish-wheels-to-pypi: 68 | if: startsWith(github.ref, 'refs/tags/v') 69 | name: Publish Wheels To PyPI 70 | runs-on: ubuntu-latest 71 | needs: test-testpypi-release 72 | steps: 73 | - name: Download Cached Wheels 74 | uses: actions/download-artifact@v3 75 | with: 76 | name: final-wheels 77 | path: dist 78 | 79 | - name: Checkout repo 80 | uses: actions/checkout@v3 81 | 82 | - name: Make Source Distribution 83 | run: python setup.py sdist 84 | 85 | - name: Publish Package to TestPyPI 86 | uses: pypa/gh-action-pypi-publish@release/v1 87 | with: 88 | user: __token__ 89 | password: ${{ secrets.PYPI_API_TOKEN }} 90 | 91 | test-pypi-release: 92 | if: startsWith(github.ref, 'refs/tags/v') 93 | name: Test PyPi Release 94 | needs: publish-wheels-to-pypi 95 | uses: ./.github/workflows/test.yml 96 | with: 97 | local-wheels: false 98 | wheel-location: https://pypi.org/simple/ 99 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test Release 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | local-wheels: 7 | type: boolean 8 | default: true 9 | wheel-location: 10 | type: string 11 | default: final-wheels 12 | python-versions: 13 | type: string 14 | default: "[3.6, 3.7, 3.8, 3.9]" 15 | pytorch-versions: 16 | type: string 17 | default: "[1.9, '1.10', 1.11, 1.12]" 18 | cuda-versions: 19 | type: string 20 | default: "[10.2, 11.3, 11.6]" 21 | 22 | jobs: 23 | test-release: 24 | name: Test Release 25 | runs-on: [self-hosted, Linux, X64, gpu] 26 | strategy: 27 | fail-fast: false 28 | matrix: 29 | python-version: ${{ fromJson(inputs.python-versions) }} 30 | pytorch-version: ${{ fromJson(inputs.pytorch-versions) }} 31 | cuda-version: ${{ fromJson(inputs.cuda-versions) }} 32 | exclude: 33 | - pytorch-version: 1.9 34 | cuda-version: 11.3 35 | - python-version: 3.6 36 | pytorch-version: 1.11 37 | - python-version: 3.6 38 | pytorch-version: 1.12 39 | - pytorch-version: 1.9 40 | cuda-version: 11.6 41 | - pytorch-version: '1.10' 42 | cuda-version: 11.6 43 | - pytorch-version: 1.11 44 | cuda-version: 11.6 45 | steps: 46 | - name: Checkout Repository 47 | uses: actions/checkout@v2 48 | 49 | - name: Setup Python 50 | uses: actions/setup-python@v3 51 | with: 52 | python-version: ${{ matrix.python-version }} 53 | 54 | - name: Install Test Requirements 55 | run: | 56 | ENV=../.venv-${{ matrix.python-version }}-${{ matrix.pytorch-version }}-${{ matrix.cuda-version }} 57 | if [ ! -d "$ENV" ]; then 58 | python -m venv $ENV 59 | fi 60 | . $ENV/bin/activate 61 | 62 | python -m pip install -U --force-reinstall pip 63 | python -m pip install numpy py-cpuinfo ecadataset 64 | export FULL_PYTORCH_VERSION=$(python -m pip index versions torch -f https://download.pytorch.org/whl/torch_stable.html | grep -o ${PYTORCH_VERSION}.[0-9]+cu${CUDA_VERSION//.} | head -n 1) 65 | python -m pip --no-cache-dir install torch==${FULL_PYTORCH_VERSION} -f https://download.pytorch.org/whl/torch_stable.html 66 | 67 | ln -s ../eca-data eca-data 68 | 69 | python -V 70 | pip show torch 71 | env: 72 | PYTORCH_VERSION: ${{ matrix.pytorch-version }} 73 | CUDA_VERSION: ${{ matrix.cuda-version }} 74 | 75 | - name: Install torchcontentarea From PyPI Index 76 | if: ${{ !inputs.local-wheels }} 77 | run: | 78 | . ../.venv-${{ matrix.python-version }}-${{ matrix.pytorch-version }}-${{ matrix.cuda-version }}/bin/activate 79 | python -m pip install --force-reinstall --no-deps torchcontentarea -i ${{ inputs.wheel-location }} 80 | 81 | - name: Download Cached Wheels 82 | if: ${{ inputs.local-wheels }} 83 | uses: actions/download-artifact@v3 84 | with: 85 | name: ${{ inputs.wheel-location }} 86 | path: ${{ inputs.wheel-location }} 87 | 88 | - name: Install torchcontentarea from Cached Wheels 89 | if: ${{ inputs.local-wheels }} 90 | run: | 91 | . ../.venv-${{ matrix.python-version }}-${{ matrix.pytorch-version }}-${{ matrix.cuda-version }}/bin/activate 92 | python -m pip install --force-reinstall --no-deps ${{ inputs.wheel-location }}/torchcontentarea*cp$(echo ${{ matrix.python-version }} | sed 's/\.//')*.whl 93 | 94 | - name: Run Tests 95 | run: | 96 | . ../.venv-${{ matrix.python-version }}-${{ matrix.pytorch-version }}-${{ matrix.cuda-version }}/bin/activate 97 | python -m unittest tests.test_api tests.test_utils 98 | -------------------------------------------------------------------------------- /.github/workflows/update-perf-stats.yml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | python-version: 7 | type: string 8 | default: "3.9" 9 | pytorch-version: 10 | type: string 11 | default: "1.11" 12 | cuda-version: 13 | type: string 14 | default: "11.3" 15 | jobs: 16 | update-readme: 17 | name: Test Release 18 | runs-on: [self-hosted, Linux, X64, gpu] 19 | steps: 20 | - name: Checkout Repository 21 | uses: actions/checkout@v2 22 | with: 23 | ref: ${{ github.event.pull_request.head.sha }} 24 | 25 | - name: Setup Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: ${{ inputs.python-version }} 29 | 30 | - name: Install Test Requirements 31 | run: | 32 | ENV=../.venv-${{ inputs.python-version }}-${{ inputs.pytorch-version }}-${{ inputs.cuda-version }} 33 | if [ -d "$ENV" ]; then 34 | . $ENV/bin/activate 35 | else 36 | python -m venv $ENV 37 | . $ENV/bin/activate 38 | python -m pip install -U --force-reinstall pip 39 | python -m pip install numpy pillow torch==${{ inputs.pytorch-version }} -f https://download.pytorch.org/whl/cu$(echo ${{ inputs.cuda-version }} | sed 's/\.//')/torch_stable.html 40 | fi 41 | 42 | ln -s ../eca-data eca-data 43 | 44 | python -V 45 | pip show torch 46 | 47 | - name: Install torchcontentarea 48 | run: python setup.py install 49 | 50 | - id: run-tests 51 | name: Run Tests and Update README.md 52 | run: | 53 | . ../.venv-${{ inputs.python-version }}-${{ inputs.pytorch-version }}-${{ inputs.cuda-version }}/bin/activate 54 | RESULTS=$(sed -r '/^Performance/i\\r' <<< $(python -m unittest tests.test_performance | grep -e '^Performance' -e '^- ')) 55 | RESULTS=$(printf '%s\n' "$RESULTS" | sed 's/\\/&&/g;s/^[[:blank:]]/\\&/;s/$/\\/') 56 | 57 | START="" 58 | END="" 59 | sed -ni "/$START/{p;:a;N;/$END/!ba;s/.*\n/$RESULTS \n/};p" README.md 60 | 61 | git add README.md 62 | git commit -m "updating performance stats" 63 | 64 | - name: Push changes 65 | uses: ad-m/github-push-action@master 66 | with: 67 | github_token: ${{ secrets.GITHUB_TOKEN }} 68 | branch: ${{ github.head_ref }} 69 | 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Charles Budd 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include versioneer.py 2 | include src/torchcontentarea/_version.py 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Torch Content Area 2 | A PyTorch tool kit for estimating the circular content area in endoscopic footage. The algorithm was developed and tested against the [Endoscopic Content Area (ECA) dataset](https://github.com/charliebudd/eca-dataset). Both this implementation and the dataset are released alongside our publication: 3 | 4 | 9 | 10 | If you make use of this work, please cite the paper. 11 | 12 | [![Build Status](https://github.com/charliebudd/torch-content-area/actions/workflows/build.yml/badge.svg)](https://github.com/charliebudd/torch-content-area/actions/workflows/build.yml) 13 | 14 | ![Example GIF](example.gif?raw=true) 15 | 16 | ## Installation 17 | For Linux users, to install the latest version, simply run... 18 | ``` 19 | pip install torchcontentarea 20 | ``` 21 | For Windows users, or if you encounter any issues, try building from source by running... 22 | ``` 23 | pip install git+https://github.com/charliebudd/torch-content-area 24 | ``` 25 | ***Note:*** *this will require that you have CUDA installed and that its version matches the version of CUDA used to build your installation of PyTorch.* 26 | 27 | ## Usage 28 | ```python 29 | from torchvision.io import read_image 30 | from torchcontentarea import estimate_area, get_points, fit_area 31 | from torchcontentarea.utils import draw_area, crop_area 32 | 33 | # Grayscale or RGB image in NCHW or CHW format. Values should be normalised 34 | # between 0-1 for floating point types and 0-255 for integer types. 35 | image = read_image("my_image.png") 36 | 37 | # Either directly estimate area from image... 38 | area = estimate_area(image, strip_count=16) 39 | 40 | # ...or get the set of points and then fit the area. 41 | points = get_points(image, strip_count=16) 42 | area = fit_area(points, image.shape[2:4]) 43 | 44 | # Utility function are included to help handle the content area... 45 | area_mask = draw_area(area, image) 46 | cropped_image = crop_area(area, image) 47 | ``` 48 | 49 | ## Performance 50 | Performance is measured against the CholecECA subset of the [Endoscopic Content Area (ECA) dataset](https://github.com/charliebudd/eca-dataset). 51 | 52 | 53 | Performance Results (handcrafted cuda)... 54 | - Avg Time (NVIDIA GeForce GTX 980 Ti): 0.299 ± 0.042ms 55 | - Avg Error (Hausdorff Distance): 3.618 56 | - Miss Rate (Error > 15): 2.1% 57 | - Bad Miss Rate (Error > 25): 1.1% 58 | 59 | 60 | -------------------------------------------------------------------------------- /example.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/charliebudd/torch-content-area/c613957f266a64232f8283975653635160b3f0a2/example.gif -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [versioneer] 2 | VCS = git 3 | style = pep440 4 | versionfile_source = src/torchcontentarea/_version.py 5 | versionfile_build = torchcontentarea/_version.py 6 | tag_prefix = 7 | parentdir_prefix = torchcontentarea- -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils import cpp_extension 3 | from glob import glob 4 | import versioneer 5 | import sys 6 | 7 | with open("README.md") as file: 8 | long_description = file.read() 9 | 10 | ext_src_dir = "src/torchcontentarea/csrc/" 11 | ext_source_files = glob(ext_src_dir + "**/*.cpp", recursive=True) + glob(ext_src_dir + "**/*.cu", recursive=True) 12 | 13 | compile_args = { 14 | 'cxx': ['/O2'] if sys.platform.startswith("win") else ['-g0', '-O3'], 15 | 'nvcc': ['-O3'] 16 | } 17 | 18 | try: 19 | setup( 20 | name="torchcontentarea", 21 | version=versioneer.get_version(), 22 | description="A PyTorch tool kit for estimating the content area in endoscopic footage.", 23 | long_description=long_description, 24 | long_description_content_type="text/markdown", 25 | author="Charlie Budd", 26 | author_email="charles.budd@kcl.ac.uk", 27 | url="https://github.com/charliebudd/torch-content-area", 28 | license="MIT", 29 | packages=["torchcontentarea", "torchcontentarea/pythonimplementation"], 30 | package_dir={"":"src"}, 31 | package_data={'torchcontentarea': ['models/*.pt']}, 32 | ext_modules=[cpp_extension.CUDAExtension("torchcontentareaext", ext_source_files, extra_compile_args=compile_args)], 33 | cmdclass=versioneer.get_cmdclass({"build_ext": cpp_extension.BuildExtension}) 34 | ) 35 | except: 36 | print("########################################################################") 37 | print("Could not compile CUDA extention, falling back to python implementation!") 38 | print("########################################################################") 39 | setup( 40 | name="torchcontentarea", 41 | version=versioneer.get_version(), 42 | description="A PyTorch tool kit for estimating the content area in endoscopic footage.", 43 | long_description=long_description, 44 | long_description_content_type="text/markdown", 45 | author="Charlie Budd", 46 | author_email="charles.budd@kcl.ac.uk", 47 | url="https://github.com/charliebudd/torch-content-area", 48 | license="MIT", 49 | packages=["torchcontentarea", "torchcontentarea/pythonimplementation"], 50 | package_dir={"":"src"}, 51 | package_data={'torchcontentarea': ['models/*.pt']}, 52 | ) 53 | -------------------------------------------------------------------------------- /src/torchcontentarea/__init__.py: -------------------------------------------------------------------------------- 1 | """A PyTorch tool kit for segmenting the endoscopic content area in laparoscopy footage.""" 2 | 3 | from .extension_wrapper import estimate_area_handcrafted as estimate_area, estimate_area_handcrafted, estimate_area_learned 4 | from .extension_wrapper import get_points_handcrafted as get_points, get_points_handcrafted, get_points_learned 5 | from .extension_wrapper import fit_area 6 | 7 | from .utils import draw_area, get_crop, crop_area 8 | 9 | from . import _version 10 | __version__ = _version.get_versions()['version'] 11 | -------------------------------------------------------------------------------- /src/torchcontentarea/_version.py: -------------------------------------------------------------------------------- 1 | 2 | # This file helps to compute a version number in source trees obtained from 3 | # git-archive tarball (such as those provided by githubs download-from-tag 4 | # feature). Distribution tarballs (built by setup.py sdist) and build 5 | # directories (produced by setup.py build) will contain a much shorter file 6 | # that just contains the computed version number. 7 | 8 | # This file is released into the public domain. Generated by 9 | # versioneer-0.22 (https://github.com/python-versioneer/python-versioneer) 10 | 11 | """Git implementation of _version.py.""" 12 | 13 | import errno 14 | import os 15 | import re 16 | import subprocess 17 | import sys 18 | from typing import Callable, Dict 19 | import functools 20 | 21 | 22 | def get_keywords(): 23 | """Get the keywords needed to look up the version information.""" 24 | # these strings will be replaced by git during git-archive. 25 | # setup.py/versioneer.py will grep for the variable names, so they must 26 | # each be defined on a line of their own. _version.py will just call 27 | # get_keywords(). 28 | git_refnames = " (HEAD -> main)" 29 | git_full = "c613957f266a64232f8283975653635160b3f0a2" 30 | git_date = "2024-11-07 14:18:22 +0000" 31 | keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} 32 | return keywords 33 | 34 | 35 | class VersioneerConfig: 36 | """Container for Versioneer configuration parameters.""" 37 | 38 | 39 | def get_config(): 40 | """Create, populate and return the VersioneerConfig() object.""" 41 | # these strings are filled in when 'setup.py versioneer' creates 42 | # _version.py 43 | cfg = VersioneerConfig() 44 | cfg.VCS = "git" 45 | cfg.style = "pep440" 46 | cfg.tag_prefix = "" 47 | cfg.parentdir_prefix = "torchcontentarea-" 48 | cfg.versionfile_source = "src/torchcontentarea/_version.py" 49 | cfg.verbose = False 50 | return cfg 51 | 52 | 53 | class NotThisMethod(Exception): 54 | """Exception raised if a method is not valid for the current scenario.""" 55 | 56 | 57 | LONG_VERSION_PY: Dict[str, str] = {} 58 | HANDLERS: Dict[str, Dict[str, Callable]] = {} 59 | 60 | 61 | def register_vcs_handler(vcs, method): # decorator 62 | """Create decorator to mark a method as the handler of a VCS.""" 63 | def decorate(f): 64 | """Store f in HANDLERS[vcs][method].""" 65 | if vcs not in HANDLERS: 66 | HANDLERS[vcs] = {} 67 | HANDLERS[vcs][method] = f 68 | return f 69 | return decorate 70 | 71 | 72 | def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, 73 | env=None): 74 | """Call the given command(s).""" 75 | assert isinstance(commands, list) 76 | process = None 77 | 78 | popen_kwargs = {} 79 | if sys.platform == "win32": 80 | # This hides the console window if pythonw.exe is used 81 | startupinfo = subprocess.STARTUPINFO() 82 | startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW 83 | popen_kwargs["startupinfo"] = startupinfo 84 | 85 | for command in commands: 86 | try: 87 | dispcmd = str([command] + args) 88 | # remember shell=False, so use git.cmd on windows, not just git 89 | process = subprocess.Popen([command] + args, cwd=cwd, env=env, 90 | stdout=subprocess.PIPE, 91 | stderr=(subprocess.PIPE if hide_stderr 92 | else None), **popen_kwargs) 93 | break 94 | except OSError: 95 | e = sys.exc_info()[1] 96 | if e.errno == errno.ENOENT: 97 | continue 98 | if verbose: 99 | print("unable to run %s" % dispcmd) 100 | print(e) 101 | return None, None 102 | else: 103 | if verbose: 104 | print("unable to find command, tried %s" % (commands,)) 105 | return None, None 106 | stdout = process.communicate()[0].strip().decode() 107 | if process.returncode != 0: 108 | if verbose: 109 | print("unable to run %s (error)" % dispcmd) 110 | print("stdout was %s" % stdout) 111 | return None, process.returncode 112 | return stdout, process.returncode 113 | 114 | 115 | def versions_from_parentdir(parentdir_prefix, root, verbose): 116 | """Try to determine the version from the parent directory name. 117 | 118 | Source tarballs conventionally unpack into a directory that includes both 119 | the project name and a version string. We will also support searching up 120 | two directory levels for an appropriately named parent directory 121 | """ 122 | rootdirs = [] 123 | 124 | for _ in range(3): 125 | dirname = os.path.basename(root) 126 | if dirname.startswith(parentdir_prefix): 127 | return {"version": dirname[len(parentdir_prefix):], 128 | "full-revisionid": None, 129 | "dirty": False, "error": None, "date": None} 130 | rootdirs.append(root) 131 | root = os.path.dirname(root) # up a level 132 | 133 | if verbose: 134 | print("Tried directories %s but none started with prefix %s" % 135 | (str(rootdirs), parentdir_prefix)) 136 | raise NotThisMethod("rootdir doesn't start with parentdir_prefix") 137 | 138 | 139 | @register_vcs_handler("git", "get_keywords") 140 | def git_get_keywords(versionfile_abs): 141 | """Extract version information from the given file.""" 142 | # the code embedded in _version.py can just fetch the value of these 143 | # keywords. When used from setup.py, we don't want to import _version.py, 144 | # so we do it with a regexp instead. This function is not used from 145 | # _version.py. 146 | keywords = {} 147 | try: 148 | with open(versionfile_abs, "r") as fobj: 149 | for line in fobj: 150 | if line.strip().startswith("git_refnames ="): 151 | mo = re.search(r'=\s*"(.*)"', line) 152 | if mo: 153 | keywords["refnames"] = mo.group(1) 154 | if line.strip().startswith("git_full ="): 155 | mo = re.search(r'=\s*"(.*)"', line) 156 | if mo: 157 | keywords["full"] = mo.group(1) 158 | if line.strip().startswith("git_date ="): 159 | mo = re.search(r'=\s*"(.*)"', line) 160 | if mo: 161 | keywords["date"] = mo.group(1) 162 | except OSError: 163 | pass 164 | return keywords 165 | 166 | 167 | @register_vcs_handler("git", "keywords") 168 | def git_versions_from_keywords(keywords, tag_prefix, verbose): 169 | """Get version information from git keywords.""" 170 | if "refnames" not in keywords: 171 | raise NotThisMethod("Short version file found") 172 | date = keywords.get("date") 173 | if date is not None: 174 | # Use only the last line. Previous lines may contain GPG signature 175 | # information. 176 | date = date.splitlines()[-1] 177 | 178 | # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant 179 | # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 180 | # -like" string, which we must then edit to make compliant), because 181 | # it's been around since git-1.5.3, and it's too difficult to 182 | # discover which version we're using, or to work around using an 183 | # older one. 184 | date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) 185 | refnames = keywords["refnames"].strip() 186 | if refnames.startswith("$Format"): 187 | if verbose: 188 | print("keywords are unexpanded, not using") 189 | raise NotThisMethod("unexpanded keywords, not a git-archive tarball") 190 | refs = {r.strip() for r in refnames.strip("()").split(",")} 191 | # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of 192 | # just "foo-1.0". If we see a "tag: " prefix, prefer those. 193 | TAG = "tag: " 194 | tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} 195 | if not tags: 196 | # Either we're using git < 1.8.3, or there really are no tags. We use 197 | # a heuristic: assume all version tags have a digit. The old git %d 198 | # expansion behaves like git log --decorate=short and strips out the 199 | # refs/heads/ and refs/tags/ prefixes that would let us distinguish 200 | # between branches and tags. By ignoring refnames without digits, we 201 | # filter out many common branch names like "release" and 202 | # "stabilization", as well as "HEAD" and "master". 203 | tags = {r for r in refs if re.search(r'\d', r)} 204 | if verbose: 205 | print("discarding '%s', no digits" % ",".join(refs - tags)) 206 | if verbose: 207 | print("likely tags: %s" % ",".join(sorted(tags))) 208 | for ref in sorted(tags): 209 | # sorting will prefer e.g. "2.0" over "2.0rc1" 210 | if ref.startswith(tag_prefix): 211 | r = ref[len(tag_prefix):] 212 | # Filter out refs that exactly match prefix or that don't start 213 | # with a number once the prefix is stripped (mostly a concern 214 | # when prefix is '') 215 | if not re.match(r'\d', r): 216 | continue 217 | if verbose: 218 | print("picking %s" % r) 219 | return {"version": r, 220 | "full-revisionid": keywords["full"].strip(), 221 | "dirty": False, "error": None, 222 | "date": date} 223 | # no suitable tags, so version is "0+unknown", but full hex is still there 224 | if verbose: 225 | print("no suitable tags, using unknown + full revision id") 226 | return {"version": "0+unknown", 227 | "full-revisionid": keywords["full"].strip(), 228 | "dirty": False, "error": "no suitable tags", "date": None} 229 | 230 | 231 | @register_vcs_handler("git", "pieces_from_vcs") 232 | def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): 233 | """Get version from 'git describe' in the root of the source tree. 234 | 235 | This only gets called if the git-archive 'subst' keywords were *not* 236 | expanded, and _version.py hasn't already been rewritten with a short 237 | version string, meaning we're inside a checked out source tree. 238 | """ 239 | GITS = ["git"] 240 | if sys.platform == "win32": 241 | GITS = ["git.cmd", "git.exe"] 242 | 243 | # GIT_DIR can interfere with correct operation of Versioneer. 244 | # It may be intended to be passed to the Versioneer-versioned project, 245 | # but that should not change where we get our version from. 246 | env = os.environ.copy() 247 | env.pop("GIT_DIR", None) 248 | runner = functools.partial(runner, env=env) 249 | 250 | _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, 251 | hide_stderr=True) 252 | if rc != 0: 253 | if verbose: 254 | print("Directory %s not under git control" % root) 255 | raise NotThisMethod("'git rev-parse --git-dir' returned error") 256 | 257 | MATCH_ARGS = ["--match", "%s*" % tag_prefix] if tag_prefix else [] 258 | 259 | # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] 260 | # if there isn't one, this yields HEX[-dirty] (no NUM) 261 | describe_out, rc = runner(GITS, ["describe", "--tags", "--dirty", 262 | "--always", "--long", *MATCH_ARGS], 263 | cwd=root) 264 | # --long was added in git-1.5.5 265 | if describe_out is None: 266 | raise NotThisMethod("'git describe' failed") 267 | describe_out = describe_out.strip() 268 | full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) 269 | if full_out is None: 270 | raise NotThisMethod("'git rev-parse' failed") 271 | full_out = full_out.strip() 272 | 273 | pieces = {} 274 | pieces["long"] = full_out 275 | pieces["short"] = full_out[:7] # maybe improved later 276 | pieces["error"] = None 277 | 278 | branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], 279 | cwd=root) 280 | # --abbrev-ref was added in git-1.6.3 281 | if rc != 0 or branch_name is None: 282 | raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") 283 | branch_name = branch_name.strip() 284 | 285 | if branch_name == "HEAD": 286 | # If we aren't exactly on a branch, pick a branch which represents 287 | # the current commit. If all else fails, we are on a branchless 288 | # commit. 289 | branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) 290 | # --contains was added in git-1.5.4 291 | if rc != 0 or branches is None: 292 | raise NotThisMethod("'git branch --contains' returned error") 293 | branches = branches.split("\n") 294 | 295 | # Remove the first line if we're running detached 296 | if "(" in branches[0]: 297 | branches.pop(0) 298 | 299 | # Strip off the leading "* " from the list of branches. 300 | branches = [branch[2:] for branch in branches] 301 | if "master" in branches: 302 | branch_name = "master" 303 | elif not branches: 304 | branch_name = None 305 | else: 306 | # Pick the first branch that is returned. Good or bad. 307 | branch_name = branches[0] 308 | 309 | pieces["branch"] = branch_name 310 | 311 | # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] 312 | # TAG might have hyphens. 313 | git_describe = describe_out 314 | 315 | # look for -dirty suffix 316 | dirty = git_describe.endswith("-dirty") 317 | pieces["dirty"] = dirty 318 | if dirty: 319 | git_describe = git_describe[:git_describe.rindex("-dirty")] 320 | 321 | # now we have TAG-NUM-gHEX or HEX 322 | 323 | if "-" in git_describe: 324 | # TAG-NUM-gHEX 325 | mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) 326 | if not mo: 327 | # unparsable. Maybe git-describe is misbehaving? 328 | pieces["error"] = ("unable to parse git-describe output: '%s'" 329 | % describe_out) 330 | return pieces 331 | 332 | # tag 333 | full_tag = mo.group(1) 334 | if not full_tag.startswith(tag_prefix): 335 | if verbose: 336 | fmt = "tag '%s' doesn't start with prefix '%s'" 337 | print(fmt % (full_tag, tag_prefix)) 338 | pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" 339 | % (full_tag, tag_prefix)) 340 | return pieces 341 | pieces["closest-tag"] = full_tag[len(tag_prefix):] 342 | 343 | # distance: number of commits since tag 344 | pieces["distance"] = int(mo.group(2)) 345 | 346 | # commit: short hex revision ID 347 | pieces["short"] = mo.group(3) 348 | 349 | else: 350 | # HEX: no tags 351 | pieces["closest-tag"] = None 352 | count_out, rc = runner(GITS, ["rev-list", "HEAD", "--count"], cwd=root) 353 | pieces["distance"] = int(count_out) # total number of commits 354 | 355 | # commit date: see ISO-8601 comment in git_versions_from_keywords() 356 | date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() 357 | # Use only the last line. Previous lines may contain GPG signature 358 | # information. 359 | date = date.splitlines()[-1] 360 | pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) 361 | 362 | return pieces 363 | 364 | 365 | def plus_or_dot(pieces): 366 | """Return a + if we don't already have one, else return a .""" 367 | if "+" in pieces.get("closest-tag", ""): 368 | return "." 369 | return "+" 370 | 371 | 372 | def render_pep440(pieces): 373 | """Build up version string, with post-release "local version identifier". 374 | 375 | Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you 376 | get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty 377 | 378 | Exceptions: 379 | 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] 380 | """ 381 | if pieces["closest-tag"]: 382 | rendered = pieces["closest-tag"] 383 | if pieces["distance"] or pieces["dirty"]: 384 | rendered += plus_or_dot(pieces) 385 | rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) 386 | if pieces["dirty"]: 387 | rendered += ".dirty" 388 | else: 389 | # exception #1 390 | rendered = "0+untagged.%d.g%s" % (pieces["distance"], 391 | pieces["short"]) 392 | if pieces["dirty"]: 393 | rendered += ".dirty" 394 | return rendered 395 | 396 | 397 | def render_pep440_branch(pieces): 398 | """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . 399 | 400 | The ".dev0" means not master branch. Note that .dev0 sorts backwards 401 | (a feature branch will appear "older" than the master branch). 402 | 403 | Exceptions: 404 | 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] 405 | """ 406 | if pieces["closest-tag"]: 407 | rendered = pieces["closest-tag"] 408 | if pieces["distance"] or pieces["dirty"]: 409 | if pieces["branch"] != "master": 410 | rendered += ".dev0" 411 | rendered += plus_or_dot(pieces) 412 | rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) 413 | if pieces["dirty"]: 414 | rendered += ".dirty" 415 | else: 416 | # exception #1 417 | rendered = "0" 418 | if pieces["branch"] != "master": 419 | rendered += ".dev0" 420 | rendered += "+untagged.%d.g%s" % (pieces["distance"], 421 | pieces["short"]) 422 | if pieces["dirty"]: 423 | rendered += ".dirty" 424 | return rendered 425 | 426 | 427 | def pep440_split_post(ver): 428 | """Split pep440 version string at the post-release segment. 429 | 430 | Returns the release segments before the post-release and the 431 | post-release version number (or -1 if no post-release segment is present). 432 | """ 433 | vc = str.split(ver, ".post") 434 | return vc[0], int(vc[1] or 0) if len(vc) == 2 else None 435 | 436 | 437 | def render_pep440_pre(pieces): 438 | """TAG[.postN.devDISTANCE] -- No -dirty. 439 | 440 | Exceptions: 441 | 1: no tags. 0.post0.devDISTANCE 442 | """ 443 | if pieces["closest-tag"]: 444 | if pieces["distance"]: 445 | # update the post release segment 446 | tag_version, post_version = pep440_split_post(pieces["closest-tag"]) 447 | rendered = tag_version 448 | if post_version is not None: 449 | rendered += ".post%d.dev%d" % (post_version+1, pieces["distance"]) 450 | else: 451 | rendered += ".post0.dev%d" % (pieces["distance"]) 452 | else: 453 | # no commits, use the tag as the version 454 | rendered = pieces["closest-tag"] 455 | else: 456 | # exception #1 457 | rendered = "0.post0.dev%d" % pieces["distance"] 458 | return rendered 459 | 460 | 461 | def render_pep440_post(pieces): 462 | """TAG[.postDISTANCE[.dev0]+gHEX] . 463 | 464 | The ".dev0" means dirty. Note that .dev0 sorts backwards 465 | (a dirty tree will appear "older" than the corresponding clean one), 466 | but you shouldn't be releasing software with -dirty anyways. 467 | 468 | Exceptions: 469 | 1: no tags. 0.postDISTANCE[.dev0] 470 | """ 471 | if pieces["closest-tag"]: 472 | rendered = pieces["closest-tag"] 473 | if pieces["distance"] or pieces["dirty"]: 474 | rendered += ".post%d" % pieces["distance"] 475 | if pieces["dirty"]: 476 | rendered += ".dev0" 477 | rendered += plus_or_dot(pieces) 478 | rendered += "g%s" % pieces["short"] 479 | else: 480 | # exception #1 481 | rendered = "0.post%d" % pieces["distance"] 482 | if pieces["dirty"]: 483 | rendered += ".dev0" 484 | rendered += "+g%s" % pieces["short"] 485 | return rendered 486 | 487 | 488 | def render_pep440_post_branch(pieces): 489 | """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . 490 | 491 | The ".dev0" means not master branch. 492 | 493 | Exceptions: 494 | 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] 495 | """ 496 | if pieces["closest-tag"]: 497 | rendered = pieces["closest-tag"] 498 | if pieces["distance"] or pieces["dirty"]: 499 | rendered += ".post%d" % pieces["distance"] 500 | if pieces["branch"] != "master": 501 | rendered += ".dev0" 502 | rendered += plus_or_dot(pieces) 503 | rendered += "g%s" % pieces["short"] 504 | if pieces["dirty"]: 505 | rendered += ".dirty" 506 | else: 507 | # exception #1 508 | rendered = "0.post%d" % pieces["distance"] 509 | if pieces["branch"] != "master": 510 | rendered += ".dev0" 511 | rendered += "+g%s" % pieces["short"] 512 | if pieces["dirty"]: 513 | rendered += ".dirty" 514 | return rendered 515 | 516 | 517 | def render_pep440_old(pieces): 518 | """TAG[.postDISTANCE[.dev0]] . 519 | 520 | The ".dev0" means dirty. 521 | 522 | Exceptions: 523 | 1: no tags. 0.postDISTANCE[.dev0] 524 | """ 525 | if pieces["closest-tag"]: 526 | rendered = pieces["closest-tag"] 527 | if pieces["distance"] or pieces["dirty"]: 528 | rendered += ".post%d" % pieces["distance"] 529 | if pieces["dirty"]: 530 | rendered += ".dev0" 531 | else: 532 | # exception #1 533 | rendered = "0.post%d" % pieces["distance"] 534 | if pieces["dirty"]: 535 | rendered += ".dev0" 536 | return rendered 537 | 538 | 539 | def render_git_describe(pieces): 540 | """TAG[-DISTANCE-gHEX][-dirty]. 541 | 542 | Like 'git describe --tags --dirty --always'. 543 | 544 | Exceptions: 545 | 1: no tags. HEX[-dirty] (note: no 'g' prefix) 546 | """ 547 | if pieces["closest-tag"]: 548 | rendered = pieces["closest-tag"] 549 | if pieces["distance"]: 550 | rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) 551 | else: 552 | # exception #1 553 | rendered = pieces["short"] 554 | if pieces["dirty"]: 555 | rendered += "-dirty" 556 | return rendered 557 | 558 | 559 | def render_git_describe_long(pieces): 560 | """TAG-DISTANCE-gHEX[-dirty]. 561 | 562 | Like 'git describe --tags --dirty --always -long'. 563 | The distance/hash is unconditional. 564 | 565 | Exceptions: 566 | 1: no tags. HEX[-dirty] (note: no 'g' prefix) 567 | """ 568 | if pieces["closest-tag"]: 569 | rendered = pieces["closest-tag"] 570 | rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) 571 | else: 572 | # exception #1 573 | rendered = pieces["short"] 574 | if pieces["dirty"]: 575 | rendered += "-dirty" 576 | return rendered 577 | 578 | 579 | def render(pieces, style): 580 | """Render the given version pieces into the requested style.""" 581 | if pieces["error"]: 582 | return {"version": "unknown", 583 | "full-revisionid": pieces.get("long"), 584 | "dirty": None, 585 | "error": pieces["error"], 586 | "date": None} 587 | 588 | if not style or style == "default": 589 | style = "pep440" # the default 590 | 591 | if style == "pep440": 592 | rendered = render_pep440(pieces) 593 | elif style == "pep440-branch": 594 | rendered = render_pep440_branch(pieces) 595 | elif style == "pep440-pre": 596 | rendered = render_pep440_pre(pieces) 597 | elif style == "pep440-post": 598 | rendered = render_pep440_post(pieces) 599 | elif style == "pep440-post-branch": 600 | rendered = render_pep440_post_branch(pieces) 601 | elif style == "pep440-old": 602 | rendered = render_pep440_old(pieces) 603 | elif style == "git-describe": 604 | rendered = render_git_describe(pieces) 605 | elif style == "git-describe-long": 606 | rendered = render_git_describe_long(pieces) 607 | else: 608 | raise ValueError("unknown style '%s'" % style) 609 | 610 | return {"version": rendered, "full-revisionid": pieces["long"], 611 | "dirty": pieces["dirty"], "error": None, 612 | "date": pieces.get("date")} 613 | 614 | 615 | def get_versions(): 616 | """Get version information or return default if unable to do so.""" 617 | # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have 618 | # __file__, we can work backwards from there to the root. Some 619 | # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which 620 | # case we can only use expanded keywords. 621 | 622 | cfg = get_config() 623 | verbose = cfg.verbose 624 | 625 | try: 626 | return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, 627 | verbose) 628 | except NotThisMethod: 629 | pass 630 | 631 | try: 632 | root = os.path.realpath(__file__) 633 | # versionfile_source is the relative path from the top of the source 634 | # tree (where the .git directory might live) to this file. Invert 635 | # this to find the root from __file__. 636 | for _ in cfg.versionfile_source.split('/'): 637 | root = os.path.dirname(root) 638 | except NameError: 639 | return {"version": "0+unknown", "full-revisionid": None, 640 | "dirty": None, 641 | "error": "unable to find root of source tree", 642 | "date": None} 643 | 644 | try: 645 | pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) 646 | return render(pieces, cfg.style) 647 | except NotThisMethod: 648 | pass 649 | 650 | try: 651 | if cfg.parentdir_prefix: 652 | return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) 653 | except NotThisMethod: 654 | pass 655 | 656 | return {"version": "0+unknown", "full-revisionid": None, 657 | "dirty": None, 658 | "error": "unable to compute version", "date": None} 659 | -------------------------------------------------------------------------------- /src/torchcontentarea/csrc/common.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | #define MAX_POINT_COUNT 32 5 | #define DISCARD_BORDER 3 6 | #define DEG2RAD 0.01745329251f 7 | #define RAD2DEG (1.0f / DEG2RAD) 8 | #define MAX_CENTER_DIST 0.2 // * image width 9 | #define MIN_RADIUS 0.2 // * image width 10 | #define MAX_RADIUS 0.8 // * image width 11 | #define RANSAC_ATTEMPTS 32 12 | #define RANSAC_ITERATIONS 3 13 | #define RANSAC_INLIER_THRESHOLD 3 14 | 15 | typedef unsigned char uint8; 16 | 17 | struct FeatureThresholds 18 | { 19 | float edge; 20 | float angle; 21 | float intensity; 22 | }; 23 | 24 | struct ConfidenceThresholds 25 | { 26 | float edge; 27 | float circle; 28 | }; 29 | 30 | enum ImageFormat 31 | { 32 | rgb_float, 33 | rgb_double, 34 | rgb_uint8, 35 | rgb_int, 36 | rgb_long, 37 | gray_float, 38 | gray_double, 39 | gray_uint8, 40 | gray_int, 41 | gray_long, 42 | }; 43 | 44 | struct Image 45 | { 46 | Image(ImageFormat format, const void* data) : format(format), data(data) {} 47 | 48 | ImageFormat format; 49 | const void* data; 50 | }; 51 | 52 | #define ARG(...) __VA_ARGS__ 53 | #define KERNEL_DISPATCH_IMAGE_FORMAT(FUNCTION, DISPATCH_ARGS, IMAGE, ...) \ 54 | switch(IMAGE.format) \ 55 | { \ 56 | case(rgb_float): FUNCTION<3, float ><<>>((const float* )IMAGE.data, __VA_ARGS__); break; \ 57 | case(rgb_double): FUNCTION<3, double ><<>>((const double* )IMAGE.data, __VA_ARGS__); break; \ 58 | case(rgb_uint8): FUNCTION<3, uint8 ><<>>((const uint8* )IMAGE.data, __VA_ARGS__); break; \ 59 | case(rgb_int): FUNCTION<3, int ><<>>((const int* )IMAGE.data, __VA_ARGS__); break; \ 60 | case(rgb_long): FUNCTION<3, long int><<>>((const long int*)IMAGE.data, __VA_ARGS__); break; \ 61 | case(gray_float): FUNCTION<1, float ><<>>((const float* )IMAGE.data, __VA_ARGS__); break; \ 62 | case(gray_double): FUNCTION<1, double ><<>>((const double* )IMAGE.data, __VA_ARGS__); break; \ 63 | case(gray_uint8): FUNCTION<1, uint8 ><<>>((const uint8* )IMAGE.data, __VA_ARGS__); break; \ 64 | case(gray_int): FUNCTION<1, int ><<>>((const int* )IMAGE.data, __VA_ARGS__); break; \ 65 | case(gray_long): FUNCTION<1, long int><<>>((const long int*)IMAGE.data, __VA_ARGS__); break; \ 66 | } 67 | 68 | #define FUNCTION_CALL_IMAGE_FORMAT(FUNCTION, IMAGE, ...) \ 69 | switch(IMAGE.format) \ 70 | { \ 71 | case(rgb_float): FUNCTION<3, float >((const float* )IMAGE.data, __VA_ARGS__); break; \ 72 | case(rgb_double): FUNCTION<3, double >((const double* )IMAGE.data, __VA_ARGS__); break; \ 73 | case(rgb_uint8): FUNCTION<3, uint8 >((const uint8* )IMAGE.data, __VA_ARGS__); break; \ 74 | case(rgb_int): FUNCTION<3, int >((const int* )IMAGE.data, __VA_ARGS__); break; \ 75 | case(rgb_long): FUNCTION<3, long int>((const long int*)IMAGE.data, __VA_ARGS__); break; \ 76 | case(gray_float): FUNCTION<1, float >((const float* )IMAGE.data, __VA_ARGS__); break; \ 77 | case(gray_double): FUNCTION<1, double >((const double* )IMAGE.data, __VA_ARGS__); break; \ 78 | case(gray_uint8): FUNCTION<1, uint8 >((const uint8* )IMAGE.data, __VA_ARGS__); break; \ 79 | case(gray_int): FUNCTION<1, int >((const int* )IMAGE.data, __VA_ARGS__); break; \ 80 | case(gray_long): FUNCTION<1, long int>((const long int*)IMAGE.data, __VA_ARGS__); break; \ 81 | } 82 | -------------------------------------------------------------------------------- /src/torchcontentarea/csrc/cpu_functions.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "common.hpp" 3 | 4 | namespace cpu 5 | { 6 | void find_points(Image image, const int batch_count, const int channel_count, const int image_height, const int image_width, const int strip_count, const FeatureThresholds feature_thresholds, float* points_x, float* points_y, float* point_score); 7 | void make_strips(Image image, const int batch_count, const int channel_count, const int image_height, const int image_width, const int strip_count, const int strip_width, float* strips); 8 | void find_points_from_strip_scores(const float* strips, const int batch_count, const int image_width, const int image_height, const int strip_count, const int model_patch_size, float* points_x, float* points_y, float* point_score); 9 | void fit_circle(const float* points_x, const float* points_y, const float* points_score, const int batch_count, const int point_count, const ConfidenceThresholds confidence_thresholds, const int image_height, const int image_width, float* results); 10 | } 11 | -------------------------------------------------------------------------------- /src/torchcontentarea/csrc/cuda_functions.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "common.hpp" 3 | 4 | namespace cuda 5 | { 6 | void find_points(Image image, const int batch_count, const int channel_count, const int image_height, const int image_width, const int strip_count, const FeatureThresholds feature_thresholds, float* points_x, float* points_y, float* point_score); 7 | void make_strips(Image image, const int batch_count, const int channel_count, const int image_height, const int image_width, const int strip_count, const int strip_width, float* strips); 8 | void find_points_from_strip_scores(const float* strips, const int batch_count, const int image_width, const int image_height, const int strip_count, const int model_patch_size, float* points_x, float* points_y, float* point_score); 9 | void fit_circle(const float* points_x, const float* points_y, const float* points_score, const int batch_count, const int point_count, const ConfidenceThresholds confidence_thresholds, const int image_height, const int image_width, float* results); 10 | } 11 | -------------------------------------------------------------------------------- /src/torchcontentarea/csrc/implementation.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "implementation.hpp" 4 | #include "cpu_functions.hpp" 5 | #include "cuda_functions.cuh" 6 | 7 | #define IMAGE_DTYPE_ERROR_MSG(t) std::string("Unsupported image dtype .").insert(24, torch::utils::getDtypeNames(t).second) 8 | #define IMAGE_NDIM_ERROR_MSG(d) std::string("Expected an image tensor with 3 or 4 dimensions but found .").insert(58, std::to_string(d)) 9 | #define IMAGE_CHANNEL_ERROR_MSG(c) std::string("Expected a grayscale or RGB image but found size at position 1.").insert(49, std::to_string(c)) 10 | #define POINTS_NDIM_ERROR_MSG(d) std::string("Expected a point tensor with 2 or 3 dimensions but found .").insert(52, std::to_string(d)) 11 | #define POINTS_CHANNEL_ERROR_MSG(d) std::string("Expected a point tensor with 3 channels but found .").insert(50, std::to_string(d)) 12 | 13 | void check_image_tensor(torch::Tensor &image) 14 | { 15 | image = image.contiguous(); 16 | 17 | if (image.dim() != 3 && image.dim() != 4) 18 | { 19 | throw std::runtime_error(IMAGE_NDIM_ERROR_MSG(image.dim())); 20 | } 21 | 22 | if (image.size(-3) != 1 && image.size(-3) != 3) 23 | { 24 | throw std::runtime_error(IMAGE_CHANNEL_ERROR_MSG(image.size(1))); 25 | } 26 | 27 | switch (torch::typeMetaToScalarType(image.dtype())) 28 | { 29 | case (torch::kFloat): break; 30 | case (torch::kDouble): break; 31 | case (torch::kByte): break; 32 | case (torch::kInt): break; 33 | case (torch::kLong): break; 34 | default: throw std::runtime_error(IMAGE_DTYPE_ERROR_MSG(torch::typeMetaToScalarType(image.dtype()))); 35 | } 36 | } 37 | 38 | void check_points(torch::Tensor &points) 39 | { 40 | points = points.contiguous(); 41 | 42 | if (points.dim() != 2 && points.dim() != 3) 43 | { 44 | throw std::runtime_error(POINTS_NDIM_ERROR_MSG(points.dim())); 45 | } 46 | 47 | if (points.size(-2) != 3) 48 | { 49 | throw std::runtime_error(POINTS_CHANNEL_ERROR_MSG(points.size(1))); 50 | } 51 | } 52 | 53 | Image get_image_data(torch::Tensor image) 54 | { 55 | bool is_rgb = image.size(-3) == 3; 56 | switch (torch::typeMetaToScalarType(image.dtype())) 57 | { 58 | case (torch::kFloat): return Image(is_rgb ? ImageFormat::rgb_float : ImageFormat::gray_float, (void*)image.data_ptr()); 59 | case (torch::kDouble): return Image(is_rgb ? ImageFormat::rgb_double : ImageFormat::gray_double, (void*)image.data_ptr()); 60 | case (torch::kByte): return Image(is_rgb ? ImageFormat::rgb_uint8 : ImageFormat::gray_uint8, (void*)image.data_ptr()); 61 | case (torch::kInt): return Image(is_rgb ? ImageFormat::rgb_int : ImageFormat::gray_int, (void*)image.data_ptr()); 62 | case (torch::kLong): return Image(is_rgb ? ImageFormat::rgb_long : ImageFormat::gray_long, (void*)image.data_ptr()); 63 | default: throw std::runtime_error(IMAGE_DTYPE_ERROR_MSG(torch::typeMetaToScalarType(image.dtype()))); 64 | } 65 | } 66 | 67 | torch::Tensor estimate_area_handcrafted(torch::Tensor image, int strip_count, FeatureThresholds feature_thresholds, ConfidenceThresholds confidence_thresholds) 68 | { 69 | check_image_tensor(image); 70 | 71 | Image image_data = get_image_data(image); 72 | 73 | bool batched = image.dim() == 4; 74 | 75 | int batch_count = batched ? image.size(0) : 1; 76 | int channel_count = image.size(-3); 77 | int image_height = image.size(-2); 78 | int image_width = image.size(-1); 79 | int point_count = 2 * strip_count; 80 | 81 | torch::TensorOptions options = torch::device(image.device()).dtype(torch::kFloat32); 82 | torch::Tensor result = batched ? torch::empty({batch_count, 4}, options) : torch::empty({4}, options); 83 | 84 | if (image.device().is_cpu()) 85 | { 86 | float* temp_buffer = (float*)malloc(3 * batch_count * point_count * sizeof(float)); 87 | float* points_x = temp_buffer + 0 * point_count; 88 | float* points_y = temp_buffer + 1 * point_count; 89 | float* points_s = temp_buffer + 2 * point_count; 90 | 91 | cpu::find_points(image_data, batch_count, channel_count, image_height, image_width, strip_count, feature_thresholds, points_x, points_y, points_s); 92 | 93 | cpu::fit_circle(points_x, points_y, points_s, batch_count, point_count, confidence_thresholds, image_height, image_width, result.data_ptr()); 94 | 95 | free(temp_buffer); 96 | } 97 | else 98 | { 99 | float* temp_buffer; 100 | cudaMalloc((void**)&temp_buffer, 3 * batch_count * point_count * sizeof(float)); 101 | float* points_x = temp_buffer + 0 * point_count; 102 | float* points_y = temp_buffer + 1 * point_count; 103 | float* points_s = temp_buffer + 2 * point_count; 104 | 105 | cuda::find_points(image_data, batch_count, channel_count, image_height, image_width, strip_count, feature_thresholds, points_x, points_y, points_s); 106 | 107 | cuda::fit_circle(points_x, points_y, points_s, batch_count, point_count, confidence_thresholds, image_height, image_width, result.data_ptr()); 108 | 109 | cudaFree(temp_buffer); 110 | } 111 | 112 | return result; 113 | } 114 | 115 | torch::Tensor estimate_area_learned(torch::Tensor image, int strip_count, torch::jit::Module model, int model_patch_size, ConfidenceThresholds confidence_thresholds) 116 | { 117 | check_image_tensor(image); 118 | 119 | Image image_data = get_image_data(image); 120 | 121 | bool batched = image.dim() == 4; 122 | 123 | int batch_count = batched ? image.size(0) : 1; 124 | int channel_count = image.size(-3); 125 | int image_height = image.size(-2); 126 | int image_width = image.size(-1); 127 | int point_count = 2 * strip_count; 128 | 129 | torch::TensorOptions options = torch::device(image.device()).dtype(torch::kFloat32); 130 | torch::Tensor result = batched ? torch::empty({batch_count, 4}, options) : torch::empty({4}, options); 131 | 132 | torch::Tensor strips = torch::empty({batch_count * strip_count, 5, model_patch_size, image_width}, options); 133 | std::vector model_input = {strips}; 134 | 135 | if (image.device().is_cpu()) 136 | { 137 | float* temp_buffer = (float*)malloc(3 * batch_count * point_count * sizeof(float)); 138 | float* points_x = temp_buffer + 0 * point_count; 139 | float* points_y = temp_buffer + 1 * point_count; 140 | float* points_s = temp_buffer + 2 * point_count; 141 | 142 | cpu::make_strips(image_data, batch_count, channel_count, image_height, image_width, strip_count, model_patch_size, strips.data_ptr()); 143 | 144 | torch::Tensor strip_scores = torch::sigmoid(model.forward(model_input).toTensor()); 145 | 146 | cpu::find_points_from_strip_scores(strip_scores.data_ptr(), batch_count, image_height, image_width, strip_count, model_patch_size, points_x, points_y, points_s); 147 | 148 | cpu::fit_circle(points_x, points_y, points_s, batch_count, point_count, confidence_thresholds, image_height, image_width, result.data_ptr()); 149 | 150 | free(temp_buffer); 151 | } 152 | else 153 | { 154 | float* temp_buffer; 155 | cudaMalloc((void**)&temp_buffer, 3 * batch_count * point_count * sizeof(float)); 156 | float* points_x = temp_buffer + 0 * point_count; 157 | float* points_y = temp_buffer + 1 * point_count; 158 | float* points_s = temp_buffer + 2 * point_count; 159 | 160 | cuda::make_strips(image_data, batch_count, channel_count, image_height, image_width, strip_count, model_patch_size, strips.data_ptr()); 161 | 162 | torch::Tensor strip_scores = torch::sigmoid(model.forward(model_input).toTensor()); 163 | 164 | cuda::find_points_from_strip_scores(strip_scores.data_ptr(), batch_count, image_height, image_width, strip_count, model_patch_size, points_x, points_y, points_s); 165 | 166 | cuda::fit_circle(points_x, points_y, points_s, batch_count, point_count, confidence_thresholds, image_height, image_width, result.data_ptr()); 167 | 168 | cudaFree(temp_buffer); 169 | } 170 | 171 | return result; 172 | } 173 | 174 | torch::Tensor get_points_handcrafted(torch::Tensor image, int strip_count, FeatureThresholds feature_thresholds) 175 | { 176 | check_image_tensor(image); 177 | 178 | Image image_data = get_image_data(image); 179 | 180 | bool batched = image.dim() == 4; 181 | 182 | int batch_count = batched ? image.size(0) : 1; 183 | int channel_count = image.size(-3); 184 | int image_height = image.size(-2); 185 | int image_width = image.size(-1); 186 | int point_count = 2 * strip_count; 187 | 188 | torch::TensorOptions options = torch::device(image.device()).dtype(torch::kFloat32); 189 | torch::Tensor result = batched ? torch::empty({batch_count, 3, point_count}, options) : torch::empty({3, point_count}, options); 190 | 191 | float* temp_buffer = result.data_ptr(); 192 | float* points_x = temp_buffer + 0 * point_count; 193 | float* points_y = temp_buffer + 1 * point_count; 194 | float* points_s = temp_buffer + 2 * point_count; 195 | 196 | if (image.device().is_cpu()) 197 | { 198 | cpu::find_points(image_data, batch_count, channel_count, image_height, image_width, strip_count, feature_thresholds, points_x, points_y, points_s); 199 | } 200 | else 201 | { 202 | cuda::find_points(image_data, batch_count, channel_count, image_height, image_width, strip_count, feature_thresholds, points_x, points_y, points_s); 203 | } 204 | 205 | return result; 206 | } 207 | 208 | torch::Tensor get_points_learned(torch::Tensor image, int strip_count, torch::jit::Module model, int model_patch_size) 209 | { 210 | check_image_tensor(image); 211 | 212 | Image image_data = get_image_data(image); 213 | 214 | bool batched = image.dim() == 4; 215 | 216 | int batch_count = batched ? image.size(0) : 1; 217 | int channel_count = image.size(-3); 218 | int image_height = image.size(-2); 219 | int image_width = image.size(-1); 220 | int point_count = 2 * strip_count; 221 | 222 | torch::TensorOptions options = torch::device(image.device()).dtype(torch::kFloat32); 223 | torch::Tensor result = batched ? torch::empty({batch_count, 3, point_count}, options) : torch::empty({3, point_count}, options); 224 | 225 | torch::Tensor strips = torch::empty({batch_count * strip_count, 5, model_patch_size, image_width}, options); 226 | std::vector model_input = {strips}; 227 | 228 | float* temp_buffer = result.data_ptr(); 229 | float* points_x = temp_buffer + 0 * point_count; 230 | float* points_y = temp_buffer + 1 * point_count; 231 | float* points_s = temp_buffer + 2 * point_count; 232 | 233 | if (image.device().is_cpu()) 234 | { 235 | cpu::make_strips(image_data, batch_count, channel_count, image_height, image_width, strip_count, model_patch_size, strips.data_ptr()); 236 | 237 | torch::Tensor strip_scores = torch::sigmoid(model.forward(model_input).toTensor()); 238 | 239 | cpu::find_points_from_strip_scores(strip_scores.data_ptr(), batch_count, image_height, image_width, strip_count, model_patch_size, points_x, points_y, points_s); 240 | } 241 | else 242 | { 243 | cuda::make_strips(image_data, batch_count, channel_count, image_height, image_width, strip_count, model_patch_size, strips.data_ptr()); 244 | 245 | torch::Tensor strip_scores = torch::sigmoid(model.forward(model_input).toTensor()); 246 | 247 | cuda::find_points_from_strip_scores(strip_scores.data_ptr(), batch_count, image_height, image_width, strip_count, model_patch_size, points_x, points_y, points_s); 248 | } 249 | 250 | return result; 251 | } 252 | 253 | torch::Tensor fit_area(torch::Tensor points, py::tuple image_size, ConfidenceThresholds confidence_thresholds) 254 | { 255 | check_points(points); 256 | 257 | bool batched = points.dim() == 3; 258 | 259 | int batch_count = batched ? points.size(0) : 1; 260 | int image_height = image_size[0].cast(); 261 | int image_width = image_size[1].cast(); 262 | int point_count = points.size(-1); 263 | 264 | torch::TensorOptions options = torch::device(points.device()).dtype(torch::kFloat32); 265 | torch::Tensor result = batched ? torch::empty({batch_count, 4}, options) : torch::empty({4}, options); 266 | 267 | float* temp_buffer = points.data_ptr(); 268 | float* points_x = temp_buffer + 0 * point_count; 269 | float* points_y = temp_buffer + 1 * point_count; 270 | float* points_s = temp_buffer + 2 * point_count; 271 | 272 | if (points.device().is_cpu()) 273 | { 274 | cpu::fit_circle(points_x, points_y, points_s, batch_count, point_count, confidence_thresholds, image_height, image_width, result.data_ptr()); 275 | } 276 | else 277 | { 278 | cuda::fit_circle(points_x, points_y, points_s, batch_count, point_count, confidence_thresholds, image_height, image_width, result.data_ptr()); 279 | } 280 | 281 | return result; 282 | } 283 | -------------------------------------------------------------------------------- /src/torchcontentarea/csrc/implementation.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include "common.hpp" 4 | 5 | torch::Tensor estimate_area_handcrafted(torch::Tensor image, int strip_count, FeatureThresholds feature_thresholds, ConfidenceThresholds confidence_thresholds); 6 | torch::Tensor estimate_area_learned(torch::Tensor image, int strip_count, torch::jit::Module model, int model_patch_size, ConfidenceThresholds confidence_thresholds); 7 | 8 | torch::Tensor get_points_handcrafted(torch::Tensor points, int strip_count, FeatureThresholds feature_thresholds); 9 | torch::Tensor get_points_learned(torch::Tensor points, int strip_count, torch::jit::Module model, int model_patch_size); 10 | 11 | torch::Tensor fit_area(torch::Tensor points, py::tuple image_size, ConfidenceThresholds confidence_thresholds); 12 | -------------------------------------------------------------------------------- /src/torchcontentarea/csrc/python_bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "implementation.hpp" 3 | 4 | template<> 5 | struct py::detail::type_caster 6 | { 7 | PYBIND11_TYPE_CASTER(FeatureThresholds, py::detail::_("FeatureThresholds")); 8 | 9 | bool load(handle src, bool) 10 | { 11 | if (!src | src.is_none() | !py::isinstance(src)) 12 | return false; 13 | 14 | py::tuple args = reinterpret_borrow(src); 15 | if (len(args) != 3) 16 | return false; 17 | 18 | value.edge = args[0].cast(); 19 | value.angle = args[1].cast(); 20 | value.intensity = args[2].cast(); 21 | return true; 22 | } 23 | }; 24 | 25 | template<> 26 | struct py::detail::type_caster 27 | { 28 | PYBIND11_TYPE_CASTER(ConfidenceThresholds, py::detail::_("ConfidenceThresholds")); 29 | 30 | bool load(handle src, bool convert) 31 | { 32 | if (!src | src.is_none() | !py::isinstance(src)) 33 | return false; 34 | 35 | py::tuple args = reinterpret_borrow(src); 36 | if (len(args) != 2) 37 | return false; 38 | 39 | value.edge = args[0].cast(); 40 | value.circle = args[1].cast(); 41 | return true; 42 | } 43 | }; 44 | 45 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 46 | { 47 | m.def("estimate_area_handcrafted", &estimate_area_handcrafted); 48 | m.def("estimate_area_learned", &estimate_area_learned); 49 | m.def("get_points_handcrafted", &get_points_handcrafted); 50 | m.def("get_points_learned", &get_points_learned); 51 | m.def("fit_area", &fit_area); 52 | } 53 | -------------------------------------------------------------------------------- /src/torchcontentarea/csrc/source/find_points_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "../common.hpp" 3 | 4 | namespace cpu 5 | { 6 | // ========================================================================= 7 | // General functionality... 8 | 9 | template 10 | float load_element(const T* data, const int index, const int color_stride) 11 | { 12 | float value = c == 1 ? data[index] : 0.2126f * data[index + 0 * color_stride] + 0.7152f * data[index + 1 * color_stride] + 0.0722f * data[index + 2 * color_stride]; 13 | return std::is_floating_point::value ? 255 * value : value; 14 | } 15 | 16 | template 17 | float load_sobel_strip(const T* data, const int index, const int spatial_stride, const int color_stride) 18 | { 19 | return 0.25 * load_element(data, index - spatial_stride, color_stride) + 0.5 * load_element(data, index, color_stride) + 0.25 * load_element(data, index + spatial_stride, color_stride); 20 | } 21 | 22 | // ========================================================================= 23 | // Main function... 24 | 25 | template 26 | void find_points(const T* images, const int batch_count, const int image_height, const int image_width, const int strip_count, FeatureThresholds feature_thresholds, float* points_x, float* points_y, float* point_scores) 27 | { 28 | for (int batch_index = 0; batch_index < batch_count; ++batch_index) 29 | { 30 | const T* image = images + batch_index * c * image_height * image_width; 31 | 32 | for (int strip_index = 0; strip_index < strip_count; ++strip_index) 33 | { 34 | int image_y = 1 + (image_height - 2) / (1.0f + std::exp(-(strip_index - strip_count / 2.0f + 0.5f)/(strip_count / 8.0f))); 35 | 36 | for (int point_index = 0; point_index < 2; ++point_index) 37 | { 38 | bool flip = point_index > 0; 39 | 40 | float max_preceeding_intensity = 0.0f; 41 | float best_score = 0.0f; 42 | int best_index = 0; 43 | 44 | for (int x = 1; x < image_width / 2; ++x) 45 | { 46 | int image_x = flip ? image_width - 1 - x : x; 47 | 48 | float intensity = c == 3 ? load_element(image, image_x + image_y * image_width, image_width * image_height) : image[image_x + image_y * image_width]; 49 | max_preceeding_intensity = max_preceeding_intensity < intensity ? intensity : max_preceeding_intensity; 50 | 51 | float left = load_sobel_strip(image, (image_x - 1) + image_y * image_width, image_width, image_width * image_height); 52 | float right = load_sobel_strip(image, (image_x + 1) + image_y * image_width, image_width, image_width * image_height); 53 | float top = load_sobel_strip(image, image_x + (image_y - 1) * image_width, 1, image_width * image_height); 54 | float bot = load_sobel_strip(image, image_x + (image_y + 1) * image_width, 1, image_width * image_height); 55 | 56 | float grad_x = right - left; 57 | float grad_y = bot - top; 58 | float grad = sqrt(grad_x * grad_x + grad_y * grad_y); 59 | 60 | float center_dir_x = (0.5f * image_width) - (float)image_x; 61 | float center_dir_y = (0.5f * image_height) - (float)image_y; 62 | float center_dir_norm = sqrt(center_dir_x * center_dir_x + center_dir_y * center_dir_y); 63 | 64 | float dot = grad == 0 ? -1 : (center_dir_x * grad_x + center_dir_y * grad_y) / (center_dir_norm * grad); 65 | float angle = RAD2DEG * acos(dot); 66 | 67 | // ============================================================ 68 | // Final scoring... 69 | 70 | float edge_score = tanh(grad / feature_thresholds.edge); 71 | float angle_score = 1.0f - tanh(angle / feature_thresholds.angle); 72 | float intensity_score = 1.0f - tanh(max_preceeding_intensity / feature_thresholds.intensity); 73 | 74 | float point_score = edge_score * angle_score * intensity_score; 75 | 76 | if (point_score > best_score) 77 | { 78 | best_score = point_score; 79 | best_index = image_x; 80 | } 81 | } 82 | 83 | if (best_index < DISCARD_BORDER || best_index >= image_width - DISCARD_BORDER) 84 | { 85 | best_score = 0.0f; 86 | } 87 | 88 | points_x[strip_index + point_index * strip_count + batch_index * 3 * 2 * strip_count] = best_index; 89 | points_y[strip_index + point_index * strip_count + batch_index * 3 * 2 * strip_count] = image_y; 90 | point_scores[strip_index + point_index * strip_count + batch_index * 3 * 2 * strip_count] = best_score; 91 | } 92 | } 93 | } 94 | } 95 | 96 | void find_points(Image image, const int batch_count, const int channel_count, const int image_height, const int image_width, const int strip_count, FeatureThresholds feature_thresholds, float* points_x, float* points_y, float* point_scores) 97 | { 98 | FUNCTION_CALL_IMAGE_FORMAT(find_points, image, batch_count, image_height, image_width, strip_count, feature_thresholds, points_x, points_y, point_scores); 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /src/torchcontentarea/csrc/source/find_points_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include "../common.hpp" 3 | 4 | namespace cuda 5 | { 6 | // ========================================================================= 7 | // General functionality... 8 | 9 | template 10 | __device__ float load_element(const T* data, const int index, const int color_stride) 11 | { 12 | float value = c == 1 ? data[index] : 0.2126f * data[index + 0 * color_stride] + 0.7152f * data[index + 1 * color_stride] + 0.0722f * data[index + 2 * color_stride]; 13 | return std::is_floating_point::value ? 255 * value : value; 14 | } 15 | 16 | __device__ float sobel_filter(const float* data, const int index, const int x_stride, const int y_stride, float* x_grad, float* y_grad) 17 | { 18 | float left = 0.25f * data[index - x_stride - y_stride] + 0.5f * data[index - x_stride] + 0.25f * data[index - x_stride + y_stride]; 19 | float right = 0.25f * data[index + x_stride - y_stride] + 0.5f * data[index + x_stride] + 0.25f * data[index + x_stride + y_stride]; 20 | *x_grad = right - left; 21 | 22 | float top = 0.25f * data[index - x_stride - y_stride] + 0.5f * data[index - y_stride] + 0.25f * data[index + x_stride - y_stride]; 23 | float bot = 0.25f * data[index - x_stride + y_stride] + 0.5f * data[index + y_stride] + 0.25f * data[index + x_stride + y_stride]; 24 | *y_grad = bot - top; 25 | 26 | return sqrt(*x_grad * *x_grad + *y_grad * *y_grad); 27 | } 28 | 29 | // ========================================================================= 30 | // Kernels... 31 | 32 | template 33 | __global__ void find_points_kernel(const T* g_image_batch, float* g_edge_x_batch, float* g_edge_y_batch, float* g_edge_scores_batch, const int image_width, const int image_height, const int strip_count, const FeatureThresholds feature_thresholds) 34 | { 35 | constexpr int warp_size = 32; 36 | 37 | const T* g_image = g_image_batch + blockIdx.z * c * image_width * image_height; 38 | float* g_edge_x = g_edge_x_batch + blockIdx.z * 3 * 2 * strip_count; 39 | float* g_edge_y = g_edge_y_batch + blockIdx.z * 3 * 2 * strip_count; 40 | float* g_edge_scores = g_edge_scores_batch + blockIdx.z * 3 * 2 * strip_count; 41 | 42 | int thread_count = blockDim.x; 43 | int warp_count = 1 + (thread_count - 1) / warp_size; 44 | 45 | extern __shared__ int s_shared_buffer[]; 46 | float* s_image_strip = (float*)s_shared_buffer; 47 | int* s_cross_warp_operation_buffer = s_shared_buffer + 3 * thread_count; 48 | float* s_cross_warp_operation_buffer_2 = (float*)(s_shared_buffer + 3 * thread_count + warp_count); 49 | 50 | int warp_index = threadIdx.x >> 5; 51 | int lane_index = threadIdx.x & 31; 52 | 53 | bool flip = blockIdx.x == 1; 54 | 55 | // ============================================================ 56 | // Load strip into shared memory... 57 | 58 | int image_x = flip ? image_width - 1 - threadIdx.x : threadIdx.x; 59 | 60 | int strip_index = blockIdx.y; 61 | int strip_height = 1 + (image_height - 2) / (1.0f + exp(-(strip_index - strip_count / 2.0f + 0.5f)/(strip_count / 8.0f))); 62 | 63 | #pragma unroll 64 | for (int y = 0; y < 3; y++) 65 | { 66 | int image_element_index = image_x + (strip_height + (y - 1)) * image_width; 67 | s_image_strip[threadIdx.x + y * thread_count] = load_element(g_image, image_element_index, image_width * image_height); 68 | } 69 | 70 | __syncthreads(); 71 | 72 | // ============================================================ 73 | // Calculate largest preceeding intensity... 74 | 75 | float max_preceeding_intensity = s_image_strip[threadIdx.x + thread_count]; 76 | 77 | #pragma unroll 78 | for (int d=1; d < 32; d<<=1) 79 | { 80 | float other_intensity = __shfl_up_sync(0xffffffff, max_preceeding_intensity, d); 81 | 82 | if (lane_index >= d && other_intensity > max_preceeding_intensity) 83 | { 84 | max_preceeding_intensity = other_intensity; 85 | } 86 | } 87 | 88 | if (lane_index == warp_size - 1) 89 | { 90 | s_cross_warp_operation_buffer[warp_index] = max_preceeding_intensity; 91 | } 92 | 93 | __syncthreads(); 94 | 95 | if (warp_index == 0) 96 | { 97 | float warp_max = lane_index < warp_count ? s_cross_warp_operation_buffer[lane_index] : 0; 98 | 99 | #pragma unroll 100 | for (int d=1; d < 32; d<<=1) 101 | { 102 | float other_max = __shfl_up_sync(0xffffffff, warp_max, d); 103 | 104 | if (lane_index >= d && other_max > warp_max) 105 | { 106 | warp_max = other_max; 107 | } 108 | } 109 | 110 | if (lane_index < warp_count) 111 | { 112 | s_cross_warp_operation_buffer[lane_index] = warp_max; 113 | } 114 | } 115 | 116 | __syncthreads(); 117 | 118 | if (warp_index > 0) 119 | { 120 | float other_intensity = s_cross_warp_operation_buffer[warp_index-1]; 121 | max_preceeding_intensity = other_intensity > max_preceeding_intensity ? other_intensity : max_preceeding_intensity; 122 | } 123 | 124 | // ============================================================ 125 | // Applying sobel kernel to image patch... 126 | 127 | float x_grad = 0; 128 | float y_grad = 0; 129 | float grad = 0; 130 | 131 | if (threadIdx.x > 0 && threadIdx.x < thread_count - 1) 132 | { 133 | grad = sobel_filter(s_image_strip, threadIdx.x + thread_count, 1, thread_count, &x_grad, &y_grad); 134 | } 135 | 136 | // ============================================================ 137 | // Calculating angle between gradient vector and center vector... 138 | 139 | float center_dir_x = (0.5f * image_width) - (float)image_x; 140 | float center_dir_y = (0.5f * image_height) - (float)strip_height; 141 | float center_dir_norm = sqrt(center_dir_x * center_dir_x + center_dir_y * center_dir_y); 142 | 143 | x_grad = flip ? -x_grad : x_grad; 144 | 145 | float dot = grad == 0 ? -1 : (center_dir_x * x_grad + center_dir_y * y_grad) / (center_dir_norm * grad); 146 | float angle = RAD2DEG * acos(dot); 147 | 148 | // ============================================================ 149 | // Final scoring... 150 | 151 | float edge_score = tanh(grad / feature_thresholds.edge); 152 | float angle_score = 1.0f - tanh(angle / feature_thresholds.angle); 153 | float intensity_score = 1.0f - tanh(max_preceeding_intensity / feature_thresholds.intensity); 154 | 155 | float point_score = edge_score * angle_score * intensity_score; 156 | 157 | // ============================================================ 158 | // Reduction to find the best edge... 159 | 160 | int best_edge_x = image_x; 161 | float best_edge_score = point_score; 162 | 163 | // warp reduction.... 164 | #pragma unroll 165 | for (int offset = warp_size >> 1; offset > 0; offset >>= 1) 166 | { 167 | int other_edge_x = __shfl_down_sync(0xffffffff, best_edge_x, offset); 168 | float other_edge_score = __shfl_down_sync(0xffffffff, best_edge_score, offset); 169 | 170 | if (other_edge_score > best_edge_score) 171 | { 172 | best_edge_x = other_edge_x; 173 | best_edge_score = other_edge_score; 174 | } 175 | } 176 | 177 | if (lane_index == 0) 178 | { 179 | s_cross_warp_operation_buffer[warp_index] = best_edge_x; 180 | s_cross_warp_operation_buffer_2[warp_index] = best_edge_score; 181 | } 182 | 183 | __syncthreads(); 184 | 185 | // block reduction.... 186 | if (warp_index == 0 && lane_index < warp_count) 187 | { 188 | best_edge_x = s_cross_warp_operation_buffer[lane_index]; 189 | best_edge_score = s_cross_warp_operation_buffer_2[lane_index]; 190 | 191 | int next_power_two = pow(2, ceil(log(warp_count)/log(2))); 192 | 193 | #pragma unroll 194 | for (int offset = next_power_two >> 1 ; offset > 0; offset >>= 1) 195 | { 196 | int other_edge_x = __shfl_down_sync(0xffffffff, best_edge_x, offset); 197 | float other_edge_score = __shfl_down_sync(0xffffffff, best_edge_score, offset); 198 | 199 | if (other_edge_score > best_edge_score) 200 | { 201 | best_edge_x = other_edge_x; 202 | best_edge_score = other_edge_score; 203 | } 204 | } 205 | 206 | if (lane_index == 0) 207 | { 208 | int point_index = flip ? strip_index : strip_index + strip_count; 209 | 210 | if (best_edge_x < DISCARD_BORDER || best_edge_x >= image_width - DISCARD_BORDER) 211 | { 212 | best_edge_score = 0.0f; 213 | } 214 | 215 | g_edge_x[point_index] = best_edge_x; 216 | g_edge_y[point_index] = strip_height; 217 | g_edge_scores[point_index] = best_edge_score; 218 | } 219 | } 220 | } 221 | 222 | // ========================================================================= 223 | // Main function... 224 | 225 | void find_points(Image image, const int batch_count, const int channel_count, const int image_height, const int image_width, const int strip_count, const FeatureThresholds feature_thresholds, float* points_x, float* points_y, float* point_scores) 226 | { 227 | int half_width = image_width / 2; 228 | int warps = 1 + (half_width - 1) / 32; 229 | int threads = warps * 32; 230 | 231 | threads = threads > 1024 ? 1024 : threads; 232 | 233 | dim3 grid(2, strip_count, batch_count); 234 | dim3 block(threads); 235 | int shared_memory = (3 * threads + 2 * warps) * sizeof(int); 236 | 237 | KERNEL_DISPATCH_IMAGE_FORMAT(find_points_kernel, ARG(grid, block, shared_memory), image, points_x, points_y, point_scores, image_width, image_height, strip_count, feature_thresholds); 238 | } 239 | } 240 | -------------------------------------------------------------------------------- /src/torchcontentarea/csrc/source/find_points_from_scores_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "../common.hpp" 3 | 4 | namespace cpu 5 | { 6 | void find_points_from_strip_scores(const float* strips, const int batch_count, const int image_height, const int image_width, const int strip_count, const int model_patch_size, float* points_x, float* points_y, float* point_score) 7 | { 8 | for (int batch_index = 0; batch_index < batch_count; ++batch_index) 9 | { 10 | int half_patch_size = model_patch_size / 2; 11 | int strip_width = image_width - 2 * half_patch_size; 12 | 13 | for (int strip_index = 0; strip_index < strip_count; ++strip_index) 14 | { 15 | int image_y = 1 + (image_height - 2) / (1.0f + std::exp(-(strip_index - strip_count / 2.0f + 0.5f) / (strip_count / 8.0f))); 16 | 17 | float best_score = 0.0f; 18 | int best_index = 0; 19 | 20 | for (int strip_x = 0; strip_x < strip_width / 2; ++strip_x) 21 | { 22 | float point_score = strips[strip_x + strip_index * strip_width + batch_index * strip_count * strip_width]; 23 | 24 | if (point_score > best_score) 25 | { 26 | best_score = point_score; 27 | best_index = strip_x + half_patch_size; 28 | } 29 | } 30 | 31 | points_x[strip_index + batch_index * 3 * 2 * strip_count] = best_index; 32 | points_y[strip_index + batch_index * 3 * 2 * strip_count] = image_y; 33 | point_score[strip_index + batch_index * 3 * 2 * strip_count] = best_score; 34 | 35 | best_score = 0.0f; 36 | best_index = 0; 37 | 38 | for (int strip_x = strip_width / 2; strip_x < strip_width; ++strip_x) 39 | { 40 | float point_score = strips[strip_x + strip_index * strip_width + batch_index * strip_count * strip_width]; 41 | 42 | if (point_score > best_score) 43 | { 44 | best_score = point_score; 45 | best_index = strip_x + half_patch_size; 46 | } 47 | } 48 | 49 | points_x[strip_index + strip_count + batch_index * 3 * 2 * strip_count] = best_index; 50 | points_y[strip_index + strip_count + batch_index * 3 * 2 * strip_count] = image_y; 51 | point_score[strip_index + strip_count + batch_index * 3 * 2 * strip_count] = best_score; 52 | } 53 | } 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /src/torchcontentarea/csrc/source/find_points_from_scores_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include "../common.hpp" 3 | 4 | namespace cuda 5 | { 6 | __global__ void find_best_edge(const float* g_score_strips_batch, float* g_edge_x_batch, float* g_edge_y_batch, float* g_edge_scores_batch, const int strip_width, const int image_height, const int strip_count, const int half_patch_size) 7 | { 8 | int thread_count = blockDim.x; 9 | int warp_count = 1 + (thread_count - 1) / 32; 10 | 11 | extern __shared__ float s_shared_buffer[]; 12 | float* s_cross_warp_operation_buffer = s_shared_buffer; 13 | float* s_cross_warp_operation_buffer_2 = s_shared_buffer + warp_count; 14 | 15 | const float* g_score_strips = g_score_strips_batch + blockIdx.z * strip_count * strip_width; 16 | float* g_edge_x = g_edge_x_batch + blockIdx.z * 3 * 2 * strip_count; 17 | float* g_edge_y = g_edge_y_batch + blockIdx.z * 3 * 2 * strip_count; 18 | float* g_edge_scores = g_edge_scores_batch + blockIdx.z * 3 * 2 * strip_count; 19 | 20 | int warp_index = threadIdx.x >> 5; 21 | int lane_index = threadIdx.x & 31; 22 | 23 | bool flip = blockIdx.x == 1; 24 | 25 | // ============================================================ 26 | // Load strip into shared memory... 27 | 28 | int image_x = flip ? strip_width - 1 - threadIdx.x : threadIdx.x; 29 | 30 | int strip_index = blockIdx.y; 31 | int strip_height = 1 + (image_height - 2) / (1.0f + exp(-(strip_index - strip_count / 2.0f + 0.5f)/(strip_count / 8.0f))); 32 | 33 | float point_score = g_score_strips[image_x + strip_index * strip_width]; 34 | 35 | int best_edge_x = image_x; 36 | float best_edge_score = point_score; 37 | 38 | // warp reduction.... 39 | #pragma unroll 40 | for (int offset = 32 >> 1; offset > 0; offset >>= 1) 41 | { 42 | int other_edge_x = __shfl_down_sync(0xffffffff, best_edge_x, offset); 43 | float other_edge_score = __shfl_down_sync(0xffffffff, best_edge_score, offset); 44 | 45 | if (other_edge_score > best_edge_score) 46 | { 47 | best_edge_x = other_edge_x; 48 | best_edge_score = other_edge_score; 49 | } 50 | } 51 | 52 | if (lane_index == 0) 53 | { 54 | s_cross_warp_operation_buffer[warp_index] = best_edge_x; 55 | s_cross_warp_operation_buffer_2[warp_index] = best_edge_score; 56 | } 57 | 58 | __syncthreads(); 59 | 60 | // block reduction.... 61 | if (warp_index == 0 && lane_index < warp_count) 62 | { 63 | best_edge_x = s_cross_warp_operation_buffer[lane_index]; 64 | best_edge_score = s_cross_warp_operation_buffer_2[lane_index]; 65 | 66 | int next_power_two = pow(2, ceil(log(warp_count)/log(2))); 67 | 68 | #pragma unroll 69 | for (int offset = next_power_two >> 1 ; offset > 0; offset >>= 1) 70 | { 71 | int other_edge_x = __shfl_down_sync(0xffffffff, best_edge_x, offset); 72 | float other_edge_score = __shfl_down_sync(0xffffffff, best_edge_score, offset); 73 | 74 | if (other_edge_score > best_edge_score) 75 | { 76 | best_edge_x = other_edge_x; 77 | best_edge_score = other_edge_score; 78 | } 79 | } 80 | 81 | if (lane_index == 0) 82 | { 83 | int point_index = flip ? strip_index : strip_index + strip_count; 84 | 85 | g_edge_x[point_index] = best_edge_x + half_patch_size; 86 | g_edge_y[point_index] = strip_height; 87 | g_edge_scores[point_index] = best_edge_score; 88 | } 89 | } 90 | } 91 | 92 | void find_points_from_strip_scores(const float* strips, const int batch_count, const int image_height, const int image_width, const int strip_count, const int model_patch_size, float* points_x, float* points_y, float* point_score) 93 | { 94 | int half_patch_size = (model_patch_size - 1) / 2; 95 | int strip_width = image_width - 2 * half_patch_size; 96 | 97 | int half_width = strip_width / 2; 98 | int warps = 1 + (half_width - 1) / 32; 99 | int threads = warps * 32; 100 | 101 | threads = threads > 1024 ? 1024 : threads; 102 | 103 | dim3 grid(2, strip_count, batch_count); 104 | dim3 block(threads); 105 | int shared_memory = 2 * warps * sizeof(int); 106 | find_best_edge<<>>(strips, points_x, points_y, point_score, strip_width, image_height, strip_count, half_patch_size); 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /src/torchcontentarea/csrc/source/fit_circle_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "../common.hpp" 3 | 4 | // ========================================================================= 5 | // General functionality... 6 | 7 | namespace cpu 8 | { 9 | 10 | int fast_rand(int& seed) 11 | { 12 | seed = 214013 * seed + 2531011; 13 | return (seed >> 16) & 0x7FFF; 14 | } 15 | 16 | void rand_triplet(int seed, int seed_stride, int max, int* triplet) 17 | { 18 | triplet[0] = fast_rand(seed) % max; 19 | do {triplet[1] = fast_rand(seed) % max;} while (triplet[1] == triplet[0]); 20 | do {triplet[2] = fast_rand(seed) % max;} while (triplet[2] == triplet[0] || triplet[2] == triplet[1]); 21 | } 22 | 23 | bool check_circle(float x, float y, float r, int image_width, int image_height) 24 | { 25 | float x_diff = x - 0.5 * image_width; 26 | float y_diff = y - 0.5 * image_height; 27 | float diff = sqrt(x_diff * x_diff + y_diff * y_diff); 28 | 29 | bool valid = true; 30 | valid &= diff < MAX_CENTER_DIST * image_width; 31 | valid &= r > MIN_RADIUS * image_width; 32 | valid &= r < MAX_RADIUS * image_width; 33 | 34 | return valid; 35 | } 36 | 37 | bool calculate_circle(float ax, float ay, float bx, float by, float cx, float cy, float* x, float* y, float* r) 38 | { 39 | float offset = bx * bx + by * by; 40 | 41 | float bc = 0.5f * (ax * ax + ay * ay - offset); 42 | float cd = 0.5f * (offset - cx * cx - cy * cy); 43 | 44 | float det = (ax - bx) * (by - cy) - (bx - cx) * (ay - by); 45 | 46 | bool valid = abs(det) > 1e-8; 47 | 48 | if (valid) 49 | { 50 | float idet = 1.0f / det; 51 | 52 | *x = (bc * (by - cy) - cd * (ay - by)) * idet; 53 | *y = (cd * (ax - bx) - bc * (bx - cx)) * idet; 54 | *r = sqrt((bx - *x) * (bx - *x) + (by - *y) * (by - *y)); 55 | } 56 | 57 | return valid; 58 | } 59 | 60 | bool Cholesky3x3(float lhs[3][3], float rhs[3]) 61 | { 62 | float sum; 63 | float diagonal[3]; 64 | 65 | sum = lhs[0][0]; 66 | 67 | if (sum <= 0.f) 68 | return false; 69 | 70 | diagonal[0] = sqrt(sum); 71 | 72 | sum = lhs[0][1]; 73 | lhs[1][0] = sum / diagonal[0]; 74 | 75 | sum = lhs[0][2]; 76 | lhs[2][0] = sum / diagonal[0]; 77 | 78 | sum = lhs[1][1] - lhs[1][0] * lhs[1][0]; 79 | 80 | if (sum <= 0.f) 81 | return false; 82 | 83 | diagonal[1] = sqrt(sum); 84 | 85 | sum = lhs[1][2] - lhs[1][0] * lhs[2][0]; 86 | lhs[2][1] = sum / diagonal[1]; 87 | 88 | sum = lhs[2][2] - lhs[2][1] * lhs[2][1] - lhs[2][0] * lhs[2][0]; 89 | 90 | if (sum <= 0.f) 91 | return false; 92 | 93 | diagonal[2] = sqrt(sum); 94 | 95 | sum = rhs[0]; 96 | rhs[0] = sum / diagonal[0]; 97 | 98 | sum = rhs[1] - lhs[1][0] * rhs[0]; 99 | rhs[1] = sum / diagonal[1]; 100 | 101 | sum = rhs[2] - lhs[2][1] * rhs[1] - lhs[2][0] * rhs[0]; 102 | rhs[2] = sum / diagonal[2]; 103 | 104 | sum = rhs[2]; 105 | rhs[2] = sum / diagonal[2]; 106 | 107 | sum = rhs[1] - lhs[2][1] * rhs[2]; 108 | rhs[1] = sum / diagonal[1]; 109 | 110 | sum = rhs[0] - lhs[1][0] * rhs[1] - lhs[2][0] * rhs[2]; 111 | rhs[0] = sum / diagonal[0]; 112 | 113 | return true; 114 | } 115 | 116 | void get_circle(int point_count, int* indices, int* points_x, int* points_y, float* circle_x, float* circle_y, float* circle_r) 117 | { 118 | if (point_count == 3) 119 | { 120 | int a = indices[0]; 121 | int b = indices[1]; 122 | int c = indices[2]; 123 | 124 | calculate_circle(points_x[a], points_y[a], points_x[b], points_y[b], points_x[c], points_y[c], circle_x, circle_y, circle_r); 125 | } 126 | else 127 | { 128 | 129 | float lhs[3][3] {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}; 130 | float rhs[3] {0, 0, 0}; 131 | 132 | for (int i = 0; i < point_count; i++) 133 | { 134 | float p_x = points_x[indices[i]]; 135 | float p_y = points_y[indices[i]]; 136 | 137 | lhs[0][0] += p_x * p_x; 138 | lhs[0][1] += p_x * p_y; 139 | lhs[1][1] += p_y * p_y; 140 | lhs[0][2] += p_x; 141 | lhs[1][2] += p_y; 142 | lhs[2][2] += 1; 143 | 144 | rhs[0] += p_x * p_x * p_x + p_x * p_y * p_y; 145 | rhs[1] += p_x * p_x * p_y + p_y * p_y * p_y; 146 | rhs[2] += p_x * p_x + p_y * p_y; 147 | } 148 | 149 | Cholesky3x3(lhs, rhs); 150 | 151 | float A=rhs[0], B=rhs[1], C=rhs[2]; 152 | 153 | *circle_x = A / 2.0f; 154 | *circle_y = B / 2.0f; 155 | *circle_r = std::sqrt(4.0f * C + A * A + B * B) / 2.0f; 156 | } 157 | } 158 | 159 | // ========================================================================= 160 | // Main function... 161 | 162 | void fit_circle(const float* points_x, const float* points_y, const float* points_score, const int batch_count, const int point_count, const ConfidenceThresholds confidence_thresholds, const int image_height, const int image_width, float* results) 163 | { 164 | int* compacted_points = (int*)malloc(3 * point_count * sizeof(int)); 165 | int* compacted_points_x = compacted_points + 0 * point_count; 166 | int* compacted_points_y = compacted_points + 1 * point_count; 167 | float* compacted_points_s = (float*)compacted_points + 2 * point_count; 168 | 169 | 170 | for (int batch_index = 0; batch_index < batch_count; ++batch_index) 171 | { 172 | // Point compaction... 173 | int real_point_count = 0; 174 | for (int i = 0; i < point_count; ++i) 175 | { 176 | if (points_score[i] > confidence_thresholds.edge) 177 | { 178 | compacted_points_x[real_point_count] = points_x[i + batch_index * 3 * point_count]; 179 | compacted_points_y[real_point_count] = points_y[i + batch_index * 3 * point_count]; 180 | compacted_points_s[real_point_count] = points_score[i + batch_index * 3 * point_count]; 181 | real_point_count += 1; 182 | } 183 | } 184 | 185 | results[0 + batch_index * 4] = 0.0f; 186 | results[1 + batch_index * 4] = 0.0f; 187 | results[2 + batch_index * 4] = 0.0f; 188 | results[3 + batch_index * 4] = 0.0f; 189 | 190 | // Early out... 191 | if (real_point_count < 3) 192 | { 193 | return; 194 | } 195 | 196 | // Ransac attempts... 197 | for (int ransac_attempt = 0; ransac_attempt < RANSAC_ATTEMPTS; ++ransac_attempt) 198 | { 199 | int inlier_count = 3; 200 | int inliers[MAX_POINT_COUNT]; 201 | rand_triplet(ransac_attempt * 42342, RANSAC_ATTEMPTS, real_point_count, inliers); 202 | 203 | float circle_x, circle_y, circle_r; 204 | float circle_score = 0.0f; 205 | 206 | for (int i = 0; i < RANSAC_ITERATIONS; i++) 207 | { 208 | get_circle(inlier_count, inliers, compacted_points_x, compacted_points_y, &circle_x, &circle_y, &circle_r); 209 | 210 | inlier_count = 0; 211 | circle_score = 0.0f; 212 | 213 | for (int point_index = 0; point_index < real_point_count; point_index++) 214 | { 215 | int edge_x = compacted_points_x[point_index]; 216 | int edge_y = compacted_points_y[point_index]; 217 | float edge_score = compacted_points_s[point_index]; 218 | 219 | float delta_x = circle_x - edge_x; 220 | float delta_y = circle_y - edge_y; 221 | 222 | float delta = std::sqrt(delta_x * delta_x + delta_y * delta_y); 223 | float error = std::abs(circle_r - delta); 224 | 225 | if (error < RANSAC_INLIER_THRESHOLD) 226 | { 227 | circle_score += edge_score; 228 | 229 | inliers[inlier_count] = point_index; 230 | inlier_count++; 231 | } 232 | } 233 | 234 | circle_score /= point_count; 235 | } 236 | 237 | bool circle_valid = check_circle(circle_x, circle_y, circle_r, image_width, image_height); 238 | 239 | if (circle_valid && circle_score > results[3 + batch_index * 4]) 240 | { 241 | results[0 + batch_index * 4] = circle_x; 242 | results[1 + batch_index * 4] = circle_y; 243 | results[2 + batch_index * 4] = circle_r; 244 | results[3 + batch_index * 4] = circle_score; 245 | } 246 | } 247 | } 248 | 249 | free(compacted_points); 250 | } 251 | } 252 | -------------------------------------------------------------------------------- /src/torchcontentarea/csrc/source/fit_circle_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include "../common.hpp" 3 | 4 | namespace cuda 5 | { 6 | // ========================================================================= 7 | // General functionality... 8 | 9 | __device__ int fast_rand(int& seed) 10 | { 11 | seed = 214013 * seed + 2531011; 12 | return (seed >> 16) & 0x7FFF; 13 | } 14 | 15 | __device__ void rand_triplet(int seed, int seed_stride, int max, int* triplet) 16 | { 17 | triplet[0] = fast_rand(seed) % max; 18 | do {triplet[1] = fast_rand(seed) % max;} while (triplet[1] == triplet[0]); 19 | do {triplet[2] = fast_rand(seed) % max;} while (triplet[2] == triplet[0] || triplet[2] == triplet[1]); 20 | } 21 | 22 | __device__ bool check_circle(float x, float y, float r, int image_width, int image_height) 23 | { 24 | float x_diff = x - 0.5 * image_width; 25 | float y_diff = y - 0.5 * image_height; 26 | float diff = sqrt(x_diff * x_diff + y_diff * y_diff); 27 | 28 | bool valid = true; 29 | valid &= diff < MAX_CENTER_DIST * image_width; 30 | valid &= r > MIN_RADIUS * image_width; 31 | valid &= r < MAX_RADIUS * image_width; 32 | 33 | return valid; 34 | } 35 | 36 | __device__ bool calculate_circle(float ax, float ay, float bx, float by, float cx, float cy, float* x, float* y, float* r) 37 | { 38 | float offset = bx * bx + by * by; 39 | 40 | float bc = 0.5f * (ax * ax + ay * ay - offset); 41 | float cd = 0.5f * (offset - cx * cx - cy * cy); 42 | 43 | float det = (ax - bx) * (by - cy) - (bx - cx) * (ay - by); 44 | 45 | bool valid = abs(det) > 1e-8; 46 | 47 | if (valid) 48 | { 49 | float idet = 1.0f / det; 50 | 51 | *x = (bc * (by - cy) - cd * (ay - by)) * idet; 52 | *y = (cd * (ax - bx) - bc * (bx - cx)) * idet; 53 | *r = sqrt((bx - *x) * (bx - *x) + (by - *y) * (by - *y)); 54 | } 55 | 56 | return valid; 57 | } 58 | 59 | __device__ bool Cholesky3x3(float lhs[3][3], float rhs[3]) 60 | { 61 | float sum; 62 | float diagonal[3]; 63 | 64 | sum = lhs[0][0]; 65 | 66 | if (sum <= 0.f) 67 | return false; 68 | 69 | diagonal[0] = sqrt(sum); 70 | 71 | sum = lhs[0][1]; 72 | lhs[1][0] = sum / diagonal[0]; 73 | 74 | sum = lhs[0][2]; 75 | lhs[2][0] = sum / diagonal[0]; 76 | 77 | sum = lhs[1][1] - lhs[1][0] * lhs[1][0]; 78 | 79 | if (sum <= 0.f) 80 | return false; 81 | 82 | diagonal[1] = sqrt(sum); 83 | 84 | sum = lhs[1][2] - lhs[1][0] * lhs[2][0]; 85 | lhs[2][1] = sum / diagonal[1]; 86 | 87 | sum = lhs[2][2] - lhs[2][1] * lhs[2][1] - lhs[2][0] * lhs[2][0]; 88 | 89 | if (sum <= 0.f) 90 | return false; 91 | 92 | diagonal[2] = sqrt(sum); 93 | 94 | sum = rhs[0]; 95 | rhs[0] = sum / diagonal[0]; 96 | 97 | sum = rhs[1] - lhs[1][0] * rhs[0]; 98 | rhs[1] = sum / diagonal[1]; 99 | 100 | sum = rhs[2] - lhs[2][1] * rhs[1] - lhs[2][0] * rhs[0]; 101 | rhs[2] = sum / diagonal[2]; 102 | 103 | sum = rhs[2]; 104 | rhs[2] = sum / diagonal[2]; 105 | 106 | sum = rhs[1] - lhs[2][1] * rhs[2]; 107 | rhs[1] = sum / diagonal[1]; 108 | 109 | sum = rhs[0] - lhs[1][0] * rhs[1] - lhs[2][0] * rhs[2]; 110 | rhs[0] = sum / diagonal[0]; 111 | 112 | return true; 113 | } 114 | 115 | __device__ void get_circle(int point_count, int* indices, int* points_x, int* points_y, float* circle_x, float* circle_y, float* circle_r) 116 | { 117 | if (point_count == 3) 118 | { 119 | int a = indices[0]; 120 | int b = indices[1]; 121 | int c = indices[2]; 122 | 123 | calculate_circle(points_x[a], points_y[a], points_x[b], points_y[b], points_x[c], points_y[c], circle_x, circle_y, circle_r); 124 | } 125 | else 126 | { 127 | 128 | float lhs[3][3] {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}; 129 | float rhs[3] {0, 0, 0}; 130 | 131 | for (int i = 0; i < point_count; i++) 132 | { 133 | float p_x = points_x[indices[i]]; 134 | float p_y = points_y[indices[i]]; 135 | 136 | lhs[0][0] += p_x * p_x; 137 | lhs[0][1] += p_x * p_y; 138 | lhs[1][1] += p_y * p_y; 139 | lhs[0][2] += p_x; 140 | lhs[1][2] += p_y; 141 | lhs[2][2] += 1; 142 | 143 | rhs[0] += p_x * p_x * p_x + p_x * p_y * p_y; 144 | rhs[1] += p_x * p_x * p_y + p_y * p_y * p_y; 145 | rhs[2] += p_x * p_x + p_y * p_y; 146 | } 147 | 148 | Cholesky3x3(lhs, rhs); 149 | 150 | float A=rhs[0], B=rhs[1], C=rhs[2]; 151 | 152 | *circle_x = A / 2.0f; 153 | *circle_y = B / 2.0f; 154 | *circle_r = sqrt(4.0f * C + A * A + B * B) / 2.0f; 155 | } 156 | } 157 | 158 | // ========================================================================= 159 | // Kernels... 160 | 161 | template 162 | __global__ void fit_circle_kernel(const float* g_edge_x_batch, const float* g_edge_y_batch, const float* g_edge_scores_batch, float* g_circle_batch, const int point_count, const ConfidenceThresholds confidence_thresholds, const int image_height, const int image_width) 163 | { 164 | extern __shared__ int s_edge_info[]; 165 | __shared__ int s_valid_point_count; 166 | __shared__ float s_score_reduction_buffer[warp_count]; 167 | __shared__ float s_x_reduction_buffer[warp_count]; 168 | __shared__ float s_y_reduction_buffer[warp_count]; 169 | __shared__ float s_r_reduction_buffer[warp_count]; 170 | 171 | const float* g_edge_x = g_edge_x_batch + blockIdx.x * 3 * point_count; 172 | const float* g_edge_y = g_edge_y_batch + blockIdx.x * 3 * point_count; 173 | const float* g_edge_scores = g_edge_scores_batch + blockIdx.x * 3 * point_count; 174 | float* g_circle = g_circle_batch + blockIdx.x * 4; 175 | 176 | int* s_edge_x = (int*)(s_edge_info + 0 * point_count); 177 | int* s_edge_y = (int*)(s_edge_info + 1 * point_count); 178 | float* s_edge_scores = (float*)(s_edge_info + 2 * point_count); 179 | 180 | const int warp_index = threadIdx.x >> 5; 181 | const int lane_index = threadIdx.x & 31; 182 | 183 | // Loading points to shared memory... 184 | if (threadIdx.x < point_count) 185 | { 186 | s_edge_x[threadIdx.x] = g_edge_x[threadIdx.x]; 187 | s_edge_y[threadIdx.x] = g_edge_y[threadIdx.x]; 188 | s_edge_scores[threadIdx.x] = g_edge_scores[threadIdx.x]; 189 | } 190 | 191 | // Point compaction... 192 | bool has_point = threadIdx.x < point_count ? s_edge_scores[threadIdx.x] > confidence_thresholds.edge : false; 193 | int preceeding_count = has_point; 194 | 195 | #pragma unroll 196 | for (int d=1; d < 32; d<<=1) 197 | { 198 | float other_count = __shfl_up_sync(0xffffffff, preceeding_count, d); 199 | 200 | if (lane_index >= d) 201 | { 202 | preceeding_count += other_count; 203 | } 204 | } 205 | 206 | if (lane_index == 31) 207 | { 208 | s_score_reduction_buffer[warp_index] = preceeding_count; 209 | } 210 | 211 | __syncthreads(); 212 | 213 | if (warp_index == 0) 214 | { 215 | int warp_sum = lane_index < warp_count ? s_score_reduction_buffer[lane_index] : 0; 216 | 217 | #pragma unroll 218 | for (int d=1; d < 32; d<<=1) 219 | { 220 | float other_warp_sum = __shfl_up_sync(0xffffffff, warp_sum, d); 221 | 222 | if (lane_index >= d && other_warp_sum > warp_sum) 223 | { 224 | warp_sum = other_warp_sum; 225 | } 226 | } 227 | 228 | if (lane_index < warp_count) 229 | { 230 | s_score_reduction_buffer[lane_index] = warp_sum; 231 | } 232 | } 233 | 234 | __syncthreads(); 235 | 236 | if (warp_index > 0) 237 | { 238 | preceeding_count += s_score_reduction_buffer[warp_index-1]; 239 | } 240 | 241 | if (has_point) 242 | { 243 | s_edge_x[preceeding_count - 1] = s_edge_x[threadIdx.x]; 244 | s_edge_y[preceeding_count - 1] = s_edge_y[threadIdx.x]; 245 | s_edge_scores[preceeding_count - 1] = s_edge_scores[threadIdx.x]; 246 | } 247 | 248 | if (threadIdx.x == blockDim.x - 1) 249 | { 250 | s_valid_point_count = preceeding_count; 251 | } 252 | 253 | __syncthreads(); 254 | 255 | if (s_valid_point_count < 3) 256 | { 257 | if (threadIdx.x < 4) 258 | { 259 | g_circle[threadIdx.x] = 0.0; 260 | } 261 | 262 | return; 263 | } 264 | 265 | int inlier_count = 3; 266 | int inliers[MAX_POINT_COUNT]; 267 | rand_triplet(threadIdx.x * 42342, RANSAC_ATTEMPTS, s_valid_point_count, inliers); 268 | 269 | float circle_x, circle_y, circle_r; 270 | float circle_score = 0.0f; 271 | 272 | for (int i = 0; i < RANSAC_ITERATIONS; i++) 273 | { 274 | get_circle(inlier_count, inliers, s_edge_x, s_edge_y, &circle_x, &circle_y, &circle_r); 275 | 276 | inlier_count = 0; 277 | circle_score = 0.0f; 278 | 279 | for (int point_index = 0; point_index < s_valid_point_count; point_index++) 280 | { 281 | int edge_x = s_edge_x[point_index]; 282 | int edge_y = s_edge_y[point_index]; 283 | float edge_score = s_edge_scores[point_index]; 284 | 285 | float delta_x = circle_x - edge_x; 286 | float delta_y = circle_y - edge_y; 287 | 288 | float delta = sqrt(delta_x * delta_x + delta_y * delta_y); 289 | float error = abs(circle_r - delta); 290 | 291 | if (error < RANSAC_INLIER_THRESHOLD) 292 | { 293 | circle_score += edge_score; 294 | 295 | inliers[inlier_count] = point_index; 296 | inlier_count++; 297 | } 298 | } 299 | 300 | circle_score /= point_count; 301 | } 302 | 303 | bool circle_valid = check_circle(circle_x, circle_y, circle_r, image_width, image_height); 304 | 305 | if (!circle_valid) 306 | { 307 | circle_score = 0; 308 | } 309 | 310 | //################################# 311 | // Reduction 312 | 313 | #pragma unroll 314 | for (int offset = 16; offset > 0; offset /= 2) 315 | { 316 | float other_circle_score = __shfl_down_sync(0xffffffff, circle_score, offset); 317 | float other_circle_x = __shfl_down_sync(0xffffffff, circle_x, offset); 318 | float other_circle_y = __shfl_down_sync(0xffffffff, circle_y, offset); 319 | float other_circle_r = __shfl_down_sync(0xffffffff, circle_r, offset); 320 | 321 | if (other_circle_score > circle_score) 322 | { 323 | circle_score = other_circle_score; 324 | circle_x = other_circle_x; 325 | circle_y = other_circle_y; 326 | circle_r = other_circle_r; 327 | } 328 | } 329 | 330 | if (lane_index == 0) 331 | { 332 | s_score_reduction_buffer[warp_index] = circle_score; 333 | s_x_reduction_buffer[warp_index] = circle_x; 334 | s_y_reduction_buffer[warp_index] = circle_y; 335 | s_r_reduction_buffer[warp_index] = circle_r; 336 | } 337 | 338 | __syncthreads(); 339 | 340 | if (warp_index == 0 && lane_index < warp_count) 341 | { 342 | circle_score = s_score_reduction_buffer[warp_index]; 343 | circle_x = s_x_reduction_buffer[warp_index]; 344 | circle_y = s_y_reduction_buffer[warp_index]; 345 | circle_r = s_r_reduction_buffer[warp_index]; 346 | 347 | #pragma unroll 348 | for (int offset = warp_count / 2; offset > 0; offset /= 2) 349 | { 350 | float other_circle_score = __shfl_down_sync(0xffffffff, circle_score, offset); 351 | float other_circle_x = __shfl_down_sync(0xffffffff, circle_x, offset); 352 | float other_circle_y = __shfl_down_sync(0xffffffff, circle_y, offset); 353 | float other_circle_r = __shfl_down_sync(0xffffffff, circle_r, offset); 354 | 355 | if (other_circle_score > circle_score) 356 | { 357 | circle_score = other_circle_score; 358 | circle_x = other_circle_x; 359 | circle_y = other_circle_y; 360 | circle_r = other_circle_r; 361 | } 362 | } 363 | 364 | if (lane_index == 0) 365 | { 366 | g_circle[0] = circle_x; 367 | g_circle[1] = circle_y; 368 | g_circle[2] = circle_r; 369 | g_circle[3] = circle_score; 370 | } 371 | } 372 | } 373 | 374 | // ========================================================================= 375 | // Main function... 376 | 377 | #define ransac_threads RANSAC_ATTEMPTS 378 | #define ransac_warps (1 + (RANSAC_ATTEMPTS - 1) / 32) 379 | 380 | void fit_circle(const float* points_x, const float* points_y, const float* points_score, const int batch_count, const int point_count, const ConfidenceThresholds confidence_thresholds, const int image_height, const int image_width, float* results) 381 | { 382 | dim3 grid(batch_count); 383 | dim3 block(point_count); 384 | fit_circle_kernel<<>>(points_x, points_y, points_score, results, point_count, confidence_thresholds, image_height, image_width); 385 | } 386 | 387 | } 388 | -------------------------------------------------------------------------------- /src/torchcontentarea/csrc/source/make_strips_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "../common.hpp" 3 | 4 | namespace cpu 5 | { 6 | template 7 | void make_strips(const T* image, const int batch_count, const int image_height, const int image_width, const int strip_count, const int strip_width, float* strips) 8 | { 9 | for (int batch_index = 0; batch_index < batch_count; ++batch_index) 10 | { 11 | for (int strip_index = 0; strip_index < strip_count; ++strip_index) 12 | { 13 | int strip_height = 1 + (image_height - 2) / (1.0f + std::exp(-(strip_index - strip_count / 2.0f + 0.5f)/(strip_count / 8.0f))); 14 | int strip_offset = strip_index * 5 * image_width * strip_width; 15 | 16 | for (int image_x = 0; image_x < image_width; ++image_x) 17 | { 18 | for (int strip_y = 0; strip_y < strip_width; ++strip_y) 19 | { 20 | int image_y = strip_height + strip_y - (strip_width - 1) / 2; 21 | 22 | int image_pixel_index = image_x + image_y * image_width; 23 | int strip_pixel_index = strip_offset + image_x + strip_y * image_width; 24 | 25 | float r, g, b; 26 | float norm = std::is_floating_point::value ? 1.0f : 255.0f; 27 | 28 | if (c == 3) 29 | { 30 | r = (image[image_pixel_index + 0 * image_width * image_height + batch_index * 3 * image_width * image_height]/norm - 0.3441f) / 0.2381f; 31 | g = (image[image_pixel_index + 1 * image_width * image_height + batch_index * 3 * image_width * image_height]/norm - 0.2251f) / 0.1994f; 32 | b = (image[image_pixel_index + 2 * image_width * image_height + batch_index * 3 * image_width * image_height]/norm - 0.2203f) / 0.1939f; 33 | } 34 | else 35 | { 36 | r = (image[image_pixel_index + batch_index * 3 * image_width * image_height]/norm - 0.3441f) / 0.2381f; 37 | g = (image[image_pixel_index + batch_index * 3 * image_width * image_height]/norm - 0.2251f) / 0.1994f; 38 | b = (image[image_pixel_index + batch_index * 3 * image_width * image_height]/norm - 0.2203f) / 0.1939f; 39 | } 40 | 41 | float x = ((float)image_x / image_width) - 0.5f; 42 | float y = ((float)image_y / image_height) - 0.5f; 43 | 44 | strips[strip_pixel_index + 0 * image_width * strip_width + batch_index * strip_count * 5 * strip_width * image_width] = r; 45 | strips[strip_pixel_index + 1 * image_width * strip_width + batch_index * strip_count * 5 * strip_width * image_width] = g; 46 | strips[strip_pixel_index + 2 * image_width * strip_width + batch_index * strip_count * 5 * strip_width * image_width] = b; 47 | strips[strip_pixel_index + 3 * image_width * strip_width + batch_index * strip_count * 5 * strip_width * image_width] = x; 48 | strips[strip_pixel_index + 4 * image_width * strip_width + batch_index * strip_count * 5 * strip_width * image_width] = y; 49 | } 50 | } 51 | } 52 | } 53 | } 54 | 55 | 56 | void make_strips(Image image, const int batch_count, const int channel_count, const int image_height, const int image_width, const int strip_count, const int strip_width, float* strips) 57 | { 58 | FUNCTION_CALL_IMAGE_FORMAT(make_strips, image, batch_count, image_height, image_width, strip_count, strip_width, strips); 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /src/torchcontentarea/csrc/source/make_strips_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include "../common.hpp" 3 | 4 | namespace cuda 5 | { 6 | __device__ int get_strip_height(int strip_index, int strip_count, int image_height) 7 | { 8 | return 1 + (image_height - 2) / (1.0f + exp(-(strip_index - strip_count / 2.0f + 0.5f) / (strip_count / 8.0f))); 9 | } 10 | 11 | template 12 | __global__ void make_strips_kernel(const T* g_image_batch, const int image_width, const int image_height, const int strip_count, const int strip_width, float* g_strips_batch) 13 | { 14 | const T* g_image = g_image_batch + blockIdx.z * c * image_width * image_height; 15 | float* g_strips = g_strips_batch + blockIdx.z * strip_count * 5 * image_width * strip_width; 16 | 17 | int strip_index = blockIdx.y; 18 | int strip_offset = strip_index * 5 * image_width * strip_width; 19 | int strip_height = get_strip_height(strip_index, strip_count, image_height); 20 | 21 | int image_x = threadIdx.x + blockIdx.x * blockDim.x; 22 | int strip_y = threadIdx.y; 23 | 24 | if (image_x >= image_width) 25 | return; 26 | 27 | int image_y = strip_height + strip_y - (strip_width - 1) / 2; 28 | 29 | int image_pixel_index = image_x + image_y * image_width; 30 | int strip_pixel_index = strip_offset + image_x + strip_y * image_width; 31 | 32 | float r, g, b; 33 | 34 | float norm = std::is_floating_point::value ? 1.0f : 255.0f; 35 | 36 | if (c == 3) 37 | { 38 | r = (g_image[image_pixel_index + 0 * image_width * image_height]/norm - 0.3441f) / 0.2381f; 39 | g = (g_image[image_pixel_index + 1 * image_width * image_height]/norm - 0.2251f) / 0.1994f; 40 | b = (g_image[image_pixel_index + 2 * image_width * image_height]/norm - 0.2203f) / 0.1939f; 41 | } 42 | else 43 | { 44 | r = (g_image[image_pixel_index]/norm - 0.3441f) / 0.2381f; 45 | g = (g_image[image_pixel_index]/norm - 0.2251f) / 0.1994f; 46 | b = (g_image[image_pixel_index]/norm - 0.2203f) / 0.1939f; 47 | } 48 | 49 | float x = ((float)image_x / image_width) - 0.5f; 50 | float y = ((float)image_y / image_height) - 0.5f; 51 | 52 | g_strips[strip_pixel_index + 0 * image_width * strip_width] = r; 53 | g_strips[strip_pixel_index + 1 * image_width * strip_width] = g; 54 | g_strips[strip_pixel_index + 2 * image_width * strip_width] = b; 55 | g_strips[strip_pixel_index + 3 * image_width * strip_width] = x; 56 | g_strips[strip_pixel_index + 4 * image_width * strip_width] = y; 57 | } 58 | 59 | void make_strips(Image image, const int batch_count, const int channel_count, const int image_height, const int image_width, const int strip_count, const int strip_width, float* strips) 60 | { 61 | dim3 grid(((image_width - 1) / 128) + 1, strip_count, batch_count); 62 | dim3 block(128, strip_width); 63 | KERNEL_DISPATCH_IMAGE_FORMAT(make_strips_kernel, ARG(grid, block), image, image_width, image_height, strip_count, strip_width, strips); 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /src/torchcontentarea/extension_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Sequence, Optional 3 | 4 | try: 5 | import torchcontentareaext as ext 6 | FALL_BACK = False 7 | except: 8 | print("Falling back to python implementation...") 9 | from . import pythonimplementation as ext 10 | FALL_BACK = True 11 | 12 | models = {} 13 | 14 | def load_default_model(device): 15 | dir = "/".join(__file__.split("/")[:-1]) 16 | return torch.jit.load(f"{dir}/models/kernel_3_8.pt", map_location=device) 17 | 18 | def get_default_model(device): 19 | global models 20 | model = models.get(device) 21 | if model == None: 22 | model = load_default_model(device) 23 | models.update({device: model}) 24 | return model 25 | 26 | 27 | def estimate_area_handcrafted(image: torch.Tensor, strip_count: int=16, feature_thresholds: Sequence[float]=(20, 30, 25), confidence_thresholds: Sequence[float]=(0.03, 0.06)) -> torch.Tensor: 28 | """ 29 | Estimates the content area for the given endoscopic image(s) using handcrafted feature extraction. 30 | """ 31 | return ext.estimate_area_handcrafted(image, strip_count, feature_thresholds, confidence_thresholds) 32 | 33 | 34 | def estimate_area_learned(image: torch.Tensor, strip_count: int=16, model: Optional[torch.jit.ScriptModule]=None, model_patch_size: int=7, confidence_thresholds: Sequence[float]=(0.03, 0.06)) -> torch.Tensor: 35 | """ 36 | Estimates the content area for the given endoscopic image(s) using learned feature extraction. 37 | """ 38 | if model == None: 39 | model = get_default_model(image.device) 40 | return ext.estimate_area_learned(image, strip_count, model if FALL_BACK else model._c, model_patch_size, confidence_thresholds) 41 | 42 | 43 | def get_points_handcrafted(image: torch.Tensor, strip_count: int=16, feature_thresholds: Sequence[float]=(20, 30, 25)) -> torch.Tensor: 44 | """ 45 | Finds candidate edge points and corresponding scores in the given image(s) using handcrafted feature extraction. 46 | """ 47 | return ext.get_points_handcrafted(image, strip_count, feature_thresholds) 48 | 49 | 50 | def get_points_learned(image: torch.Tensor, strip_count: int=16, model: Optional[torch.jit.ScriptModule]=None, model_patch_size: int=7) -> torch.Tensor: 51 | """ 52 | Finds candidate edge points and corresponding scores in the given image(s) using learned feature extraction. 53 | """ 54 | if model == None: 55 | model = get_default_model(image.device) 56 | return ext.get_points_learned(image, strip_count, model if FALL_BACK else model._c, model_patch_size) 57 | 58 | 59 | def fit_area(points: torch.Tensor, image_size: Sequence[int], confidence_thresholds: Sequence[float]=(0.03, 0.06)) -> torch.Tensor: 60 | """ 61 | Finds candidate edge points and corresponding scores in the given image(s) using handcrafted feature extraction. 62 | """ 63 | return ext.fit_area(points, image_size, confidence_thresholds) 64 | -------------------------------------------------------------------------------- /src/torchcontentarea/models/kernel_1_8.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/charliebudd/torch-content-area/c613957f266a64232f8283975653635160b3f0a2/src/torchcontentarea/models/kernel_1_8.pt -------------------------------------------------------------------------------- /src/torchcontentarea/models/kernel_2_8.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/charliebudd/torch-content-area/c613957f266a64232f8283975653635160b3f0a2/src/torchcontentarea/models/kernel_2_8.pt -------------------------------------------------------------------------------- /src/torchcontentarea/models/kernel_3_8.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/charliebudd/torch-content-area/c613957f266a64232f8283975653635160b3f0a2/src/torchcontentarea/models/kernel_3_8.pt -------------------------------------------------------------------------------- /src/torchcontentarea/pythonimplementation/__init__.py: -------------------------------------------------------------------------------- 1 | from .estimate_area import estimate_area_handcrafted, estimate_area_learned 2 | from .get_points import get_points_handcrafted, get_points_learned 3 | from .fit_area import fit_area 4 | -------------------------------------------------------------------------------- /src/torchcontentarea/pythonimplementation/estimate_area.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Sequence, Optional 3 | 4 | from .get_points import get_points_handcrafted, get_points_learned 5 | from .fit_area import fit_area 6 | 7 | def estimate_area_handcrafted(image: torch.Tensor, strip_count: int=16, feature_thresholds: Sequence[float]=(20, 30, 25), confidence_thresholds: Sequence[float]=(0.03, 0.06)) -> torch.Tensor: 8 | """ 9 | Estimates the content area for the given endoscopic image(s) using handcrafted feature extraction. 10 | """ 11 | points = get_points_handcrafted(image, strip_count, feature_thresholds) 12 | area = fit_area(points, image.shape[-2:], confidence_thresholds) 13 | return area 14 | 15 | 16 | def estimate_area_learned(image: torch.Tensor, strip_count: int=16, model: Optional[torch.jit.ScriptModule]=None, model_patch_size: int=7, confidence_thresholds: Sequence[float]=(0.03, 0.06)) -> torch.Tensor: 17 | """ 18 | Estimates the content area for the given endoscopic image(s) using learned feature extraction. 19 | """ 20 | points = get_points_learned(image, strip_count, model, model_patch_size) 21 | area = fit_area(points, image.shape[-2:], confidence_thresholds) 22 | return area 23 | -------------------------------------------------------------------------------- /src/torchcontentarea/pythonimplementation/fit_area.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from math import sqrt 3 | from typing import Sequence 4 | 5 | MAX_CENTER_DIST=0.2 6 | MIN_RADIUS=0.2 7 | MAX_RADIUS=0.8 8 | RANSAC_ATTEMPTS=32 9 | RANSAC_ITERATIONS=3 10 | RANSAC_INLIER_THRESHOLD=3 11 | 12 | def check_circle(x, y, r, w, h): 13 | 14 | x_diff = x - 0.5 * w 15 | y_diff = y - 0.5 * h 16 | diff = sqrt(x_diff * x_diff + y_diff * y_diff) 17 | 18 | valid = True 19 | valid &= diff < MAX_CENTER_DIST * w 20 | valid &= r > MIN_RADIUS * w 21 | valid &= r < MAX_RADIUS * w 22 | 23 | return valid 24 | 25 | def calculate_circle(points): 26 | 27 | ax, ay = points[:2, 0] 28 | bx, by = points[:2, 1] 29 | cx, cy = points[:2, 2] 30 | 31 | offset = bx * bx + by * by 32 | 33 | bc = 0.5 * (ax * ax + ay * ay - offset) 34 | cd = 0.5 * (offset - cx * cx - cy * cy) 35 | 36 | det = (ax - bx) * (by - cy) - (bx - cx) * (ay - by) 37 | 38 | if abs(det) > 1e-8: 39 | idet = 1.0 / det 40 | x = (bc * (by - cy) - cd * (ay - by)) * idet 41 | y = (cd * (ax - bx) - bc * (bx - cx)) * idet 42 | r = sqrt((bx - x) * (bx - x) + (by - y) * (by - y)) 43 | return x, y, r 44 | else: 45 | return None 46 | 47 | def get_circle(points): 48 | 49 | if points.size(1) == 3: 50 | return calculate_circle(points) 51 | else: 52 | lhs = torch.zeros(3, 3).to(points.device) 53 | rhs = torch.zeros(3).to(points.device) 54 | 55 | for p_x, p_y, _ in points.T: 56 | lhs[0, 0] += p_x * p_x 57 | lhs[0, 1] += p_x * p_y 58 | lhs[1, 1] += p_y * p_y 59 | lhs[0, 2] += p_x 60 | lhs[1, 2] += p_y 61 | lhs[2, 2] += 1 62 | 63 | rhs[0] += p_x * p_x * p_x + p_x * p_y * p_y 64 | rhs[1] += p_x * p_x * p_y + p_y * p_y * p_y 65 | rhs[2] += p_x * p_x + p_y * p_y 66 | 67 | lhs[1, 0] = lhs[0, 1] 68 | lhs[2, 0] = lhs[0, 2] 69 | lhs[2, 1] = lhs[1, 2] 70 | 71 | try: 72 | L = torch.linalg.cholesky(lhs) 73 | y = torch.linalg.solve(L, rhs) 74 | x = torch.linalg.solve(L.T, y) 75 | except: 76 | return None 77 | 78 | A, B, C = x[0], x[1], x[2] 79 | 80 | x = A / 2.0 81 | y = B / 2.0 82 | r = torch.sqrt(4.0 * C + A * A + B * B) / 2.0 83 | 84 | return x, y, r 85 | 86 | def fit_area(points: torch.Tensor, image_size: Sequence[int], confidence_thresholds: Sequence[float]=(0.03, 0.06)) -> torch.Tensor: 87 | """ 88 | Finds candidate edge points and corresponding scores in the given image(s) using handcrafted feature extraction. 89 | """ 90 | 91 | batched = len(points.shape) == 3 92 | 93 | if not batched: 94 | points = points.unsqueeze(0) 95 | 96 | areas = [] 97 | 98 | for point_batch in points: 99 | 100 | point_batch = point_batch[:, point_batch[2] > confidence_thresholds[0]] 101 | 102 | if point_batch.size(1) < 3: 103 | areas.append(torch.zeros(4)) 104 | continue 105 | 106 | best_circle = torch.zeros(4) 107 | 108 | for _ in range(RANSAC_ATTEMPTS): 109 | 110 | indices = torch.randperm(point_batch.size(1))[:3] 111 | inliers = point_batch[:, indices] 112 | 113 | for _ in range(RANSAC_ITERATIONS): 114 | circle = get_circle(inliers) 115 | 116 | if circle is None: 117 | x, y, r = 0, 0, 0 118 | circle_score = 0 119 | break 120 | 121 | x, y, r = circle 122 | 123 | dx = x - point_batch[0] 124 | dy = y - point_batch[1] 125 | 126 | error = torch.abs(torch.sqrt(dx**2 + dy**2) - r) 127 | 128 | inliers = point_batch[:, error < RANSAC_INLIER_THRESHOLD] 129 | circle_score = inliers[2].sum() 130 | 131 | circle_score /= point_batch.size(1) 132 | 133 | circle_valid = check_circle(x, y, r, image_size[1], image_size[0]) 134 | 135 | if circle_valid and circle_score > best_circle[3]: 136 | best_circle = torch.tensor([x, y, r, circle_score]) 137 | 138 | areas.append(best_circle) 139 | 140 | areas = torch.stack(areas) 141 | 142 | if not batched: 143 | areas = areas[0] 144 | 145 | return areas 146 | -------------------------------------------------------------------------------- /src/torchcontentarea/pythonimplementation/get_points.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Sequence, Optional 3 | 4 | DEG2RAD=0.01745329251 5 | RAD2DEG=1.0/DEG2RAD 6 | 7 | GRAY_SCALE_WEIGHTS = torch.tensor([ 8 | 0.2126, 0.7152, 0.0722 9 | ]) 10 | 11 | SOBEL_KERNEL = torch.tensor([ 12 | [[ 13 | [-0.25, 0.0, 0.25], 14 | [-0.5, 0.0, 0.5], 15 | [-0.25, 0.0, 0.25], 16 | ]], 17 | [[ 18 | [-0.25, -0.5, -0.25], 19 | [0.0, 0.0, 0.0], 20 | [0.25, 0.5, 0.25], 21 | ]], 22 | ]) 23 | 24 | def get_points_handcrafted(image: torch.Tensor, strip_count: int=16, feature_thresholds: Sequence[float]=(20, 30, 25)) -> torch.Tensor: 25 | """ 26 | Finds candidate edge points and corresponding scores in the given image(s) using handcrafted feature extraction. 27 | """ 28 | 29 | batched = len(image.shape) == 4 30 | device = image.device 31 | if not batched: 32 | image.unsqueeze(0) 33 | if image.dtype.is_floating_point: 34 | image = image * 255.0 35 | if image.size(1) != 1: 36 | image = (image * GRAY_SCALE_WEIGHTS[None, :, None, None].to(device)).sum(dim=1) 37 | image = image.float() 38 | B, H, W = image.shape 39 | 40 | strip_indices = torch.arange(strip_count, device=device) 41 | strip_heights = (1 + (H - 2) / (1.0 + torch.exp(-(strip_indices - strip_count / 2.0 + 0.5) / (strip_count / 8.0)))).long() 42 | 43 | indices = torch.cat([strip_heights-1, strip_heights, strip_heights+1]) 44 | strips = image[:, indices, :].reshape(B, 3, strip_count, W).permute(0, 2, 1, 3) 45 | strips = torch.cat([strips[..., :W//2], strips.flip(-1)[..., :W//2]], dim=1) 46 | 47 | ys = torch.cat([strip_heights, strip_heights])[None, :, None].repeat(B, 1, W//2-2) 48 | xs = (torch.arange(W//2-2) + 1)[None, None, :].to(device).repeat(B, strip_count, 1) 49 | xs = torch.cat([xs, W - 1 - xs], dim=1) 50 | 51 | grad = torch.conv2d(strips.reshape(B*2*strip_count, 1, 3, -1), SOBEL_KERNEL.to(device)).reshape(B, 2*strip_count, 2, -1) 52 | strips = strips[:, :, 1, 1:-1] 53 | 54 | grad_x = grad[:, :, 0] 55 | grad_y = grad[:, :, 1] 56 | grad_x[:, strip_count:] *= -1 57 | 58 | grad = torch.sqrt(grad_x**2 + grad_y**2) 59 | 60 | max_preceeding_intensity = torch.cummax(strips, dim=-1)[0] 61 | 62 | center_dir_x = (0.5 * W) - xs 63 | center_dir_y = (0.5 * H) - ys 64 | center_dir_norm = torch.sqrt(center_dir_x * center_dir_x + center_dir_y * center_dir_y) 65 | 66 | dot = torch.where(grad == 0, -1, (center_dir_x * grad_x + center_dir_y * grad_y) / (center_dir_norm * grad)) 67 | dot = torch.clamp(dot, -0.99, 0.99) 68 | angle = RAD2DEG * torch.acos(dot) 69 | 70 | edge_score = torch.tanh(grad / feature_thresholds[0]) 71 | angle_score = 1.0 - torch.tanh(angle / feature_thresholds[1]) 72 | intensity_score = 1.0 - torch.tanh(max_preceeding_intensity / feature_thresholds[2]) 73 | 74 | point_scores = edge_score * angle_score * intensity_score 75 | 76 | point_scores, indices = torch.max(point_scores, dim=-1) 77 | 78 | point_scores = point_scores 79 | ys = torch.gather(ys, 2, indices[:, :, None])[:, :, 0] 80 | xs = torch.gather(xs, 2, indices[:, :, None])[:, :, 0] 81 | 82 | result = torch.stack([xs, ys, point_scores], dim=2).reshape(B, strip_count*2, 3).permute(0, 2, 1) 83 | 84 | if not batched: 85 | result = result[0] 86 | 87 | return result 88 | 89 | def get_points_learned(image: torch.Tensor, strip_count: int=16, model: Optional[torch.jit.ScriptModule]=None, model_patch_size: int=7) -> torch.Tensor: 90 | """ 91 | Finds candidate edge points and corresponding scores in the given image(s) using learned feature extraction. 92 | """ 93 | raise NotImplementedError("The learned method is not implemented in the python fallback. The handcrafted method should provide better performance, but to use the learned method you will need to install the compiled extension.") 94 | -------------------------------------------------------------------------------- /src/torchcontentarea/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from math import sqrt, floor 3 | 4 | def draw_area(area, image, bias=0): 5 | """ 6 | Draws a binary mask with 1 inside the content area and 0 outside. 7 | """ 8 | image_size, device = image.shape[-2:], image.device 9 | 10 | mesh = torch.meshgrid(torch.arange(0, image_size[0], device=device), torch.arange(0, image_size[1], device=device)) 11 | dist = torch.sqrt((mesh[0] - area[1])**2 + (mesh[1] - area[0])**2) 12 | mask = torch.where(dist < (area[2] + bias), 1, 0).to(dtype=torch.uint8) 13 | 14 | return mask 15 | 16 | def get_crop(area, image_size, aspect_ratio=None, bias=0): 17 | """ 18 | Returns the optimal crop (top, bottom, left, right) to remove areas outside the content area. 19 | """ 20 | i_h, i_w = image_size 21 | a_x, a_y, a_r = area[:3] 22 | 23 | if aspect_ratio == None: 24 | aspect_ratio = i_w / i_h 25 | 26 | inscribed_height = 2 * (a_r + bias - 2) / sqrt(1 + aspect_ratio * aspect_ratio) 27 | inscribed_width = inscribed_height * aspect_ratio 28 | 29 | left = max(a_x - inscribed_width / 2, 0) 30 | right = min(a_x + inscribed_width / 2, i_w) 31 | top = max(a_y - inscribed_height / 2, 0) 32 | bottom = min(a_y + inscribed_height / 2, i_h) 33 | 34 | x_scale = (right - left) 35 | y_scale = (bottom - top) * aspect_ratio 36 | 37 | scale = min(x_scale, y_scale) 38 | 39 | w = int(floor(scale)) 40 | h = int(floor(scale / aspect_ratio)) 41 | 42 | x = int(left + (right - left) / 2 - w / 2) 43 | y = int(top + (bottom - top) / 2 - h / 2) 44 | 45 | crop = y, y+h, x, x+w 46 | 47 | return crop 48 | 49 | def crop_area(area, image, aspect_ratio=None, bias=0): 50 | """ 51 | Crops the image to remove areas outside the content area. 52 | """ 53 | crop = get_crop(area, image.shape[-2:], aspect_ratio, bias) 54 | 55 | cropped_image = image[..., crop[0]:crop[1], crop[2]:crop[3]] 56 | 57 | return cropped_image 58 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # from .test_performance import TestPerformance 2 | from .test_api import TestAPI 3 | -------------------------------------------------------------------------------- /tests/test_api.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | 4 | from .utils.data import DummyDataset 5 | from .utils.scoring import content_area_hausdorff, MISS_THRESHOLD 6 | 7 | from torchcontentarea import estimate_area_handcrafted, estimate_area_learned, get_points_handcrafted, get_points_learned, fit_area 8 | 9 | ESTIMATION_MEHTODS = [ 10 | ("handcrafted cpu", estimate_area_handcrafted, "cpu"), 11 | ("learned cpu", estimate_area_learned, "cpu"), 12 | ("handcrafted cuda", estimate_area_handcrafted, "cuda"), 13 | ("learned cuda", estimate_area_learned, "cuda"), 14 | ("handcrafted cpu two staged", lambda x: fit_area(get_points_handcrafted(x), x.shape[-2:]), "cpu"), 15 | ("learned cpu two staged", lambda x: fit_area(get_points_learned(x), x.shape[-2:]), "cpu"), 16 | ("handcrafted cuda two staged", lambda x: fit_area(get_points_handcrafted(x), x.shape[-2:]), "cuda"), 17 | ("learned cuda two staged", lambda x: fit_area(get_points_learned(x), x.shape[-2:]), "cuda"), 18 | ] 19 | 20 | class TestAPI(unittest.TestCase): 21 | 22 | def __init__(self, methodName: str = ...) -> None: 23 | super().__init__(methodName) 24 | self.dataset = DummyDataset() 25 | 26 | def test_unbatched(self): 27 | image, _, true_area = self.dataset[0] 28 | for name, method, device in ESTIMATION_MEHTODS: 29 | with self.subTest(name): 30 | image = image.to(device) 31 | result = method(image).tolist() 32 | estimated_area = result[0:3] 33 | error, _ = content_area_hausdorff(true_area, estimated_area, image.shape[-2:]) 34 | self.assertLess(error, MISS_THRESHOLD) 35 | 36 | def test_batched(self): 37 | image_a, _, true_area_a = self.dataset[0] 38 | image_b, _, true_area_b = self.dataset[1] 39 | images = torch.stack([image_a, image_b]) 40 | true_areas = [true_area_a, true_area_b] 41 | for name, method, device in ESTIMATION_MEHTODS: 42 | with self.subTest(name): 43 | images = images.to(device) 44 | results = method(images).tolist() 45 | for true_area, result in zip(true_areas, results): 46 | estimated_area = result[0:3] 47 | error, _ = content_area_hausdorff(true_area, estimated_area, images.shape[-2:]) 48 | self.assertLess(error, MISS_THRESHOLD) 49 | 50 | def test_rgb(self): 51 | image, _, true_area = self.dataset[0] 52 | for name, method, device in ESTIMATION_MEHTODS: 53 | with self.subTest(name): 54 | image = image.to(device) 55 | result = method(image).tolist() 56 | estimated_area = result[0:3] 57 | error, _ = content_area_hausdorff(true_area, estimated_area, image.shape[-2:]) 58 | self.assertLess(error, MISS_THRESHOLD) 59 | 60 | def test_grayscale(self): 61 | image, _, true_area = self.dataset[0] 62 | image = (0.2126 * image[0:1] + 0.7152 * image[1:2]+ 0.0722 * image[2:3]).to(torch.uint8) 63 | for name, method, device in ESTIMATION_MEHTODS: 64 | with self.subTest(name): 65 | image = image.to(device) 66 | result = method(image).tolist() 67 | estimated_area = result[0:3] 68 | error, _ = content_area_hausdorff(true_area, estimated_area, image.shape[-2:]) 69 | self.assertLess(error, MISS_THRESHOLD) 70 | 71 | def test_large(self): 72 | image, _, true_area = self.dataset[0] 73 | image = torch.nn.functional.interpolate(image.unsqueeze(0).to(dtype=torch.float), scale_factor=4, mode='bilinear')[0].to(dtype=torch.uint8) 74 | true_area = tuple(map(lambda x: x*4, true_area)) 75 | for name, method, device in ESTIMATION_MEHTODS: 76 | with self.subTest(name): 77 | image = image.to(device) 78 | result = method(image).tolist() 79 | estimated_area = result[0:3] 80 | error, _ = content_area_hausdorff(true_area, estimated_area, image.shape[-2:]) 81 | self.assertLess(error, MISS_THRESHOLD) 82 | 83 | def test_small(self): 84 | image, _, true_area = self.dataset[0] 85 | image = image[:, ::4, ::4] 86 | true_area = tuple(map(lambda x: x/4, true_area)) 87 | for name, method, device in ESTIMATION_MEHTODS: 88 | with self.subTest(name): 89 | image = image.to(device) 90 | result = method(image).tolist() 91 | estimated_area = result[0:3] 92 | error, _ = content_area_hausdorff(true_area, estimated_area, image.shape[-2:]) 93 | self.assertLess(error, MISS_THRESHOLD) 94 | 95 | def test_byte(self): 96 | image, _, true_area = self.dataset[0] 97 | image = image.to(dtype=torch.uint8) 98 | for name, method, device in ESTIMATION_MEHTODS: 99 | with self.subTest(name): 100 | image = image.to(device) 101 | result = method(image).tolist() 102 | estimated_area = result[0:3] 103 | error, _ = content_area_hausdorff(true_area, estimated_area, image.shape[-2:]) 104 | self.assertLess(error, MISS_THRESHOLD) 105 | 106 | def test_int(self): 107 | image, _, true_area = self.dataset[0] 108 | image = image.to(dtype=torch.int) 109 | for name, method, device in ESTIMATION_MEHTODS: 110 | with self.subTest(name): 111 | image = image.to(device) 112 | result = method(image).tolist() 113 | estimated_area = result[0:3] 114 | error, _ = content_area_hausdorff(true_area, estimated_area, image.shape[-2:]) 115 | self.assertLess(error, MISS_THRESHOLD) 116 | 117 | def test_long(self): 118 | image, _, true_area = self.dataset[0] 119 | image = image.to(dtype=torch.int64) 120 | for name, method, device in ESTIMATION_MEHTODS: 121 | with self.subTest(name): 122 | image = image.to(device) 123 | result = method(image).tolist() 124 | estimated_area = result[0:3] 125 | error, _ = content_area_hausdorff(true_area, estimated_area, image.shape[-2:]) 126 | self.assertLess(error, MISS_THRESHOLD) 127 | 128 | def test_float(self): 129 | image, _, true_area = self.dataset[0] 130 | image = image.to(dtype=torch.float) / 255 131 | for name, method, device in ESTIMATION_MEHTODS: 132 | with self.subTest(name): 133 | image = image.to(device) 134 | result = method(image).tolist() 135 | estimated_area = result[0:3] 136 | error, _ = content_area_hausdorff(true_area, estimated_area, image.shape[-2:]) 137 | self.assertLess(error, MISS_THRESHOLD) 138 | 139 | def test_double(self): 140 | image, _, true_area = self.dataset[0] 141 | image = image.to(dtype=torch.double) / 255 142 | for name, method, device in ESTIMATION_MEHTODS: 143 | with self.subTest(name): 144 | image = image.to(device) 145 | result = method(image).tolist() 146 | estimated_area = result[0:3] 147 | error, _ = content_area_hausdorff(true_area, estimated_area, image.shape[-2:]) 148 | self.assertLess(error, MISS_THRESHOLD) 149 | 150 | 151 | if __name__ == '__main__': 152 | unittest.main() 153 | -------------------------------------------------------------------------------- /tests/test_performance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import unittest 4 | import cpuinfo 5 | from time import sleep 6 | 7 | from .utils.data import TestDataset, TestDataLoader 8 | from .utils.scoring import content_area_hausdorff, MISS_THRESHOLD, BAD_MISS_THRESHOLD 9 | from .utils.profiling import Timer 10 | 11 | from torchcontentarea import estimate_area_handcrafted, estimate_area_learned 12 | 13 | TEST_LOG = "" 14 | 15 | TEST_CASES = [ 16 | # ("handcrafted cpu", estimate_area_handcrafted, "cpu"), 17 | # ("learned cpu", estimate_area_learned, "cpu"), 18 | ("handcrafted cuda", estimate_area_handcrafted, "cuda"), 19 | # ("learned cuda", estimate_area_learned, "cuda"), 20 | ] 21 | 22 | class TestPerformance(unittest.TestCase): 23 | 24 | def __init__(self, methodName: str = ...) -> None: 25 | super().__init__(methodName) 26 | self.dataset = TestDataset() 27 | self.dataloader = TestDataLoader(self.dataset) 28 | 29 | def test_performance(self): 30 | 31 | times = [[] for _ in range(len(TEST_CASES))] 32 | errors = [[] for _ in range(len(TEST_CASES))] 33 | 34 | for img, area in self.dataloader: 35 | 36 | for i, (name, method, device) in enumerate(TEST_CASES): 37 | 38 | img = img.to(device=device) 39 | 40 | with Timer() as timer: 41 | result = method(img) 42 | time = timer.time 43 | 44 | result = result.tolist() 45 | infered_area, confidence = result[:3], result[3] 46 | 47 | infered_area = tuple(map(int, infered_area)) 48 | 49 | if confidence < 0.06: 50 | infered_area = None 51 | 52 | error, _ = content_area_hausdorff(area, infered_area, img.shape[-2:]) 53 | 54 | errors[i].append(error) 55 | times[i].append(time) 56 | 57 | 58 | for (name, _, device), times, errors in zip(TEST_CASES, times, errors): 59 | device_name = torch.cuda.get_device_name() if device == "cuda" else cpuinfo.get_cpu_info()['brand_raw'] 60 | run_in_count = int(len(times) // 100) 61 | times = times[run_in_count:] 62 | avg_time = sum(times) / len(times) 63 | std_time = np.std(times) 64 | 65 | sample_count = len(self.dataset) 66 | average_error = sum(errors) / sample_count 67 | miss_percentage = 100 * sum(map(lambda x: x > MISS_THRESHOLD, errors)) / sample_count 68 | bad_miss_percentage = 100 * sum(map(lambda x: x > BAD_MISS_THRESHOLD, errors)) / sample_count 69 | 70 | global TEST_LOG 71 | TEST_LOG += "\n".join([ 72 | f"\n", 73 | f"Performance Results ({name})...", 74 | f"- Avg Time ({device_name}): {avg_time:.3f} ± {std_time:.3f}ms", 75 | f"- Avg Error (Hausdorff Distance): {average_error:.3f}", 76 | f"- Miss Rate (Error > {MISS_THRESHOLD}): {miss_percentage:.1f}%", 77 | f"- Bad Miss Rate (Error > {BAD_MISS_THRESHOLD}): {bad_miss_percentage:.1f}%" 78 | ]) 79 | 80 | self.assertTrue(avg_time < 10) 81 | self.assertTrue(average_error < 10) 82 | self.assertTrue(miss_percentage < 10.0) 83 | self.assertTrue(bad_miss_percentage < 5.0) 84 | 85 | @classmethod 86 | def tearDownClass(cls): 87 | if TEST_LOG != "": 88 | sleep(3) 89 | print(TEST_LOG) 90 | 91 | if __name__ == '__main__': 92 | unittest.main() 93 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | 4 | from .utils.data import DummyDataset 5 | from .utils.scoring import content_area_hausdorff, MISS_THRESHOLD 6 | 7 | from torchcontentarea.utils import draw_area, crop_area 8 | 9 | TESTS_CASES = ["cuda", "cpu"] 10 | 11 | class TestUtils(unittest.TestCase): 12 | 13 | def __init__(self, methodName: str = ...) -> None: 14 | super().__init__(methodName) 15 | self.dataset = DummyDataset() 16 | 17 | def test_draw_area(self): 18 | _, mask, area = self.dataset[0] 19 | for device in TESTS_CASES: 20 | with self.subTest(device): 21 | mask = mask.to(device) 22 | result = draw_area(area, mask) 23 | score = torch.where(1 - result == mask, 0, 1).sum() 24 | self.assertLess(score, 1) 25 | 26 | def test_crop_area(self): 27 | _, mask, area = self.dataset[0] 28 | for device in TESTS_CASES: 29 | with self.subTest(device): 30 | mask = mask.to(device) 31 | result = crop_area(area, mask) 32 | self.assertTrue(result.unique().numel() == 1) 33 | 34 | if __name__ == '__main__': 35 | unittest.main() -------------------------------------------------------------------------------- /tests/utils/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data import Dataset, DataLoader 4 | from ecadataset import ECADataset, DataSource, AnnotationType 5 | 6 | def meshgrid(tensors): 7 | major, minor = map(int, torch.__version__.split(".")[:2]) 8 | if major >= 1 and minor > 9: 9 | return torch.meshgrid(tensors, indexing="ij") 10 | else: 11 | return torch.meshgrid(tensors) 12 | 13 | 14 | ######################## 15 | # Datasets... 16 | 17 | class TestDataLoader(DataLoader): 18 | def __init__(self, dataset, shuffle=False) -> None: 19 | super().__init__(dataset=dataset, batch_size=None, pin_memory=True, shuffle=shuffle) 20 | 21 | 22 | class TestDataset(Dataset): 23 | def __init__(self) -> None: 24 | super().__init__() 25 | self.dataset = ECADataset("eca-data", DataSource.CHOLEC, AnnotationType.AREA, include_cropped=True) 26 | 27 | def __len__(self): 28 | return len(self.dataset) 29 | 30 | def __getitem__(self, index): 31 | image, area = self.dataset[index] 32 | img = torch.from_numpy(np.array(image)).permute(2, 0, 1) 33 | area = None if area == None else torch.from_numpy(np.array(area)) 34 | return img, area 35 | 36 | 37 | class DummyDataset(Dataset): 38 | def __init__(self) -> None: 39 | super().__init__() 40 | 41 | self.width = 854 42 | self.height = 480 43 | 44 | self.areas = [ 45 | (400, 250, 360), 46 | (340, 200, 370), 47 | (450, 230, 250), 48 | None, 49 | ] 50 | 51 | def __len__(self): 52 | return len(self.areas) 53 | 54 | def __getitem__(self, index): 55 | area = self.areas[index] 56 | 57 | if area != None: 58 | area_x, area_y, area_r = self.areas[index] 59 | coords = torch.stack(meshgrid([torch.arange(0, self.height), torch.arange(0, self.width)])) 60 | center = torch.Tensor([area_y, area_x]).reshape((2, 1, 1)) 61 | mask = torch.where(torch.linalg.norm(abs(coords - center), dim=0) < area_r, 0, 1).unsqueeze(0) 62 | else: 63 | mask = torch.zeros(1, self.height, self.width) 64 | 65 | img = 255 * (1 - mask).expand((3, self.height, self.width)) 66 | img = img.to(dtype=torch.uint8).contiguous() 67 | mask = mask.to(dtype=torch.uint8).contiguous() 68 | 69 | return img, mask, area 70 | -------------------------------------------------------------------------------- /tests/utils/profiling.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | 4 | UNITS = {'h': 1.0/60.0, 'm': 1.0/60.0, 's': 1.0, 'ms': 1e3, 'us': 1e6} 5 | 6 | def get_time(): 7 | torch.cuda.synchronize() 8 | return time.time() 9 | 10 | class Timer(): 11 | def __init__(self, units='ms'): 12 | assert units in UNITS, f'The given units {units} is not supported, please use h, m, s, ms, or us' 13 | self.units = units 14 | 15 | def __enter__(self): 16 | self.start = get_time() 17 | return self 18 | 19 | def __exit__(self, type, value, traceback): 20 | self.end = get_time() 21 | self.time = (self.end - self.start) * UNITS[self.units] 22 | -------------------------------------------------------------------------------- /tests/utils/scoring.py: -------------------------------------------------------------------------------- 1 | from ecadataset import content_area_hausdorff 2 | 3 | MISS_THRESHOLD=15 4 | BAD_MISS_THRESHOLD=25 5 | --------------------------------------------------------------------------------