├── .DS_Store
├── .gitignore
├── .vscode
├── c_cpp_properties.json
├── launch.json
├── settings.json
└── tasks.json
├── CMakeLists.txt
├── Readme.md
├── build.gradle.kts
├── main.out
└── src
├── .DS_Store
├── base_optmized.onnx
├── best.onnx
├── best.ort
├── best_int8.onnx
├── best_int8.ort
├── best_optmized.onnx
├── best_quant.onnx
├── best_saved_model
├── best_float16.tflite
├── best_float32.tflite
├── fingerprint.pb
├── metadata.yaml
├── saved_model.pb
└── variables
│ ├── variables.data-00000-of-00001
│ └── variables.index
├── camera_inference.cpp
├── camera_inference.out
├── classes.txt
├── ia
├── YOLO11.hpp
└── tools
│ ├── Config.hpp
│ ├── Debug.hpp
│ └── ScopedTimer.hpp
├── image_2.jpg
├── input.mov
├── kotlin
├── AndroidManifest.xml
├── Application.kt
├── BuildConfig.kt
├── DebugUtils.kt
├── MainActivity.kt
├── ModelParseActivity.kt
├── ScopedTimer.kt
├── TFLiteModelManager.kt
├── YOLO11Detector.kt
├── activity_main.xml
├── build.gradle
├── build.gradle.kts
└── res
│ └── layout
│ └── activity_model_parse.xml
├── output.mp4
├── output
├── base_simplify.onnx
├── t1.mp4
├── yolo_cli_pt.mp4
└── yolov11_cpp_onnx.mp4
├── runs
└── detect
│ ├── predict2
│ └── t1.mp4
│ └── predict3
│ └── t1.mp4
├── t1.mp4
├── viewer.cpp
└── viewer.out
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DanielSarmiento04/yolov11cpp/c0690429b302c0b8a283a900ee30b89152019909/.DS_Store
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 |
163 | # Prerequisites
164 | *.d
165 |
166 | # Object files
167 | *.o
168 | *.ko
169 | *.obj
170 | *.elf
171 |
172 | # Linker output
173 | *.ilk
174 | *.map
175 | *.exp
176 |
177 | # Precompiled Headers
178 | *.gch
179 | *.pch
180 |
181 | # Libraries
182 | *.lib
183 | *.a
184 | *.la
185 | *.lo
186 |
187 | # Shared objects (inc. Windows DLLs)
188 | *.dll
189 | *.so
190 | *.so.*
191 | *.dylib
192 |
193 | # Executables
194 | *.exe
195 | *.out
196 | *.app
197 | *.i*86
198 | *.x86_64
199 | *.hex
200 |
201 | # Debug files
202 | *.dSYM/
203 | *.su
204 | *.idb
205 | *.pdb
206 |
207 | # Kernel Module Compile Results
208 | *.mod*
209 | *.cmd
210 | .tmp_versions/
211 | modules.order
212 | Module.symvers
213 | Mkfile.old
214 | dkms.conf
215 |
216 | *.onnx
217 |
218 | *.pt
219 |
220 | *.DS_Store
--------------------------------------------------------------------------------
/.vscode/c_cpp_properties.json:
--------------------------------------------------------------------------------
1 | {
2 | "configurations": [
3 | {
4 | "name": "Mac",
5 | "includePath": [
6 | "${workspaceFolder}/**",
7 | "/opt/homebrew/Cellar/opencv/4.11.0/include/opencv4/opencv2",
8 | "/opt/homebrew/Cellar/opencv/4.11.0/include/opencv4",
9 | "/opt/homebrew/Cellar/onnxruntime/1.17.1/include/onnxruntime",
10 | "/Users/danielsarmiento/Desktop/hobby/yolov11cpp/src/ia/"
11 | ],
12 | "defines": [],
13 | "macFrameworkPath": [],
14 | "compilerPath": "/usr/bin/g++",
15 | "cStandard": "c17",
16 | "cppStandard": "c++17",
17 | "intelliSenseMode": "clang-x64",
18 | "browse": {
19 | "path": [
20 | "/opt/homebrew/Cellar/opencv/4.11.0/include/opencv4",
21 | "/opt/homebrew/Cellar/onnxruntime/1.17.1/include/onnxruntime",
22 | "/Users/danielsarmiento/Desktop/hobby/yolov11cpp/src/ia/"
23 | ],
24 | "limitSymbolsToIncludedHeaders": true,
25 | "databaseFilename": ""
26 | }
27 | }
28 | ],
29 | "version": 4
30 | }
--------------------------------------------------------------------------------
/.vscode/launch.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": "0.2.0",
3 | "configurations": [
4 | {
5 | "name": "(lldb) Launch",
6 | "type": "cppdbg",
7 | "request": "launch",
8 | "program": "${fileDirname}/${fileBasenameNoExtension}.out",
9 | "args": [],
10 | "stopAtEntry": true,
11 | "cwd": "${workspaceFolder}",
12 | "environment": [],
13 | "externalConsole": true,
14 | "MIMode": "lldb",
15 | "preLaunchTask": "Build"
16 | }
17 | ]
18 | }
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "files.associations": {
3 | "__availability": "cpp",
4 | "charconv": "cpp",
5 | "string": "cpp",
6 | "vector": "cpp",
7 | "__config": "cpp",
8 | "__split_buffer": "cpp",
9 | "deque": "cpp",
10 | "list": "cpp",
11 | "__bit_reference": "cpp",
12 | "__debug": "cpp",
13 | "__errc": "cpp",
14 | "__hash_table": "cpp",
15 | "__locale": "cpp",
16 | "__mutex_base": "cpp",
17 | "__node_handle": "cpp",
18 | "__threading_support": "cpp",
19 | "__tree": "cpp",
20 | "__verbose_abort": "cpp",
21 | "array": "cpp",
22 | "atomic": "cpp",
23 | "bitset": "cpp",
24 | "cctype": "cpp",
25 | "clocale": "cpp",
26 | "cmath": "cpp",
27 | "complex": "cpp",
28 | "cstdarg": "cpp",
29 | "cstddef": "cpp",
30 | "cstdint": "cpp",
31 | "cstdio": "cpp",
32 | "cstdlib": "cpp",
33 | "cstring": "cpp",
34 | "ctime": "cpp",
35 | "cwchar": "cpp",
36 | "cwctype": "cpp",
37 | "exception": "cpp",
38 | "fstream": "cpp",
39 | "initializer_list": "cpp",
40 | "iomanip": "cpp",
41 | "ios": "cpp",
42 | "iosfwd": "cpp",
43 | "iostream": "cpp",
44 | "istream": "cpp",
45 | "limits": "cpp",
46 | "locale": "cpp",
47 | "map": "cpp",
48 | "mutex": "cpp",
49 | "new": "cpp",
50 | "optional": "cpp",
51 | "ostream": "cpp",
52 | "queue": "cpp",
53 | "ratio": "cpp",
54 | "set": "cpp",
55 | "sstream": "cpp",
56 | "stdexcept": "cpp",
57 | "streambuf": "cpp",
58 | "string_view": "cpp",
59 | "system_error": "cpp",
60 | "tuple": "cpp",
61 | "typeinfo": "cpp",
62 | "unordered_map": "cpp",
63 | "variant": "cpp",
64 | "algorithm": "cpp",
65 | "execution": "cpp",
66 | "regex": "cpp"
67 | }
68 | }
--------------------------------------------------------------------------------
/.vscode/tasks.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": "2.0.0",
3 | "tasks": [
4 | {
5 | "label": "Build",
6 | "type": "shell",
7 | "command": "clang++",
8 | "args": [
9 | "-std=c++17",
10 | "${file}",
11 | "-o",
12 | "${fileDirname}/${fileBasenameNoExtension}.out",
13 | "-I",
14 | "/opt/homebrew/Cellar/opencv/4.11.0/include/opencv4/opencv2",
15 | "-I",
16 | "/opt/homebrew/Cellar/onnxruntime/1.17.1/include/onnxruntime",
17 | "-I",
18 | "/opt/homebrew/Cellar/opencv/4.11.0/include/opencv4",
19 | "-I",
20 | "/Users/danielsarmiento/Desktop/hobby/yolov11cpp/src/ia/",
21 | "-L",
22 | "/opt/homebrew/Cellar/opencv/4.11.0/lib",
23 | "-L",
24 | "/opt/homebrew/Cellar/onnxruntime/1.17.1/lib",
25 | "-l",
26 | "onnxruntime",
27 | "-l",
28 | "opencv_stitching",
29 | "-l",
30 | "opencv_superres",
31 | "-l",
32 | "opencv_videostab",
33 | "-l",
34 | "opencv_aruco",
35 | "-l",
36 | "opencv_bgsegm",
37 | "-l",
38 | "opencv_bioinspired",
39 | "-l",
40 | "opencv_ccalib",
41 | "-l",
42 | "opencv_dnn_objdetect",
43 | "-l",
44 | "opencv_dpm",
45 | "-l",
46 | "opencv_face",
47 | "-l",
48 | "opencv_fuzzy",
49 | "-l",
50 | "opencv_hfs",
51 | "-l",
52 | "opencv_img_hash",
53 | "-l",
54 | "opencv_line_descriptor",
55 | "-l",
56 | "opencv_optflow",
57 | "-l",
58 | "opencv_reg",
59 | "-l",
60 | "opencv_rgbd",
61 | "-l",
62 | "opencv_saliency",
63 | "-l",
64 | "opencv_stereo",
65 | "-l",
66 | "opencv_structured_light",
67 | "-l",
68 | "opencv_phase_unwrapping",
69 | "-l",
70 | "opencv_surface_matching",
71 | "-l",
72 | "opencv_tracking",
73 | "-l",
74 | "opencv_datasets",
75 | "-l",
76 | "opencv_dnn",
77 | "-l",
78 | "opencv_plot",
79 | "-l",
80 | "opencv_xfeatures2d",
81 | "-l",
82 | "opencv_shape",
83 | "-l",
84 | "opencv_video",
85 | "-l",
86 | "opencv_ml",
87 | "-l",
88 | "opencv_ximgproc",
89 | "-l",
90 | "opencv_xobjdetect",
91 | "-l",
92 | "opencv_objdetect",
93 | "-l",
94 | "opencv_calib3d",
95 | "-l",
96 | "opencv_features2d",
97 | "-l",
98 | "opencv_highgui",
99 | "-l",
100 | "opencv_videoio",
101 | "-l",
102 | "opencv_imgcodecs",
103 | "-l",
104 | "opencv_flann",
105 | "-l",
106 | "opencv_xphoto",
107 | "-l",
108 | "opencv_photo",
109 | "-l",
110 | "opencv_imgproc",
111 | "-l",
112 | "opencv_core",
113 | // "-g"
114 | ],
115 | "group": {
116 | "kind": "build",
117 | "isDefault": true
118 | },
119 | "problemMatcher": [
120 | "$gcc"
121 | ]
122 | },
123 | {
124 | "type": "cppbuild",
125 | "label": "C/C++: clang++ build active file",
126 | "command": "/usr/bin/clang++",
127 | "args": [
128 | "-fcolor-diagnostics",
129 | "-fansi-escape-codes",
130 | "-g",
131 | "${file}",
132 | "-o",
133 | "${fileDirname}/${fileBasenameNoExtension}"
134 | ],
135 | "options": {
136 | "cwd": "${fileDirname}"
137 | },
138 | "problemMatcher": [
139 | "$gcc"
140 | ],
141 | "group": "build",
142 | "detail": "compiler: /usr/bin/clang++"
143 | }
144 | ]
145 | }
--------------------------------------------------------------------------------
/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.10)
2 |
3 | # Set the project name in a variable
4 | set(project_name yolov10_cpp)
5 | project(${project_name})
6 | set(CMAKE_CXX_STANDARD 17)
7 |
8 | find_package(OpenCV REQUIRED)
9 |
10 | # Find ONNX Runtime package
11 | find_path(ONNXRUNTIME_INCLUDE_DIR onnxruntime_c_api.h
12 | HINTS /opt/homebrew/Cellar/onnxruntime/1.17.1/include/onnxruntime
13 | )
14 | find_library(ONNXRUNTIME_LIBRARY onnxruntime
15 | HINTS /opt/homebrew/Cellar/onnxruntime/1.17.1/lib
16 | src/ia/
17 | )
18 |
19 | if(NOT ONNXRUNTIME_INCLUDE_DIR)
20 | message(FATAL_ERROR "ONNX Runtime include directory not found")
21 | endif()
22 | if(NOT ONNXRUNTIME_LIBRARY)
23 | message(FATAL_ERROR "ONNX Runtime library not found")
24 | endif()
25 |
26 | add_library(${project_name}-lib
27 |
28 | )
29 |
30 | target_include_directories(${project_name}-lib PUBLIC src)
31 | target_include_directories(${project_name}-lib PUBLIC ${ONNXRUNTIME_INCLUDE_DIR})
32 |
33 | target_link_libraries(${project_name}-lib
34 | PUBLIC ${OpenCV_LIBS}
35 | PUBLIC ${ONNXRUNTIME_LIBRARY}
36 | )
37 |
38 | # Add the main executable
39 | add_executable(${project_name}
40 | ./src/camera_inference.cpp.cpp
41 | )
42 | # target_include_directories(${project_name} PUBLIC ${ONNXRUNTIME_INCLUDE_DIR})
43 | # target_link_libraries(${project_name} ${project_name}-lib)
44 |
45 | # # Add the video executable
46 | # add_executable(${project_name}_video
47 | # ./src/video.cpp
48 | # )
49 | # target_include_directories(${project_name}_video PUBLIC ${ONNXRUNTIME_INCLUDE_DIR})
50 | # target_link_libraries(${project_name}_video ${project_name}-lib)
51 |
52 |
--------------------------------------------------------------------------------
/Readme.md:
--------------------------------------------------------------------------------
1 | # YOLOv11 C++ Implementation
2 |
3 | A high-performance C++ implementation of YOLOv11 object detection using ONNX Runtime and OpenCV.
4 |
5 | 
6 |
7 | ## Features
8 |
9 | - Fast and efficient object detection using YOLOv11
10 | - Support for both CPU and GPU inference (CUDA)
11 | - Video processing capabilities
12 | - Dynamic confidence and IoU thresholds
13 | - Visual performance metrics (FPS counter)
14 | - Semi-transparent bounding box masks for cleaner visualization
15 |
16 | ## Prerequisites
17 |
18 | - CMake 3.12+
19 | - C++17 compatible compiler
20 | - OpenCV 4.x
21 | - ONNX Runtime 1.17+
22 | - CUDA Toolkit (optional, for GPU acceleration)
23 |
24 | ## Installation
25 |
26 | ### Clone the Repository
27 |
28 | ```bash
29 | git clone https://github.com/yourusername/yolov11cpp.git
30 | cd yolov11cpp
31 | ```
32 |
33 | ### Building with CMake
34 |
35 | ```bash
36 | mkdir build
37 | cd build
38 | cmake ..
39 | make -j$(nproc)
40 | ```
41 |
42 | ### Prepare the Model
43 |
44 | 1. Export your YOLOv11 model to ONNX format using Ultralytics:
45 |
46 | ```bash
47 | # If using Python/Ultralytics
48 | yolo export model=yolov11s.pt format=onnx opset=12 simplify=True
49 | ```
50 |
51 | 2. Place your ONNX model and class names file in the project directory:
52 |
53 | ```bash
54 | cp path/to/best.onnx ./
55 | cp path/to/classes.txt ./
56 | ```
57 |
58 | ## Usage
59 |
60 | ### Basic Command
61 |
62 | ```bash
63 | ./yolov11_detector [options]
64 | ```
65 |
66 | ### Options
67 |
68 | - `--model`: Path to the ONNX model file (default: "./best.onnx")
69 | - `--classes`: Path to the class names file (default: "./classes.txt")
70 | - `--input`: Path to input video file or camera device index (default: "./input.mov")
71 | - `--output`: Path for output video file (default: "./output.mp4")
72 | - `--gpu`: Use GPU acceleration if available (default: false)
73 | - `--conf`: Confidence threshold (default: 0.25)
74 | - `--iou`: IoU threshold for NMS (default: 0.45)
75 |
76 | ### Example
77 |
78 | ```bash
79 | # Process a video file with custom thresholds
80 | ./yolov11_detector --input=test_video.mp4 --output=result.mp4 --conf=0.3 --iou=0.4
81 |
82 | # Use webcam (device 0) with GPU acceleration
83 | ./yolov11_detector --input=0 --gpu=true
84 | ```
85 |
86 | ## Configuration
87 |
88 | You can modify the default settings by editing the constants in:
89 |
90 | - `src/camera_inference.cpp` - Main application settings
91 | - `src/ia/YOLO11.hpp` - Detection parameters and algorithms
92 | - `src/ia/tools/Config.hpp` - Debug and timing configurations
93 |
94 | ## Debugging
95 |
96 | Enable debugging by uncommenting these lines in `src/ia/tools/Config.hpp`:
97 |
98 | ```cpp
99 | // Enable debug messages
100 | #define DEBUG_MODE
101 |
102 | // Enable performance timing
103 | #define TIMING_MODE
104 | ```
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 | ## Troubleshooting
113 |
114 | ### Accuracy Issues
115 |
116 | If you notice differences in detection accuracy compared to the Python implementation:
117 |
118 | 1. Verify your ONNX model is exported correctly with proper settings
119 | 2. Check that preprocessing matches Ultralytics implementation (RGB conversion, normalization)
120 | 3. Confirm your class names file is correct and in the expected format
121 | 4. Try adjusting the confidence and IoU thresholds to match Ultralytics defaults (0.25 and 0.45)
122 |
123 | ### Performance Issues
124 |
125 | - For CPU optimization, ensure `ORT_ENABLE_ALL` optimization is enabled
126 | - For GPU usage, verify CUDA toolkit and ONNX Runtime with CUDA support are installed
127 | - Reduce input image resolution for better performance
128 |
129 | [Take Reference](https://github.com/Geekgineer/YOLOs-CPP)
--------------------------------------------------------------------------------
/build.gradle.kts:
--------------------------------------------------------------------------------
1 | import org.jetbrains.kotlin.gradle.tasks.KotlinCompile
2 |
3 | plugins {
4 | kotlin("jvm") version "1.5.1"
5 | application
6 | }
7 |
8 | group = "com.yolov11kotlin"
9 | version = "1.0-SNAPSHOT"
10 |
11 | repositories {
12 | mavenCentral()
13 | maven { url = uri("https://oss.sonatype.org/content/repositories/snapshots") }
14 | }
15 |
16 | dependencies {
17 | // ONNX Runtime
18 | implementation("com.microsoft.onnxruntime:onnxruntime-mobile:latest.release")
19 |
20 | // OpenCV
21 | implementation("org.openpnp:opencv:4.5.1-2")
22 |
23 | // Kotlin standard library
24 | implementation(kotlin("stdlib"))
25 |
26 | // Coroutines for async operations
27 | implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.5.0")
28 |
29 | // Testing
30 | testImplementation(kotlin("test"))
31 | }
32 |
33 | tasks.test {
34 | useJUnit()
35 | }
36 |
37 | tasks.withType {
38 | kotlinOptions.jvmTarget = "11"
39 | }
40 |
41 | application {
42 | mainClass.set("com.yolov11kotlin.MainKt")
43 | }
44 |
45 | // Task to copy native libraries to the build directory
46 | tasks.register("copyNativeLibs") {
47 | from("libs")
48 | into("${buildDir}/libs")
49 | include("**/*.so", "**/*.dll", "**/*.dylib")
50 | }
51 |
52 | tasks.named("run") {
53 | dependsOn("copyNativeLibs")
54 | }
55 |
--------------------------------------------------------------------------------
/main.out:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DanielSarmiento04/yolov11cpp/c0690429b302c0b8a283a900ee30b89152019909/main.out
--------------------------------------------------------------------------------
/src/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DanielSarmiento04/yolov11cpp/c0690429b302c0b8a283a900ee30b89152019909/src/.DS_Store
--------------------------------------------------------------------------------
/src/base_optmized.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DanielSarmiento04/yolov11cpp/c0690429b302c0b8a283a900ee30b89152019909/src/base_optmized.onnx
--------------------------------------------------------------------------------
/src/best.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DanielSarmiento04/yolov11cpp/c0690429b302c0b8a283a900ee30b89152019909/src/best.onnx
--------------------------------------------------------------------------------
/src/best.ort:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DanielSarmiento04/yolov11cpp/c0690429b302c0b8a283a900ee30b89152019909/src/best.ort
--------------------------------------------------------------------------------
/src/best_int8.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DanielSarmiento04/yolov11cpp/c0690429b302c0b8a283a900ee30b89152019909/src/best_int8.onnx
--------------------------------------------------------------------------------
/src/best_int8.ort:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DanielSarmiento04/yolov11cpp/c0690429b302c0b8a283a900ee30b89152019909/src/best_int8.ort
--------------------------------------------------------------------------------
/src/best_optmized.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DanielSarmiento04/yolov11cpp/c0690429b302c0b8a283a900ee30b89152019909/src/best_optmized.onnx
--------------------------------------------------------------------------------
/src/best_quant.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DanielSarmiento04/yolov11cpp/c0690429b302c0b8a283a900ee30b89152019909/src/best_quant.onnx
--------------------------------------------------------------------------------
/src/best_saved_model/best_float16.tflite:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DanielSarmiento04/yolov11cpp/c0690429b302c0b8a283a900ee30b89152019909/src/best_saved_model/best_float16.tflite
--------------------------------------------------------------------------------
/src/best_saved_model/best_float32.tflite:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DanielSarmiento04/yolov11cpp/c0690429b302c0b8a283a900ee30b89152019909/src/best_saved_model/best_float32.tflite
--------------------------------------------------------------------------------
/src/best_saved_model/fingerprint.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DanielSarmiento04/yolov11cpp/c0690429b302c0b8a283a900ee30b89152019909/src/best_saved_model/fingerprint.pb
--------------------------------------------------------------------------------
/src/best_saved_model/metadata.yaml:
--------------------------------------------------------------------------------
1 | description: Ultralytics YOLO11n model trained on ./custom.yml
2 | author: Ultralytics
3 | date: '2025-03-04T06:23:12.362203'
4 | version: 8.3.82
5 | license: AGPL-3.0 License (https://ultralytics.com/license)
6 | docs: https://docs.ultralytics.com
7 | stride: 32
8 | task: detect
9 | batch: 1
10 | imgsz:
11 | - 640
12 | - 640
13 | names:
14 | 0: person
15 | 1: bicycle
16 | 2: car
17 | 3: motorcycle
18 | 4: airplane
19 | 5: bus
20 | 6: train
21 | 7: truck
22 | 8: boat
23 | 9: traffic light
24 | 10: fire hydrant
25 | 11: stop sign
26 | 12: parking meter
27 | 13: bench
28 | 14: bird
29 | 15: cat
30 | 16: dog
31 | 17: horse
32 | 18: sheep
33 | 19: cow
34 | 20: elephant
35 | 21: bear
36 | 22: zebra
37 | 23: giraffe
38 | 24: backpack
39 | 25: umbrella
40 | 26: handbag
41 | 27: tie
42 | 28: suitcase
43 | 29: frisbee
44 | 30: skis
45 | 31: snowboard
46 | 32: sports ball
47 | 33: kite
48 | 34: baseball bat
49 | 35: baseball glove
50 | 36: skateboard
51 | 37: surfboard
52 | 38: tennis racket
53 | 39: bottle
54 | 40: wine glass
55 | 41: cup
56 | 42: fork
57 | 43: knife
58 | 44: spoon
59 | 45: bowl
60 | 46: banana
61 | 47: apple
62 | 48: sandwich
63 | 49: orange
64 | 50: broccoli
65 | 51: carrot
66 | 52: hot dog
67 | 53: pizza
68 | 54: donut
69 | 55: cake
70 | 56: chair
71 | 57: couch
72 | 58: potted plant
73 | 59: bed
74 | 60: dining table
75 | 61: toilet
76 | 62: tv
77 | 63: laptop
78 | 64: mouse
79 | 65: remote
80 | 66: keyboard
81 | 67: cell phone
82 | 68: microwave
83 | 69: oven
84 | 70: toaster
85 | 71: sink
86 | 72: refrigerator
87 | 73: book
88 | 74: clock
89 | 75: vase
90 | 76: scissors
91 | 77: teddy bear
92 | 78: hair drier
93 | 79: toothbrush
94 | 80: pump
95 | 81: pipe
96 | 82: steel pipe
97 | 83: electric cable
98 | args:
99 | batch: 1
100 | half: false
101 | int8: false
102 | nms: false
103 |
--------------------------------------------------------------------------------
/src/best_saved_model/saved_model.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DanielSarmiento04/yolov11cpp/c0690429b302c0b8a283a900ee30b89152019909/src/best_saved_model/saved_model.pb
--------------------------------------------------------------------------------
/src/best_saved_model/variables/variables.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DanielSarmiento04/yolov11cpp/c0690429b302c0b8a283a900ee30b89152019909/src/best_saved_model/variables/variables.data-00000-of-00001
--------------------------------------------------------------------------------
/src/best_saved_model/variables/variables.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DanielSarmiento04/yolov11cpp/c0690429b302c0b8a283a900ee30b89152019909/src/best_saved_model/variables/variables.index
--------------------------------------------------------------------------------
/src/camera_inference.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 |
6 | #include
7 | #include
8 |
9 | #include "./ia/YOLO11.hpp"
10 |
11 | int main()
12 | {
13 |
14 | // Configuration parameters
15 | const bool isGPU = false;
16 | const std::string labelsPath = "./classes.txt";
17 | const std::string modelPath = "./best_optmized.onnx";
18 | const std::string videoSource = "./input.mov"; // your usb cam device
19 | const std::string outputPath = "./output.mp4"; // path for output video file
20 |
21 | // Use the same default thresholds as Ultralytics CLI
22 | const float confThreshold = 0.25f; // Match Ultralytics default confidence threshold
23 | const float iouThreshold = 0.45f; // Match Ultralytics default IoU threshold
24 |
25 | std::cout << "Initializing YOLOv11 detector with model: " << modelPath << std::endl;
26 | std::cout << "Using confidence threshold: " << confThreshold << ", IoU threshold: " << iouThreshold << std::endl;
27 |
28 | // read model
29 | std::cout << "Loading model and labels..." << std::endl;
30 |
31 | // Initialize YOLO detector
32 | YOLO11Detector detector(modelPath, labelsPath, isGPU);
33 |
34 | // Open video capture
35 | cv::VideoCapture cap;
36 |
37 | // configure the best camera to iphone 11
38 | cap.open(videoSource, cv::CAP_FFMPEG);
39 | if (!cap.isOpened())
40 | {
41 | std::cerr << "Error: Could not open the camera!\n";
42 | return -1;
43 | }
44 |
45 | // Get video properties for the writer
46 | double fps = cap.get(cv::CAP_PROP_FPS);
47 | int width = static_cast(cap.get(cv::CAP_PROP_FRAME_WIDTH));
48 | int height = static_cast(cap.get(cv::CAP_PROP_FRAME_HEIGHT));
49 |
50 | // Initialize video writer
51 | cv::VideoWriter videoWriter;
52 | int fourcc = cv::VideoWriter::fourcc('a', 'v', 'c', '1'); // H.264 codec
53 |
54 | // Open the video writer
55 | bool isWriterOpened = videoWriter.open(outputPath, fourcc, fps, cv::Size(width, height), true);
56 | if (!isWriterOpened) {
57 | std::cerr << "Error: Could not open video writer!\n";
58 | return -1;
59 | }
60 |
61 | std::cout << "Recording output to: " << outputPath << std::endl;
62 | std::cout << "Press 'q' to stop recording and exit" << std::endl;
63 |
64 | int frame_count = 0;
65 | double total_time = 0.0;
66 |
67 | for (;;)
68 | {
69 | cv::Mat frame;
70 | cap >> frame;
71 | if (frame.empty())
72 | {
73 | std::cerr << "Error: Could not read a frame!\n";
74 | break;
75 | }
76 |
77 | // Display the frame
78 | cv::imshow("input", frame);
79 |
80 | // Measure detection time
81 | auto start_time = std::chrono::high_resolution_clock::now();
82 |
83 | // Perform detection with the updated thresholds
84 | std::vector detections = detector.detect(frame, confThreshold, iouThreshold);
85 |
86 | auto end_time = std::chrono::high_resolution_clock::now();
87 | auto duration = std::chrono::duration_cast(end_time - start_time).count();
88 | total_time += duration;
89 | frame_count++;
90 |
91 | // Create a copy for output with detections drawn
92 | cv::Mat outputFrame = frame.clone();
93 |
94 | // Draw bounding boxes and masks on the frame
95 | detector.drawBoundingBoxMask(outputFrame, detections);
96 |
97 | // Add FPS info
98 | double fps = 1000.0 / (total_time / frame_count);
99 | cv::putText(outputFrame, "FPS: " + std::to_string(static_cast(fps)),
100 | cv::Point(20, 40), cv::FONT_HERSHEY_SIMPLEX, 1.0, cv::Scalar(0, 255, 0), 2);
101 |
102 | // Write the processed frame to the output video
103 | videoWriter.write(outputFrame);
104 |
105 | // Display the frame
106 | cv::imshow("Detections", outputFrame);
107 |
108 | // Use a small delay and check for 'q' key press to quit
109 | if (cv::waitKey(1) == 'q')
110 | {
111 | break;
112 | }
113 | }
114 |
115 | // Release resources
116 | cap.release();
117 | videoWriter.release();
118 | cv::destroyAllWindows();
119 |
120 | std::cout << "Video processing completed. Output saved to: " << outputPath << std::endl;
121 | std::cout << "Average FPS: " << (1000.0 / (total_time / frame_count)) << std::endl;
122 |
123 | return 0;
124 | }
125 |
--------------------------------------------------------------------------------
/src/camera_inference.out:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DanielSarmiento04/yolov11cpp/c0690429b302c0b8a283a900ee30b89152019909/src/camera_inference.out
--------------------------------------------------------------------------------
/src/classes.txt:
--------------------------------------------------------------------------------
1 | person
2 | bicycle
3 | car
4 | motorcycle
5 | airplane
6 | bus
7 | train
8 | truck
9 | boat
10 | traffic light
11 | fire hydrant
12 | stop sign
13 | parking meter
14 | bench
15 | bird
16 | cat
17 | dog
18 | horse
19 | sheep
20 | cow
21 | elephant
22 | bear
23 | zebra
24 | giraffe
25 | backpack
26 | umbrella
27 | handbag
28 | tie
29 | suitcase
30 | frisbee
31 | skis
32 | snowboard
33 | sports ball
34 | kite
35 | baseball bat
36 | baseball glove
37 | skateboard
38 | surfboard
39 | tennis racket
40 | bottle
41 | wine glass
42 | cup
43 | fork
44 | knife
45 | spoon
46 | bowl
47 | banana
48 | apple
49 | sandwich
50 | orange
51 | broccoli
52 | carrot
53 | hot dog
54 | pizza
55 | donut
56 | cake
57 | chair
58 | couch
59 | potted plant
60 | bed
61 | dining table
62 | toilet
63 | tv
64 | laptop
65 | mouse
66 | remote
67 | keyboard
68 | cell phone
69 | microwave
70 | oven
71 | toaster
72 | sink
73 | refrigerator
74 | book
75 | clock
76 | vase
77 | scissors
78 | teddy bear
79 | hair drier
80 | toothbrush
81 | pump
82 | pipe
83 | steel pipe
84 | electric cable
85 |
--------------------------------------------------------------------------------
/src/ia/YOLO11.hpp:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | // ===================================
4 | // Single YOLOv11 Detector Header File
5 | // ===================================
6 | //
7 | // This header defines the YOLO11Detector class for performing object detection using the YOLOv11 model.
8 | // It includes necessary libraries, utility structures, and helper functions to facilitate model inference
9 | // and result postprocessing.
10 | //
11 | // Author: Abdalrahman M. Amer, www.linkedin.com/in/abdalrahman-m-amer
12 | // Date: 29.09.2024
13 | //
14 | // ================================
15 |
16 | /**
17 | * @file YOLO11Detector.hpp
18 | * @brief Header file for the YOLO11Detector class, responsible for object detection
19 | * using the YOLOv11 model with optimized performance for minimal latency.
20 | */
21 |
22 | // Include necessary ONNX Runtime and OpenCV headers
23 | #include
24 | #include
25 |
26 | #include
27 | #include
28 | #include
29 | #include
30 | #include
31 | #include
32 | #include
33 | #include
34 | #include
35 | #include
36 | #include
37 |
38 | // Include debug and custom ScopedTimer tools for performance measurement
39 | #include "tools/Debug.hpp"
40 | #include "tools/ScopedTimer.hpp"
41 |
42 |
43 | /**
44 | * @brief Confidence threshold for filtering detections.
45 | */
46 | const float CONFIDENCE_THRESHOLD = 0.4f;
47 |
48 | /**
49 | * @brief IoU threshold for filtering detections.
50 | */
51 | const float IOU_THRESHOLD = 0.3f;
52 |
53 |
54 | /**
55 | * @brief Struct to represent a bounding box.
56 | */
57 |
58 | // Struct to represent a bounding box
59 | struct BoundingBox {
60 | int x;
61 | int y;
62 | int width;
63 | int height;
64 |
65 | BoundingBox() : x(0), y(0), width(0), height(0) {}
66 | BoundingBox(int x_, int y_, int width_, int height_)
67 | : x(x_), y(y_), width(width_), height(height_) {}
68 | };
69 |
70 | /**
71 | * @brief Struct to represent a detection.
72 | */
73 | struct Detection {
74 | BoundingBox box;
75 | float conf{};
76 | int classId{};
77 | };
78 |
79 | /**
80 | * @namespace utils
81 | * @brief Namespace containing utility functions for the YOLO11Detector.
82 | */
83 | namespace utils {
84 |
85 | /**
86 | * @brief A robust implementation of a clamp function.
87 | * Restricts a value to lie within a specified range [low, high].
88 | *
89 | * @tparam T The type of the value to clamp. Should be an arithmetic type (int, float, etc.).
90 | * @param value The value to clamp.
91 | * @param low The lower bound of the range.
92 | * @param high The upper bound of the range.
93 | * @return const T& The clamped value, constrained to the range [low, high].
94 | *
95 | * @note If low > high, the function swaps the bounds automatically to ensure valid behavior.
96 | */
97 | template
98 | typename std::enable_if::value, T>::type
99 | inline clamp(const T &value, const T &low, const T &high)
100 | {
101 | // Ensure the range [low, high] is valid; swap if necessary
102 | T validLow = low < high ? low : high;
103 | T validHigh = low < high ? high : low;
104 |
105 | // Clamp the value to the range [validLow, validHigh]
106 | if (value < validLow)
107 | return validLow;
108 | if (value > validHigh)
109 | return validHigh;
110 | return value;
111 | }
112 |
113 |
114 | /**
115 | * @brief Loads class names from a given file path.
116 | *
117 | * @param path Path to the file containing class names.
118 | * @return std::vector Vector of class names.
119 | */
120 | std::vector getClassNames(const std::string &path) {
121 | std::vector classNames;
122 | std::ifstream infile(path);
123 |
124 | if (infile) {
125 | std::string line;
126 | while (getline(infile, line)) {
127 | // Remove carriage return if present (for Windows compatibility)
128 | if (!line.empty() && line.back() == '\r')
129 | line.pop_back();
130 | classNames.emplace_back(line);
131 | }
132 | } else {
133 | std::cerr << "ERROR: Failed to access class name path: " << path << std::endl;
134 | }
135 |
136 | DEBUG_PRINT("Loaded " << classNames.size() << " class names from " + path);
137 | return classNames;
138 | }
139 |
140 | /**
141 | * @brief Computes the product of elements in a vector.
142 | *
143 | * @param vector Vector of integers.
144 | * @return size_t Product of all elements.
145 | */
146 | size_t vectorProduct(const std::vector &vector) {
147 | return std::accumulate(vector.begin(), vector.end(), 1ull, std::multiplies());
148 | }
149 |
150 |
151 | /**
152 | * @brief Resizes an image with letterboxing to maintain aspect ratio.
153 | *
154 | * @param image Input image.
155 | * @param outImage Output resized and padded image.
156 | * @param newShape Desired output size.
157 | * @param color Padding color (default is gray).
158 | * @param auto_ Automatically adjust padding to be multiple of stride.
159 | * @param scaleFill Whether to scale to fill the new shape without keeping aspect ratio.
160 | * @param scaleUp Whether to allow scaling up of the image.
161 | * @param stride Stride size for padding alignment.
162 | */
163 | inline void letterBox(const cv::Mat& image, cv::Mat& outImage,
164 | const cv::Size& newShape,
165 | const cv::Scalar& color = cv::Scalar(114, 114, 114),
166 | bool auto_ = true,
167 | bool scaleFill = false,
168 | bool scaleUp = true,
169 | int stride = 32) {
170 | // Calculate the scaling ratio to fit the image within the new shape
171 | float ratio = std::min(static_cast(newShape.height) / image.rows,
172 | static_cast(newShape.width) / image.cols);
173 |
174 | // Prevent scaling up if not allowed
175 | if (!scaleUp) {
176 | ratio = std::min(ratio, 1.0f);
177 | }
178 |
179 | // Calculate new dimensions after scaling
180 | int newUnpadW = static_cast(std::round(image.cols * ratio));
181 | int newUnpadH = static_cast(std::round(image.rows * ratio));
182 |
183 | // Calculate padding needed to reach the desired shape
184 | int dw = newShape.width - newUnpadW;
185 | int dh = newShape.height - newUnpadH;
186 |
187 | if (auto_) {
188 | // Ensure padding is a multiple of stride for model compatibility
189 | dw = (dw % stride) / 2;
190 | dh = (dh % stride) / 2;
191 | } else if (scaleFill) {
192 | // Scale to fill without maintaining aspect ratio
193 | newUnpadW = newShape.width;
194 | newUnpadH = newShape.height;
195 | ratio = std::min(static_cast(newShape.width) / image.cols,
196 | static_cast(newShape.height) / image.rows);
197 | dw = 0;
198 | dh = 0;
199 | } else {
200 | // Evenly distribute padding on both sides
201 | // Calculate separate padding for left/right and top/bottom to handle odd padding
202 | int padLeft = dw / 2;
203 | int padRight = dw - padLeft;
204 | int padTop = dh / 2;
205 | int padBottom = dh - padTop;
206 |
207 | // Resize the image if the new dimensions differ
208 | if (image.cols != newUnpadW || image.rows != newUnpadH) {
209 | cv::resize(image, outImage, cv::Size(newUnpadW, newUnpadH), 0, 0, cv::INTER_LINEAR);
210 | } else {
211 | // Avoid unnecessary copying if dimensions are the same
212 | outImage = image;
213 | }
214 |
215 | // Apply padding to reach the desired shape
216 | cv::copyMakeBorder(outImage, outImage, padTop, padBottom, padLeft, padRight, cv::BORDER_CONSTANT, color);
217 | return; // Exit early since padding is already applied
218 | }
219 |
220 | // Resize the image if the new dimensions differ
221 | if (image.cols != newUnpadW || image.rows != newUnpadH) {
222 | cv::resize(image, outImage, cv::Size(newUnpadW, newUnpadH), 0, 0, cv::INTER_LINEAR);
223 | } else {
224 | // Avoid unnecessary copying if dimensions are the same
225 | outImage = image;
226 | }
227 |
228 | // Calculate separate padding for left/right and top/bottom to handle odd padding
229 | int padLeft = dw / 2;
230 | int padRight = dw - padLeft;
231 | int padTop = dh / 2;
232 | int padBottom = dh - padTop;
233 |
234 | // Apply padding to reach the desired shape
235 | cv::copyMakeBorder(outImage, outImage, padTop, padBottom, padLeft, padRight, cv::BORDER_CONSTANT, color);
236 | }
237 |
238 | /**
239 | * @brief Scales detection coordinates back to the original image size.
240 | *
241 | * @param imageShape Shape of the resized image used for inference.
242 | * @param bbox Detection bounding box to be scaled.
243 | * @param imageOriginalShape Original image size before resizing.
244 | * @param p_Clip Whether to clip the coordinates to the image boundaries.
245 | * @return BoundingBox Scaled bounding box.
246 | */
247 | BoundingBox scaleCoords(const cv::Size &imageShape, BoundingBox coords,
248 | const cv::Size &imageOriginalShape, bool p_Clip) {
249 | BoundingBox result;
250 | float gain = std::min(static_cast(imageShape.height) / static_cast(imageOriginalShape.height),
251 | static_cast(imageShape.width) / static_cast(imageOriginalShape.width));
252 |
253 | int padX = static_cast(std::round((imageShape.width - imageOriginalShape.width * gain) / 2.0f));
254 | int padY = static_cast(std::round((imageShape.height - imageOriginalShape.height * gain) / 2.0f));
255 |
256 | result.x = static_cast(std::round((coords.x - padX) / gain));
257 | result.y = static_cast(std::round((coords.y - padY) / gain));
258 | result.width = static_cast(std::round(coords.width / gain));
259 | result.height = static_cast(std::round(coords.height / gain));
260 |
261 | if (p_Clip) {
262 | result.x = utils::clamp(result.x, 0, imageOriginalShape.width);
263 | result.y = utils::clamp(result.y, 0, imageOriginalShape.height);
264 | result.width = utils::clamp(result.width, 0, imageOriginalShape.width - result.x);
265 | result.height = utils::clamp(result.height, 0, imageOriginalShape.height - result.y);
266 | }
267 | return result;
268 | }
269 |
270 | /**
271 | * @brief Performs Non-Maximum Suppression (NMS) on the bounding boxes.
272 | *
273 | * @param boundingBoxes Vector of bounding boxes.
274 | * @param scores Vector of confidence scores corresponding to each bounding box.
275 | * @param scoreThreshold Confidence threshold to filter boxes.
276 | * @param nmsThreshold IoU threshold for NMS.
277 | * @param indices Output vector of indices that survive NMS.
278 | */
279 | // Optimized Non-Maximum Suppression Function
280 | void NMSBoxes(const std::vector& boundingBoxes,
281 | const std::vector& scores,
282 | float scoreThreshold,
283 | float nmsThreshold,
284 | std::vector& indices)
285 | {
286 | indices.clear();
287 |
288 | const size_t numBoxes = boundingBoxes.size();
289 | if (numBoxes == 0) {
290 | DEBUG_PRINT("No bounding boxes to process in NMS");
291 | return;
292 | }
293 |
294 | // Step 1: Filter out boxes with scores below the threshold
295 | // and create a list of indices sorted by descending scores
296 | std::vector sortedIndices;
297 | sortedIndices.reserve(numBoxes);
298 | for (size_t i = 0; i < numBoxes; ++i) {
299 | if (scores[i] >= scoreThreshold) {
300 | sortedIndices.push_back(static_cast(i));
301 | }
302 | }
303 |
304 | // If no boxes remain after thresholding
305 | if (sortedIndices.empty()) {
306 | DEBUG_PRINT("No bounding boxes above score threshold");
307 | return;
308 | }
309 |
310 | // Sort the indices based on scores in descending order
311 | std::sort(sortedIndices.begin(), sortedIndices.end(),
312 | [&scores](int idx1, int idx2) {
313 | return scores[idx1] > scores[idx2];
314 | });
315 |
316 | // Step 2: Precompute the areas of all boxes
317 | std::vector areas(numBoxes, 0.0f);
318 | for (size_t i = 0; i < numBoxes; ++i) {
319 | areas[i] = boundingBoxes[i].width * boundingBoxes[i].height;
320 | }
321 |
322 | // Step 3: Suppression mask to mark boxes that are suppressed
323 | std::vector suppressed(numBoxes, false);
324 |
325 | // Step 4: Iterate through the sorted list and suppress boxes with high IoU
326 | for (size_t i = 0; i < sortedIndices.size(); ++i) {
327 | int currentIdx = sortedIndices[i];
328 | if (suppressed[currentIdx]) {
329 | continue;
330 | }
331 |
332 | // Select the current box as a valid detection
333 | indices.push_back(currentIdx);
334 |
335 | const BoundingBox& currentBox = boundingBoxes[currentIdx];
336 | const float x1_max = currentBox.x;
337 | const float y1_max = currentBox.y;
338 | const float x2_max = currentBox.x + currentBox.width;
339 | const float y2_max = currentBox.y + currentBox.height;
340 | const float area_current = areas[currentIdx];
341 |
342 | // Compare IoU of the current box with the rest
343 | for (size_t j = i + 1; j < sortedIndices.size(); ++j) {
344 | int compareIdx = sortedIndices[j];
345 | if (suppressed[compareIdx]) {
346 | continue;
347 | }
348 |
349 | const BoundingBox& compareBox = boundingBoxes[compareIdx];
350 | const float x1 = std::max(x1_max, static_cast(compareBox.x));
351 | const float y1 = std::max(y1_max, static_cast(compareBox.y));
352 | const float x2 = std::min(x2_max, static_cast(compareBox.x + compareBox.width));
353 | const float y2 = std::min(y2_max, static_cast(compareBox.y + compareBox.height));
354 |
355 | const float interWidth = x2 - x1;
356 | const float interHeight = y2 - y1;
357 |
358 | if (interWidth <= 0 || interHeight <= 0) {
359 | continue;
360 | }
361 |
362 | const float intersection = interWidth * interHeight;
363 | const float unionArea = area_current + areas[compareIdx] - intersection;
364 | const float iou = (unionArea > 0.0f) ? (intersection / unionArea) : 0.0f;
365 |
366 | if (iou > nmsThreshold) {
367 | suppressed[compareIdx] = true;
368 | }
369 | }
370 | }
371 |
372 | DEBUG_PRINT("NMS completed with " + std::to_string(indices.size()) + " indices remaining");
373 | }
374 |
375 |
376 | /**
377 | * @brief Generates a vector of colors for each class name.
378 | *
379 | * @param classNames Vector of class names.
380 | * @param seed Seed for random color generation to ensure reproducibility.
381 | * @return std::vector Vector of colors.
382 | */
383 | inline std::vector generateColors(const std::vector &classNames, int seed = 42) {
384 | // Static cache to store colors based on class names to avoid regenerating
385 | static std::unordered_map> colorCache;
386 |
387 | // Compute a hash key based on class names to identify unique class configurations
388 | size_t hashKey = 0;
389 | for (const auto& name : classNames) {
390 | hashKey ^= std::hash{}(name) + 0x9e3779b9 + (hashKey << 6) + (hashKey >> 2);
391 | }
392 |
393 | // Check if colors for this class configuration are already cached
394 | auto it = colorCache.find(hashKey);
395 | if (it != colorCache.end()) {
396 | return it->second;
397 | }
398 |
399 | // Generate unique random colors for each class
400 | std::vector colors;
401 | colors.reserve(classNames.size());
402 |
403 | std::mt19937 rng(seed); // Initialize random number generator with fixed seed
404 | std::uniform_int_distribution uni(0, 255); // Define distribution for color values
405 |
406 | for (size_t i = 0; i < classNames.size(); ++i) {
407 | colors.emplace_back(cv::Scalar(uni(rng), uni(rng), uni(rng))); // Generate random BGR color
408 | }
409 |
410 | // Cache the generated colors for future use
411 | colorCache.emplace(hashKey, colors);
412 |
413 | return colorCache[hashKey];
414 | }
415 |
416 | /**
417 | * @brief Draws bounding boxes and labels on the image based on detections.
418 | *
419 | * @param image Image on which to draw.
420 | * @param detections Vector of detections.
421 | * @param classNames Vector of class names corresponding to object IDs.
422 | * @param colors Vector of colors for each class.
423 | */
424 | inline void drawBoundingBox(cv::Mat &image, const std::vector &detections,
425 | const std::vector &classNames, const std::vector &colors) {
426 | // Iterate through each detection to draw bounding boxes and labels
427 | for (const auto& detection : detections) {
428 | // Skip detections below the confidence threshold
429 | if (detection.conf <= CONFIDENCE_THRESHOLD)
430 | continue;
431 |
432 | // Ensure the object ID is within valid range
433 | if (detection.classId < 0 || static_cast(detection.classId) >= classNames.size())
434 | continue;
435 |
436 | // Select color based on object ID for consistent coloring
437 | const cv::Scalar& color = colors[detection.classId % colors.size()];
438 |
439 | // Draw the bounding box rectangle
440 | cv::rectangle(image, cv::Point(detection.box.x, detection.box.y),
441 | cv::Point(detection.box.x + detection.box.width, detection.box.y + detection.box.height),
442 | color, 2, cv::LINE_AA);
443 |
444 | // Prepare label text with class name and confidence percentage
445 | std::string label = classNames[detection.classId] + ": " + std::to_string(static_cast(detection.conf * 100)) + "%";
446 |
447 | // Define text properties for labels
448 | int fontFace = cv::FONT_HERSHEY_SIMPLEX;
449 | double fontScale = std::min(image.rows, image.cols) * 0.0008;
450 | const int thickness = std::max(1, static_cast(std::min(image.rows, image.cols) * 0.002));
451 | int baseline = 0;
452 |
453 | // Calculate text size for background rectangles
454 | cv::Size textSize = cv::getTextSize(label, fontFace, fontScale, thickness, &baseline);
455 |
456 | // Define positions for the label
457 | int labelY = std::max(detection.box.y, textSize.height + 5);
458 | cv::Point labelTopLeft(detection.box.x, labelY - textSize.height - 5);
459 | cv::Point labelBottomRight(detection.box.x + textSize.width + 5, labelY + baseline - 5);
460 |
461 | // Draw background rectangle for label
462 | cv::rectangle(image, labelTopLeft, labelBottomRight, color, cv::FILLED);
463 |
464 | // Put label text
465 | cv::putText(image, label, cv::Point(detection.box.x + 2, labelY - 2), fontFace, fontScale, cv::Scalar(255, 255, 255), thickness, cv::LINE_AA);
466 | }
467 | }
468 |
469 | /**
470 | * @brief Draws bounding boxes and semi-transparent masks on the image based on detections.
471 | *
472 | * @param image Image on which to draw.
473 | * @param detections Vector of detections.
474 | * @param classNames Vector of class names corresponding to object IDs.
475 | * @param classColors Vector of colors for each class.
476 | * @param maskAlpha Alpha value for the mask transparency.
477 | */
478 | inline void drawBoundingBoxMask(cv::Mat &image, const std::vector &detections,
479 | const std::vector &classNames, const std::vector &classColors,
480 | float maskAlpha = 0.4f) {
481 | // Validate input image
482 | if (image.empty()) {
483 | std::cerr << "ERROR: Empty image provided to drawBoundingBoxMask." << std::endl;
484 | return;
485 | }
486 |
487 | const int imgHeight = image.rows;
488 | const int imgWidth = image.cols;
489 |
490 | // Precompute dynamic font size and thickness based on image dimensions
491 | const double fontSize = std::min(imgHeight, imgWidth) * 0.0006;
492 | const int textThickness = std::max(1, static_cast(std::min(imgHeight, imgWidth) * 0.001));
493 |
494 | // Create a mask image for blending (initialized to zero)
495 | cv::Mat maskImage(image.size(), image.type(), cv::Scalar::all(0));
496 |
497 | // Pre-filter detections to include only those above the confidence threshold and with valid class IDs
498 | std::vector filteredDetections;
499 | for (const auto& detection : detections) {
500 | if (detection.conf > CONFIDENCE_THRESHOLD &&
501 | detection.classId >= 0 &&
502 | static_cast(detection.classId) < classNames.size()) {
503 | filteredDetections.emplace_back(&detection);
504 | }
505 | }
506 |
507 | // Draw filled rectangles on the mask image for the semi-transparent overlay
508 | for (const auto* detection : filteredDetections) {
509 | cv::Rect box(detection->box.x, detection->box.y, detection->box.width, detection->box.height);
510 | const cv::Scalar &color = classColors[detection->classId];
511 | cv::rectangle(maskImage, box, color, cv::FILLED);
512 | }
513 |
514 | // Blend the maskImage with the original image to apply the semi-transparent masks
515 | cv::addWeighted(maskImage, maskAlpha, image, 1.0f, 0, image);
516 |
517 | // Draw bounding boxes and labels on the original image
518 | for (const auto* detection : filteredDetections) {
519 | cv::Rect box(detection->box.x, detection->box.y, detection->box.width, detection->box.height);
520 | const cv::Scalar &color = classColors[detection->classId];
521 | cv::rectangle(image, box, color, 2, cv::LINE_AA);
522 |
523 | std::string label = classNames[detection->classId] + ": " + std::to_string(static_cast(detection->conf * 100)) + "%";
524 | int baseLine = 0;
525 | cv::Size labelSize = cv::getTextSize(label, cv::FONT_HERSHEY_SIMPLEX, fontSize, textThickness, &baseLine);
526 |
527 | int labelY = std::max(detection->box.y, labelSize.height + 5);
528 | cv::Point labelTopLeft(detection->box.x, labelY - labelSize.height - 5);
529 | cv::Point labelBottomRight(detection->box.x + labelSize.width + 5, labelY + baseLine - 5);
530 |
531 | // Draw background rectangle for label
532 | cv::rectangle(image, labelTopLeft, labelBottomRight, color, cv::FILLED);
533 |
534 | // Put label text
535 | cv::putText(image, label, cv::Point(detection->box.x + 2, labelY - 2), cv::FONT_HERSHEY_SIMPLEX, fontSize, cv::Scalar(255, 255, 255), textThickness, cv::LINE_AA);
536 | }
537 |
538 | DEBUG_PRINT("Bounding boxes and masks drawn on image.");
539 | }
540 |
541 | };
542 |
543 | /**
544 | * @brief YOLO11Detector class handles loading the YOLO model, preprocessing images, running inference, and postprocessing results.
545 | */
546 | class YOLO11Detector {
547 | public:
548 | /**
549 | * @brief Constructor to initialize the YOLO detector with model and label paths.
550 | *
551 | * @param modelPath Path to the ONNX model file.
552 | * @param labelsPath Path to the file containing class labels.
553 | * @param useGPU Whether to use GPU for inference (default is false).
554 | */
555 | YOLO11Detector(const std::string &modelPath, const std::string &labelsPath, bool useGPU = false);
556 |
557 | /**
558 | * @brief Runs detection on the provided image.
559 | *
560 | * @param image Input image for detection.
561 | * @param confThreshold Confidence threshold to filter detections (default is 0.4).
562 | * @param iouThreshold IoU threshold for Non-Maximum Suppression (default is 0.45).
563 | * @return std::vector Vector of detections.
564 | */
565 | std::vector detect(const cv::Mat &image, float confThreshold = 0.4f, float iouThreshold = 0.45f);
566 |
567 | /**
568 | * @brief Draws bounding boxes on the image based on detections.
569 | *
570 | * @param image Image on which to draw.
571 | * @param detections Vector of detections.
572 | */
573 | void drawBoundingBox(cv::Mat &image, const std::vector &detections) const {
574 | utils::drawBoundingBox(image, detections, classNames, classColors);
575 | }
576 |
577 | /**
578 | * @brief Draws bounding boxes and semi-transparent masks on the image based on detections.
579 | *
580 | * @param image Image on which to draw.
581 | * @param detections Vector of detections.
582 | * @param maskAlpha Alpha value for mask transparency (default is 0.4).
583 | */
584 | void drawBoundingBoxMask(cv::Mat &image, const std::vector &detections, float maskAlpha = 0.4f) const {
585 | utils::drawBoundingBoxMask(image, detections, classNames, classColors, maskAlpha);
586 | }
587 |
588 | private:
589 | Ort::Env env{nullptr}; // ONNX Runtime environment
590 | Ort::SessionOptions sessionOptions{nullptr}; // Session options for ONNX Runtime
591 | Ort::Session session{nullptr}; // ONNX Runtime session for running inference
592 | bool isDynamicInputShape{}; // Flag indicating if input shape is dynamic
593 | cv::Size inputImageShape; // Expected input image shape for the model
594 |
595 | // Vectors to hold allocated input and output node names
596 | std::vector inputNodeNameAllocatedStrings;
597 | std::vector inputNames;
598 | std::vector outputNodeNameAllocatedStrings;
599 | std::vector outputNames;
600 |
601 | size_t numInputNodes, numOutputNodes; // Number of input and output nodes in the model
602 |
603 | std::vector classNames; // Vector of class names loaded from file
604 | std::vector classColors; // Vector of colors for each class
605 |
606 | /**
607 | * @brief Preprocesses the input image for model inference.
608 | *
609 | * @param image Input image.
610 | * @param blob Reference to pointer where preprocessed data will be stored.
611 | * @param inputTensorShape Reference to vector representing input tensor shape.
612 | * @return cv::Mat Resized image after preprocessing.
613 | */
614 | cv::Mat preprocess(const cv::Mat &image, float *&blob, std::vector &inputTensorShape);
615 |
616 | /**
617 | * @brief Postprocesses the model output to extract detections.
618 | *
619 | * @param originalImageSize Size of the original input image.
620 | * @param resizedImageShape Size of the image after preprocessing.
621 | * @param outputTensors Vector of output tensors from the model.
622 | * @param confThreshold Confidence threshold to filter detections.
623 | * @param iouThreshold IoU threshold for Non-Maximum Suppression.
624 | * @return std::vector Vector of detections.
625 | */
626 | std::vector postprocess(const cv::Size &originalImageSize, const cv::Size &resizedImageShape,
627 | const std::vector &outputTensors,
628 | float confThreshold, float iouThreshold);
629 |
630 | };
631 |
632 | // Implementation of YOLO11Detector constructor
633 | YOLO11Detector::YOLO11Detector(const std::string &modelPath, const std::string &labelsPath, bool useGPU) {
634 | // Initialize ONNX Runtime environment with warning level
635 | env = Ort::Env(ORT_LOGGING_LEVEL_WARNING, "ONNX_DETECTION");
636 | sessionOptions = Ort::SessionOptions();
637 |
638 | // Set number of intra-op threads for parallelism
639 | sessionOptions.SetIntraOpNumThreads(std::min(6, static_cast(std::thread::hardware_concurrency())));
640 | sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
641 |
642 | // Retrieve available execution providers (e.g., CPU, CUDA)
643 | std::vector availableProviders = Ort::GetAvailableProviders();
644 | auto cudaAvailable = std::find(availableProviders.begin(), availableProviders.end(), "CUDAExecutionProvider");
645 | OrtCUDAProviderOptions cudaOption;
646 |
647 | // Configure session options based on whether GPU is to be used and available
648 | if (useGPU && cudaAvailable != availableProviders.end()) {
649 | std::cout << "Inference device: GPU" << std::endl;
650 | sessionOptions.AppendExecutionProvider_CUDA(cudaOption); // Append CUDA execution provider
651 | } else {
652 | if (useGPU) {
653 | std::cout << "GPU is not supported by your ONNXRuntime build. Fallback to CPU." << std::endl;
654 | }
655 | std::cout << "Inference device: CPU" << std::endl;
656 | }
657 |
658 | // Load the ONNX model into the session
659 | #ifdef _WIN32
660 | std::wstring w_modelPath(modelPath.begin(), modelPath.end());
661 | session = Ort::Session(env, w_modelPath.c_str(), sessionOptions);
662 | #else
663 | session = Ort::Session(env, modelPath.c_str(), sessionOptions);
664 | #endif
665 |
666 | Ort::AllocatorWithDefaultOptions allocator;
667 |
668 | // Retrieve input tensor shape information
669 | Ort::TypeInfo inputTypeInfo = session.GetInputTypeInfo(0);
670 | std::vector inputTensorShapeVec = inputTypeInfo.GetTensorTypeAndShapeInfo().GetShape();
671 | isDynamicInputShape = (inputTensorShapeVec.size() >= 4) && (inputTensorShapeVec[2] == -1 && inputTensorShapeVec[3] == -1); // Check for dynamic dimensions
672 |
673 | // Allocate and store input node names
674 | auto input_name = session.GetInputNameAllocated(0, allocator);
675 | inputNodeNameAllocatedStrings.push_back(std::move(input_name));
676 | inputNames.push_back(inputNodeNameAllocatedStrings.back().get());
677 |
678 | // Allocate and store output node names
679 | auto output_name = session.GetOutputNameAllocated(0, allocator);
680 | outputNodeNameAllocatedStrings.push_back(std::move(output_name));
681 | outputNames.push_back(outputNodeNameAllocatedStrings.back().get());
682 |
683 | // Set the expected input image shape based on the model's input tensor
684 | if (inputTensorShapeVec.size() >= 4) {
685 | inputImageShape = cv::Size(static_cast(inputTensorShapeVec[3]), static_cast(inputTensorShapeVec[2]));
686 | } else {
687 | throw std::runtime_error("Invalid input tensor shape.");
688 | }
689 |
690 | // Get the number of input and output nodes
691 | numInputNodes = session.GetInputCount();
692 | numOutputNodes = session.GetOutputCount();
693 |
694 | // Load class names and generate corresponding colors
695 | classNames = utils::getClassNames(labelsPath);
696 | classColors = utils::generateColors(classNames);
697 |
698 | std::cout << "Model loaded successfully with " << numInputNodes << " input nodes and " << numOutputNodes << " output nodes." << std::endl;
699 | }
700 |
701 | // Preprocess function implementation
702 | cv::Mat YOLO11Detector::preprocess(const cv::Mat &image, float *&blob, std::vector &inputTensorShape) {
703 | ScopedTimer timer("preprocessing");
704 |
705 | cv::Mat resizedImage;
706 | // Resize and pad the image using letterBox utility
707 | utils::letterBox(image, resizedImage, inputImageShape, cv::Scalar(114, 114, 114), isDynamicInputShape, false, true, 32);
708 |
709 | // Convert BGR to RGB (YOLOv11 expects RGB input)
710 | cv::Mat rgbImage;
711 | cv::cvtColor(resizedImage, rgbImage, cv::COLOR_BGR2RGB);
712 |
713 | // YOLOv11 normalization: Convert to float, normalize to [0, 1]
714 | rgbImage.convertTo(rgbImage, CV_32FC3, 1.0f/255.0f);
715 |
716 | // Allocate memory for the image blob in CHW format
717 | blob = new float[rgbImage.cols * rgbImage.rows * rgbImage.channels()];
718 |
719 | // Split the image into separate channels and store in the blob
720 | std::vector chw(rgbImage.channels());
721 | for (int i = 0; i < rgbImage.channels(); ++i) {
722 | chw[i] = cv::Mat(rgbImage.rows, rgbImage.cols, CV_32FC1, blob + i * rgbImage.cols * rgbImage.rows);
723 | }
724 | cv::split(rgbImage, chw); // Split channels into the blob
725 |
726 | DEBUG_PRINT("Preprocessing completed with RGB conversion");
727 |
728 | return rgbImage;
729 | }
730 |
731 | // Postprocess function to convert raw model output into detections
732 | std::vector YOLO11Detector::postprocess(
733 | const cv::Size &originalImageSize,
734 | const cv::Size &resizedImageShape,
735 | const std::vector &outputTensors,
736 | float confThreshold,
737 | float iouThreshold
738 | ) {
739 | ScopedTimer timer("postprocessing"); // Measure postprocessing time
740 |
741 | std::vector detections;
742 | const float* rawOutput = outputTensors[0].GetTensorData(); // Extract raw output data from the first output tensor
743 | const std::vector outputShape = outputTensors[0].GetTensorTypeAndShapeInfo().GetShape();
744 |
745 | // Determine the number of features and detections
746 | const size_t num_features = outputShape[1];
747 | const size_t num_detections = outputShape[2];
748 |
749 | // Early exit if no detections
750 | if (num_detections == 0) {
751 | return detections;
752 | }
753 |
754 | // Calculate number of classes based on output shape
755 | const int numClasses = static_cast(num_features) - 4;
756 | if (numClasses <= 0) {
757 | // Invalid number of classes
758 | return detections;
759 | }
760 |
761 | // Reserve memory for efficient appending
762 | std::vector boxes;
763 | boxes.reserve(num_detections);
764 | std::vector confs;
765 | confs.reserve(num_detections);
766 | std::vector classIds;
767 | classIds.reserve(num_detections);
768 | std::vector nms_boxes;
769 | nms_boxes.reserve(num_detections);
770 |
771 | // Constants for indexing
772 | const float* ptr = rawOutput;
773 |
774 | for (size_t d = 0; d < num_detections; ++d) {
775 | // Extract bounding box coordinates (center x, center y, width, height)
776 | float centerX = ptr[0 * num_detections + d];
777 | float centerY = ptr[1 * num_detections + d];
778 | float width = ptr[2 * num_detections + d];
779 | float height = ptr[3 * num_detections + d];
780 |
781 | // Find class with the highest confidence score
782 | int classId = -1;
783 | float maxScore = -FLT_MAX;
784 | for (int c = 0; c < numClasses; ++c) {
785 | const float score = ptr[d + (4 + c) * num_detections];
786 | if (score > maxScore) {
787 | maxScore = score;
788 | classId = c;
789 | }
790 | }
791 |
792 | // Proceed only if confidence exceeds threshold
793 | if (maxScore > confThreshold) {
794 | // Convert center coordinates to top-left (x1, y1)
795 | float left = centerX - width / 2.0f;
796 | float top = centerY - height / 2.0f;
797 |
798 | // Scale to original image size
799 | BoundingBox scaledBox = utils::scaleCoords(
800 | resizedImageShape,
801 | BoundingBox(left, top, width, height),
802 | originalImageSize,
803 | true
804 | );
805 |
806 | // Round coordinates for integer pixel positions
807 | BoundingBox roundedBox;
808 | roundedBox.x = std::round(scaledBox.x);
809 | roundedBox.y = std::round(scaledBox.y);
810 | roundedBox.width = std::round(scaledBox.width);
811 | roundedBox.height = std::round(scaledBox.height);
812 |
813 | // Adjust NMS box coordinates to prevent overlap between classes
814 | BoundingBox nmsBox = roundedBox;
815 | nmsBox.x += classId * 7680; // Arbitrary offset to differentiate classes
816 | nmsBox.y += classId * 7680;
817 |
818 | // Add to respective containers
819 | nms_boxes.emplace_back(nmsBox);
820 | boxes.emplace_back(roundedBox);
821 | confs.emplace_back(maxScore);
822 | classIds.emplace_back(classId);
823 | }
824 | }
825 |
826 | // Apply Non-Maximum Suppression (NMS) to eliminate redundant detections
827 | std::vector indices;
828 | utils::NMSBoxes(nms_boxes, confs, confThreshold, iouThreshold, indices);
829 |
830 | // Collect filtered detections into the result vector
831 | detections.reserve(indices.size());
832 | for (const int idx : indices) {
833 | detections.emplace_back(Detection{
834 | boxes[idx], // Bounding box
835 | confs[idx], // Confidence score
836 | classIds[idx] // Class ID
837 | });
838 | }
839 |
840 | DEBUG_PRINT("Postprocessing completed") // Debug log for completion
841 |
842 | return detections;
843 | }
844 |
845 | // Detect function implementation
846 | std::vector YOLO11Detector::detect(const cv::Mat& image, float confThreshold, float iouThreshold) {
847 | ScopedTimer timer("Overall detection");
848 |
849 | // Check for empty images
850 | if (image.empty()) {
851 | std::cerr << "Error: Empty image provided to detector" << std::endl;
852 | return {};
853 | }
854 |
855 | float* blobPtr = nullptr; // Pointer to hold preprocessed image data
856 | // Define the shape of the input tensor (batch size, channels, height, width)
857 | std::vector inputTensorShape = {1, 3, inputImageShape.height, inputImageShape.width};
858 |
859 | // Preprocess the image and obtain a pointer to the blob
860 | cv::Mat preprocessedImage = preprocess(image, blobPtr, inputTensorShape);
861 |
862 | // Compute the total number of elements in the input tensor
863 | size_t inputTensorSize = utils::vectorProduct(inputTensorShape);
864 |
865 | // Create a vector from the blob data for ONNX Runtime input
866 | std::vector inputTensorValues(blobPtr, blobPtr + inputTensorSize);
867 |
868 | delete[] blobPtr; // Free the allocated memory for the blob
869 |
870 | // Create an Ort memory info object (can be cached if used repeatedly)
871 | static Ort::MemoryInfo memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
872 |
873 | // Create input tensor object using the preprocessed data
874 | Ort::Value inputTensor = Ort::Value::CreateTensor(
875 | memoryInfo,
876 | inputTensorValues.data(),
877 | inputTensorSize,
878 | inputTensorShape.data(),
879 | inputTensorShape.size()
880 | );
881 |
882 | // Run the inference session with the input tensor and retrieve output tensors
883 | std::vector outputTensors = session.Run(
884 | Ort::RunOptions{nullptr},
885 | inputNames.data(),
886 | &inputTensor,
887 | numInputNodes,
888 | outputNames.data(),
889 | numOutputNodes
890 | );
891 |
892 | // Determine the resized image shape based on input tensor shape
893 | cv::Size resizedImageShape(static_cast(inputTensorShape[3]), static_cast(inputTensorShape[2]));
894 |
895 | // Postprocess the output tensors to obtain detections
896 | std::vector detections = postprocess(image.size(), resizedImageShape, outputTensors, confThreshold, iouThreshold);
897 |
898 | return detections; // Return the vector of detections
899 | }
--------------------------------------------------------------------------------
/src/ia/tools/Config.hpp:
--------------------------------------------------------------------------------
1 | // Config.hpp
2 | #ifndef CONFIG_HPP
3 | #define CONFIG_HPP
4 |
5 | // Enable debug messages to help troubleshoot
6 | #define DEBUG_MODE
7 |
8 | // Enable performance timing
9 | #define TIMING_MODE
10 |
11 | #endif // CONFIG_HPP
--------------------------------------------------------------------------------
/src/ia/tools/Debug.hpp:
--------------------------------------------------------------------------------
1 | // Debug.hpp
2 | #ifndef DEBUG_HPP
3 | #define DEBUG_HPP
4 |
5 |
6 | // Include necessary libraries
7 | #include
8 | #include "./tools/Config.hpp" // Include the config file to access the flags
9 |
10 | #ifdef DEBUG_MODE
11 | #define DEBUG_PRINT(x) std::cout << x << std::endl;
12 | #else
13 | #define DEBUG_PRINT(x)
14 | #endif
15 |
16 | #endif // DEBUG_HPP
17 |
--------------------------------------------------------------------------------
/src/ia/tools/ScopedTimer.hpp:
--------------------------------------------------------------------------------
1 | // ScopedTimer.hpp
2 | #ifndef SCOPEDTIMER_HPP
3 | #define SCOPEDTIMER_HPP
4 |
5 | #include
6 | #include
7 | #include
8 | #include "./tools/Config.hpp" // Include the config file to access the flags
9 |
10 | #ifdef TIMING_MODE
11 | class ScopedTimer {
12 | public:
13 | /**
14 | * @brief Constructs a ScopedTimer to measure the duration of a named code block.
15 | * @param name The name of the code block being timed.
16 | */
17 | ScopedTimer(const std::string &name)
18 | : func_name(name), start(std::chrono::high_resolution_clock::now()) {}
19 |
20 | /**
21 | * @brief Destructor that calculates and prints the elapsed time.
22 | */
23 | ~ScopedTimer() {
24 | auto stop = std::chrono::high_resolution_clock::now();
25 | std::chrono::duration duration = stop - start;
26 | std::cout << func_name << " took " << duration.count() << " milliseconds." << std::endl;
27 | }
28 |
29 | private:
30 | std::string func_name; ///< The name of the timed function.
31 | std::chrono::time_point start; ///< Start time point.
32 | };
33 | #else
34 | class ScopedTimer {
35 | public:
36 | ScopedTimer(const std::string &name) {}
37 | ~ScopedTimer() {}
38 | };
39 | #endif // TIMING_MODE
40 |
41 | #endif // SCOPEDTIMER_HPP
--------------------------------------------------------------------------------
/src/image_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DanielSarmiento04/yolov11cpp/c0690429b302c0b8a283a900ee30b89152019909/src/image_2.jpg
--------------------------------------------------------------------------------
/src/input.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DanielSarmiento04/yolov11cpp/c0690429b302c0b8a283a900ee30b89152019909/src/input.mov
--------------------------------------------------------------------------------
/src/kotlin/AndroidManifest.xml:
--------------------------------------------------------------------------------
1 |
3 |
4 |
12 |
13 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
--------------------------------------------------------------------------------
/src/kotlin/Application.kt:
--------------------------------------------------------------------------------
1 | package com.yolov11kotlin
2 |
3 | import android.app.Application
4 | import android.util.Log
5 | import org.opencv.android.OpenCVLoader
6 |
7 | /**
8 | * Application class that initializes OpenCV at app startup
9 | */
10 | class YoloApplication : Application() {
11 |
12 | override fun onCreate() {
13 | super.onCreate()
14 |
15 | // Initialize OpenCV with static initialization
16 | try {
17 | if (!OpenCVLoader.initDebug()) {
18 | Log.e(TAG, "OpenCV initialization failed")
19 | } else {
20 | Log.i(TAG, "OpenCV initialization succeeded")
21 | // Load the native library
22 | System.loadLibrary("opencv_java4")
23 | Log.i(TAG, "OpenCV native library loaded")
24 | }
25 | } catch (e: UnsatisfiedLinkError) {
26 | Log.e(TAG, "Failed to load OpenCV native library", e)
27 | } catch (e: Exception) {
28 | Log.e(TAG, "Error during OpenCV initialization", e)
29 | }
30 | }
31 |
32 | companion object {
33 | private const val TAG = "YoloApplication"
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/src/kotlin/BuildConfig.kt:
--------------------------------------------------------------------------------
1 | package com.yolov11kotlin
2 |
3 | /**
4 | * Build configuration flags for debugging and performance measurement
5 | * Matches the C++ configuration in Config.hpp
6 | */
7 | object BuildConfig {
8 | // Whether to enable debug logging (matches DEBUG_MODE in C++)
9 | const val DEBUG = true
10 |
11 | // Whether to enable performance timing measurements (matches TIMING_MODE in C++)
12 | const val TIMING_MODE = true
13 | }
14 |
--------------------------------------------------------------------------------
/src/kotlin/DebugUtils.kt:
--------------------------------------------------------------------------------
1 | package com.yolov11kotlin
2 |
3 | import android.util.Log
4 |
5 | /**
6 | * Debug utility functions that match the functionality from C++ implementation
7 | */
8 | object DebugUtils {
9 | private const val TAG = "YOLO11Debug"
10 |
11 | /**
12 | * Prints a debug message if DEBUG mode is enabled in BuildConfig
13 | */
14 | fun debug(message: String) {
15 | if (BuildConfig.DEBUG) {
16 | Log.d(TAG, message)
17 | }
18 | }
19 |
20 | /**
21 | * Prints an error message regardless of debug mode
22 | */
23 | fun error(message: String, throwable: Throwable? = null) {
24 | if (throwable != null) {
25 | Log.e(TAG, message, throwable)
26 | } else {
27 | Log.e(TAG, message)
28 | }
29 | }
30 |
31 | /**
32 | * Prints verbose information about model and inference
33 | */
34 | fun logModelInfo(modelPath: String, inputWidth: Int, inputHeight: Int, isQuantized: Boolean, numClasses: Int) {
35 | if (BuildConfig.DEBUG) {
36 | Log.d(TAG, "Model: $modelPath")
37 | Log.d(TAG, "Input dimensions: ${inputWidth}x${inputHeight}")
38 | Log.d(TAG, "Quantized: $isQuantized")
39 | Log.d(TAG, "Number of classes: $numClasses")
40 | }
41 | }
42 | }
43 |
--------------------------------------------------------------------------------
/src/kotlin/MainActivity.kt:
--------------------------------------------------------------------------------
1 | package com.example.opencv_tutorial
2 |
3 | import android.app.ActivityManager
4 | import android.content.Context
5 | import android.graphics.Bitmap
6 | import android.graphics.BitmapFactory
7 | import android.os.Bundle
8 | import android.util.Log
9 | import android.widget.ImageView
10 | import android.widget.TextView
11 | import androidx.appcompat.app.AppCompatActivity
12 | import org.opencv.android.OpenCVLoader
13 | import java.io.IOException
14 | import java.util.concurrent.Executors
15 | import android.os.SystemClock
16 | import android.graphics.Matrix
17 | import android.os.Build
18 | import androidx.core.content.ContextCompat
19 | import org.tensorflow.lite.gpu.CompatibilityList
20 | import java.util.Locale
21 |
22 | class MainActivity : AppCompatActivity() {
23 |
24 | // Views for UI
25 | private lateinit var imageView: ImageView
26 | private lateinit var resultText: TextView
27 |
28 | // YOLOv11 detector instance
29 | private lateinit var yoloDetector: YOLO11Detector
30 |
31 | // Background thread for async loading
32 | private val backgroundExecutor = Executors.newSingleThreadExecutor()
33 |
34 | override fun onCreate(savedInstanceState: Bundle?) {
35 | super.onCreate(savedInstanceState)
36 | setContentView(R.layout.activity_main)
37 |
38 | // Initialize UI components
39 | imageView = findViewById(R.id.imageView)
40 | resultText = findViewById(R.id.resultText)
41 |
42 | // Initialize OpenCV and proceed with detection in background
43 | initializeOpenCVAndDetector()
44 | }
45 |
46 | private fun initializeOpenCVAndDetector() {
47 | resultText.text = "Initializing OpenCV..."
48 |
49 | backgroundExecutor.execute {
50 | try {
51 | // Use static initialization for OpenCV
52 | if (!OpenCVLoader.initDebug()) {
53 | Log.e(TAG, "Unable to load OpenCV")
54 | runOnUiThread {
55 | resultText.text = "Error: OpenCV initialization failed."
56 | }
57 | return@execute
58 | }
59 |
60 | // Load native OpenCV library
61 | try {
62 | System.loadLibrary("opencv_java4")
63 | Log.i(TAG, "OpenCV loaded successfully")
64 |
65 | // Now proceed with detector initialization
66 | initializeDetectorAndProcess()
67 | } catch (e: UnsatisfiedLinkError) {
68 | Log.e(TAG, "Unable to load OpenCV native library", e)
69 | runOnUiThread {
70 | resultText.text = "Error: OpenCV native library failed to load.\nError: ${e.message}"
71 | }
72 | } catch (e: Exception) {
73 | Log.e(TAG, "Error during OpenCV initialization", e)
74 | runOnUiThread {
75 | resultText.text = "Error: ${e.message}"
76 | }
77 | }
78 | } catch (e: Exception) {
79 | Log.e(TAG, "Unexpected error during initialization", e)
80 | runOnUiThread {
81 | resultText.text = "Unexpected error: ${e.message}"
82 | }
83 | }
84 | }
85 | }
86 |
87 | override fun onResume() {
88 | super.onResume()
89 | // Reinitialize if necessary but avoid duplicate initialization
90 | if (!::yoloDetector.isInitialized && !backgroundExecutor.isShutdown) {
91 | initializeOpenCVAndDetector()
92 | }
93 | }
94 |
95 | private fun initializeDetectorAndProcess() {
96 | runOnUiThread {
97 | resultText.text = "Loading model and preparing detection..."
98 | }
99 |
100 | try {
101 | // Initialize the YOLO11 detector with model and labels from assets
102 | // Try alternative model formats if the default fails
103 | val modelVariants = listOf(
104 | "best_float16.tflite", // Try float16 first (smaller, works on many devices)
105 | "best_float32.tflite", // Try float32 as fallback (more compatible but larger)
106 | "best.tflite" // Try default naming as last resort
107 | )
108 |
109 | val labelsPath = "classes.txt"
110 |
111 | // Check device compatibility first with more accurate detection
112 | val useGPU = checkGpuCompatibility()
113 | Log.d(TAG, "GPU acceleration decision: $useGPU")
114 |
115 | // Try model variants in sequence until one works
116 | var lastException: Exception? = null
117 | var detector: YOLO11Detector? = null
118 |
119 | for (modelFile in modelVariants) {
120 | try {
121 | Log.d(TAG, "Attempting to load model: $modelFile")
122 |
123 | // Check if file exists in assets
124 | try {
125 | assets.open(modelFile).close()
126 | } catch (e: IOException) {
127 | Log.d(TAG, "Model file $modelFile not found in assets, skipping")
128 | continue
129 | }
130 |
131 | runOnUiThread {
132 | resultText.text = "Loading model: $modelFile..."
133 | }
134 |
135 | // Create detector with current model variant
136 | detector = YOLO11Detector(
137 | context = this,
138 | modelPath = modelFile,
139 | labelsPath = labelsPath,
140 | useGPU = useGPU
141 | )
142 |
143 | // If we get here, initialization succeeded
144 | yoloDetector = detector
145 | Log.d(TAG, "Successfully initialized detector with model: $modelFile")
146 | break
147 |
148 | } catch (e: Exception) {
149 | Log.e(TAG, "Failed to initialize with model $modelFile: ${e.message}")
150 | e.printStackTrace()
151 | lastException = e
152 |
153 | // If this is GPU mode and failed, try again with CPU
154 | if (useGPU) {
155 | try {
156 | Log.d(TAG, "Retrying model $modelFile with CPU only")
157 | detector = YOLO11Detector(
158 | context = this,
159 | modelPath = modelFile,
160 | labelsPath = labelsPath,
161 | useGPU = false
162 | )
163 |
164 | yoloDetector = detector
165 | Log.d(TAG, "Successfully initialized detector with CPU and model: $modelFile")
166 | break
167 | } catch (cpuEx: Exception) {
168 | Log.e(TAG, "CPU fallback also failed for $modelFile: ${cpuEx.message}")
169 | cpuEx.printStackTrace()
170 | }
171 | }
172 | }
173 | }
174 |
175 | // Check if any model variant worked
176 | if (detector == null) {
177 | throw RuntimeException("Failed to initialize detector with any available model", lastException)
178 | }
179 |
180 | runOnUiThread {
181 | resultText.text = "Model loaded successfully, preparing image..."
182 | }
183 |
184 | // Load test image from assets
185 | val imageBitmap = loadImageFromAssets("image_2.jpg")
186 |
187 | if (imageBitmap != null) {
188 | Log.d(TAG, "Image loaded with dimensions: ${imageBitmap.width}x${imageBitmap.height}")
189 |
190 | runOnUiThread {
191 | resultText.text = "Running detection..."
192 | }
193 |
194 | try {
195 | val startTime = SystemClock.elapsedRealtime()
196 |
197 | // Use exactly the same thresholds as in C++
198 | val confThreshold = 0.25f
199 | val iouThreshold = 0.45f
200 |
201 | Log.d(TAG, "Starting detection with conf=$confThreshold, iou=$iouThreshold")
202 |
203 | // Run detection
204 | val detections = yoloDetector.detect(
205 | bitmap = imageBitmap,
206 | confidenceThreshold = confThreshold,
207 | iouThreshold = iouThreshold
208 | )
209 |
210 | val inferenceTime = SystemClock.elapsedRealtime() - startTime
211 | Log.d(TAG, "Detection completed in $inferenceTime ms, found ${detections.size} objects")
212 |
213 | // More detailed logging for debugging
214 | if (detections.isEmpty()) {
215 | Log.d(TAG, "WARNING: No detections found! Check confidence threshold.")
216 | } else {
217 | // Log first few detections in more detail
218 | detections.take(5).forEachIndexed { index, detection ->
219 | val className = yoloDetector.getClassName(detection.classId)
220 | val box = detection.box
221 | Log.d(TAG, "Top detection #$index: $className (${detection.conf}), " +
222 | "box=${box.x},${box.y},${box.width},${box.height}, " +
223 | "area=${box.width * box.height}")
224 | }
225 | }
226 |
227 | // Filter by confidence for display purposes
228 | val displayThreshold = 0.30f // Higher threshold just for display
229 | val qualityDetections = detections.filter { it.conf > displayThreshold }
230 | Log.d(TAG, "After filtering with threshold $displayThreshold: ${qualityDetections.size} detections")
231 |
232 | // Draw detections with mask overlay for better visualization
233 | val resultBitmap = yoloDetector.drawDetectionsMask(imageBitmap, qualityDetections)
234 |
235 | // Show results in UI
236 | runOnUiThread {
237 | // Display the image with detections
238 | imageView.setImageBitmap(resultBitmap)
239 |
240 | // Format and display detection results
241 | val resultInfo = StringBuilder()
242 | resultInfo.append("Detection completed in $inferenceTime ms\n")
243 | resultInfo.append("Found ${detections.size} objects (${qualityDetections.size} shown)\n\n")
244 |
245 | // Display top detections with highest confidence
246 | qualityDetections.sortedByDescending { it.conf }
247 | .take(5)
248 | .forEach { detection ->
249 | val className = yoloDetector.getClassName(detection.classId)
250 | val confidence = (detection.conf * 100).toInt()
251 | resultInfo.append("• $className: ${confidence}%\n")
252 | }
253 |
254 | resultText.text = resultInfo.toString()
255 | }
256 | } catch (e: Exception) {
257 | Log.e(TAG, "Error during detection", e)
258 | // Show original image at least
259 | val finalImageBitmap = imageBitmap
260 | runOnUiThread {
261 | resultText.text = "Detection error: ${e.message}\n${e.stackTraceToString().take(200)}..."
262 | imageView.setImageBitmap(finalImageBitmap)
263 | }
264 | }
265 | } else {
266 | runOnUiThread {
267 | resultText.text = "Error: Failed to load image from assets. Please check that image_2.jpg exists in the assets folder."
268 | }
269 | }
270 | } catch (e: Exception) {
271 | Log.e(TAG, "Error in detection process", e)
272 | runOnUiThread {
273 | resultText.text = "Error: ${e.message}\n${e.stackTraceToString().take(300)}..."
274 | }
275 | }
276 | }
277 |
278 | /**
279 | * Check if the device is compatible with GPU acceleration with enhanced detection
280 | */
281 | private fun checkGpuCompatibility(): Boolean {
282 | Log.d(TAG, "Checking GPU compatibility...")
283 |
284 | // Check if GPU delegation is supported
285 | val compatList = CompatibilityList()
286 | val isGpuSupported = compatList.isDelegateSupportedOnThisDevice
287 | Log.d(TAG, "GPU supported according to compatibility list: $isGpuSupported")
288 |
289 | // Check if running on emulator
290 | val isEmulator = Build.FINGERPRINT.contains("generic") ||
291 | Build.FINGERPRINT.startsWith("unknown") ||
292 | Build.MODEL.contains("google_sdk") ||
293 | Build.MODEL.contains("Emulator") ||
294 | Build.MODEL.contains("Android SDK")
295 | Log.d(TAG, "Is emulator: $isEmulator")
296 |
297 | // Check known problematic device models and manufacturers
298 | val deviceModel = Build.MODEL.toLowerCase(Locale.ROOT)
299 | val manufacturer = Build.MANUFACTURER.toLowerCase(Locale.ROOT)
300 |
301 | // List of known problematic device patterns
302 | val problematicPatterns = listOf(
303 | "mali-g57", "mali-g72", "mali-g52", "mali-g76", // Some Mali GPUs have TFLite issues
304 | "adreno 6", "adreno 5", // Some older Adreno GPUs
305 | "mediatek", "mt6", "helio" // Some MediaTek chips
306 | )
307 |
308 | val isProblematicDevice = problematicPatterns.any { pattern ->
309 | deviceModel.contains(pattern) || manufacturer.contains(pattern)
310 | }
311 |
312 | Log.d(TAG, "Device details: manufacturer=$manufacturer, model=$deviceModel")
313 | Log.d(TAG, "Is problematic device: $isProblematicDevice")
314 |
315 | // Check Android version - some versions have known TFLite GPU issues
316 | val androidVersion = Build.VERSION.SDK_INT
317 | val isProblematicAndroidVersion = androidVersion < Build.VERSION_CODES.P // Android 9-
318 |
319 | Log.d(TAG, "Android version: $androidVersion, problematic: $isProblematicAndroidVersion")
320 |
321 | // Check available memory - GPU acceleration needs sufficient memory
322 | val memoryInfo = ActivityManager.MemoryInfo()
323 | val activityManager = getSystemService(Context.ACTIVITY_SERVICE) as ActivityManager
324 | activityManager.getMemoryInfo(memoryInfo)
325 |
326 | val availableMem = memoryInfo.availMem / (1024 * 1024) // Convert to MB
327 | val lowMemory = availableMem < 200 // Less than 200MB available
328 |
329 | Log.d(TAG, "Available memory: $availableMem MB, low memory: $lowMemory")
330 |
331 | // Final decision based on all factors
332 | val shouldUseGpu = isGpuSupported &&
333 | !isEmulator &&
334 | !isProblematicDevice &&
335 | !isProblematicAndroidVersion &&
336 | !lowMemory
337 |
338 | Log.d(TAG, "Final GPU acceleration decision: $shouldUseGpu")
339 |
340 | return shouldUseGpu
341 | }
342 |
343 | /**
344 | * Load an image from the assets folder with proper orientation and error handling
345 | */
346 | private fun loadImageFromAssets(fileName: String): Bitmap? {
347 | return try {
348 | val startTime = SystemClock.elapsedRealtime()
349 |
350 | assets.open(fileName).use { inputStream ->
351 | // Load image size first to check dimensions
352 | val options = BitmapFactory.Options().apply {
353 | inJustDecodeBounds = true
354 | }
355 | BitmapFactory.decodeStream(inputStream, null, options)
356 | inputStream.reset()
357 |
358 | // If image is very large, scale it down to avoid memory issues
359 | val maxDimension = 1920 // Reasonable max size for detection
360 | val sampleSize = calculateSampleSize(options.outWidth, options.outHeight, maxDimension)
361 |
362 | // Decode with appropriate sample size
363 | val decodeOptions = BitmapFactory.Options().apply {
364 | inPreferredConfig = Bitmap.Config.ARGB_8888
365 | inScaled = false
366 | inSampleSize = sampleSize
367 | }
368 |
369 | val bitmap = BitmapFactory.decodeStream(inputStream, null, decodeOptions)
370 |
371 | val loadTime = SystemClock.elapsedRealtime() - startTime
372 | Log.d(TAG, "Image loaded: ${bitmap?.width}x${bitmap?.height} " +
373 | "(original: ${options.outWidth}x${options.outHeight}, " +
374 | "sample size: $sampleSize), took $loadTime ms")
375 | bitmap
376 | }
377 | } catch (e: Exception) {
378 | Log.e(TAG, "Failed to load image '$fileName'", e)
379 | null
380 | }
381 | }
382 |
383 | /**
384 | * Calculate appropriate sample size for large images
385 | */
386 | private fun calculateSampleSize(width: Int, height: Int, maxDimension: Int): Int {
387 | var sampleSize = 1
388 | while (width / sampleSize > maxDimension || height / sampleSize > maxDimension) {
389 | sampleSize *= 2
390 | }
391 | return sampleSize
392 | }
393 |
394 | override fun onDestroy() {
395 | super.onDestroy()
396 | // Clean up resources
397 | if (::yoloDetector.isInitialized) {
398 | yoloDetector.close()
399 | }
400 | // Shutdown executor service
401 | backgroundExecutor.shutdown()
402 | }
403 |
404 | companion object {
405 | private const val TAG = "YOLO11MainActivity"
406 | }
407 | }
408 |
--------------------------------------------------------------------------------
/src/kotlin/ModelParseActivity.kt:
--------------------------------------------------------------------------------
1 | package com.example.opencv_tutorial
2 |
3 | import android.os.Bundle
4 | import android.util.Log
5 | import android.widget.TextView
6 | import androidx.appcompat.app.AppCompatActivity
7 | import kotlinx.coroutines.CoroutineScope
8 | import kotlinx.coroutines.Dispatchers
9 | import kotlinx.coroutines.launch
10 | import kotlinx.coroutines.withContext
11 | import org.tensorflow.lite.support.metadata.MetadataExtractor
12 | import java.io.File
13 | import java.io.FileOutputStream
14 | import java.nio.ByteBuffer
15 | import java.nio.channels.FileChannel
16 |
17 | /**
18 | * Diagnostic activity for detailed model inspection
19 | * This helps identify issues with model loading on physical devices
20 | */
21 | class ModelParseActivity : AppCompatActivity() {
22 | private lateinit var resultText: TextView
23 | private val scope = CoroutineScope(Dispatchers.Main)
24 |
25 | companion object {
26 | private const val TAG = "ModelParse"
27 | }
28 |
29 | override fun onCreate(savedInstanceState: Bundle?) {
30 | super.onCreate(savedInstanceState)
31 | setContentView(R.layout.activity_model_parse)
32 |
33 | resultText = findViewById(R.id.modelParseResultText)
34 | resultText.text = "Analyzing TFLite model..."
35 |
36 | // Run model inspection in background
37 | scope.launch {
38 | try {
39 | val results = withContext(Dispatchers.IO) {
40 | analyzeModels()
41 | }
42 | resultText.text = results
43 | } catch (e: Exception) {
44 | Log.e(TAG, "Error during model analysis", e)
45 | resultText.text = "Error analyzing models:\n${e.message}\n\n${e.stackTraceToString()}"
46 | }
47 | }
48 | }
49 |
50 | private fun analyzeModels(): String {
51 | val result = StringBuilder()
52 | result.append("TFLite Model Analysis\n")
53 | result.append("====================\n\n")
54 |
55 | val modelFiles = listOf(
56 | "best_float16.tflite",
57 | "best_float32.tflite",
58 | "best.tflite"
59 | )
60 |
61 | for (modelFile in modelFiles) {
62 | try {
63 | result.append("Model: $modelFile\n")
64 | result.append("-----------------\n")
65 |
66 | // Check if file exists
67 | try {
68 | assets.open(modelFile).close()
69 | result.append("File exists in assets: Yes\n")
70 | } catch (e: Exception) {
71 | result.append("File exists in assets: No\n")
72 | result.append("\n")
73 | continue
74 | }
75 |
76 | // Extract model to temp file for analysis
77 | val tempFile = extractModelToTemp(modelFile)
78 |
79 | result.append("File size: ${tempFile.length()} bytes\n")
80 |
81 | // Basic header verification
82 | val isValidFlatBuffer = checkFlatBufferHeader(tempFile)
83 | result.append("Valid FlatBuffer header: $isValidFlatBuffer\n")
84 |
85 | // Try to parse model metadata
86 | try {
87 | val metadata = parseModelMetadata(tempFile)
88 | result.append(metadata)
89 | } catch (e: Exception) {
90 | result.append("Metadata extraction failed: ${e.message}\n")
91 | }
92 |
93 | // Try basic TFLite interpreter creation
94 | try {
95 | testInterpreterCreation(modelFile)
96 | result.append("Interpreter creation: Success\n")
97 | } catch (e: Exception) {
98 | result.append("Interpreter creation failed: ${e.message}\n")
99 | }
100 |
101 | result.append("\n")
102 |
103 | } catch (e: Exception) {
104 | result.append("Error analyzing $modelFile: ${e.message}\n\n")
105 | }
106 | }
107 |
108 | // Add device information
109 | result.append("Device Information\n")
110 | result.append("-----------------\n")
111 | result.append("Manufacturer: ${android.os.Build.MANUFACTURER}\n")
112 | result.append("Model: ${android.os.Build.MODEL}\n")
113 | result.append("Android version: ${android.os.Build.VERSION.RELEASE} (SDK ${android.os.Build.VERSION.SDK_INT})\n")
114 | result.append("ABI: ${android.os.Build.SUPPORTED_ABIS.joinToString()}\n")
115 |
116 | return result.toString()
117 | }
118 |
119 | private fun extractModelToTemp(modelFile: String): File {
120 | val file = File(cacheDir, "temp_$modelFile")
121 |
122 | assets.open(modelFile).use { input ->
123 | FileOutputStream(file).use { output ->
124 | val buffer = ByteArray(4 * 1024)
125 | var read: Int
126 | while (input.read(buffer).also { read = it } != -1) {
127 | output.write(buffer, 0, read)
128 | }
129 | output.flush()
130 | }
131 | }
132 |
133 | return file
134 | }
135 |
136 | private fun checkFlatBufferHeader(file: File): Boolean {
137 | return file.inputStream().use { input ->
138 | val header = ByteArray(8)
139 | val bytesRead = input.read(header)
140 |
141 | // Check standard FlatBuffer header
142 | (bytesRead == 8) &&
143 | header[0].toInt() == 0x18 &&
144 | header[1].toInt() == 0x00 &&
145 | header[2].toInt() == 0x00 &&
146 | header[3].toInt() == 0x00
147 | }
148 | }
149 |
150 | private fun parseModelMetadata(file: File): String {
151 | val result = StringBuilder()
152 |
153 | try {
154 | val mappedBuffer = file.inputStream().channel.map(
155 | FileChannel.MapMode.READ_ONLY, 0, file.length()
156 | )
157 |
158 | val metadataExtractor = MetadataExtractor(mappedBuffer)
159 |
160 | // Check if model has metadata
161 | if (metadataExtractor.hasMetadata()) {
162 | result.append("Has metadata: Yes\n")
163 |
164 | // Get model description
165 | val modelMetadata = metadataExtractor.modelMetadata
166 | if (modelMetadata != null) {
167 | result.append("Model name: ${modelMetadata.name()}\n")
168 | result.append("Model description: ${modelMetadata.description()}\n")
169 | result.append("Model version: ${modelMetadata.version()}\n")
170 | }
171 |
172 | // Get input/output tensors
173 | val inputTensorCount = metadataExtractor.inputTensorCount
174 | val outputTensorCount = metadataExtractor.outputTensorCount
175 |
176 | result.append("Input tensors: $inputTensorCount\n")
177 | result.append("Output tensors: $outputTensorCount\n")
178 |
179 | for (i in 0 until inputTensorCount) {
180 | val tensorMetadata = metadataExtractor.getInputTensorMetadata(i)
181 | result.append("Input #$i: ${tensorMetadata.name()}, ")
182 | result.append("type: ${tensorMetadata.tensorType().name}\n")
183 | }
184 | } else {
185 | result.append("Has metadata: No\n")
186 | }
187 |
188 | // Get basic model info directly from the buffer
189 | try {
190 | mappedBuffer.rewind()
191 | val model = org.tensorflow.lite.schema.Model.getRootAsModel(mappedBuffer)
192 | result.append("Model version: ${model.version()}\n")
193 | result.append("Operator codes: ${model.operatorCodesLength()}\n")
194 | result.append("Subgraphs: ${model.subgraphsLength()}\n")
195 |
196 | if (model.subgraphsLength() > 0) {
197 | val subgraph = model.subgraphs(0)
198 | if (subgraph != null) {
199 | result.append("Inputs: ${subgraph.inputsLength()}, ")
200 | result.append("Outputs: ${subgraph.outputsLength()}\n")
201 | }
202 | }
203 | } catch (e: Exception) {
204 | result.append("Schema parse error: ${e.message}\n")
205 | }
206 |
207 | } catch (e: Exception) {
208 | result.append("Metadata extraction error: ${e.message}\n")
209 | }
210 |
211 | return result.toString()
212 | }
213 |
214 | private fun testInterpreterCreation(modelFile: String) {
215 | val assetFd = assets.openFd(modelFile)
216 | val fileChannel = FileInputStream(assetFd.fileDescriptor).channel
217 | val mappedBuffer = fileChannel.map(
218 | FileChannel.MapMode.READ_ONLY,
219 | assetFd.startOffset,
220 | assetFd.declaredLength
221 | )
222 |
223 | // Test creating interpreter with basic options
224 | val options = org.tensorflow.lite.Interpreter.Options()
225 | val interpreter = org.tensorflow.lite.Interpreter(mappedBuffer, options)
226 |
227 | // Log the model info
228 | val inputs = interpreter.inputTensorCount
229 | val outputs = interpreter.outputTensorCount
230 | Log.d(TAG, "Model has $inputs inputs and $outputs outputs")
231 |
232 | // Clean up
233 | interpreter.close()
234 | fileChannel.close()
235 | assetFd.close()
236 | }
237 | }
238 |
--------------------------------------------------------------------------------
/src/kotlin/ScopedTimer.kt:
--------------------------------------------------------------------------------
1 | package com.yolov11kotlin
2 |
3 | import android.os.SystemClock
4 | import android.util.Log
5 |
6 | /**
7 | * Utility class for measuring execution time of code blocks.
8 | * Only logs times when TIMING_MODE is enabled in the BuildConfig.
9 | */
10 | class ScopedTimer(private val name: String) {
11 | private val startTime: Long = SystemClock.elapsedRealtime()
12 | private var stopped = false
13 |
14 | /**
15 | * Stops the timer and logs the elapsed time.
16 | */
17 | fun stop() {
18 | if (stopped) return
19 | stopped = true
20 |
21 | if (BuildConfig.TIMING_MODE) {
22 | val endTime = SystemClock.elapsedRealtime()
23 | val duration = endTime - startTime
24 | Log.d("ScopedTimer", "$name took $duration milliseconds.")
25 | }
26 | }
27 |
28 | /**
29 | * Automatically stops the timer when the object is garbage collected.
30 | */
31 | protected fun finalize() {
32 | if (!stopped) {
33 | stop()
34 | }
35 | }
36 | }
37 |
--------------------------------------------------------------------------------
/src/kotlin/TFLiteModelManager.kt:
--------------------------------------------------------------------------------
1 | package com.example.opencv_tutorial
2 |
3 | import android.content.Context
4 | import android.os.Build
5 | import android.util.Log
6 | import java.io.File
7 | import java.io.FileOutputStream
8 | import java.io.IOException
9 | import java.nio.MappedByteBuffer
10 | import java.nio.channels.FileChannel
11 |
12 | /**
13 | * Utility class for TFLite model management
14 | * Handles model extraction, validation, and adaptation
15 | */
16 | class TFLiteModelManager(private val context: Context) {
17 | companion object {
18 | private const val TAG = "TFLiteModelManager"
19 | }
20 |
21 | /**
22 | * Extracts and validates a TFLite model from assets
23 | * May convert the model format to ensure compatibility with the device
24 | * @return Path to the optimized model file
25 | */
26 | fun prepareModelForDevice(assetModelPath: String): String {
27 | Log.d(TAG, "Preparing model: $assetModelPath")
28 |
29 | try {
30 | // First check if the model exists
31 | val assets = context.assets
32 | assets.open(assetModelPath).use { inStream ->
33 | // Read some header bytes to validate the file
34 | val header = ByteArray(8)
35 | val bytesRead = inStream.read(header)
36 |
37 | if (bytesRead != 8) {
38 | throw IOException("Could not read model header bytes")
39 | }
40 |
41 | // Verify this is a valid FlatBuffer file (basic check)
42 | // TFLite models should have the first 4 bytes as the FlatBuffer header
43 | if (header[0].toInt() != 0x18 || header[1].toInt() != 0x00 ||
44 | header[2].toInt() != 0x00 || header[3].toInt() != 0x00) {
45 | Log.w(TAG, "Model may not be a valid FlatBuffer file")
46 | }
47 |
48 | Log.d(TAG, "Model header verified")
49 | }
50 |
51 | // Extract to local storage for potential modification
52 | val modelFile = extractAssetToCache(assetModelPath)
53 | Log.d(TAG, "Model extracted to: ${modelFile.absolutePath}")
54 |
55 | return modelFile.absolutePath
56 |
57 | } catch (e: Exception) {
58 | Log.e(TAG, "Error preparing model: ${e.message}")
59 | throw e
60 | }
61 | }
62 |
63 | /**
64 | * Extract an asset file to the app's cache directory
65 | */
66 | private fun extractAssetToCache(assetPath: String): File {
67 | val fileName = assetPath.substringAfterLast("/")
68 | val outputFile = File(context.cacheDir, "models_${Build.VERSION.SDK_INT}_$fileName")
69 |
70 | // Only extract if the file doesn't exist or is outdated
71 | if (!outputFile.exists() || outputFile.length() == 0L) {
72 | Log.d(TAG, "Extracting asset to: ${outputFile.absolutePath}")
73 |
74 | context.assets.open(assetPath).use { inputStream ->
75 | FileOutputStream(outputFile).use { outputStream ->
76 | val buffer = ByteArray(4 * 1024)
77 | var read: Int
78 | while (inputStream.read(buffer).also { read = it } != -1) {
79 | outputStream.write(buffer, 0, read)
80 | }
81 | outputStream.flush()
82 | }
83 | }
84 | } else {
85 | Log.d(TAG, "Using cached model: ${outputFile.absolutePath}")
86 | }
87 |
88 | return outputFile
89 | }
90 |
91 | /**
92 | * Load a TFLite model from a file with enhanced error handling
93 | */
94 | fun loadModelFile(modelPath: String): MappedByteBuffer {
95 | Log.d(TAG, "Loading model file: $modelPath")
96 |
97 | val file = File(modelPath)
98 | if (!file.exists()) {
99 | throw IOException("Model file not found: $modelPath")
100 | }
101 |
102 | return file.inputStream().channel.map(
103 | FileChannel.MapMode.READ_ONLY, 0, file.length()
104 | ).also {
105 | Log.d(TAG, "Model loaded, capacity: ${it.capacity()} bytes")
106 | }
107 | }
108 |
109 | /**
110 | * Check if a model file appears to be valid
111 | */
112 | fun validateModelFile(modelPath: String): Boolean {
113 | try {
114 | val file = File(modelPath)
115 | if (!file.exists() || file.length() < 8) {
116 | return false
117 | }
118 |
119 | // Basic header check
120 | file.inputStream().use { input ->
121 | val header = ByteArray(8)
122 | input.read(header)
123 |
124 | // Check for FlatBuffer header
125 | return header[0].toInt() == 0x18 && header[1].toInt() == 0x00 &&
126 | header[2].toInt() == 0x00 && header[3].toInt() == 0x00
127 | }
128 | } catch (e: Exception) {
129 | Log.e(TAG, "Error validating model file: ${e.message}")
130 | return false
131 | }
132 | }
133 | }
134 |
--------------------------------------------------------------------------------
/src/kotlin/YOLO11Detector.kt:
--------------------------------------------------------------------------------
1 | package com.example.opencv_tutorial
2 |
3 | import android.content.Context
4 | import android.graphics.Bitmap
5 | import android.graphics.Canvas
6 | import android.graphics.Color
7 | import android.graphics.Paint
8 | import android.graphics.RectF
9 | import android.os.Build
10 | import android.os.SystemClock
11 | import android.util.Log
12 | import org.opencv.android.Utils
13 | import org.opencv.core.*
14 | import org.opencv.imgproc.Imgproc
15 | import org.tensorflow.lite.Interpreter
16 | import org.tensorflow.lite.gpu.CompatibilityList
17 | import org.tensorflow.lite.gpu.GpuDelegate
18 | import java.io.FileInputStream
19 | import java.nio.ByteBuffer
20 | import java.nio.ByteOrder
21 | import java.nio.MappedByteBuffer
22 | import java.nio.channels.FileChannel
23 | import java.util.*
24 | import kotlin.math.max
25 | import kotlin.math.min
26 | import kotlin.math.round
27 | //import android.util.Log
28 |
29 | /**
30 | * YOLOv11Detector for Android using TFLite and OpenCV
31 | *
32 | * This class handles object detection using the YOLOv11 model with TensorFlow Lite
33 | * for inference and OpenCV for image processing.
34 | */
35 | class YOLO11Detector(
36 | private val context: Context,
37 | private val modelPath: String,
38 | private val labelsPath: String,
39 | useGPU: Boolean = true
40 | ) {
41 | // Detection parameters - matching C++ implementation
42 | companion object {
43 | // Match the C++ implementation thresholds
44 | const val CONFIDENCE_THRESHOLD = 0.25f // Changed from 0.4f to match C++ code
45 | const val IOU_THRESHOLD = 0.45f // Changed from 0.3f to match C++ code
46 | private const val TAG = "YOLO11Detector"
47 | }
48 |
49 | // Data structures for model and inference
50 | private var interpreter: Interpreter
51 | private val classNames: List
52 | private val classColors: List
53 | private var gpuDelegate: GpuDelegate? = null
54 |
55 | // Input shape info
56 | private var inputWidth: Int = 640
57 | private var inputHeight: Int = 640
58 | private var isQuantized: Boolean = false
59 | private var numClasses: Int = 0
60 |
61 | init {
62 | try {
63 | // Log starting initialization for debugging purposes
64 | debug("Initializing YOLO11Detector with model: $modelPath, useGPU: $useGPU")
65 | debug("Device: ${Build.MANUFACTURER} ${Build.MODEL}, Android ${Build.VERSION.SDK_INT}")
66 |
67 | // Load model with proper options
68 | val tfliteOptions = Interpreter.Options()
69 |
70 | // GPU Delegate setup with improved validation and error recovery
71 | if (useGPU) {
72 | try {
73 | val compatList = CompatibilityList()
74 | debug("GPU delegate supported on device: ${compatList.isDelegateSupportedOnThisDevice}")
75 |
76 | if (compatList.isDelegateSupportedOnThisDevice) {
77 | // First try to create GPU delegate without configuring options
78 | // This can help detect early incompatibilities
79 | try {
80 | val tempDelegate = GpuDelegate()
81 | tempDelegate.close() // Just testing creation
82 | debug("Basic GPU delegate creation successful")
83 | } catch (e: Exception) {
84 | debug("Basic GPU delegate test failed: ${e.message}")
85 | throw Exception("Device reports GPU compatible but fails basic delegate test")
86 | }
87 |
88 | debug("Configuring GPU acceleration with safe defaults")
89 |
90 | // Use conservative GPU delegation options
91 | val delegateOptions = GpuDelegate.Options().apply {
92 | setPrecisionLossAllowed(true) // Allow precision loss for better compatibility
93 | setQuantizedModelsAllowed(true) // Allow quantized models
94 | }
95 |
96 | gpuDelegate = GpuDelegate(delegateOptions)
97 | tfliteOptions.addDelegate(gpuDelegate)
98 | debug("GPU delegate successfully created and added")
99 |
100 | // Always configure CPU fallback options
101 | configureCpuOptions(tfliteOptions)
102 | } else {
103 | debug("GPU acceleration not supported on this device, using CPU only")
104 | configureCpuOptions(tfliteOptions)
105 | }
106 | } catch (e: Exception) {
107 | debug("Error setting up GPU acceleration: ${e.message}, stack: ${e.stackTraceToString()}")
108 | debug("Falling back to CPU execution")
109 | // Clean up any GPU resources
110 | try {
111 | gpuDelegate?.close()
112 | } catch (closeEx: Exception) {
113 | debug("Error closing GPU delegate: ${closeEx.message}")
114 | }
115 | gpuDelegate = null
116 | configureCpuOptions(tfliteOptions)
117 | }
118 | } else {
119 | debug("GPU acceleration disabled, using CPU only")
120 | configureCpuOptions(tfliteOptions)
121 | }
122 |
123 | // Enhanced model loading with diagnostics
124 | val modelBuffer: MappedByteBuffer
125 | try {
126 | debug("Loading model from assets: $modelPath")
127 | modelBuffer = loadModelFile(modelPath)
128 | debug("Model loaded successfully, size: ${modelBuffer.capacity() / 1024} KB")
129 |
130 | // Simple validation - check if buffer size is reasonable
131 | if (modelBuffer.capacity() < 10000) {
132 | throw RuntimeException("Model file appears too small (${modelBuffer.capacity()} bytes)")
133 | }
134 | } catch (e: Exception) {
135 | debug("Failed to load model: ${e.message}")
136 | throw RuntimeException("Model loading failed: ${e.message}", e)
137 | }
138 |
139 | // Initialize interpreter with more controlled error handling
140 | try {
141 | debug("Creating TFLite interpreter")
142 |
143 | // Add memory management options for large models
144 | tfliteOptions.setAllowFp16PrecisionForFp32(true) // Reduce memory requirements
145 |
146 | interpreter = Interpreter(modelBuffer, tfliteOptions)
147 | debug("TFLite interpreter created successfully")
148 |
149 | // Log interpreter details for diagnostics
150 | val inputTensor = interpreter.getInputTensor(0)
151 | val inputShape = inputTensor.shape()
152 | val outputTensor = interpreter.getOutputTensor(0)
153 | val outputShape = outputTensor.shape()
154 |
155 | debug("Model input shape: ${inputShape.joinToString()}")
156 | debug("Model output shape: ${outputShape.joinToString()}")
157 | debug("Input tensor type: ${inputTensor.dataType()}")
158 |
159 | // Capture model input properties
160 | inputHeight = inputShape[1]
161 | inputWidth = inputShape[2]
162 | isQuantized = inputTensor.dataType() == org.tensorflow.lite.DataType.UINT8
163 | numClasses = outputShape[1] - 4
164 |
165 | debug("Model setup: inputSize=${inputWidth}x${inputHeight}, isQuantized=$isQuantized, numClasses=$numClasses")
166 | } catch (e: Exception) {
167 | debug("Failed to initialize interpreter: ${e.message}, stack: ${e.stackTraceToString()}")
168 | // Clean up resources
169 | try {
170 | gpuDelegate?.close()
171 | } catch (closeEx: Exception) {
172 | debug("Error closing GPU delegate during cleanup: ${closeEx.message}")
173 | }
174 | throw RuntimeException("TFLite initialization failed: ${e.message}", e)
175 | }
176 |
177 | // Load class names
178 | try {
179 | classNames = loadClassNames(labelsPath)
180 | debug("Loaded ${classNames.size} classes from $labelsPath")
181 | classColors = generateColors(classNames.size)
182 |
183 | if (classNames.size != numClasses) {
184 | debug("Warning: Number of classes in label file (${classNames.size}) differs from model output ($numClasses)")
185 | }
186 | } catch (e: Exception) {
187 | debug("Failed to load class names: ${e.message}")
188 | throw RuntimeException("Failed to load class names", e)
189 | }
190 |
191 | debug("YOLO11Detector initialization completed successfully")
192 | } catch (e: Exception) {
193 | debug("FATAL: Detector initialization failed: ${e.message}")
194 | debug("Stack trace: ${e.stackTraceToString()}")
195 | throw e // Re-throw to ensure caller sees the failure
196 | }
197 | }
198 |
199 | /**
200 | * Configure CPU-specific options for the TFLite interpreter with safer defaults
201 | */
202 | private fun configureCpuOptions(options: Interpreter.Options) {
203 | try {
204 | // Determine optimal thread count based on device
205 | val cpuCores = Runtime.getRuntime().availableProcessors()
206 | // For lower-end devices, use fewer threads to avoid overwhelming the CPU
207 | val optimalThreads = when {
208 | cpuCores <= 2 -> 1
209 | cpuCores <= 4 -> 2
210 | else -> cpuCores - 2
211 | }
212 |
213 | options.setNumThreads(optimalThreads)
214 | options.setUseXNNPACK(true) // Use XNNPACK for CPU acceleration
215 |
216 | // Add FlatBuffer-related options
217 | options.setAllowFp16PrecisionForFp32(true)
218 | options.setAllowBufferHandleOutput(true)
219 |
220 | debug("CPU options configured with $optimalThreads threads")
221 | } catch (e: Exception) {
222 | debug("Error configuring CPU options: ${e.message}")
223 | // Use safe defaults
224 | options.setNumThreads(1)
225 | }
226 | }
227 |
228 | /**
229 | * Loads the TFLite model file with enhanced error checking
230 | */
231 | private fun loadModelFile(modelPath: String): MappedByteBuffer {
232 | try {
233 | val assetManager = context.assets
234 |
235 | // First check if file exists
236 | val assetList = assetManager.list("") ?: emptyArray()
237 | debug("Available assets: ${assetList.joinToString()}")
238 |
239 | if (!assetList.contains(modelPath)) {
240 | throw IOException("Model file not found in assets: $modelPath")
241 | }
242 |
243 | val assetFileDescriptor = assetManager.openFd(modelPath)
244 | val modelSize = assetFileDescriptor.length
245 | debug("Model file size: $modelSize bytes")
246 |
247 | // Check if model size is reasonable
248 | if (modelSize <= 0) {
249 | throw IOException("Invalid model file size: $modelSize")
250 | }
251 |
252 | val fileInputStream = FileInputStream(assetFileDescriptor.fileDescriptor)
253 | val fileChannel = fileInputStream.channel
254 | val startOffset = assetFileDescriptor.startOffset
255 | val declaredLength = assetFileDescriptor.declaredLength
256 |
257 | debug("Mapping model file: offset=$startOffset, length=$declaredLength")
258 |
259 | return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength).also {
260 | debug("Model buffer capacity: ${it.capacity()} bytes")
261 | }
262 | } catch (e: Exception) {
263 | debug("Error loading model file: $modelPath - ${e.message}")
264 | e.printStackTrace()
265 | throw e
266 | }
267 | }
268 |
269 | /**
270 | * Main detection function that processes an image and returns detected objects
271 | */
272 | fun detect(bitmap: Bitmap, confidenceThreshold: Float = CONFIDENCE_THRESHOLD,
273 | iouThreshold: Float = IOU_THRESHOLD): List {
274 | val startTime = SystemClock.elapsedRealtime()
275 | debug("Starting detection with conf=$confidenceThreshold, iou=$iouThreshold")
276 |
277 | try {
278 | // Add debug for input dimensions
279 | debug("Input image dimensions: ${bitmap.width}x${bitmap.height}")
280 |
281 | // Convert Bitmap to Mat for OpenCV processing
282 | val inputMat = Mat()
283 | Utils.bitmapToMat(bitmap, inputMat)
284 | Imgproc.cvtColor(inputMat, inputMat, Imgproc.COLOR_RGBA2BGR)
285 |
286 | // Prepare input for TFLite
287 | val originalSize = Size(bitmap.width.toDouble(), bitmap.height.toDouble())
288 | val resizedImgMat = Mat() // Will hold the resized image
289 |
290 | // Input shape for model
291 | val modelInputShape = Size(inputWidth.toDouble(), inputHeight.toDouble())
292 | debug("Model input shape: ${modelInputShape.width.toInt()}x${modelInputShape.height.toInt()}")
293 |
294 | // First preprocess using OpenCV
295 | val inputTensor = preprocessImageOpenCV(
296 | inputMat,
297 | resizedImgMat,
298 | modelInputShape
299 | )
300 |
301 | // Run inference
302 | return try {
303 | val outputs = runInference(inputTensor)
304 |
305 | // Process outputs to get detections
306 | val detections = postprocess(
307 | outputs,
308 | originalSize,
309 | Size(inputWidth.toDouble(), inputHeight.toDouble()),
310 | confidenceThreshold,
311 | iouThreshold
312 | )
313 |
314 | val inferenceTime = SystemClock.elapsedRealtime() - startTime
315 | debug("Detection completed in $inferenceTime ms with ${detections.size} objects")
316 |
317 | detections
318 | } catch (e: Exception) {
319 | debug("Error during inference: ${e.message}")
320 | e.printStackTrace()
321 | emptyList() // Return empty list on error
322 | } finally {
323 | // Ensure we clean up resources
324 | inputMat.release()
325 | resizedImgMat.release()
326 | }
327 | } catch (e: Exception) {
328 | debug("Error preparing input: ${e.message}")
329 | e.printStackTrace()
330 | return emptyList()
331 | }
332 | }
333 |
334 | /**
335 | * Preprocess the input image using OpenCV to match the C++ implementation exactly
336 | */
337 | private fun preprocessImageOpenCV(image: Mat, outImage: Mat, newShape: Size): ByteBuffer {
338 | val scopedTimer = ScopedTimer("preprocessing")
339 |
340 | // Track original dimensions before any processing
341 | debug("Original image dimensions: ${image.width()}x${image.height()}")
342 |
343 | // Resize with letterboxing to maintain aspect ratio
344 | letterBox(image, outImage, newShape, Scalar(114.0, 114.0, 114.0))
345 |
346 | // Log resized dimensions with letterboxing
347 | debug("After letterbox: ${outImage.width()}x${outImage.height()}")
348 |
349 | // Convert BGR to RGB (YOLOv11 expects RGB input)
350 | val rgbMat = Mat()
351 | Imgproc.cvtColor(outImage, rgbMat, Imgproc.COLOR_BGR2RGB)
352 |
353 | // DEBUG: Output dimensions for verification
354 | debug("Preprocessed image dimensions: ${rgbMat.width()}x${rgbMat.height()}")
355 |
356 | // Prepare the ByteBuffer to store the model input data
357 | val bytesPerChannel = if (isQuantized) 1 else 4
358 | val inputBuffer = ByteBuffer.allocateDirect(1 * inputWidth * inputHeight * 3 * bytesPerChannel)
359 | inputBuffer.order(ByteOrder.nativeOrder())
360 |
361 | try {
362 | // Convert to proper format for TFLite
363 | if (isQuantized) {
364 | // For quantized models, prepare as bytes
365 | val pixels = ByteArray(rgbMat.width() * rgbMat.height() * rgbMat.channels())
366 | rgbMat.get(0, 0, pixels)
367 |
368 | for (i in pixels.indices) {
369 | inputBuffer.put(pixels[i])
370 | }
371 | } else {
372 | // For float models, normalize to [0,1]
373 | // CRITICAL: Create a normalized float Mat directly using OpenCV for better precision
374 | val normalizedMat = Mat()
375 | rgbMat.convertTo(normalizedMat, CvType.CV_32FC3, 1.0/255.0)
376 |
377 | // Now copy the normalized float values to TFLite input buffer
378 | val floatValues = FloatArray(normalizedMat.width() * normalizedMat.height() * normalizedMat.channels())
379 | normalizedMat.get(0, 0, floatValues)
380 |
381 | for (value in floatValues) {
382 | inputBuffer.putFloat(value)
383 | }
384 |
385 | normalizedMat.release()
386 | }
387 | } catch (e: Exception) {
388 | debug("Error during preprocessing: ${e.message}")
389 | e.printStackTrace()
390 | }
391 |
392 | inputBuffer.rewind()
393 | rgbMat.release()
394 |
395 | scopedTimer.stop()
396 | return inputBuffer
397 | }
398 |
399 | /**
400 | * Runs inference with TFLite and returns the raw output
401 | */
402 | private fun runInference(inputBuffer: ByteBuffer): Map {
403 | val scopedTimer = ScopedTimer("inference")
404 |
405 | val outputs: MutableMap = HashMap()
406 |
407 | try {
408 | // YOLOv11 with TFLite typically outputs a single tensor
409 | val outputShape = interpreter.getOutputTensor(0).shape()
410 | debug("Output tensor shape: ${outputShape.joinToString()}")
411 |
412 | // Correctly allocate output buffer based on the shape
413 | if (isQuantized) {
414 | val outputSize = outputShape.reduce { acc, i -> acc * i }
415 | val outputBuffer = ByteBuffer.allocateDirect(4 * outputSize)
416 | .order(ByteOrder.nativeOrder())
417 | outputs[0] = outputBuffer
418 |
419 | // Run inference with quantized model
420 | interpreter.run(inputBuffer, outputBuffer)
421 | } else {
422 | val outputSize = outputShape.reduce { acc, i -> acc * i }
423 | val outputBuffer = ByteBuffer.allocateDirect(4 * outputSize)
424 | .order(ByteOrder.nativeOrder())
425 | outputs[0] = outputBuffer
426 |
427 | // Run inference with float model
428 | interpreter.run(inputBuffer, outputBuffer)
429 |
430 | // Debug: Peek at some values to verify output format
431 | outputBuffer.rewind()
432 | val values = FloatArray(min(10, outputSize))
433 | for (i in values.indices) {
434 | values[i] = outputBuffer.float
435 | }
436 | debug("First few output values: ${values.joinToString()}")
437 | outputBuffer.rewind()
438 | }
439 | } catch (e: Exception) {
440 | debug("Error during inference: ${e.message}")
441 | e.printStackTrace()
442 | }
443 |
444 | scopedTimer.stop()
445 | return outputs
446 | }
447 |
448 | /**
449 | * Post-processes the model outputs to extract detections
450 | * Modified to correctly handle normalized coordinates
451 | */
452 | private fun postprocess(
453 | outputMap: Map,
454 | originalImageSize: Size,
455 | resizedImageShape: Size,
456 | confThreshold: Float,
457 | iouThreshold: Float
458 | ): List {
459 | val scopedTimer = ScopedTimer("postprocessing")
460 |
461 | val detections = mutableListOf()
462 |
463 | try {
464 | // Get output buffer
465 | val outputBuffer = outputMap[0] as ByteBuffer
466 | outputBuffer.rewind()
467 |
468 | // Get output dimensions
469 | val outputShapes = interpreter.getOutputTensor(0).shape()
470 | debug("Output tensor shape: ${outputShapes.joinToString()}")
471 |
472 | // YOLOv11 output tensor shape is [1, 84+4, 8400] = [batch, classes+xywh, predictions]
473 | // This is in TRANSPOSE format (different from YOLOv8)
474 | val num_classes = outputShapes[1] - 4 // 84 classes (88 - 4)
475 | val num_predictions = outputShapes[2] // 8400 predictions
476 |
477 | debug("Processing output tensor: features=${outputShapes[1]}, predictions=$num_predictions, classes=$num_classes")
478 |
479 | // Extract boxes, confidences, and class ids
480 | val boxes = mutableListOf()
481 | val confidences = mutableListOf()
482 | val classIds = mutableListOf()
483 | val nmsBoxes = mutableListOf() // For class-separated NMS
484 |
485 | // Create a float array from the buffer for more efficient access
486 | val outputArray = FloatArray(outputShapes[0] * outputShapes[1] * outputShapes[2])
487 | outputBuffer.rewind()
488 | for (i in outputArray.indices) {
489 | outputArray[i] = outputBuffer.float
490 | }
491 |
492 | // Process each prediction
493 | for (i in 0 until num_predictions) {
494 | // Find class with maximum score and its index
495 | var maxScore = -Float.MAX_VALUE
496 | var classId = -1
497 |
498 | // Scan through all classes (start at index 4, after x,y,w,h)
499 | for (c in 0 until num_classes) {
500 | // Class scores are after the 4 box coordinates
501 | val score = outputArray[(4 + c) * num_predictions + i]
502 | if (score > maxScore) {
503 | maxScore = score
504 | classId = c
505 | }
506 | }
507 |
508 | // Filter by confidence threshold
509 | if (maxScore >= confThreshold) {
510 | // Extract bounding box coordinates (normalized between 0-1)
511 | val x = outputArray[0 * num_predictions + i] // center_x
512 | val y = outputArray[1 * num_predictions + i] // center_y
513 | val w = outputArray[2 * num_predictions + i] // width
514 | val h = outputArray[3 * num_predictions + i] // height
515 |
516 | // Convert from center format (xywh) to corner format (xyxy) - all normalized
517 | val left = x - w / 2
518 | val top = y - h / 2
519 | val right = x + w / 2
520 | val bottom = y + h / 2
521 |
522 | debug("Detection found: center=($x,$y), wh=($w,$h), score=$maxScore, class=$classId")
523 | debug(" box normalized: ($left,$top,$right,$bottom)")
524 |
525 | // Scale coordinates to original image size
526 | val scaledBox = scaleCoords(
527 | resizedImageShape,
528 | RectF(left, top, right, bottom),
529 | originalImageSize
530 | )
531 |
532 | // Additional debug for scaled box
533 | debug(" box in original image: (${scaledBox.left},${scaledBox.top},${scaledBox.right},${scaledBox.bottom})")
534 |
535 | // Validate dimensions before adding
536 | val boxWidth = scaledBox.right - scaledBox.left
537 | val boxHeight = scaledBox.bottom - scaledBox.top
538 |
539 | if (boxWidth > 1 && boxHeight > 1) { // Ensure reasonable size
540 | // Round coordinates to integer precision
541 | val roundedBox = RectF(
542 | round(scaledBox.left),
543 | round(scaledBox.top),
544 | round(scaledBox.right),
545 | round(scaledBox.bottom)
546 | )
547 |
548 | // Create offset box for NMS with class separation
549 | val nmsBox = RectF(
550 | roundedBox.left + classId * 7680f,
551 | roundedBox.top + classId * 7680f,
552 | roundedBox.right + classId * 7680f,
553 | roundedBox.bottom + classId * 7680f
554 | )
555 |
556 | nmsBoxes.add(nmsBox)
557 | boxes.add(roundedBox)
558 | confidences.add(maxScore)
559 | classIds.add(classId)
560 | } else {
561 | debug("Skipped detection with invalid dimensions: ${boxWidth}x${boxHeight}")
562 | }
563 | }
564 | }
565 |
566 | debug("Found ${boxes.size} raw detections before NMS")
567 |
568 | // Run NMS to eliminate redundant boxes
569 | val selectedIndices = mutableListOf()
570 | nonMaxSuppression(nmsBoxes, confidences, confThreshold, iouThreshold, selectedIndices)
571 |
572 | debug("After NMS: ${selectedIndices.size} detections remaining")
573 |
574 | // Create final detection objects
575 | for (idx in selectedIndices) {
576 | val box = boxes[idx]
577 |
578 | // Calculate width and height from corners
579 | val width = box.right - box.left
580 | val height = box.bottom - box.top
581 |
582 | // Create detection object with proper dimensions
583 | val detection = Detection(
584 | BoundingBox(
585 | box.left.toInt(),
586 | box.top.toInt(),
587 | width.toInt(),
588 | height.toInt()
589 | ),
590 | confidences[idx],
591 | classIds[idx]
592 | )
593 |
594 | detections.add(detection)
595 | debug("Added detection: box=${detection.box.x},${detection.box.y},${detection.box.width},${detection.box.height}, " +
596 | "conf=${detection.conf}, class=${classIds[idx]}")
597 | }
598 | } catch (e: Exception) {
599 | debug("Error during postprocessing: ${e.message}")
600 | e.printStackTrace()
601 | }
602 |
603 | scopedTimer.stop()
604 | return detections
605 | }
606 |
607 | /**
608 | * Draws bounding boxes on the provided bitmap
609 | */
610 | fun drawDetections(bitmap: Bitmap, detections: List): Bitmap {
611 | val mutableBitmap = bitmap.copy(Bitmap.Config.ARGB_8888, true)
612 | val canvas = Canvas(mutableBitmap)
613 | val paint = Paint()
614 | paint.style = Paint.Style.STROKE
615 | paint.strokeWidth = max(bitmap.width, bitmap.height) * 0.004f
616 |
617 | val textPaint = Paint()
618 | textPaint.style = Paint.Style.FILL
619 | textPaint.textSize = max(bitmap.width, bitmap.height) * 0.02f
620 |
621 | // Filter detections to ensure quality results
622 | val filteredDetections = detections.filter {
623 | it.conf > CONFIDENCE_THRESHOLD &&
624 | it.classId >= 0 &&
625 | it.classId < classNames.size
626 | }
627 |
628 | for (detection in filteredDetections) {
629 | // Get color for this class
630 | val color = classColors[detection.classId % classColors.size]
631 | paint.color = Color.rgb(color[0], color[1], color[2])
632 |
633 | // Draw bounding box
634 | canvas.drawRect(
635 | detection.box.x.toFloat(),
636 | detection.box.y.toFloat(),
637 | (detection.box.x + detection.box.width).toFloat(),
638 | (detection.box.y + detection.box.height).toFloat(),
639 | paint
640 | )
641 |
642 | // Create label text
643 | val label = "${classNames[detection.classId]}: ${(detection.conf * 100).toInt()}%"
644 |
645 | // Measure text for background rectangle
646 | val textWidth = textPaint.measureText(label)
647 | val textHeight = textPaint.textSize
648 |
649 | // Define label position
650 | val labelY = max(detection.box.y.toFloat(), textHeight + 5f)
651 |
652 | // Draw background rectangle for text
653 | val bgPaint = Paint()
654 | bgPaint.color = Color.rgb(color[0], color[1], color[2])
655 | bgPaint.style = Paint.Style.FILL
656 |
657 | canvas.drawRect(
658 | detection.box.x.toFloat(),
659 | labelY - textHeight - 5f,
660 | detection.box.x.toFloat() + textWidth + 10f,
661 | labelY + 5f,
662 | bgPaint
663 | )
664 |
665 | // Draw text
666 | textPaint.color = Color.WHITE
667 | canvas.drawText(
668 | label,
669 | detection.box.x.toFloat() + 5f,
670 | labelY - 5f,
671 | textPaint
672 | )
673 | }
674 |
675 | return mutableBitmap
676 | }
677 |
678 | /**
679 | * Draws bounding boxes and semi-transparent masks on the provided bitmap
680 | */
681 | fun drawDetectionsMask(bitmap: Bitmap, detections: List, maskAlpha: Float = 0.4f): Bitmap {
682 | val mutableBitmap = bitmap.copy(Bitmap.Config.ARGB_8888, true)
683 | val width = bitmap.width
684 | val height = bitmap.height
685 |
686 | // Create a mask bitmap for overlay
687 | val maskBitmap = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888)
688 | val maskCanvas = Canvas(maskBitmap)
689 |
690 | // Filter detections to ensure quality results
691 | val filteredDetections = detections.filter {
692 | it.conf > CONFIDENCE_THRESHOLD &&
693 | it.classId >= 0 &&
694 | it.classId < classNames.size
695 | }
696 |
697 | // Draw filled rectangles on mask bitmap
698 | for (detection in filteredDetections) {
699 | val color = classColors[detection.classId % classColors.size]
700 | val paint = Paint()
701 | paint.color = Color.argb(
702 | (255 * maskAlpha).toInt(),
703 | color[0],
704 | color[1],
705 | color[2]
706 | )
707 | paint.style = Paint.Style.FILL
708 |
709 | maskCanvas.drawRect(
710 | detection.box.x.toFloat(),
711 | detection.box.y.toFloat(),
712 | (detection.box.x + detection.box.width).toFloat(),
713 | (detection.box.y + detection.box.height).toFloat(),
714 | paint
715 | )
716 | }
717 |
718 | // Overlay mask on original image
719 | val canvas = Canvas(mutableBitmap)
720 | val paint = Paint()
721 | paint.alpha = (255 * maskAlpha).toInt()
722 | canvas.drawBitmap(maskBitmap, 0f, 0f, paint)
723 |
724 | // Draw bounding boxes and labels (reusing existing method but with full opacity)
725 | val mainCanvas = Canvas(mutableBitmap)
726 | val boxPaint = Paint()
727 | boxPaint.style = Paint.Style.STROKE
728 | boxPaint.strokeWidth = max(width, height) * 0.004f
729 |
730 | val textPaint = Paint()
731 | textPaint.textSize = max(width, height) * 0.02f
732 |
733 | for (detection in filteredDetections) {
734 | val color = classColors[detection.classId % classColors.size]
735 | boxPaint.color = Color.rgb(color[0], color[1], color[2])
736 |
737 | // Draw bounding box
738 | mainCanvas.drawRect(
739 | detection.box.x.toFloat(),
740 | detection.box.y.toFloat(),
741 | (detection.box.x + detection.box.width).toFloat(),
742 | (detection.box.y + detection.box.height).toFloat(),
743 | boxPaint
744 | )
745 |
746 | // Create and draw label
747 | val label = "${classNames[detection.classId]}: ${(detection.conf * 100).toInt()}%"
748 | val textWidth = textPaint.measureText(label)
749 | val textHeight = textPaint.textSize
750 |
751 | val labelY = max(detection.box.y.toFloat(), textHeight + 5f)
752 |
753 | val bgPaint = Paint()
754 | bgPaint.color = Color.rgb(color[0], color[1], color[2])
755 | bgPaint.style = Paint.Style.FILL
756 |
757 | mainCanvas.drawRect(
758 | detection.box.x.toFloat(),
759 | labelY - textHeight - 5f,
760 | detection.box.x.toFloat() + textWidth + 10f,
761 | labelY + 5f,
762 | bgPaint
763 | )
764 |
765 | textPaint.color = Color.WHITE
766 | mainCanvas.drawText(
767 | label,
768 | detection.box.x.toFloat() + 5f,
769 | labelY - 5f,
770 | textPaint
771 | )
772 | }
773 |
774 | // Clean up
775 | maskBitmap.recycle()
776 |
777 | return mutableBitmap
778 | }
779 |
780 | /**
781 | * Loads class names from a file
782 | */
783 | private fun loadClassNames(labelsPath: String): List {
784 | return context.assets.open(labelsPath).bufferedReader().useLines {
785 | it.map { line -> line.trim() }.filter { it.isNotEmpty() }.toList()
786 | }
787 | }
788 |
789 | /**
790 | * Generate colors for visualization
791 | */
792 | private fun generateColors(numClasses: Int): List {
793 | val colors = mutableListOf()
794 | val random = Random(42) // Fixed seed for reproducibility
795 |
796 | for (i in 0 until numClasses) {
797 | val color = intArrayOf(
798 | random.nextInt(256), // R
799 | random.nextInt(256), // G
800 | random.nextInt(256) // B
801 | )
802 | colors.add(color)
803 | }
804 |
805 | return colors
806 | }
807 |
808 | /**
809 | * Get class name for a given class ID
810 | * @param classId The class ID to get the name for
811 | * @return The class name or "Unknown" if the ID is invalid
812 | */
813 | fun getClassName(classId: Int): String {
814 | return if (classId >= 0 && classId < classNames.size) {
815 | classNames[classId]
816 | } else {
817 | "Unknown"
818 | }
819 | }
820 |
821 | /**
822 | * Get details about the model's input requirements
823 | * @return String containing shape and data type information
824 | */
825 | fun getInputDetails(): String {
826 | val inputTensor = interpreter.getInputTensor(0)
827 | val shape = inputTensor.shape()
828 | val type = when(inputTensor.dataType()) {
829 | org.tensorflow.lite.DataType.FLOAT32 -> "FLOAT32"
830 | org.tensorflow.lite.DataType.UINT8 -> "UINT8"
831 | else -> "OTHER"
832 | }
833 | return "Shape: ${shape.joinToString()}, Type: $type"
834 | }
835 |
836 | /**
837 | * Cleanup resources when no longer needed
838 | */
839 | fun close() {
840 | try {
841 | interpreter.close()
842 | debug("TFLite interpreter closed")
843 | } catch (e: Exception) {
844 | debug("Error closing interpreter: ${e.message}")
845 | }
846 |
847 | try {
848 | gpuDelegate?.close()
849 | debug("GPU delegate resources released")
850 | } catch (e: Exception) {
851 | debug("Error closing GPU delegate: ${e.message}")
852 | }
853 |
854 | gpuDelegate = null
855 | }
856 |
857 | /**
858 | * Data classes for detections and bounding boxes
859 | */
860 | data class BoundingBox(val x: Int, val y: Int, val width: Int, val height: Int)
861 |
862 | data class Detection(val box: BoundingBox, val conf: Float, val classId: Int)
863 |
864 | /**
865 | * Helper functions
866 | */
867 |
868 | /**
869 | * Letterbox an image to fit a specific size while maintaining aspect ratio
870 | * Fixed padding calculation to ensure consistent vertical alignment
871 | */
872 | private fun letterBox(
873 | image: Mat,
874 | outImage: Mat,
875 | newShape: Size,
876 | color: Scalar = Scalar(114.0, 114.0, 114.0),
877 | auto: Boolean = true,
878 | scaleFill: Boolean = false,
879 | scaleUp: Boolean = true,
880 | stride: Int = 32
881 | ) {
882 | val originalShape = Size(image.cols().toDouble(), image.rows().toDouble())
883 |
884 | // Calculate ratio to fit the image within new shape
885 | var ratio = min(
886 | newShape.height / originalShape.height,
887 | newShape.width / originalShape.width
888 | ).toFloat()
889 |
890 | // Prevent scaling up if not allowed
891 | if (!scaleUp) {
892 | ratio = min(ratio, 1.0f)
893 | }
894 |
895 | // Calculate new unpadded dimensions
896 | val newUnpadW = round(originalShape.width * ratio).toInt()
897 | val newUnpadH = round(originalShape.height * ratio).toInt()
898 |
899 | // Calculate padding
900 | val dw = (newShape.width - newUnpadW).toFloat()
901 | val dh = (newShape.height - newUnpadH).toFloat()
902 |
903 | // Calculate padding distribution
904 | val padLeft: Int
905 | val padRight: Int
906 | val padTop: Int
907 | val padBottom: Int
908 |
909 | if (auto) {
910 | // Auto padding aligned to stride
911 | val dwHalf = ((dw % stride) / 2).toFloat()
912 | val dhHalf = ((dh % stride) / 2).toFloat()
913 |
914 | padLeft = (dw / 2 - dwHalf).toInt()
915 | padRight = (dw / 2 + dwHalf).toInt()
916 | padTop = (dh / 2 - dhHalf).toInt()
917 | padBottom = (dh / 2 + dhHalf).toInt()
918 | } else if (scaleFill) {
919 | // Scale to fill without maintaining aspect ratio
920 | padLeft = 0
921 | padRight = 0
922 | padTop = 0
923 | padBottom = 0
924 | Imgproc.resize(image, outImage, newShape)
925 | return
926 | } else {
927 | // Even padding on all sides
928 | padLeft = (dw / 2).toInt()
929 | padRight = (dw - padLeft).toInt()
930 | padTop = (dh / 2).toInt()
931 | padBottom = (dh - padTop).toInt()
932 | }
933 |
934 | // Log detailed padding information
935 | debug("Letterbox: original=${originalShape.width}x${originalShape.height}, " +
936 | "new=${newUnpadW}x${newUnpadH}, ratio=$ratio")
937 | debug("Letterbox: padding left=$padLeft, right=$padRight, top=$padTop, bottom=$padBottom")
938 |
939 | // Resize the image to fit within the new dimensions
940 | Imgproc.resize(
941 | image,
942 | outImage,
943 | Size(newUnpadW.toDouble(), newUnpadH.toDouble()),
944 | 0.0, 0.0,
945 | Imgproc.INTER_LINEAR
946 | )
947 |
948 | // Apply padding to create letterboxed image
949 | Core.copyMakeBorder(
950 | outImage,
951 | outImage,
952 | padTop,
953 | padBottom,
954 | padLeft,
955 | padRight,
956 | Core.BORDER_CONSTANT,
957 | color
958 | )
959 | }
960 |
961 | /**
962 | * Scale coordinates from model input size to original image size
963 | * Fixed vertical positioning issue with letterboxed images
964 | */
965 | private fun scaleCoords(
966 | imageShape: Size,
967 | coords: RectF,
968 | imageOriginalShape: Size,
969 | clip: Boolean = true
970 | ): RectF {
971 | // Get dimensions in pixels
972 | val inputWidth = imageShape.width.toFloat()
973 | val inputHeight = imageShape.height.toFloat()
974 | val originalWidth = imageOriginalShape.width.toFloat()
975 | val originalHeight = imageOriginalShape.height.toFloat()
976 |
977 | // Calculate scaling factor (ratio) between original and input sizes
978 | val gain = min(inputWidth / originalWidth, inputHeight / originalHeight)
979 |
980 | // Calculate padding needed for letterboxing
981 | val padX = (inputWidth - originalWidth * gain) / 2.0f
982 | val padY = (inputHeight - originalHeight * gain) / 2.0f
983 |
984 | // Debug dimensions
985 | debug("Scale coords: input=${inputWidth}x${inputHeight}, original=${originalWidth}x${originalHeight}")
986 | debug("Scale coords: gain=$gain, padding=($padX, $padY)")
987 | debug("Scale coords: input normalized=(${coords.left}, ${coords.top}, ${coords.right}, ${coords.bottom})")
988 |
989 | // Convert normalized coordinates [0-1] to absolute pixel coordinates
990 | val absLeft = coords.left * inputWidth
991 | val absTop = coords.top * inputHeight
992 | val absRight = coords.right * inputWidth
993 | val absBottom = coords.bottom * inputHeight
994 |
995 | debug("Scale coords: absolute pixels=($absLeft, $absTop, $absRight, $absBottom)")
996 |
997 | // Remove padding and scale back to original image dimensions
998 | val x1 = (absLeft - padX) / gain
999 | val y1 = (absTop - padY) / gain
1000 | val x2 = (absRight - padX) / gain
1001 | val y2 = (absBottom - padY) / gain
1002 |
1003 | debug("Scale coords: output original=($x1, $y1, $x2, $y2)")
1004 |
1005 | // Create result rectangle
1006 | val result = RectF(x1, y1, x2, y2)
1007 |
1008 | // Clip to image boundaries if requested
1009 | if (clip) {
1010 | result.left = max(0f, min(result.left, originalWidth))
1011 | result.top = max(0f, min(result.top, originalHeight))
1012 | result.right = max(0f, min(result.right, originalWidth))
1013 | result.bottom = max(0f, min(result.bottom, originalHeight))
1014 | }
1015 |
1016 | return result
1017 | }
1018 |
1019 | /**
1020 | * Clamp a value between min and max
1021 | */
1022 | private fun clamp(value: Float, min: Float, max: Float): Float {
1023 | return when {
1024 | value < min -> min
1025 | value > max -> max
1026 | else -> value
1027 | }
1028 | }
1029 |
1030 | /**
1031 | * Non-Maximum Suppression implementation to filter redundant boxes
1032 | * Updated to exactly match the C++ implementation
1033 | */
1034 | private fun nonMaxSuppression(
1035 | boxes: List,
1036 | scores: List,
1037 | scoreThreshold: Float,
1038 | iouThreshold: Float,
1039 | indices: MutableList
1040 | ) {
1041 | indices.clear()
1042 |
1043 | // Early return if no boxes
1044 | if (boxes.isEmpty()) {
1045 | return
1046 | }
1047 |
1048 | // Create list of indices sorted by score (highest first)
1049 | val sortedIndices = boxes.indices
1050 | .filter { scores[it] >= scoreThreshold }
1051 | .sortedByDescending { scores[it] }
1052 |
1053 | if (sortedIndices.isEmpty()) {
1054 | return
1055 | }
1056 |
1057 | // Calculate areas once
1058 | val areas = boxes.map { (it.right - it.left) * (it.bottom - it.top) }
1059 |
1060 | // Suppression mask
1061 | val suppressed = BooleanArray(boxes.size) { false }
1062 |
1063 | // Process boxes in order of decreasing score
1064 | for (i in sortedIndices.indices) {
1065 | val currentIdx = sortedIndices[i]
1066 |
1067 | if (suppressed[currentIdx]) {
1068 | continue
1069 | }
1070 |
1071 | // Add current box to valid detections
1072 | indices.add(currentIdx)
1073 |
1074 | // Get current box coordinates
1075 | val currentBox = boxes[currentIdx]
1076 | val x1Max = currentBox.left
1077 | val y1Max = currentBox.top
1078 | val x2Max = currentBox.right
1079 | val y2Max = currentBox.bottom
1080 | val areaCurrent = areas[currentIdx]
1081 |
1082 | // Compare with remaining boxes
1083 | for (j in i + 1 until sortedIndices.size) {
1084 | val compareIdx = sortedIndices[j]
1085 |
1086 | if (suppressed[compareIdx]) {
1087 | continue
1088 | }
1089 |
1090 | // Calculate intersection
1091 | val compareBox = boxes[compareIdx]
1092 | val x1 = max(x1Max, compareBox.left)
1093 | val y1 = max(y1Max, compareBox.top)
1094 | val x2 = min(x2Max, compareBox.right)
1095 | val y2 = min(y2Max, compareBox.bottom)
1096 |
1097 | val interWidth = max(0f, x2 - x1)
1098 | val interHeight = max(0f, y2 - y1)
1099 |
1100 | if (interWidth <= 0 || interHeight <= 0) {
1101 | continue
1102 | }
1103 |
1104 | val intersection = interWidth * interHeight
1105 | val unionArea = areaCurrent + areas[compareIdx] - intersection
1106 | val iou = if (unionArea > 0) intersection / unionArea else 0f
1107 |
1108 | // Suppress if IoU exceeds threshold
1109 | if (iou > iouThreshold) {
1110 | suppressed[compareIdx] = true
1111 | }
1112 | }
1113 | }
1114 | }
1115 |
1116 | /**
1117 | * Debug print function with enhanced logging
1118 | */
1119 | private fun debug(message: String) {
1120 | Log.d(TAG, message)
1121 | if (BuildConfig.DEBUG) {
1122 | println("YOLO11Detector: $message")
1123 | }
1124 | }
1125 |
1126 | // Add ScopedTimer implementation (if missing)
1127 | private class ScopedTimer(private val name: String) {
1128 | private val startTime = SystemClock.elapsedRealtime()
1129 |
1130 | fun stop() {
1131 | val endTime = SystemClock.elapsedRealtime()
1132 | // debug("$name took ${endTime - startTime} ms")
1133 | }
1134 | }
1135 | }
1136 |
--------------------------------------------------------------------------------
/src/kotlin/activity_main.xml:
--------------------------------------------------------------------------------
1 |
2 |
9 |
10 |
20 |
21 |
29 |
30 |
37 |
38 |
39 |
40 |
--------------------------------------------------------------------------------
/src/kotlin/build.gradle:
--------------------------------------------------------------------------------
1 | // This file is part of OpenCV project.
2 | // It is subject to the license terms in the LICENSE file found in the top-level directory
3 | // of this distribution and at http://opencv.org/license.html.
4 |
5 | //
6 | // Notes about integration OpenCV into existed Android Studio application project are below (application 'app' module should exist).
7 | //
8 | // This file is located in /sdk directory (near 'etc', 'java', 'native' subdirectories)
9 | //
10 | // Add module into Android Studio application project:
11 | //
12 | // - Android Studio way:
13 | // (will copy almost all OpenCV Android SDK into your project, ~200Mb)
14 | //
15 | // Import module: Menu -> "File" -> "New" -> "Module" -> "Import Gradle project":
16 | // Source directory: select this "sdk" directory
17 | // Module name: ":opencv"
18 | //
19 | // - or attach library module from OpenCV Android SDK
20 | // (without copying into application project directory, allow to share the same module between projects)
21 | //
22 | // Edit "settings.gradle" and add these lines:
23 | //
24 | // def opencvsdk=''
25 | // // You can put declaration above into gradle.properties file instead (including file in HOME directory),
26 | // // but without 'def' and apostrophe symbols ('): opencvsdk=
27 | // include ':opencv'
28 | // project(':opencv').projectDir = new File(opencvsdk + '/sdk')
29 | //
30 | //
31 | //
32 | // Add dependency into application module:
33 | //
34 | // - Android Studio way:
35 | // "Open Module Settings" (F4) -> "Dependencies" tab
36 | //
37 | // - or add "project(':opencv')" dependency into app/build.gradle:
38 | //
39 | // dependencies {
40 | // implementation fileTree(dir: 'libs', include: ['*.jar'])
41 | // ...
42 | // implementation project(':opencv')
43 | // }
44 | //
45 | //
46 | //
47 | // Load OpenCV native library before using:
48 | //
49 | // - avoid using of "OpenCVLoader.initAsync()" approach - it is deprecated
50 | // It may load library with different version (from OpenCV Android Manager, which is installed separatelly on device)
51 | //
52 | // - use "System.loadLibrary("opencv_java4")" or "OpenCVLoader.initDebug()"
53 | // TODO: Add accurate API to load OpenCV native library
54 | //
55 | //
56 | //
57 | // Native C++ support (necessary to use OpenCV in native code of application only):
58 | //
59 | // - Use find_package() in app/CMakeLists.txt:
60 | //
61 | // find_package(OpenCV 4.11 REQUIRED java)
62 | // ...
63 | // target_link_libraries(native-lib ${OpenCV_LIBRARIES})
64 | //
65 | // - Add "OpenCV_DIR" and enable C++ exceptions/RTTI support via app/build.gradle
66 | // Documentation about CMake options: https://developer.android.com/ndk/guides/cmake.html
67 | //
68 | // defaultConfig {
69 | // ...
70 | // externalNativeBuild {
71 | // cmake {
72 | // cppFlags "-std=c++11 -frtti -fexceptions"
73 | // arguments "-DOpenCV_DIR=" + opencvsdk + "/sdk/native/jni" // , "-DANDROID_ARM_NEON=TRUE"
74 | // }
75 | // }
76 | // }
77 | //
78 | // - (optional) Limit/filter ABIs to build ('android' scope of 'app/build.gradle'):
79 | // Useful information: https://developer.android.com/studio/build/gradle-tips.html (Configure separate APKs per ABI)
80 | //
81 | // splits {
82 | // abi {
83 | // enable true
84 | // universalApk false
85 | // reset()
86 | // include 'armeabi-v7a' // , 'x86', 'x86_64', 'arm64-v8a'
87 | // }
88 | // }
89 | //
90 |
91 | apply plugin: 'com.android.library'
92 | apply plugin: 'maven-publish'
93 | try {
94 | // apply plugin: 'kotlin-android'
95 | println "Configure OpenCV with Kotlin"
96 | } catch (Exception e) {
97 | println "Configure OpenCV without Kotlin"
98 | }
99 |
100 | def openCVersionName = "4.11.0"
101 | def openCVersionCode = ((4 * 100 + 11) * 100 + 0) * 10 + 0
102 |
103 | println "OpenCV: " +openCVersionName + " " + project.buildscript.sourceFile
104 |
105 | android {
106 | namespace 'org.opencv'
107 | compileSdkVersion 34
108 |
109 | defaultConfig {
110 | minSdkVersion 21
111 | targetSdkVersion 34
112 |
113 | versionCode openCVersionCode
114 | versionName openCVersionName
115 |
116 | externalNativeBuild {
117 | cmake {
118 | arguments "-DANDROID_STL=c++_shared"
119 | targets "opencv_jni_shared"
120 | }
121 | }
122 | }
123 |
124 | android {
125 | buildFeatures {
126 | buildConfig true
127 | }
128 | }
129 | compileOptions {
130 | sourceCompatibility JavaVersion.VERSION_17
131 | targetCompatibility JavaVersion.VERSION_17
132 | }
133 |
134 | buildTypes {
135 | debug {
136 | packagingOptions {
137 | doNotStrip '**/*.so' // controlled by OpenCV CMake scripts
138 | }
139 | }
140 | release {
141 | packagingOptions {
142 | doNotStrip '**/*.so' // controlled by OpenCV CMake scripts
143 | }
144 | minifyEnabled false
145 | proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.txt'
146 | }
147 | }
148 |
149 | sourceSets {
150 | main {
151 | jniLibs.srcDirs = ['native/libs']
152 | java.srcDirs = ['java/src']
153 | res.srcDirs = ['java/res']
154 | manifest.srcFile 'java/AndroidManifest.xml'
155 | }
156 | }
157 |
158 | externalNativeBuild {
159 | cmake {
160 | path (project.projectDir.toString() + '/libcxx_helper/CMakeLists.txt')
161 | }
162 | }
163 |
164 | buildFeatures {
165 | prefabPublishing true
166 | buildConfig true
167 | }
168 |
169 | prefab {
170 | opencv_jni_shared {
171 | headers 'native/jni/include'
172 | }
173 | }
174 |
175 | publishing {
176 | singleVariant('release') {
177 | withSourcesJar()
178 | withJavadocJar()
179 | }
180 | }
181 |
182 | }
183 |
184 | publishing {
185 | publications {
186 | release(MavenPublication) {
187 | groupId = 'org.opencv'
188 | artifactId = 'opencv'
189 | version = '4.11.0'
190 |
191 | afterEvaluate {
192 | from components.release
193 | }
194 | }
195 | }
196 | repositories {
197 | maven {
198 | name = 'myrepo'
199 | url = "${project.buildDir}/repo"
200 | }
201 | }
202 | }
203 |
204 | dependencies {
205 | }
206 |
--------------------------------------------------------------------------------
/src/kotlin/build.gradle.kts:
--------------------------------------------------------------------------------
1 | plugins {
2 | alias(libs.plugins.android.application)
3 | alias(libs.plugins.kotlin.android)
4 | }
5 |
6 | android {
7 | namespace = "com.example.opencv_tutorial"
8 | compileSdk = 35
9 |
10 | defaultConfig {
11 | applicationId = "com.example.opencv_tutorial"
12 | minSdk = 24
13 | targetSdk = 35
14 | versionCode = 1
15 | versionName = "1.0"
16 |
17 | testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
18 |
19 | // Add NDK ABI filters to ensure compatibility
20 | ndk {
21 | abiFilters.addAll(listOf("armeabi-v7a", "arm64-v8a", "x86", "x86_64"))
22 | }
23 | }
24 |
25 | buildTypes {
26 | release {
27 | isMinifyEnabled = false
28 | proguardFiles(
29 | getDefaultProguardFile("proguard-android-optimize.txt"),
30 | "proguard-rules.pro"
31 | )
32 | }
33 | debug {
34 | isDebuggable = true
35 | // Enable more detailed native logging for debugging
36 | buildConfigField("boolean", "ENABLE_DETAILED_LOGGING", "true")
37 | }
38 | }
39 | compileOptions {
40 | sourceCompatibility = JavaVersion.VERSION_11
41 | targetCompatibility = JavaVersion.VERSION_11
42 | }
43 | kotlinOptions {
44 | jvmTarget = "11"
45 | }
46 | buildFeatures {
47 | compose = true
48 | }
49 | composeOptions {
50 | kotlinCompilerExtensionVersion = "1.5.1"
51 | }
52 |
53 | packaging {
54 | resources {
55 | excludes += "/META-INF/{AL2.0,LGPL2.1}"
56 | // Avoid duplicate library files
57 | pickFirst("**/libc++_shared.so")
58 | pickFirst("**/libOpenCL.so")
59 | }
60 | jniLibs {
61 | useLegacyPackaging = true // Helps with native lib compatibility
62 | }
63 | }
64 |
65 | // Add for better compatibility with native libraries
66 | ndkVersion = "21.4.7075529" // Use a stable NDK version
67 | }
68 |
69 | dependencies {
70 |
71 | implementation(libs.androidx.core.ktx)
72 | implementation(libs.androidx.appcompat)
73 | implementation(libs.material)
74 | implementation(libs.androidx.activity)
75 | implementation(libs.androidx.constraintlayout)
76 | implementation(project(":sdk"))
77 | testImplementation(libs.junit)
78 | androidTestImplementation(libs.androidx.junit)
79 | androidTestImplementation(libs.androidx.espresso.core)
80 |
81 | // Replace with more specific version
82 | implementation("com.microsoft.onnxruntime:onnxruntime-android:latest.release")
83 |
84 | // Other dependencies...
85 | implementation("androidx.compose.ui:ui:1.5.1")
86 | implementation("androidx.compose.material:material:1.5.1")
87 | implementation("androidx.compose.ui:ui-tooling-preview:1.5.1")
88 | implementation("androidx.activity:activity-compose:1.7.2")
89 | debugImplementation("androidx.compose.ui:ui-tooling:1.5.1")
90 | implementation("org.tensorflow:tensorflow-lite:2.9.0")
91 | implementation("org.tensorflow:tensorflow-lite-task-vision:0.4.2")
92 | implementation("org.tensorflow:tensorflow-lite-gpu:2.9.0")
93 | implementation("org.tensorflow:tensorflow-lite-support:0.4.2")
94 |
95 | // Add metadata extractor for better model information
96 | implementation("org.tensorflow:tensorflow-lite-metadata:0.4.2")
97 | }
--------------------------------------------------------------------------------
/src/kotlin/res/layout/activity_model_parse.xml:
--------------------------------------------------------------------------------
1 |
2 |
9 |
10 |
14 |
15 |
22 |
23 |
24 |
25 |
--------------------------------------------------------------------------------
/src/output.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DanielSarmiento04/yolov11cpp/c0690429b302c0b8a283a900ee30b89152019909/src/output.mp4
--------------------------------------------------------------------------------
/src/output/base_simplify.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DanielSarmiento04/yolov11cpp/c0690429b302c0b8a283a900ee30b89152019909/src/output/base_simplify.onnx
--------------------------------------------------------------------------------
/src/output/t1.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DanielSarmiento04/yolov11cpp/c0690429b302c0b8a283a900ee30b89152019909/src/output/t1.mp4
--------------------------------------------------------------------------------
/src/output/yolo_cli_pt.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DanielSarmiento04/yolov11cpp/c0690429b302c0b8a283a900ee30b89152019909/src/output/yolo_cli_pt.mp4
--------------------------------------------------------------------------------
/src/output/yolov11_cpp_onnx.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DanielSarmiento04/yolov11cpp/c0690429b302c0b8a283a900ee30b89152019909/src/output/yolov11_cpp_onnx.mp4
--------------------------------------------------------------------------------
/src/runs/detect/predict2/t1.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DanielSarmiento04/yolov11cpp/c0690429b302c0b8a283a900ee30b89152019909/src/runs/detect/predict2/t1.mp4
--------------------------------------------------------------------------------
/src/runs/detect/predict3/t1.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DanielSarmiento04/yolov11cpp/c0690429b302c0b8a283a900ee30b89152019909/src/runs/detect/predict3/t1.mp4
--------------------------------------------------------------------------------
/src/t1.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DanielSarmiento04/yolov11cpp/c0690429b302c0b8a283a900ee30b89152019909/src/t1.mp4
--------------------------------------------------------------------------------
/src/viewer.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | using namespace cv;
5 |
6 | int main(int argc, char const *argv[])
7 | {
8 | const std::string videoSource = "./input.mov"; // your usb cam device
9 |
10 |
11 | cv::VideoCapture cap;
12 |
13 | // configure the best camera to iphone 11
14 | cap.open(videoSource, cv::CAP_FFMPEG);
15 | if (!cap.isOpened())
16 | {
17 | std::cerr << "Error: Could not open the camera!\n";
18 | return -1;
19 | }
20 |
21 | for(;;)
22 | {
23 | cv::Mat frame;
24 | cap >> frame;
25 | if (frame.empty())
26 | {
27 | std::cerr << "Error: Could not read a frame!\n";
28 | break;
29 | }
30 |
31 | // Display the frame
32 | cv::imshow("input", frame);
33 |
34 | if (cv::waitKey(1) >= 0)
35 | {
36 | break;
37 | }
38 |
39 | }
40 |
41 |
42 | return 0;
43 | }
44 |
--------------------------------------------------------------------------------
/src/viewer.out:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DanielSarmiento04/yolov11cpp/c0690429b302c0b8a283a900ee30b89152019909/src/viewer.out
--------------------------------------------------------------------------------