├── .gitattributes
├── .github
└── workflows
│ ├── build.yml
│ ├── test.yml
│ └── update-perf-stats.yml
├── .gitignore
├── LICENSE
├── MANIFEST.in
├── README.md
├── example.gif
├── setup.cfg
├── setup.py
├── src
└── torchcontentarea
│ ├── __init__.py
│ ├── _version.py
│ ├── csrc
│ ├── common.hpp
│ ├── cpu_functions.hpp
│ ├── cuda_functions.cuh
│ ├── implementation.cpp
│ ├── implementation.hpp
│ ├── python_bindings.cpp
│ └── source
│ │ ├── find_points_cpu.cpp
│ │ ├── find_points_cuda.cu
│ │ ├── find_points_from_scores_cpu.cpp
│ │ ├── find_points_from_scores_cuda.cu
│ │ ├── fit_circle_cpu.cpp
│ │ ├── fit_circle_cuda.cu
│ │ ├── make_strips_cpu.cpp
│ │ └── make_strips_cuda.cu
│ ├── extension_wrapper.py
│ ├── models
│ ├── kernel_1_8.pt
│ ├── kernel_2_8.pt
│ └── kernel_3_8.pt
│ ├── pythonimplementation
│ ├── __init__.py
│ ├── estimate_area.py
│ ├── fit_area.py
│ └── get_points.py
│ └── utils.py
├── tests
├── __init__.py
├── test_api.py
├── test_performance.py
├── test_utils.py
└── utils
│ ├── data.py
│ ├── profiling.py
│ └── scoring.py
└── versioneer.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | src/torchcontentarea/_version.py export-subst
2 |
--------------------------------------------------------------------------------
/.github/workflows/build.yml:
--------------------------------------------------------------------------------
1 | name: Build
2 |
3 | on:
4 | pull_request:
5 | branches:
6 | - main
7 | push:
8 | branches:
9 | - main
10 | tags:
11 | - 'v*'
12 |
13 | jobs:
14 | build-wheels:
15 | name: Build Wheels
16 | uses: charliebudd/torch-extension-builder/.github/workflows/build-pytorch-extension-wheels.yml@main
17 | with:
18 | build-command: "git config --global --add safe.directory ${GITHUB_WORKSPACE} && python setup.py bdist_wheel"
19 |
20 | test-wheels-locally:
21 | name: Test Wheels Locally
22 | needs: build-wheels
23 | uses: ./.github/workflows/test.yml
24 | with:
25 | local-wheels: true
26 |
27 | update-perfromance-stats:
28 | if: ${{ github.base_ref == 'main' }}
29 | name: Update Performance Stats
30 | needs: test-wheels-locally
31 | uses: ./.github/workflows/update-perf-stats.yml
32 |
33 | publish-wheels-to-testpypi:
34 | if: startsWith(github.ref, 'refs/tags/v')
35 | name: Publish Wheels To TestPyPI
36 | runs-on: ubuntu-latest
37 | needs: test-wheels-locally
38 | steps:
39 | - name: Download Cached Wheels
40 | uses: actions/download-artifact@v3
41 | with:
42 | name: final-wheels
43 | path: dist
44 |
45 | - name: Checkout repo
46 | uses: actions/checkout@v3
47 |
48 | - name: Make Source Distribution
49 | run: python setup.py sdist
50 |
51 | - name: Publish Package to TestPyPI
52 | uses: pypa/gh-action-pypi-publish@release/v1
53 | with:
54 | user: __token__
55 | password: ${{ secrets.TEST_PYPI_API_TOKEN }}
56 | repository_url: https://test.pypi.org/legacy/
57 |
58 | test-testpypi-release:
59 | if: startsWith(github.ref, 'refs/tags/v')
60 | name: Test TestPyPi Release
61 | needs: publish-wheels-to-testpypi
62 | uses: ./.github/workflows/test.yml
63 | with:
64 | local-wheels: false
65 | wheel-location: https://test.pypi.org/simple/
66 |
67 | publish-wheels-to-pypi:
68 | if: startsWith(github.ref, 'refs/tags/v')
69 | name: Publish Wheels To PyPI
70 | runs-on: ubuntu-latest
71 | needs: test-testpypi-release
72 | steps:
73 | - name: Download Cached Wheels
74 | uses: actions/download-artifact@v3
75 | with:
76 | name: final-wheels
77 | path: dist
78 |
79 | - name: Checkout repo
80 | uses: actions/checkout@v3
81 |
82 | - name: Make Source Distribution
83 | run: python setup.py sdist
84 |
85 | - name: Publish Package to TestPyPI
86 | uses: pypa/gh-action-pypi-publish@release/v1
87 | with:
88 | user: __token__
89 | password: ${{ secrets.PYPI_API_TOKEN }}
90 |
91 | test-pypi-release:
92 | if: startsWith(github.ref, 'refs/tags/v')
93 | name: Test PyPi Release
94 | needs: publish-wheels-to-pypi
95 | uses: ./.github/workflows/test.yml
96 | with:
97 | local-wheels: false
98 | wheel-location: https://pypi.org/simple/
99 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: Test Release
2 |
3 | on:
4 | workflow_call:
5 | inputs:
6 | local-wheels:
7 | type: boolean
8 | default: true
9 | wheel-location:
10 | type: string
11 | default: final-wheels
12 | python-versions:
13 | type: string
14 | default: "[3.6, 3.7, 3.8, 3.9]"
15 | pytorch-versions:
16 | type: string
17 | default: "[1.9, '1.10', 1.11, 1.12]"
18 | cuda-versions:
19 | type: string
20 | default: "[10.2, 11.3, 11.6]"
21 |
22 | jobs:
23 | test-release:
24 | name: Test Release
25 | runs-on: [self-hosted, Linux, X64, gpu]
26 | strategy:
27 | fail-fast: false
28 | matrix:
29 | python-version: ${{ fromJson(inputs.python-versions) }}
30 | pytorch-version: ${{ fromJson(inputs.pytorch-versions) }}
31 | cuda-version: ${{ fromJson(inputs.cuda-versions) }}
32 | exclude:
33 | - pytorch-version: 1.9
34 | cuda-version: 11.3
35 | - python-version: 3.6
36 | pytorch-version: 1.11
37 | - python-version: 3.6
38 | pytorch-version: 1.12
39 | - pytorch-version: 1.9
40 | cuda-version: 11.6
41 | - pytorch-version: '1.10'
42 | cuda-version: 11.6
43 | - pytorch-version: 1.11
44 | cuda-version: 11.6
45 | steps:
46 | - name: Checkout Repository
47 | uses: actions/checkout@v2
48 |
49 | - name: Setup Python
50 | uses: actions/setup-python@v3
51 | with:
52 | python-version: ${{ matrix.python-version }}
53 |
54 | - name: Install Test Requirements
55 | run: |
56 | ENV=../.venv-${{ matrix.python-version }}-${{ matrix.pytorch-version }}-${{ matrix.cuda-version }}
57 | if [ ! -d "$ENV" ]; then
58 | python -m venv $ENV
59 | fi
60 | . $ENV/bin/activate
61 |
62 | python -m pip install -U --force-reinstall pip
63 | python -m pip install numpy py-cpuinfo ecadataset
64 | export FULL_PYTORCH_VERSION=$(python -m pip index versions torch -f https://download.pytorch.org/whl/torch_stable.html | grep -o ${PYTORCH_VERSION}.[0-9]+cu${CUDA_VERSION//.} | head -n 1)
65 | python -m pip --no-cache-dir install torch==${FULL_PYTORCH_VERSION} -f https://download.pytorch.org/whl/torch_stable.html
66 |
67 | ln -s ../eca-data eca-data
68 |
69 | python -V
70 | pip show torch
71 | env:
72 | PYTORCH_VERSION: ${{ matrix.pytorch-version }}
73 | CUDA_VERSION: ${{ matrix.cuda-version }}
74 |
75 | - name: Install torchcontentarea From PyPI Index
76 | if: ${{ !inputs.local-wheels }}
77 | run: |
78 | . ../.venv-${{ matrix.python-version }}-${{ matrix.pytorch-version }}-${{ matrix.cuda-version }}/bin/activate
79 | python -m pip install --force-reinstall --no-deps torchcontentarea -i ${{ inputs.wheel-location }}
80 |
81 | - name: Download Cached Wheels
82 | if: ${{ inputs.local-wheels }}
83 | uses: actions/download-artifact@v3
84 | with:
85 | name: ${{ inputs.wheel-location }}
86 | path: ${{ inputs.wheel-location }}
87 |
88 | - name: Install torchcontentarea from Cached Wheels
89 | if: ${{ inputs.local-wheels }}
90 | run: |
91 | . ../.venv-${{ matrix.python-version }}-${{ matrix.pytorch-version }}-${{ matrix.cuda-version }}/bin/activate
92 | python -m pip install --force-reinstall --no-deps ${{ inputs.wheel-location }}/torchcontentarea*cp$(echo ${{ matrix.python-version }} | sed 's/\.//')*.whl
93 |
94 | - name: Run Tests
95 | run: |
96 | . ../.venv-${{ matrix.python-version }}-${{ matrix.pytorch-version }}-${{ matrix.cuda-version }}/bin/activate
97 | python -m unittest tests.test_api tests.test_utils
98 |
--------------------------------------------------------------------------------
/.github/workflows/update-perf-stats.yml:
--------------------------------------------------------------------------------
1 | name: Build
2 |
3 | on:
4 | workflow_call:
5 | inputs:
6 | python-version:
7 | type: string
8 | default: "3.9"
9 | pytorch-version:
10 | type: string
11 | default: "1.11"
12 | cuda-version:
13 | type: string
14 | default: "11.3"
15 | jobs:
16 | update-readme:
17 | name: Test Release
18 | runs-on: [self-hosted, Linux, X64, gpu]
19 | steps:
20 | - name: Checkout Repository
21 | uses: actions/checkout@v2
22 | with:
23 | ref: ${{ github.event.pull_request.head.sha }}
24 |
25 | - name: Setup Python
26 | uses: actions/setup-python@v3
27 | with:
28 | python-version: ${{ inputs.python-version }}
29 |
30 | - name: Install Test Requirements
31 | run: |
32 | ENV=../.venv-${{ inputs.python-version }}-${{ inputs.pytorch-version }}-${{ inputs.cuda-version }}
33 | if [ -d "$ENV" ]; then
34 | . $ENV/bin/activate
35 | else
36 | python -m venv $ENV
37 | . $ENV/bin/activate
38 | python -m pip install -U --force-reinstall pip
39 | python -m pip install numpy pillow torch==${{ inputs.pytorch-version }} -f https://download.pytorch.org/whl/cu$(echo ${{ inputs.cuda-version }} | sed 's/\.//')/torch_stable.html
40 | fi
41 |
42 | ln -s ../eca-data eca-data
43 |
44 | python -V
45 | pip show torch
46 |
47 | - name: Install torchcontentarea
48 | run: python setup.py install
49 |
50 | - id: run-tests
51 | name: Run Tests and Update README.md
52 | run: |
53 | . ../.venv-${{ inputs.python-version }}-${{ inputs.pytorch-version }}-${{ inputs.cuda-version }}/bin/activate
54 | RESULTS=$(sed -r '/^Performance/i\\r' <<< $(python -m unittest tests.test_performance | grep -e '^Performance' -e '^- '))
55 | RESULTS=$(printf '%s\n' "$RESULTS" | sed 's/\\/&&/g;s/^[[:blank:]]/\\&/;s/$/\\/')
56 |
57 | START=""
58 | END=""
59 | sed -ni "/$START/{p;:a;N;/$END/!ba;s/.*\n/$RESULTS \n/};p" README.md
60 |
61 | git add README.md
62 | git commit -m "updating performance stats"
63 |
64 | - name: Push changes
65 | uses: ad-m/github-push-action@master
66 | with:
67 | github_token: ${{ secrets.GITHUB_TOKEN }}
68 | branch: ${{ github.head_ref }}
69 |
70 |
71 |
72 |
73 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
134 | # pytype static type analyzer
135 | .pytype/
136 |
137 | # Cython debug symbols
138 | cython_debug/
139 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Charles Budd
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include versioneer.py
2 | include src/torchcontentarea/_version.py
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Torch Content Area
2 | A PyTorch tool kit for estimating the circular content area in endoscopic footage. The algorithm was developed and tested against the [Endoscopic Content Area (ECA) dataset](https://github.com/charliebudd/eca-dataset). Both this implementation and the dataset are released alongside our publication:
3 |
4 |
Rapid and robust endoscopic content area estimation: A lean GPU-based pipeline and curated benchmark dataset
5 | Charlie Budd, Luis C. Garcia-Peraza-Herrera, Martin Huber, Sebastien Ourselin, Tom Vercauteren.
6 | [ arXiv ]
7 | [ publication ]
8 |
9 |
10 | If you make use of this work, please cite the paper.
11 |
12 | [](https://github.com/charliebudd/torch-content-area/actions/workflows/build.yml)
13 |
14 | 
15 |
16 | ## Installation
17 | For Linux users, to install the latest version, simply run...
18 | ```
19 | pip install torchcontentarea
20 | ```
21 | For Windows users, or if you encounter any issues, try building from source by running...
22 | ```
23 | pip install git+https://github.com/charliebudd/torch-content-area
24 | ```
25 | ***Note:*** *this will require that you have CUDA installed and that its version matches the version of CUDA used to build your installation of PyTorch.*
26 |
27 | ## Usage
28 | ```python
29 | from torchvision.io import read_image
30 | from torchcontentarea import estimate_area, get_points, fit_area
31 | from torchcontentarea.utils import draw_area, crop_area
32 |
33 | # Grayscale or RGB image in NCHW or CHW format. Values should be normalised
34 | # between 0-1 for floating point types and 0-255 for integer types.
35 | image = read_image("my_image.png")
36 |
37 | # Either directly estimate area from image...
38 | area = estimate_area(image, strip_count=16)
39 |
40 | # ...or get the set of points and then fit the area.
41 | points = get_points(image, strip_count=16)
42 | area = fit_area(points, image.shape[2:4])
43 |
44 | # Utility function are included to help handle the content area...
45 | area_mask = draw_area(area, image)
46 | cropped_image = crop_area(area, image)
47 | ```
48 |
49 | ## Performance
50 | Performance is measured against the CholecECA subset of the [Endoscopic Content Area (ECA) dataset](https://github.com/charliebudd/eca-dataset).
51 |
52 |
53 | Performance Results (handcrafted cuda)...
54 | - Avg Time (NVIDIA GeForce GTX 980 Ti): 0.299 ± 0.042ms
55 | - Avg Error (Hausdorff Distance): 3.618
56 | - Miss Rate (Error > 15): 2.1%
57 | - Bad Miss Rate (Error > 25): 1.1%
58 |
59 |
60 |
--------------------------------------------------------------------------------
/example.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/charliebudd/torch-content-area/c613957f266a64232f8283975653635160b3f0a2/example.gif
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [versioneer]
2 | VCS = git
3 | style = pep440
4 | versionfile_source = src/torchcontentarea/_version.py
5 | versionfile_build = torchcontentarea/_version.py
6 | tag_prefix =
7 | parentdir_prefix = torchcontentarea-
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from torch.utils import cpp_extension
3 | from glob import glob
4 | import versioneer
5 | import sys
6 |
7 | with open("README.md") as file:
8 | long_description = file.read()
9 |
10 | ext_src_dir = "src/torchcontentarea/csrc/"
11 | ext_source_files = glob(ext_src_dir + "**/*.cpp", recursive=True) + glob(ext_src_dir + "**/*.cu", recursive=True)
12 |
13 | compile_args = {
14 | 'cxx': ['/O2'] if sys.platform.startswith("win") else ['-g0', '-O3'],
15 | 'nvcc': ['-O3']
16 | }
17 |
18 | try:
19 | setup(
20 | name="torchcontentarea",
21 | version=versioneer.get_version(),
22 | description="A PyTorch tool kit for estimating the content area in endoscopic footage.",
23 | long_description=long_description,
24 | long_description_content_type="text/markdown",
25 | author="Charlie Budd",
26 | author_email="charles.budd@kcl.ac.uk",
27 | url="https://github.com/charliebudd/torch-content-area",
28 | license="MIT",
29 | packages=["torchcontentarea", "torchcontentarea/pythonimplementation"],
30 | package_dir={"":"src"},
31 | package_data={'torchcontentarea': ['models/*.pt']},
32 | ext_modules=[cpp_extension.CUDAExtension("torchcontentareaext", ext_source_files, extra_compile_args=compile_args)],
33 | cmdclass=versioneer.get_cmdclass({"build_ext": cpp_extension.BuildExtension})
34 | )
35 | except:
36 | print("########################################################################")
37 | print("Could not compile CUDA extention, falling back to python implementation!")
38 | print("########################################################################")
39 | setup(
40 | name="torchcontentarea",
41 | version=versioneer.get_version(),
42 | description="A PyTorch tool kit for estimating the content area in endoscopic footage.",
43 | long_description=long_description,
44 | long_description_content_type="text/markdown",
45 | author="Charlie Budd",
46 | author_email="charles.budd@kcl.ac.uk",
47 | url="https://github.com/charliebudd/torch-content-area",
48 | license="MIT",
49 | packages=["torchcontentarea", "torchcontentarea/pythonimplementation"],
50 | package_dir={"":"src"},
51 | package_data={'torchcontentarea': ['models/*.pt']},
52 | )
53 |
--------------------------------------------------------------------------------
/src/torchcontentarea/__init__.py:
--------------------------------------------------------------------------------
1 | """A PyTorch tool kit for segmenting the endoscopic content area in laparoscopy footage."""
2 |
3 | from .extension_wrapper import estimate_area_handcrafted as estimate_area, estimate_area_handcrafted, estimate_area_learned
4 | from .extension_wrapper import get_points_handcrafted as get_points, get_points_handcrafted, get_points_learned
5 | from .extension_wrapper import fit_area
6 |
7 | from .utils import draw_area, get_crop, crop_area
8 |
9 | from . import _version
10 | __version__ = _version.get_versions()['version']
11 |
--------------------------------------------------------------------------------
/src/torchcontentarea/_version.py:
--------------------------------------------------------------------------------
1 |
2 | # This file helps to compute a version number in source trees obtained from
3 | # git-archive tarball (such as those provided by githubs download-from-tag
4 | # feature). Distribution tarballs (built by setup.py sdist) and build
5 | # directories (produced by setup.py build) will contain a much shorter file
6 | # that just contains the computed version number.
7 |
8 | # This file is released into the public domain. Generated by
9 | # versioneer-0.22 (https://github.com/python-versioneer/python-versioneer)
10 |
11 | """Git implementation of _version.py."""
12 |
13 | import errno
14 | import os
15 | import re
16 | import subprocess
17 | import sys
18 | from typing import Callable, Dict
19 | import functools
20 |
21 |
22 | def get_keywords():
23 | """Get the keywords needed to look up the version information."""
24 | # these strings will be replaced by git during git-archive.
25 | # setup.py/versioneer.py will grep for the variable names, so they must
26 | # each be defined on a line of their own. _version.py will just call
27 | # get_keywords().
28 | git_refnames = " (HEAD -> main)"
29 | git_full = "c613957f266a64232f8283975653635160b3f0a2"
30 | git_date = "2024-11-07 14:18:22 +0000"
31 | keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
32 | return keywords
33 |
34 |
35 | class VersioneerConfig:
36 | """Container for Versioneer configuration parameters."""
37 |
38 |
39 | def get_config():
40 | """Create, populate and return the VersioneerConfig() object."""
41 | # these strings are filled in when 'setup.py versioneer' creates
42 | # _version.py
43 | cfg = VersioneerConfig()
44 | cfg.VCS = "git"
45 | cfg.style = "pep440"
46 | cfg.tag_prefix = ""
47 | cfg.parentdir_prefix = "torchcontentarea-"
48 | cfg.versionfile_source = "src/torchcontentarea/_version.py"
49 | cfg.verbose = False
50 | return cfg
51 |
52 |
53 | class NotThisMethod(Exception):
54 | """Exception raised if a method is not valid for the current scenario."""
55 |
56 |
57 | LONG_VERSION_PY: Dict[str, str] = {}
58 | HANDLERS: Dict[str, Dict[str, Callable]] = {}
59 |
60 |
61 | def register_vcs_handler(vcs, method): # decorator
62 | """Create decorator to mark a method as the handler of a VCS."""
63 | def decorate(f):
64 | """Store f in HANDLERS[vcs][method]."""
65 | if vcs not in HANDLERS:
66 | HANDLERS[vcs] = {}
67 | HANDLERS[vcs][method] = f
68 | return f
69 | return decorate
70 |
71 |
72 | def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
73 | env=None):
74 | """Call the given command(s)."""
75 | assert isinstance(commands, list)
76 | process = None
77 |
78 | popen_kwargs = {}
79 | if sys.platform == "win32":
80 | # This hides the console window if pythonw.exe is used
81 | startupinfo = subprocess.STARTUPINFO()
82 | startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
83 | popen_kwargs["startupinfo"] = startupinfo
84 |
85 | for command in commands:
86 | try:
87 | dispcmd = str([command] + args)
88 | # remember shell=False, so use git.cmd on windows, not just git
89 | process = subprocess.Popen([command] + args, cwd=cwd, env=env,
90 | stdout=subprocess.PIPE,
91 | stderr=(subprocess.PIPE if hide_stderr
92 | else None), **popen_kwargs)
93 | break
94 | except OSError:
95 | e = sys.exc_info()[1]
96 | if e.errno == errno.ENOENT:
97 | continue
98 | if verbose:
99 | print("unable to run %s" % dispcmd)
100 | print(e)
101 | return None, None
102 | else:
103 | if verbose:
104 | print("unable to find command, tried %s" % (commands,))
105 | return None, None
106 | stdout = process.communicate()[0].strip().decode()
107 | if process.returncode != 0:
108 | if verbose:
109 | print("unable to run %s (error)" % dispcmd)
110 | print("stdout was %s" % stdout)
111 | return None, process.returncode
112 | return stdout, process.returncode
113 |
114 |
115 | def versions_from_parentdir(parentdir_prefix, root, verbose):
116 | """Try to determine the version from the parent directory name.
117 |
118 | Source tarballs conventionally unpack into a directory that includes both
119 | the project name and a version string. We will also support searching up
120 | two directory levels for an appropriately named parent directory
121 | """
122 | rootdirs = []
123 |
124 | for _ in range(3):
125 | dirname = os.path.basename(root)
126 | if dirname.startswith(parentdir_prefix):
127 | return {"version": dirname[len(parentdir_prefix):],
128 | "full-revisionid": None,
129 | "dirty": False, "error": None, "date": None}
130 | rootdirs.append(root)
131 | root = os.path.dirname(root) # up a level
132 |
133 | if verbose:
134 | print("Tried directories %s but none started with prefix %s" %
135 | (str(rootdirs), parentdir_prefix))
136 | raise NotThisMethod("rootdir doesn't start with parentdir_prefix")
137 |
138 |
139 | @register_vcs_handler("git", "get_keywords")
140 | def git_get_keywords(versionfile_abs):
141 | """Extract version information from the given file."""
142 | # the code embedded in _version.py can just fetch the value of these
143 | # keywords. When used from setup.py, we don't want to import _version.py,
144 | # so we do it with a regexp instead. This function is not used from
145 | # _version.py.
146 | keywords = {}
147 | try:
148 | with open(versionfile_abs, "r") as fobj:
149 | for line in fobj:
150 | if line.strip().startswith("git_refnames ="):
151 | mo = re.search(r'=\s*"(.*)"', line)
152 | if mo:
153 | keywords["refnames"] = mo.group(1)
154 | if line.strip().startswith("git_full ="):
155 | mo = re.search(r'=\s*"(.*)"', line)
156 | if mo:
157 | keywords["full"] = mo.group(1)
158 | if line.strip().startswith("git_date ="):
159 | mo = re.search(r'=\s*"(.*)"', line)
160 | if mo:
161 | keywords["date"] = mo.group(1)
162 | except OSError:
163 | pass
164 | return keywords
165 |
166 |
167 | @register_vcs_handler("git", "keywords")
168 | def git_versions_from_keywords(keywords, tag_prefix, verbose):
169 | """Get version information from git keywords."""
170 | if "refnames" not in keywords:
171 | raise NotThisMethod("Short version file found")
172 | date = keywords.get("date")
173 | if date is not None:
174 | # Use only the last line. Previous lines may contain GPG signature
175 | # information.
176 | date = date.splitlines()[-1]
177 |
178 | # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant
179 | # datestamp. However we prefer "%ci" (which expands to an "ISO-8601
180 | # -like" string, which we must then edit to make compliant), because
181 | # it's been around since git-1.5.3, and it's too difficult to
182 | # discover which version we're using, or to work around using an
183 | # older one.
184 | date = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
185 | refnames = keywords["refnames"].strip()
186 | if refnames.startswith("$Format"):
187 | if verbose:
188 | print("keywords are unexpanded, not using")
189 | raise NotThisMethod("unexpanded keywords, not a git-archive tarball")
190 | refs = {r.strip() for r in refnames.strip("()").split(",")}
191 | # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
192 | # just "foo-1.0". If we see a "tag: " prefix, prefer those.
193 | TAG = "tag: "
194 | tags = {r[len(TAG):] for r in refs if r.startswith(TAG)}
195 | if not tags:
196 | # Either we're using git < 1.8.3, or there really are no tags. We use
197 | # a heuristic: assume all version tags have a digit. The old git %d
198 | # expansion behaves like git log --decorate=short and strips out the
199 | # refs/heads/ and refs/tags/ prefixes that would let us distinguish
200 | # between branches and tags. By ignoring refnames without digits, we
201 | # filter out many common branch names like "release" and
202 | # "stabilization", as well as "HEAD" and "master".
203 | tags = {r for r in refs if re.search(r'\d', r)}
204 | if verbose:
205 | print("discarding '%s', no digits" % ",".join(refs - tags))
206 | if verbose:
207 | print("likely tags: %s" % ",".join(sorted(tags)))
208 | for ref in sorted(tags):
209 | # sorting will prefer e.g. "2.0" over "2.0rc1"
210 | if ref.startswith(tag_prefix):
211 | r = ref[len(tag_prefix):]
212 | # Filter out refs that exactly match prefix or that don't start
213 | # with a number once the prefix is stripped (mostly a concern
214 | # when prefix is '')
215 | if not re.match(r'\d', r):
216 | continue
217 | if verbose:
218 | print("picking %s" % r)
219 | return {"version": r,
220 | "full-revisionid": keywords["full"].strip(),
221 | "dirty": False, "error": None,
222 | "date": date}
223 | # no suitable tags, so version is "0+unknown", but full hex is still there
224 | if verbose:
225 | print("no suitable tags, using unknown + full revision id")
226 | return {"version": "0+unknown",
227 | "full-revisionid": keywords["full"].strip(),
228 | "dirty": False, "error": "no suitable tags", "date": None}
229 |
230 |
231 | @register_vcs_handler("git", "pieces_from_vcs")
232 | def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command):
233 | """Get version from 'git describe' in the root of the source tree.
234 |
235 | This only gets called if the git-archive 'subst' keywords were *not*
236 | expanded, and _version.py hasn't already been rewritten with a short
237 | version string, meaning we're inside a checked out source tree.
238 | """
239 | GITS = ["git"]
240 | if sys.platform == "win32":
241 | GITS = ["git.cmd", "git.exe"]
242 |
243 | # GIT_DIR can interfere with correct operation of Versioneer.
244 | # It may be intended to be passed to the Versioneer-versioned project,
245 | # but that should not change where we get our version from.
246 | env = os.environ.copy()
247 | env.pop("GIT_DIR", None)
248 | runner = functools.partial(runner, env=env)
249 |
250 | _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root,
251 | hide_stderr=True)
252 | if rc != 0:
253 | if verbose:
254 | print("Directory %s not under git control" % root)
255 | raise NotThisMethod("'git rev-parse --git-dir' returned error")
256 |
257 | MATCH_ARGS = ["--match", "%s*" % tag_prefix] if tag_prefix else []
258 |
259 | # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]
260 | # if there isn't one, this yields HEX[-dirty] (no NUM)
261 | describe_out, rc = runner(GITS, ["describe", "--tags", "--dirty",
262 | "--always", "--long", *MATCH_ARGS],
263 | cwd=root)
264 | # --long was added in git-1.5.5
265 | if describe_out is None:
266 | raise NotThisMethod("'git describe' failed")
267 | describe_out = describe_out.strip()
268 | full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root)
269 | if full_out is None:
270 | raise NotThisMethod("'git rev-parse' failed")
271 | full_out = full_out.strip()
272 |
273 | pieces = {}
274 | pieces["long"] = full_out
275 | pieces["short"] = full_out[:7] # maybe improved later
276 | pieces["error"] = None
277 |
278 | branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"],
279 | cwd=root)
280 | # --abbrev-ref was added in git-1.6.3
281 | if rc != 0 or branch_name is None:
282 | raise NotThisMethod("'git rev-parse --abbrev-ref' returned error")
283 | branch_name = branch_name.strip()
284 |
285 | if branch_name == "HEAD":
286 | # If we aren't exactly on a branch, pick a branch which represents
287 | # the current commit. If all else fails, we are on a branchless
288 | # commit.
289 | branches, rc = runner(GITS, ["branch", "--contains"], cwd=root)
290 | # --contains was added in git-1.5.4
291 | if rc != 0 or branches is None:
292 | raise NotThisMethod("'git branch --contains' returned error")
293 | branches = branches.split("\n")
294 |
295 | # Remove the first line if we're running detached
296 | if "(" in branches[0]:
297 | branches.pop(0)
298 |
299 | # Strip off the leading "* " from the list of branches.
300 | branches = [branch[2:] for branch in branches]
301 | if "master" in branches:
302 | branch_name = "master"
303 | elif not branches:
304 | branch_name = None
305 | else:
306 | # Pick the first branch that is returned. Good or bad.
307 | branch_name = branches[0]
308 |
309 | pieces["branch"] = branch_name
310 |
311 | # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty]
312 | # TAG might have hyphens.
313 | git_describe = describe_out
314 |
315 | # look for -dirty suffix
316 | dirty = git_describe.endswith("-dirty")
317 | pieces["dirty"] = dirty
318 | if dirty:
319 | git_describe = git_describe[:git_describe.rindex("-dirty")]
320 |
321 | # now we have TAG-NUM-gHEX or HEX
322 |
323 | if "-" in git_describe:
324 | # TAG-NUM-gHEX
325 | mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe)
326 | if not mo:
327 | # unparsable. Maybe git-describe is misbehaving?
328 | pieces["error"] = ("unable to parse git-describe output: '%s'"
329 | % describe_out)
330 | return pieces
331 |
332 | # tag
333 | full_tag = mo.group(1)
334 | if not full_tag.startswith(tag_prefix):
335 | if verbose:
336 | fmt = "tag '%s' doesn't start with prefix '%s'"
337 | print(fmt % (full_tag, tag_prefix))
338 | pieces["error"] = ("tag '%s' doesn't start with prefix '%s'"
339 | % (full_tag, tag_prefix))
340 | return pieces
341 | pieces["closest-tag"] = full_tag[len(tag_prefix):]
342 |
343 | # distance: number of commits since tag
344 | pieces["distance"] = int(mo.group(2))
345 |
346 | # commit: short hex revision ID
347 | pieces["short"] = mo.group(3)
348 |
349 | else:
350 | # HEX: no tags
351 | pieces["closest-tag"] = None
352 | count_out, rc = runner(GITS, ["rev-list", "HEAD", "--count"], cwd=root)
353 | pieces["distance"] = int(count_out) # total number of commits
354 |
355 | # commit date: see ISO-8601 comment in git_versions_from_keywords()
356 | date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip()
357 | # Use only the last line. Previous lines may contain GPG signature
358 | # information.
359 | date = date.splitlines()[-1]
360 | pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
361 |
362 | return pieces
363 |
364 |
365 | def plus_or_dot(pieces):
366 | """Return a + if we don't already have one, else return a ."""
367 | if "+" in pieces.get("closest-tag", ""):
368 | return "."
369 | return "+"
370 |
371 |
372 | def render_pep440(pieces):
373 | """Build up version string, with post-release "local version identifier".
374 |
375 | Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you
376 | get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty
377 |
378 | Exceptions:
379 | 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty]
380 | """
381 | if pieces["closest-tag"]:
382 | rendered = pieces["closest-tag"]
383 | if pieces["distance"] or pieces["dirty"]:
384 | rendered += plus_or_dot(pieces)
385 | rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
386 | if pieces["dirty"]:
387 | rendered += ".dirty"
388 | else:
389 | # exception #1
390 | rendered = "0+untagged.%d.g%s" % (pieces["distance"],
391 | pieces["short"])
392 | if pieces["dirty"]:
393 | rendered += ".dirty"
394 | return rendered
395 |
396 |
397 | def render_pep440_branch(pieces):
398 | """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] .
399 |
400 | The ".dev0" means not master branch. Note that .dev0 sorts backwards
401 | (a feature branch will appear "older" than the master branch).
402 |
403 | Exceptions:
404 | 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty]
405 | """
406 | if pieces["closest-tag"]:
407 | rendered = pieces["closest-tag"]
408 | if pieces["distance"] or pieces["dirty"]:
409 | if pieces["branch"] != "master":
410 | rendered += ".dev0"
411 | rendered += plus_or_dot(pieces)
412 | rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
413 | if pieces["dirty"]:
414 | rendered += ".dirty"
415 | else:
416 | # exception #1
417 | rendered = "0"
418 | if pieces["branch"] != "master":
419 | rendered += ".dev0"
420 | rendered += "+untagged.%d.g%s" % (pieces["distance"],
421 | pieces["short"])
422 | if pieces["dirty"]:
423 | rendered += ".dirty"
424 | return rendered
425 |
426 |
427 | def pep440_split_post(ver):
428 | """Split pep440 version string at the post-release segment.
429 |
430 | Returns the release segments before the post-release and the
431 | post-release version number (or -1 if no post-release segment is present).
432 | """
433 | vc = str.split(ver, ".post")
434 | return vc[0], int(vc[1] or 0) if len(vc) == 2 else None
435 |
436 |
437 | def render_pep440_pre(pieces):
438 | """TAG[.postN.devDISTANCE] -- No -dirty.
439 |
440 | Exceptions:
441 | 1: no tags. 0.post0.devDISTANCE
442 | """
443 | if pieces["closest-tag"]:
444 | if pieces["distance"]:
445 | # update the post release segment
446 | tag_version, post_version = pep440_split_post(pieces["closest-tag"])
447 | rendered = tag_version
448 | if post_version is not None:
449 | rendered += ".post%d.dev%d" % (post_version+1, pieces["distance"])
450 | else:
451 | rendered += ".post0.dev%d" % (pieces["distance"])
452 | else:
453 | # no commits, use the tag as the version
454 | rendered = pieces["closest-tag"]
455 | else:
456 | # exception #1
457 | rendered = "0.post0.dev%d" % pieces["distance"]
458 | return rendered
459 |
460 |
461 | def render_pep440_post(pieces):
462 | """TAG[.postDISTANCE[.dev0]+gHEX] .
463 |
464 | The ".dev0" means dirty. Note that .dev0 sorts backwards
465 | (a dirty tree will appear "older" than the corresponding clean one),
466 | but you shouldn't be releasing software with -dirty anyways.
467 |
468 | Exceptions:
469 | 1: no tags. 0.postDISTANCE[.dev0]
470 | """
471 | if pieces["closest-tag"]:
472 | rendered = pieces["closest-tag"]
473 | if pieces["distance"] or pieces["dirty"]:
474 | rendered += ".post%d" % pieces["distance"]
475 | if pieces["dirty"]:
476 | rendered += ".dev0"
477 | rendered += plus_or_dot(pieces)
478 | rendered += "g%s" % pieces["short"]
479 | else:
480 | # exception #1
481 | rendered = "0.post%d" % pieces["distance"]
482 | if pieces["dirty"]:
483 | rendered += ".dev0"
484 | rendered += "+g%s" % pieces["short"]
485 | return rendered
486 |
487 |
488 | def render_pep440_post_branch(pieces):
489 | """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] .
490 |
491 | The ".dev0" means not master branch.
492 |
493 | Exceptions:
494 | 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty]
495 | """
496 | if pieces["closest-tag"]:
497 | rendered = pieces["closest-tag"]
498 | if pieces["distance"] or pieces["dirty"]:
499 | rendered += ".post%d" % pieces["distance"]
500 | if pieces["branch"] != "master":
501 | rendered += ".dev0"
502 | rendered += plus_or_dot(pieces)
503 | rendered += "g%s" % pieces["short"]
504 | if pieces["dirty"]:
505 | rendered += ".dirty"
506 | else:
507 | # exception #1
508 | rendered = "0.post%d" % pieces["distance"]
509 | if pieces["branch"] != "master":
510 | rendered += ".dev0"
511 | rendered += "+g%s" % pieces["short"]
512 | if pieces["dirty"]:
513 | rendered += ".dirty"
514 | return rendered
515 |
516 |
517 | def render_pep440_old(pieces):
518 | """TAG[.postDISTANCE[.dev0]] .
519 |
520 | The ".dev0" means dirty.
521 |
522 | Exceptions:
523 | 1: no tags. 0.postDISTANCE[.dev0]
524 | """
525 | if pieces["closest-tag"]:
526 | rendered = pieces["closest-tag"]
527 | if pieces["distance"] or pieces["dirty"]:
528 | rendered += ".post%d" % pieces["distance"]
529 | if pieces["dirty"]:
530 | rendered += ".dev0"
531 | else:
532 | # exception #1
533 | rendered = "0.post%d" % pieces["distance"]
534 | if pieces["dirty"]:
535 | rendered += ".dev0"
536 | return rendered
537 |
538 |
539 | def render_git_describe(pieces):
540 | """TAG[-DISTANCE-gHEX][-dirty].
541 |
542 | Like 'git describe --tags --dirty --always'.
543 |
544 | Exceptions:
545 | 1: no tags. HEX[-dirty] (note: no 'g' prefix)
546 | """
547 | if pieces["closest-tag"]:
548 | rendered = pieces["closest-tag"]
549 | if pieces["distance"]:
550 | rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
551 | else:
552 | # exception #1
553 | rendered = pieces["short"]
554 | if pieces["dirty"]:
555 | rendered += "-dirty"
556 | return rendered
557 |
558 |
559 | def render_git_describe_long(pieces):
560 | """TAG-DISTANCE-gHEX[-dirty].
561 |
562 | Like 'git describe --tags --dirty --always -long'.
563 | The distance/hash is unconditional.
564 |
565 | Exceptions:
566 | 1: no tags. HEX[-dirty] (note: no 'g' prefix)
567 | """
568 | if pieces["closest-tag"]:
569 | rendered = pieces["closest-tag"]
570 | rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
571 | else:
572 | # exception #1
573 | rendered = pieces["short"]
574 | if pieces["dirty"]:
575 | rendered += "-dirty"
576 | return rendered
577 |
578 |
579 | def render(pieces, style):
580 | """Render the given version pieces into the requested style."""
581 | if pieces["error"]:
582 | return {"version": "unknown",
583 | "full-revisionid": pieces.get("long"),
584 | "dirty": None,
585 | "error": pieces["error"],
586 | "date": None}
587 |
588 | if not style or style == "default":
589 | style = "pep440" # the default
590 |
591 | if style == "pep440":
592 | rendered = render_pep440(pieces)
593 | elif style == "pep440-branch":
594 | rendered = render_pep440_branch(pieces)
595 | elif style == "pep440-pre":
596 | rendered = render_pep440_pre(pieces)
597 | elif style == "pep440-post":
598 | rendered = render_pep440_post(pieces)
599 | elif style == "pep440-post-branch":
600 | rendered = render_pep440_post_branch(pieces)
601 | elif style == "pep440-old":
602 | rendered = render_pep440_old(pieces)
603 | elif style == "git-describe":
604 | rendered = render_git_describe(pieces)
605 | elif style == "git-describe-long":
606 | rendered = render_git_describe_long(pieces)
607 | else:
608 | raise ValueError("unknown style '%s'" % style)
609 |
610 | return {"version": rendered, "full-revisionid": pieces["long"],
611 | "dirty": pieces["dirty"], "error": None,
612 | "date": pieces.get("date")}
613 |
614 |
615 | def get_versions():
616 | """Get version information or return default if unable to do so."""
617 | # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have
618 | # __file__, we can work backwards from there to the root. Some
619 | # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which
620 | # case we can only use expanded keywords.
621 |
622 | cfg = get_config()
623 | verbose = cfg.verbose
624 |
625 | try:
626 | return git_versions_from_keywords(get_keywords(), cfg.tag_prefix,
627 | verbose)
628 | except NotThisMethod:
629 | pass
630 |
631 | try:
632 | root = os.path.realpath(__file__)
633 | # versionfile_source is the relative path from the top of the source
634 | # tree (where the .git directory might live) to this file. Invert
635 | # this to find the root from __file__.
636 | for _ in cfg.versionfile_source.split('/'):
637 | root = os.path.dirname(root)
638 | except NameError:
639 | return {"version": "0+unknown", "full-revisionid": None,
640 | "dirty": None,
641 | "error": "unable to find root of source tree",
642 | "date": None}
643 |
644 | try:
645 | pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose)
646 | return render(pieces, cfg.style)
647 | except NotThisMethod:
648 | pass
649 |
650 | try:
651 | if cfg.parentdir_prefix:
652 | return versions_from_parentdir(cfg.parentdir_prefix, root, verbose)
653 | except NotThisMethod:
654 | pass
655 |
656 | return {"version": "0+unknown", "full-revisionid": None,
657 | "dirty": None,
658 | "error": "unable to compute version", "date": None}
659 |
--------------------------------------------------------------------------------
/src/torchcontentarea/csrc/common.hpp:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | #define MAX_POINT_COUNT 32
5 | #define DISCARD_BORDER 3
6 | #define DEG2RAD 0.01745329251f
7 | #define RAD2DEG (1.0f / DEG2RAD)
8 | #define MAX_CENTER_DIST 0.2 // * image width
9 | #define MIN_RADIUS 0.2 // * image width
10 | #define MAX_RADIUS 0.8 // * image width
11 | #define RANSAC_ATTEMPTS 32
12 | #define RANSAC_ITERATIONS 3
13 | #define RANSAC_INLIER_THRESHOLD 3
14 |
15 | typedef unsigned char uint8;
16 |
17 | struct FeatureThresholds
18 | {
19 | float edge;
20 | float angle;
21 | float intensity;
22 | };
23 |
24 | struct ConfidenceThresholds
25 | {
26 | float edge;
27 | float circle;
28 | };
29 |
30 | enum ImageFormat
31 | {
32 | rgb_float,
33 | rgb_double,
34 | rgb_uint8,
35 | rgb_int,
36 | rgb_long,
37 | gray_float,
38 | gray_double,
39 | gray_uint8,
40 | gray_int,
41 | gray_long,
42 | };
43 |
44 | struct Image
45 | {
46 | Image(ImageFormat format, const void* data) : format(format), data(data) {}
47 |
48 | ImageFormat format;
49 | const void* data;
50 | };
51 |
52 | #define ARG(...) __VA_ARGS__
53 | #define KERNEL_DISPATCH_IMAGE_FORMAT(FUNCTION, DISPATCH_ARGS, IMAGE, ...) \
54 | switch(IMAGE.format) \
55 | { \
56 | case(rgb_float): FUNCTION<3, float ><<>>((const float* )IMAGE.data, __VA_ARGS__); break; \
57 | case(rgb_double): FUNCTION<3, double ><<>>((const double* )IMAGE.data, __VA_ARGS__); break; \
58 | case(rgb_uint8): FUNCTION<3, uint8 ><<>>((const uint8* )IMAGE.data, __VA_ARGS__); break; \
59 | case(rgb_int): FUNCTION<3, int ><<>>((const int* )IMAGE.data, __VA_ARGS__); break; \
60 | case(rgb_long): FUNCTION<3, long int><<>>((const long int*)IMAGE.data, __VA_ARGS__); break; \
61 | case(gray_float): FUNCTION<1, float ><<>>((const float* )IMAGE.data, __VA_ARGS__); break; \
62 | case(gray_double): FUNCTION<1, double ><<>>((const double* )IMAGE.data, __VA_ARGS__); break; \
63 | case(gray_uint8): FUNCTION<1, uint8 ><<>>((const uint8* )IMAGE.data, __VA_ARGS__); break; \
64 | case(gray_int): FUNCTION<1, int ><<>>((const int* )IMAGE.data, __VA_ARGS__); break; \
65 | case(gray_long): FUNCTION<1, long int><<>>((const long int*)IMAGE.data, __VA_ARGS__); break; \
66 | }
67 |
68 | #define FUNCTION_CALL_IMAGE_FORMAT(FUNCTION, IMAGE, ...) \
69 | switch(IMAGE.format) \
70 | { \
71 | case(rgb_float): FUNCTION<3, float >((const float* )IMAGE.data, __VA_ARGS__); break; \
72 | case(rgb_double): FUNCTION<3, double >((const double* )IMAGE.data, __VA_ARGS__); break; \
73 | case(rgb_uint8): FUNCTION<3, uint8 >((const uint8* )IMAGE.data, __VA_ARGS__); break; \
74 | case(rgb_int): FUNCTION<3, int >((const int* )IMAGE.data, __VA_ARGS__); break; \
75 | case(rgb_long): FUNCTION<3, long int>((const long int*)IMAGE.data, __VA_ARGS__); break; \
76 | case(gray_float): FUNCTION<1, float >((const float* )IMAGE.data, __VA_ARGS__); break; \
77 | case(gray_double): FUNCTION<1, double >((const double* )IMAGE.data, __VA_ARGS__); break; \
78 | case(gray_uint8): FUNCTION<1, uint8 >((const uint8* )IMAGE.data, __VA_ARGS__); break; \
79 | case(gray_int): FUNCTION<1, int >((const int* )IMAGE.data, __VA_ARGS__); break; \
80 | case(gray_long): FUNCTION<1, long int>((const long int*)IMAGE.data, __VA_ARGS__); break; \
81 | }
82 |
--------------------------------------------------------------------------------
/src/torchcontentarea/csrc/cpu_functions.hpp:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "common.hpp"
3 |
4 | namespace cpu
5 | {
6 | void find_points(Image image, const int batch_count, const int channel_count, const int image_height, const int image_width, const int strip_count, const FeatureThresholds feature_thresholds, float* points_x, float* points_y, float* point_score);
7 | void make_strips(Image image, const int batch_count, const int channel_count, const int image_height, const int image_width, const int strip_count, const int strip_width, float* strips);
8 | void find_points_from_strip_scores(const float* strips, const int batch_count, const int image_width, const int image_height, const int strip_count, const int model_patch_size, float* points_x, float* points_y, float* point_score);
9 | void fit_circle(const float* points_x, const float* points_y, const float* points_score, const int batch_count, const int point_count, const ConfidenceThresholds confidence_thresholds, const int image_height, const int image_width, float* results);
10 | }
11 |
--------------------------------------------------------------------------------
/src/torchcontentarea/csrc/cuda_functions.cuh:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "common.hpp"
3 |
4 | namespace cuda
5 | {
6 | void find_points(Image image, const int batch_count, const int channel_count, const int image_height, const int image_width, const int strip_count, const FeatureThresholds feature_thresholds, float* points_x, float* points_y, float* point_score);
7 | void make_strips(Image image, const int batch_count, const int channel_count, const int image_height, const int image_width, const int strip_count, const int strip_width, float* strips);
8 | void find_points_from_strip_scores(const float* strips, const int batch_count, const int image_width, const int image_height, const int strip_count, const int model_patch_size, float* points_x, float* points_y, float* point_score);
9 | void fit_circle(const float* points_x, const float* points_y, const float* points_score, const int batch_count, const int point_count, const ConfidenceThresholds confidence_thresholds, const int image_height, const int image_width, float* results);
10 | }
11 |
--------------------------------------------------------------------------------
/src/torchcontentarea/csrc/implementation.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include "implementation.hpp"
4 | #include "cpu_functions.hpp"
5 | #include "cuda_functions.cuh"
6 |
7 | #define IMAGE_DTYPE_ERROR_MSG(t) std::string("Unsupported image dtype .").insert(24, torch::utils::getDtypeNames(t).second)
8 | #define IMAGE_NDIM_ERROR_MSG(d) std::string("Expected an image tensor with 3 or 4 dimensions but found .").insert(58, std::to_string(d))
9 | #define IMAGE_CHANNEL_ERROR_MSG(c) std::string("Expected a grayscale or RGB image but found size at position 1.").insert(49, std::to_string(c))
10 | #define POINTS_NDIM_ERROR_MSG(d) std::string("Expected a point tensor with 2 or 3 dimensions but found .").insert(52, std::to_string(d))
11 | #define POINTS_CHANNEL_ERROR_MSG(d) std::string("Expected a point tensor with 3 channels but found .").insert(50, std::to_string(d))
12 |
13 | void check_image_tensor(torch::Tensor &image)
14 | {
15 | image = image.contiguous();
16 |
17 | if (image.dim() != 3 && image.dim() != 4)
18 | {
19 | throw std::runtime_error(IMAGE_NDIM_ERROR_MSG(image.dim()));
20 | }
21 |
22 | if (image.size(-3) != 1 && image.size(-3) != 3)
23 | {
24 | throw std::runtime_error(IMAGE_CHANNEL_ERROR_MSG(image.size(1)));
25 | }
26 |
27 | switch (torch::typeMetaToScalarType(image.dtype()))
28 | {
29 | case (torch::kFloat): break;
30 | case (torch::kDouble): break;
31 | case (torch::kByte): break;
32 | case (torch::kInt): break;
33 | case (torch::kLong): break;
34 | default: throw std::runtime_error(IMAGE_DTYPE_ERROR_MSG(torch::typeMetaToScalarType(image.dtype())));
35 | }
36 | }
37 |
38 | void check_points(torch::Tensor &points)
39 | {
40 | points = points.contiguous();
41 |
42 | if (points.dim() != 2 && points.dim() != 3)
43 | {
44 | throw std::runtime_error(POINTS_NDIM_ERROR_MSG(points.dim()));
45 | }
46 |
47 | if (points.size(-2) != 3)
48 | {
49 | throw std::runtime_error(POINTS_CHANNEL_ERROR_MSG(points.size(1)));
50 | }
51 | }
52 |
53 | Image get_image_data(torch::Tensor image)
54 | {
55 | bool is_rgb = image.size(-3) == 3;
56 | switch (torch::typeMetaToScalarType(image.dtype()))
57 | {
58 | case (torch::kFloat): return Image(is_rgb ? ImageFormat::rgb_float : ImageFormat::gray_float, (void*)image.data_ptr());
59 | case (torch::kDouble): return Image(is_rgb ? ImageFormat::rgb_double : ImageFormat::gray_double, (void*)image.data_ptr());
60 | case (torch::kByte): return Image(is_rgb ? ImageFormat::rgb_uint8 : ImageFormat::gray_uint8, (void*)image.data_ptr());
61 | case (torch::kInt): return Image(is_rgb ? ImageFormat::rgb_int : ImageFormat::gray_int, (void*)image.data_ptr());
62 | case (torch::kLong): return Image(is_rgb ? ImageFormat::rgb_long : ImageFormat::gray_long, (void*)image.data_ptr());
63 | default: throw std::runtime_error(IMAGE_DTYPE_ERROR_MSG(torch::typeMetaToScalarType(image.dtype())));
64 | }
65 | }
66 |
67 | torch::Tensor estimate_area_handcrafted(torch::Tensor image, int strip_count, FeatureThresholds feature_thresholds, ConfidenceThresholds confidence_thresholds)
68 | {
69 | check_image_tensor(image);
70 |
71 | Image image_data = get_image_data(image);
72 |
73 | bool batched = image.dim() == 4;
74 |
75 | int batch_count = batched ? image.size(0) : 1;
76 | int channel_count = image.size(-3);
77 | int image_height = image.size(-2);
78 | int image_width = image.size(-1);
79 | int point_count = 2 * strip_count;
80 |
81 | torch::TensorOptions options = torch::device(image.device()).dtype(torch::kFloat32);
82 | torch::Tensor result = batched ? torch::empty({batch_count, 4}, options) : torch::empty({4}, options);
83 |
84 | if (image.device().is_cpu())
85 | {
86 | float* temp_buffer = (float*)malloc(3 * batch_count * point_count * sizeof(float));
87 | float* points_x = temp_buffer + 0 * point_count;
88 | float* points_y = temp_buffer + 1 * point_count;
89 | float* points_s = temp_buffer + 2 * point_count;
90 |
91 | cpu::find_points(image_data, batch_count, channel_count, image_height, image_width, strip_count, feature_thresholds, points_x, points_y, points_s);
92 |
93 | cpu::fit_circle(points_x, points_y, points_s, batch_count, point_count, confidence_thresholds, image_height, image_width, result.data_ptr());
94 |
95 | free(temp_buffer);
96 | }
97 | else
98 | {
99 | float* temp_buffer;
100 | cudaMalloc((void**)&temp_buffer, 3 * batch_count * point_count * sizeof(float));
101 | float* points_x = temp_buffer + 0 * point_count;
102 | float* points_y = temp_buffer + 1 * point_count;
103 | float* points_s = temp_buffer + 2 * point_count;
104 |
105 | cuda::find_points(image_data, batch_count, channel_count, image_height, image_width, strip_count, feature_thresholds, points_x, points_y, points_s);
106 |
107 | cuda::fit_circle(points_x, points_y, points_s, batch_count, point_count, confidence_thresholds, image_height, image_width, result.data_ptr());
108 |
109 | cudaFree(temp_buffer);
110 | }
111 |
112 | return result;
113 | }
114 |
115 | torch::Tensor estimate_area_learned(torch::Tensor image, int strip_count, torch::jit::Module model, int model_patch_size, ConfidenceThresholds confidence_thresholds)
116 | {
117 | check_image_tensor(image);
118 |
119 | Image image_data = get_image_data(image);
120 |
121 | bool batched = image.dim() == 4;
122 |
123 | int batch_count = batched ? image.size(0) : 1;
124 | int channel_count = image.size(-3);
125 | int image_height = image.size(-2);
126 | int image_width = image.size(-1);
127 | int point_count = 2 * strip_count;
128 |
129 | torch::TensorOptions options = torch::device(image.device()).dtype(torch::kFloat32);
130 | torch::Tensor result = batched ? torch::empty({batch_count, 4}, options) : torch::empty({4}, options);
131 |
132 | torch::Tensor strips = torch::empty({batch_count * strip_count, 5, model_patch_size, image_width}, options);
133 | std::vector model_input = {strips};
134 |
135 | if (image.device().is_cpu())
136 | {
137 | float* temp_buffer = (float*)malloc(3 * batch_count * point_count * sizeof(float));
138 | float* points_x = temp_buffer + 0 * point_count;
139 | float* points_y = temp_buffer + 1 * point_count;
140 | float* points_s = temp_buffer + 2 * point_count;
141 |
142 | cpu::make_strips(image_data, batch_count, channel_count, image_height, image_width, strip_count, model_patch_size, strips.data_ptr());
143 |
144 | torch::Tensor strip_scores = torch::sigmoid(model.forward(model_input).toTensor());
145 |
146 | cpu::find_points_from_strip_scores(strip_scores.data_ptr(), batch_count, image_height, image_width, strip_count, model_patch_size, points_x, points_y, points_s);
147 |
148 | cpu::fit_circle(points_x, points_y, points_s, batch_count, point_count, confidence_thresholds, image_height, image_width, result.data_ptr());
149 |
150 | free(temp_buffer);
151 | }
152 | else
153 | {
154 | float* temp_buffer;
155 | cudaMalloc((void**)&temp_buffer, 3 * batch_count * point_count * sizeof(float));
156 | float* points_x = temp_buffer + 0 * point_count;
157 | float* points_y = temp_buffer + 1 * point_count;
158 | float* points_s = temp_buffer + 2 * point_count;
159 |
160 | cuda::make_strips(image_data, batch_count, channel_count, image_height, image_width, strip_count, model_patch_size, strips.data_ptr());
161 |
162 | torch::Tensor strip_scores = torch::sigmoid(model.forward(model_input).toTensor());
163 |
164 | cuda::find_points_from_strip_scores(strip_scores.data_ptr(), batch_count, image_height, image_width, strip_count, model_patch_size, points_x, points_y, points_s);
165 |
166 | cuda::fit_circle(points_x, points_y, points_s, batch_count, point_count, confidence_thresholds, image_height, image_width, result.data_ptr());
167 |
168 | cudaFree(temp_buffer);
169 | }
170 |
171 | return result;
172 | }
173 |
174 | torch::Tensor get_points_handcrafted(torch::Tensor image, int strip_count, FeatureThresholds feature_thresholds)
175 | {
176 | check_image_tensor(image);
177 |
178 | Image image_data = get_image_data(image);
179 |
180 | bool batched = image.dim() == 4;
181 |
182 | int batch_count = batched ? image.size(0) : 1;
183 | int channel_count = image.size(-3);
184 | int image_height = image.size(-2);
185 | int image_width = image.size(-1);
186 | int point_count = 2 * strip_count;
187 |
188 | torch::TensorOptions options = torch::device(image.device()).dtype(torch::kFloat32);
189 | torch::Tensor result = batched ? torch::empty({batch_count, 3, point_count}, options) : torch::empty({3, point_count}, options);
190 |
191 | float* temp_buffer = result.data_ptr();
192 | float* points_x = temp_buffer + 0 * point_count;
193 | float* points_y = temp_buffer + 1 * point_count;
194 | float* points_s = temp_buffer + 2 * point_count;
195 |
196 | if (image.device().is_cpu())
197 | {
198 | cpu::find_points(image_data, batch_count, channel_count, image_height, image_width, strip_count, feature_thresholds, points_x, points_y, points_s);
199 | }
200 | else
201 | {
202 | cuda::find_points(image_data, batch_count, channel_count, image_height, image_width, strip_count, feature_thresholds, points_x, points_y, points_s);
203 | }
204 |
205 | return result;
206 | }
207 |
208 | torch::Tensor get_points_learned(torch::Tensor image, int strip_count, torch::jit::Module model, int model_patch_size)
209 | {
210 | check_image_tensor(image);
211 |
212 | Image image_data = get_image_data(image);
213 |
214 | bool batched = image.dim() == 4;
215 |
216 | int batch_count = batched ? image.size(0) : 1;
217 | int channel_count = image.size(-3);
218 | int image_height = image.size(-2);
219 | int image_width = image.size(-1);
220 | int point_count = 2 * strip_count;
221 |
222 | torch::TensorOptions options = torch::device(image.device()).dtype(torch::kFloat32);
223 | torch::Tensor result = batched ? torch::empty({batch_count, 3, point_count}, options) : torch::empty({3, point_count}, options);
224 |
225 | torch::Tensor strips = torch::empty({batch_count * strip_count, 5, model_patch_size, image_width}, options);
226 | std::vector model_input = {strips};
227 |
228 | float* temp_buffer = result.data_ptr();
229 | float* points_x = temp_buffer + 0 * point_count;
230 | float* points_y = temp_buffer + 1 * point_count;
231 | float* points_s = temp_buffer + 2 * point_count;
232 |
233 | if (image.device().is_cpu())
234 | {
235 | cpu::make_strips(image_data, batch_count, channel_count, image_height, image_width, strip_count, model_patch_size, strips.data_ptr());
236 |
237 | torch::Tensor strip_scores = torch::sigmoid(model.forward(model_input).toTensor());
238 |
239 | cpu::find_points_from_strip_scores(strip_scores.data_ptr(), batch_count, image_height, image_width, strip_count, model_patch_size, points_x, points_y, points_s);
240 | }
241 | else
242 | {
243 | cuda::make_strips(image_data, batch_count, channel_count, image_height, image_width, strip_count, model_patch_size, strips.data_ptr());
244 |
245 | torch::Tensor strip_scores = torch::sigmoid(model.forward(model_input).toTensor());
246 |
247 | cuda::find_points_from_strip_scores(strip_scores.data_ptr(), batch_count, image_height, image_width, strip_count, model_patch_size, points_x, points_y, points_s);
248 | }
249 |
250 | return result;
251 | }
252 |
253 | torch::Tensor fit_area(torch::Tensor points, py::tuple image_size, ConfidenceThresholds confidence_thresholds)
254 | {
255 | check_points(points);
256 |
257 | bool batched = points.dim() == 3;
258 |
259 | int batch_count = batched ? points.size(0) : 1;
260 | int image_height = image_size[0].cast();
261 | int image_width = image_size[1].cast();
262 | int point_count = points.size(-1);
263 |
264 | torch::TensorOptions options = torch::device(points.device()).dtype(torch::kFloat32);
265 | torch::Tensor result = batched ? torch::empty({batch_count, 4}, options) : torch::empty({4}, options);
266 |
267 | float* temp_buffer = points.data_ptr();
268 | float* points_x = temp_buffer + 0 * point_count;
269 | float* points_y = temp_buffer + 1 * point_count;
270 | float* points_s = temp_buffer + 2 * point_count;
271 |
272 | if (points.device().is_cpu())
273 | {
274 | cpu::fit_circle(points_x, points_y, points_s, batch_count, point_count, confidence_thresholds, image_height, image_width, result.data_ptr());
275 | }
276 | else
277 | {
278 | cuda::fit_circle(points_x, points_y, points_s, batch_count, point_count, confidence_thresholds, image_height, image_width, result.data_ptr());
279 | }
280 |
281 | return result;
282 | }
283 |
--------------------------------------------------------------------------------
/src/torchcontentarea/csrc/implementation.hpp:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include "common.hpp"
4 |
5 | torch::Tensor estimate_area_handcrafted(torch::Tensor image, int strip_count, FeatureThresholds feature_thresholds, ConfidenceThresholds confidence_thresholds);
6 | torch::Tensor estimate_area_learned(torch::Tensor image, int strip_count, torch::jit::Module model, int model_patch_size, ConfidenceThresholds confidence_thresholds);
7 |
8 | torch::Tensor get_points_handcrafted(torch::Tensor points, int strip_count, FeatureThresholds feature_thresholds);
9 | torch::Tensor get_points_learned(torch::Tensor points, int strip_count, torch::jit::Module model, int model_patch_size);
10 |
11 | torch::Tensor fit_area(torch::Tensor points, py::tuple image_size, ConfidenceThresholds confidence_thresholds);
12 |
--------------------------------------------------------------------------------
/src/torchcontentarea/csrc/python_bindings.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include "implementation.hpp"
3 |
4 | template<>
5 | struct py::detail::type_caster
6 | {
7 | PYBIND11_TYPE_CASTER(FeatureThresholds, py::detail::_("FeatureThresholds"));
8 |
9 | bool load(handle src, bool)
10 | {
11 | if (!src | src.is_none() | !py::isinstance(src))
12 | return false;
13 |
14 | py::tuple args = reinterpret_borrow(src);
15 | if (len(args) != 3)
16 | return false;
17 |
18 | value.edge = args[0].cast();
19 | value.angle = args[1].cast();
20 | value.intensity = args[2].cast();
21 | return true;
22 | }
23 | };
24 |
25 | template<>
26 | struct py::detail::type_caster
27 | {
28 | PYBIND11_TYPE_CASTER(ConfidenceThresholds, py::detail::_("ConfidenceThresholds"));
29 |
30 | bool load(handle src, bool convert)
31 | {
32 | if (!src | src.is_none() | !py::isinstance(src))
33 | return false;
34 |
35 | py::tuple args = reinterpret_borrow(src);
36 | if (len(args) != 2)
37 | return false;
38 |
39 | value.edge = args[0].cast();
40 | value.circle = args[1].cast();
41 | return true;
42 | }
43 | };
44 |
45 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
46 | {
47 | m.def("estimate_area_handcrafted", &estimate_area_handcrafted);
48 | m.def("estimate_area_learned", &estimate_area_learned);
49 | m.def("get_points_handcrafted", &get_points_handcrafted);
50 | m.def("get_points_learned", &get_points_learned);
51 | m.def("fit_area", &fit_area);
52 | }
53 |
--------------------------------------------------------------------------------
/src/torchcontentarea/csrc/source/find_points_cpu.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include "../common.hpp"
3 |
4 | namespace cpu
5 | {
6 | // =========================================================================
7 | // General functionality...
8 |
9 | template
10 | float load_element(const T* data, const int index, const int color_stride)
11 | {
12 | float value = c == 1 ? data[index] : 0.2126f * data[index + 0 * color_stride] + 0.7152f * data[index + 1 * color_stride] + 0.0722f * data[index + 2 * color_stride];
13 | return std::is_floating_point::value ? 255 * value : value;
14 | }
15 |
16 | template
17 | float load_sobel_strip(const T* data, const int index, const int spatial_stride, const int color_stride)
18 | {
19 | return 0.25 * load_element(data, index - spatial_stride, color_stride) + 0.5 * load_element(data, index, color_stride) + 0.25 * load_element(data, index + spatial_stride, color_stride);
20 | }
21 |
22 | // =========================================================================
23 | // Main function...
24 |
25 | template
26 | void find_points(const T* images, const int batch_count, const int image_height, const int image_width, const int strip_count, FeatureThresholds feature_thresholds, float* points_x, float* points_y, float* point_scores)
27 | {
28 | for (int batch_index = 0; batch_index < batch_count; ++batch_index)
29 | {
30 | const T* image = images + batch_index * c * image_height * image_width;
31 |
32 | for (int strip_index = 0; strip_index < strip_count; ++strip_index)
33 | {
34 | int image_y = 1 + (image_height - 2) / (1.0f + std::exp(-(strip_index - strip_count / 2.0f + 0.5f)/(strip_count / 8.0f)));
35 |
36 | for (int point_index = 0; point_index < 2; ++point_index)
37 | {
38 | bool flip = point_index > 0;
39 |
40 | float max_preceeding_intensity = 0.0f;
41 | float best_score = 0.0f;
42 | int best_index = 0;
43 |
44 | for (int x = 1; x < image_width / 2; ++x)
45 | {
46 | int image_x = flip ? image_width - 1 - x : x;
47 |
48 | float intensity = c == 3 ? load_element(image, image_x + image_y * image_width, image_width * image_height) : image[image_x + image_y * image_width];
49 | max_preceeding_intensity = max_preceeding_intensity < intensity ? intensity : max_preceeding_intensity;
50 |
51 | float left = load_sobel_strip(image, (image_x - 1) + image_y * image_width, image_width, image_width * image_height);
52 | float right = load_sobel_strip(image, (image_x + 1) + image_y * image_width, image_width, image_width * image_height);
53 | float top = load_sobel_strip(image, image_x + (image_y - 1) * image_width, 1, image_width * image_height);
54 | float bot = load_sobel_strip(image, image_x + (image_y + 1) * image_width, 1, image_width * image_height);
55 |
56 | float grad_x = right - left;
57 | float grad_y = bot - top;
58 | float grad = sqrt(grad_x * grad_x + grad_y * grad_y);
59 |
60 | float center_dir_x = (0.5f * image_width) - (float)image_x;
61 | float center_dir_y = (0.5f * image_height) - (float)image_y;
62 | float center_dir_norm = sqrt(center_dir_x * center_dir_x + center_dir_y * center_dir_y);
63 |
64 | float dot = grad == 0 ? -1 : (center_dir_x * grad_x + center_dir_y * grad_y) / (center_dir_norm * grad);
65 | float angle = RAD2DEG * acos(dot);
66 |
67 | // ============================================================
68 | // Final scoring...
69 |
70 | float edge_score = tanh(grad / feature_thresholds.edge);
71 | float angle_score = 1.0f - tanh(angle / feature_thresholds.angle);
72 | float intensity_score = 1.0f - tanh(max_preceeding_intensity / feature_thresholds.intensity);
73 |
74 | float point_score = edge_score * angle_score * intensity_score;
75 |
76 | if (point_score > best_score)
77 | {
78 | best_score = point_score;
79 | best_index = image_x;
80 | }
81 | }
82 |
83 | if (best_index < DISCARD_BORDER || best_index >= image_width - DISCARD_BORDER)
84 | {
85 | best_score = 0.0f;
86 | }
87 |
88 | points_x[strip_index + point_index * strip_count + batch_index * 3 * 2 * strip_count] = best_index;
89 | points_y[strip_index + point_index * strip_count + batch_index * 3 * 2 * strip_count] = image_y;
90 | point_scores[strip_index + point_index * strip_count + batch_index * 3 * 2 * strip_count] = best_score;
91 | }
92 | }
93 | }
94 | }
95 |
96 | void find_points(Image image, const int batch_count, const int channel_count, const int image_height, const int image_width, const int strip_count, FeatureThresholds feature_thresholds, float* points_x, float* points_y, float* point_scores)
97 | {
98 | FUNCTION_CALL_IMAGE_FORMAT(find_points, image, batch_count, image_height, image_width, strip_count, feature_thresholds, points_x, points_y, point_scores);
99 | }
100 | }
101 |
--------------------------------------------------------------------------------
/src/torchcontentarea/csrc/source/find_points_cuda.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include "../common.hpp"
3 |
4 | namespace cuda
5 | {
6 | // =========================================================================
7 | // General functionality...
8 |
9 | template
10 | __device__ float load_element(const T* data, const int index, const int color_stride)
11 | {
12 | float value = c == 1 ? data[index] : 0.2126f * data[index + 0 * color_stride] + 0.7152f * data[index + 1 * color_stride] + 0.0722f * data[index + 2 * color_stride];
13 | return std::is_floating_point::value ? 255 * value : value;
14 | }
15 |
16 | __device__ float sobel_filter(const float* data, const int index, const int x_stride, const int y_stride, float* x_grad, float* y_grad)
17 | {
18 | float left = 0.25f * data[index - x_stride - y_stride] + 0.5f * data[index - x_stride] + 0.25f * data[index - x_stride + y_stride];
19 | float right = 0.25f * data[index + x_stride - y_stride] + 0.5f * data[index + x_stride] + 0.25f * data[index + x_stride + y_stride];
20 | *x_grad = right - left;
21 |
22 | float top = 0.25f * data[index - x_stride - y_stride] + 0.5f * data[index - y_stride] + 0.25f * data[index + x_stride - y_stride];
23 | float bot = 0.25f * data[index - x_stride + y_stride] + 0.5f * data[index + y_stride] + 0.25f * data[index + x_stride + y_stride];
24 | *y_grad = bot - top;
25 |
26 | return sqrt(*x_grad * *x_grad + *y_grad * *y_grad);
27 | }
28 |
29 | // =========================================================================
30 | // Kernels...
31 |
32 | template
33 | __global__ void find_points_kernel(const T* g_image_batch, float* g_edge_x_batch, float* g_edge_y_batch, float* g_edge_scores_batch, const int image_width, const int image_height, const int strip_count, const FeatureThresholds feature_thresholds)
34 | {
35 | constexpr int warp_size = 32;
36 |
37 | const T* g_image = g_image_batch + blockIdx.z * c * image_width * image_height;
38 | float* g_edge_x = g_edge_x_batch + blockIdx.z * 3 * 2 * strip_count;
39 | float* g_edge_y = g_edge_y_batch + blockIdx.z * 3 * 2 * strip_count;
40 | float* g_edge_scores = g_edge_scores_batch + blockIdx.z * 3 * 2 * strip_count;
41 |
42 | int thread_count = blockDim.x;
43 | int warp_count = 1 + (thread_count - 1) / warp_size;
44 |
45 | extern __shared__ int s_shared_buffer[];
46 | float* s_image_strip = (float*)s_shared_buffer;
47 | int* s_cross_warp_operation_buffer = s_shared_buffer + 3 * thread_count;
48 | float* s_cross_warp_operation_buffer_2 = (float*)(s_shared_buffer + 3 * thread_count + warp_count);
49 |
50 | int warp_index = threadIdx.x >> 5;
51 | int lane_index = threadIdx.x & 31;
52 |
53 | bool flip = blockIdx.x == 1;
54 |
55 | // ============================================================
56 | // Load strip into shared memory...
57 |
58 | int image_x = flip ? image_width - 1 - threadIdx.x : threadIdx.x;
59 |
60 | int strip_index = blockIdx.y;
61 | int strip_height = 1 + (image_height - 2) / (1.0f + exp(-(strip_index - strip_count / 2.0f + 0.5f)/(strip_count / 8.0f)));
62 |
63 | #pragma unroll
64 | for (int y = 0; y < 3; y++)
65 | {
66 | int image_element_index = image_x + (strip_height + (y - 1)) * image_width;
67 | s_image_strip[threadIdx.x + y * thread_count] = load_element(g_image, image_element_index, image_width * image_height);
68 | }
69 |
70 | __syncthreads();
71 |
72 | // ============================================================
73 | // Calculate largest preceeding intensity...
74 |
75 | float max_preceeding_intensity = s_image_strip[threadIdx.x + thread_count];
76 |
77 | #pragma unroll
78 | for (int d=1; d < 32; d<<=1)
79 | {
80 | float other_intensity = __shfl_up_sync(0xffffffff, max_preceeding_intensity, d);
81 |
82 | if (lane_index >= d && other_intensity > max_preceeding_intensity)
83 | {
84 | max_preceeding_intensity = other_intensity;
85 | }
86 | }
87 |
88 | if (lane_index == warp_size - 1)
89 | {
90 | s_cross_warp_operation_buffer[warp_index] = max_preceeding_intensity;
91 | }
92 |
93 | __syncthreads();
94 |
95 | if (warp_index == 0)
96 | {
97 | float warp_max = lane_index < warp_count ? s_cross_warp_operation_buffer[lane_index] : 0;
98 |
99 | #pragma unroll
100 | for (int d=1; d < 32; d<<=1)
101 | {
102 | float other_max = __shfl_up_sync(0xffffffff, warp_max, d);
103 |
104 | if (lane_index >= d && other_max > warp_max)
105 | {
106 | warp_max = other_max;
107 | }
108 | }
109 |
110 | if (lane_index < warp_count)
111 | {
112 | s_cross_warp_operation_buffer[lane_index] = warp_max;
113 | }
114 | }
115 |
116 | __syncthreads();
117 |
118 | if (warp_index > 0)
119 | {
120 | float other_intensity = s_cross_warp_operation_buffer[warp_index-1];
121 | max_preceeding_intensity = other_intensity > max_preceeding_intensity ? other_intensity : max_preceeding_intensity;
122 | }
123 |
124 | // ============================================================
125 | // Applying sobel kernel to image patch...
126 |
127 | float x_grad = 0;
128 | float y_grad = 0;
129 | float grad = 0;
130 |
131 | if (threadIdx.x > 0 && threadIdx.x < thread_count - 1)
132 | {
133 | grad = sobel_filter(s_image_strip, threadIdx.x + thread_count, 1, thread_count, &x_grad, &y_grad);
134 | }
135 |
136 | // ============================================================
137 | // Calculating angle between gradient vector and center vector...
138 |
139 | float center_dir_x = (0.5f * image_width) - (float)image_x;
140 | float center_dir_y = (0.5f * image_height) - (float)strip_height;
141 | float center_dir_norm = sqrt(center_dir_x * center_dir_x + center_dir_y * center_dir_y);
142 |
143 | x_grad = flip ? -x_grad : x_grad;
144 |
145 | float dot = grad == 0 ? -1 : (center_dir_x * x_grad + center_dir_y * y_grad) / (center_dir_norm * grad);
146 | float angle = RAD2DEG * acos(dot);
147 |
148 | // ============================================================
149 | // Final scoring...
150 |
151 | float edge_score = tanh(grad / feature_thresholds.edge);
152 | float angle_score = 1.0f - tanh(angle / feature_thresholds.angle);
153 | float intensity_score = 1.0f - tanh(max_preceeding_intensity / feature_thresholds.intensity);
154 |
155 | float point_score = edge_score * angle_score * intensity_score;
156 |
157 | // ============================================================
158 | // Reduction to find the best edge...
159 |
160 | int best_edge_x = image_x;
161 | float best_edge_score = point_score;
162 |
163 | // warp reduction....
164 | #pragma unroll
165 | for (int offset = warp_size >> 1; offset > 0; offset >>= 1)
166 | {
167 | int other_edge_x = __shfl_down_sync(0xffffffff, best_edge_x, offset);
168 | float other_edge_score = __shfl_down_sync(0xffffffff, best_edge_score, offset);
169 |
170 | if (other_edge_score > best_edge_score)
171 | {
172 | best_edge_x = other_edge_x;
173 | best_edge_score = other_edge_score;
174 | }
175 | }
176 |
177 | if (lane_index == 0)
178 | {
179 | s_cross_warp_operation_buffer[warp_index] = best_edge_x;
180 | s_cross_warp_operation_buffer_2[warp_index] = best_edge_score;
181 | }
182 |
183 | __syncthreads();
184 |
185 | // block reduction....
186 | if (warp_index == 0 && lane_index < warp_count)
187 | {
188 | best_edge_x = s_cross_warp_operation_buffer[lane_index];
189 | best_edge_score = s_cross_warp_operation_buffer_2[lane_index];
190 |
191 | int next_power_two = pow(2, ceil(log(warp_count)/log(2)));
192 |
193 | #pragma unroll
194 | for (int offset = next_power_two >> 1 ; offset > 0; offset >>= 1)
195 | {
196 | int other_edge_x = __shfl_down_sync(0xffffffff, best_edge_x, offset);
197 | float other_edge_score = __shfl_down_sync(0xffffffff, best_edge_score, offset);
198 |
199 | if (other_edge_score > best_edge_score)
200 | {
201 | best_edge_x = other_edge_x;
202 | best_edge_score = other_edge_score;
203 | }
204 | }
205 |
206 | if (lane_index == 0)
207 | {
208 | int point_index = flip ? strip_index : strip_index + strip_count;
209 |
210 | if (best_edge_x < DISCARD_BORDER || best_edge_x >= image_width - DISCARD_BORDER)
211 | {
212 | best_edge_score = 0.0f;
213 | }
214 |
215 | g_edge_x[point_index] = best_edge_x;
216 | g_edge_y[point_index] = strip_height;
217 | g_edge_scores[point_index] = best_edge_score;
218 | }
219 | }
220 | }
221 |
222 | // =========================================================================
223 | // Main function...
224 |
225 | void find_points(Image image, const int batch_count, const int channel_count, const int image_height, const int image_width, const int strip_count, const FeatureThresholds feature_thresholds, float* points_x, float* points_y, float* point_scores)
226 | {
227 | int half_width = image_width / 2;
228 | int warps = 1 + (half_width - 1) / 32;
229 | int threads = warps * 32;
230 |
231 | threads = threads > 1024 ? 1024 : threads;
232 |
233 | dim3 grid(2, strip_count, batch_count);
234 | dim3 block(threads);
235 | int shared_memory = (3 * threads + 2 * warps) * sizeof(int);
236 |
237 | KERNEL_DISPATCH_IMAGE_FORMAT(find_points_kernel, ARG(grid, block, shared_memory), image, points_x, points_y, point_scores, image_width, image_height, strip_count, feature_thresholds);
238 | }
239 | }
240 |
--------------------------------------------------------------------------------
/src/torchcontentarea/csrc/source/find_points_from_scores_cpu.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include "../common.hpp"
3 |
4 | namespace cpu
5 | {
6 | void find_points_from_strip_scores(const float* strips, const int batch_count, const int image_height, const int image_width, const int strip_count, const int model_patch_size, float* points_x, float* points_y, float* point_score)
7 | {
8 | for (int batch_index = 0; batch_index < batch_count; ++batch_index)
9 | {
10 | int half_patch_size = model_patch_size / 2;
11 | int strip_width = image_width - 2 * half_patch_size;
12 |
13 | for (int strip_index = 0; strip_index < strip_count; ++strip_index)
14 | {
15 | int image_y = 1 + (image_height - 2) / (1.0f + std::exp(-(strip_index - strip_count / 2.0f + 0.5f) / (strip_count / 8.0f)));
16 |
17 | float best_score = 0.0f;
18 | int best_index = 0;
19 |
20 | for (int strip_x = 0; strip_x < strip_width / 2; ++strip_x)
21 | {
22 | float point_score = strips[strip_x + strip_index * strip_width + batch_index * strip_count * strip_width];
23 |
24 | if (point_score > best_score)
25 | {
26 | best_score = point_score;
27 | best_index = strip_x + half_patch_size;
28 | }
29 | }
30 |
31 | points_x[strip_index + batch_index * 3 * 2 * strip_count] = best_index;
32 | points_y[strip_index + batch_index * 3 * 2 * strip_count] = image_y;
33 | point_score[strip_index + batch_index * 3 * 2 * strip_count] = best_score;
34 |
35 | best_score = 0.0f;
36 | best_index = 0;
37 |
38 | for (int strip_x = strip_width / 2; strip_x < strip_width; ++strip_x)
39 | {
40 | float point_score = strips[strip_x + strip_index * strip_width + batch_index * strip_count * strip_width];
41 |
42 | if (point_score > best_score)
43 | {
44 | best_score = point_score;
45 | best_index = strip_x + half_patch_size;
46 | }
47 | }
48 |
49 | points_x[strip_index + strip_count + batch_index * 3 * 2 * strip_count] = best_index;
50 | points_y[strip_index + strip_count + batch_index * 3 * 2 * strip_count] = image_y;
51 | point_score[strip_index + strip_count + batch_index * 3 * 2 * strip_count] = best_score;
52 | }
53 | }
54 | }
55 | }
56 |
--------------------------------------------------------------------------------
/src/torchcontentarea/csrc/source/find_points_from_scores_cuda.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include "../common.hpp"
3 |
4 | namespace cuda
5 | {
6 | __global__ void find_best_edge(const float* g_score_strips_batch, float* g_edge_x_batch, float* g_edge_y_batch, float* g_edge_scores_batch, const int strip_width, const int image_height, const int strip_count, const int half_patch_size)
7 | {
8 | int thread_count = blockDim.x;
9 | int warp_count = 1 + (thread_count - 1) / 32;
10 |
11 | extern __shared__ float s_shared_buffer[];
12 | float* s_cross_warp_operation_buffer = s_shared_buffer;
13 | float* s_cross_warp_operation_buffer_2 = s_shared_buffer + warp_count;
14 |
15 | const float* g_score_strips = g_score_strips_batch + blockIdx.z * strip_count * strip_width;
16 | float* g_edge_x = g_edge_x_batch + blockIdx.z * 3 * 2 * strip_count;
17 | float* g_edge_y = g_edge_y_batch + blockIdx.z * 3 * 2 * strip_count;
18 | float* g_edge_scores = g_edge_scores_batch + blockIdx.z * 3 * 2 * strip_count;
19 |
20 | int warp_index = threadIdx.x >> 5;
21 | int lane_index = threadIdx.x & 31;
22 |
23 | bool flip = blockIdx.x == 1;
24 |
25 | // ============================================================
26 | // Load strip into shared memory...
27 |
28 | int image_x = flip ? strip_width - 1 - threadIdx.x : threadIdx.x;
29 |
30 | int strip_index = blockIdx.y;
31 | int strip_height = 1 + (image_height - 2) / (1.0f + exp(-(strip_index - strip_count / 2.0f + 0.5f)/(strip_count / 8.0f)));
32 |
33 | float point_score = g_score_strips[image_x + strip_index * strip_width];
34 |
35 | int best_edge_x = image_x;
36 | float best_edge_score = point_score;
37 |
38 | // warp reduction....
39 | #pragma unroll
40 | for (int offset = 32 >> 1; offset > 0; offset >>= 1)
41 | {
42 | int other_edge_x = __shfl_down_sync(0xffffffff, best_edge_x, offset);
43 | float other_edge_score = __shfl_down_sync(0xffffffff, best_edge_score, offset);
44 |
45 | if (other_edge_score > best_edge_score)
46 | {
47 | best_edge_x = other_edge_x;
48 | best_edge_score = other_edge_score;
49 | }
50 | }
51 |
52 | if (lane_index == 0)
53 | {
54 | s_cross_warp_operation_buffer[warp_index] = best_edge_x;
55 | s_cross_warp_operation_buffer_2[warp_index] = best_edge_score;
56 | }
57 |
58 | __syncthreads();
59 |
60 | // block reduction....
61 | if (warp_index == 0 && lane_index < warp_count)
62 | {
63 | best_edge_x = s_cross_warp_operation_buffer[lane_index];
64 | best_edge_score = s_cross_warp_operation_buffer_2[lane_index];
65 |
66 | int next_power_two = pow(2, ceil(log(warp_count)/log(2)));
67 |
68 | #pragma unroll
69 | for (int offset = next_power_two >> 1 ; offset > 0; offset >>= 1)
70 | {
71 | int other_edge_x = __shfl_down_sync(0xffffffff, best_edge_x, offset);
72 | float other_edge_score = __shfl_down_sync(0xffffffff, best_edge_score, offset);
73 |
74 | if (other_edge_score > best_edge_score)
75 | {
76 | best_edge_x = other_edge_x;
77 | best_edge_score = other_edge_score;
78 | }
79 | }
80 |
81 | if (lane_index == 0)
82 | {
83 | int point_index = flip ? strip_index : strip_index + strip_count;
84 |
85 | g_edge_x[point_index] = best_edge_x + half_patch_size;
86 | g_edge_y[point_index] = strip_height;
87 | g_edge_scores[point_index] = best_edge_score;
88 | }
89 | }
90 | }
91 |
92 | void find_points_from_strip_scores(const float* strips, const int batch_count, const int image_height, const int image_width, const int strip_count, const int model_patch_size, float* points_x, float* points_y, float* point_score)
93 | {
94 | int half_patch_size = (model_patch_size - 1) / 2;
95 | int strip_width = image_width - 2 * half_patch_size;
96 |
97 | int half_width = strip_width / 2;
98 | int warps = 1 + (half_width - 1) / 32;
99 | int threads = warps * 32;
100 |
101 | threads = threads > 1024 ? 1024 : threads;
102 |
103 | dim3 grid(2, strip_count, batch_count);
104 | dim3 block(threads);
105 | int shared_memory = 2 * warps * sizeof(int);
106 | find_best_edge<<>>(strips, points_x, points_y, point_score, strip_width, image_height, strip_count, half_patch_size);
107 | }
108 | }
109 |
--------------------------------------------------------------------------------
/src/torchcontentarea/csrc/source/fit_circle_cpu.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include "../common.hpp"
3 |
4 | // =========================================================================
5 | // General functionality...
6 |
7 | namespace cpu
8 | {
9 |
10 | int fast_rand(int& seed)
11 | {
12 | seed = 214013 * seed + 2531011;
13 | return (seed >> 16) & 0x7FFF;
14 | }
15 |
16 | void rand_triplet(int seed, int seed_stride, int max, int* triplet)
17 | {
18 | triplet[0] = fast_rand(seed) % max;
19 | do {triplet[1] = fast_rand(seed) % max;} while (triplet[1] == triplet[0]);
20 | do {triplet[2] = fast_rand(seed) % max;} while (triplet[2] == triplet[0] || triplet[2] == triplet[1]);
21 | }
22 |
23 | bool check_circle(float x, float y, float r, int image_width, int image_height)
24 | {
25 | float x_diff = x - 0.5 * image_width;
26 | float y_diff = y - 0.5 * image_height;
27 | float diff = sqrt(x_diff * x_diff + y_diff * y_diff);
28 |
29 | bool valid = true;
30 | valid &= diff < MAX_CENTER_DIST * image_width;
31 | valid &= r > MIN_RADIUS * image_width;
32 | valid &= r < MAX_RADIUS * image_width;
33 |
34 | return valid;
35 | }
36 |
37 | bool calculate_circle(float ax, float ay, float bx, float by, float cx, float cy, float* x, float* y, float* r)
38 | {
39 | float offset = bx * bx + by * by;
40 |
41 | float bc = 0.5f * (ax * ax + ay * ay - offset);
42 | float cd = 0.5f * (offset - cx * cx - cy * cy);
43 |
44 | float det = (ax - bx) * (by - cy) - (bx - cx) * (ay - by);
45 |
46 | bool valid = abs(det) > 1e-8;
47 |
48 | if (valid)
49 | {
50 | float idet = 1.0f / det;
51 |
52 | *x = (bc * (by - cy) - cd * (ay - by)) * idet;
53 | *y = (cd * (ax - bx) - bc * (bx - cx)) * idet;
54 | *r = sqrt((bx - *x) * (bx - *x) + (by - *y) * (by - *y));
55 | }
56 |
57 | return valid;
58 | }
59 |
60 | bool Cholesky3x3(float lhs[3][3], float rhs[3])
61 | {
62 | float sum;
63 | float diagonal[3];
64 |
65 | sum = lhs[0][0];
66 |
67 | if (sum <= 0.f)
68 | return false;
69 |
70 | diagonal[0] = sqrt(sum);
71 |
72 | sum = lhs[0][1];
73 | lhs[1][0] = sum / diagonal[0];
74 |
75 | sum = lhs[0][2];
76 | lhs[2][0] = sum / diagonal[0];
77 |
78 | sum = lhs[1][1] - lhs[1][0] * lhs[1][0];
79 |
80 | if (sum <= 0.f)
81 | return false;
82 |
83 | diagonal[1] = sqrt(sum);
84 |
85 | sum = lhs[1][2] - lhs[1][0] * lhs[2][0];
86 | lhs[2][1] = sum / diagonal[1];
87 |
88 | sum = lhs[2][2] - lhs[2][1] * lhs[2][1] - lhs[2][0] * lhs[2][0];
89 |
90 | if (sum <= 0.f)
91 | return false;
92 |
93 | diagonal[2] = sqrt(sum);
94 |
95 | sum = rhs[0];
96 | rhs[0] = sum / diagonal[0];
97 |
98 | sum = rhs[1] - lhs[1][0] * rhs[0];
99 | rhs[1] = sum / diagonal[1];
100 |
101 | sum = rhs[2] - lhs[2][1] * rhs[1] - lhs[2][0] * rhs[0];
102 | rhs[2] = sum / diagonal[2];
103 |
104 | sum = rhs[2];
105 | rhs[2] = sum / diagonal[2];
106 |
107 | sum = rhs[1] - lhs[2][1] * rhs[2];
108 | rhs[1] = sum / diagonal[1];
109 |
110 | sum = rhs[0] - lhs[1][0] * rhs[1] - lhs[2][0] * rhs[2];
111 | rhs[0] = sum / diagonal[0];
112 |
113 | return true;
114 | }
115 |
116 | void get_circle(int point_count, int* indices, int* points_x, int* points_y, float* circle_x, float* circle_y, float* circle_r)
117 | {
118 | if (point_count == 3)
119 | {
120 | int a = indices[0];
121 | int b = indices[1];
122 | int c = indices[2];
123 |
124 | calculate_circle(points_x[a], points_y[a], points_x[b], points_y[b], points_x[c], points_y[c], circle_x, circle_y, circle_r);
125 | }
126 | else
127 | {
128 |
129 | float lhs[3][3] {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}};
130 | float rhs[3] {0, 0, 0};
131 |
132 | for (int i = 0; i < point_count; i++)
133 | {
134 | float p_x = points_x[indices[i]];
135 | float p_y = points_y[indices[i]];
136 |
137 | lhs[0][0] += p_x * p_x;
138 | lhs[0][1] += p_x * p_y;
139 | lhs[1][1] += p_y * p_y;
140 | lhs[0][2] += p_x;
141 | lhs[1][2] += p_y;
142 | lhs[2][2] += 1;
143 |
144 | rhs[0] += p_x * p_x * p_x + p_x * p_y * p_y;
145 | rhs[1] += p_x * p_x * p_y + p_y * p_y * p_y;
146 | rhs[2] += p_x * p_x + p_y * p_y;
147 | }
148 |
149 | Cholesky3x3(lhs, rhs);
150 |
151 | float A=rhs[0], B=rhs[1], C=rhs[2];
152 |
153 | *circle_x = A / 2.0f;
154 | *circle_y = B / 2.0f;
155 | *circle_r = std::sqrt(4.0f * C + A * A + B * B) / 2.0f;
156 | }
157 | }
158 |
159 | // =========================================================================
160 | // Main function...
161 |
162 | void fit_circle(const float* points_x, const float* points_y, const float* points_score, const int batch_count, const int point_count, const ConfidenceThresholds confidence_thresholds, const int image_height, const int image_width, float* results)
163 | {
164 | int* compacted_points = (int*)malloc(3 * point_count * sizeof(int));
165 | int* compacted_points_x = compacted_points + 0 * point_count;
166 | int* compacted_points_y = compacted_points + 1 * point_count;
167 | float* compacted_points_s = (float*)compacted_points + 2 * point_count;
168 |
169 |
170 | for (int batch_index = 0; batch_index < batch_count; ++batch_index)
171 | {
172 | // Point compaction...
173 | int real_point_count = 0;
174 | for (int i = 0; i < point_count; ++i)
175 | {
176 | if (points_score[i] > confidence_thresholds.edge)
177 | {
178 | compacted_points_x[real_point_count] = points_x[i + batch_index * 3 * point_count];
179 | compacted_points_y[real_point_count] = points_y[i + batch_index * 3 * point_count];
180 | compacted_points_s[real_point_count] = points_score[i + batch_index * 3 * point_count];
181 | real_point_count += 1;
182 | }
183 | }
184 |
185 | results[0 + batch_index * 4] = 0.0f;
186 | results[1 + batch_index * 4] = 0.0f;
187 | results[2 + batch_index * 4] = 0.0f;
188 | results[3 + batch_index * 4] = 0.0f;
189 |
190 | // Early out...
191 | if (real_point_count < 3)
192 | {
193 | return;
194 | }
195 |
196 | // Ransac attempts...
197 | for (int ransac_attempt = 0; ransac_attempt < RANSAC_ATTEMPTS; ++ransac_attempt)
198 | {
199 | int inlier_count = 3;
200 | int inliers[MAX_POINT_COUNT];
201 | rand_triplet(ransac_attempt * 42342, RANSAC_ATTEMPTS, real_point_count, inliers);
202 |
203 | float circle_x, circle_y, circle_r;
204 | float circle_score = 0.0f;
205 |
206 | for (int i = 0; i < RANSAC_ITERATIONS; i++)
207 | {
208 | get_circle(inlier_count, inliers, compacted_points_x, compacted_points_y, &circle_x, &circle_y, &circle_r);
209 |
210 | inlier_count = 0;
211 | circle_score = 0.0f;
212 |
213 | for (int point_index = 0; point_index < real_point_count; point_index++)
214 | {
215 | int edge_x = compacted_points_x[point_index];
216 | int edge_y = compacted_points_y[point_index];
217 | float edge_score = compacted_points_s[point_index];
218 |
219 | float delta_x = circle_x - edge_x;
220 | float delta_y = circle_y - edge_y;
221 |
222 | float delta = std::sqrt(delta_x * delta_x + delta_y * delta_y);
223 | float error = std::abs(circle_r - delta);
224 |
225 | if (error < RANSAC_INLIER_THRESHOLD)
226 | {
227 | circle_score += edge_score;
228 |
229 | inliers[inlier_count] = point_index;
230 | inlier_count++;
231 | }
232 | }
233 |
234 | circle_score /= point_count;
235 | }
236 |
237 | bool circle_valid = check_circle(circle_x, circle_y, circle_r, image_width, image_height);
238 |
239 | if (circle_valid && circle_score > results[3 + batch_index * 4])
240 | {
241 | results[0 + batch_index * 4] = circle_x;
242 | results[1 + batch_index * 4] = circle_y;
243 | results[2 + batch_index * 4] = circle_r;
244 | results[3 + batch_index * 4] = circle_score;
245 | }
246 | }
247 | }
248 |
249 | free(compacted_points);
250 | }
251 | }
252 |
--------------------------------------------------------------------------------
/src/torchcontentarea/csrc/source/fit_circle_cuda.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include "../common.hpp"
3 |
4 | namespace cuda
5 | {
6 | // =========================================================================
7 | // General functionality...
8 |
9 | __device__ int fast_rand(int& seed)
10 | {
11 | seed = 214013 * seed + 2531011;
12 | return (seed >> 16) & 0x7FFF;
13 | }
14 |
15 | __device__ void rand_triplet(int seed, int seed_stride, int max, int* triplet)
16 | {
17 | triplet[0] = fast_rand(seed) % max;
18 | do {triplet[1] = fast_rand(seed) % max;} while (triplet[1] == triplet[0]);
19 | do {triplet[2] = fast_rand(seed) % max;} while (triplet[2] == triplet[0] || triplet[2] == triplet[1]);
20 | }
21 |
22 | __device__ bool check_circle(float x, float y, float r, int image_width, int image_height)
23 | {
24 | float x_diff = x - 0.5 * image_width;
25 | float y_diff = y - 0.5 * image_height;
26 | float diff = sqrt(x_diff * x_diff + y_diff * y_diff);
27 |
28 | bool valid = true;
29 | valid &= diff < MAX_CENTER_DIST * image_width;
30 | valid &= r > MIN_RADIUS * image_width;
31 | valid &= r < MAX_RADIUS * image_width;
32 |
33 | return valid;
34 | }
35 |
36 | __device__ bool calculate_circle(float ax, float ay, float bx, float by, float cx, float cy, float* x, float* y, float* r)
37 | {
38 | float offset = bx * bx + by * by;
39 |
40 | float bc = 0.5f * (ax * ax + ay * ay - offset);
41 | float cd = 0.5f * (offset - cx * cx - cy * cy);
42 |
43 | float det = (ax - bx) * (by - cy) - (bx - cx) * (ay - by);
44 |
45 | bool valid = abs(det) > 1e-8;
46 |
47 | if (valid)
48 | {
49 | float idet = 1.0f / det;
50 |
51 | *x = (bc * (by - cy) - cd * (ay - by)) * idet;
52 | *y = (cd * (ax - bx) - bc * (bx - cx)) * idet;
53 | *r = sqrt((bx - *x) * (bx - *x) + (by - *y) * (by - *y));
54 | }
55 |
56 | return valid;
57 | }
58 |
59 | __device__ bool Cholesky3x3(float lhs[3][3], float rhs[3])
60 | {
61 | float sum;
62 | float diagonal[3];
63 |
64 | sum = lhs[0][0];
65 |
66 | if (sum <= 0.f)
67 | return false;
68 |
69 | diagonal[0] = sqrt(sum);
70 |
71 | sum = lhs[0][1];
72 | lhs[1][0] = sum / diagonal[0];
73 |
74 | sum = lhs[0][2];
75 | lhs[2][0] = sum / diagonal[0];
76 |
77 | sum = lhs[1][1] - lhs[1][0] * lhs[1][0];
78 |
79 | if (sum <= 0.f)
80 | return false;
81 |
82 | diagonal[1] = sqrt(sum);
83 |
84 | sum = lhs[1][2] - lhs[1][0] * lhs[2][0];
85 | lhs[2][1] = sum / diagonal[1];
86 |
87 | sum = lhs[2][2] - lhs[2][1] * lhs[2][1] - lhs[2][0] * lhs[2][0];
88 |
89 | if (sum <= 0.f)
90 | return false;
91 |
92 | diagonal[2] = sqrt(sum);
93 |
94 | sum = rhs[0];
95 | rhs[0] = sum / diagonal[0];
96 |
97 | sum = rhs[1] - lhs[1][0] * rhs[0];
98 | rhs[1] = sum / diagonal[1];
99 |
100 | sum = rhs[2] - lhs[2][1] * rhs[1] - lhs[2][0] * rhs[0];
101 | rhs[2] = sum / diagonal[2];
102 |
103 | sum = rhs[2];
104 | rhs[2] = sum / diagonal[2];
105 |
106 | sum = rhs[1] - lhs[2][1] * rhs[2];
107 | rhs[1] = sum / diagonal[1];
108 |
109 | sum = rhs[0] - lhs[1][0] * rhs[1] - lhs[2][0] * rhs[2];
110 | rhs[0] = sum / diagonal[0];
111 |
112 | return true;
113 | }
114 |
115 | __device__ void get_circle(int point_count, int* indices, int* points_x, int* points_y, float* circle_x, float* circle_y, float* circle_r)
116 | {
117 | if (point_count == 3)
118 | {
119 | int a = indices[0];
120 | int b = indices[1];
121 | int c = indices[2];
122 |
123 | calculate_circle(points_x[a], points_y[a], points_x[b], points_y[b], points_x[c], points_y[c], circle_x, circle_y, circle_r);
124 | }
125 | else
126 | {
127 |
128 | float lhs[3][3] {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}};
129 | float rhs[3] {0, 0, 0};
130 |
131 | for (int i = 0; i < point_count; i++)
132 | {
133 | float p_x = points_x[indices[i]];
134 | float p_y = points_y[indices[i]];
135 |
136 | lhs[0][0] += p_x * p_x;
137 | lhs[0][1] += p_x * p_y;
138 | lhs[1][1] += p_y * p_y;
139 | lhs[0][2] += p_x;
140 | lhs[1][2] += p_y;
141 | lhs[2][2] += 1;
142 |
143 | rhs[0] += p_x * p_x * p_x + p_x * p_y * p_y;
144 | rhs[1] += p_x * p_x * p_y + p_y * p_y * p_y;
145 | rhs[2] += p_x * p_x + p_y * p_y;
146 | }
147 |
148 | Cholesky3x3(lhs, rhs);
149 |
150 | float A=rhs[0], B=rhs[1], C=rhs[2];
151 |
152 | *circle_x = A / 2.0f;
153 | *circle_y = B / 2.0f;
154 | *circle_r = sqrt(4.0f * C + A * A + B * B) / 2.0f;
155 | }
156 | }
157 |
158 | // =========================================================================
159 | // Kernels...
160 |
161 | template
162 | __global__ void fit_circle_kernel(const float* g_edge_x_batch, const float* g_edge_y_batch, const float* g_edge_scores_batch, float* g_circle_batch, const int point_count, const ConfidenceThresholds confidence_thresholds, const int image_height, const int image_width)
163 | {
164 | extern __shared__ int s_edge_info[];
165 | __shared__ int s_valid_point_count;
166 | __shared__ float s_score_reduction_buffer[warp_count];
167 | __shared__ float s_x_reduction_buffer[warp_count];
168 | __shared__ float s_y_reduction_buffer[warp_count];
169 | __shared__ float s_r_reduction_buffer[warp_count];
170 |
171 | const float* g_edge_x = g_edge_x_batch + blockIdx.x * 3 * point_count;
172 | const float* g_edge_y = g_edge_y_batch + blockIdx.x * 3 * point_count;
173 | const float* g_edge_scores = g_edge_scores_batch + blockIdx.x * 3 * point_count;
174 | float* g_circle = g_circle_batch + blockIdx.x * 4;
175 |
176 | int* s_edge_x = (int*)(s_edge_info + 0 * point_count);
177 | int* s_edge_y = (int*)(s_edge_info + 1 * point_count);
178 | float* s_edge_scores = (float*)(s_edge_info + 2 * point_count);
179 |
180 | const int warp_index = threadIdx.x >> 5;
181 | const int lane_index = threadIdx.x & 31;
182 |
183 | // Loading points to shared memory...
184 | if (threadIdx.x < point_count)
185 | {
186 | s_edge_x[threadIdx.x] = g_edge_x[threadIdx.x];
187 | s_edge_y[threadIdx.x] = g_edge_y[threadIdx.x];
188 | s_edge_scores[threadIdx.x] = g_edge_scores[threadIdx.x];
189 | }
190 |
191 | // Point compaction...
192 | bool has_point = threadIdx.x < point_count ? s_edge_scores[threadIdx.x] > confidence_thresholds.edge : false;
193 | int preceeding_count = has_point;
194 |
195 | #pragma unroll
196 | for (int d=1; d < 32; d<<=1)
197 | {
198 | float other_count = __shfl_up_sync(0xffffffff, preceeding_count, d);
199 |
200 | if (lane_index >= d)
201 | {
202 | preceeding_count += other_count;
203 | }
204 | }
205 |
206 | if (lane_index == 31)
207 | {
208 | s_score_reduction_buffer[warp_index] = preceeding_count;
209 | }
210 |
211 | __syncthreads();
212 |
213 | if (warp_index == 0)
214 | {
215 | int warp_sum = lane_index < warp_count ? s_score_reduction_buffer[lane_index] : 0;
216 |
217 | #pragma unroll
218 | for (int d=1; d < 32; d<<=1)
219 | {
220 | float other_warp_sum = __shfl_up_sync(0xffffffff, warp_sum, d);
221 |
222 | if (lane_index >= d && other_warp_sum > warp_sum)
223 | {
224 | warp_sum = other_warp_sum;
225 | }
226 | }
227 |
228 | if (lane_index < warp_count)
229 | {
230 | s_score_reduction_buffer[lane_index] = warp_sum;
231 | }
232 | }
233 |
234 | __syncthreads();
235 |
236 | if (warp_index > 0)
237 | {
238 | preceeding_count += s_score_reduction_buffer[warp_index-1];
239 | }
240 |
241 | if (has_point)
242 | {
243 | s_edge_x[preceeding_count - 1] = s_edge_x[threadIdx.x];
244 | s_edge_y[preceeding_count - 1] = s_edge_y[threadIdx.x];
245 | s_edge_scores[preceeding_count - 1] = s_edge_scores[threadIdx.x];
246 | }
247 |
248 | if (threadIdx.x == blockDim.x - 1)
249 | {
250 | s_valid_point_count = preceeding_count;
251 | }
252 |
253 | __syncthreads();
254 |
255 | if (s_valid_point_count < 3)
256 | {
257 | if (threadIdx.x < 4)
258 | {
259 | g_circle[threadIdx.x] = 0.0;
260 | }
261 |
262 | return;
263 | }
264 |
265 | int inlier_count = 3;
266 | int inliers[MAX_POINT_COUNT];
267 | rand_triplet(threadIdx.x * 42342, RANSAC_ATTEMPTS, s_valid_point_count, inliers);
268 |
269 | float circle_x, circle_y, circle_r;
270 | float circle_score = 0.0f;
271 |
272 | for (int i = 0; i < RANSAC_ITERATIONS; i++)
273 | {
274 | get_circle(inlier_count, inliers, s_edge_x, s_edge_y, &circle_x, &circle_y, &circle_r);
275 |
276 | inlier_count = 0;
277 | circle_score = 0.0f;
278 |
279 | for (int point_index = 0; point_index < s_valid_point_count; point_index++)
280 | {
281 | int edge_x = s_edge_x[point_index];
282 | int edge_y = s_edge_y[point_index];
283 | float edge_score = s_edge_scores[point_index];
284 |
285 | float delta_x = circle_x - edge_x;
286 | float delta_y = circle_y - edge_y;
287 |
288 | float delta = sqrt(delta_x * delta_x + delta_y * delta_y);
289 | float error = abs(circle_r - delta);
290 |
291 | if (error < RANSAC_INLIER_THRESHOLD)
292 | {
293 | circle_score += edge_score;
294 |
295 | inliers[inlier_count] = point_index;
296 | inlier_count++;
297 | }
298 | }
299 |
300 | circle_score /= point_count;
301 | }
302 |
303 | bool circle_valid = check_circle(circle_x, circle_y, circle_r, image_width, image_height);
304 |
305 | if (!circle_valid)
306 | {
307 | circle_score = 0;
308 | }
309 |
310 | //#################################
311 | // Reduction
312 |
313 | #pragma unroll
314 | for (int offset = 16; offset > 0; offset /= 2)
315 | {
316 | float other_circle_score = __shfl_down_sync(0xffffffff, circle_score, offset);
317 | float other_circle_x = __shfl_down_sync(0xffffffff, circle_x, offset);
318 | float other_circle_y = __shfl_down_sync(0xffffffff, circle_y, offset);
319 | float other_circle_r = __shfl_down_sync(0xffffffff, circle_r, offset);
320 |
321 | if (other_circle_score > circle_score)
322 | {
323 | circle_score = other_circle_score;
324 | circle_x = other_circle_x;
325 | circle_y = other_circle_y;
326 | circle_r = other_circle_r;
327 | }
328 | }
329 |
330 | if (lane_index == 0)
331 | {
332 | s_score_reduction_buffer[warp_index] = circle_score;
333 | s_x_reduction_buffer[warp_index] = circle_x;
334 | s_y_reduction_buffer[warp_index] = circle_y;
335 | s_r_reduction_buffer[warp_index] = circle_r;
336 | }
337 |
338 | __syncthreads();
339 |
340 | if (warp_index == 0 && lane_index < warp_count)
341 | {
342 | circle_score = s_score_reduction_buffer[warp_index];
343 | circle_x = s_x_reduction_buffer[warp_index];
344 | circle_y = s_y_reduction_buffer[warp_index];
345 | circle_r = s_r_reduction_buffer[warp_index];
346 |
347 | #pragma unroll
348 | for (int offset = warp_count / 2; offset > 0; offset /= 2)
349 | {
350 | float other_circle_score = __shfl_down_sync(0xffffffff, circle_score, offset);
351 | float other_circle_x = __shfl_down_sync(0xffffffff, circle_x, offset);
352 | float other_circle_y = __shfl_down_sync(0xffffffff, circle_y, offset);
353 | float other_circle_r = __shfl_down_sync(0xffffffff, circle_r, offset);
354 |
355 | if (other_circle_score > circle_score)
356 | {
357 | circle_score = other_circle_score;
358 | circle_x = other_circle_x;
359 | circle_y = other_circle_y;
360 | circle_r = other_circle_r;
361 | }
362 | }
363 |
364 | if (lane_index == 0)
365 | {
366 | g_circle[0] = circle_x;
367 | g_circle[1] = circle_y;
368 | g_circle[2] = circle_r;
369 | g_circle[3] = circle_score;
370 | }
371 | }
372 | }
373 |
374 | // =========================================================================
375 | // Main function...
376 |
377 | #define ransac_threads RANSAC_ATTEMPTS
378 | #define ransac_warps (1 + (RANSAC_ATTEMPTS - 1) / 32)
379 |
380 | void fit_circle(const float* points_x, const float* points_y, const float* points_score, const int batch_count, const int point_count, const ConfidenceThresholds confidence_thresholds, const int image_height, const int image_width, float* results)
381 | {
382 | dim3 grid(batch_count);
383 | dim3 block(point_count);
384 | fit_circle_kernel<<>>(points_x, points_y, points_score, results, point_count, confidence_thresholds, image_height, image_width);
385 | }
386 |
387 | }
388 |
--------------------------------------------------------------------------------
/src/torchcontentarea/csrc/source/make_strips_cpu.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include "../common.hpp"
3 |
4 | namespace cpu
5 | {
6 | template
7 | void make_strips(const T* image, const int batch_count, const int image_height, const int image_width, const int strip_count, const int strip_width, float* strips)
8 | {
9 | for (int batch_index = 0; batch_index < batch_count; ++batch_index)
10 | {
11 | for (int strip_index = 0; strip_index < strip_count; ++strip_index)
12 | {
13 | int strip_height = 1 + (image_height - 2) / (1.0f + std::exp(-(strip_index - strip_count / 2.0f + 0.5f)/(strip_count / 8.0f)));
14 | int strip_offset = strip_index * 5 * image_width * strip_width;
15 |
16 | for (int image_x = 0; image_x < image_width; ++image_x)
17 | {
18 | for (int strip_y = 0; strip_y < strip_width; ++strip_y)
19 | {
20 | int image_y = strip_height + strip_y - (strip_width - 1) / 2;
21 |
22 | int image_pixel_index = image_x + image_y * image_width;
23 | int strip_pixel_index = strip_offset + image_x + strip_y * image_width;
24 |
25 | float r, g, b;
26 | float norm = std::is_floating_point::value ? 1.0f : 255.0f;
27 |
28 | if (c == 3)
29 | {
30 | r = (image[image_pixel_index + 0 * image_width * image_height + batch_index * 3 * image_width * image_height]/norm - 0.3441f) / 0.2381f;
31 | g = (image[image_pixel_index + 1 * image_width * image_height + batch_index * 3 * image_width * image_height]/norm - 0.2251f) / 0.1994f;
32 | b = (image[image_pixel_index + 2 * image_width * image_height + batch_index * 3 * image_width * image_height]/norm - 0.2203f) / 0.1939f;
33 | }
34 | else
35 | {
36 | r = (image[image_pixel_index + batch_index * 3 * image_width * image_height]/norm - 0.3441f) / 0.2381f;
37 | g = (image[image_pixel_index + batch_index * 3 * image_width * image_height]/norm - 0.2251f) / 0.1994f;
38 | b = (image[image_pixel_index + batch_index * 3 * image_width * image_height]/norm - 0.2203f) / 0.1939f;
39 | }
40 |
41 | float x = ((float)image_x / image_width) - 0.5f;
42 | float y = ((float)image_y / image_height) - 0.5f;
43 |
44 | strips[strip_pixel_index + 0 * image_width * strip_width + batch_index * strip_count * 5 * strip_width * image_width] = r;
45 | strips[strip_pixel_index + 1 * image_width * strip_width + batch_index * strip_count * 5 * strip_width * image_width] = g;
46 | strips[strip_pixel_index + 2 * image_width * strip_width + batch_index * strip_count * 5 * strip_width * image_width] = b;
47 | strips[strip_pixel_index + 3 * image_width * strip_width + batch_index * strip_count * 5 * strip_width * image_width] = x;
48 | strips[strip_pixel_index + 4 * image_width * strip_width + batch_index * strip_count * 5 * strip_width * image_width] = y;
49 | }
50 | }
51 | }
52 | }
53 | }
54 |
55 |
56 | void make_strips(Image image, const int batch_count, const int channel_count, const int image_height, const int image_width, const int strip_count, const int strip_width, float* strips)
57 | {
58 | FUNCTION_CALL_IMAGE_FORMAT(make_strips, image, batch_count, image_height, image_width, strip_count, strip_width, strips);
59 | }
60 | }
61 |
--------------------------------------------------------------------------------
/src/torchcontentarea/csrc/source/make_strips_cuda.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include "../common.hpp"
3 |
4 | namespace cuda
5 | {
6 | __device__ int get_strip_height(int strip_index, int strip_count, int image_height)
7 | {
8 | return 1 + (image_height - 2) / (1.0f + exp(-(strip_index - strip_count / 2.0f + 0.5f) / (strip_count / 8.0f)));
9 | }
10 |
11 | template
12 | __global__ void make_strips_kernel(const T* g_image_batch, const int image_width, const int image_height, const int strip_count, const int strip_width, float* g_strips_batch)
13 | {
14 | const T* g_image = g_image_batch + blockIdx.z * c * image_width * image_height;
15 | float* g_strips = g_strips_batch + blockIdx.z * strip_count * 5 * image_width * strip_width;
16 |
17 | int strip_index = blockIdx.y;
18 | int strip_offset = strip_index * 5 * image_width * strip_width;
19 | int strip_height = get_strip_height(strip_index, strip_count, image_height);
20 |
21 | int image_x = threadIdx.x + blockIdx.x * blockDim.x;
22 | int strip_y = threadIdx.y;
23 |
24 | if (image_x >= image_width)
25 | return;
26 |
27 | int image_y = strip_height + strip_y - (strip_width - 1) / 2;
28 |
29 | int image_pixel_index = image_x + image_y * image_width;
30 | int strip_pixel_index = strip_offset + image_x + strip_y * image_width;
31 |
32 | float r, g, b;
33 |
34 | float norm = std::is_floating_point::value ? 1.0f : 255.0f;
35 |
36 | if (c == 3)
37 | {
38 | r = (g_image[image_pixel_index + 0 * image_width * image_height]/norm - 0.3441f) / 0.2381f;
39 | g = (g_image[image_pixel_index + 1 * image_width * image_height]/norm - 0.2251f) / 0.1994f;
40 | b = (g_image[image_pixel_index + 2 * image_width * image_height]/norm - 0.2203f) / 0.1939f;
41 | }
42 | else
43 | {
44 | r = (g_image[image_pixel_index]/norm - 0.3441f) / 0.2381f;
45 | g = (g_image[image_pixel_index]/norm - 0.2251f) / 0.1994f;
46 | b = (g_image[image_pixel_index]/norm - 0.2203f) / 0.1939f;
47 | }
48 |
49 | float x = ((float)image_x / image_width) - 0.5f;
50 | float y = ((float)image_y / image_height) - 0.5f;
51 |
52 | g_strips[strip_pixel_index + 0 * image_width * strip_width] = r;
53 | g_strips[strip_pixel_index + 1 * image_width * strip_width] = g;
54 | g_strips[strip_pixel_index + 2 * image_width * strip_width] = b;
55 | g_strips[strip_pixel_index + 3 * image_width * strip_width] = x;
56 | g_strips[strip_pixel_index + 4 * image_width * strip_width] = y;
57 | }
58 |
59 | void make_strips(Image image, const int batch_count, const int channel_count, const int image_height, const int image_width, const int strip_count, const int strip_width, float* strips)
60 | {
61 | dim3 grid(((image_width - 1) / 128) + 1, strip_count, batch_count);
62 | dim3 block(128, strip_width);
63 | KERNEL_DISPATCH_IMAGE_FORMAT(make_strips_kernel, ARG(grid, block), image, image_width, image_height, strip_count, strip_width, strips);
64 | }
65 | }
66 |
--------------------------------------------------------------------------------
/src/torchcontentarea/extension_wrapper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Sequence, Optional
3 |
4 | try:
5 | import torchcontentareaext as ext
6 | FALL_BACK = False
7 | except:
8 | print("Falling back to python implementation...")
9 | from . import pythonimplementation as ext
10 | FALL_BACK = True
11 |
12 | models = {}
13 |
14 | def load_default_model(device):
15 | dir = "/".join(__file__.split("/")[:-1])
16 | return torch.jit.load(f"{dir}/models/kernel_3_8.pt", map_location=device)
17 |
18 | def get_default_model(device):
19 | global models
20 | model = models.get(device)
21 | if model == None:
22 | model = load_default_model(device)
23 | models.update({device: model})
24 | return model
25 |
26 |
27 | def estimate_area_handcrafted(image: torch.Tensor, strip_count: int=16, feature_thresholds: Sequence[float]=(20, 30, 25), confidence_thresholds: Sequence[float]=(0.03, 0.06)) -> torch.Tensor:
28 | """
29 | Estimates the content area for the given endoscopic image(s) using handcrafted feature extraction.
30 | """
31 | return ext.estimate_area_handcrafted(image, strip_count, feature_thresholds, confidence_thresholds)
32 |
33 |
34 | def estimate_area_learned(image: torch.Tensor, strip_count: int=16, model: Optional[torch.jit.ScriptModule]=None, model_patch_size: int=7, confidence_thresholds: Sequence[float]=(0.03, 0.06)) -> torch.Tensor:
35 | """
36 | Estimates the content area for the given endoscopic image(s) using learned feature extraction.
37 | """
38 | if model == None:
39 | model = get_default_model(image.device)
40 | return ext.estimate_area_learned(image, strip_count, model if FALL_BACK else model._c, model_patch_size, confidence_thresholds)
41 |
42 |
43 | def get_points_handcrafted(image: torch.Tensor, strip_count: int=16, feature_thresholds: Sequence[float]=(20, 30, 25)) -> torch.Tensor:
44 | """
45 | Finds candidate edge points and corresponding scores in the given image(s) using handcrafted feature extraction.
46 | """
47 | return ext.get_points_handcrafted(image, strip_count, feature_thresholds)
48 |
49 |
50 | def get_points_learned(image: torch.Tensor, strip_count: int=16, model: Optional[torch.jit.ScriptModule]=None, model_patch_size: int=7) -> torch.Tensor:
51 | """
52 | Finds candidate edge points and corresponding scores in the given image(s) using learned feature extraction.
53 | """
54 | if model == None:
55 | model = get_default_model(image.device)
56 | return ext.get_points_learned(image, strip_count, model if FALL_BACK else model._c, model_patch_size)
57 |
58 |
59 | def fit_area(points: torch.Tensor, image_size: Sequence[int], confidence_thresholds: Sequence[float]=(0.03, 0.06)) -> torch.Tensor:
60 | """
61 | Finds candidate edge points and corresponding scores in the given image(s) using handcrafted feature extraction.
62 | """
63 | return ext.fit_area(points, image_size, confidence_thresholds)
64 |
--------------------------------------------------------------------------------
/src/torchcontentarea/models/kernel_1_8.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/charliebudd/torch-content-area/c613957f266a64232f8283975653635160b3f0a2/src/torchcontentarea/models/kernel_1_8.pt
--------------------------------------------------------------------------------
/src/torchcontentarea/models/kernel_2_8.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/charliebudd/torch-content-area/c613957f266a64232f8283975653635160b3f0a2/src/torchcontentarea/models/kernel_2_8.pt
--------------------------------------------------------------------------------
/src/torchcontentarea/models/kernel_3_8.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/charliebudd/torch-content-area/c613957f266a64232f8283975653635160b3f0a2/src/torchcontentarea/models/kernel_3_8.pt
--------------------------------------------------------------------------------
/src/torchcontentarea/pythonimplementation/__init__.py:
--------------------------------------------------------------------------------
1 | from .estimate_area import estimate_area_handcrafted, estimate_area_learned
2 | from .get_points import get_points_handcrafted, get_points_learned
3 | from .fit_area import fit_area
4 |
--------------------------------------------------------------------------------
/src/torchcontentarea/pythonimplementation/estimate_area.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Sequence, Optional
3 |
4 | from .get_points import get_points_handcrafted, get_points_learned
5 | from .fit_area import fit_area
6 |
7 | def estimate_area_handcrafted(image: torch.Tensor, strip_count: int=16, feature_thresholds: Sequence[float]=(20, 30, 25), confidence_thresholds: Sequence[float]=(0.03, 0.06)) -> torch.Tensor:
8 | """
9 | Estimates the content area for the given endoscopic image(s) using handcrafted feature extraction.
10 | """
11 | points = get_points_handcrafted(image, strip_count, feature_thresholds)
12 | area = fit_area(points, image.shape[-2:], confidence_thresholds)
13 | return area
14 |
15 |
16 | def estimate_area_learned(image: torch.Tensor, strip_count: int=16, model: Optional[torch.jit.ScriptModule]=None, model_patch_size: int=7, confidence_thresholds: Sequence[float]=(0.03, 0.06)) -> torch.Tensor:
17 | """
18 | Estimates the content area for the given endoscopic image(s) using learned feature extraction.
19 | """
20 | points = get_points_learned(image, strip_count, model, model_patch_size)
21 | area = fit_area(points, image.shape[-2:], confidence_thresholds)
22 | return area
23 |
--------------------------------------------------------------------------------
/src/torchcontentarea/pythonimplementation/fit_area.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from math import sqrt
3 | from typing import Sequence
4 |
5 | MAX_CENTER_DIST=0.2
6 | MIN_RADIUS=0.2
7 | MAX_RADIUS=0.8
8 | RANSAC_ATTEMPTS=32
9 | RANSAC_ITERATIONS=3
10 | RANSAC_INLIER_THRESHOLD=3
11 |
12 | def check_circle(x, y, r, w, h):
13 |
14 | x_diff = x - 0.5 * w
15 | y_diff = y - 0.5 * h
16 | diff = sqrt(x_diff * x_diff + y_diff * y_diff)
17 |
18 | valid = True
19 | valid &= diff < MAX_CENTER_DIST * w
20 | valid &= r > MIN_RADIUS * w
21 | valid &= r < MAX_RADIUS * w
22 |
23 | return valid
24 |
25 | def calculate_circle(points):
26 |
27 | ax, ay = points[:2, 0]
28 | bx, by = points[:2, 1]
29 | cx, cy = points[:2, 2]
30 |
31 | offset = bx * bx + by * by
32 |
33 | bc = 0.5 * (ax * ax + ay * ay - offset)
34 | cd = 0.5 * (offset - cx * cx - cy * cy)
35 |
36 | det = (ax - bx) * (by - cy) - (bx - cx) * (ay - by)
37 |
38 | if abs(det) > 1e-8:
39 | idet = 1.0 / det
40 | x = (bc * (by - cy) - cd * (ay - by)) * idet
41 | y = (cd * (ax - bx) - bc * (bx - cx)) * idet
42 | r = sqrt((bx - x) * (bx - x) + (by - y) * (by - y))
43 | return x, y, r
44 | else:
45 | return None
46 |
47 | def get_circle(points):
48 |
49 | if points.size(1) == 3:
50 | return calculate_circle(points)
51 | else:
52 | lhs = torch.zeros(3, 3).to(points.device)
53 | rhs = torch.zeros(3).to(points.device)
54 |
55 | for p_x, p_y, _ in points.T:
56 | lhs[0, 0] += p_x * p_x
57 | lhs[0, 1] += p_x * p_y
58 | lhs[1, 1] += p_y * p_y
59 | lhs[0, 2] += p_x
60 | lhs[1, 2] += p_y
61 | lhs[2, 2] += 1
62 |
63 | rhs[0] += p_x * p_x * p_x + p_x * p_y * p_y
64 | rhs[1] += p_x * p_x * p_y + p_y * p_y * p_y
65 | rhs[2] += p_x * p_x + p_y * p_y
66 |
67 | lhs[1, 0] = lhs[0, 1]
68 | lhs[2, 0] = lhs[0, 2]
69 | lhs[2, 1] = lhs[1, 2]
70 |
71 | try:
72 | L = torch.linalg.cholesky(lhs)
73 | y = torch.linalg.solve(L, rhs)
74 | x = torch.linalg.solve(L.T, y)
75 | except:
76 | return None
77 |
78 | A, B, C = x[0], x[1], x[2]
79 |
80 | x = A / 2.0
81 | y = B / 2.0
82 | r = torch.sqrt(4.0 * C + A * A + B * B) / 2.0
83 |
84 | return x, y, r
85 |
86 | def fit_area(points: torch.Tensor, image_size: Sequence[int], confidence_thresholds: Sequence[float]=(0.03, 0.06)) -> torch.Tensor:
87 | """
88 | Finds candidate edge points and corresponding scores in the given image(s) using handcrafted feature extraction.
89 | """
90 |
91 | batched = len(points.shape) == 3
92 |
93 | if not batched:
94 | points = points.unsqueeze(0)
95 |
96 | areas = []
97 |
98 | for point_batch in points:
99 |
100 | point_batch = point_batch[:, point_batch[2] > confidence_thresholds[0]]
101 |
102 | if point_batch.size(1) < 3:
103 | areas.append(torch.zeros(4))
104 | continue
105 |
106 | best_circle = torch.zeros(4)
107 |
108 | for _ in range(RANSAC_ATTEMPTS):
109 |
110 | indices = torch.randperm(point_batch.size(1))[:3]
111 | inliers = point_batch[:, indices]
112 |
113 | for _ in range(RANSAC_ITERATIONS):
114 | circle = get_circle(inliers)
115 |
116 | if circle is None:
117 | x, y, r = 0, 0, 0
118 | circle_score = 0
119 | break
120 |
121 | x, y, r = circle
122 |
123 | dx = x - point_batch[0]
124 | dy = y - point_batch[1]
125 |
126 | error = torch.abs(torch.sqrt(dx**2 + dy**2) - r)
127 |
128 | inliers = point_batch[:, error < RANSAC_INLIER_THRESHOLD]
129 | circle_score = inliers[2].sum()
130 |
131 | circle_score /= point_batch.size(1)
132 |
133 | circle_valid = check_circle(x, y, r, image_size[1], image_size[0])
134 |
135 | if circle_valid and circle_score > best_circle[3]:
136 | best_circle = torch.tensor([x, y, r, circle_score])
137 |
138 | areas.append(best_circle)
139 |
140 | areas = torch.stack(areas)
141 |
142 | if not batched:
143 | areas = areas[0]
144 |
145 | return areas
146 |
--------------------------------------------------------------------------------
/src/torchcontentarea/pythonimplementation/get_points.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Sequence, Optional
3 |
4 | DEG2RAD=0.01745329251
5 | RAD2DEG=1.0/DEG2RAD
6 |
7 | GRAY_SCALE_WEIGHTS = torch.tensor([
8 | 0.2126, 0.7152, 0.0722
9 | ])
10 |
11 | SOBEL_KERNEL = torch.tensor([
12 | [[
13 | [-0.25, 0.0, 0.25],
14 | [-0.5, 0.0, 0.5],
15 | [-0.25, 0.0, 0.25],
16 | ]],
17 | [[
18 | [-0.25, -0.5, -0.25],
19 | [0.0, 0.0, 0.0],
20 | [0.25, 0.5, 0.25],
21 | ]],
22 | ])
23 |
24 | def get_points_handcrafted(image: torch.Tensor, strip_count: int=16, feature_thresholds: Sequence[float]=(20, 30, 25)) -> torch.Tensor:
25 | """
26 | Finds candidate edge points and corresponding scores in the given image(s) using handcrafted feature extraction.
27 | """
28 |
29 | batched = len(image.shape) == 4
30 | device = image.device
31 | if not batched:
32 | image.unsqueeze(0)
33 | if image.dtype.is_floating_point:
34 | image = image * 255.0
35 | if image.size(1) != 1:
36 | image = (image * GRAY_SCALE_WEIGHTS[None, :, None, None].to(device)).sum(dim=1)
37 | image = image.float()
38 | B, H, W = image.shape
39 |
40 | strip_indices = torch.arange(strip_count, device=device)
41 | strip_heights = (1 + (H - 2) / (1.0 + torch.exp(-(strip_indices - strip_count / 2.0 + 0.5) / (strip_count / 8.0)))).long()
42 |
43 | indices = torch.cat([strip_heights-1, strip_heights, strip_heights+1])
44 | strips = image[:, indices, :].reshape(B, 3, strip_count, W).permute(0, 2, 1, 3)
45 | strips = torch.cat([strips[..., :W//2], strips.flip(-1)[..., :W//2]], dim=1)
46 |
47 | ys = torch.cat([strip_heights, strip_heights])[None, :, None].repeat(B, 1, W//2-2)
48 | xs = (torch.arange(W//2-2) + 1)[None, None, :].to(device).repeat(B, strip_count, 1)
49 | xs = torch.cat([xs, W - 1 - xs], dim=1)
50 |
51 | grad = torch.conv2d(strips.reshape(B*2*strip_count, 1, 3, -1), SOBEL_KERNEL.to(device)).reshape(B, 2*strip_count, 2, -1)
52 | strips = strips[:, :, 1, 1:-1]
53 |
54 | grad_x = grad[:, :, 0]
55 | grad_y = grad[:, :, 1]
56 | grad_x[:, strip_count:] *= -1
57 |
58 | grad = torch.sqrt(grad_x**2 + grad_y**2)
59 |
60 | max_preceeding_intensity = torch.cummax(strips, dim=-1)[0]
61 |
62 | center_dir_x = (0.5 * W) - xs
63 | center_dir_y = (0.5 * H) - ys
64 | center_dir_norm = torch.sqrt(center_dir_x * center_dir_x + center_dir_y * center_dir_y)
65 |
66 | dot = torch.where(grad == 0, -1, (center_dir_x * grad_x + center_dir_y * grad_y) / (center_dir_norm * grad))
67 | dot = torch.clamp(dot, -0.99, 0.99)
68 | angle = RAD2DEG * torch.acos(dot)
69 |
70 | edge_score = torch.tanh(grad / feature_thresholds[0])
71 | angle_score = 1.0 - torch.tanh(angle / feature_thresholds[1])
72 | intensity_score = 1.0 - torch.tanh(max_preceeding_intensity / feature_thresholds[2])
73 |
74 | point_scores = edge_score * angle_score * intensity_score
75 |
76 | point_scores, indices = torch.max(point_scores, dim=-1)
77 |
78 | point_scores = point_scores
79 | ys = torch.gather(ys, 2, indices[:, :, None])[:, :, 0]
80 | xs = torch.gather(xs, 2, indices[:, :, None])[:, :, 0]
81 |
82 | result = torch.stack([xs, ys, point_scores], dim=2).reshape(B, strip_count*2, 3).permute(0, 2, 1)
83 |
84 | if not batched:
85 | result = result[0]
86 |
87 | return result
88 |
89 | def get_points_learned(image: torch.Tensor, strip_count: int=16, model: Optional[torch.jit.ScriptModule]=None, model_patch_size: int=7) -> torch.Tensor:
90 | """
91 | Finds candidate edge points and corresponding scores in the given image(s) using learned feature extraction.
92 | """
93 | raise NotImplementedError("The learned method is not implemented in the python fallback. The handcrafted method should provide better performance, but to use the learned method you will need to install the compiled extension.")
94 |
--------------------------------------------------------------------------------
/src/torchcontentarea/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from math import sqrt, floor
3 |
4 | def draw_area(area, image, bias=0):
5 | """
6 | Draws a binary mask with 1 inside the content area and 0 outside.
7 | """
8 | image_size, device = image.shape[-2:], image.device
9 |
10 | mesh = torch.meshgrid(torch.arange(0, image_size[0], device=device), torch.arange(0, image_size[1], device=device))
11 | dist = torch.sqrt((mesh[0] - area[1])**2 + (mesh[1] - area[0])**2)
12 | mask = torch.where(dist < (area[2] + bias), 1, 0).to(dtype=torch.uint8)
13 |
14 | return mask
15 |
16 | def get_crop(area, image_size, aspect_ratio=None, bias=0):
17 | """
18 | Returns the optimal crop (top, bottom, left, right) to remove areas outside the content area.
19 | """
20 | i_h, i_w = image_size
21 | a_x, a_y, a_r = area[:3]
22 |
23 | if aspect_ratio == None:
24 | aspect_ratio = i_w / i_h
25 |
26 | inscribed_height = 2 * (a_r + bias - 2) / sqrt(1 + aspect_ratio * aspect_ratio)
27 | inscribed_width = inscribed_height * aspect_ratio
28 |
29 | left = max(a_x - inscribed_width / 2, 0)
30 | right = min(a_x + inscribed_width / 2, i_w)
31 | top = max(a_y - inscribed_height / 2, 0)
32 | bottom = min(a_y + inscribed_height / 2, i_h)
33 |
34 | x_scale = (right - left)
35 | y_scale = (bottom - top) * aspect_ratio
36 |
37 | scale = min(x_scale, y_scale)
38 |
39 | w = int(floor(scale))
40 | h = int(floor(scale / aspect_ratio))
41 |
42 | x = int(left + (right - left) / 2 - w / 2)
43 | y = int(top + (bottom - top) / 2 - h / 2)
44 |
45 | crop = y, y+h, x, x+w
46 |
47 | return crop
48 |
49 | def crop_area(area, image, aspect_ratio=None, bias=0):
50 | """
51 | Crops the image to remove areas outside the content area.
52 | """
53 | crop = get_crop(area, image.shape[-2:], aspect_ratio, bias)
54 |
55 | cropped_image = image[..., crop[0]:crop[1], crop[2]:crop[3]]
56 |
57 | return cropped_image
58 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # from .test_performance import TestPerformance
2 | from .test_api import TestAPI
3 |
--------------------------------------------------------------------------------
/tests/test_api.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import unittest
3 |
4 | from .utils.data import DummyDataset
5 | from .utils.scoring import content_area_hausdorff, MISS_THRESHOLD
6 |
7 | from torchcontentarea import estimate_area_handcrafted, estimate_area_learned, get_points_handcrafted, get_points_learned, fit_area
8 |
9 | ESTIMATION_MEHTODS = [
10 | ("handcrafted cpu", estimate_area_handcrafted, "cpu"),
11 | ("learned cpu", estimate_area_learned, "cpu"),
12 | ("handcrafted cuda", estimate_area_handcrafted, "cuda"),
13 | ("learned cuda", estimate_area_learned, "cuda"),
14 | ("handcrafted cpu two staged", lambda x: fit_area(get_points_handcrafted(x), x.shape[-2:]), "cpu"),
15 | ("learned cpu two staged", lambda x: fit_area(get_points_learned(x), x.shape[-2:]), "cpu"),
16 | ("handcrafted cuda two staged", lambda x: fit_area(get_points_handcrafted(x), x.shape[-2:]), "cuda"),
17 | ("learned cuda two staged", lambda x: fit_area(get_points_learned(x), x.shape[-2:]), "cuda"),
18 | ]
19 |
20 | class TestAPI(unittest.TestCase):
21 |
22 | def __init__(self, methodName: str = ...) -> None:
23 | super().__init__(methodName)
24 | self.dataset = DummyDataset()
25 |
26 | def test_unbatched(self):
27 | image, _, true_area = self.dataset[0]
28 | for name, method, device in ESTIMATION_MEHTODS:
29 | with self.subTest(name):
30 | image = image.to(device)
31 | result = method(image).tolist()
32 | estimated_area = result[0:3]
33 | error, _ = content_area_hausdorff(true_area, estimated_area, image.shape[-2:])
34 | self.assertLess(error, MISS_THRESHOLD)
35 |
36 | def test_batched(self):
37 | image_a, _, true_area_a = self.dataset[0]
38 | image_b, _, true_area_b = self.dataset[1]
39 | images = torch.stack([image_a, image_b])
40 | true_areas = [true_area_a, true_area_b]
41 | for name, method, device in ESTIMATION_MEHTODS:
42 | with self.subTest(name):
43 | images = images.to(device)
44 | results = method(images).tolist()
45 | for true_area, result in zip(true_areas, results):
46 | estimated_area = result[0:3]
47 | error, _ = content_area_hausdorff(true_area, estimated_area, images.shape[-2:])
48 | self.assertLess(error, MISS_THRESHOLD)
49 |
50 | def test_rgb(self):
51 | image, _, true_area = self.dataset[0]
52 | for name, method, device in ESTIMATION_MEHTODS:
53 | with self.subTest(name):
54 | image = image.to(device)
55 | result = method(image).tolist()
56 | estimated_area = result[0:3]
57 | error, _ = content_area_hausdorff(true_area, estimated_area, image.shape[-2:])
58 | self.assertLess(error, MISS_THRESHOLD)
59 |
60 | def test_grayscale(self):
61 | image, _, true_area = self.dataset[0]
62 | image = (0.2126 * image[0:1] + 0.7152 * image[1:2]+ 0.0722 * image[2:3]).to(torch.uint8)
63 | for name, method, device in ESTIMATION_MEHTODS:
64 | with self.subTest(name):
65 | image = image.to(device)
66 | result = method(image).tolist()
67 | estimated_area = result[0:3]
68 | error, _ = content_area_hausdorff(true_area, estimated_area, image.shape[-2:])
69 | self.assertLess(error, MISS_THRESHOLD)
70 |
71 | def test_large(self):
72 | image, _, true_area = self.dataset[0]
73 | image = torch.nn.functional.interpolate(image.unsqueeze(0).to(dtype=torch.float), scale_factor=4, mode='bilinear')[0].to(dtype=torch.uint8)
74 | true_area = tuple(map(lambda x: x*4, true_area))
75 | for name, method, device in ESTIMATION_MEHTODS:
76 | with self.subTest(name):
77 | image = image.to(device)
78 | result = method(image).tolist()
79 | estimated_area = result[0:3]
80 | error, _ = content_area_hausdorff(true_area, estimated_area, image.shape[-2:])
81 | self.assertLess(error, MISS_THRESHOLD)
82 |
83 | def test_small(self):
84 | image, _, true_area = self.dataset[0]
85 | image = image[:, ::4, ::4]
86 | true_area = tuple(map(lambda x: x/4, true_area))
87 | for name, method, device in ESTIMATION_MEHTODS:
88 | with self.subTest(name):
89 | image = image.to(device)
90 | result = method(image).tolist()
91 | estimated_area = result[0:3]
92 | error, _ = content_area_hausdorff(true_area, estimated_area, image.shape[-2:])
93 | self.assertLess(error, MISS_THRESHOLD)
94 |
95 | def test_byte(self):
96 | image, _, true_area = self.dataset[0]
97 | image = image.to(dtype=torch.uint8)
98 | for name, method, device in ESTIMATION_MEHTODS:
99 | with self.subTest(name):
100 | image = image.to(device)
101 | result = method(image).tolist()
102 | estimated_area = result[0:3]
103 | error, _ = content_area_hausdorff(true_area, estimated_area, image.shape[-2:])
104 | self.assertLess(error, MISS_THRESHOLD)
105 |
106 | def test_int(self):
107 | image, _, true_area = self.dataset[0]
108 | image = image.to(dtype=torch.int)
109 | for name, method, device in ESTIMATION_MEHTODS:
110 | with self.subTest(name):
111 | image = image.to(device)
112 | result = method(image).tolist()
113 | estimated_area = result[0:3]
114 | error, _ = content_area_hausdorff(true_area, estimated_area, image.shape[-2:])
115 | self.assertLess(error, MISS_THRESHOLD)
116 |
117 | def test_long(self):
118 | image, _, true_area = self.dataset[0]
119 | image = image.to(dtype=torch.int64)
120 | for name, method, device in ESTIMATION_MEHTODS:
121 | with self.subTest(name):
122 | image = image.to(device)
123 | result = method(image).tolist()
124 | estimated_area = result[0:3]
125 | error, _ = content_area_hausdorff(true_area, estimated_area, image.shape[-2:])
126 | self.assertLess(error, MISS_THRESHOLD)
127 |
128 | def test_float(self):
129 | image, _, true_area = self.dataset[0]
130 | image = image.to(dtype=torch.float) / 255
131 | for name, method, device in ESTIMATION_MEHTODS:
132 | with self.subTest(name):
133 | image = image.to(device)
134 | result = method(image).tolist()
135 | estimated_area = result[0:3]
136 | error, _ = content_area_hausdorff(true_area, estimated_area, image.shape[-2:])
137 | self.assertLess(error, MISS_THRESHOLD)
138 |
139 | def test_double(self):
140 | image, _, true_area = self.dataset[0]
141 | image = image.to(dtype=torch.double) / 255
142 | for name, method, device in ESTIMATION_MEHTODS:
143 | with self.subTest(name):
144 | image = image.to(device)
145 | result = method(image).tolist()
146 | estimated_area = result[0:3]
147 | error, _ = content_area_hausdorff(true_area, estimated_area, image.shape[-2:])
148 | self.assertLess(error, MISS_THRESHOLD)
149 |
150 |
151 | if __name__ == '__main__':
152 | unittest.main()
153 |
--------------------------------------------------------------------------------
/tests/test_performance.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import unittest
4 | import cpuinfo
5 | from time import sleep
6 |
7 | from .utils.data import TestDataset, TestDataLoader
8 | from .utils.scoring import content_area_hausdorff, MISS_THRESHOLD, BAD_MISS_THRESHOLD
9 | from .utils.profiling import Timer
10 |
11 | from torchcontentarea import estimate_area_handcrafted, estimate_area_learned
12 |
13 | TEST_LOG = ""
14 |
15 | TEST_CASES = [
16 | # ("handcrafted cpu", estimate_area_handcrafted, "cpu"),
17 | # ("learned cpu", estimate_area_learned, "cpu"),
18 | ("handcrafted cuda", estimate_area_handcrafted, "cuda"),
19 | # ("learned cuda", estimate_area_learned, "cuda"),
20 | ]
21 |
22 | class TestPerformance(unittest.TestCase):
23 |
24 | def __init__(self, methodName: str = ...) -> None:
25 | super().__init__(methodName)
26 | self.dataset = TestDataset()
27 | self.dataloader = TestDataLoader(self.dataset)
28 |
29 | def test_performance(self):
30 |
31 | times = [[] for _ in range(len(TEST_CASES))]
32 | errors = [[] for _ in range(len(TEST_CASES))]
33 |
34 | for img, area in self.dataloader:
35 |
36 | for i, (name, method, device) in enumerate(TEST_CASES):
37 |
38 | img = img.to(device=device)
39 |
40 | with Timer() as timer:
41 | result = method(img)
42 | time = timer.time
43 |
44 | result = result.tolist()
45 | infered_area, confidence = result[:3], result[3]
46 |
47 | infered_area = tuple(map(int, infered_area))
48 |
49 | if confidence < 0.06:
50 | infered_area = None
51 |
52 | error, _ = content_area_hausdorff(area, infered_area, img.shape[-2:])
53 |
54 | errors[i].append(error)
55 | times[i].append(time)
56 |
57 |
58 | for (name, _, device), times, errors in zip(TEST_CASES, times, errors):
59 | device_name = torch.cuda.get_device_name() if device == "cuda" else cpuinfo.get_cpu_info()['brand_raw']
60 | run_in_count = int(len(times) // 100)
61 | times = times[run_in_count:]
62 | avg_time = sum(times) / len(times)
63 | std_time = np.std(times)
64 |
65 | sample_count = len(self.dataset)
66 | average_error = sum(errors) / sample_count
67 | miss_percentage = 100 * sum(map(lambda x: x > MISS_THRESHOLD, errors)) / sample_count
68 | bad_miss_percentage = 100 * sum(map(lambda x: x > BAD_MISS_THRESHOLD, errors)) / sample_count
69 |
70 | global TEST_LOG
71 | TEST_LOG += "\n".join([
72 | f"\n",
73 | f"Performance Results ({name})...",
74 | f"- Avg Time ({device_name}): {avg_time:.3f} ± {std_time:.3f}ms",
75 | f"- Avg Error (Hausdorff Distance): {average_error:.3f}",
76 | f"- Miss Rate (Error > {MISS_THRESHOLD}): {miss_percentage:.1f}%",
77 | f"- Bad Miss Rate (Error > {BAD_MISS_THRESHOLD}): {bad_miss_percentage:.1f}%"
78 | ])
79 |
80 | self.assertTrue(avg_time < 10)
81 | self.assertTrue(average_error < 10)
82 | self.assertTrue(miss_percentage < 10.0)
83 | self.assertTrue(bad_miss_percentage < 5.0)
84 |
85 | @classmethod
86 | def tearDownClass(cls):
87 | if TEST_LOG != "":
88 | sleep(3)
89 | print(TEST_LOG)
90 |
91 | if __name__ == '__main__':
92 | unittest.main()
93 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import unittest
3 |
4 | from .utils.data import DummyDataset
5 | from .utils.scoring import content_area_hausdorff, MISS_THRESHOLD
6 |
7 | from torchcontentarea.utils import draw_area, crop_area
8 |
9 | TESTS_CASES = ["cuda", "cpu"]
10 |
11 | class TestUtils(unittest.TestCase):
12 |
13 | def __init__(self, methodName: str = ...) -> None:
14 | super().__init__(methodName)
15 | self.dataset = DummyDataset()
16 |
17 | def test_draw_area(self):
18 | _, mask, area = self.dataset[0]
19 | for device in TESTS_CASES:
20 | with self.subTest(device):
21 | mask = mask.to(device)
22 | result = draw_area(area, mask)
23 | score = torch.where(1 - result == mask, 0, 1).sum()
24 | self.assertLess(score, 1)
25 |
26 | def test_crop_area(self):
27 | _, mask, area = self.dataset[0]
28 | for device in TESTS_CASES:
29 | with self.subTest(device):
30 | mask = mask.to(device)
31 | result = crop_area(area, mask)
32 | self.assertTrue(result.unique().numel() == 1)
33 |
34 | if __name__ == '__main__':
35 | unittest.main()
--------------------------------------------------------------------------------
/tests/utils/data.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch.utils.data import Dataset, DataLoader
4 | from ecadataset import ECADataset, DataSource, AnnotationType
5 |
6 | def meshgrid(tensors):
7 | major, minor = map(int, torch.__version__.split(".")[:2])
8 | if major >= 1 and minor > 9:
9 | return torch.meshgrid(tensors, indexing="ij")
10 | else:
11 | return torch.meshgrid(tensors)
12 |
13 |
14 | ########################
15 | # Datasets...
16 |
17 | class TestDataLoader(DataLoader):
18 | def __init__(self, dataset, shuffle=False) -> None:
19 | super().__init__(dataset=dataset, batch_size=None, pin_memory=True, shuffle=shuffle)
20 |
21 |
22 | class TestDataset(Dataset):
23 | def __init__(self) -> None:
24 | super().__init__()
25 | self.dataset = ECADataset("eca-data", DataSource.CHOLEC, AnnotationType.AREA, include_cropped=True)
26 |
27 | def __len__(self):
28 | return len(self.dataset)
29 |
30 | def __getitem__(self, index):
31 | image, area = self.dataset[index]
32 | img = torch.from_numpy(np.array(image)).permute(2, 0, 1)
33 | area = None if area == None else torch.from_numpy(np.array(area))
34 | return img, area
35 |
36 |
37 | class DummyDataset(Dataset):
38 | def __init__(self) -> None:
39 | super().__init__()
40 |
41 | self.width = 854
42 | self.height = 480
43 |
44 | self.areas = [
45 | (400, 250, 360),
46 | (340, 200, 370),
47 | (450, 230, 250),
48 | None,
49 | ]
50 |
51 | def __len__(self):
52 | return len(self.areas)
53 |
54 | def __getitem__(self, index):
55 | area = self.areas[index]
56 |
57 | if area != None:
58 | area_x, area_y, area_r = self.areas[index]
59 | coords = torch.stack(meshgrid([torch.arange(0, self.height), torch.arange(0, self.width)]))
60 | center = torch.Tensor([area_y, area_x]).reshape((2, 1, 1))
61 | mask = torch.where(torch.linalg.norm(abs(coords - center), dim=0) < area_r, 0, 1).unsqueeze(0)
62 | else:
63 | mask = torch.zeros(1, self.height, self.width)
64 |
65 | img = 255 * (1 - mask).expand((3, self.height, self.width))
66 | img = img.to(dtype=torch.uint8).contiguous()
67 | mask = mask.to(dtype=torch.uint8).contiguous()
68 |
69 | return img, mask, area
70 |
--------------------------------------------------------------------------------
/tests/utils/profiling.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 |
4 | UNITS = {'h': 1.0/60.0, 'm': 1.0/60.0, 's': 1.0, 'ms': 1e3, 'us': 1e6}
5 |
6 | def get_time():
7 | torch.cuda.synchronize()
8 | return time.time()
9 |
10 | class Timer():
11 | def __init__(self, units='ms'):
12 | assert units in UNITS, f'The given units {units} is not supported, please use h, m, s, ms, or us'
13 | self.units = units
14 |
15 | def __enter__(self):
16 | self.start = get_time()
17 | return self
18 |
19 | def __exit__(self, type, value, traceback):
20 | self.end = get_time()
21 | self.time = (self.end - self.start) * UNITS[self.units]
22 |
--------------------------------------------------------------------------------
/tests/utils/scoring.py:
--------------------------------------------------------------------------------
1 | from ecadataset import content_area_hausdorff
2 |
3 | MISS_THRESHOLD=15
4 | BAD_MISS_THRESHOLD=25
5 |
--------------------------------------------------------------------------------