├── .gitattributes ├── LICENSE ├── README.md ├── data_reader.py ├── imagenet_val ├── clock.jpg └── input.jpg ├── models ├── LICENSE_MOBILENETV2 ├── LICENSE_YOLOV3.txt ├── LICENSE_YOLOX ├── mobilenetv2_1.0.opt.onnx ├── mobilenev2_quantized.onnx ├── yolov3-tiny.opt.onnx ├── yolov3-tiny_quantized_per_tensor.onnx ├── yolox_tiny.opt.onnx ├── yolox_tiny_quantized_per_channel.onnx └── yolox_tiny_quantized_per_tensor.onnx ├── quantize.py └── test ├── input.jpg ├── util ├── __init__.py ├── classifier_utils.py ├── detector_utils.py ├── image_utils.py ├── log_init.py ├── math_utils.py ├── microphone_utils.py ├── model_utils.py ├── nms_utils.py ├── params.py ├── utils.py └── webcamera_utils.py ├── yolox.py └── yolox_utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.onnx filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 ax Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # onnx-quantization 2 | 3 | This is a example to quantize onnx. The input is onnx of float. Quantization is done using onnxruntime. The output is onnx of int8. 4 | 5 | ## Requirements 6 | 7 | - onnxruntime 1.13.1 8 | - onnx 1.13.0 9 | 10 | ## Architecture 11 | 12 | ```mermaid 13 | classDiagram 14 | `ONNX (int8)` <|-- `ONNX Runtime` : Quantized model 15 | `ONNX Runtime` <|-- `ONNX (float)` : Input model 16 | `ONNX Runtime` <|-- `Images` : Calibration images 17 | `ONNX Runtime` : quantize_static API 18 | `ONNX (float)` : FLoat model 19 | `ONNX (int8)` : Int8 model 20 | `Images` : Images 21 | ``` 22 | 23 | ## Calibration images 24 | 25 | The default is to quantize using only 2 images, which is less accurate. 26 | Place imagenet validation images in the imagenet_val folder or coco2017 images folder to improve quantization accuracy. 27 | (like --calibrate_dataset E:/git/ailia-models-measurement/object_detection/data/coco2017/images) 28 | 29 | ## Quantization command 30 | 31 | Quantization can be performed with the following command. 32 | 33 | MobileNetV2 (1 inputs model) 34 | 35 | ``` 36 | python3 quantize.py --input_model ./models/mobilenetv2_1.0.opt.onnx --output_model ./models/mobilenet_quantized.onnx --calibrate_dataset imagenet_val --per_channel True 37 | ``` 38 | 39 | YOLOv3 Tiny (4 inputs model) 40 | 41 | ``` 42 | python3 quantize.py --input_model ./models/yolov3-tiny.opt.onnx --output_model ./models/yolov3-tiny_quantized_per_tensor.onnx --calibrate_dataset imagenet_val 43 | ``` 44 | 45 | YOLOX Tiny (1 inputs model) 46 | 47 | ``` 48 | python3 quantize.py --input_model ./models/yolox_tiny.opt.onnx --output_model ./models/yolox_tiny_quantized_per_tensor.onnx --calibrate_dataset imagenet_val 49 | python3 quantize.py --input_model ./models/yolox_tiny.opt.onnx --output_model ./models/yolox_tiny_quantized_per_channel.onnx --calibrate_dataset imagenet_val --per_channel Trues 50 | ``` 51 | 52 | ## Test 53 | 54 | Inference using quantized yolox can be executed with the following command. 55 | 56 | ``` 57 | cd test 58 | python3 yolox.py 59 | ``` 60 | 61 | ## Output 62 | 63 | - [mobilenev2_quantized.onnx](./models/mobilenev2_quantized.onnx) 64 | - [yolov3-tiny_quantized_per_tensor.onnx](./models/yolov3-tiny_quantized_per_tensor.onnx) 65 | - [yolox_tiny_quantized_per_channel.onnx](./models/yolox_tiny_quantized_per_channel.onnx) 66 | - [yolox_tiny_quantized_per_tensor.onnx](./models/yolox_tiny_quantized_per_tensor.onnx) 67 | 68 | ## Limitation 69 | 70 | Per-Channel support with QDQ format requires onnx opset version 13 or above. 71 | 72 | ## Reference 73 | 74 | - [Official sample](https://onnxruntime.ai/docs/performance/quantization.html) 75 | - [Official document](https://github.com/microsoft/onnxruntime-inference-examples/tree/main/quantization/image_classification/cpu) 76 | -------------------------------------------------------------------------------- /data_reader.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import onnxruntime 3 | import os 4 | from onnxruntime.quantization import CalibrationDataReader 5 | from PIL import Image 6 | 7 | 8 | def _preprocess_images(images_folder: str, height: int, width: int, size_limit=0, model_path: str=None): 9 | """ 10 | Loads a batch of images and preprocess them 11 | parameter images_folder: path to folder storing images 12 | parameter height: image height in pixels 13 | parameter width: image width in pixels 14 | parameter size_limit: number of images to load. Default is 0 which means all images are picked. 15 | return: list of matrices characterizing multiple images 16 | """ 17 | image_names = os.listdir(images_folder) 18 | if size_limit > 0 and len(image_names) >= size_limit: 19 | batch_filenames = [image_names[i] for i in range(size_limit)] 20 | else: 21 | batch_filenames = image_names 22 | unconcatenated_batch_data = [] 23 | 24 | for image_name in batch_filenames: 25 | image_filepath = images_folder + "/" + image_name 26 | if "yolox" in model_path: 27 | pillow_img = Image.new("RGB", (width, height)) 28 | pillow_img.paste(Image.open(image_filepath).resize((width, height))) 29 | input_data = numpy.float32(pillow_img) # 0 - 255 30 | input_data = input_data[:,:,::-1] # RGB -> BGR 31 | elif "yolov3" in model_path: 32 | pillow_img = Image.new("RGB", (width, height)) # RGB 33 | pillow_img.paste(Image.open(image_filepath).resize((width, height))) 34 | input_data = numpy.float32(pillow_img) / 255.0 # 0 - 1 35 | else: 36 | pillow_img = Image.new("RGB", (width, height)) # RGB 37 | pillow_img.paste(Image.open(image_filepath).resize((width, height))) 38 | input_data = numpy.float32(pillow_img) - numpy.array( 39 | [123.68, 116.78, 103.94], dtype=numpy.float32 40 | ) # -128 - 127 41 | nhwc_data = numpy.expand_dims(input_data, axis=0) 42 | nchw_data = nhwc_data.transpose(0, 3, 1, 2) # ONNX Runtime standard 43 | unconcatenated_batch_data.append(nchw_data) 44 | batch_data = numpy.concatenate( 45 | numpy.expand_dims(unconcatenated_batch_data, axis=0), axis=0 46 | ) 47 | return batch_data 48 | 49 | 50 | class DataReader(CalibrationDataReader): 51 | def __init__(self, calibration_image_folder: str, model_path: str): 52 | self.enum_data = None 53 | 54 | # Use inference session to get input shape. 55 | session = onnxruntime.InferenceSession(model_path, None) 56 | (_, _, height, width) = session.get_inputs()[0].shape 57 | if "yolov3" in model_path: 58 | width = 416 59 | height = 416 60 | 61 | # Convert image to input data 62 | self.nhwc_data_list = _preprocess_images( 63 | calibration_image_folder, height, width, size_limit=0, model_path = model_path 64 | ) 65 | self.input_name = session.get_inputs()[0].name 66 | self.shape_nam = None 67 | if "yolov3" in model_path: 68 | self.shape_name = session.get_inputs()[1].name 69 | self.width = width 70 | self.height = height 71 | self.iou_name = session.get_inputs()[2].name 72 | self.threshold_name = session.get_inputs()[3].name 73 | self.datasize = len(self.nhwc_data_list) 74 | 75 | def get_next(self): 76 | if self.enum_data is None: 77 | if self.shape_name: 78 | shape_data = numpy.array([self.height, self.width], dtype='float32').reshape(1, 2) 79 | iou_data = numpy.array([numpy.random.rand()], dtype='float32').reshape(1) 80 | threshold_data = numpy.array([numpy.random.rand()], dtype='float32').reshape(1) 81 | self.enum_data = iter( 82 | [{self.input_name: nhwc_data, self.shape_name: shape_data, self.iou_name: iou_data, self.threshold_name: threshold_data} for nhwc_data in self.nhwc_data_list] 83 | ) 84 | else: 85 | self.enum_data = iter( 86 | [{self.input_name: nhwc_data} for nhwc_data in self.nhwc_data_list] 87 | ) 88 | return next(self.enum_data, None) 89 | 90 | def rewind(self): 91 | self.enum_data = None 92 | -------------------------------------------------------------------------------- /imagenet_val/clock.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/axinc-ai/onnx-quantization/b3f422836eac25c5e046d80e56bcad6a26f870cd/imagenet_val/clock.jpg -------------------------------------------------------------------------------- /imagenet_val/input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/axinc-ai/onnx-quantization/b3f422836eac25c5e046d80e56bcad6a26f870cd/imagenet_val/input.jpg -------------------------------------------------------------------------------- /models/LICENSE_MOBILENETV2: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /models/LICENSE_YOLOV3.txt: -------------------------------------------------------------------------------- 1 | YOLO LICENSE 2 | Version 2, July 29 2016 3 | 4 | THIS SOFTWARE LICENSE IS PROVIDED "ALL CAPS" SO THAT YOU KNOW IT IS SUPER 5 | SERIOUS AND YOU DON'T MESS AROUND WITH COPYRIGHT LAW BECAUSE YOU WILL GET IN 6 | TROUBLE HERE ARE SOME OTHER BUZZWORDS COMMONLY IN THESE THINGS WARRANTIES 7 | LIABILITY CONTRACT TORT LIABLE CLAIMS RESTRICTION MERCHANTABILITY. NOW HERE'S 8 | THE REAL LICENSE: 9 | 10 | 0. Darknet is public domain. 11 | 1. Do whatever you want with it. 12 | 2. Stop emailing me about it! 13 | -------------------------------------------------------------------------------- /models/LICENSE_YOLOX: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2021 Megvii, Base Detection 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /models/mobilenetv2_1.0.opt.onnx: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:63b4317933eb388d5b7ba9f910e338c298981516e2ce48bd8b640debfd711f65 3 | size 13967959 4 | -------------------------------------------------------------------------------- /models/mobilenev2_quantized.onnx: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:11adcfec18f5cda5d57c11f3ff592376484a4ca7714c40e03e3b94d19750f813 3 | size 3608592 4 | -------------------------------------------------------------------------------- /models/yolov3-tiny.opt.onnx: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:d64a1ef36a2bdb3595b65a66ec55c6a0d6f8c12bb4e90cf5d3f52d61f389775e 3 | size 35509443 4 | -------------------------------------------------------------------------------- /models/yolov3-tiny_quantized_per_tensor.onnx: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:feaff4f914e15b5e5c5831976ab061d4931148e076836845e802986607d7cf6b 3 | size 9072322 4 | -------------------------------------------------------------------------------- /models/yolox_tiny.opt.onnx: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e177616cd978c70ff98bc0fa42723636aee5c604da29c7e113570faf2f492aaf 3 | size 20219792 4 | -------------------------------------------------------------------------------- /models/yolox_tiny_quantized_per_channel.onnx: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:68b8330d05341bca8e948af703ea42e73cc8e3a49d76a14d59b6cc2daccc2aac 3 | size 5347868 4 | -------------------------------------------------------------------------------- /models/yolox_tiny_quantized_per_tensor.onnx: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:99db29e05cb74e7af24e06e5567eb22ddcf7d4e1698cfad66297a9f4ce98dce8 3 | size 5230126 4 | -------------------------------------------------------------------------------- /quantize.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import onnxruntime 4 | import time 5 | from onnxruntime.quantization import QuantFormat, QuantType, quantize_static 6 | 7 | import data_reader 8 | 9 | 10 | def benchmark(model_path): 11 | session = onnxruntime.InferenceSession(model_path) 12 | input_name = session.get_inputs()[0].name 13 | 14 | total = 0.0 15 | runs = 10 16 | input_data = np.zeros((1, 3, 416, 416), np.float32) 17 | # Warming up 18 | _ = session.run([], {input_name: input_data}) 19 | for i in range(runs): 20 | start = time.perf_counter() 21 | _ = session.run([], {input_name: input_data}) 22 | end = (time.perf_counter() - start) * 1000 23 | total += end 24 | print(f"{end:.2f}ms") 25 | total /= runs 26 | print(f"Avg: {total:.2f}ms") 27 | 28 | 29 | def get_args(): 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument("--input_model", required=True, help="input model") 32 | parser.add_argument("--output_model", required=True, help="output model") 33 | parser.add_argument( 34 | "--calibrate_dataset", default="./test_images", help="calibration data set" 35 | ) 36 | parser.add_argument( 37 | "--quant_format", 38 | default=QuantFormat.QDQ, 39 | type=QuantFormat.from_string, 40 | choices=list(QuantFormat), 41 | ) 42 | parser.add_argument("--per_channel", default=False, type=bool) 43 | args = parser.parse_args() 44 | return args 45 | 46 | 47 | def main(): 48 | args = get_args() 49 | input_model_path = args.input_model 50 | output_model_path = args.output_model 51 | calibration_dataset_path = args.calibrate_dataset 52 | dr = data_reader.DataReader( 53 | calibration_dataset_path, input_model_path 54 | ) 55 | 56 | if args.per_channel: 57 | import onnx 58 | model = onnx.load(input_model_path) 59 | print("Per-Channel support with QDQ format requires onnx opset version 13 or above. So automatically convert opset version.") 60 | if model.opset_import[0].version < 13: 61 | op = onnx.OperatorSetIdProto() 62 | op.version = 13 63 | update_model = onnx.helper.make_model(model.graph, opset_imports=[op]) 64 | input_model_path = './models/temp.onnx' 65 | onnx.save(update_model, './models/temp.onnx') 66 | 67 | # Calibrate and quantize model 68 | # Turn off model optimization during quantization 69 | quantize_static( 70 | input_model_path, 71 | output_model_path, 72 | dr, 73 | quant_format=args.quant_format, 74 | per_channel=args.per_channel, 75 | weight_type=QuantType.QInt8, 76 | optimize_model=False, 77 | ) 78 | print("Calibrated and quantized model saved.") 79 | 80 | #print("benchmarking fp32 model...") 81 | #benchmark(input_model_path) 82 | 83 | #print("benchmarking int8 model...") 84 | #benchmark(output_model_path) 85 | 86 | 87 | if __name__ == "__main__": 88 | main() 89 | -------------------------------------------------------------------------------- /test/input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/axinc-ai/onnx-quantization/b3f422836eac25c5e046d80e56bcad6a26f870cd/test/input.jpg -------------------------------------------------------------------------------- /test/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/axinc-ai/onnx-quantization/b3f422836eac25c5e046d80e56bcad6a26f870cd/test/util/__init__.py -------------------------------------------------------------------------------- /test/util/classifier_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import cv2 5 | 6 | MAX_CLASS_COUNT = 3 7 | 8 | RECT_WIDTH = 640 9 | RECT_HEIGHT = 20 10 | RECT_MARGIN = 2 11 | 12 | def get_top_scores(classifier, top_k=MAX_CLASS_COUNT): 13 | if hasattr(classifier, 'get_class_count'): 14 | # ailia classifier API 15 | count = classifier.get_class_count() 16 | scores = {} 17 | top_scores = [] 18 | for idx in range(count): 19 | obj = classifier.get_class(idx) 20 | top_scores.append(obj.category) 21 | scores[obj.category] = obj.prob 22 | else: 23 | # ailia predict API 24 | classifier = classifier[0] 25 | top_scores = classifier.argsort()[-1 * top_k:][::-1] 26 | scores = classifier 27 | return top_scores, scores 28 | 29 | 30 | def print_results(classifier, labels, top_k=MAX_CLASS_COUNT): 31 | top_scores, scores = get_top_scores(classifier, top_k) 32 | top_k = min(len(top_scores),top_k) 33 | 34 | print('==============================================================') 35 | print(f'class_count={top_k}') 36 | for idx in range(top_k): 37 | print(f'+ idx={idx}') 38 | print(f' category={top_scores[idx]}[' 39 | f'{labels[top_scores[idx]]} ]') 40 | print(f' prob={scores[top_scores[idx]]}') 41 | 42 | 43 | def hsv_to_rgb(h, s, v): 44 | bgr = cv2.cvtColor( 45 | np.array([[[h, s, v]]], dtype=np.uint8), cv2.COLOR_HSV2BGR)[0][0] 46 | return (int(bgr[0]), int(bgr[1]), int(bgr[2]), 255) 47 | 48 | 49 | def plot_results(input_image, classifier, labels, top_k=MAX_CLASS_COUNT, logging=True): 50 | x = RECT_MARGIN 51 | y = RECT_MARGIN 52 | w = RECT_WIDTH 53 | h = RECT_HEIGHT 54 | 55 | top_scores, scores = get_top_scores(classifier, top_k) 56 | top_k = min(len(top_scores),top_k) 57 | 58 | if logging: 59 | print('==============================================================') 60 | print(f'class_count={top_k}') 61 | for idx in range(top_k): 62 | if logging: 63 | print(f'+ idx={idx}') 64 | print(f' category={top_scores[idx]}[' 65 | f'{labels[top_scores[idx]]} ]') 66 | print(f' prob={scores[top_scores[idx]]}') 67 | 68 | text = f'category={top_scores[idx]}[{labels[top_scores[idx]]} ] prob={scores[top_scores[idx]]}' 69 | 70 | color = hsv_to_rgb(256 * top_scores[idx] / (len(labels)+1), 128, 255) 71 | 72 | cv2.rectangle(input_image, (x, y), (x + w, y + h), color, thickness=-1) 73 | text_position = (x+4, y+int(RECT_HEIGHT/2)+4) 74 | 75 | color = (0,0,0) 76 | fontScale = 0.5 77 | 78 | cv2.putText( 79 | input_image, 80 | text, 81 | text_position, 82 | cv2.FONT_HERSHEY_SIMPLEX, 83 | fontScale, 84 | color, 85 | 1 86 | ) 87 | 88 | y=y + h + RECT_MARGIN 89 | 90 | 91 | def write_predictions(file_name, classifier, labels): 92 | top_k = 5 93 | top_scores, scores = get_top_scores(classifier, top_k) 94 | top_k = min(len(top_scores),top_k) 95 | with open(file_name, 'w') as f: 96 | for idx in range(top_k): 97 | f.write('%s %d %f\n' % ( 98 | labels[top_scores[idx]].replace(' ', '_'), 99 | top_scores[idx], 100 | scores[top_scores[idx]] 101 | )) 102 | -------------------------------------------------------------------------------- /test/util/detector_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from logging import getLogger 4 | 5 | import cv2 6 | import numpy as np 7 | 8 | import ctypes 9 | 10 | class DetectorObject(ctypes.Structure): 11 | _fields_ = [ 12 | ("category", ctypes.c_uint), 13 | ("prob", ctypes.c_float), 14 | ("x", ctypes.c_float), 15 | ("y", ctypes.c_float), 16 | ("w", ctypes.c_float), 17 | ("h", ctypes.c_float)] 18 | VERSION = ctypes.c_uint(1) 19 | 20 | logger = getLogger(__name__) 21 | 22 | sys.path.append(os.path.dirname(__file__)) 23 | from image_utils import imread # noqa: E402 24 | 25 | 26 | def preprocessing_img(img): 27 | if len(img.shape) < 3: 28 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGRA) 29 | elif img.shape[2] == 3: 30 | img = cv2.cvtColor(img, cv2.COLOR_BGR2BGRA) 31 | elif img.shape[2] == 1: 32 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGRA) 33 | return img 34 | 35 | 36 | def load_image(image_path): 37 | if os.path.isfile(image_path): 38 | img = imread(image_path, cv2.IMREAD_UNCHANGED) 39 | else: 40 | logger.error(f'{image_path} not found.') 41 | sys.exit() 42 | return preprocessing_img(img) 43 | 44 | 45 | def hsv_to_rgb(h, s, v): 46 | bgr = cv2.cvtColor( 47 | np.array([[[h, s, v]]], dtype=np.uint8), cv2.COLOR_HSV2BGR)[0][0] 48 | return (int(bgr[0]), int(bgr[1]), int(bgr[2]), 255) 49 | 50 | 51 | def letterbox_convert(frame, det_shape): 52 | """ 53 | Adjust the size of the frame from the webcam to the ailia input shape. 54 | 55 | Parameters 56 | ---------- 57 | frame: numpy array 58 | det_shape: tuple 59 | ailia model input (height,width) 60 | 61 | Returns 62 | ------- 63 | resized_img: numpy array 64 | Resized `img` as well as adapt the scale 65 | """ 66 | height, width = det_shape[0], det_shape[1] 67 | f_height, f_width = frame.shape[0], frame.shape[1] 68 | scale = np.max((f_height / height, f_width / width)) 69 | 70 | # padding base 71 | img = np.zeros( 72 | (int(round(scale * height)), int(round(scale * width)), 3), 73 | np.uint8 74 | ) 75 | start = (np.array(img.shape) - np.array(frame.shape)) // 2 76 | img[ 77 | start[0]: start[0] + f_height, 78 | start[1]: start[1] + f_width 79 | ] = frame 80 | resized_img = cv2.resize(img, (width, height)) 81 | return resized_img 82 | 83 | 84 | def reverse_letterbox(detections, img, det_shape): 85 | h, w = img.shape[0], img.shape[1] 86 | 87 | pad_x = pad_y = 0 88 | if det_shape != None: 89 | scale = np.max((h / det_shape[0], w / det_shape[1])) 90 | start = (det_shape[0:2] - np.array(img.shape[0:2]) / scale) // 2 91 | pad_x = start[1] * scale 92 | pad_y = start[0] * scale 93 | 94 | new_detections = [] 95 | for detection in detections: 96 | logger.debug(detection) 97 | r = DetectorObject( 98 | category=detection.category, 99 | prob=detection.prob, 100 | x=(detection.x * (w + pad_x * 2) - pad_x) / w, 101 | y=(detection.y * (h + pad_y * 2) - pad_y) / h, 102 | w=(detection.w * (w + pad_x * 2)) / w, 103 | h=(detection.h * (h + pad_y * 2)) / h, 104 | ) 105 | new_detections.append(r) 106 | 107 | return new_detections 108 | 109 | 110 | def plot_results(detector, img, category=None, segm_masks=None, logging=True): 111 | """ 112 | :param detector: ailia.Detector, or list of ailia.DetectorObject 113 | :param img: ndarray data of image 114 | :param category: list of category_name 115 | :param segm_masks: 116 | :param logging: output log flg 117 | :return: 118 | """ 119 | h, w = img.shape[0], img.shape[1] 120 | 121 | count = detector.get_object_count() if hasattr(detector, 'get_object_count') else len(detector) 122 | if logging: 123 | print(f'object_count={count}') 124 | 125 | # prepare color data 126 | colors = [] 127 | for idx in range(count): 128 | obj = detector.get_object(idx) if hasattr(detector, 'get_object') else detector[idx] 129 | 130 | # print result 131 | if logging: 132 | print(f'+ idx={idx}') 133 | print( 134 | f' category={obj.category}[ {category[int(obj.category)]} ]' 135 | if not isinstance(obj.category, str) and category is not None 136 | else f' category=[ {obj.category} ]' 137 | ) 138 | print(f' prob={obj.prob}') 139 | print(f' x={obj.x}') 140 | print(f' y={obj.y}') 141 | print(f' w={obj.w}') 142 | print(f' h={obj.h}') 143 | 144 | if isinstance(obj.category, int) and category is not None: 145 | color = hsv_to_rgb(256 * obj.category / (len(category) + 1), 255, 255) 146 | else: 147 | color = hsv_to_rgb(256 * idx / (len(detector) + 1), 255, 255) 148 | colors.append(color) 149 | 150 | # draw segmentation area 151 | if segm_masks: 152 | for idx in range(count): 153 | mask = np.repeat(np.expand_dims(segm_masks[idx], 2), 3, 2).astype(np.bool) 154 | color = colors[idx][:3] 155 | fill = np.repeat(np.repeat([[color]], img.shape[0], 0), img.shape[1], 1) 156 | img[:, :, :3][mask] = img[:, :, :3][mask] * 0.7 + fill[mask] * 0.3 157 | 158 | # draw bounding box 159 | for idx in range(count): 160 | obj = detector.get_object(idx) if hasattr(detector, 'get_object') else detector[idx] 161 | top_left = (int(w * obj.x), int(h * obj.y)) 162 | bottom_right = (int(w * (obj.x + obj.w)), int(h * (obj.y + obj.h))) 163 | 164 | color = colors[idx] 165 | cv2.rectangle(img, top_left, bottom_right, color, 4) 166 | 167 | # draw label 168 | for idx in range(count): 169 | obj = detector.get_object(idx) if hasattr(detector, 'get_object') else detector[idx] 170 | fontScale = w / 2048 171 | 172 | text = category[int(obj.category)] \ 173 | if not isinstance(obj.category, str) and category is not None \ 174 | else obj.category 175 | text = "{} {}".format(text, int(obj.prob * 100) / 100) 176 | textsize = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, fontScale, 1)[0] 177 | tw = textsize[0] 178 | th = textsize[1] 179 | 180 | margin = 3 181 | 182 | x1 = int(w * obj.x) 183 | y1 = int(h * obj.y) 184 | x2 = x1 + tw + margin 185 | y2 = y1 + th + margin 186 | 187 | # check the x,y x2,y2 are inside the image 188 | if x1 < 0: 189 | x1 = 0 190 | elif x2 > w: 191 | x1 = w - (tw + margin) 192 | 193 | if y1 < 0: 194 | y1 = 0 195 | elif y2 > h: 196 | y1 = h - (th + margin) 197 | 198 | # recompute x2, y2 if shift occured 199 | x2 = x1 + tw + margin 200 | y2 = y1 + th + margin 201 | 202 | top_left = (x1, y1) 203 | bottom_right = (x2, y2) 204 | 205 | color = colors[idx] 206 | cv2.rectangle(img, top_left, bottom_right, color, thickness=-1) 207 | 208 | text_color = (255, 255, 255, 255) 209 | cv2.putText( 210 | img, 211 | text, 212 | (top_left[0], top_left[1] + th), 213 | cv2.FONT_HERSHEY_SIMPLEX, 214 | fontScale, 215 | text_color, 216 | 1 217 | ) 218 | return img 219 | 220 | 221 | def write_predictions(file_name, detector, img=None, category=None): 222 | h, w = (img.shape[0], img.shape[1]) if img is not None else (1, 1) 223 | 224 | count = detector.get_object_count() if hasattr(detector, 'get_object_count') else len(detector) 225 | 226 | with open(file_name, 'w') as f: 227 | for idx in range(count): 228 | obj = detector.get_object(idx) if hasattr(detector, 'get_object') else detector[idx] 229 | label = category[obj.category] if category else obj.category 230 | f.write('%s %f %d %d %d %d\n' % ( 231 | label.replace(' ', '_'), 232 | obj.prob, 233 | int(w * obj.x), int(h * obj.y), 234 | int(w * obj.w), int(h * obj.h), 235 | )) 236 | -------------------------------------------------------------------------------- /test/util/image_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from logging import getLogger 4 | 5 | import cv2 6 | import numpy as np 7 | 8 | logger = getLogger(__name__) 9 | 10 | 11 | def imread(filename, flags=cv2.IMREAD_COLOR): 12 | if not os.path.isfile(filename): 13 | logger.error(f"File does not exist: {filename}") 14 | sys.exit() 15 | data = np.fromfile(filename, np.int8) 16 | img = cv2.imdecode(data, flags) 17 | return img 18 | 19 | 20 | def normalize_image(image, normalize_type='255'): 21 | """ 22 | Normalize image 23 | 24 | Parameters 25 | ---------- 26 | image: numpy array 27 | The image you want to normalize 28 | normalize_type: string 29 | Normalize type should be chosen from the type below. 30 | - '255': simply dividing by 255.0 31 | - '127.5': output range : -1 and 1 32 | - 'ImageNet': normalize by mean and std of ImageNet 33 | - 'None': no normalization 34 | 35 | Returns 36 | ------- 37 | normalized_image: numpy array 38 | """ 39 | if normalize_type == 'None': 40 | return image 41 | elif normalize_type == '255': 42 | return image / 255.0 43 | elif normalize_type == '127.5': 44 | return image / 127.5 - 1.0 45 | elif normalize_type == 'ImageNet': 46 | mean = np.array([0.485, 0.456, 0.406]) 47 | std = np.array([0.229, 0.224, 0.225]) 48 | image = image / 255.0 49 | for i in range(3): 50 | image[:, :, i] = (image[:, :, i] - mean[i]) / std[i] 51 | return image 52 | else: 53 | logger.error(f'Unknown normalize_type is given: {normalize_type}') 54 | sys.exit() 55 | 56 | 57 | def load_image( 58 | image_path, 59 | image_shape, 60 | rgb=True, 61 | normalize_type='255', 62 | gen_input_ailia=False, 63 | ): 64 | """ 65 | Loads the image of the given path, performs the necessary preprocessing, 66 | and returns it. 67 | 68 | Parameters 69 | ---------- 70 | image_path: string 71 | The path of image which you want to load. 72 | image_shape: (int, int) (height, width) 73 | Resizes the loaded image to the size required by the model. 74 | rgb: bool, default=True 75 | Load as rgb image when True, as gray scale image when False. 76 | normalize_type: string 77 | Normalize type should be chosen from the type below. 78 | - '255': output range: 0 and 1 79 | - '127.5': output range : -1 and 1 80 | - 'ImageNet': normalize by mean and std of ImageNet. 81 | - 'None': no normalization 82 | gen_input_ailia: bool, default=False 83 | If True, convert the image to the form corresponding to the ailia. 84 | 85 | Returns 86 | ------- 87 | image: numpy array 88 | """ 89 | # rgb == True --> cv2.IMREAD_COLOR 90 | # rbg == False --> cv2.IMREAD_GRAYSCALE 91 | image = imread(image_path, int(rgb)) 92 | if rgb: 93 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 94 | image = normalize_image(image, normalize_type) 95 | image = cv2.resize(image, (image_shape[1], image_shape[0])) 96 | 97 | if gen_input_ailia: 98 | if rgb: 99 | image = image.transpose((2, 0, 1)) # channel first 100 | image = image[np.newaxis, :, :, :] # (batch_size, channel, h, w) 101 | else: 102 | image = image[np.newaxis, np.newaxis, :, :] 103 | return image 104 | 105 | 106 | def get_image_shape(image_path): 107 | tmp = imread(image_path) 108 | height, width = tmp.shape[0], tmp.shape[1] 109 | return height, width 110 | 111 | 112 | # (ref: https://qiita.com/yasudadesu/items/dd3e74dcc7e8f72bc680) 113 | def draw_texts(img, texts, font_scale=0.7, thickness=2): 114 | h, w, c = img.shape 115 | offset_x = 10 116 | initial_y = 0 117 | dy = int(img.shape[1] / 15) 118 | color = (0, 0, 0) # black 119 | 120 | texts = [texts] if type(texts) == str else texts 121 | 122 | for i, text in enumerate(texts): 123 | offset_y = initial_y + (i+1)*dy 124 | cv2.putText(img, text, (offset_x, offset_y), cv2.FONT_HERSHEY_SIMPLEX, 125 | font_scale, color, thickness, cv2.LINE_AA) 126 | 127 | 128 | def draw_result_on_img(img, texts, w_ratio=0.35, h_ratio=0.2, alpha=0.4): 129 | overlay = img.copy() 130 | pt1 = (0, 0) 131 | pt2 = (int(img.shape[1] * w_ratio), int(img.shape[0] * h_ratio)) 132 | 133 | mat_color = (200, 200, 200) 134 | fill = -1 135 | cv2.rectangle(overlay, pt1, pt2, mat_color, fill) 136 | 137 | mat_img = cv2.addWeighted(overlay, alpha, img, 1 - alpha, 0) 138 | 139 | draw_texts(mat_img, texts) 140 | return mat_img 141 | -------------------------------------------------------------------------------- /test/util/log_init.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import datetime 3 | 4 | from logging import getLogger, StreamHandler, FileHandler, Formatter 5 | from logging import INFO, DEBUG 6 | 7 | 8 | # ===== User Configuration ==================================================== 9 | 10 | # Log file name (if disable_file_handler is set to False) 11 | now = datetime.datetime.now() 12 | save_filename = now.strftime('%Y%m%d') + '.log' 13 | 14 | # level: CRITICAL > ERROR > WARNING > INFO > DEBUG 15 | log_level = INFO 16 | 17 | # params 18 | disable_stream_handler = False 19 | disable_file_handler = True # set False if you want to save text log file 20 | display_date = False 21 | 22 | # ============================================================================= 23 | 24 | # default logging format 25 | if display_date: 26 | datefmt = '%Y/%m/%d %H:%M:%S' 27 | default_fmt = Formatter( 28 | '[%(asctime)s.%(msecs)03d] %(levelname)5s ' 29 | '(%(process)d) %(filename)s: %(message)s', 30 | datefmt=datefmt 31 | ) 32 | else: 33 | default_fmt = Formatter( 34 | '%(levelname)5s %(filename)s (%(lineno)d) : %(message)s' 35 | ) 36 | 37 | 38 | logger = getLogger() 39 | 40 | # remove duplicate handlers 41 | if (logger.hasHandlers()): 42 | logger.handlers.clear() 43 | 44 | # the level of logging passed to the Handler 45 | logger.setLevel(log_level) 46 | 47 | # set up stream handler 48 | if not disable_stream_handler: 49 | try: 50 | # Rainbow Logging 51 | from rainbow_logging_handler import RainbowLoggingHandler # noqa: E402 52 | color_msecs = ('green', None, True) 53 | stream_handler = RainbowLoggingHandler( 54 | sys.stdout, color_msecs=color_msecs, datefmt=datefmt 55 | ) 56 | # msecs color 57 | stream_handler._column_color['.'] = color_msecs 58 | stream_handler._column_color['%(asctime)s'] = color_msecs 59 | stream_handler._column_color['%(msecs)03d'] = color_msecs 60 | except Exception: 61 | stream_handler = StreamHandler() 62 | 63 | # the level of output logging 64 | stream_handler.setLevel(DEBUG) 65 | stream_handler.setFormatter(default_fmt) 66 | logger.addHandler(stream_handler) 67 | 68 | if not disable_file_handler: 69 | file_handler = FileHandler(filename=save_filename) 70 | 71 | # the level of output logging 72 | file_handler.setLevel(DEBUG) 73 | file_handler.setFormatter(default_fmt) 74 | logger.addHandler(file_handler) 75 | -------------------------------------------------------------------------------- /test/util/math_utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | 5 | 6 | def softmax(x, axis=None): 7 | max = np.max(x, axis=axis, keepdims=True) 8 | e_x = np.exp(x - max) 9 | sum = np.sum(e_x, axis=axis, keepdims=True) 10 | f_x = e_x / sum 11 | return f_x 12 | 13 | 14 | def sigmoid(x): 15 | with warnings.catch_warnings(): 16 | warnings.simplefilter('ignore') 17 | 18 | return 1.0 / (1.0 + np.exp(-x)) 19 | -------------------------------------------------------------------------------- /test/util/microphone_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | import threading 4 | import multiprocessing as mp 5 | from logging import getLogger 6 | 7 | import numpy as np 8 | 9 | logger = getLogger(__name__) 10 | 11 | M1_SAMPLE_RATE = 48000 12 | 13 | 14 | def capture_microphone(que, ready, pause, fin, sample_rate, sc=False, speaker=False): 15 | if sc: 16 | import soundcard as sc 17 | else: 18 | import pyaudio 19 | import librosa 20 | 21 | # M1 macだとpyaudioが48000Hzしか取得できない 22 | SAMPLE_RATE = M1_SAMPLE_RATE if sc is False else sample_rate 23 | THRES_SPEECH_POW = 0.001 24 | THRES_SILENCE_POW = 0.0001 25 | INTERVAL = SAMPLE_RATE * 3 26 | INTERVAL_MIN = SAMPLE_RATE * 1.5 27 | BUFFER_MAX = SAMPLE_RATE * 10 28 | 29 | def send(audio, n): 30 | if INTERVAL_MIN < n: 31 | if sc is False and SAMPLE_RATE != sample_rate: 32 | audio = librosa.resample(audio, orig_sr=SAMPLE_RATE, target_sr=sample_rate) 33 | que.put_nowait(audio[:n]) 34 | 35 | def read(src): 36 | ready.set() 37 | 38 | v = np.ones(100) / 100 39 | buf = np.array([], dtype=np.float32) 40 | while not fin.is_set(): 41 | if pause.is_set(): 42 | buf = buf[:0] 43 | time.sleep(0.1) 44 | continue 45 | 46 | if sc: 47 | audio = src.record(INTERVAL) 48 | else: 49 | audio = np.frombuffer(src.read(INTERVAL, exception_on_overflow=False), dtype=np.int16) / 32768.0 50 | 51 | audio = audio.reshape(-1) 52 | square = audio ** 2 53 | if np.max(square) >= THRES_SPEECH_POW: 54 | sys.stdout.write(".") 55 | sys.stdout.flush() 56 | 57 | # 平準化 58 | conv = np.convolve(square, v, 'valid') 59 | conv = np.pad(conv, (0, len(v) - 1), mode='edge') 60 | # 0.1s刻みで0.5s区間をチェック 61 | s = SAMPLE_RATE // 10 62 | x = [(min(i + s * 5, INTERVAL), np.any(conv[i:i + s * 5] >= THRES_SILENCE_POW)) 63 | for i in range(0, INTERVAL - 5 * s + 1, s)] 64 | 65 | # Speech section 66 | speech = [a[0] for a in x if a[1]] 67 | if speech: 68 | if len(buf) == 0 and 1 < len(speech): 69 | i = max(speech[0] - 3 * s, 0) 70 | audio = audio[i:] 71 | i = speech[-1] 72 | audio = audio[:i] 73 | else: 74 | i = 0 75 | 76 | buf = np.concatenate([buf, audio]) 77 | if i < INTERVAL: 78 | send(buf, len(buf)) 79 | buf = buf[:0] 80 | elif BUFFER_MAX < len(buf): 81 | i = np.argmin(buf[::-1]) 82 | i = len(buf) - i 83 | if 0 < i: 84 | send(buf, i) 85 | buf = buf[i:] 86 | else: 87 | send(buf, len(buf)) 88 | buf = buf[:0] 89 | elif 0 < len(buf): 90 | send(buf, len(buf)) 91 | buf = buf[:0] 92 | 93 | try: 94 | # start recording 95 | if sc: 96 | mic_id = str(sc.default_speaker().name) if speaker else str(sc.default_microphone().name) 97 | with sc.get_microphone(id=mic_id, include_loopback=speaker).recorder( 98 | samplerate=SAMPLE_RATE, channels=1) as mic: 99 | read(mic) 100 | else: 101 | p = pyaudio.PyAudio() 102 | stream = p.open( 103 | format=pyaudio.paInt16, 104 | channels=1, 105 | rate=SAMPLE_RATE, 106 | input=True, 107 | frames_per_buffer=1024, 108 | ) 109 | stream.start_stream() 110 | read(stream) 111 | 112 | stream.stop_stream() 113 | stream.close() 114 | p.terminate() 115 | pass 116 | except KeyboardInterrupt: 117 | pass 118 | except Exception as e: 119 | logger.exception(e) 120 | 121 | 122 | def start_microphone_input(sample_rate, sc=False, speaker=False, thread=False, queue_size=2): 123 | que = mp.Queue(maxsize=queue_size) 124 | ready = mp.Event() 125 | pause = mp.Event() 126 | fin = mp.Event() 127 | 128 | if thread: 129 | p = threading.Thread( 130 | target=capture_microphone, 131 | args=(que, ready, pause, fin, sample_rate, sc, speaker), 132 | daemon=True) 133 | else: 134 | p = mp.Process( 135 | target=capture_microphone, 136 | args=(que, ready, pause, fin, sample_rate, sc, speaker), 137 | daemon=True) 138 | p.start() 139 | 140 | # キャプチャスレッド起動待ち 141 | while p.is_alive(): 142 | if ready.is_set(): 143 | break 144 | 145 | if not p.is_alive(): 146 | raise Exception('Fail to start microphone capture.') 147 | 148 | params = dict( 149 | p=p, 150 | que=que, 151 | pause=pause, 152 | fin=fin, 153 | ) 154 | 155 | return params 156 | -------------------------------------------------------------------------------- /test/util/model_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import urllib.request 3 | import ssl 4 | 5 | # logger 6 | from logging import getLogger 7 | logger = getLogger(__name__) 8 | 9 | 10 | def progress_print(block_count, block_size, total_size): 11 | """ 12 | Callback function to display the progress 13 | (ref: https://qiita.com/jesus_isao/items/ffa63778e7d3952537db) 14 | 15 | Parameters 16 | ---------- 17 | block_count: 18 | block_size: 19 | total_size: 20 | """ 21 | percentage = 100.0 * block_count * block_size / total_size 22 | if percentage > 100: 23 | # Bigger than 100 does not look good, so... 24 | percentage = 100 25 | max_bar = 50 26 | bar_num = int(percentage / (100 / max_bar)) 27 | progress_element = '=' * bar_num 28 | if bar_num != max_bar: 29 | progress_element += '>' 30 | bar_fill = ' ' # fill the blanks 31 | bar = progress_element.ljust(max_bar, bar_fill) 32 | total_size_kb = total_size / 1024 33 | print(f'[{bar} {percentage:.2f}% ( {total_size_kb:.0f}KB )]', end='\r') 34 | 35 | def urlretrieve(remote_path,weight_path,progress_print): 36 | try: 37 | #raise ssl.SSLError # test 38 | urllib.request.urlretrieve( 39 | remote_path, 40 | weight_path, 41 | progress_print, 42 | ) 43 | except ssl.SSLError as e: 44 | logger.info(f'SSLError detected, so try to download without ssl') 45 | remote_path = remote_path.replace("https","http") 46 | urllib.request.urlretrieve( 47 | remote_path, 48 | weight_path, 49 | progress_print, 50 | ) 51 | 52 | def check_and_download_models(weight_path, model_path, remote_path): 53 | """ 54 | Check if the onnx file and prototxt file exists, 55 | and if necessary, download the files to the given path. 56 | 57 | Parameters 58 | ---------- 59 | weight_path: string 60 | The path of onnx file. 61 | model_path: string 62 | The path of prototxt file for ailia. 63 | remote_path: string 64 | The url where the onnx file and prototxt file are saved. 65 | ex. "https://storage.googleapis.com/ailia-models/mobilenetv2/" 66 | """ 67 | 68 | if not os.path.exists(weight_path): 69 | logger.info(f'Downloading onnx file... (save path: {weight_path})') 70 | urlretrieve( 71 | remote_path + os.path.basename(weight_path), 72 | weight_path, 73 | progress_print, 74 | ) 75 | logger.info('\n') 76 | if model_path!=None and not os.path.exists(model_path): 77 | logger.info(f'Downloading prototxt file... (save path: {model_path})') 78 | urlretrieve( 79 | remote_path + os.path.basename(model_path), 80 | model_path, 81 | progress_print, 82 | ) 83 | logger.info('\n') 84 | logger.info('ONNX file and Prototxt file are prepared!') 85 | -------------------------------------------------------------------------------- /test/util/nms_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def bb_intersection_over_union(boxA, boxB): 5 | # determine the (x, y)-coordinates of the intersection rectangle 6 | xA = max(boxA[0], boxB[0]) 7 | yA = max(boxA[1], boxB[1]) 8 | xB = min(boxA[2], boxB[2]) 9 | yB = min(boxA[3], boxB[3]) 10 | # compute the area of intersection rectangle 11 | interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1) 12 | # compute the area of both the prediction and ground-truth 13 | # rectangles 14 | boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1) 15 | boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1) 16 | # compute the intersection over union by taking the intersection 17 | # area and dividing it by the sum of prediction + ground-truth 18 | # areas - the interesection area 19 | iou = interArea / float(boxAArea + boxBArea - interArea) 20 | # return the intersection over union value 21 | return iou 22 | 23 | 24 | def nms_between_categories(detections, w, h, categories=None, iou_threshold=0.25): 25 | # Normally darknet use per class nms 26 | # But some cases need between class nms 27 | # https://github.com/opencv/opencv/issues/17111 28 | 29 | # remove overwrapped detection 30 | det = [] 31 | keep = [] 32 | for idx in range(len(detections)): 33 | obj = detections[idx] 34 | is_keep = True 35 | for idx2 in range(len(det)): 36 | if not keep[idx2]: 37 | continue 38 | box_a = [w * det[idx2].x, h * det[idx2].y, w * (det[idx2].x + det[idx2].w), h * (det[idx2].y + det[idx2].h)] 39 | box_b = [w * obj.x, h * obj.y, w * (obj.x + obj.w), h * (obj.y + obj.h)] 40 | iou = bb_intersection_over_union(box_a, box_b) 41 | if iou >= iou_threshold and ( 42 | categories == None or ((det[idx2].category in categories) and (obj.category in categories))): 43 | if det[idx2].prob <= obj.prob: 44 | keep[idx2] = False 45 | else: 46 | is_keep = False 47 | det.append(obj) 48 | keep.append(is_keep) 49 | 50 | det = [] 51 | for idx in range(len(detections)): 52 | if keep[idx]: 53 | det.append(detections[idx]) 54 | 55 | return det 56 | 57 | 58 | def nms_boxes(boxes, scores, iou_thres): 59 | # Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union (IoU). 60 | 61 | keep = [] 62 | for i, box_a in enumerate(boxes): 63 | is_keep = True 64 | for j in range(i): 65 | if not keep[j]: 66 | continue 67 | box_b = boxes[j] 68 | iou = bb_intersection_over_union(box_a, box_b) 69 | if iou >= iou_thres: 70 | if scores[i] > scores[j]: 71 | keep[j] = False 72 | else: 73 | is_keep = False 74 | break 75 | 76 | keep.append(is_keep) 77 | 78 | return np.array(keep).nonzero()[0] 79 | 80 | 81 | def batched_nms(boxes, scores, labels, iou_thres): 82 | a = [] 83 | for i in np.unique(labels): 84 | idx = (labels == i) 85 | idx = np.nonzero(idx)[0] 86 | i = nms_boxes(boxes[idx], scores[idx], iou_thres) 87 | idx = idx[i] 88 | a.append(idx) 89 | 90 | keep = np.concatenate(a) 91 | scores = scores[keep] 92 | idxs = np.argsort(-scores) 93 | keep = keep[idxs] 94 | 95 | return keep 96 | 97 | 98 | def packed_nms(boxes, scores, iou_thres): 99 | packed_idx = [] 100 | remained = np.argsort(-scores) 101 | while 0 < len(remained): 102 | idx = remained 103 | i = idx[0] 104 | candidates = [i] 105 | remained = [] 106 | for j in idx[1:]: 107 | similarity = bb_intersection_over_union(boxes[i], boxes[j]) 108 | if similarity > iou_thres: 109 | candidates.append(j) 110 | else: 111 | remained.append(j) 112 | 113 | packed_idx.append(candidates) 114 | 115 | return packed_idx 116 | -------------------------------------------------------------------------------- /test/util/params.py: -------------------------------------------------------------------------------- 1 | # ============================================================================= 2 | # Available input data modalities 3 | MODALITIES = ['image', 'video', 'audio'] 4 | 5 | # recognized extension (for glob.glob) 6 | EXTENSIONS = { 7 | 'image': ['*.png', '*.jpg', '*.[jJ][pP][eE][gG]', '*.bmp'], 8 | 'video': ['*.mp4'], 9 | 'audio': ['*.mp3', '*.wav'], 10 | } 11 | 12 | # ============================================================================= 13 | -------------------------------------------------------------------------------- /test/util/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import glob 5 | from logging import DEBUG 6 | 7 | from params import MODALITIES, EXTENSIONS 8 | import log_init 9 | 10 | # FIXME: Next two lines should be better to call from the main script 11 | # once we prepared one. For now, we do the initialization of logger here. 12 | logger = log_init.logger 13 | logger.info('Start!') 14 | 15 | # TODO: better to use them (first, fix above) 16 | # from logging import getLogger 17 | # logger = getLogger(__name__) 18 | 19 | 20 | # TODO: yaml config file and yaml loader 21 | 22 | try: 23 | import ailia 24 | AILIA_EXIST = True 25 | except ImportError: 26 | logger.warning('ailia package cannot be found under `sys.path`') 27 | logger.warning('default env_id is set to 0, you can change the id by ' 28 | '[--env_id N]') 29 | AILIA_EXIST = False 30 | 31 | 32 | def check_file_existance(filename): 33 | if os.path.isfile(filename): 34 | return True 35 | else: 36 | logger.error(f'{filename} not found') 37 | sys.exit() 38 | 39 | 40 | def get_base_parser( 41 | description, default_input, default_save, input_ftype='image', 42 | ): 43 | """ 44 | Get ailia default argument parser 45 | 46 | Parameters 47 | ---------- 48 | description : str 49 | default_input : str 50 | default input data (image / video) path 51 | default_save : str 52 | default save path 53 | input_ftype : str 54 | 55 | Returns 56 | ------- 57 | out : ArgumentParser() 58 | 59 | """ 60 | parser = argparse.ArgumentParser( 61 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 62 | description=description, 63 | conflict_handler='resolve', # allow to overwrite default argument 64 | ) 65 | parser.add_argument( 66 | '-i', '--input', metavar='IMAGE/VIDEO', default=default_input, 67 | help=('The default (model-dependent) input data (image / video) path. ' 68 | 'If a directory name is specified, the model will be run for ' 69 | 'the files inside. File type is specified by --ftype argument') 70 | ) 71 | parser.add_argument( 72 | '-v', '--video', metavar='VIDEO', default=None, 73 | help=('You can convert the input video by entering style image.' 74 | 'If the int variable is given, ' 75 | 'corresponding webcam input will be used.') 76 | ) 77 | parser.add_argument( 78 | '-s', '--savepath', metavar='SAVE_PATH', default=default_save, 79 | help='Save path for the output (image / video / text).' 80 | ) 81 | parser.add_argument( 82 | '-b', '--benchmark', action='store_true', 83 | help=('Running the inference on the same input 5 times to measure ' 84 | 'execution performance. (Cannot be used in video mode)') 85 | ) 86 | parser.add_argument( 87 | '-e', '--env_id', type=int, 88 | default=ailia.get_gpu_environment_id() if AILIA_EXIST else 0, 89 | help=('A specific environment id can be specified. By default, ' 90 | 'the return value of ailia.get_gpu_environment_id will be used') 91 | ) 92 | parser.add_argument( 93 | '--env_list', action='store_true', 94 | help='display environment list' 95 | ) 96 | parser.add_argument( 97 | '--ftype', metavar='FILE_TYPE', default=input_ftype, 98 | choices=MODALITIES, 99 | help='file type list: ' + ' | '.join(MODALITIES) 100 | ) 101 | parser.add_argument( 102 | '--debug', action='store_true', 103 | help='set default logger level to DEBUG (enable to show DEBUG logs)' 104 | ) 105 | parser.add_argument( 106 | '--profile', action='store_true', 107 | help='set profile mode (enable to show PROFILE logs)' 108 | ) 109 | parser.add_argument( 110 | '-bc', '--benchmark_count', metavar='BENCHMARK_COUNT', 111 | default=5, type=int, 112 | help='set iteration count of benchmark' 113 | ) 114 | return parser 115 | 116 | 117 | def update_parser(parser, check_input_type=True, large_model=False): 118 | """Default check or update configurations should be placed here 119 | 120 | Parameters 121 | ---------- 122 | parser : ArgumentParser() 123 | 124 | Returns 125 | ------- 126 | args : ArgumentParser() 127 | (parse_args() will be done here) 128 | """ 129 | args = parser.parse_args() 130 | 131 | # ------------------------------------------------------------------------- 132 | # 0. logger level update 133 | if args.debug: 134 | logger.setLevel(DEBUG) 135 | 136 | # ------------------------------------------------------------------------- 137 | # 1. check env_id count 138 | if AILIA_EXIST: 139 | count = ailia.get_environment_count() 140 | if count <= args.env_id: 141 | logger.error(f'specified env_id: {args.env_id} cannot found. ') 142 | logger.info('env_id updated to 0') 143 | args.env_id = 0 144 | 145 | if large_model: 146 | if args.env_id == ailia.get_gpu_environment_id() and ailia.get_environment(args.env_id).props == "LOWPOWER": 147 | args.env_id = 0 # cpu 148 | logger.warning('This model requires fuge gpu memory so fallback to cpu mode') 149 | 150 | if args.env_list: 151 | for idx in range(count) : 152 | env = ailia.get_environment(idx) 153 | logger.info(" env[" + str(idx) + "]=" + str(env)) 154 | 155 | if args.env_id == ailia.ENVIRONMENT_AUTO: 156 | args.env_id = ailia.get_gpu_environment_id() 157 | if args.env_id == ailia.ENVIRONMENT_AUTO: 158 | logger.info('env_id updated to 0') 159 | args.env_id = 0 160 | else: 161 | logger.info('env_id updated to ' + str(args.env_id) + '(from get_gpu_environment_id())') 162 | 163 | logger.info(f'env_id: {args.env_id}') 164 | 165 | env = ailia.get_environment(args.env_id) 166 | logger.info(f'{env.name}') 167 | 168 | # ------------------------------------------------------------------------- 169 | # 2. update input 170 | if args.video is not None: 171 | args.ftype = 'video' 172 | args.input = None # force video mode 173 | 174 | if args.input is None: 175 | # TODO: args.video, args.input is vague... 176 | # input is None --> video mode maybe? 177 | pass 178 | elif isinstance(args.input, list): 179 | # LIST --> nothing will be changed here. 180 | pass 181 | elif os.path.isdir(args.input): 182 | # Directory Path --> generate list of inputs 183 | files_grapped = [] 184 | in_dir = args.input 185 | for extension in EXTENSIONS[args.ftype]: 186 | files_grapped.extend(glob.glob(os.path.join(in_dir, extension))) 187 | logger.info(f'{len(files_grapped)} {args.ftype} files found!') 188 | 189 | args.input = sorted(files_grapped) 190 | 191 | # create save directory 192 | if args.savepath is None: 193 | pass 194 | else: 195 | if '.' in args.savepath: 196 | logger.warning('Please specify save directory as --savepath ' 197 | 'if you specified a direcotry for --input') 198 | logger.info(f'[{in_dir}_results] directory will be created') 199 | if in_dir[-1] == '/': 200 | in_dir = in_dir[:-1] 201 | args.savepath = in_dir + '_results' 202 | os.makedirs(args.savepath, exist_ok=True) 203 | logger.info(f'output saving directory: {args.savepath}') 204 | 205 | elif os.path.isfile(args.input): 206 | args.input = [args.input] 207 | else: 208 | if check_input_type: 209 | logger.error('specified input is not file path nor directory path') 210 | sys.exit(0) 211 | 212 | # ------------------------------------------------------------------------- 213 | return args 214 | 215 | 216 | def get_savepath(arg_path, src_path, prefix='', post_fix='_res', ext=None): 217 | """Get savepath 218 | NOTE: we may have better option... 219 | TODO: args.save_dir & args.save_path ? 220 | 221 | Parameters 222 | ---------- 223 | arg_path : str 224 | argument parser's savepath 225 | src_path : str 226 | the path of source path 227 | prefix : str, default is '' 228 | postfix : str, default is '_res' 229 | ext : str, default is None 230 | if you need to specify the extension, use this argument 231 | the argument has to start with '.' like '.png' or '.jpg' 232 | 233 | Returns 234 | ------- 235 | new_path : str 236 | """ 237 | 238 | if '.' in arg_path: 239 | # 1. args.savepath is actually the image path 240 | arg_base, arg_ext = os.path.splitext(arg_path) 241 | new_ext = arg_ext if ext is None else ext 242 | new_path = arg_base + new_ext 243 | else: 244 | # 2. args.savepath is save directory path 245 | src_base, src_ext = os.path.splitext(os.path.basename(src_path)) 246 | new_ext = src_ext if ext is None else ext 247 | new_path = os.path.join( 248 | arg_path, prefix + src_base + post_fix + new_ext 249 | ) 250 | return new_path 251 | -------------------------------------------------------------------------------- /test/util/webcamera_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import numpy as np 5 | import cv2 6 | 7 | from utils import check_file_existance 8 | from image_utils import normalize_image 9 | 10 | from logging import getLogger 11 | logger = getLogger(__name__) 12 | 13 | 14 | def calc_adjust_fsize(f_height, f_width, height, width): 15 | # calculate the image size of the output('img') of adjust_frame_size 16 | # This function is supposed to be used to declare 'cv2.writer' 17 | scale = np.max((f_height / height, f_width / width)) 18 | return int(scale * height), int(scale * width) 19 | 20 | 21 | def adjust_frame_size(frame, height, width): 22 | """ 23 | Adjust the size of the frame from the webcam to the ailia input shape. 24 | 25 | Parameters 26 | ---------- 27 | frame: numpy array 28 | height: int 29 | ailia model input height 30 | width: int 31 | ailia model input width 32 | 33 | Returns 34 | ------- 35 | img: numpy array 36 | Image with the propotions of height and width 37 | adjusted by padding for ailia model input. 38 | resized_img: numpy array 39 | Resized `img` as well as adapt the scale 40 | """ 41 | f_height, f_width = frame.shape[0], frame.shape[1] 42 | scale = np.max((f_height / height, f_width / width)) 43 | 44 | # padding base 45 | img = np.zeros( 46 | (int(round(scale * height)), int(round(scale * width)), 3), 47 | np.uint8 48 | ) 49 | start = (np.array(img.shape) - np.array(frame.shape)) // 2 50 | img[ 51 | start[0]: start[0] + f_height, 52 | start[1]: start[1] + f_width 53 | ] = frame 54 | resized_img = cv2.resize(img, (width, height)) 55 | return img, resized_img 56 | 57 | 58 | def cut_max_square(frame: np.array) -> np.array: 59 | """ 60 | Cut out a maximum square area from the center of given frame (np.array). 61 | Parameters 62 | ---------- 63 | frame: numpy array 64 | 65 | Returns 66 | ------- 67 | frame_square: numpy array 68 | Maximum square area of the frame at its center 69 | """ 70 | frame_height, frame_width, _ = frame.shape 71 | frame_size_min = min(frame_width, frame_height) 72 | if frame_width >= frame_height: 73 | x, y = frame_width // 2 - frame_height // 2, 0 74 | else: 75 | x, y = 0, frame_height // 2 - frame_width // 2 76 | 77 | frame_square = frame[y: (y + frame_size_min), x: (x + frame_size_min)] 78 | return frame_square 79 | 80 | 81 | def preprocess_frame( 82 | frame, input_height, input_width, data_rgb=True, normalize_type='255' 83 | ): 84 | """ 85 | Pre-process the frames taken from the webcam to input to ailia. 86 | 87 | Parameters 88 | ---------- 89 | frame: numpy array 90 | input_height: int 91 | ailia model input height 92 | input_width: int 93 | ailia model input width 94 | data_rgb: bool (default: True) 95 | Convert as rgb image when True, as gray scale image when False. 96 | Only `data` will be influenced by this configuration. 97 | normalize_type: string (default: 255) 98 | Normalize type should be chosen from the type below. 99 | - '255': simply dividing by 255.0 100 | - '127.5': output range : -1 and 1 101 | - 'ImageNet': normalize by mean and std of ImageNet 102 | - 'None': no normalization 103 | 104 | Returns 105 | ------- 106 | img: numpy array 107 | Image with the propotions of height and width 108 | adjusted by padding for ailia model input. 109 | data: numpy array 110 | Input data for ailia 111 | """ 112 | img, resized_img = adjust_frame_size(frame, input_height, input_width) 113 | 114 | if data_rgb: 115 | resized_img = cv2.cvtColor(resized_img, cv2.COLOR_BGR2RGB) 116 | 117 | data = normalize_image(resized_img, normalize_type) 118 | 119 | if data_rgb: 120 | data = np.rollaxis(data, 2, 0) 121 | data = np.expand_dims(data, axis=0).astype(np.float32) 122 | else: 123 | data = cv2.cvtColor(data.astype(np.float32), cv2.COLOR_BGR2GRAY) 124 | data = data[np.newaxis, np.newaxis, :, :] 125 | return img, data 126 | 127 | 128 | def get_writer(savepath, height, width, fps=20, rgb=True): 129 | """get cv2.VideoWriter 130 | 131 | Parameters 132 | ---------- 133 | save_path : str 134 | height : int 135 | width : int 136 | fps : int 137 | rgb : bool, default is True 138 | 139 | Returns 140 | ------- 141 | writer : cv2.VideoWriter() 142 | """ 143 | if os.path.isdir(savepath): 144 | savepath = savepath + "/out.mp4" 145 | 146 | writer = cv2.VideoWriter( 147 | savepath, 148 | # cv2.VideoWriter_fourcc(*'MJPG'), # avi mode 149 | cv2.VideoWriter_fourcc('m', 'p', '4', 'v'), # mp4 mode 150 | fps, 151 | (width, height), 152 | isColor=rgb 153 | ) 154 | return writer 155 | 156 | 157 | def get_capture(video): 158 | """ 159 | Get cv2.VideoCapture 160 | 161 | * TODO: maybe get capture & writer at the same time? 162 | * then, you can use capture frame size directory 163 | 164 | Parameters 165 | ---------- 166 | video : str 167 | webcamera-id or video path 168 | 169 | Returns 170 | ------- 171 | capture : cv2.VideoCapture 172 | """ 173 | try: 174 | video_id = int(video) 175 | 176 | # webcamera-mode 177 | capture = cv2.VideoCapture(video_id) 178 | if not capture.isOpened(): 179 | logger.error(f"webcamera (ID - {video_id}) not found") 180 | sys.exit(0) 181 | 182 | except ValueError: 183 | # if file path is given, open video file 184 | if check_file_existance(video): 185 | capture = cv2.VideoCapture(video) 186 | 187 | return capture 188 | -------------------------------------------------------------------------------- /test/yolox.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import sys 4 | import time 5 | 6 | import cv2 7 | import numpy as np 8 | import onnxruntime 9 | 10 | from yolox_utils import multiclass_nms, postprocess, predictions_to_object 11 | from yolox_utils import preproc as preprocess 12 | 13 | # import original modules 14 | sys.path.append('./util') 15 | # logger 16 | from logging import getLogger 17 | 18 | import webcamera_utils 19 | from detector_utils import (load_image, plot_results, reverse_letterbox, 20 | write_predictions) 21 | from image_utils import imread # noqa: E402 22 | from model_utils import check_and_download_models 23 | from utils import get_base_parser, get_savepath, update_parser 24 | 25 | logger = getLogger(__name__) 26 | 27 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' 28 | 29 | # ====================== 30 | # Parameters 31 | # ====================== 32 | MODEL_PARAMS = {'yolox_nano': {'input_shape': [416, 416]}, 33 | 'yolox_tiny': {'input_shape': [416, 416]}, 34 | 'yolox_s': {'input_shape': [640, 640]}, 35 | 'yolox_m': {'input_shape': [640, 640]}, 36 | 'yolox_l': {'input_shape': [640, 640]}, 37 | 'yolox_darknet': {'input_shape': [640, 640]}, 38 | 'yolox_x': {'input_shape': [640, 640]}} 39 | 40 | REMOTE_PATH = 'https://storage.googleapis.com/ailia-models/yolox/' 41 | 42 | IMAGE_PATH = 'input.jpg' 43 | SAVE_IMAGE_PATH = 'output.jpg' 44 | 45 | COCO_CATEGORY = [ 46 | "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", 47 | "truck", "boat", "traffic light", "fire hydrant", "stop sign", 48 | "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", 49 | "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", 50 | "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", 51 | "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", 52 | "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", 53 | "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", 54 | "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", 55 | "couch", "potted plant", "bed", "dining table", "toilet", "tv", 56 | "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", 57 | "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", 58 | "scissors", "teddy bear", "hair drier", "toothbrush" 59 | ] 60 | 61 | SCORE_THR = 0.4 62 | NMS_THR = 0.45 63 | 64 | # ====================== 65 | # Arguemnt Parser Config 66 | # ====================== 67 | parser = get_base_parser('yolox model', IMAGE_PATH, SAVE_IMAGE_PATH) 68 | parser.add_argument( 69 | '-m', '--model_name', 70 | default='yolox_tiny', 71 | help='[yolox_nano, yolox_tiny, yolox_s, yolox_m, yolox_l,' 72 | 'yolox_darknet, yolox_x]' 73 | ) 74 | parser.add_argument( 75 | '-w', '--write_prediction', 76 | action='store_true', 77 | help='Flag to output the prediction file.' 78 | ) 79 | parser.add_argument( 80 | '-th', '--threshold', 81 | default=SCORE_THR, type=float, 82 | help='The detection threshold for yolo. (default: '+str(SCORE_THR)+')' 83 | ) 84 | parser.add_argument( 85 | '-iou', '--iou', 86 | default=NMS_THR, type=float, 87 | help='The detection iou for yolo. (default: '+str(NMS_THR)+')' 88 | ) 89 | args = update_parser(parser) 90 | 91 | MODEL_NAME = args.model_name 92 | MODEL_PATH = "../models/yolox_tiny_quantized_per_tensor.onnx" 93 | 94 | HEIGHT = MODEL_PARAMS[MODEL_NAME]['input_shape'][0] 95 | WIDTH = MODEL_PARAMS[MODEL_NAME]['input_shape'][1] 96 | 97 | # ====================== 98 | # Main functions 99 | # ====================== 100 | def recognize_from_image(detector): 101 | # input image loop 102 | for image_path in args.input: 103 | # prepare input data 104 | logger.debug(f'input image: {image_path}') 105 | raw_img = imread(image_path, cv2.IMREAD_COLOR) 106 | img, ratio = preprocess(raw_img, (HEIGHT, WIDTH)) 107 | logger.debug(f'input image shape: {raw_img.shape}') 108 | 109 | def compute(): 110 | input_name = detector.get_inputs()[0].name 111 | print(input_name) 112 | return detector.run([], {input_name:img[None, :, :, :]}) 113 | 114 | # inference 115 | logger.info('Start inference...') 116 | if args.benchmark: 117 | logger.info('BENCHMARK mode') 118 | total_time = 0 119 | for i in range(args.benchmark_count): 120 | start = int(round(time.time() * 1000)) 121 | output = compute() 122 | end = int(round(time.time() * 1000)) 123 | if i != 0: 124 | total_time = total_time + (end - start) 125 | logger.info(f'\tailia processing time {end - start} ms') 126 | logger.info(f'\taverage time {total_time / (args.benchmark_count-1)} ms') 127 | else: 128 | output = compute() 129 | 130 | predictions = postprocess(output[0], (HEIGHT, WIDTH))[0] 131 | detect_object = predictions_to_object(predictions, raw_img, ratio, args.iou, args.threshold) 132 | detect_object = reverse_letterbox(detect_object, raw_img, (raw_img.shape[0], raw_img.shape[1])) 133 | res_img = plot_results(detect_object, raw_img, COCO_CATEGORY) 134 | 135 | # plot result 136 | savepath = get_savepath(args.savepath, image_path) 137 | logger.info(f'saved at : {savepath}') 138 | cv2.imwrite(savepath, res_img) 139 | 140 | # write prediction 141 | if args.write_prediction: 142 | pred_file = '%s.txt' % savepath.rsplit('.', 1)[0] 143 | write_predictions(pred_file, detect_object, raw_img, COCO_CATEGORY) 144 | 145 | logger.info('Script finished successfully.') 146 | 147 | def recognize_from_video(detector): 148 | capture = webcamera_utils.get_capture(args.video) 149 | 150 | # create video writer if savepath is specified as video format 151 | if args.savepath != SAVE_IMAGE_PATH: 152 | f_h = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) 153 | f_w = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)) 154 | save_h, save_w = f_h, f_w 155 | writer = webcamera_utils.get_writer(args.savepath, save_h, save_w) 156 | else: 157 | writer = None 158 | 159 | frame_shown = False 160 | while (True): 161 | ret, frame = capture.read() 162 | if (cv2.waitKey(1) & 0xFF == ord('q')) or not ret: 163 | break 164 | if frame_shown and cv2.getWindowProperty('frame', cv2.WND_PROP_VISIBLE) == 0: 165 | break 166 | 167 | raw_img = frame 168 | img, ratio = preprocess(raw_img, (HEIGHT, WIDTH)) 169 | output = detector.run(img[None, :, :, :]) 170 | predictions = postprocess(output[0], (HEIGHT, WIDTH))[0] 171 | detect_object = predictions_to_object(predictions, raw_img, ratio, args.iou, args.threshold) 172 | detect_object = reverse_letterbox(detect_object, raw_img, (raw_img.shape[0], raw_img.shape[1])) 173 | res_img = plot_results(detect_object, raw_img, COCO_CATEGORY) 174 | cv2.imshow('frame', res_img) 175 | frame_shown = True 176 | 177 | # save results 178 | if writer is not None: 179 | writer.write(res_img) 180 | 181 | capture.release() 182 | cv2.destroyAllWindows() 183 | if writer is not None: 184 | writer.release() 185 | logger.info('Script finished successfully.') 186 | 187 | def main(): 188 | detector = onnxruntime.InferenceSession(MODEL_PATH) 189 | 190 | if args.video is not None: 191 | # video mode 192 | recognize_from_video(detector) 193 | else: 194 | # image mode 195 | recognize_from_image(detector) 196 | 197 | 198 | if __name__ == '__main__': 199 | main() 200 | -------------------------------------------------------------------------------- /test/yolox_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import math 5 | import sys 6 | sys.path.append('./util') 7 | from detector_utils import DetectorObject 8 | 9 | import cv2 10 | 11 | def preproc(img, input_size, swap=(2, 0, 1)): 12 | if len(img.shape) == 3: 13 | padded_img = np.ones((input_size[0], input_size[1], img.shape[2]), dtype=np.uint8) * 114 14 | else: 15 | padded_img = np.ones(input_size, dtype=np.uint8) * 114 16 | 17 | r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) 18 | resized_img = cv2.resize( 19 | img, 20 | (int(img.shape[1] * r), int(img.shape[0] * r)), 21 | interpolation=cv2.INTER_LINEAR, 22 | ).astype(np.uint8) 23 | padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img 24 | 25 | padded_img = padded_img.transpose(swap) 26 | padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) 27 | return padded_img, r 28 | 29 | 30 | def nms(boxes, scores, nms_thr): 31 | """Single class NMS implemented in Numpy.""" 32 | x1 = boxes[:, 0] 33 | y1 = boxes[:, 1] 34 | x2 = boxes[:, 2] 35 | y2 = boxes[:, 3] 36 | 37 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 38 | order = scores.argsort()[::-1] 39 | 40 | keep = [] 41 | while order.size > 0: 42 | i = order[0] 43 | keep.append(i) 44 | xx1 = np.maximum(x1[i], x1[order[1:]]) 45 | yy1 = np.maximum(y1[i], y1[order[1:]]) 46 | xx2 = np.minimum(x2[i], x2[order[1:]]) 47 | yy2 = np.minimum(y2[i], y2[order[1:]]) 48 | 49 | w = np.maximum(0.0, xx2 - xx1 + 1) 50 | h = np.maximum(0.0, yy2 - yy1 + 1) 51 | inter = w * h 52 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 53 | 54 | inds = np.where(ovr <= nms_thr)[0] 55 | order = order[inds + 1] 56 | 57 | return keep 58 | 59 | 60 | def multiclass_nms(boxes, scores, nms_thr, score_thr, class_agnostic=True): 61 | """Multiclass NMS implemented in Numpy""" 62 | if class_agnostic: 63 | nms_method = multiclass_nms_class_agnostic 64 | else: 65 | nms_method = multiclass_nms_class_aware 66 | return nms_method(boxes, scores, nms_thr, score_thr) 67 | 68 | 69 | def multiclass_nms_class_aware(boxes, scores, nms_thr, score_thr): 70 | """Multiclass NMS implemented in Numpy. Class-aware version.""" 71 | final_dets = [] 72 | num_classes = scores.shape[1] 73 | for cls_ind in range(num_classes): 74 | cls_scores = scores[:, cls_ind] 75 | valid_score_mask = cls_scores > score_thr 76 | if valid_score_mask.sum() == 0: 77 | continue 78 | else: 79 | valid_scores = cls_scores[valid_score_mask] 80 | valid_boxes = boxes[valid_score_mask] 81 | keep = nms(valid_boxes, valid_scores, nms_thr) 82 | if len(keep) > 0: 83 | cls_inds = np.ones((len(keep), 1)) * cls_ind 84 | dets = np.concatenate( 85 | [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1 86 | ) 87 | final_dets.append(dets) 88 | if len(final_dets) == 0: 89 | return None 90 | return np.concatenate(final_dets, 0) 91 | 92 | 93 | def multiclass_nms_class_agnostic(boxes, scores, nms_thr, score_thr): 94 | """Multiclass NMS implemented in Numpy. Class-agnostic version.""" 95 | cls_inds = scores.argmax(1) 96 | cls_scores = scores[np.arange(len(cls_inds)), cls_inds] 97 | 98 | valid_score_mask = cls_scores > score_thr 99 | if valid_score_mask.sum() == 0: 100 | return None 101 | valid_scores = cls_scores[valid_score_mask] 102 | valid_boxes = boxes[valid_score_mask] 103 | valid_cls_inds = cls_inds[valid_score_mask] 104 | keep = nms(valid_boxes, valid_scores, nms_thr) 105 | if keep: 106 | dets = np.concatenate( 107 | [valid_boxes[keep], valid_scores[keep, None], valid_cls_inds[keep, None]], 1 108 | ) 109 | return dets 110 | 111 | 112 | def postprocess(outputs, img_size, p6=False): 113 | 114 | grids = [] 115 | expanded_strides = [] 116 | 117 | if not p6: 118 | strides = [8, 16, 32] 119 | else: 120 | strides = [8, 16, 32, 64] 121 | 122 | hsizes = [img_size[0] // stride for stride in strides] 123 | wsizes = [img_size[1] // stride for stride in strides] 124 | 125 | for hsize, wsize, stride in zip(hsizes, wsizes, strides): 126 | xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize)) 127 | grid = np.stack((xv, yv), 2).reshape(1, -1, 2) 128 | grids.append(grid) 129 | shape = grid.shape[:2] 130 | expanded_strides.append(np.full((*shape, 1), stride)) 131 | 132 | grids = np.concatenate(grids, 1) 133 | expanded_strides = np.concatenate(expanded_strides, 1) 134 | outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides 135 | outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides 136 | 137 | return outputs 138 | 139 | 140 | def predictions_to_object(predictions,raw_img,ratio,nms_thr,score_thr): 141 | boxes = predictions[:, :4] 142 | scores = predictions[:, 4:5] * predictions[:, 5:] 143 | 144 | boxes_xyxy = np.ones_like(boxes) 145 | boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2] / 2. 146 | boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3] / 2. 147 | boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2] / 2. 148 | boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3] / 2. 149 | boxes_xyxy /= ratio 150 | dets = multiclass_nms(boxes_xyxy, scores, nms_thr, score_thr) 151 | 152 | detect_object = [] 153 | if dets is not None: 154 | img_size_h, img_size_w = raw_img.shape[:2] 155 | final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5] 156 | for i, box in enumerate(final_boxes): 157 | x1, y1, x2, y2 = box 158 | c = int(final_cls_inds[i]) 159 | r = DetectorObject( 160 | category=c, 161 | prob=final_scores[i], 162 | x=x1 / img_size_w, 163 | y=y1 / img_size_h, 164 | w=(x2 - x1) / img_size_w, 165 | h=(y2 - y1) / img_size_h, 166 | ) 167 | detect_object.append(r) 168 | 169 | return detect_object 170 | --------------------------------------------------------------------------------