├── GenMM.py ├── LICENSE ├── README.md ├── __init__.py ├── configs ├── default.yaml └── ganimator.yaml ├── dataset ├── blender_motion.py ├── bvh │ ├── Quaternions.py │ ├── __pycache__ │ │ ├── Quaternions.cpython-37.pyc │ │ ├── bvh_io.cpython-37.pyc │ │ ├── bvh_parser.cpython-37.pyc │ │ └── bvh_writer.cpython-37.pyc │ ├── bvh_io.py │ ├── bvh_parser.py │ └── bvh_writer.py ├── bvh_motion.py ├── motion.py └── tracks_motion.py ├── demo.blend ├── docker ├── Dockerfile ├── README.md ├── apt-sources.list ├── requirements.txt └── requirements_blender.txt ├── fix_contact.py ├── nearest_neighbor ├── losses.py └── utils.py ├── run_random_generation.py ├── run_web_server.py └── utils ├── base.py ├── contact.py ├── kinematics.py ├── skeleton.py └── transforms.py /GenMM.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from utils.base import logger 8 | 9 | class GenMM: 10 | def __init__(self, mode = 'random_synthesis', noise_sigma = 1.0, coarse_ratio = 0.2, coarse_ratio_factor = 6, pyr_factor = 0.75, num_stages_limit = -1, device = 'cuda:0', silent = False): 11 | ''' 12 | GenMM main constructor 13 | Args: 14 | device : str = 'cuda:0', default device. 15 | silent : bool = False, whether to mute the output. 16 | ''' 17 | self.device = torch.device(device) 18 | self.silent = silent 19 | 20 | def _get_pyramid_lengths(self, final_len, coarse_ratio, pyr_factor): 21 | ''' 22 | Get a list of pyramid lengths using given target length and ratio 23 | ''' 24 | lengths = [int(np.round(final_len * coarse_ratio))] 25 | while lengths[-1] < final_len: 26 | lengths.append(int(np.round(lengths[-1] / pyr_factor))) 27 | if lengths[-1] == lengths[-2]: 28 | lengths[-1] += 1 29 | lengths[-1] = final_len 30 | 31 | return lengths 32 | 33 | def _get_target_pyramid(self, target, coarse_ratio, pyr_factor, num_stages_limit=-1): 34 | ''' 35 | Reads a target motion(s) and create a pyraimd out of it. Ordered in increatorch.sing size 36 | ''' 37 | self.num_target = len(target) 38 | lengths = [] 39 | min_len = 10000 40 | for i in range(len(target)): 41 | new_length = self._get_pyramid_lengths(len(target[i].motion_data), coarse_ratio, pyr_factor) 42 | min_len = min(min_len, len(new_length)) 43 | if num_stages_limit != -1: 44 | new_length = new_length[:num_stages_limit] 45 | lengths.append(new_length) 46 | for i in range(len(target)): 47 | lengths[i] = lengths[i][-min_len:] 48 | self.pyraimd_lengths = lengths 49 | 50 | target_pyramid = [[] for _ in range(len(lengths[0]))] 51 | for step in range(len(lengths[0])): 52 | for i in range(len(target)): 53 | length = lengths[i][step] 54 | target_pyramid[step].append(target[i].sample(size=length).to(self.device)) 55 | 56 | if not self.silent: 57 | print('Levels:', lengths) 58 | for i in range(len(target_pyramid)): 59 | print(f'Number of clips in target pyramid {i} is {len(target_pyramid[i])}, ranging {[[tgt.min(), tgt.max()] for tgt in target_pyramid[i]]}') 60 | 61 | return target_pyramid 62 | 63 | def _get_initial_motion(self, init_length, noise_sigma): 64 | ''' 65 | Prepare the initial motion for optimization 66 | ''' 67 | initial_motion = F.interpolate(torch.cat([self.target_pyramid[0][i] for i in range(self.num_target)], dim=-1), 68 | size=init_length, mode='linear', align_corners=True) 69 | if noise_sigma > 0: 70 | initial_motion_w_noise = initial_motion + torch.randn_like(initial_motion) * noise_sigma 71 | initial_motion_w_noise = torch.fmod(initial_motion_w_noise, 1.0) 72 | else: 73 | initial_motion_w_noise = initial_motion 74 | 75 | if not self.silent: 76 | print('Initial motion:', initial_motion.min(), initial_motion.max()) 77 | print('Initial motion with noise:', initial_motion_w_noise.min(), initial_motion_w_noise.max()) 78 | 79 | return initial_motion_w_noise 80 | 81 | def run(self, target, criteria, num_frames, num_steps, noise_sigma, patch_size, coarse_ratio, pyr_factor, ext=None, debug_dir=None): 82 | ''' 83 | generation function 84 | Args: 85 | mode : - string = 'x?', generate x times longer frames results 86 | : - int, specifying the number of times to generate 87 | noise_sigma : float = 1.0, random noise. 88 | coarse_ratio : float = 0.2, ratio at the coarse level. 89 | pyr_factor : float = 0.75, pyramid factor. 90 | num_stages_limit : int = -1, no limit. 91 | ''' 92 | if debug_dir is not None: 93 | from tensorboardX import SummaryWriter 94 | writer = SummaryWriter(log_dir=debug_dir) 95 | 96 | # build target pyramid 97 | if 'patchsize' in coarse_ratio: 98 | coarse_ratio = patch_size * float(coarse_ratio.split('x_')[0]) / max([len(t.motion_data) for t in target]) 99 | elif 'nframes' in coarse_ratio: 100 | coarse_ratio = float(coarse_ratio.split('x_')[0]) 101 | else: 102 | raise ValueError('Unsupported coarse ratio specified') 103 | self.target_pyramid = self._get_target_pyramid(target, coarse_ratio, pyr_factor) 104 | 105 | # get the initial motion data 106 | if 'nframes' in num_frames: 107 | syn_length = int(sum([i[-1] for i in self.pyraimd_lengths]) * float(num_frames.split('x_')[0])) 108 | elif num_frames.isdigit(): 109 | syn_length = int(num_frames) 110 | else: 111 | raise ValueError(f'Unsupported mode {self.mode}') 112 | self.synthesized_lengths = self._get_pyramid_lengths(syn_length, coarse_ratio, pyr_factor) 113 | if not self.silent: 114 | print('Synthesized lengths:', self.synthesized_lengths) 115 | self.synthesized = self._get_initial_motion(self.synthesized_lengths[0], noise_sigma) 116 | 117 | # perform the optimization 118 | self.synthesized.requires_grad_(False) 119 | self.pbar = logger(num_steps, len(self.target_pyramid)) 120 | for lvl, lvl_target in enumerate(self.target_pyramid): 121 | self.pbar.new_lvl() 122 | if lvl > 0: 123 | with torch.no_grad(): 124 | self.synthesized = F.interpolate(self.synthesized.detach(), size=self.synthesized_lengths[lvl], mode='linear') 125 | 126 | self.synthesized, losses = GenMM.match_and_blend(self.synthesized, lvl_target, criteria, num_steps, self.pbar, ext=ext) 127 | 128 | criteria.clean_cache() 129 | if debug_dir is not None: 130 | for itr in range(len(losses)): 131 | writer.add_scalar(f'optimize/losses_lvl{lvl}', losses[itr], itr) 132 | self.pbar.pbar.close() 133 | 134 | return self.synthesized.detach() 135 | 136 | 137 | @staticmethod 138 | @torch.no_grad() 139 | def match_and_blend(synthesized, targets, criteria, n_steps, pbar, ext=None): 140 | ''' 141 | Minimizes criteria bewteen synthesized and target 142 | Args: 143 | synthesized : torch.Tensor, optimized motion data 144 | targets : torch.Tensor, target motion data 145 | criteria : optimmize target function 146 | n_steps : int, number of steps to optimize 147 | pbar : logger 148 | ext : extra configurations or constraints (optional) 149 | ''' 150 | losses = [] 151 | for _i in range(n_steps): 152 | synthesized, loss = criteria(synthesized, targets, ext=ext, return_blended_results=True) 153 | 154 | # Update staus 155 | losses.append(loss.item()) 156 | pbar.step() 157 | pbar.print() 158 | 159 | return synthesized, losses 160 | 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Example-based Motion Synthesis via Generative Motion Matching, ACM Transactions on Graphics (Proceedings of SIGGRAPH 2023) 2 | 3 | #####

[Weiyu Li*](https://wyysf-98.github.io/), [Xuelin Chen*†](https://xuelin-chen.github.io/), [Peizhuo Li](https://peizhuoli.github.io/), [Olga Sorkine-Hornung](https://igl.ethz.ch/people/sorkine/), [Baoquan Chen](https://cfcs.pku.edu.cn/baoquan/)

4 | 5 | ####

[Project Page](https://wyysf-98.github.io/GenMM) | [ArXiv](https://arxiv.org/abs/2306.00378) | [Paper](https://wyysf-98.github.io/GenMM/paper/Paper_high_res.pdf) | [Video](https://youtu.be/lehnxcade4I)

6 | 7 |

8 | 9 |

10 | 11 |

All Code and demo will be released in this week(still ongoing...) 🏗️ 🚧 🔨

12 | 13 | - [x] Release main code 14 | - [x] Release blender addon 15 | - [x] Detailed README and installation guide 16 | - [ ] Release skeleton-aware component, WIP as we need to split the joints into groups manually. 17 | - [ ] Release codes for evaluation 18 | 19 | ## Prerequisite 20 | 21 |
Setup environment 22 | 23 | :smiley: We also provide a Dockerfile for easy installation, see [Setup using Docker](./docker/README.md). 24 | 25 | - Python 3.8 26 | - PyTorch 1.12.1 27 | - [unfoldNd](https://github.com/f-dangel/unfoldNd) 28 | 29 | Clone this repository. 30 | 31 | ```sh 32 | git clone git@github.com:wyysf-98/GenMM.git 33 | ``` 34 | 35 | Install the required packages. 36 | 37 | ```sh 38 | conda create -n GenMM python=3.8 39 | conda activate GenMM 40 | conda install -c pytorch pytorch=1.12.1 torchvision=0.13.1 cudatoolkit=11.3 && \ 41 | pip install -r docker/requirements.txt 42 | pip install torch-scatter==2.1.1 43 | ``` 44 | 45 |
46 | 47 | ## Quick inference demo 48 | For local quick inference demo using .bvh file, you can use 49 | 50 | ```sh 51 | python run_random_generation.py -i './data/Malcolm/Gangnam-Style.bvh' 52 | ``` 53 | More configuration can be found in the `run_random_generation.py`. 54 | We use an Apple M1 and NVIDIA Tesla V100 with 32 GB RAM to generate each motion, which takes about ~0.2s and ~0.05s as mentioned in our paper. 55 | 56 | ## Blender add-on 57 | You can install and use the blender add-on with easy installation as our method is efficient and you do not need to install CUDA Toolkit. 58 | We test our code using blender 3.22.0, and will support 2.8.0 in the future. 59 | 60 | Step 1: Find yout blender python path. Common paths are as follows 61 | ```sh 62 | (Windows) 'C:\Program Files\Blender Foundation\Blender 3.2\3.2\python\bin' 63 | (Linux) '/path/to/blender/blender-path/3.2/python/bin' 64 | (Windows) '/Applications/Blender.app/Contents/Resources/3.2/python/bin' 65 | ``` 66 | 67 | Step 2: Install required packages. Open your shell(Linux) or powershell(Windows), 68 | ```sh 69 | cd {your python path} && pip3 install -r docker/requirements.txt && pip3 install torch-scatter==2.1.0 -f https://data.pyg.org/whl/torch-1.12.0+${CUDA}.html 70 | ``` 71 | , where ${CUDA} should be replaced by either cpu, cu117, or cu118 depending on your PyTorch installation. 72 | On my MacOS with M1 cpu, 73 | 74 | ```sh 75 | cd /Applications/Blender.app/Contents/Resources/3.2/python/bin && pip3 install -r docker/requirements_blender.txt && pip3 install torch-scatter==2.1.0 -f https://data.pyg.org/whl/torch-1.12.0+cpu.html 76 | ``` 77 | 78 | Step 3: Install add-on in blender. [Blender Add-ons Official Tutorial](https://docs.blender.org/manual/en/latest/editors/preferences/addons.html). `edit -> Preferences -> Add-ons -> Install -> Select the downloaded .zip file` 79 | 80 | Step 4: Have fun! Click the armature and you will find a `GenMM` tag. 81 | 82 | (GPU support) If you have GPU and CUDA Toolskits installed, we automatically dectect the running device. 83 | 84 | Feel free to submit an issue if you run into any issues during the installation :) 85 | 86 | ## Acknowledgement 87 | 88 | We thank [@stefanonuvoli](https://github.com/stefanonuvoli/skinmixer) for the help for the discussion of implementation about `Motion Reassembly` part (we eventually manually merged the meshes of different characters). And [@Radamés Ajna](https://github.com/radames) for the help of a better huggingface demo. 89 | 90 | 91 | ## Citation 92 | 93 | If you find our work useful for your research, please consider citing using the following BibTeX entry. 94 | 95 | ```BibTeX 96 | @article{10.1145/weiyu23GenMM, 97 | author = {Li, Weiyu and Chen, Xuelin and Li, Peizhuo and Sorkine-Hornung, Olga and Chen, Baoquan}, 98 | title = {Example-Based Motion Synthesis via Generative Motion Matching}, 99 | journal = {ACM Transactions on Graphics (TOG)}, 100 | volume = {42}, 101 | number = {4}, 102 | year = {2023}, 103 | articleno = {94}, 104 | doi = {10.1145/3592395}, 105 | publisher = {Association for Computing Machinery}, 106 | } 107 | ``` 108 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # This program is free software; you can redistribute it and/or modify 2 | # it under the terms of the GNU General Public License as published by 3 | # the Free Software Foundation; either version 3 of the License, or 4 | # (at your option) any later version. 5 | # 6 | # This program is distributed in the hope that it will be useful, but 7 | # WITHOUT ANY WARRANTY; without even the implied warranty of 8 | # MERCHANTIBILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 9 | # General Public License for more details. 10 | # 11 | # You should have received a copy of the GNU General Public License 12 | # along with this program. If not, see . 13 | import os 14 | import sys 15 | import bpy 16 | import torch 17 | import mathutils 18 | import numpy as np 19 | from math import degrees, radians, ceil 20 | from mathutils import Vector, Matrix, Euler 21 | from typing import List, Iterable, Tuple, Any, Dict 22 | 23 | abs_path = os.path.abspath(__file__) 24 | sys.path.append(os.path.dirname(abs_path)) 25 | from GenMM import GenMM 26 | from nearest_neighbor.losses import PatchCoherentLoss 27 | from dataset.blender_motion import BlenderMotion 28 | 29 | bl_info = { 30 | "name" : "GenMM", 31 | "author" : "Weiyu Li", 32 | "description" : "Blender addon for SIGGRAPH paper 'Example-Based Motion Synthesis via Generative Motion Matching'", 33 | "blender" : (3, 2, 0), 34 | "version" : (0, 0, 1), 35 | "location": "3D View", 36 | "description": "Synthesis novel motions form a few exemplars.", 37 | "location" : "", 38 | "support": "TESTING", 39 | "warning" : "", 40 | "category" : "Generic" 41 | } 42 | 43 | 44 | # This function is modified from 45 | # https://github.com/bwrsandman/blender-addons/blob/master/io_anim_bvh 46 | def get_bvh_data(context, 47 | frame_end, 48 | frame_start, 49 | global_scale=1.0, 50 | rotate_mode='NATIVE', 51 | root_transform_only=False, 52 | ): 53 | 54 | def ensure_rot_order(rot_order_str): 55 | if set(rot_order_str) != {'X', 'Y', 'Z'}: 56 | rot_order_str = "XYZ" 57 | return rot_order_str 58 | 59 | file_str = [] 60 | 61 | obj = context.object 62 | arm = obj.data 63 | 64 | # Build a dictionary of children. 65 | # None for parentless 66 | children = {None: []} 67 | 68 | # initialize with blank lists 69 | for bone in arm.bones: 70 | children[bone.name] = [] 71 | 72 | # keep bone order from armature, no sorting, not esspential but means 73 | # we can maintain order from import -> export which secondlife incorrectly expects. 74 | for bone in arm.bones: 75 | children[getattr(bone.parent, "name", None)].append(bone.name) 76 | 77 | # bone name list in the order that the bones are written 78 | serialized_names = [] 79 | 80 | node_locations = {} 81 | 82 | file_str.append("HIERARCHY\n") 83 | 84 | def write_recursive_nodes(bone_name, indent): 85 | my_children = children[bone_name] 86 | 87 | indent_str = "\t" * indent 88 | 89 | bone = arm.bones[bone_name] 90 | pose_bone = obj.pose.bones[bone_name] 91 | loc = bone.head_local 92 | node_locations[bone_name] = loc 93 | 94 | if rotate_mode == "NATIVE": 95 | rot_order_str = ensure_rot_order(pose_bone.rotation_mode) 96 | else: 97 | rot_order_str = rotate_mode 98 | 99 | # make relative if we can 100 | if bone.parent: 101 | loc = loc - node_locations[bone.parent.name] 102 | 103 | if indent: 104 | file_str.append("%sJOINT %s\n" % (indent_str, bone_name)) 105 | else: 106 | file_str.append("%sROOT %s\n" % (indent_str, bone_name)) 107 | 108 | file_str.append("%s{\n" % indent_str) 109 | file_str.append("%s\tOFFSET %.6f %.6f %.6f\n" % (indent_str, loc.x * global_scale, loc.y * global_scale, loc.z * global_scale)) 110 | if (bone.use_connect or root_transform_only) and bone.parent: 111 | file_str.append("%s\tCHANNELS 3 %srotation %srotation %srotation\n" % (indent_str, rot_order_str[0], rot_order_str[1], rot_order_str[2])) 112 | else: 113 | file_str.append("%s\tCHANNELS 6 Xposition Yposition Zposition %srotation %srotation %srotation\n" % (indent_str, rot_order_str[0], rot_order_str[1], rot_order_str[2])) 114 | 115 | if my_children: 116 | # store the location for the children 117 | # to get their relative offset 118 | 119 | # Write children 120 | for child_bone in my_children: 121 | serialized_names.append(child_bone) 122 | write_recursive_nodes(child_bone, indent + 1) 123 | 124 | else: 125 | # Write the bone end. 126 | file_str.append("%s\tEnd Site\n" % indent_str) 127 | file_str.append("%s\t{\n" % indent_str) 128 | loc = bone.tail_local - node_locations[bone_name] 129 | file_str.append("%s\t\tOFFSET %.6f %.6f %.6f\n" % (indent_str, loc.x * global_scale, loc.y * global_scale, loc.z * global_scale)) 130 | file_str.append("%s\t}\n" % indent_str) 131 | 132 | file_str.append("%s}\n" % indent_str) 133 | 134 | if len(children[None]) == 1: 135 | key = children[None][0] 136 | serialized_names.append(key) 137 | indent = 0 138 | 139 | write_recursive_nodes(key, indent) 140 | 141 | else: 142 | # Write a dummy parent node, with a dummy key name 143 | # Just be sure it's not used by another bone! 144 | i = 0 145 | key = "__%d" % i 146 | while key in children: 147 | i += 1 148 | key = "__%d" % i 149 | file_str.append("ROOT %s\n" % key) 150 | file_str.append("{\n") 151 | file_str.append("\tOFFSET 0.0 0.0 0.0\n") 152 | file_str.append("\tCHANNELS 0\n") # Xposition Yposition Zposition Xrotation Yrotation Zrotation 153 | indent = 1 154 | 155 | # Write children 156 | for child_bone in children[None]: 157 | serialized_names.append(child_bone) 158 | write_recursive_nodes(child_bone, indent) 159 | 160 | file_str.append("}\n") 161 | file_str = ''.join(file_str) 162 | # redefine bones as sorted by serialized_names 163 | # so we can write motion 164 | 165 | class DecoratedBone: 166 | __slots__ = ( 167 | # Bone name, used as key in many places. 168 | "name", 169 | "parent", # decorated bone parent, set in a later loop 170 | # Blender armature bone. 171 | "rest_bone", 172 | # Blender pose bone. 173 | "pose_bone", 174 | # Blender pose matrix. 175 | "pose_mat", 176 | # Blender rest matrix (armature space). 177 | "rest_arm_mat", 178 | # Blender rest matrix (local space). 179 | "rest_local_mat", 180 | # Pose_mat inverted. 181 | "pose_imat", 182 | # Rest_arm_mat inverted. 183 | "rest_arm_imat", 184 | # Rest_local_mat inverted. 185 | "rest_local_imat", 186 | # Last used euler to preserve euler compatibility in between keyframes. 187 | "prev_euler", 188 | # Is the bone disconnected to the parent bone? 189 | "skip_position", 190 | "rot_order", 191 | "rot_order_str", 192 | # Needed for the euler order when converting from a matrix. 193 | "rot_order_str_reverse", 194 | ) 195 | 196 | _eul_order_lookup = { 197 | 'XYZ': (0, 1, 2), 198 | 'XZY': (0, 2, 1), 199 | 'YXZ': (1, 0, 2), 200 | 'YZX': (1, 2, 0), 201 | 'ZXY': (2, 0, 1), 202 | 'ZYX': (2, 1, 0), 203 | } 204 | 205 | def __init__(self, bone_name): 206 | self.name = bone_name 207 | self.rest_bone = arm.bones[bone_name] 208 | self.pose_bone = obj.pose.bones[bone_name] 209 | 210 | if rotate_mode == "NATIVE": 211 | self.rot_order_str = ensure_rot_order(self.pose_bone.rotation_mode) 212 | else: 213 | self.rot_order_str = rotate_mode 214 | self.rot_order_str_reverse = self.rot_order_str[::-1] 215 | 216 | self.rot_order = DecoratedBone._eul_order_lookup[self.rot_order_str] 217 | 218 | self.pose_mat = self.pose_bone.matrix 219 | 220 | # mat = self.rest_bone.matrix # UNUSED 221 | self.rest_arm_mat = self.rest_bone.matrix_local 222 | self.rest_local_mat = self.rest_bone.matrix 223 | 224 | # inverted mats 225 | self.pose_imat = self.pose_mat.inverted() 226 | self.rest_arm_imat = self.rest_arm_mat.inverted() 227 | self.rest_local_imat = self.rest_local_mat.inverted() 228 | 229 | self.parent = None 230 | self.prev_euler = Euler((0.0, 0.0, 0.0), self.rot_order_str_reverse) 231 | self.skip_position = ((self.rest_bone.use_connect or root_transform_only) and self.rest_bone.parent) 232 | 233 | def update_posedata(self): 234 | self.pose_mat = self.pose_bone.matrix 235 | self.pose_imat = self.pose_mat.inverted() 236 | 237 | def __repr__(self): 238 | if self.parent: 239 | return "[\"%s\" child on \"%s\"]\n" % (self.name, self.parent.name) 240 | else: 241 | return "[\"%s\" root bone]\n" % (self.name) 242 | 243 | bones_decorated = [DecoratedBone(bone_name) for bone_name in serialized_names] 244 | 245 | # Assign parents 246 | bones_decorated_dict = {dbone.name: dbone for dbone in bones_decorated} 247 | for dbone in bones_decorated: 248 | parent = dbone.rest_bone.parent 249 | if parent: 250 | dbone.parent = bones_decorated_dict[parent.name] 251 | del bones_decorated_dict 252 | # finish assigning parents 253 | 254 | scene = context.scene 255 | frame_current = scene.frame_current 256 | 257 | file_str += "MOTION\n" 258 | file_str += "Frames: %d\n" % (frame_end - frame_start + 1) 259 | file_str += "Frame Time: %.6f\n" % (1.0 / (scene.render.fps / scene.render.fps_base)) 260 | 261 | for frame in range(frame_start, frame_end + 1): 262 | scene.frame_set(frame) 263 | 264 | for dbone in bones_decorated: 265 | dbone.update_posedata() 266 | 267 | for dbone in bones_decorated: 268 | trans = Matrix.Translation(dbone.rest_bone.head_local) 269 | itrans = Matrix.Translation(-dbone.rest_bone.head_local) 270 | 271 | if dbone.parent: 272 | mat_final = dbone.parent.rest_arm_mat @ dbone.parent.pose_imat @ dbone.pose_mat @ dbone.rest_arm_imat 273 | mat_final = itrans @ mat_final @ trans 274 | loc = mat_final.to_translation() + (dbone.rest_bone.head_local - dbone.parent.rest_bone.head_local) 275 | else: 276 | mat_final = dbone.pose_mat @ dbone.rest_arm_imat 277 | mat_final = itrans @ mat_final @ trans 278 | loc = mat_final.to_translation() + dbone.rest_bone.head 279 | 280 | # keep eulers compatible, no jumping on interpolation. 281 | rot = mat_final.to_euler(dbone.rot_order_str_reverse, dbone.prev_euler) 282 | 283 | if not dbone.skip_position: 284 | file_str += "%.6f %.6f %.6f " % (loc * global_scale)[:] 285 | 286 | file_str += "%.6f %.6f %.6f " % (degrees(rot[dbone.rot_order[0]]), degrees(rot[dbone.rot_order[1]]), degrees(rot[dbone.rot_order[2]])) 287 | 288 | dbone.prev_euler = rot 289 | 290 | file_str += "\n" 291 | 292 | scene.frame_set(frame_current) 293 | 294 | return file_str 295 | 296 | 297 | class BVH_Node: 298 | __slots__ = ( 299 | # Bvh joint name. 300 | 'name', 301 | # BVH_Node type or None for no parent. 302 | 'parent', 303 | # A list of children of this type.. 304 | 'children', 305 | # Worldspace rest location for the head of this node. 306 | 'rest_head_world', 307 | # Localspace rest location for the head of this node. 308 | 'rest_head_local', 309 | # Worldspace rest location for the tail of this node. 310 | 'rest_tail_world', 311 | # Worldspace rest location for the tail of this node. 312 | 'rest_tail_local', 313 | # List of 6 ints, -1 for an unused channel, 314 | # otherwise an index for the BVH motion data lines, 315 | # loc triple then rot triple. 316 | 'channels', 317 | # A triple of indices as to the order rotation is applied. 318 | # [0,1,2] is x/y/z - [None, None, None] if no rotation.. 319 | 'rot_order', 320 | # Same as above but a string 'XYZ' format.. 321 | 'rot_order_str', 322 | # A list one tuple's one for each frame: (locx, locy, locz, rotx, roty, rotz), 323 | # euler rotation ALWAYS stored xyz order, even when native used. 324 | 'anim_data', 325 | # Convenience function, bool, same as: (channels[0] != -1 or channels[1] != -1 or channels[2] != -1). 326 | 'has_loc', 327 | # Convenience function, bool, same as: (channels[3] != -1 or channels[4] != -1 or channels[5] != -1). 328 | 'has_rot', 329 | # Index from the file, not strictly needed but nice to maintain order. 330 | 'index', 331 | # Use this for whatever you want. 332 | 'temp', 333 | ) 334 | 335 | _eul_order_lookup = { 336 | (None, None, None): 'XYZ', # XXX Dummy one, no rotation anyway! 337 | (0, 1, 2): 'XYZ', 338 | (0, 2, 1): 'XZY', 339 | (1, 0, 2): 'YXZ', 340 | (1, 2, 0): 'YZX', 341 | (2, 0, 1): 'ZXY', 342 | (2, 1, 0): 'ZYX', 343 | } 344 | 345 | def __init__(self, name, rest_head_world, rest_head_local, parent, channels, rot_order, index): 346 | self.name = name 347 | self.rest_head_world = rest_head_world 348 | self.rest_head_local = rest_head_local 349 | self.rest_tail_world = None 350 | self.rest_tail_local = None 351 | self.parent = parent 352 | self.channels = channels 353 | self.rot_order = tuple(rot_order) 354 | self.rot_order_str = BVH_Node._eul_order_lookup[self.rot_order] 355 | self.index = index 356 | 357 | # convenience functions 358 | self.has_loc = channels[0] != -1 or channels[1] != -1 or channels[2] != -1 359 | self.has_rot = channels[3] != -1 or channels[4] != -1 or channels[5] != -1 360 | 361 | self.children = [] 362 | 363 | # List of 6 length tuples: (lx, ly, lz, rx, ry, rz) 364 | # even if the channels aren't used they will just be zero. 365 | self.anim_data = [(0, 0, 0, 0, 0, 0)] 366 | 367 | def __repr__(self): 368 | return ( 369 | "BVH name: '%s', rest_loc:(%.3f,%.3f,%.3f), rest_tail:(%.3f,%.3f,%.3f)" % ( 370 | self.name, 371 | *self.rest_head_world, 372 | *self.rest_head_world, 373 | ) 374 | ) 375 | 376 | 377 | def sorted_nodes(bvh_nodes): 378 | bvh_nodes_list = list(bvh_nodes.values()) 379 | bvh_nodes_list.sort(key=lambda bvh_node: bvh_node.index) 380 | return bvh_nodes_list 381 | 382 | 383 | def read_bvh(context, bvh_str, rotate_mode='XYZ', global_scale=1.0): 384 | # Separate into a list of lists, each line a list of words. 385 | file_lines = bvh_str 386 | # Non standard carriage returns? 387 | if len(file_lines) == 1: 388 | file_lines = file_lines[0].split('\r') 389 | 390 | # Split by whitespace. 391 | file_lines = [ll for ll in [l.split() for l in file_lines] if ll] 392 | 393 | # Create hierarchy as empties 394 | if file_lines[0][0].lower() == 'hierarchy': 395 | # print 'Importing the BVH Hierarchy for:', file_path 396 | pass 397 | else: 398 | raise Exception("This is not a BVH file") 399 | 400 | bvh_nodes = {None: None} 401 | bvh_nodes_serial = [None] 402 | bvh_frame_count = None 403 | bvh_frame_time = None 404 | 405 | channelIndex = -1 406 | 407 | lineIdx = 0 # An index for the file. 408 | while lineIdx < len(file_lines) - 1: 409 | if file_lines[lineIdx][0].lower() in {'root', 'joint'}: 410 | 411 | # Join spaces into 1 word with underscores joining it. 412 | if len(file_lines[lineIdx]) > 2: 413 | file_lines[lineIdx][1] = '_'.join(file_lines[lineIdx][1:]) 414 | file_lines[lineIdx] = file_lines[lineIdx][:2] 415 | 416 | # MAY NEED TO SUPPORT MULTIPLE ROOTS HERE! Still unsure weather multiple roots are possible? 417 | 418 | # Make sure the names are unique - Object names will match joint names exactly and both will be unique. 419 | name = file_lines[lineIdx][1] 420 | 421 | # print '%snode: %s, parent: %s' % (len(bvh_nodes_serial) * ' ', name, bvh_nodes_serial[-1]) 422 | 423 | lineIdx += 2 # Increment to the next line (Offset) 424 | rest_head_local = global_scale * Vector(( 425 | float(file_lines[lineIdx][1]), 426 | float(file_lines[lineIdx][2]), 427 | float(file_lines[lineIdx][3]), 428 | )) 429 | lineIdx += 1 # Increment to the next line (Channels) 430 | 431 | # newChannel[Xposition, Yposition, Zposition, Xrotation, Yrotation, Zrotation] 432 | # newChannel references indices to the motiondata, 433 | # if not assigned then -1 refers to the last value that will be added on loading at a value of zero, this is appended 434 | # We'll add a zero value onto the end of the MotionDATA so this always refers to a value. 435 | my_channel = [-1, -1, -1, -1, -1, -1] 436 | my_rot_order = [None, None, None] 437 | rot_count = 0 438 | for channel in file_lines[lineIdx][2:]: 439 | channel = channel.lower() 440 | channelIndex += 1 # So the index points to the right channel 441 | if channel == 'xposition': 442 | my_channel[0] = channelIndex 443 | elif channel == 'yposition': 444 | my_channel[1] = channelIndex 445 | elif channel == 'zposition': 446 | my_channel[2] = channelIndex 447 | 448 | elif channel == 'xrotation': 449 | my_channel[3] = channelIndex 450 | my_rot_order[rot_count] = 0 451 | rot_count += 1 452 | elif channel == 'yrotation': 453 | my_channel[4] = channelIndex 454 | my_rot_order[rot_count] = 1 455 | rot_count += 1 456 | elif channel == 'zrotation': 457 | my_channel[5] = channelIndex 458 | my_rot_order[rot_count] = 2 459 | rot_count += 1 460 | 461 | channels = file_lines[lineIdx][2:] 462 | 463 | my_parent = bvh_nodes_serial[-1] # account for none 464 | 465 | # Apply the parents offset accumulatively 466 | if my_parent is None: 467 | rest_head_world = Vector(rest_head_local) 468 | else: 469 | rest_head_world = my_parent.rest_head_world + rest_head_local 470 | 471 | bvh_node = bvh_nodes[name] = BVH_Node( 472 | name, 473 | rest_head_world, 474 | rest_head_local, 475 | my_parent, 476 | my_channel, 477 | my_rot_order, 478 | len(bvh_nodes) - 1, 479 | ) 480 | 481 | # If we have another child then we can call ourselves a parent, else 482 | bvh_nodes_serial.append(bvh_node) 483 | 484 | # Account for an end node. 485 | # There is sometimes a name after 'End Site' but we will ignore it. 486 | if file_lines[lineIdx][0].lower() == 'end' and file_lines[lineIdx][1].lower() == 'site': 487 | # Increment to the next line (Offset) 488 | lineIdx += 2 489 | rest_tail = global_scale * Vector(( 490 | float(file_lines[lineIdx][1]), 491 | float(file_lines[lineIdx][2]), 492 | float(file_lines[lineIdx][3]), 493 | )) 494 | 495 | bvh_nodes_serial[-1].rest_tail_world = bvh_nodes_serial[-1].rest_head_world + rest_tail 496 | bvh_nodes_serial[-1].rest_tail_local = bvh_nodes_serial[-1].rest_head_local + rest_tail 497 | 498 | # Just so we can remove the parents in a uniform way, 499 | # the end has kids so this is a placeholder. 500 | bvh_nodes_serial.append(None) 501 | 502 | if len(file_lines[lineIdx]) == 1 and file_lines[lineIdx][0] == '}': # == ['}'] 503 | bvh_nodes_serial.pop() # Remove the last item 504 | 505 | # End of the hierarchy. Begin the animation section of the file with 506 | # the following header. 507 | # MOTION 508 | # Frames: n 509 | # Frame Time: dt 510 | if len(file_lines[lineIdx]) == 1 and file_lines[lineIdx][0].lower() == 'motion': 511 | lineIdx += 1 # Read frame count. 512 | if ( 513 | len(file_lines[lineIdx]) == 2 and 514 | file_lines[lineIdx][0].lower() == 'frames:' 515 | ): 516 | bvh_frame_count = int(file_lines[lineIdx][1]) 517 | 518 | lineIdx += 1 # Read frame rate. 519 | if ( 520 | len(file_lines[lineIdx]) == 3 and 521 | file_lines[lineIdx][0].lower() == 'frame' and 522 | file_lines[lineIdx][1].lower() == 'time:' 523 | ): 524 | bvh_frame_time = float(file_lines[lineIdx][2]) 525 | 526 | lineIdx += 1 # Set the cursor to the first frame 527 | 528 | break 529 | 530 | lineIdx += 1 531 | 532 | # Remove the None value used for easy parent reference 533 | del bvh_nodes[None] 534 | # Don't use anymore 535 | del bvh_nodes_serial 536 | 537 | # importing world with any order but nicer to maintain order 538 | # second life expects it, which isn't to spec. 539 | bvh_nodes_list = sorted_nodes(bvh_nodes) 540 | 541 | while lineIdx < len(file_lines): 542 | line = file_lines[lineIdx] 543 | for bvh_node in bvh_nodes_list: 544 | # for bvh_node in bvh_nodes_serial: 545 | lx = ly = lz = rx = ry = rz = 0.0 546 | channels = bvh_node.channels 547 | anim_data = bvh_node.anim_data 548 | if channels[0] != -1: 549 | lx = global_scale * float(line[channels[0]]) 550 | 551 | if channels[1] != -1: 552 | ly = global_scale * float(line[channels[1]]) 553 | 554 | if channels[2] != -1: 555 | lz = global_scale * float(line[channels[2]]) 556 | 557 | if channels[3] != -1 or channels[4] != -1 or channels[5] != -1: 558 | 559 | rx = radians(float(line[channels[3]])) 560 | ry = radians(float(line[channels[4]])) 561 | rz = radians(float(line[channels[5]])) 562 | 563 | # Done importing motion data # 564 | anim_data.append((lx, ly, lz, rx, ry, rz)) 565 | lineIdx += 1 566 | 567 | # Assign children 568 | for bvh_node in bvh_nodes_list: 569 | bvh_node_parent = bvh_node.parent 570 | if bvh_node_parent: 571 | bvh_node_parent.children.append(bvh_node) 572 | 573 | # Now set the tip of each bvh_node 574 | for bvh_node in bvh_nodes_list: 575 | 576 | if not bvh_node.rest_tail_world: 577 | if len(bvh_node.children) == 0: 578 | # could just fail here, but rare BVH files have childless nodes 579 | bvh_node.rest_tail_world = Vector(bvh_node.rest_head_world) 580 | bvh_node.rest_tail_local = Vector(bvh_node.rest_head_local) 581 | elif len(bvh_node.children) == 1: 582 | bvh_node.rest_tail_world = Vector(bvh_node.children[0].rest_head_world) 583 | bvh_node.rest_tail_local = bvh_node.rest_head_local + bvh_node.children[0].rest_head_local 584 | else: 585 | # allow this, see above 586 | # if not bvh_node.children: 587 | # raise Exception("bvh node has no end and no children. bad file") 588 | 589 | # Removed temp for now 590 | rest_tail_world = Vector((0.0, 0.0, 0.0)) 591 | rest_tail_local = Vector((0.0, 0.0, 0.0)) 592 | for bvh_node_child in bvh_node.children: 593 | rest_tail_world += bvh_node_child.rest_head_world 594 | rest_tail_local += bvh_node_child.rest_head_local 595 | 596 | bvh_node.rest_tail_world = rest_tail_world * (1.0 / len(bvh_node.children)) 597 | bvh_node.rest_tail_local = rest_tail_local * (1.0 / len(bvh_node.children)) 598 | 599 | # Make sure tail isn't the same location as the head. 600 | if (bvh_node.rest_tail_local - bvh_node.rest_head_local).length <= 0.001 * global_scale: 601 | print("\tzero length node found:", bvh_node.name) 602 | bvh_node.rest_tail_local.y = bvh_node.rest_tail_local.y + global_scale / 10 603 | bvh_node.rest_tail_world.y = bvh_node.rest_tail_world.y + global_scale / 10 604 | 605 | return bvh_nodes, bvh_frame_time, bvh_frame_count 606 | 607 | 608 | def bvh_node_dict2objects(context, bvh_name, bvh_nodes, rotate_mode='NATIVE', frame_start=1, IMPORT_LOOP=False): 609 | 610 | if frame_start < 1: 611 | frame_start = 1 612 | 613 | scene = context.scene 614 | for obj in scene.objects: 615 | obj.select_set(False) 616 | 617 | objects = [] 618 | 619 | def add_ob(name): 620 | obj = bpy.data.objects.new(name, None) 621 | context.collection.objects.link(obj) 622 | objects.append(obj) 623 | obj.select_set(True) 624 | 625 | # nicer drawing. 626 | obj.empty_display_type = 'CUBE' 627 | obj.empty_display_size = 0.1 628 | 629 | return obj 630 | 631 | # Add objects 632 | for name, bvh_node in bvh_nodes.items(): 633 | bvh_node.temp = add_ob(name) 634 | bvh_node.temp.rotation_mode = bvh_node.rot_order_str[::-1] 635 | 636 | # Parent the objects 637 | for bvh_node in bvh_nodes.values(): 638 | for bvh_node_child in bvh_node.children: 639 | bvh_node_child.temp.parent = bvh_node.temp 640 | 641 | # Offset 642 | for bvh_node in bvh_nodes.values(): 643 | # Make relative to parents offset 644 | bvh_node.temp.location = bvh_node.rest_head_local 645 | 646 | # Add tail objects 647 | for name, bvh_node in bvh_nodes.items(): 648 | if not bvh_node.children: 649 | ob_end = add_ob(name + '_end') 650 | ob_end.parent = bvh_node.temp 651 | ob_end.location = bvh_node.rest_tail_world - bvh_node.rest_head_world 652 | 653 | for name, bvh_node in bvh_nodes.items(): 654 | obj = bvh_node.temp 655 | 656 | for frame_current in range(len(bvh_node.anim_data)): 657 | 658 | lx, ly, lz, rx, ry, rz = bvh_node.anim_data[frame_current] 659 | 660 | if bvh_node.has_loc: 661 | obj.delta_location = Vector((lx, ly, lz)) - bvh_node.rest_head_world 662 | obj.keyframe_insert("delta_location", index=-1, frame=frame_start + frame_current) 663 | 664 | if bvh_node.has_rot: 665 | obj.delta_rotation_euler = rx, ry, rz 666 | obj.keyframe_insert("delta_rotation_euler", index=-1, frame=frame_start + frame_current) 667 | 668 | return objects 669 | 670 | 671 | def bvh_node_dict2armature( 672 | context, 673 | bvh_name, 674 | bvh_nodes, 675 | bvh_frame_time, 676 | rotate_mode='XYZ', 677 | frame_start=1, 678 | IMPORT_LOOP=False, 679 | global_matrix=None, 680 | use_fps_scale=False, 681 | ): 682 | 683 | if frame_start < 1: 684 | frame_start = 1 685 | 686 | # Add the new armature, 687 | scene = context.scene 688 | for obj in scene.objects: 689 | obj.select_set(False) 690 | 691 | arm_data = bpy.data.armatures.new(bvh_name) 692 | arm_ob = bpy.data.objects.new(bvh_name, arm_data) 693 | 694 | context.collection.objects.link(arm_ob) 695 | 696 | arm_ob.select_set(True) 697 | context.view_layer.objects.active = arm_ob 698 | 699 | bpy.ops.object.mode_set(mode='OBJECT', toggle=False) 700 | bpy.ops.object.mode_set(mode='EDIT', toggle=False) 701 | 702 | bvh_nodes_list = sorted_nodes(bvh_nodes) 703 | 704 | # Get the average bone length for zero length bones, we may not use this. 705 | average_bone_length = 0.0 706 | nonzero_count = 0 707 | for bvh_node in bvh_nodes_list: 708 | l = (bvh_node.rest_head_local - bvh_node.rest_tail_local).length 709 | if l: 710 | average_bone_length += l 711 | nonzero_count += 1 712 | 713 | # Very rare cases all bones could be zero length??? 714 | if not average_bone_length: 715 | average_bone_length = 0.1 716 | else: 717 | # Normal operation 718 | average_bone_length = average_bone_length / nonzero_count 719 | 720 | # XXX, annoying, remove bone. 721 | while arm_data.edit_bones: 722 | arm_ob.edit_bones.remove(arm_data.edit_bones[-1]) 723 | 724 | ZERO_AREA_BONES = [] 725 | for bvh_node in bvh_nodes_list: 726 | 727 | # New editbone 728 | bone = bvh_node.temp = arm_data.edit_bones.new(bvh_node.name) 729 | 730 | bone.head = bvh_node.rest_head_world 731 | bone.tail = bvh_node.rest_tail_world 732 | 733 | # Zero Length Bones! (an exceptional case) 734 | if (bone.head - bone.tail).length < 0.001: 735 | print("\tzero length bone found:", bone.name) 736 | if bvh_node.parent: 737 | ofs = bvh_node.parent.rest_head_local - bvh_node.parent.rest_tail_local 738 | if ofs.length: # is our parent zero length also?? unlikely 739 | bone.tail = bone.tail - ofs 740 | else: 741 | bone.tail.y = bone.tail.y + average_bone_length 742 | else: 743 | bone.tail.y = bone.tail.y + average_bone_length 744 | 745 | ZERO_AREA_BONES.append(bone.name) 746 | 747 | for bvh_node in bvh_nodes_list: 748 | if bvh_node.parent: 749 | # bvh_node.temp is the Editbone 750 | 751 | # Set the bone parent 752 | bvh_node.temp.parent = bvh_node.parent.temp 753 | 754 | # Set the connection state 755 | if( 756 | (not bvh_node.has_loc) and 757 | (bvh_node.parent.temp.name not in ZERO_AREA_BONES) and 758 | (bvh_node.parent.rest_tail_local == bvh_node.rest_head_local) 759 | ): 760 | bvh_node.temp.use_connect = True 761 | 762 | # Replace the editbone with the editbone name, 763 | # to avoid memory errors accessing the editbone outside editmode 764 | for bvh_node in bvh_nodes_list: 765 | bvh_node.temp = bvh_node.temp.name 766 | 767 | # Now Apply the animation to the armature 768 | 769 | # Get armature animation data 770 | bpy.ops.object.mode_set(mode='OBJECT', toggle=False) 771 | 772 | pose = arm_ob.pose 773 | pose_bones = pose.bones 774 | 775 | if rotate_mode == 'NATIVE': 776 | for bvh_node in bvh_nodes_list: 777 | bone_name = bvh_node.temp # may not be the same name as the bvh_node, could have been shortened. 778 | pose_bone = pose_bones[bone_name] 779 | pose_bone.rotation_mode = bvh_node.rot_order_str 780 | 781 | elif rotate_mode != 'QUATERNION': 782 | for pose_bone in pose_bones: 783 | pose_bone.rotation_mode = rotate_mode 784 | else: 785 | # Quats default 786 | pass 787 | 788 | context.view_layer.update() 789 | 790 | arm_ob.animation_data_create() 791 | action = bpy.data.actions.new(name=bvh_name) 792 | arm_ob.animation_data.action = action 793 | 794 | # Replace the bvh_node.temp (currently an editbone) 795 | # With a tuple (pose_bone, armature_bone, bone_rest_matrix, bone_rest_matrix_inv) 796 | num_frame = 0 797 | for bvh_node in bvh_nodes_list: 798 | bone_name = bvh_node.temp # may not be the same name as the bvh_node, could have been shortened. 799 | pose_bone = pose_bones[bone_name] 800 | rest_bone = arm_data.bones[bone_name] 801 | bone_rest_matrix = rest_bone.matrix_local.to_3x3() 802 | 803 | bone_rest_matrix_inv = Matrix(bone_rest_matrix) 804 | bone_rest_matrix_inv.invert() 805 | 806 | bone_rest_matrix_inv.resize_4x4() 807 | bone_rest_matrix.resize_4x4() 808 | bvh_node.temp = (pose_bone, bone, bone_rest_matrix, bone_rest_matrix_inv) 809 | 810 | if 0 == num_frame: 811 | num_frame = len(bvh_node.anim_data) 812 | 813 | # Choose to skip some frames at the beginning. Frame 0 is the rest pose 814 | # used internally by this importer. Frame 1, by convention, is also often 815 | # the rest pose of the skeleton exported by the motion capture system. 816 | skip_frame = 1 817 | if num_frame > skip_frame: 818 | num_frame = num_frame - skip_frame 819 | 820 | # Create a shared time axis for all animation curves. 821 | time = [float(frame_start)] * num_frame 822 | if use_fps_scale: 823 | dt = scene.render.fps * bvh_frame_time 824 | for frame_i in range(1, num_frame): 825 | time[frame_i] += float(frame_i) * dt 826 | else: 827 | for frame_i in range(1, num_frame): 828 | time[frame_i] += float(frame_i) 829 | 830 | # print("bvh_frame_time = %f, dt = %f, num_frame = %d" 831 | # % (bvh_frame_time, dt, num_frame])) 832 | 833 | for i, bvh_node in enumerate(bvh_nodes_list): 834 | pose_bone, bone, bone_rest_matrix, bone_rest_matrix_inv = bvh_node.temp 835 | 836 | if bvh_node.has_loc: 837 | # Not sure if there is a way to query this or access it in the 838 | # PoseBone structure. 839 | data_path = 'pose.bones["%s"].location' % pose_bone.name 840 | 841 | location = [(0.0, 0.0, 0.0)] * num_frame 842 | for frame_i in range(num_frame): 843 | bvh_loc = bvh_node.anim_data[frame_i + skip_frame][:3] 844 | 845 | bone_translate_matrix = Matrix.Translation( 846 | Vector(bvh_loc) - bvh_node.rest_head_local) 847 | location[frame_i] = (bone_rest_matrix_inv @ 848 | bone_translate_matrix).to_translation() 849 | 850 | # For each location x, y, z. 851 | for axis_i in range(3): 852 | curve = action.fcurves.new(data_path=data_path, index=axis_i, action_group=bvh_node.name) 853 | keyframe_points = curve.keyframe_points 854 | keyframe_points.add(num_frame) 855 | 856 | for frame_i in range(num_frame): 857 | keyframe_points[frame_i].co = ( 858 | time[frame_i], 859 | location[frame_i][axis_i], 860 | ) 861 | 862 | if bvh_node.has_rot: 863 | data_path = None 864 | rotate = None 865 | 866 | if 'QUATERNION' == rotate_mode: 867 | rotate = [(1.0, 0.0, 0.0, 0.0)] * num_frame 868 | data_path = ('pose.bones["%s"].rotation_quaternion' 869 | % pose_bone.name) 870 | else: 871 | rotate = [(0.0, 0.0, 0.0)] * num_frame 872 | data_path = ('pose.bones["%s"].rotation_euler' % 873 | pose_bone.name) 874 | 875 | prev_euler = Euler((0.0, 0.0, 0.0)) 876 | for frame_i in range(num_frame): 877 | bvh_rot = bvh_node.anim_data[frame_i + skip_frame][3:] 878 | 879 | # apply rotation order and convert to XYZ 880 | # note that the rot_order_str is reversed. 881 | euler = Euler(bvh_rot, bvh_node.rot_order_str[::-1]) 882 | bone_rotation_matrix = euler.to_matrix().to_4x4() 883 | bone_rotation_matrix = ( 884 | bone_rest_matrix_inv @ 885 | bone_rotation_matrix @ 886 | bone_rest_matrix 887 | ) 888 | 889 | if len(rotate[frame_i]) == 4: 890 | rotate[frame_i] = bone_rotation_matrix.to_quaternion() 891 | else: 892 | rotate[frame_i] = bone_rotation_matrix.to_euler( 893 | pose_bone.rotation_mode, prev_euler) 894 | prev_euler = rotate[frame_i] 895 | 896 | # For each euler angle x, y, z (or quaternion w, x, y, z). 897 | for axis_i in range(len(rotate[0])): 898 | curve = action.fcurves.new(data_path=data_path, index=axis_i, action_group=bvh_node.name) 899 | keyframe_points = curve.keyframe_points 900 | keyframe_points.add(num_frame) 901 | 902 | for frame_i in range(num_frame): 903 | keyframe_points[frame_i].co = ( 904 | time[frame_i], 905 | rotate[frame_i][axis_i], 906 | ) 907 | 908 | for cu in action.fcurves: 909 | if IMPORT_LOOP: 910 | pass # 2.5 doenst have cyclic now? 911 | 912 | for bez in cu.keyframe_points: 913 | bez.interpolation = 'LINEAR' 914 | 915 | # finally apply matrix 916 | try: 917 | arm_ob.matrix_world = global_matrix 918 | except: 919 | pass 920 | bpy.ops.object.transform_apply(location=False, rotation=True, scale=False) 921 | 922 | return arm_ob 923 | 924 | 925 | def load( 926 | context, 927 | bvh_str, 928 | *, 929 | target='ARMATURE', 930 | rotate_mode='NATIVE', 931 | global_scale=1.0, 932 | use_cyclic=False, 933 | frame_start=1, 934 | global_matrix=None, 935 | use_fps_scale=False, 936 | update_scene_fps=False, 937 | update_scene_duration=False, 938 | report=print, 939 | ): 940 | import time 941 | t1 = time.time() 942 | 943 | bvh_nodes, bvh_frame_time, bvh_frame_count = read_bvh( 944 | context, bvh_str, 945 | rotate_mode=rotate_mode, 946 | global_scale=global_scale, 947 | ) 948 | 949 | print("%.4f" % (time.time() - t1)) 950 | 951 | scene = context.scene 952 | frame_orig = scene.frame_current 953 | 954 | # Broken BVH handling: guess frame rate when it is not contained in the file. 955 | if bvh_frame_time is None: 956 | report( 957 | {'WARNING'}, 958 | "The BVH file does not contain frame duration in its MOTION " 959 | "section, assuming the BVH and Blender scene have the same " 960 | "frame rate" 961 | ) 962 | bvh_frame_time = scene.render.fps_base / scene.render.fps 963 | # No need to scale the frame rate, as they're equal now anyway. 964 | use_fps_scale = False 965 | 966 | if update_scene_fps: 967 | _update_scene_fps(context, report, bvh_frame_time) 968 | 969 | # Now that we have a 1-to-1 mapping of Blender frames and BVH frames, there is no need 970 | # to scale the FPS any more. It's even better not to, to prevent roundoff errors. 971 | use_fps_scale = False 972 | 973 | if update_scene_duration: 974 | _update_scene_duration(context, report, bvh_frame_count, bvh_frame_time, frame_start, use_fps_scale) 975 | 976 | t1 = time.time() 977 | print("\timporting to blender...", end="") 978 | 979 | bvh_name = bpy.path.display_name_from_filepath('synsized') 980 | 981 | if target == 'ARMATURE': 982 | bvh_node_dict2armature( 983 | context, bvh_name, bvh_nodes, bvh_frame_time, 984 | rotate_mode=rotate_mode, 985 | frame_start=frame_start, 986 | IMPORT_LOOP=use_cyclic, 987 | global_matrix=global_matrix, 988 | use_fps_scale=use_fps_scale, 989 | ) 990 | 991 | elif target == 'OBJECT': 992 | bvh_node_dict2objects( 993 | context, bvh_name, bvh_nodes, 994 | rotate_mode=rotate_mode, 995 | frame_start=frame_start, 996 | IMPORT_LOOP=use_cyclic, 997 | # global_matrix=global_matrix, # TODO 998 | ) 999 | 1000 | else: 1001 | report({'ERROR'}, tip_("Invalid target %r (must be 'ARMATURE' or 'OBJECT')") % target) 1002 | return {'CANCELLED'} 1003 | 1004 | print('Done in %.4f\n' % (time.time() - t1)) 1005 | 1006 | context.scene.frame_set(frame_orig) 1007 | 1008 | return {'FINISHED'} 1009 | 1010 | 1011 | def _update_scene_fps(context, report, bvh_frame_time): 1012 | """Update the scene's FPS settings from the BVH, but only if the BVH contains enough info.""" 1013 | 1014 | # Broken BVH handling: prevent division by zero. 1015 | if bvh_frame_time == 0.0: 1016 | report( 1017 | {'WARNING'}, 1018 | "Unable to update scene frame rate, as the BVH file " 1019 | "contains a zero frame duration in its MOTION section", 1020 | ) 1021 | return 1022 | 1023 | scene = context.scene 1024 | scene_fps = scene.render.fps / scene.render.fps_base 1025 | new_fps = 1.0 / bvh_frame_time 1026 | 1027 | if scene.render.fps != new_fps or scene.render.fps_base != 1.0: 1028 | print("\tupdating scene FPS (was %f) to BVH FPS (%f)" % (scene_fps, new_fps)) 1029 | scene.render.fps = int(round(new_fps)) 1030 | scene.render.fps_base = scene.render.fps / new_fps 1031 | 1032 | 1033 | def _update_scene_duration( 1034 | context, report, bvh_frame_count, bvh_frame_time, frame_start, 1035 | use_fps_scale): 1036 | """Extend the scene's duration so that the BVH file fits in its entirety.""" 1037 | 1038 | if bvh_frame_count is None: 1039 | report( 1040 | {'WARNING'}, 1041 | "Unable to extend the scene duration, as the BVH file does not " 1042 | "contain the number of frames in its MOTION section", 1043 | ) 1044 | return 1045 | 1046 | # Not likely, but it can happen when a BVH is just used to store an armature. 1047 | if bvh_frame_count == 0: 1048 | return 1049 | 1050 | if use_fps_scale: 1051 | scene_fps = context.scene.render.fps / context.scene.render.fps_base 1052 | scaled_frame_count = int(ceil(bvh_frame_count * bvh_frame_time * scene_fps)) 1053 | bvh_last_frame = frame_start + scaled_frame_count 1054 | else: 1055 | bvh_last_frame = frame_start + bvh_frame_count 1056 | 1057 | # Only extend the scene, never shorten it. 1058 | if context.scene.frame_end < bvh_last_frame: 1059 | context.scene.frame_end = bvh_last_frame 1060 | 1061 | 1062 | # This function is from 1063 | # https://github.com/yuki-koyama/blender-cli-rendering 1064 | def set_smooth_shading(mesh: bpy.types.Mesh) -> None: 1065 | for polygon in mesh.polygons: 1066 | polygon.use_smooth = True 1067 | 1068 | 1069 | # This function is from 1070 | # https://github.com/yuki-koyama/blender-cli-rendering 1071 | def create_mesh_from_pydata(scene: bpy.types.Scene, 1072 | vertices: Iterable[Iterable[float]], 1073 | faces: Iterable[Iterable[int]], 1074 | mesh_name: str, 1075 | object_name: str, 1076 | use_smooth: bool = True) -> bpy.types.Object: 1077 | # Add a new mesh and set vertices and faces 1078 | # Note: In this case, it does not require to set edges. 1079 | # Note: After manipulating mesh data, update() needs to be called. 1080 | new_mesh: bpy.types.Mesh = bpy.data.meshes.new(mesh_name) 1081 | new_mesh.from_pydata(vertices, [], faces) 1082 | new_mesh.update() 1083 | if use_smooth: 1084 | set_smooth_shading(new_mesh) 1085 | 1086 | new_object: bpy.types.Object = bpy.data.objects.new(object_name, new_mesh) 1087 | scene.collection.objects.link(new_object) 1088 | 1089 | return new_object 1090 | 1091 | 1092 | # This function is from 1093 | # https://github.com/yuki-koyama/blender-cli-rendering 1094 | def add_subdivision_surface_modifier(mesh_object: bpy.types.Object, level: int, is_simple: bool = False) -> None: 1095 | ''' 1096 | https://docs.blender.org/api/current/bpy.types.SubsurfModifier.html 1097 | ''' 1098 | 1099 | modifier: bpy.types.SubsurfModifier = mesh_object.modifiers.new(name="Subsurf", type='SUBSURF') 1100 | 1101 | modifier.levels = level 1102 | modifier.render_levels = level 1103 | modifier.subdivision_type = 'SIMPLE' if is_simple else 'CATMULL_CLARK' 1104 | 1105 | 1106 | # This function is from 1107 | # https://github.com/yuki-koyama/blender-cli-rendering 1108 | def create_armature_mesh(scene: bpy.types.Scene, armature_object: bpy.types.Object, mesh_name: str) -> bpy.types.Object: 1109 | assert armature_object.type == 'ARMATURE', 'Error' 1110 | assert len(armature_object.data.bones) != 0, 'Error' 1111 | 1112 | def add_rigid_vertex_group(target_object: bpy.types.Object, name: str, vertex_indices: Iterable[int]) -> None: 1113 | new_vertex_group = target_object.vertex_groups.new(name=name) 1114 | for vertex_index in vertex_indices: 1115 | new_vertex_group.add([vertex_index], 1.0, 'REPLACE') 1116 | 1117 | def generate_bone_mesh_pydata(radius: float, length: float) -> Tuple[List[mathutils.Vector], List[List[int]]]: 1118 | base_radius = radius 1119 | top_radius = 0.5 * radius 1120 | 1121 | vertices = [ 1122 | # Cross section of the base part 1123 | mathutils.Vector((-base_radius, 0.0, +base_radius)), 1124 | mathutils.Vector((+base_radius, 0.0, +base_radius)), 1125 | mathutils.Vector((+base_radius, 0.0, -base_radius)), 1126 | mathutils.Vector((-base_radius, 0.0, -base_radius)), 1127 | 1128 | # Cross section of the top part 1129 | mathutils.Vector((-top_radius, length, +top_radius)), 1130 | mathutils.Vector((+top_radius, length, +top_radius)), 1131 | mathutils.Vector((+top_radius, length, -top_radius)), 1132 | mathutils.Vector((-top_radius, length, -top_radius)), 1133 | 1134 | # End points 1135 | mathutils.Vector((0.0, -base_radius, 0.0)), 1136 | mathutils.Vector((0.0, length + top_radius, 0.0)) 1137 | ] 1138 | 1139 | faces = [ 1140 | # End point for the base part 1141 | [8, 1, 0], 1142 | [8, 2, 1], 1143 | [8, 3, 2], 1144 | [8, 0, 3], 1145 | 1146 | # End point for the top part 1147 | [9, 4, 5], 1148 | [9, 5, 6], 1149 | [9, 6, 7], 1150 | [9, 7, 4], 1151 | 1152 | # Side faces 1153 | [0, 1, 5, 4], 1154 | [1, 2, 6, 5], 1155 | [2, 3, 7, 6], 1156 | [3, 0, 4, 7], 1157 | ] 1158 | 1159 | return vertices, faces 1160 | 1161 | armature_data: bpy.types.Armature = armature_object.data 1162 | 1163 | vertices: List[mathutils.Vector] = [] 1164 | faces: List[List[int]] = [] 1165 | vertex_groups: List[Dict[str, Any]] = [] 1166 | 1167 | for bone in armature_data.bones: 1168 | radius = 0.10 * (0.10 + bone.length) 1169 | temp_vertices, temp_faces = generate_bone_mesh_pydata(radius, bone.length) 1170 | 1171 | vertex_index_offset = len(vertices) 1172 | 1173 | temp_vertex_group = {'name': bone.name, 'vertex_indices': []} 1174 | for local_index, vertex in enumerate(temp_vertices): 1175 | vertices.append(bone.matrix_local @ vertex) 1176 | temp_vertex_group['vertex_indices'].append(local_index + vertex_index_offset) 1177 | vertex_groups.append(temp_vertex_group) 1178 | 1179 | for face in temp_faces: 1180 | if len(face) == 3: 1181 | faces.append([ 1182 | face[0] + vertex_index_offset, 1183 | face[1] + vertex_index_offset, 1184 | face[2] + vertex_index_offset, 1185 | ]) 1186 | else: 1187 | faces.append([ 1188 | face[0] + vertex_index_offset, 1189 | face[1] + vertex_index_offset, 1190 | face[2] + vertex_index_offset, 1191 | face[3] + vertex_index_offset, 1192 | ]) 1193 | 1194 | new_object = create_mesh_from_pydata(scene, vertices, faces, mesh_name, mesh_name) 1195 | new_object.matrix_world = armature_object.matrix_world 1196 | 1197 | for vertex_group in vertex_groups: 1198 | add_rigid_vertex_group(new_object, vertex_group['name'], vertex_group['vertex_indices']) 1199 | 1200 | armature_modifier = new_object.modifiers.new('Armature', 'ARMATURE') 1201 | armature_modifier.object = armature_object 1202 | armature_modifier.use_vertex_groups = True 1203 | 1204 | add_subdivision_surface_modifier(new_object, 1, is_simple=True) 1205 | add_subdivision_surface_modifier(new_object, 2, is_simple=False) 1206 | 1207 | # Set the armature as the parent of the new object 1208 | bpy.ops.object.select_all(action='DESELECT') 1209 | new_object.select_set(True) 1210 | armature_object.select_set(True) 1211 | bpy.context.view_layer.objects.active = armature_object 1212 | bpy.ops.object.parent_set(type='OBJECT') 1213 | 1214 | return new_object 1215 | 1216 | 1217 | class OP_AddMesh(bpy.types.Operator): 1218 | 1219 | bl_idname = "genmm.add_mesh" 1220 | bl_label = "Add mesh" 1221 | bl_description = "" 1222 | bl_options = {"REGISTER", "UNDO"} 1223 | 1224 | def __init__(self) -> None: 1225 | super().__init__() 1226 | 1227 | def execute(self, context: bpy.types.Context): 1228 | name = bpy.context.object.name + "_proxy" 1229 | create_armature_mesh(bpy.context.scene, bpy.context.object, name) 1230 | 1231 | return {'FINISHED'} 1232 | 1233 | 1234 | class OP_RunSynthesis(bpy.types.Operator): 1235 | 1236 | bl_idname = "genmm.run_synthesis" 1237 | bl_label = "Run synthesis" 1238 | bl_description = "" 1239 | bl_options = {"REGISTER", "UNDO"} 1240 | 1241 | def __init__(self) -> None: 1242 | super().__init__() 1243 | 1244 | def execute(self, context: bpy.types.Context): 1245 | setting = context.scene.setting 1246 | 1247 | anim = context.object.animation_data.action 1248 | start_frame, end_frame = map(int, anim.frame_range) 1249 | start_frame = start_frame if setting.start_frame == -1 else start_frame 1250 | end_frame = end_frame if setting.end_frame == -1 else end_frame 1251 | 1252 | bvh_str = get_bvh_data(context, frame_start=start_frame, frame_end=end_frame) 1253 | frames_str, frame_time_str = bvh_str.split('MOTION\n')[1].split('\n')[:2] 1254 | motion_data_str = bvh_str.split('MOTION\n')[1].split('\n')[2:-1] 1255 | motion_data = np.array([item.strip().split(' ') for item in motion_data_str], dtype=np.float32) 1256 | 1257 | motion = [BlenderMotion(motion_data, repr='repr6d', use_velo=True, keep_up_pos=True, up_axis=setting.up_axis, padding_last=False)] 1258 | model = GenMM(device='cuda' if torch.cuda.is_available() else 'cpu', silent=True) 1259 | criteria = PatchCoherentLoss(patch_size=setting.patch_size, 1260 | alpha=setting.alpha, 1261 | loop=setting.loop, cache=True) 1262 | 1263 | syn = model.run(motion, criteria, 1264 | num_frames=str(setting.num_syn_frames), 1265 | num_steps=setting.num_steps, 1266 | noise_sigma=setting.noise, 1267 | patch_size=setting.patch_size, 1268 | coarse_ratio=f'{setting.coarse_ratio}x_nframes', 1269 | pyr_factor=setting.pyr_factor) 1270 | motion_data_str = [' '.join(str(x) for x in item) for item in motion[0].parse(syn)] 1271 | 1272 | load(context, bvh_str.split('MOTION\n')[0].split('\n')+['MOTION']+[frames_str]+[frame_time_str]+motion_data_str) 1273 | # name = bpy.context.object.name + "_proxy" 1274 | # create_armature_mesh(bpy.context.scene, bpy.context.object, name) 1275 | 1276 | return {'FINISHED'} 1277 | 1278 | 1279 | class GENMM_PT_ControlPanel(bpy.types.Panel): 1280 | 1281 | bl_label = "GenMM" 1282 | bl_space_type = 'VIEW_3D' 1283 | bl_region_type = 'UI' 1284 | bl_category = "GenMM" 1285 | 1286 | @classmethod 1287 | def poll(cls, context: bpy.types.Context): 1288 | return True 1289 | 1290 | def draw_header(self, context: bpy.types.Context): 1291 | layout = self.layout 1292 | layout.label(text="", icon='PLUGIN') 1293 | 1294 | def draw(self, context: bpy.types.Context): 1295 | layout = self.layout 1296 | scene = bpy.context.scene 1297 | 1298 | ops: List[bpy.type.Operator] = [ 1299 | OP_AddMesh, 1300 | ] 1301 | for op in ops: 1302 | layout.operator(op.bl_idname, text=op.bl_label) 1303 | 1304 | box = layout.box() 1305 | box.label(text="Exemplar config:") 1306 | exemplar_row = box.row() 1307 | exemplar_row.prop(scene.setting, "start_frame") 1308 | exemplar_row.prop(scene.setting, "end_frame") 1309 | exemplar_row = box.row() 1310 | exemplar_row.prop(scene.setting, "up_axis") 1311 | 1312 | box = layout.box() 1313 | box.label(text="Synthesis config:") 1314 | box.prop(scene.setting, "loop") 1315 | box.prop(scene.setting, "noise") 1316 | box.prop(scene.setting, "num_syn_frames") 1317 | box.prop(scene.setting, "patch_size") 1318 | box.prop(scene.setting, "coarse_ratio") 1319 | box.prop(scene.setting, "pyr_factor") 1320 | box.prop(scene.setting, "alpha") 1321 | box.prop(scene.setting, "num_steps") 1322 | 1323 | ops: List[bpy.type.Operator] = [ 1324 | OP_RunSynthesis, 1325 | ] 1326 | for op in ops: 1327 | layout.operator(op.bl_idname, text=op.bl_label) 1328 | 1329 | 1330 | class PropertyGroup(bpy.types.PropertyGroup): 1331 | '''Property container for options and paths of GenMM''' 1332 | start_frame: bpy.props.IntProperty( 1333 | name="Start Frame", 1334 | description="Start Frame of the Exemplar Moition.", 1335 | default=1) 1336 | end_frame: bpy.props.IntProperty( 1337 | name="End Frame", 1338 | description="End Frame of the Exemplar Moition.", 1339 | default=-1) 1340 | up_axis: bpy.props.EnumProperty( 1341 | name="Up Axis", 1342 | default='Z_UP', 1343 | description="Up axis of the Exemplar Moition", 1344 | items=[('Z_UP', "Z-Up", 'Z Up'), 1345 | ('Y_UP', "Y-Up", 'Y Up'), 1346 | ('X_UP', "X-Up", 'X Up'), 1347 | ] 1348 | ) 1349 | noise: bpy.props.FloatProperty( 1350 | name="Noise Intensity", 1351 | description="Intensity of Noise Added to the Synthesized Motion.", 1352 | default=10) 1353 | num_syn_frames: bpy.props.IntProperty( 1354 | name="Num. of Frames", 1355 | description="Number of the Synthesized Motion.", 1356 | default=600) 1357 | patch_size: bpy.props.IntProperty( 1358 | name="Patch Size", 1359 | description="Size for Patch Extraction.", 1360 | min=7, 1361 | default=15) 1362 | coarse_ratio: bpy.props.FloatProperty( 1363 | name="Coarse Ratio", 1364 | description="Ratio of the Coarest Pyramid.", 1365 | min=0.0, 1366 | default=0.2) 1367 | pyr_factor: bpy.props.FloatProperty( 1368 | name="Pyramid Factor", 1369 | description="Pyramid Downsample Factor.", 1370 | min=0.1, 1371 | default=0.75) 1372 | alpha: bpy.props.FloatProperty( 1373 | name="Completeness Alpha", 1374 | description="Alpha Value for Completeness/Diversity Trade-off.", 1375 | default=0.05) 1376 | loop: bpy.props.BoolProperty( 1377 | name="Endless Loop", 1378 | description="Whether to Use Loop Constrain.", 1379 | default=False) 1380 | num_steps: bpy.props.IntProperty( 1381 | name="Num of Steps", 1382 | description="Number of Optimized Steps.", 1383 | default=5) 1384 | 1385 | 1386 | classes = [ 1387 | OP_AddMesh, 1388 | OP_RunSynthesis, 1389 | GENMM_PT_ControlPanel, 1390 | ] 1391 | 1392 | 1393 | def register(): 1394 | bpy.utils.register_class(PropertyGroup) 1395 | bpy.types.Scene.setting = bpy.props.PointerProperty(type=PropertyGroup) 1396 | 1397 | for cls in classes: 1398 | bpy.utils.register_class(cls) 1399 | 1400 | 1401 | def unregister(): 1402 | bpy.utils.unregister_class(PropertyGroup) 1403 | 1404 | for cls in classes: 1405 | bpy.utils.unregister_class(cls) 1406 | 1407 | 1408 | if __name__ == "__main__": 1409 | register() 1410 | -------------------------------------------------------------------------------- /configs/default.yaml: -------------------------------------------------------------------------------- 1 | # motion data config 2 | repr: 'repr6d' 3 | skeleton_name: null 4 | use_velo: true 5 | keep_up_pos: true 6 | up_axis: 'Y_UP' 7 | padding_last: false 8 | requires_contact: false 9 | joint_reduction: false 10 | skeleton_aware: false 11 | joints_group: null 12 | 13 | # generate parameters 14 | num_frames: '2x_nframes' 15 | alpha: 0.01 16 | num_steps: 3 17 | noise_sigma: 10.0 18 | coarse_ratio: '5x_patchsize' 19 | # coarse_ratio: '0.2x_nframes' 20 | pyr_factor: 0.75 21 | num_stages_limit: -1 22 | patch_size: 11 23 | loop: false -------------------------------------------------------------------------------- /configs/ganimator.yaml: -------------------------------------------------------------------------------- 1 | ################################################################ 2 | # This configuration uses the same input format of GANimmator for generation 3 | ################################################################ 4 | outout_dir: './output/ganimator_format' 5 | 6 | # for GANimator BVH data 7 | repr: 'repr6d' 8 | skeleton_name: 'mixamo' 9 | use_velo: true 10 | keep_up_pos: true 11 | up_axis: 'Y_UP' 12 | padding_last: true 13 | requires_contact: true 14 | joint_reduction: true 15 | skeleton_aware: false 16 | joints_group: null 17 | 18 | # generate parameters 19 | num_frames: '2x_nframes' 20 | alpha: 0.01 21 | num_steps: 3 22 | noise_sigma: 10.0 23 | coarse_ratio: '3x_patchsize' 24 | # coarse_ratio: '0.1x_nframes' 25 | pyr_factor: 0.75 26 | num_stages_limit: -1 27 | patch_size: 11 28 | loop: false -------------------------------------------------------------------------------- /dataset/blender_motion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import torch 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from .motion import MotionData 7 | from utils.transforms import quat2repr6d, euler2mat, mat2quat, repr6d2quat, quat2euler 8 | 9 | class BlenderMotion: 10 | def __init__(self, motion_data, repr='quat', use_velo=True, keep_up_pos=True, up_axis=None, padding_last=False): 11 | ''' 12 | BVHMotion constructor 13 | Args: 14 | motion_data : np.array, bvh format data to load from 15 | repr : string, rotation representation, support ['quat', 'repr6d', 'euler'] 16 | use_velo : book, whether to transform the joints positions to velocities 17 | keep_up_pos : bool, whether to keep y position when converting to velocity 18 | up_axis : string, up axis of the motion data 19 | padding_last : bool, whether to pad the last position 20 | requires_contact : bool, whether to concatenate contact information 21 | ''' 22 | self.motion_data = motion_data 23 | 24 | def to_tensor(motion_data, repr='euler', rot_only=False): 25 | if repr not in ['euler', 'quat', 'quaternion', 'repr6d']: 26 | raise Exception('Unknown rotation representation') 27 | if repr == 'quaternion' or repr == 'quat' or repr == 'repr6d': # default is euler for blender data 28 | rotations = torch.tensor(motion_data[:, 3:], dtype=torch.float).view(motion_data.shape[0], -1, 3) 29 | if repr == 'quat': 30 | rotations = euler2mat(rotations) 31 | rotations = mat2quat(rotations) 32 | if repr == 'repr6d': 33 | rotations = euler2mat(rotations) 34 | rotations = mat2quat(rotations) 35 | rotations = quat2repr6d(rotations) 36 | 37 | positions = torch.tensor(motion_data[:, :3], dtype=torch.float32) 38 | 39 | if rot_only: 40 | return rotations.reshape(rotations.shape[0], -1) 41 | 42 | rotations = rotations.reshape(rotations.shape[0], -1) 43 | return torch.cat((rotations, positions), dim=-1) 44 | 45 | self.motion_data = MotionData(to_tensor(motion_data, repr=repr).permute(1, 0).unsqueeze(0), repr=repr, use_velo=use_velo, 46 | keep_up_pos=keep_up_pos, up_axis=up_axis, padding_last=padding_last, contact_id=None) 47 | @property 48 | def repr(self): 49 | return self.motion_data.repr 50 | 51 | @property 52 | def use_velo(self): 53 | return self.motion_data.use_velo 54 | 55 | @property 56 | def keep_up_pos(self): 57 | return self.motion_data.keep_up_pos 58 | 59 | @property 60 | def padding_last(self): 61 | return self.motion_data.padding_last 62 | 63 | @property 64 | def concat_id(self): 65 | return self.motion_data.contact_id 66 | 67 | @property 68 | def n_pad(self): 69 | return self.motion_data.n_pad 70 | 71 | @property 72 | def n_contact(self): 73 | return self.motion_data.n_contact 74 | 75 | @property 76 | def n_rot(self): 77 | return self.motion_data.n_rot 78 | 79 | def sample(self, size=None, slerp=False): 80 | ''' 81 | Sample motion data, support slerp 82 | ''' 83 | return self.motion_data.sample(size, slerp) 84 | 85 | def parse(self, motion, keep_velo=False,): 86 | """ 87 | No batch support here!!! 88 | :returns tracks_json 89 | """ 90 | motion = motion.clone() 91 | 92 | if self.use_velo and not keep_velo: 93 | motion = self.motion_data.to_position(motion) 94 | if self.n_pad: 95 | motion = motion[:, :-self.n_pad] 96 | 97 | motion = motion.squeeze().permute(1, 0) 98 | pos = motion[..., -3:] 99 | rot = motion[..., :-3].reshape(motion.shape[0], -1, self.n_rot) 100 | if self.repr == 'quat': 101 | rot = quat2euler(rot) 102 | elif self.repr == 'repr6d': 103 | rot = repr6d2quat(rot) 104 | rot = quat2euler(rot) 105 | 106 | return torch.cat([pos, rot.view(motion.shape[0], -1)], dim=-1).cpu().numpy() 107 | -------------------------------------------------------------------------------- /dataset/bvh/Quaternions.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is modified from: 3 | http://theorangeduck.com/page/deep-learning-framework-character-motion-synthesis-and-editing 4 | 5 | by Daniel Holden et al 6 | """ 7 | 8 | 9 | import numpy as np 10 | 11 | class Quaternions: 12 | """ 13 | Quaternions is a wrapper around a numpy ndarray 14 | that allows it to act as if it were an narray of 15 | a quater data type. 16 | 17 | Therefore addition, subtraction, multiplication, 18 | division, negation, absolute, are all defined 19 | in terms of quater operations such as quater 20 | multiplication. 21 | 22 | This allows for much neater code and many routines 23 | which conceptually do the same thing to be written 24 | in the same way for point data and for rotation data. 25 | 26 | The Quaternions class has been desgined such that it 27 | should support broadcasting and slicing in all of the 28 | usual ways. 29 | """ 30 | 31 | def __init__(self, qs): 32 | if isinstance(qs, np.ndarray): 33 | if len(qs.shape) == 1: qs = np.array([qs]) 34 | self.qs = qs 35 | return 36 | 37 | if isinstance(qs, Quaternions): 38 | self.qs = qs 39 | return 40 | 41 | raise TypeError('Quaternions must be constructed from iterable, numpy array, or Quaternions, not %s' % type(qs)) 42 | 43 | def __str__(self): return "Quaternions("+ str(self.qs) + ")" 44 | def __repr__(self): return "Quaternions("+ repr(self.qs) + ")" 45 | 46 | """ Helper Methods for Broadcasting and Data extraction """ 47 | 48 | @classmethod 49 | def _broadcast(cls, sqs, oqs, scalar=False): 50 | if isinstance(oqs, float): return sqs, oqs * np.ones(sqs.shape[:-1]) 51 | 52 | ss = np.array(sqs.shape) if not scalar else np.array(sqs.shape[:-1]) 53 | os = np.array(oqs.shape) 54 | 55 | if len(ss) != len(os): 56 | raise TypeError('Quaternions cannot broadcast together shapes %s and %s' % (sqs.shape, oqs.shape)) 57 | 58 | if np.all(ss == os): return sqs, oqs 59 | 60 | if not np.all((ss == os) | (os == np.ones(len(os))) | (ss == np.ones(len(ss)))): 61 | raise TypeError('Quaternions cannot broadcast together shapes %s and %s' % (sqs.shape, oqs.shape)) 62 | 63 | sqsn, oqsn = sqs.copy(), oqs.copy() 64 | 65 | for a in np.where(ss == 1)[0]: sqsn = sqsn.repeat(os[a], axis=a) 66 | for a in np.where(os == 1)[0]: oqsn = oqsn.repeat(ss[a], axis=a) 67 | 68 | return sqsn, oqsn 69 | 70 | """ Adding Quaterions is just Defined as Multiplication """ 71 | 72 | def __add__(self, other): return self * other 73 | def __sub__(self, other): return self / other 74 | 75 | """ Quaterion Multiplication """ 76 | 77 | def __mul__(self, other): 78 | """ 79 | Quaternion multiplication has three main methods. 80 | 81 | When multiplying a Quaternions array by Quaternions 82 | normal quater multiplication is performed. 83 | 84 | When multiplying a Quaternions array by a vector 85 | array of the same shape, where the last axis is 3, 86 | it is assumed to be a Quaternion by 3D-Vector 87 | multiplication and the 3D-Vectors are rotated 88 | in space by the Quaternions. 89 | 90 | When multipplying a Quaternions array by a scalar 91 | or vector of different shape it is assumed to be 92 | a Quaternions by Scalars multiplication and the 93 | Quaternions are scaled using Slerp and the identity 94 | quaternions. 95 | """ 96 | 97 | """ If Quaternions type do Quaternions * Quaternions """ 98 | if isinstance(other, Quaternions): 99 | sqs, oqs = Quaternions._broadcast(self.qs, other.qs) 100 | 101 | q0 = sqs[...,0]; q1 = sqs[...,1]; 102 | q2 = sqs[...,2]; q3 = sqs[...,3]; 103 | r0 = oqs[...,0]; r1 = oqs[...,1]; 104 | r2 = oqs[...,2]; r3 = oqs[...,3]; 105 | 106 | qs = np.empty(sqs.shape) 107 | qs[...,0] = r0 * q0 - r1 * q1 - r2 * q2 - r3 * q3 108 | qs[...,1] = r0 * q1 + r1 * q0 - r2 * q3 + r3 * q2 109 | qs[...,2] = r0 * q2 + r1 * q3 + r2 * q0 - r3 * q1 110 | qs[...,3] = r0 * q3 - r1 * q2 + r2 * q1 + r3 * q0 111 | 112 | return Quaternions(qs) 113 | 114 | """ If array type do Quaternions * Vectors """ 115 | if isinstance(other, np.ndarray) and other.shape[-1] == 3: 116 | vs = Quaternions(np.concatenate([np.zeros(other.shape[:-1] + (1,)), other], axis=-1)) 117 | 118 | return (self * (vs * -self)).imaginaries 119 | 120 | """ If float do Quaternions * Scalars """ 121 | if isinstance(other, np.ndarray) or isinstance(other, float): 122 | return Quaternions.slerp(Quaternions.id_like(self), self, other) 123 | 124 | raise TypeError('Cannot multiply/add Quaternions with type %s' % str(type(other))) 125 | 126 | def __div__(self, other): 127 | """ 128 | When a Quaternion type is supplied, division is defined 129 | as multiplication by the inverse of that Quaternion. 130 | 131 | When a scalar or vector is supplied it is defined 132 | as multiplicaion of one over the supplied value. 133 | Essentially a scaling. 134 | """ 135 | 136 | if isinstance(other, Quaternions): return self * (-other) 137 | if isinstance(other, np.ndarray): return self * (1.0 / other) 138 | if isinstance(other, float): return self * (1.0 / other) 139 | raise TypeError('Cannot divide/subtract Quaternions with type %s' + str(type(other))) 140 | 141 | def __eq__(self, other): return self.qs == other.qs 142 | def __ne__(self, other): return self.qs != other.qs 143 | 144 | def __neg__(self): 145 | """ Invert Quaternions """ 146 | return Quaternions(self.qs * np.array([[1, -1, -1, -1]])) 147 | 148 | def __abs__(self): 149 | """ Unify Quaternions To Single Pole """ 150 | qabs = self.normalized().copy() 151 | top = np.sum(( qabs.qs) * np.array([1,0,0,0]), axis=-1) 152 | bot = np.sum((-qabs.qs) * np.array([1,0,0,0]), axis=-1) 153 | qabs.qs[top < bot] = -qabs.qs[top < bot] 154 | return qabs 155 | 156 | def __iter__(self): return iter(self.qs) 157 | def __len__(self): return len(self.qs) 158 | 159 | def __getitem__(self, k): return Quaternions(self.qs[k]) 160 | def __setitem__(self, k, v): self.qs[k] = v.qs 161 | 162 | @property 163 | def lengths(self): 164 | return np.sum(self.qs**2.0, axis=-1)**0.5 165 | 166 | @property 167 | def reals(self): 168 | return self.qs[...,0] 169 | 170 | @property 171 | def imaginaries(self): 172 | return self.qs[...,1:4] 173 | 174 | @property 175 | def shape(self): return self.qs.shape[:-1] 176 | 177 | def repeat(self, n, **kwargs): 178 | return Quaternions(self.qs.repeat(n, **kwargs)) 179 | 180 | def normalized(self): 181 | return Quaternions(self.qs / self.lengths[...,np.newaxis]) 182 | 183 | def log(self): 184 | norm = abs(self.normalized()) 185 | imgs = norm.imaginaries 186 | lens = np.sqrt(np.sum(imgs**2, axis=-1)) 187 | lens = np.arctan2(lens, norm.reals) / (lens + 1e-10) 188 | return imgs * lens[...,np.newaxis] 189 | 190 | def constrained(self, axis): 191 | 192 | rl = self.reals 193 | im = np.sum(axis * self.imaginaries, axis=-1) 194 | 195 | t1 = -2 * np.arctan2(rl, im) + np.pi 196 | t2 = -2 * np.arctan2(rl, im) - np.pi 197 | 198 | top = Quaternions.exp(axis[np.newaxis] * (t1[:,np.newaxis] / 2.0)) 199 | bot = Quaternions.exp(axis[np.newaxis] * (t2[:,np.newaxis] / 2.0)) 200 | img = self.dot(top) > self.dot(bot) 201 | 202 | ret = top.copy() 203 | ret[ img] = top[ img] 204 | ret[~img] = bot[~img] 205 | return ret 206 | 207 | def constrained_x(self): return self.constrained(np.array([1,0,0])) 208 | def constrained_y(self): return self.constrained(np.array([0,1,0])) 209 | def constrained_z(self): return self.constrained(np.array([0,0,1])) 210 | 211 | def dot(self, q): return np.sum(self.qs * q.qs, axis=-1) 212 | 213 | def copy(self): return Quaternions(np.copy(self.qs)) 214 | 215 | def reshape(self, s): 216 | self.qs.reshape(s) 217 | return self 218 | 219 | def interpolate(self, ws): 220 | return Quaternions.exp(np.average(abs(self).log, axis=0, weights=ws)) 221 | 222 | def euler(self, order='xyz'): 223 | 224 | q = self.normalized().qs 225 | q0 = q[...,0] 226 | q1 = q[...,1] 227 | q2 = q[...,2] 228 | q3 = q[...,3] 229 | es = np.zeros(self.shape + (3,)) 230 | 231 | # These version is wrong on converting 232 | ''' 233 | if order == 'xyz': 234 | es[...,0] = np.arctan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) 235 | es[...,1] = np.arcsin((2 * (q0 * q2 - q3 * q1)).clip(-1,1)) 236 | es[...,2] = np.arctan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) 237 | elif order == 'yzx': 238 | es[...,0] = np.arctan2(2 * (q1 * q0 - q2 * q3), -q1 * q1 + q2 * q2 - q3 * q3 + q0 * q0) 239 | es[...,1] = np.arctan2(2 * (q2 * q0 - q1 * q3), q1 * q1 - q2 * q2 - q3 * q3 + q0 * q0) 240 | es[...,2] = np.arcsin((2 * (q1 * q2 + q3 * q0)).clip(-1,1)) 241 | else: 242 | raise NotImplementedError('Cannot convert from ordering %s' % order) 243 | 244 | ''' 245 | 246 | if order == 'xyz': 247 | es[..., 2] = np.arctan2(2 * (q0 * q3 - q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3) 248 | es[..., 1] = np.arcsin((2 * (q1 * q3 + q0 * q2)).clip(-1,1)) 249 | es[..., 0] = np.arctan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3) 250 | else: 251 | raise NotImplementedError('Cannot convert from ordering %s' % order) 252 | 253 | # These conversion don't appear to work correctly for Maya. 254 | # http://bediyap.com/programming/convert-quaternion-to-euler-rotations/ 255 | ''' 256 | if order == 'xyz': 257 | es[..., 0] = np.arctan2(2 * (q0 * q3 - q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3) 258 | es[..., 1] = np.arcsin((2 * (q1 * q3 + q0 * q2)).clip(-1,1)) 259 | es[..., 2] = np.arctan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3) 260 | elif order == 'yzx': 261 | es[fa + (0,)] = np.arctan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3) 262 | es[fa + (1,)] = np.arcsin((2 * (q1 * q2 + q0 * q3)).clip(-1,1)) 263 | es[fa + (2,)] = np.arctan2(2 * (q0 * q2 - q1 * q3), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3) 264 | elif order == 'zxy': 265 | es[fa + (0,)] = np.arctan2(2 * (q0 * q2 - q1 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3) 266 | es[fa + (1,)] = np.arcsin((2 * (q0 * q1 + q2 * q3)).clip(-1,1)) 267 | es[fa + (2,)] = np.arctan2(2 * (q0 * q3 - q1 * q2), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3) 268 | elif order == 'xzy': 269 | es[fa + (0,)] = np.arctan2(2 * (q0 * q2 + q1 * q3), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3) 270 | es[fa + (1,)] = np.arcsin((2 * (q0 * q3 - q1 * q2)).clip(-1,1)) 271 | es[fa + (2,)] = np.arctan2(2 * (q0 * q1 + q2 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3) 272 | elif order == 'yxz': 273 | es[fa + (0,)] = np.arctan2(2 * (q1 * q2 + q0 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3) 274 | es[fa + (1,)] = np.arcsin((2 * (q0 * q1 - q2 * q3)).clip(-1,1)) 275 | es[fa + (2,)] = np.arctan2(2 * (q1 * q3 + q0 * q2), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3) 276 | elif order == 'zyx': 277 | es[fa + (0,)] = np.arctan2(2 * (q0 * q1 + q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3) 278 | es[fa + (1,)] = np.arcsin((2 * (q0 * q2 - q1 * q3)).clip(-1,1)) 279 | es[fa + (2,)] = np.arctan2(2 * (q0 * q3 + q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3) 280 | 281 | else: 282 | raise KeyError('Unknown ordering %s' % order) 283 | ''' 284 | 285 | 286 | # https://github.com/ehsan/ogre/blob/master/OgreMain/src/OgreMatrix3.cpp 287 | # Use this class and convert from matrix 288 | 289 | return es 290 | 291 | 292 | def average(self): 293 | 294 | if len(self.shape) == 1: 295 | 296 | import numpy.core.umath_tests as ut 297 | system = ut.matrix_multiply(self.qs[:,:,np.newaxis], self.qs[:,np.newaxis,:]).sum(axis=0) 298 | w, v = np.linalg.eigh(system) 299 | qiT_dot_qref = (self.qs[:,:,np.newaxis] * v[np.newaxis,:,:]).sum(axis=1) 300 | return Quaternions(v[:,np.argmin((1.-qiT_dot_qref**2).sum(axis=0))]) 301 | 302 | else: 303 | 304 | raise NotImplementedError('Cannot average multi-dimensionsal Quaternions') 305 | 306 | def angle_axis(self): 307 | 308 | norm = self.normalized() 309 | s = np.sqrt(1 - (norm.reals**2.0)) 310 | s[s == 0] = 0.001 311 | 312 | angles = 2.0 * np.arccos(norm.reals) 313 | axis = norm.imaginaries / s[...,np.newaxis] 314 | 315 | return angles, axis 316 | 317 | 318 | def transforms(self): 319 | 320 | qw = self.qs[...,0] 321 | qx = self.qs[...,1] 322 | qy = self.qs[...,2] 323 | qz = self.qs[...,3] 324 | 325 | x2 = qx + qx; y2 = qy + qy; z2 = qz + qz; 326 | xx = qx * x2; yy = qy * y2; wx = qw * x2; 327 | xy = qx * y2; yz = qy * z2; wy = qw * y2; 328 | xz = qx * z2; zz = qz * z2; wz = qw * z2; 329 | 330 | m = np.empty(self.shape + (3,3)) 331 | m[...,0,0] = 1.0 - (yy + zz) 332 | m[...,0,1] = xy - wz 333 | m[...,0,2] = xz + wy 334 | m[...,1,0] = xy + wz 335 | m[...,1,1] = 1.0 - (xx + zz) 336 | m[...,1,2] = yz - wx 337 | m[...,2,0] = xz - wy 338 | m[...,2,1] = yz + wx 339 | m[...,2,2] = 1.0 - (xx + yy) 340 | 341 | return m 342 | 343 | def ravel(self): 344 | return self.qs.ravel() 345 | 346 | @classmethod 347 | def id(cls, n): 348 | 349 | if isinstance(n, tuple): 350 | qs = np.zeros(n + (4,)) 351 | qs[...,0] = 1.0 352 | return Quaternions(qs) 353 | 354 | if isinstance(n, int) or isinstance(n, long): 355 | qs = np.zeros((n,4)) 356 | qs[:,0] = 1.0 357 | return Quaternions(qs) 358 | 359 | raise TypeError('Cannot Construct Quaternion from %s type' % str(type(n))) 360 | 361 | @classmethod 362 | def id_like(cls, a): 363 | qs = np.zeros(a.shape + (4,)) 364 | qs[...,0] = 1.0 365 | return Quaternions(qs) 366 | 367 | @classmethod 368 | def exp(cls, ws): 369 | 370 | ts = np.sum(ws**2.0, axis=-1)**0.5 371 | ts[ts == 0] = 0.001 372 | ls = np.sin(ts) / ts 373 | 374 | qs = np.empty(ws.shape[:-1] + (4,)) 375 | qs[...,0] = np.cos(ts) 376 | qs[...,1] = ws[...,0] * ls 377 | qs[...,2] = ws[...,1] * ls 378 | qs[...,3] = ws[...,2] * ls 379 | 380 | return Quaternions(qs).normalized() 381 | 382 | @classmethod 383 | def slerp(cls, q0s, q1s, a): 384 | 385 | fst, snd = cls._broadcast(q0s.qs, q1s.qs) 386 | fst, a = cls._broadcast(fst, a, scalar=True) 387 | snd, a = cls._broadcast(snd, a, scalar=True) 388 | 389 | len = np.sum(fst * snd, axis=-1) 390 | 391 | neg = len < 0.0 392 | len[neg] = -len[neg] 393 | snd[neg] = -snd[neg] 394 | 395 | amount0 = np.zeros(a.shape) 396 | amount1 = np.zeros(a.shape) 397 | 398 | linear = (1.0 - len) < 0.01 399 | omegas = np.arccos(len[~linear]) 400 | sinoms = np.sin(omegas) 401 | 402 | amount0[ linear] = 1.0 - a[linear] 403 | amount1[ linear] = a[linear] 404 | amount0[~linear] = np.sin((1.0 - a[~linear]) * omegas) / sinoms 405 | amount1[~linear] = np.sin( a[~linear] * omegas) / sinoms 406 | 407 | return Quaternions( 408 | amount0[...,np.newaxis] * fst + 409 | amount1[...,np.newaxis] * snd) 410 | 411 | @classmethod 412 | def between(cls, v0s, v1s): 413 | a = np.cross(v0s, v1s) 414 | w = np.sqrt((v0s**2).sum(axis=-1) * (v1s**2).sum(axis=-1)) + (v0s * v1s).sum(axis=-1) 415 | return Quaternions(np.concatenate([w[...,np.newaxis], a], axis=-1)).normalized() 416 | 417 | @classmethod 418 | def from_angle_axis(cls, angles, axis): 419 | axis = axis / (np.sqrt(np.sum(axis**2, axis=-1)) + 1e-10)[...,np.newaxis] 420 | sines = np.sin(angles / 2.0)[...,np.newaxis] 421 | cosines = np.cos(angles / 2.0)[...,np.newaxis] 422 | return Quaternions(np.concatenate([cosines, axis * sines], axis=-1)) 423 | 424 | @classmethod 425 | def from_euler(cls, es, order='xyz', world=False): 426 | 427 | axis = { 428 | 'x' : np.array([1,0,0]), 429 | 'y' : np.array([0,1,0]), 430 | 'z' : np.array([0,0,1]), 431 | } 432 | 433 | q0s = Quaternions.from_angle_axis(es[...,0], axis[order[0]]) 434 | q1s = Quaternions.from_angle_axis(es[...,1], axis[order[1]]) 435 | q2s = Quaternions.from_angle_axis(es[...,2], axis[order[2]]) 436 | 437 | return (q2s * (q1s * q0s)) if world else (q0s * (q1s * q2s)) 438 | 439 | @classmethod 440 | def from_transforms(cls, ts): 441 | 442 | d0, d1, d2 = ts[...,0,0], ts[...,1,1], ts[...,2,2] 443 | 444 | q0 = ( d0 + d1 + d2 + 1.0) / 4.0 445 | q1 = ( d0 - d1 - d2 + 1.0) / 4.0 446 | q2 = (-d0 + d1 - d2 + 1.0) / 4.0 447 | q3 = (-d0 - d1 + d2 + 1.0) / 4.0 448 | 449 | q0 = np.sqrt(q0.clip(0,None)) 450 | q1 = np.sqrt(q1.clip(0,None)) 451 | q2 = np.sqrt(q2.clip(0,None)) 452 | q3 = np.sqrt(q3.clip(0,None)) 453 | 454 | c0 = (q0 >= q1) & (q0 >= q2) & (q0 >= q3) 455 | c1 = (q1 >= q0) & (q1 >= q2) & (q1 >= q3) 456 | c2 = (q2 >= q0) & (q2 >= q1) & (q2 >= q3) 457 | c3 = (q3 >= q0) & (q3 >= q1) & (q3 >= q2) 458 | 459 | q1[c0] *= np.sign(ts[c0,2,1] - ts[c0,1,2]) 460 | q2[c0] *= np.sign(ts[c0,0,2] - ts[c0,2,0]) 461 | q3[c0] *= np.sign(ts[c0,1,0] - ts[c0,0,1]) 462 | 463 | q0[c1] *= np.sign(ts[c1,2,1] - ts[c1,1,2]) 464 | q2[c1] *= np.sign(ts[c1,1,0] + ts[c1,0,1]) 465 | q3[c1] *= np.sign(ts[c1,0,2] + ts[c1,2,0]) 466 | 467 | q0[c2] *= np.sign(ts[c2,0,2] - ts[c2,2,0]) 468 | q1[c2] *= np.sign(ts[c2,1,0] + ts[c2,0,1]) 469 | q3[c2] *= np.sign(ts[c2,2,1] + ts[c2,1,2]) 470 | 471 | q0[c3] *= np.sign(ts[c3,1,0] - ts[c3,0,1]) 472 | q1[c3] *= np.sign(ts[c3,2,0] + ts[c3,0,2]) 473 | q2[c3] *= np.sign(ts[c3,2,1] + ts[c3,1,2]) 474 | 475 | qs = np.empty(ts.shape[:-2] + (4,)) 476 | qs[...,0] = q0 477 | qs[...,1] = q1 478 | qs[...,2] = q2 479 | qs[...,3] = q3 480 | 481 | return cls(qs) 482 | -------------------------------------------------------------------------------- /dataset/bvh/__pycache__/Quaternions.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wyysf-98/GenMM/fae65cdf199da8a25c4b28ceef20636b534269aa/dataset/bvh/__pycache__/Quaternions.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/bvh/__pycache__/bvh_io.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wyysf-98/GenMM/fae65cdf199da8a25c4b28ceef20636b534269aa/dataset/bvh/__pycache__/bvh_io.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/bvh/__pycache__/bvh_parser.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wyysf-98/GenMM/fae65cdf199da8a25c4b28ceef20636b534269aa/dataset/bvh/__pycache__/bvh_parser.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/bvh/__pycache__/bvh_writer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wyysf-98/GenMM/fae65cdf199da8a25c4b28ceef20636b534269aa/dataset/bvh/__pycache__/bvh_writer.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/bvh/bvh_io.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is modified from: 3 | http://theorangeduck.com/page/deep-learning-framework-character-motion-synthesis-and-editing 4 | 5 | by Daniel Holden et al 6 | """ 7 | 8 | 9 | import re 10 | import numpy as np 11 | from dataset.bvh.Quaternions import Quaternions 12 | 13 | channelmap = { 14 | 'Xrotation' : 'x', 15 | 'Yrotation' : 'y', 16 | 'Zrotation' : 'z' 17 | } 18 | 19 | channelmap_inv = { 20 | 'x': 'Xrotation', 21 | 'y': 'Yrotation', 22 | 'z': 'Zrotation', 23 | } 24 | 25 | ordermap = { 26 | 'x': 0, 27 | 'y': 1, 28 | 'z': 2, 29 | } 30 | 31 | 32 | class Animation: 33 | def __init__(self, rotations, positions, orients, offsets, parents, names, frametime): 34 | self.rotations = rotations 35 | self.positions = positions 36 | self.orients = orients 37 | self.offsets = offsets 38 | self.parent = parents 39 | self.names = names 40 | self.frametime = frametime 41 | 42 | @property 43 | def shape(self): 44 | return self.rotations.shape 45 | 46 | 47 | def load(filename, start=None, end=None, order=None, world=False, need_quater=False) -> Animation: 48 | """ 49 | Reads a BVH file and constructs an animation 50 | 51 | Parameters 52 | ---------- 53 | filename: str 54 | File to be opened 55 | 56 | start : int 57 | Optional Starting Frame 58 | 59 | end : int 60 | Optional Ending Frame 61 | 62 | order : str 63 | Optional Specifier for joint order. 64 | Given as string E.G 'xyz', 'zxy' 65 | 66 | world : bool 67 | If set to true euler angles are applied 68 | together in world space rather than local 69 | space 70 | Returns 71 | ------- 72 | 73 | (animation, joint_names, frametime) 74 | Tuple of loaded animation and joint names 75 | """ 76 | 77 | f = open(filename, "r") 78 | 79 | i = 0 80 | active = -1 81 | end_site = False 82 | 83 | names = [] 84 | orients = Quaternions.id(0) 85 | offsets = np.array([]).reshape((0, 3)) 86 | parents = np.array([], dtype=int) 87 | orders = [] 88 | 89 | for line in f: 90 | 91 | if "HIERARCHY" in line: continue 92 | if "MOTION" in line: continue 93 | 94 | """ Modified line read to handle mixamo data """ 95 | # rmatch = re.match(r"ROOT (\w+)", line) 96 | rmatch = re.match(r"ROOT (\w+:?\w+)", line) 97 | if rmatch: 98 | names.append(rmatch.group(1)) 99 | offsets = np.append(offsets, np.array([[0, 0, 0]]), axis=0) 100 | orients.qs = np.append(orients.qs, np.array([[1, 0, 0, 0]]), axis=0) 101 | parents = np.append(parents, active) 102 | active = (len(parents) - 1) 103 | continue 104 | 105 | if "{" in line: continue 106 | 107 | if "}" in line: 108 | if end_site: 109 | end_site = False 110 | else: 111 | active = parents[active] 112 | continue 113 | 114 | offmatch = re.match(r"\s*OFFSET\s+([\-\d\.e]+)\s+([\-\d\.e]+)\s+([\-\d\.e]+)", line) 115 | if offmatch: 116 | if not end_site: 117 | offsets[active] = np.array([list(map(float, offmatch.groups()))]) 118 | continue 119 | 120 | chanmatch = re.match(r"\s*CHANNELS\s+(\d+)", line) 121 | if chanmatch: 122 | channels = int(chanmatch.group(1)) 123 | 124 | channelis = 0 if channels == 3 else 3 125 | channelie = 3 if channels == 3 else 6 126 | parts = line.split()[2 + channelis:2 + channelie] 127 | if any([p not in channelmap for p in parts]): 128 | continue 129 | order = "".join([channelmap[p] for p in parts]) 130 | orders.append(order) 131 | continue 132 | 133 | """ Modified line read to handle mixamo data """ 134 | # jmatch = re.match("\s*JOINT\s+(\w+)", line) 135 | jmatch = re.match("\s*JOINT\s+(\w+:?\w+)", line) 136 | if jmatch: 137 | names.append(jmatch.group(1)) 138 | offsets = np.append(offsets, np.array([[0, 0, 0]]), axis=0) 139 | orients.qs = np.append(orients.qs, np.array([[1, 0, 0, 0]]), axis=0) 140 | parents = np.append(parents, active) 141 | active = (len(parents) - 1) 142 | continue 143 | 144 | if "End Site" in line: 145 | end_site = True 146 | continue 147 | 148 | fmatch = re.match("\s*Frames:\s+(\d+)", line) 149 | if fmatch: 150 | if start and end: 151 | fnum = (end - start) - 1 152 | else: 153 | fnum = int(fmatch.group(1)) 154 | jnum = len(parents) 155 | positions = offsets[np.newaxis].repeat(fnum, axis=0) 156 | rotations = np.zeros((fnum, len(orients), 3)) 157 | continue 158 | 159 | fmatch = re.match("\s*Frame Time:\s+([\d\.]+)", line) 160 | if fmatch: 161 | frametime = float(fmatch.group(1)) 162 | continue 163 | 164 | if (start and end) and (i < start or i >= end - 1): 165 | i += 1 166 | continue 167 | 168 | # dmatch = line.strip().split(' ') 169 | dmatch = line.strip().split() 170 | if dmatch: 171 | data_block = np.array(list(map(float, dmatch))) 172 | N = len(parents) 173 | fi = i - start if start else i 174 | if channels == 3: 175 | positions[fi, 0:1] = data_block[0:3] 176 | rotations[fi, :] = data_block[3:].reshape(N, 3) 177 | elif channels == 6: 178 | data_block = data_block.reshape(N, 6) 179 | positions[fi, :] = data_block[:, 0:3] 180 | rotations[fi, :] = data_block[:, 3:6] 181 | elif channels == 9: 182 | positions[fi, 0] = data_block[0:3] 183 | data_block = data_block[3:].reshape(N - 1, 9) 184 | rotations[fi, 1:] = data_block[:, 3:6] 185 | positions[fi, 1:] += data_block[:, 0:3] * data_block[:, 6:9] 186 | else: 187 | raise Exception("Too many channels! %i" % channels) 188 | 189 | i += 1 190 | 191 | f.close() 192 | 193 | all_rotations = [] 194 | canonical_order = 'xyz' 195 | for i, order in enumerate(orders): 196 | rot = rotations[:, i:i + 1] 197 | if need_quater: 198 | quat = Quaternions.from_euler(np.radians(rot), order=order, world=world) 199 | all_rotations.append(quat) 200 | continue 201 | elif order != canonical_order: 202 | quat = Quaternions.from_euler(np.radians(rot), order=order, world=world) 203 | rot = np.degrees(quat.euler(order=canonical_order)) 204 | all_rotations.append(rot) 205 | rotations = np.concatenate(all_rotations, axis=1) 206 | 207 | return Animation(rotations, positions, orients, offsets, parents, names, frametime) 208 | 209 | 210 | def save(filename, anim, names=None, frametime=1.0/24.0, order='zyx', positions=False, orients=True): 211 | """ 212 | Saves an Animation to file as BVH 213 | 214 | Parameters 215 | ---------- 216 | filename: str 217 | File to be saved to 218 | 219 | anim : Animation 220 | Animation to save 221 | 222 | names : [str] 223 | List of joint names 224 | 225 | order : str 226 | Optional Specifier for joint order. 227 | Given as string E.G 'xyz', 'zxy' 228 | 229 | frametime : float 230 | Optional Animation Frame time 231 | 232 | positions : bool 233 | Optional specfier to save bone 234 | positions for each frame 235 | 236 | orients : bool 237 | Multiply joint orients to the rotations 238 | before saving. 239 | 240 | """ 241 | 242 | if names is None: 243 | names = ["joint_" + str(i) for i in range(len(anim.parents))] 244 | 245 | with open(filename, 'w') as f: 246 | 247 | t = "" 248 | f.write("%sHIERARCHY\n" % t) 249 | f.write("%sROOT %s\n" % (t, names[0])) 250 | f.write("%s{\n" % t) 251 | t += '\t' 252 | 253 | f.write("%sOFFSET %f %f %f\n" % (t, anim.offsets[0,0], anim.offsets[0,1], anim.offsets[0,2]) ) 254 | f.write("%sCHANNELS 6 Xposition Yposition Zposition %s %s %s \n" % 255 | (t, channelmap_inv[order[0]], channelmap_inv[order[1]], channelmap_inv[order[2]])) 256 | 257 | for i in range(anim.shape[1]): 258 | if anim.parents[i] == 0: 259 | t = save_joint(f, anim, names, t, i, order=order, positions=positions) 260 | 261 | t = t[:-1] 262 | f.write("%s}\n" % t) 263 | 264 | f.write("MOTION\n") 265 | f.write("Frames: %i\n" % anim.shape[0]); 266 | f.write("Frame Time: %f\n" % frametime); 267 | 268 | #if orients: 269 | # rots = np.degrees((-anim.orients[np.newaxis] * anim.rotations).euler(order=order[::-1])) 270 | #else: 271 | # rots = np.degrees(anim.rotations.euler(order=order[::-1])) 272 | rots = np.degrees(anim.rotations.euler(order=order[::-1])) 273 | poss = anim.positions 274 | 275 | for i in range(anim.shape[0]): 276 | for j in range(anim.shape[1]): 277 | 278 | if positions or j == 0: 279 | 280 | f.write("%f %f %f %f %f %f " % ( 281 | poss[i,j,0], poss[i,j,1], poss[i,j,2], 282 | rots[i,j,ordermap[order[0]]], rots[i,j,ordermap[order[1]]], rots[i,j,ordermap[order[2]]])) 283 | 284 | else: 285 | 286 | f.write("%f %f %f " % ( 287 | rots[i,j,ordermap[order[0]]], rots[i,j,ordermap[order[1]]], rots[i,j,ordermap[order[2]]])) 288 | 289 | f.write("\n") 290 | 291 | 292 | def save_joint(f, anim, names, t, i, order='zyx', positions=False): 293 | 294 | f.write("%sJOINT %s\n" % (t, names[i])) 295 | f.write("%s{\n" % t) 296 | t += '\t' 297 | 298 | f.write("%sOFFSET %f %f %f\n" % (t, anim.offsets[i,0], anim.offsets[i,1], anim.offsets[i,2])) 299 | 300 | if positions: 301 | f.write("%sCHANNELS 6 Xposition Yposition Zposition %s %s %s \n" % (t, 302 | channelmap_inv[order[0]], channelmap_inv[order[1]], channelmap_inv[order[2]])) 303 | else: 304 | f.write("%sCHANNELS 3 %s %s %s\n" % (t, 305 | channelmap_inv[order[0]], channelmap_inv[order[1]], channelmap_inv[order[2]])) 306 | 307 | end_site = True 308 | 309 | for j in range(anim.shape[1]): 310 | if anim.parents[j] == i: 311 | t = save_joint(f, anim, names, t, j, order=order, positions=positions) 312 | end_site = False 313 | 314 | if end_site: 315 | f.write("%sEnd Site\n" % t) 316 | f.write("%s{\n" % t) 317 | t += '\t' 318 | f.write("%sOFFSET %f %f %f\n" % (t, 0.0, 0.0, 0.0)) 319 | t = t[:-1] 320 | f.write("%s}\n" % t) 321 | 322 | t = t[:-1] 323 | f.write("%s}\n" % t) 324 | 325 | return t 326 | -------------------------------------------------------------------------------- /dataset/bvh/bvh_parser.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import dataset.bvh.bvh_io as bvh_io 4 | from utils.kinematics import ForwardKinematicsJoint 5 | from utils.transforms import quat2repr6d 6 | from utils.contact import foot_contact 7 | from dataset.bvh.Quaternions import Quaternions 8 | from dataset.bvh.bvh_writer import WriterWrapper 9 | 10 | 11 | class Skeleton: 12 | def __init__(self, names, parent, offsets, joint_reduction=True, skeleton_conf=None): 13 | self._names = names 14 | self.original_parent = parent 15 | self._offsets = offsets 16 | self._parent = None 17 | self._ee_id = None 18 | self.contact_names = [] 19 | 20 | for i, name in enumerate(self._names): 21 | if ':' in name: 22 | self._names[i] = name[name.find(':')+1:] 23 | 24 | if joint_reduction or skeleton_conf is not None: 25 | assert skeleton_conf is not None, 'skeleton_conf can not be None if you use joint reduction' 26 | corps_names = skeleton_conf['corps_names'] 27 | self.contact_names = skeleton_conf['corps_names'] 28 | self.contact_threshold = skeleton_conf['contact_threshold'] 29 | 30 | self.contact_id = [] 31 | for i in self.contact_names: 32 | self.contact_id.append(corps_names.index(i)) 33 | else: 34 | self.skeleton_type = -1 35 | corps_names = self._names 36 | 37 | self.details = [] # joints that does not belong to the corps (we are not interested in them) 38 | for i, name in enumerate(self._names): 39 | if name not in corps_names: self.details.append(i) 40 | 41 | self.corps = [] 42 | self.simplified_name = [] 43 | self.simplify_map = {} 44 | self.inverse_simplify_map = {} 45 | 46 | # Repermute the skeleton id according to the databse 47 | for name in corps_names: 48 | for j in range(len(self._names)): 49 | if name in self._names[j]: 50 | self.corps.append(j) 51 | break 52 | if len(self.corps) != len(corps_names): 53 | for i in self.corps: 54 | print(self._names[i], end=' ') 55 | print(self.corps, self.skeleton_type, len(self.corps), sep='\n') 56 | raise Exception('Problem in this skeleton') 57 | 58 | self.joint_num_simplify = len(self.corps) 59 | for i, j in enumerate(self.corps): 60 | self.simplify_map[j] = i 61 | self.inverse_simplify_map[i] = j 62 | self.simplified_name.append(self._names[j]) 63 | self.inverse_simplify_map[0] = -1 64 | for i in range(len(self._names)): 65 | if i in self.details: 66 | self.simplify_map[i] = -1 67 | 68 | @property 69 | def parent(self): 70 | if self._parent is None: 71 | self._parent = self.original_parent[self.corps].copy() 72 | for i in range(self._parent.shape[0]): 73 | if i >= 1: self._parent[i] = self.simplify_map[self._parent[i]] 74 | self._parent = tuple(self._parent) 75 | return self._parent 76 | 77 | @property 78 | def offsets(self): 79 | return torch.tensor(self._offsets[self.corps], dtype=torch.float) 80 | 81 | @property 82 | def names(self): 83 | return self.simplified_name 84 | 85 | @property 86 | def ee_id(self): 87 | raise Exception('Abaddoned') 88 | # if self._ee_id is None: 89 | # self._ee_id = [] 90 | # for i in SkeletonDatabase.ee_names[self.skeleton_type]: 91 | # self.ee_id._ee_id(corps_names[self.skeleton_type].index(i)) 92 | 93 | 94 | class BVH_file: 95 | def __init__(self, file_path, skeleton_conf=None, requires_contact=False, joint_reduction=True, auto_scale=True): 96 | self.anim = bvh_io.load(file_path) 97 | self._names = self.anim.names 98 | self.frametime = self.anim.frametime 99 | if requires_contact or joint_reduction: 100 | assert skeleton_conf is not None, 'Please provide a skeleton configuration for contact or joint reduction' 101 | self.skeleton = Skeleton(self.anim.names, self.anim.parent, self.anim.offsets, joint_reduction, skeleton_conf) 102 | 103 | # Downsample to 30 fps for our application 104 | if self.frametime < 0.0084: 105 | self.frametime *= 2 106 | self.anim.positions = self.anim.positions[::2] 107 | self.anim.rotations = self.anim.rotations[::2] 108 | if self.frametime < 0.017: 109 | self.frametime *= 2 110 | self.anim.positions = self.anim.positions[::2] 111 | self.anim.rotations = self.anim.rotations[::2] 112 | 113 | self.requires_contact = requires_contact 114 | 115 | if requires_contact: 116 | self.contact_names = self.skeleton.contact_names 117 | else: 118 | self.contact_names = [] 119 | 120 | self.fk = ForwardKinematicsJoint(self.skeleton.parent, self.skeleton.offsets) 121 | self.writer = WriterWrapper(self.skeleton.parent, self.skeleton.offsets) 122 | 123 | self.auto_scale = auto_scale 124 | if auto_scale: 125 | self.scale = 1. / np.ceil(self.skeleton.offsets.max().cpu().numpy()) 126 | print(f'rescale the skeleton with scale: {self.scale}') 127 | self.rescale(self.scale) 128 | else: 129 | self.scale = 1.0 130 | 131 | if self.requires_contact: 132 | gl_pos = self.joint_position() 133 | self.contact_label = foot_contact(gl_pos[:, self.skeleton.contact_id], 134 | threshold=self.skeleton.contact_threshold) 135 | self.gl_pos = gl_pos 136 | 137 | def local_pos(self): 138 | gl_pos = self.joint_position() 139 | local_pos = gl_pos - gl_pos[:, 0:1, :] 140 | return local_pos[:, 1:] 141 | 142 | def rescale(self, ratio): 143 | self.anim.offsets *= ratio 144 | self.anim.positions *= ratio 145 | 146 | def to_tensor(self, repr='euler', rot_only=False): 147 | if repr not in ['euler', 'quat', 'quaternion', 'repr6d']: 148 | raise Exception('Unknown rotation representation') 149 | positions = self.get_position() 150 | rotations = self.get_rotation(repr=repr) 151 | 152 | if rot_only: 153 | return rotations.reshape(rotations.shape[0], -1) 154 | 155 | if self.requires_contact: 156 | virtual_contact = torch.zeros_like(rotations[:, :len(self.skeleton.contact_id)]) 157 | virtual_contact[..., 0] = self.contact_label 158 | rotations = torch.cat([rotations, virtual_contact], dim=1) 159 | 160 | rotations = rotations.reshape(rotations.shape[0], -1) 161 | return torch.cat((rotations, positions), dim=-1) 162 | 163 | def joint_position(self): 164 | positions = torch.tensor(self.anim.positions[:, 0, :], dtype=torch.float) 165 | rotations = self.anim.rotations[:, self.skeleton.corps, :] 166 | rotations = Quaternions.from_euler(np.radians(rotations)).qs 167 | rotations = torch.tensor(rotations, dtype=torch.float) 168 | j_loc = self.fk.forward(rotations, positions) 169 | return j_loc 170 | 171 | def get_rotation(self, repr='quat'): 172 | rotations = self.anim.rotations[:, self.skeleton.corps, :] 173 | if repr == 'quaternion' or repr == 'quat' or repr == 'repr6d': 174 | rotations = Quaternions.from_euler(np.radians(rotations)).qs 175 | rotations = torch.tensor(rotations, dtype=torch.float) 176 | if repr == 'repr6d': 177 | rotations = quat2repr6d(rotations) 178 | if repr == 'euler': 179 | rotations = torch.tensor(rotations, dtype=torch.float) 180 | return rotations 181 | 182 | def get_position(self): 183 | return torch.tensor(self.anim.positions[:, 0, :], dtype=torch.float) 184 | 185 | def dfs(self, x, vis, dist): 186 | fa = self.skeleton.parent 187 | vis[x] = 1 188 | for y in range(len(fa)): 189 | if (fa[y] == x or fa[x] == y) and vis[y] == 0: 190 | dist[y] = dist[x] + 1 191 | self.dfs(y, vis, dist) 192 | 193 | def get_neighbor(self, threshold, enforce_contact=False): 194 | fa = self.skeleton.parent 195 | neighbor_list = [] 196 | for x in range(0, len(fa)): 197 | vis = [0 for _ in range(len(fa))] 198 | dist = [0 for _ in range(len(fa))] 199 | self.dfs(x, vis, dist) 200 | neighbor = [] 201 | for j in range(0, len(fa)): 202 | if dist[j] <= threshold: 203 | neighbor.append(j) 204 | neighbor_list.append(neighbor) 205 | 206 | contact_list = [] 207 | if self.requires_contact: 208 | for i, p_id in enumerate(self.skeleton.contact_id): 209 | v_id = len(neighbor_list) 210 | neighbor_list[p_id].append(v_id) 211 | neighbor_list.append(neighbor_list[p_id]) 212 | contact_list.append(v_id) 213 | 214 | root_neighbor = neighbor_list[0] 215 | id_root = len(neighbor_list) 216 | 217 | if enforce_contact: 218 | root_neighbor = root_neighbor + contact_list 219 | for j in contact_list: 220 | neighbor_list[j] = list(set(neighbor_list[j])) 221 | 222 | root_neighbor = list(set(root_neighbor)) 223 | for j in root_neighbor: 224 | neighbor_list[j].append(id_root) 225 | root_neighbor.append(id_root) 226 | neighbor_list.append(root_neighbor) # Neighbor for root position 227 | return neighbor_list -------------------------------------------------------------------------------- /dataset/bvh/bvh_writer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.transforms import quat2euler, repr6d2quat 3 | 4 | 5 | # rotation with shape frame * J * 3 6 | def write_bvh(parent, offset, rotation, position, names, frametime, order, path, endsite=None): 7 | file = open(path, 'w') 8 | frame = rotation.shape[0] 9 | joint_num = rotation.shape[1] 10 | order = order.upper() 11 | 12 | file_string = 'HIERARCHY\n' 13 | 14 | seq = [] 15 | 16 | def write_static(idx, prefix): 17 | nonlocal parent, offset, rotation, names, order, endsite, file_string, seq 18 | seq.append(idx) 19 | if idx == 0: 20 | name_label = 'ROOT ' + names[idx] 21 | channel_label = 'CHANNELS 6 Xposition Yposition Zposition {}rotation {}rotation {}rotation'.format(*order) 22 | else: 23 | name_label = 'JOINT ' + names[idx] 24 | channel_label = 'CHANNELS 3 {}rotation {}rotation {}rotation'.format(*order) 25 | offset_label = 'OFFSET %.6f %.6f %.6f' % (offset[idx][0], offset[idx][1], offset[idx][2]) 26 | 27 | file_string += prefix + name_label + '\n' 28 | file_string += prefix + '{\n' 29 | file_string += prefix + '\t' + offset_label + '\n' 30 | file_string += prefix + '\t' + channel_label + '\n' 31 | 32 | has_child = False 33 | for y in range(idx+1, rotation.shape[1]): 34 | if parent[y] == idx: 35 | has_child = True 36 | write_static(y, prefix + '\t') 37 | if not has_child: 38 | file_string += prefix + '\t' + 'End Site\n' 39 | file_string += prefix + '\t' + '{\n' 40 | file_string += prefix + '\t\t' + 'OFFSET 0 0 0\n' 41 | file_string += prefix + '\t' + '}\n' 42 | 43 | file_string += prefix + '}\n' 44 | 45 | write_static(0, '') 46 | 47 | file_string += 'MOTION\n' + 'Frames: {}\n'.format(frame) + 'Frame Time: %.8f\n' % frametime 48 | for i in range(frame): 49 | file_string += '%.6f %.6f %.6f ' % (position[i][0], position[i][1], position[i][2]) 50 | for j in range(joint_num): 51 | idx = seq[j] 52 | file_string += '%.6f %.6f %.6f ' % (rotation[i][idx][0], rotation[i][idx][1], rotation[i][idx][2]) 53 | file_string += '\n' 54 | 55 | file.write(file_string) 56 | return file_string 57 | 58 | 59 | class WriterWrapper: 60 | def __init__(self, parents, offset=None): 61 | self.parents = parents 62 | self.offset = offset 63 | 64 | def write(self, filename, rot, pos, offset=None, names=None, repr='quat'): 65 | """ 66 | Write animation to bvh file 67 | :param filename: 68 | :param rot: Quaternion as (w, x, y, z) 69 | :param pos: 70 | :param offset: 71 | :return: 72 | """ 73 | if repr not in ['euler', 'quat', 'quaternion', 'repr6d']: 74 | raise Exception('Unknown rotation representation') 75 | if offset is None: 76 | offset = self.offset 77 | if not isinstance(offset, torch.Tensor): 78 | offset = torch.tensor(offset) 79 | n_bone = offset.shape[0] 80 | 81 | if repr == 'repr6d': 82 | rot = rot.reshape(rot.shape[0], -1, 6) 83 | rot = repr6d2quat(rot) 84 | if repr == 'repr6d' or repr == 'quat' or repr == 'quaternion': 85 | rot = rot.reshape(rot.shape[0], -1, 4) 86 | rot /= rot.norm(dim=-1, keepdim=True) ** 0.5 87 | euler = quat2euler(rot, order='xyz') 88 | rot = euler 89 | 90 | if names is None: 91 | names = ['%02d' % i for i in range(n_bone)] 92 | write_bvh(self.parents, offset, rot, pos, names, 1, 'xyz', filename) 93 | -------------------------------------------------------------------------------- /dataset/bvh_motion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import torch 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from .motion import MotionData 7 | from .bvh.bvh_parser import BVH_file 8 | 9 | 10 | ## Some skeleton configurations 11 | crab_dance_corps_names = ['ORG_Hips', 'ORG_BN_Bip01_Pelvis', 'DEF_BN_Eye_L_01', 'DEF_BN_Eye_L_02', 'DEF_BN_Eye_L_03', 'DEF_BN_Eye_L_03_end', 'DEF_BN_Eye_R_01', 'DEF_BN_Eye_R_02', 'DEF_BN_Eye_R_03', 'DEF_BN_Eye_R_03_end', 'DEF_BN_Leg_L_11', 'DEF_BN_Leg_L_12', 'DEF_BN_Leg_L_13', 'DEF_BN_Leg_L_14', 'DEF_BN_Leg_L_15', 'DEF_BN_Leg_L_15_end', 'DEF_BN_Leg_R_11', 'DEF_BN_Leg_R_12', 'DEF_BN_Leg_R_13', 'DEF_BN_Leg_R_14', 'DEF_BN_Leg_R_15', 'DEF_BN_Leg_R_15_end', 'DEF_BN_leg_L_01', 'DEF_BN_leg_L_02', 'DEF_BN_leg_L_03', 'DEF_BN_leg_L_04', 'DEF_BN_leg_L_05', 'DEF_BN_leg_L_05_end', 12 | 'DEF_BN_leg_L_06', 'DEF_BN_Leg_L_07', 'DEF_BN_Leg_L_08', 'DEF_BN_Leg_L_09', 'DEF_BN_Leg_L_10', 'DEF_BN_Leg_L_10_end', 'DEF_BN_leg_R_01', 'DEF_BN_leg_R_02', 'DEF_BN_leg_R_03', 'DEF_BN_leg_R_04', 'DEF_BN_leg_R_05', 'DEF_BN_leg_R_05_end', 'DEF_BN_leg_R_06', 'DEF_BN_Leg_R_07', 'DEF_BN_Leg_R_08', 'DEF_BN_Leg_R_09', 'DEF_BN_Leg_R_10', 'DEF_BN_Leg_R_10_end', 'DEF_BN_Bip01_Pelvis', 'DEF_BN_Bip01_Pelvis_end', 'DEF_BN_Arm_L_01', 'DEF_BN_Arm_L_02', 'DEF_BN_Arm_L_03', 'DEF_BN_Arm_L_03_end', 'DEF_BN_Arm_R_01', 'DEF_BN_Arm_R_02', 'DEF_BN_Arm_R_03', 'DEF_BN_Arm_R_03_end'] 13 | skeleton_confs = { 14 | 'mixamo': { 15 | 'corps_names': ['Hips', 'LeftUpLeg', 'LeftLeg', 'LeftFoot', 'LeftToeBase', 'LeftToe_End', 'RightUpLeg', 'RightLeg', 'RightFoot', 'RightToeBase', 'RightToe_End', 'Spine', 'Spine1', 'Spine2', 'Neck', 'Head', 'LeftShoulder', 'LeftArm', 'LeftForeArm', 'LeftHand', 'RightShoulder', 'RightArm', 'RightForeArm', 'RightHand'], 16 | 'contact_names': ['LeftToe_End', 'RightToe_End', 'LeftToeBase', 'RightToeBase'], 17 | 'contact_threshold': 0.018 18 | }, 19 | 'crab_dance': { 20 | 'corps_names': crab_dance_corps_names, 21 | 'contact_names': [name for name in crab_dance_corps_names if 'end' in name and ('05' in name or '10' in name or '15' in name)], 22 | 'contact_threshold': 0.006 23 | }, 24 | 'xia': { 25 | 'corps_names': ['Hips', 'LHipJoint', 'LeftUpLeg', 'LeftLeg', 'LeftFoot', 'LeftToeBase', 'RHipJoint', 'RightUpLeg', 'RightLeg', 'RightFoot', 'RightToeBase', 'LowerBack', 'Spine', 'Spine1', 'Neck', 'Neck1', 'Head', 'LeftShoulder', 'LeftArm', 'LeftForeArm', 'LeftHand', 'LeftFingerBase', 'LeftHandIndex1', 'LThumb', 'RightShoulder', 'RightArm', 'RightForeArm', 'RightHand', 'RightFingerBase', 'RightHandIndex1', 'RThumb'], 26 | 'contact_names': ['LeftToeBase', 'RightToeBase'], 27 | 'contact_threshold': 0.006 28 | } 29 | } 30 | 31 | class BVHMotion: 32 | def __init__(self, bvh_file, skeleton_name=None, repr='quat', use_velo=True, keep_up_pos=False, up_axis='Y_UP', padding_last=False, requires_contact=False, joint_reduction=False): 33 | ''' 34 | BVHMotion constructor 35 | Args: 36 | bvh_file : string, bvh_file path to load from 37 | skelton_name : string, name of predefined skeleton, used when joint_reduction==True or contact==True 38 | repr : string, rotation representation, support ['quat', 'repr6d', 'euler'] 39 | use_velo : book, whether to transform the joints positions to velocities 40 | keep_up_pos : bool, whether to keep y position when converting to velocity 41 | up_axis : string, string, up axis of the motion data 42 | padding_last : bool, whether to pad the last position 43 | requires_contact : bool, whether to concatenate contact information 44 | joint_reduction : bool, whether to reduce the joint number 45 | ''' 46 | self.bvh_file = bvh_file 47 | self.skeleton_name = skeleton_name 48 | if skeleton_name is not None: 49 | assert skeleton_name in skeleton_confs, f'{skeleton_name} not found, please add a skeleton configuration.' 50 | self.requires_contact = requires_contact 51 | self.joint_reduction = joint_reduction 52 | 53 | self.raw_data = BVH_file(bvh_file, skeleton_confs[skeleton_name] if skeleton_name is not None else None, requires_contact, joint_reduction, auto_scale=True) 54 | self.motion_data = MotionData(self.raw_data.to_tensor(repr=repr).permute(1, 0).unsqueeze(0), repr=repr, use_velo=use_velo, keep_up_pos=keep_up_pos, up_axis=up_axis, 55 | padding_last=padding_last, contact_id=self.raw_data.skeleton.contact_id if requires_contact else None) 56 | @property 57 | def repr(self): 58 | return self.motion_data.repr 59 | 60 | @property 61 | def use_velo(self): 62 | return self.motion_data.use_velo 63 | 64 | @property 65 | def keep_up_pos(self): 66 | return self.motion_data.keep_up_pos 67 | 68 | @property 69 | def padding_last(self): 70 | return self.motion_data.padding_last 71 | 72 | @property 73 | def concat_id(self): 74 | return self.motion_data.contact_id 75 | 76 | @property 77 | def n_pad(self): 78 | return self.motion_data.n_pad 79 | 80 | @property 81 | def n_contact(self): 82 | return self.motion_data.n_contact 83 | 84 | @property 85 | def n_rot(self): 86 | return self.motion_data.n_rot 87 | 88 | def sample(self, size=None, slerp=False): 89 | ''' 90 | Sample motion data, support slerp 91 | ''' 92 | return self.motion_data.sample(size, slerp) 93 | 94 | 95 | def write(self, filename, data): 96 | ''' 97 | Parse motion data into position, velocity and contact(if exists) 98 | data should be [] 99 | No batch support here!!! 100 | ''' 101 | assert len(data.shape) == 3, 'The data format should be [batch_size x n_channels x n_frames]' 102 | 103 | if self.n_pad: 104 | data = data.clone()[:, :-self.n_pad] 105 | if self.use_velo: 106 | data = self.motion_data.to_position(data) 107 | data = data.squeeze().permute(1, 0) 108 | pos = data[..., -3:] 109 | rot = data[..., :-3].reshape(data.shape[0], -1, self.n_rot) 110 | if self.requires_contact: 111 | contact = rot[..., -self.n_contact:, 0] 112 | rot = rot[..., :-self.n_contact, :] 113 | else: 114 | contact = None 115 | 116 | if contact is not None: 117 | np.save(filename + '.contact', contact.detach().cpu().numpy()) 118 | 119 | # rescale the output 120 | self.raw_data.rescale(1. / self.raw_data.scale) 121 | pos *= 1. / self.raw_data.scale 122 | self.raw_data.writer.write(filename, rot, pos, names=self.raw_data.skeleton.names, repr=self.repr) 123 | 124 | 125 | def load_multiple_dataset(name_list, **kargs): 126 | with open(name_list, 'r') as f: 127 | names = [line.strip() for line in f.readlines()] 128 | datasets = [] 129 | for f in names: 130 | kargs['bvh_file'] = osp.join(osp.dirname(name_list), f) 131 | datasets.append(BVHMotion(**kargs)) 132 | return datasets -------------------------------------------------------------------------------- /dataset/motion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | class MotionData: 6 | def __init__(self, data, repr='quat', use_velo=True, keep_up_pos=True, up_axis='Y', padding_last=False, contact_id=None): 7 | ''' 8 | BaseMotionData constructor 9 | Args: 10 | data : torch.Tensor, [batch_size x n_channels x n_frames] input motion data, 11 | the channels dim shoud be [n_joints x n_dim_of_rotation + 3(global position)] 12 | repr : string, rotation representation, support ['quat', 'repr6d', 'euler'] 13 | use_velo : book, whether to transform the joints positions to velocities 14 | keep_up_pos : bool, whether to keep up position when converting to velocity 15 | up_axis : string, string, up axis of the motion data 16 | padding_last : bool, whether to pad the last position 17 | contact_id : list, contact joints id 18 | ''' 19 | self.data = data 20 | self.repr = repr 21 | self.use_velo = use_velo 22 | self.keep_up_pos = keep_up_pos 23 | self.up_axis = up_axis 24 | self.padding_last = padding_last 25 | self.contact_id = contact_id 26 | self.begin_pos = None 27 | 28 | # assert the rotation representation 29 | if self.repr == 'quat': 30 | self.n_rot = 4 31 | assert (self.data.shape[1] - 3) % 4 == 0, 'rotation is not "quaternion" representation' 32 | elif self.repr == 'repr6d': 33 | self.n_rot = 6 34 | assert (self.data.shape[1] - 3) % 6 == 0, 'rotation is not "repr6d" representation' 35 | elif self.repr == 'eluer': 36 | self.n_rot = 3 37 | assert (self.data.shape[1] - 3) % 3 == 0, 'rotation is not "euler" representation' 38 | 39 | # whether to pad the position data with zero 40 | if self.padding_last: 41 | self.n_pad = self.data.shape[1] - 3 # pad position channels to match the n_channels of rotation 42 | paddings = torch.zeros_like(self.data[:, :self.n_pad]) 43 | self.data = torch.cat((self.data, paddings), dim=1) 44 | else: 45 | self.n_pad = 0 46 | 47 | # get the contact information 48 | if self.contact_id is not None: 49 | self.n_contact = len(contact_id) 50 | else: 51 | self.n_contact = 0 52 | 53 | # whether to keep y position when converting to velocity 54 | if self.keep_up_pos: 55 | if self.up_axis == 'X_UP': 56 | self.velo_mask = [-2, -1] 57 | elif self.up_axis == 'Y_UP': 58 | self.velo_mask = [-3, -1] 59 | elif self.up_axis == 'Z_UP': 60 | self.velo_mask = [-3, -2] 61 | else: 62 | self.velo_mask = [-3, -2, -1] 63 | 64 | # whether to convert global position to velocity 65 | if self.use_velo: 66 | self.data = self.to_velocity(self.data) 67 | 68 | 69 | def __len__(self): 70 | ''' 71 | return the number of motion frames 72 | ''' 73 | return self.data.shape[-1] 74 | 75 | 76 | def sample(self, size=None, slerp=False): 77 | ''' 78 | sample the motion data using given size 79 | ''' 80 | if size is None: 81 | return self.data 82 | else: 83 | if slerp: 84 | motion = self.slerp(self.data, size=size) 85 | else: 86 | motion = F.interpolate(self.data, size=size, mode='linear', align_corners=False) 87 | return motion 88 | 89 | 90 | def to_velocity(self, pos): 91 | ''' 92 | convert motion data to velocity 93 | ''' 94 | assert self.begin_pos is None, 'the motion data had been converted to velocity' 95 | msk = [i - self.n_pad for i in self.velo_mask] 96 | velo = pos.detach().clone().to(pos.device) 97 | velo[:, msk, 1:] = pos[:, msk, 1:] - pos[:, msk, :-1] 98 | self.begin_pos = pos[:, msk, 0].clone() 99 | velo[:, msk, 0] = pos[:, msk, 1] 100 | return velo 101 | 102 | def to_position(self, velo): 103 | ''' 104 | convert motion data to position 105 | ''' 106 | assert self.begin_pos is not None, 'the motion data is already position' 107 | msk = [i - self.n_pad for i in self.velo_mask] 108 | pos = velo.detach().clone().to(velo.device) 109 | pos[:, msk, 0] = self.begin_pos.to(velo.device) 110 | pos[:, msk] = torch.cumsum(pos[:, msk], dim=-1) 111 | self.begin_pos = None 112 | return pos -------------------------------------------------------------------------------- /dataset/tracks_motion.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as pjoin 3 | import numpy as np 4 | import copy 5 | import torch 6 | from .motion import MotionData 7 | from ..utils.transforms import quat2repr6d, quat2euler, repr6d2quat 8 | 9 | class TracksParser(): 10 | def __init__(self, tracks_json, scale): 11 | self.tracks_json = tracks_json 12 | self.scale = scale 13 | 14 | self.skeleton_names = [] 15 | self.rotations = [] 16 | for i, track in enumerate(self.tracks_json): 17 | self.skeleton_names.append(track['name']) 18 | if i == 0: 19 | assert track['type'] == 'vector' 20 | self.position = np.array(track['values']).reshape(-1, 3) * self.scale 21 | self.num_frames = self.position.shape[0] 22 | else: 23 | assert track['type'] == 'quaternion' # DEAFULT: quaternion 24 | rotation = np.array(track['values']).reshape(-1, 4) 25 | if rotation.shape[0] == 0: 26 | rotation = np.zeros((self.num_frames, 4)) 27 | elif rotation.shape[0] < self.num_frames: 28 | rotation = np.repeat(rotation, self.num_frames // rotation.shape[0], axis=0) 29 | elif rotation.shape[0] > self.num_frames: 30 | rotation = rotation[:self.num_frames] 31 | self.rotations += [rotation] 32 | self.rotations = np.array(self.rotations, dtype=np.float32) 33 | 34 | def to_tensor(self, repr='euler', rot_only=False): 35 | if repr not in ['euler', 'quat', 'quaternion', 'repr6d']: 36 | raise Exception('Unknown rotation representation') 37 | rotations = self.get_rotation(repr=repr) 38 | positions = self.get_position() 39 | 40 | if rot_only: 41 | return rotations.reshape(rotations.shape[0], -1) 42 | 43 | rotations = rotations.reshape(rotations.shape[0], -1) 44 | return torch.cat((rotations, positions), dim=-1) 45 | 46 | def get_rotation(self, repr='quat'): 47 | if repr == 'quaternion' or repr == 'quat' or repr == 'repr6d': 48 | rotations = torch.tensor(self.rotations, dtype=torch.float).transpose(0, 1) 49 | if repr == 'repr6d': 50 | rotations = quat2repr6d(rotations) 51 | if repr == 'euler': 52 | rotations = quat2euler(rotations) 53 | return rotations 54 | 55 | def get_position(self): 56 | return torch.tensor(self.position, dtype=torch.float32) 57 | 58 | class TracksMotion: 59 | def __init__(self, tracks_json, scale=1.0, repr='quat', use_velo=True, keep_up_pos=True, up_axis='Y_UP', padding_last=False): 60 | ''' 61 | TracksMotion constructor 62 | Args: 63 | tracks_json : dict, json format tracks data to load from 64 | scale : float, scale of the tracks motion data 65 | repr : string, rotation representation, support ['quat', 'repr6d', 'euler'] 66 | use_velo : book, whether to transform the joints positions to velocities 67 | keep_up_pos : bool, whether to keep y position when converting to velocity 68 | up_axis : string, string, up axis of the motion data 69 | padding_last : bool, whether to pad the last position 70 | ''' 71 | self.tracks_json = tracks_json 72 | 73 | self.raw_data = TracksParser(tracks_json, scale) 74 | self.motion_data = MotionData(self.raw_data.to_tensor(repr=repr).permute(1, 0).unsqueeze(0), repr=repr, use_velo=use_velo, keep_up_pos=keep_up_pos, up_axis=up_axis, 75 | padding_last=padding_last, contact_id=None) 76 | @property 77 | def repr(self): 78 | return self.motion_data.repr 79 | 80 | @property 81 | def use_velo(self): 82 | return self.motion_data.use_velo 83 | 84 | @property 85 | def keep_up_pos(self): 86 | return self.motion_data.keep_up_pos 87 | 88 | @property 89 | def padding_last(self): 90 | return self.motion_data.padding_last 91 | 92 | @property 93 | def n_pad(self): 94 | return self.motion_data.n_pad 95 | 96 | @property 97 | def n_rot(self): 98 | return self.motion_data.n_rot 99 | 100 | def sample(self, size=None, slerp=False): 101 | ''' 102 | Sample motion data, support slerp 103 | ''' 104 | return self.motion_data.sample(size, slerp) 105 | 106 | 107 | def parse(self, motion, keep_velo=False,): 108 | """ 109 | No batch support here!!! 110 | :returns tracks_json 111 | """ 112 | motion = motion.clone() 113 | 114 | if self.use_velo and not keep_velo: 115 | motion = self.motion_data.to_position(motion) 116 | if self.n_pad: 117 | motion = motion[:, :-self.n_pad] 118 | 119 | motion = motion.squeeze().permute(1, 0) 120 | pos = motion[..., -3:] / self.raw_data.scale 121 | rot = motion[..., :-3].reshape(motion.shape[0], -1, self.n_rot) 122 | if self.repr == 'repr6d': 123 | rot = repr6d2quat(rot) 124 | elif self.repr == 'euler': 125 | raise NotImplementedError('parse "euler is not implemented yet!!!') 126 | 127 | times = [] 128 | out_tracks_json = copy.deepcopy(self.tracks_json) 129 | for i, _track in enumerate(out_tracks_json): 130 | if i == 0: 131 | times = [ j * out_tracks_json[i]['times'][1] for j in range(motion.shape[0])] 132 | out_tracks_json[i]['values'] = pos.flatten().detach().cpu().numpy().tolist() 133 | else: 134 | out_tracks_json[i]['values'] = rot[:, i-1, :].flatten().detach().cpu().numpy().tolist() 135 | out_tracks_json[i]['times'] = times 136 | 137 | return out_tracks_json 138 | -------------------------------------------------------------------------------- /demo.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wyysf-98/GenMM/fae65cdf199da8a25c4b28ceef20636b534269aa/demo.blend -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:1.10.0-cuda11.3-cudnn8-devel 2 | 3 | # For the convenience for users in China mainland 4 | COPY apt-sources.list /etc/apt/sources.list 5 | # Install some basic utilities 6 | RUN rm /etc/apt/sources.list.d/cuda.list 7 | RUN rm /etc/apt/sources.list.d/nvidia-ml.list 8 | RUN apt-get update && apt-get install -y \ 9 | curl \ 10 | ca-certificates \ 11 | sudo \ 12 | git \ 13 | bzip2 \ 14 | libx11-6 \ 15 | gcc \ 16 | g++ \ 17 | libusb-1.0-0 \ 18 | libgl1-mesa-glx \ 19 | libglib2.0-dev \ 20 | openssh-server \ 21 | openssh-client \ 22 | iputils-ping \ 23 | unzip \ 24 | cmake \ 25 | libssl-dev \ 26 | libosmesa6-dev \ 27 | freeglut3-dev \ 28 | ffmpeg \ 29 | iputils-ping \ 30 | && rm -rf /var/lib/apt/lists/* 31 | 32 | # For the convenience for users in China mainland 33 | RUN pip config set global.index-url https://pypi.mirrors.ustc.edu.cn/simple \ 34 | && export PATH="/usr/local/bin:$PATH" \ 35 | && /bin/bash -c "source ~/.bashrc" 36 | RUN conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ \ 37 | && conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ \ 38 | && conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/ \ 39 | && conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/ \ 40 | && conda config --set show_channel_urls yes 41 | 42 | # Install dependencies 43 | COPY requirements.txt requirements.txt 44 | RUN pip install -r requirements.txt --user 45 | 46 | CMD ["python3"] -------------------------------------------------------------------------------- /docker/README.md: -------------------------------------------------------------------------------- 1 | ## Build Docker Environment and use with GPU Support 2 | 3 | Before you can use this Docker environment, you need to have the following: 4 | 5 | - Docker installed on your system 6 | - NVIDIA drivers installed on your system 7 | - NVIDIA Container Toolkit installed on your system 8 | 9 | 10 | ### Build and Run 11 | 1. Build docker image: 12 | ```sh 13 | docker build -t GenMM:latest . 14 | ``` 15 | 2. Start the docker container: 16 | ```sh 17 | docker run --gpus all -it GenMM:latest /bin/bash 18 | ``` 19 | 3. Clone the repository: 20 | ```sh 21 | git clone git@github.com:wyysf-98/GenMM.git 22 | ``` 23 | 24 | ## Troubleshooting 25 | 26 | If you encounter any issues with the Docker environment with GPU support, please check the following: 27 | 28 | - Make sure that you have installed the NVIDIA drivers and NVIDIA Container Toolkit on your system. 29 | - Make sure that you have specified the --gpus all option when starting the Docker container. 30 | - Make sure that your deep learning application is configured to use the GPU. -------------------------------------------------------------------------------- /docker/apt-sources.list: -------------------------------------------------------------------------------- 1 | deb https://mirrors.ustc.edu.cn/ubuntu/ bionic main restricted universe multiverse 2 | deb-src https://mirrors.ustc.edu.cn/ubuntu/ bionic main restricted universe multiverse 3 | deb https://mirrors.ustc.edu.cn/ubuntu/ bionic-updates main restricted universe multiverse 4 | deb-src https://mirrors.ustc.edu.cn/ubuntu/ bionic-updates main restricted universe multiverse 5 | deb https://mirrors.ustc.edu.cn/ubuntu/ bionic-backports main restricted universe multiverse 6 | deb-src https://mirrors.ustc.edu.cn/ubuntu/ bionic-backports main restricted universe multiverse 7 | deb https://mirrors.ustc.edu.cn/ubuntu/ bionic-security main restricted universe multiverse 8 | deb-src https://mirrors.ustc.edu.cn/ubuntu/ bionic-security main restricted universe multiverse 9 | deb https://mirrors.ustc.edu.cn/ubuntu/ bionic-proposed main restricted universe multiverse 10 | deb-src https://mirrors.ustc.edu.cn/ubuntu/ bionic-proposed main restricted universe multiverse -------------------------------------------------------------------------------- /docker/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.12.1 2 | torchvision==0.13.1 3 | tensorboardX==2.5 4 | tqdm==4.62.3 5 | unfoldNd==0.2.0 6 | pyyaml>=5.3.1 7 | gradio==3.34.0 8 | matplotlib==3.3.2 -------------------------------------------------------------------------------- /docker/requirements_blender.txt: -------------------------------------------------------------------------------- 1 | torch==1.12.1 2 | torchvision==0.13.1 3 | tqdm==4.62.3 4 | unfoldNd==0.2.0 5 | pyyaml>=5.3.1 -------------------------------------------------------------------------------- /fix_contact.py: -------------------------------------------------------------------------------- 1 | from dataset.bvh.bvh_parser import BVH_file 2 | from os.path import join as pjoin 3 | import numpy as np 4 | import torch 5 | from utils.contact import constrain_from_contact 6 | from utils.kinematics import InverseKinematicsJoint2 7 | from utils.transforms import repr6d2quat 8 | from tqdm import tqdm 9 | import argparse 10 | import matplotlib.pyplot as plt 11 | from dataset.bvh_motion import skeleton_confs 12 | 13 | def continuous_filter(contact, length=2): 14 | contact = contact.copy() 15 | for j in range(contact.shape[1]): 16 | c = contact[:, j] 17 | t_len = 0 18 | prev = c[0] 19 | for i in range(contact.shape[0]): 20 | if prev == c[i]: 21 | t_len += 1 22 | else: 23 | if t_len <= length: 24 | c[i - t_len:i] = c[i] 25 | t_len = 1 26 | prev = c[i] 27 | return contact 28 | 29 | 30 | def fix_negative_height(contact, constrain, cid): 31 | floor = -1 32 | constrain = constrain.clone() 33 | for i in range(constrain.shape[0]): 34 | for j in range(constrain.shape[1]): 35 | if constrain[i, j, 1] < floor: 36 | constrain[i, j, 1] = floor 37 | return constrain 38 | 39 | 40 | def fix_contact(bvh_file, contact): 41 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 42 | cid = bvh_file.skeleton.contact_id 43 | glb = bvh_file.joint_position() 44 | rotation = bvh_file.get_rotation(repr='repr6d').to(device) 45 | position = bvh_file.get_position().to(device) 46 | contact = contact > 0.5 47 | # contact = continuous_filter(contact) 48 | constrain = constrain_from_contact(contact, glb, cid) 49 | constrain = fix_negative_height(contact, constrain, cid).to(device) 50 | cid = list(range(glb.shape[1])) 51 | ik_solver = InverseKinematicsJoint2(rotation, position, bvh_file.skeleton.offsets.to(device), bvh_file.skeleton.parent, 52 | constrain[:, cid], cid, 0.1, 0.01, use_velo=True) 53 | 54 | loop = tqdm(range(500)) 55 | losses = [] 56 | for i in loop: 57 | loss = ik_solver.step() 58 | loop.set_description(f'loss = {loss:.07f}') 59 | losses += [loss] 60 | plt.plot(losses) 61 | 62 | 63 | return repr6d2quat(ik_solver.rotations.detach()), ik_solver.get_position() 64 | 65 | 66 | def fix_contact_on_file(prefix, name): 67 | try: 68 | contact = np.load(pjoin(prefix, name + '.bvh.contact.npy')) 69 | except: 70 | print(f'{name} not found') 71 | return 72 | bvh_file = BVH_file(pjoin(prefix, name + '.bvh'), no_scale=True, requires_contact=True) 73 | print('Fixing foot contact with IK...') 74 | res = fix_contact(bvh_file, contact) 75 | bvh_file.writer.write(pjoin(prefix, name + '_fixed.bvh'), res[0], res[1], names=bvh_file.skeleton.names, repr='quat') 76 | 77 | 78 | if __name__ == '__main__': 79 | parser = argparse.ArgumentParser() 80 | parser.add_argument('--prefix', type=str, required=True) 81 | parser.add_argument('--name', type=str, required=True) 82 | parser.add_argument('--skeleton_name', type=str, required=True) 83 | args = parser.parse_args() 84 | if args.prefix[0] == '/': 85 | prefix = args.prefix 86 | else: 87 | prefix = f'./results/{args.prefix}' 88 | name = args.name 89 | contact = np.load(pjoin(prefix, name + '.bvh.contact.npy')) 90 | bvh_file = BVH_file(pjoin(prefix, name + '.bvh'), skeleton_confs[args.skeleton_name], auto_scale=False, requires_contact=True) 91 | 92 | res = fix_contact(bvh_file, contact) 93 | plt.savefig(f'{prefix}/losses.png') 94 | 95 | bvh_file.writer.write(pjoin(prefix, name + '_fixed.bvh'), res[0], res[1], names=bvh_file.skeleton.names, repr='quat') -------------------------------------------------------------------------------- /nearest_neighbor/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .utils import extract_patches, combine_patches, efficient_cdist, get_NNs_Dists 5 | 6 | class PatchCoherentLoss(torch.nn.Module): 7 | def __init__(self, patch_size=7, stride=1, alpha=None, loop=False, cache=False): 8 | super(PatchCoherentLoss, self).__init__() 9 | self.patch_size = patch_size 10 | assert self.patch_size % 2 == 1, "Only support odd patch size" 11 | self.stride = stride 12 | assert self.stride == 1, "Only support stride of 1" 13 | self.alpha = alpha 14 | self.loop = loop 15 | self.cache = cache 16 | if cache: 17 | self.cached_data = None 18 | 19 | def forward(self, X, Ys, dist_wrapper=None, ext=None, return_blended_results=False): 20 | """For each patch in input X find its NN in target Y and sum the their distances""" 21 | assert X.shape[0] == 1, "Only support batch size of 1 for X" 22 | dist_fn = lambda X, Y: dist_wrapper(efficient_cdist, X, Y) if dist_wrapper is not None else efficient_cdist(X, Y) 23 | 24 | x_patches = extract_patches(X, self.patch_size, self.stride, loop=self.loop) 25 | 26 | if not self.cache or self.cached_data is None: 27 | y_patches = [] 28 | for y in Ys: 29 | y_patches += [extract_patches(y, self.patch_size, self.stride, loop=False)] 30 | y_patches = torch.cat(y_patches, dim=1) 31 | self.cached_data = y_patches 32 | else: 33 | y_patches = self.cached_data 34 | 35 | nnf, dist = get_NNs_Dists(dist_fn, x_patches.squeeze(0), y_patches.squeeze(0), self.alpha) 36 | 37 | if return_blended_results: 38 | return combine_patches(X.shape, y_patches[:, nnf, :], self.patch_size, self.stride, loop=self.loop), dist.mean() 39 | else: 40 | return dist.mean() 41 | 42 | def clean_cache(self): 43 | self.cached_data = None -------------------------------------------------------------------------------- /nearest_neighbor/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | this file borrows some codes from https://github.com/ariel415el/Efficient-GPNN/blob/main/utils/NN.py. 3 | """ 4 | import torch 5 | import torch.nn.functional as F 6 | import unfoldNd 7 | 8 | def extract_patches(x, patch_size, stride, loop=False): 9 | """Extract patches from a motion sequence""" 10 | b, c, _t = x.shape 11 | 12 | # manually padding to loop 13 | if loop: 14 | half = patch_size // 2 15 | front, tail = x[:,:,:half], x[:,:,-half:] 16 | x = torch.concat([tail, x, front], dim=-1) 17 | 18 | x_patches = unfoldNd.unfoldNd(x, kernel_size=patch_size, stride=stride).transpose(1, 2).reshape(b, -1, c, patch_size) 19 | 20 | return x_patches.view(b, -1, c * patch_size) 21 | 22 | def combine_patches(x_shape, ys, patch_size, stride, loop=False): 23 | """Combine motion patches""" 24 | 25 | # manually handle the loop situation 26 | out_shape = [*x_shape] 27 | if loop: 28 | padding = patch_size // 2 29 | out_shape[-1] = out_shape[-1] + padding * 2 30 | 31 | combined = unfoldNd.foldNd(ys.permute(0, 2, 1), output_size=tuple(out_shape[-1:]), kernel_size=patch_size, stride=stride) 32 | 33 | # normal fold matrix 34 | input_ones = torch.ones(tuple(out_shape), dtype=ys.dtype, device=ys.device) 35 | divisor = unfoldNd.unfoldNd(input_ones, kernel_size=patch_size, stride=stride) 36 | divisor = unfoldNd.foldNd(divisor, output_size=tuple(out_shape[-1:]), kernel_size=patch_size, stride=stride) 37 | combined = (combined / divisor).squeeze(dim=0).unsqueeze(0) 38 | 39 | if loop: 40 | half = patch_size // 2 41 | front, tail = combined[:,:,:half], combined[:,:,-half:] 42 | combined[:, :, half:2 * half] = (combined[:, :, half:2 * half] + tail) / 2 43 | combined[:, :, - 2 * half:-half] = (front + combined[:, :, - 2 * half:-half]) / 2 44 | combined = combined[:, :, half:-half] 45 | 46 | return combined 47 | 48 | 49 | def efficient_cdist(X, Y): 50 | """ 51 | borrowed from https://github.com/ariel415el/Efficient-GPNN/blob/main/utils/NN.py 52 | Pytorch efficient way of computing distances between all vectors in X and Y, i.e (X[:, None] - Y[None, :])**2 53 | Get the nearest neighbor index from Y for each X 54 | :param X: (n1, d) tensor 55 | :param Y: (n2, d) tensor 56 | Returns a n2 n1 of indices 57 | """ 58 | dist = (X * X).sum(1)[:, None] + (Y * Y).sum(1)[None, :] - 2.0 * torch.mm(X, torch.transpose(Y, 0, 1)) 59 | d = X.shape[1] 60 | dist /= d # normalize by size of vector to make dists independent of the size of d ( use same alpha for all patche-sizes) 61 | return dist # DO NOT use torch.sqrt 62 | 63 | 64 | def get_col_mins_efficient(dist_fn, X, Y, b=1024): 65 | """ 66 | borrowed from https://github.com/ariel415el/Efficient-GPNN/blob/main/utils/NN.py 67 | Computes the l2 distance to the closest x or each y. 68 | :param X: (n1, d) tensor 69 | :param Y: (n2, d) tensor 70 | Returns n1 long array of L2 distances 71 | """ 72 | n_batches = len(Y) // b 73 | mins = torch.zeros(Y.shape[0], dtype=X.dtype, device=X.device) 74 | for i in range(n_batches): 75 | mins[i * b:(i + 1) * b] = dist_fn(X, Y[i * b:(i + 1) * b]).min(0)[0] 76 | if len(Y) % b != 0: 77 | mins[n_batches * b:] = dist_fn(X, Y[n_batches * b:]).min(0)[0] 78 | 79 | return mins 80 | 81 | 82 | def get_NNs_Dists(dist_fn, X, Y, alpha=None, b=1024): 83 | """ 84 | borrowed from https://github.com/ariel415el/Efficient-GPNN/blob/main/utils/NN.py 85 | Get the nearest neighbor index from Y for each X. 86 | Avoids holding a (n1 * n2) amtrix in order to reducing memory footprint to (b * max(n1,n2)). 87 | :param X: (n1, d) tensor 88 | :param Y: (n2, d) tensor 89 | Returns a n2 n1 of indices amd distances 90 | """ 91 | if alpha is not None: 92 | normalizing_row = get_col_mins_efficient(dist_fn, X, Y, b=b) 93 | normalizing_row = alpha + normalizing_row[None, :] 94 | else: 95 | normalizing_row = 1 96 | 97 | NNs = torch.zeros(X.shape[0], dtype=torch.long, device=X.device) 98 | Dists = torch.zeros(X.shape[0], dtype=torch.float, device=X.device) 99 | 100 | n_batches = len(X) // b 101 | for i in range(n_batches): 102 | dists = dist_fn(X[i * b:(i + 1) * b], Y) / normalizing_row 103 | NNs[i * b:(i + 1) * b] = dists.min(1)[1] 104 | Dists[i * b:(i + 1) * b] = dists.min(1)[0] 105 | if len(X) % b != 0: 106 | dists = dist_fn(X[n_batches * b:], Y) / normalizing_row 107 | NNs[n_batches * b:] = dists.min(1)[1] 108 | Dists[n_batches * b: ] = dists.min(1)[0] 109 | 110 | return NNs, Dists 111 | -------------------------------------------------------------------------------- /run_random_generation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import argparse 4 | from GenMM import GenMM 5 | from nearest_neighbor.losses import PatchCoherentLoss 6 | from dataset.bvh_motion import BVHMotion, load_multiple_dataset 7 | from utils.base import ConfigParser, set_seed 8 | 9 | args = argparse.ArgumentParser( 10 | description='Random shuffle the input motion sequence') 11 | args.add_argument('-m', '--mode', default='run', 12 | choices=['run', 'eval', 'debug'], type=str, help='current run mode.') 13 | args.add_argument('-i', '--input', required=True, 14 | type=str, help='exemplar motion path.') 15 | args.add_argument('-o', '--output_dir', default='./output', 16 | type=str, help='output folder path for saving results.') 17 | args.add_argument('-c', '--config', default='./configs/default.yaml', 18 | type=str, help='config file path.') 19 | args.add_argument('-s', '--seed', default=None, 20 | type=int, help='random seed used.') 21 | args.add_argument('-d', '--device', default="cuda:0", 22 | type=str, help='device to use.') 23 | args.add_argument('--post_precess', action='store_true', 24 | help='whether to use IK post-process to fix foot contact.') 25 | 26 | # Use argsparser to overwrite the configuration 27 | # for dataset 28 | args.add_argument('--skeleton_name', type=str, 29 | help='(used when joint_reduction==True or contact==True) skeleton name to load pre-defined joints configuration.') 30 | args.add_argument('--use_velo', type=int, 31 | help='whether to use velocity rather than global position of each joint.') 32 | args.add_argument('--repr', choices=['repr6d', 'quat', 'euler'], type=str, 33 | help='rotation representation, support [epr6d, quat, reuler].') 34 | args.add_argument('--requires_contact', type=int, 35 | help='whether to use contact label.') 36 | args.add_argument('--keep_up_pos', type=int, 37 | help='whether to do not use velocity and keep the y(up) position.') 38 | args.add_argument('--up_axis', type=str, choices=['X_UP', 'Y_UP', 'Z_UP'], 39 | help='up axis of the motion.') 40 | args.add_argument('--padding_last', type=int, 41 | help='whether to pad the last position channel to match the rotation dimension.') 42 | args.add_argument('--joint_reduction', type=int, 43 | help='whether to simplify the skeleton using provided skeleton config.') 44 | args.add_argument('--skeleton_aware', type=int, 45 | help='whether to enable skeleton-aware component.') 46 | args.add_argument('--joints_group', type=str, 47 | help='joints spliting group for using skeleton-aware component.') 48 | # for synthesis 49 | args.add_argument('--num_frames', type=str, 50 | help='number of synthesized frames, supported Nx(N times) and int input.') 51 | args.add_argument('--alpha', type=float, 52 | help='completeness/diversity trade-off alpha value.') 53 | args.add_argument('--num_steps', type=int, 54 | help='number of optimization steps at each pyramid level.') 55 | args.add_argument('--noise_sigma', type=float, 56 | help='standard deviation of the zero mean normal noise added to the initialization.') 57 | args.add_argument('--coarse_ratio', type=float, 58 | help='downscale ratio of the coarse level.') 59 | args.add_argument('--coarse_ratio_factor', type=float, 60 | help='downscale ratio of the coarse level.') 61 | args.add_argument('--pyr_factor', type=float, 62 | help='upsample ratio of each pyramid level.') 63 | args.add_argument('--num_stages_limit', type=int, 64 | help='limit of the number of stages.') 65 | args.add_argument('--patch_size', type=int, help='patch size for generation.') 66 | args.add_argument('--loop', type=int, help='whether to loop the sequence.') 67 | cfg = ConfigParser(args) 68 | 69 | 70 | def generate(cfg): 71 | # seet seed for reproducible 72 | set_seed(cfg.seed) 73 | 74 | # set save path and prepare data for generation 75 | if cfg.input.endswith('.bvh'): 76 | base_dir = osp.join( 77 | cfg.output_dir, cfg.input.split('/')[-1].split('.')[0]) 78 | motion_data = [BVHMotion(cfg.input, skeleton_name=cfg.skeleton_name, repr=cfg.repr, 79 | use_velo=cfg.use_velo, keep_up_pos=cfg.keep_up_pos, up_axis=cfg.up_axis, padding_last=cfg.padding_last, 80 | requires_contact=cfg.requires_contact, joint_reduction=cfg.joint_reduction)] 81 | elif cfg.input.endswith('.txt'): 82 | base_dir = osp.join(cfg.output_dir, cfg.input.split( 83 | '/')[-2], cfg.input.split('/')[-1].split('.')[0]) 84 | motion_data = load_multiple_dataset(name_list=cfg.input, skeleton_name=cfg.skeleton_name, repr=cfg.repr, 85 | use_velo=cfg.use_velo, keep_up_pos=cfg.keep_up_pos, up_axis=cfg.up_axis, padding_last=cfg.padding_last, 86 | requires_contact=cfg.requires_contact, joint_reduction=cfg.joint_reduction) 87 | else: 88 | raise ValueError('exemplar must be a bvh file or a txt file') 89 | prefix = f"s{cfg.seed}+{cfg.num_frames}+{cfg.repr}+use_velo_{cfg.use_velo}+kypose_{cfg.keep_up_pos}+padding_{cfg.padding_last}" \ 90 | f"+contact_{cfg.requires_contact}+jredu_{cfg.joint_reduction}+n{cfg.noise_sigma}+pyr{cfg.pyr_factor}" \ 91 | f"+r{cfg.coarse_ratio}_{cfg.coarse_ratio_factor}+itr{cfg.num_steps}+ps_{cfg.patch_size}+alpha_{cfg.alpha}" \ 92 | f"+loop_{cfg.loop}" 93 | 94 | # perform the generation 95 | model = GenMM(device=cfg.device, silent=True if cfg.mode == 'eval' else False) 96 | criteria = PatchCoherentLoss(patch_size=cfg.patch_size, alpha=cfg.alpha, loop=cfg.loop, cache=True) 97 | syn = model.run(motion_data, criteria, 98 | num_frames=cfg.num_frames, 99 | num_steps=cfg.num_steps, 100 | noise_sigma=cfg.noise_sigma, 101 | patch_size=cfg.patch_size, 102 | coarse_ratio=cfg.coarse_ratio, 103 | pyr_factor=cfg.pyr_factor, 104 | debug_dir=save_dir if cfg.mode == 'debug' else None) 105 | 106 | # save the generated results 107 | save_dir = osp.join(base_dir, prefix) 108 | os.makedirs(save_dir, exist_ok=True) 109 | motion_data[0].write(f"{save_dir}/syn.bvh", syn) 110 | 111 | if cfg.post_precess: 112 | cmd = f"python fix_contact.py --prefix {osp.abspath(save_dir)} --name syn --skeleton_name={cfg.skeleton_name}" 113 | os.system(cmd) 114 | 115 | if __name__ == '__main__': 116 | generate(cfg) 117 | -------------------------------------------------------------------------------- /run_web_server.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | import torch 4 | import argparse 5 | import gradio as gr 6 | 7 | from GenMM import GenMM 8 | from nearest_neighbor.losses import PatchCoherentLoss 9 | from dataset.tracks_motion import TracksMotion 10 | 11 | args = argparse.ArgumentParser(description='Web server for GenMM') 12 | args.add_argument('-d', '--device', default="cuda:0", type=str, help='device to use.') 13 | args.add_argument('--ip', default="0.0.0.0", type=str, help='interface url to host.') 14 | args.add_argument('--port', default=8000, type=int, help='interface port to serve.') 15 | args.add_argument('--debug', action='store_true', help='debug mode.') 16 | args = args.parse_args() 17 | 18 | def generate(data): 19 | data = json.loads(data) 20 | 21 | # create track object 22 | motion_data = [TracksMotion(data['tracks'], repr='repr6d', use_velo=True, keep_y_pos=True, padding_last=False)] 23 | model = GenMM(device=args.device, silent=True) 24 | criteria = PatchCoherentLoss(patch_size=data['setting']['patch_size'], 25 | alpha=data['setting']['alpha'] if data['setting']['completeness'] else None, 26 | loop=data['setting']['loop'], cache=True) 27 | 28 | # start generation 29 | start = time.time() 30 | syn = model.run(motion_data, criteria, 31 | num_frames=str(data['setting']['frames']), 32 | num_steps=data['setting']['num_steps'], 33 | noise_sigma=data['setting']['noise_sigma'], 34 | patch_size=data['setting']['patch_size'], 35 | coarse_ratio=f'{data["setting"]["coarse_ratio"]}x_nframes', 36 | # coarse_ratio=f'3x_patchsize', 37 | pyr_factor=data['setting']['pyr_factor']) 38 | end = time.time() 39 | 40 | data['time'] = end - start 41 | data['tracks'] = motion_data[0].parse(syn) 42 | 43 | return data 44 | 45 | if __name__ == '__main__': 46 | demo = gr.Interface(fn=generate, inputs="json", outputs="json") 47 | demo.launch(debug=args.debug, server_name=args.ip, server_port=args.port) -------------------------------------------------------------------------------- /utils/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import sys 4 | import time 5 | import yaml 6 | import imageio 7 | import random 8 | import shutil 9 | import random 10 | import numpy as np 11 | import torch 12 | from tqdm import tqdm 13 | 14 | # configuration 15 | class ConfigParser(): 16 | def __init__(self, args): 17 | """ 18 | class to parse configuration. 19 | """ 20 | args = args.parse_args() 21 | self.cfg = self.merge_config_file(args) 22 | 23 | # set random seed 24 | self.set_seed() 25 | 26 | def __str__(self): 27 | return str(self.cfg.__dict__) 28 | 29 | def __getattr__(self, name): 30 | """ 31 | Access items use dot.notation. 32 | """ 33 | return self.cfg.__dict__[name] 34 | 35 | def __getitem__(self, name): 36 | """ 37 | Access items like ordinary dict. 38 | """ 39 | return self.cfg.__dict__[name] 40 | 41 | def merge_config_file(self, args, allow_invalid=True): 42 | """ 43 | Load json config file and merge the arguments 44 | """ 45 | assert args.config is not None 46 | with open(args.config, 'r') as f: 47 | cfg = yaml.safe_load(f) 48 | if 'config' in cfg.keys(): 49 | del cfg['config'] 50 | f.close() 51 | invalid_args = list(set(cfg.keys()) - set(dir(args))) 52 | if invalid_args and not allow_invalid: 53 | raise ValueError(f"Invalid args {invalid_args} in {args.config}.") 54 | 55 | for k in list(cfg.keys()): 56 | if k in args.__dict__.keys() and args.__dict__[k] is not None: 57 | print('=========> overwrite config: {} = {}'.format(k, args.__dict__[k])) 58 | del cfg[k] 59 | 60 | args.__dict__.update(cfg) 61 | 62 | return args 63 | 64 | def set_seed(self): 65 | ''' set random seed for random, numpy and torch. ''' 66 | if 'seed' not in self.cfg.__dict__.keys(): 67 | return 68 | if self.cfg.seed is None: 69 | self.cfg.seed = int(time.time()) % 1000000 70 | print('=========> set random seed: {}'.format(self.cfg.seed)) 71 | # fix random seeds for reproducibility 72 | random.seed(self.cfg.seed) 73 | np.random.seed(self.cfg.seed) 74 | torch.manual_seed(self.cfg.seed) 75 | torch.cuda.manual_seed(self.cfg.seed) 76 | 77 | def save_codes_and_config(self, save_path): 78 | """ 79 | save codes and config to $save_path. 80 | """ 81 | cur_codes_path = osp.dirname(osp.dirname(os.path.abspath(__file__))) 82 | if os.path.exists(save_path): 83 | shutil.rmtree(save_path) 84 | shutil.copytree(cur_codes_path, osp.join(save_path, 'codes'), \ 85 | ignore=shutil.ignore_patterns('*debug*', '*data*', '*output*', '*exps*', '*.txt', '*.json', '*.mp4', '*.png', '*.jpg', '*.bvh', '*.csv', '*.pth', '*.tar', '*.npz')) 86 | 87 | with open(osp.join(save_path, 'config.yaml'), 'w') as f: 88 | f.write(yaml.dump(self.cfg.__dict__)) 89 | f.close() 90 | 91 | 92 | # logger util 93 | class logger: 94 | """ 95 | Keeps track of the levels and steps of optimization. Logs it via TQDM 96 | """ 97 | def __init__(self, n_steps, n_lvls): 98 | self.n_steps = n_steps 99 | self.n_lvls = n_lvls 100 | self.lvl = -1 101 | self.lvl_step = 0 102 | self.steps = 0 103 | self.pbar = tqdm(total=self.n_lvls * self.n_steps, desc='Starting') 104 | 105 | def step(self): 106 | self.pbar.update(1) 107 | self.steps += 1 108 | self.lvl_step += 1 109 | 110 | def new_lvl(self): 111 | self.lvl += 1 112 | self.lvl_step = 0 113 | 114 | def print(self): 115 | self.pbar.set_description(f'Lvl {self.lvl}/{self.n_lvls-1}, step {self.lvl_step}/{self.n_steps}') 116 | 117 | 118 | # other utils 119 | def set_seed(seed=None): 120 | """ 121 | Set all the seed for the reproducible 122 | """ 123 | if seed is not None: 124 | random.seed(seed) 125 | np.random.seed(seed) 126 | torch.manual_seed(seed) 127 | torch.cuda.manual_seed(seed) -------------------------------------------------------------------------------- /utils/contact.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def foot_contact_by_height(pos): 5 | eps = 0.25 6 | return (-eps < pos[..., 1]) * (pos[..., 1] < eps) 7 | 8 | 9 | def velocity(pos, padding=False): 10 | velo = pos[1:, ...] - pos[:-1, ...] 11 | velo_norm = torch.norm(velo, dim=-1) 12 | if padding: 13 | pad = torch.zeros_like(velo_norm[:1, :]) 14 | velo_norm = torch.cat([pad, velo_norm], dim=0) 15 | return velo_norm 16 | 17 | 18 | def foot_contact(pos, ref_height=1., threshold=0.018): 19 | velo_norm = velocity(pos) 20 | contact = velo_norm < threshold 21 | contact = contact.int() 22 | padding = torch.zeros_like(contact) 23 | contact = torch.cat([padding[:1, :], contact], dim=0) 24 | return contact 25 | 26 | 27 | def alpha(t): 28 | return 2.0 * t * t * t - 3.0 * t * t + 1 29 | 30 | 31 | def lerp(a, l, r): 32 | return (1 - a) * l + a * r 33 | 34 | 35 | def constrain_from_contact(contact, glb, fid='TBD', L=5): 36 | """ 37 | :param contact: contact label 38 | :param glb: original global position 39 | :param fid: joint id to fix, corresponding to the order in contact 40 | :param L: frame to look forward/backward 41 | :return: 42 | """ 43 | T = glb.shape[0] 44 | 45 | for i, fidx in enumerate(fid): # fidx: index of the foot joint 46 | fixed = contact[:, i] # [T] 47 | s = 0 48 | while s < T: 49 | while s < T and fixed[s] == 0: 50 | s += 1 51 | if s >= T: 52 | break 53 | t = s 54 | avg = glb[t, fidx].clone() 55 | while t + 1 < T and fixed[t + 1] == 1: 56 | t += 1 57 | avg += glb[t, fidx].clone() 58 | avg /= (t - s + 1) 59 | 60 | for j in range(s, t + 1): 61 | glb[j, fidx] = avg.clone() 62 | s = t + 1 63 | 64 | for s in range(T): 65 | if fixed[s] == 1: 66 | continue 67 | l, r = None, None 68 | consl, consr = False, False 69 | for k in range(L): 70 | if s - k - 1 < 0: 71 | break 72 | if fixed[s - k - 1]: 73 | l = s - k - 1 74 | consl = True 75 | break 76 | for k in range(L): 77 | if s + k + 1 >= T: 78 | break 79 | if fixed[s + k + 1]: 80 | r = s + k + 1 81 | consr = True 82 | break 83 | if not consl and not consr: 84 | continue 85 | if consl and consr: 86 | litp = lerp(alpha(1.0 * (s - l + 1) / (L + 1)), 87 | glb[s, fidx], glb[l, fidx]) 88 | ritp = lerp(alpha(1.0 * (r - s + 1) / (L + 1)), 89 | glb[s, fidx], glb[r, fidx]) 90 | itp = lerp(alpha(1.0 * (s - l + 1) / (r - l + 1)), 91 | ritp, litp) 92 | glb[s, fidx] = itp.clone() 93 | continue 94 | if consl: 95 | litp = lerp(alpha(1.0 * (s - l + 1) / (L + 1)), 96 | glb[s, fidx], glb[l, fidx]) 97 | glb[s, fidx] = litp.clone() 98 | continue 99 | if consr: 100 | ritp = lerp(alpha(1.0 * (r - s + 1) / (L + 1)), 101 | glb[s, fidx], glb[r, fidx]) 102 | glb[s, fidx] = ritp.clone() 103 | return glb 104 | -------------------------------------------------------------------------------- /utils/kinematics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.transforms import quat2mat, repr6d2mat, euler2mat 3 | 4 | 5 | class ForwardKinematics: 6 | def __init__(self, parents, offsets=None): 7 | self.parents = parents 8 | if offsets is not None and len(offsets.shape) == 2: 9 | offsets = offsets.unsqueeze(0) 10 | self.offsets = offsets 11 | 12 | def forward(self, rots, offsets=None, global_pos=None): 13 | """ 14 | Forward Kinematics: returns a per-bone transformation 15 | @param rots: local joint rotations (batch_size, bone_num, 3, 3) 16 | @param offsets: (batch_size, bone_num, 3) or None 17 | @param global_pos: global_position: (batch_size, 3) or keep it as in offsets (default) 18 | @return: (batch_szie, bone_num, 3, 4) 19 | """ 20 | rots = rots.clone() 21 | if offsets is None: 22 | offsets = self.offsets.to(rots.device) 23 | if global_pos is None: 24 | global_pos = offsets[:, 0] 25 | 26 | pos = torch.zeros((rots.shape[0], rots.shape[1], 3), device=rots.device) 27 | rest_pos = torch.zeros_like(pos) 28 | res = torch.zeros((rots.shape[0], rots.shape[1], 3, 4), device=rots.device) 29 | 30 | pos[:, 0] = global_pos 31 | rest_pos[:, 0] = offsets[:, 0] 32 | 33 | for i, p in enumerate(self.parents): 34 | if i != 0: 35 | rots[:, i] = torch.matmul(rots[:, p], rots[:, i]) 36 | pos[:, i] = torch.matmul(rots[:, p], offsets[:, i].unsqueeze(-1)).squeeze(-1) + pos[:, p] 37 | rest_pos[:, i] = rest_pos[:, p] + offsets[:, i] 38 | 39 | res[:, i, :3, :3] = rots[:, i] 40 | res[:, i, :, 3] = torch.matmul(rots[:, i], -rest_pos[:, i].unsqueeze(-1)).squeeze(-1) + pos[:, i] 41 | 42 | return res 43 | 44 | def accumulate(self, local_rots): 45 | """ 46 | Get global joint rotation from local rotations 47 | @param local_rots: (batch_size, n_bone, 3, 3) 48 | @return: global_rotations 49 | """ 50 | res = torch.empty_like(local_rots) 51 | for i, p in enumerate(self.parents): 52 | if i == 0: 53 | res[:, i] = local_rots[:, i] 54 | else: 55 | res[:, i] = torch.matmul(res[:, p], local_rots[:, i]) 56 | return res 57 | 58 | def unaccumulate(self, global_rots): 59 | """ 60 | Get local joint rotation from global rotations 61 | @param global_rots: (batch_size, n_bone, 3, 3) 62 | @return: local_rotations 63 | """ 64 | res = torch.empty_like(global_rots) 65 | inv = torch.empty_like(global_rots) 66 | 67 | for i, p in enumerate(self.parents): 68 | if i == 0: 69 | inv[:, i] = global_rots[:, i].transpose(-2, -1) 70 | res[:, i] = global_rots[:, i] 71 | continue 72 | res[:, i] = torch.matmul(inv[:, p], global_rots[:, i]) 73 | inv[:, i] = torch.matmul(res[:, i].transpose(-2, -1), inv[:, p]) 74 | 75 | return res 76 | 77 | 78 | class ForwardKinematicsJoint: 79 | def __init__(self, parents, offset): 80 | self.parents = parents 81 | self.offset = offset 82 | 83 | ''' 84 | rotation should have shape batch_size * Joint_num * (3/4) * Time 85 | position should have shape batch_size * 3 * Time 86 | offset should have shape batch_size * Joint_num * 3 87 | output have shape batch_size * Time * Joint_num * 3 88 | ''' 89 | 90 | def forward(self, rotation: torch.Tensor, position: torch.Tensor, offset=None, 91 | world=True): 92 | ''' 93 | if not quater and rotation.shape[-2] != 3: raise Exception('Unexpected shape of rotation') 94 | if quater and rotation.shape[-2] != 4: raise Exception('Unexpected shape of rotation') 95 | rotation = rotation.permute(0, 3, 1, 2) 96 | position = position.permute(0, 2, 1) 97 | ''' 98 | if rotation.shape[-1] == 6: 99 | transform = repr6d2mat(rotation) 100 | elif rotation.shape[-1] == 4: 101 | norm = torch.norm(rotation, dim=-1, keepdim=True) 102 | rotation = rotation / norm 103 | transform = quat2mat(rotation) 104 | elif rotation.shape[-1] == 3: 105 | transform = euler2mat(rotation) 106 | else: 107 | raise Exception('Only accept quaternion rotation input') 108 | result = torch.empty(transform.shape[:-2] + (3,), device=position.device) 109 | 110 | if offset is None: 111 | offset = self.offset 112 | offset = offset.reshape((-1, 1, offset.shape[-2], offset.shape[-1], 1)) 113 | 114 | result[..., 0, :] = position 115 | for i, pi in enumerate(self.parents): 116 | if pi == -1: 117 | assert i == 0 118 | continue 119 | 120 | result[..., i, :] = torch.matmul(transform[..., pi, :, :], offset[..., i, :, :]).squeeze() 121 | transform[..., i, :, :] = torch.matmul(transform[..., pi, :, :].clone(), transform[..., i, :, :].clone()) 122 | if world: result[..., i, :] += result[..., pi, :] 123 | return result 124 | 125 | 126 | class InverseKinematicsJoint: 127 | def __init__(self, rotations: torch.Tensor, positions: torch.Tensor, offset, parents, constrains): 128 | self.rotations = rotations.detach().clone() 129 | self.rotations.requires_grad_(True) 130 | self.position = positions.detach().clone() 131 | self.position.requires_grad_(True) 132 | 133 | self.parents = parents 134 | self.offset = offset 135 | self.constrains = constrains 136 | 137 | self.optimizer = torch.optim.Adam([self.position, self.rotations], lr=1e-3, betas=(0.9, 0.999)) 138 | self.criteria = torch.nn.MSELoss() 139 | 140 | self.fk = ForwardKinematicsJoint(parents, offset) 141 | 142 | self.glb = None 143 | 144 | def step(self): 145 | self.optimizer.zero_grad() 146 | glb = self.fk.forward(self.rotations, self.position) 147 | loss = self.criteria(glb, self.constrains) 148 | loss.backward() 149 | self.optimizer.step() 150 | self.glb = glb 151 | return loss.item() 152 | 153 | 154 | class InverseKinematicsJoint2: 155 | def __init__(self, rotations: torch.Tensor, positions: torch.Tensor, offset, parents, constrains, cid, 156 | lambda_rec_rot=1., lambda_rec_pos=1., use_velo=False): 157 | self.use_velo = use_velo 158 | self.rotations_ori = rotations.detach().clone() 159 | self.rotations = rotations.detach().clone() 160 | self.rotations.requires_grad_(True) 161 | self.position_ori = positions.detach().clone() 162 | self.position = positions.detach().clone() 163 | if self.use_velo: 164 | self.position[1:] = self.position[1:] - self.position[:-1] 165 | self.position.requires_grad_(True) 166 | 167 | self.parents = parents 168 | self.offset = offset 169 | self.constrains = constrains.detach().clone() 170 | self.cid = cid 171 | 172 | self.lambda_rec_rot = lambda_rec_rot 173 | self.lambda_rec_pos = lambda_rec_pos 174 | 175 | self.optimizer = torch.optim.Adam([self.position, self.rotations], lr=1e-3, betas=(0.9, 0.999)) 176 | self.criteria = torch.nn.MSELoss() 177 | 178 | self.fk = ForwardKinematicsJoint(parents, offset) 179 | 180 | self.glb = None 181 | 182 | def step(self): 183 | self.optimizer.zero_grad() 184 | if self.use_velo: 185 | position = torch.cumsum(self.position, dim=0) 186 | else: 187 | position = self.position 188 | glb = self.fk.forward(self.rotations, position) 189 | self.constrain_loss = self.criteria(glb[:, self.cid], self.constrains) 190 | self.rec_loss_rot = self.criteria(self.rotations, self.rotations_ori) 191 | self.rec_loss_pos = self.criteria(self.position, self.position_ori) 192 | loss = self.constrain_loss + self.rec_loss_rot * self.lambda_rec_rot + self.rec_loss_pos * self.lambda_rec_pos 193 | loss.backward() 194 | self.optimizer.step() 195 | self.glb = glb 196 | return loss.item() 197 | 198 | def get_position(self): 199 | if self.use_velo: 200 | position = torch.cumsum(self.position.detach(), dim=0) 201 | else: 202 | position = self.position.detach() 203 | return position 204 | -------------------------------------------------------------------------------- /utils/skeleton.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import numpy as np 6 | 7 | 8 | class SkeletonConv(nn.Module): 9 | def __init__(self, neighbour_list, in_channels, out_channels, kernel_size, joint_num, stride=1, padding=0, 10 | bias=True, padding_mode='zeros', add_offset=False, in_offset_channel=0): 11 | super(SkeletonConv, self).__init__() 12 | 13 | if in_channels % joint_num != 0 or out_channels % joint_num != 0: 14 | raise Exception('in/out channels should be divided by joint_num') 15 | self.in_channels_per_joint = in_channels // joint_num 16 | self.out_channels_per_joint = out_channels // joint_num 17 | 18 | if padding_mode == 'zeros': padding_mode = 'constant' 19 | 20 | self.expanded_neighbour_list = [] 21 | self.expanded_neighbour_list_offset = [] 22 | self.neighbour_list = neighbour_list 23 | self.add_offset = add_offset 24 | self.joint_num = joint_num 25 | 26 | self.stride = stride 27 | self.dilation = 1 28 | self.groups = 1 29 | self.padding = padding 30 | self.padding_mode = padding_mode 31 | self._padding_repeated_twice = (padding, padding) 32 | 33 | for neighbour in neighbour_list: 34 | expanded = [] 35 | for k in neighbour: 36 | for i in range(self.in_channels_per_joint): 37 | expanded.append(k * self.in_channels_per_joint + i) 38 | self.expanded_neighbour_list.append(expanded) 39 | 40 | if self.add_offset: 41 | self.offset_enc = SkeletonLinear(neighbour_list, in_offset_channel * len(neighbour_list), out_channels) 42 | 43 | for neighbour in neighbour_list: 44 | expanded = [] 45 | for k in neighbour: 46 | for i in range(add_offset): 47 | expanded.append(k * in_offset_channel + i) 48 | self.expanded_neighbour_list_offset.append(expanded) 49 | 50 | self.weight = torch.zeros(out_channels, in_channels, kernel_size) 51 | if bias: 52 | self.bias = torch.zeros(out_channels) 53 | else: 54 | self.register_parameter('bias', None) 55 | 56 | self.mask = torch.zeros_like(self.weight) 57 | for i, neighbour in enumerate(self.expanded_neighbour_list): 58 | self.mask[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), neighbour, ...] = 1 59 | self.mask = nn.Parameter(self.mask, requires_grad=False) 60 | 61 | self.description = 'SkeletonConv(in_channels_per_armature={}, out_channels_per_armature={}, kernel_size={}, ' \ 62 | 'joint_num={}, stride={}, padding={}, bias={})'.format( 63 | in_channels // joint_num, out_channels // joint_num, kernel_size, joint_num, stride, padding, bias 64 | ) 65 | 66 | self.reset_parameters() 67 | 68 | def reset_parameters(self): 69 | for i, neighbour in enumerate(self.expanded_neighbour_list): 70 | """ Use temporary variable to avoid assign to copy of slice, which might lead to un expected result """ 71 | tmp = torch.zeros_like(self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), 72 | neighbour, ...]) 73 | nn.init.kaiming_uniform_(tmp, a=math.sqrt(5)) 74 | self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), 75 | neighbour, ...] = tmp 76 | if self.bias is not None: 77 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out( 78 | self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), neighbour, ...]) 79 | bound = 1 / math.sqrt(fan_in) 80 | tmp = torch.zeros_like( 81 | self.bias[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1)]) 82 | nn.init.uniform_(tmp, -bound, bound) 83 | self.bias[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1)] = tmp 84 | 85 | self.weight = nn.Parameter(self.weight) 86 | if self.bias is not None: 87 | self.bias = nn.Parameter(self.bias) 88 | 89 | def set_offset(self, offset): 90 | if not self.add_offset: raise Exception('Wrong Combination of Parameters') 91 | self.offset = offset.reshape(offset.shape[0], -1) 92 | 93 | def forward(self, input): 94 | weight_masked = self.weight * self.mask 95 | res = F.conv1d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode), 96 | weight_masked, self.bias, self.stride, 97 | 0, self.dilation, self.groups) 98 | 99 | if self.add_offset: 100 | offset_res = self.offset_enc(self.offset) 101 | offset_res = offset_res.reshape(offset_res.shape + (1, )) 102 | res += offset_res / 100 103 | return res 104 | 105 | def __repr__(self): 106 | return self.description 107 | 108 | 109 | class SkeletonLinear(nn.Module): 110 | def __init__(self, neighbour_list, in_channels, out_channels, extra_dim1=False): 111 | super(SkeletonLinear, self).__init__() 112 | self.neighbour_list = neighbour_list 113 | self.in_channels = in_channels 114 | self.out_channels = out_channels 115 | self.in_channels_per_joint = in_channels // len(neighbour_list) 116 | self.out_channels_per_joint = out_channels // len(neighbour_list) 117 | self.extra_dim1 = extra_dim1 118 | self.expanded_neighbour_list = [] 119 | 120 | for neighbour in neighbour_list: 121 | expanded = [] 122 | for k in neighbour: 123 | for i in range(self.in_channels_per_joint): 124 | expanded.append(k * self.in_channels_per_joint + i) 125 | self.expanded_neighbour_list.append(expanded) 126 | 127 | self.weight = torch.zeros(out_channels, in_channels) 128 | self.mask = torch.zeros(out_channels, in_channels) 129 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 130 | 131 | self.reset_parameters() 132 | 133 | def reset_parameters(self): 134 | for i, neighbour in enumerate(self.expanded_neighbour_list): 135 | tmp = torch.zeros_like( 136 | self.weight[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] 137 | ) 138 | self.mask[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] = 1 139 | nn.init.kaiming_uniform_(tmp, a=math.sqrt(5)) 140 | self.weight[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] = tmp 141 | 142 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 143 | bound = 1 / math.sqrt(fan_in) 144 | nn.init.uniform_(self.bias, -bound, bound) 145 | 146 | self.weight = nn.Parameter(self.weight) 147 | self.mask = nn.Parameter(self.mask, requires_grad=False) 148 | 149 | def forward(self, input): 150 | input = input.reshape(input.shape[0], -1) 151 | weight_masked = self.weight * self.mask 152 | res = F.linear(input, weight_masked, self.bias) 153 | if self.extra_dim1: res = res.reshape(res.shape + (1,)) 154 | return res 155 | 156 | 157 | class SkeletonPoolJoint(nn.Module): 158 | def __init__(self, topology, pooling_mode, channels_per_joint, last_pool=False): 159 | super(SkeletonPoolJoint, self).__init__() 160 | 161 | if pooling_mode != 'mean': 162 | raise Exception('Unimplemented pooling mode in matrix_implementation') 163 | 164 | self.joint_num = len(topology) 165 | self.parent = topology 166 | self.pooling_list = [] 167 | self.pooling_mode = pooling_mode 168 | 169 | self.pooling_map = [-1 for _ in range(len(self.parent))] 170 | self.child = [-1 for _ in range(len(self.parent))] 171 | children_cnt = [0 for _ in range(len(self.parent))] 172 | for x, pa in enumerate(self.parent): 173 | if pa < 0: continue 174 | children_cnt[pa] += 1 175 | self.child[pa] = x 176 | self.pooling_map[0] = 0 177 | for x in range(len(self.parent)): 178 | if children_cnt[x] == 0 or (children_cnt[x] == 1 and children_cnt[self.child[x]] > 1): 179 | while children_cnt[x] <= 1: 180 | pa = self.parent[x] 181 | if last_pool: 182 | seq = [x] 183 | while pa != -1 and children_cnt[pa] == 1: 184 | seq = [pa] + seq 185 | x = pa 186 | pa = self.parent[x] 187 | self.pooling_list.append(seq) 188 | break 189 | else: 190 | if pa != -1 and children_cnt[pa] == 1: 191 | self.pooling_list.append([pa, x]) 192 | x = self.parent[pa] 193 | else: 194 | self.pooling_list.append([x, ]) 195 | break 196 | elif children_cnt[x] > 1: 197 | self.pooling_list.append([x, ]) 198 | 199 | self.description = 'SkeletonPool(in_joint_num={}, out_joint_num={})'.format( 200 | len(topology), len(self.pooling_list), 201 | ) 202 | 203 | self.pooling_list.sort(key=lambda x:x[0]) 204 | for i, a in enumerate(self.pooling_list): 205 | for j in a: 206 | self.pooling_map[j] = i 207 | 208 | self.output_joint_num = len(self.pooling_list) 209 | self.new_topology = [-1 for _ in range(len(self.pooling_list))] 210 | for i, x in enumerate(self.pooling_list): 211 | if i < 1: continue 212 | self.new_topology[i] = self.pooling_map[self.parent[x[0]]] 213 | 214 | self.weight = torch.zeros(len(self.pooling_list) * channels_per_joint, self.joint_num * channels_per_joint) 215 | 216 | for i, pair in enumerate(self.pooling_list): 217 | for j in pair: 218 | for c in range(channels_per_joint): 219 | self.weight[i * channels_per_joint + c, j * channels_per_joint + c] = 1.0 / len(pair) 220 | 221 | self.weight = nn.Parameter(self.weight, requires_grad=False) 222 | 223 | def forward(self, input: torch.Tensor): 224 | return torch.matmul(self.weight, input.unsqueeze(-1)).squeeze(-1) 225 | 226 | 227 | class SkeletonPool(nn.Module): 228 | def __init__(self, edges, pooling_mode, channels_per_edge, last_pool=False): 229 | super(SkeletonPool, self).__init__() 230 | 231 | if pooling_mode != 'mean': 232 | raise Exception('Unimplemented pooling mode in matrix_implementation') 233 | 234 | self.channels_per_edge = channels_per_edge 235 | self.pooling_mode = pooling_mode 236 | self.edge_num = len(edges) + 1 237 | self.seq_list = [] 238 | self.pooling_list = [] 239 | self.new_edges = [] 240 | degree = [0] * 100 241 | 242 | for edge in edges: 243 | degree[edge[0]] += 1 244 | degree[edge[1]] += 1 245 | 246 | def find_seq(j, seq): 247 | nonlocal self, degree, edges 248 | 249 | if degree[j] > 2 and j != 0: 250 | self.seq_list.append(seq) 251 | seq = [] 252 | 253 | if degree[j] == 1: 254 | self.seq_list.append(seq) 255 | return 256 | 257 | for idx, edge in enumerate(edges): 258 | if edge[0] == j: 259 | find_seq(edge[1], seq + [idx]) 260 | 261 | find_seq(0, []) 262 | for seq in self.seq_list: 263 | if last_pool: 264 | self.pooling_list.append(seq) 265 | continue 266 | if len(seq) % 2 == 1: 267 | self.pooling_list.append([seq[0]]) 268 | self.new_edges.append(edges[seq[0]]) 269 | seq = seq[1:] 270 | for i in range(0, len(seq), 2): 271 | self.pooling_list.append([seq[i], seq[i + 1]]) 272 | self.new_edges.append([edges[seq[i]][0], edges[seq[i + 1]][1]]) 273 | 274 | # add global position 275 | self.pooling_list.append([self.edge_num - 1]) 276 | 277 | self.description = 'SkeletonPool(in_edge_num={}, out_edge_num={})'.format( 278 | len(edges), len(self.pooling_list) 279 | ) 280 | 281 | self.weight = torch.zeros(len(self.pooling_list) * channels_per_edge, self.edge_num * channels_per_edge) 282 | 283 | for i, pair in enumerate(self.pooling_list): 284 | for j in pair: 285 | for c in range(channels_per_edge): 286 | self.weight[i * channels_per_edge + c, j * channels_per_edge + c] = 1.0 / len(pair) 287 | 288 | self.weight = nn.Parameter(self.weight, requires_grad=False) 289 | 290 | def forward(self, input: torch.Tensor): 291 | return torch.matmul(self.weight, input) 292 | 293 | 294 | class SkeletonUnpool(nn.Module): 295 | def __init__(self, pooling_list, channels_per_edge): 296 | super(SkeletonUnpool, self).__init__() 297 | self.pooling_list = pooling_list 298 | self.input_joint_num = len(pooling_list) 299 | self.output_joint_num = 0 300 | self.channels_per_edge = channels_per_edge 301 | for t in self.pooling_list: 302 | self.output_joint_num += len(t) 303 | 304 | self.description = 'SkeletonUnpool(in_joint_num={}, out_joint_num={})'.format( 305 | self.input_joint_num, self.output_joint_num, 306 | ) 307 | 308 | self.weight = torch.zeros(self.output_joint_num * channels_per_edge, self.input_joint_num * channels_per_edge) 309 | 310 | for i, pair in enumerate(self.pooling_list): 311 | for j in pair: 312 | for c in range(channels_per_edge): 313 | self.weight[j * channels_per_edge + c, i * channels_per_edge + c] = 1 314 | 315 | self.weight = nn.Parameter(self.weight) 316 | self.weight.requires_grad_(False) 317 | 318 | def forward(self, input: torch.Tensor): 319 | return torch.matmul(self.weight, input.unsqueeze(-1)).squeeze(-1) 320 | 321 | 322 | def find_neighbor_joint(parents, threshold): 323 | n_joint = len(parents) 324 | dist_mat = np.empty((n_joint, n_joint), dtype=np.int) 325 | dist_mat[:, :] = 100000 326 | for i, p in enumerate(parents): 327 | dist_mat[i, i] = 0 328 | if i != 0: 329 | dist_mat[i, p] = dist_mat[p, i] = 1 330 | 331 | """ 332 | Floyd's algorithm 333 | """ 334 | for k in range(n_joint): 335 | for i in range(n_joint): 336 | for j in range(n_joint): 337 | dist_mat[i, j] = min(dist_mat[i, j], dist_mat[i, k] + dist_mat[k, j]) 338 | 339 | neighbor_list = [] 340 | for i in range(n_joint): 341 | neighbor = [] 342 | for j in range(n_joint): 343 | if dist_mat[i, j] <= threshold: 344 | neighbor.append(j) 345 | neighbor_list.append(neighbor) 346 | 347 | return neighbor_list 348 | -------------------------------------------------------------------------------- /utils/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def batch_mm(matrix, matrix_batch): 6 | """ 7 | https://github.com/pytorch/pytorch/issues/14489#issuecomment-607730242 8 | :param matrix: Sparse or dense matrix, size (m, n). 9 | :param matrix_batch: Batched dense matrices, size (b, n, k). 10 | :return: The batched matrix-matrix product, size (m, n) x (b, n, k) = (b, m, k). 11 | """ 12 | batch_size = matrix_batch.shape[0] 13 | # Stack the vector batch into columns. (b, n, k) -> (n, b, k) -> (n, b*k) 14 | vectors = matrix_batch.transpose(0, 1).reshape(matrix.shape[1], -1) 15 | 16 | # A matrix-matrix product is a batched matrix-vector product of the columns. 17 | # And then reverse the reshaping. (m, n) x (n, b*k) = (m, b*k) -> (m, b, k) -> (b, m, k) 18 | return matrix.mm(vectors).reshape(matrix.shape[0], batch_size, -1).transpose(1, 0) 19 | 20 | 21 | def aa2quat(rots, form='wxyz', unified_orient=True): 22 | """ 23 | Convert angle-axis representation to wxyz quaternion and to the half plan (w >= 0) 24 | @param rots: angle-axis rotations, (*, 3) 25 | @param form: quaternion format, either 'wxyz' or 'xyzw' 26 | @param unified_orient: Use unified orientation for quaternion (quaternion is dual cover of SO3) 27 | :return: 28 | """ 29 | angles = rots.norm(dim=-1, keepdim=True) 30 | norm = angles.clone() 31 | norm[norm < 1e-8] = 1 32 | axis = rots / norm 33 | quats = torch.empty(rots.shape[:-1] + (4,), device=rots.device, dtype=rots.dtype) 34 | angles = angles * 0.5 35 | if form == 'wxyz': 36 | quats[..., 0] = torch.cos(angles.squeeze(-1)) 37 | quats[..., 1:] = torch.sin(angles) * axis 38 | elif form == 'xyzw': 39 | quats[..., :3] = torch.sin(angles) * axis 40 | quats[..., 3] = torch.cos(angles.squeeze(-1)) 41 | 42 | if unified_orient: 43 | idx = quats[..., 0] < 0 44 | quats[idx, :] *= -1 45 | 46 | return quats 47 | 48 | 49 | def quat2aa(quats): 50 | """ 51 | Convert wxyz quaternions to angle-axis representation 52 | :param quats: 53 | :return: 54 | """ 55 | _cos = quats[..., 0] 56 | xyz = quats[..., 1:] 57 | _sin = xyz.norm(dim=-1) 58 | norm = _sin.clone() 59 | norm[norm < 1e-7] = 1 60 | axis = xyz / norm.unsqueeze(-1) 61 | angle = torch.atan2(_sin, _cos) * 2 62 | return axis * angle.unsqueeze(-1) 63 | 64 | 65 | def quat2mat(quats: torch.Tensor): 66 | """ 67 | Convert (w, x, y, z) quaternions to 3x3 rotation matrix 68 | :param quats: quaternions of shape (..., 4) 69 | :return: rotation matrices of shape (..., 3, 3) 70 | """ 71 | qw = quats[..., 0] 72 | qx = quats[..., 1] 73 | qy = quats[..., 2] 74 | qz = quats[..., 3] 75 | 76 | x2 = qx + qx 77 | y2 = qy + qy 78 | z2 = qz + qz 79 | xx = qx * x2 80 | yy = qy * y2 81 | wx = qw * x2 82 | xy = qx * y2 83 | yz = qy * z2 84 | wy = qw * y2 85 | xz = qx * z2 86 | zz = qz * z2 87 | wz = qw * z2 88 | 89 | m = torch.empty(quats.shape[:-1] + (3, 3), device=quats.device, dtype=quats.dtype) 90 | m[..., 0, 0] = 1.0 - (yy + zz) 91 | m[..., 0, 1] = xy - wz 92 | m[..., 0, 2] = xz + wy 93 | m[..., 1, 0] = xy + wz 94 | m[..., 1, 1] = 1.0 - (xx + zz) 95 | m[..., 1, 2] = yz - wx 96 | m[..., 2, 0] = xz - wy 97 | m[..., 2, 1] = yz + wx 98 | m[..., 2, 2] = 1.0 - (xx + yy) 99 | 100 | return m 101 | 102 | 103 | def quat2euler(q, order='xyz', degrees=True): 104 | """ 105 | Convert (w, x, y, z) quaternions to xyz euler angles. This is used for bvh output. 106 | """ 107 | q0 = q[..., 0] 108 | q1 = q[..., 1] 109 | q2 = q[..., 2] 110 | q3 = q[..., 3] 111 | es = torch.empty(q0.shape + (3,), device=q.device, dtype=q.dtype) 112 | 113 | if order == 'xyz': 114 | es[..., 2] = torch.atan2(2 * (q0 * q3 - q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3) 115 | es[..., 1] = torch.asin((2 * (q1 * q3 + q0 * q2)).clip(-1, 1)) 116 | es[..., 0] = torch.atan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3) 117 | else: 118 | raise NotImplementedError('Cannot convert to ordering %s' % order) 119 | 120 | if degrees: 121 | es = es * 180 / np.pi 122 | 123 | return es 124 | 125 | 126 | def euler2mat(rots, order='xyz'): 127 | axis = {'x': torch.tensor((1, 0, 0), device=rots.device), 128 | 'y': torch.tensor((0, 1, 0), device=rots.device), 129 | 'z': torch.tensor((0, 0, 1), device=rots.device)} 130 | 131 | rots = rots / 180 * np.pi 132 | mats = [] 133 | for i in range(3): 134 | aa = axis[order[i]] * rots[..., i].unsqueeze(-1) 135 | mats.append(aa2mat(aa)) 136 | return mats[0] @ (mats[1] @ mats[2]) 137 | 138 | 139 | def aa2mat(rots): 140 | """ 141 | Convert angle-axis representation to rotation matrix 142 | :param rots: angle-axis representation 143 | :return: 144 | """ 145 | quat = aa2quat(rots) 146 | mat = quat2mat(quat) 147 | return mat 148 | 149 | 150 | def mat2quat(R) -> torch.Tensor: 151 | ''' 152 | https://github.com/duolu/pyrotation/blob/master/pyrotation/pyrotation.py 153 | Convert a rotation matrix to a unit quaternion. 154 | 155 | This uses the Shepperd’s method for numerical stability. 156 | ''' 157 | 158 | # The rotation matrix must be orthonormal 159 | 160 | w2 = (1 + R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2]) 161 | x2 = (1 + R[..., 0, 0] - R[..., 1, 1] - R[..., 2, 2]) 162 | y2 = (1 - R[..., 0, 0] + R[..., 1, 1] - R[..., 2, 2]) 163 | z2 = (1 - R[..., 0, 0] - R[..., 1, 1] + R[..., 2, 2]) 164 | 165 | yz = (R[..., 1, 2] + R[..., 2, 1]) 166 | xz = (R[..., 2, 0] + R[..., 0, 2]) 167 | xy = (R[..., 0, 1] + R[..., 1, 0]) 168 | 169 | wx = (R[..., 2, 1] - R[..., 1, 2]) 170 | wy = (R[..., 0, 2] - R[..., 2, 0]) 171 | wz = (R[..., 1, 0] - R[..., 0, 1]) 172 | 173 | w = torch.empty_like(x2) 174 | x = torch.empty_like(x2) 175 | y = torch.empty_like(x2) 176 | z = torch.empty_like(x2) 177 | 178 | flagA = (R[..., 2, 2] < 0) * (R[..., 0, 0] > R[..., 1, 1]) 179 | flagB = (R[..., 2, 2] < 0) * (R[..., 0, 0] <= R[..., 1, 1]) 180 | flagC = (R[..., 2, 2] >= 0) * (R[..., 0, 0] < -R[..., 1, 1]) 181 | flagD = (R[..., 2, 2] >= 0) * (R[..., 0, 0] >= -R[..., 1, 1]) 182 | 183 | x[flagA] = torch.sqrt(x2[flagA]) 184 | w[flagA] = wx[flagA] / x[flagA] 185 | y[flagA] = xy[flagA] / x[flagA] 186 | z[flagA] = xz[flagA] / x[flagA] 187 | 188 | y[flagB] = torch.sqrt(y2[flagB]) 189 | w[flagB] = wy[flagB] / y[flagB] 190 | x[flagB] = xy[flagB] / y[flagB] 191 | z[flagB] = yz[flagB] / y[flagB] 192 | 193 | z[flagC] = torch.sqrt(z2[flagC]) 194 | w[flagC] = wz[flagC] / z[flagC] 195 | x[flagC] = xz[flagC] / z[flagC] 196 | y[flagC] = yz[flagC] / z[flagC] 197 | 198 | w[flagD] = torch.sqrt(w2[flagD]) 199 | x[flagD] = wx[flagD] / w[flagD] 200 | y[flagD] = wy[flagD] / w[flagD] 201 | z[flagD] = wz[flagD] / w[flagD] 202 | 203 | # if R[..., 2, 2] < 0: 204 | # 205 | # if R[..., 0, 0] > R[..., 1, 1]: 206 | # 207 | # x = torch.sqrt(x2) 208 | # w = wx / x 209 | # y = xy / x 210 | # z = xz / x 211 | # 212 | # else: 213 | # 214 | # y = torch.sqrt(y2) 215 | # w = wy / y 216 | # x = xy / y 217 | # z = yz / y 218 | # 219 | # else: 220 | # 221 | # if R[..., 0, 0] < -R[..., 1, 1]: 222 | # 223 | # z = torch.sqrt(z2) 224 | # w = wz / z 225 | # x = xz / z 226 | # y = yz / z 227 | # 228 | # else: 229 | # 230 | # w = torch.sqrt(w2) 231 | # x = wx / w 232 | # y = wy / w 233 | # z = wz / w 234 | 235 | res = [w, x, y, z] 236 | res = [z.unsqueeze(-1) for z in res] 237 | 238 | return torch.cat(res, dim=-1) / 2 239 | 240 | 241 | def quat2repr6d(quat): 242 | mat = quat2mat(quat) 243 | res = mat[..., :2, :] 244 | res = res.reshape(res.shape[:-2] + (6, )) 245 | return res 246 | 247 | 248 | def repr6d2mat(repr): 249 | x = repr[..., :3] 250 | y = repr[..., 3:] 251 | x = x / x.norm(dim=-1, keepdim=True) 252 | z = torch.cross(x, y) 253 | z = z / z.norm(dim=-1, keepdim=True) 254 | y = torch.cross(z, x) 255 | res = [x, y, z] 256 | res = [v.unsqueeze(-2) for v in res] 257 | mat = torch.cat(res, dim=-2) 258 | return mat 259 | 260 | 261 | def repr6d2quat(repr) -> torch.Tensor: 262 | x = repr[..., :3] 263 | y = repr[..., 3:] 264 | x = x / x.norm(dim=-1, keepdim=True) 265 | z = torch.cross(x, y) 266 | z = z / z.norm(dim=-1, keepdim=True) 267 | y = torch.cross(z, x) 268 | res = [x, y, z] 269 | res = [v.unsqueeze(-2) for v in res] 270 | mat = torch.cat(res, dim=-2) 271 | return mat2quat(mat) 272 | 273 | 274 | def inv_affine(mat): 275 | """ 276 | Calculate the inverse of any affine transformation 277 | """ 278 | affine = torch.zeros((mat.shape[:2] + (1, 4))) 279 | affine[..., 3] = 1 280 | vert_mat = torch.cat((mat, affine), dim=2) 281 | vert_mat_inv = torch.inverse(vert_mat) 282 | return vert_mat_inv[..., :3, :] 283 | 284 | 285 | def inv_rigid_affine(mat): 286 | """ 287 | Calculate the inverse of a rigid affine transformation 288 | """ 289 | res = mat.clone() 290 | res[..., :3] = mat[..., :3].transpose(-2, -1) 291 | res[..., 3] = -torch.matmul(res[..., :3], mat[..., 3].unsqueeze(-1)).squeeze(-1) 292 | return res 293 | 294 | 295 | def generate_pose(batch_size, device, uniform=False, factor=1, root_rot=False, n_bone=None, ee=None): 296 | if n_bone is None: n_bone = 24 297 | if ee is not None: 298 | if root_rot: 299 | ee.append(0) 300 | n_bone_ = n_bone 301 | n_bone = len(ee) 302 | axis = torch.randn((batch_size, n_bone, 3), device=device) 303 | axis /= axis.norm(dim=-1, keepdim=True) 304 | if uniform: 305 | angle = torch.rand((batch_size, n_bone, 1), device=device) * np.pi 306 | else: 307 | angle = torch.randn((batch_size, n_bone, 1), device=device) * np.pi / 6 * factor 308 | angle.clamp(-np.pi, np.pi) 309 | poses = axis * angle 310 | if ee is not None: 311 | res = torch.zeros((batch_size, n_bone_, 3), device=device) 312 | for i, id in enumerate(ee): 313 | res[:, id] = poses[:, i] 314 | poses = res 315 | poses = poses.reshape(batch_size, -1) 316 | if not root_rot: 317 | poses[..., :3] = 0 318 | return poses 319 | 320 | 321 | def slerp(l, r, t, unit=True): 322 | """ 323 | :param l: shape = (*, n) 324 | :param r: shape = (*, n) 325 | :param t: shape = (*) 326 | :param unit: If l and h are unit vectors 327 | :return: 328 | """ 329 | eps = 1e-8 330 | if not unit: 331 | l_n = l / torch.norm(l, dim=-1, keepdim=True) 332 | r_n = r / torch.norm(r, dim=-1, keepdim=True) 333 | else: 334 | l_n = l 335 | r_n = r 336 | omega = torch.acos((l_n * r_n).sum(dim=-1).clamp(-1, 1)) 337 | dom = torch.sin(omega) 338 | 339 | flag = dom < eps 340 | 341 | res = torch.empty_like(l_n) 342 | t_t = t[flag].unsqueeze(-1) 343 | res[flag] = (1 - t_t) * l_n[flag] + t_t * r_n[flag] 344 | 345 | flag = ~ flag 346 | 347 | t_t = t[flag] 348 | d_t = dom[flag] 349 | va = torch.sin((1 - t_t) * omega[flag]) / d_t 350 | vb = torch.sin(t_t * omega[flag]) / d_t 351 | res[flag] = (va.unsqueeze(-1) * l_n[flag] + vb.unsqueeze(-1) * r_n[flag]) 352 | return res 353 | 354 | 355 | def slerp_quat(l, r, t): 356 | """ 357 | slerp for unit quaternions 358 | :param l: (*, 4) unit quaternion 359 | :param r: (*, 4) unit quaternion 360 | :param t: (*) scalar between 0 and 1 361 | """ 362 | t = t.expand(l.shape[:-1]) 363 | flag = (l * r).sum(dim=-1) >= 0 364 | res = torch.empty_like(l) 365 | res[flag] = slerp(l[flag], r[flag], t[flag]) 366 | flag = ~ flag 367 | res[flag] = slerp(-l[flag], r[flag], t[flag]) 368 | return res 369 | 370 | 371 | # def slerp_6d(l, r, t): 372 | # l_q = repr6d2quat(l) 373 | # r_q = repr6d2quat(r) 374 | # res_q = slerp_quat(l_q, r_q, t) 375 | # return quat2repr6d(res_q) 376 | 377 | 378 | def interpolate_6d(input, size): 379 | """ 380 | :param input: (batch_size, n_channels, length) 381 | :param size: required output size for temporal axis 382 | :return: 383 | """ 384 | batch = input.shape[0] 385 | length = input.shape[-1] 386 | input = input.reshape((batch, -1, 6, length)) 387 | input = input.permute(0, 1, 3, 2) # (batch_size, n_joint, length, 6) 388 | input_q = repr6d2quat(input) 389 | idx = torch.tensor(list(range(size)), device=input_q.device, dtype=torch.float) / size * (length - 1) 390 | idx_l = torch.floor(idx) 391 | t = idx - idx_l 392 | idx_l = idx_l.long() 393 | idx_r = idx_l + 1 394 | t = t.reshape((1, 1, -1)) 395 | res_q = slerp_quat(input_q[..., idx_l, :], input_q[..., idx_r, :], t) 396 | res = quat2repr6d(res_q) # shape = (batch_size, n_joint, t, 6) 397 | res = res.permute(0, 1, 3, 2) 398 | res = res.reshape((batch, -1, size)) 399 | return res 400 | --------------------------------------------------------------------------------