├── .gitignore
├── README.md
├── assets
├── TitlePage.png
├── dance.gif
├── graphical_abstract.jpg
├── ib.gif
├── jump.gif
├── pc.gif
└── walk.gif
├── cmib
├── data
│ ├── lafan1_dataset.py
│ ├── quaternion.py
│ └── utils.py
├── lafan1
│ ├── __init__.py
│ ├── benchmarks.py
│ ├── extract.py
│ └── utils.py
├── misc
│ └── sampler.py
├── model
│ ├── network.py
│ ├── positional_encoding.py
│ ├── preprocess.py
│ └── skeleton.py
└── vis
│ └── pose.py
├── requirements.txt
├── run_cmib.py
├── run_test_multi.py
├── test_benchmark.py
├── test_condition_comp.py
├── train.sh
└── trainer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Project Specific
2 | settings.json
3 | log/*
4 | model_weights/*
5 | results/*
6 | ubisoft-laforge-animation-dataset/
7 | lafan_train*.pkl
8 | runs/
9 | wandb/
10 | processed_data*/
11 | AMASS/
12 | *.pkl
13 |
14 | # Byte-compiled / optimized / DLL files
15 | __pycache__/
16 | *.py[cod]
17 | *$py.class
18 |
19 | # C extensions
20 | *.so
21 |
22 | # Distribution / packaging
23 | .Python
24 | build/
25 | develop-eggs/
26 | dist/
27 | downloads/
28 | eggs/
29 | .eggs/
30 | lib/
31 | lib64/
32 | parts/
33 | sdist/
34 | var/
35 | wheels/
36 | pip-wheel-metadata/
37 | share/python-wheels/
38 | *.egg-info/
39 | .installed.cfg
40 | *.egg
41 | MANIFEST
42 |
43 | # PyInstaller
44 | # Usually these files are written by a python script from a template
45 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
46 | *.manifest
47 | *.spec
48 |
49 | # Installer logs
50 | pip-log.txt
51 | pip-delete-this-directory.txt
52 |
53 | # Unit test / coverage reports
54 | htmlcov/
55 | .tox/
56 | .nox/
57 | .coverage
58 | .coverage.*
59 | .cache
60 | nosetests.xml
61 | coverage.xml
62 | *.cover
63 | *.py,cover
64 | .hypothesis/
65 | .pytest_cache/
66 |
67 | # Translations
68 | *.mo
69 | *.pot
70 |
71 | # Django stuff:
72 | *.log
73 | local_settings.py
74 | db.sqlite3
75 | db.sqlite3-journal
76 |
77 | # Flask stuff:
78 | instance/
79 | .webassets-cache
80 |
81 | # Scrapy stuff:
82 | .scrapy
83 |
84 | # Sphinx documentation
85 | docs/_build/
86 |
87 | # PyBuilder
88 | target/
89 |
90 | # Jupyter Notebook
91 | .ipynb_checkpoints
92 |
93 | # IPython
94 | profile_default/
95 | ipython_config.py
96 |
97 | # pyenv
98 | .python-version
99 |
100 | # pipenv
101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
104 | # install all needed dependencies.
105 | #Pipfile.lock
106 |
107 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
108 | __pypackages__/
109 |
110 | # Celery stuff
111 | celerybeat-schedule
112 | celerybeat.pid
113 |
114 | # SageMath parsed files
115 | *.sage.py
116 |
117 | # Environments
118 | .env
119 | .venv
120 | env/
121 | venv/
122 | ENV/
123 | env.bak/
124 | venv.bak/
125 |
126 | # Spyder project settings
127 | .spyderproject
128 | .spyproject
129 |
130 | # Rope project settings
131 | .ropeproject
132 |
133 | # mkdocs documentation
134 | /site
135 |
136 | # mypy
137 | .mypy_cache/
138 | .dmypy.json
139 | dmypy.json
140 |
141 | # Pyre type checker
142 | .pyre/
143 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Conditional Motion In-Betweening (CMIB)
2 |
3 | Official implementation of paper: Conditional Motion In-betweeening.
4 |
5 | [Paper](https://www.sciencedirect.com/science/article/pii/S0031320322003752) | [Project Page](https://jihoonerd.github.io/Conditional-Motion-In-Betweening/) | [YouTube](https://youtu.be/XAELcHOREJ8)
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 | in-betweening |
14 | pose-conditioned |
15 |
16 |
17 |  |
18 |  |
19 |
20 |
21 |
22 |
23 |
24 | walk |
25 | jump |
26 | dance |
27 |
28 |
29 |  |
30 |  |
31 |  |
32 |
33 |
34 |
35 | ## Environments
36 |
37 | This repo is tested on following environment:
38 |
39 | * Ubuntu 20.04
40 | * Python >= 3.7
41 | * PyTorch == 1.10.1
42 | * Cuda V11.3.109
43 |
44 | ## Install
45 |
46 | 1. Follow [`LAFAN1`](https://github.com/ubisoft/ubisoft-laforge-animation-dataset) dataset's installation guide.
47 | *You need to install git lfs first before cloning the dataset repo.*
48 |
49 | 2. Run LAFAN1's `evaluate.py` to unzip and validate it. (Install `numpy` first if you don't have it)
50 | ```bash
51 | $ pip install numpy
52 | $ python ubisoft-laforge-animation-dataset/evaluate.py
53 | ```
54 | With this, you will have unpacked LAFAN dataset under `ubisoft-laforge-animation-dataset` folder.
55 |
56 | 3. Install appropriate `pytorch` version depending on your device(CPU/GPU), then install packages listed in `requirements.txt`. .
57 |
58 | ## Trained Weights
59 |
60 | You can download trained weights from [here](https://works.do/FCqKVjy).
61 |
62 | ## Train from Scratch
63 |
64 | Trining script is `trainer.py`.
65 |
66 | ```bash
67 | python trainer.py \
68 | --processed_data_dir="processed_data_80/" \
69 | --window=90 \
70 | --batch_size=32 \
71 | --epochs=5000 \
72 | --device=0 \
73 | --entity=cmib_exp \
74 | --exp_name="cmib_80" \
75 | --save_interval=50 \
76 | --learning_rate=0.0001 \
77 | --loss_cond_weight=1.5 \
78 | --loss_pos_weight=0.05 \
79 | --loss_rot_weight=2.0 \
80 | --from_idx=9 \
81 | --target_idx=88 \
82 | --interpolation='slerp'
83 |
84 | ```
85 |
86 | ## Inference
87 |
88 | You can use `run_cmib.py` for inference. Please refer to help page of `run_cmib.py` for more details.
89 |
90 | ```python
91 | python run_cmib.py --help
92 | ```
93 |
94 | ## Reference
95 |
96 | * LAFAN1 Dataset
97 | ```
98 | @article{harvey2020robust,
99 | author = {Félix G. Harvey and Mike Yurick and Derek Nowrouzezahrai and Christopher Pal},
100 | title = {Robust Motion In-Betweening},
101 | booktitle = {ACM Transactions on Graphics (Proceedings of ACM SIGGRAPH)},
102 | publisher = {ACM},
103 | volume = {39},
104 | number = {4},
105 | year = {2020}
106 | }
107 | ```
108 |
109 | ## Citation
110 | ```
111 | @article{KIM2022108894,
112 | title = {Conditional Motion In-betweening},
113 | journal = {Pattern Recognition},
114 | pages = {108894},
115 | year = {2022},
116 | issn = {0031-3203},
117 | doi = {https://doi.org/10.1016/j.patcog.2022.108894},
118 | url = {https://www.sciencedirect.com/science/article/pii/S0031320322003752},
119 | author = {Jihoon Kim and Taehyun Byun and Seungyoun Shin and Jungdam Won and Sungjoon Choi},
120 | keywords = {motion in-betweening, conditional motion generation, generative model, motion data augmentation},
121 | abstract = {Motion in-betweening (MIB) is a process of generating intermediate skeletal movement between the given start and target poses while preserving the naturalness of the motion, such as periodic footstep motion while walking. Although state-of-the-art MIB methods are capable of producing plausible motions given sparse key-poses, they often lack the controllability to generate motions satisfying the semantic contexts required in practical applications. We focus on the method that can handle pose or semantic conditioned MIB tasks using a unified model. We also present a motion augmentation method to improve the quality of pose-conditioned motion generation via defining a distribution over smooth trajectories. Our proposed method outperforms the existing state-of-the-art MIB method in pose prediction errors while providing additional controllability. Our code and results are available on our project web page: https://jihoonerd.github.io/Conditional-Motion-In-Betweening}
122 | }
123 | ```
124 |
125 | ## Author
126 |
127 | * [Jihoon Kim](https://github.com/jihoonerd)
128 | * [Taehyun Byun](https://github.com/childtoy)
129 |
--------------------------------------------------------------------------------
/assets/TitlePage.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jihoonerd/Conditional-Motion-In-Betweening/40f42c6d2d0e081e2162569180c5e2ad42ce659e/assets/TitlePage.png
--------------------------------------------------------------------------------
/assets/dance.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jihoonerd/Conditional-Motion-In-Betweening/40f42c6d2d0e081e2162569180c5e2ad42ce659e/assets/dance.gif
--------------------------------------------------------------------------------
/assets/graphical_abstract.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jihoonerd/Conditional-Motion-In-Betweening/40f42c6d2d0e081e2162569180c5e2ad42ce659e/assets/graphical_abstract.jpg
--------------------------------------------------------------------------------
/assets/ib.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jihoonerd/Conditional-Motion-In-Betweening/40f42c6d2d0e081e2162569180c5e2ad42ce659e/assets/ib.gif
--------------------------------------------------------------------------------
/assets/jump.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jihoonerd/Conditional-Motion-In-Betweening/40f42c6d2d0e081e2162569180c5e2ad42ce659e/assets/jump.gif
--------------------------------------------------------------------------------
/assets/pc.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jihoonerd/Conditional-Motion-In-Betweening/40f42c6d2d0e081e2162569180c5e2ad42ce659e/assets/pc.gif
--------------------------------------------------------------------------------
/assets/walk.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jihoonerd/Conditional-Motion-In-Betweening/40f42c6d2d0e081e2162569180c5e2ad42ce659e/assets/walk.gif
--------------------------------------------------------------------------------
/cmib/data/lafan1_dataset.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | from cmib.lafan1 import extract, utils
3 | import numpy as np
4 | import pickle
5 | import os
6 |
7 |
8 | class LAFAN1Dataset(Dataset):
9 | def __init__(
10 | self,
11 | lafan_path: str,
12 | processed_data_dir: str,
13 | train: bool,
14 | device: str,
15 | window: int = 65,
16 | dataset: str = 'LAFAN'
17 | ):
18 | self.lafan_path = lafan_path
19 |
20 | self.train = train
21 | # 4.3: It contains actions performedby 5 subjects, with Subject 5 used as the test set.
22 | self.dataset = dataset
23 |
24 | if self.dataset == 'LAFAN':
25 | self.actors = (
26 | ["subject1", "subject2", "subject3", "subject4"] if train else ["subject5"]
27 | )
28 | elif self.dataset in ['HumanEva', 'PosePrior']:
29 | self.actors = (
30 | ["subject1", "subject2"] if train else ["subject3"]
31 | )
32 | elif self.dataset in ['HUMAN4D']:
33 | self.actors = (
34 | ["subject1", "subject2", "subject3", "subject4", "subject5", "subject6", "subject7"] if train else ["subject8"]
35 | )
36 | elif self.dataset in ["MPI_HDM05"]:
37 | self.actors = (
38 | ["subject1", "subject2", "subject3"] if train else ["subject4"]
39 | )
40 | else:
41 | ValueError("Invalid Dataset")
42 |
43 | # 4.3: ... The training statistics for normalization are computed on windows of 50 frames offset by 20 frames.
44 | self.window = window
45 |
46 | # 4.3: Given the larger size of ... we sample our test windows from Subject 5 at every 40 frames.
47 | # The training statistics for normalization are computed on windows of 50 frames offset by 20 frames.
48 | self.offset = 20 if self.train else 40
49 |
50 | self.device = device
51 |
52 | pickle_name = "processed_train_data.pkl" if train else "processed_test_data.pkl"
53 |
54 | if pickle_name in os.listdir(processed_data_dir):
55 | with open(os.path.join(processed_data_dir, pickle_name), "rb") as f:
56 | self.data = pickle.load(f)
57 | else:
58 | self.data = self.load_lafan() # Call this last
59 | with open(os.path.join(processed_data_dir, pickle_name), "wb") as f:
60 | pickle.dump(self.data, f, pickle.HIGHEST_PROTOCOL)
61 |
62 | @property
63 | def root_v_dim(self):
64 | return self.data["root_v"].shape[2]
65 |
66 | @property
67 | def local_q_dim(self):
68 | return self.data["local_q"].shape[2] * self.data["local_q"].shape[3]
69 |
70 | @property
71 | def contact_dim(self):
72 | return self.data["contact"].shape[2]
73 |
74 | @property
75 | def num_joints(self):
76 | return self.data["global_pos"].shape[2]
77 |
78 | def load_lafan(self):
79 | # This uses method provided with LAFAN1.
80 | # X and Q are local position/quaternion. Motions are rotated to make 10th frame facing X+ position.
81 | # Refer to paper 3.1 Data formatting
82 | X, Q, parents, contacts_l, contacts_r, seq_names = extract.get_lafan1_set(
83 | self.lafan_path, self.actors, self.window, self.offset, self.train, self.dataset
84 | )
85 |
86 | # Retrieve global representations. (global quaternion, global positions)
87 | _, global_pos = utils.quat_fk(Q, X, parents)
88 |
89 | input_data = {}
90 | input_data["local_q"] = Q # q_{t}
91 | input_data["local_q_offset"] = Q[:, -1, :, :] # lasst frame's quaternions
92 | input_data["q_target"] = Q[:, -1, :, :] # q_{T}
93 |
94 | input_data["root_v"] = (
95 | global_pos[:, 1:, 0, :] - global_pos[:, :-1, 0, :]
96 | ) # \dot{r}_{t}
97 | input_data["root_p_offset"] = global_pos[
98 | :, -1, 0, :
99 | ] # last frame's root positions
100 | input_data["root_p"] = global_pos[:, :, 0, :]
101 |
102 | input_data["contact"] = np.concatenate(
103 | [contacts_l, contacts_r], -1
104 | ) # Foot contact
105 | input_data["global_pos"] = global_pos[
106 | :, :, :, :
107 | ] # global position (N, 50, 22, 30) why not just global_pos
108 | input_data["seq_names"] = seq_names
109 | return input_data
110 |
111 | def __len__(self):
112 | return self.data["global_pos"].shape[0]
113 |
114 | def __getitem__(self, index):
115 | query = {}
116 | query["local_q"] = self.data["local_q"][index].astype(np.float32)
117 | query["local_q_offset"] = self.data["local_q_offset"][index].astype(np.float32)
118 | query["q_target"] = self.data["q_target"][index].astype(np.float32)
119 | query["root_v"] = self.data["root_v"][index].astype(np.float32)
120 | query["root_p_offset"] = self.data["root_p_offset"][index].astype(np.float32)
121 | query["root_p"] = self.data["root_p"][index].astype(np.float32)
122 | query["contact"] = self.data["contact"][index].astype(np.float32)
123 | query["global_pos"] = self.data["global_pos"][index].astype(np.float32)
124 | return query
125 |
--------------------------------------------------------------------------------
/cmib/data/quaternion.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import torch
9 | import numpy as np
10 |
11 | # PyTorch-backed implementations
12 |
13 |
14 | def qmul(q, r):
15 | """
16 | Multiply quaternion(s) q with quaternion(s) r.
17 | Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions.
18 | Returns q*r as a tensor of shape (*, 4).
19 | """
20 | assert q.shape[-1] == 4
21 | assert r.shape[-1] == 4
22 |
23 | original_shape = q.shape
24 |
25 | # Compute outer product
26 | terms = torch.bmm(r.reshape(-1, 4, 1), q.reshape(-1, 1, 4))
27 |
28 | w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3]
29 | x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2]
30 | y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1]
31 | z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0]
32 | return torch.stack((w, x, y, z), dim=1).view(original_shape)
33 |
34 |
35 | def qrot(q, v):
36 | """
37 | Rotate vector(s) v about the rotation described by quaternion(s) q.
38 | Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
39 | where * denotes any number of dimensions.
40 | Returns a tensor of shape (*, 3).
41 | """
42 | assert q.shape[-1] == 4
43 | assert v.shape[-1] == 3
44 | assert q.shape[:-1] == v.shape[:-1]
45 |
46 | original_shape = list(v.shape)
47 | q = q.reshape(-1, 4)
48 | v = v.reshape(-1, 3)
49 |
50 | qvec = q[:, 1:]
51 | uv = torch.cross(qvec, v, dim=1)
52 | uuv = torch.cross(qvec, uv, dim=1)
53 | return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
54 |
55 |
56 | def qeuler(q, order, epsilon=0):
57 | """
58 | Convert quaternion(s) q to Euler angles.
59 | Expects a tensor of shape (*, 4), where * denotes any number of dimensions.
60 | Returns a tensor of shape (*, 3).
61 | """
62 | assert q.shape[-1] == 4
63 |
64 | original_shape = list(q.shape)
65 | original_shape[-1] = 3
66 | q = q.view(-1, 4)
67 |
68 | q0 = q[:, 0]
69 | q1 = q[:, 1]
70 | q2 = q[:, 2]
71 | q3 = q[:, 3]
72 |
73 | if order == "xyz":
74 | x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
75 | y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon))
76 | z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
77 | elif order == "yzx":
78 | x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
79 | y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
80 | z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon))
81 | elif order == "zxy":
82 | x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon))
83 | y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
84 | z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3))
85 | elif order == "xzy":
86 | x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
87 | y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
88 | z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon))
89 | elif order == "yxz":
90 | x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon))
91 | y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2))
92 | z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
93 | elif order == "zyx":
94 | x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
95 | y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon))
96 | z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
97 | else:
98 | raise
99 |
100 | return torch.stack((x, y, z), dim=1).view(original_shape)
101 |
102 |
103 | # Numpy-backed implementations
104 |
105 |
106 | def qmul_np(q, r):
107 | q = torch.from_numpy(q).contiguous()
108 | r = torch.from_numpy(r).contiguous()
109 | return qmul(q, r).numpy()
110 |
111 |
112 | def qrot_np(q, v):
113 | q = torch.from_numpy(q).contiguous()
114 | v = torch.from_numpy(v).contiguous()
115 | return qrot(q, v).numpy()
116 |
117 |
118 | def qeuler_np(q, order, epsilon=0, use_gpu=False):
119 | if use_gpu:
120 | q = torch.from_numpy(q).cuda()
121 | return qeuler(q, order, epsilon).cpu().numpy()
122 | else:
123 | q = torch.from_numpy(q).contiguous()
124 | return qeuler(q, order, epsilon).numpy()
125 |
126 |
127 | def qfix(q):
128 | """
129 | Enforce quaternion continuity across the time dimension by selecting
130 | the representation (q or -q) with minimal distance (or, equivalently, maximal dot product)
131 | between two consecutive frames.
132 |
133 | Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints.
134 | Returns a tensor of the same shape.
135 | """
136 | assert len(q.shape) == 3
137 | assert q.shape[-1] == 4
138 |
139 | result = q.copy()
140 | dot_products = np.sum(q[1:] * q[:-1], axis=2)
141 | mask = dot_products < 0
142 | mask = (np.cumsum(mask, axis=0) % 2).astype(bool)
143 | result[1:][mask] *= -1
144 | return result
145 |
146 |
147 | def expmap_to_quaternion(e):
148 | """
149 | Convert axis-angle rotations (aka exponential maps) to quaternions.
150 | Stable formula from "Practical Parameterization of Rotations Using the Exponential Map".
151 | Expects a tensor of shape (*, 3), where * denotes any number of dimensions.
152 | Returns a tensor of shape (*, 4).
153 | """
154 | assert e.shape[-1] == 3
155 |
156 | original_shape = list(e.shape)
157 | original_shape[-1] = 4
158 | e = e.reshape(-1, 3)
159 |
160 | theta = np.linalg.norm(e, axis=1).reshape(-1, 1)
161 | w = np.cos(0.5 * theta).reshape(-1, 1)
162 | xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e
163 | return np.concatenate((w, xyz), axis=1).reshape(original_shape)
164 |
165 |
166 | def euler_to_quaternion(e, order):
167 | """
168 | Convert Euler angles to quaternions.
169 | """
170 | assert e.shape[-1] == 3
171 |
172 | original_shape = list(e.shape)
173 | original_shape[-1] = 4
174 |
175 | e = e.reshape(-1, 3)
176 |
177 | x = e[:, 0]
178 | y = e[:, 1]
179 | z = e[:, 2]
180 |
181 | rx = np.stack(
182 | (np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1
183 | )
184 | ry = np.stack(
185 | (np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1
186 | )
187 | rz = np.stack(
188 | (np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1
189 | )
190 |
191 | result = None
192 | for coord in order:
193 | if coord == "x":
194 | r = rx
195 | elif coord == "y":
196 | r = ry
197 | elif coord == "z":
198 | r = rz
199 | else:
200 | raise
201 | if result is None:
202 | result = r
203 | else:
204 | result = qmul_np(result, r)
205 |
206 | # Reverse antipodal representation to have a non-negative "w"
207 | if order in ["xyz", "yzx", "zxy"]:
208 | result *= -1
209 |
210 | return result.reshape(original_shape)
211 |
--------------------------------------------------------------------------------
/cmib/data/utils.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import json
3 | import os
4 | from pathlib import Path
5 | import re
6 | import numpy as np
7 | from cmib.data.quaternion import euler_to_quaternion, qeuler_np
8 |
9 |
10 | def drop_end_quat(quaternions, skeleton):
11 | """
12 | quaternions: [N,T,Joints,4]
13 | """
14 |
15 | return quaternions[:, :, skeleton.has_children()]
16 |
17 |
18 | def write_json(filename, local_q, root_pos, joint_names):
19 | json_out = {}
20 | json_out["root_pos"] = root_pos.tolist()
21 | json_out["local_quat"] = local_q.tolist()
22 | json_out["joint_names"] = joint_names
23 | with open(filename, "w") as outfile:
24 | json.dump(json_out, outfile)
25 |
26 |
27 | def flip_bvh(bvh_folder: str, skip: str):
28 | """
29 | Generate LR flip of existing bvh files. Assumes Z-forward.
30 | It does not flip files contains skip string in their name.
31 | """
32 |
33 | print("Left-Right Flipping Process...")
34 |
35 | # List files which are not flipped yet
36 | to_convert = []
37 | not_convert = []
38 | bvh_files = os.listdir(bvh_folder)
39 | for bvh_file in bvh_files:
40 | if "_LRflip.bvh" in bvh_file:
41 | continue
42 | if skip in bvh_file:
43 | not_convert.append(bvh_file)
44 | continue
45 | flipped_file = bvh_file.replace(".bvh", "_LRflip.bvh")
46 | if flipped_file in bvh_files:
47 | print(f"[SKIP: {bvh_file}] (flipped file already exists)")
48 | continue
49 | to_convert.append(bvh_file)
50 |
51 | print("Following files will be flipped: ")
52 | print(to_convert)
53 | print("Following files are not flipped: ")
54 | print(not_convert)
55 |
56 | for i, converting_fn in enumerate(to_convert):
57 | fout = open(
58 | os.path.join(bvh_folder, converting_fn.replace(".bvh", "_LRflip.bvh")), "w"
59 | )
60 | file_read = open(os.path.join(bvh_folder, converting_fn), "r")
61 | file_lines = file_read.readlines()
62 | hierarchy_part = True
63 | for line in file_lines:
64 | if hierarchy_part:
65 | fout.write(line)
66 | if "Frame Time" in line:
67 | # This should be the last exact copy. Motion line comes next
68 | hierarchy_part = False
69 | else:
70 | # Followings are very helpful to understand which axis needs to be inverted
71 | # http://lo-th.github.io/olympe/BVH_player.html
72 | # https://quaternions.online/
73 | str_to_num = line.split(" ")[:-1] # Extract number only
74 | motion_mat = np.array([float(x) for x in str_to_num]).reshape(
75 | 23, 3
76 | ) # Hips 6 Channel + 3 * 21 = 69
77 | motion_mat[0, 2] *= -1.0 # Invert translation Z axis (forward-backward)
78 | quat = euler_to_quaternion(
79 | np.radians(motion_mat[1:]), "zyx"
80 | ) # This function takes radians
81 | # Invert X-axis (Left-Right) / Quaternion representation: (w, x, y, z)
82 | quat[:, 0] *= -1.0
83 | quat[:, 1] *= -1.0
84 | motion_mat[1:] = np.degrees(qeuler_np(quat, "zyx"))
85 |
86 | # idx 0: Hips Wolrd coord, idx 1: Hips Rotation
87 | left_idx = [2, 3, 4, 5, 15, 16, 17, 18] # From 2: LeftUpLeg...
88 | right_idx = [6, 7, 8, 9, 19, 20, 21, 22] # From 6: RightUpLeg...
89 | motion_mat[left_idx + right_idx] = motion_mat[
90 | right_idx + left_idx
91 | ].copy()
92 | motion_mat = np.round(motion_mat, decimals=6)
93 | motion_vector = np.reshape(motion_mat, (69,))
94 | motion_part_str = ""
95 | for s in motion_vector:
96 | motion_part_str += str(s) + " "
97 | motion_part_str += "\n"
98 | fout.write(motion_part_str)
99 | print(f"[{i+1}/{len(to_convert)}] {converting_fn} flipped.")
100 |
101 |
102 | def increment_path(path, exist_ok=False, sep="", mkdir=False):
103 | # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
104 | path = Path(path) # os-agnostic
105 | if path.exists() and not exist_ok:
106 | suffix = path.suffix
107 | path = path.with_suffix("")
108 | dirs = glob.glob(f"{path}{sep}*") # similar paths
109 | matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
110 | i = [int(m.groups()[0]) for m in matches if m] # indices
111 | n = max(i) + 1 if i else 2 # increment number
112 | path = Path(f"{path}{sep}{n}{suffix}") # update path
113 | dir = path if path.suffix == "" else path.parent # directory
114 | if not dir.exists() and mkdir:
115 | dir.mkdir(parents=True, exist_ok=True) # make directory
116 | return path
117 |
118 |
119 | def process_seq_names(seq_names, dataset):
120 |
121 | if dataset in ['HumanEva', 'HUMAN4D', 'MPI_HDM05']:
122 | processed_seqname = [x[:-1] for x in seq_names]
123 | elif dataset == 'PosePrior':
124 | processed_seqname = []
125 | for seq in seq_names:
126 | if 'lar' in seq:
127 | pr_seq = 'lar'
128 | elif 'op' in seq:
129 | pr_seq = 'op'
130 | elif 'rom' in seq:
131 | pr_seq = 'rom'
132 | elif 'uar' in seq:
133 | pr_seq = 'uar'
134 | elif 'ulr' in seq:
135 | pr_seq = 'ulr'
136 | else:
137 | ValueError('Invlaid seq name')
138 | processed_seqname.append(pr_seq)
139 | else:
140 | ValueError('Invalid dataset name')
141 | return processed_seqname
--------------------------------------------------------------------------------
/cmib/lafan1/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jihoonerd/Conditional-Motion-In-Betweening/40f42c6d2d0e081e2162569180c5e2ad42ce659e/cmib/lafan1/__init__.py
--------------------------------------------------------------------------------
/cmib/lafan1/benchmarks.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from . import utils
4 | import torch
5 |
6 | np.set_printoptions(precision=3)
7 |
8 |
9 | def npss(gt_seq, pred_seq):
10 | # Fourier coefficients along the time dimension
11 | gt_fourier_coeffs = torch.real(torch.fft.fft(gt_seq, dim=1))
12 | pred_fourier_coeffs = torch.real(torch.fft.fft(pred_seq, dim=1))
13 |
14 | # Square of the Fourier coefficients
15 | gt_power = torch.square(gt_fourier_coeffs)
16 | pred_power = torch.square(pred_fourier_coeffs)
17 |
18 | # Sum of powers over time dimension
19 | gt_total_power = torch.sum(gt_power, dim=1)
20 | pred_total_power = torch.sum(pred_power, dim=1)
21 |
22 | # Normalize powers with totals
23 | gt_norm_power = gt_power / gt_total_power.unsqueeze(1)
24 | pred_norm_power = pred_power / pred_total_power.unsqueeze(1)
25 |
26 | # Cumulative sum over time
27 | cdf_gt_power = torch.cumsum(gt_norm_power, dim=1)
28 | cdf_pred_power = torch.cumsum(pred_norm_power, dim=1)
29 |
30 | # Earth mover distance
31 | emd = torch.norm((cdf_pred_power - cdf_gt_power), p=1, dim=1)
32 |
33 | # Weighted EMD
34 | power_weighted_emd = torch.sum(emd * gt_total_power) / torch.sum(gt_total_power)
35 |
36 | return power_weighted_emd
37 |
38 |
39 | def fast_npss(gt_seq, pred_seq):
40 | """
41 | Computes Normalized Power Spectrum Similarity (NPSS).
42 |
43 | This is the metric proposed by Gropalakrishnan et al (2019).
44 | This implementation uses numpy parallelism for improved performance.
45 |
46 | :param gt_seq: ground-truth array of shape : (Batchsize, Timesteps, Dimension)
47 | :param pred_seq: shape : (Batchsize, Timesteps, Dimension)
48 | :return: The average npss metric for the batch
49 | """
50 | # Fourier coefficients along the time dimension
51 | gt_fourier_coeffs = np.real(np.fft.fft(gt_seq, axis=1))
52 | pred_fourier_coeffs = np.real(np.fft.fft(pred_seq, axis=1))
53 |
54 | # Square of the Fourier coefficients
55 | gt_power = np.square(gt_fourier_coeffs)
56 | pred_power = np.square(pred_fourier_coeffs)
57 |
58 | # Sum of powers over time dimension
59 | gt_total_power = np.sum(gt_power, axis=1)
60 | pred_total_power = np.sum(pred_power, axis=1)
61 |
62 | # Normalize powers with totals
63 | gt_norm_power = gt_power / gt_total_power[:, np.newaxis, :]
64 | pred_norm_power = pred_power / pred_total_power[:, np.newaxis, :]
65 |
66 | # Cumulative sum over time
67 | cdf_gt_power = np.cumsum(gt_norm_power, axis=1)
68 | cdf_pred_power = np.cumsum(pred_norm_power, axis=1)
69 |
70 | # Earth mover distance
71 | emd = np.linalg.norm((cdf_pred_power - cdf_gt_power), ord=1, axis=1)
72 |
73 | # Weighted EMD
74 | power_weighted_emd = np.average(emd, weights=gt_total_power)
75 |
76 | return power_weighted_emd
77 |
78 |
79 | def flatjoints(x):
80 | """
81 | Shorthand for a common reshape pattern. Collapses all but the two first dimensions of a tensor.
82 | :param x: Data tensor of at least 3 dimensions.
83 | :return: The flattened tensor.
84 | """
85 | return x.reshape((x.shape[0], x.shape[1], -1))
86 |
87 |
88 | def benchmark_interpolation(
89 | X, Q, x_mean, x_std, offsets, parents, out_path=None, n_past=10, n_future=10
90 | ):
91 | """
92 | Evaluate naive baselines (zero-velocity and interpolation) for transition generation on given data.
93 | :param X: Local positions array of shape (Batchsize, Timesteps, Joints, 3)
94 | :param Q: Local quaternions array of shape (B, T, J, 4)
95 | :param x_mean : Mean vector of local positions of shape (1, J*3, 1)
96 | :param out_path: Standard deviation vector of local positions (1, J*3, 1)
97 | :param offsets: Local bone offsets tensor of shape (1, 1, J, 3)
98 | :param parents: List of bone parents indices defining the hierarchy
99 | :param out_path: optional path for saving the results
100 | :param n_past: Number of frames used as past context
101 | :param n_future: Number of frames used as future context (only the first frame is used as the target)
102 | :return: Results dictionary
103 | """
104 |
105 | trans_lengths = [5, 15, 30, 45]
106 | n_joints = 22
107 | res = {}
108 |
109 | for n_trans in trans_lengths:
110 | print("Computing errors for transition length = {}...".format(n_trans))
111 |
112 | # Format the data for the current transition lengths. The number of samples and the offset stays unchanged.
113 | curr_window = n_trans + n_past + n_future
114 | curr_x = X[:, :curr_window, ...]
115 | curr_q = Q[:, :curr_window, ...]
116 | batchsize = curr_x.shape[0]
117 |
118 | # Ground-truth positions/quats/eulers
119 | gt_local_quats = curr_q
120 | gt_roots = curr_x[:, :, 0:1, :]
121 | gt_offsets = np.tile(offsets, [batchsize, curr_window, 1, 1])
122 | gt_local_poses = np.concatenate([gt_roots, gt_offsets], axis=2)
123 | trans_gt_local_poses = gt_local_poses[:, n_past:-n_future, ...]
124 | trans_gt_local_quats = gt_local_quats[:, n_past:-n_future, ...]
125 | # Local to global with Forward Kinematics (FK)
126 | trans_gt_global_quats, trans_gt_global_poses = utils.quat_fk(
127 | trans_gt_local_quats, trans_gt_local_poses, parents
128 | )
129 | trans_gt_global_poses = trans_gt_global_poses.reshape(
130 | (trans_gt_global_poses.shape[0], -1, n_joints * 3)
131 | ).transpose([0, 2, 1])
132 | # Normalize
133 | trans_gt_global_poses = (trans_gt_global_poses - x_mean) / x_std
134 |
135 | # Zero-velocity pos/quats
136 | zerov_trans_local_quats, zerov_trans_local_poses = (
137 | np.zeros_like(trans_gt_local_quats),
138 | np.zeros_like(trans_gt_local_poses),
139 | )
140 | zerov_trans_local_quats[:, :, :, :] = gt_local_quats[
141 | :, n_past - 1 : n_past, :, :
142 | ]
143 | zerov_trans_local_poses[:, :, :, :] = gt_local_poses[
144 | :, n_past - 1 : n_past, :, :
145 | ]
146 | # To global
147 | trans_zerov_global_quats, trans_zerov_global_poses = utils.quat_fk(
148 | zerov_trans_local_quats, zerov_trans_local_poses, parents
149 | )
150 | trans_zerov_global_poses = trans_zerov_global_poses.reshape(
151 | (trans_zerov_global_poses.shape[0], -1, n_joints * 3)
152 | ).transpose([0, 2, 1])
153 | # Normalize
154 | trans_zerov_global_poses = (trans_zerov_global_poses - x_mean) / x_std
155 |
156 | # Interpolation pos/quats
157 | r, q = curr_x[:, :, 0:1], curr_q
158 | inter_root, inter_local_quats = utils.interpolate_local(r, q, n_past, n_future)
159 | trans_inter_root = inter_root[:, 1:-1, :, :]
160 | trans_inter_offsets = np.tile(offsets, [batchsize, n_trans, 1, 1])
161 | trans_inter_local_poses = np.concatenate(
162 | [trans_inter_root, trans_inter_offsets], axis=2
163 | )
164 | inter_local_quats = inter_local_quats[:, 1:-1, :, :]
165 | # To global
166 | trans_interp_global_quats, trans_interp_global_poses = utils.quat_fk(
167 | inter_local_quats, trans_inter_local_poses, parents
168 | )
169 | trans_interp_global_poses = trans_interp_global_poses.reshape(
170 | (trans_interp_global_poses.shape[0], -1, n_joints * 3)
171 | ).transpose([0, 2, 1])
172 | # Normalize
173 | trans_interp_global_poses = (trans_interp_global_poses - x_mean) / x_std
174 |
175 | # Local quaternion loss
176 | res[("zerov_quat_loss", n_trans)] = np.mean(
177 | np.sqrt(
178 | np.sum(
179 | (trans_zerov_global_quats - trans_gt_global_quats) ** 2.0,
180 | axis=(2, 3),
181 | )
182 | )
183 | )
184 | res[("interp_quat_loss", n_trans)] = np.mean(
185 | np.sqrt(
186 | np.sum(
187 | (trans_interp_global_quats - trans_gt_global_quats) ** 2.0,
188 | axis=(2, 3),
189 | )
190 | )
191 | )
192 |
193 | # Global positions loss
194 | res[("zerov_pos_loss", n_trans)] = np.mean(
195 | np.sqrt(
196 | np.sum(
197 | (trans_zerov_global_poses - trans_gt_global_poses) ** 2.0, axis=1
198 | )
199 | )
200 | )
201 | res[("interp_pos_loss", n_trans)] = np.mean(
202 | np.sqrt(
203 | np.sum(
204 | (trans_interp_global_poses - trans_gt_global_poses) ** 2.0, axis=1
205 | )
206 | )
207 | )
208 |
209 | # NPSS loss on global quaternions
210 | res[("zerov_npss_loss", n_trans)] = fast_npss(
211 | flatjoints(trans_gt_global_quats), flatjoints(trans_zerov_global_quats)
212 | )
213 | res[("interp_npss_loss", n_trans)] = fast_npss(
214 | flatjoints(trans_gt_global_quats), flatjoints(trans_interp_global_quats)
215 | )
216 |
217 | print()
218 | avg_zerov_quat_losses = [res[("zerov_quat_loss", n)] for n in trans_lengths]
219 | avg_interp_quat_losses = [res[("interp_quat_loss", n)] for n in trans_lengths]
220 | print("=== Global quat losses ===")
221 | print(
222 | "{0: <16} | {1:6d} | {2:6d} | {3:6d} | {4:6d}".format("Lengths", 5, 15, 30, 45)
223 | )
224 | print(
225 | "{0: <16} | {1:6.2f} | {2:6.2f} | {3:6.2f} | {4:6.2f}".format(
226 | "Zero-V", *avg_zerov_quat_losses
227 | )
228 | )
229 | print(
230 | "{0: <16} | {1:6.2f} | {2:6.2f} | {3:6.2f} | {4:6.2f}".format(
231 | "Interp.", *avg_interp_quat_losses
232 | )
233 | )
234 | print()
235 |
236 | avg_zerov_pos_losses = [res[("zerov_pos_loss", n)] for n in trans_lengths]
237 | avg_interp_pos_losses = [res[("interp_pos_loss", n)] for n in trans_lengths]
238 | print("=== Global pos losses ===")
239 | print(
240 | "{0: <16} | {1:6d} | {2:6d} | {3:6d} | {4:6d}".format("Lengths", 5, 15, 30, 45)
241 | )
242 | print(
243 | "{0: <16} | {1:6.3f} | {2:6.3f} | {3:6.3f} | {4:6.3f}".format(
244 | "Zero-V", *avg_zerov_pos_losses
245 | )
246 | )
247 | print(
248 | "{0: <16} | {1:6.3f} | {2:6.3f} | {3:6.3f} | {4:6.3f}".format(
249 | "Interp.", *avg_interp_pos_losses
250 | )
251 | )
252 | print()
253 |
254 | avg_zerov_npss_losses = [res[("zerov_npss_loss", n)] for n in trans_lengths]
255 | avg_interp_npss_losses = [res[("interp_npss_loss", n)] for n in trans_lengths]
256 | print("=== NPSS on global quats ===")
257 | print(
258 | "{0: <16} | {1:5d} | {2:5d} | {3:5d} | {4:5d}".format(
259 | "Lengths", 5, 15, 30, 45
260 | )
261 | )
262 | print(
263 | "{0: <16} | {1:5.4f} | {2:5.4f} | {3:5.4f} | {4:5.4f}".format(
264 | "Zero-V", *avg_zerov_npss_losses
265 | )
266 | )
267 | print(
268 | "{0: <16} | {1:5.4f} | {2:5.4f} | {3:5.4f} | {4:5.4f}".format(
269 | "Interp.", *avg_interp_npss_losses
270 | )
271 | )
272 | print()
273 |
274 | # Write to file is desired
275 | if out_path is not None:
276 | res_txt_file = open(
277 | os.path.join(out_path, "h36m_transitions_benchmark.txt"), "a"
278 | )
279 | res_txt_file.write("\n=== Global quat losses ===\n")
280 | res_txt_file.write(
281 | "{0: <16} | {1:6d} | {2:6d} | {3:6d} | {4:6d}\n".format(
282 | "Lengths", 5, 15, 30, 45
283 | )
284 | )
285 | res_txt_file.write(
286 | "{0: <16} | {1:6.2f} | {2:6.2f} | {3:6.2f} | {4:6.2f}\n".format(
287 | "Zero-V", *avg_zerov_quat_losses
288 | )
289 | )
290 | res_txt_file.write(
291 | "{0: <16} | {1:6.2f} | {2:6.2f} | {3:6.2f} | {4:6.2f}\n".format(
292 | "Interp.", *avg_interp_quat_losses
293 | )
294 | )
295 | res_txt_file.write("\n\n")
296 | res_txt_file.write("=== Global pos losses ===\n")
297 | res_txt_file.write(
298 | "{0: <16} | {1:5d} | {2:5d} | {3:5d} | {4:5d}\n".format(
299 | "Lengths", 5, 15, 30, 45
300 | )
301 | )
302 | res_txt_file.write(
303 | "{0: <16} | {1:5.4f} | {2:5.4f} | {3:5.4f} | {4:5.4f}\n".format(
304 | "Zero-V", *avg_zerov_pos_losses
305 | )
306 | )
307 | res_txt_file.write(
308 | "{0: <16} | {1:5.4f} | {2:5.4f} | {3:5.4f} | {4:5.4f}\n".format(
309 | "Interp.", *avg_interp_pos_losses
310 | )
311 | )
312 | res_txt_file.write("\n\n")
313 | res_txt_file.write("=== NPSS on global quats ===\n")
314 | res_txt_file.write(
315 | "{0: <16} | {1:5d} | {2:5d} | {3:5d} | {4:5d}\n".format(
316 | "Lengths", 5, 15, 30, 45
317 | )
318 | )
319 | res_txt_file.write(
320 | "{0: <16} | {1:5.4f} | {2:5.4f} | {3:5.4f} | {4:5.4f}\n".format(
321 | "Zero-V", *avg_zerov_npss_losses
322 | )
323 | )
324 | res_txt_file.write(
325 | "{0: <16} | {1:5.4f} | {2:5.4f} | {3:5.4f} | {4:5.4f}\n".format(
326 | "Interp.", *avg_interp_npss_losses
327 | )
328 | )
329 | res_txt_file.write("\n\n\n\n")
330 | res_txt_file.close()
331 |
332 | return res
333 |
--------------------------------------------------------------------------------
/cmib/lafan1/extract.py:
--------------------------------------------------------------------------------
1 | import re, os, ntpath
2 | import numpy as np
3 | from . import utils
4 |
5 | channelmap = {"Xrotation": "x", "Yrotation": "y", "Zrotation": "z"}
6 |
7 | channelmap_inv = {
8 | "x": "Xrotation",
9 | "y": "Yrotation",
10 | "z": "Zrotation",
11 | }
12 |
13 | ordermap = {
14 | "x": 0,
15 | "y": 1,
16 | "z": 2,
17 | }
18 |
19 |
20 | class Anim(object):
21 | """
22 | A very basic animation object
23 | """
24 |
25 | def __init__(self, quats, pos, offsets, parents, bones):
26 | """
27 | :param quats: local quaternions tensor
28 | :param pos: local positions tensor
29 | :param offsets: local joint offsets
30 | :param parents: bone hierarchy
31 | :param bones: bone names
32 | """
33 | self.quats = quats
34 | self.pos = pos
35 | self.offsets = offsets
36 | self.parents = parents
37 | self.bones = bones
38 |
39 |
40 | def read_bvh(filename, start=None, end=None, order=None):
41 | """
42 | Reads a BVH file and extracts animation information.
43 |
44 | :param filename: BVh filename
45 | :param start: start frame
46 | :param end: end frame
47 | :param order: order of euler rotations
48 | :return: A simple Anim object conatining the extracted information.
49 | """
50 |
51 | f = open(filename, "r")
52 |
53 | i = 0
54 | active = -1
55 | end_site = False
56 |
57 | names = []
58 | orients = np.array([]).reshape((0, 4))
59 | offsets = np.array([]).reshape((0, 3))
60 | parents = np.array([], dtype=int)
61 |
62 | # Parse the file, line by line
63 | for line in f:
64 |
65 | if "HIERARCHY" in line:
66 | continue
67 | if "MOTION" in line:
68 | continue
69 |
70 | rmatch = re.match(r"ROOT (\w+)", line)
71 | if rmatch:
72 | names.append(rmatch.group(1))
73 | offsets = np.append(offsets, np.array([[0, 0, 0]]), axis=0)
74 | orients = np.append(orients, np.array([[1, 0, 0, 0]]), axis=0)
75 | parents = np.append(parents, active)
76 | active = len(parents) - 1
77 | continue
78 |
79 | if "{" in line:
80 | continue
81 |
82 | if "}" in line:
83 | if end_site:
84 | end_site = False
85 | else:
86 | active = parents[active]
87 | continue
88 |
89 | offmatch = re.match(
90 | r"\s*OFFSET\s+([\-\d\.e]+)\s+([\-\d\.e]+)\s+([\-\d\.e]+)", line
91 | )
92 | if offmatch:
93 | if not end_site:
94 | offsets[active] = np.array([list(map(float, offmatch.groups()))])
95 | continue
96 |
97 | chanmatch = re.match(r"\s*CHANNELS\s+(\d+)", line)
98 | if chanmatch:
99 | channels = int(chanmatch.group(1))
100 | if order is None:
101 | channelis = 0 if channels == 3 else 3
102 | channelie = 3 if channels == 3 else 6
103 | parts = line.split()[2 + channelis : 2 + channelie]
104 | if any([p not in channelmap for p in parts]):
105 | continue
106 | order = "".join([channelmap[p] for p in parts])
107 | continue
108 |
109 | jmatch = re.match("\s*JOINT\s+(\w+)", line)
110 | if jmatch:
111 | names.append(jmatch.group(1))
112 | offsets = np.append(offsets, np.array([[0, 0, 0]]), axis=0)
113 | orients = np.append(orients, np.array([[1, 0, 0, 0]]), axis=0)
114 | parents = np.append(parents, active)
115 | active = len(parents) - 1
116 | continue
117 |
118 | if "End Site" in line:
119 | end_site = True
120 | continue
121 |
122 | fmatch = re.match("\s*Frames:\s+(\d+)", line)
123 | if fmatch:
124 | if start and end:
125 | fnum = (end - start) - 1
126 | else:
127 | fnum = int(fmatch.group(1))
128 | positions = offsets[np.newaxis].repeat(fnum, axis=0)
129 | rotations = np.zeros((fnum, len(orients), 3))
130 | continue
131 |
132 | fmatch = re.match("\s*Frame Time:\s+([\d\.]+)", line)
133 | if fmatch:
134 | frametime = float(fmatch.group(1))
135 | continue
136 |
137 | if (start and end) and (i < start or i >= end - 1):
138 | i += 1
139 | continue
140 |
141 | dmatch = line.strip().split(" ")
142 | if dmatch:
143 | data_block = np.array(list(map(float, dmatch)))
144 | N = len(parents)
145 | fi = i - start if start else i
146 | if channels == 3:
147 | positions[fi, 0:1] = data_block[0:3]
148 | rotations[fi, :] = data_block[3:].reshape(N, 3)
149 | elif channels == 6:
150 | data_block = data_block.reshape(N, 6)
151 | positions[fi, :] = data_block[:, 0:3]
152 | rotations[fi, :] = data_block[:, 3:6]
153 | elif channels == 9:
154 | positions[fi, 0] = data_block[0:3]
155 | data_block = data_block[3:].reshape(N - 1, 9)
156 | rotations[fi, 1:] = data_block[:, 3:6]
157 | positions[fi, 1:] += data_block[:, 0:3] * data_block[:, 6:9]
158 | else:
159 | raise Exception("Too many channels! %i" % channels)
160 |
161 | i += 1
162 |
163 | f.close()
164 |
165 | rotations = utils.euler_to_quat(np.radians(rotations), order=order)
166 | rotations = utils.remove_quat_discontinuities(rotations)
167 |
168 | return Anim(rotations, positions, offsets, parents, names)
169 |
170 |
171 | def get_lafan1_set(bvh_path, actors, window=50, offset=20, train=True, stats=False, datset='LAFAN'):
172 | """
173 | Extract the same test set as in the article, given the location of the BVH files.
174 |
175 | :param bvh_path: Path to the dataset BVH files
176 | :param list: actor prefixes to use in set
177 | :param window: width of the sliding windows (in timesteps)
178 | :param offset: offset between windows (in timesteps)
179 | :return: tuple:
180 | X: local positions
181 | Q: local quaternions
182 | parents: list of parent indices defining the bone hierarchy
183 | contacts_l: binary tensor of left-foot contacts of shape (Batchsize, Timesteps, 2)
184 | contacts_r: binary tensor of right-foot contacts of shape (Batchsize, Timesteps, 2)
185 | """
186 | npast = 10
187 | subjects = []
188 | seq_names = []
189 | X = []
190 | Q = []
191 | contacts_l = []
192 | contacts_r = []
193 |
194 | # Extract
195 | bvh_files = sorted(os.listdir(bvh_path))
196 |
197 | for file in bvh_files:
198 | if file.endswith(".bvh"):
199 | file_info = ntpath.basename(file[:-4]).split("_")
200 | seq_name = file_info[0]
201 | subject = file_info[1]
202 |
203 | if (not train) and (file_info[-1] == "LRflip"):
204 | continue
205 |
206 | if stats and (file_info[-1] == "LRflip"):
207 | continue
208 |
209 | # seq_name, subject = ntpath.basename(file[:-4]).split("_")
210 | if subject in actors:
211 | print("Processing file {}".format(file))
212 | seq_path = os.path.join(bvh_path, file)
213 | anim = read_bvh(seq_path)
214 |
215 | # Sliding windows
216 | i = 0
217 | while i + window < anim.pos.shape[0]:
218 | q, x = utils.quat_fk(
219 | anim.quats[i : i + window],
220 | anim.pos[i : i + window],
221 | anim.parents,
222 | )
223 | # Extract contacts
224 | c_l, c_r = utils.extract_feet_contacts(
225 | x, [3, 4], [7, 8], velfactor=0.02
226 | )
227 | X.append(anim.pos[i : i + window])
228 | Q.append(anim.quats[i : i + window])
229 | seq_names.append(seq_name)
230 | subjects.append(subjects)
231 | contacts_l.append(c_l)
232 | contacts_r.append(c_r)
233 |
234 | i += offset
235 |
236 | X = np.asarray(X)
237 | Q = np.asarray(Q)
238 | contacts_l = np.asarray(contacts_l)
239 | contacts_r = np.asarray(contacts_r)
240 |
241 | # Sequences around XZ = 0
242 | xzs = np.mean(X[:, :, 0, ::2], axis=1, keepdims=True)
243 | X[:, :, 0, 0] = X[:, :, 0, 0] - xzs[..., 0]
244 | X[:, :, 0, 2] = X[:, :, 0, 2] - xzs[..., 1]
245 |
246 | # Unify facing on last seed frame
247 | X, Q = utils.rotate_at_frame(X, Q, anim.parents, n_past=npast)
248 |
249 | return X, Q, anim.parents, contacts_l, contacts_r, seq_names
250 |
251 |
252 | def get_train_stats(bvh_folder, train_set):
253 | """
254 | Extract the same training set as in the paper in order to compute the normalizing statistics
255 | :return: Tuple of (local position mean vector, local position standard deviation vector, local joint offsets tensor)
256 | """
257 | print("Building the train set...")
258 | xtrain, qtrain, parents, _, _, _ = get_lafan1_set(
259 | bvh_folder, train_set, window=50, offset=20, train=True, stats=True
260 | )
261 |
262 | print("Computing stats...\n")
263 | # Joint offsets : are constant, so just take the first frame:
264 | offsets = xtrain[0:1, 0:1, 1:, :] # Shape : (1, 1, J, 3)
265 |
266 | # Global representation:
267 | q_glbl, x_glbl = utils.quat_fk(qtrain, xtrain, parents)
268 |
269 | # Global positions stats:
270 | x_mean = np.mean(
271 | x_glbl.reshape([x_glbl.shape[0], x_glbl.shape[1], -1]).transpose([0, 2, 1]),
272 | axis=(0, 2),
273 | keepdims=True,
274 | )
275 | x_std = np.std(
276 | x_glbl.reshape([x_glbl.shape[0], x_glbl.shape[1], -1]).transpose([0, 2, 1]),
277 | axis=(0, 2),
278 | keepdims=True,
279 | )
280 |
281 | return x_mean, x_std, offsets
282 |
--------------------------------------------------------------------------------
/cmib/lafan1/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def length(x, axis=-1, keepdims=True):
5 | """
6 | Computes vector norm along a tensor axis(axes)
7 |
8 | :param x: tensor
9 | :param axis: axis(axes) along which to compute the norm
10 | :param keepdims: indicates if the dimension(s) on axis should be kept
11 | :return: The length or vector of lengths.
12 | """
13 | lgth = np.sqrt(np.sum(x * x, axis=axis, keepdims=keepdims))
14 | return lgth
15 |
16 |
17 | def normalize(x, axis=-1, eps=1e-8):
18 | """
19 | Normalizes a tensor over some axis (axes)
20 |
21 | :param x: data tensor
22 | :param axis: axis(axes) along which to compute the norm
23 | :param eps: epsilon to prevent numerical instabilities
24 | :return: The normalized tensor
25 | """
26 | res = x / (length(x, axis=axis) + eps)
27 | return res
28 |
29 |
30 | def quat_normalize(x, eps=1e-8):
31 | """
32 | Normalizes a quaternion tensor
33 |
34 | :param x: data tensor
35 | :param eps: epsilon to prevent numerical instabilities
36 | :return: The normalized quaternions tensor
37 | """
38 | res = normalize(x, eps=eps)
39 | return res
40 |
41 |
42 | def angle_axis_to_quat(angle, axis):
43 | """
44 | Converts from and angle-axis representation to a quaternion representation
45 |
46 | :param angle: angles tensor
47 | :param axis: axis tensor
48 | :return: quaternion tensor
49 | """
50 | c = np.cos(angle / 2.0)[..., np.newaxis]
51 | s = np.sin(angle / 2.0)[..., np.newaxis]
52 | q = np.concatenate([c, s * axis], axis=-1)
53 | return q
54 |
55 |
56 | def euler_to_quat(e, order="zyx"):
57 | """
58 |
59 | Converts from an euler representation to a quaternion representation
60 |
61 | :param e: euler tensor
62 | :param order: order of euler rotations
63 | :return: quaternion tensor
64 | """
65 | axis = {
66 | "x": np.asarray([1, 0, 0], dtype=np.float32),
67 | "y": np.asarray([0, 1, 0], dtype=np.float32),
68 | "z": np.asarray([0, 0, 1], dtype=np.float32),
69 | }
70 |
71 | q0 = angle_axis_to_quat(e[..., 0], axis[order[0]])
72 | q1 = angle_axis_to_quat(e[..., 1], axis[order[1]])
73 | q2 = angle_axis_to_quat(e[..., 2], axis[order[2]])
74 |
75 | return quat_mul(q0, quat_mul(q1, q2))
76 |
77 |
78 | def quat_inv(q):
79 | """
80 | Inverts a tensor of quaternions
81 |
82 | :param q: quaternion tensor
83 | :return: tensor of inverted quaternions
84 | """
85 | res = np.asarray([1, -1, -1, -1], dtype=np.float32) * q
86 | return res
87 |
88 |
89 | def quat_fk(lrot, lpos, parents):
90 | """
91 | Performs Forward Kinematics (FK) on local quaternions and local positions to retrieve global representations
92 |
93 | :param lrot: tensor of local quaternions with shape (..., Nb of joints, 4)
94 | :param lpos: tensor of local positions with shape (..., Nb of joints, 3)
95 | :param parents: list of parents indices
96 | :return: tuple of tensors of global quaternion, global positions
97 | """
98 | gp, gr = [lpos[..., :1, :]], [lrot[..., :1, :]]
99 | for i in range(1, len(parents)):
100 | gp.append(
101 | quat_mul_vec(gr[parents[i]], lpos[..., i : i + 1, :]) + gp[parents[i]]
102 | )
103 | gr.append(quat_mul(gr[parents[i]], lrot[..., i : i + 1, :]))
104 |
105 | res = np.concatenate(gr, axis=-2), np.concatenate(gp, axis=-2)
106 | return res
107 |
108 |
109 | def quat_ik(grot, gpos, parents):
110 | """
111 | Performs Inverse Kinematics (IK) on global quaternions and global positions to retrieve local representations
112 |
113 | :param grot: tensor of global quaternions with shape (..., Nb of joints, 4)
114 | :param gpos: tensor of global positions with shape (..., Nb of joints, 3)
115 | :param parents: list of parents indices
116 | :return: tuple of tensors of local quaternion, local positions
117 | """
118 | res = [
119 | np.concatenate(
120 | [
121 | grot[..., :1, :],
122 | quat_mul(quat_inv(grot[..., parents[1:], :]), grot[..., 1:, :]),
123 | ],
124 | axis=-2,
125 | ),
126 | np.concatenate(
127 | [
128 | gpos[..., :1, :],
129 | quat_mul_vec(
130 | quat_inv(grot[..., parents[1:], :]),
131 | gpos[..., 1:, :] - gpos[..., parents[1:], :],
132 | ),
133 | ],
134 | axis=-2,
135 | ),
136 | ]
137 |
138 | return res
139 |
140 |
141 | def quat_mul(x, y):
142 | """
143 | Performs quaternion multiplication on arrays of quaternions
144 |
145 | :param x: tensor of quaternions of shape (..., Nb of joints, 4)
146 | :param y: tensor of quaternions of shape (..., Nb of joints, 4)
147 | :return: The resulting quaternions
148 | """
149 | x0, x1, x2, x3 = x[..., 0:1], x[..., 1:2], x[..., 2:3], x[..., 3:4]
150 | y0, y1, y2, y3 = y[..., 0:1], y[..., 1:2], y[..., 2:3], y[..., 3:4]
151 |
152 | res = np.concatenate(
153 | [
154 | y0 * x0 - y1 * x1 - y2 * x2 - y3 * x3,
155 | y0 * x1 + y1 * x0 - y2 * x3 + y3 * x2,
156 | y0 * x2 + y1 * x3 + y2 * x0 - y3 * x1,
157 | y0 * x3 - y1 * x2 + y2 * x1 + y3 * x0,
158 | ],
159 | axis=-1,
160 | )
161 |
162 | return res
163 |
164 |
165 | def quat_mul_vec(q, x):
166 | """
167 | Performs multiplication of an array of 3D vectors by an array of quaternions (rotation).
168 |
169 | :param q: tensor of quaternions of shape (..., Nb of joints, 4)
170 | :param x: tensor of vectors of shape (..., Nb of joints, 3)
171 | :return: the resulting array of rotated vectors
172 | """
173 | t = 2.0 * np.cross(q[..., 1:], x)
174 | res = x + q[..., 0][..., np.newaxis] * t + np.cross(q[..., 1:], t)
175 |
176 | return res
177 |
178 |
179 | def quat_slerp(x, y, a):
180 | """
181 | Perfroms spherical linear interpolation (SLERP) between x and y, with proportion a
182 |
183 | :param x: quaternion tensor
184 | :param y: quaternion tensor
185 | :param a: indicator (between 0 and 1) of completion of the interpolation.
186 | :return: tensor of interpolation results
187 | """
188 | len = np.sum(x * y, axis=-1)
189 |
190 | neg = len < 0.0
191 | len[neg] = -len[neg]
192 | y[neg] = -y[neg]
193 |
194 | a = np.zeros_like(x[..., 0]) + a
195 | amount0 = np.zeros(a.shape)
196 | amount1 = np.zeros(a.shape)
197 |
198 | linear = (1.0 - len) < 0.01
199 | omegas = np.arccos(len[~linear])
200 | sinoms = np.sin(omegas)
201 |
202 | amount0[linear] = 1.0 - a[linear]
203 | amount0[~linear] = np.sin((1.0 - a[~linear]) * omegas) / sinoms
204 |
205 | amount1[linear] = a[linear]
206 | amount1[~linear] = np.sin(a[~linear] * omegas) / sinoms
207 | res = amount0[..., np.newaxis] * x + amount1[..., np.newaxis] * y
208 |
209 | return res
210 |
211 |
212 | def quat_between(x, y):
213 | """
214 | Quaternion rotations between two 3D-vector arrays
215 |
216 | :param x: tensor of 3D vectors
217 | :param y: tensor of 3D vetcors
218 | :return: tensor of quaternions
219 | """
220 | res = np.concatenate(
221 | [
222 | np.sqrt(np.sum(x * x, axis=-1) * np.sum(y * y, axis=-1))[..., np.newaxis]
223 | + np.sum(x * y, axis=-1)[..., np.newaxis],
224 | np.cross(x, y),
225 | ],
226 | axis=-1,
227 | )
228 | return res
229 |
230 |
231 | def interpolate_local(lcl_r_mb, lcl_q_mb, n_past, n_future):
232 | """
233 | Performs interpolation between 2 frames of an animation sequence.
234 |
235 | The 2 frames are indirectly specified through n_past and n_future.
236 | SLERP is performed on the quaternions
237 | LERP is performed on the root's positions.
238 |
239 | :param lcl_r_mb: Local/Global root positions (B, T, 1, 3)
240 | :param lcl_q_mb: Local quaternions (B, T, J, 4)
241 | :param n_past: Number of frames of past context
242 | :param n_future: Number of frames of future context
243 | :return: Interpolated root and quats
244 | """
245 | # Extract last past frame and target frame
246 | start_lcl_r_mb = lcl_r_mb[:, n_past - 1, :, :][:, None, :, :] # (B, 1, J, 3)
247 | end_lcl_r_mb = lcl_r_mb[:, -n_future, :, :][:, None, :, :]
248 |
249 | start_lcl_q_mb = lcl_q_mb[:, n_past - 1, :, :]
250 | end_lcl_q_mb = lcl_q_mb[:, -n_future, :, :]
251 |
252 | # LERP Local Positions:
253 | n_trans = lcl_r_mb.shape[1] - (n_past + n_future)
254 | interp_ws = np.linspace(0.0, 1.0, num=n_trans + 2, dtype=np.float32)
255 | offset = end_lcl_r_mb - start_lcl_r_mb
256 |
257 | const_trans = np.tile(start_lcl_r_mb, [1, n_trans + 2, 1, 1])
258 | inter_lcl_r_mb = const_trans + (interp_ws)[None, :, None, None] * offset
259 |
260 | # SLERP Local Quats:
261 | interp_ws = np.linspace(0.0, 1.0, num=n_trans + 2, dtype=np.float32)
262 | inter_lcl_q_mb = np.stack(
263 | [
264 | (
265 | quat_normalize(
266 | quat_slerp(
267 | quat_normalize(start_lcl_q_mb), quat_normalize(end_lcl_q_mb), w
268 | )
269 | )
270 | )
271 | for w in interp_ws
272 | ],
273 | axis=1,
274 | )
275 |
276 | return inter_lcl_r_mb, inter_lcl_q_mb
277 |
278 |
279 | def remove_quat_discontinuities(rotations):
280 | """
281 |
282 | Removing quat discontinuities on the time dimension (removing flips)
283 |
284 | :param rotations: Array of quaternions of shape (T, J, 4)
285 | :return: The processed array without quaternion inversion.
286 | """
287 | rots_inv = -rotations
288 |
289 | for i in range(1, rotations.shape[0]):
290 | # Compare dot products
291 | replace_mask = np.sum(
292 | rotations[i - 1 : i] * rotations[i : i + 1], axis=-1
293 | ) < np.sum(rotations[i - 1 : i] * rots_inv[i : i + 1], axis=-1)
294 | replace_mask = replace_mask[..., np.newaxis]
295 | rotations[i] = replace_mask * rots_inv[i] + (1.0 - replace_mask) * rotations[i]
296 |
297 | return rotations
298 |
299 |
300 | # Orient the data according to the las past keframe
301 | def rotate_at_frame(X, Q, parents, n_past=10):
302 | """
303 | Re-orients the animation data according to the last frame of past context.
304 |
305 | :param X: tensor of local positions of shape (Batchsize, Timesteps, Joints, 3)
306 | :param Q: tensor of local quaternions (Batchsize, Timesteps, Joints, 4)
307 | :param parents: list of parents' indices
308 | :param n_past: number of frames in the past context
309 | :return: The rotated positions X and quaternions Q
310 | """
311 | # Get global quats and global poses (FK)
312 | global_q, global_x = quat_fk(Q, X, parents)
313 |
314 | key_glob_Q = global_q[:, n_past - 1 : n_past, 0:1, :] # (B, 1, 1, 4)
315 | forward = np.array([1, 0, 1])[np.newaxis, np.newaxis, np.newaxis, :] * quat_mul_vec(
316 | key_glob_Q, np.array([0, 1, 0])[np.newaxis, np.newaxis, np.newaxis, :]
317 | )
318 | forward = normalize(forward)
319 | yrot = quat_normalize(quat_between(np.array([1, 0, 0]), forward))
320 | new_glob_Q = quat_mul(quat_inv(yrot), global_q)
321 | new_glob_X = quat_mul_vec(quat_inv(yrot), global_x)
322 |
323 | # back to local quat-pos
324 | Q, X = quat_ik(new_glob_Q, new_glob_X, parents)
325 |
326 | return X, Q
327 |
328 |
329 | def extract_feet_contacts(pos, lfoot_idx, rfoot_idx, velfactor=0.02):
330 | """
331 | Extracts binary tensors of feet contacts
332 |
333 | :param pos: tensor of global positions of shape (Timesteps, Joints, 3)
334 | :param lfoot_idx: indices list of left foot joints
335 | :param rfoot_idx: indices list of right foot joints
336 | :param velfactor: velocity threshold to consider a joint moving or not
337 | :return: binary tensors of left foot contacts and right foot contacts
338 | """
339 | lfoot_xyz = (pos[1:, lfoot_idx, :] - pos[:-1, lfoot_idx, :]) ** 2
340 | contacts_l = np.sum(lfoot_xyz, axis=-1) < velfactor
341 |
342 | rfoot_xyz = (pos[1:, rfoot_idx, :] - pos[:-1, rfoot_idx, :]) ** 2
343 | contacts_r = np.sum(rfoot_xyz, axis=-1) < velfactor
344 |
345 | # Duplicate the last frame for shape consistency
346 | contacts_l = np.concatenate([contacts_l, contacts_l[-1:]], axis=0)
347 | contacts_r = np.concatenate([contacts_r, contacts_r[-1:]], axis=0)
348 |
349 | return contacts_l, contacts_r
350 |
--------------------------------------------------------------------------------
/cmib/misc/sampler.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from pathlib import Path
4 |
5 | import imageio
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | from PIL import Image
10 | from sklearn.preprocessing import LabelEncoder
11 |
12 | from cmib.data.lafan1_dataset import LAFAN1Dataset
13 | from cmib.data.utils import write_json
14 | from cmib.lafan1.utils import quat_ik
15 | from cmib.model.network import TransformerModel
16 | from cmib.model.preprocess import (lerp_input_repr, replace_constant,
17 | slerp_input_repr, vectorize_representation)
18 | from cmib.model.skeleton import (Skeleton, sk_joints_to_remove, sk_offsets, joint_names,
19 | sk_parents)
20 | from cmib.vis.pose import plot_pose_with_stop
21 |
22 |
23 | def test(opt, device):
24 |
25 | save_dir = Path(os.path.join('runs', 'train', opt.exp_name))
26 | wdir = save_dir / 'weights'
27 | weights = os.listdir(wdir)
28 | weights_paths = [wdir / weight for weight in weights]
29 | latest_weight = max(weights_paths , key = os.path.getctime)
30 | ckpt = torch.load(latest_weight, map_location=device)
31 | print(f"Loaded weight: {latest_weight}")
32 |
33 | # Load Skeleton
34 | skeleton_mocap = Skeleton(offsets=sk_offsets, parents=sk_parents, device=device)
35 | skeleton_mocap.remove_joints(sk_joints_to_remove)
36 |
37 | # Load LAFAN Dataset
38 | Path(opt.processed_data_dir).mkdir(parents=True, exist_ok=True)
39 | lafan_dataset = LAFAN1Dataset(lafan_path=opt.data_path, processed_data_dir=opt.processed_data_dir, train=False, device=device)
40 | total_data = lafan_dataset.data['global_pos'].shape[0]
41 |
42 | # Replace with noise to In-betweening Frames
43 | from_idx, target_idx = ckpt['from_idx'], ckpt['target_idx'] # default: 9-40, max: 48
44 | horizon = ckpt['horizon']
45 | print(f"HORIZON: {horizon}")
46 |
47 | test_idx = []
48 | for i in range(total_data):
49 | test_idx.append(i)
50 |
51 | # Compare Input data, Prediction, GT
52 | save_path = os.path.join(opt.save_path, 'sampler')
53 | for i in range(len(test_idx)):
54 | Path(save_path).mkdir(parents=True, exist_ok=True)
55 |
56 | start_pose = lafan_dataset.data['global_pos'][test_idx[i], from_idx]
57 | target_pose = lafan_dataset.data['global_pos'][test_idx[i], target_idx]
58 | gt_stopover_pose = lafan_dataset.data['global_pos'][test_idx[i], from_idx]
59 |
60 | gt_img_path = os.path.join(save_path)
61 | plot_pose_with_stop(start_pose, target_pose, target_pose, gt_stopover_pose, i, skeleton_mocap, save_dir=gt_img_path, prefix='gt')
62 | print(f"ID {test_idx[i]}: completed.")
63 |
64 | def parse_opt():
65 | parser = argparse.ArgumentParser()
66 | parser.add_argument('--project', default='runs/train', help='project/name')
67 | parser.add_argument('--exp_name', default='slerp_40', help='experiment name')
68 | parser.add_argument('--data_path', type=str, default='ubisoft-laforge-animation-dataset/output/BVH', help='BVH dataset path')
69 | parser.add_argument('--skeleton_path', type=str, default='ubisoft-laforge-animation-dataset/output/BVH/walk1_subject1.bvh', help='path to reference skeleton')
70 | parser.add_argument('--processed_data_dir', type=str, default='processed_data_original/', help='path to save pickled processed data')
71 | parser.add_argument('--save_path', type=str, default='runs/test', help='path to save model')
72 | parser.add_argument('--motion_type', type=str, default='jumps', help='motion type')
73 | opt = parser.parse_args()
74 | return opt
75 |
76 | if __name__ == "__main__":
77 | opt = parse_opt()
78 | device = torch.device("cpu")
79 | test(opt, device)
80 |
--------------------------------------------------------------------------------
/cmib/model/network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 | from torch.nn import TransformerEncoder, TransformerEncoderLayer
4 | from cmib.model.positional_encoding import PositionalEmbedding
5 |
6 |
7 | class TransformerModel(nn.Module):
8 | def __init__(
9 | self,
10 | seq_len: int,
11 | d_model: int,
12 | nhead: int,
13 | d_hid: int,
14 | nlayers: int,
15 | dropout: float = 0.5,
16 | out_dim=91,
17 | num_labels=15
18 | ):
19 | super().__init__()
20 | self.model_type = "Transformer"
21 | self.seq_len = seq_len
22 | self.d_model = d_model
23 | self.nhead = nhead
24 | self.d_hid = d_hid
25 | self.nlayers = nlayers
26 | self.cond_emb = nn.Embedding(num_labels, d_model)
27 | self.pos_embedding = PositionalEmbedding(seq_len=seq_len, d_model=d_model)
28 | encoder_layers = TransformerEncoderLayer(
29 | d_model, nhead, d_hid, dropout, activation="gelu"
30 | )
31 | self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
32 | self.decoder = nn.Linear(d_model, out_dim)
33 |
34 | self.init_weights()
35 |
36 | def init_weights(self) -> None:
37 | initrange = 0.1
38 | self.decoder.bias.data.zero_()
39 | self.decoder.weight.data.uniform_(-initrange, initrange)
40 |
41 | def forward(self, src: Tensor, src_mask: Tensor, cond_code: Tensor) -> Tensor:
42 | """
43 | Args:
44 | src: Tensor, shape [seq_len, batch_size, embedding_dim]
45 | src_mask: Tensor, shape [seq_len, seq_len]
46 |
47 | Returns:
48 | output Tensor of shape [seq_len, batch_size, embedding_dim]
49 | """
50 | cond_embedding = self.cond_emb(cond_code).permute(1, 0, 2)
51 | output = self.pos_embedding(src)
52 | output = torch.cat([cond_embedding, output], dim=0)
53 | output = self.transformer_encoder(output, src_mask)
54 | output = self.decoder(output)
55 | return output, cond_embedding
56 |
--------------------------------------------------------------------------------
/cmib/model/positional_encoding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 | import math
4 |
5 |
6 | class PositionalEmbedding(nn.Module):
7 | def __init__(self, seq_len: int = 32, d_model: int = 96):
8 | super().__init__()
9 | self.pos_emb = nn.Embedding(seq_len + 1, d_model)
10 |
11 | def forward(self, inputs):
12 | positions = (
13 | torch.arange(inputs.size(0), device=inputs.device)
14 | .expand(inputs.size(1), inputs.size(0))
15 | .contiguous()
16 | + 1
17 | )
18 | outputs = inputs + self.pos_emb(positions).permute(1, 0, 2)
19 | return outputs
20 |
21 |
22 | class PositionalEncoding(nn.Module):
23 | def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
24 | super().__init__()
25 | self.dropout = nn.Dropout(p=dropout)
26 |
27 | position = torch.arange(max_len).unsqueeze(1)
28 | div_term = torch.exp(
29 | torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
30 | )
31 | pe = torch.zeros(max_len, 1, d_model)
32 | pe[:, 0, 0::2] = torch.sin(position * div_term)
33 | pe[:, 0, 1::2] = torch.cos(position * div_term)
34 | self.register_buffer("pe", pe)
35 |
36 | def forward(self, x: Tensor) -> Tensor:
37 | """
38 | Args:
39 | x: Tensor, shape [seq_len, batch_size, embedding_dim]
40 | """
41 | x = x + self.pe[: x.size(0)]
42 | return self.dropout(x)
43 |
--------------------------------------------------------------------------------
/cmib/model/preprocess.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def replace_constant(minibatch_pose_input, mask_start_frame):
5 |
6 | seq_len = minibatch_pose_input.size(1)
7 | interpolated = (
8 | torch.ones_like(minibatch_pose_input, device=minibatch_pose_input.device) * 0.1
9 | )
10 |
11 | if mask_start_frame == 0 or mask_start_frame == (seq_len - 1):
12 | interpolate_start = minibatch_pose_input[:, 0, :]
13 | interpolate_end = minibatch_pose_input[:, seq_len - 1, :]
14 |
15 | interpolated[:, 0, :] = interpolate_start
16 | interpolated[:, seq_len - 1, :] = interpolate_end
17 |
18 | assert torch.allclose(interpolated[:, 0, :], interpolate_start)
19 | assert torch.allclose(interpolated[:, seq_len - 1, :], interpolate_end)
20 |
21 | else:
22 | interpolate_start1 = minibatch_pose_input[:, 0, :]
23 | interpolate_end1 = minibatch_pose_input[:, mask_start_frame, :]
24 |
25 | interpolate_start2 = minibatch_pose_input[:, mask_start_frame, :]
26 | interpolate_end2 = minibatch_pose_input[:, seq_len - 1, :]
27 |
28 | interpolated[:, 0, :] = interpolate_start1
29 | interpolated[:, mask_start_frame, :] = interpolate_end1
30 |
31 | interpolated[:, mask_start_frame, :] = interpolate_start2
32 | interpolated[:, seq_len - 1, :] = interpolate_end2
33 |
34 | assert torch.allclose(interpolated[:, 0, :], interpolate_start1)
35 | assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_end1)
36 |
37 | assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_start2)
38 | assert torch.allclose(interpolated[:, seq_len - 1, :], interpolate_end2)
39 | return interpolated
40 |
41 |
42 | def slerp(x, y, a):
43 | """
44 | Perfroms spherical linear interpolation (SLERP) between x and y, with proportion a
45 |
46 | :param x: quaternion tensor
47 | :param y: quaternion tensor
48 | :param a: indicator (between 0 and 1) of completion of the interpolation.
49 | :return: tensor of interpolation results
50 | """
51 | device = x.device
52 | len = torch.sum(x * y, dim=-1)
53 |
54 | neg = len < 0.0
55 | len[neg] = -len[neg]
56 | y[neg] = -y[neg]
57 |
58 | a = torch.zeros_like(x[..., 0]) + a
59 | amount0 = torch.zeros(a.shape, device=device)
60 | amount1 = torch.zeros(a.shape, device=device)
61 |
62 | linear = (1.0 - len) < 0.01
63 | omegas = torch.arccos(len[~linear])
64 | sinoms = torch.sin(omegas)
65 |
66 | amount0[linear] = 1.0 - a[linear]
67 | amount0[~linear] = torch.sin((1.0 - a[~linear]) * omegas) / sinoms
68 |
69 | amount1[linear] = a[linear]
70 | amount1[~linear] = torch.sin(a[~linear] * omegas) / sinoms
71 | # res = amount0[..., np.newaxis] * x + amount1[..., np.newaxis] * y
72 | res = amount0.unsqueeze(3) * x + amount1.unsqueeze(3) * y
73 |
74 | return res
75 |
76 |
77 | def slerp_input_repr(minibatch_pose_input, mask_start_frame):
78 | seq_len = minibatch_pose_input.size(1)
79 | minibatch_pose_input = minibatch_pose_input.reshape(
80 | minibatch_pose_input.size(0), seq_len, -1, 4
81 | )
82 | interpolated = torch.zeros_like(
83 | minibatch_pose_input, device=minibatch_pose_input.device
84 | )
85 |
86 | if mask_start_frame == 0 or mask_start_frame == (seq_len - 1):
87 | interpolate_start = minibatch_pose_input[:, 0:1]
88 | interpolate_end = minibatch_pose_input[:, seq_len - 1 :]
89 |
90 | for i in range(seq_len):
91 | dt = 1 / (seq_len - 1)
92 | interpolated[:, i : i + 1, :] = slerp(
93 | interpolate_start, interpolate_end, dt * i
94 | )
95 |
96 | assert torch.allclose(interpolated[:, 0:1], interpolate_start)
97 | assert torch.allclose(interpolated[:, seq_len - 1 :], interpolate_end)
98 | else:
99 | interpolate_start1 = minibatch_pose_input[:, 0:1]
100 | interpolate_end1 = minibatch_pose_input[
101 | :, mask_start_frame : mask_start_frame + 1
102 | ]
103 |
104 | interpolate_start2 = minibatch_pose_input[
105 | :, mask_start_frame : mask_start_frame + 1
106 | ]
107 | interpolate_end2 = minibatch_pose_input[:, seq_len - 1 :]
108 |
109 | for i in range(mask_start_frame + 1):
110 | dt = 1 / mask_start_frame
111 | interpolated[:, i : i + 1, :] = slerp(
112 | interpolate_start1, interpolate_end1, dt * i
113 | )
114 |
115 | assert torch.allclose(interpolated[:, 0:1], interpolate_start1)
116 | assert torch.allclose(
117 | interpolated[:, mask_start_frame : mask_start_frame + 1], interpolate_end1
118 | )
119 |
120 | for i in range(mask_start_frame, seq_len):
121 | dt = 1 / (seq_len - mask_start_frame - 1)
122 | interpolated[:, i : i + 1, :] = slerp(
123 | interpolate_start2, interpolate_end2, dt * (i - mask_start_frame)
124 | )
125 |
126 | assert torch.allclose(
127 | interpolated[:, mask_start_frame : mask_start_frame + 1], interpolate_start2
128 | )
129 | assert torch.allclose(interpolated[:, seq_len - 1 :], interpolate_end2)
130 |
131 | interpolated = torch.nn.functional.normalize(interpolated, p=2.0, dim=3)
132 | return interpolated.reshape(minibatch_pose_input.size(0), seq_len, -1)
133 |
134 |
135 | def lerp_input_repr(minibatch_pose_input, mask_start_frame):
136 | seq_len = minibatch_pose_input.size(1)
137 | interpolated = torch.zeros_like(
138 | minibatch_pose_input, device=minibatch_pose_input.device
139 | )
140 |
141 | if mask_start_frame == 0 or mask_start_frame == (seq_len - 1):
142 | interpolate_start = minibatch_pose_input[:, 0, :]
143 | interpolate_end = minibatch_pose_input[:, seq_len - 1, :]
144 |
145 | for i in range(seq_len):
146 | dt = 1 / (seq_len - 1)
147 | interpolated[:, i, :] = torch.lerp(
148 | interpolate_start, interpolate_end, dt * i
149 | )
150 |
151 | assert torch.allclose(interpolated[:, 0, :], interpolate_start)
152 | assert torch.allclose(interpolated[:, seq_len - 1, :], interpolate_end)
153 | else:
154 | interpolate_start1 = minibatch_pose_input[:, 0, :]
155 | interpolate_end1 = minibatch_pose_input[:, mask_start_frame, :]
156 |
157 | interpolate_start2 = minibatch_pose_input[:, mask_start_frame, :]
158 | interpolate_end2 = minibatch_pose_input[:, -1, :]
159 |
160 | for i in range(mask_start_frame + 1):
161 | dt = 1 / mask_start_frame
162 | interpolated[:, i, :] = torch.lerp(
163 | interpolate_start1, interpolate_end1, dt * i
164 | )
165 |
166 | assert torch.allclose(interpolated[:, 0, :], interpolate_start1)
167 | assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_end1)
168 |
169 | for i in range(mask_start_frame, seq_len):
170 | dt = 1 / (seq_len - mask_start_frame - 1)
171 | interpolated[:, i, :] = torch.lerp(
172 | interpolate_start2, interpolate_end2, dt * (i - mask_start_frame)
173 | )
174 |
175 | assert torch.allclose(interpolated[:, mask_start_frame, :], interpolate_start2)
176 | assert torch.allclose(interpolated[:, -1, :], interpolate_end2)
177 | return interpolated
178 |
179 |
180 | def vectorize_representation(global_position, global_rotation):
181 |
182 | batch_size = global_position.shape[0]
183 | seq_len = global_position.shape[1]
184 |
185 | global_pos_vec = global_position.reshape(batch_size, seq_len, -1).contiguous()
186 | global_rot_vec = global_rotation.reshape(batch_size, seq_len, -1).contiguous()
187 |
188 | global_pose_vec_gt = torch.cat([global_pos_vec, global_rot_vec], dim=2)
189 | return global_pose_vec_gt
190 |
--------------------------------------------------------------------------------
/cmib/model/skeleton.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from cmib.data.quaternion import qmul, qrot
4 | import torch.nn as nn
5 |
6 | amass_offsets = [
7 | [0.0, 0.0, 0.0],
8 |
9 | [0.058581, -0.082280, -0.017664],
10 | [0.043451, -0.386469, 0.008037],
11 | [-0.014790, -0.426874, -0.037428],
12 | [0.041054, -0.060286, 0.122042],
13 | [0.0, 0.0, 0.0],
14 |
15 | [-0.060310, -0.090513, -0.013543],
16 | [-0.043257, -0.383688, -0.004843],
17 | [0.019056, -0.420046, -0.034562],
18 | [-0.034840, -0.062106, 0.130323],
19 | [0.0, 0.0, 0.0],
20 |
21 | [0.004439, 0.124404, -0.038385],
22 | [0.004488, 0.137956, 0.026820],
23 | [-0.002265, 0.056032, 0.002855],
24 | [-0.013390, 0.211636, -0.033468],
25 | [0.010113, 0.088937, 0.050410],
26 | [0.0, 0.0, 0.0],
27 |
28 | [0.071702, 0.114000, -0.018898],
29 | [0.122921, 0.045205, -0.019046],
30 | [0.255332, -0.015649, -0.022946],
31 | [0.265709, 0.012698, -0.007375],
32 | [0.0, 0.0, 0.0],
33 |
34 | [-0.082954, 0.112472, -0.023707],
35 | [-0.113228, 0.046853, -0.008472],
36 | [-0.260127, -0.014369, -0.031269],
37 | [-0.269108, 0.006794, -0.006027],
38 | [0.0, 0.0, 0.0]
39 | ]
40 |
41 | sk_offsets = [
42 | [-42.198200, 91.614723, -40.067841],
43 |
44 | [0.103456, 1.857829, 10.548506],
45 | [43.499992, -0.000038, -0.000002],
46 | [42.372192, 0.000015, -0.000007],
47 | [17.299999, -0.000002, 0.000003],
48 | [0.000000, 0.000000, 0.000000],
49 |
50 | [0.103457, 1.857829, -10.548503],
51 | [43.500042, -0.000027, 0.000008],
52 | [42.372257, -0.000008, 0.000014],
53 | [17.299992, -0.000005, 0.000004],
54 | [0.000000, 0.000000, 0.000000],
55 |
56 | [6.901968, -2.603733, -0.000001],
57 | [12.588099, 0.000002, 0.000000],
58 | [12.343206, 0.000000, -0.000001],
59 | [25.832886, -0.000004, 0.000003],
60 | [11.766620, 0.000005, -0.000001],
61 | [0.000000, 0.000000, 0.000000],
62 |
63 | [19.745899, -1.480370, 6.000108],
64 | [11.284125, -0.000009, -0.000018],
65 | [33.000050, 0.000004, 0.000032],
66 | [25.200008, 0.000015, 0.000008],
67 | [0.000000, 0.000000, 0.000000],
68 |
69 | [19.746099, -1.480375, -6.000073],
70 | [11.284138, -0.000015, -0.000012],
71 | [33.000092, 0.000017, 0.000013],
72 | [25.199780, 0.000135, 0.000422],
73 | [0.000000, 0.000000, 0.000000],
74 | ]
75 |
76 | sk_parents = [
77 | -1,
78 | 0,
79 | 1,
80 | 2,
81 | 3,
82 | 4,
83 | 0,
84 | 6,
85 | 7,
86 | 8,
87 | 9,
88 | 0,
89 | 11,
90 | 12,
91 | 13,
92 | 14,
93 | 15,
94 | 13,
95 | 17,
96 | 18,
97 | 19,
98 | 20,
99 | 13,
100 | 22,
101 | 23,
102 | 24,
103 | 25,
104 | ]
105 |
106 | sk_joints_to_remove = [5, 10, 16, 21, 26]
107 |
108 | joint_names = [
109 | "Hips",
110 | "LeftUpLeg",
111 | "LeftLeg",
112 | "LeftFoot",
113 | "LeftToe",
114 | "RightUpLeg",
115 | "RightLeg",
116 | "RightFoot",
117 | "RightToe",
118 | "Spine",
119 | "Spine1",
120 | "Spine2",
121 | "Neck",
122 | "Head",
123 | "LeftShoulder",
124 | "LeftArm",
125 | "LeftForeArm",
126 | "LeftHand",
127 | "RightShoulder",
128 | "RightArm",
129 | "RightForeArm",
130 | "RightHand",
131 | ]
132 |
133 |
134 | class Skeleton:
135 | def __init__(
136 | self,
137 | offsets,
138 | parents,
139 | joints_left=None,
140 | joints_right=None,
141 | bone_length=None,
142 | device=None,
143 | ):
144 | assert len(offsets) == len(parents)
145 |
146 | self._offsets = torch.Tensor(offsets).to(device)
147 | self._parents = np.array(parents)
148 | self._joints_left = joints_left
149 | self._joints_right = joints_right
150 | self._compute_metadata()
151 |
152 | def num_joints(self):
153 | return self._offsets.shape[0]
154 |
155 | def offsets(self):
156 | return self._offsets
157 |
158 | def parents(self):
159 | return self._parents
160 |
161 | def has_children(self):
162 | return self._has_children
163 |
164 | def children(self):
165 | return self._children
166 |
167 | def convert_to_global_pos(self, unit_vec_rerp):
168 | """
169 | Convert the unit offset matrix to global position.
170 | First row(root) will have absolute position value in global coordinates.
171 | """
172 | bone_length = self.get_bone_length_weight()
173 | batch_size = unit_vec_rerp.size(0)
174 | seq_len = unit_vec_rerp.size(1)
175 | unit_vec_table = unit_vec_rerp.reshape(batch_size, seq_len, 22, 3)
176 | global_position = torch.zeros_like(unit_vec_table, device=unit_vec_table.device)
177 |
178 | for i, parent in enumerate(self._parents):
179 | if parent == -1: # if root
180 | global_position[:, :, i] = unit_vec_table[:, :, i]
181 |
182 | else:
183 | global_position[:, :, i] = global_position[:, :, parent] + (
184 | nn.functional.normalize(unit_vec_table[:, :, i], p=2.0, dim=-1)
185 | * bone_length[i]
186 | )
187 |
188 | return global_position
189 |
190 | def convert_to_unit_offset_mat(self, global_position):
191 | """
192 | Convert the global position of the skeleton to a unit offset matrix.
193 | First row(root) will have absolute position value in global coordinates.
194 | """
195 |
196 | bone_length = self.get_bone_length_weight()
197 | unit_offset_mat = torch.zeros_like(
198 | global_position, device=global_position.device
199 | )
200 |
201 | for i, parent in enumerate(self._parents):
202 |
203 | if parent == -1: # if root
204 | unit_offset_mat[:, :, i] = global_position[:, :, i]
205 | else:
206 | unit_offset_mat[:, :, i] = (
207 | global_position[:, :, i] - global_position[:, :, parent]
208 | ) / bone_length[i]
209 |
210 | return unit_offset_mat
211 |
212 | def remove_joints(self, joints_to_remove):
213 | """
214 | Remove the joints specified in 'joints_to_remove', both from the
215 | skeleton definition and from the dataset (which is modified in place).
216 | The rotations of removed joints are propagated along the kinematic chain.
217 | """
218 | valid_joints = []
219 | for joint in range(len(self._parents)):
220 | if joint not in joints_to_remove:
221 | valid_joints.append(joint)
222 |
223 | index_offsets = np.zeros(len(self._parents), dtype=int)
224 | new_parents = []
225 | for i, parent in enumerate(self._parents):
226 | if i not in joints_to_remove:
227 | new_parents.append(parent - index_offsets[parent])
228 | else:
229 | index_offsets[i:] += 1
230 | self._parents = np.array(new_parents)
231 |
232 | self._offsets = self._offsets[valid_joints]
233 | self._compute_metadata()
234 |
235 | def forward_kinematics(self, rotations, root_positions):
236 | """
237 | Perform forward kinematics using the given trajectory and local rotations.
238 | Arguments (where N = batch size, L = sequence length, J = number of joints):
239 | -- rotations: (N, L, J, 4) tensor of unit quaternions describing the local rotations of each joint.
240 | -- root_positions: (N, L, 3) tensor describing the root joint positions.
241 | """
242 | assert len(rotations.shape) == 4
243 | assert rotations.shape[-1] == 4
244 |
245 | positions_world = []
246 | rotations_world = []
247 |
248 | expanded_offsets = self._offsets.expand(
249 | rotations.shape[0],
250 | rotations.shape[1],
251 | self._offsets.shape[0],
252 | self._offsets.shape[1],
253 | )
254 |
255 | # Parallelize along the batch and time dimensions
256 | for i in range(self._offsets.shape[0]):
257 | if self._parents[i] == -1:
258 | positions_world.append(root_positions)
259 | rotations_world.append(rotations[:, :, 0])
260 | else:
261 | positions_world.append(
262 | qrot(rotations_world[self._parents[i]], expanded_offsets[:, :, i])
263 | + positions_world[self._parents[i]]
264 | )
265 | if self._has_children[i]:
266 | rotations_world.append(
267 | qmul(rotations_world[self._parents[i]], rotations[:, :, i])
268 | )
269 | else:
270 | # This joint is a terminal node -> it would be useless to compute the transformation
271 | rotations_world.append(None)
272 |
273 | return torch.stack(positions_world, dim=3).permute(0, 1, 3, 2)
274 |
275 | def forward_kinematics_with_rotation(self, rotations, root_positions):
276 | """
277 | Perform forward kinematics using the given trajectory and local rotations.
278 | Arguments (where N = batch size, L = sequence length, J = number of joints):
279 | -- rotations: (N, L, J, 4) tensor of unit quaternions describing the local rotations of each joint.
280 | -- root_positions: (N, L, 3) tensor describing the root joint positions.
281 | """
282 | assert len(rotations.shape) == 4
283 | assert rotations.shape[-1] == 4
284 |
285 | positions_world = []
286 | rotations_world = []
287 |
288 | expanded_offsets = self._offsets.expand(
289 | rotations.shape[0],
290 | rotations.shape[1],
291 | self._offsets.shape[0],
292 | self._offsets.shape[1],
293 | )
294 |
295 | # Parallelize along the batch and time dimensions
296 | for i in range(self._offsets.shape[0]):
297 | if self._parents[i] == -1:
298 | positions_world.append(root_positions)
299 | rotations_world.append(rotations[:, :, 0])
300 | else:
301 | positions_world.append(
302 | qrot(rotations_world[self._parents[i]], expanded_offsets[:, :, i])
303 | + positions_world[self._parents[i]]
304 | )
305 | if self._has_children[i]:
306 | rotations_world.append(
307 | qmul(rotations_world[self._parents[i]], rotations[:, :, i])
308 | )
309 | else:
310 | # This joint is a terminal node -> it would be useless to compute the transformation
311 | rotations_world.append(
312 | torch.Tensor([1, 0, 0, 0])
313 | .expand(rotations.shape[0], rotations.shape[1], 4)
314 | .to(rotations.device)
315 | )
316 |
317 | return torch.stack(positions_world, dim=3).permute(0, 1, 3, 2), torch.stack(
318 | rotations_world, dim=3
319 | ).permute(0, 1, 3, 2)
320 |
321 | def get_bone_length_weight(self):
322 | bone_length = []
323 | for i, parent in enumerate(self._parents):
324 | if parent == -1:
325 | bone_length.append(1)
326 | else:
327 | bone_length.append(
328 | torch.linalg.norm(self._offsets[i : i + 1], ord="fro").item()
329 | )
330 | return torch.Tensor(bone_length)
331 |
332 | def joints_left(self):
333 | return self._joints_left
334 |
335 | def joints_right(self):
336 | return self._joints_right
337 |
338 | def _compute_metadata(self):
339 | self._has_children = np.zeros(len(self._parents)).astype(bool)
340 | for i, parent in enumerate(self._parents):
341 | if parent != -1:
342 | self._has_children[parent] = True
343 |
344 | self._children = []
345 | for i, parent in enumerate(self._parents):
346 | self._children.append([])
347 | for i, parent in enumerate(self._parents):
348 | if parent != -1:
349 | self._children[parent].append(i)
350 |
--------------------------------------------------------------------------------
/cmib/vis/pose.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pathlib
3 |
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 |
7 |
8 | def project_root_position(position_arr: np.array, file_name: str):
9 | """
10 | Take batch of root arrays and porject it on 2D plane
11 |
12 | N: samples
13 | L: trajectory length
14 | J: joints
15 |
16 | position_arr: [N,L,J,3]
17 | """
18 |
19 | root_joints = position_arr[:, :, 0]
20 |
21 | x_pos = root_joints[:, :, 0]
22 | y_pos = root_joints[:, :, 2]
23 |
24 | fig = plt.figure()
25 |
26 | for i in range(x_pos.shape[1]):
27 |
28 | if i == 0:
29 | plt.scatter(x_pos[:, i], y_pos[:, i], c="b")
30 | elif i == x_pos.shape[1] - 1:
31 | plt.scatter(x_pos[:, i], y_pos[:, i], c="r")
32 | else:
33 | plt.scatter(x_pos[:, i], y_pos[:, i], c="k", marker="*", s=1)
34 |
35 | plt.title(f"Root Position: {file_name}")
36 | plt.xlabel("X Axis")
37 | plt.ylabel("Y Axis")
38 | plt.xlim((-300, 300))
39 | plt.ylim((-300, 300))
40 | plt.grid()
41 | plt.savefig(f"{file_name}.png", dpi=200)
42 |
43 |
44 | def plot_single_pose(
45 | pose,
46 | frame_idx,
47 | skeleton,
48 | save_dir,
49 | prefix,
50 | ):
51 | fig = plt.figure()
52 | ax = fig.add_subplot(111, projection="3d")
53 |
54 | parent_idx = skeleton.parents()
55 |
56 | for i, p in enumerate(parent_idx):
57 | if i > 0:
58 | ax.plot(
59 | [pose[i, 0], pose[p, 0]],
60 | [pose[i, 2], pose[p, 2]],
61 | [pose[i, 1], pose[p, 1]],
62 | c="k",
63 | )
64 |
65 | x_min = pose[:, 0].min()
66 | x_max = pose[:, 0].max()
67 |
68 | y_min = pose[:, 1].min()
69 | y_max = pose[:, 1].max()
70 |
71 | z_min = pose[:, 2].min()
72 | z_max = pose[:, 2].max()
73 |
74 | ax.set_xlim(x_min, x_max)
75 | ax.set_xlabel("$X$ Axis")
76 |
77 | ax.set_ylim(z_min, z_max)
78 | ax.set_ylabel("$Y$ Axis")
79 |
80 | ax.set_zlim(y_min, y_max)
81 | ax.set_zlabel("$Z$ Axis")
82 |
83 | plt.draw()
84 |
85 | title = f"{prefix}: {frame_idx}"
86 | plt.title(title)
87 | prefix = prefix
88 | pathlib.Path(save_dir).mkdir(parents=True, exist_ok=True)
89 | plt.savefig(os.path.join(save_dir, prefix + str(frame_idx) + ".png"), dpi=60)
90 | plt.close()
91 |
92 |
93 | def plot_pose(
94 | start_pose,
95 | inbetween_pose,
96 | target_pose,
97 | frame_idx,
98 | skeleton,
99 | save_dir,
100 | prefix,
101 | ):
102 | fig = plt.figure()
103 | ax = fig.add_subplot(111, projection="3d")
104 |
105 | parent_idx = skeleton.parents()
106 |
107 | for i, p in enumerate(parent_idx):
108 | if i > 0:
109 | ax.plot(
110 | [start_pose[i, 0], start_pose[p, 0]],
111 | [start_pose[i, 2], start_pose[p, 2]],
112 | [start_pose[i, 1], start_pose[p, 1]],
113 | c="b",
114 | )
115 | ax.plot(
116 | [inbetween_pose[i, 0], inbetween_pose[p, 0]],
117 | [inbetween_pose[i, 2], inbetween_pose[p, 2]],
118 | [inbetween_pose[i, 1], inbetween_pose[p, 1]],
119 | c="k",
120 | )
121 | ax.plot(
122 | [target_pose[i, 0], target_pose[p, 0]],
123 | [target_pose[i, 2], target_pose[p, 2]],
124 | [target_pose[i, 1], target_pose[p, 1]],
125 | c="r",
126 | )
127 |
128 | x_min = np.min(
129 | [start_pose[:, 0].min(), inbetween_pose[:, 0].min(), target_pose[:, 0].min()]
130 | )
131 | x_max = np.max(
132 | [start_pose[:, 0].max(), inbetween_pose[:, 0].max(), target_pose[:, 0].max()]
133 | )
134 |
135 | y_min = np.min(
136 | [start_pose[:, 1].min(), inbetween_pose[:, 1].min(), target_pose[:, 1].min()]
137 | )
138 | y_max = np.max(
139 | [start_pose[:, 1].max(), inbetween_pose[:, 1].max(), target_pose[:, 1].max()]
140 | )
141 |
142 | z_min = np.min(
143 | [start_pose[:, 2].min(), inbetween_pose[:, 2].min(), target_pose[:, 2].min()]
144 | )
145 | z_max = np.max(
146 | [start_pose[:, 2].max(), inbetween_pose[:, 2].max(), target_pose[:, 2].max()]
147 | )
148 |
149 | ax.set_xlim(x_min, x_max)
150 | ax.set_xlabel("$X$ Axis")
151 |
152 | ax.set_ylim(z_min, z_max)
153 | ax.set_ylabel("$Y$ Axis")
154 |
155 | ax.set_zlim(y_min, y_max)
156 | ax.set_zlabel("$Z$ Axis")
157 |
158 | plt.draw()
159 |
160 | title = f"{prefix}: {frame_idx}"
161 | plt.title(title)
162 | prefix = prefix
163 | pathlib.Path(save_dir).mkdir(parents=True, exist_ok=True)
164 | plt.savefig(os.path.join(save_dir, prefix + str(frame_idx) + ".png"), dpi=60)
165 | plt.close()
166 |
167 |
168 | def plot_pose_with_stop(
169 | start_pose,
170 | inbetween_pose,
171 | target_pose,
172 | stopover,
173 | frame_idx,
174 | skeleton,
175 | save_dir,
176 | prefix,
177 | ):
178 | fig = plt.figure()
179 | ax = fig.add_subplot(111, projection="3d")
180 |
181 | parent_idx = skeleton.parents()
182 |
183 | for i, p in enumerate(parent_idx):
184 | if i > 0:
185 | ax.plot(
186 | [start_pose[i, 0], start_pose[p, 0]],
187 | [start_pose[i, 2], start_pose[p, 2]],
188 | [start_pose[i, 1], start_pose[p, 1]],
189 | c="b",
190 | )
191 | ax.plot(
192 | [inbetween_pose[i, 0], inbetween_pose[p, 0]],
193 | [inbetween_pose[i, 2], inbetween_pose[p, 2]],
194 | [inbetween_pose[i, 1], inbetween_pose[p, 1]],
195 | c="k",
196 | )
197 | ax.plot(
198 | [target_pose[i, 0], target_pose[p, 0]],
199 | [target_pose[i, 2], target_pose[p, 2]],
200 | [target_pose[i, 1], target_pose[p, 1]],
201 | c="r",
202 | )
203 |
204 | ax.plot(
205 | [stopover[i, 0], stopover[p, 0]],
206 | [stopover[i, 2], stopover[p, 2]],
207 | [stopover[i, 1], stopover[p, 1]],
208 | c="indigo",
209 | )
210 |
211 | x_min = np.min(
212 | [start_pose[:, 0].min(), inbetween_pose[:, 0].min(), target_pose[:, 0].min()]
213 | )
214 | x_max = np.max(
215 | [start_pose[:, 0].max(), inbetween_pose[:, 0].max(), target_pose[:, 0].max()]
216 | )
217 |
218 | y_min = np.min(
219 | [start_pose[:, 1].min(), inbetween_pose[:, 1].min(), target_pose[:, 1].min()]
220 | )
221 | y_max = np.max(
222 | [start_pose[:, 1].max(), inbetween_pose[:, 1].max(), target_pose[:, 1].max()]
223 | )
224 |
225 | z_min = np.min(
226 | [start_pose[:, 2].min(), inbetween_pose[:, 2].min(), target_pose[:, 2].min()]
227 | )
228 | z_max = np.max(
229 | [start_pose[:, 2].max(), inbetween_pose[:, 2].max(), target_pose[:, 2].max()]
230 | )
231 |
232 | ax.set_xlim(x_min, x_max)
233 | ax.set_xlabel("$X$ Axis")
234 |
235 | ax.set_ylim(z_min, z_max)
236 | ax.set_ylabel("$Y$ Axis")
237 |
238 | ax.set_zlim(y_min, y_max)
239 | ax.set_zlabel("$Z$ Axis")
240 |
241 | plt.draw()
242 |
243 | title = f"{prefix}: {frame_idx}"
244 | plt.title(title)
245 | prefix = prefix
246 | pathlib.Path(save_dir).mkdir(parents=True, exist_ok=True)
247 | plt.savefig(os.path.join(save_dir, prefix + str(frame_idx) + ".png"), dpi=60)
248 | plt.close()
249 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | black==24.3.0
2 | certifi==2023.7.22
3 | charset-normalizer==2.0.9
4 | click==8.0.3
5 | configparser==5.2.0
6 | cycler==0.11.0
7 | docker-pycreds==0.4.0
8 | fonttools==4.43.0
9 | gitdb==4.0.9
10 | GitPython==3.1.41
11 | idna==3.7
12 | imageio==2.13.3
13 | importlib-metadata==4.8.2
14 | joblib==1.2.0
15 | kiwisolver==1.3.2
16 | matplotlib==3.5.1
17 | mypy-extensions==0.4.3
18 | numpy==1.22.0
19 | packaging==21.3
20 | pathspec==0.9.0
21 | pathtools==0.1.2
22 | Pillow==10.3.0
23 | platformdirs==2.4.0
24 | promise==2.3
25 | protobuf==3.19.5
26 | psutil==5.8.0
27 | pyparsing==3.0.6
28 | python-dateutil==2.8.2
29 | PyYAML==6.0
30 | requests==2.31.0
31 | scikit-learn==1.0.1
32 | scipy==1.10.0
33 | sentry-sdk==1.14.0
34 | shortuuid==1.0.8
35 | six==1.16.0
36 | sklearn==0.0
37 | smmap==5.0.0
38 | subprocess32==3.5.4
39 | termcolor==1.1.0
40 | threadpoolctl==3.0.0
41 | tomli==1.2.3
42 | torch==1.13.1
43 | torchaudio==0.10.1+cu113
44 | torchvision==0.11.2+cu113
45 | tqdm==4.62.3
46 | typed-ast==1.5.1
47 | typing_extensions==4.0.1
48 | urllib3==1.26.18
49 | wandb==0.12.7
50 | yaspin==2.1.0
51 | zipp==3.6.0
52 |
--------------------------------------------------------------------------------
/run_cmib.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from pathlib import Path
4 |
5 | import imageio
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | from PIL import Image
10 | from sklearn.preprocessing import LabelEncoder
11 |
12 | from cmib.data.lafan1_dataset import LAFAN1Dataset
13 | from cmib.data.utils import write_json
14 | from cmib.lafan1.utils import quat_ik
15 | from cmib.model.network import TransformerModel
16 | from cmib.model.preprocess import (lerp_input_repr, replace_constant,
17 | slerp_input_repr, vectorize_representation)
18 | from cmib.model.skeleton import (Skeleton, joint_names, sk_joints_to_remove,
19 | sk_offsets, sk_parents)
20 | from cmib.vis.pose import plot_pose_with_stop
21 |
22 |
23 | def test(opt, device):
24 |
25 | save_dir = Path(os.path.join('runs', 'train', opt.exp_name))
26 | wdir = save_dir / 'weights'
27 | weights = os.listdir(wdir)
28 |
29 | if opt.weight == 'latest':
30 | weights_paths = [wdir / weight for weight in weights]
31 | weight_path = max(weights_paths , key = os.path.getctime)
32 | else:
33 | weight_path = wdir / ('train-' + opt.weight + '.pt')
34 | ckpt = torch.load(weight_path, map_location=device)
35 | print(f"Loaded weight: {weight_path}")
36 |
37 |
38 | # Load Skeleton
39 | skeleton_mocap = Skeleton(offsets=sk_offsets, parents=sk_parents, device=device)
40 | skeleton_mocap.remove_joints(sk_joints_to_remove)
41 |
42 | # Load LAFAN Dataset
43 | Path(opt.processed_data_dir).mkdir(parents=True, exist_ok=True)
44 | test_window = ckpt['horizon'] - 1 + 10
45 | lafan_dataset = LAFAN1Dataset(lafan_path=opt.data_path, processed_data_dir=opt.processed_data_dir, train=False, device=device, window=test_window)
46 | total_data = lafan_dataset.data['global_pos'].shape[0]
47 |
48 | # Replace with noise to In-betweening Frames
49 | from_idx, target_idx = ckpt['from_idx'], ckpt['target_idx'] # default: 9-40, max: 48
50 | horizon = ckpt['horizon']
51 | print(f"HORIZON: {horizon}")
52 |
53 | test_idx = [950, 1140, 2100]
54 |
55 | # Extract dimension from processed data
56 | pos_dim = lafan_dataset.num_joints * 3
57 | rot_dim = lafan_dataset.num_joints * 4
58 | repr_dim = pos_dim + rot_dim
59 |
60 | root_pos = torch.Tensor(lafan_dataset.data['root_p'][:, from_idx:target_idx+1]).to(device)
61 | local_q = torch.Tensor(lafan_dataset.data['local_q'][:, from_idx:target_idx+1]).to(device)
62 | local_q_normalized = nn.functional.normalize(local_q, p=2.0, dim=-1)
63 |
64 | # Replace testing inputs
65 | fixed = 0
66 | global_pos, global_q = skeleton_mocap.forward_kinematics_with_rotation(local_q_normalized, root_pos)
67 |
68 | interpolation = ckpt['interpolation']
69 | print(f"Interpolation Mode: {interpolation}")
70 |
71 | if interpolation == 'constant':
72 | global_pose_vec_gt = vectorize_representation(global_pos, global_q)
73 | global_pose_vec_input = global_pose_vec_gt.clone().detach()
74 | pose_interpolated_input = replace_constant(global_pose_vec_input, fixed)
75 | input_pos = pose_interpolated_input[:,:,:pos_dim].detach().numpy()
76 |
77 | elif interpolation == 'slerp':
78 | global_pose_vec_gt = vectorize_representation(global_pos, global_q)
79 | global_pose_vec_input = global_pose_vec_gt.clone().detach()
80 | root_vec = global_pose_vec_input[:,:,:pos_dim]
81 | rot_vec = global_pose_vec_input[:,:,pos_dim:]
82 | root_lerped = lerp_input_repr(root_vec, fixed)
83 | rot_slerped = slerp_input_repr(rot_vec, fixed)
84 | pose_interpolated_input = torch.cat([root_lerped, rot_slerped], dim=2)
85 | input_pos = pose_interpolated_input[:,:,:pos_dim].detach().numpy()
86 |
87 | else:
88 | raise ValueError('Invalid interpolation method')
89 |
90 | pose_vectorized_input = pose_interpolated_input.permute(1,0,2)
91 |
92 | src_mask = torch.zeros((horizon, horizon), device=device).type(torch.bool)
93 | src_mask = src_mask.to(device)
94 |
95 | seq_categories = [x[:-1] for x in lafan_dataset.data['seq_names']]
96 |
97 | le = LabelEncoder()
98 | le.classes_ = np.load(os.path.join(save_dir, 'le_classes_.npy'))
99 |
100 | target_seq = opt.motion_type
101 | seq_id = np.where(le.classes_==target_seq)[0]
102 | conditioning_labels = np.expand_dims((np.repeat(seq_id[0], repeats=len(seq_categories))), axis=1)
103 | conditioning_labels = torch.Tensor(conditioning_labels).type(torch.int64).to(device)
104 |
105 | model = TransformerModel(seq_len=ckpt['horizon'], d_model=ckpt['d_model'], nhead=ckpt['nhead'], d_hid=ckpt['d_hid'], nlayers=ckpt['nlayers'], dropout=0.05, out_dim=repr_dim)
106 | model.load_state_dict(ckpt['transformer_encoder_state_dict'])
107 | model.eval()
108 |
109 | output, _ = model(pose_vectorized_input, src_mask, conditioning_labels)
110 |
111 | pred_global_pos = output[1:,:,:pos_dim].permute(1,0,2).reshape(total_data,horizon-1,22,3)
112 | global_pos_unit_vec = skeleton_mocap.convert_to_unit_offset_mat(pred_global_pos)
113 | pred_global_pos = skeleton_mocap.convert_to_global_pos(global_pos_unit_vec).detach().numpy()
114 |
115 | pred_global_rot = output[1:,:,pos_dim:].permute(1,0,2).reshape(total_data,horizon-1,22,4)
116 | pred_global_rot_normalized = nn.functional.normalize(pred_global_rot, p=2.0, dim=3).detach().numpy()
117 |
118 | clue = global_pos.clone().detach()
119 |
120 | # Compare Input data, Prediction, GT
121 | for i in range(len(test_idx)):
122 | save_path = os.path.join(opt.save_path, 'test_' + f'{test_idx[i]}')
123 | Path(save_path).mkdir(parents=True, exist_ok=True)
124 | pred_json_path = os.path.join(save_path, 'pred_json')
125 | Path(pred_json_path).mkdir(parents=True, exist_ok=True)
126 | gt_json_path = os.path.join(save_path, 'gt_json')
127 | Path(gt_json_path).mkdir(parents=True, exist_ok=True)
128 |
129 | start_pose = lafan_dataset.data['global_pos'][test_idx[i], from_idx]
130 | target_pose = lafan_dataset.data['global_pos'][test_idx[i], target_idx]
131 | stopover_pose = clue[test_idx[i],fixed]
132 | stopover_rot = global_q[test_idx[i],fixed]
133 | gt_stopover_pose = lafan_dataset.data['global_pos'][test_idx[i], from_idx + fixed]
134 |
135 | # Replace start/end with gt
136 | pred_global_pos[test_idx[i], 0] = start_pose
137 |
138 | gpos = pred_global_pos[test_idx[i]]
139 | grot = pred_global_rot_normalized[test_idx[i]]
140 |
141 | local_quaternion_stopover, local_positions_stopover = quat_ik(stopover_rot.detach().numpy(), stopover_pose.detach().numpy(), parents=skeleton_mocap.parents())
142 | local_quaternion, local_positions = quat_ik(grot, gpos, parents=skeleton_mocap.parents())
143 |
144 | img_aggr_list = []
145 |
146 | write_json(filename=os.path.join(pred_json_path, f'start.json'), local_q=local_quaternion[0], root_pos=local_positions[0,0], joint_names=joint_names)
147 | write_json(filename=os.path.join(pred_json_path, f'target.json'), local_q=local_quaternion[-1], root_pos=local_positions[-1,0], joint_names=joint_names)
148 | write_json(filename=os.path.join(pred_json_path, f'stopover.json'), local_q=local_quaternion_stopover, root_pos=local_positions_stopover[0], joint_names=joint_names)
149 |
150 | write_json(filename=os.path.join(gt_json_path, f'start.json'), local_q=local_quaternion[0], root_pos=local_positions[0,0], joint_names=joint_names)
151 | write_json(filename=os.path.join(gt_json_path, f'target.json'), local_q=local_quaternion[-1], root_pos=local_positions[-1,0], joint_names=joint_names)
152 | write_json(filename=os.path.join(gt_json_path, f'stopover.json'), local_q=local_quaternion_stopover, root_pos=local_positions_stopover[0], joint_names=joint_names)
153 |
154 | for t in range(horizon-1):
155 |
156 | if opt.plot_image:
157 | input_img_path = os.path.join(save_path, 'input')
158 | pred_img_path = os.path.join(save_path, 'pred_img')
159 | gt_img_path = os.path.join(save_path, 'gt_img')
160 |
161 | plot_pose_with_stop(start_pose, input_pos[test_idx[i],t].reshape(lafan_dataset.num_joints, 3), target_pose, stopover_pose, t, skeleton_mocap, save_dir=input_img_path, prefix='input')
162 | plot_pose_with_stop(start_pose, pred_global_pos[test_idx[i],t].reshape(lafan_dataset.num_joints, 3), target_pose, stopover_pose, t, skeleton_mocap, save_dir=pred_img_path, prefix='pred')
163 | plot_pose_with_stop(start_pose, lafan_dataset.data['global_pos'][test_idx[i], t+from_idx], target_pose, gt_stopover_pose, t, skeleton_mocap, save_dir=gt_img_path, prefix='gt')
164 |
165 | input_img = Image.open(os.path.join(input_img_path, 'input'+str(t)+'.png'), 'r')
166 | pred_img = Image.open(os.path.join(pred_img_path, 'pred'+str(t)+'.png'), 'r')
167 | gt_img = Image.open(os.path.join(gt_img_path, 'gt'+str(t)+'.png'), 'r')
168 |
169 | img_aggr_list.append(np.concatenate([input_img, pred_img, gt_img.resize(pred_img.size)], 1))
170 |
171 | write_json(filename=os.path.join(pred_json_path, f'{t:05}.json'), local_q=local_quaternion[t], root_pos=local_positions[t,0], joint_names=joint_names)
172 | write_json(filename=os.path.join(gt_json_path, f'{t:05}.json'), local_q=lafan_dataset.data['local_q'][test_idx[i], from_idx + t], root_pos=lafan_dataset.data['global_pos'][test_idx[i], from_idx + t, 0], joint_names=joint_names)
173 |
174 | # Save images
175 | if opt.plot_image:
176 | gif_path = os.path.join(save_path, f'img_{test_idx[i]}.gif')
177 | imageio.mimsave(gif_path, img_aggr_list, duration=0.1)
178 | print(f"ID {test_idx[i]}: test completed.")
179 |
180 | def parse_opt():
181 | parser = argparse.ArgumentParser()
182 | parser.add_argument('--project', default='runs/train', help='project/name')
183 | parser.add_argument('--weight', default='latest')
184 | parser.add_argument('--exp_name', default='exp', help='experiment name')
185 | parser.add_argument('--data_path', type=str, default='ubisoft-laforge-animation-dataset/output/BVH', help='BVH dataset path')
186 | parser.add_argument('--skeleton_path', type=str, default='ubisoft-laforge-animation-dataset/output/BVH/walk1_subject1.bvh', help='path to reference skeleton')
187 | parser.add_argument('--processed_data_dir', type=str, default='processed_data_80/', help='path to save pickled processed data')
188 | parser.add_argument('--save_path', type=str, default='runs/test', help='path to save model')
189 | parser.add_argument('--motion_type', type=str, default='jumps', help='motion type')
190 | parser.add_argument('--plot_image', type=bool, default=False, help='plot image')
191 | opt = parser.parse_args()
192 | return opt
193 |
194 | if __name__ == "__main__":
195 | opt = parse_opt()
196 | device = torch.device("cpu")
197 | test(opt, device)
198 |
--------------------------------------------------------------------------------
/run_test_multi.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 |
3 |
4 | for epoch in range(1000, 1200, 10):
5 | print(f"Epochs: {epoch}")
6 | subprocess.run(["python", "test_benchmark.py",
7 | "--project", "runs/train",
8 | "--exp_name", "slerp30_qnorm_final_bc",
9 | "--weight", str(epoch),
10 | "--processed_data_dir", "processed_data_original_bc/"])
11 |
12 |
--------------------------------------------------------------------------------
/test_benchmark.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import pickle
5 | from pathlib import Path
6 |
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | from sklearn.preprocessing import LabelEncoder
11 |
12 | from cmib.data.lafan1_dataset import LAFAN1Dataset
13 | from cmib.lafan1 import benchmarks, extract
14 | from cmib.data.utils import process_seq_names
15 | from cmib.model.network import TransformerModel
16 | from cmib.model.preprocess import (lerp_input_repr, replace_constant,
17 | slerp_input_repr, vectorize_representation)
18 | from cmib.model.skeleton import (Skeleton, sk_joints_to_remove, sk_offsets, sk_parents, amass_offsets)
19 |
20 |
21 | def test(opt, device):
22 |
23 | save_dir = Path(os.path.join('runs', 'train', opt.exp_name))
24 | test_dir = Path(os.path.join('runs', 'test', opt.exp_name))
25 | test_dir.mkdir(exist_ok=True, parents=True)
26 | wdir = save_dir / 'weights'
27 | weights = os.listdir(wdir)
28 |
29 | if opt.weight == 'latest':
30 | weights_paths = [wdir / weight for weight in weights]
31 | weight_path = max(weights_paths , key = os.path.getctime)
32 | else:
33 | weight_path = wdir / ('train-' + opt.weight + '.pt')
34 | ckpt = torch.load(weight_path, map_location=device)
35 | print(f"Loaded weight: {weight_path}")
36 |
37 | # Load Skeleton
38 | offset = sk_offsets if opt.dataset == 'LAFAN' else amass_offsets
39 | skeleton_mocap = Skeleton(offsets=offset, parents=sk_parents, device=device)
40 | skeleton_mocap.remove_joints(sk_joints_to_remove)
41 |
42 | # Load LAFAN Dataset
43 | Path(opt.processed_data_dir).mkdir(parents=True, exist_ok=True)
44 | if ckpt['horizon'] < 60:
45 | test_window = 65 # Use default test window for 30-frame benchamrk setting.
46 | else:
47 | test_window = ckpt['horizon'] - 1 + 10
48 | lafan_dataset = LAFAN1Dataset(lafan_path=opt.data_path, processed_data_dir=opt.processed_data_dir, train=False, device=device, window=test_window, dataset=opt.dataset)
49 |
50 | # Extract stats
51 | if opt.dataset == 'LAFAN':
52 | train_actors = ["subject1", "subject2", "subject3", "subject4"]
53 | elif opt.dataset in ['HumanEva', 'PosePrior']:
54 | train_actors = ["subject1", "subject2"]
55 | elif opt.dataset in ['HUMAN4D']:
56 | train_actors = ["subject1", "subject2", "subject3", "subject4", "subject5", "subject6", "subject7"]
57 | elif opt.dataset in ['MPI_HDM05']:
58 | train_actors = ["subject1", "subject2", "subject3"]
59 | else:
60 | ValueError("Invalid Dataset")
61 |
62 | bvh_folder = opt.data_path
63 | stats_file = os.path.join(opt.processed_data_dir, 'train_stats.pkl')
64 |
65 | if not os.path.exists(stats_file):
66 | x_mean, x_std, offsets = extract.get_train_stats(bvh_folder, train_actors)
67 | with open(stats_file, 'wb') as f:
68 | pickle.dump({
69 | 'x_mean': x_mean,
70 | 'x_std': x_std,
71 | 'offsets': offsets,
72 | }, f, protocol=pickle.HIGHEST_PROTOCOL)
73 | else:
74 | print('Reusing stats file: ' + stats_file)
75 | with open(stats_file, 'rb') as f:
76 | stats = pickle.load(f)
77 | x_mean = stats['x_mean']
78 | x_std = stats['x_std']
79 | offsets = stats['offsets']
80 |
81 |
82 | total_data = lafan_dataset.data['global_pos'].shape[0]
83 |
84 | from_idx, target_idx = ckpt['from_idx'], ckpt['target_idx'] # default: 9-38, max: 48
85 | horizon = ckpt['horizon']
86 | print(f"HORIZON: {horizon}")
87 |
88 | test_idx = []
89 | for i in range(total_data):
90 | test_idx.append(i)
91 |
92 | # Extract dimension from processed data
93 | pos_dim = lafan_dataset.num_joints * 3
94 | rot_dim = lafan_dataset.num_joints * 4
95 | repr_dim = pos_dim + rot_dim
96 |
97 | root_pos = torch.Tensor(lafan_dataset.data['root_p'][:, from_idx:target_idx+1]).to(device)
98 | local_q = torch.Tensor(lafan_dataset.data['local_q'][:, from_idx:target_idx+1]).to(device)
99 | local_q_normalized = nn.functional.normalize(local_q, p=2.0, dim=-1)
100 |
101 | # Replace testing inputs
102 | fixed = 0
103 |
104 | global_pos, global_q = skeleton_mocap.forward_kinematics_with_rotation(local_q_normalized, root_pos)
105 | global_pos[:,fixed] += torch.Tensor([0,0,0]).expand(global_pos.size(0),lafan_dataset.num_joints,3)
106 |
107 | interpolation = ckpt['interpolation']
108 |
109 | if interpolation == 'constant':
110 | global_pose_vec_gt = vectorize_representation(global_pos, global_q)
111 | global_pose_vec_input = global_pose_vec_gt.clone().detach()
112 | pose_interpolated_input = replace_constant(global_pose_vec_input, fixed)
113 |
114 | elif interpolation == 'slerp':
115 | global_pose_vec_gt = vectorize_representation(global_pos, global_q)
116 | global_pose_vec_input = global_pose_vec_gt.clone().detach()
117 | root_vec = global_pose_vec_input[:,:,:pos_dim]
118 | rot_vec = global_pose_vec_input[:,:,pos_dim:]
119 | root_lerped = lerp_input_repr(root_vec, fixed)
120 | rot_slerped = slerp_input_repr(rot_vec, fixed)
121 | pose_interpolated_input = torch.cat([root_lerped, rot_slerped], dim=2)
122 |
123 | else:
124 | raise ValueError('Invalid interpolation method')
125 |
126 | pose_vectorized_input = pose_interpolated_input.permute(1,0,2)
127 |
128 | src_mask = torch.zeros((horizon, horizon), device=device).type(torch.bool)
129 | src_mask = src_mask.to(device)
130 |
131 | le = LabelEncoder()
132 | le.classes_ = np.load(os.path.join(save_dir, 'le_classes_.npy'))
133 | num_labels = len(le.classes_)
134 |
135 | model = TransformerModel(seq_len=ckpt['horizon'], d_model=ckpt['d_model'], nhead=ckpt['nhead'], d_hid=ckpt['d_hid'], nlayers=ckpt['nlayers'], dropout=0.0, out_dim=repr_dim, num_labels=num_labels)
136 | model.load_state_dict(ckpt['transformer_encoder_state_dict'])
137 | model.eval()
138 |
139 | l2p = []
140 | l2q = []
141 |
142 | pred_rot_npss = []
143 | for i in range(len(test_idx)):
144 | print(f"Processing ID: {test_idx[i]}")
145 |
146 | seq_label = lafan_dataset.data['seq_names'][i][:-1]
147 |
148 | if opt.dataset == 'LAFAN':
149 | seq_label = [x[:-1] for x in lafan_dataset.data['seq_names']][i]
150 | else:
151 | seq_label = process_seq_names(lafan_dataset.data['seq_names'], dataset=opt.dataset)[i]
152 |
153 | match_class = np.where(le.classes_ == seq_label)[0]
154 | class_id = 0 if len(match_class) == 0 else match_class[0]
155 |
156 | conditioning_label = torch.Tensor([[class_id]]).type(torch.int64).to(device)
157 | cond_output, cond_gt = model(pose_vectorized_input[:, test_idx[i]:test_idx[i]+1, :], src_mask, conditioning_label)
158 | print(f"Condition: {le.classes_[class_id]}")
159 |
160 | output = cond_output
161 |
162 | pred_global_pos = output[1:,:,:pos_dim].permute(1,0,2).reshape(1,horizon-1,22,3)
163 | global_pos_unit_vec = skeleton_mocap.convert_to_unit_offset_mat(pred_global_pos)
164 | pred_global_pos = skeleton_mocap.convert_to_global_pos(global_pos_unit_vec).detach().numpy()
165 |
166 | # Replace start/end with gt
167 | gt_global_pos = lafan_dataset.data['global_pos'][test_idx[i]:test_idx[i]+1, from_idx:target_idx+1].reshape(1, -1, lafan_dataset.num_joints, 3)
168 | pred_global_pos[0,0] = gt_global_pos[0,0]
169 | pred_global_pos[0,-1] = gt_global_pos[0,-1]
170 |
171 | pred_global_rot = output[1:,:,pos_dim:].permute(1,0,2).reshape(1,horizon-1,22,4)
172 | pred_global_rot_normalized = nn.functional.normalize(pred_global_rot, p=2.0, dim=3)
173 | gt_global_rot = global_q[test_idx[i]:test_idx[i]+1]
174 | pred_global_rot_normalized[0,0] = gt_global_rot[0,0]
175 | pred_global_rot_normalized[0,-1] = gt_global_rot[0,-1]
176 | pred_rot_npss.append(pred_global_rot_normalized)
177 |
178 | # Normalize for L2P
179 | normalized_gt_pos = torch.Tensor((lafan_dataset.data['global_pos'][test_idx[i]:test_idx[i]+1, from_idx:target_idx+1].reshape(1, -1, lafan_dataset.num_joints * 3).transpose(0,2,1) - x_mean) / x_std)
180 | normalized_pred_pos = torch.Tensor((pred_global_pos.reshape(1, -1, lafan_dataset.num_joints * 3).transpose(0,2,1) - x_mean) / x_std)
181 |
182 | l2p.append(torch.mean(torch.norm(normalized_pred_pos[0] - normalized_gt_pos[0], dim=(0))).item())
183 | l2q.append(torch.mean(torch.norm(pred_global_rot_normalized[0] - global_q[test_idx[i]], dim=(1,2))).item())
184 | print(f"ID {test_idx[i]}: test completed.")
185 |
186 | l2p_mean = np.mean(l2p)
187 | l2q_mean = np.mean(l2q)
188 |
189 | # Drop end nodes for fair comparison
190 | pred_quaternions = torch.cat(pred_rot_npss, dim=0)
191 | npss_gt = global_q[:,:,skeleton_mocap.has_children()].reshape(global_q.shape[0],global_q.shape[1], -1)
192 | npss_pred = pred_quaternions[:,:,skeleton_mocap.has_children()].reshape(pred_quaternions.shape[0],pred_quaternions.shape[1], -1)
193 | npss = benchmarks.npss(npss_gt, npss_pred).item()
194 |
195 | print(f"TOTAL TEST DATA: {len(l2p)}")
196 | print(f"L2P: {l2p_mean}")
197 | print(f"L2Q: {l2q_mean}")
198 | print(f"NPSS: {npss}")
199 |
200 | benchmark_out = {
201 | 'total_data': len(l2p),
202 | 'L2P': l2p_mean,
203 | 'L2Q': l2q_mean,
204 | 'NPSS': npss
205 | }
206 |
207 | with open(os.path.join(test_dir, f'benchmark_out-{opt.weight}.json'), 'w') as f:
208 | json.dump(benchmark_out, f)
209 |
210 |
211 | def parse_opt():
212 | parser = argparse.ArgumentParser()
213 | parser.add_argument('--project', default='runs/train', help='project/name')
214 | parser.add_argument('--exp_name', default='HUMAN4D_80', help='experiment name')
215 | parser.add_argument('--weight', default='latest')
216 | parser.add_argument('--data_path', type=str, default='AMASS/PosePrior', help='BVH dataset path')
217 | parser.add_argument('--dataset', type=str, default='HUMAN4D', help='Dataset name')
218 | parser.add_argument('--processed_data_dir', type=str, default='processed_data_human4d_80/', help='path to save pickled processed data')
219 | opt = parser.parse_args()
220 | return opt
221 |
222 | if __name__ == "__main__":
223 | opt = parse_opt()
224 | device = torch.device("cpu")
225 | test(opt, device)
226 |
--------------------------------------------------------------------------------
/test_condition_comp.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import pickle
5 | from pathlib import Path
6 |
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | from sklearn.preprocessing import LabelEncoder
11 |
12 | from cmib.data.lafan1_dataset import LAFAN1Dataset
13 | from cmib.data.utils import drop_end_quat
14 | from cmib.lafan1 import extract
15 | from cmib.model.network import TransformerModel
16 | from cmib.model.preprocess import (lerp_input_repr, replace_constant,
17 | slerp_input_repr, vectorize_representation)
18 | from cmib.model.skeleton import (Skeleton, sk_joints_to_remove, sk_offsets,
19 | sk_parents)
20 |
21 |
22 | def test(opt, device):
23 |
24 | save_dir = Path(os.path.join('runs', 'train', opt.exp_name))
25 | gt_motion = opt.data_path.split('/')[-1].split('_')[-1].lower()
26 | Path(os.path.join('cond_bch', gt_motion)).mkdir(parents=True, exist_ok=True)
27 | wdir = save_dir / 'weights'
28 | weights = os.listdir(wdir)
29 | weights_paths = [wdir / weight for weight in weights]
30 | latest_weight = max(weights_paths , key = os.path.getctime)
31 | ckpt = torch.load(latest_weight, map_location=device)
32 | print(f"Loaded weight: {latest_weight}")
33 |
34 | # Load Skeleton
35 | skeleton_mocap = Skeleton(offsets=sk_offsets, parents=sk_parents, device=device)
36 | skeleton_mocap.remove_joints(sk_joints_to_remove)
37 |
38 | # Load LAFAN Dataset
39 | processed_data_dir = 'condition_test_' + gt_motion
40 | Path(processed_data_dir).mkdir(parents=True, exist_ok=True)
41 | test_window = ckpt['horizon'] - 1 + 10
42 | print(f"Test Window: {test_window}")
43 | lafan_dataset = LAFAN1Dataset(lafan_path=opt.data_path, processed_data_dir=processed_data_dir, train=False, device=device, window=test_window)
44 |
45 | # Extract stats
46 | train_actors = ['subject1', 'subject2', 'subject3', 'subject4']
47 | bvh_folder = os.path.join('ubisoft-laforge-animation-dataset', 'output', 'BVH')
48 | stats_file = os.path.join(opt.train_stat, 'train_stats.pkl')
49 |
50 | if not os.path.exists(stats_file):
51 | x_mean, x_std, offsets = extract.get_train_stats(bvh_folder, train_actors)
52 | with open(stats_file, 'wb') as f:
53 | pickle.dump({
54 | 'x_mean': x_mean,
55 | 'x_std': x_std,
56 | 'offsets': offsets,
57 | }, f, protocol=pickle.HIGHEST_PROTOCOL)
58 | else:
59 | print('Reusing stats file: ' + stats_file)
60 | with open(stats_file, 'rb') as f:
61 | stats = pickle.load(f)
62 | x_mean = stats['x_mean']
63 | x_std = stats['x_std']
64 | offsets = stats['offsets']
65 |
66 |
67 | total_data = lafan_dataset.data['global_pos'].shape[0]
68 |
69 | # Replace with noise to In-betweening Frames
70 | from_idx, target_idx = ckpt['from_idx'], ckpt['target_idx']
71 | horizon = ckpt['horizon']
72 | print(f"HORIZON: {horizon}")
73 |
74 | test_idx = []
75 | for i in range(total_data):
76 | test_idx.append(i)
77 |
78 | # Extract dimension from processed data
79 | pos_dim = lafan_dataset.num_joints * 3
80 | rot_dim = lafan_dataset.num_joints * 4
81 | repr_dim = pos_dim + rot_dim
82 |
83 | root_pos = torch.Tensor(lafan_dataset.data['root_p'][:, from_idx:target_idx+1]).to(device)
84 | local_q = torch.Tensor(lafan_dataset.data['local_q'][:, from_idx:target_idx+1]).to(device)
85 | local_q_normalized = nn.functional.normalize(local_q, p=2.0, dim=-1)
86 | global_pos, global_q = skeleton_mocap.forward_kinematics_with_rotation(local_q_normalized, root_pos)
87 |
88 | fixed = 0
89 | interpolation = ckpt['interpolation']
90 |
91 | if interpolation == 'constant':
92 | global_pose_vec_gt = vectorize_representation(global_pos, global_q)
93 | global_pose_vec_input = global_pose_vec_gt.clone().detach()
94 | pose_interpolated_input = replace_constant(global_pose_vec_input, fixed)
95 |
96 | elif interpolation == 'slerp':
97 | global_pose_vec_gt = vectorize_representation(global_pos, global_q)
98 | global_pose_vec_input = global_pose_vec_gt.clone().detach()
99 | root_vec = global_pose_vec_input[:,:,:pos_dim]
100 | rot_vec = global_pose_vec_input[:,:,pos_dim:]
101 | root_lerped = lerp_input_repr(root_vec, fixed)
102 | rot_slerped = slerp_input_repr(rot_vec, fixed)
103 | pose_interpolated_input = torch.cat([root_lerped, rot_slerped], dim=2)
104 |
105 | else:
106 | raise ValueError('Invalid interpolation method')
107 |
108 | pose_vectorized_input = pose_interpolated_input.permute(1,0,2)
109 |
110 | src_mask = torch.zeros((horizon, horizon), device=device).type(torch.bool)
111 | src_mask = src_mask.to(device)
112 |
113 | seq_categories = [x[:-1] for x in lafan_dataset.data['seq_names']]
114 |
115 | l1_loss = nn.L1Loss()
116 |
117 | le = LabelEncoder()
118 | le.classes_ = np.load(os.path.join(save_dir, 'le_classes_.npy'))
119 |
120 | model = TransformerModel(seq_len=ckpt['horizon'], d_model=ckpt['d_model'], nhead=ckpt['nhead'], d_hid=ckpt['d_hid'], nlayers=ckpt['nlayers'], dropout=0.0, out_dim=repr_dim)
121 | model.load_state_dict(ckpt['transformer_encoder_state_dict'])
122 | model.eval()
123 |
124 | testing_motions = ['walk', 'run', 'dance', 'jumps', 'fight']
125 |
126 | summary = {}
127 | for cond_motion in testing_motions:
128 | l2p = []
129 | l2q = []
130 |
131 | pred_rot_npss = []
132 |
133 | print(f"GT: {gt_motion}")
134 | print(f"Condition: {cond_motion}")
135 |
136 | bch_out = {}
137 |
138 | bch_out['cond_motion'] = cond_motion
139 | bch_out['gt_motion'] = gt_motion
140 |
141 | motion_index = np.where(le.classes_ == cond_motion)[0][0]
142 |
143 | conditioning_label = torch.Tensor([[motion_index] * total_data]).type(torch.int64).to(device).permute(1,0)
144 | cond_output, _ = model(pose_vectorized_input[:, :, :], src_mask, conditioning_label)
145 |
146 | output = cond_output
147 |
148 | pred_global_pos = output[1:,:,:pos_dim].permute(1,0,2).reshape(total_data,horizon-1,22,3)
149 | global_pos_unit_vec = skeleton_mocap.convert_to_unit_offset_mat(pred_global_pos)
150 | pred_global_pos = skeleton_mocap.convert_to_global_pos(global_pos_unit_vec).detach().numpy()
151 |
152 | gt_global_pos = lafan_dataset.data['global_pos'][:, from_idx:target_idx+1].reshape(1, -1, lafan_dataset.num_joints, 3)
153 | pred_global_pos[0,0] = gt_global_pos[0,0]
154 | pred_global_pos[0,-1] = gt_global_pos[0,-1]
155 |
156 | pred_global_rot = output[1:,:,pos_dim:].permute(1,0,2).reshape(total_data,horizon-1,22,4)
157 | pred_global_rot_normalized = nn.functional.normalize(pred_global_rot, p=2.0, dim=3)
158 | gt_global_rot = global_q[:]
159 | pred_global_rot_normalized[0,0] = gt_global_rot[0,0]
160 | pred_global_rot_normalized[0,-1] = gt_global_rot[0,-1]
161 | pred_rot_npss.append(pred_global_rot_normalized)
162 |
163 | # Normalize for L2P
164 | normalized_gt_pos = torch.Tensor((lafan_dataset.data['global_pos'][:, from_idx:target_idx+1].reshape(total_data, -1, lafan_dataset.num_joints * 3).transpose(0,2,1) - x_mean) / x_std)
165 | normalized_pred_pos = torch.Tensor((pred_global_pos.reshape(total_data, -1, lafan_dataset.num_joints * 3).transpose(0,2,1) - x_mean) / x_std)
166 |
167 | l2p.append(torch.mean(torch.norm(normalized_pred_pos - normalized_gt_pos, dim=(1))).item())
168 | l2q.append(torch.mean(torch.norm(pred_global_rot_normalized - global_q, dim=(2,3))).item())
169 |
170 | l2p_mean = np.mean(l2p)
171 | l2q_mean = np.mean(l2q)
172 |
173 | print(f"TOTAL TEST DATA: {total_data}")
174 | print(f"L2P: {l2p_mean}")
175 | print(f"L2Q: {l2q_mean}")
176 | print("=================")
177 |
178 | bch_out['L2P'] = l2p_mean
179 | bch_out['L2Q'] = l2q_mean
180 | bch_out['TotalData'] = total_data
181 |
182 | summary[cond_motion] = bch_out
183 | with open(os.path.join('cond_bch', gt_motion,f'{gt_motion}.txt'), 'w') as f:
184 | json.dump(summary, f)
185 |
186 | def parse_opt():
187 | parser = argparse.ArgumentParser()
188 | parser.add_argument('--project', default='runs/train', help='project/name')
189 | parser.add_argument('--exp_name', default='train_60', help='experiment name')
190 | parser.add_argument('--data_path', type=str, default='ubisoft-laforge-animation-dataset/output/BVH_FIGHT', help='BVH dataset path')
191 | parser.add_argument('--skeleton_path', type=str, default='ubisoft-laforge-animation-dataset/output/BVH/walk1_subject1.bvh', help='path to reference skeleton')
192 | parser.add_argument('--save_path', type=str, default='runs/test', help='path to save model')
193 | parser.add_argument('--train_stat', default='processed_data_80', help='train stat')
194 | opt = parser.parse_args()
195 | return opt
196 |
197 | if __name__ == "__main__":
198 | opt = parse_opt()
199 | device = torch.device("cpu")
200 | test(opt, device)
201 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | python trainer.py \
2 | --data_path="AMASS/MPI_HDM05" \
3 | --dataset="MPI_HDM05" \
4 | --processed_data_dir="processed_data_mpi05_60/" \
5 | --window=70 \
6 | --batch_size=32 \
7 | --epochs=5000 \
8 | --device=1 \
9 | --exp_name="MPI_HDM05_60" \
10 | --save_interval=20 \
11 | --learning_rate=0.0001 \
12 | --loss_cond_weight=1.5 \
13 | --loss_pos_weight=0.05 \
14 | --loss_rot_weight=2.0 \
15 | --from_idx=9 \
16 | --target_idx=68 \
17 | --interpolation='slerp'
18 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from pathlib import Path
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import wandb
9 | import yaml
10 | from sklearn.preprocessing import LabelEncoder
11 | from torch.optim import AdamW
12 | from torch.utils.data import DataLoader, TensorDataset
13 | from tqdm import tqdm
14 |
15 | from cmib.data.lafan1_dataset import LAFAN1Dataset
16 | from cmib.data.utils import flip_bvh, increment_path, process_seq_names
17 | from cmib.model.network import TransformerModel
18 | from cmib.model.preprocess import (lerp_input_repr, replace_constant,
19 | slerp_input_repr, vectorize_representation)
20 | from cmib.model.skeleton import (Skeleton, sk_joints_to_remove, sk_offsets, sk_parents, amass_offsets)
21 |
22 |
23 | def train(opt, device):
24 |
25 | print(f"[DATASET: {opt.dataset}]")
26 | # Prepare Directories
27 | save_dir = Path(opt.save_dir)
28 | wdir = save_dir / 'weights'
29 | wdir.mkdir(parents=True, exist_ok=True)
30 |
31 | # Save run settings
32 | with open(save_dir / 'opt.yaml', 'w') as f:
33 | yaml.safe_dump(vars(opt), f, sort_keys=True)
34 |
35 | epochs = opt.epochs
36 | save_interval = opt.save_interval
37 |
38 | # Loggers
39 | wandb.init(config=opt, project=opt.wandb_pj_name, entity=opt.entity, name=opt.exp_name, dir=opt.save_dir)
40 |
41 | # Load Skeleton
42 | offset = sk_offsets if opt.dataset == 'LAFAN' else amass_offsets
43 | skeleton_mocap = Skeleton(offsets=offset, parents=sk_parents, device=device)
44 | skeleton_mocap.remove_joints(sk_joints_to_remove)
45 |
46 | # Flip, Load and preprocess data. It utilizes LAFAN1 utilities
47 | if opt.dataset == 'LAFAN':
48 | flip_bvh(opt.data_path, skip='subject5')
49 |
50 | # Load LAFAN Dataset
51 | Path(opt.processed_data_dir).mkdir(parents=True, exist_ok=True)
52 | lafan_dataset = LAFAN1Dataset(lafan_path=opt.data_path, processed_data_dir=opt.processed_data_dir, train=True, device=device, window=opt.window, dataset=opt.dataset)
53 |
54 | from_idx, target_idx = opt.from_idx, opt.target_idx
55 | horizon = target_idx - from_idx + 1
56 | print(f"Horizon: {horizon}")
57 | horizon += 1 # Add one for conditioning token
58 | print(f"Horizon with Conditioning: {horizon}")
59 | print(f"Interpolation Mode: {opt.interpolation}")
60 |
61 | root_pos = torch.Tensor(lafan_dataset.data['root_p'][:, from_idx:target_idx+1]).to(device)
62 | local_q = torch.Tensor(lafan_dataset.data['local_q'][:, from_idx:target_idx+1]).to(device)
63 | local_q_normalized = nn.functional.normalize(local_q, p=2.0, dim=-1)
64 |
65 | global_pos, global_q = skeleton_mocap.forward_kinematics_with_rotation(local_q_normalized, root_pos)
66 |
67 | global_pose_vec_gt = vectorize_representation(global_pos, global_q)
68 | global_pose_vec_input = global_pose_vec_gt.clone().detach()
69 |
70 | if opt.dataset == 'LAFAN':
71 | seq_categories = [x[:-1] for x in lafan_dataset.data['seq_names']]
72 | else:
73 | seq_categories = process_seq_names(lafan_dataset.data['seq_names'], dataset=opt.dataset)
74 |
75 | le = LabelEncoder()
76 | le_np = le.fit_transform(seq_categories)
77 | seq_labels = torch.Tensor(le_np).type(torch.int64).unsqueeze(1).to(device)
78 | np.save(f'{save_dir}/le_classes_.npy', le.classes_)
79 | num_labels = len(seq_labels.squeeze().unique())
80 |
81 | tensor_dataset = TensorDataset(global_pose_vec_input, global_pose_vec_gt, seq_labels)
82 | lafan_data_loader = DataLoader(tensor_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=0)
83 |
84 | pos_dim = lafan_dataset.num_joints * 3
85 | rot_dim = lafan_dataset.num_joints * 4
86 | repr_dim = pos_dim + rot_dim
87 | nhead = 7 # repr_dim = 154
88 |
89 | transformer_encoder = TransformerModel(seq_len=horizon, d_model=repr_dim, nhead=nhead, d_hid=2048, nlayers=8, dropout=0.05, out_dim=repr_dim, num_labels=num_labels)
90 | transformer_encoder.to(device)
91 |
92 | l1_loss = nn.L1Loss()
93 | optim = AdamW(params=transformer_encoder.parameters(), lr=opt.learning_rate)
94 | scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=100, gamma=0.9)
95 |
96 | for epoch in range(1, epochs + 1):
97 |
98 | pbar = tqdm(lafan_data_loader, position=1, desc="Batch")
99 |
100 | recon_cond_loss = []
101 | recon_pos_loss = []
102 | recon_rot_loss = []
103 | total_loss_list = []
104 |
105 | for minibatch_pose_input, minibatch_pose_gt, seq_label in pbar:
106 |
107 | for _ in range(5):
108 | mask_start_frame = np.random.randint(0, horizon-1)
109 |
110 | if opt.interpolation == 'constant':
111 | pose_interpolated_input = replace_constant(minibatch_pose_input, mask_start_frame)
112 | elif opt.interpolation == 'slerp':
113 | root_vec = minibatch_pose_input[:,:,:pos_dim]
114 | rot_vec = minibatch_pose_input[:,:,pos_dim:]
115 | root_lerped = lerp_input_repr(root_vec, mask_start_frame)
116 | rot_slerped = slerp_input_repr(rot_vec, mask_start_frame)
117 | pose_interpolated_input = torch.cat([root_lerped, rot_slerped], dim=2)
118 | else:
119 | raise ValueError('Invalid interpolation method')
120 |
121 | pose_interpolated_input = pose_interpolated_input.permute(1,0,2)
122 |
123 | src_mask = torch.zeros((horizon, horizon), device=device).type(torch.bool)
124 | src_mask = src_mask.to(device)
125 |
126 | output, cond_gt = transformer_encoder(pose_interpolated_input, src_mask, seq_label)
127 |
128 | cond_pred = output[0:1, :, :]
129 | cond_loss = l1_loss(cond_pred, cond_gt)
130 | recon_cond_loss.append(opt.loss_cond_weight * cond_loss)
131 |
132 | pos_pred = output[1:,:,:pos_dim].permute(1,0,2)
133 | pos_gt = minibatch_pose_gt[:,:,:pos_dim]
134 | pos_loss = l1_loss(pos_pred, pos_gt)
135 | recon_pos_loss.append(opt.loss_pos_weight * pos_loss)
136 |
137 | rot_pred = output[1:,:,pos_dim:].permute(1,0,2)
138 | rot_pred_reshaped = rot_pred.reshape(rot_pred.shape[0], rot_pred.shape[1], lafan_dataset.num_joints, 4)
139 | rot_pred_normalized = nn.functional.normalize(rot_pred_reshaped, p=2.0, dim=3)
140 |
141 | rot_gt = minibatch_pose_gt[:,:,pos_dim:]
142 | rot_gt_reshaped = rot_gt.reshape(rot_gt.shape[0], rot_gt.shape[1], lafan_dataset.num_joints, 4)
143 | rot_loss = l1_loss(rot_pred_reshaped, rot_gt_reshaped)
144 | recon_rot_loss.append(opt.loss_rot_weight * rot_loss)
145 |
146 | total_g_loss = opt.loss_pos_weight * pos_loss + \
147 | opt.loss_rot_weight * rot_loss + \
148 | opt.loss_cond_weight * cond_loss
149 |
150 | total_loss_list.append(total_g_loss)
151 |
152 | optim.zero_grad()
153 | total_g_loss.backward()
154 | torch.nn.utils.clip_grad_norm_(transformer_encoder.parameters(), 1.0, error_if_nonfinite=False)
155 | optim.step()
156 |
157 | scheduler.step()
158 |
159 | # Log
160 | log_dict = {
161 | "Train/Loss/Condition Loss": torch.stack(recon_cond_loss).mean().item(),
162 | "Train/Loss/Position Loss": torch.stack(recon_pos_loss).mean().item(),
163 | "Train/Loss/Rotatation Loss": torch.stack(recon_rot_loss).mean().item(),
164 | "Train/Loss/Total Loss": torch.stack(total_loss_list).mean().item(),
165 | }
166 | wandb.log(log_dict)
167 |
168 | # Save model
169 | if (epoch % save_interval) == 0:
170 | ckpt = {'epoch': epoch,
171 | 'transformer_encoder_state_dict': transformer_encoder.state_dict(),
172 | 'horizon': transformer_encoder.seq_len,
173 | 'from_idx': opt.from_idx,
174 | 'target_idx': opt.target_idx,
175 | 'd_model': transformer_encoder.d_model,
176 | 'nhead': transformer_encoder.nhead,
177 | 'd_hid': transformer_encoder.d_hid,
178 | 'nlayers': transformer_encoder.nlayers,
179 | 'optimizer_state_dict': optim.state_dict(),
180 | 'interpolation': opt.interpolation,
181 | 'loss': total_g_loss}
182 | torch.save(ckpt, os.path.join(wdir, f'train-{epoch}.pt'))
183 | print(f"[MODEL SAVED at {epoch} Epoch]")
184 |
185 | wandb.run.finish()
186 | torch.cuda.empty_cache()
187 |
188 | def parse_opt():
189 | parser = argparse.ArgumentParser()
190 | parser.add_argument('--project', default='runs/train', help='project/name')
191 | parser.add_argument('--data_path', type=str, default='ubisoft-laforge-animation-dataset/output/BVH', help='BVH dataset path')
192 | parser.add_argument('--dataset', type=str, default='LAFAN', help='Dataset name')
193 | parser.add_argument('--processed_data_dir', type=str, default='processed_data_80/', help='path to save pickled processed data')
194 | parser.add_argument('--window', type=int, default=90, help='horizon')
195 | parser.add_argument('--wandb_pj_name', type=str, default='cmib_train', help='project name')
196 | parser.add_argument('--batch_size', type=int, default=32, help='batch size')
197 | parser.add_argument('--epochs', type=int, default=3000)
198 | parser.add_argument('--device', default='0', help='cuda device')
199 | parser.add_argument('--entity', default=None, help='W&B entity')
200 | parser.add_argument('--exp_name', default='exp', help='save to project/name')
201 | parser.add_argument('--save_interval', type=int, default=50, help='Log model after every "save_period" epoch')
202 | parser.add_argument('--learning_rate', type=float, default=0.0001, help='generator_learning_rate')
203 | parser.add_argument('--loss_cond_weight', type=float, default=1.5, help='loss_cond_weight')
204 | parser.add_argument('--loss_pos_weight', type=float, default=0.05, help='loss_pos_weight')
205 | parser.add_argument('--loss_rot_weight', type=float, default=2.0, help='loss_rot_weight')
206 | parser.add_argument('--from_idx', type=int, default=9, help='from idx')
207 | parser.add_argument('--target_idx', type=int, default=88, help='target idx')
208 | parser.add_argument('--interpolation', type=str, default='slerp', help='interpolation')
209 | opt = parser.parse_args()
210 | return opt
211 |
212 | if __name__ == "__main__":
213 | opt = parse_opt()
214 | opt.save_dir = str(increment_path(Path(opt.project) / opt.exp_name))
215 | opt.exp_name = opt.save_dir.split('/')[-1]
216 | device = torch.device(f"cuda:{opt.device}" if torch.cuda.is_available() else "cpu")
217 | train(opt, device)
218 |
--------------------------------------------------------------------------------