├── .gitignore ├── LICENSE ├── README.md ├── assets ├── example_conditional_sparse_T.gif ├── example_text_only.gif └── teaser.png ├── configs ├── __init__.py ├── card.py ├── data.py └── model.py ├── data_loaders ├── a2m │ ├── dataset.py │ ├── humanact12poses.py │ └── uestc.py ├── amass │ ├── data │ │ └── dataset.py │ └── utils │ │ ├── fk.py │ │ ├── helper_functions.py │ │ ├── rotations.py │ │ ├── smpl.yaml │ │ └── utils.py ├── amass_utils.py ├── get_data.py ├── humanml │ ├── README.md │ ├── common │ │ ├── quaternion.py │ │ └── skeleton.py │ ├── data │ │ ├── __init__.py │ │ └── dataset.py │ ├── motion_loaders │ │ ├── __init__.py │ │ ├── comp_v6_model_dataset.py │ │ ├── comp_v6_model_dataset_condmdi.py │ │ ├── dataset_motion_loader.py │ │ └── model_motion_loaders.py │ ├── networks │ │ ├── __init__.py │ │ ├── evaluator_wrapper.py │ │ ├── modules.py │ │ └── trainers.py │ ├── scripts │ │ └── motion_process.py │ └── utils │ │ ├── get_opt.py │ │ ├── metrics.py │ │ ├── paramUtil.py │ │ ├── plot_script.py │ │ ├── plotting.py │ │ ├── utils.py │ │ └── word_vectorizer.py ├── humanml_utils.py └── tensors.py ├── dataset ├── 000021.npy ├── HumanML3D_abs │ ├── Mean_abs_3d.npy │ ├── Std_abs_3d.npy │ ├── cal_mean_variance.ipynb │ └── motion_representation.ipynb ├── README.md ├── humanml_opt.txt ├── inv_rand_proj.npy ├── kit_mean.npy ├── kit_opt.txt ├── kit_std.npy ├── rand_proj.npy ├── t2m_mean.npy └── t2m_std.npy ├── diffusion ├── fp16_util.py ├── gaussian_diffusion.py ├── logger.py ├── losses.py ├── nn.py ├── resample.py └── respace.py ├── eval ├── a2m │ ├── __init__.py │ ├── action2motion │ │ ├── accuracy.py │ │ ├── diversity.py │ │ ├── evaluate.py │ │ ├── fid.py │ │ └── models.py │ ├── gru_eval.py │ ├── recognition │ │ └── models │ │ │ ├── stgcn.py │ │ │ └── stgcnutils │ │ │ ├── graph.py │ │ │ └── tgcn.py │ ├── stgcn │ │ ├── accuracy.py │ │ ├── diversity.py │ │ ├── evaluate.py │ │ └── fid.py │ ├── stgcn_eval.py │ └── tools.py ├── eval_humanact12_uestc.py ├── eval_humanml.py ├── eval_humanml_condition.py ├── eval_humanml_condmdi.py └── unconstrained │ ├── evaluate.py │ ├── metrics │ ├── kid.py │ └── precision_recall.py │ └── models │ ├── stgcn.py │ └── stgcnutils │ └── graph.py ├── model ├── cfg_sampler.py ├── mdm.py ├── mdm_dit.py ├── mdm_unet.py ├── rotation2xyz.py └── smpl.py ├── prepare ├── download_a2m_datasets.sh ├── download_glove.sh ├── download_recognition_models.sh ├── download_recognition_unconstrained_models.sh ├── download_smpl_files.sh ├── download_t2m_evaluators.sh └── download_unconstrained_datasets.sh ├── requirements.txt ├── sample ├── conditional_synthesis.py ├── edit.py ├── gmd │ ├── condition.py │ ├── generate.py │ └── keyframe_pattern.py └── synthesize.py ├── train ├── train_condmdi.py └── training_loop.py ├── utils ├── PYTORCH3D_LICENSE ├── config.py ├── dist_util.py ├── editing_util.py ├── fixseed.py ├── generation_template.py ├── hfargparse.py ├── misc.py ├── model_util.py ├── output_util.py ├── parser_util.py └── rotation_conversions.py └── visualize ├── joints2smpl ├── README.md ├── environment.yaml ├── fit_seq.py ├── smpl_models │ ├── SMPL_downsample_index.pkl │ ├── gmm_08.pkl │ ├── neutral_smpl_mean_params.h5 │ └── smplx_parts_segm.pkl └── src │ ├── config.py │ ├── customloss.py │ ├── prior.py │ └── smplify.py ├── render_mesh.py ├── simplify_loc2rot.py └── vis_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | ENV 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | 133 | # experiments 134 | save/ 135 | dataset/ 136 | kit/ 137 | t2m/ 138 | glove/ 139 | body_models/ 140 | .vscode/ 141 | wandb/ 142 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 4 | 2024 Setareh Cohan 5 | 2023 Korrawe Karunratanakul 6 | 2022 Guy Tevet 7 | 8 | Permission is hereby granted, free of charge, to any person obtaining a copy 9 | of this software and associated documentation files (the "Software"), to deal 10 | in the Software without restriction, including without limitation the rights 11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | copies of the Software, and to permit persons to whom the Software is 13 | furnished to do so, subject to the following conditions: 14 | 15 | The above copyright notice and this permission notice shall be included in all 16 | copies or substantial portions of the Software. 17 | 18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | SOFTWARE. 25 | -------------------------------------------------------------------------------- /assets/example_conditional_sparse_T.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/setarehc/diffusion-motion-inbetweening/0c58473066ef582b9464a66d5542e0d1b6ae66c8/assets/example_conditional_sparse_T.gif -------------------------------------------------------------------------------- /assets/example_text_only.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/setarehc/diffusion-motion-inbetweening/0c58473066ef582b9464a66d5542e0d1b6ae66c8/assets/example_text_only.gif -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/setarehc/diffusion-motion-inbetweening/0c58473066ef582b9464a66d5542e0d1b6ae66c8/assets/teaser.png -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/setarehc/diffusion-motion-inbetweening/0c58473066ef582b9464a66d5542e0d1b6ae66c8/configs/__init__.py -------------------------------------------------------------------------------- /configs/card.py: -------------------------------------------------------------------------------- 1 | from utils.parser_util import * 2 | from dataclasses import dataclass 3 | from configs import data, model 4 | 5 | ########################### 6 | # MOTION MODELS 7 | ########################### 8 | 9 | 10 | @dataclass 11 | class motion_rel_mdm( 12 | data.humanml_motion_rel, 13 | model.motion_mdm, 14 | ): 15 | save_dir: str = 'save/my_humanml_trans_enc_512_test' 16 | 17 | 18 | @dataclass 19 | class motion_abs_mdm( 20 | data.humanml_motion_abs, 21 | model.motion_mdm, 22 | ): 23 | save_dir: str = 'save/my_abs3d_2' 24 | 25 | 26 | @dataclass 27 | class motion_abs_mdm_proj1( 28 | data.humanml_motion_proj1, 29 | model.motion_mdm, 30 | ): 31 | save_dir: str = 'save/my_abs3d_proj_1' 32 | 33 | 34 | @dataclass 35 | class motion_abs_mdm_proj2( 36 | data.humanml_motion_proj2, 37 | model.motion_mdm, 38 | ): 39 | save_dir: str = 'save/my_abs3d_proj_2' 40 | 41 | 42 | @dataclass 43 | class motion_abs_mdm_proj5( 44 | data.humanml_motion_proj5, 45 | model.motion_mdm, 46 | ): 47 | save_dir: str = 'save/my_abs3d_proj_5' 48 | 49 | 50 | @dataclass 51 | class motion_abs_mdm_proj10( 52 | data.humanml_motion_proj10, 53 | model.motion_mdm, 54 | ): 55 | save_dir: str = 'save/my_abs3d_proj_10_2' 56 | 57 | 58 | @dataclass 59 | class motion_rel_unet_adagn_xl( 60 | data.humanml_motion_rel, 61 | model.motion_unet_adagn_xl, 62 | ): 63 | save_dir: str = 'save/unet_adazero_xl_x0_rel_loss1_fp16_clipwd_224' 64 | 65 | 66 | ########################### 67 | # UNET XL 68 | ########################### 69 | 70 | 71 | @dataclass 72 | class motion_abs_unet_adagn_xl( 73 | data.humanml_motion_abs, 74 | model.motion_unet_adagn_xl, 75 | ): 76 | save_dir: str = 'save/unet_adazero_xl_x0_abs_loss1_fp16_clipwd_224' 77 | 78 | 79 | @dataclass 80 | class motion_abs_unet_adagn_xl_loss2( 81 | data.humanml_motion_abs, 82 | model.motion_unet_adagn_xl_loss2, 83 | ): 84 | save_dir: str = 'save/unet_adazero_xl_x0_abs_loss2_fp16_clipwd_224' 85 | 86 | 87 | @dataclass 88 | class motion_abs_unet_adagn_xl_loss5( 89 | data.humanml_motion_abs, 90 | model.motion_unet_adagn_xl_loss5, 91 | ): 92 | save_dir: str = 'save/unet_adazero_xl_x0_abs_loss5_fp16_clipwd_224' 93 | 94 | 95 | @dataclass 96 | class motion_abs_unet_adagn_xl_loss10( 97 | data.humanml_motion_abs, 98 | model.motion_unet_adagn_xl_loss10, 99 | ): 100 | save_dir: str = 'save/unet_adazero_xl_x0_abs_loss10_fp16_clipwd_224' 101 | 102 | 103 | ########################### 104 | # UNET XL + PROJ 105 | ########################### 106 | 107 | @dataclass 108 | class motion_abs_proj1_unet_adagn_xl( 109 | data.humanml_motion_proj1, 110 | model.motion_unet_adagn_xl, 111 | ): 112 | save_dir: str = 'save/unet_adazero_xl_x0_abs_proj1_fp16_clipwd_224' 113 | 114 | 115 | @dataclass 116 | class motion_abs_proj2_unet_adagn_xl( 117 | data.humanml_motion_proj2, 118 | model.motion_unet_adagn_xl, 119 | ): 120 | save_dir: str = 'save/unet_adazero_xl_x0_abs_proj2_fp16_clipwd_224' 121 | 122 | 123 | @dataclass 124 | class motion_abs_proj5_unet_adagn_xl( 125 | data.humanml_motion_proj5, 126 | model.motion_unet_adagn_xl, 127 | ): 128 | save_dir: str = 'save/unet_adazero_xl_x0_abs_proj5_fp16_clipwd_224' 129 | 130 | 131 | @dataclass 132 | class motion_abs_proj10_unet_adagn_xl( 133 | data.humanml_motion_proj10, 134 | model.motion_unet_adagn_xl, 135 | ): 136 | save_dir: str = 'save/unet_adazero_xl_x0_abs_proj10_fp16_clipwd_224' 137 | 138 | ########################### 139 | # TRAJ MODELS 140 | ########################### 141 | @dataclass 142 | class traj_unet_adagn_swx( 143 | data.humanml_traj, 144 | model.traj_unet_adagn_swx, 145 | ): 146 | save_dir: str = 'save/traj_unet_adazero_swxs_eps_abs_fp16_clipwd_224' 147 | 148 | 149 | @dataclass 150 | class traj_unet_xxs( 151 | data.humanml_traj, 152 | model.traj_unet_xxs, 153 | ): 154 | save_dir: str = 'save/traj_unet_xxs_eps_abs_fp16_clipwd_224' -------------------------------------------------------------------------------- /configs/data.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from utils.parser_util import BaseOptions, DataOptions 3 | 4 | 5 | @dataclass 6 | class humanml_motion_rel(BaseOptions, DataOptions): 7 | dataset: str = 'humanml' 8 | data_dir: str = '' 9 | abs_3d: bool = False 10 | 11 | 12 | @dataclass 13 | class humanml_motion_abs(BaseOptions, DataOptions): 14 | dataset: str = 'humanml' 15 | data_dir: str = '' 16 | abs_3d: bool = True 17 | 18 | 19 | @dataclass 20 | class humanml_motion_proj1(humanml_motion_abs): 21 | use_random_proj: bool = True 22 | random_proj_scale: float = 1 23 | 24 | 25 | @dataclass 26 | class humanml_motion_proj2(humanml_motion_abs): 27 | use_random_proj: bool = True 28 | random_proj_scale: float = 2 29 | 30 | 31 | @dataclass 32 | class humanml_motion_proj5(humanml_motion_abs): 33 | use_random_proj: bool = True 34 | random_proj_scale: float = 5 35 | 36 | 37 | @dataclass 38 | class humanml_motion_proj10(humanml_motion_abs): 39 | use_random_proj: bool = True 40 | random_proj_scale: float = 10 41 | 42 | 43 | @dataclass 44 | class humanml_traj(humanml_motion_abs): 45 | traj_only: bool = True -------------------------------------------------------------------------------- /configs/model.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Tuple 3 | from utils.parser_util import DataOptions, ModelOptions, DiffusionOptions, TrainingOptions, EvaluationOptions 4 | 5 | 6 | @dataclass 7 | class _motion(ModelOptions, DataOptions, DiffusionOptions, TrainingOptions, 8 | EvaluationOptions): 9 | num_frames: int = 196 10 | predict_xstart: bool = True 11 | grad_clip: float = 1. 12 | avg_model_beta: float = 0.9999 13 | 14 | 15 | @dataclass 16 | class _traj(ModelOptions, DataOptions, DiffusionOptions, TrainingOptions, 17 | EvaluationOptions): 18 | num_frames: int = 196 19 | predict_xstart: bool = False 20 | grad_clip: float = 1. 21 | avg_model_beta: float = 0.9999 22 | batch_size: int = 64 23 | save_interval: int = 12_500 24 | num_steps: int = 100_000 25 | 26 | 27 | @dataclass 28 | class _motion_unet(_motion): 29 | # all UNETs use 224 as the training max length 30 | num_frames: int = 224 31 | weight_decay: float = 0.01 32 | use_fp16: bool = True 33 | arch: str = 'unet' 34 | latent_dim: int = 512 35 | unet_adagn: bool = True 36 | unet_zero: bool = True 37 | 38 | 39 | @dataclass 40 | class _traj_unet(_traj): 41 | # all UNETs use 224 as the training max length 42 | num_frames: int = 224 43 | weight_decay: float = 0.01 44 | use_fp16: bool = True 45 | arch: str = 'unet' 46 | latent_dim: int = 512 47 | unet_adagn: bool = True 48 | unet_zero: bool = True 49 | 50 | 51 | @dataclass 52 | class motion_mdm(_motion): 53 | arch: str = 'trans_enc' 54 | latent_dim: int = 512 55 | ff_size: int = 1024 56 | weight_decay: float = 0 57 | eval_use_avg: bool = False # MDM doesn't use avg model during inference 58 | 59 | 60 | @dataclass 61 | class traj_mdm(_traj): 62 | pass 63 | 64 | 65 | @dataclass 66 | class motion_unet_adagn_xl(_motion_unet): 67 | dim_mults: Tuple[float] = (2, 2, 2, 2) 68 | 69 | 70 | @dataclass 71 | class motion_unet_adagn_xl_loss2(_motion_unet): 72 | dim_mults: Tuple[float] = (2, 2, 2, 2) 73 | traj_extra_weight: float = 2 74 | 75 | 76 | @dataclass 77 | class motion_unet_adagn_xl_loss5(_motion_unet): 78 | dim_mults: Tuple[float] = (2, 2, 2, 2) 79 | traj_extra_weight: float = 5 80 | 81 | 82 | @dataclass 83 | class motion_unet_adagn_xl_loss10(_motion_unet): 84 | dim_mults: Tuple[float] = (2, 2, 2, 2) 85 | traj_extra_weight: float = 10 86 | 87 | 88 | @dataclass 89 | class traj_unet_adagn_swx(_traj_unet): 90 | dim_mults: Tuple[float] = (0.125, 0.25, 0.5) 91 | 92 | 93 | @dataclass 94 | class traj_unet_xxs(_traj_unet): 95 | dim_mults: Tuple[float] = (0.0625, 0.125, 0.25, 0.5) 96 | unet_adagn: bool = False 97 | unet_zero: bool = False 98 | -------------------------------------------------------------------------------- /data_loaders/a2m/humanact12poses.py: -------------------------------------------------------------------------------- 1 | import pickle as pkl 2 | import numpy as np 3 | import os 4 | from .dataset import Dataset 5 | 6 | 7 | class HumanAct12Poses(Dataset): 8 | dataname = "humanact12" 9 | 10 | def __init__(self, datapath="dataset/HumanAct12Poses", split="train", **kargs): 11 | self.datapath = datapath 12 | 13 | super().__init__(**kargs) 14 | 15 | pkldatafilepath = os.path.join(datapath, "humanact12poses.pkl") 16 | data = pkl.load(open(pkldatafilepath, "rb")) 17 | 18 | self._pose = [x for x in data["poses"]] 19 | self._num_frames_in_video = [p.shape[0] for p in self._pose] 20 | self._joints = [x for x in data["joints3D"]] 21 | 22 | self._actions = [x for x in data["y"]] 23 | 24 | total_num_actions = 12 25 | self.num_actions = total_num_actions 26 | 27 | self._train = list(range(len(self._pose))) 28 | 29 | keep_actions = np.arange(0, total_num_actions) 30 | 31 | self._action_to_label = {x: i for i, x in enumerate(keep_actions)} 32 | self._label_to_action = {i: x for i, x in enumerate(keep_actions)} 33 | 34 | self._action_classes = humanact12_coarse_action_enumerator 35 | 36 | def _load_joints3D(self, ind, frame_ix): 37 | return self._joints[ind][frame_ix] 38 | 39 | def _load_rotvec(self, ind, frame_ix): 40 | pose = self._pose[ind][frame_ix].reshape(-1, 24, 3) 41 | return pose 42 | 43 | 44 | humanact12_coarse_action_enumerator = { 45 | 0: "warm_up", 46 | 1: "walk", 47 | 2: "run", 48 | 3: "jump", 49 | 4: "drink", 50 | 5: "lift_dumbbell", 51 | 6: "sit", 52 | 7: "eat", 53 | 8: "turn steering wheel", 54 | 9: "phone", 55 | 10: "boxing", 56 | 11: "throw", 57 | } 58 | -------------------------------------------------------------------------------- /data_loaders/amass/data/dataset.py: -------------------------------------------------------------------------------- 1 | """ Code adapted from https://github.com/c-he/NeMF""" 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) 2021 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), 5 | # acting on behalf of its Max Planck Institute for Intelligent Systems and the 6 | # Max Planck Institute for Biological Cybernetics. All rights reserved. 7 | # 8 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights 9 | # on this computer program. You can only use this computer program if you have closed a license agreement 10 | # with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. 11 | # Any use of the computer program without a valid license is prohibited and liable to prosecution. 12 | # Contact: ps-license@tuebingen.mpg.de 13 | # 14 | # 15 | # If you use this code in a research publication please consider citing the following: 16 | # 17 | # AMASS: Archive of Motion Capture as Surface Shapes 18 | # 19 | # 20 | # Code Developed by: 21 | # Nima Ghorbani 22 | # 23 | # 2019.08.09 24 | 25 | # len(train_dataset) = 11642, len(test_dataset) = 164, len(valid_dataset) = 1668 26 | # train_dataset.ds is a dictionary with keys = (['trans', 'rotmat', 'pos', 'angular', 'contacts', 'height', 'root_vel', 'velocity', 'global_xform', 'root_orient', 'rot6d']) 27 | # train_dataset.ds['pos'] is a tensor of shape [11642, 128, 24, 3] 28 | 29 | import os 30 | import sys 31 | 32 | file_path = os.path.dirname(os.path.realpath(__file__)) 33 | sys.path.append(os.path.join(file_path, '..')) 34 | 35 | import glob 36 | import torch 37 | from torch.utils.data import Dataset 38 | 39 | class AMASS(Dataset): 40 | """AMASSN: a pytorch loader for unified human motion capture dataset. http://amass.is.tue.mpg.de/""" 41 | """Adopted from NeMF codebase: https://github.com/c-he/NeMF """ 42 | 43 | def __init__(self, split='train'): 44 | self.root_dir = 'dataset/amass/generative' 45 | self.dataset_dir = os.path.join(self.root_dir, split) 46 | self.ds = {} 47 | for data_fname in glob.glob(os.path.join(self.dataset_dir, '*.pt')): 48 | k = os.path.basename(data_fname).split('-')[0] 49 | self.ds[k] = torch.load(data_fname) 50 | self.clip_length = 128 51 | self.mean = torch.load(os.path.join(self.root_dir, 'mean-male-128-30fps.pt')) # [1, 1, clip_length, dim] 52 | self.std = torch.load(os.path.join(self.root_dir, 'std-male-128-30fps.pt')) # [1, 1, clip_length, dim] 53 | 54 | def __len__(self): 55 | return len(self.ds['trans']) 56 | 57 | def _normalize_field(self, value, key): 58 | return (value - self.mean[key][0]) / self.std[key][0] 59 | 60 | def _denormalize_field(self, value, key): 61 | return value * self.std[key][0] + self.mean[key][0] 62 | 63 | def normalize(self, data): 64 | data = data.copy() 65 | for key in data.keys(): 66 | data[key] = self._normalize_field(data[key], key) 67 | return data 68 | 69 | def denormalize(self, data): 70 | data = data.copy() 71 | for key in data.keys(): 72 | data[key] = self._denormalize_field(data[key], key) 73 | return data 74 | 75 | def __getitem__(self, idx): 76 | data = [] 77 | for key in self.ds.keys(): 78 | # value is of shape [datset_size, clip_length, dim] 79 | normalized_item = self._normalize_field(self.ds[key][idx], key) 80 | data.append(normalized_item.reshape(self.clip_length, -1)) 81 | data = torch.cat(data, dim=-1) 82 | return data -------------------------------------------------------------------------------- /data_loaders/amass/utils/fk.py: -------------------------------------------------------------------------------- 1 | """Based on Daniel Holden code from: 2 | A Deep Learning Framework for Character Motion Synthesis and Editing 3 | (http://www.ipab.inf.ed.ac.uk/cgvu/motionsynthesis.pdf) 4 | """ 5 | 6 | import os 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | from data_loaders.amass.utils.rotations import euler_angles_to_matrix, quaternion_to_matrix, rotation_6d_to_matrix 12 | import yaml 13 | 14 | class ForwardKinematicsLayer(nn.Module): 15 | """ Forward Kinematics Layer Class """ 16 | 17 | def __init__(self, parents=None, positions=None, device=None): 18 | with open('data_loaders/amass/utils/smpl.yaml', 'r') as f: 19 | smpl = yaml.safe_load(f) 20 | super().__init__() 21 | self.b_idxs = None 22 | if device is None: 23 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 24 | else: 25 | self.device = device 26 | 27 | if parents is None and positions is None: 28 | # Load SMPL skeleton (their joint order is different from the one we use for bvh export) 29 | smpl_fname = os.path.join(smpl['smpl_body_model'], 'male', 'model.npz') 30 | smpl_data = np.load(smpl_fname, encoding='latin1') 31 | self.parents = torch.from_numpy(smpl_data['kintree_table'][0].astype(np.int32)).to(self.device) 32 | self.parents = self.parents.long() 33 | self.positions = torch.from_numpy(smpl_data['J'].astype(np.float32)).to(self.device) 34 | self.positions[1:] -= self.positions[self.parents[1:]] 35 | else: 36 | self.parents = torch.from_numpy(parents).to(self.device) 37 | self.parents = self.parents.long() 38 | self.positions = torch.from_numpy(positions).to(self.device) 39 | self.positions = self.positions.float() 40 | self.positions[0] = 0 41 | 42 | def rotate(self, t0s, t1s): 43 | return torch.matmul(t0s, t1s) 44 | 45 | def identity_rotation(self, rotations): 46 | diagonal = torch.diag(torch.tensor([1.0, 1.0, 1.0, 1.0])).to(self.device) 47 | diagonal = torch.reshape( 48 | diagonal, torch.Size([1] * len(rotations.shape[:2]) + [4, 4])) 49 | ts = diagonal.repeat(rotations.shape[:2] + torch.Size([1, 1])) 50 | return ts 51 | 52 | def make_fast_rotation_matrices(self, positions, rotations): 53 | if len(rotations.shape) == 4 and rotations.shape[-2:] == torch.Size([3, 3]): 54 | rot_matrices = rotations 55 | elif rotations.shape[-1] == 3: 56 | rot_matrices = euler_angles_to_matrix(rotations, convention='XYZ') 57 | elif rotations.shape[-1] == 4: 58 | rot_matrices = quaternion_to_matrix(rotations) 59 | elif rotations.shape[-1] == 6: 60 | rot_matrices = rotation_6d_to_matrix(rotations) 61 | else: 62 | raise NotImplementedError(f'Unimplemented rotation representation in FK layer, shape of {rotations.shape}') 63 | 64 | rot_matrices = torch.cat([rot_matrices, positions[..., None]], dim=-1) 65 | zeros = torch.zeros(rot_matrices.shape[:-2] + torch.Size([1, 3])).to(self.device) 66 | ones = torch.ones(rot_matrices.shape[:-2] + torch.Size([1, 1])).to(self.device) 67 | zerosones = torch.cat([zeros, ones], dim=-1) 68 | rot_matrices = torch.cat([rot_matrices, zerosones], dim=-2) 69 | return rot_matrices 70 | 71 | def rotate_global(self, parents, positions, rotations): 72 | locals = self.make_fast_rotation_matrices(positions, rotations) 73 | globals = self.identity_rotation(rotations) 74 | 75 | globals = torch.cat([locals[:, 0:1], globals[:, 1:]], dim=1) 76 | b_size = positions.shape[0] 77 | if self.b_idxs is None: 78 | self.b_idxs = torch.LongTensor(np.arange(b_size)).to(self.device) 79 | elif self.b_idxs.shape[-1] != b_size: 80 | self.b_idxs = torch.LongTensor(np.arange(b_size)).to(self.device) 81 | 82 | for i in range(1, positions.shape[1]): 83 | globals[:, i] = self.rotate( 84 | globals[self.b_idxs, parents[i]], locals[:, i]) 85 | 86 | return globals 87 | 88 | def get_tpose_joints(self, offsets, parents): 89 | num_joints = len(parents) 90 | joints = [offsets[:, 0]] 91 | for j in range(1, len(parents)): 92 | joints.append(joints[parents[j]] + offsets[:, j]) 93 | 94 | return torch.stack(joints, dim=1) 95 | 96 | def canonical_to_local(self, canonical_xform, global_orient=None): 97 | """ 98 | Args: 99 | canonical_xform: (B, J, 3, 3) 100 | global_orient: (B, 3, 3) 101 | 102 | Returns: 103 | local_xform: (B, J, 3, 3) 104 | """ 105 | local_xform = torch.zeros_like(canonical_xform) 106 | 107 | if global_orient is None: 108 | global_xform = canonical_xform 109 | else: 110 | global_xform = torch.matmul(global_orient.unsqueeze(1), canonical_xform) 111 | for i in range(global_xform.shape[1]): 112 | if i == 0: 113 | local_xform[:, i] = global_xform[:, i] 114 | else: 115 | local_xform[:, i] = torch.bmm(torch.linalg.inv(global_xform[:, self.parents[i]]), global_xform[:, i]) 116 | 117 | return local_xform 118 | 119 | def global_to_local(self, global_xform): 120 | """ 121 | Args: 122 | global_xform: (B, J, 3, 3) 123 | 124 | Returns: 125 | local_xform: (B, J, 3, 3) 126 | """ 127 | local_xform = torch.zeros_like(global_xform) 128 | 129 | for i in range(global_xform.shape[1]): 130 | if i == 0: 131 | local_xform[:, i] = global_xform[:, i] 132 | else: 133 | local_xform[:, i] = torch.bmm(torch.linalg.inv(global_xform[:, self.parents[i]]), global_xform[:, i]) 134 | 135 | return local_xform 136 | 137 | def forward(self, rotations, positions=None): 138 | """ 139 | Args: 140 | rotations (B, J, D) 141 | 142 | Returns: 143 | The global position of each joint after FK (B, J, 3) 144 | """ 145 | # Get the full transform with rotations for skinning 146 | b_size = rotations.shape[0] 147 | if positions is None: 148 | positions = self.positions.repeat(b_size, 1, 1) 149 | transforms = self.rotate_global(self.parents, positions, rotations) 150 | coordinates = transforms[:, :, :3, 3] / transforms[:, :, 3:, 3] 151 | 152 | return coordinates, transforms 153 | -------------------------------------------------------------------------------- /data_loaders/amass/utils/helper_functions.py: -------------------------------------------------------------------------------- 1 | ### Helper functions adopted from src/utils.py in NeMF 2 | 3 | import torch 4 | 5 | def estimate_linear_velocity(data_seq, dt): 6 | ''' 7 | Given some batched data sequences of T timesteps in the shape (B, T, ...), estimates 8 | the velocity for the middle T-2 steps using a second order central difference scheme. 9 | The first and last frames are with forward and backward first-order 10 | differences, respectively 11 | - h : step size 12 | ''' 13 | # first steps is forward diff (t+1 - t) / dt 14 | init_vel = (data_seq[:, 1:2] - data_seq[:, :1]) / dt 15 | # middle steps are second order (t+1 - t-1) / 2dt 16 | middle_vel = (data_seq[:, 2:] - data_seq[:, 0:-2]) / (2 * dt) 17 | # last step is backward diff (t - t-1) / dt 18 | final_vel = (data_seq[:, -1:] - data_seq[:, -2:-1]) / dt 19 | 20 | vel_seq = torch.cat([init_vel, middle_vel, final_vel], dim=1) 21 | return vel_seq 22 | 23 | 24 | def estimate_angular_velocity(rot_seq, dt): 25 | ''' 26 | Given a batch of sequences of T rotation matrices, estimates angular velocity at T-2 steps. 27 | Input sequence should be of shape (B, T, ..., 3, 3) 28 | ''' 29 | # see https://en.wikipedia.org/wiki/Angular_velocity#Calculation_from_the_orientation_matrix 30 | dRdt = estimate_linear_velocity(rot_seq, dt) 31 | R = rot_seq 32 | RT = R.transpose(-1, -2) 33 | # compute skew-symmetric angular velocity tensor 34 | w_mat = torch.matmul(dRdt, RT) 35 | # pull out angular velocity vector by averaging symmetric entries 36 | w_x = (-w_mat[..., 1, 2] + w_mat[..., 2, 1]) / 2.0 37 | w_y = (w_mat[..., 0, 2] - w_mat[..., 2, 0]) / 2.0 38 | w_z = (-w_mat[..., 0, 1] + w_mat[..., 1, 0]) / 2.0 39 | w = torch.stack([w_x, w_y, w_z], axis=-1) 40 | return w -------------------------------------------------------------------------------- /data_loaders/amass/utils/smpl.yaml: -------------------------------------------------------------------------------- 1 | smpl_body_model: ./body_models/smpl 2 | 3 | offsets: 4 | male: [ 5 | [-0.002174, 0.972724, 0.028584], 6 | [ 0.058581, -0.082280, -0.017664], 7 | [ 0.043451, -0.386469, 0.008037], 8 | [-0.014790, -0.426874, -0.037428], 9 | [ 0.041054, -0.060286, 0.122042], 10 | [-0.060310, -0.090513, -0.013543], 11 | [-0.043257, -0.383688, -0.004843], 12 | [ 0.019056, -0.420046, -0.034562], 13 | [-0.034840, -0.062106, 0.130323], 14 | [ 0.004439, 0.124404, -0.038385], 15 | [ 0.004488, 0.137956, 0.026820], 16 | [-0.002265, 0.056032, 0.002855], 17 | [-0.013390, 0.211635, -0.033468], 18 | [ 0.010113, 0.088937, 0.050410], 19 | [ 0.071702, 0.114000, -0.018898], 20 | [ 0.122921, 0.045205, -0.019046], 21 | [ 0.255332, -0.015649, -0.022946], 22 | [ 0.265709, 0.012698, -0.007375], 23 | [ 0.086691, -0.010636, -0.015594], 24 | [-0.082954, 0.112472, -0.023707], 25 | [-0.113228, 0.046853, -0.008472], 26 | [-0.260127, -0.014369, -0.031269], 27 | [-0.269108, 0.006794, -0.006027], 28 | [-0.088754, -0.008652, -0.010107] 29 | ] 30 | female: [ 31 | [-0.000876, 0.909315, 0.027821], 32 | [ 0.071361, -0.089584, -0.008046], 33 | [ 0.030669, -0.364209, -0.006689], 34 | [-0.011554, -0.383348, -0.043502], 35 | [ 0.023338, -0.054645, 0.114370], 36 | [-0.069012, -0.088960, -0.004796], 37 | [-0.036152, -0.370650, -0.009185], 38 | [ 0.014029, -0.383638, -0.041892], 39 | [-0.022043, -0.046410, 0.117900], 40 | [-0.002508, 0.103257, -0.022185], 41 | [ 0.003581, 0.127658, -0.001713], 42 | [ 0.002027, 0.049072, 0.027867], 43 | [-0.001963, 0.208243, -0.049765], 44 | [ 0.003517, 0.062322, 0.050205], 45 | [ 0.075298, 0.117780, -0.036875], 46 | [ 0.085317, 0.031739, -0.007293], 47 | [ 0.251247, -0.011967, -0.027518], 48 | [ 0.238019, 0.009007, 0.000044], 49 | [ 0.079668, -0.009683, -0.013206], 50 | [-0.077033, 0.115606, -0.041811], 51 | [-0.089203, 0.032785, -0.009802], 52 | [-0.245990, -0.013152, -0.020162], 53 | [-0.245177, 0.008622, -0.003532], 54 | [-0.080400, -0.007248, -0.010419] 55 | ] 56 | 57 | parents: [-1, 0, 1, 2, 3, 0, 5, 6, 7, 0, 9, 10, 11, 12, 11, 14, 15, 16, 17, 11, 19, 20, 21, 22] 58 | 59 | joint_names: ['Pelvis', 'L_Hip', 'L_Knee', 'L_Ankle', 'L_Foot', 'R_Hip', 'R_Knee', 'R_Ankle', 'R_Foot', 'Spine1', 'Spine2', 'Spine3', 'Neck', 'Head', 'L_Collar', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'L_Hand', 'R_Collar', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'R_Hand'] 60 | 61 | joints_to_use: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 37] 62 | 63 | leaf_joints: [10, 11, 15, 22, 23] 64 | 65 | lfoot_index: [7, 10] 66 | 67 | rfoot_index: [8, 11] -------------------------------------------------------------------------------- /data_loaders/amass_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # NeMF uses target = [pos, rotmat ,trans, root_orient] but we use everything except velocities and contacts 4 | 5 | 6 | # Matrix that shows joint correspondces to SMPL features 7 | MAT_POS = np.zeros((24, 764), dtype=np.bool) 8 | MAT_POS[0, :3] = True # root position = trans 9 | for joint_idx in range(24): 10 | ub = 3 + 24*3*3 + 3 * (joint_idx + 1) 11 | lb = ub - 3 12 | MAT_POS[joint_idx, lb:ub] = True # joint position = pos 13 | 14 | MAT_ROTMAT = np.zeros((24, 764), dtype=np.bool) # rotmat = 24,3,3 wrp to the parent joint 15 | for joint_idx in range(24): 16 | ub = 3 + 3*3 * (joint_idx + 1) 17 | lb = ub - 9 18 | MAT_ROTMAT[joint_idx, lb:ub] = True # joint rotation = rotmat 19 | 20 | MAT_HEIGHT = np.zeros((24, 764), dtype=np.bool) # height = 24 21 | for joint_idx in range(24): 22 | ub = 3 + 24*3*3 + 24*3 + 24*3 + 8 + (joint_idx + 1) 23 | lb = ub - 1 24 | MAT_HEIGHT[joint_idx, lb:ub] = True # joint rotation = rotmat 25 | 26 | MAT_ROT6D = np.zeros((24, 764), dtype=np.bool) # rot2d = 24,2 wrp to the parent joint 27 | for joint_idx in range(24): 28 | ub = 3 + 24*3*3 + 24*3 + 24*3 + 8 + 24 + 3 + 24*3 + 24*6 + 6 + 6 * (joint_idx + 1) 29 | lb = ub - 6 30 | MAT_ROT6D[joint_idx, lb:ub] = True # joint rotation = rotmat 31 | 32 | MAT_ROT = np.zeros((24, 764), dtype=np.bool) # global_xform = 24, 6 wrp to the root 33 | lb = 3 + 24*3*3 + 24*3 + 24*3 + 8 + 24 + 3 + 24*3 + 24*6 34 | MAT_ROT[0, lb:lb+6] = True # root rotation = root_orient 35 | for joint_idx in range(24): 36 | ub = 3 + 24*3*3 + 24*3 + 24*3 + 8 + 24 + 3 + 24*3 + (joint_idx + 1) * 6 37 | lb = ub - 6 38 | MAT_ROT[joint_idx, lb:ub] = True # joint rotation = global_xform -------------------------------------------------------------------------------- /data_loaders/get_data.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from data_loaders.tensors import collate as all_collate 3 | from data_loaders.tensors import t2m_collate, amass_collate 4 | from typing import Tuple 5 | from dataclasses import dataclass 6 | 7 | 8 | def get_dataset_class(name): 9 | if name == "amass": 10 | from data_loaders.amass.data.dataset import AMASS 11 | return AMASS 12 | elif name == "uestc": 13 | from .a2m.uestc import UESTC 14 | return UESTC 15 | elif name == "humanact12": 16 | from .a2m.humanact12poses import HumanAct12Poses 17 | return HumanAct12Poses 18 | elif name == "humanml": 19 | from data_loaders.humanml.data.dataset import HumanML3D 20 | return HumanML3D 21 | elif name == "kit": 22 | from data_loaders.humanml.data.dataset import KIT 23 | return KIT 24 | else: 25 | raise ValueError(f'Unsupported dataset name [{name}]') 26 | 27 | 28 | def get_collate_fn(name, hml_mode='train'): 29 | if hml_mode == 'gt': 30 | from data_loaders.humanml.data.dataset import collate_fn as t2m_eval_collate 31 | return t2m_eval_collate 32 | if name in ["humanml", "kit"]: 33 | return t2m_collate 34 | elif name == 'amass': 35 | return amass_collate 36 | else: 37 | return all_collate 38 | 39 | 40 | @dataclass 41 | class DatasetConfig: 42 | name: str 43 | batch_size: int 44 | num_frames: int 45 | split: str = 'train' 46 | hml_mode: str = 'train' 47 | use_abs3d: bool = False 48 | traject_only: bool = False 49 | use_random_projection: bool = False 50 | random_projection_scale: float = None 51 | augment_type: str = 'none' 52 | std_scale_shift: Tuple[float] = (1.0, 0.0) 53 | drop_redundant: bool = False 54 | 55 | 56 | def get_dataset(conf: DatasetConfig): 57 | DATA = get_dataset_class(conf.name) 58 | if conf.name in ["humanml", "kit"]: 59 | dataset = DATA(split=conf.split, 60 | num_frames=conf.num_frames, 61 | mode=conf.hml_mode, 62 | use_abs3d=conf.use_abs3d, 63 | traject_only=conf.traject_only, 64 | use_random_projection=conf.use_random_projection, 65 | random_projection_scale=conf.random_projection_scale, 66 | augment_type=conf.augment_type, 67 | std_scale_shift=conf.std_scale_shift, 68 | drop_redundant=conf.drop_redundant) 69 | elif conf.name == "amass": 70 | dataset = DATA(split=conf.split) 71 | else: 72 | raise NotImplementedError() 73 | dataset = DATA(split=split, num_frames=num_frames) 74 | return dataset 75 | 76 | 77 | def get_dataset_loader(conf: DatasetConfig, shuffle=True, num_workers=8, drop_last=True): 78 | dataset = get_dataset(conf) 79 | collate = get_collate_fn(conf.name, conf.hml_mode) 80 | 81 | # return dataset 82 | loader = DataLoader(dataset, 83 | batch_size=conf.batch_size, 84 | shuffle=shuffle, 85 | num_workers=num_workers, 86 | drop_last=drop_last, 87 | collate_fn=collate,) 88 | #pin_memory=True) # Remove if out of memory occurs 89 | 90 | return loader -------------------------------------------------------------------------------- /data_loaders/humanml/README.md: -------------------------------------------------------------------------------- 1 | This code is based on https://github.com/EricGuo5513/text-to-motion.git -------------------------------------------------------------------------------- /data_loaders/humanml/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/setarehc/diffusion-motion-inbetweening/0c58473066ef582b9464a66d5542e0d1b6ae66c8/data_loaders/humanml/data/__init__.py -------------------------------------------------------------------------------- /data_loaders/humanml/motion_loaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/setarehc/diffusion-motion-inbetweening/0c58473066ef582b9464a66d5542e0d1b6ae66c8/data_loaders/humanml/motion_loaders/__init__.py -------------------------------------------------------------------------------- /data_loaders/humanml/motion_loaders/dataset_motion_loader.py: -------------------------------------------------------------------------------- 1 | from t2m.data.dataset import Text2MotionDatasetV2, collate_fn 2 | from t2m.utils.word_vectorizer import WordVectorizer 3 | import numpy as np 4 | from os.path import join as pjoin 5 | from torch.utils.data import DataLoader 6 | from t2m.utils.get_opt import get_opt 7 | 8 | def get_dataset_motion_loader(opt_path, batch_size, device): 9 | opt = get_opt(opt_path, device) 10 | 11 | # Configurations of T2M dataset and KIT dataset is almost the same 12 | if opt.dataset_name == 't2m' or opt.dataset_name == 'kit': 13 | print('Loading dataset %s ...' % opt.dataset_name) 14 | 15 | mean = np.load(pjoin(opt.meta_dir, 'mean.npy')) 16 | std = np.load(pjoin(opt.meta_dir, 'std.npy')) 17 | 18 | w_vectorizer = WordVectorizer('./glove', 'our_vab') 19 | split_file = pjoin(opt.data_root, 'test.txt') 20 | dataset = Text2MotionDatasetV2(opt, mean, std, split_file, w_vectorizer) 21 | dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4, drop_last=True, 22 | collate_fn=collate_fn, shuffle=True) 23 | else: 24 | raise KeyError('Dataset not Recognized !!') 25 | 26 | print('Ground Truth Dataset Loading Completed!!!') 27 | return dataloader, dataset -------------------------------------------------------------------------------- /data_loaders/humanml/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/setarehc/diffusion-motion-inbetweening/0c58473066ef582b9464a66d5542e0d1b6ae66c8/data_loaders/humanml/networks/__init__.py -------------------------------------------------------------------------------- /data_loaders/humanml/utils/get_opt.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import Namespace 3 | import re 4 | from os.path import join as pjoin 5 | from data_loaders.humanml.utils.word_vectorizer import POS_enumerator 6 | 7 | 8 | def is_float(numStr): 9 | flag = False 10 | numStr = str(numStr).strip().lstrip('-').lstrip('+') # 去除正数(+)、负数(-)符号 11 | try: 12 | reg = re.compile(r'^[-+]?[0-9]+\.[0-9]+$') 13 | res = reg.match(str(numStr)) 14 | if res: 15 | flag = True 16 | except Exception as ex: 17 | print("is_float() - error: " + str(ex)) 18 | return flag 19 | 20 | 21 | def is_number(numStr): 22 | flag = False 23 | numStr = str(numStr).strip().lstrip('-').lstrip('+') # 去除正数(+)、负数(-)符号 24 | if str(numStr).isdigit(): 25 | flag = True 26 | return flag 27 | 28 | 29 | def get_opt(opt_path, device, mode, max_motion_length, use_abs3d=False): 30 | opt = Namespace() 31 | opt_dict = vars(opt) 32 | 33 | skip = ('-------------- End ----------------', 34 | '------------ Options -------------', 35 | '\n') 36 | print('Reading', opt_path) 37 | with open(opt_path) as f: 38 | for line in f: 39 | if line.strip() not in skip: 40 | # print(line.strip()) 41 | key, value = line.strip().split(': ') 42 | if value in ('True', 'False'): 43 | opt_dict[key] = bool(value) 44 | elif is_float(value): 45 | opt_dict[key] = float(value) 46 | elif is_number(value): 47 | opt_dict[key] = int(value) 48 | else: 49 | opt_dict[key] = str(value) 50 | 51 | # print(opt) 52 | opt_dict['which_epoch'] = 'latest' 53 | opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name) 54 | opt.model_dir = pjoin(opt.save_root, 'model') 55 | opt.meta_dir = pjoin(opt.save_root, 'meta') 56 | 57 | if opt.dataset_name == 't2m': 58 | opt.data_root = './dataset/HumanML3D' 59 | # Set directory based on type of dataset representation: 60 | # Will load the original dataset (relative) if in 'eval' or 'gt' mode 61 | data_dir = 'new_joint_vecs_abs_3d' if use_abs3d and mode not in ['eval', 'gt'] else 'new_joint_vecs' 62 | if "DATA_ROOT" in os.environ: 63 | local_data_root = pjoin(os.environ["DATA_ROOT"], opt.data_root) 64 | opt.motion_dir = pjoin(local_data_root, data_dir) 65 | opt.text_dir = pjoin(local_data_root, 'texts') 66 | else: 67 | opt.motion_dir = pjoin(opt.data_root, data_dir) 68 | opt.text_dir = pjoin(opt.data_root, 'texts') 69 | 70 | opt.joints_num = 22 71 | opt.dim_pose = 263 72 | # NOTE: UNET needs to uses multiples of 16 73 | opt.max_motion_length = max_motion_length 74 | print(f'WARNING: max_motion_length is set to {max_motion_length}') 75 | elif opt.dataset_name == 'kit': 76 | raise NotImplementedError() 77 | opt.data_root = './dataset/KIT-ML' 78 | opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs') 79 | opt.text_dir = pjoin(opt.data_root, 'texts') 80 | opt.joints_num = 21 81 | opt.dim_pose = 251 82 | opt.max_motion_length = 196 83 | else: 84 | raise KeyError('Dataset not recognized') 85 | 86 | opt.dim_word = 300 87 | opt.num_classes = 200 // opt.unit_length 88 | opt.dim_pos_ohot = len(POS_enumerator) 89 | opt.is_train = False 90 | opt.is_continue = False 91 | opt.device = device 92 | 93 | return opt -------------------------------------------------------------------------------- /data_loaders/humanml/utils/paramUtil.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # Define a kinematic tree for the skeletal struture 4 | kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]] 5 | 6 | kit_raw_offsets = np.array( 7 | [ 8 | [0, 0, 0], 9 | [0, 1, 0], 10 | [0, 1, 0], 11 | [0, 1, 0], 12 | [0, 1, 0], 13 | [1, 0, 0], 14 | [0, -1, 0], 15 | [0, -1, 0], 16 | [-1, 0, 0], 17 | [0, -1, 0], 18 | [0, -1, 0], 19 | [1, 0, 0], 20 | [0, -1, 0], 21 | [0, -1, 0], 22 | [0, 0, 1], 23 | [0, 0, 1], 24 | [-1, 0, 0], 25 | [0, -1, 0], 26 | [0, -1, 0], 27 | [0, 0, 1], 28 | [0, 0, 1] 29 | ] 30 | ) 31 | 32 | t2m_raw_offsets = np.array([[0,0,0], 33 | [1,0,0], 34 | [-1,0,0], 35 | [0,1,0], 36 | [0,-1,0], 37 | [0,-1,0], 38 | [0,1,0], 39 | [0,-1,0], 40 | [0,-1,0], 41 | [0,1,0], 42 | [0,0,1], 43 | [0,0,1], 44 | [0,1,0], 45 | [1,0,0], 46 | [-1,0,0], 47 | [0,0,1], 48 | [0,-1,0], 49 | [0,-1,0], 50 | [0,-1,0], 51 | [0,-1,0], 52 | [0,-1,0], 53 | [0,-1,0]]) 54 | 55 | t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]] 56 | t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]] 57 | t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]] 58 | 59 | 60 | kit_tgt_skel_id = '03950' 61 | 62 | t2m_tgt_skel_id = '000021' 63 | 64 | -------------------------------------------------------------------------------- /data_loaders/humanml/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | # import cv2 4 | from PIL import Image 5 | from data_loaders.humanml.utils import paramUtil 6 | import math 7 | import time 8 | import matplotlib.pyplot as plt 9 | from scipy.ndimage import gaussian_filter 10 | 11 | 12 | def mkdir(path): 13 | if not os.path.exists(path): 14 | os.makedirs(path) 15 | 16 | COLORS = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], 17 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], 18 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] 19 | 20 | MISSING_VALUE = -1 21 | 22 | def save_image(image_numpy, image_path): 23 | img_pil = Image.fromarray(image_numpy) 24 | img_pil.save(image_path) 25 | 26 | 27 | def save_logfile(log_loss, save_path): 28 | with open(save_path, 'wt') as f: 29 | for k, v in log_loss.items(): 30 | w_line = k 31 | for digit in v: 32 | w_line += ' %.3f' % digit 33 | f.write(w_line + '\n') 34 | 35 | 36 | def print_current_loss(start_time, niter_state, losses, epoch=None, sub_epoch=None, 37 | inner_iter=None, tf_ratio=None, sl_steps=None): 38 | 39 | def as_minutes(s): 40 | m = math.floor(s / 60) 41 | s -= m * 60 42 | return '%dm %ds' % (m, s) 43 | 44 | def time_since(since, percent): 45 | now = time.time() 46 | s = now - since 47 | es = s / percent 48 | rs = es - s 49 | return '%s (- %s)' % (as_minutes(s), as_minutes(rs)) 50 | 51 | if epoch is not None: 52 | print('epoch: %3d niter: %6d sub_epoch: %2d inner_iter: %4d' % (epoch, niter_state, sub_epoch, inner_iter), end=" ") 53 | 54 | # message = '%s niter: %d completed: %3d%%)' % (time_since(start_time, niter_state / total_niters), 55 | # niter_state, niter_state / total_niters * 100) 56 | now = time.time() 57 | message = '%s'%(as_minutes(now - start_time)) 58 | 59 | for k, v in losses.items(): 60 | message += ' %s: %.4f ' % (k, v) 61 | message += ' sl_length:%2d tf_ratio:%.2f'%(sl_steps, tf_ratio) 62 | print(message) 63 | 64 | def print_current_loss_decomp(start_time, niter_state, total_niters, losses, epoch=None, inner_iter=None): 65 | 66 | def as_minutes(s): 67 | m = math.floor(s / 60) 68 | s -= m * 60 69 | return '%dm %ds' % (m, s) 70 | 71 | def time_since(since, percent): 72 | now = time.time() 73 | s = now - since 74 | es = s / percent 75 | rs = es - s 76 | return '%s (- %s)' % (as_minutes(s), as_minutes(rs)) 77 | 78 | print('epoch: %03d inner_iter: %5d' % (epoch, inner_iter), end=" ") 79 | # now = time.time() 80 | message = '%s niter: %07d completed: %3d%%)'%(time_since(start_time, niter_state / total_niters), niter_state, niter_state / total_niters * 100) 81 | for k, v in losses.items(): 82 | message += ' %s: %.4f ' % (k, v) 83 | print(message) 84 | 85 | 86 | def compose_gif_img_list(img_list, fp_out, duration): 87 | img, *imgs = [Image.fromarray(np.array(image)) for image in img_list] 88 | img.save(fp=fp_out, format='GIF', append_images=imgs, optimize=False, 89 | save_all=True, loop=0, duration=duration) 90 | 91 | 92 | def save_images(visuals, image_path): 93 | if not os.path.exists(image_path): 94 | os.makedirs(image_path) 95 | 96 | for i, (label, img_numpy) in enumerate(visuals.items()): 97 | img_name = '%d_%s.jpg' % (i, label) 98 | save_path = os.path.join(image_path, img_name) 99 | save_image(img_numpy, save_path) 100 | 101 | 102 | def save_images_test(visuals, image_path, from_name, to_name): 103 | if not os.path.exists(image_path): 104 | os.makedirs(image_path) 105 | 106 | for i, (label, img_numpy) in enumerate(visuals.items()): 107 | img_name = "%s_%s_%s" % (from_name, to_name, label) 108 | save_path = os.path.join(image_path, img_name) 109 | save_image(img_numpy, save_path) 110 | 111 | 112 | def compose_and_save_img(img_list, save_dir, img_name, col=4, row=1, img_size=(256, 200)): 113 | # print(col, row) 114 | compose_img = compose_image(img_list, col, row, img_size) 115 | if not os.path.exists(save_dir): 116 | os.makedirs(save_dir) 117 | img_path = os.path.join(save_dir, img_name) 118 | # print(img_path) 119 | compose_img.save(img_path) 120 | 121 | 122 | def compose_image(img_list, col, row, img_size): 123 | to_image = Image.new('RGB', (col * img_size[0], row * img_size[1])) 124 | for y in range(0, row): 125 | for x in range(0, col): 126 | from_img = Image.fromarray(img_list[y * col + x]) 127 | # print((x * img_size[0], y*img_size[1], 128 | # (x + 1) * img_size[0], (y + 1) * img_size[1])) 129 | paste_area = (x * img_size[0], y*img_size[1], 130 | (x + 1) * img_size[0], (y + 1) * img_size[1]) 131 | to_image.paste(from_img, paste_area) 132 | # to_image[y*img_size[1]:(y + 1) * img_size[1], x * img_size[0] :(x + 1) * img_size[0]] = from_img 133 | return to_image 134 | 135 | 136 | def plot_loss_curve(losses, save_path, intervals=500): 137 | plt.figure(figsize=(10, 5)) 138 | plt.title("Loss During Training") 139 | for key in losses.keys(): 140 | plt.plot(list_cut_average(losses[key], intervals), label=key) 141 | plt.xlabel("Iterations/" + str(intervals)) 142 | plt.ylabel("Loss") 143 | plt.legend() 144 | plt.savefig(save_path) 145 | plt.show() 146 | 147 | 148 | def list_cut_average(ll, intervals): 149 | if intervals == 1: 150 | return ll 151 | 152 | bins = math.ceil(len(ll) * 1.0 / intervals) 153 | ll_new = [] 154 | for i in range(bins): 155 | l_low = intervals * i 156 | l_high = l_low + intervals 157 | l_high = l_high if l_high < len(ll) else len(ll) 158 | ll_new.append(np.mean(ll[l_low:l_high])) 159 | return ll_new 160 | 161 | 162 | def motion_temporal_filter(motion, sigma=1): 163 | motion = motion.reshape(motion.shape[0], -1) 164 | # print(motion.shape)
 165 | for i in range(motion.shape[1]): 166 | motion[:, i] = gaussian_filter(motion[:, i], sigma=sigma, mode="nearest") 167 | return motion.reshape(motion.shape[0], -1, 3) 168 | 169 | -------------------------------------------------------------------------------- /data_loaders/humanml/utils/word_vectorizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | from os.path import join as pjoin 4 | 5 | POS_enumerator = { 6 | 'VERB': 0, 7 | 'NOUN': 1, 8 | 'DET': 2, 9 | 'ADP': 3, 10 | 'NUM': 4, 11 | 'AUX': 5, 12 | 'PRON': 6, 13 | 'ADJ': 7, 14 | 'ADV': 8, 15 | 'Loc_VIP': 9, 16 | 'Body_VIP': 10, 17 | 'Obj_VIP': 11, 18 | 'Act_VIP': 12, 19 | 'Desc_VIP': 13, 20 | 'OTHER': 14, 21 | } 22 | 23 | Loc_list = ('left', 'right', 'clockwise', 'counterclockwise', 'anticlockwise', 'forward', 'back', 'backward', 24 | 'up', 'down', 'straight', 'curve') 25 | 26 | Body_list = ('arm', 'chin', 'foot', 'feet', 'face', 'hand', 'mouth', 'leg', 'waist', 'eye', 'knee', 'shoulder', 'thigh') 27 | 28 | Obj_List = ('stair', 'dumbbell', 'chair', 'window', 'floor', 'car', 'ball', 'handrail', 'baseball', 'basketball') 29 | 30 | Act_list = ('walk', 'run', 'swing', 'pick', 'bring', 'kick', 'put', 'squat', 'throw', 'hop', 'dance', 'jump', 'turn', 31 | 'stumble', 'dance', 'stop', 'sit', 'lift', 'lower', 'raise', 'wash', 'stand', 'kneel', 'stroll', 32 | 'rub', 'bend', 'balance', 'flap', 'jog', 'shuffle', 'lean', 'rotate', 'spin', 'spread', 'climb') 33 | 34 | Desc_list = ('slowly', 'carefully', 'fast', 'careful', 'slow', 'quickly', 'happy', 'angry', 'sad', 'happily', 35 | 'angrily', 'sadly') 36 | 37 | VIP_dict = { 38 | 'Loc_VIP': Loc_list, 39 | 'Body_VIP': Body_list, 40 | 'Obj_VIP': Obj_List, 41 | 'Act_VIP': Act_list, 42 | 'Desc_VIP': Desc_list, 43 | } 44 | 45 | 46 | class WordVectorizer(object): 47 | def __init__(self, meta_root, prefix): 48 | vectors = np.load(pjoin(meta_root, '%s_data.npy'%prefix)) 49 | words = pickle.load(open(pjoin(meta_root, '%s_words.pkl'%prefix), 'rb')) 50 | word2idx = pickle.load(open(pjoin(meta_root, '%s_idx.pkl'%prefix), 'rb')) 51 | self.word2vec = {w: vectors[word2idx[w]] for w in words} 52 | 53 | def _get_pos_ohot(self, pos): 54 | pos_vec = np.zeros(len(POS_enumerator)) 55 | if pos in POS_enumerator: 56 | pos_vec[POS_enumerator[pos]] = 1 57 | else: 58 | pos_vec[POS_enumerator['OTHER']] = 1 59 | return pos_vec 60 | 61 | def __len__(self): 62 | return len(self.word2vec) 63 | 64 | def __getitem__(self, item): 65 | word, pos = item.split('/') 66 | if word in self.word2vec: 67 | word_vec = self.word2vec[word] 68 | vip_pos = None 69 | for key, values in VIP_dict.items(): 70 | if word in values: 71 | vip_pos = key 72 | break 73 | if vip_pos is not None: 74 | pos_vec = self._get_pos_ohot(vip_pos) 75 | else: 76 | pos_vec = self._get_pos_ohot(pos) 77 | else: 78 | word_vec = self.word2vec['unk'] 79 | pos_vec = self._get_pos_ohot('OTHER') 80 | return word_vec, pos_vec -------------------------------------------------------------------------------- /data_loaders/humanml_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | HML_JOINT_NAMES = [ 4 | 'pelvis', 5 | 'left_hip', 6 | 'right_hip', 7 | 'spine1', 8 | 'left_knee', 9 | 'right_knee', 10 | 'spine2', 11 | 'left_ankle', 12 | 'right_ankle', 13 | 'spine3', 14 | 'left_foot', 15 | 'right_foot', 16 | 'neck', 17 | 'left_collar', 18 | 'right_collar', 19 | 'head', 20 | 'left_shoulder', 21 | 'right_shoulder', 22 | 'left_elbow', 23 | 'right_elbow', 24 | 'left_wrist', 25 | 'right_wrist', 26 | ] 27 | 28 | NUM_HML_JOINTS = len(HML_JOINT_NAMES) # 22 SMPLH body joints 29 | 30 | HML_LOWER_BODY_JOINTS = [HML_JOINT_NAMES.index(name) for name in ['pelvis', 'left_hip', 'right_hip', 'left_knee', 'right_knee', 'left_ankle', 'right_ankle', 'left_foot', 'right_foot',]] 31 | SMPL_UPPER_BODY_JOINTS = [i for i in range(len(HML_JOINT_NAMES)) if i not in HML_LOWER_BODY_JOINTS] 32 | HML_LOWER_BODY_RIGHT_JOINTS = [HML_JOINT_NAMES.index(name) for name in ['pelvis', 'right_hip', 'right_knee', 'right_ankle', 'right_foot',]] 33 | HML_PELVIS_FEET = [HML_JOINT_NAMES.index(name) for name in ['pelvis', 'left_foot', 'right_foot']] 34 | HML_PELVIS_HANDS = [HML_JOINT_NAMES.index(name) for name in ['pelvis', 'left_wrist', 'right_wrist']] 35 | HML_PELVIS_VR = [HML_JOINT_NAMES.index(name) for name in ['pelvis', 'left_wrist', 'right_wrist', 'head']] 36 | 37 | # Recover global angle and positions for rotation data 38 | # root_rot_velocity (B, seq_len, 1) 39 | # root_linear_velocity (B, seq_len, 2) 40 | # root_y (B, seq_len, 1) 41 | # ric_data (B, seq_len, (joint_num - 1)*3) 42 | # rot_data (B, seq_len, (joint_num - 1)*6) 43 | # local_velocity (B, seq_len, joint_num*3) 44 | # foot contact (B, seq_len, 4) 45 | HML_ROOT_BINARY = np.array([True] + [False] * (NUM_HML_JOINTS-1)) 46 | HML_ROOT_MASK = np.concatenate(([True]*(1+2+1), 47 | HML_ROOT_BINARY[1:].repeat(3), 48 | HML_ROOT_BINARY[1:].repeat(6), 49 | HML_ROOT_BINARY.repeat(3), 50 | [False] * 4)) 51 | HML_LOWER_BODY_JOINTS_BINARY = np.array([i in HML_LOWER_BODY_JOINTS for i in range(NUM_HML_JOINTS)]) 52 | HML_LOWER_BODY_MASK = np.concatenate(([True]*(1+2+1), 53 | HML_LOWER_BODY_JOINTS_BINARY[1:].repeat(3), 54 | HML_LOWER_BODY_JOINTS_BINARY[1:].repeat(6), 55 | HML_LOWER_BODY_JOINTS_BINARY.repeat(3), 56 | [True]*4)) 57 | HML_UPPER_BODY_MASK = ~HML_LOWER_BODY_MASK 58 | 59 | HML_LOWER_BODY_RIGHT_JOINTS_BINARY = np.array([i in HML_LOWER_BODY_RIGHT_JOINTS for i in range(NUM_HML_JOINTS)]) 60 | HML_LOWER_BODY_RIGHT_MASK = np.concatenate(([True]*(1+2+1), 61 | HML_LOWER_BODY_RIGHT_JOINTS_BINARY[1:].repeat(3), 62 | HML_LOWER_BODY_RIGHT_JOINTS_BINARY[1:].repeat(6), 63 | HML_LOWER_BODY_RIGHT_JOINTS_BINARY.repeat(3), 64 | [True]*4)) 65 | 66 | 67 | # Matrix that shows joint correspondces to SMPL features 68 | MAT_POS = np.zeros((22, 263), dtype=np.bool) 69 | MAT_POS[0, 1:4] = True 70 | for joint_idx in range(1, 22): 71 | ub = 4 + 3 * joint_idx 72 | lb = ub - 3 73 | MAT_POS[joint_idx, lb:ub] = True 74 | 75 | MAT_ROT = np.zeros((22, 263), dtype=np.bool) 76 | MAT_ROT[0, 0] = True 77 | for joint_idx in range(1, 22): 78 | ub = 4 + 21*3 + 6 * joint_idx 79 | lb = ub - 6 80 | MAT_ROT[joint_idx, lb:ub] = True 81 | 82 | MAT_VEL = np.zeros((22, 263), dtype=np.bool) 83 | for joint_idx in range(0, 22): 84 | ub = 4 + 21*3 + 21*6 + 3 * (joint_idx + 1) 85 | lb = ub - 3 86 | MAT_VEL[joint_idx, lb:ub] = True 87 | 88 | MAT_CNT = np.zeros((22, 263), dtype=np.bool) 89 | MAT_CNT[7, -4] = True 90 | MAT_CNT[10, -3] = True 91 | MAT_CNT[8, -2] = True 92 | MAT_CNT[11, -1] = True -------------------------------------------------------------------------------- /data_loaders/tensors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def lengths_to_mask(lengths, max_len): 4 | # max_len = max(lengths) 5 | mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1) 6 | return mask 7 | 8 | 9 | def collate_tensors(batch): 10 | dims = batch[0].dim() 11 | max_size = [max([b.size(i) for b in batch]) for i in range(dims)] 12 | size = (len(batch),) + tuple(max_size) 13 | canvas = batch[0].new_zeros(size=size) 14 | for i, b in enumerate(batch): 15 | sub_tensor = canvas[i] 16 | for d in range(dims): 17 | sub_tensor = sub_tensor.narrow(d, 0, b.size(d)) 18 | sub_tensor.add_(b) 19 | return canvas 20 | 21 | 22 | def collate(batch): 23 | notnone_batches = [b for b in batch if b is not None] 24 | databatch = [b['inp'] for b in notnone_batches] 25 | if 'lengths' in notnone_batches[0]: 26 | lenbatch = [b['lengths'] for b in notnone_batches] 27 | else: 28 | lenbatch = [len(b['inp'][0][0]) for b in notnone_batches] 29 | 30 | 31 | databatchTensor = collate_tensors(databatch) 32 | lenbatchTensor = torch.as_tensor(lenbatch) 33 | maskbatchTensor = lengths_to_mask(lenbatchTensor, databatchTensor.shape[-1]).unsqueeze(1).unsqueeze(1) # unqueeze for broadcasting 34 | 35 | motion = databatchTensor 36 | cond = {'y': {'mask': maskbatchTensor, 'lengths': lenbatchTensor}} 37 | 38 | if 'text' in notnone_batches[0]: 39 | textbatch = [b['text'] for b in notnone_batches] 40 | cond['y'].update({'text': textbatch}) 41 | 42 | if 'tokens' in notnone_batches[0]: 43 | textbatch = [b['tokens'] for b in notnone_batches] 44 | cond['y'].update({'tokens': textbatch}) 45 | 46 | if 'action' in notnone_batches[0]: 47 | actionbatch = [b['action'] for b in notnone_batches] 48 | cond['y'].update({'action': torch.as_tensor(actionbatch).unsqueeze(1)}) 49 | 50 | # collate action textual names 51 | if 'action_text' in notnone_batches[0]: 52 | action_text = [b['action_text']for b in notnone_batches] 53 | cond['y'].update({'action_text': action_text}) 54 | 55 | return motion, cond 56 | 57 | # an adapter to our collate func 58 | def t2m_collate(batch): 59 | # batch.sort(key=lambda x: x[3], reverse=True) 60 | adapted_batch = [{ 61 | 'inp': torch.tensor(b[4].T).float().unsqueeze(1), # [seqlen, J] -> [J, 1, seqlen] 62 | 'text': b[2], #b[0]['caption'] 63 | 'tokens': b[6], 64 | 'lengths': b[5], 65 | } for b in batch] 66 | return collate(adapted_batch) 67 | 68 | def amass_collate(batch): 69 | adapted_batch = [{ 70 | 'inp': (b.clone().detach().T).float().unsqueeze(1), # [seqlen, J] -> [J, 1, seqlen] 71 | } for b in batch] 72 | return collate(adapted_batch) 73 | -------------------------------------------------------------------------------- /dataset/000021.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/setarehc/diffusion-motion-inbetweening/0c58473066ef582b9464a66d5542e0d1b6ae66c8/dataset/000021.npy -------------------------------------------------------------------------------- /dataset/HumanML3D_abs/Mean_abs_3d.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/setarehc/diffusion-motion-inbetweening/0c58473066ef582b9464a66d5542e0d1b6ae66c8/dataset/HumanML3D_abs/Mean_abs_3d.npy -------------------------------------------------------------------------------- /dataset/HumanML3D_abs/Std_abs_3d.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/setarehc/diffusion-motion-inbetweening/0c58473066ef582b9464a66d5542e0d1b6ae66c8/dataset/HumanML3D_abs/Std_abs_3d.npy -------------------------------------------------------------------------------- /dataset/HumanML3D_abs/cal_mean_variance.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import sys\n", 11 | "import os\n", 12 | "from os.path import join as pjoin\n", 13 | "\n", 14 | "\n", 15 | "# root_rot_velocity (B, seq_len, 1)\n", 16 | "# root_linear_velocity (B, seq_len, 2)\n", 17 | "# root_y (B, seq_len, 1)\n", 18 | "# ric_data (B, seq_len, (joint_num - 1)*3)\n", 19 | "# rot_data (B, seq_len, (joint_num - 1)*6)\n", 20 | "# local_velocity (B, seq_len, joint_num*3)\n", 21 | "# foot contact (B, seq_len, 4)\n", 22 | "def mean_variance(data_dir, save_dir, joints_num):\n", 23 | " file_list = os.listdir(data_dir)\n", 24 | " data_list = []\n", 25 | "\n", 26 | " for file in file_list:\n", 27 | " data = np.load(pjoin(data_dir, file))\n", 28 | " if np.isnan(data).any():\n", 29 | " print(file)\n", 30 | " continue\n", 31 | " data_list.append(data)\n", 32 | "\n", 33 | " data = np.concatenate(data_list, axis=0)\n", 34 | " print(data.shape)\n", 35 | " Mean = data.mean(axis=0)\n", 36 | " Std = data.std(axis=0)\n", 37 | " Std[0:1] = Std[0:1].mean() / 1.0\n", 38 | " Std[1:3] = Std[1:3].mean() / 1.0\n", 39 | " Std[3:4] = Std[3:4].mean() / 1.0\n", 40 | " Std[4: 4+(joints_num - 1) * 3] = Std[4: 4+(joints_num - 1) * 3].mean() / 1.0\n", 41 | " Std[4+(joints_num - 1) * 3: 4+(joints_num - 1) * 9] = Std[4+(joints_num - 1) * 3: 4+(joints_num - 1) * 9].mean() / 1.0\n", 42 | " Std[4+(joints_num - 1) * 9: 4+(joints_num - 1) * 9 + joints_num*3] = Std[4+(joints_num - 1) * 9: 4+(joints_num - 1) * 9 + joints_num*3].mean() / 1.0\n", 43 | " Std[4 + (joints_num - 1) * 9 + joints_num * 3: ] = Std[4 + (joints_num - 1) * 9 + joints_num * 3: ].mean() / 1.0\n", 44 | "\n", 45 | " assert 8 + (joints_num - 1) * 9 + joints_num * 3 == Std.shape[-1]\n", 46 | "\n", 47 | "# np.save(pjoin(save_dir, 'Mean.npy'), Mean)\n", 48 | "# np.save(pjoin(save_dir, 'Std.npy'), Std)\n", 49 | " \n", 50 | " np.save(pjoin(save_dir, 'Mean_abs_3d.npy'), Mean)\n", 51 | " np.save(pjoin(save_dir, 'Std_abs_3d.npy'), Std)\n", 52 | "\n", 53 | " return Mean, Std" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 2, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "# The given data is used to double check if you are on the right track.\n", 63 | "reference1 = np.load('./HumanML3D/Mean.npy')\n", 64 | "reference2 = np.load('./HumanML3D/Std.npy')" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 3, 70 | "metadata": {}, 71 | "outputs": [ 72 | { 73 | "name": "stdout", 74 | "output_type": "stream", 75 | "text": [ 76 | "(4117392, 263)\n" 77 | ] 78 | } 79 | ], 80 | "source": [ 81 | "if __name__ == '__main__':\n", 82 | "# data_dir = './HumanML3D/new_joint_vecs/'\n", 83 | " data_dir = './HumanML3D/new_joint_vecs_abs_3d/'\n", 84 | " save_dir = './HumanML3D/'\n", 85 | " mean, std = mean_variance(data_dir, save_dir, 22)\n", 86 | "# print(mean)\n", 87 | "# print(Std)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": {}, 93 | "source": [ 94 | "### Check if your data is correct. If it's aligned with the given reference, then it is right" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 4, 100 | "metadata": {}, 101 | "outputs": [ 102 | { 103 | "data": { 104 | "text/plain": [ 105 | "0.32431233" 106 | ] 107 | }, 108 | "execution_count": 4, 109 | "metadata": {}, 110 | "output_type": "execute_result" 111 | } 112 | ], 113 | "source": [ 114 | "abs(mean-reference1).sum()" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 5, 120 | "metadata": {}, 121 | "outputs": [ 122 | { 123 | "data": { 124 | "text/plain": [ 125 | "2.052561" 126 | ] 127 | }, 128 | "execution_count": 5, 129 | "metadata": {}, 130 | "output_type": "execute_result" 131 | } 132 | ], 133 | "source": [ 134 | "abs(std-reference2).sum()" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [] 143 | } 144 | ], 145 | "metadata": { 146 | "kernelspec": { 147 | "display_name": "Python 3", 148 | "language": "python", 149 | "name": "python3" 150 | }, 151 | "language_info": { 152 | "codemirror_mode": { 153 | "name": "ipython", 154 | "version": 3 155 | }, 156 | "file_extension": ".py", 157 | "mimetype": "text/x-python", 158 | "name": "python", 159 | "nbconvert_exporter": "python", 160 | "pygments_lexer": "ipython3", 161 | "version": "3.7.10" 162 | } 163 | }, 164 | "nbformat": 4, 165 | "nbformat_minor": 4 166 | } 167 | -------------------------------------------------------------------------------- /dataset/README.md: -------------------------------------------------------------------------------- 1 | ## Data 2 | 3 | * Data dirs should be placed here. 4 | 5 | * The `opt` files are configurations for how to read the data according to [text-to-motion](https://github.com/EricGuo5513/text-to-motion). 6 | 7 | * The `*_mean.npy` and `*_std.npy` files, are stats used for evaluation only, according to [text-to-motion](https://github.com/EricGuo5513/text-to-motion). -------------------------------------------------------------------------------- /dataset/humanml_opt.txt: -------------------------------------------------------------------------------- 1 | ------------ Options ------------- 2 | batch_size: 32 3 | checkpoints_dir: ./checkpoints 4 | dataset_name: t2m 5 | decomp_name: Decomp_SP001_SM001_H512 6 | dim_att_vec: 512 7 | dim_dec_hidden: 1024 8 | dim_movement2_dec_hidden: 512 9 | dim_movement_dec_hidden: 512 10 | dim_movement_enc_hidden: 512 11 | dim_movement_latent: 512 12 | dim_msd_hidden: 512 13 | dim_pos_hidden: 1024 14 | dim_pri_hidden: 1024 15 | dim_seq_de_hidden: 512 16 | dim_seq_en_hidden: 512 17 | dim_text_hidden: 512 18 | dim_z: 128 19 | early_stop_count: 3 20 | estimator_mod: bigru 21 | eval_every_e: 5 22 | feat_bias: 5 23 | fixed_steps: 5 24 | gpu_id: 3 25 | input_z: False 26 | is_continue: True 27 | is_train: True 28 | lambda_fake: 10 29 | lambda_gan_l: 0.1 30 | lambda_gan_mt: 0.1 31 | lambda_gan_mv: 0.1 32 | lambda_kld: 0.01 33 | lambda_rec: 1 34 | lambda_rec_init: 1 35 | lambda_rec_mot: 1 36 | lambda_rec_mov: 1 37 | log_every: 50 38 | lr: 0.0002 39 | max_sub_epoch: 50 40 | max_text_len: 20 41 | n_layers_dec: 1 42 | n_layers_msd: 2 43 | n_layers_pos: 1 44 | n_layers_pri: 1 45 | n_layers_seq_de: 2 46 | n_layers_seq_en: 1 47 | name: Comp_v6_KLD01 48 | num_experts: 4 49 | save_every_e: 10 50 | save_latest: 500 51 | text_enc_mod: bigru 52 | tf_ratio: 0.4 53 | unit_length: 4 54 | -------------- End ---------------- 55 | -------------------------------------------------------------------------------- /dataset/inv_rand_proj.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/setarehc/diffusion-motion-inbetweening/0c58473066ef582b9464a66d5542e0d1b6ae66c8/dataset/inv_rand_proj.npy -------------------------------------------------------------------------------- /dataset/kit_mean.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/setarehc/diffusion-motion-inbetweening/0c58473066ef582b9464a66d5542e0d1b6ae66c8/dataset/kit_mean.npy -------------------------------------------------------------------------------- /dataset/kit_opt.txt: -------------------------------------------------------------------------------- 1 | ------------ Options ------------- 2 | batch_size: 32 3 | checkpoints_dir: ./checkpoints 4 | dataset_name: kit 5 | decomp_name: Decomp_SP001_SM001_H512 6 | dim_att_vec: 512 7 | dim_dec_hidden: 1024 8 | dim_movement2_dec_hidden: 512 9 | dim_movement_dec_hidden: 512 10 | dim_movement_enc_hidden: 512 11 | dim_movement_latent: 512 12 | dim_msd_hidden: 512 13 | dim_pos_hidden: 1024 14 | dim_pri_hidden: 1024 15 | dim_seq_de_hidden: 512 16 | dim_seq_en_hidden: 512 17 | dim_text_hidden: 512 18 | dim_z: 128 19 | early_stop_count: 3 20 | estimator_mod: bigru 21 | eval_every_e: 5 22 | feat_bias: 5 23 | fixed_steps: 5 24 | gpu_id: 2 25 | input_z: False 26 | is_continue: True 27 | is_train: True 28 | lambda_fake: 10 29 | lambda_gan_l: 0.1 30 | lambda_gan_mt: 0.1 31 | lambda_gan_mv: 0.1 32 | lambda_kld: 0.005 33 | lambda_rec: 1 34 | lambda_rec_init: 1 35 | lambda_rec_mot: 1 36 | lambda_rec_mov: 1 37 | log_every: 50 38 | lr: 0.0002 39 | max_sub_epoch: 50 40 | max_text_len: 20 41 | n_layers_dec: 1 42 | n_layers_msd: 2 43 | n_layers_pos: 1 44 | n_layers_pri: 1 45 | n_layers_seq_de: 2 46 | n_layers_seq_en: 1 47 | name: Comp_v6_KLD005 48 | num_experts: 4 49 | save_every_e: 10 50 | save_latest: 500 51 | text_enc_mod: bigru 52 | tf_ratio: 0.4 53 | unit_length: 4 54 | -------------- End ---------------- 55 | -------------------------------------------------------------------------------- /dataset/kit_std.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/setarehc/diffusion-motion-inbetweening/0c58473066ef582b9464a66d5542e0d1b6ae66c8/dataset/kit_std.npy -------------------------------------------------------------------------------- /dataset/rand_proj.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/setarehc/diffusion-motion-inbetweening/0c58473066ef582b9464a66d5542e0d1b6ae66c8/dataset/rand_proj.npy -------------------------------------------------------------------------------- /dataset/t2m_mean.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/setarehc/diffusion-motion-inbetweening/0c58473066ef582b9464a66d5542e0d1b6ae66c8/dataset/t2m_mean.npy -------------------------------------------------------------------------------- /dataset/t2m_std.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/setarehc/diffusion-motion-inbetweening/0c58473066ef582b9464a66d5542e0d1b6ae66c8/dataset/t2m_std.npy -------------------------------------------------------------------------------- /diffusion/losses.py: -------------------------------------------------------------------------------- 1 | # This code is based on https://github.com/openai/guided-diffusion 2 | """ 3 | Helpers for various likelihood-based losses. These are ported from the original 4 | Ho et al. diffusion models codebase: 5 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 6 | """ 7 | 8 | import numpy as np 9 | import torch as th 10 | 11 | 12 | def normal_kl(mean1, logvar1, mean2, logvar2): 13 | """ 14 | Compute the KL divergence between two gaussians. 15 | 16 | Shapes are automatically broadcasted, so batches can be compared to 17 | scalars, among other use cases. 18 | """ 19 | tensor = None 20 | for obj in (mean1, logvar1, mean2, logvar2): 21 | if isinstance(obj, th.Tensor): 22 | tensor = obj 23 | break 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | return 0.5 * ( 34 | -1.0 35 | + logvar2 36 | - logvar1 37 | + th.exp(logvar1 - logvar2) 38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 39 | ) 40 | 41 | 42 | def approx_standard_normal_cdf(x): 43 | """ 44 | A fast approximation of the cumulative distribution function of the 45 | standard normal. 46 | """ 47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 48 | 49 | 50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 51 | """ 52 | Compute the log-likelihood of a Gaussian distribution discretizing to a 53 | given image. 54 | 55 | :param x: the target images. It is assumed that this was uint8 values, 56 | rescaled to the range [-1, 1]. 57 | :param means: the Gaussian mean Tensor. 58 | :param log_scales: the Gaussian log stddev Tensor. 59 | :return: a tensor like x of log probabilities (in nats). 60 | """ 61 | assert x.shape == means.shape == log_scales.shape 62 | centered_x = x - means 63 | inv_stdv = th.exp(-log_scales) 64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 65 | cdf_plus = approx_standard_normal_cdf(plus_in) 66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 67 | cdf_min = approx_standard_normal_cdf(min_in) 68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 70 | cdf_delta = cdf_plus - cdf_min 71 | log_probs = th.where( 72 | x < -0.999, 73 | log_cdf_plus, 74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 75 | ) 76 | assert log_probs.shape == x.shape 77 | return log_probs 78 | -------------------------------------------------------------------------------- /diffusion/nn.py: -------------------------------------------------------------------------------- 1 | # This code is based on https://github.com/openai/guided-diffusion 2 | """ 3 | Various utilities for neural networks. 4 | """ 5 | 6 | import math 7 | 8 | import torch as th 9 | import torch.nn as nn 10 | 11 | 12 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 13 | class SiLU(nn.Module): 14 | def forward(self, x): 15 | return x * th.sigmoid(x) 16 | 17 | 18 | class GroupNorm32(nn.GroupNorm): 19 | def forward(self, x): 20 | return super().forward(x.float()).type(x.dtype) 21 | 22 | 23 | def conv_nd(dims, *args, **kwargs): 24 | """ 25 | Create a 1D, 2D, or 3D convolution module. 26 | """ 27 | if dims == 1: 28 | return nn.Conv1d(*args, **kwargs) 29 | elif dims == 2: 30 | return nn.Conv2d(*args, **kwargs) 31 | elif dims == 3: 32 | return nn.Conv3d(*args, **kwargs) 33 | raise ValueError(f"unsupported dimensions: {dims}") 34 | 35 | 36 | def linear(*args, **kwargs): 37 | """ 38 | Create a linear module. 39 | """ 40 | return nn.Linear(*args, **kwargs) 41 | 42 | 43 | def avg_pool_nd(dims, *args, **kwargs): 44 | """ 45 | Create a 1D, 2D, or 3D average pooling module. 46 | """ 47 | if dims == 1: 48 | return nn.AvgPool1d(*args, **kwargs) 49 | elif dims == 2: 50 | return nn.AvgPool2d(*args, **kwargs) 51 | elif dims == 3: 52 | return nn.AvgPool3d(*args, **kwargs) 53 | raise ValueError(f"unsupported dimensions: {dims}") 54 | 55 | 56 | def update_ema(target_params, source_params, rate=0.99): 57 | """ 58 | Update target parameters to be closer to those of source parameters using 59 | an exponential moving average. 60 | 61 | :param target_params: the target parameter sequence. 62 | :param source_params: the source parameter sequence. 63 | :param rate: the EMA rate (closer to 1 means slower). 64 | """ 65 | for targ, src in zip(target_params, source_params): 66 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 67 | 68 | 69 | def zero_module(module): 70 | """ 71 | Zero out the parameters of a module and return it. 72 | """ 73 | for p in module.parameters(): 74 | p.detach().zero_() 75 | return module 76 | 77 | 78 | def scale_module(module, scale): 79 | """ 80 | Scale the parameters of a module and return it. 81 | """ 82 | for p in module.parameters(): 83 | p.detach().mul_(scale) 84 | return module 85 | 86 | 87 | def mean_flat(tensor): 88 | """ 89 | Take the mean over all non-batch dimensions. 90 | """ 91 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 92 | 93 | def sum_flat(tensor): 94 | """ 95 | Take the sum over all non-batch dimensions. 96 | """ 97 | return tensor.sum(dim=list(range(1, len(tensor.shape)))) 98 | 99 | 100 | def normalization(channels): 101 | """ 102 | Make a standard normalization layer. 103 | 104 | :param channels: number of input channels. 105 | :return: an nn.Module for normalization. 106 | """ 107 | return GroupNorm32(32, channels) 108 | 109 | 110 | def timestep_embedding(timesteps, dim, max_period=10000): 111 | """ 112 | Create sinusoidal timestep embeddings. 113 | 114 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 115 | These may be fractional. 116 | :param dim: the dimension of the output. 117 | :param max_period: controls the minimum frequency of the embeddings. 118 | :return: an [N x dim] Tensor of positional embeddings. 119 | """ 120 | half = dim // 2 121 | freqs = th.exp( 122 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 123 | ).to(device=timesteps.device) 124 | args = timesteps[:, None].float() * freqs[None] 125 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 126 | if dim % 2: 127 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 128 | return embedding 129 | 130 | 131 | def checkpoint(func, inputs, params, flag): 132 | """ 133 | Evaluate a function without caching intermediate activations, allowing for 134 | reduced memory at the expense of extra compute in the backward pass. 135 | :param func: the function to evaluate. 136 | :param inputs: the argument sequence to pass to `func`. 137 | :param params: a sequence of parameters `func` depends on but does not 138 | explicitly take as arguments. 139 | :param flag: if False, disable gradient checkpointing. 140 | """ 141 | if flag: 142 | args = tuple(inputs) + tuple(params) 143 | return CheckpointFunction.apply(func, len(inputs), *args) 144 | else: 145 | return func(*inputs) 146 | 147 | 148 | class CheckpointFunction(th.autograd.Function): 149 | @staticmethod 150 | @th.cuda.amp.custom_fwd 151 | def forward(ctx, run_function, length, *args): 152 | ctx.run_function = run_function 153 | ctx.input_length = length 154 | ctx.save_for_backward(*args) 155 | with th.no_grad(): 156 | output_tensors = ctx.run_function(*args[:length]) 157 | return output_tensors 158 | 159 | @staticmethod 160 | @th.cuda.amp.custom_bwd 161 | def backward(ctx, *output_grads): 162 | args = list(ctx.saved_tensors) 163 | 164 | # Filter for inputs that require grad. If none, exit early. 165 | input_indices = [i for (i, x) in enumerate(args) if x.requires_grad] 166 | if not input_indices: 167 | return (None, None) + tuple(None for _ in args) 168 | 169 | with th.enable_grad(): 170 | for i in input_indices: 171 | if i < ctx.input_length: 172 | # Not sure why the OAI code does this little 173 | # dance. It might not be necessary. 174 | args[i] = args[i].detach().requires_grad_() 175 | args[i] = args[i].view_as(args[i]) 176 | output_tensors = ctx.run_function(*args[:ctx.input_length]) 177 | 178 | if isinstance(output_tensors, th.Tensor): 179 | output_tensors = [output_tensors] 180 | 181 | # Filter for outputs that require grad. If none, exit early. 182 | out_and_grads = [(o, g) for (o, g) in zip(output_tensors, output_grads) if o.requires_grad] 183 | if not out_and_grads: 184 | return (None, None) + tuple(None for _ in args) 185 | 186 | # Compute gradients on the filtered tensors. 187 | computed_grads = th.autograd.grad( 188 | [o for (o, g) in out_and_grads], 189 | [args[i] for i in input_indices], 190 | [g for (o, g) in out_and_grads] 191 | ) 192 | 193 | # Reassemble the complete gradient tuple. 194 | input_grads = [None for _ in args] 195 | for (i, g) in zip(input_indices, computed_grads): 196 | input_grads[i] = g 197 | return (None, None) + tuple(input_grads) 198 | -------------------------------------------------------------------------------- /diffusion/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | 12 | :param name: the name of the sampler. 13 | :param diffusion: the diffusion object to sample for. 14 | """ 15 | if name == "uniform": 16 | return UniformSampler(diffusion) 17 | elif name == "loss-second-moment": 18 | return LossSecondMomentResampler(diffusion) 19 | else: 20 | raise NotImplementedError(f"unknown schedule sampler: {name}") 21 | 22 | 23 | class ScheduleSampler(ABC): 24 | """ 25 | A distribution over timesteps in the diffusion process, intended to reduce 26 | variance of the objective. 27 | 28 | By default, samplers perform unbiased importance sampling, in which the 29 | objective's mean is unchanged. 30 | However, subclasses may override sample() to change how the resampled 31 | terms are reweighted, allowing for actual changes in the objective. 32 | """ 33 | 34 | @abstractmethod 35 | def weights(self): 36 | """ 37 | Get a numpy array of weights, one per diffusion step. 38 | 39 | The weights needn't be normalized, but must be positive. 40 | """ 41 | 42 | def sample(self, batch_size, device): 43 | """ 44 | Importance-sample timesteps for a batch. 45 | 46 | :param batch_size: the number of timesteps. 47 | :param device: the torch device to save to. 48 | :return: a tuple (timesteps, weights): 49 | - timesteps: a tensor of timestep indices. 50 | - weights: a tensor of weights to scale the resulting losses. 51 | """ 52 | w = self.weights() 53 | p = w / np.sum(w) 54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 55 | indices = th.from_numpy(indices_np).long().to(device) 56 | weights_np = 1 / (len(p) * p[indices_np]) 57 | weights = th.from_numpy(weights_np).float().to(device) 58 | return indices, weights 59 | 60 | 61 | class UniformSampler(ScheduleSampler): 62 | def __init__(self, diffusion): 63 | self.diffusion = diffusion 64 | self._weights = np.ones([diffusion.num_timesteps]) 65 | 66 | def weights(self): 67 | return self._weights 68 | 69 | 70 | class LossAwareSampler(ScheduleSampler): 71 | def update_with_local_losses(self, local_ts, local_losses): 72 | """ 73 | Update the reweighting using losses from a model. 74 | 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | 80 | :param local_ts: an integer Tensor of timesteps. 81 | :param local_losses: a 1D Tensor of losses. 82 | """ 83 | batch_sizes = [ 84 | th.tensor([0], dtype=th.int32, device=local_ts.device) 85 | for _ in range(dist.get_world_size()) 86 | ] 87 | dist.all_gather( 88 | batch_sizes, 89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 90 | ) 91 | 92 | # Pad all_gather batches to be the maximum batch size. 93 | batch_sizes = [x.item() for x in batch_sizes] 94 | max_bs = max(batch_sizes) 95 | 96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 98 | dist.all_gather(timestep_batches, local_ts) 99 | dist.all_gather(loss_batches, local_losses) 100 | timesteps = [ 101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 102 | ] 103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 104 | self.update_with_all_losses(timesteps, losses) 105 | 106 | @abstractmethod 107 | def update_with_all_losses(self, ts, losses): 108 | """ 109 | Update the reweighting using losses from a model. 110 | 111 | Sub-classes should override this method to update the reweighting 112 | using losses from the model. 113 | 114 | This method directly updates the reweighting without synchronizing 115 | between workers. It is called by update_with_local_losses from all 116 | ranks with identical arguments. Thus, it should have deterministic 117 | behavior to maintain state across workers. 118 | 119 | :param ts: a list of int timesteps. 120 | :param losses: a list of float losses, one per timestep. 121 | """ 122 | 123 | 124 | class LossSecondMomentResampler(LossAwareSampler): 125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 126 | self.diffusion = diffusion 127 | self.history_per_term = history_per_term 128 | self.uniform_prob = uniform_prob 129 | self._loss_history = np.zeros( 130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 131 | ) 132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 133 | 134 | def weights(self): 135 | if not self._warmed_up(): 136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 138 | weights /= np.sum(weights) 139 | weights *= 1 - self.uniform_prob 140 | weights += self.uniform_prob / len(weights) 141 | return weights 142 | 143 | def update_with_all_losses(self, ts, losses): 144 | for t, loss in zip(ts, losses): 145 | if self._loss_counts[t] == self.history_per_term: 146 | # Shift out the oldest loss term. 147 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 148 | self._loss_history[t, -1] = loss 149 | else: 150 | self._loss_history[t, self._loss_counts[t]] = loss 151 | self._loss_counts[t] += 1 152 | 153 | def _warmed_up(self): 154 | return (self._loss_counts == self.history_per_term).all() 155 | -------------------------------------------------------------------------------- /diffusion/respace.py: -------------------------------------------------------------------------------- 1 | # This code is based on https://github.com/openai/guided-diffusion 2 | from copy import deepcopy 3 | import numpy as np 4 | import torch as th 5 | 6 | from .gaussian_diffusion import GaussianDiffusion, DiffusionConfig 7 | 8 | 9 | def space_timesteps(num_timesteps, section_counts): 10 | """ 11 | Create a list of timesteps to use from an original diffusion process, 12 | given the number of timesteps we want to take from equally-sized portions 13 | of the original process. 14 | 15 | For example, if there's 300 timesteps and the section counts are [10,15,20] 16 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 17 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 18 | 19 | If the stride is a string starting with "ddim", then the fixed striding 20 | from the DDIM paper is used, and only one section is allowed. 21 | 22 | :param num_timesteps: the number of diffusion steps in the original 23 | process to divide up. 24 | :param section_counts: either a list of numbers, or a string containing 25 | comma-separated numbers, indicating the step count 26 | per section. As a special case, use "ddimN" where N 27 | is a number of steps to use the striding from the 28 | DDIM paper. 29 | :return: a set of diffusion steps from the original process to use. 30 | """ 31 | if isinstance(section_counts, str): 32 | if section_counts.startswith("ddim"): 33 | desired_count = int(section_counts[len("ddim") :]) 34 | for i in range(1, num_timesteps): 35 | if len(range(0, num_timesteps, i)) == desired_count: 36 | return set(range(0, num_timesteps, i)) 37 | raise ValueError( 38 | f"cannot create exactly {num_timesteps} steps with an integer stride" 39 | ) 40 | section_counts = [int(x) for x in section_counts.split(",")] 41 | size_per = num_timesteps // len(section_counts) 42 | extra = num_timesteps % len(section_counts) 43 | start_idx = 0 44 | all_steps = [] 45 | for i, section_count in enumerate(section_counts): 46 | size = size_per + (1 if i < extra else 0) 47 | if size < section_count: 48 | raise ValueError( 49 | f"cannot divide section of {size} steps into {section_count}" 50 | ) 51 | if section_count <= 1: 52 | frac_stride = 1 53 | else: 54 | frac_stride = (size - 1) / (section_count - 1) 55 | cur_idx = 0.0 56 | taken_steps = [] 57 | for _ in range(section_count): 58 | taken_steps.append(start_idx + round(cur_idx)) 59 | cur_idx += frac_stride 60 | all_steps += taken_steps 61 | start_idx += size 62 | return set(all_steps) 63 | 64 | 65 | class SpacedDiffusion(GaussianDiffusion): 66 | """ 67 | A diffusion process which can skip steps in a base diffusion process. 68 | 69 | :param use_timesteps: a collection (sequence or set) of timesteps from the 70 | original diffusion process to retain. 71 | :param kwargs: the kwargs to create the base diffusion process. 72 | """ 73 | 74 | def __init__(self, use_timesteps, conf: DiffusionConfig): 75 | self.use_timesteps = set(use_timesteps) 76 | self.timestep_map = [] 77 | self.original_num_steps = len(conf.betas) 78 | 79 | base_diffusion = GaussianDiffusion(conf) # pylint: disable=missing-kwoa 80 | last_alpha_cumprod = 1.0 81 | new_betas = [] 82 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 83 | if i in self.use_timesteps: 84 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 85 | last_alpha_cumprod = alpha_cumprod 86 | self.timestep_map.append(i) 87 | 88 | # use the new conf to create a new diffusion process 89 | new_conf = deepcopy(conf) 90 | new_conf.betas = np.array(new_betas) 91 | super().__init__(new_conf) 92 | 93 | def p_mean_variance( 94 | self, model, *args, **kwargs 95 | ): # pylint: disable=signature-differs 96 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 97 | 98 | def training_losses( 99 | self, model, *args, **kwargs 100 | ): # pylint: disable=signature-differs 101 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 102 | 103 | def condition_mean(self, cond_fn, *args, **kwargs): 104 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 105 | 106 | def condition_score(self, cond_fn, *args, **kwargs): 107 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 108 | 109 | def _wrap_model(self, model): 110 | if isinstance(model, _WrappedModel): 111 | return model 112 | return _WrappedModel( 113 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 114 | ) 115 | 116 | def _scale_timesteps(self, t): 117 | # Scaling is done by the wrapped model. 118 | return t 119 | 120 | 121 | class _WrappedModel: 122 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 123 | self.model = model 124 | self.timestep_map = timestep_map 125 | self.rescale_timesteps = rescale_timesteps 126 | self.original_num_steps = original_num_steps 127 | 128 | def __call__(self, x, ts, **kwargs): 129 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 130 | new_ts = map_tensor[ts] 131 | if self.rescale_timesteps: 132 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 133 | return self.model(x, new_ts, **kwargs) 134 | -------------------------------------------------------------------------------- /eval/a2m/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/setarehc/diffusion-motion-inbetweening/0c58473066ef582b9464a66d5542e0d1b6ae66c8/eval/a2m/__init__.py -------------------------------------------------------------------------------- /eval/a2m/action2motion/accuracy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def calculate_accuracy(model, motion_loader, num_labels, classifier, device): 5 | confusion = torch.zeros(num_labels, num_labels, dtype=torch.long) 6 | with torch.no_grad(): 7 | for batch in motion_loader: 8 | batch_prob = classifier(batch["output_xyz"], lengths=batch["lengths"]) 9 | batch_pred = batch_prob.max(dim=1).indices 10 | for label, pred in zip(batch["y"], batch_pred): 11 | confusion[label][pred] += 1 12 | 13 | accuracy = torch.trace(confusion)/torch.sum(confusion) 14 | return accuracy.item(), confusion 15 | -------------------------------------------------------------------------------- /eval/a2m/action2motion/diversity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | #adapted from action2motion 6 | def calculate_diversity(activations): 7 | diversity_times = 200 8 | num_motions = len(activations) 9 | 10 | diversity = 0 11 | 12 | first_indices = np.random.randint(0, num_motions, diversity_times) 13 | second_indices = np.random.randint(0, num_motions, diversity_times) 14 | for first_idx, second_idx in zip(first_indices, second_indices): 15 | diversity += torch.dist(activations[first_idx, :], 16 | activations[second_idx, :]) 17 | diversity /= diversity_times 18 | return diversity 19 | 20 | # from action2motion 21 | def calculate_diversity_multimodality(activations, labels, num_labels, unconstrained = False): 22 | diversity_times = 200 23 | multimodality_times = 20 24 | if not unconstrained: 25 | labels = labels.long() 26 | num_motions = activations.shape[0] # len(labels) 27 | 28 | diversity = 0 29 | 30 | first_indices = np.random.randint(0, num_motions, diversity_times) 31 | second_indices = np.random.randint(0, num_motions, diversity_times) 32 | for first_idx, second_idx in zip(first_indices, second_indices): 33 | diversity += torch.dist(activations[first_idx, :], 34 | activations[second_idx, :]) 35 | diversity /= diversity_times 36 | 37 | if not unconstrained: 38 | multimodality = 0 39 | label_quotas = np.zeros(num_labels) 40 | label_quotas[labels.unique()] = multimodality_times # if a label does not appear in batch, its quota remains zero 41 | while np.any(label_quotas > 0): 42 | # print(label_quotas) 43 | first_idx = np.random.randint(0, num_motions) 44 | first_label = labels[first_idx] 45 | if not label_quotas[first_label]: 46 | continue 47 | 48 | second_idx = np.random.randint(0, num_motions) 49 | second_label = labels[second_idx] 50 | while first_label != second_label: 51 | second_idx = np.random.randint(0, num_motions) 52 | second_label = labels[second_idx] 53 | 54 | label_quotas[first_label] -= 1 55 | 56 | first_activation = activations[first_idx, :] 57 | second_activation = activations[second_idx, :] 58 | multimodality += torch.dist(first_activation, 59 | second_activation) 60 | 61 | multimodality /= (multimodality_times * num_labels) 62 | else: 63 | multimodality = torch.tensor(np.nan) 64 | 65 | return diversity.item(), multimodality.item() 66 | 67 | -------------------------------------------------------------------------------- /eval/a2m/action2motion/evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from .models import load_classifier, load_classifier_for_fid 4 | from .accuracy import calculate_accuracy 5 | from .fid import calculate_fid 6 | from .diversity import calculate_diversity_multimodality 7 | 8 | 9 | class A2MEvaluation: 10 | def __init__(self, device): 11 | dataset_opt = {"input_size_raw": 72, "joints_num": 24, "num_classes": 12} 12 | 13 | self.input_size_raw = dataset_opt["input_size_raw"] 14 | self.num_classes = dataset_opt["num_classes"] 15 | self.device = device 16 | 17 | self.gru_classifier_for_fid = load_classifier_for_fid(self.input_size_raw, self.num_classes, device).eval() 18 | self.gru_classifier = load_classifier(self.input_size_raw, self.num_classes, device).eval() 19 | 20 | def compute_features(self, model, motionloader): 21 | # calculate_activations_labels function from action2motion 22 | activations = [] 23 | labels = [] 24 | with torch.no_grad(): 25 | for idx, batch in enumerate(motionloader): 26 | activations.append(self.gru_classifier_for_fid(batch["output_xyz"], lengths=batch["lengths"])) 27 | if model.cond_mode != 'no_cond': 28 | labels.append(batch["y"]) 29 | activations = torch.cat(activations, dim=0) 30 | if model.cond_mode != 'no_cond': 31 | labels = torch.cat(labels, dim=0) 32 | return activations, labels 33 | 34 | @staticmethod 35 | def calculate_activation_statistics(activations): 36 | activations = activations.cpu().numpy() 37 | mu = np.mean(activations, axis=0) 38 | sigma = np.cov(activations, rowvar=False) 39 | return mu, sigma 40 | 41 | def evaluate(self, model, loaders): 42 | 43 | def print_logs(metric, key): 44 | print(f"Computing action2motion {metric} on the {key} loader ...") 45 | 46 | metrics = {} 47 | 48 | computedfeats = {} 49 | for key, loader in loaders.items(): 50 | metric = "accuracy" 51 | print_logs(metric, key) 52 | mkey = f"{metric}_{key}" 53 | if model.cond_mode != 'no_cond': 54 | metrics[mkey], _ = calculate_accuracy(model, loader, 55 | self.num_classes, 56 | self.gru_classifier, self.device) 57 | else: 58 | metrics[mkey] = np.nan 59 | 60 | # features for diversity 61 | print_logs("features", key) 62 | feats, labels = self.compute_features(model, loader) 63 | print_logs("stats", key) 64 | stats = self.calculate_activation_statistics(feats) 65 | 66 | computedfeats[key] = {"feats": feats, 67 | "labels": labels, 68 | "stats": stats} 69 | 70 | print_logs("diversity", key) 71 | ret = calculate_diversity_multimodality(feats, labels, self.num_classes, unconstrained=(model.cond_mode=='no_cond')) 72 | metrics[f"diversity_{key}"], metrics[f"multimodality_{key}"] = ret 73 | 74 | # taking the stats of the ground truth and remove it from the computed feats 75 | gtstats = computedfeats["gt"]["stats"] 76 | # computing fid 77 | for key, loader in computedfeats.items(): 78 | metric = "fid" 79 | mkey = f"{metric}_{key}" 80 | 81 | stats = computedfeats[key]["stats"] 82 | metrics[mkey] = float(calculate_fid(gtstats, stats)) 83 | 84 | return metrics 85 | -------------------------------------------------------------------------------- /eval/a2m/action2motion/fid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import linalg 3 | 4 | 5 | # from action2motion 6 | def calculate_fid(statistics_1, statistics_2): 7 | return calculate_frechet_distance(statistics_1[0], statistics_1[1], 8 | statistics_2[0], statistics_2[1]) 9 | 10 | 11 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 12 | """Numpy implementation of the Frechet Distance. 13 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 14 | and X_2 ~ N(mu_2, C_2) is 15 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 16 | Stable version by Dougal J. Sutherland. 17 | Params: 18 | -- mu1 : Numpy array containing the activations of a layer of the 19 | inception net (like returned by the function 'get_predictions') 20 | for generated samples. 21 | -- mu2 : The sample mean over activations, precalculated on an 22 | representative data set. 23 | -- sigma1: The covariance matrix over activations for generated samples. 24 | -- sigma2: The covariance matrix over activations, precalculated on an 25 | representative data set. 26 | Returns: 27 | -- : The Frechet Distance. 28 | """ 29 | 30 | mu1 = np.atleast_1d(mu1) 31 | mu2 = np.atleast_1d(mu2) 32 | 33 | sigma1 = np.atleast_2d(sigma1) 34 | sigma2 = np.atleast_2d(sigma2) 35 | 36 | assert mu1.shape == mu2.shape, \ 37 | 'Training and test mean vectors have different lengths' 38 | assert sigma1.shape == sigma2.shape, \ 39 | 'Training and test covariances have different dimensions' 40 | 41 | diff = mu1 - mu2 42 | 43 | # Product might be almost singular 44 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 45 | if not np.isfinite(covmean).all(): 46 | msg = ('fid calculation produces singular product; ' 47 | 'adding %s to diagonal of cov estimates') % eps 48 | print(msg) 49 | offset = np.eye(sigma1.shape[0]) * eps 50 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 51 | 52 | # Numerical error might give slight imaginary component 53 | if np.iscomplexobj(covmean): 54 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 55 | m = np.max(np.abs(covmean.imag)) 56 | raise ValueError('Imaginary component {}'.format(m)) 57 | covmean = covmean.real 58 | 59 | tr_covmean = np.trace(covmean) 60 | 61 | return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean) 62 | -------------------------------------------------------------------------------- /eval/a2m/action2motion/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | # adapted from action2motion to take inputs of different lengths 6 | class MotionDiscriminator(nn.Module): 7 | def __init__(self, input_size, hidden_size, hidden_layer, device, output_size=12, use_noise=None): 8 | super(MotionDiscriminator, self).__init__() 9 | self.device = device 10 | 11 | self.input_size = input_size 12 | self.hidden_size = hidden_size 13 | self.hidden_layer = hidden_layer 14 | self.use_noise = use_noise 15 | 16 | self.recurrent = nn.GRU(input_size, hidden_size, hidden_layer) 17 | self.linear1 = nn.Linear(hidden_size, 30) 18 | self.linear2 = nn.Linear(30, output_size) 19 | 20 | def forward(self, motion_sequence, lengths=None, hidden_unit=None): 21 | # dim (motion_length, num_samples, hidden_size) 22 | bs, njoints, nfeats, num_frames = motion_sequence.shape 23 | motion_sequence = motion_sequence.reshape(bs, njoints*nfeats, num_frames) 24 | motion_sequence = motion_sequence.permute(2, 0, 1) 25 | if hidden_unit is None: 26 | # motion_sequence = motion_sequence.permute(1, 0, 2) 27 | hidden_unit = self.initHidden(motion_sequence.size(1), self.hidden_layer) 28 | gru_o, _ = self.recurrent(motion_sequence.float(), hidden_unit) 29 | 30 | # select the last valid, instead of: gru_o[-1, :, :] 31 | out = gru_o[tuple(torch.stack((lengths-1, torch.arange(bs, device=self.device))))] 32 | 33 | # dim (num_samples, 30) 34 | lin1 = self.linear1(out) 35 | lin1 = torch.tanh(lin1) 36 | # dim (num_samples, output_size) 37 | lin2 = self.linear2(lin1) 38 | return lin2 39 | 40 | def initHidden(self, num_samples, layer): 41 | return torch.randn(layer, num_samples, self.hidden_size, device=self.device, requires_grad=False) 42 | 43 | 44 | class MotionDiscriminatorForFID(MotionDiscriminator): 45 | def forward(self, motion_sequence, lengths=None, hidden_unit=None): 46 | # dim (motion_length, num_samples, hidden_size) 47 | bs, njoints, nfeats, num_frames = motion_sequence.shape 48 | motion_sequence = motion_sequence.reshape(bs, njoints*nfeats, num_frames) 49 | motion_sequence = motion_sequence.permute(2, 0, 1) 50 | if hidden_unit is None: 51 | # motion_sequence = motion_sequence.permute(1, 0, 2) 52 | hidden_unit = self.initHidden(motion_sequence.size(1), self.hidden_layer) 53 | gru_o, _ = self.recurrent(motion_sequence.float(), hidden_unit) 54 | 55 | # select the last valid, instead of: gru_o[-1, :, :] 56 | out = gru_o[tuple(torch.stack((lengths-1, torch.arange(bs, device=self.device))))] 57 | 58 | # dim (num_samples, 30) 59 | lin1 = self.linear1(out) 60 | lin1 = torch.tanh(lin1) 61 | return lin1 62 | 63 | 64 | model_path = "./assets/actionrecognition/humanact12_gru.tar" 65 | 66 | 67 | def load_classifier(input_size_raw, num_classes, device): 68 | model = torch.load(model_path, map_location=device) 69 | classifier = MotionDiscriminator(input_size_raw, 128, 2, device=device, output_size=num_classes).to(device) 70 | classifier.load_state_dict(model["model"]) 71 | classifier.eval() 72 | return classifier 73 | 74 | 75 | def load_classifier_for_fid(input_size_raw, num_classes, device): 76 | model = torch.load(model_path, map_location=device) 77 | classifier = MotionDiscriminatorForFID(input_size_raw, 128, 2, device=device, output_size=num_classes).to(device) 78 | classifier.load_state_dict(model["model"]) 79 | classifier.eval() 80 | return classifier 81 | 82 | 83 | def test(): 84 | from src.datasets.ntu13 import NTU13 85 | import src.utils.fixseed # noqa 86 | 87 | classifier = load_classifier("ntu13", input_size_raw=54, num_classes=13, device="cuda").eval() 88 | params = {"pose_rep": "rot6d", 89 | "translation": True, 90 | "glob": True, 91 | "jointstype": "a2m", 92 | "vertstrans": True, 93 | "num_frames": 60, 94 | "sampling": "conseq", 95 | "sampling_step": 1} 96 | dataset = NTU13(**params) 97 | 98 | from src.models.rotation2xyz import Rotation2xyz 99 | rot2xyz = Rotation2xyz(device="cuda") 100 | confusion_xyz = torch.zeros(13, 13, dtype=torch.long) 101 | confusion = torch.zeros(13, 13, dtype=torch.long) 102 | 103 | for i in range(1000): 104 | dataset.pose_rep = "xyz" 105 | data = dataset[i][0].to("cuda") 106 | data = data[None] 107 | 108 | dataset.pose_rep = params["pose_rep"] 109 | x = dataset[i][0].to("cuda")[None] 110 | mask = torch.ones(1, x.shape[-1], dtype=bool, device="cuda") 111 | lengths = mask.sum(1) 112 | 113 | xyz_t = rot2xyz(x, mask, **params) 114 | 115 | predicted_cls_xyz = classifier(data, lengths=lengths).argmax().item() 116 | predicted_cls = classifier(xyz_t, lengths=lengths).argmax().item() 117 | 118 | gt_cls = dataset[i][1] 119 | 120 | confusion_xyz[gt_cls][predicted_cls_xyz] += 1 121 | confusion[gt_cls][predicted_cls] += 1 122 | 123 | accuracy_xyz = torch.trace(confusion_xyz)/torch.sum(confusion_xyz).item() 124 | accuracy = torch.trace(confusion)/torch.sum(confusion).item() 125 | 126 | print(f"accuracy: {accuracy:.1%}, accuracy_xyz: {accuracy_xyz:.1%}") 127 | 128 | 129 | if __name__ == "__main__": 130 | test() 131 | -------------------------------------------------------------------------------- /eval/a2m/gru_eval.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | 4 | import numpy as np 5 | from tqdm import tqdm 6 | import torch 7 | import functools 8 | from torch.utils.data import DataLoader 9 | 10 | from utils.fixseed import fixseed 11 | from data_loaders.tensors import collate 12 | from eval.a2m.action2motion.evaluate import A2MEvaluation 13 | from eval.unconstrained.evaluate import evaluate_unconstrained_metrics 14 | from .tools import save_metrics, format_metrics 15 | from utils import dist_util 16 | 17 | num_samples_unconstrained = 1000 18 | 19 | class NewDataloader: 20 | def __init__(self, mode, model, diffusion, dataiterator, device, unconstrained, num_samples: int=-1): 21 | assert mode in ["gen", "gt"] 22 | self.batches = [] 23 | sample_fn = diffusion.p_sample_loop 24 | with torch.no_grad(): 25 | for motions, model_kwargs in tqdm(dataiterator, desc=f"Construct dataloader: {mode}.."): 26 | motions = motions.to(device) 27 | if num_samples != -1 and len(self.batches) * dataiterator.batch_size > num_samples: 28 | continue # do not break because it confuses the multiple loaders 29 | batch = dict() 30 | if mode == "gen": 31 | sample = sample_fn(model, motions.shape, clip_denoised=False, model_kwargs=model_kwargs) 32 | batch['output'] = sample 33 | elif mode == "gt": 34 | batch["output"] = motions 35 | 36 | # mask = torch.ones([batch["output"].shape[0], batch["output"].shape[-1]], dtype=bool).to(device) # batch_size x num_frames 37 | max_n_frames = model_kwargs['y']['lengths'].max() 38 | mask = model_kwargs['y']['mask'].reshape(dataiterator.batch_size, max_n_frames).bool() 39 | batch["output_xyz"] = model.rot2xyz(x=batch["output"], mask=mask, pose_rep='rot6d', glob=True, 40 | translation=True, jointstype='smpl', vertstrans=True, betas=None, 41 | beta=0, glob_rot=None, get_rotations_back=False) 42 | batch["lengths"] = model_kwargs['y']['lengths'].to(device) 43 | if unconstrained: # proceed only if not running unconstrained 44 | batch["y"] = model_kwargs['y']['action'].squeeze().long().cpu() # using torch.long so lengths/action will be used as indices 45 | self.batches.append(batch) 46 | 47 | num_samples_last_batch = num_samples % dataiterator.batch_size 48 | if num_samples_last_batch > 0: 49 | for k, v in self.batches[-1].items(): 50 | self.batches[-1][k] = v[:num_samples_last_batch] 51 | 52 | def __iter__(self): 53 | return iter(self.batches) 54 | 55 | def evaluate(args, model, diffusion, data): 56 | num_frames = 60 57 | 58 | # fix parameters for action2motion evaluation 59 | args.num_frames = num_frames 60 | args.jointstype = "smpl" 61 | args.vertstrans = True 62 | 63 | device = dist_util.dev() 64 | 65 | model.eval() 66 | 67 | a2mevaluation = A2MEvaluation(device=device) 68 | a2mmetrics = {} 69 | 70 | datasetGT1 = copy.deepcopy(data) 71 | datasetGT2 = copy.deepcopy(data) 72 | 73 | allseeds = list(range(args.num_seeds)) 74 | 75 | try: 76 | for index, seed in enumerate(allseeds): 77 | print(f"Evaluation number: {index+1}/{args.num_seeds}") 78 | fixseed(seed) 79 | 80 | datasetGT1.reset_shuffle() 81 | datasetGT1.shuffle() 82 | 83 | datasetGT2.reset_shuffle() 84 | datasetGT2.shuffle() 85 | 86 | dataiterator = DataLoader(datasetGT1, batch_size=args.batch_size, 87 | shuffle=False, num_workers=8, collate_fn=collate) 88 | dataiterator2 = DataLoader(datasetGT2, batch_size=args.batch_size, 89 | shuffle=False, num_workers=8, collate_fn=collate) 90 | 91 | new_data_loader = functools.partial(NewDataloader, model=model, diffusion=diffusion, device=device, 92 | unconstrained=args.unconstrained, num_samples=args.num_samples) 93 | motionloader = new_data_loader(mode="gen", dataiterator=dataiterator) 94 | gt_motionloader = new_data_loader("gt", dataiterator=dataiterator) 95 | gt_motionloader2 = new_data_loader("gt", dataiterator=dataiterator2) 96 | 97 | # Action2motionEvaluation 98 | loaders = {"gen": motionloader, 99 | "gt": gt_motionloader, 100 | "gt2": gt_motionloader2} 101 | 102 | a2mmetrics[seed] = a2mevaluation.evaluate(model, loaders) 103 | 104 | del loaders 105 | 106 | if args.unconstrained: # unconstrained 107 | dataset_unconstrained = copy.deepcopy(data) 108 | dataset_unconstrained.reset_shuffle() 109 | dataset_unconstrained.shuffle() 110 | dataiterator_unconstrained = DataLoader(dataset_unconstrained, batch_size=args.batch_size, 111 | shuffle=False, num_workers=8, collate_fn=collate) 112 | motionloader_unconstrained = new_data_loader(mode="gen", dataiterator=dataiterator_unconstrained, num_samples=num_samples_unconstrained) 113 | 114 | generated_motions = [] 115 | for motion in motionloader_unconstrained: 116 | idx = [15, 12, 16, 18, 20, 17, 19, 21, 0, 1, 4, 7, 2, 5, 8] 117 | motion = motion['output_xyz'][:, idx, :, :] 118 | generated_motions.append(motion.cpu().numpy()) 119 | generated_motions = np.concatenate(generated_motions) 120 | unconstrained_metrics = evaluate_unconstrained_metrics(generated_motions, device, fast=True) 121 | unconstrained_metrics = {k+'_unconstrained': v for k, v in unconstrained_metrics.items()} 122 | 123 | except KeyboardInterrupt: 124 | string = "Saving the evaluation before exiting.." 125 | print(string) 126 | 127 | metrics = {"feats": {key: [format_metrics(a2mmetrics[seed])[key] for seed in a2mmetrics.keys()] for key in a2mmetrics[allseeds[0]]}} 128 | if args.unconstrained: 129 | metrics["feats"] = {**metrics["feats"], **unconstrained_metrics} 130 | 131 | return metrics 132 | -------------------------------------------------------------------------------- /eval/a2m/recognition/models/stgcnutils/graph.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle as pkl 3 | 4 | from utils.config import SMPL_KINTREE_PATH 5 | 6 | 7 | class Graph: 8 | """ The Graph to model the skeletons extracted by the openpose 9 | Args: 10 | strategy (string): must be one of the follow candidates 11 | - uniform: Uniform Labeling 12 | - distance: Distance Partitioning 13 | - spatial: Spatial Configuration 14 | For more information, please refer to the section 'Partition Strategies' 15 | in our paper (https://arxiv.org/abs/1801.07455). 16 | layout (string): must be one of the follow candidates 17 | - openpose: Is consists of 18 joints. For more information, please 18 | refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose#output 19 | - ntu-rgb+d: Is consists of 25 joints. For more information, please 20 | refer to https://github.com/shahroudy/NTURGB-D 21 | - smpl: Consists of 24/23 joints with without global rotation. 22 | max_hop (int): the maximal distance between two connected nodes 23 | dilation (int): controls the spacing between the kernel points 24 | """ 25 | 26 | def __init__(self, 27 | layout='openpose', 28 | strategy='uniform', 29 | kintree_path=SMPL_KINTREE_PATH, 30 | max_hop=1, 31 | dilation=1): 32 | self.max_hop = max_hop 33 | self.dilation = dilation 34 | 35 | self.kintree_path = kintree_path 36 | 37 | self.get_edge(layout) 38 | self.hop_dis = get_hop_distance( 39 | self.num_node, self.edge, max_hop=max_hop) 40 | self.get_adjacency(strategy) 41 | 42 | def __str__(self): 43 | return self.A 44 | 45 | def get_edge(self, layout): 46 | if layout == 'openpose': 47 | self.num_node = 18 48 | self_link = [(i, i) for i in range(self.num_node)] 49 | neighbor_link = [(4, 3), (3, 2), (7, 6), (6, 5), (13, 12), (12, 50 | 11), 51 | (10, 9), (9, 8), (11, 5), (8, 2), (5, 1), (2, 1), 52 | (0, 1), (15, 0), (14, 0), (17, 15), (16, 14)] 53 | self.edge = self_link + neighbor_link 54 | self.center = 1 55 | elif layout == 'smpl': 56 | self.num_node = 24 57 | self_link = [(i, i) for i in range(self.num_node)] 58 | kt = pkl.load(open(self.kintree_path, "rb")) 59 | neighbor_link = [(k, kt[1][i + 1]) for i, k in enumerate(kt[0][1:])] 60 | self.edge = self_link + neighbor_link 61 | self.center = 0 62 | elif layout == 'smpl_noglobal': 63 | self.num_node = 23 64 | self_link = [(i, i) for i in range(self.num_node)] 65 | kt = pkl.load(open(self.kintree_path, "rb")) 66 | neighbor_link = [(k, kt[1][i + 1]) for i, k in enumerate(kt[0][1:])] 67 | # remove the root joint 68 | neighbor_1base = [n for n in neighbor_link if n[0] != 0 and n[1] != 0] 69 | neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base] 70 | self.edge = self_link + neighbor_link 71 | self.center = 0 72 | elif layout == 'ntu-rgb+d': 73 | self.num_node = 25 74 | self_link = [(i, i) for i in range(self.num_node)] 75 | neighbor_1base = [(1, 2), (2, 21), (3, 21), (4, 3), (5, 21), 76 | (6, 5), (7, 6), (8, 7), (9, 21), (10, 9), 77 | (11, 10), (12, 11), (13, 1), (14, 13), (15, 14), 78 | (16, 15), (17, 1), (18, 17), (19, 18), (20, 19), 79 | (22, 23), (23, 8), (24, 25), (25, 12)] 80 | neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base] 81 | self.edge = self_link + neighbor_link 82 | self.center = 21 - 1 83 | elif layout == 'ntu_edge': 84 | self.num_node = 24 85 | self_link = [(i, i) for i in range(self.num_node)] 86 | neighbor_1base = [(1, 2), (3, 2), (4, 3), (5, 2), (6, 5), (7, 6), 87 | (8, 7), (9, 2), (10, 9), (11, 10), (12, 11), 88 | (13, 1), (14, 13), (15, 14), (16, 15), (17, 1), 89 | (18, 17), (19, 18), (20, 19), (21, 22), (22, 8), 90 | (23, 24), (24, 12)] 91 | neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base] 92 | self.edge = self_link + neighbor_link 93 | self.center = 2 94 | # elif layout=='customer settings' 95 | # pass 96 | else: 97 | raise NotImplementedError("This Layout is not supported") 98 | 99 | def get_adjacency(self, strategy): 100 | valid_hop = range(0, self.max_hop + 1, self.dilation) 101 | adjacency = np.zeros((self.num_node, self.num_node)) 102 | for hop in valid_hop: 103 | adjacency[self.hop_dis == hop] = 1 104 | normalize_adjacency = normalize_digraph(adjacency) 105 | 106 | if strategy == 'uniform': 107 | A = np.zeros((1, self.num_node, self.num_node)) 108 | A[0] = normalize_adjacency 109 | self.A = A 110 | elif strategy == 'distance': 111 | A = np.zeros((len(valid_hop), self.num_node, self.num_node)) 112 | for i, hop in enumerate(valid_hop): 113 | A[i][self.hop_dis == hop] = normalize_adjacency[self.hop_dis == hop] 114 | self.A = A 115 | elif strategy == 'spatial': 116 | A = [] 117 | for hop in valid_hop: 118 | a_root = np.zeros((self.num_node, self.num_node)) 119 | a_close = np.zeros((self.num_node, self.num_node)) 120 | a_further = np.zeros((self.num_node, self.num_node)) 121 | for i in range(self.num_node): 122 | for j in range(self.num_node): 123 | if self.hop_dis[j, i] == hop: 124 | if self.hop_dis[j, self.center] == self.hop_dis[ 125 | i, self.center]: 126 | a_root[j, i] = normalize_adjacency[j, i] 127 | elif self.hop_dis[j, self. 128 | center] > self.hop_dis[i, self. 129 | center]: 130 | a_close[j, i] = normalize_adjacency[j, i] 131 | else: 132 | a_further[j, i] = normalize_adjacency[j, i] 133 | if hop == 0: 134 | A.append(a_root) 135 | else: 136 | A.append(a_root + a_close) 137 | A.append(a_further) 138 | A = np.stack(A) 139 | self.A = A 140 | else: 141 | raise NotImplementedError("This Strategy is not supported") 142 | 143 | 144 | def get_hop_distance(num_node, edge, max_hop=1): 145 | A = np.zeros((num_node, num_node)) 146 | for i, j in edge: 147 | A[j, i] = 1 148 | A[i, j] = 1 149 | 150 | # compute hop steps 151 | hop_dis = np.zeros((num_node, num_node)) + np.inf 152 | transfer_mat = [np.linalg.matrix_power(A, d) for d in range(max_hop + 1)] 153 | arrive_mat = (np.stack(transfer_mat) > 0) 154 | for d in range(max_hop, -1, -1): 155 | hop_dis[arrive_mat[d]] = d 156 | return hop_dis 157 | 158 | 159 | def normalize_digraph(A): 160 | Dl = np.sum(A, 0) 161 | num_node = A.shape[0] 162 | Dn = np.zeros((num_node, num_node)) 163 | for i in range(num_node): 164 | if Dl[i] > 0: 165 | Dn[i, i] = Dl[i]**(-1) 166 | AD = np.dot(A, Dn) 167 | return AD 168 | 169 | 170 | def normalize_undigraph(A): 171 | Dl = np.sum(A, 0) 172 | num_node = A.shape[0] 173 | Dn = np.zeros((num_node, num_node)) 174 | for i in range(num_node): 175 | if Dl[i] > 0: 176 | Dn[i, i] = Dl[i]**(-0.5) 177 | DAD = np.dot(np.dot(Dn, A), Dn) 178 | return DAD 179 | -------------------------------------------------------------------------------- /eval/a2m/recognition/models/stgcnutils/tgcn.py: -------------------------------------------------------------------------------- 1 | # The based unit of graph convolutional networks. 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class ConvTemporalGraphical(nn.Module): 8 | 9 | r"""The basic module for applying a graph convolution. 10 | Args: 11 | in_channels (int): Number of channels in the input sequence data 12 | out_channels (int): Number of channels produced by the convolution 13 | kernel_size (int): Size of the graph convolving kernel 14 | t_kernel_size (int): Size of the temporal convolving kernel 15 | t_stride (int, optional): Stride of the temporal convolution. Default: 1 16 | t_padding (int, optional): Temporal zero-padding added to both sides of 17 | the input. Default: 0 18 | t_dilation (int, optional): Spacing between temporal kernel elements. 19 | Default: 1 20 | bias (bool, optional): If ``True``, adds a learnable bias to the output. 21 | Default: ``True`` 22 | Shape: 23 | - Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format 24 | - Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format 25 | - Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format 26 | - Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format 27 | where 28 | :math:`N` is a batch size, 29 | :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`, 30 | :math:`T_{in}/T_{out}` is a length of input/output sequence, 31 | :math:`V` is the number of graph nodes. 32 | """ 33 | 34 | def __init__(self, 35 | in_channels, 36 | out_channels, 37 | kernel_size, 38 | t_kernel_size=1, 39 | t_stride=1, 40 | t_padding=0, 41 | t_dilation=1, 42 | bias=True): 43 | super().__init__() 44 | 45 | self.kernel_size = kernel_size 46 | self.conv = nn.Conv2d( 47 | in_channels, 48 | out_channels * kernel_size, 49 | kernel_size=(t_kernel_size, 1), 50 | padding=(t_padding, 0), 51 | stride=(t_stride, 1), 52 | dilation=(t_dilation, 1), 53 | bias=bias) 54 | 55 | def forward(self, x, A): 56 | assert A.size(0) == self.kernel_size 57 | 58 | x = self.conv(x) 59 | 60 | n, kc, t, v = x.size() 61 | x = x.view(n, self.kernel_size, kc//self.kernel_size, t, v) 62 | x = torch.einsum('nkctv,kvw->nctw', (x, A)) 63 | 64 | return x.contiguous(), A 65 | -------------------------------------------------------------------------------- /eval/a2m/stgcn/accuracy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def calculate_accuracy(model, motion_loader, num_labels, classifier, device): 5 | confusion = torch.zeros(num_labels, num_labels, dtype=torch.long) 6 | with torch.no_grad(): 7 | for batch in motion_loader: 8 | batch_prob = classifier(batch)["yhat"] 9 | batch_pred = batch_prob.max(dim=1).indices 10 | for label, pred in zip(batch["y"], batch_pred): 11 | confusion[label][pred] += 1 12 | 13 | accuracy = torch.trace(confusion)/torch.sum(confusion) 14 | return accuracy.item(), confusion 15 | -------------------------------------------------------------------------------- /eval/a2m/stgcn/diversity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | # from action2motion 6 | def calculate_diversity_multimodality(activations, labels, num_labels, seed=None, unconstrained = False): 7 | diversity_times = 200 8 | multimodality_times = 20 9 | if not unconstrained: 10 | labels = labels.long() 11 | num_motions = activations.shape[0] # len(labels) 12 | 13 | diversity = 0 14 | 15 | if seed is not None: 16 | np.random.seed(seed) 17 | 18 | first_indices = np.random.randint(0, num_motions, diversity_times) 19 | second_indices = np.random.randint(0, num_motions, diversity_times) 20 | for first_idx, second_idx in zip(first_indices, second_indices): 21 | diversity += torch.dist(activations[first_idx, :], 22 | activations[second_idx, :]) 23 | diversity /= diversity_times 24 | 25 | if not unconstrained: 26 | multimodality = 0 27 | label_quotas = np.zeros(num_labels) 28 | label_quotas[labels.unique()] = multimodality_times # if a label does not appear in batch, its quota remains zero 29 | while np.any(label_quotas > 0): 30 | # print(label_quotas) 31 | first_idx = np.random.randint(0, num_motions) 32 | first_label = labels[first_idx] 33 | if not label_quotas[first_label]: 34 | continue 35 | 36 | second_idx = np.random.randint(0, num_motions) 37 | second_label = labels[second_idx] 38 | while first_label != second_label: 39 | second_idx = np.random.randint(0, num_motions) 40 | second_label = labels[second_idx] 41 | 42 | label_quotas[first_label] -= 1 43 | 44 | first_activation = activations[first_idx, :] 45 | second_activation = activations[second_idx, :] 46 | multimodality += torch.dist(first_activation, 47 | second_activation) 48 | 49 | multimodality /= (multimodality_times * num_labels) 50 | else: 51 | multimodality = torch.tensor(np.nan) 52 | 53 | return diversity.item(), multimodality.item() 54 | 55 | -------------------------------------------------------------------------------- /eval/a2m/stgcn/evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from .accuracy import calculate_accuracy 4 | from .fid import calculate_fid 5 | from .diversity import calculate_diversity_multimodality 6 | 7 | from eval.a2m.recognition.models.stgcn import STGCN 8 | 9 | 10 | class Evaluation: 11 | def __init__(self, dataname, parameters, device, seed=None): 12 | layout = "smpl" # if parameters["glob"] else "smpl_noglobal" 13 | model = STGCN(in_channels=parameters["nfeats"], 14 | num_class=parameters["num_classes"], 15 | graph_args={"layout": layout, "strategy": "spatial"}, 16 | edge_importance_weighting=True, 17 | device=device) 18 | 19 | model = model.to(device) 20 | 21 | model_path = "./assets/actionrecognition/uestc_rot6d_stgcn.tar" 22 | 23 | state_dict = torch.load(model_path, map_location=device) 24 | model.load_state_dict(state_dict) 25 | model.eval() 26 | 27 | self.num_classes = parameters["num_classes"] 28 | self.model = model 29 | 30 | self.dataname = dataname 31 | self.device = device 32 | 33 | self.seed = seed 34 | 35 | def compute_features(self, model, motionloader): 36 | # calculate_activations_labels function from action2motion 37 | activations = [] 38 | labels = [] 39 | with torch.no_grad(): 40 | for idx, batch in enumerate(motionloader): 41 | activations.append(self.model(batch)["features"]) 42 | if model.cond_mode != 'no_cond': 43 | labels.append(batch["y"]) 44 | activations = torch.cat(activations, dim=0) 45 | if model.cond_mode != 'no_cond': 46 | labels = torch.cat(labels, dim=0) 47 | return activations, labels 48 | 49 | @staticmethod 50 | def calculate_activation_statistics(activations): 51 | activations = activations.cpu().numpy() 52 | mu = np.mean(activations, axis=0) 53 | sigma = np.cov(activations, rowvar=False) 54 | return mu, sigma 55 | 56 | def evaluate(self, model, loaders): 57 | def print_logs(metric, key): 58 | print(f"Computing stgcn {metric} on the {key} loader ...") 59 | 60 | metrics_all = {} 61 | for sets in ["train", "test"]: 62 | computedfeats = {} 63 | metrics = {} 64 | for key, loaderSets in loaders.items(): 65 | loader = loaderSets[sets] 66 | 67 | metric = "accuracy" 68 | mkey = f"{metric}_{key}" 69 | if model.cond_mode != 'no_cond': 70 | print_logs(metric, key) 71 | metrics[mkey], _ = calculate_accuracy(model, loader, 72 | self.num_classes, 73 | self.model, self.device) 74 | else: 75 | metrics[mkey] = np.nan 76 | 77 | # features for diversity 78 | print_logs("features", key) 79 | feats, labels = self.compute_features(model, loader) 80 | print_logs("stats", key) 81 | stats = self.calculate_activation_statistics(feats) 82 | 83 | computedfeats[key] = {"feats": feats, 84 | "labels": labels, 85 | "stats": stats} 86 | 87 | print_logs("diversity", key) 88 | ret = calculate_diversity_multimodality(feats, labels, self.num_classes, 89 | seed=self.seed, unconstrained=(model.cond_mode=='no_cond')) 90 | metrics[f"diversity_{key}"], metrics[f"multimodality_{key}"] = ret 91 | 92 | # taking the stats of the ground truth and remove it from the computed feats 93 | gtstats = computedfeats["gt"]["stats"] 94 | # computing fid 95 | for key, loader in computedfeats.items(): 96 | metric = "fid" 97 | mkey = f"{metric}_{key}" 98 | 99 | stats = computedfeats[key]["stats"] 100 | metrics[mkey] = float(calculate_fid(gtstats, stats)) 101 | 102 | metrics_all[sets] = metrics 103 | 104 | metrics = {} 105 | for sets in ["train", "test"]: 106 | for key in metrics_all[sets]: 107 | metrics[f"{key}_{sets}"] = metrics_all[sets][key] 108 | return metrics 109 | -------------------------------------------------------------------------------- /eval/a2m/stgcn/fid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import linalg 3 | 4 | 5 | # from action2motion 6 | def calculate_fid(statistics_1, statistics_2): 7 | return calculate_frechet_distance(statistics_1[0], statistics_1[1], 8 | statistics_2[0], statistics_2[1]) 9 | 10 | 11 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 12 | """Numpy implementation of the Frechet Distance. 13 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 14 | and X_2 ~ N(mu_2, C_2) is 15 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 16 | Stable version by Dougal J. Sutherland. 17 | Params: 18 | -- mu1 : Numpy array containing the activations of a layer of the 19 | inception net (like returned by the function 'get_predictions') 20 | for generated samples. 21 | -- mu2 : The sample mean over activations, precalculated on an 22 | representative data set. 23 | -- sigma1: The covariance matrix over activations for generated samples. 24 | -- sigma2: The covariance matrix over activations, precalculated on an 25 | representative data set. 26 | Returns: 27 | -- : The Frechet Distance. 28 | """ 29 | 30 | mu1 = np.atleast_1d(mu1) 31 | mu2 = np.atleast_1d(mu2) 32 | 33 | sigma1 = np.atleast_2d(sigma1) 34 | sigma2 = np.atleast_2d(sigma2) 35 | 36 | assert mu1.shape == mu2.shape, \ 37 | 'Training and test mean vectors have different lengths' 38 | assert sigma1.shape == sigma2.shape, \ 39 | 'Training and test covariances have different dimensions' 40 | 41 | diff = mu1 - mu2 42 | 43 | # Product might be almost singular 44 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 45 | if not np.isfinite(covmean).all(): 46 | msg = ('fid calculation produces singular product; ' 47 | 'adding %s to diagonal of cov estimates') % eps 48 | print(msg) 49 | offset = np.eye(sigma1.shape[0]) * eps 50 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 51 | 52 | # Numerical error might give slight imaginary component 53 | if np.iscomplexobj(covmean): 54 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 55 | m = np.max(np.abs(covmean.imag)) 56 | raise ValueError('Imaginary component {}'.format(m)) 57 | covmean = covmean.real 58 | 59 | tr_covmean = np.trace(covmean) 60 | 61 | return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean) 62 | -------------------------------------------------------------------------------- /eval/a2m/stgcn_eval.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from tqdm import tqdm 4 | import functools 5 | 6 | from utils.fixseed import fixseed 7 | 8 | from eval.a2m.stgcn.evaluate import Evaluation as STGCNEvaluation 9 | from torch.utils.data import DataLoader 10 | from data_loaders.tensors import collate 11 | 12 | 13 | from .tools import format_metrics 14 | import utils.rotation_conversions as geometry 15 | from utils import dist_util 16 | 17 | 18 | def convert_x_to_rot6d(x, pose_rep): 19 | # convert rotation to rot6d 20 | if pose_rep == "rotvec": 21 | x = geometry.matrix_to_rotation_6d(geometry.axis_angle_to_matrix(x)) 22 | elif pose_rep == "rotmat": 23 | x = x.reshape(*x.shape[:-1], 3, 3) 24 | x = geometry.matrix_to_rotation_6d(x) 25 | elif pose_rep == "rotquat": 26 | x = geometry.matrix_to_rotation_6d(geometry.quaternion_to_matrix(x)) 27 | elif pose_rep == "rot6d": 28 | x = x 29 | else: 30 | raise NotImplementedError("No geometry for this one.") 31 | return x 32 | 33 | 34 | class NewDataloader: 35 | def __init__(self, mode, model, diffusion, dataiterator, device, cond_mode, dataset, num_samples): 36 | assert mode in ["gen", "gt"] 37 | 38 | self.batches = [] 39 | sample_fn = diffusion.p_sample_loop 40 | 41 | with torch.no_grad(): 42 | for motions, model_kwargs in tqdm(dataiterator, desc=f"Construct dataloader: {mode}.."): 43 | motions = motions.to(device) 44 | if num_samples != -1 and len(self.batches) * dataiterator.batch_size > num_samples: 45 | continue # do not break because it confuses the multiple loaders 46 | batch = dict() 47 | if mode == "gen": 48 | sample = sample_fn(model, motions.shape, clip_denoised=False, model_kwargs=model_kwargs) 49 | batch['output'] = sample 50 | elif mode == "gt": 51 | batch['output'] = motions 52 | 53 | max_n_frames = model_kwargs['y']['lengths'].max() 54 | mask = model_kwargs['y']['mask'].reshape(dataiterator.batch_size, max_n_frames).bool() 55 | batch["output_xyz"] = model.rot2xyz(x=batch["output"], mask=mask, pose_rep='rot6d', glob=True, 56 | translation=True, jointstype='smpl', vertstrans=True, betas=None, 57 | beta=0, glob_rot=None, get_rotations_back=False) 58 | if model.translation: 59 | # the stgcn model expects rotations only 60 | batch["output"] = batch["output"][:, :-1] 61 | 62 | batch["lengths"] = model_kwargs['y']['lengths'].to(device) 63 | # using torch.long so lengths/action will be used as indices 64 | if cond_mode != 'no_cond': # proceed only if not running unconstrained 65 | batch["y"] = model_kwargs['y']['action'].squeeze().long().cpu() # using torch.long so lengths/action will be used as indices 66 | self.batches.append(batch) 67 | 68 | num_samples_last_batch = num_samples % dataiterator.batch_size 69 | if num_samples_last_batch > 0: 70 | for k, v in self.batches[-1].items(): 71 | self.batches[-1][k] = v[:num_samples_last_batch] 72 | 73 | 74 | def __iter__(self): 75 | return iter(self.batches) 76 | 77 | 78 | def evaluate(args, model, diffusion, data): 79 | torch.multiprocessing.set_sharing_strategy('file_system') 80 | 81 | bs = args.batch_size 82 | args.num_classes = 40 83 | args.nfeats = 6 84 | args.njoint = 25 85 | device = dist_util.dev() 86 | 87 | 88 | recogparameters = args.__dict__.copy() 89 | recogparameters["pose_rep"] = "rot6d" 90 | recogparameters["nfeats"] = 6 91 | 92 | # Action2motionEvaluation 93 | stgcnevaluation = STGCNEvaluation(args.dataset, recogparameters, device) 94 | 95 | stgcn_metrics = {} 96 | 97 | data_types = ['train', 'test'] 98 | datasetGT = {'train': [data], 'test': [copy.deepcopy(data)]} 99 | 100 | for key in data_types: 101 | datasetGT[key][0].split = key 102 | 103 | compute_gt_gt = False 104 | if compute_gt_gt: 105 | for key in data_types: 106 | datasetGT[key].append(copy.deepcopy(datasetGT[key][0])) 107 | 108 | model.eval() 109 | 110 | allseeds = list(range(args.num_seeds)) 111 | 112 | for index, seed in enumerate(allseeds): 113 | print(f"Evaluation number: {index + 1}/{args.num_seeds}") 114 | fixseed(seed) 115 | for key in data_types: 116 | for data in datasetGT[key]: 117 | data.reset_shuffle() 118 | data.shuffle() 119 | 120 | dataiterator = {key: [DataLoader(data, batch_size=bs, shuffle=False, num_workers=8, collate_fn=collate) 121 | for data in datasetGT[key]] 122 | for key in data_types} 123 | 124 | new_data_loader = functools.partial(NewDataloader, model=model, diffusion=diffusion, device=device, 125 | cond_mode=args.cond_mode, dataset=args.dataset, num_samples=args.num_samples) 126 | gtLoaders = {key: new_data_loader(mode="gt", dataiterator=dataiterator[key][0]) 127 | for key in ["train", "test"]} 128 | 129 | if compute_gt_gt: 130 | gtLoaders2 = {key: new_data_loader(mode="gt", dataiterator=dataiterator[key][0]) 131 | for key in ["train", "test"]} 132 | 133 | genLoaders = {key: new_data_loader(mode="gen", dataiterator=dataiterator[key][0]) 134 | for key in ["train", "test"]} 135 | 136 | loaders = {"gen": genLoaders, 137 | "gt": gtLoaders} 138 | 139 | if compute_gt_gt: 140 | loaders["gt2"] = gtLoaders2 141 | 142 | stgcn_metrics[seed] = stgcnevaluation.evaluate(model, loaders) 143 | del loaders 144 | 145 | metrics = {"feats": {key: [format_metrics(stgcn_metrics[seed])[key] for seed in allseeds] for key in stgcn_metrics[allseeds[0]]}} 146 | 147 | return metrics 148 | -------------------------------------------------------------------------------- /eval/a2m/tools.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | 4 | def format_metrics(metrics, formatter="{:.6}"): 5 | newmetrics = {} 6 | for key, val in metrics.items(): 7 | newmetrics[key] = formatter.format(val) 8 | return newmetrics 9 | 10 | 11 | def save_metrics(path, metrics): 12 | with open(path, "w") as yfile: 13 | yaml.dump(metrics, yfile) 14 | 15 | 16 | def load_metrics(path): 17 | with open(path, "r") as yfile: 18 | string = yfile.read() 19 | return yaml.load(string, yaml.loader.BaseLoader) 20 | -------------------------------------------------------------------------------- /eval/eval_humanact12_uestc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate a large batch of image samples from a model and save them as a large 3 | numpy array. This can be used to produce samples for FID evaluation. 4 | """ 5 | import os 6 | import torch 7 | import re 8 | 9 | from utils import dist_util 10 | from model.cfg_sampler import ClassifierFreeSampleModel 11 | from data_loaders.get_data import get_dataset_loader 12 | from eval.a2m.tools import save_metrics 13 | from utils.parser_util import evaluation_parser 14 | from utils.fixseed import fixseed 15 | from utils.model_util import create_model_and_diffusion, load_model_wo_clip 16 | 17 | 18 | def evaluate(args, model, diffusion, data): 19 | scale = None 20 | if args.guidance_param != 1: 21 | model = ClassifierFreeSampleModel(model) # wrapping model with the classifier-free sampler 22 | scale = { 23 | 'action': torch.ones(args.batch_size) * args.guidance_param, 24 | } 25 | model.to(dist_util.dev()) 26 | model.eval() # disable random masking 27 | 28 | 29 | folder, ckpt_name = os.path.split(args.model_path) 30 | if args.dataset == "humanact12": 31 | from eval.a2m.gru_eval import evaluate 32 | eval_results = evaluate(args, model, diffusion, data) 33 | elif args.dataset == "uestc": 34 | from eval.a2m.stgcn_eval import evaluate 35 | eval_results = evaluate(args, model, diffusion, data) 36 | else: 37 | raise NotImplementedError("This dataset is not supported.") 38 | 39 | # save results 40 | iter = int(re.findall('\d+', ckpt_name)[0]) 41 | scale = 1 if scale is None else scale['action'][0].item() 42 | scale = str(scale).replace('.', 'p') 43 | metricname = "evaluation_results_iter{}_samp{}_scale{}_a2m.yaml".format(iter, args.num_samples, scale) 44 | evalpath = os.path.join(folder, metricname) 45 | print(f"Saving evaluation: {evalpath}") 46 | save_metrics(evalpath, eval_results) 47 | 48 | return eval_results 49 | 50 | 51 | def main(): 52 | args = evaluation_parser() 53 | fixseed(args.seed) 54 | dist_util.setup_dist(args.device) 55 | 56 | print(f'Eval mode [{args.eval_mode}]') 57 | assert args.eval_mode in ['debug', 'full'], f'eval_mode {args.eval_mode} is not supported for dataset {args.dataset}' 58 | if args.eval_mode == 'debug': 59 | args.num_samples = 10 60 | args.num_seeds = 2 61 | else: 62 | args.num_samples = 1000 63 | args.num_seeds = 20 64 | 65 | data_loader = get_dataset_loader(name=args.dataset, num_frames=60, batch_size=args.batch_size,) 66 | 67 | print("creating model and diffusion...") 68 | model, diffusion = create_model_and_diffusion(args, data_loader) 69 | 70 | print(f"Loading checkpoints from [{args.model_path}]...") 71 | state_dict = torch.load(args.model_path, map_location='cpu') 72 | load_model_wo_clip(model, state_dict) 73 | 74 | eval_results = evaluate(args, model, diffusion, data_loader.dataset) 75 | 76 | fid_to_print = {k : sum([float(vv) for vv in v])/len(v) for k, v in eval_results['feats'].items() if 'fid' in k and 'gen' in k} 77 | print(fid_to_print) 78 | 79 | if __name__ == '__main__': 80 | main() 81 | -------------------------------------------------------------------------------- /eval/unconstrained/evaluate.py: -------------------------------------------------------------------------------- 1 | from eval.unconstrained.models.stgcn import STGCN 2 | import pandas as pd 3 | import os.path as osp 4 | import os 5 | import datetime 6 | 7 | import torch 8 | 9 | from torch.utils.data import DataLoader 10 | import numpy as np 11 | import sys as _sys 12 | from eval.a2m.action2motion.fid import calculate_fid 13 | from eval.a2m.action2motion.diversity import calculate_diversity 14 | from eval.unconstrained.metrics.kid import calculate_kid 15 | from eval.unconstrained.metrics.precision_recall import precision_and_recall 16 | from matplotlib import pyplot as plt 17 | 18 | TEST = False 19 | 20 | 21 | def initialize_model(device, modelpath): 22 | num_classes = 12 23 | model = STGCN(in_channels=3, 24 | num_class=num_classes, 25 | graph_args={"layout": 'openpose', "strategy": "spatial"}, 26 | edge_importance_weighting=True, 27 | device=device) 28 | model = model.to(device) 29 | state_dict = torch.load(modelpath, map_location=device) 30 | model.load_state_dict(state_dict) 31 | model.eval() 32 | return model 33 | 34 | def calculate_activation_statistics(activations): 35 | activations = activations.cpu().detach().numpy() 36 | mu = np.mean(activations, axis=0) 37 | sigma = np.cov(activations, rowvar=False) 38 | return mu, sigma 39 | 40 | 41 | def compute_features(model, iterator, device): 42 | activations = [] 43 | predictions = [] 44 | with torch.no_grad(): 45 | for i, batch in enumerate(iterator): 46 | batch_for_model = {} 47 | batch_for_model['x'] = batch.to(device).float() 48 | model(batch_for_model) 49 | activations.append(batch_for_model['features']) 50 | predictions.append(batch_for_model['yhat']) 51 | # labels.append(batch_for_model['y']) 52 | activations = torch.cat(activations, dim=0) 53 | predictions = torch.cat(predictions, dim=0) 54 | return activations, predictions 55 | 56 | 57 | def evaluate_unconstrained_metrics(generated_motions, device, fast): 58 | 59 | act_rec_model_path = './assets/actionrecognition/humanact12_gru_modi_struct.pth.tar' 60 | dataset_path = './dataset/HumanAct12Poses/humanact12_modi_struct.npy' 61 | 62 | # initialize model 63 | act_rec_model = initialize_model(device, act_rec_model_path) 64 | 65 | generated_motions -= generated_motions[:, 8:9, :, :] # locate root joint of all frames at origin 66 | 67 | iterator_generated = DataLoader(generated_motions, batch_size=64, shuffle=False, num_workers=8) 68 | 69 | # compute features of generated motions 70 | generated_features, generated_predictions = compute_features(act_rec_model, iterator_generated, device=device) 71 | generated_stats = calculate_activation_statistics(generated_features) 72 | 73 | 74 | # dataset motions 75 | motion_data_raw = np.load(dataset_path, allow_pickle=True) 76 | motion_data = motion_data_raw[:, :15] # data has 16 joints for back compitability with older formats 77 | motion_data -= motion_data[:, 8:9, :, :] # locate root joint of all frames at origin 78 | iterator_dataset = DataLoader(motion_data, batch_size=64, shuffle=False, num_workers=8) 79 | 80 | # compute features of dataset motions 81 | dataset_features, dataset_predictions = compute_features(act_rec_model, iterator_dataset, device=device) 82 | real_stats = calculate_activation_statistics(dataset_features) 83 | 84 | print("evaluation resutls:\n") 85 | 86 | fid = calculate_fid(generated_stats, real_stats) 87 | print(f"FID score: {fid}\n") 88 | 89 | print("calculating KID...") 90 | kid = calculate_kid(dataset_features.cpu(), generated_features.cpu()) 91 | (m, s) = kid 92 | print('KID : %.3f (%.3f)\n' % (m, s)) 93 | 94 | dataset_diversity = calculate_diversity(dataset_features) 95 | generated_diversity = calculate_diversity(generated_features) 96 | print(f"Diversity of generated motions: {generated_diversity}") 97 | print(f"Diversity of dataset motions: {dataset_diversity}\n") 98 | 99 | if fast: 100 | print("Skipping precision-recall calculation\n") 101 | precision = recall = None 102 | else: 103 | print("calculating precision recall...") 104 | precision, recall = precision_and_recall(generated_features, dataset_features) 105 | print(f"precision: {precision}") 106 | print(f"recall: {recall}\n") 107 | 108 | metrics = {'fid': fid, 'kid': kid[0], 'diversity_gen': generated_diversity.cpu().item(), 'diversity_gt': dataset_diversity.cpu().item(), 109 | 'precision': precision, 'recall':recall} 110 | return metrics 111 | 112 | -------------------------------------------------------------------------------- /eval/unconstrained/metrics/kid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from tqdm import tqdm 4 | from sklearn.metrics.pairwise import polynomial_kernel 5 | import sys 6 | 7 | # from: https://github.com/abdulfatir/gan-metrics-pytorch/blob/master/kid_score.py 8 | def polynomial_mmd_averages(codes_g, codes_r, n_subsets=50, subset_size=1000, 9 | ret_var=True, output=sys.stdout, **kernel_args): 10 | m = min(codes_g.shape[0], codes_r.shape[0]) 11 | mmds = np.zeros(n_subsets) 12 | if ret_var: 13 | vars = np.zeros(n_subsets) 14 | choice = np.random.choice 15 | 16 | replace = subset_size < len(codes_g) 17 | with tqdm(range(n_subsets), desc='MMD', file=output, disable=True) as bar: 18 | for i in bar: 19 | g = codes_g[choice(len(codes_g), subset_size, replace=replace)] 20 | r = codes_r[choice(len(codes_r), subset_size, replace=replace)] 21 | o = polynomial_mmd(g, r, **kernel_args, var_at_m=m, ret_var=ret_var) 22 | if ret_var: 23 | mmds[i], vars[i] = o 24 | else: 25 | mmds[i] = o 26 | bar.set_postfix({'mean': mmds[:i+1].mean()}) 27 | return (mmds, vars) if ret_var else mmds 28 | 29 | def polynomial_mmd(codes_g, codes_r, degree=3, gamma=None, coef0=1, 30 | var_at_m=None, ret_var=True): 31 | # use k(x, y) = (gamma + coef0)^degree 32 | # default gamma is 1 / dim 33 | X = codes_g 34 | Y = codes_r 35 | 36 | K_XX = polynomial_kernel(X, degree=degree, gamma=gamma, coef0=coef0) 37 | K_YY = polynomial_kernel(Y, degree=degree, gamma=gamma, coef0=coef0) 38 | K_XY = polynomial_kernel(X, Y, degree=degree, gamma=gamma, coef0=coef0) 39 | 40 | return _mmd2_and_variance(K_XX, K_XY, K_YY, 41 | var_at_m=var_at_m, ret_var=ret_var) 42 | 43 | def _mmd2_and_variance(K_XX, K_XY, K_YY, unit_diagonal=False, 44 | mmd_est='unbiased', block_size=1024, 45 | var_at_m=None, ret_var=True): 46 | # based on 47 | # https://github.com/dougalsutherland/opt-mmd/blob/master/two_sample/mmd.py 48 | # but changed to not compute the full kernel matrix at once 49 | m = K_XX.shape[0] 50 | assert K_XX.shape == (m, m) 51 | assert K_XY.shape == (m, m) 52 | assert K_YY.shape == (m, m) 53 | if var_at_m is None: 54 | var_at_m = m 55 | 56 | # Get the various sums of kernels that we'll use 57 | # Kts drop the diagonal, but we don't need to compute them explicitly 58 | if unit_diagonal: 59 | diag_X = diag_Y = 1 60 | sum_diag_X = sum_diag_Y = m 61 | sum_diag2_X = sum_diag2_Y = m 62 | else: 63 | diag_X = np.diagonal(K_XX) 64 | diag_Y = np.diagonal(K_YY) 65 | 66 | sum_diag_X = diag_X.sum() 67 | sum_diag_Y = diag_Y.sum() 68 | 69 | sum_diag2_X = _sqn(diag_X) 70 | sum_diag2_Y = _sqn(diag_Y) 71 | 72 | Kt_XX_sums = K_XX.sum(axis=1) - diag_X 73 | Kt_YY_sums = K_YY.sum(axis=1) - diag_Y 74 | K_XY_sums_0 = K_XY.sum(axis=0) 75 | K_XY_sums_1 = K_XY.sum(axis=1) 76 | 77 | Kt_XX_sum = Kt_XX_sums.sum() 78 | Kt_YY_sum = Kt_YY_sums.sum() 79 | K_XY_sum = K_XY_sums_0.sum() 80 | 81 | if mmd_est == 'biased': 82 | mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m) 83 | + (Kt_YY_sum + sum_diag_Y) / (m * m) 84 | - 2 * K_XY_sum / (m * m)) 85 | else: 86 | assert mmd_est in {'unbiased', 'u-statistic'} 87 | mmd2 = (Kt_XX_sum + Kt_YY_sum) / (m * (m-1)) 88 | if mmd_est == 'unbiased': 89 | mmd2 -= 2 * K_XY_sum / (m * m) 90 | else: 91 | mmd2 -= 2 * (K_XY_sum - np.trace(K_XY)) / (m * (m-1)) 92 | 93 | if not ret_var: 94 | return mmd2 95 | 96 | Kt_XX_2_sum = _sqn(K_XX) - sum_diag2_X 97 | Kt_YY_2_sum = _sqn(K_YY) - sum_diag2_Y 98 | K_XY_2_sum = _sqn(K_XY) 99 | 100 | dot_XX_XY = Kt_XX_sums.dot(K_XY_sums_1) 101 | dot_YY_YX = Kt_YY_sums.dot(K_XY_sums_0) 102 | 103 | m1 = m - 1 104 | m2 = m - 2 105 | zeta1_est = ( 106 | 1 / (m * m1 * m2) * ( 107 | _sqn(Kt_XX_sums) - Kt_XX_2_sum + _sqn(Kt_YY_sums) - Kt_YY_2_sum) 108 | - 1 / (m * m1)**2 * (Kt_XX_sum**2 + Kt_YY_sum**2) 109 | + 1 / (m * m * m1) * ( 110 | _sqn(K_XY_sums_1) + _sqn(K_XY_sums_0) - 2 * K_XY_2_sum) 111 | - 2 / m**4 * K_XY_sum**2 112 | - 2 / (m * m * m1) * (dot_XX_XY + dot_YY_YX) 113 | + 2 / (m**3 * m1) * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum 114 | ) 115 | zeta2_est = ( 116 | 1 / (m * m1) * (Kt_XX_2_sum + Kt_YY_2_sum) 117 | - 1 / (m * m1)**2 * (Kt_XX_sum**2 + Kt_YY_sum**2) 118 | + 2 / (m * m) * K_XY_2_sum 119 | - 2 / m**4 * K_XY_sum**2 120 | - 4 / (m * m * m1) * (dot_XX_XY + dot_YY_YX) 121 | + 4 / (m**3 * m1) * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum 122 | ) 123 | var_est = (4 * (var_at_m - 2) / (var_at_m * (var_at_m - 1)) * zeta1_est 124 | + 2 / (var_at_m * (var_at_m - 1)) * zeta2_est) 125 | 126 | return mmd2, var_est 127 | 128 | 129 | def _sqn(arr): 130 | flat = np.ravel(arr) 131 | return flat.dot(flat) 132 | 133 | def calculate_kid(real_activations, generated_activations): 134 | kid_values = polynomial_mmd_averages(real_activations, generated_activations, n_subsets=100) 135 | results = (kid_values[0].mean(), kid_values[0].std()) 136 | return results 137 | -------------------------------------------------------------------------------- /eval/unconstrained/metrics/precision_recall.py: -------------------------------------------------------------------------------- 1 | # based on https://github.com/blandocs/improved-precision-and-recall-metric-pytorch/blob/master/functions.py 2 | import os, torch 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from tqdm import tqdm 7 | 8 | # self.batch_size = args.batch_size 9 | # self.cpu = args.cpu 10 | # self.data_size = args.data_size 11 | 12 | def precision_and_recall(generated_features, real_features): 13 | k = 3 14 | 15 | data_num = min(len(generated_features), len(real_features)) 16 | print(f'data num: {data_num}') 17 | 18 | if data_num <= 0: 19 | print("there is no data") 20 | return 21 | generated_features = generated_features[:data_num] 22 | real_features = real_features[:data_num] 23 | 24 | # get precision and recall 25 | precision = manifold_estimate(real_features, generated_features, k) 26 | recall = manifold_estimate(generated_features, real_features, k) 27 | 28 | return precision, recall 29 | 30 | def manifold_estimate( A_features, B_features, k): 31 | A_features = list(A_features) 32 | B_features = list(B_features) 33 | KNN_list_in_A = {} 34 | for A in tqdm(A_features, ncols=80): 35 | pairwise_distances = np.zeros(shape=(len(A_features))) 36 | 37 | for i, A_prime in enumerate(A_features): 38 | d = torch.norm((A - A_prime), 2) 39 | pairwise_distances[i] = d 40 | 41 | v = np.partition(pairwise_distances, k)[k] 42 | KNN_list_in_A[A] = v 43 | 44 | n = 0 45 | 46 | for B in tqdm(B_features, ncols=80): 47 | for A_prime in A_features: 48 | d = torch.norm((B - A_prime), 2) 49 | if d <= KNN_list_in_A[A_prime]: 50 | n += 1 51 | break 52 | 53 | return n / len(B_features) 54 | 55 | 56 | -------------------------------------------------------------------------------- /model/cfg_sampler.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from copy import deepcopy 3 | 4 | 5 | class ClassifierFreeSampleModel(nn.Module): 6 | 7 | def __init__(self, model): 8 | super().__init__() 9 | self.model = model # model is the actual model to run 10 | 11 | assert self.model.cond_mask_prob > 0, 'Cannot run a guided diffusion on a model that has not been trained with no conditions' 12 | 13 | # pointers to inner model 14 | self.rot2xyz = self.model.rot2xyz 15 | self.translation = self.model.translation 16 | self.njoints = self.model.njoints 17 | self.nfeats = self.model.nfeats 18 | self.data_rep = self.model.data_rep 19 | self.cond_mode = self.model.cond_mode 20 | self.keyframe_conditioned = self.model.keyframe_conditioned 21 | 22 | # self.mask_value = 0.0 23 | self.mask_value = -2.0 24 | 25 | def forward(self, x, timesteps, y=None, obs_x0=None, obs_mask=None, **kwargs): 26 | cond_mode = self.model.cond_mode 27 | assert cond_mode in ['text', 'action'] 28 | y_uncond = deepcopy(y) 29 | y_uncond['uncond'] = True 30 | # If there is condition, the uncond_model will take in the spatial condition as well 31 | out = self.model(x, timesteps, y, obs_x0, obs_mask, **kwargs) 32 | out_uncond = self.model(x, timesteps, y_uncond, obs_x0, obs_mask, **kwargs) 33 | # return out 34 | # return out_uncond + (1.5 * (out - out_uncond)) 35 | return out_uncond + (y['text_scale'].view(-1, 1, 1, 1) * (out - out_uncond)) 36 | 37 | def forward_smd_final(self, x, timesteps, y=None, **kwargs): 38 | '''_ori''' 39 | cond_mode = self.model.cond_mode 40 | assert cond_mode in ['text', 'action'] 41 | y_uncond = deepcopy(y) 42 | y_uncond['uncond'] = True 43 | # If there is condition, the uncond_model will take in the spatial condition as well 44 | out = self.model(x, timesteps, y, **kwargs) 45 | out_uncond = self.model(x, timesteps, y_uncond, **kwargs) 46 | # return out 47 | return out_uncond + (1.5 * (out - out_uncond)) 48 | # return out_uncond + (y['scale'].view(-1, 1, 1, 1) * (out - out_uncond)) 49 | 50 | def forward_correct(self, x, timesteps, y=None): 51 | '''_correct''' 52 | kps_to_text_ratio = 0.8 53 | cond_mode = self.model.cond_mode 54 | assert cond_mode in ['text', 'action'] 55 | y_uncond = deepcopy(y) 56 | y_uncond['uncond'] = True 57 | # out = self.model(x, timesteps, y) 58 | # out_spatial = self.model(x, timesteps, y_uncond) 59 | 60 | x_without_spatial = x + 0.0 61 | x_without_spatial[:, 263, :, :] *= 0.0 62 | x_without_spatial[:, 264:, :, :] = self.mask_value 63 | # out_uncond = self.model(x_without_spatial, timesteps, y_uncond) 64 | out_uncond = self.model(x, timesteps, y_uncond) 65 | # return out_uncond + (y['scale'].view(-1, 1, 1, 1) * (out - out_uncond)) 66 | # out = out_uncond + (0.9 * (out_spatial - out_uncond)) 67 | # out_text = self.model(x_without_spatial, timesteps, y) 68 | out_text = self.model(x, timesteps, y) 69 | 70 | out = out_uncond + (1.8 * (out_text - out_uncond)) 71 | # out = out_uncond + (0.5 * (out_text_scale - out_uncond)) 72 | # out[:, :3, :, :] = out_text[:, :3, :, :] 73 | return out # out_uncond + (0.8 * (out_patial - out_uncond)) 74 | # return out_uncond + (2.5 * (out - out_uncond)) 75 | 76 | 77 | def forward_average(self, x, timesteps, y=None): 78 | '''_average''' 79 | kps_to_text_ratio = 1.5 80 | cond_mode = self.model.cond_mode 81 | assert cond_mode in ['text', 'action'] 82 | y_uncond = deepcopy(y) 83 | y_uncond['uncond'] = True 84 | out_spatial = self.model(x, timesteps, y_uncond) 85 | 86 | x_without_spatial = x + 0.0 87 | x_without_spatial[:, 263, :, :] *= 0.0 88 | x_without_spatial[:, 264:, :, :] = self.mask_value 89 | out_uncond = self.model(x_without_spatial, timesteps, y_uncond) 90 | # out_uncond = self.model(x, timesteps, y_uncond) 91 | # out_with_spatial = out_uncond + (y['scale'].view(-1, 1, 1, 1) * (out - out_uncond)) 92 | 93 | out_text = self.model(x_without_spatial, timesteps, y) 94 | # out_text = self.model(x, timesteps, y) 95 | 96 | combined_out_spatial = out_uncond + (1.0 * (out_spatial - out_uncond)) 97 | combined_out_text = out_uncond + (y['scale'].view(-1, 1, 1, 1) * (out_text - out_uncond)) 98 | 99 | # return 1.0 * combined_out_text + 1.0 * combined_out_spatial - 1.0 * out_uncond 100 | return combined_out_text + (kps_to_text_ratio) * (combined_out_spatial - combined_out_text) 101 | 102 | 103 | # x_without_spatial = x + 0.0 104 | # x_without_spatial[:, 263:, :, :] *= 0.0 105 | # out_no_cond = self.model(x_without_spatial, timesteps, y) 106 | # out_uncond_without_spatial = self.model(x_without_spatial, timesteps, y_uncond) 107 | # out_without_spatial = out_uncond_without_spatial + (y['scale'].view(-1, 1, 1, 1) * (out_no_cond - out_uncond_without_spatial)) 108 | 109 | # # import pdb; pdb.set_trace() 110 | # return out_without_spatial + (kps_to_text_ratio) * (out_with_spatial - out_without_spatial) 111 | -------------------------------------------------------------------------------- /model/rotation2xyz.py: -------------------------------------------------------------------------------- 1 | # This code is based on https://github.com/Mathux/ACTOR.git 2 | import torch 3 | import utils.rotation_conversions as geometry 4 | 5 | 6 | from model.smpl import SMPL, JOINTSTYPE_ROOT 7 | # from .get_model import JOINTSTYPES 8 | JOINTSTYPES = ["a2m", "a2mpl", "smpl", "vibe", "vertices"] 9 | 10 | 11 | class Rotation2xyz: 12 | def __init__(self, device, dataset='amass'): 13 | self.device = device 14 | self.dataset = dataset 15 | self.smpl_model = SMPL().eval().to(device) 16 | 17 | def __call__(self, x, mask, pose_rep, translation, glob, 18 | jointstype, vertstrans, betas=None, beta=0, 19 | glob_rot=None, get_rotations_back=False, **kwargs): 20 | if pose_rep == "xyz": 21 | return x 22 | 23 | if mask is None: 24 | mask = torch.ones((x.shape[0], x.shape[-1]), dtype=bool, device=x.device) 25 | 26 | if not glob and glob_rot is None: 27 | raise TypeError("You must specify global rotation if glob is False") 28 | 29 | if jointstype not in JOINTSTYPES: 30 | raise NotImplementedError("This jointstype is not implemented.") 31 | 32 | if translation: 33 | x_translations = x[:, -1, :3] 34 | x_rotations = x[:, :-1] 35 | else: 36 | x_rotations = x 37 | 38 | x_rotations = x_rotations.permute(0, 3, 1, 2) 39 | nsamples, time, njoints, feats = x_rotations.shape 40 | 41 | # Compute rotations (convert only masked sequences output) 42 | if pose_rep == "rotvec": 43 | rotations = geometry.axis_angle_to_matrix(x_rotations[mask]) 44 | elif pose_rep == "rotmat": 45 | rotations = x_rotations[mask].view(-1, njoints, 3, 3) 46 | elif pose_rep == "rotquat": 47 | rotations = geometry.quaternion_to_matrix(x_rotations[mask]) 48 | elif pose_rep == "rot6d": 49 | rotations = geometry.rotation_6d_to_matrix(x_rotations[mask]) 50 | else: 51 | raise NotImplementedError("No geometry for this one.") 52 | 53 | if not glob: 54 | global_orient = torch.tensor(glob_rot, device=x.device) 55 | global_orient = geometry.axis_angle_to_matrix(global_orient).view(1, 1, 3, 3) 56 | global_orient = global_orient.repeat(len(rotations), 1, 1, 1) 57 | else: 58 | global_orient = rotations[:, 0] 59 | rotations = rotations[:, 1:] 60 | 61 | if betas is None: 62 | betas = torch.zeros([rotations.shape[0], self.smpl_model.num_betas], 63 | dtype=rotations.dtype, device=rotations.device) 64 | betas[:, 1] = beta 65 | # import ipdb; ipdb.set_trace() 66 | out = self.smpl_model(body_pose=rotations, global_orient=global_orient, betas=betas) 67 | 68 | # get the desirable joints 69 | joints = out[jointstype] 70 | 71 | x_xyz = torch.empty(nsamples, time, joints.shape[1], 3, device=x.device, dtype=x.dtype) 72 | x_xyz[~mask] = 0 73 | x_xyz[mask] = joints 74 | 75 | x_xyz = x_xyz.permute(0, 2, 3, 1).contiguous() 76 | 77 | # the first translation root at the origin on the prediction 78 | if jointstype != "vertices": 79 | rootindex = JOINTSTYPE_ROOT[jointstype] 80 | x_xyz = x_xyz - x_xyz[:, [rootindex], :, :] 81 | 82 | if translation and vertstrans: 83 | # the first translation root at the origin 84 | x_translations = x_translations - x_translations[:, :, [0]] 85 | 86 | # add the translation to all the joints 87 | x_xyz = x_xyz + x_translations[:, None, :, :] 88 | 89 | if get_rotations_back: 90 | return x_xyz, rotations, global_orient 91 | else: 92 | return x_xyz 93 | -------------------------------------------------------------------------------- /model/smpl.py: -------------------------------------------------------------------------------- 1 | # This code is based on https://github.com/Mathux/ACTOR.git 2 | import numpy as np 3 | import torch 4 | 5 | import contextlib 6 | 7 | from smplx import SMPLLayer as _SMPLLayer 8 | from smplx.lbs import vertices2joints 9 | 10 | 11 | # action2motion_joints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 21, 24, 38] 12 | # change 0 and 8 13 | action2motion_joints = [8, 1, 2, 3, 4, 5, 6, 7, 0, 9, 10, 11, 12, 13, 14, 21, 24, 38] 14 | 15 | from utils.config import SMPL_MODEL_PATH, JOINT_REGRESSOR_TRAIN_EXTRA 16 | 17 | JOINTSTYPE_ROOT = {"a2m": 0, # action2motion 18 | "smpl": 0, 19 | "a2mpl": 0, # set(smpl, a2m) 20 | "vibe": 8} # 0 is the 8 position: OP MidHip below 21 | 22 | JOINT_MAP = { 23 | 'OP Nose': 24, 'OP Neck': 12, 'OP RShoulder': 17, 24 | 'OP RElbow': 19, 'OP RWrist': 21, 'OP LShoulder': 16, 25 | 'OP LElbow': 18, 'OP LWrist': 20, 'OP MidHip': 0, 26 | 'OP RHip': 2, 'OP RKnee': 5, 'OP RAnkle': 8, 27 | 'OP LHip': 1, 'OP LKnee': 4, 'OP LAnkle': 7, 28 | 'OP REye': 25, 'OP LEye': 26, 'OP REar': 27, 29 | 'OP LEar': 28, 'OP LBigToe': 29, 'OP LSmallToe': 30, 30 | 'OP LHeel': 31, 'OP RBigToe': 32, 'OP RSmallToe': 33, 'OP RHeel': 34, 31 | 'Right Ankle': 8, 'Right Knee': 5, 'Right Hip': 45, 32 | 'Left Hip': 46, 'Left Knee': 4, 'Left Ankle': 7, 33 | 'Right Wrist': 21, 'Right Elbow': 19, 'Right Shoulder': 17, 34 | 'Left Shoulder': 16, 'Left Elbow': 18, 'Left Wrist': 20, 35 | 'Neck (LSP)': 47, 'Top of Head (LSP)': 48, 36 | 'Pelvis (MPII)': 49, 'Thorax (MPII)': 50, 37 | 'Spine (H36M)': 51, 'Jaw (H36M)': 52, 38 | 'Head (H36M)': 53, 'Nose': 24, 'Left Eye': 26, 39 | 'Right Eye': 25, 'Left Ear': 28, 'Right Ear': 27 40 | } 41 | 42 | JOINT_NAMES = [ 43 | 'OP Nose', 'OP Neck', 'OP RShoulder', 44 | 'OP RElbow', 'OP RWrist', 'OP LShoulder', 45 | 'OP LElbow', 'OP LWrist', 'OP MidHip', 46 | 'OP RHip', 'OP RKnee', 'OP RAnkle', 47 | 'OP LHip', 'OP LKnee', 'OP LAnkle', 48 | 'OP REye', 'OP LEye', 'OP REar', 49 | 'OP LEar', 'OP LBigToe', 'OP LSmallToe', 50 | 'OP LHeel', 'OP RBigToe', 'OP RSmallToe', 'OP RHeel', 51 | 'Right Ankle', 'Right Knee', 'Right Hip', 52 | 'Left Hip', 'Left Knee', 'Left Ankle', 53 | 'Right Wrist', 'Right Elbow', 'Right Shoulder', 54 | 'Left Shoulder', 'Left Elbow', 'Left Wrist', 55 | 'Neck (LSP)', 'Top of Head (LSP)', 56 | 'Pelvis (MPII)', 'Thorax (MPII)', 57 | 'Spine (H36M)', 'Jaw (H36M)', 58 | 'Head (H36M)', 'Nose', 'Left Eye', 59 | 'Right Eye', 'Left Ear', 'Right Ear' 60 | ] 61 | 62 | 63 | # adapted from VIBE/SPIN to output smpl_joints, vibe joints and action2motion joints 64 | class SMPL(_SMPLLayer): 65 | """ Extension of the official SMPL implementation to support more joints """ 66 | 67 | def __init__(self, model_path=SMPL_MODEL_PATH, **kwargs): 68 | kwargs["model_path"] = model_path 69 | 70 | # remove the verbosity for the 10-shapes beta parameters 71 | with contextlib.redirect_stdout(None): 72 | super(SMPL, self).__init__(**kwargs) 73 | 74 | J_regressor_extra = np.load(JOINT_REGRESSOR_TRAIN_EXTRA) 75 | self.register_buffer('J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32)) 76 | vibe_indexes = np.array([JOINT_MAP[i] for i in JOINT_NAMES]) 77 | a2m_indexes = vibe_indexes[action2motion_joints] 78 | smpl_indexes = np.arange(24) 79 | a2mpl_indexes = np.unique(np.r_[smpl_indexes, a2m_indexes]) 80 | 81 | self.maps = {"vibe": vibe_indexes, 82 | "a2m": a2m_indexes, 83 | "smpl": smpl_indexes, 84 | "a2mpl": a2mpl_indexes} 85 | 86 | def forward(self, *args, **kwargs): 87 | smpl_output = super(SMPL, self).forward(*args, **kwargs) 88 | 89 | extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices) 90 | all_joints = torch.cat([smpl_output.joints, extra_joints], dim=1) 91 | 92 | output = {"vertices": smpl_output.vertices} 93 | 94 | for joinstype, indexes in self.maps.items(): 95 | output[joinstype] = all_joints[:, indexes] 96 | 97 | return output -------------------------------------------------------------------------------- /prepare/download_a2m_datasets.sh: -------------------------------------------------------------------------------- 1 | mkdir -p dataset/ 2 | cd dataset/ 3 | 4 | echo "The datasets will be stored in the 'dataset' folder\n" 5 | 6 | # HumanAct12 poses 7 | echo "Downloading the HumanAct12 poses dataset" 8 | gdown "https://drive.google.com/uc?id=1130gHSvNyJmii7f6pv5aY5IyQIWc3t7R" 9 | echo "Extracting the HumanAct12 poses dataset" 10 | tar xfzv HumanAct12Poses.tar.gz 11 | echo "Cleaning\n" 12 | rm HumanAct12Poses.tar.gz 13 | 14 | # Donwload UESTC poses estimated with VIBE 15 | echo "Downloading the UESTC poses estimated with VIBE" 16 | gdown "https://drive.google.com/uc?id=1LE-EmYNzECU8o7A2DmqDKtqDMucnSJsy" 17 | echo "Extracting the UESTC poses estimated with VIBE" 18 | tar xjvf uestc.tar.bz2 19 | echo "Cleaning\n" 20 | rm uestc.tar.bz2 21 | 22 | echo "Downloading done!" 23 | -------------------------------------------------------------------------------- /prepare/download_glove.sh: -------------------------------------------------------------------------------- 1 | echo -e "Downloading glove (in use by the evaluators, not by GMD itself)" 2 | gdown --fuzzy https://drive.google.com/file/d/1cmXKUT31pqd7_XpJAiWEo1K81TMYHA5n/view?usp=sharing 3 | rm -rf glove 4 | 5 | unzip glove.zip 6 | echo -e "Cleaning\n" 7 | rm glove.zip 8 | 9 | echo -e "Downloading done!" -------------------------------------------------------------------------------- /prepare/download_recognition_models.sh: -------------------------------------------------------------------------------- 1 | mkdir -p assets/actionrecognition/ 2 | cd assets/actionrecognition/ 3 | 4 | echo -e "Downloading the HumanAct12 action recognition model" 5 | wget https://raw.githubusercontent.com/EricGuo5513/action-to-motion/master/model_file/action_recognition_model_humanact12.tar -O humanact12_gru.tar 6 | echo -e 7 | 8 | echo -e "Downloading the UESTC action recognition model" 9 | gdown "https://drive.google.com/uc?id=1bSSD69s1dHY7Uk0RGbGc6p7uhUxSDSBK" 10 | echo -e 11 | 12 | echo -e "Downloading done!" 13 | -------------------------------------------------------------------------------- /prepare/download_recognition_unconstrained_models.sh: -------------------------------------------------------------------------------- 1 | mkdir -p assets/actionrecognition/ 2 | cd assets/actionrecognition/ 3 | 4 | echo -e "Downloading the HumanAct12 action recognition model, adjusted for the unconstrained setting." 5 | gdown "1xfigimkPxKt3a8zvn_ME_NAR6CyTqneK" 6 | echo -e 7 | 8 | echo -e "Downloading done!" 9 | -------------------------------------------------------------------------------- /prepare/download_smpl_files.sh: -------------------------------------------------------------------------------- 1 | mkdir -p body_models 2 | cd body_models/ 3 | 4 | echo -e "The smpl files will be stored in the 'body_models/smpl/' folder\n" 5 | gdown "https://drive.google.com/uc?id=1INYlGA76ak_cKGzvpOV2Pe6RkYTlXTW2" 6 | rm -rf smpl 7 | 8 | unzip smpl.zip 9 | echo -e "Cleaning\n" 10 | rm smpl.zip 11 | 12 | echo -e "Downloading done!" -------------------------------------------------------------------------------- /prepare/download_t2m_evaluators.sh: -------------------------------------------------------------------------------- 1 | echo -e "Downloading T2M evaluators" 2 | gdown --fuzzy https://drive.google.com/file/d/1DSaKqWX2HlwBtVH5l7DdW96jeYUIXsOP/view 3 | gdown --fuzzy https://drive.google.com/file/d/1tX79xk0fflp07EZ660Xz1RAFE33iEyJR/view 4 | rm -rf t2m 5 | rm -rf kit 6 | 7 | unzip t2m.zip 8 | unzip kit.zip 9 | echo -e "Cleaning\n" 10 | rm t2m.zip 11 | rm kit.zip 12 | 13 | echo -e "Downloading done!" -------------------------------------------------------------------------------- /prepare/download_unconstrained_datasets.sh: -------------------------------------------------------------------------------- 1 | mkdir -p dataset/HumanAct12Poses 2 | cd dataset/HumanAct12Poses 3 | 4 | echo "The datasets will be stored in the 'dataset' folder\n" 5 | 6 | # HumanAct12 poses unconstrained 7 | echo "Downloading the HumanAct12 unconstrained poses dataset" 8 | gdown "1KqOBTtLFgkvWSZb8ao-wdBMG7sTP3Q7d" 9 | 10 | echo "Downloading done!" 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | blobfile==2.0.2 2 | chumpy==0.70 3 | einops==0.6.1 4 | ffmpeg==1.4 5 | gdown==4.7.1 6 | human-body-prior==0.8.5.0 7 | matplotlib==3.1.3 8 | numpy==1.21.5 9 | nvidia-cublas-cu11==11.10.3.66 10 | nvidia-cuda-nvrtc-cu11==11.7.99 11 | nvidia-cuda-runtime-cu11==11.7.99 12 | nvidia-cudnn-cu11==8.5.0.96 13 | Pillow==9.2.0 14 | scikit-learn==1.0.2 15 | scipy==1.7.3 16 | seaborn==0.12.2 17 | six==1.16.0 18 | smplx==0.1.28 19 | spacy==3.3.1 20 | torch==1.13.1 21 | torchvision==0.14.1 22 | tqdm==4.66.1 23 | wandb==0.16.1 24 | 25 | # Also, must install Clip: 26 | ### pip install git+https://github.com/openai/CLIP.git 27 | -------------------------------------------------------------------------------- /sample/gmd/keyframe_pattern.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def get_kframes(ground_positions=None, pattern="square", interpolate=False): 4 | # ground_positions = None 5 | if ground_positions is not None: 6 | # Add frame index to ground_positions 7 | # k_positions = [1, 2, 3, 15, 30, 45, 60, 75, 90, 105, 120] 8 | # k_positions = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 9 | # 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 10 | # 45, 60, 75, 90, 11 | # 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120] 12 | # k_positions = [0, 1, 2, 3, 4, 95, 96, 97, 98, 99] 13 | k_positions = [ii for ii in range(1, 120, 1)] 14 | if 119 not in k_positions: 15 | k_positions.append(119) 16 | kframes = [] 17 | for k_posi in k_positions: 18 | kframes.append((k_posi, (float(ground_positions[k_posi - 1, 0, 0]), 19 | float(ground_positions[k_posi - 1, 0, 20 | 2])))) 21 | return kframes 22 | 23 | 24 | if pattern == "square": 25 | kframes = [ ( 1, (0.0, 0.0)), 26 | (30, (0.0, 3.0)), 27 | (45, (1.5, 3.0)), 28 | (60, (3.0, 3.0)), 29 | (75, (3.0, 1.5)), 30 | (90, (3.0, 0.0)), 31 | (105, (1.5, 0.0)), 32 | (119, (0.0, 0.0)) 33 | ] 34 | elif pattern == "inverse_N": 35 | kframes = [ ( 1, (0.0, 0.0)), 36 | # (30, (0.0- 2.0, 3.0- 2.0)), 37 | # (45, (1.5- 2.0, 1.5- 2.0)), 38 | # (60, (3.0- 2.0, 0.0- 2.0)), 39 | # (75, (3.0- 2.0, 1.5- 2.0)), 40 | # (90, (3.0- 2.0, 3.0- 2.0)), 41 | (30, (0.0, 3.0)), 42 | (45, (1.5, 1.5)), 43 | (60, (3.0, 0.0)), 44 | # (75, (3.0, 1.5)), 45 | (90, (3.0, 3.0)), 46 | # (105, (1.5, 0.0)), 47 | (119, (0.0, 0.0)) 48 | ] 49 | elif pattern == "3dots": 50 | kframes = [ ( 1, (0.0, 0.0)), 51 | # (29, (0.0, 2.0)), 52 | # (45, (0.0, 3.0)), 53 | # (31, (0.0, 2.0)), 54 | # (59, (2.0, 2.0)), 55 | # (59, (3.0, 3.0)), 56 | (59, (0.0, 3.0)), 57 | # (89, (3.0, 3.0)), 58 | # (89, (3.0, 0.0)), 59 | # (119, (0.0, 3.0)), 60 | # (91, (3.0, 0.0)), 61 | # (105, (1.5, 0.0)), 62 | (119, (3.0, 3.0)) 63 | ] 64 | elif pattern == "sdf": 65 | kframes = [ 66 | (1, (0.0, 0.0)), 67 | # (90, (2.0, 3.0)), 68 | # (91, (2.0, 3.0)), 69 | # (92, (2.0, 3.0)), 70 | # (93, (2.0, 3.0)), 71 | # (94, (2.0, 3.0)), 72 | # (116, (3.0, 4.5)), 73 | # (117, (3.0, 4.5)), 74 | # (118, (3.0, 4.5)), 75 | (119, (2.0, 2.0)), 76 | ] 77 | elif pattern == "zigzag": 78 | kframes = [ 79 | (1, (0.0, 0.0)), 80 | (40, (0.0,2.0)), 81 | (79, (2.0, 2.0)), 82 | # (119, (2.0, 3.0)), 83 | # (90, (2.0, 3.0)), 84 | # (91, (2.0, 3.0)), 85 | # (92, (2.0, 3.0)), 86 | # (93, (2.0, 3.0)), 87 | # (94, (3.0, 3.0)), 88 | # (94, (-1.5, 2.0)), 89 | (116, (2.0, 4.0)), 90 | # (117, (3.0, 4.5)), 91 | # (118, (3.0, 4.5)), 92 | # (119, (0.0, 0.0)), 93 | ] 94 | else: 95 | # kframes = [ 96 | # (1, (0.0, 0.0)), 97 | # (80, (3.0, 5.0)), 98 | # ] 99 | kframes = [ 100 | (1, (0.0, 0.0)), 101 | # (30, (0.0, 2.0)), 102 | # (30, (0.0, 3.0)), 103 | # (45, (1.5, 3.0)), 104 | # (60, (2.2, 2.2)), 105 | 106 | (90, (2.0, 3.0)), 107 | (91, (2.0, 3.0)), 108 | (92, (2.0, 3.0)), 109 | (93, (2.0, 3.0)), 110 | (94, (2.0, 3.0)), 111 | 112 | # (60, (0.0, 3.0)), 113 | # (75, (2.5, 4)), 114 | # (120, (0.0, 4.0)), 115 | # (90, (3.0, 4.0)), 116 | # (91, (3.0, 4.0)), 117 | # (92, (3.0, 4.0)), 118 | # (93, (3.0, 4.0)), 119 | # (105, (1.5, 0.0)), 120 | (116, (3.0, 4.5)), 121 | (117, (3.0, 4.5)), 122 | (118, (3.0, 4.5)), 123 | (119, (3.0, 4.5)), 124 | # (180, (-3.0, 4.0)), 125 | # (196, (-3.0, 6.0)), 126 | ] 127 | 128 | # if interpolate: 129 | # kframes = interpolate_kps(kframes) 130 | return kframes 131 | 132 | 133 | def get_obstacles(): 134 | # Obstacles for obstacle avoidance task. Each one is a circle with radius 135 | # on the xz plane with center at (x, z) 136 | obs_list = [ 137 | # ((-0.2, 3.5) , 0.5), 138 | ((4, 1.5) , 0.7), 139 | ((0.7, 1.5) , 0.6), 140 | ] 141 | return obs_list 142 | 143 | 144 | def interpolate_kps(kframes): 145 | kframes_new = [] 146 | lastx, lasty = 0.0, 0.0 147 | last = 0 148 | for frame, loc in kframes: 149 | diff = frame - last 150 | for i in range(diff): 151 | kframes_new.append((last + i, (lastx + (loc[0] - lastx) * i / diff , lasty + (loc[1] - lasty) * i / diff))) 152 | # kframes_new.append((frame, loc)) 153 | lastx, lasty = loc 154 | last = frame 155 | # Add the last frame 156 | kframes_new.append((frame, loc)) 157 | kframes = kframes_new 158 | return kframes -------------------------------------------------------------------------------- /train/train_condmdi.py: -------------------------------------------------------------------------------- 1 | # This code is based on https://github.com/openai/guided-diffusion, 2 | # and is used to train a diffusion model on human motion sequences. 3 | 4 | import os 5 | import sys 6 | import json 7 | from pprint import pprint 8 | from utils.fixseed import fixseed 9 | from utils.parser_util import train_args 10 | from utils import dist_util 11 | from train.training_loop import TrainLoop 12 | from data_loaders.get_data import DatasetConfig, get_dataset_loader 13 | from utils.model_util import create_model_and_diffusion 14 | from configs import card 15 | import wandb 16 | 17 | 18 | def init_wandb(config, project_name=None, entity=None, tags=[], notes=None, **kwargs): 19 | if entity is None: 20 | assert ( 21 | "WANDB_ENTITY" in os.environ 22 | ), "Please either pass in \"entity\" to logging.init or set environment variable 'WANDB_ENTITY' to your wandb entity name." 23 | if project_name is None: 24 | assert ( 25 | "WANDB_PROJECT" in os.environ 26 | ), "Please either pass in \"project_name\" to logging.init or set environment variable 'WANDB_PROJECT' to your wandb project name." 27 | tags.append(os.path.basename(sys.argv[0])) 28 | if "_MY_JOB_ID" in os.environ: 29 | x = f"(jobid:{os.environ['_MY_JOB_ID']})" 30 | notes = x if notes is None else notes + " " + x 31 | if len(config.resume_checkpoint) > 0: 32 | # FIXME: this is specific to the current project's setting 33 | run_id = config.resume_checkpoint.split("/")[-2] 34 | wandb.init(project=project_name, entity=entity, config=config, tags=tags, notes=notes, resume="allow", id=run_id, **kwargs) 35 | else: 36 | wandb.init(project=project_name, entity=entity, config=config, tags=tags, notes=notes, **kwargs) 37 | 38 | 39 | def main(): 40 | args = train_args(base_cls=card.motion_abs_unet_adagn_xl) # Choose the default full motion model from GMD 41 | init_wandb(config=args) 42 | args.save_dir = os.path.join("save", wandb.run.id) 43 | pprint(args.__dict__) 44 | fixseed(args.seed) 45 | 46 | if args.save_dir is None: 47 | raise FileNotFoundError('save_dir was not specified.') 48 | elif not os.path.exists(args.save_dir): 49 | os.makedirs(args.save_dir) 50 | args_path = os.path.join(args.save_dir, 'args.json') 51 | with open(args_path, 'w') as fw: 52 | json.dump(vars(args), fw, indent=4, sort_keys=True) 53 | 54 | dist_util.setup_dist(args.device) 55 | 56 | print("creating data loader...") 57 | data_conf = DatasetConfig( 58 | name=args.dataset, 59 | batch_size=args.batch_size, 60 | num_frames=args.num_frames, 61 | use_abs3d=args.abs_3d, 62 | traject_only=args.traj_only, 63 | use_random_projection=args.use_random_proj, 64 | random_projection_scale=args.random_proj_scale, 65 | augment_type=args.augment_type, 66 | std_scale_shift=args.std_scale_shift, 67 | drop_redundant=args.drop_redundant, 68 | ) 69 | 70 | data = get_dataset_loader(data_conf) 71 | 72 | print("creating model and diffusion...") 73 | model, diffusion = create_model_and_diffusion(args, data) 74 | model.to(dist_util.dev()) 75 | model.rot2xyz.smpl_model.eval() 76 | 77 | print('Total params: %.2fM' % 78 | (sum(p.numel() for p in model.parameters_wo_clip()) / 1000000.0)) 79 | print("Training...") 80 | TrainLoop(args, model, diffusion, data).run_loop() 81 | wandb.finish() 82 | 83 | 84 | if __name__ == "__main__": 85 | main() -------------------------------------------------------------------------------- /utils/PYTORCH3D_LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For PyTorch3D software 4 | 5 | Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Facebook nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | SMPL_DATA_PATH = "./body_models/smpl" 4 | 5 | SMPL_KINTREE_PATH = os.path.join(SMPL_DATA_PATH, "kintree_table.pkl") 6 | SMPL_MODEL_PATH = os.path.join(SMPL_DATA_PATH, "SMPL_NEUTRAL.pkl") 7 | JOINT_REGRESSOR_TRAIN_EXTRA = os.path.join(SMPL_DATA_PATH, 'J_regressor_extra.npy') 8 | 9 | ROT_CONVENTION_TO_ROT_NUMBER = { 10 | 'legacy': 23, 11 | 'no_hands': 21, 12 | 'full_hands': 51, 13 | 'mitten_hands': 33, 14 | } 15 | 16 | GENDERS = ['neutral', 'male', 'female'] 17 | NUM_BETAS = 10 -------------------------------------------------------------------------------- /utils/dist_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for distributed training. 3 | """ 4 | 5 | import socket 6 | 7 | import torch as th 8 | import torch.distributed as dist 9 | 10 | # Change this to reflect your cluster layout. 11 | # The GPU for a given rank is (rank % GPUS_PER_NODE). 12 | GPUS_PER_NODE = 8 13 | 14 | SETUP_RETRY_COUNT = 3 15 | 16 | used_device = 0 17 | 18 | def setup_dist(device=0): 19 | """ 20 | Setup a distributed process group. 21 | """ 22 | global used_device 23 | used_device = device 24 | if dist.is_initialized(): 25 | return 26 | # os.environ["CUDA_VISIBLE_DEVICES"] = str(device) # f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}" 27 | 28 | # comm = MPI.COMM_WORLD 29 | # backend = "gloo" if not th.cuda.is_available() else "nccl" 30 | 31 | # if backend == "gloo": 32 | # hostname = "localhost" 33 | # else: 34 | # hostname = socket.gethostbyname(socket.getfqdn()) 35 | # os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) 36 | # os.environ["RANK"] = str(comm.rank) 37 | # os.environ["WORLD_SIZE"] = str(comm.size) 38 | 39 | # port = comm.bcast(_find_free_port(), root=used_device) 40 | # os.environ["MASTER_PORT"] = str(port) 41 | # dist.init_process_group(backend=backend, init_method="env://") 42 | 43 | 44 | def dev(): 45 | """ 46 | Get the device to use for torch.distributed. 47 | """ 48 | global used_device 49 | if th.cuda.is_available() and used_device>=0: 50 | return th.device(f"cuda:{used_device}") 51 | return th.device("cpu") 52 | 53 | 54 | def load_state_dict(path, **kwargs): 55 | """ 56 | Load a PyTorch file without redundant fetches across MPI ranks. 57 | """ 58 | return th.load(path, **kwargs) 59 | 60 | 61 | def sync_params(params): 62 | """ 63 | Synchronize a sequence of Tensors across ranks from rank 0. 64 | """ 65 | for p in params: 66 | with th.no_grad(): 67 | dist.broadcast(p, 0) 68 | 69 | 70 | def _find_free_port(): 71 | try: 72 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 73 | s.bind(("", 0)) 74 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 75 | return s.getsockname()[1] 76 | finally: 77 | s.close() 78 | -------------------------------------------------------------------------------- /utils/fixseed.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | 5 | 6 | def fixseed(seed): 7 | torch.backends.cudnn.benchmark = False 8 | random.seed(seed) 9 | np.random.seed(seed) 10 | torch.manual_seed(seed) 11 | torch.cuda.manual_seed(seed) 12 | 13 | 14 | # SEED = 10 15 | # EVALSEED = 0 16 | # # Provoc warning: not fully functionnal yet 17 | # # torch.set_deterministic(True) 18 | # torch.backends.cudnn.benchmark = False 19 | # fixseed(SEED) 20 | -------------------------------------------------------------------------------- /utils/generation_template.py: -------------------------------------------------------------------------------- 1 | from utils.parser_util import FullModelArgs 2 | 3 | 4 | def get_template(args: FullModelArgs, template_name="no"): 5 | # [no, trajectory, kps, sdf] 6 | if template_name == "mdm_legacy": 7 | updated_args = mdm_template(args) 8 | elif template_name == "no": 9 | updated_args = args 10 | elif template_name == "trajectory": 11 | updated_args = trajectory_template(args) 12 | elif template_name == "kps": 13 | updated_args = kps_template(args) 14 | elif template_name == "sdf": 15 | updated_args = sdf_template(args) 16 | elif template_name == "testing": 17 | updated_args = testing_template(args) 18 | else: 19 | raise NotImplementedError() 20 | return updated_args 21 | 22 | 23 | def mdm_template(args: FullModelArgs): 24 | # NOTE: backward compatible. Otherwise, get_template() is only allowed to change generate args. 25 | MODEL_NAME = 'motion,trans_enc,x0,rel,normal'.split(',') 26 | # args.gen_avg_model = False 27 | args.motion_length = 6.0 28 | args.abs_3d = False 29 | args.gen_two_stages = False 30 | args.do_inpaint = True 31 | # This "mdm_legacy" mode is only used when we do trajectory imputing with mdm 32 | args.guidance_mode = "mdm_legacy" 33 | 34 | return args 35 | 36 | 37 | def trajectory_template(args: FullModelArgs): 38 | args.do_inpaint = True 39 | # Data flags 40 | # NOTE: this should already be in json for new model 41 | # May need to update json for previous model 42 | # args.use_random_proj = True 43 | # args.random_proj_scale = 10.0 44 | args.guidance_mode = "trajectory" # ["no", "trajectory", "kps", "sdf"] 45 | args.gen_two_stages = False 46 | return args 47 | 48 | 49 | def kps_template(args: FullModelArgs): 50 | args.do_inpaint = True 51 | args.guidance_mode = "kps" # ["no", "trajectory", "kps", "sdf"] 52 | args.gen_two_stages = True 53 | # NOTE: set imputation p2p mode here 54 | # args.p2p_impute = False 55 | args.p2p_impute = True 56 | 57 | return args 58 | 59 | 60 | def sdf_template(args: FullModelArgs): 61 | args.do_inpaint = True 62 | args.guidance_mode = "sdf" # ["no", "trajectory", "kps", "sdf"] 63 | args.gen_two_stages = True 64 | args.p2p_impute = False 65 | return args 66 | 67 | 68 | def testing_template(args: FullModelArgs): 69 | args.do_inpaint = False # True 70 | args.guidance_mode = "no" # ["no", "trajectory", "kps", "sdf"] 71 | # args.classifier_scale = 1.0 72 | args.gen_two_stages = False 73 | args.p2p_impute = False 74 | args.use_ddim = False # True 75 | # args.motion_length = 4.5 76 | args.interpolate_cond = False # True 77 | return args -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def to_numpy(tensor): 5 | if torch.is_tensor(tensor): 6 | return tensor.cpu().numpy() 7 | elif type(tensor).__module__ != 'numpy': 8 | raise ValueError("Cannot convert {} to numpy array".format( 9 | type(tensor))) 10 | return tensor 11 | 12 | 13 | def to_torch(ndarray): 14 | if type(ndarray).__module__ == 'numpy': 15 | return torch.from_numpy(ndarray) 16 | elif not torch.is_tensor(ndarray): 17 | raise ValueError("Cannot convert {} to torch tensor".format( 18 | type(ndarray))) 19 | return ndarray 20 | 21 | 22 | def cleanexit(): 23 | import sys 24 | import os 25 | try: 26 | sys.exit(0) 27 | except SystemExit: 28 | os._exit(0) 29 | 30 | def load_model_wo_clip(model, state_dict): 31 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) 32 | assert len(unexpected_keys) == 0 33 | assert all([k.startswith('clip_model.') for k in missing_keys]) 34 | 35 | def freeze_joints(x, joints_to_freeze): 36 | # Freezes selected joint *rotations* as they appear in the first frame 37 | # x [bs, [root+n_joints], joint_dim(6), seqlen] 38 | frozen = x.detach().clone() 39 | frozen[:, joints_to_freeze, :, :] = frozen[:, joints_to_freeze, :, :1] 40 | return frozen 41 | -------------------------------------------------------------------------------- /utils/model_util.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | from torch import nn 5 | from data_loaders.humanml.data.dataset import Text2MotionDatasetV2, HumanML3D, TextOnlyDataset 6 | 7 | from diffusion import gaussian_diffusion as gd 8 | from diffusion.respace import DiffusionConfig, SpacedDiffusion, space_timesteps 9 | from model.mdm import MDM 10 | from model.mdm_dit import MDM_DiT 11 | from model.mdm_unet import MDM_UNET 12 | from utils.parser_util import DataOptions, DiffusionOptions, ModelOptions, TrainingOptions 13 | from torch.utils.data import DataLoader 14 | 15 | FullModelOptions = Union[DataOptions, ModelOptions, DiffusionOptions, TrainingOptions] 16 | Datasets = Union[Text2MotionDatasetV2, HumanML3D, TextOnlyDataset] 17 | 18 | 19 | def load_model_wo_clip(model: nn.Module, state_dict): 20 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, 21 | strict=False) 22 | assert len(unexpected_keys) == 0, f'unexpected keys: {unexpected_keys}' 23 | assert all([k.startswith('clip_model.') for k in missing_keys]) 24 | 25 | 26 | def create_model_and_diffusion(args: FullModelOptions, data: DataLoader): 27 | if args.arch.startswith('dit'): 28 | # NOTE: adding 'two_head' in the args.arch would imply two_head=True 29 | # the model would predict both eps and x0. 30 | model = MDM_DiT(**get_model_args(args, data)) 31 | elif args.arch.startswith('unet'): 32 | assert 'two_head' not in args.arch, 'unet does not support two_head' 33 | model = MDM_UNET(**get_model_args(args, data)) 34 | else: 35 | model = MDM(**get_model_args(args, data)) 36 | diffusion = create_gaussian_diffusion(args) 37 | return model, diffusion 38 | 39 | 40 | def get_model_args(args: FullModelOptions, data: DataLoader): 41 | # default args 42 | clip_version = 'ViT-B/32' 43 | action_emb = 'tensor' 44 | if args.unconstrained: 45 | cond_mode = 'no_cond' 46 | elif args.dataset == 'amass': 47 | cond_mode = 'no_cond' 48 | elif args.dataset in ['kit', 'humanml']: 49 | cond_mode = 'text' 50 | else: 51 | cond_mode = 'action' 52 | if hasattr(data.dataset, 'num_actions'): 53 | num_actions = data.dataset.num_actions 54 | else: 55 | num_actions = 1 56 | 57 | # SMPL defaults 58 | data_rep = 'rot6d' 59 | njoints = 25 60 | nfeats = 6 61 | 62 | if args.dataset == 'humanml': 63 | data_rep = 'hml_vec' 64 | nfeats = 1 65 | if args.drop_redundant: 66 | njoints = 67 # 4 + 21 * 3 67 | else: 68 | njoints = 263 69 | elif args.dataset == 'kit': 70 | data_rep = 'hml_vec' 71 | njoints = 251 72 | nfeats = 1 73 | elif args.dataset == 'amass': 74 | data_rep = 'hml_vec' # FIXME: find what is the correct data rep 75 | njoints = 764 76 | nfeats = 1 77 | 78 | # Only produce trajectory (4 values: rot, x, z, y) 79 | if args.traj_only: 80 | njoints = 4 81 | nfeats = 1 82 | 83 | # whether to predict xstart and eps at the same time 84 | two_head = 'two_head' in args.arch 85 | 86 | return { 87 | 'modeltype': '', 88 | 'njoints': njoints, 89 | 'nfeats': nfeats, 90 | 'num_actions': num_actions, 91 | 'translation': True, 92 | 'pose_rep': 'rot6d', 93 | 'glob': True, 94 | 'glob_rot': True, 95 | 'latent_dim': args.latent_dim, 96 | 'ff_size': args.ff_size, 97 | 'num_layers': args.layers, 98 | 'num_heads': 4, 99 | 'dropout': 0.1, 100 | 'activation': "gelu", 101 | 'data_rep': data_rep, 102 | 'cond_mode': cond_mode, 103 | 'cond_mask_prob': args.cond_mask_prob, 104 | 'action_emb': action_emb, 105 | 'arch': args.arch, 106 | 'emb_trans_dec': args.emb_trans_dec, 107 | 'clip_version': clip_version, 108 | 'dataset': args.dataset, 109 | 'two_head': two_head, 110 | 'dim_mults': args.dim_mults, 111 | 'adagn': args.unet_adagn, 112 | 'zero': args.unet_zero, 113 | 'unet_out_mult': args.out_mult, 114 | 'tf_out_mult': args.out_mult, 115 | 'xz_only': args.xz_only, 116 | 'keyframe_conditioned': args.keyframe_conditioned, 117 | 'keyframe_selection_scheme': args.keyframe_selection_scheme, 118 | 'zero_keyframe_loss': args.zero_keyframe_loss, 119 | } 120 | 121 | 122 | def create_gaussian_diffusion(args: FullModelOptions): 123 | steps = 1000 # 1000 124 | scale_beta = 1. # no scaling 125 | if args.use_ddim: 126 | timestep_respacing = 'ddim100' # 'ddim100' # can be used for ddim sampling, we don't use it. 127 | else: 128 | timestep_respacing = '' # can be used for ddim sampling, we don't use it. 129 | learn_sigma = False 130 | rescale_timesteps = False 131 | 132 | betas = gd.get_named_beta_schedule(args.noise_schedule, steps, scale_beta) 133 | loss_type = gd.LossType.MSE 134 | 135 | if not timestep_respacing: 136 | timestep_respacing = [steps] 137 | 138 | return SpacedDiffusion( 139 | use_timesteps=space_timesteps(steps, timestep_respacing), 140 | conf=DiffusionConfig( 141 | betas=betas, 142 | model_mean_type=(gd.ModelMeanType.EPSILON 143 | if not args.predict_xstart else 144 | gd.ModelMeanType.START_X), 145 | model_var_type=( 146 | (gd.ModelVarType.FIXED_LARGE 147 | if not args.sigma_small else gd.ModelVarType.FIXED_SMALL) 148 | if not learn_sigma else gd.ModelVarType.LEARNED_RANGE), 149 | loss_type=loss_type, 150 | rescale_timesteps=rescale_timesteps, 151 | lambda_vel=args.lambda_vel, 152 | lambda_rcxyz=args.lambda_rcxyz, 153 | lambda_fc=args.lambda_fc, 154 | clip_range=args.clip_range, 155 | train_trajectory_only_xz=args.xz_only, 156 | use_random_proj=args.use_random_proj, 157 | fp16=args.use_fp16, 158 | traj_only=args.traj_only, 159 | abs_3d=args.abs_3d, 160 | apply_zero_mask=args.apply_zero_mask, 161 | traj_extra_weight=args.traj_extra_weight, 162 | time_weighted_loss=args.time_weighted_loss, 163 | train_x0_as_eps=args.train_x0_as_eps, 164 | ), 165 | ) 166 | 167 | 168 | def load_saved_model(model, model_path, use_avg: bool=True): # use_avg_model 169 | state_dict = torch.load(model_path, map_location='cpu') 170 | # Use average model when possible 171 | if use_avg and 'model_avg' in state_dict.keys(): 172 | # if use_avg_model: 173 | print('loading avg model') 174 | state_dict = state_dict['model_avg'] 175 | else: 176 | if 'model' in state_dict: 177 | print('loading model without avg') 178 | state_dict = state_dict['model'] 179 | else: 180 | print('checkpoint has no avg model') 181 | load_model_wo_clip(model, state_dict) 182 | return model 183 | -------------------------------------------------------------------------------- /visualize/joints2smpl/README.md: -------------------------------------------------------------------------------- 1 | # joints2smpl 2 | fit SMPL model using 3D joints 3 | 4 | ## Prerequisites 5 | We have tested the code on Ubuntu 18.04/20.04 with CUDA 10.2/11.3 6 | 7 | ## Installation 8 | First you have to make sure that you have all dependencies in place. 9 | The simplest way to do is to use the [anaconda](https://www.anaconda.com/). 10 | 11 | You can create an anaconda environment called `fit3d` using 12 | ``` 13 | conda env create -f environment.yaml 14 | conda activate fit3d 15 | ``` 16 | 17 | ## Download SMPL models 18 | Download [SMPL Female and Male](https://smpl.is.tue.mpg.de/) and [SMPL Netural](https://smplify.is.tue.mpg.de/), and rename the files and extract them to `/smpl_models/smpl/`, eventually, the `/smpl_models` folder should have the following structure: 19 | ``` 20 | smpl_models 21 | └-- smpl 22 | └-- SMPL_FEMALE.pkl 23 | └-- SMPL_MALE.pkl 24 | └-- SMPL_NEUTRAL.pkl 25 | ``` 26 | 27 | ## Demo 28 | ### Demo for sequences 29 | python fit_seq.py --files test_motion2.npy 30 | 31 | The results will locate in ./demo/demo_results/ 32 | 33 | ## Citation 34 | If you find this project useful for your research, please consider citing: 35 | ``` 36 | @article{zuo2021sparsefusion, 37 | title={Sparsefusion: Dynamic human avatar modeling from sparse rgbd images}, 38 | author={Zuo, Xinxin and Wang, Sen and Zheng, Jiangbin and Yu, Weiwei and Gong, Minglun and Yang, Ruigang and Cheng, Li}, 39 | journal={IEEE Transactions on Multimedia}, 40 | volume={23}, 41 | pages={1617--1629}, 42 | year={2021} 43 | } 44 | ``` 45 | 46 | ## References 47 | We indicate if a function or script is borrowed externally inside each file. Here are some great resources we 48 | benefit: 49 | 50 | - Shape/Pose prior and some functions are borrowed from [VIBE](https://github.com/mkocabas/VIBE). 51 | - SMPL models and layer is from [SMPL-X model](https://github.com/vchoutas/smplx). 52 | - Some functions are borrowed from [HMR-pytorch](https://github.com/MandyMo/pytorch_HMR). 53 | -------------------------------------------------------------------------------- /visualize/joints2smpl/environment.yaml: -------------------------------------------------------------------------------- 1 | name: fit3d 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | - defaults 6 | - pytorch3d 7 | - open3d-admin 8 | - anaconda 9 | dependencies: 10 | - pip=21.1.3 11 | - numpy=1.20.3 12 | - numpy-base=1.20.3 13 | - matplotlib=3.4.2 14 | - matplotlib-base=3.4.2 15 | - pandas=1.3.1 16 | - python=3.7.6 17 | - pytorch=1.7.1 18 | - tensorboardx=2.2 19 | - cudatoolkit=10.2.89 20 | - torchvision=0.8.2 21 | - einops=0.3.0 22 | - pytorch3d=0.4.0 23 | - tqdm=4.61.2 24 | - trimesh=3.9.24 25 | - joblib=1.0.1 26 | - open3d=0.13.0 27 | - pip: 28 | - h5py==2.9.0 29 | - chumpy==0.70 30 | - smplx==0.1.28 31 | -------------------------------------------------------------------------------- /visualize/joints2smpl/fit_seq.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import argparse 3 | import torch 4 | import os,sys 5 | from os import walk, listdir 6 | from os.path import isfile, join 7 | import numpy as np 8 | import joblib 9 | import smplx 10 | import trimesh 11 | import h5py 12 | from tqdm import tqdm 13 | 14 | sys.path.append(os.path.join(os.path.dirname(__file__), "src")) 15 | from smplify import SMPLify3D 16 | import config 17 | 18 | # parsing argmument 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--batchSize', type=int, default=1, 21 | help='input batch size') 22 | parser.add_argument('--num_smplify_iters', type=int, default=100, 23 | help='num of smplify iters') 24 | parser.add_argument('--cuda', type=bool, default=False, 25 | help='enables cuda') 26 | parser.add_argument('--gpu_ids', type=int, default=0, 27 | help='choose gpu ids') 28 | parser.add_argument('--num_joints', type=int, default=22, 29 | help='joint number') 30 | parser.add_argument('--joint_category', type=str, default="AMASS", 31 | help='use correspondence') 32 | parser.add_argument('--fix_foot', type=str, default="False", 33 | help='fix foot or not') 34 | parser.add_argument('--data_folder', type=str, default="./demo/demo_data/", 35 | help='data in the folder') 36 | parser.add_argument('--save_folder', type=str, default="./demo/demo_results/", 37 | help='results save folder') 38 | parser.add_argument('--files', type=str, default="test_motion.npy", 39 | help='files use') 40 | opt = parser.parse_args() 41 | print(opt) 42 | 43 | # ---load predefined something 44 | device = torch.device("cuda:" + str(opt.gpu_ids) if opt.cuda else "cpu") 45 | print(config.SMPL_MODEL_DIR) 46 | smplmodel = smplx.create(config.SMPL_MODEL_DIR, 47 | model_type="smpl", gender="neutral", ext="pkl", 48 | batch_size=opt.batchSize).to(device) 49 | 50 | # ## --- load the mean pose as original ---- 51 | smpl_mean_file = config.SMPL_MEAN_FILE 52 | 53 | file = h5py.File(smpl_mean_file, 'r') 54 | init_mean_pose = torch.from_numpy(file['pose'][:]).unsqueeze(0).float() 55 | init_mean_shape = torch.from_numpy(file['shape'][:]).unsqueeze(0).float() 56 | cam_trans_zero = torch.Tensor([0.0, 0.0, 0.0]).to(device) 57 | # 58 | pred_pose = torch.zeros(opt.batchSize, 72).to(device) 59 | pred_betas = torch.zeros(opt.batchSize, 10).to(device) 60 | pred_cam_t = torch.zeros(opt.batchSize, 3).to(device) 61 | keypoints_3d = torch.zeros(opt.batchSize, opt.num_joints, 3).to(device) 62 | 63 | # # #-------------initialize SMPLify 64 | smplify = SMPLify3D(smplxmodel=smplmodel, 65 | batch_size=opt.batchSize, 66 | joints_category=opt.joint_category, 67 | num_iters=opt.num_smplify_iters, 68 | device=device) 69 | #print("initialize SMPLify3D done!") 70 | 71 | 72 | purename = os.path.splitext(opt.files)[0] 73 | # --- load data --- 74 | data = np.load(opt.data_folder + "/" + purename + ".npy") # [nframes, njoints, 3] 75 | 76 | dir_save = os.path.join(opt.save_folder, purename) 77 | if not os.path.isdir(dir_save): 78 | os.makedirs(dir_save, exist_ok=True) 79 | 80 | # run the whole seqs 81 | num_seqs = data.shape[0] 82 | 83 | for idx in tqdm(range(num_seqs)): 84 | #print(idx) 85 | 86 | joints3d = data[idx] #*1.2 #scale problem [check first] 87 | keypoints_3d[0, :, :] = torch.Tensor(joints3d).to(device).float() 88 | 89 | if idx == 0: 90 | pred_betas[0, :] = init_mean_shape 91 | pred_pose[0, :] = init_mean_pose 92 | pred_cam_t[0, :] = cam_trans_zero 93 | else: 94 | data_param = joblib.load(dir_save + "/" + "%04d"%(idx-1) + ".pkl") 95 | pred_betas[0, :] = torch.from_numpy(data_param['beta']).unsqueeze(0).float() 96 | pred_pose[0, :] = torch.from_numpy(data_param['pose']).unsqueeze(0).float() 97 | pred_cam_t[0, :] = torch.from_numpy(data_param['cam']).unsqueeze(0).float() 98 | 99 | if opt.joint_category =="AMASS": 100 | confidence_input = torch.ones(opt.num_joints) 101 | # make sure the foot and ankle 102 | if opt.fix_foot == True: 103 | confidence_input[7] = 1.5 104 | confidence_input[8] = 1.5 105 | confidence_input[10] = 1.5 106 | confidence_input[11] = 1.5 107 | else: 108 | print("Such category not settle down!") 109 | 110 | # ----- from initial to fitting ------- 111 | new_opt_vertices, new_opt_joints, new_opt_pose, new_opt_betas, \ 112 | new_opt_cam_t, new_opt_joint_loss = smplify( 113 | pred_pose.detach(), 114 | pred_betas.detach(), 115 | pred_cam_t.detach(), 116 | keypoints_3d, 117 | conf_3d=confidence_input.to(device), 118 | seq_ind=idx 119 | ) 120 | 121 | # # -- save the results to ply--- 122 | outputp = smplmodel(betas=new_opt_betas, global_orient=new_opt_pose[:, :3], body_pose=new_opt_pose[:, 3:], 123 | transl=new_opt_cam_t, return_verts=True) 124 | mesh_p = trimesh.Trimesh(vertices=outputp.vertices.detach().cpu().numpy().squeeze(), faces=smplmodel.faces, process=False) 125 | mesh_p.export(dir_save + "/" + "%04d"%idx + ".ply") 126 | 127 | # save the pkl 128 | param = {} 129 | param['beta'] = new_opt_betas.detach().cpu().numpy() 130 | param['pose'] = new_opt_pose.detach().cpu().numpy() 131 | param['cam'] = new_opt_cam_t.detach().cpu().numpy() 132 | joblib.dump(param, dir_save + "/" + "%04d"%idx + ".pkl", compress=3) 133 | -------------------------------------------------------------------------------- /visualize/joints2smpl/smpl_models/SMPL_downsample_index.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/setarehc/diffusion-motion-inbetweening/0c58473066ef582b9464a66d5542e0d1b6ae66c8/visualize/joints2smpl/smpl_models/SMPL_downsample_index.pkl -------------------------------------------------------------------------------- /visualize/joints2smpl/smpl_models/neutral_smpl_mean_params.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/setarehc/diffusion-motion-inbetweening/0c58473066ef582b9464a66d5542e0d1b6ae66c8/visualize/joints2smpl/smpl_models/neutral_smpl_mean_params.h5 -------------------------------------------------------------------------------- /visualize/joints2smpl/src/config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # Map joints Name to SMPL joints idx 4 | JOINT_MAP = { 5 | 'MidHip': 0, 6 | 'LHip': 1, 'LKnee': 4, 'LAnkle': 7, 'LFoot': 10, 7 | 'RHip': 2, 'RKnee': 5, 'RAnkle': 8, 'RFoot': 11, 8 | 'LShoulder': 16, 'LElbow': 18, 'LWrist': 20, 'LHand': 22, 9 | 'RShoulder': 17, 'RElbow': 19, 'RWrist': 21, 'RHand': 23, 10 | 'spine1': 3, 'spine2': 6, 'spine3': 9, 'Neck': 12, 'Head': 15, 11 | 'LCollar':13, 'Rcollar' :14, 12 | 'Nose':24, 'REye':26, 'LEye':26, 'REar':27, 'LEar':28, 13 | 'LHeel': 31, 'RHeel': 34, 14 | 'OP RShoulder': 17, 'OP LShoulder': 16, 15 | 'OP RHip': 2, 'OP LHip': 1, 16 | 'OP Neck': 12, 17 | } 18 | 19 | full_smpl_idx = range(24) 20 | key_smpl_idx = [0, 1, 4, 7, 2, 5, 8, 17, 19, 21, 16, 18, 20] 21 | 22 | 23 | AMASS_JOINT_MAP = { 24 | 'MidHip': 0, 25 | 'LHip': 1, 'LKnee': 4, 'LAnkle': 7, 'LFoot': 10, 26 | 'RHip': 2, 'RKnee': 5, 'RAnkle': 8, 'RFoot': 11, 27 | 'LShoulder': 16, 'LElbow': 18, 'LWrist': 20, 28 | 'RShoulder': 17, 'RElbow': 19, 'RWrist': 21, 29 | 'spine1': 3, 'spine2': 6, 'spine3': 9, 'Neck': 12, 'Head': 15, 30 | 'LCollar':13, 'Rcollar' :14, 31 | } 32 | amass_idx = range(22) 33 | amass_smpl_idx = range(22) 34 | 35 | 36 | SMPL_MODEL_DIR = "./body_models/" 37 | GMM_MODEL_DIR = "./visualize/joints2smpl/smpl_models/" 38 | SMPL_MEAN_FILE = "./visualize/joints2smpl/smpl_models/neutral_smpl_mean_params.h5" 39 | # for collsion 40 | Part_Seg_DIR = "./visualize/joints2smpl/smpl_models/smplx_parts_segm.pkl" -------------------------------------------------------------------------------- /visualize/render_mesh.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from visualize import vis_utils 4 | import shutil 5 | from tqdm import tqdm 6 | 7 | if __name__ == '__main__': 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--input_path", type=str, required=True, help='stick figure mp4 file to be rendered.') 10 | parser.add_argument("--cuda", type=bool, default=True, help='') 11 | parser.add_argument("--device", type=int, default=0, help='') 12 | parser.add_argument("--sample", type=int, default=0, help='') 13 | params = parser.parse_args() 14 | 15 | assert params.input_path.endswith('.mp4') 16 | parsed_name = os.path.basename(params.input_path).replace('.mp4', '').replace('sample', '').replace('rep', '') 17 | sample_i = params.sample # 1 18 | rep_i = 0 # 6 19 | # sample_i, rep_i = [int(e) for e in parsed_name.split('_')] 20 | npy_path = os.path.join(os.path.dirname(params.input_path), 'results.npy') 21 | out_npy_path = params.input_path.replace('.mp4', '_smpl_params.npy') 22 | assert os.path.exists(npy_path) 23 | results_dir = params.input_path.replace('.mp4', '_obj') 24 | if os.path.exists(results_dir): 25 | shutil.rmtree(results_dir) 26 | os.makedirs(results_dir) 27 | os.makedirs(os.path.join(results_dir, "loc")) 28 | 29 | npy2obj = vis_utils.npy2obj(npy_path, sample_i, rep_i, 30 | device=params.device, cuda=params.cuda) 31 | 32 | print('Saving obj files to [{}]'.format(os.path.abspath(results_dir))) 33 | for frame_i in tqdm(range(npy2obj.real_num_frames)): 34 | npy2obj.save_obj(os.path.join(results_dir, 'frame{:03d}.obj'.format(frame_i)), frame_i) 35 | 36 | print('Saving SMPL params to [{}]'.format(os.path.abspath(out_npy_path))) 37 | npy2obj.save_npy(out_npy_path) 38 | -------------------------------------------------------------------------------- /visualize/simplify_loc2rot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | from visualize.joints2smpl.src import config 5 | import smplx 6 | import h5py 7 | from visualize.joints2smpl.src.smplify import SMPLify3D 8 | from tqdm import tqdm 9 | import utils.rotation_conversions as geometry 10 | import argparse 11 | 12 | 13 | class joints2smpl: 14 | 15 | def __init__(self, num_frames, device_id, cuda=True): 16 | self.device = torch.device("cuda:" + str(device_id) if cuda else "cpu") 17 | # self.device = torch.device("cpu") 18 | self.batch_size = num_frames 19 | self.num_joints = 22 # for HumanML3D 20 | self.joint_category = "AMASS" 21 | self.num_smplify_iters = 150 22 | self.fix_foot = False 23 | print(config.SMPL_MODEL_DIR) 24 | smplmodel = smplx.create(config.SMPL_MODEL_DIR, 25 | model_type="smpl", gender="neutral", ext="pkl", 26 | batch_size=self.batch_size).to(self.device) 27 | 28 | # ## --- load the mean pose as original ---- 29 | smpl_mean_file = config.SMPL_MEAN_FILE 30 | 31 | file = h5py.File(smpl_mean_file, 'r') 32 | self.init_mean_pose = torch.from_numpy(file['pose'][:]).unsqueeze(0).repeat(self.batch_size, 1).float().to(self.device) 33 | self.init_mean_shape = torch.from_numpy(file['shape'][:]).unsqueeze(0).repeat(self.batch_size, 1).float().to(self.device) 34 | self.cam_trans_zero = torch.Tensor([0.0, 0.0, 0.0]).unsqueeze(0).to(self.device) 35 | # 36 | 37 | # # #-------------initialize SMPLify 38 | self.smplify = SMPLify3D(smplxmodel=smplmodel, 39 | batch_size=self.batch_size, 40 | joints_category=self.joint_category, 41 | num_iters=self.num_smplify_iters, 42 | device=self.device) 43 | 44 | 45 | def npy2smpl(self, npy_path): 46 | out_path = npy_path.replace('.npy', '_rot.npy') 47 | motions = np.load(npy_path, allow_pickle=True)[None][0] 48 | # print_batch('', motions) 49 | n_samples = motions['motion'].shape[0] 50 | all_thetas = [] 51 | for sample_i in tqdm(range(n_samples)): 52 | thetas, _ = self.joint2smpl(motions['motion'][sample_i].transpose(2, 0, 1)) # [nframes, njoints, 3] 53 | all_thetas.append(thetas.cpu().numpy()) 54 | motions['motion'] = np.concatenate(all_thetas, axis=0) 55 | print('motions', motions['motion'].shape) 56 | 57 | print(f'Saving [{out_path}]') 58 | np.save(out_path, motions) 59 | exit() 60 | 61 | 62 | 63 | def joint2smpl(self, input_joints, init_params=None): 64 | _smplify = self.smplify # if init_params is None else self.smplify_fast 65 | pred_pose = torch.zeros(self.batch_size, 72).to(self.device) 66 | pred_betas = torch.zeros(self.batch_size, 10).to(self.device) 67 | pred_cam_t = torch.zeros(self.batch_size, 3).to(self.device) 68 | keypoints_3d = torch.zeros(self.batch_size, self.num_joints, 3).to(self.device) 69 | 70 | # run the whole seqs 71 | num_seqs = input_joints.shape[0] 72 | 73 | 74 | # joints3d = input_joints[idx] # *1.2 #scale problem [check first] 75 | keypoints_3d = torch.Tensor(input_joints).to(self.device).float() 76 | 77 | # if idx == 0: 78 | if init_params is None: 79 | pred_betas = self.init_mean_shape 80 | pred_pose = self.init_mean_pose 81 | pred_cam_t = self.cam_trans_zero 82 | else: 83 | pred_betas = init_params['betas'] 84 | pred_pose = init_params['pose'] 85 | pred_cam_t = init_params['cam'] 86 | 87 | if self.joint_category == "AMASS": 88 | confidence_input = torch.ones(self.num_joints) 89 | # make sure the foot and ankle 90 | if self.fix_foot == True: 91 | confidence_input[7] = 1.5 92 | confidence_input[8] = 1.5 93 | confidence_input[10] = 1.5 94 | confidence_input[11] = 1.5 95 | else: 96 | print("Such category not settle down!") 97 | 98 | new_opt_vertices, new_opt_joints, new_opt_pose, new_opt_betas, \ 99 | new_opt_cam_t, new_opt_joint_loss = _smplify( 100 | pred_pose.detach(), 101 | pred_betas.detach(), 102 | pred_cam_t.detach(), 103 | keypoints_3d, 104 | conf_3d=confidence_input.to(self.device), 105 | # seq_ind=idx 106 | ) 107 | 108 | thetas = new_opt_pose.reshape(self.batch_size, 24, 3) 109 | thetas = geometry.matrix_to_rotation_6d(geometry.axis_angle_to_matrix(thetas)) # [bs, 24, 6] 110 | root_loc = torch.tensor(keypoints_3d[:, 0]) # [bs, 3] 111 | root_loc = torch.cat([root_loc, torch.zeros_like(root_loc)], dim=-1).unsqueeze(1) # [bs, 1, 6] 112 | thetas = torch.cat([thetas, root_loc], dim=1).unsqueeze(0).permute(0, 2, 3, 1) # [1, 25, 6, 196] 113 | 114 | return thetas.clone().detach(), {'pose': new_opt_joints[0, :24].flatten().clone().detach(), 'betas': new_opt_betas.clone().detach(), 'cam': new_opt_cam_t.clone().detach()} 115 | 116 | 117 | if __name__ == '__main__': 118 | parser = argparse.ArgumentParser() 119 | parser.add_argument("--input_path", type=str, required=True, help='Blender file or dir with blender files') 120 | parser.add_argument("--cuda", type=bool, default=True, help='') 121 | parser.add_argument("--device", type=int, default=0, help='') 122 | params = parser.parse_args() 123 | 124 | simplify = joints2smpl(device_id=params.device, cuda=params.cuda) 125 | 126 | if os.path.isfile(params.input_path) and params.input_path.endswith('.npy'): 127 | simplify.npy2smpl(params.input_path) 128 | elif os.path.isdir(params.input_path): 129 | files = [os.path.join(params.input_path, f) for f in os.listdir(params.input_path) if f.endswith('.npy')] 130 | for f in files: 131 | simplify.npy2smpl(f) -------------------------------------------------------------------------------- /visualize/vis_utils.py: -------------------------------------------------------------------------------- 1 | from model.rotation2xyz import Rotation2xyz 2 | import numpy as np 3 | import trimesh 4 | from trimesh import Trimesh 5 | import os 6 | import torch 7 | from visualize.simplify_loc2rot import joints2smpl 8 | 9 | class npy2obj: 10 | def __init__(self, npy_path, sample_idx, rep_idx, device=0, cuda=True): 11 | self.npy_path = npy_path 12 | self.motions = np.load(self.npy_path, allow_pickle=True) 13 | if self.npy_path.endswith('.npz'): 14 | self.motions = self.motions['arr_0'] 15 | self.motions = self.motions[None][0] 16 | self.rot2xyz = Rotation2xyz(device='cpu') 17 | self.faces = self.rot2xyz.smpl_model.faces 18 | self.bs, self.njoints, self.nfeats, self.nframes = self.motions['motion'].shape 19 | self.opt_cache = {} 20 | self.sample_idx = sample_idx 21 | self.total_num_samples = self.motions['num_samples'] 22 | self.rep_idx = rep_idx 23 | self.absl_idx = self.rep_idx*self.total_num_samples + self.sample_idx 24 | self.num_frames = self.motions['motion'][self.absl_idx].shape[-1] 25 | self.j2s = joints2smpl(num_frames=self.num_frames, device_id=device, cuda=cuda) 26 | 27 | if self.nfeats == 3: 28 | print(f'Running SMPLify For sample [{sample_idx}], repetition [{rep_idx}], it may take a few minutes.') 29 | motion_tensor, opt_dict = self.j2s.joint2smpl(self.motions['motion'][self.absl_idx].transpose(2, 0, 1)) # [nframes, njoints, 3] 30 | self.motions['motion'] = motion_tensor.cpu().numpy() 31 | elif self.nfeats == 6: 32 | self.motions['motion'] = self.motions['motion'][[self.absl_idx]] 33 | self.bs, self.njoints, self.nfeats, self.nframes = self.motions['motion'].shape 34 | self.real_num_frames = self.motions['lengths'][self.absl_idx] 35 | 36 | self.vertices = self.rot2xyz(torch.tensor(self.motions['motion']), mask=None, 37 | pose_rep='rot6d', translation=True, glob=True, 38 | jointstype='vertices', 39 | # jointstype='smpl', # for joint locations 40 | vertstrans=True) 41 | self.root_loc = self.motions['motion'][:, -1, :3, :].reshape(1, 1, 3, -1) 42 | 43 | # import pdb; pdb.set_trace() 44 | # self.vertices += self.root_loc 45 | # self.vertices[:, :, 1, :] += self.root_loc[:, :, 1, :] 46 | 47 | def get_vertices(self, sample_i, frame_i): 48 | return self.vertices[sample_i, :, :, frame_i].squeeze().tolist() 49 | 50 | def get_trimesh(self, sample_i, frame_i): 51 | return Trimesh(vertices=self.get_vertices(sample_i, frame_i), 52 | faces=self.faces) 53 | 54 | def get_traj_sphere(self, mesh): 55 | # import pdb; pdb.set_trace() 56 | root_posi = np.copy(mesh.vertices).mean(0) # (6000, 3) 57 | # import pdb; pdb.set_trace() 58 | # root_posi[1] = mesh.vertices.min(0)[1] + 0.1 59 | root_posi[1] = self.vertices.numpy().min(axis=(0, 1, 3))[1] + 0.1 60 | mesh = trimesh.primitives.Sphere(radius=0.05, center=root_posi, transform=None, subdivisions=1) 61 | return mesh 62 | 63 | def save_obj(self, save_path, frame_i): 64 | mesh = self.get_trimesh(0, frame_i) 65 | ground_sph_mesh = self.get_traj_sphere(mesh) 66 | loc_obj_name = os.path.splitext(os.path.basename(save_path))[0] + "_ground_loc.obj" 67 | ground_save_path = os.path.join(os.path.dirname(save_path), "loc", loc_obj_name) 68 | with open(save_path, 'w') as fw: 69 | mesh.export(fw, 'obj') 70 | with open(ground_save_path, 'w') as fw: 71 | ground_sph_mesh.export(fw, 'obj') 72 | return save_path 73 | 74 | def save_npy(self, save_path): 75 | data_dict = { 76 | 'motion': self.motions['motion'][0, :, :, :self.real_num_frames], 77 | 'thetas': self.motions['motion'][0, :-1, :, :self.real_num_frames], 78 | 'root_translation': self.motions['motion'][0, -1, :3, :self.real_num_frames], 79 | 'faces': self.faces, 80 | 'vertices': self.vertices[0, :, :, :self.real_num_frames], 81 | 'text': self.motions['text'][0], 82 | 'length': self.real_num_frames, 83 | } 84 | np.save(save_path, data_dict) 85 | --------------------------------------------------------------------------------