├── README.md ├── install.sh ├── net_texture.png ├── setup.py └── smpl_visualizer ├── konia_transform.py ├── smpl.py ├── torch_transform.py ├── vis.py ├── vis_pyvista.py ├── vis_scenepic.py └── vis_sport.py /README.md: -------------------------------------------------------------------------------- 1 | # SMPL human model visualizer in Python 2 | Note: the code is mostly adapted from [GLAMR](https://github.com/NVlabs/GLAMR). Please follow their license in using the code. 3 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | python -m pip install -e . 2 | -------------------------------------------------------------------------------- /net_texture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Haotianz94/smpl_visualizer/773b54c67d1703ec4dc228fc67a03de9a619e399/net_texture.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import time 4 | 5 | from setuptools import find_packages, setup 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | MAJOR = 0 9 | MINOR = 1 10 | PATCH = 0 11 | SUFFIX = '' 12 | SHORT_VERSION = '{}.{}.{}{}'.format(MAJOR, MINOR, PATCH, SUFFIX) 13 | 14 | version_file = 'smpl_visualizer/version.py' 15 | 16 | 17 | def get_git_hash(): 18 | 19 | def _minimal_ext_cmd(cmd): 20 | # construct minimal environment 21 | env = {} 22 | for k in ['SYSTEMROOT', 'PATH', 'HOME']: 23 | v = os.environ.get(k) 24 | if v is not None: 25 | env[k] = v 26 | # LANGUAGE is used on win32 27 | env['LANGUAGE'] = 'C' 28 | env['LANG'] = 'C' 29 | env['LC_ALL'] = 'C' 30 | out = subprocess.Popen( 31 | cmd, stdout=subprocess.PIPE, env=env).communicate()[0] 32 | return out 33 | 34 | try: 35 | out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) 36 | sha = out.strip().decode('ascii') 37 | except OSError: 38 | sha = 'unknown' 39 | 40 | return sha 41 | 42 | 43 | def get_hash(): 44 | if os.path.exists('.git'): 45 | sha = get_git_hash()[:7] 46 | elif os.path.exists(version_file): 47 | try: 48 | from smpl_visualizer.version import __version__ 49 | sha = __version__.split('+')[-1] 50 | except ImportError: 51 | raise ImportError('Unable to get git version') 52 | else: 53 | sha = 'unknown' 54 | 55 | return sha 56 | 57 | 58 | def write_version_py(): 59 | content = """# GENERATED VERSION FILE 60 | # TIME: {} 61 | 62 | __version__ = '{}' 63 | short_version = '{}' 64 | """ 65 | sha = get_hash() 66 | VERSION = SHORT_VERSION + '+' + sha 67 | 68 | with open(version_file, 'w') as f: 69 | f.write(content.format(time.asctime(), VERSION, SHORT_VERSION)) 70 | 71 | 72 | def make_cuda_ext(name, module, sources): 73 | 74 | return CUDAExtension( 75 | name='{}.{}'.format(module, name), 76 | sources=[os.path.join(*module.split('.'), p) for p in sources], 77 | extra_compile_args={ 78 | 'cxx': [], 79 | 'nvcc': [ 80 | '-D__CUDA_NO_HALF_OPERATORS__', 81 | '-D__CUDA_NO_HALF_CONVERSIONS__', 82 | '-D__CUDA_NO_HALF2_OPERATORS__', 83 | ] 84 | }) 85 | 86 | 87 | def get_version(): 88 | with open(version_file, 'r') as f: 89 | exec(compile(f.read(), version_file, 'exec')) 90 | return locals()['__version__'] 91 | 92 | 93 | if __name__ == '__main__': 94 | write_version_py() 95 | setup( 96 | name='smpl_visualizer', 97 | version=get_version(), 98 | description='Library for visualizing 3D humans', 99 | long_description='Library for visualizing 3D humans', 100 | keywords='human pose and shape visualization', 101 | packages=find_packages(exclude=('data', 'exp',)), 102 | package_data={'': ['*.json', '*.txt']}, 103 | classifiers=[ 104 | 'Development Status :: 4 - Beta', 105 | 'License :: OSI Approved :: Apache Software License', 106 | 'Operating System :: OS Independent', 107 | 'Programming Language :: Python :: 2', 108 | 'Programming Language :: Python :: 2.7', 109 | 'Programming Language :: Python :: 3', 110 | 'Programming Language :: Python :: 3.4', 111 | 'Programming Language :: Python :: 3.5', 112 | 'Programming Language :: Python :: 3.6', 113 | 'Programming Language :: Python :: 3.7', 114 | ], 115 | license='GPLv3', 116 | setup_requires=['pytest-runner'], 117 | tests_require=['pytest'], 118 | install_requires=[ 119 | 'smplx[all]', 'numpy', 'opencv-python', 'pyvista', 'pyrender', 120 | 'chumpy', 'scenepic' 121 | ], 122 | cmdclass={'build_ext': BuildExtension}, 123 | zip_safe=False) -------------------------------------------------------------------------------- /smpl_visualizer/konia_transform.py: -------------------------------------------------------------------------------- 1 | # This script is borrowed and extended from https://github.com/kornia/kornia/blob/master/kornia/geometry/conversions.py 2 | # Adhere to their licence to use this script 3 | 4 | import enum 5 | import warnings 6 | from typing import Tuple 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | from numpy import pi 12 | 13 | __all__ = [ 14 | # functional api 15 | "rad2deg", 16 | "deg2rad", 17 | "pol2cart", 18 | "cart2pol", 19 | "convert_points_from_homogeneous", 20 | "convert_points_to_homogeneous", 21 | "convert_affinematrix_to_homography", 22 | "convert_affinematrix_to_homography3d", 23 | "angle_axis_to_rotation_matrix", 24 | "angle_axis_to_quaternion", 25 | "rotation_matrix_to_angle_axis", 26 | "rotation_matrix_to_quaternion", 27 | "quaternion_to_angle_axis", 28 | "quaternion_to_rotation_matrix", 29 | "quaternion_log_to_exp", 30 | "quaternion_exp_to_log", 31 | "denormalize_pixel_coordinates", 32 | "normalize_pixel_coordinates", 33 | "normalize_quaternion", 34 | "denormalize_pixel_coordinates3d", 35 | "normalize_pixel_coordinates3d", 36 | ] 37 | 38 | 39 | class QuaternionCoeffOrder(enum.Enum): 40 | XYZW = 'xyzw' 41 | WXYZ = 'wxyz' 42 | 43 | 44 | def torch_safe_atan2(y, x, eps: float = 1e-6): 45 | y = y.clone() 46 | y[(y.abs() < eps) & (x.abs() < eps)] += eps 47 | return torch.atan2(y, x) 48 | 49 | 50 | def rad2deg(tensor: torch.Tensor) -> torch.Tensor: 51 | r"""Function that converts angles from radians to degrees. 52 | 53 | Args: 54 | tensor: Tensor of arbitrary shape. 55 | 56 | Returns: 57 | Tensor with same shape as input. 58 | 59 | Example: 60 | >>> input = torch.tensor(3.1415926535) * torch.rand(1, 3, 3) 61 | >>> output = rad2deg(input) 62 | """ 63 | if not isinstance(tensor, torch.Tensor): 64 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(tensor))) 65 | 66 | return 180.0 * tensor / pi.to(tensor.device).type(tensor.dtype) 67 | 68 | 69 | def deg2rad(tensor: torch.Tensor) -> torch.Tensor: 70 | r"""Function that converts angles from degrees to radians. 71 | 72 | Args: 73 | tensor: Tensor of arbitrary shape. 74 | 75 | Returns: 76 | tensor with same shape as input. 77 | 78 | Examples: 79 | >>> input = 360. * torch.rand(1, 3, 3) 80 | >>> output = deg2rad(input) 81 | """ 82 | if not isinstance(tensor, torch.Tensor): 83 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(tensor))) 84 | 85 | return tensor * pi.to(tensor.device).type(tensor.dtype) / 180.0 86 | 87 | 88 | def pol2cart(rho: torch.Tensor, phi: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 89 | r"""Function that converts polar coordinates to cartesian coordinates. 90 | 91 | Args: 92 | rho: Tensor of arbitrary shape. 93 | phi: Tensor of same arbitrary shape. 94 | 95 | Returns: 96 | Tensor with same shape as input. 97 | 98 | Example: 99 | >>> rho = torch.rand(1, 3, 3) 100 | >>> phi = torch.rand(1, 3, 3) 101 | >>> x, y = pol2cart(rho, phi) 102 | """ 103 | if not (isinstance(rho, torch.Tensor) & isinstance(phi, torch.Tensor)): 104 | raise TypeError("Input type is not a torch.Tensor. Got {}, {}".format(type(rho), type(phi))) 105 | 106 | x = rho * torch.cos(phi) 107 | y = rho * torch.sin(phi) 108 | return x, y 109 | 110 | 111 | def cart2pol(x: torch.Tensor, y: torch.Tensor, eps: float = 1.0e-8) -> Tuple[torch.Tensor, torch.Tensor]: 112 | """Function that converts cartesian coordinates to polar coordinates. 113 | 114 | Args: 115 | rho: Tensor of arbitrary shape. 116 | phi: Tensor of same arbitrary shape. 117 | eps: To avoid division by zero. 118 | 119 | Returns: 120 | Tensor with same shape as input. 121 | 122 | Example: 123 | >>> x = torch.rand(1, 3, 3) 124 | >>> y = torch.rand(1, 3, 3) 125 | >>> rho, phi = cart2pol(x, y) 126 | """ 127 | if not (isinstance(x, torch.Tensor) & isinstance(y, torch.Tensor)): 128 | raise TypeError("Input type is not a torch.Tensor. Got {}, {}".format(type(x), type(y))) 129 | 130 | rho = torch.sqrt((x ** 2 + y ** 2).clamp_min(eps)) 131 | phi = torch_safe_atan2(y, x) 132 | return rho, phi 133 | 134 | 135 | def convert_points_from_homogeneous(points: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: 136 | r"""Function that converts points from homogeneous to Euclidean space. 137 | 138 | Args: 139 | points: the points to be transformed. 140 | eps: to avoid division by zero. 141 | 142 | Returns: 143 | the points in Euclidean space. 144 | 145 | Examples: 146 | >>> input = torch.rand(2, 4, 3) # BxNx3 147 | >>> output = convert_points_from_homogeneous(input) # BxNx2 148 | """ 149 | if not isinstance(points, torch.Tensor): 150 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(points))) 151 | 152 | if len(points.shape) < 2: 153 | raise ValueError("Input must be at least a 2D tensor. Got {}".format(points.shape)) 154 | 155 | # we check for points at max_val 156 | z_vec: torch.Tensor = points[..., -1:] 157 | 158 | # set the results of division by zeror/near-zero to 1.0 159 | # follow the convention of opencv: 160 | # https://github.com/opencv/opencv/pull/14411/files 161 | mask: torch.Tensor = torch.abs(z_vec) > eps 162 | scale = torch.where(mask, 1.0 / (z_vec + eps), torch.ones_like(z_vec)) 163 | 164 | return scale * points[..., :-1] 165 | 166 | 167 | def convert_points_to_homogeneous(points: torch.Tensor) -> torch.Tensor: 168 | r"""Function that converts points from Euclidean to homogeneous space. 169 | 170 | Args: 171 | points: the points to be transformed. 172 | 173 | Returns: 174 | the points in homogeneous coordinates. 175 | 176 | Examples: 177 | >>> input = torch.rand(2, 4, 3) # BxNx3 178 | >>> output = convert_points_to_homogeneous(input) # BxNx4 179 | """ 180 | if not isinstance(points, torch.Tensor): 181 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(points))) 182 | if len(points.shape) < 2: 183 | raise ValueError("Input must be at least a 2D tensor. Got {}".format(points.shape)) 184 | 185 | return torch.nn.functional.pad(points, [0, 1], "constant", 1.0) 186 | 187 | 188 | def _convert_affinematrix_to_homography_impl(A: torch.Tensor) -> torch.Tensor: 189 | H: torch.Tensor = torch.nn.functional.pad(A, [0, 0, 0, 1], "constant", value=0.0) 190 | H[..., -1, -1] += 1.0 191 | return H 192 | 193 | 194 | def convert_affinematrix_to_homography(A: torch.Tensor) -> torch.Tensor: 195 | r"""Function that converts batch of affine matrices. 196 | 197 | Args: 198 | A: the affine matrix with shape :math:`(B,2,3)`. 199 | 200 | Returns: 201 | the homography matrix with shape of :math:`(B,3,3)`. 202 | 203 | Examples: 204 | >>> input = torch.rand(2, 2, 3) # Bx2x3 205 | >>> output = convert_affinematrix_to_homography(input) # Bx3x3 206 | """ 207 | if not isinstance(A, torch.Tensor): 208 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(A))) 209 | if not (len(A.shape) == 3 and A.shape[-2:] == (2, 3)): 210 | raise ValueError("Input matrix must be a Bx2x3 tensor. Got {}".format(A.shape)) 211 | return _convert_affinematrix_to_homography_impl(A) 212 | 213 | 214 | def convert_affinematrix_to_homography3d(A: torch.Tensor) -> torch.Tensor: 215 | r"""Function that converts batch of 3d affine matrices. 216 | 217 | Args: 218 | A: the affine matrix with shape :math:`(B,3,4)`. 219 | 220 | Returns: 221 | the homography matrix with shape of :math:`(B,4,4)`. 222 | 223 | Examples: 224 | >>> input = torch.rand(2, 3, 4) # Bx3x4 225 | >>> output = convert_affinematrix_to_homography3d(input) # Bx4x4 226 | """ 227 | if not isinstance(A, torch.Tensor): 228 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(A))) 229 | if not (len(A.shape) == 3 and A.shape[-2:] == (3, 4)): 230 | raise ValueError("Input matrix must be a Bx3x4 tensor. Got {}".format(A.shape)) 231 | return _convert_affinematrix_to_homography_impl(A) 232 | 233 | 234 | def angle_axis_to_rotation_matrix(angle_axis: torch.Tensor) -> torch.Tensor: 235 | r"""Convert 3d vector of axis-angle rotation to 3x3 rotation matrix. 236 | 237 | Args: 238 | angle_axis: tensor of 3d vector of axis-angle rotations. 239 | 240 | Returns: 241 | tensor of 3x3 rotation matrices. 242 | 243 | Shape: 244 | - Input: :math:`(N, 3)` 245 | - Output: :math:`(N, 3, 3)` 246 | 247 | Example: 248 | >>> input = torch.rand(1, 3) # Nx3 249 | >>> output = angle_axis_to_rotation_matrix(input) # Nx3x3 250 | """ 251 | if not isinstance(angle_axis, torch.Tensor): 252 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(angle_axis))) 253 | 254 | if not angle_axis.shape[-1] == 3: 255 | raise ValueError("Input size must be a (*, 3) tensor. Got {}".format(angle_axis.shape)) 256 | 257 | orig_shape = angle_axis.shape 258 | angle_axis = angle_axis.reshape(-1, 3) 259 | 260 | def _compute_rotation_matrix(angle_axis, theta2, eps=1e-6): 261 | # We want to be careful to only evaluate the square root if the 262 | # norm of the angle_axis vector is greater than zero. Otherwise 263 | # we get a division by zero. 264 | k_one = 1.0 265 | theta = torch.sqrt(theta2.clamp_min(eps)) 266 | wxyz = angle_axis / (theta + eps) 267 | wx, wy, wz = torch.chunk(wxyz, 3, dim=1) 268 | cos_theta = torch.cos(theta) 269 | sin_theta = torch.sin(theta) 270 | 271 | r00 = cos_theta + wx * wx * (k_one - cos_theta) 272 | r10 = wz * sin_theta + wx * wy * (k_one - cos_theta) 273 | r20 = -wy * sin_theta + wx * wz * (k_one - cos_theta) 274 | r01 = wx * wy * (k_one - cos_theta) - wz * sin_theta 275 | r11 = cos_theta + wy * wy * (k_one - cos_theta) 276 | r21 = wx * sin_theta + wy * wz * (k_one - cos_theta) 277 | r02 = wy * sin_theta + wx * wz * (k_one - cos_theta) 278 | r12 = -wx * sin_theta + wy * wz * (k_one - cos_theta) 279 | r22 = cos_theta + wz * wz * (k_one - cos_theta) 280 | rotation_matrix = torch.cat([r00, r01, r02, r10, r11, r12, r20, r21, r22], dim=1) 281 | return rotation_matrix.view(-1, 3, 3) 282 | 283 | def _compute_rotation_matrix_taylor(angle_axis): 284 | rx, ry, rz = torch.chunk(angle_axis, 3, dim=1) 285 | k_one = torch.ones_like(rx) 286 | rotation_matrix = torch.cat([k_one, -rz, ry, rz, k_one, -rx, -ry, rx, k_one], dim=1) 287 | return rotation_matrix.view(-1, 3, 3) 288 | 289 | # stolen from ceres/rotation.h 290 | 291 | _angle_axis = torch.unsqueeze(angle_axis, dim=1) 292 | theta2 = torch.matmul(_angle_axis, _angle_axis.transpose(1, 2)) 293 | theta2 = torch.squeeze(theta2, dim=1) 294 | 295 | # compute rotation matrices 296 | rotation_matrix_normal = _compute_rotation_matrix(angle_axis, theta2) 297 | rotation_matrix_taylor = _compute_rotation_matrix_taylor(angle_axis) 298 | 299 | # create mask to handle both cases 300 | eps = 1e-6 301 | mask = (theta2 > eps).view(-1, 1, 1).to(theta2.device) 302 | mask_pos = (mask).type_as(theta2) 303 | mask_neg = (mask == False).type_as(theta2) # noqa 304 | 305 | # create output pose matrix 306 | batch_size = angle_axis.shape[0] 307 | rotation_matrix = torch.eye(3).to(angle_axis.device).type_as(angle_axis) 308 | rotation_matrix = rotation_matrix.view(1, 3, 3).repeat(batch_size, 1, 1) 309 | # fill output matrix with masked values 310 | rotation_matrix[..., :3, :3] = mask_pos * rotation_matrix_normal + mask_neg * rotation_matrix_taylor 311 | 312 | rotation_matrix = rotation_matrix.view(orig_shape[:-1] + (3, 3)) 313 | return rotation_matrix # Nx3x3 314 | 315 | 316 | def rotation_matrix_to_angle_axis(rotation_matrix: torch.Tensor) -> torch.Tensor: 317 | r"""Convert 3x3 rotation matrix to Rodrigues vector. 318 | 319 | Args: 320 | rotation_matrix: rotation matrix. 321 | 322 | Returns: 323 | Rodrigues vector transformation. 324 | 325 | Shape: 326 | - Input: :math:`(N, 3, 3)` 327 | - Output: :math:`(N, 3)` 328 | 329 | Example: 330 | >>> input = torch.rand(2, 3, 3) # Nx3x3 331 | >>> output = rotation_matrix_to_angle_axis(input) # Nx3 332 | """ 333 | if not isinstance(rotation_matrix, torch.Tensor): 334 | raise TypeError(f"Input type is not a torch.Tensor. Got {type(rotation_matrix)}") 335 | 336 | if not rotation_matrix.shape[-2:] == (3, 3): 337 | raise ValueError(f"Input size must be a (*, 3, 3) tensor. Got {rotation_matrix.shape}") 338 | quaternion: torch.Tensor = rotation_matrix_to_quaternion(rotation_matrix, order=QuaternionCoeffOrder.WXYZ) 339 | return quaternion_to_angle_axis(quaternion, order=QuaternionCoeffOrder.WXYZ) 340 | 341 | 342 | 343 | def safe_zero_division(numerator: torch.Tensor, denominator: torch.Tensor, eps: float = 1.0e-6) -> torch.Tensor: 344 | denominator = denominator.clone() 345 | denominator[denominator.abs() < eps] += eps 346 | return numerator / denominator 347 | 348 | 349 | def rotation_matrix_to_quaternion( 350 | rotation_matrix: torch.Tensor, eps: float = 1.0e-6, order: QuaternionCoeffOrder = QuaternionCoeffOrder.WXYZ 351 | ) -> torch.Tensor: 352 | r"""Convert 3x3 rotation matrix to 4d quaternion vector. 353 | 354 | The quaternion vector has components in (w, x, y, z) or (x, y, z, w) format. 355 | 356 | .. note:: 357 | The (x, y, z, w) order is going to be deprecated in favor of efficiency. 358 | 359 | Args: 360 | rotation_matrix: the rotation matrix to convert. 361 | eps: small value to avoid zero division. 362 | order: quaternion coefficient order. Note: 'xyzw' will be deprecated in favor of 'wxyz'. 363 | 364 | Return: 365 | the rotation in quaternion. 366 | 367 | Shape: 368 | - Input: :math:`(*, 3, 3)` 369 | - Output: :math:`(*, 4)` 370 | 371 | Example: 372 | >>> input = torch.rand(4, 3, 3) # Nx3x3 373 | >>> output = rotation_matrix_to_quaternion(input, eps=torch.finfo(input.dtype).eps, 374 | ... order=QuaternionCoeffOrder.WXYZ) # Nx4 375 | """ 376 | if not isinstance(rotation_matrix, torch.Tensor): 377 | raise TypeError(f"Input type is not a torch.Tensor. Got {type(rotation_matrix)}") 378 | 379 | if not rotation_matrix.shape[-2:] == (3, 3): 380 | raise ValueError(f"Input size must be a (*, 3, 3) tensor. Got {rotation_matrix.shape}") 381 | 382 | if not torch.jit.is_scripting(): 383 | if order.name not in QuaternionCoeffOrder.__members__.keys(): 384 | raise ValueError(f"order must be one of {QuaternionCoeffOrder.__members__.keys()}") 385 | 386 | if order == QuaternionCoeffOrder.XYZW: 387 | warnings.warn( 388 | "`XYZW` quaternion coefficient order is deprecated and" 389 | " will be removed after > 0.6. " 390 | "Please use `QuaternionCoeffOrder.WXYZ` instead." 391 | ) 392 | 393 | rotation_matrix_vec: torch.Tensor = rotation_matrix.view(*rotation_matrix.shape[:-2], 9) 394 | 395 | m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.chunk(rotation_matrix_vec, chunks=9, dim=-1) 396 | 397 | trace: torch.Tensor = m00 + m11 + m22 398 | 399 | def trace_positive_cond(): 400 | sq = torch.sqrt((trace + 1.0).clamp_min(eps)) * 2.0 # sq = 4 * qw. 401 | qw = 0.25 * sq 402 | qx = safe_zero_division(m21 - m12, sq) 403 | qy = safe_zero_division(m02 - m20, sq) 404 | qz = safe_zero_division(m10 - m01, sq) 405 | if order == QuaternionCoeffOrder.XYZW: 406 | return torch.cat((qx, qy, qz, qw), dim=-1) 407 | return torch.cat((qw, qx, qy, qz), dim=-1) 408 | 409 | def cond_1(): 410 | sq = torch.sqrt((1.0 + m00 - m11 - m22).clamp_min(eps)) * 2.0 # sq = 4 * qx. 411 | qw = safe_zero_division(m21 - m12, sq) 412 | qx = 0.25 * sq 413 | qy = safe_zero_division(m01 + m10, sq) 414 | qz = safe_zero_division(m02 + m20, sq) 415 | if order == QuaternionCoeffOrder.XYZW: 416 | return torch.cat((qx, qy, qz, qw), dim=-1) 417 | return torch.cat((qw, qx, qy, qz), dim=-1) 418 | 419 | def cond_2(): 420 | sq = torch.sqrt((1.0 + m11 - m00 - m22).clamp_min(eps)) * 2.0 # sq = 4 * qy. 421 | qw = safe_zero_division(m02 - m20, sq) 422 | qx = safe_zero_division(m01 + m10, sq) 423 | qy = 0.25 * sq 424 | qz = safe_zero_division(m12 + m21, sq) 425 | if order == QuaternionCoeffOrder.XYZW: 426 | return torch.cat((qx, qy, qz, qw), dim=-1) 427 | return torch.cat((qw, qx, qy, qz), dim=-1) 428 | 429 | def cond_3(): 430 | sq = torch.sqrt((1.0 + m22 - m00 - m11).clamp_min(eps)) * 2.0 # sq = 4 * qz. 431 | qw = safe_zero_division(m10 - m01, sq) 432 | qx = safe_zero_division(m02 + m20, sq) 433 | qy = safe_zero_division(m12 + m21, sq) 434 | qz = 0.25 * sq 435 | if order == QuaternionCoeffOrder.XYZW: 436 | return torch.cat((qx, qy, qz, qw), dim=-1) 437 | return torch.cat((qw, qx, qy, qz), dim=-1) 438 | 439 | where_2 = torch.where(m11 > m22, cond_2(), cond_3()) 440 | where_1 = torch.where((m00 > m11) & (m00 > m22), cond_1(), where_2) 441 | 442 | quaternion: torch.Tensor = torch.where(trace > 0.0, trace_positive_cond(), where_1) 443 | return quaternion 444 | 445 | 446 | def normalize_quaternion(quaternion: torch.Tensor, eps: float = 1.0e-12) -> torch.Tensor: 447 | r"""Normalizes a quaternion. 448 | 449 | The quaternion should be in (x, y, z, w) format. 450 | 451 | Args: 452 | quaternion: a tensor containing a quaternion to be normalized. 453 | The tensor can be of shape :math:`(*, 4)`. 454 | eps: small value to avoid division by zero. 455 | 456 | Return: 457 | the normalized quaternion of shape :math:`(*, 4)`. 458 | 459 | Example: 460 | >>> quaternion = torch.tensor((1., 0., 1., 0.)) 461 | >>> normalize_quaternion(quaternion) 462 | tensor([0.7071, 0.0000, 0.7071, 0.0000]) 463 | """ 464 | if not isinstance(quaternion, torch.Tensor): 465 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(quaternion))) 466 | 467 | if not quaternion.shape[-1] == 4: 468 | raise ValueError("Input must be a tensor of shape (*, 4). Got {}".format(quaternion.shape)) 469 | return F.normalize(quaternion, p=2.0, dim=-1, eps=eps) 470 | 471 | 472 | # based on: 473 | # https://github.com/matthew-brett/transforms3d/blob/8965c48401d9e8e66b6a8c37c65f2fc200a076fa/transforms3d/quaternions.py#L101 474 | # https://github.com/tensorflow/graphics/blob/master/tensorflow_graphics/geometry/transformation/rotation_matrix_3d.py#L247 475 | 476 | 477 | def quaternion_to_rotation_matrix( 478 | quaternion: torch.Tensor, order: QuaternionCoeffOrder = QuaternionCoeffOrder.WXYZ 479 | ) -> torch.Tensor: 480 | r"""Converts a quaternion to a rotation matrix. 481 | 482 | The quaternion should be in (x, y, z, w) or (w, x, y, z) format. 483 | 484 | Args: 485 | quaternion: a tensor containing a quaternion to be converted. 486 | The tensor can be of shape :math:`(*, 4)`. 487 | order: quaternion coefficient order. Note: 'xyzw' will be deprecated in favor of 'wxyz'. 488 | 489 | Return: 490 | the rotation matrix of shape :math:`(*, 3, 3)`. 491 | 492 | Example: 493 | >>> quaternion = torch.tensor((0., 0., 0., 1.)) 494 | >>> quaternion_to_rotation_matrix(quaternion, order=QuaternionCoeffOrder.WXYZ) 495 | tensor([[-1., 0., 0.], 496 | [ 0., -1., 0.], 497 | [ 0., 0., 1.]]) 498 | """ 499 | if not isinstance(quaternion, torch.Tensor): 500 | raise TypeError(f"Input type is not a torch.Tensor. Got {type(quaternion)}") 501 | 502 | if not quaternion.shape[-1] == 4: 503 | raise ValueError(f"Input must be a tensor of shape (*, 4). Got {quaternion.shape}") 504 | 505 | if not torch.jit.is_scripting(): 506 | if order.name not in QuaternionCoeffOrder.__members__.keys(): 507 | raise ValueError(f"order must be one of {QuaternionCoeffOrder.__members__.keys()}") 508 | 509 | if order == QuaternionCoeffOrder.XYZW: 510 | warnings.warn( 511 | "`XYZW` quaternion coefficient order is deprecated and" 512 | " will be removed after > 0.6. " 513 | "Please use `QuaternionCoeffOrder.WXYZ` instead." 514 | ) 515 | 516 | # normalize the input quaternion 517 | quaternion_norm: torch.Tensor = normalize_quaternion(quaternion) 518 | 519 | # unpack the normalized quaternion components 520 | if order == QuaternionCoeffOrder.XYZW: 521 | x, y, z, w = torch.chunk(quaternion_norm, chunks=4, dim=-1) 522 | else: 523 | w, x, y, z = torch.chunk(quaternion_norm, chunks=4, dim=-1) 524 | 525 | # compute the actual conversion 526 | tx: torch.Tensor = 2.0 * x 527 | ty: torch.Tensor = 2.0 * y 528 | tz: torch.Tensor = 2.0 * z 529 | twx: torch.Tensor = tx * w 530 | twy: torch.Tensor = ty * w 531 | twz: torch.Tensor = tz * w 532 | txx: torch.Tensor = tx * x 533 | txy: torch.Tensor = ty * x 534 | txz: torch.Tensor = tz * x 535 | tyy: torch.Tensor = ty * y 536 | tyz: torch.Tensor = tz * y 537 | tzz: torch.Tensor = tz * z 538 | one: torch.Tensor = torch.tensor(1.0) 539 | 540 | matrix: torch.Tensor = torch.stack( 541 | ( 542 | one - (tyy + tzz), 543 | txy - twz, 544 | txz + twy, 545 | txy + twz, 546 | one - (txx + tzz), 547 | tyz - twx, 548 | txz - twy, 549 | tyz + twx, 550 | one - (txx + tyy), 551 | ), 552 | dim=-1, 553 | ).view(quaternion.shape[:-1] + (3, 3)) 554 | 555 | # if len(quaternion.shape) == 1: 556 | # matrix = torch.squeeze(matrix, dim=0) 557 | return matrix 558 | 559 | 560 | def quaternion_to_angle_axis( 561 | quaternion: torch.Tensor, eps: float = 1.0e-6, order: QuaternionCoeffOrder = QuaternionCoeffOrder.WXYZ 562 | ) -> torch.Tensor: 563 | """Convert quaternion vector to angle axis of rotation. 564 | 565 | The quaternion should be in (x, y, z, w) or (w, x, y, z) format. 566 | 567 | Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h 568 | 569 | Args: 570 | quaternion: tensor with quaternions. 571 | order: quaternion coefficient order. Note: 'xyzw' will be deprecated in favor of 'wxyz'. 572 | 573 | Return: 574 | tensor with angle axis of rotation. 575 | 576 | Shape: 577 | - Input: :math:`(*, 4)` where `*` means, any number of dimensions 578 | - Output: :math:`(*, 3)` 579 | 580 | Example: 581 | >>> quaternion = torch.rand(2, 4) # Nx4 582 | >>> angle_axis = quaternion_to_angle_axis(quaternion) # Nx3 583 | """ 584 | 585 | if not quaternion.shape[-1] == 4: 586 | raise ValueError(f"Input must be a tensor of shape Nx4 or 4. Got {quaternion.shape}") 587 | 588 | if not torch.jit.is_scripting(): 589 | if order.name not in QuaternionCoeffOrder.__members__.keys(): 590 | raise ValueError(f"order must be one of {QuaternionCoeffOrder.__members__.keys()}") 591 | 592 | if order == QuaternionCoeffOrder.XYZW: 593 | warnings.warn( 594 | "`XYZW` quaternion coefficient order is deprecated and" 595 | " will be removed after > 0.6. " 596 | "Please use `QuaternionCoeffOrder.WXYZ` instead." 597 | ) 598 | # unpack input and compute conversion 599 | q1: torch.Tensor = torch.tensor([]) 600 | q2: torch.Tensor = torch.tensor([]) 601 | q3: torch.Tensor = torch.tensor([]) 602 | cos_theta: torch.Tensor = torch.tensor([]) 603 | 604 | if order == QuaternionCoeffOrder.XYZW: 605 | q1 = quaternion[..., 0] 606 | q2 = quaternion[..., 1] 607 | q3 = quaternion[..., 2] 608 | cos_theta = quaternion[..., 3] 609 | else: 610 | cos_theta = quaternion[..., 0] 611 | q1 = quaternion[..., 1] 612 | q2 = quaternion[..., 2] 613 | q3 = quaternion[..., 3] 614 | 615 | sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3 616 | 617 | sin_theta: torch.Tensor = torch.sqrt((sin_squared_theta).clamp_min(eps)) 618 | two_theta: torch.Tensor = 2.0 * torch.where( 619 | cos_theta < 0.0, torch_safe_atan2(-sin_theta, -cos_theta), torch_safe_atan2(sin_theta, cos_theta) 620 | ) 621 | 622 | k_pos: torch.Tensor = safe_zero_division(two_theta, sin_theta, eps) 623 | k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta) 624 | k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg) 625 | 626 | angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3] 627 | angle_axis[..., 0] += q1 * k 628 | angle_axis[..., 1] += q2 * k 629 | angle_axis[..., 2] += q3 * k 630 | return angle_axis 631 | 632 | 633 | def quaternion_log_to_exp( 634 | quaternion: torch.Tensor, eps: float = 1.0e-6, order: QuaternionCoeffOrder = QuaternionCoeffOrder.WXYZ 635 | ) -> torch.Tensor: 636 | r"""Applies exponential map to log quaternion. 637 | 638 | The quaternion should be in (x, y, z, w) or (w, x, y, z) format. 639 | 640 | Args: 641 | quaternion: a tensor containing a quaternion to be converted. 642 | The tensor can be of shape :math:`(*, 3)`. 643 | order: quaternion coefficient order. Note: 'xyzw' will be deprecated in favor of 'wxyz'. 644 | 645 | Return: 646 | the quaternion exponential map of shape :math:`(*, 4)`. 647 | 648 | Example: 649 | >>> quaternion = torch.tensor((0., 0., 0.)) 650 | >>> quaternion_log_to_exp(quaternion, eps=torch.finfo(quaternion.dtype).eps, 651 | ... order=QuaternionCoeffOrder.WXYZ) 652 | tensor([1., 0., 0., 0.]) 653 | """ 654 | if not isinstance(quaternion, torch.Tensor): 655 | raise TypeError(f"Input type is not a torch.Tensor. Got {type(quaternion)}") 656 | 657 | if not quaternion.shape[-1] == 3: 658 | raise ValueError(f"Input must be a tensor of shape (*, 3). Got {quaternion.shape}") 659 | 660 | if not torch.jit.is_scripting(): 661 | if order.name not in QuaternionCoeffOrder.__members__.keys(): 662 | raise ValueError(f"order must be one of {QuaternionCoeffOrder.__members__.keys()}") 663 | 664 | if order == QuaternionCoeffOrder.XYZW: 665 | warnings.warn( 666 | "`XYZW` quaternion coefficient order is deprecated and" 667 | " will be removed after > 0.6. " 668 | "Please use `QuaternionCoeffOrder.WXYZ` instead." 669 | ) 670 | 671 | # compute quaternion norm 672 | norm_q: torch.Tensor = torch.norm(quaternion, p=2, dim=-1, keepdim=True).clamp(min=eps) 673 | 674 | # compute scalar and vector 675 | quaternion_vector: torch.Tensor = quaternion * torch.sin(norm_q) / norm_q 676 | quaternion_scalar: torch.Tensor = torch.cos(norm_q) 677 | 678 | # compose quaternion and return 679 | quaternion_exp: torch.Tensor = torch.tensor([]) 680 | if order == QuaternionCoeffOrder.XYZW: 681 | quaternion_exp = torch.cat((quaternion_vector, quaternion_scalar), dim=-1) 682 | else: 683 | quaternion_exp = torch.cat((quaternion_scalar, quaternion_vector), dim=-1) 684 | 685 | return quaternion_exp 686 | 687 | 688 | def quaternion_exp_to_log( 689 | quaternion: torch.Tensor, eps: float = 1.0e-6, order: QuaternionCoeffOrder = QuaternionCoeffOrder.WXYZ 690 | ) -> torch.Tensor: 691 | r"""Applies the log map to a quaternion. 692 | 693 | The quaternion should be in (x, y, z, w) format. 694 | 695 | Args: 696 | quaternion: a tensor containing a quaternion to be converted. 697 | The tensor can be of shape :math:`(*, 4)`. 698 | eps: A small number for clamping. 699 | order: quaternion coefficient order. Note: 'xyzw' will be deprecated in favor of 'wxyz'. 700 | 701 | Return: 702 | the quaternion log map of shape :math:`(*, 3)`. 703 | 704 | Example: 705 | >>> quaternion = torch.tensor((1., 0., 0., 0.)) 706 | >>> quaternion_exp_to_log(quaternion, eps=torch.finfo(quaternion.dtype).eps, 707 | ... order=QuaternionCoeffOrder.WXYZ) 708 | tensor([0., 0., 0.]) 709 | """ 710 | if not isinstance(quaternion, torch.Tensor): 711 | raise TypeError(f"Input type is not a torch.Tensor. Got {type(quaternion)}") 712 | 713 | if not quaternion.shape[-1] == 4: 714 | raise ValueError(f"Input must be a tensor of shape (*, 4). Got {quaternion.shape}") 715 | 716 | if not torch.jit.is_scripting(): 717 | if order.name not in QuaternionCoeffOrder.__members__.keys(): 718 | raise ValueError(f"order must be one of {QuaternionCoeffOrder.__members__.keys()}") 719 | 720 | if order == QuaternionCoeffOrder.XYZW: 721 | warnings.warn( 722 | "`XYZW` quaternion coefficient order is deprecated and" 723 | " will be removed after > 0.6. " 724 | "Please use `QuaternionCoeffOrder.WXYZ` instead." 725 | ) 726 | 727 | # unpack quaternion vector and scalar 728 | quaternion_vector: torch.Tensor = torch.tensor([]) 729 | quaternion_scalar: torch.Tensor = torch.tensor([]) 730 | 731 | if order == QuaternionCoeffOrder.XYZW: 732 | quaternion_vector = quaternion[..., 0:3] 733 | quaternion_scalar = quaternion[..., 3:4] 734 | else: 735 | quaternion_scalar = quaternion[..., 0:1] 736 | quaternion_vector = quaternion[..., 1:4] 737 | 738 | # compute quaternion norm 739 | norm_q: torch.Tensor = torch.norm(quaternion_vector, p=2, dim=-1, keepdim=True).clamp(min=eps) 740 | 741 | # apply log map 742 | quaternion_log: torch.Tensor = ( 743 | quaternion_vector * torch.acos(torch.clamp(quaternion_scalar, min=-1.0 + eps, max=1.0 - eps)) / norm_q 744 | ) 745 | 746 | return quaternion_log 747 | 748 | 749 | # based on: 750 | # https://github.com/facebookresearch/QuaterNet/blob/master/common/quaternion.py#L138 751 | 752 | 753 | def angle_axis_to_quaternion( 754 | angle_axis: torch.Tensor, eps: float = 1.0e-6, order: QuaternionCoeffOrder = QuaternionCoeffOrder.WXYZ 755 | ) -> torch.Tensor: 756 | r"""Convert an angle axis to a quaternion. 757 | 758 | The quaternion vector has components in (x, y, z, w) or (w, x, y, z) format. 759 | 760 | Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h 761 | 762 | Args: 763 | angle_axis: tensor with angle axis. 764 | order: quaternion coefficient order. Note: 'xyzw' will be deprecated in favor of 'wxyz'. 765 | 766 | Return: 767 | tensor with quaternion. 768 | 769 | Shape: 770 | - Input: :math:`(*, 3)` where `*` means, any number of dimensions 771 | - Output: :math:`(*, 4)` 772 | 773 | Example: 774 | >>> angle_axis = torch.rand(2, 3) # Nx3 775 | >>> quaternion = angle_axis_to_quaternion(angle_axis, order=QuaternionCoeffOrder.WXYZ) # Nx4 776 | """ 777 | 778 | if not angle_axis.shape[-1] == 3: 779 | raise ValueError(f"Input must be a tensor of shape Nx3 or 3. Got {angle_axis.shape}") 780 | 781 | if not torch.jit.is_scripting(): 782 | if order.name not in QuaternionCoeffOrder.__members__.keys(): 783 | raise ValueError(f"order must be one of {QuaternionCoeffOrder.__members__.keys()}") 784 | 785 | if order == QuaternionCoeffOrder.XYZW: 786 | warnings.warn( 787 | "`XYZW` quaternion coefficient order is deprecated and" 788 | " will be removed after > 0.6. " 789 | "Please use `QuaternionCoeffOrder.WXYZ` instead." 790 | ) 791 | 792 | # unpack input and compute conversion 793 | a0: torch.Tensor = angle_axis[..., 0:1] 794 | a1: torch.Tensor = angle_axis[..., 1:2] 795 | a2: torch.Tensor = angle_axis[..., 2:3] 796 | theta_squared: torch.Tensor = a0 * a0 + a1 * a1 + a2 * a2 797 | 798 | theta: torch.Tensor = torch.sqrt((theta_squared).clamp_min(eps)) 799 | half_theta: torch.Tensor = theta * 0.5 800 | 801 | mask: torch.Tensor = theta_squared > 0.0 802 | ones: torch.Tensor = torch.ones_like(half_theta) 803 | 804 | k_neg: torch.Tensor = 0.5 * ones 805 | k_pos: torch.Tensor = safe_zero_division(torch.sin(half_theta), theta, eps) 806 | k: torch.Tensor = torch.where(mask, k_pos, k_neg) 807 | w: torch.Tensor = torch.where(mask, torch.cos(half_theta), ones) 808 | 809 | quaternion: torch.Tensor = torch.zeros( 810 | size=angle_axis.shape[:-1] + (4,), dtype=angle_axis.dtype, device=angle_axis.device 811 | ) 812 | if order == QuaternionCoeffOrder.XYZW: 813 | quaternion[..., 0:1] = a0 * k 814 | quaternion[..., 1:2] = a1 * k 815 | quaternion[..., 2:3] = a2 * k 816 | quaternion[..., 3:4] = w 817 | else: 818 | quaternion[..., 1:2] = a0 * k 819 | quaternion[..., 2:3] = a1 * k 820 | quaternion[..., 3:4] = a2 * k 821 | quaternion[..., 0:1] = w 822 | return quaternion 823 | 824 | 825 | # based on: 826 | # https://github.com/ClementPinard/SfmLearner-Pytorch/blob/master/inverse_warp.py#L65-L71 827 | 828 | 829 | def normalize_pixel_coordinates( 830 | pixel_coordinates: torch.Tensor, height: int, width: int, eps: float = 1e-8 831 | ) -> torch.Tensor: 832 | r"""Normalize pixel coordinates between -1 and 1. 833 | 834 | Normalized, -1 if on extreme left, 1 if on extreme right (x = w-1). 835 | 836 | Args: 837 | pixel_coordinates: the grid with pixel coordinates. Shape can be :math:`(*, 2)`. 838 | width: the maximum width in the x-axis. 839 | height: the maximum height in the y-axis. 840 | eps: safe division by zero. 841 | 842 | Return: 843 | the normalized pixel coordinates. 844 | """ 845 | if pixel_coordinates.shape[-1] != 2: 846 | raise ValueError("Input pixel_coordinates must be of shape (*, 2). " "Got {}".format(pixel_coordinates.shape)) 847 | # compute normalization factor 848 | hw: torch.Tensor = torch.stack( 849 | [ 850 | torch.tensor(width, device=pixel_coordinates.device, dtype=pixel_coordinates.dtype), 851 | torch.tensor(height, device=pixel_coordinates.device, dtype=pixel_coordinates.dtype), 852 | ] 853 | ) 854 | 855 | factor: torch.Tensor = torch.tensor(2.0, device=pixel_coordinates.device, dtype=pixel_coordinates.dtype) / ( 856 | hw - 1 857 | ).clamp(eps) 858 | 859 | return factor * pixel_coordinates - 1 860 | 861 | 862 | def denormalize_pixel_coordinates( 863 | pixel_coordinates: torch.Tensor, height: int, width: int, eps: float = 1e-8 864 | ) -> torch.Tensor: 865 | r"""Denormalize pixel coordinates. 866 | 867 | The input is assumed to be -1 if on extreme left, 1 if on extreme right (x = w-1). 868 | 869 | Args: 870 | pixel_coordinates: the normalized grid coordinates. Shape can be :math:`(*, 2)`. 871 | width: the maximum width in the x-axis. 872 | height: the maximum height in the y-axis. 873 | eps: safe division by zero. 874 | 875 | Return: 876 | the denormalized pixel coordinates. 877 | """ 878 | if pixel_coordinates.shape[-1] != 2: 879 | raise ValueError("Input pixel_coordinates must be of shape (*, 2). " "Got {}".format(pixel_coordinates.shape)) 880 | # compute normalization factor 881 | hw: torch.Tensor = ( 882 | torch.stack([torch.tensor(width), torch.tensor(height)]) 883 | .to(pixel_coordinates.device) 884 | .to(pixel_coordinates.dtype) 885 | ) 886 | 887 | factor: torch.Tensor = torch.tensor(2.0) / (hw - 1).clamp(eps) 888 | 889 | return torch.tensor(1.0) / factor * (pixel_coordinates + 1) 890 | 891 | 892 | def normalize_pixel_coordinates3d( 893 | pixel_coordinates: torch.Tensor, depth: int, height: int, width: int, eps: float = 1e-8 894 | ) -> torch.Tensor: 895 | r"""Normalize pixel coordinates between -1 and 1. 896 | 897 | Normalized, -1 if on extreme left, 1 if on extreme right (x = w-1). 898 | 899 | Args: 900 | pixel_coordinates: the grid with pixel coordinates. Shape can be :math:`(*, 3)`. 901 | depth: the maximum depth in the z-axis. 902 | height: the maximum height in the y-axis. 903 | width: the maximum width in the x-axis. 904 | eps: safe division by zero. 905 | 906 | Return: 907 | the normalized pixel coordinates. 908 | """ 909 | if pixel_coordinates.shape[-1] != 3: 910 | raise ValueError("Input pixel_coordinates must be of shape (*, 3). " "Got {}".format(pixel_coordinates.shape)) 911 | # compute normalization factor 912 | dhw: torch.Tensor = ( 913 | torch.stack([torch.tensor(depth), torch.tensor(width), torch.tensor(height)]) 914 | .to(pixel_coordinates.device) 915 | .to(pixel_coordinates.dtype) 916 | ) 917 | 918 | factor: torch.Tensor = torch.tensor(2.0) / (dhw - 1).clamp(eps) 919 | 920 | return factor * pixel_coordinates - 1 921 | 922 | 923 | def denormalize_pixel_coordinates3d( 924 | pixel_coordinates: torch.Tensor, depth: int, height: int, width: int, eps: float = 1e-8 925 | ) -> torch.Tensor: 926 | r"""Denormalize pixel coordinates. 927 | 928 | The input is assumed to be -1 if on extreme left, 1 if on extreme right (x = w-1). 929 | 930 | Args: 931 | pixel_coordinates: the normalized grid coordinates. Shape can be :math:`(*, 3)`. 932 | depth: the maximum depth in the x-axis. 933 | height: the maximum height in the y-axis. 934 | width: the maximum width in the x-axis. 935 | eps: safe division by zero. 936 | 937 | Return: 938 | the denormalized pixel coordinates. 939 | """ 940 | if pixel_coordinates.shape[-1] != 3: 941 | raise ValueError("Input pixel_coordinates must be of shape (*, 3). " "Got {}".format(pixel_coordinates.shape)) 942 | # compute normalization factor 943 | dhw: torch.Tensor = ( 944 | torch.stack([torch.tensor(depth), torch.tensor(width), torch.tensor(height)]) 945 | .to(pixel_coordinates.device) 946 | .to(pixel_coordinates.dtype) 947 | ) 948 | 949 | factor: torch.Tensor = torch.tensor(2.0) / (dhw - 1).clamp(eps) 950 | 951 | return torch.tensor(1.0) / factor * (pixel_coordinates + 1) -------------------------------------------------------------------------------- /smpl_visualizer/smpl.py: -------------------------------------------------------------------------------- 1 | # This script is borrowed and extended from https://github.com/nkolot/SPIN/blob/master/models/hmr.py 2 | # Adhere to their licence to use this script 3 | 4 | import torch 5 | import numpy as np 6 | from collections import namedtuple 7 | from smplx import SMPL as _SMPL 8 | import os 9 | from smplx.lbs import vertices2joints, batch_rigid_transform, batch_rodrigues, transform_mat, blend_shapes 10 | 11 | ModelOutput = namedtuple('ModelOutput', 12 | ['vertices', 13 | 'joints', 'full_pose', 'betas', 14 | 'global_orient', 15 | 'body_pose', 'expression', 16 | 'left_hand_pose', 'right_hand_pose', 17 | 'jaw_pose', 18 | 'global_trans', 19 | 'scale']) 20 | ModelOutput.__new__.__defaults__ = (None,) * len(ModelOutput._fields) 21 | 22 | 23 | H36M_TO_J17 = [6, 5, 4, 1, 2, 3, 16, 15, 14, 11, 12, 13, 8, 10, 0, 7, 9] 24 | H36M_TO_J14 = H36M_TO_J17[:14] 25 | H36M_TO_J15 = [H36M_TO_J17[14]] + H36M_TO_J17[:14] 26 | H36M_TO_J16 = H36M_TO_J17[14:16] + H36M_TO_J17[:14] 27 | 28 | BASE_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 29 | SMPL_MODEL_DIR = os.path.join(BASE_DIR, 'data/smpl') 30 | 31 | 32 | # Map joints to SMPL joints 33 | JOINT_MAP = { 34 | 'OP Nose': 24, 'OP Neck': 12, 'OP RShoulder': 17, 35 | 'OP RElbow': 19, 'OP RWrist': 21, 'OP LShoulder': 16, 36 | 'OP LElbow': 18, 'OP LWrist': 20, 'OP MidHip': 0, 37 | 'OP RHip': 2, 'OP RKnee': 5, 'OP RAnkle': 8, 38 | 'OP LHip': 1, 'OP LKnee': 4, 'OP LAnkle': 7, 39 | 'OP REye': 25, 'OP LEye': 26, 'OP REar': 27, 40 | 'OP LEar': 28, 'OP LBigToe': 29, 'OP LSmallToe': 30, 41 | 'OP LHeel': 31, 'OP RBigToe': 32, 'OP RSmallToe': 33, 'OP RHeel': 34, 42 | 'Right Ankle': 8, 'Right Knee': 5, 'Right Hip': 45, 43 | 'Left Hip': 46, 'Left Knee': 4, 'Left Ankle': 7, 44 | 'Right Wrist': 21, 'Right Elbow': 19, 'Right Shoulder': 17, 45 | 'Left Shoulder': 16, 'Left Elbow': 18, 'Left Wrist': 20, 46 | 'Neck (LSP)': 47, 'Top of Head (LSP)': 48, 47 | 'Pelvis (MPII)': 49, 'Thorax (MPII)': 50, 48 | 'Spine (H36M)': 51, 'Jaw (H36M)': 52, 49 | 'Head (H36M)': 53, 'Nose': 24, 'Left Eye': 26, 50 | 'Right Eye': 25, 'Left Ear': 28, 'Right Ear': 27, 51 | 'Left Thumb Tip': 35, 'Left Index Tip': 36, 'Left Middle Tip': 37, 52 | 'Left Ring Tip': 38, 'Left Pinky Tip': 39, 53 | 'Right Thumb Tip': 40, 'Right Index Tip': 41, 'Right Middle Tip': 42, 54 | 'Right Ring Tip': 43, 'Right Pinky Tip': 44 55 | } 56 | 57 | JOINT_NAMES = [ 58 | 'OP Nose', 'OP Neck', 'OP RShoulder', 59 | 'OP RElbow', 'OP RWrist', 'OP LShoulder', 60 | 'OP LElbow', 'OP LWrist', 'OP MidHip', 61 | 'OP RHip', 'OP RKnee', 'OP RAnkle', 62 | 'OP LHip', 'OP LKnee', 'OP LAnkle', 63 | 'OP REye', 'OP LEye', 'OP REar', 64 | 'OP LEar', 'OP LBigToe', 'OP LSmallToe', 65 | 'OP LHeel', 'OP RBigToe', 'OP RSmallToe', 'OP RHeel', 66 | 'Right Ankle', 'Right Knee', 'Right Hip', 67 | 'Left Hip', 'Left Knee', 'Left Ankle', 68 | 'Right Wrist', 'Right Elbow', 'Right Shoulder', 69 | 'Left Shoulder', 'Left Elbow', 'Left Wrist', 70 | 'Neck (LSP)', 'Top of Head (LSP)', 71 | 'Pelvis (MPII)', 'Thorax (MPII)', 72 | 'Spine (H36M)', 'Jaw (H36M)', 73 | 'Head (H36M)', 'Nose', 'Left Eye', 74 | 'Right Eye', 'Left Ear', 'Right Ear' 75 | ] 76 | 77 | SMPL_JOINT_NAMES = [ 78 | 'pelvis', 79 | 'left_hip', 80 | 'right_hip', 81 | 'spine1', 82 | 'left_knee', 83 | 'right_knee', 84 | 'spine2', 85 | 'left_ankle', 86 | 'right_ankle', 87 | 'spine3', 88 | 'left_foot', 89 | 'right_foot', 90 | 'neck', 91 | 'left_collar', 92 | 'right_collar', 93 | 'head', 94 | 'left_shoulder', 95 | 'right_shoulder', 96 | 'left_elbow', 97 | 'right_elbow', 98 | 'left_wrist', 99 | 'right_wrist', 100 | 'left_index1', 101 | 'right_index1' 102 | ] 103 | 104 | 105 | def print_smpl_joint_val(x_dict, include_root=False): 106 | sind = 0 if include_root else 1 107 | jnum = 24 - sind 108 | for i in range(jnum): 109 | jstr = SMPL_JOINT_NAMES[i + sind] 110 | val_str = ' --- '.join([name + ': ' + np.array2string(x[..., 3 * i: 3* (i + 1)], precision=3, suppress_small=True, sign=' ') for name, x in x_dict.items()]) 111 | print(f' {jstr:20} --- {val_str}') 112 | return 113 | 114 | 115 | 116 | def get_ordered_joint_names(pose_type): 117 | joint_names = None 118 | if pose_type == 'body26': 119 | # joints order according to 120 | joint_names = [ 121 | 'Pelvis (MPII)', # 0 122 | 'OP LHip', # 1 123 | 'OP RHip', # 2 124 | 'Spine (H36M)', # 3 125 | 'OP LKnee', # 4 126 | 'OP RKnee', # 5 127 | 'OP Neck', # 6 128 | 'OP LAnkle', # 7 129 | 'OP RAnkle', # 8 130 | 'OP LBigToe', # 9 131 | 'OP RBigToe', # 10 132 | 'OP LSmallToe', # 11 133 | 'OP RSmallToe', # 12 134 | 'OP LHeel', # 13 135 | 'OP RHeel', # 14 136 | 'OP Nose', # 15 137 | 'OP LEye', # 16 138 | 'OP REye', # 17 139 | 'OP LEar', # 18 140 | 'OP REar', # 19 141 | 'OP LShoulder', # 20 142 | 'OP RShoulder', # 21 143 | 'OP LElbow', # 22 144 | 'OP RElbow', # 23 145 | 'OP LWrist', # 24 146 | 'OP RWrist', # 25 147 | ] 148 | elif pose_type == 'body34': 149 | joint_names = [ 150 | 'Pelvis (MPII)', # 0 151 | 'OP LHip', # 1 152 | 'OP RHip', # 2 153 | 'Spine (H36M)', # 3 154 | 'OP LKnee', # 4 155 | 'OP RKnee', # 5 156 | 'OP Neck', # 6 157 | 'OP LAnkle', # 7 158 | 'OP RAnkle', # 8 159 | 'OP LBigToe', # 9 160 | 'OP RBigToe', # 10 161 | 'OP LSmallToe', # 11 162 | 'OP RSmallToe', # 12 163 | 'OP LHeel', # 13 164 | 'OP RHeel', # 14 165 | 'OP Nose', # 15 166 | 'OP LEye', # 16 167 | 'OP REye', # 17 168 | 'OP LEar', # 18 169 | 'OP REar', # 19 170 | 'OP LShoulder', # 20 171 | 'OP RShoulder', # 21 172 | 'OP LElbow', # 22 173 | 'OP RElbow', # 23 174 | 'OP LWrist', # 24 175 | 'OP RWrist', # 25 176 | 'Left Pinky Tip', # 26 FIXME Using Tip instead of Knuckle. 177 | 'Right Pinky Tip', # 27 FIXME Using Tip instead of Knuckle. 178 | 'Left Middle Tip', #28 179 | 'Right Middle Tip', #29 180 | 'Left Index Tip', # 30 FIXME: Using Tip instead of knuckle 181 | 'Right Index Tip', # 31 FIXME: Using Tip instead of knuckle 182 | 'Left Thumb Tip', # 32 183 | 'Right Thumb Tip' #33 184 | ] 185 | elif pose_type == 'body30': 186 | joint_names = [ 187 | 'Pelvis (MPII)', # 0 188 | 'OP LHip', # 1 189 | 'OP RHip', # 2 190 | 'Spine (H36M)', # 3 191 | 'OP LKnee', # 4 192 | 'OP RKnee', # 5 193 | 'OP Neck', # 6 194 | 'OP LAnkle', # 7 195 | 'OP RAnkle', # 8 196 | 'OP LBigToe', # 9 197 | 'OP RBigToe', # 10 198 | 'OP LSmallToe', # 11 199 | 'OP RSmallToe', # 12 200 | 'OP LHeel', # 13 201 | 'OP RHeel', # 14 202 | 'OP Nose', # 15 203 | 'OP LEye', # 16 204 | 'OP REye', # 17 205 | 'OP LEar', # 18 206 | 'OP REar', # 19 207 | 'OP LShoulder', # 20 208 | 'OP RShoulder', # 21 209 | 'OP LElbow', # 22 210 | 'OP RElbow', # 23 211 | 'OP LWrist', # 24 212 | 'OP RWrist', # 25 213 | 'Left Pinky Tip', # 26 FIXME Using Tip instead of Knuckle. 214 | 'Right Pinky Tip', # 27 FIXME Using Tip instead of Knuckle. 215 | 'Left Index Tip', # 30 FIXME: Using Tip instead of knuckle 216 | 'Right Index Tip', # 31 FIXME: Using Tip instead of knuckle 217 | ] 218 | 219 | elif pose_type == "body26fk": 220 | joint_names = [ 221 | #'Pelvis (MPII)', # 0 222 | 'Pelvis (MPII)', # 0 223 | 'OP LHip', # 1 224 | 'OP RHip', # 2 225 | 'Spine (H36M)', # 3 226 | 'OP LKnee', # 4 227 | 'OP RKnee', # 5 228 | 'OP Neck', # 6 229 | 'OP LAnkle', # 7 230 | 'OP RAnkle', # 8 231 | 'OP LBigToe', # 9 232 | 'OP RBigToe', # 10 233 | 'OP LSmallToe', # 11 234 | 'OP RSmallToe', # 12 235 | 'OP LHeel', # 13 236 | 'OP RHeel', # 14 237 | 'OP Nose', # 15 238 | 'OP LEye', # 16 239 | 'OP REye', # 17 240 | 'OP LEar', # 18 241 | 'OP REar', # 19 242 | 'OP LShoulder', # 20 243 | 'OP RShoulder', # 21 244 | 'OP LElbow', # 22 245 | 'OP RElbow', # 23 246 | 'OP LWrist', # 24 247 | 'OP RWrist', # 25 248 | ] 249 | elif pose_type == "body15": 250 | 251 | joint_names = [ 252 | 'Pelvis (MPII)', # 0 253 | 'OP RAnkle', # 1 254 | 'OP RKnee', # 2 255 | 'OP RHip', # 3 256 | 'OP LHip', # 4 257 | 'OP LKnee', # 5 258 | 'OP LAnkle', # 6 259 | 'OP RWrist', # 7 260 | 'OP RElbow', # 8 261 | 'OP RShoulder', # 9 262 | 'OP LShoulder', # 10 263 | 'OP LElbow', # 11 264 | 'OP LWrist', # 12 265 | 'Neck (LSP)', # 13 266 | 'Top of Head (LSP)' # 14 267 | ] 268 | 269 | return joint_names 270 | 271 | 272 | class SMPL(_SMPL): 273 | """ Extension of the official SMPL implementation to support more joints """ 274 | 275 | def __init__(self, *args, **kwargs): 276 | super(SMPL, self).__init__(*args, **kwargs) 277 | if 'pose_type' in kwargs.keys(): 278 | self.joint_names = get_ordered_joint_names(kwargs['pose_type'] ) 279 | else: 280 | self.joint_names = JOINT_NAMES 281 | if 'device' in kwargs.keys(): 282 | self.device = kwargs['device'] 283 | else: 284 | self.device = torch.device('cpu') 285 | 286 | joints = [JOINT_MAP[i] for i in self.joint_names] 287 | self.joint_map = torch.tensor(joints, dtype=torch.long).to(self.device) 288 | 289 | if 'betas' in kwargs.keys(): 290 | v_shaped = self.v_template + blend_shapes(torch.FloatTensor(kwargs['betas']), self.shapedirs) 291 | else: 292 | v_shaped = self.v_template.unsqueeze(0) 293 | joint_pos_bind = torch.matmul(self.J_regressor, v_shaped).to(self.device) 294 | joint_pos_bind_rel = joint_pos_bind.clone() 295 | joint_pos_bind_rel[:, 1:] -= joint_pos_bind[:, self.parents[1:]] 296 | joint_pos_bind_rel[:, 0] = 0 297 | if 'batch_size' in kwargs.keys(): 298 | batch_size = kwargs['batch_size'] // len(kwargs['betas']) 299 | joint_pos_bind = joint_pos_bind.repeat(batch_size, 1, 1) 300 | joint_pos_bind_rel = joint_pos_bind_rel.repeat(batch_size, 1, 1) 301 | self.joint_pos_bind = joint_pos_bind 302 | self.joint_pos_bind_rel = joint_pos_bind_rel 303 | 304 | def forward(self, *args, root_trans=None, root_scale=None, orig_joints=False, **kwargs): 305 | """ 306 | root_trans: B x 3, root translation 307 | root_scale: B, scale factor w.r.t root 308 | """ 309 | kwargs['get_skin'] = True 310 | smpl_output = super(SMPL, self).forward(*args, **kwargs) 311 | if orig_joints: 312 | joints = smpl_output.joints[:, :24] 313 | else: 314 | extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices) 315 | joints = torch.cat([smpl_output.joints, extra_joints], dim=1) 316 | joints = joints[:, self.joint_map, :] 317 | 318 | output = ModelOutput(vertices=smpl_output.vertices, 319 | global_orient=smpl_output.global_orient, 320 | body_pose=smpl_output.body_pose, 321 | joints=joints, 322 | betas=smpl_output.betas, 323 | full_pose=smpl_output.full_pose) 324 | if root_trans is not None: 325 | if root_scale is None: 326 | root_scale = torch.ones_like(root_trans[:, 0]) 327 | cur_root_trans = joints[:, [0], :] 328 | # rel_trans = (root_trans - joints[:, 0, :]).unsqueeze(1) 329 | output.vertices[:] = (output.vertices - cur_root_trans) * root_scale[:, None, None] + root_trans[:, None, :] 330 | output.joints[:] = (output.joints - cur_root_trans) * root_scale[:, None, None] + root_trans[:, None, :] 331 | return output 332 | 333 | def get_joints(self, betas=None, body_pose=None, global_orient=None, transl=None, 334 | pose2rot=True, root_trans=None, root_scale=None, no_orient=False, dtype=torch.float32): 335 | # If no shape and lib parameters are passed along, then use the 336 | # ones from the module 337 | if no_orient: 338 | pose = body_pose 339 | else: 340 | pose = torch.cat([global_orient, body_pose], dim=1) 341 | 342 | """ LBS """ 343 | batch_size = pose.shape[0] 344 | # J = torch.matmul(self.J_regressor, self.v_template).repeat((batch_size, 1, 1)) 345 | J = self.joint_pos_bind[:batch_size, ...].unsqueeze(-1) 346 | if pose2rot: 347 | rot_mats = batch_rodrigues(pose.view(-1, 3)).view([batch_size, -1, 3, 3]) 348 | else: 349 | rot_mats = pose.view(batch_size, -1, 3, 3) 350 | joints, A = batch_rigid_transform(rot_mats, J, self.parents, dtype=torch.float32) 351 | 352 | if transl is not None: 353 | joints += transl.unsqueeze(dim=1) 354 | 355 | if root_trans is not None: 356 | if root_scale is None: 357 | root_scale = torch.ones_like(root_trans[:, 0]) 358 | cur_root_trans = joints[:, [0], :] 359 | joints[:] = (joints - cur_root_trans) * root_scale[:, None, None] + root_trans[:, None, :] 360 | 361 | return joints 362 | 363 | def get_joints_fast(self, pose=None, root_trans=None): 364 | # If no shape and lib parameters are passed along, then use the 365 | # ones from the module 366 | 367 | """ LBS """ 368 | batch_size = pose.shape[0] 369 | rot_mats = pose.view(batch_size, -1, 3, 3) 370 | 371 | rel_joints = self.joint_pos_bind_rel[:batch_size] 372 | 373 | transforms_mat = transform_mat( 374 | rot_mats.reshape(-1, 3, 3), 375 | rel_joints.reshape(-1, 3, 1)).reshape(-1, rel_joints.shape[1], 4, 4) 376 | 377 | transform_chain = [transforms_mat[:, 0]] 378 | for i in range(1, self.parents.shape[0]): 379 | # Subtract the joint location at the rest pose 380 | # No need for rotation, since it's identity when at rest 381 | curr_res = torch.matmul(transform_chain[self.parents[i]], 382 | transforms_mat[:, i]) 383 | transform_chain.append(curr_res) 384 | 385 | transforms = torch.stack(transform_chain, dim=1) 386 | 387 | # The last column of the transformations contains the posed joints 388 | posed_joints = transforms[:, :, :3, 3] 389 | 390 | if root_trans is not None: 391 | posed_joints = posed_joints + root_trans[:, None, :] 392 | 393 | return posed_joints 394 | 395 | 396 | def get_smpl_faces(): 397 | smpl = SMPL(SMPL_MODEL_DIR, batch_size=1, create_transl=False) 398 | return smpl.faces 399 | 400 | -------------------------------------------------------------------------------- /smpl_visualizer/torch_transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from .konia_transform import quaternion_to_angle_axis, angle_axis_to_quaternion, quaternion_to_rotation_matrix, rotation_matrix_to_quaternion, rotation_matrix_to_angle_axis, angle_axis_to_rotation_matrix 4 | 5 | 6 | def normalize(x, eps: float = 1e-9): 7 | return x / x.norm(p=2, dim=-1).clamp(min=eps, max=None).unsqueeze(-1) 8 | 9 | 10 | @torch.jit.script 11 | def quat_mul(a, b): 12 | assert a.shape == b.shape 13 | shape = a.shape 14 | a = a.reshape(-1, 4) 15 | b = b.reshape(-1, 4) 16 | 17 | w1, x1, y1, z1 = a[:, 0], a[:, 1], a[:, 2], a[:, 3] 18 | w2, x2, y2, z2 = b[:, 0], b[:, 1], b[:, 2], b[:, 3] 19 | ww = (z1 + x1) * (x2 + y2) 20 | yy = (w1 - y1) * (w2 + z2) 21 | zz = (w1 + y1) * (w2 - z2) 22 | xx = ww + yy + zz 23 | qq = 0.5 * (xx + (z1 - x1) * (x2 - y2)) 24 | w = qq - ww + (z1 - y1) * (y2 - z2) 25 | x = qq - xx + (x1 + w1) * (x2 + w2) 26 | y = qq - yy + (w1 - x1) * (y2 + z2) 27 | z = qq - zz + (z1 + y1) * (w2 - x2) 28 | return torch.stack([w, x, y, z], dim=-1).view(shape) 29 | 30 | 31 | @torch.jit.script 32 | def quat_conjugate(a): 33 | shape = a.shape 34 | a = a.reshape(-1, 4) 35 | return torch.cat((a[:, 0:1], -a[:, 1:]), dim=-1).view(shape) 36 | 37 | 38 | @torch.jit.script 39 | def quat_apply(a, b): 40 | shape = b.shape 41 | a = a.reshape(-1, 4) 42 | b = b.reshape(-1, 3) 43 | xyz = a[:, 1:].clone() 44 | t = xyz.cross(b, dim=-1) * 2 45 | return (b + a[:, 0:1].clone() * t + xyz.cross(t, dim=-1)).view(shape) 46 | 47 | 48 | @torch.jit.script 49 | def quat_angle(a, eps: float = 1e-6): 50 | shape = a.shape 51 | a = a.reshape(-1, 4) 52 | s = 2 * (a[:, 0] ** 2) - 1 53 | s = s.clamp(-1 + eps, 1 - eps) 54 | s = s.acos() 55 | return s.view(shape[:-1]) 56 | 57 | 58 | @torch.jit.script 59 | def quat_angle_diff(quat1, quat2): 60 | return quat_angle(quat_mul(quat1, quat_conjugate(quat2))) 61 | 62 | 63 | @torch.jit.script 64 | def torch_safe_atan2(y, x, eps: float = 1e-6): 65 | y = y.clone() 66 | y[(y.abs() < eps) & (x.abs() < eps)] += eps 67 | return torch.atan2(y, x) 68 | 69 | 70 | @torch.jit.script 71 | def ypr_euler_from_quat(q, handle_singularity: bool = False, eps: float = 1e-6, singular_eps: float = 1e-6): 72 | """ 73 | convert quaternion to yaw-pitch-roll euler angles 74 | """ 75 | yaw_atany = 2 * (q[..., 0] * q[..., 3] + q[..., 1] * q[..., 2]) 76 | yaw_atanx = 1 - 2 * (q[..., 2] * q[..., 2] + q[..., 3] * q[..., 3]) 77 | roll_atany = 2 * (q[..., 0] * q[..., 1] + q[..., 2] * q[..., 3]) 78 | roll_atanx = 1 - 2 * (q[..., 1] * q[..., 1] + q[..., 2] * q[..., 2]) 79 | yaw = torch_safe_atan2(yaw_atany, yaw_atanx, eps) 80 | pitch = torch.asin(torch.clamp(2 * (q[..., 0] * q[..., 2] - q[..., 1] * q[..., 3]), min=-1 + eps, max=1 - eps)) 81 | roll = torch_safe_atan2(roll_atany, roll_atanx, eps) 82 | 83 | if handle_singularity: 84 | """ handle two special cases """ 85 | test = q[..., 0] * q[..., 2] - q[..., 1] * q[..., 3] 86 | # north pole, pitch ~= 90 degrees 87 | np_ind = test > 0.5 - singular_eps 88 | if torch.any(np_ind): 89 | # print('ypr_euler_from_quat singularity -- north pole!') 90 | roll[np_ind] = 0.0 91 | pitch[np_ind].clamp_max_(0.5 * np.pi) 92 | yaw_atany = q[..., 3][np_ind] 93 | yaw_atanx = q[..., 0][np_ind] 94 | yaw[np_ind] = 2 * torch_safe_atan2(yaw_atany, yaw_atanx, eps) 95 | # south pole, pitch ~= -90 degrees 96 | sp_ind = test < -0.5 + singular_eps 97 | if torch.any(sp_ind): 98 | # print('ypr_euler_from_quat singularity -- south pole!') 99 | roll[sp_ind] = 0.0 100 | pitch[sp_ind].clamp_min_(-0.5 * np.pi) 101 | yaw_atany = q[..., 3][sp_ind] 102 | yaw_atanx = q[..., 0][sp_ind] 103 | yaw[sp_ind] = 2 * torch_safe_atan2(yaw_atany, yaw_atanx, eps) 104 | 105 | return torch.stack([roll, pitch, yaw], dim=-1) 106 | 107 | 108 | @torch.jit.script 109 | def quat_from_ypr_euler(angles): 110 | """ 111 | convert yaw-pitch-roll euler angles to quaternion 112 | """ 113 | half_ang = angles * 0.5 114 | sin = torch.sin(half_ang) 115 | cos = torch.cos(half_ang) 116 | q = torch.stack([ 117 | cos[..., 0] * cos[..., 1] * cos[..., 2] + sin[..., 0] * sin[..., 1] * sin[..., 2], 118 | sin[..., 0] * cos[..., 1] * cos[..., 2] - cos[..., 0] * sin[..., 1] * sin[..., 2], 119 | cos[..., 0] * sin[..., 1] * cos[..., 2] + sin[..., 0] * cos[..., 1] * sin[..., 2], 120 | cos[..., 0] * cos[..., 1] * sin[..., 2] - sin[..., 0] * sin[..., 1] * cos[..., 2] 121 | ], dim=-1) 122 | return q 123 | 124 | 125 | def quat_between_two_vec(v1, v2, eps: float = 1e-6): 126 | """ 127 | quaternion for rotating v1 to v2 128 | """ 129 | orig_shape = v1.shape 130 | v1 = v1.reshape(-1, 3) 131 | v2 = v2.reshape(-1, 3) 132 | dot = (v1 * v2).sum(-1) 133 | cross = torch.cross(v1, v2, dim=-1) 134 | out = torch.cat([(1 + dot).unsqueeze(-1), cross], dim=-1) 135 | # handle v1 & v2 with same direction 136 | sind = dot > 1 - eps 137 | out[sind] = torch.tensor([1., 0., 0., 0.], device=v1.device) 138 | # handle v1 & v2 with opposite direction 139 | nind = dot < -1 + eps 140 | if torch.any(nind): 141 | vx = torch.tensor([1., 0., 0.], device=v1.device) 142 | vxdot = (v1 * vx).sum(-1).abs() 143 | nxind = nind & (vxdot < 1 - eps) 144 | if torch.any(nxind): 145 | out[nxind] = angle_axis_to_quaternion(normalize(torch.cross(vx.expand_as(v1[nxind]), v1[nxind], dim=-1)) * np.pi) 146 | # handle v1 & v2 with opposite direction and they are parallel to x axis 147 | pind = nind & (vxdot >= 1 - eps) 148 | if torch.any(pind): 149 | vy = torch.tensor([0., 1., 0.], device=v1.device) 150 | out[pind] = angle_axis_to_quaternion(normalize(torch.cross(vy.expand_as(v1[pind]), v1[pind], dim=-1)) * np.pi) 151 | # normalize and reshape 152 | out = normalize(out).view(orig_shape[:-1] + (4,)) 153 | return out 154 | 155 | 156 | @torch.jit.script 157 | def get_yaw(q, eps: float = 1e-6): 158 | yaw_atany = 2 * (q[..., 0] * q[..., 3] + q[..., 1] * q[..., 2]) 159 | yaw_atanx = 1 - 2 * (q[..., 2] * q[..., 2] + q[..., 3] * q[..., 3]) 160 | yaw = torch_safe_atan2(yaw_atany, yaw_atanx, eps) 161 | return yaw 162 | 163 | 164 | @torch.jit.script 165 | def get_yaw_q(q): 166 | yaw = get_yaw(q) 167 | angle_axis = torch.cat([torch.zeros(yaw.shape + (2,), device=q.device), yaw.unsqueeze(-1)], dim=-1) 168 | heading_q = angle_axis_to_quaternion(angle_axis) 169 | return heading_q 170 | 171 | 172 | @torch.jit.script 173 | def get_heading(q, eps: float = 1e-6): 174 | heading_atany = q[..., 3] 175 | heading_atanx = q[..., 0] 176 | heading = 2 * torch_safe_atan2(heading_atany, heading_atanx, eps) 177 | return heading 178 | 179 | 180 | def get_heading_q(q): 181 | q_new = q.clone() 182 | q_new[..., 1] = 0 183 | q_new[..., 2] = 0 184 | q_new = normalize(q_new) 185 | return q_new 186 | 187 | 188 | @torch.jit.script 189 | def heading_to_vec(h_theta): 190 | v = torch.stack([torch.cos(h_theta), torch.sin(h_theta)], dim=-1) 191 | return v 192 | 193 | 194 | @torch.jit.script 195 | def vec_to_heading(h_vec): 196 | h_theta = torch_safe_atan2(h_vec[..., 1], h_vec[..., 0]) 197 | return h_theta 198 | 199 | 200 | @torch.jit.script 201 | def heading_to_quat(h_theta): 202 | angle_axis = torch.cat([torch.zeros(h_theta.shape + (2,), device=h_theta.device), h_theta.unsqueeze(-1)], dim=-1) 203 | heading_q = angle_axis_to_quaternion(angle_axis) 204 | return heading_q 205 | 206 | 207 | def deheading_quat(q, heading_q=None): 208 | if heading_q is None: 209 | heading_q = get_heading_q(q) 210 | dq = quat_mul(quat_conjugate(heading_q), q) 211 | return dq 212 | 213 | 214 | @torch.jit.script 215 | def rotmat_to_rot6d(mat): 216 | rot6d = torch.cat([mat[..., 0], mat[..., 1]], dim=-1) 217 | return rot6d 218 | 219 | 220 | def rot6d_to_rotmat(rot6d): 221 | a1 = rot6d[..., :3] 222 | a2 = rot6d[..., 3:] 223 | b1 = normalize(a1) 224 | b2 = normalize(a2 - (b1 * a2).sum(-1, keepdims=True) * b1) 225 | b3 = torch.cross(b1, b2, dim=-1) 226 | mat = torch.stack([b1, b2, b3], dim=-1) 227 | return mat 228 | 229 | 230 | def angle_axis_to_rot6d(aa): 231 | return rotmat_to_rot6d(angle_axis_to_rotation_matrix(aa)) 232 | 233 | 234 | def rot6d_to_angle_axis(rot6d): 235 | return rotation_matrix_to_angle_axis(rot6d_to_rotmat(rot6d)) 236 | 237 | 238 | def quat_to_rot6d(q): 239 | return rotmat_to_rot6d(quaternion_to_rotation_matrix(q)) 240 | 241 | 242 | def rot6d_to_quat(rot6d): 243 | return rotation_matrix_to_quaternion(rot6d_to_rotmat(rot6d)) 244 | 245 | 246 | def make_transform(rot, trans, rot_type=None): 247 | if rot_type == 'axis_angle': 248 | rot = angle_axis_to_rotation_matrix(rot) 249 | elif rot_type == '6d': 250 | rot = rot6d_to_rotmat(rot) 251 | transform = torch.eye(4).to(trans.device).repeat(rot.shape[:-2] + (1, 1)) 252 | transform[..., :3, :3] = rot 253 | transform[..., :3, 3] = trans 254 | return transform 255 | 256 | 257 | def transform_trans(transform_mat, trans): 258 | trans = torch.cat((trans, torch.ones_like(trans[..., [0]])), dim=-1)[..., None, :] 259 | while len(transform_mat.shape) < len(trans.shape): 260 | transform_mat = transform_mat.unsqueeze(-3) 261 | trans_new = torch.matmul(trans, transform_mat.transpose(-2, -1))[..., 0, :3] 262 | return trans_new 263 | 264 | 265 | def transform_rot(transform_mat, rot): 266 | rot_qmat = angle_axis_to_rotation_matrix(rot) 267 | while len(transform_mat.shape) < len(rot_qmat.shape): 268 | transform_mat = transform_mat.unsqueeze(-3) 269 | rot_qmat_new = torch.matmul(transform_mat[..., :3, :3], rot_qmat) 270 | rot_new = rotation_matrix_to_angle_axis(rot_qmat_new) 271 | return rot_new 272 | 273 | 274 | def inverse_transform(transform_mat): 275 | transform_inv = torch.zeros_like(transform_mat) 276 | transform_inv[..., :3, :3] = transform_mat[..., :3, :3].transpose(-2, -1) 277 | transform_inv[..., :3, 3] = -torch.matmul(transform_mat[..., :3, 3].unsqueeze(-2), transform_mat[..., :3, :3]).squeeze(-2) 278 | transform_inv[..., 3, 3] = 1.0 279 | return transform_inv 280 | 281 | 282 | def batch_compute_similarity_transform_torch(S1, S2): 283 | ''' 284 | This function is borrowed from https://github.com/mkocabas/VIBE/blob/c0c3f77d587351c806e901221a9dc05d1ffade4b/lib/utils/eval_utils.py#L199 285 | 286 | Computes a similarity transform (sR, t) that takes 287 | a set of 3D points S1 (3 x N) closest to a set of 3D points S2, 288 | where R is an 3x3 rotation matrix, t 3x1 translation, s scale. 289 | i.e. solves the orthogonal Procrutes problem. 290 | ''' 291 | if len(S1.shape) > 3: 292 | orig_shape = S1.shape 293 | S1 = S1.reshape(-1, *S1.shape[-2:]) 294 | S2 = S2.reshape(-1, *S2.shape[-2:]) 295 | else: 296 | orig_shape = None 297 | 298 | transposed = False 299 | if S1.shape[0] != 3 and S1.shape[0] != 2: 300 | S1 = S1.permute(0,2,1) 301 | S2 = S2.permute(0,2,1) 302 | transposed = True 303 | assert(S2.shape[1] == S1.shape[1]) 304 | 305 | # 1. Remove mean. 306 | mu1 = S1.mean(axis=-1, keepdims=True) 307 | mu2 = S2.mean(axis=-1, keepdims=True) 308 | 309 | X1 = S1 - mu1 310 | X2 = S2 - mu2 311 | 312 | # 2. Compute variance of X1 used for scale. 313 | var1 = torch.sum(X1**2, dim=1).sum(dim=1) 314 | 315 | # 3. The outer product of X1 and X2. 316 | K = X1.bmm(X2.permute(0,2,1)) 317 | 318 | # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are 319 | # singular vectors of K. 320 | U, s, V = torch.svd(K) 321 | 322 | # Construct Z that fixes the orientation of R to get det(R)=1. 323 | Z = torch.eye(U.shape[1], device=S1.device).unsqueeze(0) 324 | Z = Z.repeat(U.shape[0],1,1) 325 | Z[:,-1, -1] *= torch.sign(torch.det(U.bmm(V.permute(0,2,1)))) 326 | 327 | # Construct R. 328 | R = V.bmm(Z.bmm(U.permute(0,2,1))) 329 | 330 | # 5. Recover scale. 331 | scale = torch.cat([torch.trace(x).unsqueeze(0) for x in R.bmm(K)]) / var1 332 | 333 | # 6. Recover translation. 334 | t = mu2 - (scale.unsqueeze(-1).unsqueeze(-1) * (R.bmm(mu1))) 335 | 336 | # 7. Error: 337 | S1_hat = scale.unsqueeze(-1).unsqueeze(-1) * R.bmm(S1) + t 338 | 339 | if transposed: 340 | S1_hat = S1_hat.permute(0,2,1) 341 | 342 | if orig_shape is not None: 343 | S1_hat = S1_hat.reshape(orig_shape) 344 | 345 | return S1_hat 346 | 347 | -------------------------------------------------------------------------------- /smpl_visualizer/vis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import os.path as osp 4 | import subprocess 5 | import platform 6 | import vtk 7 | import cv2 as cv 8 | from PIL import ImageColor 9 | import matplotlib.pyplot as plt 10 | 11 | 12 | FFMPEG_PATH = '/usr/bin/ffmpeg' if osp.exists('/usr/bin/ffmpeg') else 'ffmpeg' 13 | font_files = { 14 | 'Windows': 'C:/Windows/Fonts/arial.ttf', 15 | 'Linux': '/usr/share/fonts/truetype/lato/Lato-Regular.ttf', 16 | 'Darwin': '/System/Library/Fonts/Supplemental/Arial.ttf' 17 | } 18 | 19 | 20 | def images_to_video(img_dir, out_path, img_fmt="%06d.png", fps=30, crf=10, verbose=True): 21 | os.makedirs(osp.dirname(out_path), exist_ok=True) 22 | cmd = [FFMPEG_PATH, '-y', '-r', f'{fps}', '-f', 'image2', '-start_number', '0', 23 | '-i', f'{img_dir}/{img_fmt}', '-vcodec', 'libx264', '-crf', f'{crf}', '-pix_fmt', 'yuv420p', out_path] 24 | if not verbose: 25 | cmd += ['-hide_banner', '-loglevel', 'error'] 26 | subprocess.run(cmd) 27 | 28 | 29 | def video_to_images(video_path, out_path, img_fmt="%06d.png", fps=30, verbose=True): 30 | os.makedirs(out_path, exist_ok=True) 31 | cmd = [FFMPEG_PATH, '-i', video_path, '-r', f'{fps}', f'{out_path}/{img_fmt}'] 32 | if not verbose: 33 | cmd += ['-hide_banner', '-loglevel', 'error'] 34 | subprocess.run(cmd) 35 | 36 | 37 | def hstack_videos(video1_path, video2_path, out_path, crf=10, verbose=True, text1=None, text2=None, text_color='white', text_size=60): 38 | if not (text1 is None or text2 is None): 39 | write_text = True 40 | tmp_file = f'{osp.splitext(out_path)[0]}_tmp.mp4' 41 | else: 42 | write_text = False 43 | 44 | os.makedirs(osp.dirname(out_path), exist_ok=True) 45 | cmd = [FFMPEG_PATH, '-y', '-i', video1_path, '-i', video2_path, '-filter_complex', 'hstack,format=yuv420p', 46 | '-vcodec', 'libx264', '-crf', f'{crf}', tmp_file if write_text else out_path] 47 | if not verbose: 48 | cmd += ['-hide_banner', '-loglevel', 'error'] 49 | subprocess.run(cmd) 50 | 51 | if write_text: 52 | font_file = font_files[platform.system()] 53 | draw_str = f"drawtext=fontsize={text_size}:fontfile={font_file}:fontcolor={text_color}:text='{text1}':x=(w-text_w)/4:y=20,"\ 54 | f"drawtext=fontsize={text_size}:fontfile={font_file}:fontcolor={text_color}:text='{text2}':x=3*(w-text_w)/4:y=20" 55 | cmd = [FFMPEG_PATH, '-i', tmp_file, '-y', '-vf', draw_str, '-c:a', 'copy', out_path] 56 | if not verbose: 57 | cmd += ['-hide_banner', '-loglevel', 'error'] 58 | subprocess.run(cmd) 59 | os.remove(tmp_file) 60 | 61 | 62 | def vstack_videos(video1_path, video2_path, out_path, crf=10, verbose=True, text1=None, text2=None, text_color='white', text_size=60): 63 | if not (text1 is None or text2 is None): 64 | write_text = True 65 | tmp_file = f'{osp.splitext(out_path)[0]}_tmp.mp4' 66 | else: 67 | write_text = False 68 | 69 | os.makedirs(osp.dirname(out_path), exist_ok=True) 70 | cmd = [FFMPEG_PATH, '-y', '-i', video1_path, '-i', video2_path, '-filter_complex', 'vstack,format=yuv420p', 71 | '-vcodec', 'libx264', '-crf', f'{crf}', tmp_file if write_text else out_path] 72 | if not verbose: 73 | cmd += ['-hide_banner', '-loglevel', 'error'] 74 | subprocess.run(cmd) 75 | 76 | if write_text: 77 | font_file = font_files[platform.system()] 78 | draw_str = f"drawtext=fontsize={text_size}:fontfile={font_file}:fontcolor={text_color}:text='{text1}':x=10:y=20,"\ 79 | f"drawtext=fontsize={text_size}:fontfile={font_file}:fontcolor={text_color}:text='{text2}':x=10:y=h/2+20" 80 | cmd = [FFMPEG_PATH, '-i', tmp_file, '-y', '-vf', draw_str, '-c:a', 'copy', out_path] 81 | if not verbose: 82 | cmd += ['-hide_banner', '-loglevel', 'error'] 83 | subprocess.run(cmd) 84 | os.remove(tmp_file) 85 | 86 | 87 | 88 | def make_checker_board_texture(color1='black', color2='white', width=1000, alpha=None): 89 | c1 = np.asarray(ImageColor.getcolor(color1, 'RGB')).astype(np.uint8) 90 | c2 = np.asarray(ImageColor.getcolor(color2, 'RGB')).astype(np.uint8) 91 | if alpha is not None: 92 | c1 = np.append(c1, int(alpha*255)) 93 | c2 = np.append(c2, int(alpha*255)) 94 | hw = hh = width // 2 95 | c1_block = np.tile(c1, (hh, hw, 1)) 96 | c2_block = np.tile(c2, (hh, hw, 1)) 97 | tex = np.block([ 98 | [[c1_block], [c2_block]], 99 | [[c2_block], [c1_block]] 100 | ]) 101 | return tex 102 | 103 | 104 | def resize_bbox(bbox, scale): 105 | x1, y1, x2, y2 = bbox[..., 0], bbox[..., 1], bbox[..., 2], bbox[..., 3] 106 | h, w = y2 - y1, x2 - x1 107 | cx, cy = x1 + 0.5 * w, y1 + 0.5 * h 108 | h_new, w_new = h * scale, w * scale 109 | x1_new, x2_new = cx - 0.5 * w_new, cx + 0.5 * w_new 110 | y1_new, y2_new = cy - 0.5 * h_new, cy + 0.5 * h_new 111 | bbox_new = np.stack([x1_new, y1_new, x2_new, y2_new], axis=-1) 112 | return bbox_new 113 | 114 | 115 | def nparray_to_vtk_matrix(array): 116 | """Convert a numpy.ndarray to a vtk.vtkMatrix4x4 """ 117 | matrix = vtk.vtkMatrix4x4() 118 | for i in range(array.shape[0]): 119 | for j in range(array.shape[1]): 120 | matrix.SetElement(i, j, array[i, j]) 121 | return matrix 122 | 123 | 124 | def vtk_matrix_to_nparray(matrix): 125 | """Convert a numpy.ndarray to a vtk.vtkMatrix4x4 """ 126 | array = np.zeros([4, 4]) 127 | for i in range(array.shape[0]): 128 | for j in range(array.shape[1]): 129 | array[i, j] = matrix.GetElement(i, j) 130 | return array 131 | 132 | 133 | def draw_keypoints(img, keypoints, confidence, size=4, color=(255, 0, 255)): 134 | for kp, conf in zip(keypoints, confidence): 135 | if conf > 0.2: 136 | cv.circle(img, np.round(kp).astype(int).tolist(), size, color=color, thickness=-1) 137 | return img 138 | 139 | 140 | def get_color_palette(n, colormap='rainbow', use_float=False): 141 | cmap = plt.get_cmap(colormap) 142 | colors = [cmap(i) for i in np.linspace(0, 1, n)] 143 | unit = 1 if use_float else 255 144 | colors = [[int(c[0] * unit), int(c[1] * unit), int(c[2] * unit)] for c in colors] 145 | return colors -------------------------------------------------------------------------------- /smpl_visualizer/vis_pyvista.py: -------------------------------------------------------------------------------- 1 | import pyvista 2 | import time 3 | import math 4 | import numpy as np 5 | import os 6 | import os.path as osp 7 | import shutil 8 | import platform 9 | import tempfile 10 | import pyrender 11 | from tqdm import tqdm 12 | from .vis import images_to_video, make_checker_board_texture, nparray_to_vtk_matrix 13 | 14 | 15 | class PyvistaVisualizer: 16 | 17 | def __init__(self, init_T=6, enable_shadow=False, anti_aliasing=True, use_floor=False, 18 | add_cube=False, distance=5, elevation=20, azimuth=0, verbose=True): 19 | self.enable_shadow = enable_shadow 20 | self.anti_aliasing = anti_aliasing 21 | self.use_floor = use_floor 22 | self.add_cube = add_cube 23 | self.verbose = verbose 24 | self.pl = None 25 | # animation control 26 | self.fr = 0 27 | self.num_fr = 1 28 | self.fps_arr = [1, 2, 5, 10, 30, 40, 50, 60] 29 | self.T_arr = [1, 2, 4, 6, 8, 10, 15, 20, 30, 40, 50, 60] 30 | self.T = init_T 31 | self.paused = False 32 | self.reverse = False 33 | self.repeat = False 34 | # camera 35 | self.distance = distance 36 | self.elevation = elevation 37 | self.azimuth = azimuth 38 | 39 | def init_camera(self): 40 | pass 41 | # self.pl.camera_position = 'yz' 42 | # self.pl.camera.focal_point = (0, 0, 0) 43 | # self.pl.camera.position = (self.distance, 0, 0) 44 | # self.pl.camera.elevation = self.elevation 45 | # self.pl.camera.azimuth = self.azimuth 46 | # self.pl.camera.zoom(1.0) 47 | 48 | def set_camera_instrinsics(self, fx=None, fy=None, cx=None, cy=None, z_near=0.1, zfar=1000): 49 | if None in (fx, fy, cx, cy): 50 | wsize = np.array(self.pl.window_size) 51 | if platform.system() == 'Darwin': 52 | wsize //= 2 53 | fx = fy = wsize.max() 54 | cx, cy = 0.5 * wsize 55 | 56 | intrinsic_cam = pyrender.IntrinsicsCamera(fx, fy, cx, cy, z_near, zfar) 57 | proj_transform = intrinsic_cam.get_projection_matrix(*self.pl.window_size) 58 | self.pl.camera.SetExplicitProjectionTransformMatrix(nparray_to_vtk_matrix(proj_transform)) 59 | self.pl.camera.SetUseExplicitProjectionTransformMatrix(1) 60 | 61 | def init_scene(self, init_args): 62 | # self.pl.set_background('#DBDAD9') 63 | self.pl.set_background('#FCC2EB', top='#C9DFFF') # Classic Rose -> Lavender Blue 64 | # shadow 65 | if self.enable_shadow: 66 | self.pl.enable_shadows() 67 | if self.anti_aliasing: 68 | self.pl.enable_anti_aliasing() 69 | # floor 70 | if self.use_floor: 71 | wlh = (20.0, 20.0, 0.05) 72 | center = np.array([0, 0, -wlh[2] * 0.5]) 73 | self.floor_mesh = pyvista.Cube(center, *wlh) 74 | self.floor_mesh.t_coords *= 2 / self.floor_mesh.t_coords.max() 75 | tex = pyvista.numpy_to_texture(make_checker_board_texture('#81C6EB', '#D4F1F7')) 76 | self.pl.add_mesh(self.floor_mesh, texture=tex, ambient=0.2, diffuse=0.8, specular=0.8, specular_power=5, smooth_shading=True) 77 | else: 78 | self.floor_mesh = None 79 | # cube 80 | if self.add_cube: 81 | self.cube_mesh = pyvista.Box() 82 | self.cube_mesh.points *= 0.1 83 | self.cube_mesh.translate((-5, 0.0, 0.1)) 84 | self.pl.add_mesh(self.cube_mesh, color='orange', ambient=0.2, diffuse=0.8, specular=0.8, specular_power=10, smooth_shading=True) 85 | 86 | def update_camera(self, interactive): 87 | pass 88 | 89 | def update_scene(self): 90 | pass 91 | 92 | def setup_key_callback(self): 93 | 94 | def close(): 95 | exit(0) 96 | 97 | def slowdown(): 98 | if self.frame_mode == 'fps': 99 | self.fps = self.fps_arr[(self.fps_arr.index(self.fps) - 1) % len(self.fps_arr)] 100 | print(f'Setting fps to {self.fps}') 101 | else: 102 | self.T = self.T_arr[(self.T_arr.index(self.T) + 1) % len(self.T_arr)] 103 | 104 | def speedup(): 105 | if self.frame_mode == 'fps': 106 | self.fps = self.fps_arr[(self.fps_arr.index(self.fps) + 1) % len(self.fps_arr)] 107 | print(f'Setting fps to {self.fps}') 108 | else: 109 | self.T = self.T_arr[(self.T_arr.index(self.T) - 1) % len(self.T_arr)] 110 | 111 | def reverse(): 112 | self.reverse = not self.reverse 113 | 114 | def repeat(): 115 | self.repeat = not self.repeat 116 | 117 | def pause(): 118 | self.paused = not self.paused 119 | 120 | def next_frame(): 121 | if self.fr < self.num_fr - 1: 122 | self.fr += 1 123 | self.update_scene() 124 | 125 | def prev_frame(): 126 | if self.fr > 0: 127 | self.fr -= 1 128 | self.update_scene() 129 | 130 | def go_to_start(): 131 | self.fr = 0 132 | self.update_scene() 133 | 134 | def go_to_end(): 135 | self.fr = self.num_fr - 1 136 | self.update_scene() 137 | 138 | self.pl.add_key_event('q', close) 139 | self.pl.add_key_event('s', slowdown) 140 | self.pl.add_key_event('d', speedup) 141 | self.pl.add_key_event('a', reverse) 142 | self.pl.add_key_event('g', repeat) 143 | self.pl.add_key_event('Up', go_to_start) 144 | self.pl.add_key_event('Down', go_to_end) 145 | self.pl.add_key_event('space', pause) 146 | self.pl.add_key_event('Left', prev_frame) 147 | self.pl.add_key_event('Right', next_frame) 148 | 149 | def render(self, interactive): 150 | self.update_camera(interactive) 151 | self.pl.update() 152 | 153 | def tframe_animation_loop(self): 154 | t = 0 155 | while True: 156 | self.render(interactive=True) 157 | if t >= math.floor(self.T): 158 | if not self.reverse: 159 | if self.fr < self.num_fr - 1: 160 | self.fr += 1 161 | elif self.repeat: 162 | self.fr = 0 163 | elif self.reverse and self.fr > 0: 164 | self.fr -= 1 165 | self.update_scene() 166 | t = 0 167 | if not self.paused: 168 | t += 1 169 | 170 | def fps_animation_loop(self): 171 | last_render_time = time.time() 172 | self.update_scene() 173 | while True: 174 | while True: 175 | self.render(interactive=True) 176 | if time.time() - last_render_time >= (1 / self.fps - 0.002): 177 | break 178 | if not self.paused: 179 | if not self.reverse: 180 | if self.fr < self.num_fr - 1: 181 | self.fr += 1 182 | self.update_scene() 183 | elif self.repeat: 184 | self.fr = 0 185 | self.update_scene() 186 | else: 187 | return 188 | elif self.reverse and self.fr > 0: 189 | self.fr -= 1 190 | self.update_scene() 191 | # print('fps', 1 / (time.time() - last_render_time)) 192 | last_render_time = time.time() 193 | 194 | def show_animation(self, window_size=(800, 800), init_args=None, enable_shadow=None, frame_mode='fps', fps=30, repeat=False, show_axes=True): 195 | self.frame_mode = frame_mode 196 | self.fps = fps 197 | self.repeat = repeat 198 | if enable_shadow is not None: 199 | self.enable_shadow = enable_shadow 200 | if self.enable_shadow: 201 | self.pl = pyvista.Plotter(window_size=window_size, lighting='none') 202 | else: 203 | self.pl = pyvista.Plotter(window_size=window_size) 204 | self.init_camera(init_args) 205 | self.init_scene(init_args) 206 | self.update_scene() 207 | self.setup_key_callback() 208 | if show_axes: 209 | self.pl.show_axes() 210 | self.pl.show(interactive_update=True) 211 | if self.frame_mode == 'fps': 212 | self.fps_animation_loop() 213 | else: 214 | self.tframe_animation_loop() 215 | self.pl.close() 216 | 217 | def save_frame(self, fr, img_path): 218 | self.fr = fr 219 | self.update_scene() 220 | self.render(interactive=False) 221 | self.pl.screenshot(img_path) 222 | 223 | def save_animation_as_video(self, video_path, init_args=None, window_size=(800, 800), enable_shadow=None, fps=30, frame_dir=None, cleanup=True): 224 | if platform.system() == 'Linux': 225 | pyvista.start_xvfb() 226 | if enable_shadow is not None: 227 | self.enable_shadow = enable_shadow 228 | if self.enable_shadow: 229 | window_size = (1000, 1000) 230 | self.pl = pyvista.Plotter(window_size=window_size, off_screen=True, lighting='none') 231 | else: 232 | self.pl = pyvista.Plotter(window_size=window_size, off_screen=True) 233 | self.init_camera(init_args) 234 | self.init_scene(init_args) 235 | self.pl.show(interactive_update=True) 236 | if frame_dir is None: 237 | frame_dir = tempfile.mkdtemp(prefix="visualizer3d-") 238 | else: 239 | if osp.exists(frame_dir): 240 | shutil.rmtree(frame_dir) 241 | os.makedirs(frame_dir) 242 | os.makedirs(osp.dirname(video_path), exist_ok=True) 243 | for fr in tqdm(range(self.num_fr)): 244 | self.save_frame(fr, f'{frame_dir}/{fr:06d}.png') 245 | images_to_video(frame_dir, video_path, fps=fps, verbose=self.verbose) 246 | print(f'Animation saved to {video_path}') 247 | if cleanup: 248 | shutil.rmtree(frame_dir) 249 | 250 | 251 | 252 | if __name__ == '__main__': 253 | 254 | visualizer = PyvistaVisualizer(add_cube=True, enable_shadow=True) 255 | visualizer.show_animation(show_axes=True) -------------------------------------------------------------------------------- /smpl_visualizer/vis_scenepic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import scenepic as sp 4 | from .smpl import SMPL, SMPL_MODEL_DIR, BASE_DIR 5 | from .torch_transform import quat_between_two_vec, quaternion_to_angle_axis 6 | from .vis import make_checker_board_texture, get_color_palette 7 | from tqdm import tqdm 8 | from PIL import Image 9 | import os 10 | 11 | NET_TEXTURE_PATH = os.path.join(BASE_DIR, 'net_texture.png') 12 | 13 | def get_transform(scale, trans, new_dir, old_dir=[1., 0., 0.]): 14 | trans = sp.Transforms.Translate(trans) 15 | scale = sp.Transforms.Scale(scale) 16 | new_dir = torch.from_numpy(new_dir).float() 17 | aa = quaternion_to_angle_axis(quat_between_two_vec(torch.tensor(old_dir).expand_as(new_dir), new_dir)).numpy() 18 | angle = np.linalg.norm(aa, axis=-1, keepdims=True) 19 | axis = aa / (angle + 1e-6) 20 | rotation = sp.Transforms.rotation_matrix_from_axis_angle(axis, angle) 21 | return trans.dot(rotation.dot(scale)) 22 | 23 | class SportVisualizerHTML(): 24 | def __init__(self, smpl=None, device=torch.device('cpu'), gender='male', show_ball=True, show_ball_target=True): 25 | self.show_ball = show_ball 26 | self.show_ball_target = show_ball_target 27 | if smpl is not None: 28 | self.smpl = smpl 29 | else: 30 | self.smpl = SMPL(SMPL_MODEL_DIR, create_transl=False, gender=gender).to(device) 31 | self.smpl_faces = self.smpl.faces 32 | self.smpl_seq = None 33 | self.racket_params = None 34 | self.ball_params = None 35 | 36 | def get_camera_intrinsics(self, width=1920, height=1080): 37 | fx = fy = max(width, height) 38 | K = np.diag([fx, fy, 1.0]) 39 | K[0, 2] = width / 2 40 | K[1, 2] = height / 2 41 | return K 42 | 43 | def load_default_camera(self, intrinsics): 44 | # this function loads an "OpenCV"-style camera representation 45 | # and converts it to a GL style for use in ScenePic 46 | # location = np.array(camera_info["location"], np.float32) 47 | # euler_angles = np.array(camera_info["rotation"], np.float32) 48 | # rotation = sp.Transforms.euler_angles_to_matrix(euler_angles, "XYZ") 49 | # translation = sp.Transforms.translate(location) 50 | # extrinsics = translation @ rotation 51 | 52 | img_width = intrinsics[0,2] * 2 53 | img_height = intrinsics[1,2] * 2 54 | 55 | aspect_ratio = img_width / img_height 56 | 57 | return sp.Camera(center=(0, -25, 3) if self.sport=='tennis' else (0, -13, 3), 58 | look_at=(0, 0, 0), up_dir=(0, 0, 1), 59 | fov_y_degrees=45.0, aspect_ratio=aspect_ratio, far_crop_distance=80) 60 | 61 | def load_camera_from_ext_int(self, extrinsics, intrinsics): 62 | # this function loads an "OpenCV"-style camera representation 63 | # and converts it to a GL style for use in ScenePic 64 | # location = np.array(camera_info["location"], np.float32) 65 | # euler_angles = np.array(camera_info["rotation"], np.float32) 66 | # rotation = sp.Transforms.euler_angles_to_matrix(euler_angles, "XYZ") 67 | # translation = sp.Transforms.translate(location) 68 | # extrinsics = translation @ rotation 69 | 70 | img_width = intrinsics[0,2] * 2 71 | img_height = intrinsics[1,2] * 2 72 | fx = intrinsics[0,0] 73 | fy = intrinsics[1,1] 74 | 75 | # pred_f_pix = orig_img_h / 2. / np.tan(pred_vfov / 2.) 76 | 77 | vfov = 2 * np.arctan(2 * fx / img_height) 78 | world_to_camera = sp.Transforms.gl_world_to_camera(extrinsics) 79 | aspect_ratio = img_width / img_height 80 | projection = sp.Transforms.gl_projection(45, aspect_ratio, 0.01, 100) 81 | 82 | return sp.Camera(world_to_camera, projection) 83 | 84 | def init_players_and_rackets(self, smpl_seq=None, racket_seq=None): 85 | self.smpl_seq = smpl_seq 86 | self.smpl_verts = None 87 | 88 | if 'joint_rot' in smpl_seq: 89 | joint_rot = smpl_seq['joint_rot'] # num_actor x num_frames x (num_joints x 3) 90 | trans = smpl_seq['trans'] # num_actor x num_frames x 3 91 | 92 | self.smpl_motion = self.smpl( 93 | global_orient=joint_rot[..., :3].view(-1, 3), 94 | body_pose=joint_rot[..., 3:].view(-1, 69), 95 | betas=smpl_seq['betas'].view(-1, 1, 10).expand(-1, joint_rot.shape[1], 10).reshape(-1, 10), 96 | root_trans = trans.view(-1, 3), 97 | return_full_pose=True, 98 | orig_joints=True 99 | ) 100 | 101 | self.smpl_verts = self.smpl_motion.vertices.reshape(*joint_rot.shape[:-1], -1, 3) 102 | self.smpl_joints = self.smpl_motion.joints.reshape(*joint_rot.shape[:-1], -1, 3) 103 | 104 | if self.correct_root_height: 105 | diff_root_height = torch.min(self.smpl_joints[:, :, 10:12, 2], dim=2)[0].view(*trans.shape[:2], 1) 106 | self.smpl_joints[:, :, :, 2] -= diff_root_height 107 | if self.smpl_verts is not None: 108 | self.smpl_verts[:, :, :, 2] -= diff_root_height 109 | trans[:, :, 2:] -= diff_root_height 110 | 111 | if racket_seq is not None: 112 | num_actors, num_frames = trans.shape[:2] 113 | for i in range(num_actors): 114 | for j in range(num_frames): 115 | if racket_seq[i][j] is not None: 116 | racket_seq[i][j]['root'] = trans[i, j].cpu().numpy() 117 | self.racket_params = racket_seq 118 | 119 | self.num_actors = self.smpl_joints.shape[0] 120 | self.num_fr = self.smpl_joints.shape[1] 121 | 122 | def create_court(self, scene, frame): 123 | if self.sport == 'tennis': 124 | # Court 125 | court_mesh = scene.create_mesh() 126 | 127 | wlh = (10.97, 11.89*2, 0.05) 128 | center = np.array([0, 0, -wlh[2] * 0.5]) 129 | court_mesh.add_cube( 130 | color=np.array([74, 96, 157]) / 255., 131 | transform=np.dot(sp.Transforms.Translate(center), sp.Transforms.Scale(wlh))) 132 | 133 | # Court lines (vertical) 134 | for x, l in zip([-10.97/2, -8.23/2, 0, 8.23/2, 10.97/2], [23.77, 23.77, 12.8, 23.77, 23.77]): 135 | wlh = (0.05, l, 0.05) 136 | center = np.array([x, 0, -wlh[2] * 0.5 + 0.01]) 137 | court_mesh.add_cube( 138 | color=(1, 1, 1), 139 | transform=np.dot(sp.Transforms.Translate(center), sp.Transforms.Scale(wlh))) 140 | 141 | # Court lines (horizontal) 142 | for y, w in zip([-11.89, -6.4, 0, 6.4, 11.89], [10.97, 8.23, 10.97, 8.23, 10.97]): 143 | wlh = (w, 0.05, 0.05) 144 | center = np.array([0, y, -wlh[2] * 0.5 + 0.01]) 145 | court_mesh.add_cube( 146 | color=(1, 1, 1), 147 | transform=np.dot(sp.Transforms.Translate(center), sp.Transforms.Scale(wlh))) 148 | 149 | frame.add_mesh(court_mesh) 150 | 151 | # Post 152 | post_mesh = scene.create_mesh() 153 | for x in [-0.91-10.97/2, 0.91+10.97/2]: 154 | wlh = (0.05, 0.05, 1.2) 155 | center = np.array([x, 0, wlh[2] * 0.5]) 156 | post_mesh.add_cube( 157 | color=np.array([189, 116, 39]) / 255., 158 | transform=np.dot(sp.Transforms.Translate(center), sp.Transforms.Scale(wlh))) 159 | frame.add_mesh(post_mesh) 160 | 161 | # Net 162 | wlh = (10.97+0.91*2, 0.01, 1.07) 163 | center = np.array([0, 0, 1.07/2]) 164 | net_texture = scene.create_image(image_id="net") 165 | # Failed to use checker_board texture because scenepic does not support uv > 1 166 | # checker_board = make_checker_board_texture('#FFFFFF', '#AAAAAA', width=10) 167 | # net_texture.from_numpy(checker_board) 168 | net_texture.load(NET_TEXTURE_PATH) 169 | net_mesh = scene.create_mesh(double_sided=True, use_texture_alpha=True, texture_id='net') 170 | net_mesh.add_image( 171 | origin=(-wlh[0]/2, 0, 0), 172 | x_axis=(wlh[0], 0, 0), 173 | y_axis=(0, 0, wlh[2]), 174 | uv_1=(1, 0), 175 | uv_2=(1, wlh[2]/wlh[0]), 176 | uv_3=(0, wlh[2]/wlh[0]), 177 | ) 178 | frame.add_mesh(net_mesh) 179 | 180 | elif self.sport == 'badminton': 181 | # Court 182 | court_mesh = scene.create_mesh() 183 | wlh = [6.1, 13.41, 0.05] 184 | center = np.array([0, 0, -wlh[2] * 0.5]) 185 | court_mesh.add_cube( 186 | color=np.array([74, 96, 157]) / 255., 187 | transform=np.dot(sp.Transforms.Translate(center), sp.Transforms.Scale(wlh))) 188 | 189 | # Court lines (vertical) 190 | for x in [-3.05, -2.6, 2.6, 3.05]: 191 | wlh = (0.05, 13.41, 0.05) 192 | center = np.array([x, 0, -wlh[2] * 0.5 + 0.01]) 193 | court_mesh.add_cube( 194 | color=(1, 1, 1), 195 | transform=np.dot(sp.Transforms.Translate(center), sp.Transforms.Scale(wlh))) 196 | for x, y, l in zip([0, 0], [-(1.98+6.71)/2, (1.98+6.71)/2], [3.96+0.76, 3.96+0.76]): 197 | wlh = (0.05, l, 0.05) 198 | center = np.array([x, y, -wlh[2] * 0.5 + 0.01]) 199 | court_mesh.add_cube( 200 | color=(1, 1, 1), 201 | transform=np.dot(sp.Transforms.Translate(center), sp.Transforms.Scale(wlh))) 202 | 203 | # Court lines (horizontal) 204 | for y in [-6.71, -3.96-1.98, -1.98, 1.98, 1.98+3.96, 6.71]: 205 | wlh = (6.1, 0.05, 0.05) 206 | center = np.array([0, y, -wlh[2] * 0.5 + 0.01]) 207 | court_mesh.add_cube( 208 | color=(1, 1, 1), 209 | transform=np.dot(sp.Transforms.Translate(center), sp.Transforms.Scale(wlh))) 210 | 211 | frame.add_mesh(court_mesh) 212 | 213 | # Post 214 | post_mesh = scene.create_mesh() 215 | for x in [-3.05, 3.05]: 216 | wlh = (0.05, 0.05, 1.55) 217 | center = np.array([x, 0, wlh[2] * 0.5]) 218 | post_mesh.add_cube( 219 | color=np.array([189, 116, 39]) / 255., 220 | transform=np.dot(sp.Transforms.Translate(center), sp.Transforms.Scale(wlh))) 221 | frame.add_mesh(post_mesh) 222 | 223 | # Net 224 | wlh = (6.1, 0.01, 0.79) 225 | center = np.array([0, 0, (0.76+1.55)/2]) 226 | net_texture = scene.create_image(image_id="net") 227 | net_texture.load(NET_TEXTURE_PATH) 228 | net_mesh = scene.create_mesh(double_sided=True, use_texture_alpha=True, texture_id='net') 229 | net_mesh.add_image( 230 | origin=(-wlh[0]/2, 0, 0.76), 231 | x_axis=(wlh[0], 0, 0), 232 | y_axis=(0, 0, wlh[2]), 233 | uv_1=(1, 0), 234 | uv_2=(1, wlh[2]/wlh[0]), 235 | uv_3=(0, wlh[2]/wlh[0]), 236 | ) 237 | frame.add_mesh(net_mesh) 238 | 239 | # Floor 240 | floor_mesh = scene.create_mesh() 241 | wlh = (20, 40, 0.05) 242 | center = np.array([0, 0, -wlh[2] * 0.5 - 0.01]) 243 | floor_mesh.add_cube( 244 | color=np.array([118, 151, 113]) / 255., 245 | transform=np.dot(sp.Transforms.Translate([center]), sp.Transforms.Scale(wlh))) 246 | frame.add_mesh(floor_mesh) 247 | 248 | def create_racket(self, scene, params): 249 | if params is None: return None 250 | 251 | if self.sport == 'tennis': 252 | racket_mesh = scene.create_mesh() 253 | # Head 254 | # racket_mesh.add_disc( 255 | # color=np.array([50, 50, 50]) / 255., 256 | # segment_count=50, 257 | # fill_triangles=False, 258 | # add_wireframe=True, 259 | # transform=get_transform( 260 | # scale=[1, 0.3, 0.26], 261 | # trans=params['head_center'] + params['root'], 262 | # new_dir=params['racket_normal'] 263 | # )) 264 | racket_mesh.add_cylinder( 265 | color=np.array([50, 50, 50]) / 255., 266 | transform=get_transform( 267 | scale=[0.02, 0.3, 0.3], 268 | trans=params['head_center'] + params['root'], 269 | new_dir=params['racket_normal'] 270 | )) 271 | # Shaft 272 | racket_mesh.add_cylinder( 273 | color=np.array([50, 50, 50]) / 255., 274 | transform=get_transform( 275 | scale=[0.15/np.cos(np.pi/10), 0.015, 0.015], 276 | trans=params['shaft_left_center'] + params['root'], 277 | new_dir=params['shaft_left_dir'] 278 | )) 279 | racket_mesh.add_cylinder( 280 | color=np.array([50, 50, 50]) / 255., 281 | transform=get_transform( 282 | scale=[0.15/np.cos(np.pi/10), 0.015, 0.015], 283 | trans=params['shaft_right_center'] + params['root'], 284 | new_dir=params['shaft_right_dir'] 285 | )) 286 | # Handle 287 | racket_mesh.add_cylinder( 288 | color=np.array([50, 50, 50]) / 255., 289 | transform=get_transform( 290 | scale=[0.2, 0.03, 0.03], 291 | trans=params['handle_center'] + params['root'], 292 | new_dir=params['racket_dir'] 293 | )) 294 | elif self.sport == 'badminton': 295 | racket_mesh = scene.create_mesh() 296 | 297 | # Head 298 | racket_mesh.add_disc( 299 | color=np.array([50, 50, 50]) / 255., 300 | segment_count=50, 301 | fill_triangles=False, 302 | add_wireframe=True, 303 | transform=get_transform( 304 | scale=[1, 0.28, 0.25], 305 | trans=params['head_center'] + params['root'], 306 | new_dir=params['racket_normal'] 307 | )) 308 | # Shaft 309 | racket_mesh.add_cylinder( 310 | color=np.array([50, 50, 50]) / 255., 311 | transform=get_transform( 312 | scale=[0.25, 0.01, 0.01], 313 | trans=params['shaft_center'] + params['root'], 314 | new_dir=params['racket_dir'] 315 | )) 316 | # Handle 317 | racket_mesh.add_cylinder( 318 | color=np.array([50, 50, 50]) / 255., 319 | transform=get_transform( 320 | scale=[0.15, 0.0254, 0.0254], 321 | trans=params['handle_center'] + params['root'], 322 | new_dir=params['racket_dir'] 323 | )) 324 | return racket_mesh 325 | 326 | 327 | def create_canvas(self, scene): 328 | canvas = scene.create_canvas_3d(width=self.image_width, height=self.image_height) 329 | cam_intrinsics = self.get_camera_intrinsics(self.image_width, self.image_height) 330 | 331 | for i in tqdm(range(self.num_fr)): 332 | frame = canvas.create_frame() 333 | 334 | # Add coordinate axis 335 | # coord_ax = scene.create_mesh() 336 | # coord_ax.add_coordinate_axes() 337 | # frame.add_mesh(coord_ax) 338 | 339 | self.create_court(scene, frame) 340 | 341 | colors = get_color_palette(self.num_actors, 'rainbow', use_float=True) 342 | for j in range(self.num_actors): 343 | if self.num_actors == 1: 344 | smpl_mesh = scene.create_mesh(shared_color=(0.7, 0.7, 0.7)) 345 | elif self.num_actors == 2: 346 | smpl_mesh = scene.create_mesh(shared_color=(0.7, 0.7, 0.7) if j == 0 else (0.5, 0.5, 0.5)) 347 | else: 348 | smpl_mesh = scene.create_mesh(shared_color=colors[j]) 349 | if self.smpl_seq is None or self.smpl_seq['joint_rot'][j, i].sum() != 0: 350 | smpl_mesh.add_mesh_without_normals(self.smpl_verts[j, i].contiguous().cpu().numpy(), self.smpl_faces) 351 | frame.add_mesh(smpl_mesh) 352 | 353 | if self.racket_params is not None: 354 | racket_mesh = self.create_racket(scene, self.racket_params[j][i]) 355 | if racket_mesh is not None: frame.add_mesh(racket_mesh) 356 | 357 | if self.show_ball and self.ball_params is not None: 358 | ball_mesh = scene.create_mesh() 359 | wlh = (0.1, 0.1, 0.1) 360 | center = self.ball_params[j][i] 361 | ball_mesh.add_sphere( 362 | color=np.array([223,255,79]) / 255., 363 | transform=np.dot(sp.Transforms.Translate(center), sp.Transforms.Scale(wlh))) 364 | frame.add_mesh(ball_mesh) 365 | 366 | if self.show_ball_target and self.ball_targets is not None: 367 | ball_target_mesh = scene.create_mesh() 368 | ball_target_mesh.add_cylinder( 369 | color=np.array([1., 0, 0]), 370 | transform=get_transform( 371 | scale=(0.05, 1.0, 1.0), 372 | trans=self.ball_targets[j][i], 373 | new_dir=np.array([0, 0, 1.0]))) 374 | frame.add_mesh(ball_target_mesh) 375 | 376 | frame.camera = self.load_default_camera(cam_intrinsics) 377 | return canvas 378 | 379 | def save_animation_as_html(self, init_args, html_path="demo.html"): 380 | scene = sp.Scene() 381 | 382 | self.image_width = init_args.get('image_width', 1920) 383 | self.image_height = init_args.get('image_height', 1080) 384 | self.sport = init_args.get('sport', 'tennis') 385 | self.correct_root_height = init_args.get('correct_root_height') 386 | 387 | if init_args.get('smpl_verts') is not None: 388 | self.smpl_verts = init_args['smpl_verts'].cpu() 389 | self.racket_params = init_args.get('racket_params') 390 | self.ball_params = init_args.get('ball_params') 391 | self.ball_targets = init_args.get('ball_targets') 392 | self.num_actors, self.num_fr = self.smpl_verts.shape[:2] 393 | else: 394 | self.init_players_and_rackets(smpl_seq=init_args.get('smpl_seq'), 395 | racket_seq=init_args.get('racket_seq')) 396 | 397 | self.create_canvas(scene) 398 | 399 | scene.save_as_html(html_path, title="sport visualizer") 400 | print(f"Saved animation as html into {html_path}") -------------------------------------------------------------------------------- /smpl_visualizer/vis_sport.py: -------------------------------------------------------------------------------- 1 | import pyvista 2 | import torch 3 | import numpy as np 4 | import platform 5 | from pyvista.plotting.tools import parse_color 6 | from vtk import vtkTransform 7 | import pdb 8 | from math import pi 9 | import time 10 | from .torch_transform import quat_apply, quat_between_two_vec, quaternion_to_angle_axis, angle_axis_to_quaternion 11 | from .vis_pyvista import PyvistaVisualizer 12 | from .smpl import SMPL, SMPL_MODEL_DIR 13 | from .vis import make_checker_board_texture, get_color_palette 14 | 15 | 16 | class SMPLActor(): 17 | 18 | def __init__(self, pl, verts, faces, color='#FF8A82', visible=True): 19 | self.pl = pl 20 | self.verts = verts 21 | self.face = faces 22 | self.mesh = pyvista.PolyData(verts, faces) 23 | self.actor = self.pl.add_mesh(self.mesh, color=color, pbr=True, metallic=0.0, roughness=0.3, diffuse=1) 24 | self.set_visibility(visible) 25 | 26 | def update_verts(self, new_verts): 27 | self.mesh.points[...] = new_verts 28 | self.mesh.compute_normals(inplace=True) 29 | 30 | def set_opacity(self, opacity): 31 | self.actor.GetProperty().SetOpacity(opacity) 32 | 33 | def set_visibility(self, flag): 34 | self.actor.SetVisibility(flag) 35 | 36 | def set_color(self, color): 37 | rgb_color = parse_color(color) 38 | self.actor.GetProperty().SetColor(rgb_color) 39 | 40 | 41 | class SkeletonActor(): 42 | 43 | def __init__(self, pl, joint_parents, joint_color='green', bone_color='yellow', joint_radius=0.03, bone_radius=0.02, visible=True): 44 | self.pl = pl 45 | self.joint_parents = joint_parents 46 | self.joint_meshes = [] 47 | self.joint_actors = [] 48 | self.bone_meshes = [] 49 | self.bone_actors = [] 50 | self.bone_pairs = [] 51 | for j, pa in enumerate(self.joint_parents): 52 | # joint 53 | joint_mesh = pyvista.Sphere(radius=joint_radius, center=(0, 0, 0), theta_resolution=10, phi_resolution=10) 54 | # joint_actor = self.pl.add_mesh(joint_mesh, color=joint_color, pbr=True, metallic=0.0, roughness=0.3, diffuse=1) 55 | joint_actor = self.pl.add_mesh(joint_mesh, color=joint_color, ambient=0.3, diffuse=0.5, specular=0.8, specular_power=5, smooth_shading=True) 56 | self.joint_meshes.append(joint_mesh) 57 | self.joint_actors.append(joint_actor) 58 | # bone 59 | if pa >= 0: 60 | bone_mesh = pyvista.Cylinder(radius=bone_radius, center=(0, 0, 0), direction=(0, 0, 1), resolution=30) 61 | # bone_actor = self.pl.add_mesh(bone_mesh, color=bone_color, pbr=True, metallic=0.0, roughness=0.3, diffuse=1) 62 | bone_actor = self.pl.add_mesh(bone_mesh, color=bone_color, ambient=0.3, diffuse=0.5, specular=0.8, specular_power=5, smooth_shading=True) 63 | self.bone_meshes.append(bone_mesh) 64 | self.bone_actors.append(bone_actor) 65 | self.bone_pairs.append((j, pa)) 66 | self.set_visibility(visible) 67 | 68 | def update_joints(self, jpos): 69 | # joint 70 | for actor, pos in zip(self.joint_actors, jpos): 71 | trans = vtkTransform() 72 | trans.Translate(*pos) 73 | actor.SetUserTransform(trans) 74 | # bone 75 | vec = [] 76 | for actor, (j, pa) in zip(self.bone_actors, self.bone_pairs): 77 | vec.append((jpos[j] - jpos[pa])) 78 | vec = np.stack(vec) 79 | dist = np.linalg.norm(vec, axis=-1) 80 | vec = torch.tensor(vec / dist[..., None]) 81 | aa = quaternion_to_angle_axis(quat_between_two_vec(torch.tensor([0., 0., 1.]).expand_as(vec), vec)).numpy() 82 | angle = np.linalg.norm(aa, axis=-1, keepdims=True) 83 | axis = aa / (angle + 1e-6) 84 | 85 | for actor, (j, pa), angle_i, axis_i, dist_i in zip(self.bone_actors, self.bone_pairs, angle, axis, dist): 86 | trans = vtkTransform() 87 | trans.Translate(*(jpos[pa] + jpos[j]) * 0.5) 88 | trans.RotateWXYZ(np.rad2deg(angle_i), *axis_i) 89 | trans.Scale(1, 1, dist_i) 90 | actor.SetUserTransform(trans) 91 | 92 | def set_opacity(self, opacity): 93 | for actor in self.joint_actors: 94 | actor.GetProperty().SetOpacity(opacity) 95 | for actor in self.bone_actors: 96 | actor.GetProperty().SetOpacity(opacity) 97 | 98 | def set_visibility(self, flag): 99 | for actor in self.joint_actors: 100 | actor.SetVisibility(flag) 101 | for actor in self.bone_actors: 102 | actor.SetVisibility(flag) 103 | 104 | def set_color(self, color): 105 | rgb_color = parse_color(color) 106 | for actor in self.joint_actors: 107 | actor.GetProperty().SetColor(rgb_color) 108 | for actor in self.jbone_actors: 109 | actor.GetProperty().SetColor(rgb_color) 110 | 111 | 112 | def get_transform(new_pos, new_dir): 113 | trans = vtkTransform() 114 | trans.Translate(new_pos) 115 | new_dir = torch.from_numpy(new_dir).float() 116 | aa = quaternion_to_angle_axis(quat_between_two_vec(torch.tensor([0., 0., 1.]).expand_as(new_dir), new_dir)).numpy() 117 | angle = np.linalg.norm(aa, axis=-1, keepdims=True) 118 | axis = aa / (angle + 1e-6) 119 | trans.RotateWXYZ(np.rad2deg(angle), *axis) 120 | return trans 121 | 122 | 123 | class RacketActor(): 124 | 125 | def __init__(self, pl, sport='tennis', debug=True): 126 | self.pl = pl 127 | self.sport = sport 128 | self.debug = debug 129 | if self.sport == 'badminton': 130 | self.net_mesh = pyvista.Cylinder(center=(0, 0, 0), radius=0.25/2, height=0.01, direction=(0, 0, 1)) 131 | self.net_mesh.active_t_coords *= 1000 132 | tex = pyvista.numpy_to_texture(make_checker_board_texture('#FFFFFF', '#AAAAAA', width=10)) 133 | self.net_actor = self.pl.add_mesh(self.net_mesh, texture=tex, ambient=0.2, diffuse=0.8, opacity=0.1, smooth_shading=True) 134 | 135 | self.head_mesh = pyvista.Tube(pointa=(0, 0, -0.005), pointb=(0, 0, 0.005), radius=0.25/2) 136 | self.head_actor = self.pl.add_mesh(self.head_mesh, color='black', ambient=0.3, diffuse=0.5, smooth_shading=True) 137 | 138 | self.shaft_mesh = pyvista.Cylinder(center=(0, 0, 0), radius=0.005, height=0.25, direction=(0, 0, 1)) 139 | self.shaft_actor = self.pl.add_mesh(self.shaft_mesh, color='black', ambient=0.3, diffuse=0.5, smooth_shading=True) 140 | 141 | self.handle_mesh = pyvista.Cylinder(center=(0, 0, 0), radius=0.0254/2, height=0.15, direction=(0, 0, 1)) 142 | self.handle_actor = self.pl.add_mesh(self.handle_mesh, color='#AAAAAA', ambient=0.3, diffuse=0.5, smooth_shading=True) 143 | 144 | self.actors = [self.head_actor, self.net_actor, self.shaft_actor, self.handle_actor] 145 | elif self.sport == 'tennis': 146 | self.net_mesh = pyvista.Cylinder(center=(0, 0, 0), radius=0.15, height=0.01, direction=(0, 0, 1)) 147 | self.net_mesh.active_t_coords *= 1000 148 | tex = pyvista.numpy_to_texture(make_checker_board_texture('#FFFFFF', '#AAAAAA', width=10)) 149 | self.net_actor = self.pl.add_mesh(self.net_mesh, texture=tex, ambient=0.2, diffuse=0.8, opacity=0.1, smooth_shading=True) 150 | 151 | self.head_mesh = pyvista.Tube(pointa=(0, 0, -0.01), pointb=(0, 0, 0.01), radius=0.15) 152 | self.head_actor = self.pl.add_mesh(self.head_mesh, color='black', ambient=0.3, diffuse=0.5, smooth_shading=True) 153 | 154 | self.shaft_left_mesh = pyvista.Cylinder(center=(0, 0, 0), radius=0.015/2, height=0.15/np.cos(np.pi/10), direction=(0, 0, 1)) 155 | self.shaft_left_actor = self.pl.add_mesh(self.shaft_left_mesh, color='black', ambient=0.3, diffuse=0.5, smooth_shading=True) 156 | 157 | self.shaft_right_mesh = pyvista.Cylinder(center=(0, 0, 0), radius=0.015/2, height=0.15/np.cos(np.pi/10), direction=(0, 0, 1)) 158 | self.shaft_right_actor = self.pl.add_mesh(self.shaft_right_mesh, color='black', ambient=0.3, diffuse=0.5, smooth_shading=True) 159 | 160 | self.handle_mesh = pyvista.Cylinder(center=(0, 0, 0), radius=0.03/2, height=0.2, direction=(0, 0, 1)) 161 | self.handle_actor = self.pl.add_mesh(self.handle_mesh, color='black', ambient=0.3, diffuse=0.5, smooth_shading=True) 162 | 163 | self.actors = [self.head_actor, self.net_actor, self.shaft_left_actor, self.shaft_right_actor, self.handle_actor] 164 | 165 | if self.debug: 166 | self.normal_mesh = pyvista.Cylinder(center=(0, 0, 0.1), radius=0.01, height=0.2, direction=(0, 0, 1)) 167 | self.normal_actor = self.pl.add_mesh(self.normal_mesh, color='red', diffuse=1, smooth_shading=True) 168 | self.actors += [self.normal_actor] 169 | 170 | def update_racket(self, params): 171 | if self.sport == 'badminton': 172 | self.head_actor.SetUserTransform(get_transform(params['head_center'] + params['root'], params['racket_normal'])) 173 | self.net_actor.SetUserTransform(get_transform(params['head_center'] + params['root'], params['racket_normal'])) 174 | self.shaft_actor.SetUserTransform(get_transform(params['shaft_center'] + params['root'], params['racket_dir'])) 175 | self.handle_actor.SetUserTransform(get_transform(params['handle_center'] + params['root'], params['racket_dir'])) 176 | elif self.sport == 'tennis': 177 | self.head_actor.SetUserTransform(get_transform(params['head_center'] + params['root'], params['racket_normal'])) 178 | self.net_actor.SetUserTransform(get_transform(params['head_center'] + params['root'], params['racket_normal'])) 179 | self.shaft_left_actor.SetUserTransform(get_transform(params['shaft_left_center'] + params['root'], params['shaft_left_dir'])) 180 | self.shaft_right_actor.SetUserTransform(get_transform(params['shaft_right_center'] + params['root'], params['shaft_right_dir'])) 181 | self.handle_actor.SetUserTransform(get_transform(params['handle_center'] + params['root'], params['racket_dir'])) 182 | if self.debug: 183 | self.normal_actor.SetUserTransform(get_transform(params['head_center'] + params['root'], params['racket_normal'])) 184 | 185 | def set_visibility(self, flag): 186 | for actor in self.actors: 187 | actor.SetVisibility(flag) 188 | 189 | 190 | class TargetRecoveryActor(): 191 | 192 | def __init__(self, pl): 193 | self.pl = pl 194 | self.marker_mesh = pyvista.Disc(center=[0,0,0], inner=0.1, outer=0.2) 195 | self.actor = self.pl.add_mesh(self.marker_mesh, color='red', ambient=0.3, diffuse=0.5, smooth_shading=True) 196 | 197 | def update_target(self, pos): 198 | trans = vtkTransform() 199 | trans.Translate([pos[0], pos[1], 0.01]) 200 | self.actor.SetUserTransform(trans) 201 | 202 | def set_visibility(self, flag): 203 | self.actor.SetVisibility(flag) 204 | 205 | 206 | class TargetReactionActor(): 207 | 208 | def __init__(self, pl): 209 | self.pl = pl 210 | self.outer_mesh = pyvista.Sphere(center=[0,0,0], radius=0.2) 211 | self.actor_outer = self.pl.add_mesh(self.outer_mesh, color='orange', ambient=0.3, diffuse=0.5, smooth_shading=True, opacity=0.5) 212 | self.inner_mesh = pyvista.Sphere(center=[0,0,0], radius=0.2) 213 | self.actor_inner = self.pl.add_mesh(self.inner_mesh, color='red', ambient=0.3, diffuse=0.5, smooth_shading=True) 214 | 215 | def update_target(self, pos, time=None): 216 | trans = vtkTransform() 217 | trans.Translate([pos[0], pos[1], pos[2]]) 218 | self.actor_outer.SetUserTransform(trans) 219 | trans = vtkTransform() 220 | trans.Translate([pos[0], pos[1], pos[2]]) 221 | trans.Scale([time, time, time]) 222 | self.actor_inner.SetUserTransform(trans) 223 | 224 | def set_visibility(self, flag): 225 | self.actor_inner.SetVisibility(flag) 226 | self.actor_outer.SetVisibility(flag) 227 | 228 | 229 | class BallActor(): 230 | 231 | def __init__(self, pl, sport='tennis', color='red', blur=False, num_exposure=30, real_shadow=False): 232 | self.pl = pl 233 | self.sport = sport 234 | self.blur = blur 235 | self.num_exposure = num_exposure 236 | self.real_shadow = real_shadow 237 | if sport == 'tennis': 238 | ball_mesh = pyvista.Sphere(center=[0,0,0], radius=0.05) 239 | self.actor = self.pl.add_mesh(ball_mesh, color=color, 240 | ambient=0.3, diffuse=1, smooth_shading=True) 241 | 242 | # for motion blur 243 | if self.blur: 244 | self.actors = [] 245 | for i in range(self.num_exposure): 246 | ball_mesh = pyvista.Sphere(center=[0,0,0], radius=0.05) 247 | self.actors += [self.pl.add_mesh(ball_mesh, color=color, 248 | ambient=0.3, diffuse=0.8, smooth_shading=True, 249 | opacity=5./num_exposure if not real_shadow else 1)] 250 | 251 | if not self.real_shadow: 252 | shadow_mesh = pyvista.Circle(radius=0.05) 253 | self.shadow_actor = self.pl.add_mesh(shadow_mesh, color='#101010', 254 | diffuse=1, smooth_shading=True) 255 | 256 | if self.blur: 257 | self.shadow_actors = [] 258 | for i in range(self.num_exposure): 259 | shadow_mesh = pyvista.Circle(radius=0.05) 260 | self.shadow_actors += [self.pl.add_mesh(shadow_mesh, color='#101010', 261 | diffuse=1, smooth_shading=True)] 262 | else: 263 | NotImplemented 264 | self.set_visibility(False) 265 | 266 | def update_ball(self, params): 267 | if params is None: 268 | self.set_visibility(False) 269 | return 270 | 271 | if self.blur and params.get('pos_blur') is not None: 272 | for i in range(self.num_exposure): 273 | trans = vtkTransform() 274 | pos = params['pos_blur'][i] 275 | trans.Translate([pos[0], pos[1], pos[2]]) 276 | self.actors[i].SetUserTransform(trans) 277 | self.actors[i].SetVisibility(True) 278 | 279 | if not self.real_shadow: 280 | trans = vtkTransform() 281 | trans.Translate([pos[0], pos[1], 0]) 282 | self.shadow_actors[i].SetUserTransform(trans) 283 | self.shadow_actors[i].SetVisibility(True) 284 | else: 285 | trans = vtkTransform() 286 | pos = params['pos'] 287 | trans.Translate([pos[0], pos[1], pos[2]]) 288 | self.actor.SetUserTransform(trans) 289 | self.actor.SetVisibility(True) 290 | 291 | # self.ang_vel_actor.SetUserTransform(get_transform(params['pos'].cpu().numpy(), params['ang_vel'].cpu().numpy())) 292 | # self.ang_vel_actor.SetVisibility(True) 293 | 294 | if not self.real_shadow: 295 | trans = vtkTransform() 296 | trans.Translate([pos[0], pos[1], 0.01]) 297 | self.shadow_actor.SetUserTransform(trans) 298 | self.shadow_actor.SetVisibility(True) 299 | 300 | def set_visibility(self, flag): 301 | self.actor.SetVisibility(flag) 302 | # self.ang_vel_actor.SetVisibility(flag) 303 | if self.blur: 304 | for actor in self.actors: 305 | actor.SetVisibility(flag) 306 | if not self.real_shadow: 307 | self.shadow_actor.SetVisibility(flag) 308 | if self.blur: 309 | for actor in self.shadow_actors: 310 | actor.SetVisibility(flag) 311 | 312 | 313 | class TargetBounceActor(): 314 | 315 | def __init__(self, pl): 316 | self.pl = pl 317 | self.marker_mesh = pyvista.Disc(center=[0,0,0], inner=0.1, outer=0.2) 318 | self.actor = self.pl.add_mesh(self.marker_mesh, color='red', ambient=0.3, diffuse=0.5, smooth_shading=True) 319 | self.set_visibility(False) 320 | 321 | def update_target(self, pos): 322 | trans = vtkTransform() 323 | trans.Translate([pos[0], pos[1], 0.01]) 324 | self.actor.SetUserTransform(trans) 325 | self.set_visibility(True) 326 | 327 | def set_visibility(self, flag): 328 | self.actor.SetVisibility(flag) 329 | 330 | 331 | class SportVisualizer(PyvistaVisualizer): 332 | 333 | def __init__(self, show_smpl=False, show_skeleton=True, show_racket=False, 334 | show_target=False, show_ball=False, show_ball_target=False, show_stats=True, 335 | track_first_actor=False, track_ball=False, 336 | enable_shadow=False, 337 | gender='male', 338 | correct_root_height=False, device=torch.device('cpu'), **kwargs): 339 | 340 | super().__init__(**kwargs) 341 | self.show_smpl = show_smpl 342 | self.show_skeleton = show_skeleton 343 | self.show_racket = show_racket 344 | self.show_target = show_target 345 | self.show_ball = show_ball 346 | self.show_ball_target = show_ball_target 347 | self.show_stats = show_stats 348 | self.track_first_actor = track_first_actor 349 | self.track_ball = track_ball 350 | self.enable_shadow = enable_shadow 351 | self.correct_root_height = correct_root_height 352 | self.camera = 'front' 353 | self.sport = 'tennis' 354 | 355 | self.smpl = SMPL(SMPL_MODEL_DIR, create_transl=False, gender=gender).to(device) 356 | faces = self.smpl.faces.copy() 357 | self.smpl_faces = faces = np.hstack([np.ones_like(faces[:, [0]]) * 3, faces]) 358 | self.smpl_joint_parents = self.smpl.parents.cpu().numpy() 359 | self.smpl_verts = None 360 | self.smpl_joints = None 361 | self.racket_params = None 362 | self.device = device 363 | 364 | self.forward = False 365 | 366 | def setup_animation(self, smpl_seq=None, racket_seq=None, ball_seq=None): 367 | self.smpl_seq = smpl_seq 368 | self.smpl_verts = None 369 | 370 | if 'joint_rot' in smpl_seq: 371 | joint_rot = smpl_seq['joint_rot'] # num_actor x num_frames x (num_joints x 3) 372 | trans = smpl_seq['trans'] # num_actor x num_frames x 3 373 | 374 | self.smpl_motion = self.smpl( 375 | global_orient=joint_rot[..., :3].view(-1, 3), 376 | body_pose=joint_rot[..., 3:].view(-1, 69), 377 | betas=smpl_seq['betas'].view(-1, 1, 10).expand(-1, joint_rot.shape[1], 10).reshape(-1, 10), 378 | root_trans = trans.view(-1, 3), 379 | return_full_pose=True, 380 | orig_joints=True 381 | ) 382 | 383 | self.smpl_verts = self.smpl_motion.vertices.reshape(*joint_rot.shape[:-1], -1, 3) 384 | if 'joint_pos' not in smpl_seq: 385 | self.smpl_joints = self.smpl_motion.joints.reshape(*joint_rot.shape[:-1], -1, 3) 386 | # set all 0 if joint rot is all 0 (invalid pose) 387 | num_actors, num_frames = self.smpl_joints.shape[:2] 388 | for i in range(num_actors): 389 | for j in range(num_frames): 390 | if joint_rot[i, j].sum() == 0: 391 | self.smpl_joints[i, j, :, :] = 0 392 | 393 | if 'joint_pos' in smpl_seq: 394 | joints = smpl_seq['joint_pos'] # num_actor x num_frames x num_joints x 3 395 | trans = smpl_seq['trans'] # num_actor x num_frames x 3 396 | 397 | # orient is None for hybrIK since joints already has global orentation 398 | orient = smpl_seq['orient'] 399 | 400 | joints_world = joints 401 | if orient is not None: 402 | joints_world = torch.cat([torch.zeros_like(joints[..., :3]), joints], dim=-1).view(*joints.shape[:-1], -1, 3) 403 | orient_q = angle_axis_to_quaternion(orient).unsqueeze(-2).expand(joints.shape[:-1] + (4,)) 404 | joints_world = quat_apply(orient_q, joints_world) 405 | if trans is not None: 406 | joints_world = joints_world + trans.unsqueeze(-2) 407 | self.smpl_joints = joints_world 408 | 409 | if self.correct_root_height: 410 | diff_root_height = torch.min(self.smpl_joints[:, :, 10:12, 2], dim=2)[0].view(*trans.shape[:2], 1) 411 | self.smpl_joints[:, :, :, 2] -= diff_root_height 412 | if self.smpl_verts is not None: 413 | self.smpl_verts[:, :, :, 2] -= diff_root_height 414 | trans[:, :, 2:] -= diff_root_height 415 | 416 | if racket_seq is not None: 417 | num_actors, num_frames = trans.shape[:2] 418 | for i in range(num_actors): 419 | for j in range(num_frames): 420 | if racket_seq[i][j] is not None: 421 | racket_seq[i][j]['root'] = trans[i, j].numpy() 422 | self.racket_params = racket_seq 423 | 424 | self.ball_params = ball_seq 425 | 426 | self.fr = 0 427 | self.num_fr = self.smpl_joints.shape[1] 428 | 429 | def init_camera(self, init_args): 430 | self.camera = init_args.get('camera', self.camera) 431 | self.sport = init_args.get('sport', self.sport) 432 | 433 | if self.sport == 'tennis': 434 | if self.camera == 'front': 435 | self.pl.camera.up = (0, 0, 1) 436 | self.pl.camera.focal_point = [0, 0, 0] if self.enable_shadow else [0, -0.66, -1.78] 437 | self.pl.camera.position = [0, -30, 5] if self.enable_shadow else [0, -25.5, 8.4] 438 | elif self.camera == 'front_right': 439 | self.pl.camera.up = (0, 0, 1) 440 | self.pl.camera.focal_point = [2, 0, 0] if self.enable_shadow else [0, -0.66, -1.78] 441 | self.pl.camera.position = [2, -30, 5] if self.enable_shadow else [0, -25.5, 8.4] 442 | elif self.camera == 'front_left': 443 | self.pl.camera.up = (0, 0, 1) 444 | self.pl.camera.focal_point = [-2, 0, 0] if self.enable_shadow else [0, -0.66, -1.78] 445 | self.pl.camera.position = [-2, -30, 5] if self.enable_shadow else [0, -25.5, 8.4] 446 | elif self.camera == 'back': 447 | self.pl.camera.up = (0, 0, 1) 448 | self.pl.camera.focal_point = [0, 0, 0] 449 | self.pl.camera.position = [0, 25, 5] 450 | elif self.camera == 'top_both': 451 | self.pl.camera.up = (-1, 0, 0) 452 | self.pl.camera.focal_point = [0, 0, 0] 453 | self.pl.camera.position = [0, 0, 35] 454 | elif self.camera == 'top_near': 455 | self.pl.camera.up = (-1, 0, 0) 456 | self.pl.camera.focal_point = [0, -12, 0] 457 | self.pl.camera.position = [0, -12, 20] 458 | elif self.camera == 'top_far': 459 | self.pl.camera.up = (-1, 0, 0) 460 | self.pl.camera.focal_point = [0, 12, 0] 461 | self.pl.camera.position = [0, 12, 20] 462 | elif self.camera == 'side_both': 463 | self.pl.camera.up = (0, 0, 1) 464 | self.pl.camera.focal_point = [0, 0, 0] 465 | self.pl.camera.position = [35, 0, 3] 466 | elif self.camera == 'near_left': 467 | self.pl.camera.elevation = 0 468 | self.pl.camera.up = (0, 0, 1) 469 | self.pl.camera.focal_point = [0, -13, 0] 470 | self.pl.camera.position = [-12, -13, 3] 471 | elif self.camera == 'near_right': 472 | self.pl.camera.elevation = 0 473 | self.pl.camera.up = (0, 0, 1) 474 | self.pl.camera.focal_point = [0, -13, 0] 475 | self.pl.camera.position = [12, -13, 3] 476 | elif self.sport == 'badminton': 477 | if self.camera == 'front': 478 | self.pl.camera.up = (0, 0, 1) 479 | self.pl.camera.focal_point = [0, 0, 0] 480 | self.pl.camera.position = [0, -13, 3] 481 | elif self.camera == 'side_both': 482 | self.pl.camera.up = (0, 0, 1) 483 | self.pl.camera.focal_point = [0, 0, 0] 484 | self.pl.camera.position = [15, 0, 3] 485 | elif self.camera == 'side_near': 486 | self.pl.camera.elevation = 0 487 | self.pl.camera.up = (0, 0, 1) 488 | self.pl.camera.focal_point = [0, -3.5, 0] 489 | self.pl.camera.position = [15, -3.5, 0] 490 | elif self.camera == 'side_far': 491 | self.pl.camera.elevation = 0 492 | self.pl.camera.up = (0, 0, 1) 493 | self.pl.camera.focal_point = [0, 3.5, 0] 494 | self.pl.camera.position = [15, 3.5, 0] 495 | 496 | def init_scene(self, init_args): 497 | if init_args is None: 498 | init_args = dict() 499 | super().init_scene(init_args) 500 | 501 | # Init tennis court 502 | if init_args.get('sport') == 'tennis' and not init_args.get('no_court', False): 503 | # Court 504 | wlh = (10.97, 11.89*2, 0.05) 505 | center = np.array([0, 0, -wlh[2] * 0.5]) 506 | court_mesh = pyvista.Cube(center, *wlh) 507 | self.pl.add_mesh(court_mesh, color='#4A609D', ambient=0.2, diffuse=0.8, specular=0.2, specular_power=5, smooth_shading=True) 508 | 509 | # Court lines (vertical) 510 | for x, l in zip([-10.97/2, -8.23/2, 0, 8.23/2, 10.97/2], [23.77, 23.77, 12.8, 23.77, 23.77]): 511 | wlh = (0.05, l, 0.05) 512 | center = np.array([x, 0, -wlh[2] * 0.5]) 513 | court_line_mesh = pyvista.Cube(center, *wlh) 514 | court_line_mesh.points[:, 2] += 0.01 515 | self.pl.add_mesh(court_line_mesh, color='#FFFFFF', smooth_shading=True) 516 | 517 | # Court lines (horizontal) 518 | for y, w in zip([-11.89, -6.4, 0, 6.4, 11.89], [10.97, 8.23, 10.97, 8.23, 10.97]): 519 | wlh = (w, 0.05, 0.05) 520 | center = np.array([0, y, -wlh[2] * 0.5]) 521 | court_line_mesh = pyvista.Cube(center, *wlh) 522 | court_line_mesh.points[:, 2] += 0.01 523 | self.pl.add_mesh(court_line_mesh, color='#FFFFFF', smooth_shading=True) 524 | 525 | # Post 526 | for x in [-0.91-10.97/2, 0.91+10.97/2]: 527 | wlh = (0.05, 0.05, 1.2) 528 | center = np.array([x, 0, wlh[2] * 0.5]) 529 | post_mesh = pyvista.Cube(center, *wlh) 530 | self.pl.add_mesh(post_mesh, color='#BD7427', ambient=0.2, diffuse=0.8, specular=0.8, specular_power=5, smooth_shading=True) 531 | 532 | # Net 533 | wlh = (10.97+0.91*2, 0.01, 1.07) 534 | center = np.array([0, 0, 1.07/2]) 535 | net_mesh = pyvista.Cube(center, *wlh) 536 | if not self.enable_shadow: 537 | net_mesh.active_t_coords *= 1000 538 | tex = pyvista.numpy_to_texture(make_checker_board_texture('#FFFFFF', '#AAAAAA', width=10)) 539 | self.pl.add_mesh(net_mesh, texture=tex, ambient=0.2, diffuse=0.8, opacity=0.1, smooth_shading=True) 540 | 541 | # Lighting 542 | if self.enable_shadow: 543 | for x, y in [(-5, -5), (5, -5), (-5, 5), (5, 5)]: 544 | light = pyvista.Light( 545 | position=(x, y, 10), 546 | focal_point=(0, 0, 0), 547 | color=[1.0, 1.0, 1.0, 1.0], # Color temp. 5400 K 548 | intensity=0.4, 549 | ) 550 | self.pl.add_light(light) 551 | 552 | elif init_args.get('sport') == 'badminton' and not init_args.get('no_court', False): 553 | # Court 554 | wlh = (6.1, 13.41, 0.05) 555 | center = np.array([0, 0, -wlh[2] * 0.5]) 556 | court_mesh = pyvista.Cube(center, *wlh) 557 | self.pl.add_mesh(court_mesh, color='#4A609D', ambient=0.2, diffuse=0.8, specular=0.2, specular_power=5, smooth_shading=True) 558 | 559 | # Court lines (vertical) 560 | for x in [-3.05, -2.6, 2.6, 3.05]: 561 | wlh = (0.05, 13.41, 0.05) 562 | center = np.array([x, 0, -wlh[2] * 0.5]) 563 | court_line_mesh = pyvista.Cube(center, *wlh) 564 | court_line_mesh.points[:, 2] += 0.01 565 | self.pl.add_mesh(court_line_mesh, color='#FFFFFF', smooth_shading=True) 566 | for x, y, l in zip([0, 0], [-(1.98+6.71)/2, (1.98+6.71)/2], [3.96+0.76, 3.96+0.76]): 567 | wlh = (0.05, l, 0.05) 568 | center = np.array([x, y, -wlh[2] * 0.5]) 569 | court_line_mesh = pyvista.Cube(center, *wlh) 570 | court_line_mesh.points[:, 2] += 0.01 571 | self.pl.add_mesh(court_line_mesh, color='#FFFFFF', smooth_shading=True) 572 | 573 | # Court lines (horizontal) 574 | for y in [-6.71, -3.96-1.98, -1.98, 1.98, 1.98+3.96, 6.71]: 575 | wlh = (6.1, 0.05, 0.05) 576 | center = np.array([0, y, -wlh[2] * 0.5]) 577 | court_line_mesh = pyvista.Cube(center, *wlh) 578 | court_line_mesh.points[:, 2] += 0.01 579 | self.pl.add_mesh(court_line_mesh, color='#FFFFFF', smooth_shading=True) 580 | 581 | # Post 582 | for x in [-3.05, 3.05]: 583 | wlh = (0.05, 0.05, 1.55) 584 | center = np.array([x, 0, wlh[2] * 0.5]) 585 | post_mesh = pyvista.Cube(center, *wlh) 586 | self.pl.add_mesh(post_mesh, color='#BD7427', ambient=0.2, diffuse=0.8, specular=0.8, specular_power=5, smooth_shading=True) 587 | 588 | # Net 589 | wlh = (6.1, 0.01, 0.79) 590 | center = np.array([0, 0, (0.76+1.55)/2]) 591 | net_mesh = pyvista.Cube(center, *wlh) 592 | net_mesh.active_t_coords *= 1000 593 | tex = pyvista.numpy_to_texture(make_checker_board_texture('#FFFFFF', '#AAAAAA', width=10)) 594 | self.pl.add_mesh(net_mesh, texture=tex, ambient=0.2, diffuse=0.8, opacity=0.1, smooth_shading=True) 595 | 596 | if not init_args.get('no_court', False): 597 | # floor 598 | wlh = (100, 100, 0.05) 599 | center = np.array([0, 0, -wlh[2] * 0.5]) 600 | floor_mesh = pyvista.Cube(center, *wlh) 601 | floor_mesh.points[:, 2] -= 0.01 602 | self.pl.add_mesh(floor_mesh, color='#769771', ambient=0.2, diffuse=0.8, specular=0.2, specular_power=5, smooth_shading=True) 603 | else: 604 | wlh = (20.0, 40.0, 0.05) 605 | center = np.array([0, 0, -wlh[2] * 0.5]) 606 | self.floor_mesh = pyvista.Cube(center, *wlh) 607 | self.floor_mesh.t_coords *= 10 / self.floor_mesh.t_coords.max() 608 | tex = pyvista.numpy_to_texture(make_checker_board_texture('#81C6EB', '#D4F1F7')) 609 | self.pl.add_mesh(self.floor_mesh, texture=tex, ambient=0.2, diffuse=0.8, specular=0.8, specular_power=5, smooth_shading=True) 610 | 611 | smpl_seq, racket_seq, ball_seq = init_args.get('smpl_seq'), init_args.get('racket_seq'), init_args.get('ball_seq') 612 | if smpl_seq is not None: 613 | self.setup_animation(smpl_seq, racket_seq, ball_seq) 614 | elif ball_seq is not None: 615 | self.ball_params = ball_seq 616 | self.fr = 0 617 | self.num_fr = len(ball_seq[0]) 618 | self.num_actors = init_args['num_actors'] 619 | 620 | if self.show_smpl: 621 | if init_args.get('vis_mvae') and init_args.get('vis_pd_target'): 622 | colors_smpl = ['#ffca3a'] * (self.num_actors // 3) + ['#9d0208'] * (self.num_actors // 3) + ['green'] * (self.num_actors // 3) 623 | elif init_args.get('vis_mvae') or init_args.get('vis_pd_target'): 624 | colors_smpl = ['#ffca3a'] * (self.num_actors // 2) + ['#9d0208'] * (self.num_actors // 2) 625 | elif self.num_actors <= 2: 626 | colors_smpl = ['#ffca3a'] * self.num_actors 627 | else: 628 | colors_smpl = get_color_palette(self.num_actors, 'Wistia') 629 | # colors_smpl = ['#ffca3a'] * self.num_actors 630 | # HACK: get vertices from fake smpl joint 631 | smpl_motion = self.smpl( 632 | global_orient=torch.zeros((1, 3)).float(), 633 | body_pose=torch.zeros((1, 69)).float(), 634 | betas=torch.zeros((1, 10)).float(), 635 | root_trans = torch.zeros((1, 3)).float(), 636 | return_full_pose=True, 637 | orig_joints=True 638 | ) 639 | vertices = smpl_motion.vertices.reshape(-1, 3).numpy() 640 | if init_args.get('debug_root'): 641 | # Odd actors are final result, even actors are old result 642 | self.smpl_actors = [SMPLActor(self.pl, vertices, self.smpl_faces, color='#d00000' if i%2==0 else '#ffca3a') 643 | for i in range(self.num_actors)] 644 | else: 645 | self.smpl_actors = [SMPLActor(self.pl, vertices, self.smpl_faces, color=colors_smpl[a]) 646 | for a in range(self.num_actors)] 647 | 648 | if self.show_skeleton: 649 | if not self.show_smpl: 650 | colors_skeleton = get_color_palette(self.num_actors, colormap='autumn') 651 | else: 652 | colors_skeleton = ['yellow'] * self.num_actors 653 | self.skeleton_actors = [SkeletonActor(self.pl, self.smpl_joint_parents, bone_color=colors_skeleton[a]) 654 | for a in range(self.num_actors)] 655 | 656 | if self.show_racket: 657 | self.racket_actors = [RacketActor(self.pl, init_args.get('sport'), debug=False) 658 | for _ in range(self.num_actors)] 659 | 660 | if self.show_target: 661 | self.tar_reaction_actors = [TargetReactionActor(self.pl) for _ in range(self.num_actors)] 662 | self.tar_recover_actors = [TargetRecoveryActor(self.pl) for _ in range(self.num_actors)] 663 | 664 | if self.show_ball: 665 | if init_args.get('debug_ball'): 666 | if init_args.get('debug_ball') == 'comparison': 667 | self.ball_actors = [BallActor(self.pl, init_args.get('sport'), color='yellow' if a < self.num_actors//2 else 'red', 668 | real_shadow=self.enable_shadow) for a in range(self.num_actors)] 669 | else: 670 | colors = get_color_palette(self.num_actors, 'rainbow') 671 | self.ball_actors = [BallActor(self.pl, init_args.get('sport'), color=colors[a], 672 | real_shadow=self.enable_shadow) for a in range(self.num_actors)] 673 | elif init_args.get('add_second_ball'): 674 | self.ball_actors = [BallActor(self.pl, init_args.get('sport'), blur=init_args.get('ball_blur'), 675 | real_shadow=self.enable_shadow) for _ in range(self.num_actors * 2)] 676 | elif self.num_actors <= 2 or init_args.get('vis_mvae') or init_args.get('vis_pd_target'): 677 | self.ball_actors = [BallActor(self.pl, init_args.get('sport'), blur=init_args.get('ball_blur'), 678 | real_shadow=self.enable_shadow) for _ in range(self.num_actors)] 679 | else: 680 | self.ball_actors = [BallActor(self.pl, init_args.get('sport'), blur=init_args.get('ball_blur'), 681 | color=colors_smpl[a], real_shadow=self.enable_shadow) 682 | for a in range(self.num_actors)] 683 | 684 | if self.show_ball_target: 685 | self.ball_tar_actors = [TargetBounceActor(self.pl) for _ in range(self.num_actors)] 686 | 687 | if self.show_stats: 688 | self.text_actor_tar = self.pl.add_text('', position=(30, 1050), color='black', font_size=12) 689 | self.text_actor_reward = self.pl.add_text('', position=(30, 1020), color='black', font_size=12) 690 | self.text_actor_residual = self.pl.add_text('', position=(30, 990), color='black', font_size=12) 691 | self.text_actor_pose = self.pl.add_text('', position=(30, 960), color='black', font_size=12) 692 | self.text_actor_pose_tar = self.pl.add_text('', position=(30, 930), color='black', font_size=12) 693 | self.text_actor_racket = self.pl.add_text('', position=(30, 900), color='black', font_size=12) 694 | self.text_actor_ball = self.pl.add_text('', position=(30, 870), color='black', font_size=12) 695 | self.text_actor_contact = self.pl.add_text('', position=(30, 840), color='black', font_size=12) 696 | 697 | def update_camera(self, interactive): 698 | if self.track_first_actor: 699 | root_pos = self.smpl_joints[0, self.fr, 0].cpu().numpy() 700 | if self.camera == 'front': 701 | new_pos = [root_pos[0], -30, 5] if self.enable_shadow else [root_pos[0], -25, 5] 702 | self.pl.camera.up = (0, 0, 1) 703 | self.pl.camera.focal_point = [0, 0, 0] 704 | self.pl.camera.position = new_pos 705 | 706 | elif self.camera == 'near_left': 707 | self.pl.camera.up = (0, 0, 1) 708 | self.pl.camera.focal_point = [root_pos[0], root_pos[1], 1] 709 | self.pl.camera.position = [root_pos[0] - 5, root_pos[1], 1] 710 | 711 | elif self.camera == 'near_right': 712 | self.pl.camera.up = (0, 0, 1) 713 | self.pl.camera.focal_point = [root_pos[0], root_pos[1], 1] 714 | self.pl.camera.position = [root_pos[0] + 5, root_pos[1], 1] 715 | 716 | if self.track_ball: 717 | self.pl.camera.up = (0, 0, 1) 718 | self.pl.camera.focal_point = self.ball_params[0][self.fr].cpu().numpy() 719 | self.pl.camera.position = self.pl.camera.focal_point + np.array([3, 0, 0]) 720 | 721 | def update_scene(self): 722 | super().update_scene() 723 | 724 | if self.show_smpl and self.smpl_verts is not None: 725 | for i, actor in enumerate(self.smpl_actors): 726 | if self.smpl_joints[i, self.fr].sum() == 0: 727 | actor.set_visibility(False) 728 | else: 729 | actor.update_verts(self.smpl_verts[i, self.fr].cpu().numpy()) 730 | actor.set_visibility(True) 731 | if self.enable_shadow: 732 | actor.set_opacity(1.0) 733 | elif self.show_skeleton: 734 | actor.set_opacity(0.8) 735 | else: 736 | actor.set_opacity(1.0) 737 | 738 | if self.show_skeleton and self.smpl_joints is not None: 739 | for i, actor in enumerate(self.skeleton_actors): 740 | if self.smpl_joints[i, self.fr].sum() == 0: 741 | actor.set_visibility(False) 742 | else: 743 | actor.update_joints(self.smpl_joints[i, self.fr].cpu().numpy()) 744 | if self.enable_shadow: 745 | actor.set_visibility(False) 746 | else: 747 | actor.set_visibility(True) 748 | actor.set_opacity(0.2) 749 | 750 | if self.show_racket and self.racket_params is not None: 751 | for i, actor in enumerate(self.racket_actors): 752 | if self.smpl_joints[i, self.fr].sum() == 0: 753 | actor.set_visibility(False) 754 | else: 755 | actor.update_racket(self.racket_params[i][self.fr]) 756 | actor.set_visibility(True) 757 | 758 | if self.show_ball and self.ball_params is not None: 759 | for i, actor in enumerate(self.ball_actors): 760 | if i >= len(self.ball_params): actor.set_visibility(False) 761 | else: 762 | actor.update_ball(self.ball_params[i][self.fr]) 763 | 764 | def setup_key_callback(self): 765 | super().setup_key_callback() 766 | 767 | def track_first_actor(): 768 | self.track_first_actor = not self.track_first_actor 769 | 770 | def track_ball(): 771 | self.track_ball = not self.track_ball 772 | if not self.track_ball: 773 | self.init_camera({'camera': 'front'}) 774 | 775 | def forward(): 776 | self.forward = True 777 | 778 | def reset_camera_front(): 779 | self.init_camera({'camera': 'front'}) 780 | 781 | def reset_camera_back(): 782 | self.init_camera({'camera': 'back'}) 783 | 784 | def reset_camera_side_both(): 785 | self.init_camera({'camera': 'side_both'}) 786 | 787 | def reset_camera_top_both(): 788 | self.init_camera({'camera': 'top_both'}) 789 | 790 | def reset_camera_top_near(): 791 | self.init_camera({'camera': 'top_near'}) 792 | 793 | def reset_camera_top_far(): 794 | self.init_camera({'camera': 'top_far'}) 795 | 796 | def reset_camera_near_left(): 797 | self.init_camera({'camera': 'near_left'}) 798 | 799 | def reset_camera_near_right(): 800 | self.init_camera({'camera': 'near_right'}) 801 | 802 | self.pl.add_key_event('t', track_first_actor) 803 | self.pl.add_key_event('b', track_ball) 804 | self.pl.add_key_event('n', forward) 805 | self.pl.add_key_event('1', reset_camera_front) 806 | self.pl.add_key_event('2', reset_camera_back) 807 | self.pl.add_key_event('3', reset_camera_side_both) 808 | self.pl.add_key_event('4', reset_camera_near_left) 809 | self.pl.add_key_event('5', reset_camera_near_right) 810 | self.pl.add_key_event('6', reset_camera_top_both) 811 | self.pl.add_key_event('7', reset_camera_top_near) 812 | self.pl.add_key_event('8', reset_camera_top_far) 813 | 814 | def show_animation_online(self, window_size=(800, 800), init_args=None, enable_shadow=None, 815 | show_axes=True, off_screen=False, fps=30): 816 | self.fps = fps 817 | self.frame_mode = 'fps' 818 | if off_screen: 819 | if platform.system() == 'Linux': 820 | pyvista.start_xvfb() 821 | if enable_shadow is not None: 822 | self.enable_shadow = enable_shadow 823 | if self.enable_shadow: 824 | window_size = (1000, 1000) 825 | self.pl = pyvista.Plotter(window_size=window_size, off_screen=off_screen, lighting='none') 826 | else: 827 | self.pl = pyvista.Plotter(window_size=window_size, off_screen=off_screen) 828 | self.init_camera(init_args) 829 | self.init_scene(init_args) 830 | self.setup_key_callback() 831 | if show_axes: 832 | self.pl.show_axes() 833 | self.pl.show(interactive_update=True) 834 | 835 | def update_scene_online(self, joint_pos=None, smpl_verts=None, racket_params=None, ball_params=None, ball_targets=None, 836 | tar_pos=None, tar_action=None, tar_time=None, stats=None): 837 | 838 | if self.show_smpl and smpl_verts is not None: 839 | for i, actor in enumerate(self.smpl_actors): 840 | actor.update_verts(smpl_verts[i].cpu().numpy()) 841 | actor.set_visibility(True) 842 | if self.show_skeleton: 843 | actor.set_opacity(0.8 if not self.enable_shadow else 1.0) 844 | else: 845 | actor.set_opacity(1.0) 846 | 847 | if self.show_skeleton and joint_pos is not None: 848 | for i, actor in enumerate(self.skeleton_actors): 849 | actor.update_joints(joint_pos[i].cpu().numpy()) 850 | if not self.enable_shadow: 851 | actor.set_visibility(True) 852 | actor.set_opacity(1.0) 853 | else: 854 | actor.set_visibility(False) 855 | 856 | if self.show_racket and racket_params is not None: 857 | for i, actor in enumerate(self.racket_actors): 858 | actor.update_racket(racket_params[i]) 859 | actor.set_visibility(True) 860 | 861 | if self.show_target and tar_pos is not None: 862 | for i in range(self.num_actors): 863 | rea_actor = self.tar_reaction_actors[i] 864 | rec_actor = self.tar_recover_actors[i] 865 | if tar_action is None: 866 | rec_actor.update_target(tar_pos[i].cpu().numpy()) 867 | rec_actor.set_visibility(True) 868 | rea_actor.set_visibility(False) 869 | else: 870 | rec_actor.set_visibility(False) 871 | if tar_action[i] == 1: 872 | rea_actor.update_target(tar_pos[i].cpu().numpy(), tar_time[i].cpu().numpy()) 873 | rea_actor.set_visibility(True) 874 | 875 | if self.show_ball and ball_params is not None: 876 | for i in range(min(self.num_actors, len(ball_params))): 877 | self.ball_actors[i].update_ball(ball_params[i]) 878 | 879 | if self.show_ball_target and ball_targets is not None: 880 | for i in range(min(self.num_actors, len(ball_targets))): 881 | self.ball_tar_actors[i].update_target(ball_targets[i]) 882 | 883 | if self.show_stats and stats is not None: 884 | self.text_actor_tar.SetInput('Target time, action, phase, swing, recovery, atnet: {:02d}, {}, {:.2f}, {}, {}, {}'.format( 885 | stats['tar_time'].cpu().numpy(), 886 | stats['tar_action'].cpu().numpy(), 887 | stats['phase'].cpu().numpy(), 888 | stats['swing_type'].cpu().numpy(), 889 | stats['target_recovery'].cpu().numpy(), 890 | stats['at_net'].cpu().numpy(), 891 | )) 892 | if stats['sub_reward_names'] is not None: 893 | self.text_actor_reward.SetInput('Reward: ({}) - {}'.format( 894 | stats['sub_reward_names'].replace('_reward', ''), 895 | np.array2string(stats['sub_rewards'].cpu().numpy(), formatter={'all': lambda x: '{:>7.4f}'.format(x)}, separator=','), 896 | )) 897 | if stats.get('res_dof_actions') is not None and stats.get('mvae_actions_norm') is not None: 898 | self.text_actor_residual.SetInput('VAE action norm, residual: {} {}'.format( 899 | np.array2string(stats['mvae_actions_norm'].cpu().numpy(), formatter={'all': lambda x: '{:>5.1f}'.format(x)}, separator=','), 900 | np.array2string(stats['res_dof_actions'].cpu().numpy(), formatter={'all': lambda x: '{:>5.1f}'.format(x * 180)}, separator=','), 901 | )) 902 | if stats.get('wrist_angle') is not None: 903 | self.text_actor_pose.SetInput('Physics Wrist, elbow, shoulder: {} {} {}'.format( 904 | # np.array2string(stats['wrist_angle_glb'].cpu().numpy(), formatter={'all': lambda x: '{:04.1f}'.format(x * 180 / pi)}, separator=','), 905 | np.array2string(stats['wrist_angle'].cpu().numpy(), formatter={'all': lambda x: '{:>5.1f}'.format(x * 180 / pi)}, separator=','), 906 | np.array2string(stats['elbow_angle'].cpu().numpy(), formatter={'all': lambda x: '{:>5.1f}'.format(x * 180 / pi)}, separator=','), 907 | np.array2string(stats['shoulder_angle'].cpu().numpy(), formatter={'all': lambda x: '{:>5.1f}'.format(x * 180 / pi)}, separator=','), 908 | )) 909 | if stats.get('wrist_angle_tar') is not None: 910 | self.text_actor_pose_tar.SetInput('Target Wrist, elbow, shoulder: {} {} {}'.format( 911 | # np.array2string(stats['wrist_angle_glb'].cpu().numpy(), formatter={'all': lambda x: '{:04.1f}'.format(x * 180 / pi)}, separator=','), 912 | np.array2string(stats['wrist_angle_tar'].cpu().numpy(), formatter={'all': lambda x: '{:>5.1f}'.format(x * 180 / pi)}, separator=','), 913 | np.array2string(stats['elbow_angle_tar'].cpu().numpy(), formatter={'all': lambda x: '{:>5.1f}'.format(x * 180 / pi)}, separator=','), 914 | np.array2string(stats['shoulder_angle_tar'].cpu().numpy(), formatter={'all': lambda x: '{:>5.1f}'.format(x * 180 / pi)}, separator=','), 915 | )) 916 | if stats.get('racket_pos') is not None: 917 | self.text_actor_racket.SetInput('Racket pos, vel, norm: {} {} {}'.format( 918 | np.array2string(stats['racket_pos'].cpu().numpy(), formatter={'all': lambda x: '{:>6.2f}'.format(x)}, separator=','), 919 | np.array2string(stats['racket_vel'].cpu().numpy(), formatter={'all': lambda x: '{:>6.2f}'.format(x)}, separator=','), 920 | np.array2string(stats['racket_normal'].cpu().numpy(), formatter={'all': lambda x: '{:>6.2f}'.format(x)}, separator=','), 921 | )) 922 | if stats.get('ball_pos') is not None: 923 | self.text_actor_ball.SetInput('Ball pos, vel, ang_vel, vspin, bounce, target: {} {} {} {} {} {} {}'.format( 924 | np.array2string(stats['ball_pos'].cpu().numpy(), formatter={'all': lambda x: '{:>6.2f}'.format(x)}, separator=','), 925 | np.array2string(stats['ball_vel'].cpu().numpy(), formatter={'all': lambda x: '{:>6.2f}'.format(x)}, separator=','), 926 | np.array2string(stats['ball_ang_vel'].cpu().numpy(), formatter={'all': lambda x: '{:>6.2f}'.format(x)}, separator=','), 927 | np.array2string(stats['ball_vspin'].cpu().numpy(), formatter={'all': lambda x: '{:>6.2f}'.format(x)}, separator=','), 928 | np.array2string(stats['est_ball_bounce'].cpu().numpy(), formatter={'all': lambda x: '{:>6.2f}'.format(x)}, separator=','), 929 | np.array2string(stats['ball_target_pos'].cpu().numpy(), formatter={'all': lambda x: '{:>6.2f}'.format(x)}, separator=','), 930 | np.array2string(stats['ball_target_spin'].cpu().numpy(), formatter={'all': lambda x: '{}'.format(x)}, separator=','), 931 | )) 932 | if stats.get('contact_force') is not None: 933 | self.text_actor_contact.SetInput('Contact force racket, ball: {} {}'.format( 934 | np.array2string(stats['contact_force'][-2].cpu().numpy(), formatter={'all': lambda x: '{:>5.1f}'.format(x)}, separator=','), 935 | np.array2string(stats['contact_force'][-1].cpu().numpy(), formatter={'all': lambda x: '{:>5.1f}'.format(x)}, separator=','), 936 | )) 937 | 938 | self.smpl_joints = joint_pos.unsqueeze(0) 939 | if ball_params[0] is not None: 940 | self.ball_params = [[ball_params[0]['pos']]] 941 | self.fr = 0 942 | 943 | def render_online(self, interactive): 944 | last_render_time = time.time() 945 | if interactive: 946 | while True: 947 | self.render(interactive=True) 948 | if self.forward: 949 | self.paused = True 950 | self.forward = False 951 | break 952 | if self.paused: continue 953 | if time.time() - last_render_time >= (1 / self.fps - 0.002): 954 | break 955 | else: 956 | self.render(interactive=False) --------------------------------------------------------------------------------