├── requirements.txt ├── .gitattributes ├── example.png ├── __init__.py ├── .gitignore ├── pyproject.toml ├── .github └── workflows │ └── publish.yml ├── readme.md ├── LICENSE ├── utils.py ├── models └── onnx_models.py ├── onetoall ├── utils.py └── infer_function.py ├── nodes.py ├── retarget_pose.py └── pose_utils ├── pose2d_utils.py └── human_visualization.py /requirements.txt: -------------------------------------------------------------------------------- 1 | onnx 2 | onnxruntime-gpu 3 | opencv-python -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kijai/ComfyUI-WanAnimatePreprocess/HEAD/example.png -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 2 | 3 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | output/ 2 | *__pycache__/ 3 | samples*/ 4 | runs/ 5 | checkpoints/ 6 | master_ip 7 | logs/ 8 | *.DS_Store 9 | .idea 10 | tools/ 11 | .vscode/ 12 | convert_* 13 | *.pt -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "ComfyUI-WanAnimatePreprocess" 3 | description = "ComfyUI nodes for WanAnimate input processing" 4 | version = "1.0.2" 5 | license = {file = "LICENSE"} 6 | dependencies = ["opencv-python", "onnxruntime-gpu", "onnx"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/kijai/ComfyUI-WanAnimatePreprocess" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "kijai" 14 | DisplayName = "ComfyUI-WanAnimatePreprocess" 15 | Icon = "" 16 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - "pyproject.toml" 9 | 10 | permissions: 11 | issues: write 12 | 13 | jobs: 14 | publish-node: 15 | name: Publish Custom Node to registry 16 | runs-on: ubuntu-latest 17 | if: ${{ github.repository_owner == 'kijai' }} 18 | steps: 19 | - name: Check out code 20 | uses: actions/checkout@v4 21 | - name: Publish Custom Node 22 | uses: Comfy-Org/publish-node-action@v1 23 | with: 24 | ## Add your own personal access token to your Github Repository secrets and reference it here. 25 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 26 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | ## ComfyUI helper nodes for [Wan video 2.2 Animate preprocessing](https://github.com/Wan-Video/Wan2.2/tree/main/wan/modules/animate/preprocess) 2 | 3 | 4 | Nodes to run the ViTPose model, get face crops and keypoint list for SAM2 segmentation. 5 | 6 | Models: 7 | 8 | to `ComfyUI/models/detection` (subject to change in the future) 9 | 10 | YOLO: 11 | 12 | https://huggingface.co/Wan-AI/Wan2.2-Animate-14B/blob/main/process_checkpoint/det/yolov10m.onnx 13 | 14 | ViTPose ONNX: 15 | 16 | Use either the Large model from here: 17 | 18 | https://huggingface.co/JunkyByte/easy_ViTPose/tree/main/onnx/wholebody 19 | 20 | Or the Huge model like in the original code, it's split into two files due to ONNX file size limit: 21 | 22 | Both files need to be in same directory, and the onnx file selected in the model loader: 23 | 24 | `vitpose_h_wholebody_data.bin` and `vitpose_h_wholebody_model.onnx` 25 | 26 | https://huggingface.co/Kijai/vitpose_comfy/tree/main/onnx 27 | 28 | 29 | ![example](example.png) 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import os 3 | import cv2 4 | import math 5 | import random 6 | import numpy as np 7 | 8 | def get_mask_boxes(mask): 9 | y_coords, x_coords = np.nonzero(mask) 10 | x_min = x_coords.min() 11 | x_max = x_coords.max() 12 | y_min = y_coords.min() 13 | y_max = y_coords.max() 14 | bbox = np.array([x_min, y_min, x_max, y_max]).astype(np.int32) 15 | return bbox 16 | 17 | 18 | def get_aug_mask(body_mask, w_len=10, h_len=20): 19 | body_bbox = get_mask_boxes(body_mask) 20 | 21 | bbox_wh = body_bbox[2:4] - body_bbox[0:2] 22 | w_slice = np.int32(bbox_wh[0] / w_len) 23 | h_slice = np.int32(bbox_wh[1] / h_len) 24 | 25 | for each_w in range(body_bbox[0], body_bbox[2], w_slice): 26 | w_start = min(each_w, body_bbox[2]) 27 | w_end = min((each_w + w_slice), body_bbox[2]) 28 | for each_h in range(body_bbox[1], body_bbox[3], h_slice): 29 | h_start = min(each_h, body_bbox[3]) 30 | h_end = min((each_h + h_slice), body_bbox[3]) 31 | if body_mask[h_start:h_end, w_start:w_end].sum() > 0: 32 | body_mask[h_start:h_end, w_start:w_end] = 1 33 | 34 | return body_mask 35 | 36 | def get_mask_body_img(img_copy, hand_mask, k=7, iterations=1): 37 | kernel = np.ones((k, k), np.uint8) 38 | dilation = cv2.dilate(hand_mask, kernel, iterations=iterations) 39 | mask_hand_img = img_copy * (1 - dilation[:, :, None]) 40 | 41 | return mask_hand_img, dilation 42 | 43 | 44 | def get_face_bboxes(kp2ds, scale, image_shape, ratio_aug): 45 | h, w = image_shape 46 | kp2ds_face = kp2ds.copy()[23:91, :2] 47 | 48 | min_x, min_y = np.min(kp2ds_face, axis=0) 49 | max_x, max_y = np.max(kp2ds_face, axis=0) 50 | 51 | 52 | initial_width = max_x - min_x 53 | initial_height = max_y - min_y 54 | 55 | initial_area = initial_width * initial_height 56 | 57 | expanded_area = initial_area * scale 58 | 59 | new_width = np.sqrt(expanded_area * (initial_width / initial_height)) 60 | new_height = np.sqrt(expanded_area * (initial_height / initial_width)) 61 | 62 | delta_width = (new_width - initial_width) / 2 63 | delta_height = (new_height - initial_height) / 4 64 | 65 | if ratio_aug: 66 | if random.random() > 0.5: 67 | delta_width += random.uniform(0, initial_width // 10) 68 | else: 69 | delta_height += random.uniform(0, initial_height // 10) 70 | 71 | expanded_min_x = max(min_x - delta_width, 0) 72 | expanded_max_x = min(max_x + delta_width, w) 73 | expanded_min_y = max(min_y - 3 * delta_height, 0) 74 | expanded_max_y = min(max_y + delta_height, h) 75 | 76 | return [int(expanded_min_x), int(expanded_max_x), int(expanded_min_y), int(expanded_max_y)] 77 | 78 | 79 | def calculate_new_size(orig_w, orig_h, target_area, divisor=64): 80 | 81 | target_ratio = orig_w / orig_h 82 | 83 | def check_valid(w, h): 84 | 85 | if w <= 0 or h <= 0: 86 | return False 87 | return (w * h <= target_area and 88 | w % divisor == 0 and 89 | h % divisor == 0) 90 | 91 | def get_ratio_diff(w, h): 92 | 93 | return abs(w / h - target_ratio) 94 | 95 | def round_to_64(value, round_up=False, divisor=64): 96 | 97 | if round_up: 98 | return divisor * ((value + (divisor - 1)) // divisor) 99 | return divisor * (value // divisor) 100 | 101 | possible_sizes = [] 102 | 103 | max_area_h = int(np.sqrt(target_area / target_ratio)) 104 | max_area_w = int(max_area_h * target_ratio) 105 | 106 | max_h = round_to_64(max_area_h, round_up=True, divisor=divisor) 107 | max_w = round_to_64(max_area_w, round_up=True, divisor=divisor) 108 | 109 | for h in range(divisor, max_h + divisor, divisor): 110 | ideal_w = h * target_ratio 111 | 112 | w_down = round_to_64(ideal_w) 113 | w_up = round_to_64(ideal_w, round_up=True) 114 | 115 | for w in [w_down, w_up]: 116 | if check_valid(w, h, divisor): 117 | possible_sizes.append((w, h, get_ratio_diff(w, h))) 118 | 119 | if not possible_sizes: 120 | raise ValueError("Can not find suitable size") 121 | 122 | possible_sizes.sort(key=lambda x: (-x[0] * x[1], x[2])) 123 | 124 | best_w, best_h, _ = possible_sizes[0] 125 | return int(best_w), int(best_h) 126 | 127 | 128 | def resize_by_area(image, target_area, keep_aspect_ratio=True, divisor=64, padding_color=(0, 0, 0)): 129 | h, w = image.shape[:2] 130 | try: 131 | new_w, new_h = calculate_new_size(w, h, target_area, divisor) 132 | except: 133 | aspect_ratio = w / h 134 | 135 | if keep_aspect_ratio: 136 | new_h = math.sqrt(target_area / aspect_ratio) 137 | new_w = target_area / new_h 138 | else: 139 | new_w = new_h = math.sqrt(target_area) 140 | 141 | new_w, new_h = int((new_w // divisor) * divisor), int((new_h // divisor) * divisor) 142 | 143 | interpolation = cv2.INTER_AREA if (new_w * new_h < w * h) else cv2.INTER_LINEAR 144 | 145 | resized_image = padding_resize(image, height=new_h, width=new_w, padding_color=padding_color, 146 | interpolation=interpolation) 147 | return resized_image 148 | 149 | 150 | def padding_resize(img_ori, height=512, width=512, padding_color=(0, 0, 0), interpolation=cv2.INTER_LINEAR): 151 | ori_height = img_ori.shape[0] 152 | ori_width = img_ori.shape[1] 153 | channel = img_ori.shape[2] 154 | 155 | img_pad = np.zeros((height, width, channel), dtype=img_ori.dtype) 156 | if channel == 1: 157 | img_pad[:, :, 0] = padding_color[0] 158 | else: 159 | img_pad[:, :, 0] = padding_color[0] 160 | img_pad[:, :, 1] = padding_color[1] 161 | img_pad[:, :, 2] = padding_color[2] 162 | 163 | if (ori_height / ori_width) > (height / width): 164 | new_width = int(height / ori_height * ori_width) 165 | img = cv2.resize(img_ori, (new_width, height), interpolation=interpolation) 166 | padding = int((width - new_width) / 2) 167 | if len(img.shape) == 2: 168 | img = img[:, :, np.newaxis] 169 | img_pad[:, padding: padding + new_width, :] = img 170 | else: 171 | new_height = int(width / ori_width * ori_height) 172 | img = cv2.resize(img_ori, (width, new_height), interpolation=interpolation) 173 | padding = int((height - new_height) / 2) 174 | if len(img.shape) == 2: 175 | img = img[:, :, np.newaxis] 176 | img_pad[padding: padding + new_height, :, :] = img 177 | 178 | return img_pad 179 | 180 | def resize_to_bounds(img_ori, height=512, width=512, padding_color=(0, 0, 0), interpolation=cv2.INTER_LINEAR, extra_padding=64, crop_target_image=None): 181 | # Find non-black pixel bounds 182 | if crop_target_image is not None: 183 | ref = crop_target_image 184 | if ref.ndim == 2: 185 | mask = ref > 0 186 | else: 187 | mask = np.any(ref != 0, axis=2) 188 | coords = np.argwhere(mask) 189 | if coords.size == 0: 190 | # All black, fallback to full image 191 | y0, x0 = 0, 0 192 | y1, x1 = img_ori.shape[0], img_ori.shape[1] 193 | else: 194 | y0, x0 = coords.min(axis=0) 195 | y1, x1 = coords.max(axis=0) + 1 196 | # Intended crop bounds with padding 197 | pad_y0 = y0 - extra_padding 198 | pad_x0 = x0 - extra_padding 199 | pad_y1 = y1 + extra_padding 200 | pad_x1 = x1 + extra_padding 201 | # Actual crop bounds clipped to image 202 | crop_y0 = max(pad_y0, 0) 203 | crop_x0 = max(pad_x0, 0) 204 | crop_y1 = min(pad_y1, img_ori.shape[0]) 205 | crop_x1 = min(pad_x1, img_ori.shape[1]) 206 | crop_img = img_ori[crop_y0:crop_y1, crop_x0:crop_x1] 207 | # Pad if needed 208 | pad_top = crop_y0 - pad_y0 209 | pad_left = crop_x0 - pad_x0 210 | pad_bottom = pad_y1 - crop_y1 211 | pad_right = pad_x1 - crop_x1 212 | if any([pad_top, pad_left, pad_bottom, pad_right]): 213 | channel = crop_img.shape[2] if crop_img.ndim == 3 else 1 214 | crop_img = np.pad( 215 | crop_img, 216 | ((pad_top, pad_bottom), (pad_left, pad_right)) + ((0, 0),) if channel > 1 else ((pad_top, pad_bottom), (pad_left, pad_right)), 217 | mode='constant', constant_values=0 218 | ) 219 | else: 220 | if img_ori.ndim == 2: 221 | mask = img_ori > 0 222 | else: 223 | mask = np.any(img_ori != 0, axis=2) 224 | coords = np.argwhere(mask) 225 | if coords.size == 0: 226 | # All black, fallback to original 227 | crop_img = img_ori 228 | else: 229 | y0, x0 = coords.min(axis=0) 230 | y1, x1 = coords.max(axis=0) + 1 231 | pad_y0 = y0 - extra_padding 232 | pad_x0 = x0 - extra_padding 233 | pad_y1 = y1 + extra_padding 234 | pad_x1 = x1 + extra_padding 235 | crop_y0 = max(pad_y0, 0) 236 | crop_x0 = max(pad_x0, 0) 237 | crop_y1 = min(pad_y1, img_ori.shape[0]) 238 | crop_x1 = min(pad_x1, img_ori.shape[1]) 239 | crop_img = img_ori[crop_y0:crop_y1, crop_x0:crop_x1] 240 | pad_top = crop_y0 - pad_y0 241 | pad_left = crop_x0 - pad_x0 242 | pad_bottom = pad_y1 - crop_y1 243 | pad_right = pad_x1 - crop_x1 244 | if any([pad_top, pad_left, pad_bottom, pad_right]): 245 | channel = crop_img.shape[2] if crop_img.ndim == 3 else 1 246 | crop_img = np.pad( 247 | crop_img, 248 | ((pad_top, pad_bottom), (pad_left, pad_right)) + ((0, 0),) if channel > 1 else ((pad_top, pad_bottom), (pad_left, pad_right)), 249 | mode='constant', constant_values=0 250 | ) 251 | 252 | ori_height = crop_img.shape[0] 253 | ori_width = crop_img.shape[1] 254 | channel = crop_img.shape[2] if crop_img.ndim == 3 else 1 255 | 256 | img_pad = np.zeros((height, width, channel), dtype=crop_img.dtype) 257 | if channel == 1: 258 | img_pad[:, :, 0] = padding_color[0] 259 | else: 260 | for c in range(channel): 261 | img_pad[:, :, c] = padding_color[c % len(padding_color)] 262 | 263 | # Resize cropped image to fit target size, preserving aspect ratio 264 | crop_aspect = ori_width / ori_height 265 | target_aspect = width / height 266 | if crop_aspect > target_aspect: 267 | new_width = width 268 | new_height = int(width / crop_aspect) 269 | else: 270 | new_height = height 271 | new_width = int(height * crop_aspect) 272 | img = cv2.resize(crop_img, (new_width, new_height), interpolation=interpolation) 273 | if img.ndim == 2: 274 | img = img[:, :, np.newaxis] 275 | y_pad = (height - new_height) // 2 276 | x_pad = (width - new_width) // 2 277 | img_pad[y_pad:y_pad + new_height, x_pad:x_pad + new_width, :] = img 278 | 279 | return img_pad 280 | 281 | 282 | def get_frame_indices(frame_num, video_fps, clip_length, train_fps): 283 | 284 | start_frame = 0 285 | times = np.arange(0, clip_length) / train_fps 286 | frame_indices = start_frame + np.round(times * video_fps).astype(int) 287 | frame_indices = np.clip(frame_indices, 0, frame_num - 1) 288 | 289 | return frame_indices.tolist() 290 | 291 | 292 | def get_face_bboxes(kp2ds, scale, image_shape): 293 | h, w = image_shape 294 | kp2ds_face = kp2ds.copy()[1:] * (w, h) 295 | 296 | min_x, min_y = np.min(kp2ds_face, axis=0) 297 | max_x, max_y = np.max(kp2ds_face, axis=0) 298 | 299 | initial_width = max_x - min_x 300 | initial_height = max_y - min_y 301 | 302 | initial_area = initial_width * initial_height 303 | 304 | expanded_area = initial_area * scale 305 | 306 | new_width = np.sqrt(expanded_area * (initial_width / initial_height)) 307 | new_height = np.sqrt(expanded_area * (initial_height / initial_width)) 308 | 309 | delta_width = (new_width - initial_width) / 2 310 | delta_height = (new_height - initial_height) / 4 311 | 312 | expanded_min_x = max(min_x - delta_width, 0) 313 | expanded_max_x = min(max_x + delta_width, w) 314 | expanded_min_y = max(min_y - 3 * delta_height, 0) 315 | expanded_max_y = min(max_y + delta_height, h) 316 | 317 | return [int(expanded_min_x), int(expanded_max_x), int(expanded_min_y), int(expanded_max_y)] -------------------------------------------------------------------------------- /models/onnx_models.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | import onnxruntime 7 | 8 | from ..pose_utils.pose2d_utils import box_convert_simple, keypoints_from_heatmaps 9 | 10 | class SimpleOnnxInference(object): 11 | def __init__(self, checkpoint, device='CUDAExecutionProvider', **kwargs): 12 | # Store initialization parameters for potential reinit 13 | self.checkpoint = checkpoint 14 | self.init_kwargs = kwargs 15 | provider = [device, 'CPUExecutionProvider'] if device == 'CUDAExecutionProvider' else [device] 16 | 17 | self.provider = provider 18 | self.session = onnxruntime.InferenceSession(checkpoint, providers=provider) 19 | self.input_name = self.session.get_inputs()[0].name 20 | self.output_name = self.session.get_outputs()[0].name 21 | self.input_resolution = self.session.get_inputs()[0].shape[2:] 22 | self.input_resolution = np.array(self.input_resolution) 23 | 24 | def __call__(self, *args, **kwargs): 25 | return self.forward(*args, **kwargs) 26 | 27 | def get_output_names(self): 28 | output_names = [] 29 | for node in self.session.get_outputs(): 30 | output_names.append(node.name) 31 | return output_names 32 | 33 | def cleanup(self): 34 | if hasattr(self, 'session') and self.session is not None: 35 | # Close the ONNX Runtime session 36 | del self.session 37 | self.session = None 38 | 39 | def reinit(self, provider=None): 40 | # Use provided provider or fall back to original provider 41 | if provider is not None: 42 | self.provider = provider 43 | 44 | if self.session is None: 45 | checkpoint = self.checkpoint 46 | self.session = onnxruntime.InferenceSession(checkpoint, providers=self.provider) 47 | self.input_name = self.session.get_inputs()[0].name 48 | self.output_name = self.session.get_outputs()[0].name 49 | self.input_resolution = self.session.get_inputs()[0].shape[2:] 50 | self.input_resolution = np.array(self.input_resolution) 51 | 52 | class Yolo(SimpleOnnxInference): 53 | def __init__(self, checkpoint, device='cuda', threshold_conf=0.05, threshold_multi_persons=0.1, input_resolution=(640, 640), threshold_iou=0.5, threshold_bbox_shape_ratio=0.4, cat_id=[1], select_type='max', strict=True, sorted_func=None, **kwargs): 54 | super(Yolo, self).__init__(checkpoint, device=device, **kwargs) 55 | 56 | model_inputs = self.session.get_inputs() 57 | input_shape = model_inputs[0].shape 58 | 59 | self.input_width = 640 60 | self.input_height = 640 61 | 62 | self.threshold_multi_persons = threshold_multi_persons 63 | self.threshold_conf = threshold_conf 64 | self.threshold_iou = threshold_iou 65 | self.threshold_bbox_shape_ratio = threshold_bbox_shape_ratio 66 | self.input_resolution = input_resolution 67 | self.cat_id = cat_id 68 | self.select_type = select_type 69 | self.strict = strict 70 | self.sorted_func = sorted_func 71 | 72 | 73 | 74 | def postprocess(self, output, shape_raw, cat_id=[1]): 75 | """ 76 | Performs post-processing on the model's output to extract bounding boxes, scores, and class IDs. 77 | 78 | Args: 79 | input_image (numpy.ndarray): The input image. 80 | output (numpy.ndarray): The output of the model. 81 | 82 | Returns: 83 | numpy.ndarray: The input image with detections drawn on it. 84 | """ 85 | # Transpose and squeeze the output to match the expected shape 86 | 87 | outputs = np.squeeze(output) 88 | if len(outputs.shape) == 1: 89 | outputs = outputs[None] 90 | if output.shape[-1] != 6 and output.shape[1] == 84: 91 | outputs = np.transpose(outputs) 92 | 93 | # Get the number of rows in the outputs array 94 | rows = outputs.shape[0] 95 | 96 | # Calculate the scaling factors for the bounding box coordinates 97 | x_factor = shape_raw[1] / self.input_width 98 | y_factor = shape_raw[0] / self.input_height 99 | 100 | # Lists to store the bounding boxes, scores, and class IDs of the detections 101 | boxes = [] 102 | scores = [] 103 | class_ids = [] 104 | 105 | if outputs.shape[-1] == 6: 106 | max_scores = outputs[:, 4] 107 | classid = outputs[:, -1] 108 | 109 | threshold_conf_masks = max_scores >= self.threshold_conf 110 | classid_masks = classid[threshold_conf_masks] != 3.14159 111 | 112 | max_scores = max_scores[threshold_conf_masks][classid_masks] 113 | classid = classid[threshold_conf_masks][classid_masks] 114 | 115 | boxes = outputs[:, :4][threshold_conf_masks][classid_masks] 116 | boxes[:, [0, 2]] *= x_factor 117 | boxes[:, [1, 3]] *= y_factor 118 | boxes[:, 2] = boxes[:, 2] - boxes[:, 0] 119 | boxes[:, 3] = boxes[:, 3] - boxes[:, 1] 120 | boxes = boxes.astype(np.int32) 121 | 122 | else: 123 | classes_scores = outputs[:, 4:] 124 | max_scores = np.amax(classes_scores, -1) 125 | threshold_conf_masks = max_scores >= self.threshold_conf 126 | 127 | classid = np.argmax(classes_scores[threshold_conf_masks], -1) 128 | 129 | classid_masks = classid!=3.14159 130 | 131 | classes_scores = classes_scores[threshold_conf_masks][classid_masks] 132 | max_scores = max_scores[threshold_conf_masks][classid_masks] 133 | classid = classid[classid_masks] 134 | 135 | xywh = outputs[:, :4][threshold_conf_masks][classid_masks] 136 | 137 | x = xywh[:, 0:1] 138 | y = xywh[:, 1:2] 139 | w = xywh[:, 2:3] 140 | h = xywh[:, 3:4] 141 | 142 | left = ((x - w / 2) * x_factor) 143 | top = ((y - h / 2) * y_factor) 144 | width = (w * x_factor) 145 | height = (h * y_factor) 146 | boxes = np.concatenate([left, top, width, height], axis=-1).astype(np.int32) 147 | 148 | boxes = boxes.tolist() 149 | scores = max_scores.tolist() 150 | class_ids = classid.tolist() 151 | 152 | # Apply non-maximum suppression to filter out overlapping bounding boxes 153 | indices = cv2.dnn.NMSBoxes(boxes, scores, self.threshold_conf, self.threshold_iou) 154 | # Iterate over the selected indices after non-maximum suppression 155 | 156 | results = [] 157 | for i in indices: 158 | # Get the box, score, and class ID corresponding to the index 159 | box = box_convert_simple(boxes[i], 'xywh2xyxy') 160 | score = scores[i] 161 | class_id = class_ids[i] 162 | results.append(box + [score] + [class_id]) 163 | # # Draw the detection on the input image 164 | 165 | # Return the modified input image 166 | return np.array(results) 167 | 168 | 169 | def process_results(self, results, shape_raw, cat_id=[1], single_person=True): 170 | if isinstance(results, tuple): 171 | det_results = results[0] 172 | else: 173 | det_results = results 174 | 175 | person_results = [] 176 | person_count = 0 177 | if len(results): 178 | max_idx = -1 179 | max_bbox_size = shape_raw[0] * shape_raw[1] * -10 180 | max_bbox_shape = -1 181 | 182 | bboxes = [] 183 | idx_list = [] 184 | for i in range(results.shape[0]): 185 | bbox = results[i] 186 | if (bbox[-1] + 1 in cat_id) and (bbox[-2] > self.threshold_conf): 187 | idx_list.append(i) 188 | bbox_shape = max((bbox[2] - bbox[0]), ((bbox[3] - bbox[1]))) 189 | if bbox_shape > max_bbox_shape: 190 | max_bbox_shape = bbox_shape 191 | 192 | results = results[idx_list] 193 | 194 | for i in range(results.shape[0]): 195 | bbox = results[i] 196 | bboxes.append(bbox) 197 | if self.select_type == 'max': 198 | bbox_size = (bbox[2] - bbox[0]) * ((bbox[3] - bbox[1])) 199 | elif self.select_type == 'center': 200 | bbox_size = (abs((bbox[2] + bbox[0]) / 2 - shape_raw[1]/2)) * -1 201 | bbox_shape = max((bbox[2] - bbox[0]), ((bbox[3] - bbox[1]))) 202 | if bbox_size > max_bbox_size: 203 | if (self.strict or max_idx != -1) and bbox_shape < max_bbox_shape * self.threshold_bbox_shape_ratio: 204 | continue 205 | max_bbox_size = bbox_size 206 | max_bbox_shape = bbox_shape 207 | max_idx = i 208 | 209 | if self.sorted_func is not None and len(bboxes) > 0: 210 | max_idx = self.sorted_func(bboxes, shape_raw) 211 | bbox = bboxes[max_idx] 212 | if self.select_type == 'max': 213 | max_bbox_size = (bbox[2] - bbox[0]) * ((bbox[3] - bbox[1])) 214 | elif self.select_type == 'center': 215 | max_bbox_size = (abs((bbox[2] + bbox[0]) / 2 - shape_raw[1]/2)) * -1 216 | 217 | if max_idx != -1: 218 | person_count = 1 219 | 220 | if max_idx != -1: 221 | person = {} 222 | person['bbox'] = results[max_idx, :5] 223 | person['track_id'] = int(0) 224 | person_results.append(person) 225 | 226 | for i in range(results.shape[0]): 227 | bbox = results[i] 228 | if (bbox[-1] + 1 in cat_id) and (bbox[-2] > self.threshold_conf): 229 | if self.select_type == 'max': 230 | bbox_size = (bbox[2] - bbox[0]) * ((bbox[3] - bbox[1])) 231 | elif self.select_type == 'center': 232 | bbox_size = (abs((bbox[2] + bbox[0]) / 2 - shape_raw[1]/2)) * -1 233 | if i != max_idx and bbox_size > max_bbox_size * self.threshold_multi_persons and bbox_size < max_bbox_size: 234 | person_count += 1 235 | if not single_person: 236 | person = {} 237 | person['bbox'] = results[i, :5] 238 | person['track_id'] = int(person_count - 1) 239 | person_results.append(person) 240 | return person_results 241 | else: 242 | return None 243 | 244 | 245 | def postprocess_threading(self, outputs, shape_raw, person_results, i, single_person=True, **kwargs): 246 | result = self.postprocess(outputs[i], shape_raw[i], cat_id=self.cat_id) 247 | result = self.process_results(result, shape_raw[i], cat_id=self.cat_id, single_person=single_person) 248 | if result is not None and len(result) != 0: 249 | person_results[i] = result 250 | 251 | 252 | def forward(self, img, shape_raw, **kwargs): 253 | """ 254 | Performs inference using an ONNX model and returns the output image with drawn detections. 255 | 256 | Returns: 257 | output_img: The output image with drawn detections. 258 | """ 259 | if isinstance(img, torch.Tensor): 260 | img = img.cpu().numpy() 261 | shape_raw = shape_raw.cpu().numpy() 262 | 263 | outputs = self.session.run(None, {self.session.get_inputs()[0].name: img})[0] 264 | person_results = [[{'bbox': np.array([0., 0., 1.*shape_raw[i][1], 1.*shape_raw[i][0], -1]), 'track_id': -1}] for i in range(len(outputs))] 265 | 266 | for i in range(len(outputs)): 267 | self.postprocess_threading(outputs, shape_raw, person_results, i, **kwargs) 268 | return person_results 269 | 270 | 271 | class ViTPose(SimpleOnnxInference): 272 | def __init__(self, checkpoint, device='cuda', **kwargs): 273 | super(ViTPose, self).__init__(checkpoint, device=device) 274 | 275 | def forward(self, img, center, scale, **kwargs): 276 | heatmaps = self.session.run([], {self.session.get_inputs()[0].name: img})[0] 277 | points, prob = keypoints_from_heatmaps(heatmaps=heatmaps, 278 | center=center, 279 | scale=scale*200, 280 | unbiased=True, 281 | use_udp=False) 282 | return np.concatenate([points, prob], axis=2) 283 | -------------------------------------------------------------------------------- /onetoall/utils.py: -------------------------------------------------------------------------------- 1 | # https://github.com/ssj9596/One-to-All-Animation 2 | 3 | import cv2 4 | import numpy as np 5 | import math 6 | import copy 7 | 8 | eps = 0.01 9 | 10 | DROP_FACE_POINTS = {0, 14, 15, 16, 17} 11 | DROP_UPPER_POINTS = {0, 14, 15, 16, 17, 2, 1, 5, 3, 6} 12 | DROP_LOWER_POINTS = {8, 9, 10, 11, 12, 13} 13 | 14 | def scale_and_translate_pose(tgt_pose, ref_pose, conf_th=0.9, return_ratio=False): 15 | aligned_pose = copy.deepcopy(tgt_pose) 16 | th = 1e-6 17 | ref_kpt = ref_pose['bodies']['candidate'].astype(np.float32) 18 | tgt_kpt = aligned_pose['bodies']['candidate'].astype(np.float32) 19 | 20 | ref_sc = ref_pose['bodies'].get('score', np.ones(ref_kpt.shape[0])).astype(np.float32).reshape(-1) 21 | tgt_sc = tgt_pose['bodies'].get('score', np.ones(tgt_kpt.shape[0])).astype(np.float32).reshape(-1) 22 | 23 | ref_shoulder_valid = (ref_sc[2] >= conf_th) and (ref_sc[5] >= conf_th) 24 | tgt_shoulder_valid = (tgt_sc[2] >= conf_th) and (tgt_sc[5] >= conf_th) 25 | shoulder_ok = ref_shoulder_valid and tgt_shoulder_valid 26 | 27 | ref_hip_valid = (ref_sc[8] >= conf_th) and (ref_sc[11] >= conf_th) 28 | tgt_hip_valid = (tgt_sc[8] >= conf_th) and (tgt_sc[11] >= conf_th) 29 | hip_ok = ref_hip_valid and tgt_hip_valid 30 | 31 | if shoulder_ok and hip_ok: 32 | ref_shoulder_w = abs(ref_kpt[5, 0] - ref_kpt[2, 0]) 33 | tgt_shoulder_w = abs(tgt_kpt[5, 0] - tgt_kpt[2, 0]) 34 | x_ratio = ref_shoulder_w / tgt_shoulder_w if tgt_shoulder_w > th else 1.0 35 | 36 | ref_torso_h = abs(np.mean(ref_kpt[[8, 11], 1]) - np.mean(ref_kpt[[2, 5], 1])) 37 | tgt_torso_h = abs(np.mean(tgt_kpt[[8, 11], 1]) - np.mean(tgt_kpt[[2, 5], 1])) 38 | y_ratio = ref_torso_h / tgt_torso_h if tgt_torso_h > th else 1.0 39 | scale_ratio = (x_ratio + y_ratio) / 2 40 | 41 | elif shoulder_ok: 42 | ref_sh_dist = np.linalg.norm(ref_kpt[2] - ref_kpt[5]) 43 | tgt_sh_dist = np.linalg.norm(tgt_kpt[2] - tgt_kpt[5]) 44 | scale_ratio = ref_sh_dist / tgt_sh_dist if tgt_sh_dist > th else 1.0 45 | 46 | else: 47 | ref_ear_dist = np.linalg.norm(ref_kpt[16] - ref_kpt[17]) 48 | tgt_ear_dist = np.linalg.norm(tgt_kpt[16] - tgt_kpt[17]) 49 | scale_ratio = ref_ear_dist / tgt_ear_dist if tgt_ear_dist > th else 1.0 50 | 51 | if return_ratio: 52 | return scale_ratio 53 | 54 | # scale 55 | anchor_idx = 1 56 | anchor_pt_before_scale = tgt_kpt[anchor_idx].copy() 57 | def scale(arr): 58 | if arr is not None and arr.size > 0: 59 | arr[..., 0] = anchor_pt_before_scale[0] + (arr[..., 0] - anchor_pt_before_scale[0]) * scale_ratio 60 | arr[..., 1] = anchor_pt_before_scale[1] + (arr[..., 1] - anchor_pt_before_scale[1]) * scale_ratio 61 | scale(tgt_kpt) 62 | scale(aligned_pose.get('faces')) 63 | scale(aligned_pose.get('hands')) 64 | 65 | # offset 66 | offset = ref_kpt[anchor_idx] - tgt_kpt[anchor_idx] 67 | def translate(arr): 68 | if arr is not None and arr.size > 0: 69 | arr += offset 70 | translate(tgt_kpt) 71 | translate(aligned_pose.get('faces')) 72 | translate(aligned_pose.get('hands')) 73 | aligned_pose['bodies']['candidate'] = tgt_kpt 74 | 75 | return aligned_pose, shoulder_ok, hip_ok 76 | 77 | 78 | def warp_ref_to_pose(tgt_img, 79 | ref_pose: dict, #driven pose 80 | tgt_pose: dict, 81 | bg_val=(0, 0, 0), 82 | conf_th=0.9, 83 | align_center=False): 84 | 85 | H, W = tgt_img.shape[:2] 86 | img_tgt_pose = draw_pose_aligned(tgt_pose, H, W, without_face=True) 87 | 88 | tgt_kpt = tgt_pose['bodies']['candidate'].astype(np.float32) 89 | ref_kpt = ref_pose['bodies']['candidate'].astype(np.float32) 90 | 91 | scale_ratio = scale_and_translate_pose(tgt_pose, ref_pose, conf_th=conf_th, return_ratio=True) 92 | 93 | anchor_idx = 1 94 | x0 = tgt_kpt[anchor_idx][0] * W 95 | y0 = tgt_kpt[anchor_idx][1] * H 96 | 97 | ref_x = ref_kpt[anchor_idx][0] * W if not align_center else W/2 98 | ref_y = ref_kpt[anchor_idx][1] * H 99 | 100 | dx = ref_x - x0 101 | dy = ref_y - y0 102 | 103 | # Affine transformation matrix 104 | M = np.array([[scale_ratio, 0, (1-scale_ratio)*x0 + dx], 105 | [0, scale_ratio, (1-scale_ratio)*y0 + dy]], 106 | dtype=np.float32) 107 | img_warp = cv2.warpAffine(tgt_img, M, (W, H), 108 | flags=cv2.INTER_LINEAR, 109 | borderValue=bg_val) 110 | img_tgt_pose_warp = cv2.warpAffine(img_tgt_pose, M, (W, H), 111 | flags=cv2.INTER_LINEAR, 112 | borderValue=bg_val) 113 | zeros = np.zeros((H, W), dtype=np.uint8) 114 | mask_warp = cv2.warpAffine(zeros, M, (W, H), 115 | flags=cv2.INTER_NEAREST, 116 | borderValue=255) 117 | return img_warp, img_tgt_pose_warp, mask_warp 118 | 119 | def hsv_to_rgb(hsv): 120 | hsv = np.asarray(hsv, dtype=np.float32) 121 | in_shape = hsv.shape 122 | hsv = hsv.reshape(-1, 3) 123 | 124 | h, s, v = hsv[:, 0], hsv[:, 1], hsv[:, 2] 125 | 126 | i = (h * 6.0).astype(int) 127 | f = (h * 6.0) - i 128 | i = i % 6 129 | 130 | p = v * (1.0 - s) 131 | q = v * (1.0 - s * f) 132 | t = v * (1.0 - s * (1.0 - f)) 133 | 134 | rgb = np.zeros_like(hsv) 135 | rgb[i == 0] = np.stack([v[i == 0], t[i == 0], p[i == 0]], axis=1) 136 | rgb[i == 1] = np.stack([q[i == 1], v[i == 1], p[i == 1]], axis=1) 137 | rgb[i == 2] = np.stack([p[i == 2], v[i == 2], t[i == 2]], axis=1) 138 | rgb[i == 3] = np.stack([p[i == 3], q[i == 3], v[i == 3]], axis=1) 139 | rgb[i == 4] = np.stack([t[i == 4], p[i == 4], v[i == 4]], axis=1) 140 | rgb[i == 5] = np.stack([v[i == 5], p[i == 5], q[i == 5]], axis=1) 141 | 142 | gray_mask = s == 0 143 | rgb[gray_mask] = np.stack([v[gray_mask]] * 3, axis=1) 144 | 145 | return (rgb.reshape(in_shape) * 255) 146 | 147 | def get_stickwidth(W, H, stickwidth=4): 148 | if max(W, H) < 512: 149 | ratio = 1.0 150 | elif max(W, H) < 1080: 151 | ratio = 1.5 152 | elif max(W, H) < 2160: 153 | ratio = 2.0 154 | elif max(W, H) < 3240: 155 | ratio = 2.5 156 | elif max(W, H) < 4320: 157 | ratio = 3.5 158 | elif max(W, H) < 5400: 159 | ratio = 4.5 160 | else: 161 | ratio = 4.0 162 | return int(stickwidth * ratio) 163 | 164 | 165 | def alpha_blend_color(color, alpha): 166 | return [int(c * alpha) for c in color] 167 | 168 | 169 | def draw_bodypose_aligned(canvas, candidate, subset, score, plan=None): 170 | H, W, C = canvas.shape 171 | candidate = np.array(candidate) 172 | subset = np.array(subset) 173 | stickwidth = get_stickwidth(W, H, stickwidth=3) 174 | 175 | limbSeq = [ 176 | [2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], 177 | [2, 9], [9, 10], [10, 11], [2, 12], [12, 13], [13, 14], 178 | [2, 1], [1, 15], [15, 17], [1, 16], [16, 18], [3, 17], [6, 18]] 179 | colors = [ 180 | [255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], 181 | [85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255], 182 | [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], 183 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] 184 | 185 | HIDE_JOINTS = set() 186 | stretch_limb_idx = None 187 | stretch_scale = None 188 | if plan: 189 | if plan["mode"] == "drop_point": 190 | HIDE_JOINTS.add(plan["point_idx"]) 191 | elif plan["mode"] == "drop_region": 192 | HIDE_JOINTS |= set(plan["points"]) 193 | elif plan["mode"] == "stretch_limb": 194 | stretch_limb_idx = plan["limb_idx"] 195 | stretch_scale = plan["stretch_scale"] 196 | 197 | hide_joint = np.zeros_like(subset, dtype=bool) 198 | 199 | for i in range(17): 200 | for n in range(len(subset)): 201 | idx_pair = limbSeq[i] 202 | 203 | if any(j in HIDE_JOINTS for j in idx_pair): 204 | continue 205 | 206 | index = subset[n][np.array(idx_pair) - 1] 207 | conf = score[n][np.array(idx_pair) - 1] 208 | if -1 in index: 209 | continue 210 | # color lighten 211 | alpha = max(conf[0] * conf[1], 0) if conf[0]>0 and conf[1]>0 else 0.35 212 | if conf[0] == 0 or conf[1] == 0: 213 | alpha = 0 214 | 215 | Y = candidate[index.astype(int), 0] * float(W) 216 | X = candidate[index.astype(int), 1] * float(H) 217 | 218 | if stretch_limb_idx == i: 219 | vec_x = X[1] - X[0] 220 | vec_y = Y[1] - Y[0] 221 | X[1] = X[0] + vec_x * stretch_scale 222 | Y[1] = Y[0] + vec_y * stretch_scale 223 | hide_joint[n, idx_pair[1]-1] = True 224 | 225 | mX = np.mean(X) 226 | mY = np.mean(Y) 227 | length = ((X[0]-X[1])**2 + (Y[0]-Y[1])**2) ** 0.5 228 | angle = math.degrees(math.atan2(X[0]-X[1], Y[0]-Y[1])) 229 | polygon = cv2.ellipse2Poly((int(mY), int(mX)), 230 | (int(length/2), stickwidth), int(angle), 0, 360, 1) 231 | cv2.fillConvexPoly(canvas, polygon, alpha_blend_color(colors[i], alpha)) 232 | 233 | canvas = (canvas * 0.6).astype(np.uint8) 234 | 235 | for i in range(18): 236 | if i in HIDE_JOINTS: 237 | continue 238 | for n in range(len(subset)): 239 | if hide_joint[n, i]: 240 | continue 241 | index = int(subset[n][i]) 242 | if index == -1: 243 | continue 244 | x, y = candidate[index][0:2] 245 | conf = score[n][i] 246 | 247 | alpha = 0 if conf==-2 else max(conf, 0) 248 | x = int(x * W) 249 | y = int(y * H) 250 | cv2.circle(canvas, (x, y), stickwidth, alpha_blend_color(colors[i], alpha), thickness=-1) 251 | 252 | return canvas 253 | 254 | 255 | def draw_handpose_aligned(canvas, all_hand_peaks, all_hand_scores, draw_th=0.3): 256 | H, W, C = canvas.shape 257 | stickwidth = get_stickwidth(W, H, stickwidth=2) 258 | line_thickness = get_stickwidth(W, H, stickwidth=2) 259 | 260 | edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \ 261 | [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]] 262 | 263 | for peaks, scores in zip(all_hand_peaks, all_hand_scores): 264 | for ie, e in enumerate(edges): 265 | if scores[e[0]] < draw_th or scores[e[1]] < draw_th: 266 | continue 267 | x1, y1 = peaks[e[0]] 268 | x2, y2 = peaks[e[1]] 269 | x1 = int(x1 * W) 270 | y1 = int(y1 * H) 271 | x2 = int(x2 * W) 272 | y2 = int(y2 * H) 273 | 274 | score = int(scores[e[0]] * scores[e[1]] * 255) 275 | if x1 > eps and y1 > eps and x2 > eps and y2 > eps: 276 | color = hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]).flatten() 277 | color = tuple(int(c * score / 255) for c in color) 278 | cv2.line(canvas, (x1, y1), (x2, y2), color, thickness=line_thickness) 279 | 280 | for i, keyponit in enumerate(peaks): 281 | if scores[i] < draw_th: 282 | continue 283 | 284 | x, y = keyponit 285 | x = int(x * W) 286 | y = int(y * H) 287 | score = int(scores[i] * 255) 288 | if x > eps and y > eps: 289 | cv2.circle(canvas, (x, y), stickwidth, (0, 0, score), thickness=-1) 290 | return canvas 291 | 292 | 293 | def draw_facepose_aligned(canvas, all_lmks, all_scores, draw_th=0.3,face_change=False): 294 | H, W, C = canvas.shape 295 | stickwidth = get_stickwidth(W, H, stickwidth=2) 296 | SKIP_IDX = set(range(0, 17)) 297 | SKIP_IDX |= set(range(27, 36)) 298 | 299 | for lmks, scores in zip(all_lmks, all_scores): 300 | for idx, (lmk, score) in enumerate(zip(lmks, scores)): 301 | # skip chin 302 | if idx in SKIP_IDX: 303 | continue 304 | if score < draw_th: 305 | continue 306 | x, y = lmk 307 | x = int(x * W) 308 | y = int(y * H) 309 | conf = int(score * 255) 310 | # color lighten 311 | if face_change: 312 | conf = int(conf * 0.35) 313 | 314 | if x > eps and y > eps: 315 | cv2.circle(canvas, (x, y), stickwidth, (conf, conf, conf), thickness=-1) 316 | return canvas 317 | 318 | 319 | def draw_pose_aligned(pose, H, W, ref_w=2160, without_face=False, pose_plan=None, head_strength="full", face_change=False): 320 | bodies = pose['bodies'] 321 | faces = pose['faces'] 322 | hands = pose['hands'] 323 | candidate = bodies['candidate'] 324 | subset = bodies['subset'] 325 | body_score = bodies['score'].copy() 326 | # control color 327 | if head_strength == "weak": 328 | target_joints = [0, 14, 15, 16, 17] 329 | body_score[:, target_joints] = -2 330 | elif head_strength == "none": 331 | target_joints = [0, 14, 15, 16, 17] 332 | body_score[:, target_joints] = 0 333 | 334 | sz = min(H, W) 335 | sr = (ref_w / sz) if sz != ref_w else 1 336 | canvas = np.zeros(shape=(int(H*sr), int(W*sr), 3), dtype=np.uint8) 337 | 338 | canvas = draw_bodypose_aligned(canvas, candidate, subset, 339 | score=body_score, 340 | plan=pose_plan,) 341 | 342 | canvas = draw_handpose_aligned(canvas, hands, pose['hands_score']) 343 | 344 | if not without_face: 345 | canvas = draw_facepose_aligned(canvas, faces, pose['faces_score'],face_change=face_change) 346 | 347 | return cv2.resize(canvas, (W, H)) 348 | -------------------------------------------------------------------------------- /onetoall/infer_function.py: -------------------------------------------------------------------------------- 1 | # https://github.com/ssj9596/One-to-All-Animation 2 | 3 | import numpy as np 4 | import copy 5 | from ..retarget_pose import get_retarget_pose 6 | 7 | L_EYE_IDXS = list(range(36, 42)) 8 | R_EYE_IDXS = list(range(42, 48)) 9 | NOSE_TIP = 30 10 | MOUTH_L = 48 11 | MOUTH_R = 54 12 | JAW_LINE = list(range(0, 17)) 13 | 14 | 15 | # ===========================Convert wanpose format into our dwpose-like format====================== 16 | def aaposemeta_to_dwpose(meta): 17 | candidate_body = meta['keypoints_body'][:-2][:, :2] 18 | score_body = meta['keypoints_body'][:-2][:, 2] 19 | subset_body = np.arange(len(candidate_body), dtype=float) 20 | subset_body[score_body <= 0] = -1 21 | bodies = { 22 | "candidate": candidate_body, 23 | "subset": np.expand_dims(subset_body, axis=0), # shape (1, N) 24 | "score": np.expand_dims(score_body, axis=0) # shape (1, N) 25 | } 26 | hands_coords = np.stack([ 27 | meta['keypoints_right_hand'][:, :2], 28 | meta['keypoints_left_hand'][:, :2] 29 | ]) 30 | hands_score = np.stack([ 31 | meta['keypoints_right_hand'][:, 2], 32 | meta['keypoints_left_hand'][:, 2] 33 | ]) 34 | faces_coords = np.expand_dims(meta['keypoints_face'][1:][:, :2], axis=0) 35 | faces_score = np.expand_dims(meta['keypoints_face'][1:][:, 2], axis=0) 36 | dwpose_format = { 37 | "bodies": bodies, 38 | "hands": hands_coords, 39 | "hands_score": hands_score, 40 | "faces": faces_coords, 41 | "faces_score": faces_score 42 | } 43 | return dwpose_format 44 | 45 | def aaposemeta_obj_to_dwpose(pose_meta): 46 | """ 47 | Convert an AAPoseMeta object into a dwpose-like data structure 48 | Restore coordinates to relative coordinates (divide by width, height) 49 | Only handle None -> fill with zeros 50 | """ 51 | w = pose_meta.width 52 | h = pose_meta.height 53 | 54 | # If None, fill with all zeros 55 | def safe(arr, like_shape): 56 | if arr is None: 57 | return np.zeros(like_shape, dtype=np.float32) 58 | arr_np = np.array(arr, dtype=np.float32) 59 | arr_np = np.nan_to_num(arr_np, nan=0.0) 60 | return arr_np 61 | # body 62 | kps_body = safe(pose_meta.kps_body, (pose_meta.kps_body_p.shape[0], 2)) 63 | candidate_body = kps_body / np.array([w, h]) 64 | score_body = safe(pose_meta.kps_body_p, (candidate_body.shape[0],)) 65 | subset_body = np.arange(len(candidate_body), dtype=float) 66 | subset_body[score_body <= 0] = -1 67 | bodies = { 68 | "candidate": candidate_body, 69 | "subset": np.expand_dims(subset_body, axis=0), 70 | "score": np.expand_dims(score_body, axis=0) 71 | } 72 | 73 | # hands 74 | kps_rhand = safe(pose_meta.kps_rhand, (pose_meta.kps_rhand_p.shape[0], 2)) 75 | kps_lhand = safe(pose_meta.kps_lhand, (pose_meta.kps_lhand_p.shape[0], 2)) 76 | hands_coords = np.stack([ 77 | kps_rhand / np.array([w, h]), 78 | kps_lhand / np.array([w, h]) 79 | ]) 80 | hands_score = np.stack([ 81 | safe(pose_meta.kps_rhand_p, (kps_rhand.shape[0],)), 82 | safe(pose_meta.kps_lhand_p, (kps_lhand.shape[0],)) 83 | ]) 84 | 85 | dwpose_format = { 86 | "bodies": bodies, 87 | "hands": hands_coords, 88 | "hands_score": hands_score, 89 | "faces": None, 90 | "faces_score": None 91 | } 92 | return dwpose_format 93 | 94 | # ===============================Face Rough alignment====================== 95 | 96 | def _to_68x2(arr): 97 | if arr.shape == (1, 68, 2): 98 | def to_orig(x): 99 | x = np.asarray(x, dtype=np.float64) 100 | if x.shape != (68, 2): 101 | raise ValueError("to_orig expects (68,2)") 102 | return x[np.newaxis, :, :] 103 | return arr[0].astype(np.float64), to_orig 104 | if arr.shape == (68, 2): 105 | def to_orig(x): 106 | x = np.asarray(x, dtype=np.float64) 107 | if x.shape != (68, 2): 108 | raise ValueError("to_orig expects (68,2)") 109 | return x 110 | return arr.astype(np.float64), to_orig 111 | if arr.shape == (2, 68): 112 | def to_orig(x): 113 | x = np.asarray(x, dtype=np.float64) 114 | if x.shape != (68, 2): 115 | raise ValueError("to_orig expects (68,2)") 116 | return x.T 117 | return arr.T.astype(np.float64), to_orig 118 | raise ValueError(f"faces shape {arr.shape} not supported; expected (1,68,2) or (68,2) or (2,68)") 119 | 120 | def _eye_center(face68, idxs): 121 | return face68[idxs].mean(axis=0) 122 | 123 | def _anchors(face68): 124 | le = _eye_center(face68, L_EYE_IDXS) 125 | re = _eye_center(face68, R_EYE_IDXS) 126 | nose = face68[NOSE_TIP] 127 | lm = face68[MOUTH_L] 128 | rm = face68[MOUTH_R] 129 | if re[0] < le[0]: 130 | le, re = re, le 131 | return np.stack([le, re, nose, lm, rm], axis=0) 132 | 133 | def _face_scale_only(src68, ref68, target_nose_pos, alpha=1.0, anchor_pairs=[[36, 45], [27, 8]]): 134 | """ 135 | Rough alignment - adjust the shape of the source face according to the proportions of the reference, and align the nose tip to target_nose_pos. 136 | anchor_pairs: 137 | - [36, 45] for x 138 | - [27, 8] for y 139 | """ 140 | src = np.asarray(src68, dtype=np.float64) 141 | ref = np.asarray(ref68, dtype=np.float64) 142 | 143 | center = _anchors(src).mean(axis=0) 144 | src_centered = src - center 145 | 146 | src_w = np.linalg.norm(src[anchor_pairs[0][0]] - src[anchor_pairs[0][1]]) 147 | ref_w = np.linalg.norm(ref[anchor_pairs[0][0]] - ref[anchor_pairs[0][1]]) 148 | 149 | src_h = np.linalg.norm(src[anchor_pairs[1][0]] - src[anchor_pairs[1][1]]) 150 | ref_h = np.linalg.norm(ref[anchor_pairs[1][0]] - ref[anchor_pairs[1][1]]) 151 | 152 | scale_x = ref_w / src_w if src_w > 1e-6 else 1.0 153 | scale_y = ref_h / src_h if src_h > 1e-6 else 1.0 154 | 155 | scaled_local = src_centered.copy() 156 | scaled_local[:, 0] *= (1 - alpha) + scale_x * alpha 157 | scaled_local[:, 1] *= (1 - alpha) + scale_y * alpha 158 | scaled_global = scaled_local + center 159 | 160 | nose_idx = NOSE_TIP 161 | current_nose = scaled_global[nose_idx] 162 | offset = target_nose_pos - current_nose 163 | scaled_global += offset 164 | 165 | return scaled_global 166 | 167 | # ===============================Reference Img Pre-Process====================== 168 | 169 | 170 | def scale_and_translate_pose(tgt_pose, ref_pose, conf_th=0.9, return_ratio=False): 171 | aligned_pose = copy.deepcopy(tgt_pose) 172 | th = 1e-6 173 | ref_kpt = ref_pose['bodies']['candidate'].astype(np.float32) 174 | tgt_kpt = aligned_pose['bodies']['candidate'].astype(np.float32) 175 | 176 | ref_sc = ref_pose['bodies'].get('score', np.ones(ref_kpt.shape[0])).astype(np.float32).reshape(-1) 177 | tgt_sc = tgt_pose['bodies'].get('score', np.ones(tgt_kpt.shape[0])).astype(np.float32).reshape(-1) 178 | 179 | ref_shoulder_valid = (ref_sc[2] >= conf_th) and (ref_sc[5] >= conf_th) 180 | tgt_shoulder_valid = (tgt_sc[2] >= conf_th) and (tgt_sc[5] >= conf_th) 181 | shoulder_ok = ref_shoulder_valid and tgt_shoulder_valid 182 | 183 | ref_hip_valid = (ref_sc[8] >= conf_th) and (ref_sc[11] >= conf_th) 184 | tgt_hip_valid = (tgt_sc[8] >= conf_th) and (tgt_sc[11] >= conf_th) 185 | hip_ok = ref_hip_valid and tgt_hip_valid 186 | 187 | if shoulder_ok and hip_ok: 188 | ref_shoulder_w = abs(ref_kpt[5, 0] - ref_kpt[2, 0]) 189 | tgt_shoulder_w = abs(tgt_kpt[5, 0] - tgt_kpt[2, 0]) 190 | x_ratio = ref_shoulder_w / tgt_shoulder_w if tgt_shoulder_w > th else 1.0 191 | 192 | ref_torso_h = abs(np.mean(ref_kpt[[8, 11], 1]) - np.mean(ref_kpt[[2, 5], 1])) 193 | tgt_torso_h = abs(np.mean(tgt_kpt[[8, 11], 1]) - np.mean(tgt_kpt[[2, 5], 1])) 194 | y_ratio = ref_torso_h / tgt_torso_h if tgt_torso_h > th else 1.0 195 | scale_ratio = (x_ratio + y_ratio) / 2 196 | 197 | elif shoulder_ok: 198 | ref_sh_dist = np.linalg.norm(ref_kpt[2] - ref_kpt[5]) 199 | tgt_sh_dist = np.linalg.norm(tgt_kpt[2] - tgt_kpt[5]) 200 | scale_ratio = ref_sh_dist / tgt_sh_dist if tgt_sh_dist > th else 1.0 201 | 202 | else: 203 | ref_ear_dist = np.linalg.norm(ref_kpt[16] - ref_kpt[17]) 204 | tgt_ear_dist = np.linalg.norm(tgt_kpt[16] - tgt_kpt[17]) 205 | scale_ratio = ref_ear_dist / tgt_ear_dist if tgt_ear_dist > th else 1.0 206 | 207 | if return_ratio: 208 | return scale_ratio 209 | 210 | # scale 211 | anchor_idx = 1 212 | anchor_pt_before_scale = tgt_kpt[anchor_idx].copy() 213 | def scale(arr): 214 | if arr is not None and arr.size > 0: 215 | arr[..., 0] = anchor_pt_before_scale[0] + (arr[..., 0] - anchor_pt_before_scale[0]) * scale_ratio 216 | arr[..., 1] = anchor_pt_before_scale[1] + (arr[..., 1] - anchor_pt_before_scale[1]) * scale_ratio 217 | scale(tgt_kpt) 218 | scale(aligned_pose.get('faces')) 219 | scale(aligned_pose.get('hands')) 220 | 221 | # offset 222 | offset = ref_kpt[anchor_idx] - tgt_kpt[anchor_idx] 223 | def translate(arr): 224 | if arr is not None and arr.size > 0: 225 | arr += offset 226 | translate(tgt_kpt) 227 | translate(aligned_pose.get('faces')) 228 | translate(aligned_pose.get('hands')) 229 | aligned_pose['bodies']['candidate'] = tgt_kpt 230 | 231 | return aligned_pose, shoulder_ok, hip_ok 232 | 233 | # ===============================Align to Ref Driven Pose Retarget ====================== 234 | 235 | def align_to_reference(ref_pose_meta, tpl_pose_metas, tpl_dwposes, anchor_idx=None): 236 | # pose retarget + face rough align 237 | 238 | ref_pose_dw = aaposemeta_to_dwpose(ref_pose_meta) 239 | best_idx = anchor_idx 240 | tpl_pose_meta_best = tpl_pose_metas[best_idx] 241 | 242 | tpl_retarget_pose_metas = get_retarget_pose( 243 | tpl_pose_meta_best, 244 | ref_pose_meta, 245 | tpl_pose_metas, 246 | None, None 247 | ) 248 | 249 | retarget_dwposes = [aaposemeta_obj_to_dwpose(pm) for pm in tpl_retarget_pose_metas] 250 | 251 | if ref_pose_dw['faces'] is not None: 252 | ref68, _ = _to_68x2(ref_pose_dw['faces']) 253 | for frame_idx, (tpl_dw, rt_dw) in enumerate(zip(tpl_dwposes, retarget_dwposes)): 254 | if tpl_dw['faces'] is None: 255 | continue 256 | src68, to_orig = _to_68x2(tpl_dw['faces']) 257 | target_nose_pos = rt_dw['bodies']['candidate'][0] 258 | scaled68 = _face_scale_only(src68, ref68, target_nose_pos, alpha=1.0) 259 | rt_dw['faces'] = to_orig(scaled68) 260 | rt_dw['faces_score'] = tpl_dw['faces_score'] 261 | 262 | return retarget_dwposes 263 | 264 | # ===============================Rescale-Ref && Change part of pose(Option)====================== 265 | 266 | 267 | def compute_ratios_stepwise(ref_scores, source_scores, ref_pts, src_pts, conf_th=0.9, th=1e-6): 268 | 269 | def keypoint_valid(idx): 270 | return ref_scores[0, idx] >= conf_th and source_scores[0, idx] >= conf_th 271 | 272 | def safe_ratio(p1, p2): 273 | len_ref = np.linalg.norm(ref_pts[p1] - ref_pts[p2]) 274 | len_src = np.linalg.norm(src_pts[p1] - src_pts[p2]) 275 | if len_src > th: 276 | return len_ref / len_src 277 | else: 278 | return 1.0 279 | 280 | ratio_pairs = [ 281 | (0,1),(1,2),(1,5),(2,3),(3,4),(5,6),(6,7), 282 | (0,14),(0,15),(14,16),(15,17), 283 | (8,9),(9,10),(11,12),(12,13), 284 | (1,8),(1,11) 285 | ] 286 | ratios = {p: 1.0 for p in ratio_pairs} 287 | 288 | parent_map = { 289 | (3, 4): (2, 3), 290 | (6, 7): (5, 6), 291 | (9, 10): (8, 9), 292 | (12, 13): (11, 12) 293 | } 294 | 295 | # Group 1 — head only 296 | if all(keypoint_valid(i) for i in [0,1,14,15,16,17]): 297 | ratios[(0,1)] = safe_ratio(0,1) 298 | ratios[(0,14)] = safe_ratio(0,14) 299 | ratios[(0,15)] = safe_ratio(0,15) 300 | ratios[(14,16)]= safe_ratio(14,16) 301 | ratios[(15,17)]= safe_ratio(15,17) 302 | 303 | # Group 2 — +shoulder 304 | if all(keypoint_valid(i) for i in [0,1,2,5,14,15,16,17]): 305 | ratios[(1,2)] = safe_ratio(1,2) 306 | ratios[(1,5)] = safe_ratio(1,5) 307 | 308 | # Group 3 — +upper arm 309 | if all(keypoint_valid(i) for i in [0,1,2,5,14,15,16,17,3,6]): 310 | ratios[(2,3)] = safe_ratio(2,3) 311 | ratios[(5,6)] = safe_ratio(5,6) 312 | ratios[(3,4)] = ratios[parent_map[(3,4)]] 313 | ratios[(6,7)] = ratios[parent_map[(6,7)]] 314 | 315 | # Group 4 — +hips 316 | if all(keypoint_valid(i) for i in [0,1,2,5,14,15,16,17,3,6,8,11]): 317 | ratios[(1,8)] = safe_ratio(1,8) 318 | ratios[(1,11)] = safe_ratio(1,11) 319 | 320 | # Group 5 — forearm own 321 | if all(keypoint_valid(i) for i in [0,1,2,5,14,15,16,17,3,6,8,11,4,7]): 322 | ratios[(3,4)] = safe_ratio(3,4) 323 | ratios[(6,7)] = safe_ratio(6,7) 324 | 325 | # Group 6 — knees 326 | if all(keypoint_valid(i) for i in [0,1,2,5,14,15,16,17,3,6,8,11,4,7,9,12]): 327 | ratios[(8,9)] = safe_ratio(8,9) 328 | ratios[(11,12)] = safe_ratio(11,12) 329 | ratios[(9,10)] = ratios[parent_map[(9,10)]] 330 | ratios[(12,13)]= ratios[parent_map[(12,13)]] 331 | 332 | # Full body — all ratios 333 | if all(keypoint_valid(i) for i in range(18)): 334 | for p in ratio_pairs: 335 | ratios[p] = safe_ratio(*p) 336 | 337 | symmetric_pairs = [ 338 | ((1, 2), (1, 5)), # 两肩 339 | ((2, 3), (5, 6)), # 上臂 340 | ((3, 4), (6, 7)), # 前臂 341 | ((8, 9), (11, 12)), # 大腿 342 | ((9, 10), (12, 13)) # 小腿 343 | ] 344 | for left_key, right_key in symmetric_pairs: 345 | left_val = ratios.get(left_key) 346 | right_val = ratios.get(right_key) 347 | if left_val is not None and right_val is not None: 348 | avg_val = (left_val + right_val) / 2.0 349 | ratios[left_key] = avg_val 350 | ratios[right_key] = avg_val 351 | 352 | eye_pairs = [ 353 | ((13, 15), (14, 16)) 354 | ] 355 | for left_key, right_key in eye_pairs: 356 | left_val = ratios.get(left_key) 357 | right_val = ratios.get(right_key) 358 | if left_val is not None and right_val is not None: 359 | avg_val = (left_val + right_val) / 2.0 360 | ratios[left_key] = avg_val 361 | ratios[right_key] = avg_val 362 | 363 | return ratios 364 | 365 | def align_to_pose(ref_dwpose, tpl_dwposes,anchor_idx=None,conf_th=0.9,): 366 | detected_poses = copy.deepcopy(tpl_dwposes) 367 | 368 | best_pose = tpl_dwposes[anchor_idx] 369 | ref_pose_scaled, _, _ = scale_and_translate_pose(ref_dwpose, best_pose, conf_th=conf_th) 370 | 371 | ref_candidate = ref_pose_scaled['bodies']['candidate'].astype(np.float32) 372 | ref_scores = ref_pose_scaled['bodies']['score'].astype(np.float32) 373 | 374 | source_candidate = best_pose['bodies']['candidate'].astype(np.float32) 375 | source_scores = best_pose['bodies']['score'].astype(np.float32) 376 | 377 | has_ref_face = 'faces' in ref_pose_scaled and ref_pose_scaled['faces'] is not None and ref_pose_scaled['faces'].size > 0 378 | if has_ref_face: 379 | try: 380 | ref68, _ = _to_68x2(ref_pose_scaled['faces']) 381 | except Exception as e: 382 | print("Reference face conversion failed:", e) 383 | has_ref_face = False 384 | 385 | ratios = compute_ratios_stepwise(ref_scores, source_scores, ref_candidate, source_candidate, conf_th=conf_th, th=1e-6) 386 | 387 | for pose in detected_poses: 388 | candidate = pose['bodies']['candidate'] 389 | hands = pose['hands'] 390 | 391 | # ===== Neck ===== 392 | ratio = ratios[(0, 1)] 393 | x_offset = (candidate[1][0] - candidate[0][0]) * (1. - ratio) 394 | y_offset = (candidate[1][1] - candidate[0][1]) * (1. - ratio) 395 | candidate[[0, 14, 15, 16, 17], 0] += x_offset 396 | candidate[[0, 14, 15, 16, 17], 1] += y_offset 397 | 398 | # ===== Shoulder Right ===== 399 | ratio = ratios[(1, 2)] 400 | x_offset = (candidate[1][0] - candidate[2][0]) * (1. - ratio) 401 | y_offset = (candidate[1][1] - candidate[2][1]) * (1. - ratio) 402 | candidate[[2, 3, 4], 0] += x_offset 403 | candidate[[2, 3, 4], 1] += y_offset 404 | hands[1, :, 0] += x_offset 405 | hands[1, :, 1] += y_offset 406 | 407 | # ===== Shoulder Left ===== 408 | ratio = ratios[(1, 5)] 409 | x_offset = (candidate[1][0] - candidate[5][0]) * (1. - ratio) 410 | y_offset = (candidate[1][1] - candidate[5][1]) * (1. - ratio) 411 | candidate[[5, 6, 7], 0] += x_offset 412 | candidate[[5, 6, 7], 1] += y_offset 413 | hands[0, :, 0] += x_offset 414 | hands[0, :, 1] += y_offset 415 | 416 | # ===== Upper Arm Right ===== 417 | ratio = ratios[(2, 3)] 418 | x_offset = (candidate[2][0] - candidate[3][0]) * (1. - ratio) 419 | y_offset = (candidate[2][1] - candidate[3][1]) * (1. - ratio) 420 | candidate[[3, 4], 0] += x_offset 421 | candidate[[3, 4], 1] += y_offset 422 | hands[1, :, 0] += x_offset 423 | hands[1, :, 1] += y_offset 424 | 425 | # ===== Forearm Right ===== 426 | ratio = ratios[(3, 4)] 427 | x_offset = (candidate[3][0] - candidate[4][0]) * (1. - ratio) 428 | y_offset = (candidate[3][1] - candidate[4][1]) * (1. - ratio) 429 | candidate[4, 0] += x_offset 430 | candidate[4, 1] += y_offset 431 | hands[1, :, 0] += x_offset 432 | hands[1, :, 1] += y_offset 433 | 434 | # ===== Upper Arm Left ===== 435 | ratio = ratios[(5, 6)] 436 | x_offset = (candidate[5][0] - candidate[6][0]) * (1. - ratio) 437 | y_offset = (candidate[5][1] - candidate[6][1]) * (1. - ratio) 438 | candidate[[6, 7], 0] += x_offset 439 | candidate[[6, 7], 1] += y_offset 440 | hands[0, :, 0] += x_offset 441 | hands[0, :, 1] += y_offset 442 | 443 | # ===== Forearm Left ===== 444 | ratio = ratios[(6, 7)] 445 | x_offset = (candidate[6][0] - candidate[7][0]) * (1. - ratio) 446 | y_offset = (candidate[6][1] - candidate[7][1]) * (1. - ratio) 447 | candidate[7, 0] += x_offset 448 | candidate[7, 1] += y_offset 449 | hands[0, :, 0] += x_offset 450 | hands[0, :, 1] += y_offset 451 | 452 | # ===== Head parts ===== 453 | for (p1, p2) in [(0,14),(0,15),(14,16),(15,17)]: 454 | ratio = ratios[(p1,p2)] 455 | x_offset = (candidate[p1][0] - candidate[p2][0]) * (1. - ratio) 456 | y_offset = (candidate[p1][1] - candidate[p2][1]) * (1. - ratio) 457 | candidate[p2, 0] += x_offset 458 | candidate[p2, 1] += y_offset 459 | 460 | # ===== Hips (added) ===== 461 | ratio = ratios[(1, 8)] 462 | x_offset = (candidate[1][0] - candidate[8][0]) * (1. - ratio) 463 | y_offset = (candidate[1][1] - candidate[8][1]) * (1. - ratio) 464 | candidate[8, 0] += x_offset 465 | candidate[8, 1] += y_offset 466 | 467 | ratio = ratios[(1, 11)] 468 | x_offset = (candidate[1][0] - candidate[11][0]) * (1. - ratio) 469 | y_offset = (candidate[1][1] - candidate[11][1]) * (1. - ratio) 470 | candidate[11, 0] += x_offset 471 | candidate[11, 1] += y_offset 472 | 473 | # ===== Legs ===== 474 | ratio = ratios[(8, 9)] 475 | x_offset = (candidate[9][0] - candidate[8][0]) * (ratio - 1.) 476 | y_offset = (candidate[9][1] - candidate[8][1]) * (ratio - 1.) 477 | candidate[[9, 10], 0] += x_offset 478 | candidate[[9, 10], 1] += y_offset 479 | 480 | ratio = ratios[(9, 10)] 481 | x_offset = (candidate[10][0] - candidate[9][0]) * (ratio - 1.) 482 | y_offset = (candidate[10][1] - candidate[9][1]) * (ratio - 1.) 483 | candidate[10, 0] += x_offset 484 | candidate[10, 1] += y_offset 485 | 486 | ratio = ratios[(11, 12)] 487 | x_offset = (candidate[12][0] - candidate[11][0]) * (ratio - 1.) 488 | y_offset = (candidate[12][1] - candidate[11][1]) * (ratio - 1.) 489 | candidate[[12, 13], 0] += x_offset 490 | candidate[[12, 13], 1] += y_offset 491 | 492 | ratio = ratios[(12, 13)] 493 | x_offset = (candidate[13][0] - candidate[12][0]) * (ratio - 1.) 494 | y_offset = (candidate[13][1] - candidate[12][1]) * (ratio - 1.) 495 | candidate[13, 0] += x_offset 496 | candidate[13, 1] += y_offset 497 | 498 | # rough align 499 | if has_ref_face and 'faces' in pose and pose['faces'] is not None and pose['faces'].size > 0: 500 | try: 501 | src68, to_orig = _to_68x2(pose['faces']) 502 | scaled68 = _face_scale_only(src68, ref68, candidate[0], alpha=1.0) 503 | pose['faces'] = to_orig(scaled68) 504 | except Exception as e: 505 | print("Reference face conversion failed:", e) 506 | continue 507 | 508 | return detected_poses 509 | -------------------------------------------------------------------------------- /nodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from tqdm import tqdm 4 | import numpy as np 5 | import folder_paths 6 | import cv2 7 | import json 8 | import logging 9 | script_directory = os.path.dirname(os.path.abspath(__file__)) 10 | 11 | from comfy import model_management as mm 12 | from comfy.utils import ProgressBar 13 | device = mm.get_torch_device() 14 | offload_device = mm.unet_offload_device() 15 | 16 | folder_paths.add_model_folder_path("detection", os.path.join(folder_paths.models_dir, "detection")) 17 | 18 | from .models.onnx_models import ViTPose, Yolo 19 | from .pose_utils.pose2d_utils import load_pose_metas_from_kp2ds_seq, crop, bbox_from_detector 20 | from .utils import get_face_bboxes, padding_resize, resize_by_area, resize_to_bounds 21 | from .pose_utils.human_visualization import AAPoseMeta, draw_aapose_by_meta_new 22 | from .retarget_pose import get_retarget_pose 23 | 24 | class OnnxDetectionModelLoader: 25 | @classmethod 26 | def INPUT_TYPES(s): 27 | return { 28 | "required": { 29 | "vitpose_model": (folder_paths.get_filename_list("detection"), {"tooltip": "These models are loaded from the 'ComfyUI/models/detection' -folder",}), 30 | "yolo_model": (folder_paths.get_filename_list("detection"), {"tooltip": "These models are loaded from the 'ComfyUI/models/detection' -folder",}), 31 | "onnx_device": (["CUDAExecutionProvider", "CPUExecutionProvider"], {"default": "CUDAExecutionProvider", "tooltip": "Device to run the ONNX models on"}), 32 | }, 33 | } 34 | 35 | RETURN_TYPES = ("POSEMODEL",) 36 | RETURN_NAMES = ("model", ) 37 | FUNCTION = "loadmodel" 38 | CATEGORY = "WanAnimatePreprocess" 39 | DESCRIPTION = "Loads ONNX models for pose and face detection. ViTPose for pose estimation and YOLO for object detection." 40 | 41 | def loadmodel(self, vitpose_model, yolo_model, onnx_device): 42 | 43 | vitpose_model_path = folder_paths.get_full_path_or_raise("detection", vitpose_model) 44 | yolo_model_path = folder_paths.get_full_path_or_raise("detection", yolo_model) 45 | 46 | vitpose = ViTPose(vitpose_model_path, onnx_device) 47 | yolo = Yolo(yolo_model_path, onnx_device) 48 | 49 | model = { 50 | "vitpose": vitpose, 51 | "yolo": yolo, 52 | } 53 | 54 | return (model, ) 55 | 56 | class PoseAndFaceDetection: 57 | @classmethod 58 | def INPUT_TYPES(s): 59 | return { 60 | "required": { 61 | "model": ("POSEMODEL",), 62 | "images": ("IMAGE",), 63 | "width": ("INT", {"default": 832, "min": 64, "max": 2048, "step": 1, "tooltip": "Width of the generation"}), 64 | "height": ("INT", {"default": 480, "min": 64, "max": 2048, "step": 1, "tooltip": "Height of the generation"}), 65 | }, 66 | "optional": { 67 | "retarget_image": ("IMAGE", {"default": None, "tooltip": "Optional reference image for pose retargeting"}), 68 | }, 69 | } 70 | 71 | RETURN_TYPES = ("POSEDATA", "IMAGE", "STRING", "BBOX", "BBOX,") 72 | RETURN_NAMES = ("pose_data", "face_images", "key_frame_body_points", "bboxes", "face_bboxes") 73 | FUNCTION = "process" 74 | CATEGORY = "WanAnimatePreprocess" 75 | DESCRIPTION = "Detects human poses and face images from input images. Optionally retargets poses based on a reference image." 76 | 77 | def process(self, model, images, width, height, retarget_image=None): 78 | detector = model["yolo"] 79 | pose_model = model["vitpose"] 80 | B, H, W, C = images.shape 81 | 82 | shape = np.array([H, W])[None] 83 | images_np = images.numpy() 84 | 85 | IMG_NORM_MEAN = np.array([0.485, 0.456, 0.406]) 86 | IMG_NORM_STD = np.array([0.229, 0.224, 0.225]) 87 | input_resolution=(256, 192) 88 | rescale = 1.25 89 | 90 | detector.reinit() 91 | pose_model.reinit() 92 | if retarget_image is not None: 93 | refer_img = resize_by_area(retarget_image[0].numpy() * 255, width * height, divisor=16) / 255.0 94 | ref_bbox = (detector( 95 | cv2.resize(refer_img.astype(np.float32), (640, 640)).transpose(2, 0, 1)[None], 96 | shape 97 | )[0][0]["bbox"]) 98 | 99 | if ref_bbox is None or ref_bbox[-1] <= 0 or (ref_bbox[2] - ref_bbox[0]) < 10 or (ref_bbox[3] - ref_bbox[1]) < 10: 100 | ref_bbox = np.array([0, 0, refer_img.shape[1], refer_img.shape[0]]) 101 | 102 | center, scale = bbox_from_detector(ref_bbox, input_resolution, rescale=rescale) 103 | refer_img = crop(refer_img, center, scale, (input_resolution[0], input_resolution[1]))[0] 104 | 105 | img_norm = (refer_img - IMG_NORM_MEAN) / IMG_NORM_STD 106 | img_norm = img_norm.transpose(2, 0, 1).astype(np.float32) 107 | 108 | ref_keypoints = pose_model(img_norm[None], np.array(center)[None], np.array(scale)[None]) 109 | refer_pose_meta = load_pose_metas_from_kp2ds_seq(ref_keypoints, width=retarget_image.shape[2], height=retarget_image.shape[1])[0] 110 | 111 | comfy_pbar = ProgressBar(B*2) 112 | progress = 0 113 | bboxes = [] 114 | for img in tqdm(images_np, total=len(images_np), desc="Detecting bboxes"): 115 | bboxes.append(detector( 116 | cv2.resize(img, (640, 640)).transpose(2, 0, 1)[None], 117 | shape 118 | )[0][0]["bbox"]) 119 | progress += 1 120 | if progress % 10 == 0: 121 | comfy_pbar.update_absolute(progress) 122 | 123 | detector.cleanup() 124 | 125 | kp2ds = [] 126 | for img, bbox in tqdm(zip(images_np, bboxes), total=len(images_np), desc="Extracting keypoints"): 127 | if bbox is None or bbox[-1] <= 0 or (bbox[2] - bbox[0]) < 10 or (bbox[3] - bbox[1]) < 10: 128 | bbox = np.array([0, 0, img.shape[1], img.shape[0]]) 129 | 130 | bbox_xywh = bbox 131 | center, scale = bbox_from_detector(bbox_xywh, input_resolution, rescale=rescale) 132 | img = crop(img, center, scale, (input_resolution[0], input_resolution[1]))[0] 133 | 134 | img_norm = (img - IMG_NORM_MEAN) / IMG_NORM_STD 135 | img_norm = img_norm.transpose(2, 0, 1).astype(np.float32) 136 | 137 | keypoints = pose_model(img_norm[None], np.array(center)[None], np.array(scale)[None]) 138 | kp2ds.append(keypoints) 139 | progress += 1 140 | if progress % 10 == 0: 141 | comfy_pbar.update_absolute(progress) 142 | 143 | pose_model.cleanup() 144 | 145 | kp2ds = np.concatenate(kp2ds, 0) 146 | pose_metas = load_pose_metas_from_kp2ds_seq(kp2ds, width=W, height=H) 147 | 148 | face_images = [] 149 | face_bboxes = [] 150 | for idx, meta in enumerate(pose_metas): 151 | face_bbox_for_image = get_face_bboxes(meta['keypoints_face'][:, :2], scale=1.3, image_shape=(H, W)) 152 | x1, x2, y1, y2 = face_bbox_for_image 153 | face_bboxes.append((x1, y1, x2, y2)) 154 | face_image = images_np[idx][y1:y2, x1:x2] 155 | # Check if face_image is valid before resizing 156 | if face_image.size == 0 or face_image.shape[0] == 0 or face_image.shape[1] == 0: 157 | logging.warning(f"Empty face crop on frame {idx}, creating fallback image.") 158 | # Create a fallback image (black or use center crop) 159 | fallback_size = int(min(H, W) * 0.3) 160 | fallback_x1 = (W - fallback_size) // 2 161 | fallback_x2 = fallback_x1 + fallback_size 162 | fallback_y1 = int(H * 0.1) 163 | fallback_y2 = fallback_y1 + fallback_size 164 | face_image = images_np[idx][fallback_y1:fallback_y2, fallback_x1:fallback_x2] 165 | 166 | # If still empty, create a black image 167 | if face_image.size == 0: 168 | face_image = np.zeros((fallback_size, fallback_size, C), dtype=images_np.dtype) 169 | face_image = cv2.resize(face_image, (512, 512)) 170 | face_images.append(face_image) 171 | 172 | face_images_np = np.stack(face_images, 0) 173 | face_images_tensor = torch.from_numpy(face_images_np) 174 | 175 | if retarget_image is not None and refer_pose_meta is not None: 176 | retarget_pose_metas = get_retarget_pose(pose_metas[0], refer_pose_meta, pose_metas, None, None) 177 | else: 178 | retarget_pose_metas = [AAPoseMeta.from_humanapi_meta(meta) for meta in pose_metas] 179 | 180 | bbox = np.array(bboxes[0]).flatten() 181 | if bbox.shape[0] >= 4: 182 | bbox_ints = tuple(int(v) for v in bbox[:4]) 183 | else: 184 | bbox_ints = (0, 0, 0, 0) 185 | 186 | key_frame_num = 4 if B >= 4 else 1 187 | key_frame_step = len(pose_metas) // key_frame_num 188 | key_frame_index_list = list(range(0, len(pose_metas), key_frame_step)) 189 | 190 | key_points_index = [0, 1, 2, 5, 8, 11, 10, 13] 191 | 192 | for key_frame_index in key_frame_index_list: 193 | keypoints_body_list = [] 194 | body_key_points = pose_metas[key_frame_index]['keypoints_body'] 195 | for each_index in key_points_index: 196 | each_keypoint = body_key_points[each_index] 197 | if None is each_keypoint: 198 | continue 199 | keypoints_body_list.append(each_keypoint) 200 | 201 | keypoints_body = np.array(keypoints_body_list)[:, :2] 202 | wh = np.array([[pose_metas[0]['width'], pose_metas[0]['height']]]) 203 | points = (keypoints_body * wh).astype(np.int32) 204 | points_dict_list = [] 205 | for point in points: 206 | points_dict_list.append({"x": int(point[0]), "y": int(point[1])}) 207 | 208 | pose_data = { 209 | "retarget_image": refer_img if retarget_image is not None else None, 210 | "pose_metas": retarget_pose_metas, 211 | "refer_pose_meta": refer_pose_meta if retarget_image is not None else None, 212 | "pose_metas_original": pose_metas, 213 | } 214 | 215 | return (pose_data, face_images_tensor, json.dumps(points_dict_list), [bbox_ints], face_bboxes) 216 | 217 | class DrawViTPose: 218 | @classmethod 219 | def INPUT_TYPES(s): 220 | return { 221 | "required": { 222 | "pose_data": ("POSEDATA",), 223 | "width": ("INT", {"default": 832, "min": 64, "max": 2048, "step": 1, "tooltip": "Width of the generation"}), 224 | "height": ("INT", {"default": 480, "min": 64, "max": 2048, "step": 1, "tooltip": "Height of the generation"}), 225 | "retarget_padding": ("INT", {"default": 16, "min": 0, "max": 512, "step": 1, "tooltip": "When > 0, the retargeted pose image is padded and resized to the target size"}), 226 | "body_stick_width": ("INT", {"default": -1, "min": -1, "max": 20, "step": 1, "tooltip": "Width of the body sticks. Set to 0 to disable body drawing, -1 for auto"}), 227 | "hand_stick_width": ("INT", {"default": -1, "min": -1, "max": 20, "step": 1, "tooltip": "Width of the hand sticks. Set to 0 to disable hand drawing, -1 for auto"}), 228 | "draw_head": ("BOOLEAN", {"default": "True", "tooltip": "Whether to draw head keypoints"}), 229 | }, 230 | } 231 | 232 | RETURN_TYPES = ("IMAGE", ) 233 | RETURN_NAMES = ("pose_images", ) 234 | FUNCTION = "process" 235 | CATEGORY = "WanAnimatePreprocess" 236 | DESCRIPTION = "Draws pose images from pose data." 237 | 238 | def process(self, pose_data, width, height, body_stick_width, hand_stick_width, draw_head, retarget_padding=64): 239 | 240 | retarget_image = pose_data.get("retarget_image", None) 241 | pose_metas = pose_data["pose_metas"] 242 | 243 | draw_hand = hand_stick_width != 0 244 | use_retarget_resize = retarget_padding > 0 and retarget_image is not None 245 | 246 | comfy_pbar = ProgressBar(len(pose_metas)) 247 | progress = 0 248 | crop_target_image = None 249 | pose_images = [] 250 | 251 | for idx, meta in enumerate(tqdm(pose_metas, desc="Drawing pose images")): 252 | canvas = np.zeros((height, width, 3), dtype=np.uint8) 253 | pose_image = draw_aapose_by_meta_new(canvas, meta, draw_hand=draw_hand, draw_head=draw_head, body_stick_width=body_stick_width, hand_stick_width=hand_stick_width) 254 | 255 | if crop_target_image is None: 256 | crop_target_image = pose_image 257 | 258 | if use_retarget_resize: 259 | pose_image = resize_to_bounds(pose_image, height, width, crop_target_image=crop_target_image, extra_padding=retarget_padding) 260 | else: 261 | pose_image = padding_resize(pose_image, height, width) 262 | 263 | pose_images.append(pose_image) 264 | progress += 1 265 | if progress % 10 == 0: 266 | comfy_pbar.update_absolute(progress) 267 | 268 | pose_images_np = np.stack(pose_images, 0) 269 | pose_images_tensor = torch.from_numpy(pose_images_np).float() / 255.0 270 | 271 | return (pose_images_tensor, ) 272 | 273 | class PoseRetargetPromptHelper: 274 | @classmethod 275 | def INPUT_TYPES(s): 276 | return { 277 | "required": { 278 | "pose_data": ("POSEDATA",), 279 | }, 280 | } 281 | 282 | RETURN_TYPES = ("STRING", "STRING", ) 283 | RETURN_NAMES = ("prompt", "retarget_prompt", ) 284 | FUNCTION = "process" 285 | CATEGORY = "WanAnimatePreprocess" 286 | DESCRIPTION = "Generates text prompts for pose retargeting based on visibility of arms and legs in the template pose. Originally used for Flux Kontext" 287 | 288 | def process(self, pose_data): 289 | refer_pose_meta = pose_data.get("refer_pose_meta", None) 290 | if refer_pose_meta is None: 291 | return ("Change the person to face forward.", "Change the person to face forward.", ) 292 | tpl_pose_metas = pose_data["pose_metas_original"] 293 | arm_visible = False 294 | leg_visible = False 295 | 296 | for tpl_pose_meta in tpl_pose_metas: 297 | tpl_keypoints = tpl_pose_meta['keypoints_body'] 298 | tpl_keypoints = np.array(tpl_keypoints) 299 | if np.any(tpl_keypoints[3]) != 0 or np.any(tpl_keypoints[4]) != 0 or np.any(tpl_keypoints[6]) != 0 or np.any(tpl_keypoints[7]) != 0: 300 | if (tpl_keypoints[3][0] <= 1 and tpl_keypoints[3][1] <= 1 and tpl_keypoints[3][2] >= 0.75) or (tpl_keypoints[4][0] <= 1 and tpl_keypoints[4][1] <= 1 and tpl_keypoints[4][2] >= 0.75) or \ 301 | (tpl_keypoints[6][0] <= 1 and tpl_keypoints[6][1] <= 1 and tpl_keypoints[6][2] >= 0.75) or (tpl_keypoints[7][0] <= 1 and tpl_keypoints[7][1] <= 1 and tpl_keypoints[7][2] >= 0.75): 302 | arm_visible = True 303 | if np.any(tpl_keypoints[9]) != 0 or np.any(tpl_keypoints[12]) != 0 or np.any(tpl_keypoints[10]) != 0 or np.any(tpl_keypoints[13]) != 0: 304 | if (tpl_keypoints[9][0] <= 1 and tpl_keypoints[9][1] <= 1 and tpl_keypoints[9][2] >= 0.75) or (tpl_keypoints[12][0] <= 1 and tpl_keypoints[12][1] <= 1 and tpl_keypoints[12][2] >= 0.75) or \ 305 | (tpl_keypoints[10][0] <= 1 and tpl_keypoints[10][1] <= 1 and tpl_keypoints[10][2] >= 0.75) or (tpl_keypoints[13][0] <= 1 and tpl_keypoints[13][1] <= 1 and tpl_keypoints[13][2] >= 0.75): 306 | leg_visible = True 307 | if arm_visible and leg_visible: 308 | break 309 | 310 | if leg_visible: 311 | if tpl_pose_meta['width'] > tpl_pose_meta['height']: 312 | tpl_prompt = "Change the person to a standard T-pose (facing forward with arms extended). The person is standing. Feet and Hands are visible in the image." 313 | else: 314 | tpl_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. The person is standing. Feet and Hands are visible in the image." 315 | 316 | if refer_pose_meta['width'] > refer_pose_meta['height']: 317 | refer_prompt = "Change the person to a standard T-pose (facing forward with arms extended). The person is standing. Feet and Hands are visible in the image." 318 | else: 319 | refer_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. The person is standing. Feet and Hands are visible in the image." 320 | elif arm_visible: 321 | if tpl_pose_meta['width'] > tpl_pose_meta['height']: 322 | tpl_prompt = "Change the person to a standard T-pose (facing forward with arms extended). Hands are visible in the image." 323 | else: 324 | tpl_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. Hands are visible in the image." 325 | 326 | if refer_pose_meta['width'] > refer_pose_meta['height']: 327 | refer_prompt = "Change the person to a standard T-pose (facing forward with arms extended). Hands are visible in the image." 328 | else: 329 | refer_prompt = "Change the person to a standard pose with the face oriented forward and arms extending straight down by the sides. Hands are visible in the image." 330 | else: 331 | tpl_prompt = "Change the person to face forward." 332 | refer_prompt = "Change the person to face forward." 333 | 334 | return (tpl_prompt, refer_prompt, ) 335 | 336 | class PoseDetectionOneToAllAnimation: 337 | @classmethod 338 | def INPUT_TYPES(s): 339 | return { 340 | "required": { 341 | "model": ("POSEMODEL",), 342 | "images": ("IMAGE",), 343 | "width": ("INT", {"default": 832, "min": 64, "max": 2048, "step": 2, "tooltip": "Width of the generation"}), 344 | "height": ("INT", {"default": 480, "min": 64, "max": 2048, "step": 2, "tooltip": "Height of the generation"}), 345 | "align_to": (["ref", "pose", "none"], {"default": "ref", "tooltip": "Alignment mode for poses"}), 346 | "draw_face_points": (["full", "weak", "none"], {"default": "full", "tooltip": "Whether to draw face keypoints on the pose images"}), 347 | "draw_head": (["full", "weak", "none"], {"default": "full", "tooltip": "Whether to draw head keypoints on the pose images"}), 348 | }, 349 | "optional": { 350 | "ref_image": ("IMAGE", {"default": None, "tooltip": "Optional reference image for pose retargeting"}), 351 | }, 352 | } 353 | 354 | RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE", "MASK",) 355 | RETURN_NAMES = ("pose_images", "ref_pose_image", "ref_image", "ref_mask") 356 | FUNCTION = "process" 357 | CATEGORY = "WanAnimatePreprocess" 358 | DESCRIPTION = "Specialized pose detection and alignment for OneToAllAnimation model https://github.com/ssj9596/One-to-All-Animation. Detects poses from input images and aligns them based on a reference image if provided." 359 | 360 | def process(self, model, images, width, height, align_to, draw_face_points, draw_head, ref_image=None): 361 | from .onetoall.infer_function import aaposemeta_to_dwpose, align_to_reference, align_to_pose 362 | from .onetoall.utils import draw_pose_aligned, warp_ref_to_pose 363 | detector = model["yolo"] 364 | pose_model = model["vitpose"] 365 | B, H, W, C = images.shape 366 | 367 | shape = np.array([H, W])[None] 368 | images_np = images.numpy() 369 | 370 | IMG_NORM_MEAN = np.array([0.485, 0.456, 0.406]) 371 | IMG_NORM_STD = np.array([0.229, 0.224, 0.225]) 372 | input_resolution=(256, 192) 373 | rescale = 1.25 374 | 375 | detector.reinit() 376 | pose_model.reinit() 377 | 378 | if ref_image is not None: 379 | refer_img_np = ref_image[0].numpy() * 255 380 | refer_img = resize_by_area(refer_img_np, width * height, divisor=16) / 255.0 381 | ref_bbox = (detector( 382 | cv2.resize(refer_img.astype(np.float32), (640, 640)).transpose(2, 0, 1)[None], 383 | shape 384 | )[0][0]["bbox"]) 385 | 386 | if ref_bbox is None or ref_bbox[-1] <= 0 or (ref_bbox[2] - ref_bbox[0]) < 10 or (ref_bbox[3] - ref_bbox[1]) < 10: 387 | ref_bbox = np.array([0, 0, refer_img.shape[1], refer_img.shape[0]]) 388 | 389 | center, scale = bbox_from_detector(ref_bbox, input_resolution, rescale=rescale) 390 | refer_img = crop(refer_img, center, scale, (input_resolution[0], input_resolution[1]))[0] 391 | 392 | img_norm = (refer_img - IMG_NORM_MEAN) / IMG_NORM_STD 393 | img_norm = img_norm.transpose(2, 0, 1).astype(np.float32) 394 | 395 | ref_keypoints = pose_model(img_norm[None], np.array(center)[None], np.array(scale)[None]) 396 | refer_pose_meta = load_pose_metas_from_kp2ds_seq(ref_keypoints, width=ref_image.shape[2], height=ref_image.shape[1])[0] 397 | 398 | ref_dwpose = aaposemeta_to_dwpose(refer_pose_meta) 399 | 400 | comfy_pbar = ProgressBar(B*2) 401 | progress = 0 402 | bboxes = [] 403 | for img in tqdm(images_np, total=len(images_np), desc="Detecting bboxes"): 404 | bboxes.append(detector( 405 | cv2.resize(img, (640, 640)).transpose(2, 0, 1)[None], 406 | shape 407 | )[0][0]["bbox"]) 408 | progress += 1 409 | if progress % 10 == 0: 410 | comfy_pbar.update_absolute(progress) 411 | 412 | detector.cleanup() 413 | 414 | kp2ds = [] 415 | for img, bbox in tqdm(zip(images_np, bboxes), total=len(images_np), desc="Extracting keypoints"): 416 | if bbox is None or bbox[-1] <= 0 or (bbox[2] - bbox[0]) < 10 or (bbox[3] - bbox[1]) < 10: 417 | bbox = np.array([0, 0, img.shape[1], img.shape[0]]) 418 | 419 | bbox_xywh = bbox 420 | center, scale = bbox_from_detector(bbox_xywh, input_resolution, rescale=rescale) 421 | img = crop(img, center, scale, (input_resolution[0], input_resolution[1]))[0] 422 | 423 | img_norm = (img - IMG_NORM_MEAN) / IMG_NORM_STD 424 | img_norm = img_norm.transpose(2, 0, 1).astype(np.float32) 425 | 426 | keypoints = pose_model(img_norm[None], np.array(center)[None], np.array(scale)[None]) 427 | kp2ds.append(keypoints) 428 | progress += 1 429 | if progress % 10 == 0: 430 | comfy_pbar.update_absolute(progress) 431 | 432 | pose_model.cleanup() 433 | 434 | kp2ds = np.concatenate(kp2ds, 0) 435 | pose_metas = load_pose_metas_from_kp2ds_seq(kp2ds, width=W, height=H) 436 | tpl_dwposes = [aaposemeta_to_dwpose(meta) for meta in pose_metas] 437 | 438 | ref_pose_image_tensor = None 439 | if ref_image is not None: 440 | if align_to == "ref": 441 | ref_pose_image = draw_pose_aligned(ref_dwpose, height, width, without_face=True) 442 | ref_pose_image_np = np.stack(ref_pose_image, 0) 443 | ref_pose_image_tensor = torch.from_numpy(ref_pose_image_np).unsqueeze(0).float() / 255.0 444 | tpl_dwposes = align_to_reference(refer_pose_meta, pose_metas, tpl_dwposes, anchor_idx=0) 445 | image_input_tensor = ref_image 446 | image_mask_tensor = torch.zeros(1, ref_image.shape[1], ref_image.shape[2], dtype=torch.float32, device="cpu") 447 | elif align_to == "pose": 448 | image_input, ref_pose_image_np, image_mask = warp_ref_to_pose(refer_img_np, tpl_dwposes[0], ref_dwpose) 449 | ref_pose_image_np = np.stack(ref_pose_image_np, 0) 450 | ref_pose_image_tensor = torch.from_numpy(ref_pose_image_np).unsqueeze(0).float() / 255.0 451 | tpl_dwposes = align_to_pose(ref_dwpose, tpl_dwposes, anchor_idx=0) 452 | image_input_tensor = torch.from_numpy(image_input).unsqueeze(0).float() / 255.0 453 | image_mask_tensor = torch.from_numpy(image_mask).unsqueeze(0).float() / 255.0 454 | elif align_to == "none": 455 | ref_pose_image = draw_pose_aligned(ref_dwpose, height, width, without_face=True) 456 | ref_pose_image_np = np.stack(ref_pose_image, 0) 457 | ref_pose_image_tensor = torch.from_numpy(ref_pose_image_np).unsqueeze(0).float() / 255.0 458 | image_input_tensor = ref_image 459 | image_mask_tensor = torch.zeros(1, ref_image.shape[1], ref_image.shape[2], dtype=torch.float32, device="cpu") 460 | else: 461 | ref_pose_image_tensor = torch.zeros((1, height, width, 3), dtype=torch.float32, device="cpu") 462 | image_input_tensor = torch.zeros((1, height, width, 3), dtype=torch.float32, device="cpu") 463 | image_mask_tensor = torch.zeros(1, height, width, dtype=torch.float32, device="cpu") 464 | 465 | pose_imgs = [] 466 | for pose_np in tpl_dwposes: 467 | pose_img = draw_pose_aligned(pose_np, height, width, without_face=(draw_face_points=="none"), face_change=(draw_face_points=="weak"), head_strength=draw_head) 468 | pose_img = torch.from_numpy(np.array(pose_img)) 469 | pose_imgs.append(pose_img) 470 | 471 | pose_tensor = torch.stack(pose_imgs).cpu().float() / 255.0 472 | 473 | return (pose_tensor, ref_pose_image_tensor, image_input_tensor, image_mask_tensor) 474 | 475 | NODE_CLASS_MAPPINGS = { 476 | "OnnxDetectionModelLoader": OnnxDetectionModelLoader, 477 | "PoseAndFaceDetection": PoseAndFaceDetection, 478 | "DrawViTPose": DrawViTPose, 479 | "PoseRetargetPromptHelper": PoseRetargetPromptHelper, 480 | "PoseDetectionOneToAllAnimation": PoseDetectionOneToAllAnimation, 481 | } 482 | NODE_DISPLAY_NAME_MAPPINGS = { 483 | "OnnxDetectionModelLoader": "ONNX Detection Model Loader", 484 | "PoseAndFaceDetection": "Pose and Face Detection", 485 | "DrawViTPose": "Draw ViT Pose", 486 | "PoseRetargetPromptHelper": "Pose Retarget Prompt Helper", 487 | "PoseDetectionOneToAllAnimation": "Pose Detection OneToAll Animation", 488 | } 489 | -------------------------------------------------------------------------------- /retarget_pose.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import numpy as np 3 | from tqdm import tqdm 4 | import math 5 | from typing import NamedTuple 6 | import copy 7 | from .pose_utils.pose2d_utils import AAPoseMeta 8 | 9 | # load skeleton name and bone lines 10 | keypoint_list = [ 11 | "Nose", 12 | "Neck", 13 | "RShoulder", 14 | "RElbow", 15 | "RWrist", # No.4 16 | "LShoulder", 17 | "LElbow", 18 | "LWrist", # No.7 19 | "RHip", 20 | "RKnee", 21 | "RAnkle", # No.10 22 | "LHip", 23 | "LKnee", 24 | "LAnkle", # No.13 25 | "REye", 26 | "LEye", 27 | "REar", 28 | "LEar", 29 | "LToe", 30 | "RToe", 31 | ] 32 | 33 | 34 | limbSeq = [ 35 | [2, 3], [2, 6], # shoulders 36 | [3, 4], [4, 5], # left arm 37 | [6, 7], [7, 8], # right arm 38 | [2, 9], [9, 10], [10, 11], # right leg 39 | [2, 12], [12, 13], [13, 14], # left leg 40 | [2, 1], [1, 15], [15, 17], [1, 16], [16, 18], # face (nose, eyes, ears) 41 | [14, 19], # left foot 42 | [11, 20] # right foot 43 | ] 44 | 45 | eps = 0.01 46 | 47 | class Keypoint(NamedTuple): 48 | x: float 49 | y: float 50 | score: float = 1.0 51 | id: int = -1 52 | 53 | 54 | # for each limb, calculate src & dst bone's length 55 | # and calculate their ratios 56 | def get_length(skeleton, limb): 57 | 58 | k1_index, k2_index = limb 59 | 60 | H, W = skeleton['height'], skeleton['width'] 61 | keypoints = skeleton['keypoints_body'] 62 | keypoint1 = keypoints[k1_index - 1] 63 | keypoint2 = keypoints[k2_index - 1] 64 | 65 | if keypoint1 is None or keypoint2 is None: 66 | return None, None, None 67 | 68 | X = np.array([keypoint1[0], keypoint2[0]]) * float(W) 69 | Y = np.array([keypoint1[1], keypoint2[1]]) * float(H) 70 | length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 71 | 72 | return X, Y, length 73 | 74 | 75 | 76 | def get_handpose_meta(keypoints, delta, src_H, src_W): 77 | 78 | new_keypoints = [] 79 | 80 | for idx, keypoint in enumerate(keypoints): 81 | if keypoint is None: 82 | new_keypoints.append(None) 83 | continue 84 | if keypoint.score == 0: 85 | new_keypoints.append(None) 86 | continue 87 | 88 | x, y = keypoint.x, keypoint.y 89 | x = int(x * src_W + delta[0]) 90 | y = int(y * src_H + delta[1]) 91 | 92 | new_keypoints.append( 93 | Keypoint( 94 | x=x, 95 | y=y, 96 | score=keypoint.score, 97 | )) 98 | 99 | return new_keypoints 100 | 101 | 102 | def deal_hand_keypoints(hand_res, r_ratio, l_ratio, hand_score_th = 0.5): 103 | 104 | left_hand = [] 105 | right_hand = [] 106 | 107 | left_delta_x = hand_res['left'][0][0] * (l_ratio - 1) 108 | left_delta_y = hand_res['left'][0][1] * (l_ratio - 1) 109 | 110 | right_delta_x = hand_res['right'][0][0] * (r_ratio - 1) 111 | right_delta_y = hand_res['right'][0][1] * (r_ratio - 1) 112 | 113 | length = len(hand_res['left']) 114 | 115 | for i in range(length): 116 | # left hand 117 | if hand_res['left'][i][2] < hand_score_th: 118 | left_hand.append( 119 | Keypoint( 120 | x=-1, 121 | y=-1, 122 | score=0, 123 | ) 124 | ) 125 | else: 126 | left_hand.append( 127 | Keypoint( 128 | x=hand_res['left'][i][0] * l_ratio - left_delta_x, 129 | y=hand_res['left'][i][1] * l_ratio - left_delta_y, 130 | score = hand_res['left'][i][2] 131 | ) 132 | ) 133 | 134 | # right hand 135 | if hand_res['right'][i][2] < hand_score_th: 136 | right_hand.append( 137 | Keypoint( 138 | x=-1, 139 | y=-1, 140 | score=0, 141 | ) 142 | ) 143 | else: 144 | right_hand.append( 145 | Keypoint( 146 | x=hand_res['right'][i][0] * r_ratio - right_delta_x, 147 | y=hand_res['right'][i][1] * r_ratio - right_delta_y, 148 | score = hand_res['right'][i][2] 149 | ) 150 | ) 151 | 152 | return right_hand, left_hand 153 | 154 | 155 | def get_scaled_pose(canvas, src_canvas, keypoints, keypoints_hand, bone_ratio_list, delta_ground_x, delta_ground_y, 156 | rescaled_src_ground_x, body_flag, id, scale_min, threshold = 0.4): 157 | 158 | H, W = canvas 159 | src_H, src_W = src_canvas 160 | 161 | new_length_list = [ ] 162 | angle_list = [ ] 163 | 164 | # keypoints from 0-1 to H/W range 165 | for idx in range(len(keypoints)): 166 | if keypoints[idx] is None or len(keypoints[idx]) == 0: 167 | continue 168 | 169 | keypoints[idx] = [keypoints[idx][0] * src_W, keypoints[idx][1] * src_H, keypoints[idx][2]] 170 | 171 | # first traverse, get new_length_list and angle_list 172 | for idx, (k1_index, k2_index) in enumerate(limbSeq): 173 | keypoint1 = keypoints[k1_index - 1] 174 | keypoint2 = keypoints[k2_index - 1] 175 | 176 | if keypoint1 is None or keypoint2 is None or len(keypoint1) == 0 or len(keypoint2) == 0: 177 | new_length_list.append(None) 178 | angle_list.append(None) 179 | continue 180 | 181 | Y = np.array([keypoint1[0], keypoint2[0]]) #* float(W) 182 | X = np.array([keypoint1[1], keypoint2[1]]) #* float(H) 183 | 184 | length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 185 | 186 | new_length = length * bone_ratio_list[idx] 187 | angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) 188 | 189 | new_length_list.append(new_length) 190 | angle_list.append(angle) 191 | 192 | # Keep foot length within 0.5x calf length 193 | foot_lower_leg_ratio = 0.5 194 | if new_length_list[8] != None and new_length_list[18] != None: 195 | if new_length_list[18] > new_length_list[8] * foot_lower_leg_ratio: 196 | new_length_list[18] = new_length_list[8] * foot_lower_leg_ratio 197 | 198 | if new_length_list[11] != None and new_length_list[17] != None: 199 | if new_length_list[17] > new_length_list[11] * foot_lower_leg_ratio: 200 | new_length_list[17] = new_length_list[11] * foot_lower_leg_ratio 201 | 202 | # second traverse, calculate new keypoints 203 | rescale_keypoints = keypoints.copy() 204 | 205 | for idx, (k1_index, k2_index) in enumerate(limbSeq): 206 | # update dst_keypoints 207 | start_keypoint = rescale_keypoints[k1_index - 1] 208 | new_length = new_length_list[idx] 209 | angle = angle_list[idx] 210 | 211 | if rescale_keypoints[k1_index - 1] is None or rescale_keypoints[k2_index - 1] is None or \ 212 | len(rescale_keypoints[k1_index - 1]) == 0 or len(rescale_keypoints[k2_index - 1]) == 0: 213 | continue 214 | 215 | # calculate end_keypoint 216 | delta_x = new_length * math.cos(math.radians(angle)) 217 | delta_y = new_length * math.sin(math.radians(angle)) 218 | 219 | end_keypoint_x = start_keypoint[0] - delta_x 220 | end_keypoint_y = start_keypoint[1] - delta_y 221 | 222 | # update keypoints 223 | rescale_keypoints[k2_index - 1] = [end_keypoint_x, end_keypoint_y, rescale_keypoints[k2_index - 1][2]] 224 | 225 | if id == 0: 226 | if body_flag == 'full_body' and rescale_keypoints[8] != None and rescale_keypoints[11] != None: 227 | delta_ground_x_offset_first_frame = (rescale_keypoints[8][0] + rescale_keypoints[11][0]) / 2 - rescaled_src_ground_x 228 | delta_ground_x += delta_ground_x_offset_first_frame 229 | elif body_flag == 'half_body' and rescale_keypoints[1] != None: 230 | delta_ground_x_offset_first_frame = rescale_keypoints[1][0] - rescaled_src_ground_x 231 | delta_ground_x += delta_ground_x_offset_first_frame 232 | 233 | # offset all keypoints 234 | for idx in range(len(rescale_keypoints)): 235 | if rescale_keypoints[idx] is None or len(rescale_keypoints[idx]) == 0 : 236 | continue 237 | rescale_keypoints[idx][0] -= delta_ground_x 238 | rescale_keypoints[idx][1] -= delta_ground_y 239 | 240 | # rescale keypoints to original size 241 | rescale_keypoints[idx][0] /= scale_min 242 | rescale_keypoints[idx][1] /= scale_min 243 | 244 | # Scale hand proportions based on body skeletal ratios 245 | r_ratio = max(bone_ratio_list[0], bone_ratio_list[1]) / scale_min 246 | l_ratio = max(bone_ratio_list[0], bone_ratio_list[1]) / scale_min 247 | left_hand, right_hand = deal_hand_keypoints(keypoints_hand, r_ratio, l_ratio, hand_score_th = threshold) 248 | 249 | left_hand_new = left_hand.copy() 250 | right_hand_new = right_hand.copy() 251 | 252 | if rescale_keypoints[4] == None and rescale_keypoints[7] == None: 253 | pass 254 | 255 | elif rescale_keypoints[4] == None and rescale_keypoints[7] != None: 256 | right_hand_delta = np.array(rescale_keypoints[7][:2]) - np.array(keypoints[7][:2]) 257 | right_hand_new = get_handpose_meta(right_hand, right_hand_delta, src_H, src_W) 258 | 259 | elif rescale_keypoints[4] != None and rescale_keypoints[7] == None: 260 | left_hand_delta = np.array(rescale_keypoints[4][:2]) - np.array(keypoints[4][:2]) 261 | left_hand_new = get_handpose_meta(left_hand, left_hand_delta, src_H, src_W) 262 | 263 | else: 264 | # get left_hand and right_hand offset 265 | left_hand_delta = np.array(rescale_keypoints[4][:2]) - np.array(keypoints[4][:2]) 266 | right_hand_delta = np.array(rescale_keypoints[7][:2]) - np.array(keypoints[7][:2]) 267 | 268 | if keypoints[4][0] != None and left_hand[0].x != -1: 269 | left_hand_root_offset = np.array( ( keypoints[4][0] - left_hand[0].x * src_W, keypoints[4][1] - left_hand[0].y * src_H)) 270 | left_hand_delta += left_hand_root_offset 271 | 272 | if keypoints[7][0] != None and right_hand[0].x != -1: 273 | right_hand_root_offset = np.array( ( keypoints[7][0] - right_hand[0].x * src_W, keypoints[7][1] - right_hand[0].y * src_H)) 274 | right_hand_delta += right_hand_root_offset 275 | 276 | dis_left_hand = ((keypoints[4][0] - left_hand[0].x * src_W) ** 2 + (keypoints[4][1] - left_hand[0].y * src_H) ** 2) ** 0.5 277 | dis_right_hand = ((keypoints[7][0] - left_hand[0].x * src_W) ** 2 + (keypoints[7][1] - left_hand[0].y * src_H) ** 2) ** 0.5 278 | 279 | if dis_left_hand > dis_right_hand: 280 | right_hand_new = get_handpose_meta(left_hand, right_hand_delta, src_H, src_W) 281 | left_hand_new = get_handpose_meta(right_hand, left_hand_delta, src_H, src_W) 282 | else: 283 | left_hand_new = get_handpose_meta(left_hand, left_hand_delta, src_H, src_W) 284 | right_hand_new = get_handpose_meta(right_hand, right_hand_delta, src_H, src_W) 285 | 286 | # get normalized keypoints_body 287 | norm_body_keypoints = [ ] 288 | for body_keypoint in rescale_keypoints: 289 | if body_keypoint != None: 290 | norm_body_keypoints.append([body_keypoint[0] / W , body_keypoint[1] / H, body_keypoint[2]]) 291 | else: 292 | norm_body_keypoints.append(None) 293 | 294 | frame_info = { 295 | 'height': H, 296 | 'width': W, 297 | 'keypoints_body': norm_body_keypoints, 298 | 'keypoints_left_hand' : left_hand_new, 299 | 'keypoints_right_hand' : right_hand_new, 300 | } 301 | 302 | return frame_info 303 | 304 | 305 | def rescale_skeleton(H, W, keypoints, bone_ratio_list): 306 | 307 | rescale_keypoints = keypoints.copy() 308 | 309 | new_length_list = [ ] 310 | angle_list = [ ] 311 | 312 | # keypoints from 0-1 to H/W range 313 | for idx in range(len(rescale_keypoints)): 314 | if rescale_keypoints[idx] is None or len(rescale_keypoints[idx]) == 0: 315 | continue 316 | 317 | rescale_keypoints[idx] = [rescale_keypoints[idx][0] * W, rescale_keypoints[idx][1] * H] 318 | 319 | # first traverse, get new_length_list and angle_list 320 | for idx, (k1_index, k2_index) in enumerate(limbSeq): 321 | keypoint1 = rescale_keypoints[k1_index - 1] 322 | keypoint2 = rescale_keypoints[k2_index - 1] 323 | 324 | if keypoint1 is None or keypoint2 is None or len(keypoint1) == 0 or len(keypoint2) == 0: 325 | new_length_list.append(None) 326 | angle_list.append(None) 327 | continue 328 | 329 | Y = np.array([keypoint1[0], keypoint2[0]]) #* float(W) 330 | X = np.array([keypoint1[1], keypoint2[1]]) #* float(H) 331 | 332 | length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 333 | 334 | 335 | new_length = length * bone_ratio_list[idx] 336 | angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) 337 | 338 | new_length_list.append(new_length) 339 | angle_list.append(angle) 340 | 341 | # # second traverse, calculate new keypoints 342 | for idx, (k1_index, k2_index) in enumerate(limbSeq): 343 | # update dst_keypoints 344 | start_keypoint = rescale_keypoints[k1_index - 1] 345 | new_length = new_length_list[idx] 346 | angle = angle_list[idx] 347 | 348 | if rescale_keypoints[k1_index - 1] is None or rescale_keypoints[k2_index - 1] is None or \ 349 | len(rescale_keypoints[k1_index - 1]) == 0 or len(rescale_keypoints[k2_index - 1]) == 0: 350 | continue 351 | 352 | # calculate end_keypoint 353 | delta_x = new_length * math.cos(math.radians(angle)) 354 | delta_y = new_length * math.sin(math.radians(angle)) 355 | 356 | end_keypoint_x = start_keypoint[0] - delta_x 357 | end_keypoint_y = start_keypoint[1] - delta_y 358 | 359 | # update keypoints 360 | rescale_keypoints[k2_index - 1] = [end_keypoint_x, end_keypoint_y] 361 | 362 | return rescale_keypoints 363 | 364 | 365 | def fix_lack_keypoints_use_sym(skeleton): 366 | 367 | keypoints = skeleton['keypoints_body'] 368 | H, W = skeleton['height'], skeleton['width'] 369 | 370 | limb_points_list = [ 371 | [3, 4, 5], 372 | [6, 7, 8], 373 | [12, 13, 14, 19], 374 | [9, 10, 11, 20], 375 | ] 376 | 377 | for limb_points in limb_points_list: 378 | miss_flag = False 379 | for point in limb_points: 380 | if keypoints[point - 1] is None: 381 | miss_flag = True 382 | continue 383 | if miss_flag: 384 | skeleton['keypoints_body'][point - 1] = None 385 | 386 | repair_limb_seq_left = [ 387 | [3, 4], [4, 5], # left arm 388 | [12, 13], [13, 14], # left leg 389 | [14, 19] # left foot 390 | ] 391 | 392 | repair_limb_seq_right = [ 393 | [6, 7], [7, 8], # right arm 394 | [9, 10], [10, 11], # right leg 395 | [11, 20] # right foot 396 | ] 397 | 398 | repair_limb_seq = [repair_limb_seq_left, repair_limb_seq_right] 399 | 400 | for idx_part, part in enumerate(repair_limb_seq): 401 | for idx, limb in enumerate(part): 402 | 403 | k1_index, k2_index = limb 404 | keypoint1 = keypoints[k1_index - 1] 405 | keypoint2 = keypoints[k2_index - 1] 406 | 407 | if keypoint1 != None and keypoint2 is None: 408 | # reference to symmetric limb 409 | sym_limb = repair_limb_seq[1-idx_part][idx] 410 | k1_index_sym, k2_index_sym = sym_limb 411 | keypoint1_sym = keypoints[k1_index_sym - 1] 412 | keypoint2_sym = keypoints[k2_index_sym - 1] 413 | ref_length = 0 414 | 415 | if keypoint1_sym != None and keypoint2_sym != None: 416 | X = np.array([keypoint1_sym[0], keypoint2_sym[0]]) * float(W) 417 | Y = np.array([keypoint1_sym[1], keypoint2_sym[1]]) * float(H) 418 | ref_length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 419 | else: 420 | ref_length_left, ref_length_right = 0, 0 421 | if keypoints[1] != None and keypoints[8] != None: 422 | X = np.array([keypoints[1][0], keypoints[8][0]]) * float(W) 423 | Y = np.array([keypoints[1][1], keypoints[8][1]]) * float(H) 424 | ref_length_left = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 425 | if idx <= 1: # arms 426 | ref_length_left /= 2 427 | 428 | if keypoints[1] != None and keypoints[11] != None: 429 | X = np.array([keypoints[1][0], keypoints[11][0]]) * float(W) 430 | Y = np.array([keypoints[1][1], keypoints[11][1]]) * float(H) 431 | ref_length_right = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 432 | if idx <= 1: # arms 433 | ref_length_right /= 2 434 | elif idx == 4: # foot 435 | ref_length_right /= 5 436 | 437 | ref_length = max(ref_length_left, ref_length_right) 438 | 439 | if ref_length != 0: 440 | skeleton['keypoints_body'][k2_index - 1] = [0, 0] #init 441 | skeleton['keypoints_body'][k2_index - 1][0] = skeleton['keypoints_body'][k1_index - 1][0] 442 | skeleton['keypoints_body'][k2_index - 1][1] = skeleton['keypoints_body'][k1_index - 1][1] + ref_length / H 443 | return skeleton 444 | 445 | 446 | def rescale_shorten_skeleton(ratio_list, src_length_list, dst_length_list): 447 | 448 | modify_bone_list = [ 449 | [0, 1], 450 | [2, 4], 451 | [3, 5], 452 | [6, 9], 453 | [7, 10], 454 | [8, 11], 455 | [17, 18] 456 | ] 457 | 458 | for modify_bone in modify_bone_list: 459 | new_ratio = max(ratio_list[modify_bone[0]], ratio_list[modify_bone[1]]) 460 | ratio_list[modify_bone[0]] = new_ratio 461 | ratio_list[modify_bone[1]] = new_ratio 462 | 463 | if ratio_list[13]!= None and ratio_list[15]!= None: 464 | ratio_eye_avg = (ratio_list[13] + ratio_list[15]) / 2 465 | ratio_list[13] = ratio_eye_avg 466 | ratio_list[15] = ratio_eye_avg 467 | 468 | if ratio_list[14]!= None and ratio_list[16]!= None: 469 | ratio_eye_avg = (ratio_list[14] + ratio_list[16]) / 2 470 | ratio_list[14] = ratio_eye_avg 471 | ratio_list[16] = ratio_eye_avg 472 | 473 | return ratio_list, src_length_list, dst_length_list 474 | 475 | 476 | 477 | def check_full_body(keypoints, threshold = 0.4): 478 | 479 | body_flag = 'half_body' 480 | 481 | # 1. If ankle points exist, confidence is greater than the threshold, and points do not exceed the frame, return full_body 482 | if keypoints[10] != None and keypoints[13] != None and keypoints[8] != None and keypoints[11] != None: 483 | if (keypoints[10][1] <= 1 and keypoints[13][1] <= 1) and (keypoints[10][2] >= threshold and keypoints[13][2] >= threshold) and \ 484 | (keypoints[8][1] <= 1 and keypoints[11][1] <= 1) and (keypoints[8][2] >= threshold and keypoints[11][2] >= threshold): 485 | body_flag = 'full_body' 486 | return body_flag 487 | 488 | # 2. If hip points exist, return three_quarter_body 489 | if (keypoints[8] != None and keypoints[11] != None): 490 | if (keypoints[8][1] <= 1 and keypoints[11][1] <= 1) and (keypoints[8][2] >= threshold and keypoints[11][2] >= threshold): 491 | body_flag = 'three_quarter_body' 492 | return body_flag 493 | 494 | return body_flag 495 | 496 | 497 | def check_full_body_both(flag1, flag2): 498 | body_flag_dict = { 499 | 'full_body': 2, 500 | 'three_quarter_body' : 1, 501 | 'half_body': 0 502 | } 503 | 504 | body_flag_dict_reverse = { 505 | 2: 'full_body', 506 | 1: 'three_quarter_body', 507 | 0: 'half_body' 508 | } 509 | 510 | flag1_num = body_flag_dict[flag1] 511 | flag2_num = body_flag_dict[flag2] 512 | flag_both_num = min(flag1_num, flag2_num) 513 | return body_flag_dict_reverse[flag_both_num] 514 | 515 | 516 | def write_to_poses(data_to_json, none_idx, dst_shape, bone_ratio_list, delta_ground_x, delta_ground_y, rescaled_src_ground_x, body_flag, scale_min): 517 | outputs = [] 518 | length = len(data_to_json) 519 | for id in tqdm(range(length)): 520 | 521 | src_height, src_width = data_to_json[id]['height'], data_to_json[id]['width'] 522 | width, height = dst_shape 523 | keypoints = data_to_json[id]['keypoints_body'] 524 | for idx in range(len(keypoints)): 525 | if idx in none_idx: 526 | keypoints[idx] = None 527 | new_keypoints = keypoints.copy() 528 | 529 | # get hand keypoints 530 | keypoints_hand = {'left' : data_to_json[id]['keypoints_left_hand'], 'right' : data_to_json[id]['keypoints_right_hand']} 531 | # Normalize hand coordinates to 0-1 range 532 | for hand_idx in range(len(data_to_json[id]['keypoints_left_hand'])): 533 | data_to_json[id]['keypoints_left_hand'][hand_idx][0] = data_to_json[id]['keypoints_left_hand'][hand_idx][0] / src_width 534 | data_to_json[id]['keypoints_left_hand'][hand_idx][1] = data_to_json[id]['keypoints_left_hand'][hand_idx][1] / src_height 535 | 536 | for hand_idx in range(len(data_to_json[id]['keypoints_right_hand'])): 537 | data_to_json[id]['keypoints_right_hand'][hand_idx][0] = data_to_json[id]['keypoints_right_hand'][hand_idx][0] / src_width 538 | data_to_json[id]['keypoints_right_hand'][hand_idx][1] = data_to_json[id]['keypoints_right_hand'][hand_idx][1] / src_height 539 | 540 | 541 | frame_info = get_scaled_pose((height, width), (src_height, src_width), new_keypoints, keypoints_hand, bone_ratio_list, delta_ground_x, delta_ground_y, rescaled_src_ground_x, body_flag, id, scale_min) 542 | outputs.append(frame_info) 543 | 544 | return outputs 545 | 546 | 547 | def calculate_scale_ratio(skeleton, skeleton_edit, scale_ratio_flag): 548 | if scale_ratio_flag: 549 | 550 | headw = max(skeleton['keypoints_body'][0][0], skeleton['keypoints_body'][14][0], skeleton['keypoints_body'][15][0], skeleton['keypoints_body'][16][0], skeleton['keypoints_body'][17][0]) - \ 551 | min(skeleton['keypoints_body'][0][0], skeleton['keypoints_body'][14][0], skeleton['keypoints_body'][15][0], skeleton['keypoints_body'][16][0], skeleton['keypoints_body'][17][0]) 552 | headw_edit = max(skeleton_edit['keypoints_body'][0][0], skeleton_edit['keypoints_body'][14][0], skeleton_edit['keypoints_body'][15][0], skeleton_edit['keypoints_body'][16][0], skeleton_edit['keypoints_body'][17][0]) - \ 553 | min(skeleton_edit['keypoints_body'][0][0], skeleton_edit['keypoints_body'][14][0], skeleton_edit['keypoints_body'][15][0], skeleton_edit['keypoints_body'][16][0], skeleton_edit['keypoints_body'][17][0]) 554 | headw_ratio = headw / headw_edit 555 | 556 | _, _, shoulder = get_length(skeleton, [6,3]) 557 | _, _, shoulder_edit = get_length(skeleton_edit, [6,3]) 558 | shoulder_ratio = shoulder / shoulder_edit 559 | 560 | return max(headw_ratio, shoulder_ratio) 561 | 562 | else: 563 | return 1 564 | 565 | 566 | 567 | def retarget_pose(src_skeleton, dst_skeleton, all_src_skeleton, src_skeleton_edit, dst_skeleton_edit, threshold=0.4): 568 | 569 | if src_skeleton_edit is not None and dst_skeleton_edit is not None: 570 | use_edit_for_base = True 571 | else: 572 | use_edit_for_base = False 573 | 574 | src_skeleton_ori = copy.deepcopy(src_skeleton) 575 | 576 | dst_skeleton_ori_h, dst_skeleton_ori_w = dst_skeleton['height'], dst_skeleton['width'] 577 | if src_skeleton['keypoints_body'][0] != None and src_skeleton['keypoints_body'][10] != None and src_skeleton['keypoints_body'][13] != None and \ 578 | dst_skeleton['keypoints_body'][0] != None and dst_skeleton['keypoints_body'][10] != None and dst_skeleton['keypoints_body'][13] != None and \ 579 | src_skeleton['keypoints_body'][0][2] > 0.5 and src_skeleton['keypoints_body'][10][2] > 0.5 and src_skeleton['keypoints_body'][13][2] > 0.5 and \ 580 | dst_skeleton['keypoints_body'][0][2] > 0.5 and dst_skeleton['keypoints_body'][10][2] > 0.5 and dst_skeleton['keypoints_body'][13][2] > 0.5: 581 | 582 | src_height = src_skeleton['height'] * abs( 583 | (src_skeleton['keypoints_body'][10][1] + src_skeleton['keypoints_body'][13][1]) / 2 - 584 | src_skeleton['keypoints_body'][0][1]) 585 | dst_height = dst_skeleton['height'] * abs( 586 | (dst_skeleton['keypoints_body'][10][1] + dst_skeleton['keypoints_body'][13][1]) / 2 - 587 | dst_skeleton['keypoints_body'][0][1]) 588 | scale_min = 1.0 * src_height / dst_height 589 | elif src_skeleton['keypoints_body'][0] != None and src_skeleton['keypoints_body'][8] != None and src_skeleton['keypoints_body'][11] != None and \ 590 | dst_skeleton['keypoints_body'][0] != None and dst_skeleton['keypoints_body'][8] != None and dst_skeleton['keypoints_body'][11] != None and \ 591 | src_skeleton['keypoints_body'][0][2] > 0.5 and src_skeleton['keypoints_body'][8][2] > 0.5 and src_skeleton['keypoints_body'][11][2] > 0.5 and \ 592 | dst_skeleton['keypoints_body'][0][2] > 0.5 and dst_skeleton['keypoints_body'][8][2] > 0.5 and dst_skeleton['keypoints_body'][11][2] > 0.5: 593 | 594 | src_height = src_skeleton['height'] * abs( 595 | (src_skeleton['keypoints_body'][8][1] + src_skeleton['keypoints_body'][11][1]) / 2 - 596 | src_skeleton['keypoints_body'][0][1]) 597 | dst_height = dst_skeleton['height'] * abs( 598 | (dst_skeleton['keypoints_body'][8][1] + dst_skeleton['keypoints_body'][11][1]) / 2 - 599 | dst_skeleton['keypoints_body'][0][1]) 600 | scale_min = 1.0 * src_height / dst_height 601 | else: 602 | scale_min = np.sqrt(src_skeleton['height'] * src_skeleton['width']) / np.sqrt(dst_skeleton['height'] * dst_skeleton['width']) 603 | 604 | if use_edit_for_base: 605 | scale_ratio_flag = False 606 | if src_skeleton_edit['keypoints_body'][0] != None and src_skeleton_edit['keypoints_body'][10] != None and src_skeleton_edit['keypoints_body'][13] != None and \ 607 | dst_skeleton_edit['keypoints_body'][0] != None and dst_skeleton_edit['keypoints_body'][10] != None and dst_skeleton_edit['keypoints_body'][13] != None and \ 608 | src_skeleton_edit['keypoints_body'][0][2] > 0.5 and src_skeleton_edit['keypoints_body'][10][2] > 0.5 and src_skeleton_edit['keypoints_body'][13][2] > 0.5 and \ 609 | dst_skeleton_edit['keypoints_body'][0][2] > 0.5 and dst_skeleton_edit['keypoints_body'][10][2] > 0.5 and dst_skeleton_edit['keypoints_body'][13][2] > 0.5: 610 | 611 | src_height_edit = src_skeleton_edit['height'] * abs( 612 | (src_skeleton_edit['keypoints_body'][10][1] + src_skeleton_edit['keypoints_body'][13][1]) / 2 - 613 | src_skeleton_edit['keypoints_body'][0][1]) 614 | dst_height_edit = dst_skeleton_edit['height'] * abs( 615 | (dst_skeleton_edit['keypoints_body'][10][1] + dst_skeleton_edit['keypoints_body'][13][1]) / 2 - 616 | dst_skeleton_edit['keypoints_body'][0][1]) 617 | scale_min_edit = 1.0 * src_height_edit / dst_height_edit 618 | elif src_skeleton_edit['keypoints_body'][0] != None and src_skeleton_edit['keypoints_body'][8] != None and src_skeleton_edit['keypoints_body'][11] != None and \ 619 | dst_skeleton_edit['keypoints_body'][0] != None and dst_skeleton_edit['keypoints_body'][8] != None and dst_skeleton_edit['keypoints_body'][11] != None and \ 620 | src_skeleton_edit['keypoints_body'][0][2] > 0.5 and src_skeleton_edit['keypoints_body'][8][2] > 0.5 and src_skeleton_edit['keypoints_body'][11][2] > 0.5 and \ 621 | dst_skeleton_edit['keypoints_body'][0][2] > 0.5 and dst_skeleton_edit['keypoints_body'][8][2] > 0.5 and dst_skeleton_edit['keypoints_body'][11][2] > 0.5: 622 | 623 | src_height_edit = src_skeleton_edit['height'] * abs( 624 | (src_skeleton_edit['keypoints_body'][8][1] + src_skeleton_edit['keypoints_body'][11][1]) / 2 - 625 | src_skeleton_edit['keypoints_body'][0][1]) 626 | dst_height_edit = dst_skeleton_edit['height'] * abs( 627 | (dst_skeleton_edit['keypoints_body'][8][1] + dst_skeleton_edit['keypoints_body'][11][1]) / 2 - 628 | dst_skeleton_edit['keypoints_body'][0][1]) 629 | scale_min_edit = 1.0 * src_height_edit / dst_height_edit 630 | else: 631 | scale_min_edit = np.sqrt(src_skeleton_edit['height'] * src_skeleton_edit['width']) / np.sqrt(dst_skeleton_edit['height'] * dst_skeleton_edit['width']) 632 | scale_ratio_flag = True 633 | 634 | # Flux may change the scale, compensate for it here 635 | ratio_src = calculate_scale_ratio(src_skeleton, src_skeleton_edit, scale_ratio_flag) 636 | ratio_dst = calculate_scale_ratio(dst_skeleton, dst_skeleton_edit, scale_ratio_flag) 637 | 638 | dst_skeleton_edit['height'] = int(dst_skeleton_edit['height'] * scale_min_edit) 639 | dst_skeleton_edit['width'] = int(dst_skeleton_edit['width'] * scale_min_edit) 640 | for idx in range(len(dst_skeleton_edit['keypoints_left_hand'])): 641 | dst_skeleton_edit['keypoints_left_hand'][idx][0] *= scale_min_edit 642 | dst_skeleton_edit['keypoints_left_hand'][idx][1] *= scale_min_edit 643 | for idx in range(len(dst_skeleton_edit['keypoints_right_hand'])): 644 | dst_skeleton_edit['keypoints_right_hand'][idx][0] *= scale_min_edit 645 | dst_skeleton_edit['keypoints_right_hand'][idx][1] *= scale_min_edit 646 | 647 | 648 | dst_skeleton['height'] = int(dst_skeleton['height'] * scale_min) 649 | dst_skeleton['width'] = int(dst_skeleton['width'] * scale_min) 650 | for idx in range(len(dst_skeleton['keypoints_left_hand'])): 651 | dst_skeleton['keypoints_left_hand'][idx][0] *= scale_min 652 | dst_skeleton['keypoints_left_hand'][idx][1] *= scale_min 653 | for idx in range(len(dst_skeleton['keypoints_right_hand'])): 654 | dst_skeleton['keypoints_right_hand'][idx][0] *= scale_min 655 | dst_skeleton['keypoints_right_hand'][idx][1] *= scale_min 656 | 657 | 658 | dst_body_flag = check_full_body(dst_skeleton['keypoints_body'], threshold) 659 | src_body_flag = check_full_body(src_skeleton_ori['keypoints_body'], threshold) 660 | body_flag = check_full_body_both(dst_body_flag, src_body_flag) 661 | #print('body_flag: ', body_flag) 662 | 663 | if use_edit_for_base: 664 | src_skeleton_edit = fix_lack_keypoints_use_sym(src_skeleton_edit) 665 | dst_skeleton_edit = fix_lack_keypoints_use_sym(dst_skeleton_edit) 666 | else: 667 | src_skeleton = fix_lack_keypoints_use_sym(src_skeleton) 668 | dst_skeleton = fix_lack_keypoints_use_sym(dst_skeleton) 669 | 670 | none_idx = [] 671 | for idx in range(len(dst_skeleton['keypoints_body'])): 672 | if dst_skeleton['keypoints_body'][idx] == None or src_skeleton['keypoints_body'][idx] == None: 673 | src_skeleton['keypoints_body'][idx] = None 674 | dst_skeleton['keypoints_body'][idx] = None 675 | none_idx.append(idx) 676 | 677 | # get bone ratio list 678 | ratio_list, src_length_list, dst_length_list = [], [], [] 679 | for idx, limb in enumerate(limbSeq): 680 | if use_edit_for_base: 681 | src_X, src_Y, src_length = get_length(src_skeleton_edit, limb) 682 | dst_X, dst_Y, dst_length = get_length(dst_skeleton_edit, limb) 683 | 684 | if src_X is None or src_Y is None or dst_X is None or dst_Y is None: 685 | ratio = -1 686 | else: 687 | ratio = 1.0 * dst_length * ratio_dst / src_length / ratio_src 688 | 689 | else: 690 | src_X, src_Y, src_length = get_length(src_skeleton, limb) 691 | dst_X, dst_Y, dst_length = get_length(dst_skeleton, limb) 692 | 693 | if src_X is None or src_Y is None or dst_X is None or dst_Y is None: 694 | ratio = -1 695 | else: 696 | ratio = 1.0 * dst_length / src_length 697 | 698 | ratio_list.append(ratio) 699 | src_length_list.append(src_length) 700 | dst_length_list.append(dst_length) 701 | 702 | for idx, ratio in enumerate(ratio_list): 703 | if ratio == -1: 704 | if ratio_list[0] != -1 and ratio_list[1] != -1: 705 | ratio_list[idx] = (ratio_list[0] + ratio_list[1]) / 2 706 | 707 | # Consider adding constraints when Flux fails to correct head pose, causing neck issues. 708 | # if ratio_list[12] > (ratio_list[0]+ratio_list[1])/2*1.25: 709 | # ratio_list[12] = (ratio_list[0]+ratio_list[1])/2*1.25 710 | 711 | ratio_list, src_length_list, dst_length_list = rescale_shorten_skeleton(ratio_list, src_length_list, dst_length_list) 712 | 713 | rescaled_src_skeleton_ori = rescale_skeleton(src_skeleton_ori['height'], src_skeleton_ori['width'], 714 | src_skeleton_ori['keypoints_body'], ratio_list) 715 | 716 | # get global translation offset_x and offset_y 717 | if body_flag == 'full_body': 718 | #print('use foot mark.') 719 | dst_ground_y = max(dst_skeleton['keypoints_body'][10][1], dst_skeleton['keypoints_body'][13][1]) * dst_skeleton[ 720 | 'height'] 721 | # The midpoint between toe and ankle 722 | if dst_skeleton['keypoints_body'][18] != None and dst_skeleton['keypoints_body'][19] != None: 723 | right_foot_mid = (dst_skeleton['keypoints_body'][10][1] + dst_skeleton['keypoints_body'][19][1]) / 2 724 | left_foot_mid = (dst_skeleton['keypoints_body'][13][1] + dst_skeleton['keypoints_body'][18][1]) / 2 725 | dst_ground_y = max(left_foot_mid, right_foot_mid) * dst_skeleton['height'] 726 | 727 | rescaled_src_ground_y = max(rescaled_src_skeleton_ori[10][1], rescaled_src_skeleton_ori[13][1]) 728 | delta_ground_y = rescaled_src_ground_y - dst_ground_y 729 | 730 | dst_ground_x = (dst_skeleton['keypoints_body'][8][0] + dst_skeleton['keypoints_body'][11][0]) * dst_skeleton[ 731 | 'width'] / 2 732 | rescaled_src_ground_x = (rescaled_src_skeleton_ori[8][0] + rescaled_src_skeleton_ori[11][0]) / 2 733 | delta_ground_x = rescaled_src_ground_x - dst_ground_x 734 | delta_x, delta_y = delta_ground_x, delta_ground_y 735 | 736 | else: 737 | #print('use neck mark.') 738 | # use neck keypoint as mark 739 | src_neck_y = rescaled_src_skeleton_ori[1][1] 740 | dst_neck_y = dst_skeleton['keypoints_body'][1][1] 741 | delta_neck_y = src_neck_y - dst_neck_y * dst_skeleton['height'] 742 | 743 | src_neck_x = rescaled_src_skeleton_ori[1][0] 744 | dst_neck_x = dst_skeleton['keypoints_body'][1][0] 745 | delta_neck_x = src_neck_x - dst_neck_x * dst_skeleton['width'] 746 | delta_x, delta_y = delta_neck_x, delta_neck_y 747 | rescaled_src_ground_x = src_neck_x 748 | 749 | 750 | dst_shape = (dst_skeleton_ori_w, dst_skeleton_ori_h) 751 | output = write_to_poses(all_src_skeleton, none_idx, dst_shape, ratio_list, delta_x, delta_y, 752 | rescaled_src_ground_x, body_flag, scale_min) 753 | return output 754 | 755 | 756 | def get_retarget_pose(tpl_pose_meta0, refer_pose_meta, tpl_pose_metas, tql_edit_pose_meta0, refer_edit_pose_meta): 757 | 758 | for key, value in tpl_pose_meta0.items(): 759 | if type(value) is np.ndarray: 760 | if key in ['keypoints_left_hand', 'keypoints_right_hand']: 761 | value = value * np.array([[tpl_pose_meta0["width"], tpl_pose_meta0["height"], 1.0]]) 762 | if not isinstance(value, list): 763 | value = value.tolist() 764 | tpl_pose_meta0[key] = value 765 | 766 | for key, value in refer_pose_meta.items(): 767 | if type(value) is np.ndarray: 768 | if key in ['keypoints_left_hand', 'keypoints_right_hand']: 769 | value = value * np.array([[refer_pose_meta["width"], refer_pose_meta["height"], 1.0]]) 770 | if not isinstance(value, list): 771 | value = value.tolist() 772 | refer_pose_meta[key] = value 773 | 774 | tpl_pose_metas_new = [] 775 | for meta in tpl_pose_metas: 776 | for key, value in meta.items(): 777 | if type(value) is np.ndarray: 778 | if key in ['keypoints_left_hand', 'keypoints_right_hand']: 779 | value = value * np.array([[meta["width"], meta["height"], 1.0]]) 780 | if not isinstance(value, list): 781 | value = value.tolist() 782 | meta[key] = value 783 | tpl_pose_metas_new.append(meta) 784 | 785 | if tql_edit_pose_meta0 is not None: 786 | for key, value in tql_edit_pose_meta0.items(): 787 | if type(value) is np.ndarray: 788 | if key in ['keypoints_left_hand', 'keypoints_right_hand']: 789 | value = value * np.array([[tql_edit_pose_meta0["width"], tql_edit_pose_meta0["height"], 1.0]]) 790 | if not isinstance(value, list): 791 | value = value.tolist() 792 | tql_edit_pose_meta0[key] = value 793 | 794 | if refer_edit_pose_meta is not None: 795 | for key, value in refer_edit_pose_meta.items(): 796 | if type(value) is np.ndarray: 797 | if key in ['keypoints_left_hand', 'keypoints_right_hand']: 798 | value = value * np.array([[refer_edit_pose_meta["width"], refer_edit_pose_meta["height"], 1.0]]) 799 | if not isinstance(value, list): 800 | value = value.tolist() 801 | refer_edit_pose_meta[key] = value 802 | 803 | retarget_tpl_pose_metas = retarget_pose(tpl_pose_meta0, refer_pose_meta, tpl_pose_metas_new, tql_edit_pose_meta0, refer_edit_pose_meta) 804 | 805 | pose_metas = [] 806 | for meta in retarget_tpl_pose_metas: 807 | pose_meta = AAPoseMeta() 808 | width, height = meta["width"], meta["height"] 809 | pose_meta.width = width 810 | pose_meta.height = height 811 | pose_meta.kps_body = np.array(meta["keypoints_body"])[:, :2] * (width, height) 812 | pose_meta.kps_body_p = np.array(meta["keypoints_body"])[:, 2] 813 | 814 | kps_lhand = [] 815 | kps_lhand_p = [] 816 | for each_kps_lhand in meta["keypoints_left_hand"]: 817 | if each_kps_lhand is not None: 818 | kps_lhand.append([each_kps_lhand.x, each_kps_lhand.y]) 819 | kps_lhand_p.append(each_kps_lhand.score) 820 | else: 821 | kps_lhand.append([None, None]) 822 | kps_lhand_p.append(0.0) 823 | 824 | pose_meta.kps_lhand = np.array(kps_lhand) 825 | pose_meta.kps_lhand_p = np.array(kps_lhand_p) 826 | 827 | kps_rhand = [] 828 | kps_rhand_p = [] 829 | for each_kps_rhand in meta["keypoints_right_hand"]: 830 | if each_kps_rhand is not None: 831 | kps_rhand.append([each_kps_rhand.x, each_kps_rhand.y]) 832 | kps_rhand_p.append(each_kps_rhand.score) 833 | else: 834 | kps_rhand.append([None, None]) 835 | kps_rhand_p.append(0.0) 836 | 837 | pose_meta.kps_rhand = np.array(kps_rhand) 838 | pose_meta.kps_rhand_p = np.array(kps_rhand_p) 839 | 840 | pose_metas.append(pose_meta) 841 | 842 | return pose_metas 843 | 844 | -------------------------------------------------------------------------------- /pose_utils/pose2d_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import warnings 3 | import cv2 4 | import numpy as np 5 | from typing import List 6 | 7 | def box_convert_simple(box, convert_type='xyxy2xywh'): 8 | if convert_type == 'xyxy2xywh': 9 | return [box[0], box[1], box[2] - box[0], box[3] - box[1]] 10 | elif convert_type == 'xywh2xyxy': 11 | return [box[0], box[1], box[2] + box[0], box[3] + box[1]] 12 | elif convert_type == 'xyxy2ctwh': 13 | return [(box[0] + box[2]) / 2, (box[1] + box[3]) / 2, box[2] - box[0], box[3] - box[1]] 14 | elif convert_type == 'ctwh2xyxy': 15 | return [box[0] - box[2] // 2, box[1] - box[3] // 2, box[0] + (box[2] - box[2] // 2), box[1] + (box[3] - box[3] // 2)] 16 | 17 | class AAPoseMeta: 18 | def __init__(self, meta=None, kp2ds=None): 19 | self.image_id = "" 20 | self.height = 0 21 | self.width = 0 22 | 23 | self.kps_body: np.ndarray = None 24 | self.kps_lhand: np.ndarray = None 25 | self.kps_rhand: np.ndarray = None 26 | self.kps_face: np.ndarray = None 27 | self.kps_body_p: np.ndarray = None 28 | self.kps_lhand_p: np.ndarray = None 29 | self.kps_rhand_p: np.ndarray = None 30 | self.kps_face_p: np.ndarray = None 31 | 32 | 33 | if meta is not None: 34 | self.load_from_meta(meta) 35 | elif kp2ds is not None: 36 | self.load_from_kp2ds(kp2ds) 37 | 38 | def is_valid(self, kp, p, threshold): 39 | x, y = kp 40 | if x < 0 or y < 0 or x > self.width or y > self.height or p < threshold: 41 | return False 42 | else: 43 | return True 44 | 45 | def get_bbox(self, kp, kp_p, threshold=0.5): 46 | kps = kp[kp_p > threshold] 47 | if kps.size == 0: 48 | return 0, 0, 0, 0 49 | x0, y0 = kps.min(axis=0) 50 | x1, y1 = kps.max(axis=0) 51 | return x0, y0, x1, y1 52 | 53 | def crop(self, x0, y0, x1, y1): 54 | all_kps = [self.kps_body, self.kps_lhand, self.kps_rhand, self.kps_face] 55 | for kps in all_kps: 56 | if kps is not None: 57 | kps[:, 0] -= x0 58 | kps[:, 1] -= y0 59 | self.width = x1 - x0 60 | self.height = y1 - y0 61 | return self 62 | 63 | def resize(self, width, height): 64 | scale_x = width / self.width 65 | scale_y = height / self.height 66 | all_kps = [self.kps_body, self.kps_lhand, self.kps_rhand, self.kps_face] 67 | for kps in all_kps: 68 | if kps is not None: 69 | kps[:, 0] *= scale_x 70 | kps[:, 1] *= scale_y 71 | self.width = width 72 | self.height = height 73 | return self 74 | 75 | 76 | def get_kps_body_with_p(self, normalize=False): 77 | kps_body = self.kps_body.copy() 78 | if normalize: 79 | kps_body = kps_body / np.array([self.width, self.height]) 80 | 81 | return np.concatenate([kps_body, self.kps_body_p[:, None]]) 82 | 83 | @staticmethod 84 | def from_kps_face(kps_face: np.ndarray, height: int, width: int): 85 | 86 | pose_meta = AAPoseMeta() 87 | pose_meta.kps_face = kps_face[:, :2] 88 | if kps_face.shape[1] == 3: 89 | pose_meta.kps_face_p = kps_face[:, 2] 90 | else: 91 | pose_meta.kps_face_p = kps_face[:, 0] * 0 + 1 92 | pose_meta.height = height 93 | pose_meta.width = width 94 | return pose_meta 95 | 96 | @staticmethod 97 | def from_kps_body(kps_body: np.ndarray, height: int, width: int): 98 | 99 | pose_meta = AAPoseMeta() 100 | pose_meta.kps_body = kps_body[:, :2] 101 | pose_meta.kps_body_p = kps_body[:, 2] 102 | pose_meta.height = height 103 | pose_meta.width = width 104 | return pose_meta 105 | @staticmethod 106 | def from_humanapi_meta(meta): 107 | pose_meta = AAPoseMeta() 108 | width, height = meta["width"], meta["height"] 109 | pose_meta.width = width 110 | pose_meta.height = height 111 | pose_meta.kps_body = meta["keypoints_body"][:, :2] * (width, height) 112 | pose_meta.kps_body_p = meta["keypoints_body"][:, 2] 113 | pose_meta.kps_lhand = meta["keypoints_left_hand"][:, :2] * (width, height) 114 | pose_meta.kps_lhand_p = meta["keypoints_left_hand"][:, 2] 115 | pose_meta.kps_rhand = meta["keypoints_right_hand"][:, :2] * (width, height) 116 | pose_meta.kps_rhand_p = meta["keypoints_right_hand"][:, 2] 117 | if 'keypoints_face' in meta: 118 | pose_meta.kps_face = meta["keypoints_face"][:, :2] * (width, height) 119 | pose_meta.kps_face_p = meta["keypoints_face"][:, 2] 120 | return pose_meta 121 | 122 | def load_from_meta(self, meta, norm_body=True, norm_hand=False): 123 | 124 | self.image_id = meta.get("image_id", "00000.png") 125 | self.height = meta["height"] 126 | self.width = meta["width"] 127 | kps_body_p = [] 128 | kps_body = [] 129 | for kp in meta["keypoints_body"]: 130 | if kp is None: 131 | kps_body.append([0, 0]) 132 | kps_body_p.append(0) 133 | else: 134 | kps_body.append(kp) 135 | kps_body_p.append(1) 136 | 137 | self.kps_body = np.array(kps_body) 138 | self.kps_body[:, 0] *= self.width 139 | self.kps_body[:, 1] *= self.height 140 | self.kps_body_p = np.array(kps_body_p) 141 | 142 | self.kps_lhand = np.array(meta["keypoints_left_hand"])[:, :2] 143 | self.kps_lhand_p = np.array(meta["keypoints_left_hand"])[:, 2] 144 | self.kps_rhand = np.array(meta["keypoints_right_hand"])[:, :2] 145 | self.kps_rhand_p = np.array(meta["keypoints_right_hand"])[:, 2] 146 | 147 | @staticmethod 148 | def load_from_kp2ds(kp2ds: List[np.ndarray], width: int, height: int): 149 | """input 133x3 numpy keypoints and output AAPoseMeta 150 | 151 | Args: 152 | kp2ds (List[np.ndarray]): _description_ 153 | width (int): _description_ 154 | height (int): _description_ 155 | 156 | Returns: 157 | _type_: _description_ 158 | """ 159 | pose_meta = AAPoseMeta() 160 | pose_meta.width = width 161 | pose_meta.height = height 162 | kps_body = (kp2ds[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + kp2ds[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2 163 | kps_lhand = kp2ds[91:112] 164 | kps_rhand = kp2ds[112:133] 165 | kps_face = np.concatenate([kp2ds[23:23+68], kp2ds[1:3]], axis=0) 166 | pose_meta.kps_body = kps_body[:, :2] 167 | pose_meta.kps_body_p = kps_body[:, 2] 168 | pose_meta.kps_lhand = kps_lhand[:, :2] 169 | pose_meta.kps_lhand_p = kps_lhand[:, 2] 170 | pose_meta.kps_rhand = kps_rhand[:, :2] 171 | pose_meta.kps_rhand_p = kps_rhand[:, 2] 172 | pose_meta.kps_face = kps_face[:, :2] 173 | pose_meta.kps_face_p = kps_face[:, 2] 174 | return pose_meta 175 | 176 | @staticmethod 177 | def from_dwpose(dwpose_det_res, height, width): 178 | pose_meta = AAPoseMeta() 179 | pose_meta.kps_body = dwpose_det_res["bodies"]["candidate"] 180 | pose_meta.kps_body_p = dwpose_det_res["bodies"]["score"] 181 | pose_meta.kps_body[:, 0] *= width 182 | pose_meta.kps_body[:, 1] *= height 183 | 184 | pose_meta.kps_lhand, pose_meta.kps_rhand = dwpose_det_res["hands"] 185 | pose_meta.kps_lhand[:, 0] *= width 186 | pose_meta.kps_lhand[:, 1] *= height 187 | pose_meta.kps_rhand[:, 0] *= width 188 | pose_meta.kps_rhand[:, 1] *= height 189 | pose_meta.kps_lhand_p, pose_meta.kps_rhand_p = dwpose_det_res["hands_score"] 190 | 191 | pose_meta.kps_face = dwpose_det_res["faces"][0] 192 | pose_meta.kps_face[:, 0] *= width 193 | pose_meta.kps_face[:, 1] *= height 194 | pose_meta.kps_face_p = dwpose_det_res["faces_score"][0] 195 | return pose_meta 196 | 197 | def save_json(self): 198 | pass 199 | 200 | def draw_aapose(self, img, threshold=0.5, stick_width_norm=200, draw_hand=True, draw_head=True): 201 | from .human_visualization import draw_aapose_by_meta 202 | return draw_aapose_by_meta(img, self, threshold, stick_width_norm, draw_hand, draw_head) 203 | 204 | 205 | def translate(self, x0, y0): 206 | all_kps = [self.kps_body, self.kps_lhand, self.kps_rhand, self.kps_face] 207 | for kps in all_kps: 208 | if kps is not None: 209 | kps[:, 0] -= x0 210 | kps[:, 1] -= y0 211 | 212 | def scale(self, sx, sy): 213 | all_kps = [self.kps_body, self.kps_lhand, self.kps_rhand, self.kps_face] 214 | for kps in all_kps: 215 | if kps is not None: 216 | kps[:, 0] *= sx 217 | kps[:, 1] *= sy 218 | 219 | def padding_resize2(self, height=512, width=512): 220 | """kps will be changed inplace 221 | 222 | """ 223 | 224 | all_kps = [self.kps_body, self.kps_lhand, self.kps_rhand, self.kps_face] 225 | 226 | ori_height, ori_width = self.height, self.width 227 | 228 | if (ori_height / ori_width) > (height / width): 229 | new_width = int(height / ori_height * ori_width) 230 | padding = int((width - new_width) / 2) 231 | padding_width = padding 232 | padding_height = 0 233 | scale = height / ori_height 234 | 235 | for kps in all_kps: 236 | if kps is not None: 237 | kps[:, 0] = kps[:, 0] * scale + padding 238 | kps[:, 1] = kps[:, 1] * scale 239 | 240 | else: 241 | new_height = int(width / ori_width * ori_height) 242 | padding = int((height - new_height) / 2) 243 | padding_width = 0 244 | padding_height = padding 245 | scale = width / ori_width 246 | for kps in all_kps: 247 | if kps is not None: 248 | kps[:, 1] = kps[:, 1] * scale + padding 249 | kps[:, 0] = kps[:, 0] * scale 250 | 251 | 252 | self.width = width 253 | self.height = height 254 | return self 255 | 256 | 257 | def transform_preds(coords, center, scale, output_size, use_udp=False): 258 | """Get final keypoint predictions from heatmaps and apply scaling and 259 | translation to map them back to the image. 260 | 261 | Note: 262 | num_keypoints: K 263 | 264 | Args: 265 | coords (np.ndarray[K, ndims]): 266 | 267 | * If ndims=2, corrds are predicted keypoint location. 268 | * If ndims=4, corrds are composed of (x, y, scores, tags) 269 | * If ndims=5, corrds are composed of (x, y, scores, tags, 270 | flipped_tags) 271 | 272 | center (np.ndarray[2, ]): Center of the bounding box (x, y). 273 | scale (np.ndarray[2, ]): Scale of the bounding box 274 | wrt [width, height]. 275 | output_size (np.ndarray[2, ] | list(2,)): Size of the 276 | destination heatmaps. 277 | use_udp (bool): Use unbiased data processing 278 | 279 | Returns: 280 | np.ndarray: Predicted coordinates in the images. 281 | """ 282 | assert coords.shape[1] in (2, 4, 5) 283 | assert len(center) == 2 284 | assert len(scale) == 2 285 | assert len(output_size) == 2 286 | 287 | # Recover the scale which is normalized by a factor of 200. 288 | # scale = scale * 200.0 289 | 290 | if use_udp: 291 | scale_x = scale[0] / (output_size[0] - 1.0) 292 | scale_y = scale[1] / (output_size[1] - 1.0) 293 | else: 294 | scale_x = scale[0] / output_size[0] 295 | scale_y = scale[1] / output_size[1] 296 | 297 | target_coords = np.ones_like(coords) 298 | target_coords[:, 0] = coords[:, 0] * scale_x + center[0] - scale[0] * 0.5 299 | target_coords[:, 1] = coords[:, 1] * scale_y + center[1] - scale[1] * 0.5 300 | 301 | return target_coords 302 | 303 | 304 | def _calc_distances(preds, targets, mask, normalize): 305 | """Calculate the normalized distances between preds and target. 306 | 307 | Note: 308 | batch_size: N 309 | num_keypoints: K 310 | dimension of keypoints: D (normally, D=2 or D=3) 311 | 312 | Args: 313 | preds (np.ndarray[N, K, D]): Predicted keypoint location. 314 | targets (np.ndarray[N, K, D]): Groundtruth keypoint location. 315 | mask (np.ndarray[N, K]): Visibility of the target. False for invisible 316 | joints, and True for visible. Invisible joints will be ignored for 317 | accuracy calculation. 318 | normalize (np.ndarray[N, D]): Typical value is heatmap_size 319 | 320 | Returns: 321 | np.ndarray[K, N]: The normalized distances. \ 322 | If target keypoints are missing, the distance is -1. 323 | """ 324 | N, K, _ = preds.shape 325 | # set mask=0 when normalize==0 326 | _mask = mask.copy() 327 | _mask[np.where((normalize == 0).sum(1))[0], :] = False 328 | distances = np.full((N, K), -1, dtype=np.float32) 329 | # handle invalid values 330 | normalize[np.where(normalize <= 0)] = 1e6 331 | distances[_mask] = np.linalg.norm( 332 | ((preds - targets) / normalize[:, None, :])[_mask], axis=-1) 333 | return distances.T 334 | 335 | 336 | def _distance_acc(distances, thr=0.5): 337 | """Return the percentage below the distance threshold, while ignoring 338 | distances values with -1. 339 | 340 | Note: 341 | batch_size: N 342 | Args: 343 | distances (np.ndarray[N, ]): The normalized distances. 344 | thr (float): Threshold of the distances. 345 | 346 | Returns: 347 | float: Percentage of distances below the threshold. \ 348 | If all target keypoints are missing, return -1. 349 | """ 350 | distance_valid = distances != -1 351 | num_distance_valid = distance_valid.sum() 352 | if num_distance_valid > 0: 353 | return (distances[distance_valid] < thr).sum() / num_distance_valid 354 | return -1 355 | 356 | 357 | def _get_max_preds(heatmaps): 358 | """Get keypoint predictions from score maps. 359 | 360 | Note: 361 | batch_size: N 362 | num_keypoints: K 363 | heatmap height: H 364 | heatmap width: W 365 | 366 | Args: 367 | heatmaps (np.ndarray[N, K, H, W]): model predicted heatmaps. 368 | 369 | Returns: 370 | tuple: A tuple containing aggregated results. 371 | 372 | - preds (np.ndarray[N, K, 2]): Predicted keypoint location. 373 | - maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints. 374 | """ 375 | assert isinstance(heatmaps, 376 | np.ndarray), ('heatmaps should be numpy.ndarray') 377 | assert heatmaps.ndim == 4, 'batch_images should be 4-ndim' 378 | 379 | N, K, _, W = heatmaps.shape 380 | heatmaps_reshaped = heatmaps.reshape((N, K, -1)) 381 | idx = np.argmax(heatmaps_reshaped, 2).reshape((N, K, 1)) 382 | maxvals = np.amax(heatmaps_reshaped, 2).reshape((N, K, 1)) 383 | 384 | preds = np.tile(idx, (1, 1, 2)).astype(np.float32) 385 | preds[:, :, 0] = preds[:, :, 0] % W 386 | preds[:, :, 1] = preds[:, :, 1] // W 387 | 388 | preds = np.where(np.tile(maxvals, (1, 1, 2)) > 0.0, preds, -1) 389 | return preds, maxvals 390 | 391 | 392 | def _get_max_preds_3d(heatmaps): 393 | """Get keypoint predictions from 3D score maps. 394 | 395 | Note: 396 | batch size: N 397 | num keypoints: K 398 | heatmap depth size: D 399 | heatmap height: H 400 | heatmap width: W 401 | 402 | Args: 403 | heatmaps (np.ndarray[N, K, D, H, W]): model predicted heatmaps. 404 | 405 | Returns: 406 | tuple: A tuple containing aggregated results. 407 | 408 | - preds (np.ndarray[N, K, 3]): Predicted keypoint location. 409 | - maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints. 410 | """ 411 | assert isinstance(heatmaps, np.ndarray), \ 412 | ('heatmaps should be numpy.ndarray') 413 | assert heatmaps.ndim == 5, 'heatmaps should be 5-ndim' 414 | 415 | N, K, D, H, W = heatmaps.shape 416 | heatmaps_reshaped = heatmaps.reshape((N, K, -1)) 417 | idx = np.argmax(heatmaps_reshaped, 2).reshape((N, K, 1)) 418 | maxvals = np.amax(heatmaps_reshaped, 2).reshape((N, K, 1)) 419 | 420 | preds = np.zeros((N, K, 3), dtype=np.float32) 421 | _idx = idx[..., 0] 422 | preds[..., 2] = _idx // (H * W) 423 | preds[..., 1] = (_idx // W) % H 424 | preds[..., 0] = _idx % W 425 | 426 | preds = np.where(maxvals > 0.0, preds, -1) 427 | return preds, maxvals 428 | 429 | 430 | def pose_pck_accuracy(output, target, mask, thr=0.05, normalize=None): 431 | """Calculate the pose accuracy of PCK for each individual keypoint and the 432 | averaged accuracy across all keypoints from heatmaps. 433 | 434 | Note: 435 | PCK metric measures accuracy of the localization of the body joints. 436 | The distances between predicted positions and the ground-truth ones 437 | are typically normalized by the bounding box size. 438 | The threshold (thr) of the normalized distance is commonly set 439 | as 0.05, 0.1 or 0.2 etc. 440 | 441 | - batch_size: N 442 | - num_keypoints: K 443 | - heatmap height: H 444 | - heatmap width: W 445 | 446 | Args: 447 | output (np.ndarray[N, K, H, W]): Model output heatmaps. 448 | target (np.ndarray[N, K, H, W]): Groundtruth heatmaps. 449 | mask (np.ndarray[N, K]): Visibility of the target. False for invisible 450 | joints, and True for visible. Invisible joints will be ignored for 451 | accuracy calculation. 452 | thr (float): Threshold of PCK calculation. Default 0.05. 453 | normalize (np.ndarray[N, 2]): Normalization factor for H&W. 454 | 455 | Returns: 456 | tuple: A tuple containing keypoint accuracy. 457 | 458 | - np.ndarray[K]: Accuracy of each keypoint. 459 | - float: Averaged accuracy across all keypoints. 460 | - int: Number of valid keypoints. 461 | """ 462 | N, K, H, W = output.shape 463 | if K == 0: 464 | return None, 0, 0 465 | if normalize is None: 466 | normalize = np.tile(np.array([[H, W]]), (N, 1)) 467 | 468 | pred, _ = _get_max_preds(output) 469 | gt, _ = _get_max_preds(target) 470 | return keypoint_pck_accuracy(pred, gt, mask, thr, normalize) 471 | 472 | 473 | def keypoint_pck_accuracy(pred, gt, mask, thr, normalize): 474 | """Calculate the pose accuracy of PCK for each individual keypoint and the 475 | averaged accuracy across all keypoints for coordinates. 476 | 477 | Note: 478 | PCK metric measures accuracy of the localization of the body joints. 479 | The distances between predicted positions and the ground-truth ones 480 | are typically normalized by the bounding box size. 481 | The threshold (thr) of the normalized distance is commonly set 482 | as 0.05, 0.1 or 0.2 etc. 483 | 484 | - batch_size: N 485 | - num_keypoints: K 486 | 487 | Args: 488 | pred (np.ndarray[N, K, 2]): Predicted keypoint location. 489 | gt (np.ndarray[N, K, 2]): Groundtruth keypoint location. 490 | mask (np.ndarray[N, K]): Visibility of the target. False for invisible 491 | joints, and True for visible. Invisible joints will be ignored for 492 | accuracy calculation. 493 | thr (float): Threshold of PCK calculation. 494 | normalize (np.ndarray[N, 2]): Normalization factor for H&W. 495 | 496 | Returns: 497 | tuple: A tuple containing keypoint accuracy. 498 | 499 | - acc (np.ndarray[K]): Accuracy of each keypoint. 500 | - avg_acc (float): Averaged accuracy across all keypoints. 501 | - cnt (int): Number of valid keypoints. 502 | """ 503 | distances = _calc_distances(pred, gt, mask, normalize) 504 | 505 | acc = np.array([_distance_acc(d, thr) for d in distances]) 506 | valid_acc = acc[acc >= 0] 507 | cnt = len(valid_acc) 508 | avg_acc = valid_acc.mean() if cnt > 0 else 0 509 | return acc, avg_acc, cnt 510 | 511 | 512 | def keypoint_auc(pred, gt, mask, normalize, num_step=20): 513 | """Calculate the pose accuracy of PCK for each individual keypoint and the 514 | averaged accuracy across all keypoints for coordinates. 515 | 516 | Note: 517 | - batch_size: N 518 | - num_keypoints: K 519 | 520 | Args: 521 | pred (np.ndarray[N, K, 2]): Predicted keypoint location. 522 | gt (np.ndarray[N, K, 2]): Groundtruth keypoint location. 523 | mask (np.ndarray[N, K]): Visibility of the target. False for invisible 524 | joints, and True for visible. Invisible joints will be ignored for 525 | accuracy calculation. 526 | normalize (float): Normalization factor. 527 | 528 | Returns: 529 | float: Area under curve. 530 | """ 531 | nor = np.tile(np.array([[normalize, normalize]]), (pred.shape[0], 1)) 532 | x = [1.0 * i / num_step for i in range(num_step)] 533 | y = [] 534 | for thr in x: 535 | _, avg_acc, _ = keypoint_pck_accuracy(pred, gt, mask, thr, nor) 536 | y.append(avg_acc) 537 | 538 | auc = 0 539 | for i in range(num_step): 540 | auc += 1.0 / num_step * y[i] 541 | return auc 542 | 543 | 544 | def keypoint_nme(pred, gt, mask, normalize_factor): 545 | """Calculate the normalized mean error (NME). 546 | 547 | Note: 548 | - batch_size: N 549 | - num_keypoints: K 550 | 551 | Args: 552 | pred (np.ndarray[N, K, 2]): Predicted keypoint location. 553 | gt (np.ndarray[N, K, 2]): Groundtruth keypoint location. 554 | mask (np.ndarray[N, K]): Visibility of the target. False for invisible 555 | joints, and True for visible. Invisible joints will be ignored for 556 | accuracy calculation. 557 | normalize_factor (np.ndarray[N, 2]): Normalization factor. 558 | 559 | Returns: 560 | float: normalized mean error 561 | """ 562 | distances = _calc_distances(pred, gt, mask, normalize_factor) 563 | distance_valid = distances[distances != -1] 564 | return distance_valid.sum() / max(1, len(distance_valid)) 565 | 566 | 567 | def keypoint_epe(pred, gt, mask): 568 | """Calculate the end-point error. 569 | 570 | Note: 571 | - batch_size: N 572 | - num_keypoints: K 573 | 574 | Args: 575 | pred (np.ndarray[N, K, 2]): Predicted keypoint location. 576 | gt (np.ndarray[N, K, 2]): Groundtruth keypoint location. 577 | mask (np.ndarray[N, K]): Visibility of the target. False for invisible 578 | joints, and True for visible. Invisible joints will be ignored for 579 | accuracy calculation. 580 | 581 | Returns: 582 | float: Average end-point error. 583 | """ 584 | 585 | distances = _calc_distances( 586 | pred, gt, mask, 587 | np.ones((pred.shape[0], pred.shape[2]), dtype=np.float32)) 588 | distance_valid = distances[distances != -1] 589 | return distance_valid.sum() / max(1, len(distance_valid)) 590 | 591 | 592 | def _taylor(heatmap, coord): 593 | """Distribution aware coordinate decoding method. 594 | 595 | Note: 596 | - heatmap height: H 597 | - heatmap width: W 598 | 599 | Args: 600 | heatmap (np.ndarray[H, W]): Heatmap of a particular joint type. 601 | coord (np.ndarray[2,]): Coordinates of the predicted keypoints. 602 | 603 | Returns: 604 | np.ndarray[2,]: Updated coordinates. 605 | """ 606 | H, W = heatmap.shape[:2] 607 | px, py = int(coord[0]), int(coord[1]) 608 | if 1 < px < W - 2 and 1 < py < H - 2: 609 | dx = 0.5 * (heatmap[py][px + 1] - heatmap[py][px - 1]) 610 | dy = 0.5 * (heatmap[py + 1][px] - heatmap[py - 1][px]) 611 | dxx = 0.25 * ( 612 | heatmap[py][px + 2] - 2 * heatmap[py][px] + heatmap[py][px - 2]) 613 | dxy = 0.25 * ( 614 | heatmap[py + 1][px + 1] - heatmap[py - 1][px + 1] - 615 | heatmap[py + 1][px - 1] + heatmap[py - 1][px - 1]) 616 | dyy = 0.25 * ( 617 | heatmap[py + 2 * 1][px] - 2 * heatmap[py][px] + 618 | heatmap[py - 2 * 1][px]) 619 | derivative = np.array([[dx], [dy]]) 620 | hessian = np.array([[dxx, dxy], [dxy, dyy]]) 621 | if dxx * dyy - dxy**2 != 0: 622 | hessianinv = np.linalg.inv(hessian) 623 | offset = -hessianinv @ derivative 624 | offset = np.squeeze(np.array(offset.T), axis=0) 625 | coord += offset 626 | return coord 627 | 628 | 629 | def post_dark_udp(coords, batch_heatmaps, kernel=3): 630 | """DARK post-pocessing. Implemented by udp. Paper ref: Huang et al. The 631 | Devil is in the Details: Delving into Unbiased Data Processing for Human 632 | Pose Estimation (CVPR 2020). Zhang et al. Distribution-Aware Coordinate 633 | Representation for Human Pose Estimation (CVPR 2020). 634 | 635 | Note: 636 | - batch size: B 637 | - num keypoints: K 638 | - num persons: N 639 | - height of heatmaps: H 640 | - width of heatmaps: W 641 | 642 | B=1 for bottom_up paradigm where all persons share the same heatmap. 643 | B=N for top_down paradigm where each person has its own heatmaps. 644 | 645 | Args: 646 | coords (np.ndarray[N, K, 2]): Initial coordinates of human pose. 647 | batch_heatmaps (np.ndarray[B, K, H, W]): batch_heatmaps 648 | kernel (int): Gaussian kernel size (K) for modulation. 649 | 650 | Returns: 651 | np.ndarray([N, K, 2]): Refined coordinates. 652 | """ 653 | if not isinstance(batch_heatmaps, np.ndarray): 654 | batch_heatmaps = batch_heatmaps.cpu().numpy() 655 | B, K, H, W = batch_heatmaps.shape 656 | N = coords.shape[0] 657 | assert (B == 1 or B == N) 658 | for heatmaps in batch_heatmaps: 659 | for heatmap in heatmaps: 660 | cv2.GaussianBlur(heatmap, (kernel, kernel), 0, heatmap) 661 | np.clip(batch_heatmaps, 0.001, 50, batch_heatmaps) 662 | np.log(batch_heatmaps, batch_heatmaps) 663 | 664 | batch_heatmaps_pad = np.pad( 665 | batch_heatmaps, ((0, 0), (0, 0), (1, 1), (1, 1)), 666 | mode='edge').flatten() 667 | 668 | index = coords[..., 0] + 1 + (coords[..., 1] + 1) * (W + 2) 669 | index += (W + 2) * (H + 2) * np.arange(0, B * K).reshape(-1, K) 670 | index = index.astype(int).reshape(-1, 1) 671 | i_ = batch_heatmaps_pad[index] 672 | ix1 = batch_heatmaps_pad[index + 1] 673 | iy1 = batch_heatmaps_pad[index + W + 2] 674 | ix1y1 = batch_heatmaps_pad[index + W + 3] 675 | ix1_y1_ = batch_heatmaps_pad[index - W - 3] 676 | ix1_ = batch_heatmaps_pad[index - 1] 677 | iy1_ = batch_heatmaps_pad[index - 2 - W] 678 | 679 | dx = 0.5 * (ix1 - ix1_) 680 | dy = 0.5 * (iy1 - iy1_) 681 | derivative = np.concatenate([dx, dy], axis=1) 682 | derivative = derivative.reshape(N, K, 2, 1) 683 | dxx = ix1 - 2 * i_ + ix1_ 684 | dyy = iy1 - 2 * i_ + iy1_ 685 | dxy = 0.5 * (ix1y1 - ix1 - iy1 + i_ + i_ - ix1_ - iy1_ + ix1_y1_) 686 | hessian = np.concatenate([dxx, dxy, dxy, dyy], axis=1) 687 | hessian = hessian.reshape(N, K, 2, 2) 688 | hessian = np.linalg.inv(hessian + np.finfo(np.float32).eps * np.eye(2)) 689 | coords -= np.einsum('ijmn,ijnk->ijmk', hessian, derivative).squeeze() 690 | return coords 691 | 692 | 693 | def _gaussian_blur(heatmaps, kernel=11): 694 | """Modulate heatmap distribution with Gaussian. 695 | sigma = 0.3*((kernel_size-1)*0.5-1)+0.8 696 | sigma~=3 if k=17 697 | sigma=2 if k=11; 698 | sigma~=1.5 if k=7; 699 | sigma~=1 if k=3; 700 | 701 | Note: 702 | - batch_size: N 703 | - num_keypoints: K 704 | - heatmap height: H 705 | - heatmap width: W 706 | 707 | Args: 708 | heatmaps (np.ndarray[N, K, H, W]): model predicted heatmaps. 709 | kernel (int): Gaussian kernel size (K) for modulation, which should 710 | match the heatmap gaussian sigma when training. 711 | K=17 for sigma=3 and k=11 for sigma=2. 712 | 713 | Returns: 714 | np.ndarray ([N, K, H, W]): Modulated heatmap distribution. 715 | """ 716 | assert kernel % 2 == 1 717 | 718 | border = (kernel - 1) // 2 719 | batch_size = heatmaps.shape[0] 720 | num_joints = heatmaps.shape[1] 721 | height = heatmaps.shape[2] 722 | width = heatmaps.shape[3] 723 | for i in range(batch_size): 724 | for j in range(num_joints): 725 | origin_max = np.max(heatmaps[i, j]) 726 | dr = np.zeros((height + 2 * border, width + 2 * border), 727 | dtype=np.float32) 728 | dr[border:-border, border:-border] = heatmaps[i, j].copy() 729 | dr = cv2.GaussianBlur(dr, (kernel, kernel), 0) 730 | heatmaps[i, j] = dr[border:-border, border:-border].copy() 731 | heatmaps[i, j] *= origin_max / np.max(heatmaps[i, j]) 732 | return heatmaps 733 | 734 | 735 | def keypoints_from_regression(regression_preds, center, scale, img_size): 736 | """Get final keypoint predictions from regression vectors and transform 737 | them back to the image. 738 | 739 | Note: 740 | - batch_size: N 741 | - num_keypoints: K 742 | 743 | Args: 744 | regression_preds (np.ndarray[N, K, 2]): model prediction. 745 | center (np.ndarray[N, 2]): Center of the bounding box (x, y). 746 | scale (np.ndarray[N, 2]): Scale of the bounding box 747 | wrt height/width. 748 | img_size (list(img_width, img_height)): model input image size. 749 | 750 | Returns: 751 | tuple: 752 | 753 | - preds (np.ndarray[N, K, 2]): Predicted keypoint location in images. 754 | - maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints. 755 | """ 756 | N, K, _ = regression_preds.shape 757 | preds, maxvals = regression_preds, np.ones((N, K, 1), dtype=np.float32) 758 | 759 | preds = preds * img_size 760 | 761 | # Transform back to the image 762 | for i in range(N): 763 | preds[i] = transform_preds(preds[i], center[i], scale[i], img_size) 764 | 765 | return preds, maxvals 766 | 767 | 768 | def keypoints_from_heatmaps(heatmaps, 769 | center, 770 | scale, 771 | unbiased=False, 772 | post_process='default', 773 | kernel=11, 774 | valid_radius_factor=0.0546875, 775 | use_udp=False, 776 | target_type='GaussianHeatmap'): 777 | """Get final keypoint predictions from heatmaps and transform them back to 778 | the image. 779 | 780 | Note: 781 | - batch size: N 782 | - num keypoints: K 783 | - heatmap height: H 784 | - heatmap width: W 785 | 786 | Args: 787 | heatmaps (np.ndarray[N, K, H, W]): model predicted heatmaps. 788 | center (np.ndarray[N, 2]): Center of the bounding box (x, y). 789 | scale (np.ndarray[N, 2]): Scale of the bounding box 790 | wrt height/width. 791 | post_process (str/None): Choice of methods to post-process 792 | heatmaps. Currently supported: None, 'default', 'unbiased', 793 | 'megvii'. 794 | unbiased (bool): Option to use unbiased decoding. Mutually 795 | exclusive with megvii. 796 | Note: this arg is deprecated and unbiased=True can be replaced 797 | by post_process='unbiased' 798 | Paper ref: Zhang et al. Distribution-Aware Coordinate 799 | Representation for Human Pose Estimation (CVPR 2020). 800 | kernel (int): Gaussian kernel size (K) for modulation, which should 801 | match the heatmap gaussian sigma when training. 802 | K=17 for sigma=3 and k=11 for sigma=2. 803 | valid_radius_factor (float): The radius factor of the positive area 804 | in classification heatmap for UDP. 805 | use_udp (bool): Use unbiased data processing. 806 | target_type (str): 'GaussianHeatmap' or 'CombinedTarget'. 807 | GaussianHeatmap: Classification target with gaussian distribution. 808 | CombinedTarget: The combination of classification target 809 | (response map) and regression target (offset map). 810 | Paper ref: Huang et al. The Devil is in the Details: Delving into 811 | Unbiased Data Processing for Human Pose Estimation (CVPR 2020). 812 | 813 | Returns: 814 | tuple: A tuple containing keypoint predictions and scores. 815 | 816 | - preds (np.ndarray[N, K, 2]): Predicted keypoint location in images. 817 | - maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints. 818 | """ 819 | # Avoid being affected 820 | heatmaps = heatmaps.copy() 821 | 822 | # detect conflicts 823 | if unbiased: 824 | assert post_process not in [False, None, 'megvii'] 825 | if post_process in ['megvii', 'unbiased']: 826 | assert kernel > 0 827 | if use_udp: 828 | assert not post_process == 'megvii' 829 | 830 | # normalize configs 831 | if post_process is False: 832 | warnings.warn( 833 | 'post_process=False is deprecated, ' 834 | 'please use post_process=None instead', DeprecationWarning) 835 | post_process = None 836 | elif post_process is True: 837 | if unbiased is True: 838 | warnings.warn( 839 | 'post_process=True, unbiased=True is deprecated,' 840 | " please use post_process='unbiased' instead", 841 | DeprecationWarning) 842 | post_process = 'unbiased' 843 | else: 844 | warnings.warn( 845 | 'post_process=True, unbiased=False is deprecated, ' 846 | "please use post_process='default' instead", 847 | DeprecationWarning) 848 | post_process = 'default' 849 | elif post_process == 'default': 850 | if unbiased is True: 851 | warnings.warn( 852 | 'unbiased=True is deprecated, please use ' 853 | "post_process='unbiased' instead", DeprecationWarning) 854 | post_process = 'unbiased' 855 | 856 | # start processing 857 | if post_process == 'megvii': 858 | heatmaps = _gaussian_blur(heatmaps, kernel=kernel) 859 | 860 | N, K, H, W = heatmaps.shape 861 | if use_udp: 862 | if target_type.lower() == 'GaussianHeatMap'.lower(): 863 | preds, maxvals = _get_max_preds(heatmaps) 864 | preds = post_dark_udp(preds, heatmaps, kernel=kernel) 865 | elif target_type.lower() == 'CombinedTarget'.lower(): 866 | for person_heatmaps in heatmaps: 867 | for i, heatmap in enumerate(person_heatmaps): 868 | kt = 2 * kernel + 1 if i % 3 == 0 else kernel 869 | cv2.GaussianBlur(heatmap, (kt, kt), 0, heatmap) 870 | # valid radius is in direct proportion to the height of heatmap. 871 | valid_radius = valid_radius_factor * H 872 | offset_x = heatmaps[:, 1::3, :].flatten() * valid_radius 873 | offset_y = heatmaps[:, 2::3, :].flatten() * valid_radius 874 | heatmaps = heatmaps[:, ::3, :] 875 | preds, maxvals = _get_max_preds(heatmaps) 876 | index = preds[..., 0] + preds[..., 1] * W 877 | index += W * H * np.arange(0, N * K / 3) 878 | index = index.astype(int).reshape(N, K // 3, 1) 879 | preds += np.concatenate((offset_x[index], offset_y[index]), axis=2) 880 | else: 881 | raise ValueError('target_type should be either ' 882 | "'GaussianHeatmap' or 'CombinedTarget'") 883 | else: 884 | preds, maxvals = _get_max_preds(heatmaps) 885 | if post_process == 'unbiased': # alleviate biased coordinate 886 | # apply Gaussian distribution modulation. 887 | heatmaps = np.log( 888 | np.maximum(_gaussian_blur(heatmaps, kernel), 1e-10)) 889 | for n in range(N): 890 | for k in range(K): 891 | preds[n][k] = _taylor(heatmaps[n][k], preds[n][k]) 892 | elif post_process is not None: 893 | # add +/-0.25 shift to the predicted locations for higher acc. 894 | for n in range(N): 895 | for k in range(K): 896 | heatmap = heatmaps[n][k] 897 | px = int(preds[n][k][0]) 898 | py = int(preds[n][k][1]) 899 | if 1 < px < W - 1 and 1 < py < H - 1: 900 | diff = np.array([ 901 | heatmap[py][px + 1] - heatmap[py][px - 1], 902 | heatmap[py + 1][px] - heatmap[py - 1][px] 903 | ]) 904 | preds[n][k] += np.sign(diff) * .25 905 | if post_process == 'megvii': 906 | preds[n][k] += 0.5 907 | 908 | # Transform back to the image 909 | for i in range(N): 910 | preds[i] = transform_preds( 911 | preds[i], center[i], scale[i], [W, H], use_udp=use_udp) 912 | 913 | if post_process == 'megvii': 914 | maxvals = maxvals / 255.0 + 0.5 915 | 916 | return preds, maxvals 917 | 918 | 919 | def keypoints_from_heatmaps3d(heatmaps, center, scale): 920 | """Get final keypoint predictions from 3d heatmaps and transform them back 921 | to the image. 922 | 923 | Note: 924 | - batch size: N 925 | - num keypoints: K 926 | - heatmap depth size: D 927 | - heatmap height: H 928 | - heatmap width: W 929 | 930 | Args: 931 | heatmaps (np.ndarray[N, K, D, H, W]): model predicted heatmaps. 932 | center (np.ndarray[N, 2]): Center of the bounding box (x, y). 933 | scale (np.ndarray[N, 2]): Scale of the bounding box 934 | wrt height/width. 935 | 936 | Returns: 937 | tuple: A tuple containing keypoint predictions and scores. 938 | 939 | - preds (np.ndarray[N, K, 3]): Predicted 3d keypoint location \ 940 | in images. 941 | - maxvals (np.ndarray[N, K, 1]): Scores (confidence) of the keypoints. 942 | """ 943 | N, K, D, H, W = heatmaps.shape 944 | preds, maxvals = _get_max_preds_3d(heatmaps) 945 | # Transform back to the image 946 | for i in range(N): 947 | preds[i, :, :2] = transform_preds(preds[i, :, :2], center[i], scale[i], 948 | [W, H]) 949 | return preds, maxvals 950 | 951 | 952 | def multilabel_classification_accuracy(pred, gt, mask, thr=0.5): 953 | """Get multi-label classification accuracy. 954 | 955 | Note: 956 | - batch size: N 957 | - label number: L 958 | 959 | Args: 960 | pred (np.ndarray[N, L, 2]): model predicted labels. 961 | gt (np.ndarray[N, L, 2]): ground-truth labels. 962 | mask (np.ndarray[N, 1] or np.ndarray[N, L] ): reliability of 963 | ground-truth labels. 964 | 965 | Returns: 966 | float: multi-label classification accuracy. 967 | """ 968 | # we only compute accuracy on the samples with ground-truth of all labels. 969 | valid = (mask > 0).min(axis=1) if mask.ndim == 2 else (mask > 0) 970 | pred, gt = pred[valid], gt[valid] 971 | 972 | if pred.shape[0] == 0: 973 | acc = 0.0 # when no sample is with gt labels, set acc to 0. 974 | else: 975 | # The classification of a sample is regarded as correct 976 | # only if it's correct for all labels. 977 | acc = (((pred - thr) * (gt - thr)) > 0).all(axis=1).mean() 978 | return acc 979 | 980 | 981 | 982 | def get_transform(center, scale, res, rot=0): 983 | """Generate transformation matrix.""" 984 | # res: (height, width), (rows, cols) 985 | crop_aspect_ratio = res[0] / float(res[1]) 986 | h = 200 * scale 987 | w = h / crop_aspect_ratio 988 | t = np.zeros((3, 3)) 989 | t[0, 0] = float(res[1]) / w 990 | t[1, 1] = float(res[0]) / h 991 | t[0, 2] = res[1] * (-float(center[0]) / w + .5) 992 | t[1, 2] = res[0] * (-float(center[1]) / h + .5) 993 | t[2, 2] = 1 994 | if not rot == 0: 995 | rot = -rot # To match direction of rotation from cropping 996 | rot_mat = np.zeros((3, 3)) 997 | rot_rad = rot * np.pi / 180 998 | sn, cs = np.sin(rot_rad), np.cos(rot_rad) 999 | rot_mat[0, :2] = [cs, -sn] 1000 | rot_mat[1, :2] = [sn, cs] 1001 | rot_mat[2, 2] = 1 1002 | # Need to rotate around center 1003 | t_mat = np.eye(3) 1004 | t_mat[0, 2] = -res[1] / 2 1005 | t_mat[1, 2] = -res[0] / 2 1006 | t_inv = t_mat.copy() 1007 | t_inv[:2, 2] *= -1 1008 | t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t))) 1009 | return t 1010 | 1011 | 1012 | def transform(pt, center, scale, res, invert=0, rot=0): 1013 | """Transform pixel location to different reference.""" 1014 | t = get_transform(center, scale, res, rot=rot) 1015 | if invert: 1016 | t = np.linalg.inv(t) 1017 | new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T 1018 | new_pt = np.dot(t, new_pt) 1019 | return np.array([round(new_pt[0]), round(new_pt[1])], dtype=int) + 1 1020 | 1021 | 1022 | def bbox_from_detector(bbox, input_resolution=(224, 224), rescale=1.25): 1023 | """ 1024 | Get center and scale of bounding box from bounding box. 1025 | The expected format is [min_x, min_y, max_x, max_y]. 1026 | """ 1027 | CROP_IMG_HEIGHT, CROP_IMG_WIDTH = input_resolution 1028 | CROP_ASPECT_RATIO = CROP_IMG_HEIGHT / float(CROP_IMG_WIDTH) 1029 | 1030 | # center 1031 | center_x = (bbox[0] + bbox[2]) / 2.0 1032 | center_y = (bbox[1] + bbox[3]) / 2.0 1033 | center = np.array([center_x, center_y]) 1034 | 1035 | # scale 1036 | bbox_w = bbox[2] - bbox[0] 1037 | bbox_h = bbox[3] - bbox[1] 1038 | bbox_size = max(bbox_w * CROP_ASPECT_RATIO, bbox_h) 1039 | 1040 | scale = np.array([bbox_size / CROP_ASPECT_RATIO, bbox_size]) / 200.0 1041 | # scale = bbox_size / 200.0 1042 | # adjust bounding box tightness 1043 | scale *= rescale 1044 | return center, scale 1045 | 1046 | 1047 | def crop(img, center, scale, res): 1048 | """ 1049 | Crop image according to the supplied bounding box. 1050 | res: [rows, cols] 1051 | """ 1052 | # Upper left point 1053 | ul = np.array(transform([1, 1], center, max(scale), res, invert=1)) - 1 1054 | # Bottom right point 1055 | br = np.array(transform([res[1] + 1, res[0] + 1], center, max(scale), res, invert=1)) - 1 1056 | 1057 | new_shape = [br[1] - ul[1], br[0] - ul[0]] 1058 | if len(img.shape) > 2: 1059 | new_shape += [img.shape[2]] 1060 | new_img = np.zeros(new_shape, dtype=np.float32) 1061 | 1062 | # Range to fill new array 1063 | new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0] 1064 | new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1] 1065 | # Range to sample from original image 1066 | old_x = max(0, ul[0]), min(len(img[0]), br[0]) 1067 | old_y = max(0, ul[1]), min(len(img), br[1]) 1068 | try: 1069 | new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]] 1070 | except Exception as e: 1071 | print(e) 1072 | 1073 | new_img = cv2.resize(new_img, (res[1], res[0])) # (cols, rows) 1074 | return new_img, new_shape, (old_x, old_y), (new_x, new_y) # , ul, br 1075 | 1076 | 1077 | def split_kp2ds_for_aa(kp2ds, ret_face=False): 1078 | kp2ds_body = (kp2ds[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + kp2ds[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2 1079 | kp2ds_lhand = kp2ds[91:112] 1080 | kp2ds_rhand = kp2ds[112:133] 1081 | kp2ds_face = kp2ds[22:91] 1082 | if ret_face: 1083 | return kp2ds_body.copy(), kp2ds_lhand.copy(), kp2ds_rhand.copy(), kp2ds_face.copy() 1084 | return kp2ds_body.copy(), kp2ds_lhand.copy(), kp2ds_rhand.copy() 1085 | 1086 | 1087 | def load_pose_metas_from_kp2ds_seq(kp2ds_seq, width, height): 1088 | metas = [] 1089 | last_kp2ds_body = None 1090 | for kps in kp2ds_seq: 1091 | kps = kps.copy() 1092 | kps[:, 0] /= width 1093 | kps[:, 1] /= height 1094 | kp2ds_body, kp2ds_lhand, kp2ds_rhand, kp2ds_face = split_kp2ds_for_aa(kps, ret_face=True) 1095 | 1096 | # Exclude cases where all values are less than 0 1097 | if last_kp2ds_body is not None and kp2ds_body[:, :2].min(axis=1).max() < 0: 1098 | kp2ds_body = last_kp2ds_body 1099 | last_kp2ds_body = kp2ds_body 1100 | 1101 | meta = { 1102 | "width": width, 1103 | "height": height, 1104 | "keypoints_body": kp2ds_body, 1105 | "keypoints_left_hand": kp2ds_lhand, 1106 | "keypoints_right_hand": kp2ds_rhand, 1107 | "keypoints_face": kp2ds_face, 1108 | } 1109 | metas.append(meta) 1110 | return metas -------------------------------------------------------------------------------- /pose_utils/human_visualization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import os 3 | import cv2 4 | import time 5 | import math 6 | import matplotlib 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | from typing import Dict, List 10 | import random 11 | from .pose2d_utils import AAPoseMeta 12 | 13 | 14 | def draw_handpose(canvas, keypoints, hand_score_th=0.6): 15 | """ 16 | Draw keypoints and connections representing hand pose on a given canvas. 17 | 18 | Args: 19 | canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose. 20 | keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn 21 | or None if no keypoints are present. 22 | 23 | Returns: 24 | np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose. 25 | 26 | Note: 27 | The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. 28 | """ 29 | eps = 0.01 30 | 31 | H, W, C = canvas.shape 32 | stickwidth = max(int(min(H, W) / 200), 1) 33 | 34 | edges = [ 35 | [0, 1], 36 | [1, 2], 37 | [2, 3], 38 | [3, 4], 39 | [0, 5], 40 | [5, 6], 41 | [6, 7], 42 | [7, 8], 43 | [0, 9], 44 | [9, 10], 45 | [10, 11], 46 | [11, 12], 47 | [0, 13], 48 | [13, 14], 49 | [14, 15], 50 | [15, 16], 51 | [0, 17], 52 | [17, 18], 53 | [18, 19], 54 | [19, 20], 55 | ] 56 | 57 | for ie, (e1, e2) in enumerate(edges): 58 | k1 = keypoints[e1] 59 | k2 = keypoints[e2] 60 | if k1 is None or k2 is None: 61 | continue 62 | if k1[2] < hand_score_th or k2[2] < hand_score_th: 63 | continue 64 | 65 | x1 = int(k1[0]) 66 | y1 = int(k1[1]) 67 | x2 = int(k2[0]) 68 | y2 = int(k2[1]) 69 | if x1 > eps and y1 > eps and x2 > eps and y2 > eps: 70 | cv2.line( 71 | canvas, 72 | (x1, y1), 73 | (x2, y2), 74 | matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, 75 | thickness=stickwidth, 76 | ) 77 | 78 | for keypoint in keypoints: 79 | 80 | if keypoint is None: 81 | continue 82 | if keypoint[2] < hand_score_th: 83 | continue 84 | 85 | x, y = keypoint[0], keypoint[1] 86 | x = int(x) 87 | y = int(y) 88 | if x > eps and y > eps: 89 | cv2.circle(canvas, (x, y), stickwidth, (0, 0, 255), thickness=-1) 90 | return canvas 91 | 92 | 93 | def draw_handpose_new(canvas, keypoints, stickwidth_type='v2', hand_score_th=0.6, hand_stick_width=4): 94 | """ 95 | Draw keypoints and connections representing hand pose on a given canvas. 96 | 97 | Args: 98 | canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose. 99 | keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn 100 | or None if no keypoints are present. 101 | 102 | Returns: 103 | np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose. 104 | 105 | Note: 106 | The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. 107 | """ 108 | eps = 0.01 109 | 110 | H, W, C = canvas.shape 111 | # if stickwidth_type == 'v1': 112 | # stickwidth = max(int(min(H, W) / 200), 1) 113 | # elif stickwidth_type == 'v2': 114 | # stickwidth = max(max(int(min(H, W) / 200) - 1, 1) // 2, 1) 115 | if hand_stick_width == -1: 116 | stickwidth = max(max(int(min(H, W) / 200) - 1, 1) // 2, 1) 117 | else: 118 | stickwidth = hand_stick_width 119 | 120 | edges = [ 121 | [0, 1], 122 | [1, 2], 123 | [2, 3], 124 | [3, 4], 125 | [0, 5], 126 | [5, 6], 127 | [6, 7], 128 | [7, 8], 129 | [0, 9], 130 | [9, 10], 131 | [10, 11], 132 | [11, 12], 133 | [0, 13], 134 | [13, 14], 135 | [14, 15], 136 | [15, 16], 137 | [0, 17], 138 | [17, 18], 139 | [18, 19], 140 | [19, 20], 141 | ] 142 | 143 | for ie, (e1, e2) in enumerate(edges): 144 | k1 = keypoints[e1] 145 | k2 = keypoints[e2] 146 | if k1 is None or k2 is None: 147 | continue 148 | if k1[2] < hand_score_th or k2[2] < hand_score_th: 149 | continue 150 | 151 | x1 = int(k1[0]) 152 | y1 = int(k1[1]) 153 | x2 = int(k2[0]) 154 | y2 = int(k2[1]) 155 | if x1 > eps and y1 > eps and x2 > eps and y2 > eps: 156 | cv2.line( 157 | canvas, 158 | (x1, y1), 159 | (x2, y2), 160 | matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, 161 | thickness=stickwidth, 162 | ) 163 | 164 | for keypoint in keypoints: 165 | 166 | if keypoint is None: 167 | continue 168 | if keypoint[2] < hand_score_th: 169 | continue 170 | 171 | x, y = keypoint[0], keypoint[1] 172 | x = int(x) 173 | y = int(y) 174 | if x > eps and y > eps: 175 | cv2.circle(canvas, (x, y), stickwidth, (0, 0, 255), thickness=-1) 176 | return canvas 177 | 178 | 179 | def draw_ellipse_by_2kp(img, keypoint1, keypoint2, color, threshold=0.6): 180 | H, W, C = img.shape 181 | stickwidth = max(int(min(H, W) / 200), 1) 182 | 183 | if keypoint1[-1] < threshold or keypoint2[-1] < threshold: 184 | return img 185 | 186 | Y = np.array([keypoint1[0], keypoint2[0]]) 187 | X = np.array([keypoint1[1], keypoint2[1]]) 188 | mX = np.mean(X) 189 | mY = np.mean(Y) 190 | length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 191 | angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) 192 | polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) 193 | cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color]) 194 | return img 195 | 196 | 197 | def split_pose2d_kps_to_aa(kp2ds: np.ndarray) -> List[np.ndarray]: 198 | """Convert the 133 keypoints from pose2d to body and hands keypoints. 199 | 200 | Args: 201 | kp2ds (np.ndarray): [133, 2] 202 | 203 | Returns: 204 | List[np.ndarray]: _description_ 205 | """ 206 | kp2ds_body = ( 207 | kp2ds[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] 208 | + kp2ds[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]] 209 | ) / 2 210 | kp2ds_lhand = kp2ds[91:112] 211 | kp2ds_rhand = kp2ds[112:133] 212 | return kp2ds_body.copy(), kp2ds_lhand.copy(), kp2ds_rhand.copy() 213 | 214 | 215 | def draw_aapose_by_meta(img, meta: AAPoseMeta, threshold=0.5, stick_width_norm=200, draw_hand=True, draw_head=True): 216 | kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None]], axis=1) 217 | kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1) 218 | kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1) 219 | pose_img = draw_aapose(img, kp2ds, threshold, kp2ds_lhand=kp2ds_lhand, kp2ds_rhand=kp2ds_rhand, stick_width_norm=stick_width_norm, draw_hand=draw_hand, draw_head=draw_head) 220 | return pose_img 221 | 222 | def draw_aapose_by_meta_new(img, meta: AAPoseMeta, threshold=0.5, stickwidth_type='v2', body_stick_width=-1, draw_hand=True, draw_head=True, hand_stick_width=4): 223 | kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None]], axis=1) 224 | kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1) 225 | kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1) 226 | pose_img = draw_aapose_new(img, kp2ds, threshold, kp2ds_lhand=kp2ds_lhand, kp2ds_rhand=kp2ds_rhand, body_stick_width=body_stick_width, 227 | stickwidth_type=stickwidth_type, draw_hand=draw_hand, draw_head=draw_head, hand_stick_width=hand_stick_width) 228 | return pose_img 229 | 230 | def draw_hand_by_meta(img, meta: AAPoseMeta, threshold=0.5, stick_width_norm=200): 231 | kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None] * 0], axis=1) 232 | kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1) 233 | kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1) 234 | pose_img = draw_aapose(img, kp2ds, threshold, kp2ds_lhand=kp2ds_lhand, kp2ds_rhand=kp2ds_rhand, stick_width_norm=stick_width_norm, draw_hand=True, draw_head=False) 235 | return pose_img 236 | 237 | 238 | def draw_aaface_by_meta(img, meta: AAPoseMeta, threshold=0.5, stick_width_norm=200, draw_hand=False, draw_head=True): 239 | kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None]], axis=1) 240 | # kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1) 241 | # kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1) 242 | pose_img = draw_M(img, kp2ds, threshold, kp2ds_lhand=None, kp2ds_rhand=None, stick_width_norm=stick_width_norm, draw_hand=draw_hand, draw_head=draw_head) 243 | return pose_img 244 | 245 | 246 | def draw_aanose_by_meta(img, meta: AAPoseMeta, threshold=0.5, stick_width_norm=100, draw_hand=False): 247 | kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None]], axis=1) 248 | # kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1) 249 | # kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1) 250 | pose_img = draw_nose(img, kp2ds, threshold, kp2ds_lhand=None, kp2ds_rhand=None, stick_width_norm=stick_width_norm, draw_hand=draw_hand) 251 | return pose_img 252 | 253 | 254 | def gen_face_motion_seq(img, metas: List[AAPoseMeta], threshold=0.5, stick_width_norm=200): 255 | 256 | return 257 | 258 | 259 | def draw_M( 260 | img, 261 | kp2ds, 262 | threshold=0.6, 263 | data_to_json=None, 264 | idx=-1, 265 | kp2ds_lhand=None, 266 | kp2ds_rhand=None, 267 | draw_hand=False, 268 | stick_width_norm=200, 269 | draw_head=True 270 | ): 271 | """ 272 | Draw keypoints and connections representing hand pose on a given canvas. 273 | 274 | Args: 275 | canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose. 276 | keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn 277 | or None if no keypoints are present. 278 | 279 | Returns: 280 | np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose. 281 | 282 | Note: 283 | The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. 284 | """ 285 | 286 | new_kep_list = [ 287 | "Nose", 288 | "Neck", 289 | "RShoulder", 290 | "RElbow", 291 | "RWrist", # No.4 292 | "LShoulder", 293 | "LElbow", 294 | "LWrist", # No.7 295 | "RHip", 296 | "RKnee", 297 | "RAnkle", # No.10 298 | "LHip", 299 | "LKnee", 300 | "LAnkle", # No.13 301 | "REye", 302 | "LEye", 303 | "REar", 304 | "LEar", 305 | "LToe", 306 | "RToe", 307 | ] 308 | # kp2ds_body = (kp2ds.copy()[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + \ 309 | # kp2ds.copy()[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2 310 | kp2ds = kp2ds.copy() 311 | # import ipdb; ipdb.set_trace() 312 | kp2ds[[1,2,3,4,5,6,7,8,9,10,11,12,13,18,19], 2] = 0 313 | if not draw_head: 314 | kp2ds[[0,14,15,16,17], 2] = 0 315 | kp2ds_body = kp2ds 316 | # kp2ds_body = kp2ds_body[:18] 317 | 318 | # kp2ds_lhand = kp2ds.copy()[91:112] 319 | # kp2ds_rhand = kp2ds.copy()[112:133] 320 | 321 | limbSeq = [ 322 | # [2, 3], 323 | # [2, 6], # shoulders 324 | # [3, 4], 325 | # [4, 5], # left arm 326 | # [6, 7], 327 | # [7, 8], # right arm 328 | # [2, 9], 329 | # [9, 10], 330 | # [10, 11], # right leg 331 | # [2, 12], 332 | # [12, 13], 333 | # [13, 14], # left leg 334 | # [2, 1], 335 | [1, 15], 336 | [15, 17], 337 | [1, 16], 338 | [16, 18], # face (nose, eyes, ears) 339 | # [14, 19], 340 | # [11, 20], # foot 341 | ] 342 | 343 | colors = [ 344 | # [255, 0, 0], 345 | # [255, 85, 0], 346 | # [255, 170, 0], 347 | # [255, 255, 0], 348 | # [170, 255, 0], 349 | # [85, 255, 0], 350 | # [0, 255, 0], 351 | # [0, 255, 85], 352 | # [0, 255, 170], 353 | # [0, 255, 255], 354 | # [0, 170, 255], 355 | # [0, 85, 255], 356 | # [0, 0, 255], 357 | # [85, 0, 255], 358 | [170, 0, 255], 359 | [255, 0, 255], 360 | [255, 0, 170], 361 | [255, 0, 85], 362 | # foot 363 | # [200, 200, 0], 364 | # [100, 100, 0], 365 | ] 366 | 367 | H, W, C = img.shape 368 | stickwidth = max(int(min(H, W) / stick_width_norm), 1) 369 | 370 | for _idx, ((k1_index, k2_index), color) in enumerate(zip(limbSeq, colors)): 371 | keypoint1 = kp2ds_body[k1_index - 1] 372 | keypoint2 = kp2ds_body[k2_index - 1] 373 | 374 | if keypoint1[-1] < threshold or keypoint2[-1] < threshold: 375 | continue 376 | 377 | Y = np.array([keypoint1[0], keypoint2[0]]) 378 | X = np.array([keypoint1[1], keypoint2[1]]) 379 | mX = np.mean(X) 380 | mY = np.mean(Y) 381 | length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 382 | angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) 383 | polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) 384 | cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color]) 385 | 386 | for _idx, (keypoint, color) in enumerate(zip(kp2ds_body, colors)): 387 | if keypoint[-1] < threshold: 388 | continue 389 | x, y = keypoint[0], keypoint[1] 390 | # cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1) 391 | cv2.circle(img, (int(x), int(y)), stickwidth, color, thickness=-1) 392 | 393 | if draw_hand: 394 | img = draw_handpose(img, kp2ds_lhand, hand_score_th=threshold) 395 | img = draw_handpose(img, kp2ds_rhand, hand_score_th=threshold) 396 | 397 | kp2ds_body[:, 0] /= W 398 | kp2ds_body[:, 1] /= H 399 | 400 | if data_to_json is not None: 401 | if idx == -1: 402 | data_to_json.append( 403 | { 404 | "image_id": "frame_{:05d}.jpg".format(len(data_to_json) + 1), 405 | "height": H, 406 | "width": W, 407 | "category_id": 1, 408 | "keypoints_body": kp2ds_body.tolist(), 409 | "keypoints_left_hand": kp2ds_lhand.tolist(), 410 | "keypoints_right_hand": kp2ds_rhand.tolist(), 411 | } 412 | ) 413 | else: 414 | data_to_json[idx] = { 415 | "image_id": "frame_{:05d}.jpg".format(idx + 1), 416 | "height": H, 417 | "width": W, 418 | "category_id": 1, 419 | "keypoints_body": kp2ds_body.tolist(), 420 | "keypoints_left_hand": kp2ds_lhand.tolist(), 421 | "keypoints_right_hand": kp2ds_rhand.tolist(), 422 | } 423 | return img 424 | 425 | 426 | def draw_nose( 427 | img, 428 | kp2ds, 429 | threshold=0.6, 430 | data_to_json=None, 431 | idx=-1, 432 | kp2ds_lhand=None, 433 | kp2ds_rhand=None, 434 | draw_hand=False, 435 | stick_width_norm=200, 436 | ): 437 | """ 438 | Draw keypoints and connections representing hand pose on a given canvas. 439 | 440 | Args: 441 | canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose. 442 | keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn 443 | or None if no keypoints are present. 444 | 445 | Returns: 446 | np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose. 447 | 448 | Note: 449 | The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. 450 | """ 451 | 452 | new_kep_list = [ 453 | "Nose", 454 | "Neck", 455 | "RShoulder", 456 | "RElbow", 457 | "RWrist", # No.4 458 | "LShoulder", 459 | "LElbow", 460 | "LWrist", # No.7 461 | "RHip", 462 | "RKnee", 463 | "RAnkle", # No.10 464 | "LHip", 465 | "LKnee", 466 | "LAnkle", # No.13 467 | "REye", 468 | "LEye", 469 | "REar", 470 | "LEar", 471 | "LToe", 472 | "RToe", 473 | ] 474 | # kp2ds_body = (kp2ds.copy()[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + \ 475 | # kp2ds.copy()[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2 476 | kp2ds = kp2ds.copy() 477 | kp2ds[1:, 2] = 0 478 | # kp2ds[0, 2] = 1 479 | kp2ds_body = kp2ds 480 | # kp2ds_body = kp2ds_body[:18] 481 | 482 | # kp2ds_lhand = kp2ds.copy()[91:112] 483 | # kp2ds_rhand = kp2ds.copy()[112:133] 484 | 485 | limbSeq = [ 486 | # [2, 3], 487 | # [2, 6], # shoulders 488 | # [3, 4], 489 | # [4, 5], # left arm 490 | # [6, 7], 491 | # [7, 8], # right arm 492 | # [2, 9], 493 | # [9, 10], 494 | # [10, 11], # right leg 495 | # [2, 12], 496 | # [12, 13], 497 | # [13, 14], # left leg 498 | # [2, 1], 499 | [1, 15], 500 | [15, 17], 501 | [1, 16], 502 | [16, 18], # face (nose, eyes, ears) 503 | # [14, 19], 504 | # [11, 20], # foot 505 | ] 506 | 507 | colors = [ 508 | # [255, 0, 0], 509 | # [255, 85, 0], 510 | # [255, 170, 0], 511 | # [255, 255, 0], 512 | # [170, 255, 0], 513 | # [85, 255, 0], 514 | # [0, 255, 0], 515 | # [0, 255, 85], 516 | # [0, 255, 170], 517 | # [0, 255, 255], 518 | # [0, 170, 255], 519 | # [0, 85, 255], 520 | # [0, 0, 255], 521 | # [85, 0, 255], 522 | [170, 0, 255], 523 | # [255, 0, 255], 524 | # [255, 0, 170], 525 | # [255, 0, 85], 526 | # foot 527 | # [200, 200, 0], 528 | # [100, 100, 0], 529 | ] 530 | 531 | H, W, C = img.shape 532 | stickwidth = max(int(min(H, W) / stick_width_norm), 1) 533 | 534 | # for _idx, ((k1_index, k2_index), color) in enumerate(zip(limbSeq, colors)): 535 | # keypoint1 = kp2ds_body[k1_index - 1] 536 | # keypoint2 = kp2ds_body[k2_index - 1] 537 | 538 | # if keypoint1[-1] < threshold or keypoint2[-1] < threshold: 539 | # continue 540 | 541 | # Y = np.array([keypoint1[0], keypoint2[0]]) 542 | # X = np.array([keypoint1[1], keypoint2[1]]) 543 | # mX = np.mean(X) 544 | # mY = np.mean(Y) 545 | # length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 546 | # angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) 547 | # polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) 548 | # cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color]) 549 | 550 | for _idx, (keypoint, color) in enumerate(zip(kp2ds_body, colors)): 551 | if keypoint[-1] < threshold: 552 | continue 553 | x, y = keypoint[0], keypoint[1] 554 | # cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1) 555 | cv2.circle(img, (int(x), int(y)), stickwidth, color, thickness=-1) 556 | 557 | if draw_hand: 558 | img = draw_handpose(img, kp2ds_lhand, hand_score_th=threshold) 559 | img = draw_handpose(img, kp2ds_rhand, hand_score_th=threshold) 560 | 561 | kp2ds_body[:, 0] /= W 562 | kp2ds_body[:, 1] /= H 563 | 564 | if data_to_json is not None: 565 | if idx == -1: 566 | data_to_json.append( 567 | { 568 | "image_id": "frame_{:05d}.jpg".format(len(data_to_json) + 1), 569 | "height": H, 570 | "width": W, 571 | "category_id": 1, 572 | "keypoints_body": kp2ds_body.tolist(), 573 | "keypoints_left_hand": kp2ds_lhand.tolist(), 574 | "keypoints_right_hand": kp2ds_rhand.tolist(), 575 | } 576 | ) 577 | else: 578 | data_to_json[idx] = { 579 | "image_id": "frame_{:05d}.jpg".format(idx + 1), 580 | "height": H, 581 | "width": W, 582 | "category_id": 1, 583 | "keypoints_body": kp2ds_body.tolist(), 584 | "keypoints_left_hand": kp2ds_lhand.tolist(), 585 | "keypoints_right_hand": kp2ds_rhand.tolist(), 586 | } 587 | return img 588 | 589 | 590 | def draw_aapose( 591 | img, 592 | kp2ds, 593 | threshold=0.6, 594 | data_to_json=None, 595 | idx=-1, 596 | kp2ds_lhand=None, 597 | kp2ds_rhand=None, 598 | draw_hand=False, 599 | stick_width_norm=200, 600 | draw_head=True 601 | ): 602 | """ 603 | Draw keypoints and connections representing hand pose on a given canvas. 604 | 605 | Args: 606 | canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose. 607 | keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn 608 | or None if no keypoints are present. 609 | 610 | Returns: 611 | np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose. 612 | 613 | Note: 614 | The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. 615 | """ 616 | 617 | new_kep_list = [ 618 | "Nose", 619 | "Neck", 620 | "RShoulder", 621 | "RElbow", 622 | "RWrist", # No.4 623 | "LShoulder", 624 | "LElbow", 625 | "LWrist", # No.7 626 | "RHip", 627 | "RKnee", 628 | "RAnkle", # No.10 629 | "LHip", 630 | "LKnee", 631 | "LAnkle", # No.13 632 | "REye", 633 | "LEye", 634 | "REar", 635 | "LEar", 636 | "LToe", 637 | "RToe", 638 | ] 639 | # kp2ds_body = (kp2ds.copy()[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + \ 640 | # kp2ds.copy()[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2 641 | kp2ds = kp2ds.copy() 642 | if not draw_head: 643 | kp2ds[[0,14,15,16,17], 2] = 0 644 | kp2ds_body = kp2ds 645 | 646 | # kp2ds_lhand = kp2ds.copy()[91:112] 647 | # kp2ds_rhand = kp2ds.copy()[112:133] 648 | 649 | limbSeq = [ 650 | [2, 3], 651 | [2, 6], # shoulders 652 | [3, 4], 653 | [4, 5], # left arm 654 | [6, 7], 655 | [7, 8], # right arm 656 | [2, 9], 657 | [9, 10], 658 | [10, 11], # right leg 659 | [2, 12], 660 | [12, 13], 661 | [13, 14], # left leg 662 | [2, 1], 663 | [1, 15], 664 | [15, 17], 665 | [1, 16], 666 | [16, 18], # face (nose, eyes, ears) 667 | [14, 19], 668 | [11, 20], # foot 669 | ] 670 | 671 | colors = [ 672 | [255, 0, 0], 673 | [255, 85, 0], 674 | [255, 170, 0], 675 | [255, 255, 0], 676 | [170, 255, 0], 677 | [85, 255, 0], 678 | [0, 255, 0], 679 | [0, 255, 85], 680 | [0, 255, 170], 681 | [0, 255, 255], 682 | [0, 170, 255], 683 | [0, 85, 255], 684 | [0, 0, 255], 685 | [85, 0, 255], 686 | [170, 0, 255], 687 | [255, 0, 255], 688 | [255, 0, 170], 689 | [255, 0, 85], 690 | # foot 691 | [200, 200, 0], 692 | [100, 100, 0], 693 | ] 694 | 695 | H, W, C = img.shape 696 | stickwidth = max(int(min(H, W) / stick_width_norm), 1) 697 | 698 | for _idx, ((k1_index, k2_index), color) in enumerate(zip(limbSeq, colors)): 699 | keypoint1 = kp2ds_body[k1_index - 1] 700 | keypoint2 = kp2ds_body[k2_index - 1] 701 | 702 | if keypoint1[-1] < threshold or keypoint2[-1] < threshold: 703 | continue 704 | 705 | Y = np.array([keypoint1[0], keypoint2[0]]) 706 | X = np.array([keypoint1[1], keypoint2[1]]) 707 | mX = np.mean(X) 708 | mY = np.mean(Y) 709 | length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 710 | angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) 711 | polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) 712 | cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color]) 713 | 714 | for _idx, (keypoint, color) in enumerate(zip(kp2ds_body, colors)): 715 | if keypoint[-1] < threshold: 716 | continue 717 | x, y = keypoint[0], keypoint[1] 718 | # cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1) 719 | cv2.circle(img, (int(x), int(y)), stickwidth, color, thickness=-1) 720 | 721 | if draw_hand: 722 | img = draw_handpose(img, kp2ds_lhand, hand_score_th=threshold) 723 | img = draw_handpose(img, kp2ds_rhand, hand_score_th=threshold) 724 | 725 | kp2ds_body[:, 0] /= W 726 | kp2ds_body[:, 1] /= H 727 | 728 | if data_to_json is not None: 729 | if idx == -1: 730 | data_to_json.append( 731 | { 732 | "image_id": "frame_{:05d}.jpg".format(len(data_to_json) + 1), 733 | "height": H, 734 | "width": W, 735 | "category_id": 1, 736 | "keypoints_body": kp2ds_body.tolist(), 737 | "keypoints_left_hand": kp2ds_lhand.tolist(), 738 | "keypoints_right_hand": kp2ds_rhand.tolist(), 739 | } 740 | ) 741 | else: 742 | data_to_json[idx] = { 743 | "image_id": "frame_{:05d}.jpg".format(idx + 1), 744 | "height": H, 745 | "width": W, 746 | "category_id": 1, 747 | "keypoints_body": kp2ds_body.tolist(), 748 | "keypoints_left_hand": kp2ds_lhand.tolist(), 749 | "keypoints_right_hand": kp2ds_rhand.tolist(), 750 | } 751 | return img 752 | 753 | 754 | def draw_aapose_new( 755 | img, 756 | kp2ds, 757 | threshold=0.6, 758 | data_to_json=None, 759 | idx=-1, 760 | kp2ds_lhand=None, 761 | kp2ds_rhand=None, 762 | draw_hand=False, 763 | stickwidth_type='v2', 764 | body_stick_width=-1, 765 | hand_stick_width=-1, 766 | draw_head=True 767 | ): 768 | """ 769 | Draw keypoints and connections representing hand pose on a given canvas. 770 | 771 | Args: 772 | canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose. 773 | keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn 774 | or None if no keypoints are present. 775 | 776 | Returns: 777 | np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose. 778 | 779 | Note: 780 | The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. 781 | """ 782 | 783 | new_kep_list = [ 784 | "Nose", 785 | "Neck", 786 | "RShoulder", 787 | "RElbow", 788 | "RWrist", # No.4 789 | "LShoulder", 790 | "LElbow", 791 | "LWrist", # No.7 792 | "RHip", 793 | "RKnee", 794 | "RAnkle", # No.10 795 | "LHip", 796 | "LKnee", 797 | "LAnkle", # No.13 798 | "REye", 799 | "LEye", 800 | "REar", 801 | "LEar", 802 | "LToe", 803 | "RToe", 804 | ] 805 | # kp2ds_body = (kp2ds.copy()[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + \ 806 | # kp2ds.copy()[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2 807 | kp2ds = kp2ds.copy() 808 | if not draw_head: 809 | kp2ds[[0,14,15,16,17], 2] = 0 810 | kp2ds_body = kp2ds 811 | 812 | # kp2ds_lhand = kp2ds.copy()[91:112] 813 | # kp2ds_rhand = kp2ds.copy()[112:133] 814 | 815 | limbSeq = [ 816 | [2, 3], 817 | [2, 6], # shoulders 818 | [3, 4], 819 | [4, 5], # left arm 820 | [6, 7], 821 | [7, 8], # right arm 822 | [2, 9], 823 | [9, 10], 824 | [10, 11], # right leg 825 | [2, 12], 826 | [12, 13], 827 | [13, 14], # left leg 828 | [2, 1], 829 | [1, 15], 830 | [15, 17], 831 | [1, 16], 832 | [16, 18], # face (nose, eyes, ears) 833 | [14, 19], 834 | [11, 20], # foot 835 | ] 836 | 837 | colors = [ 838 | [255, 0, 0], 839 | [255, 85, 0], 840 | [255, 170, 0], 841 | [255, 255, 0], 842 | [170, 255, 0], 843 | [85, 255, 0], 844 | [0, 255, 0], 845 | [0, 255, 85], 846 | [0, 255, 170], 847 | [0, 255, 255], 848 | [0, 170, 255], 849 | [0, 85, 255], 850 | [0, 0, 255], 851 | [85, 0, 255], 852 | [170, 0, 255], 853 | [255, 0, 255], 854 | [255, 0, 170], 855 | [255, 0, 85], 856 | # foot 857 | [200, 200, 0], 858 | [100, 100, 0], 859 | ] 860 | 861 | H, W, C = img.shape 862 | H, W, C = img.shape 863 | 864 | #if stickwidth_type == 'v1': 865 | # stickwidth = max(int(min(H, W) / 200), 1) 866 | #elif stickwidth_type == 'v2': 867 | if body_stick_width == -1: 868 | stickwidth = max(int(min(H, W) / 200) - 1, 1) 869 | else: 870 | stickwidth = body_stick_width 871 | 872 | for _idx, ((k1_index, k2_index), color) in enumerate(zip(limbSeq, colors)): 873 | keypoint1 = kp2ds_body[k1_index - 1] 874 | keypoint2 = kp2ds_body[k2_index - 1] 875 | 876 | if keypoint1[-1] < threshold or keypoint2[-1] < threshold: 877 | continue 878 | 879 | Y = np.array([keypoint1[0], keypoint2[0]]) 880 | X = np.array([keypoint1[1], keypoint2[1]]) 881 | mX = np.mean(X) 882 | mY = np.mean(Y) 883 | length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 884 | angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) 885 | polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) 886 | cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color]) 887 | 888 | for _idx, (keypoint, color) in enumerate(zip(kp2ds_body, colors)): 889 | if keypoint[-1] < threshold: 890 | continue 891 | x, y = keypoint[0], keypoint[1] 892 | # cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1) 893 | cv2.circle(img, (int(x), int(y)), stickwidth, color, thickness=-1) 894 | 895 | if draw_hand: 896 | img = draw_handpose_new(img, kp2ds_lhand, stickwidth_type=stickwidth_type, hand_score_th=threshold, hand_stick_width=hand_stick_width) 897 | img = draw_handpose_new(img, kp2ds_rhand, stickwidth_type=stickwidth_type, hand_score_th=threshold, hand_stick_width=hand_stick_width) 898 | 899 | kp2ds_body[:, 0] /= W 900 | kp2ds_body[:, 1] /= H 901 | 902 | if data_to_json is not None: 903 | if idx == -1: 904 | data_to_json.append( 905 | { 906 | "image_id": "frame_{:05d}.jpg".format(len(data_to_json) + 1), 907 | "height": H, 908 | "width": W, 909 | "category_id": 1, 910 | "keypoints_body": kp2ds_body.tolist(), 911 | "keypoints_left_hand": kp2ds_lhand.tolist(), 912 | "keypoints_right_hand": kp2ds_rhand.tolist(), 913 | } 914 | ) 915 | else: 916 | data_to_json[idx] = { 917 | "image_id": "frame_{:05d}.jpg".format(idx + 1), 918 | "height": H, 919 | "width": W, 920 | "category_id": 1, 921 | "keypoints_body": kp2ds_body.tolist(), 922 | "keypoints_left_hand": kp2ds_lhand.tolist(), 923 | "keypoints_right_hand": kp2ds_rhand.tolist(), 924 | } 925 | return img 926 | 927 | 928 | def draw_bbox(img, bbox, color=(255, 0, 0)): 929 | img = load_image(img) 930 | bbox = [int(bbox_tmp) for bbox_tmp in bbox] 931 | cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color, 2) 932 | return img 933 | 934 | 935 | def draw_kp2ds(img, kp2ds, threshold=0, color=(255, 0, 0), skeleton=None, reverse=False): 936 | img = load_image(img, reverse) 937 | 938 | if skeleton is not None: 939 | if skeleton == "coco17": 940 | skeleton_list = [ 941 | [6, 8], 942 | [8, 10], 943 | [5, 7], 944 | [7, 9], 945 | [11, 13], 946 | [13, 15], 947 | [12, 14], 948 | [14, 16], 949 | [5, 6], 950 | [6, 12], 951 | [12, 11], 952 | [11, 5], 953 | ] 954 | color_list = [ 955 | (255, 0, 0), 956 | (0, 255, 0), 957 | (0, 0, 255), 958 | (255, 255, 0), 959 | (255, 0, 255), 960 | (0, 255, 255), 961 | ] 962 | elif skeleton == "cocowholebody": 963 | skeleton_list = [ 964 | [6, 8], 965 | [8, 10], 966 | [5, 7], 967 | [7, 9], 968 | [11, 13], 969 | [13, 15], 970 | [12, 14], 971 | [14, 16], 972 | [5, 6], 973 | [6, 12], 974 | [12, 11], 975 | [11, 5], 976 | [15, 17], 977 | [15, 18], 978 | [15, 19], 979 | [16, 20], 980 | [16, 21], 981 | [16, 22], 982 | [91, 92, 93, 94, 95], 983 | [91, 96, 97, 98, 99], 984 | [91, 100, 101, 102, 103], 985 | [91, 104, 105, 106, 107], 986 | [91, 108, 109, 110, 111], 987 | [112, 113, 114, 115, 116], 988 | [112, 117, 118, 119, 120], 989 | [112, 121, 122, 123, 124], 990 | [112, 125, 126, 127, 128], 991 | [112, 129, 130, 131, 132], 992 | ] 993 | color_list = [ 994 | (255, 0, 0), 995 | (0, 255, 0), 996 | (0, 0, 255), 997 | (255, 255, 0), 998 | (255, 0, 255), 999 | (0, 255, 255), 1000 | ] 1001 | else: 1002 | color_list = [color] 1003 | for _idx, _skeleton in enumerate(skeleton_list): 1004 | for i in range(len(_skeleton) - 1): 1005 | cv2.line( 1006 | img, 1007 | (int(kp2ds[_skeleton[i], 0]), int(kp2ds[_skeleton[i], 1])), 1008 | (int(kp2ds[_skeleton[i + 1], 0]), int(kp2ds[_skeleton[i + 1], 1])), 1009 | color_list[_idx % len(color_list)], 1010 | 3, 1011 | ) 1012 | 1013 | for _idx, kp2d in enumerate(kp2ds): 1014 | if kp2d[2] > threshold: 1015 | cv2.circle(img, (int(kp2d[0]), int(kp2d[1])), 3, color, -1) 1016 | # cv2.putText(img, 1017 | # str(_idx), 1018 | # (int(kp2d[0, i, 0])*1, 1019 | # int(kp2d[0, i, 1])*1), 1020 | # cv2.FONT_HERSHEY_SIMPLEX, 1021 | # 0.75, 1022 | # color, 1023 | # 2 1024 | # ) 1025 | 1026 | return img 1027 | 1028 | 1029 | def draw_pcd(pcd_list, save_path=None): 1030 | fig = plt.figure() 1031 | ax = fig.add_subplot(111, projection="3d") 1032 | 1033 | color_list = ["r", "g", "b", "y", "p"] 1034 | 1035 | for _idx, _pcd in enumerate(pcd_list): 1036 | ax.scatter(_pcd[:, 0], _pcd[:, 1], _pcd[:, 2], c=color_list[_idx], marker="o") 1037 | 1038 | ax.set_xlabel("X") 1039 | ax.set_ylabel("Y") 1040 | ax.set_zlabel("Z") 1041 | 1042 | if save_path is not None: 1043 | plt.savefig(save_path) 1044 | else: 1045 | plt.savefig("tmp.png") 1046 | 1047 | 1048 | def load_image(img, reverse=False): 1049 | if type(img) == str: 1050 | img = cv2.imread(img) 1051 | if reverse: 1052 | img = img.astype(np.float32) 1053 | img = img[:, :, ::-1] 1054 | img = img.astype(np.uint8) 1055 | return img 1056 | 1057 | 1058 | def draw_skeleten(meta): 1059 | kps = [] 1060 | for i, kp in enumerate(meta["keypoints_body"]): 1061 | if kp is None: 1062 | # if kp is None: 1063 | kps.append([0, 0, 0]) 1064 | else: 1065 | kps.append([*kp, 1]) 1066 | kps = np.array(kps) 1067 | 1068 | kps[:, 0] *= meta["width"] 1069 | kps[:, 1] *= meta["height"] 1070 | pose_img = np.zeros([meta["height"], meta["width"], 3], dtype=np.uint8) 1071 | 1072 | pose_img = draw_aapose( 1073 | pose_img, 1074 | kps, 1075 | draw_hand=True, 1076 | kp2ds_lhand=meta["keypoints_left_hand"], 1077 | kp2ds_rhand=meta["keypoints_right_hand"], 1078 | ) 1079 | return pose_img 1080 | 1081 | 1082 | def draw_skeleten_with_pncc(pncc: np.ndarray, meta: Dict) -> np.ndarray: 1083 | """ 1084 | Args: 1085 | pncc: [H,W,3] 1086 | meta: required keys: keypoints_body: [N, 3] keypoints_left_hand, keypoints_right_hand 1087 | Return: 1088 | np.ndarray [H, W, 3] 1089 | """ 1090 | # preprocess keypoints 1091 | kps = [] 1092 | for i, kp in enumerate(meta["keypoints_body"]): 1093 | if kp is None: 1094 | # if kp is None: 1095 | kps.append([0, 0, 0]) 1096 | elif i in [14, 15, 16, 17]: 1097 | kps.append([0, 0, 0]) 1098 | else: 1099 | kps.append([*kp]) 1100 | kps = np.stack(kps) 1101 | 1102 | kps[:, 0] *= pncc.shape[1] 1103 | kps[:, 1] *= pncc.shape[0] 1104 | 1105 | # draw neck 1106 | canvas = np.zeros_like(pncc) 1107 | if kps[0][2] > 0.6 and kps[1][2] > 0.6: 1108 | canvas = draw_ellipse_by_2kp(canvas, kps[0], kps[1], [0, 0, 255]) 1109 | 1110 | # draw pncc 1111 | mask = (pncc > 0).max(axis=2) 1112 | canvas[mask] = pncc[mask] 1113 | pncc = canvas 1114 | 1115 | # draw other skeleten 1116 | kps[0] = 0 1117 | 1118 | meta["keypoints_left_hand"][:, 0] *= meta["width"] 1119 | meta["keypoints_left_hand"][:, 1] *= meta["height"] 1120 | 1121 | meta["keypoints_right_hand"][:, 0] *= meta["width"] 1122 | meta["keypoints_right_hand"][:, 1] *= meta["height"] 1123 | pose_img = draw_aapose( 1124 | pncc, 1125 | kps, 1126 | draw_hand=True, 1127 | kp2ds_lhand=meta["keypoints_left_hand"], 1128 | kp2ds_rhand=meta["keypoints_right_hand"], 1129 | ) 1130 | return pose_img 1131 | 1132 | 1133 | FACE_CUSTOM_STYLE = { 1134 | "eyeball": {"indexs": [68, 69], "color": [255, 255, 255], "connect": False}, 1135 | "left_eyebrow": {"indexs": [17, 18, 19, 20, 21], "color": [0, 255, 0]}, 1136 | "right_eyebrow": {"indexs": [22, 23, 24, 25, 26], "color": [0, 0, 255]}, 1137 | "left_eye": {"indexs": [36, 37, 38, 39, 40, 41], "color": [255, 255, 0], "close": True}, 1138 | "right_eye": {"indexs": [42, 43, 44, 45, 46, 47], "color": [255, 0, 255], "close": True}, 1139 | "mouth_outside": {"indexs": list(range(48, 60)), "color": [100, 255, 50], "close": True}, 1140 | "mouth_inside": {"indexs": [60, 61, 62, 63, 64, 65, 66, 67], "color": [255, 100, 50], "close": True}, 1141 | } 1142 | 1143 | 1144 | def draw_face_kp(img, kps, thickness=2, style=FACE_CUSTOM_STYLE): 1145 | """ 1146 | Args: 1147 | img: [H, W, 3] 1148 | kps: [70, 2] 1149 | """ 1150 | img = img.copy() 1151 | for key, item in style.items(): 1152 | pts = np.array(kps[item["indexs"]]).astype(np.int32) 1153 | connect = item.get("connect", True) 1154 | color = item["color"] 1155 | close = item.get("close", False) 1156 | if connect: 1157 | cv2.polylines(img, [pts], close, color, thickness=thickness) 1158 | else: 1159 | for kp in pts: 1160 | kp = np.array(kp).astype(np.int32) 1161 | cv2.circle(img, kp, thickness * 2, color=color, thickness=-1) 1162 | return img 1163 | 1164 | 1165 | def draw_traj(metas: List[AAPoseMeta], threshold=0.6): 1166 | 1167 | colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ 1168 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ 1169 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85], [100, 255, 50], [255, 100, 50], 1170 | # foot 1171 | [200, 200, 0], 1172 | [100, 100, 0] 1173 | ] 1174 | limbSeq = [ 1175 | [1, 2], [1, 5], # shoulders 1176 | [2, 3], [3, 4], # left arm 1177 | [5, 6], [6, 7], # right arm 1178 | [1, 8], [8, 9], [9, 10], # right leg 1179 | [1, 11], [11, 12], [12, 13], # left leg 1180 | # face (nose, eyes, ears) 1181 | [13, 18], [10, 19] # foot 1182 | ] 1183 | 1184 | face_seq = [[1, 0], [0, 14], [14, 16], [0, 15], [15, 17]] 1185 | kp_body = np.array([meta.kps_body for meta in metas]) 1186 | kp_body_p = np.array([meta.kps_body_p for meta in metas]) 1187 | 1188 | 1189 | face_seq = random.sample(face_seq, 2) 1190 | 1191 | kp_lh = np.array([meta.kps_lhand for meta in metas]) 1192 | kp_rh = np.array([meta.kps_rhand for meta in metas]) 1193 | 1194 | kp_lh_p = np.array([meta.kps_lhand_p for meta in metas]) 1195 | kp_rh_p = np.array([meta.kps_rhand_p for meta in metas]) 1196 | 1197 | # kp_lh = np.concatenate([kp_lh, kp_lh_p], axis=-1) 1198 | # kp_rh = np.concatenate([kp_rh, kp_rh_p], axis=-1) 1199 | 1200 | new_limbSeq = [] 1201 | key_point_list = [] 1202 | for _idx, ((k1_index, k2_index)) in enumerate(limbSeq): 1203 | 1204 | vis = (kp_body_p[:, k1_index] > threshold) * (kp_body_p[:, k2_index] > threshold) * 1 1205 | if vis.sum() * 1.0 / vis.shape[0] > 0.4: 1206 | new_limbSeq.append([k1_index, k2_index]) 1207 | 1208 | for _idx, ((k1_index, k2_index)) in enumerate(limbSeq): 1209 | 1210 | keypoint1 = kp_body[:, k1_index - 1] 1211 | keypoint2 = kp_body[:, k2_index - 1] 1212 | interleave = random.randint(4, 7) 1213 | randind = random.randint(0, interleave - 1) 1214 | # randind = random.rand(range(interleave), sampling_num) 1215 | 1216 | Y = np.array([keypoint1[:, 0], keypoint2[:, 0]]) 1217 | X = np.array([keypoint1[:, 1], keypoint2[:, 1]]) 1218 | 1219 | vis = (keypoint1[:, -1] > threshold) * (keypoint2[:, -1] > threshold) * 1 1220 | 1221 | # for randidx in randind: 1222 | t = randind / interleave 1223 | x = (1-t)*Y[0, :] + t*Y[1, :] 1224 | y = (1-t)*X[0, :] + t*X[1, :] 1225 | 1226 | # np.array([1]) 1227 | x = x.astype(int) 1228 | y = y.astype(int) 1229 | 1230 | new_array = np.array([x, y, vis]).T 1231 | 1232 | key_point_list.append(new_array) 1233 | 1234 | indx_lh = random.randint(0, kp_lh.shape[1] - 1) 1235 | lh = kp_lh[:, indx_lh, :] 1236 | lh_p = kp_lh_p[:, indx_lh:indx_lh+1] 1237 | lh = np.concatenate([lh, lh_p], axis=-1) 1238 | 1239 | indx_rh = random.randint(0, kp_rh.shape[1] - 1) 1240 | rh = kp_rh[:, random.randint(0, kp_rh.shape[1] - 1), :] 1241 | rh_p = kp_rh_p[:, indx_rh:indx_rh+1] 1242 | rh = np.concatenate([rh, rh_p], axis=-1) 1243 | 1244 | 1245 | 1246 | lh[-1, :] = (lh[-1, :] > threshold) * 1 1247 | rh[-1, :] = (rh[-1, :] > threshold) * 1 1248 | 1249 | # print(rh.shape, new_array.shape) 1250 | # exit() 1251 | key_point_list.append(lh.astype(int)) 1252 | key_point_list.append(rh.astype(int)) 1253 | 1254 | 1255 | key_points_list = np.stack(key_point_list) 1256 | num_points = len(key_points_list) 1257 | sample_colors = random.sample(colors, num_points) 1258 | 1259 | stickwidth = max(int(min(metas[0].width, metas[0].height) / 150), 2) 1260 | 1261 | image_list_ori = [] 1262 | for i in range(key_points_list.shape[-2]): 1263 | _image_vis = np.zeros((metas[0].width, metas[0].height, 3)) 1264 | points = key_points_list[:, i, :] 1265 | for idx, point in enumerate(points): 1266 | x, y, vis = point 1267 | if vis == 1: 1268 | cv2.circle(_image_vis, (x, y), stickwidth, sample_colors[idx], thickness=-1) 1269 | 1270 | image_list_ori.append(_image_vis) 1271 | 1272 | return image_list_ori 1273 | --------------------------------------------------------------------------------