├── .gitignore ├── README.md ├── assets ├── TitlePage.png ├── dance.gif ├── graphical_abstract.jpg ├── ib.gif ├── jump.gif ├── pc.gif └── walk.gif ├── cmib ├── data │ ├── lafan1_dataset.py │ ├── quaternion.py │ └── utils.py ├── lafan1 │ ├── __init__.py │ ├── benchmarks.py │ ├── extract.py │ └── utils.py ├── misc │ └── sampler.py ├── model │ ├── network.py │ ├── positional_encoding.py │ ├── preprocess.py │ └── skeleton.py └── vis │ └── pose.py ├── requirements.txt ├── run_cmib.py ├── run_test_multi.py ├── test_benchmark.py ├── test_condition_comp.py ├── train.sh └── trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Project Specific 2 | settings.json 3 | log/* 4 | model_weights/* 5 | results/* 6 | ubisoft-laforge-animation-dataset/ 7 | lafan_train*.pkl 8 | runs/ 9 | wandb/ 10 | processed_data*/ 11 | AMASS/ 12 | *.pkl 13 | 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | *$py.class 18 | 19 | # C extensions 20 | *.so 21 | 22 | # Distribution / packaging 23 | .Python 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | pip-wheel-metadata/ 37 | share/python-wheels/ 38 | *.egg-info/ 39 | .installed.cfg 40 | *.egg 41 | MANIFEST 42 | 43 | # PyInstaller 44 | # Usually these files are written by a python script from a template 45 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 46 | *.manifest 47 | *.spec 48 | 49 | # Installer logs 50 | pip-log.txt 51 | pip-delete-this-directory.txt 52 | 53 | # Unit test / coverage reports 54 | htmlcov/ 55 | .tox/ 56 | .nox/ 57 | .coverage 58 | .coverage.* 59 | .cache 60 | nosetests.xml 61 | coverage.xml 62 | *.cover 63 | *.py,cover 64 | .hypothesis/ 65 | .pytest_cache/ 66 | 67 | # Translations 68 | *.mo 69 | *.pot 70 | 71 | # Django stuff: 72 | *.log 73 | local_settings.py 74 | db.sqlite3 75 | db.sqlite3-journal 76 | 77 | # Flask stuff: 78 | instance/ 79 | .webassets-cache 80 | 81 | # Scrapy stuff: 82 | .scrapy 83 | 84 | # Sphinx documentation 85 | docs/_build/ 86 | 87 | # PyBuilder 88 | target/ 89 | 90 | # Jupyter Notebook 91 | .ipynb_checkpoints 92 | 93 | # IPython 94 | profile_default/ 95 | ipython_config.py 96 | 97 | # pyenv 98 | .python-version 99 | 100 | # pipenv 101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 104 | # install all needed dependencies. 105 | #Pipfile.lock 106 | 107 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 108 | __pypackages__/ 109 | 110 | # Celery stuff 111 | celerybeat-schedule 112 | celerybeat.pid 113 | 114 | # SageMath parsed files 115 | *.sage.py 116 | 117 | # Environments 118 | .env 119 | .venv 120 | env/ 121 | venv/ 122 | ENV/ 123 | env.bak/ 124 | venv.bak/ 125 | 126 | # Spyder project settings 127 | .spyderproject 128 | .spyproject 129 | 130 | # Rope project settings 131 | .ropeproject 132 | 133 | # mkdocs documentation 134 | /site 135 | 136 | # mypy 137 | .mypy_cache/ 138 | .dmypy.json 139 | dmypy.json 140 | 141 | # Pyre type checker 142 | .pyre/ 143 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Conditional Motion In-Betweening (CMIB) 2 | 3 | Official implementation of paper: Conditional Motion In-betweeening. 4 | 5 | [Paper](https://www.sciencedirect.com/science/article/pii/S0031320322003752) | [Project Page](https://jihoonerd.github.io/Conditional-Motion-In-Betweening/) | [YouTube](https://youtu.be/XAELcHOREJ8) 6 | 7 |

8 | Graphical Abstract 9 |

10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 |
in-betweeningpose-conditioned
21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 |
walkjumpdance
34 | 35 | ## Environments 36 | 37 | This repo is tested on following environment: 38 | 39 | * Ubuntu 20.04 40 | * Python >= 3.7 41 | * PyTorch == 1.10.1 42 | * Cuda V11.3.109 43 | 44 | ## Install 45 | 46 | 1. Follow [`LAFAN1`](https://github.com/ubisoft/ubisoft-laforge-animation-dataset) dataset's installation guide. 47 | *You need to install git lfs first before cloning the dataset repo.* 48 | 49 | 2. Run LAFAN1's `evaluate.py` to unzip and validate it. (Install `numpy` first if you don't have it) 50 | ```bash 51 | $ pip install numpy 52 | $ python ubisoft-laforge-animation-dataset/evaluate.py 53 | ``` 54 | With this, you will have unpacked LAFAN dataset under `ubisoft-laforge-animation-dataset` folder. 55 | 56 | 3. Install appropriate `pytorch` version depending on your device(CPU/GPU), then install packages listed in `requirements.txt`. . 57 | 58 | ## Trained Weights 59 | 60 | You can download trained weights from [here](https://works.do/FCqKVjy). 61 | 62 | ## Train from Scratch 63 | 64 | Trining script is `trainer.py`. 65 | 66 | ```bash 67 | python trainer.py \ 68 | --processed_data_dir="processed_data_80/" \ 69 | --window=90 \ 70 | --batch_size=32 \ 71 | --epochs=5000 \ 72 | --device=0 \ 73 | --entity=cmib_exp \ 74 | --exp_name="cmib_80" \ 75 | --save_interval=50 \ 76 | --learning_rate=0.0001 \ 77 | --loss_cond_weight=1.5 \ 78 | --loss_pos_weight=0.05 \ 79 | --loss_rot_weight=2.0 \ 80 | --from_idx=9 \ 81 | --target_idx=88 \ 82 | --interpolation='slerp' 83 | 84 | ``` 85 | 86 | ## Inference 87 | 88 | You can use `run_cmib.py` for inference. Please refer to help page of `run_cmib.py` for more details. 89 | 90 | ```python 91 | python run_cmib.py --help 92 | ``` 93 | 94 | ## Reference 95 | 96 | * LAFAN1 Dataset 97 | ``` 98 | @article{harvey2020robust, 99 | author = {Félix G. Harvey and Mike Yurick and Derek Nowrouzezahrai and Christopher Pal}, 100 | title = {Robust Motion In-Betweening}, 101 | booktitle = {ACM Transactions on Graphics (Proceedings of ACM SIGGRAPH)}, 102 | publisher = {ACM}, 103 | volume = {39}, 104 | number = {4}, 105 | year = {2020} 106 | } 107 | ``` 108 | 109 | ## Citation 110 | ``` 111 | @article{KIM2022108894, 112 | title = {Conditional Motion In-betweening}, 113 | journal = {Pattern Recognition}, 114 | pages = {108894}, 115 | year = {2022}, 116 | issn = {0031-3203}, 117 | doi = {https://doi.org/10.1016/j.patcog.2022.108894}, 118 | url = {https://www.sciencedirect.com/science/article/pii/S0031320322003752}, 119 | author = {Jihoon Kim and Taehyun Byun and Seungyoun Shin and Jungdam Won and Sungjoon Choi}, 120 | keywords = {motion in-betweening, conditional motion generation, generative model, motion data augmentation}, 121 | abstract = {Motion in-betweening (MIB) is a process of generating intermediate skeletal movement between the given start and target poses while preserving the naturalness of the motion, such as periodic footstep motion while walking. Although state-of-the-art MIB methods are capable of producing plausible motions given sparse key-poses, they often lack the controllability to generate motions satisfying the semantic contexts required in practical applications. We focus on the method that can handle pose or semantic conditioned MIB tasks using a unified model. We also present a motion augmentation method to improve the quality of pose-conditioned motion generation via defining a distribution over smooth trajectories. Our proposed method outperforms the existing state-of-the-art MIB method in pose prediction errors while providing additional controllability. Our code and results are available on our project web page: https://jihoonerd.github.io/Conditional-Motion-In-Betweening} 122 | } 123 | ``` 124 | 125 | ## Author 126 | 127 | * [Jihoon Kim](https://github.com/jihoonerd) 128 | * [Taehyun Byun](https://github.com/childtoy) 129 | -------------------------------------------------------------------------------- /assets/TitlePage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jihoonerd/Conditional-Motion-In-Betweening/40f42c6d2d0e081e2162569180c5e2ad42ce659e/assets/TitlePage.png -------------------------------------------------------------------------------- /assets/dance.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jihoonerd/Conditional-Motion-In-Betweening/40f42c6d2d0e081e2162569180c5e2ad42ce659e/assets/dance.gif -------------------------------------------------------------------------------- /assets/graphical_abstract.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jihoonerd/Conditional-Motion-In-Betweening/40f42c6d2d0e081e2162569180c5e2ad42ce659e/assets/graphical_abstract.jpg -------------------------------------------------------------------------------- /assets/ib.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jihoonerd/Conditional-Motion-In-Betweening/40f42c6d2d0e081e2162569180c5e2ad42ce659e/assets/ib.gif -------------------------------------------------------------------------------- /assets/jump.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jihoonerd/Conditional-Motion-In-Betweening/40f42c6d2d0e081e2162569180c5e2ad42ce659e/assets/jump.gif -------------------------------------------------------------------------------- /assets/pc.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jihoonerd/Conditional-Motion-In-Betweening/40f42c6d2d0e081e2162569180c5e2ad42ce659e/assets/pc.gif -------------------------------------------------------------------------------- /assets/walk.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jihoonerd/Conditional-Motion-In-Betweening/40f42c6d2d0e081e2162569180c5e2ad42ce659e/assets/walk.gif -------------------------------------------------------------------------------- /cmib/data/lafan1_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from cmib.lafan1 import extract, utils 3 | import numpy as np 4 | import pickle 5 | import os 6 | 7 | 8 | class LAFAN1Dataset(Dataset): 9 | def __init__( 10 | self, 11 | lafan_path: str, 12 | processed_data_dir: str, 13 | train: bool, 14 | device: str, 15 | window: int = 65, 16 | dataset: str = 'LAFAN' 17 | ): 18 | self.lafan_path = lafan_path 19 | 20 | self.train = train 21 | # 4.3: It contains actions performedby 5 subjects, with Subject 5 used as the test set. 22 | self.dataset = dataset 23 | 24 | if self.dataset == 'LAFAN': 25 | self.actors = ( 26 | ["subject1", "subject2", "subject3", "subject4"] if train else ["subject5"] 27 | ) 28 | elif self.dataset in ['HumanEva', 'PosePrior']: 29 | self.actors = ( 30 | ["subject1", "subject2"] if train else ["subject3"] 31 | ) 32 | elif self.dataset in ['HUMAN4D']: 33 | self.actors = ( 34 | ["subject1", "subject2", "subject3", "subject4", "subject5", "subject6", "subject7"] if train else ["subject8"] 35 | ) 36 | elif self.dataset in ["MPI_HDM05"]: 37 | self.actors = ( 38 | ["subject1", "subject2", "subject3"] if train else ["subject4"] 39 | ) 40 | else: 41 | ValueError("Invalid Dataset") 42 | 43 | # 4.3: ... The training statistics for normalization are computed on windows of 50 frames offset by 20 frames. 44 | self.window = window 45 | 46 | # 4.3: Given the larger size of ... we sample our test windows from Subject 5 at every 40 frames. 47 | # The training statistics for normalization are computed on windows of 50 frames offset by 20 frames. 48 | self.offset = 20 if self.train else 40 49 | 50 | self.device = device 51 | 52 | pickle_name = "processed_train_data.pkl" if train else "processed_test_data.pkl" 53 | 54 | if pickle_name in os.listdir(processed_data_dir): 55 | with open(os.path.join(processed_data_dir, pickle_name), "rb") as f: 56 | self.data = pickle.load(f) 57 | else: 58 | self.data = self.load_lafan() # Call this last 59 | with open(os.path.join(processed_data_dir, pickle_name), "wb") as f: 60 | pickle.dump(self.data, f, pickle.HIGHEST_PROTOCOL) 61 | 62 | @property 63 | def root_v_dim(self): 64 | return self.data["root_v"].shape[2] 65 | 66 | @property 67 | def local_q_dim(self): 68 | return self.data["local_q"].shape[2] * self.data["local_q"].shape[3] 69 | 70 | @property 71 | def contact_dim(self): 72 | return self.data["contact"].shape[2] 73 | 74 | @property 75 | def num_joints(self): 76 | return self.data["global_pos"].shape[2] 77 | 78 | def load_lafan(self): 79 | # This uses method provided with LAFAN1. 80 | # X and Q are local position/quaternion. Motions are rotated to make 10th frame facing X+ position. 81 | # Refer to paper 3.1 Data formatting 82 | X, Q, parents, contacts_l, contacts_r, seq_names = extract.get_lafan1_set( 83 | self.lafan_path, self.actors, self.window, self.offset, self.train, self.dataset 84 | ) 85 | 86 | # Retrieve global representations. (global quaternion, global positions) 87 | _, global_pos = utils.quat_fk(Q, X, parents) 88 | 89 | input_data = {} 90 | input_data["local_q"] = Q # q_{t} 91 | input_data["local_q_offset"] = Q[:, -1, :, :] # lasst frame's quaternions 92 | input_data["q_target"] = Q[:, -1, :, :] # q_{T} 93 | 94 | input_data["root_v"] = ( 95 | global_pos[:, 1:, 0, :] - global_pos[:, :-1, 0, :] 96 | ) # \dot{r}_{t} 97 | input_data["root_p_offset"] = global_pos[ 98 | :, -1, 0, : 99 | ] # last frame's root positions 100 | input_data["root_p"] = global_pos[:, :, 0, :] 101 | 102 | input_data["contact"] = np.concatenate( 103 | [contacts_l, contacts_r], -1 104 | ) # Foot contact 105 | input_data["global_pos"] = global_pos[ 106 | :, :, :, : 107 | ] # global position (N, 50, 22, 30) why not just global_pos 108 | input_data["seq_names"] = seq_names 109 | return input_data 110 | 111 | def __len__(self): 112 | return self.data["global_pos"].shape[0] 113 | 114 | def __getitem__(self, index): 115 | query = {} 116 | query["local_q"] = self.data["local_q"][index].astype(np.float32) 117 | query["local_q_offset"] = self.data["local_q_offset"][index].astype(np.float32) 118 | query["q_target"] = self.data["q_target"][index].astype(np.float32) 119 | query["root_v"] = self.data["root_v"][index].astype(np.float32) 120 | query["root_p_offset"] = self.data["root_p_offset"][index].astype(np.float32) 121 | query["root_p"] = self.data["root_p"][index].astype(np.float32) 122 | query["contact"] = self.data["contact"][index].astype(np.float32) 123 | query["global_pos"] = self.data["global_pos"][index].astype(np.float32) 124 | return query 125 | -------------------------------------------------------------------------------- /cmib/data/quaternion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | import numpy as np 10 | 11 | # PyTorch-backed implementations 12 | 13 | 14 | def qmul(q, r): 15 | """ 16 | Multiply quaternion(s) q with quaternion(s) r. 17 | Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions. 18 | Returns q*r as a tensor of shape (*, 4). 19 | """ 20 | assert q.shape[-1] == 4 21 | assert r.shape[-1] == 4 22 | 23 | original_shape = q.shape 24 | 25 | # Compute outer product 26 | terms = torch.bmm(r.reshape(-1, 4, 1), q.reshape(-1, 1, 4)) 27 | 28 | w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3] 29 | x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2] 30 | y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1] 31 | z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0] 32 | return torch.stack((w, x, y, z), dim=1).view(original_shape) 33 | 34 | 35 | def qrot(q, v): 36 | """ 37 | Rotate vector(s) v about the rotation described by quaternion(s) q. 38 | Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v, 39 | where * denotes any number of dimensions. 40 | Returns a tensor of shape (*, 3). 41 | """ 42 | assert q.shape[-1] == 4 43 | assert v.shape[-1] == 3 44 | assert q.shape[:-1] == v.shape[:-1] 45 | 46 | original_shape = list(v.shape) 47 | q = q.reshape(-1, 4) 48 | v = v.reshape(-1, 3) 49 | 50 | qvec = q[:, 1:] 51 | uv = torch.cross(qvec, v, dim=1) 52 | uuv = torch.cross(qvec, uv, dim=1) 53 | return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape) 54 | 55 | 56 | def qeuler(q, order, epsilon=0): 57 | """ 58 | Convert quaternion(s) q to Euler angles. 59 | Expects a tensor of shape (*, 4), where * denotes any number of dimensions. 60 | Returns a tensor of shape (*, 3). 61 | """ 62 | assert q.shape[-1] == 4 63 | 64 | original_shape = list(q.shape) 65 | original_shape[-1] = 3 66 | q = q.view(-1, 4) 67 | 68 | q0 = q[:, 0] 69 | q1 = q[:, 1] 70 | q2 = q[:, 2] 71 | q3 = q[:, 3] 72 | 73 | if order == "xyz": 74 | x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) 75 | y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon)) 76 | z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) 77 | elif order == "yzx": 78 | x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) 79 | y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) 80 | z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon)) 81 | elif order == "zxy": 82 | x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon)) 83 | y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) 84 | z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3)) 85 | elif order == "xzy": 86 | x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) 87 | y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) 88 | z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon)) 89 | elif order == "yxz": 90 | x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon)) 91 | y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2)) 92 | z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) 93 | elif order == "zyx": 94 | x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) 95 | y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon)) 96 | z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) 97 | else: 98 | raise 99 | 100 | return torch.stack((x, y, z), dim=1).view(original_shape) 101 | 102 | 103 | # Numpy-backed implementations 104 | 105 | 106 | def qmul_np(q, r): 107 | q = torch.from_numpy(q).contiguous() 108 | r = torch.from_numpy(r).contiguous() 109 | return qmul(q, r).numpy() 110 | 111 | 112 | def qrot_np(q, v): 113 | q = torch.from_numpy(q).contiguous() 114 | v = torch.from_numpy(v).contiguous() 115 | return qrot(q, v).numpy() 116 | 117 | 118 | def qeuler_np(q, order, epsilon=0, use_gpu=False): 119 | if use_gpu: 120 | q = torch.from_numpy(q).cuda() 121 | return qeuler(q, order, epsilon).cpu().numpy() 122 | else: 123 | q = torch.from_numpy(q).contiguous() 124 | return qeuler(q, order, epsilon).numpy() 125 | 126 | 127 | def qfix(q): 128 | """ 129 | Enforce quaternion continuity across the time dimension by selecting 130 | the representation (q or -q) with minimal distance (or, equivalently, maximal dot product) 131 | between two consecutive frames. 132 | 133 | Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints. 134 | Returns a tensor of the same shape. 135 | """ 136 | assert len(q.shape) == 3 137 | assert q.shape[-1] == 4 138 | 139 | result = q.copy() 140 | dot_products = np.sum(q[1:] * q[:-1], axis=2) 141 | mask = dot_products < 0 142 | mask = (np.cumsum(mask, axis=0) % 2).astype(bool) 143 | result[1:][mask] *= -1 144 | return result 145 | 146 | 147 | def expmap_to_quaternion(e): 148 | """ 149 | Convert axis-angle rotations (aka exponential maps) to quaternions. 150 | Stable formula from "Practical Parameterization of Rotations Using the Exponential Map". 151 | Expects a tensor of shape (*, 3), where * denotes any number of dimensions. 152 | Returns a tensor of shape (*, 4). 153 | """ 154 | assert e.shape[-1] == 3 155 | 156 | original_shape = list(e.shape) 157 | original_shape[-1] = 4 158 | e = e.reshape(-1, 3) 159 | 160 | theta = np.linalg.norm(e, axis=1).reshape(-1, 1) 161 | w = np.cos(0.5 * theta).reshape(-1, 1) 162 | xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e 163 | return np.concatenate((w, xyz), axis=1).reshape(original_shape) 164 | 165 | 166 | def euler_to_quaternion(e, order): 167 | """ 168 | Convert Euler angles to quaternions. 169 | """ 170 | assert e.shape[-1] == 3 171 | 172 | original_shape = list(e.shape) 173 | original_shape[-1] = 4 174 | 175 | e = e.reshape(-1, 3) 176 | 177 | x = e[:, 0] 178 | y = e[:, 1] 179 | z = e[:, 2] 180 | 181 | rx = np.stack( 182 | (np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1 183 | ) 184 | ry = np.stack( 185 | (np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1 186 | ) 187 | rz = np.stack( 188 | (np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1 189 | ) 190 | 191 | result = None 192 | for coord in order: 193 | if coord == "x": 194 | r = rx 195 | elif coord == "y": 196 | r = ry 197 | elif coord == "z": 198 | r = rz 199 | else: 200 | raise 201 | if result is None: 202 | result = r 203 | else: 204 | result = qmul_np(result, r) 205 | 206 | # Reverse antipodal representation to have a non-negative "w" 207 | if order in ["xyz", "yzx", "zxy"]: 208 | result *= -1 209 | 210 | return result.reshape(original_shape) 211 | -------------------------------------------------------------------------------- /cmib/data/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | from pathlib import Path 5 | import re 6 | import numpy as np 7 | from cmib.data.quaternion import euler_to_quaternion, qeuler_np 8 | 9 | 10 | def drop_end_quat(quaternions, skeleton): 11 | """ 12 | quaternions: [N,T,Joints,4] 13 | """ 14 | 15 | return quaternions[:, :, skeleton.has_children()] 16 | 17 | 18 | def write_json(filename, local_q, root_pos, joint_names): 19 | json_out = {} 20 | json_out["root_pos"] = root_pos.tolist() 21 | json_out["local_quat"] = local_q.tolist() 22 | json_out["joint_names"] = joint_names 23 | with open(filename, "w") as outfile: 24 | json.dump(json_out, outfile) 25 | 26 | 27 | def flip_bvh(bvh_folder: str, skip: str): 28 | """ 29 | Generate LR flip of existing bvh files. Assumes Z-forward. 30 | It does not flip files contains skip string in their name. 31 | """ 32 | 33 | print("Left-Right Flipping Process...") 34 | 35 | # List files which are not flipped yet 36 | to_convert = [] 37 | not_convert = [] 38 | bvh_files = os.listdir(bvh_folder) 39 | for bvh_file in bvh_files: 40 | if "_LRflip.bvh" in bvh_file: 41 | continue 42 | if skip in bvh_file: 43 | not_convert.append(bvh_file) 44 | continue 45 | flipped_file = bvh_file.replace(".bvh", "_LRflip.bvh") 46 | if flipped_file in bvh_files: 47 | print(f"[SKIP: {bvh_file}] (flipped file already exists)") 48 | continue 49 | to_convert.append(bvh_file) 50 | 51 | print("Following files will be flipped: ") 52 | print(to_convert) 53 | print("Following files are not flipped: ") 54 | print(not_convert) 55 | 56 | for i, converting_fn in enumerate(to_convert): 57 | fout = open( 58 | os.path.join(bvh_folder, converting_fn.replace(".bvh", "_LRflip.bvh")), "w" 59 | ) 60 | file_read = open(os.path.join(bvh_folder, converting_fn), "r") 61 | file_lines = file_read.readlines() 62 | hierarchy_part = True 63 | for line in file_lines: 64 | if hierarchy_part: 65 | fout.write(line) 66 | if "Frame Time" in line: 67 | # This should be the last exact copy. Motion line comes next 68 | hierarchy_part = False 69 | else: 70 | # Followings are very helpful to understand which axis needs to be inverted 71 | # http://lo-th.github.io/olympe/BVH_player.html 72 | # https://quaternions.online/ 73 | str_to_num = line.split(" ")[:-1] # Extract number only 74 | motion_mat = np.array([float(x) for x in str_to_num]).reshape( 75 | 23, 3 76 | ) # Hips 6 Channel + 3 * 21 = 69 77 | motion_mat[0, 2] *= -1.0 # Invert translation Z axis (forward-backward) 78 | quat = euler_to_quaternion( 79 | np.radians(motion_mat[1:]), "zyx" 80 | ) # This function takes radians 81 | # Invert X-axis (Left-Right) / Quaternion representation: (w, x, y, z) 82 | quat[:, 0] *= -1.0 83 | quat[:, 1] *= -1.0 84 | motion_mat[1:] = np.degrees(qeuler_np(quat, "zyx")) 85 | 86 | # idx 0: Hips Wolrd coord, idx 1: Hips Rotation 87 | left_idx = [2, 3, 4, 5, 15, 16, 17, 18] # From 2: LeftUpLeg... 88 | right_idx = [6, 7, 8, 9, 19, 20, 21, 22] # From 6: RightUpLeg... 89 | motion_mat[left_idx + right_idx] = motion_mat[ 90 | right_idx + left_idx 91 | ].copy() 92 | motion_mat = np.round(motion_mat, decimals=6) 93 | motion_vector = np.reshape(motion_mat, (69,)) 94 | motion_part_str = "" 95 | for s in motion_vector: 96 | motion_part_str += str(s) + " " 97 | motion_part_str += "\n" 98 | fout.write(motion_part_str) 99 | print(f"[{i+1}/{len(to_convert)}] {converting_fn} flipped.") 100 | 101 | 102 | def increment_path(path, exist_ok=False, sep="", mkdir=False): 103 | # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc. 104 | path = Path(path) # os-agnostic 105 | if path.exists() and not exist_ok: 106 | suffix = path.suffix 107 | path = path.with_suffix("") 108 | dirs = glob.glob(f"{path}{sep}*") # similar paths 109 | matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs] 110 | i = [int(m.groups()[0]) for m in matches if m] # indices 111 | n = max(i) + 1 if i else 2 # increment number 112 | path = Path(f"{path}{sep}{n}{suffix}") # update path 113 | dir = path if path.suffix == "" else path.parent # directory 114 | if not dir.exists() and mkdir: 115 | dir.mkdir(parents=True, exist_ok=True) # make directory 116 | return path 117 | 118 | 119 | def process_seq_names(seq_names, dataset): 120 | 121 | if dataset in ['HumanEva', 'HUMAN4D', 'MPI_HDM05']: 122 | processed_seqname = [x[:-1] for x in seq_names] 123 | elif dataset == 'PosePrior': 124 | processed_seqname = [] 125 | for seq in seq_names: 126 | if 'lar' in seq: 127 | pr_seq = 'lar' 128 | elif 'op' in seq: 129 | pr_seq = 'op' 130 | elif 'rom' in seq: 131 | pr_seq = 'rom' 132 | elif 'uar' in seq: 133 | pr_seq = 'uar' 134 | elif 'ulr' in seq: 135 | pr_seq = 'ulr' 136 | else: 137 | ValueError('Invlaid seq name') 138 | processed_seqname.append(pr_seq) 139 | else: 140 | ValueError('Invalid dataset name') 141 | return processed_seqname -------------------------------------------------------------------------------- /cmib/lafan1/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jihoonerd/Conditional-Motion-In-Betweening/40f42c6d2d0e081e2162569180c5e2ad42ce659e/cmib/lafan1/__init__.py -------------------------------------------------------------------------------- /cmib/lafan1/benchmarks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from . import utils 4 | import torch 5 | 6 | np.set_printoptions(precision=3) 7 | 8 | 9 | def npss(gt_seq, pred_seq): 10 | # Fourier coefficients along the time dimension 11 | gt_fourier_coeffs = torch.real(torch.fft.fft(gt_seq, dim=1)) 12 | pred_fourier_coeffs = torch.real(torch.fft.fft(pred_seq, dim=1)) 13 | 14 | # Square of the Fourier coefficients 15 | gt_power = torch.square(gt_fourier_coeffs) 16 | pred_power = torch.square(pred_fourier_coeffs) 17 | 18 | # Sum of powers over time dimension 19 | gt_total_power = torch.sum(gt_power, dim=1) 20 | pred_total_power = torch.sum(pred_power, dim=1) 21 | 22 | # Normalize powers with totals 23 | gt_norm_power = gt_power / gt_total_power.unsqueeze(1) 24 | pred_norm_power = pred_power / pred_total_power.unsqueeze(1) 25 | 26 | # Cumulative sum over time 27 | cdf_gt_power = torch.cumsum(gt_norm_power, dim=1) 28 | cdf_pred_power = torch.cumsum(pred_norm_power, dim=1) 29 | 30 | # Earth mover distance 31 | emd = torch.norm((cdf_pred_power - cdf_gt_power), p=1, dim=1) 32 | 33 | # Weighted EMD 34 | power_weighted_emd = torch.sum(emd * gt_total_power) / torch.sum(gt_total_power) 35 | 36 | return power_weighted_emd 37 | 38 | 39 | def fast_npss(gt_seq, pred_seq): 40 | """ 41 | Computes Normalized Power Spectrum Similarity (NPSS). 42 | 43 | This is the metric proposed by Gropalakrishnan et al (2019). 44 | This implementation uses numpy parallelism for improved performance. 45 | 46 | :param gt_seq: ground-truth array of shape : (Batchsize, Timesteps, Dimension) 47 | :param pred_seq: shape : (Batchsize, Timesteps, Dimension) 48 | :return: The average npss metric for the batch 49 | """ 50 | # Fourier coefficients along the time dimension 51 | gt_fourier_coeffs = np.real(np.fft.fft(gt_seq, axis=1)) 52 | pred_fourier_coeffs = np.real(np.fft.fft(pred_seq, axis=1)) 53 | 54 | # Square of the Fourier coefficients 55 | gt_power = np.square(gt_fourier_coeffs) 56 | pred_power = np.square(pred_fourier_coeffs) 57 | 58 | # Sum of powers over time dimension 59 | gt_total_power = np.sum(gt_power, axis=1) 60 | pred_total_power = np.sum(pred_power, axis=1) 61 | 62 | # Normalize powers with totals 63 | gt_norm_power = gt_power / gt_total_power[:, np.newaxis, :] 64 | pred_norm_power = pred_power / pred_total_power[:, np.newaxis, :] 65 | 66 | # Cumulative sum over time 67 | cdf_gt_power = np.cumsum(gt_norm_power, axis=1) 68 | cdf_pred_power = np.cumsum(pred_norm_power, axis=1) 69 | 70 | # Earth mover distance 71 | emd = np.linalg.norm((cdf_pred_power - cdf_gt_power), ord=1, axis=1) 72 | 73 | # Weighted EMD 74 | power_weighted_emd = np.average(emd, weights=gt_total_power) 75 | 76 | return power_weighted_emd 77 | 78 | 79 | def flatjoints(x): 80 | """ 81 | Shorthand for a common reshape pattern. Collapses all but the two first dimensions of a tensor. 82 | :param x: Data tensor of at least 3 dimensions. 83 | :return: The flattened tensor. 84 | """ 85 | return x.reshape((x.shape[0], x.shape[1], -1)) 86 | 87 | 88 | def benchmark_interpolation( 89 | X, Q, x_mean, x_std, offsets, parents, out_path=None, n_past=10, n_future=10 90 | ): 91 | """ 92 | Evaluate naive baselines (zero-velocity and interpolation) for transition generation on given data. 93 | :param X: Local positions array of shape (Batchsize, Timesteps, Joints, 3) 94 | :param Q: Local quaternions array of shape (B, T, J, 4) 95 | :param x_mean : Mean vector of local positions of shape (1, J*3, 1) 96 | :param out_path: Standard deviation vector of local positions (1, J*3, 1) 97 | :param offsets: Local bone offsets tensor of shape (1, 1, J, 3) 98 | :param parents: List of bone parents indices defining the hierarchy 99 | :param out_path: optional path for saving the results 100 | :param n_past: Number of frames used as past context 101 | :param n_future: Number of frames used as future context (only the first frame is used as the target) 102 | :return: Results dictionary 103 | """ 104 | 105 | trans_lengths = [5, 15, 30, 45] 106 | n_joints = 22 107 | res = {} 108 | 109 | for n_trans in trans_lengths: 110 | print("Computing errors for transition length = {}...".format(n_trans)) 111 | 112 | # Format the data for the current transition lengths. The number of samples and the offset stays unchanged. 113 | curr_window = n_trans + n_past + n_future 114 | curr_x = X[:, :curr_window, ...] 115 | curr_q = Q[:, :curr_window, ...] 116 | batchsize = curr_x.shape[0] 117 | 118 | # Ground-truth positions/quats/eulers 119 | gt_local_quats = curr_q 120 | gt_roots = curr_x[:, :, 0:1, :] 121 | gt_offsets = np.tile(offsets, [batchsize, curr_window, 1, 1]) 122 | gt_local_poses = np.concatenate([gt_roots, gt_offsets], axis=2) 123 | trans_gt_local_poses = gt_local_poses[:, n_past:-n_future, ...] 124 | trans_gt_local_quats = gt_local_quats[:, n_past:-n_future, ...] 125 | # Local to global with Forward Kinematics (FK) 126 | trans_gt_global_quats, trans_gt_global_poses = utils.quat_fk( 127 | trans_gt_local_quats, trans_gt_local_poses, parents 128 | ) 129 | trans_gt_global_poses = trans_gt_global_poses.reshape( 130 | (trans_gt_global_poses.shape[0], -1, n_joints * 3) 131 | ).transpose([0, 2, 1]) 132 | # Normalize 133 | trans_gt_global_poses = (trans_gt_global_poses - x_mean) / x_std 134 | 135 | # Zero-velocity pos/quats 136 | zerov_trans_local_quats, zerov_trans_local_poses = ( 137 | np.zeros_like(trans_gt_local_quats), 138 | np.zeros_like(trans_gt_local_poses), 139 | ) 140 | zerov_trans_local_quats[:, :, :, :] = gt_local_quats[ 141 | :, n_past - 1 : n_past, :, : 142 | ] 143 | zerov_trans_local_poses[:, :, :, :] = gt_local_poses[ 144 | :, n_past - 1 : n_past, :, : 145 | ] 146 | # To global 147 | trans_zerov_global_quats, trans_zerov_global_poses = utils.quat_fk( 148 | zerov_trans_local_quats, zerov_trans_local_poses, parents 149 | ) 150 | trans_zerov_global_poses = trans_zerov_global_poses.reshape( 151 | (trans_zerov_global_poses.shape[0], -1, n_joints * 3) 152 | ).transpose([0, 2, 1]) 153 | # Normalize 154 | trans_zerov_global_poses = (trans_zerov_global_poses - x_mean) / x_std 155 | 156 | # Interpolation pos/quats 157 | r, q = curr_x[:, :, 0:1], curr_q 158 | inter_root, inter_local_quats = utils.interpolate_local(r, q, n_past, n_future) 159 | trans_inter_root = inter_root[:, 1:-1, :, :] 160 | trans_inter_offsets = np.tile(offsets, [batchsize, n_trans, 1, 1]) 161 | trans_inter_local_poses = np.concatenate( 162 | [trans_inter_root, trans_inter_offsets], axis=2 163 | ) 164 | inter_local_quats = inter_local_quats[:, 1:-1, :, :] 165 | # To global 166 | trans_interp_global_quats, trans_interp_global_poses = utils.quat_fk( 167 | inter_local_quats, trans_inter_local_poses, parents 168 | ) 169 | trans_interp_global_poses = trans_interp_global_poses.reshape( 170 | (trans_interp_global_poses.shape[0], -1, n_joints * 3) 171 | ).transpose([0, 2, 1]) 172 | # Normalize 173 | trans_interp_global_poses = (trans_interp_global_poses - x_mean) / x_std 174 | 175 | # Local quaternion loss 176 | res[("zerov_quat_loss", n_trans)] = np.mean( 177 | np.sqrt( 178 | np.sum( 179 | (trans_zerov_global_quats - trans_gt_global_quats) ** 2.0, 180 | axis=(2, 3), 181 | ) 182 | ) 183 | ) 184 | res[("interp_quat_loss", n_trans)] = np.mean( 185 | np.sqrt( 186 | np.sum( 187 | (trans_interp_global_quats - trans_gt_global_quats) ** 2.0, 188 | axis=(2, 3), 189 | ) 190 | ) 191 | ) 192 | 193 | # Global positions loss 194 | res[("zerov_pos_loss", n_trans)] = np.mean( 195 | np.sqrt( 196 | np.sum( 197 | (trans_zerov_global_poses - trans_gt_global_poses) ** 2.0, axis=1 198 | ) 199 | ) 200 | ) 201 | res[("interp_pos_loss", n_trans)] = np.mean( 202 | np.sqrt( 203 | np.sum( 204 | (trans_interp_global_poses - trans_gt_global_poses) ** 2.0, axis=1 205 | ) 206 | ) 207 | ) 208 | 209 | # NPSS loss on global quaternions 210 | res[("zerov_npss_loss", n_trans)] = fast_npss( 211 | flatjoints(trans_gt_global_quats), flatjoints(trans_zerov_global_quats) 212 | ) 213 | res[("interp_npss_loss", n_trans)] = fast_npss( 214 | flatjoints(trans_gt_global_quats), flatjoints(trans_interp_global_quats) 215 | ) 216 | 217 | print() 218 | avg_zerov_quat_losses = [res[("zerov_quat_loss", n)] for n in trans_lengths] 219 | avg_interp_quat_losses = [res[("interp_quat_loss", n)] for n in trans_lengths] 220 | print("=== Global quat losses ===") 221 | print( 222 | "{0: <16} | {1:6d} | {2:6d} | {3:6d} | {4:6d}".format("Lengths", 5, 15, 30, 45) 223 | ) 224 | print( 225 | "{0: <16} | {1:6.2f} | {2:6.2f} | {3:6.2f} | {4:6.2f}".format( 226 | "Zero-V", *avg_zerov_quat_losses 227 | ) 228 | ) 229 | print( 230 | "{0: <16} | {1:6.2f} | {2:6.2f} | {3:6.2f} | {4:6.2f}".format( 231 | "Interp.", *avg_interp_quat_losses 232 | ) 233 | ) 234 | print() 235 | 236 | avg_zerov_pos_losses = [res[("zerov_pos_loss", n)] for n in trans_lengths] 237 | avg_interp_pos_losses = [res[("interp_pos_loss", n)] for n in trans_lengths] 238 | print("=== Global pos losses ===") 239 | print( 240 | "{0: <16} | {1:6d} | {2:6d} | {3:6d} | {4:6d}".format("Lengths", 5, 15, 30, 45) 241 | ) 242 | print( 243 | "{0: <16} | {1:6.3f} | {2:6.3f} | {3:6.3f} | {4:6.3f}".format( 244 | "Zero-V", *avg_zerov_pos_losses 245 | ) 246 | ) 247 | print( 248 | "{0: <16} | {1:6.3f} | {2:6.3f} | {3:6.3f} | {4:6.3f}".format( 249 | "Interp.", *avg_interp_pos_losses 250 | ) 251 | ) 252 | print() 253 | 254 | avg_zerov_npss_losses = [res[("zerov_npss_loss", n)] for n in trans_lengths] 255 | avg_interp_npss_losses = [res[("interp_npss_loss", n)] for n in trans_lengths] 256 | print("=== NPSS on global quats ===") 257 | print( 258 | "{0: <16} | {1:5d} | {2:5d} | {3:5d} | {4:5d}".format( 259 | "Lengths", 5, 15, 30, 45 260 | ) 261 | ) 262 | print( 263 | "{0: <16} | {1:5.4f} | {2:5.4f} | {3:5.4f} | {4:5.4f}".format( 264 | "Zero-V", *avg_zerov_npss_losses 265 | ) 266 | ) 267 | print( 268 | "{0: <16} | {1:5.4f} | {2:5.4f} | {3:5.4f} | {4:5.4f}".format( 269 | "Interp.", *avg_interp_npss_losses 270 | ) 271 | ) 272 | print() 273 | 274 | # Write to file is desired 275 | if out_path is not None: 276 | res_txt_file = open( 277 | os.path.join(out_path, "h36m_transitions_benchmark.txt"), "a" 278 | ) 279 | res_txt_file.write("\n=== Global quat losses ===\n") 280 | res_txt_file.write( 281 | "{0: <16} | {1:6d} | {2:6d} | {3:6d} | {4:6d}\n".format( 282 | "Lengths", 5, 15, 30, 45 283 | ) 284 | ) 285 | res_txt_file.write( 286 | "{0: <16} | {1:6.2f} | {2:6.2f} | {3:6.2f} | {4:6.2f}\n".format( 287 | "Zero-V", *avg_zerov_quat_losses 288 | ) 289 | ) 290 | res_txt_file.write( 291 | "{0: <16} | {1:6.2f} | {2:6.2f} | {3:6.2f} | {4:6.2f}\n".format( 292 | "Interp.", *avg_interp_quat_losses 293 | ) 294 | ) 295 | res_txt_file.write("\n\n") 296 | res_txt_file.write("=== Global pos losses ===\n") 297 | res_txt_file.write( 298 | "{0: <16} | {1:5d} | {2:5d} | {3:5d} | {4:5d}\n".format( 299 | "Lengths", 5, 15, 30, 45 300 | ) 301 | ) 302 | res_txt_file.write( 303 | "{0: <16} | {1:5.4f} | {2:5.4f} | {3:5.4f} | {4:5.4f}\n".format( 304 | "Zero-V", *avg_zerov_pos_losses 305 | ) 306 | ) 307 | res_txt_file.write( 308 | "{0: <16} | {1:5.4f} | {2:5.4f} | {3:5.4f} | {4:5.4f}\n".format( 309 | "Interp.", *avg_interp_pos_losses 310 | ) 311 | ) 312 | res_txt_file.write("\n\n") 313 | res_txt_file.write("=== NPSS on global quats ===\n") 314 | res_txt_file.write( 315 | "{0: <16} | {1:5d} | {2:5d} | {3:5d} | {4:5d}\n".format( 316 | "Lengths", 5, 15, 30, 45 317 | ) 318 | ) 319 | res_txt_file.write( 320 | "{0: <16} | {1:5.4f} | {2:5.4f} | {3:5.4f} | {4:5.4f}\n".format( 321 | "Zero-V", *avg_zerov_npss_losses 322 | ) 323 | ) 324 | res_txt_file.write( 325 | "{0: <16} | {1:5.4f} | {2:5.4f} | {3:5.4f} | {4:5.4f}\n".format( 326 | "Interp.", *avg_interp_npss_losses 327 | ) 328 | ) 329 | res_txt_file.write("\n\n\n\n") 330 | res_txt_file.close() 331 | 332 | return res 333 | -------------------------------------------------------------------------------- /cmib/lafan1/extract.py: -------------------------------------------------------------------------------- 1 | import re, os, ntpath 2 | import numpy as np 3 | from . import utils 4 | 5 | channelmap = {"Xrotation": "x", "Yrotation": "y", "Zrotation": "z"} 6 | 7 | channelmap_inv = { 8 | "x": "Xrotation", 9 | "y": "Yrotation", 10 | "z": "Zrotation", 11 | } 12 | 13 | ordermap = { 14 | "x": 0, 15 | "y": 1, 16 | "z": 2, 17 | } 18 | 19 | 20 | class Anim(object): 21 | """ 22 | A very basic animation object 23 | """ 24 | 25 | def __init__(self, quats, pos, offsets, parents, bones): 26 | """ 27 | :param quats: local quaternions tensor 28 | :param pos: local positions tensor 29 | :param offsets: local joint offsets 30 | :param parents: bone hierarchy 31 | :param bones: bone names 32 | """ 33 | self.quats = quats 34 | self.pos = pos 35 | self.offsets = offsets 36 | self.parents = parents 37 | self.bones = bones 38 | 39 | 40 | def read_bvh(filename, start=None, end=None, order=None): 41 | """ 42 | Reads a BVH file and extracts animation information. 43 | 44 | :param filename: BVh filename 45 | :param start: start frame 46 | :param end: end frame 47 | :param order: order of euler rotations 48 | :return: A simple Anim object conatining the extracted information. 49 | """ 50 | 51 | f = open(filename, "r") 52 | 53 | i = 0 54 | active = -1 55 | end_site = False 56 | 57 | names = [] 58 | orients = np.array([]).reshape((0, 4)) 59 | offsets = np.array([]).reshape((0, 3)) 60 | parents = np.array([], dtype=int) 61 | 62 | # Parse the file, line by line 63 | for line in f: 64 | 65 | if "HIERARCHY" in line: 66 | continue 67 | if "MOTION" in line: 68 | continue 69 | 70 | rmatch = re.match(r"ROOT (\w+)", line) 71 | if rmatch: 72 | names.append(rmatch.group(1)) 73 | offsets = np.append(offsets, np.array([[0, 0, 0]]), axis=0) 74 | orients = np.append(orients, np.array([[1, 0, 0, 0]]), axis=0) 75 | parents = np.append(parents, active) 76 | active = len(parents) - 1 77 | continue 78 | 79 | if "{" in line: 80 | continue 81 | 82 | if "}" in line: 83 | if end_site: 84 | end_site = False 85 | else: 86 | active = parents[active] 87 | continue 88 | 89 | offmatch = re.match( 90 | r"\s*OFFSET\s+([\-\d\.e]+)\s+([\-\d\.e]+)\s+([\-\d\.e]+)", line 91 | ) 92 | if offmatch: 93 | if not end_site: 94 | offsets[active] = np.array([list(map(float, offmatch.groups()))]) 95 | continue 96 | 97 | chanmatch = re.match(r"\s*CHANNELS\s+(\d+)", line) 98 | if chanmatch: 99 | channels = int(chanmatch.group(1)) 100 | if order is None: 101 | channelis = 0 if channels == 3 else 3 102 | channelie = 3 if channels == 3 else 6 103 | parts = line.split()[2 + channelis : 2 + channelie] 104 | if any([p not in channelmap for p in parts]): 105 | continue 106 | order = "".join([channelmap[p] for p in parts]) 107 | continue 108 | 109 | jmatch = re.match("\s*JOINT\s+(\w+)", line) 110 | if jmatch: 111 | names.append(jmatch.group(1)) 112 | offsets = np.append(offsets, np.array([[0, 0, 0]]), axis=0) 113 | orients = np.append(orients, np.array([[1, 0, 0, 0]]), axis=0) 114 | parents = np.append(parents, active) 115 | active = len(parents) - 1 116 | continue 117 | 118 | if "End Site" in line: 119 | end_site = True 120 | continue 121 | 122 | fmatch = re.match("\s*Frames:\s+(\d+)", line) 123 | if fmatch: 124 | if start and end: 125 | fnum = (end - start) - 1 126 | else: 127 | fnum = int(fmatch.group(1)) 128 | positions = offsets[np.newaxis].repeat(fnum, axis=0) 129 | rotations = np.zeros((fnum, len(orients), 3)) 130 | continue 131 | 132 | fmatch = re.match("\s*Frame Time:\s+([\d\.]+)", line) 133 | if fmatch: 134 | frametime = float(fmatch.group(1)) 135 | continue 136 | 137 | if (start and end) and (i < start or i >= end - 1): 138 | i += 1 139 | continue 140 | 141 | dmatch = line.strip().split(" ") 142 | if dmatch: 143 | data_block = np.array(list(map(float, dmatch))) 144 | N = len(parents) 145 | fi = i - start if start else i 146 | if channels == 3: 147 | positions[fi, 0:1] = data_block[0:3] 148 | rotations[fi, :] = data_block[3:].reshape(N, 3) 149 | elif channels == 6: 150 | data_block = data_block.reshape(N, 6) 151 | positions[fi, :] = data_block[:, 0:3] 152 | rotations[fi, :] = data_block[:, 3:6] 153 | elif channels == 9: 154 | positions[fi, 0] = data_block[0:3] 155 | data_block = data_block[3:].reshape(N - 1, 9) 156 | rotations[fi, 1:] = data_block[:, 3:6] 157 | positions[fi, 1:] += data_block[:, 0:3] * data_block[:, 6:9] 158 | else: 159 | raise Exception("Too many channels! %i" % channels) 160 | 161 | i += 1 162 | 163 | f.close() 164 | 165 | rotations = utils.euler_to_quat(np.radians(rotations), order=order) 166 | rotations = utils.remove_quat_discontinuities(rotations) 167 | 168 | return Anim(rotations, positions, offsets, parents, names) 169 | 170 | 171 | def get_lafan1_set(bvh_path, actors, window=50, offset=20, train=True, stats=False, datset='LAFAN'): 172 | """ 173 | Extract the same test set as in the article, given the location of the BVH files. 174 | 175 | :param bvh_path: Path to the dataset BVH files 176 | :param list: actor prefixes to use in set 177 | :param window: width of the sliding windows (in timesteps) 178 | :param offset: offset between windows (in timesteps) 179 | :return: tuple: 180 | X: local positions 181 | Q: local quaternions 182 | parents: list of parent indices defining the bone hierarchy 183 | contacts_l: binary tensor of left-foot contacts of shape (Batchsize, Timesteps, 2) 184 | contacts_r: binary tensor of right-foot contacts of shape (Batchsize, Timesteps, 2) 185 | """ 186 | npast = 10 187 | subjects = [] 188 | seq_names = [] 189 | X = [] 190 | Q = [] 191 | contacts_l = [] 192 | contacts_r = [] 193 | 194 | # Extract 195 | bvh_files = sorted(os.listdir(bvh_path)) 196 | 197 | for file in bvh_files: 198 | if file.endswith(".bvh"): 199 | file_info = ntpath.basename(file[:-4]).split("_") 200 | seq_name = file_info[0] 201 | subject = file_info[1] 202 | 203 | if (not train) and (file_info[-1] == "LRflip"): 204 | continue 205 | 206 | if stats and (file_info[-1] == "LRflip"): 207 | continue 208 | 209 | # seq_name, subject = ntpath.basename(file[:-4]).split("_") 210 | if subject in actors: 211 | print("Processing file {}".format(file)) 212 | seq_path = os.path.join(bvh_path, file) 213 | anim = read_bvh(seq_path) 214 | 215 | # Sliding windows 216 | i = 0 217 | while i + window < anim.pos.shape[0]: 218 | q, x = utils.quat_fk( 219 | anim.quats[i : i + window], 220 | anim.pos[i : i + window], 221 | anim.parents, 222 | ) 223 | # Extract contacts 224 | c_l, c_r = utils.extract_feet_contacts( 225 | x, [3, 4], [7, 8], velfactor=0.02 226 | ) 227 | X.append(anim.pos[i : i + window]) 228 | Q.append(anim.quats[i : i + window]) 229 | seq_names.append(seq_name) 230 | subjects.append(subjects) 231 | contacts_l.append(c_l) 232 | contacts_r.append(c_r) 233 | 234 | i += offset 235 | 236 | X = np.asarray(X) 237 | Q = np.asarray(Q) 238 | contacts_l = np.asarray(contacts_l) 239 | contacts_r = np.asarray(contacts_r) 240 | 241 | # Sequences around XZ = 0 242 | xzs = np.mean(X[:, :, 0, ::2], axis=1, keepdims=True) 243 | X[:, :, 0, 0] = X[:, :, 0, 0] - xzs[..., 0] 244 | X[:, :, 0, 2] = X[:, :, 0, 2] - xzs[..., 1] 245 | 246 | # Unify facing on last seed frame 247 | X, Q = utils.rotate_at_frame(X, Q, anim.parents, n_past=npast) 248 | 249 | return X, Q, anim.parents, contacts_l, contacts_r, seq_names 250 | 251 | 252 | def get_train_stats(bvh_folder, train_set): 253 | """ 254 | Extract the same training set as in the paper in order to compute the normalizing statistics 255 | :return: Tuple of (local position mean vector, local position standard deviation vector, local joint offsets tensor) 256 | """ 257 | print("Building the train set...") 258 | xtrain, qtrain, parents, _, _, _ = get_lafan1_set( 259 | bvh_folder, train_set, window=50, offset=20, train=True, stats=True 260 | ) 261 | 262 | print("Computing stats...\n") 263 | # Joint offsets : are constant, so just take the first frame: 264 | offsets = xtrain[0:1, 0:1, 1:, :] # Shape : (1, 1, J, 3) 265 | 266 | # Global representation: 267 | q_glbl, x_glbl = utils.quat_fk(qtrain, xtrain, parents) 268 | 269 | # Global positions stats: 270 | x_mean = np.mean( 271 | x_glbl.reshape([x_glbl.shape[0], x_glbl.shape[1], -1]).transpose([0, 2, 1]), 272 | axis=(0, 2), 273 | keepdims=True, 274 | ) 275 | x_std = np.std( 276 | x_glbl.reshape([x_glbl.shape[0], x_glbl.shape[1], -1]).transpose([0, 2, 1]), 277 | axis=(0, 2), 278 | keepdims=True, 279 | ) 280 | 281 | return x_mean, x_std, offsets 282 | -------------------------------------------------------------------------------- /cmib/lafan1/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def length(x, axis=-1, keepdims=True): 5 | """ 6 | Computes vector norm along a tensor axis(axes) 7 | 8 | :param x: tensor 9 | :param axis: axis(axes) along which to compute the norm 10 | :param keepdims: indicates if the dimension(s) on axis should be kept 11 | :return: The length or vector of lengths. 12 | """ 13 | lgth = np.sqrt(np.sum(x * x, axis=axis, keepdims=keepdims)) 14 | return lgth 15 | 16 | 17 | def normalize(x, axis=-1, eps=1e-8): 18 | """ 19 | Normalizes a tensor over some axis (axes) 20 | 21 | :param x: data tensor 22 | :param axis: axis(axes) along which to compute the norm 23 | :param eps: epsilon to prevent numerical instabilities 24 | :return: The normalized tensor 25 | """ 26 | res = x / (length(x, axis=axis) + eps) 27 | return res 28 | 29 | 30 | def quat_normalize(x, eps=1e-8): 31 | """ 32 | Normalizes a quaternion tensor 33 | 34 | :param x: data tensor 35 | :param eps: epsilon to prevent numerical instabilities 36 | :return: The normalized quaternions tensor 37 | """ 38 | res = normalize(x, eps=eps) 39 | return res 40 | 41 | 42 | def angle_axis_to_quat(angle, axis): 43 | """ 44 | Converts from and angle-axis representation to a quaternion representation 45 | 46 | :param angle: angles tensor 47 | :param axis: axis tensor 48 | :return: quaternion tensor 49 | """ 50 | c = np.cos(angle / 2.0)[..., np.newaxis] 51 | s = np.sin(angle / 2.0)[..., np.newaxis] 52 | q = np.concatenate([c, s * axis], axis=-1) 53 | return q 54 | 55 | 56 | def euler_to_quat(e, order="zyx"): 57 | """ 58 | 59 | Converts from an euler representation to a quaternion representation 60 | 61 | :param e: euler tensor 62 | :param order: order of euler rotations 63 | :return: quaternion tensor 64 | """ 65 | axis = { 66 | "x": np.asarray([1, 0, 0], dtype=np.float32), 67 | "y": np.asarray([0, 1, 0], dtype=np.float32), 68 | "z": np.asarray([0, 0, 1], dtype=np.float32), 69 | } 70 | 71 | q0 = angle_axis_to_quat(e[..., 0], axis[order[0]]) 72 | q1 = angle_axis_to_quat(e[..., 1], axis[order[1]]) 73 | q2 = angle_axis_to_quat(e[..., 2], axis[order[2]]) 74 | 75 | return quat_mul(q0, quat_mul(q1, q2)) 76 | 77 | 78 | def quat_inv(q): 79 | """ 80 | Inverts a tensor of quaternions 81 | 82 | :param q: quaternion tensor 83 | :return: tensor of inverted quaternions 84 | """ 85 | res = np.asarray([1, -1, -1, -1], dtype=np.float32) * q 86 | return res 87 | 88 | 89 | def quat_fk(lrot, lpos, parents): 90 | """ 91 | Performs Forward Kinematics (FK) on local quaternions and local positions to retrieve global representations 92 | 93 | :param lrot: tensor of local quaternions with shape (..., Nb of joints, 4) 94 | :param lpos: tensor of local positions with shape (..., Nb of joints, 3) 95 | :param parents: list of parents indices 96 | :return: tuple of tensors of global quaternion, global positions 97 | """ 98 | gp, gr = [lpos[..., :1, :]], [lrot[..., :1, :]] 99 | for i in range(1, len(parents)): 100 | gp.append( 101 | quat_mul_vec(gr[parents[i]], lpos[..., i : i + 1, :]) + gp[parents[i]] 102 | ) 103 | gr.append(quat_mul(gr[parents[i]], lrot[..., i : i + 1, :])) 104 | 105 | res = np.concatenate(gr, axis=-2), np.concatenate(gp, axis=-2) 106 | return res 107 | 108 | 109 | def quat_ik(grot, gpos, parents): 110 | """ 111 | Performs Inverse Kinematics (IK) on global quaternions and global positions to retrieve local representations 112 | 113 | :param grot: tensor of global quaternions with shape (..., Nb of joints, 4) 114 | :param gpos: tensor of global positions with shape (..., Nb of joints, 3) 115 | :param parents: list of parents indices 116 | :return: tuple of tensors of local quaternion, local positions 117 | """ 118 | res = [ 119 | np.concatenate( 120 | [ 121 | grot[..., :1, :], 122 | quat_mul(quat_inv(grot[..., parents[1:], :]), grot[..., 1:, :]), 123 | ], 124 | axis=-2, 125 | ), 126 | np.concatenate( 127 | [ 128 | gpos[..., :1, :], 129 | quat_mul_vec( 130 | quat_inv(grot[..., parents[1:], :]), 131 | gpos[..., 1:, :] - gpos[..., parents[1:], :], 132 | ), 133 | ], 134 | axis=-2, 135 | ), 136 | ] 137 | 138 | return res 139 | 140 | 141 | def quat_mul(x, y): 142 | """ 143 | Performs quaternion multiplication on arrays of quaternions 144 | 145 | :param x: tensor of quaternions of shape (..., Nb of joints, 4) 146 | :param y: tensor of quaternions of shape (..., Nb of joints, 4) 147 | :return: The resulting quaternions 148 | """ 149 | x0, x1, x2, x3 = x[..., 0:1], x[..., 1:2], x[..., 2:3], x[..., 3:4] 150 | y0, y1, y2, y3 = y[..., 0:1], y[..., 1:2], y[..., 2:3], y[..., 3:4] 151 | 152 | res = np.concatenate( 153 | [ 154 | y0 * x0 - y1 * x1 - y2 * x2 - y3 * x3, 155 | y0 * x1 + y1 * x0 - y2 * x3 + y3 * x2, 156 | y0 * x2 + y1 * x3 + y2 * x0 - y3 * x1, 157 | y0 * x3 - y1 * x2 + y2 * x1 + y3 * x0, 158 | ], 159 | axis=-1, 160 | ) 161 | 162 | return res 163 | 164 | 165 | def quat_mul_vec(q, x): 166 | """ 167 | Performs multiplication of an array of 3D vectors by an array of quaternions (rotation). 168 | 169 | :param q: tensor of quaternions of shape (..., Nb of joints, 4) 170 | :param x: tensor of vectors of shape (..., Nb of joints, 3) 171 | :return: the resulting array of rotated vectors 172 | """ 173 | t = 2.0 * np.cross(q[..., 1:], x) 174 | res = x + q[..., 0][..., np.newaxis] * t + np.cross(q[..., 1:], t) 175 | 176 | return res 177 | 178 | 179 | def quat_slerp(x, y, a): 180 | """ 181 | Perfroms spherical linear interpolation (SLERP) between x and y, with proportion a 182 | 183 | :param x: quaternion tensor 184 | :param y: quaternion tensor 185 | :param a: indicator (between 0 and 1) of completion of the interpolation. 186 | :return: tensor of interpolation results 187 | """ 188 | len = np.sum(x * y, axis=-1) 189 | 190 | neg = len < 0.0 191 | len[neg] = -len[neg] 192 | y[neg] = -y[neg] 193 | 194 | a = np.zeros_like(x[..., 0]) + a 195 | amount0 = np.zeros(a.shape) 196 | amount1 = np.zeros(a.shape) 197 | 198 | linear = (1.0 - len) < 0.01 199 | omegas = np.arccos(len[~linear]) 200 | sinoms = np.sin(omegas) 201 | 202 | amount0[linear] = 1.0 - a[linear] 203 | amount0[~linear] = np.sin((1.0 - a[~linear]) * omegas) / sinoms 204 | 205 | amount1[linear] = a[linear] 206 | amount1[~linear] = np.sin(a[~linear] * omegas) / sinoms 207 | res = amount0[..., np.newaxis] * x + amount1[..., np.newaxis] * y 208 | 209 | return res 210 | 211 | 212 | def quat_between(x, y): 213 | """ 214 | Quaternion rotations between two 3D-vector arrays 215 | 216 | :param x: tensor of 3D vectors 217 | :param y: tensor of 3D vetcors 218 | :return: tensor of quaternions 219 | """ 220 | res = np.concatenate( 221 | [ 222 | np.sqrt(np.sum(x * x, axis=-1) * np.sum(y * y, axis=-1))[..., np.newaxis] 223 | + np.sum(x * y, axis=-1)[..., np.newaxis], 224 | np.cross(x, y), 225 | ], 226 | axis=-1, 227 | ) 228 | return res 229 | 230 | 231 | def interpolate_local(lcl_r_mb, lcl_q_mb, n_past, n_future): 232 | """ 233 | Performs interpolation between 2 frames of an animation sequence. 234 | 235 | The 2 frames are indirectly specified through n_past and n_future. 236 | SLERP is performed on the quaternions 237 | LERP is performed on the root's positions. 238 | 239 | :param lcl_r_mb: Local/Global root positions (B, T, 1, 3) 240 | :param lcl_q_mb: Local quaternions (B, T, J, 4) 241 | :param n_past: Number of frames of past context 242 | :param n_future: Number of frames of future context 243 | :return: Interpolated root and quats 244 | """ 245 | # Extract last past frame and target frame 246 | start_lcl_r_mb = lcl_r_mb[:, n_past - 1, :, :][:, None, :, :] # (B, 1, J, 3) 247 | end_lcl_r_mb = lcl_r_mb[:, -n_future, :, :][:, None, :, :] 248 | 249 | start_lcl_q_mb = lcl_q_mb[:, n_past - 1, :, :] 250 | end_lcl_q_mb = lcl_q_mb[:, -n_future, :, :] 251 | 252 | # LERP Local Positions: 253 | n_trans = lcl_r_mb.shape[1] - (n_past + n_future) 254 | interp_ws = np.linspace(0.0, 1.0, num=n_trans + 2, dtype=np.float32) 255 | offset = end_lcl_r_mb - start_lcl_r_mb 256 | 257 | const_trans = np.tile(start_lcl_r_mb, [1, n_trans + 2, 1, 1]) 258 | inter_lcl_r_mb = const_trans + (interp_ws)[None, :, None, None] * offset 259 | 260 | # SLERP Local Quats: 261 | interp_ws = np.linspace(0.0, 1.0, num=n_trans + 2, dtype=np.float32) 262 | inter_lcl_q_mb = np.stack( 263 | [ 264 | ( 265 | quat_normalize( 266 | quat_slerp( 267 | quat_normalize(start_lcl_q_mb), quat_normalize(end_lcl_q_mb), w 268 | ) 269 | ) 270 | ) 271 | for w in interp_ws 272 | ], 273 | axis=1, 274 | ) 275 | 276 | return inter_lcl_r_mb, inter_lcl_q_mb 277 | 278 | 279 | def remove_quat_discontinuities(rotations): 280 | """ 281 | 282 | Removing quat discontinuities on the time dimension (removing flips) 283 | 284 | :param rotations: Array of quaternions of shape (T, J, 4) 285 | :return: The processed array without quaternion inversion. 286 | """ 287 | rots_inv = -rotations 288 | 289 | for i in range(1, rotations.shape[0]): 290 | # Compare dot products 291 | replace_mask = np.sum( 292 | rotations[i - 1 : i] * rotations[i : i + 1], axis=-1 293 | ) < np.sum(rotations[i - 1 : i] * rots_inv[i : i + 1], axis=-1) 294 | replace_mask = replace_mask[..., np.newaxis] 295 | rotations[i] = replace_mask * rots_inv[i] + (1.0 - replace_mask) * rotations[i] 296 | 297 | return rotations 298 | 299 | 300 | # Orient the data according to the las past keframe 301 | def rotate_at_frame(X, Q, parents, n_past=10): 302 | """ 303 | Re-orients the animation data according to the last frame of past context. 304 | 305 | :param X: tensor of local positions of shape (Batchsize, Timesteps, Joints, 3) 306 | :param Q: tensor of local quaternions (Batchsize, Timesteps, Joints, 4) 307 | :param parents: list of parents' indices 308 | :param n_past: number of frames in the past context 309 | :return: The rotated positions X and quaternions Q 310 | """ 311 | # Get global quats and global poses (FK) 312 | global_q, global_x = quat_fk(Q, X, parents) 313 | 314 | key_glob_Q = global_q[:, n_past - 1 : n_past, 0:1, :] # (B, 1, 1, 4) 315 | forward = np.array([1, 0, 1])[np.newaxis, np.newaxis, np.newaxis, :] * quat_mul_vec( 316 | key_glob_Q, np.array([0, 1, 0])[np.newaxis, np.newaxis, np.newaxis, :] 317 | ) 318 | forward = normalize(forward) 319 | yrot = quat_normalize(quat_between(np.array([1, 0, 0]), forward)) 320 | new_glob_Q = quat_mul(quat_inv(yrot), global_q) 321 | new_glob_X = quat_mul_vec(quat_inv(yrot), global_x) 322 | 323 | # back to local quat-pos 324 | Q, X = quat_ik(new_glob_Q, new_glob_X, parents) 325 | 326 | return X, Q 327 | 328 | 329 | def extract_feet_contacts(pos, lfoot_idx, rfoot_idx, velfactor=0.02): 330 | """ 331 | Extracts binary tensors of feet contacts 332 | 333 | :param pos: tensor of global positions of shape (Timesteps, Joints, 3) 334 | :param lfoot_idx: indices list of left foot joints 335 | :param rfoot_idx: indices list of right foot joints 336 | :param velfactor: velocity threshold to consider a joint moving or not 337 | :return: binary tensors of left foot contacts and right foot contacts 338 | """ 339 | lfoot_xyz = (pos[1:, lfoot_idx, :] - pos[:-1, lfoot_idx, :]) ** 2 340 | contacts_l = np.sum(lfoot_xyz, axis=-1) < velfactor 341 | 342 | rfoot_xyz = (pos[1:, rfoot_idx, :] - pos[:-1, rfoot_idx, :]) ** 2 343 | contacts_r = np.sum(rfoot_xyz, axis=-1) < velfactor 344 | 345 | # Duplicate the last frame for shape consistency 346 | contacts_l = np.concatenate([contacts_l, contacts_l[-1:]], axis=0) 347 | contacts_r = np.concatenate([contacts_r, contacts_r[-1:]], axis=0) 348 | 349 | return contacts_l, contacts_r 350 | -------------------------------------------------------------------------------- /cmib/misc/sampler.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from pathlib import Path 4 | 5 | import imageio 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from PIL import Image 10 | from sklearn.preprocessing import LabelEncoder 11 | 12 | from cmib.data.lafan1_dataset import LAFAN1Dataset 13 | from cmib.data.utils import write_json 14 | from cmib.lafan1.utils import quat_ik 15 | from cmib.model.network import TransformerModel 16 | from cmib.model.preprocess import (lerp_input_repr, replace_constant, 17 | slerp_input_repr, vectorize_representation) 18 | from cmib.model.skeleton import (Skeleton, sk_joints_to_remove, sk_offsets, joint_names, 19 | sk_parents) 20 | from cmib.vis.pose import plot_pose_with_stop 21 | 22 | 23 | def test(opt, device): 24 | 25 | save_dir = Path(os.path.join('runs', 'train', opt.exp_name)) 26 | wdir = save_dir / 'weights' 27 | weights = os.listdir(wdir) 28 | weights_paths = [wdir / weight for weight in weights] 29 | latest_weight = max(weights_paths , key = os.path.getctime) 30 | ckpt = torch.load(latest_weight, map_location=device) 31 | print(f"Loaded weight: {latest_weight}") 32 | 33 | # Load Skeleton 34 | skeleton_mocap = Skeleton(offsets=sk_offsets, parents=sk_parents, device=device) 35 | skeleton_mocap.remove_joints(sk_joints_to_remove) 36 | 37 | # Load LAFAN Dataset 38 | Path(opt.processed_data_dir).mkdir(parents=True, exist_ok=True) 39 | lafan_dataset = LAFAN1Dataset(lafan_path=opt.data_path, processed_data_dir=opt.processed_data_dir, train=False, device=device) 40 | total_data = lafan_dataset.data['global_pos'].shape[0] 41 | 42 | # Replace with noise to In-betweening Frames 43 | from_idx, target_idx = ckpt['from_idx'], ckpt['target_idx'] # default: 9-40, max: 48 44 | horizon = ckpt['horizon'] 45 | print(f"HORIZON: {horizon}") 46 | 47 | test_idx = [] 48 | for i in range(total_data): 49 | test_idx.append(i) 50 | 51 | # Compare Input data, Prediction, GT 52 | save_path = os.path.join(opt.save_path, 'sampler') 53 | for i in range(len(test_idx)): 54 | Path(save_path).mkdir(parents=True, exist_ok=True) 55 | 56 | start_pose = lafan_dataset.data['global_pos'][test_idx[i], from_idx] 57 | target_pose = lafan_dataset.data['global_pos'][test_idx[i], target_idx] 58 | gt_stopover_pose = lafan_dataset.data['global_pos'][test_idx[i], from_idx] 59 | 60 | gt_img_path = os.path.join(save_path) 61 | plot_pose_with_stop(start_pose, target_pose, target_pose, gt_stopover_pose, i, skeleton_mocap, save_dir=gt_img_path, prefix='gt') 62 | print(f"ID {test_idx[i]}: completed.") 63 | 64 | def parse_opt(): 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument('--project', default='runs/train', help='project/name') 67 | parser.add_argument('--exp_name', default='slerp_40', help='experiment name') 68 | parser.add_argument('--data_path', type=str, default='ubisoft-laforge-animation-dataset/output/BVH', help='BVH dataset path') 69 | parser.add_argument('--skeleton_path', type=str, default='ubisoft-laforge-animation-dataset/output/BVH/walk1_subject1.bvh', help='path to reference skeleton') 70 | parser.add_argument('--processed_data_dir', type=str, default='processed_data_original/', help='path to save pickled processed data') 71 | parser.add_argument('--save_path', type=str, default='runs/test', help='path to save model') 72 | parser.add_argument('--motion_type', type=str, default='jumps', help='motion type') 73 | opt = parser.parse_args() 74 | return opt 75 | 76 | if __name__ == "__main__": 77 | opt = parse_opt() 78 | device = torch.device("cpu") 79 | test(opt, device) 80 | -------------------------------------------------------------------------------- /cmib/model/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from torch.nn import TransformerEncoder, TransformerEncoderLayer 4 | from cmib.model.positional_encoding import PositionalEmbedding 5 | 6 | 7 | class TransformerModel(nn.Module): 8 | def __init__( 9 | self, 10 | seq_len: int, 11 | d_model: int, 12 | nhead: int, 13 | d_hid: int, 14 | nlayers: int, 15 | dropout: float = 0.5, 16 | out_dim=91, 17 | num_labels=15 18 | ): 19 | super().__init__() 20 | self.model_type = "Transformer" 21 | self.seq_len = seq_len 22 | self.d_model = d_model 23 | self.nhead = nhead 24 | self.d_hid = d_hid 25 | self.nlayers = nlayers 26 | self.cond_emb = nn.Embedding(num_labels, d_model) 27 | self.pos_embedding = PositionalEmbedding(seq_len=seq_len, d_model=d_model) 28 | encoder_layers = TransformerEncoderLayer( 29 | d_model, nhead, d_hid, dropout, activation="gelu" 30 | ) 31 | self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) 32 | self.decoder = nn.Linear(d_model, out_dim) 33 | 34 | self.init_weights() 35 | 36 | def init_weights(self) -> None: 37 | initrange = 0.1 38 | self.decoder.bias.data.zero_() 39 | self.decoder.weight.data.uniform_(-initrange, initrange) 40 | 41 | def forward(self, src: Tensor, src_mask: Tensor, cond_code: Tensor) -> Tensor: 42 | """ 43 | Args: 44 | src: Tensor, shape [seq_len, batch_size, embedding_dim] 45 | src_mask: Tensor, shape [seq_len, seq_len] 46 | 47 | Returns: 48 | output Tensor of shape [seq_len, batch_size, embedding_dim] 49 | """ 50 | cond_embedding = self.cond_emb(cond_code).permute(1, 0, 2) 51 | output = self.pos_embedding(src) 52 | output = torch.cat([cond_embedding, output], dim=0) 53 | output = self.transformer_encoder(output, src_mask) 54 | output = self.decoder(output) 55 | return output, cond_embedding 56 | -------------------------------------------------------------------------------- /cmib/model/positional_encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | import math 4 | 5 | 6 | class PositionalEmbedding(nn.Module): 7 | def __init__(self, seq_len: int = 32, d_model: int = 96): 8 | super().__init__() 9 | self.pos_emb = nn.Embedding(seq_len + 1, d_model) 10 | 11 | def forward(self, inputs): 12 | positions = ( 13 | torch.arange(inputs.size(0), device=inputs.device) 14 | .expand(inputs.size(1), inputs.size(0)) 15 | .contiguous() 16 | + 1 17 | ) 18 | outputs = inputs + self.pos_emb(positions).permute(1, 0, 2) 19 | return outputs 20 | 21 | 22 | class PositionalEncoding(nn.Module): 23 | def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): 24 | super().__init__() 25 | self.dropout = nn.Dropout(p=dropout) 26 | 27 | position = torch.arange(max_len).unsqueeze(1) 28 | div_term = torch.exp( 29 | torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) 30 | ) 31 | pe = torch.zeros(max_len, 1, d_model) 32 | pe[:, 0, 0::2] = torch.sin(position * div_term) 33 | pe[:, 0, 1::2] = torch.cos(position * div_term) 34 | self.register_buffer("pe", pe) 35 | 36 | def forward(self, x: Tensor) -> Tensor: 37 | """ 38 | Args: 39 | x: Tensor, shape [seq_len, batch_size, embedding_dim] 40 | """ 41 | x = x + self.pe[: x.size(0)] 42 | return self.dropout(x) 43 | -------------------------------------------------------------------------------- /cmib/model/preprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def replace_constant(minibatch_pose_input, mask_start_frame): 5 | 6 | seq_len = minibatch_pose_input.size(1) 7 | interpolated = ( 8 | torch.ones_like(minibatch_pose_input, device=minibatch_pose_input.device) * 0.1 9 | ) 10 | 11 | if mask_start_frame == 0 or mask_start_frame == (seq_len - 1): 12 | interpolate_start = minibatch_pose_input[:, 0, :] 13 | interpolate_end = minibatch_pose_input[:, seq_len - 1, :] 14 | 15 | interpolated[:, 0, :] = interpolate_start 16 | interpolated[:, seq_len - 1, :] = interpolate_end 17 | 18 | assert torch.allclose(interpolated[:, 0, :], interpolate_start) 19 | assert torch.allclose(interpolated[:, seq_len - 1, :], interpolate_end) 20 | 21 | else: 22 | interpolate_start1 = minibatch_pose_input[:, 0, :] 23 | interpolate_end1 = minibatch_pose_input[:, mask_start_frame, :] 24 | 25 | interpolate_start2 = minibatch_pose_input[:, mask_start_frame, :] 26 | interpolate_end2 = minibatch_pose_input[:, seq_len - 1, :] 27 | 28 | interpolated[:, 0, :] = interpolate_start1 29 | interpolated[:, mask_start_frame, :] = interpolate_end1 30 | 31 | interpolated[:, mask_start_frame, :] = interpolate_start2 32 | interpolated[:, seq_len - 1, :] = interpolate_end2 33 | 34 | assert torch.allclose(interpolated[:, 0, :], interpolate_start1) 35 | assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_end1) 36 | 37 | assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_start2) 38 | assert torch.allclose(interpolated[:, seq_len - 1, :], interpolate_end2) 39 | return interpolated 40 | 41 | 42 | def slerp(x, y, a): 43 | """ 44 | Perfroms spherical linear interpolation (SLERP) between x and y, with proportion a 45 | 46 | :param x: quaternion tensor 47 | :param y: quaternion tensor 48 | :param a: indicator (between 0 and 1) of completion of the interpolation. 49 | :return: tensor of interpolation results 50 | """ 51 | device = x.device 52 | len = torch.sum(x * y, dim=-1) 53 | 54 | neg = len < 0.0 55 | len[neg] = -len[neg] 56 | y[neg] = -y[neg] 57 | 58 | a = torch.zeros_like(x[..., 0]) + a 59 | amount0 = torch.zeros(a.shape, device=device) 60 | amount1 = torch.zeros(a.shape, device=device) 61 | 62 | linear = (1.0 - len) < 0.01 63 | omegas = torch.arccos(len[~linear]) 64 | sinoms = torch.sin(omegas) 65 | 66 | amount0[linear] = 1.0 - a[linear] 67 | amount0[~linear] = torch.sin((1.0 - a[~linear]) * omegas) / sinoms 68 | 69 | amount1[linear] = a[linear] 70 | amount1[~linear] = torch.sin(a[~linear] * omegas) / sinoms 71 | # res = amount0[..., np.newaxis] * x + amount1[..., np.newaxis] * y 72 | res = amount0.unsqueeze(3) * x + amount1.unsqueeze(3) * y 73 | 74 | return res 75 | 76 | 77 | def slerp_input_repr(minibatch_pose_input, mask_start_frame): 78 | seq_len = minibatch_pose_input.size(1) 79 | minibatch_pose_input = minibatch_pose_input.reshape( 80 | minibatch_pose_input.size(0), seq_len, -1, 4 81 | ) 82 | interpolated = torch.zeros_like( 83 | minibatch_pose_input, device=minibatch_pose_input.device 84 | ) 85 | 86 | if mask_start_frame == 0 or mask_start_frame == (seq_len - 1): 87 | interpolate_start = minibatch_pose_input[:, 0:1] 88 | interpolate_end = minibatch_pose_input[:, seq_len - 1 :] 89 | 90 | for i in range(seq_len): 91 | dt = 1 / (seq_len - 1) 92 | interpolated[:, i : i + 1, :] = slerp( 93 | interpolate_start, interpolate_end, dt * i 94 | ) 95 | 96 | assert torch.allclose(interpolated[:, 0:1], interpolate_start) 97 | assert torch.allclose(interpolated[:, seq_len - 1 :], interpolate_end) 98 | else: 99 | interpolate_start1 = minibatch_pose_input[:, 0:1] 100 | interpolate_end1 = minibatch_pose_input[ 101 | :, mask_start_frame : mask_start_frame + 1 102 | ] 103 | 104 | interpolate_start2 = minibatch_pose_input[ 105 | :, mask_start_frame : mask_start_frame + 1 106 | ] 107 | interpolate_end2 = minibatch_pose_input[:, seq_len - 1 :] 108 | 109 | for i in range(mask_start_frame + 1): 110 | dt = 1 / mask_start_frame 111 | interpolated[:, i : i + 1, :] = slerp( 112 | interpolate_start1, interpolate_end1, dt * i 113 | ) 114 | 115 | assert torch.allclose(interpolated[:, 0:1], interpolate_start1) 116 | assert torch.allclose( 117 | interpolated[:, mask_start_frame : mask_start_frame + 1], interpolate_end1 118 | ) 119 | 120 | for i in range(mask_start_frame, seq_len): 121 | dt = 1 / (seq_len - mask_start_frame - 1) 122 | interpolated[:, i : i + 1, :] = slerp( 123 | interpolate_start2, interpolate_end2, dt * (i - mask_start_frame) 124 | ) 125 | 126 | assert torch.allclose( 127 | interpolated[:, mask_start_frame : mask_start_frame + 1], interpolate_start2 128 | ) 129 | assert torch.allclose(interpolated[:, seq_len - 1 :], interpolate_end2) 130 | 131 | interpolated = torch.nn.functional.normalize(interpolated, p=2.0, dim=3) 132 | return interpolated.reshape(minibatch_pose_input.size(0), seq_len, -1) 133 | 134 | 135 | def lerp_input_repr(minibatch_pose_input, mask_start_frame): 136 | seq_len = minibatch_pose_input.size(1) 137 | interpolated = torch.zeros_like( 138 | minibatch_pose_input, device=minibatch_pose_input.device 139 | ) 140 | 141 | if mask_start_frame == 0 or mask_start_frame == (seq_len - 1): 142 | interpolate_start = minibatch_pose_input[:, 0, :] 143 | interpolate_end = minibatch_pose_input[:, seq_len - 1, :] 144 | 145 | for i in range(seq_len): 146 | dt = 1 / (seq_len - 1) 147 | interpolated[:, i, :] = torch.lerp( 148 | interpolate_start, interpolate_end, dt * i 149 | ) 150 | 151 | assert torch.allclose(interpolated[:, 0, :], interpolate_start) 152 | assert torch.allclose(interpolated[:, seq_len - 1, :], interpolate_end) 153 | else: 154 | interpolate_start1 = minibatch_pose_input[:, 0, :] 155 | interpolate_end1 = minibatch_pose_input[:, mask_start_frame, :] 156 | 157 | interpolate_start2 = minibatch_pose_input[:, mask_start_frame, :] 158 | interpolate_end2 = minibatch_pose_input[:, -1, :] 159 | 160 | for i in range(mask_start_frame + 1): 161 | dt = 1 / mask_start_frame 162 | interpolated[:, i, :] = torch.lerp( 163 | interpolate_start1, interpolate_end1, dt * i 164 | ) 165 | 166 | assert torch.allclose(interpolated[:, 0, :], interpolate_start1) 167 | assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_end1) 168 | 169 | for i in range(mask_start_frame, seq_len): 170 | dt = 1 / (seq_len - mask_start_frame - 1) 171 | interpolated[:, i, :] = torch.lerp( 172 | interpolate_start2, interpolate_end2, dt * (i - mask_start_frame) 173 | ) 174 | 175 | assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_start2) 176 | assert torch.allclose(interpolated[:, -1, :], interpolate_end2) 177 | return interpolated 178 | 179 | 180 | def vectorize_representation(global_position, global_rotation): 181 | 182 | batch_size = global_position.shape[0] 183 | seq_len = global_position.shape[1] 184 | 185 | global_pos_vec = global_position.reshape(batch_size, seq_len, -1).contiguous() 186 | global_rot_vec = global_rotation.reshape(batch_size, seq_len, -1).contiguous() 187 | 188 | global_pose_vec_gt = torch.cat([global_pos_vec, global_rot_vec], dim=2) 189 | return global_pose_vec_gt 190 | -------------------------------------------------------------------------------- /cmib/model/skeleton.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from cmib.data.quaternion import qmul, qrot 4 | import torch.nn as nn 5 | 6 | amass_offsets = [ 7 | [0.0, 0.0, 0.0], 8 | 9 | [0.058581, -0.082280, -0.017664], 10 | [0.043451, -0.386469, 0.008037], 11 | [-0.014790, -0.426874, -0.037428], 12 | [0.041054, -0.060286, 0.122042], 13 | [0.0, 0.0, 0.0], 14 | 15 | [-0.060310, -0.090513, -0.013543], 16 | [-0.043257, -0.383688, -0.004843], 17 | [0.019056, -0.420046, -0.034562], 18 | [-0.034840, -0.062106, 0.130323], 19 | [0.0, 0.0, 0.0], 20 | 21 | [0.004439, 0.124404, -0.038385], 22 | [0.004488, 0.137956, 0.026820], 23 | [-0.002265, 0.056032, 0.002855], 24 | [-0.013390, 0.211636, -0.033468], 25 | [0.010113, 0.088937, 0.050410], 26 | [0.0, 0.0, 0.0], 27 | 28 | [0.071702, 0.114000, -0.018898], 29 | [0.122921, 0.045205, -0.019046], 30 | [0.255332, -0.015649, -0.022946], 31 | [0.265709, 0.012698, -0.007375], 32 | [0.0, 0.0, 0.0], 33 | 34 | [-0.082954, 0.112472, -0.023707], 35 | [-0.113228, 0.046853, -0.008472], 36 | [-0.260127, -0.014369, -0.031269], 37 | [-0.269108, 0.006794, -0.006027], 38 | [0.0, 0.0, 0.0] 39 | ] 40 | 41 | sk_offsets = [ 42 | [-42.198200, 91.614723, -40.067841], 43 | 44 | [0.103456, 1.857829, 10.548506], 45 | [43.499992, -0.000038, -0.000002], 46 | [42.372192, 0.000015, -0.000007], 47 | [17.299999, -0.000002, 0.000003], 48 | [0.000000, 0.000000, 0.000000], 49 | 50 | [0.103457, 1.857829, -10.548503], 51 | [43.500042, -0.000027, 0.000008], 52 | [42.372257, -0.000008, 0.000014], 53 | [17.299992, -0.000005, 0.000004], 54 | [0.000000, 0.000000, 0.000000], 55 | 56 | [6.901968, -2.603733, -0.000001], 57 | [12.588099, 0.000002, 0.000000], 58 | [12.343206, 0.000000, -0.000001], 59 | [25.832886, -0.000004, 0.000003], 60 | [11.766620, 0.000005, -0.000001], 61 | [0.000000, 0.000000, 0.000000], 62 | 63 | [19.745899, -1.480370, 6.000108], 64 | [11.284125, -0.000009, -0.000018], 65 | [33.000050, 0.000004, 0.000032], 66 | [25.200008, 0.000015, 0.000008], 67 | [0.000000, 0.000000, 0.000000], 68 | 69 | [19.746099, -1.480375, -6.000073], 70 | [11.284138, -0.000015, -0.000012], 71 | [33.000092, 0.000017, 0.000013], 72 | [25.199780, 0.000135, 0.000422], 73 | [0.000000, 0.000000, 0.000000], 74 | ] 75 | 76 | sk_parents = [ 77 | -1, 78 | 0, 79 | 1, 80 | 2, 81 | 3, 82 | 4, 83 | 0, 84 | 6, 85 | 7, 86 | 8, 87 | 9, 88 | 0, 89 | 11, 90 | 12, 91 | 13, 92 | 14, 93 | 15, 94 | 13, 95 | 17, 96 | 18, 97 | 19, 98 | 20, 99 | 13, 100 | 22, 101 | 23, 102 | 24, 103 | 25, 104 | ] 105 | 106 | sk_joints_to_remove = [5, 10, 16, 21, 26] 107 | 108 | joint_names = [ 109 | "Hips", 110 | "LeftUpLeg", 111 | "LeftLeg", 112 | "LeftFoot", 113 | "LeftToe", 114 | "RightUpLeg", 115 | "RightLeg", 116 | "RightFoot", 117 | "RightToe", 118 | "Spine", 119 | "Spine1", 120 | "Spine2", 121 | "Neck", 122 | "Head", 123 | "LeftShoulder", 124 | "LeftArm", 125 | "LeftForeArm", 126 | "LeftHand", 127 | "RightShoulder", 128 | "RightArm", 129 | "RightForeArm", 130 | "RightHand", 131 | ] 132 | 133 | 134 | class Skeleton: 135 | def __init__( 136 | self, 137 | offsets, 138 | parents, 139 | joints_left=None, 140 | joints_right=None, 141 | bone_length=None, 142 | device=None, 143 | ): 144 | assert len(offsets) == len(parents) 145 | 146 | self._offsets = torch.Tensor(offsets).to(device) 147 | self._parents = np.array(parents) 148 | self._joints_left = joints_left 149 | self._joints_right = joints_right 150 | self._compute_metadata() 151 | 152 | def num_joints(self): 153 | return self._offsets.shape[0] 154 | 155 | def offsets(self): 156 | return self._offsets 157 | 158 | def parents(self): 159 | return self._parents 160 | 161 | def has_children(self): 162 | return self._has_children 163 | 164 | def children(self): 165 | return self._children 166 | 167 | def convert_to_global_pos(self, unit_vec_rerp): 168 | """ 169 | Convert the unit offset matrix to global position. 170 | First row(root) will have absolute position value in global coordinates. 171 | """ 172 | bone_length = self.get_bone_length_weight() 173 | batch_size = unit_vec_rerp.size(0) 174 | seq_len = unit_vec_rerp.size(1) 175 | unit_vec_table = unit_vec_rerp.reshape(batch_size, seq_len, 22, 3) 176 | global_position = torch.zeros_like(unit_vec_table, device=unit_vec_table.device) 177 | 178 | for i, parent in enumerate(self._parents): 179 | if parent == -1: # if root 180 | global_position[:, :, i] = unit_vec_table[:, :, i] 181 | 182 | else: 183 | global_position[:, :, i] = global_position[:, :, parent] + ( 184 | nn.functional.normalize(unit_vec_table[:, :, i], p=2.0, dim=-1) 185 | * bone_length[i] 186 | ) 187 | 188 | return global_position 189 | 190 | def convert_to_unit_offset_mat(self, global_position): 191 | """ 192 | Convert the global position of the skeleton to a unit offset matrix. 193 | First row(root) will have absolute position value in global coordinates. 194 | """ 195 | 196 | bone_length = self.get_bone_length_weight() 197 | unit_offset_mat = torch.zeros_like( 198 | global_position, device=global_position.device 199 | ) 200 | 201 | for i, parent in enumerate(self._parents): 202 | 203 | if parent == -1: # if root 204 | unit_offset_mat[:, :, i] = global_position[:, :, i] 205 | else: 206 | unit_offset_mat[:, :, i] = ( 207 | global_position[:, :, i] - global_position[:, :, parent] 208 | ) / bone_length[i] 209 | 210 | return unit_offset_mat 211 | 212 | def remove_joints(self, joints_to_remove): 213 | """ 214 | Remove the joints specified in 'joints_to_remove', both from the 215 | skeleton definition and from the dataset (which is modified in place). 216 | The rotations of removed joints are propagated along the kinematic chain. 217 | """ 218 | valid_joints = [] 219 | for joint in range(len(self._parents)): 220 | if joint not in joints_to_remove: 221 | valid_joints.append(joint) 222 | 223 | index_offsets = np.zeros(len(self._parents), dtype=int) 224 | new_parents = [] 225 | for i, parent in enumerate(self._parents): 226 | if i not in joints_to_remove: 227 | new_parents.append(parent - index_offsets[parent]) 228 | else: 229 | index_offsets[i:] += 1 230 | self._parents = np.array(new_parents) 231 | 232 | self._offsets = self._offsets[valid_joints] 233 | self._compute_metadata() 234 | 235 | def forward_kinematics(self, rotations, root_positions): 236 | """ 237 | Perform forward kinematics using the given trajectory and local rotations. 238 | Arguments (where N = batch size, L = sequence length, J = number of joints): 239 | -- rotations: (N, L, J, 4) tensor of unit quaternions describing the local rotations of each joint. 240 | -- root_positions: (N, L, 3) tensor describing the root joint positions. 241 | """ 242 | assert len(rotations.shape) == 4 243 | assert rotations.shape[-1] == 4 244 | 245 | positions_world = [] 246 | rotations_world = [] 247 | 248 | expanded_offsets = self._offsets.expand( 249 | rotations.shape[0], 250 | rotations.shape[1], 251 | self._offsets.shape[0], 252 | self._offsets.shape[1], 253 | ) 254 | 255 | # Parallelize along the batch and time dimensions 256 | for i in range(self._offsets.shape[0]): 257 | if self._parents[i] == -1: 258 | positions_world.append(root_positions) 259 | rotations_world.append(rotations[:, :, 0]) 260 | else: 261 | positions_world.append( 262 | qrot(rotations_world[self._parents[i]], expanded_offsets[:, :, i]) 263 | + positions_world[self._parents[i]] 264 | ) 265 | if self._has_children[i]: 266 | rotations_world.append( 267 | qmul(rotations_world[self._parents[i]], rotations[:, :, i]) 268 | ) 269 | else: 270 | # This joint is a terminal node -> it would be useless to compute the transformation 271 | rotations_world.append(None) 272 | 273 | return torch.stack(positions_world, dim=3).permute(0, 1, 3, 2) 274 | 275 | def forward_kinematics_with_rotation(self, rotations, root_positions): 276 | """ 277 | Perform forward kinematics using the given trajectory and local rotations. 278 | Arguments (where N = batch size, L = sequence length, J = number of joints): 279 | -- rotations: (N, L, J, 4) tensor of unit quaternions describing the local rotations of each joint. 280 | -- root_positions: (N, L, 3) tensor describing the root joint positions. 281 | """ 282 | assert len(rotations.shape) == 4 283 | assert rotations.shape[-1] == 4 284 | 285 | positions_world = [] 286 | rotations_world = [] 287 | 288 | expanded_offsets = self._offsets.expand( 289 | rotations.shape[0], 290 | rotations.shape[1], 291 | self._offsets.shape[0], 292 | self._offsets.shape[1], 293 | ) 294 | 295 | # Parallelize along the batch and time dimensions 296 | for i in range(self._offsets.shape[0]): 297 | if self._parents[i] == -1: 298 | positions_world.append(root_positions) 299 | rotations_world.append(rotations[:, :, 0]) 300 | else: 301 | positions_world.append( 302 | qrot(rotations_world[self._parents[i]], expanded_offsets[:, :, i]) 303 | + positions_world[self._parents[i]] 304 | ) 305 | if self._has_children[i]: 306 | rotations_world.append( 307 | qmul(rotations_world[self._parents[i]], rotations[:, :, i]) 308 | ) 309 | else: 310 | # This joint is a terminal node -> it would be useless to compute the transformation 311 | rotations_world.append( 312 | torch.Tensor([1, 0, 0, 0]) 313 | .expand(rotations.shape[0], rotations.shape[1], 4) 314 | .to(rotations.device) 315 | ) 316 | 317 | return torch.stack(positions_world, dim=3).permute(0, 1, 3, 2), torch.stack( 318 | rotations_world, dim=3 319 | ).permute(0, 1, 3, 2) 320 | 321 | def get_bone_length_weight(self): 322 | bone_length = [] 323 | for i, parent in enumerate(self._parents): 324 | if parent == -1: 325 | bone_length.append(1) 326 | else: 327 | bone_length.append( 328 | torch.linalg.norm(self._offsets[i : i + 1], ord="fro").item() 329 | ) 330 | return torch.Tensor(bone_length) 331 | 332 | def joints_left(self): 333 | return self._joints_left 334 | 335 | def joints_right(self): 336 | return self._joints_right 337 | 338 | def _compute_metadata(self): 339 | self._has_children = np.zeros(len(self._parents)).astype(bool) 340 | for i, parent in enumerate(self._parents): 341 | if parent != -1: 342 | self._has_children[parent] = True 343 | 344 | self._children = [] 345 | for i, parent in enumerate(self._parents): 346 | self._children.append([]) 347 | for i, parent in enumerate(self._parents): 348 | if parent != -1: 349 | self._children[parent].append(i) 350 | -------------------------------------------------------------------------------- /cmib/vis/pose.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | 7 | 8 | def project_root_position(position_arr: np.array, file_name: str): 9 | """ 10 | Take batch of root arrays and porject it on 2D plane 11 | 12 | N: samples 13 | L: trajectory length 14 | J: joints 15 | 16 | position_arr: [N,L,J,3] 17 | """ 18 | 19 | root_joints = position_arr[:, :, 0] 20 | 21 | x_pos = root_joints[:, :, 0] 22 | y_pos = root_joints[:, :, 2] 23 | 24 | fig = plt.figure() 25 | 26 | for i in range(x_pos.shape[1]): 27 | 28 | if i == 0: 29 | plt.scatter(x_pos[:, i], y_pos[:, i], c="b") 30 | elif i == x_pos.shape[1] - 1: 31 | plt.scatter(x_pos[:, i], y_pos[:, i], c="r") 32 | else: 33 | plt.scatter(x_pos[:, i], y_pos[:, i], c="k", marker="*", s=1) 34 | 35 | plt.title(f"Root Position: {file_name}") 36 | plt.xlabel("X Axis") 37 | plt.ylabel("Y Axis") 38 | plt.xlim((-300, 300)) 39 | plt.ylim((-300, 300)) 40 | plt.grid() 41 | plt.savefig(f"{file_name}.png", dpi=200) 42 | 43 | 44 | def plot_single_pose( 45 | pose, 46 | frame_idx, 47 | skeleton, 48 | save_dir, 49 | prefix, 50 | ): 51 | fig = plt.figure() 52 | ax = fig.add_subplot(111, projection="3d") 53 | 54 | parent_idx = skeleton.parents() 55 | 56 | for i, p in enumerate(parent_idx): 57 | if i > 0: 58 | ax.plot( 59 | [pose[i, 0], pose[p, 0]], 60 | [pose[i, 2], pose[p, 2]], 61 | [pose[i, 1], pose[p, 1]], 62 | c="k", 63 | ) 64 | 65 | x_min = pose[:, 0].min() 66 | x_max = pose[:, 0].max() 67 | 68 | y_min = pose[:, 1].min() 69 | y_max = pose[:, 1].max() 70 | 71 | z_min = pose[:, 2].min() 72 | z_max = pose[:, 2].max() 73 | 74 | ax.set_xlim(x_min, x_max) 75 | ax.set_xlabel("$X$ Axis") 76 | 77 | ax.set_ylim(z_min, z_max) 78 | ax.set_ylabel("$Y$ Axis") 79 | 80 | ax.set_zlim(y_min, y_max) 81 | ax.set_zlabel("$Z$ Axis") 82 | 83 | plt.draw() 84 | 85 | title = f"{prefix}: {frame_idx}" 86 | plt.title(title) 87 | prefix = prefix 88 | pathlib.Path(save_dir).mkdir(parents=True, exist_ok=True) 89 | plt.savefig(os.path.join(save_dir, prefix + str(frame_idx) + ".png"), dpi=60) 90 | plt.close() 91 | 92 | 93 | def plot_pose( 94 | start_pose, 95 | inbetween_pose, 96 | target_pose, 97 | frame_idx, 98 | skeleton, 99 | save_dir, 100 | prefix, 101 | ): 102 | fig = plt.figure() 103 | ax = fig.add_subplot(111, projection="3d") 104 | 105 | parent_idx = skeleton.parents() 106 | 107 | for i, p in enumerate(parent_idx): 108 | if i > 0: 109 | ax.plot( 110 | [start_pose[i, 0], start_pose[p, 0]], 111 | [start_pose[i, 2], start_pose[p, 2]], 112 | [start_pose[i, 1], start_pose[p, 1]], 113 | c="b", 114 | ) 115 | ax.plot( 116 | [inbetween_pose[i, 0], inbetween_pose[p, 0]], 117 | [inbetween_pose[i, 2], inbetween_pose[p, 2]], 118 | [inbetween_pose[i, 1], inbetween_pose[p, 1]], 119 | c="k", 120 | ) 121 | ax.plot( 122 | [target_pose[i, 0], target_pose[p, 0]], 123 | [target_pose[i, 2], target_pose[p, 2]], 124 | [target_pose[i, 1], target_pose[p, 1]], 125 | c="r", 126 | ) 127 | 128 | x_min = np.min( 129 | [start_pose[:, 0].min(), inbetween_pose[:, 0].min(), target_pose[:, 0].min()] 130 | ) 131 | x_max = np.max( 132 | [start_pose[:, 0].max(), inbetween_pose[:, 0].max(), target_pose[:, 0].max()] 133 | ) 134 | 135 | y_min = np.min( 136 | [start_pose[:, 1].min(), inbetween_pose[:, 1].min(), target_pose[:, 1].min()] 137 | ) 138 | y_max = np.max( 139 | [start_pose[:, 1].max(), inbetween_pose[:, 1].max(), target_pose[:, 1].max()] 140 | ) 141 | 142 | z_min = np.min( 143 | [start_pose[:, 2].min(), inbetween_pose[:, 2].min(), target_pose[:, 2].min()] 144 | ) 145 | z_max = np.max( 146 | [start_pose[:, 2].max(), inbetween_pose[:, 2].max(), target_pose[:, 2].max()] 147 | ) 148 | 149 | ax.set_xlim(x_min, x_max) 150 | ax.set_xlabel("$X$ Axis") 151 | 152 | ax.set_ylim(z_min, z_max) 153 | ax.set_ylabel("$Y$ Axis") 154 | 155 | ax.set_zlim(y_min, y_max) 156 | ax.set_zlabel("$Z$ Axis") 157 | 158 | plt.draw() 159 | 160 | title = f"{prefix}: {frame_idx}" 161 | plt.title(title) 162 | prefix = prefix 163 | pathlib.Path(save_dir).mkdir(parents=True, exist_ok=True) 164 | plt.savefig(os.path.join(save_dir, prefix + str(frame_idx) + ".png"), dpi=60) 165 | plt.close() 166 | 167 | 168 | def plot_pose_with_stop( 169 | start_pose, 170 | inbetween_pose, 171 | target_pose, 172 | stopover, 173 | frame_idx, 174 | skeleton, 175 | save_dir, 176 | prefix, 177 | ): 178 | fig = plt.figure() 179 | ax = fig.add_subplot(111, projection="3d") 180 | 181 | parent_idx = skeleton.parents() 182 | 183 | for i, p in enumerate(parent_idx): 184 | if i > 0: 185 | ax.plot( 186 | [start_pose[i, 0], start_pose[p, 0]], 187 | [start_pose[i, 2], start_pose[p, 2]], 188 | [start_pose[i, 1], start_pose[p, 1]], 189 | c="b", 190 | ) 191 | ax.plot( 192 | [inbetween_pose[i, 0], inbetween_pose[p, 0]], 193 | [inbetween_pose[i, 2], inbetween_pose[p, 2]], 194 | [inbetween_pose[i, 1], inbetween_pose[p, 1]], 195 | c="k", 196 | ) 197 | ax.plot( 198 | [target_pose[i, 0], target_pose[p, 0]], 199 | [target_pose[i, 2], target_pose[p, 2]], 200 | [target_pose[i, 1], target_pose[p, 1]], 201 | c="r", 202 | ) 203 | 204 | ax.plot( 205 | [stopover[i, 0], stopover[p, 0]], 206 | [stopover[i, 2], stopover[p, 2]], 207 | [stopover[i, 1], stopover[p, 1]], 208 | c="indigo", 209 | ) 210 | 211 | x_min = np.min( 212 | [start_pose[:, 0].min(), inbetween_pose[:, 0].min(), target_pose[:, 0].min()] 213 | ) 214 | x_max = np.max( 215 | [start_pose[:, 0].max(), inbetween_pose[:, 0].max(), target_pose[:, 0].max()] 216 | ) 217 | 218 | y_min = np.min( 219 | [start_pose[:, 1].min(), inbetween_pose[:, 1].min(), target_pose[:, 1].min()] 220 | ) 221 | y_max = np.max( 222 | [start_pose[:, 1].max(), inbetween_pose[:, 1].max(), target_pose[:, 1].max()] 223 | ) 224 | 225 | z_min = np.min( 226 | [start_pose[:, 2].min(), inbetween_pose[:, 2].min(), target_pose[:, 2].min()] 227 | ) 228 | z_max = np.max( 229 | [start_pose[:, 2].max(), inbetween_pose[:, 2].max(), target_pose[:, 2].max()] 230 | ) 231 | 232 | ax.set_xlim(x_min, x_max) 233 | ax.set_xlabel("$X$ Axis") 234 | 235 | ax.set_ylim(z_min, z_max) 236 | ax.set_ylabel("$Y$ Axis") 237 | 238 | ax.set_zlim(y_min, y_max) 239 | ax.set_zlabel("$Z$ Axis") 240 | 241 | plt.draw() 242 | 243 | title = f"{prefix}: {frame_idx}" 244 | plt.title(title) 245 | prefix = prefix 246 | pathlib.Path(save_dir).mkdir(parents=True, exist_ok=True) 247 | plt.savefig(os.path.join(save_dir, prefix + str(frame_idx) + ".png"), dpi=60) 248 | plt.close() 249 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | black==24.3.0 2 | certifi==2023.7.22 3 | charset-normalizer==2.0.9 4 | click==8.0.3 5 | configparser==5.2.0 6 | cycler==0.11.0 7 | docker-pycreds==0.4.0 8 | fonttools==4.43.0 9 | gitdb==4.0.9 10 | GitPython==3.1.41 11 | idna==3.7 12 | imageio==2.13.3 13 | importlib-metadata==4.8.2 14 | joblib==1.2.0 15 | kiwisolver==1.3.2 16 | matplotlib==3.5.1 17 | mypy-extensions==0.4.3 18 | numpy==1.22.0 19 | packaging==21.3 20 | pathspec==0.9.0 21 | pathtools==0.1.2 22 | Pillow==10.3.0 23 | platformdirs==2.4.0 24 | promise==2.3 25 | protobuf==3.19.5 26 | psutil==5.8.0 27 | pyparsing==3.0.6 28 | python-dateutil==2.8.2 29 | PyYAML==6.0 30 | requests==2.31.0 31 | scikit-learn==1.0.1 32 | scipy==1.10.0 33 | sentry-sdk==1.14.0 34 | shortuuid==1.0.8 35 | six==1.16.0 36 | sklearn==0.0 37 | smmap==5.0.0 38 | subprocess32==3.5.4 39 | termcolor==1.1.0 40 | threadpoolctl==3.0.0 41 | tomli==1.2.3 42 | torch==1.13.1 43 | torchaudio==0.10.1+cu113 44 | torchvision==0.11.2+cu113 45 | tqdm==4.62.3 46 | typed-ast==1.5.1 47 | typing_extensions==4.0.1 48 | urllib3==1.26.18 49 | wandb==0.12.7 50 | yaspin==2.1.0 51 | zipp==3.6.0 52 | -------------------------------------------------------------------------------- /run_cmib.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from pathlib import Path 4 | 5 | import imageio 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from PIL import Image 10 | from sklearn.preprocessing import LabelEncoder 11 | 12 | from cmib.data.lafan1_dataset import LAFAN1Dataset 13 | from cmib.data.utils import write_json 14 | from cmib.lafan1.utils import quat_ik 15 | from cmib.model.network import TransformerModel 16 | from cmib.model.preprocess import (lerp_input_repr, replace_constant, 17 | slerp_input_repr, vectorize_representation) 18 | from cmib.model.skeleton import (Skeleton, joint_names, sk_joints_to_remove, 19 | sk_offsets, sk_parents) 20 | from cmib.vis.pose import plot_pose_with_stop 21 | 22 | 23 | def test(opt, device): 24 | 25 | save_dir = Path(os.path.join('runs', 'train', opt.exp_name)) 26 | wdir = save_dir / 'weights' 27 | weights = os.listdir(wdir) 28 | 29 | if opt.weight == 'latest': 30 | weights_paths = [wdir / weight for weight in weights] 31 | weight_path = max(weights_paths , key = os.path.getctime) 32 | else: 33 | weight_path = wdir / ('train-' + opt.weight + '.pt') 34 | ckpt = torch.load(weight_path, map_location=device) 35 | print(f"Loaded weight: {weight_path}") 36 | 37 | 38 | # Load Skeleton 39 | skeleton_mocap = Skeleton(offsets=sk_offsets, parents=sk_parents, device=device) 40 | skeleton_mocap.remove_joints(sk_joints_to_remove) 41 | 42 | # Load LAFAN Dataset 43 | Path(opt.processed_data_dir).mkdir(parents=True, exist_ok=True) 44 | test_window = ckpt['horizon'] - 1 + 10 45 | lafan_dataset = LAFAN1Dataset(lafan_path=opt.data_path, processed_data_dir=opt.processed_data_dir, train=False, device=device, window=test_window) 46 | total_data = lafan_dataset.data['global_pos'].shape[0] 47 | 48 | # Replace with noise to In-betweening Frames 49 | from_idx, target_idx = ckpt['from_idx'], ckpt['target_idx'] # default: 9-40, max: 48 50 | horizon = ckpt['horizon'] 51 | print(f"HORIZON: {horizon}") 52 | 53 | test_idx = [950, 1140, 2100] 54 | 55 | # Extract dimension from processed data 56 | pos_dim = lafan_dataset.num_joints * 3 57 | rot_dim = lafan_dataset.num_joints * 4 58 | repr_dim = pos_dim + rot_dim 59 | 60 | root_pos = torch.Tensor(lafan_dataset.data['root_p'][:, from_idx:target_idx+1]).to(device) 61 | local_q = torch.Tensor(lafan_dataset.data['local_q'][:, from_idx:target_idx+1]).to(device) 62 | local_q_normalized = nn.functional.normalize(local_q, p=2.0, dim=-1) 63 | 64 | # Replace testing inputs 65 | fixed = 0 66 | global_pos, global_q = skeleton_mocap.forward_kinematics_with_rotation(local_q_normalized, root_pos) 67 | 68 | interpolation = ckpt['interpolation'] 69 | print(f"Interpolation Mode: {interpolation}") 70 | 71 | if interpolation == 'constant': 72 | global_pose_vec_gt = vectorize_representation(global_pos, global_q) 73 | global_pose_vec_input = global_pose_vec_gt.clone().detach() 74 | pose_interpolated_input = replace_constant(global_pose_vec_input, fixed) 75 | input_pos = pose_interpolated_input[:,:,:pos_dim].detach().numpy() 76 | 77 | elif interpolation == 'slerp': 78 | global_pose_vec_gt = vectorize_representation(global_pos, global_q) 79 | global_pose_vec_input = global_pose_vec_gt.clone().detach() 80 | root_vec = global_pose_vec_input[:,:,:pos_dim] 81 | rot_vec = global_pose_vec_input[:,:,pos_dim:] 82 | root_lerped = lerp_input_repr(root_vec, fixed) 83 | rot_slerped = slerp_input_repr(rot_vec, fixed) 84 | pose_interpolated_input = torch.cat([root_lerped, rot_slerped], dim=2) 85 | input_pos = pose_interpolated_input[:,:,:pos_dim].detach().numpy() 86 | 87 | else: 88 | raise ValueError('Invalid interpolation method') 89 | 90 | pose_vectorized_input = pose_interpolated_input.permute(1,0,2) 91 | 92 | src_mask = torch.zeros((horizon, horizon), device=device).type(torch.bool) 93 | src_mask = src_mask.to(device) 94 | 95 | seq_categories = [x[:-1] for x in lafan_dataset.data['seq_names']] 96 | 97 | le = LabelEncoder() 98 | le.classes_ = np.load(os.path.join(save_dir, 'le_classes_.npy')) 99 | 100 | target_seq = opt.motion_type 101 | seq_id = np.where(le.classes_==target_seq)[0] 102 | conditioning_labels = np.expand_dims((np.repeat(seq_id[0], repeats=len(seq_categories))), axis=1) 103 | conditioning_labels = torch.Tensor(conditioning_labels).type(torch.int64).to(device) 104 | 105 | model = TransformerModel(seq_len=ckpt['horizon'], d_model=ckpt['d_model'], nhead=ckpt['nhead'], d_hid=ckpt['d_hid'], nlayers=ckpt['nlayers'], dropout=0.05, out_dim=repr_dim) 106 | model.load_state_dict(ckpt['transformer_encoder_state_dict']) 107 | model.eval() 108 | 109 | output, _ = model(pose_vectorized_input, src_mask, conditioning_labels) 110 | 111 | pred_global_pos = output[1:,:,:pos_dim].permute(1,0,2).reshape(total_data,horizon-1,22,3) 112 | global_pos_unit_vec = skeleton_mocap.convert_to_unit_offset_mat(pred_global_pos) 113 | pred_global_pos = skeleton_mocap.convert_to_global_pos(global_pos_unit_vec).detach().numpy() 114 | 115 | pred_global_rot = output[1:,:,pos_dim:].permute(1,0,2).reshape(total_data,horizon-1,22,4) 116 | pred_global_rot_normalized = nn.functional.normalize(pred_global_rot, p=2.0, dim=3).detach().numpy() 117 | 118 | clue = global_pos.clone().detach() 119 | 120 | # Compare Input data, Prediction, GT 121 | for i in range(len(test_idx)): 122 | save_path = os.path.join(opt.save_path, 'test_' + f'{test_idx[i]}') 123 | Path(save_path).mkdir(parents=True, exist_ok=True) 124 | pred_json_path = os.path.join(save_path, 'pred_json') 125 | Path(pred_json_path).mkdir(parents=True, exist_ok=True) 126 | gt_json_path = os.path.join(save_path, 'gt_json') 127 | Path(gt_json_path).mkdir(parents=True, exist_ok=True) 128 | 129 | start_pose = lafan_dataset.data['global_pos'][test_idx[i], from_idx] 130 | target_pose = lafan_dataset.data['global_pos'][test_idx[i], target_idx] 131 | stopover_pose = clue[test_idx[i],fixed] 132 | stopover_rot = global_q[test_idx[i],fixed] 133 | gt_stopover_pose = lafan_dataset.data['global_pos'][test_idx[i], from_idx + fixed] 134 | 135 | # Replace start/end with gt 136 | pred_global_pos[test_idx[i], 0] = start_pose 137 | 138 | gpos = pred_global_pos[test_idx[i]] 139 | grot = pred_global_rot_normalized[test_idx[i]] 140 | 141 | local_quaternion_stopover, local_positions_stopover = quat_ik(stopover_rot.detach().numpy(), stopover_pose.detach().numpy(), parents=skeleton_mocap.parents()) 142 | local_quaternion, local_positions = quat_ik(grot, gpos, parents=skeleton_mocap.parents()) 143 | 144 | img_aggr_list = [] 145 | 146 | write_json(filename=os.path.join(pred_json_path, f'start.json'), local_q=local_quaternion[0], root_pos=local_positions[0,0], joint_names=joint_names) 147 | write_json(filename=os.path.join(pred_json_path, f'target.json'), local_q=local_quaternion[-1], root_pos=local_positions[-1,0], joint_names=joint_names) 148 | write_json(filename=os.path.join(pred_json_path, f'stopover.json'), local_q=local_quaternion_stopover, root_pos=local_positions_stopover[0], joint_names=joint_names) 149 | 150 | write_json(filename=os.path.join(gt_json_path, f'start.json'), local_q=local_quaternion[0], root_pos=local_positions[0,0], joint_names=joint_names) 151 | write_json(filename=os.path.join(gt_json_path, f'target.json'), local_q=local_quaternion[-1], root_pos=local_positions[-1,0], joint_names=joint_names) 152 | write_json(filename=os.path.join(gt_json_path, f'stopover.json'), local_q=local_quaternion_stopover, root_pos=local_positions_stopover[0], joint_names=joint_names) 153 | 154 | for t in range(horizon-1): 155 | 156 | if opt.plot_image: 157 | input_img_path = os.path.join(save_path, 'input') 158 | pred_img_path = os.path.join(save_path, 'pred_img') 159 | gt_img_path = os.path.join(save_path, 'gt_img') 160 | 161 | plot_pose_with_stop(start_pose, input_pos[test_idx[i],t].reshape(lafan_dataset.num_joints, 3), target_pose, stopover_pose, t, skeleton_mocap, save_dir=input_img_path, prefix='input') 162 | plot_pose_with_stop(start_pose, pred_global_pos[test_idx[i],t].reshape(lafan_dataset.num_joints, 3), target_pose, stopover_pose, t, skeleton_mocap, save_dir=pred_img_path, prefix='pred') 163 | plot_pose_with_stop(start_pose, lafan_dataset.data['global_pos'][test_idx[i], t+from_idx], target_pose, gt_stopover_pose, t, skeleton_mocap, save_dir=gt_img_path, prefix='gt') 164 | 165 | input_img = Image.open(os.path.join(input_img_path, 'input'+str(t)+'.png'), 'r') 166 | pred_img = Image.open(os.path.join(pred_img_path, 'pred'+str(t)+'.png'), 'r') 167 | gt_img = Image.open(os.path.join(gt_img_path, 'gt'+str(t)+'.png'), 'r') 168 | 169 | img_aggr_list.append(np.concatenate([input_img, pred_img, gt_img.resize(pred_img.size)], 1)) 170 | 171 | write_json(filename=os.path.join(pred_json_path, f'{t:05}.json'), local_q=local_quaternion[t], root_pos=local_positions[t,0], joint_names=joint_names) 172 | write_json(filename=os.path.join(gt_json_path, f'{t:05}.json'), local_q=lafan_dataset.data['local_q'][test_idx[i], from_idx + t], root_pos=lafan_dataset.data['global_pos'][test_idx[i], from_idx + t, 0], joint_names=joint_names) 173 | 174 | # Save images 175 | if opt.plot_image: 176 | gif_path = os.path.join(save_path, f'img_{test_idx[i]}.gif') 177 | imageio.mimsave(gif_path, img_aggr_list, duration=0.1) 178 | print(f"ID {test_idx[i]}: test completed.") 179 | 180 | def parse_opt(): 181 | parser = argparse.ArgumentParser() 182 | parser.add_argument('--project', default='runs/train', help='project/name') 183 | parser.add_argument('--weight', default='latest') 184 | parser.add_argument('--exp_name', default='exp', help='experiment name') 185 | parser.add_argument('--data_path', type=str, default='ubisoft-laforge-animation-dataset/output/BVH', help='BVH dataset path') 186 | parser.add_argument('--skeleton_path', type=str, default='ubisoft-laforge-animation-dataset/output/BVH/walk1_subject1.bvh', help='path to reference skeleton') 187 | parser.add_argument('--processed_data_dir', type=str, default='processed_data_80/', help='path to save pickled processed data') 188 | parser.add_argument('--save_path', type=str, default='runs/test', help='path to save model') 189 | parser.add_argument('--motion_type', type=str, default='jumps', help='motion type') 190 | parser.add_argument('--plot_image', type=bool, default=False, help='plot image') 191 | opt = parser.parse_args() 192 | return opt 193 | 194 | if __name__ == "__main__": 195 | opt = parse_opt() 196 | device = torch.device("cpu") 197 | test(opt, device) 198 | -------------------------------------------------------------------------------- /run_test_multi.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | 4 | for epoch in range(1000, 1200, 10): 5 | print(f"Epochs: {epoch}") 6 | subprocess.run(["python", "test_benchmark.py", 7 | "--project", "runs/train", 8 | "--exp_name", "slerp30_qnorm_final_bc", 9 | "--weight", str(epoch), 10 | "--processed_data_dir", "processed_data_original_bc/"]) 11 | 12 | -------------------------------------------------------------------------------- /test_benchmark.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import pickle 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | from sklearn.preprocessing import LabelEncoder 11 | 12 | from cmib.data.lafan1_dataset import LAFAN1Dataset 13 | from cmib.lafan1 import benchmarks, extract 14 | from cmib.data.utils import process_seq_names 15 | from cmib.model.network import TransformerModel 16 | from cmib.model.preprocess import (lerp_input_repr, replace_constant, 17 | slerp_input_repr, vectorize_representation) 18 | from cmib.model.skeleton import (Skeleton, sk_joints_to_remove, sk_offsets, sk_parents, amass_offsets) 19 | 20 | 21 | def test(opt, device): 22 | 23 | save_dir = Path(os.path.join('runs', 'train', opt.exp_name)) 24 | test_dir = Path(os.path.join('runs', 'test', opt.exp_name)) 25 | test_dir.mkdir(exist_ok=True, parents=True) 26 | wdir = save_dir / 'weights' 27 | weights = os.listdir(wdir) 28 | 29 | if opt.weight == 'latest': 30 | weights_paths = [wdir / weight for weight in weights] 31 | weight_path = max(weights_paths , key = os.path.getctime) 32 | else: 33 | weight_path = wdir / ('train-' + opt.weight + '.pt') 34 | ckpt = torch.load(weight_path, map_location=device) 35 | print(f"Loaded weight: {weight_path}") 36 | 37 | # Load Skeleton 38 | offset = sk_offsets if opt.dataset == 'LAFAN' else amass_offsets 39 | skeleton_mocap = Skeleton(offsets=offset, parents=sk_parents, device=device) 40 | skeleton_mocap.remove_joints(sk_joints_to_remove) 41 | 42 | # Load LAFAN Dataset 43 | Path(opt.processed_data_dir).mkdir(parents=True, exist_ok=True) 44 | if ckpt['horizon'] < 60: 45 | test_window = 65 # Use default test window for 30-frame benchamrk setting. 46 | else: 47 | test_window = ckpt['horizon'] - 1 + 10 48 | lafan_dataset = LAFAN1Dataset(lafan_path=opt.data_path, processed_data_dir=opt.processed_data_dir, train=False, device=device, window=test_window, dataset=opt.dataset) 49 | 50 | # Extract stats 51 | if opt.dataset == 'LAFAN': 52 | train_actors = ["subject1", "subject2", "subject3", "subject4"] 53 | elif opt.dataset in ['HumanEva', 'PosePrior']: 54 | train_actors = ["subject1", "subject2"] 55 | elif opt.dataset in ['HUMAN4D']: 56 | train_actors = ["subject1", "subject2", "subject3", "subject4", "subject5", "subject6", "subject7"] 57 | elif opt.dataset in ['MPI_HDM05']: 58 | train_actors = ["subject1", "subject2", "subject3"] 59 | else: 60 | ValueError("Invalid Dataset") 61 | 62 | bvh_folder = opt.data_path 63 | stats_file = os.path.join(opt.processed_data_dir, 'train_stats.pkl') 64 | 65 | if not os.path.exists(stats_file): 66 | x_mean, x_std, offsets = extract.get_train_stats(bvh_folder, train_actors) 67 | with open(stats_file, 'wb') as f: 68 | pickle.dump({ 69 | 'x_mean': x_mean, 70 | 'x_std': x_std, 71 | 'offsets': offsets, 72 | }, f, protocol=pickle.HIGHEST_PROTOCOL) 73 | else: 74 | print('Reusing stats file: ' + stats_file) 75 | with open(stats_file, 'rb') as f: 76 | stats = pickle.load(f) 77 | x_mean = stats['x_mean'] 78 | x_std = stats['x_std'] 79 | offsets = stats['offsets'] 80 | 81 | 82 | total_data = lafan_dataset.data['global_pos'].shape[0] 83 | 84 | from_idx, target_idx = ckpt['from_idx'], ckpt['target_idx'] # default: 9-38, max: 48 85 | horizon = ckpt['horizon'] 86 | print(f"HORIZON: {horizon}") 87 | 88 | test_idx = [] 89 | for i in range(total_data): 90 | test_idx.append(i) 91 | 92 | # Extract dimension from processed data 93 | pos_dim = lafan_dataset.num_joints * 3 94 | rot_dim = lafan_dataset.num_joints * 4 95 | repr_dim = pos_dim + rot_dim 96 | 97 | root_pos = torch.Tensor(lafan_dataset.data['root_p'][:, from_idx:target_idx+1]).to(device) 98 | local_q = torch.Tensor(lafan_dataset.data['local_q'][:, from_idx:target_idx+1]).to(device) 99 | local_q_normalized = nn.functional.normalize(local_q, p=2.0, dim=-1) 100 | 101 | # Replace testing inputs 102 | fixed = 0 103 | 104 | global_pos, global_q = skeleton_mocap.forward_kinematics_with_rotation(local_q_normalized, root_pos) 105 | global_pos[:,fixed] += torch.Tensor([0,0,0]).expand(global_pos.size(0),lafan_dataset.num_joints,3) 106 | 107 | interpolation = ckpt['interpolation'] 108 | 109 | if interpolation == 'constant': 110 | global_pose_vec_gt = vectorize_representation(global_pos, global_q) 111 | global_pose_vec_input = global_pose_vec_gt.clone().detach() 112 | pose_interpolated_input = replace_constant(global_pose_vec_input, fixed) 113 | 114 | elif interpolation == 'slerp': 115 | global_pose_vec_gt = vectorize_representation(global_pos, global_q) 116 | global_pose_vec_input = global_pose_vec_gt.clone().detach() 117 | root_vec = global_pose_vec_input[:,:,:pos_dim] 118 | rot_vec = global_pose_vec_input[:,:,pos_dim:] 119 | root_lerped = lerp_input_repr(root_vec, fixed) 120 | rot_slerped = slerp_input_repr(rot_vec, fixed) 121 | pose_interpolated_input = torch.cat([root_lerped, rot_slerped], dim=2) 122 | 123 | else: 124 | raise ValueError('Invalid interpolation method') 125 | 126 | pose_vectorized_input = pose_interpolated_input.permute(1,0,2) 127 | 128 | src_mask = torch.zeros((horizon, horizon), device=device).type(torch.bool) 129 | src_mask = src_mask.to(device) 130 | 131 | le = LabelEncoder() 132 | le.classes_ = np.load(os.path.join(save_dir, 'le_classes_.npy')) 133 | num_labels = len(le.classes_) 134 | 135 | model = TransformerModel(seq_len=ckpt['horizon'], d_model=ckpt['d_model'], nhead=ckpt['nhead'], d_hid=ckpt['d_hid'], nlayers=ckpt['nlayers'], dropout=0.0, out_dim=repr_dim, num_labels=num_labels) 136 | model.load_state_dict(ckpt['transformer_encoder_state_dict']) 137 | model.eval() 138 | 139 | l2p = [] 140 | l2q = [] 141 | 142 | pred_rot_npss = [] 143 | for i in range(len(test_idx)): 144 | print(f"Processing ID: {test_idx[i]}") 145 | 146 | seq_label = lafan_dataset.data['seq_names'][i][:-1] 147 | 148 | if opt.dataset == 'LAFAN': 149 | seq_label = [x[:-1] for x in lafan_dataset.data['seq_names']][i] 150 | else: 151 | seq_label = process_seq_names(lafan_dataset.data['seq_names'], dataset=opt.dataset)[i] 152 | 153 | match_class = np.where(le.classes_ == seq_label)[0] 154 | class_id = 0 if len(match_class) == 0 else match_class[0] 155 | 156 | conditioning_label = torch.Tensor([[class_id]]).type(torch.int64).to(device) 157 | cond_output, cond_gt = model(pose_vectorized_input[:, test_idx[i]:test_idx[i]+1, :], src_mask, conditioning_label) 158 | print(f"Condition: {le.classes_[class_id]}") 159 | 160 | output = cond_output 161 | 162 | pred_global_pos = output[1:,:,:pos_dim].permute(1,0,2).reshape(1,horizon-1,22,3) 163 | global_pos_unit_vec = skeleton_mocap.convert_to_unit_offset_mat(pred_global_pos) 164 | pred_global_pos = skeleton_mocap.convert_to_global_pos(global_pos_unit_vec).detach().numpy() 165 | 166 | # Replace start/end with gt 167 | gt_global_pos = lafan_dataset.data['global_pos'][test_idx[i]:test_idx[i]+1, from_idx:target_idx+1].reshape(1, -1, lafan_dataset.num_joints, 3) 168 | pred_global_pos[0,0] = gt_global_pos[0,0] 169 | pred_global_pos[0,-1] = gt_global_pos[0,-1] 170 | 171 | pred_global_rot = output[1:,:,pos_dim:].permute(1,0,2).reshape(1,horizon-1,22,4) 172 | pred_global_rot_normalized = nn.functional.normalize(pred_global_rot, p=2.0, dim=3) 173 | gt_global_rot = global_q[test_idx[i]:test_idx[i]+1] 174 | pred_global_rot_normalized[0,0] = gt_global_rot[0,0] 175 | pred_global_rot_normalized[0,-1] = gt_global_rot[0,-1] 176 | pred_rot_npss.append(pred_global_rot_normalized) 177 | 178 | # Normalize for L2P 179 | normalized_gt_pos = torch.Tensor((lafan_dataset.data['global_pos'][test_idx[i]:test_idx[i]+1, from_idx:target_idx+1].reshape(1, -1, lafan_dataset.num_joints * 3).transpose(0,2,1) - x_mean) / x_std) 180 | normalized_pred_pos = torch.Tensor((pred_global_pos.reshape(1, -1, lafan_dataset.num_joints * 3).transpose(0,2,1) - x_mean) / x_std) 181 | 182 | l2p.append(torch.mean(torch.norm(normalized_pred_pos[0] - normalized_gt_pos[0], dim=(0))).item()) 183 | l2q.append(torch.mean(torch.norm(pred_global_rot_normalized[0] - global_q[test_idx[i]], dim=(1,2))).item()) 184 | print(f"ID {test_idx[i]}: test completed.") 185 | 186 | l2p_mean = np.mean(l2p) 187 | l2q_mean = np.mean(l2q) 188 | 189 | # Drop end nodes for fair comparison 190 | pred_quaternions = torch.cat(pred_rot_npss, dim=0) 191 | npss_gt = global_q[:,:,skeleton_mocap.has_children()].reshape(global_q.shape[0],global_q.shape[1], -1) 192 | npss_pred = pred_quaternions[:,:,skeleton_mocap.has_children()].reshape(pred_quaternions.shape[0],pred_quaternions.shape[1], -1) 193 | npss = benchmarks.npss(npss_gt, npss_pred).item() 194 | 195 | print(f"TOTAL TEST DATA: {len(l2p)}") 196 | print(f"L2P: {l2p_mean}") 197 | print(f"L2Q: {l2q_mean}") 198 | print(f"NPSS: {npss}") 199 | 200 | benchmark_out = { 201 | 'total_data': len(l2p), 202 | 'L2P': l2p_mean, 203 | 'L2Q': l2q_mean, 204 | 'NPSS': npss 205 | } 206 | 207 | with open(os.path.join(test_dir, f'benchmark_out-{opt.weight}.json'), 'w') as f: 208 | json.dump(benchmark_out, f) 209 | 210 | 211 | def parse_opt(): 212 | parser = argparse.ArgumentParser() 213 | parser.add_argument('--project', default='runs/train', help='project/name') 214 | parser.add_argument('--exp_name', default='HUMAN4D_80', help='experiment name') 215 | parser.add_argument('--weight', default='latest') 216 | parser.add_argument('--data_path', type=str, default='AMASS/PosePrior', help='BVH dataset path') 217 | parser.add_argument('--dataset', type=str, default='HUMAN4D', help='Dataset name') 218 | parser.add_argument('--processed_data_dir', type=str, default='processed_data_human4d_80/', help='path to save pickled processed data') 219 | opt = parser.parse_args() 220 | return opt 221 | 222 | if __name__ == "__main__": 223 | opt = parse_opt() 224 | device = torch.device("cpu") 225 | test(opt, device) 226 | -------------------------------------------------------------------------------- /test_condition_comp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import pickle 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | from sklearn.preprocessing import LabelEncoder 11 | 12 | from cmib.data.lafan1_dataset import LAFAN1Dataset 13 | from cmib.data.utils import drop_end_quat 14 | from cmib.lafan1 import extract 15 | from cmib.model.network import TransformerModel 16 | from cmib.model.preprocess import (lerp_input_repr, replace_constant, 17 | slerp_input_repr, vectorize_representation) 18 | from cmib.model.skeleton import (Skeleton, sk_joints_to_remove, sk_offsets, 19 | sk_parents) 20 | 21 | 22 | def test(opt, device): 23 | 24 | save_dir = Path(os.path.join('runs', 'train', opt.exp_name)) 25 | gt_motion = opt.data_path.split('/')[-1].split('_')[-1].lower() 26 | Path(os.path.join('cond_bch', gt_motion)).mkdir(parents=True, exist_ok=True) 27 | wdir = save_dir / 'weights' 28 | weights = os.listdir(wdir) 29 | weights_paths = [wdir / weight for weight in weights] 30 | latest_weight = max(weights_paths , key = os.path.getctime) 31 | ckpt = torch.load(latest_weight, map_location=device) 32 | print(f"Loaded weight: {latest_weight}") 33 | 34 | # Load Skeleton 35 | skeleton_mocap = Skeleton(offsets=sk_offsets, parents=sk_parents, device=device) 36 | skeleton_mocap.remove_joints(sk_joints_to_remove) 37 | 38 | # Load LAFAN Dataset 39 | processed_data_dir = 'condition_test_' + gt_motion 40 | Path(processed_data_dir).mkdir(parents=True, exist_ok=True) 41 | test_window = ckpt['horizon'] - 1 + 10 42 | print(f"Test Window: {test_window}") 43 | lafan_dataset = LAFAN1Dataset(lafan_path=opt.data_path, processed_data_dir=processed_data_dir, train=False, device=device, window=test_window) 44 | 45 | # Extract stats 46 | train_actors = ['subject1', 'subject2', 'subject3', 'subject4'] 47 | bvh_folder = os.path.join('ubisoft-laforge-animation-dataset', 'output', 'BVH') 48 | stats_file = os.path.join(opt.train_stat, 'train_stats.pkl') 49 | 50 | if not os.path.exists(stats_file): 51 | x_mean, x_std, offsets = extract.get_train_stats(bvh_folder, train_actors) 52 | with open(stats_file, 'wb') as f: 53 | pickle.dump({ 54 | 'x_mean': x_mean, 55 | 'x_std': x_std, 56 | 'offsets': offsets, 57 | }, f, protocol=pickle.HIGHEST_PROTOCOL) 58 | else: 59 | print('Reusing stats file: ' + stats_file) 60 | with open(stats_file, 'rb') as f: 61 | stats = pickle.load(f) 62 | x_mean = stats['x_mean'] 63 | x_std = stats['x_std'] 64 | offsets = stats['offsets'] 65 | 66 | 67 | total_data = lafan_dataset.data['global_pos'].shape[0] 68 | 69 | # Replace with noise to In-betweening Frames 70 | from_idx, target_idx = ckpt['from_idx'], ckpt['target_idx'] 71 | horizon = ckpt['horizon'] 72 | print(f"HORIZON: {horizon}") 73 | 74 | test_idx = [] 75 | for i in range(total_data): 76 | test_idx.append(i) 77 | 78 | # Extract dimension from processed data 79 | pos_dim = lafan_dataset.num_joints * 3 80 | rot_dim = lafan_dataset.num_joints * 4 81 | repr_dim = pos_dim + rot_dim 82 | 83 | root_pos = torch.Tensor(lafan_dataset.data['root_p'][:, from_idx:target_idx+1]).to(device) 84 | local_q = torch.Tensor(lafan_dataset.data['local_q'][:, from_idx:target_idx+1]).to(device) 85 | local_q_normalized = nn.functional.normalize(local_q, p=2.0, dim=-1) 86 | global_pos, global_q = skeleton_mocap.forward_kinematics_with_rotation(local_q_normalized, root_pos) 87 | 88 | fixed = 0 89 | interpolation = ckpt['interpolation'] 90 | 91 | if interpolation == 'constant': 92 | global_pose_vec_gt = vectorize_representation(global_pos, global_q) 93 | global_pose_vec_input = global_pose_vec_gt.clone().detach() 94 | pose_interpolated_input = replace_constant(global_pose_vec_input, fixed) 95 | 96 | elif interpolation == 'slerp': 97 | global_pose_vec_gt = vectorize_representation(global_pos, global_q) 98 | global_pose_vec_input = global_pose_vec_gt.clone().detach() 99 | root_vec = global_pose_vec_input[:,:,:pos_dim] 100 | rot_vec = global_pose_vec_input[:,:,pos_dim:] 101 | root_lerped = lerp_input_repr(root_vec, fixed) 102 | rot_slerped = slerp_input_repr(rot_vec, fixed) 103 | pose_interpolated_input = torch.cat([root_lerped, rot_slerped], dim=2) 104 | 105 | else: 106 | raise ValueError('Invalid interpolation method') 107 | 108 | pose_vectorized_input = pose_interpolated_input.permute(1,0,2) 109 | 110 | src_mask = torch.zeros((horizon, horizon), device=device).type(torch.bool) 111 | src_mask = src_mask.to(device) 112 | 113 | seq_categories = [x[:-1] for x in lafan_dataset.data['seq_names']] 114 | 115 | l1_loss = nn.L1Loss() 116 | 117 | le = LabelEncoder() 118 | le.classes_ = np.load(os.path.join(save_dir, 'le_classes_.npy')) 119 | 120 | model = TransformerModel(seq_len=ckpt['horizon'], d_model=ckpt['d_model'], nhead=ckpt['nhead'], d_hid=ckpt['d_hid'], nlayers=ckpt['nlayers'], dropout=0.0, out_dim=repr_dim) 121 | model.load_state_dict(ckpt['transformer_encoder_state_dict']) 122 | model.eval() 123 | 124 | testing_motions = ['walk', 'run', 'dance', 'jumps', 'fight'] 125 | 126 | summary = {} 127 | for cond_motion in testing_motions: 128 | l2p = [] 129 | l2q = [] 130 | 131 | pred_rot_npss = [] 132 | 133 | print(f"GT: {gt_motion}") 134 | print(f"Condition: {cond_motion}") 135 | 136 | bch_out = {} 137 | 138 | bch_out['cond_motion'] = cond_motion 139 | bch_out['gt_motion'] = gt_motion 140 | 141 | motion_index = np.where(le.classes_ == cond_motion)[0][0] 142 | 143 | conditioning_label = torch.Tensor([[motion_index] * total_data]).type(torch.int64).to(device).permute(1,0) 144 | cond_output, _ = model(pose_vectorized_input[:, :, :], src_mask, conditioning_label) 145 | 146 | output = cond_output 147 | 148 | pred_global_pos = output[1:,:,:pos_dim].permute(1,0,2).reshape(total_data,horizon-1,22,3) 149 | global_pos_unit_vec = skeleton_mocap.convert_to_unit_offset_mat(pred_global_pos) 150 | pred_global_pos = skeleton_mocap.convert_to_global_pos(global_pos_unit_vec).detach().numpy() 151 | 152 | gt_global_pos = lafan_dataset.data['global_pos'][:, from_idx:target_idx+1].reshape(1, -1, lafan_dataset.num_joints, 3) 153 | pred_global_pos[0,0] = gt_global_pos[0,0] 154 | pred_global_pos[0,-1] = gt_global_pos[0,-1] 155 | 156 | pred_global_rot = output[1:,:,pos_dim:].permute(1,0,2).reshape(total_data,horizon-1,22,4) 157 | pred_global_rot_normalized = nn.functional.normalize(pred_global_rot, p=2.0, dim=3) 158 | gt_global_rot = global_q[:] 159 | pred_global_rot_normalized[0,0] = gt_global_rot[0,0] 160 | pred_global_rot_normalized[0,-1] = gt_global_rot[0,-1] 161 | pred_rot_npss.append(pred_global_rot_normalized) 162 | 163 | # Normalize for L2P 164 | normalized_gt_pos = torch.Tensor((lafan_dataset.data['global_pos'][:, from_idx:target_idx+1].reshape(total_data, -1, lafan_dataset.num_joints * 3).transpose(0,2,1) - x_mean) / x_std) 165 | normalized_pred_pos = torch.Tensor((pred_global_pos.reshape(total_data, -1, lafan_dataset.num_joints * 3).transpose(0,2,1) - x_mean) / x_std) 166 | 167 | l2p.append(torch.mean(torch.norm(normalized_pred_pos - normalized_gt_pos, dim=(1))).item()) 168 | l2q.append(torch.mean(torch.norm(pred_global_rot_normalized - global_q, dim=(2,3))).item()) 169 | 170 | l2p_mean = np.mean(l2p) 171 | l2q_mean = np.mean(l2q) 172 | 173 | print(f"TOTAL TEST DATA: {total_data}") 174 | print(f"L2P: {l2p_mean}") 175 | print(f"L2Q: {l2q_mean}") 176 | print("=================") 177 | 178 | bch_out['L2P'] = l2p_mean 179 | bch_out['L2Q'] = l2q_mean 180 | bch_out['TotalData'] = total_data 181 | 182 | summary[cond_motion] = bch_out 183 | with open(os.path.join('cond_bch', gt_motion,f'{gt_motion}.txt'), 'w') as f: 184 | json.dump(summary, f) 185 | 186 | def parse_opt(): 187 | parser = argparse.ArgumentParser() 188 | parser.add_argument('--project', default='runs/train', help='project/name') 189 | parser.add_argument('--exp_name', default='train_60', help='experiment name') 190 | parser.add_argument('--data_path', type=str, default='ubisoft-laforge-animation-dataset/output/BVH_FIGHT', help='BVH dataset path') 191 | parser.add_argument('--skeleton_path', type=str, default='ubisoft-laforge-animation-dataset/output/BVH/walk1_subject1.bvh', help='path to reference skeleton') 192 | parser.add_argument('--save_path', type=str, default='runs/test', help='path to save model') 193 | parser.add_argument('--train_stat', default='processed_data_80', help='train stat') 194 | opt = parser.parse_args() 195 | return opt 196 | 197 | if __name__ == "__main__": 198 | opt = parse_opt() 199 | device = torch.device("cpu") 200 | test(opt, device) 201 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | python trainer.py \ 2 | --data_path="AMASS/MPI_HDM05" \ 3 | --dataset="MPI_HDM05" \ 4 | --processed_data_dir="processed_data_mpi05_60/" \ 5 | --window=70 \ 6 | --batch_size=32 \ 7 | --epochs=5000 \ 8 | --device=1 \ 9 | --exp_name="MPI_HDM05_60" \ 10 | --save_interval=20 \ 11 | --learning_rate=0.0001 \ 12 | --loss_cond_weight=1.5 \ 13 | --loss_pos_weight=0.05 \ 14 | --loss_rot_weight=2.0 \ 15 | --from_idx=9 \ 16 | --target_idx=68 \ 17 | --interpolation='slerp' 18 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import wandb 9 | import yaml 10 | from sklearn.preprocessing import LabelEncoder 11 | from torch.optim import AdamW 12 | from torch.utils.data import DataLoader, TensorDataset 13 | from tqdm import tqdm 14 | 15 | from cmib.data.lafan1_dataset import LAFAN1Dataset 16 | from cmib.data.utils import flip_bvh, increment_path, process_seq_names 17 | from cmib.model.network import TransformerModel 18 | from cmib.model.preprocess import (lerp_input_repr, replace_constant, 19 | slerp_input_repr, vectorize_representation) 20 | from cmib.model.skeleton import (Skeleton, sk_joints_to_remove, sk_offsets, sk_parents, amass_offsets) 21 | 22 | 23 | def train(opt, device): 24 | 25 | print(f"[DATASET: {opt.dataset}]") 26 | # Prepare Directories 27 | save_dir = Path(opt.save_dir) 28 | wdir = save_dir / 'weights' 29 | wdir.mkdir(parents=True, exist_ok=True) 30 | 31 | # Save run settings 32 | with open(save_dir / 'opt.yaml', 'w') as f: 33 | yaml.safe_dump(vars(opt), f, sort_keys=True) 34 | 35 | epochs = opt.epochs 36 | save_interval = opt.save_interval 37 | 38 | # Loggers 39 | wandb.init(config=opt, project=opt.wandb_pj_name, entity=opt.entity, name=opt.exp_name, dir=opt.save_dir) 40 | 41 | # Load Skeleton 42 | offset = sk_offsets if opt.dataset == 'LAFAN' else amass_offsets 43 | skeleton_mocap = Skeleton(offsets=offset, parents=sk_parents, device=device) 44 | skeleton_mocap.remove_joints(sk_joints_to_remove) 45 | 46 | # Flip, Load and preprocess data. It utilizes LAFAN1 utilities 47 | if opt.dataset == 'LAFAN': 48 | flip_bvh(opt.data_path, skip='subject5') 49 | 50 | # Load LAFAN Dataset 51 | Path(opt.processed_data_dir).mkdir(parents=True, exist_ok=True) 52 | lafan_dataset = LAFAN1Dataset(lafan_path=opt.data_path, processed_data_dir=opt.processed_data_dir, train=True, device=device, window=opt.window, dataset=opt.dataset) 53 | 54 | from_idx, target_idx = opt.from_idx, opt.target_idx 55 | horizon = target_idx - from_idx + 1 56 | print(f"Horizon: {horizon}") 57 | horizon += 1 # Add one for conditioning token 58 | print(f"Horizon with Conditioning: {horizon}") 59 | print(f"Interpolation Mode: {opt.interpolation}") 60 | 61 | root_pos = torch.Tensor(lafan_dataset.data['root_p'][:, from_idx:target_idx+1]).to(device) 62 | local_q = torch.Tensor(lafan_dataset.data['local_q'][:, from_idx:target_idx+1]).to(device) 63 | local_q_normalized = nn.functional.normalize(local_q, p=2.0, dim=-1) 64 | 65 | global_pos, global_q = skeleton_mocap.forward_kinematics_with_rotation(local_q_normalized, root_pos) 66 | 67 | global_pose_vec_gt = vectorize_representation(global_pos, global_q) 68 | global_pose_vec_input = global_pose_vec_gt.clone().detach() 69 | 70 | if opt.dataset == 'LAFAN': 71 | seq_categories = [x[:-1] for x in lafan_dataset.data['seq_names']] 72 | else: 73 | seq_categories = process_seq_names(lafan_dataset.data['seq_names'], dataset=opt.dataset) 74 | 75 | le = LabelEncoder() 76 | le_np = le.fit_transform(seq_categories) 77 | seq_labels = torch.Tensor(le_np).type(torch.int64).unsqueeze(1).to(device) 78 | np.save(f'{save_dir}/le_classes_.npy', le.classes_) 79 | num_labels = len(seq_labels.squeeze().unique()) 80 | 81 | tensor_dataset = TensorDataset(global_pose_vec_input, global_pose_vec_gt, seq_labels) 82 | lafan_data_loader = DataLoader(tensor_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=0) 83 | 84 | pos_dim = lafan_dataset.num_joints * 3 85 | rot_dim = lafan_dataset.num_joints * 4 86 | repr_dim = pos_dim + rot_dim 87 | nhead = 7 # repr_dim = 154 88 | 89 | transformer_encoder = TransformerModel(seq_len=horizon, d_model=repr_dim, nhead=nhead, d_hid=2048, nlayers=8, dropout=0.05, out_dim=repr_dim, num_labels=num_labels) 90 | transformer_encoder.to(device) 91 | 92 | l1_loss = nn.L1Loss() 93 | optim = AdamW(params=transformer_encoder.parameters(), lr=opt.learning_rate) 94 | scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=100, gamma=0.9) 95 | 96 | for epoch in range(1, epochs + 1): 97 | 98 | pbar = tqdm(lafan_data_loader, position=1, desc="Batch") 99 | 100 | recon_cond_loss = [] 101 | recon_pos_loss = [] 102 | recon_rot_loss = [] 103 | total_loss_list = [] 104 | 105 | for minibatch_pose_input, minibatch_pose_gt, seq_label in pbar: 106 | 107 | for _ in range(5): 108 | mask_start_frame = np.random.randint(0, horizon-1) 109 | 110 | if opt.interpolation == 'constant': 111 | pose_interpolated_input = replace_constant(minibatch_pose_input, mask_start_frame) 112 | elif opt.interpolation == 'slerp': 113 | root_vec = minibatch_pose_input[:,:,:pos_dim] 114 | rot_vec = minibatch_pose_input[:,:,pos_dim:] 115 | root_lerped = lerp_input_repr(root_vec, mask_start_frame) 116 | rot_slerped = slerp_input_repr(rot_vec, mask_start_frame) 117 | pose_interpolated_input = torch.cat([root_lerped, rot_slerped], dim=2) 118 | else: 119 | raise ValueError('Invalid interpolation method') 120 | 121 | pose_interpolated_input = pose_interpolated_input.permute(1,0,2) 122 | 123 | src_mask = torch.zeros((horizon, horizon), device=device).type(torch.bool) 124 | src_mask = src_mask.to(device) 125 | 126 | output, cond_gt = transformer_encoder(pose_interpolated_input, src_mask, seq_label) 127 | 128 | cond_pred = output[0:1, :, :] 129 | cond_loss = l1_loss(cond_pred, cond_gt) 130 | recon_cond_loss.append(opt.loss_cond_weight * cond_loss) 131 | 132 | pos_pred = output[1:,:,:pos_dim].permute(1,0,2) 133 | pos_gt = minibatch_pose_gt[:,:,:pos_dim] 134 | pos_loss = l1_loss(pos_pred, pos_gt) 135 | recon_pos_loss.append(opt.loss_pos_weight * pos_loss) 136 | 137 | rot_pred = output[1:,:,pos_dim:].permute(1,0,2) 138 | rot_pred_reshaped = rot_pred.reshape(rot_pred.shape[0], rot_pred.shape[1], lafan_dataset.num_joints, 4) 139 | rot_pred_normalized = nn.functional.normalize(rot_pred_reshaped, p=2.0, dim=3) 140 | 141 | rot_gt = minibatch_pose_gt[:,:,pos_dim:] 142 | rot_gt_reshaped = rot_gt.reshape(rot_gt.shape[0], rot_gt.shape[1], lafan_dataset.num_joints, 4) 143 | rot_loss = l1_loss(rot_pred_reshaped, rot_gt_reshaped) 144 | recon_rot_loss.append(opt.loss_rot_weight * rot_loss) 145 | 146 | total_g_loss = opt.loss_pos_weight * pos_loss + \ 147 | opt.loss_rot_weight * rot_loss + \ 148 | opt.loss_cond_weight * cond_loss 149 | 150 | total_loss_list.append(total_g_loss) 151 | 152 | optim.zero_grad() 153 | total_g_loss.backward() 154 | torch.nn.utils.clip_grad_norm_(transformer_encoder.parameters(), 1.0, error_if_nonfinite=False) 155 | optim.step() 156 | 157 | scheduler.step() 158 | 159 | # Log 160 | log_dict = { 161 | "Train/Loss/Condition Loss": torch.stack(recon_cond_loss).mean().item(), 162 | "Train/Loss/Position Loss": torch.stack(recon_pos_loss).mean().item(), 163 | "Train/Loss/Rotatation Loss": torch.stack(recon_rot_loss).mean().item(), 164 | "Train/Loss/Total Loss": torch.stack(total_loss_list).mean().item(), 165 | } 166 | wandb.log(log_dict) 167 | 168 | # Save model 169 | if (epoch % save_interval) == 0: 170 | ckpt = {'epoch': epoch, 171 | 'transformer_encoder_state_dict': transformer_encoder.state_dict(), 172 | 'horizon': transformer_encoder.seq_len, 173 | 'from_idx': opt.from_idx, 174 | 'target_idx': opt.target_idx, 175 | 'd_model': transformer_encoder.d_model, 176 | 'nhead': transformer_encoder.nhead, 177 | 'd_hid': transformer_encoder.d_hid, 178 | 'nlayers': transformer_encoder.nlayers, 179 | 'optimizer_state_dict': optim.state_dict(), 180 | 'interpolation': opt.interpolation, 181 | 'loss': total_g_loss} 182 | torch.save(ckpt, os.path.join(wdir, f'train-{epoch}.pt')) 183 | print(f"[MODEL SAVED at {epoch} Epoch]") 184 | 185 | wandb.run.finish() 186 | torch.cuda.empty_cache() 187 | 188 | def parse_opt(): 189 | parser = argparse.ArgumentParser() 190 | parser.add_argument('--project', default='runs/train', help='project/name') 191 | parser.add_argument('--data_path', type=str, default='ubisoft-laforge-animation-dataset/output/BVH', help='BVH dataset path') 192 | parser.add_argument('--dataset', type=str, default='LAFAN', help='Dataset name') 193 | parser.add_argument('--processed_data_dir', type=str, default='processed_data_80/', help='path to save pickled processed data') 194 | parser.add_argument('--window', type=int, default=90, help='horizon') 195 | parser.add_argument('--wandb_pj_name', type=str, default='cmib_train', help='project name') 196 | parser.add_argument('--batch_size', type=int, default=32, help='batch size') 197 | parser.add_argument('--epochs', type=int, default=3000) 198 | parser.add_argument('--device', default='0', help='cuda device') 199 | parser.add_argument('--entity', default=None, help='W&B entity') 200 | parser.add_argument('--exp_name', default='exp', help='save to project/name') 201 | parser.add_argument('--save_interval', type=int, default=50, help='Log model after every "save_period" epoch') 202 | parser.add_argument('--learning_rate', type=float, default=0.0001, help='generator_learning_rate') 203 | parser.add_argument('--loss_cond_weight', type=float, default=1.5, help='loss_cond_weight') 204 | parser.add_argument('--loss_pos_weight', type=float, default=0.05, help='loss_pos_weight') 205 | parser.add_argument('--loss_rot_weight', type=float, default=2.0, help='loss_rot_weight') 206 | parser.add_argument('--from_idx', type=int, default=9, help='from idx') 207 | parser.add_argument('--target_idx', type=int, default=88, help='target idx') 208 | parser.add_argument('--interpolation', type=str, default='slerp', help='interpolation') 209 | opt = parser.parse_args() 210 | return opt 211 | 212 | if __name__ == "__main__": 213 | opt = parse_opt() 214 | opt.save_dir = str(increment_path(Path(opt.project) / opt.exp_name)) 215 | opt.exp_name = opt.save_dir.split('/')[-1] 216 | device = torch.device(f"cuda:{opt.device}" if torch.cuda.is_available() else "cpu") 217 | train(opt, device) 218 | --------------------------------------------------------------------------------