├── utils ├── __init__.py └── geometry.py ├── sample ├── motion.pkl ├── phrase.pkl ├── meta_info.pkl └── read_sample.py ├── KP ├── lap.txt ├── pp.txt ├── lop.txt ├── pdp.txt └── prpp.txt ├── README.md ├── LICENSE └── kp.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sample/motion.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Foruck/Kinematic-Phrases/HEAD/sample/motion.pkl -------------------------------------------------------------------------------- /sample/phrase.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Foruck/Kinematic-Phrases/HEAD/sample/phrase.pkl -------------------------------------------------------------------------------- /sample/meta_info.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Foruck/Kinematic-Phrases/HEAD/sample/meta_info.pkl -------------------------------------------------------------------------------- /KP/lap.txt: -------------------------------------------------------------------------------- 1 | left arm 2 | left leg 3 | left upper arm 4 | left thigh 5 | right arm 6 | right leg 7 | right upper arm 8 | right thigh 9 | -------------------------------------------------------------------------------- /sample/read_sample.py: -------------------------------------------------------------------------------- 1 | import joblib 2 | import numpy as np 3 | meta = joblib.load('meta_info.pkl') 4 | phrase = joblib.load('phrase.pkl') 5 | motion = joblib.load('motion.pkl') 6 | print(phrase.shape) 7 | for key in meta['META2IDX'].keys(): 8 | print(key, len(meta['META2IDX'][key])) 9 | for i, info in enumerate(meta['IDX2META']): 10 | if i < 20: 11 | print(info, np.unique(phrase[:, i])) 12 | print(motion.keys()) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bridging the Gap between Human Motion and Action Semantics via Kinematic Phrases (ECCV 2024) 2 | 3 | ## KP Extraction 4 | 5 | 1. Download SMPL(X) models and put them in ``models``. 6 | 2. Run ``python kp.py`` to extract KP sequence corresponding to the sample motion sequence. 7 | 8 | ## TODOs 9 | 10 | - [x] Kinematic Pharases list in `KP`. 11 | - [x] Kinematic Prompt Generation (KPG) prompts in `kpg.txt`. 12 | - [x] A data sample is available in `sample`. 13 | - [x] Code for KP extraction. 14 | - [ ] Code for KPG benchmarking. 15 | - [ ] Code for Kinematic Phrases Base aggregation. 16 | -------------------------------------------------------------------------------- /KP/pp.txt: -------------------------------------------------------------------------------- 1 | left wrist, rl 2 | left wrist, fb 3 | left wrist, ud 4 | right wrist, rl 5 | right wrist, fb 6 | right wrist, ud 7 | left elbow, rl 8 | left elbow, fb 9 | left elbow, ud 10 | right elbow, rl 11 | right elbow, fb 12 | right elbow, ud 13 | left knee, rl 14 | left knee, fb 15 | left knee, ud 16 | right knee, rl 17 | right knee, fb 18 | right knee, ud 19 | left shoulder, fb 20 | left shoulder, ud 21 | right shoulder, fb 22 | right shoulder, ud 23 | head, rl 24 | head, fb 25 | head, ud 26 | left foot, rl 27 | left foot, fb 28 | left foot, ud 29 | right foot, rl 30 | right foot, fb 31 | right foot, ud 32 | neck, rl 33 | neck, fb 34 | neck, ud 35 | -------------------------------------------------------------------------------- /KP/lop.txt: -------------------------------------------------------------------------------- 1 | left wrist, left elbow, rl 2 | left wrist, left elbow, fb 3 | left wrist, left elbow, ud 4 | right wrist, right elbow, rl 5 | right wrist, right elbow, fb 6 | right wrist, right elbow, ud 7 | left elbow, left shoulder, rl 8 | left elbow, left shoulder, fb 9 | left elbow, left shoulder, ud 10 | right elbow, right shoulder, rl 11 | right elbow, right shoulder, fb 12 | right elbow, right shoulder, ud 13 | left knee, left foot, rl 14 | left knee, left foot, fb 15 | left knee, left hip, rl 16 | left knee, left hip, ud 17 | right knee, right foot, rl 18 | right knee, right foot, fb 19 | right knee, right hip, rl 20 | right knee, right hip, ud 21 | head, neck, rl 22 | head, neck, fb 23 | head, neck, ud 24 | neck, pelvis, fb -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Xinpeng Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /KP/pdp.txt: -------------------------------------------------------------------------------- 1 | left wrist, right wrist 2 | left wrist, right elbow 3 | left wrist, left knee 4 | left wrist, right knee 5 | left wrist, left shoulder 6 | left wrist, right shoulder 7 | left wrist, head 8 | left wrist, left foot 9 | left wrist, right foot 10 | left wrist, neck 11 | left wrist, left hip 12 | left wrist, right hip 13 | left wrist, pelvis 14 | right wrist, left elbow 15 | right wrist, left knee 16 | right wrist, right knee 17 | right wrist, left shoulder 18 | right wrist, right shoulder 19 | right wrist, head 20 | right wrist, left foot 21 | right wrist, right foot 22 | right wrist, neck 23 | right wrist, left hip 24 | right wrist, right hip 25 | right wrist, pelvis 26 | left elbow, right elbow 27 | left elbow, left knee 28 | left elbow, right knee 29 | left elbow, right shoulder 30 | left elbow, head 31 | left elbow, left foot 32 | left elbow, right foot 33 | left elbow, neck 34 | left elbow, left hip 35 | left elbow, right hip 36 | left elbow, pelvis 37 | right elbow, left knee 38 | right elbow, right knee 39 | right elbow, left shoulder 40 | right elbow, head 41 | right elbow, left foot 42 | right elbow, right foot 43 | right elbow, neck 44 | right elbow, left hip 45 | right elbow, right hip 46 | right elbow, pelvis 47 | left knee, right knee 48 | left knee, left shoulder 49 | left knee, right shoulder 50 | left knee, head 51 | left knee, right foot 52 | left knee, neck 53 | left knee, right hip 54 | left knee, pelvis 55 | right knee, left shoulder 56 | right knee, right shoulder 57 | right knee, head 58 | right knee, left foot 59 | right knee, neck 60 | right knee, left hip 61 | right knee, pelvis 62 | left shoulder, right shoulder 63 | left shoulder, head 64 | left shoulder, left foot 65 | left shoulder, right foot 66 | right shoulder, head 67 | right shoulder, left foot 68 | right shoulder, right foot 69 | head, left foot 70 | head, right foot 71 | head, left hip 72 | head, right hip 73 | left foot, right foot 74 | left foot, neck 75 | left foot, left hip 76 | left foot, right hip 77 | left foot, pelvis 78 | right foot, neck 79 | right foot, left hip 80 | right foot, right hip 81 | right foot, pelvis 82 | -------------------------------------------------------------------------------- /KP/prpp.txt: -------------------------------------------------------------------------------- 1 | left wrist, right wrist, rl 2 | left wrist, right wrist, fb 3 | left wrist, right wrist, ud 4 | left wrist, right elbow, rl 5 | left wrist, right elbow, fb 6 | left wrist, right elbow, ud 7 | left wrist, left knee, rl 8 | left wrist, left knee, fb 9 | left wrist, left knee, ud 10 | left wrist, right knee, rl 11 | left wrist, right knee, fb 12 | left wrist, right knee, ud 13 | left wrist, left shoulder, rl 14 | left wrist, left shoulder, fb 15 | left wrist, left shoulder, ud 16 | left wrist, right shoulder, rl 17 | left wrist, right shoulder, fb 18 | left wrist, right shoulder, ud 19 | left wrist, head, rl 20 | left wrist, head, fb 21 | left wrist, head, ud 22 | left wrist, left foot, rl 23 | left wrist, left foot, fb 24 | left wrist, right foot, rl 25 | left wrist, right foot, fb 26 | left wrist, neck, fb 27 | left wrist, neck, ud 28 | left wrist, left hip, rl 29 | left wrist, left hip, ud 30 | left wrist, right hip, rl 31 | left wrist, right hip, ud 32 | left wrist, pelvis, rl 33 | left wrist, pelvis, fb 34 | left wrist, pelvis, ud 35 | right wrist, left elbow, rl 36 | right wrist, left elbow, fb 37 | right wrist, left elbow, ud 38 | right wrist, left knee, rl 39 | right wrist, left knee, fb 40 | right wrist, left knee, ud 41 | right wrist, right knee, rl 42 | right wrist, right knee, fb 43 | right wrist, right knee, ud 44 | right wrist, left shoulder, rl 45 | right wrist, left shoulder, fb 46 | right wrist, left shoulder, ud 47 | right wrist, right shoulder, rl 48 | right wrist, right shoulder, fb 49 | right wrist, right shoulder, ud 50 | right wrist, head, rl 51 | right wrist, head, fb 52 | right wrist, head, ud 53 | right wrist, left foot, rl 54 | right wrist, left foot, fb 55 | right wrist, right foot, rl 56 | right wrist, right foot, fb 57 | right wrist, neck, fb 58 | right wrist, neck, ud 59 | right wrist, left hip, rl 60 | right wrist, left hip, ud 61 | right wrist, right hip, rl 62 | right wrist, right hip, ud 63 | right wrist, pelvis, rl 64 | right wrist, pelvis, fb 65 | right wrist, pelvis, ud 66 | left elbow, right elbow, rl 67 | left elbow, right elbow, fb 68 | left elbow, right elbow, ud 69 | left elbow, left knee, rl 70 | left elbow, left knee, fb 71 | left elbow, left knee, ud 72 | left elbow, right knee, rl 73 | left elbow, right knee, fb 74 | left elbow, right knee, ud 75 | left elbow, right shoulder, fb 76 | left elbow, right shoulder, ud 77 | left elbow, head, rl 78 | left elbow, head, fb 79 | left elbow, head, ud 80 | left elbow, left foot, rl 81 | left elbow, left foot, fb 82 | left elbow, right foot, rl 83 | left elbow, right foot, fb 84 | left elbow, neck, fb 85 | left elbow, neck, ud 86 | left elbow, left hip, rl 87 | left elbow, left hip, ud 88 | left elbow, right hip, rl 89 | left elbow, right hip, ud 90 | left elbow, pelvis, rl 91 | left elbow, pelvis, fb 92 | left elbow, pelvis, ud 93 | left elbow, left eye, rl 94 | left elbow, left eye, fb 95 | left elbow, right eye, rl 96 | left elbow, right eye, fb 97 | right elbow, left knee, rl 98 | right elbow, left knee, fb 99 | right elbow, left knee, ud 100 | right elbow, right knee, rl 101 | right elbow, right knee, fb 102 | right elbow, right knee, ud 103 | right elbow, left shoulder, fb 104 | right elbow, left shoulder, ud 105 | right elbow, head, rl 106 | right elbow, head, fb 107 | right elbow, head, ud 108 | right elbow, left foot, rl 109 | right elbow, left foot, fb 110 | right elbow, right foot, rl 111 | right elbow, right foot, fb 112 | right elbow, neck, fb 113 | right elbow, neck, ud 114 | right elbow, left hip, rl 115 | right elbow, left hip, ud 116 | right elbow, right hip, rl 117 | right elbow, right hip, ud 118 | right elbow, pelvis, rl 119 | right elbow, pelvis, fb 120 | right elbow, pelvis, ud 121 | right elbow, left eye, rl 122 | right elbow, left eye, fb 123 | right elbow, right eye, rl 124 | right elbow, right eye, fb 125 | left knee, right knee, rl 126 | left knee, right knee, fb 127 | left knee, right knee, ud 128 | left knee, left shoulder, rl 129 | left knee, left shoulder, fb 130 | left knee, left shoulder, ud 131 | left knee, right shoulder, fb 132 | left knee, right shoulder, ud 133 | left knee, head, rl 134 | left knee, head, fb 135 | left knee, head, ud 136 | left knee, right foot, rl 137 | left knee, right foot, fb 138 | left knee, right foot, ud 139 | left knee, left foot, ud 140 | left knee, neck, fb 141 | left knee, neck, ud 142 | left knee, right hip, ud 143 | left knee, pelvis, rl 144 | left knee, pelvis, fb 145 | left knee, pelvis, ud 146 | left knee, left eye, rl 147 | left knee, left eye, fb 148 | left knee, right eye, rl 149 | left knee, right eye, fb 150 | right knee, left shoulder, fb 151 | right knee, left shoulder, ud 152 | right knee, right shoulder, rl 153 | right knee, right shoulder, fb 154 | right knee, right shoulder, ud 155 | right knee, head, rl 156 | right knee, head, fb 157 | right knee, head, ud 158 | right knee, left foot, rl 159 | right knee, left foot, fb 160 | right knee, left foot, ud 161 | right knee, right foot, ud 162 | right knee, neck, fb 163 | right knee, neck, ud 164 | right knee, left hip, ud 165 | right knee, pelvis, rl 166 | right knee, pelvis, fb 167 | right knee, pelvis, ud 168 | right knee, left eye, rl 169 | right knee, left eye, fb 170 | right knee, right eye, rl 171 | right knee, right eye, fb 172 | left shoulder, right shoulder, fb 173 | left shoulder, right shoulder, ud 174 | left shoulder, head, fb 175 | left shoulder, head, ud 176 | left shoulder, left foot, rl 177 | left shoulder, left foot, fb 178 | left shoulder, right foot, rl 179 | left shoulder, right foot, fb 180 | right shoulder, head, fb 181 | right shoulder, head, ud 182 | right shoulder, left foot, rl 183 | right shoulder, left foot, fb 184 | right shoulder, right foot, rl 185 | right shoulder, right foot, fb 186 | head, left foot, rl 187 | head, left foot, fb 188 | head, left foot, ud 189 | head, right foot, rl 190 | head, right foot, fb 191 | head, right foot, ud 192 | head, left hip, rl 193 | head, left hip, fb 194 | head, left hip, ud 195 | head, right hip, rl 196 | head, right hip, fb 197 | head, right hip, ud 198 | left foot, right foot, rl 199 | left foot, right foot, fb 200 | left foot, right foot, ud 201 | left foot, left hip, rl 202 | left foot, right hip, rl 203 | left foot, pelvis, rl 204 | left foot, pelvis, fb 205 | left foot, pelvis, ud 206 | left foot, left eye, rl 207 | left foot, left eye, fb 208 | left foot, right eye, rl 209 | left foot, right eye, fb 210 | right foot, left hip, rl 211 | right foot, right hip, rl 212 | right foot, pelvis, rl 213 | right foot, pelvis, fb 214 | right foot, pelvis, ud 215 | right foot, left eye, rl 216 | right foot, left eye, fb 217 | right foot, right eye, rl 218 | right foot, right eye, fb 219 | left eye, right eye, fb 220 | left eye, right eye, ud 221 | left hip, right hip, fb 222 | left hip, right hip, ud 223 | left hip, pelvis, fb 224 | left hip, pelvis, ud 225 | right hip, pelvis, fb 226 | right hip, pelvis, ud 227 | neck, left hip, fb 228 | neck, left hip, ud 229 | neck, right hip, fb 230 | neck, right hip, ud 231 | neck, pelvis, ud 232 | head, pelvis, rl 233 | head, pelvis, fb 234 | head, pelvis, ud 235 | left shoulder, left hip, ud 236 | left shoulder, right hip, ud 237 | left shoulder, pelvis, fb 238 | left shoulder, pelvis, ud 239 | right shoulder, left hip, ud 240 | right shoulder, right hip, ud 241 | right shoulder, pelvis, fb 242 | right shoulder, pelvis, ud 243 | -------------------------------------------------------------------------------- /kp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import joblib 4 | import torch 5 | import torch.nn as nn 6 | import smplx 7 | from utils.geometry import canonicalize_smplx 8 | 9 | class JOI2KP(nn.Module): 10 | 11 | def __init__(self, input_type='smplx'): 12 | super().__init__() 13 | self.input_type = input_type.lower() 14 | meta = joblib.load('sample/meta_info.pkl') 15 | 16 | if self.input_type in ['smplx']: 17 | JOINT_NAMES = [ 18 | "pelvis", 19 | "left_hip", 20 | "right_hip", 21 | "spine1", 22 | "left_knee", 23 | "right_knee", 24 | "spine2", 25 | "left_ankle", 26 | "right_ankle", 27 | "spine3", 28 | "left_foot", 29 | "right_foot", 30 | "neck", 31 | "left_collar", 32 | "right_collar", 33 | "head", 34 | "left_shoulder", 35 | "right_shoulder", 36 | "left_elbow", 37 | "right_elbow", 38 | "left_wrist", 39 | "right_wrist", 40 | "jaw", 41 | "left_eye", 42 | "right_eye", 43 | ] 44 | joint_idx = {' '.join(k.split('_')): v for v, k in enumerate(JOINT_NAMES)} 45 | else: 46 | raise NotImplementedError 47 | 48 | axis_idx = { 49 | 'ud': 0, 50 | 'rl': 1, 51 | 'fb': 2, 52 | 'none': 3, 53 | 'left upper arm': 4, 54 | 'left thigh': 5, 55 | 'right upper arm': 6, 56 | 'right thigh': 7, 57 | } 58 | limbs = { 59 | 'left lower arm': ('left wrist', 'left elbow'), 60 | 'left upper arm': ('left shoulder', 'left elbow'), 61 | 'left shank': ('left foot', 'left knee'), 62 | 'left thigh': ('left hip', 'left knee'), 63 | 'left body': ('left shoulder', 'left hip'), 64 | 'right lower arm': ('right wrist', 'right elbow'), 65 | 'right upper arm': ('right shoulder', 'right elbow'), 66 | 'right shank': ('right foot', 'right knee'), 67 | 'right thigh': ('right hip', 'right knee'), 68 | 'right body': ('right shoulder', 'right hip'), 69 | 'upper body': ('pelvis', 'neck'), 70 | } 71 | fullLimbs = { 72 | 'left arm': ('left lower arm', 'left upper arm'), 73 | 'left leg': ('left shank', 'left thigh'), 74 | 'left upper arm': ('left body', 'left upper arm'), 75 | 'left thigh': ('upper body', 'left thigh'), 76 | 'right arm': ('right lower arm', 'right upper arm'), 77 | 'right leg': ('right shank', 'right thigh'), 78 | 'right upper arm': ('right body', 'right upper arm'), 79 | 'right thigh': ('upper body', 'right thigh'), 80 | } 81 | idx = [] # each item, j1 index, j2 index, axis index 82 | for i, info in enumerate(meta['IDX2META']): 83 | if info[0] == 'pp': 84 | # e.g. 85 | # (left hand, ud) 86 | # 1: left hand moves upwards 87 | # -1: left hand moves downwards 88 | part, ax = info[1] # 1: part moving upwards () 89 | idx.append([joint_idx[part], 0, axis_idx[ax]]) 90 | elif info[0] == 'pdp': 91 | # e.g. 92 | # (left hand, right hand) 93 | # 1: lhand and rhand moves away from each other 94 | # -1: lhand and rhand moves closer 95 | ja, jb = info[1] 96 | idx.append([joint_idx[ja], joint_idx[jb], axis_idx['none']]) 97 | elif info[0] in ['prpp', 'lop']: 98 | # e.g. 99 | # (left hand, right hand, ud) 100 | # 1: lhand above rhand 101 | # -1: lhand below rhand 102 | ja, jb, ax = info[1] 103 | idx.append([joint_idx[ja], joint_idx[jb], axis_idx[ax]]) 104 | elif info[0] == 'lap': 105 | # e.g. 106 | # (left arm) 107 | # 1: left arm unbends 108 | # -1: left arm bends 109 | fLimb = fullLimbs[info[1]] 110 | idx.append([joint_idx[limbs[fLimb[0]][0]], joint_idx[limbs[fLimb[0]][1]], axis_idx[fLimb[1]]]) 111 | print(len(idx)) 112 | 113 | self.idx = np.array(idx) 114 | self.joint_idx = joint_idx 115 | 116 | def forward(self, joi, index=None): 117 | axis = torch.zeros(joi.shape[0], 8, 3, device=joi.device) 118 | axis[:, 0, -1] = 1. # ud 119 | axis[:, 1] = joi[:, self.joint_idx['right hip']] - joi[:, self.joint_idx['left hip']] # rl 120 | axis[:, 2] = torch.cross(axis[:, 0], axis[:, 1]) # fb 121 | axis[:, 3] = 1. # none 122 | axis[:, 4] = joi[:, self.joint_idx['left shoulder']] - joi[:, self.joint_idx['left elbow']] # left upper arm 123 | axis[:, 5] = joi[:, self.joint_idx['left hip']] - joi[:, self.joint_idx['left knee']] # left thigh 124 | axis[:, 6] = joi[:, self.joint_idx['right shoulder']] - joi[:, self.joint_idx['right elbow']] # right upper arm 125 | axis[:, 7] = joi[:, self.joint_idx['right hip']] - joi[:, self.joint_idx['right knee']] # right thigh 126 | axis = axis / torch.norm(axis, p=2, dim=2, keepdim=True) 127 | if index is None: 128 | ind1 = torch.sum((joi[:, self.idx[:381, 0]] - joi[:, self.idx[:381, 1]]) * axis[:, self.idx[:381, 2]], axis=-1) 129 | ind2 = torch.arccos(torch.sum((joi[:, self.idx[381:, 0]] - joi[:, self.idx[381:, 1]]) * axis[:, self.idx[381:, 2]] / (torch.norm((joi[:, self.idx[381:, 0]] - joi[:, self.idx[381:, 1]]), dim=2, p=2, keepdim=True) + 1e-8), dim=-1)) 130 | ind3 = torch.sum((joi[1:, [0, 0, 0]] - joi[:-1, [0, 0, 0]]) * axis[:-1, :3], dim=-1) 131 | ind3 = torch.cat([ind3, ind3[-1:]]) 132 | indicators = torch.cat([ind1, ind2, ind3], axis=1) 133 | indicators[1:, :115] = torch.diff(indicators[:, :115], axis=0) 134 | indicators[:1, :115] = indicators[1:2, :115] 135 | indicators[1:, 381:389] = torch.diff(indicators[:, 381:389], axis=0) 136 | indicators[:1, 381:389] = indicators[1:2, 381:389] 137 | gvp_indicators = torch.sum((joi[1:, :1] - joi[:-1, :1]).expand(-1, 3, -1) * axis[:-1, :3], dim=-1) 138 | indicators = torch.cat((indicators, gvp_indicators), dim=-1) 139 | indicators[torch.abs(indicators) < 1e-3] = 0 140 | indicators = torch.sign(indicators) 141 | return indicators 142 | else: 143 | if index < 381: 144 | indicator = torch.sum((joi[:, self.idx[index, 0]] - joi[:, self.idx[index, 1]]) * axis[:, self.idx[index, 2]], axis=-1) 145 | elif index < 389: 146 | x1 = joi[:, self.idx[index, 0]] - joi[:, self.idx[index, 1]] 147 | x2 = axis[:, self.idx[index, 2]] 148 | cos = torch.clip(torch.nn.functional.cosine_similarity(x1, x2, dim=1, eps=1e-8), -1, 1) 149 | indicator = torch.arccos(cos) 150 | else: 151 | indicator = torch.sum((joi[1:, 0] - joi[:-1, 0]) * axis[:-1, index - 389], dim=-1) 152 | if index < 115 or 389 > index > 381: 153 | indicator = torch.diff(indicator) 154 | indicator[torch.abs(indicator) < 1e-3] = 0 155 | indicator = torch.sign(indicator) 156 | return indicator 157 | 158 | if __name__ == '__main__': 159 | body_model = smplx.create('models', model_type='smplx', gender='neutral', use_face_contour=True, num_betas=16, num_expression_coeffs=10, ext='npz', use_pca=False, create_global_orient=False, create_body_pose=False, create_left_hand_pose=False, create_right_hand_pose=False, create_jaw_pose=False, create_leye_pose=False, create_reye_pose=False, create_betas=False, create_expression=False, create_transl=False,).double() 160 | 161 | data = joblib.load('sample/motion.pkl') 162 | nf = len(data['poses']) 163 | betas = np.concatenate([data['betas'], np.zeros(16 - len(data['betas']))]) 164 | betas = torch.from_numpy(betas).expand(nf, -1).double() 165 | expression = torch.zeros(nf, 10).double() 166 | pose = torch.from_numpy(data['poses']).double() 167 | trans = torch.from_numpy(data['trans']).double() 168 | trans[1:] = trans[1:] - trans[:-1] 169 | trans[0] = 0 170 | pose, trans = canonicalize_smplx(pose.reshape(1, nf, 55, 3), 'aa', trans[None], 'aa') 171 | pose = pose[0].flatten(1) 172 | trans = trans[0].cumsum(0) 173 | smplx_data = body_model(betas=betas, expression=expression, transl=trans, 174 | global_orient=pose[..., :3], body_pose=pose[..., 3:66], jaw_pose=pose[..., 66:69], 175 | leye_pose=pose[..., 69:72], reye_pose=pose[..., 72:75], 176 | left_hand_pose=pose[..., 75:120], right_hand_pose=pose[..., 120:165], 177 | return_verts=False, return_shaped=False, dense_verts=False) 178 | joints = smplx_data.joints.detach().cpu().squeeze() 179 | kp = JOI2KP() 180 | print(kp(joints).shape) 181 | -------------------------------------------------------------------------------- /utils/geometry.py: -------------------------------------------------------------------------------- 1 | # Mostly ported from Pytorch3D 2 | # Some functions are adapted from https://github.com/qazwsxal/diffusion-extensions 3 | # Some are self-defined 4 | 5 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 6 | # Check PYTORCH3D_LICENCE before use 7 | 8 | import functools 9 | from typing import Optional 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | 14 | 15 | """ 16 | The transformation matrices returned from the functions in this file assume 17 | the points on which the transformation will be applied are column vectors. 18 | i.e. the R matrix is structured as 19 | 20 | R = [ 21 | [Rxx, Rxy, Rxz], 22 | [Ryx, Ryy, Ryz], 23 | [Rzx, Rzy, Rzz], 24 | ] # (3, 3) 25 | 26 | This matrix can be applied to column vectors by post multiplication 27 | by the points e.g. 28 | 29 | points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point 30 | transformed_points = R * points 31 | 32 | To apply the same matrix to points which are row vectors, the R matrix 33 | can be transposed and pre multiplied by the points: 34 | 35 | e.g. 36 | points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point 37 | transformed_points = points * R.transpose(1, 0) 38 | """ 39 | 40 | 41 | # Added 42 | def matrix_of_angles(cos, sin, inv=False, dim=2): 43 | assert dim in [2, 3] 44 | sin = -sin if inv else sin 45 | if dim == 2: 46 | row1 = torch.stack((cos, -sin), axis=-1) 47 | row2 = torch.stack((sin, cos), axis=-1) 48 | return torch.stack((row1, row2), axis=-2) 49 | elif dim == 3: 50 | row1 = torch.stack((cos, -sin, 0*cos), axis=-1) 51 | row2 = torch.stack((sin, cos, 0*cos), axis=-1) 52 | row3 = torch.stack((0*sin, 0*cos, 1+0*cos), axis=-1) 53 | return torch.stack((row1, row2, row3),axis=-2) 54 | 55 | 56 | def quaternion_to_matrix(quaternions): 57 | """ 58 | Convert rotations given as quaternions to rotation matrices. 59 | 60 | Args: 61 | quaternions: quaternions with real part first, 62 | as tensor of shape (..., 4). 63 | 64 | Returns: 65 | Rotation matrices as tensor of shape (..., 3, 3). 66 | """ 67 | r, i, j, k = torch.unbind(quaternions, -1) 68 | two_s = 2.0 / (quaternions * quaternions).sum(-1) 69 | 70 | o = torch.stack( 71 | ( 72 | 1 - two_s * (j * j + k * k), 73 | two_s * (i * j - k * r), 74 | two_s * (i * k + j * r), 75 | two_s * (i * j + k * r), 76 | 1 - two_s * (i * i + k * k), 77 | two_s * (j * k - i * r), 78 | two_s * (i * k - j * r), 79 | two_s * (j * k + i * r), 80 | 1 - two_s * (i * i + j * j), 81 | ), 82 | -1, 83 | ) 84 | return o.reshape(quaternions.shape[:-1] + (3, 3)) 85 | 86 | 87 | def _copysign(a, b): 88 | """ 89 | Return a tensor where each element has the absolute value taken from the, 90 | corresponding element of a, with sign taken from the corresponding 91 | element of b. This is like the standard copysign floating-point operation, 92 | but is not careful about negative 0 and NaN. 93 | 94 | Args: 95 | a: source tensor. 96 | b: tensor whose signs will be used, of the same shape as a. 97 | 98 | Returns: 99 | Tensor of the same shape as a with the signs of b. 100 | """ 101 | signs_differ = (a < 0) != (b < 0) 102 | return torch.where(signs_differ, -a, a) 103 | 104 | 105 | def _sqrt_positive_part(x): 106 | """ 107 | Returns torch.sqrt(torch.max(0, x)) 108 | but with a zero subgradient where x is 0. 109 | """ 110 | ret = torch.zeros_like(x) 111 | positive_mask = x > 0 112 | ret[positive_mask] = torch.sqrt(x[positive_mask]) 113 | return ret 114 | 115 | 116 | def matrix_to_quaternion(matrix): 117 | """ 118 | Convert rotations given as rotation matrices to quaternions. 119 | 120 | Args: 121 | matrix: Rotation matrices as tensor of shape (..., 3, 3). 122 | 123 | Returns: 124 | quaternions with real part first, as tensor of shape (..., 4). 125 | """ 126 | if matrix.size(-1) != 3 or matrix.size(-2) != 3: 127 | raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") 128 | m00 = matrix[..., 0, 0] 129 | m11 = matrix[..., 1, 1] 130 | m22 = matrix[..., 2, 2] 131 | o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22) 132 | x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22) 133 | y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22) 134 | z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22) 135 | o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2]) 136 | o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0]) 137 | o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1]) 138 | return torch.stack((o0, o1, o2, o3), -1) 139 | 140 | 141 | def _axis_angle_rotation(axis: str, angle): 142 | """ 143 | Return the rotation matrices for one of the rotations about an axis 144 | of which Euler angles describe, for each value of the angle given. 145 | 146 | Args: 147 | axis: Axis label "X" or "Y or "Z". 148 | angle: any shape tensor of Euler angles in radians 149 | 150 | Returns: 151 | Rotation matrices as tensor of shape (..., 3, 3). 152 | """ 153 | 154 | cos = torch.cos(angle) 155 | sin = torch.sin(angle) 156 | one = torch.ones_like(angle) 157 | zero = torch.zeros_like(angle) 158 | 159 | if axis == "X": 160 | R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) 161 | if axis == "Y": 162 | R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) 163 | if axis == "Z": 164 | R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) 165 | 166 | return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) 167 | 168 | 169 | def euler_angles_to_matrix(euler_angles, convention: str): 170 | """ 171 | Convert rotations given as Euler angles in radians to rotation matrices. 172 | 173 | Args: 174 | euler_angles: Euler angles in radians as tensor of shape (..., 3). 175 | convention: Convention string of three uppercase letters from 176 | {"X", "Y", and "Z"}. 177 | 178 | Returns: 179 | Rotation matrices as tensor of shape (..., 3, 3). 180 | """ 181 | if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: 182 | raise ValueError("Invalid input euler angles.") 183 | if len(convention) != 3: 184 | raise ValueError("Convention must have 3 letters.") 185 | if convention[1] in (convention[0], convention[2]): 186 | raise ValueError(f"Invalid convention {convention}.") 187 | for letter in convention: 188 | if letter not in ("X", "Y", "Z"): 189 | raise ValueError(f"Invalid letter {letter} in convention string.") 190 | matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1)) 191 | return functools.reduce(torch.matmul, matrices) 192 | 193 | 194 | def _angle_from_tan( 195 | axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool 196 | ): 197 | """ 198 | Extract the first or third Euler angle from the two members of 199 | the matrix which are positive constant times its sine and cosine. 200 | 201 | Args: 202 | axis: Axis label "X" or "Y or "Z" for the angle we are finding. 203 | other_axis: Axis label "X" or "Y or "Z" for the middle axis in the 204 | convention. 205 | data: Rotation matrices as tensor of shape (..., 3, 3). 206 | horizontal: Whether we are looking for the angle for the third axis, 207 | which means the relevant entries are in the same row of the 208 | rotation matrix. If not, they are in the same column. 209 | tait_bryan: Whether the first and third axes in the convention differ. 210 | 211 | Returns: 212 | Euler Angles in radians for each matrix in data as a tensor 213 | of shape (...). 214 | """ 215 | 216 | i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] 217 | if horizontal: 218 | i2, i1 = i1, i2 219 | even = (axis + other_axis) in ["XY", "YZ", "ZX"] 220 | if horizontal == even: 221 | return torch.atan2(data[..., i1], data[..., i2]) 222 | if tait_bryan: 223 | return torch.atan2(-data[..., i2], data[..., i1]) 224 | return torch.atan2(data[..., i2], -data[..., i1]) 225 | 226 | 227 | def _index_from_letter(letter: str): 228 | if letter == "X": 229 | return 0 230 | if letter == "Y": 231 | return 1 232 | if letter == "Z": 233 | return 2 234 | 235 | 236 | def matrix_to_euler_angles(matrix, convention: str): 237 | """ 238 | Convert rotations given as rotation matrices to Euler angles in radians. 239 | 240 | Args: 241 | matrix: Rotation matrices as tensor of shape (..., 3, 3). 242 | convention: Convention string of three uppercase letters. 243 | 244 | Returns: 245 | Euler angles in radians as tensor of shape (..., 3). 246 | """ 247 | if len(convention) != 3: 248 | raise ValueError("Convention must have 3 letters.") 249 | if convention[1] in (convention[0], convention[2]): 250 | raise ValueError(f"Invalid convention {convention}.") 251 | for letter in convention: 252 | if letter not in ("X", "Y", "Z"): 253 | raise ValueError(f"Invalid letter {letter} in convention string.") 254 | if matrix.size(-1) != 3 or matrix.size(-2) != 3: 255 | raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") 256 | i0 = _index_from_letter(convention[0]) 257 | i2 = _index_from_letter(convention[2]) 258 | tait_bryan = i0 != i2 259 | if tait_bryan: 260 | central_angle = torch.asin( 261 | matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) 262 | ) 263 | else: 264 | central_angle = torch.acos(matrix[..., i0, i0]) 265 | 266 | o = ( 267 | _angle_from_tan( 268 | convention[0], convention[1], matrix[..., i2], False, tait_bryan 269 | ), 270 | central_angle, 271 | _angle_from_tan( 272 | convention[2], convention[1], matrix[..., i0, :], True, tait_bryan 273 | ), 274 | ) 275 | return torch.stack(o, -1) 276 | 277 | 278 | def random_quaternions( 279 | n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False 280 | ): 281 | """ 282 | Generate random quaternions representing rotations, 283 | i.e. versors with nonnegative real part. 284 | 285 | Args: 286 | n: Number of quaternions in a batch to return. 287 | dtype: Type to return. 288 | device: Desired device of returned tensor. Default: 289 | uses the current device for the default tensor type. 290 | requires_grad: Whether the resulting tensor should have the gradient 291 | flag set. 292 | 293 | Returns: 294 | Quaternions as tensor of shape (N, 4). 295 | """ 296 | o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad) 297 | s = (o * o).sum(1) 298 | o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None] 299 | return o 300 | 301 | 302 | def random_rotations( 303 | n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False 304 | ): 305 | """ 306 | Generate random rotations as 3x3 rotation matrices. 307 | 308 | Args: 309 | n: Number of rotation matrices in a batch to return. 310 | dtype: Type to return. 311 | device: Device of returned tensor. Default: if None, 312 | uses the current device for the default tensor type. 313 | requires_grad: Whether the resulting tensor should have the gradient 314 | flag set. 315 | 316 | Returns: 317 | Rotation matrices as tensor of shape (n, 3, 3). 318 | """ 319 | quaternions = random_quaternions( 320 | n, dtype=dtype, device=device, requires_grad=requires_grad 321 | ) 322 | return quaternion_to_matrix(quaternions) 323 | 324 | 325 | def random_rotation( 326 | dtype: Optional[torch.dtype] = None, device=None, requires_grad=False 327 | ): 328 | """ 329 | Generate a single random 3x3 rotation matrix. 330 | 331 | Args: 332 | dtype: Type to return 333 | device: Device of returned tensor. Default: if None, 334 | uses the current device for the default tensor type 335 | requires_grad: Whether the resulting tensor should have the gradient 336 | flag set 337 | 338 | Returns: 339 | Rotation matrix as tensor of shape (3, 3). 340 | """ 341 | return random_rotations(1, dtype, device, requires_grad)[0] 342 | 343 | 344 | def standardize_quaternion(quaternions): 345 | """ 346 | Convert a unit quaternion to a standard form: one in which the real 347 | part is non negative. 348 | 349 | Args: 350 | quaternions: Quaternions with real part first, 351 | as tensor of shape (..., 4). 352 | 353 | Returns: 354 | Standardized quaternions as tensor of shape (..., 4). 355 | """ 356 | return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) 357 | 358 | 359 | def quaternion_raw_multiply(a, b): 360 | """ 361 | Multiply two quaternions. 362 | Usual torch rules for broadcasting apply. 363 | 364 | Args: 365 | a: Quaternions as tensor of shape (..., 4), real part first. 366 | b: Quaternions as tensor of shape (..., 4), real part first. 367 | 368 | Returns: 369 | The product of a and b, a tensor of quaternions shape (..., 4). 370 | """ 371 | aw, ax, ay, az = torch.unbind(a, -1) 372 | bw, bx, by, bz = torch.unbind(b, -1) 373 | ow = aw * bw - ax * bx - ay * by - az * bz 374 | ox = aw * bx + ax * bw + ay * bz - az * by 375 | oy = aw * by - ax * bz + ay * bw + az * bx 376 | oz = aw * bz + ax * by - ay * bx + az * bw 377 | return torch.stack((ow, ox, oy, oz), -1) 378 | 379 | 380 | def quaternion_multiply(a, b): 381 | """ 382 | Multiply two quaternions representing rotations, returning the quaternion 383 | representing their composition, i.e. the versor with nonnegative real part. 384 | Usual torch rules for broadcasting apply. 385 | 386 | Args: 387 | a: Quaternions as tensor of shape (..., 4), real part first. 388 | b: Quaternions as tensor of shape (..., 4), real part first. 389 | 390 | Returns: 391 | The product of a and b, a tensor of quaternions of shape (..., 4). 392 | """ 393 | ab = quaternion_raw_multiply(a, b) 394 | return standardize_quaternion(ab) 395 | 396 | 397 | def quaternion_invert(quaternion): 398 | """ 399 | Given a quaternion representing rotation, get the quaternion representing 400 | its inverse. 401 | 402 | Args: 403 | quaternion: Quaternions as tensor of shape (..., 4), with real part 404 | first, which must be versors (unit quaternions). 405 | 406 | Returns: 407 | The inverse, a tensor of quaternions of shape (..., 4). 408 | """ 409 | 410 | return quaternion * quaternion.new_tensor([1, -1, -1, -1]) 411 | 412 | 413 | def quaternion_apply(quaternion, point): 414 | """ 415 | Apply the rotation given by a quaternion to a 3D point. 416 | Usual torch rules for broadcasting apply. 417 | 418 | Args: 419 | quaternion: Tensor of quaternions, real part first, of shape (..., 4). 420 | point: Tensor of 3D points of shape (..., 3). 421 | 422 | Returns: 423 | Tensor of rotated points of shape (..., 3). 424 | """ 425 | if point.size(-1) != 3: 426 | raise ValueError(f"Points are not in 3D, f{point.shape}.") 427 | real_parts = point.new_zeros(point.shape[:-1] + (1,)) 428 | point_as_quaternion = torch.cat((real_parts, point), -1) 429 | out = quaternion_raw_multiply( 430 | quaternion_raw_multiply(quaternion, point_as_quaternion), 431 | quaternion_invert(quaternion), 432 | ) 433 | return out[..., 1:] 434 | 435 | 436 | def axis_angle_to_matrix(axis_angle): 437 | """ 438 | Convert rotations given as axis/angle to rotation matrices. 439 | 440 | Args: 441 | axis_angle: Rotations given as a vector in axis angle form, 442 | as a tensor of shape (..., 3), where the magnitude is 443 | the angle turned anticlockwise in radians around the 444 | vector's direction. 445 | 446 | Returns: 447 | Rotation matrices as tensor of shape (..., 3, 3). 448 | """ 449 | return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) 450 | 451 | 452 | def matrix_to_axis_angle(matrix): 453 | """ 454 | Convert rotations given as rotation matrices to axis/angle. 455 | 456 | Args: 457 | matrix: Rotation matrices as tensor of shape (..., 3, 3). 458 | 459 | Returns: 460 | Rotations given as a vector in axis angle form, as a tensor 461 | of shape (..., 3), where the magnitude is the angle 462 | turned anticlockwise in radians around the vector's 463 | direction. 464 | """ 465 | return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) 466 | 467 | 468 | def axis_angle_to_quaternion(axis_angle): 469 | """ 470 | Convert rotations given as axis/angle to quaternions. 471 | 472 | Args: 473 | axis_angle: Rotations given as a vector in axis angle form, 474 | as a tensor of shape (..., 3), where the magnitude is 475 | the angle turned anticlockwise in radians around the 476 | vector's direction. 477 | 478 | Returns: 479 | quaternions with real part first, as tensor of shape (..., 4). 480 | """ 481 | angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) 482 | half_angles = 0.5 * angles 483 | eps = 1e-6 484 | small_angles = angles.abs() < eps 485 | sin_half_angles_over_angles = torch.empty_like(angles) 486 | sin_half_angles_over_angles[~small_angles] = ( 487 | torch.sin(half_angles[~small_angles]) / angles[~small_angles] 488 | ) 489 | # for x small, sin(x/2) is about x/2 - (x/2)^3/6 490 | # so sin(x/2)/x is about 1/2 - (x*x)/48 491 | sin_half_angles_over_angles[small_angles] = ( 492 | 0.5 - (angles[small_angles] * angles[small_angles]) / 48 493 | ) 494 | quaternions = torch.cat( 495 | [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 496 | ) 497 | return quaternions 498 | 499 | 500 | def quaternion_to_axis_angle(quaternions): 501 | """ 502 | Convert rotations given as quaternions to axis/angle. 503 | 504 | Args: 505 | quaternions: quaternions with real part first, 506 | as tensor of shape (..., 4). 507 | 508 | Returns: 509 | Rotations given as a vector in axis angle form, as a tensor 510 | of shape (..., 3), where the magnitude is the angle 511 | turned anticlockwise in radians around the vector's 512 | direction. 513 | """ 514 | norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) 515 | half_angles = torch.atan2(norms, quaternions[..., :1]) 516 | angles = 2 * half_angles 517 | eps = 1e-6 518 | small_angles = angles.abs() < eps 519 | sin_half_angles_over_angles = torch.empty_like(angles) 520 | sin_half_angles_over_angles[~small_angles] = ( 521 | torch.sin(half_angles[~small_angles]) / angles[~small_angles] 522 | ) 523 | # for x small, sin(x/2) is about x/2 - (x/2)^3/6 524 | # so sin(x/2)/x is about 1/2 - (x*x)/48 525 | sin_half_angles_over_angles[small_angles] = ( 526 | 0.5 - (angles[small_angles] * angles[small_angles]) / 48 527 | ) 528 | return quaternions[..., 1:] / sin_half_angles_over_angles 529 | 530 | # Self-defined 531 | def decompose_axis_angle(axis_angle): 532 | """ 533 | Decompose axis/angle representation. 534 | 535 | Args: 536 | axis_angle: Rotations given as a vector in axis angle form, 537 | as a tensor of shape (..., 3), where the magnitude is 538 | the angle turned anticlockwise in radians around the 539 | vector's direction. 540 | 541 | Returns: 542 | Decomposed axis angle represention with axis of shape (..., 3) 543 | and angle of shape (...) 544 | """ 545 | angle = torch.norm(axis_angle, p=2, dim=-1) 546 | bottom = 1 / angle 547 | bottom[bottom.isnan()] = 1.0 548 | bottom[bottom.isinf()] = 1.0 549 | axis = axis_angle * bottom[..., None] 550 | return axis, angle 551 | 552 | # Self-defined 553 | def compose_axis_angle(axis, angle): 554 | """ 555 | Compose axis angle representation 556 | 557 | Args: 558 | axis: Axis in axis angle form, as a tensor of shape (..., 3) 559 | angle: Angle in axis angle form, as a tensor of shape (...) 560 | 561 | Returns: 562 | Composed axis angle represention of shape (..., 3) 563 | """ 564 | return axis * angle[..., None] 565 | 566 | # Self-defined 567 | def standardize_rotation_6d(d6: torch.Tensor) -> torch.Tensor: 568 | """ 569 | Standardize 6D rotation representation by Zhou et al. [1] 570 | using Gram--Schmidt orthogonalisation per Section B of [1]. 571 | Args: 572 | d6: 6D rotation representation, of size (*, 6) 573 | 574 | Returns: 575 | batch of standardized 6D rotation representation of size (*, 6) 576 | 577 | [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. 578 | On the Continuity of Rotation Representations in Neural Networks. 579 | IEEE Conference on Computer Vision and Pattern Recognition, 2019. 580 | Retrieved from http://arxiv.org/abs/1812.07035 581 | """ 582 | 583 | a1, a2 = d6[..., :3], d6[..., 3:] 584 | b1 = F.normalize(a1, dim=-1) 585 | b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 586 | b2 = F.normalize(b2, dim=-1) 587 | 588 | return torch.cat([b1, b2], dim=-1) 589 | 590 | 591 | def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: 592 | """ 593 | Converts 6D rotation representation by Zhou et al. [1] to rotation matrix 594 | using Gram--Schmidt orthogonalisation per Section B of [1]. 595 | Args: 596 | d6: 6D rotation representation, of size (*, 6) 597 | 598 | Returns: 599 | batch of rotation matrices of size (*, 3, 3) 600 | 601 | [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. 602 | On the Continuity of Rotation Representations in Neural Networks. 603 | IEEE Conference on Computer Vision and Pattern Recognition, 2019. 604 | Retrieved from http://arxiv.org/abs/1812.07035 605 | """ 606 | 607 | a1, a2 = d6[..., :3], d6[..., 3:] 608 | b1 = F.normalize(a1, dim=-1) 609 | b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 610 | b2 = F.normalize(b2, dim=-1) 611 | b3 = torch.cross(b1, b2, dim=-1) 612 | return torch.stack((b1, b2, b3), dim=-2) 613 | 614 | 615 | def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: 616 | """ 617 | Converts rotation matrices to 6D rotation representation by Zhou et al. [1] 618 | by dropping the last row. Note that 6D representation is not unique. 619 | Args: 620 | matrix: batch of rotation matrices of size (*, 3, 3) 621 | 622 | Returns: 623 | 6D rotation representation, of size (*, 6) 624 | 625 | [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. 626 | On the Continuity of Rotation Representations in Neural Networks. 627 | IEEE Conference on Computer Vision and Pattern Recognition, 2019. 628 | Retrieved from http://arxiv.org/abs/1812.07035 629 | """ 630 | return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6) 631 | 632 | # Self-defined 633 | def rot_from_to(input: torch.Tensor, src: str, tgt: str) -> torch.Tensor: 634 | if src == 'aa': 635 | if tgt == 'aa': 636 | return input 637 | elif tgt == 'quat': 638 | return axis_angle_to_quaternion(input) 639 | elif tgt == 'rot6d': 640 | return matrix_to_rotation_6d(axis_angle_to_matrix(input)) 641 | elif tgt == 'matrix': 642 | return axis_angle_to_matrix(input) 643 | else: 644 | raise NotImplementedError 645 | elif src == 'quat': 646 | if tgt == 'aa': 647 | return quaternion_to_axis_angle(input) 648 | elif tgt == 'quat': 649 | return input 650 | elif tgt == 'rot6d': 651 | return matrix_to_rotation_6d(quaternion_to_matrix(input)) 652 | elif tgt == 'matrix': 653 | return quaternion_to_matrix(input) 654 | else: 655 | raise NotImplementedError 656 | elif src == 'rot6d': 657 | if tgt == 'aa': 658 | return matrix_to_axis_angle(rotation_6d_to_matrix(input)) 659 | elif tgt == 'quat': 660 | return matrix_to_quaternion(rotation_6d_to_matrix(input)) 661 | elif tgt == 'rot6d': 662 | return input 663 | elif tgt == 'matrix': 664 | return rotation_6d_to_matrix(input) 665 | else: 666 | raise NotImplementedError 667 | elif src == 'matrix': 668 | if tgt == 'aa': 669 | return matrix_to_axis_angle(input) 670 | elif tgt == 'quat': 671 | return matrix_to_quaternion(input) 672 | elif tgt == 'rot6d': 673 | return matrix_to_rotation_6d(input) 674 | elif tgt == 'matrix': 675 | return input 676 | else: 677 | raise NotImplementedError 678 | else: 679 | raise NotImplementedError 680 | 681 | # Ported from https://github.com/qazwsxal/diffusion-extensions 682 | def skew2vec(skew: torch.Tensor) -> torch.Tensor: 683 | vec = torch.zeros_like(skew[..., 0]) 684 | vec[..., 0] = skew[..., 2, 1] 685 | vec[..., 1] = -skew[..., 2, 0] 686 | vec[..., 2] = skew[..., 1, 0] 687 | return vec 688 | 689 | # Ported from https://github.com/qazwsxal/diffusion-extensions 690 | def vec2skew(vec: torch.Tensor) -> torch.Tensor: 691 | skew = torch.repeat_interleave(torch.zeros_like(vec).unsqueeze(-1), 3, dim=-1) 692 | skew[..., 2, 1] = vec[..., 0] 693 | skew[..., 2, 0] = -vec[..., 1] 694 | skew[..., 1, 0] = vec[..., 2] 695 | return skew - skew.transpose(-1, -2) 696 | 697 | def log_rmat(r_mat: torch.Tensor) -> torch.Tensor: 698 | ''' 699 | See paper 700 | Exponentials of skew-symmetric matrices and logarithms of orthogonal matrices 701 | https://doi.org/10.1016/j.cam.2009.11.032 702 | For most of the derivatons here 703 | We use atan2 instead of acos here dut to better numerical stability. 704 | it means we get nicer behaviour around 0 degrees 705 | More effort to derive sin terms 706 | but as we're dealing with small angles a lot, 707 | the tradeoff is worth it. 708 | ''' 709 | skew_mat = r_mat - r_mat.transpose(-1, -2) 710 | sk_vec = skew2vec(skew_mat) 711 | s_angle = sk_vec.norm(p=2, dim=-1) / 2 712 | c_angle = (torch.einsum('...ii', r_mat) - 1) / 2 713 | angle = torch.atan2(s_angle, c_angle) 714 | scale = angle / (2 * s_angle) 715 | # if s_angle = 0, i.e. rotation by 0 or pi (180), we get NaNs 716 | # by definition, scale values are 0 if rotating by 0. 717 | # This also breaks down if rotating by pi, fix further down 718 | scale[angle == 0.0] = 0.0 719 | log_r_mat = scale[..., None, None] * skew_mat 720 | 721 | # Check for NaNs caused by 180deg rotations. 722 | nanlocs = log_r_mat[..., 0, 0].isnan() 723 | nanmats = r_mat[nanlocs] 724 | # We need to use an alternative way of finding the logarithm for nanmats, 725 | # Use eigendecomposition to discover axis of rotation. 726 | # By definition, these are symmetric, so use eigh. 727 | # NOTE: linalg.eig() isn't in torch 1.8, 728 | # and torch.eig() doesn't do batched matrices 729 | eigval, eigvec = torch.linalg.eigh(nanmats) 730 | # Final eigenvalue == 1, might be slightly off because floats, but other two are -ve. 731 | # this *should* just be the last column if the docs for eigh are true. 732 | nan_axes = eigvec[..., -1, :] 733 | nan_angle = angle[nanlocs] 734 | nan_skew = vec2skew(nan_angle[..., None] * nan_axes) 735 | log_r_mat[nanlocs] = nan_skew 736 | return log_r_mat 737 | 738 | # Adapted from https://github.com/qazwsxal/diffusion-extensions 739 | def rot_lerp(rot_a: torch.Tensor, rot_b: torch.Tensor, weight: torch.Tensor, src: str = 'matrix', tgt: str = 'matrix') -> torch.Tensor: 740 | ''' Weighted interpolation between rot_a and rot_b 741 | ''' 742 | # Treat rot_b = rot_a @ rot_c 743 | # rot_a^-1 @ rot_a = I 744 | # rot_a^-1 @ rot_b = rot_a^-1 @ rot_a @ rot_c = I @ rot_c 745 | # once we have rot_c, use axis-angle forms to lerp angle 746 | rot_a = rot_from_to(rot_a, src, 'matrix') 747 | rot_b = rot_from_to(rot_b, src, 'matrix') 748 | rot_c = rot_a.transpose(-1, -2) @ rot_b 749 | axis, angle = decompose_axis_angle(rot_from_to(rot_c, 'matrix', 'aa'))# rmat_to_aa(rot_c) 750 | # once we have axis-angle forms, determine intermediate angles. 751 | # print(weight.shape, angle.shape) 752 | i_angle = weight * angle 753 | aa = compose_axis_angle(axis, i_angle) 754 | rot_c_i = rot_from_to(aa, 'aa', 'matrix') 755 | res = rot_from_to(rot_a @ rot_c_i, 'matrix', tgt) 756 | return res 757 | 758 | # Adapted from https://github.com/qazwsxal/diffusion-extensions 759 | def rot_scale(input, scalars, src='matrix', tgt='matrix'): 760 | '''Scale the magnitude of a rotation, 761 | e.g. a 45 degree rotation scaled by a factor of 2 gives a 90 degree rotation. 762 | 763 | This is the same as taking matrix powers, but pytorch only supports integer exponents 764 | 765 | So instead, we take advantage of the properties of rotation matrices 766 | to calculate logarithms easily. and multiply instead. 767 | ''' 768 | rmat = rot_from_to(input, src, 'matrix') 769 | logs = log_rmat(rmat) 770 | scaled_logs = logs * scalars[..., None, None] 771 | out = torch.matrix_exp(scaled_logs) 772 | out = rot_from_to(out, 'matrix', tgt) 773 | return out 774 | 775 | def canonicalize_smplx(poses: torch.Tensor, repr: str, trans: Optional[torch.Tensor] = None, tgt: str = None, return_mat: bool = False): 776 | ''' 777 | Input: [bs, nframes, njoints, 3/4/6/9] 778 | ''' 779 | bs, nframes, njoints = poses.shape[:3] 780 | if tgt is None: 781 | tgt = repr 782 | 783 | global_orient = rot_from_to(poses[:, :, 0], repr, 'matrix') 784 | 785 | # first global rotations 786 | rot2d = rot_from_to(global_orient[:, 0], 'matrix', 'aa') 787 | rot2d[:, :2] *= 0 # Remove the rotation along the vertical axis 788 | rot2d = rot_from_to(rot2d, 'aa', 'matrix') 789 | 790 | # Rotate the global rotation to eliminate Z rotations 791 | global_orient = torch.einsum("ikj,imkl->imjl", rot2d, global_orient) 792 | global_orient = rot_from_to(global_orient, 'matrix', tgt) 793 | 794 | # Construct canonicalized version of x 795 | xc = torch.cat((global_orient[:, :, None], rot_from_to(poses[:, :, 1:], repr, tgt)), dim=2) 796 | 797 | if trans is not None: 798 | # vel = trans[:, 1:] 799 | # Turn the translation as well 800 | trans = torch.einsum("ikj,ilk->ilj", rot2d, trans) 801 | # trans = torch.cat((torch.zeros(bs, 1, 3, device=vel.device), vel), 1) 802 | if return_mat: 803 | return xc, trans, rot2d 804 | return xc, trans 805 | else: 806 | if return_mat: 807 | return xc, rot2d 808 | return xc 809 | 810 | def rotate_smplx(poses: torch.Tensor, rot2d: torch.Tensor, repr: str, trans: Optional[torch.Tensor] = None, tgt: str = None): 811 | ''' 812 | Input: [bs, nframes, njoints, 3/4/6/9] 813 | trans: [bs, nframes, njoints, 3] velocity 814 | ''' 815 | bs, nframes, njoints = poses.shape[:3] 816 | if tgt is None: 817 | tgt = repr 818 | 819 | global_orient = rot_from_to(poses[:, :, 0], repr, 'matrix') 820 | 821 | # first global rotations 822 | # rot2d = rot_from_to(global_orient[:, 0], 'matrix', 'aa') 823 | # rot2d[:, :2] *= 0 # Remove the rotation along the vertical axis 824 | # rot2d = rot_from_to(rot2d, 'aa', 'matrix') 825 | rot2d = rot_from_to(rot2d, repr, 'matrix') # bz, 3, 3 826 | global_orient = torch.einsum("ikj,imkl->imjl", rot2d, global_orient) 827 | global_orient = rot_from_to(global_orient, 'matrix', tgt) 828 | 829 | # Construct canonicalized version of x 830 | xc = torch.cat((global_orient[:, :, None], rot_from_to(poses[:, :, 1:], repr, tgt)), dim=2) 831 | 832 | if trans is not None: 833 | # vel = trans[:, 1:] 834 | # Turn the translation as well 835 | trans = torch.einsum("ikj,ilk->ilj", rot2d, trans) 836 | # trans = torch.cat((torch.zeros(bs, 1, 3, device=vel.device), vel), 1) 837 | return xc, trans 838 | else: 839 | return xc 840 | 841 | # Adapted from https://github.com/zju3dv/EasyMocap/blob/64e0e48d2970b352cfc60ffd95922495083ef306/easymocap/dataset/mirror.py 842 | # TODO 843 | _PERMUTATION = { 844 | 'smpl': [0, 2, 1, 3, 5, 4, 6, 8, 7, 9, 11, 10, 12, 14, 13, 15, 17, 16, 19, 18, 21, 20, 23, 22], 845 | 'smplh': [0, 2, 1, 3, 5, 4, 6, 8, 7, 9, 11, 10, 12, 14, 13, 15, 17, 16, 19, 18, 21, 20, 24, 25, 23, 24], 846 | 'smplx': [0, 2, 1, 3, 5, 4, 6, 8, 7, 9, 11, 10, 12, 14, 13, 15, 17, 16, 19, 18, 21, 20, 24, 25, 23, 24, 26, 28, 27], 847 | 'smplhfull': [ 848 | 0, 2, 1, 3, 5, 4, 6, 8, 7, 9, 11, 10, 12, 14, 13, 15, 17, 16, 19, 18, 21, 20, # body 849 | 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 850 | 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36 851 | ], 852 | 'smplxfull': [ 853 | 0, 2, 1, 3, 5, 4, 6, 8, 7, 9, 11, 10, 12, 14, 13, 15, 17, 16, 19, 18, 21, 20, # body 854 | 22, 24, 23, # jaw, left eye, right eye 855 | 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, # right hand 856 | 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, # left hand 857 | ] 858 | } 859 | def flipSMPLXPoses(poses, trans, ptype='smplx'): 860 | """ 861 | poses: L, 24, 3 862 | trans: L, 3 863 | """ 864 | assert ptype in ['smplx', 'smpl'], '{} is not implemented'.format(ptype) 865 | # flip pose 866 | poses = poses[:, _PERMUTATION[ptype + 'full']] 867 | poses[..., 1:] = -poses[..., 1:] 868 | trans[:, 0] *= -1 869 | 870 | return poses, trans 871 | 872 | # ported from https://github.com/c-he/NeMF/blob/79918430970fd138ae730510459c8f34893a3f86/src/utils.py#L155C1-L171C19 873 | def estimate_linear_velocity(data_seq, dt): 874 | ''' 875 | Given some batched data sequences of T timesteps in the shape (B, T, ...), estimates 876 | the velocity for the middle T-2 steps using a second order central difference scheme. 877 | The first and last frames are with forward and backward first-order 878 | differences, respectively 879 | - h : step size 880 | ''' 881 | # first steps is forward diff (t+1 - t) / dt 882 | init_vel = (data_seq[:, 1:2] - data_seq[:, :1]) / dt 883 | # middle steps are second order (t+1 - t-1) / 2dt 884 | middle_vel = (data_seq[:, 2:] - data_seq[:, 0:-2]) / (2 * dt) 885 | # last step is backward diff (t - t-1) / dt 886 | final_vel = (data_seq[:, -1:] - data_seq[:, -2:-1]) / dt 887 | 888 | vel_seq = torch.cat([init_vel, middle_vel, final_vel], dim=1) 889 | return vel_seq 890 | 891 | # ported from https://github.com/c-he/NeMF/blob/main/src/utils.py#L174 892 | def estimate_angular_velocity(rot_seq, dt, repr='matrix'): 893 | ''' 894 | Given a batch of sequences of T rotation matrices, estimates angular velocity at T-2 steps. 895 | Input sequence should be of shape (B, T, ..., 3, 3) 896 | ''' 897 | # see https://en.wikipedia.org/wiki/Angular_velocity#Calculation_from_the_orientation_matrix 898 | rot_seq = rot_from_to(rot_seq, repr, 'matrix') 899 | dRdt = estimate_linear_velocity(rot_seq, dt) 900 | R = rot_seq 901 | RT = R.transpose(-1, -2) 902 | # compute skew-symmetric angular velocity tensor 903 | w_mat = torch.matmul(dRdt, RT) 904 | # pull out angular velocity vector by averaging symmetric entries 905 | w_x = (-w_mat[..., 1, 2] + w_mat[..., 2, 1]) / 2.0 906 | w_y = (w_mat[..., 0, 2] - w_mat[..., 2, 0]) / 2.0 907 | w_z = (-w_mat[..., 0, 1] + w_mat[..., 1, 0]) / 2.0 908 | w = torch.stack([w_x, w_y, w_z], axis=-1) 909 | return w --------------------------------------------------------------------------------