├── .gitignore
├── LICENSE
├── README.md
├── app.py
├── hand.png
├── keypoint_classification.ipynb
├── model
├── __init__.py
├── hand_landmark
│ ├── hand_landmark.py
│ └── hand_landmark_sparse_Nx3x224x224.onnx
├── keypoint_classifier
│ ├── keypoint.csv
│ ├── keypoint_classifier.hdf5
│ ├── keypoint_classifier.onnx
│ ├── keypoint_classifier.py
│ ├── keypoint_classifier.tflite
│ ├── keypoint_classifier_label.csv
│ ├── make_argmax.py
│ └── tflite_to_onnx.sh
├── palm_detection
│ ├── palm_detection.py
│ └── palm_detection_full_inf_post_192x192.onnx
└── point_history_classifier
│ ├── argmax.onnx
│ ├── make_argmax.py
│ ├── point_history.csv
│ ├── point_history_classifier.hdf5
│ ├── point_history_classifier.onnx
│ ├── point_history_classifier.py
│ ├── point_history_classifier.tflite
│ ├── point_history_classifier_label.csv
│ ├── point_history_classifier_lstm.onnx
│ └── tflite_to_onnx.sh
├── point_history_classification.ipynb
├── requirements.txt
└── utils
├── __init__.py
├── cvfpscalc.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode/
2 | __pycache__/
3 | *.engine
4 | *.profile
5 | *.mp4
6 | test_distance_calc.py
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2022 Katsuya Hyodo
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [Japanese / [WIP] English]
2 |
3 | ---
4 | # hand-gesture-recognition-using-onnx
5 | [Kazuhito00/hand-gesture-recognition-using-mediapipe](https://github.com/Kazuhito00/hand-gesture-recognition-using-mediapipe) [@Kazuhito00](https://github.com/Kazuhito00) を引用させていただき、MediaPipeの実装を全てONNXにリプレースしたリポジトリです。簡易なMLPでハンドサインとフィンガージェスチャーを認識するサンプルプログラムです。とりあえず動く、ということしか確認していません。
6 |
7 | https://user-images.githubusercontent.com/33194443/189632510-1823cd54-bc36-4889-ac14-adc16deba9b8.mp4
8 |
9 | 本リポジトリは以下の内容を含みます。
10 | - サンプルプログラム
11 | - ハンドディテクションモデル (改造ONNX)
12 | - 手のひらランドマークディテクションモデル (改造ONNX)
13 | - ハンドサイン認識モデル (改造ONNX)
14 | - フィンガージェスチャー認識モデル (改造ONNX)
15 | - ハンドサイン認識用学習データ、および、学習用ノートブック
16 | - フィンガージェスチャー認識用学習データ、および、学習用ノートブック
17 |
18 | # Requirements
19 | - onnxruntime 1.12.0 or onnxruntime-gpu 1.12.0
20 | - opencv-contrib-python 4.6.0.66 or Later
21 | - Tensorflow 2.10.0 (学習後にONNXファイルを再作成する場合のみ)
22 | - PyTorch 1.12.0 (学習後にONNXファイルを再作成する場合のみ)
23 | - tf2onnx 1.12.0 or Later (学習後にONNXファイルを再作成する場合のみ)
24 | - simple-onnx-processing-tools 1.0.54 or Later (学習後にONNXファイルを再作成する場合のみ)
25 | - scikit-learn 0.23.2 or Later (学習時に混同行列を表示したい場合のみ)
26 | - matplotlib 3.3.2 or Later (学習時に混同行列を表示したい場合のみ)
27 |
28 | # Demo
29 | Webカメラを使ったデモの実行方法は以下です。
30 | ```bash
31 | python app.py
32 | ```
33 |
34 | デモ実行時には、以下のオプションが指定可能です。
35 | ```
36 | --device
37 | カメラデバイス番号の指定 (デフォルト:0)
38 |
39 | --width
40 | カメラキャプチャ時の横幅 (デフォルト:640)
41 |
42 | --height
43 | カメラキャプチャ時の縦幅 (デフォルト:480)
44 |
45 | --min_detection_confidence
46 | 検出信頼値の閾値 (デフォルト:0.6)
47 |
48 | --disable_image_flip
49 | 入力映像の左右反転を無効化
50 | ```
51 |
52 | # Directory
53 | ```
54 | .
55 | │ app.py
56 | │ keypoint_classification.ipynb
57 | │ point_history_classification.ipynb
58 | │ requirements.txt
59 | │ README.md
60 | │
61 | ├─model
62 | │ ├─keypoint_classifier
63 | │ │ │ tflite_to_onnx.sh
64 | │ │ │ make_argmax.py
65 | │ │ │ keypoint.csv
66 | │ │ │ keypoint_classifier.hdf5
67 | │ │ │ keypoint_classifier.py
68 | │ │ │ keypoint_classifier.tflite
69 | │ │ │ keypoint_classifier.onnx
70 | │ │ └─ keypoint_classifier_label.csv
71 | │ │
72 | │ └─point_history_classifier
73 | │ │ tflite_to_onnx.sh
74 | │ │ make_argmax.py
75 | │ │ point_history.csv
76 | │ │ point_history_classifier.hdf5
77 | │ │ point_history_classifier.py
78 | │ │ point_history_classifier.tflite
79 | │ │ point_history_classifier.onnx
80 | │ └─ point_history_classifier_label.csv
81 | │
82 | └─utils
83 | │ cvfpscalc.py
84 | └─ utils.py
85 | ```
86 | ### app.py
87 | 推論用のサンプルプログラムです。 また、ハンドサイン認識用の学習データ(キーポイント)、
88 | フィンガージェスチャー認識用の学習データ(人差指の座標履歴)を収集することもできます。
89 |
90 | ### keypoint_classification.ipynb
91 | ハンドサイン認識用のモデル訓練用スクリプトです。
92 |
93 | ### point_history_classification.ipynb
94 | フィンガージェスチャー認識用のモデル訓練用スクリプトです。
95 |
96 | ### model/keypoint_classifier
97 | ハンドサイン認識に関わるファイルを格納するディレクトリです。
98 | 以下のファイルが格納されます。
99 | - tfliteのONNX変換スクリプト (tflite_to_onnx.sh)
100 | - ONNX部品生成プログラム (make_argmax.py)
101 | - 学習用データ (keypoint.csv)
102 | - 学習済モデル (keypoint_classifier.tflite)
103 | - 学習済モデル (keypoint_classifier.onnx)
104 | - ラベルデータ (keypoint_classifier_label.csv)
105 | - 推論用クラス (keypoint_classifier.py)
106 |
107 | ### model/point_history_classifier
108 | フィンガージェスチャー認識に関わるファイルを格納するディレクトリです。
109 | 以下のファイルが格納されます。
110 | - tfliteのONNX変換スクリプト (tflite_to_onnx.sh)
111 | - ONNX部品生成プログラム (make_argmax.py)
112 | - 学習用データ (point_history.csv)
113 | - 学習済モデル (point_history_classifier.tflite)
114 | - 学習済モデル (point_history_classifier.onnx)
115 | - ラベルデータ (point_history_classifier_label.csv)
116 | - 推論用クラス (point_history_classifier.py)
117 |
118 | ### utils/cvfpscalc.py
119 | FPS計測用のモジュールです。
120 |
121 | ### utils/utils.py
122 | 画像加工用の関数群です。
123 |
124 | # Training
125 | ハンドサイン認識、フィンガージェスチャー認識は、 学習データの追加、変更、モデルの再トレーニングが出来ます。
126 |
127 | ### ハンドサイン認識トレーニング方法
128 | #### 1.学習データ収集
129 | 「k」を押すと、キーポイントの保存するモードになります(「MODE:Logging Key Point」と表示される)
130 |
131 | 「0」~「9」を押すと「model/keypoint_classifier/keypoint.csv」に以下のようにキーポイントが追記されます。
132 | 1列目:押下した数字(クラスIDとして使用)、2列目:trackid、3列目以降:キーポイント座標
133 |
134 | キーポイント座標は以下の前処理を④まで実施したものを保存します。
135 |
136 |
137 | 初期状態では、パー(クラスID:0)、グー(クラスID:1)、指差し(クラスID:2)の3種類の学習データが入っています。
138 | 必要に応じて3以降を追加したり、csvの既存データを削除して、学習データを用意してください。
139 |
140 |
141 | #### 2.モデル訓練
142 | 「[keypoint_classification.ipynb](keypoint_classification.ipynb)」をJupyter Notebookで開いて上から順に実行してください。
143 | 学習データのクラス数を変更する場合は「NUM_CLASSES = 3」の値を変更し、 「model/keypoint_classifier/keypoint_classifier_label.csv」のラベルを適宜修正してください。
144 |
145 | #### X.モデル構造
146 | 「[keypoint_classification.ipynb](keypoint_classification.ipynb)」で用意しているモデルのイメージは以下です。
147 |
148 |
149 | ### フィンガージェスチャー認識トレーニング方法
150 | #### 1.学習データ収集
151 | 「h」を押すと、指先座標の履歴を保存するモードになります(「MODE:Logging Point History」と表示される)
152 |
153 | 「0」~「9」を押すと「model/point_history_classifier/point_history.csv」に以下のようにキーポイントが追記されます。
154 | 1列目:押下した数字(クラスIDとして使用)、2列目:trackid、3列目以降:座標履歴
155 |
156 | キーポイント座標は以下の前処理を④まで実施したものを保存します。
157 |
158 | 初期状態では、静止(クラスID:0)、時計回り(クラスID:1)、反時計回り(クラスID:2)、移動(クラスID:4)の 4種類の学習データが入っています。
159 | 必要に応じて5以降を追加したり、csvの既存データを削除して、学習データを用意してください。
160 |
161 |
162 | #### 2.モデル訓練
163 | 「[point_history_classification.ipynb](point_history_classification.ipynb)」をJupyter Notebookで開いて上から順に実行してください。
164 | 学習データのクラス数を変更する場合は「NUM_CLASSES = 4」の値を変更し、 「model/point_history_classifier/point_history_classifier_label.csv」のラベルを適宜修正してください。
165 |
166 | #### X.モデル構造
167 | 「[point_history_classification.ipynb](point_history_classification.ipynb)」で用意しているモデルのイメージは以下です。
168 |
169 | 「LSTM」を用いたモデルは以下です。 使用する際には「use_lstm = False」を「True」に変更してください
170 |
171 |
172 | # Application example
173 | - None
174 |
175 | # Reference
176 | - [Kazuhito00/hand-gesture-recognition-using-mediapipe](https://github.com/Kazuhito00/hand-gesture-recognition-using-mediapipe)
177 | - [geaxgx/depthai_hand_tracker](https://github.com/geaxgx/depthai_hand_tracker)
178 | - [PINTO0309/simple-onnx-processing-tools](https://github.com/PINTO0309/simple-onnx-processing-tools)
179 | - https://github.com/PINTO0309/PINTO_model_zoo/tree/main/033_Hand_Detection_and_Tracking
180 |
181 | # Acknowledgments
182 | - hand-gesture-recognition-using-mediapipe: 高橋かずひと (https://twitter.com/KzhtTkhs)
183 | - [Apache-2.0 License](https://github.com/Kazuhito00/hand-gesture-recognition-using-mediapipe/blob/main/LICENSE)
184 | - depthai_hand_tracker: geaxgx (https://github.com/geaxgx/depthai_hand_tracker)
185 | - [MIT License](https://github.com/geaxgx/depthai_hand_tracker/blob/main/LICENSE.txt)
186 |
187 | # Author
188 | - hand-gesture-recognition-using-onnx: Katsuya Hyodo
189 | - [Apache-2.0 License](LICENSE)
190 |
191 | # ToDo
192 | - [x] LSTM ONNX
193 | - [ ] BPF (body_pre_focusing) の実装
194 | - [ ] 変数とループの整理
195 | - [ ] CSVロギングの見直し
196 | - [ ] DeepSort
197 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import csv
4 | import copy
5 | import argparse
6 | import itertools
7 | from math import degrees
8 | from collections import Counter
9 | from collections import deque
10 |
11 | import cv2 as cv
12 | import numpy as np
13 |
14 | from utils import CvFpsCalc
15 | from utils.utils import rotate_and_crop_rectangle
16 | from model import PalmDetection
17 | from model import HandLandmark
18 | from model import KeyPointClassifier
19 | from model import PointHistoryClassifier
20 |
21 |
22 | def get_args():
23 | parser = argparse.ArgumentParser()
24 |
25 | parser.add_argument(
26 | '-d',
27 | '--device',
28 | type=int,
29 | default=0,
30 | )
31 | parser.add_argument(
32 | '-im',
33 | '--image',
34 | type=str,
35 | default='',
36 | )
37 | parser.add_argument(
38 | '-wi',
39 | '--width',
40 | help='cap width',
41 | type=int,
42 | default=640,
43 | )
44 | parser.add_argument(
45 | '-he',
46 | '--height',
47 | help='cap height',
48 | type=int,
49 | default=480,
50 | )
51 | parser.add_argument(
52 | '-mdc',
53 | '--min_detection_confidence',
54 | help='min_detection_confidence',
55 | type=float,
56 | default=0.6,
57 | )
58 | parser.add_argument(
59 | '-dif',
60 | '--disable_image_flip',
61 | help='disable image flip',
62 | action='store_true',
63 | )
64 |
65 |
66 | args = parser.parse_args()
67 |
68 | return args
69 |
70 |
71 | def main():
72 | # 引数解析 #################################################################
73 | args = get_args()
74 |
75 | if not args.image:
76 | cap_device = args.device
77 | else:
78 | cap_device = args.image
79 | cap_width = args.width
80 | cap_height = args.height
81 | min_detection_confidence = args.min_detection_confidence
82 |
83 | lines_hand = [
84 | [0,1],[1,2],[2,3],[3,4],
85 | [0,5],[5,6],[6,7],[7,8],
86 | [5,9],[9,10],[10,11],[11,12],
87 | [9,13],[13,14],[14,15],[15,16],
88 | [13,17],[17,18],[18,19],[19,20],[0,17],
89 | ]
90 |
91 | # カメラ準備 ###############################################################
92 | cap = cv.VideoCapture(cap_device)
93 | cap.set(cv.CAP_PROP_FRAME_WIDTH, cap_width)
94 | cap.set(cv.CAP_PROP_FRAME_HEIGHT, cap_height)
95 | cap_fps = cap.get(cv.CAP_PROP_FPS)
96 | fourcc = cv.VideoWriter_fourcc('m', 'p', '4', 'v')
97 | video_writer = cv.VideoWriter(
98 | filename='output.mp4',
99 | fourcc=fourcc,
100 | fps=cap_fps,
101 | frameSize=(cap_width, cap_height),
102 | )
103 |
104 | # モデルロード #############################################################
105 | palm_detection = PalmDetection(score_threshold=min_detection_confidence)
106 | hand_landmark = HandLandmark()
107 |
108 | keypoint_classifier = KeyPointClassifier()
109 | point_history_classifier = PointHistoryClassifier()
110 |
111 | # ラベル読み込み ###########################################################
112 | with open(
113 | 'model/keypoint_classifier/keypoint_classifier_label.csv',
114 | encoding='utf-8-sig',
115 | ) as f:
116 | keypoint_classifier_labels = csv.reader(f)
117 | keypoint_classifier_labels = [
118 | row[0] for row in keypoint_classifier_labels
119 | ]
120 | with open(
121 | 'model/point_history_classifier/point_history_classifier_label.csv',
122 | encoding='utf-8-sig',
123 | ) as f:
124 | point_history_classifier_labels = csv.reader(f)
125 | point_history_classifier_labels = [
126 | row[0] for row in point_history_classifier_labels
127 | ]
128 |
129 | # FPS計測モジュール ########################################################
130 | cvFpsCalc = CvFpsCalc(buffer_len=10)
131 |
132 | # 座標履歴 ################################################################
133 | history_length = 16
134 | # point_history = deque(maxlen=history_length)
135 | point_history = {}
136 | pre_point_history = {}
137 |
138 | # フィンガージェスチャー履歴 #################################################
139 | gesture_history_length = 10
140 | finger_gesture_history = {}
141 |
142 | # 手のひらトラッキング用手のひら中心座標最新履歴 #################################
143 | # {
144 | # int(trackid1): [cx, cy],
145 | # int(trackid2): [cx, cy],
146 | # int(trackid3): [cx, cy],
147 | # :
148 | # }
149 | palm_trackid_cxcy = {}
150 |
151 | # #######################################################################
152 | mode = 0
153 | wh_ratio = cap_width / cap_height
154 |
155 | auto = False
156 | prev_number = -1
157 | image = None
158 |
159 | while True:
160 | fps = cvFpsCalc.get()
161 |
162 | # キー処理(ESC:終了) #################################################
163 | key = cv.waitKey(1) if not args.image else cv.waitKey(0) if image is not None and args.image else cv.waitKey(1)
164 | if key == 27: # ESC
165 | break
166 | number, mode, auto, prev_number = select_mode(key, mode, auto, prev_number)
167 |
168 | # カメラキャプチャ #####################################################
169 | ret, image = cap.read()
170 | if not ret:
171 | break
172 | image = image if args.disable_image_flip else cv.flip(image, 1) # ミラー表示
173 | debug_image = copy.deepcopy(image)
174 |
175 | # 検出実施 #############################################################
176 |
177 | # ============================================================= PalmDetection
178 | # ハンドディテクション - シングルバッチ処理
179 | hands = palm_detection(image)
180 | # hand: sqn_rr_size, rotation, sqn_rr_center_x, sqn_rr_center_y
181 |
182 | rects = []
183 | not_rotate_rects = []
184 | rects_tuple = None
185 | cropted_rotated_hands_images = []
186 |
187 | # 手の検出件数がゼロになったらトラッキング用手のひら中心座標最新履歴を初期化
188 | if len(hands) == 0:
189 | palm_trackid_cxcy = {}
190 | # トラッキング用手のひら中心座標最新履歴とバウンディングボックスの検出順序紐づけリスト
191 | palm_trackid_box_x1y1s = {}
192 |
193 | if len(hands) > 0:
194 | for hand in hands:
195 | # hand: sqn_rr_size, rotation, sqn_rr_center_x, sqn_rr_center_y
196 | sqn_rr_size = hand[0]
197 | rotation = hand[1]
198 | sqn_rr_center_x = hand[2]
199 | sqn_rr_center_y = hand[3]
200 |
201 | cx = int(sqn_rr_center_x * cap_width)
202 | cy = int(sqn_rr_center_y * cap_height)
203 | xmin = int((sqn_rr_center_x - (sqn_rr_size / 2)) * cap_width)
204 | xmax = int((sqn_rr_center_x + (sqn_rr_size / 2)) * cap_width)
205 | ymin = int((sqn_rr_center_y - (sqn_rr_size * wh_ratio / 2)) * cap_height)
206 | ymax = int((sqn_rr_center_y + (sqn_rr_size * wh_ratio / 2)) * cap_height)
207 | xmin = max(0, xmin)
208 | xmax = min(cap_width, xmax)
209 | ymin = max(0, ymin)
210 | ymax = min(cap_height, ymax)
211 | degree = degrees(rotation)
212 | # [boxcount, cx, cy, width, height, degree]
213 | rects.append([cx, cy, (xmax-xmin), (ymax-ymin), degree])
214 |
215 | rects = np.asarray(rects, dtype=np.float32)
216 |
217 | # 回転角度をゼロ度に補正した手のひら画像の取得
218 | cropted_rotated_hands_images = rotate_and_crop_rectangle(
219 | image=image,
220 | rects_tmp=rects,
221 | operation_when_cropping_out_of_range='padding',
222 | )
223 |
224 | # Debug ===============================================================
225 | for rect in rects:
226 | # 回転考慮の領域の描画, 赤色の枠
227 | rects_tuple = ((rect[0], rect[1]), (rect[2], rect[3]), rect[4])
228 | box = cv.boxPoints(rects_tuple).astype(np.int0)
229 | cv.drawContours(debug_image, [box], 0,(0,0,255), 2, cv.LINE_AA)
230 |
231 | # 回転非考慮の領域の描画, オレンジ色の枠
232 | rcx = int(rect[0])
233 | rcy = int(rect[1])
234 | half_w = int(rect[2] // 2)
235 | half_h = int(rect[3] // 2)
236 | x1 = rcx - half_w
237 | y1 = rcy - half_h
238 | x2 = rcx + half_w
239 | y2 = rcy + half_h
240 | text_x = max(x1, 10)
241 | text_x = min(text_x, cap_width-120)
242 | text_y = max(y1-15, 45)
243 | text_y = min(text_y, cap_height-20)
244 | # [boxcount, rcx, rcy, x1, y1, x2, y2, height, degree]
245 | not_rotate_rects.append([rcx, rcy, x1, y1, x2, y2, 0])
246 | # 検出枠のサイズ WxH
247 | cv.putText(
248 | debug_image,
249 | f'{y2-y1}x{x2-x1}',
250 | (text_x, text_y),
251 | cv.FONT_HERSHEY_SIMPLEX,
252 | 0.8,
253 | (0,0,0),
254 | 2,
255 | cv.LINE_AA,
256 | )
257 | cv.putText(
258 | debug_image,
259 | f'{y2-y1}x{x2-x1}',
260 | (text_x, text_y),
261 | cv.FONT_HERSHEY_SIMPLEX,
262 | 0.8,
263 | (59,255,255),
264 | 1,
265 | cv.LINE_AA,
266 | )
267 | # 検出枠の描画
268 | cv.rectangle(
269 | debug_image,
270 | (x1,y1),
271 | (x2,y2),
272 | (0,128,255),
273 | 2,
274 | cv.LINE_AA,
275 | )
276 | # 検出領域の中心座標描画
277 | cv.circle(
278 | debug_image,
279 | (rcx, rcy),
280 | 3,
281 | (0, 255, 255),
282 | -1,
283 | )
284 | """
285 | 手のひらトラッキング用手のひら中心座標最新履歴の保存
286 | 1. 過去履歴の中から基準点との距離が一番近い中心座標を抽出
287 | 2. 距離が100pxより離れている場合は新たな手のひらと認識させる
288 | 3. 距離が100px以下の場合は該当のtrackidを割り当てて過去履歴の中心座標を上書きする
289 | """
290 | # 1. 過去履歴の中から基準点との距離が一番近い中心座標を抽出
291 | base_point = np.asarray(
292 | [rcx, rcy],
293 | dtype=np.float32,
294 | )
295 | points = np.asarray(
296 | list(palm_trackid_cxcy.values()),
297 | dtype=np.float32,
298 | )
299 | if len(points) > 0:
300 | # 最近傍点探索
301 | diff_val = points - base_point
302 | all_points_distance = np.linalg.norm(diff_val, axis=1)
303 | nearest_trackid = np.argmin(all_points_distance)
304 | nearest_distance = all_points_distance[nearest_trackid]
305 | new_trackid = int(nearest_trackid) + 1
306 | # 2. 距離が100pxより離れている場合は新たな手のひらと認識させる
307 | # 3. 距離が100px以下の場合は該当のtrackidを割り当てて過去履歴の中心座標を上書きする
308 | if nearest_distance > 100:
309 | # 現状のtrackid最大値+1を新規trackidとして生成
310 | new_trackid = next(iter(reversed(palm_trackid_cxcy))) + 1
311 | else:
312 | # trackid初期値
313 | new_trackid = 1
314 |
315 | # 手のひらトラッキング用手のひら中心座標最新履歴の最新座標を更新 または 新規追加
316 | palm_trackid_cxcy[new_trackid] = [rcx, rcy]
317 | # バウンディングボックスの検出順序とtrackidの順序を整合
318 | # box_x1y1x2y2_palm_trackids.append([x1, y1, x2, y2, new_trackid])
319 | palm_trackid_box_x1y1s[new_trackid] = [x1, y1]
320 | # Debug ===============================================================
321 |
322 | # ============================================================= HandLandmark
323 | if len(cropted_rotated_hands_images) > 0:
324 |
325 | # Inference HandLandmark - バッチ処理
326 | hand_landmarks, rotated_image_size_leftrights = hand_landmark(
327 | images=cropted_rotated_hands_images,
328 | rects=rects,
329 | )
330 |
331 | if len(hand_landmarks) > 0:
332 | # Draw
333 | pre_processed_landmarks = []
334 | pre_processed_point_histories = []
335 | for (trackid, x1y1), landmark, rotated_image_size_leftright, not_rotate_rect in \
336 | zip(palm_trackid_box_x1y1s.items(), hand_landmarks, rotated_image_size_leftrights, not_rotate_rects):
337 |
338 | x1, y1 = x1y1
339 | rotated_image_width, _, left_hand_0_or_right_hand_1 = rotated_image_size_leftright
340 | thick_coef = rotated_image_width / 400
341 | lines = np.asarray(
342 | [
343 | np.array([landmark[point] for point in line]).astype(np.int32) for line in lines_hand
344 | ]
345 | )
346 | radius = int(1+thick_coef*5)
347 | cv.polylines(
348 | debug_image,
349 | lines,
350 | False,
351 | (255, 0, 0),
352 | int(radius),
353 | cv.LINE_AA,
354 | )
355 | _ = [cv.circle(debug_image, (int(x), int(y)), radius, (0,128,255), -1) for x,y in landmark[:,:2]]
356 | left_hand_0_or_right_hand_1 = left_hand_0_or_right_hand_1 if args.disable_image_flip else (1 - left_hand_0_or_right_hand_1)
357 | handedness = 'Left ' if left_hand_0_or_right_hand_1 == 0 else 'Right'
358 | _, _, x1, y1, _, _, _ = not_rotate_rect
359 | text_x = max(x1, 10)
360 | text_x = min(text_x, cap_width-120)
361 | text_y = max(y1-70, 20)
362 | text_y = min(text_y, cap_height-70)
363 | cv.putText(
364 | debug_image,
365 | f'trackid:{trackid} {handedness}',
366 | (text_x, text_y),
367 | cv.FONT_HERSHEY_SIMPLEX,
368 | 0.8,
369 | (0,0,0),
370 | 2,
371 | cv.LINE_AA,
372 | )
373 | cv.putText(
374 | debug_image,
375 | f'trackid:{trackid} {handedness}',
376 | (text_x, text_y),
377 | cv.FONT_HERSHEY_SIMPLEX,
378 | 0.8,
379 | (59,255,255),
380 | 1,
381 | cv.LINE_AA,
382 | )
383 |
384 | # 相対座標・正規化座標への変換
385 | """
386 | pre_processed_landmark: np.ndarray [42], [x,y]x21
387 | """
388 | pre_processed_landmark = pre_process_landmark(
389 | landmark,
390 | )
391 | pre_processed_landmarks.append(pre_processed_landmark)
392 |
393 | """
394 | point_history: dict
395 | {
396 | int(trackid1): [[x, y],[x, y],[x, y],[x, y], ...],
397 | int(trackid2): [[x, y],[x, y], ...],
398 | int(trackid3): [[x, y],[x, y],[x, y], ...],
399 | :
400 | }
401 | ↓
402 | pre_processed_point_histories: List
403 | [
404 | [rx, ry, rx, ry, rx, ry, rx, ry, ...],
405 | [rx, ry, rx, ry, ...],
406 | [rx, ry, rx, ry, rx, ry, ...],
407 | :
408 | ]
409 | """
410 | # 人差し指軌跡を相対座標へ変換
411 | pre_processed_point_histories = pre_process_point_history(
412 | image_width=debug_image.shape[1],
413 | image_height=debug_image.shape[0],
414 | point_history=point_history,
415 | )
416 |
417 | # 学習データ保存
418 | logging_csv(
419 | number,
420 | mode,
421 | trackid,
422 | pre_processed_landmark,
423 | pre_processed_point_histories,
424 | )
425 |
426 | # ハンドサイン分類 - バッチ処理
427 | hand_sign_ids = keypoint_classifier(
428 | np.asarray(pre_processed_landmarks, dtype=np.float32)
429 | )
430 | for (trackid, x1y1), landmark, hand_sign_id in zip(palm_trackid_box_x1y1s.items(), hand_landmarks, hand_sign_ids):
431 | x1, y1 = x1y1
432 | point_history.setdefault(trackid, deque(maxlen=history_length))
433 | if hand_sign_id == 2: # 指差しサイン
434 | point_history[trackid].append(list(landmark[8])) # 人差指座標
435 | else:
436 | point_history[trackid].append([0, 0])
437 |
438 | """
439 | 人差し指の軌跡が表示上に残り続けるのを割けるため
440 | トラッキング対象外になった(画角から外れた)手のひらがある場合は人差指XY座標の履歴をクリアする
441 | 今回の全ての軌跡座標と前回の全ての軌跡座標が完全に一致したtrackidの履歴情報は変化なしと断定して履歴から削除する
442 | point_history: 最新の軌跡16点
443 | pre_point_history: 前回の軌跡16点
444 | """
445 | if len(pre_point_history) > 0:
446 | temp_point_history = copy.deepcopy(point_history)
447 | for track_id, points in temp_point_history.items():
448 | if track_id in pre_point_history:
449 | pre_points = pre_point_history[track_id]
450 | if points == pre_points:
451 | _ = point_history.pop(track_id, None)
452 | pre_point_history = copy.deepcopy(point_history)
453 |
454 | # フィンガージェスチャー分類 - バッチ処理
455 | finger_gesture_ids = None
456 | temp_trackid_x1y1s = {}
457 | temp_pre_processed_point_history = []
458 | for (trackid, x1y1), pre_processed_point_history in zip(palm_trackid_box_x1y1s.items(), pre_processed_point_histories):
459 | point_history_len = len(pre_processed_point_history)
460 | if point_history_len > 0 and point_history_len % (history_length * 2) == 0:
461 | temp_trackid_x1y1s[trackid] = x1y1
462 | temp_pre_processed_point_history.append(pre_processed_point_history)
463 | if len(temp_pre_processed_point_history) > 0:
464 | finger_gesture_ids = point_history_classifier(
465 | temp_pre_processed_point_history,
466 | )
467 |
468 | # 直近検出の中で最多のジェスチャーIDを算出
469 | if finger_gesture_ids is not None:
470 | for (trackid, x1y1), finger_gesture_id in zip(temp_trackid_x1y1s.items(), finger_gesture_ids):
471 | x1, y1 = x1y1
472 | trackid_str = str(trackid)
473 | finger_gesture_history.setdefault(trackid_str, deque(maxlen=gesture_history_length))
474 | finger_gesture_history[trackid_str].append(int(finger_gesture_id))
475 | most_common_fg_id = Counter(finger_gesture_history[trackid_str]).most_common()
476 | text_x = max(x1, 10)
477 | text_x = min(text_x, cap_width-120)
478 | text_y = max(y1-45, 20)
479 | text_y = min(text_y, cap_height-45)
480 | classifier_label = point_history_classifier_labels[most_common_fg_id[0][0]]
481 | # print(f'trackid: {trackid} [x1,y1]: [{x1},{y1}] finger_gesture_id: {classifier_label}')
482 | cv.putText(
483 | debug_image,
484 | f'{classifier_label}',
485 | (text_x, text_y),
486 | cv.FONT_HERSHEY_SIMPLEX,
487 | 0.8,
488 | (0,0,0),
489 | 2,
490 | cv.LINE_AA,
491 | )
492 | cv.putText(
493 | debug_image,
494 | f'{classifier_label}',
495 | (text_x, text_y),
496 | cv.FONT_HERSHEY_SIMPLEX,
497 | 0.8,
498 | (59,255,255),
499 | 1,
500 | cv.LINE_AA,
501 | )
502 |
503 | else:
504 | point_history = {}
505 |
506 | else:
507 | point_history = {}
508 |
509 | debug_image = draw_point_history(debug_image, point_history)
510 | debug_image = draw_info(debug_image, fps, mode, number, auto)
511 |
512 | # 画面反映 #############################################################
513 | cv.imshow('Hand Gesture Recognition', debug_image)
514 | video_writer.write(debug_image)
515 |
516 | if video_writer:
517 | video_writer.release()
518 | if cap:
519 | cap.release()
520 | cv.destroyAllWindows()
521 |
522 |
523 | def select_mode(key, mode, auto=False, prev_number=-1):
524 | number = -1
525 | if 48 <= key <= 57: # 0 ~ 9
526 | number = key - 48
527 | prev_number = number
528 | if key == 110: # n
529 | mode = 0
530 | if key == 107: # k
531 | mode = 1
532 | if key == 104: # h
533 | mode = 2
534 | if key == 97: # a
535 | auto = not auto
536 | if auto == True:
537 | if prev_number != -1:
538 | number = prev_number
539 | else:
540 | prev_number = -1
541 |
542 | return number, mode, auto, prev_number
543 |
544 |
545 | def pre_process_landmark(landmark_list):
546 | if len(landmark_list) == 0:
547 | return []
548 |
549 | temp_landmark_list = copy.deepcopy(landmark_list)
550 | # 相対座標に変換
551 | base_x, base_y = temp_landmark_list[0][0], temp_landmark_list[0][1]
552 | temp_landmark_list = [
553 | [temp_landmark[0] - base_x, temp_landmark[1] - base_y] for temp_landmark in temp_landmark_list
554 | ]
555 | # 1次元リストに変換
556 | temp_landmark_list = list(
557 | itertools.chain.from_iterable(temp_landmark_list)
558 | )
559 | # 正規化
560 | max_value = max(list(map(abs, temp_landmark_list)))
561 |
562 | def normalize_(n):
563 | return n / max_value
564 |
565 | temp_landmark_list = list(map(normalize_, temp_landmark_list))
566 | return temp_landmark_list
567 |
568 |
569 | def pre_process_point_history(
570 | image_width: int,
571 | image_height: int,
572 | point_history: dict,
573 | ):
574 | """pre_process_point_history
575 |
576 | Parameters
577 | ----------
578 | image_width: int
579 | Input image width
580 |
581 | image_height: int
582 | Input image height
583 |
584 | point_history: dict
585 | Index finger XY coordinate history per trackid (detected palm)
586 | {
587 | int(trackid1): [[x, y],[x, y],[x, y],[x, y], ...],
588 | int(trackid2): [[x, y],[x, y], ...],
589 | int(trackid3): [[x, y],[x, y],[x, y], ...],
590 | :
591 | }
592 |
593 | Returns
594 | -------
595 | relative_coordinate_list_by_trackid: List
596 | [
597 | [rx, ry, rx, ry, rx, ry, rx, ry, ...],
598 | [rx, ry, rx, ry, ...],
599 | [rx, ry, rx, ry, rx, ry, ...],
600 | :
601 | ]
602 | """
603 | if len(point_history) == 0:
604 | return []
605 |
606 | temp_point_history = copy.deepcopy(point_history)
607 | relative_coordinate_list_by_trackid = []
608 |
609 | # trackidごとに相対座標へ変換
610 | for trackid, points in temp_point_history.items():
611 | base_x, base_y = points[0][0], points[0][1]
612 | relative_coordinate_list = [
613 | [
614 | (point[0] - base_x) / image_width,
615 | (point[1] - base_y) / image_height,
616 | ] for point in points
617 | ]
618 | # 1次元リストに変換
619 | relative_coordinate_list_1d = list(
620 | itertools.chain.from_iterable(relative_coordinate_list)
621 | )
622 | relative_coordinate_list_by_trackid.append(relative_coordinate_list_1d)
623 | return relative_coordinate_list_by_trackid
624 |
625 |
626 | def logging_csv(number, mode, trackid, landmark_list, point_histories):
627 | if mode == 0:
628 | pass
629 | if mode == 1 and (0 <= number <= 9):
630 | csv_path = 'model/keypoint_classifier/keypoint.csv'
631 | with open(csv_path, 'a', newline="") as f:
632 | writer = csv.writer(f)
633 | writer.writerow([number, trackid, *landmark_list])
634 | if mode == 2 and (0 <= number <= 9):
635 | csv_path = 'model/point_history_classifier/point_history.csv'
636 | with open(csv_path, 'a', newline="") as f:
637 | writer = csv.writer(f)
638 | for point_history in point_histories:
639 | writer.writerow([number, trackid, *point_history])
640 |
641 |
642 | def draw_info_text(
643 | image,
644 | brect,
645 | handedness,
646 | hand_sign_text,
647 | finger_gesture_text
648 | ):
649 | info_text = handedness
650 | if hand_sign_text != "":
651 | info_text = f'{handedness}:{hand_sign_text}'
652 | cv.putText(
653 | image,
654 | info_text,
655 | (brect[0] + 5, brect[1] - 4),
656 | cv.FONT_HERSHEY_SIMPLEX,
657 | 0.6,
658 | (255, 255, 255),
659 | 1,
660 | cv.LINE_AA,
661 | )
662 |
663 | if finger_gesture_text != "":
664 | cv.putText(
665 | image,
666 | f'Finger Gesture:{finger_gesture_text}',
667 | (10, 60),
668 | cv.FONT_HERSHEY_SIMPLEX,
669 | 1.0,
670 | (0, 0, 0),
671 | 4,
672 | cv.LINE_AA,
673 | )
674 | cv.putText(
675 | image,
676 | f'Finger Gesture:{finger_gesture_text}',
677 | (10, 60),
678 | cv.FONT_HERSHEY_SIMPLEX,
679 | 1.0,
680 | (255, 255, 255),
681 | 2,
682 | cv.LINE_AA,
683 | )
684 |
685 | return image
686 |
687 |
688 | def draw_point_history(image, point_history):
689 | _ = [
690 | cv.circle(image, (point[0], point[1]), 1 + int(index / 2), (152, 251, 152), 2) \
691 | for trackid, points in point_history.items() \
692 | for index, point in enumerate(points) if point[0] != 0 and point[1] != 0
693 | ]
694 | return image
695 |
696 |
697 | def draw_info(image, fps, mode, number, auto):
698 | cv.putText(
699 | image,
700 | f'FPS:{str(fps)}',
701 | (10, 30),
702 | cv.FONT_HERSHEY_SIMPLEX,
703 | 1.0,
704 | (0, 0, 0),
705 | 4,
706 | cv.LINE_AA,
707 | )
708 | cv.putText(
709 | image,
710 | f'FPS:{str(fps)}',
711 | (10, 30),
712 | cv.FONT_HERSHEY_SIMPLEX,
713 | 1.0,
714 | (255, 255, 255),
715 | 2,
716 | cv.LINE_AA,
717 | )
718 |
719 | mode_string = ['Logging Key Point', 'Logging Point History']
720 | if 1 <= mode <= 2:
721 | cv.putText(
722 | image,
723 | f'MODE:{mode_string[mode - 1]}',
724 | (10, 90),
725 | cv.FONT_HERSHEY_SIMPLEX,
726 | 0.6,
727 | (255, 255, 255),
728 | 1,
729 | cv.LINE_AA,
730 | )
731 | if 0 <= number <= 9:
732 | cv.putText(
733 | image,
734 | f'NUM:{str(number)}',
735 | (10, 110),
736 | cv.FONT_HERSHEY_SIMPLEX,
737 | 0.6,
738 | (255, 255, 255),
739 | 1,
740 | cv.LINE_AA,
741 | )
742 | cv.putText(
743 | image,
744 | f'AUTO:{str(auto)}',
745 | (10, 130),
746 | cv.FONT_HERSHEY_SIMPLEX,
747 | 0.6,
748 | (255, 255, 255),
749 | 1,
750 | cv.LINE_AA,
751 | )
752 | return image
753 |
754 |
755 | if __name__ == '__main__':
756 | main()
757 |
--------------------------------------------------------------------------------
/hand.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PINTO0309/hand-gesture-recognition-using-onnx/baf475d8015183dbff17e56771c0f623c8b5c816/hand.png
--------------------------------------------------------------------------------
/keypoint_classification.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import csv\n",
10 | "\n",
11 | "import numpy as np\n",
12 | "import tensorflow as tf\n",
13 | "from sklearn.model_selection import train_test_split\n",
14 | "\n",
15 | "RANDOM_SEED = 42"
16 | ]
17 | },
18 | {
19 | "cell_type": "markdown",
20 | "metadata": {},
21 | "source": [
22 | "# 各パス指定"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 2,
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "dataset = 'model/keypoint_classifier/keypoint.csv'\n",
32 | "model_save_path = 'model/keypoint_classifier/keypoint_classifier.hdf5'"
33 | ]
34 | },
35 | {
36 | "cell_type": "markdown",
37 | "metadata": {},
38 | "source": [
39 | "# 分類数設定"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": 3,
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "NUM_CLASSES = 3"
49 | ]
50 | },
51 | {
52 | "cell_type": "markdown",
53 | "metadata": {},
54 | "source": [
55 | "# 学習データ読み込み"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": 4,
61 | "metadata": {},
62 | "outputs": [],
63 | "source": [
64 | "X_dataset = np.loadtxt(dataset, delimiter=',', dtype='float32', usecols=list(range(1, (21 * 2) + 1)))"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": 5,
70 | "metadata": {},
71 | "outputs": [],
72 | "source": [
73 | "y_dataset = np.loadtxt(dataset, delimiter=',', dtype='int32', usecols=(0))"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": 6,
79 | "metadata": {},
80 | "outputs": [],
81 | "source": [
82 | "X_train, X_test, y_train, y_test = train_test_split(X_dataset, y_dataset, train_size=0.75, random_state=RANDOM_SEED)"
83 | ]
84 | },
85 | {
86 | "cell_type": "markdown",
87 | "metadata": {},
88 | "source": [
89 | "# モデル構築"
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": 7,
95 | "metadata": {},
96 | "outputs": [],
97 | "source": [
98 | "model = tf.keras.models.Sequential([\n",
99 | " tf.keras.layers.Input((21 * 2, )),\n",
100 | " tf.keras.layers.Dropout(0.2),\n",
101 | " tf.keras.layers.Dense(20, activation='relu'),\n",
102 | " tf.keras.layers.Dropout(0.4),\n",
103 | " tf.keras.layers.Dense(10, activation='relu'),\n",
104 | " tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')\n",
105 | "])"
106 | ]
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": 8,
111 | "metadata": {},
112 | "outputs": [
113 | {
114 | "name": "stdout",
115 | "output_type": "stream",
116 | "text": [
117 | "Model: \"sequential\"\n",
118 | "_________________________________________________________________\n",
119 | "Layer (type) Output Shape Param # \n",
120 | "=================================================================\n",
121 | "dropout (Dropout) (None, 42) 0 \n",
122 | "_________________________________________________________________\n",
123 | "dense (Dense) (None, 20) 860 \n",
124 | "_________________________________________________________________\n",
125 | "dropout_1 (Dropout) (None, 20) 0 \n",
126 | "_________________________________________________________________\n",
127 | "dense_1 (Dense) (None, 10) 210 \n",
128 | "_________________________________________________________________\n",
129 | "dense_2 (Dense) (None, 3) 33 \n",
130 | "=================================================================\n",
131 | "Total params: 1,103\n",
132 | "Trainable params: 1,103\n",
133 | "Non-trainable params: 0\n",
134 | "_________________________________________________________________\n"
135 | ]
136 | }
137 | ],
138 | "source": [
139 | "model.summary() # tf.keras.utils.plot_model(model, show_shapes=True)"
140 | ]
141 | },
142 | {
143 | "cell_type": "code",
144 | "execution_count": 9,
145 | "metadata": {},
146 | "outputs": [],
147 | "source": [
148 | "# モデルチェックポイントのコールバック\n",
149 | "cp_callback = tf.keras.callbacks.ModelCheckpoint(\n",
150 | " model_save_path, verbose=1, save_weights_only=False)\n",
151 | "# 早期打ち切り用コールバック\n",
152 | "es_callback = tf.keras.callbacks.EarlyStopping(patience=20, verbose=1)"
153 | ]
154 | },
155 | {
156 | "cell_type": "code",
157 | "execution_count": 10,
158 | "metadata": {},
159 | "outputs": [],
160 | "source": [
161 | "# モデルコンパイル\n",
162 | "model.compile(\n",
163 | " optimizer='adam',\n",
164 | " loss='sparse_categorical_crossentropy',\n",
165 | " metrics=['accuracy']\n",
166 | ")"
167 | ]
168 | },
169 | {
170 | "cell_type": "markdown",
171 | "metadata": {},
172 | "source": [
173 | "# モデル訓練"
174 | ]
175 | },
176 | {
177 | "cell_type": "code",
178 | "execution_count": 11,
179 | "metadata": {
180 | "scrolled": true
181 | },
182 | "outputs": [
183 | {
184 | "name": "stdout",
185 | "output_type": "stream",
186 | "text": [
187 | "Epoch 1/1000\n",
188 | " 1/27 [>.............................] - ETA: 0s - loss: 1.1295 - accuracy: 0.3203\n",
189 | "Epoch 00001: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
190 | "27/27 [==============================] - 0s 11ms/step - loss: 1.1004 - accuracy: 0.3602 - val_loss: 1.0431 - val_accuracy: 0.5220\n",
191 | "Epoch 2/1000\n",
192 | " 1/27 [>.............................] - ETA: 0s - loss: 1.0440 - accuracy: 0.4844\n",
193 | "Epoch 00002: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
194 | "27/27 [==============================] - 0s 3ms/step - loss: 1.0503 - accuracy: 0.4297 - val_loss: 0.9953 - val_accuracy: 0.6397\n",
195 | "Epoch 3/1000\n",
196 | " 1/27 [>.............................] - ETA: 0s - loss: 1.0043 - accuracy: 0.5312\n",
197 | "Epoch 00003: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
198 | "27/27 [==============================] - 0s 4ms/step - loss: 1.0210 - accuracy: 0.4582 - val_loss: 0.9545 - val_accuracy: 0.6523\n",
199 | "Epoch 4/1000\n",
200 | " 1/27 [>.............................] - ETA: 0s - loss: 0.9503 - accuracy: 0.5625\n",
201 | "Epoch 00004: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
202 | "27/27 [==============================] - 0s 4ms/step - loss: 0.9906 - accuracy: 0.5022 - val_loss: 0.9168 - val_accuracy: 0.6721\n",
203 | "Epoch 5/1000\n",
204 | " 1/27 [>.............................] - ETA: 0s - loss: 0.9562 - accuracy: 0.5469\n",
205 | "Epoch 00005: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
206 | "27/27 [==============================] - 0s 3ms/step - loss: 0.9654 - accuracy: 0.5340 - val_loss: 0.8791 - val_accuracy: 0.7017\n",
207 | "Epoch 6/1000\n",
208 | " 1/27 [>.............................] - ETA: 0s - loss: 0.9184 - accuracy: 0.5938\n",
209 | "Epoch 00006: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
210 | "27/27 [==============================] - 0s 3ms/step - loss: 0.9256 - accuracy: 0.5577 - val_loss: 0.8344 - val_accuracy: 0.7269\n",
211 | "Epoch 7/1000\n",
212 | "27/27 [==============================] - ETA: 0s - loss: 0.9050 - accuracy: 0.5715\n",
213 | "Epoch 00007: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
214 | "27/27 [==============================] - 0s 4ms/step - loss: 0.9050 - accuracy: 0.5715 - val_loss: 0.7887 - val_accuracy: 0.7646\n",
215 | "Epoch 8/1000\n",
216 | " 1/27 [>.............................] - ETA: 0s - loss: 0.9135 - accuracy: 0.5547\n",
217 | "Epoch 00008: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
218 | "27/27 [==============================] - 0s 3ms/step - loss: 0.8642 - accuracy: 0.5993 - val_loss: 0.7414 - val_accuracy: 0.7996\n",
219 | "Epoch 9/1000\n",
220 | " 1/27 [>.............................] - ETA: 0s - loss: 0.8002 - accuracy: 0.6172\n",
221 | "Epoch 00009: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
222 | "27/27 [==============================] - 0s 3ms/step - loss: 0.8258 - accuracy: 0.6263 - val_loss: 0.6881 - val_accuracy: 0.8149\n",
223 | "Epoch 10/1000\n",
224 | " 1/27 [>.............................] - ETA: 0s - loss: 0.8056 - accuracy: 0.6328\n",
225 | "Epoch 00010: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
226 | "27/27 [==============================] - 0s 3ms/step - loss: 0.8008 - accuracy: 0.6341 - val_loss: 0.6461 - val_accuracy: 0.8239\n",
227 | "Epoch 11/1000\n",
228 | " 1/27 [>.............................] - ETA: 0s - loss: 0.7746 - accuracy: 0.6719\n",
229 | "Epoch 00011: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
230 | "27/27 [==============================] - 0s 3ms/step - loss: 0.7771 - accuracy: 0.6491 - val_loss: 0.6143 - val_accuracy: 0.8266\n",
231 | "Epoch 12/1000\n",
232 | " 1/27 [>.............................] - ETA: 0s - loss: 0.7242 - accuracy: 0.7109\n",
233 | "Epoch 00012: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
234 | "27/27 [==============================] - 0s 4ms/step - loss: 0.7490 - accuracy: 0.6650 - val_loss: 0.5740 - val_accuracy: 0.8320\n",
235 | "Epoch 13/1000\n",
236 | " 1/27 [>.............................] - ETA: 0s - loss: 0.8363 - accuracy: 0.6328\n",
237 | "Epoch 00013: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
238 | "27/27 [==============================] - 0s 4ms/step - loss: 0.7397 - accuracy: 0.6731 - val_loss: 0.5465 - val_accuracy: 0.8446\n",
239 | "Epoch 14/1000\n",
240 | " 1/27 [>.............................] - ETA: 0s - loss: 0.7634 - accuracy: 0.6172\n",
241 | "Epoch 00014: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
242 | "27/27 [==============================] - 0s 3ms/step - loss: 0.7190 - accuracy: 0.6883 - val_loss: 0.5202 - val_accuracy: 0.8589\n",
243 | "Epoch 15/1000\n",
244 | " 1/27 [>.............................] - ETA: 0s - loss: 0.6699 - accuracy: 0.6875\n",
245 | "Epoch 00015: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
246 | "27/27 [==============================] - 0s 3ms/step - loss: 0.7077 - accuracy: 0.6973 - val_loss: 0.4944 - val_accuracy: 0.8652\n",
247 | "Epoch 16/1000\n",
248 | " 1/27 [>.............................] - ETA: 0s - loss: 0.6447 - accuracy: 0.7500\n",
249 | "Epoch 00016: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
250 | "27/27 [==============================] - 0s 3ms/step - loss: 0.6903 - accuracy: 0.6928 - val_loss: 0.4781 - val_accuracy: 0.8805\n",
251 | "Epoch 17/1000\n",
252 | " 1/27 [>.............................] - ETA: 0s - loss: 0.7165 - accuracy: 0.6875\n",
253 | "Epoch 00017: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
254 | "27/27 [==============================] - 0s 3ms/step - loss: 0.6919 - accuracy: 0.6973 - val_loss: 0.4696 - val_accuracy: 0.8895\n",
255 | "Epoch 18/1000\n",
256 | " 1/27 [>.............................] - ETA: 0s - loss: 0.6268 - accuracy: 0.7422\n",
257 | "Epoch 00018: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
258 | "27/27 [==============================] - 0s 3ms/step - loss: 0.6498 - accuracy: 0.7303 - val_loss: 0.4440 - val_accuracy: 0.8967\n",
259 | "Epoch 19/1000\n",
260 | "27/27 [==============================] - ETA: 0s - loss: 0.6499 - accuracy: 0.7261\n",
261 | "Epoch 00019: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
262 | "27/27 [==============================] - 0s 4ms/step - loss: 0.6499 - accuracy: 0.7261 - val_loss: 0.4254 - val_accuracy: 0.9039\n",
263 | "Epoch 20/1000\n",
264 | "26/27 [===========================>..] - ETA: 0s - loss: 0.6386 - accuracy: 0.7236\n",
265 | "Epoch 00020: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
266 | "27/27 [==============================] - 0s 4ms/step - loss: 0.6415 - accuracy: 0.7228 - val_loss: 0.4082 - val_accuracy: 0.9093\n",
267 | "Epoch 21/1000\n",
268 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5631 - accuracy: 0.7500\n",
269 | "Epoch 00021: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
270 | "27/27 [==============================] - 0s 3ms/step - loss: 0.6484 - accuracy: 0.7180 - val_loss: 0.4114 - val_accuracy: 0.9173\n",
271 | "Epoch 22/1000\n",
272 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5775 - accuracy: 0.7812\n",
273 | "Epoch 00022: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
274 | "27/27 [==============================] - 0s 3ms/step - loss: 0.6154 - accuracy: 0.7480 - val_loss: 0.3907 - val_accuracy: 0.9218\n",
275 | "Epoch 23/1000\n",
276 | "25/27 [==========================>...] - ETA: 0s - loss: 0.5967 - accuracy: 0.7588\n",
277 | "Epoch 00023: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
278 | "27/27 [==============================] - 0s 4ms/step - loss: 0.5971 - accuracy: 0.7582 - val_loss: 0.3763 - val_accuracy: 0.9227\n",
279 | "Epoch 24/1000\n",
280 | "26/27 [===========================>..] - ETA: 0s - loss: 0.6064 - accuracy: 0.7569\n",
281 | "Epoch 00024: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
282 | "27/27 [==============================] - 0s 4ms/step - loss: 0.6066 - accuracy: 0.7567 - val_loss: 0.3714 - val_accuracy: 0.9254\n",
283 | "Epoch 25/1000\n",
284 | " 1/27 [>.............................] - ETA: 0s - loss: 0.6211 - accuracy: 0.7422\n",
285 | "Epoch 00025: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
286 | "27/27 [==============================] - 0s 7ms/step - loss: 0.5954 - accuracy: 0.7579 - val_loss: 0.3611 - val_accuracy: 0.9353\n",
287 | "Epoch 26/1000\n",
288 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5621 - accuracy: 0.7812\n",
289 | "Epoch 00026: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
290 | "27/27 [==============================] - 0s 3ms/step - loss: 0.5818 - accuracy: 0.7737 - val_loss: 0.3498 - val_accuracy: 0.9380\n",
291 | "Epoch 27/1000\n",
292 | " 1/27 [>.............................] - ETA: 0s - loss: 0.6431 - accuracy: 0.7500\n",
293 | "Epoch 00027: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
294 | "27/27 [==============================] - 0s 3ms/step - loss: 0.5882 - accuracy: 0.7648 - val_loss: 0.3355 - val_accuracy: 0.9416\n"
295 | ]
296 | },
297 | {
298 | "name": "stdout",
299 | "output_type": "stream",
300 | "text": [
301 | "Epoch 28/1000\n",
302 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5633 - accuracy: 0.8203\n",
303 | "Epoch 00028: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
304 | "27/27 [==============================] - 0s 3ms/step - loss: 0.5746 - accuracy: 0.7702 - val_loss: 0.3273 - val_accuracy: 0.9425\n",
305 | "Epoch 29/1000\n",
306 | "27/27 [==============================] - ETA: 0s - loss: 0.5856 - accuracy: 0.7651\n",
307 | "Epoch 00029: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
308 | "27/27 [==============================] - 0s 4ms/step - loss: 0.5856 - accuracy: 0.7651 - val_loss: 0.3237 - val_accuracy: 0.9434\n",
309 | "Epoch 30/1000\n",
310 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5524 - accuracy: 0.7812\n",
311 | "Epoch 00030: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
312 | "27/27 [==============================] - 0s 3ms/step - loss: 0.5576 - accuracy: 0.7788 - val_loss: 0.3203 - val_accuracy: 0.9452\n",
313 | "Epoch 31/1000\n",
314 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5643 - accuracy: 0.7578\n",
315 | "Epoch 00031: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
316 | "27/27 [==============================] - 0s 3ms/step - loss: 0.5450 - accuracy: 0.7773 - val_loss: 0.3111 - val_accuracy: 0.9443\n",
317 | "Epoch 32/1000\n",
318 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5507 - accuracy: 0.7812\n",
319 | "Epoch 00032: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
320 | "27/27 [==============================] - 0s 6ms/step - loss: 0.5574 - accuracy: 0.7860 - val_loss: 0.3017 - val_accuracy: 0.9434\n",
321 | "Epoch 33/1000\n",
322 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5302 - accuracy: 0.8125\n",
323 | "Epoch 00033: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
324 | "27/27 [==============================] - 0s 3ms/step - loss: 0.5444 - accuracy: 0.7905 - val_loss: 0.2917 - val_accuracy: 0.9479\n",
325 | "Epoch 34/1000\n",
326 | "27/27 [==============================] - ETA: 0s - loss: 0.5421 - accuracy: 0.7848\n",
327 | "Epoch 00034: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
328 | "27/27 [==============================] - 0s 4ms/step - loss: 0.5421 - accuracy: 0.7848 - val_loss: 0.2863 - val_accuracy: 0.9470\n",
329 | "Epoch 35/1000\n",
330 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4633 - accuracy: 0.8125\n",
331 | "Epoch 00035: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
332 | "27/27 [==============================] - 0s 3ms/step - loss: 0.5331 - accuracy: 0.7980 - val_loss: 0.2804 - val_accuracy: 0.9506\n",
333 | "Epoch 36/1000\n",
334 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5972 - accuracy: 0.7812\n",
335 | "Epoch 00036: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
336 | "27/27 [==============================] - 0s 3ms/step - loss: 0.5255 - accuracy: 0.7968 - val_loss: 0.2774 - val_accuracy: 0.9479\n",
337 | "Epoch 37/1000\n",
338 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5851 - accuracy: 0.7578\n",
339 | "Epoch 00037: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
340 | "27/27 [==============================] - 0s 3ms/step - loss: 0.5284 - accuracy: 0.7947 - val_loss: 0.2681 - val_accuracy: 0.9497\n",
341 | "Epoch 38/1000\n",
342 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4775 - accuracy: 0.7812\n",
343 | "Epoch 00038: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
344 | "27/27 [==============================] - 0s 3ms/step - loss: 0.5143 - accuracy: 0.7995 - val_loss: 0.2678 - val_accuracy: 0.9479\n",
345 | "Epoch 39/1000\n",
346 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5495 - accuracy: 0.7812\n",
347 | "Epoch 00039: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
348 | "27/27 [==============================] - 0s 3ms/step - loss: 0.5250 - accuracy: 0.7935 - val_loss: 0.2657 - val_accuracy: 0.9470\n",
349 | "Epoch 40/1000\n",
350 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5244 - accuracy: 0.8203\n",
351 | "Epoch 00040: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
352 | "27/27 [==============================] - 0s 3ms/step - loss: 0.5195 - accuracy: 0.8001 - val_loss: 0.2606 - val_accuracy: 0.9524\n",
353 | "Epoch 41/1000\n",
354 | " 1/27 [>.............................] - ETA: 0s - loss: 0.6996 - accuracy: 0.6953\n",
355 | "Epoch 00041: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
356 | "27/27 [==============================] - 0s 3ms/step - loss: 0.5304 - accuracy: 0.7956 - val_loss: 0.2572 - val_accuracy: 0.9515\n",
357 | "Epoch 42/1000\n",
358 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4843 - accuracy: 0.8281\n",
359 | "Epoch 00042: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
360 | "27/27 [==============================] - 0s 3ms/step - loss: 0.5115 - accuracy: 0.8091 - val_loss: 0.2513 - val_accuracy: 0.9524\n",
361 | "Epoch 43/1000\n",
362 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3830 - accuracy: 0.8594\n",
363 | "Epoch 00043: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
364 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4939 - accuracy: 0.8133 - val_loss: 0.2423 - val_accuracy: 0.9551\n",
365 | "Epoch 44/1000\n",
366 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4531 - accuracy: 0.7969\n",
367 | "Epoch 00044: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
368 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4969 - accuracy: 0.8094 - val_loss: 0.2437 - val_accuracy: 0.9497\n",
369 | "Epoch 45/1000\n",
370 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5198 - accuracy: 0.7812\n",
371 | "Epoch 00045: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
372 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4961 - accuracy: 0.8106 - val_loss: 0.2441 - val_accuracy: 0.9533\n",
373 | "Epoch 46/1000\n",
374 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5912 - accuracy: 0.7812\n",
375 | "Epoch 00046: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
376 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4816 - accuracy: 0.8271 - val_loss: 0.2482 - val_accuracy: 0.9542\n",
377 | "Epoch 47/1000\n",
378 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5292 - accuracy: 0.8047\n",
379 | "Epoch 00047: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
380 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4789 - accuracy: 0.8127 - val_loss: 0.2313 - val_accuracy: 0.9569\n",
381 | "Epoch 48/1000\n",
382 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4428 - accuracy: 0.8125\n",
383 | "Epoch 00048: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
384 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4901 - accuracy: 0.8079 - val_loss: 0.2319 - val_accuracy: 0.9560\n",
385 | "Epoch 49/1000\n",
386 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4565 - accuracy: 0.8281\n",
387 | "Epoch 00049: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
388 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4893 - accuracy: 0.8154 - val_loss: 0.2300 - val_accuracy: 0.9533\n",
389 | "Epoch 50/1000\n",
390 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5185 - accuracy: 0.7812\n",
391 | "Epoch 00050: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
392 | "27/27 [==============================] - 0s 3ms/step - loss: 0.5079 - accuracy: 0.8112 - val_loss: 0.2386 - val_accuracy: 0.9524\n",
393 | "Epoch 51/1000\n",
394 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4280 - accuracy: 0.8203\n",
395 | "Epoch 00051: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
396 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4692 - accuracy: 0.8205 - val_loss: 0.2332 - val_accuracy: 0.9578\n",
397 | "Epoch 52/1000\n",
398 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5299 - accuracy: 0.8203\n",
399 | "Epoch 00052: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
400 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4907 - accuracy: 0.8145 - val_loss: 0.2336 - val_accuracy: 0.9569\n",
401 | "Epoch 53/1000\n",
402 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5856 - accuracy: 0.7969\n",
403 | "Epoch 00053: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
404 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4882 - accuracy: 0.8181 - val_loss: 0.2306 - val_accuracy: 0.9605\n",
405 | "Epoch 54/1000\n",
406 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4466 - accuracy: 0.8047\n",
407 | "Epoch 00054: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
408 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4641 - accuracy: 0.8235 - val_loss: 0.2218 - val_accuracy: 0.9596\n"
409 | ]
410 | },
411 | {
412 | "name": "stdout",
413 | "output_type": "stream",
414 | "text": [
415 | "Epoch 55/1000\n",
416 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3760 - accuracy: 0.8672\n",
417 | "Epoch 00055: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
418 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4630 - accuracy: 0.8241 - val_loss: 0.2242 - val_accuracy: 0.9578\n",
419 | "Epoch 56/1000\n",
420 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4607 - accuracy: 0.7734\n",
421 | "Epoch 00056: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
422 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4546 - accuracy: 0.8277 - val_loss: 0.2168 - val_accuracy: 0.9605\n",
423 | "Epoch 57/1000\n",
424 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4946 - accuracy: 0.7969\n",
425 | "Epoch 00057: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
426 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4675 - accuracy: 0.8214 - val_loss: 0.2211 - val_accuracy: 0.9578\n",
427 | "Epoch 58/1000\n",
428 | "25/27 [==========================>...] - ETA: 0s - loss: 0.4393 - accuracy: 0.8334\n",
429 | "Epoch 00058: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
430 | "27/27 [==============================] - 0s 4ms/step - loss: 0.4418 - accuracy: 0.8325 - val_loss: 0.2115 - val_accuracy: 0.9632\n",
431 | "Epoch 59/1000\n",
432 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4987 - accuracy: 0.7969\n",
433 | "Epoch 00059: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
434 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4646 - accuracy: 0.8217 - val_loss: 0.2116 - val_accuracy: 0.9596\n",
435 | "Epoch 60/1000\n",
436 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4904 - accuracy: 0.7812\n",
437 | "Epoch 00060: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
438 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4576 - accuracy: 0.8232 - val_loss: 0.2108 - val_accuracy: 0.9569\n",
439 | "Epoch 61/1000\n",
440 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5147 - accuracy: 0.8281\n",
441 | "Epoch 00061: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
442 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4646 - accuracy: 0.8253 - val_loss: 0.2174 - val_accuracy: 0.9587\n",
443 | "Epoch 62/1000\n",
444 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3386 - accuracy: 0.8750\n",
445 | "Epoch 00062: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
446 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4433 - accuracy: 0.8310 - val_loss: 0.2145 - val_accuracy: 0.9560\n",
447 | "Epoch 63/1000\n",
448 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4222 - accuracy: 0.8594\n",
449 | "Epoch 00063: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
450 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4624 - accuracy: 0.8283 - val_loss: 0.2099 - val_accuracy: 0.9569\n",
451 | "Epoch 64/1000\n",
452 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4376 - accuracy: 0.8203\n",
453 | "Epoch 00064: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
454 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4601 - accuracy: 0.8235 - val_loss: 0.2075 - val_accuracy: 0.9641\n",
455 | "Epoch 65/1000\n",
456 | "27/27 [==============================] - ETA: 0s - loss: 0.4676 - accuracy: 0.8265\n",
457 | "Epoch 00065: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
458 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4676 - accuracy: 0.8265 - val_loss: 0.2172 - val_accuracy: 0.9551\n",
459 | "Epoch 66/1000\n",
460 | "26/27 [===========================>..] - ETA: 0s - loss: 0.4434 - accuracy: 0.8368\n",
461 | "Epoch 00066: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
462 | "27/27 [==============================] - 1s 20ms/step - loss: 0.4429 - accuracy: 0.8370 - val_loss: 0.2154 - val_accuracy: 0.9578\n",
463 | "Epoch 67/1000\n",
464 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4827 - accuracy: 0.8125\n",
465 | "Epoch 00067: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
466 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4434 - accuracy: 0.8358 - val_loss: 0.2090 - val_accuracy: 0.9587\n",
467 | "Epoch 68/1000\n",
468 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5810 - accuracy: 0.7656\n",
469 | "Epoch 00068: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
470 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4892 - accuracy: 0.8049 - val_loss: 0.2160 - val_accuracy: 0.9578\n",
471 | "Epoch 69/1000\n",
472 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4375 - accuracy: 0.7812\n",
473 | "Epoch 00069: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
474 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4597 - accuracy: 0.8250 - val_loss: 0.2100 - val_accuracy: 0.9605\n",
475 | "Epoch 70/1000\n",
476 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3782 - accuracy: 0.8359\n",
477 | "Epoch 00070: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
478 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4408 - accuracy: 0.8349 - val_loss: 0.2087 - val_accuracy: 0.9596\n",
479 | "Epoch 71/1000\n",
480 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4578 - accuracy: 0.8438\n",
481 | "Epoch 00071: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
482 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4425 - accuracy: 0.8355 - val_loss: 0.2075 - val_accuracy: 0.9587\n",
483 | "Epoch 72/1000\n",
484 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4267 - accuracy: 0.8438\n",
485 | "Epoch 00072: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
486 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4581 - accuracy: 0.8292 - val_loss: 0.2059 - val_accuracy: 0.9623\n",
487 | "Epoch 73/1000\n",
488 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4450 - accuracy: 0.8750\n",
489 | "Epoch 00073: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
490 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4508 - accuracy: 0.8403 - val_loss: 0.2083 - val_accuracy: 0.9614\n",
491 | "Epoch 74/1000\n",
492 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3858 - accuracy: 0.8906\n",
493 | "Epoch 00074: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
494 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4472 - accuracy: 0.8361 - val_loss: 0.2043 - val_accuracy: 0.9650\n",
495 | "Epoch 75/1000\n",
496 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4439 - accuracy: 0.8359\n",
497 | "Epoch 00075: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
498 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4515 - accuracy: 0.8325 - val_loss: 0.2138 - val_accuracy: 0.9632\n",
499 | "Epoch 76/1000\n",
500 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3983 - accuracy: 0.8203\n",
501 | "Epoch 00076: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
502 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4462 - accuracy: 0.8334 - val_loss: 0.2065 - val_accuracy: 0.9623\n",
503 | "Epoch 77/1000\n",
504 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5020 - accuracy: 0.8047\n",
505 | "Epoch 00077: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
506 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4325 - accuracy: 0.8388 - val_loss: 0.2061 - val_accuracy: 0.9605\n",
507 | "Epoch 78/1000\n",
508 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3215 - accuracy: 0.8672\n",
509 | "Epoch 00078: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
510 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4394 - accuracy: 0.8391 - val_loss: 0.2054 - val_accuracy: 0.9578\n",
511 | "Epoch 79/1000\n",
512 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4025 - accuracy: 0.8359\n",
513 | "Epoch 00079: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
514 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4370 - accuracy: 0.8310 - val_loss: 0.2031 - val_accuracy: 0.9605\n",
515 | "Epoch 80/1000\n",
516 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4505 - accuracy: 0.8125\n",
517 | "Epoch 00080: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
518 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4419 - accuracy: 0.8340 - val_loss: 0.2010 - val_accuracy: 0.9596\n",
519 | "Epoch 81/1000\n",
520 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5287 - accuracy: 0.7891\n",
521 | "Epoch 00081: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
522 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4369 - accuracy: 0.8304 - val_loss: 0.2081 - val_accuracy: 0.9578\n"
523 | ]
524 | },
525 | {
526 | "name": "stdout",
527 | "output_type": "stream",
528 | "text": [
529 | "Epoch 82/1000\n",
530 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5132 - accuracy: 0.8047\n",
531 | "Epoch 00082: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
532 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4360 - accuracy: 0.8460 - val_loss: 0.2045 - val_accuracy: 0.9605\n",
533 | "Epoch 83/1000\n",
534 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4239 - accuracy: 0.8125\n",
535 | "Epoch 00083: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
536 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4511 - accuracy: 0.8313 - val_loss: 0.1984 - val_accuracy: 0.9605\n",
537 | "Epoch 84/1000\n",
538 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4853 - accuracy: 0.8203\n",
539 | "Epoch 00084: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
540 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4357 - accuracy: 0.8304 - val_loss: 0.2024 - val_accuracy: 0.9623\n",
541 | "Epoch 85/1000\n",
542 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4782 - accuracy: 0.8125\n",
543 | "Epoch 00085: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
544 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4320 - accuracy: 0.8424 - val_loss: 0.2015 - val_accuracy: 0.9587\n",
545 | "Epoch 86/1000\n",
546 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3985 - accuracy: 0.8828\n",
547 | "Epoch 00086: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
548 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4422 - accuracy: 0.8349 - val_loss: 0.2087 - val_accuracy: 0.9587\n",
549 | "Epoch 87/1000\n",
550 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4810 - accuracy: 0.8359\n",
551 | "Epoch 00087: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
552 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4393 - accuracy: 0.8316 - val_loss: 0.2105 - val_accuracy: 0.9605\n",
553 | "Epoch 88/1000\n",
554 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4819 - accuracy: 0.8125\n",
555 | "Epoch 00088: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
556 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4381 - accuracy: 0.8400 - val_loss: 0.2070 - val_accuracy: 0.9623\n",
557 | "Epoch 89/1000\n",
558 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5002 - accuracy: 0.8281\n",
559 | "Epoch 00089: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
560 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4428 - accuracy: 0.8343 - val_loss: 0.2044 - val_accuracy: 0.9605\n",
561 | "Epoch 90/1000\n",
562 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3314 - accuracy: 0.9062\n",
563 | "Epoch 00090: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
564 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4156 - accuracy: 0.8406 - val_loss: 0.2026 - val_accuracy: 0.9578\n",
565 | "Epoch 91/1000\n",
566 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3850 - accuracy: 0.8594\n",
567 | "Epoch 00091: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
568 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4125 - accuracy: 0.8439 - val_loss: 0.2058 - val_accuracy: 0.9551\n",
569 | "Epoch 92/1000\n",
570 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4976 - accuracy: 0.7734\n",
571 | "Epoch 00092: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
572 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4217 - accuracy: 0.8415 - val_loss: 0.1999 - val_accuracy: 0.9623\n",
573 | "Epoch 93/1000\n",
574 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4155 - accuracy: 0.8516\n",
575 | "Epoch 00093: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
576 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4166 - accuracy: 0.8412 - val_loss: 0.1947 - val_accuracy: 0.9614\n",
577 | "Epoch 94/1000\n",
578 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3642 - accuracy: 0.8750\n",
579 | "Epoch 00094: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
580 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4204 - accuracy: 0.8418 - val_loss: 0.2008 - val_accuracy: 0.9569\n",
581 | "Epoch 95/1000\n",
582 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3773 - accuracy: 0.8594\n",
583 | "Epoch 00095: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
584 | "27/27 [==============================] - 0s 4ms/step - loss: 0.4171 - accuracy: 0.8421 - val_loss: 0.1945 - val_accuracy: 0.9596\n",
585 | "Epoch 96/1000\n",
586 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4086 - accuracy: 0.8672\n",
587 | "Epoch 00096: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
588 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4124 - accuracy: 0.8520 - val_loss: 0.1930 - val_accuracy: 0.9614\n",
589 | "Epoch 97/1000\n",
590 | " 1/27 [>.............................] - ETA: 0s - loss: 0.2914 - accuracy: 0.8906\n",
591 | "Epoch 00097: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
592 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4303 - accuracy: 0.8367 - val_loss: 0.1958 - val_accuracy: 0.9569\n",
593 | "Epoch 98/1000\n",
594 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4562 - accuracy: 0.8672\n",
595 | "Epoch 00098: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
596 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4191 - accuracy: 0.8400 - val_loss: 0.1950 - val_accuracy: 0.9596\n",
597 | "Epoch 99/1000\n",
598 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3967 - accuracy: 0.8438\n",
599 | "Epoch 00099: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
600 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4260 - accuracy: 0.8418 - val_loss: 0.2044 - val_accuracy: 0.9551\n",
601 | "Epoch 100/1000\n",
602 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4173 - accuracy: 0.8516\n",
603 | "Epoch 00100: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
604 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4200 - accuracy: 0.8442 - val_loss: 0.2066 - val_accuracy: 0.9560\n",
605 | "Epoch 101/1000\n",
606 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3892 - accuracy: 0.8438\n",
607 | "Epoch 00101: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
608 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4245 - accuracy: 0.8418 - val_loss: 0.2058 - val_accuracy: 0.9578\n",
609 | "Epoch 102/1000\n",
610 | " 1/27 [>.............................] - ETA: 0s - loss: 0.2965 - accuracy: 0.8984\n",
611 | "Epoch 00102: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
612 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4100 - accuracy: 0.8445 - val_loss: 0.2093 - val_accuracy: 0.9578\n",
613 | "Epoch 103/1000\n",
614 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4241 - accuracy: 0.8125\n",
615 | "Epoch 00103: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
616 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4189 - accuracy: 0.8403 - val_loss: 0.1928 - val_accuracy: 0.9659\n",
617 | "Epoch 104/1000\n",
618 | " 1/27 [>.............................] - ETA: 0s - loss: 0.6486 - accuracy: 0.7891\n",
619 | "Epoch 00104: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
620 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4223 - accuracy: 0.8424 - val_loss: 0.1964 - val_accuracy: 0.9596\n",
621 | "Epoch 105/1000\n",
622 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4589 - accuracy: 0.8281\n",
623 | "Epoch 00105: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
624 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4199 - accuracy: 0.8418 - val_loss: 0.1971 - val_accuracy: 0.9623\n",
625 | "Epoch 106/1000\n",
626 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4395 - accuracy: 0.8203\n",
627 | "Epoch 00106: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
628 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4028 - accuracy: 0.8511 - val_loss: 0.1912 - val_accuracy: 0.9641\n",
629 | "Epoch 107/1000\n",
630 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4061 - accuracy: 0.8594\n",
631 | "Epoch 00107: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
632 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4093 - accuracy: 0.8556 - val_loss: 0.1854 - val_accuracy: 0.9668\n",
633 | "Epoch 108/1000\n",
634 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4142 - accuracy: 0.8438\n",
635 | "Epoch 00108: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
636 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4249 - accuracy: 0.8367 - val_loss: 0.2008 - val_accuracy: 0.9614\n"
637 | ]
638 | },
639 | {
640 | "name": "stdout",
641 | "output_type": "stream",
642 | "text": [
643 | "Epoch 109/1000\n",
644 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5373 - accuracy: 0.7969\n",
645 | "Epoch 00109: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
646 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4198 - accuracy: 0.8379 - val_loss: 0.1955 - val_accuracy: 0.9659\n",
647 | "Epoch 110/1000\n",
648 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3872 - accuracy: 0.8281\n",
649 | "Epoch 00110: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
650 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4077 - accuracy: 0.8505 - val_loss: 0.2020 - val_accuracy: 0.9614\n",
651 | "Epoch 111/1000\n",
652 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3474 - accuracy: 0.8828\n",
653 | "Epoch 00111: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
654 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4132 - accuracy: 0.8433 - val_loss: 0.1984 - val_accuracy: 0.9632\n",
655 | "Epoch 112/1000\n",
656 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4205 - accuracy: 0.8672\n",
657 | "Epoch 00112: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
658 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4182 - accuracy: 0.8397 - val_loss: 0.1973 - val_accuracy: 0.9614\n",
659 | "Epoch 113/1000\n",
660 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4995 - accuracy: 0.8125\n",
661 | "Epoch 00113: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
662 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4042 - accuracy: 0.8478 - val_loss: 0.1922 - val_accuracy: 0.9650\n",
663 | "Epoch 114/1000\n",
664 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3640 - accuracy: 0.8750\n",
665 | "Epoch 00114: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
666 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4072 - accuracy: 0.8430 - val_loss: 0.1868 - val_accuracy: 0.9596\n",
667 | "Epoch 115/1000\n",
668 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5042 - accuracy: 0.8203\n",
669 | "Epoch 00115: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
670 | "27/27 [==============================] - 0s 3ms/step - loss: 0.3999 - accuracy: 0.8535 - val_loss: 0.1966 - val_accuracy: 0.9605\n",
671 | "Epoch 116/1000\n",
672 | " 1/27 [>.............................] - ETA: 0s - loss: 0.5355 - accuracy: 0.7422\n",
673 | "Epoch 00116: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
674 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4174 - accuracy: 0.8433 - val_loss: 0.1919 - val_accuracy: 0.9659\n",
675 | "Epoch 117/1000\n",
676 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3729 - accuracy: 0.8750\n",
677 | "Epoch 00117: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
678 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4101 - accuracy: 0.8451 - val_loss: 0.1932 - val_accuracy: 0.9578\n",
679 | "Epoch 118/1000\n",
680 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3193 - accuracy: 0.8828\n",
681 | "Epoch 00118: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
682 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4260 - accuracy: 0.8379 - val_loss: 0.1865 - val_accuracy: 0.9641\n",
683 | "Epoch 119/1000\n",
684 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3852 - accuracy: 0.8438\n",
685 | "Epoch 00119: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
686 | "27/27 [==============================] - 0s 3ms/step - loss: 0.3797 - accuracy: 0.8622 - val_loss: 0.1900 - val_accuracy: 0.9677\n",
687 | "Epoch 120/1000\n",
688 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3636 - accuracy: 0.8594\n",
689 | "Epoch 00120: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
690 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4017 - accuracy: 0.8460 - val_loss: 0.1908 - val_accuracy: 0.9659\n",
691 | "Epoch 121/1000\n",
692 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4521 - accuracy: 0.8359\n",
693 | "Epoch 00121: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
694 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4023 - accuracy: 0.8538 - val_loss: 0.1935 - val_accuracy: 0.9659\n",
695 | "Epoch 122/1000\n",
696 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4849 - accuracy: 0.8203\n",
697 | "Epoch 00122: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
698 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4186 - accuracy: 0.8457 - val_loss: 0.1937 - val_accuracy: 0.9659\n",
699 | "Epoch 123/1000\n",
700 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4021 - accuracy: 0.8516\n",
701 | "Epoch 00123: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
702 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4156 - accuracy: 0.8478 - val_loss: 0.1907 - val_accuracy: 0.9632\n",
703 | "Epoch 124/1000\n",
704 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3601 - accuracy: 0.8906\n",
705 | "Epoch 00124: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
706 | "27/27 [==============================] - 0s 3ms/step - loss: 0.3948 - accuracy: 0.8550 - val_loss: 0.1862 - val_accuracy: 0.9605\n",
707 | "Epoch 125/1000\n",
708 | " 1/27 [>.............................] - ETA: 0s - loss: 0.4446 - accuracy: 0.7891\n",
709 | "Epoch 00125: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
710 | "27/27 [==============================] - 0s 3ms/step - loss: 0.4152 - accuracy: 0.8520 - val_loss: 0.1888 - val_accuracy: 0.9623\n",
711 | "Epoch 126/1000\n",
712 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3733 - accuracy: 0.8438\n",
713 | "Epoch 00126: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
714 | "27/27 [==============================] - 0s 3ms/step - loss: 0.3913 - accuracy: 0.8550 - val_loss: 0.1937 - val_accuracy: 0.9632\n",
715 | "Epoch 127/1000\n",
716 | " 1/27 [>.............................] - ETA: 0s - loss: 0.3000 - accuracy: 0.8828\n",
717 | "Epoch 00127: saving model to model/keypoint_classifier\\keypoint_classifier.hdf5\n",
718 | "27/27 [==============================] - 0s 3ms/step - loss: 0.3820 - accuracy: 0.8583 - val_loss: 0.1867 - val_accuracy: 0.9632\n",
719 | "Epoch 00127: early stopping\n"
720 | ]
721 | },
722 | {
723 | "data": {
724 | "text/plain": [
725 | ""
726 | ]
727 | },
728 | "execution_count": 11,
729 | "metadata": {},
730 | "output_type": "execute_result"
731 | }
732 | ],
733 | "source": [
734 | "model.fit(\n",
735 | " X_train,\n",
736 | " y_train,\n",
737 | " epochs=1000,\n",
738 | " batch_size=128,\n",
739 | " validation_data=(X_test, y_test),\n",
740 | " callbacks=[cp_callback, es_callback]\n",
741 | ")"
742 | ]
743 | },
744 | {
745 | "cell_type": "code",
746 | "execution_count": 12,
747 | "metadata": {},
748 | "outputs": [
749 | {
750 | "name": "stdout",
751 | "output_type": "stream",
752 | "text": [
753 | "9/9 [==============================] - 0s 1ms/step - loss: 0.1867 - accuracy: 0.9632\n"
754 | ]
755 | }
756 | ],
757 | "source": [
758 | "# モデル評価\n",
759 | "val_loss, val_acc = model.evaluate(X_test, y_test, batch_size=128)"
760 | ]
761 | },
762 | {
763 | "cell_type": "code",
764 | "execution_count": 13,
765 | "metadata": {},
766 | "outputs": [],
767 | "source": [
768 | "# 保存したモデルのロード\n",
769 | "model = tf.keras.models.load_model(model_save_path)"
770 | ]
771 | },
772 | {
773 | "cell_type": "code",
774 | "execution_count": 14,
775 | "metadata": {},
776 | "outputs": [
777 | {
778 | "name": "stdout",
779 | "output_type": "stream",
780 | "text": [
781 | "[0.77297777 0.1697358 0.05728642]\n",
782 | "0\n"
783 | ]
784 | }
785 | ],
786 | "source": [
787 | "# 推論テスト\n",
788 | "predict_result = model.predict(np.array([X_test[0]]))\n",
789 | "print(np.squeeze(predict_result))\n",
790 | "print(np.argmax(np.squeeze(predict_result)))"
791 | ]
792 | },
793 | {
794 | "cell_type": "markdown",
795 | "metadata": {},
796 | "source": [
797 | "# 混同行列"
798 | ]
799 | },
800 | {
801 | "cell_type": "code",
802 | "execution_count": 15,
803 | "metadata": {},
804 | "outputs": [
805 | {
806 | "data": {
807 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZUAAAFmCAYAAAClXQeMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAfJklEQVR4nO3de5yWdZn48c81DKiA5llhYMWUPGQbtumalBnkibXQLQlfmSc22s3MyvLQK9d0s7VNM63NbVwPuJsoWaRrHlP8mZkgChKndPIQIIh4BExl5vn+/phbeyJmHmTumeeeez5vX/eL5/k+9+Ea58VcXN/re98TKSUkScpDQ70DkCSVh0lFkpQbk4okKTcmFUlSbkwqkqTcmFQkSbkxqUhSHxMR/SJiTkTckr3fNSJmRkRLRNwQEQOy8c2y9y3Z5yNqnbuxm2Nn3aonvBGmwAY1HVTvENQJ7yMrvnVvLIvczpXDz8v+279zY+I5DVgEbJW9/w5wSUrp+oj4L2AScHn254sppd0jYmK236c6O7GViiQVRaWt61sNETEM+Afgv7P3AYwBbsx2mQIclb0en70n+3xstn+HTCqS1Ld8HzgDqGTvtwNeSim1Zu+XAk3Z6yZgCUD2+cvZ/h0yqUhSUaRKl7eImBwRs6u2yW+ePiKOBFamlB7uri+h23sqkqSNVKnU3qeGlFIz0NzBx6OBj0fEOGBz2nsqlwJbR0RjVo0MA5Zl+y8DhgNLI6IReAfwfGfXt1KRpIJIqdLlrfPzp7NTSsNSSiOAicA9KaVPAzOAT2a7nQDclL2+OXtP9vk9qcbqEZOKJOlM4CsR0UJ7z+TKbPxKYLts/CvAWbVO5PSXJBVFDtNfGyuldC9wb/b6CWD/DezzGnDM2zmvSUWSiqLG9FVvYFKRpKLYiPtMis6kIklFUYJKxUa9JCk3ViqSVBQ92KjvLiYVSSqIWveZ9AYmFUkqCisVSVJuSlCp2KiXJOXGSkWSisL7VCRJuSnB9JdJRZKKogSNensqkqTcWKlIUlE4/SVJyk0Jpr9MKpJUECm5+kuSlJcSTH/ZqJck5cZKRZKKwp6KJCk3JZj+MqlIUlH4mBZJUm5KUKnYqJck5cZKRZKKwka9JCk3JZj+MqlIUlGUoFKxpyJJyo2ViiQVRQkqFZOKJBWED5SUJOXHSkWSlJsSrP6yUS9Jyo1JRZKKolLp+taJiNg8ImZFxKMRsSAizsvGr4mIJyNibraNysYjIi6LiJaImBcR76v1JTj9JUlF0f3TX68DY1JKayKiP3B/RNyWffa1lNKN6+1/BDAy2/4euDz7s0MmFUkqim5u1KeUErAme9s/21Inh4wHrs2OezAito6IISml5R0d4PSXJBVFqnR5i4jJETG7aptcfYmI6BcRc4GVwF0ppZnZRxdkU1yXRMRm2VgTsKTq8KXZWIesVCSpRFJKzUBzJ5+3AaMiYmtgekTsA5wNrAAGZMeeCZy/Kde3UpGkoujmRn21lNJLwAzg8JTS8tTudeBqYP9st2XA8KrDhmVjHTKpSFJRdP/qrx2yCoWI2AI4BFgcEUOysQCOAuZnh9wMHJ+tAjsAeLmzfgo4/SVJxdH9q7+GAFMioh/tRcW0lNItEXFPROwABDAX+Ods/1uBcUAL8CpwUq0LmFQkqY9IKc0D9t3A+JgO9k/AKW/nGk5/bUBbWxufPPEUPv+1cwFY+swKjv3slzhiwsmcfs6/s27dur/Y/64Z97PP6COYv+ixeoSrKg0NDcyaeTvTp19T71BUZdiwodx150959NEZzJ17D6d+YVK9QyqmHuypdBeTygb8709v4p0j/uat95dcfhWf+dRR3DbtKrbacjA/u+WOtz5bu/ZV/venN/G3e+9Rj1C1nlNPncTixS31DkPraW1t5YwzzuO97/0IH/zgx/jnfzmRvfYaWe+wiieHJcX1ZlJZz4qVz3HfA7P4xMcOAyClxMyHH+XQgz8EwPhxH+We+3771v4/uOJaTj7uGAZsNqAu8erPmpqGcMQRY7nq6uvqHYrWs2LFSubMbe/9rlmzlsWLH2fo0J3rHFUB9YVKJSL2jIgzs+e/XJa93qsngquH71z6Y77y+UlEtP+veenlV9hy8CAaG/sBsNMO27PyuecBWPj7FlasXMWHD9y/w/Op51x80Tc5++wLqFQ6u0FY9bbLLsMY9d59mDVrTr1DKZ6yVyoRcSZwPe0rAmZlWwBTI+Ks7g+vZ937m5lsu83WvHvP2mV5pVLhP37QzNdO/WwPRKZaxo0by8rnVjFnzu/qHYo6MWjQQKbdcAWnf/VcVq9eU/sA9Tq1Vn9NAt6dUvqLznREfA9YAFy4oYOyxwJMBvjRxd/in44/NodQu9+ceQu59/4H+fVvH+L1N9axdu2rXPj9/2L1mrW0trbR2NiPZ59bxY47bMfaV/9EyxNPc9IXzgBg1QsvcuqZ5/GD75zLPnu9q85fSd9z4Af248h/OJTDDxvD5ptvxlZbbck1V1/GiSd9sd6hKdPY2Mi0G65g6tTp/OIXt9U+oC8qwPRVV0X7irEOPoxYDByWUnp6vfFdgDtTSjW70+tWPdEr5yJmPTKPa6b+jB999zy+8o0L+OjBoxn30YM57z9+wB6778rEfzzyL/Y/8Qtn8NVT/qnXJZRBTQfVO4TcHXTQB/jylz/H0UefWO9Quqyzv5+9zdVXXcoLL7zE6V89t96h5GrdG8sir3P9adr5Xf6GbzHhX3OLZ1PU6ql8Cbg7Im6LiOZsux24Gzit26MriC//y8lce/10jphwMi+/8gr/eOSh9Q5J6lVGH7gfxx33ST7ykQOZ/dCdzH7oTg4/fIO3RvRtKXV9q7NOKxWAaO9Y78+fn0y5DHgoeyhZTb21UukryliplEmZKpWyyrVSmXpu1yuVY8+ra6VS8476lFIFeLAHYpEk9XI+pkWSiqIEjXqTiiQVRQHuM+kqk4okFUUJKhUf0yJJyo2ViiQVRQlW+5lUJKkoSjD9ZVKRpKIwqUiSclOC1V826iVJubFSkaSCSCX4XUAmFUkqCnsqkqTclKCnYlKRpKIowfSXjXpJUm6sVCSpKOypSJJyY1KRJOWmBM/+sqciScqNlYokFYXTX5Kk3JRgSbFJRZKKogQ3P9pTkaSiqKSub52IiM0jYlZEPBoRCyLivGx814iYGREtEXFDRAzIxjfL3rdkn4+o9SWYVCSp73gdGJNSei8wCjg8Ig4AvgNcklLaHXgRmJTtPwl4MRu/JNuvUyYVSSqIVKl0eev0/O3WZG/7Z1sCxgA3ZuNTgKOy1+Oz92Sfj42I6OwaJhVJKopunv4CiIh+ETEXWAncBfwBeCml1JrtshRoyl43AUsAss9fBrbr7PwmFUkqilTp8hYRkyNidtU2+S8ukVJbSmkUMAzYH9gzzy/B1V+SVCIppWageSP2eykiZgAfALaOiMasGhkGLMt2WwYMB5ZGRCPwDuD5zs5rpSJJRdH9q792iIits9dbAIcAi4AZwCez3U4Abspe35y9J/v8npQ6f5aMlYokFUX331E/BJgSEf1oLyqmpZRuiYiFwPUR8S1gDnBltv+VwP9ERAvwAjCx1gVMKpJUFN18R31KaR6w7wbGn6C9v7L++GvAMW/nGiYVSSoK76iXJOnPrFQkqSh8oKQkKS+17ojvDUwqklQUViqSpNyUIKnYqJck5cZKRZKKogRLik0qklQUJZj+MqlIUkGkEiQVeyqSpNxYqUhSUZSgUjGpSFJRePOjJCk3ViqSpNyUIKnYqJck5cZKRZIKosZv6u0VTCqSVBQlmP4yqUhSUZhUahs49EPdfQl1wSu3nVvvENSJkcdcVu8Q1IO8o16SpCpOf0lSUZSgUjGpSFJR9P4b6k0qklQU9lQkSapipSJJRVGCSsWkIklFYU9FkpSXMvRUTCqSVBQlqFRs1EuScmOlIkkFUYbpLysVSSqKSg5bJyJieETMiIiFEbEgIk7Lxr8ZEcsiYm62jas65uyIaImI30fEYbW+BCsVSSqI1P09lVbg9JTSIxGxJfBwRNyVfXZJSumi6p0jYm9gIvBuYCjwq4h4V0qpraMLWKlIUlF0c6WSUlqeUnoke70aWAQ0dXLIeOD6lNLrKaUngRZg/86uYVKRpD4oIkYA+wIzs6EvRMS8iLgqIrbJxpqAJVWHLaXzJGRSkaSiSJWubxExOSJmV22T179ORAwGfgZ8KaX0CnA5sBswClgOXLypX4M9FUkqihx6KimlZqC5o88joj/tCeUnKaWfZ8c8W/X5FcAt2dtlwPCqw4dlYx2yUpGkgsijUulMRARwJbAopfS9qvEhVbsdDczPXt8MTIyIzSJiV2AkMKuza1ipSFLfMRr4DPC7iJibjX0dODYiRgEJeAr4HEBKaUFETAMW0r5y7JTOVn6BSUWSCqO7lxSnlO4HYgMf3drJMRcAF2zsNUwqklQQPXCfSrczqUhSUaQNFRG9i0lFkgqiDJWKq78kSbmxUpGkgkgVp78kSTkpw/SXSUWSCiLZqJck5aUMlYqNeklSbqxUJKkgbNRLknKTev+vqDepSFJRlKFSsaciScqNlYokFUQZKhWTiiQVhD0VSVJurFQkSbkpwx31NuolSbmxUpGkgijDY1pMKpJUEJUSTH+ZVCSpIMrQUzGpSFJBlGH1l416SVJurFQkqSC8+VGSlJsyTH+ZVCSpIMqw+sueiiQpN1YqklQQLimWJOXGRn0fMmzYUK6+6lJ23Gl7Ukpc+d8/4Qc/vLLeYfU5r69r5eTv3cC61jZaKxU+uu9IPn/kaM659nYefnwJg7fYDIDzP3M4ew7f8a3j5j+1ghMuuo4LTz6SQ973rnqF36cMadqZS3/0bbbfcTtSSlw35Uau/PH/ste79+DC753DoEEDWfLHZzj1c2eyZvXaeodbCGXoqZhUNlJraytnnHEec+bOZ/DgQcyceTu/uvs+Fi16vN6h9SkDGvtxxWnHMHDzAaxra+Oki6/ng+/eFYAvH/3hDSaMtkqFS39xHwfsNaKHo+3b2lpbOf+c7zJ/3iIGDR7IbfdM4757H+C7l57Ht/71Ih58YDaf+vTR/POpJ3HRt39Y73ALoQzTXzbqN9KKFSuZM3c+AGvWrGXx4scZOnTnOkfV90QEAzcfAEBrW4XWtgpB538Rp947h7H7jmTbLQf2RIjKrHx2FfPnLQJg7ZpXefyxJ9h5yE68c/ddePCB2QDcd+9vGfexQ+oZpnK2yUklIk7KM5DeZJddhjHqvfswa9aceofSJ7VVKkz49rWMOfNyDthzF96z6xAAfnjz/RzzrSl898YZvLGuFYBnX1rNjLktTPjQqDpGrGHDh7LP3+7FnIfn8djiP3DYuDEAHDn+UP9xViWlrm+diYjhETEjIhZGxIKIOC0b3zYi7oqIx7M/t8nGIyIui4iWiJgXEe+r9TV0pVI5rwvH9lqDBg1k2g1XcPpXz2X16jX1DqdP6tfQwLSvH88dF0xm/lMraHlmFV8c/0F+ce5J/OTMT/Py2te4+q6HAPjuT+/ltKM/REND759W6K0GDtqC5imX8M2vf4c1q9dy+qnncPykidx6zw0MHjyIdevW1TvEwqik6PJWQytwekppb+AA4JSI2Bs4C7g7pTQSuDt7D3AEMDLbJgOX17pApz2ViJjX0UfATp0cNzkLgIZ+76ChYVCtOHqFxsZGpt1wBVOnTucXv7it3uH0eVsN3Jz99hjObxY8yQmH7AfAgP6NjP/APlz7q/bplYV/XMGZV/4SgJfW/on75z9Bv4ZgzKiRdYu7L2lsbKR5yveZfuMvue2WXwHwh8ef5NOfmAzArrvtwthDDqpniIXS3T2VlNJyYHn2enVELAKagPHAwdluU4B7gTOz8WtTSgl4MCK2jogh2Xk2qFajfifgMODF9cYDeKCTwJuBZoD+A5pKsEiu3RXNF7N4cQvfv7S53qH0WS+sfpXGfg1sNXBzXntjHQ8uepqTDt2P515eww7vGExKiRmPtrD70O0AuPXfPvvWsedcezsH7fNOE0oPuuiy82l57Amu+NG1b41tt/22PL/qBSKC007/HP9zzbQ6Rlgseaz+qv5HfaY5+5m8/n4jgH2BmcBOVYliBX8uGpqAJVWHLc3GNjmp3AIMTinN3UBA99Y4tlRGH7gfxx33SX73u4XMfuhOAL5xzoXcfvs9dY6sb1n18lrOufY2KpVEJSUO/bs9OOg9u/HZ70/jxTV/IqXEHsN25BvHfrTeofZ5+/39vnxy4sdZtOAx7vh/NwLwnX+7lF1324UTJk0E4LZbfsUNP5lezzBLp/of9R2JiMHAz4AvpZReifhzMksppYjY5GIgUjffbVOmSqWMXrnt3HqHoE6MPOayeoegGpa+MD+3OasHh/5jl39eHvDMzzuNJyL6014w3JFS+l429nvg4JTS8ogYAtybUtojIn6cvZ66/n4dnd8lxZJUEN3dqI/2kuRKYNGbCSVzM3BC9voE4Kaq8eOzVWAHAC93llDAmx8lqTB64ObH0cBngN9FxNxs7OvAhcC0iJgEPA1MyD67FRgHtACvAjVvJTGpSFIfkVK6Hzq8W3jsBvZPwClv5xomFUkqiEq9A8iBSUWSCiLVeORQb2BSkaSCqJRgraxJRZIKolKCSsUlxZKk3FipSFJB2FORJOXG1V+SpNyUoVKxpyJJyo2ViiQVhNNfkqTcmFQkSbkpQ0/FpCJJBVHp/TnFRr0kKT9WKpJUEGV4TItJRZIKogTPkzSpSFJRuPpLkpSbSvT+6S8b9ZKk3FipSFJB2FORJOXGnookKTfe/ChJUhUrFUkqCG9+lCTlxka9JCk3ZeipmFQkqSDKsPrLRr0kKTdWKpJUEPZUJEm5saciScqNPRVJUm4qOWy1RMRVEbEyIuZXjX0zIpZFxNxsG1f12dkR0RIRv4+Iw2qd36QiSX3LNcDhGxi/JKU0KttuBYiIvYGJwLuzY34UEf06O7lJRZIKIkXXt5rXSOk+4IWNDGk8cH1K6fWU0pNAC7B/ZweYVCSpIHpi+qsTX4iIedn02DbZWBOwpGqfpdlYh0wqklQQeSSViJgcEbOrtskbcenLgd2AUcBy4OJN/Rpc/SVJJZJSagaa3+Yxz775OiKuAG7J3i4DhlftOiwb65CViiQVRMph2xQRMaTq7dHAmyvDbgYmRsRmEbErMBKY1dm5rFQkqSB64ubHiJgKHAxsHxFLgXOBgyNiFO156SngcwAppQURMQ1YCLQCp6SU2jo7v0lFkgqiJ25+TCkdu4HhKzvZ/wLggo09v0lFkgrCO+olSapipSJJBeFTiiVJufEpxZKk3JShp2JSkaSCKMP0l416SVJuur1SKUPmLbN3jDu/3iGoEy9d9PF6h6AeVCnBT0ynvySpIOypSJJy0/vrFHsqkqQcWalIUkE4/SVJyo03P0qScuPqL0lSbnp/SrFRL0nKkZWKJBWEjXpJUm7sqUiSctP7U4pJRZIKowzTXzbqJUm5sVKRpIKwpyJJyk3vTykmFUkqDHsqkiRVsVKRpIJIJZgAM6lIUkGUYfrLpCJJBeHqL0lSbnp/SrFRL0nKkZWKJBVEGaa/rFQkqSAqOWy1RMRVEbEyIuZXjW0bEXdFxOPZn9tk4xERl0VES0TMi4j31Tq/SUWSCiLl8N9GuAY4fL2xs4C7U0ojgbuz9wBHACOzbTJwea2Tm1QkqSB6olJJKd0HvLDe8HhgSvZ6CnBU1fi1qd2DwNYRMaSz85tUJEk7pZSWZ69XADtlr5uAJVX7Lc3GOmRSkaSCyGP6KyImR8Tsqm3y24ohpUQXVje7+kuSCiKPO+pTSs1A89s87NmIGJJSWp5Nb63MxpcBw6v2G5aNdchKRZIKopJSl7dNdDNwQvb6BOCmqvHjs1VgBwAvV02TbZCViiT1IRExFTgY2D4ilgLnAhcC0yJiEvA0MCHb/VZgHNACvAqcVOv8JhVJKoieuPUxpXRsBx+N3cC+CTjl7ZzfpCJJBVGGO+pNKpJUEP4+FUlSbsrw+1Rc/SVJyo2ViiQVhD0VSVJu7KlIknJThp6KSUWSCiJt+h3xhWGjXpKUGysVSSoIG/WSpNzYU5Ek5aYMq7/sqUiScmOlIkkFYU9FkpSbMiwpNqlIUkHYqJck5cZGfR9z2KEHs2D+fSxeeD9nfO1t/TI09ZCGhgZmzbyd6dOvqXcofdLrrW0cN+0hJkydySeue5DLZz4BwPXzlvDx/3mAfX94Ny/+6Y2/Om7Bs6/w/v+8h7tanu3pkJUzK5WN1NDQwGWXXsDh445l6dLlPPjbW/m/W+5k0aLH6x2aqpx66iQWL25hy60G1zuUPmlAvwaaj9qXgQMaWddW4eSfP8zoXbZj1JCtOWjE9vzT9Ef+6pi2SuLSB1o44G+2rUPExVKGRn3NSiUi9oyIsRExeL3xw7svrOLZf799+cMfnuLJJ//IunXrmDbtJj7+scPqHZaqNDUN4YgjxnLV1dfVO5Q+KyIYOKD936qtlURrJRHAnjtsydCtttjgMdfPW8LY3XZg2y0G9GCkxZRS6vJWb50mlYj4InATcCowPyLGV3387e4MrGiGNu3MkqXPvPV+6bLlDB26cx0j0vouvuibnH32BVQq9f+L1Ze1VRKfun4mY6/6NQcM35b37PyODvddueY17nniOY55z7AejLC4KqQub/VWq1L5LPB3KaWjgIOBcyLitOyz6OigiJgcEbMjYnalsjaXQKXOjBs3lpXPrWLOnN/VO5Q+r19DcMPEv+eOE0cz/9mXaXl+TYf7fvfXj3PagbvTEB3+OFEvU6un0pBSWgOQUnoqIg4GboyIXegkqaSUmoFmgMYBTfVPnTl4ZtkKhg8b+tb7YU1DeOaZFXWMSNUO/MB+HPkPh3L4YWPYfPPN2GqrLbnm6ss48aQv1ju0PmvLzfrz/qZteODp59l9uw33uBaufIWz7pgPwEuvreP+p1fR2NDAR965Q0+GWhh9YfXXsxEx6s03WYI5EtgeeE83xlU4D82ey+6778qIEcPp378/EyaM5/9uubPeYSnzjXMu5J277ce79vgAx33mFGbc+xsTSh288Kc3WP36OgBea21j5pIXGLHNoA73/+UJo7k12z66246c/eE9+mxCAaik1OWt3mpVKscDrdUDKaVW4PiI+HG3RVVAbW1tnPalb3DrL6+jX0MD10y5gYULH6t3WFKhrFr7Ov/6q4VUUvsPyEN235GDdt2e6x5dwpRHnub5V99gwtSZfHDE9pw7Zq96h1s49U8JXRfdvVqgLNNfZeVcdrG9dNHH6x2Cahh46o9y+0s0umlMl39e/mbZPXX9S+3Nj5Kk3HjzoyQVRBGWBHeVSUWSCqIINy92lUlFkgrCSkWSlJu+cJ+KJEkbzUpFkgqiJ3oqEfEUsBpoA1pTSu+PiG2BG4ARwFPAhJTSi5tyfisVSSqIHnyg5EdSSqNSSu/P3p8F3J1SGgncnb3fJCYVSSqIOj76fjwwJXs9BThqU09kUpGkEql+Sny2TV5vlwTcGREPV322U0ppefZ6BbDTpl7fnookFUQeS4qrnxLfgQ+mlJZFxI7AXRGxeL3jU0RsciAmFUkqiJ5YUpxSWpb9uTIipgP70/5E+iEppeURMQRYuannd/pLkgqiux99HxGDImLLN18DhwLzgZuBE7LdTqD9N/5uEisVSSqIHqhUdgKmR/vTyRuB61JKt0fEQ8C0iJgEPA1M2NQLmFQkqY9IKT0BvHcD488DY/O4hklFkgqiCL+5satMKpJUEGV49pdJRZIKwkpFkpSbMlQqLimWJOXGSkWSCsLpL0lSbsow/WVSkaSCSKlS7xC6zJ6KJCk3ViqSVBB5PKW43kwqklQQPfHrhLubSUWSCsJKRZKUmzJUKjbqJUm5sVKRpILw5kdJUm68+VGSlJsy9FRMKpJUEGVY/WWjXpKUGysVSSoIp78kSblx9ZckKTdlqFTsqUiScmOlIkkFUYbVXyYVSSqIMkx/mVQkqSBs1EuSclOGx7TYqJck5cZKRZIKwukvSVJubNRLknJjT0WSlJuUUpe3WiLi8Ij4fUS0RMRZeX8NJhVJ6iMioh/wn8ARwN7AsRGxd57XcPpLkgqiB3oq+wMtKaUnACLiemA8sDCvC1ipSFJBpBy2GpqAJVXvl2Zjuen2SqX1jWXR3dfoSRExOaXUXO84tGF+f4rP71HH8vh5GRGTgclVQ809+f/bSuXtm1x7F9WR35/i83vUjVJKzSml91dt1QllGTC86v2wbCw3JhVJ6jseAkZGxK4RMQCYCNyc5wVs1EtSH5FSao2ILwB3AP2Aq1JKC/K8hknl7XMuuNj8/hSf36M6SindCtzaXeePMjwWQJJUDPZUJEm5Mam8Dd39eANtuoi4KiJWRsT8eseivxYRwyNiRkQsjIgFEXFavWNS93D6ayNljzd4DDiE9huGHgKOTSnldieqNl1EHASsAa5NKe1T73j0lyJiCDAkpfRIRGwJPAwc5d+f8rFS2XhvPd4gpfQG8ObjDVQAKaX7gBfqHYc2LKW0PKX0SPZ6NbCInO/kVjGYVDZetz/eQOoLImIEsC8ws86hqBuYVCT1mIgYDPwM+FJK6ZV6x6P8mVQ2Xrc/3kAqs4joT3tC+UlK6ef1jkfdw6Sy8br98QZSWUVEAFcCi1JK36t3POo+JpWNlFJqBd58vMEiYFrejzfQpouIqcBvgT0iYmlETKp3TPoLo4HPAGMiYm62jat3UMqfS4olSbmxUpEk5cakIknKjUlFkpQbk4okKTcmFUlSbkwqkqTcmFQkSbkxqUiScvP/AVplmVKXNqIBAAAAAElFTkSuQmCC\n",
808 | "text/plain": [
809 | ""
810 | ]
811 | },
812 | "metadata": {
813 | "needs_background": "light"
814 | },
815 | "output_type": "display_data"
816 | },
817 | {
818 | "name": "stdout",
819 | "output_type": "stream",
820 | "text": [
821 | "Classification Report\n",
822 | " precision recall f1-score support\n",
823 | "\n",
824 | " 0 1.00 0.99 0.99 410\n",
825 | " 1 0.98 0.92 0.95 385\n",
826 | " 2 0.91 0.99 0.95 318\n",
827 | "\n",
828 | " accuracy 0.96 1113\n",
829 | " macro avg 0.96 0.96 0.96 1113\n",
830 | "weighted avg 0.96 0.96 0.96 1113\n",
831 | "\n"
832 | ]
833 | }
834 | ],
835 | "source": [
836 | "import pandas as pd\n",
837 | "import seaborn as sns\n",
838 | "import matplotlib.pyplot as plt\n",
839 | "from sklearn.metrics import confusion_matrix, classification_report\n",
840 | "\n",
841 | "def print_confusion_matrix(y_true, y_pred, report=True):\n",
842 | " labels = sorted(list(set(y_true)))\n",
843 | " cmx_data = confusion_matrix(y_true, y_pred, labels=labels)\n",
844 | " \n",
845 | " df_cmx = pd.DataFrame(cmx_data, index=labels, columns=labels)\n",
846 | " \n",
847 | " fig, ax = plt.subplots(figsize=(7, 6))\n",
848 | " sns.heatmap(df_cmx, annot=True, fmt='g' ,square=False)\n",
849 | " ax.set_ylim(len(set(y_true)), 0)\n",
850 | " plt.show()\n",
851 | " \n",
852 | " if report:\n",
853 | " print('Classification Report')\n",
854 | " print(classification_report(y_test, y_pred))\n",
855 | "\n",
856 | "Y_pred = model.predict(X_test)\n",
857 | "y_pred = np.argmax(Y_pred, axis=1)\n",
858 | "\n",
859 | "print_confusion_matrix(y_test, y_pred)"
860 | ]
861 | },
862 | {
863 | "cell_type": "markdown",
864 | "metadata": {},
865 | "source": [
866 | "# Tensorflow-Lite用のモデルへ変換"
867 | ]
868 | },
869 | {
870 | "cell_type": "code",
871 | "execution_count": 16,
872 | "metadata": {},
873 | "outputs": [],
874 | "source": [
875 | "# 推論専用のモデルとして保存\n",
876 | "model.save(model_save_path, include_optimizer=False)"
877 | ]
878 | },
879 | {
880 | "cell_type": "code",
881 | "execution_count": 17,
882 | "metadata": {},
883 | "outputs": [
884 | {
885 | "name": "stdout",
886 | "output_type": "stream",
887 | "text": [
888 | "WARNING:tensorflow:From d:\\00.envs\\20201208_mediapipe\\lib\\site-packages\\tensorflow\\python\\training\\tracking\\tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.\n",
889 | "Instructions for updating:\n",
890 | "This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n",
891 | "WARNING:tensorflow:From d:\\00.envs\\20201208_mediapipe\\lib\\site-packages\\tensorflow\\python\\training\\tracking\\tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
892 | "Instructions for updating:\n",
893 | "This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n",
894 | "INFO:tensorflow:Assets written to: C:\\Users\\sihit\\AppData\\Local\\Temp\\tmpy2l6ipxu\\assets\n"
895 | ]
896 | },
897 | {
898 | "data": {
899 | "text/plain": [
900 | "6224"
901 | ]
902 | },
903 | "execution_count": 17,
904 | "metadata": {},
905 | "output_type": "execute_result"
906 | }
907 | ],
908 | "source": [
909 | "# モデルを変換(量子化)\n",
910 | "tflite_save_path = 'model/keypoint_classifier/keypoint_classifier.tflite'\n",
911 | "\n",
912 | "converter = tf.lite.TFLiteConverter.from_keras_model(model)\n",
913 | "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
914 | "tflite_quantized_model = converter.convert()\n",
915 | "\n",
916 | "open(tflite_save_path, 'wb').write(tflite_quantized_model)"
917 | ]
918 | },
919 | {
920 | "cell_type": "markdown",
921 | "metadata": {},
922 | "source": [
923 | "# 推論テスト"
924 | ]
925 | },
926 | {
927 | "cell_type": "code",
928 | "execution_count": 18,
929 | "metadata": {},
930 | "outputs": [],
931 | "source": [
932 | "interpreter = tf.lite.Interpreter(model_path=tflite_save_path)\n",
933 | "interpreter.allocate_tensors()"
934 | ]
935 | },
936 | {
937 | "cell_type": "code",
938 | "execution_count": 19,
939 | "metadata": {},
940 | "outputs": [],
941 | "source": [
942 | "# 入出力テンソルを取得\n",
943 | "input_details = interpreter.get_input_details()\n",
944 | "output_details = interpreter.get_output_details()"
945 | ]
946 | },
947 | {
948 | "cell_type": "code",
949 | "execution_count": 20,
950 | "metadata": {},
951 | "outputs": [],
952 | "source": [
953 | "interpreter.set_tensor(input_details[0]['index'], np.array([X_test[0]]))"
954 | ]
955 | },
956 | {
957 | "cell_type": "code",
958 | "execution_count": 21,
959 | "metadata": {
960 | "scrolled": true
961 | },
962 | "outputs": [
963 | {
964 | "name": "stdout",
965 | "output_type": "stream",
966 | "text": [
967 | "Wall time: 0 ns\n"
968 | ]
969 | }
970 | ],
971 | "source": [
972 | "%%time\n",
973 | "# 推論実施\n",
974 | "interpreter.invoke()\n",
975 | "tflite_results = interpreter.get_tensor(output_details[0]['index'])"
976 | ]
977 | },
978 | {
979 | "cell_type": "code",
980 | "execution_count": 22,
981 | "metadata": {},
982 | "outputs": [
983 | {
984 | "name": "stdout",
985 | "output_type": "stream",
986 | "text": [
987 | "[0.7729778 0.16973573 0.05728643]\n",
988 | "0\n"
989 | ]
990 | }
991 | ],
992 | "source": [
993 | "print(np.squeeze(tflite_results))\n",
994 | "print(np.argmax(np.squeeze(tflite_results)))"
995 | ]
996 | }
997 | ],
998 | "metadata": {
999 | "kernelspec": {
1000 | "display_name": "Python 3",
1001 | "language": "python",
1002 | "name": "python3"
1003 | },
1004 | "language_info": {
1005 | "codemirror_mode": {
1006 | "name": "ipython",
1007 | "version": 3
1008 | },
1009 | "file_extension": ".py",
1010 | "mimetype": "text/x-python",
1011 | "name": "python",
1012 | "nbconvert_exporter": "python",
1013 | "pygments_lexer": "ipython3",
1014 | "version": "3.8.5"
1015 | }
1016 | },
1017 | "nbformat": 4,
1018 | "nbformat_minor": 4
1019 | }
1020 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | from model.keypoint_classifier.keypoint_classifier import KeyPointClassifier
2 | from model.point_history_classifier.point_history_classifier import PointHistoryClassifier
3 | from model.palm_detection.palm_detection import PalmDetection
4 | from model.hand_landmark.hand_landmark import HandLandmark
--------------------------------------------------------------------------------
/model/hand_landmark/hand_landmark.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from typing import (
3 | Tuple,
4 | Optional,
5 | List,
6 | )
7 | import cv2
8 | import onnxruntime
9 | import numpy as np
10 |
11 | from utils.utils import keep_aspect_resize_and_pad
12 |
13 |
14 | class HandLandmark(object):
15 | def __init__(
16 | self,
17 | model_path: Optional[str] = 'model/hand_landmark/hand_landmark_sparse_Nx3x224x224.onnx',
18 | class_score_th: Optional[float] = 0.50,
19 | providers: Optional[List] = [
20 | # (
21 | # 'TensorrtExecutionProvider', {
22 | # 'trt_engine_cache_enable': True,
23 | # 'trt_engine_cache_path': '.',
24 | # 'trt_fp16_enable': True,
25 | # }
26 | # ),
27 | 'CUDAExecutionProvider',
28 | 'CPUExecutionProvider',
29 | ],
30 | ):
31 | """HandLandmark
32 |
33 | Parameters
34 | ----------
35 | model_path: Optional[str]
36 | ONNX file path for Hand Landmark
37 |
38 | class_score_th: Optional[float]
39 | Score threshold. Default: 0.50
40 |
41 | providers: Optional[List]
42 | Name of onnx execution providers
43 | Default:
44 | [
45 | (
46 | 'TensorrtExecutionProvider', {
47 | 'trt_engine_cache_enable': True,
48 | 'trt_engine_cache_path': '.',
49 | 'trt_fp16_enable': True,
50 | }
51 | ),
52 | 'CUDAExecutionProvider',
53 | 'CPUExecutionProvider',
54 | ]
55 | """
56 | # Threshold
57 | self.class_score_th = class_score_th
58 |
59 | # Model loading
60 | session_option = onnxruntime.SessionOptions()
61 | session_option.log_severity_level = 3
62 | self.onnx_session = onnxruntime.InferenceSession(
63 | model_path,
64 | sess_options=session_option,
65 | providers=providers,
66 | )
67 | self.providers = self.onnx_session.get_providers()
68 |
69 | self.input_shapes = [
70 | input.shape for input in self.onnx_session.get_inputs()
71 | ]
72 | self.input_names = [
73 | input.name for input in self.onnx_session.get_inputs()
74 | ]
75 | self.output_names = [
76 | output.name for output in self.onnx_session.get_outputs()
77 | ]
78 |
79 |
80 | def __call__(
81 | self,
82 | images: List[np.ndarray],
83 | rects: np.ndarray,
84 | ) -> Tuple[np.ndarray, np.ndarray]:
85 | """HandLandmark
86 |
87 | Parameters
88 | ----------
89 | images: List[np.ndarray]
90 | Multiple palm images.
91 |
92 | rects: np.ndarray
93 | Coordinates, size and angle of the cut palm.
94 | [boxcount, cx, cy, width, height, angle]
95 |
96 | Returns
97 | -------
98 | hand_landmarks: np.ndarray
99 | Hand landmarks (X,Y) x 21
100 |
101 | rotated_image_size_leftrights: np.ndarray
102 | Maximum width and height of the perimeter of the rectangle around
103 | which the bounding box of the detected hand is rotated,
104 | and flags for left and right hand
105 | [rotated_image_width, rotated_image_height, left_hand_0_or_right_hand_1]
106 | """
107 | temp_images = copy.deepcopy(images)
108 |
109 | # PreProcess
110 | inference_images, resized_images, resize_scales_224x224, half_pad_sizes_224x224 = self.__preprocess(
111 | images=temp_images,
112 | )
113 |
114 | # Inference
115 | xyz_x21s, hand_scores, left_hand_0_or_right_hand_1s = self.onnx_session.run(
116 | self.output_names,
117 | {input_name: inference_images for input_name in self.input_names},
118 | )
119 |
120 | # PostProcess
121 | hand_landmarks, rotated_image_size_leftrights = self.__postprocess(
122 | resized_images=resized_images,
123 | resize_scales_224x224=resize_scales_224x224,
124 | half_pad_sizes_224x224=half_pad_sizes_224x224,
125 | rects=rects,
126 | xyz_x21s=xyz_x21s,
127 | hand_scores=hand_scores,
128 | left_hand_0_or_right_hand_1s=left_hand_0_or_right_hand_1s,
129 | )
130 |
131 | return hand_landmarks, rotated_image_size_leftrights
132 |
133 |
134 | def __preprocess(
135 | self,
136 | images: List[np.ndarray],
137 | swap: Optional[Tuple[int,int,int]] = (2, 0, 1),
138 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
139 | """__preprocess
140 |
141 | Parameters
142 | ----------
143 | images: List[np.ndarray]
144 | Multiple palm images.
145 |
146 | swap: tuple
147 | HWC to CHW: (2,0,1)
148 | CHW to HWC: (1,2,0)
149 | HWC to HWC: (0,1,2)
150 | CHW to CHW: (0,1,2)
151 |
152 | Returns
153 | -------
154 | padded_images: np.ndarray
155 | Resized and Padding and normalized image. [N,C,H,W]
156 | """
157 | temp_images = copy.deepcopy(images)
158 |
159 | # Resize + Padding + Normalization + BGR->RGB
160 | input_h = self.input_shapes[0][2]
161 | input_w = self.input_shapes[0][3]
162 |
163 | padded_images = []
164 | resized_images = []
165 | resize_scales_224x224 = []
166 | half_pad_sizes_224x224 = []
167 | for image in temp_images:
168 | padded_image, resized_image = keep_aspect_resize_and_pad(
169 | image=image,
170 | resize_width=input_w,
171 | resize_height=input_h,
172 | )
173 | # reduction_ratio_h = resized_h / original_h
174 | resize_224x224_scale_h = resized_image.shape[0] / image.shape[0]
175 | # reduction_ratio_w = resized_w / original_w
176 | resize_224x224_scale_w = resized_image.shape[1] / image.shape[1]
177 | resize_scales_224x224.append(
178 | [
179 | resize_224x224_scale_w,
180 | resize_224x224_scale_h,
181 | ]
182 | )
183 |
184 | pad_h = padded_image.shape[0] - resized_image.shape[0]
185 | pad_w = padded_image.shape[1] - resized_image.shape[1]
186 | half_pad_h_224x224 = pad_h // 2
187 | half_pad_h_224x224 = half_pad_h_224x224 if half_pad_h_224x224 >= 0 else 0
188 | half_pad_w_224x224 = pad_w // 2
189 | half_pad_w_224x224 = half_pad_w_224x224 if half_pad_w_224x224 >= 0 else 0
190 | half_pad_sizes_224x224.append([half_pad_w_224x224, half_pad_h_224x224])
191 |
192 | padded_image = np.divide(padded_image, 255.0)
193 | padded_image = padded_image[..., ::-1]
194 | padded_image = padded_image.transpose(swap)
195 | padded_image = np.ascontiguousarray(
196 | padded_image,
197 | dtype=np.float32,
198 | )
199 |
200 | padded_images.append(padded_image)
201 | resized_images.append(resized_image)
202 |
203 | return \
204 | np.asarray(padded_images, dtype=np.float32), \
205 | resized_images, \
206 | np.asarray(resize_scales_224x224, dtype=np.float32), \
207 | np.asarray(half_pad_sizes_224x224, dtype=np.int32)
208 |
209 |
210 | def __postprocess(
211 | self,
212 | resized_images: List[np.ndarray],
213 | resize_scales_224x224: np.ndarray,
214 | half_pad_sizes_224x224: np.ndarray,
215 | rects: np.ndarray,
216 | xyz_x21s: np.ndarray,
217 | hand_scores: np.ndarray,
218 | left_hand_0_or_right_hand_1s: np.ndarray,
219 | ) -> Tuple[np.ndarray, np.ndarray]:
220 | """__postprocess
221 |
222 | Parameters
223 | ----------
224 | rects: np.ndarray
225 | [N, cx, cy, (xmax-xmin), (ymax-ymin), degree]
226 |
227 | xyz_x21s: np.ndarray
228 | float32[N, 63]
229 | XYZ coordinates. 21 points.
230 |
231 | hand_scores: np.ndarray
232 | float32[N, 1]
233 | Hand score.
234 |
235 | left_hand_0_or_right_hand_1s: np.ndarray
236 | float32[N, 1]
237 | 0: Left hand
238 | 1: Right hand
239 |
240 | Returns
241 | -------
242 | extracted_hands: np.ndarray
243 | Hand landmarks (X,Y) x 21
244 |
245 | rotated_image_size_leftrights: np.ndarray
246 | Maximum width and height of the perimeter of the rectangle around
247 | which the bounding box of the detected hand is rotated,
248 | and flags for left and right hand
249 | [rotated_image_width, rotated_image_height, left_hand_0_or_right_hand_1]
250 | """
251 | hand_landmarks = np.asarray([], dtype=np.int32)
252 | extracted_hands = []
253 | rotated_image_size_leftrights = []
254 |
255 | keep = hand_scores[:, 0] > self.class_score_th # hand_score > self.class_score_th
256 | xyz_x21s = xyz_x21s[keep, :]
257 | hand_scores = hand_scores[keep, :]
258 | left_hand_0_or_right_hand_1s = left_hand_0_or_right_hand_1s[keep, :]
259 | resized_images = [i for (i, k) in zip(resized_images, keep) if k == True]
260 |
261 | for resized_image, resize_scale_224x224, half_pad_size_224x224, rect, xyz_x21, left_hand_0_or_right_hand_1 in \
262 | zip(resized_images, resize_scales_224x224, half_pad_sizes_224x224, rects, xyz_x21s, left_hand_0_or_right_hand_1s):
263 | """
264 | hands: sqn_rr_size, rotation, sqn_rr_center_x, sqn_rr_center_y
265 | cx = int(sqn_rr_center_x * frame_width)
266 | cy = int(sqn_rr_center_y * frame_height)
267 | xmin = int((sqn_rr_center_x - (sqn_rr_size / 2))*w)
268 | xmax = int((sqn_rr_center_x + (sqn_rr_size / 2))*w)
269 | ymin = int((sqn_rr_center_y - (sqn_rr_size * wh_ratio / 2))*h)
270 | ymax = int((sqn_rr_center_y + (sqn_rr_size * wh_ratio / 2))*h)
271 | degree = degrees(rotation)
272 | rect : cx, cy, (xmax-xmin), (ymax-ymin), degree
273 | rotation = radians(degree)
274 | sqn_rr_center_x = cx / frame_width
275 | sqn_rr_center_y = cy / frame_height
276 | """
277 | rrn_lms = xyz_x21
278 | input_h = self.input_shapes[0][2]
279 | input_w = self.input_shapes[0][3]
280 | rrn_lms = rrn_lms / input_h
281 |
282 | rcx = rect[0]
283 | rcy = rect[1]
284 | angle = rect[4]
285 |
286 | view_image = copy.deepcopy(resized_image)
287 | view_image = cv2.resize(
288 | view_image,
289 | dsize=None,
290 | fx=1/resize_scale_224x224[0],
291 | fy=1/resize_scale_224x224[1],
292 | )
293 | rescaled_xy = np.asarray([[v[0], v[1]] for v in zip(rrn_lms[0::3], rrn_lms[1::3])], dtype=np.float32)
294 | rescaled_xy[:, 0] = (rescaled_xy[:, 0] * input_w - half_pad_size_224x224[0]) / resize_scale_224x224[0]
295 | rescaled_xy[:, 1] = (rescaled_xy[:, 1] * input_h - half_pad_size_224x224[1]) / resize_scale_224x224[1]
296 | rescaled_xy = rescaled_xy.astype(np.int32)
297 |
298 | height, width = view_image.shape[:2]
299 | image_center = (width//2, height//2)
300 | rotation_matrix = cv2.getRotationMatrix2D(image_center, -int(angle), 1)
301 | abs_cos = abs(rotation_matrix[0,0])
302 | abs_sin = abs(rotation_matrix[0,1])
303 | bound_w = int(height * abs_sin + width * abs_cos)
304 | bound_h = int(height * abs_cos + width * abs_sin)
305 | rotation_matrix[0, 2] += bound_w/2 - image_center[0]
306 | rotation_matrix[1, 2] += bound_h/2 - image_center[1]
307 | rotated_image = cv2.warpAffine(view_image, rotation_matrix, (bound_w, bound_h))
308 |
309 | keypoints = []
310 | for x,y in rescaled_xy:
311 | coord_arr = np.array([
312 | [x, y, 1], # Left-Top
313 | ])
314 | new_coord = rotation_matrix.dot(coord_arr.T)
315 | x_ls = new_coord[0]
316 | y_ls = new_coord[1]
317 | keypoints.append([int(x_ls), int(y_ls)])
318 |
319 | rotated_image_width = rotated_image.shape[1]
320 | rotated_image_height = rotated_image.shape[0]
321 | roatated_hand_half_width = rotated_image_width//2
322 | roatated_hand_half_height = rotated_image_height//2
323 |
324 | hand_landmarks = np.asarray(keypoints, dtype=np.int32).reshape(-1,2)
325 | hand_landmarks[..., 0] = hand_landmarks[..., 0] + rcx - roatated_hand_half_width
326 | hand_landmarks[..., 1] = hand_landmarks[..., 1] + rcy - roatated_hand_half_height
327 | extracted_hands.append(hand_landmarks)
328 | rotated_image_size_leftrights.append([rotated_image_width, rotated_image_height, left_hand_0_or_right_hand_1])
329 |
330 | return np.asarray(extracted_hands, dtype=np.int32), np.asarray(rotated_image_size_leftrights)
331 |
--------------------------------------------------------------------------------
/model/hand_landmark/hand_landmark_sparse_Nx3x224x224.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PINTO0309/hand-gesture-recognition-using-onnx/baf475d8015183dbff17e56771c0f623c8b5c816/model/hand_landmark/hand_landmark_sparse_Nx3x224x224.onnx
--------------------------------------------------------------------------------
/model/keypoint_classifier/keypoint_classifier.hdf5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PINTO0309/hand-gesture-recognition-using-onnx/baf475d8015183dbff17e56771c0f623c8b5c816/model/keypoint_classifier/keypoint_classifier.hdf5
--------------------------------------------------------------------------------
/model/keypoint_classifier/keypoint_classifier.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PINTO0309/hand-gesture-recognition-using-onnx/baf475d8015183dbff17e56771c0f623c8b5c816/model/keypoint_classifier/keypoint_classifier.onnx
--------------------------------------------------------------------------------
/model/keypoint_classifier/keypoint_classifier.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import onnxruntime
4 | import numpy as np
5 | from typing import (
6 | Optional,
7 | List,
8 | )
9 |
10 | class KeyPointClassifier(object):
11 | def __init__(
12 | self,
13 | model_path: Optional[str] = 'model/keypoint_classifier/keypoint_classifier.onnx',
14 | providers: Optional[List] = [
15 | # (
16 | # 'TensorrtExecutionProvider', {
17 | # 'trt_engine_cache_enable': True,
18 | # 'trt_engine_cache_path': '.',
19 | # 'trt_fp16_enable': True,
20 | # }
21 | # ),
22 | 'CUDAExecutionProvider',
23 | 'CPUExecutionProvider',
24 | ],
25 | ):
26 | """KeyPointClassifier
27 |
28 | Parameters
29 | ----------
30 | model_path: Optional[str]
31 | ONNX file path for Palm Detection
32 |
33 | providers: Optional[List]
34 | Name of onnx execution providers
35 | Default:
36 | [
37 | (
38 | 'TensorrtExecutionProvider', {
39 | 'trt_engine_cache_enable': True,
40 | 'trt_engine_cache_path': '.',
41 | 'trt_fp16_enable': True,
42 | }
43 | ),
44 | 'CUDAExecutionProvider',
45 | 'CPUExecutionProvider',
46 | ]
47 | """
48 | # Model loading
49 | session_option = onnxruntime.SessionOptions()
50 | session_option.log_severity_level = 3
51 | self.onnx_session = onnxruntime.InferenceSession(
52 | model_path,
53 | sess_options=session_option,
54 | providers=providers,
55 | )
56 | self.providers = self.onnx_session.get_providers()
57 |
58 | self.input_shapes = [
59 | input.shape for input in self.onnx_session.get_inputs()
60 | ]
61 | self.input_names = [
62 | input.name for input in self.onnx_session.get_inputs()
63 | ]
64 | self.output_names = [
65 | output.name for output in self.onnx_session.get_outputs()
66 | ]
67 |
68 |
69 | def __call__(
70 | self,
71 | landmarks: np.ndarray,
72 | ) -> np.ndarray:
73 | """KeyPointClassifier
74 |
75 | Parameters
76 | ----------
77 | landmarks: np.ndarray
78 | Landmarks [N, 42]
79 |
80 | Returns
81 | -------
82 | class_ids: np.ndarray
83 | float32[N]
84 | ClassIDs of Hand Signatures
85 | """
86 | class_ids = self.onnx_session.run(
87 | self.output_names,
88 | {input_name: landmarks for input_name in self.input_names},
89 | )[0]
90 |
91 | return class_ids
--------------------------------------------------------------------------------
/model/keypoint_classifier/keypoint_classifier.tflite:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PINTO0309/hand-gesture-recognition-using-onnx/baf475d8015183dbff17e56771c0f623c8b5c816/model/keypoint_classifier/keypoint_classifier.tflite
--------------------------------------------------------------------------------
/model/keypoint_classifier/keypoint_classifier_label.csv:
--------------------------------------------------------------------------------
1 | Open
2 | Close
3 | Pointer
4 |
--------------------------------------------------------------------------------
/model/keypoint_classifier/make_argmax.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 |
3 | import torch
4 | import torch.nn as nn
5 | import numpy as np
6 | import onnx
7 | from onnxsim import simplify
8 | from argparse import ArgumentParser
9 |
10 | class Model(nn.Module):
11 | def __init__(self):
12 | super(Model, self).__init__()
13 |
14 | def forward(self, scores):
15 | max_values, max_indices = torch.max(scores, dim=1)
16 | return max_indices
17 |
18 |
19 | if __name__ == "__main__":
20 | parser = ArgumentParser()
21 | parser.add_argument(
22 | '-o',
23 | '--opset',
24 | type=int,
25 | default=11,
26 | help='onnx opset'
27 | )
28 | parser.add_argument(
29 | '-b',
30 | '--batches',
31 | type=int,
32 | default=1,
33 | help='batch size'
34 | )
35 | parser.add_argument(
36 | '-c',
37 | '--classes',
38 | type=int,
39 | default=3,
40 | help='classes'
41 | )
42 | args = parser.parse_args()
43 |
44 | model = Model()
45 |
46 | MODEL = f'argmax'
47 | OPSET=args.opset
48 | BATCHES = args.batches
49 | CLASSES = args.classes
50 |
51 | onnx_file = f"{MODEL}.onnx"
52 | scores = torch.randn(BATCHES, CLASSES)
53 |
54 | torch.onnx.export(
55 | model,
56 | args=(scores),
57 | f=onnx_file,
58 | opset_version=OPSET,
59 | input_names=['argmax_input'],
60 | output_names=['class_ids'],
61 | dynamic_axes={
62 | 'argmax_input' : {0: 'batch'},
63 | 'class_ids' : {0: 'batch'},
64 | },
65 | )
66 | model_onnx1 = onnx.load(onnx_file)
67 | model_onnx1 = onnx.shape_inference.infer_shapes(model_onnx1)
68 | onnx.save(model_onnx1, onnx_file)
69 |
70 | model_onnx2 = onnx.load(onnx_file)
71 | model_simp, check = simplify(model_onnx2)
72 | onnx.save(model_simp, onnx_file)
--------------------------------------------------------------------------------
/model/keypoint_classifier/tflite_to_onnx.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [ $# -gt 1 ]; then
4 | echo "The number of arguments specified is $#." 1>&2
5 | echo "Be sure to specify 0 or 1 (Number of classes) argument. (default:4)" 1>&2
6 | exit 1
7 | fi
8 |
9 | if [ $# -eq 0 ]; then
10 | CLASSES=3
11 | else
12 | CLASSES=$1
13 | fi
14 |
15 | python -m tf2onnx.convert \
16 | --opset 11 \
17 | --tflite keypoint_classifier.tflite \
18 | --output keypoint_classifier.onnx
19 |
20 | onnxsim keypoint_classifier.onnx keypoint_classifier.onnx
21 |
22 | sbi4onnx \
23 | --input_onnx_file_path keypoint_classifier.onnx \
24 | --output_onnx_file_path keypoint_classifier.onnx \
25 | --initialization_character_string batch
26 |
27 | sor4onnx \
28 | --input_onnx_file_path keypoint_classifier.onnx \
29 | --old_new "input_1" "input" \
30 | --mode inputs \
31 | --output_onnx_file_path keypoint_classifier.onnx
32 |
33 | sor4onnx \
34 | --input_onnx_file_path keypoint_classifier.onnx \
35 | --old_new "Identity" "base_scores" \
36 | --mode outputs \
37 | --output_onnx_file_path keypoint_classifier.onnx
38 |
39 | python make_argmax.py --classes ${CLASSES}
40 |
41 | snc4onnx \
42 | -if keypoint_classifier.onnx argmax.onnx \
43 | -of keypoint_classifier.onnx \
44 | -sd base_scores argmax_input
45 |
46 | rm argmax.onnx
47 |
--------------------------------------------------------------------------------
/model/palm_detection/palm_detection.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from typing import (
3 | Tuple,
4 | Optional,
5 | List,
6 | )
7 | from math import (
8 | sin,
9 | cos,
10 | atan2,
11 | pi,
12 | )
13 | import onnxruntime
14 | import numpy as np
15 |
16 | from utils.utils import (
17 | normalize_radians,
18 | keep_aspect_resize_and_pad,
19 | )
20 |
21 |
22 | class PalmDetection(object):
23 | def __init__(
24 | self,
25 | model_path: Optional[str] = 'model/palm_detection/palm_detection_full_inf_post_192x192.onnx',
26 | score_threshold: Optional[float] = 0.60,
27 | providers: Optional[List] = [
28 | # (
29 | # 'TensorrtExecutionProvider', {
30 | # 'trt_engine_cache_enable': True,
31 | # 'trt_engine_cache_path': '.',
32 | # 'trt_fp16_enable': True,
33 | # }
34 | # ),
35 | 'CUDAExecutionProvider',
36 | 'CPUExecutionProvider',
37 | ],
38 | ):
39 | """PalmDetection
40 |
41 | Parameters
42 | ----------
43 | model_path: Optional[str]
44 | ONNX file path for Palm Detection
45 |
46 | score_threshold: Optional[float]
47 | Detection score threshold. Default: 0.60
48 |
49 | providers: Optional[List]
50 | Name of onnx execution providers
51 | Default:
52 | [
53 | (
54 | 'TensorrtExecutionProvider', {
55 | 'trt_engine_cache_enable': True,
56 | 'trt_engine_cache_path': '.',
57 | 'trt_fp16_enable': True,
58 | }
59 | ),
60 | 'CUDAExecutionProvider',
61 | 'CPUExecutionProvider',
62 | ]
63 | """
64 | # Threshold
65 | self.score_threshold = score_threshold
66 |
67 | # Model loading
68 | session_option = onnxruntime.SessionOptions()
69 | session_option.log_severity_level = 3
70 | self.onnx_session = onnxruntime.InferenceSession(
71 | model_path,
72 | sess_options=session_option,
73 | providers=providers,
74 | )
75 | self.providers = self.onnx_session.get_providers()
76 |
77 | self.input_shapes = [
78 | input.shape for input in self.onnx_session.get_inputs()
79 | ]
80 | self.input_names = [
81 | input.name for input in self.onnx_session.get_inputs()
82 | ]
83 | self.output_names = [
84 | output.name for output in self.onnx_session.get_outputs()
85 | ]
86 | self.square_standard_size = 0
87 |
88 |
89 | def __call__(
90 | self,
91 | image: np.ndarray,
92 | ) -> Tuple[np.ndarray, np.ndarray]:
93 | """PalmDetection
94 |
95 | Parameters
96 | ----------
97 | image: np.ndarray
98 | Entire image
99 |
100 | Returns
101 | -------
102 | hands: np.ndarray
103 | float32[N, 4]
104 | sqn_rr_size, rotation, sqn_rr_center_x, sqn_rr_center_y
105 | """
106 | temp_image = copy.deepcopy(image)
107 |
108 | # PreProcess
109 | inference_image = self.__preprocess(
110 | temp_image,
111 | )
112 |
113 | # Inference
114 | inferece_image = np.asarray([inference_image], dtype=np.float32)
115 | boxes = self.onnx_session.run(
116 | self.output_names,
117 | {input_name: inferece_image for input_name in self.input_names},
118 | )
119 |
120 | # PostProcess
121 | hands = self.__postprocess(
122 | image=temp_image,
123 | boxes=boxes[0],
124 | )
125 |
126 | return hands
127 |
128 |
129 | def __preprocess(
130 | self,
131 | image: np.ndarray,
132 | swap: Optional[Tuple[int,int,int]] = (2, 0, 1),
133 | ) -> np.ndarray:
134 | """__preprocess
135 |
136 | Parameters
137 | ----------
138 | image: np.ndarray
139 | Entire image
140 |
141 | swap: tuple
142 | HWC to CHW: (2,0,1)
143 | CHW to HWC: (1,2,0)
144 | HWC to HWC: (0,1,2)
145 | CHW to CHW: (0,1,2)
146 |
147 | Returns
148 | -------
149 | padded_image: np.ndarray
150 | Resized and Padding and normalized image.
151 | """
152 | # Resize + Padding + Normalization + BGR->RGB
153 | input_h = self.input_shapes[0][2]
154 | input_w = self.input_shapes[0][3]
155 | image_height , image_width = image.shape[:2]
156 |
157 | self.square_standard_size = max(image_height, image_width)
158 | self.square_padding_half_size = abs(image_height - image_width) // 2
159 |
160 | padded_image, resized_image = keep_aspect_resize_and_pad(
161 | image=image,
162 | resize_width=input_w,
163 | resize_height=input_h,
164 | )
165 |
166 | pad_size_half_h = max(0, (input_h - resized_image.shape[0]) // 2)
167 | pad_size_half_w = max(0, (input_w - resized_image.shape[1]) // 2)
168 |
169 | self.pad_size_scale_h = pad_size_half_h / input_h
170 | self.pad_size_scale_w = pad_size_half_w / input_w
171 |
172 | padded_image = np.divide(padded_image, 255.0)
173 | padded_image = padded_image[..., ::-1]
174 | padded_image = padded_image.transpose(swap)
175 | padded_image = np.ascontiguousarray(
176 | padded_image,
177 | dtype=np.float32,
178 | )
179 | return padded_image
180 |
181 |
182 | def __postprocess(
183 | self,
184 | image: np.ndarray,
185 | boxes: np.ndarray,
186 | ) -> np.ndarray:
187 | """__postprocess
188 |
189 | Parameters
190 | ----------
191 | image: np.ndarray
192 | Entire image.
193 |
194 | boxes: np.ndarray
195 | float32[N, 8]
196 | pd_score, box_x, box_y, box_size, kp0_x, kp0_y, kp2_x, kp2_y
197 |
198 | Returns
199 | -------
200 | hands: np.ndarray
201 | float32[N, 4]
202 | sqn_rr_size, rotation, sqn_rr_center_x, sqn_rr_center_y
203 | """
204 | image_height = image.shape[0]
205 | image_width = image.shape[1]
206 |
207 | hands = []
208 | keep = boxes[:, 0] > self.score_threshold # pd_score > self.score_threshold
209 | boxes = boxes[keep, :]
210 |
211 | for box in boxes:
212 | pd_score, box_x, box_y, box_size, kp0_x, kp0_y, kp2_x, kp2_y = box
213 | if box_size > 0:
214 | kp02_x = kp2_x - kp0_x
215 | kp02_y = kp2_y - kp0_y
216 | sqn_rr_size = 2.9 * box_size
217 | rotation = 0.5 * pi - atan2(-kp02_y, kp02_x)
218 | rotation = normalize_radians(rotation)
219 | sqn_rr_center_x = box_x + 0.5*box_size*sin(rotation)
220 | sqn_rr_center_y = box_y - 0.5*box_size*cos(rotation)
221 | sqn_rr_center_y = (sqn_rr_center_y * self.square_standard_size - self.square_padding_half_size) / image_height
222 | hands.append(
223 | [
224 | sqn_rr_size,
225 | rotation,
226 | sqn_rr_center_x,
227 | sqn_rr_center_y,
228 | ]
229 | )
230 |
231 | return np.asarray(hands)
--------------------------------------------------------------------------------
/model/palm_detection/palm_detection_full_inf_post_192x192.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PINTO0309/hand-gesture-recognition-using-onnx/baf475d8015183dbff17e56771c0f623c8b5c816/model/palm_detection/palm_detection_full_inf_post_192x192.onnx
--------------------------------------------------------------------------------
/model/point_history_classifier/argmax.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PINTO0309/hand-gesture-recognition-using-onnx/baf475d8015183dbff17e56771c0f623c8b5c816/model/point_history_classifier/argmax.onnx
--------------------------------------------------------------------------------
/model/point_history_classifier/make_argmax.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 |
3 | import torch
4 | import torch.nn as nn
5 | import numpy as np
6 | import onnx
7 | from onnxsim import simplify
8 | from argparse import ArgumentParser
9 |
10 | class Model(nn.Module):
11 | def __init__(self):
12 | super(Model, self).__init__()
13 | self.invalid_value = 0
14 |
15 | def forward(self, scores, score_threshold):
16 | max_scores, class_ids = torch.max(scores, dim=1)
17 |
18 | invalid_idxs = max_scores < score_threshold
19 | class_ids[invalid_idxs] = self.invalid_value
20 |
21 | return class_ids
22 |
23 |
24 | if __name__ == "__main__":
25 | parser = ArgumentParser()
26 | parser.add_argument(
27 | '-o',
28 | '--opset',
29 | type=int,
30 | default=11,
31 | help='onnx opset'
32 | )
33 | parser.add_argument(
34 | '-b',
35 | '--batches',
36 | type=int,
37 | default=1,
38 | help='batch size'
39 | )
40 | parser.add_argument(
41 | '-c',
42 | '--classes',
43 | type=int,
44 | default=4,
45 | help='classes'
46 | )
47 | args = parser.parse_args()
48 |
49 | model = Model()
50 |
51 | MODEL = f'argmax'
52 | OPSET=args.opset
53 | BATCHES = args.batches
54 | CLASSES = args.classes
55 |
56 | onnx_file = f"{MODEL}.onnx"
57 | scores = torch.randn(BATCHES, CLASSES)
58 | score_threshold = torch.tensor(0.5, dtype=torch.float32)
59 |
60 | torch.onnx.export(
61 | model,
62 | args=(scores, score_threshold),
63 | f=onnx_file,
64 | opset_version=OPSET,
65 | input_names=['argmax_input', 'score_threshold'],
66 | output_names=['class_ids'],
67 | dynamic_axes={
68 | 'argmax_input' : {0: 'batch'},
69 | 'class_ids' : {0: 'batch'},
70 | },
71 | )
72 | model_onnx1 = onnx.load(onnx_file)
73 | model_onnx1 = onnx.shape_inference.infer_shapes(model_onnx1)
74 | onnx.save(model_onnx1, onnx_file)
75 |
76 | model_onnx2 = onnx.load(onnx_file)
77 | model_simp, check = simplify(model_onnx2)
78 | onnx.save(model_simp, onnx_file)
--------------------------------------------------------------------------------
/model/point_history_classifier/point_history_classifier.hdf5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PINTO0309/hand-gesture-recognition-using-onnx/baf475d8015183dbff17e56771c0f623c8b5c816/model/point_history_classifier/point_history_classifier.hdf5
--------------------------------------------------------------------------------
/model/point_history_classifier/point_history_classifier.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PINTO0309/hand-gesture-recognition-using-onnx/baf475d8015183dbff17e56771c0f623c8b5c816/model/point_history_classifier/point_history_classifier.onnx
--------------------------------------------------------------------------------
/model/point_history_classifier/point_history_classifier.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import onnxruntime
4 | import numpy as np
5 | from typing import (
6 | Optional,
7 | List,
8 | )
9 |
10 |
11 | class PointHistoryClassifier(object):
12 | def __init__(
13 | self,
14 | model_path: Optional[str] = 'model/point_history_classifier/point_history_classifier_lstm.onnx',
15 | providers: Optional[List] = [
16 | # (
17 | # 'TensorrtExecutionProvider', {
18 | # 'trt_engine_cache_enable': True,
19 | # 'trt_engine_cache_path': '.',
20 | # 'trt_fp16_enable': True,
21 | # }
22 | # ),
23 | 'CUDAExecutionProvider',
24 | 'CPUExecutionProvider',
25 | ],
26 | score_th=0.5,
27 | ):
28 | """PointHistoryClassifier
29 |
30 | Parameters
31 | ----------
32 | model_path: Optional[str]
33 | ONNX file path for Palm Detection
34 |
35 | providers: Optional[List]
36 | Name of onnx execution providers
37 | Default:
38 | [
39 | (
40 | 'TensorrtExecutionProvider', {
41 | 'trt_engine_cache_enable': True,
42 | 'trt_engine_cache_path': '.',
43 | 'trt_fp16_enable': True,
44 | }
45 | ),
46 | 'CUDAExecutionProvider',
47 | 'CPUExecutionProvider',
48 | ]
49 | """
50 | # Model loading
51 | session_option = onnxruntime.SessionOptions()
52 | session_option.log_severity_level = 3
53 | self.onnx_session = onnxruntime.InferenceSession(
54 | model_path,
55 | sess_options=session_option,
56 | providers=providers,
57 | )
58 | self.providers = self.onnx_session.get_providers()
59 |
60 | self.input_shapes = [
61 | input.shape for input in self.onnx_session.get_inputs()
62 | ]
63 | self.input_names = [
64 | input.name for input in self.onnx_session.get_inputs()
65 | ]
66 | self.output_names = [
67 | output.name for output in self.onnx_session.get_outputs()
68 | ]
69 | self.score_th = np.asarray(score_th, dtype=np.float32)
70 |
71 |
72 | def __call__(
73 | self,
74 | point_history: np.ndarray,
75 | ) -> np.ndarray:
76 | """PointHistoryClassifier
77 |
78 | Parameters
79 | ----------
80 | point_history: np.ndarray
81 | Landmarks [N, 32]
82 |
83 | Returns
84 | -------
85 | class_ids: np.ndarray
86 | int64[N]
87 | ClassIDs of Finger gesture
88 | """
89 | class_ids = self.onnx_session.run(
90 | self.output_names,
91 | {
92 | self.input_names[0]: point_history,
93 | self.input_names[1]: self.score_th,
94 | },
95 | )[0]
96 |
97 | return class_ids
98 |
--------------------------------------------------------------------------------
/model/point_history_classifier/point_history_classifier.tflite:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PINTO0309/hand-gesture-recognition-using-onnx/baf475d8015183dbff17e56771c0f623c8b5c816/model/point_history_classifier/point_history_classifier.tflite
--------------------------------------------------------------------------------
/model/point_history_classifier/point_history_classifier_label.csv:
--------------------------------------------------------------------------------
1 | Stop
2 | Clockwise
3 | Counter Clockwise
4 | Move
5 |
--------------------------------------------------------------------------------
/model/point_history_classifier/point_history_classifier_lstm.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PINTO0309/hand-gesture-recognition-using-onnx/baf475d8015183dbff17e56771c0f623c8b5c816/model/point_history_classifier/point_history_classifier_lstm.onnx
--------------------------------------------------------------------------------
/model/point_history_classifier/tflite_to_onnx.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [ $# -gt 1 ]; then
4 | echo "The number of arguments specified is $#." 1>&2
5 | echo "Be sure to specify 0 or 1 (Number of classes) argument. (default:4)" 1>&2
6 | exit 1
7 | fi
8 |
9 | if [ $# -eq 0 ]; then
10 | CLASSES=4
11 | else
12 | CLASSES=$1
13 | fi
14 |
15 | python -m tf2onnx.convert \
16 | --opset 11 \
17 | --tflite point_history_classifier.tflite \
18 | --output point_history_classifier.onnx
19 |
20 | onnxsim point_history_classifier.onnx point_history_classifier.onnx
21 |
22 | sbi4onnx \
23 | --input_onnx_file_path point_history_classifier.onnx \
24 | --output_onnx_file_path point_history_classifier.onnx \
25 | --initialization_character_string batch
26 |
27 | sor4onnx \
28 | --input_onnx_file_path point_history_classifier.onnx \
29 | --old_new "input_1" "input" \
30 | --mode inputs \
31 | --output_onnx_file_path point_history_classifier.onnx
32 |
33 | sor4onnx \
34 | --input_onnx_file_path point_history_classifier.onnx \
35 | --old_new "Identity" "base_scores" \
36 | --mode outputs \
37 | --output_onnx_file_path point_history_classifier.onnx
38 |
39 | python make_argmax.py --classes ${CLASSES}
40 |
41 | snc4onnx \
42 | -if point_history_classifier.onnx argmax.onnx \
43 | -of point_history_classifier.onnx \
44 | -sd base_scores argmax_input
45 |
46 | rm argmax.onnx
47 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | onnxruntime >= 1.12.0
2 | opencv-contrib-python >= 4.6.0.66
3 | scikit-learn >= 0.23.2
4 | matplotlib >= 3.3.2
5 | tf2onnx >= 1.12.0
6 | simple-onnx-processing-tools >= 1.0.54
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from utils.cvfpscalc import CvFpsCalc
--------------------------------------------------------------------------------
/utils/cvfpscalc.py:
--------------------------------------------------------------------------------
1 | from collections import deque
2 | import cv2 as cv
3 |
4 |
5 | class CvFpsCalc(object):
6 | def __init__(self, buffer_len=1):
7 | self._start_tick = cv.getTickCount()
8 | self._freq = 1000.0 / cv.getTickFrequency()
9 | self._difftimes = deque(maxlen=buffer_len)
10 |
11 | def get(self):
12 | current_tick = cv.getTickCount()
13 | different_time = (current_tick - self._start_tick) * self._freq
14 | self._start_tick = current_tick
15 |
16 | self._difftimes.append(different_time)
17 |
18 | fps = 1000.0 / (sum(self._difftimes) / len(self._difftimes))
19 | fps_rounded = round(fps, 2)
20 |
21 | return fps_rounded
22 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import math
3 | import copy
4 | import numpy as np
5 | from typing import Tuple, List
6 | from math import sin, cos, pi, floor
7 |
8 |
9 | def normalize_radians(
10 | angle: float
11 | ) -> float:
12 | """__normalize_radians
13 |
14 | Parameters
15 | ----------
16 | angle: float
17 |
18 | Returns
19 | -------
20 | normalized_angle: float
21 | """
22 | return angle - 2 * pi * floor((angle + pi) / (2 * pi))
23 |
24 |
25 | def is_inside_rect(
26 | rects: np.ndarray,
27 | width_of_outer_rect: int,
28 | height_of_outer_rect: int,
29 | ) -> np.ndarray:
30 | """Determines whether rects is inside or outside the outer rectangle.
31 |
32 | Parameters
33 | ----------
34 | rects: np.ndarray
35 | [boxcount, 5] = [boxcount, cx, cy, width, height, angle]\n
36 | Area to be verified.
37 |
38 | cx: float
39 | Rectangle center X coordinate.
40 |
41 | cy: float
42 | Rectangle center Y coordinate.
43 |
44 | width: float
45 | Width of the rectangle.
46 |
47 | height: float
48 | Height of the rectangle.
49 |
50 | angle: float
51 | The rotation angle in a clockwise direction.
52 | When the angle is 0, 90, 180, 270, 360 etc.,
53 | the rectangle becomes an up-right rectangle.
54 |
55 | width_of_outer_rect: int
56 | Width of outer rectangle.
57 |
58 | height_of_outer_rect: int
59 | Height of outer rectangle
60 |
61 | Returns
62 | -------
63 | result: np.ndarray
64 | True: if the rotated sub rectangle is side the up-right rectange, False: else
65 | """
66 | results = []
67 |
68 | for rect in rects:
69 | cx = rect[0]
70 | cy = rect[1]
71 | width = rect[2]
72 | height = rect[3]
73 | angle = rect[4]
74 |
75 | if (cx < 0) or (cx > width_of_outer_rect):
76 | # Center X coordinate is outside the range of the outer rectangle
77 | results.append(False)
78 |
79 | elif (cy < 0) or (cy > height_of_outer_rect):
80 | # Center Y coordinate is outside the range of the outer rectangle
81 | results.append(False)
82 |
83 | else:
84 | # Coordinate acquisition of bounding rectangle considering rotation
85 | # http://labs.eecs.tottori-u.ac.jp/sd/Member/oyamada/OpenCV/html/py_tutorials/py_imgproc/py_contours/py_contour_features/py_contour_features.html#b
86 | rect_tuple = ((cx, cy), (width, height), angle)
87 | box = cv2.boxPoints(rect_tuple)
88 |
89 | x_max = int(np.max(box[:,0]))
90 | x_min = int(np.min(box[:,0]))
91 | y_max = int(np.max(box[:,1]))
92 | y_min = int(np.min(box[:,1]))
93 |
94 | if (x_min >= 0) and (x_max <= width_of_outer_rect) and \
95 | (y_min >= 0) and (y_max <= height_of_outer_rect):
96 | # All 4 vertices are within the perimeter rectangle
97 | results.append(True)
98 | else:
99 | # Any of the 4 vertices is outside the perimeter rectangle
100 | results.append(False)
101 |
102 | return np.asarray(results, dtype=np.bool_)
103 |
104 |
105 | def bounding_box_from_rotated_rect(
106 | rects: np.ndarray,
107 | ) -> np.ndarray:
108 | """Conversion to bounding rectangle without rotation.
109 |
110 | Parameters
111 | ----------
112 | rects: np.ndarray
113 | [boxcount, 5] = [boxcount, cx, cy, width, height, angle]\n
114 | Rotated rectangle.
115 |
116 | cx: float
117 | Rectangle center X coordinate.
118 |
119 | cy: float
120 | Rectangle center Y coordinate.
121 |
122 | width: float
123 | Width of the rectangle.
124 |
125 | height: float
126 | Height of the rectangle.
127 |
128 | angle: float
129 | The rotation angle in a clockwise direction.
130 | When the angle is 0, 90, 180, 270, 360 etc.,
131 | the rectangle becomes an up-right rectangle.
132 |
133 | Returns
134 | -------
135 | result: np.ndarray
136 | e.g.:\n
137 | [input] rotated rectangle:\n
138 | [center:(10, 10), height:4, width:4, angle:45 degree]\n
139 | [output] bounding box for this rotated rectangle:\n
140 | [center:(10, 10), height:4*sqrt(2), width:4*sqrt(2), angle:0 degree]
141 | """
142 | results = []
143 |
144 | for rect in rects:
145 | cx = rect[0]
146 | cy = rect[1]
147 | width = rect[2]
148 | height = rect[3]
149 | angle = rect[4]
150 |
151 | rect_tuple = ((cx, cy), (width, height), angle)
152 | box = cv2.boxPoints(rect_tuple)
153 |
154 | x_max = int(np.max(box[:,0]))
155 | x_min = int(np.min(box[:,0]))
156 | y_max = int(np.max(box[:,1]))
157 | y_min = int(np.min(box[:,1]))
158 |
159 | cx = int((x_min + x_max) // 2)
160 | cy = int((y_min + y_max) // 2)
161 | width = int(x_max - x_min)
162 | height = int(y_max - y_min)
163 | angle = 0
164 | results.append([cx, cy, width, height, angle])
165 |
166 | return np.asarray(results, dtype=np.float32)
167 |
168 |
169 | def image_rotation_without_crop(
170 | images: List[np.ndarray],
171 | angles: np.ndarray,
172 | ) -> List[np.ndarray]:
173 | """Conversion to bounding rectangle without rotation.
174 |
175 | Parameters
176 | ----------
177 | images: List[np.ndarray]
178 | Image to be rotated.
179 |
180 | angles: np.ndarray
181 | Rotation degree.
182 |
183 | Returns
184 | -------
185 | rotated_images: List[np.ndarray]
186 | Image after rotation.
187 | """
188 | rotated_images = []
189 | # https://stackoverflow.com/questions/22041699/rotate-an-image-without-cropping-in-opencv-in-c
190 | for image, angle in zip(images, angles):
191 | height, width = image.shape[:2]
192 | image_center = (width//2, height//2)
193 | rotation_matrix = cv2.getRotationMatrix2D(image_center, int(angle), 1)
194 | abs_cos = abs(rotation_matrix[0,0])
195 | abs_sin = abs(rotation_matrix[0,1])
196 | bound_w = int(height * abs_sin + width * abs_cos)
197 | bound_h = int(height * abs_cos + width * abs_sin)
198 | rotation_matrix[0, 2] += bound_w/2 - image_center[0]
199 | rotation_matrix[1, 2] += bound_h/2 - image_center[1]
200 | rotated_image = cv2.warpAffine(image, rotation_matrix, (bound_w, bound_h))
201 | rotated_images.append(rotated_image)
202 |
203 | return rotated_images
204 |
205 |
206 | def crop_rectangle(
207 | image: np.ndarray,
208 | rects: np.ndarray,
209 | ) -> List[np.ndarray]:
210 | """rect has to be upright.
211 |
212 | Parameters
213 | ----------
214 | image: np.ndarray
215 | Image to be rotate and crop.
216 |
217 | rects: np.ndarray
218 | [boxcount, 5] = [boxcount, cx, cy, width, height, angle]\n
219 | Rotat and crop rectangle.
220 |
221 | cx: float
222 | Rectangle center X coordinate.
223 |
224 | cy: float
225 | Rectangle center Y coordinate.
226 |
227 | width: float
228 | Width of the rectangle.
229 |
230 | height: float
231 | Height of the rectangle.
232 |
233 | angle: float
234 | The rotation angle in a clockwise direction.
235 | When the angle is 0, 90, 180, 270, 360 etc.,
236 | the rectangle becomes an up-right rectangle.
237 |
238 | Returns
239 | -------
240 | croped_images: List[np.ndarray]
241 | Image after cropping.
242 | """
243 | croped_images = []
244 | height = image.shape[0]
245 | width = image.shape[1]
246 |
247 | # Determine if rect is inside the entire image
248 | inside_or_outsides = is_inside_rect(
249 | rects=rects,
250 | width_of_outer_rect=width,
251 | height_of_outer_rect=height,
252 | )
253 |
254 | rects = rects[inside_or_outsides, ...]
255 |
256 | for rect in rects:
257 | cx = int(rect[0])
258 | cy = int(rect[1])
259 | rect_width = int(rect[2])
260 | rect_height = int(rect[3])
261 |
262 | croped_image = image[
263 | cy-rect_height//2:cy+rect_height-rect_height//2,
264 | cx-rect_width//2:cx+rect_width-rect_width//2,
265 | ]
266 | croped_images.append(croped_image)
267 |
268 | return croped_images
269 |
270 |
271 | def rotate_and_crop_rectangle(
272 | image: np.ndarray,
273 | rects_tmp: np.ndarray,
274 | operation_when_cropping_out_of_range: str,
275 | ) -> List[np.ndarray]:
276 | """Crop a rotated rectangle from a image.
277 |
278 | Parameters
279 | ----------
280 | image: np.ndarray
281 | Image to be rotate and crop.
282 |
283 | rects: np.ndarray
284 | [boxcount, 5] = [boxcount, cx, cy, width, height, angle]\n
285 | Rotat and crop rectangle.
286 |
287 | cx: float
288 | Rectangle center X coordinate.
289 |
290 | cy: float
291 | Rectangle center Y coordinate.
292 |
293 | width: float
294 | Width of the rectangle.
295 |
296 | height: float
297 | Height of the rectangle.
298 |
299 | angle: float
300 | The rotation angle in a clockwise direction.
301 | When the angle is 0, 90, 180, 270, 360 etc.,
302 | the rectangle becomes an up-right rectangle.
303 |
304 | operation_when_cropping_out_of_range: str
305 | 'padding' or 'ignore'
306 |
307 | Returns
308 | -------
309 | rotated_croped_image: List[np.ndarray]
310 | Image after cropping and rotation.
311 | """
312 | rects = copy.deepcopy(rects_tmp)
313 | rotated_croped_images = []
314 | height = image.shape[0]
315 | width = image.shape[1]
316 |
317 | # Determine if rect is inside the entire image
318 | if operation_when_cropping_out_of_range == 'padding':
319 | size = (int(math.sqrt(width ** 2 + height ** 2)) + 2) * 2
320 | image = pad_image(
321 | image=image,
322 | resize_width=size,
323 | resize_height=size,
324 | )
325 | rects[:, 0] = rects[:, 0] + abs(size-width) / 2
326 | rects[:, 1] = rects[:, 1] + abs(size-height) / 2
327 |
328 | elif operation_when_cropping_out_of_range == 'ignore':
329 | inside_or_outsides = is_inside_rect(
330 | rects=rects,
331 | width_of_outer_rect=width,
332 | height_of_outer_rect=height,
333 | )
334 | rects = rects[inside_or_outsides, ...]
335 |
336 | rect_bbx_upright = bounding_box_from_rotated_rect(
337 | rects=rects,
338 | )
339 |
340 | rect_bbx_upright_images = crop_rectangle(
341 | image=image,
342 | rects=rect_bbx_upright,
343 | )
344 |
345 | rotated_rect_bbx_upright_images = image_rotation_without_crop(
346 | images=rect_bbx_upright_images,
347 | angles=rects[..., 4:5],
348 | )
349 |
350 | for rotated_rect_bbx_upright_image, rect in zip(rotated_rect_bbx_upright_images, rects):
351 | crop_cx = rotated_rect_bbx_upright_image.shape[1]//2
352 | crop_cy = rotated_rect_bbx_upright_image.shape[0]//2
353 | rect_width = int(rect[2])
354 | rect_height = int(rect[3])
355 |
356 | rotated_croped_images.append(
357 | rotated_rect_bbx_upright_image[
358 | crop_cy-rect_height//2:crop_cy+(rect_height-rect_height//2),
359 | crop_cx-rect_width//2:crop_cx+(rect_width-rect_width//2),
360 | ]
361 | )
362 |
363 | return rotated_croped_images
364 |
365 |
366 | def keep_aspect_resize_and_pad(
367 | image: np.ndarray,
368 | resize_width: int,
369 | resize_height: int,
370 | ) -> Tuple[np.ndarray, np.ndarray]:
371 | """Resize the image based on the long side to fit inside
372 | the specified bounding rectangle while maintaining the aspect ratio,
373 | and padding the short side to the specified size.
374 |
375 | Parameters
376 | ----------
377 | image: np.ndarray
378 | Image to be resize and pad.
379 |
380 | resize_width: int
381 | Width of outer rectangle.
382 |
383 | resize_width: int
384 | Height of outer rectangle
385 |
386 | Returns
387 | -------
388 | padded_image: np.ndarray
389 | Image after padding.
390 |
391 | resized_image: np.ndarray
392 | Image after resize. (Before padding)
393 | """
394 | image_height = image.shape[0]
395 | image_width = image.shape[1]
396 | padded_image = np.zeros(
397 | (resize_height, resize_width, 3),
398 | np.uint8
399 | )
400 | ash = resize_height / image_height
401 | asw = resize_width / image_width
402 | if asw < ash:
403 | sizeas = (int(image_width * asw), int(image_height * asw))
404 | else:
405 | sizeas = (int(image_width * ash), int(image_height * ash))
406 | resized_image = cv2.resize(image, dsize=sizeas)
407 | start_h = int(resize_height / 2 - sizeas[1] / 2)
408 | end_h = int(resize_height / 2 + sizeas[1] / 2)
409 | start_w = int(resize_width / 2 - sizeas[0] / 2)
410 | end_w = int(resize_width / 2 + sizeas[0] / 2)
411 | padded_image[start_h:end_h, start_w:end_w, :] = resized_image.copy()
412 |
413 | return padded_image, resized_image
414 |
415 |
416 | def pad_image(
417 | image: np.ndarray,
418 | resize_width: int,
419 | resize_height: int,
420 | ) -> np.ndarray:
421 | """Padding the perimeter of the image to the specified bounding rectangle size.
422 |
423 | Parameters
424 | ----------
425 | image: np.ndarray
426 | Image to be resize and pad.
427 |
428 | resize_width: int
429 | Width of outer rectangle.
430 |
431 | resize_width: int
432 | Height of outer rectangle
433 |
434 | Returns
435 | -------
436 | padded_image: np.ndarray
437 | Image after padding.
438 | """
439 | image_height = image.shape[0]
440 | image_width = image.shape[1]
441 |
442 | if resize_width < image_width:
443 | resize_width = image_width
444 | if resize_height < image_height:
445 | resize_height = image_height
446 |
447 | padded_image = np.zeros(
448 | (resize_height, resize_width, 3),
449 | np.uint8
450 | )
451 | start_h = int(resize_height / 2 - image_height / 2)
452 | end_h = int(resize_height / 2 + image_height / 2)
453 | start_w = int(resize_width / 2 - image_width / 2)
454 | end_w = int(resize_width / 2 + image_width / 2)
455 | padded_image[start_h:end_h, start_w:end_w, :] = image
456 |
457 | return padded_image
458 |
459 |
460 | def rotated_rect_to_points(
461 | cx: float,
462 | cy: float,
463 | width: float,
464 | height: float,
465 | rotation: float,
466 | ) -> List[List[int]]:
467 | """Obtain coordinates for drawing a rotated rectangle
468 |
469 | Parameters
470 | ----------
471 | cx: float
472 | Rectangle center X coordinate.
473 |
474 | cy: float
475 | Rectangle center Y coordinate.
476 |
477 | width: float
478 | Width of the rectangle.
479 |
480 | height: float
481 | Height of the rectangle.
482 |
483 | rotation: float
484 | The rotation radians in a clockwise direction.
485 |
486 | Returns
487 | -------
488 | rotated_croped_image: List[np.ndarray]
489 | 4 coordinates for drawing rotated rectangle.
490 |
491 | e.g.\n
492 | cv2.polylines(\n
493 | img=frame,\n
494 | pts=[np.array([[p0x,p0y], [p1x,p1y], [p2x,p2y], [p3x,p3y]])],\n
495 | isClosed=True,\n
496 | color=(0,255,255),\n
497 | thickness=2,\n
498 | lineType=cv2.LINE_AA,\n
499 | )
500 | """
501 | b = cos(rotation) * 0.5
502 | a = sin(rotation) * 0.5
503 | p0x = cx - a*height - b*width
504 | p0y = cy + b*height - a*width
505 | p1x = cx + a*height - b*width
506 | p1y = cy - b*height - a*width
507 | p2x = int(2*cx - p0x)
508 | p2y = int(2*cy - p0y)
509 | p3x = int(2*cx - p1x)
510 | p3y = int(2*cy - p1y)
511 | p0x, p0y, p1x, p1y = int(p0x), int(p0y), int(p1x), int(p1y)
512 |
513 | return [[p0x,p0y], [p1x,p1y], [p2x,p2y], [p3x,p3y]]
514 |
--------------------------------------------------------------------------------