├── GenMM.py
├── LICENSE
├── README.md
├── __init__.py
├── configs
├── default.yaml
└── ganimator.yaml
├── dataset
├── blender_motion.py
├── bvh
│ ├── Quaternions.py
│ ├── __pycache__
│ │ ├── Quaternions.cpython-37.pyc
│ │ ├── bvh_io.cpython-37.pyc
│ │ ├── bvh_parser.cpython-37.pyc
│ │ └── bvh_writer.cpython-37.pyc
│ ├── bvh_io.py
│ ├── bvh_parser.py
│ └── bvh_writer.py
├── bvh_motion.py
├── motion.py
└── tracks_motion.py
├── demo.blend
├── docker
├── Dockerfile
├── README.md
├── apt-sources.list
├── requirements.txt
└── requirements_blender.txt
├── fix_contact.py
├── nearest_neighbor
├── losses.py
└── utils.py
├── run_random_generation.py
├── run_web_server.py
└── utils
├── base.py
├── contact.py
├── kinematics.py
├── skeleton.py
└── transforms.py
/GenMM.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 |
7 | from utils.base import logger
8 |
9 | class GenMM:
10 | def __init__(self, mode = 'random_synthesis', noise_sigma = 1.0, coarse_ratio = 0.2, coarse_ratio_factor = 6, pyr_factor = 0.75, num_stages_limit = -1, device = 'cuda:0', silent = False):
11 | '''
12 | GenMM main constructor
13 | Args:
14 | device : str = 'cuda:0', default device.
15 | silent : bool = False, whether to mute the output.
16 | '''
17 | self.device = torch.device(device)
18 | self.silent = silent
19 |
20 | def _get_pyramid_lengths(self, final_len, coarse_ratio, pyr_factor):
21 | '''
22 | Get a list of pyramid lengths using given target length and ratio
23 | '''
24 | lengths = [int(np.round(final_len * coarse_ratio))]
25 | while lengths[-1] < final_len:
26 | lengths.append(int(np.round(lengths[-1] / pyr_factor)))
27 | if lengths[-1] == lengths[-2]:
28 | lengths[-1] += 1
29 | lengths[-1] = final_len
30 |
31 | return lengths
32 |
33 | def _get_target_pyramid(self, target, coarse_ratio, pyr_factor, num_stages_limit=-1):
34 | '''
35 | Reads a target motion(s) and create a pyraimd out of it. Ordered in increatorch.sing size
36 | '''
37 | self.num_target = len(target)
38 | lengths = []
39 | min_len = 10000
40 | for i in range(len(target)):
41 | new_length = self._get_pyramid_lengths(len(target[i].motion_data), coarse_ratio, pyr_factor)
42 | min_len = min(min_len, len(new_length))
43 | if num_stages_limit != -1:
44 | new_length = new_length[:num_stages_limit]
45 | lengths.append(new_length)
46 | for i in range(len(target)):
47 | lengths[i] = lengths[i][-min_len:]
48 | self.pyraimd_lengths = lengths
49 |
50 | target_pyramid = [[] for _ in range(len(lengths[0]))]
51 | for step in range(len(lengths[0])):
52 | for i in range(len(target)):
53 | length = lengths[i][step]
54 | target_pyramid[step].append(target[i].sample(size=length).to(self.device))
55 |
56 | if not self.silent:
57 | print('Levels:', lengths)
58 | for i in range(len(target_pyramid)):
59 | print(f'Number of clips in target pyramid {i} is {len(target_pyramid[i])}, ranging {[[tgt.min(), tgt.max()] for tgt in target_pyramid[i]]}')
60 |
61 | return target_pyramid
62 |
63 | def _get_initial_motion(self, init_length, noise_sigma):
64 | '''
65 | Prepare the initial motion for optimization
66 | '''
67 | initial_motion = F.interpolate(torch.cat([self.target_pyramid[0][i] for i in range(self.num_target)], dim=-1),
68 | size=init_length, mode='linear', align_corners=True)
69 | if noise_sigma > 0:
70 | initial_motion_w_noise = initial_motion + torch.randn_like(initial_motion) * noise_sigma
71 | initial_motion_w_noise = torch.fmod(initial_motion_w_noise, 1.0)
72 | else:
73 | initial_motion_w_noise = initial_motion
74 |
75 | if not self.silent:
76 | print('Initial motion:', initial_motion.min(), initial_motion.max())
77 | print('Initial motion with noise:', initial_motion_w_noise.min(), initial_motion_w_noise.max())
78 |
79 | return initial_motion_w_noise
80 |
81 | def run(self, target, criteria, num_frames, num_steps, noise_sigma, patch_size, coarse_ratio, pyr_factor, ext=None, debug_dir=None):
82 | '''
83 | generation function
84 | Args:
85 | mode : - string = 'x?', generate x times longer frames results
86 | : - int, specifying the number of times to generate
87 | noise_sigma : float = 1.0, random noise.
88 | coarse_ratio : float = 0.2, ratio at the coarse level.
89 | pyr_factor : float = 0.75, pyramid factor.
90 | num_stages_limit : int = -1, no limit.
91 | '''
92 | if debug_dir is not None:
93 | from tensorboardX import SummaryWriter
94 | writer = SummaryWriter(log_dir=debug_dir)
95 |
96 | # build target pyramid
97 | if 'patchsize' in coarse_ratio:
98 | coarse_ratio = patch_size * float(coarse_ratio.split('x_')[0]) / max([len(t.motion_data) for t in target])
99 | elif 'nframes' in coarse_ratio:
100 | coarse_ratio = float(coarse_ratio.split('x_')[0])
101 | else:
102 | raise ValueError('Unsupported coarse ratio specified')
103 | self.target_pyramid = self._get_target_pyramid(target, coarse_ratio, pyr_factor)
104 |
105 | # get the initial motion data
106 | if 'nframes' in num_frames:
107 | syn_length = int(sum([i[-1] for i in self.pyraimd_lengths]) * float(num_frames.split('x_')[0]))
108 | elif num_frames.isdigit():
109 | syn_length = int(num_frames)
110 | else:
111 | raise ValueError(f'Unsupported mode {self.mode}')
112 | self.synthesized_lengths = self._get_pyramid_lengths(syn_length, coarse_ratio, pyr_factor)
113 | if not self.silent:
114 | print('Synthesized lengths:', self.synthesized_lengths)
115 | self.synthesized = self._get_initial_motion(self.synthesized_lengths[0], noise_sigma)
116 |
117 | # perform the optimization
118 | self.synthesized.requires_grad_(False)
119 | self.pbar = logger(num_steps, len(self.target_pyramid))
120 | for lvl, lvl_target in enumerate(self.target_pyramid):
121 | self.pbar.new_lvl()
122 | if lvl > 0:
123 | with torch.no_grad():
124 | self.synthesized = F.interpolate(self.synthesized.detach(), size=self.synthesized_lengths[lvl], mode='linear')
125 |
126 | self.synthesized, losses = GenMM.match_and_blend(self.synthesized, lvl_target, criteria, num_steps, self.pbar, ext=ext)
127 |
128 | criteria.clean_cache()
129 | if debug_dir is not None:
130 | for itr in range(len(losses)):
131 | writer.add_scalar(f'optimize/losses_lvl{lvl}', losses[itr], itr)
132 | self.pbar.pbar.close()
133 |
134 | return self.synthesized.detach()
135 |
136 |
137 | @staticmethod
138 | @torch.no_grad()
139 | def match_and_blend(synthesized, targets, criteria, n_steps, pbar, ext=None):
140 | '''
141 | Minimizes criteria bewteen synthesized and target
142 | Args:
143 | synthesized : torch.Tensor, optimized motion data
144 | targets : torch.Tensor, target motion data
145 | criteria : optimmize target function
146 | n_steps : int, number of steps to optimize
147 | pbar : logger
148 | ext : extra configurations or constraints (optional)
149 | '''
150 | losses = []
151 | for _i in range(n_steps):
152 | synthesized, loss = criteria(synthesized, targets, ext=ext, return_blended_results=True)
153 |
154 | # Update staus
155 | losses.append(loss.item())
156 | pbar.step()
157 | pbar.print()
158 |
159 | return synthesized, losses
160 |
161 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Example-based Motion Synthesis via Generative Motion Matching, ACM Transactions on Graphics (Proceedings of SIGGRAPH 2023)
2 |
3 | #####
[Weiyu Li*](https://wyysf-98.github.io/), [Xuelin Chen*†](https://xuelin-chen.github.io/), [Peizhuo Li](https://peizhuoli.github.io/), [Olga Sorkine-Hornung](https://igl.ethz.ch/people/sorkine/), [Baoquan Chen](https://cfcs.pku.edu.cn/baoquan/)
4 |
5 | #### [Project Page](https://wyysf-98.github.io/GenMM) | [ArXiv](https://arxiv.org/abs/2306.00378) | [Paper](https://wyysf-98.github.io/GenMM/paper/Paper_high_res.pdf) | [Video](https://youtu.be/lehnxcade4I)
6 |
7 |
8 |
9 |
10 |
11 | All Code and demo will be released in this week(still ongoing...) 🏗️ 🚧 🔨
12 |
13 | - [x] Release main code
14 | - [x] Release blender addon
15 | - [x] Detailed README and installation guide
16 | - [ ] Release skeleton-aware component, WIP as we need to split the joints into groups manually.
17 | - [ ] Release codes for evaluation
18 |
19 | ## Prerequisite
20 |
21 | Setup environment
22 |
23 | :smiley: We also provide a Dockerfile for easy installation, see [Setup using Docker](./docker/README.md).
24 |
25 | - Python 3.8
26 | - PyTorch 1.12.1
27 | - [unfoldNd](https://github.com/f-dangel/unfoldNd)
28 |
29 | Clone this repository.
30 |
31 | ```sh
32 | git clone git@github.com:wyysf-98/GenMM.git
33 | ```
34 |
35 | Install the required packages.
36 |
37 | ```sh
38 | conda create -n GenMM python=3.8
39 | conda activate GenMM
40 | conda install -c pytorch pytorch=1.12.1 torchvision=0.13.1 cudatoolkit=11.3 && \
41 | pip install -r docker/requirements.txt
42 | pip install torch-scatter==2.1.1
43 | ```
44 |
45 |
46 |
47 | ## Quick inference demo
48 | For local quick inference demo using .bvh file, you can use
49 |
50 | ```sh
51 | python run_random_generation.py -i './data/Malcolm/Gangnam-Style.bvh'
52 | ```
53 | More configuration can be found in the `run_random_generation.py`.
54 | We use an Apple M1 and NVIDIA Tesla V100 with 32 GB RAM to generate each motion, which takes about ~0.2s and ~0.05s as mentioned in our paper.
55 |
56 | ## Blender add-on
57 | You can install and use the blender add-on with easy installation as our method is efficient and you do not need to install CUDA Toolkit.
58 | We test our code using blender 3.22.0, and will support 2.8.0 in the future.
59 |
60 | Step 1: Find yout blender python path. Common paths are as follows
61 | ```sh
62 | (Windows) 'C:\Program Files\Blender Foundation\Blender 3.2\3.2\python\bin'
63 | (Linux) '/path/to/blender/blender-path/3.2/python/bin'
64 | (Windows) '/Applications/Blender.app/Contents/Resources/3.2/python/bin'
65 | ```
66 |
67 | Step 2: Install required packages. Open your shell(Linux) or powershell(Windows),
68 | ```sh
69 | cd {your python path} && pip3 install -r docker/requirements.txt && pip3 install torch-scatter==2.1.0 -f https://data.pyg.org/whl/torch-1.12.0+${CUDA}.html
70 | ```
71 | , where ${CUDA} should be replaced by either cpu, cu117, or cu118 depending on your PyTorch installation.
72 | On my MacOS with M1 cpu,
73 |
74 | ```sh
75 | cd /Applications/Blender.app/Contents/Resources/3.2/python/bin && pip3 install -r docker/requirements_blender.txt && pip3 install torch-scatter==2.1.0 -f https://data.pyg.org/whl/torch-1.12.0+cpu.html
76 | ```
77 |
78 | Step 3: Install add-on in blender. [Blender Add-ons Official Tutorial](https://docs.blender.org/manual/en/latest/editors/preferences/addons.html). `edit -> Preferences -> Add-ons -> Install -> Select the downloaded .zip file`
79 |
80 | Step 4: Have fun! Click the armature and you will find a `GenMM` tag.
81 |
82 | (GPU support) If you have GPU and CUDA Toolskits installed, we automatically dectect the running device.
83 |
84 | Feel free to submit an issue if you run into any issues during the installation :)
85 |
86 | ## Acknowledgement
87 |
88 | We thank [@stefanonuvoli](https://github.com/stefanonuvoli/skinmixer) for the help for the discussion of implementation about `Motion Reassembly` part (we eventually manually merged the meshes of different characters). And [@Radamés Ajna](https://github.com/radames) for the help of a better huggingface demo.
89 |
90 |
91 | ## Citation
92 |
93 | If you find our work useful for your research, please consider citing using the following BibTeX entry.
94 |
95 | ```BibTeX
96 | @article{10.1145/weiyu23GenMM,
97 | author = {Li, Weiyu and Chen, Xuelin and Li, Peizhuo and Sorkine-Hornung, Olga and Chen, Baoquan},
98 | title = {Example-Based Motion Synthesis via Generative Motion Matching},
99 | journal = {ACM Transactions on Graphics (TOG)},
100 | volume = {42},
101 | number = {4},
102 | year = {2023},
103 | articleno = {94},
104 | doi = {10.1145/3592395},
105 | publisher = {Association for Computing Machinery},
106 | }
107 | ```
108 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | # This program is free software; you can redistribute it and/or modify
2 | # it under the terms of the GNU General Public License as published by
3 | # the Free Software Foundation; either version 3 of the License, or
4 | # (at your option) any later version.
5 | #
6 | # This program is distributed in the hope that it will be useful, but
7 | # WITHOUT ANY WARRANTY; without even the implied warranty of
8 | # MERCHANTIBILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
9 | # General Public License for more details.
10 | #
11 | # You should have received a copy of the GNU General Public License
12 | # along with this program. If not, see .
13 | import os
14 | import sys
15 | import bpy
16 | import torch
17 | import mathutils
18 | import numpy as np
19 | from math import degrees, radians, ceil
20 | from mathutils import Vector, Matrix, Euler
21 | from typing import List, Iterable, Tuple, Any, Dict
22 |
23 | abs_path = os.path.abspath(__file__)
24 | sys.path.append(os.path.dirname(abs_path))
25 | from GenMM import GenMM
26 | from nearest_neighbor.losses import PatchCoherentLoss
27 | from dataset.blender_motion import BlenderMotion
28 |
29 | bl_info = {
30 | "name" : "GenMM",
31 | "author" : "Weiyu Li",
32 | "description" : "Blender addon for SIGGRAPH paper 'Example-Based Motion Synthesis via Generative Motion Matching'",
33 | "blender" : (3, 2, 0),
34 | "version" : (0, 0, 1),
35 | "location": "3D View",
36 | "description": "Synthesis novel motions form a few exemplars.",
37 | "location" : "",
38 | "support": "TESTING",
39 | "warning" : "",
40 | "category" : "Generic"
41 | }
42 |
43 |
44 | # This function is modified from
45 | # https://github.com/bwrsandman/blender-addons/blob/master/io_anim_bvh
46 | def get_bvh_data(context,
47 | frame_end,
48 | frame_start,
49 | global_scale=1.0,
50 | rotate_mode='NATIVE',
51 | root_transform_only=False,
52 | ):
53 |
54 | def ensure_rot_order(rot_order_str):
55 | if set(rot_order_str) != {'X', 'Y', 'Z'}:
56 | rot_order_str = "XYZ"
57 | return rot_order_str
58 |
59 | file_str = []
60 |
61 | obj = context.object
62 | arm = obj.data
63 |
64 | # Build a dictionary of children.
65 | # None for parentless
66 | children = {None: []}
67 |
68 | # initialize with blank lists
69 | for bone in arm.bones:
70 | children[bone.name] = []
71 |
72 | # keep bone order from armature, no sorting, not esspential but means
73 | # we can maintain order from import -> export which secondlife incorrectly expects.
74 | for bone in arm.bones:
75 | children[getattr(bone.parent, "name", None)].append(bone.name)
76 |
77 | # bone name list in the order that the bones are written
78 | serialized_names = []
79 |
80 | node_locations = {}
81 |
82 | file_str.append("HIERARCHY\n")
83 |
84 | def write_recursive_nodes(bone_name, indent):
85 | my_children = children[bone_name]
86 |
87 | indent_str = "\t" * indent
88 |
89 | bone = arm.bones[bone_name]
90 | pose_bone = obj.pose.bones[bone_name]
91 | loc = bone.head_local
92 | node_locations[bone_name] = loc
93 |
94 | if rotate_mode == "NATIVE":
95 | rot_order_str = ensure_rot_order(pose_bone.rotation_mode)
96 | else:
97 | rot_order_str = rotate_mode
98 |
99 | # make relative if we can
100 | if bone.parent:
101 | loc = loc - node_locations[bone.parent.name]
102 |
103 | if indent:
104 | file_str.append("%sJOINT %s\n" % (indent_str, bone_name))
105 | else:
106 | file_str.append("%sROOT %s\n" % (indent_str, bone_name))
107 |
108 | file_str.append("%s{\n" % indent_str)
109 | file_str.append("%s\tOFFSET %.6f %.6f %.6f\n" % (indent_str, loc.x * global_scale, loc.y * global_scale, loc.z * global_scale))
110 | if (bone.use_connect or root_transform_only) and bone.parent:
111 | file_str.append("%s\tCHANNELS 3 %srotation %srotation %srotation\n" % (indent_str, rot_order_str[0], rot_order_str[1], rot_order_str[2]))
112 | else:
113 | file_str.append("%s\tCHANNELS 6 Xposition Yposition Zposition %srotation %srotation %srotation\n" % (indent_str, rot_order_str[0], rot_order_str[1], rot_order_str[2]))
114 |
115 | if my_children:
116 | # store the location for the children
117 | # to get their relative offset
118 |
119 | # Write children
120 | for child_bone in my_children:
121 | serialized_names.append(child_bone)
122 | write_recursive_nodes(child_bone, indent + 1)
123 |
124 | else:
125 | # Write the bone end.
126 | file_str.append("%s\tEnd Site\n" % indent_str)
127 | file_str.append("%s\t{\n" % indent_str)
128 | loc = bone.tail_local - node_locations[bone_name]
129 | file_str.append("%s\t\tOFFSET %.6f %.6f %.6f\n" % (indent_str, loc.x * global_scale, loc.y * global_scale, loc.z * global_scale))
130 | file_str.append("%s\t}\n" % indent_str)
131 |
132 | file_str.append("%s}\n" % indent_str)
133 |
134 | if len(children[None]) == 1:
135 | key = children[None][0]
136 | serialized_names.append(key)
137 | indent = 0
138 |
139 | write_recursive_nodes(key, indent)
140 |
141 | else:
142 | # Write a dummy parent node, with a dummy key name
143 | # Just be sure it's not used by another bone!
144 | i = 0
145 | key = "__%d" % i
146 | while key in children:
147 | i += 1
148 | key = "__%d" % i
149 | file_str.append("ROOT %s\n" % key)
150 | file_str.append("{\n")
151 | file_str.append("\tOFFSET 0.0 0.0 0.0\n")
152 | file_str.append("\tCHANNELS 0\n") # Xposition Yposition Zposition Xrotation Yrotation Zrotation
153 | indent = 1
154 |
155 | # Write children
156 | for child_bone in children[None]:
157 | serialized_names.append(child_bone)
158 | write_recursive_nodes(child_bone, indent)
159 |
160 | file_str.append("}\n")
161 | file_str = ''.join(file_str)
162 | # redefine bones as sorted by serialized_names
163 | # so we can write motion
164 |
165 | class DecoratedBone:
166 | __slots__ = (
167 | # Bone name, used as key in many places.
168 | "name",
169 | "parent", # decorated bone parent, set in a later loop
170 | # Blender armature bone.
171 | "rest_bone",
172 | # Blender pose bone.
173 | "pose_bone",
174 | # Blender pose matrix.
175 | "pose_mat",
176 | # Blender rest matrix (armature space).
177 | "rest_arm_mat",
178 | # Blender rest matrix (local space).
179 | "rest_local_mat",
180 | # Pose_mat inverted.
181 | "pose_imat",
182 | # Rest_arm_mat inverted.
183 | "rest_arm_imat",
184 | # Rest_local_mat inverted.
185 | "rest_local_imat",
186 | # Last used euler to preserve euler compatibility in between keyframes.
187 | "prev_euler",
188 | # Is the bone disconnected to the parent bone?
189 | "skip_position",
190 | "rot_order",
191 | "rot_order_str",
192 | # Needed for the euler order when converting from a matrix.
193 | "rot_order_str_reverse",
194 | )
195 |
196 | _eul_order_lookup = {
197 | 'XYZ': (0, 1, 2),
198 | 'XZY': (0, 2, 1),
199 | 'YXZ': (1, 0, 2),
200 | 'YZX': (1, 2, 0),
201 | 'ZXY': (2, 0, 1),
202 | 'ZYX': (2, 1, 0),
203 | }
204 |
205 | def __init__(self, bone_name):
206 | self.name = bone_name
207 | self.rest_bone = arm.bones[bone_name]
208 | self.pose_bone = obj.pose.bones[bone_name]
209 |
210 | if rotate_mode == "NATIVE":
211 | self.rot_order_str = ensure_rot_order(self.pose_bone.rotation_mode)
212 | else:
213 | self.rot_order_str = rotate_mode
214 | self.rot_order_str_reverse = self.rot_order_str[::-1]
215 |
216 | self.rot_order = DecoratedBone._eul_order_lookup[self.rot_order_str]
217 |
218 | self.pose_mat = self.pose_bone.matrix
219 |
220 | # mat = self.rest_bone.matrix # UNUSED
221 | self.rest_arm_mat = self.rest_bone.matrix_local
222 | self.rest_local_mat = self.rest_bone.matrix
223 |
224 | # inverted mats
225 | self.pose_imat = self.pose_mat.inverted()
226 | self.rest_arm_imat = self.rest_arm_mat.inverted()
227 | self.rest_local_imat = self.rest_local_mat.inverted()
228 |
229 | self.parent = None
230 | self.prev_euler = Euler((0.0, 0.0, 0.0), self.rot_order_str_reverse)
231 | self.skip_position = ((self.rest_bone.use_connect or root_transform_only) and self.rest_bone.parent)
232 |
233 | def update_posedata(self):
234 | self.pose_mat = self.pose_bone.matrix
235 | self.pose_imat = self.pose_mat.inverted()
236 |
237 | def __repr__(self):
238 | if self.parent:
239 | return "[\"%s\" child on \"%s\"]\n" % (self.name, self.parent.name)
240 | else:
241 | return "[\"%s\" root bone]\n" % (self.name)
242 |
243 | bones_decorated = [DecoratedBone(bone_name) for bone_name in serialized_names]
244 |
245 | # Assign parents
246 | bones_decorated_dict = {dbone.name: dbone for dbone in bones_decorated}
247 | for dbone in bones_decorated:
248 | parent = dbone.rest_bone.parent
249 | if parent:
250 | dbone.parent = bones_decorated_dict[parent.name]
251 | del bones_decorated_dict
252 | # finish assigning parents
253 |
254 | scene = context.scene
255 | frame_current = scene.frame_current
256 |
257 | file_str += "MOTION\n"
258 | file_str += "Frames: %d\n" % (frame_end - frame_start + 1)
259 | file_str += "Frame Time: %.6f\n" % (1.0 / (scene.render.fps / scene.render.fps_base))
260 |
261 | for frame in range(frame_start, frame_end + 1):
262 | scene.frame_set(frame)
263 |
264 | for dbone in bones_decorated:
265 | dbone.update_posedata()
266 |
267 | for dbone in bones_decorated:
268 | trans = Matrix.Translation(dbone.rest_bone.head_local)
269 | itrans = Matrix.Translation(-dbone.rest_bone.head_local)
270 |
271 | if dbone.parent:
272 | mat_final = dbone.parent.rest_arm_mat @ dbone.parent.pose_imat @ dbone.pose_mat @ dbone.rest_arm_imat
273 | mat_final = itrans @ mat_final @ trans
274 | loc = mat_final.to_translation() + (dbone.rest_bone.head_local - dbone.parent.rest_bone.head_local)
275 | else:
276 | mat_final = dbone.pose_mat @ dbone.rest_arm_imat
277 | mat_final = itrans @ mat_final @ trans
278 | loc = mat_final.to_translation() + dbone.rest_bone.head
279 |
280 | # keep eulers compatible, no jumping on interpolation.
281 | rot = mat_final.to_euler(dbone.rot_order_str_reverse, dbone.prev_euler)
282 |
283 | if not dbone.skip_position:
284 | file_str += "%.6f %.6f %.6f " % (loc * global_scale)[:]
285 |
286 | file_str += "%.6f %.6f %.6f " % (degrees(rot[dbone.rot_order[0]]), degrees(rot[dbone.rot_order[1]]), degrees(rot[dbone.rot_order[2]]))
287 |
288 | dbone.prev_euler = rot
289 |
290 | file_str += "\n"
291 |
292 | scene.frame_set(frame_current)
293 |
294 | return file_str
295 |
296 |
297 | class BVH_Node:
298 | __slots__ = (
299 | # Bvh joint name.
300 | 'name',
301 | # BVH_Node type or None for no parent.
302 | 'parent',
303 | # A list of children of this type..
304 | 'children',
305 | # Worldspace rest location for the head of this node.
306 | 'rest_head_world',
307 | # Localspace rest location for the head of this node.
308 | 'rest_head_local',
309 | # Worldspace rest location for the tail of this node.
310 | 'rest_tail_world',
311 | # Worldspace rest location for the tail of this node.
312 | 'rest_tail_local',
313 | # List of 6 ints, -1 for an unused channel,
314 | # otherwise an index for the BVH motion data lines,
315 | # loc triple then rot triple.
316 | 'channels',
317 | # A triple of indices as to the order rotation is applied.
318 | # [0,1,2] is x/y/z - [None, None, None] if no rotation..
319 | 'rot_order',
320 | # Same as above but a string 'XYZ' format..
321 | 'rot_order_str',
322 | # A list one tuple's one for each frame: (locx, locy, locz, rotx, roty, rotz),
323 | # euler rotation ALWAYS stored xyz order, even when native used.
324 | 'anim_data',
325 | # Convenience function, bool, same as: (channels[0] != -1 or channels[1] != -1 or channels[2] != -1).
326 | 'has_loc',
327 | # Convenience function, bool, same as: (channels[3] != -1 or channels[4] != -1 or channels[5] != -1).
328 | 'has_rot',
329 | # Index from the file, not strictly needed but nice to maintain order.
330 | 'index',
331 | # Use this for whatever you want.
332 | 'temp',
333 | )
334 |
335 | _eul_order_lookup = {
336 | (None, None, None): 'XYZ', # XXX Dummy one, no rotation anyway!
337 | (0, 1, 2): 'XYZ',
338 | (0, 2, 1): 'XZY',
339 | (1, 0, 2): 'YXZ',
340 | (1, 2, 0): 'YZX',
341 | (2, 0, 1): 'ZXY',
342 | (2, 1, 0): 'ZYX',
343 | }
344 |
345 | def __init__(self, name, rest_head_world, rest_head_local, parent, channels, rot_order, index):
346 | self.name = name
347 | self.rest_head_world = rest_head_world
348 | self.rest_head_local = rest_head_local
349 | self.rest_tail_world = None
350 | self.rest_tail_local = None
351 | self.parent = parent
352 | self.channels = channels
353 | self.rot_order = tuple(rot_order)
354 | self.rot_order_str = BVH_Node._eul_order_lookup[self.rot_order]
355 | self.index = index
356 |
357 | # convenience functions
358 | self.has_loc = channels[0] != -1 or channels[1] != -1 or channels[2] != -1
359 | self.has_rot = channels[3] != -1 or channels[4] != -1 or channels[5] != -1
360 |
361 | self.children = []
362 |
363 | # List of 6 length tuples: (lx, ly, lz, rx, ry, rz)
364 | # even if the channels aren't used they will just be zero.
365 | self.anim_data = [(0, 0, 0, 0, 0, 0)]
366 |
367 | def __repr__(self):
368 | return (
369 | "BVH name: '%s', rest_loc:(%.3f,%.3f,%.3f), rest_tail:(%.3f,%.3f,%.3f)" % (
370 | self.name,
371 | *self.rest_head_world,
372 | *self.rest_head_world,
373 | )
374 | )
375 |
376 |
377 | def sorted_nodes(bvh_nodes):
378 | bvh_nodes_list = list(bvh_nodes.values())
379 | bvh_nodes_list.sort(key=lambda bvh_node: bvh_node.index)
380 | return bvh_nodes_list
381 |
382 |
383 | def read_bvh(context, bvh_str, rotate_mode='XYZ', global_scale=1.0):
384 | # Separate into a list of lists, each line a list of words.
385 | file_lines = bvh_str
386 | # Non standard carriage returns?
387 | if len(file_lines) == 1:
388 | file_lines = file_lines[0].split('\r')
389 |
390 | # Split by whitespace.
391 | file_lines = [ll for ll in [l.split() for l in file_lines] if ll]
392 |
393 | # Create hierarchy as empties
394 | if file_lines[0][0].lower() == 'hierarchy':
395 | # print 'Importing the BVH Hierarchy for:', file_path
396 | pass
397 | else:
398 | raise Exception("This is not a BVH file")
399 |
400 | bvh_nodes = {None: None}
401 | bvh_nodes_serial = [None]
402 | bvh_frame_count = None
403 | bvh_frame_time = None
404 |
405 | channelIndex = -1
406 |
407 | lineIdx = 0 # An index for the file.
408 | while lineIdx < len(file_lines) - 1:
409 | if file_lines[lineIdx][0].lower() in {'root', 'joint'}:
410 |
411 | # Join spaces into 1 word with underscores joining it.
412 | if len(file_lines[lineIdx]) > 2:
413 | file_lines[lineIdx][1] = '_'.join(file_lines[lineIdx][1:])
414 | file_lines[lineIdx] = file_lines[lineIdx][:2]
415 |
416 | # MAY NEED TO SUPPORT MULTIPLE ROOTS HERE! Still unsure weather multiple roots are possible?
417 |
418 | # Make sure the names are unique - Object names will match joint names exactly and both will be unique.
419 | name = file_lines[lineIdx][1]
420 |
421 | # print '%snode: %s, parent: %s' % (len(bvh_nodes_serial) * ' ', name, bvh_nodes_serial[-1])
422 |
423 | lineIdx += 2 # Increment to the next line (Offset)
424 | rest_head_local = global_scale * Vector((
425 | float(file_lines[lineIdx][1]),
426 | float(file_lines[lineIdx][2]),
427 | float(file_lines[lineIdx][3]),
428 | ))
429 | lineIdx += 1 # Increment to the next line (Channels)
430 |
431 | # newChannel[Xposition, Yposition, Zposition, Xrotation, Yrotation, Zrotation]
432 | # newChannel references indices to the motiondata,
433 | # if not assigned then -1 refers to the last value that will be added on loading at a value of zero, this is appended
434 | # We'll add a zero value onto the end of the MotionDATA so this always refers to a value.
435 | my_channel = [-1, -1, -1, -1, -1, -1]
436 | my_rot_order = [None, None, None]
437 | rot_count = 0
438 | for channel in file_lines[lineIdx][2:]:
439 | channel = channel.lower()
440 | channelIndex += 1 # So the index points to the right channel
441 | if channel == 'xposition':
442 | my_channel[0] = channelIndex
443 | elif channel == 'yposition':
444 | my_channel[1] = channelIndex
445 | elif channel == 'zposition':
446 | my_channel[2] = channelIndex
447 |
448 | elif channel == 'xrotation':
449 | my_channel[3] = channelIndex
450 | my_rot_order[rot_count] = 0
451 | rot_count += 1
452 | elif channel == 'yrotation':
453 | my_channel[4] = channelIndex
454 | my_rot_order[rot_count] = 1
455 | rot_count += 1
456 | elif channel == 'zrotation':
457 | my_channel[5] = channelIndex
458 | my_rot_order[rot_count] = 2
459 | rot_count += 1
460 |
461 | channels = file_lines[lineIdx][2:]
462 |
463 | my_parent = bvh_nodes_serial[-1] # account for none
464 |
465 | # Apply the parents offset accumulatively
466 | if my_parent is None:
467 | rest_head_world = Vector(rest_head_local)
468 | else:
469 | rest_head_world = my_parent.rest_head_world + rest_head_local
470 |
471 | bvh_node = bvh_nodes[name] = BVH_Node(
472 | name,
473 | rest_head_world,
474 | rest_head_local,
475 | my_parent,
476 | my_channel,
477 | my_rot_order,
478 | len(bvh_nodes) - 1,
479 | )
480 |
481 | # If we have another child then we can call ourselves a parent, else
482 | bvh_nodes_serial.append(bvh_node)
483 |
484 | # Account for an end node.
485 | # There is sometimes a name after 'End Site' but we will ignore it.
486 | if file_lines[lineIdx][0].lower() == 'end' and file_lines[lineIdx][1].lower() == 'site':
487 | # Increment to the next line (Offset)
488 | lineIdx += 2
489 | rest_tail = global_scale * Vector((
490 | float(file_lines[lineIdx][1]),
491 | float(file_lines[lineIdx][2]),
492 | float(file_lines[lineIdx][3]),
493 | ))
494 |
495 | bvh_nodes_serial[-1].rest_tail_world = bvh_nodes_serial[-1].rest_head_world + rest_tail
496 | bvh_nodes_serial[-1].rest_tail_local = bvh_nodes_serial[-1].rest_head_local + rest_tail
497 |
498 | # Just so we can remove the parents in a uniform way,
499 | # the end has kids so this is a placeholder.
500 | bvh_nodes_serial.append(None)
501 |
502 | if len(file_lines[lineIdx]) == 1 and file_lines[lineIdx][0] == '}': # == ['}']
503 | bvh_nodes_serial.pop() # Remove the last item
504 |
505 | # End of the hierarchy. Begin the animation section of the file with
506 | # the following header.
507 | # MOTION
508 | # Frames: n
509 | # Frame Time: dt
510 | if len(file_lines[lineIdx]) == 1 and file_lines[lineIdx][0].lower() == 'motion':
511 | lineIdx += 1 # Read frame count.
512 | if (
513 | len(file_lines[lineIdx]) == 2 and
514 | file_lines[lineIdx][0].lower() == 'frames:'
515 | ):
516 | bvh_frame_count = int(file_lines[lineIdx][1])
517 |
518 | lineIdx += 1 # Read frame rate.
519 | if (
520 | len(file_lines[lineIdx]) == 3 and
521 | file_lines[lineIdx][0].lower() == 'frame' and
522 | file_lines[lineIdx][1].lower() == 'time:'
523 | ):
524 | bvh_frame_time = float(file_lines[lineIdx][2])
525 |
526 | lineIdx += 1 # Set the cursor to the first frame
527 |
528 | break
529 |
530 | lineIdx += 1
531 |
532 | # Remove the None value used for easy parent reference
533 | del bvh_nodes[None]
534 | # Don't use anymore
535 | del bvh_nodes_serial
536 |
537 | # importing world with any order but nicer to maintain order
538 | # second life expects it, which isn't to spec.
539 | bvh_nodes_list = sorted_nodes(bvh_nodes)
540 |
541 | while lineIdx < len(file_lines):
542 | line = file_lines[lineIdx]
543 | for bvh_node in bvh_nodes_list:
544 | # for bvh_node in bvh_nodes_serial:
545 | lx = ly = lz = rx = ry = rz = 0.0
546 | channels = bvh_node.channels
547 | anim_data = bvh_node.anim_data
548 | if channels[0] != -1:
549 | lx = global_scale * float(line[channels[0]])
550 |
551 | if channels[1] != -1:
552 | ly = global_scale * float(line[channels[1]])
553 |
554 | if channels[2] != -1:
555 | lz = global_scale * float(line[channels[2]])
556 |
557 | if channels[3] != -1 or channels[4] != -1 or channels[5] != -1:
558 |
559 | rx = radians(float(line[channels[3]]))
560 | ry = radians(float(line[channels[4]]))
561 | rz = radians(float(line[channels[5]]))
562 |
563 | # Done importing motion data #
564 | anim_data.append((lx, ly, lz, rx, ry, rz))
565 | lineIdx += 1
566 |
567 | # Assign children
568 | for bvh_node in bvh_nodes_list:
569 | bvh_node_parent = bvh_node.parent
570 | if bvh_node_parent:
571 | bvh_node_parent.children.append(bvh_node)
572 |
573 | # Now set the tip of each bvh_node
574 | for bvh_node in bvh_nodes_list:
575 |
576 | if not bvh_node.rest_tail_world:
577 | if len(bvh_node.children) == 0:
578 | # could just fail here, but rare BVH files have childless nodes
579 | bvh_node.rest_tail_world = Vector(bvh_node.rest_head_world)
580 | bvh_node.rest_tail_local = Vector(bvh_node.rest_head_local)
581 | elif len(bvh_node.children) == 1:
582 | bvh_node.rest_tail_world = Vector(bvh_node.children[0].rest_head_world)
583 | bvh_node.rest_tail_local = bvh_node.rest_head_local + bvh_node.children[0].rest_head_local
584 | else:
585 | # allow this, see above
586 | # if not bvh_node.children:
587 | # raise Exception("bvh node has no end and no children. bad file")
588 |
589 | # Removed temp for now
590 | rest_tail_world = Vector((0.0, 0.0, 0.0))
591 | rest_tail_local = Vector((0.0, 0.0, 0.0))
592 | for bvh_node_child in bvh_node.children:
593 | rest_tail_world += bvh_node_child.rest_head_world
594 | rest_tail_local += bvh_node_child.rest_head_local
595 |
596 | bvh_node.rest_tail_world = rest_tail_world * (1.0 / len(bvh_node.children))
597 | bvh_node.rest_tail_local = rest_tail_local * (1.0 / len(bvh_node.children))
598 |
599 | # Make sure tail isn't the same location as the head.
600 | if (bvh_node.rest_tail_local - bvh_node.rest_head_local).length <= 0.001 * global_scale:
601 | print("\tzero length node found:", bvh_node.name)
602 | bvh_node.rest_tail_local.y = bvh_node.rest_tail_local.y + global_scale / 10
603 | bvh_node.rest_tail_world.y = bvh_node.rest_tail_world.y + global_scale / 10
604 |
605 | return bvh_nodes, bvh_frame_time, bvh_frame_count
606 |
607 |
608 | def bvh_node_dict2objects(context, bvh_name, bvh_nodes, rotate_mode='NATIVE', frame_start=1, IMPORT_LOOP=False):
609 |
610 | if frame_start < 1:
611 | frame_start = 1
612 |
613 | scene = context.scene
614 | for obj in scene.objects:
615 | obj.select_set(False)
616 |
617 | objects = []
618 |
619 | def add_ob(name):
620 | obj = bpy.data.objects.new(name, None)
621 | context.collection.objects.link(obj)
622 | objects.append(obj)
623 | obj.select_set(True)
624 |
625 | # nicer drawing.
626 | obj.empty_display_type = 'CUBE'
627 | obj.empty_display_size = 0.1
628 |
629 | return obj
630 |
631 | # Add objects
632 | for name, bvh_node in bvh_nodes.items():
633 | bvh_node.temp = add_ob(name)
634 | bvh_node.temp.rotation_mode = bvh_node.rot_order_str[::-1]
635 |
636 | # Parent the objects
637 | for bvh_node in bvh_nodes.values():
638 | for bvh_node_child in bvh_node.children:
639 | bvh_node_child.temp.parent = bvh_node.temp
640 |
641 | # Offset
642 | for bvh_node in bvh_nodes.values():
643 | # Make relative to parents offset
644 | bvh_node.temp.location = bvh_node.rest_head_local
645 |
646 | # Add tail objects
647 | for name, bvh_node in bvh_nodes.items():
648 | if not bvh_node.children:
649 | ob_end = add_ob(name + '_end')
650 | ob_end.parent = bvh_node.temp
651 | ob_end.location = bvh_node.rest_tail_world - bvh_node.rest_head_world
652 |
653 | for name, bvh_node in bvh_nodes.items():
654 | obj = bvh_node.temp
655 |
656 | for frame_current in range(len(bvh_node.anim_data)):
657 |
658 | lx, ly, lz, rx, ry, rz = bvh_node.anim_data[frame_current]
659 |
660 | if bvh_node.has_loc:
661 | obj.delta_location = Vector((lx, ly, lz)) - bvh_node.rest_head_world
662 | obj.keyframe_insert("delta_location", index=-1, frame=frame_start + frame_current)
663 |
664 | if bvh_node.has_rot:
665 | obj.delta_rotation_euler = rx, ry, rz
666 | obj.keyframe_insert("delta_rotation_euler", index=-1, frame=frame_start + frame_current)
667 |
668 | return objects
669 |
670 |
671 | def bvh_node_dict2armature(
672 | context,
673 | bvh_name,
674 | bvh_nodes,
675 | bvh_frame_time,
676 | rotate_mode='XYZ',
677 | frame_start=1,
678 | IMPORT_LOOP=False,
679 | global_matrix=None,
680 | use_fps_scale=False,
681 | ):
682 |
683 | if frame_start < 1:
684 | frame_start = 1
685 |
686 | # Add the new armature,
687 | scene = context.scene
688 | for obj in scene.objects:
689 | obj.select_set(False)
690 |
691 | arm_data = bpy.data.armatures.new(bvh_name)
692 | arm_ob = bpy.data.objects.new(bvh_name, arm_data)
693 |
694 | context.collection.objects.link(arm_ob)
695 |
696 | arm_ob.select_set(True)
697 | context.view_layer.objects.active = arm_ob
698 |
699 | bpy.ops.object.mode_set(mode='OBJECT', toggle=False)
700 | bpy.ops.object.mode_set(mode='EDIT', toggle=False)
701 |
702 | bvh_nodes_list = sorted_nodes(bvh_nodes)
703 |
704 | # Get the average bone length for zero length bones, we may not use this.
705 | average_bone_length = 0.0
706 | nonzero_count = 0
707 | for bvh_node in bvh_nodes_list:
708 | l = (bvh_node.rest_head_local - bvh_node.rest_tail_local).length
709 | if l:
710 | average_bone_length += l
711 | nonzero_count += 1
712 |
713 | # Very rare cases all bones could be zero length???
714 | if not average_bone_length:
715 | average_bone_length = 0.1
716 | else:
717 | # Normal operation
718 | average_bone_length = average_bone_length / nonzero_count
719 |
720 | # XXX, annoying, remove bone.
721 | while arm_data.edit_bones:
722 | arm_ob.edit_bones.remove(arm_data.edit_bones[-1])
723 |
724 | ZERO_AREA_BONES = []
725 | for bvh_node in bvh_nodes_list:
726 |
727 | # New editbone
728 | bone = bvh_node.temp = arm_data.edit_bones.new(bvh_node.name)
729 |
730 | bone.head = bvh_node.rest_head_world
731 | bone.tail = bvh_node.rest_tail_world
732 |
733 | # Zero Length Bones! (an exceptional case)
734 | if (bone.head - bone.tail).length < 0.001:
735 | print("\tzero length bone found:", bone.name)
736 | if bvh_node.parent:
737 | ofs = bvh_node.parent.rest_head_local - bvh_node.parent.rest_tail_local
738 | if ofs.length: # is our parent zero length also?? unlikely
739 | bone.tail = bone.tail - ofs
740 | else:
741 | bone.tail.y = bone.tail.y + average_bone_length
742 | else:
743 | bone.tail.y = bone.tail.y + average_bone_length
744 |
745 | ZERO_AREA_BONES.append(bone.name)
746 |
747 | for bvh_node in bvh_nodes_list:
748 | if bvh_node.parent:
749 | # bvh_node.temp is the Editbone
750 |
751 | # Set the bone parent
752 | bvh_node.temp.parent = bvh_node.parent.temp
753 |
754 | # Set the connection state
755 | if(
756 | (not bvh_node.has_loc) and
757 | (bvh_node.parent.temp.name not in ZERO_AREA_BONES) and
758 | (bvh_node.parent.rest_tail_local == bvh_node.rest_head_local)
759 | ):
760 | bvh_node.temp.use_connect = True
761 |
762 | # Replace the editbone with the editbone name,
763 | # to avoid memory errors accessing the editbone outside editmode
764 | for bvh_node in bvh_nodes_list:
765 | bvh_node.temp = bvh_node.temp.name
766 |
767 | # Now Apply the animation to the armature
768 |
769 | # Get armature animation data
770 | bpy.ops.object.mode_set(mode='OBJECT', toggle=False)
771 |
772 | pose = arm_ob.pose
773 | pose_bones = pose.bones
774 |
775 | if rotate_mode == 'NATIVE':
776 | for bvh_node in bvh_nodes_list:
777 | bone_name = bvh_node.temp # may not be the same name as the bvh_node, could have been shortened.
778 | pose_bone = pose_bones[bone_name]
779 | pose_bone.rotation_mode = bvh_node.rot_order_str
780 |
781 | elif rotate_mode != 'QUATERNION':
782 | for pose_bone in pose_bones:
783 | pose_bone.rotation_mode = rotate_mode
784 | else:
785 | # Quats default
786 | pass
787 |
788 | context.view_layer.update()
789 |
790 | arm_ob.animation_data_create()
791 | action = bpy.data.actions.new(name=bvh_name)
792 | arm_ob.animation_data.action = action
793 |
794 | # Replace the bvh_node.temp (currently an editbone)
795 | # With a tuple (pose_bone, armature_bone, bone_rest_matrix, bone_rest_matrix_inv)
796 | num_frame = 0
797 | for bvh_node in bvh_nodes_list:
798 | bone_name = bvh_node.temp # may not be the same name as the bvh_node, could have been shortened.
799 | pose_bone = pose_bones[bone_name]
800 | rest_bone = arm_data.bones[bone_name]
801 | bone_rest_matrix = rest_bone.matrix_local.to_3x3()
802 |
803 | bone_rest_matrix_inv = Matrix(bone_rest_matrix)
804 | bone_rest_matrix_inv.invert()
805 |
806 | bone_rest_matrix_inv.resize_4x4()
807 | bone_rest_matrix.resize_4x4()
808 | bvh_node.temp = (pose_bone, bone, bone_rest_matrix, bone_rest_matrix_inv)
809 |
810 | if 0 == num_frame:
811 | num_frame = len(bvh_node.anim_data)
812 |
813 | # Choose to skip some frames at the beginning. Frame 0 is the rest pose
814 | # used internally by this importer. Frame 1, by convention, is also often
815 | # the rest pose of the skeleton exported by the motion capture system.
816 | skip_frame = 1
817 | if num_frame > skip_frame:
818 | num_frame = num_frame - skip_frame
819 |
820 | # Create a shared time axis for all animation curves.
821 | time = [float(frame_start)] * num_frame
822 | if use_fps_scale:
823 | dt = scene.render.fps * bvh_frame_time
824 | for frame_i in range(1, num_frame):
825 | time[frame_i] += float(frame_i) * dt
826 | else:
827 | for frame_i in range(1, num_frame):
828 | time[frame_i] += float(frame_i)
829 |
830 | # print("bvh_frame_time = %f, dt = %f, num_frame = %d"
831 | # % (bvh_frame_time, dt, num_frame]))
832 |
833 | for i, bvh_node in enumerate(bvh_nodes_list):
834 | pose_bone, bone, bone_rest_matrix, bone_rest_matrix_inv = bvh_node.temp
835 |
836 | if bvh_node.has_loc:
837 | # Not sure if there is a way to query this or access it in the
838 | # PoseBone structure.
839 | data_path = 'pose.bones["%s"].location' % pose_bone.name
840 |
841 | location = [(0.0, 0.0, 0.0)] * num_frame
842 | for frame_i in range(num_frame):
843 | bvh_loc = bvh_node.anim_data[frame_i + skip_frame][:3]
844 |
845 | bone_translate_matrix = Matrix.Translation(
846 | Vector(bvh_loc) - bvh_node.rest_head_local)
847 | location[frame_i] = (bone_rest_matrix_inv @
848 | bone_translate_matrix).to_translation()
849 |
850 | # For each location x, y, z.
851 | for axis_i in range(3):
852 | curve = action.fcurves.new(data_path=data_path, index=axis_i, action_group=bvh_node.name)
853 | keyframe_points = curve.keyframe_points
854 | keyframe_points.add(num_frame)
855 |
856 | for frame_i in range(num_frame):
857 | keyframe_points[frame_i].co = (
858 | time[frame_i],
859 | location[frame_i][axis_i],
860 | )
861 |
862 | if bvh_node.has_rot:
863 | data_path = None
864 | rotate = None
865 |
866 | if 'QUATERNION' == rotate_mode:
867 | rotate = [(1.0, 0.0, 0.0, 0.0)] * num_frame
868 | data_path = ('pose.bones["%s"].rotation_quaternion'
869 | % pose_bone.name)
870 | else:
871 | rotate = [(0.0, 0.0, 0.0)] * num_frame
872 | data_path = ('pose.bones["%s"].rotation_euler' %
873 | pose_bone.name)
874 |
875 | prev_euler = Euler((0.0, 0.0, 0.0))
876 | for frame_i in range(num_frame):
877 | bvh_rot = bvh_node.anim_data[frame_i + skip_frame][3:]
878 |
879 | # apply rotation order and convert to XYZ
880 | # note that the rot_order_str is reversed.
881 | euler = Euler(bvh_rot, bvh_node.rot_order_str[::-1])
882 | bone_rotation_matrix = euler.to_matrix().to_4x4()
883 | bone_rotation_matrix = (
884 | bone_rest_matrix_inv @
885 | bone_rotation_matrix @
886 | bone_rest_matrix
887 | )
888 |
889 | if len(rotate[frame_i]) == 4:
890 | rotate[frame_i] = bone_rotation_matrix.to_quaternion()
891 | else:
892 | rotate[frame_i] = bone_rotation_matrix.to_euler(
893 | pose_bone.rotation_mode, prev_euler)
894 | prev_euler = rotate[frame_i]
895 |
896 | # For each euler angle x, y, z (or quaternion w, x, y, z).
897 | for axis_i in range(len(rotate[0])):
898 | curve = action.fcurves.new(data_path=data_path, index=axis_i, action_group=bvh_node.name)
899 | keyframe_points = curve.keyframe_points
900 | keyframe_points.add(num_frame)
901 |
902 | for frame_i in range(num_frame):
903 | keyframe_points[frame_i].co = (
904 | time[frame_i],
905 | rotate[frame_i][axis_i],
906 | )
907 |
908 | for cu in action.fcurves:
909 | if IMPORT_LOOP:
910 | pass # 2.5 doenst have cyclic now?
911 |
912 | for bez in cu.keyframe_points:
913 | bez.interpolation = 'LINEAR'
914 |
915 | # finally apply matrix
916 | try:
917 | arm_ob.matrix_world = global_matrix
918 | except:
919 | pass
920 | bpy.ops.object.transform_apply(location=False, rotation=True, scale=False)
921 |
922 | return arm_ob
923 |
924 |
925 | def load(
926 | context,
927 | bvh_str,
928 | *,
929 | target='ARMATURE',
930 | rotate_mode='NATIVE',
931 | global_scale=1.0,
932 | use_cyclic=False,
933 | frame_start=1,
934 | global_matrix=None,
935 | use_fps_scale=False,
936 | update_scene_fps=False,
937 | update_scene_duration=False,
938 | report=print,
939 | ):
940 | import time
941 | t1 = time.time()
942 |
943 | bvh_nodes, bvh_frame_time, bvh_frame_count = read_bvh(
944 | context, bvh_str,
945 | rotate_mode=rotate_mode,
946 | global_scale=global_scale,
947 | )
948 |
949 | print("%.4f" % (time.time() - t1))
950 |
951 | scene = context.scene
952 | frame_orig = scene.frame_current
953 |
954 | # Broken BVH handling: guess frame rate when it is not contained in the file.
955 | if bvh_frame_time is None:
956 | report(
957 | {'WARNING'},
958 | "The BVH file does not contain frame duration in its MOTION "
959 | "section, assuming the BVH and Blender scene have the same "
960 | "frame rate"
961 | )
962 | bvh_frame_time = scene.render.fps_base / scene.render.fps
963 | # No need to scale the frame rate, as they're equal now anyway.
964 | use_fps_scale = False
965 |
966 | if update_scene_fps:
967 | _update_scene_fps(context, report, bvh_frame_time)
968 |
969 | # Now that we have a 1-to-1 mapping of Blender frames and BVH frames, there is no need
970 | # to scale the FPS any more. It's even better not to, to prevent roundoff errors.
971 | use_fps_scale = False
972 |
973 | if update_scene_duration:
974 | _update_scene_duration(context, report, bvh_frame_count, bvh_frame_time, frame_start, use_fps_scale)
975 |
976 | t1 = time.time()
977 | print("\timporting to blender...", end="")
978 |
979 | bvh_name = bpy.path.display_name_from_filepath('synsized')
980 |
981 | if target == 'ARMATURE':
982 | bvh_node_dict2armature(
983 | context, bvh_name, bvh_nodes, bvh_frame_time,
984 | rotate_mode=rotate_mode,
985 | frame_start=frame_start,
986 | IMPORT_LOOP=use_cyclic,
987 | global_matrix=global_matrix,
988 | use_fps_scale=use_fps_scale,
989 | )
990 |
991 | elif target == 'OBJECT':
992 | bvh_node_dict2objects(
993 | context, bvh_name, bvh_nodes,
994 | rotate_mode=rotate_mode,
995 | frame_start=frame_start,
996 | IMPORT_LOOP=use_cyclic,
997 | # global_matrix=global_matrix, # TODO
998 | )
999 |
1000 | else:
1001 | report({'ERROR'}, tip_("Invalid target %r (must be 'ARMATURE' or 'OBJECT')") % target)
1002 | return {'CANCELLED'}
1003 |
1004 | print('Done in %.4f\n' % (time.time() - t1))
1005 |
1006 | context.scene.frame_set(frame_orig)
1007 |
1008 | return {'FINISHED'}
1009 |
1010 |
1011 | def _update_scene_fps(context, report, bvh_frame_time):
1012 | """Update the scene's FPS settings from the BVH, but only if the BVH contains enough info."""
1013 |
1014 | # Broken BVH handling: prevent division by zero.
1015 | if bvh_frame_time == 0.0:
1016 | report(
1017 | {'WARNING'},
1018 | "Unable to update scene frame rate, as the BVH file "
1019 | "contains a zero frame duration in its MOTION section",
1020 | )
1021 | return
1022 |
1023 | scene = context.scene
1024 | scene_fps = scene.render.fps / scene.render.fps_base
1025 | new_fps = 1.0 / bvh_frame_time
1026 |
1027 | if scene.render.fps != new_fps or scene.render.fps_base != 1.0:
1028 | print("\tupdating scene FPS (was %f) to BVH FPS (%f)" % (scene_fps, new_fps))
1029 | scene.render.fps = int(round(new_fps))
1030 | scene.render.fps_base = scene.render.fps / new_fps
1031 |
1032 |
1033 | def _update_scene_duration(
1034 | context, report, bvh_frame_count, bvh_frame_time, frame_start,
1035 | use_fps_scale):
1036 | """Extend the scene's duration so that the BVH file fits in its entirety."""
1037 |
1038 | if bvh_frame_count is None:
1039 | report(
1040 | {'WARNING'},
1041 | "Unable to extend the scene duration, as the BVH file does not "
1042 | "contain the number of frames in its MOTION section",
1043 | )
1044 | return
1045 |
1046 | # Not likely, but it can happen when a BVH is just used to store an armature.
1047 | if bvh_frame_count == 0:
1048 | return
1049 |
1050 | if use_fps_scale:
1051 | scene_fps = context.scene.render.fps / context.scene.render.fps_base
1052 | scaled_frame_count = int(ceil(bvh_frame_count * bvh_frame_time * scene_fps))
1053 | bvh_last_frame = frame_start + scaled_frame_count
1054 | else:
1055 | bvh_last_frame = frame_start + bvh_frame_count
1056 |
1057 | # Only extend the scene, never shorten it.
1058 | if context.scene.frame_end < bvh_last_frame:
1059 | context.scene.frame_end = bvh_last_frame
1060 |
1061 |
1062 | # This function is from
1063 | # https://github.com/yuki-koyama/blender-cli-rendering
1064 | def set_smooth_shading(mesh: bpy.types.Mesh) -> None:
1065 | for polygon in mesh.polygons:
1066 | polygon.use_smooth = True
1067 |
1068 |
1069 | # This function is from
1070 | # https://github.com/yuki-koyama/blender-cli-rendering
1071 | def create_mesh_from_pydata(scene: bpy.types.Scene,
1072 | vertices: Iterable[Iterable[float]],
1073 | faces: Iterable[Iterable[int]],
1074 | mesh_name: str,
1075 | object_name: str,
1076 | use_smooth: bool = True) -> bpy.types.Object:
1077 | # Add a new mesh and set vertices and faces
1078 | # Note: In this case, it does not require to set edges.
1079 | # Note: After manipulating mesh data, update() needs to be called.
1080 | new_mesh: bpy.types.Mesh = bpy.data.meshes.new(mesh_name)
1081 | new_mesh.from_pydata(vertices, [], faces)
1082 | new_mesh.update()
1083 | if use_smooth:
1084 | set_smooth_shading(new_mesh)
1085 |
1086 | new_object: bpy.types.Object = bpy.data.objects.new(object_name, new_mesh)
1087 | scene.collection.objects.link(new_object)
1088 |
1089 | return new_object
1090 |
1091 |
1092 | # This function is from
1093 | # https://github.com/yuki-koyama/blender-cli-rendering
1094 | def add_subdivision_surface_modifier(mesh_object: bpy.types.Object, level: int, is_simple: bool = False) -> None:
1095 | '''
1096 | https://docs.blender.org/api/current/bpy.types.SubsurfModifier.html
1097 | '''
1098 |
1099 | modifier: bpy.types.SubsurfModifier = mesh_object.modifiers.new(name="Subsurf", type='SUBSURF')
1100 |
1101 | modifier.levels = level
1102 | modifier.render_levels = level
1103 | modifier.subdivision_type = 'SIMPLE' if is_simple else 'CATMULL_CLARK'
1104 |
1105 |
1106 | # This function is from
1107 | # https://github.com/yuki-koyama/blender-cli-rendering
1108 | def create_armature_mesh(scene: bpy.types.Scene, armature_object: bpy.types.Object, mesh_name: str) -> bpy.types.Object:
1109 | assert armature_object.type == 'ARMATURE', 'Error'
1110 | assert len(armature_object.data.bones) != 0, 'Error'
1111 |
1112 | def add_rigid_vertex_group(target_object: bpy.types.Object, name: str, vertex_indices: Iterable[int]) -> None:
1113 | new_vertex_group = target_object.vertex_groups.new(name=name)
1114 | for vertex_index in vertex_indices:
1115 | new_vertex_group.add([vertex_index], 1.0, 'REPLACE')
1116 |
1117 | def generate_bone_mesh_pydata(radius: float, length: float) -> Tuple[List[mathutils.Vector], List[List[int]]]:
1118 | base_radius = radius
1119 | top_radius = 0.5 * radius
1120 |
1121 | vertices = [
1122 | # Cross section of the base part
1123 | mathutils.Vector((-base_radius, 0.0, +base_radius)),
1124 | mathutils.Vector((+base_radius, 0.0, +base_radius)),
1125 | mathutils.Vector((+base_radius, 0.0, -base_radius)),
1126 | mathutils.Vector((-base_radius, 0.0, -base_radius)),
1127 |
1128 | # Cross section of the top part
1129 | mathutils.Vector((-top_radius, length, +top_radius)),
1130 | mathutils.Vector((+top_radius, length, +top_radius)),
1131 | mathutils.Vector((+top_radius, length, -top_radius)),
1132 | mathutils.Vector((-top_radius, length, -top_radius)),
1133 |
1134 | # End points
1135 | mathutils.Vector((0.0, -base_radius, 0.0)),
1136 | mathutils.Vector((0.0, length + top_radius, 0.0))
1137 | ]
1138 |
1139 | faces = [
1140 | # End point for the base part
1141 | [8, 1, 0],
1142 | [8, 2, 1],
1143 | [8, 3, 2],
1144 | [8, 0, 3],
1145 |
1146 | # End point for the top part
1147 | [9, 4, 5],
1148 | [9, 5, 6],
1149 | [9, 6, 7],
1150 | [9, 7, 4],
1151 |
1152 | # Side faces
1153 | [0, 1, 5, 4],
1154 | [1, 2, 6, 5],
1155 | [2, 3, 7, 6],
1156 | [3, 0, 4, 7],
1157 | ]
1158 |
1159 | return vertices, faces
1160 |
1161 | armature_data: bpy.types.Armature = armature_object.data
1162 |
1163 | vertices: List[mathutils.Vector] = []
1164 | faces: List[List[int]] = []
1165 | vertex_groups: List[Dict[str, Any]] = []
1166 |
1167 | for bone in armature_data.bones:
1168 | radius = 0.10 * (0.10 + bone.length)
1169 | temp_vertices, temp_faces = generate_bone_mesh_pydata(radius, bone.length)
1170 |
1171 | vertex_index_offset = len(vertices)
1172 |
1173 | temp_vertex_group = {'name': bone.name, 'vertex_indices': []}
1174 | for local_index, vertex in enumerate(temp_vertices):
1175 | vertices.append(bone.matrix_local @ vertex)
1176 | temp_vertex_group['vertex_indices'].append(local_index + vertex_index_offset)
1177 | vertex_groups.append(temp_vertex_group)
1178 |
1179 | for face in temp_faces:
1180 | if len(face) == 3:
1181 | faces.append([
1182 | face[0] + vertex_index_offset,
1183 | face[1] + vertex_index_offset,
1184 | face[2] + vertex_index_offset,
1185 | ])
1186 | else:
1187 | faces.append([
1188 | face[0] + vertex_index_offset,
1189 | face[1] + vertex_index_offset,
1190 | face[2] + vertex_index_offset,
1191 | face[3] + vertex_index_offset,
1192 | ])
1193 |
1194 | new_object = create_mesh_from_pydata(scene, vertices, faces, mesh_name, mesh_name)
1195 | new_object.matrix_world = armature_object.matrix_world
1196 |
1197 | for vertex_group in vertex_groups:
1198 | add_rigid_vertex_group(new_object, vertex_group['name'], vertex_group['vertex_indices'])
1199 |
1200 | armature_modifier = new_object.modifiers.new('Armature', 'ARMATURE')
1201 | armature_modifier.object = armature_object
1202 | armature_modifier.use_vertex_groups = True
1203 |
1204 | add_subdivision_surface_modifier(new_object, 1, is_simple=True)
1205 | add_subdivision_surface_modifier(new_object, 2, is_simple=False)
1206 |
1207 | # Set the armature as the parent of the new object
1208 | bpy.ops.object.select_all(action='DESELECT')
1209 | new_object.select_set(True)
1210 | armature_object.select_set(True)
1211 | bpy.context.view_layer.objects.active = armature_object
1212 | bpy.ops.object.parent_set(type='OBJECT')
1213 |
1214 | return new_object
1215 |
1216 |
1217 | class OP_AddMesh(bpy.types.Operator):
1218 |
1219 | bl_idname = "genmm.add_mesh"
1220 | bl_label = "Add mesh"
1221 | bl_description = ""
1222 | bl_options = {"REGISTER", "UNDO"}
1223 |
1224 | def __init__(self) -> None:
1225 | super().__init__()
1226 |
1227 | def execute(self, context: bpy.types.Context):
1228 | name = bpy.context.object.name + "_proxy"
1229 | create_armature_mesh(bpy.context.scene, bpy.context.object, name)
1230 |
1231 | return {'FINISHED'}
1232 |
1233 |
1234 | class OP_RunSynthesis(bpy.types.Operator):
1235 |
1236 | bl_idname = "genmm.run_synthesis"
1237 | bl_label = "Run synthesis"
1238 | bl_description = ""
1239 | bl_options = {"REGISTER", "UNDO"}
1240 |
1241 | def __init__(self) -> None:
1242 | super().__init__()
1243 |
1244 | def execute(self, context: bpy.types.Context):
1245 | setting = context.scene.setting
1246 |
1247 | anim = context.object.animation_data.action
1248 | start_frame, end_frame = map(int, anim.frame_range)
1249 | start_frame = start_frame if setting.start_frame == -1 else start_frame
1250 | end_frame = end_frame if setting.end_frame == -1 else end_frame
1251 |
1252 | bvh_str = get_bvh_data(context, frame_start=start_frame, frame_end=end_frame)
1253 | frames_str, frame_time_str = bvh_str.split('MOTION\n')[1].split('\n')[:2]
1254 | motion_data_str = bvh_str.split('MOTION\n')[1].split('\n')[2:-1]
1255 | motion_data = np.array([item.strip().split(' ') for item in motion_data_str], dtype=np.float32)
1256 |
1257 | motion = [BlenderMotion(motion_data, repr='repr6d', use_velo=True, keep_up_pos=True, up_axis=setting.up_axis, padding_last=False)]
1258 | model = GenMM(device='cuda' if torch.cuda.is_available() else 'cpu', silent=True)
1259 | criteria = PatchCoherentLoss(patch_size=setting.patch_size,
1260 | alpha=setting.alpha,
1261 | loop=setting.loop, cache=True)
1262 |
1263 | syn = model.run(motion, criteria,
1264 | num_frames=str(setting.num_syn_frames),
1265 | num_steps=setting.num_steps,
1266 | noise_sigma=setting.noise,
1267 | patch_size=setting.patch_size,
1268 | coarse_ratio=f'{setting.coarse_ratio}x_nframes',
1269 | pyr_factor=setting.pyr_factor)
1270 | motion_data_str = [' '.join(str(x) for x in item) for item in motion[0].parse(syn)]
1271 |
1272 | load(context, bvh_str.split('MOTION\n')[0].split('\n')+['MOTION']+[frames_str]+[frame_time_str]+motion_data_str)
1273 | # name = bpy.context.object.name + "_proxy"
1274 | # create_armature_mesh(bpy.context.scene, bpy.context.object, name)
1275 |
1276 | return {'FINISHED'}
1277 |
1278 |
1279 | class GENMM_PT_ControlPanel(bpy.types.Panel):
1280 |
1281 | bl_label = "GenMM"
1282 | bl_space_type = 'VIEW_3D'
1283 | bl_region_type = 'UI'
1284 | bl_category = "GenMM"
1285 |
1286 | @classmethod
1287 | def poll(cls, context: bpy.types.Context):
1288 | return True
1289 |
1290 | def draw_header(self, context: bpy.types.Context):
1291 | layout = self.layout
1292 | layout.label(text="", icon='PLUGIN')
1293 |
1294 | def draw(self, context: bpy.types.Context):
1295 | layout = self.layout
1296 | scene = bpy.context.scene
1297 |
1298 | ops: List[bpy.type.Operator] = [
1299 | OP_AddMesh,
1300 | ]
1301 | for op in ops:
1302 | layout.operator(op.bl_idname, text=op.bl_label)
1303 |
1304 | box = layout.box()
1305 | box.label(text="Exemplar config:")
1306 | exemplar_row = box.row()
1307 | exemplar_row.prop(scene.setting, "start_frame")
1308 | exemplar_row.prop(scene.setting, "end_frame")
1309 | exemplar_row = box.row()
1310 | exemplar_row.prop(scene.setting, "up_axis")
1311 |
1312 | box = layout.box()
1313 | box.label(text="Synthesis config:")
1314 | box.prop(scene.setting, "loop")
1315 | box.prop(scene.setting, "noise")
1316 | box.prop(scene.setting, "num_syn_frames")
1317 | box.prop(scene.setting, "patch_size")
1318 | box.prop(scene.setting, "coarse_ratio")
1319 | box.prop(scene.setting, "pyr_factor")
1320 | box.prop(scene.setting, "alpha")
1321 | box.prop(scene.setting, "num_steps")
1322 |
1323 | ops: List[bpy.type.Operator] = [
1324 | OP_RunSynthesis,
1325 | ]
1326 | for op in ops:
1327 | layout.operator(op.bl_idname, text=op.bl_label)
1328 |
1329 |
1330 | class PropertyGroup(bpy.types.PropertyGroup):
1331 | '''Property container for options and paths of GenMM'''
1332 | start_frame: bpy.props.IntProperty(
1333 | name="Start Frame",
1334 | description="Start Frame of the Exemplar Moition.",
1335 | default=1)
1336 | end_frame: bpy.props.IntProperty(
1337 | name="End Frame",
1338 | description="End Frame of the Exemplar Moition.",
1339 | default=-1)
1340 | up_axis: bpy.props.EnumProperty(
1341 | name="Up Axis",
1342 | default='Z_UP',
1343 | description="Up axis of the Exemplar Moition",
1344 | items=[('Z_UP', "Z-Up", 'Z Up'),
1345 | ('Y_UP', "Y-Up", 'Y Up'),
1346 | ('X_UP', "X-Up", 'X Up'),
1347 | ]
1348 | )
1349 | noise: bpy.props.FloatProperty(
1350 | name="Noise Intensity",
1351 | description="Intensity of Noise Added to the Synthesized Motion.",
1352 | default=10)
1353 | num_syn_frames: bpy.props.IntProperty(
1354 | name="Num. of Frames",
1355 | description="Number of the Synthesized Motion.",
1356 | default=600)
1357 | patch_size: bpy.props.IntProperty(
1358 | name="Patch Size",
1359 | description="Size for Patch Extraction.",
1360 | min=7,
1361 | default=15)
1362 | coarse_ratio: bpy.props.FloatProperty(
1363 | name="Coarse Ratio",
1364 | description="Ratio of the Coarest Pyramid.",
1365 | min=0.0,
1366 | default=0.2)
1367 | pyr_factor: bpy.props.FloatProperty(
1368 | name="Pyramid Factor",
1369 | description="Pyramid Downsample Factor.",
1370 | min=0.1,
1371 | default=0.75)
1372 | alpha: bpy.props.FloatProperty(
1373 | name="Completeness Alpha",
1374 | description="Alpha Value for Completeness/Diversity Trade-off.",
1375 | default=0.05)
1376 | loop: bpy.props.BoolProperty(
1377 | name="Endless Loop",
1378 | description="Whether to Use Loop Constrain.",
1379 | default=False)
1380 | num_steps: bpy.props.IntProperty(
1381 | name="Num of Steps",
1382 | description="Number of Optimized Steps.",
1383 | default=5)
1384 |
1385 |
1386 | classes = [
1387 | OP_AddMesh,
1388 | OP_RunSynthesis,
1389 | GENMM_PT_ControlPanel,
1390 | ]
1391 |
1392 |
1393 | def register():
1394 | bpy.utils.register_class(PropertyGroup)
1395 | bpy.types.Scene.setting = bpy.props.PointerProperty(type=PropertyGroup)
1396 |
1397 | for cls in classes:
1398 | bpy.utils.register_class(cls)
1399 |
1400 |
1401 | def unregister():
1402 | bpy.utils.unregister_class(PropertyGroup)
1403 |
1404 | for cls in classes:
1405 | bpy.utils.unregister_class(cls)
1406 |
1407 |
1408 | if __name__ == "__main__":
1409 | register()
1410 |
--------------------------------------------------------------------------------
/configs/default.yaml:
--------------------------------------------------------------------------------
1 | # motion data config
2 | repr: 'repr6d'
3 | skeleton_name: null
4 | use_velo: true
5 | keep_up_pos: true
6 | up_axis: 'Y_UP'
7 | padding_last: false
8 | requires_contact: false
9 | joint_reduction: false
10 | skeleton_aware: false
11 | joints_group: null
12 |
13 | # generate parameters
14 | num_frames: '2x_nframes'
15 | alpha: 0.01
16 | num_steps: 3
17 | noise_sigma: 10.0
18 | coarse_ratio: '5x_patchsize'
19 | # coarse_ratio: '0.2x_nframes'
20 | pyr_factor: 0.75
21 | num_stages_limit: -1
22 | patch_size: 11
23 | loop: false
--------------------------------------------------------------------------------
/configs/ganimator.yaml:
--------------------------------------------------------------------------------
1 | ################################################################
2 | # This configuration uses the same input format of GANimmator for generation
3 | ################################################################
4 | outout_dir: './output/ganimator_format'
5 |
6 | # for GANimator BVH data
7 | repr: 'repr6d'
8 | skeleton_name: 'mixamo'
9 | use_velo: true
10 | keep_up_pos: true
11 | up_axis: 'Y_UP'
12 | padding_last: true
13 | requires_contact: true
14 | joint_reduction: true
15 | skeleton_aware: false
16 | joints_group: null
17 |
18 | # generate parameters
19 | num_frames: '2x_nframes'
20 | alpha: 0.01
21 | num_steps: 3
22 | noise_sigma: 10.0
23 | coarse_ratio: '3x_patchsize'
24 | # coarse_ratio: '0.1x_nframes'
25 | pyr_factor: 0.75
26 | num_stages_limit: -1
27 | patch_size: 11
28 | loop: false
--------------------------------------------------------------------------------
/dataset/blender_motion.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import torch
4 | import numpy as np
5 | import torch.nn.functional as F
6 | from .motion import MotionData
7 | from utils.transforms import quat2repr6d, euler2mat, mat2quat, repr6d2quat, quat2euler
8 |
9 | class BlenderMotion:
10 | def __init__(self, motion_data, repr='quat', use_velo=True, keep_up_pos=True, up_axis=None, padding_last=False):
11 | '''
12 | BVHMotion constructor
13 | Args:
14 | motion_data : np.array, bvh format data to load from
15 | repr : string, rotation representation, support ['quat', 'repr6d', 'euler']
16 | use_velo : book, whether to transform the joints positions to velocities
17 | keep_up_pos : bool, whether to keep y position when converting to velocity
18 | up_axis : string, up axis of the motion data
19 | padding_last : bool, whether to pad the last position
20 | requires_contact : bool, whether to concatenate contact information
21 | '''
22 | self.motion_data = motion_data
23 |
24 | def to_tensor(motion_data, repr='euler', rot_only=False):
25 | if repr not in ['euler', 'quat', 'quaternion', 'repr6d']:
26 | raise Exception('Unknown rotation representation')
27 | if repr == 'quaternion' or repr == 'quat' or repr == 'repr6d': # default is euler for blender data
28 | rotations = torch.tensor(motion_data[:, 3:], dtype=torch.float).view(motion_data.shape[0], -1, 3)
29 | if repr == 'quat':
30 | rotations = euler2mat(rotations)
31 | rotations = mat2quat(rotations)
32 | if repr == 'repr6d':
33 | rotations = euler2mat(rotations)
34 | rotations = mat2quat(rotations)
35 | rotations = quat2repr6d(rotations)
36 |
37 | positions = torch.tensor(motion_data[:, :3], dtype=torch.float32)
38 |
39 | if rot_only:
40 | return rotations.reshape(rotations.shape[0], -1)
41 |
42 | rotations = rotations.reshape(rotations.shape[0], -1)
43 | return torch.cat((rotations, positions), dim=-1)
44 |
45 | self.motion_data = MotionData(to_tensor(motion_data, repr=repr).permute(1, 0).unsqueeze(0), repr=repr, use_velo=use_velo,
46 | keep_up_pos=keep_up_pos, up_axis=up_axis, padding_last=padding_last, contact_id=None)
47 | @property
48 | def repr(self):
49 | return self.motion_data.repr
50 |
51 | @property
52 | def use_velo(self):
53 | return self.motion_data.use_velo
54 |
55 | @property
56 | def keep_up_pos(self):
57 | return self.motion_data.keep_up_pos
58 |
59 | @property
60 | def padding_last(self):
61 | return self.motion_data.padding_last
62 |
63 | @property
64 | def concat_id(self):
65 | return self.motion_data.contact_id
66 |
67 | @property
68 | def n_pad(self):
69 | return self.motion_data.n_pad
70 |
71 | @property
72 | def n_contact(self):
73 | return self.motion_data.n_contact
74 |
75 | @property
76 | def n_rot(self):
77 | return self.motion_data.n_rot
78 |
79 | def sample(self, size=None, slerp=False):
80 | '''
81 | Sample motion data, support slerp
82 | '''
83 | return self.motion_data.sample(size, slerp)
84 |
85 | def parse(self, motion, keep_velo=False,):
86 | """
87 | No batch support here!!!
88 | :returns tracks_json
89 | """
90 | motion = motion.clone()
91 |
92 | if self.use_velo and not keep_velo:
93 | motion = self.motion_data.to_position(motion)
94 | if self.n_pad:
95 | motion = motion[:, :-self.n_pad]
96 |
97 | motion = motion.squeeze().permute(1, 0)
98 | pos = motion[..., -3:]
99 | rot = motion[..., :-3].reshape(motion.shape[0], -1, self.n_rot)
100 | if self.repr == 'quat':
101 | rot = quat2euler(rot)
102 | elif self.repr == 'repr6d':
103 | rot = repr6d2quat(rot)
104 | rot = quat2euler(rot)
105 |
106 | return torch.cat([pos, rot.view(motion.shape[0], -1)], dim=-1).cpu().numpy()
107 |
--------------------------------------------------------------------------------
/dataset/bvh/Quaternions.py:
--------------------------------------------------------------------------------
1 | """
2 | This code is modified from:
3 | http://theorangeduck.com/page/deep-learning-framework-character-motion-synthesis-and-editing
4 |
5 | by Daniel Holden et al
6 | """
7 |
8 |
9 | import numpy as np
10 |
11 | class Quaternions:
12 | """
13 | Quaternions is a wrapper around a numpy ndarray
14 | that allows it to act as if it were an narray of
15 | a quater data type.
16 |
17 | Therefore addition, subtraction, multiplication,
18 | division, negation, absolute, are all defined
19 | in terms of quater operations such as quater
20 | multiplication.
21 |
22 | This allows for much neater code and many routines
23 | which conceptually do the same thing to be written
24 | in the same way for point data and for rotation data.
25 |
26 | The Quaternions class has been desgined such that it
27 | should support broadcasting and slicing in all of the
28 | usual ways.
29 | """
30 |
31 | def __init__(self, qs):
32 | if isinstance(qs, np.ndarray):
33 | if len(qs.shape) == 1: qs = np.array([qs])
34 | self.qs = qs
35 | return
36 |
37 | if isinstance(qs, Quaternions):
38 | self.qs = qs
39 | return
40 |
41 | raise TypeError('Quaternions must be constructed from iterable, numpy array, or Quaternions, not %s' % type(qs))
42 |
43 | def __str__(self): return "Quaternions("+ str(self.qs) + ")"
44 | def __repr__(self): return "Quaternions("+ repr(self.qs) + ")"
45 |
46 | """ Helper Methods for Broadcasting and Data extraction """
47 |
48 | @classmethod
49 | def _broadcast(cls, sqs, oqs, scalar=False):
50 | if isinstance(oqs, float): return sqs, oqs * np.ones(sqs.shape[:-1])
51 |
52 | ss = np.array(sqs.shape) if not scalar else np.array(sqs.shape[:-1])
53 | os = np.array(oqs.shape)
54 |
55 | if len(ss) != len(os):
56 | raise TypeError('Quaternions cannot broadcast together shapes %s and %s' % (sqs.shape, oqs.shape))
57 |
58 | if np.all(ss == os): return sqs, oqs
59 |
60 | if not np.all((ss == os) | (os == np.ones(len(os))) | (ss == np.ones(len(ss)))):
61 | raise TypeError('Quaternions cannot broadcast together shapes %s and %s' % (sqs.shape, oqs.shape))
62 |
63 | sqsn, oqsn = sqs.copy(), oqs.copy()
64 |
65 | for a in np.where(ss == 1)[0]: sqsn = sqsn.repeat(os[a], axis=a)
66 | for a in np.where(os == 1)[0]: oqsn = oqsn.repeat(ss[a], axis=a)
67 |
68 | return sqsn, oqsn
69 |
70 | """ Adding Quaterions is just Defined as Multiplication """
71 |
72 | def __add__(self, other): return self * other
73 | def __sub__(self, other): return self / other
74 |
75 | """ Quaterion Multiplication """
76 |
77 | def __mul__(self, other):
78 | """
79 | Quaternion multiplication has three main methods.
80 |
81 | When multiplying a Quaternions array by Quaternions
82 | normal quater multiplication is performed.
83 |
84 | When multiplying a Quaternions array by a vector
85 | array of the same shape, where the last axis is 3,
86 | it is assumed to be a Quaternion by 3D-Vector
87 | multiplication and the 3D-Vectors are rotated
88 | in space by the Quaternions.
89 |
90 | When multipplying a Quaternions array by a scalar
91 | or vector of different shape it is assumed to be
92 | a Quaternions by Scalars multiplication and the
93 | Quaternions are scaled using Slerp and the identity
94 | quaternions.
95 | """
96 |
97 | """ If Quaternions type do Quaternions * Quaternions """
98 | if isinstance(other, Quaternions):
99 | sqs, oqs = Quaternions._broadcast(self.qs, other.qs)
100 |
101 | q0 = sqs[...,0]; q1 = sqs[...,1];
102 | q2 = sqs[...,2]; q3 = sqs[...,3];
103 | r0 = oqs[...,0]; r1 = oqs[...,1];
104 | r2 = oqs[...,2]; r3 = oqs[...,3];
105 |
106 | qs = np.empty(sqs.shape)
107 | qs[...,0] = r0 * q0 - r1 * q1 - r2 * q2 - r3 * q3
108 | qs[...,1] = r0 * q1 + r1 * q0 - r2 * q3 + r3 * q2
109 | qs[...,2] = r0 * q2 + r1 * q3 + r2 * q0 - r3 * q1
110 | qs[...,3] = r0 * q3 - r1 * q2 + r2 * q1 + r3 * q0
111 |
112 | return Quaternions(qs)
113 |
114 | """ If array type do Quaternions * Vectors """
115 | if isinstance(other, np.ndarray) and other.shape[-1] == 3:
116 | vs = Quaternions(np.concatenate([np.zeros(other.shape[:-1] + (1,)), other], axis=-1))
117 |
118 | return (self * (vs * -self)).imaginaries
119 |
120 | """ If float do Quaternions * Scalars """
121 | if isinstance(other, np.ndarray) or isinstance(other, float):
122 | return Quaternions.slerp(Quaternions.id_like(self), self, other)
123 |
124 | raise TypeError('Cannot multiply/add Quaternions with type %s' % str(type(other)))
125 |
126 | def __div__(self, other):
127 | """
128 | When a Quaternion type is supplied, division is defined
129 | as multiplication by the inverse of that Quaternion.
130 |
131 | When a scalar or vector is supplied it is defined
132 | as multiplicaion of one over the supplied value.
133 | Essentially a scaling.
134 | """
135 |
136 | if isinstance(other, Quaternions): return self * (-other)
137 | if isinstance(other, np.ndarray): return self * (1.0 / other)
138 | if isinstance(other, float): return self * (1.0 / other)
139 | raise TypeError('Cannot divide/subtract Quaternions with type %s' + str(type(other)))
140 |
141 | def __eq__(self, other): return self.qs == other.qs
142 | def __ne__(self, other): return self.qs != other.qs
143 |
144 | def __neg__(self):
145 | """ Invert Quaternions """
146 | return Quaternions(self.qs * np.array([[1, -1, -1, -1]]))
147 |
148 | def __abs__(self):
149 | """ Unify Quaternions To Single Pole """
150 | qabs = self.normalized().copy()
151 | top = np.sum(( qabs.qs) * np.array([1,0,0,0]), axis=-1)
152 | bot = np.sum((-qabs.qs) * np.array([1,0,0,0]), axis=-1)
153 | qabs.qs[top < bot] = -qabs.qs[top < bot]
154 | return qabs
155 |
156 | def __iter__(self): return iter(self.qs)
157 | def __len__(self): return len(self.qs)
158 |
159 | def __getitem__(self, k): return Quaternions(self.qs[k])
160 | def __setitem__(self, k, v): self.qs[k] = v.qs
161 |
162 | @property
163 | def lengths(self):
164 | return np.sum(self.qs**2.0, axis=-1)**0.5
165 |
166 | @property
167 | def reals(self):
168 | return self.qs[...,0]
169 |
170 | @property
171 | def imaginaries(self):
172 | return self.qs[...,1:4]
173 |
174 | @property
175 | def shape(self): return self.qs.shape[:-1]
176 |
177 | def repeat(self, n, **kwargs):
178 | return Quaternions(self.qs.repeat(n, **kwargs))
179 |
180 | def normalized(self):
181 | return Quaternions(self.qs / self.lengths[...,np.newaxis])
182 |
183 | def log(self):
184 | norm = abs(self.normalized())
185 | imgs = norm.imaginaries
186 | lens = np.sqrt(np.sum(imgs**2, axis=-1))
187 | lens = np.arctan2(lens, norm.reals) / (lens + 1e-10)
188 | return imgs * lens[...,np.newaxis]
189 |
190 | def constrained(self, axis):
191 |
192 | rl = self.reals
193 | im = np.sum(axis * self.imaginaries, axis=-1)
194 |
195 | t1 = -2 * np.arctan2(rl, im) + np.pi
196 | t2 = -2 * np.arctan2(rl, im) - np.pi
197 |
198 | top = Quaternions.exp(axis[np.newaxis] * (t1[:,np.newaxis] / 2.0))
199 | bot = Quaternions.exp(axis[np.newaxis] * (t2[:,np.newaxis] / 2.0))
200 | img = self.dot(top) > self.dot(bot)
201 |
202 | ret = top.copy()
203 | ret[ img] = top[ img]
204 | ret[~img] = bot[~img]
205 | return ret
206 |
207 | def constrained_x(self): return self.constrained(np.array([1,0,0]))
208 | def constrained_y(self): return self.constrained(np.array([0,1,0]))
209 | def constrained_z(self): return self.constrained(np.array([0,0,1]))
210 |
211 | def dot(self, q): return np.sum(self.qs * q.qs, axis=-1)
212 |
213 | def copy(self): return Quaternions(np.copy(self.qs))
214 |
215 | def reshape(self, s):
216 | self.qs.reshape(s)
217 | return self
218 |
219 | def interpolate(self, ws):
220 | return Quaternions.exp(np.average(abs(self).log, axis=0, weights=ws))
221 |
222 | def euler(self, order='xyz'):
223 |
224 | q = self.normalized().qs
225 | q0 = q[...,0]
226 | q1 = q[...,1]
227 | q2 = q[...,2]
228 | q3 = q[...,3]
229 | es = np.zeros(self.shape + (3,))
230 |
231 | # These version is wrong on converting
232 | '''
233 | if order == 'xyz':
234 | es[...,0] = np.arctan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
235 | es[...,1] = np.arcsin((2 * (q0 * q2 - q3 * q1)).clip(-1,1))
236 | es[...,2] = np.arctan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
237 | elif order == 'yzx':
238 | es[...,0] = np.arctan2(2 * (q1 * q0 - q2 * q3), -q1 * q1 + q2 * q2 - q3 * q3 + q0 * q0)
239 | es[...,1] = np.arctan2(2 * (q2 * q0 - q1 * q3), q1 * q1 - q2 * q2 - q3 * q3 + q0 * q0)
240 | es[...,2] = np.arcsin((2 * (q1 * q2 + q3 * q0)).clip(-1,1))
241 | else:
242 | raise NotImplementedError('Cannot convert from ordering %s' % order)
243 |
244 | '''
245 |
246 | if order == 'xyz':
247 | es[..., 2] = np.arctan2(2 * (q0 * q3 - q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)
248 | es[..., 1] = np.arcsin((2 * (q1 * q3 + q0 * q2)).clip(-1,1))
249 | es[..., 0] = np.arctan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)
250 | else:
251 | raise NotImplementedError('Cannot convert from ordering %s' % order)
252 |
253 | # These conversion don't appear to work correctly for Maya.
254 | # http://bediyap.com/programming/convert-quaternion-to-euler-rotations/
255 | '''
256 | if order == 'xyz':
257 | es[..., 0] = np.arctan2(2 * (q0 * q3 - q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)
258 | es[..., 1] = np.arcsin((2 * (q1 * q3 + q0 * q2)).clip(-1,1))
259 | es[..., 2] = np.arctan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)
260 | elif order == 'yzx':
261 | es[fa + (0,)] = np.arctan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3)
262 | es[fa + (1,)] = np.arcsin((2 * (q1 * q2 + q0 * q3)).clip(-1,1))
263 | es[fa + (2,)] = np.arctan2(2 * (q0 * q2 - q1 * q3), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)
264 | elif order == 'zxy':
265 | es[fa + (0,)] = np.arctan2(2 * (q0 * q2 - q1 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)
266 | es[fa + (1,)] = np.arcsin((2 * (q0 * q1 + q2 * q3)).clip(-1,1))
267 | es[fa + (2,)] = np.arctan2(2 * (q0 * q3 - q1 * q2), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3)
268 | elif order == 'xzy':
269 | es[fa + (0,)] = np.arctan2(2 * (q0 * q2 + q1 * q3), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)
270 | es[fa + (1,)] = np.arcsin((2 * (q0 * q3 - q1 * q2)).clip(-1,1))
271 | es[fa + (2,)] = np.arctan2(2 * (q0 * q1 + q2 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3)
272 | elif order == 'yxz':
273 | es[fa + (0,)] = np.arctan2(2 * (q1 * q2 + q0 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3)
274 | es[fa + (1,)] = np.arcsin((2 * (q0 * q1 - q2 * q3)).clip(-1,1))
275 | es[fa + (2,)] = np.arctan2(2 * (q1 * q3 + q0 * q2), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)
276 | elif order == 'zyx':
277 | es[fa + (0,)] = np.arctan2(2 * (q0 * q1 + q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)
278 | es[fa + (1,)] = np.arcsin((2 * (q0 * q2 - q1 * q3)).clip(-1,1))
279 | es[fa + (2,)] = np.arctan2(2 * (q0 * q3 + q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)
280 |
281 | else:
282 | raise KeyError('Unknown ordering %s' % order)
283 | '''
284 |
285 |
286 | # https://github.com/ehsan/ogre/blob/master/OgreMain/src/OgreMatrix3.cpp
287 | # Use this class and convert from matrix
288 |
289 | return es
290 |
291 |
292 | def average(self):
293 |
294 | if len(self.shape) == 1:
295 |
296 | import numpy.core.umath_tests as ut
297 | system = ut.matrix_multiply(self.qs[:,:,np.newaxis], self.qs[:,np.newaxis,:]).sum(axis=0)
298 | w, v = np.linalg.eigh(system)
299 | qiT_dot_qref = (self.qs[:,:,np.newaxis] * v[np.newaxis,:,:]).sum(axis=1)
300 | return Quaternions(v[:,np.argmin((1.-qiT_dot_qref**2).sum(axis=0))])
301 |
302 | else:
303 |
304 | raise NotImplementedError('Cannot average multi-dimensionsal Quaternions')
305 |
306 | def angle_axis(self):
307 |
308 | norm = self.normalized()
309 | s = np.sqrt(1 - (norm.reals**2.0))
310 | s[s == 0] = 0.001
311 |
312 | angles = 2.0 * np.arccos(norm.reals)
313 | axis = norm.imaginaries / s[...,np.newaxis]
314 |
315 | return angles, axis
316 |
317 |
318 | def transforms(self):
319 |
320 | qw = self.qs[...,0]
321 | qx = self.qs[...,1]
322 | qy = self.qs[...,2]
323 | qz = self.qs[...,3]
324 |
325 | x2 = qx + qx; y2 = qy + qy; z2 = qz + qz;
326 | xx = qx * x2; yy = qy * y2; wx = qw * x2;
327 | xy = qx * y2; yz = qy * z2; wy = qw * y2;
328 | xz = qx * z2; zz = qz * z2; wz = qw * z2;
329 |
330 | m = np.empty(self.shape + (3,3))
331 | m[...,0,0] = 1.0 - (yy + zz)
332 | m[...,0,1] = xy - wz
333 | m[...,0,2] = xz + wy
334 | m[...,1,0] = xy + wz
335 | m[...,1,1] = 1.0 - (xx + zz)
336 | m[...,1,2] = yz - wx
337 | m[...,2,0] = xz - wy
338 | m[...,2,1] = yz + wx
339 | m[...,2,2] = 1.0 - (xx + yy)
340 |
341 | return m
342 |
343 | def ravel(self):
344 | return self.qs.ravel()
345 |
346 | @classmethod
347 | def id(cls, n):
348 |
349 | if isinstance(n, tuple):
350 | qs = np.zeros(n + (4,))
351 | qs[...,0] = 1.0
352 | return Quaternions(qs)
353 |
354 | if isinstance(n, int) or isinstance(n, long):
355 | qs = np.zeros((n,4))
356 | qs[:,0] = 1.0
357 | return Quaternions(qs)
358 |
359 | raise TypeError('Cannot Construct Quaternion from %s type' % str(type(n)))
360 |
361 | @classmethod
362 | def id_like(cls, a):
363 | qs = np.zeros(a.shape + (4,))
364 | qs[...,0] = 1.0
365 | return Quaternions(qs)
366 |
367 | @classmethod
368 | def exp(cls, ws):
369 |
370 | ts = np.sum(ws**2.0, axis=-1)**0.5
371 | ts[ts == 0] = 0.001
372 | ls = np.sin(ts) / ts
373 |
374 | qs = np.empty(ws.shape[:-1] + (4,))
375 | qs[...,0] = np.cos(ts)
376 | qs[...,1] = ws[...,0] * ls
377 | qs[...,2] = ws[...,1] * ls
378 | qs[...,3] = ws[...,2] * ls
379 |
380 | return Quaternions(qs).normalized()
381 |
382 | @classmethod
383 | def slerp(cls, q0s, q1s, a):
384 |
385 | fst, snd = cls._broadcast(q0s.qs, q1s.qs)
386 | fst, a = cls._broadcast(fst, a, scalar=True)
387 | snd, a = cls._broadcast(snd, a, scalar=True)
388 |
389 | len = np.sum(fst * snd, axis=-1)
390 |
391 | neg = len < 0.0
392 | len[neg] = -len[neg]
393 | snd[neg] = -snd[neg]
394 |
395 | amount0 = np.zeros(a.shape)
396 | amount1 = np.zeros(a.shape)
397 |
398 | linear = (1.0 - len) < 0.01
399 | omegas = np.arccos(len[~linear])
400 | sinoms = np.sin(omegas)
401 |
402 | amount0[ linear] = 1.0 - a[linear]
403 | amount1[ linear] = a[linear]
404 | amount0[~linear] = np.sin((1.0 - a[~linear]) * omegas) / sinoms
405 | amount1[~linear] = np.sin( a[~linear] * omegas) / sinoms
406 |
407 | return Quaternions(
408 | amount0[...,np.newaxis] * fst +
409 | amount1[...,np.newaxis] * snd)
410 |
411 | @classmethod
412 | def between(cls, v0s, v1s):
413 | a = np.cross(v0s, v1s)
414 | w = np.sqrt((v0s**2).sum(axis=-1) * (v1s**2).sum(axis=-1)) + (v0s * v1s).sum(axis=-1)
415 | return Quaternions(np.concatenate([w[...,np.newaxis], a], axis=-1)).normalized()
416 |
417 | @classmethod
418 | def from_angle_axis(cls, angles, axis):
419 | axis = axis / (np.sqrt(np.sum(axis**2, axis=-1)) + 1e-10)[...,np.newaxis]
420 | sines = np.sin(angles / 2.0)[...,np.newaxis]
421 | cosines = np.cos(angles / 2.0)[...,np.newaxis]
422 | return Quaternions(np.concatenate([cosines, axis * sines], axis=-1))
423 |
424 | @classmethod
425 | def from_euler(cls, es, order='xyz', world=False):
426 |
427 | axis = {
428 | 'x' : np.array([1,0,0]),
429 | 'y' : np.array([0,1,0]),
430 | 'z' : np.array([0,0,1]),
431 | }
432 |
433 | q0s = Quaternions.from_angle_axis(es[...,0], axis[order[0]])
434 | q1s = Quaternions.from_angle_axis(es[...,1], axis[order[1]])
435 | q2s = Quaternions.from_angle_axis(es[...,2], axis[order[2]])
436 |
437 | return (q2s * (q1s * q0s)) if world else (q0s * (q1s * q2s))
438 |
439 | @classmethod
440 | def from_transforms(cls, ts):
441 |
442 | d0, d1, d2 = ts[...,0,0], ts[...,1,1], ts[...,2,2]
443 |
444 | q0 = ( d0 + d1 + d2 + 1.0) / 4.0
445 | q1 = ( d0 - d1 - d2 + 1.0) / 4.0
446 | q2 = (-d0 + d1 - d2 + 1.0) / 4.0
447 | q3 = (-d0 - d1 + d2 + 1.0) / 4.0
448 |
449 | q0 = np.sqrt(q0.clip(0,None))
450 | q1 = np.sqrt(q1.clip(0,None))
451 | q2 = np.sqrt(q2.clip(0,None))
452 | q3 = np.sqrt(q3.clip(0,None))
453 |
454 | c0 = (q0 >= q1) & (q0 >= q2) & (q0 >= q3)
455 | c1 = (q1 >= q0) & (q1 >= q2) & (q1 >= q3)
456 | c2 = (q2 >= q0) & (q2 >= q1) & (q2 >= q3)
457 | c3 = (q3 >= q0) & (q3 >= q1) & (q3 >= q2)
458 |
459 | q1[c0] *= np.sign(ts[c0,2,1] - ts[c0,1,2])
460 | q2[c0] *= np.sign(ts[c0,0,2] - ts[c0,2,0])
461 | q3[c0] *= np.sign(ts[c0,1,0] - ts[c0,0,1])
462 |
463 | q0[c1] *= np.sign(ts[c1,2,1] - ts[c1,1,2])
464 | q2[c1] *= np.sign(ts[c1,1,0] + ts[c1,0,1])
465 | q3[c1] *= np.sign(ts[c1,0,2] + ts[c1,2,0])
466 |
467 | q0[c2] *= np.sign(ts[c2,0,2] - ts[c2,2,0])
468 | q1[c2] *= np.sign(ts[c2,1,0] + ts[c2,0,1])
469 | q3[c2] *= np.sign(ts[c2,2,1] + ts[c2,1,2])
470 |
471 | q0[c3] *= np.sign(ts[c3,1,0] - ts[c3,0,1])
472 | q1[c3] *= np.sign(ts[c3,2,0] + ts[c3,0,2])
473 | q2[c3] *= np.sign(ts[c3,2,1] + ts[c3,1,2])
474 |
475 | qs = np.empty(ts.shape[:-2] + (4,))
476 | qs[...,0] = q0
477 | qs[...,1] = q1
478 | qs[...,2] = q2
479 | qs[...,3] = q3
480 |
481 | return cls(qs)
482 |
--------------------------------------------------------------------------------
/dataset/bvh/__pycache__/Quaternions.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wyysf-98/GenMM/fae65cdf199da8a25c4b28ceef20636b534269aa/dataset/bvh/__pycache__/Quaternions.cpython-37.pyc
--------------------------------------------------------------------------------
/dataset/bvh/__pycache__/bvh_io.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wyysf-98/GenMM/fae65cdf199da8a25c4b28ceef20636b534269aa/dataset/bvh/__pycache__/bvh_io.cpython-37.pyc
--------------------------------------------------------------------------------
/dataset/bvh/__pycache__/bvh_parser.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wyysf-98/GenMM/fae65cdf199da8a25c4b28ceef20636b534269aa/dataset/bvh/__pycache__/bvh_parser.cpython-37.pyc
--------------------------------------------------------------------------------
/dataset/bvh/__pycache__/bvh_writer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wyysf-98/GenMM/fae65cdf199da8a25c4b28ceef20636b534269aa/dataset/bvh/__pycache__/bvh_writer.cpython-37.pyc
--------------------------------------------------------------------------------
/dataset/bvh/bvh_io.py:
--------------------------------------------------------------------------------
1 | """
2 | This code is modified from:
3 | http://theorangeduck.com/page/deep-learning-framework-character-motion-synthesis-and-editing
4 |
5 | by Daniel Holden et al
6 | """
7 |
8 |
9 | import re
10 | import numpy as np
11 | from dataset.bvh.Quaternions import Quaternions
12 |
13 | channelmap = {
14 | 'Xrotation' : 'x',
15 | 'Yrotation' : 'y',
16 | 'Zrotation' : 'z'
17 | }
18 |
19 | channelmap_inv = {
20 | 'x': 'Xrotation',
21 | 'y': 'Yrotation',
22 | 'z': 'Zrotation',
23 | }
24 |
25 | ordermap = {
26 | 'x': 0,
27 | 'y': 1,
28 | 'z': 2,
29 | }
30 |
31 |
32 | class Animation:
33 | def __init__(self, rotations, positions, orients, offsets, parents, names, frametime):
34 | self.rotations = rotations
35 | self.positions = positions
36 | self.orients = orients
37 | self.offsets = offsets
38 | self.parent = parents
39 | self.names = names
40 | self.frametime = frametime
41 |
42 | @property
43 | def shape(self):
44 | return self.rotations.shape
45 |
46 |
47 | def load(filename, start=None, end=None, order=None, world=False, need_quater=False) -> Animation:
48 | """
49 | Reads a BVH file and constructs an animation
50 |
51 | Parameters
52 | ----------
53 | filename: str
54 | File to be opened
55 |
56 | start : int
57 | Optional Starting Frame
58 |
59 | end : int
60 | Optional Ending Frame
61 |
62 | order : str
63 | Optional Specifier for joint order.
64 | Given as string E.G 'xyz', 'zxy'
65 |
66 | world : bool
67 | If set to true euler angles are applied
68 | together in world space rather than local
69 | space
70 | Returns
71 | -------
72 |
73 | (animation, joint_names, frametime)
74 | Tuple of loaded animation and joint names
75 | """
76 |
77 | f = open(filename, "r")
78 |
79 | i = 0
80 | active = -1
81 | end_site = False
82 |
83 | names = []
84 | orients = Quaternions.id(0)
85 | offsets = np.array([]).reshape((0, 3))
86 | parents = np.array([], dtype=int)
87 | orders = []
88 |
89 | for line in f:
90 |
91 | if "HIERARCHY" in line: continue
92 | if "MOTION" in line: continue
93 |
94 | """ Modified line read to handle mixamo data """
95 | # rmatch = re.match(r"ROOT (\w+)", line)
96 | rmatch = re.match(r"ROOT (\w+:?\w+)", line)
97 | if rmatch:
98 | names.append(rmatch.group(1))
99 | offsets = np.append(offsets, np.array([[0, 0, 0]]), axis=0)
100 | orients.qs = np.append(orients.qs, np.array([[1, 0, 0, 0]]), axis=0)
101 | parents = np.append(parents, active)
102 | active = (len(parents) - 1)
103 | continue
104 |
105 | if "{" in line: continue
106 |
107 | if "}" in line:
108 | if end_site:
109 | end_site = False
110 | else:
111 | active = parents[active]
112 | continue
113 |
114 | offmatch = re.match(r"\s*OFFSET\s+([\-\d\.e]+)\s+([\-\d\.e]+)\s+([\-\d\.e]+)", line)
115 | if offmatch:
116 | if not end_site:
117 | offsets[active] = np.array([list(map(float, offmatch.groups()))])
118 | continue
119 |
120 | chanmatch = re.match(r"\s*CHANNELS\s+(\d+)", line)
121 | if chanmatch:
122 | channels = int(chanmatch.group(1))
123 |
124 | channelis = 0 if channels == 3 else 3
125 | channelie = 3 if channels == 3 else 6
126 | parts = line.split()[2 + channelis:2 + channelie]
127 | if any([p not in channelmap for p in parts]):
128 | continue
129 | order = "".join([channelmap[p] for p in parts])
130 | orders.append(order)
131 | continue
132 |
133 | """ Modified line read to handle mixamo data """
134 | # jmatch = re.match("\s*JOINT\s+(\w+)", line)
135 | jmatch = re.match("\s*JOINT\s+(\w+:?\w+)", line)
136 | if jmatch:
137 | names.append(jmatch.group(1))
138 | offsets = np.append(offsets, np.array([[0, 0, 0]]), axis=0)
139 | orients.qs = np.append(orients.qs, np.array([[1, 0, 0, 0]]), axis=0)
140 | parents = np.append(parents, active)
141 | active = (len(parents) - 1)
142 | continue
143 |
144 | if "End Site" in line:
145 | end_site = True
146 | continue
147 |
148 | fmatch = re.match("\s*Frames:\s+(\d+)", line)
149 | if fmatch:
150 | if start and end:
151 | fnum = (end - start) - 1
152 | else:
153 | fnum = int(fmatch.group(1))
154 | jnum = len(parents)
155 | positions = offsets[np.newaxis].repeat(fnum, axis=0)
156 | rotations = np.zeros((fnum, len(orients), 3))
157 | continue
158 |
159 | fmatch = re.match("\s*Frame Time:\s+([\d\.]+)", line)
160 | if fmatch:
161 | frametime = float(fmatch.group(1))
162 | continue
163 |
164 | if (start and end) and (i < start or i >= end - 1):
165 | i += 1
166 | continue
167 |
168 | # dmatch = line.strip().split(' ')
169 | dmatch = line.strip().split()
170 | if dmatch:
171 | data_block = np.array(list(map(float, dmatch)))
172 | N = len(parents)
173 | fi = i - start if start else i
174 | if channels == 3:
175 | positions[fi, 0:1] = data_block[0:3]
176 | rotations[fi, :] = data_block[3:].reshape(N, 3)
177 | elif channels == 6:
178 | data_block = data_block.reshape(N, 6)
179 | positions[fi, :] = data_block[:, 0:3]
180 | rotations[fi, :] = data_block[:, 3:6]
181 | elif channels == 9:
182 | positions[fi, 0] = data_block[0:3]
183 | data_block = data_block[3:].reshape(N - 1, 9)
184 | rotations[fi, 1:] = data_block[:, 3:6]
185 | positions[fi, 1:] += data_block[:, 0:3] * data_block[:, 6:9]
186 | else:
187 | raise Exception("Too many channels! %i" % channels)
188 |
189 | i += 1
190 |
191 | f.close()
192 |
193 | all_rotations = []
194 | canonical_order = 'xyz'
195 | for i, order in enumerate(orders):
196 | rot = rotations[:, i:i + 1]
197 | if need_quater:
198 | quat = Quaternions.from_euler(np.radians(rot), order=order, world=world)
199 | all_rotations.append(quat)
200 | continue
201 | elif order != canonical_order:
202 | quat = Quaternions.from_euler(np.radians(rot), order=order, world=world)
203 | rot = np.degrees(quat.euler(order=canonical_order))
204 | all_rotations.append(rot)
205 | rotations = np.concatenate(all_rotations, axis=1)
206 |
207 | return Animation(rotations, positions, orients, offsets, parents, names, frametime)
208 |
209 |
210 | def save(filename, anim, names=None, frametime=1.0/24.0, order='zyx', positions=False, orients=True):
211 | """
212 | Saves an Animation to file as BVH
213 |
214 | Parameters
215 | ----------
216 | filename: str
217 | File to be saved to
218 |
219 | anim : Animation
220 | Animation to save
221 |
222 | names : [str]
223 | List of joint names
224 |
225 | order : str
226 | Optional Specifier for joint order.
227 | Given as string E.G 'xyz', 'zxy'
228 |
229 | frametime : float
230 | Optional Animation Frame time
231 |
232 | positions : bool
233 | Optional specfier to save bone
234 | positions for each frame
235 |
236 | orients : bool
237 | Multiply joint orients to the rotations
238 | before saving.
239 |
240 | """
241 |
242 | if names is None:
243 | names = ["joint_" + str(i) for i in range(len(anim.parents))]
244 |
245 | with open(filename, 'w') as f:
246 |
247 | t = ""
248 | f.write("%sHIERARCHY\n" % t)
249 | f.write("%sROOT %s\n" % (t, names[0]))
250 | f.write("%s{\n" % t)
251 | t += '\t'
252 |
253 | f.write("%sOFFSET %f %f %f\n" % (t, anim.offsets[0,0], anim.offsets[0,1], anim.offsets[0,2]) )
254 | f.write("%sCHANNELS 6 Xposition Yposition Zposition %s %s %s \n" %
255 | (t, channelmap_inv[order[0]], channelmap_inv[order[1]], channelmap_inv[order[2]]))
256 |
257 | for i in range(anim.shape[1]):
258 | if anim.parents[i] == 0:
259 | t = save_joint(f, anim, names, t, i, order=order, positions=positions)
260 |
261 | t = t[:-1]
262 | f.write("%s}\n" % t)
263 |
264 | f.write("MOTION\n")
265 | f.write("Frames: %i\n" % anim.shape[0]);
266 | f.write("Frame Time: %f\n" % frametime);
267 |
268 | #if orients:
269 | # rots = np.degrees((-anim.orients[np.newaxis] * anim.rotations).euler(order=order[::-1]))
270 | #else:
271 | # rots = np.degrees(anim.rotations.euler(order=order[::-1]))
272 | rots = np.degrees(anim.rotations.euler(order=order[::-1]))
273 | poss = anim.positions
274 |
275 | for i in range(anim.shape[0]):
276 | for j in range(anim.shape[1]):
277 |
278 | if positions or j == 0:
279 |
280 | f.write("%f %f %f %f %f %f " % (
281 | poss[i,j,0], poss[i,j,1], poss[i,j,2],
282 | rots[i,j,ordermap[order[0]]], rots[i,j,ordermap[order[1]]], rots[i,j,ordermap[order[2]]]))
283 |
284 | else:
285 |
286 | f.write("%f %f %f " % (
287 | rots[i,j,ordermap[order[0]]], rots[i,j,ordermap[order[1]]], rots[i,j,ordermap[order[2]]]))
288 |
289 | f.write("\n")
290 |
291 |
292 | def save_joint(f, anim, names, t, i, order='zyx', positions=False):
293 |
294 | f.write("%sJOINT %s\n" % (t, names[i]))
295 | f.write("%s{\n" % t)
296 | t += '\t'
297 |
298 | f.write("%sOFFSET %f %f %f\n" % (t, anim.offsets[i,0], anim.offsets[i,1], anim.offsets[i,2]))
299 |
300 | if positions:
301 | f.write("%sCHANNELS 6 Xposition Yposition Zposition %s %s %s \n" % (t,
302 | channelmap_inv[order[0]], channelmap_inv[order[1]], channelmap_inv[order[2]]))
303 | else:
304 | f.write("%sCHANNELS 3 %s %s %s\n" % (t,
305 | channelmap_inv[order[0]], channelmap_inv[order[1]], channelmap_inv[order[2]]))
306 |
307 | end_site = True
308 |
309 | for j in range(anim.shape[1]):
310 | if anim.parents[j] == i:
311 | t = save_joint(f, anim, names, t, j, order=order, positions=positions)
312 | end_site = False
313 |
314 | if end_site:
315 | f.write("%sEnd Site\n" % t)
316 | f.write("%s{\n" % t)
317 | t += '\t'
318 | f.write("%sOFFSET %f %f %f\n" % (t, 0.0, 0.0, 0.0))
319 | t = t[:-1]
320 | f.write("%s}\n" % t)
321 |
322 | t = t[:-1]
323 | f.write("%s}\n" % t)
324 |
325 | return t
326 |
--------------------------------------------------------------------------------
/dataset/bvh/bvh_parser.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import dataset.bvh.bvh_io as bvh_io
4 | from utils.kinematics import ForwardKinematicsJoint
5 | from utils.transforms import quat2repr6d
6 | from utils.contact import foot_contact
7 | from dataset.bvh.Quaternions import Quaternions
8 | from dataset.bvh.bvh_writer import WriterWrapper
9 |
10 |
11 | class Skeleton:
12 | def __init__(self, names, parent, offsets, joint_reduction=True, skeleton_conf=None):
13 | self._names = names
14 | self.original_parent = parent
15 | self._offsets = offsets
16 | self._parent = None
17 | self._ee_id = None
18 | self.contact_names = []
19 |
20 | for i, name in enumerate(self._names):
21 | if ':' in name:
22 | self._names[i] = name[name.find(':')+1:]
23 |
24 | if joint_reduction or skeleton_conf is not None:
25 | assert skeleton_conf is not None, 'skeleton_conf can not be None if you use joint reduction'
26 | corps_names = skeleton_conf['corps_names']
27 | self.contact_names = skeleton_conf['corps_names']
28 | self.contact_threshold = skeleton_conf['contact_threshold']
29 |
30 | self.contact_id = []
31 | for i in self.contact_names:
32 | self.contact_id.append(corps_names.index(i))
33 | else:
34 | self.skeleton_type = -1
35 | corps_names = self._names
36 |
37 | self.details = [] # joints that does not belong to the corps (we are not interested in them)
38 | for i, name in enumerate(self._names):
39 | if name not in corps_names: self.details.append(i)
40 |
41 | self.corps = []
42 | self.simplified_name = []
43 | self.simplify_map = {}
44 | self.inverse_simplify_map = {}
45 |
46 | # Repermute the skeleton id according to the databse
47 | for name in corps_names:
48 | for j in range(len(self._names)):
49 | if name in self._names[j]:
50 | self.corps.append(j)
51 | break
52 | if len(self.corps) != len(corps_names):
53 | for i in self.corps:
54 | print(self._names[i], end=' ')
55 | print(self.corps, self.skeleton_type, len(self.corps), sep='\n')
56 | raise Exception('Problem in this skeleton')
57 |
58 | self.joint_num_simplify = len(self.corps)
59 | for i, j in enumerate(self.corps):
60 | self.simplify_map[j] = i
61 | self.inverse_simplify_map[i] = j
62 | self.simplified_name.append(self._names[j])
63 | self.inverse_simplify_map[0] = -1
64 | for i in range(len(self._names)):
65 | if i in self.details:
66 | self.simplify_map[i] = -1
67 |
68 | @property
69 | def parent(self):
70 | if self._parent is None:
71 | self._parent = self.original_parent[self.corps].copy()
72 | for i in range(self._parent.shape[0]):
73 | if i >= 1: self._parent[i] = self.simplify_map[self._parent[i]]
74 | self._parent = tuple(self._parent)
75 | return self._parent
76 |
77 | @property
78 | def offsets(self):
79 | return torch.tensor(self._offsets[self.corps], dtype=torch.float)
80 |
81 | @property
82 | def names(self):
83 | return self.simplified_name
84 |
85 | @property
86 | def ee_id(self):
87 | raise Exception('Abaddoned')
88 | # if self._ee_id is None:
89 | # self._ee_id = []
90 | # for i in SkeletonDatabase.ee_names[self.skeleton_type]:
91 | # self.ee_id._ee_id(corps_names[self.skeleton_type].index(i))
92 |
93 |
94 | class BVH_file:
95 | def __init__(self, file_path, skeleton_conf=None, requires_contact=False, joint_reduction=True, auto_scale=True):
96 | self.anim = bvh_io.load(file_path)
97 | self._names = self.anim.names
98 | self.frametime = self.anim.frametime
99 | if requires_contact or joint_reduction:
100 | assert skeleton_conf is not None, 'Please provide a skeleton configuration for contact or joint reduction'
101 | self.skeleton = Skeleton(self.anim.names, self.anim.parent, self.anim.offsets, joint_reduction, skeleton_conf)
102 |
103 | # Downsample to 30 fps for our application
104 | if self.frametime < 0.0084:
105 | self.frametime *= 2
106 | self.anim.positions = self.anim.positions[::2]
107 | self.anim.rotations = self.anim.rotations[::2]
108 | if self.frametime < 0.017:
109 | self.frametime *= 2
110 | self.anim.positions = self.anim.positions[::2]
111 | self.anim.rotations = self.anim.rotations[::2]
112 |
113 | self.requires_contact = requires_contact
114 |
115 | if requires_contact:
116 | self.contact_names = self.skeleton.contact_names
117 | else:
118 | self.contact_names = []
119 |
120 | self.fk = ForwardKinematicsJoint(self.skeleton.parent, self.skeleton.offsets)
121 | self.writer = WriterWrapper(self.skeleton.parent, self.skeleton.offsets)
122 |
123 | self.auto_scale = auto_scale
124 | if auto_scale:
125 | self.scale = 1. / np.ceil(self.skeleton.offsets.max().cpu().numpy())
126 | print(f'rescale the skeleton with scale: {self.scale}')
127 | self.rescale(self.scale)
128 | else:
129 | self.scale = 1.0
130 |
131 | if self.requires_contact:
132 | gl_pos = self.joint_position()
133 | self.contact_label = foot_contact(gl_pos[:, self.skeleton.contact_id],
134 | threshold=self.skeleton.contact_threshold)
135 | self.gl_pos = gl_pos
136 |
137 | def local_pos(self):
138 | gl_pos = self.joint_position()
139 | local_pos = gl_pos - gl_pos[:, 0:1, :]
140 | return local_pos[:, 1:]
141 |
142 | def rescale(self, ratio):
143 | self.anim.offsets *= ratio
144 | self.anim.positions *= ratio
145 |
146 | def to_tensor(self, repr='euler', rot_only=False):
147 | if repr not in ['euler', 'quat', 'quaternion', 'repr6d']:
148 | raise Exception('Unknown rotation representation')
149 | positions = self.get_position()
150 | rotations = self.get_rotation(repr=repr)
151 |
152 | if rot_only:
153 | return rotations.reshape(rotations.shape[0], -1)
154 |
155 | if self.requires_contact:
156 | virtual_contact = torch.zeros_like(rotations[:, :len(self.skeleton.contact_id)])
157 | virtual_contact[..., 0] = self.contact_label
158 | rotations = torch.cat([rotations, virtual_contact], dim=1)
159 |
160 | rotations = rotations.reshape(rotations.shape[0], -1)
161 | return torch.cat((rotations, positions), dim=-1)
162 |
163 | def joint_position(self):
164 | positions = torch.tensor(self.anim.positions[:, 0, :], dtype=torch.float)
165 | rotations = self.anim.rotations[:, self.skeleton.corps, :]
166 | rotations = Quaternions.from_euler(np.radians(rotations)).qs
167 | rotations = torch.tensor(rotations, dtype=torch.float)
168 | j_loc = self.fk.forward(rotations, positions)
169 | return j_loc
170 |
171 | def get_rotation(self, repr='quat'):
172 | rotations = self.anim.rotations[:, self.skeleton.corps, :]
173 | if repr == 'quaternion' or repr == 'quat' or repr == 'repr6d':
174 | rotations = Quaternions.from_euler(np.radians(rotations)).qs
175 | rotations = torch.tensor(rotations, dtype=torch.float)
176 | if repr == 'repr6d':
177 | rotations = quat2repr6d(rotations)
178 | if repr == 'euler':
179 | rotations = torch.tensor(rotations, dtype=torch.float)
180 | return rotations
181 |
182 | def get_position(self):
183 | return torch.tensor(self.anim.positions[:, 0, :], dtype=torch.float)
184 |
185 | def dfs(self, x, vis, dist):
186 | fa = self.skeleton.parent
187 | vis[x] = 1
188 | for y in range(len(fa)):
189 | if (fa[y] == x or fa[x] == y) and vis[y] == 0:
190 | dist[y] = dist[x] + 1
191 | self.dfs(y, vis, dist)
192 |
193 | def get_neighbor(self, threshold, enforce_contact=False):
194 | fa = self.skeleton.parent
195 | neighbor_list = []
196 | for x in range(0, len(fa)):
197 | vis = [0 for _ in range(len(fa))]
198 | dist = [0 for _ in range(len(fa))]
199 | self.dfs(x, vis, dist)
200 | neighbor = []
201 | for j in range(0, len(fa)):
202 | if dist[j] <= threshold:
203 | neighbor.append(j)
204 | neighbor_list.append(neighbor)
205 |
206 | contact_list = []
207 | if self.requires_contact:
208 | for i, p_id in enumerate(self.skeleton.contact_id):
209 | v_id = len(neighbor_list)
210 | neighbor_list[p_id].append(v_id)
211 | neighbor_list.append(neighbor_list[p_id])
212 | contact_list.append(v_id)
213 |
214 | root_neighbor = neighbor_list[0]
215 | id_root = len(neighbor_list)
216 |
217 | if enforce_contact:
218 | root_neighbor = root_neighbor + contact_list
219 | for j in contact_list:
220 | neighbor_list[j] = list(set(neighbor_list[j]))
221 |
222 | root_neighbor = list(set(root_neighbor))
223 | for j in root_neighbor:
224 | neighbor_list[j].append(id_root)
225 | root_neighbor.append(id_root)
226 | neighbor_list.append(root_neighbor) # Neighbor for root position
227 | return neighbor_list
--------------------------------------------------------------------------------
/dataset/bvh/bvh_writer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from utils.transforms import quat2euler, repr6d2quat
3 |
4 |
5 | # rotation with shape frame * J * 3
6 | def write_bvh(parent, offset, rotation, position, names, frametime, order, path, endsite=None):
7 | file = open(path, 'w')
8 | frame = rotation.shape[0]
9 | joint_num = rotation.shape[1]
10 | order = order.upper()
11 |
12 | file_string = 'HIERARCHY\n'
13 |
14 | seq = []
15 |
16 | def write_static(idx, prefix):
17 | nonlocal parent, offset, rotation, names, order, endsite, file_string, seq
18 | seq.append(idx)
19 | if idx == 0:
20 | name_label = 'ROOT ' + names[idx]
21 | channel_label = 'CHANNELS 6 Xposition Yposition Zposition {}rotation {}rotation {}rotation'.format(*order)
22 | else:
23 | name_label = 'JOINT ' + names[idx]
24 | channel_label = 'CHANNELS 3 {}rotation {}rotation {}rotation'.format(*order)
25 | offset_label = 'OFFSET %.6f %.6f %.6f' % (offset[idx][0], offset[idx][1], offset[idx][2])
26 |
27 | file_string += prefix + name_label + '\n'
28 | file_string += prefix + '{\n'
29 | file_string += prefix + '\t' + offset_label + '\n'
30 | file_string += prefix + '\t' + channel_label + '\n'
31 |
32 | has_child = False
33 | for y in range(idx+1, rotation.shape[1]):
34 | if parent[y] == idx:
35 | has_child = True
36 | write_static(y, prefix + '\t')
37 | if not has_child:
38 | file_string += prefix + '\t' + 'End Site\n'
39 | file_string += prefix + '\t' + '{\n'
40 | file_string += prefix + '\t\t' + 'OFFSET 0 0 0\n'
41 | file_string += prefix + '\t' + '}\n'
42 |
43 | file_string += prefix + '}\n'
44 |
45 | write_static(0, '')
46 |
47 | file_string += 'MOTION\n' + 'Frames: {}\n'.format(frame) + 'Frame Time: %.8f\n' % frametime
48 | for i in range(frame):
49 | file_string += '%.6f %.6f %.6f ' % (position[i][0], position[i][1], position[i][2])
50 | for j in range(joint_num):
51 | idx = seq[j]
52 | file_string += '%.6f %.6f %.6f ' % (rotation[i][idx][0], rotation[i][idx][1], rotation[i][idx][2])
53 | file_string += '\n'
54 |
55 | file.write(file_string)
56 | return file_string
57 |
58 |
59 | class WriterWrapper:
60 | def __init__(self, parents, offset=None):
61 | self.parents = parents
62 | self.offset = offset
63 |
64 | def write(self, filename, rot, pos, offset=None, names=None, repr='quat'):
65 | """
66 | Write animation to bvh file
67 | :param filename:
68 | :param rot: Quaternion as (w, x, y, z)
69 | :param pos:
70 | :param offset:
71 | :return:
72 | """
73 | if repr not in ['euler', 'quat', 'quaternion', 'repr6d']:
74 | raise Exception('Unknown rotation representation')
75 | if offset is None:
76 | offset = self.offset
77 | if not isinstance(offset, torch.Tensor):
78 | offset = torch.tensor(offset)
79 | n_bone = offset.shape[0]
80 |
81 | if repr == 'repr6d':
82 | rot = rot.reshape(rot.shape[0], -1, 6)
83 | rot = repr6d2quat(rot)
84 | if repr == 'repr6d' or repr == 'quat' or repr == 'quaternion':
85 | rot = rot.reshape(rot.shape[0], -1, 4)
86 | rot /= rot.norm(dim=-1, keepdim=True) ** 0.5
87 | euler = quat2euler(rot, order='xyz')
88 | rot = euler
89 |
90 | if names is None:
91 | names = ['%02d' % i for i in range(n_bone)]
92 | write_bvh(self.parents, offset, rot, pos, names, 1, 'xyz', filename)
93 |
--------------------------------------------------------------------------------
/dataset/bvh_motion.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import torch
4 | import numpy as np
5 | import torch.nn.functional as F
6 | from .motion import MotionData
7 | from .bvh.bvh_parser import BVH_file
8 |
9 |
10 | ## Some skeleton configurations
11 | crab_dance_corps_names = ['ORG_Hips', 'ORG_BN_Bip01_Pelvis', 'DEF_BN_Eye_L_01', 'DEF_BN_Eye_L_02', 'DEF_BN_Eye_L_03', 'DEF_BN_Eye_L_03_end', 'DEF_BN_Eye_R_01', 'DEF_BN_Eye_R_02', 'DEF_BN_Eye_R_03', 'DEF_BN_Eye_R_03_end', 'DEF_BN_Leg_L_11', 'DEF_BN_Leg_L_12', 'DEF_BN_Leg_L_13', 'DEF_BN_Leg_L_14', 'DEF_BN_Leg_L_15', 'DEF_BN_Leg_L_15_end', 'DEF_BN_Leg_R_11', 'DEF_BN_Leg_R_12', 'DEF_BN_Leg_R_13', 'DEF_BN_Leg_R_14', 'DEF_BN_Leg_R_15', 'DEF_BN_Leg_R_15_end', 'DEF_BN_leg_L_01', 'DEF_BN_leg_L_02', 'DEF_BN_leg_L_03', 'DEF_BN_leg_L_04', 'DEF_BN_leg_L_05', 'DEF_BN_leg_L_05_end',
12 | 'DEF_BN_leg_L_06', 'DEF_BN_Leg_L_07', 'DEF_BN_Leg_L_08', 'DEF_BN_Leg_L_09', 'DEF_BN_Leg_L_10', 'DEF_BN_Leg_L_10_end', 'DEF_BN_leg_R_01', 'DEF_BN_leg_R_02', 'DEF_BN_leg_R_03', 'DEF_BN_leg_R_04', 'DEF_BN_leg_R_05', 'DEF_BN_leg_R_05_end', 'DEF_BN_leg_R_06', 'DEF_BN_Leg_R_07', 'DEF_BN_Leg_R_08', 'DEF_BN_Leg_R_09', 'DEF_BN_Leg_R_10', 'DEF_BN_Leg_R_10_end', 'DEF_BN_Bip01_Pelvis', 'DEF_BN_Bip01_Pelvis_end', 'DEF_BN_Arm_L_01', 'DEF_BN_Arm_L_02', 'DEF_BN_Arm_L_03', 'DEF_BN_Arm_L_03_end', 'DEF_BN_Arm_R_01', 'DEF_BN_Arm_R_02', 'DEF_BN_Arm_R_03', 'DEF_BN_Arm_R_03_end']
13 | skeleton_confs = {
14 | 'mixamo': {
15 | 'corps_names': ['Hips', 'LeftUpLeg', 'LeftLeg', 'LeftFoot', 'LeftToeBase', 'LeftToe_End', 'RightUpLeg', 'RightLeg', 'RightFoot', 'RightToeBase', 'RightToe_End', 'Spine', 'Spine1', 'Spine2', 'Neck', 'Head', 'LeftShoulder', 'LeftArm', 'LeftForeArm', 'LeftHand', 'RightShoulder', 'RightArm', 'RightForeArm', 'RightHand'],
16 | 'contact_names': ['LeftToe_End', 'RightToe_End', 'LeftToeBase', 'RightToeBase'],
17 | 'contact_threshold': 0.018
18 | },
19 | 'crab_dance': {
20 | 'corps_names': crab_dance_corps_names,
21 | 'contact_names': [name for name in crab_dance_corps_names if 'end' in name and ('05' in name or '10' in name or '15' in name)],
22 | 'contact_threshold': 0.006
23 | },
24 | 'xia': {
25 | 'corps_names': ['Hips', 'LHipJoint', 'LeftUpLeg', 'LeftLeg', 'LeftFoot', 'LeftToeBase', 'RHipJoint', 'RightUpLeg', 'RightLeg', 'RightFoot', 'RightToeBase', 'LowerBack', 'Spine', 'Spine1', 'Neck', 'Neck1', 'Head', 'LeftShoulder', 'LeftArm', 'LeftForeArm', 'LeftHand', 'LeftFingerBase', 'LeftHandIndex1', 'LThumb', 'RightShoulder', 'RightArm', 'RightForeArm', 'RightHand', 'RightFingerBase', 'RightHandIndex1', 'RThumb'],
26 | 'contact_names': ['LeftToeBase', 'RightToeBase'],
27 | 'contact_threshold': 0.006
28 | }
29 | }
30 |
31 | class BVHMotion:
32 | def __init__(self, bvh_file, skeleton_name=None, repr='quat', use_velo=True, keep_up_pos=False, up_axis='Y_UP', padding_last=False, requires_contact=False, joint_reduction=False):
33 | '''
34 | BVHMotion constructor
35 | Args:
36 | bvh_file : string, bvh_file path to load from
37 | skelton_name : string, name of predefined skeleton, used when joint_reduction==True or contact==True
38 | repr : string, rotation representation, support ['quat', 'repr6d', 'euler']
39 | use_velo : book, whether to transform the joints positions to velocities
40 | keep_up_pos : bool, whether to keep y position when converting to velocity
41 | up_axis : string, string, up axis of the motion data
42 | padding_last : bool, whether to pad the last position
43 | requires_contact : bool, whether to concatenate contact information
44 | joint_reduction : bool, whether to reduce the joint number
45 | '''
46 | self.bvh_file = bvh_file
47 | self.skeleton_name = skeleton_name
48 | if skeleton_name is not None:
49 | assert skeleton_name in skeleton_confs, f'{skeleton_name} not found, please add a skeleton configuration.'
50 | self.requires_contact = requires_contact
51 | self.joint_reduction = joint_reduction
52 |
53 | self.raw_data = BVH_file(bvh_file, skeleton_confs[skeleton_name] if skeleton_name is not None else None, requires_contact, joint_reduction, auto_scale=True)
54 | self.motion_data = MotionData(self.raw_data.to_tensor(repr=repr).permute(1, 0).unsqueeze(0), repr=repr, use_velo=use_velo, keep_up_pos=keep_up_pos, up_axis=up_axis,
55 | padding_last=padding_last, contact_id=self.raw_data.skeleton.contact_id if requires_contact else None)
56 | @property
57 | def repr(self):
58 | return self.motion_data.repr
59 |
60 | @property
61 | def use_velo(self):
62 | return self.motion_data.use_velo
63 |
64 | @property
65 | def keep_up_pos(self):
66 | return self.motion_data.keep_up_pos
67 |
68 | @property
69 | def padding_last(self):
70 | return self.motion_data.padding_last
71 |
72 | @property
73 | def concat_id(self):
74 | return self.motion_data.contact_id
75 |
76 | @property
77 | def n_pad(self):
78 | return self.motion_data.n_pad
79 |
80 | @property
81 | def n_contact(self):
82 | return self.motion_data.n_contact
83 |
84 | @property
85 | def n_rot(self):
86 | return self.motion_data.n_rot
87 |
88 | def sample(self, size=None, slerp=False):
89 | '''
90 | Sample motion data, support slerp
91 | '''
92 | return self.motion_data.sample(size, slerp)
93 |
94 |
95 | def write(self, filename, data):
96 | '''
97 | Parse motion data into position, velocity and contact(if exists)
98 | data should be []
99 | No batch support here!!!
100 | '''
101 | assert len(data.shape) == 3, 'The data format should be [batch_size x n_channels x n_frames]'
102 |
103 | if self.n_pad:
104 | data = data.clone()[:, :-self.n_pad]
105 | if self.use_velo:
106 | data = self.motion_data.to_position(data)
107 | data = data.squeeze().permute(1, 0)
108 | pos = data[..., -3:]
109 | rot = data[..., :-3].reshape(data.shape[0], -1, self.n_rot)
110 | if self.requires_contact:
111 | contact = rot[..., -self.n_contact:, 0]
112 | rot = rot[..., :-self.n_contact, :]
113 | else:
114 | contact = None
115 |
116 | if contact is not None:
117 | np.save(filename + '.contact', contact.detach().cpu().numpy())
118 |
119 | # rescale the output
120 | self.raw_data.rescale(1. / self.raw_data.scale)
121 | pos *= 1. / self.raw_data.scale
122 | self.raw_data.writer.write(filename, rot, pos, names=self.raw_data.skeleton.names, repr=self.repr)
123 |
124 |
125 | def load_multiple_dataset(name_list, **kargs):
126 | with open(name_list, 'r') as f:
127 | names = [line.strip() for line in f.readlines()]
128 | datasets = []
129 | for f in names:
130 | kargs['bvh_file'] = osp.join(osp.dirname(name_list), f)
131 | datasets.append(BVHMotion(**kargs))
132 | return datasets
--------------------------------------------------------------------------------
/dataset/motion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | class MotionData:
6 | def __init__(self, data, repr='quat', use_velo=True, keep_up_pos=True, up_axis='Y', padding_last=False, contact_id=None):
7 | '''
8 | BaseMotionData constructor
9 | Args:
10 | data : torch.Tensor, [batch_size x n_channels x n_frames] input motion data,
11 | the channels dim shoud be [n_joints x n_dim_of_rotation + 3(global position)]
12 | repr : string, rotation representation, support ['quat', 'repr6d', 'euler']
13 | use_velo : book, whether to transform the joints positions to velocities
14 | keep_up_pos : bool, whether to keep up position when converting to velocity
15 | up_axis : string, string, up axis of the motion data
16 | padding_last : bool, whether to pad the last position
17 | contact_id : list, contact joints id
18 | '''
19 | self.data = data
20 | self.repr = repr
21 | self.use_velo = use_velo
22 | self.keep_up_pos = keep_up_pos
23 | self.up_axis = up_axis
24 | self.padding_last = padding_last
25 | self.contact_id = contact_id
26 | self.begin_pos = None
27 |
28 | # assert the rotation representation
29 | if self.repr == 'quat':
30 | self.n_rot = 4
31 | assert (self.data.shape[1] - 3) % 4 == 0, 'rotation is not "quaternion" representation'
32 | elif self.repr == 'repr6d':
33 | self.n_rot = 6
34 | assert (self.data.shape[1] - 3) % 6 == 0, 'rotation is not "repr6d" representation'
35 | elif self.repr == 'eluer':
36 | self.n_rot = 3
37 | assert (self.data.shape[1] - 3) % 3 == 0, 'rotation is not "euler" representation'
38 |
39 | # whether to pad the position data with zero
40 | if self.padding_last:
41 | self.n_pad = self.data.shape[1] - 3 # pad position channels to match the n_channels of rotation
42 | paddings = torch.zeros_like(self.data[:, :self.n_pad])
43 | self.data = torch.cat((self.data, paddings), dim=1)
44 | else:
45 | self.n_pad = 0
46 |
47 | # get the contact information
48 | if self.contact_id is not None:
49 | self.n_contact = len(contact_id)
50 | else:
51 | self.n_contact = 0
52 |
53 | # whether to keep y position when converting to velocity
54 | if self.keep_up_pos:
55 | if self.up_axis == 'X_UP':
56 | self.velo_mask = [-2, -1]
57 | elif self.up_axis == 'Y_UP':
58 | self.velo_mask = [-3, -1]
59 | elif self.up_axis == 'Z_UP':
60 | self.velo_mask = [-3, -2]
61 | else:
62 | self.velo_mask = [-3, -2, -1]
63 |
64 | # whether to convert global position to velocity
65 | if self.use_velo:
66 | self.data = self.to_velocity(self.data)
67 |
68 |
69 | def __len__(self):
70 | '''
71 | return the number of motion frames
72 | '''
73 | return self.data.shape[-1]
74 |
75 |
76 | def sample(self, size=None, slerp=False):
77 | '''
78 | sample the motion data using given size
79 | '''
80 | if size is None:
81 | return self.data
82 | else:
83 | if slerp:
84 | motion = self.slerp(self.data, size=size)
85 | else:
86 | motion = F.interpolate(self.data, size=size, mode='linear', align_corners=False)
87 | return motion
88 |
89 |
90 | def to_velocity(self, pos):
91 | '''
92 | convert motion data to velocity
93 | '''
94 | assert self.begin_pos is None, 'the motion data had been converted to velocity'
95 | msk = [i - self.n_pad for i in self.velo_mask]
96 | velo = pos.detach().clone().to(pos.device)
97 | velo[:, msk, 1:] = pos[:, msk, 1:] - pos[:, msk, :-1]
98 | self.begin_pos = pos[:, msk, 0].clone()
99 | velo[:, msk, 0] = pos[:, msk, 1]
100 | return velo
101 |
102 | def to_position(self, velo):
103 | '''
104 | convert motion data to position
105 | '''
106 | assert self.begin_pos is not None, 'the motion data is already position'
107 | msk = [i - self.n_pad for i in self.velo_mask]
108 | pos = velo.detach().clone().to(velo.device)
109 | pos[:, msk, 0] = self.begin_pos.to(velo.device)
110 | pos[:, msk] = torch.cumsum(pos[:, msk], dim=-1)
111 | self.begin_pos = None
112 | return pos
--------------------------------------------------------------------------------
/dataset/tracks_motion.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os.path import join as pjoin
3 | import numpy as np
4 | import copy
5 | import torch
6 | from .motion import MotionData
7 | from ..utils.transforms import quat2repr6d, quat2euler, repr6d2quat
8 |
9 | class TracksParser():
10 | def __init__(self, tracks_json, scale):
11 | self.tracks_json = tracks_json
12 | self.scale = scale
13 |
14 | self.skeleton_names = []
15 | self.rotations = []
16 | for i, track in enumerate(self.tracks_json):
17 | self.skeleton_names.append(track['name'])
18 | if i == 0:
19 | assert track['type'] == 'vector'
20 | self.position = np.array(track['values']).reshape(-1, 3) * self.scale
21 | self.num_frames = self.position.shape[0]
22 | else:
23 | assert track['type'] == 'quaternion' # DEAFULT: quaternion
24 | rotation = np.array(track['values']).reshape(-1, 4)
25 | if rotation.shape[0] == 0:
26 | rotation = np.zeros((self.num_frames, 4))
27 | elif rotation.shape[0] < self.num_frames:
28 | rotation = np.repeat(rotation, self.num_frames // rotation.shape[0], axis=0)
29 | elif rotation.shape[0] > self.num_frames:
30 | rotation = rotation[:self.num_frames]
31 | self.rotations += [rotation]
32 | self.rotations = np.array(self.rotations, dtype=np.float32)
33 |
34 | def to_tensor(self, repr='euler', rot_only=False):
35 | if repr not in ['euler', 'quat', 'quaternion', 'repr6d']:
36 | raise Exception('Unknown rotation representation')
37 | rotations = self.get_rotation(repr=repr)
38 | positions = self.get_position()
39 |
40 | if rot_only:
41 | return rotations.reshape(rotations.shape[0], -1)
42 |
43 | rotations = rotations.reshape(rotations.shape[0], -1)
44 | return torch.cat((rotations, positions), dim=-1)
45 |
46 | def get_rotation(self, repr='quat'):
47 | if repr == 'quaternion' or repr == 'quat' or repr == 'repr6d':
48 | rotations = torch.tensor(self.rotations, dtype=torch.float).transpose(0, 1)
49 | if repr == 'repr6d':
50 | rotations = quat2repr6d(rotations)
51 | if repr == 'euler':
52 | rotations = quat2euler(rotations)
53 | return rotations
54 |
55 | def get_position(self):
56 | return torch.tensor(self.position, dtype=torch.float32)
57 |
58 | class TracksMotion:
59 | def __init__(self, tracks_json, scale=1.0, repr='quat', use_velo=True, keep_up_pos=True, up_axis='Y_UP', padding_last=False):
60 | '''
61 | TracksMotion constructor
62 | Args:
63 | tracks_json : dict, json format tracks data to load from
64 | scale : float, scale of the tracks motion data
65 | repr : string, rotation representation, support ['quat', 'repr6d', 'euler']
66 | use_velo : book, whether to transform the joints positions to velocities
67 | keep_up_pos : bool, whether to keep y position when converting to velocity
68 | up_axis : string, string, up axis of the motion data
69 | padding_last : bool, whether to pad the last position
70 | '''
71 | self.tracks_json = tracks_json
72 |
73 | self.raw_data = TracksParser(tracks_json, scale)
74 | self.motion_data = MotionData(self.raw_data.to_tensor(repr=repr).permute(1, 0).unsqueeze(0), repr=repr, use_velo=use_velo, keep_up_pos=keep_up_pos, up_axis=up_axis,
75 | padding_last=padding_last, contact_id=None)
76 | @property
77 | def repr(self):
78 | return self.motion_data.repr
79 |
80 | @property
81 | def use_velo(self):
82 | return self.motion_data.use_velo
83 |
84 | @property
85 | def keep_up_pos(self):
86 | return self.motion_data.keep_up_pos
87 |
88 | @property
89 | def padding_last(self):
90 | return self.motion_data.padding_last
91 |
92 | @property
93 | def n_pad(self):
94 | return self.motion_data.n_pad
95 |
96 | @property
97 | def n_rot(self):
98 | return self.motion_data.n_rot
99 |
100 | def sample(self, size=None, slerp=False):
101 | '''
102 | Sample motion data, support slerp
103 | '''
104 | return self.motion_data.sample(size, slerp)
105 |
106 |
107 | def parse(self, motion, keep_velo=False,):
108 | """
109 | No batch support here!!!
110 | :returns tracks_json
111 | """
112 | motion = motion.clone()
113 |
114 | if self.use_velo and not keep_velo:
115 | motion = self.motion_data.to_position(motion)
116 | if self.n_pad:
117 | motion = motion[:, :-self.n_pad]
118 |
119 | motion = motion.squeeze().permute(1, 0)
120 | pos = motion[..., -3:] / self.raw_data.scale
121 | rot = motion[..., :-3].reshape(motion.shape[0], -1, self.n_rot)
122 | if self.repr == 'repr6d':
123 | rot = repr6d2quat(rot)
124 | elif self.repr == 'euler':
125 | raise NotImplementedError('parse "euler is not implemented yet!!!')
126 |
127 | times = []
128 | out_tracks_json = copy.deepcopy(self.tracks_json)
129 | for i, _track in enumerate(out_tracks_json):
130 | if i == 0:
131 | times = [ j * out_tracks_json[i]['times'][1] for j in range(motion.shape[0])]
132 | out_tracks_json[i]['values'] = pos.flatten().detach().cpu().numpy().tolist()
133 | else:
134 | out_tracks_json[i]['values'] = rot[:, i-1, :].flatten().detach().cpu().numpy().tolist()
135 | out_tracks_json[i]['times'] = times
136 |
137 | return out_tracks_json
138 |
--------------------------------------------------------------------------------
/demo.blend:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wyysf-98/GenMM/fae65cdf199da8a25c4b28ceef20636b534269aa/demo.blend
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM pytorch/pytorch:1.10.0-cuda11.3-cudnn8-devel
2 |
3 | # For the convenience for users in China mainland
4 | COPY apt-sources.list /etc/apt/sources.list
5 | # Install some basic utilities
6 | RUN rm /etc/apt/sources.list.d/cuda.list
7 | RUN rm /etc/apt/sources.list.d/nvidia-ml.list
8 | RUN apt-get update && apt-get install -y \
9 | curl \
10 | ca-certificates \
11 | sudo \
12 | git \
13 | bzip2 \
14 | libx11-6 \
15 | gcc \
16 | g++ \
17 | libusb-1.0-0 \
18 | libgl1-mesa-glx \
19 | libglib2.0-dev \
20 | openssh-server \
21 | openssh-client \
22 | iputils-ping \
23 | unzip \
24 | cmake \
25 | libssl-dev \
26 | libosmesa6-dev \
27 | freeglut3-dev \
28 | ffmpeg \
29 | iputils-ping \
30 | && rm -rf /var/lib/apt/lists/*
31 |
32 | # For the convenience for users in China mainland
33 | RUN pip config set global.index-url https://pypi.mirrors.ustc.edu.cn/simple \
34 | && export PATH="/usr/local/bin:$PATH" \
35 | && /bin/bash -c "source ~/.bashrc"
36 | RUN conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ \
37 | && conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ \
38 | && conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/ \
39 | && conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/ \
40 | && conda config --set show_channel_urls yes
41 |
42 | # Install dependencies
43 | COPY requirements.txt requirements.txt
44 | RUN pip install -r requirements.txt --user
45 |
46 | CMD ["python3"]
--------------------------------------------------------------------------------
/docker/README.md:
--------------------------------------------------------------------------------
1 | ## Build Docker Environment and use with GPU Support
2 |
3 | Before you can use this Docker environment, you need to have the following:
4 |
5 | - Docker installed on your system
6 | - NVIDIA drivers installed on your system
7 | - NVIDIA Container Toolkit installed on your system
8 |
9 |
10 | ### Build and Run
11 | 1. Build docker image:
12 | ```sh
13 | docker build -t GenMM:latest .
14 | ```
15 | 2. Start the docker container:
16 | ```sh
17 | docker run --gpus all -it GenMM:latest /bin/bash
18 | ```
19 | 3. Clone the repository:
20 | ```sh
21 | git clone git@github.com:wyysf-98/GenMM.git
22 | ```
23 |
24 | ## Troubleshooting
25 |
26 | If you encounter any issues with the Docker environment with GPU support, please check the following:
27 |
28 | - Make sure that you have installed the NVIDIA drivers and NVIDIA Container Toolkit on your system.
29 | - Make sure that you have specified the --gpus all option when starting the Docker container.
30 | - Make sure that your deep learning application is configured to use the GPU.
--------------------------------------------------------------------------------
/docker/apt-sources.list:
--------------------------------------------------------------------------------
1 | deb https://mirrors.ustc.edu.cn/ubuntu/ bionic main restricted universe multiverse
2 | deb-src https://mirrors.ustc.edu.cn/ubuntu/ bionic main restricted universe multiverse
3 | deb https://mirrors.ustc.edu.cn/ubuntu/ bionic-updates main restricted universe multiverse
4 | deb-src https://mirrors.ustc.edu.cn/ubuntu/ bionic-updates main restricted universe multiverse
5 | deb https://mirrors.ustc.edu.cn/ubuntu/ bionic-backports main restricted universe multiverse
6 | deb-src https://mirrors.ustc.edu.cn/ubuntu/ bionic-backports main restricted universe multiverse
7 | deb https://mirrors.ustc.edu.cn/ubuntu/ bionic-security main restricted universe multiverse
8 | deb-src https://mirrors.ustc.edu.cn/ubuntu/ bionic-security main restricted universe multiverse
9 | deb https://mirrors.ustc.edu.cn/ubuntu/ bionic-proposed main restricted universe multiverse
10 | deb-src https://mirrors.ustc.edu.cn/ubuntu/ bionic-proposed main restricted universe multiverse
--------------------------------------------------------------------------------
/docker/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.12.1
2 | torchvision==0.13.1
3 | tensorboardX==2.5
4 | tqdm==4.62.3
5 | unfoldNd==0.2.0
6 | pyyaml>=5.3.1
7 | gradio==3.34.0
8 | matplotlib==3.3.2
--------------------------------------------------------------------------------
/docker/requirements_blender.txt:
--------------------------------------------------------------------------------
1 | torch==1.12.1
2 | torchvision==0.13.1
3 | tqdm==4.62.3
4 | unfoldNd==0.2.0
5 | pyyaml>=5.3.1
--------------------------------------------------------------------------------
/fix_contact.py:
--------------------------------------------------------------------------------
1 | from dataset.bvh.bvh_parser import BVH_file
2 | from os.path import join as pjoin
3 | import numpy as np
4 | import torch
5 | from utils.contact import constrain_from_contact
6 | from utils.kinematics import InverseKinematicsJoint2
7 | from utils.transforms import repr6d2quat
8 | from tqdm import tqdm
9 | import argparse
10 | import matplotlib.pyplot as plt
11 | from dataset.bvh_motion import skeleton_confs
12 |
13 | def continuous_filter(contact, length=2):
14 | contact = contact.copy()
15 | for j in range(contact.shape[1]):
16 | c = contact[:, j]
17 | t_len = 0
18 | prev = c[0]
19 | for i in range(contact.shape[0]):
20 | if prev == c[i]:
21 | t_len += 1
22 | else:
23 | if t_len <= length:
24 | c[i - t_len:i] = c[i]
25 | t_len = 1
26 | prev = c[i]
27 | return contact
28 |
29 |
30 | def fix_negative_height(contact, constrain, cid):
31 | floor = -1
32 | constrain = constrain.clone()
33 | for i in range(constrain.shape[0]):
34 | for j in range(constrain.shape[1]):
35 | if constrain[i, j, 1] < floor:
36 | constrain[i, j, 1] = floor
37 | return constrain
38 |
39 |
40 | def fix_contact(bvh_file, contact):
41 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
42 | cid = bvh_file.skeleton.contact_id
43 | glb = bvh_file.joint_position()
44 | rotation = bvh_file.get_rotation(repr='repr6d').to(device)
45 | position = bvh_file.get_position().to(device)
46 | contact = contact > 0.5
47 | # contact = continuous_filter(contact)
48 | constrain = constrain_from_contact(contact, glb, cid)
49 | constrain = fix_negative_height(contact, constrain, cid).to(device)
50 | cid = list(range(glb.shape[1]))
51 | ik_solver = InverseKinematicsJoint2(rotation, position, bvh_file.skeleton.offsets.to(device), bvh_file.skeleton.parent,
52 | constrain[:, cid], cid, 0.1, 0.01, use_velo=True)
53 |
54 | loop = tqdm(range(500))
55 | losses = []
56 | for i in loop:
57 | loss = ik_solver.step()
58 | loop.set_description(f'loss = {loss:.07f}')
59 | losses += [loss]
60 | plt.plot(losses)
61 |
62 |
63 | return repr6d2quat(ik_solver.rotations.detach()), ik_solver.get_position()
64 |
65 |
66 | def fix_contact_on_file(prefix, name):
67 | try:
68 | contact = np.load(pjoin(prefix, name + '.bvh.contact.npy'))
69 | except:
70 | print(f'{name} not found')
71 | return
72 | bvh_file = BVH_file(pjoin(prefix, name + '.bvh'), no_scale=True, requires_contact=True)
73 | print('Fixing foot contact with IK...')
74 | res = fix_contact(bvh_file, contact)
75 | bvh_file.writer.write(pjoin(prefix, name + '_fixed.bvh'), res[0], res[1], names=bvh_file.skeleton.names, repr='quat')
76 |
77 |
78 | if __name__ == '__main__':
79 | parser = argparse.ArgumentParser()
80 | parser.add_argument('--prefix', type=str, required=True)
81 | parser.add_argument('--name', type=str, required=True)
82 | parser.add_argument('--skeleton_name', type=str, required=True)
83 | args = parser.parse_args()
84 | if args.prefix[0] == '/':
85 | prefix = args.prefix
86 | else:
87 | prefix = f'./results/{args.prefix}'
88 | name = args.name
89 | contact = np.load(pjoin(prefix, name + '.bvh.contact.npy'))
90 | bvh_file = BVH_file(pjoin(prefix, name + '.bvh'), skeleton_confs[args.skeleton_name], auto_scale=False, requires_contact=True)
91 |
92 | res = fix_contact(bvh_file, contact)
93 | plt.savefig(f'{prefix}/losses.png')
94 |
95 | bvh_file.writer.write(pjoin(prefix, name + '_fixed.bvh'), res[0], res[1], names=bvh_file.skeleton.names, repr='quat')
--------------------------------------------------------------------------------
/nearest_neighbor/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .utils import extract_patches, combine_patches, efficient_cdist, get_NNs_Dists
5 |
6 | class PatchCoherentLoss(torch.nn.Module):
7 | def __init__(self, patch_size=7, stride=1, alpha=None, loop=False, cache=False):
8 | super(PatchCoherentLoss, self).__init__()
9 | self.patch_size = patch_size
10 | assert self.patch_size % 2 == 1, "Only support odd patch size"
11 | self.stride = stride
12 | assert self.stride == 1, "Only support stride of 1"
13 | self.alpha = alpha
14 | self.loop = loop
15 | self.cache = cache
16 | if cache:
17 | self.cached_data = None
18 |
19 | def forward(self, X, Ys, dist_wrapper=None, ext=None, return_blended_results=False):
20 | """For each patch in input X find its NN in target Y and sum the their distances"""
21 | assert X.shape[0] == 1, "Only support batch size of 1 for X"
22 | dist_fn = lambda X, Y: dist_wrapper(efficient_cdist, X, Y) if dist_wrapper is not None else efficient_cdist(X, Y)
23 |
24 | x_patches = extract_patches(X, self.patch_size, self.stride, loop=self.loop)
25 |
26 | if not self.cache or self.cached_data is None:
27 | y_patches = []
28 | for y in Ys:
29 | y_patches += [extract_patches(y, self.patch_size, self.stride, loop=False)]
30 | y_patches = torch.cat(y_patches, dim=1)
31 | self.cached_data = y_patches
32 | else:
33 | y_patches = self.cached_data
34 |
35 | nnf, dist = get_NNs_Dists(dist_fn, x_patches.squeeze(0), y_patches.squeeze(0), self.alpha)
36 |
37 | if return_blended_results:
38 | return combine_patches(X.shape, y_patches[:, nnf, :], self.patch_size, self.stride, loop=self.loop), dist.mean()
39 | else:
40 | return dist.mean()
41 |
42 | def clean_cache(self):
43 | self.cached_data = None
--------------------------------------------------------------------------------
/nearest_neighbor/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | this file borrows some codes from https://github.com/ariel415el/Efficient-GPNN/blob/main/utils/NN.py.
3 | """
4 | import torch
5 | import torch.nn.functional as F
6 | import unfoldNd
7 |
8 | def extract_patches(x, patch_size, stride, loop=False):
9 | """Extract patches from a motion sequence"""
10 | b, c, _t = x.shape
11 |
12 | # manually padding to loop
13 | if loop:
14 | half = patch_size // 2
15 | front, tail = x[:,:,:half], x[:,:,-half:]
16 | x = torch.concat([tail, x, front], dim=-1)
17 |
18 | x_patches = unfoldNd.unfoldNd(x, kernel_size=patch_size, stride=stride).transpose(1, 2).reshape(b, -1, c, patch_size)
19 |
20 | return x_patches.view(b, -1, c * patch_size)
21 |
22 | def combine_patches(x_shape, ys, patch_size, stride, loop=False):
23 | """Combine motion patches"""
24 |
25 | # manually handle the loop situation
26 | out_shape = [*x_shape]
27 | if loop:
28 | padding = patch_size // 2
29 | out_shape[-1] = out_shape[-1] + padding * 2
30 |
31 | combined = unfoldNd.foldNd(ys.permute(0, 2, 1), output_size=tuple(out_shape[-1:]), kernel_size=patch_size, stride=stride)
32 |
33 | # normal fold matrix
34 | input_ones = torch.ones(tuple(out_shape), dtype=ys.dtype, device=ys.device)
35 | divisor = unfoldNd.unfoldNd(input_ones, kernel_size=patch_size, stride=stride)
36 | divisor = unfoldNd.foldNd(divisor, output_size=tuple(out_shape[-1:]), kernel_size=patch_size, stride=stride)
37 | combined = (combined / divisor).squeeze(dim=0).unsqueeze(0)
38 |
39 | if loop:
40 | half = patch_size // 2
41 | front, tail = combined[:,:,:half], combined[:,:,-half:]
42 | combined[:, :, half:2 * half] = (combined[:, :, half:2 * half] + tail) / 2
43 | combined[:, :, - 2 * half:-half] = (front + combined[:, :, - 2 * half:-half]) / 2
44 | combined = combined[:, :, half:-half]
45 |
46 | return combined
47 |
48 |
49 | def efficient_cdist(X, Y):
50 | """
51 | borrowed from https://github.com/ariel415el/Efficient-GPNN/blob/main/utils/NN.py
52 | Pytorch efficient way of computing distances between all vectors in X and Y, i.e (X[:, None] - Y[None, :])**2
53 | Get the nearest neighbor index from Y for each X
54 | :param X: (n1, d) tensor
55 | :param Y: (n2, d) tensor
56 | Returns a n2 n1 of indices
57 | """
58 | dist = (X * X).sum(1)[:, None] + (Y * Y).sum(1)[None, :] - 2.0 * torch.mm(X, torch.transpose(Y, 0, 1))
59 | d = X.shape[1]
60 | dist /= d # normalize by size of vector to make dists independent of the size of d ( use same alpha for all patche-sizes)
61 | return dist # DO NOT use torch.sqrt
62 |
63 |
64 | def get_col_mins_efficient(dist_fn, X, Y, b=1024):
65 | """
66 | borrowed from https://github.com/ariel415el/Efficient-GPNN/blob/main/utils/NN.py
67 | Computes the l2 distance to the closest x or each y.
68 | :param X: (n1, d) tensor
69 | :param Y: (n2, d) tensor
70 | Returns n1 long array of L2 distances
71 | """
72 | n_batches = len(Y) // b
73 | mins = torch.zeros(Y.shape[0], dtype=X.dtype, device=X.device)
74 | for i in range(n_batches):
75 | mins[i * b:(i + 1) * b] = dist_fn(X, Y[i * b:(i + 1) * b]).min(0)[0]
76 | if len(Y) % b != 0:
77 | mins[n_batches * b:] = dist_fn(X, Y[n_batches * b:]).min(0)[0]
78 |
79 | return mins
80 |
81 |
82 | def get_NNs_Dists(dist_fn, X, Y, alpha=None, b=1024):
83 | """
84 | borrowed from https://github.com/ariel415el/Efficient-GPNN/blob/main/utils/NN.py
85 | Get the nearest neighbor index from Y for each X.
86 | Avoids holding a (n1 * n2) amtrix in order to reducing memory footprint to (b * max(n1,n2)).
87 | :param X: (n1, d) tensor
88 | :param Y: (n2, d) tensor
89 | Returns a n2 n1 of indices amd distances
90 | """
91 | if alpha is not None:
92 | normalizing_row = get_col_mins_efficient(dist_fn, X, Y, b=b)
93 | normalizing_row = alpha + normalizing_row[None, :]
94 | else:
95 | normalizing_row = 1
96 |
97 | NNs = torch.zeros(X.shape[0], dtype=torch.long, device=X.device)
98 | Dists = torch.zeros(X.shape[0], dtype=torch.float, device=X.device)
99 |
100 | n_batches = len(X) // b
101 | for i in range(n_batches):
102 | dists = dist_fn(X[i * b:(i + 1) * b], Y) / normalizing_row
103 | NNs[i * b:(i + 1) * b] = dists.min(1)[1]
104 | Dists[i * b:(i + 1) * b] = dists.min(1)[0]
105 | if len(X) % b != 0:
106 | dists = dist_fn(X[n_batches * b:], Y) / normalizing_row
107 | NNs[n_batches * b:] = dists.min(1)[1]
108 | Dists[n_batches * b: ] = dists.min(1)[0]
109 |
110 | return NNs, Dists
111 |
--------------------------------------------------------------------------------
/run_random_generation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import argparse
4 | from GenMM import GenMM
5 | from nearest_neighbor.losses import PatchCoherentLoss
6 | from dataset.bvh_motion import BVHMotion, load_multiple_dataset
7 | from utils.base import ConfigParser, set_seed
8 |
9 | args = argparse.ArgumentParser(
10 | description='Random shuffle the input motion sequence')
11 | args.add_argument('-m', '--mode', default='run',
12 | choices=['run', 'eval', 'debug'], type=str, help='current run mode.')
13 | args.add_argument('-i', '--input', required=True,
14 | type=str, help='exemplar motion path.')
15 | args.add_argument('-o', '--output_dir', default='./output',
16 | type=str, help='output folder path for saving results.')
17 | args.add_argument('-c', '--config', default='./configs/default.yaml',
18 | type=str, help='config file path.')
19 | args.add_argument('-s', '--seed', default=None,
20 | type=int, help='random seed used.')
21 | args.add_argument('-d', '--device', default="cuda:0",
22 | type=str, help='device to use.')
23 | args.add_argument('--post_precess', action='store_true',
24 | help='whether to use IK post-process to fix foot contact.')
25 |
26 | # Use argsparser to overwrite the configuration
27 | # for dataset
28 | args.add_argument('--skeleton_name', type=str,
29 | help='(used when joint_reduction==True or contact==True) skeleton name to load pre-defined joints configuration.')
30 | args.add_argument('--use_velo', type=int,
31 | help='whether to use velocity rather than global position of each joint.')
32 | args.add_argument('--repr', choices=['repr6d', 'quat', 'euler'], type=str,
33 | help='rotation representation, support [epr6d, quat, reuler].')
34 | args.add_argument('--requires_contact', type=int,
35 | help='whether to use contact label.')
36 | args.add_argument('--keep_up_pos', type=int,
37 | help='whether to do not use velocity and keep the y(up) position.')
38 | args.add_argument('--up_axis', type=str, choices=['X_UP', 'Y_UP', 'Z_UP'],
39 | help='up axis of the motion.')
40 | args.add_argument('--padding_last', type=int,
41 | help='whether to pad the last position channel to match the rotation dimension.')
42 | args.add_argument('--joint_reduction', type=int,
43 | help='whether to simplify the skeleton using provided skeleton config.')
44 | args.add_argument('--skeleton_aware', type=int,
45 | help='whether to enable skeleton-aware component.')
46 | args.add_argument('--joints_group', type=str,
47 | help='joints spliting group for using skeleton-aware component.')
48 | # for synthesis
49 | args.add_argument('--num_frames', type=str,
50 | help='number of synthesized frames, supported Nx(N times) and int input.')
51 | args.add_argument('--alpha', type=float,
52 | help='completeness/diversity trade-off alpha value.')
53 | args.add_argument('--num_steps', type=int,
54 | help='number of optimization steps at each pyramid level.')
55 | args.add_argument('--noise_sigma', type=float,
56 | help='standard deviation of the zero mean normal noise added to the initialization.')
57 | args.add_argument('--coarse_ratio', type=float,
58 | help='downscale ratio of the coarse level.')
59 | args.add_argument('--coarse_ratio_factor', type=float,
60 | help='downscale ratio of the coarse level.')
61 | args.add_argument('--pyr_factor', type=float,
62 | help='upsample ratio of each pyramid level.')
63 | args.add_argument('--num_stages_limit', type=int,
64 | help='limit of the number of stages.')
65 | args.add_argument('--patch_size', type=int, help='patch size for generation.')
66 | args.add_argument('--loop', type=int, help='whether to loop the sequence.')
67 | cfg = ConfigParser(args)
68 |
69 |
70 | def generate(cfg):
71 | # seet seed for reproducible
72 | set_seed(cfg.seed)
73 |
74 | # set save path and prepare data for generation
75 | if cfg.input.endswith('.bvh'):
76 | base_dir = osp.join(
77 | cfg.output_dir, cfg.input.split('/')[-1].split('.')[0])
78 | motion_data = [BVHMotion(cfg.input, skeleton_name=cfg.skeleton_name, repr=cfg.repr,
79 | use_velo=cfg.use_velo, keep_up_pos=cfg.keep_up_pos, up_axis=cfg.up_axis, padding_last=cfg.padding_last,
80 | requires_contact=cfg.requires_contact, joint_reduction=cfg.joint_reduction)]
81 | elif cfg.input.endswith('.txt'):
82 | base_dir = osp.join(cfg.output_dir, cfg.input.split(
83 | '/')[-2], cfg.input.split('/')[-1].split('.')[0])
84 | motion_data = load_multiple_dataset(name_list=cfg.input, skeleton_name=cfg.skeleton_name, repr=cfg.repr,
85 | use_velo=cfg.use_velo, keep_up_pos=cfg.keep_up_pos, up_axis=cfg.up_axis, padding_last=cfg.padding_last,
86 | requires_contact=cfg.requires_contact, joint_reduction=cfg.joint_reduction)
87 | else:
88 | raise ValueError('exemplar must be a bvh file or a txt file')
89 | prefix = f"s{cfg.seed}+{cfg.num_frames}+{cfg.repr}+use_velo_{cfg.use_velo}+kypose_{cfg.keep_up_pos}+padding_{cfg.padding_last}" \
90 | f"+contact_{cfg.requires_contact}+jredu_{cfg.joint_reduction}+n{cfg.noise_sigma}+pyr{cfg.pyr_factor}" \
91 | f"+r{cfg.coarse_ratio}_{cfg.coarse_ratio_factor}+itr{cfg.num_steps}+ps_{cfg.patch_size}+alpha_{cfg.alpha}" \
92 | f"+loop_{cfg.loop}"
93 |
94 | # perform the generation
95 | model = GenMM(device=cfg.device, silent=True if cfg.mode == 'eval' else False)
96 | criteria = PatchCoherentLoss(patch_size=cfg.patch_size, alpha=cfg.alpha, loop=cfg.loop, cache=True)
97 | syn = model.run(motion_data, criteria,
98 | num_frames=cfg.num_frames,
99 | num_steps=cfg.num_steps,
100 | noise_sigma=cfg.noise_sigma,
101 | patch_size=cfg.patch_size,
102 | coarse_ratio=cfg.coarse_ratio,
103 | pyr_factor=cfg.pyr_factor,
104 | debug_dir=save_dir if cfg.mode == 'debug' else None)
105 |
106 | # save the generated results
107 | save_dir = osp.join(base_dir, prefix)
108 | os.makedirs(save_dir, exist_ok=True)
109 | motion_data[0].write(f"{save_dir}/syn.bvh", syn)
110 |
111 | if cfg.post_precess:
112 | cmd = f"python fix_contact.py --prefix {osp.abspath(save_dir)} --name syn --skeleton_name={cfg.skeleton_name}"
113 | os.system(cmd)
114 |
115 | if __name__ == '__main__':
116 | generate(cfg)
117 |
--------------------------------------------------------------------------------
/run_web_server.py:
--------------------------------------------------------------------------------
1 | import json
2 | import time
3 | import torch
4 | import argparse
5 | import gradio as gr
6 |
7 | from GenMM import GenMM
8 | from nearest_neighbor.losses import PatchCoherentLoss
9 | from dataset.tracks_motion import TracksMotion
10 |
11 | args = argparse.ArgumentParser(description='Web server for GenMM')
12 | args.add_argument('-d', '--device', default="cuda:0", type=str, help='device to use.')
13 | args.add_argument('--ip', default="0.0.0.0", type=str, help='interface url to host.')
14 | args.add_argument('--port', default=8000, type=int, help='interface port to serve.')
15 | args.add_argument('--debug', action='store_true', help='debug mode.')
16 | args = args.parse_args()
17 |
18 | def generate(data):
19 | data = json.loads(data)
20 |
21 | # create track object
22 | motion_data = [TracksMotion(data['tracks'], repr='repr6d', use_velo=True, keep_y_pos=True, padding_last=False)]
23 | model = GenMM(device=args.device, silent=True)
24 | criteria = PatchCoherentLoss(patch_size=data['setting']['patch_size'],
25 | alpha=data['setting']['alpha'] if data['setting']['completeness'] else None,
26 | loop=data['setting']['loop'], cache=True)
27 |
28 | # start generation
29 | start = time.time()
30 | syn = model.run(motion_data, criteria,
31 | num_frames=str(data['setting']['frames']),
32 | num_steps=data['setting']['num_steps'],
33 | noise_sigma=data['setting']['noise_sigma'],
34 | patch_size=data['setting']['patch_size'],
35 | coarse_ratio=f'{data["setting"]["coarse_ratio"]}x_nframes',
36 | # coarse_ratio=f'3x_patchsize',
37 | pyr_factor=data['setting']['pyr_factor'])
38 | end = time.time()
39 |
40 | data['time'] = end - start
41 | data['tracks'] = motion_data[0].parse(syn)
42 |
43 | return data
44 |
45 | if __name__ == '__main__':
46 | demo = gr.Interface(fn=generate, inputs="json", outputs="json")
47 | demo.launch(debug=args.debug, server_name=args.ip, server_port=args.port)
--------------------------------------------------------------------------------
/utils/base.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import sys
4 | import time
5 | import yaml
6 | import imageio
7 | import random
8 | import shutil
9 | import random
10 | import numpy as np
11 | import torch
12 | from tqdm import tqdm
13 |
14 | # configuration
15 | class ConfigParser():
16 | def __init__(self, args):
17 | """
18 | class to parse configuration.
19 | """
20 | args = args.parse_args()
21 | self.cfg = self.merge_config_file(args)
22 |
23 | # set random seed
24 | self.set_seed()
25 |
26 | def __str__(self):
27 | return str(self.cfg.__dict__)
28 |
29 | def __getattr__(self, name):
30 | """
31 | Access items use dot.notation.
32 | """
33 | return self.cfg.__dict__[name]
34 |
35 | def __getitem__(self, name):
36 | """
37 | Access items like ordinary dict.
38 | """
39 | return self.cfg.__dict__[name]
40 |
41 | def merge_config_file(self, args, allow_invalid=True):
42 | """
43 | Load json config file and merge the arguments
44 | """
45 | assert args.config is not None
46 | with open(args.config, 'r') as f:
47 | cfg = yaml.safe_load(f)
48 | if 'config' in cfg.keys():
49 | del cfg['config']
50 | f.close()
51 | invalid_args = list(set(cfg.keys()) - set(dir(args)))
52 | if invalid_args and not allow_invalid:
53 | raise ValueError(f"Invalid args {invalid_args} in {args.config}.")
54 |
55 | for k in list(cfg.keys()):
56 | if k in args.__dict__.keys() and args.__dict__[k] is not None:
57 | print('=========> overwrite config: {} = {}'.format(k, args.__dict__[k]))
58 | del cfg[k]
59 |
60 | args.__dict__.update(cfg)
61 |
62 | return args
63 |
64 | def set_seed(self):
65 | ''' set random seed for random, numpy and torch. '''
66 | if 'seed' not in self.cfg.__dict__.keys():
67 | return
68 | if self.cfg.seed is None:
69 | self.cfg.seed = int(time.time()) % 1000000
70 | print('=========> set random seed: {}'.format(self.cfg.seed))
71 | # fix random seeds for reproducibility
72 | random.seed(self.cfg.seed)
73 | np.random.seed(self.cfg.seed)
74 | torch.manual_seed(self.cfg.seed)
75 | torch.cuda.manual_seed(self.cfg.seed)
76 |
77 | def save_codes_and_config(self, save_path):
78 | """
79 | save codes and config to $save_path.
80 | """
81 | cur_codes_path = osp.dirname(osp.dirname(os.path.abspath(__file__)))
82 | if os.path.exists(save_path):
83 | shutil.rmtree(save_path)
84 | shutil.copytree(cur_codes_path, osp.join(save_path, 'codes'), \
85 | ignore=shutil.ignore_patterns('*debug*', '*data*', '*output*', '*exps*', '*.txt', '*.json', '*.mp4', '*.png', '*.jpg', '*.bvh', '*.csv', '*.pth', '*.tar', '*.npz'))
86 |
87 | with open(osp.join(save_path, 'config.yaml'), 'w') as f:
88 | f.write(yaml.dump(self.cfg.__dict__))
89 | f.close()
90 |
91 |
92 | # logger util
93 | class logger:
94 | """
95 | Keeps track of the levels and steps of optimization. Logs it via TQDM
96 | """
97 | def __init__(self, n_steps, n_lvls):
98 | self.n_steps = n_steps
99 | self.n_lvls = n_lvls
100 | self.lvl = -1
101 | self.lvl_step = 0
102 | self.steps = 0
103 | self.pbar = tqdm(total=self.n_lvls * self.n_steps, desc='Starting')
104 |
105 | def step(self):
106 | self.pbar.update(1)
107 | self.steps += 1
108 | self.lvl_step += 1
109 |
110 | def new_lvl(self):
111 | self.lvl += 1
112 | self.lvl_step = 0
113 |
114 | def print(self):
115 | self.pbar.set_description(f'Lvl {self.lvl}/{self.n_lvls-1}, step {self.lvl_step}/{self.n_steps}')
116 |
117 |
118 | # other utils
119 | def set_seed(seed=None):
120 | """
121 | Set all the seed for the reproducible
122 | """
123 | if seed is not None:
124 | random.seed(seed)
125 | np.random.seed(seed)
126 | torch.manual_seed(seed)
127 | torch.cuda.manual_seed(seed)
--------------------------------------------------------------------------------
/utils/contact.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def foot_contact_by_height(pos):
5 | eps = 0.25
6 | return (-eps < pos[..., 1]) * (pos[..., 1] < eps)
7 |
8 |
9 | def velocity(pos, padding=False):
10 | velo = pos[1:, ...] - pos[:-1, ...]
11 | velo_norm = torch.norm(velo, dim=-1)
12 | if padding:
13 | pad = torch.zeros_like(velo_norm[:1, :])
14 | velo_norm = torch.cat([pad, velo_norm], dim=0)
15 | return velo_norm
16 |
17 |
18 | def foot_contact(pos, ref_height=1., threshold=0.018):
19 | velo_norm = velocity(pos)
20 | contact = velo_norm < threshold
21 | contact = contact.int()
22 | padding = torch.zeros_like(contact)
23 | contact = torch.cat([padding[:1, :], contact], dim=0)
24 | return contact
25 |
26 |
27 | def alpha(t):
28 | return 2.0 * t * t * t - 3.0 * t * t + 1
29 |
30 |
31 | def lerp(a, l, r):
32 | return (1 - a) * l + a * r
33 |
34 |
35 | def constrain_from_contact(contact, glb, fid='TBD', L=5):
36 | """
37 | :param contact: contact label
38 | :param glb: original global position
39 | :param fid: joint id to fix, corresponding to the order in contact
40 | :param L: frame to look forward/backward
41 | :return:
42 | """
43 | T = glb.shape[0]
44 |
45 | for i, fidx in enumerate(fid): # fidx: index of the foot joint
46 | fixed = contact[:, i] # [T]
47 | s = 0
48 | while s < T:
49 | while s < T and fixed[s] == 0:
50 | s += 1
51 | if s >= T:
52 | break
53 | t = s
54 | avg = glb[t, fidx].clone()
55 | while t + 1 < T and fixed[t + 1] == 1:
56 | t += 1
57 | avg += glb[t, fidx].clone()
58 | avg /= (t - s + 1)
59 |
60 | for j in range(s, t + 1):
61 | glb[j, fidx] = avg.clone()
62 | s = t + 1
63 |
64 | for s in range(T):
65 | if fixed[s] == 1:
66 | continue
67 | l, r = None, None
68 | consl, consr = False, False
69 | for k in range(L):
70 | if s - k - 1 < 0:
71 | break
72 | if fixed[s - k - 1]:
73 | l = s - k - 1
74 | consl = True
75 | break
76 | for k in range(L):
77 | if s + k + 1 >= T:
78 | break
79 | if fixed[s + k + 1]:
80 | r = s + k + 1
81 | consr = True
82 | break
83 | if not consl and not consr:
84 | continue
85 | if consl and consr:
86 | litp = lerp(alpha(1.0 * (s - l + 1) / (L + 1)),
87 | glb[s, fidx], glb[l, fidx])
88 | ritp = lerp(alpha(1.0 * (r - s + 1) / (L + 1)),
89 | glb[s, fidx], glb[r, fidx])
90 | itp = lerp(alpha(1.0 * (s - l + 1) / (r - l + 1)),
91 | ritp, litp)
92 | glb[s, fidx] = itp.clone()
93 | continue
94 | if consl:
95 | litp = lerp(alpha(1.0 * (s - l + 1) / (L + 1)),
96 | glb[s, fidx], glb[l, fidx])
97 | glb[s, fidx] = litp.clone()
98 | continue
99 | if consr:
100 | ritp = lerp(alpha(1.0 * (r - s + 1) / (L + 1)),
101 | glb[s, fidx], glb[r, fidx])
102 | glb[s, fidx] = ritp.clone()
103 | return glb
104 |
--------------------------------------------------------------------------------
/utils/kinematics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from utils.transforms import quat2mat, repr6d2mat, euler2mat
3 |
4 |
5 | class ForwardKinematics:
6 | def __init__(self, parents, offsets=None):
7 | self.parents = parents
8 | if offsets is not None and len(offsets.shape) == 2:
9 | offsets = offsets.unsqueeze(0)
10 | self.offsets = offsets
11 |
12 | def forward(self, rots, offsets=None, global_pos=None):
13 | """
14 | Forward Kinematics: returns a per-bone transformation
15 | @param rots: local joint rotations (batch_size, bone_num, 3, 3)
16 | @param offsets: (batch_size, bone_num, 3) or None
17 | @param global_pos: global_position: (batch_size, 3) or keep it as in offsets (default)
18 | @return: (batch_szie, bone_num, 3, 4)
19 | """
20 | rots = rots.clone()
21 | if offsets is None:
22 | offsets = self.offsets.to(rots.device)
23 | if global_pos is None:
24 | global_pos = offsets[:, 0]
25 |
26 | pos = torch.zeros((rots.shape[0], rots.shape[1], 3), device=rots.device)
27 | rest_pos = torch.zeros_like(pos)
28 | res = torch.zeros((rots.shape[0], rots.shape[1], 3, 4), device=rots.device)
29 |
30 | pos[:, 0] = global_pos
31 | rest_pos[:, 0] = offsets[:, 0]
32 |
33 | for i, p in enumerate(self.parents):
34 | if i != 0:
35 | rots[:, i] = torch.matmul(rots[:, p], rots[:, i])
36 | pos[:, i] = torch.matmul(rots[:, p], offsets[:, i].unsqueeze(-1)).squeeze(-1) + pos[:, p]
37 | rest_pos[:, i] = rest_pos[:, p] + offsets[:, i]
38 |
39 | res[:, i, :3, :3] = rots[:, i]
40 | res[:, i, :, 3] = torch.matmul(rots[:, i], -rest_pos[:, i].unsqueeze(-1)).squeeze(-1) + pos[:, i]
41 |
42 | return res
43 |
44 | def accumulate(self, local_rots):
45 | """
46 | Get global joint rotation from local rotations
47 | @param local_rots: (batch_size, n_bone, 3, 3)
48 | @return: global_rotations
49 | """
50 | res = torch.empty_like(local_rots)
51 | for i, p in enumerate(self.parents):
52 | if i == 0:
53 | res[:, i] = local_rots[:, i]
54 | else:
55 | res[:, i] = torch.matmul(res[:, p], local_rots[:, i])
56 | return res
57 |
58 | def unaccumulate(self, global_rots):
59 | """
60 | Get local joint rotation from global rotations
61 | @param global_rots: (batch_size, n_bone, 3, 3)
62 | @return: local_rotations
63 | """
64 | res = torch.empty_like(global_rots)
65 | inv = torch.empty_like(global_rots)
66 |
67 | for i, p in enumerate(self.parents):
68 | if i == 0:
69 | inv[:, i] = global_rots[:, i].transpose(-2, -1)
70 | res[:, i] = global_rots[:, i]
71 | continue
72 | res[:, i] = torch.matmul(inv[:, p], global_rots[:, i])
73 | inv[:, i] = torch.matmul(res[:, i].transpose(-2, -1), inv[:, p])
74 |
75 | return res
76 |
77 |
78 | class ForwardKinematicsJoint:
79 | def __init__(self, parents, offset):
80 | self.parents = parents
81 | self.offset = offset
82 |
83 | '''
84 | rotation should have shape batch_size * Joint_num * (3/4) * Time
85 | position should have shape batch_size * 3 * Time
86 | offset should have shape batch_size * Joint_num * 3
87 | output have shape batch_size * Time * Joint_num * 3
88 | '''
89 |
90 | def forward(self, rotation: torch.Tensor, position: torch.Tensor, offset=None,
91 | world=True):
92 | '''
93 | if not quater and rotation.shape[-2] != 3: raise Exception('Unexpected shape of rotation')
94 | if quater and rotation.shape[-2] != 4: raise Exception('Unexpected shape of rotation')
95 | rotation = rotation.permute(0, 3, 1, 2)
96 | position = position.permute(0, 2, 1)
97 | '''
98 | if rotation.shape[-1] == 6:
99 | transform = repr6d2mat(rotation)
100 | elif rotation.shape[-1] == 4:
101 | norm = torch.norm(rotation, dim=-1, keepdim=True)
102 | rotation = rotation / norm
103 | transform = quat2mat(rotation)
104 | elif rotation.shape[-1] == 3:
105 | transform = euler2mat(rotation)
106 | else:
107 | raise Exception('Only accept quaternion rotation input')
108 | result = torch.empty(transform.shape[:-2] + (3,), device=position.device)
109 |
110 | if offset is None:
111 | offset = self.offset
112 | offset = offset.reshape((-1, 1, offset.shape[-2], offset.shape[-1], 1))
113 |
114 | result[..., 0, :] = position
115 | for i, pi in enumerate(self.parents):
116 | if pi == -1:
117 | assert i == 0
118 | continue
119 |
120 | result[..., i, :] = torch.matmul(transform[..., pi, :, :], offset[..., i, :, :]).squeeze()
121 | transform[..., i, :, :] = torch.matmul(transform[..., pi, :, :].clone(), transform[..., i, :, :].clone())
122 | if world: result[..., i, :] += result[..., pi, :]
123 | return result
124 |
125 |
126 | class InverseKinematicsJoint:
127 | def __init__(self, rotations: torch.Tensor, positions: torch.Tensor, offset, parents, constrains):
128 | self.rotations = rotations.detach().clone()
129 | self.rotations.requires_grad_(True)
130 | self.position = positions.detach().clone()
131 | self.position.requires_grad_(True)
132 |
133 | self.parents = parents
134 | self.offset = offset
135 | self.constrains = constrains
136 |
137 | self.optimizer = torch.optim.Adam([self.position, self.rotations], lr=1e-3, betas=(0.9, 0.999))
138 | self.criteria = torch.nn.MSELoss()
139 |
140 | self.fk = ForwardKinematicsJoint(parents, offset)
141 |
142 | self.glb = None
143 |
144 | def step(self):
145 | self.optimizer.zero_grad()
146 | glb = self.fk.forward(self.rotations, self.position)
147 | loss = self.criteria(glb, self.constrains)
148 | loss.backward()
149 | self.optimizer.step()
150 | self.glb = glb
151 | return loss.item()
152 |
153 |
154 | class InverseKinematicsJoint2:
155 | def __init__(self, rotations: torch.Tensor, positions: torch.Tensor, offset, parents, constrains, cid,
156 | lambda_rec_rot=1., lambda_rec_pos=1., use_velo=False):
157 | self.use_velo = use_velo
158 | self.rotations_ori = rotations.detach().clone()
159 | self.rotations = rotations.detach().clone()
160 | self.rotations.requires_grad_(True)
161 | self.position_ori = positions.detach().clone()
162 | self.position = positions.detach().clone()
163 | if self.use_velo:
164 | self.position[1:] = self.position[1:] - self.position[:-1]
165 | self.position.requires_grad_(True)
166 |
167 | self.parents = parents
168 | self.offset = offset
169 | self.constrains = constrains.detach().clone()
170 | self.cid = cid
171 |
172 | self.lambda_rec_rot = lambda_rec_rot
173 | self.lambda_rec_pos = lambda_rec_pos
174 |
175 | self.optimizer = torch.optim.Adam([self.position, self.rotations], lr=1e-3, betas=(0.9, 0.999))
176 | self.criteria = torch.nn.MSELoss()
177 |
178 | self.fk = ForwardKinematicsJoint(parents, offset)
179 |
180 | self.glb = None
181 |
182 | def step(self):
183 | self.optimizer.zero_grad()
184 | if self.use_velo:
185 | position = torch.cumsum(self.position, dim=0)
186 | else:
187 | position = self.position
188 | glb = self.fk.forward(self.rotations, position)
189 | self.constrain_loss = self.criteria(glb[:, self.cid], self.constrains)
190 | self.rec_loss_rot = self.criteria(self.rotations, self.rotations_ori)
191 | self.rec_loss_pos = self.criteria(self.position, self.position_ori)
192 | loss = self.constrain_loss + self.rec_loss_rot * self.lambda_rec_rot + self.rec_loss_pos * self.lambda_rec_pos
193 | loss.backward()
194 | self.optimizer.step()
195 | self.glb = glb
196 | return loss.item()
197 |
198 | def get_position(self):
199 | if self.use_velo:
200 | position = torch.cumsum(self.position.detach(), dim=0)
201 | else:
202 | position = self.position.detach()
203 | return position
204 |
--------------------------------------------------------------------------------
/utils/skeleton.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 | import numpy as np
6 |
7 |
8 | class SkeletonConv(nn.Module):
9 | def __init__(self, neighbour_list, in_channels, out_channels, kernel_size, joint_num, stride=1, padding=0,
10 | bias=True, padding_mode='zeros', add_offset=False, in_offset_channel=0):
11 | super(SkeletonConv, self).__init__()
12 |
13 | if in_channels % joint_num != 0 or out_channels % joint_num != 0:
14 | raise Exception('in/out channels should be divided by joint_num')
15 | self.in_channels_per_joint = in_channels // joint_num
16 | self.out_channels_per_joint = out_channels // joint_num
17 |
18 | if padding_mode == 'zeros': padding_mode = 'constant'
19 |
20 | self.expanded_neighbour_list = []
21 | self.expanded_neighbour_list_offset = []
22 | self.neighbour_list = neighbour_list
23 | self.add_offset = add_offset
24 | self.joint_num = joint_num
25 |
26 | self.stride = stride
27 | self.dilation = 1
28 | self.groups = 1
29 | self.padding = padding
30 | self.padding_mode = padding_mode
31 | self._padding_repeated_twice = (padding, padding)
32 |
33 | for neighbour in neighbour_list:
34 | expanded = []
35 | for k in neighbour:
36 | for i in range(self.in_channels_per_joint):
37 | expanded.append(k * self.in_channels_per_joint + i)
38 | self.expanded_neighbour_list.append(expanded)
39 |
40 | if self.add_offset:
41 | self.offset_enc = SkeletonLinear(neighbour_list, in_offset_channel * len(neighbour_list), out_channels)
42 |
43 | for neighbour in neighbour_list:
44 | expanded = []
45 | for k in neighbour:
46 | for i in range(add_offset):
47 | expanded.append(k * in_offset_channel + i)
48 | self.expanded_neighbour_list_offset.append(expanded)
49 |
50 | self.weight = torch.zeros(out_channels, in_channels, kernel_size)
51 | if bias:
52 | self.bias = torch.zeros(out_channels)
53 | else:
54 | self.register_parameter('bias', None)
55 |
56 | self.mask = torch.zeros_like(self.weight)
57 | for i, neighbour in enumerate(self.expanded_neighbour_list):
58 | self.mask[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), neighbour, ...] = 1
59 | self.mask = nn.Parameter(self.mask, requires_grad=False)
60 |
61 | self.description = 'SkeletonConv(in_channels_per_armature={}, out_channels_per_armature={}, kernel_size={}, ' \
62 | 'joint_num={}, stride={}, padding={}, bias={})'.format(
63 | in_channels // joint_num, out_channels // joint_num, kernel_size, joint_num, stride, padding, bias
64 | )
65 |
66 | self.reset_parameters()
67 |
68 | def reset_parameters(self):
69 | for i, neighbour in enumerate(self.expanded_neighbour_list):
70 | """ Use temporary variable to avoid assign to copy of slice, which might lead to un expected result """
71 | tmp = torch.zeros_like(self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1),
72 | neighbour, ...])
73 | nn.init.kaiming_uniform_(tmp, a=math.sqrt(5))
74 | self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1),
75 | neighbour, ...] = tmp
76 | if self.bias is not None:
77 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(
78 | self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), neighbour, ...])
79 | bound = 1 / math.sqrt(fan_in)
80 | tmp = torch.zeros_like(
81 | self.bias[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1)])
82 | nn.init.uniform_(tmp, -bound, bound)
83 | self.bias[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1)] = tmp
84 |
85 | self.weight = nn.Parameter(self.weight)
86 | if self.bias is not None:
87 | self.bias = nn.Parameter(self.bias)
88 |
89 | def set_offset(self, offset):
90 | if not self.add_offset: raise Exception('Wrong Combination of Parameters')
91 | self.offset = offset.reshape(offset.shape[0], -1)
92 |
93 | def forward(self, input):
94 | weight_masked = self.weight * self.mask
95 | res = F.conv1d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
96 | weight_masked, self.bias, self.stride,
97 | 0, self.dilation, self.groups)
98 |
99 | if self.add_offset:
100 | offset_res = self.offset_enc(self.offset)
101 | offset_res = offset_res.reshape(offset_res.shape + (1, ))
102 | res += offset_res / 100
103 | return res
104 |
105 | def __repr__(self):
106 | return self.description
107 |
108 |
109 | class SkeletonLinear(nn.Module):
110 | def __init__(self, neighbour_list, in_channels, out_channels, extra_dim1=False):
111 | super(SkeletonLinear, self).__init__()
112 | self.neighbour_list = neighbour_list
113 | self.in_channels = in_channels
114 | self.out_channels = out_channels
115 | self.in_channels_per_joint = in_channels // len(neighbour_list)
116 | self.out_channels_per_joint = out_channels // len(neighbour_list)
117 | self.extra_dim1 = extra_dim1
118 | self.expanded_neighbour_list = []
119 |
120 | for neighbour in neighbour_list:
121 | expanded = []
122 | for k in neighbour:
123 | for i in range(self.in_channels_per_joint):
124 | expanded.append(k * self.in_channels_per_joint + i)
125 | self.expanded_neighbour_list.append(expanded)
126 |
127 | self.weight = torch.zeros(out_channels, in_channels)
128 | self.mask = torch.zeros(out_channels, in_channels)
129 | self.bias = nn.Parameter(torch.Tensor(out_channels))
130 |
131 | self.reset_parameters()
132 |
133 | def reset_parameters(self):
134 | for i, neighbour in enumerate(self.expanded_neighbour_list):
135 | tmp = torch.zeros_like(
136 | self.weight[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour]
137 | )
138 | self.mask[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] = 1
139 | nn.init.kaiming_uniform_(tmp, a=math.sqrt(5))
140 | self.weight[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] = tmp
141 |
142 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
143 | bound = 1 / math.sqrt(fan_in)
144 | nn.init.uniform_(self.bias, -bound, bound)
145 |
146 | self.weight = nn.Parameter(self.weight)
147 | self.mask = nn.Parameter(self.mask, requires_grad=False)
148 |
149 | def forward(self, input):
150 | input = input.reshape(input.shape[0], -1)
151 | weight_masked = self.weight * self.mask
152 | res = F.linear(input, weight_masked, self.bias)
153 | if self.extra_dim1: res = res.reshape(res.shape + (1,))
154 | return res
155 |
156 |
157 | class SkeletonPoolJoint(nn.Module):
158 | def __init__(self, topology, pooling_mode, channels_per_joint, last_pool=False):
159 | super(SkeletonPoolJoint, self).__init__()
160 |
161 | if pooling_mode != 'mean':
162 | raise Exception('Unimplemented pooling mode in matrix_implementation')
163 |
164 | self.joint_num = len(topology)
165 | self.parent = topology
166 | self.pooling_list = []
167 | self.pooling_mode = pooling_mode
168 |
169 | self.pooling_map = [-1 for _ in range(len(self.parent))]
170 | self.child = [-1 for _ in range(len(self.parent))]
171 | children_cnt = [0 for _ in range(len(self.parent))]
172 | for x, pa in enumerate(self.parent):
173 | if pa < 0: continue
174 | children_cnt[pa] += 1
175 | self.child[pa] = x
176 | self.pooling_map[0] = 0
177 | for x in range(len(self.parent)):
178 | if children_cnt[x] == 0 or (children_cnt[x] == 1 and children_cnt[self.child[x]] > 1):
179 | while children_cnt[x] <= 1:
180 | pa = self.parent[x]
181 | if last_pool:
182 | seq = [x]
183 | while pa != -1 and children_cnt[pa] == 1:
184 | seq = [pa] + seq
185 | x = pa
186 | pa = self.parent[x]
187 | self.pooling_list.append(seq)
188 | break
189 | else:
190 | if pa != -1 and children_cnt[pa] == 1:
191 | self.pooling_list.append([pa, x])
192 | x = self.parent[pa]
193 | else:
194 | self.pooling_list.append([x, ])
195 | break
196 | elif children_cnt[x] > 1:
197 | self.pooling_list.append([x, ])
198 |
199 | self.description = 'SkeletonPool(in_joint_num={}, out_joint_num={})'.format(
200 | len(topology), len(self.pooling_list),
201 | )
202 |
203 | self.pooling_list.sort(key=lambda x:x[0])
204 | for i, a in enumerate(self.pooling_list):
205 | for j in a:
206 | self.pooling_map[j] = i
207 |
208 | self.output_joint_num = len(self.pooling_list)
209 | self.new_topology = [-1 for _ in range(len(self.pooling_list))]
210 | for i, x in enumerate(self.pooling_list):
211 | if i < 1: continue
212 | self.new_topology[i] = self.pooling_map[self.parent[x[0]]]
213 |
214 | self.weight = torch.zeros(len(self.pooling_list) * channels_per_joint, self.joint_num * channels_per_joint)
215 |
216 | for i, pair in enumerate(self.pooling_list):
217 | for j in pair:
218 | for c in range(channels_per_joint):
219 | self.weight[i * channels_per_joint + c, j * channels_per_joint + c] = 1.0 / len(pair)
220 |
221 | self.weight = nn.Parameter(self.weight, requires_grad=False)
222 |
223 | def forward(self, input: torch.Tensor):
224 | return torch.matmul(self.weight, input.unsqueeze(-1)).squeeze(-1)
225 |
226 |
227 | class SkeletonPool(nn.Module):
228 | def __init__(self, edges, pooling_mode, channels_per_edge, last_pool=False):
229 | super(SkeletonPool, self).__init__()
230 |
231 | if pooling_mode != 'mean':
232 | raise Exception('Unimplemented pooling mode in matrix_implementation')
233 |
234 | self.channels_per_edge = channels_per_edge
235 | self.pooling_mode = pooling_mode
236 | self.edge_num = len(edges) + 1
237 | self.seq_list = []
238 | self.pooling_list = []
239 | self.new_edges = []
240 | degree = [0] * 100
241 |
242 | for edge in edges:
243 | degree[edge[0]] += 1
244 | degree[edge[1]] += 1
245 |
246 | def find_seq(j, seq):
247 | nonlocal self, degree, edges
248 |
249 | if degree[j] > 2 and j != 0:
250 | self.seq_list.append(seq)
251 | seq = []
252 |
253 | if degree[j] == 1:
254 | self.seq_list.append(seq)
255 | return
256 |
257 | for idx, edge in enumerate(edges):
258 | if edge[0] == j:
259 | find_seq(edge[1], seq + [idx])
260 |
261 | find_seq(0, [])
262 | for seq in self.seq_list:
263 | if last_pool:
264 | self.pooling_list.append(seq)
265 | continue
266 | if len(seq) % 2 == 1:
267 | self.pooling_list.append([seq[0]])
268 | self.new_edges.append(edges[seq[0]])
269 | seq = seq[1:]
270 | for i in range(0, len(seq), 2):
271 | self.pooling_list.append([seq[i], seq[i + 1]])
272 | self.new_edges.append([edges[seq[i]][0], edges[seq[i + 1]][1]])
273 |
274 | # add global position
275 | self.pooling_list.append([self.edge_num - 1])
276 |
277 | self.description = 'SkeletonPool(in_edge_num={}, out_edge_num={})'.format(
278 | len(edges), len(self.pooling_list)
279 | )
280 |
281 | self.weight = torch.zeros(len(self.pooling_list) * channels_per_edge, self.edge_num * channels_per_edge)
282 |
283 | for i, pair in enumerate(self.pooling_list):
284 | for j in pair:
285 | for c in range(channels_per_edge):
286 | self.weight[i * channels_per_edge + c, j * channels_per_edge + c] = 1.0 / len(pair)
287 |
288 | self.weight = nn.Parameter(self.weight, requires_grad=False)
289 |
290 | def forward(self, input: torch.Tensor):
291 | return torch.matmul(self.weight, input)
292 |
293 |
294 | class SkeletonUnpool(nn.Module):
295 | def __init__(self, pooling_list, channels_per_edge):
296 | super(SkeletonUnpool, self).__init__()
297 | self.pooling_list = pooling_list
298 | self.input_joint_num = len(pooling_list)
299 | self.output_joint_num = 0
300 | self.channels_per_edge = channels_per_edge
301 | for t in self.pooling_list:
302 | self.output_joint_num += len(t)
303 |
304 | self.description = 'SkeletonUnpool(in_joint_num={}, out_joint_num={})'.format(
305 | self.input_joint_num, self.output_joint_num,
306 | )
307 |
308 | self.weight = torch.zeros(self.output_joint_num * channels_per_edge, self.input_joint_num * channels_per_edge)
309 |
310 | for i, pair in enumerate(self.pooling_list):
311 | for j in pair:
312 | for c in range(channels_per_edge):
313 | self.weight[j * channels_per_edge + c, i * channels_per_edge + c] = 1
314 |
315 | self.weight = nn.Parameter(self.weight)
316 | self.weight.requires_grad_(False)
317 |
318 | def forward(self, input: torch.Tensor):
319 | return torch.matmul(self.weight, input.unsqueeze(-1)).squeeze(-1)
320 |
321 |
322 | def find_neighbor_joint(parents, threshold):
323 | n_joint = len(parents)
324 | dist_mat = np.empty((n_joint, n_joint), dtype=np.int)
325 | dist_mat[:, :] = 100000
326 | for i, p in enumerate(parents):
327 | dist_mat[i, i] = 0
328 | if i != 0:
329 | dist_mat[i, p] = dist_mat[p, i] = 1
330 |
331 | """
332 | Floyd's algorithm
333 | """
334 | for k in range(n_joint):
335 | for i in range(n_joint):
336 | for j in range(n_joint):
337 | dist_mat[i, j] = min(dist_mat[i, j], dist_mat[i, k] + dist_mat[k, j])
338 |
339 | neighbor_list = []
340 | for i in range(n_joint):
341 | neighbor = []
342 | for j in range(n_joint):
343 | if dist_mat[i, j] <= threshold:
344 | neighbor.append(j)
345 | neighbor_list.append(neighbor)
346 |
347 | return neighbor_list
348 |
--------------------------------------------------------------------------------
/utils/transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def batch_mm(matrix, matrix_batch):
6 | """
7 | https://github.com/pytorch/pytorch/issues/14489#issuecomment-607730242
8 | :param matrix: Sparse or dense matrix, size (m, n).
9 | :param matrix_batch: Batched dense matrices, size (b, n, k).
10 | :return: The batched matrix-matrix product, size (m, n) x (b, n, k) = (b, m, k).
11 | """
12 | batch_size = matrix_batch.shape[0]
13 | # Stack the vector batch into columns. (b, n, k) -> (n, b, k) -> (n, b*k)
14 | vectors = matrix_batch.transpose(0, 1).reshape(matrix.shape[1], -1)
15 |
16 | # A matrix-matrix product is a batched matrix-vector product of the columns.
17 | # And then reverse the reshaping. (m, n) x (n, b*k) = (m, b*k) -> (m, b, k) -> (b, m, k)
18 | return matrix.mm(vectors).reshape(matrix.shape[0], batch_size, -1).transpose(1, 0)
19 |
20 |
21 | def aa2quat(rots, form='wxyz', unified_orient=True):
22 | """
23 | Convert angle-axis representation to wxyz quaternion and to the half plan (w >= 0)
24 | @param rots: angle-axis rotations, (*, 3)
25 | @param form: quaternion format, either 'wxyz' or 'xyzw'
26 | @param unified_orient: Use unified orientation for quaternion (quaternion is dual cover of SO3)
27 | :return:
28 | """
29 | angles = rots.norm(dim=-1, keepdim=True)
30 | norm = angles.clone()
31 | norm[norm < 1e-8] = 1
32 | axis = rots / norm
33 | quats = torch.empty(rots.shape[:-1] + (4,), device=rots.device, dtype=rots.dtype)
34 | angles = angles * 0.5
35 | if form == 'wxyz':
36 | quats[..., 0] = torch.cos(angles.squeeze(-1))
37 | quats[..., 1:] = torch.sin(angles) * axis
38 | elif form == 'xyzw':
39 | quats[..., :3] = torch.sin(angles) * axis
40 | quats[..., 3] = torch.cos(angles.squeeze(-1))
41 |
42 | if unified_orient:
43 | idx = quats[..., 0] < 0
44 | quats[idx, :] *= -1
45 |
46 | return quats
47 |
48 |
49 | def quat2aa(quats):
50 | """
51 | Convert wxyz quaternions to angle-axis representation
52 | :param quats:
53 | :return:
54 | """
55 | _cos = quats[..., 0]
56 | xyz = quats[..., 1:]
57 | _sin = xyz.norm(dim=-1)
58 | norm = _sin.clone()
59 | norm[norm < 1e-7] = 1
60 | axis = xyz / norm.unsqueeze(-1)
61 | angle = torch.atan2(_sin, _cos) * 2
62 | return axis * angle.unsqueeze(-1)
63 |
64 |
65 | def quat2mat(quats: torch.Tensor):
66 | """
67 | Convert (w, x, y, z) quaternions to 3x3 rotation matrix
68 | :param quats: quaternions of shape (..., 4)
69 | :return: rotation matrices of shape (..., 3, 3)
70 | """
71 | qw = quats[..., 0]
72 | qx = quats[..., 1]
73 | qy = quats[..., 2]
74 | qz = quats[..., 3]
75 |
76 | x2 = qx + qx
77 | y2 = qy + qy
78 | z2 = qz + qz
79 | xx = qx * x2
80 | yy = qy * y2
81 | wx = qw * x2
82 | xy = qx * y2
83 | yz = qy * z2
84 | wy = qw * y2
85 | xz = qx * z2
86 | zz = qz * z2
87 | wz = qw * z2
88 |
89 | m = torch.empty(quats.shape[:-1] + (3, 3), device=quats.device, dtype=quats.dtype)
90 | m[..., 0, 0] = 1.0 - (yy + zz)
91 | m[..., 0, 1] = xy - wz
92 | m[..., 0, 2] = xz + wy
93 | m[..., 1, 0] = xy + wz
94 | m[..., 1, 1] = 1.0 - (xx + zz)
95 | m[..., 1, 2] = yz - wx
96 | m[..., 2, 0] = xz - wy
97 | m[..., 2, 1] = yz + wx
98 | m[..., 2, 2] = 1.0 - (xx + yy)
99 |
100 | return m
101 |
102 |
103 | def quat2euler(q, order='xyz', degrees=True):
104 | """
105 | Convert (w, x, y, z) quaternions to xyz euler angles. This is used for bvh output.
106 | """
107 | q0 = q[..., 0]
108 | q1 = q[..., 1]
109 | q2 = q[..., 2]
110 | q3 = q[..., 3]
111 | es = torch.empty(q0.shape + (3,), device=q.device, dtype=q.dtype)
112 |
113 | if order == 'xyz':
114 | es[..., 2] = torch.atan2(2 * (q0 * q3 - q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)
115 | es[..., 1] = torch.asin((2 * (q1 * q3 + q0 * q2)).clip(-1, 1))
116 | es[..., 0] = torch.atan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)
117 | else:
118 | raise NotImplementedError('Cannot convert to ordering %s' % order)
119 |
120 | if degrees:
121 | es = es * 180 / np.pi
122 |
123 | return es
124 |
125 |
126 | def euler2mat(rots, order='xyz'):
127 | axis = {'x': torch.tensor((1, 0, 0), device=rots.device),
128 | 'y': torch.tensor((0, 1, 0), device=rots.device),
129 | 'z': torch.tensor((0, 0, 1), device=rots.device)}
130 |
131 | rots = rots / 180 * np.pi
132 | mats = []
133 | for i in range(3):
134 | aa = axis[order[i]] * rots[..., i].unsqueeze(-1)
135 | mats.append(aa2mat(aa))
136 | return mats[0] @ (mats[1] @ mats[2])
137 |
138 |
139 | def aa2mat(rots):
140 | """
141 | Convert angle-axis representation to rotation matrix
142 | :param rots: angle-axis representation
143 | :return:
144 | """
145 | quat = aa2quat(rots)
146 | mat = quat2mat(quat)
147 | return mat
148 |
149 |
150 | def mat2quat(R) -> torch.Tensor:
151 | '''
152 | https://github.com/duolu/pyrotation/blob/master/pyrotation/pyrotation.py
153 | Convert a rotation matrix to a unit quaternion.
154 |
155 | This uses the Shepperd’s method for numerical stability.
156 | '''
157 |
158 | # The rotation matrix must be orthonormal
159 |
160 | w2 = (1 + R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2])
161 | x2 = (1 + R[..., 0, 0] - R[..., 1, 1] - R[..., 2, 2])
162 | y2 = (1 - R[..., 0, 0] + R[..., 1, 1] - R[..., 2, 2])
163 | z2 = (1 - R[..., 0, 0] - R[..., 1, 1] + R[..., 2, 2])
164 |
165 | yz = (R[..., 1, 2] + R[..., 2, 1])
166 | xz = (R[..., 2, 0] + R[..., 0, 2])
167 | xy = (R[..., 0, 1] + R[..., 1, 0])
168 |
169 | wx = (R[..., 2, 1] - R[..., 1, 2])
170 | wy = (R[..., 0, 2] - R[..., 2, 0])
171 | wz = (R[..., 1, 0] - R[..., 0, 1])
172 |
173 | w = torch.empty_like(x2)
174 | x = torch.empty_like(x2)
175 | y = torch.empty_like(x2)
176 | z = torch.empty_like(x2)
177 |
178 | flagA = (R[..., 2, 2] < 0) * (R[..., 0, 0] > R[..., 1, 1])
179 | flagB = (R[..., 2, 2] < 0) * (R[..., 0, 0] <= R[..., 1, 1])
180 | flagC = (R[..., 2, 2] >= 0) * (R[..., 0, 0] < -R[..., 1, 1])
181 | flagD = (R[..., 2, 2] >= 0) * (R[..., 0, 0] >= -R[..., 1, 1])
182 |
183 | x[flagA] = torch.sqrt(x2[flagA])
184 | w[flagA] = wx[flagA] / x[flagA]
185 | y[flagA] = xy[flagA] / x[flagA]
186 | z[flagA] = xz[flagA] / x[flagA]
187 |
188 | y[flagB] = torch.sqrt(y2[flagB])
189 | w[flagB] = wy[flagB] / y[flagB]
190 | x[flagB] = xy[flagB] / y[flagB]
191 | z[flagB] = yz[flagB] / y[flagB]
192 |
193 | z[flagC] = torch.sqrt(z2[flagC])
194 | w[flagC] = wz[flagC] / z[flagC]
195 | x[flagC] = xz[flagC] / z[flagC]
196 | y[flagC] = yz[flagC] / z[flagC]
197 |
198 | w[flagD] = torch.sqrt(w2[flagD])
199 | x[flagD] = wx[flagD] / w[flagD]
200 | y[flagD] = wy[flagD] / w[flagD]
201 | z[flagD] = wz[flagD] / w[flagD]
202 |
203 | # if R[..., 2, 2] < 0:
204 | #
205 | # if R[..., 0, 0] > R[..., 1, 1]:
206 | #
207 | # x = torch.sqrt(x2)
208 | # w = wx / x
209 | # y = xy / x
210 | # z = xz / x
211 | #
212 | # else:
213 | #
214 | # y = torch.sqrt(y2)
215 | # w = wy / y
216 | # x = xy / y
217 | # z = yz / y
218 | #
219 | # else:
220 | #
221 | # if R[..., 0, 0] < -R[..., 1, 1]:
222 | #
223 | # z = torch.sqrt(z2)
224 | # w = wz / z
225 | # x = xz / z
226 | # y = yz / z
227 | #
228 | # else:
229 | #
230 | # w = torch.sqrt(w2)
231 | # x = wx / w
232 | # y = wy / w
233 | # z = wz / w
234 |
235 | res = [w, x, y, z]
236 | res = [z.unsqueeze(-1) for z in res]
237 |
238 | return torch.cat(res, dim=-1) / 2
239 |
240 |
241 | def quat2repr6d(quat):
242 | mat = quat2mat(quat)
243 | res = mat[..., :2, :]
244 | res = res.reshape(res.shape[:-2] + (6, ))
245 | return res
246 |
247 |
248 | def repr6d2mat(repr):
249 | x = repr[..., :3]
250 | y = repr[..., 3:]
251 | x = x / x.norm(dim=-1, keepdim=True)
252 | z = torch.cross(x, y)
253 | z = z / z.norm(dim=-1, keepdim=True)
254 | y = torch.cross(z, x)
255 | res = [x, y, z]
256 | res = [v.unsqueeze(-2) for v in res]
257 | mat = torch.cat(res, dim=-2)
258 | return mat
259 |
260 |
261 | def repr6d2quat(repr) -> torch.Tensor:
262 | x = repr[..., :3]
263 | y = repr[..., 3:]
264 | x = x / x.norm(dim=-1, keepdim=True)
265 | z = torch.cross(x, y)
266 | z = z / z.norm(dim=-1, keepdim=True)
267 | y = torch.cross(z, x)
268 | res = [x, y, z]
269 | res = [v.unsqueeze(-2) for v in res]
270 | mat = torch.cat(res, dim=-2)
271 | return mat2quat(mat)
272 |
273 |
274 | def inv_affine(mat):
275 | """
276 | Calculate the inverse of any affine transformation
277 | """
278 | affine = torch.zeros((mat.shape[:2] + (1, 4)))
279 | affine[..., 3] = 1
280 | vert_mat = torch.cat((mat, affine), dim=2)
281 | vert_mat_inv = torch.inverse(vert_mat)
282 | return vert_mat_inv[..., :3, :]
283 |
284 |
285 | def inv_rigid_affine(mat):
286 | """
287 | Calculate the inverse of a rigid affine transformation
288 | """
289 | res = mat.clone()
290 | res[..., :3] = mat[..., :3].transpose(-2, -1)
291 | res[..., 3] = -torch.matmul(res[..., :3], mat[..., 3].unsqueeze(-1)).squeeze(-1)
292 | return res
293 |
294 |
295 | def generate_pose(batch_size, device, uniform=False, factor=1, root_rot=False, n_bone=None, ee=None):
296 | if n_bone is None: n_bone = 24
297 | if ee is not None:
298 | if root_rot:
299 | ee.append(0)
300 | n_bone_ = n_bone
301 | n_bone = len(ee)
302 | axis = torch.randn((batch_size, n_bone, 3), device=device)
303 | axis /= axis.norm(dim=-1, keepdim=True)
304 | if uniform:
305 | angle = torch.rand((batch_size, n_bone, 1), device=device) * np.pi
306 | else:
307 | angle = torch.randn((batch_size, n_bone, 1), device=device) * np.pi / 6 * factor
308 | angle.clamp(-np.pi, np.pi)
309 | poses = axis * angle
310 | if ee is not None:
311 | res = torch.zeros((batch_size, n_bone_, 3), device=device)
312 | for i, id in enumerate(ee):
313 | res[:, id] = poses[:, i]
314 | poses = res
315 | poses = poses.reshape(batch_size, -1)
316 | if not root_rot:
317 | poses[..., :3] = 0
318 | return poses
319 |
320 |
321 | def slerp(l, r, t, unit=True):
322 | """
323 | :param l: shape = (*, n)
324 | :param r: shape = (*, n)
325 | :param t: shape = (*)
326 | :param unit: If l and h are unit vectors
327 | :return:
328 | """
329 | eps = 1e-8
330 | if not unit:
331 | l_n = l / torch.norm(l, dim=-1, keepdim=True)
332 | r_n = r / torch.norm(r, dim=-1, keepdim=True)
333 | else:
334 | l_n = l
335 | r_n = r
336 | omega = torch.acos((l_n * r_n).sum(dim=-1).clamp(-1, 1))
337 | dom = torch.sin(omega)
338 |
339 | flag = dom < eps
340 |
341 | res = torch.empty_like(l_n)
342 | t_t = t[flag].unsqueeze(-1)
343 | res[flag] = (1 - t_t) * l_n[flag] + t_t * r_n[flag]
344 |
345 | flag = ~ flag
346 |
347 | t_t = t[flag]
348 | d_t = dom[flag]
349 | va = torch.sin((1 - t_t) * omega[flag]) / d_t
350 | vb = torch.sin(t_t * omega[flag]) / d_t
351 | res[flag] = (va.unsqueeze(-1) * l_n[flag] + vb.unsqueeze(-1) * r_n[flag])
352 | return res
353 |
354 |
355 | def slerp_quat(l, r, t):
356 | """
357 | slerp for unit quaternions
358 | :param l: (*, 4) unit quaternion
359 | :param r: (*, 4) unit quaternion
360 | :param t: (*) scalar between 0 and 1
361 | """
362 | t = t.expand(l.shape[:-1])
363 | flag = (l * r).sum(dim=-1) >= 0
364 | res = torch.empty_like(l)
365 | res[flag] = slerp(l[flag], r[flag], t[flag])
366 | flag = ~ flag
367 | res[flag] = slerp(-l[flag], r[flag], t[flag])
368 | return res
369 |
370 |
371 | # def slerp_6d(l, r, t):
372 | # l_q = repr6d2quat(l)
373 | # r_q = repr6d2quat(r)
374 | # res_q = slerp_quat(l_q, r_q, t)
375 | # return quat2repr6d(res_q)
376 |
377 |
378 | def interpolate_6d(input, size):
379 | """
380 | :param input: (batch_size, n_channels, length)
381 | :param size: required output size for temporal axis
382 | :return:
383 | """
384 | batch = input.shape[0]
385 | length = input.shape[-1]
386 | input = input.reshape((batch, -1, 6, length))
387 | input = input.permute(0, 1, 3, 2) # (batch_size, n_joint, length, 6)
388 | input_q = repr6d2quat(input)
389 | idx = torch.tensor(list(range(size)), device=input_q.device, dtype=torch.float) / size * (length - 1)
390 | idx_l = torch.floor(idx)
391 | t = idx - idx_l
392 | idx_l = idx_l.long()
393 | idx_r = idx_l + 1
394 | t = t.reshape((1, 1, -1))
395 | res_q = slerp_quat(input_q[..., idx_l, :], input_q[..., idx_r, :], t)
396 | res = quat2repr6d(res_q) # shape = (batch_size, n_joint, t, 6)
397 | res = res.permute(0, 1, 3, 2)
398 | res = res.reshape((batch, -1, size))
399 | return res
400 |
--------------------------------------------------------------------------------