├── .gitignore ├── LICENSE ├── README.md ├── app.py ├── hand.png ├── keypoint_classification.ipynb ├── model ├── __init__.py ├── hand_landmark │ ├── hand_landmark.py │ └── hand_landmark_sparse_Nx3x224x224.onnx ├── keypoint_classifier │ ├── keypoint.csv │ ├── keypoint_classifier.hdf5 │ ├── keypoint_classifier.onnx │ ├── keypoint_classifier.py │ ├── keypoint_classifier.tflite │ ├── keypoint_classifier_label.csv │ ├── make_argmax.py │ └── tflite_to_onnx.sh ├── palm_detection │ ├── palm_detection.py │ └── palm_detection_full_inf_post_192x192.onnx └── point_history_classifier │ ├── argmax.onnx │ ├── make_argmax.py │ ├── point_history.csv │ ├── point_history_classifier.hdf5 │ ├── point_history_classifier.onnx │ ├── point_history_classifier.py │ ├── point_history_classifier.tflite │ ├── point_history_classifier_label.csv │ ├── point_history_classifier_lstm.onnx │ └── tflite_to_onnx.sh ├── point_history_classification.ipynb ├── requirements.txt └── utils ├── __init__.py ├── cvfpscalc.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | __pycache__/ 3 | *.engine 4 | *.profile 5 | *.mp4 6 | test_distance_calc.py -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2022 Katsuya Hyodo 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [Japanese / [WIP] English] 2 | 3 | --- 4 | # hand-gesture-recognition-using-onnx 5 | [Kazuhito00/hand-gesture-recognition-using-mediapipe](https://github.com/Kazuhito00/hand-gesture-recognition-using-mediapipe) [@Kazuhito00](https://github.com/Kazuhito00) を引用させていただき、MediaPipeの実装を全てONNXにリプレースしたリポジトリです。簡易なMLPでハンドサインとフィンガージェスチャーを認識するサンプルプログラムです。とりあえず動く、ということしか確認していません。 6 | 7 | https://user-images.githubusercontent.com/33194443/189632510-1823cd54-bc36-4889-ac14-adc16deba9b8.mp4 8 | 9 | 本リポジトリは以下の内容を含みます。 10 | - サンプルプログラム 11 | - ハンドディテクションモデル (改造ONNX) 12 | - 手のひらランドマークディテクションモデル (改造ONNX) 13 | - ハンドサイン認識モデル (改造ONNX) 14 | - フィンガージェスチャー認識モデル (改造ONNX) 15 | - ハンドサイン認識用学習データ、および、学習用ノートブック 16 | - フィンガージェスチャー認識用学習データ、および、学習用ノートブック 17 | 18 | # Requirements 19 | - onnxruntime 1.12.0 or onnxruntime-gpu 1.12.0 20 | - opencv-contrib-python 4.6.0.66 or Later 21 | - Tensorflow 2.10.0 (学習後にONNXファイルを再作成する場合のみ) 22 | - PyTorch 1.12.0 (学習後にONNXファイルを再作成する場合のみ) 23 | - tf2onnx 1.12.0 or Later (学習後にONNXファイルを再作成する場合のみ) 24 | - simple-onnx-processing-tools 1.0.54 or Later (学習後にONNXファイルを再作成する場合のみ) 25 | - scikit-learn 0.23.2 or Later (学習時に混同行列を表示したい場合のみ) 26 | - matplotlib 3.3.2 or Later (学習時に混同行列を表示したい場合のみ) 27 | 28 | # Demo 29 | Webカメラを使ったデモの実行方法は以下です。 30 | ```bash 31 | python app.py 32 | ``` 33 | 34 | デモ実行時には、以下のオプションが指定可能です。 35 | ``` 36 | --device 37 | カメラデバイス番号の指定 (デフォルト:0) 38 | 39 | --width 40 | カメラキャプチャ時の横幅 (デフォルト:640) 41 | 42 | --height 43 | カメラキャプチャ時の縦幅 (デフォルト:480) 44 | 45 | --min_detection_confidence 46 | 検出信頼値の閾値 (デフォルト:0.6) 47 | 48 | --disable_image_flip 49 | 入力映像の左右反転を無効化 50 | ``` 51 | 52 | # Directory 53 | ``` 54 | . 55 | │ app.py 56 | │ keypoint_classification.ipynb 57 | │ point_history_classification.ipynb 58 | │ requirements.txt 59 | │ README.md 60 | │ 61 | ├─model 62 | │ ├─keypoint_classifier 63 | │ │ │ tflite_to_onnx.sh 64 | │ │ │ make_argmax.py 65 | │ │ │ keypoint.csv 66 | │ │ │ keypoint_classifier.hdf5 67 | │ │ │ keypoint_classifier.py 68 | │ │ │ keypoint_classifier.tflite 69 | │ │ │ keypoint_classifier.onnx 70 | │ │ └─ keypoint_classifier_label.csv 71 | │ │ 72 | │ └─point_history_classifier 73 | │ │ tflite_to_onnx.sh 74 | │ │ make_argmax.py 75 | │ │ point_history.csv 76 | │ │ point_history_classifier.hdf5 77 | │ │ point_history_classifier.py 78 | │ │ point_history_classifier.tflite 79 | │ │ point_history_classifier.onnx 80 | │ └─ point_history_classifier_label.csv 81 | │ 82 | └─utils 83 | │ cvfpscalc.py 84 | └─ utils.py 85 | ``` 86 | ### app.py 87 | 推論用のサンプルプログラムです。
また、ハンドサイン認識用の学習データ(キーポイント)、
88 | フィンガージェスチャー認識用の学習データ(人差指の座標履歴)を収集することもできます。 89 | 90 | ### keypoint_classification.ipynb 91 | ハンドサイン認識用のモデル訓練用スクリプトです。 92 | 93 | ### point_history_classification.ipynb 94 | フィンガージェスチャー認識用のモデル訓練用スクリプトです。 95 | 96 | ### model/keypoint_classifier 97 | ハンドサイン認識に関わるファイルを格納するディレクトリです。
98 | 以下のファイルが格納されます。 99 | - tfliteのONNX変換スクリプト (tflite_to_onnx.sh) 100 | - ONNX部品生成プログラム (make_argmax.py) 101 | - 学習用データ (keypoint.csv) 102 | - 学習済モデル (keypoint_classifier.tflite) 103 | - 学習済モデル (keypoint_classifier.onnx) 104 | - ラベルデータ (keypoint_classifier_label.csv) 105 | - 推論用クラス (keypoint_classifier.py) 106 | 107 | ### model/point_history_classifier 108 | フィンガージェスチャー認識に関わるファイルを格納するディレクトリです。
109 | 以下のファイルが格納されます。 110 | - tfliteのONNX変換スクリプト (tflite_to_onnx.sh) 111 | - ONNX部品生成プログラム (make_argmax.py) 112 | - 学習用データ (point_history.csv) 113 | - 学習済モデル (point_history_classifier.tflite) 114 | - 学習済モデル (point_history_classifier.onnx) 115 | - ラベルデータ (point_history_classifier_label.csv) 116 | - 推論用クラス (point_history_classifier.py) 117 | 118 | ### utils/cvfpscalc.py 119 | FPS計測用のモジュールです。 120 | 121 | ### utils/utils.py 122 | 画像加工用の関数群です。 123 | 124 | # Training 125 | ハンドサイン認識、フィンガージェスチャー認識は、
学習データの追加、変更、モデルの再トレーニングが出来ます。 126 | 127 | ### ハンドサイン認識トレーニング方法 128 | #### 1.学習データ収集 129 | 「k」を押すと、キーポイントの保存するモードになります(「MODE:Logging Key Point」と表示される)
130 |

131 | 「0」~「9」を押すと「model/keypoint_classifier/keypoint.csv」に以下のようにキーポイントが追記されます。
132 | 1列目:押下した数字(クラスIDとして使用)、2列目:trackid、3列目以降:キーポイント座標
133 |

134 | キーポイント座標は以下の前処理を④まで実施したものを保存します。
135 | 136 |

137 | 初期状態では、パー(クラスID:0)、グー(クラスID:1)、指差し(クラスID:2)の3種類の学習データが入っています。
138 | 必要に応じて3以降を追加したり、csvの既存データを削除して、学習データを用意してください。
139 |    140 | 141 | #### 2.モデル訓練 142 | 「[keypoint_classification.ipynb](keypoint_classification.ipynb)」をJupyter Notebookで開いて上から順に実行してください。
143 | 学習データのクラス数を変更する場合は「NUM_CLASSES = 3」の値を変更し、
「model/keypoint_classifier/keypoint_classifier_label.csv」のラベルを適宜修正してください。

144 | 145 | #### X.モデル構造 146 | 「[keypoint_classification.ipynb](keypoint_classification.ipynb)」で用意しているモデルのイメージは以下です。 147 |

148 | 149 | ### フィンガージェスチャー認識トレーニング方法 150 | #### 1.学習データ収集 151 | 「h」を押すと、指先座標の履歴を保存するモードになります(「MODE:Logging Point History」と表示される)
152 |

153 | 「0」~「9」を押すと「model/point_history_classifier/point_history.csv」に以下のようにキーポイントが追記されます。
154 | 1列目:押下した数字(クラスIDとして使用)、2列目:trackid、3列目以降:座標履歴
155 |

156 | キーポイント座標は以下の前処理を④まで実施したものを保存します。
157 |

158 | 初期状態では、静止(クラスID:0)、時計回り(クラスID:1)、反時計回り(クラスID:2)、移動(クラスID:4)の
4種類の学習データが入っています。
159 | 必要に応じて5以降を追加したり、csvの既存データを削除して、学習データを用意してください。
160 |     161 | 162 | #### 2.モデル訓練 163 | 「[point_history_classification.ipynb](point_history_classification.ipynb)」をJupyter Notebookで開いて上から順に実行してください。
164 | 学習データのクラス数を変更する場合は「NUM_CLASSES = 4」の値を変更し、
「model/point_history_classifier/point_history_classifier_label.csv」のラベルを適宜修正してください。

165 | 166 | #### X.モデル構造 167 | 「[point_history_classification.ipynb](point_history_classification.ipynb)」で用意しているモデルのイメージは以下です。 168 |
169 | 「LSTM」を用いたモデルは以下です。
使用する際には「use_lstm = False」を「True」に変更してください
170 | 171 | 172 | # Application example 173 | - None 174 | 175 | # Reference 176 | - [Kazuhito00/hand-gesture-recognition-using-mediapipe](https://github.com/Kazuhito00/hand-gesture-recognition-using-mediapipe) 177 | - [geaxgx/depthai_hand_tracker](https://github.com/geaxgx/depthai_hand_tracker) 178 | - [PINTO0309/simple-onnx-processing-tools](https://github.com/PINTO0309/simple-onnx-processing-tools) 179 | - https://github.com/PINTO0309/PINTO_model_zoo/tree/main/033_Hand_Detection_and_Tracking 180 | 181 | # Acknowledgments 182 | - hand-gesture-recognition-using-mediapipe: 高橋かずひと (https://twitter.com/KzhtTkhs) 183 | - [Apache-2.0 License](https://github.com/Kazuhito00/hand-gesture-recognition-using-mediapipe/blob/main/LICENSE) 184 | - depthai_hand_tracker: geaxgx (https://github.com/geaxgx/depthai_hand_tracker) 185 | - [MIT License](https://github.com/geaxgx/depthai_hand_tracker/blob/main/LICENSE.txt) 186 | 187 | # Author 188 | - hand-gesture-recognition-using-onnx: Katsuya Hyodo 189 | - [Apache-2.0 License](LICENSE) 190 | 191 | # ToDo 192 | - [x] LSTM ONNX 193 | - [ ] BPF (body_pre_focusing) の実装 194 | - [ ] 変数とループの整理 195 | - [ ] CSVロギングの見直し 196 | - [ ] DeepSort 197 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import csv 4 | import copy 5 | import argparse 6 | import itertools 7 | from math import degrees 8 | from collections import Counter 9 | from collections import deque 10 | 11 | import cv2 as cv 12 | import numpy as np 13 | 14 | from utils import CvFpsCalc 15 | from utils.utils import rotate_and_crop_rectangle 16 | from model import PalmDetection 17 | from model import HandLandmark 18 | from model import KeyPointClassifier 19 | from model import PointHistoryClassifier 20 | 21 | 22 | def get_args(): 23 | parser = argparse.ArgumentParser() 24 | 25 | parser.add_argument( 26 | '-d', 27 | '--device', 28 | type=int, 29 | default=0, 30 | ) 31 | parser.add_argument( 32 | '-im', 33 | '--image', 34 | type=str, 35 | default='', 36 | ) 37 | parser.add_argument( 38 | '-wi', 39 | '--width', 40 | help='cap width', 41 | type=int, 42 | default=640, 43 | ) 44 | parser.add_argument( 45 | '-he', 46 | '--height', 47 | help='cap height', 48 | type=int, 49 | default=480, 50 | ) 51 | parser.add_argument( 52 | '-mdc', 53 | '--min_detection_confidence', 54 | help='min_detection_confidence', 55 | type=float, 56 | default=0.6, 57 | ) 58 | parser.add_argument( 59 | '-dif', 60 | '--disable_image_flip', 61 | help='disable image flip', 62 | action='store_true', 63 | ) 64 | 65 | 66 | args = parser.parse_args() 67 | 68 | return args 69 | 70 | 71 | def main(): 72 | # 引数解析 ################################################################# 73 | args = get_args() 74 | 75 | if not args.image: 76 | cap_device = args.device 77 | else: 78 | cap_device = args.image 79 | cap_width = args.width 80 | cap_height = args.height 81 | min_detection_confidence = args.min_detection_confidence 82 | 83 | lines_hand = [ 84 | [0,1],[1,2],[2,3],[3,4], 85 | [0,5],[5,6],[6,7],[7,8], 86 | [5,9],[9,10],[10,11],[11,12], 87 | [9,13],[13,14],[14,15],[15,16], 88 | [13,17],[17,18],[18,19],[19,20],[0,17], 89 | ] 90 | 91 | # カメラ準備 ############################################################### 92 | cap = cv.VideoCapture(cap_device) 93 | cap.set(cv.CAP_PROP_FRAME_WIDTH, cap_width) 94 | cap.set(cv.CAP_PROP_FRAME_HEIGHT, cap_height) 95 | cap_fps = cap.get(cv.CAP_PROP_FPS) 96 | fourcc = cv.VideoWriter_fourcc('m', 'p', '4', 'v') 97 | video_writer = cv.VideoWriter( 98 | filename='output.mp4', 99 | fourcc=fourcc, 100 | fps=cap_fps, 101 | frameSize=(cap_width, cap_height), 102 | ) 103 | 104 | # モデルロード ############################################################# 105 | palm_detection = PalmDetection(score_threshold=min_detection_confidence) 106 | hand_landmark = HandLandmark() 107 | 108 | keypoint_classifier = KeyPointClassifier() 109 | point_history_classifier = PointHistoryClassifier() 110 | 111 | # ラベル読み込み ########################################################### 112 | with open( 113 | 'model/keypoint_classifier/keypoint_classifier_label.csv', 114 | encoding='utf-8-sig', 115 | ) as f: 116 | keypoint_classifier_labels = csv.reader(f) 117 | keypoint_classifier_labels = [ 118 | row[0] for row in keypoint_classifier_labels 119 | ] 120 | with open( 121 | 'model/point_history_classifier/point_history_classifier_label.csv', 122 | encoding='utf-8-sig', 123 | ) as f: 124 | point_history_classifier_labels = csv.reader(f) 125 | point_history_classifier_labels = [ 126 | row[0] for row in point_history_classifier_labels 127 | ] 128 | 129 | # FPS計測モジュール ######################################################## 130 | cvFpsCalc = CvFpsCalc(buffer_len=10) 131 | 132 | # 座標履歴 ################################################################ 133 | history_length = 16 134 | # point_history = deque(maxlen=history_length) 135 | point_history = {} 136 | pre_point_history = {} 137 | 138 | # フィンガージェスチャー履歴 ################################################# 139 | gesture_history_length = 10 140 | finger_gesture_history = {} 141 | 142 | # 手のひらトラッキング用手のひら中心座標最新履歴 ################################# 143 | # { 144 | # int(trackid1): [cx, cy], 145 | # int(trackid2): [cx, cy], 146 | # int(trackid3): [cx, cy], 147 | # : 148 | # } 149 | palm_trackid_cxcy = {} 150 | 151 | # ####################################################################### 152 | mode = 0 153 | wh_ratio = cap_width / cap_height 154 | 155 | auto = False 156 | prev_number = -1 157 | image = None 158 | 159 | while True: 160 | fps = cvFpsCalc.get() 161 | 162 | # キー処理(ESC:終了) ################################################# 163 | key = cv.waitKey(1) if not args.image else cv.waitKey(0) if image is not None and args.image else cv.waitKey(1) 164 | if key == 27: # ESC 165 | break 166 | number, mode, auto, prev_number = select_mode(key, mode, auto, prev_number) 167 | 168 | # カメラキャプチャ ##################################################### 169 | ret, image = cap.read() 170 | if not ret: 171 | break 172 | image = image if args.disable_image_flip else cv.flip(image, 1) # ミラー表示 173 | debug_image = copy.deepcopy(image) 174 | 175 | # 検出実施 ############################################################# 176 | 177 | # ============================================================= PalmDetection 178 | # ハンドディテクション - シングルバッチ処理 179 | hands = palm_detection(image) 180 | # hand: sqn_rr_size, rotation, sqn_rr_center_x, sqn_rr_center_y 181 | 182 | rects = [] 183 | not_rotate_rects = [] 184 | rects_tuple = None 185 | cropted_rotated_hands_images = [] 186 | 187 | # 手の検出件数がゼロになったらトラッキング用手のひら中心座標最新履歴を初期化 188 | if len(hands) == 0: 189 | palm_trackid_cxcy = {} 190 | # トラッキング用手のひら中心座標最新履歴とバウンディングボックスの検出順序紐づけリスト 191 | palm_trackid_box_x1y1s = {} 192 | 193 | if len(hands) > 0: 194 | for hand in hands: 195 | # hand: sqn_rr_size, rotation, sqn_rr_center_x, sqn_rr_center_y 196 | sqn_rr_size = hand[0] 197 | rotation = hand[1] 198 | sqn_rr_center_x = hand[2] 199 | sqn_rr_center_y = hand[3] 200 | 201 | cx = int(sqn_rr_center_x * cap_width) 202 | cy = int(sqn_rr_center_y * cap_height) 203 | xmin = int((sqn_rr_center_x - (sqn_rr_size / 2)) * cap_width) 204 | xmax = int((sqn_rr_center_x + (sqn_rr_size / 2)) * cap_width) 205 | ymin = int((sqn_rr_center_y - (sqn_rr_size * wh_ratio / 2)) * cap_height) 206 | ymax = int((sqn_rr_center_y + (sqn_rr_size * wh_ratio / 2)) * cap_height) 207 | xmin = max(0, xmin) 208 | xmax = min(cap_width, xmax) 209 | ymin = max(0, ymin) 210 | ymax = min(cap_height, ymax) 211 | degree = degrees(rotation) 212 | # [boxcount, cx, cy, width, height, degree] 213 | rects.append([cx, cy, (xmax-xmin), (ymax-ymin), degree]) 214 | 215 | rects = np.asarray(rects, dtype=np.float32) 216 | 217 | # 回転角度をゼロ度に補正した手のひら画像の取得 218 | cropted_rotated_hands_images = rotate_and_crop_rectangle( 219 | image=image, 220 | rects_tmp=rects, 221 | operation_when_cropping_out_of_range='padding', 222 | ) 223 | 224 | # Debug =============================================================== 225 | for rect in rects: 226 | # 回転考慮の領域の描画, 赤色の枠 227 | rects_tuple = ((rect[0], rect[1]), (rect[2], rect[3]), rect[4]) 228 | box = cv.boxPoints(rects_tuple).astype(np.int0) 229 | cv.drawContours(debug_image, [box], 0,(0,0,255), 2, cv.LINE_AA) 230 | 231 | # 回転非考慮の領域の描画, オレンジ色の枠 232 | rcx = int(rect[0]) 233 | rcy = int(rect[1]) 234 | half_w = int(rect[2] // 2) 235 | half_h = int(rect[3] // 2) 236 | x1 = rcx - half_w 237 | y1 = rcy - half_h 238 | x2 = rcx + half_w 239 | y2 = rcy + half_h 240 | text_x = max(x1, 10) 241 | text_x = min(text_x, cap_width-120) 242 | text_y = max(y1-15, 45) 243 | text_y = min(text_y, cap_height-20) 244 | # [boxcount, rcx, rcy, x1, y1, x2, y2, height, degree] 245 | not_rotate_rects.append([rcx, rcy, x1, y1, x2, y2, 0]) 246 | # 検出枠のサイズ WxH 247 | cv.putText( 248 | debug_image, 249 | f'{y2-y1}x{x2-x1}', 250 | (text_x, text_y), 251 | cv.FONT_HERSHEY_SIMPLEX, 252 | 0.8, 253 | (0,0,0), 254 | 2, 255 | cv.LINE_AA, 256 | ) 257 | cv.putText( 258 | debug_image, 259 | f'{y2-y1}x{x2-x1}', 260 | (text_x, text_y), 261 | cv.FONT_HERSHEY_SIMPLEX, 262 | 0.8, 263 | (59,255,255), 264 | 1, 265 | cv.LINE_AA, 266 | ) 267 | # 検出枠の描画 268 | cv.rectangle( 269 | debug_image, 270 | (x1,y1), 271 | (x2,y2), 272 | (0,128,255), 273 | 2, 274 | cv.LINE_AA, 275 | ) 276 | # 検出領域の中心座標描画 277 | cv.circle( 278 | debug_image, 279 | (rcx, rcy), 280 | 3, 281 | (0, 255, 255), 282 | -1, 283 | ) 284 | """ 285 | 手のひらトラッキング用手のひら中心座標最新履歴の保存 286 | 1. 過去履歴の中から基準点との距離が一番近い中心座標を抽出 287 | 2. 距離が100pxより離れている場合は新たな手のひらと認識させる 288 | 3. 距離が100px以下の場合は該当のtrackidを割り当てて過去履歴の中心座標を上書きする 289 | """ 290 | # 1. 過去履歴の中から基準点との距離が一番近い中心座標を抽出 291 | base_point = np.asarray( 292 | [rcx, rcy], 293 | dtype=np.float32, 294 | ) 295 | points = np.asarray( 296 | list(palm_trackid_cxcy.values()), 297 | dtype=np.float32, 298 | ) 299 | if len(points) > 0: 300 | # 最近傍点探索 301 | diff_val = points - base_point 302 | all_points_distance = np.linalg.norm(diff_val, axis=1) 303 | nearest_trackid = np.argmin(all_points_distance) 304 | nearest_distance = all_points_distance[nearest_trackid] 305 | new_trackid = int(nearest_trackid) + 1 306 | # 2. 距離が100pxより離れている場合は新たな手のひらと認識させる 307 | # 3. 距離が100px以下の場合は該当のtrackidを割り当てて過去履歴の中心座標を上書きする 308 | if nearest_distance > 100: 309 | # 現状のtrackid最大値+1を新規trackidとして生成 310 | new_trackid = next(iter(reversed(palm_trackid_cxcy))) + 1 311 | else: 312 | # trackid初期値 313 | new_trackid = 1 314 | 315 | # 手のひらトラッキング用手のひら中心座標最新履歴の最新座標を更新 または 新規追加 316 | palm_trackid_cxcy[new_trackid] = [rcx, rcy] 317 | # バウンディングボックスの検出順序とtrackidの順序を整合 318 | # box_x1y1x2y2_palm_trackids.append([x1, y1, x2, y2, new_trackid]) 319 | palm_trackid_box_x1y1s[new_trackid] = [x1, y1] 320 | # Debug =============================================================== 321 | 322 | # ============================================================= HandLandmark 323 | if len(cropted_rotated_hands_images) > 0: 324 | 325 | # Inference HandLandmark - バッチ処理 326 | hand_landmarks, rotated_image_size_leftrights = hand_landmark( 327 | images=cropted_rotated_hands_images, 328 | rects=rects, 329 | ) 330 | 331 | if len(hand_landmarks) > 0: 332 | # Draw 333 | pre_processed_landmarks = [] 334 | pre_processed_point_histories = [] 335 | for (trackid, x1y1), landmark, rotated_image_size_leftright, not_rotate_rect in \ 336 | zip(palm_trackid_box_x1y1s.items(), hand_landmarks, rotated_image_size_leftrights, not_rotate_rects): 337 | 338 | x1, y1 = x1y1 339 | rotated_image_width, _, left_hand_0_or_right_hand_1 = rotated_image_size_leftright 340 | thick_coef = rotated_image_width / 400 341 | lines = np.asarray( 342 | [ 343 | np.array([landmark[point] for point in line]).astype(np.int32) for line in lines_hand 344 | ] 345 | ) 346 | radius = int(1+thick_coef*5) 347 | cv.polylines( 348 | debug_image, 349 | lines, 350 | False, 351 | (255, 0, 0), 352 | int(radius), 353 | cv.LINE_AA, 354 | ) 355 | _ = [cv.circle(debug_image, (int(x), int(y)), radius, (0,128,255), -1) for x,y in landmark[:,:2]] 356 | left_hand_0_or_right_hand_1 = left_hand_0_or_right_hand_1 if args.disable_image_flip else (1 - left_hand_0_or_right_hand_1) 357 | handedness = 'Left ' if left_hand_0_or_right_hand_1 == 0 else 'Right' 358 | _, _, x1, y1, _, _, _ = not_rotate_rect 359 | text_x = max(x1, 10) 360 | text_x = min(text_x, cap_width-120) 361 | text_y = max(y1-70, 20) 362 | text_y = min(text_y, cap_height-70) 363 | cv.putText( 364 | debug_image, 365 | f'trackid:{trackid} {handedness}', 366 | (text_x, text_y), 367 | cv.FONT_HERSHEY_SIMPLEX, 368 | 0.8, 369 | (0,0,0), 370 | 2, 371 | cv.LINE_AA, 372 | ) 373 | cv.putText( 374 | debug_image, 375 | f'trackid:{trackid} {handedness}', 376 | (text_x, text_y), 377 | cv.FONT_HERSHEY_SIMPLEX, 378 | 0.8, 379 | (59,255,255), 380 | 1, 381 | cv.LINE_AA, 382 | ) 383 | 384 | # 相対座標・正規化座標への変換 385 | """ 386 | pre_processed_landmark: np.ndarray [42], [x,y]x21 387 | """ 388 | pre_processed_landmark = pre_process_landmark( 389 | landmark, 390 | ) 391 | pre_processed_landmarks.append(pre_processed_landmark) 392 | 393 | """ 394 | point_history: dict 395 | { 396 | int(trackid1): [[x, y],[x, y],[x, y],[x, y], ...], 397 | int(trackid2): [[x, y],[x, y], ...], 398 | int(trackid3): [[x, y],[x, y],[x, y], ...], 399 | : 400 | } 401 | ↓ 402 | pre_processed_point_histories: List 403 | [ 404 | [rx, ry, rx, ry, rx, ry, rx, ry, ...], 405 | [rx, ry, rx, ry, ...], 406 | [rx, ry, rx, ry, rx, ry, ...], 407 | : 408 | ] 409 | """ 410 | # 人差し指軌跡を相対座標へ変換 411 | pre_processed_point_histories = pre_process_point_history( 412 | image_width=debug_image.shape[1], 413 | image_height=debug_image.shape[0], 414 | point_history=point_history, 415 | ) 416 | 417 | # 学習データ保存 418 | logging_csv( 419 | number, 420 | mode, 421 | trackid, 422 | pre_processed_landmark, 423 | pre_processed_point_histories, 424 | ) 425 | 426 | # ハンドサイン分類 - バッチ処理 427 | hand_sign_ids = keypoint_classifier( 428 | np.asarray(pre_processed_landmarks, dtype=np.float32) 429 | ) 430 | for (trackid, x1y1), landmark, hand_sign_id in zip(palm_trackid_box_x1y1s.items(), hand_landmarks, hand_sign_ids): 431 | x1, y1 = x1y1 432 | point_history.setdefault(trackid, deque(maxlen=history_length)) 433 | if hand_sign_id == 2: # 指差しサイン 434 | point_history[trackid].append(list(landmark[8])) # 人差指座標 435 | else: 436 | point_history[trackid].append([0, 0]) 437 | 438 | """ 439 | 人差し指の軌跡が表示上に残り続けるのを割けるため 440 | トラッキング対象外になった(画角から外れた)手のひらがある場合は人差指XY座標の履歴をクリアする 441 | 今回の全ての軌跡座標と前回の全ての軌跡座標が完全に一致したtrackidの履歴情報は変化なしと断定して履歴から削除する 442 | point_history: 最新の軌跡16点 443 | pre_point_history: 前回の軌跡16点 444 | """ 445 | if len(pre_point_history) > 0: 446 | temp_point_history = copy.deepcopy(point_history) 447 | for track_id, points in temp_point_history.items(): 448 | if track_id in pre_point_history: 449 | pre_points = pre_point_history[track_id] 450 | if points == pre_points: 451 | _ = point_history.pop(track_id, None) 452 | pre_point_history = copy.deepcopy(point_history) 453 | 454 | # フィンガージェスチャー分類 - バッチ処理 455 | finger_gesture_ids = None 456 | temp_trackid_x1y1s = {} 457 | temp_pre_processed_point_history = [] 458 | for (trackid, x1y1), pre_processed_point_history in zip(palm_trackid_box_x1y1s.items(), pre_processed_point_histories): 459 | point_history_len = len(pre_processed_point_history) 460 | if point_history_len > 0 and point_history_len % (history_length * 2) == 0: 461 | temp_trackid_x1y1s[trackid] = x1y1 462 | temp_pre_processed_point_history.append(pre_processed_point_history) 463 | if len(temp_pre_processed_point_history) > 0: 464 | finger_gesture_ids = point_history_classifier( 465 | temp_pre_processed_point_history, 466 | ) 467 | 468 | # 直近検出の中で最多のジェスチャーIDを算出 469 | if finger_gesture_ids is not None: 470 | for (trackid, x1y1), finger_gesture_id in zip(temp_trackid_x1y1s.items(), finger_gesture_ids): 471 | x1, y1 = x1y1 472 | trackid_str = str(trackid) 473 | finger_gesture_history.setdefault(trackid_str, deque(maxlen=gesture_history_length)) 474 | finger_gesture_history[trackid_str].append(int(finger_gesture_id)) 475 | most_common_fg_id = Counter(finger_gesture_history[trackid_str]).most_common() 476 | text_x = max(x1, 10) 477 | text_x = min(text_x, cap_width-120) 478 | text_y = max(y1-45, 20) 479 | text_y = min(text_y, cap_height-45) 480 | classifier_label = point_history_classifier_labels[most_common_fg_id[0][0]] 481 | # print(f'trackid: {trackid} [x1,y1]: [{x1},{y1}] finger_gesture_id: {classifier_label}') 482 | cv.putText( 483 | debug_image, 484 | f'{classifier_label}', 485 | (text_x, text_y), 486 | cv.FONT_HERSHEY_SIMPLEX, 487 | 0.8, 488 | (0,0,0), 489 | 2, 490 | cv.LINE_AA, 491 | ) 492 | cv.putText( 493 | debug_image, 494 | f'{classifier_label}', 495 | (text_x, text_y), 496 | cv.FONT_HERSHEY_SIMPLEX, 497 | 0.8, 498 | (59,255,255), 499 | 1, 500 | cv.LINE_AA, 501 | ) 502 | 503 | else: 504 | point_history = {} 505 | 506 | else: 507 | point_history = {} 508 | 509 | debug_image = draw_point_history(debug_image, point_history) 510 | debug_image = draw_info(debug_image, fps, mode, number, auto) 511 | 512 | # 画面反映 ############################################################# 513 | cv.imshow('Hand Gesture Recognition', debug_image) 514 | video_writer.write(debug_image) 515 | 516 | if video_writer: 517 | video_writer.release() 518 | if cap: 519 | cap.release() 520 | cv.destroyAllWindows() 521 | 522 | 523 | def select_mode(key, mode, auto=False, prev_number=-1): 524 | number = -1 525 | if 48 <= key <= 57: # 0 ~ 9 526 | number = key - 48 527 | prev_number = number 528 | if key == 110: # n 529 | mode = 0 530 | if key == 107: # k 531 | mode = 1 532 | if key == 104: # h 533 | mode = 2 534 | if key == 97: # a 535 | auto = not auto 536 | if auto == True: 537 | if prev_number != -1: 538 | number = prev_number 539 | else: 540 | prev_number = -1 541 | 542 | return number, mode, auto, prev_number 543 | 544 | 545 | def pre_process_landmark(landmark_list): 546 | if len(landmark_list) == 0: 547 | return [] 548 | 549 | temp_landmark_list = copy.deepcopy(landmark_list) 550 | # 相対座標に変換 551 | base_x, base_y = temp_landmark_list[0][0], temp_landmark_list[0][1] 552 | temp_landmark_list = [ 553 | [temp_landmark[0] - base_x, temp_landmark[1] - base_y] for temp_landmark in temp_landmark_list 554 | ] 555 | # 1次元リストに変換 556 | temp_landmark_list = list( 557 | itertools.chain.from_iterable(temp_landmark_list) 558 | ) 559 | # 正規化 560 | max_value = max(list(map(abs, temp_landmark_list))) 561 | 562 | def normalize_(n): 563 | return n / max_value 564 | 565 | temp_landmark_list = list(map(normalize_, temp_landmark_list)) 566 | return temp_landmark_list 567 | 568 | 569 | def pre_process_point_history( 570 | image_width: int, 571 | image_height: int, 572 | point_history: dict, 573 | ): 574 | """pre_process_point_history 575 | 576 | Parameters 577 | ---------- 578 | image_width: int 579 | Input image width 580 | 581 | image_height: int 582 | Input image height 583 | 584 | point_history: dict 585 | Index finger XY coordinate history per trackid (detected palm) 586 | { 587 | int(trackid1): [[x, y],[x, y],[x, y],[x, y], ...], 588 | int(trackid2): [[x, y],[x, y], ...], 589 | int(trackid3): [[x, y],[x, y],[x, y], ...], 590 | : 591 | } 592 | 593 | Returns 594 | ------- 595 | relative_coordinate_list_by_trackid: List 596 | [ 597 | [rx, ry, rx, ry, rx, ry, rx, ry, ...], 598 | [rx, ry, rx, ry, ...], 599 | [rx, ry, rx, ry, rx, ry, ...], 600 | : 601 | ] 602 | """ 603 | if len(point_history) == 0: 604 | return [] 605 | 606 | temp_point_history = copy.deepcopy(point_history) 607 | relative_coordinate_list_by_trackid = [] 608 | 609 | # trackidごとに相対座標へ変換 610 | for trackid, points in temp_point_history.items(): 611 | base_x, base_y = points[0][0], points[0][1] 612 | relative_coordinate_list = [ 613 | [ 614 | (point[0] - base_x) / image_width, 615 | (point[1] - base_y) / image_height, 616 | ] for point in points 617 | ] 618 | # 1次元リストに変換 619 | relative_coordinate_list_1d = list( 620 | itertools.chain.from_iterable(relative_coordinate_list) 621 | ) 622 | relative_coordinate_list_by_trackid.append(relative_coordinate_list_1d) 623 | return relative_coordinate_list_by_trackid 624 | 625 | 626 | def logging_csv(number, mode, trackid, landmark_list, point_histories): 627 | if mode == 0: 628 | pass 629 | if mode == 1 and (0 <= number <= 9): 630 | csv_path = 'model/keypoint_classifier/keypoint.csv' 631 | with open(csv_path, 'a', newline="") as f: 632 | writer = csv.writer(f) 633 | writer.writerow([number, trackid, *landmark_list]) 634 | if mode == 2 and (0 <= number <= 9): 635 | csv_path = 'model/point_history_classifier/point_history.csv' 636 | with open(csv_path, 'a', newline="") as f: 637 | writer = csv.writer(f) 638 | for point_history in point_histories: 639 | writer.writerow([number, trackid, *point_history]) 640 | 641 | 642 | def draw_info_text( 643 | image, 644 | brect, 645 | handedness, 646 | hand_sign_text, 647 | finger_gesture_text 648 | ): 649 | info_text = handedness 650 | if hand_sign_text != "": 651 | info_text = f'{handedness}:{hand_sign_text}' 652 | cv.putText( 653 | image, 654 | info_text, 655 | (brect[0] + 5, brect[1] - 4), 656 | cv.FONT_HERSHEY_SIMPLEX, 657 | 0.6, 658 | (255, 255, 255), 659 | 1, 660 | cv.LINE_AA, 661 | ) 662 | 663 | if finger_gesture_text != "": 664 | cv.putText( 665 | image, 666 | f'Finger Gesture:{finger_gesture_text}', 667 | (10, 60), 668 | cv.FONT_HERSHEY_SIMPLEX, 669 | 1.0, 670 | (0, 0, 0), 671 | 4, 672 | cv.LINE_AA, 673 | ) 674 | cv.putText( 675 | image, 676 | f'Finger Gesture:{finger_gesture_text}', 677 | (10, 60), 678 | cv.FONT_HERSHEY_SIMPLEX, 679 | 1.0, 680 | (255, 255, 255), 681 | 2, 682 | cv.LINE_AA, 683 | ) 684 | 685 | return image 686 | 687 | 688 | def draw_point_history(image, point_history): 689 | _ = [ 690 | cv.circle(image, (point[0], point[1]), 1 + int(index / 2), (152, 251, 152), 2) \ 691 | for trackid, points in point_history.items() \ 692 | for index, point in enumerate(points) if point[0] != 0 and point[1] != 0 693 | ] 694 | return image 695 | 696 | 697 | def draw_info(image, fps, mode, number, auto): 698 | cv.putText( 699 | image, 700 | f'FPS:{str(fps)}', 701 | (10, 30), 702 | cv.FONT_HERSHEY_SIMPLEX, 703 | 1.0, 704 | (0, 0, 0), 705 | 4, 706 | cv.LINE_AA, 707 | ) 708 | cv.putText( 709 | image, 710 | f'FPS:{str(fps)}', 711 | (10, 30), 712 | cv.FONT_HERSHEY_SIMPLEX, 713 | 1.0, 714 | (255, 255, 255), 715 | 2, 716 | cv.LINE_AA, 717 | ) 718 | 719 | mode_string = ['Logging Key Point', 'Logging Point History'] 720 | if 1 <= mode <= 2: 721 | cv.putText( 722 | image, 723 | f'MODE:{mode_string[mode - 1]}', 724 | (10, 90), 725 | cv.FONT_HERSHEY_SIMPLEX, 726 | 0.6, 727 | (255, 255, 255), 728 | 1, 729 | cv.LINE_AA, 730 | ) 731 | if 0 <= number <= 9: 732 | cv.putText( 733 | image, 734 | f'NUM:{str(number)}', 735 | (10, 110), 736 | cv.FONT_HERSHEY_SIMPLEX, 737 | 0.6, 738 | (255, 255, 255), 739 | 1, 740 | cv.LINE_AA, 741 | ) 742 | cv.putText( 743 | image, 744 | f'AUTO:{str(auto)}', 745 | (10, 130), 746 | cv.FONT_HERSHEY_SIMPLEX, 747 | 0.6, 748 | (255, 255, 255), 749 | 1, 750 | cv.LINE_AA, 751 | ) 752 | return image 753 | 754 | 755 | if __name__ == '__main__': 756 | main() 757 | -------------------------------------------------------------------------------- /hand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PINTO0309/hand-gesture-recognition-using-onnx/baf475d8015183dbff17e56771c0f623c8b5c816/hand.png -------------------------------------------------------------------------------- /keypoint_classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import csv\n", 10 | "\n", 11 | "import numpy as np\n", 12 | "import tensorflow as tf\n", 13 | "from sklearn.model_selection import train_test_split\n", 14 | "\n", 15 | "RANDOM_SEED = 42" 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "metadata": {}, 21 | "source": [ 22 | "# 各パス指定" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 2, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "dataset = 'model/keypoint_classifier/keypoint.csv'\n", 32 | "model_save_path = 'model/keypoint_classifier/keypoint_classifier.hdf5'" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "# 分類数設定" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 3, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "NUM_CLASSES = 3" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "# 学習データ読み込み" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 4, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "X_dataset = np.loadtxt(dataset, delimiter=',', dtype='float32', usecols=list(range(1, (21 * 2) + 1)))" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 5, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "y_dataset = np.loadtxt(dataset, delimiter=',', dtype='int32', usecols=(0))" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 6, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "X_train, X_test, y_train, y_test = train_test_split(X_dataset, y_dataset, train_size=0.75, random_state=RANDOM_SEED)" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": {}, 88 | "source": [ 89 | "# モデル構築" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 7, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "model = tf.keras.models.Sequential([\n", 99 | " tf.keras.layers.Input((21 * 2, )),\n", 100 | " tf.keras.layers.Dropout(0.2),\n", 101 | " tf.keras.layers.Dense(20, activation='relu'),\n", 102 | " tf.keras.layers.Dropout(0.4),\n", 103 | " tf.keras.layers.Dense(10, activation='relu'),\n", 104 | " tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')\n", 105 | "])" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 8, 111 | "metadata": {}, 112 | "outputs": [ 113 | { 114 | "name": "stdout", 115 | "output_type": "stream", 116 | "text": [ 117 | "Model: \"sequential\"\n", 118 | "_________________________________________________________________\n", 119 | "Layer (type) Output Shape Param # \n", 120 | "=================================================================\n", 121 | "dropout (Dropout) (None, 42) 0 \n", 122 | "_________________________________________________________________\n", 123 | "dense (Dense) (None, 20) 860 \n", 124 | "_________________________________________________________________\n", 125 | "dropout_1 (Dropout) (None, 20) 0 \n", 126 | "_________________________________________________________________\n", 127 | "dense_1 (Dense) (None, 10) 210 \n", 128 | "_________________________________________________________________\n", 129 | "dense_2 (Dense) (None, 3) 33 \n", 130 | "=================================================================\n", 131 | "Total params: 1,103\n", 132 | "Trainable params: 1,103\n", 133 | "Non-trainable params: 0\n", 134 | "_________________________________________________________________\n" 135 | ] 136 | } 137 | ], 138 | "source": [ 139 | "model.summary() # tf.keras.utils.plot_model(model, show_shapes=True)" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 9, 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "# モデルチェックポイントのコールバック\n", 149 | "cp_callback = tf.keras.callbacks.ModelCheckpoint(\n", 150 | " model_save_path, verbose=1, save_weights_only=False)\n", 151 | "# 早期打ち切り用コールバック\n", 152 | "es_callback = tf.keras.callbacks.EarlyStopping(patience=20, verbose=1)" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 10, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "# モデルコンパイル\n", 162 | "model.compile(\n", 163 | " optimizer='adam',\n", 164 | " loss='sparse_categorical_crossentropy',\n", 165 | " metrics=['accuracy']\n", 166 | ")" 167 | ] 168 | }, 169 | { 170 | "cell_type": "markdown", 171 | "metadata": {}, 172 | "source": [ 173 | "# モデル訓練" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 11, 179 | "metadata": { 180 | "scrolled": true 181 | }, 182 | "outputs": [ 183 | { 184 | "name": "stdout", 185 | "output_type": "stream", 186 | "text": [ 187 | "Epoch 1/1000\n", 188 | " 1/27 [>.............................] - ETA: 0s - loss: 1.1295 - accuracy: 0.3203\n", 189 | "Epoch 00001: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 190 | "27/27 [==============================] - 0s 11ms/step - loss: 1.1004 - accuracy: 0.3602 - val_loss: 1.0431 - val_accuracy: 0.5220\n", 191 | "Epoch 2/1000\n", 192 | " 1/27 [>.............................] - ETA: 0s - loss: 1.0440 - accuracy: 0.4844\n", 193 | "Epoch 00002: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 194 | "27/27 [==============================] - 0s 3ms/step - loss: 1.0503 - accuracy: 0.4297 - val_loss: 0.9953 - val_accuracy: 0.6397\n", 195 | "Epoch 3/1000\n", 196 | " 1/27 [>.............................] - ETA: 0s - loss: 1.0043 - accuracy: 0.5312\n", 197 | "Epoch 00003: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 198 | "27/27 [==============================] - 0s 4ms/step - loss: 1.0210 - accuracy: 0.4582 - val_loss: 0.9545 - val_accuracy: 0.6523\n", 199 | "Epoch 4/1000\n", 200 | " 1/27 [>.............................] - ETA: 0s - loss: 0.9503 - accuracy: 0.5625\n", 201 | "Epoch 00004: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 202 | "27/27 [==============================] - 0s 4ms/step - loss: 0.9906 - accuracy: 0.5022 - val_loss: 0.9168 - val_accuracy: 0.6721\n", 203 | "Epoch 5/1000\n", 204 | " 1/27 [>.............................] - ETA: 0s - loss: 0.9562 - accuracy: 0.5469\n", 205 | "Epoch 00005: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 206 | "27/27 [==============================] - 0s 3ms/step - loss: 0.9654 - accuracy: 0.5340 - val_loss: 0.8791 - val_accuracy: 0.7017\n", 207 | "Epoch 6/1000\n", 208 | " 1/27 [>.............................] - ETA: 0s - loss: 0.9184 - accuracy: 0.5938\n", 209 | "Epoch 00006: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 210 | "27/27 [==============================] - 0s 3ms/step - loss: 0.9256 - accuracy: 0.5577 - val_loss: 0.8344 - val_accuracy: 0.7269\n", 211 | "Epoch 7/1000\n", 212 | "27/27 [==============================] - ETA: 0s - loss: 0.9050 - accuracy: 0.5715\n", 213 | "Epoch 00007: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 214 | "27/27 [==============================] - 0s 4ms/step - loss: 0.9050 - accuracy: 0.5715 - val_loss: 0.7887 - val_accuracy: 0.7646\n", 215 | "Epoch 8/1000\n", 216 | " 1/27 [>.............................] - ETA: 0s - loss: 0.9135 - accuracy: 0.5547\n", 217 | "Epoch 00008: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 218 | "27/27 [==============================] - 0s 3ms/step - loss: 0.8642 - accuracy: 0.5993 - val_loss: 0.7414 - val_accuracy: 0.7996\n", 219 | "Epoch 9/1000\n", 220 | " 1/27 [>.............................] - ETA: 0s - loss: 0.8002 - accuracy: 0.6172\n", 221 | "Epoch 00009: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 222 | "27/27 [==============================] - 0s 3ms/step - loss: 0.8258 - accuracy: 0.6263 - val_loss: 0.6881 - val_accuracy: 0.8149\n", 223 | "Epoch 10/1000\n", 224 | " 1/27 [>.............................] - ETA: 0s - loss: 0.8056 - accuracy: 0.6328\n", 225 | "Epoch 00010: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 226 | "27/27 [==============================] - 0s 3ms/step - loss: 0.8008 - accuracy: 0.6341 - val_loss: 0.6461 - val_accuracy: 0.8239\n", 227 | "Epoch 11/1000\n", 228 | " 1/27 [>.............................] - ETA: 0s - loss: 0.7746 - accuracy: 0.6719\n", 229 | "Epoch 00011: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 230 | "27/27 [==============================] - 0s 3ms/step - loss: 0.7771 - accuracy: 0.6491 - val_loss: 0.6143 - val_accuracy: 0.8266\n", 231 | "Epoch 12/1000\n", 232 | " 1/27 [>.............................] - ETA: 0s - loss: 0.7242 - accuracy: 0.7109\n", 233 | "Epoch 00012: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 234 | "27/27 [==============================] - 0s 4ms/step - loss: 0.7490 - accuracy: 0.6650 - val_loss: 0.5740 - val_accuracy: 0.8320\n", 235 | "Epoch 13/1000\n", 236 | " 1/27 [>.............................] - ETA: 0s - loss: 0.8363 - accuracy: 0.6328\n", 237 | "Epoch 00013: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 238 | "27/27 [==============================] - 0s 4ms/step - loss: 0.7397 - accuracy: 0.6731 - val_loss: 0.5465 - val_accuracy: 0.8446\n", 239 | "Epoch 14/1000\n", 240 | " 1/27 [>.............................] - ETA: 0s - loss: 0.7634 - accuracy: 0.6172\n", 241 | "Epoch 00014: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 242 | "27/27 [==============================] - 0s 3ms/step - loss: 0.7190 - accuracy: 0.6883 - val_loss: 0.5202 - val_accuracy: 0.8589\n", 243 | "Epoch 15/1000\n", 244 | " 1/27 [>.............................] - ETA: 0s - loss: 0.6699 - accuracy: 0.6875\n", 245 | "Epoch 00015: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 246 | "27/27 [==============================] - 0s 3ms/step - loss: 0.7077 - accuracy: 0.6973 - val_loss: 0.4944 - val_accuracy: 0.8652\n", 247 | "Epoch 16/1000\n", 248 | " 1/27 [>.............................] - ETA: 0s - loss: 0.6447 - accuracy: 0.7500\n", 249 | "Epoch 00016: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 250 | "27/27 [==============================] - 0s 3ms/step - loss: 0.6903 - accuracy: 0.6928 - val_loss: 0.4781 - val_accuracy: 0.8805\n", 251 | "Epoch 17/1000\n", 252 | " 1/27 [>.............................] - ETA: 0s - loss: 0.7165 - accuracy: 0.6875\n", 253 | "Epoch 00017: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 254 | "27/27 [==============================] - 0s 3ms/step - loss: 0.6919 - accuracy: 0.6973 - val_loss: 0.4696 - val_accuracy: 0.8895\n", 255 | "Epoch 18/1000\n", 256 | " 1/27 [>.............................] - ETA: 0s - loss: 0.6268 - accuracy: 0.7422\n", 257 | "Epoch 00018: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 258 | "27/27 [==============================] - 0s 3ms/step - loss: 0.6498 - accuracy: 0.7303 - val_loss: 0.4440 - val_accuracy: 0.8967\n", 259 | "Epoch 19/1000\n", 260 | "27/27 [==============================] - ETA: 0s - loss: 0.6499 - accuracy: 0.7261\n", 261 | "Epoch 00019: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 262 | "27/27 [==============================] - 0s 4ms/step - loss: 0.6499 - accuracy: 0.7261 - val_loss: 0.4254 - val_accuracy: 0.9039\n", 263 | "Epoch 20/1000\n", 264 | "26/27 [===========================>..] - ETA: 0s - loss: 0.6386 - accuracy: 0.7236\n", 265 | "Epoch 00020: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 266 | "27/27 [==============================] - 0s 4ms/step - loss: 0.6415 - accuracy: 0.7228 - val_loss: 0.4082 - val_accuracy: 0.9093\n", 267 | "Epoch 21/1000\n", 268 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5631 - accuracy: 0.7500\n", 269 | "Epoch 00021: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 270 | "27/27 [==============================] - 0s 3ms/step - loss: 0.6484 - accuracy: 0.7180 - val_loss: 0.4114 - val_accuracy: 0.9173\n", 271 | "Epoch 22/1000\n", 272 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5775 - accuracy: 0.7812\n", 273 | "Epoch 00022: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 274 | "27/27 [==============================] - 0s 3ms/step - loss: 0.6154 - accuracy: 0.7480 - val_loss: 0.3907 - val_accuracy: 0.9218\n", 275 | "Epoch 23/1000\n", 276 | "25/27 [==========================>...] - ETA: 0s - loss: 0.5967 - accuracy: 0.7588\n", 277 | "Epoch 00023: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 278 | "27/27 [==============================] - 0s 4ms/step - loss: 0.5971 - accuracy: 0.7582 - val_loss: 0.3763 - val_accuracy: 0.9227\n", 279 | "Epoch 24/1000\n", 280 | "26/27 [===========================>..] - ETA: 0s - loss: 0.6064 - accuracy: 0.7569\n", 281 | "Epoch 00024: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 282 | "27/27 [==============================] - 0s 4ms/step - loss: 0.6066 - accuracy: 0.7567 - val_loss: 0.3714 - val_accuracy: 0.9254\n", 283 | "Epoch 25/1000\n", 284 | " 1/27 [>.............................] - ETA: 0s - loss: 0.6211 - accuracy: 0.7422\n", 285 | "Epoch 00025: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 286 | "27/27 [==============================] - 0s 7ms/step - loss: 0.5954 - accuracy: 0.7579 - val_loss: 0.3611 - val_accuracy: 0.9353\n", 287 | "Epoch 26/1000\n", 288 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5621 - accuracy: 0.7812\n", 289 | "Epoch 00026: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 290 | "27/27 [==============================] - 0s 3ms/step - loss: 0.5818 - accuracy: 0.7737 - val_loss: 0.3498 - val_accuracy: 0.9380\n", 291 | "Epoch 27/1000\n", 292 | " 1/27 [>.............................] - ETA: 0s - loss: 0.6431 - accuracy: 0.7500\n", 293 | "Epoch 00027: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 294 | "27/27 [==============================] - 0s 3ms/step - loss: 0.5882 - accuracy: 0.7648 - val_loss: 0.3355 - val_accuracy: 0.9416\n" 295 | ] 296 | }, 297 | { 298 | "name": "stdout", 299 | "output_type": "stream", 300 | "text": [ 301 | "Epoch 28/1000\n", 302 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5633 - accuracy: 0.8203\n", 303 | "Epoch 00028: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 304 | "27/27 [==============================] - 0s 3ms/step - loss: 0.5746 - accuracy: 0.7702 - val_loss: 0.3273 - val_accuracy: 0.9425\n", 305 | "Epoch 29/1000\n", 306 | "27/27 [==============================] - ETA: 0s - loss: 0.5856 - accuracy: 0.7651\n", 307 | "Epoch 00029: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 308 | "27/27 [==============================] - 0s 4ms/step - loss: 0.5856 - accuracy: 0.7651 - val_loss: 0.3237 - val_accuracy: 0.9434\n", 309 | "Epoch 30/1000\n", 310 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5524 - accuracy: 0.7812\n", 311 | "Epoch 00030: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 312 | "27/27 [==============================] - 0s 3ms/step - loss: 0.5576 - accuracy: 0.7788 - val_loss: 0.3203 - val_accuracy: 0.9452\n", 313 | "Epoch 31/1000\n", 314 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5643 - accuracy: 0.7578\n", 315 | "Epoch 00031: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 316 | "27/27 [==============================] - 0s 3ms/step - loss: 0.5450 - accuracy: 0.7773 - val_loss: 0.3111 - val_accuracy: 0.9443\n", 317 | "Epoch 32/1000\n", 318 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5507 - accuracy: 0.7812\n", 319 | "Epoch 00032: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 320 | "27/27 [==============================] - 0s 6ms/step - loss: 0.5574 - accuracy: 0.7860 - val_loss: 0.3017 - val_accuracy: 0.9434\n", 321 | "Epoch 33/1000\n", 322 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5302 - accuracy: 0.8125\n", 323 | "Epoch 00033: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 324 | "27/27 [==============================] - 0s 3ms/step - loss: 0.5444 - accuracy: 0.7905 - val_loss: 0.2917 - val_accuracy: 0.9479\n", 325 | "Epoch 34/1000\n", 326 | "27/27 [==============================] - ETA: 0s - loss: 0.5421 - accuracy: 0.7848\n", 327 | "Epoch 00034: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 328 | "27/27 [==============================] - 0s 4ms/step - loss: 0.5421 - accuracy: 0.7848 - val_loss: 0.2863 - val_accuracy: 0.9470\n", 329 | "Epoch 35/1000\n", 330 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4633 - accuracy: 0.8125\n", 331 | "Epoch 00035: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 332 | "27/27 [==============================] - 0s 3ms/step - loss: 0.5331 - accuracy: 0.7980 - val_loss: 0.2804 - val_accuracy: 0.9506\n", 333 | "Epoch 36/1000\n", 334 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5972 - accuracy: 0.7812\n", 335 | "Epoch 00036: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 336 | "27/27 [==============================] - 0s 3ms/step - loss: 0.5255 - accuracy: 0.7968 - val_loss: 0.2774 - val_accuracy: 0.9479\n", 337 | "Epoch 37/1000\n", 338 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5851 - accuracy: 0.7578\n", 339 | "Epoch 00037: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 340 | "27/27 [==============================] - 0s 3ms/step - loss: 0.5284 - accuracy: 0.7947 - val_loss: 0.2681 - val_accuracy: 0.9497\n", 341 | "Epoch 38/1000\n", 342 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4775 - accuracy: 0.7812\n", 343 | "Epoch 00038: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 344 | "27/27 [==============================] - 0s 3ms/step - loss: 0.5143 - accuracy: 0.7995 - val_loss: 0.2678 - val_accuracy: 0.9479\n", 345 | "Epoch 39/1000\n", 346 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5495 - accuracy: 0.7812\n", 347 | "Epoch 00039: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 348 | "27/27 [==============================] - 0s 3ms/step - loss: 0.5250 - accuracy: 0.7935 - val_loss: 0.2657 - val_accuracy: 0.9470\n", 349 | "Epoch 40/1000\n", 350 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5244 - accuracy: 0.8203\n", 351 | "Epoch 00040: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 352 | "27/27 [==============================] - 0s 3ms/step - loss: 0.5195 - accuracy: 0.8001 - val_loss: 0.2606 - val_accuracy: 0.9524\n", 353 | "Epoch 41/1000\n", 354 | " 1/27 [>.............................] - ETA: 0s - loss: 0.6996 - accuracy: 0.6953\n", 355 | "Epoch 00041: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 356 | "27/27 [==============================] - 0s 3ms/step - loss: 0.5304 - accuracy: 0.7956 - val_loss: 0.2572 - val_accuracy: 0.9515\n", 357 | "Epoch 42/1000\n", 358 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4843 - accuracy: 0.8281\n", 359 | "Epoch 00042: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 360 | "27/27 [==============================] - 0s 3ms/step - loss: 0.5115 - accuracy: 0.8091 - val_loss: 0.2513 - val_accuracy: 0.9524\n", 361 | "Epoch 43/1000\n", 362 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3830 - accuracy: 0.8594\n", 363 | "Epoch 00043: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 364 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4939 - accuracy: 0.8133 - val_loss: 0.2423 - val_accuracy: 0.9551\n", 365 | "Epoch 44/1000\n", 366 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4531 - accuracy: 0.7969\n", 367 | "Epoch 00044: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 368 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4969 - accuracy: 0.8094 - val_loss: 0.2437 - val_accuracy: 0.9497\n", 369 | "Epoch 45/1000\n", 370 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5198 - accuracy: 0.7812\n", 371 | "Epoch 00045: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 372 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4961 - accuracy: 0.8106 - val_loss: 0.2441 - val_accuracy: 0.9533\n", 373 | "Epoch 46/1000\n", 374 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5912 - accuracy: 0.7812\n", 375 | "Epoch 00046: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 376 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4816 - accuracy: 0.8271 - val_loss: 0.2482 - val_accuracy: 0.9542\n", 377 | "Epoch 47/1000\n", 378 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5292 - accuracy: 0.8047\n", 379 | "Epoch 00047: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 380 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4789 - accuracy: 0.8127 - val_loss: 0.2313 - val_accuracy: 0.9569\n", 381 | "Epoch 48/1000\n", 382 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4428 - accuracy: 0.8125\n", 383 | "Epoch 00048: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 384 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4901 - accuracy: 0.8079 - val_loss: 0.2319 - val_accuracy: 0.9560\n", 385 | "Epoch 49/1000\n", 386 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4565 - accuracy: 0.8281\n", 387 | "Epoch 00049: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 388 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4893 - accuracy: 0.8154 - val_loss: 0.2300 - val_accuracy: 0.9533\n", 389 | "Epoch 50/1000\n", 390 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5185 - accuracy: 0.7812\n", 391 | "Epoch 00050: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 392 | "27/27 [==============================] - 0s 3ms/step - loss: 0.5079 - accuracy: 0.8112 - val_loss: 0.2386 - val_accuracy: 0.9524\n", 393 | "Epoch 51/1000\n", 394 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4280 - accuracy: 0.8203\n", 395 | "Epoch 00051: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 396 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4692 - accuracy: 0.8205 - val_loss: 0.2332 - val_accuracy: 0.9578\n", 397 | "Epoch 52/1000\n", 398 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5299 - accuracy: 0.8203\n", 399 | "Epoch 00052: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 400 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4907 - accuracy: 0.8145 - val_loss: 0.2336 - val_accuracy: 0.9569\n", 401 | "Epoch 53/1000\n", 402 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5856 - accuracy: 0.7969\n", 403 | "Epoch 00053: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 404 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4882 - accuracy: 0.8181 - val_loss: 0.2306 - val_accuracy: 0.9605\n", 405 | "Epoch 54/1000\n", 406 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4466 - accuracy: 0.8047\n", 407 | "Epoch 00054: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 408 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4641 - accuracy: 0.8235 - val_loss: 0.2218 - val_accuracy: 0.9596\n" 409 | ] 410 | }, 411 | { 412 | "name": "stdout", 413 | "output_type": "stream", 414 | "text": [ 415 | "Epoch 55/1000\n", 416 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3760 - accuracy: 0.8672\n", 417 | "Epoch 00055: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 418 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4630 - accuracy: 0.8241 - val_loss: 0.2242 - val_accuracy: 0.9578\n", 419 | "Epoch 56/1000\n", 420 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4607 - accuracy: 0.7734\n", 421 | "Epoch 00056: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 422 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4546 - accuracy: 0.8277 - val_loss: 0.2168 - val_accuracy: 0.9605\n", 423 | "Epoch 57/1000\n", 424 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4946 - accuracy: 0.7969\n", 425 | "Epoch 00057: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 426 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4675 - accuracy: 0.8214 - val_loss: 0.2211 - val_accuracy: 0.9578\n", 427 | "Epoch 58/1000\n", 428 | "25/27 [==========================>...] - ETA: 0s - loss: 0.4393 - accuracy: 0.8334\n", 429 | "Epoch 00058: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 430 | "27/27 [==============================] - 0s 4ms/step - loss: 0.4418 - accuracy: 0.8325 - val_loss: 0.2115 - val_accuracy: 0.9632\n", 431 | "Epoch 59/1000\n", 432 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4987 - accuracy: 0.7969\n", 433 | "Epoch 00059: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 434 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4646 - accuracy: 0.8217 - val_loss: 0.2116 - val_accuracy: 0.9596\n", 435 | "Epoch 60/1000\n", 436 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4904 - accuracy: 0.7812\n", 437 | "Epoch 00060: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 438 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4576 - accuracy: 0.8232 - val_loss: 0.2108 - val_accuracy: 0.9569\n", 439 | "Epoch 61/1000\n", 440 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5147 - accuracy: 0.8281\n", 441 | "Epoch 00061: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 442 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4646 - accuracy: 0.8253 - val_loss: 0.2174 - val_accuracy: 0.9587\n", 443 | "Epoch 62/1000\n", 444 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3386 - accuracy: 0.8750\n", 445 | "Epoch 00062: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 446 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4433 - accuracy: 0.8310 - val_loss: 0.2145 - val_accuracy: 0.9560\n", 447 | "Epoch 63/1000\n", 448 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4222 - accuracy: 0.8594\n", 449 | "Epoch 00063: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 450 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4624 - accuracy: 0.8283 - val_loss: 0.2099 - val_accuracy: 0.9569\n", 451 | "Epoch 64/1000\n", 452 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4376 - accuracy: 0.8203\n", 453 | "Epoch 00064: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 454 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4601 - accuracy: 0.8235 - val_loss: 0.2075 - val_accuracy: 0.9641\n", 455 | "Epoch 65/1000\n", 456 | "27/27 [==============================] - ETA: 0s - loss: 0.4676 - accuracy: 0.8265\n", 457 | "Epoch 00065: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 458 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4676 - accuracy: 0.8265 - val_loss: 0.2172 - val_accuracy: 0.9551\n", 459 | "Epoch 66/1000\n", 460 | "26/27 [===========================>..] - ETA: 0s - loss: 0.4434 - accuracy: 0.8368\n", 461 | "Epoch 00066: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 462 | "27/27 [==============================] - 1s 20ms/step - loss: 0.4429 - accuracy: 0.8370 - val_loss: 0.2154 - val_accuracy: 0.9578\n", 463 | "Epoch 67/1000\n", 464 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4827 - accuracy: 0.8125\n", 465 | "Epoch 00067: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 466 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4434 - accuracy: 0.8358 - val_loss: 0.2090 - val_accuracy: 0.9587\n", 467 | "Epoch 68/1000\n", 468 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5810 - accuracy: 0.7656\n", 469 | "Epoch 00068: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 470 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4892 - accuracy: 0.8049 - val_loss: 0.2160 - val_accuracy: 0.9578\n", 471 | "Epoch 69/1000\n", 472 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4375 - accuracy: 0.7812\n", 473 | "Epoch 00069: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 474 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4597 - accuracy: 0.8250 - val_loss: 0.2100 - val_accuracy: 0.9605\n", 475 | "Epoch 70/1000\n", 476 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3782 - accuracy: 0.8359\n", 477 | "Epoch 00070: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 478 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4408 - accuracy: 0.8349 - val_loss: 0.2087 - val_accuracy: 0.9596\n", 479 | "Epoch 71/1000\n", 480 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4578 - accuracy: 0.8438\n", 481 | "Epoch 00071: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 482 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4425 - accuracy: 0.8355 - val_loss: 0.2075 - val_accuracy: 0.9587\n", 483 | "Epoch 72/1000\n", 484 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4267 - accuracy: 0.8438\n", 485 | "Epoch 00072: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 486 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4581 - accuracy: 0.8292 - val_loss: 0.2059 - val_accuracy: 0.9623\n", 487 | "Epoch 73/1000\n", 488 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4450 - accuracy: 0.8750\n", 489 | "Epoch 00073: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 490 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4508 - accuracy: 0.8403 - val_loss: 0.2083 - val_accuracy: 0.9614\n", 491 | "Epoch 74/1000\n", 492 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3858 - accuracy: 0.8906\n", 493 | "Epoch 00074: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 494 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4472 - accuracy: 0.8361 - val_loss: 0.2043 - val_accuracy: 0.9650\n", 495 | "Epoch 75/1000\n", 496 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4439 - accuracy: 0.8359\n", 497 | "Epoch 00075: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 498 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4515 - accuracy: 0.8325 - val_loss: 0.2138 - val_accuracy: 0.9632\n", 499 | "Epoch 76/1000\n", 500 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3983 - accuracy: 0.8203\n", 501 | "Epoch 00076: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 502 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4462 - accuracy: 0.8334 - val_loss: 0.2065 - val_accuracy: 0.9623\n", 503 | "Epoch 77/1000\n", 504 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5020 - accuracy: 0.8047\n", 505 | "Epoch 00077: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 506 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4325 - accuracy: 0.8388 - val_loss: 0.2061 - val_accuracy: 0.9605\n", 507 | "Epoch 78/1000\n", 508 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3215 - accuracy: 0.8672\n", 509 | "Epoch 00078: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 510 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4394 - accuracy: 0.8391 - val_loss: 0.2054 - val_accuracy: 0.9578\n", 511 | "Epoch 79/1000\n", 512 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4025 - accuracy: 0.8359\n", 513 | "Epoch 00079: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 514 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4370 - accuracy: 0.8310 - val_loss: 0.2031 - val_accuracy: 0.9605\n", 515 | "Epoch 80/1000\n", 516 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4505 - accuracy: 0.8125\n", 517 | "Epoch 00080: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 518 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4419 - accuracy: 0.8340 - val_loss: 0.2010 - val_accuracy: 0.9596\n", 519 | "Epoch 81/1000\n", 520 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5287 - accuracy: 0.7891\n", 521 | "Epoch 00081: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 522 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4369 - accuracy: 0.8304 - val_loss: 0.2081 - val_accuracy: 0.9578\n" 523 | ] 524 | }, 525 | { 526 | "name": "stdout", 527 | "output_type": "stream", 528 | "text": [ 529 | "Epoch 82/1000\n", 530 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5132 - accuracy: 0.8047\n", 531 | "Epoch 00082: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 532 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4360 - accuracy: 0.8460 - val_loss: 0.2045 - val_accuracy: 0.9605\n", 533 | "Epoch 83/1000\n", 534 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4239 - accuracy: 0.8125\n", 535 | "Epoch 00083: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 536 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4511 - accuracy: 0.8313 - val_loss: 0.1984 - val_accuracy: 0.9605\n", 537 | "Epoch 84/1000\n", 538 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4853 - accuracy: 0.8203\n", 539 | "Epoch 00084: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 540 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4357 - accuracy: 0.8304 - val_loss: 0.2024 - val_accuracy: 0.9623\n", 541 | "Epoch 85/1000\n", 542 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4782 - accuracy: 0.8125\n", 543 | "Epoch 00085: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 544 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4320 - accuracy: 0.8424 - val_loss: 0.2015 - val_accuracy: 0.9587\n", 545 | "Epoch 86/1000\n", 546 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3985 - accuracy: 0.8828\n", 547 | "Epoch 00086: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 548 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4422 - accuracy: 0.8349 - val_loss: 0.2087 - val_accuracy: 0.9587\n", 549 | "Epoch 87/1000\n", 550 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4810 - accuracy: 0.8359\n", 551 | "Epoch 00087: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 552 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4393 - accuracy: 0.8316 - val_loss: 0.2105 - val_accuracy: 0.9605\n", 553 | "Epoch 88/1000\n", 554 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4819 - accuracy: 0.8125\n", 555 | "Epoch 00088: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 556 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4381 - accuracy: 0.8400 - val_loss: 0.2070 - val_accuracy: 0.9623\n", 557 | "Epoch 89/1000\n", 558 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5002 - accuracy: 0.8281\n", 559 | "Epoch 00089: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 560 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4428 - accuracy: 0.8343 - val_loss: 0.2044 - val_accuracy: 0.9605\n", 561 | "Epoch 90/1000\n", 562 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3314 - accuracy: 0.9062\n", 563 | "Epoch 00090: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 564 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4156 - accuracy: 0.8406 - val_loss: 0.2026 - val_accuracy: 0.9578\n", 565 | "Epoch 91/1000\n", 566 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3850 - accuracy: 0.8594\n", 567 | "Epoch 00091: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 568 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4125 - accuracy: 0.8439 - val_loss: 0.2058 - val_accuracy: 0.9551\n", 569 | "Epoch 92/1000\n", 570 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4976 - accuracy: 0.7734\n", 571 | "Epoch 00092: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 572 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4217 - accuracy: 0.8415 - val_loss: 0.1999 - val_accuracy: 0.9623\n", 573 | "Epoch 93/1000\n", 574 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4155 - accuracy: 0.8516\n", 575 | "Epoch 00093: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 576 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4166 - accuracy: 0.8412 - val_loss: 0.1947 - val_accuracy: 0.9614\n", 577 | "Epoch 94/1000\n", 578 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3642 - accuracy: 0.8750\n", 579 | "Epoch 00094: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 580 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4204 - accuracy: 0.8418 - val_loss: 0.2008 - val_accuracy: 0.9569\n", 581 | "Epoch 95/1000\n", 582 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3773 - accuracy: 0.8594\n", 583 | "Epoch 00095: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 584 | "27/27 [==============================] - 0s 4ms/step - loss: 0.4171 - accuracy: 0.8421 - val_loss: 0.1945 - val_accuracy: 0.9596\n", 585 | "Epoch 96/1000\n", 586 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4086 - accuracy: 0.8672\n", 587 | "Epoch 00096: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 588 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4124 - accuracy: 0.8520 - val_loss: 0.1930 - val_accuracy: 0.9614\n", 589 | "Epoch 97/1000\n", 590 | " 1/27 [>.............................] - ETA: 0s - loss: 0.2914 - accuracy: 0.8906\n", 591 | "Epoch 00097: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 592 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4303 - accuracy: 0.8367 - val_loss: 0.1958 - val_accuracy: 0.9569\n", 593 | "Epoch 98/1000\n", 594 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4562 - accuracy: 0.8672\n", 595 | "Epoch 00098: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 596 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4191 - accuracy: 0.8400 - val_loss: 0.1950 - val_accuracy: 0.9596\n", 597 | "Epoch 99/1000\n", 598 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3967 - accuracy: 0.8438\n", 599 | "Epoch 00099: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 600 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4260 - accuracy: 0.8418 - val_loss: 0.2044 - val_accuracy: 0.9551\n", 601 | "Epoch 100/1000\n", 602 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4173 - accuracy: 0.8516\n", 603 | "Epoch 00100: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 604 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4200 - accuracy: 0.8442 - val_loss: 0.2066 - val_accuracy: 0.9560\n", 605 | "Epoch 101/1000\n", 606 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3892 - accuracy: 0.8438\n", 607 | "Epoch 00101: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 608 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4245 - accuracy: 0.8418 - val_loss: 0.2058 - val_accuracy: 0.9578\n", 609 | "Epoch 102/1000\n", 610 | " 1/27 [>.............................] - ETA: 0s - loss: 0.2965 - accuracy: 0.8984\n", 611 | "Epoch 00102: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 612 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4100 - accuracy: 0.8445 - val_loss: 0.2093 - val_accuracy: 0.9578\n", 613 | "Epoch 103/1000\n", 614 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4241 - accuracy: 0.8125\n", 615 | "Epoch 00103: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 616 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4189 - accuracy: 0.8403 - val_loss: 0.1928 - val_accuracy: 0.9659\n", 617 | "Epoch 104/1000\n", 618 | " 1/27 [>.............................] - ETA: 0s - loss: 0.6486 - accuracy: 0.7891\n", 619 | "Epoch 00104: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 620 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4223 - accuracy: 0.8424 - val_loss: 0.1964 - val_accuracy: 0.9596\n", 621 | "Epoch 105/1000\n", 622 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4589 - accuracy: 0.8281\n", 623 | "Epoch 00105: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 624 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4199 - accuracy: 0.8418 - val_loss: 0.1971 - val_accuracy: 0.9623\n", 625 | "Epoch 106/1000\n", 626 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4395 - accuracy: 0.8203\n", 627 | "Epoch 00106: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 628 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4028 - accuracy: 0.8511 - val_loss: 0.1912 - val_accuracy: 0.9641\n", 629 | "Epoch 107/1000\n", 630 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4061 - accuracy: 0.8594\n", 631 | "Epoch 00107: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 632 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4093 - accuracy: 0.8556 - val_loss: 0.1854 - val_accuracy: 0.9668\n", 633 | "Epoch 108/1000\n", 634 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4142 - accuracy: 0.8438\n", 635 | "Epoch 00108: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 636 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4249 - accuracy: 0.8367 - val_loss: 0.2008 - val_accuracy: 0.9614\n" 637 | ] 638 | }, 639 | { 640 | "name": "stdout", 641 | "output_type": "stream", 642 | "text": [ 643 | "Epoch 109/1000\n", 644 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5373 - accuracy: 0.7969\n", 645 | "Epoch 00109: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 646 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4198 - accuracy: 0.8379 - val_loss: 0.1955 - val_accuracy: 0.9659\n", 647 | "Epoch 110/1000\n", 648 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3872 - accuracy: 0.8281\n", 649 | "Epoch 00110: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 650 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4077 - accuracy: 0.8505 - val_loss: 0.2020 - val_accuracy: 0.9614\n", 651 | "Epoch 111/1000\n", 652 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3474 - accuracy: 0.8828\n", 653 | "Epoch 00111: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 654 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4132 - accuracy: 0.8433 - val_loss: 0.1984 - val_accuracy: 0.9632\n", 655 | "Epoch 112/1000\n", 656 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4205 - accuracy: 0.8672\n", 657 | "Epoch 00112: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 658 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4182 - accuracy: 0.8397 - val_loss: 0.1973 - val_accuracy: 0.9614\n", 659 | "Epoch 113/1000\n", 660 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4995 - accuracy: 0.8125\n", 661 | "Epoch 00113: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 662 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4042 - accuracy: 0.8478 - val_loss: 0.1922 - val_accuracy: 0.9650\n", 663 | "Epoch 114/1000\n", 664 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3640 - accuracy: 0.8750\n", 665 | "Epoch 00114: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 666 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4072 - accuracy: 0.8430 - val_loss: 0.1868 - val_accuracy: 0.9596\n", 667 | "Epoch 115/1000\n", 668 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5042 - accuracy: 0.8203\n", 669 | "Epoch 00115: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 670 | "27/27 [==============================] - 0s 3ms/step - loss: 0.3999 - accuracy: 0.8535 - val_loss: 0.1966 - val_accuracy: 0.9605\n", 671 | "Epoch 116/1000\n", 672 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5355 - accuracy: 0.7422\n", 673 | "Epoch 00116: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 674 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4174 - accuracy: 0.8433 - val_loss: 0.1919 - val_accuracy: 0.9659\n", 675 | "Epoch 117/1000\n", 676 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3729 - accuracy: 0.8750\n", 677 | "Epoch 00117: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 678 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4101 - accuracy: 0.8451 - val_loss: 0.1932 - val_accuracy: 0.9578\n", 679 | "Epoch 118/1000\n", 680 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3193 - accuracy: 0.8828\n", 681 | "Epoch 00118: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 682 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4260 - accuracy: 0.8379 - val_loss: 0.1865 - val_accuracy: 0.9641\n", 683 | "Epoch 119/1000\n", 684 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3852 - accuracy: 0.8438\n", 685 | "Epoch 00119: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 686 | "27/27 [==============================] - 0s 3ms/step - loss: 0.3797 - accuracy: 0.8622 - val_loss: 0.1900 - val_accuracy: 0.9677\n", 687 | "Epoch 120/1000\n", 688 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3636 - accuracy: 0.8594\n", 689 | "Epoch 00120: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 690 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4017 - accuracy: 0.8460 - val_loss: 0.1908 - val_accuracy: 0.9659\n", 691 | "Epoch 121/1000\n", 692 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4521 - accuracy: 0.8359\n", 693 | "Epoch 00121: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 694 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4023 - accuracy: 0.8538 - val_loss: 0.1935 - val_accuracy: 0.9659\n", 695 | "Epoch 122/1000\n", 696 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4849 - accuracy: 0.8203\n", 697 | "Epoch 00122: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 698 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4186 - accuracy: 0.8457 - val_loss: 0.1937 - val_accuracy: 0.9659\n", 699 | "Epoch 123/1000\n", 700 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4021 - accuracy: 0.8516\n", 701 | "Epoch 00123: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 702 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4156 - accuracy: 0.8478 - val_loss: 0.1907 - val_accuracy: 0.9632\n", 703 | "Epoch 124/1000\n", 704 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3601 - accuracy: 0.8906\n", 705 | "Epoch 00124: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 706 | "27/27 [==============================] - 0s 3ms/step - loss: 0.3948 - accuracy: 0.8550 - val_loss: 0.1862 - val_accuracy: 0.9605\n", 707 | "Epoch 125/1000\n", 708 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4446 - accuracy: 0.7891\n", 709 | "Epoch 00125: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 710 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4152 - accuracy: 0.8520 - val_loss: 0.1888 - val_accuracy: 0.9623\n", 711 | "Epoch 126/1000\n", 712 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3733 - accuracy: 0.8438\n", 713 | "Epoch 00126: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 714 | "27/27 [==============================] - 0s 3ms/step - loss: 0.3913 - accuracy: 0.8550 - val_loss: 0.1937 - val_accuracy: 0.9632\n", 715 | "Epoch 127/1000\n", 716 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3000 - accuracy: 0.8828\n", 717 | "Epoch 00127: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n", 718 | "27/27 [==============================] - 0s 3ms/step - loss: 0.3820 - accuracy: 0.8583 - val_loss: 0.1867 - val_accuracy: 0.9632\n", 719 | "Epoch 00127: early stopping\n" 720 | ] 721 | }, 722 | { 723 | "data": { 724 | "text/plain": [ 725 | "" 726 | ] 727 | }, 728 | "execution_count": 11, 729 | "metadata": {}, 730 | "output_type": "execute_result" 731 | } 732 | ], 733 | "source": [ 734 | "model.fit(\n", 735 | " X_train,\n", 736 | " y_train,\n", 737 | " epochs=1000,\n", 738 | " batch_size=128,\n", 739 | " validation_data=(X_test, y_test),\n", 740 | " callbacks=[cp_callback, es_callback]\n", 741 | ")" 742 | ] 743 | }, 744 | { 745 | "cell_type": "code", 746 | "execution_count": 12, 747 | "metadata": {}, 748 | "outputs": [ 749 | { 750 | "name": "stdout", 751 | "output_type": "stream", 752 | "text": [ 753 | "9/9 [==============================] - 0s 1ms/step - loss: 0.1867 - accuracy: 0.9632\n" 754 | ] 755 | } 756 | ], 757 | "source": [ 758 | "# モデル評価\n", 759 | "val_loss, val_acc = model.evaluate(X_test, y_test, batch_size=128)" 760 | ] 761 | }, 762 | { 763 | "cell_type": "code", 764 | "execution_count": 13, 765 | "metadata": {}, 766 | "outputs": [], 767 | "source": [ 768 | "# 保存したモデルのロード\n", 769 | "model = tf.keras.models.load_model(model_save_path)" 770 | ] 771 | }, 772 | { 773 | "cell_type": "code", 774 | "execution_count": 14, 775 | "metadata": {}, 776 | "outputs": [ 777 | { 778 | "name": "stdout", 779 | "output_type": "stream", 780 | "text": [ 781 | "[0.77297777 0.1697358 0.05728642]\n", 782 | "0\n" 783 | ] 784 | } 785 | ], 786 | "source": [ 787 | "# 推論テスト\n", 788 | "predict_result = model.predict(np.array([X_test[0]]))\n", 789 | "print(np.squeeze(predict_result))\n", 790 | "print(np.argmax(np.squeeze(predict_result)))" 791 | ] 792 | }, 793 | { 794 | "cell_type": "markdown", 795 | "metadata": {}, 796 | "source": [ 797 | "# 混同行列" 798 | ] 799 | }, 800 | { 801 | "cell_type": "code", 802 | "execution_count": 15, 803 | "metadata": {}, 804 | "outputs": [ 805 | { 806 | "data": { 807 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZUAAAFmCAYAAAClXQeMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAfJklEQVR4nO3de5yWdZn48c81DKiA5llhYMWUPGQbtumalBnkibXQLQlfmSc22s3MyvLQK9d0s7VNM63NbVwPuJsoWaRrHlP8mZkgChKndPIQIIh4BExl5vn+/phbeyJmHmTumeeeez5vX/eL5/k+9+Ea58VcXN/re98TKSUkScpDQ70DkCSVh0lFkpQbk4okKTcmFUlSbkwqkqTcmFQkSbkxqUhSHxMR/SJiTkTckr3fNSJmRkRLRNwQEQOy8c2y9y3Z5yNqnbuxm2Nn3aonvBGmwAY1HVTvENQJ7yMrvnVvLIvczpXDz8v+279zY+I5DVgEbJW9/w5wSUrp+oj4L2AScHn254sppd0jYmK236c6O7GViiQVRaWt61sNETEM+Afgv7P3AYwBbsx2mQIclb0en70n+3xstn+HTCqS1Ld8HzgDqGTvtwNeSim1Zu+XAk3Z6yZgCUD2+cvZ/h0yqUhSUaRKl7eImBwRs6u2yW+ePiKOBFamlB7uri+h23sqkqSNVKnU3qeGlFIz0NzBx6OBj0fEOGBz2nsqlwJbR0RjVo0MA5Zl+y8DhgNLI6IReAfwfGfXt1KRpIJIqdLlrfPzp7NTSsNSSiOAicA9KaVPAzOAT2a7nQDclL2+OXtP9vk9qcbqEZOKJOlM4CsR0UJ7z+TKbPxKYLts/CvAWbVO5PSXJBVFDtNfGyuldC9wb/b6CWD/DezzGnDM2zmvSUWSiqLG9FVvYFKRpKLYiPtMis6kIklFUYJKxUa9JCk3ViqSVBQ92KjvLiYVSSqIWveZ9AYmFUkqCisVSVJuSlCp2KiXJOXGSkWSisL7VCRJuSnB9JdJRZKKogSNensqkqTcWKlIUlE4/SVJyk0Jpr9MKpJUECm5+kuSlJcSTH/ZqJck5cZKRZKKwp6KJCk3JZj+MqlIUlH4mBZJUm5KUKnYqJck5cZKRZKKwka9JCk3JZj+MqlIUlGUoFKxpyJJyo2ViiQVRQkqFZOKJBWED5SUJOXHSkWSlJsSrP6yUS9Jyo1JRZKKolLp+taJiNg8ImZFxKMRsSAizsvGr4mIJyNibraNysYjIi6LiJaImBcR76v1JTj9JUlF0f3TX68DY1JKayKiP3B/RNyWffa1lNKN6+1/BDAy2/4euDz7s0MmFUkqim5u1KeUErAme9s/21Inh4wHrs2OezAito6IISml5R0d4PSXJBVFqnR5i4jJETG7aptcfYmI6BcRc4GVwF0ppZnZRxdkU1yXRMRm2VgTsKTq8KXZWIesVCSpRFJKzUBzJ5+3AaMiYmtgekTsA5wNrAAGZMeeCZy/Kde3UpGkoujmRn21lNJLwAzg8JTS8tTudeBqYP9st2XA8KrDhmVjHTKpSFJRdP/qrx2yCoWI2AI4BFgcEUOysQCOAuZnh9wMHJ+tAjsAeLmzfgo4/SVJxdH9q7+GAFMioh/tRcW0lNItEXFPROwABDAX+Ods/1uBcUAL8CpwUq0LmFQkqY9IKc0D9t3A+JgO9k/AKW/nGk5/bUBbWxufPPEUPv+1cwFY+swKjv3slzhiwsmcfs6/s27dur/Y/64Z97PP6COYv+ixeoSrKg0NDcyaeTvTp19T71BUZdiwodx150959NEZzJ17D6d+YVK9QyqmHuypdBeTygb8709v4p0j/uat95dcfhWf+dRR3DbtKrbacjA/u+WOtz5bu/ZV/venN/G3e+9Rj1C1nlNPncTixS31DkPraW1t5YwzzuO97/0IH/zgx/jnfzmRvfYaWe+wiieHJcX1ZlJZz4qVz3HfA7P4xMcOAyClxMyHH+XQgz8EwPhxH+We+3771v4/uOJaTj7uGAZsNqAu8erPmpqGcMQRY7nq6uvqHYrWs2LFSubMbe/9rlmzlsWLH2fo0J3rHFUB9YVKJSL2jIgzs+e/XJa93qsngquH71z6Y77y+UlEtP+veenlV9hy8CAaG/sBsNMO27PyuecBWPj7FlasXMWHD9y/w/Op51x80Tc5++wLqFQ6u0FY9bbLLsMY9d59mDVrTr1DKZ6yVyoRcSZwPe0rAmZlWwBTI+Ks7g+vZ937m5lsu83WvHvP2mV5pVLhP37QzNdO/WwPRKZaxo0by8rnVjFnzu/qHYo6MWjQQKbdcAWnf/VcVq9eU/sA9Tq1Vn9NAt6dUvqLznREfA9YAFy4oYOyxwJMBvjRxd/in44/NodQu9+ceQu59/4H+fVvH+L1N9axdu2rXPj9/2L1mrW0trbR2NiPZ59bxY47bMfaV/9EyxNPc9IXzgBg1QsvcuqZ5/GD75zLPnu9q85fSd9z4Af248h/OJTDDxvD5ptvxlZbbck1V1/GiSd9sd6hKdPY2Mi0G65g6tTp/OIXt9U+oC8qwPRVV0X7irEOPoxYDByWUnp6vfFdgDtTSjW70+tWPdEr5yJmPTKPa6b+jB999zy+8o0L+OjBoxn30YM57z9+wB6778rEfzzyL/Y/8Qtn8NVT/qnXJZRBTQfVO4TcHXTQB/jylz/H0UefWO9Quqyzv5+9zdVXXcoLL7zE6V89t96h5GrdG8sir3P9adr5Xf6GbzHhX3OLZ1PU6ql8Cbg7Im6LiOZsux24Gzit26MriC//y8lce/10jphwMi+/8gr/eOSh9Q5J6lVGH7gfxx33ST7ykQOZ/dCdzH7oTg4/fIO3RvRtKXV9q7NOKxWAaO9Y78+fn0y5DHgoeyhZTb21UukryliplEmZKpWyyrVSmXpu1yuVY8+ra6VS8476lFIFeLAHYpEk9XI+pkWSiqIEjXqTiiQVRQHuM+kqk4okFUUJKhUf0yJJyo2ViiQVRQlW+5lUJKkoSjD9ZVKRpKIwqUiSclOC1V826iVJubFSkaSCSCX4XUAmFUkqCnsqkqTclKCnYlKRpKIowfSXjXpJUm6sVCSpKOypSJJyY1KRJOWmBM/+sqciScqNlYokFYXTX5Kk3JRgSbFJRZKKogQ3P9pTkaSiqKSub52IiM0jYlZEPBoRCyLivGx814iYGREtEXFDRAzIxjfL3rdkn4+o9SWYVCSp73gdGJNSei8wCjg8Ig4AvgNcklLaHXgRmJTtPwl4MRu/JNuvUyYVSSqIVKl0eev0/O3WZG/7Z1sCxgA3ZuNTgKOy1+Oz92Sfj42I6OwaJhVJKopunv4CiIh+ETEXWAncBfwBeCml1JrtshRoyl43AUsAss9fBrbr7PwmFUkqilTp8hYRkyNidtU2+S8ukVJbSmkUMAzYH9gzzy/B1V+SVCIppWageSP2eykiZgAfALaOiMasGhkGLMt2WwYMB5ZGRCPwDuD5zs5rpSJJRdH9q792iIits9dbAIcAi4AZwCez3U4Abspe35y9J/v8npQ6f5aMlYokFUX331E/BJgSEf1oLyqmpZRuiYiFwPUR8S1gDnBltv+VwP9ERAvwAjCx1gVMKpJUFN18R31KaR6w7wbGn6C9v7L++GvAMW/nGiYVSSoK76iXJOnPrFQkqSh8oKQkKS+17ojvDUwqklQUViqSpNyUIKnYqJck5cZKRZKKogRLik0qklQUJZj+MqlIUkGkEiQVeyqSpNxYqUhSUZSgUjGpSFJRePOjJCk3ViqSpNyUIKnYqJck5cZKRZIKosZv6u0VTCqSVBQlmP4yqUhSUZhUahs49EPdfQl1wSu3nVvvENSJkcdcVu8Q1IO8o16SpCpOf0lSUZSgUjGpSFJR9P4b6k0qklQU9lQkSapipSJJRVGCSsWkIklFYU9FkpSXMvRUTCqSVBQlqFRs1EuScmOlIkkFUYbpLysVSSqKSg5bJyJieETMiIiFEbEgIk7Lxr8ZEcsiYm62jas65uyIaImI30fEYbW+BCsVSSqI1P09lVbg9JTSIxGxJfBwRNyVfXZJSumi6p0jYm9gIvBuYCjwq4h4V0qpraMLWKlIUlF0c6WSUlqeUnoke70aWAQ0dXLIeOD6lNLrKaUngRZg/86uYVKRpD4oIkYA+wIzs6EvRMS8iLgqIrbJxpqAJVWHLaXzJGRSkaSiSJWubxExOSJmV22T179ORAwGfgZ8KaX0CnA5sBswClgOXLypX4M9FUkqihx6KimlZqC5o88joj/tCeUnKaWfZ8c8W/X5FcAt2dtlwPCqw4dlYx2yUpGkgsijUulMRARwJbAopfS9qvEhVbsdDczPXt8MTIyIzSJiV2AkMKuza1ipSFLfMRr4DPC7iJibjX0dODYiRgEJeAr4HEBKaUFETAMW0r5y7JTOVn6BSUWSCqO7lxSnlO4HYgMf3drJMRcAF2zsNUwqklQQPXCfSrczqUhSUaQNFRG9i0lFkgqiDJWKq78kSbmxUpGkgkgVp78kSTkpw/SXSUWSCiLZqJck5aUMlYqNeklSbqxUJKkgbNRLknKTev+vqDepSFJRlKFSsaciScqNlYokFUQZKhWTiiQVhD0VSVJurFQkSbkpwx31NuolSbmxUpGkgijDY1pMKpJUEJUSTH+ZVCSpIMrQUzGpSFJBlGH1l416SVJurFQkqSC8+VGSlJsyTH+ZVCSpIMqw+sueiiQpN1YqklQQLimWJOXGRn0fMmzYUK6+6lJ23Gl7Ukpc+d8/4Qc/vLLeYfU5r69r5eTv3cC61jZaKxU+uu9IPn/kaM659nYefnwJg7fYDIDzP3M4ew7f8a3j5j+1ghMuuo4LTz6SQ973rnqF36cMadqZS3/0bbbfcTtSSlw35Uau/PH/ste79+DC753DoEEDWfLHZzj1c2eyZvXaeodbCGXoqZhUNlJraytnnHEec+bOZ/DgQcyceTu/uvs+Fi16vN6h9SkDGvtxxWnHMHDzAaxra+Oki6/ng+/eFYAvH/3hDSaMtkqFS39xHwfsNaKHo+3b2lpbOf+c7zJ/3iIGDR7IbfdM4757H+C7l57Ht/71Ih58YDaf+vTR/POpJ3HRt39Y73ALoQzTXzbqN9KKFSuZM3c+AGvWrGXx4scZOnTnOkfV90QEAzcfAEBrW4XWtgpB538Rp947h7H7jmTbLQf2RIjKrHx2FfPnLQJg7ZpXefyxJ9h5yE68c/ddePCB2QDcd+9vGfexQ+oZpnK2yUklIk7KM5DeZJddhjHqvfswa9aceofSJ7VVKkz49rWMOfNyDthzF96z6xAAfnjz/RzzrSl898YZvLGuFYBnX1rNjLktTPjQqDpGrGHDh7LP3+7FnIfn8djiP3DYuDEAHDn+UP9xViWlrm+diYjhETEjIhZGxIKIOC0b3zYi7oqIx7M/t8nGIyIui4iWiJgXEe+r9TV0pVI5rwvH9lqDBg1k2g1XcPpXz2X16jX1DqdP6tfQwLSvH88dF0xm/lMraHlmFV8c/0F+ce5J/OTMT/Py2te4+q6HAPjuT+/ltKM/REND759W6K0GDtqC5imX8M2vf4c1q9dy+qnncPykidx6zw0MHjyIdevW1TvEwqik6PJWQytwekppb+AA4JSI2Bs4C7g7pTQSuDt7D3AEMDLbJgOX17pApz2ViJjX0UfATp0cNzkLgIZ+76ChYVCtOHqFxsZGpt1wBVOnTucXv7it3uH0eVsN3Jz99hjObxY8yQmH7AfAgP6NjP/APlz7q/bplYV/XMGZV/4SgJfW/on75z9Bv4ZgzKiRdYu7L2lsbKR5yveZfuMvue2WXwHwh8ef5NOfmAzArrvtwthDDqpniIXS3T2VlNJyYHn2enVELAKagPHAwdluU4B7gTOz8WtTSgl4MCK2jogh2Xk2qFajfifgMODF9cYDeKCTwJuBZoD+A5pKsEiu3RXNF7N4cQvfv7S53qH0WS+sfpXGfg1sNXBzXntjHQ8uepqTDt2P515eww7vGExKiRmPtrD70O0AuPXfPvvWsedcezsH7fNOE0oPuuiy82l57Amu+NG1b41tt/22PL/qBSKC007/HP9zzbQ6Rlgseaz+qv5HfaY5+5m8/n4jgH2BmcBOVYliBX8uGpqAJVWHLc3GNjmp3AIMTinN3UBA99Y4tlRGH7gfxx33SX73u4XMfuhOAL5xzoXcfvs9dY6sb1n18lrOufY2KpVEJSUO/bs9OOg9u/HZ70/jxTV/IqXEHsN25BvHfrTeofZ5+/39vnxy4sdZtOAx7vh/NwLwnX+7lF1324UTJk0E4LZbfsUNP5lezzBLp/of9R2JiMHAz4AvpZReifhzMksppYjY5GIgUjffbVOmSqWMXrnt3HqHoE6MPOayeoegGpa+MD+3OasHh/5jl39eHvDMzzuNJyL6014w3JFS+l429nvg4JTS8ogYAtybUtojIn6cvZ66/n4dnd8lxZJUEN3dqI/2kuRKYNGbCSVzM3BC9voE4Kaq8eOzVWAHAC93llDAmx8lqTB64ObH0cBngN9FxNxs7OvAhcC0iJgEPA1MyD67FRgHtACvAjVvJTGpSFIfkVK6Hzq8W3jsBvZPwClv5xomFUkqiEq9A8iBSUWSCiLVeORQb2BSkaSCqJRgraxJRZIKolKCSsUlxZKk3FipSFJB2FORJOXG1V+SpNyUoVKxpyJJyo2ViiQVhNNfkqTcmFQkSbkpQ0/FpCJJBVHp/TnFRr0kKT9WKpJUEGV4TItJRZIKogTPkzSpSFJRuPpLkpSbSvT+6S8b9ZKk3FipSFJB2FORJOXGnookKTfe/ChJUhUrFUkqCG9+lCTlxka9JCk3ZeipmFQkqSDKsPrLRr0kKTdWKpJUEPZUJEm5saciScqNPRVJUm4qOWy1RMRVEbEyIuZXjX0zIpZFxNxsG1f12dkR0RIRv4+Iw2qd36QiSX3LNcDhGxi/JKU0KttuBYiIvYGJwLuzY34UEf06O7lJRZIKIkXXt5rXSOk+4IWNDGk8cH1K6fWU0pNAC7B/ZweYVCSpIHpi+qsTX4iIedn02DbZWBOwpGqfpdlYh0wqklQQeSSViJgcEbOrtskbcenLgd2AUcBy4OJN/Rpc/SVJJZJSagaa3+Yxz775OiKuAG7J3i4DhlftOiwb65CViiQVRMph2xQRMaTq7dHAmyvDbgYmRsRmEbErMBKY1dm5rFQkqSB64ubHiJgKHAxsHxFLgXOBgyNiFO156SngcwAppQURMQ1YCLQCp6SU2jo7v0lFkgqiJ25+TCkdu4HhKzvZ/wLggo09v0lFkgrCO+olSapipSJJBeFTiiVJufEpxZKk3JShp2JSkaSCKMP0l416SVJuur1SKUPmLbN3jDu/3iGoEy9d9PF6h6AeVCnBT0ynvySpIOypSJJy0/vrFHsqkqQcWalIUkE4/SVJyo03P0qScuPqL0lSbnp/SrFRL0nKkZWKJBWEjXpJUm7sqUiSctP7U4pJRZIKowzTXzbqJUm5sVKRpIKwpyJJyk3vTykmFUkqDHsqkiRVsVKRpIJIJZgAM6lIUkGUYfrLpCJJBeHqL0lSbnp/SrFRL0nKkZWKJBVEGaa/rFQkqSAqOWy1RMRVEbEyIuZXjW0bEXdFxOPZn9tk4xERl0VES0TMi4j31Tq/SUWSCiLl8N9GuAY4fL2xs4C7U0ojgbuz9wBHACOzbTJwea2Tm1QkqSB6olJJKd0HvLDe8HhgSvZ6CnBU1fi1qd2DwNYRMaSz85tUJEk7pZSWZ69XADtlr5uAJVX7Lc3GOmRSkaSCyGP6KyImR8Tsqm3y24ohpUQXVje7+kuSCiKPO+pTSs1A89s87NmIGJJSWp5Nb63MxpcBw6v2G5aNdchKRZIKopJSl7dNdDNwQvb6BOCmqvHjs1VgBwAvV02TbZCViiT1IRExFTgY2D4ilgLnAhcC0yJiEvA0MCHb/VZgHNACvAqcVOv8JhVJKoieuPUxpXRsBx+N3cC+CTjl7ZzfpCJJBVGGO+pNKpJUEP4+FUlSbsrw+1Rc/SVJyo2ViiQVhD0VSVJu7KlIknJThp6KSUWSCiJt+h3xhWGjXpKUGysVSSoIG/WSpNzYU5Ek5aYMq7/sqUiScmOlIkkFYU9FkpSbMiwpNqlIUkHYqJck5cZGfR9z2KEHs2D+fSxeeD9nfO1t/TI09ZCGhgZmzbyd6dOvqXcofdLrrW0cN+0hJkydySeue5DLZz4BwPXzlvDx/3mAfX94Ny/+6Y2/Om7Bs6/w/v+8h7tanu3pkJUzK5WN1NDQwGWXXsDh445l6dLlPPjbW/m/W+5k0aLH6x2aqpx66iQWL25hy60G1zuUPmlAvwaaj9qXgQMaWddW4eSfP8zoXbZj1JCtOWjE9vzT9Ef+6pi2SuLSB1o44G+2rUPExVKGRn3NSiUi9oyIsRExeL3xw7svrOLZf799+cMfnuLJJ//IunXrmDbtJj7+scPqHZaqNDUN4YgjxnLV1dfVO5Q+KyIYOKD936qtlURrJRHAnjtsydCtttjgMdfPW8LY3XZg2y0G9GCkxZRS6vJWb50mlYj4InATcCowPyLGV3387e4MrGiGNu3MkqXPvPV+6bLlDB26cx0j0vouvuibnH32BVQq9f+L1Ze1VRKfun4mY6/6NQcM35b37PyODvddueY17nniOY55z7AejLC4KqQub/VWq1L5LPB3KaWjgIOBcyLitOyz6OigiJgcEbMjYnalsjaXQKXOjBs3lpXPrWLOnN/VO5Q+r19DcMPEv+eOE0cz/9mXaXl+TYf7fvfXj3PagbvTEB3+OFEvU6un0pBSWgOQUnoqIg4GboyIXegkqaSUmoFmgMYBTfVPnTl4ZtkKhg8b+tb7YU1DeOaZFXWMSNUO/MB+HPkPh3L4YWPYfPPN2GqrLbnm6ss48aQv1ju0PmvLzfrz/qZteODp59l9uw33uBaufIWz7pgPwEuvreP+p1fR2NDAR965Q0+GWhh9YfXXsxEx6s03WYI5EtgeeE83xlU4D82ey+6778qIEcPp378/EyaM5/9uubPeYSnzjXMu5J277ce79vgAx33mFGbc+xsTSh288Kc3WP36OgBea21j5pIXGLHNoA73/+UJo7k12z66246c/eE9+mxCAaik1OWt3mpVKscDrdUDKaVW4PiI+HG3RVVAbW1tnPalb3DrL6+jX0MD10y5gYULH6t3WFKhrFr7Ov/6q4VUUvsPyEN235GDdt2e6x5dwpRHnub5V99gwtSZfHDE9pw7Zq96h1s49U8JXRfdvVqgLNNfZeVcdrG9dNHH6x2Cahh46o9y+0s0umlMl39e/mbZPXX9S+3Nj5Kk3HjzoyQVRBGWBHeVSUWSCqIINy92lUlFkgrCSkWSlJu+cJ+KJEkbzUpFkgqiJ3oqEfEUsBpoA1pTSu+PiG2BG4ARwFPAhJTSi5tyfisVSSqIHnyg5EdSSqNSSu/P3p8F3J1SGgncnb3fJCYVSSqIOj76fjwwJXs9BThqU09kUpGkEql+Sny2TV5vlwTcGREPV322U0ppefZ6BbDTpl7fnookFUQeS4qrnxLfgQ+mlJZFxI7AXRGxeL3jU0RsciAmFUkqiJ5YUpxSWpb9uTIipgP70/5E+iEppeURMQRYuannd/pLkgqiux99HxGDImLLN18DhwLzgZuBE7LdTqD9N/5uEisVSSqIHqhUdgKmR/vTyRuB61JKt0fEQ8C0iJgEPA1M2NQLmFQkqY9IKT0BvHcD488DY/O4hklFkgqiCL+5satMKpJUEGV49pdJRZIKwkpFkpSbMlQqLimWJOXGSkWSCsLpL0lSbsow/WVSkaSCSKlS7xC6zJ6KJCk3ViqSVBB5PKW43kwqklQQPfHrhLubSUWSCsJKRZKUmzJUKjbqJUm5sVKRpILw5kdJUm68+VGSlJsy9FRMKpJUEGVY/WWjXpKUGysVSSoIp78kSblx9ZckKTdlqFTsqUiScmOlIkkFUYbVXyYVSSqIMkx/mVQkqSBs1EuSclOGx7TYqJck5cZKRZIKwukvSVJubNRLknJjT0WSlJuUUpe3WiLi8Ij4fUS0RMRZeX8NJhVJ6iMioh/wn8ARwN7AsRGxd57XcPpLkgqiB3oq+wMtKaUnACLiemA8sDCvC1ipSFJBpBy2GpqAJVXvl2Zjuen2SqX1jWXR3dfoSRExOaXUXO84tGF+f4rP71HH8vh5GRGTgclVQ809+f/bSuXtm1x7F9WR35/i83vUjVJKzSml91dt1QllGTC86v2wbCw3JhVJ6jseAkZGxK4RMQCYCNyc5wVs1EtSH5FSao2ILwB3AP2Aq1JKC/K8hknl7XMuuNj8/hSf36M6SindCtzaXeePMjwWQJJUDPZUJEm5Mam8Dd39eANtuoi4KiJWRsT8eseivxYRwyNiRkQsjIgFEXFavWNS93D6ayNljzd4DDiE9huGHgKOTSnldieqNl1EHASsAa5NKe1T73j0lyJiCDAkpfRIRGwJPAwc5d+f8rFS2XhvPd4gpfQG8ObjDVQAKaX7gBfqHYc2LKW0PKX0SPZ6NbCInO/kVjGYVDZetz/eQOoLImIEsC8ws86hqBuYVCT1mIgYDPwM+FJK6ZV6x6P8mVQ2Xrc/3kAqs4joT3tC+UlK6ef1jkfdw6Sy8br98QZSWUVEAFcCi1JK36t3POo+JpWNlFJqBd58vMEiYFrejzfQpouIqcBvgT0iYmlETKp3TPoLo4HPAGMiYm62jat3UMqfS4olSbmxUpEk5cakIknKjUlFkpQbk4okKTcmFUlSbkwqkqTcmFQkSbkxqUiScvP/AVplmVKXNqIBAAAAAElFTkSuQmCC\n", 808 | "text/plain": [ 809 | "
" 810 | ] 811 | }, 812 | "metadata": { 813 | "needs_background": "light" 814 | }, 815 | "output_type": "display_data" 816 | }, 817 | { 818 | "name": "stdout", 819 | "output_type": "stream", 820 | "text": [ 821 | "Classification Report\n", 822 | " precision recall f1-score support\n", 823 | "\n", 824 | " 0 1.00 0.99 0.99 410\n", 825 | " 1 0.98 0.92 0.95 385\n", 826 | " 2 0.91 0.99 0.95 318\n", 827 | "\n", 828 | " accuracy 0.96 1113\n", 829 | " macro avg 0.96 0.96 0.96 1113\n", 830 | "weighted avg 0.96 0.96 0.96 1113\n", 831 | "\n" 832 | ] 833 | } 834 | ], 835 | "source": [ 836 | "import pandas as pd\n", 837 | "import seaborn as sns\n", 838 | "import matplotlib.pyplot as plt\n", 839 | "from sklearn.metrics import confusion_matrix, classification_report\n", 840 | "\n", 841 | "def print_confusion_matrix(y_true, y_pred, report=True):\n", 842 | " labels = sorted(list(set(y_true)))\n", 843 | " cmx_data = confusion_matrix(y_true, y_pred, labels=labels)\n", 844 | " \n", 845 | " df_cmx = pd.DataFrame(cmx_data, index=labels, columns=labels)\n", 846 | " \n", 847 | " fig, ax = plt.subplots(figsize=(7, 6))\n", 848 | " sns.heatmap(df_cmx, annot=True, fmt='g' ,square=False)\n", 849 | " ax.set_ylim(len(set(y_true)), 0)\n", 850 | " plt.show()\n", 851 | " \n", 852 | " if report:\n", 853 | " print('Classification Report')\n", 854 | " print(classification_report(y_test, y_pred))\n", 855 | "\n", 856 | "Y_pred = model.predict(X_test)\n", 857 | "y_pred = np.argmax(Y_pred, axis=1)\n", 858 | "\n", 859 | "print_confusion_matrix(y_test, y_pred)" 860 | ] 861 | }, 862 | { 863 | "cell_type": "markdown", 864 | "metadata": {}, 865 | "source": [ 866 | "# Tensorflow-Lite用のモデルへ変換" 867 | ] 868 | }, 869 | { 870 | "cell_type": "code", 871 | "execution_count": 16, 872 | "metadata": {}, 873 | "outputs": [], 874 | "source": [ 875 | "# 推論専用のモデルとして保存\n", 876 | "model.save(model_save_path, include_optimizer=False)" 877 | ] 878 | }, 879 | { 880 | "cell_type": "code", 881 | "execution_count": 17, 882 | "metadata": {}, 883 | "outputs": [ 884 | { 885 | "name": "stdout", 886 | "output_type": "stream", 887 | "text": [ 888 | "WARNING:tensorflow:From d:\\00.envs\\20201208_mediapipe\\lib\\site-packages\\tensorflow\\python\\training\\tracking\\tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.\n", 889 | "Instructions for updating:\n", 890 | "This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n", 891 | "WARNING:tensorflow:From d:\\00.envs\\20201208_mediapipe\\lib\\site-packages\\tensorflow\\python\\training\\tracking\\tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n", 892 | "Instructions for updating:\n", 893 | "This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n", 894 | "INFO:tensorflow:Assets written to: C:\\Users\\sihit\\AppData\\Local\\Temp\\tmpy2l6ipxu\\assets\n" 895 | ] 896 | }, 897 | { 898 | "data": { 899 | "text/plain": [ 900 | "6224" 901 | ] 902 | }, 903 | "execution_count": 17, 904 | "metadata": {}, 905 | "output_type": "execute_result" 906 | } 907 | ], 908 | "source": [ 909 | "# モデルを変換(量子化)\n", 910 | "tflite_save_path = 'model/keypoint_classifier/keypoint_classifier.tflite'\n", 911 | "\n", 912 | "converter = tf.lite.TFLiteConverter.from_keras_model(model)\n", 913 | "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n", 914 | "tflite_quantized_model = converter.convert()\n", 915 | "\n", 916 | "open(tflite_save_path, 'wb').write(tflite_quantized_model)" 917 | ] 918 | }, 919 | { 920 | "cell_type": "markdown", 921 | "metadata": {}, 922 | "source": [ 923 | "# 推論テスト" 924 | ] 925 | }, 926 | { 927 | "cell_type": "code", 928 | "execution_count": 18, 929 | "metadata": {}, 930 | "outputs": [], 931 | "source": [ 932 | "interpreter = tf.lite.Interpreter(model_path=tflite_save_path)\n", 933 | "interpreter.allocate_tensors()" 934 | ] 935 | }, 936 | { 937 | "cell_type": "code", 938 | "execution_count": 19, 939 | "metadata": {}, 940 | "outputs": [], 941 | "source": [ 942 | "# 入出力テンソルを取得\n", 943 | "input_details = interpreter.get_input_details()\n", 944 | "output_details = interpreter.get_output_details()" 945 | ] 946 | }, 947 | { 948 | "cell_type": "code", 949 | "execution_count": 20, 950 | "metadata": {}, 951 | "outputs": [], 952 | "source": [ 953 | "interpreter.set_tensor(input_details[0]['index'], np.array([X_test[0]]))" 954 | ] 955 | }, 956 | { 957 | "cell_type": "code", 958 | "execution_count": 21, 959 | "metadata": { 960 | "scrolled": true 961 | }, 962 | "outputs": [ 963 | { 964 | "name": "stdout", 965 | "output_type": "stream", 966 | "text": [ 967 | "Wall time: 0 ns\n" 968 | ] 969 | } 970 | ], 971 | "source": [ 972 | "%%time\n", 973 | "# 推論実施\n", 974 | "interpreter.invoke()\n", 975 | "tflite_results = interpreter.get_tensor(output_details[0]['index'])" 976 | ] 977 | }, 978 | { 979 | "cell_type": "code", 980 | "execution_count": 22, 981 | "metadata": {}, 982 | "outputs": [ 983 | { 984 | "name": "stdout", 985 | "output_type": "stream", 986 | "text": [ 987 | "[0.7729778 0.16973573 0.05728643]\n", 988 | "0\n" 989 | ] 990 | } 991 | ], 992 | "source": [ 993 | "print(np.squeeze(tflite_results))\n", 994 | "print(np.argmax(np.squeeze(tflite_results)))" 995 | ] 996 | } 997 | ], 998 | "metadata": { 999 | "kernelspec": { 1000 | "display_name": "Python 3", 1001 | "language": "python", 1002 | "name": "python3" 1003 | }, 1004 | "language_info": { 1005 | "codemirror_mode": { 1006 | "name": "ipython", 1007 | "version": 3 1008 | }, 1009 | "file_extension": ".py", 1010 | "mimetype": "text/x-python", 1011 | "name": "python", 1012 | "nbconvert_exporter": "python", 1013 | "pygments_lexer": "ipython3", 1014 | "version": "3.8.5" 1015 | } 1016 | }, 1017 | "nbformat": 4, 1018 | "nbformat_minor": 4 1019 | } 1020 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from model.keypoint_classifier.keypoint_classifier import KeyPointClassifier 2 | from model.point_history_classifier.point_history_classifier import PointHistoryClassifier 3 | from model.palm_detection.palm_detection import PalmDetection 4 | from model.hand_landmark.hand_landmark import HandLandmark -------------------------------------------------------------------------------- /model/hand_landmark/hand_landmark.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import ( 3 | Tuple, 4 | Optional, 5 | List, 6 | ) 7 | import cv2 8 | import onnxruntime 9 | import numpy as np 10 | 11 | from utils.utils import keep_aspect_resize_and_pad 12 | 13 | 14 | class HandLandmark(object): 15 | def __init__( 16 | self, 17 | model_path: Optional[str] = 'model/hand_landmark/hand_landmark_sparse_Nx3x224x224.onnx', 18 | class_score_th: Optional[float] = 0.50, 19 | providers: Optional[List] = [ 20 | # ( 21 | # 'TensorrtExecutionProvider', { 22 | # 'trt_engine_cache_enable': True, 23 | # 'trt_engine_cache_path': '.', 24 | # 'trt_fp16_enable': True, 25 | # } 26 | # ), 27 | 'CUDAExecutionProvider', 28 | 'CPUExecutionProvider', 29 | ], 30 | ): 31 | """HandLandmark 32 | 33 | Parameters 34 | ---------- 35 | model_path: Optional[str] 36 | ONNX file path for Hand Landmark 37 | 38 | class_score_th: Optional[float] 39 | Score threshold. Default: 0.50 40 | 41 | providers: Optional[List] 42 | Name of onnx execution providers 43 | Default: 44 | [ 45 | ( 46 | 'TensorrtExecutionProvider', { 47 | 'trt_engine_cache_enable': True, 48 | 'trt_engine_cache_path': '.', 49 | 'trt_fp16_enable': True, 50 | } 51 | ), 52 | 'CUDAExecutionProvider', 53 | 'CPUExecutionProvider', 54 | ] 55 | """ 56 | # Threshold 57 | self.class_score_th = class_score_th 58 | 59 | # Model loading 60 | session_option = onnxruntime.SessionOptions() 61 | session_option.log_severity_level = 3 62 | self.onnx_session = onnxruntime.InferenceSession( 63 | model_path, 64 | sess_options=session_option, 65 | providers=providers, 66 | ) 67 | self.providers = self.onnx_session.get_providers() 68 | 69 | self.input_shapes = [ 70 | input.shape for input in self.onnx_session.get_inputs() 71 | ] 72 | self.input_names = [ 73 | input.name for input in self.onnx_session.get_inputs() 74 | ] 75 | self.output_names = [ 76 | output.name for output in self.onnx_session.get_outputs() 77 | ] 78 | 79 | 80 | def __call__( 81 | self, 82 | images: List[np.ndarray], 83 | rects: np.ndarray, 84 | ) -> Tuple[np.ndarray, np.ndarray]: 85 | """HandLandmark 86 | 87 | Parameters 88 | ---------- 89 | images: List[np.ndarray] 90 | Multiple palm images. 91 | 92 | rects: np.ndarray 93 | Coordinates, size and angle of the cut palm. 94 | [boxcount, cx, cy, width, height, angle] 95 | 96 | Returns 97 | ------- 98 | hand_landmarks: np.ndarray 99 | Hand landmarks (X,Y) x 21 100 | 101 | rotated_image_size_leftrights: np.ndarray 102 | Maximum width and height of the perimeter of the rectangle around 103 | which the bounding box of the detected hand is rotated, 104 | and flags for left and right hand 105 | [rotated_image_width, rotated_image_height, left_hand_0_or_right_hand_1] 106 | """ 107 | temp_images = copy.deepcopy(images) 108 | 109 | # PreProcess 110 | inference_images, resized_images, resize_scales_224x224, half_pad_sizes_224x224 = self.__preprocess( 111 | images=temp_images, 112 | ) 113 | 114 | # Inference 115 | xyz_x21s, hand_scores, left_hand_0_or_right_hand_1s = self.onnx_session.run( 116 | self.output_names, 117 | {input_name: inference_images for input_name in self.input_names}, 118 | ) 119 | 120 | # PostProcess 121 | hand_landmarks, rotated_image_size_leftrights = self.__postprocess( 122 | resized_images=resized_images, 123 | resize_scales_224x224=resize_scales_224x224, 124 | half_pad_sizes_224x224=half_pad_sizes_224x224, 125 | rects=rects, 126 | xyz_x21s=xyz_x21s, 127 | hand_scores=hand_scores, 128 | left_hand_0_or_right_hand_1s=left_hand_0_or_right_hand_1s, 129 | ) 130 | 131 | return hand_landmarks, rotated_image_size_leftrights 132 | 133 | 134 | def __preprocess( 135 | self, 136 | images: List[np.ndarray], 137 | swap: Optional[Tuple[int,int,int]] = (2, 0, 1), 138 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 139 | """__preprocess 140 | 141 | Parameters 142 | ---------- 143 | images: List[np.ndarray] 144 | Multiple palm images. 145 | 146 | swap: tuple 147 | HWC to CHW: (2,0,1) 148 | CHW to HWC: (1,2,0) 149 | HWC to HWC: (0,1,2) 150 | CHW to CHW: (0,1,2) 151 | 152 | Returns 153 | ------- 154 | padded_images: np.ndarray 155 | Resized and Padding and normalized image. [N,C,H,W] 156 | """ 157 | temp_images = copy.deepcopy(images) 158 | 159 | # Resize + Padding + Normalization + BGR->RGB 160 | input_h = self.input_shapes[0][2] 161 | input_w = self.input_shapes[0][3] 162 | 163 | padded_images = [] 164 | resized_images = [] 165 | resize_scales_224x224 = [] 166 | half_pad_sizes_224x224 = [] 167 | for image in temp_images: 168 | padded_image, resized_image = keep_aspect_resize_and_pad( 169 | image=image, 170 | resize_width=input_w, 171 | resize_height=input_h, 172 | ) 173 | # reduction_ratio_h = resized_h / original_h 174 | resize_224x224_scale_h = resized_image.shape[0] / image.shape[0] 175 | # reduction_ratio_w = resized_w / original_w 176 | resize_224x224_scale_w = resized_image.shape[1] / image.shape[1] 177 | resize_scales_224x224.append( 178 | [ 179 | resize_224x224_scale_w, 180 | resize_224x224_scale_h, 181 | ] 182 | ) 183 | 184 | pad_h = padded_image.shape[0] - resized_image.shape[0] 185 | pad_w = padded_image.shape[1] - resized_image.shape[1] 186 | half_pad_h_224x224 = pad_h // 2 187 | half_pad_h_224x224 = half_pad_h_224x224 if half_pad_h_224x224 >= 0 else 0 188 | half_pad_w_224x224 = pad_w // 2 189 | half_pad_w_224x224 = half_pad_w_224x224 if half_pad_w_224x224 >= 0 else 0 190 | half_pad_sizes_224x224.append([half_pad_w_224x224, half_pad_h_224x224]) 191 | 192 | padded_image = np.divide(padded_image, 255.0) 193 | padded_image = padded_image[..., ::-1] 194 | padded_image = padded_image.transpose(swap) 195 | padded_image = np.ascontiguousarray( 196 | padded_image, 197 | dtype=np.float32, 198 | ) 199 | 200 | padded_images.append(padded_image) 201 | resized_images.append(resized_image) 202 | 203 | return \ 204 | np.asarray(padded_images, dtype=np.float32), \ 205 | resized_images, \ 206 | np.asarray(resize_scales_224x224, dtype=np.float32), \ 207 | np.asarray(half_pad_sizes_224x224, dtype=np.int32) 208 | 209 | 210 | def __postprocess( 211 | self, 212 | resized_images: List[np.ndarray], 213 | resize_scales_224x224: np.ndarray, 214 | half_pad_sizes_224x224: np.ndarray, 215 | rects: np.ndarray, 216 | xyz_x21s: np.ndarray, 217 | hand_scores: np.ndarray, 218 | left_hand_0_or_right_hand_1s: np.ndarray, 219 | ) -> Tuple[np.ndarray, np.ndarray]: 220 | """__postprocess 221 | 222 | Parameters 223 | ---------- 224 | rects: np.ndarray 225 | [N, cx, cy, (xmax-xmin), (ymax-ymin), degree] 226 | 227 | xyz_x21s: np.ndarray 228 | float32[N, 63] 229 | XYZ coordinates. 21 points. 230 | 231 | hand_scores: np.ndarray 232 | float32[N, 1] 233 | Hand score. 234 | 235 | left_hand_0_or_right_hand_1s: np.ndarray 236 | float32[N, 1] 237 | 0: Left hand 238 | 1: Right hand 239 | 240 | Returns 241 | ------- 242 | extracted_hands: np.ndarray 243 | Hand landmarks (X,Y) x 21 244 | 245 | rotated_image_size_leftrights: np.ndarray 246 | Maximum width and height of the perimeter of the rectangle around 247 | which the bounding box of the detected hand is rotated, 248 | and flags for left and right hand 249 | [rotated_image_width, rotated_image_height, left_hand_0_or_right_hand_1] 250 | """ 251 | hand_landmarks = np.asarray([], dtype=np.int32) 252 | extracted_hands = [] 253 | rotated_image_size_leftrights = [] 254 | 255 | keep = hand_scores[:, 0] > self.class_score_th # hand_score > self.class_score_th 256 | xyz_x21s = xyz_x21s[keep, :] 257 | hand_scores = hand_scores[keep, :] 258 | left_hand_0_or_right_hand_1s = left_hand_0_or_right_hand_1s[keep, :] 259 | resized_images = [i for (i, k) in zip(resized_images, keep) if k == True] 260 | 261 | for resized_image, resize_scale_224x224, half_pad_size_224x224, rect, xyz_x21, left_hand_0_or_right_hand_1 in \ 262 | zip(resized_images, resize_scales_224x224, half_pad_sizes_224x224, rects, xyz_x21s, left_hand_0_or_right_hand_1s): 263 | """ 264 | hands: sqn_rr_size, rotation, sqn_rr_center_x, sqn_rr_center_y 265 | cx = int(sqn_rr_center_x * frame_width) 266 | cy = int(sqn_rr_center_y * frame_height) 267 | xmin = int((sqn_rr_center_x - (sqn_rr_size / 2))*w) 268 | xmax = int((sqn_rr_center_x + (sqn_rr_size / 2))*w) 269 | ymin = int((sqn_rr_center_y - (sqn_rr_size * wh_ratio / 2))*h) 270 | ymax = int((sqn_rr_center_y + (sqn_rr_size * wh_ratio / 2))*h) 271 | degree = degrees(rotation) 272 | rect : cx, cy, (xmax-xmin), (ymax-ymin), degree 273 | rotation = radians(degree) 274 | sqn_rr_center_x = cx / frame_width 275 | sqn_rr_center_y = cy / frame_height 276 | """ 277 | rrn_lms = xyz_x21 278 | input_h = self.input_shapes[0][2] 279 | input_w = self.input_shapes[0][3] 280 | rrn_lms = rrn_lms / input_h 281 | 282 | rcx = rect[0] 283 | rcy = rect[1] 284 | angle = rect[4] 285 | 286 | view_image = copy.deepcopy(resized_image) 287 | view_image = cv2.resize( 288 | view_image, 289 | dsize=None, 290 | fx=1/resize_scale_224x224[0], 291 | fy=1/resize_scale_224x224[1], 292 | ) 293 | rescaled_xy = np.asarray([[v[0], v[1]] for v in zip(rrn_lms[0::3], rrn_lms[1::3])], dtype=np.float32) 294 | rescaled_xy[:, 0] = (rescaled_xy[:, 0] * input_w - half_pad_size_224x224[0]) / resize_scale_224x224[0] 295 | rescaled_xy[:, 1] = (rescaled_xy[:, 1] * input_h - half_pad_size_224x224[1]) / resize_scale_224x224[1] 296 | rescaled_xy = rescaled_xy.astype(np.int32) 297 | 298 | height, width = view_image.shape[:2] 299 | image_center = (width//2, height//2) 300 | rotation_matrix = cv2.getRotationMatrix2D(image_center, -int(angle), 1) 301 | abs_cos = abs(rotation_matrix[0,0]) 302 | abs_sin = abs(rotation_matrix[0,1]) 303 | bound_w = int(height * abs_sin + width * abs_cos) 304 | bound_h = int(height * abs_cos + width * abs_sin) 305 | rotation_matrix[0, 2] += bound_w/2 - image_center[0] 306 | rotation_matrix[1, 2] += bound_h/2 - image_center[1] 307 | rotated_image = cv2.warpAffine(view_image, rotation_matrix, (bound_w, bound_h)) 308 | 309 | keypoints = [] 310 | for x,y in rescaled_xy: 311 | coord_arr = np.array([ 312 | [x, y, 1], # Left-Top 313 | ]) 314 | new_coord = rotation_matrix.dot(coord_arr.T) 315 | x_ls = new_coord[0] 316 | y_ls = new_coord[1] 317 | keypoints.append([int(x_ls), int(y_ls)]) 318 | 319 | rotated_image_width = rotated_image.shape[1] 320 | rotated_image_height = rotated_image.shape[0] 321 | roatated_hand_half_width = rotated_image_width//2 322 | roatated_hand_half_height = rotated_image_height//2 323 | 324 | hand_landmarks = np.asarray(keypoints, dtype=np.int32).reshape(-1,2) 325 | hand_landmarks[..., 0] = hand_landmarks[..., 0] + rcx - roatated_hand_half_width 326 | hand_landmarks[..., 1] = hand_landmarks[..., 1] + rcy - roatated_hand_half_height 327 | extracted_hands.append(hand_landmarks) 328 | rotated_image_size_leftrights.append([rotated_image_width, rotated_image_height, left_hand_0_or_right_hand_1]) 329 | 330 | return np.asarray(extracted_hands, dtype=np.int32), np.asarray(rotated_image_size_leftrights) 331 | -------------------------------------------------------------------------------- /model/hand_landmark/hand_landmark_sparse_Nx3x224x224.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PINTO0309/hand-gesture-recognition-using-onnx/baf475d8015183dbff17e56771c0f623c8b5c816/model/hand_landmark/hand_landmark_sparse_Nx3x224x224.onnx -------------------------------------------------------------------------------- /model/keypoint_classifier/keypoint_classifier.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PINTO0309/hand-gesture-recognition-using-onnx/baf475d8015183dbff17e56771c0f623c8b5c816/model/keypoint_classifier/keypoint_classifier.hdf5 -------------------------------------------------------------------------------- /model/keypoint_classifier/keypoint_classifier.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PINTO0309/hand-gesture-recognition-using-onnx/baf475d8015183dbff17e56771c0f623c8b5c816/model/keypoint_classifier/keypoint_classifier.onnx -------------------------------------------------------------------------------- /model/keypoint_classifier/keypoint_classifier.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import onnxruntime 4 | import numpy as np 5 | from typing import ( 6 | Optional, 7 | List, 8 | ) 9 | 10 | class KeyPointClassifier(object): 11 | def __init__( 12 | self, 13 | model_path: Optional[str] = 'model/keypoint_classifier/keypoint_classifier.onnx', 14 | providers: Optional[List] = [ 15 | # ( 16 | # 'TensorrtExecutionProvider', { 17 | # 'trt_engine_cache_enable': True, 18 | # 'trt_engine_cache_path': '.', 19 | # 'trt_fp16_enable': True, 20 | # } 21 | # ), 22 | 'CUDAExecutionProvider', 23 | 'CPUExecutionProvider', 24 | ], 25 | ): 26 | """KeyPointClassifier 27 | 28 | Parameters 29 | ---------- 30 | model_path: Optional[str] 31 | ONNX file path for Palm Detection 32 | 33 | providers: Optional[List] 34 | Name of onnx execution providers 35 | Default: 36 | [ 37 | ( 38 | 'TensorrtExecutionProvider', { 39 | 'trt_engine_cache_enable': True, 40 | 'trt_engine_cache_path': '.', 41 | 'trt_fp16_enable': True, 42 | } 43 | ), 44 | 'CUDAExecutionProvider', 45 | 'CPUExecutionProvider', 46 | ] 47 | """ 48 | # Model loading 49 | session_option = onnxruntime.SessionOptions() 50 | session_option.log_severity_level = 3 51 | self.onnx_session = onnxruntime.InferenceSession( 52 | model_path, 53 | sess_options=session_option, 54 | providers=providers, 55 | ) 56 | self.providers = self.onnx_session.get_providers() 57 | 58 | self.input_shapes = [ 59 | input.shape for input in self.onnx_session.get_inputs() 60 | ] 61 | self.input_names = [ 62 | input.name for input in self.onnx_session.get_inputs() 63 | ] 64 | self.output_names = [ 65 | output.name for output in self.onnx_session.get_outputs() 66 | ] 67 | 68 | 69 | def __call__( 70 | self, 71 | landmarks: np.ndarray, 72 | ) -> np.ndarray: 73 | """KeyPointClassifier 74 | 75 | Parameters 76 | ---------- 77 | landmarks: np.ndarray 78 | Landmarks [N, 42] 79 | 80 | Returns 81 | ------- 82 | class_ids: np.ndarray 83 | float32[N] 84 | ClassIDs of Hand Signatures 85 | """ 86 | class_ids = self.onnx_session.run( 87 | self.output_names, 88 | {input_name: landmarks for input_name in self.input_names}, 89 | )[0] 90 | 91 | return class_ids -------------------------------------------------------------------------------- /model/keypoint_classifier/keypoint_classifier.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PINTO0309/hand-gesture-recognition-using-onnx/baf475d8015183dbff17e56771c0f623c8b5c816/model/keypoint_classifier/keypoint_classifier.tflite -------------------------------------------------------------------------------- /model/keypoint_classifier/keypoint_classifier_label.csv: -------------------------------------------------------------------------------- 1 | Open 2 | Close 3 | Pointer 4 | -------------------------------------------------------------------------------- /model/keypoint_classifier/make_argmax.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | import onnx 7 | from onnxsim import simplify 8 | from argparse import ArgumentParser 9 | 10 | class Model(nn.Module): 11 | def __init__(self): 12 | super(Model, self).__init__() 13 | 14 | def forward(self, scores): 15 | max_values, max_indices = torch.max(scores, dim=1) 16 | return max_indices 17 | 18 | 19 | if __name__ == "__main__": 20 | parser = ArgumentParser() 21 | parser.add_argument( 22 | '-o', 23 | '--opset', 24 | type=int, 25 | default=11, 26 | help='onnx opset' 27 | ) 28 | parser.add_argument( 29 | '-b', 30 | '--batches', 31 | type=int, 32 | default=1, 33 | help='batch size' 34 | ) 35 | parser.add_argument( 36 | '-c', 37 | '--classes', 38 | type=int, 39 | default=3, 40 | help='classes' 41 | ) 42 | args = parser.parse_args() 43 | 44 | model = Model() 45 | 46 | MODEL = f'argmax' 47 | OPSET=args.opset 48 | BATCHES = args.batches 49 | CLASSES = args.classes 50 | 51 | onnx_file = f"{MODEL}.onnx" 52 | scores = torch.randn(BATCHES, CLASSES) 53 | 54 | torch.onnx.export( 55 | model, 56 | args=(scores), 57 | f=onnx_file, 58 | opset_version=OPSET, 59 | input_names=['argmax_input'], 60 | output_names=['class_ids'], 61 | dynamic_axes={ 62 | 'argmax_input' : {0: 'batch'}, 63 | 'class_ids' : {0: 'batch'}, 64 | }, 65 | ) 66 | model_onnx1 = onnx.load(onnx_file) 67 | model_onnx1 = onnx.shape_inference.infer_shapes(model_onnx1) 68 | onnx.save(model_onnx1, onnx_file) 69 | 70 | model_onnx2 = onnx.load(onnx_file) 71 | model_simp, check = simplify(model_onnx2) 72 | onnx.save(model_simp, onnx_file) -------------------------------------------------------------------------------- /model/keypoint_classifier/tflite_to_onnx.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ $# -gt 1 ]; then 4 | echo "The number of arguments specified is $#." 1>&2 5 | echo "Be sure to specify 0 or 1 (Number of classes) argument. (default:4)" 1>&2 6 | exit 1 7 | fi 8 | 9 | if [ $# -eq 0 ]; then 10 | CLASSES=3 11 | else 12 | CLASSES=$1 13 | fi 14 | 15 | python -m tf2onnx.convert \ 16 | --opset 11 \ 17 | --tflite keypoint_classifier.tflite \ 18 | --output keypoint_classifier.onnx 19 | 20 | onnxsim keypoint_classifier.onnx keypoint_classifier.onnx 21 | 22 | sbi4onnx \ 23 | --input_onnx_file_path keypoint_classifier.onnx \ 24 | --output_onnx_file_path keypoint_classifier.onnx \ 25 | --initialization_character_string batch 26 | 27 | sor4onnx \ 28 | --input_onnx_file_path keypoint_classifier.onnx \ 29 | --old_new "input_1" "input" \ 30 | --mode inputs \ 31 | --output_onnx_file_path keypoint_classifier.onnx 32 | 33 | sor4onnx \ 34 | --input_onnx_file_path keypoint_classifier.onnx \ 35 | --old_new "Identity" "base_scores" \ 36 | --mode outputs \ 37 | --output_onnx_file_path keypoint_classifier.onnx 38 | 39 | python make_argmax.py --classes ${CLASSES} 40 | 41 | snc4onnx \ 42 | -if keypoint_classifier.onnx argmax.onnx \ 43 | -of keypoint_classifier.onnx \ 44 | -sd base_scores argmax_input 45 | 46 | rm argmax.onnx 47 | -------------------------------------------------------------------------------- /model/palm_detection/palm_detection.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import ( 3 | Tuple, 4 | Optional, 5 | List, 6 | ) 7 | from math import ( 8 | sin, 9 | cos, 10 | atan2, 11 | pi, 12 | ) 13 | import onnxruntime 14 | import numpy as np 15 | 16 | from utils.utils import ( 17 | normalize_radians, 18 | keep_aspect_resize_and_pad, 19 | ) 20 | 21 | 22 | class PalmDetection(object): 23 | def __init__( 24 | self, 25 | model_path: Optional[str] = 'model/palm_detection/palm_detection_full_inf_post_192x192.onnx', 26 | score_threshold: Optional[float] = 0.60, 27 | providers: Optional[List] = [ 28 | # ( 29 | # 'TensorrtExecutionProvider', { 30 | # 'trt_engine_cache_enable': True, 31 | # 'trt_engine_cache_path': '.', 32 | # 'trt_fp16_enable': True, 33 | # } 34 | # ), 35 | 'CUDAExecutionProvider', 36 | 'CPUExecutionProvider', 37 | ], 38 | ): 39 | """PalmDetection 40 | 41 | Parameters 42 | ---------- 43 | model_path: Optional[str] 44 | ONNX file path for Palm Detection 45 | 46 | score_threshold: Optional[float] 47 | Detection score threshold. Default: 0.60 48 | 49 | providers: Optional[List] 50 | Name of onnx execution providers 51 | Default: 52 | [ 53 | ( 54 | 'TensorrtExecutionProvider', { 55 | 'trt_engine_cache_enable': True, 56 | 'trt_engine_cache_path': '.', 57 | 'trt_fp16_enable': True, 58 | } 59 | ), 60 | 'CUDAExecutionProvider', 61 | 'CPUExecutionProvider', 62 | ] 63 | """ 64 | # Threshold 65 | self.score_threshold = score_threshold 66 | 67 | # Model loading 68 | session_option = onnxruntime.SessionOptions() 69 | session_option.log_severity_level = 3 70 | self.onnx_session = onnxruntime.InferenceSession( 71 | model_path, 72 | sess_options=session_option, 73 | providers=providers, 74 | ) 75 | self.providers = self.onnx_session.get_providers() 76 | 77 | self.input_shapes = [ 78 | input.shape for input in self.onnx_session.get_inputs() 79 | ] 80 | self.input_names = [ 81 | input.name for input in self.onnx_session.get_inputs() 82 | ] 83 | self.output_names = [ 84 | output.name for output in self.onnx_session.get_outputs() 85 | ] 86 | self.square_standard_size = 0 87 | 88 | 89 | def __call__( 90 | self, 91 | image: np.ndarray, 92 | ) -> Tuple[np.ndarray, np.ndarray]: 93 | """PalmDetection 94 | 95 | Parameters 96 | ---------- 97 | image: np.ndarray 98 | Entire image 99 | 100 | Returns 101 | ------- 102 | hands: np.ndarray 103 | float32[N, 4] 104 | sqn_rr_size, rotation, sqn_rr_center_x, sqn_rr_center_y 105 | """ 106 | temp_image = copy.deepcopy(image) 107 | 108 | # PreProcess 109 | inference_image = self.__preprocess( 110 | temp_image, 111 | ) 112 | 113 | # Inference 114 | inferece_image = np.asarray([inference_image], dtype=np.float32) 115 | boxes = self.onnx_session.run( 116 | self.output_names, 117 | {input_name: inferece_image for input_name in self.input_names}, 118 | ) 119 | 120 | # PostProcess 121 | hands = self.__postprocess( 122 | image=temp_image, 123 | boxes=boxes[0], 124 | ) 125 | 126 | return hands 127 | 128 | 129 | def __preprocess( 130 | self, 131 | image: np.ndarray, 132 | swap: Optional[Tuple[int,int,int]] = (2, 0, 1), 133 | ) -> np.ndarray: 134 | """__preprocess 135 | 136 | Parameters 137 | ---------- 138 | image: np.ndarray 139 | Entire image 140 | 141 | swap: tuple 142 | HWC to CHW: (2,0,1) 143 | CHW to HWC: (1,2,0) 144 | HWC to HWC: (0,1,2) 145 | CHW to CHW: (0,1,2) 146 | 147 | Returns 148 | ------- 149 | padded_image: np.ndarray 150 | Resized and Padding and normalized image. 151 | """ 152 | # Resize + Padding + Normalization + BGR->RGB 153 | input_h = self.input_shapes[0][2] 154 | input_w = self.input_shapes[0][3] 155 | image_height , image_width = image.shape[:2] 156 | 157 | self.square_standard_size = max(image_height, image_width) 158 | self.square_padding_half_size = abs(image_height - image_width) // 2 159 | 160 | padded_image, resized_image = keep_aspect_resize_and_pad( 161 | image=image, 162 | resize_width=input_w, 163 | resize_height=input_h, 164 | ) 165 | 166 | pad_size_half_h = max(0, (input_h - resized_image.shape[0]) // 2) 167 | pad_size_half_w = max(0, (input_w - resized_image.shape[1]) // 2) 168 | 169 | self.pad_size_scale_h = pad_size_half_h / input_h 170 | self.pad_size_scale_w = pad_size_half_w / input_w 171 | 172 | padded_image = np.divide(padded_image, 255.0) 173 | padded_image = padded_image[..., ::-1] 174 | padded_image = padded_image.transpose(swap) 175 | padded_image = np.ascontiguousarray( 176 | padded_image, 177 | dtype=np.float32, 178 | ) 179 | return padded_image 180 | 181 | 182 | def __postprocess( 183 | self, 184 | image: np.ndarray, 185 | boxes: np.ndarray, 186 | ) -> np.ndarray: 187 | """__postprocess 188 | 189 | Parameters 190 | ---------- 191 | image: np.ndarray 192 | Entire image. 193 | 194 | boxes: np.ndarray 195 | float32[N, 8] 196 | pd_score, box_x, box_y, box_size, kp0_x, kp0_y, kp2_x, kp2_y 197 | 198 | Returns 199 | ------- 200 | hands: np.ndarray 201 | float32[N, 4] 202 | sqn_rr_size, rotation, sqn_rr_center_x, sqn_rr_center_y 203 | """ 204 | image_height = image.shape[0] 205 | image_width = image.shape[1] 206 | 207 | hands = [] 208 | keep = boxes[:, 0] > self.score_threshold # pd_score > self.score_threshold 209 | boxes = boxes[keep, :] 210 | 211 | for box in boxes: 212 | pd_score, box_x, box_y, box_size, kp0_x, kp0_y, kp2_x, kp2_y = box 213 | if box_size > 0: 214 | kp02_x = kp2_x - kp0_x 215 | kp02_y = kp2_y - kp0_y 216 | sqn_rr_size = 2.9 * box_size 217 | rotation = 0.5 * pi - atan2(-kp02_y, kp02_x) 218 | rotation = normalize_radians(rotation) 219 | sqn_rr_center_x = box_x + 0.5*box_size*sin(rotation) 220 | sqn_rr_center_y = box_y - 0.5*box_size*cos(rotation) 221 | sqn_rr_center_y = (sqn_rr_center_y * self.square_standard_size - self.square_padding_half_size) / image_height 222 | hands.append( 223 | [ 224 | sqn_rr_size, 225 | rotation, 226 | sqn_rr_center_x, 227 | sqn_rr_center_y, 228 | ] 229 | ) 230 | 231 | return np.asarray(hands) -------------------------------------------------------------------------------- /model/palm_detection/palm_detection_full_inf_post_192x192.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PINTO0309/hand-gesture-recognition-using-onnx/baf475d8015183dbff17e56771c0f623c8b5c816/model/palm_detection/palm_detection_full_inf_post_192x192.onnx -------------------------------------------------------------------------------- /model/point_history_classifier/argmax.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PINTO0309/hand-gesture-recognition-using-onnx/baf475d8015183dbff17e56771c0f623c8b5c816/model/point_history_classifier/argmax.onnx -------------------------------------------------------------------------------- /model/point_history_classifier/make_argmax.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | import onnx 7 | from onnxsim import simplify 8 | from argparse import ArgumentParser 9 | 10 | class Model(nn.Module): 11 | def __init__(self): 12 | super(Model, self).__init__() 13 | self.invalid_value = 0 14 | 15 | def forward(self, scores, score_threshold): 16 | max_scores, class_ids = torch.max(scores, dim=1) 17 | 18 | invalid_idxs = max_scores < score_threshold 19 | class_ids[invalid_idxs] = self.invalid_value 20 | 21 | return class_ids 22 | 23 | 24 | if __name__ == "__main__": 25 | parser = ArgumentParser() 26 | parser.add_argument( 27 | '-o', 28 | '--opset', 29 | type=int, 30 | default=11, 31 | help='onnx opset' 32 | ) 33 | parser.add_argument( 34 | '-b', 35 | '--batches', 36 | type=int, 37 | default=1, 38 | help='batch size' 39 | ) 40 | parser.add_argument( 41 | '-c', 42 | '--classes', 43 | type=int, 44 | default=4, 45 | help='classes' 46 | ) 47 | args = parser.parse_args() 48 | 49 | model = Model() 50 | 51 | MODEL = f'argmax' 52 | OPSET=args.opset 53 | BATCHES = args.batches 54 | CLASSES = args.classes 55 | 56 | onnx_file = f"{MODEL}.onnx" 57 | scores = torch.randn(BATCHES, CLASSES) 58 | score_threshold = torch.tensor(0.5, dtype=torch.float32) 59 | 60 | torch.onnx.export( 61 | model, 62 | args=(scores, score_threshold), 63 | f=onnx_file, 64 | opset_version=OPSET, 65 | input_names=['argmax_input', 'score_threshold'], 66 | output_names=['class_ids'], 67 | dynamic_axes={ 68 | 'argmax_input' : {0: 'batch'}, 69 | 'class_ids' : {0: 'batch'}, 70 | }, 71 | ) 72 | model_onnx1 = onnx.load(onnx_file) 73 | model_onnx1 = onnx.shape_inference.infer_shapes(model_onnx1) 74 | onnx.save(model_onnx1, onnx_file) 75 | 76 | model_onnx2 = onnx.load(onnx_file) 77 | model_simp, check = simplify(model_onnx2) 78 | onnx.save(model_simp, onnx_file) -------------------------------------------------------------------------------- /model/point_history_classifier/point_history_classifier.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PINTO0309/hand-gesture-recognition-using-onnx/baf475d8015183dbff17e56771c0f623c8b5c816/model/point_history_classifier/point_history_classifier.hdf5 -------------------------------------------------------------------------------- /model/point_history_classifier/point_history_classifier.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PINTO0309/hand-gesture-recognition-using-onnx/baf475d8015183dbff17e56771c0f623c8b5c816/model/point_history_classifier/point_history_classifier.onnx -------------------------------------------------------------------------------- /model/point_history_classifier/point_history_classifier.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import onnxruntime 4 | import numpy as np 5 | from typing import ( 6 | Optional, 7 | List, 8 | ) 9 | 10 | 11 | class PointHistoryClassifier(object): 12 | def __init__( 13 | self, 14 | model_path: Optional[str] = 'model/point_history_classifier/point_history_classifier_lstm.onnx', 15 | providers: Optional[List] = [ 16 | # ( 17 | # 'TensorrtExecutionProvider', { 18 | # 'trt_engine_cache_enable': True, 19 | # 'trt_engine_cache_path': '.', 20 | # 'trt_fp16_enable': True, 21 | # } 22 | # ), 23 | 'CUDAExecutionProvider', 24 | 'CPUExecutionProvider', 25 | ], 26 | score_th=0.5, 27 | ): 28 | """PointHistoryClassifier 29 | 30 | Parameters 31 | ---------- 32 | model_path: Optional[str] 33 | ONNX file path for Palm Detection 34 | 35 | providers: Optional[List] 36 | Name of onnx execution providers 37 | Default: 38 | [ 39 | ( 40 | 'TensorrtExecutionProvider', { 41 | 'trt_engine_cache_enable': True, 42 | 'trt_engine_cache_path': '.', 43 | 'trt_fp16_enable': True, 44 | } 45 | ), 46 | 'CUDAExecutionProvider', 47 | 'CPUExecutionProvider', 48 | ] 49 | """ 50 | # Model loading 51 | session_option = onnxruntime.SessionOptions() 52 | session_option.log_severity_level = 3 53 | self.onnx_session = onnxruntime.InferenceSession( 54 | model_path, 55 | sess_options=session_option, 56 | providers=providers, 57 | ) 58 | self.providers = self.onnx_session.get_providers() 59 | 60 | self.input_shapes = [ 61 | input.shape for input in self.onnx_session.get_inputs() 62 | ] 63 | self.input_names = [ 64 | input.name for input in self.onnx_session.get_inputs() 65 | ] 66 | self.output_names = [ 67 | output.name for output in self.onnx_session.get_outputs() 68 | ] 69 | self.score_th = np.asarray(score_th, dtype=np.float32) 70 | 71 | 72 | def __call__( 73 | self, 74 | point_history: np.ndarray, 75 | ) -> np.ndarray: 76 | """PointHistoryClassifier 77 | 78 | Parameters 79 | ---------- 80 | point_history: np.ndarray 81 | Landmarks [N, 32] 82 | 83 | Returns 84 | ------- 85 | class_ids: np.ndarray 86 | int64[N] 87 | ClassIDs of Finger gesture 88 | """ 89 | class_ids = self.onnx_session.run( 90 | self.output_names, 91 | { 92 | self.input_names[0]: point_history, 93 | self.input_names[1]: self.score_th, 94 | }, 95 | )[0] 96 | 97 | return class_ids 98 | -------------------------------------------------------------------------------- /model/point_history_classifier/point_history_classifier.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PINTO0309/hand-gesture-recognition-using-onnx/baf475d8015183dbff17e56771c0f623c8b5c816/model/point_history_classifier/point_history_classifier.tflite -------------------------------------------------------------------------------- /model/point_history_classifier/point_history_classifier_label.csv: -------------------------------------------------------------------------------- 1 | Stop 2 | Clockwise 3 | Counter Clockwise 4 | Move 5 | -------------------------------------------------------------------------------- /model/point_history_classifier/point_history_classifier_lstm.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PINTO0309/hand-gesture-recognition-using-onnx/baf475d8015183dbff17e56771c0f623c8b5c816/model/point_history_classifier/point_history_classifier_lstm.onnx -------------------------------------------------------------------------------- /model/point_history_classifier/tflite_to_onnx.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ $# -gt 1 ]; then 4 | echo "The number of arguments specified is $#." 1>&2 5 | echo "Be sure to specify 0 or 1 (Number of classes) argument. (default:4)" 1>&2 6 | exit 1 7 | fi 8 | 9 | if [ $# -eq 0 ]; then 10 | CLASSES=4 11 | else 12 | CLASSES=$1 13 | fi 14 | 15 | python -m tf2onnx.convert \ 16 | --opset 11 \ 17 | --tflite point_history_classifier.tflite \ 18 | --output point_history_classifier.onnx 19 | 20 | onnxsim point_history_classifier.onnx point_history_classifier.onnx 21 | 22 | sbi4onnx \ 23 | --input_onnx_file_path point_history_classifier.onnx \ 24 | --output_onnx_file_path point_history_classifier.onnx \ 25 | --initialization_character_string batch 26 | 27 | sor4onnx \ 28 | --input_onnx_file_path point_history_classifier.onnx \ 29 | --old_new "input_1" "input" \ 30 | --mode inputs \ 31 | --output_onnx_file_path point_history_classifier.onnx 32 | 33 | sor4onnx \ 34 | --input_onnx_file_path point_history_classifier.onnx \ 35 | --old_new "Identity" "base_scores" \ 36 | --mode outputs \ 37 | --output_onnx_file_path point_history_classifier.onnx 38 | 39 | python make_argmax.py --classes ${CLASSES} 40 | 41 | snc4onnx \ 42 | -if point_history_classifier.onnx argmax.onnx \ 43 | -of point_history_classifier.onnx \ 44 | -sd base_scores argmax_input 45 | 46 | rm argmax.onnx 47 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | onnxruntime >= 1.12.0 2 | opencv-contrib-python >= 4.6.0.66 3 | scikit-learn >= 0.23.2 4 | matplotlib >= 3.3.2 5 | tf2onnx >= 1.12.0 6 | simple-onnx-processing-tools >= 1.0.54 -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from utils.cvfpscalc import CvFpsCalc -------------------------------------------------------------------------------- /utils/cvfpscalc.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | import cv2 as cv 3 | 4 | 5 | class CvFpsCalc(object): 6 | def __init__(self, buffer_len=1): 7 | self._start_tick = cv.getTickCount() 8 | self._freq = 1000.0 / cv.getTickFrequency() 9 | self._difftimes = deque(maxlen=buffer_len) 10 | 11 | def get(self): 12 | current_tick = cv.getTickCount() 13 | different_time = (current_tick - self._start_tick) * self._freq 14 | self._start_tick = current_tick 15 | 16 | self._difftimes.append(different_time) 17 | 18 | fps = 1000.0 / (sum(self._difftimes) / len(self._difftimes)) 19 | fps_rounded = round(fps, 2) 20 | 21 | return fps_rounded 22 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import copy 4 | import numpy as np 5 | from typing import Tuple, List 6 | from math import sin, cos, pi, floor 7 | 8 | 9 | def normalize_radians( 10 | angle: float 11 | ) -> float: 12 | """__normalize_radians 13 | 14 | Parameters 15 | ---------- 16 | angle: float 17 | 18 | Returns 19 | ------- 20 | normalized_angle: float 21 | """ 22 | return angle - 2 * pi * floor((angle + pi) / (2 * pi)) 23 | 24 | 25 | def is_inside_rect( 26 | rects: np.ndarray, 27 | width_of_outer_rect: int, 28 | height_of_outer_rect: int, 29 | ) -> np.ndarray: 30 | """Determines whether rects is inside or outside the outer rectangle. 31 | 32 | Parameters 33 | ---------- 34 | rects: np.ndarray 35 | [boxcount, 5] = [boxcount, cx, cy, width, height, angle]\n 36 | Area to be verified. 37 | 38 | cx: float 39 | Rectangle center X coordinate. 40 | 41 | cy: float 42 | Rectangle center Y coordinate. 43 | 44 | width: float 45 | Width of the rectangle. 46 | 47 | height: float 48 | Height of the rectangle. 49 | 50 | angle: float 51 | The rotation angle in a clockwise direction. 52 | When the angle is 0, 90, 180, 270, 360 etc., 53 | the rectangle becomes an up-right rectangle. 54 | 55 | width_of_outer_rect: int 56 | Width of outer rectangle. 57 | 58 | height_of_outer_rect: int 59 | Height of outer rectangle 60 | 61 | Returns 62 | ------- 63 | result: np.ndarray 64 | True: if the rotated sub rectangle is side the up-right rectange, False: else 65 | """ 66 | results = [] 67 | 68 | for rect in rects: 69 | cx = rect[0] 70 | cy = rect[1] 71 | width = rect[2] 72 | height = rect[3] 73 | angle = rect[4] 74 | 75 | if (cx < 0) or (cx > width_of_outer_rect): 76 | # Center X coordinate is outside the range of the outer rectangle 77 | results.append(False) 78 | 79 | elif (cy < 0) or (cy > height_of_outer_rect): 80 | # Center Y coordinate is outside the range of the outer rectangle 81 | results.append(False) 82 | 83 | else: 84 | # Coordinate acquisition of bounding rectangle considering rotation 85 | # http://labs.eecs.tottori-u.ac.jp/sd/Member/oyamada/OpenCV/html/py_tutorials/py_imgproc/py_contours/py_contour_features/py_contour_features.html#b 86 | rect_tuple = ((cx, cy), (width, height), angle) 87 | box = cv2.boxPoints(rect_tuple) 88 | 89 | x_max = int(np.max(box[:,0])) 90 | x_min = int(np.min(box[:,0])) 91 | y_max = int(np.max(box[:,1])) 92 | y_min = int(np.min(box[:,1])) 93 | 94 | if (x_min >= 0) and (x_max <= width_of_outer_rect) and \ 95 | (y_min >= 0) and (y_max <= height_of_outer_rect): 96 | # All 4 vertices are within the perimeter rectangle 97 | results.append(True) 98 | else: 99 | # Any of the 4 vertices is outside the perimeter rectangle 100 | results.append(False) 101 | 102 | return np.asarray(results, dtype=np.bool_) 103 | 104 | 105 | def bounding_box_from_rotated_rect( 106 | rects: np.ndarray, 107 | ) -> np.ndarray: 108 | """Conversion to bounding rectangle without rotation. 109 | 110 | Parameters 111 | ---------- 112 | rects: np.ndarray 113 | [boxcount, 5] = [boxcount, cx, cy, width, height, angle]\n 114 | Rotated rectangle. 115 | 116 | cx: float 117 | Rectangle center X coordinate. 118 | 119 | cy: float 120 | Rectangle center Y coordinate. 121 | 122 | width: float 123 | Width of the rectangle. 124 | 125 | height: float 126 | Height of the rectangle. 127 | 128 | angle: float 129 | The rotation angle in a clockwise direction. 130 | When the angle is 0, 90, 180, 270, 360 etc., 131 | the rectangle becomes an up-right rectangle. 132 | 133 | Returns 134 | ------- 135 | result: np.ndarray 136 | e.g.:\n 137 | [input] rotated rectangle:\n 138 | [center:(10, 10), height:4, width:4, angle:45 degree]\n 139 | [output] bounding box for this rotated rectangle:\n 140 | [center:(10, 10), height:4*sqrt(2), width:4*sqrt(2), angle:0 degree] 141 | """ 142 | results = [] 143 | 144 | for rect in rects: 145 | cx = rect[0] 146 | cy = rect[1] 147 | width = rect[2] 148 | height = rect[3] 149 | angle = rect[4] 150 | 151 | rect_tuple = ((cx, cy), (width, height), angle) 152 | box = cv2.boxPoints(rect_tuple) 153 | 154 | x_max = int(np.max(box[:,0])) 155 | x_min = int(np.min(box[:,0])) 156 | y_max = int(np.max(box[:,1])) 157 | y_min = int(np.min(box[:,1])) 158 | 159 | cx = int((x_min + x_max) // 2) 160 | cy = int((y_min + y_max) // 2) 161 | width = int(x_max - x_min) 162 | height = int(y_max - y_min) 163 | angle = 0 164 | results.append([cx, cy, width, height, angle]) 165 | 166 | return np.asarray(results, dtype=np.float32) 167 | 168 | 169 | def image_rotation_without_crop( 170 | images: List[np.ndarray], 171 | angles: np.ndarray, 172 | ) -> List[np.ndarray]: 173 | """Conversion to bounding rectangle without rotation. 174 | 175 | Parameters 176 | ---------- 177 | images: List[np.ndarray] 178 | Image to be rotated. 179 | 180 | angles: np.ndarray 181 | Rotation degree. 182 | 183 | Returns 184 | ------- 185 | rotated_images: List[np.ndarray] 186 | Image after rotation. 187 | """ 188 | rotated_images = [] 189 | # https://stackoverflow.com/questions/22041699/rotate-an-image-without-cropping-in-opencv-in-c 190 | for image, angle in zip(images, angles): 191 | height, width = image.shape[:2] 192 | image_center = (width//2, height//2) 193 | rotation_matrix = cv2.getRotationMatrix2D(image_center, int(angle), 1) 194 | abs_cos = abs(rotation_matrix[0,0]) 195 | abs_sin = abs(rotation_matrix[0,1]) 196 | bound_w = int(height * abs_sin + width * abs_cos) 197 | bound_h = int(height * abs_cos + width * abs_sin) 198 | rotation_matrix[0, 2] += bound_w/2 - image_center[0] 199 | rotation_matrix[1, 2] += bound_h/2 - image_center[1] 200 | rotated_image = cv2.warpAffine(image, rotation_matrix, (bound_w, bound_h)) 201 | rotated_images.append(rotated_image) 202 | 203 | return rotated_images 204 | 205 | 206 | def crop_rectangle( 207 | image: np.ndarray, 208 | rects: np.ndarray, 209 | ) -> List[np.ndarray]: 210 | """rect has to be upright. 211 | 212 | Parameters 213 | ---------- 214 | image: np.ndarray 215 | Image to be rotate and crop. 216 | 217 | rects: np.ndarray 218 | [boxcount, 5] = [boxcount, cx, cy, width, height, angle]\n 219 | Rotat and crop rectangle. 220 | 221 | cx: float 222 | Rectangle center X coordinate. 223 | 224 | cy: float 225 | Rectangle center Y coordinate. 226 | 227 | width: float 228 | Width of the rectangle. 229 | 230 | height: float 231 | Height of the rectangle. 232 | 233 | angle: float 234 | The rotation angle in a clockwise direction. 235 | When the angle is 0, 90, 180, 270, 360 etc., 236 | the rectangle becomes an up-right rectangle. 237 | 238 | Returns 239 | ------- 240 | croped_images: List[np.ndarray] 241 | Image after cropping. 242 | """ 243 | croped_images = [] 244 | height = image.shape[0] 245 | width = image.shape[1] 246 | 247 | # Determine if rect is inside the entire image 248 | inside_or_outsides = is_inside_rect( 249 | rects=rects, 250 | width_of_outer_rect=width, 251 | height_of_outer_rect=height, 252 | ) 253 | 254 | rects = rects[inside_or_outsides, ...] 255 | 256 | for rect in rects: 257 | cx = int(rect[0]) 258 | cy = int(rect[1]) 259 | rect_width = int(rect[2]) 260 | rect_height = int(rect[3]) 261 | 262 | croped_image = image[ 263 | cy-rect_height//2:cy+rect_height-rect_height//2, 264 | cx-rect_width//2:cx+rect_width-rect_width//2, 265 | ] 266 | croped_images.append(croped_image) 267 | 268 | return croped_images 269 | 270 | 271 | def rotate_and_crop_rectangle( 272 | image: np.ndarray, 273 | rects_tmp: np.ndarray, 274 | operation_when_cropping_out_of_range: str, 275 | ) -> List[np.ndarray]: 276 | """Crop a rotated rectangle from a image. 277 | 278 | Parameters 279 | ---------- 280 | image: np.ndarray 281 | Image to be rotate and crop. 282 | 283 | rects: np.ndarray 284 | [boxcount, 5] = [boxcount, cx, cy, width, height, angle]\n 285 | Rotat and crop rectangle. 286 | 287 | cx: float 288 | Rectangle center X coordinate. 289 | 290 | cy: float 291 | Rectangle center Y coordinate. 292 | 293 | width: float 294 | Width of the rectangle. 295 | 296 | height: float 297 | Height of the rectangle. 298 | 299 | angle: float 300 | The rotation angle in a clockwise direction. 301 | When the angle is 0, 90, 180, 270, 360 etc., 302 | the rectangle becomes an up-right rectangle. 303 | 304 | operation_when_cropping_out_of_range: str 305 | 'padding' or 'ignore' 306 | 307 | Returns 308 | ------- 309 | rotated_croped_image: List[np.ndarray] 310 | Image after cropping and rotation. 311 | """ 312 | rects = copy.deepcopy(rects_tmp) 313 | rotated_croped_images = [] 314 | height = image.shape[0] 315 | width = image.shape[1] 316 | 317 | # Determine if rect is inside the entire image 318 | if operation_when_cropping_out_of_range == 'padding': 319 | size = (int(math.sqrt(width ** 2 + height ** 2)) + 2) * 2 320 | image = pad_image( 321 | image=image, 322 | resize_width=size, 323 | resize_height=size, 324 | ) 325 | rects[:, 0] = rects[:, 0] + abs(size-width) / 2 326 | rects[:, 1] = rects[:, 1] + abs(size-height) / 2 327 | 328 | elif operation_when_cropping_out_of_range == 'ignore': 329 | inside_or_outsides = is_inside_rect( 330 | rects=rects, 331 | width_of_outer_rect=width, 332 | height_of_outer_rect=height, 333 | ) 334 | rects = rects[inside_or_outsides, ...] 335 | 336 | rect_bbx_upright = bounding_box_from_rotated_rect( 337 | rects=rects, 338 | ) 339 | 340 | rect_bbx_upright_images = crop_rectangle( 341 | image=image, 342 | rects=rect_bbx_upright, 343 | ) 344 | 345 | rotated_rect_bbx_upright_images = image_rotation_without_crop( 346 | images=rect_bbx_upright_images, 347 | angles=rects[..., 4:5], 348 | ) 349 | 350 | for rotated_rect_bbx_upright_image, rect in zip(rotated_rect_bbx_upright_images, rects): 351 | crop_cx = rotated_rect_bbx_upright_image.shape[1]//2 352 | crop_cy = rotated_rect_bbx_upright_image.shape[0]//2 353 | rect_width = int(rect[2]) 354 | rect_height = int(rect[3]) 355 | 356 | rotated_croped_images.append( 357 | rotated_rect_bbx_upright_image[ 358 | crop_cy-rect_height//2:crop_cy+(rect_height-rect_height//2), 359 | crop_cx-rect_width//2:crop_cx+(rect_width-rect_width//2), 360 | ] 361 | ) 362 | 363 | return rotated_croped_images 364 | 365 | 366 | def keep_aspect_resize_and_pad( 367 | image: np.ndarray, 368 | resize_width: int, 369 | resize_height: int, 370 | ) -> Tuple[np.ndarray, np.ndarray]: 371 | """Resize the image based on the long side to fit inside 372 | the specified bounding rectangle while maintaining the aspect ratio, 373 | and padding the short side to the specified size. 374 | 375 | Parameters 376 | ---------- 377 | image: np.ndarray 378 | Image to be resize and pad. 379 | 380 | resize_width: int 381 | Width of outer rectangle. 382 | 383 | resize_width: int 384 | Height of outer rectangle 385 | 386 | Returns 387 | ------- 388 | padded_image: np.ndarray 389 | Image after padding. 390 | 391 | resized_image: np.ndarray 392 | Image after resize. (Before padding) 393 | """ 394 | image_height = image.shape[0] 395 | image_width = image.shape[1] 396 | padded_image = np.zeros( 397 | (resize_height, resize_width, 3), 398 | np.uint8 399 | ) 400 | ash = resize_height / image_height 401 | asw = resize_width / image_width 402 | if asw < ash: 403 | sizeas = (int(image_width * asw), int(image_height * asw)) 404 | else: 405 | sizeas = (int(image_width * ash), int(image_height * ash)) 406 | resized_image = cv2.resize(image, dsize=sizeas) 407 | start_h = int(resize_height / 2 - sizeas[1] / 2) 408 | end_h = int(resize_height / 2 + sizeas[1] / 2) 409 | start_w = int(resize_width / 2 - sizeas[0] / 2) 410 | end_w = int(resize_width / 2 + sizeas[0] / 2) 411 | padded_image[start_h:end_h, start_w:end_w, :] = resized_image.copy() 412 | 413 | return padded_image, resized_image 414 | 415 | 416 | def pad_image( 417 | image: np.ndarray, 418 | resize_width: int, 419 | resize_height: int, 420 | ) -> np.ndarray: 421 | """Padding the perimeter of the image to the specified bounding rectangle size. 422 | 423 | Parameters 424 | ---------- 425 | image: np.ndarray 426 | Image to be resize and pad. 427 | 428 | resize_width: int 429 | Width of outer rectangle. 430 | 431 | resize_width: int 432 | Height of outer rectangle 433 | 434 | Returns 435 | ------- 436 | padded_image: np.ndarray 437 | Image after padding. 438 | """ 439 | image_height = image.shape[0] 440 | image_width = image.shape[1] 441 | 442 | if resize_width < image_width: 443 | resize_width = image_width 444 | if resize_height < image_height: 445 | resize_height = image_height 446 | 447 | padded_image = np.zeros( 448 | (resize_height, resize_width, 3), 449 | np.uint8 450 | ) 451 | start_h = int(resize_height / 2 - image_height / 2) 452 | end_h = int(resize_height / 2 + image_height / 2) 453 | start_w = int(resize_width / 2 - image_width / 2) 454 | end_w = int(resize_width / 2 + image_width / 2) 455 | padded_image[start_h:end_h, start_w:end_w, :] = image 456 | 457 | return padded_image 458 | 459 | 460 | def rotated_rect_to_points( 461 | cx: float, 462 | cy: float, 463 | width: float, 464 | height: float, 465 | rotation: float, 466 | ) -> List[List[int]]: 467 | """Obtain coordinates for drawing a rotated rectangle 468 | 469 | Parameters 470 | ---------- 471 | cx: float 472 | Rectangle center X coordinate. 473 | 474 | cy: float 475 | Rectangle center Y coordinate. 476 | 477 | width: float 478 | Width of the rectangle. 479 | 480 | height: float 481 | Height of the rectangle. 482 | 483 | rotation: float 484 | The rotation radians in a clockwise direction. 485 | 486 | Returns 487 | ------- 488 | rotated_croped_image: List[np.ndarray] 489 | 4 coordinates for drawing rotated rectangle. 490 | 491 | e.g.\n 492 | cv2.polylines(\n 493 | img=frame,\n 494 | pts=[np.array([[p0x,p0y], [p1x,p1y], [p2x,p2y], [p3x,p3y]])],\n 495 | isClosed=True,\n 496 | color=(0,255,255),\n 497 | thickness=2,\n 498 | lineType=cv2.LINE_AA,\n 499 | ) 500 | """ 501 | b = cos(rotation) * 0.5 502 | a = sin(rotation) * 0.5 503 | p0x = cx - a*height - b*width 504 | p0y = cy + b*height - a*width 505 | p1x = cx + a*height - b*width 506 | p1y = cy - b*height - a*width 507 | p2x = int(2*cx - p0x) 508 | p2y = int(2*cy - p0y) 509 | p3x = int(2*cx - p1x) 510 | p3y = int(2*cy - p1y) 511 | p0x, p0y, p1x, p1y = int(p0x), int(p0y), int(p1x), int(p1y) 512 | 513 | return [[p0x,p0y], [p1x,p1y], [p2x,p2y], [p3x,p3y]] 514 | --------------------------------------------------------------------------------